Commit 413036a1 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge changes from develop branch

parents f06f6aa3 a96d3a95
#ifndef MIGRAPHX_GUARD_OPERATORS_MEAN_HPP
#define MIGRAPHX_GUARD_OPERATORS_MEAN_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/config.hpp>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct reduce_mean
{
std::vector<std::int64_t> axes{};
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axes, "axes"));
}
std::string name() const { return "reduce_mean"; }
std::vector<int64_t> tune_axes(std::size_t n_dim) const
{
auto tuned_axes = axes;
if(tuned_axes.empty())
{
tuned_axes.resize(n_dim);
std::iota(tuned_axes.begin(), tuned_axes.end(), 0);
}
else
{
for(auto& axis : tuned_axes)
{
int64_t s_dim = static_cast<int64_t>(n_dim);
if(axis >= s_dim or axis < -s_dim)
{
MIGRAPHX_THROW("REDUCE_MEAN: axis out of range");
}
if(axis < 0)
{
axis += n_dim;
}
}
}
return tuned_axes;
}
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto s = inputs.at(0);
auto lens = s.lens();
auto tuned_axes = tune_axes(lens.size());
for(auto axis : tuned_axes)
{
lens[axis] = 1;
}
return {s.type(), lens};
}
template <class T>
void calc_mean(tensor_view<T>& input,
shape& batch_shape,
std::vector<int64_t>& tuned_axes,
std::vector<std::size_t>& out_idx,
tensor_view<T>& output) const
{
auto data_idx = out_idx;
T val = T{0};
shape_for_each(batch_shape, [&](auto b_idx) {
for(auto axis : tuned_axes)
{
data_idx[axis] = b_idx[axis];
}
val += input(data_idx.begin(), data_idx.end());
});
output(out_idx.begin(), out_idx.end()) = val / batch_shape.elements();
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto arg_lens = args.front().get_shape().lens();
auto tuned_axes = tune_axes(arg_lens.size());
std::vector<std::size_t> batch_lens(output_shape.lens().size(), 1);
for(auto axis : tuned_axes)
{
batch_lens[axis] = arg_lens[axis];
}
shape batch_shape{output_shape.type(), batch_lens};
visit_all(result, args[0])([&](auto output, auto input) {
par_for(output_shape.elements(), [&](auto i) {
auto out_idx = output_shape.multi(i);
this->calc_mean(input, batch_shape, tuned_axes, out_idx, output);
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <vector> #include <vector>
...@@ -13,7 +14,7 @@ namespace op { ...@@ -13,7 +14,7 @@ namespace op {
struct reduce_sum struct reduce_sum
{ {
std::vector<std::size_t> axes; std::vector<int64_t> axes{};
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -23,25 +24,82 @@ struct reduce_sum ...@@ -23,25 +24,82 @@ struct reduce_sum
std::string name() const { return "reduce_sum"; } std::string name() const { return "reduce_sum"; }
std::vector<int64_t> tune_axes(std::size_t n_dim) const
{
auto tuned_axes = axes;
if(tuned_axes.empty())
{
tuned_axes.resize(n_dim);
std::iota(tuned_axes.begin(), tuned_axes.end(), 0);
}
else
{
for(auto& axis : tuned_axes)
{
int64_t s_dim = static_cast<int64_t>(n_dim);
if(axis >= s_dim or axis < -s_dim)
{
MIGRAPHX_THROW("REDUCE_MEAN: axis out of range");
}
if(axis < 0)
{
axis += n_dim;
}
}
}
return tuned_axes;
}
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this}.has(1);
auto s = inputs.at(0); auto s = inputs.at(0);
auto lens = s.lens(); auto lens = s.lens();
for(auto axis : axes) auto tuned_axes = tune_axes(lens.size());
for(auto axis : tuned_axes)
{
lens[axis] = 1; lens[axis] = 1;
}
return {s.type(), lens}; return {s.type(), lens};
} }
template <class T>
void calc_sum(tensor_view<T>& input,
shape& batch_shape,
std::vector<int64_t>& tuned_axes,
std::vector<std::size_t>& out_idx,
tensor_view<T>& output) const
{
auto data_idx = out_idx;
T val = T{0};
shape_for_each(batch_shape, [&](auto b_idx) {
for(auto axis : tuned_axes)
{
data_idx[axis] = b_idx[axis];
}
val += input(data_idx.begin(), data_idx.end());
});
output(out_idx.begin(), out_idx.end()) = val;
}
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
auto arg_lens = args.front().get_shape().lens();
std::vector<int64_t> tuned_axes = tune_axes(arg_lens.size());
std::vector<std::size_t> batch_lens(output_shape.lens().size(), 1);
for(auto axis : tuned_axes)
{
batch_lens[axis] = arg_lens[axis];
}
shape batch_shape{output_shape.type(), batch_lens};
visit_all(result, args[0])([&](auto output, auto input) { visit_all(result, args[0])([&](auto output, auto input) {
shape_for_each(input.get_shape(), [&](auto&& in_idx) { par_for(output_shape.elements(), [&](auto i) {
auto out_idx = in_idx; auto out_idx = output_shape.multi(i);
for(auto axis : axes) this->calc_sum(input, batch_shape, tuned_axes, out_idx, output);
out_idx[axis] = 0;
output(out_idx.begin(), out_idx.end()) += input(in_idx.begin(), in_idx.end());
}); });
}); });
......
...@@ -48,6 +48,7 @@ ...@@ -48,6 +48,7 @@
#include <migraphx/op/quant_convolution.hpp> #include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/quant_dot.hpp> #include <migraphx/op/quant_dot.hpp>
#include <migraphx/op/reduce_sum.hpp> #include <migraphx/op/reduce_sum.hpp>
#include <migraphx/op/reduce_mean.hpp>
#include <migraphx/op/relu.hpp> #include <migraphx/op/relu.hpp>
#include <migraphx/op/reshape.hpp> #include <migraphx/op/reshape.hpp>
#include <migraphx/op/rnn.hpp> #include <migraphx/op/rnn.hpp>
......
...@@ -96,7 +96,8 @@ struct onnx_parser ...@@ -96,7 +96,8 @@ struct onnx_parser
add_mem_op("GRU", &onnx_parser::parse_gru); add_mem_op("GRU", &onnx_parser::parse_gru);
add_mem_op("LSTM", &onnx_parser::parse_lstm); add_mem_op("LSTM", &onnx_parser::parse_lstm);
add_mem_op("Pad", &onnx_parser::parse_pad); add_mem_op("Pad", &onnx_parser::parse_pad);
add_mem_op("ReduceSum", &onnx_parser::parse_reduce_sum); add_mem_op("ReduceSum", &onnx_parser::parse_reduce_oper<op::reduce_sum>);
add_mem_op("ReduceMean", &onnx_parser::parse_reduce_oper<op::reduce_mean>);
// init the activation function map // init the activation function map
init_actv_func(); init_actv_func();
...@@ -1288,20 +1289,21 @@ struct onnx_parser ...@@ -1288,20 +1289,21 @@ struct onnx_parser
return {hidden_states, last_output, last_cell_output}; return {hidden_states, last_output, last_cell_output};
} }
instruction_ref parse_reduce_sum(const std::string&, template <class T>
instruction_ref parse_reduce_oper(const std::string&,
attribute_map attributes, attribute_map attributes,
std::vector<instruction_ref> args) std::vector<instruction_ref> args)
{ {
std::size_t n_dim = args.front()->get_shape().lens().size(); std::size_t n_dim = args.front()->get_shape().lens().size();
// default to reduce over all dimensions // default to reduce over all dimensions
std::vector<std::size_t> axes(n_dim); std::vector<int64_t> axes(n_dim);
std::iota(axes.begin(), axes.end(), 0); std::iota(axes.begin(), axes.end(), 0);
if(contains(attributes, "axes")) if(contains(attributes, "axes"))
{ {
axes.clear(); axes.clear();
auto&& attr_axes = attributes["axes"].ints(); auto&& attr_axes = attributes["axes"].ints();
axes = std::vector<std::size_t>(attr_axes.begin(), attr_axes.end()); axes = std::vector<int64_t>(attr_axes.begin(), attr_axes.end());
} }
int keep_dims = 1; int keep_dims = 1;
...@@ -1312,13 +1314,12 @@ struct onnx_parser ...@@ -1312,13 +1314,12 @@ struct onnx_parser
if(keep_dims == 1) if(keep_dims == 1)
{ {
return prog.add_instruction(op::reduce_sum{axes}, std::move(args)); return prog.add_instruction(T{axes}, std::move(args));
} }
else else
{ {
auto ins = prog.add_instruction(op::reduce_sum{axes}, std::move(args)); auto ins = prog.add_instruction(T{axes}, std::move(args));
std::vector<int64_t> squeeze_axes{axes.begin(), axes.end()}; return prog.add_instruction(op::squeeze{axes}, ins);
return prog.add_instruction(op::squeeze{squeeze_axes}, ins);
} }
} }
......
...@@ -10,8 +10,8 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,8 +10,8 @@ inline namespace MIGRAPHX_INLINE_NS {
bool skip_propogate(instruction_ref ins) bool skip_propogate(instruction_ref ins)
{ {
if(ins->name() == "@literal") if(ins->name() == "contiguous")
return true; return skip_propogate(ins->inputs().front());
auto&& s = ins->get_shape(); auto&& s = ins->get_shape();
if(s.broadcasted() and not s.scalar()) if(s.broadcasted() and not s.scalar())
return true; return true;
...@@ -33,7 +33,7 @@ void propagate_constant::apply(program& p) const ...@@ -33,7 +33,7 @@ void propagate_constant::apply(program& p) const
ins->outputs().end()); ins->outputs().end());
for(auto child : children) for(auto child : children)
{ {
if(skip_propogate(child)) if(child->name() == "@literal" or skip_propogate(child))
{ {
self(child); self(child);
continue; continue;
......
...@@ -40,6 +40,7 @@ add_library(migraphx_device ...@@ -40,6 +40,7 @@ add_library(migraphx_device
device/pack.cpp device/pack.cpp
device/clip.cpp device/clip.cpp
device/reduce_sum.cpp device/reduce_sum.cpp
device/reduce_mean.cpp
) )
set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device) set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device)
rocm_clang_tidy_check(migraphx_device) rocm_clang_tidy_check(migraphx_device)
...@@ -80,6 +81,7 @@ add_library(migraphx_gpu ...@@ -80,6 +81,7 @@ add_library(migraphx_gpu
adjust_allocation.cpp adjust_allocation.cpp
clip.cpp clip.cpp
reduce_sum.cpp reduce_sum.cpp
reduce_mean.cpp
) )
set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu) set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu)
rocm_clang_tidy_check(migraphx_gpu) rocm_clang_tidy_check(migraphx_gpu)
......
...@@ -28,6 +28,16 @@ struct id ...@@ -28,6 +28,16 @@ struct id
} }
}; };
struct mean
{
size_t item_num = 1;
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x) const
{
return static_cast<T>(x / item_num);
}
};
struct max struct max
{ {
template <class T, class U> template <class T, class U>
......
#include <migraphx/gpu/device/reduce_mean.hpp>
#include <migraphx/gpu/device/reduce.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void reduce_mean(hipStream_t stream, const argument& result, const argument& arg)
{
std::size_t item_num = arg.get_shape().elements() / result.get_shape().elements();
reduce(stream, result, arg, sum{}, 0, id{}, mean{item_num});
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_REDUCE_MEAN_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_REDUCE_MEAN_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void reduce_mean(hipStream_t stream, const argument& result, const argument& arg);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_REDUCE_MEAN_HPP
#define MIGRAPHX_GUARD_RTGLIB_REDUCE_MEAN_HPP
#include <migraphx/shape.hpp>
#include <migraphx/op/reduce_mean.hpp>
#include <migraphx/reflect.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct hip_reduce_mean
{
op::reduce_mean op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "gpu::reduce_mean"; }
shape compute_shape(std::vector<shape> inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -53,6 +53,7 @@ ...@@ -53,6 +53,7 @@
#include <migraphx/gpu/convert.hpp> #include <migraphx/gpu/convert.hpp>
#include <migraphx/gpu/clip.hpp> #include <migraphx/gpu/clip.hpp>
#include <migraphx/gpu/reduce_sum.hpp> #include <migraphx/gpu/reduce_sum.hpp>
#include <migraphx/gpu/reduce_mean.hpp>
#include <utility> #include <utility>
#include <functional> #include <functional>
#include <algorithm> #include <algorithm>
...@@ -115,6 +116,7 @@ struct miopen_apply ...@@ -115,6 +116,7 @@ struct miopen_apply
add_extend_op<hip_convert, op::convert>("convert"); add_extend_op<hip_convert, op::convert>("convert");
add_extend_op<hip_clip, op::clip>("clip"); add_extend_op<hip_clip, op::clip>("clip");
add_extend_op<hip_reduce_sum, op::reduce_sum>("reduce_sum"); add_extend_op<hip_reduce_sum, op::reduce_sum>("reduce_sum");
add_extend_op<hip_reduce_mean, op::reduce_mean>("reduce_mean");
add_lrn_op(); add_lrn_op();
add_convolution_op(); add_convolution_op();
......
#include <migraphx/gpu/reduce_mean.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device/reduce_mean.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_reduce_mean::compute_shape(std::vector<shape> inputs) const
{
inputs.pop_back();
return op.compute_shape(inputs);
}
argument
hip_reduce_mean::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::reduce_mean(ctx.get_stream().get(), args.back(), args.front());
return args.back();
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -79,7 +79,8 @@ struct tf_parser ...@@ -79,7 +79,8 @@ struct tf_parser
return result; return result;
} }
std::vector<size_t> parse_axes(const attribute_map& attributes, const std::string& s) const std::vector<size_t>
parse_axes(const attribute_map& attributes, const std::string& s, const size_t num_dims) const
{ {
auto attrs = attributes.at(s).list().i(); auto attrs = attributes.at(s).list().i();
std::vector<size_t> axes; std::vector<size_t> axes;
...@@ -87,14 +88,14 @@ struct tf_parser ...@@ -87,14 +88,14 @@ struct tf_parser
if(is_nhwc) if(is_nhwc)
{ {
std::transform(axes.begin(), axes.end(), axes.begin(), [&](size_t axis) { std::transform(axes.begin(), axes.end(), axes.begin(), [&](size_t axis) {
return parse_axis(axis); return parse_axis(axis, num_dims);
}); });
} }
return axes; return axes;
} }
template <class T> template <class T>
std::vector<T> parse_axes(std::vector<T> axes) const std::vector<T> parse_axes(std::vector<T> axes, const size_t num_dims) const
{ {
if(is_nhwc) if(is_nhwc)
{ {
...@@ -102,7 +103,7 @@ struct tf_parser ...@@ -102,7 +103,7 @@ struct tf_parser
std::transform(axes.begin(), std::transform(axes.begin(),
axes.end(), axes.end(),
std::back_inserter(new_axes), std::back_inserter(new_axes),
[&](size_t axis) { return parse_axis(axis); }); [&](size_t axis) { return parse_axis(axis, num_dims); });
return new_axes; return new_axes;
} }
return axes; return axes;
...@@ -117,17 +118,17 @@ struct tf_parser ...@@ -117,17 +118,17 @@ struct tf_parser
std::vector<T> new_data(prev_data.size()); std::vector<T> new_data(prev_data.size());
for(size_t i = 0; i < new_data.size(); i++) for(size_t i = 0; i < new_data.size(); i++)
{ {
auto new_idx = parse_axis(i); auto new_idx = parse_axis(i, new_data.size());
new_data.at(new_idx) = prev_data.at(i); new_data.at(new_idx) = prev_data.at(i);
} }
prev_data = new_data; prev_data = new_data;
} }
template <class T> template <class T>
T parse_axis(const T& dim) const T parse_axis(const T& dim, const size_t num_dims) const
{ {
T new_dim = dim; T new_dim = dim;
if(is_nhwc) if(is_nhwc and num_dims >= 4)
{ {
switch(dim) switch(dim)
{ {
...@@ -165,6 +166,7 @@ struct tf_parser ...@@ -165,6 +166,7 @@ struct tf_parser
add_mem_op("Const", &tf_parser::parse_constant); add_mem_op("Const", &tf_parser::parse_constant);
add_mem_op("Conv2D", &tf_parser::parse_conv); add_mem_op("Conv2D", &tf_parser::parse_conv);
add_mem_op("DepthwiseConv2dNative", &tf_parser::parse_depthwiseconv); add_mem_op("DepthwiseConv2dNative", &tf_parser::parse_depthwiseconv);
add_mem_op("ExpandDims", &tf_parser::parse_expanddims, false);
add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm); add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm);
add_mem_op("MatMul", &tf_parser::parse_matmul, false); add_mem_op("MatMul", &tf_parser::parse_matmul, false);
add_mem_op("MaxPool", &tf_parser::parse_pooling); add_mem_op("MaxPool", &tf_parser::parse_pooling);
...@@ -490,6 +492,25 @@ struct tf_parser ...@@ -490,6 +492,25 @@ struct tf_parser
return prog.add_instruction(op, {l0, new_weights}); return prog.add_instruction(op, {l0, new_weights});
} }
instruction_ref
parse_expanddims(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
std::vector<size_t> input_dims = args[0]->get_shape().lens();
std::vector<int64_t> new_dims(input_dims.begin(), input_dims.end());
size_t num_dims = input_dims.size();
int32_t dim = args[1]->eval().at<int32_t>();
if(dim < 0)
{
new_dims.insert(new_dims.begin() + (num_dims + dim + 1), 1);
}
else
{
new_dims.insert(new_dims.begin() + dim, 1);
}
return prog.add_instruction(op::reshape{new_dims}, args[0]);
}
instruction_ref instruction_ref
parse_matmul(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_matmul(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
...@@ -519,11 +540,12 @@ struct tf_parser ...@@ -519,11 +540,12 @@ struct tf_parser
instruction_ref instruction_ref
parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
auto axes = parse_axes(args[1]->eval().get<int32_t>().to_vector());
bool keep_dims = attributes.at("keep_dims").b(); bool keep_dims = attributes.at("keep_dims").b();
std::vector<int32_t> hw_axes{2, 3}; std::vector<int32_t> hw_axes{2, 3};
// check if conditions for GlobalAvgPool are met // check if conditions for GlobalAvgPool are met
auto lens = args[0]->get_shape().lens(); auto lens = args[0]->get_shape().lens();
auto axes = parse_axes(args[1]->eval().get<int32_t>().to_vector(), lens.size());
if(axes == hw_axes and lens.size() == 4) if(axes == hw_axes and lens.size() == 4)
{ {
op::pooling op{"average"}; op::pooling op{"average"};
...@@ -694,14 +716,15 @@ struct tf_parser ...@@ -694,14 +716,15 @@ struct tf_parser
std::vector<instruction_ref> args) std::vector<instruction_ref> args)
{ {
op::squeeze op; op::squeeze op;
auto input_dims = args[0]->get_shape().lens();
auto axes = attributes.at("squeeze_dims").list().i(); auto axes = attributes.at("squeeze_dims").list().i();
copy(axes, std::back_inserter(op.axes)); copy(axes, std::back_inserter(op.axes));
auto args0_dims = args[0]->get_shape().lens();
if(op.axes.empty()) // no squeeze_dims provided, remove any dim that equals 1 if(op.axes.empty()) // no squeeze_dims provided, remove any dim that equals 1
{ {
for(size_t i = 0; i < args0_dims.size(); i++) for(size_t i = 0; i < input_dims.size(); i++)
{ {
if(args0_dims.at(i) == 1) if(input_dims.at(i) == 1)
{ {
op.axes.push_back(i); op.axes.push_back(i);
} }
......
...@@ -1804,7 +1804,7 @@ TEST_CASE(clip_test) ...@@ -1804,7 +1804,7 @@ TEST_CASE(clip_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(reduce_sum_test0) TEST_CASE(reduce_sum_axis0)
{ {
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
...@@ -1819,7 +1819,7 @@ TEST_CASE(reduce_sum_test0) ...@@ -1819,7 +1819,7 @@ TEST_CASE(reduce_sum_test0)
EXPECT(results_vector == gold); EXPECT(results_vector == gold);
} }
TEST_CASE(reduce_sum_test1) TEST_CASE(reduce_sum_axis1)
{ {
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
...@@ -1834,7 +1834,7 @@ TEST_CASE(reduce_sum_test1) ...@@ -1834,7 +1834,7 @@ TEST_CASE(reduce_sum_test1)
EXPECT(results_vector == gold); EXPECT(results_vector == gold);
} }
TEST_CASE(reduce_sum_test2) TEST_CASE(reduce_sum_axis2)
{ {
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
...@@ -1849,7 +1849,7 @@ TEST_CASE(reduce_sum_test2) ...@@ -1849,7 +1849,7 @@ TEST_CASE(reduce_sum_test2)
EXPECT(results_vector == gold); EXPECT(results_vector == gold);
} }
TEST_CASE(reduce_sum_test02) TEST_CASE(reduce_sum_axis02)
{ {
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
...@@ -1864,7 +1864,7 @@ TEST_CASE(reduce_sum_test02) ...@@ -1864,7 +1864,7 @@ TEST_CASE(reduce_sum_test02)
EXPECT(results_vector == gold); EXPECT(results_vector == gold);
} }
TEST_CASE(reduce_sum_test12) TEST_CASE(reduce_sum_axis12)
{ {
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}}; migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
...@@ -1879,4 +1879,79 @@ TEST_CASE(reduce_sum_test12) ...@@ -1879,4 +1879,79 @@ TEST_CASE(reduce_sum_test12)
EXPECT(results_vector == gold); EXPECT(results_vector == gold);
} }
TEST_CASE(reduce_mean_axis1)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_mean{{1}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{2, 3, 6, 7, 10, 11};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_mean_axis2)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_mean{{2}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1.5f, 3.5f, 5.5f, 7.5f, 9.5f, 11.5f};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_mean_axis02)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_mean{{0, 2}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{5.5, 7.5};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_mean_axis12)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_mean{{1, 2}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{2.5f, 6.5f, 10.5f};
EXPECT(results_vector == gold);
}
TEST_CASE(reduce_mean_int)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::int32_type, {3, 2, 2}};
auto input = migraphx::literal{s, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}};
auto l0 = p.add_literal(input);
p.add_instruction(migraphx::op::reduce_mean{{1, 2}}, l0);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<int> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<int> gold{2, 6, 10};
EXPECT(results_vector == gold);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -3631,4 +3631,76 @@ struct test_fp32_fp16_sub : verify_program<test_fp32_fp16_sub> ...@@ -3631,4 +3631,76 @@ struct test_fp32_fp16_sub : verify_program<test_fp32_fp16_sub>
}; };
}; };
struct test_reduce_sum : verify_program<test_reduce_sum>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 1026, 4, 3}};
auto x = p.add_parameter("x", s);
p.add_instruction(migraphx::op::reduce_sum{{1}}, x);
return p;
};
};
struct test_reduce_sum_int : verify_program<test_reduce_sum_int>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::int32_type, {3, 4, 8, 8}};
auto x = p.add_parameter("x", s);
p.add_instruction(migraphx::op::reduce_sum{{1}}, x);
return p;
};
};
struct test_reduce_sum_half : verify_program<test_reduce_sum_half>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::half_type, {3, 4, 8, 8}};
auto x = p.add_parameter("x", s);
p.add_instruction(migraphx::op::reduce_sum{{1}}, x);
return p;
};
};
struct test_reduce_mean : verify_program<test_reduce_mean>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 9, 4, 3}};
auto x = p.add_parameter("x", s);
p.add_instruction(migraphx::op::reduce_mean{{1}}, x);
return p;
};
};
struct test_reduce_mean_int : verify_program<test_reduce_mean_int>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::int32_type, {3, 1024, 8, 8}};
auto x = p.add_parameter("x", s);
p.add_instruction(migraphx::op::reduce_mean{{1}}, x);
return p;
};
};
struct test_reduce_mean_half : verify_program<test_reduce_mean_half>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::half_type, {3, 1024, 8, 8}};
auto x = p.add_parameter("x", s);
p.add_instruction(migraphx::op::reduce_mean{{2}}, x);
return p;
};
};
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -858,6 +858,27 @@ TEST_CASE(reducesum_test3) ...@@ -858,6 +858,27 @@ TEST_CASE(reducesum_test3)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(reducemean_test1)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l1 = p.add_instruction(migraphx::op::reduce_mean{{2, 3}}, l0);
p.add_instruction(migraphx::op::squeeze{{2, 3}}, l1);
auto prog = migraphx::parse_onnx("reducemean_test1.onnx");
EXPECT(p == prog);
}
TEST_CASE(reducemean_test2)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
p.add_instruction(migraphx::op::reduce_mean{{2}}, l0);
auto prog = migraphx::parse_onnx("reducemean_test2.onnx");
EXPECT(p == prog);
}
TEST_CASE(clip_test) TEST_CASE(clip_test)
{ {
migraphx::program p; migraphx::program p;
......
reducemean-example:}
0
xy"
ReduceMean*
axes@*
keepdimstest_reducemeanZ
x




b
y




B
...@@ -410,37 +410,111 @@ TEST_CASE(softmax) { test_softmax_variations<migraphx::op::softmax>(); } ...@@ -410,37 +410,111 @@ TEST_CASE(softmax) { test_softmax_variations<migraphx::op::softmax>(); }
TEST_CASE(logsoftmax) { test_softmax_variations<migraphx::op::logsoftmax>(); } TEST_CASE(logsoftmax) { test_softmax_variations<migraphx::op::logsoftmax>(); }
template <class T> TEST_CASE(test_argmax)
void test_argop_var()
{ {
{ {
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {1, 3, 4, 5}}, T{0}, input); expect_shape(migraphx::shape{migraphx::shape::int64_type, {1, 3, 4, 5}},
migraphx::op::argmax{0},
input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 1, 4, 5}},
migraphx::op::argmax{1},
input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 1, 5}},
migraphx::op::argmax{2},
input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 1}},
migraphx::op::argmax{3},
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(migraphx::op::argmax{4}, input);
} }
}
TEST_CASE(test_argmin)
{
{ {
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 1, 4, 5}}, T{1}, input); expect_shape(migraphx::shape{migraphx::shape::int64_type, {1, 3, 4, 5}},
migraphx::op::argmin{0},
input);
} }
{ {
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 1, 5}}, T{2}, input); expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 1, 4, 5}},
migraphx::op::argmin{1},
input);
} }
{ {
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 1}}, T{3}, input); expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 1, 5}},
migraphx::op::argmin{2},
input);
} }
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 1}},
migraphx::op::argmin{3},
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(migraphx::op::argmin{4}, input);
}
}
template <class T>
void test_reduce_ops()
{
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}}, T{}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}}, T{{0, 1, 2, 3}}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 1, 1}}, T{{2, 3}}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}}, T{{0}}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 1}}, T{{-1}}, input);
}
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(T{4}, input); throws_shape(T{{4}}, input);
} }
} }
TEST_CASE(argmax) { test_argop_var<migraphx::op::argmax>(); } TEST_CASE(reduce_sum) { test_reduce_ops<migraphx::op::reduce_sum>(); }
TEST_CASE(argmin) { test_argop_var<migraphx::op::argmin>(); } TEST_CASE(reduce_mean) { test_reduce_ops<migraphx::op::reduce_mean>(); }
// 2 inputs arguments // 2 inputs arguments
TEST_CASE(matmul) TEST_CASE(matmul)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment