Commit 860208db authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add onnx support for three operatos--cast, constantofshape, and expand

parent cd5fd751
...@@ -65,11 +65,13 @@ struct onnx_parser ...@@ -65,11 +65,13 @@ struct onnx_parser
add_mem_op("ArgMax", &onnx_parser::parse_argmax); add_mem_op("ArgMax", &onnx_parser::parse_argmax);
add_mem_op("ArgMin", &onnx_parser::parse_argmin); add_mem_op("ArgMin", &onnx_parser::parse_argmin);
add_mem_op("Cast", &onnx_parser::parse_cast);
add_mem_op("Clip", &onnx_parser::parse_clip); add_mem_op("Clip", &onnx_parser::parse_clip);
add_mem_op("LRN", &onnx_parser::parse_lrn); add_mem_op("LRN", &onnx_parser::parse_lrn);
add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler); add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler);
add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu); add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu);
add_mem_op("Elu", &onnx_parser::parse_elu); add_mem_op("Elu", &onnx_parser::parse_elu);
add_mem_op("Expand", &onnx_parser::parse_expand);
add_mem_op("Constant", &onnx_parser::parse_constant); add_mem_op("Constant", &onnx_parser::parse_constant);
add_mem_op("Conv", &onnx_parser::parse_conv); add_mem_op("Conv", &onnx_parser::parse_conv);
add_mem_op("MaxPool", &onnx_parser::parse_pooling); add_mem_op("MaxPool", &onnx_parser::parse_pooling);
...@@ -90,6 +92,7 @@ struct onnx_parser ...@@ -90,6 +92,7 @@ struct onnx_parser
add_mem_op("Gather", &onnx_parser::parse_gather); add_mem_op("Gather", &onnx_parser::parse_gather);
add_mem_op("Shape", &onnx_parser::parse_shape); add_mem_op("Shape", &onnx_parser::parse_shape);
add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill); add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill);
add_mem_op("ConstantOfShape", &onnx_parser::parse_constant_of_shape);
add_mem_op("Transpose", &onnx_parser::parse_transpose); add_mem_op("Transpose", &onnx_parser::parse_transpose);
add_mem_op("RNN", &onnx_parser::parse_rnn); add_mem_op("RNN", &onnx_parser::parse_rnn);
add_mem_op("GRU", &onnx_parser::parse_gru); add_mem_op("GRU", &onnx_parser::parse_gru);
...@@ -185,7 +188,13 @@ struct onnx_parser ...@@ -185,7 +188,13 @@ struct onnx_parser
s0.end(), s0.end(),
s1.begin() + offset, s1.begin() + offset,
out_lens.begin() + offset, out_lens.begin() + offset,
[](auto a, auto b) { return std::max(a, b); }); [](auto a, auto b) {
if (a != b and a != 1 and b != 1)
{
MIGRAPHX_THROW("COMPUTE_BROADCASTED_LEN: input shapes mismatch!");
}
return std::max(a, b);
});
return out_lens; return out_lens;
} }
...@@ -891,6 +900,61 @@ struct onnx_parser ...@@ -891,6 +900,61 @@ struct onnx_parser
} }
} }
instruction_ref parse_constant_of_shape(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
literal l_val{};
if (contains(attributes, "value"))
{
l_val = parse_value(attributes.at("value"));
if (l_val.get_shape().elements() != 1)
{
MIGRAPHX_THROW("ConstantOfShape: attribute value can contain only 1 elements!");
}
}
else
{
l_val = literal({shape::float_type, {1}, {0}}, {0.0f});
}
// input is empty, output is a scalar
auto type = l_val.get_shape().type();
if (args.size() == 0)
{
return prog.add_literal(literal({type, {1}, {0}}, l_val.data()));
}
else
{
migraphx::argument in = args[0]->eval();
if(in.empty())
{
MIGRAPHX_THROW("ConstantOfShape: cannot handle dynamic shape as input");
}
std::vector<std::size_t> dims;
in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
migraphx::shape s(type, dims);
literal l_out;
l_val.visit([&](auto val) {
using type = std::remove_cv_t<typename decltype(val)::value_type>;
// l_val contains only one element
std::vector<type> out_vec(s.elements(), *val.begin());
l_out = literal(s, out_vec);
});
return prog.add_literal(l_out);
}
}
instruction_ref parse_expand(const std::string&, attribute_map, std::vector<instruction_ref> args)
{
auto in_lens = args[0]->get_shape().lens();
auto ex_lens = args[1]->get_shape().lens();
auto out_lens = compute_broadcasted_lens(in_lens, ex_lens);
return prog.add_instruction(op::multibroadcast{out_lens}, std::move(args[0]));
}
std::vector<instruction_ref> std::vector<instruction_ref>
parse_rnn(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_rnn(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
...@@ -1313,6 +1377,18 @@ struct onnx_parser ...@@ -1313,6 +1377,18 @@ struct onnx_parser
} }
} }
instruction_ref parse_cast(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
if (!contains(attributes, "to"))
{
MIGRAPHX_THROW("PARSE_CAST: missing to type attribute!");
}
int to_type = parse_value(attributes.at("to")).at<int>();
shape::type_t type = get_type(to_type);
return prog.add_instruction(op::convert{type}, std::move(args));
}
void parse_from(std::istream& is) void parse_from(std::istream& is)
{ {
onnx::ModelProto model; onnx::ModelProto model;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment