Commit 0905762b authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fixed review comments and support minus axis input

parent 3f0e74b4
......@@ -14,7 +14,7 @@ namespace op {
struct reduce_mean
{
std::vector<std::size_t> axes{};
std::vector<std::int64_t> axes{};
template <class Self, class F>
static auto reflect(Self& self, F f)
......@@ -24,15 +24,41 @@ struct reduce_mean
std::string name() const { return "reduce_mean"; }
std::vector<int64_t> tune_axes(std::size_t n_dim) const
{
auto tuned_axes = axes;
if (tuned_axes.empty())
{
tuned_axes.resize(n_dim);
std::iota(tuned_axes.begin(), tuned_axes.end(), 0);
}
else
{
for (std::size_t i = 0; i < tuned_axes.size(); ++i)
{
int64_t s_dim = static_cast<int64_t>(n_dim);
if (tuned_axes[i] >= s_dim or tuned_axes[i] < -s_dim)
{
MIGRAPHX_THROW("REDUCE_MEAN: axis out of range");
}
if (tuned_axes[i] < 0)
{
tuned_axes[i] += n_dim;
}
}
}
return tuned_axes;
}
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto s = inputs.at(0);
auto lens = s.lens();
for(auto axis : axes)
auto tuned_axes = tune_axes(lens.size());
for(auto axis : tuned_axes)
{
if(axis >= lens.size())
MIGRAPHX_THROW("REDUCE_MEAN: axis out of range");
lens[axis] = 1;
}
......@@ -42,13 +68,14 @@ struct reduce_mean
template <class T>
void calc_mean(tensor_view<T>& input,
shape& batch_shape,
std::vector<int64_t>& tuned_axes,
std::vector<std::size_t>& out_idx,
tensor_view<T>& output) const
{
auto data_idx = out_idx;
T val = T{0};
shape_for_each(batch_shape, [&](auto b_idx) {
for(auto axis : axes)
for(auto axis : tuned_axes)
{
data_idx[axis] = b_idx[axis];
}
......@@ -62,8 +89,9 @@ struct reduce_mean
{
argument result{output_shape};
auto arg_lens = args.front().get_shape().lens();
auto tuned_axes = tune_axes(arg_lens.size());
std::vector<std::size_t> batch_lens(output_shape.lens().size(), 1);
for(auto axis : axes)
for(auto axis : tuned_axes)
{
batch_lens[axis] = arg_lens[axis];
}
......@@ -71,7 +99,7 @@ struct reduce_mean
visit_all(result, args[0])([&](auto output, auto input) {
par_for(output_shape.elements(), [&](auto i) {
auto out_idx = output_shape.multi(i);
this->calc_mean(input, batch_shape, out_idx, output);
this->calc_mean(input, batch_shape, tuned_axes, out_idx, output);
});
});
......
......@@ -14,7 +14,7 @@ namespace op {
struct reduce_sum
{
std::vector<std::size_t> axes{};
std::vector<int64_t> axes{};
template <class Self, class F>
static auto reflect(Self& self, F f)
......@@ -24,30 +24,58 @@ struct reduce_sum
std::string name() const { return "reduce_sum"; }
std::vector<int64_t> tune_axes(std::size_t n_dim) const
{
auto tuned_axes = axes;
if (tuned_axes.empty())
{
tuned_axes.resize(n_dim);
std::iota(tuned_axes.begin(), tuned_axes.end(), 0);
}
else
{
for (std::size_t i = 0; i < tuned_axes.size(); ++i)
{
int64_t s_dim = static_cast<int64_t>(n_dim);
if (tuned_axes[i] >= s_dim or tuned_axes[i] < -s_dim)
{
MIGRAPHX_THROW("REDUCE_SUM: axis out of range");
}
if (tuned_axes[i] < 0)
{
tuned_axes[i] += n_dim;
}
}
}
return tuned_axes;
}
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto s = inputs.at(0);
auto lens = s.lens();
for(auto axis : axes)
auto tuned_axes = tune_axes(lens.size());
for(auto axis : tuned_axes)
{
if(axis >= lens.size())
MIGRAPHX_THROW("REDUCE_SUM: axis out of range");
lens[axis] = 1;
}
return {s.type(), lens};
}
template <class T>
void calc_sum(tensor_view<T>& input,
shape& batch_shape,
std::vector<int64_t>& tuned_axes,
std::vector<std::size_t>& out_idx,
tensor_view<T>& output) const
{
auto data_idx = out_idx;
T val = T{0};
shape_for_each(batch_shape, [&](auto b_idx) {
for(auto axis : axes)
for(auto axis : tuned_axes)
{
data_idx[axis] = b_idx[axis];
}
......@@ -61,8 +89,9 @@ struct reduce_sum
{
argument result{output_shape};
auto arg_lens = args.front().get_shape().lens();
std::vector<int64_t> tuned_axes = tune_axes(arg_lens.size());
std::vector<std::size_t> batch_lens(output_shape.lens().size(), 1);
for(auto axis : axes)
for(auto axis : tuned_axes)
{
batch_lens[axis] = arg_lens[axis];
}
......@@ -70,7 +99,7 @@ struct reduce_sum
visit_all(result, args[0])([&](auto output, auto input) {
par_for(output_shape.elements(), [&](auto i) {
auto out_idx = output_shape.multi(i);
this->calc_sum(input, batch_shape, out_idx, output);
this->calc_sum(input, batch_shape, tuned_axes, out_idx, output);
});
});
......
......@@ -1296,13 +1296,13 @@ struct onnx_parser
std::size_t n_dim = args.front()->get_shape().lens().size();
// default to reduce over all dimensions
std::vector<std::size_t> axes(n_dim);
std::vector<int64_t> axes(n_dim);
std::iota(axes.begin(), axes.end(), 0);
if(contains(attributes, "axes"))
{
axes.clear();
auto&& attr_axes = attributes["axes"].ints();
axes = std::vector<std::size_t>(attr_axes.begin(), attr_axes.end());
axes = std::vector<int64_t>(attr_axes.begin(), attr_axes.end());
}
int keep_dims = 1;
......@@ -1318,8 +1318,7 @@ struct onnx_parser
else
{
auto ins = prog.add_instruction(T{axes}, std::move(args));
std::vector<int64_t> sq_axes(axes.begin(), axes.end());
return prog.add_instruction(op::squeeze{sq_axes}, ins);
return prog.add_instruction(op::squeeze{axes}, ins);
}
}
......
......@@ -1688,7 +1688,7 @@ TEST_CASE(clip_test)
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(reduce_sum_test0)
TEST_CASE(reduce_sum_axis0)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
......@@ -1703,7 +1703,7 @@ TEST_CASE(reduce_sum_test0)
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_sum_test1)
TEST_CASE(reduce_sum_axis1)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
......@@ -1718,7 +1718,7 @@ TEST_CASE(reduce_sum_test1)
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_sum_test2)
TEST_CASE(reduce_sum_axis2)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
......@@ -1733,7 +1733,7 @@ TEST_CASE(reduce_sum_test2)
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_sum_test02)
TEST_CASE(reduce_sum_axis02)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
......@@ -1748,7 +1748,7 @@ TEST_CASE(reduce_sum_test02)
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_sum_test12)
TEST_CASE(reduce_sum_axis12)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
......@@ -1763,7 +1763,7 @@ TEST_CASE(reduce_sum_test12)
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_mean_test1)
TEST_CASE(reduce_mean_axis1)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
......@@ -1778,7 +1778,7 @@ TEST_CASE(reduce_mean_test1)
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_mean_test2)
TEST_CASE(reduce_mean_axis2)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
......@@ -1793,7 +1793,7 @@ TEST_CASE(reduce_mean_test2)
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_mean_test02)
TEST_CASE(reduce_mean_axis02)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
......@@ -1808,7 +1808,7 @@ TEST_CASE(reduce_mean_test02)
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_mean_test12)
TEST_CASE(reduce_mean_axis12)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
......
......@@ -465,6 +465,15 @@ TEST_CASE(test_argmin)
template <class T>
void test_reduce_ops()
{
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}}, T{}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}}, T{{0, 1, 2, 3}}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 1, 1}}, T{{2, 3}}, input);
......@@ -473,6 +482,10 @@ void test_reduce_ops()
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}}, T{{0}}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 1}}, T{{-1}}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(T{{4}}, input);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment