"src/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "e1e37208c5b19939acb3bedc07b3a3abc8daf355"
Commit 9bff4331 authored by Paul's avatar Paul
Browse files

Merge

parents 214b313f 94a7f6ee
...@@ -113,7 +113,8 @@ struct onnx_parser ...@@ -113,7 +113,8 @@ struct onnx_parser
void parse_from(std::istream& is, std::string name = ""); void parse_from(std::istream& is, std::string name = "");
void parse_from(const void* data, std::size_t size); void parse_from(const void* data, std::size_t size);
void parse_graph(module* mod, const onnx::GraphProto& graph); std::vector<instruction_ref>
parse_graph(module* mod, const onnx::GraphProto& graph, bool inlining = false);
literal parse_value(const onnx::AttributeProto& attr) const; literal parse_value(const onnx::AttributeProto& attr) const;
literal parse_tensor(const onnx::TensorProto& t) const; literal parse_tensor(const onnx::TensorProto& t) const;
shape parse_type(const onnx::TypeProto& t, const std::vector<std::size_t>& input_dims) const; shape parse_type(const onnx::TypeProto& t, const std::vector<std::size_t>& input_dims) const;
......
...@@ -110,9 +110,19 @@ instruction_ref onnx_parser::node_info::add_bias(const std::vector<instruction_r ...@@ -110,9 +110,19 @@ instruction_ref onnx_parser::node_info::add_bias(const std::vector<instruction_r
{ {
if(args.size() == 3) if(args.size() == 3)
{ {
auto bias_bcast = mod->add_instruction( instruction_ref bias_bcast;
make_op("broadcast", {{"axis", axis}, {"out_lens", curr_ins->get_shape().lens()}}), // if curr_ins has a dynamic output shape use 2 input broadcast
args[2]); if(curr_ins->get_shape().dynamic())
{
bias_bcast =
mod->add_instruction(make_op("broadcast", {{"axis", axis}}), args[2], curr_ins);
}
else
{
bias_bcast = mod->add_instruction(
make_op("broadcast", {{"axis", axis}, {"out_lens", curr_ins->get_shape().lens()}}),
args[2]);
}
return mod->add_instruction(make_op("add"), curr_ins, bias_bcast); return mod->add_instruction(make_op("add"), curr_ins, bias_bcast);
} }
return curr_ins; return curr_ins;
...@@ -210,7 +220,7 @@ void onnx_parser::parse_from(std::istream& is, std::string name) ...@@ -210,7 +220,7 @@ void onnx_parser::parse_from(std::istream& is, std::string name)
if(model.has_graph()) if(model.has_graph())
{ {
this->parse_graph(mm, model.graph()); (void)this->parse_graph(mm, model.graph());
} }
} }
else else
...@@ -230,7 +240,7 @@ void onnx_parser::parse_from(const void* data, std::size_t size) ...@@ -230,7 +240,7 @@ void onnx_parser::parse_from(const void* data, std::size_t size)
if(model.has_graph()) if(model.has_graph())
{ {
this->parse_graph(mm, model.graph()); (void)this->parse_graph(mm, model.graph());
} }
} }
else else
...@@ -254,7 +264,8 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model) ...@@ -254,7 +264,8 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model)
return version; return version;
} }
void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph) std::vector<instruction_ref>
onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlining)
{ {
std::unordered_map<std::string, instruction_ref> mod_insts; std::unordered_map<std::string, instruction_ref> mod_insts;
for(auto&& f : graph.initializer()) for(auto&& f : graph.initializer())
...@@ -362,11 +373,16 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph) ...@@ -362,11 +373,16 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
std::back_inserter(output_ins), std::back_inserter(output_ins),
[&](const auto& name) { return instructions[name]; }); [&](const auto& name) { return instructions[name]; });
// add the return instuction if(not inlining)
mod->add_return(output_ins); {
// add the return instuction
mod->add_return(output_ins);
// remove instructions added in this mod // Remove instructions added in module (this is turned off for subgraph inlining)
erase_if(instructions, [&](auto&& p) { return mod->has_instruction(p.second); }); erase_if(instructions, [&](auto&& p) { return mod->has_instruction(p.second); });
}
return output_ins;
} }
literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const
...@@ -393,18 +409,31 @@ literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const ...@@ -393,18 +409,31 @@ literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const
literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
{ {
std::vector<std::size_t> dims(t.dims().begin(), t.dims().end()); std::vector<std::size_t> dims(t.dims().begin(), t.dims().end());
if(not t.external_data().empty()) auto type = get_type(t.data_type());
shape tensor_shape(type, dims);
auto external_data = t.external_data();
if(not external_data.empty())
{ {
const std::string& data_file = t.external_data().at(0).value(); const std::string& data_file = external_data.at(0).value();
auto raw_buffer = read_buffer(path + "/" + data_file); size_t num_data_fields = external_data.size();
size_t offset = 0;
size_t nbytes = tensor_shape.bytes();
if(num_data_fields > 1) // if offset field is present
{
offset = std::stoul(t.external_data().at(1).value());
}
if(num_data_fields > 2) // if nbytes field is present
{
nbytes = std::stoul(t.external_data().at(2).value());
}
auto raw_buffer = read_buffer(path + "/" + data_file, offset, nbytes);
std::string s(raw_buffer.begin(), raw_buffer.end()); std::string s(raw_buffer.begin(), raw_buffer.end());
auto type = get_type(t.data_type());
return create_literal(type, dims, s.data()); return create_literal(type, dims, s.data());
} }
if(t.has_raw_data()) if(t.has_raw_data())
{ {
const std::string& s = t.raw_data(); const std::string& s = t.raw_data();
auto type = get_type(t.data_type());
return create_literal(type, dims, s.data()); return create_literal(type, dims, s.data());
} }
......
...@@ -39,10 +39,19 @@ struct parse_gemm : op_parser<parse_gemm> ...@@ -39,10 +39,19 @@ struct parse_gemm : op_parser<parse_gemm>
onnx_parser::node_info info, onnx_parser::node_info info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
float alpha = 1.0f; auto a_arg = args[0];
float beta = 1.0f; auto b_arg = args[1];
bool transa = false; if(a_arg->get_shape().ndim() != 2 or b_arg->get_shape().ndim() != 2)
bool transb = false; {
MIGRAPHX_THROW("PARSE_GEMM: A and B should be rank 2, A is rank " +
std::to_string(a_arg->get_shape().ndim()) + ", B is rank " +
std::to_string(b_arg->get_shape().ndim()));
}
float alpha = 1.0f;
float beta = 1.0f;
bool trans_a = false;
bool trans_b = false;
if(contains(info.attributes, "alpha")) if(contains(info.attributes, "alpha"))
{ {
alpha = parser.parse_value(info.attributes.at("alpha")).at<float>(); alpha = parser.parse_value(info.attributes.at("alpha")).at<float>();
...@@ -53,65 +62,73 @@ struct parse_gemm : op_parser<parse_gemm> ...@@ -53,65 +62,73 @@ struct parse_gemm : op_parser<parse_gemm>
} }
if(contains(info.attributes, "transA")) if(contains(info.attributes, "transA"))
{ {
transa = parser.parse_value(info.attributes.at("transA")).at<bool>(); trans_a = parser.parse_value(info.attributes.at("transA")).at<bool>();
} }
if(contains(info.attributes, "transB")) if(contains(info.attributes, "transB"))
{ {
transb = parser.parse_value(info.attributes.at("transB")).at<bool>(); trans_b = parser.parse_value(info.attributes.at("transB")).at<bool>();
} }
std::vector<int64_t> perm(args[0]->get_shape().lens().size()); std::vector<int64_t> perm = {1, 0};
std::iota(perm.begin(), perm.end(), int64_t{0}); auto dot_type = a_arg->get_shape().type();
// swap the last two elements
std::swap(*perm.rbegin(), *(perm.rbegin() + 1));
auto l1 = args[0];
auto dot_type = l1->get_shape().type();
if(alpha != 1.0f) if(alpha != 1.0f)
{ {
auto alpha_literal = info.add_literal(alpha); auto alpha_literal = info.add_literal(alpha);
l1 = info.add_broadcastable_binary_op("mul", alpha_literal, l1); a_arg = info.add_broadcastable_binary_op("mul", alpha_literal, a_arg);
if(l1->get_shape().type() != dot_type)
if(a_arg->get_shape().type() != dot_type)
{ {
l1 = info.add_instruction(make_op("convert", {{"target_type", dot_type}}), l1); a_arg =
info.add_instruction(make_op("convert", {{"target_type", dot_type}}), a_arg);
} }
} }
l1 = a_arg = (trans_a)
(transa) ? info.add_instruction(make_op("transpose", {{"permutation", perm}}), l1) : l1; ? info.add_instruction(make_op("transpose", {{"permutation", perm}}), a_arg)
auto l2 = (transb) : a_arg;
? info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[1]) b_arg = (trans_b)
: args[1]; ? info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[1])
: args[1];
auto ret = info.add_instruction(make_op("dot"), l1, l2); auto dot_ins = info.add_instruction(make_op("dot"), a_arg, b_arg);
if(args.size() == 3) if(args.size() == 3)
{ {
if(not float_equal(beta, 0.0f) && args[2]->get_shape().elements() > 0) if(not float_equal(beta, 0.0f))
{ {
auto out_lens = l1->get_shape().lens(); auto c_arg = args[2];
out_lens.back() = l2->get_shape().lens().back(); if(dot_ins->get_shape().dynamic())
auto l3 = args[2];
auto l3_lens = l3->get_shape().lens();
if(not std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
{ {
l3 = info.add_instruction(make_op("multibroadcast", {{"out_lens", out_lens}}), c_arg = info.add_instruction(make_op("multibroadcast"), args[2], dot_ins);
args[2]);
} }
auto beta_literal = info.add_literal(beta); else
auto beta_l3 = info.add_broadcastable_binary_op("mul", l3, beta_literal);
if(beta_l3->get_shape().type() != dot_type)
{ {
beta_l3 = info.add_instruction(make_op("convert", {{"target_type", dot_type}}), auto out_lens = a_arg->get_shape().lens();
beta_l3); out_lens.back() = b_arg->get_shape().lens().back();
auto c_lens = c_arg->get_shape().lens();
if(not std::equal(
out_lens.begin(), out_lens.end(), c_lens.begin(), c_lens.end()))
{
c_arg = info.add_instruction(
make_op("multibroadcast", {{"out_lens", out_lens}}), args[2]);
}
} }
return info.add_instruction(make_op("add"), ret, beta_l3); if(not float_equal(beta, 1.0f))
{
auto beta_literal = info.add_literal(beta);
c_arg = info.add_broadcastable_binary_op("mul", c_arg, beta_literal);
if(c_arg->get_shape().type() != dot_type)
{
c_arg = info.add_instruction(
make_op("convert", {{"target_type", dot_type}}), c_arg);
}
}
return info.add_instruction(make_op("add"), dot_ins, c_arg);
} }
} }
return dot_ins;
return ret;
} }
}; };
......
...@@ -51,6 +51,24 @@ struct parse_if : op_parser<parse_if> ...@@ -51,6 +51,24 @@ struct parse_if : op_parser<parse_if>
" condition input can have only one element!"); " condition input can have only one element!");
} }
// Fold instruction if condition is constant thus can be evaled
// prior to inference
if(args.front()->can_eval())
{
auto cond_arg = args.front()->eval();
auto* mod = info.mod;
// then branch
if(cond_arg.at<bool>())
{
return parser.parse_graph(mod, then_graph, true);
}
// else branch
else
{
return parser.parse_graph(mod, else_graph, true);
}
}
std::string then_name = info.name + "_if"; std::string then_name = info.name + "_if";
module_ref then_mdl = parser.prog.create_module(then_name); module_ref then_mdl = parser.prog.create_module(then_name);
...@@ -58,10 +76,10 @@ struct parse_if : op_parser<parse_if> ...@@ -58,10 +76,10 @@ struct parse_if : op_parser<parse_if>
module_ref else_mdl = parser.prog.create_module(else_name); module_ref else_mdl = parser.prog.create_module(else_name);
// parse the then sub_graph // parse the then sub_graph
parser.parse_graph(then_mdl, then_graph); (void)parser.parse_graph(then_mdl, then_graph);
// parse_the else sub_graph // parse_the else sub_graph
parser.parse_graph(else_mdl, else_graph); (void)parser.parse_graph(else_mdl, else_graph);
auto then_out_shapes = then_mdl->get_output_shapes(); auto then_out_shapes = then_mdl->get_output_shapes();
auto else_out_shapes = else_mdl->get_output_shapes(); auto else_out_shapes = else_mdl->get_output_shapes();
......
...@@ -71,7 +71,7 @@ struct parse_loop : op_parser<parse_loop> ...@@ -71,7 +71,7 @@ struct parse_loop : op_parser<parse_loop>
module_ref sub_mod = parser.prog.create_module(mod_name); module_ref sub_mod = parser.prog.create_module(mod_name);
// parse the sub_graph // parse the sub_graph
parser.parse_graph(sub_mod, sub_graph); (void)parser.parse_graph(sub_mod, sub_graph);
auto ret = info.add_instruction( auto ret = info.add_instruction(
make_op("loop", {{"max_iterations", max_iterations}}), args, {sub_mod}); make_op("loop", {{"max_iterations", max_iterations}}), args, {sub_mod});
......
...@@ -43,55 +43,79 @@ struct parse_matmul : op_parser<parse_matmul> ...@@ -43,55 +43,79 @@ struct parse_matmul : op_parser<parse_matmul>
const onnx_parser::node_info& info, const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
auto l0 = args[0]; auto a0 = args[0];
auto l1 = args[1]; auto a1 = args[1];
auto l0_lens = l0->get_shape().lens(); auto s0 = a0->get_shape();
auto l1_lens = l1->get_shape().lens(); auto s1 = a1->get_shape();
// args[0] is a vector, prepend 1 to the shape instruction_ref dot_res;
bool is_a_prepended = false; bool is_a_prepended = false;
if(l0_lens.size() == 1) bool is_b_appended = false;
if(s0.ndim() == 1)
{ {
is_a_prepended = true; is_a_prepended = true;
l0_lens.insert(l0_lens.begin(), 1); a0 = info.add_instruction(make_op("unsqueeze", {{"axes", {0}}}), args[0]);
l0 = info.add_instruction(make_op("unsqueeze", {{"axes", {0}}}), args[0]);
} }
if(s1.ndim() == 1)
bool is_b_appended = false;
if(l1_lens.size() == 1)
{ {
is_b_appended = true; is_b_appended = true;
l1_lens.push_back(1); a1 = info.add_instruction(make_op("unsqueeze", {{"axes", {1}}}), args[1]);
l1 = info.add_instruction(make_op("unsqueeze", {{"axes", {1}}}), args[1]);
} }
instruction_ref bl0 = l0; if(s0.dynamic() or s1.dynamic())
instruction_ref bl1 = l1;
if(not std::equal(
l0_lens.rbegin() + 2, l0_lens.rend(), l1_lens.rbegin() + 2, l1_lens.rend()))
{ {
auto l0_it = l0_lens.begin() + l0_lens.size() - 2; if(opd.op_name == "quant_dot")
std::vector<std::size_t> l0_broadcasted_lens(l0_lens.begin(), l0_it); {
auto l1_it = l1_lens.begin() + l1_lens.size() - 2; MIGRAPHX_THROW("PARSE_MATMUL: dynamic MatMulInteger not supported");
std::vector<std::size_t> l1_broadcasted_lens(l1_lens.begin(), l1_it); }
auto output_lens = compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens); auto s0_dds = a0->get_shape().to_dynamic().dyn_dims();
l0_broadcasted_lens = output_lens; auto s1_dds = a1->get_shape().to_dynamic().dyn_dims();
l0_broadcasted_lens.insert(l0_broadcasted_lens.end(), l0_it, l0_lens.end());
l1_broadcasted_lens = output_lens; // TODO: handling this case requires a new multibroadcast mode
l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, l1_lens.end()); if(not std::equal(
if(l0_lens != l0_broadcasted_lens) s0_dds.rbegin() + 2, s0_dds.rend(), s1_dds.rbegin() + 2, s1_dds.rend()))
{ {
bl0 = info.add_instruction( MIGRAPHX_THROW("PARSE_MATMUL: dynamic shape broadcasting not supported");
make_op("multibroadcast", {{"out_lens", l0_broadcasted_lens}}), l0);
} }
if(l1_lens != l1_broadcasted_lens)
dot_res = info.add_instruction(make_op(opd.op_name), a0, a1);
}
else
{
auto s0_lens = a0->get_shape().lens();
auto s1_lens = a1->get_shape().lens();
instruction_ref ba0 = a0;
instruction_ref ba1 = a1;
// try broadcasting if dimensions other than last two do not match
if(not std::equal(
s0_lens.rbegin() + 2, s0_lens.rend(), s1_lens.rbegin() + 2, s1_lens.rend()))
{ {
bl1 = info.add_instruction( auto l0_it = s0_lens.begin() + s0_lens.size() - 2;
make_op("multibroadcast", {{"out_lens", l1_broadcasted_lens}}), l1); std::vector<std::size_t> l0_broadcasted_lens(s0_lens.begin(), l0_it);
auto l1_it = s1_lens.begin() + s1_lens.size() - 2;
std::vector<std::size_t> l1_broadcasted_lens(s1_lens.begin(), l1_it);
auto output_lens =
compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens);
l0_broadcasted_lens = output_lens;
l0_broadcasted_lens.insert(l0_broadcasted_lens.end(), l0_it, s0_lens.end());
l1_broadcasted_lens = output_lens;
l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, s1_lens.end());
if(s0_lens != l0_broadcasted_lens)
{
ba0 = info.add_instruction(
make_op("multibroadcast", {{"out_lens", l0_broadcasted_lens}}), a0);
}
if(s1_lens != l1_broadcasted_lens)
{
ba1 = info.add_instruction(
make_op("multibroadcast", {{"out_lens", l1_broadcasted_lens}}), a1);
}
} }
dot_res = info.add_instruction(make_op(opd.op_name), ba0, ba1);
} }
instruction_ref dot_res = info.add_instruction(make_op(opd.op_name), bl0, bl1);
int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size()); // squeeze the appended or prepended dimensions
int64_t num_axis = dot_res->get_shape().ndim();
if(is_a_prepended) if(is_a_prepended)
{ {
dot_res = info.add_instruction(make_op("squeeze", {{"axes", {num_axis - 2}}}), dot_res); dot_res = info.add_instruction(make_op("squeeze", {{"axes", {num_axis - 2}}}), dot_res);
......
...@@ -147,7 +147,13 @@ struct parse_pad : op_parser<parse_pad> ...@@ -147,7 +147,13 @@ struct parse_pad : op_parser<parse_pad>
{ {
auto mode = info.attributes.at("mode").s(); auto mode = info.attributes.at("mode").s();
if(mode == "reflect") if(mode == "reflect")
{
if(args.front()->get_shape().dynamic())
{
MIGRAPHX_THROW("PARSE_PAD: reflect padding with dynamic shape not supported");
}
return reflect_pad(info, pads, args.front()); return reflect_pad(info, pads, args.front());
}
if(mode != "constant") if(mode != "constant")
{ {
MIGRAPHX_THROW( MIGRAPHX_THROW(
......
...@@ -68,8 +68,7 @@ instruction_ref parse_reduce_oper(const std::string& op_name, ...@@ -68,8 +68,7 @@ instruction_ref parse_reduce_oper(const std::string& op_name,
} }
else else
{ {
std::size_t n_dim = args.front()->get_shape().lens().size(); axes.resize(args.front()->get_shape().ndim());
axes.resize(n_dim);
std::iota(axes.begin(), axes.end(), 0); std::iota(axes.begin(), axes.end(), 0);
} }
} }
......
...@@ -49,7 +49,7 @@ struct parse_reshape : op_parser<parse_reshape> ...@@ -49,7 +49,7 @@ struct parse_reshape : op_parser<parse_reshape>
if(args.size() == 2) if(args.size() == 2)
{ {
auto s = args[1]->eval(); auto s = args[1]->eval();
check_arg_empty(s, "Reshape: dynamic shape is not supported"); check_arg_empty(s, "Reshape: non-constant shape input is not supported");
s.visit([&](auto v) { copy(v, std::back_inserter(dims)); }); s.visit([&](auto v) { copy(v, std::back_inserter(dims)); });
} }
......
...@@ -46,7 +46,7 @@ struct parse_slice : op_parser<parse_slice> ...@@ -46,7 +46,7 @@ struct parse_slice : op_parser<parse_slice>
std::vector<int64_t> steps; std::vector<int64_t> steps;
// slice can have up to 5 inputs, we first check the 5th one // slice can have up to 5 inputs, we first check the 5th one
// to decide whether MIGRAPHX can handle this slice // to decide whether MIGRAPHX can handle this slice.
if(args.size() == 5) if(args.size() == 5)
{ {
migraphx::argument step_arg = args.back()->eval(); migraphx::argument step_arg = args.back()->eval();
...@@ -90,9 +90,10 @@ struct parse_slice : op_parser<parse_slice> ...@@ -90,9 +90,10 @@ struct parse_slice : op_parser<parse_slice>
s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); }); s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); });
} }
// If axes arg is not given, the default is all of them.
if(op.axes.empty()) if(op.axes.empty())
{ {
std::vector<int64_t> axes(args[0]->get_shape().lens().size()); std::vector<int64_t> axes(args[0]->get_shape().ndim());
std::iota(axes.begin(), axes.end(), int64_t{0}); std::iota(axes.begin(), axes.end(), int64_t{0});
op.axes = axes; op.axes = axes;
} }
...@@ -103,6 +104,7 @@ struct parse_slice : op_parser<parse_slice> ...@@ -103,6 +104,7 @@ struct parse_slice : op_parser<parse_slice>
assert(op.axes.size() == op.starts.size()); assert(op.axes.size() == op.starts.size());
assert(op.axes.size() == op.ends.size()); assert(op.axes.size() == op.ends.size());
// If any axes have negative step, prepare to add a "reverse" op
for(auto i : range(steps.size())) for(auto i : range(steps.size()))
{ {
if(steps[i] >= 0) if(steps[i] >= 0)
...@@ -117,7 +119,10 @@ struct parse_slice : op_parser<parse_slice> ...@@ -117,7 +119,10 @@ struct parse_slice : op_parser<parse_slice>
auto ins = info.add_instruction(op, args[0]); auto ins = info.add_instruction(op, args[0]);
if(not raxes.empty()) if(not raxes.empty())
{
ins = info.add_instruction(make_op("reverse", {{"axes", raxes}}), ins); ins = info.add_instruction(make_op("reverse", {{"axes", raxes}}), ins);
}
// If any steps are other than default 1, add a "steps" op
if(std::any_of(steps.begin(), steps.end(), [](auto s) { return std::abs(s) != 1; })) if(std::any_of(steps.begin(), steps.end(), [](auto s) { return std::abs(s) != 1; }))
{ {
std::vector<int64_t> nsteps; std::vector<int64_t> nsteps;
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_trilu : op_parser<parse_trilu>
{
std::vector<op_desc> operators() const { return {{"Trilu"}}; }
instruction_ref parse(const op_desc&,
const onnx_parser&,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto input_shape = args[0]->get_shape();
assert(input_shape.ndim() >= 2);
auto input_lens = input_shape.lens();
size_t num_rows = *(input_lens.rbegin() + 1);
size_t num_cols = input_lens.back();
int k = 0;
bool upper = true;
if(args.size() > 1)
{
auto arg_k = args[1]->eval();
check_arg_empty(arg_k, "PARSE_TRILU: dynamic k not supported");
k = arg_k.at<int>();
}
if(k < 0)
MIGRAPHX_THROW("PARSE_TRILU: negative k values not supported");
if(contains(info.attributes, "upper"))
{
upper = static_cast<bool>(info.attributes.at("upper").i());
}
shape::type_t output_type = args[0]->get_shape().type();
// when creating the mask, if upper == 1,
// the inner triangle will have values set to 0
std::vector<bool> mask_mat(num_rows * num_cols, upper);
for(size_t i = 0; i < num_rows; i++)
{
for(size_t j = 0; j < std::min(k, static_cast<int>(num_cols)); j++)
{
mask_mat[i * num_cols + j] = not upper;
}
k++;
}
auto mask = info.add_literal(
migraphx::literal{migraphx::shape{output_type, {num_rows, num_cols}}, mask_mat});
return info.add_broadcastable_binary_op("mul", mask, args[0]);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -40,28 +40,44 @@ struct parse_where : op_parser<parse_where> ...@@ -40,28 +40,44 @@ struct parse_where : op_parser<parse_where>
const onnx_parser::node_info& info, const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
auto lens = // TODO: broadcasting for dynamic shapes is only implemented
compute_broadcasted_lens(args[0]->get_shape().lens(), args[1]->get_shape().lens()); // for binary ops at time of writing, not ternary ops.
lens = compute_broadcasted_lens(lens, args[2]->get_shape().lens()); // When it becomes available, add multibroadcasting steps in the dynamic shape case.
if(args[0]->get_shape().lens() != lens) // For now for dynamic shapes, just insert the Where op. All shapes must be the
// same for it to succeed.
if(std::all_of(args.begin(), args.end(), [](auto v) { return v->get_shape().dynamic(); }))
{ {
args[0] = return info.add_instruction(make_op("where"), args[0], args[1], args[2]);
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[0]);
} }
else if(std::none_of(
if(args[1]->get_shape().lens() != lens) args.begin(), args.end(), [](auto v) { return v->get_shape().dynamic(); }))
{ {
args[1] = // If shapes are static and any are broadcasted, insert multibroadcast ops
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[1]); auto lens =
} compute_broadcasted_lens(args[0]->get_shape().lens(), args[1]->get_shape().lens());
lens = compute_broadcasted_lens(lens, args[2]->get_shape().lens());
if(args[0]->get_shape().lens() != lens)
{
args[0] =
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[0]);
}
if(args[2]->get_shape().lens() != lens) if(args[1]->get_shape().lens() != lens)
{ {
args[2] = args[1] =
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[2]); info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[1]);
} }
if(args[2]->get_shape().lens() != lens)
{
args[2] =
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[2]);
}
return info.add_instruction(make_op("where"), args[0], args[1], args[2]); return info.add_instruction(make_op("where"), args[0], args[1], args[2]);
}
else
MIGRAPHX_THROW("PARSE_WHERE: doesn't support mixed static and dynamic shape inputs");
} }
}; };
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include "memory_coloring_impl.hpp"
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void memory_coloring_impl::run()
{
// calc implicit depdendencies
mod_implicit_deps = p_mod->calc_implicit_deps();
MIGRAPHX_DEBUG(dump("---Before memory coloring---"));
MIGRAPHX_DEBUG(dump_module());
build();
if(num_of_lives != 0)
{
MIGRAPHX_DEBUG(dump_intervals());
// Coloring
while(not alloc_queue.empty())
{
interval_ptr interval = alloc_queue.top();
allocate(interval);
alloc_queue.pop();
}
// rewrite happens after all modules are processed
rewrite();
if(enable_verify)
verify();
}
}
bool memory_coloring_impl::allocate(interval_ptr interval)
{
shape s = interval->result;
std::size_t size = s.bytes();
if(size == 0)
return false;
std::size_t element_size = (s.elements() == 0 ? 4 : (size / s.elements()));
live_range& segment = interval->segment;
int vn = segment.vn;
std::priority_queue<live_range*, std::vector<live_range*>, ordering> conflict_queue;
std::unordered_map<long long, live_range*> offset2_live;
offset2_live.clear();
if(conflict_table.find(vn) != conflict_table.end())
{
const std::set<int>& vn_set = conflict_table[vn];
for(const auto& iter : vn_set)
{
live_range* range = live_ranges[iter];
long long offset = range->offset;
if(offset != invalid_offset)
{
conflict_queue.push(range);
if(offset2_live.find(offset) == offset2_live.end())
{
offset2_live[offset] = range;
}
else
{
live_range* prev = offset2_live[offset];
assert(prev->offset == offset);
if(prev->size < range->size)
offset2_live[offset] = range;
}
}
}
}
std::size_t offset = 0;
while(not conflict_queue.empty())
{
live_range* range = conflict_queue.top();
std::size_t iter_offset = range->offset;
if(offset > iter_offset)
{
offset = std::max(offset, iter_offset + range->size);
}
else if(offset2_live[iter_offset] == range)
{
if((iter_offset > offset) && (iter_offset - offset) >= size)
{
break;
}
offset = iter_offset + range->size;
}
// alignment
if((offset % element_size) != 0)
offset += (element_size - (offset % element_size));
conflict_queue.pop();
}
// when int8 type is used, the offset could be any number
// if not 4-byte aligned, miopen int8 convolution can crash
offset = (offset + 3) / 4 * 4;
segment.offset = offset;
MIGRAPHX_DEBUG(segment.dump());
required_bytes = std::max(required_bytes, offset + segment.size);
return true;
}
void memory_coloring_impl::build()
{
std::size_t num_of_instrs = p_mod->size();
if(num_of_instrs == 0)
return;
auto cur_points = num_of_instrs * 2;
instruction_ref iter = p_mod->end();
instruction_ref begin = p_mod->begin();
std::vector<instruction_ref> dead_instrs;
std::set<int> live_set;
// Build live intervals.
live_intervals.resize(num_of_instrs);
do
{
iter = std::prev(iter);
const instruction* p_iter = &(*iter);
interval_ptr def_interval = nullptr;
bool is_dead = false;
if(instr2_live.find(p_iter) != instr2_live.end())
{
def_interval = instr2_live[p_iter];
bool is_lit = is_literal(iter);
if(is_allocate(iter) or is_lit)
{
live_range& range = def_interval->segment;
def_interval->result = iter->get_shape();
def_interval->is_literal = is_lit;
range.begin = cur_points;
def_interval->def_point = cur_points;
range.size = (iter->get_shape()).bytes();
if(not is_lit or unify_literals)
alloc_queue.push(def_interval);
live_set.erase(range.vn);
}
}
else if(not is_param(iter) && not is_outline(iter) && not is_check_context(iter))
{
is_dead = true;
}
auto inputs = iter->inputs();
if(contains(mod_implicit_deps, iter))
{
const auto& impl_deps = mod_implicit_deps.at(iter);
inputs.insert(inputs.end(), impl_deps.begin(), impl_deps.end());
}
for(auto&& arg : inputs)
{
if(not p_mod->has_instruction(arg))
continue;
if(is_param(arg) or is_outline(arg))
{
if(is_output_param(arg))
is_dead = false;
if(def_interval != nullptr)
{
def_interval->is_live_on_entry = true;
}
continue;
}
const instruction* p_arg = &(*instruction::get_output_alias(arg));
if(instr2_live.find(p_arg) == instr2_live.end())
{
// First time see a use, create a live interval.
int id = num_of_lives++;
interval_ptr interval = &(live_intervals[id]);
interval->id = id;
interval->segment.end = cur_points;
interval->segment.vn = ++max_value_number;
interval->add_use(cur_points);
instr2_live[p_arg] = interval;
add_conflicts(live_set, max_value_number);
live_set.insert(max_value_number);
live_ranges[max_value_number] = &(interval->segment);
earliest_end_point = cur_points;
if(latest_end_point == -1)
latest_end_point = cur_points;
}
else
{
interval_ptr interval = instr2_live[p_arg];
interval->add_use(cur_points);
assert(live_set.find(interval->id) != live_set.end());
}
}
if(is_dead)
dead_instrs.push_back(iter);
cur_points -= 2;
} while(iter != begin);
}
void memory_coloring_impl::rewrite()
{
std::vector<std::size_t> dims;
dims.push_back((required_bytes + sizeof(float) - 1) / sizeof(float));
shape s = {shape::float_type, dims};
instruction_ref scratch_param = p_mod->add_parameter("scratch", s);
for(auto ins : iterator_for(*p_mod))
{
const instruction* p_iter = &(*ins);
if(instr2_live.find(p_iter) != instr2_live.end())
{
interval_ptr interval = instr2_live[p_iter];
if(interval->get_begin() == invalid_offset)
continue;
if(not unify_literals && interval->is_literal)
continue;
std::size_t offset = 0;
if(interval->get_offset() != invalid_offset)
{
offset = interval->get_offset();
}
else
{
assert(interval->result.bytes() == 0);
}
if(is_allocate(ins))
{
p_mod->replace_instruction(
ins,
make_op("load", {{"shape", to_value(ins->get_shape())}, {"offset", offset}}),
scratch_param);
}
}
}
MIGRAPHX_DEBUG(dump("---After rewrite---"));
MIGRAPHX_DEBUG(dump_module());
}
void memory_coloring_impl::verify()
{
if(num_of_lives > 0)
{
for(int i = 0; i < num_of_lives; ++i)
{
const live_interval& interval = live_intervals[i];
const live_range& segment = interval.segment;
if(segment.begin == invalid_offset)
{
// if(not interval.is_live_on_entry)
// MIGRAPHX_THROW("interval is not live on entry");
continue;
}
if(segment.offset == invalid_offset)
{
continue;
}
int vn = segment.vn;
if(conflict_table.find(vn) != conflict_table.end())
{
const std::set<int>& vn_set = conflict_table[vn];
for(const auto& iter : vn_set)
{
live_range* range = live_ranges[iter];
if(range->offset == invalid_offset)
continue;
if(not is_disjoin(*range, segment))
MIGRAPHX_THROW("range and segment is not disjoined");
}
}
}
}
}
#ifdef MIGRAPHX_DEBUG_OPT
void memory_coloring_impl::dump(const std::string& str) { std::cout << str << std::endl; }
void memory_coloring_impl::dump_module() { std::cout << *p_mod << std::endl; }
void memory_coloring_impl::dump_intervals()
{
if(num_of_lives > 0)
{
std::cout << "---live intervals ---" << std::endl;
for(int i = 0; i < num_of_lives; ++i)
{
live_interval& interval = live_intervals[i];
interval.dump();
}
std::cout << "---conflict table---" << std::endl;
for(int i = 0; i <= max_value_number; ++i)
{
std::cout << " segment:" << i;
std::cout << " =>";
const std::set<int>& table = conflict_table[i];
for(const auto& iter : table)
{
std::cout << (iter) << ",";
}
}
std::cout << std::endl;
}
}
// map liveness tracking point to instruction enum.
static int get_ins_enum(int x)
{
if(x > 0)
{
return (x / 2) - 1;
}
else
return invalid_offset;
}
void live_range::dump()
{
std::cout << " segment:" << vn;
std::cout << " [" << get_ins_enum(begin) << ", " << get_ins_enum(end) << "]";
if(offset != invalid_offset)
{
std::cout << " mem:";
std::cout << " [" << offset << "," << offset + size - 1 << "]";
}
std::cout << std::endl;
}
void live_interval::dump()
{
std::cout << "id:" << id;
segment.dump();
std::cout << " uses:";
for(const auto& iter : use_points)
{
std::cout << " " << get_ins_enum(iter) << ",";
}
std::cout << " def:";
std::cout << " " << get_ins_enum(def_point);
if(is_literal)
std::cout << " literal";
std::cout << " " << result;
std::cout << std::endl;
}
#endif
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
This diff is collapsed.
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/optimize_module.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/propagate_constant.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void optimize_module::apply(module_pass_manager& mpm) const
{
for(int i = 0; i < 2; i++)
{
mpm.run_pass(simplify_reshapes{});
mpm.run_pass(simplify_algebra{});
mpm.run_pass(eliminate_common_subexpression{});
mpm.run_pass(dead_code_elimination{});
mpm.run_pass(propagate_constant{});
mpm.run_pass(dead_code_elimination{});
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -329,15 +329,21 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -329,15 +329,21 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.def("is_compiled", &migraphx::program::is_compiled) .def("is_compiled", &migraphx::program::is_compiled)
.def( .def(
"compile", "compile",
[](migraphx::program& p, const migraphx::target& t, bool offload_copy, bool fast_math) { [](migraphx::program& p,
const migraphx::target& t,
bool offload_copy,
bool fast_math,
bool exhaustive_tune) {
migraphx::compile_options options; migraphx::compile_options options;
options.offload_copy = offload_copy; options.offload_copy = offload_copy;
options.fast_math = fast_math; options.fast_math = fast_math;
options.exhaustive_tune = exhaustive_tune;
p.compile(t, options); p.compile(t, options);
}, },
py::arg("t"), py::arg("t"),
py::arg("offload_copy") = true, py::arg("offload_copy") = true,
py::arg("fast_math") = true) py::arg("fast_math") = true,
py::arg("exhaustive_tune") = false)
.def("get_main_module", [](const migraphx::program& p) { return p.get_main_module(); }) .def("get_main_module", [](const migraphx::program& p) { return p.get_main_module(); })
.def( .def(
"create_module", "create_module",
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment