Commit 079ccd40 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix review comments

parent 2b8daf9c
......@@ -14,7 +14,7 @@ namespace op {
struct reduce_mean
{
std::vector<int64_t> axes{};
std::vector<std::size_t> axes{};
template <class Self, class F>
static auto reflect(Self& self, F f)
......@@ -31,7 +31,7 @@ struct reduce_mean
auto lens = s.lens();
for(auto axis : axes)
{
if(axis < 0 or axis >= lens.size())
if(axis >= lens.size())
MIGRAPHX_THROW("REDUCE_MEAN: axis out of range");
lens[axis] = 1;
}
......
......@@ -14,7 +14,7 @@ namespace op {
struct reduce_sum
{
std::vector<int64_t> axes{};
std::vector<std::size_t> axes{};
template <class Self, class F>
static auto reflect(Self& self, F f)
......@@ -31,7 +31,7 @@ struct reduce_sum
auto lens = s.lens();
for(auto axis : axes)
{
if(axis < 0 or axis >= lens.size())
if(axis >= lens.size())
MIGRAPHX_THROW("REDUCE_SUM: axis out of range");
lens[axis] = 1;
}
......
......@@ -95,8 +95,8 @@ struct onnx_parser
add_mem_op("GRU", &onnx_parser::parse_gru);
add_mem_op("LSTM", &onnx_parser::parse_lstm);
add_mem_op("Pad", &onnx_parser::parse_pad);
add_mem_op("ReduceSum", &onnx_parser::parse_reduce_sum);
add_mem_op("ReduceMean", &onnx_parser::parse_reduce_mean);
add_mem_op("ReduceSum", &onnx_parser::parse_reduce_oper<op::reduce_sum>);
add_mem_op("ReduceMean", &onnx_parser::parse_reduce_oper<op::reduce_mean>);
// init the activation function map
init_actv_func();
......@@ -1288,20 +1288,21 @@ struct onnx_parser
return {hidden_states, last_output, last_cell_output};
}
instruction_ref parse_reduce_sum(const std::string&,
template<class T>
instruction_ref parse_reduce_oper(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
{
std::size_t n_dim = args.front()->get_shape().lens().size();
// default to reduce over all dimensions
std::vector<int64_t> axes(n_dim);
std::vector<std::size_t> axes(n_dim);
std::iota(axes.begin(), axes.end(), 0);
if(contains(attributes, "axes"))
{
axes.clear();
auto&& attr_axes = attributes["axes"].ints();
axes = std::vector<int64_t>(attr_axes.begin(), attr_axes.end());
axes = std::vector<std::size_t>(attr_axes.begin(), attr_axes.end());
}
int keep_dims = 1;
......@@ -1312,45 +1313,13 @@ struct onnx_parser
if(keep_dims == 1)
{
return prog.add_instruction(op::reduce_sum{axes}, std::move(args));
return prog.add_instruction(T{axes}, std::move(args));
}
else
{
auto ins = prog.add_instruction(op::reduce_sum{axes}, std::move(args));
return prog.add_instruction(op::squeeze{axes}, ins);
}
}
instruction_ref parse_reduce_mean(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
{
std::size_t n_dim = args.front()->get_shape().lens().size();
// default to reduce over all dimensions
std::vector<int64_t> axes(n_dim);
std::iota(axes.begin(), axes.end(), 0);
if(contains(attributes, "axes"))
{
axes.clear();
auto&& attr_axes = attributes["axes"].ints();
axes = std::vector<int64_t>(attr_axes.begin(), attr_axes.end());
}
int keep_dims = 1;
if(contains(attributes, "keepdims"))
{
keep_dims = parse_value(attributes.at("keepdims")).at<int>();
}
if(keep_dims == 1)
{
return prog.add_instruction(op::reduce_mean{axes}, std::move(args));
}
else
{
auto ins = prog.add_instruction(op::reduce_mean{axes}, std::move(args));
return prog.add_instruction(op::squeeze{axes}, ins);
auto ins = prog.add_instruction(T{axes}, std::move(args));
std::vector<int64_t> sq_axes(axes.begin(), axes.end());
return prog.add_instruction(op::squeeze{sq_axes}, ins);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment