Unverified Commit 177eb1b0 authored by Ted Themistokleous's avatar Ted Themistokleous Committed by GitHub
Browse files

Make JIT and pointwise work with zero input args (#1587)

Ensure that we don't have empty inputs when computing shape for pointwise function
parent 1a41c9e9
......@@ -238,6 +238,8 @@ std::string cpp_generator::create_function(const cpp_generator::function& f)
std::string name = f.name.empty() ? "f" + std::to_string(impl->function_count) : f.name;
impl->fs << join_strings(f.attributes, " ") << " " << f.return_type << " " << name;
char delim = '(';
if(f.params.empty())
impl->fs << delim;
for(auto&& p : f.params)
{
impl->fs << delim << p.type << " " << p.name;
......
......@@ -74,6 +74,7 @@ static void create_pointwise_modules(module_pass_manager& mpm)
std::unordered_map<instruction_ref, instruction_ref> param_map;
std::vector<instruction_ref> pointwise_inputs;
std::size_t i = 0;
for(auto input : ins->inputs())
{
if(contains(param_map, input))
......@@ -92,6 +93,10 @@ static void create_pointwise_modules(module_pass_manager& mpm)
}
}
// Don't create pointwise module if no inputs are detected
if(pointwise_inputs.empty())
continue;
std::vector<instruction_ref> inputs;
std::transform(ins->inputs().begin(),
ins->inputs().end(),
......
......@@ -46,13 +46,14 @@ struct pointwise
MIGRAPHX_THROW("should have one submodule.");
}
auto* pm = mods.front();
if(pm->get_output_shapes().size() != 1)
MIGRAPHX_THROW("pointwise should have only one output.");
if(inputs.empty())
MIGRAPHX_THROW("pointwise should have at least one input");
auto pnames = pm->get_parameter_names();
std::sort(pnames.begin(), pnames.end());
check_shapes{inputs, *this}.has(pnames.size()).same_dims();
if(pm->get_output_shapes().size() != 1)
MIGRAPHX_THROW("submodule should have only one output.");
auto type = pm->get_output_shapes().front().type();
// Scalar output if all inputs are scalar
......
......@@ -329,4 +329,36 @@ TEST_CASE(all_scalar_input)
EXPECT(p1 == p2);
}
TEST_CASE(no_input)
{
migraphx::program p;
{
auto* mm = p.get_main_module();
migraphx::shape g_shape{migraphx::shape::int64_type, {1}, {0}};
migraphx::shape s_indices{migraphx::shape::int32_type, {3}};
std::vector<int> indices{3, 800, 800};
auto a0 = mm->add_literal(migraphx::literal{s_indices, indices});
auto a1 = mm->add_literal(migraphx::literal{g_shape, {1}});
int axis = 0;
auto out = mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), a0, a1);
mm->add_return({out});
}
run_pass(p);
// This should NOT create a pointwise module if there are no inputs here.
migraphx::program p2;
{
auto* mm = p2.get_main_module();
migraphx::shape g_shape{migraphx::shape::int64_type, {1}, {0}};
migraphx::shape s_indices{migraphx::shape::int32_type, {3}};
std::vector<int> indices{3, 800, 800};
auto a0 = mm->add_literal(migraphx::literal{s_indices, indices});
auto a1 = mm->add_literal(migraphx::literal{g_shape, {1}});
int axis = 0;
auto out = mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), a0, a1);
mm->add_return({out});
}
EXPECT(p == p2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -1822,6 +1822,33 @@ TEST_CASE(pad_dyn_shape1)
expect_shape(output, migraphx::make_op("pad", {{"pads", {0, 0, 1, 1, 0, 0, 1, 1}}}), input);
}
TEST_CASE(pointwise_no_module)
{
migraphx::shape input{migraphx::shape::float_type, {0}, {0}};
throws_shape(migraphx::make_op("pointwise"), input);
}
TEST_CASE(pointwise_no_input)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::module m;
std::vector<migraphx::instruction_ref> args{};
auto output = migraphx::shape(migraphx::shape::float_type, {1}, {0});
auto l = m.add_literal(migraphx::literal(output, {1}));
m.add_return({l});
EXPECT(test::throws([&] { mm->add_instruction(migraphx::make_op("pointwise"), args, {&m}); }));
}
TEST_CASE(pointwise_no_output)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::module m;
std::vector<migraphx::instruction_ref> args{};
EXPECT(test::throws([&] { mm->add_instruction(migraphx::make_op("pointwise"), args, {&m}); }));
}
TEST_CASE(pooling_shape0)
{
migraphx::shape input{migraphx::shape::float_type, {4, 3, 3, 3}};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment