"driver/vscode:/vscode.git/clone" did not exist on "85ae70d3d39bccd1b090b3f5f1b2b29f6b6f65ca"
Commit b0b02e63 authored by Brian Pickrell's avatar Brian Pickrell
Browse files

Refactored reduce_op::normalize_compute_shape() handling of dynamic shapes to...

Refactored reduce_op::normalize_compute_shape() handling of dynamic shapes to set reduced dimensions to {1,1} instead of removing them.
Updated shape tests for reduce ops.
Spurious change to reduce_mean.cpp reverted.
parent f25caab5
...@@ -56,7 +56,7 @@ struct argmax ...@@ -56,7 +56,7 @@ struct argmax
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this, true}.has(1); check_shapes{inputs, *this}.has(1);
auto lens = inputs[0].lens(); auto lens = inputs[0].lens();
lens[axis] = 1; lens[axis] = 1;
......
...@@ -106,36 +106,33 @@ struct reduce_op : op_name<Derived> ...@@ -106,36 +106,33 @@ struct reduce_op : op_name<Derived>
return tuned_axes; return tuned_axes;
} }
/**
* @brief returns a shape in which the axis or axes named
* for reduction by this op are set to size 1.
*
* @param inputs list of input shapes
* @return shape
*/
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this, true}.has(1); check_shapes{inputs, *this, true}.has(1);
auto s = inputs.at(0); auto s = inputs.at(0);
if(s.dynamic()) if(s.dynamic())
{ {
std::vector<shape::dynamic_dimension> output_dyn_dims = {}; auto output_dyn_dims = s.dyn_dims();
auto tuned_axes = tune_axes(output_dyn_dims.size());
// create a dynamic dimensions vector that leaves out any axis named in this->axes. for(const auto& axis : tuned_axes)
for(size_t index = 0; index < s.dyn_dims().size(); index++)
{ {
auto name_it = std::find_if(this->axes.begin(), this->axes.end(), [&](auto& axis) { output_dyn_dims[axis] = {1, 1};
return (axis == index); // if the dim is in this op's axes list, don't include
// it
});
if(name_it == this->axes.end())
{
output_dyn_dims.push_back(s.dyn_dims().at(index));
}
} }
// compare with what src/include/migraphx/op/convolution.hpp does:
return shape{s.type(), output_dyn_dims}; return shape{s.type(), output_dyn_dims};
} }
else else
{ {
auto lens = s.lens(); auto lens = s.lens();
auto tuned_axes = tune_axes(lens.size()); auto tuned_axes = tune_axes(lens.size());
for(auto axis : tuned_axes) for(auto& axis : tuned_axes)
{ {
lens[axis] = 1; lens[axis] = 1;
} }
......
...@@ -1396,24 +1396,42 @@ void test_reduce_ops() ...@@ -1396,24 +1396,42 @@ void test_reduce_ops()
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(T{{4}}, input); throws_shape(T{{4}}, input);
} }
// dynamic shape }
// dynamic shape
template <class T>
void test_dyn_reduce_ops()
{
{ {
migraphx::shape input{migraphx::shape::float_type, {{2, 3, 4}, {2, 4, 4}}}; migraphx::shape input{migraphx::shape::float_type, {{2, 3, 3}, {2, 4, 4}}};
migraphx::shape::dynamic_dimension dd0{2, 3, 4};
expect_shape(migraphx::shape{migraphx::shape::float_type, expect_shape(migraphx::shape{migraphx::shape::float_type,
std::vector<migraphx::shape::dynamic_dimension>({{2, 3, 4}})}, std::vector<migraphx::shape::dynamic_dimension>({{2, 3, 3}, {1, 1, 0}})},
T{{-1}}, T{{-1}},
input); input);
}
{
migraphx::shape input{migraphx::shape::float_type, {{2, 3, 3}, {2, 4, 4}}};
expect_shape(migraphx::shape{migraphx::shape::float_type, expect_shape(migraphx::shape{migraphx::shape::float_type,
std::vector<migraphx::shape::dynamic_dimension>({{2, 4, 4}})}, std::vector<migraphx::shape::dynamic_dimension>({{1, 1, 0}, {2, 4, 4}})},
T{{0}}, T{{0}},
input); input);
} }
{
migraphx::shape input{migraphx::shape::float_type, {{2, 3, 3}, {2, 4, 4}}};
throws_shape(T{{4}}, input);
}
} }
TEST_CASE(reduce_max) { test_reduce_ops<migraphx::op::reduce_max>(); }
TEST_CASE(reduce_mean) { test_reduce_ops<migraphx::op::reduce_mean>(); } TEST_CASE(reduce_mean) { test_reduce_ops<migraphx::op::reduce_mean>(); }
TEST_CASE(reduce_prod) { test_reduce_ops<migraphx::op::reduce_prod>(); }
TEST_CASE(reduce_sum) { test_reduce_ops<migraphx::op::reduce_sum>(); } TEST_CASE(reduce_sum) { test_reduce_ops<migraphx::op::reduce_sum>(); }
TEST_CASE(reduce_max_dyn) { test_dyn_reduce_ops<migraphx::op::reduce_max>(); }
TEST_CASE(reduce_mean_dyn) { test_dyn_reduce_ops<migraphx::op::reduce_mean>(); }
TEST_CASE(reduce_prod_dyn) { test_dyn_reduce_ops<migraphx::op::reduce_prod>(); }
TEST_CASE(reduce_sum_dyn) { test_dyn_reduce_ops<migraphx::op::reduce_sum>(); }
TEST_CASE(reshape_shape) TEST_CASE(reshape_shape)
{ {
migraphx::shape input{migraphx::shape::float_type, {24, 1, 1, 1}}; migraphx::shape input{migraphx::shape::float_type, {24, 1, 1, 1}};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment