Commit 4a39a0f7 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into add-conv_bn_add-test

parents 5564172e bb827865
#include <migraphx/onnx/onnx_parser.hpp> #include <migraphx/onnx/onnx_parser.hpp>
#include <migraphx/onnx/op_parser.hpp>
#include <iostream> #include <iostream>
#include <fstream> #include <fstream>
#include <unordered_map> #include <unordered_map>
...@@ -20,6 +21,7 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs) ...@@ -20,6 +21,7 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs)
parser.map_input_dims = options.map_input_dims; parser.map_input_dims = options.map_input_dims;
parser.default_dim_value = options.default_dim_value; parser.default_dim_value = options.default_dim_value;
parser.skip_unknown_operators = options.skip_unknown_operators; parser.skip_unknown_operators = options.skip_unknown_operators;
parser.max_loop_iterations = options.max_loop_iterations;
if(options.print_program_on_error) if(options.print_program_on_error)
{ {
...@@ -57,5 +59,7 @@ program parse_onnx_buffer(const void* data, std::size_t size, const onnx_options ...@@ -57,5 +59,7 @@ program parse_onnx_buffer(const void* data, std::size_t size, const onnx_options
return parse_onnx_from(options, data, size); return parse_onnx_from(options, data, size);
} }
std::vector<std::string> get_onnx_operators() { return onnx::get_op_parsers(); }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/pad_calc.hpp> #include <migraphx/pad_calc.hpp>
#include <migraphx/common.hpp>
#include <migraphx/type_traits.hpp> #include <migraphx/type_traits.hpp>
#include <migraphx/float_equal.hpp> #include <migraphx/float_equal.hpp>
#include <migraphx/file_buffer.hpp> #include <migraphx/file_buffer.hpp>
...@@ -84,73 +85,18 @@ instruction_ref onnx_parser::node_info::add_bias(const std::vector<instruction_r ...@@ -84,73 +85,18 @@ instruction_ref onnx_parser::node_info::add_bias(const std::vector<instruction_r
if(args.size() == 3) if(args.size() == 3)
{ {
auto bias_bcast = mod->add_instruction( auto bias_bcast = mod->add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", curr_ins->get_shape().lens()}}), make_op("broadcast", {{"axis", axis}, {"out_lens", curr_ins->get_shape().lens()}}),
args[2]); args[2]);
return mod->add_instruction(make_op("add"), curr_ins, bias_bcast); return mod->add_instruction(make_op("add"), curr_ins, bias_bcast);
} }
return curr_ins; return curr_ins;
} }
std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
std::vector<std::size_t> s1)
{
// Example:
// s0 = (3,2,4,5) and s1 = (2,1,1)
//
// In this case we need to broadcast (:,1,1) portion of
// s1 plus broadcast the 1st dimension of s1
// giving output_lens = (3,2,4,5)
//
// Another example:
// s0 = (3,2,1,5) and s1 = (2,7,5)
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
if(s0.size() > s1.size())
{
s0.swap(s1);
}
std::vector<std::size_t> out_lens(s1);
auto offset = s1.size() - s0.size();
std::transform(
s0.begin(), s0.end(), s1.begin() + offset, out_lens.begin() + offset, [&](auto a, auto b) {
if(a != b and a != 1 and b != 1)
{
MIGRAPHX_THROW("COMPUTE_BROADCASTLEN: shape {" + to_string_range(s0) + "} and {" +
to_string_range(s1) + "} mismatch!");
}
return std::max(a, b);
});
return out_lens;
}
instruction_ref onnx_parser::node_info::add_broadcastable_binary_op(const std::string& op_name, instruction_ref onnx_parser::node_info::add_broadcastable_binary_op(const std::string& op_name,
instruction_ref arg0, instruction_ref arg0,
instruction_ref arg1) const instruction_ref arg1) const
{ {
if(arg0->get_shape().lens() != arg1->get_shape().lens()) return add_common_op(*mod, make_op(op_name), {arg0, arg1});
{
// Get lengths for both arguments
auto s0 = arg0->get_shape().lens();
auto s1 = arg1->get_shape().lens();
auto out_lens = compute_broadcasted_lens(s0, s1);
auto l0 = arg0;
if(arg0->get_shape().lens() != out_lens)
l0 = add_instruction(make_op("multibroadcast", {{"output_lens", out_lens}}), arg0);
auto l1 = arg1;
if(arg1->get_shape().lens() != out_lens)
l1 = add_instruction(make_op("multibroadcast", {{"output_lens", out_lens}}), arg1);
return add_instruction(make_op(op_name), l0, l1);
}
else
{
return add_instruction(make_op(op_name), {arg0, arg1});
}
} }
instruction_ref instruction_ref
...@@ -278,28 +224,42 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model) ...@@ -278,28 +224,42 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model)
void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph) void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
{ {
std::unordered_map<std::string, instruction_ref> mod_insts;
for(auto&& f : graph.initializer()) for(auto&& f : graph.initializer())
{ {
instructions[f.name()] = mod->add_literal(parse_tensor(f)); // backup instructions in parent mod
mod_insts[f.name()] = mod->add_literal(parse_tensor(f));
} }
for(auto&& input : graph.input()) for(auto&& input : graph.input())
{ {
const std::string& name = input.name(); const std::string& name = input.name();
// input not in initializer_data, so it is a real input // input not in initializer_data, so it is a real input
if(!contains(instructions, name)) if(!contains(mod_insts, name))
{ {
// ONNX specification does not specify hwo to deal with the
// scenario that a nested subgraph contains a parameter with the
// name existed in its parent graph.
// In the current implementation, MIGraphX throws an exception for that.
if(contains(instructions, name))
{
MIGRAPHX_THROW("module \"" + mod->name() + "\" has parameter name \"" + name +
"\" existing in parent graph!");
}
std::vector<std::size_t> dims; std::vector<std::size_t> dims;
if(map_input_dims.count(name) > 0) if(map_input_dims.count(name) > 0)
{ {
dims = map_input_dims.at(name); dims = map_input_dims.at(name);
} }
shape s = parse_type(input.type(), dims); shape s = parse_type(input.type(), dims);
instructions[name] = mod->add_parameter(name, s); mod_insts[name] = mod->add_parameter(name, s);
} }
} }
std::copy(mod_insts.begin(), mod_insts.end(), std::inserter(instructions, instructions.end()));
for(auto&& node : graph.node()) for(auto&& node : graph.node())
{ {
std::vector<instruction_ref> args; std::vector<instruction_ref> args;
...@@ -363,6 +323,9 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph) ...@@ -363,6 +323,9 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
// add the return instuction // add the return instuction
mod->add_return(output_ins); mod->add_return(output_ins);
// remove instructions added in this mod
erase_if(instructions, [&](auto&& p) { return mod->has_instruction(p.second); });
} }
literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const
......
...@@ -126,7 +126,7 @@ void check_asym_padding(const onnx_parser::node_info& info, ...@@ -126,7 +126,7 @@ void check_asym_padding(const onnx_parser::node_info& info,
auto left_pad_it = padding.begin(); auto left_pad_it = padding.begin();
auto right_pad_it = left_pad_it + pad_ndims; auto right_pad_it = left_pad_it + pad_ndims;
if(is_asym_padding(padding) or count_include_pad == 1) if(count_include_pad == 1)
{ {
std::vector<int64_t> asym_pads{0, 0, 0, 0}; // don't pad N and C std::vector<int64_t> asym_pads{0, 0, 0, 0}; // don't pad N and C
// add left pads // add left pads
...@@ -134,10 +134,19 @@ void check_asym_padding(const onnx_parser::node_info& info, ...@@ -134,10 +134,19 @@ void check_asym_padding(const onnx_parser::node_info& info,
// add right pads // add right pads
asym_pads.insert(asym_pads.begin() + pad_ndims + 4, right_pad_it, padding.end()); asym_pads.insert(asym_pads.begin() + pad_ndims + 4, right_pad_it, padding.end());
ins = info.add_instruction(make_op("pad", {{"pads", asym_pads}, {"value", pad_val}}), ins); ins = info.add_instruction(make_op("pad", {{"pads", asym_pads}, {"value", pad_val}}), ins);
} std::vector<size_t> new_padding(padding.size());
else // subtract asym padding originally found from parsing the operator
{ std::transform(padding.begin(),
v["padding"] = std::vector<size_t>(left_pad_it, right_pad_it); left_pad_it,
asym_pads.begin() + 2,
new_padding.begin(),
std::minus<size_t>());
std::transform(right_pad_it,
padding.end(),
asym_pads.begin() + pad_ndims + 4,
new_padding.begin() + pad_ndims,
std::minus<size_t>());
v["padding"] = new_padding;
} }
} }
......
...@@ -36,7 +36,8 @@ struct parse_binary_op : op_parser<parse_binary_op> ...@@ -36,7 +36,8 @@ struct parse_binary_op : op_parser<parse_binary_op>
{ {
uint64_t axis = parser.parse_value(info.attributes.at("axis")).at<uint64_t>(); uint64_t axis = parser.parse_value(info.attributes.at("axis")).at<uint64_t>();
auto l = info.add_instruction( auto l = info.add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", args[0]->get_shape().lens()}}), make_op("broadcast",
{{"axis", axis}, {"out_lens", args[0]->get_shape().lens()}}),
args[1]); args[1]);
return info.add_instruction(make_op(opd.op_name), args[0], l); return info.add_instruction(make_op(opd.op_name), args[0], l);
} }
......
...@@ -47,13 +47,13 @@ struct parse_clip : op_parser<parse_clip> ...@@ -47,13 +47,13 @@ struct parse_clip : op_parser<parse_clip>
if(min_used) if(min_used)
{ {
min_arg = info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}), min_arg = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
min_arg); min_arg);
} }
if(max_used) if(max_used)
{ {
max_arg = info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}), max_arg = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
max_arg); max_arg);
} }
......
...@@ -73,7 +73,7 @@ struct parse_convolution : op_parser<parse_convolution> ...@@ -73,7 +73,7 @@ struct parse_convolution : op_parser<parse_convolution>
values["padding_mode"] = to_value(op::padding_mode_t::same); values["padding_mode"] = to_value(op::padding_mode_t::same);
} }
} }
check_asym_padding(info, l0, padding, values); values["padding"] = std::vector<size_t>(padding.begin(), padding.end());
if(contains(info.attributes, "group")) if(contains(info.attributes, "group"))
{ {
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_depthtospace : op_parser<parse_depthtospace>
{
std::vector<op_desc> operators() const { return {{"DepthToSpace"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto s = args[0]->get_shape();
// mode attribute of DepthToSpace
auto mode = std::string("DCR");
if(contains(info.attributes, "mode"))
{
mode = info.attributes.at("mode").s(); // DCR or CRD?
}
// blocksize attribute of DepthToSpace
int blocksize = 0;
if(contains(info.attributes, "blocksize"))
{
blocksize = info.attributes.at("blocksize").i();
}
if(blocksize < 1)
{
MIGRAPHX_THROW("DepthToSpace: blocksize is less than 1");
}
// calculate dimensions
auto lens1 = s.lens();
auto lens2 = s.lens();
unsigned long divisor = std::pow(blocksize, 2);
if((lens2[1] % divisor) == 0)
lens2[1] = lens2[1] / divisor;
else
MIGRAPHX_THROW("DepthToSpace: div by blocksize quotient not int ");
lens1.push_back(lens1[2]);
lens1.push_back(lens1[3]);
lens2[2] = lens2[2] * blocksize;
lens2[3] = lens2[3] * blocksize;
lens1[2] = blocksize;
std::vector<int64_t> perm;
if(mode == "DCR")
{
lens1[3] = lens1[1] / divisor;
lens1[1] = blocksize;
perm = {0, 3, 4, 1, 5, 2};
}
else if(mode == "CRD")
{
lens1[1] = lens1[1] / divisor;
lens1[3] = blocksize;
perm = {0, 1, 4, 2, 5, 3};
}
else
MIGRAPHX_THROW("DepthToSpace: mode attribute cannot be read.");
auto temp1 = info.add_instruction(make_op("reshape", {{"dims", lens1}}), args[0]);
auto temp2 = info.add_instruction(make_op("transpose", {{"permutation", perm}}), temp1);
return info.add_instruction(make_op("reshape", {{"dims", lens2}}),
info.make_contiguous(temp2));
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -15,46 +15,49 @@ struct parse_dequantizelinear : op_parser<parse_dequantizelinear> ...@@ -15,46 +15,49 @@ struct parse_dequantizelinear : op_parser<parse_dequantizelinear>
instruction_ref parse(const op_desc& opd, instruction_ref parse(const op_desc& opd,
const onnx_parser& /*parser*/, const onnx_parser& /*parser*/,
const onnx_parser::node_info& info, const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const const std::vector<instruction_ref>& args) const
{ {
int axis = 1; int axis = 1;
if(contains(info.attributes, "axis")) if(contains(info.attributes, "axis"))
axis = info.attributes.at("axis").i(); axis = info.attributes.at("axis").i();
auto input_lens = args[0]->get_shape().lens(); auto input_lens = args[0]->get_shape().lens();
int n_dim = static_cast<int>(input_lens.size()); auto n_dim = input_lens.size();
auto sub_zero_point = args[0]; instruction_ref x_scale;
if(args[1]->get_shape().elements() != 1)
{
auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
x_scale = info.add_instruction(
make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}), args[1]);
}
else
{
x_scale = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
args[1]);
}
if(args.size() == 3) if(args.size() == 3)
{ {
auto zero_point = args[2]; auto x_zero_point = args[2];
if(not(zero_point->get_shape().elements() == 1)) if(x_zero_point->get_shape().elements() != 1)
{ {
axis = tune_axis(n_dim, axis, opd.op_name); auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
zero_point = info.add_instruction( x_zero_point = info.add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", input_lens}}), zero_point); make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}),
x_zero_point);
}
else
{
x_zero_point = info.add_instruction(
make_op("multibroadcast", {{"out_lens", input_lens}}), x_zero_point);
} }
auto zero_point_int32 = info.add_instruction( return info.add_instruction(
make_op("convert", {{"target_type", shape::int32_type}}), zero_point); make_op("dequantizelinear"), args[0], x_scale, x_zero_point);
auto sub_zero_point_int32 = info.add_instruction(
make_op("convert", {{"target_type", shape::int32_type}}), sub_zero_point);
sub_zero_point =
info.add_broadcastable_binary_op("sub", sub_zero_point_int32, zero_point_int32);
} }
auto dequant_input = info.add_instruction( return info.add_instruction(make_op("dequantizelinear"), args[0], x_scale);
make_op("convert", {{"target_type", shape::float_type}}), sub_zero_point);
auto scale = args[1];
if(not(scale->get_shape().elements() == 1))
{
axis = tune_axis(n_dim, axis, opd.op_name);
scale = info.add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", input_lens}}), scale);
}
return info.add_broadcastable_binary_op("mul", dequant_input, scale);
} }
}; };
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <migraphx/onnx/checks.hpp> #include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/common.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
namespace migraphx { namespace migraphx {
...@@ -23,8 +24,7 @@ struct parse_expand : op_parser<parse_expand> ...@@ -23,8 +24,7 @@ struct parse_expand : op_parser<parse_expand>
std::vector<std::size_t> dims; std::vector<std::size_t> dims;
arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); }); arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
auto out_lens = compute_broadcasted_lens(in_lens, dims); auto out_lens = compute_broadcasted_lens(in_lens, dims);
return info.add_instruction(make_op("multibroadcast", {{"output_lens", out_lens}}), return info.add_instruction(make_op("multibroadcast", {{"out_lens", out_lens}}), args[0]);
args[0]);
} }
}; };
......
...@@ -63,8 +63,8 @@ struct parse_gather_elements : op_parser<parse_gather_elements> ...@@ -63,8 +63,8 @@ struct parse_gather_elements : op_parser<parse_gather_elements>
info.add_literal(literal(ind_s, data_indices.begin(), data_indices.end())); info.add_literal(literal(ind_s, data_indices.begin(), data_indices.end()));
auto l_dim_idx = info.add_literal(literal(ind_s, vec_axis_ind.begin(), vec_axis_ind.end())); auto l_dim_idx = info.add_literal(literal(ind_s, vec_axis_ind.begin(), vec_axis_ind.end()));
auto l_stride = info.add_literal(literal{{ind_s.type(), {1}}, {axis_stride}}); auto l_stride = info.add_literal(literal{{ind_s.type(), {1}}, {axis_stride}});
l_stride = info.add_instruction(make_op("multibroadcast", {{"output_lens", ind_s.lens()}}), l_stride =
l_stride); info.add_instruction(make_op("multibroadcast", {{"out_lens", ind_s.lens()}}), l_stride);
auto dim_diff = info.add_instruction(make_op("sub"), arg_ind, l_dim_idx); auto dim_diff = info.add_instruction(make_op("sub"), arg_ind, l_dim_idx);
auto delta = info.add_instruction(make_op("mul"), dim_diff, l_stride); auto delta = info.add_instruction(make_op("mul"), dim_diff, l_stride);
auto ind = info.add_instruction(make_op("add"), l_shape_idx, delta); auto ind = info.add_instruction(make_op("add"), l_shape_idx, delta);
......
...@@ -42,13 +42,30 @@ struct parse_gemm : op_parser<parse_gemm> ...@@ -42,13 +42,30 @@ struct parse_gemm : op_parser<parse_gemm>
// swap the last two elements // swap the last two elements
std::swap(*perm.rbegin(), *(perm.rbegin() + 1)); std::swap(*perm.rbegin(), *(perm.rbegin() + 1));
auto l1 = (transa) ? info.add_instruction(make_op("transpose", {{"dims", perm}}), args[0]) auto l1 = args[0];
: args[0]; auto dot_type = l1->get_shape().type();
auto l2 = (transb) ? info.add_instruction(make_op("transpose", {{"dims", perm}}), args[1])
: args[1]; if(alpha != 1.0f)
{
auto alpha_literal = info.add_literal(alpha);
l1 = info.add_broadcastable_binary_op("mul", alpha_literal, l1);
if(l1->get_shape().type() != dot_type)
{
l1 = info.add_instruction(make_op("convert", {{"target_type", dot_type}}), l1);
}
}
l1 =
(transa) ? info.add_instruction(make_op("transpose", {{"permutation", perm}}), l1) : l1;
auto l2 = (transb)
? info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[1])
: args[1];
auto ret = info.add_instruction(make_op("dot"), l1, l2);
if(args.size() == 3) if(args.size() == 3)
{ {
if(beta != 0.f && args[2]->get_shape().elements() > 0) if(not float_equal(beta, 0.0f) && args[2]->get_shape().elements() > 0)
{ {
auto out_lens = l1->get_shape().lens(); auto out_lens = l1->get_shape().lens();
out_lens.back() = l2->get_shape().lens().back(); out_lens.back() = l2->get_shape().lens().back();
...@@ -56,15 +73,22 @@ struct parse_gemm : op_parser<parse_gemm> ...@@ -56,15 +73,22 @@ struct parse_gemm : op_parser<parse_gemm>
auto l3_lens = l3->get_shape().lens(); auto l3_lens = l3->get_shape().lens();
if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end())) if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
{ {
l3 = info.add_instruction( l3 = info.add_instruction(make_op("multibroadcast", {{"out_lens", out_lens}}),
make_op("multibroadcast", {{"output_lens", out_lens}}), args[2]); args[2]);
} }
return info.add_instruction( auto beta_literal = info.add_literal(beta);
make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3); auto beta_l3 = info.add_broadcastable_binary_op("mul", l3, beta_literal);
if(beta_l3->get_shape().type() != dot_type)
{
beta_l3 = info.add_instruction(make_op("convert", {{"target_type", dot_type}}),
beta_l3);
}
return info.add_instruction(make_op("add"), ret, beta_l3);
} }
} }
return info.add_instruction(make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2); return ret;
} }
}; };
......
...@@ -35,6 +35,8 @@ struct parse_generic_op : op_parser<parse_generic_op> ...@@ -35,6 +35,8 @@ struct parse_generic_op : op_parser<parse_generic_op>
{"Reciprocal", "recip"}, {"Reciprocal", "recip"},
{"Relu", "relu"}, {"Relu", "relu"},
{"Round", "round"}, {"Round", "round"},
{"Scatter", "scatter"},
{"ScatterElements", "scatter"},
{"Sigmoid", "sigmoid"}, {"Sigmoid", "sigmoid"},
{"Sign", "sign"}, {"Sign", "sign"},
{"Sin", "sin"}, {"Sin", "sin"},
...@@ -47,7 +49,7 @@ struct parse_generic_op : op_parser<parse_generic_op> ...@@ -47,7 +49,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
bool needs_contiguous(const std::string& op_name) const bool needs_contiguous(const std::string& op_name) const
{ {
return contains({"gather"}, op_name); return contains({"flatten", "gather", "scatter"}, op_name);
} }
instruction_ref parse(const op_desc& opd, instruction_ref parse(const op_desc& opd,
......
#include <migraphx/instruction_ref.hpp>
#include <migraphx/onnx/op_parser.hpp> #include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/onnx_parser.hpp> #include <migraphx/onnx/onnx_parser.hpp>
#include <migraphx/onnx/checks.hpp> #include <migraphx/onnx/checks.hpp>
...@@ -26,59 +27,41 @@ struct parse_if : op_parser<parse_if> ...@@ -26,59 +27,41 @@ struct parse_if : op_parser<parse_if>
MIGRAPHX_THROW("PARSE_IF: condition input can have only one element!"); MIGRAPHX_THROW("PARSE_IF: condition input can have only one element!");
} }
migraphx::argument cond_arg = args.front()->eval(); std::string then_name = info.name + "_if";
// cond is not constant, need to create sub_modules module_ref then_mdl = parser.prog.create_module(then_name);
if(cond_arg.empty())
{
std::string then_name = info.name + "_if";
module_ref then_mdl = parser.prog.create_module(then_name);
std::string else_name = info.name + "_else";
module_ref else_mdl = parser.prog.create_module(else_name);
// parse the then sub_graph std::string else_name = info.name + "_else";
parser.parse_graph(then_mdl, then_graph); module_ref else_mdl = parser.prog.create_module(else_name);
// parse_the else sub_graph // parse the then sub_graph
parser.parse_graph(else_mdl, else_graph); parser.parse_graph(then_mdl, then_graph);
auto then_out_shapes = then_mdl->get_output_shapes(); // parse_the else sub_graph
auto else_out_shapes = else_mdl->get_output_shapes(); parser.parse_graph(else_mdl, else_graph);
if(not std::equal(then_out_shapes.begin(),
then_out_shapes.end(),
else_out_shapes.begin(),
else_out_shapes.end()))
{
MIGRAPHX_THROW("PARSE_IF: then and else sub_grahps must have same output shapes!");
}
auto ret = info.add_instruction(make_op("if"), args, {then_mdl, else_mdl}); auto then_out_shapes = then_mdl->get_output_shapes();
auto else_out_shapes = else_mdl->get_output_shapes();
return {ret}; if(not std::equal(then_out_shapes.begin(),
} then_out_shapes.end(),
else else_out_shapes.begin(),
else_out_shapes.end()))
{ {
auto* mod = info.mod; MIGRAPHX_THROW("PARSE_IF: then and else sub_grahps must have same output shapes!");
// then branch }
if(cond_arg.at<bool>())
{
parser.parse_graph(mod, then_graph);
}
// else branch
else
{
parser.parse_graph(mod, else_graph);
}
// inputs of the return instruction are that of the output of the auto if_ret = info.add_instruction(make_op("if"), args, {then_mdl, else_mdl});
// if instruction auto out_s = if_ret->get_shape();
instruction_ref ret_ins = std::prev(mod->end()); assert(out_s.type() == shape::tuple_type);
auto outputs = ret_ins->inputs();
assert(ret_ins->name() == "@return");
mod->remove_instruction(ret_ins);
return outputs; const auto& vec_shapes = out_s.sub_shapes();
std::vector<instruction_ref> out_inss;
for(std::size_t i = 0; i < vec_shapes.size(); ++i)
{
auto ret = info.add_instruction(make_op("get_tuple_elem", {{"index", i}}), if_ret);
out_inss.push_back(ret);
} }
return out_inss;
} }
}; };
......
...@@ -40,7 +40,7 @@ struct parse_imagescalar : op_parser<parse_imagescalar> ...@@ -40,7 +40,7 @@ struct parse_imagescalar : op_parser<parse_imagescalar>
auto img_scaled = auto img_scaled =
info.add_instruction(migraphx::make_op("mul"), args.front(), scale_tensor); info.add_instruction(migraphx::make_op("mul"), args.front(), scale_tensor);
auto bias_bcast = info.add_instruction( auto bias_bcast = info.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", input_lens}}), bias_vals); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", input_lens}}), bias_vals);
return info.add_instruction(migraphx::make_op("add"), img_scaled, bias_bcast); return info.add_instruction(migraphx::make_op("add"), img_scaled, bias_bcast);
} }
}; };
......
...@@ -38,23 +38,23 @@ struct parse_instancenorm : op_parser<parse_instancenorm> ...@@ -38,23 +38,23 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
auto mean = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x); auto mean = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x);
auto mean_bcast = auto mean_bcast =
info.add_instruction(make_op("multibroadcast", {{"output_lens", dims}}), mean); info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), mean);
auto l0 = info.add_instruction(make_op("sqdiff"), x, mean_bcast); auto l0 = info.add_instruction(make_op("sqdiff"), x, mean_bcast);
auto variance = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), l0); auto variance = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), l0);
auto l1 = info.add_instruction(make_op("sub"), x, mean_bcast); auto l1 = info.add_instruction(make_op("sub"), x, mean_bcast);
auto epsilon_literal = info.add_literal(epsilon); auto epsilon_literal = info.add_literal(epsilon);
auto epsilon_bcast = info.add_instruction( auto epsilon_bcast =
make_op("multibroadcast", {{"output_lens", dims}}), epsilon_literal); info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), epsilon_literal);
auto variance_bcast = auto variance_bcast =
info.add_instruction(make_op("multibroadcast", {{"output_lens", dims}}), variance); info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), variance);
auto l2 = info.add_instruction(make_op("add"), variance_bcast, epsilon_bcast); auto l2 = info.add_instruction(make_op("add"), variance_bcast, epsilon_bcast);
auto l3 = info.add_instruction(make_op("rsqrt"), l2); auto l3 = info.add_instruction(make_op("rsqrt"), l2);
auto l4 = info.add_instruction(make_op("mul"), l1, l3); auto l4 = info.add_instruction(make_op("mul"), l1, l3);
auto scale_bcast = auto scale_bcast =
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"dims", dims}}), scale); info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), scale);
; ;
auto bias_bcast = auto bias_bcast =
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"dims", dims}}), bias); info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), bias);
auto l5 = info.add_instruction(make_op("mul"), l4, scale_bcast); auto l5 = info.add_instruction(make_op("mul"), l4, scale_bcast);
return info.add_instruction(make_op("add"), l5, bias_bcast); return info.add_instruction(make_op("add"), l5, bias_bcast);
} }
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/onnx_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_loop : op_parser<parse_loop>
{
std::vector<op_desc> operators() const { return {{"Loop"}}; }
std::vector<instruction_ref> parse(const op_desc& /*opd*/,
onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
// default value of the max_iter_num
int64_t max_iterations = parser.max_loop_iterations;
// iteration input is empty
if(args.at(0)->name() == "undefined")
{
shape iter_s{shape::int64_type};
args[0] = info.add_literal(literal(iter_s, {max_iterations}));
}
else
{
auto arg_iters = args.at(0)->eval();
if(not arg_iters.empty())
{
max_iterations = arg_iters.at<int64_t>();
}
}
// condition input is empty
if(args.at(1)->name() == "undefined")
{
shape cond_s{shape::bool_type};
args[1] = info.add_literal(literal(cond_s, {true}));
}
// retrieve the subgraph
const auto& sub_graph = info.attributes.at("body").g();
std::string mod_name = info.name + "_loop";
module_ref sub_mod = parser.prog.create_module(mod_name);
// parse the sub_graph
parser.parse_graph(sub_mod, sub_graph);
auto ret = info.add_instruction(
make_op("loop", {{"max_iterations", max_iterations}}), args, {sub_mod});
auto out_s = ret->get_shape();
assert(out_s.type() == shape::tuple_type);
const auto& vec_shapes = out_s.sub_shapes();
std::vector<instruction_ref> out_inss;
for(std::size_t i = 0; i < vec_shapes.size(); ++i)
{
auto r = info.add_instruction(make_op("get_tuple_elem", {{"index", i}}), ret);
out_inss.push_back(r);
}
return out_inss;
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp> #include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/common.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
namespace migraphx { namespace migraphx {
...@@ -57,18 +58,16 @@ struct parse_matmul : op_parser<parse_matmul> ...@@ -57,18 +58,16 @@ struct parse_matmul : op_parser<parse_matmul>
if(l0_lens != l0_broadcasted_lens) if(l0_lens != l0_broadcasted_lens)
{ {
bl0 = info.add_instruction( bl0 = info.add_instruction(
make_op("multibroadcast", {{"output_lens", l0_broadcasted_lens}}), l0); make_op("multibroadcast", {{"out_lens", l0_broadcasted_lens}}), l0);
} }
if(l1_lens != l1_broadcasted_lens) if(l1_lens != l1_broadcasted_lens)
{ {
bl1 = info.add_instruction( bl1 = info.add_instruction(
make_op("multibroadcast", {{"output_lens", l1_broadcasted_lens}}), l1); make_op("multibroadcast", {{"out_lens", l1_broadcasted_lens}}), l1);
} }
} }
instruction_ref dot_res = info.add_instruction(make_op(opd.op_name), bl0, bl1);
auto dot_res = int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size());
info.add_instruction(make_op(opd.op_name, {{"alpha", 1}, {"beta", 0}}), bl0, bl1);
int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size());
if(is_a_prepended) if(is_a_prepended)
{ {
dot_res = info.add_instruction(make_op("squeeze", {{"axes", {num_axis - 2}}}), dot_res); dot_res = info.add_instruction(make_op("squeeze", {{"axes", {num_axis - 2}}}), dot_res);
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <random>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_multinomial : op_parser<parse_multinomial>
{
std::vector<op_desc> operators() const { return {{"Multinomial"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
int dtype = 6;
if(contains(info.attributes, "dtype"))
dtype = info.attributes.at("dtype").i();
shape::type_t output_type = get_type(dtype);
size_t sample_size = 1;
if(contains(info.attributes, "sample_size"))
sample_size = info.attributes.at("sample_size").i();
float seed = static_cast<float>(
std::chrono::high_resolution_clock::now().time_since_epoch().count());
if(contains(info.attributes, "seed"))
seed = info.attributes.at("seed").f();
// Subtract the per-batch maximum log-probability, making the per-batch max 0
auto maxes =
info.add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), args[0]);
auto mb_maxes = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", args[0]->get_shape().lens()}}),
maxes);
auto cdf = info.add_instruction(migraphx::make_op("sub"), args[0], mb_maxes);
// Take the element-wise exponent to get probabilities in the range (0, 1]
cdf = info.add_instruction(migraphx::make_op("exp"), cdf);
// Compute the cumulative density function
cdf = info.add_instruction(
migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), cdf);
// Pre-compute random distribution
std::mt19937 gen(seed);
std::uniform_real_distribution<> dis(0.0, 1.0);
size_t batch_size = args[0]->get_shape().lens().front();
migraphx::shape dist_shape{migraphx::shape::float_type, {batch_size, sample_size}};
std::vector<float> random_dist(batch_size * sample_size);
std::generate(random_dist.begin(), random_dist.end(), [&]() { return dis(gen); });
auto dist_lit = info.add_literal(migraphx::literal{dist_shape, random_dist});
return info.add_instruction(
migraphx::make_op("multinomial", {{"dtype", output_type}}), cdf, dist_lit);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -9,7 +9,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -9,7 +9,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace onnx { namespace onnx {
template <class T> template <class T>
std::vector<std::size_t> nonzero_indices(const std::vector<T>& data) static std::vector<std::size_t> nonzero_indices(const std::vector<T>& data)
{ {
std::vector<std::size_t> indices; std::vector<std::size_t> indices;
for(std::size_t i = 0; i < data.size(); ++i) for(std::size_t i = 0; i < data.size(); ++i)
...@@ -31,30 +31,35 @@ struct parse_nonzero : op_parser<parse_nonzero> ...@@ -31,30 +31,35 @@ struct parse_nonzero : op_parser<parse_nonzero>
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
migraphx::argument data_arg = args.back()->eval(); migraphx::argument data_arg = args.back()->eval();
check_arg_empty(data_arg, "PARSE_NONZERO: cannot support non-constant input!"); if(data_arg.empty())
{
std::vector<std::size_t> indices; return info.add_instruction(make_op("nonzero"), args);
data_arg.visit([&](auto val) { }
using val_type = std::remove_cv_t<typename decltype(val)::value_type>; else
std::vector<val_type> vec_data; {
vec_data.assign(val.begin(), val.end()); std::vector<std::size_t> indices;
indices = nonzero_indices(vec_data); data_arg.visit([&](auto val) {
}); using val_type = std::remove_cv_t<typename decltype(val)::value_type>;
std::vector<val_type> vec_data;
vec_data.assign(val.begin(), val.end());
indices = nonzero_indices(vec_data);
});
shape in_s = args[0]->get_shape(); shape in_s = args[0]->get_shape();
shape out_s{shape::int64_type, {in_s.lens().size(), indices.size()}}; shape out_s{shape::int64_type, {in_s.lens().size(), indices.size()}};
std::vector<int64_t> out_data(out_s.elements()); std::vector<int64_t> out_data(out_s.elements());
for(std::size_t i = 0; i < indices.size(); ++i) for(std::size_t i = 0; i < indices.size(); ++i)
{
auto idx = in_s.multi(indices[i]);
for(std::size_t j = 0; j < in_s.lens().size(); ++j)
{ {
out_data[out_s.index({j, i})] = idx[j]; auto idx = in_s.multi(indices[i]);
for(std::size_t j = 0; j < in_s.lens().size(); ++j)
{
out_data[out_s.index({j, i})] = idx[j];
}
} }
}
return info.add_literal(literal(out_s, out_data)); return info.add_literal(literal(out_s, out_data));
}
} }
}; };
......
...@@ -45,8 +45,9 @@ struct parse_onehot : op_parser<parse_onehot> ...@@ -45,8 +45,9 @@ struct parse_onehot : op_parser<parse_onehot>
std::vector<int64_t> perm(n_rank - 1); std::vector<int64_t> perm(n_rank - 1);
std::iota(perm.begin(), perm.end(), 0); std::iota(perm.begin(), perm.end(), 0);
perm.insert(perm.begin() + tuned_axis, n_rank - 1); perm.insert(perm.begin() + tuned_axis, n_rank - 1);
auto tr_out = info.add_instruction(make_op("transpose", {{"dims", perm}}), gather_out); auto tr_out =
auto lens = tr_out->get_shape().lens(); info.add_instruction(make_op("transpose", {{"permutation", perm}}), gather_out);
auto lens = tr_out->get_shape().lens();
auto off_val = info.add_instruction( auto off_val = info.add_instruction(
make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]); make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]);
...@@ -54,9 +55,9 @@ struct parse_onehot : op_parser<parse_onehot> ...@@ -54,9 +55,9 @@ struct parse_onehot : op_parser<parse_onehot>
make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]); make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]);
auto diff = info.add_instruction(make_op("sub"), on_val, off_val); auto diff = info.add_instruction(make_op("sub"), on_val, off_val);
auto unsq_off_val = auto unsq_off_val =
info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), off_val); info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), off_val);
auto unsq_diff_val = auto unsq_diff_val =
info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), diff); info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), diff);
auto l_mul = info.add_instruction(make_op("mul"), tr_out, unsq_diff_val); auto l_mul = info.add_instruction(make_op("mul"), tr_out, unsq_diff_val);
return info.add_instruction(make_op("add"), l_mul, unsq_off_val); return info.add_instruction(make_op("add"), l_mul, unsq_off_val);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment