"vscode:/vscode.git/clone" did not exist on "dd6ec02965254291b7bf2c1a90f5eb9a5a5051d4"
Commit 7e297b13 authored by Paul's avatar Paul
Browse files

Merge

parents 86ea5e91 aa7ff911
......@@ -70,12 +70,14 @@ static literal from_repeated(shape::type_t t, const T& r)
instruction_ref onnx_parser::node_info::make_contiguous(instruction_ref ins) const
{
if(ins->get_shape().standard())
auto attr = ins->get_operator().to_value();
std::string key = "require_std_shape";
if((attr.get(key, false)) or (not ins->get_shape().standard()))
{
return ins;
return add_instruction(make_op("contiguous"), ins);
}
return add_instruction(make_op("contiguous"), ins);
return ins;
}
instruction_ref onnx_parser::node_info::add_bias(const std::vector<instruction_ref>& args,
......@@ -85,7 +87,7 @@ instruction_ref onnx_parser::node_info::add_bias(const std::vector<instruction_r
if(args.size() == 3)
{
auto bias_bcast = mod->add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", curr_ins->get_shape().lens()}}),
make_op("broadcast", {{"axis", axis}, {"out_lens", curr_ins->get_shape().lens()}}),
args[2]);
return mod->add_instruction(make_op("add"), curr_ins, bias_bcast);
}
......@@ -96,7 +98,13 @@ instruction_ref onnx_parser::node_info::add_broadcastable_binary_op(const std::s
instruction_ref arg0,
instruction_ref arg1) const
{
return add_common_op(*mod, make_op(op_name), {arg0, arg1});
return this->add_common_op(op_name, arg0, arg1);
}
instruction_ref onnx_parser::node_info::add_common_op(const std::string& op_name,
std::vector<instruction_ref> inputs) const
{
return migraphx::add_common_op(*mod, make_op(op_name), std::move(inputs));
}
instruction_ref
......@@ -224,28 +232,42 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model)
void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
{
std::unordered_map<std::string, instruction_ref> mod_insts;
for(auto&& f : graph.initializer())
{
instructions[f.name()] = mod->add_literal(parse_tensor(f));
// backup instructions in parent mod
mod_insts[f.name()] = mod->add_literal(parse_tensor(f));
}
for(auto&& input : graph.input())
{
const std::string& name = input.name();
// input not in initializer_data, so it is a real input
if(!contains(instructions, name))
if(!contains(mod_insts, name))
{
// ONNX specification does not specify hwo to deal with the
// scenario that a nested subgraph contains a parameter with the
// name existed in its parent graph.
// In the current implementation, MIGraphX throws an exception for that.
if(contains(instructions, name))
{
MIGRAPHX_THROW("module \"" + mod->name() + "\" has parameter name \"" + name +
"\" existing in parent graph!");
}
std::vector<std::size_t> dims;
if(map_input_dims.count(name) > 0)
{
dims = map_input_dims.at(name);
}
shape s = parse_type(input.type(), dims);
instructions[name] = mod->add_parameter(name, s);
shape s = parse_type(input.type(), dims);
mod_insts[name] = mod->add_parameter(name, s);
}
}
std::copy(mod_insts.begin(), mod_insts.end(), std::inserter(instructions, instructions.end()));
for(auto&& node : graph.node())
{
std::vector<instruction_ref> args;
......@@ -309,6 +331,9 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
// add the return instuction
mod->add_return(output_ins);
// remove instructions added in this mod
erase_if(instructions, [&](auto&& p) { return mod->has_instruction(p.second); });
}
literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const
......@@ -363,8 +388,7 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, t.int64_data());
case onnx::TensorProto::UINT64:
return create_literal(shape::uint64_type, dims, t.uint64_data());
case onnx::TensorProto::FLOAT16:
{
case onnx::TensorProto::FLOAT16: {
std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end());
std::vector<half> data_half;
std::transform(data_uint16.begin(),
......@@ -434,7 +458,8 @@ shape::type_t get_type(int dtype)
case 11: return shape::double_type;
case 12: return shape::uint32_type;
case 13: return shape::uint64_type;
default: { MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported");
default: {
MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported");
}
}
}
......
......@@ -3,6 +3,7 @@
#include <migraphx/pad_calc.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -94,7 +95,7 @@ void tune_padding_size(const value& v,
std::vector<int64_t>& s_start)
{
// maxpooling or count_include_pad is 1, no change is required.
if(v.at("mode").to<std::string>() == "max" or count_include_pad == 1)
if(v.at("mode").to<op::pooling_mode>() == op::pooling_mode::max or count_include_pad == 1)
{
return;
}
......
......@@ -36,7 +36,8 @@ struct parse_binary_op : op_parser<parse_binary_op>
{
uint64_t axis = parser.parse_value(info.attributes.at("axis")).at<uint64_t>();
auto l = info.add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", args[0]->get_shape().lens()}}),
make_op("broadcast",
{{"axis", axis}, {"out_lens", args[0]->get_shape().lens()}}),
args[1]);
return info.add_instruction(make_op(opd.op_name), args[0], l);
}
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_celu : op_parser<parse_celu>
{
std::vector<op_desc> operators() const { return {{"Celu"}}; }
instruction_ref parse(const op_desc&,
const onnx_parser&,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
float alpha = 1.0;
if(contains(info.attributes, "alpha"))
{
alpha = info.attributes.at("alpha").f();
}
if(float_equal(alpha, 0.0f))
{
MIGRAPHX_THROW("CELU: alpha is zero (division by zero)");
}
auto input_lens = args[0]->get_shape().lens();
auto input_type = args[0]->get_shape().type();
if(input_type != migraphx::shape::float_type)
{
MIGRAPHX_THROW("CELU: input tensor not float type");
}
auto zero_lit = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {0.}}));
auto one_lit = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {1.}}));
auto alpha_lit = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {alpha}}));
auto linear_part = info.add_instruction(migraphx::make_op("max"), zero_lit, args[0]);
auto divi = info.add_instruction(migraphx::make_op("div"), args[0], alpha_lit);
auto expo = info.add_instruction(migraphx::make_op("exp"), divi);
auto sub = info.add_instruction(migraphx::make_op("sub"), expo, one_lit);
auto mul = info.add_instruction(migraphx::make_op("mul"), alpha_lit, sub);
auto exp_part = info.add_instruction(migraphx::make_op("min"), zero_lit, mul);
return info.add_instruction(migraphx::make_op("add"), linear_part, exp_part);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -16,7 +16,6 @@ struct parse_clip : op_parser<parse_clip>
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
auto input_lens = args[0]->get_shape().lens();
instruction_ref min_arg;
instruction_ref max_arg;
bool min_used = false;
......@@ -45,29 +44,17 @@ struct parse_clip : op_parser<parse_clip>
max_used = true;
}
if(min_used)
{
min_arg = info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}),
min_arg);
}
if(max_used)
{
max_arg = info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}),
max_arg);
}
if(min_used and max_used)
{
return info.add_instruction(make_op("clip"), args[0], min_arg, max_arg);
return info.add_common_op("clip", args[0], min_arg, max_arg);
}
else if(max_used)
{
return info.add_instruction(make_op("min"), args[0], max_arg);
return info.add_broadcastable_binary_op("min", args[0], max_arg);
}
else if(min_used)
{
return info.add_instruction(make_op("max"), args[0], min_arg);
return info.add_broadcastable_binary_op("max", args[0], min_arg);
}
else
{
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_depthtospace : op_parser<parse_depthtospace>
{
std::vector<op_desc> operators() const { return {{"DepthToSpace"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto s = args[0]->get_shape();
// mode attribute of DepthToSpace
auto mode = std::string("DCR");
if(contains(info.attributes, "mode"))
{
mode = info.attributes.at("mode").s(); // DCR or CRD?
}
// blocksize attribute of DepthToSpace
int blocksize = 0;
if(contains(info.attributes, "blocksize"))
{
blocksize = info.attributes.at("blocksize").i();
}
if(blocksize < 1)
{
MIGRAPHX_THROW("DepthToSpace: blocksize is less than 1");
}
// calculate dimensions
auto lens1 = s.lens();
auto lens2 = s.lens();
unsigned long divisor = std::pow(blocksize, 2);
if((lens2[1] % divisor) == 0)
lens2[1] = lens2[1] / divisor;
else
MIGRAPHX_THROW("DepthToSpace: div by blocksize quotient not int ");
lens1.push_back(lens1[2]);
lens1.push_back(lens1[3]);
lens2[2] = lens2[2] * blocksize;
lens2[3] = lens2[3] * blocksize;
lens1[2] = blocksize;
std::vector<int64_t> perm;
if(mode == "DCR")
{
lens1[3] = lens1[1] / divisor;
lens1[1] = blocksize;
perm = {0, 3, 4, 1, 5, 2};
}
else if(mode == "CRD")
{
lens1[1] = lens1[1] / divisor;
lens1[3] = blocksize;
perm = {0, 1, 4, 2, 5, 3};
}
else
MIGRAPHX_THROW("DepthToSpace: mode attribute cannot be read.");
auto temp1 = info.add_instruction(make_op("reshape", {{"dims", lens1}}), args[0]);
auto temp2 = info.add_instruction(make_op("transpose", {{"permutation", perm}}), temp1);
return info.add_instruction(make_op("reshape", {{"dims", lens2}}),
info.make_contiguous(temp2));
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -29,11 +29,11 @@ struct parse_dequantizelinear : op_parser<parse_dequantizelinear>
{
auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
x_scale = info.add_instruction(
make_op("broadcast", {{"axis", tuned_axis}, {"dims", input_lens}}), args[1]);
make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}), args[1]);
}
else
{
x_scale = info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}),
x_scale = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
args[1]);
}
......@@ -44,13 +44,13 @@ struct parse_dequantizelinear : op_parser<parse_dequantizelinear>
{
auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
x_zero_point = info.add_instruction(
make_op("broadcast", {{"axis", tuned_axis}, {"dims", input_lens}}),
make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}),
x_zero_point);
}
else
{
x_zero_point = info.add_instruction(
make_op("multibroadcast", {{"output_lens", input_lens}}), x_zero_point);
make_op("multibroadcast", {{"out_lens", input_lens}}), x_zero_point);
}
return info.add_instruction(
......
......@@ -24,8 +24,7 @@ struct parse_expand : op_parser<parse_expand>
std::vector<std::size_t> dims;
arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
auto out_lens = compute_broadcasted_lens(in_lens, dims);
return info.add_instruction(make_op("multibroadcast", {{"output_lens", out_lens}}),
args[0]);
return info.add_instruction(make_op("multibroadcast", {{"out_lens", out_lens}}), args[0]);
}
};
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_eyelike : op_parser<parse_eyelike>
{
std::vector<op_desc> operators() const { return {{"EyeLike"}}; }
instruction_ref parse(const op_desc&,
const onnx_parser&,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto input_shape = args[0]->get_shape();
auto input_lens = input_shape.lens();
if(input_lens.size() != 2)
{
MIGRAPHX_THROW("EYELIKE: tensor input not of rank 2");
}
std::ptrdiff_t num_rows = input_lens.front();
std::ptrdiff_t num_cols = input_lens.back();
shape::type_t output_type = args[0]->get_shape().type();
if(contains(info.attributes, "dtype"))
{
output_type = get_type(info.attributes.at("dtype").i());
}
std::ptrdiff_t k = 0;
if(contains(info.attributes, "k"))
{
k = info.attributes.at("k").i();
}
if(k >= 0)
{
if(k >= num_cols)
{
std::ostringstream oss;
oss << "EYELIKE: positive k out of bounds, k = " << k << " num_cols = " << num_cols;
MIGRAPHX_THROW(oss.str());
}
}
else
{
if(std::abs(k) >= num_rows)
{
std::ostringstream oss;
oss << "EYELIKE: negative k out of bounds, k = " << k << " num_rows = " << num_cols;
MIGRAPHX_THROW(oss.str());
}
}
std::vector<char> eyelike_mat(num_rows * num_cols, 0);
for(std::ptrdiff_t i = 0; i < num_rows; ++i)
{
auto idx = i + k;
if(idx < num_cols and idx >= 0)
eyelike_mat[(num_cols + 1) * i + k] = char{1};
}
return info.add_literal(
migraphx::literal{migraphx::shape{output_type, input_lens}, eyelike_mat});
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -39,7 +39,7 @@ struct parse_gather_elements : op_parser<parse_gather_elements>
int tuned_axis = tune_axis(n_rank, axis, opd.op_name);
auto axis_stride = data_s.strides()[tuned_axis];
int64_t data_elem_num = static_cast<int64_t>(data_s.elements());
int64_t data_elem_num = data_s.elements();
// reshape the input data as one dimension and used as input data
// to the gather operator
arg_data = info.add_instruction(make_op("reshape", {{"dims", {data_elem_num}}}), arg_data);
......@@ -63,8 +63,8 @@ struct parse_gather_elements : op_parser<parse_gather_elements>
info.add_literal(literal(ind_s, data_indices.begin(), data_indices.end()));
auto l_dim_idx = info.add_literal(literal(ind_s, vec_axis_ind.begin(), vec_axis_ind.end()));
auto l_stride = info.add_literal(literal{{ind_s.type(), {1}}, {axis_stride}});
l_stride = info.add_instruction(make_op("multibroadcast", {{"output_lens", ind_s.lens()}}),
l_stride);
l_stride =
info.add_instruction(make_op("multibroadcast", {{"out_lens", ind_s.lens()}}), l_stride);
auto dim_diff = info.add_instruction(make_op("sub"), arg_ind, l_dim_idx);
auto delta = info.add_instruction(make_op("mul"), dim_diff, l_stride);
auto ind = info.add_instruction(make_op("add"), l_shape_idx, delta);
......
......@@ -55,13 +55,17 @@ struct parse_gemm : op_parser<parse_gemm>
}
}
l1 = (transa) ? info.add_instruction(make_op("transpose", {{"dims", perm}}), l1) : l1;
auto l2 = (transb) ? info.add_instruction(make_op("transpose", {{"dims", perm}}), args[1])
: args[1];
l1 =
(transa) ? info.add_instruction(make_op("transpose", {{"permutation", perm}}), l1) : l1;
auto l2 = (transb)
? info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[1])
: args[1];
auto ret = info.add_instruction(make_op("dot"), l1, l2);
if(args.size() == 3)
{
if(beta != 0.0f && args[2]->get_shape().elements() > 0)
if(not float_equal(beta, 0.0f) && args[2]->get_shape().elements() > 0)
{
auto out_lens = l1->get_shape().lens();
out_lens.back() = l2->get_shape().lens().back();
......@@ -69,8 +73,8 @@ struct parse_gemm : op_parser<parse_gemm>
auto l3_lens = l3->get_shape().lens();
if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
{
l3 = info.add_instruction(
make_op("multibroadcast", {{"output_lens", out_lens}}), args[2]);
l3 = info.add_instruction(make_op("multibroadcast", {{"out_lens", out_lens}}),
args[2]);
}
auto beta_literal = info.add_literal(beta);
auto beta_l3 = info.add_broadcastable_binary_op("mul", l3, beta_literal);
......@@ -80,12 +84,11 @@ struct parse_gemm : op_parser<parse_gemm>
beta_l3);
}
return info.add_instruction(
make_op("dot", {{"alpha", 1.0f}, {"beta", 1.0f}}), l1, l2, beta_l3);
return info.add_instruction(make_op("add"), ret, beta_l3);
}
}
return info.add_instruction(make_op("dot", {{"alpha", 1.0f}, {"beta", 1.0f}}), l1, l2);
return ret;
}
};
......
......@@ -10,6 +10,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
{
std::vector<op_desc> operators() const
{
// clang-format off
return {{"Abs", "abs"},
{"Acos", "acos"},
{"Acosh", "acosh"},
......@@ -27,16 +28,17 @@ struct parse_generic_op : op_parser<parse_generic_op>
{"Flatten", "flatten"},
{"Floor", "floor"},
{"Gather", "gather"},
{"GatherND", "gathernd"},
{"Identity", "identity"},
{"IsNaN", "isnan"},
{"LeakyRelu", "leaky_relu"},
{"Log", "log"},
{"LRN", "lrn"},
{"Neg", "neg"},
{"NonMaxSuppression", "nonmaxsuppression"},
{"Reciprocal", "recip"},
{"Relu", "relu"},
{"Round", "round"},
{"Scatter", "scatter"},
{"ScatterElements", "scatter"},
{"Sigmoid", "sigmoid"},
{"Sign", "sign"},
{"Sin", "sin"},
......@@ -45,11 +47,12 @@ struct parse_generic_op : op_parser<parse_generic_op>
{"Tan", "tan"},
{"Tanh", "tanh"},
{"Not", "not"}};
// clang-format on
}
bool needs_contiguous(const std::string& op_name) const
{
return contains({"flatten", "gather", "scatter"}, op_name);
return contains({"flatten", "gather", "nonmaxsuppression", "scatter"}, op_name);
}
instruction_ref parse(const op_desc& opd,
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_greaterorequal : op_parser<parse_greaterorequal>
{
std::vector<op_desc> operators() const { return {{"GreaterOrEqual"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto in_res = info.add_broadcastable_binary_op("less", args[0], args[1]);
if(in_res->get_shape().type() != shape::bool_type)
{
in_res = info.add_instruction(make_op("convert", {{"target_type", shape::bool_type}}),
in_res);
}
return info.add_instruction(make_op("not"), in_res);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_hardsigmoid : op_parser<parse_hardsigmoid>
{
std::vector<op_desc> operators() const { return {{"HardSigmoid"}, {"HardSwish"}}; }
instruction_ref parse(const op_desc& opd,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
float alpha = 0.2;
float beta = 0.5;
if(opd.onnx_name == "HardSwish")
{
alpha = 1.0 / 6.0;
}
else
{
if(contains(info.attributes, "alpha"))
alpha = info.attributes.at("alpha").f();
if(contains(info.attributes, "beta"))
beta = info.attributes.at("beta").f();
}
auto input_lens = args[0]->get_shape().lens();
auto input_type = args[0]->get_shape().type();
auto mb_alpha = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {alpha}}));
auto mb_beta = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {beta}}));
auto mb_zero = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {0}}));
auto mb_one = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {1}}));
auto mul = info.add_instruction(migraphx::make_op("mul"), mb_alpha, args[0]);
auto add = info.add_instruction(migraphx::make_op("add"), mb_beta, mul);
auto hardsigmoid = info.add_instruction(migraphx::make_op("clip"), add, mb_zero, mb_one);
if(opd.onnx_name == "HardSwish")
return info.add_instruction(migraphx::make_op("mul"), args[0], hardsigmoid);
return hardsigmoid;
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -40,7 +40,7 @@ struct parse_imagescalar : op_parser<parse_imagescalar>
auto img_scaled =
info.add_instruction(migraphx::make_op("mul"), args.front(), scale_tensor);
auto bias_bcast = info.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", input_lens}}), bias_vals);
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", input_lens}}), bias_vals);
return info.add_instruction(migraphx::make_op("add"), img_scaled, bias_bcast);
}
};
......
......@@ -38,23 +38,23 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
auto mean = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x);
auto mean_bcast =
info.add_instruction(make_op("multibroadcast", {{"output_lens", dims}}), mean);
info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), mean);
auto l0 = info.add_instruction(make_op("sqdiff"), x, mean_bcast);
auto variance = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), l0);
auto l1 = info.add_instruction(make_op("sub"), x, mean_bcast);
auto epsilon_literal = info.add_literal(epsilon);
auto epsilon_bcast = info.add_instruction(
make_op("multibroadcast", {{"output_lens", dims}}), epsilon_literal);
auto epsilon_bcast =
info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), epsilon_literal);
auto variance_bcast =
info.add_instruction(make_op("multibroadcast", {{"output_lens", dims}}), variance);
info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), variance);
auto l2 = info.add_instruction(make_op("add"), variance_bcast, epsilon_bcast);
auto l3 = info.add_instruction(make_op("rsqrt"), l2);
auto l4 = info.add_instruction(make_op("mul"), l1, l3);
auto scale_bcast =
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"dims", dims}}), scale);
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), scale);
;
auto bias_bcast =
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"dims", dims}}), bias);
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), bias);
auto l5 = info.add_instruction(make_op("mul"), l4, scale_bcast);
return info.add_instruction(make_op("add"), l5, bias_bcast);
}
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/onnx_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_loop : op_parser<parse_loop>
{
std::vector<op_desc> operators() const { return {{"Loop"}}; }
std::vector<instruction_ref> parse(const op_desc& /*opd*/,
onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
// default value of the max_iter_num
int64_t max_iterations = parser.max_loop_iterations;
// iteration input is empty
if(args.at(0)->name() == "undefined")
{
shape iter_s{shape::int64_type};
args[0] = info.add_literal(literal(iter_s, {max_iterations}));
}
else
{
auto arg_iters = args.at(0)->eval();
if(not arg_iters.empty())
{
max_iterations = arg_iters.at<int64_t>();
}
}
// condition input is empty
if(args.at(1)->name() == "undefined")
{
shape cond_s{shape::bool_type};
args[1] = info.add_literal(literal(cond_s, {true}));
}
// retrieve the subgraph
const auto& sub_graph = info.attributes.at("body").g();
std::string mod_name = info.name + "_loop";
module_ref sub_mod = parser.prog.create_module(mod_name);
// parse the sub_graph
parser.parse_graph(sub_mod, sub_graph);
auto ret = info.add_instruction(
make_op("loop", {{"max_iterations", max_iterations}}), args, {sub_mod});
auto out_s = ret->get_shape();
assert(out_s.type() == shape::tuple_type);
const auto& vec_shapes = out_s.sub_shapes();
std::vector<instruction_ref> out_inss;
for(std::size_t i = 0; i < vec_shapes.size(); ++i)
{
auto r = info.add_instruction(make_op("get_tuple_elem", {{"index", i}}), ret);
out_inss.push_back(r);
}
return out_inss;
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
//! Parser for LpNormalization ONNX operator.
/*!
Normalizes a tensor by the L1 or L2 norms along a given axis.
Norms that evaluate to 0 are changed to 1 to prevent division by zero.
*/
struct parse_lpnormalization : op_parser<parse_lpnormalization>
{
std::vector<op_desc> operators() const { return {{"LpNormalization"}}; }
instruction_ref parse(const op_desc&,
const onnx_parser&,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
int p = 2;
if(contains(info.attributes, "p"))
{
p = info.attributes.at("p").i();
}
if(p != 1 and p != 2)
{
MIGRAPHX_THROW("LPNORMALIZATION: only L1 and L2 norm supported");
}
auto input = args.front();
auto input_shape = input->get_shape();
const auto& input_lens = input_shape.lens();
auto input_type = input_shape.type();
std::ptrdiff_t num_axes = input_lens.size();
std::ptrdiff_t axis = -1;
if(contains(info.attributes, "axis"))
{
axis = info.attributes.at("axis").i();
if(axis < -num_axes or axis >= num_axes)
{
// handled in normalize_attributes but throwing here might be clearer
MIGRAPHX_THROW("LPNORMALIZATION: selected axis out of bounds");
}
}
migraphx::instruction_ref p_val;
if(p == 1)
{
p_val = info.add_instruction(migraphx::make_op("abs"), input);
}
else
{
p_val = info.add_instruction(migraphx::make_op("mul"), input, input);
}
// need to check for zeros from lp norm to prevent division by zero
// change them to 1 for the element-wise division
auto norms =
info.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {axis}}}), p_val);
if(p == 2)
{
norms = info.add_instruction(migraphx::make_op("sqrt"), norms);
}
// broadcast back to initial shape, negative axis option doesn't work with unidirectional
norms = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), norms);
auto zero_mb = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {0.}}));
auto one_mb = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {1.}}));
auto is_zero = info.add_instruction(migraphx::make_op("equal"), norms, zero_mb);
auto norms_zeros_to_one =
info.add_instruction(migraphx::make_op("where"), is_zero, one_mb, norms);
return info.add_instruction(migraphx::make_op("div"), input, norms_zeros_to_one);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -58,18 +58,16 @@ struct parse_matmul : op_parser<parse_matmul>
if(l0_lens != l0_broadcasted_lens)
{
bl0 = info.add_instruction(
make_op("multibroadcast", {{"output_lens", l0_broadcasted_lens}}), l0);
make_op("multibroadcast", {{"out_lens", l0_broadcasted_lens}}), l0);
}
if(l1_lens != l1_broadcasted_lens)
{
bl1 = info.add_instruction(
make_op("multibroadcast", {{"output_lens", l1_broadcasted_lens}}), l1);
make_op("multibroadcast", {{"out_lens", l1_broadcasted_lens}}), l1);
}
}
auto dot_res =
info.add_instruction(make_op(opd.op_name, {{"alpha", 1}, {"beta", 0}}), bl0, bl1);
int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size());
instruction_ref dot_res = info.add_instruction(make_op(opd.op_name), bl0, bl1);
int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size());
if(is_a_prepended)
{
dot_res = info.add_instruction(make_op("squeeze", {{"axes", {num_axis - 2}}}), dot_res);
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_mean : op_parser<parse_mean>
{
const std::set<shape::type_t> float_types = {
shape::float_type, shape::half_type, shape::double_type};
std::vector<op_desc> operators() const { return {{"Mean"}}; }
/// Calculates the element-wise mean of n>=1 input tensors
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto num_data = args.size();
if(num_data == 1)
return args[0];
auto divisor = info.add_literal(
migraphx::literal{migraphx::shape{args[0]->get_shape().type()}, {num_data}});
if(contains(float_types, args[0]->get_shape().type()))
{
return std::accumulate(args.begin() + 1,
args.end(),
info.add_broadcastable_binary_op("div", args[0], divisor),
[&](auto mean, auto data_i) {
// Pre-divide each tensor element-wise by n to reduce risk of
// overflow during summation
auto div =
info.add_broadcastable_binary_op("div", data_i, divisor);
return info.add_broadcastable_binary_op("add", mean, div);
});
}
else
{
// Compute sum before division for integral types
auto sum = std::accumulate(
args.begin() + 1, args.end(), args[0], [&](auto accum, auto data_i) {
return info.add_broadcastable_binary_op("add", accum, data_i);
});
return info.add_broadcastable_binary_op("div", sum, divisor);
}
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment