Commit 751e8f37 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'reduce_mean' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into test_bert

parents e7d26442 3003844f
...@@ -14,7 +14,7 @@ namespace op { ...@@ -14,7 +14,7 @@ namespace op {
struct reduce_mean struct reduce_mean
{ {
std::vector<std::size_t> axes{}; std::vector<std::int64_t> axes{};
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -24,15 +24,41 @@ struct reduce_mean ...@@ -24,15 +24,41 @@ struct reduce_mean
std::string name() const { return "reduce_mean"; } std::string name() const { return "reduce_mean"; }
std::vector<int64_t> tune_axes(std::size_t n_dim) const
{
auto tuned_axes = axes;
if(tuned_axes.empty())
{
tuned_axes.resize(n_dim);
std::iota(tuned_axes.begin(), tuned_axes.end(), 0);
}
else
{
for(std::size_t i = 0; i < tuned_axes.size(); ++i)
{
int64_t s_dim = static_cast<int64_t>(n_dim);
if(tuned_axes[i] >= s_dim or tuned_axes[i] < -s_dim)
{
MIGRAPHX_THROW("REDUCE_MEAN: axis out of range");
}
if(tuned_axes[i] < 0)
{
tuned_axes[i] += n_dim;
}
}
}
return tuned_axes;
}
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this}.has(1);
auto s = inputs.at(0); auto s = inputs.at(0);
auto lens = s.lens(); auto lens = s.lens();
for(auto axis : axes) auto tuned_axes = tune_axes(lens.size());
for(auto axis : tuned_axes)
{ {
if(axis >= lens.size())
MIGRAPHX_THROW("REDUCE_MEAN: axis out of range");
lens[axis] = 1; lens[axis] = 1;
} }
...@@ -42,13 +68,14 @@ struct reduce_mean ...@@ -42,13 +68,14 @@ struct reduce_mean
template <class T> template <class T>
void calc_mean(tensor_view<T>& input, void calc_mean(tensor_view<T>& input,
shape& batch_shape, shape& batch_shape,
std::vector<int64_t>& tuned_axes,
std::vector<std::size_t>& out_idx, std::vector<std::size_t>& out_idx,
tensor_view<T>& output) const tensor_view<T>& output) const
{ {
auto data_idx = out_idx; auto data_idx = out_idx;
T val = T{0}; T val = T{0};
shape_for_each(batch_shape, [&](auto b_idx) { shape_for_each(batch_shape, [&](auto b_idx) {
for(auto axis : axes) for(auto axis : tuned_axes)
{ {
data_idx[axis] = b_idx[axis]; data_idx[axis] = b_idx[axis];
} }
...@@ -61,9 +88,10 @@ struct reduce_mean ...@@ -61,9 +88,10 @@ struct reduce_mean
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
auto arg_lens = args.front().get_shape().lens(); auto arg_lens = args.front().get_shape().lens();
auto tuned_axes = tune_axes(arg_lens.size());
std::vector<std::size_t> batch_lens(output_shape.lens().size(), 1); std::vector<std::size_t> batch_lens(output_shape.lens().size(), 1);
for(auto axis : axes) for(auto axis : tuned_axes)
{ {
batch_lens[axis] = arg_lens[axis]; batch_lens[axis] = arg_lens[axis];
} }
...@@ -71,7 +99,7 @@ struct reduce_mean ...@@ -71,7 +99,7 @@ struct reduce_mean
visit_all(result, args[0])([&](auto output, auto input) { visit_all(result, args[0])([&](auto output, auto input) {
par_for(output_shape.elements(), [&](auto i) { par_for(output_shape.elements(), [&](auto i) {
auto out_idx = output_shape.multi(i); auto out_idx = output_shape.multi(i);
this->calc_mean(input, batch_shape, out_idx, output); this->calc_mean(input, batch_shape, tuned_axes, out_idx, output);
}); });
}); });
......
...@@ -14,7 +14,7 @@ namespace op { ...@@ -14,7 +14,7 @@ namespace op {
struct reduce_sum struct reduce_sum
{ {
std::vector<std::size_t> axes{}; std::vector<int64_t> axes{};
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -24,30 +24,58 @@ struct reduce_sum ...@@ -24,30 +24,58 @@ struct reduce_sum
std::string name() const { return "reduce_sum"; } std::string name() const { return "reduce_sum"; }
std::vector<int64_t> tune_axes(std::size_t n_dim) const
{
auto tuned_axes = axes;
if(tuned_axes.empty())
{
tuned_axes.resize(n_dim);
std::iota(tuned_axes.begin(), tuned_axes.end(), 0);
}
else
{
for(std::size_t i = 0; i < tuned_axes.size(); ++i)
{
int64_t s_dim = static_cast<int64_t>(n_dim);
if(tuned_axes[i] >= s_dim or tuned_axes[i] < -s_dim)
{
MIGRAPHX_THROW("REDUCE_SUM: axis out of range");
}
if(tuned_axes[i] < 0)
{
tuned_axes[i] += n_dim;
}
}
}
return tuned_axes;
}
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this}.has(1);
auto s = inputs.at(0); auto s = inputs.at(0);
auto lens = s.lens(); auto lens = s.lens();
for(auto axis : axes) auto tuned_axes = tune_axes(lens.size());
for(auto axis : tuned_axes)
{ {
if(axis >= lens.size())
MIGRAPHX_THROW("REDUCE_SUM: axis out of range");
lens[axis] = 1; lens[axis] = 1;
} }
return {s.type(), lens}; return {s.type(), lens};
} }
template <class T> template <class T>
void calc_sum(tensor_view<T>& input, void calc_sum(tensor_view<T>& input,
shape& batch_shape, shape& batch_shape,
std::vector<int64_t>& tuned_axes,
std::vector<std::size_t>& out_idx, std::vector<std::size_t>& out_idx,
tensor_view<T>& output) const tensor_view<T>& output) const
{ {
auto data_idx = out_idx; auto data_idx = out_idx;
T val = T{0}; T val = T{0};
shape_for_each(batch_shape, [&](auto b_idx) { shape_for_each(batch_shape, [&](auto b_idx) {
for(auto axis : axes) for(auto axis : tuned_axes)
{ {
data_idx[axis] = b_idx[axis]; data_idx[axis] = b_idx[axis];
} }
...@@ -60,9 +88,10 @@ struct reduce_sum ...@@ -60,9 +88,10 @@ struct reduce_sum
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
auto arg_lens = args.front().get_shape().lens(); auto arg_lens = args.front().get_shape().lens();
std::vector<int64_t> tuned_axes = tune_axes(arg_lens.size());
std::vector<std::size_t> batch_lens(output_shape.lens().size(), 1); std::vector<std::size_t> batch_lens(output_shape.lens().size(), 1);
for(auto axis : axes) for(auto axis : tuned_axes)
{ {
batch_lens[axis] = arg_lens[axis]; batch_lens[axis] = arg_lens[axis];
} }
...@@ -70,7 +99,7 @@ struct reduce_sum ...@@ -70,7 +99,7 @@ struct reduce_sum
visit_all(result, args[0])([&](auto output, auto input) { visit_all(result, args[0])([&](auto output, auto input) {
par_for(output_shape.elements(), [&](auto i) { par_for(output_shape.elements(), [&](auto i) {
auto out_idx = output_shape.multi(i); auto out_idx = output_shape.multi(i);
this->calc_sum(input, batch_shape, out_idx, output); this->calc_sum(input, batch_shape, tuned_axes, out_idx, output);
}); });
}); });
......
...@@ -1381,13 +1381,13 @@ struct onnx_parser ...@@ -1381,13 +1381,13 @@ struct onnx_parser
std::size_t n_dim = args.front()->get_shape().lens().size(); std::size_t n_dim = args.front()->get_shape().lens().size();
// default to reduce over all dimensions // default to reduce over all dimensions
std::vector<std::size_t> axes(n_dim); std::vector<int64_t> axes(n_dim);
std::iota(axes.begin(), axes.end(), 0); std::iota(axes.begin(), axes.end(), 0);
if(contains(attributes, "axes")) if(contains(attributes, "axes"))
{ {
axes.clear(); axes.clear();
auto&& attr_axes = attributes["axes"].ints(); auto&& attr_axes = attributes["axes"].ints();
axes = std::vector<std::size_t>(attr_axes.begin(), attr_axes.end()); axes = std::vector<int64_t>(attr_axes.begin(), attr_axes.end());
} }
int keep_dims = 1; int keep_dims = 1;
...@@ -1403,8 +1403,7 @@ struct onnx_parser ...@@ -1403,8 +1403,7 @@ struct onnx_parser
else else
{ {
auto ins = prog.add_instruction(T{axes}, std::move(args)); auto ins = prog.add_instruction(T{axes}, std::move(args));
std::vector<int64_t> sq_axes(axes.begin(), axes.end()); return prog.add_instruction(op::squeeze{axes}, ins);
return prog.add_instruction(op::squeeze{sq_axes}, ins);
} }
} }
......
...@@ -1718,7 +1718,7 @@ TEST_CASE(clip_test) ...@@ -1718,7 +1718,7 @@ TEST_CASE(clip_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(reduce_sum_test0) TEST_CASE(reduce_sum_axis0)
{ {
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
...@@ -1733,7 +1733,7 @@ TEST_CASE(reduce_sum_test0) ...@@ -1733,7 +1733,7 @@ TEST_CASE(reduce_sum_test0)
EXPECT(results_vector == gold); EXPECT(results_vector == gold);
} }
TEST_CASE(reduce_sum_test1) TEST_CASE(reduce_sum_axis1)
{ {
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
...@@ -1748,7 +1748,7 @@ TEST_CASE(reduce_sum_test1) ...@@ -1748,7 +1748,7 @@ TEST_CASE(reduce_sum_test1)
EXPECT(results_vector == gold); EXPECT(results_vector == gold);
} }
TEST_CASE(reduce_sum_test2) TEST_CASE(reduce_sum_axis2)
{ {
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
...@@ -1763,7 +1763,7 @@ TEST_CASE(reduce_sum_test2) ...@@ -1763,7 +1763,7 @@ TEST_CASE(reduce_sum_test2)
EXPECT(results_vector == gold); EXPECT(results_vector == gold);
} }
TEST_CASE(reduce_sum_test02) TEST_CASE(reduce_sum_axis02)
{ {
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
...@@ -1778,7 +1778,7 @@ TEST_CASE(reduce_sum_test02) ...@@ -1778,7 +1778,7 @@ TEST_CASE(reduce_sum_test02)
EXPECT(results_vector == gold); EXPECT(results_vector == gold);
} }
TEST_CASE(reduce_sum_test12) TEST_CASE(reduce_sum_axis12)
{ {
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
...@@ -1793,7 +1793,7 @@ TEST_CASE(reduce_sum_test12) ...@@ -1793,7 +1793,7 @@ TEST_CASE(reduce_sum_test12)
EXPECT(results_vector == gold); EXPECT(results_vector == gold);
} }
TEST_CASE(reduce_mean_test1) TEST_CASE(reduce_mean_axis1)
{ {
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
...@@ -1808,7 +1808,7 @@ TEST_CASE(reduce_mean_test1) ...@@ -1808,7 +1808,7 @@ TEST_CASE(reduce_mean_test1)
EXPECT(results_vector == gold); EXPECT(results_vector == gold);
} }
TEST_CASE(reduce_mean_test2) TEST_CASE(reduce_mean_axis2)
{ {
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
...@@ -1823,7 +1823,7 @@ TEST_CASE(reduce_mean_test2) ...@@ -1823,7 +1823,7 @@ TEST_CASE(reduce_mean_test2)
EXPECT(results_vector == gold); EXPECT(results_vector == gold);
} }
TEST_CASE(reduce_mean_test02) TEST_CASE(reduce_mean_axis02)
{ {
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
...@@ -1838,7 +1838,7 @@ TEST_CASE(reduce_mean_test02) ...@@ -1838,7 +1838,7 @@ TEST_CASE(reduce_mean_test02)
EXPECT(results_vector == gold); EXPECT(results_vector == gold);
} }
TEST_CASE(reduce_mean_test12) TEST_CASE(reduce_mean_axis12)
{ {
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
......
...@@ -465,6 +465,16 @@ TEST_CASE(test_argmin) ...@@ -465,6 +465,16 @@ TEST_CASE(test_argmin)
template <class T> template <class T>
void test_reduce_ops() void test_reduce_ops()
{ {
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}}, T{}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}}, T{{0, 1, 2, 3}}, input);
}
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 1, 1}}, T{{2, 3}}, input); expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 1, 1}}, T{{2, 3}}, input);
...@@ -473,6 +483,10 @@ void test_reduce_ops() ...@@ -473,6 +483,10 @@ void test_reduce_ops()
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}}, T{{0}}, input); expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}}, T{{0}}, input);
} }
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 1}}, T{{-1}}, input);
}
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(T{{4}}, input); throws_shape(T{{4}}, input);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment