"tasks/zeroshot_gpt/detokenizer.py" did not exist on "39af7956439a39f27a4891eb2d2e7f2586afecbb"
Commit 4a39a0f7 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into add-conv_bn_add-test

parents 5564172e bb827865
#include <migraphx/onnx/onnx_parser.hpp>
#include <migraphx/onnx/op_parser.hpp>
#include <iostream>
#include <fstream>
#include <unordered_map>
......@@ -20,6 +21,7 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs)
parser.map_input_dims = options.map_input_dims;
parser.default_dim_value = options.default_dim_value;
parser.skip_unknown_operators = options.skip_unknown_operators;
parser.max_loop_iterations = options.max_loop_iterations;
if(options.print_program_on_error)
{
......@@ -57,5 +59,7 @@ program parse_onnx_buffer(const void* data, std::size_t size, const onnx_options
return parse_onnx_from(options, data, size);
}
std::vector<std::string> get_onnx_operators() { return onnx::get_op_parsers(); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -6,6 +6,7 @@
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/pad_calc.hpp>
#include <migraphx/common.hpp>
#include <migraphx/type_traits.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/file_buffer.hpp>
......@@ -84,73 +85,18 @@ instruction_ref onnx_parser::node_info::add_bias(const std::vector<instruction_r
if(args.size() == 3)
{
auto bias_bcast = mod->add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", curr_ins->get_shape().lens()}}),
make_op("broadcast", {{"axis", axis}, {"out_lens", curr_ins->get_shape().lens()}}),
args[2]);
return mod->add_instruction(make_op("add"), curr_ins, bias_bcast);
}
return curr_ins;
}
std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
std::vector<std::size_t> s1)
{
// Example:
// s0 = (3,2,4,5) and s1 = (2,1,1)
//
// In this case we need to broadcast (:,1,1) portion of
// s1 plus broadcast the 1st dimension of s1
// giving output_lens = (3,2,4,5)
//
// Another example:
// s0 = (3,2,1,5) and s1 = (2,7,5)
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
if(s0.size() > s1.size())
{
s0.swap(s1);
}
std::vector<std::size_t> out_lens(s1);
auto offset = s1.size() - s0.size();
std::transform(
s0.begin(), s0.end(), s1.begin() + offset, out_lens.begin() + offset, [&](auto a, auto b) {
if(a != b and a != 1 and b != 1)
{
MIGRAPHX_THROW("COMPUTE_BROADCASTLEN: shape {" + to_string_range(s0) + "} and {" +
to_string_range(s1) + "} mismatch!");
}
return std::max(a, b);
});
return out_lens;
}
instruction_ref onnx_parser::node_info::add_broadcastable_binary_op(const std::string& op_name,
instruction_ref arg0,
instruction_ref arg1) const
{
if(arg0->get_shape().lens() != arg1->get_shape().lens())
{
// Get lengths for both arguments
auto s0 = arg0->get_shape().lens();
auto s1 = arg1->get_shape().lens();
auto out_lens = compute_broadcasted_lens(s0, s1);
auto l0 = arg0;
if(arg0->get_shape().lens() != out_lens)
l0 = add_instruction(make_op("multibroadcast", {{"output_lens", out_lens}}), arg0);
auto l1 = arg1;
if(arg1->get_shape().lens() != out_lens)
l1 = add_instruction(make_op("multibroadcast", {{"output_lens", out_lens}}), arg1);
return add_instruction(make_op(op_name), l0, l1);
}
else
{
return add_instruction(make_op(op_name), {arg0, arg1});
}
return add_common_op(*mod, make_op(op_name), {arg0, arg1});
}
instruction_ref
......@@ -278,28 +224,42 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model)
void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
{
std::unordered_map<std::string, instruction_ref> mod_insts;
for(auto&& f : graph.initializer())
{
instructions[f.name()] = mod->add_literal(parse_tensor(f));
// backup instructions in parent mod
mod_insts[f.name()] = mod->add_literal(parse_tensor(f));
}
for(auto&& input : graph.input())
{
const std::string& name = input.name();
// input not in initializer_data, so it is a real input
if(!contains(instructions, name))
if(!contains(mod_insts, name))
{
// ONNX specification does not specify hwo to deal with the
// scenario that a nested subgraph contains a parameter with the
// name existed in its parent graph.
// In the current implementation, MIGraphX throws an exception for that.
if(contains(instructions, name))
{
MIGRAPHX_THROW("module \"" + mod->name() + "\" has parameter name \"" + name +
"\" existing in parent graph!");
}
std::vector<std::size_t> dims;
if(map_input_dims.count(name) > 0)
{
dims = map_input_dims.at(name);
}
shape s = parse_type(input.type(), dims);
instructions[name] = mod->add_parameter(name, s);
shape s = parse_type(input.type(), dims);
mod_insts[name] = mod->add_parameter(name, s);
}
}
std::copy(mod_insts.begin(), mod_insts.end(), std::inserter(instructions, instructions.end()));
for(auto&& node : graph.node())
{
std::vector<instruction_ref> args;
......@@ -363,6 +323,9 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
// add the return instuction
mod->add_return(output_ins);
// remove instructions added in this mod
erase_if(instructions, [&](auto&& p) { return mod->has_instruction(p.second); });
}
literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const
......
......@@ -126,7 +126,7 @@ void check_asym_padding(const onnx_parser::node_info& info,
auto left_pad_it = padding.begin();
auto right_pad_it = left_pad_it + pad_ndims;
if(is_asym_padding(padding) or count_include_pad == 1)
if(count_include_pad == 1)
{
std::vector<int64_t> asym_pads{0, 0, 0, 0}; // don't pad N and C
// add left pads
......@@ -134,10 +134,19 @@ void check_asym_padding(const onnx_parser::node_info& info,
// add right pads
asym_pads.insert(asym_pads.begin() + pad_ndims + 4, right_pad_it, padding.end());
ins = info.add_instruction(make_op("pad", {{"pads", asym_pads}, {"value", pad_val}}), ins);
}
else
{
v["padding"] = std::vector<size_t>(left_pad_it, right_pad_it);
std::vector<size_t> new_padding(padding.size());
// subtract asym padding originally found from parsing the operator
std::transform(padding.begin(),
left_pad_it,
asym_pads.begin() + 2,
new_padding.begin(),
std::minus<size_t>());
std::transform(right_pad_it,
padding.end(),
asym_pads.begin() + pad_ndims + 4,
new_padding.begin() + pad_ndims,
std::minus<size_t>());
v["padding"] = new_padding;
}
}
......
......@@ -36,7 +36,8 @@ struct parse_binary_op : op_parser<parse_binary_op>
{
uint64_t axis = parser.parse_value(info.attributes.at("axis")).at<uint64_t>();
auto l = info.add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", args[0]->get_shape().lens()}}),
make_op("broadcast",
{{"axis", axis}, {"out_lens", args[0]->get_shape().lens()}}),
args[1]);
return info.add_instruction(make_op(opd.op_name), args[0], l);
}
......
......@@ -47,13 +47,13 @@ struct parse_clip : op_parser<parse_clip>
if(min_used)
{
min_arg = info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}),
min_arg = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
min_arg);
}
if(max_used)
{
max_arg = info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}),
max_arg = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
max_arg);
}
......
......@@ -73,7 +73,7 @@ struct parse_convolution : op_parser<parse_convolution>
values["padding_mode"] = to_value(op::padding_mode_t::same);
}
}
check_asym_padding(info, l0, padding, values);
values["padding"] = std::vector<size_t>(padding.begin(), padding.end());
if(contains(info.attributes, "group"))
{
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_depthtospace : op_parser<parse_depthtospace>
{
std::vector<op_desc> operators() const { return {{"DepthToSpace"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto s = args[0]->get_shape();
// mode attribute of DepthToSpace
auto mode = std::string("DCR");
if(contains(info.attributes, "mode"))
{
mode = info.attributes.at("mode").s(); // DCR or CRD?
}
// blocksize attribute of DepthToSpace
int blocksize = 0;
if(contains(info.attributes, "blocksize"))
{
blocksize = info.attributes.at("blocksize").i();
}
if(blocksize < 1)
{
MIGRAPHX_THROW("DepthToSpace: blocksize is less than 1");
}
// calculate dimensions
auto lens1 = s.lens();
auto lens2 = s.lens();
unsigned long divisor = std::pow(blocksize, 2);
if((lens2[1] % divisor) == 0)
lens2[1] = lens2[1] / divisor;
else
MIGRAPHX_THROW("DepthToSpace: div by blocksize quotient not int ");
lens1.push_back(lens1[2]);
lens1.push_back(lens1[3]);
lens2[2] = lens2[2] * blocksize;
lens2[3] = lens2[3] * blocksize;
lens1[2] = blocksize;
std::vector<int64_t> perm;
if(mode == "DCR")
{
lens1[3] = lens1[1] / divisor;
lens1[1] = blocksize;
perm = {0, 3, 4, 1, 5, 2};
}
else if(mode == "CRD")
{
lens1[1] = lens1[1] / divisor;
lens1[3] = blocksize;
perm = {0, 1, 4, 2, 5, 3};
}
else
MIGRAPHX_THROW("DepthToSpace: mode attribute cannot be read.");
auto temp1 = info.add_instruction(make_op("reshape", {{"dims", lens1}}), args[0]);
auto temp2 = info.add_instruction(make_op("transpose", {{"permutation", perm}}), temp1);
return info.add_instruction(make_op("reshape", {{"dims", lens2}}),
info.make_contiguous(temp2));
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -15,46 +15,49 @@ struct parse_dequantizelinear : op_parser<parse_dequantizelinear>
instruction_ref parse(const op_desc& opd,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
const std::vector<instruction_ref>& args) const
{
int axis = 1;
if(contains(info.attributes, "axis"))
axis = info.attributes.at("axis").i();
auto input_lens = args[0]->get_shape().lens();
int n_dim = static_cast<int>(input_lens.size());
auto n_dim = input_lens.size();
auto sub_zero_point = args[0];
instruction_ref x_scale;
if(args[1]->get_shape().elements() != 1)
{
auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
x_scale = info.add_instruction(
make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}), args[1]);
}
else
{
x_scale = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
args[1]);
}
if(args.size() == 3)
{
auto zero_point = args[2];
if(not(zero_point->get_shape().elements() == 1))
auto x_zero_point = args[2];
if(x_zero_point->get_shape().elements() != 1)
{
axis = tune_axis(n_dim, axis, opd.op_name);
zero_point = info.add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", input_lens}}), zero_point);
auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
x_zero_point = info.add_instruction(
make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}),
x_zero_point);
}
else
{
x_zero_point = info.add_instruction(
make_op("multibroadcast", {{"out_lens", input_lens}}), x_zero_point);
}
auto zero_point_int32 = info.add_instruction(
make_op("convert", {{"target_type", shape::int32_type}}), zero_point);
auto sub_zero_point_int32 = info.add_instruction(
make_op("convert", {{"target_type", shape::int32_type}}), sub_zero_point);
sub_zero_point =
info.add_broadcastable_binary_op("sub", sub_zero_point_int32, zero_point_int32);
return info.add_instruction(
make_op("dequantizelinear"), args[0], x_scale, x_zero_point);
}
auto dequant_input = info.add_instruction(
make_op("convert", {{"target_type", shape::float_type}}), sub_zero_point);
auto scale = args[1];
if(not(scale->get_shape().elements() == 1))
{
axis = tune_axis(n_dim, axis, opd.op_name);
scale = info.add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", input_lens}}), scale);
}
return info.add_broadcastable_binary_op("mul", dequant_input, scale);
return info.add_instruction(make_op("dequantizelinear"), args[0], x_scale);
}
};
......
......@@ -2,6 +2,7 @@
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/common.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
......@@ -23,8 +24,7 @@ struct parse_expand : op_parser<parse_expand>
std::vector<std::size_t> dims;
arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
auto out_lens = compute_broadcasted_lens(in_lens, dims);
return info.add_instruction(make_op("multibroadcast", {{"output_lens", out_lens}}),
args[0]);
return info.add_instruction(make_op("multibroadcast", {{"out_lens", out_lens}}), args[0]);
}
};
......
......@@ -63,8 +63,8 @@ struct parse_gather_elements : op_parser<parse_gather_elements>
info.add_literal(literal(ind_s, data_indices.begin(), data_indices.end()));
auto l_dim_idx = info.add_literal(literal(ind_s, vec_axis_ind.begin(), vec_axis_ind.end()));
auto l_stride = info.add_literal(literal{{ind_s.type(), {1}}, {axis_stride}});
l_stride = info.add_instruction(make_op("multibroadcast", {{"output_lens", ind_s.lens()}}),
l_stride);
l_stride =
info.add_instruction(make_op("multibroadcast", {{"out_lens", ind_s.lens()}}), l_stride);
auto dim_diff = info.add_instruction(make_op("sub"), arg_ind, l_dim_idx);
auto delta = info.add_instruction(make_op("mul"), dim_diff, l_stride);
auto ind = info.add_instruction(make_op("add"), l_shape_idx, delta);
......
......@@ -42,13 +42,30 @@ struct parse_gemm : op_parser<parse_gemm>
// swap the last two elements
std::swap(*perm.rbegin(), *(perm.rbegin() + 1));
auto l1 = (transa) ? info.add_instruction(make_op("transpose", {{"dims", perm}}), args[0])
: args[0];
auto l2 = (transb) ? info.add_instruction(make_op("transpose", {{"dims", perm}}), args[1])
: args[1];
auto l1 = args[0];
auto dot_type = l1->get_shape().type();
if(alpha != 1.0f)
{
auto alpha_literal = info.add_literal(alpha);
l1 = info.add_broadcastable_binary_op("mul", alpha_literal, l1);
if(l1->get_shape().type() != dot_type)
{
l1 = info.add_instruction(make_op("convert", {{"target_type", dot_type}}), l1);
}
}
l1 =
(transa) ? info.add_instruction(make_op("transpose", {{"permutation", perm}}), l1) : l1;
auto l2 = (transb)
? info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[1])
: args[1];
auto ret = info.add_instruction(make_op("dot"), l1, l2);
if(args.size() == 3)
{
if(beta != 0.f && args[2]->get_shape().elements() > 0)
if(not float_equal(beta, 0.0f) && args[2]->get_shape().elements() > 0)
{
auto out_lens = l1->get_shape().lens();
out_lens.back() = l2->get_shape().lens().back();
......@@ -56,15 +73,22 @@ struct parse_gemm : op_parser<parse_gemm>
auto l3_lens = l3->get_shape().lens();
if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
{
l3 = info.add_instruction(
make_op("multibroadcast", {{"output_lens", out_lens}}), args[2]);
l3 = info.add_instruction(make_op("multibroadcast", {{"out_lens", out_lens}}),
args[2]);
}
return info.add_instruction(
make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3);
auto beta_literal = info.add_literal(beta);
auto beta_l3 = info.add_broadcastable_binary_op("mul", l3, beta_literal);
if(beta_l3->get_shape().type() != dot_type)
{
beta_l3 = info.add_instruction(make_op("convert", {{"target_type", dot_type}}),
beta_l3);
}
return info.add_instruction(make_op("add"), ret, beta_l3);
}
}
return info.add_instruction(make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2);
return ret;
}
};
......
......@@ -35,6 +35,8 @@ struct parse_generic_op : op_parser<parse_generic_op>
{"Reciprocal", "recip"},
{"Relu", "relu"},
{"Round", "round"},
{"Scatter", "scatter"},
{"ScatterElements", "scatter"},
{"Sigmoid", "sigmoid"},
{"Sign", "sign"},
{"Sin", "sin"},
......@@ -47,7 +49,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
bool needs_contiguous(const std::string& op_name) const
{
return contains({"gather"}, op_name);
return contains({"flatten", "gather", "scatter"}, op_name);
}
instruction_ref parse(const op_desc& opd,
......
#include <migraphx/instruction_ref.hpp>
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/onnx_parser.hpp>
#include <migraphx/onnx/checks.hpp>
......@@ -26,59 +27,41 @@ struct parse_if : op_parser<parse_if>
MIGRAPHX_THROW("PARSE_IF: condition input can have only one element!");
}
migraphx::argument cond_arg = args.front()->eval();
// cond is not constant, need to create sub_modules
if(cond_arg.empty())
{
std::string then_name = info.name + "_if";
module_ref then_mdl = parser.prog.create_module(then_name);
std::string else_name = info.name + "_else";
module_ref else_mdl = parser.prog.create_module(else_name);
std::string then_name = info.name + "_if";
module_ref then_mdl = parser.prog.create_module(then_name);
// parse the then sub_graph
parser.parse_graph(then_mdl, then_graph);
std::string else_name = info.name + "_else";
module_ref else_mdl = parser.prog.create_module(else_name);
// parse_the else sub_graph
parser.parse_graph(else_mdl, else_graph);
// parse the then sub_graph
parser.parse_graph(then_mdl, then_graph);
auto then_out_shapes = then_mdl->get_output_shapes();
auto else_out_shapes = else_mdl->get_output_shapes();
if(not std::equal(then_out_shapes.begin(),
then_out_shapes.end(),
else_out_shapes.begin(),
else_out_shapes.end()))
{
MIGRAPHX_THROW("PARSE_IF: then and else sub_grahps must have same output shapes!");
}
// parse_the else sub_graph
parser.parse_graph(else_mdl, else_graph);
auto ret = info.add_instruction(make_op("if"), args, {then_mdl, else_mdl});
return {ret};
}
else
auto then_out_shapes = then_mdl->get_output_shapes();
auto else_out_shapes = else_mdl->get_output_shapes();
if(not std::equal(then_out_shapes.begin(),
then_out_shapes.end(),
else_out_shapes.begin(),
else_out_shapes.end()))
{
auto* mod = info.mod;
// then branch
if(cond_arg.at<bool>())
{
parser.parse_graph(mod, then_graph);
}
// else branch
else
{
parser.parse_graph(mod, else_graph);
}
MIGRAPHX_THROW("PARSE_IF: then and else sub_grahps must have same output shapes!");
}
// inputs of the return instruction are that of the output of the
// if instruction
instruction_ref ret_ins = std::prev(mod->end());
auto outputs = ret_ins->inputs();
assert(ret_ins->name() == "@return");
mod->remove_instruction(ret_ins);
auto if_ret = info.add_instruction(make_op("if"), args, {then_mdl, else_mdl});
auto out_s = if_ret->get_shape();
assert(out_s.type() == shape::tuple_type);
return outputs;
const auto& vec_shapes = out_s.sub_shapes();
std::vector<instruction_ref> out_inss;
for(std::size_t i = 0; i < vec_shapes.size(); ++i)
{
auto ret = info.add_instruction(make_op("get_tuple_elem", {{"index", i}}), if_ret);
out_inss.push_back(ret);
}
return out_inss;
}
};
......
......@@ -40,7 +40,7 @@ struct parse_imagescalar : op_parser<parse_imagescalar>
auto img_scaled =
info.add_instruction(migraphx::make_op("mul"), args.front(), scale_tensor);
auto bias_bcast = info.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", input_lens}}), bias_vals);
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", input_lens}}), bias_vals);
return info.add_instruction(migraphx::make_op("add"), img_scaled, bias_bcast);
}
};
......
......@@ -38,23 +38,23 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
auto mean = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x);
auto mean_bcast =
info.add_instruction(make_op("multibroadcast", {{"output_lens", dims}}), mean);
info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), mean);
auto l0 = info.add_instruction(make_op("sqdiff"), x, mean_bcast);
auto variance = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), l0);
auto l1 = info.add_instruction(make_op("sub"), x, mean_bcast);
auto epsilon_literal = info.add_literal(epsilon);
auto epsilon_bcast = info.add_instruction(
make_op("multibroadcast", {{"output_lens", dims}}), epsilon_literal);
auto epsilon_bcast =
info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), epsilon_literal);
auto variance_bcast =
info.add_instruction(make_op("multibroadcast", {{"output_lens", dims}}), variance);
info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), variance);
auto l2 = info.add_instruction(make_op("add"), variance_bcast, epsilon_bcast);
auto l3 = info.add_instruction(make_op("rsqrt"), l2);
auto l4 = info.add_instruction(make_op("mul"), l1, l3);
auto scale_bcast =
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"dims", dims}}), scale);
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), scale);
;
auto bias_bcast =
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"dims", dims}}), bias);
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), bias);
auto l5 = info.add_instruction(make_op("mul"), l4, scale_bcast);
return info.add_instruction(make_op("add"), l5, bias_bcast);
}
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/onnx_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_loop : op_parser<parse_loop>
{
std::vector<op_desc> operators() const { return {{"Loop"}}; }
std::vector<instruction_ref> parse(const op_desc& /*opd*/,
onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
// default value of the max_iter_num
int64_t max_iterations = parser.max_loop_iterations;
// iteration input is empty
if(args.at(0)->name() == "undefined")
{
shape iter_s{shape::int64_type};
args[0] = info.add_literal(literal(iter_s, {max_iterations}));
}
else
{
auto arg_iters = args.at(0)->eval();
if(not arg_iters.empty())
{
max_iterations = arg_iters.at<int64_t>();
}
}
// condition input is empty
if(args.at(1)->name() == "undefined")
{
shape cond_s{shape::bool_type};
args[1] = info.add_literal(literal(cond_s, {true}));
}
// retrieve the subgraph
const auto& sub_graph = info.attributes.at("body").g();
std::string mod_name = info.name + "_loop";
module_ref sub_mod = parser.prog.create_module(mod_name);
// parse the sub_graph
parser.parse_graph(sub_mod, sub_graph);
auto ret = info.add_instruction(
make_op("loop", {{"max_iterations", max_iterations}}), args, {sub_mod});
auto out_s = ret->get_shape();
assert(out_s.type() == shape::tuple_type);
const auto& vec_shapes = out_s.sub_shapes();
std::vector<instruction_ref> out_inss;
for(std::size_t i = 0; i < vec_shapes.size(); ++i)
{
auto r = info.add_instruction(make_op("get_tuple_elem", {{"index", i}}), ret);
out_inss.push_back(r);
}
return out_inss;
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/common.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
......@@ -57,18 +58,16 @@ struct parse_matmul : op_parser<parse_matmul>
if(l0_lens != l0_broadcasted_lens)
{
bl0 = info.add_instruction(
make_op("multibroadcast", {{"output_lens", l0_broadcasted_lens}}), l0);
make_op("multibroadcast", {{"out_lens", l0_broadcasted_lens}}), l0);
}
if(l1_lens != l1_broadcasted_lens)
{
bl1 = info.add_instruction(
make_op("multibroadcast", {{"output_lens", l1_broadcasted_lens}}), l1);
make_op("multibroadcast", {{"out_lens", l1_broadcasted_lens}}), l1);
}
}
auto dot_res =
info.add_instruction(make_op(opd.op_name, {{"alpha", 1}, {"beta", 0}}), bl0, bl1);
int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size());
instruction_ref dot_res = info.add_instruction(make_op(opd.op_name), bl0, bl1);
int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size());
if(is_a_prepended)
{
dot_res = info.add_instruction(make_op("squeeze", {{"axes", {num_axis - 2}}}), dot_res);
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <random>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_multinomial : op_parser<parse_multinomial>
{
std::vector<op_desc> operators() const { return {{"Multinomial"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
int dtype = 6;
if(contains(info.attributes, "dtype"))
dtype = info.attributes.at("dtype").i();
shape::type_t output_type = get_type(dtype);
size_t sample_size = 1;
if(contains(info.attributes, "sample_size"))
sample_size = info.attributes.at("sample_size").i();
float seed = static_cast<float>(
std::chrono::high_resolution_clock::now().time_since_epoch().count());
if(contains(info.attributes, "seed"))
seed = info.attributes.at("seed").f();
// Subtract the per-batch maximum log-probability, making the per-batch max 0
auto maxes =
info.add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), args[0]);
auto mb_maxes = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", args[0]->get_shape().lens()}}),
maxes);
auto cdf = info.add_instruction(migraphx::make_op("sub"), args[0], mb_maxes);
// Take the element-wise exponent to get probabilities in the range (0, 1]
cdf = info.add_instruction(migraphx::make_op("exp"), cdf);
// Compute the cumulative density function
cdf = info.add_instruction(
migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), cdf);
// Pre-compute random distribution
std::mt19937 gen(seed);
std::uniform_real_distribution<> dis(0.0, 1.0);
size_t batch_size = args[0]->get_shape().lens().front();
migraphx::shape dist_shape{migraphx::shape::float_type, {batch_size, sample_size}};
std::vector<float> random_dist(batch_size * sample_size);
std::generate(random_dist.begin(), random_dist.end(), [&]() { return dis(gen); });
auto dist_lit = info.add_literal(migraphx::literal{dist_shape, random_dist});
return info.add_instruction(
migraphx::make_op("multinomial", {{"dtype", output_type}}), cdf, dist_lit);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -9,7 +9,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
template <class T>
std::vector<std::size_t> nonzero_indices(const std::vector<T>& data)
static std::vector<std::size_t> nonzero_indices(const std::vector<T>& data)
{
std::vector<std::size_t> indices;
for(std::size_t i = 0; i < data.size(); ++i)
......@@ -31,30 +31,35 @@ struct parse_nonzero : op_parser<parse_nonzero>
std::vector<instruction_ref> args) const
{
migraphx::argument data_arg = args.back()->eval();
check_arg_empty(data_arg, "PARSE_NONZERO: cannot support non-constant input!");
std::vector<std::size_t> indices;
data_arg.visit([&](auto val) {
using val_type = std::remove_cv_t<typename decltype(val)::value_type>;
std::vector<val_type> vec_data;
vec_data.assign(val.begin(), val.end());
indices = nonzero_indices(vec_data);
});
if(data_arg.empty())
{
return info.add_instruction(make_op("nonzero"), args);
}
else
{
std::vector<std::size_t> indices;
data_arg.visit([&](auto val) {
using val_type = std::remove_cv_t<typename decltype(val)::value_type>;
std::vector<val_type> vec_data;
vec_data.assign(val.begin(), val.end());
indices = nonzero_indices(vec_data);
});
shape in_s = args[0]->get_shape();
shape out_s{shape::int64_type, {in_s.lens().size(), indices.size()}};
shape in_s = args[0]->get_shape();
shape out_s{shape::int64_type, {in_s.lens().size(), indices.size()}};
std::vector<int64_t> out_data(out_s.elements());
for(std::size_t i = 0; i < indices.size(); ++i)
{
auto idx = in_s.multi(indices[i]);
for(std::size_t j = 0; j < in_s.lens().size(); ++j)
std::vector<int64_t> out_data(out_s.elements());
for(std::size_t i = 0; i < indices.size(); ++i)
{
out_data[out_s.index({j, i})] = idx[j];
auto idx = in_s.multi(indices[i]);
for(std::size_t j = 0; j < in_s.lens().size(); ++j)
{
out_data[out_s.index({j, i})] = idx[j];
}
}
}
return info.add_literal(literal(out_s, out_data));
return info.add_literal(literal(out_s, out_data));
}
}
};
......
......@@ -45,8 +45,9 @@ struct parse_onehot : op_parser<parse_onehot>
std::vector<int64_t> perm(n_rank - 1);
std::iota(perm.begin(), perm.end(), 0);
perm.insert(perm.begin() + tuned_axis, n_rank - 1);
auto tr_out = info.add_instruction(make_op("transpose", {{"dims", perm}}), gather_out);
auto lens = tr_out->get_shape().lens();
auto tr_out =
info.add_instruction(make_op("transpose", {{"permutation", perm}}), gather_out);
auto lens = tr_out->get_shape().lens();
auto off_val = info.add_instruction(
make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]);
......@@ -54,9 +55,9 @@ struct parse_onehot : op_parser<parse_onehot>
make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]);
auto diff = info.add_instruction(make_op("sub"), on_val, off_val);
auto unsq_off_val =
info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), off_val);
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), off_val);
auto unsq_diff_val =
info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), diff);
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), diff);
auto l_mul = info.add_instruction(make_op("mul"), tr_out, unsq_diff_val);
return info.add_instruction(make_op("add"), l_mul, unsq_off_val);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment