"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "bba1580b518d8a9f5d88aafb2140b458e4d79d24"
Commit 0246f32b authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'reduce_mean' into test_bert

parents 4e49ad18 0525939c
...@@ -14,7 +14,7 @@ namespace op { ...@@ -14,7 +14,7 @@ namespace op {
struct reduce_mean struct reduce_mean
{ {
std::vector<int64_t> axes{}; std::vector<std::size_t> axes{};
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -31,7 +31,7 @@ struct reduce_mean ...@@ -31,7 +31,7 @@ struct reduce_mean
auto lens = s.lens(); auto lens = s.lens();
for(auto axis : axes) for(auto axis : axes)
{ {
if(axis < 0 or axis >= lens.size()) if(axis >= lens.size())
MIGRAPHX_THROW("REDUCE_MEAN: axis out of range"); MIGRAPHX_THROW("REDUCE_MEAN: axis out of range");
lens[axis] = 1; lens[axis] = 1;
} }
......
...@@ -14,7 +14,7 @@ namespace op { ...@@ -14,7 +14,7 @@ namespace op {
struct reduce_sum struct reduce_sum
{ {
std::vector<int64_t> axes{}; std::vector<std::size_t> axes{};
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -31,7 +31,7 @@ struct reduce_sum ...@@ -31,7 +31,7 @@ struct reduce_sum
auto lens = s.lens(); auto lens = s.lens();
for(auto axis : axes) for(auto axis : axes)
{ {
if(axis < 0 or axis >= lens.size()) if(axis >= lens.size())
MIGRAPHX_THROW("REDUCE_SUM: axis out of range"); MIGRAPHX_THROW("REDUCE_SUM: axis out of range");
lens[axis] = 1; lens[axis] = 1;
} }
......
...@@ -100,8 +100,8 @@ struct onnx_parser ...@@ -100,8 +100,8 @@ struct onnx_parser
add_mem_op("GRU", &onnx_parser::parse_gru); add_mem_op("GRU", &onnx_parser::parse_gru);
add_mem_op("LSTM", &onnx_parser::parse_lstm); add_mem_op("LSTM", &onnx_parser::parse_lstm);
add_mem_op("Pad", &onnx_parser::parse_pad); add_mem_op("Pad", &onnx_parser::parse_pad);
add_mem_op("ReduceSum", &onnx_parser::parse_reduce_sum); add_mem_op("ReduceSum", &onnx_parser::parse_reduce_oper<op::reduce_sum>);
add_mem_op("ReduceMean", &onnx_parser::parse_reduce_mean); add_mem_op("ReduceMean", &onnx_parser::parse_reduce_oper<op::reduce_mean>);
// init the activation function map // init the activation function map
init_actv_func(); init_actv_func();
...@@ -1373,53 +1373,21 @@ struct onnx_parser ...@@ -1373,53 +1373,21 @@ struct onnx_parser
return {hidden_states, last_output, last_cell_output}; return {hidden_states, last_output, last_cell_output};
} }
instruction_ref parse_reduce_sum(const std::string&, template <class T>
attribute_map attributes, instruction_ref parse_reduce_oper(const std::string&,
std::vector<instruction_ref> args)
{
std::size_t n_dim = args.front()->get_shape().lens().size();
// default to reduce over all dimensions
std::vector<int64_t> axes(n_dim);
std::iota(axes.begin(), axes.end(), 0);
if(contains(attributes, "axes"))
{
axes.clear();
auto&& attr_axes = attributes["axes"].ints();
axes = std::vector<int64_t>(attr_axes.begin(), attr_axes.end());
}
int keep_dims = 1;
if(contains(attributes, "keepdims"))
{
keep_dims = parse_value(attributes.at("keepdims")).at<int>();
}
if(keep_dims == 1)
{
return prog.add_instruction(op::reduce_sum{axes}, std::move(args));
}
else
{
auto ins = prog.add_instruction(op::reduce_sum{axes}, std::move(args));
return prog.add_instruction(op::squeeze{axes}, ins);
}
}
instruction_ref parse_reduce_mean(const std::string&,
attribute_map attributes, attribute_map attributes,
std::vector<instruction_ref> args) std::vector<instruction_ref> args)
{ {
std::size_t n_dim = args.front()->get_shape().lens().size(); std::size_t n_dim = args.front()->get_shape().lens().size();
// default to reduce over all dimensions // default to reduce over all dimensions
std::vector<int64_t> axes(n_dim); std::vector<std::size_t> axes(n_dim);
std::iota(axes.begin(), axes.end(), 0); std::iota(axes.begin(), axes.end(), 0);
if(contains(attributes, "axes")) if(contains(attributes, "axes"))
{ {
axes.clear(); axes.clear();
auto&& attr_axes = attributes["axes"].ints(); auto&& attr_axes = attributes["axes"].ints();
axes = std::vector<int64_t>(attr_axes.begin(), attr_axes.end()); axes = std::vector<std::size_t>(attr_axes.begin(), attr_axes.end());
} }
int keep_dims = 1; int keep_dims = 1;
...@@ -1430,12 +1398,13 @@ struct onnx_parser ...@@ -1430,12 +1398,13 @@ struct onnx_parser
if(keep_dims == 1) if(keep_dims == 1)
{ {
return prog.add_instruction(op::reduce_mean{axes}, std::move(args)); return prog.add_instruction(T{axes}, std::move(args));
} }
else else
{ {
auto ins = prog.add_instruction(op::reduce_mean{axes}, std::move(args)); auto ins = prog.add_instruction(T{axes}, std::move(args));
return prog.add_instruction(op::squeeze{axes}, ins); std::vector<int64_t> sq_axes(axes.begin(), axes.end());
return prog.add_instruction(op::squeeze{sq_axes}, ins);
} }
} }
......
...@@ -28,7 +28,7 @@ struct id ...@@ -28,7 +28,7 @@ struct id
} }
}; };
struct scale struct mean
{ {
size_t item_num = 1; size_t item_num = 1;
template <class T> template <class T>
......
...@@ -9,7 +9,7 @@ namespace device { ...@@ -9,7 +9,7 @@ namespace device {
void reduce_mean(hipStream_t stream, const argument& result, const argument& arg) void reduce_mean(hipStream_t stream, const argument& result, const argument& arg)
{ {
std::size_t item_num = arg.get_shape().elements() / result.get_shape().elements(); std::size_t item_num = arg.get_shape().elements() / result.get_shape().elements();
reduce(stream, result, arg, sum{}, 0, id{}, scale{item_num}); reduce(stream, result, arg, sum{}, 0, id{}, mean{item_num});
} }
} // namespace device } // namespace device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment