Unverified Commit 8d21fdc9 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Refactor to use make_op almost everywhere (#696)

* Load op when serializing

* Formatting

* Add missing clip field

* Use make_op almost everywhere

* Formatting

* More make ops for rnns

* Get rid of spaces

* Formatting

* Remove operators headers

* Formatting

* Remove unused op headers

* Increase line threshold
parent b5633c27
...@@ -18,7 +18,7 @@ CheckOptions: ...@@ -18,7 +18,7 @@ CheckOptions:
- key: readability-function-size.BranchThreshold - key: readability-function-size.BranchThreshold
value: '15' value: '15'
- key: readability-function-size.LineThreshold - key: readability-function-size.LineThreshold
value: '300' value: '350'
- key: readability-function-size.NestingThreshold - key: readability-function-size.NestingThreshold
value: '5' value: '5'
- key: readability-function-size.ParameterThreshold - key: readability-function-size.ParameterThreshold
......
...@@ -26,6 +26,7 @@ add_library(migraphx ...@@ -26,6 +26,7 @@ add_library(migraphx
load_save.cpp load_save.cpp
make_op.cpp make_op.cpp
msgpack.cpp msgpack.cpp
operation.cpp
program.cpp program.cpp
quantization.cpp quantization.cpp
reduce_dims.cpp reduce_dims.cpp
......
#include <migraphx/auto_contiguous.hpp> #include <migraphx/auto_contiguous.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/op/contiguous.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
namespace migraphx { namespace migraphx {
...@@ -14,7 +15,7 @@ void auto_contiguous::apply(module& p) const ...@@ -14,7 +15,7 @@ void auto_contiguous::apply(module& p) const
shape s = ins->get_shape(); shape s = ins->get_shape();
if(not s.standard() and s.elements() != 0) if(not s.standard() and s.elements() != 0)
{ {
auto c = p.insert_instruction(std::next(ins), op::contiguous{}, ins); auto c = p.insert_instruction(std::next(ins), make_op("contiguous"), ins);
p.replace_instruction(ins, c); p.replace_instruction(ins, c);
} }
} }
......
...@@ -9,6 +9,8 @@ ...@@ -9,6 +9,8 @@
#include <migraphx/op/dot.hpp> #include <migraphx/op/dot.hpp>
#include <migraphx/op/multibroadcast.hpp> #include <migraphx/op/multibroadcast.hpp>
#include <migraphx/op/mul.hpp> #include <migraphx/op/mul.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/add.hpp> #include <migraphx/op/add.hpp>
namespace migraphx { namespace migraphx {
...@@ -26,17 +28,19 @@ struct find_dot_add ...@@ -26,17 +28,19 @@ struct find_dot_add
not contains({shape::float_type, shape::half_type, shape::double_type}, not contains({shape::float_type, shape::half_type, shape::double_type},
ins->get_shape().type())) ins->get_shape().type()))
return; return;
auto dot_ins = auto dot_ins = p.insert_instruction(ins,
p.insert_instruction(ins, op::dot{dot.alpha, 0}, ins->inputs()[0], ins->inputs()[1]); make_op("dot", {{"alpha", dot.alpha}, {"beta", 0}}),
ins->inputs()[0],
ins->inputs()[1]);
auto c_ins = ins->inputs()[2]; auto c_ins = ins->inputs()[2];
if(not float_equal(dot.beta, 1)) if(not float_equal(dot.beta, 1))
{ {
auto beta = p.add_literal(literal{shape{ins->get_shape().type()}, {dot.beta}}); auto beta = p.add_literal(literal{shape{ins->get_shape().type()}, {dot.beta}});
auto beta_broadcast = auto beta_broadcast = p.insert_instruction(
p.insert_instruction(ins, op::multibroadcast{ins->get_shape().lens()}, beta); ins, make_op("multibroadcast", {{"output_lens", ins->get_shape().lens()}}), beta);
c_ins = p.insert_instruction(ins, op::mul{}, c_ins, beta_broadcast); c_ins = p.insert_instruction(ins, make_op("mul"), c_ins, beta_broadcast);
} }
p.replace_instruction(ins, op::add{}, dot_ins, c_ins); p.replace_instruction(ins, make_op("add"), dot_ins, c_ins);
} }
}; };
......
#include <migraphx/eliminate_allocation.hpp> #include <migraphx/eliminate_allocation.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/op/load.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/pass_config.hpp> #include <migraphx/pass_config.hpp>
namespace migraphx { namespace migraphx {
...@@ -33,7 +36,8 @@ void eliminate_allocation::apply(module& p) const ...@@ -33,7 +36,8 @@ void eliminate_allocation::apply(module& p) const
auto ins = pp.first; auto ins = pp.first;
auto s = ins->get_shape(); auto s = ins->get_shape();
auto offset = pp.second; auto offset = pp.second;
p.replace_instruction(ins, op::load{s, offset}, mem); p.replace_instruction(
ins, make_op("load", {{"shape", to_value(s)}, {"offset", offset}}), mem);
} }
} }
} }
......
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
#include <migraphx/op/identity.hpp> #include <migraphx/op/identity.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
namespace migraphx { namespace migraphx {
...@@ -77,7 +79,7 @@ void eliminate_concat::apply(module& p) const ...@@ -77,7 +79,7 @@ void eliminate_concat::apply(module& p) const
} }
std::vector<instruction_ref> args = {super}; std::vector<instruction_ref> args = {super};
std::copy(ins->inputs().begin(), ins->inputs().end() - 1, std::back_inserter(args)); std::copy(ins->inputs().begin(), ins->inputs().end() - 1, std::back_inserter(args));
p.replace_instruction(ins, migraphx::op::identity{}, args); p.replace_instruction(ins, migraphx::make_op("identity"), args);
} }
} }
} }
......
...@@ -33,6 +33,7 @@ struct lstm ...@@ -33,6 +33,7 @@ struct lstm
return pack(f(self.hidden_size, "hidden_size"), return pack(f(self.hidden_size, "hidden_size"),
f(self.actv_funcs, "actv_func"), f(self.actv_funcs, "actv_func"),
f(self.direction, "direction"), f(self.direction, "direction"),
f(self.clip, "clip"),
f(self.input_forget, "input_forget")); f(self.input_forget, "input_forget"));
} }
......
...@@ -846,6 +846,9 @@ bool has_finalize(const T& x) ...@@ -846,6 +846,9 @@ bool has_finalize(const T& x)
return detail::has_finalize_op(x); return detail::has_finalize_op(x);
} }
void migraphx_to_value(value& v, const operation& op);
void migraphx_from_value(const value& v, operation& op);
#endif #endif
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -45,6 +45,10 @@ ...@@ -45,6 +45,10 @@
#include <migraphx/op/transpose.hpp> #include <migraphx/op/transpose.hpp>
#include <migraphx/op/undefined.hpp> #include <migraphx/op/undefined.hpp>
#include <migraphx/op/unknown.hpp> #include <migraphx/op/unknown.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/op/unsqueeze.hpp> #include <migraphx/op/unsqueeze.hpp>
namespace migraphx { namespace migraphx {
...@@ -267,7 +271,9 @@ struct onnx_parser ...@@ -267,7 +271,9 @@ struct onnx_parser
if(broadcasted != 0) if(broadcasted != 0)
{ {
uint64_t axis = parse_value(info.attributes.at("axis")).at<uint64_t>(); uint64_t axis = parse_value(info.attributes.at("axis")).at<uint64_t>();
auto l = mm->add_instruction(op::broadcast{axis, args[0]->get_shape().lens()}, auto l = mm->add_instruction(
make_op("broadcast",
{{"axis", axis}, {"dims", args[0]->get_shape().lens()}}),
args[1]); args[1]);
return mm->add_instruction(make_op(op_name), args[0], l); return mm->add_instruction(make_op(op_name), args[0], l);
} }
...@@ -341,11 +347,13 @@ struct onnx_parser ...@@ -341,11 +347,13 @@ struct onnx_parser
auto l0 = arg0; auto l0 = arg0;
if(arg0->get_shape().lens() != out_lens) if(arg0->get_shape().lens() != out_lens)
l0 = mm->add_instruction(op::multibroadcast{out_lens}, arg0); l0 = mm->add_instruction(make_op("multibroadcast", {{"output_lens", out_lens}}),
arg0);
auto l1 = arg1; auto l1 = arg1;
if(arg1->get_shape().lens() != out_lens) if(arg1->get_shape().lens() != out_lens)
l1 = mm->add_instruction(op::multibroadcast{out_lens}, arg1); l1 = mm->add_instruction(make_op("multibroadcast", {{"output_lens", out_lens}}),
arg1);
return mm->add_instruction(make_op(name), l0, l1); return mm->add_instruction(make_op(name), l0, l1);
} }
...@@ -398,8 +406,9 @@ struct onnx_parser ...@@ -398,8 +406,9 @@ struct onnx_parser
{ {
if(args.size() == 3) if(args.size() == 3)
{ {
auto bias_bcast = auto bias_bcast = mm->add_instruction(
mm->add_instruction(op::broadcast{axis, curr_ins->get_shape().lens()}, args[2]); make_op("broadcast", {{"axis", axis}, {"dims", curr_ins->get_shape().lens()}}),
args[2]);
return mm->add_instruction(make_op("add"), curr_ins, bias_bcast); return mm->add_instruction(make_op("add"), curr_ins, bias_bcast);
} }
return curr_ins; return curr_ins;
...@@ -437,7 +446,8 @@ struct onnx_parser ...@@ -437,7 +446,8 @@ struct onnx_parser
asym_pads.insert(asym_pads.begin() + 2, left_pad_it, right_pad_it); asym_pads.insert(asym_pads.begin() + 2, left_pad_it, right_pad_it);
// add right pads // add right pads
asym_pads.insert(asym_pads.begin() + pad_ndims + 4, right_pad_it, padding.end()); asym_pads.insert(asym_pads.begin() + pad_ndims + 4, right_pad_it, padding.end());
ins = mm->add_instruction(op::pad{asym_pads, pad_val}, ins); ins =
mm->add_instruction(make_op("pad", {{"pads", asym_pads}, {"value", pad_val}}), ins);
} }
else else
{ {
...@@ -479,12 +489,14 @@ struct onnx_parser ...@@ -479,12 +489,14 @@ struct onnx_parser
if(min_used) if(min_used)
{ {
min_arg = mm->add_instruction(op::multibroadcast{input_lens}, min_arg); min_arg = mm->add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}),
min_arg);
} }
if(max_used) if(max_used)
{ {
max_arg = mm->add_instruction(op::multibroadcast{input_lens}, max_arg); max_arg = mm->add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}),
max_arg);
} }
if(min_used and max_used) if(min_used and max_used)
...@@ -525,7 +537,7 @@ struct onnx_parser ...@@ -525,7 +537,7 @@ struct onnx_parser
if(keep_dims == 0) if(keep_dims == 0)
{ {
auto ins = mm->add_instruction(make_op(op_name, {{"axis", axis}}), std::move(args)); auto ins = mm->add_instruction(make_op(op_name, {{"axis", axis}}), std::move(args));
return mm->add_instruction(op::squeeze{{axis}}, ins); return mm->add_instruction(make_op("squeeze", {{"axes", {axis}}}), ins);
} }
else else
{ {
...@@ -593,7 +605,8 @@ struct onnx_parser ...@@ -593,7 +605,8 @@ struct onnx_parser
{ {
*starts_it = idx; *starts_it = idx;
*ends_it = *starts_it + 1; *ends_it = *starts_it + 1;
slices.push_back(mm->add_instruction(op::slice{axes, starts, ends}, input)); slices.push_back(mm->add_instruction(
make_op("slice", {{"axes", axes}, {"starts", starts}, {"ends", ends}}), input));
} }
// when padding on the left side, the outermost pad should be at the beginning // when padding on the left side, the outermost pad should be at the beginning
std::reverse(slices.begin(), slices.end()); std::reverse(slices.begin(), slices.end());
...@@ -602,9 +615,10 @@ struct onnx_parser ...@@ -602,9 +615,10 @@ struct onnx_parser
{ {
*starts_it = *dims_it - idx - 1; *starts_it = *dims_it - idx - 1;
*ends_it = *starts_it + 1; *ends_it = *starts_it + 1;
slices.push_back(mm->add_instruction(op::slice{axes, starts, ends}, input)); slices.push_back(mm->add_instruction(
make_op("slice", {{"axes", axes}, {"starts", starts}, {"ends", ends}}), input));
} }
input = mm->add_instruction(op::concat{axis}, slices); input = mm->add_instruction(make_op("concat", {{"axis", axis}}), slices);
} }
return input; return input;
} }
...@@ -841,7 +855,8 @@ struct onnx_parser ...@@ -841,7 +855,8 @@ struct onnx_parser
std::back_inserter(ends), std::back_inserter(ends),
[](auto curr_dim, auto pad_dim) { return curr_dim - pad_dim; }); [](auto curr_dim, auto pad_dim) { return curr_dim - pad_dim; });
l1 = mm->add_instruction(op::slice{axes, starts, ends}, l1); l1 = mm->add_instruction(
make_op("slice", {{"axes", axes}, {"starts", starts}, {"ends", ends}}), l1);
} }
if(contains(info.attributes, "output_padding")) if(contains(info.attributes, "output_padding"))
...@@ -852,7 +867,7 @@ struct onnx_parser ...@@ -852,7 +867,7 @@ struct onnx_parser
check_attr_sizes(kdims, check_attr_sizes(kdims,
output_padding.size() - non_kdims, output_padding.size() - non_kdims,
"PARSE_CONV_TRANSPOSE: inconsistent output padding"); "PARSE_CONV_TRANSPOSE: inconsistent output padding");
l1 = mm->add_instruction(op::pad{output_padding}, l1); l1 = mm->add_instruction(make_op("pad", {{"pads", output_padding}}), l1);
} }
if(contains(info.attributes, "output_shape")) if(contains(info.attributes, "output_shape"))
...@@ -871,7 +886,7 @@ struct onnx_parser ...@@ -871,7 +886,7 @@ struct onnx_parser
curr_shape.begin(), curr_shape.begin(),
std::back_inserter(target_padding), std::back_inserter(target_padding),
[](auto out_dim, auto curr_dim) { return out_dim - curr_dim; }); [](auto out_dim, auto curr_dim) { return out_dim - curr_dim; });
l1 = mm->add_instruction(op::pad{target_padding}, l1); l1 = mm->add_instruction(make_op("pad", {{"pads", target_padding}}), l1);
} }
} }
...@@ -1049,7 +1064,9 @@ struct onnx_parser ...@@ -1049,7 +1064,9 @@ struct onnx_parser
{ {
std::vector<int64_t> axes(kdims); std::vector<int64_t> axes(kdims);
std::iota(axes.begin(), axes.end(), 2); std::iota(axes.begin(), axes.end(), 2);
l1 = mm->add_instruction(op::slice{axes, slice_start, slice_end}, l1); l1 = mm->add_instruction(
make_op("slice", {{"axes", axes}, {"starts", slice_start}, {"ends", slice_end}}),
l1);
} }
return l1; return l1;
...@@ -1283,7 +1300,7 @@ struct onnx_parser ...@@ -1283,7 +1300,7 @@ struct onnx_parser
int64_t data_elem_num = static_cast<int64_t>(data_s.elements()); int64_t data_elem_num = static_cast<int64_t>(data_s.elements());
// reshape the input data as one dimension and used as input data // reshape the input data as one dimension and used as input data
// to the gather operator // to the gather operator
arg_data = mm->add_instruction(op::reshape{{data_elem_num}}, arg_data); arg_data = mm->add_instruction(make_op("reshape", {{"dims", {data_elem_num}}}), arg_data);
std::size_t elem_num = ind_s.elements(); std::size_t elem_num = ind_s.elements();
std::vector<int> ind_index(elem_num); std::vector<int> ind_index(elem_num);
...@@ -1304,7 +1321,8 @@ struct onnx_parser ...@@ -1304,7 +1321,8 @@ struct onnx_parser
mm->add_literal(literal(ind_s, data_indices.begin(), data_indices.end())); mm->add_literal(literal(ind_s, data_indices.begin(), data_indices.end()));
auto l_dim_idx = mm->add_literal(literal(ind_s, vec_axis_ind.begin(), vec_axis_ind.end())); auto l_dim_idx = mm->add_literal(literal(ind_s, vec_axis_ind.begin(), vec_axis_ind.end()));
auto l_stride = mm->add_literal(literal{{ind_s.type(), {1}}, {axis_stride}}); auto l_stride = mm->add_literal(literal{{ind_s.type(), {1}}, {axis_stride}});
l_stride = mm->add_instruction(op::multibroadcast{ind_s.lens()}, l_stride); l_stride = mm->add_instruction(make_op("multibroadcast", {{"output_lens", ind_s.lens()}}),
l_stride);
auto dim_diff = mm->add_instruction(make_op("sub"), arg_ind, l_dim_idx); auto dim_diff = mm->add_instruction(make_op("sub"), arg_ind, l_dim_idx);
auto delta = mm->add_instruction(make_op("mul"), dim_diff, l_stride); auto delta = mm->add_instruction(make_op("mul"), dim_diff, l_stride);
auto ind = mm->add_instruction(make_op("add"), l_shape_idx, delta); auto ind = mm->add_instruction(make_op("add"), l_shape_idx, delta);
...@@ -1428,8 +1446,10 @@ struct onnx_parser ...@@ -1428,8 +1446,10 @@ struct onnx_parser
// swap the last two elements // swap the last two elements
std::swap(*perm.rbegin(), *(perm.rbegin() + 1)); std::swap(*perm.rbegin(), *(perm.rbegin() + 1));
auto l1 = (transa) ? mm->add_instruction(op::transpose{perm}, args[0]) : args[0]; auto l1 = (transa) ? mm->add_instruction(make_op("transpose", {{"dims", perm}}), args[0])
auto l2 = (transb) ? mm->add_instruction(op::transpose{perm}, args[1]) : args[1]; : args[0];
auto l2 = (transb) ? mm->add_instruction(make_op("transpose", {{"dims", perm}}), args[1])
: args[1];
if(args.size() == 3) if(args.size() == 3)
{ {
if(beta != 0.f && args[2]->get_shape().elements() > 0) if(beta != 0.f && args[2]->get_shape().elements() > 0)
...@@ -1440,7 +1460,8 @@ struct onnx_parser ...@@ -1440,7 +1460,8 @@ struct onnx_parser
auto l3_lens = l3->get_shape().lens(); auto l3_lens = l3->get_shape().lens();
if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end())) if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
{ {
l3 = mm->add_instruction(op::multibroadcast{out_lens}, args[2]); l3 = mm->add_instruction(make_op("multibroadcast", {{"output_lens", out_lens}}),
args[2]);
} }
return mm->add_instruction( return mm->add_instruction(
make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3); make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3);
...@@ -1466,7 +1487,7 @@ struct onnx_parser ...@@ -1466,7 +1487,7 @@ struct onnx_parser
{ {
is_a_prepended = true; is_a_prepended = true;
l0_lens.insert(l0_lens.begin(), 1); l0_lens.insert(l0_lens.begin(), 1);
l0 = mm->add_instruction(op::unsqueeze{{0}}, args[0]); l0 = mm->add_instruction(make_op("unsqueeze", {{"axes", {0}}}), args[0]);
} }
bool is_b_appended = false; bool is_b_appended = false;
...@@ -1474,7 +1495,7 @@ struct onnx_parser ...@@ -1474,7 +1495,7 @@ struct onnx_parser
{ {
is_b_appended = true; is_b_appended = true;
l1_lens.push_back(1); l1_lens.push_back(1);
l1 = mm->add_instruction(op::unsqueeze{{1}}, args[1]); l1 = mm->add_instruction(make_op("unsqueeze", {{"axes", {1}}}), args[1]);
} }
instruction_ref bl0 = l0; instruction_ref bl0 = l0;
...@@ -1492,11 +1513,13 @@ struct onnx_parser ...@@ -1492,11 +1513,13 @@ struct onnx_parser
l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, l1_lens.end()); l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, l1_lens.end());
if(l0_lens != l0_broadcasted_lens) if(l0_lens != l0_broadcasted_lens)
{ {
bl0 = mm->add_instruction(op::multibroadcast{l0_broadcasted_lens}, l0); bl0 = mm->add_instruction(
make_op("multibroadcast", {{"output_lens", l0_broadcasted_lens}}), l0);
} }
if(l1_lens != l1_broadcasted_lens) if(l1_lens != l1_broadcasted_lens)
{ {
bl1 = mm->add_instruction(op::multibroadcast{l1_broadcasted_lens}, l1); bl1 = mm->add_instruction(
make_op("multibroadcast", {{"output_lens", l1_broadcasted_lens}}), l1);
} }
} }
...@@ -1504,12 +1527,12 @@ struct onnx_parser ...@@ -1504,12 +1527,12 @@ struct onnx_parser
int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size()); int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size());
if(is_a_prepended) if(is_a_prepended)
{ {
dot_res = mm->add_instruction(op::squeeze{{num_axis - 2}}, dot_res); dot_res = mm->add_instruction(make_op("squeeze", {{"axes", {num_axis - 2}}}), dot_res);
--num_axis; --num_axis;
} }
if(is_b_appended) if(is_b_appended)
{ {
dot_res = mm->add_instruction(op::squeeze{{num_axis - 1}}, dot_res); dot_res = mm->add_instruction(make_op("squeeze", {{"axes", {num_axis - 1}}}), dot_res);
} }
return dot_res; return dot_res;
...@@ -1563,19 +1586,24 @@ struct onnx_parser ...@@ -1563,19 +1586,24 @@ struct onnx_parser
std::iota(axes.begin(), axes.end(), 2); std::iota(axes.begin(), axes.end(), 2);
auto mean = mm->add_instruction(make_op("reduce_mean", {{"axes", axes}}), x); auto mean = mm->add_instruction(make_op("reduce_mean", {{"axes", axes}}), x);
auto mean_bcast = mm->add_instruction(op::multibroadcast{dims}, mean); auto mean_bcast =
mm->add_instruction(make_op("multibroadcast", {{"output_lens", dims}}), mean);
auto l0 = mm->add_instruction(make_op("sqdiff"), x, mean_bcast); auto l0 = mm->add_instruction(make_op("sqdiff"), x, mean_bcast);
auto variance = mm->add_instruction(make_op("reduce_mean", {{"axes", axes}}), l0); auto variance = mm->add_instruction(make_op("reduce_mean", {{"axes", axes}}), l0);
auto l1 = mm->add_instruction(make_op("sub"), x, mean_bcast); auto l1 = mm->add_instruction(make_op("sub"), x, mean_bcast);
auto epsilon_literal = mm->add_literal(epsilon); auto epsilon_literal = mm->add_literal(epsilon);
auto epsilon_bcast = mm->add_instruction(op::multibroadcast{dims}, epsilon_literal); auto epsilon_bcast = mm->add_instruction(make_op("multibroadcast", {{"output_lens", dims}}),
auto variance_bcast = mm->add_instruction(op::multibroadcast{dims}, variance); epsilon_literal);
auto variance_bcast =
mm->add_instruction(make_op("multibroadcast", {{"output_lens", dims}}), variance);
auto l2 = mm->add_instruction(make_op("add"), variance_bcast, epsilon_bcast); auto l2 = mm->add_instruction(make_op("add"), variance_bcast, epsilon_bcast);
auto l3 = mm->add_instruction(make_op("rsqrt"), l2); auto l3 = mm->add_instruction(make_op("rsqrt"), l2);
auto l4 = mm->add_instruction(make_op("mul"), l1, l3); auto l4 = mm->add_instruction(make_op("mul"), l1, l3);
auto scale_bcast = mm->add_instruction(op::broadcast{1, dims}, scale); auto scale_bcast =
mm->add_instruction(make_op("broadcast", {{"axis", 1}, {"dims", dims}}), scale);
; ;
auto bias_bcast = mm->add_instruction(op::broadcast{1, dims}, bias); auto bias_bcast =
mm->add_instruction(make_op("broadcast", {{"axis", 1}, {"dims", dims}}), bias);
auto l5 = mm->add_instruction(make_op("mul"), l4, scale_bcast); auto l5 = mm->add_instruction(make_op("mul"), l4, scale_bcast);
return mm->add_instruction(make_op("add"), l5, bias_bcast); return mm->add_instruction(make_op("add"), l5, bias_bcast);
} }
...@@ -1645,9 +1673,11 @@ struct onnx_parser ...@@ -1645,9 +1673,11 @@ struct onnx_parser
auto scale_val = mm->add_literal(literal{shape{input_type}, {scale}}); auto scale_val = mm->add_literal(literal{shape{input_type}, {scale}});
auto bias_vals = mm->add_literal(literal{shape{input_type, {bias.size()}}, bias}); auto bias_vals = mm->add_literal(literal{shape{input_type, {bias.size()}}, bias});
auto scale_tensor = mm->add_instruction(migraphx::op::scalar{input_lens}, scale_val); auto scale_tensor = mm->add_instruction(
migraphx::make_op("scalar", {{"scalar_bcst_dims", input_lens}}), scale_val);
auto img_scaled = mm->add_instruction(migraphx::make_op("mul"), args.front(), scale_tensor); auto img_scaled = mm->add_instruction(migraphx::make_op("mul"), args.front(), scale_tensor);
auto bias_bcast = mm->add_instruction(migraphx::op::broadcast{1, input_lens}, bias_vals); auto bias_bcast = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", input_lens}}), bias_vals);
return mm->add_instruction(migraphx::make_op("add"), img_scaled, bias_bcast); return mm->add_instruction(migraphx::make_op("add"), img_scaled, bias_bcast);
} }
...@@ -1660,7 +1690,7 @@ struct onnx_parser ...@@ -1660,7 +1690,7 @@ struct onnx_parser
auto&& perm_vals = info.attributes["perm"].ints(); auto&& perm_vals = info.attributes["perm"].ints();
perm = std::vector<int64_t>(perm_vals.begin(), perm_vals.end()); perm = std::vector<int64_t>(perm_vals.begin(), perm_vals.end());
} }
return mm->add_instruction(migraphx::op::transpose{perm}, args.front()); return mm->add_instruction(migraphx::make_op("transpose", {{"dims", perm}}), args.front());
} }
instruction_ref parse_pad(const std::string&, node_info info, std::vector<instruction_ref> args) instruction_ref parse_pad(const std::string&, node_info info, std::vector<instruction_ref> args)
...@@ -1721,7 +1751,8 @@ struct onnx_parser ...@@ -1721,7 +1751,8 @@ struct onnx_parser
value = parse_value(info.attributes.at("value")).at<float>(); value = parse_value(info.attributes.at("value")).at<float>();
} }
return mm->add_instruction(migraphx::op::pad{pads, value}, args.front()); return mm->add_instruction(migraphx::make_op("pad", {{"pads", pads}, {"value", value}}),
args.front());
} }
instruction_ref instruction_ref
...@@ -1918,7 +1949,7 @@ struct onnx_parser ...@@ -1918,7 +1949,7 @@ struct onnx_parser
std::vector<std::size_t> dims; std::vector<std::size_t> dims;
arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); }); arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
auto out_lens = compute_broadcasted_lens(in_lens, dims); auto out_lens = compute_broadcasted_lens(in_lens, dims);
return mm->add_instruction(op::multibroadcast{out_lens}, args[0]); return mm->add_instruction(make_op("multibroadcast", {{"output_lens", out_lens}}), args[0]);
} }
std::vector<instruction_ref> std::vector<instruction_ref>
...@@ -2001,16 +2032,20 @@ struct onnx_parser ...@@ -2001,16 +2032,20 @@ struct onnx_parser
// undefined operator to have 6 arguments // undefined operator to have 6 arguments
if(args.size() < 6) if(args.size() < 6)
{ {
auto ins = mm->add_instruction(op::undefined{}); auto ins = mm->add_instruction(make_op("undefined"));
args.insert(args.end(), (6 - args.size()), ins); args.insert(args.end(), (6 - args.size()), ins);
} }
// first output for the concatenation of hidden states // first output for the concatenation of hidden states
auto hidden_states = auto hidden_states = mm->add_instruction(make_op("rnn",
mm->add_instruction(op::rnn{hidden_size, vec_actv_funcs, dirct, clip}, std::move(args)); {{"hidden_size", hidden_size},
{"actv_func", to_value(vec_actv_funcs)},
{"direction", dirct},
{"clip", clip}}),
std::move(args));
// second output for the last hidden state // second output for the last hidden state
auto last_output = mm->add_instruction(op::rnn_last_hs_output{}, hidden_states); auto last_output = mm->add_instruction(make_op("rnn_last_hs_output"), hidden_states);
return {hidden_states, last_output}; return {hidden_states, last_output};
} }
...@@ -2122,17 +2157,22 @@ struct onnx_parser ...@@ -2122,17 +2157,22 @@ struct onnx_parser
// append undefined opeator to make 6 arguments // append undefined opeator to make 6 arguments
if(args.size() < 6) if(args.size() < 6)
{ {
auto ins = mm->add_instruction(op::undefined{}); auto ins = mm->add_instruction(make_op("undefined"));
args.insert(args.end(), 6 - args.size(), ins); args.insert(args.end(), 6 - args.size(), ins);
} }
// first output for concatenation of hidden states // first output for concatenation of hidden states
auto hidden_states = mm->add_instruction( auto hidden_states =
op::gru{hidden_size, vec_actv_funcs, dirct, clip, linear_before_reset}, mm->add_instruction(make_op("gru",
{{"hidden_size", hidden_size},
{"actv_func", to_value(vec_actv_funcs)},
{"direction", dirct},
{"clip", clip},
{"linear_before_reset", linear_before_reset}}),
std::move(args)); std::move(args));
// second output for last gru output // second output for last gru output
auto last_output = mm->add_instruction(op::rnn_last_hs_output{}, hidden_states); auto last_output = mm->add_instruction(make_op("rnn_last_hs_output"), hidden_states);
return {hidden_states, last_output}; return {hidden_states, last_output};
} }
...@@ -2304,18 +2344,23 @@ struct onnx_parser ...@@ -2304,18 +2344,23 @@ struct onnx_parser
// append undefined opeator to make 6 arguments // append undefined opeator to make 6 arguments
if(args.size() < 8) if(args.size() < 8)
{ {
auto ins = mm->add_instruction(op::undefined{}); auto ins = mm->add_instruction(make_op("undefined"));
args.insert(args.end(), 8 - args.size(), ins); args.insert(args.end(), 8 - args.size(), ins);
} }
// first output for concatenation of hidden states // first output for concatenation of hidden states
auto hidden_states = mm->add_instruction( auto hidden_states = mm->add_instruction(make_op("lstm",
op::lstm{hidden_size, vec_actv_funcs, dirct, clip, input_forget}, std::move(args)); {{"hidden_size", hidden_size},
{"actv_func", to_value(vec_actv_funcs)},
{"direction", dirct},
{"clip", clip},
{"input_forget", input_forget}}),
std::move(args));
auto last_output = mm->add_instruction(op::rnn_last_hs_output{}, hidden_states); auto last_output = mm->add_instruction(make_op("rnn_last_hs_output"), hidden_states);
// third output for last cell output // third output for last cell output
auto last_cell_output = mm->add_instruction(op::rnn_last_cell_output{}, hidden_states); auto last_cell_output = mm->add_instruction(make_op("rnn_last_cell_output"), hidden_states);
return {hidden_states, last_output, last_cell_output}; return {hidden_states, last_output, last_cell_output};
} }
...@@ -2350,7 +2395,7 @@ struct onnx_parser ...@@ -2350,7 +2395,7 @@ struct onnx_parser
else else
{ {
auto ins = mm->add_instruction(make_op(op_name, {{"axes", axes}}), std::move(args)); auto ins = mm->add_instruction(make_op(op_name, {{"axes", axes}}), std::move(args));
return mm->add_instruction(op::squeeze{axes}, ins); return mm->add_instruction(make_op("squeeze", {{"axes", axes}}), ins);
} }
} }
...@@ -2452,8 +2497,9 @@ struct onnx_parser ...@@ -2452,8 +2497,9 @@ struct onnx_parser
int64_t start = 0; int64_t start = 0;
for(auto sl : vec_splits) for(auto sl : vec_splits)
{ {
ret_ins.push_back( ret_ins.push_back(mm->add_instruction(
mm->add_instruction(op::slice{{axis}, {start}, {start + sl}}, args[0])); make_op("slice", {{"axes", {axis}}, {"starts", {start}}, {"ends", {start + sl}}}),
args[0]));
start += sl; start += sl;
} }
...@@ -2482,7 +2528,7 @@ struct onnx_parser ...@@ -2482,7 +2528,7 @@ struct onnx_parser
auto type = args[2]->get_shape().type(); auto type = args[2]->get_shape().type();
shape s{type, {depth, depth}}; shape s{type, {depth, depth}};
auto l_val = mm->add_literal({s, depth_input}); auto l_val = mm->add_literal({s, depth_input});
auto gather_out = mm->add_instruction(op::gather{0}, {l_val, args[0]}); auto gather_out = mm->add_instruction(make_op("gather", {{"axis", 0}}), {l_val, args[0]});
// Finally, we need a transpose to move the inner most dim to the axis dim // Finally, we need a transpose to move the inner most dim to the axis dim
int n_rank = gather_out->get_shape().lens().size(); int n_rank = gather_out->get_shape().lens().size();
...@@ -2494,14 +2540,18 @@ struct onnx_parser ...@@ -2494,14 +2540,18 @@ struct onnx_parser
std::vector<int64_t> perm(n_rank - 1); std::vector<int64_t> perm(n_rank - 1);
std::iota(perm.begin(), perm.end(), 0); std::iota(perm.begin(), perm.end(), 0);
perm.insert(perm.begin() + tuned_axis, n_rank - 1); perm.insert(perm.begin() + tuned_axis, n_rank - 1);
auto tr_out = mm->add_instruction(op::transpose{perm}, gather_out); auto tr_out = mm->add_instruction(make_op("transpose", {{"dims", perm}}), gather_out);
auto lens = tr_out->get_shape().lens(); auto lens = tr_out->get_shape().lens();
auto off_val = mm->add_instruction(op::slice{{0}, {0}, {1}}, args[2]); auto off_val = mm->add_instruction(
auto on_val = mm->add_instruction(op::slice{{0}, {1}, {2}}, args[2]); make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]);
auto on_val = mm->add_instruction(
make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]);
auto diff = mm->add_instruction(make_op("sub"), on_val, off_val); auto diff = mm->add_instruction(make_op("sub"), on_val, off_val);
auto unsq_off_val = mm->add_instruction(op::multibroadcast{lens}, off_val); auto unsq_off_val =
auto unsq_diff_val = mm->add_instruction(op::multibroadcast{lens}, diff); mm->add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), off_val);
auto unsq_diff_val =
mm->add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), diff);
auto l_mul = mm->add_instruction(make_op("mul"), tr_out, unsq_diff_val); auto l_mul = mm->add_instruction(make_op("mul"), tr_out, unsq_diff_val);
return mm->add_instruction(make_op("add"), l_mul, unsq_off_val); return mm->add_instruction(make_op("add"), l_mul, unsq_off_val);
} }
...@@ -2520,7 +2570,7 @@ struct onnx_parser ...@@ -2520,7 +2570,7 @@ struct onnx_parser
auto l1 = l0; auto l1 = l0;
for(int j = 1; j < repeats[i]; j++) for(int j = 1; j < repeats[i]; j++)
{ {
l0 = mm->add_instruction(op::concat{i}, l0, l1); l0 = mm->add_instruction(make_op("concat", {{"axis", i}}), l0, l1);
} }
} }
return l0; return l0;
...@@ -2585,7 +2635,7 @@ struct onnx_parser ...@@ -2585,7 +2635,7 @@ struct onnx_parser
reduce_mode = static_cast<reduce_mode_t>(info.attributes.at("mode").i()); reduce_mode = static_cast<reduce_mode_t>(info.attributes.at("mode").i());
} }
auto l0 = mm->add_instruction(op::gather{}, args[0], args[1]); auto l0 = mm->add_instruction(make_op("gather"), args[0], args[1]);
switch(reduce_mode) switch(reduce_mode)
{ {
case reduce_mode_t::sum: case reduce_mode_t::sum:
...@@ -2920,7 +2970,7 @@ struct onnx_parser ...@@ -2920,7 +2970,7 @@ struct onnx_parser
{ {
if(!contains(instructions, name)) if(!contains(instructions, name))
{ {
auto ins = mm->add_instruction(op::undefined{}); auto ins = mm->add_instruction(make_op("undefined"));
instructions[name] = ins; instructions[name] = ins;
} }
} }
......
#include <migraphx/operation.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void migraphx_to_value(value& v, const operation& op)
{
v["name"] = op.name();
v["operator"] = op.to_value();
}
void migraphx_from_value(const value& v, operation& op)
{
op = make_op(v.at("name").to<std::string>(), v.at("operator"));
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/op/load.hpp> #include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include "memory_coloring_impl.hpp" #include "memory_coloring_impl.hpp"
namespace migraphx { namespace migraphx {
...@@ -208,7 +211,9 @@ void memory_coloring_impl::rewrite() ...@@ -208,7 +211,9 @@ void memory_coloring_impl::rewrite()
if(is_allocate(ins)) if(is_allocate(ins))
{ {
p_program->replace_instruction( p_program->replace_instruction(
ins, op::load{ins->get_shape(), offset}, scratch_param); ins,
make_op("load", {{"shape", to_value(ins->get_shape())}, {"offset", offset}}),
scratch_param);
} }
} }
} }
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
#include <algorithm> #include <algorithm>
#include <set> #include <set>
#include <utility> #include <utility>
#include <migraphx/make_op.hpp>
#include <unordered_set> #include <unordered_set>
namespace migraphx { namespace migraphx {
...@@ -212,7 +214,7 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re ...@@ -212,7 +214,7 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
if(ins == std::prev(this->end())) if(ins == std::prev(this->end()))
{ {
return replace_instruction(ins, op::identity{}, rep); return replace_instruction(ins, make_op("identity"), rep);
} }
// TODO: Should it be an error if the output is empty? // TODO: Should it be an error if the output is empty?
......
...@@ -19,6 +19,10 @@ ...@@ -19,6 +19,10 @@
#include <utility> #include <utility>
#include <set> #include <set>
#include <iomanip> #include <iomanip>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <fstream> #include <fstream>
#include <algorithm> #include <algorithm>
...@@ -58,12 +62,14 @@ instruction_ref insert_quant_ins(module& modl, ...@@ -58,12 +62,14 @@ instruction_ref insert_quant_ins(module& modl,
auto float_ins = scaled_ins; auto float_ins = scaled_ins;
if(scaled_ins->get_shape().type() != shape::float_type) if(scaled_ins->get_shape().type() != shape::float_type)
{ {
float_ins = float_ins = modl.insert_instruction(
modl.insert_instruction(insert_loc, op::convert{shape::float_type}, scaled_ins); insert_loc,
make_op("convert", {{"target_type", to_value(shape::float_type)}}),
scaled_ins);
} }
std::vector<float> vec_scale(scaled_ins->get_shape().elements(), scale); std::vector<float> vec_scale(scaled_ins->get_shape().elements(), scale);
auto l_scale = modl.add_literal(literal(float_ins->get_shape(), vec_scale)); auto l_scale = modl.add_literal(literal(float_ins->get_shape(), vec_scale));
scaled_ins = modl.insert_instruction(insert_loc, op::mul{}, l_scale, float_ins); scaled_ins = modl.insert_instruction(insert_loc, make_op("mul"), l_scale, float_ins);
} }
auto shifted_ins = scaled_ins; auto shifted_ins = scaled_ins;
...@@ -73,26 +79,32 @@ instruction_ref insert_quant_ins(module& modl, ...@@ -73,26 +79,32 @@ instruction_ref insert_quant_ins(module& modl,
if(shifted_ins->get_shape().type() != shape::float_type) if(shifted_ins->get_shape().type() != shape::float_type)
{ {
float_ins = modl.insert_instruction( float_ins = modl.insert_instruction(
insert_loc, op::convert{shape::float_type}, shifted_ins); insert_loc,
make_op("convert", {{"target_type", to_value(shape::float_type)}}),
shifted_ins);
} }
std::vector<float> vec_shift(shifted_ins->get_shape().elements(), shift); std::vector<float> vec_shift(shifted_ins->get_shape().elements(), shift);
auto l_shift = modl.add_literal(literal(float_ins->get_shape(), vec_shift)); auto l_shift = modl.add_literal(literal(float_ins->get_shape(), vec_shift));
shifted_ins = modl.insert_instruction(insert_loc, op::add{}, l_shift, float_ins); shifted_ins = modl.insert_instruction(insert_loc, make_op("add"), l_shift, float_ins);
} }
auto rounded_ins = modl.insert_instruction(insert_loc, op::round{}, shifted_ins); auto rounded_ins = modl.insert_instruction(insert_loc, make_op("round"), shifted_ins);
auto rounded_lens = rounded_ins->get_shape().lens(); auto rounded_lens = rounded_ins->get_shape().lens();
auto max_clip = modl.add_literal(127.0f); auto max_clip = modl.add_literal(127.0f);
auto min_clip = modl.add_literal(-128.0f); auto min_clip = modl.add_literal(-128.0f);
max_clip = modl.insert_instruction(insert_loc, op::multibroadcast{rounded_lens}, max_clip); max_clip = modl.insert_instruction(
min_clip = modl.insert_instruction(insert_loc, op::multibroadcast{rounded_lens}, min_clip); insert_loc, make_op("multibroadcast", {{"output_lens", rounded_lens}}), max_clip);
min_clip = modl.insert_instruction(
insert_loc, make_op("multibroadcast", {{"output_lens", rounded_lens}}), min_clip);
auto clipped_ins = auto clipped_ins =
modl.insert_instruction(insert_loc, op::clip{}, rounded_ins, min_clip, max_clip); modl.insert_instruction(insert_loc, make_op("clip"), rounded_ins, min_clip, max_clip);
quant_ins = modl.insert_instruction(insert_loc, op::convert{type}, clipped_ins); quant_ins = modl.insert_instruction(
insert_loc, make_op("convert", {{"target_type", type}}), clipped_ins);
} }
else else
{ {
quant_ins = modl.insert_instruction(insert_loc, op::convert{type}, ins); quant_ins =
modl.insert_instruction(insert_loc, make_op("convert", {{"target_type", type}}), ins);
} }
map_ins[ins] = quant_ins; map_ins[ins] = quant_ins;
...@@ -162,8 +174,8 @@ void quantize_fp16(program& prog, const std::vector<std::string>& ins_names) ...@@ -162,8 +174,8 @@ void quantize_fp16(program& prog, const std::vector<std::string>& ins_names)
{ {
// check the dead code case to avoid assert // check the dead code case to avoid assert
bool output_empty = ins->outputs().empty(); bool output_empty = ins->outputs().empty();
auto ins_orig_type = auto ins_orig_type = mm->insert_instruction(
mm->insert_instruction(std::next(ins), op::convert{orig_type}, ins); std::next(ins), make_op("convert", {{"target_type", orig_type}}), ins);
if(!output_empty) if(!output_empty)
{ {
mm->replace_instruction(ins, ins_orig_type); mm->replace_instruction(ins, ins_orig_type);
...@@ -197,13 +209,18 @@ static void ins_quantize_int8(module& modl, ...@@ -197,13 +209,18 @@ static void ins_quantize_int8(module& modl,
if(shape::int32_type == orig_type) if(shape::int32_type == orig_type)
{ {
modl.replace_instruction( modl.replace_instruction(
ins, op::quant_dot{quant_alpha, quant_beta}, converted_inputs); ins,
make_op("quant_dot", {{"alpha", quant_alpha}, {"beta", quant_beta}}),
converted_inputs);
} }
else else
{ {
auto quant_dot = modl.insert_instruction( auto quant_dot = modl.insert_instruction(
ins, op::quant_dot{quant_alpha, quant_beta}, converted_inputs); ins,
modl.replace_instruction(ins, op::convert{orig_type}, quant_dot); make_op("quant_dot", {{"alpha", quant_alpha}, {"beta", quant_beta}}),
converted_inputs);
modl.replace_instruction(
ins, make_op("convert", {{"target_type", to_value(orig_type)}}), quant_dot);
} }
} }
// either alpha or beta cannot be quantized because of too big // either alpha or beta cannot be quantized because of too big
...@@ -214,8 +231,10 @@ static void ins_quantize_int8(module& modl, ...@@ -214,8 +231,10 @@ static void ins_quantize_int8(module& modl,
{ {
converted_inputs.pop_back(); converted_inputs.pop_back();
} }
auto q_dot = modl.insert_instruction(ins, op::quant_dot{1, 0}, converted_inputs); auto q_dot = modl.insert_instruction(
auto f_dot = modl.insert_instruction(ins, op::convert{shape::float_type}, q_dot); ins, make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), converted_inputs);
auto f_dot = modl.insert_instruction(
ins, make_op("convert", {{"target_type", to_value(shape::float_type)}}), q_dot);
auto c_shape = q_dot->get_shape(); auto c_shape = q_dot->get_shape();
std::vector<float> vec_alpha(c_shape.elements(), new_alpha); std::vector<float> vec_alpha(c_shape.elements(), new_alpha);
auto l_alpha = auto l_alpha =
...@@ -223,42 +242,46 @@ static void ins_quantize_int8(module& modl, ...@@ -223,42 +242,46 @@ static void ins_quantize_int8(module& modl,
if(inputs.size() == 3 and dot_op.beta != 0.0f) if(inputs.size() == 3 and dot_op.beta != 0.0f)
{ {
auto alpha_ab = modl.insert_instruction(ins, op::mul{}, l_alpha, f_dot); auto alpha_ab = modl.insert_instruction(ins, make_op("mul"), l_alpha, f_dot);
std::vector<float> vec_beta(c_shape.elements(), dot_op.beta); std::vector<float> vec_beta(c_shape.elements(), dot_op.beta);
auto l_beta = auto l_beta =
modl.add_literal(literal({shape::float_type, c_shape.lens()}, vec_beta)); modl.add_literal(literal({shape::float_type, c_shape.lens()}, vec_beta));
instruction_ref beta_c{}; instruction_ref beta_c{};
if(orig_type != shape::float_type) if(orig_type != shape::float_type)
{ {
auto fp32_c = auto fp32_c = modl.insert_instruction(
modl.insert_instruction(ins, op::convert{shape::float_type}, inputs.back()); ins,
beta_c = modl.insert_instruction(ins, op::mul{}, l_beta, fp32_c); make_op("convert", {{"target_type", to_value(shape::float_type)}}),
inputs.back());
beta_c = modl.insert_instruction(ins, make_op("mul"), l_beta, fp32_c);
} }
else else
{ {
beta_c = modl.insert_instruction(ins, op::mul{}, l_beta, inputs.back()); beta_c = modl.insert_instruction(ins, make_op("mul"), l_beta, inputs.back());
} }
if(orig_type == shape::float_type) if(orig_type == shape::float_type)
{ {
modl.replace_instruction(ins, op::add{}, alpha_ab, beta_c); modl.replace_instruction(ins, make_op("add"), alpha_ab, beta_c);
} }
else else
{ {
auto f_res = modl.insert_instruction(ins, op::add{}, alpha_ab, beta_c); auto f_res = modl.insert_instruction(ins, make_op("add"), alpha_ab, beta_c);
modl.replace_instruction(ins, op::convert{orig_type}, f_res); modl.replace_instruction(
ins, make_op("convert", {{"target_type", to_value(orig_type)}}), f_res);
} }
} }
else else
{ {
if(orig_type == shape::float_type) if(orig_type == shape::float_type)
{ {
modl.replace_instruction(ins, op::mul{}, l_alpha, f_dot); modl.replace_instruction(ins, make_op("mul"), l_alpha, f_dot);
} }
else else
{ {
auto alpha_ab = modl.insert_instruction(ins, op::mul{}, l_alpha, f_dot); auto alpha_ab = modl.insert_instruction(ins, make_op("mul"), l_alpha, f_dot);
modl.replace_instruction(ins, op::convert{orig_type}, alpha_ab); modl.replace_instruction(
ins, make_op("convert", {{"target_type", to_value(orig_type)}}), alpha_ab);
} }
} }
} }
...@@ -285,23 +308,27 @@ static void ins_quantize_int8(module& modl, ...@@ -285,23 +308,27 @@ static void ins_quantize_int8(module& modl,
{ {
auto l_factor = modl.add_literal( auto l_factor = modl.add_literal(
literal(quant_conv->get_shape(), vec_factor.begin(), vec_factor.end())); literal(quant_conv->get_shape(), vec_factor.begin(), vec_factor.end()));
modl.replace_instruction(ins, op::mul{}, quant_conv, l_factor); modl.replace_instruction(ins, make_op("mul"), quant_conv, l_factor);
} }
// convert quant_conv output to float type, multiply the factor and // convert quant_conv output to float type, multiply the factor and
// conver back to original type // conver back to original type
else else
{ {
auto float_conv = auto float_conv = modl.insert_instruction(
modl.insert_instruction(ins, op::convert{shape::float_type}, quant_conv); ins,
make_op("convert", {{"target_type", to_value(shape::float_type)}}),
quant_conv);
auto l_factor = modl.add_literal(literal(float_conv->get_shape(), vec_factor)); auto l_factor = modl.add_literal(literal(float_conv->get_shape(), vec_factor));
if(orig_type == shape::float_type) if(orig_type == shape::float_type)
{ {
modl.replace_instruction(ins, op::mul{}, l_factor, float_conv); modl.replace_instruction(ins, make_op("mul"), l_factor, float_conv);
} }
else else
{ {
auto adjusted_conv = modl.insert_instruction(ins, op::mul{}, l_factor, float_conv); auto adjusted_conv =
modl.replace_instruction(ins, op::convert{orig_type}, adjusted_conv); modl.insert_instruction(ins, make_op("mul"), l_factor, float_conv);
modl.replace_instruction(
ins, make_op("convert", {{"target_type", to_value(orig_type)}}), adjusted_conv);
} }
} }
} }
......
...@@ -7,6 +7,8 @@ ...@@ -7,6 +7,8 @@
#include <migraphx/op/mul.hpp> #include <migraphx/op/mul.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
namespace migraphx { namespace migraphx {
...@@ -46,10 +48,10 @@ void rewrite_batchnorm::apply(module& p) const ...@@ -46,10 +48,10 @@ void rewrite_batchnorm::apply(module& p) const
auto broadcast = op::broadcast{1, ins->get_shape().lens()}; auto broadcast = op::broadcast{1, ins->get_shape().lens()};
auto a_ins = p.add_literal({a.get_shape(), a.data()}); auto a_ins = p.add_literal({a.get_shape(), a.data()});
auto a_broadcast = p.insert_instruction(ins, broadcast, a_ins); auto a_broadcast = p.insert_instruction(ins, broadcast, a_ins);
auto mul = p.insert_instruction(ins, op::mul{}, ins->inputs().front(), a_broadcast); auto mul = p.insert_instruction(ins, make_op("mul"), ins->inputs().front(), a_broadcast);
auto b_ins = p.add_literal({b.get_shape(), b.data()}); auto b_ins = p.add_literal({b.get_shape(), b.data()});
auto b_broadcast = p.insert_instruction(ins, broadcast, b_ins); auto b_broadcast = p.insert_instruction(ins, broadcast, b_ins);
auto add = p.insert_instruction(ins, op::add{}, mul, b_broadcast); auto add = p.insert_instruction(ins, make_op("add"), mul, b_broadcast);
p.replace_instruction(ins, add); p.replace_instruction(ins, add);
} }
} }
......
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
#include <migraphx/op/reshape.hpp> #include <migraphx/op/reshape.hpp>
#include <migraphx/op/reduce_mean.hpp> #include <migraphx/op/reduce_mean.hpp>
#include <migraphx/op/reduce_max.hpp> #include <migraphx/op/reduce_max.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
namespace migraphx { namespace migraphx {
...@@ -31,25 +33,26 @@ void rewrite_pooling::apply(module& prog) const ...@@ -31,25 +33,26 @@ void rewrite_pooling::apply(module& prog) const
continue; continue;
std::int64_t n = s.lens()[0]; std::int64_t n = s.lens()[0];
std::int64_t c = s.lens()[1]; std::int64_t c = s.lens()[1];
auto reshape = auto reshape = prog.insert_instruction(
prog.insert_instruction(ins, op::reshape{{n * c, -1}}, ins->inputs().front()); ins, make_op("reshape", {{"dims", {n * c, -1}}}), ins->inputs().front());
instruction_ref pooling{}; instruction_ref pooling{};
// average pooling // average pooling
if(op.mode == "average") if(op.mode == "average")
{ {
pooling = prog.insert_instruction(ins, op::reduce_mean{{1}}, reshape); pooling =
prog.insert_instruction(ins, make_op("reduce_mean", {{"axes", {1}}}), reshape);
} }
// max pooling // max pooling
else else
{ {
pooling = prog.insert_instruction(ins, op::reduce_max{{1}}, reshape); pooling = prog.insert_instruction(ins, make_op("reduce_max", {{"axes", {1}}}), reshape);
} }
std::vector<int64_t> rsp_lens(lens.size(), 1); std::vector<int64_t> rsp_lens(lens.size(), 1);
rsp_lens[0] = n; rsp_lens[0] = n;
rsp_lens[1] = c; rsp_lens[1] = c;
prog.replace_instruction(ins, op::reshape{rsp_lens}, pooling); prog.replace_instruction(ins, make_op("reshape", {{"dims", rsp_lens}}), pooling);
} }
} }
......
...@@ -18,6 +18,8 @@ ...@@ -18,6 +18,8 @@
#include <migraphx/op/common.hpp> #include <migraphx/op/common.hpp>
#include <migraphx/op/rnn_var_sl_last_output.hpp> #include <migraphx/op/rnn_var_sl_last_output.hpp>
#include <migraphx/op/rnn_variable_seq_lens.hpp> #include <migraphx/op/rnn_variable_seq_lens.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
...@@ -80,20 +82,26 @@ void rewrite_rnn::apply_vanilla_rnn(module& prog, instruction_ref ins) const ...@@ -80,20 +82,26 @@ void rewrite_rnn::apply_vanilla_rnn(module& prog, instruction_ref ins) const
if(dirct == op::rnn_direction::bidirectional) if(dirct == op::rnn_direction::bidirectional)
{ {
// input weight matrix // input weight matrix
auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]); auto w_forward = prog.insert_instruction(
auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[1]);
auto w_reverse = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[1]);
// hidden state weight matrix // hidden state weight matrix
auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]); auto r_forward = prog.insert_instruction(
auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]);
auto r_reverse = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]);
// process bias // process bias
instruction_ref bias_forward = prog.end(); instruction_ref bias_forward = prog.end();
instruction_ref bias_reverse = prog.end(); instruction_ref bias_reverse = prog.end();
if(args.size() >= 4 && args[3]->name() != "undefined") if(args.size() >= 4 && args[3]->name() != "undefined")
{ {
bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]); bias_forward = prog.insert_instruction(
bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
bias_reverse = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[3]);
} }
// process intial hidden state, it could be the 6th argument // process intial hidden state, it could be the 6th argument
...@@ -102,8 +110,10 @@ void rewrite_rnn::apply_vanilla_rnn(module& prog, instruction_ref ins) const ...@@ -102,8 +110,10 @@ void rewrite_rnn::apply_vanilla_rnn(module& prog, instruction_ref ins) const
instruction_ref ih_reverse{}; instruction_ref ih_reverse{};
if(args.size() == 6 && args[5]->name() != "undefined") if(args.size() == 6 && args[5]->name() != "undefined")
{ {
ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]); ih_forward = prog.insert_instruction(
ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
ih_reverse = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[5]);
} }
else else
{ {
...@@ -120,8 +130,8 @@ void rewrite_rnn::apply_vanilla_rnn(module& prog, instruction_ref ins) const ...@@ -120,8 +130,8 @@ void rewrite_rnn::apply_vanilla_rnn(module& prog, instruction_ref ins) const
if(variable_seq_len) if(variable_seq_len)
{ {
args[0] = args[0] = prog.insert_instruction(
prog.insert_instruction(ins, op::rnn_var_sl_shift_sequence{}, args[0], seq_lens); ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
} }
auto ret_reverse = auto ret_reverse =
...@@ -131,24 +141,27 @@ void rewrite_rnn::apply_vanilla_rnn(module& prog, instruction_ref ins) const ...@@ -131,24 +141,27 @@ void rewrite_rnn::apply_vanilla_rnn(module& prog, instruction_ref ins) const
{args[0], w_reverse, r_reverse, bias_reverse, seq_lens, ih_reverse}, {args[0], w_reverse, r_reverse, bias_reverse, seq_lens, ih_reverse},
actv_funcs.at(1)); actv_funcs.at(1));
auto concat_output = auto concat_output = prog.insert_instruction(
prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]); ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output); last_output =
prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_output);
// The following logic is to ensure the last instruction rewritten from // The following logic is to ensure the last instruction rewritten from
// rnn operator is a concat instruction // rnn operator is a concat instruction
// sequence len is 1 // sequence len is 1
if(ret_forward[0] == prog.end()) if(ret_forward[0] == prog.end())
{ {
prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]); prog.replace_instruction(
ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
} }
else else
{ {
ret_forward[0] = ret_forward[0] = prog.insert_instruction(
prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]); ins, make_op("concat", {{"axis", 0}}), ret_forward[0], ret_forward[1]);
ret_reverse[0] = ret_reverse[0] = prog.insert_instruction(
prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]); ins, make_op("concat", {{"axis", 0}}), ret_reverse[1], ret_reverse[0]);
prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]}); prog.replace_instruction(
ins, make_op("concat", {{"axis", 1}}), {ret_forward[0], ret_reverse[0]});
} }
} }
else else
...@@ -180,26 +193,27 @@ void rewrite_rnn::apply_vanilla_rnn(module& prog, instruction_ref ins) const ...@@ -180,26 +193,27 @@ void rewrite_rnn::apply_vanilla_rnn(module& prog, instruction_ref ins) const
if(!is_forward and variable_seq_len) if(!is_forward and variable_seq_len)
{ {
args[0] = args[0] = prog.insert_instruction(
prog.insert_instruction(ins, op::rnn_var_sl_shift_sequence{}, args[0], seq_lens); ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
} }
auto ret = vanilla_rnn_cell( auto ret = vanilla_rnn_cell(
is_forward, prog, ins, {args[0], w, r, bias, seq_lens, ih}, actv_funcs.at(0)); is_forward, prog, ins, {args[0], w, r, bias, seq_lens, ih}, actv_funcs.at(0));
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]); last_output = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[1]);
// following logic is to ensure the last instruction is a // following logic is to ensure the last instruction is a
// concat instruction // concat instruction
// sequence len is 1 // sequence len is 1
if(ret[0] == prog.end()) if(ret[0] == prog.end())
{ {
prog.replace_instruction(ins, op::concat{0}, ret[1]); prog.replace_instruction(ins, make_op("concat", {{"axis", 0}}), ret[1]);
} }
else else
{ {
auto concat_arg0 = is_forward ? ret[0] : ret[1]; auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[0]; auto concat_arg1 = is_forward ? ret[1] : ret[0];
prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1); prog.replace_instruction(
ins, make_op("concat", {{"axis", 0}}), concat_arg0, concat_arg1);
} }
} }
...@@ -225,15 +239,15 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward, ...@@ -225,15 +239,15 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
// squeeze and transpose w // squeeze and transpose w
std::vector<int64_t> perm{1, 0}; std::vector<int64_t> perm{1, 0};
auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w); auto sw = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
auto tran_sw = prog.insert_instruction(ins, op::transpose{perm}, sw); auto tran_sw = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), sw);
// squeeze and transpose r // squeeze and transpose r
auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r); auto sr = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
auto tran_sr = prog.insert_instruction(ins, op::transpose{perm}, sr); auto tran_sr = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), sr);
// initial hidden state // initial hidden state
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih); auto sih = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
auto sih_lens = sih->get_shape().lens(); auto sih_lens = sih->get_shape().lens();
// bias // bias
...@@ -241,30 +255,36 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward, ...@@ -241,30 +255,36 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
if(bias != prog.end()) if(bias != prog.end())
{ {
long hs = static_cast<long>(r->get_shape().lens()[2]); long hs = static_cast<long>(r->get_shape().lens()[2]);
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias); auto sbias = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), bias);
auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias); auto wb = prog.insert_instruction(
auto rb = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {hs}}}), sbias);
auto wrb = prog.insert_instruction(ins, op::add{}, wb, rb); auto rb = prog.insert_instruction(
bb = prog.insert_instruction(ins, op::broadcast{1, sih_lens}, wrb); ins, make_op("slice", {{"axes", {0}}, {"starts", {hs}}, {"ends", {2 * hs}}}), sbias);
auto wrb = prog.insert_instruction(ins, make_op("add"), wb, rb);
bb = prog.insert_instruction(
ins, make_op("broadcast", {{"axis", 1}, {"dims", sih_lens}}), wrb);
} }
instruction_ref hidden_out = prog.end(); instruction_ref hidden_out = prog.end();
instruction_ref last_out{}; instruction_ref last_out{};
last_out = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, sih); last_out = prog.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), sih);
long seq_len = static_cast<long>(get_seq_len(prog, seq, seq_lens)); long seq_len = static_cast<long>(get_seq_len(prog, seq, seq_lens));
for(long i = 0; i < seq_len; i++) for(long i = 0; i < seq_len; i++)
{ {
long seq_index = is_forward ? i : (seq_len - 1 - i); long seq_index = is_forward ? i : (seq_len - 1 - i);
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq); auto xt = prog.insert_instruction(
auto cont_xt = prog.insert_instruction(ins, op::contiguous{}, xt); ins,
xt = prog.insert_instruction(ins, op::squeeze{{0}}, cont_xt); make_op("slice", {{"axes", {0}}, {"starts", {seq_index}}, {"ends", {seq_index + 1}}}),
auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_sw); seq);
auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_sr); auto cont_xt = prog.insert_instruction(ins, make_op("contiguous"), xt);
xt = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), cont_xt);
auto xt_wi = prog.insert_instruction(ins, make_op("dot"), xt, tran_sw);
auto ht_ri = prog.insert_instruction(ins, make_op("dot"), sih, tran_sr);
if(bias != prog.end()) if(bias != prog.end())
{ {
xt_wi = prog.insert_instruction(ins, op::add{}, xt_wi, bb); xt_wi = prog.insert_instruction(ins, make_op("add"), xt_wi, bb);
} }
auto xt_ht = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri); auto xt_ht = prog.insert_instruction(ins, make_op("add"), xt_wi, ht_ri);
// apply activation function // apply activation function
auto ht = prog.insert_instruction(ins, actv_func, xt_ht); auto ht = prog.insert_instruction(ins, actv_func, xt_ht);
...@@ -272,7 +292,7 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward, ...@@ -272,7 +292,7 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
// add the dimensions of sequence length (axis 0 for sequence length, // add the dimensions of sequence length (axis 0 for sequence length,
// axis 1 for num_directions // axis 1 for num_directions
last_out = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, ht); last_out = prog.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), ht);
// concatenation for the last last_out is performed in the apply() // concatenation for the last last_out is performed in the apply()
// function to ensure the last instruction is concat, then we have // function to ensure the last instruction is concat, then we have
...@@ -281,17 +301,17 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward, ...@@ -281,17 +301,17 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
{ {
if(is_forward) if(is_forward)
{ {
hidden_out = hidden_out = (seq_index == 0)
(seq_index == 0)
? last_out ? last_out
: prog.insert_instruction(ins, op::concat{0}, hidden_out, last_out); : prog.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), hidden_out, last_out);
} }
else else
{ {
hidden_out = hidden_out = (seq_index == seq_len - 1)
(seq_index == seq_len - 1)
? last_out ? last_out
: prog.insert_instruction(ins, op::concat{0}, last_out, hidden_out); : prog.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), last_out, hidden_out);
} }
} }
} }
...@@ -311,7 +331,7 @@ std::vector<operation> rewrite_rnn::vanilla_rnn_actv_funcs(instruction_ref ins) ...@@ -311,7 +331,7 @@ std::vector<operation> rewrite_rnn::vanilla_rnn_actv_funcs(instruction_ref ins)
if(rnn_op.actv_funcs.empty()) if(rnn_op.actv_funcs.empty())
{ {
// default is tanh // default is tanh
return {op::tanh{}, op::tanh{}}; return {make_op("tanh"), make_op("tanh")};
} }
else if(rnn_op.actv_funcs.size() == 1) else if(rnn_op.actv_funcs.size() == 1)
{ {
...@@ -327,7 +347,7 @@ std::vector<operation> rewrite_rnn::vanilla_rnn_actv_funcs(instruction_ref ins) ...@@ -327,7 +347,7 @@ std::vector<operation> rewrite_rnn::vanilla_rnn_actv_funcs(instruction_ref ins)
if(rnn_op.actv_funcs.empty()) if(rnn_op.actv_funcs.empty())
{ {
// default is tanh // default is tanh
return {op::tanh{}}; return {make_op("tanh")};
} }
else else
{ {
...@@ -369,20 +389,26 @@ void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const ...@@ -369,20 +389,26 @@ void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const
if(dirct == op::rnn_direction::bidirectional) if(dirct == op::rnn_direction::bidirectional)
{ {
// w weight matrix // w weight matrix
auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]); auto w_forward = prog.insert_instruction(
auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[1]);
auto w_reverse = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[1]);
// r weight matrix // r weight matrix
auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]); auto r_forward = prog.insert_instruction(
auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]);
auto r_reverse = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]);
// bias // bias
instruction_ref bias_forward = prog.end(); instruction_ref bias_forward = prog.end();
instruction_ref bias_reverse = prog.end(); instruction_ref bias_reverse = prog.end();
if(args.size() >= 4 && args[3]->name() != "undefined") if(args.size() >= 4 && args[3]->name() != "undefined")
{ {
bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]); bias_forward = prog.insert_instruction(
bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
bias_reverse = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[3]);
} }
// intial hidden state // intial hidden state
...@@ -390,8 +416,10 @@ void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const ...@@ -390,8 +416,10 @@ void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const
instruction_ref ih_reverse{}; instruction_ref ih_reverse{};
if(args.size() == 6 && args[5]->name() != "undefined") if(args.size() == 6 && args[5]->name() != "undefined")
{ {
ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]); ih_forward = prog.insert_instruction(
ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
ih_reverse = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[5]);
} }
else else
{ {
...@@ -410,8 +438,8 @@ void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const ...@@ -410,8 +438,8 @@ void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const
if(variable_seq_len) if(variable_seq_len)
{ {
args[0] = args[0] = prog.insert_instruction(
prog.insert_instruction(ins, op::rnn_var_sl_shift_sequence{}, args[0], seq_lens); ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
} }
auto ret_reverse = auto ret_reverse =
...@@ -423,23 +451,26 @@ void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const ...@@ -423,23 +451,26 @@ void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const
actv_funcs.at(2), actv_funcs.at(2),
actv_funcs.at(3)); actv_funcs.at(3));
auto concat_output = auto concat_output = prog.insert_instruction(
prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]); ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output); last_output =
prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_output);
// The following logic is to ensure the last instruction rewritten // The following logic is to ensure the last instruction rewritten
// from gru operator is a concat // from gru operator is a concat
if(ret_forward[0] == prog.end()) if(ret_forward[0] == prog.end())
{ {
prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]); prog.replace_instruction(
ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
} }
else else
{ {
ret_forward[0] = ret_forward[0] = prog.insert_instruction(
prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]); ins, make_op("concat", {{"axis", 0}}), ret_forward[0], ret_forward[1]);
ret_reverse[0] = ret_reverse[0] = prog.insert_instruction(
prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]); ins, make_op("concat", {{"axis", 0}}), ret_reverse[1], ret_reverse[0]);
prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]}); prog.replace_instruction(
ins, make_op("concat", {{"axis", 1}}), {ret_forward[0], ret_reverse[0]});
} }
} }
else else
...@@ -469,8 +500,8 @@ void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const ...@@ -469,8 +500,8 @@ void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const
if(!is_forward and variable_seq_len) if(!is_forward and variable_seq_len)
{ {
args[0] = args[0] = prog.insert_instruction(
prog.insert_instruction(ins, op::rnn_var_sl_shift_sequence{}, args[0], seq_lens); ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
} }
auto ret = gru_cell(is_forward, auto ret = gru_cell(is_forward,
...@@ -481,17 +512,18 @@ void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const ...@@ -481,17 +512,18 @@ void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const
actv_funcs.at(0), actv_funcs.at(0),
actv_funcs.at(1)); actv_funcs.at(1));
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]); last_output = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[1]);
if(ret[0] == prog.end()) if(ret[0] == prog.end())
{ {
prog.replace_instruction(ins, op::concat{0}, ret[1]); prog.replace_instruction(ins, make_op("concat", {{"axis", 0}}), ret[1]);
} }
else else
{ {
auto concat_arg0 = is_forward ? ret[0] : ret[1]; auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[0]; auto concat_arg1 = is_forward ? ret[1] : ret[0];
prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1); prog.replace_instruction(
ins, make_op("concat", {{"axis", 0}}), concat_arg0, concat_arg1);
} }
} }
...@@ -529,19 +561,21 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward, ...@@ -529,19 +561,21 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
// w matrix squeeze to 2-dim and do a transpose // w matrix squeeze to 2-dim and do a transpose
std::vector<int64_t> perm{1, 0}; std::vector<int64_t> perm{1, 0};
auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w); auto sw = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
auto tw = prog.insert_instruction(ins, op::transpose{perm}, sw); auto tw = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), sw);
// r slide to two part, zr and h // r slide to two part, zr and h
auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r); auto sr = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
auto rzr = prog.insert_instruction(ins, op::slice{{0}, {0}, {2 * hs}}, sr); auto rzr = prog.insert_instruction(
auto trzr = prog.insert_instruction(ins, op::transpose{perm}, rzr); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2 * hs}}}), sr);
auto trzr = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), rzr);
auto rh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sr); auto rh = prog.insert_instruction(
auto trh = prog.insert_instruction(ins, op::transpose{perm}, rh); ins, make_op("slice", {{"axes", {0}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), sr);
auto trh = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), rh);
// initial states // initial states
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih); auto sih = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
size_t bs = ih->get_shape().lens()[1]; size_t bs = ih->get_shape().lens()[1];
// bias // bias
...@@ -550,77 +584,100 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward, ...@@ -550,77 +584,100 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
instruction_ref brb_h{}; instruction_ref brb_h{};
if(bias != prog.end()) if(bias != prog.end())
{ {
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias); auto sbias = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), bias);
auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {3 * hs}}, sbias); auto wb = prog.insert_instruction(
bwb = prog.insert_instruction(ins, op::broadcast{1, {bs, static_cast<size_t>(3 * hs)}}, wb); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {3 * hs}}}), sbias);
bwb = prog.insert_instruction(
ins,
make_op("broadcast", {{"axis", 1}, {"dims", {bs, static_cast<size_t>(3 * hs)}}}),
wb);
auto rb_zr = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {5 * hs}}, sbias); auto rb_zr = prog.insert_instruction(
auto rb_h = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias); ins,
make_op("slice", {{"axes", {0}}, {"starts", {3 * hs}}, {"ends", {5 * hs}}}),
sbias);
auto rb_h = prog.insert_instruction(
ins,
make_op("slice", {{"axes", {0}}, {"starts", {5 * hs}}, {"ends", {6 * hs}}}),
sbias);
brb_zr = prog.insert_instruction( brb_zr = prog.insert_instruction(
ins, op::broadcast{1, {bs, static_cast<size_t>(2 * hs)}}, rb_zr); ins,
brb_h = prog.insert_instruction(ins, op::broadcast{1, {bs, static_cast<size_t>(hs)}}, rb_h); make_op("broadcast", {{"axis", 1}, {"dims", {bs, static_cast<size_t>(2 * hs)}}}),
rb_zr);
brb_h = prog.insert_instruction(
ins,
make_op("broadcast", {{"axis", 1}, {"dims", {bs, static_cast<size_t>(hs)}}}),
rb_h);
} }
long seq_len = static_cast<long>(get_seq_len(prog, seq, seq_lens)); long seq_len = static_cast<long>(get_seq_len(prog, seq, seq_lens));
for(long i = 0; i < seq_len; i++) for(long i = 0; i < seq_len; i++)
{ {
long seq_index = is_forward ? i : (seq_len - 1 - i); long seq_index = is_forward ? i : (seq_len - 1 - i);
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq); auto xt = prog.insert_instruction(
auto cont_xt = prog.insert_instruction(ins, op::contiguous{}, xt); ins,
xt = prog.insert_instruction(ins, op::squeeze{{0}}, cont_xt); make_op("slice", {{"axes", {0}}, {"starts", {seq_index}}, {"ends", {seq_index + 1}}}),
seq);
auto cont_xt = prog.insert_instruction(ins, make_op("contiguous"), xt);
xt = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), cont_xt);
auto xt_w = prog.insert_instruction(ins, op::dot{}, xt, tw); auto xt_w = prog.insert_instruction(ins, make_op("dot"), xt, tw);
auto ih1_rzr = prog.insert_instruction(ins, op::dot{}, sih, trzr); auto ih1_rzr = prog.insert_instruction(ins, make_op("dot"), sih, trzr);
if(bias != prog.end()) if(bias != prog.end())
{ {
xt_w = prog.insert_instruction(ins, op::add{}, xt_w, bwb); xt_w = prog.insert_instruction(ins, make_op("add"), xt_w, bwb);
ih1_rzr = prog.insert_instruction(ins, op::add{}, ih1_rzr, brb_zr); ih1_rzr = prog.insert_instruction(ins, make_op("add"), ih1_rzr, brb_zr);
} }
auto xw_z = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, xt_w); auto xw_z = prog.insert_instruction(
auto xw_r = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, xt_w); ins, make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {hs}}}), xt_w);
auto xw_h = prog.insert_instruction(ins, op::slice{{1}, {2 * hs}, {3 * hs}}, xt_w); auto xw_r = prog.insert_instruction(
ins, make_op("slice", {{"axes", {1}}, {"starts", {hs}}, {"ends", {2 * hs}}}), xt_w);
auto xw_h = prog.insert_instruction(
ins, make_op("slice", {{"axes", {1}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), xt_w);
auto hr_z = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, ih1_rzr); auto hr_z = prog.insert_instruction(
auto hr_r = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, ih1_rzr); ins, make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {hs}}}), ih1_rzr);
auto hr_r = prog.insert_instruction(
ins, make_op("slice", {{"axes", {1}}, {"starts", {hs}}, {"ends", {2 * hs}}}), ih1_rzr);
auto xw_hr_z = prog.insert_instruction(ins, op::add{}, xw_z, hr_z); auto xw_hr_z = prog.insert_instruction(ins, make_op("add"), xw_z, hr_z);
auto zt = prog.insert_instruction(ins, actv_func1, xw_hr_z); auto zt = prog.insert_instruction(ins, actv_func1, xw_hr_z);
auto xw_hr_r = prog.insert_instruction(ins, op::add{}, xw_r, hr_r); auto xw_hr_r = prog.insert_instruction(ins, make_op("add"), xw_r, hr_r);
auto rt = prog.insert_instruction(ins, actv_func1, xw_hr_r); auto rt = prog.insert_instruction(ins, actv_func1, xw_hr_r);
instruction_ref hr_h{}; instruction_ref hr_h{};
if(linear_before_reset == 0) if(linear_before_reset == 0)
{ {
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh) // equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
auto rt_ht1 = prog.insert_instruction(ins, op::mul{}, rt, sih); auto rt_ht1 = prog.insert_instruction(ins, make_op("mul"), rt, sih);
hr_h = prog.insert_instruction(ins, op::dot{}, rt_ht1, trh); hr_h = prog.insert_instruction(ins, make_op("dot"), rt_ht1, trh);
if(bias != prog.end()) if(bias != prog.end())
{ {
hr_h = prog.insert_instruction(ins, op::add{}, hr_h, brb_h); hr_h = prog.insert_instruction(ins, make_op("add"), hr_h, brb_h);
} }
} }
else else
{ {
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh) // equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
auto ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, trh); auto ht1_rh = prog.insert_instruction(ins, make_op("dot"), sih, trh);
if(bias != prog.end()) if(bias != prog.end())
{ {
ht1_rh = prog.insert_instruction(ins, op::add{}, ht1_rh, brb_h); ht1_rh = prog.insert_instruction(ins, make_op("add"), ht1_rh, brb_h);
} }
hr_h = prog.insert_instruction(ins, op::mul{}, rt, ht1_rh); hr_h = prog.insert_instruction(ins, make_op("mul"), rt, ht1_rh);
} }
auto xw_hr_h = prog.insert_instruction(ins, op::add{}, xw_h, hr_h); auto xw_hr_h = prog.insert_instruction(ins, make_op("add"), xw_h, hr_h);
auto ht = prog.insert_instruction(ins, actv_func2, xw_hr_h); auto ht = prog.insert_instruction(ins, actv_func2, xw_hr_h);
// equation Ht = (1 - zt) (.) ht + zt (.) Ht-1 // equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
auto one_minus_zt = prog.insert_instruction(ins, op::sub{}, l1, zt); auto one_minus_zt = prog.insert_instruction(ins, make_op("sub"), l1, zt);
auto one_minus_zt_ht = prog.insert_instruction(ins, op::mul{}, one_minus_zt, ht); auto one_minus_zt_ht = prog.insert_instruction(ins, make_op("mul"), one_minus_zt, ht);
auto zt_ht1 = prog.insert_instruction(ins, op::mul{}, zt, sih); auto zt_ht1 = prog.insert_instruction(ins, make_op("mul"), zt, sih);
sih = prog.insert_instruction(ins, op::add{}, one_minus_zt_ht, zt_ht1); sih = prog.insert_instruction(ins, make_op("add"), one_minus_zt_ht, zt_ht1);
last_output = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, sih); last_output = prog.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), sih);
if(i < seq_len - 1) if(i < seq_len - 1)
{ {
...@@ -629,14 +686,16 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward, ...@@ -629,14 +686,16 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
hidden_states = hidden_states =
(seq_index == 0) (seq_index == 0)
? last_output ? last_output
: prog.insert_instruction(ins, op::concat{0}, hidden_states, last_output); : prog.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), hidden_states, last_output);
} }
else else
{ {
hidden_states = hidden_states =
(seq_index == seq_len - 1) (seq_index == seq_len - 1)
? last_output ? last_output
: prog.insert_instruction(ins, op::concat{0}, last_output, hidden_states); : prog.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), last_output, hidden_states);
} }
} }
} }
...@@ -654,7 +713,7 @@ std::vector<operation> rewrite_rnn::gru_actv_funcs(instruction_ref ins) const ...@@ -654,7 +713,7 @@ std::vector<operation> rewrite_rnn::gru_actv_funcs(instruction_ref ins) const
if(gru_op.direction == op::rnn_direction::bidirectional) if(gru_op.direction == op::rnn_direction::bidirectional)
{ {
if(gru_op.actv_funcs.empty()) if(gru_op.actv_funcs.empty())
return {op::sigmoid{}, op::tanh{}, op::sigmoid{}, op::tanh{}}; return {make_op("sigmoid"), make_op("tanh"), make_op("sigmoid"), make_op("tanh")};
else if(gru_op.actv_funcs.size() == 1) else if(gru_op.actv_funcs.size() == 1)
return {gru_op.actv_funcs.at(0), return {gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(0), gru_op.actv_funcs.at(0),
...@@ -676,7 +735,7 @@ std::vector<operation> rewrite_rnn::gru_actv_funcs(instruction_ref ins) const ...@@ -676,7 +735,7 @@ std::vector<operation> rewrite_rnn::gru_actv_funcs(instruction_ref ins) const
else else
{ {
if(gru_op.actv_funcs.empty()) if(gru_op.actv_funcs.empty())
return {op::sigmoid{}, op::tanh{}}; return {make_op("sigmoid"), make_op("tanh")};
else if(gru_op.actv_funcs.size() == 1) else if(gru_op.actv_funcs.size() == 1)
return {gru_op.actv_funcs.at(0), gru_op.actv_funcs.at(0)}; return {gru_op.actv_funcs.at(0), gru_op.actv_funcs.at(0)};
else else
...@@ -720,20 +779,26 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const ...@@ -720,20 +779,26 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
{ {
// input weight matrix // input weight matrix
// input weight matrix // input weight matrix
auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]); auto w_forward = prog.insert_instruction(
auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[1]);
auto w_reverse = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[1]);
// hidden state weight matrix // hidden state weight matrix
auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]); auto r_forward = prog.insert_instruction(
auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]);
auto r_reverse = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]);
// process bias // process bias
instruction_ref bias_forward = prog.end(); instruction_ref bias_forward = prog.end();
instruction_ref bias_reverse = prog.end(); instruction_ref bias_reverse = prog.end();
if(args.size() >= 4 && args[3]->name() != "undefined") if(args.size() >= 4 && args[3]->name() != "undefined")
{ {
bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]); bias_forward = prog.insert_instruction(
bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
bias_reverse = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[3]);
} }
// process intial hidden state, it is the 6th argument // process intial hidden state, it is the 6th argument
...@@ -741,8 +806,10 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const ...@@ -741,8 +806,10 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
instruction_ref ih_reverse{}; instruction_ref ih_reverse{};
if(args.size() >= 6 && args[5]->name() != "undefined") if(args.size() >= 6 && args[5]->name() != "undefined")
{ {
ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]); ih_forward = prog.insert_instruction(
ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
ih_reverse = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[5]);
} }
else else
{ {
...@@ -755,8 +822,10 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const ...@@ -755,8 +822,10 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
instruction_ref ic_reverse{}; instruction_ref ic_reverse{};
if(args.size() >= 7 && args[6]->name() != "undefined") if(args.size() >= 7 && args[6]->name() != "undefined")
{ {
ic_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[6]); ic_forward = prog.insert_instruction(
ic_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[6]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[6]);
ic_reverse = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[6]);
} }
else else
{ {
...@@ -769,8 +838,10 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const ...@@ -769,8 +838,10 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
instruction_ref pph_reverse = prog.end(); instruction_ref pph_reverse = prog.end();
if(args.size() == 8 && args[7]->name() != "undefined") if(args.size() == 8 && args[7]->name() != "undefined")
{ {
pph_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[7]); pph_forward = prog.insert_instruction(
pph_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[7]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[7]);
pph_reverse = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[7]);
} }
auto ret_forward = lstm_cell(true, auto ret_forward = lstm_cell(true,
...@@ -790,8 +861,8 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const ...@@ -790,8 +861,8 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
if(variable_seq_len) if(variable_seq_len)
{ {
args[0] = args[0] = prog.insert_instruction(
prog.insert_instruction(ins, op::rnn_var_sl_shift_sequence{}, args[0], seq_lens); ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
} }
auto ret_reverse = lstm_cell(false, auto ret_reverse = lstm_cell(false,
prog, prog,
...@@ -808,12 +879,14 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const ...@@ -808,12 +879,14 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
actv_funcs.at(4), actv_funcs.at(4),
actv_funcs.at(5)); actv_funcs.at(5));
auto concat_hs_output = auto concat_hs_output = prog.insert_instruction(
prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]); ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
auto concat_cell_output = auto concat_cell_output = prog.insert_instruction(
prog.insert_instruction(ins, op::concat{1}, ret_forward[3], ret_reverse[3]); ins, make_op("concat", {{"axis", 1}}), ret_forward[3], ret_reverse[3]);
last_hs_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_hs_output); last_hs_output =
last_cell_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_cell_output); prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_hs_output);
last_cell_output =
prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_cell_output);
// the following logic is to ensure the last instruction is a concat // the following logic is to ensure the last instruction is a concat
if(ret_forward[0] == prog.end()) if(ret_forward[0] == prog.end())
...@@ -822,21 +895,21 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const ...@@ -822,21 +895,21 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
} }
else else
{ {
ret_forward[1] = ret_forward[1] = prog.insert_instruction(
prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]); ins, make_op("concat", {{"axis", 0}}), ret_forward[0], ret_forward[1]);
ret_reverse[1] = ret_reverse[1] = prog.insert_instruction(
prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]); ins, make_op("concat", {{"axis", 0}}), ret_reverse[1], ret_reverse[0]);
ret_forward[3] = ret_forward[3] = prog.insert_instruction(
prog.insert_instruction(ins, op::concat{0}, ret_forward[2], ret_forward[3]); ins, make_op("concat", {{"axis", 0}}), ret_forward[2], ret_forward[3]);
ret_reverse[3] = ret_reverse[3] = prog.insert_instruction(
prog.insert_instruction(ins, op::concat{0}, ret_reverse[3], ret_reverse[2]); ins, make_op("concat", {{"axis", 0}}), ret_reverse[3], ret_reverse[2]);
cell_outputs = cell_outputs = prog.insert_instruction(
prog.insert_instruction(ins, op::concat{1}, ret_forward[3], ret_reverse[3]); ins, make_op("concat", {{"axis", 1}}), ret_forward[3], ret_reverse[3]);
} }
hidden_state = hidden_state = prog.replace_instruction(
prog.replace_instruction(ins, op::concat{1}, {ret_forward[1], ret_reverse[1]}); ins, make_op("concat", {{"axis", 1}}), {ret_forward[1], ret_reverse[1]});
} }
else else
{ {
...@@ -883,8 +956,8 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const ...@@ -883,8 +956,8 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
if(!is_forward and variable_seq_len) if(!is_forward and variable_seq_len)
{ {
args[0] = args[0] = prog.insert_instruction(
prog.insert_instruction(ins, op::rnn_var_sl_shift_sequence{}, args[0], seq_lens); ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
} }
auto ret = lstm_cell(is_forward, auto ret = lstm_cell(is_forward,
prog, prog,
...@@ -894,24 +967,26 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const ...@@ -894,24 +967,26 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
actv_funcs.at(1), actv_funcs.at(1),
actv_funcs.at(2)); actv_funcs.at(2));
last_hs_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]); last_hs_output = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[1]);
last_cell_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[3]); last_cell_output =
prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[3]);
if(ret[0] == prog.end()) if(ret[0] == prog.end())
{ {
cell_outputs = ret[3]; cell_outputs = ret[3];
hidden_state = prog.replace_instruction(ins, op::concat{0}, ret[1]); hidden_state = prog.replace_instruction(ins, make_op("concat", {{"axis", 0}}), ret[1]);
} }
else else
{ {
auto concat_cell_arg0 = is_forward ? ret[2] : ret[3]; auto concat_cell_arg0 = is_forward ? ret[2] : ret[3];
auto concat_cell_arg1 = is_forward ? ret[3] : ret[2]; auto concat_cell_arg1 = is_forward ? ret[3] : ret[2];
cell_outputs = cell_outputs = prog.insert_instruction(
prog.insert_instruction(ins, op::concat{0}, concat_cell_arg0, concat_cell_arg1); ins, make_op("concat", {{"axis", 0}}), concat_cell_arg0, concat_cell_arg1);
auto concat_arg0 = is_forward ? ret[0] : ret[1]; auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[0]; auto concat_arg1 = is_forward ? ret[1] : ret[0];
hidden_state = prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1); hidden_state = prog.replace_instruction(
ins, make_op("concat", {{"axis", 0}}), concat_arg0, concat_arg1);
} }
} }
...@@ -957,18 +1032,18 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -957,18 +1032,18 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
std::vector<int64_t> perm{1, 0}; std::vector<int64_t> perm{1, 0};
// w matrix, squeeze and transpose // w matrix, squeeze and transpose
auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w); auto sw = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
auto tsw = prog.insert_instruction(ins, op::transpose{perm}, sw); auto tsw = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), sw);
// r matrix, squeeze and transpose // r matrix, squeeze and transpose
auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r); auto sr = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
auto tsr = prog.insert_instruction(ins, op::transpose{perm}, sr); auto tsr = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), sr);
// initial hidden state // initial hidden state
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih); auto sih = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
// initial cell state // initial cell state
auto sic = prog.insert_instruction(ins, op::squeeze{{0}}, ic); auto sic = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ic);
auto ic_lens = sic->get_shape().lens(); auto ic_lens = sic->get_shape().lens();
// bias // bias
...@@ -976,13 +1051,19 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -976,13 +1051,19 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
if(bias != prog.end()) if(bias != prog.end())
{ {
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias); auto sbias = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), bias);
auto ub_wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {4 * hs}}, sbias); auto ub_wb = prog.insert_instruction(
auto ub_rb = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {8 * hs}}, sbias); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {4 * hs}}}), sbias);
auto ub_wrb = prog.insert_instruction(ins, op::add{}, ub_wb, ub_rb); auto ub_rb = prog.insert_instruction(
ins,
make_op("slice", {{"axes", {0}}, {"starts", {4 * hs}}, {"ends", {8 * hs}}}),
sbias);
auto ub_wrb = prog.insert_instruction(ins, make_op("add"), ub_wb, ub_rb);
wrb = prog.insert_instruction( wrb = prog.insert_instruction(
ins, op::broadcast{1, {bs, 4 * static_cast<size_t>(hs)}}, ub_wrb); ins,
make_op("broadcast", {{"axis", 1}, {"dims", {bs, 4 * static_cast<size_t>(hs)}}}),
ub_wrb);
} }
// peep hole // peep hole
...@@ -991,73 +1072,90 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -991,73 +1072,90 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
instruction_ref pphf_brcst{}; instruction_ref pphf_brcst{};
if(pph != prog.end()) if(pph != prog.end())
{ {
auto spph = prog.insert_instruction(ins, op::squeeze{{0}}, pph); auto spph = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), pph);
auto pphi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, spph); auto pphi = prog.insert_instruction(
pphi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, pphi); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {hs}}}), spph);
pphi_brcst = prog.insert_instruction(
ins, make_op("broadcast", {{"axis", 1}, {"dims", ic_lens}}), pphi);
auto ppho = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, spph); auto ppho = prog.insert_instruction(
ppho_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, ppho); ins, make_op("slice", {{"axes", {0}}, {"starts", {hs}}, {"ends", {2 * hs}}}), spph);
ppho_brcst = prog.insert_instruction(
ins, make_op("broadcast", {{"axis", 1}, {"dims", ic_lens}}), ppho);
auto pphf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, spph); auto pphf = prog.insert_instruction(
pphf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, pphf); ins, make_op("slice", {{"axes", {0}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), spph);
pphf_brcst = prog.insert_instruction(
ins, make_op("broadcast", {{"axis", 1}, {"dims", ic_lens}}), pphf);
} }
long seq_len = static_cast<long>(get_seq_len(prog, seq, seq_lens)); long seq_len = static_cast<long>(get_seq_len(prog, seq, seq_lens));
for(long i = 0; i < seq_len; ++i) for(long i = 0; i < seq_len; ++i)
{ {
long seq_index = is_forward ? i : (seq_len - 1 - i); long seq_index = is_forward ? i : (seq_len - 1 - i);
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq); auto xt = prog.insert_instruction(
auto cont_xt = prog.insert_instruction(ins, op::contiguous{}, xt); ins,
xt = prog.insert_instruction(ins, op::squeeze{{0}}, cont_xt); make_op("slice", {{"axes", {0}}, {"starts", {seq_index}}, {"ends", {seq_index + 1}}}),
seq);
auto xt_tsw = prog.insert_instruction(ins, op::dot{}, xt, tsw); auto cont_xt = prog.insert_instruction(ins, make_op("contiguous"), xt);
auto sih_tsr = prog.insert_instruction(ins, op::dot{}, sih, tsr); xt = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), cont_xt);
auto xt_sih = prog.insert_instruction(ins, op::add{}, xt_tsw, sih_tsr);
auto xt_tsw = prog.insert_instruction(ins, make_op("dot"), xt, tsw);
auto sih_tsr = prog.insert_instruction(ins, make_op("dot"), sih, tsr);
auto xt_sih = prog.insert_instruction(ins, make_op("add"), xt_tsw, sih_tsr);
if(bias != prog.end()) if(bias != prog.end())
{ {
xt_sih = prog.insert_instruction(ins, op::add{}, xt_sih, wrb); xt_sih = prog.insert_instruction(ins, make_op("add"), xt_sih, wrb);
} }
auto it_before_actv = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, xt_sih); auto it_before_actv = prog.insert_instruction(
auto ot_before_actv = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, xt_sih); ins, make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {hs}}}), xt_sih);
auto ft_before_actv = auto ot_before_actv = prog.insert_instruction(
prog.insert_instruction(ins, op::slice{{1}, {2 * hs}, {3 * hs}}, xt_sih); ins, make_op("slice", {{"axes", {1}}, {"starts", {hs}}, {"ends", {2 * hs}}}), xt_sih);
auto ct_before_actv = auto ft_before_actv = prog.insert_instruction(
prog.insert_instruction(ins, op::slice{{1}, {3 * hs}, {4 * hs}}, xt_sih); ins,
make_op("slice", {{"axes", {1}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}),
xt_sih);
auto ct_before_actv = prog.insert_instruction(
ins,
make_op("slice", {{"axes", {1}}, {"starts", {3 * hs}}, {"ends", {4 * hs}}}),
xt_sih);
if(pph != prog.end()) if(pph != prog.end())
{ {
auto pphi_ct = prog.insert_instruction(ins, op::mul{}, pphi_brcst, sic); auto pphi_ct = prog.insert_instruction(ins, make_op("mul"), pphi_brcst, sic);
it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, pphi_ct); it_before_actv = prog.insert_instruction(ins, make_op("add"), it_before_actv, pphi_ct);
auto pphf_ct = prog.insert_instruction(ins, op::mul{}, pphf_brcst, sic); auto pphf_ct = prog.insert_instruction(ins, make_op("mul"), pphf_brcst, sic);
ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, pphf_ct); ft_before_actv = prog.insert_instruction(ins, make_op("add"), ft_before_actv, pphf_ct);
} }
auto it = prog.insert_instruction(ins, actv_func1, it_before_actv); auto it = prog.insert_instruction(ins, actv_func1, it_before_actv);
auto ft = prog.insert_instruction(ins, actv_func1, ft_before_actv); auto ft = prog.insert_instruction(ins, actv_func1, ft_before_actv);
auto ct = prog.insert_instruction(ins, actv_func2, ct_before_actv); auto ct = prog.insert_instruction(ins, actv_func2, ct_before_actv);
// equation Ct = ft (.) Ct-1 + it (.) ct // equation Ct = ft (.) Ct-1 + it (.) ct
auto ft_cell = prog.insert_instruction(ins, op::mul{}, ft, sic); auto ft_cell = prog.insert_instruction(ins, make_op("mul"), ft, sic);
auto it_ct = prog.insert_instruction(ins, op::mul{}, it, ct); auto it_ct = prog.insert_instruction(ins, make_op("mul"), it, ct);
auto cellt = prog.insert_instruction(ins, op::add{}, ft_cell, it_ct); auto cellt = prog.insert_instruction(ins, make_op("add"), ft_cell, it_ct);
if(pph != prog.end()) if(pph != prog.end())
{ {
auto ppho_cellt = prog.insert_instruction(ins, op::mul{}, ppho_brcst, cellt); auto ppho_cellt = prog.insert_instruction(ins, make_op("mul"), ppho_brcst, cellt);
ot_before_actv = prog.insert_instruction(ins, op::add{}, ot_before_actv, ppho_cellt); ot_before_actv =
prog.insert_instruction(ins, make_op("add"), ot_before_actv, ppho_cellt);
} }
auto ot = prog.insert_instruction(ins, actv_func1, ot_before_actv); auto ot = prog.insert_instruction(ins, actv_func1, ot_before_actv);
// Ht = ot (.) h(Ct) // Ht = ot (.) h(Ct)
auto h_cellt = prog.insert_instruction(ins, actv_func3, cellt); auto h_cellt = prog.insert_instruction(ins, actv_func3, cellt);
auto ht = prog.insert_instruction(ins, op::mul{}, ot, h_cellt); auto ht = prog.insert_instruction(ins, make_op("mul"), ot, h_cellt);
sic = cellt; sic = cellt;
sih = ht; sih = ht;
last_hs_output = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, ht); last_hs_output = prog.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), ht);
last_cell_output = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, cellt); last_cell_output =
prog.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), cellt);
if(i < seq_len - 1) if(i < seq_len - 1)
{ {
...@@ -1070,13 +1168,13 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -1070,13 +1168,13 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
{ {
auto concat_hs_arg0 = is_forward ? hidden_states : last_hs_output; auto concat_hs_arg0 = is_forward ? hidden_states : last_hs_output;
auto concat_hs_arg1 = is_forward ? last_hs_output : hidden_states; auto concat_hs_arg1 = is_forward ? last_hs_output : hidden_states;
hidden_states = hidden_states = prog.insert_instruction(
prog.insert_instruction(ins, op::concat{0}, concat_hs_arg0, concat_hs_arg1); ins, make_op("concat", {{"axis", 0}}), concat_hs_arg0, concat_hs_arg1);
auto concat_cell_arg0 = is_forward ? cell_outputs : last_cell_output; auto concat_cell_arg0 = is_forward ? cell_outputs : last_cell_output;
auto concat_cell_arg1 = is_forward ? last_cell_output : cell_outputs; auto concat_cell_arg1 = is_forward ? last_cell_output : cell_outputs;
cell_outputs = cell_outputs = prog.insert_instruction(
prog.insert_instruction(ins, op::concat{0}, concat_cell_arg0, concat_cell_arg1); ins, make_op("concat", {{"axis", 0}}), concat_cell_arg0, concat_cell_arg1);
} }
} }
} }
...@@ -1098,7 +1196,12 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const ...@@ -1098,7 +1196,12 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
switch(num_actv_funcs) switch(num_actv_funcs)
{ {
case 0: case 0:
return {op::sigmoid{}, op::tanh{}, op::tanh{}, op::sigmoid{}, op::tanh{}, op::tanh{}}; return {make_op("sigmoid"),
make_op("tanh"),
make_op("tanh"),
make_op("sigmoid"),
make_op("tanh"),
make_op("tanh")};
case 1: case 1:
return {actv_funcs.at(0), return {actv_funcs.at(0),
...@@ -1147,7 +1250,7 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const ...@@ -1147,7 +1250,7 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
{ {
switch(num_actv_funcs) switch(num_actv_funcs)
{ {
case 0: return {op::sigmoid{}, op::tanh{}, op::tanh{}}; case 0: return {make_op("sigmoid"), make_op("tanh"), make_op("tanh")};
case 1: return {actv_funcs.at(0), actv_funcs.at(0), actv_funcs.at(0)}; case 1: return {actv_funcs.at(0), actv_funcs.at(0), actv_funcs.at(0)};
...@@ -1215,7 +1318,11 @@ instruction_ref rewrite_rnn::replace_last_hs_output(module& prog, ...@@ -1215,7 +1318,11 @@ instruction_ref rewrite_rnn::replace_last_hs_output(module& prog,
if(variable_seq_len) if(variable_seq_len)
{ {
result_ins = prog.insert_instruction( result_ins = prog.insert_instruction(
std::next(ins), op::rnn_var_sl_shift_output{"hidden_states", dirct}, ins, seq_lens); std::next(ins),
make_op("rnn_var_sl_shift_output",
{{"output_name", "hidden_states"}, {"direction", dirct}}),
ins,
seq_lens);
prog.replace_instruction(ins, result_ins); prog.replace_instruction(ins, result_ins);
auto hs_outputs = find_all(result_ins->outputs(), auto hs_outputs = find_all(result_ins->outputs(),
[&](auto i) { return i->name() == "rnn_last_hs_output"; }); [&](auto i) { return i->name() == "rnn_last_hs_output"; });
...@@ -1223,8 +1330,10 @@ instruction_ref rewrite_rnn::replace_last_hs_output(module& prog, ...@@ -1223,8 +1330,10 @@ instruction_ref rewrite_rnn::replace_last_hs_output(module& prog,
for(auto& hs_out : hs_outputs) for(auto& hs_out : hs_outputs)
{ {
auto inputs = hs_out->inputs(); auto inputs = hs_out->inputs();
prog.replace_instruction( prog.replace_instruction(hs_out,
hs_out, op::rnn_var_sl_last_output{dirct}, inputs.front(), seq_lens); make_op("rnn_var_sl_last_output", {{"direction", dirct}}),
inputs.front(),
seq_lens);
} }
} }
else else
...@@ -1258,16 +1367,20 @@ void rewrite_rnn::replace_last_cell_output(module& prog, ...@@ -1258,16 +1367,20 @@ void rewrite_rnn::replace_last_cell_output(module& prog,
{ {
if(!ins_outputs.empty()) if(!ins_outputs.empty())
{ {
cell_outputs = cell_outputs = prog.insert_instruction(
prog.insert_instruction(std::next(ins), std::next(ins),
op::rnn_var_sl_shift_output{"cell_outputs", dirct}, make_op("rnn_var_sl_shift_output",
{{"output_name", "cell_outputs"}, {"direction", dirct}}),
cell_outputs, cell_outputs,
seq_lens); seq_lens);
} }
for(auto co : ins_outputs) for(auto co : ins_outputs)
{ {
prog.replace_instruction(co, op::rnn_var_sl_last_output{dirct}, cell_outputs, seq_lens); prog.replace_instruction(co,
make_op("rnn_var_sl_last_output", {{"direction", dirct}}),
cell_outputs,
seq_lens);
} }
} }
// replace the rnn_last_cell_output with the last_cell_output. The while // replace the rnn_last_cell_output with the last_cell_output. The while
...@@ -1300,7 +1413,8 @@ instruction_ref rewrite_rnn::pad_hidden_states(module& prog, ...@@ -1300,7 +1413,8 @@ instruction_ref rewrite_rnn::pad_hidden_states(module& prog,
shape pad_s{s.type(), pad_lens}; shape pad_s{s.type(), pad_lens};
std::vector<float> pad_data(pad_s.elements(), 0.0f); std::vector<float> pad_data(pad_s.elements(), 0.0f);
auto pl = prog.add_literal(pad_s, pad_data.begin(), pad_data.end()); auto pl = prog.add_literal(pad_s, pad_data.begin(), pad_data.end());
hs_padded = prog.insert_instruction(std::next(hs), op::concat{0}, hs, pl); hs_padded =
prog.insert_instruction(std::next(hs), make_op("concat", {{"axis", 0}}), hs, pl);
prog.replace_instruction(hs, hs_padded); prog.replace_instruction(hs, hs_padded);
} }
......
#include <migraphx/schedule.hpp> #include <migraphx/schedule.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraphx/par_for.hpp> #include <migraphx/par_for.hpp>
...@@ -12,6 +11,8 @@ ...@@ -12,6 +11,8 @@
#include <queue> #include <queue>
#include <thread> #include <thread>
#include <mutex> #include <mutex>
#include <migraphx/make_op.hpp>
#include <set> #include <set>
#include <deque> #include <deque>
#include <chrono> #include <chrono>
...@@ -556,7 +557,7 @@ void schedule::apply(module& p) const ...@@ -556,7 +557,7 @@ void schedule::apply(module& p) const
std::vector<instruction_ref> args; std::vector<instruction_ref> args;
args.push_back(ip.first); args.push_back(ip.first);
args.insert(args.end(), ip.second.begin(), ip.second.end()); args.insert(args.end(), ip.second.begin(), ip.second.end());
p.insert_instruction(std::next(ip.first), op::identity{}, args); p.insert_instruction(std::next(ip.first), make_op("identity"), args);
} }
} }
......
...@@ -16,6 +16,10 @@ ...@@ -16,6 +16,10 @@
#include <migraphx/op/transpose.hpp> #include <migraphx/op/transpose.hpp>
#include <migraphx/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/algorithm.hpp> #include <migraphx/algorithm.hpp>
namespace migraphx { namespace migraphx {
...@@ -62,8 +66,10 @@ struct find_mul_conv ...@@ -62,8 +66,10 @@ struct find_mul_conv
return; return;
auto new_a = p.insert_instruction( auto new_a = p.insert_instruction(
ins, op::broadcast{0, w_ins->get_shape().lens()}, a_ins->inputs().front()); ins,
auto new_mul = p.insert_instruction(ins, op::mul{}, new_a, w_ins); make_op("broadcast", {{"axis", 0}, {"dims", w_ins->get_shape().lens()}}),
a_ins->inputs().front());
auto new_mul = p.insert_instruction(ins, make_op("mul"), new_a, w_ins);
auto new_conv = p.insert_instruction( auto new_conv = p.insert_instruction(
ins, conv_ins->get_operator(), conv_ins->inputs().front(), new_mul); ins, conv_ins->get_operator(), conv_ins->inputs().front(), new_mul);
p.replace_instruction(ins, new_conv); p.replace_instruction(ins, new_conv);
...@@ -125,20 +131,27 @@ struct find_mul_slice_conv ...@@ -125,20 +131,27 @@ struct find_mul_slice_conv
auto slice_w_ins = p.insert_instruction(ins, w_slice_op, w_ins); auto slice_w_ins = p.insert_instruction(ins, w_slice_op, w_ins);
auto new_a = p.insert_instruction( auto new_a = p.insert_instruction(
ins, op::broadcast{0, slice_w_ins->get_shape().lens()}, a_ins->inputs().front()); ins,
auto new_mul = p.insert_instruction(ins, op::mul{}, new_a, slice_w_ins); make_op("broadcast", {{"axis", 0}, {"dims", slice_w_ins->get_shape().lens()}}),
a_ins->inputs().front());
auto new_mul = p.insert_instruction(ins, make_op("mul"), new_a, slice_w_ins);
std::vector<instruction_ref> sliced_weights; std::vector<instruction_ref> sliced_weights;
if(slice_op.starts.front() != 0) if(slice_op.starts.front() != 0)
sliced_weights.push_back( sliced_weights.push_back(p.insert_instruction(
p.insert_instruction(ins, op::slice{{0}, {0}, slice_op.starts}, w_ins)); ins,
make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", slice_op.starts}}),
w_ins));
sliced_weights.push_back(new_mul); sliced_weights.push_back(new_mul);
int64_t end_axis = w_ins->get_shape().lens().at(0); int64_t end_axis = w_ins->get_shape().lens().at(0);
if(slice_op.ends.front() != end_axis) if(slice_op.ends.front() != end_axis)
sliced_weights.push_back( sliced_weights.push_back(p.insert_instruction(
p.insert_instruction(ins, op::slice{{0}, {slice_op.ends}, {end_axis}}, w_ins)); ins,
make_op("slice", {{"axes", {0}}, {"starts", slice_op.ends}, {"ends", {end_axis}}}),
w_ins));
auto new_weights = p.insert_instruction(ins, op::concat{0}, sliced_weights); auto new_weights =
p.insert_instruction(ins, make_op("concat", {{"axis", 0}}), sliced_weights);
auto new_conv = p.insert_instruction( auto new_conv = p.insert_instruction(
ins, conv_ins->get_operator(), conv_ins->inputs().front(), new_weights); ins, conv_ins->get_operator(), conv_ins->inputs().front(), new_weights);
...@@ -177,9 +190,9 @@ struct find_mul_add ...@@ -177,9 +190,9 @@ struct find_mul_add
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
assert(x_ins != b_ins); assert(x_ins != b_ins);
auto ax_ins = p.insert_instruction(ins, op::mul{}, a_ins, x_ins); auto ax_ins = p.insert_instruction(ins, make_op("mul"), a_ins, x_ins);
auto ab_ins = p.insert_instruction(ins, op::mul{}, a_ins, b_ins); auto ab_ins = p.insert_instruction(ins, make_op("mul"), a_ins, b_ins);
p.replace_instruction(ins, op::add{}, ax_ins, ab_ins); p.replace_instruction(ins, make_op("add"), ax_ins, ab_ins);
} }
}; };
...@@ -198,8 +211,8 @@ struct find_add_lit_broadcast ...@@ -198,8 +211,8 @@ struct find_add_lit_broadcast
auto a_ins = r.instructions["a"]; auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"]; auto b_ins = r.instructions["b"];
auto sumab = p.insert_instruction(ins, op::add{}, a_ins, b_ins); auto sumab = p.insert_instruction(ins, make_op("add"), a_ins, b_ins);
p.replace_instruction(ins, op::add{}, x_ins, sumab); p.replace_instruction(ins, make_op("add"), x_ins, sumab);
} }
}; };
...@@ -226,17 +239,17 @@ struct find_double_add_lit_broadcast ...@@ -226,17 +239,17 @@ struct find_double_add_lit_broadcast
if(a_ins->inputs().at(0)->get_shape() != b_ins->inputs().at(0)->get_shape()) if(a_ins->inputs().at(0)->get_shape() != b_ins->inputs().at(0)->get_shape())
return; return;
auto op = a_ins->get_operator(); auto op = a_ins->get_operator();
auto presum = auto presum = p.insert_instruction(
p.insert_instruction(ins, op::add{}, a_ins->inputs().at(0), b_ins->inputs().at(0)); ins, make_op("add"), a_ins->inputs().at(0), b_ins->inputs().at(0));
sumab = p.insert_instruction(ins, op, presum); sumab = p.insert_instruction(ins, op, presum);
} }
else else
{ {
sumab = p.insert_instruction(ins, op::add{}, a_ins, b_ins); sumab = p.insert_instruction(ins, make_op("add"), a_ins, b_ins);
} }
auto sumxy = p.insert_instruction(ins, op::add{}, x_ins, y_ins); auto sumxy = p.insert_instruction(ins, make_op("add"), x_ins, y_ins);
p.replace_instruction(ins, op::add{}, sumxy, sumab); p.replace_instruction(ins, make_op("add"), sumxy, sumab);
} }
}; };
...@@ -327,7 +340,8 @@ struct find_concat_op ...@@ -327,7 +340,8 @@ struct find_concat_op
std::transform(start, last, std::back_inserter(inputs), [&](auto j) { std::transform(start, last, std::back_inserter(inputs), [&](auto j) {
return j->inputs().at(i); return j->inputs().at(i);
}); });
auto concat = p.insert_instruction(ins, op::concat{iaxis}, inputs); auto concat =
p.insert_instruction(ins, make_op("concat", {{"axis", iaxis}}), inputs);
concats.push_back(concat); concats.push_back(concat);
} }
auto y = p.insert_instruction(ins, op, concats); auto y = p.insert_instruction(ins, op, concats);
...@@ -349,7 +363,7 @@ struct find_concat_op ...@@ -349,7 +363,7 @@ struct find_concat_op
if(args.size() == 1) if(args.size() == 1)
p.replace_instruction(ins, args.front()); p.replace_instruction(ins, args.front());
else else
p.replace_instruction(ins, op::concat{axis}, args); p.replace_instruction(ins, make_op("concat", {{"axis", axis}}), args);
} }
}; };
...@@ -482,7 +496,8 @@ struct find_splits ...@@ -482,7 +496,8 @@ struct find_splits
return; return;
auto concat_axis = slice_op.axes.front(); auto concat_axis = slice_op.axes.front();
// TODO: Check if axises match // TODO: Check if axises match
auto concat = p.insert_instruction(ins, op::concat{concat_axis}, data_args); auto concat = p.insert_instruction(
ins, make_op("concat", {{"axis", concat_axis}}), data_args);
std::vector<instruction_ref> args; std::vector<instruction_ref> args;
args.resize(2); args.resize(2);
...@@ -501,7 +516,8 @@ struct find_splits ...@@ -501,7 +516,8 @@ struct find_splits
{ {
if(not contains({"reshape", "squeeze", "unsqueeze"}, output->name())) if(not contains({"reshape", "squeeze", "unsqueeze"}, output->name()))
continue; continue;
auto x = p.insert_instruction(output, op::contiguous{}, output->inputs()); auto x =
p.insert_instruction(output, make_op("contiguous"), output->inputs());
p.replace_instruction(output, output->get_operator(), x); p.replace_instruction(output, output->get_operator(), x);
} }
...@@ -648,7 +664,11 @@ struct find_add_convs ...@@ -648,7 +664,11 @@ struct find_add_convs
return; return;
new_op = a_op; new_op = a_op;
b_input = p.insert_instruction( b_input = p.insert_instruction(
ins, op::as_shape{compute_stride_shape(b_input->get_shape(), n)}, b_input); ins,
make_op(
"as_shape",
{{"shape", to_value(compute_stride_shape(b_input->get_shape(), n))}}),
b_input);
} }
else if(b_op.stride < a_op.stride) else if(b_op.stride < a_op.stride)
{ {
...@@ -657,7 +677,11 @@ struct find_add_convs ...@@ -657,7 +677,11 @@ struct find_add_convs
return; return;
new_op = b_op; new_op = b_op;
a_input = p.insert_instruction( a_input = p.insert_instruction(
ins, op::as_shape{compute_stride_shape(a_input->get_shape(), n)}, a_input); ins,
make_op(
"as_shape",
{{"shape", to_value(compute_stride_shape(a_input->get_shape(), n))}}),
a_input);
} }
else else
return; return;
...@@ -666,8 +690,10 @@ struct find_add_convs ...@@ -666,8 +690,10 @@ struct find_add_convs
return; return;
} }
auto concat_input = p.insert_instruction(ins, op::concat{1}, a_input, b_input); auto concat_input =
auto concat_weights = p.insert_instruction(ins, op::concat{1}, a_weights, b_weights); p.insert_instruction(ins, make_op("concat", {{"axis", 1}}), a_input, b_input);
auto concat_weights =
p.insert_instruction(ins, make_op("concat", {{"axis", 1}}), a_weights, b_weights);
p.replace_instruction(ins, new_op, concat_input, concat_weights); p.replace_instruction(ins, new_op, concat_input, concat_weights);
} }
}; };
...@@ -739,13 +765,18 @@ struct find_conv_dot_horiz_fusion ...@@ -739,13 +765,18 @@ struct find_conv_dot_horiz_fusion
for(auto arg : args) for(auto arg : args)
p.move_instructions(arg, input); p.move_instructions(arg, input);
// TODO: Check if axises match // TODO: Check if axises match
auto concat = p.insert_instruction(input, op::concat{concat_axis}, args); auto concat =
p.insert_instruction(input, make_op("concat", {{"axis", concat_axis}}), args);
auto fused = p.insert_instruction(std::next(input), op, input, concat); auto fused = p.insert_instruction(std::next(input), op, input, concat);
int64_t offset = 0; int64_t offset = 0;
for(auto arg : range(start, last)) for(auto arg : range(start, last))
{ {
int64_t len = arg->get_shape().lens()[axis]; int64_t len = arg->get_shape().lens()[axis];
p.replace_instruction(arg, op::slice{{axis}, {offset}, {offset + len}}, fused); p.replace_instruction(
arg,
make_op("slice",
{{"axes", {axis}}, {"starts", {offset}}, {"ends", {offset + len}}}),
fused);
offset += len; offset += len;
} }
}; };
...@@ -767,11 +798,11 @@ struct find_div_const ...@@ -767,11 +798,11 @@ struct find_div_const
auto ins = r.result; auto ins = r.result;
auto c_ins = r.instructions["c"]; auto c_ins = r.instructions["c"];
auto recip = p.insert_instruction(std::next(c_ins), op::recip{}, c_ins); auto recip = p.insert_instruction(std::next(c_ins), make_op("recip"), c_ins);
auto args = ins->inputs(); auto args = ins->inputs();
p.replace_instruction(ins, op::mul{}, args.front(), recip); p.replace_instruction(ins, make_op("mul"), args.front(), recip);
} }
}; };
...@@ -787,11 +818,11 @@ struct find_sub_const ...@@ -787,11 +818,11 @@ struct find_sub_const
auto ins = r.result; auto ins = r.result;
auto c_ins = r.instructions["c"]; auto c_ins = r.instructions["c"];
auto neg = p.insert_instruction(std::next(c_ins), op::neg{}, c_ins); auto neg = p.insert_instruction(std::next(c_ins), make_op("neg"), c_ins);
auto args = ins->inputs(); auto args = ins->inputs();
p.replace_instruction(ins, op::add{}, args.front(), neg); p.replace_instruction(ins, make_op("add"), args.front(), neg);
} }
}; };
...@@ -808,7 +839,7 @@ struct find_rsqrt ...@@ -808,7 +839,7 @@ struct find_rsqrt
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
p.replace_instruction(ins, op::rsqrt{}, x_ins); p.replace_instruction(ins, make_op("rsqrt"), x_ins);
} }
}; };
...@@ -883,14 +914,19 @@ struct find_split_reshape ...@@ -883,14 +914,19 @@ struct find_split_reshape
rsp_out_lens[rsp_axis] = std::accumulate(vec_dims.begin(), vec_dims.end(), std::int64_t{0}); rsp_out_lens[rsp_axis] = std::accumulate(vec_dims.begin(), vec_dims.end(), std::int64_t{0});
// insert the reshape instruction // insert the reshape instruction
auto rsp_ins = p.insert_instruction(std::next(input), op::reshape{rsp_out_lens}, input); auto rsp_ins = p.insert_instruction(
std::next(input), make_op("reshape", {{"dims", rsp_out_lens}}), input);
// replace the original reshape with slice // replace the original reshape with slice
int64_t start = 0; int64_t start = 0;
for(std::size_t i = 0; i < vec_rsp.size(); ++i) for(std::size_t i = 0; i < vec_rsp.size(); ++i)
{ {
p.replace_instruction( p.replace_instruction(
vec_rsp[i], op::slice{{rsp_axis}, {start}, {start + vec_dims[i]}}, rsp_ins); vec_rsp[i],
make_op(
"slice",
{{"axes", {rsp_axis}}, {"starts", {start}}, {"ends", {start + vec_dims[i]}}}),
rsp_ins);
start += vec_dims[i]; start += vec_dims[i];
} }
} }
...@@ -930,7 +966,8 @@ struct find_split_transpose ...@@ -930,7 +966,8 @@ struct find_split_transpose
} }
// insert an transpose instruction // insert an transpose instruction
auto tr = p.insert_instruction(std::next(input), op::transpose{perm}, input); auto tr =
p.insert_instruction(std::next(input), make_op("transpose", {{"dims", perm}}), input);
// compute the axis in the slice // compute the axis in the slice
auto axis = any_cast<op::slice>(slc->get_operator()).axes.front(); auto axis = any_cast<op::slice>(slc->get_operator()).axes.front();
...@@ -944,7 +981,10 @@ struct find_split_transpose ...@@ -944,7 +981,10 @@ struct find_split_transpose
auto starts = oper.starts; auto starts = oper.starts;
auto ends = oper.ends; auto ends = oper.ends;
auto tr_orig = in->outputs().front(); auto tr_orig = in->outputs().front();
p.replace_instruction(tr_orig, op::slice{{axis_new}, starts, ends}, tr); p.replace_instruction(
tr_orig,
make_op("slice", {{"axes", {axis_new}}, {"starts", starts}, {"ends", ends}}),
tr);
} }
} }
}; };
......
...@@ -11,6 +11,8 @@ ...@@ -11,6 +11,8 @@
#include <migraphx/permutation.hpp> #include <migraphx/permutation.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <unordered_set> #include <unordered_set>
#include <migraphx/make_op.hpp>
#include <map> #include <map>
namespace migraphx { namespace migraphx {
...@@ -149,7 +151,7 @@ struct find_transpose ...@@ -149,7 +151,7 @@ struct find_transpose
} }
else else
{ {
p.replace_instruction(ins, op::transpose{{dims}}, t->inputs().front()); p.replace_instruction(ins, make_op("transpose", {{"dims", dims}}), t->inputs().front());
} }
} }
}; };
...@@ -257,10 +259,10 @@ struct find_concat_transpose ...@@ -257,10 +259,10 @@ struct find_concat_transpose
std::vector<instruction_ref> inputs; std::vector<instruction_ref> inputs;
std::transform( std::transform(
ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto i) { ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto i) {
return p.insert_instruction(ins, op::transpose{permutation}, i); return p.insert_instruction(ins, make_op("transpose", {{"dims", permutation}}), i);
}); });
auto concat = p.insert_instruction(ins, op, inputs); auto concat = p.insert_instruction(ins, op, inputs);
auto t = p.insert_instruction(ins, op::transpose{ipermutation}, concat); auto t = p.insert_instruction(ins, make_op("transpose", {{"dims", ipermutation}}), concat);
assert(ins->get_shape().lens() == t->get_shape().lens()); assert(ins->get_shape().lens() == t->get_shape().lens());
p.replace_instruction(ins, t); p.replace_instruction(ins, t);
} }
......
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/tf.hpp> #include <migraphx/tf.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/pad_calc.hpp> #include <migraphx/pad_calc.hpp>
namespace migraphx { namespace migraphx {
...@@ -49,20 +51,20 @@ struct tf_parser ...@@ -49,20 +51,20 @@ struct tf_parser
instruction_ref to_nhwc(instruction_ref ins) const instruction_ref to_nhwc(instruction_ref ins) const
{ {
if(should_transpose(ins)) if(should_transpose(ins))
return mm->add_instruction(op::transpose{{0, 2, 3, 1}}, ins); return mm->add_instruction(make_op("transpose", {{"dims", {0, 2, 3, 1}}}), ins);
return ins; return ins;
} }
instruction_ref to_nchw(instruction_ref ins) const instruction_ref to_nchw(instruction_ref ins) const
{ {
if(should_transpose(ins)) if(should_transpose(ins))
return mm->add_instruction(op::transpose{{0, 3, 1, 2}}, ins); return mm->add_instruction(make_op("transpose", {{"dims", {0, 3, 1, 2}}}), ins);
return ins; return ins;
} }
instruction_ref to_kcxy(instruction_ref ins) const instruction_ref to_kcxy(instruction_ref ins) const
{ {
return mm->add_instruction(op::transpose{{3, 2, 0, 1}}, ins); return mm->add_instruction(make_op("transpose", {{"dims", {3, 2, 0, 1}}}), ins);
} }
instruction_ref make_contiguous(instruction_ref ins) const instruction_ref make_contiguous(instruction_ref ins) const
...@@ -70,7 +72,7 @@ struct tf_parser ...@@ -70,7 +72,7 @@ struct tf_parser
if(ins->get_shape().standard()) if(ins->get_shape().standard())
return ins; return ins;
else else
return mm->add_instruction(op::contiguous{}, ins); return mm->add_instruction(make_op("contiguous"), ins);
} }
std::vector<instruction_ref> to_nchw(const std::vector<instruction_ref>& args) std::vector<instruction_ref> to_nchw(const std::vector<instruction_ref>& args)
...@@ -176,20 +178,20 @@ struct tf_parser ...@@ -176,20 +178,20 @@ struct tf_parser
tf_parser() tf_parser()
{ {
add_generic_op("All", op::identity{}); add_generic_op("All", make_op("identity"));
add_generic_op("Identity", op::identity{}); add_generic_op("Identity", make_op("identity"));
add_generic_op("LessEqual", op::identity{}); add_generic_op("LessEqual", make_op("identity"));
add_generic_op("Relu", op::relu{}); add_generic_op("Relu", make_op("relu"));
add_generic_op("Rsqrt", op::rsqrt{}); add_generic_op("Rsqrt", make_op("rsqrt"));
add_generic_op("Tanh", op::tanh{}); add_generic_op("Tanh", make_op("tanh"));
add_generic_op("StopGradient", op::identity{}); add_generic_op("StopGradient", make_op("identity"));
add_binary_op("Add", op::add{}); add_binary_op("Add", make_op("add"));
add_binary_op("AddV2", op::add{}); add_binary_op("AddV2", make_op("add"));
add_binary_op("Mul", op::mul{}); add_binary_op("Mul", make_op("mul"));
add_binary_op("Pow", op::pow{}); add_binary_op("Pow", make_op("pow"));
add_binary_op("SquaredDifference", op::sqdiff{}); add_binary_op("SquaredDifference", make_op("sqdiff"));
add_binary_op("Sub", op::sub{}); add_binary_op("Sub", make_op("sub"));
add_mem_op("ArgMax", &tf_parser::parse_arg_op<op::argmax>, false); add_mem_op("ArgMax", &tf_parser::parse_arg_op<op::argmax>, false);
add_mem_op("ArgMin", &tf_parser::parse_arg_op<op::argmin>, false); add_mem_op("ArgMin", &tf_parser::parse_arg_op<op::argmin>, false);
...@@ -310,8 +312,10 @@ struct tf_parser ...@@ -310,8 +312,10 @@ struct tf_parser
output_lens.begin() + offset, output_lens.begin() + offset,
[](auto a, auto b) { return std::max(a, b); }); [](auto a, auto b) { return std::max(a, b); });
auto l0 = mm->add_instruction(op::multibroadcast{output_lens}, arg0); auto l0 = mm->add_instruction(make_op("multibroadcast", {{"output_lens", output_lens}}),
auto l1 = mm->add_instruction(op::multibroadcast{output_lens}, arg1); arg0);
auto l1 = mm->add_instruction(make_op("multibroadcast", {{"output_lens", output_lens}}),
arg1);
return to_nhwc(mm->add_instruction(x, to_nchw(l0), to_nchw(l1))); return to_nhwc(mm->add_instruction(x, to_nchw(l0), to_nchw(l1)));
} }
else else
...@@ -337,7 +341,7 @@ struct tf_parser ...@@ -337,7 +341,7 @@ struct tf_parser
int64_t axis = 0; int64_t axis = 0;
axis = args[1]->eval().at<int64_t>(); axis = args[1]->eval().at<int64_t>();
auto ins = mm->add_instruction(Op{axis}, args.front()); auto ins = mm->add_instruction(Op{axis}, args.front());
return mm->add_instruction(op::squeeze{{axis}}, ins); return mm->add_instruction(make_op("squeeze", {{"axes", {axis}}}), ins);
} }
instruction_ref parse_batchnorm(const std::string&, instruction_ref parse_batchnorm(const std::string&,
...@@ -359,8 +363,9 @@ struct tf_parser ...@@ -359,8 +363,9 @@ struct tf_parser
parse_biasadd(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const parse_biasadd(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
{ {
uint64_t axis = 1; // assume output of previous layer is in NCHW (broadcast on channel) uint64_t axis = 1; // assume output of previous layer is in NCHW (broadcast on channel)
auto l0 = mm->add_instruction(op::broadcast{axis, args[0]->get_shape().lens()}, args[1]); auto l0 = mm->add_instruction(
return mm->add_instruction(op::add{}, args[0], l0); make_op("broadcast", {{"axis", axis}, {"dims", args[0]->get_shape().lens()}}), args[1]);
return mm->add_instruction(make_op("add"), args[0], l0);
} }
instruction_ref parse_cast(const std::string&, instruction_ref parse_cast(const std::string&,
...@@ -368,7 +373,7 @@ struct tf_parser ...@@ -368,7 +373,7 @@ struct tf_parser
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
shape::type_t type = parse_type(attributes.at("DstT").type()); shape::type_t type = parse_type(attributes.at("DstT").type());
return mm->add_instruction(op::convert{type}, std::move(args)); return mm->add_instruction(make_op("convert", {{"target_type", type}}), std::move(args));
} }
instruction_ref parse_concat(const std::string&, instruction_ref parse_concat(const std::string&,
...@@ -442,7 +447,7 @@ struct tf_parser ...@@ -442,7 +447,7 @@ struct tf_parser
if(pads[0] != pads[2] || pads[1] != pads[3]) if(pads[0] != pads[2] || pads[1] != pads[3])
{ {
std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]}; std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
l0 = mm->add_instruction(migraphx::op::pad{padding}, l0); l0 = mm->add_instruction(migraphx::make_op("pad", {{"pads", padding}}), l0);
} }
else else
{ {
...@@ -528,7 +533,7 @@ struct tf_parser ...@@ -528,7 +533,7 @@ struct tf_parser
if(pads[0] != pads[2] || pads[1] != pads[3]) if(pads[0] != pads[2] || pads[1] != pads[3])
{ {
std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]}; std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
l0 = mm->add_instruction(migraphx::op::pad{padding}, l0); l0 = mm->add_instruction(migraphx::make_op("pad", {{"pads", padding}}), l0);
} }
else else
{ {
...@@ -553,8 +558,8 @@ struct tf_parser ...@@ -553,8 +558,8 @@ struct tf_parser
new_weights_shape[0] = out_channels; new_weights_shape[0] = out_channels;
new_weights_shape[1] = 1; new_weights_shape[1] = 1;
// Make sure weights are contiguous before doing reshape // Make sure weights are contiguous before doing reshape
auto new_weights = auto new_weights = mm->add_instruction(make_op("reshape", {{"dims", new_weights_shape}}),
mm->add_instruction(op::reshape{new_weights_shape}, make_contiguous(weights)); make_contiguous(weights));
return mm->add_instruction(op, {l0, new_weights}); return mm->add_instruction(op, {l0, new_weights});
} }
...@@ -576,7 +581,7 @@ struct tf_parser ...@@ -576,7 +581,7 @@ struct tf_parser
{ {
new_dims.insert(new_dims.begin() + dim, 1); new_dims.insert(new_dims.begin() + dim, 1);
} }
return mm->add_instruction(op::reshape{new_dims}, args[0]); return mm->add_instruction(make_op("reshape", {{"dims", new_dims}}), args[0]);
} }
instruction_ref instruction_ref
...@@ -617,10 +622,12 @@ struct tf_parser ...@@ -617,10 +622,12 @@ struct tf_parser
// swap the last two elements // swap the last two elements
std::iter_swap(perm.end() - 1, perm.end() - 2); std::iter_swap(perm.end() - 1, perm.end() - 2);
auto l1 = (transa) ? mm->add_instruction(op::transpose{perm}, args[0]) : args[0]; auto l1 = (transa) ? mm->add_instruction(make_op("transpose", {{"dims", perm}}), args[0])
auto l2 = (transb) ? mm->add_instruction(op::transpose{perm}, args[1]) : args[1]; : args[0];
auto l2 = (transb) ? mm->add_instruction(make_op("transpose", {{"dims", perm}}), args[1])
: args[1];
return mm->add_instruction(op::dot{}, l1, l2); return mm->add_instruction(make_op("dot"), l1, l2);
} }
instruction_ref parse_mean(const std::string&, instruction_ref parse_mean(const std::string&,
...@@ -632,12 +639,12 @@ struct tf_parser ...@@ -632,12 +639,12 @@ struct tf_parser
if(keep_dims) if(keep_dims)
{ {
return mm->add_instruction(op::reduce_mean{axes}, args[0]); return mm->add_instruction(make_op("reduce_mean", {{"axes", axes}}), args[0]);
} }
else else
{ {
auto ins = mm->add_instruction(op::reduce_mean{axes}, args[0]); auto ins = mm->add_instruction(make_op("reduce_mean", {{"axes", axes}}), args[0]);
return mm->add_instruction(op::squeeze{axes}, ins); return mm->add_instruction(make_op("squeeze", {{"axes", axes}}), ins);
} }
} }
...@@ -663,7 +670,7 @@ struct tf_parser ...@@ -663,7 +670,7 @@ struct tf_parser
{ {
shape s{shape::float_type, {depth, depth}}; shape s{shape::float_type, {depth, depth}};
auto l0 = mm->add_literal({s, depth_input}); auto l0 = mm->add_literal({s, depth_input});
return mm->add_instruction(op::gather{0}, {l0, args[0]}); return mm->add_instruction(make_op("gather", {{"axis", 0}}), {l0, args[0]});
} }
MIGRAPHX_THROW("MIGraphX does not support axis != -1"); MIGRAPHX_THROW("MIGraphX does not support axis != -1");
} }
...@@ -688,8 +695,10 @@ struct tf_parser ...@@ -688,8 +695,10 @@ struct tf_parser
args.begin(), args.begin(),
args.end(), args.end(),
std::back_inserter(unsqueezed_args), std::back_inserter(unsqueezed_args),
[&](instruction_ref arg) { return mm->add_instruction(op::unsqueeze{{axis}}, arg); }); [&](instruction_ref arg) {
return to_nhwc(mm->add_instruction(op::concat{axis}, unsqueezed_args)); return mm->add_instruction(make_op("unsqueeze", {{"axes", {axis}}}), arg);
});
return to_nhwc(mm->add_instruction(make_op("concat", {{"axis", axis}}), unsqueezed_args));
} }
instruction_ref instruction_ref
...@@ -765,7 +774,10 @@ struct tf_parser ...@@ -765,7 +774,10 @@ struct tf_parser
{ {
std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]}; std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
l0 = mm->add_instruction( l0 = mm->add_instruction(
migraphx::op::pad{padding, std::numeric_limits<float>::lowest()}, l0); migraphx::make_op(
"pad",
{{"pads", padding}, {"value", std::numeric_limits<float>::lowest()}}),
l0);
} }
else else
{ {
...@@ -784,9 +796,11 @@ struct tf_parser ...@@ -784,9 +796,11 @@ struct tf_parser
auto min_val = mm->add_literal(0.0f); auto min_val = mm->add_literal(0.0f);
auto max_val = mm->add_literal(6.0f); auto max_val = mm->add_literal(6.0f);
min_val = mm->add_instruction(op::multibroadcast{input_lens}, min_val); min_val =
max_val = mm->add_instruction(op::multibroadcast{input_lens}, max_val); mm->add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}), min_val);
return mm->add_instruction(op::clip{}, args.front(), min_val, max_val); max_val =
mm->add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}), max_val);
return mm->add_instruction(make_op("clip"), args.front(), min_val, max_val);
} }
instruction_ref instruction_ref
...@@ -884,7 +898,8 @@ struct tf_parser ...@@ -884,7 +898,8 @@ struct tf_parser
assert(num_outputs > 0); assert(num_outputs > 0);
if(num_outputs == 1) if(num_outputs == 1)
return std::vector<instruction_ref>{mm->add_instruction(op::identity{}, input_arg)}; return std::vector<instruction_ref>{
mm->add_instruction(make_op("identity"), input_arg)};
auto lens = input_arg->get_shape().lens(); auto lens = input_arg->get_shape().lens();
auto num_dims = lens.size(); auto num_dims = lens.size();
...@@ -1012,7 +1027,7 @@ struct tf_parser ...@@ -1012,7 +1027,7 @@ struct tf_parser
squeeze_axes.push_back(i); squeeze_axes.push_back(i);
} }
return mm->add_instruction(op::squeeze{squeeze_axes}, l1); return mm->add_instruction(make_op("squeeze", {{"axes", squeeze_axes}}), l1);
} }
instruction_ref parse_transpose(const std::string&, instruction_ref parse_transpose(const std::string&,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment