Commit 7dc6e3ae authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into mi100_opts

parents f94d77fc a275f590
...@@ -47,13 +47,13 @@ struct parse_clip : op_parser<parse_clip> ...@@ -47,13 +47,13 @@ struct parse_clip : op_parser<parse_clip>
if(min_used) if(min_used)
{ {
min_arg = info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}), min_arg = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
min_arg); min_arg);
} }
if(max_used) if(max_used)
{ {
max_arg = info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}), max_arg = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
max_arg); max_arg);
} }
......
...@@ -29,11 +29,11 @@ struct parse_dequantizelinear : op_parser<parse_dequantizelinear> ...@@ -29,11 +29,11 @@ struct parse_dequantizelinear : op_parser<parse_dequantizelinear>
{ {
auto tuned_axis = tune_axis(n_dim, axis, opd.op_name); auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
x_scale = info.add_instruction( x_scale = info.add_instruction(
make_op("broadcast", {{"axis", tuned_axis}, {"dims", input_lens}}), args[1]); make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}), args[1]);
} }
else else
{ {
x_scale = info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}), x_scale = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
args[1]); args[1]);
} }
...@@ -44,13 +44,13 @@ struct parse_dequantizelinear : op_parser<parse_dequantizelinear> ...@@ -44,13 +44,13 @@ struct parse_dequantizelinear : op_parser<parse_dequantizelinear>
{ {
auto tuned_axis = tune_axis(n_dim, axis, opd.op_name); auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
x_zero_point = info.add_instruction( x_zero_point = info.add_instruction(
make_op("broadcast", {{"axis", tuned_axis}, {"dims", input_lens}}), make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}),
x_zero_point); x_zero_point);
} }
else else
{ {
x_zero_point = info.add_instruction( x_zero_point = info.add_instruction(
make_op("multibroadcast", {{"output_lens", input_lens}}), x_zero_point); make_op("multibroadcast", {{"out_lens", input_lens}}), x_zero_point);
} }
return info.add_instruction( return info.add_instruction(
......
...@@ -24,8 +24,7 @@ struct parse_expand : op_parser<parse_expand> ...@@ -24,8 +24,7 @@ struct parse_expand : op_parser<parse_expand>
std::vector<std::size_t> dims; std::vector<std::size_t> dims;
arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); }); arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
auto out_lens = compute_broadcasted_lens(in_lens, dims); auto out_lens = compute_broadcasted_lens(in_lens, dims);
return info.add_instruction(make_op("multibroadcast", {{"output_lens", out_lens}}), return info.add_instruction(make_op("multibroadcast", {{"out_lens", out_lens}}), args[0]);
args[0]);
} }
}; };
......
...@@ -63,8 +63,8 @@ struct parse_gather_elements : op_parser<parse_gather_elements> ...@@ -63,8 +63,8 @@ struct parse_gather_elements : op_parser<parse_gather_elements>
info.add_literal(literal(ind_s, data_indices.begin(), data_indices.end())); info.add_literal(literal(ind_s, data_indices.begin(), data_indices.end()));
auto l_dim_idx = info.add_literal(literal(ind_s, vec_axis_ind.begin(), vec_axis_ind.end())); auto l_dim_idx = info.add_literal(literal(ind_s, vec_axis_ind.begin(), vec_axis_ind.end()));
auto l_stride = info.add_literal(literal{{ind_s.type(), {1}}, {axis_stride}}); auto l_stride = info.add_literal(literal{{ind_s.type(), {1}}, {axis_stride}});
l_stride = info.add_instruction(make_op("multibroadcast", {{"output_lens", ind_s.lens()}}), l_stride =
l_stride); info.add_instruction(make_op("multibroadcast", {{"out_lens", ind_s.lens()}}), l_stride);
auto dim_diff = info.add_instruction(make_op("sub"), arg_ind, l_dim_idx); auto dim_diff = info.add_instruction(make_op("sub"), arg_ind, l_dim_idx);
auto delta = info.add_instruction(make_op("mul"), dim_diff, l_stride); auto delta = info.add_instruction(make_op("mul"), dim_diff, l_stride);
auto ind = info.add_instruction(make_op("add"), l_shape_idx, delta); auto ind = info.add_instruction(make_op("add"), l_shape_idx, delta);
......
...@@ -55,13 +55,17 @@ struct parse_gemm : op_parser<parse_gemm> ...@@ -55,13 +55,17 @@ struct parse_gemm : op_parser<parse_gemm>
} }
} }
l1 = (transa) ? info.add_instruction(make_op("transpose", {{"dims", perm}}), l1) : l1; l1 =
auto l2 = (transb) ? info.add_instruction(make_op("transpose", {{"dims", perm}}), args[1]) (transa) ? info.add_instruction(make_op("transpose", {{"permutation", perm}}), l1) : l1;
: args[1]; auto l2 = (transb)
? info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[1])
: args[1];
auto ret = info.add_instruction(make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), l1, l2);
if(args.size() == 3) if(args.size() == 3)
{ {
if(beta != 0.0f && args[2]->get_shape().elements() > 0) if(not float_equal(beta, 0.0f) && args[2]->get_shape().elements() > 0)
{ {
auto out_lens = l1->get_shape().lens(); auto out_lens = l1->get_shape().lens();
out_lens.back() = l2->get_shape().lens().back(); out_lens.back() = l2->get_shape().lens().back();
...@@ -69,8 +73,8 @@ struct parse_gemm : op_parser<parse_gemm> ...@@ -69,8 +73,8 @@ struct parse_gemm : op_parser<parse_gemm>
auto l3_lens = l3->get_shape().lens(); auto l3_lens = l3->get_shape().lens();
if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end())) if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
{ {
l3 = info.add_instruction( l3 = info.add_instruction(make_op("multibroadcast", {{"out_lens", out_lens}}),
make_op("multibroadcast", {{"output_lens", out_lens}}), args[2]); args[2]);
} }
auto beta_literal = info.add_literal(beta); auto beta_literal = info.add_literal(beta);
auto beta_l3 = info.add_broadcastable_binary_op("mul", l3, beta_literal); auto beta_l3 = info.add_broadcastable_binary_op("mul", l3, beta_literal);
...@@ -80,12 +84,11 @@ struct parse_gemm : op_parser<parse_gemm> ...@@ -80,12 +84,11 @@ struct parse_gemm : op_parser<parse_gemm>
beta_l3); beta_l3);
} }
return info.add_instruction( return info.add_instruction(make_op("add"), ret, beta_l3);
make_op("dot", {{"alpha", 1.0f}, {"beta", 1.0f}}), l1, l2, beta_l3);
} }
} }
return info.add_instruction(make_op("dot", {{"alpha", 1.0f}, {"beta", 1.0f}}), l1, l2); return ret;
} }
}; };
......
...@@ -40,7 +40,7 @@ struct parse_imagescalar : op_parser<parse_imagescalar> ...@@ -40,7 +40,7 @@ struct parse_imagescalar : op_parser<parse_imagescalar>
auto img_scaled = auto img_scaled =
info.add_instruction(migraphx::make_op("mul"), args.front(), scale_tensor); info.add_instruction(migraphx::make_op("mul"), args.front(), scale_tensor);
auto bias_bcast = info.add_instruction( auto bias_bcast = info.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", input_lens}}), bias_vals); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", input_lens}}), bias_vals);
return info.add_instruction(migraphx::make_op("add"), img_scaled, bias_bcast); return info.add_instruction(migraphx::make_op("add"), img_scaled, bias_bcast);
} }
}; };
......
...@@ -38,23 +38,23 @@ struct parse_instancenorm : op_parser<parse_instancenorm> ...@@ -38,23 +38,23 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
auto mean = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x); auto mean = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x);
auto mean_bcast = auto mean_bcast =
info.add_instruction(make_op("multibroadcast", {{"output_lens", dims}}), mean); info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), mean);
auto l0 = info.add_instruction(make_op("sqdiff"), x, mean_bcast); auto l0 = info.add_instruction(make_op("sqdiff"), x, mean_bcast);
auto variance = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), l0); auto variance = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), l0);
auto l1 = info.add_instruction(make_op("sub"), x, mean_bcast); auto l1 = info.add_instruction(make_op("sub"), x, mean_bcast);
auto epsilon_literal = info.add_literal(epsilon); auto epsilon_literal = info.add_literal(epsilon);
auto epsilon_bcast = info.add_instruction( auto epsilon_bcast =
make_op("multibroadcast", {{"output_lens", dims}}), epsilon_literal); info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), epsilon_literal);
auto variance_bcast = auto variance_bcast =
info.add_instruction(make_op("multibroadcast", {{"output_lens", dims}}), variance); info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), variance);
auto l2 = info.add_instruction(make_op("add"), variance_bcast, epsilon_bcast); auto l2 = info.add_instruction(make_op("add"), variance_bcast, epsilon_bcast);
auto l3 = info.add_instruction(make_op("rsqrt"), l2); auto l3 = info.add_instruction(make_op("rsqrt"), l2);
auto l4 = info.add_instruction(make_op("mul"), l1, l3); auto l4 = info.add_instruction(make_op("mul"), l1, l3);
auto scale_bcast = auto scale_bcast =
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"dims", dims}}), scale); info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), scale);
; ;
auto bias_bcast = auto bias_bcast =
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"dims", dims}}), bias); info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), bias);
auto l5 = info.add_instruction(make_op("mul"), l4, scale_bcast); auto l5 = info.add_instruction(make_op("mul"), l4, scale_bcast);
return info.add_instruction(make_op("add"), l5, bias_bcast); return info.add_instruction(make_op("add"), l5, bias_bcast);
} }
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/onnx_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_loop : op_parser<parse_loop>
{
std::vector<op_desc> operators() const { return {{"Loop"}}; }
std::vector<instruction_ref> parse(const op_desc& /*opd*/,
onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
// default value of the max_iter_num
int64_t max_iterations = parser.max_loop_iterations;
// iteration input is empty
if(args.at(0)->name() == "undefined")
{
shape iter_s{shape::int64_type};
args[0] = info.add_literal(literal(iter_s, {max_iterations}));
}
else
{
auto arg_iters = args.at(0)->eval();
if(not arg_iters.empty())
{
max_iterations = arg_iters.at<int64_t>();
}
}
// condition input is empty
if(args.at(1)->name() == "undefined")
{
shape cond_s{shape::bool_type};
args[1] = info.add_literal(literal(cond_s, {true}));
}
// retrieve the subgraph
const auto& sub_graph = info.attributes.at("body").g();
std::string mod_name = info.name + "_loop";
module_ref sub_mod = parser.prog.create_module(mod_name);
// parse the sub_graph
parser.parse_graph(sub_mod, sub_graph);
auto ret = info.add_instruction(
make_op("loop", {{"max_iterations", max_iterations}}), args, {sub_mod});
auto out_s = ret->get_shape();
assert(out_s.type() == shape::tuple_type);
const auto& vec_shapes = out_s.sub_shapes();
std::vector<instruction_ref> out_inss;
for(std::size_t i = 0; i < vec_shapes.size(); ++i)
{
auto r = info.add_instruction(make_op("get_tuple_elem", {{"index", i}}), ret);
out_inss.push_back(r);
}
return out_inss;
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -58,12 +58,12 @@ struct parse_matmul : op_parser<parse_matmul> ...@@ -58,12 +58,12 @@ struct parse_matmul : op_parser<parse_matmul>
if(l0_lens != l0_broadcasted_lens) if(l0_lens != l0_broadcasted_lens)
{ {
bl0 = info.add_instruction( bl0 = info.add_instruction(
make_op("multibroadcast", {{"output_lens", l0_broadcasted_lens}}), l0); make_op("multibroadcast", {{"out_lens", l0_broadcasted_lens}}), l0);
} }
if(l1_lens != l1_broadcasted_lens) if(l1_lens != l1_broadcasted_lens)
{ {
bl1 = info.add_instruction( bl1 = info.add_instruction(
make_op("multibroadcast", {{"output_lens", l1_broadcasted_lens}}), l1); make_op("multibroadcast", {{"out_lens", l1_broadcasted_lens}}), l1);
} }
} }
......
...@@ -45,8 +45,9 @@ struct parse_onehot : op_parser<parse_onehot> ...@@ -45,8 +45,9 @@ struct parse_onehot : op_parser<parse_onehot>
std::vector<int64_t> perm(n_rank - 1); std::vector<int64_t> perm(n_rank - 1);
std::iota(perm.begin(), perm.end(), 0); std::iota(perm.begin(), perm.end(), 0);
perm.insert(perm.begin() + tuned_axis, n_rank - 1); perm.insert(perm.begin() + tuned_axis, n_rank - 1);
auto tr_out = info.add_instruction(make_op("transpose", {{"dims", perm}}), gather_out); auto tr_out =
auto lens = tr_out->get_shape().lens(); info.add_instruction(make_op("transpose", {{"permutation", perm}}), gather_out);
auto lens = tr_out->get_shape().lens();
auto off_val = info.add_instruction( auto off_val = info.add_instruction(
make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]); make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]);
...@@ -54,9 +55,9 @@ struct parse_onehot : op_parser<parse_onehot> ...@@ -54,9 +55,9 @@ struct parse_onehot : op_parser<parse_onehot>
make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]); make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]);
auto diff = info.add_instruction(make_op("sub"), on_val, off_val); auto diff = info.add_instruction(make_op("sub"), on_val, off_val);
auto unsq_off_val = auto unsq_off_val =
info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), off_val); info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), off_val);
auto unsq_diff_val = auto unsq_diff_val =
info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), diff); info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), diff);
auto l_mul = info.add_instruction(make_op("mul"), tr_out, unsq_diff_val); auto l_mul = info.add_instruction(make_op("mul"), tr_out, unsq_diff_val);
return info.add_instruction(make_op("add"), l_mul, unsq_off_val); return info.add_instruction(make_op("add"), l_mul, unsq_off_val);
} }
......
...@@ -29,11 +29,11 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear> ...@@ -29,11 +29,11 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear>
{ {
auto tuned_axis = tune_axis(n_dim, axis, opd.op_name); auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
y_scale = info.add_instruction( y_scale = info.add_instruction(
make_op("broadcast", {{"axis", tuned_axis}, {"dims", input_lens}}), args[1]); make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}), args[1]);
} }
else else
{ {
y_scale = info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}), y_scale = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
args[1]); args[1]);
} }
...@@ -44,13 +44,13 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear> ...@@ -44,13 +44,13 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear>
{ {
auto tuned_axis = tune_axis(n_dim, axis, opd.op_name); auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
y_zero_point = info.add_instruction( y_zero_point = info.add_instruction(
make_op("broadcast", {{"axis", tuned_axis}, {"dims", input_lens}}), make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}),
y_zero_point); y_zero_point);
} }
else else
{ {
y_zero_point = info.add_instruction( y_zero_point = info.add_instruction(
make_op("multibroadcast", {{"output_lens", input_lens}}), y_zero_point); make_op("multibroadcast", {{"out_lens", input_lens}}), y_zero_point);
} }
return info.add_instruction(make_op("quantizelinear"), args[0], y_scale, y_zero_point); return info.add_instruction(make_op("quantizelinear"), args[0], y_scale, y_zero_point);
......
...@@ -35,9 +35,9 @@ struct parse_selu : op_parser<parse_selu> ...@@ -35,9 +35,9 @@ struct parse_selu : op_parser<parse_selu>
if(lens != std::vector<std::size_t>{1}) if(lens != std::vector<std::size_t>{1})
{ {
l_alpha = l_alpha =
info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), l_alpha); info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), l_alpha);
l_gamma = l_gamma =
info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), l_gamma); info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), l_gamma);
} }
auto sign_x = info.add_instruction(make_op("sign"), args[0]); auto sign_x = info.add_instruction(make_op("sign"), args[0]);
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/common.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_thresholdedrelu : op_parser<parse_thresholdedrelu>
{
std::vector<op_desc> operators() const { return {{"ThresholdedRelu"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
float alpha = 1.0;
if(contains(info.attributes, "alpha"))
alpha = parser.parse_value(info.attributes.at("alpha")).at<float>();
auto x_shape = args[0]->get_shape();
auto lit_zero = info.add_literal(migraphx::literal{migraphx::shape{x_shape.type()}, {0}});
auto lit_alpha =
info.add_literal(migraphx::literal{migraphx::shape{x_shape.type()}, {alpha}});
auto mb_zero = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", x_shape.lens()}}), lit_zero);
auto mb_alpha = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", x_shape.lens()}}), lit_alpha);
auto condition = info.add_instruction(migraphx::make_op("greater"), args[0], mb_alpha);
return info.add_instruction(migraphx::make_op("where"), condition, args[0], mb_zero);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_topk : op_parser<parse_topk>
{
std::vector<op_desc> operators() const { return {{"TopK"}}; }
std::vector<instruction_ref> parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
int64_t k = 0;
if(args.size() == 2)
{
auto arg_k = args.at(1)->eval();
check_arg_empty(arg_k, "PARSE_TopK: k input must be constant");
k = arg_k.at<int>();
}
else if(contains(info.attributes, "k"))
{
k = info.attributes.at("k").i();
}
bool largest = true;
if(contains(info.attributes, "largest"))
{
largest = static_cast<bool>(info.attributes.at("largest").i());
}
int64_t axis = -1;
if(contains(info.attributes, "axis"))
{
axis = parser.parse_value(info.attributes.at("axis")).at<int>();
}
auto topk_ret = info.add_instruction(
make_op("topk", {{"k", k}, {"axis", axis}, {"largest", largest}}), args.at(0));
auto ret_val = info.add_instruction(make_op("get_tuple_elem", {{"index", 0}}), topk_ret);
auto ret_ind = info.add_instruction(make_op("get_tuple_elem", {{"index", 1}}), topk_ret);
return {ret_val, ret_ind};
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp> #include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -21,7 +22,21 @@ struct parse_transpose : op_parser<parse_transpose> ...@@ -21,7 +22,21 @@ struct parse_transpose : op_parser<parse_transpose>
auto&& perm_vals = info.attributes["perm"].ints(); auto&& perm_vals = info.attributes["perm"].ints();
perm = std::vector<int64_t>(perm_vals.begin(), perm_vals.end()); perm = std::vector<int64_t>(perm_vals.begin(), perm_vals.end());
} }
return info.add_instruction(make_op("transpose", {{"dims", perm}}), args.front());
// if perm is empty, use the default value
auto n_dim = args.front()->get_shape().lens().size();
if(perm.empty())
{
perm.resize(n_dim);
std::iota(perm.rbegin(), perm.rend(), 0);
}
if(perm.size() != n_dim)
{
MIGRAPHX_THROW("PARSE_TRANSPOSE: perm and input have diffferent number of dims!");
}
return info.add_instruction(make_op("transpose", {{"permutation", perm}}), args.front());
} }
}; };
......
...@@ -17,45 +17,28 @@ struct parse_where : op_parser<parse_where> ...@@ -17,45 +17,28 @@ struct parse_where : op_parser<parse_where>
const onnx_parser::node_info& info, const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
auto cond = auto lens =
info.add_instruction(make_op("convert", {{"target_type", shape::int32_type}}), args[0]); compute_broadcasted_lens(args[0]->get_shape().lens(), args[1]->get_shape().lens());
auto lens = compute_broadcasted_lens(cond->get_shape().lens(), args[1]->get_shape().lens()); lens = compute_broadcasted_lens(lens, args[2]->get_shape().lens());
lens = compute_broadcasted_lens(lens, args[2]->get_shape().lens()); if(args[0]->get_shape().lens() != lens)
if(cond->get_shape().lens() != lens)
{ {
cond = info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), cond); args[0] =
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[0]);
} }
if(args[1]->get_shape().lens() != lens) if(args[1]->get_shape().lens() != lens)
{ {
args[1] = args[1] =
info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), args[1]); info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[1]);
} }
if(args[2]->get_shape().lens() != lens) if(args[2]->get_shape().lens() != lens)
{ {
args[2] = args[2] =
info.add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), args[2]); info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[2]);
} }
// compute index return info.add_instruction(make_op("where"), args[0], args[1], args[2]);
auto elem_num = args[1]->get_shape().elements();
// concatenation of input data
auto concat_data = info.add_instruction(make_op("concat", {{"axis", 0}}), args[2], args[1]);
std::vector<int64_t> dims = {static_cast<int64_t>(2 * elem_num)};
auto rsp_data = info.add_instruction(make_op("reshape", {{"dims", dims}}), concat_data);
std::vector<int> ind(elem_num);
std::iota(ind.begin(), ind.end(), 0);
shape ind_s{shape::int32_type, lens};
auto l_ind = info.add_literal(literal(ind_s, ind));
std::vector<int> offset(elem_num, elem_num);
auto l_offset = info.add_literal(literal({shape::int32_type, lens}, offset));
auto ins_offset = info.add_instruction(make_op("mul"), l_offset, cond);
auto ins_ind = info.add_instruction(make_op("add"), ins_offset, l_ind);
return info.add_instruction(make_op("gather", {{"axis", 0}}), rsp_data, ins_ind);
} }
}; };
......
...@@ -40,7 +40,7 @@ bool memory_coloring_impl::allocate(interval_ptr interval) ...@@ -40,7 +40,7 @@ bool memory_coloring_impl::allocate(interval_ptr interval)
std::size_t size = s.bytes(); std::size_t size = s.bytes();
if(size == 0) if(size == 0)
return false; return false;
std::size_t element_size = size / s.elements(); std::size_t element_size = (s.elements() == 0 ? 4 : (size / s.elements()));
live_range& segment = interval->segment; live_range& segment = interval->segment;
int vn = segment.vn; int vn = segment.vn;
std::priority_queue<live_range*, std::vector<live_range*>, ordering> conflict_queue; std::priority_queue<live_range*, std::vector<live_range*>, ordering> conflict_queue;
......
...@@ -32,14 +32,6 @@ void validate_pass(module& mod, const pass& p, tracer trace) ...@@ -32,14 +32,6 @@ void validate_pass(module& mod, const pass& p, tracer trace)
trace(); trace();
#endif #endif
} }
void run_pass(module& mod, const pass& p, tracer trace)
{
trace("Module: ", mod.name(), ", Pass: ", p.name());
assert(mod.validate() == mod.end());
p.apply(mod);
trace(mod);
validate_pass(mod, p, trace);
}
void run_pass(program& prog, const pass& p, tracer trace) void run_pass(program& prog, const pass& p, tracer trace)
{ {
trace("Pass: ", p.name()); trace("Pass: ", p.name());
...@@ -47,11 +39,52 @@ void run_pass(program& prog, const pass& p, tracer trace) ...@@ -47,11 +39,52 @@ void run_pass(program& prog, const pass& p, tracer trace)
trace(prog); trace(prog);
} }
struct module_pm : module_pass_manager
{
module* mod;
program* prog;
tracer* t;
module_pm(module* pmod = nullptr, program* pprog = nullptr, tracer* pt = nullptr)
: mod(pmod), prog(pprog), t(pt)
{
}
template <class... Ts>
void trace(Ts&&... xs) const
{
assert(t);
(*t)(xs...);
}
virtual module& get_module() override
{
assert(mod);
return *mod;
}
virtual module* create_module(const std::string& name) override
{
assert(prog);
return prog->create_module(name);
}
virtual void run_pass(const pass& p) override
{
assert(mod);
trace("Module: ", mod->name(), ", Pass: ", p.name());
assert(mod->validate() == mod->end());
p.apply(*this);
trace(*mod);
validate_pass(*mod, p, *t);
}
};
module& get_module(module_pass_manager& mpm) { return mpm.get_module(); }
void run_passes(module& mod, const std::vector<pass>& passes, tracer trace) void run_passes(module& mod, const std::vector<pass>& passes, tracer trace)
{ {
for(const auto& p : passes) for(const auto& p : passes)
{ {
run_pass(mod, p, trace); module_pm{&mod, nullptr, &trace}.run_pass(p);
} }
} }
...@@ -62,7 +95,7 @@ void run_passes(program& prog, const std::vector<pass>& passes, tracer trace) ...@@ -62,7 +95,7 @@ void run_passes(program& prog, const std::vector<pass>& passes, tracer trace)
auto mods = prog.get_modules(); auto mods = prog.get_modules();
for(const auto& mod : reverse(mods)) for(const auto& mod : reverse(mods))
{ {
run_pass(*mod, p, trace); module_pm{mod, &prog, &trace}.run_pass(p);
} }
run_pass(prog, p, trace); run_pass(prog, p, trace);
} }
......
...@@ -10,6 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,6 +10,7 @@ inline namespace MIGRAPHX_INLINE_NS {
void preallocate_param::apply(module& m) const void preallocate_param::apply(module& m) const
{ {
auto last = std::prev(m.end());
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
{ {
if(ins->name() != "@param") if(ins->name() != "@param")
...@@ -19,7 +20,9 @@ void preallocate_param::apply(module& m) const ...@@ -19,7 +20,9 @@ void preallocate_param::apply(module& m) const
std::string id = m.name() + ":" + param; std::string id = m.name() + ":" + param;
auto r = m.insert_instruction(ins, model.preallocate(ins->get_shape(), id)); auto r = m.insert_instruction(ins, model.preallocate(ins->get_shape(), id));
m.replace_instruction(ins, r); m.replace_instruction(ins, r);
m.move_instruction(ins, m.end());
} }
m.remove_instructions(std::next(last), m.end());
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -184,14 +184,16 @@ std::vector<argument> generic_eval(const module* mod, ...@@ -184,14 +184,16 @@ std::vector<argument> generic_eval(const module* mod,
context& ctx, context& ctx,
std::unordered_map<std::string, argument> params, std::unordered_map<std::string, argument> params,
std::unordered_map<instruction_ref, argument> results, std::unordered_map<instruction_ref, argument> results,
F trace) F make_trace)
{ {
assert(mod->validate() == mod->end()); assert(mod->validate() == mod->end());
results.reserve(mod->size() * 2); results.reserve(mod->size() * 2);
std::vector<argument> values; std::vector<argument> values;
values.reserve(16); values.reserve(16);
auto trace = make_trace(mod);
for(auto ins : iterator_for(*mod)) for(auto ins : iterator_for(*mod))
{ {
assert(results.find(ins) == results.end());
const auto& name = ins->name(); const auto& name = ins->name();
if(name == "@literal") if(name == "@literal")
{ {
...@@ -240,7 +242,8 @@ std::vector<argument> generic_eval(const module* mod, ...@@ -240,7 +242,8 @@ std::vector<argument> generic_eval(const module* mod,
const auto& mod_args = ins->module_inputs(); const auto& mod_args = ins->module_inputs();
auto module_eval = [&](module_ref smod, auto module_eval = [&](module_ref smod,
const std::unordered_map<std::string, argument>& inputs) { const std::unordered_map<std::string, argument>& inputs) {
return generic_eval(smod, ctx, inputs, results, trace); auto ssctx = ctx;
return generic_eval(smod, ssctx, inputs, results, make_trace);
}; };
results.emplace(ins, trace(ins, [&] { results.emplace(ins, trace(ins, [&] {
...@@ -249,6 +252,7 @@ std::vector<argument> generic_eval(const module* mod, ...@@ -249,6 +252,7 @@ std::vector<argument> generic_eval(const module* mod,
})); }));
} }
assert(results.find(ins) != results.end()); assert(results.find(ins) != results.end());
assert(results.at(ins).get_shape() == ins->get_shape());
} }
return {results.at(std::prev(mod->end()))}; return {results.at(std::prev(mod->end()))};
} }
...@@ -257,50 +261,67 @@ template <class F> ...@@ -257,50 +261,67 @@ template <class F>
std::vector<argument> generic_eval(const program& p, std::vector<argument> generic_eval(const program& p,
context& ctx, context& ctx,
std::unordered_map<std::string, argument> params, std::unordered_map<std::string, argument> params,
F trace) F make_trace)
{ {
const module* mm = p.get_main_module(); const module* mm = p.get_main_module();
return generic_eval(mm, ctx, params, {}, trace); return generic_eval(mm, ctx, params, {}, make_trace);
} }
std::vector<argument> program::eval(parameter_map params) const std::vector<argument> program::eval(parameter_map params) const
{ {
auto& ctx = this->impl->ctx; auto& ctx = this->impl->ctx;
#ifndef NDEBUG #ifndef NDEBUG
auto sctx = ctx; auto with_check_context = [&](auto f) {
auto check_context = [&](auto f) { return [=, &ctx](auto&&) {
assert(is_shared(ctx, sctx)); auto sctx = std::make_shared<context>(ctx);
auto x = f(); auto check_context = [=, &ctx](auto g) {
sctx = ctx; assert(is_shared(ctx, *sctx));
return x; auto x = g();
*sctx = ctx;
return x;
};
return [=](auto&&... xs) { return f(xs..., check_context); };
};
}; };
#else #else
auto check_context = [](auto f) { return f(); }; auto with_check_context = [](auto f) {
return [=](auto&&) {
return [=](auto&&... xs) { return f(xs..., [](auto g) { return g(); }); };
};
};
#endif #endif
auto trace_level = value_of(MIGRAPHX_TRACE_EVAL{}); auto trace_level = value_of(MIGRAPHX_TRACE_EVAL{});
if(trace_level > 0) if(trace_level > 0)
{ {
return generic_eval(*this, ctx, std::move(params), [&](auto& ins, auto f) { return generic_eval(*this,
ctx.finish(); ctx,
std::cout << "Run instruction: "; std::move(params),
this->debug_print(ins); with_check_context([&](auto& ins, auto f, auto&& check_context) {
timer t{}; ctx.finish();
auto result = check_context(f); std::cout << "Run instruction: ";
double t1 = t.record<milliseconds>(); this->debug_print(ins);
ctx.finish(); timer t{};
double t2 = t.record<milliseconds>(); auto result = check_context(f);
std::cout << "Time: " << t1 << "ms, " << t2 << "ms" << std::endl; double t1 = t.record<milliseconds>();
if(trace_level > 1 and ins->name().front() != '@' and ins->name() != "load") ctx.finish();
std::cout << "Output: " << result << std::endl; double t2 = t.record<milliseconds>();
return result; std::cout << "Time: " << t1 << "ms, " << t2 << "ms" << std::endl;
}); if(trace_level > 1 and ins->name().front() != '@' and
ins->name() != "load")
std::cout << "Output: " << result << std::endl;
return result;
}));
} }
else else
{ {
return generic_eval( return generic_eval(*this,
*this, ctx, std::move(params), [&](auto&, auto f) { return check_context(f); }); ctx,
std::move(params),
with_check_context([&](auto&, auto f, auto&& check_context) {
return check_context(f);
}));
} }
} }
...@@ -502,21 +523,22 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params) ...@@ -502,21 +523,22 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
std::sort(total_vec.begin(), total_vec.end()); std::sort(total_vec.begin(), total_vec.end());
std::unordered_map<instruction_ref, std::vector<double>> ins_vec; std::unordered_map<instruction_ref, std::vector<double>> ins_vec;
// Fill the map // Fill the map
generic_eval(*this, ctx, params, [&](auto ins, auto) { generic_eval(*this, ctx, params, always([&](auto ins, auto) {
ins_vec[ins].reserve(n); ins_vec[ins].reserve(n);
return argument{}; return argument{ins->get_shape(), nullptr};
}); }));
// Run and time each instruction // Run and time each instruction
for(std::size_t i = 0; i < n; i++) for(std::size_t i = 0; i < n; i++)
{ {
generic_eval(*this, ctx, params, [&](auto ins, auto f) { generic_eval(*this, ctx, params, always([&](auto ins, auto f) {
argument result; argument result;
ins_vec[ins].push_back(time<milliseconds>([&] { ins_vec[ins].push_back(time<milliseconds>([&] {
result = f(); result = f();
ctx.finish(); ctx.finish();
})); }));
return result; return result;
}); }));
} }
for(auto&& p : ins_vec) for(auto&& p : ins_vec)
std::sort(p.second.begin(), p.second.end()); std::sort(p.second.begin(), p.second.end());
...@@ -645,7 +667,9 @@ void program::print_cpp(std::ostream& os) const ...@@ -645,7 +667,9 @@ void program::print_cpp(std::ostream& os) const
void program::dry_run(std::unordered_map<std::string, argument> params) const void program::dry_run(std::unordered_map<std::string, argument> params) const
{ {
auto& ctx = this->impl->ctx; auto& ctx = this->impl->ctx;
generic_eval(*this, ctx, std::move(params), [](auto&&...) { return argument{}; }); generic_eval(*this, ctx, std::move(params), always([](auto ins, auto&&...) {
return argument{ins->get_shape(), nullptr};
}));
} }
void program::annotate(std::ostream& os, const std::function<void(instruction_ref)>& a) const void program::annotate(std::ostream& os, const std::function<void(instruction_ref)>& a) const
...@@ -745,6 +769,22 @@ void program::remove_module(const std::string& name) ...@@ -745,6 +769,22 @@ void program::remove_module(const std::string& name)
impl->modules.at(name).end(), impl->modules.at(name).end(),
[&](auto&& ins) { return references_instruction(impl->modules, ins, name); }) && [&](auto&& ins) { return references_instruction(impl->modules, ins, name); }) &&
"Instruction referenced in another module"); "Instruction referenced in another module");
// if an instruction has an input out side of the current module, need to remove
// the instruction from its input's outputs
auto& mod = impl->modules.at(name);
for(auto ins : iterator_for(mod))
{
auto inputs = ins->inputs();
for(auto in : inputs)
{
if(not mod.has_instruction(in))
{
in->remove_output(ins);
}
}
}
impl->modules.erase(name); impl->modules.erase(name);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment