Commit c30653cb authored by Brian Pickrell's avatar Brian Pickrell
Browse files

Added a shape test. The op should return the same dynamic shape as the input.

parent d4663c02
...@@ -64,9 +64,9 @@ struct prefix_scan_op : op_name<Derived> ...@@ -64,9 +64,9 @@ struct prefix_scan_op : op_name<Derived>
auto s = inputs.front(); auto s = inputs.front();
if(s.dynamic()) if(s.dynamic())
{ {
return {s.type(), s.max_lens()}; return s;
} }
if(s.broadcasted()) else if(s.broadcasted())
{ {
return {s.type(), s.lens()}; return {s.type(), s.lens()};
} }
......
...@@ -2084,6 +2084,19 @@ TEST_CASE(prefix_scan_sum) ...@@ -2084,6 +2084,19 @@ TEST_CASE(prefix_scan_sum)
} }
} }
TEST_CASE(prefix_scan_sum_dyn)
{
{
std::vector<migraphx::shape::dynamic_dimension> dd{{5, 8}};
migraphx::shape s{migraphx::shape::float_type, dd};
expect_shape(
s,
migraphx::make_op("prefix_scan_sum", {{"axis", 0}, {"exclusive", 0}, {"reverse", 0}}),
s);
}
}
TEST_CASE(quant_convolution_shape) TEST_CASE(quant_convolution_shape)
{ {
migraphx::shape output{migraphx::shape::int32_type, {4, 4, 1, 1}}; migraphx::shape output{migraphx::shape::int32_type, {4, 4, 1, 1}};
......
...@@ -5543,19 +5543,15 @@ TEST_CASE(prefix_scan_sum_dyn_1d) ...@@ -5543,19 +5543,15 @@ TEST_CASE(prefix_scan_sum_dyn_1d)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
std::vector<migraphx::shape::dynamic_dimension> dd{{5, 6}}; std::vector<migraphx::shape::dynamic_dimension> dd{{5, 8}};
migraphx::shape s{migraphx::shape::float_type, dd}; migraphx::shape s{migraphx::shape::float_type, dd};
auto input = mm->add_parameter("X", s); auto input = mm->add_parameter("X", s);
mm->add_instruction(migraphx::make_op("prefix_scan_sum", {{"axis", 0}, {"exclusive", false}}), input); mm->add_instruction(migraphx::make_op("prefix_scan_sum", {{"axis", 0}, {"exclusive", false}}), input);
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
std::vector<float> a = {1, 2, 3, 4, 5, 6}; std::vector<float> a = {1, 2, 3, 4, 5, 6};
migraphx::parameter_map params0;
// auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6}};
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {6}}; migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {6}};
migraphx::parameter_map params0;
params0["X"] = migraphx::argument(input_fixed_shape0, a.data()); params0["X"] = migraphx::argument(input_fixed_shape0, a.data());
auto result = p.eval(params0).back(); auto result = p.eval(params0).back();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment