Commit b0b02e63 authored by Brian Pickrell's avatar Brian Pickrell
Browse files

Refactored reduce_op::normalize_compute_shape() handling of dynamic shapes to...

Refactored reduce_op::normalize_compute_shape() handling of dynamic shapes to set reduced dimensions to {1,1} instead of removing them.
Updated shape tests for reduce ops.
Spurious change to reduce_mean.cpp reverted.
parent f25caab5
......@@ -56,7 +56,7 @@ struct argmax
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this, true}.has(1);
check_shapes{inputs, *this}.has(1);
auto lens = inputs[0].lens();
lens[axis] = 1;
......
......@@ -106,36 +106,33 @@ struct reduce_op : op_name<Derived>
return tuned_axes;
}
/**
* @brief returns a shape in which the axis or axes named
* for reduction by this op are set to size 1.
*
* @param inputs list of input shapes
* @return shape
*/
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this, true}.has(1);
auto s = inputs.at(0);
if(s.dynamic())
{
std::vector<shape::dynamic_dimension> output_dyn_dims = {};
// create a dynamic dimensions vector that leaves out any axis named in this->axes.
for(size_t index = 0; index < s.dyn_dims().size(); index++)
{
auto name_it = std::find_if(this->axes.begin(), this->axes.end(), [&](auto& axis) {
return (axis == index); // if the dim is in this op's axes list, don't include
// it
});
if(name_it == this->axes.end())
auto output_dyn_dims = s.dyn_dims();
auto tuned_axes = tune_axes(output_dyn_dims.size());
for(const auto& axis : tuned_axes)
{
output_dyn_dims.push_back(s.dyn_dims().at(index));
}
output_dyn_dims[axis] = {1, 1};
}
// compare with what src/include/migraphx/op/convolution.hpp does:
return shape{s.type(), output_dyn_dims};
}
else
{
auto lens = s.lens();
auto tuned_axes = tune_axes(lens.size());
for(auto axis : tuned_axes)
for(auto& axis : tuned_axes)
{
lens[axis] = 1;
}
......
......@@ -1396,24 +1396,42 @@ void test_reduce_ops()
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(T{{4}}, input);
}
// dynamic shape
}
// dynamic shape
template <class T>
void test_dyn_reduce_ops()
{
{
migraphx::shape input{migraphx::shape::float_type, {{2, 3, 4}, {2, 4, 4}}};
migraphx::shape::dynamic_dimension dd0{2, 3, 4};
migraphx::shape input{migraphx::shape::float_type, {{2, 3, 3}, {2, 4, 4}}};
expect_shape(migraphx::shape{migraphx::shape::float_type,
std::vector<migraphx::shape::dynamic_dimension>({{2, 3, 4}})},
std::vector<migraphx::shape::dynamic_dimension>({{2, 3, 3}, {1, 1, 0}})},
T{{-1}},
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {{2, 3, 3}, {2, 4, 4}}};
expect_shape(migraphx::shape{migraphx::shape::float_type,
std::vector<migraphx::shape::dynamic_dimension>({{2, 4, 4}})},
std::vector<migraphx::shape::dynamic_dimension>({{1, 1, 0}, {2, 4, 4}})},
T{{0}},
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {{2, 3, 3}, {2, 4, 4}}};
throws_shape(T{{4}}, input);
}
}
TEST_CASE(reduce_max) { test_reduce_ops<migraphx::op::reduce_max>(); }
TEST_CASE(reduce_mean) { test_reduce_ops<migraphx::op::reduce_mean>(); }
TEST_CASE(reduce_prod) { test_reduce_ops<migraphx::op::reduce_prod>(); }
TEST_CASE(reduce_sum) { test_reduce_ops<migraphx::op::reduce_sum>(); }
TEST_CASE(reduce_max_dyn) { test_dyn_reduce_ops<migraphx::op::reduce_max>(); }
TEST_CASE(reduce_mean_dyn) { test_dyn_reduce_ops<migraphx::op::reduce_mean>(); }
TEST_CASE(reduce_prod_dyn) { test_dyn_reduce_ops<migraphx::op::reduce_prod>(); }
TEST_CASE(reduce_sum_dyn) { test_dyn_reduce_ops<migraphx::op::reduce_sum>(); }
TEST_CASE(reshape_shape)
{
migraphx::shape input{migraphx::shape::float_type, {24, 1, 1, 1}};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment