Commit 0246f32b authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'reduce_mean' into test_bert

parents 4e49ad18 0525939c
......@@ -14,7 +14,7 @@ namespace op {
struct reduce_mean
{
std::vector<int64_t> axes{};
std::vector<std::size_t> axes{};
template <class Self, class F>
static auto reflect(Self& self, F f)
......@@ -31,7 +31,7 @@ struct reduce_mean
auto lens = s.lens();
for(auto axis : axes)
{
if(axis < 0 or axis >= lens.size())
if(axis >= lens.size())
MIGRAPHX_THROW("REDUCE_MEAN: axis out of range");
lens[axis] = 1;
}
......
......@@ -14,7 +14,7 @@ namespace op {
struct reduce_sum
{
std::vector<int64_t> axes{};
std::vector<std::size_t> axes{};
template <class Self, class F>
static auto reflect(Self& self, F f)
......@@ -31,7 +31,7 @@ struct reduce_sum
auto lens = s.lens();
for(auto axis : axes)
{
if(axis < 0 or axis >= lens.size())
if(axis >= lens.size())
MIGRAPHX_THROW("REDUCE_SUM: axis out of range");
lens[axis] = 1;
}
......
......@@ -100,8 +100,8 @@ struct onnx_parser
add_mem_op("GRU", &onnx_parser::parse_gru);
add_mem_op("LSTM", &onnx_parser::parse_lstm);
add_mem_op("Pad", &onnx_parser::parse_pad);
add_mem_op("ReduceSum", &onnx_parser::parse_reduce_sum);
add_mem_op("ReduceMean", &onnx_parser::parse_reduce_mean);
add_mem_op("ReduceSum", &onnx_parser::parse_reduce_oper<op::reduce_sum>);
add_mem_op("ReduceMean", &onnx_parser::parse_reduce_oper<op::reduce_mean>);
// init the activation function map
init_actv_func();
......@@ -1373,53 +1373,21 @@ struct onnx_parser
return {hidden_states, last_output, last_cell_output};
}
instruction_ref parse_reduce_sum(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
{
std::size_t n_dim = args.front()->get_shape().lens().size();
// default to reduce over all dimensions
std::vector<int64_t> axes(n_dim);
std::iota(axes.begin(), axes.end(), 0);
if(contains(attributes, "axes"))
{
axes.clear();
auto&& attr_axes = attributes["axes"].ints();
axes = std::vector<int64_t>(attr_axes.begin(), attr_axes.end());
}
int keep_dims = 1;
if(contains(attributes, "keepdims"))
{
keep_dims = parse_value(attributes.at("keepdims")).at<int>();
}
if(keep_dims == 1)
{
return prog.add_instruction(op::reduce_sum{axes}, std::move(args));
}
else
{
auto ins = prog.add_instruction(op::reduce_sum{axes}, std::move(args));
return prog.add_instruction(op::squeeze{axes}, ins);
}
}
instruction_ref parse_reduce_mean(const std::string&,
template <class T>
instruction_ref parse_reduce_oper(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
{
std::size_t n_dim = args.front()->get_shape().lens().size();
// default to reduce over all dimensions
std::vector<int64_t> axes(n_dim);
std::vector<std::size_t> axes(n_dim);
std::iota(axes.begin(), axes.end(), 0);
if(contains(attributes, "axes"))
{
axes.clear();
auto&& attr_axes = attributes["axes"].ints();
axes = std::vector<int64_t>(attr_axes.begin(), attr_axes.end());
axes = std::vector<std::size_t>(attr_axes.begin(), attr_axes.end());
}
int keep_dims = 1;
......@@ -1430,12 +1398,13 @@ struct onnx_parser
if(keep_dims == 1)
{
return prog.add_instruction(op::reduce_mean{axes}, std::move(args));
return prog.add_instruction(T{axes}, std::move(args));
}
else
{
auto ins = prog.add_instruction(op::reduce_mean{axes}, std::move(args));
return prog.add_instruction(op::squeeze{axes}, ins);
auto ins = prog.add_instruction(T{axes}, std::move(args));
std::vector<int64_t> sq_axes(axes.begin(), axes.end());
return prog.add_instruction(op::squeeze{sq_axes}, ins);
}
}
......
......@@ -28,7 +28,7 @@ struct id
}
};
struct scale
struct mean
{
size_t item_num = 1;
template <class T>
......
......@@ -9,7 +9,7 @@ namespace device {
void reduce_mean(hipStream_t stream, const argument& result, const argument& arg)
{
std::size_t item_num = arg.get_shape().elements() / result.get_shape().elements();
reduce(stream, result, arg, sum{}, 0, id{}, scale{item_num});
reduce(stream, result, arg, sum{}, 0, id{}, mean{item_num});
}
} // namespace device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment