"vscode:/vscode.git/clone" did not exist on "e225b9abe68421a6a37b693227282ebafd9fee87"
Commit b3f0f482 authored by Brian Pickrell's avatar Brian Pickrell
Browse files

dynamic shape support for reduce_XXX operations. One test in ref_ops_test;...

dynamic shape support for reduce_XXX operations.  One test in ref_ops_test; one test in op_shape_test
parent 5696ac5f
...@@ -110,6 +110,29 @@ struct reduce_op : op_name<Derived> ...@@ -110,6 +110,29 @@ struct reduce_op : op_name<Derived>
{ {
check_shapes{inputs, *this, true}.has(1); check_shapes{inputs, *this, true}.has(1);
auto s = inputs.at(0); auto s = inputs.at(0);
if(s.dynamic())
{
std::vector<shape::dynamic_dimension> output_dyn_dims = {};
// create a dynamic dimensions vector that leaves out any axis named in this->axes.
for(size_t index = 0; index < s.dyn_dims().size(); index++)
{
auto name_it = std::find_if(this->axes.begin(), this->axes.end(), [&](auto& axis) {
return (axis == index); // if the dim is in this op's axes list, don't include
// it
});
if(name_it == this->axes.end())
{
output_dyn_dims.push_back(s.dyn_dims().at(index));
}
}
// compare with what src/include/migraphx/op/convolution.hpp does:
return shape{s.type(), output_dyn_dims};
}
else
{
auto lens = s.lens(); auto lens = s.lens();
auto tuned_axes = tune_axes(lens.size()); auto tuned_axes = tune_axes(lens.size());
for(auto axis : tuned_axes) for(auto axis : tuned_axes)
...@@ -119,6 +142,7 @@ struct reduce_op : op_name<Derived> ...@@ -119,6 +142,7 @@ struct reduce_op : op_name<Derived>
return inputs[0].with_lens(lens); return inputs[0].with_lens(lens);
} }
}
template <class T> template <class T>
void tune_dims(const std::vector<int64_t>& tuned_axes, void tune_dims(const std::vector<int64_t>& tuned_axes,
...@@ -154,7 +178,7 @@ struct reduce_op : op_name<Derived> ...@@ -154,7 +178,7 @@ struct reduce_op : op_name<Derived>
} }
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ // todo: what should be different about the computation, if anything? {
argument result{dyn_out.computed_shape}; argument result{dyn_out.computed_shape};
auto arg_lens = args.front().get_shape().lens(); auto arg_lens = args.front().get_shape().lens();
auto tuned_axes = tune_axes(arg_lens.size()); auto tuned_axes = tune_axes(arg_lens.size());
......
...@@ -1396,6 +1396,19 @@ void test_reduce_ops() ...@@ -1396,6 +1396,19 @@ void test_reduce_ops()
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(T{{4}}, input); throws_shape(T{{4}}, input);
} }
// dynamic shape
{
migraphx::shape input{migraphx::shape::float_type, {{2, 3, 4}, {2, 4, 4}}};
migraphx::shape::dynamic_dimension dd0{2, 3, 4};
expect_shape(migraphx::shape{migraphx::shape::float_type,
std::vector<migraphx::shape::dynamic_dimension>({{2, 3, 4}})},
T{{-1}},
input);
expect_shape(migraphx::shape{migraphx::shape::float_type,
std::vector<migraphx::shape::dynamic_dimension>({{2, 4, 4}})},
T{{0}},
input);
}
} }
TEST_CASE(reduce_mean) { test_reduce_ops<migraphx::op::reduce_mean>(); } TEST_CASE(reduce_mean) { test_reduce_ops<migraphx::op::reduce_mean>(); }
......
...@@ -4978,6 +4978,50 @@ TEST_CASE(reduce_max_axis0) ...@@ -4978,6 +4978,50 @@ TEST_CASE(reduce_max_axis0)
EXPECT(results_vector == gold); EXPECT(results_vector == gold);
} }
TEST_CASE(reduce_max_dynamic_axis0)
{ /*
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{3, 2, 2};
migraphx::shape s{migraphx::shape::float_type, {dd}};
// auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
// auto l0 = mm->add_literal(input);
auto input = mm->add_parameter("X", s);
std::vector<float> input_data{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
auto populate_op = migraphx::make_op("identity");
mm->add_instruction(populate_op, input);
mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {0}}}), input);
// mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {0}}}), l0);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{9, 10, 11, 12};
EXPECT(results_vector == gold);
*/
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {{2, 4, 2}, {3, 5, 3}}};
auto input = mm->add_parameter("X", s);
std::vector<float> input_data{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
auto reduce_max_op = migraphx::make_op("reduce_max", {{"axes", {0}}});
mm->add_instruction(reduce_max_op, input);
p.compile(migraphx::ref::target{});
migraphx::parameter_map params0;
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {2, 5}};
params0["X"] = migraphx::argument(input_fixed_shape0, input_data.data());
auto result = p.eval(params0).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {6, 7, 8, 9, 10};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(reduce_max_axis01) TEST_CASE(reduce_max_axis01)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment