Unverified Commit c4e53a33 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Add parsing split operator (#460)



* change parse operator function signature

* clang format

* add parsing the split operator

* clang format

* add parsing split operator

* make squeeze/unsqueeze inputs to standard shape

* add unit tests for the split operator

* clang format

* fix cppcheck error

* clang format

* update tests for multiple program outputs

* clang format

* fix cppcheck error

* clang format
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
Co-authored-by: default avatarPaul Fultz II <pfultz2@yahoo.com>
parent 330224aa
...@@ -24,9 +24,14 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -24,9 +24,14 @@ inline namespace MIGRAPHX_INLINE_NS {
struct onnx_parser struct onnx_parser
{ {
using attribute_map = std::unordered_map<std::string, onnx::AttributeProto>; using attribute_map = std::unordered_map<std::string, onnx::AttributeProto>;
struct node_info
{
attribute_map attributes{};
std::size_t num_outputs = 1;
};
using node_map = std::unordered_map<std::string, onnx::NodeProto>; using node_map = std::unordered_map<std::string, onnx::NodeProto>;
using op_func = using op_func =
std::function<std::vector<instruction_ref>(attribute_map, std::vector<instruction_ref>)>; std::function<std::vector<instruction_ref>(node_info, std::vector<instruction_ref>)>;
node_map nodes; node_map nodes;
std::unordered_map<std::string, instruction_ref> instructions; std::unordered_map<std::string, instruction_ref> instructions;
program prog = program(); program prog = program();
...@@ -121,6 +126,7 @@ struct onnx_parser ...@@ -121,6 +126,7 @@ struct onnx_parser
add_mem_op("Shape", &onnx_parser::parse_shape); add_mem_op("Shape", &onnx_parser::parse_shape);
add_mem_op("Slice", &onnx_parser::parse_slice); add_mem_op("Slice", &onnx_parser::parse_slice);
add_mem_op("Softmax", &onnx_parser::parse_softmax<op::softmax>); add_mem_op("Softmax", &onnx_parser::parse_softmax<op::softmax>);
add_mem_op("Split", &onnx_parser::parse_split);
add_mem_op("Squeeze", &onnx_parser::parse_squeeze); add_mem_op("Squeeze", &onnx_parser::parse_squeeze);
add_mem_op("Transpose", &onnx_parser::parse_transpose); add_mem_op("Transpose", &onnx_parser::parse_transpose);
add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze); add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze);
...@@ -166,15 +172,15 @@ struct onnx_parser ...@@ -166,15 +172,15 @@ struct onnx_parser
template <class T> template <class T>
void add_binary_op(std::string name, T x) void add_binary_op(std::string name, T x)
{ {
add_op(name, [this, x](attribute_map attributes, std::vector<instruction_ref> args) { add_op(name, [this, x](node_info info, std::vector<instruction_ref> args) {
if(args.size() != 2) if(args.size() != 2)
MIGRAPHX_THROW("binary operators should have 2 operands"); MIGRAPHX_THROW("binary operators should have 2 operands");
if(contains(attributes, "broadcast") and contains(attributes, "axis")) if(contains(info.attributes, "broadcast") and contains(info.attributes, "axis"))
{ {
uint64_t broadcasted = parse_value(attributes.at("broadcast")).at<uint64_t>(); uint64_t broadcasted = parse_value(info.attributes.at("broadcast")).at<uint64_t>();
if(broadcasted != 0) if(broadcasted != 0)
{ {
uint64_t axis = parse_value(attributes.at("axis")).at<uint64_t>(); uint64_t axis = parse_value(info.attributes.at("axis")).at<uint64_t>();
auto l = prog.add_instruction(op::broadcast{axis, args[0]->get_shape().lens()}, auto l = prog.add_instruction(op::broadcast{axis, args[0]->get_shape().lens()},
args[1]); args[1]);
return prog.add_instruction(x, args[0], l); return prog.add_instruction(x, args[0], l);
...@@ -266,7 +272,7 @@ struct onnx_parser ...@@ -266,7 +272,7 @@ struct onnx_parser
template <class T> template <class T>
void add_generic_op(std::string name, T x) void add_generic_op(std::string name, T x)
{ {
add_op(name, [this, x](const attribute_map&, std::vector<instruction_ref> args) { add_op(name, [this, x](const node_info&, std::vector<instruction_ref> args) {
return prog.add_instruction(x, args); return prog.add_instruction(x, args);
}); });
} }
...@@ -274,7 +280,7 @@ struct onnx_parser ...@@ -274,7 +280,7 @@ struct onnx_parser
template <class T> template <class T>
void add_variadic_op(std::string name, T x) void add_variadic_op(std::string name, T x)
{ {
add_op(name, [this, x](const attribute_map&, std::vector<instruction_ref> args) { add_op(name, [this, x](const node_info&, std::vector<instruction_ref> args) {
return std::accumulate(std::next(args.begin()), return std::accumulate(std::next(args.begin()),
args.end(), args.end(),
args.front(), args.front(),
...@@ -321,51 +327,48 @@ struct onnx_parser ...@@ -321,51 +327,48 @@ struct onnx_parser
} }
} }
instruction_ref parse_clip(const std::string&, instruction_ref
const attribute_map& attributes, parse_clip(const std::string&, node_info info, std::vector<instruction_ref> args)
std::vector<instruction_ref> args)
{ {
op::clip op; op::clip op;
if(contains(attributes, "max")) if(contains(info.attributes, "max"))
{ {
op.max_val = parse_value(attributes.at("max")).at<float>(); op.max_val = parse_value(info.attributes.at("max")).at<float>();
} }
if(contains(attributes, "min")) if(contains(info.attributes, "min"))
{ {
op.min_val = parse_value(attributes.at("min")).at<float>(); op.min_val = parse_value(info.attributes.at("min")).at<float>();
} }
return prog.add_instruction(op, std::move(args)); return prog.add_instruction(op, std::move(args));
} }
template <class Op> template <class Op>
instruction_ref parse_softmax(const std::string&, instruction_ref
const attribute_map& attributes, parse_softmax(const std::string&, node_info info, std::vector<instruction_ref> args)
std::vector<instruction_ref> args)
{ {
int64_t axis = 1; int64_t axis = 1;
if(contains(attributes, "axis")) if(contains(info.attributes, "axis"))
{ {
axis = parse_value(attributes.at("axis")).at<int>(); axis = parse_value(info.attributes.at("axis")).at<int>();
} }
return prog.add_instruction(Op{axis}, std::move(args)); return prog.add_instruction(Op{axis}, std::move(args));
} }
template <class Op> template <class Op>
instruction_ref parse_arg_op(const std::string&, instruction_ref
const attribute_map& attributes, parse_arg_op(const std::string&, node_info info, std::vector<instruction_ref> args)
std::vector<instruction_ref> args)
{ {
int64_t axis = 0; int64_t axis = 0;
if(contains(attributes, "axis")) if(contains(info.attributes, "axis"))
{ {
axis = static_cast<int64_t>(parse_value(attributes.at("axis")).at<int>()); axis = static_cast<int64_t>(parse_value(info.attributes.at("axis")).at<int>());
} }
int keep_dims = 1; int keep_dims = 1;
if(contains(attributes, "keepdims")) if(contains(info.attributes, "keepdims"))
{ {
keep_dims = parse_value(attributes.at("keepdims")).at<int>(); keep_dims = parse_value(info.attributes.at("keepdims")).at<int>();
} }
if(keep_dims == 0) if(keep_dims == 0)
...@@ -381,16 +384,16 @@ struct onnx_parser ...@@ -381,16 +384,16 @@ struct onnx_parser
template <class Op> template <class Op>
instruction_ref process_auto_pad_attribute(instruction_ref ins, instruction_ref process_auto_pad_attribute(instruction_ref ins,
attribute_map& attributes, node_info info,
Op& op, Op& op,
const std::vector<std::size_t>& in_lens) const std::vector<std::size_t>& in_lens)
{ {
if(!contains(attributes, "auto_pad")) if(!contains(info.attributes, "auto_pad"))
{ {
return ins; return ins;
} }
auto auto_pad = attributes["auto_pad"].s(); auto auto_pad = info.attributes["auto_pad"].s();
if(auto_pad.find("SAME") != std::string::npos) if(auto_pad.find("SAME") != std::string::npos)
{ {
// calculate the padding // calculate the padding
...@@ -440,41 +443,41 @@ struct onnx_parser ...@@ -440,41 +443,41 @@ struct onnx_parser
template <class Op> template <class Op>
instruction_ref instruction_ref
parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_conv(const std::string&, node_info info, std::vector<instruction_ref> args)
{ {
Op op; Op op;
auto l0 = args[0]; auto l0 = args[0];
auto weights = args[1]; auto weights = args[1];
if(contains(attributes, "pads")) if(contains(info.attributes, "pads"))
{ {
if(contains(attributes, "auto_pad")) if(contains(info.attributes, "auto_pad"))
{ {
auto s = attributes["auto_pad"].s(); auto s = info.attributes["auto_pad"].s();
if(contains(attributes, "pads") and to_upper(s) != "NOTSET") if(contains(info.attributes, "pads") and to_upper(s) != "NOTSET")
{ {
MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously"); MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
} }
} }
std::vector<std::int64_t> padding; std::vector<std::int64_t> padding;
copy(attributes["pads"].ints(), std::back_inserter(padding)); copy(info.attributes["pads"].ints(), std::back_inserter(padding));
if(padding.size() != 4) if(padding.size() != 4)
{ {
MIGRAPHX_THROW("padding should have 4 values"); MIGRAPHX_THROW("padding should have 4 values");
} }
check_asym_padding(l0, padding, op); check_asym_padding(l0, padding, op);
} }
if(contains(attributes, "strides")) if(contains(info.attributes, "strides"))
{ {
copy(attributes["strides"].ints(), op.stride.begin()); copy(info.attributes["strides"].ints(), op.stride.begin());
} }
if(contains(attributes, "dilations")) if(contains(info.attributes, "dilations"))
{ {
copy(attributes["dilations"].ints(), op.dilation.begin()); copy(info.attributes["dilations"].ints(), op.dilation.begin());
} }
if(contains(attributes, "auto_pad")) if(contains(info.attributes, "auto_pad"))
{ {
auto s = attributes["auto_pad"].s(); auto s = info.attributes["auto_pad"].s();
if(contains(attributes, "pads") and to_upper(s) != "NOTSET") if(contains(info.attributes, "pads") and to_upper(s) != "NOTSET")
{ {
MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously"); MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
} }
...@@ -496,34 +499,33 @@ struct onnx_parser ...@@ -496,34 +499,33 @@ struct onnx_parser
check_asym_padding(l0, padding, op); check_asym_padding(l0, padding, op);
} }
} }
if(contains(attributes, "group")) if(contains(info.attributes, "group"))
{ {
op.group = parse_value(attributes.at("group")).at<int>(); op.group = parse_value(info.attributes.at("group")).at<int>();
} }
auto l1 = prog.add_instruction(op, l0, args[1]); auto l1 = prog.add_instruction(op, l0, args[1]);
return add_bias(args, l1, 1); return add_bias(args, l1, 1);
} }
instruction_ref parse_conv_transpose(const std::string&, instruction_ref
attribute_map attributes, parse_conv_transpose(const std::string&, node_info info, std::vector<instruction_ref> args)
std::vector<instruction_ref> args)
{ {
op::deconvolution op; op::deconvolution op;
auto l0 = args[0]; auto l0 = args[0];
std::vector<std::int64_t> padding; std::vector<std::int64_t> padding;
bool asymm_padding = false; bool asymm_padding = false;
if(contains(attributes, "pads")) if(contains(info.attributes, "pads"))
{ {
if(contains(attributes, "auto_pad")) if(contains(info.attributes, "auto_pad"))
{ {
auto s = attributes["auto_pad"].s(); auto s = info.attributes["auto_pad"].s();
if(contains(attributes, "pads") and to_upper(s) != "NOTSET") if(contains(info.attributes, "pads") and to_upper(s) != "NOTSET")
{ {
MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously"); MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
} }
} }
copy(attributes["pads"].ints(), std::back_inserter(padding)); copy(info.attributes["pads"].ints(), std::back_inserter(padding));
if(padding.size() != 4) if(padding.size() != 4)
{ {
MIGRAPHX_THROW("padding should have 4 values"); MIGRAPHX_THROW("padding should have 4 values");
...@@ -538,18 +540,18 @@ struct onnx_parser ...@@ -538,18 +540,18 @@ struct onnx_parser
op.padding[1] = padding[1]; op.padding[1] = padding[1];
} }
} }
if(contains(attributes, "strides")) if(contains(info.attributes, "strides"))
{ {
copy(attributes["strides"].ints(), op.stride.begin()); copy(info.attributes["strides"].ints(), op.stride.begin());
} }
if(contains(attributes, "dilations")) if(contains(info.attributes, "dilations"))
{ {
copy(attributes["dilations"].ints(), op.dilation.begin()); copy(info.attributes["dilations"].ints(), op.dilation.begin());
} }
if(contains(attributes, "auto_pad")) if(contains(info.attributes, "auto_pad"))
{ {
auto s = attributes["auto_pad"].s(); auto s = info.attributes["auto_pad"].s();
if(contains(attributes, "pads") and to_upper(s) != "NOTSET") if(contains(info.attributes, "pads") and to_upper(s) != "NOTSET")
{ {
MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously"); MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
} }
...@@ -560,9 +562,9 @@ struct onnx_parser ...@@ -560,9 +562,9 @@ struct onnx_parser
} }
} }
if(contains(attributes, "group")) if(contains(info.attributes, "group"))
{ {
op.group = parse_value(attributes.at("group")).at<int>(); op.group = parse_value(info.attributes.at("group")).at<int>();
} }
auto l1 = prog.add_instruction(op, l0, args[1]); auto l1 = prog.add_instruction(op, l0, args[1]);
...@@ -579,18 +581,18 @@ struct onnx_parser ...@@ -579,18 +581,18 @@ struct onnx_parser
l1 = prog.add_instruction(slice_op, l1); l1 = prog.add_instruction(slice_op, l1);
} }
if(contains(attributes, "output_padding")) if(contains(info.attributes, "output_padding"))
{ {
std::vector<int64_t> output_padding; std::vector<int64_t> output_padding;
copy(attributes["output_padding"].ints(), std::back_inserter(output_padding)); copy(info.attributes["output_padding"].ints(), std::back_inserter(output_padding));
output_padding = {0, 0, 0, 0, 0, 0, output_padding[0], output_padding[1]}; output_padding = {0, 0, 0, 0, 0, 0, output_padding[0], output_padding[1]};
l1 = prog.add_instruction(op::pad{output_padding}, l1); l1 = prog.add_instruction(op::pad{output_padding}, l1);
} }
if(contains(attributes, "output_shape")) if(contains(info.attributes, "output_shape"))
{ {
std::vector<int64_t> output_shape; std::vector<int64_t> output_shape;
copy(attributes["output_shape"].ints(), std::back_inserter(output_shape)); copy(info.attributes["output_shape"].ints(), std::back_inserter(output_shape));
dims = to_int64_vector(l1->get_shape().lens()); dims = to_int64_vector(l1->get_shape().lens());
curr_shape = {dims[2], dims[3]}; curr_shape = {dims[2], dims[3]};
if(curr_shape != output_shape) if(curr_shape != output_shape)
...@@ -610,9 +612,8 @@ struct onnx_parser ...@@ -610,9 +612,8 @@ struct onnx_parser
return add_bias(args, l1, 1); return add_bias(args, l1, 1);
} }
instruction_ref parse_pooling(const std::string& name, instruction_ref
attribute_map attributes, parse_pooling(const std::string& name, node_info info, std::vector<instruction_ref> args)
std::vector<instruction_ref> args)
{ {
op::pooling op{ends_with(name, "MaxPool") ? "max" : "average"}; op::pooling op{ends_with(name, "MaxPool") ? "max" : "average"};
auto l0 = args[0]; auto l0 = args[0];
...@@ -622,11 +623,11 @@ struct onnx_parser ...@@ -622,11 +623,11 @@ struct onnx_parser
op.lengths = {lens[2], lens[3]}; op.lengths = {lens[2], lens[3]};
} }
if(contains(attributes, "pads")) if(contains(info.attributes, "pads"))
{ {
if(contains(attributes, "auto_pad")) if(contains(info.attributes, "auto_pad"))
{ {
auto s = attributes["auto_pad"].s(); auto s = info.attributes["auto_pad"].s();
if(to_upper(s) != "NOTSET") if(to_upper(s) != "NOTSET")
{ {
MIGRAPHX_THROW( MIGRAPHX_THROW(
...@@ -635,7 +636,7 @@ struct onnx_parser ...@@ -635,7 +636,7 @@ struct onnx_parser
} }
std::vector<std::int64_t> padding; std::vector<std::int64_t> padding;
copy(attributes["pads"].ints(), std::back_inserter(padding)); copy(info.attributes["pads"].ints(), std::back_inserter(padding));
if(padding.size() != 4) if(padding.size() != 4)
{ {
MIGRAPHX_THROW("PARSE_POOLING: padding should have 4 values"); MIGRAPHX_THROW("PARSE_POOLING: padding should have 4 values");
...@@ -646,31 +647,31 @@ struct onnx_parser ...@@ -646,31 +647,31 @@ struct onnx_parser
check_asym_padding(l0, padding, op, pad_val); check_asym_padding(l0, padding, op, pad_val);
} }
if(contains(attributes, "strides")) if(contains(info.attributes, "strides"))
{ {
copy(attributes["strides"].ints(), op.stride.begin()); copy(info.attributes["strides"].ints(), op.stride.begin());
} }
if(contains(attributes, "kernel_shape")) if(contains(info.attributes, "kernel_shape"))
{ {
copy(attributes["kernel_shape"].ints(), op.lengths.begin()); copy(info.attributes["kernel_shape"].ints(), op.lengths.begin());
} }
if(contains(attributes, "auto_pad")) if(contains(info.attributes, "auto_pad"))
{ {
auto in_lens = args[0]->get_shape().lens(); auto in_lens = args[0]->get_shape().lens();
l0 = process_auto_pad_attribute(l0, attributes, op, in_lens); l0 = process_auto_pad_attribute(l0, info, op, in_lens);
} }
return prog.add_instruction(op, l0); return prog.add_instruction(op, l0);
} }
instruction_ref instruction_ref
parse_reshape(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_reshape(const std::string&, node_info info, std::vector<instruction_ref> args)
{ {
op::reshape op; op::reshape op;
if(args.size() == 1) if(args.size() == 1)
{ {
literal s = parse_value(attributes.at("shape")); literal s = parse_value(info.attributes.at("shape"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); }); s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
} }
if(args.size() == 2) if(args.size() == 2)
...@@ -684,55 +685,55 @@ struct onnx_parser ...@@ -684,55 +685,55 @@ struct onnx_parser
} }
instruction_ref instruction_ref
parse_flatten(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_flatten(const std::string&, node_info info, std::vector<instruction_ref> args)
{ {
int64_t axis = 1; int64_t axis = 1;
if(contains(attributes, "axis")) if(contains(info.attributes, "axis"))
{ {
axis = parse_value(attributes.at("axis")).at<int>(); axis = parse_value(info.attributes.at("axis")).at<int>();
} }
return prog.add_instruction(op::flatten{axis}, args[0]); return prog.add_instruction(op::flatten{axis}, args[0]);
} }
instruction_ref instruction_ref
parse_squeeze(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_squeeze(const std::string&, node_info info, std::vector<instruction_ref> args)
{ {
op::squeeze op; op::squeeze op;
literal s = parse_value(attributes.at("axes")); literal s = parse_value(info.attributes.at("axes"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); }); s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
return prog.add_instruction(op, args[0]); return prog.add_instruction(op, make_contiguous(args[0]));
} }
instruction_ref instruction_ref
parse_unsqueeze(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_unsqueeze(const std::string&, node_info info, std::vector<instruction_ref> args)
{ {
op::unsqueeze op; op::unsqueeze op;
literal s = parse_value(attributes.at("axes")); literal s = parse_value(info.attributes.at("axes"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); }); s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
return prog.add_instruction(op, args[0]); return prog.add_instruction(op, make_contiguous(args[0]));
} }
instruction_ref instruction_ref
parse_concat(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_concat(const std::string&, node_info info, std::vector<instruction_ref> args)
{ {
// change to hande axis to be negative values // change to hande axis to be negative values
if(!contains(attributes, "axis")) if(!contains(info.attributes, "axis"))
{ {
MIGRAPHX_THROW("PARSE_CONCAT: attribute axis is required!"); MIGRAPHX_THROW("PARSE_CONCAT: attribute axis is required!");
} }
int axis = parse_value(attributes.at("axis")).at<int>(); int axis = parse_value(info.attributes.at("axis")).at<int>();
op::concat op{axis}; op::concat op{axis};
return prog.add_instruction(op, std::move(args)); return prog.add_instruction(op, std::move(args));
} }
instruction_ref instruction_ref
parse_gather(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_gather(const std::string&, node_info info, std::vector<instruction_ref> args)
{ {
int axis = 0; int axis = 0;
if(contains(attributes, "axis")) if(contains(info.attributes, "axis"))
{ {
axis = parse_value(attributes.at("axis")).at<int>(); axis = parse_value(info.attributes.at("axis")).at<int>();
} }
op::gather op{axis}; op::gather op{axis};
...@@ -740,14 +741,14 @@ struct onnx_parser ...@@ -740,14 +741,14 @@ struct onnx_parser
} }
instruction_ref instruction_ref
parse_slice(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_slice(const std::string&, node_info info, std::vector<instruction_ref> args)
{ {
op::slice op; op::slice op;
std::vector<size_t> dims = args[0]->get_shape().lens(); std::vector<size_t> dims = args[0]->get_shape().lens();
size_t num_dims = dims.size(); size_t num_dims = dims.size();
if(contains(attributes, "axes")) if(contains(info.attributes, "axes"))
{ {
literal s = parse_value(attributes.at("axes")); literal s = parse_value(info.attributes.at("axes"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); }); s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
} }
else else
...@@ -756,30 +757,29 @@ struct onnx_parser ...@@ -756,30 +757,29 @@ struct onnx_parser
std::iota(op.axes.begin(), op.axes.end(), 0); std::iota(op.axes.begin(), op.axes.end(), 0);
} }
if(contains(attributes, "ends")) if(contains(info.attributes, "ends"))
{ {
op.ends = get_indices(attributes.at("ends")); op.ends = get_indices(info.attributes.at("ends"));
} }
if(contains(attributes, "starts")) if(contains(info.attributes, "starts"))
{ {
literal s = parse_value(attributes.at("starts")); literal s = parse_value(info.attributes.at("starts"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); }); s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); });
} }
return prog.add_instruction(op, args[0]); return prog.add_instruction(op, args[0]);
} }
instruction_ref parse_constant(const std::string&, instruction_ref
attribute_map attributes, parse_constant(const std::string&, node_info info, const std::vector<instruction_ref>&)
const std::vector<instruction_ref>&)
{ {
literal v = parse_value(attributes.at("value")); literal v = parse_value(info.attributes.at("value"));
// return empty literal // return empty literal
if(v.get_shape().elements() == 0) if(v.get_shape().elements() == 0)
{ {
return prog.add_literal(literal{}); return prog.add_literal(literal{});
} }
auto dim_size = attributes.at("value").t().dims_size(); auto dim_size = info.attributes.at("value").t().dims_size();
// if dim_size is 0, it is a scalar // if dim_size is 0, it is a scalar
if(dim_size == 0) if(dim_size == 0)
{ {
...@@ -791,27 +791,27 @@ struct onnx_parser ...@@ -791,27 +791,27 @@ struct onnx_parser
} }
instruction_ref instruction_ref
parse_gemm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_gemm(const std::string&, node_info info, std::vector<instruction_ref> args)
{ {
float alpha = 1.0f; float alpha = 1.0f;
float beta = 1.0f; float beta = 1.0f;
bool transa = false; bool transa = false;
bool transb = false; bool transb = false;
if(contains(attributes, "alpha")) if(contains(info.attributes, "alpha"))
{ {
alpha = parse_value(attributes.at("alpha")).at<float>(); alpha = parse_value(info.attributes.at("alpha")).at<float>();
} }
if(contains(attributes, "beta")) if(contains(info.attributes, "beta"))
{ {
beta = parse_value(attributes.at("beta")).at<float>(); beta = parse_value(info.attributes.at("beta")).at<float>();
} }
if(contains(attributes, "transA")) if(contains(info.attributes, "transA"))
{ {
transa = parse_value(attributes.at("transA")).at<bool>(); transa = parse_value(info.attributes.at("transA")).at<bool>();
} }
if(contains(attributes, "transB")) if(contains(info.attributes, "transB"))
{ {
transb = parse_value(attributes.at("transB")).at<bool>(); transb = parse_value(info.attributes.at("transB")).at<bool>();
} }
std::vector<int64_t> perm(args[0]->get_shape().lens().size()); std::vector<int64_t> perm(args[0]->get_shape().lens().size());
...@@ -842,7 +842,7 @@ struct onnx_parser ...@@ -842,7 +842,7 @@ struct onnx_parser
template <class Op> template <class Op>
instruction_ref instruction_ref
parse_matmul(const std::string&, const attribute_map&, std::vector<instruction_ref> args) parse_matmul(const std::string&, const node_info&, std::vector<instruction_ref> args)
{ {
auto l0 = args[0]; auto l0 = args[0];
auto l1 = args[1]; auto l1 = args[1];
...@@ -905,22 +905,22 @@ struct onnx_parser ...@@ -905,22 +905,22 @@ struct onnx_parser
} }
instruction_ref instruction_ref
parse_batchnorm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_batchnorm(const std::string&, node_info info, std::vector<instruction_ref> args)
{ {
float epsilon = 1e-5f; float epsilon = 1e-5f;
float momentum = 0.9f; float momentum = 0.9f;
op::batch_norm_inference::bn_infer_mode_t bn_mode = op::batch_norm_inference::spatial; op::batch_norm_inference::bn_infer_mode_t bn_mode = op::batch_norm_inference::spatial;
if(contains(attributes, "epsilon")) if(contains(info.attributes, "epsilon"))
{ {
epsilon = parse_value(attributes.at("epsilon")).at<float>(); epsilon = parse_value(info.attributes.at("epsilon")).at<float>();
} }
if(contains(attributes, "momentum")) if(contains(info.attributes, "momentum"))
{ {
momentum = parse_value(attributes.at("momentum")).at<float>(); momentum = parse_value(info.attributes.at("momentum")).at<float>();
} }
if(contains(attributes, "spatial")) if(contains(info.attributes, "spatial"))
{ {
bn_mode = (parse_value(attributes.at("spatial")).at<uint64_t>() > 0) bn_mode = (parse_value(info.attributes.at("spatial")).at<uint64_t>() > 0)
? op::batch_norm_inference::spatial ? op::batch_norm_inference::spatial
: op::batch_norm_inference::per_activation; : op::batch_norm_inference::per_activation;
} }
...@@ -928,18 +928,17 @@ struct onnx_parser ...@@ -928,18 +928,17 @@ struct onnx_parser
return prog.add_instruction(op, std::move(args)); return prog.add_instruction(op, std::move(args));
} }
instruction_ref parse_instancenorm(const std::string&, instruction_ref
attribute_map attributes, parse_instancenorm(const std::string&, node_info info, std::vector<instruction_ref> args)
std::vector<instruction_ref> args)
{ {
// y = scale * ( x - mean ) / sqrt ( variance + epsilon ) + bias // y = scale * ( x - mean ) / sqrt ( variance + epsilon ) + bias
// mean = reduce_mean({H, W}, x) // mean = reduce_mean({H, W}, x)
// variance = reduce_mean({H, W}, (x - mean)^2) // variance = reduce_mean({H, W}, (x - mean)^2)
float epsilon = 1e-5f; float epsilon = 1e-5f;
if(contains(attributes, "epsilon")) if(contains(info.attributes, "epsilon"))
{ {
epsilon = parse_value(attributes.at("epsilon")).at<float>(); epsilon = parse_value(info.attributes.at("epsilon")).at<float>();
} }
auto x = args[0]; auto x = args[0];
auto scale = args[1]; auto scale = args[1];
...@@ -964,64 +963,60 @@ struct onnx_parser ...@@ -964,64 +963,60 @@ struct onnx_parser
return prog.add_instruction(op::add{}, l5, bias_bcast); return prog.add_instruction(op::add{}, l5, bias_bcast);
} }
instruction_ref parse_leaky_relu(const std::string&, instruction_ref
attribute_map attributes, parse_leaky_relu(const std::string&, node_info info, std::vector<instruction_ref> args)
std::vector<instruction_ref> args)
{ {
float alpha = 0.01; // default alpha val for leaky relu float alpha = 0.01; // default alpha val for leaky relu
if(contains(attributes, "alpha")) if(contains(info.attributes, "alpha"))
{ {
alpha = parse_value(attributes.at("alpha")).at<float>(); alpha = parse_value(info.attributes.at("alpha")).at<float>();
} }
op::leaky_relu op{alpha}; op::leaky_relu op{alpha};
return prog.add_instruction(op, args.front()); return prog.add_instruction(op, args.front());
} }
instruction_ref instruction_ref parse_elu(const std::string&, node_info info, std::vector<instruction_ref> args)
parse_elu(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
float alpha = 1.0; // default alpha val for elu float alpha = 1.0; // default alpha val for elu
if(contains(attributes, "alpha")) if(contains(info.attributes, "alpha"))
{ {
alpha = parse_value(attributes.at("alpha")).at<float>(); alpha = parse_value(info.attributes.at("alpha")).at<float>();
} }
op::elu op{alpha}; op::elu op{alpha};
return prog.add_instruction(op, args.front()); return prog.add_instruction(op, args.front());
} }
instruction_ref instruction_ref parse_lrn(const std::string&, node_info info, std::vector<instruction_ref> args)
parse_lrn(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
float alpha = 0.0001; float alpha = 0.0001;
float beta = 0.75; float beta = 0.75;
float bias = 1.0; float bias = 1.0;
int size = 1; int size = 1;
if(contains(attributes, "alpha")) if(contains(info.attributes, "alpha"))
alpha = parse_value(attributes.at("alpha")).at<float>(); alpha = parse_value(info.attributes.at("alpha")).at<float>();
if(contains(attributes, "beta")) if(contains(info.attributes, "beta"))
beta = parse_value(attributes.at("beta")).at<float>(); beta = parse_value(info.attributes.at("beta")).at<float>();
if(contains(attributes, "bias")) if(contains(info.attributes, "bias"))
bias = parse_value(attributes.at("bias")).at<float>(); bias = parse_value(info.attributes.at("bias")).at<float>();
if(contains(attributes, "size")) if(contains(info.attributes, "size"))
size = parse_value(attributes.at("size")).at<int>(); size = parse_value(info.attributes.at("size")).at<int>();
op::lrn op{alpha, beta, bias, size}; op::lrn op{alpha, beta, bias, size};
return prog.add_instruction(op, args.front()); return prog.add_instruction(op, args.front());
} }
instruction_ref parse_imagescaler(const std::string&, instruction_ref
attribute_map attributes, parse_imagescaler(const std::string&, node_info info, std::vector<instruction_ref> args)
std::vector<instruction_ref> args)
{ {
float scale = 1.0; float scale = 1.0;
std::vector<float> bias{}; std::vector<float> bias{};
if(contains(attributes, "scale")) if(contains(info.attributes, "scale"))
{ {
scale = parse_value(attributes.at("scale")).at<float>(); scale = parse_value(info.attributes.at("scale")).at<float>();
} }
if(contains(attributes, "bias")) if(contains(info.attributes, "bias"))
{ {
auto&& bias_floats = attributes["bias"].floats(); auto&& bias_floats = info.attributes["bias"].floats();
bias = std::vector<float>(bias_floats.begin(), bias_floats.end()); bias = std::vector<float>(bias_floats.begin(), bias_floats.end());
} }
auto input_shape = args.front()->get_shape(); auto input_shape = args.front()->get_shape();
...@@ -1038,25 +1033,24 @@ struct onnx_parser ...@@ -1038,25 +1033,24 @@ struct onnx_parser
} }
instruction_ref instruction_ref
parse_transpose(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_transpose(const std::string&, node_info info, std::vector<instruction_ref> args)
{ {
std::vector<int64_t> perm{}; std::vector<int64_t> perm{};
if(contains(attributes, "perm")) if(contains(info.attributes, "perm"))
{ {
auto&& perm_vals = attributes["perm"].ints(); auto&& perm_vals = info.attributes["perm"].ints();
perm = std::vector<int64_t>(perm_vals.begin(), perm_vals.end()); perm = std::vector<int64_t>(perm_vals.begin(), perm_vals.end());
} }
return prog.add_instruction(migraphx::op::transpose{perm}, args.front()); return prog.add_instruction(migraphx::op::transpose{perm}, args.front());
} }
instruction_ref instruction_ref parse_pad(const std::string&, node_info info, std::vector<instruction_ref> args)
parse_pad(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
std::vector<int64_t> pads{}; std::vector<int64_t> pads{};
float value = 0.0f; float value = 0.0f;
if(contains(attributes, "pads")) if(contains(info.attributes, "pads"))
{ {
auto&& pad_vals = attributes["pads"].ints(); auto&& pad_vals = info.attributes["pads"].ints();
pads = std::vector<int64_t>(pad_vals.begin(), pad_vals.end()); pads = std::vector<int64_t>(pad_vals.begin(), pad_vals.end());
} }
// check if padding is actually being done (at least one value is nonzero) // check if padding is actually being done (at least one value is nonzero)
...@@ -1064,13 +1058,13 @@ struct onnx_parser ...@@ -1064,13 +1058,13 @@ struct onnx_parser
{ {
return prog.add_instruction(migraphx::op::identity{}, args.front()); return prog.add_instruction(migraphx::op::identity{}, args.front());
} }
if(contains(attributes, "value")) if(contains(info.attributes, "value"))
{ {
value = parse_value(attributes.at("value")).at<float>(); value = parse_value(info.attributes.at("value")).at<float>();
} }
if(contains(attributes, "mode")) if(contains(info.attributes, "mode"))
{ {
auto mode = attributes.at("mode").s(); auto mode = info.attributes.at("mode").s();
if(mode != "constant") if(mode != "constant")
MIGRAPHX_THROW("migraphx currently only supports constant padding"); MIGRAPHX_THROW("migraphx currently only supports constant padding");
} }
...@@ -1079,7 +1073,7 @@ struct onnx_parser ...@@ -1079,7 +1073,7 @@ struct onnx_parser
// Use a literal instruction to replace the shape since, output of // Use a literal instruction to replace the shape since, output of
// shape operator are literals in migraphx // shape operator are literals in migraphx
instruction_ref instruction_ref
parse_shape(const std::string&, const attribute_map&, std::vector<instruction_ref> args) parse_shape(const std::string&, const node_info&, std::vector<instruction_ref> args)
{ {
if(args.size() != 1) if(args.size() != 1)
MIGRAPHX_THROW("Shape: operator should have 1 operand"); MIGRAPHX_THROW("Shape: operator should have 1 operand");
...@@ -1095,31 +1089,30 @@ struct onnx_parser ...@@ -1095,31 +1089,30 @@ struct onnx_parser
// Use a literal instruction to replace the constantFill operator. In RNN, input shape // Use a literal instruction to replace the constantFill operator. In RNN, input shape
// and value are fixed, so no need to do the actual computation for the constantFill // and value are fixed, so no need to do the actual computation for the constantFill
// operator // operator
instruction_ref parse_constant_fill(const std::string&, instruction_ref
attribute_map attributes, parse_constant_fill(const std::string&, node_info info, std::vector<instruction_ref> args)
std::vector<instruction_ref> args)
{ {
int input_as_shape = 0; int input_as_shape = 0;
int dtype = 1; int dtype = 1;
float value = 0.0f; float value = 0.0f;
if(contains(attributes, "dtype")) if(contains(info.attributes, "dtype"))
{ {
dtype = parse_value(attributes.at("dtype")).at<int>(); dtype = parse_value(info.attributes.at("dtype")).at<int>();
} }
shape::type_t type = get_type(dtype); shape::type_t type = get_type(dtype);
if(contains(attributes, "input_as_shape")) if(contains(info.attributes, "input_as_shape"))
{ {
input_as_shape = parse_value(attributes.at("input_as_shape")).at<int>(); input_as_shape = parse_value(info.attributes.at("input_as_shape")).at<int>();
} }
if(contains(attributes, "value")) if(contains(info.attributes, "value"))
{ {
value = parse_value(attributes.at("value")).at<float>(); value = parse_value(info.attributes.at("value")).at<float>();
} }
if(contains(attributes, "extra_shape")) if(contains(info.attributes, "extra_shape"))
{ {
MIGRAPHX_THROW("ConstantFill: cannot handle extra shape attribute"); MIGRAPHX_THROW("ConstantFill: cannot handle extra shape attribute");
} }
...@@ -1131,7 +1124,7 @@ struct onnx_parser ...@@ -1131,7 +1124,7 @@ struct onnx_parser
MIGRAPHX_THROW("ConstantFill: need an input argument as output shape"); MIGRAPHX_THROW("ConstantFill: need an input argument as output shape");
} }
if(contains(attributes, "shape")) if(contains(info.attributes, "shape"))
{ {
MIGRAPHX_THROW("ConstantFill: cannot set the shape argument and pass in an input " MIGRAPHX_THROW("ConstantFill: cannot set the shape argument and pass in an input "
"at the same time"); "at the same time");
...@@ -1148,12 +1141,12 @@ struct onnx_parser ...@@ -1148,12 +1141,12 @@ struct onnx_parser
} }
else if(input_as_shape == 0) else if(input_as_shape == 0)
{ {
if(!contains(attributes, "shape")) if(!contains(info.attributes, "shape"))
{ {
MIGRAPHX_THROW("ConstantFill: attribute output shape is needed"); MIGRAPHX_THROW("ConstantFill: attribute output shape is needed");
} }
literal ls = parse_value(attributes.at("shape")); literal ls = parse_value(info.attributes.at("shape"));
std::vector<std::size_t> dims; std::vector<std::size_t> dims;
ls.visit([&](auto s) { dims.assign(s.begin(), s.end()); }); ls.visit([&](auto s) { dims.assign(s.begin(), s.end()); });
migraphx::shape s{type, dims}; migraphx::shape s{type, dims};
...@@ -1166,14 +1159,13 @@ struct onnx_parser ...@@ -1166,14 +1159,13 @@ struct onnx_parser
} }
} }
instruction_ref parse_constant_of_shape(const std::string&, instruction_ref
attribute_map attributes, parse_constant_of_shape(const std::string&, node_info info, std::vector<instruction_ref> args)
std::vector<instruction_ref> args)
{ {
literal l_val{}; literal l_val{};
if(contains(attributes, "value")) if(contains(info.attributes, "value"))
{ {
l_val = parse_value(attributes.at("value")); l_val = parse_value(info.attributes.at("value"));
if(l_val.get_shape().elements() != 1) if(l_val.get_shape().elements() != 1)
{ {
MIGRAPHX_THROW("ConstantOfShape: attribute value can contain only 1 elements!"); MIGRAPHX_THROW("ConstantOfShape: attribute value can contain only 1 elements!");
...@@ -1222,7 +1214,7 @@ struct onnx_parser ...@@ -1222,7 +1214,7 @@ struct onnx_parser
} }
instruction_ref instruction_ref
parse_expand(const std::string&, const attribute_map&, std::vector<instruction_ref> args) parse_expand(const std::string&, const node_info&, std::vector<instruction_ref> args)
{ {
auto in_lens = args[0]->get_shape().lens(); auto in_lens = args[0]->get_shape().lens();
migraphx::argument arg_s = args[1]->eval(); migraphx::argument arg_s = args[1]->eval();
...@@ -1234,14 +1226,14 @@ struct onnx_parser ...@@ -1234,14 +1226,14 @@ struct onnx_parser
} }
std::vector<instruction_ref> std::vector<instruction_ref>
parse_rnn(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_rnn(const std::string&, node_info info, std::vector<instruction_ref> args)
{ {
migraphx::shape input_shape = args[0]->get_shape(); migraphx::shape input_shape = args[0]->get_shape();
std::size_t hidden_size = args[1]->get_shape().lens()[1]; std::size_t hidden_size = args[1]->get_shape().lens()[1];
if(contains(attributes, "hidden_size")) if(contains(info.attributes, "hidden_size"))
{ {
std::size_t hidden_size_att = parse_value(attributes.at("hidden_size")).at<int>(); std::size_t hidden_size_att = parse_value(info.attributes.at("hidden_size")).at<int>();
if(hidden_size != hidden_size_att) if(hidden_size != hidden_size_att)
{ {
MIGRAPHX_THROW("RNN: hidden size mismatch in input and attribute"); MIGRAPHX_THROW("RNN: hidden size mismatch in input and attribute");
...@@ -1250,9 +1242,9 @@ struct onnx_parser ...@@ -1250,9 +1242,9 @@ struct onnx_parser
// Handling of direction to be added later // Handling of direction to be added later
std::string direction{"forward"}; std::string direction{"forward"};
if(contains(attributes, "direction")) if(contains(info.attributes, "direction"))
{ {
direction = attributes.at("direction").s(); direction = info.attributes.at("direction").s();
} }
op::rnn_direction dirct = op::rnn_direction::forward; op::rnn_direction dirct = op::rnn_direction::forward;
...@@ -1266,9 +1258,9 @@ struct onnx_parser ...@@ -1266,9 +1258,9 @@ struct onnx_parser
} }
std::vector<std::string> vec_names{"tanh"}; std::vector<std::string> vec_names{"tanh"};
if(contains(attributes, "activations")) if(contains(info.attributes, "activations"))
{ {
auto names = attributes.at("activations").strings(); auto names = info.attributes.at("activations").strings();
vec_names.clear(); vec_names.clear();
vec_names.resize(names.size()); vec_names.resize(names.size());
std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) { std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
...@@ -1304,9 +1296,9 @@ struct onnx_parser ...@@ -1304,9 +1296,9 @@ struct onnx_parser
// To be added later // To be added later
float clip = 0.0; float clip = 0.0;
if(contains(attributes, "clip")) if(contains(info.attributes, "clip"))
{ {
clip = parse_value(attributes.at("clip")).at<float>(); clip = parse_value(info.attributes.at("clip")).at<float>();
} }
// if the number of arguments is less than 6, append // if the number of arguments is less than 6, append
...@@ -1328,14 +1320,14 @@ struct onnx_parser ...@@ -1328,14 +1320,14 @@ struct onnx_parser
} }
std::vector<instruction_ref> std::vector<instruction_ref>
parse_gru(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_gru(const std::string&, node_info info, std::vector<instruction_ref> args)
{ {
migraphx::shape input_shape = args[0]->get_shape(); migraphx::shape input_shape = args[0]->get_shape();
std::size_t hidden_size = args[2]->get_shape().lens()[2]; std::size_t hidden_size = args[2]->get_shape().lens()[2];
if(contains(attributes, "hidden_size")) if(contains(info.attributes, "hidden_size"))
{ {
std::size_t hidden_size_att = parse_value(attributes.at("hidden_size")).at<int>(); std::size_t hidden_size_att = parse_value(info.attributes.at("hidden_size")).at<int>();
if(hidden_size != hidden_size_att) if(hidden_size != hidden_size_att)
{ {
MIGRAPHX_THROW("GRU: hidden size mismatch in input and attribute"); MIGRAPHX_THROW("GRU: hidden size mismatch in input and attribute");
...@@ -1344,9 +1336,9 @@ struct onnx_parser ...@@ -1344,9 +1336,9 @@ struct onnx_parser
// Handling of direction to be added later // Handling of direction to be added later
std::string direction{"forward"}; std::string direction{"forward"};
if(contains(attributes, "direction")) if(contains(info.attributes, "direction"))
{ {
direction = attributes.at("direction").s(); direction = info.attributes.at("direction").s();
} }
op::rnn_direction dirct = op::rnn_direction::forward; op::rnn_direction dirct = op::rnn_direction::forward;
...@@ -1360,9 +1352,9 @@ struct onnx_parser ...@@ -1360,9 +1352,9 @@ struct onnx_parser
} }
std::vector<std::string> vec_names = {"sigmoid", "tanh"}; std::vector<std::string> vec_names = {"sigmoid", "tanh"};
if(contains(attributes, "activations")) if(contains(info.attributes, "activations"))
{ {
auto names = attributes.at("activations").strings(); auto names = info.attributes.at("activations").strings();
vec_names.clear(); vec_names.clear();
vec_names.resize(names.size()); vec_names.resize(names.size());
std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) { std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
...@@ -1420,15 +1412,15 @@ struct onnx_parser ...@@ -1420,15 +1412,15 @@ struct onnx_parser
[&](const auto& name) { return map_actv_funcs[name]; }); [&](const auto& name) { return map_actv_funcs[name]; });
float clip = 0.0; float clip = 0.0;
if(contains(attributes, "clip")) if(contains(info.attributes, "clip"))
{ {
clip = parse_value(attributes.at("clip")).at<float>(); clip = parse_value(info.attributes.at("clip")).at<float>();
} }
int linear_before_reset = 0; int linear_before_reset = 0;
if(contains(attributes, "linear_before_reset")) if(contains(info.attributes, "linear_before_reset"))
{ {
linear_before_reset = parse_value(attributes.at("linear_before_reset")).at<int>(); linear_before_reset = parse_value(info.attributes.at("linear_before_reset")).at<int>();
} }
// append undefined opeator to make 6 arguments // append undefined opeator to make 6 arguments
...@@ -1450,14 +1442,14 @@ struct onnx_parser ...@@ -1450,14 +1442,14 @@ struct onnx_parser
} }
std::vector<instruction_ref> std::vector<instruction_ref>
parse_lstm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_lstm(const std::string&, node_info info, std::vector<instruction_ref> args)
{ {
migraphx::shape input_shape = args[0]->get_shape(); migraphx::shape input_shape = args[0]->get_shape();
std::size_t hidden_size = args[2]->get_shape().lens()[2]; std::size_t hidden_size = args[2]->get_shape().lens()[2];
if(contains(attributes, "hidden_size")) if(contains(info.attributes, "hidden_size"))
{ {
std::size_t hidden_size_att = parse_value(attributes.at("hidden_size")).at<int>(); std::size_t hidden_size_att = parse_value(info.attributes.at("hidden_size")).at<int>();
if(hidden_size != hidden_size_att) if(hidden_size != hidden_size_att)
{ {
MIGRAPHX_THROW("LSTM: hidden size mismatch in input and attribute"); MIGRAPHX_THROW("LSTM: hidden size mismatch in input and attribute");
...@@ -1466,9 +1458,9 @@ struct onnx_parser ...@@ -1466,9 +1458,9 @@ struct onnx_parser
// Handling of direction to be added later // Handling of direction to be added later
std::string direction{"forward"}; std::string direction{"forward"};
if(contains(attributes, "direction")) if(contains(info.attributes, "direction"))
{ {
direction = attributes.at("direction").s(); direction = info.attributes.at("direction").s();
} }
op::rnn_direction dirct = op::rnn_direction::forward; op::rnn_direction dirct = op::rnn_direction::forward;
...@@ -1490,9 +1482,9 @@ struct onnx_parser ...@@ -1490,9 +1482,9 @@ struct onnx_parser
} }
std::vector<std::string> vec_names = {"sigmoid", "tanh", "tanh"}; std::vector<std::string> vec_names = {"sigmoid", "tanh", "tanh"};
if(contains(attributes, "activations")) if(contains(info.attributes, "activations"))
{ {
auto names = attributes.at("activations").strings(); auto names = info.attributes.at("activations").strings();
vec_names.clear(); vec_names.clear();
vec_names.resize(names.size()); vec_names.resize(names.size());
std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) { std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
...@@ -1593,15 +1585,15 @@ struct onnx_parser ...@@ -1593,15 +1585,15 @@ struct onnx_parser
[&](const auto& name) { return map_actv_funcs[name]; }); [&](const auto& name) { return map_actv_funcs[name]; });
float clip = 0.0; float clip = 0.0;
if(contains(attributes, "clip")) if(contains(info.attributes, "clip"))
{ {
clip = parse_value(attributes.at("clip")).at<float>(); clip = parse_value(info.attributes.at("clip")).at<float>();
} }
int input_forget = 0; int input_forget = 0;
if(contains(attributes, "input_forget")) if(contains(info.attributes, "input_forget"))
{ {
input_forget = parse_value(attributes.at("input_forget")).at<int>(); input_forget = parse_value(info.attributes.at("input_forget")).at<int>();
} }
// append undefined opeator to make 6 arguments // append undefined opeator to make 6 arguments
...@@ -1625,26 +1617,25 @@ struct onnx_parser ...@@ -1625,26 +1617,25 @@ struct onnx_parser
} }
template <class T> template <class T>
instruction_ref parse_reduce_oper(const std::string&, instruction_ref
attribute_map attributes, parse_reduce_oper(const std::string&, node_info info, std::vector<instruction_ref> args)
std::vector<instruction_ref> args)
{ {
std::size_t n_dim = args.front()->get_shape().lens().size(); std::size_t n_dim = args.front()->get_shape().lens().size();
// default to reduce over all dimensions // default to reduce over all dimensions
std::vector<int64_t> axes(n_dim); std::vector<int64_t> axes(n_dim);
std::iota(axes.begin(), axes.end(), 0); std::iota(axes.begin(), axes.end(), 0);
if(contains(attributes, "axes")) if(contains(info.attributes, "axes"))
{ {
axes.clear(); axes.clear();
auto&& attr_axes = attributes["axes"].ints(); auto&& attr_axes = info.attributes["axes"].ints();
axes = std::vector<int64_t>(attr_axes.begin(), attr_axes.end()); axes = std::vector<int64_t>(attr_axes.begin(), attr_axes.end());
} }
int keep_dims = 1; int keep_dims = 1;
if(contains(attributes, "keepdims")) if(contains(info.attributes, "keepdims"))
{ {
keep_dims = parse_value(attributes.at("keepdims")).at<int>(); keep_dims = parse_value(info.attributes.at("keepdims")).at<int>();
} }
if(keep_dims == 1) if(keep_dims == 1)
...@@ -1659,59 +1650,108 @@ struct onnx_parser ...@@ -1659,59 +1650,108 @@ struct onnx_parser
} }
instruction_ref instruction_ref
parse_reduce_l1(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_reduce_l1(const std::string&, node_info info, std::vector<instruction_ref> args)
{ {
auto abs_ins = prog.add_instruction(op::abs{}, args[0]); auto abs_ins = prog.add_instruction(op::abs{}, args[0]);
return parse_reduce_oper<op::reduce_sum>({}, std::move(attributes), {abs_ins}); return parse_reduce_oper<op::reduce_sum>({}, std::move(info), {abs_ins});
} }
instruction_ref instruction_ref
parse_reduce_l2(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_reduce_l2(const std::string&, node_info info, std::vector<instruction_ref> args)
{ {
auto square_ins = prog.add_instruction(op::mul{}, args[0], args[0]); auto square_ins = prog.add_instruction(op::mul{}, args[0], args[0]);
auto sum_ins = parse_reduce_oper<op::reduce_sum>({}, std::move(attributes), {square_ins}); auto sum_ins = parse_reduce_oper<op::reduce_sum>({}, std::move(info), {square_ins});
return prog.add_instruction(op::sqrt{}, sum_ins); return prog.add_instruction(op::sqrt{}, sum_ins);
} }
instruction_ref parse_reduce_log_sum(const std::string&, instruction_ref
attribute_map attributes, parse_reduce_log_sum(const std::string&, node_info info, std::vector<instruction_ref> args)
std::vector<instruction_ref> args)
{ {
auto sum_ins = auto sum_ins = parse_reduce_oper<op::reduce_sum>({}, std::move(info), std::move(args));
parse_reduce_oper<op::reduce_sum>({}, std::move(attributes), std::move(args));
return prog.add_instruction(op::log{}, sum_ins); return prog.add_instruction(op::log{}, sum_ins);
} }
instruction_ref parse_reduce_log_sum_exp(const std::string&, instruction_ref
attribute_map attributes, parse_reduce_log_sum_exp(const std::string&, node_info info, std::vector<instruction_ref> args)
std::vector<instruction_ref> args)
{ {
auto exp_ins = prog.add_instruction(op::exp{}, args[0]); auto exp_ins = prog.add_instruction(op::exp{}, args[0]);
auto sum_ins = parse_reduce_oper<op::reduce_sum>({}, std::move(attributes), {exp_ins}); auto sum_ins = parse_reduce_oper<op::reduce_sum>({}, std::move(info), {exp_ins});
return prog.add_instruction(op::log{}, sum_ins); return prog.add_instruction(op::log{}, sum_ins);
} }
instruction_ref parse_reduce_sum_square(const std::string&, instruction_ref
attribute_map attributes, parse_reduce_sum_square(const std::string&, node_info info, std::vector<instruction_ref> args)
std::vector<instruction_ref> args)
{ {
auto square_ins = prog.add_instruction(op::mul{}, args[0], args[0]); auto square_ins = prog.add_instruction(op::mul{}, args[0], args[0]);
return parse_reduce_oper<op::reduce_sum>({}, std::move(attributes), {square_ins}); return parse_reduce_oper<op::reduce_sum>({}, std::move(info), {square_ins});
} }
instruction_ref instruction_ref
parse_cast(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_cast(const std::string&, node_info info, std::vector<instruction_ref> args)
{ {
if(!contains(attributes, "to")) if(!contains(info.attributes, "to"))
{ {
MIGRAPHX_THROW("PARSE_CAST: missing to type attribute!"); MIGRAPHX_THROW("PARSE_CAST: missing to type attribute!");
} }
int to_type = parse_value(attributes.at("to")).at<int>(); int to_type = parse_value(info.attributes.at("to")).at<int>();
shape::type_t type = get_type(to_type); shape::type_t type = get_type(to_type);
return prog.add_instruction(op::convert{type}, std::move(args)); return prog.add_instruction(op::convert{type}, std::move(args));
} }
std::vector<instruction_ref>
parse_split(const std::string&, node_info info, std::vector<instruction_ref> args)
{
int64_t axis = 0;
if(contains(info.attributes, "axis"))
{
axis = parse_value(info.attributes.at("axis")).at<int>();
}
auto lens = args[0]->get_shape().lens();
int64_t n_rank = static_cast<int64_t>(lens.size());
if((axis < -n_rank) || (axis >= n_rank))
{
MIGRAPHX_THROW("PARSE_SPLIT: axis attribute out of rank!");
}
int64_t tuned_axis = (axis < 0) ? axis + n_rank : axis;
std::vector<int64_t> vec_splits;
if(contains(info.attributes, "split"))
{
literal s = parse_value(info.attributes.at("split"));
s.visit([&](auto v) { vec_splits.assign(v.begin(), v.end()); });
if(std::accumulate(vec_splits.begin(), vec_splits.end(), int64_t(0)) !=
static_cast<int64_t>(lens[tuned_axis]))
{
MIGRAPHX_THROW("PARSE_SPLIT: sum of split attribute unequal to dim size of axis!");
}
}
// no split attribute, input is equally divided
else
{
if((lens[tuned_axis] % info.num_outputs) != 0)
{
MIGRAPHX_THROW("PARSE_SPLIT: input cannot be equally divided into " +
to_string(info.num_outputs) + " splits!");
}
auto dl = lens[tuned_axis] / info.num_outputs;
vec_splits.resize(info.num_outputs, dl);
}
std::vector<instruction_ref> ret_ins;
int64_t start = 0;
for(auto sl : vec_splits)
{
ret_ins.push_back(
prog.add_instruction(op::slice{{axis}, {start}, {start + sl}}, args[0]));
start += sl;
}
return ret_ins;
}
void parse_from(std::istream& is) void parse_from(std::istream& is)
{ {
onnx::ModelProto model; onnx::ModelProto model;
...@@ -1824,7 +1864,8 @@ struct onnx_parser ...@@ -1824,7 +1864,8 @@ struct onnx_parser
} }
else else
{ {
result = ops[node.op_type()](get_attributes(node), args); std::size_t output_num = static_cast<std::size_t>(node.output().size());
result = ops[node.op_type()]({get_attributes(node), output_num}, args);
} }
// Even no output nodes produce output in migraphx // Even no output nodes produce output in migraphx
if(node.output().empty() and result.size() == 1) if(node.output().empty() and result.size() == 1)
......
...@@ -1771,6 +1771,37 @@ def softmax_test(): ...@@ -1771,6 +1771,37 @@ def softmax_test():
return ([node], [x], [y]) return ([node], [x], [y])
@onnx_test
def split_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 15])
y1 = helper.make_tensor_value_info('y1', TensorProto.FLOAT, [10, 7])
y2 = helper.make_tensor_value_info('y2', TensorProto.FLOAT, [10, 4])
y3 = helper.make_tensor_value_info('y3', TensorProto.FLOAT, [10, 4])
node = onnx.helper.make_node('Split',
inputs=['x'],
outputs=['y1', 'y2', 'y3'],
axis=1,
split=[7, 4, 4])
return ([node], [x], [y1, y2, y3])
@onnx_test
def split_test_default():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 15])
y1 = helper.make_tensor_value_info('y1', TensorProto.FLOAT, [5, 15])
y2 = helper.make_tensor_value_info('y2', TensorProto.FLOAT, [5, 15])
node = onnx.helper.make_node(
'Split',
inputs=['x'],
outputs=['y1', 'y2'],
)
return ([node], [x], [y1, y2])
@onnx_test @onnx_test
def sqrt_test(): def sqrt_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 15]) x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 15])
......
...@@ -1356,6 +1356,31 @@ TEST_CASE(softmax_test) ...@@ -1356,6 +1356,31 @@ TEST_CASE(softmax_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(split_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10, 15}});
auto r1 = p.add_instruction(migraphx::op::slice{{1}, {0}, {7}}, input);
auto r2 = p.add_instruction(migraphx::op::slice{{1}, {7}, {11}}, input);
auto r3 = p.add_instruction(migraphx::op::slice{{1}, {11}, {15}}, input);
p.add_return({r1, r2, r3});
auto prog = migraphx::parse_onnx("split_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(split_test_default)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10, 15}});
auto r1 = p.add_instruction(migraphx::op::slice{{0}, {0}, {5}}, input);
auto r2 = p.add_instruction(migraphx::op::slice{{0}, {5}, {10}}, input);
p.add_return({r1, r2});
auto prog = migraphx::parse_onnx("split_test_default.onnx");
EXPECT(p == prog);
}
TEST_CASE(sqrt_test) TEST_CASE(sqrt_test)
{ {
migraphx::program p; migraphx::program p;
......

split_test:
5
xy1y2y3"Split*
axis*
split@@@
split_testZ
x


b
y1


b
y2


b
y3


B
\ No newline at end of file
split_test_default:i

xy1y2"Splitsplit_test_defaultZ
x


b
y1


b
y2


B
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment