Unverified Commit c4e53a33 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Add parsing split operator (#460)



* change parse operator function signature

* clang format

* add parsing the split operator

* clang format

* add parsing split operator

* make squeeze/unsqueeze inputs to standard shape

* add unit tests for the split operator

* clang format

* fix cppcheck error

* clang format

* update tests for multiple program outputs

* clang format

* fix cppcheck error

* clang format
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
Co-authored-by: default avatarPaul Fultz II <pfultz2@yahoo.com>
parent 330224aa
......@@ -24,9 +24,14 @@ inline namespace MIGRAPHX_INLINE_NS {
struct onnx_parser
{
using attribute_map = std::unordered_map<std::string, onnx::AttributeProto>;
using node_map = std::unordered_map<std::string, onnx::NodeProto>;
struct node_info
{
attribute_map attributes{};
std::size_t num_outputs = 1;
};
using node_map = std::unordered_map<std::string, onnx::NodeProto>;
using op_func =
std::function<std::vector<instruction_ref>(attribute_map, std::vector<instruction_ref>)>;
std::function<std::vector<instruction_ref>(node_info, std::vector<instruction_ref>)>;
node_map nodes;
std::unordered_map<std::string, instruction_ref> instructions;
program prog = program();
......@@ -121,6 +126,7 @@ struct onnx_parser
add_mem_op("Shape", &onnx_parser::parse_shape);
add_mem_op("Slice", &onnx_parser::parse_slice);
add_mem_op("Softmax", &onnx_parser::parse_softmax<op::softmax>);
add_mem_op("Split", &onnx_parser::parse_split);
add_mem_op("Squeeze", &onnx_parser::parse_squeeze);
add_mem_op("Transpose", &onnx_parser::parse_transpose);
add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze);
......@@ -166,15 +172,15 @@ struct onnx_parser
template <class T>
void add_binary_op(std::string name, T x)
{
add_op(name, [this, x](attribute_map attributes, std::vector<instruction_ref> args) {
add_op(name, [this, x](node_info info, std::vector<instruction_ref> args) {
if(args.size() != 2)
MIGRAPHX_THROW("binary operators should have 2 operands");
if(contains(attributes, "broadcast") and contains(attributes, "axis"))
if(contains(info.attributes, "broadcast") and contains(info.attributes, "axis"))
{
uint64_t broadcasted = parse_value(attributes.at("broadcast")).at<uint64_t>();
uint64_t broadcasted = parse_value(info.attributes.at("broadcast")).at<uint64_t>();
if(broadcasted != 0)
{
uint64_t axis = parse_value(attributes.at("axis")).at<uint64_t>();
uint64_t axis = parse_value(info.attributes.at("axis")).at<uint64_t>();
auto l = prog.add_instruction(op::broadcast{axis, args[0]->get_shape().lens()},
args[1]);
return prog.add_instruction(x, args[0], l);
......@@ -266,7 +272,7 @@ struct onnx_parser
template <class T>
void add_generic_op(std::string name, T x)
{
add_op(name, [this, x](const attribute_map&, std::vector<instruction_ref> args) {
add_op(name, [this, x](const node_info&, std::vector<instruction_ref> args) {
return prog.add_instruction(x, args);
});
}
......@@ -274,7 +280,7 @@ struct onnx_parser
template <class T>
void add_variadic_op(std::string name, T x)
{
add_op(name, [this, x](const attribute_map&, std::vector<instruction_ref> args) {
add_op(name, [this, x](const node_info&, std::vector<instruction_ref> args) {
return std::accumulate(std::next(args.begin()),
args.end(),
args.front(),
......@@ -321,51 +327,48 @@ struct onnx_parser
}
}
instruction_ref parse_clip(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
instruction_ref
parse_clip(const std::string&, node_info info, std::vector<instruction_ref> args)
{
op::clip op;
if(contains(attributes, "max"))
if(contains(info.attributes, "max"))
{
op.max_val = parse_value(attributes.at("max")).at<float>();
op.max_val = parse_value(info.attributes.at("max")).at<float>();
}
if(contains(attributes, "min"))
if(contains(info.attributes, "min"))
{
op.min_val = parse_value(attributes.at("min")).at<float>();
op.min_val = parse_value(info.attributes.at("min")).at<float>();
}
return prog.add_instruction(op, std::move(args));
}
template <class Op>
instruction_ref parse_softmax(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
instruction_ref
parse_softmax(const std::string&, node_info info, std::vector<instruction_ref> args)
{
int64_t axis = 1;
if(contains(attributes, "axis"))
if(contains(info.attributes, "axis"))
{
axis = parse_value(attributes.at("axis")).at<int>();
axis = parse_value(info.attributes.at("axis")).at<int>();
}
return prog.add_instruction(Op{axis}, std::move(args));
}
template <class Op>
instruction_ref parse_arg_op(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
instruction_ref
parse_arg_op(const std::string&, node_info info, std::vector<instruction_ref> args)
{
int64_t axis = 0;
if(contains(attributes, "axis"))
if(contains(info.attributes, "axis"))
{
axis = static_cast<int64_t>(parse_value(attributes.at("axis")).at<int>());
axis = static_cast<int64_t>(parse_value(info.attributes.at("axis")).at<int>());
}
int keep_dims = 1;
if(contains(attributes, "keepdims"))
if(contains(info.attributes, "keepdims"))
{
keep_dims = parse_value(attributes.at("keepdims")).at<int>();
keep_dims = parse_value(info.attributes.at("keepdims")).at<int>();
}
if(keep_dims == 0)
......@@ -381,16 +384,16 @@ struct onnx_parser
template <class Op>
instruction_ref process_auto_pad_attribute(instruction_ref ins,
attribute_map& attributes,
node_info info,
Op& op,
const std::vector<std::size_t>& in_lens)
{
if(!contains(attributes, "auto_pad"))
if(!contains(info.attributes, "auto_pad"))
{
return ins;
}
auto auto_pad = attributes["auto_pad"].s();
auto auto_pad = info.attributes["auto_pad"].s();
if(auto_pad.find("SAME") != std::string::npos)
{
// calculate the padding
......@@ -440,41 +443,41 @@ struct onnx_parser
template <class Op>
instruction_ref
parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
parse_conv(const std::string&, node_info info, std::vector<instruction_ref> args)
{
Op op;
auto l0 = args[0];
auto weights = args[1];
if(contains(attributes, "pads"))
if(contains(info.attributes, "pads"))
{
if(contains(attributes, "auto_pad"))
if(contains(info.attributes, "auto_pad"))
{
auto s = attributes["auto_pad"].s();
if(contains(attributes, "pads") and to_upper(s) != "NOTSET")
auto s = info.attributes["auto_pad"].s();
if(contains(info.attributes, "pads") and to_upper(s) != "NOTSET")
{
MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
}
}
std::vector<std::int64_t> padding;
copy(attributes["pads"].ints(), std::back_inserter(padding));
copy(info.attributes["pads"].ints(), std::back_inserter(padding));
if(padding.size() != 4)
{
MIGRAPHX_THROW("padding should have 4 values");
}
check_asym_padding(l0, padding, op);
}
if(contains(attributes, "strides"))
if(contains(info.attributes, "strides"))
{
copy(attributes["strides"].ints(), op.stride.begin());
copy(info.attributes["strides"].ints(), op.stride.begin());
}
if(contains(attributes, "dilations"))
if(contains(info.attributes, "dilations"))
{
copy(attributes["dilations"].ints(), op.dilation.begin());
copy(info.attributes["dilations"].ints(), op.dilation.begin());
}
if(contains(attributes, "auto_pad"))
if(contains(info.attributes, "auto_pad"))
{
auto s = attributes["auto_pad"].s();
if(contains(attributes, "pads") and to_upper(s) != "NOTSET")
auto s = info.attributes["auto_pad"].s();
if(contains(info.attributes, "pads") and to_upper(s) != "NOTSET")
{
MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
}
......@@ -496,34 +499,33 @@ struct onnx_parser
check_asym_padding(l0, padding, op);
}
}
if(contains(attributes, "group"))
if(contains(info.attributes, "group"))
{
op.group = parse_value(attributes.at("group")).at<int>();
op.group = parse_value(info.attributes.at("group")).at<int>();
}
auto l1 = prog.add_instruction(op, l0, args[1]);
return add_bias(args, l1, 1);
}
instruction_ref parse_conv_transpose(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
instruction_ref
parse_conv_transpose(const std::string&, node_info info, std::vector<instruction_ref> args)
{
op::deconvolution op;
auto l0 = args[0];
std::vector<std::int64_t> padding;
bool asymm_padding = false;
if(contains(attributes, "pads"))
if(contains(info.attributes, "pads"))
{
if(contains(attributes, "auto_pad"))
if(contains(info.attributes, "auto_pad"))
{
auto s = attributes["auto_pad"].s();
if(contains(attributes, "pads") and to_upper(s) != "NOTSET")
auto s = info.attributes["auto_pad"].s();
if(contains(info.attributes, "pads") and to_upper(s) != "NOTSET")
{
MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
}
}
copy(attributes["pads"].ints(), std::back_inserter(padding));
copy(info.attributes["pads"].ints(), std::back_inserter(padding));
if(padding.size() != 4)
{
MIGRAPHX_THROW("padding should have 4 values");
......@@ -538,18 +540,18 @@ struct onnx_parser
op.padding[1] = padding[1];
}
}
if(contains(attributes, "strides"))
if(contains(info.attributes, "strides"))
{
copy(attributes["strides"].ints(), op.stride.begin());
copy(info.attributes["strides"].ints(), op.stride.begin());
}
if(contains(attributes, "dilations"))
if(contains(info.attributes, "dilations"))
{
copy(attributes["dilations"].ints(), op.dilation.begin());
copy(info.attributes["dilations"].ints(), op.dilation.begin());
}
if(contains(attributes, "auto_pad"))
if(contains(info.attributes, "auto_pad"))
{
auto s = attributes["auto_pad"].s();
if(contains(attributes, "pads") and to_upper(s) != "NOTSET")
auto s = info.attributes["auto_pad"].s();
if(contains(info.attributes, "pads") and to_upper(s) != "NOTSET")
{
MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
}
......@@ -560,9 +562,9 @@ struct onnx_parser
}
}
if(contains(attributes, "group"))
if(contains(info.attributes, "group"))
{
op.group = parse_value(attributes.at("group")).at<int>();
op.group = parse_value(info.attributes.at("group")).at<int>();
}
auto l1 = prog.add_instruction(op, l0, args[1]);
......@@ -579,18 +581,18 @@ struct onnx_parser
l1 = prog.add_instruction(slice_op, l1);
}
if(contains(attributes, "output_padding"))
if(contains(info.attributes, "output_padding"))
{
std::vector<int64_t> output_padding;
copy(attributes["output_padding"].ints(), std::back_inserter(output_padding));
copy(info.attributes["output_padding"].ints(), std::back_inserter(output_padding));
output_padding = {0, 0, 0, 0, 0, 0, output_padding[0], output_padding[1]};
l1 = prog.add_instruction(op::pad{output_padding}, l1);
}
if(contains(attributes, "output_shape"))
if(contains(info.attributes, "output_shape"))
{
std::vector<int64_t> output_shape;
copy(attributes["output_shape"].ints(), std::back_inserter(output_shape));
copy(info.attributes["output_shape"].ints(), std::back_inserter(output_shape));
dims = to_int64_vector(l1->get_shape().lens());
curr_shape = {dims[2], dims[3]};
if(curr_shape != output_shape)
......@@ -610,9 +612,8 @@ struct onnx_parser
return add_bias(args, l1, 1);
}
instruction_ref parse_pooling(const std::string& name,
attribute_map attributes,
std::vector<instruction_ref> args)
instruction_ref
parse_pooling(const std::string& name, node_info info, std::vector<instruction_ref> args)
{
op::pooling op{ends_with(name, "MaxPool") ? "max" : "average"};
auto l0 = args[0];
......@@ -622,11 +623,11 @@ struct onnx_parser
op.lengths = {lens[2], lens[3]};
}
if(contains(attributes, "pads"))
if(contains(info.attributes, "pads"))
{
if(contains(attributes, "auto_pad"))
if(contains(info.attributes, "auto_pad"))
{
auto s = attributes["auto_pad"].s();
auto s = info.attributes["auto_pad"].s();
if(to_upper(s) != "NOTSET")
{
MIGRAPHX_THROW(
......@@ -635,7 +636,7 @@ struct onnx_parser
}
std::vector<std::int64_t> padding;
copy(attributes["pads"].ints(), std::back_inserter(padding));
copy(info.attributes["pads"].ints(), std::back_inserter(padding));
if(padding.size() != 4)
{
MIGRAPHX_THROW("PARSE_POOLING: padding should have 4 values");
......@@ -646,31 +647,31 @@ struct onnx_parser
check_asym_padding(l0, padding, op, pad_val);
}
if(contains(attributes, "strides"))
if(contains(info.attributes, "strides"))
{
copy(attributes["strides"].ints(), op.stride.begin());
copy(info.attributes["strides"].ints(), op.stride.begin());
}
if(contains(attributes, "kernel_shape"))
if(contains(info.attributes, "kernel_shape"))
{
copy(attributes["kernel_shape"].ints(), op.lengths.begin());
copy(info.attributes["kernel_shape"].ints(), op.lengths.begin());
}
if(contains(attributes, "auto_pad"))
if(contains(info.attributes, "auto_pad"))
{
auto in_lens = args[0]->get_shape().lens();
l0 = process_auto_pad_attribute(l0, attributes, op, in_lens);
l0 = process_auto_pad_attribute(l0, info, op, in_lens);
}
return prog.add_instruction(op, l0);
}
instruction_ref
parse_reshape(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
parse_reshape(const std::string&, node_info info, std::vector<instruction_ref> args)
{
op::reshape op;
if(args.size() == 1)
{
literal s = parse_value(attributes.at("shape"));
literal s = parse_value(info.attributes.at("shape"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
}
if(args.size() == 2)
......@@ -684,55 +685,55 @@ struct onnx_parser
}
instruction_ref
parse_flatten(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
parse_flatten(const std::string&, node_info info, std::vector<instruction_ref> args)
{
int64_t axis = 1;
if(contains(attributes, "axis"))
if(contains(info.attributes, "axis"))
{
axis = parse_value(attributes.at("axis")).at<int>();
axis = parse_value(info.attributes.at("axis")).at<int>();
}
return prog.add_instruction(op::flatten{axis}, args[0]);
}
instruction_ref
parse_squeeze(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
parse_squeeze(const std::string&, node_info info, std::vector<instruction_ref> args)
{
op::squeeze op;
literal s = parse_value(attributes.at("axes"));
literal s = parse_value(info.attributes.at("axes"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
return prog.add_instruction(op, args[0]);
return prog.add_instruction(op, make_contiguous(args[0]));
}
instruction_ref
parse_unsqueeze(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
parse_unsqueeze(const std::string&, node_info info, std::vector<instruction_ref> args)
{
op::unsqueeze op;
literal s = parse_value(attributes.at("axes"));
literal s = parse_value(info.attributes.at("axes"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
return prog.add_instruction(op, args[0]);
return prog.add_instruction(op, make_contiguous(args[0]));
}
instruction_ref
parse_concat(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
parse_concat(const std::string&, node_info info, std::vector<instruction_ref> args)
{
// change to hande axis to be negative values
if(!contains(attributes, "axis"))
if(!contains(info.attributes, "axis"))
{
MIGRAPHX_THROW("PARSE_CONCAT: attribute axis is required!");
}
int axis = parse_value(attributes.at("axis")).at<int>();
int axis = parse_value(info.attributes.at("axis")).at<int>();
op::concat op{axis};
return prog.add_instruction(op, std::move(args));
}
instruction_ref
parse_gather(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
parse_gather(const std::string&, node_info info, std::vector<instruction_ref> args)
{
int axis = 0;
if(contains(attributes, "axis"))
if(contains(info.attributes, "axis"))
{
axis = parse_value(attributes.at("axis")).at<int>();
axis = parse_value(info.attributes.at("axis")).at<int>();
}
op::gather op{axis};
......@@ -740,14 +741,14 @@ struct onnx_parser
}
instruction_ref
parse_slice(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
parse_slice(const std::string&, node_info info, std::vector<instruction_ref> args)
{
op::slice op;
std::vector<size_t> dims = args[0]->get_shape().lens();
size_t num_dims = dims.size();
if(contains(attributes, "axes"))
if(contains(info.attributes, "axes"))
{
literal s = parse_value(attributes.at("axes"));
literal s = parse_value(info.attributes.at("axes"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
}
else
......@@ -756,30 +757,29 @@ struct onnx_parser
std::iota(op.axes.begin(), op.axes.end(), 0);
}
if(contains(attributes, "ends"))
if(contains(info.attributes, "ends"))
{
op.ends = get_indices(attributes.at("ends"));
op.ends = get_indices(info.attributes.at("ends"));
}
if(contains(attributes, "starts"))
if(contains(info.attributes, "starts"))
{
literal s = parse_value(attributes.at("starts"));
literal s = parse_value(info.attributes.at("starts"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); });
}
return prog.add_instruction(op, args[0]);
}
instruction_ref parse_constant(const std::string&,
attribute_map attributes,
const std::vector<instruction_ref>&)
instruction_ref
parse_constant(const std::string&, node_info info, const std::vector<instruction_ref>&)
{
literal v = parse_value(attributes.at("value"));
literal v = parse_value(info.attributes.at("value"));
// return empty literal
if(v.get_shape().elements() == 0)
{
return prog.add_literal(literal{});
}
auto dim_size = attributes.at("value").t().dims_size();
auto dim_size = info.attributes.at("value").t().dims_size();
// if dim_size is 0, it is a scalar
if(dim_size == 0)
{
......@@ -791,27 +791,27 @@ struct onnx_parser
}
instruction_ref
parse_gemm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
parse_gemm(const std::string&, node_info info, std::vector<instruction_ref> args)
{
float alpha = 1.0f;
float beta = 1.0f;
bool transa = false;
bool transb = false;
if(contains(attributes, "alpha"))
if(contains(info.attributes, "alpha"))
{
alpha = parse_value(attributes.at("alpha")).at<float>();
alpha = parse_value(info.attributes.at("alpha")).at<float>();
}
if(contains(attributes, "beta"))
if(contains(info.attributes, "beta"))
{
beta = parse_value(attributes.at("beta")).at<float>();
beta = parse_value(info.attributes.at("beta")).at<float>();
}
if(contains(attributes, "transA"))
if(contains(info.attributes, "transA"))
{
transa = parse_value(attributes.at("transA")).at<bool>();
transa = parse_value(info.attributes.at("transA")).at<bool>();
}
if(contains(attributes, "transB"))
if(contains(info.attributes, "transB"))
{
transb = parse_value(attributes.at("transB")).at<bool>();
transb = parse_value(info.attributes.at("transB")).at<bool>();
}
std::vector<int64_t> perm(args[0]->get_shape().lens().size());
......@@ -842,7 +842,7 @@ struct onnx_parser
template <class Op>
instruction_ref
parse_matmul(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
parse_matmul(const std::string&, const node_info&, std::vector<instruction_ref> args)
{
auto l0 = args[0];
auto l1 = args[1];
......@@ -905,22 +905,22 @@ struct onnx_parser
}
instruction_ref
parse_batchnorm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
parse_batchnorm(const std::string&, node_info info, std::vector<instruction_ref> args)
{
float epsilon = 1e-5f;
float momentum = 0.9f;
op::batch_norm_inference::bn_infer_mode_t bn_mode = op::batch_norm_inference::spatial;
if(contains(attributes, "epsilon"))
if(contains(info.attributes, "epsilon"))
{
epsilon = parse_value(attributes.at("epsilon")).at<float>();
epsilon = parse_value(info.attributes.at("epsilon")).at<float>();
}
if(contains(attributes, "momentum"))
if(contains(info.attributes, "momentum"))
{
momentum = parse_value(attributes.at("momentum")).at<float>();
momentum = parse_value(info.attributes.at("momentum")).at<float>();
}
if(contains(attributes, "spatial"))
if(contains(info.attributes, "spatial"))
{
bn_mode = (parse_value(attributes.at("spatial")).at<uint64_t>() > 0)
bn_mode = (parse_value(info.attributes.at("spatial")).at<uint64_t>() > 0)
? op::batch_norm_inference::spatial
: op::batch_norm_inference::per_activation;
}
......@@ -928,18 +928,17 @@ struct onnx_parser
return prog.add_instruction(op, std::move(args));
}
instruction_ref parse_instancenorm(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
instruction_ref
parse_instancenorm(const std::string&, node_info info, std::vector<instruction_ref> args)
{
// y = scale * ( x - mean ) / sqrt ( variance + epsilon ) + bias
// mean = reduce_mean({H, W}, x)
// variance = reduce_mean({H, W}, (x - mean)^2)
float epsilon = 1e-5f;
if(contains(attributes, "epsilon"))
if(contains(info.attributes, "epsilon"))
{
epsilon = parse_value(attributes.at("epsilon")).at<float>();
epsilon = parse_value(info.attributes.at("epsilon")).at<float>();
}
auto x = args[0];
auto scale = args[1];
......@@ -964,64 +963,60 @@ struct onnx_parser
return prog.add_instruction(op::add{}, l5, bias_bcast);
}
instruction_ref parse_leaky_relu(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
instruction_ref
parse_leaky_relu(const std::string&, node_info info, std::vector<instruction_ref> args)
{
float alpha = 0.01; // default alpha val for leaky relu
if(contains(attributes, "alpha"))
if(contains(info.attributes, "alpha"))
{
alpha = parse_value(attributes.at("alpha")).at<float>();
alpha = parse_value(info.attributes.at("alpha")).at<float>();
}
op::leaky_relu op{alpha};
return prog.add_instruction(op, args.front());
}
instruction_ref
parse_elu(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
instruction_ref parse_elu(const std::string&, node_info info, std::vector<instruction_ref> args)
{
float alpha = 1.0; // default alpha val for elu
if(contains(attributes, "alpha"))
if(contains(info.attributes, "alpha"))
{
alpha = parse_value(attributes.at("alpha")).at<float>();
alpha = parse_value(info.attributes.at("alpha")).at<float>();
}
op::elu op{alpha};
return prog.add_instruction(op, args.front());
}
instruction_ref
parse_lrn(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
instruction_ref parse_lrn(const std::string&, node_info info, std::vector<instruction_ref> args)
{
float alpha = 0.0001;
float beta = 0.75;
float bias = 1.0;
int size = 1;
if(contains(attributes, "alpha"))
alpha = parse_value(attributes.at("alpha")).at<float>();
if(contains(attributes, "beta"))
beta = parse_value(attributes.at("beta")).at<float>();
if(contains(attributes, "bias"))
bias = parse_value(attributes.at("bias")).at<float>();
if(contains(attributes, "size"))
size = parse_value(attributes.at("size")).at<int>();
if(contains(info.attributes, "alpha"))
alpha = parse_value(info.attributes.at("alpha")).at<float>();
if(contains(info.attributes, "beta"))
beta = parse_value(info.attributes.at("beta")).at<float>();
if(contains(info.attributes, "bias"))
bias = parse_value(info.attributes.at("bias")).at<float>();
if(contains(info.attributes, "size"))
size = parse_value(info.attributes.at("size")).at<int>();
op::lrn op{alpha, beta, bias, size};
return prog.add_instruction(op, args.front());
}
instruction_ref parse_imagescaler(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
instruction_ref
parse_imagescaler(const std::string&, node_info info, std::vector<instruction_ref> args)
{
float scale = 1.0;
std::vector<float> bias{};
if(contains(attributes, "scale"))
if(contains(info.attributes, "scale"))
{
scale = parse_value(attributes.at("scale")).at<float>();
scale = parse_value(info.attributes.at("scale")).at<float>();
}
if(contains(attributes, "bias"))
if(contains(info.attributes, "bias"))
{
auto&& bias_floats = attributes["bias"].floats();
auto&& bias_floats = info.attributes["bias"].floats();
bias = std::vector<float>(bias_floats.begin(), bias_floats.end());
}
auto input_shape = args.front()->get_shape();
......@@ -1038,25 +1033,24 @@ struct onnx_parser
}
instruction_ref
parse_transpose(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
parse_transpose(const std::string&, node_info info, std::vector<instruction_ref> args)
{
std::vector<int64_t> perm{};
if(contains(attributes, "perm"))
if(contains(info.attributes, "perm"))
{
auto&& perm_vals = attributes["perm"].ints();
auto&& perm_vals = info.attributes["perm"].ints();
perm = std::vector<int64_t>(perm_vals.begin(), perm_vals.end());
}
return prog.add_instruction(migraphx::op::transpose{perm}, args.front());
}
instruction_ref
parse_pad(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
instruction_ref parse_pad(const std::string&, node_info info, std::vector<instruction_ref> args)
{
std::vector<int64_t> pads{};
float value = 0.0f;
if(contains(attributes, "pads"))
if(contains(info.attributes, "pads"))
{
auto&& pad_vals = attributes["pads"].ints();
auto&& pad_vals = info.attributes["pads"].ints();
pads = std::vector<int64_t>(pad_vals.begin(), pad_vals.end());
}
// check if padding is actually being done (at least one value is nonzero)
......@@ -1064,13 +1058,13 @@ struct onnx_parser
{
return prog.add_instruction(migraphx::op::identity{}, args.front());
}
if(contains(attributes, "value"))
if(contains(info.attributes, "value"))
{
value = parse_value(attributes.at("value")).at<float>();
value = parse_value(info.attributes.at("value")).at<float>();
}
if(contains(attributes, "mode"))
if(contains(info.attributes, "mode"))
{
auto mode = attributes.at("mode").s();
auto mode = info.attributes.at("mode").s();
if(mode != "constant")
MIGRAPHX_THROW("migraphx currently only supports constant padding");
}
......@@ -1079,7 +1073,7 @@ struct onnx_parser
// Use a literal instruction to replace the shape since, output of
// shape operator are literals in migraphx
instruction_ref
parse_shape(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
parse_shape(const std::string&, const node_info&, std::vector<instruction_ref> args)
{
if(args.size() != 1)
MIGRAPHX_THROW("Shape: operator should have 1 operand");
......@@ -1095,31 +1089,30 @@ struct onnx_parser
// Use a literal instruction to replace the constantFill operator. In RNN, input shape
// and value are fixed, so no need to do the actual computation for the constantFill
// operator
instruction_ref parse_constant_fill(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
instruction_ref
parse_constant_fill(const std::string&, node_info info, std::vector<instruction_ref> args)
{
int input_as_shape = 0;
int dtype = 1;
float value = 0.0f;
if(contains(attributes, "dtype"))
if(contains(info.attributes, "dtype"))
{
dtype = parse_value(attributes.at("dtype")).at<int>();
dtype = parse_value(info.attributes.at("dtype")).at<int>();
}
shape::type_t type = get_type(dtype);
if(contains(attributes, "input_as_shape"))
if(contains(info.attributes, "input_as_shape"))
{
input_as_shape = parse_value(attributes.at("input_as_shape")).at<int>();
input_as_shape = parse_value(info.attributes.at("input_as_shape")).at<int>();
}
if(contains(attributes, "value"))
if(contains(info.attributes, "value"))
{
value = parse_value(attributes.at("value")).at<float>();
value = parse_value(info.attributes.at("value")).at<float>();
}
if(contains(attributes, "extra_shape"))
if(contains(info.attributes, "extra_shape"))
{
MIGRAPHX_THROW("ConstantFill: cannot handle extra shape attribute");
}
......@@ -1131,7 +1124,7 @@ struct onnx_parser
MIGRAPHX_THROW("ConstantFill: need an input argument as output shape");
}
if(contains(attributes, "shape"))
if(contains(info.attributes, "shape"))
{
MIGRAPHX_THROW("ConstantFill: cannot set the shape argument and pass in an input "
"at the same time");
......@@ -1148,12 +1141,12 @@ struct onnx_parser
}
else if(input_as_shape == 0)
{
if(!contains(attributes, "shape"))
if(!contains(info.attributes, "shape"))
{
MIGRAPHX_THROW("ConstantFill: attribute output shape is needed");
}
literal ls = parse_value(attributes.at("shape"));
literal ls = parse_value(info.attributes.at("shape"));
std::vector<std::size_t> dims;
ls.visit([&](auto s) { dims.assign(s.begin(), s.end()); });
migraphx::shape s{type, dims};
......@@ -1166,14 +1159,13 @@ struct onnx_parser
}
}
instruction_ref parse_constant_of_shape(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
instruction_ref
parse_constant_of_shape(const std::string&, node_info info, std::vector<instruction_ref> args)
{
literal l_val{};
if(contains(attributes, "value"))
if(contains(info.attributes, "value"))
{
l_val = parse_value(attributes.at("value"));
l_val = parse_value(info.attributes.at("value"));
if(l_val.get_shape().elements() != 1)
{
MIGRAPHX_THROW("ConstantOfShape: attribute value can contain only 1 elements!");
......@@ -1222,7 +1214,7 @@ struct onnx_parser
}
instruction_ref
parse_expand(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
parse_expand(const std::string&, const node_info&, std::vector<instruction_ref> args)
{
auto in_lens = args[0]->get_shape().lens();
migraphx::argument arg_s = args[1]->eval();
......@@ -1234,14 +1226,14 @@ struct onnx_parser
}
std::vector<instruction_ref>
parse_rnn(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
parse_rnn(const std::string&, node_info info, std::vector<instruction_ref> args)
{
migraphx::shape input_shape = args[0]->get_shape();
std::size_t hidden_size = args[1]->get_shape().lens()[1];
if(contains(attributes, "hidden_size"))
if(contains(info.attributes, "hidden_size"))
{
std::size_t hidden_size_att = parse_value(attributes.at("hidden_size")).at<int>();
std::size_t hidden_size_att = parse_value(info.attributes.at("hidden_size")).at<int>();
if(hidden_size != hidden_size_att)
{
MIGRAPHX_THROW("RNN: hidden size mismatch in input and attribute");
......@@ -1250,9 +1242,9 @@ struct onnx_parser
// Handling of direction to be added later
std::string direction{"forward"};
if(contains(attributes, "direction"))
if(contains(info.attributes, "direction"))
{
direction = attributes.at("direction").s();
direction = info.attributes.at("direction").s();
}
op::rnn_direction dirct = op::rnn_direction::forward;
......@@ -1266,9 +1258,9 @@ struct onnx_parser
}
std::vector<std::string> vec_names{"tanh"};
if(contains(attributes, "activations"))
if(contains(info.attributes, "activations"))
{
auto names = attributes.at("activations").strings();
auto names = info.attributes.at("activations").strings();
vec_names.clear();
vec_names.resize(names.size());
std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
......@@ -1304,9 +1296,9 @@ struct onnx_parser
// To be added later
float clip = 0.0;
if(contains(attributes, "clip"))
if(contains(info.attributes, "clip"))
{
clip = parse_value(attributes.at("clip")).at<float>();
clip = parse_value(info.attributes.at("clip")).at<float>();
}
// if the number of arguments is less than 6, append
......@@ -1328,14 +1320,14 @@ struct onnx_parser
}
std::vector<instruction_ref>
parse_gru(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
parse_gru(const std::string&, node_info info, std::vector<instruction_ref> args)
{
migraphx::shape input_shape = args[0]->get_shape();
std::size_t hidden_size = args[2]->get_shape().lens()[2];
if(contains(attributes, "hidden_size"))
if(contains(info.attributes, "hidden_size"))
{
std::size_t hidden_size_att = parse_value(attributes.at("hidden_size")).at<int>();
std::size_t hidden_size_att = parse_value(info.attributes.at("hidden_size")).at<int>();
if(hidden_size != hidden_size_att)
{
MIGRAPHX_THROW("GRU: hidden size mismatch in input and attribute");
......@@ -1344,9 +1336,9 @@ struct onnx_parser
// Handling of direction to be added later
std::string direction{"forward"};
if(contains(attributes, "direction"))
if(contains(info.attributes, "direction"))
{
direction = attributes.at("direction").s();
direction = info.attributes.at("direction").s();
}
op::rnn_direction dirct = op::rnn_direction::forward;
......@@ -1360,9 +1352,9 @@ struct onnx_parser
}
std::vector<std::string> vec_names = {"sigmoid", "tanh"};
if(contains(attributes, "activations"))
if(contains(info.attributes, "activations"))
{
auto names = attributes.at("activations").strings();
auto names = info.attributes.at("activations").strings();
vec_names.clear();
vec_names.resize(names.size());
std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
......@@ -1420,15 +1412,15 @@ struct onnx_parser
[&](const auto& name) { return map_actv_funcs[name]; });
float clip = 0.0;
if(contains(attributes, "clip"))
if(contains(info.attributes, "clip"))
{
clip = parse_value(attributes.at("clip")).at<float>();
clip = parse_value(info.attributes.at("clip")).at<float>();
}
int linear_before_reset = 0;
if(contains(attributes, "linear_before_reset"))
if(contains(info.attributes, "linear_before_reset"))
{
linear_before_reset = parse_value(attributes.at("linear_before_reset")).at<int>();
linear_before_reset = parse_value(info.attributes.at("linear_before_reset")).at<int>();
}
// append undefined opeator to make 6 arguments
......@@ -1450,14 +1442,14 @@ struct onnx_parser
}
std::vector<instruction_ref>
parse_lstm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
parse_lstm(const std::string&, node_info info, std::vector<instruction_ref> args)
{
migraphx::shape input_shape = args[0]->get_shape();
std::size_t hidden_size = args[2]->get_shape().lens()[2];
if(contains(attributes, "hidden_size"))
if(contains(info.attributes, "hidden_size"))
{
std::size_t hidden_size_att = parse_value(attributes.at("hidden_size")).at<int>();
std::size_t hidden_size_att = parse_value(info.attributes.at("hidden_size")).at<int>();
if(hidden_size != hidden_size_att)
{
MIGRAPHX_THROW("LSTM: hidden size mismatch in input and attribute");
......@@ -1466,9 +1458,9 @@ struct onnx_parser
// Handling of direction to be added later
std::string direction{"forward"};
if(contains(attributes, "direction"))
if(contains(info.attributes, "direction"))
{
direction = attributes.at("direction").s();
direction = info.attributes.at("direction").s();
}
op::rnn_direction dirct = op::rnn_direction::forward;
......@@ -1490,9 +1482,9 @@ struct onnx_parser
}
std::vector<std::string> vec_names = {"sigmoid", "tanh", "tanh"};
if(contains(attributes, "activations"))
if(contains(info.attributes, "activations"))
{
auto names = attributes.at("activations").strings();
auto names = info.attributes.at("activations").strings();
vec_names.clear();
vec_names.resize(names.size());
std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
......@@ -1593,15 +1585,15 @@ struct onnx_parser
[&](const auto& name) { return map_actv_funcs[name]; });
float clip = 0.0;
if(contains(attributes, "clip"))
if(contains(info.attributes, "clip"))
{
clip = parse_value(attributes.at("clip")).at<float>();
clip = parse_value(info.attributes.at("clip")).at<float>();
}
int input_forget = 0;
if(contains(attributes, "input_forget"))
if(contains(info.attributes, "input_forget"))
{
input_forget = parse_value(attributes.at("input_forget")).at<int>();
input_forget = parse_value(info.attributes.at("input_forget")).at<int>();
}
// append undefined opeator to make 6 arguments
......@@ -1625,26 +1617,25 @@ struct onnx_parser
}
template <class T>
instruction_ref parse_reduce_oper(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
instruction_ref
parse_reduce_oper(const std::string&, node_info info, std::vector<instruction_ref> args)
{
std::size_t n_dim = args.front()->get_shape().lens().size();
// default to reduce over all dimensions
std::vector<int64_t> axes(n_dim);
std::iota(axes.begin(), axes.end(), 0);
if(contains(attributes, "axes"))
if(contains(info.attributes, "axes"))
{
axes.clear();
auto&& attr_axes = attributes["axes"].ints();
auto&& attr_axes = info.attributes["axes"].ints();
axes = std::vector<int64_t>(attr_axes.begin(), attr_axes.end());
}
int keep_dims = 1;
if(contains(attributes, "keepdims"))
if(contains(info.attributes, "keepdims"))
{
keep_dims = parse_value(attributes.at("keepdims")).at<int>();
keep_dims = parse_value(info.attributes.at("keepdims")).at<int>();
}
if(keep_dims == 1)
......@@ -1659,59 +1650,108 @@ struct onnx_parser
}
instruction_ref
parse_reduce_l1(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
parse_reduce_l1(const std::string&, node_info info, std::vector<instruction_ref> args)
{
auto abs_ins = prog.add_instruction(op::abs{}, args[0]);
return parse_reduce_oper<op::reduce_sum>({}, std::move(attributes), {abs_ins});
return parse_reduce_oper<op::reduce_sum>({}, std::move(info), {abs_ins});
}
instruction_ref
parse_reduce_l2(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
parse_reduce_l2(const std::string&, node_info info, std::vector<instruction_ref> args)
{
auto square_ins = prog.add_instruction(op::mul{}, args[0], args[0]);
auto sum_ins = parse_reduce_oper<op::reduce_sum>({}, std::move(attributes), {square_ins});
auto sum_ins = parse_reduce_oper<op::reduce_sum>({}, std::move(info), {square_ins});
return prog.add_instruction(op::sqrt{}, sum_ins);
}
instruction_ref parse_reduce_log_sum(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
instruction_ref
parse_reduce_log_sum(const std::string&, node_info info, std::vector<instruction_ref> args)
{
auto sum_ins =
parse_reduce_oper<op::reduce_sum>({}, std::move(attributes), std::move(args));
auto sum_ins = parse_reduce_oper<op::reduce_sum>({}, std::move(info), std::move(args));
return prog.add_instruction(op::log{}, sum_ins);
}
instruction_ref parse_reduce_log_sum_exp(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
instruction_ref
parse_reduce_log_sum_exp(const std::string&, node_info info, std::vector<instruction_ref> args)
{
auto exp_ins = prog.add_instruction(op::exp{}, args[0]);
auto sum_ins = parse_reduce_oper<op::reduce_sum>({}, std::move(attributes), {exp_ins});
auto sum_ins = parse_reduce_oper<op::reduce_sum>({}, std::move(info), {exp_ins});
return prog.add_instruction(op::log{}, sum_ins);
}
instruction_ref parse_reduce_sum_square(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
instruction_ref
parse_reduce_sum_square(const std::string&, node_info info, std::vector<instruction_ref> args)
{
auto square_ins = prog.add_instruction(op::mul{}, args[0], args[0]);
return parse_reduce_oper<op::reduce_sum>({}, std::move(attributes), {square_ins});
return parse_reduce_oper<op::reduce_sum>({}, std::move(info), {square_ins});
}
instruction_ref
parse_cast(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
parse_cast(const std::string&, node_info info, std::vector<instruction_ref> args)
{
if(!contains(attributes, "to"))
if(!contains(info.attributes, "to"))
{
MIGRAPHX_THROW("PARSE_CAST: missing to type attribute!");
}
int to_type = parse_value(attributes.at("to")).at<int>();
int to_type = parse_value(info.attributes.at("to")).at<int>();
shape::type_t type = get_type(to_type);
return prog.add_instruction(op::convert{type}, std::move(args));
}
std::vector<instruction_ref>
parse_split(const std::string&, node_info info, std::vector<instruction_ref> args)
{
int64_t axis = 0;
if(contains(info.attributes, "axis"))
{
axis = parse_value(info.attributes.at("axis")).at<int>();
}
auto lens = args[0]->get_shape().lens();
int64_t n_rank = static_cast<int64_t>(lens.size());
if((axis < -n_rank) || (axis >= n_rank))
{
MIGRAPHX_THROW("PARSE_SPLIT: axis attribute out of rank!");
}
int64_t tuned_axis = (axis < 0) ? axis + n_rank : axis;
std::vector<int64_t> vec_splits;
if(contains(info.attributes, "split"))
{
literal s = parse_value(info.attributes.at("split"));
s.visit([&](auto v) { vec_splits.assign(v.begin(), v.end()); });
if(std::accumulate(vec_splits.begin(), vec_splits.end(), int64_t(0)) !=
static_cast<int64_t>(lens[tuned_axis]))
{
MIGRAPHX_THROW("PARSE_SPLIT: sum of split attribute unequal to dim size of axis!");
}
}
// no split attribute, input is equally divided
else
{
if((lens[tuned_axis] % info.num_outputs) != 0)
{
MIGRAPHX_THROW("PARSE_SPLIT: input cannot be equally divided into " +
to_string(info.num_outputs) + " splits!");
}
auto dl = lens[tuned_axis] / info.num_outputs;
vec_splits.resize(info.num_outputs, dl);
}
std::vector<instruction_ref> ret_ins;
int64_t start = 0;
for(auto sl : vec_splits)
{
ret_ins.push_back(
prog.add_instruction(op::slice{{axis}, {start}, {start + sl}}, args[0]));
start += sl;
}
return ret_ins;
}
void parse_from(std::istream& is)
{
onnx::ModelProto model;
......@@ -1824,7 +1864,8 @@ struct onnx_parser
}
else
{
result = ops[node.op_type()](get_attributes(node), args);
std::size_t output_num = static_cast<std::size_t>(node.output().size());
result = ops[node.op_type()]({get_attributes(node), output_num}, args);
}
// Even no output nodes produce output in migraphx
if(node.output().empty() and result.size() == 1)
......
......@@ -1771,6 +1771,37 @@ def softmax_test():
return ([node], [x], [y])
@onnx_test
def split_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 15])
y1 = helper.make_tensor_value_info('y1', TensorProto.FLOAT, [10, 7])
y2 = helper.make_tensor_value_info('y2', TensorProto.FLOAT, [10, 4])
y3 = helper.make_tensor_value_info('y3', TensorProto.FLOAT, [10, 4])
node = onnx.helper.make_node('Split',
inputs=['x'],
outputs=['y1', 'y2', 'y3'],
axis=1,
split=[7, 4, 4])
return ([node], [x], [y1, y2, y3])
@onnx_test
def split_test_default():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 15])
y1 = helper.make_tensor_value_info('y1', TensorProto.FLOAT, [5, 15])
y2 = helper.make_tensor_value_info('y2', TensorProto.FLOAT, [5, 15])
node = onnx.helper.make_node(
'Split',
inputs=['x'],
outputs=['y1', 'y2'],
)
return ([node], [x], [y1, y2])
@onnx_test
def sqrt_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 15])
......
......@@ -1356,6 +1356,31 @@ TEST_CASE(softmax_test)
EXPECT(p == prog);
}
TEST_CASE(split_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10, 15}});
auto r1 = p.add_instruction(migraphx::op::slice{{1}, {0}, {7}}, input);
auto r2 = p.add_instruction(migraphx::op::slice{{1}, {7}, {11}}, input);
auto r3 = p.add_instruction(migraphx::op::slice{{1}, {11}, {15}}, input);
p.add_return({r1, r2, r3});
auto prog = migraphx::parse_onnx("split_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(split_test_default)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10, 15}});
auto r1 = p.add_instruction(migraphx::op::slice{{0}, {0}, {5}}, input);
auto r2 = p.add_instruction(migraphx::op::slice{{0}, {5}, {10}}, input);
p.add_return({r1, r2});
auto prog = migraphx::parse_onnx("split_test_default.onnx");
EXPECT(p == prog);
}
TEST_CASE(sqrt_test)
{
migraphx::program p;
......

split_test:
5
xy1y2y3"Split*
axis*
split@@@
split_testZ
x


b
y1


b
y2


b
y3


B
\ No newline at end of file
split_test_default:i

xy1y2"Splitsplit_test_defaultZ
x


b
y1


b
y2


B
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment