Unverified Commit 8d21fdc9 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Refactor to use make_op almost everywhere (#696)

* Load op when serializing

* Formatting

* Add missing clip field

* Use make_op almost everywhere

* Formatting

* More make ops for rnns

* Get rid of spaces

* Formatting

* Remove operators headers

* Formatting

* Remove unused op headers

* Increase line threshold
parent b5633c27
......@@ -18,7 +18,7 @@ CheckOptions:
- key: readability-function-size.BranchThreshold
value: '15'
- key: readability-function-size.LineThreshold
value: '300'
value: '350'
- key: readability-function-size.NestingThreshold
value: '5'
- key: readability-function-size.ParameterThreshold
......
......@@ -26,6 +26,7 @@ add_library(migraphx
load_save.cpp
make_op.cpp
msgpack.cpp
operation.cpp
program.cpp
quantization.cpp
reduce_dims.cpp
......
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/contiguous.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp>
namespace migraphx {
......@@ -14,7 +15,7 @@ void auto_contiguous::apply(module& p) const
shape s = ins->get_shape();
if(not s.standard() and s.elements() != 0)
{
auto c = p.insert_instruction(std::next(ins), op::contiguous{}, ins);
auto c = p.insert_instruction(std::next(ins), make_op("contiguous"), ins);
p.replace_instruction(ins, c);
}
}
......
......@@ -9,6 +9,8 @@
#include <migraphx/op/dot.hpp>
#include <migraphx/op/multibroadcast.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/add.hpp>
namespace migraphx {
......@@ -26,17 +28,19 @@ struct find_dot_add
not contains({shape::float_type, shape::half_type, shape::double_type},
ins->get_shape().type()))
return;
auto dot_ins =
p.insert_instruction(ins, op::dot{dot.alpha, 0}, ins->inputs()[0], ins->inputs()[1]);
auto dot_ins = p.insert_instruction(ins,
make_op("dot", {{"alpha", dot.alpha}, {"beta", 0}}),
ins->inputs()[0],
ins->inputs()[1]);
auto c_ins = ins->inputs()[2];
if(not float_equal(dot.beta, 1))
{
auto beta = p.add_literal(literal{shape{ins->get_shape().type()}, {dot.beta}});
auto beta_broadcast =
p.insert_instruction(ins, op::multibroadcast{ins->get_shape().lens()}, beta);
c_ins = p.insert_instruction(ins, op::mul{}, c_ins, beta_broadcast);
auto beta_broadcast = p.insert_instruction(
ins, make_op("multibroadcast", {{"output_lens", ins->get_shape().lens()}}), beta);
c_ins = p.insert_instruction(ins, make_op("mul"), c_ins, beta_broadcast);
}
p.replace_instruction(ins, op::add{}, dot_ins, c_ins);
p.replace_instruction(ins, make_op("add"), dot_ins, c_ins);
}
};
......
#include <migraphx/eliminate_allocation.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/load.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/pass_config.hpp>
namespace migraphx {
......@@ -33,7 +36,8 @@ void eliminate_allocation::apply(module& p) const
auto ins = pp.first;
auto s = ins->get_shape();
auto offset = pp.second;
p.replace_instruction(ins, op::load{s, offset}, mem);
p.replace_instruction(
ins, make_op("load", {{"shape", to_value(s)}, {"offset", offset}}), mem);
}
}
}
......
......@@ -6,6 +6,8 @@
#include <migraphx/op/identity.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/dfor.hpp>
namespace migraphx {
......@@ -77,7 +79,7 @@ void eliminate_concat::apply(module& p) const
}
std::vector<instruction_ref> args = {super};
std::copy(ins->inputs().begin(), ins->inputs().end() - 1, std::back_inserter(args));
p.replace_instruction(ins, migraphx::op::identity{}, args);
p.replace_instruction(ins, migraphx::make_op("identity"), args);
}
}
}
......
......@@ -33,6 +33,7 @@ struct lstm
return pack(f(self.hidden_size, "hidden_size"),
f(self.actv_funcs, "actv_func"),
f(self.direction, "direction"),
f(self.clip, "clip"),
f(self.input_forget, "input_forget"));
}
......
......@@ -846,6 +846,9 @@ bool has_finalize(const T& x)
return detail::has_finalize_op(x);
}
void migraphx_to_value(value& v, const operation& op);
void migraphx_from_value(const value& v, operation& op);
#endif
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -45,6 +45,10 @@
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/undefined.hpp>
#include <migraphx/op/unknown.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/op/unsqueeze.hpp>
namespace migraphx {
......@@ -267,7 +271,9 @@ struct onnx_parser
if(broadcasted != 0)
{
uint64_t axis = parse_value(info.attributes.at("axis")).at<uint64_t>();
auto l = mm->add_instruction(op::broadcast{axis, args[0]->get_shape().lens()},
auto l = mm->add_instruction(
make_op("broadcast",
{{"axis", axis}, {"dims", args[0]->get_shape().lens()}}),
args[1]);
return mm->add_instruction(make_op(op_name), args[0], l);
}
......@@ -341,11 +347,13 @@ struct onnx_parser
auto l0 = arg0;
if(arg0->get_shape().lens() != out_lens)
l0 = mm->add_instruction(op::multibroadcast{out_lens}, arg0);
l0 = mm->add_instruction(make_op("multibroadcast", {{"output_lens", out_lens}}),
arg0);
auto l1 = arg1;
if(arg1->get_shape().lens() != out_lens)
l1 = mm->add_instruction(op::multibroadcast{out_lens}, arg1);
l1 = mm->add_instruction(make_op("multibroadcast", {{"output_lens", out_lens}}),
arg1);
return mm->add_instruction(make_op(name), l0, l1);
}
......@@ -398,8 +406,9 @@ struct onnx_parser
{
if(args.size() == 3)
{
auto bias_bcast =
mm->add_instruction(op::broadcast{axis, curr_ins->get_shape().lens()}, args[2]);
auto bias_bcast = mm->add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", curr_ins->get_shape().lens()}}),
args[2]);
return mm->add_instruction(make_op("add"), curr_ins, bias_bcast);
}
return curr_ins;
......@@ -437,7 +446,8 @@ struct onnx_parser
asym_pads.insert(asym_pads.begin() + 2, left_pad_it, right_pad_it);
// add right pads
asym_pads.insert(asym_pads.begin() + pad_ndims + 4, right_pad_it, padding.end());
ins = mm->add_instruction(op::pad{asym_pads, pad_val}, ins);
ins =
mm->add_instruction(make_op("pad", {{"pads", asym_pads}, {"value", pad_val}}), ins);
}
else
{
......@@ -479,12 +489,14 @@ struct onnx_parser
if(min_used)
{
min_arg = mm->add_instruction(op::multibroadcast{input_lens}, min_arg);
min_arg = mm->add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}),
min_arg);
}
if(max_used)
{
max_arg = mm->add_instruction(op::multibroadcast{input_lens}, max_arg);
max_arg = mm->add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}),
max_arg);
}
if(min_used and max_used)
......@@ -525,7 +537,7 @@ struct onnx_parser
if(keep_dims == 0)
{
auto ins = mm->add_instruction(make_op(op_name, {{"axis", axis}}), std::move(args));
return mm->add_instruction(op::squeeze{{axis}}, ins);
return mm->add_instruction(make_op("squeeze", {{"axes", {axis}}}), ins);
}
else
{
......@@ -593,7 +605,8 @@ struct onnx_parser
{
*starts_it = idx;
*ends_it = *starts_it + 1;
slices.push_back(mm->add_instruction(op::slice{axes, starts, ends}, input));
slices.push_back(mm->add_instruction(
make_op("slice", {{"axes", axes}, {"starts", starts}, {"ends", ends}}), input));
}
// when padding on the left side, the outermost pad should be at the beginning
std::reverse(slices.begin(), slices.end());
......@@ -602,9 +615,10 @@ struct onnx_parser
{
*starts_it = *dims_it - idx - 1;
*ends_it = *starts_it + 1;
slices.push_back(mm->add_instruction(op::slice{axes, starts, ends}, input));
slices.push_back(mm->add_instruction(
make_op("slice", {{"axes", axes}, {"starts", starts}, {"ends", ends}}), input));
}
input = mm->add_instruction(op::concat{axis}, slices);
input = mm->add_instruction(make_op("concat", {{"axis", axis}}), slices);
}
return input;
}
......@@ -841,7 +855,8 @@ struct onnx_parser
std::back_inserter(ends),
[](auto curr_dim, auto pad_dim) { return curr_dim - pad_dim; });
l1 = mm->add_instruction(op::slice{axes, starts, ends}, l1);
l1 = mm->add_instruction(
make_op("slice", {{"axes", axes}, {"starts", starts}, {"ends", ends}}), l1);
}
if(contains(info.attributes, "output_padding"))
......@@ -852,7 +867,7 @@ struct onnx_parser
check_attr_sizes(kdims,
output_padding.size() - non_kdims,
"PARSE_CONV_TRANSPOSE: inconsistent output padding");
l1 = mm->add_instruction(op::pad{output_padding}, l1);
l1 = mm->add_instruction(make_op("pad", {{"pads", output_padding}}), l1);
}
if(contains(info.attributes, "output_shape"))
......@@ -871,7 +886,7 @@ struct onnx_parser
curr_shape.begin(),
std::back_inserter(target_padding),
[](auto out_dim, auto curr_dim) { return out_dim - curr_dim; });
l1 = mm->add_instruction(op::pad{target_padding}, l1);
l1 = mm->add_instruction(make_op("pad", {{"pads", target_padding}}), l1);
}
}
......@@ -1049,7 +1064,9 @@ struct onnx_parser
{
std::vector<int64_t> axes(kdims);
std::iota(axes.begin(), axes.end(), 2);
l1 = mm->add_instruction(op::slice{axes, slice_start, slice_end}, l1);
l1 = mm->add_instruction(
make_op("slice", {{"axes", axes}, {"starts", slice_start}, {"ends", slice_end}}),
l1);
}
return l1;
......@@ -1283,7 +1300,7 @@ struct onnx_parser
int64_t data_elem_num = static_cast<int64_t>(data_s.elements());
// reshape the input data as one dimension and used as input data
// to the gather operator
arg_data = mm->add_instruction(op::reshape{{data_elem_num}}, arg_data);
arg_data = mm->add_instruction(make_op("reshape", {{"dims", {data_elem_num}}}), arg_data);
std::size_t elem_num = ind_s.elements();
std::vector<int> ind_index(elem_num);
......@@ -1304,7 +1321,8 @@ struct onnx_parser
mm->add_literal(literal(ind_s, data_indices.begin(), data_indices.end()));
auto l_dim_idx = mm->add_literal(literal(ind_s, vec_axis_ind.begin(), vec_axis_ind.end()));
auto l_stride = mm->add_literal(literal{{ind_s.type(), {1}}, {axis_stride}});
l_stride = mm->add_instruction(op::multibroadcast{ind_s.lens()}, l_stride);
l_stride = mm->add_instruction(make_op("multibroadcast", {{"output_lens", ind_s.lens()}}),
l_stride);
auto dim_diff = mm->add_instruction(make_op("sub"), arg_ind, l_dim_idx);
auto delta = mm->add_instruction(make_op("mul"), dim_diff, l_stride);
auto ind = mm->add_instruction(make_op("add"), l_shape_idx, delta);
......@@ -1428,8 +1446,10 @@ struct onnx_parser
// swap the last two elements
std::swap(*perm.rbegin(), *(perm.rbegin() + 1));
auto l1 = (transa) ? mm->add_instruction(op::transpose{perm}, args[0]) : args[0];
auto l2 = (transb) ? mm->add_instruction(op::transpose{perm}, args[1]) : args[1];
auto l1 = (transa) ? mm->add_instruction(make_op("transpose", {{"dims", perm}}), args[0])
: args[0];
auto l2 = (transb) ? mm->add_instruction(make_op("transpose", {{"dims", perm}}), args[1])
: args[1];
if(args.size() == 3)
{
if(beta != 0.f && args[2]->get_shape().elements() > 0)
......@@ -1440,7 +1460,8 @@ struct onnx_parser
auto l3_lens = l3->get_shape().lens();
if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
{
l3 = mm->add_instruction(op::multibroadcast{out_lens}, args[2]);
l3 = mm->add_instruction(make_op("multibroadcast", {{"output_lens", out_lens}}),
args[2]);
}
return mm->add_instruction(
make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3);
......@@ -1466,7 +1487,7 @@ struct onnx_parser
{
is_a_prepended = true;
l0_lens.insert(l0_lens.begin(), 1);
l0 = mm->add_instruction(op::unsqueeze{{0}}, args[0]);
l0 = mm->add_instruction(make_op("unsqueeze", {{"axes", {0}}}), args[0]);
}
bool is_b_appended = false;
......@@ -1474,7 +1495,7 @@ struct onnx_parser
{
is_b_appended = true;
l1_lens.push_back(1);
l1 = mm->add_instruction(op::unsqueeze{{1}}, args[1]);
l1 = mm->add_instruction(make_op("unsqueeze", {{"axes", {1}}}), args[1]);
}
instruction_ref bl0 = l0;
......@@ -1492,11 +1513,13 @@ struct onnx_parser
l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, l1_lens.end());
if(l0_lens != l0_broadcasted_lens)
{
bl0 = mm->add_instruction(op::multibroadcast{l0_broadcasted_lens}, l0);
bl0 = mm->add_instruction(
make_op("multibroadcast", {{"output_lens", l0_broadcasted_lens}}), l0);
}
if(l1_lens != l1_broadcasted_lens)
{
bl1 = mm->add_instruction(op::multibroadcast{l1_broadcasted_lens}, l1);
bl1 = mm->add_instruction(
make_op("multibroadcast", {{"output_lens", l1_broadcasted_lens}}), l1);
}
}
......@@ -1504,12 +1527,12 @@ struct onnx_parser
int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size());
if(is_a_prepended)
{
dot_res = mm->add_instruction(op::squeeze{{num_axis - 2}}, dot_res);
dot_res = mm->add_instruction(make_op("squeeze", {{"axes", {num_axis - 2}}}), dot_res);
--num_axis;
}
if(is_b_appended)
{
dot_res = mm->add_instruction(op::squeeze{{num_axis - 1}}, dot_res);
dot_res = mm->add_instruction(make_op("squeeze", {{"axes", {num_axis - 1}}}), dot_res);
}
return dot_res;
......@@ -1563,19 +1586,24 @@ struct onnx_parser
std::iota(axes.begin(), axes.end(), 2);
auto mean = mm->add_instruction(make_op("reduce_mean", {{"axes", axes}}), x);
auto mean_bcast = mm->add_instruction(op::multibroadcast{dims}, mean);
auto mean_bcast =
mm->add_instruction(make_op("multibroadcast", {{"output_lens", dims}}), mean);
auto l0 = mm->add_instruction(make_op("sqdiff"), x, mean_bcast);
auto variance = mm->add_instruction(make_op("reduce_mean", {{"axes", axes}}), l0);
auto l1 = mm->add_instruction(make_op("sub"), x, mean_bcast);
auto epsilon_literal = mm->add_literal(epsilon);
auto epsilon_bcast = mm->add_instruction(op::multibroadcast{dims}, epsilon_literal);
auto variance_bcast = mm->add_instruction(op::multibroadcast{dims}, variance);
auto epsilon_bcast = mm->add_instruction(make_op("multibroadcast", {{"output_lens", dims}}),
epsilon_literal);
auto variance_bcast =
mm->add_instruction(make_op("multibroadcast", {{"output_lens", dims}}), variance);
auto l2 = mm->add_instruction(make_op("add"), variance_bcast, epsilon_bcast);
auto l3 = mm->add_instruction(make_op("rsqrt"), l2);
auto l4 = mm->add_instruction(make_op("mul"), l1, l3);
auto scale_bcast = mm->add_instruction(op::broadcast{1, dims}, scale);
auto scale_bcast =
mm->add_instruction(make_op("broadcast", {{"axis", 1}, {"dims", dims}}), scale);
;
auto bias_bcast = mm->add_instruction(op::broadcast{1, dims}, bias);
auto bias_bcast =
mm->add_instruction(make_op("broadcast", {{"axis", 1}, {"dims", dims}}), bias);
auto l5 = mm->add_instruction(make_op("mul"), l4, scale_bcast);
return mm->add_instruction(make_op("add"), l5, bias_bcast);
}
......@@ -1645,9 +1673,11 @@ struct onnx_parser
auto scale_val = mm->add_literal(literal{shape{input_type}, {scale}});
auto bias_vals = mm->add_literal(literal{shape{input_type, {bias.size()}}, bias});
auto scale_tensor = mm->add_instruction(migraphx::op::scalar{input_lens}, scale_val);
auto scale_tensor = mm->add_instruction(
migraphx::make_op("scalar", {{"scalar_bcst_dims", input_lens}}), scale_val);
auto img_scaled = mm->add_instruction(migraphx::make_op("mul"), args.front(), scale_tensor);
auto bias_bcast = mm->add_instruction(migraphx::op::broadcast{1, input_lens}, bias_vals);
auto bias_bcast = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"dims", input_lens}}), bias_vals);
return mm->add_instruction(migraphx::make_op("add"), img_scaled, bias_bcast);
}
......@@ -1660,7 +1690,7 @@ struct onnx_parser
auto&& perm_vals = info.attributes["perm"].ints();
perm = std::vector<int64_t>(perm_vals.begin(), perm_vals.end());
}
return mm->add_instruction(migraphx::op::transpose{perm}, args.front());
return mm->add_instruction(migraphx::make_op("transpose", {{"dims", perm}}), args.front());
}
instruction_ref parse_pad(const std::string&, node_info info, std::vector<instruction_ref> args)
......@@ -1721,7 +1751,8 @@ struct onnx_parser
value = parse_value(info.attributes.at("value")).at<float>();
}
return mm->add_instruction(migraphx::op::pad{pads, value}, args.front());
return mm->add_instruction(migraphx::make_op("pad", {{"pads", pads}, {"value", value}}),
args.front());
}
instruction_ref
......@@ -1918,7 +1949,7 @@ struct onnx_parser
std::vector<std::size_t> dims;
arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
auto out_lens = compute_broadcasted_lens(in_lens, dims);
return mm->add_instruction(op::multibroadcast{out_lens}, args[0]);
return mm->add_instruction(make_op("multibroadcast", {{"output_lens", out_lens}}), args[0]);
}
std::vector<instruction_ref>
......@@ -2001,16 +2032,20 @@ struct onnx_parser
// undefined operator to have 6 arguments
if(args.size() < 6)
{
auto ins = mm->add_instruction(op::undefined{});
auto ins = mm->add_instruction(make_op("undefined"));
args.insert(args.end(), (6 - args.size()), ins);
}
// first output for the concatenation of hidden states
auto hidden_states =
mm->add_instruction(op::rnn{hidden_size, vec_actv_funcs, dirct, clip}, std::move(args));
auto hidden_states = mm->add_instruction(make_op("rnn",
{{"hidden_size", hidden_size},
{"actv_func", to_value(vec_actv_funcs)},
{"direction", dirct},
{"clip", clip}}),
std::move(args));
// second output for the last hidden state
auto last_output = mm->add_instruction(op::rnn_last_hs_output{}, hidden_states);
auto last_output = mm->add_instruction(make_op("rnn_last_hs_output"), hidden_states);
return {hidden_states, last_output};
}
......@@ -2122,17 +2157,22 @@ struct onnx_parser
// append undefined opeator to make 6 arguments
if(args.size() < 6)
{
auto ins = mm->add_instruction(op::undefined{});
auto ins = mm->add_instruction(make_op("undefined"));
args.insert(args.end(), 6 - args.size(), ins);
}
// first output for concatenation of hidden states
auto hidden_states = mm->add_instruction(
op::gru{hidden_size, vec_actv_funcs, dirct, clip, linear_before_reset},
auto hidden_states =
mm->add_instruction(make_op("gru",
{{"hidden_size", hidden_size},
{"actv_func", to_value(vec_actv_funcs)},
{"direction", dirct},
{"clip", clip},
{"linear_before_reset", linear_before_reset}}),
std::move(args));
// second output for last gru output
auto last_output = mm->add_instruction(op::rnn_last_hs_output{}, hidden_states);
auto last_output = mm->add_instruction(make_op("rnn_last_hs_output"), hidden_states);
return {hidden_states, last_output};
}
......@@ -2304,18 +2344,23 @@ struct onnx_parser
// append undefined opeator to make 6 arguments
if(args.size() < 8)
{
auto ins = mm->add_instruction(op::undefined{});
auto ins = mm->add_instruction(make_op("undefined"));
args.insert(args.end(), 8 - args.size(), ins);
}
// first output for concatenation of hidden states
auto hidden_states = mm->add_instruction(
op::lstm{hidden_size, vec_actv_funcs, dirct, clip, input_forget}, std::move(args));
auto hidden_states = mm->add_instruction(make_op("lstm",
{{"hidden_size", hidden_size},
{"actv_func", to_value(vec_actv_funcs)},
{"direction", dirct},
{"clip", clip},
{"input_forget", input_forget}}),
std::move(args));
auto last_output = mm->add_instruction(op::rnn_last_hs_output{}, hidden_states);
auto last_output = mm->add_instruction(make_op("rnn_last_hs_output"), hidden_states);
// third output for last cell output
auto last_cell_output = mm->add_instruction(op::rnn_last_cell_output{}, hidden_states);
auto last_cell_output = mm->add_instruction(make_op("rnn_last_cell_output"), hidden_states);
return {hidden_states, last_output, last_cell_output};
}
......@@ -2350,7 +2395,7 @@ struct onnx_parser
else
{
auto ins = mm->add_instruction(make_op(op_name, {{"axes", axes}}), std::move(args));
return mm->add_instruction(op::squeeze{axes}, ins);
return mm->add_instruction(make_op("squeeze", {{"axes", axes}}), ins);
}
}
......@@ -2452,8 +2497,9 @@ struct onnx_parser
int64_t start = 0;
for(auto sl : vec_splits)
{
ret_ins.push_back(
mm->add_instruction(op::slice{{axis}, {start}, {start + sl}}, args[0]));
ret_ins.push_back(mm->add_instruction(
make_op("slice", {{"axes", {axis}}, {"starts", {start}}, {"ends", {start + sl}}}),
args[0]));
start += sl;
}
......@@ -2482,7 +2528,7 @@ struct onnx_parser
auto type = args[2]->get_shape().type();
shape s{type, {depth, depth}};
auto l_val = mm->add_literal({s, depth_input});
auto gather_out = mm->add_instruction(op::gather{0}, {l_val, args[0]});
auto gather_out = mm->add_instruction(make_op("gather", {{"axis", 0}}), {l_val, args[0]});
// Finally, we need a transpose to move the inner most dim to the axis dim
int n_rank = gather_out->get_shape().lens().size();
......@@ -2494,14 +2540,18 @@ struct onnx_parser
std::vector<int64_t> perm(n_rank - 1);
std::iota(perm.begin(), perm.end(), 0);
perm.insert(perm.begin() + tuned_axis, n_rank - 1);
auto tr_out = mm->add_instruction(op::transpose{perm}, gather_out);
auto tr_out = mm->add_instruction(make_op("transpose", {{"dims", perm}}), gather_out);
auto lens = tr_out->get_shape().lens();
auto off_val = mm->add_instruction(op::slice{{0}, {0}, {1}}, args[2]);
auto on_val = mm->add_instruction(op::slice{{0}, {1}, {2}}, args[2]);
auto off_val = mm->add_instruction(
make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]);
auto on_val = mm->add_instruction(
make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]);
auto diff = mm->add_instruction(make_op("sub"), on_val, off_val);
auto unsq_off_val = mm->add_instruction(op::multibroadcast{lens}, off_val);
auto unsq_diff_val = mm->add_instruction(op::multibroadcast{lens}, diff);
auto unsq_off_val =
mm->add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), off_val);
auto unsq_diff_val =
mm->add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), diff);
auto l_mul = mm->add_instruction(make_op("mul"), tr_out, unsq_diff_val);
return mm->add_instruction(make_op("add"), l_mul, unsq_off_val);
}
......@@ -2520,7 +2570,7 @@ struct onnx_parser
auto l1 = l0;
for(int j = 1; j < repeats[i]; j++)
{
l0 = mm->add_instruction(op::concat{i}, l0, l1);
l0 = mm->add_instruction(make_op("concat", {{"axis", i}}), l0, l1);
}
}
return l0;
......@@ -2585,7 +2635,7 @@ struct onnx_parser
reduce_mode = static_cast<reduce_mode_t>(info.attributes.at("mode").i());
}
auto l0 = mm->add_instruction(op::gather{}, args[0], args[1]);
auto l0 = mm->add_instruction(make_op("gather"), args[0], args[1]);
switch(reduce_mode)
{
case reduce_mode_t::sum:
......@@ -2920,7 +2970,7 @@ struct onnx_parser
{
if(!contains(instructions, name))
{
auto ins = mm->add_instruction(op::undefined{});
auto ins = mm->add_instruction(make_op("undefined"));
instructions[name] = ins;
}
}
......
#include <migraphx/operation.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void migraphx_to_value(value& v, const operation& op)
{
v["name"] = op.name();
v["operator"] = op.to_value();
}
void migraphx_from_value(const value& v, operation& op)
{
op = make_op(v.at("name").to<std::string>(), v.at("operator"));
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/op/load.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include "memory_coloring_impl.hpp"
namespace migraphx {
......@@ -208,7 +211,9 @@ void memory_coloring_impl::rewrite()
if(is_allocate(ins))
{
p_program->replace_instruction(
ins, op::load{ins->get_shape(), offset}, scratch_param);
ins,
make_op("load", {{"shape", to_value(ins->get_shape())}, {"offset", offset}}),
scratch_param);
}
}
}
......
......@@ -15,6 +15,8 @@
#include <algorithm>
#include <set>
#include <utility>
#include <migraphx/make_op.hpp>
#include <unordered_set>
namespace migraphx {
......@@ -212,7 +214,7 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
if(ins == std::prev(this->end()))
{
return replace_instruction(ins, op::identity{}, rep);
return replace_instruction(ins, make_op("identity"), rep);
}
// TODO: Should it be an error if the output is empty?
......
......@@ -19,6 +19,10 @@
#include <utility>
#include <set>
#include <iomanip>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <fstream>
#include <algorithm>
......@@ -58,12 +62,14 @@ instruction_ref insert_quant_ins(module& modl,
auto float_ins = scaled_ins;
if(scaled_ins->get_shape().type() != shape::float_type)
{
float_ins =
modl.insert_instruction(insert_loc, op::convert{shape::float_type}, scaled_ins);
float_ins = modl.insert_instruction(
insert_loc,
make_op("convert", {{"target_type", to_value(shape::float_type)}}),
scaled_ins);
}
std::vector<float> vec_scale(scaled_ins->get_shape().elements(), scale);
auto l_scale = modl.add_literal(literal(float_ins->get_shape(), vec_scale));
scaled_ins = modl.insert_instruction(insert_loc, op::mul{}, l_scale, float_ins);
scaled_ins = modl.insert_instruction(insert_loc, make_op("mul"), l_scale, float_ins);
}
auto shifted_ins = scaled_ins;
......@@ -73,26 +79,32 @@ instruction_ref insert_quant_ins(module& modl,
if(shifted_ins->get_shape().type() != shape::float_type)
{
float_ins = modl.insert_instruction(
insert_loc, op::convert{shape::float_type}, shifted_ins);
insert_loc,
make_op("convert", {{"target_type", to_value(shape::float_type)}}),
shifted_ins);
}
std::vector<float> vec_shift(shifted_ins->get_shape().elements(), shift);
auto l_shift = modl.add_literal(literal(float_ins->get_shape(), vec_shift));
shifted_ins = modl.insert_instruction(insert_loc, op::add{}, l_shift, float_ins);
shifted_ins = modl.insert_instruction(insert_loc, make_op("add"), l_shift, float_ins);
}
auto rounded_ins = modl.insert_instruction(insert_loc, op::round{}, shifted_ins);
auto rounded_ins = modl.insert_instruction(insert_loc, make_op("round"), shifted_ins);
auto rounded_lens = rounded_ins->get_shape().lens();
auto max_clip = modl.add_literal(127.0f);
auto min_clip = modl.add_literal(-128.0f);
max_clip = modl.insert_instruction(insert_loc, op::multibroadcast{rounded_lens}, max_clip);
min_clip = modl.insert_instruction(insert_loc, op::multibroadcast{rounded_lens}, min_clip);
max_clip = modl.insert_instruction(
insert_loc, make_op("multibroadcast", {{"output_lens", rounded_lens}}), max_clip);
min_clip = modl.insert_instruction(
insert_loc, make_op("multibroadcast", {{"output_lens", rounded_lens}}), min_clip);
auto clipped_ins =
modl.insert_instruction(insert_loc, op::clip{}, rounded_ins, min_clip, max_clip);
quant_ins = modl.insert_instruction(insert_loc, op::convert{type}, clipped_ins);
modl.insert_instruction(insert_loc, make_op("clip"), rounded_ins, min_clip, max_clip);
quant_ins = modl.insert_instruction(
insert_loc, make_op("convert", {{"target_type", type}}), clipped_ins);
}
else
{
quant_ins = modl.insert_instruction(insert_loc, op::convert{type}, ins);
quant_ins =
modl.insert_instruction(insert_loc, make_op("convert", {{"target_type", type}}), ins);
}
map_ins[ins] = quant_ins;
......@@ -162,8 +174,8 @@ void quantize_fp16(program& prog, const std::vector<std::string>& ins_names)
{
// check the dead code case to avoid assert
bool output_empty = ins->outputs().empty();
auto ins_orig_type =
mm->insert_instruction(std::next(ins), op::convert{orig_type}, ins);
auto ins_orig_type = mm->insert_instruction(
std::next(ins), make_op("convert", {{"target_type", orig_type}}), ins);
if(!output_empty)
{
mm->replace_instruction(ins, ins_orig_type);
......@@ -197,13 +209,18 @@ static void ins_quantize_int8(module& modl,
if(shape::int32_type == orig_type)
{
modl.replace_instruction(
ins, op::quant_dot{quant_alpha, quant_beta}, converted_inputs);
ins,
make_op("quant_dot", {{"alpha", quant_alpha}, {"beta", quant_beta}}),
converted_inputs);
}
else
{
auto quant_dot = modl.insert_instruction(
ins, op::quant_dot{quant_alpha, quant_beta}, converted_inputs);
modl.replace_instruction(ins, op::convert{orig_type}, quant_dot);
ins,
make_op("quant_dot", {{"alpha", quant_alpha}, {"beta", quant_beta}}),
converted_inputs);
modl.replace_instruction(
ins, make_op("convert", {{"target_type", to_value(orig_type)}}), quant_dot);
}
}
// either alpha or beta cannot be quantized because of too big
......@@ -214,8 +231,10 @@ static void ins_quantize_int8(module& modl,
{
converted_inputs.pop_back();
}
auto q_dot = modl.insert_instruction(ins, op::quant_dot{1, 0}, converted_inputs);
auto f_dot = modl.insert_instruction(ins, op::convert{shape::float_type}, q_dot);
auto q_dot = modl.insert_instruction(
ins, make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), converted_inputs);
auto f_dot = modl.insert_instruction(
ins, make_op("convert", {{"target_type", to_value(shape::float_type)}}), q_dot);
auto c_shape = q_dot->get_shape();
std::vector<float> vec_alpha(c_shape.elements(), new_alpha);
auto l_alpha =
......@@ -223,42 +242,46 @@ static void ins_quantize_int8(module& modl,
if(inputs.size() == 3 and dot_op.beta != 0.0f)
{
auto alpha_ab = modl.insert_instruction(ins, op::mul{}, l_alpha, f_dot);
auto alpha_ab = modl.insert_instruction(ins, make_op("mul"), l_alpha, f_dot);
std::vector<float> vec_beta(c_shape.elements(), dot_op.beta);
auto l_beta =
modl.add_literal(literal({shape::float_type, c_shape.lens()}, vec_beta));
instruction_ref beta_c{};
if(orig_type != shape::float_type)
{
auto fp32_c =
modl.insert_instruction(ins, op::convert{shape::float_type}, inputs.back());
beta_c = modl.insert_instruction(ins, op::mul{}, l_beta, fp32_c);
auto fp32_c = modl.insert_instruction(
ins,
make_op("convert", {{"target_type", to_value(shape::float_type)}}),
inputs.back());
beta_c = modl.insert_instruction(ins, make_op("mul"), l_beta, fp32_c);
}
else
{
beta_c = modl.insert_instruction(ins, op::mul{}, l_beta, inputs.back());
beta_c = modl.insert_instruction(ins, make_op("mul"), l_beta, inputs.back());
}
if(orig_type == shape::float_type)
{
modl.replace_instruction(ins, op::add{}, alpha_ab, beta_c);
modl.replace_instruction(ins, make_op("add"), alpha_ab, beta_c);
}
else
{
auto f_res = modl.insert_instruction(ins, op::add{}, alpha_ab, beta_c);
modl.replace_instruction(ins, op::convert{orig_type}, f_res);
auto f_res = modl.insert_instruction(ins, make_op("add"), alpha_ab, beta_c);
modl.replace_instruction(
ins, make_op("convert", {{"target_type", to_value(orig_type)}}), f_res);
}
}
else
{
if(orig_type == shape::float_type)
{
modl.replace_instruction(ins, op::mul{}, l_alpha, f_dot);
modl.replace_instruction(ins, make_op("mul"), l_alpha, f_dot);
}
else
{
auto alpha_ab = modl.insert_instruction(ins, op::mul{}, l_alpha, f_dot);
modl.replace_instruction(ins, op::convert{orig_type}, alpha_ab);
auto alpha_ab = modl.insert_instruction(ins, make_op("mul"), l_alpha, f_dot);
modl.replace_instruction(
ins, make_op("convert", {{"target_type", to_value(orig_type)}}), alpha_ab);
}
}
}
......@@ -285,23 +308,27 @@ static void ins_quantize_int8(module& modl,
{
auto l_factor = modl.add_literal(
literal(quant_conv->get_shape(), vec_factor.begin(), vec_factor.end()));
modl.replace_instruction(ins, op::mul{}, quant_conv, l_factor);
modl.replace_instruction(ins, make_op("mul"), quant_conv, l_factor);
}
// convert quant_conv output to float type, multiply the factor and
// conver back to original type
else
{
auto float_conv =
modl.insert_instruction(ins, op::convert{shape::float_type}, quant_conv);
auto float_conv = modl.insert_instruction(
ins,
make_op("convert", {{"target_type", to_value(shape::float_type)}}),
quant_conv);
auto l_factor = modl.add_literal(literal(float_conv->get_shape(), vec_factor));
if(orig_type == shape::float_type)
{
modl.replace_instruction(ins, op::mul{}, l_factor, float_conv);
modl.replace_instruction(ins, make_op("mul"), l_factor, float_conv);
}
else
{
auto adjusted_conv = modl.insert_instruction(ins, op::mul{}, l_factor, float_conv);
modl.replace_instruction(ins, op::convert{orig_type}, adjusted_conv);
auto adjusted_conv =
modl.insert_instruction(ins, make_op("mul"), l_factor, float_conv);
modl.replace_instruction(
ins, make_op("convert", {{"target_type", to_value(orig_type)}}), adjusted_conv);
}
}
}
......
......@@ -7,6 +7,8 @@
#include <migraphx/op/mul.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/dfor.hpp>
namespace migraphx {
......@@ -46,10 +48,10 @@ void rewrite_batchnorm::apply(module& p) const
auto broadcast = op::broadcast{1, ins->get_shape().lens()};
auto a_ins = p.add_literal({a.get_shape(), a.data()});
auto a_broadcast = p.insert_instruction(ins, broadcast, a_ins);
auto mul = p.insert_instruction(ins, op::mul{}, ins->inputs().front(), a_broadcast);
auto mul = p.insert_instruction(ins, make_op("mul"), ins->inputs().front(), a_broadcast);
auto b_ins = p.add_literal({b.get_shape(), b.data()});
auto b_broadcast = p.insert_instruction(ins, broadcast, b_ins);
auto add = p.insert_instruction(ins, op::add{}, mul, b_broadcast);
auto add = p.insert_instruction(ins, make_op("add"), mul, b_broadcast);
p.replace_instruction(ins, add);
}
}
......
......@@ -5,6 +5,8 @@
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/reduce_mean.hpp>
#include <migraphx/op/reduce_max.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
namespace migraphx {
......@@ -31,25 +33,26 @@ void rewrite_pooling::apply(module& prog) const
continue;
std::int64_t n = s.lens()[0];
std::int64_t c = s.lens()[1];
auto reshape =
prog.insert_instruction(ins, op::reshape{{n * c, -1}}, ins->inputs().front());
auto reshape = prog.insert_instruction(
ins, make_op("reshape", {{"dims", {n * c, -1}}}), ins->inputs().front());
instruction_ref pooling{};
// average pooling
if(op.mode == "average")
{
pooling = prog.insert_instruction(ins, op::reduce_mean{{1}}, reshape);
pooling =
prog.insert_instruction(ins, make_op("reduce_mean", {{"axes", {1}}}), reshape);
}
// max pooling
else
{
pooling = prog.insert_instruction(ins, op::reduce_max{{1}}, reshape);
pooling = prog.insert_instruction(ins, make_op("reduce_max", {{"axes", {1}}}), reshape);
}
std::vector<int64_t> rsp_lens(lens.size(), 1);
rsp_lens[0] = n;
rsp_lens[1] = c;
prog.replace_instruction(ins, op::reshape{rsp_lens}, pooling);
prog.replace_instruction(ins, make_op("reshape", {{"dims", rsp_lens}}), pooling);
}
}
......
......@@ -18,6 +18,8 @@
#include <migraphx/op/common.hpp>
#include <migraphx/op/rnn_var_sl_last_output.hpp>
#include <migraphx/op/rnn_variable_seq_lens.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/ranges.hpp>
......@@ -80,20 +82,26 @@ void rewrite_rnn::apply_vanilla_rnn(module& prog, instruction_ref ins) const
if(dirct == op::rnn_direction::bidirectional)
{
// input weight matrix
auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);
auto w_forward = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[1]);
auto w_reverse = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[1]);
// hidden state weight matrix
auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);
auto r_forward = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]);
auto r_reverse = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]);
// process bias
instruction_ref bias_forward = prog.end();
instruction_ref bias_reverse = prog.end();
if(args.size() >= 4 && args[3]->name() != "undefined")
{
bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
bias_forward = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
bias_reverse = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[3]);
}
// process intial hidden state, it could be the 6th argument
......@@ -102,8 +110,10 @@ void rewrite_rnn::apply_vanilla_rnn(module& prog, instruction_ref ins) const
instruction_ref ih_reverse{};
if(args.size() == 6 && args[5]->name() != "undefined")
{
ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
ih_forward = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
ih_reverse = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[5]);
}
else
{
......@@ -120,8 +130,8 @@ void rewrite_rnn::apply_vanilla_rnn(module& prog, instruction_ref ins) const
if(variable_seq_len)
{
args[0] =
prog.insert_instruction(ins, op::rnn_var_sl_shift_sequence{}, args[0], seq_lens);
args[0] = prog.insert_instruction(
ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
}
auto ret_reverse =
......@@ -131,24 +141,27 @@ void rewrite_rnn::apply_vanilla_rnn(module& prog, instruction_ref ins) const
{args[0], w_reverse, r_reverse, bias_reverse, seq_lens, ih_reverse},
actv_funcs.at(1));
auto concat_output =
prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);
auto concat_output = prog.insert_instruction(
ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
last_output =
prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_output);
// The following logic is to ensure the last instruction rewritten from
// rnn operator is a concat instruction
// sequence len is 1
if(ret_forward[0] == prog.end())
{
prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
prog.replace_instruction(
ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
}
else
{
ret_forward[0] =
prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
ret_reverse[0] =
prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
ret_forward[0] = prog.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), ret_forward[0], ret_forward[1]);
ret_reverse[0] = prog.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), ret_reverse[1], ret_reverse[0]);
prog.replace_instruction(
ins, make_op("concat", {{"axis", 1}}), {ret_forward[0], ret_reverse[0]});
}
}
else
......@@ -180,26 +193,27 @@ void rewrite_rnn::apply_vanilla_rnn(module& prog, instruction_ref ins) const
if(!is_forward and variable_seq_len)
{
args[0] =
prog.insert_instruction(ins, op::rnn_var_sl_shift_sequence{}, args[0], seq_lens);
args[0] = prog.insert_instruction(
ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
}
auto ret = vanilla_rnn_cell(
is_forward, prog, ins, {args[0], w, r, bias, seq_lens, ih}, actv_funcs.at(0));
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
last_output = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[1]);
// following logic is to ensure the last instruction is a
// concat instruction
// sequence len is 1
if(ret[0] == prog.end())
{
prog.replace_instruction(ins, op::concat{0}, ret[1]);
prog.replace_instruction(ins, make_op("concat", {{"axis", 0}}), ret[1]);
}
else
{
auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[0];
prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
prog.replace_instruction(
ins, make_op("concat", {{"axis", 0}}), concat_arg0, concat_arg1);
}
}
......@@ -225,15 +239,15 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
// squeeze and transpose w
std::vector<int64_t> perm{1, 0};
auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w);
auto tran_sw = prog.insert_instruction(ins, op::transpose{perm}, sw);
auto sw = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
auto tran_sw = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), sw);
// squeeze and transpose r
auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r);
auto tran_sr = prog.insert_instruction(ins, op::transpose{perm}, sr);
auto sr = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
auto tran_sr = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), sr);
// initial hidden state
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
auto sih = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
auto sih_lens = sih->get_shape().lens();
// bias
......@@ -241,30 +255,36 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
if(bias != prog.end())
{
long hs = static_cast<long>(r->get_shape().lens()[2]);
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
auto rb = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
auto wrb = prog.insert_instruction(ins, op::add{}, wb, rb);
bb = prog.insert_instruction(ins, op::broadcast{1, sih_lens}, wrb);
auto sbias = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), bias);
auto wb = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {hs}}}), sbias);
auto rb = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {hs}}, {"ends", {2 * hs}}}), sbias);
auto wrb = prog.insert_instruction(ins, make_op("add"), wb, rb);
bb = prog.insert_instruction(
ins, make_op("broadcast", {{"axis", 1}, {"dims", sih_lens}}), wrb);
}
instruction_ref hidden_out = prog.end();
instruction_ref last_out{};
last_out = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, sih);
last_out = prog.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), sih);
long seq_len = static_cast<long>(get_seq_len(prog, seq, seq_lens));
for(long i = 0; i < seq_len; i++)
{
long seq_index = is_forward ? i : (seq_len - 1 - i);
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq);
auto cont_xt = prog.insert_instruction(ins, op::contiguous{}, xt);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, cont_xt);
auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_sw);
auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_sr);
auto xt = prog.insert_instruction(
ins,
make_op("slice", {{"axes", {0}}, {"starts", {seq_index}}, {"ends", {seq_index + 1}}}),
seq);
auto cont_xt = prog.insert_instruction(ins, make_op("contiguous"), xt);
xt = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), cont_xt);
auto xt_wi = prog.insert_instruction(ins, make_op("dot"), xt, tran_sw);
auto ht_ri = prog.insert_instruction(ins, make_op("dot"), sih, tran_sr);
if(bias != prog.end())
{
xt_wi = prog.insert_instruction(ins, op::add{}, xt_wi, bb);
xt_wi = prog.insert_instruction(ins, make_op("add"), xt_wi, bb);
}
auto xt_ht = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
auto xt_ht = prog.insert_instruction(ins, make_op("add"), xt_wi, ht_ri);
// apply activation function
auto ht = prog.insert_instruction(ins, actv_func, xt_ht);
......@@ -272,7 +292,7 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
// add the dimensions of sequence length (axis 0 for sequence length,
// axis 1 for num_directions
last_out = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, ht);
last_out = prog.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), ht);
// concatenation for the last last_out is performed in the apply()
// function to ensure the last instruction is concat, then we have
......@@ -281,17 +301,17 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
{
if(is_forward)
{
hidden_out =
(seq_index == 0)
hidden_out = (seq_index == 0)
? last_out
: prog.insert_instruction(ins, op::concat{0}, hidden_out, last_out);
: prog.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), hidden_out, last_out);
}
else
{
hidden_out =
(seq_index == seq_len - 1)
hidden_out = (seq_index == seq_len - 1)
? last_out
: prog.insert_instruction(ins, op::concat{0}, last_out, hidden_out);
: prog.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), last_out, hidden_out);
}
}
}
......@@ -311,7 +331,7 @@ std::vector<operation> rewrite_rnn::vanilla_rnn_actv_funcs(instruction_ref ins)
if(rnn_op.actv_funcs.empty())
{
// default is tanh
return {op::tanh{}, op::tanh{}};
return {make_op("tanh"), make_op("tanh")};
}
else if(rnn_op.actv_funcs.size() == 1)
{
......@@ -327,7 +347,7 @@ std::vector<operation> rewrite_rnn::vanilla_rnn_actv_funcs(instruction_ref ins)
if(rnn_op.actv_funcs.empty())
{
// default is tanh
return {op::tanh{}};
return {make_op("tanh")};
}
else
{
......@@ -369,20 +389,26 @@ void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const
if(dirct == op::rnn_direction::bidirectional)
{
// w weight matrix
auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);
auto w_forward = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[1]);
auto w_reverse = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[1]);
// r weight matrix
auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);
auto r_forward = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]);
auto r_reverse = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]);
// bias
instruction_ref bias_forward = prog.end();
instruction_ref bias_reverse = prog.end();
if(args.size() >= 4 && args[3]->name() != "undefined")
{
bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
bias_forward = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
bias_reverse = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[3]);
}
// intial hidden state
......@@ -390,8 +416,10 @@ void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const
instruction_ref ih_reverse{};
if(args.size() == 6 && args[5]->name() != "undefined")
{
ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
ih_forward = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
ih_reverse = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[5]);
}
else
{
......@@ -410,8 +438,8 @@ void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const
if(variable_seq_len)
{
args[0] =
prog.insert_instruction(ins, op::rnn_var_sl_shift_sequence{}, args[0], seq_lens);
args[0] = prog.insert_instruction(
ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
}
auto ret_reverse =
......@@ -423,23 +451,26 @@ void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const
actv_funcs.at(2),
actv_funcs.at(3));
auto concat_output =
prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);
auto concat_output = prog.insert_instruction(
ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
last_output =
prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_output);
// The following logic is to ensure the last instruction rewritten
// from gru operator is a concat
if(ret_forward[0] == prog.end())
{
prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
prog.replace_instruction(
ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
}
else
{
ret_forward[0] =
prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
ret_reverse[0] =
prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
ret_forward[0] = prog.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), ret_forward[0], ret_forward[1]);
ret_reverse[0] = prog.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), ret_reverse[1], ret_reverse[0]);
prog.replace_instruction(
ins, make_op("concat", {{"axis", 1}}), {ret_forward[0], ret_reverse[0]});
}
}
else
......@@ -469,8 +500,8 @@ void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const
if(!is_forward and variable_seq_len)
{
args[0] =
prog.insert_instruction(ins, op::rnn_var_sl_shift_sequence{}, args[0], seq_lens);
args[0] = prog.insert_instruction(
ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
}
auto ret = gru_cell(is_forward,
......@@ -481,17 +512,18 @@ void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const
actv_funcs.at(0),
actv_funcs.at(1));
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
last_output = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[1]);
if(ret[0] == prog.end())
{
prog.replace_instruction(ins, op::concat{0}, ret[1]);
prog.replace_instruction(ins, make_op("concat", {{"axis", 0}}), ret[1]);
}
else
{
auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[0];
prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
prog.replace_instruction(
ins, make_op("concat", {{"axis", 0}}), concat_arg0, concat_arg1);
}
}
......@@ -529,19 +561,21 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
// w matrix squeeze to 2-dim and do a transpose
std::vector<int64_t> perm{1, 0};
auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w);
auto tw = prog.insert_instruction(ins, op::transpose{perm}, sw);
auto sw = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
auto tw = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), sw);
// r slide to two part, zr and h
auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r);
auto rzr = prog.insert_instruction(ins, op::slice{{0}, {0}, {2 * hs}}, sr);
auto trzr = prog.insert_instruction(ins, op::transpose{perm}, rzr);
auto sr = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
auto rzr = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2 * hs}}}), sr);
auto trzr = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), rzr);
auto rh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sr);
auto trh = prog.insert_instruction(ins, op::transpose{perm}, rh);
auto rh = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), sr);
auto trh = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), rh);
// initial states
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
auto sih = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
size_t bs = ih->get_shape().lens()[1];
// bias
......@@ -550,77 +584,100 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
instruction_ref brb_h{};
if(bias != prog.end())
{
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {3 * hs}}, sbias);
bwb = prog.insert_instruction(ins, op::broadcast{1, {bs, static_cast<size_t>(3 * hs)}}, wb);
auto sbias = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), bias);
auto wb = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {3 * hs}}}), sbias);
bwb = prog.insert_instruction(
ins,
make_op("broadcast", {{"axis", 1}, {"dims", {bs, static_cast<size_t>(3 * hs)}}}),
wb);
auto rb_zr = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {5 * hs}}, sbias);
auto rb_h = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
auto rb_zr = prog.insert_instruction(
ins,
make_op("slice", {{"axes", {0}}, {"starts", {3 * hs}}, {"ends", {5 * hs}}}),
sbias);
auto rb_h = prog.insert_instruction(
ins,
make_op("slice", {{"axes", {0}}, {"starts", {5 * hs}}, {"ends", {6 * hs}}}),
sbias);
brb_zr = prog.insert_instruction(
ins, op::broadcast{1, {bs, static_cast<size_t>(2 * hs)}}, rb_zr);
brb_h = prog.insert_instruction(ins, op::broadcast{1, {bs, static_cast<size_t>(hs)}}, rb_h);
ins,
make_op("broadcast", {{"axis", 1}, {"dims", {bs, static_cast<size_t>(2 * hs)}}}),
rb_zr);
brb_h = prog.insert_instruction(
ins,
make_op("broadcast", {{"axis", 1}, {"dims", {bs, static_cast<size_t>(hs)}}}),
rb_h);
}
long seq_len = static_cast<long>(get_seq_len(prog, seq, seq_lens));
for(long i = 0; i < seq_len; i++)
{
long seq_index = is_forward ? i : (seq_len - 1 - i);
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq);
auto cont_xt = prog.insert_instruction(ins, op::contiguous{}, xt);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, cont_xt);
auto xt = prog.insert_instruction(
ins,
make_op("slice", {{"axes", {0}}, {"starts", {seq_index}}, {"ends", {seq_index + 1}}}),
seq);
auto cont_xt = prog.insert_instruction(ins, make_op("contiguous"), xt);
xt = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), cont_xt);
auto xt_w = prog.insert_instruction(ins, op::dot{}, xt, tw);
auto ih1_rzr = prog.insert_instruction(ins, op::dot{}, sih, trzr);
auto xt_w = prog.insert_instruction(ins, make_op("dot"), xt, tw);
auto ih1_rzr = prog.insert_instruction(ins, make_op("dot"), sih, trzr);
if(bias != prog.end())
{
xt_w = prog.insert_instruction(ins, op::add{}, xt_w, bwb);
ih1_rzr = prog.insert_instruction(ins, op::add{}, ih1_rzr, brb_zr);
xt_w = prog.insert_instruction(ins, make_op("add"), xt_w, bwb);
ih1_rzr = prog.insert_instruction(ins, make_op("add"), ih1_rzr, brb_zr);
}
auto xw_z = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, xt_w);
auto xw_r = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, xt_w);
auto xw_h = prog.insert_instruction(ins, op::slice{{1}, {2 * hs}, {3 * hs}}, xt_w);
auto xw_z = prog.insert_instruction(
ins, make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {hs}}}), xt_w);
auto xw_r = prog.insert_instruction(
ins, make_op("slice", {{"axes", {1}}, {"starts", {hs}}, {"ends", {2 * hs}}}), xt_w);
auto xw_h = prog.insert_instruction(
ins, make_op("slice", {{"axes", {1}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), xt_w);
auto hr_z = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, ih1_rzr);
auto hr_r = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, ih1_rzr);
auto hr_z = prog.insert_instruction(
ins, make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {hs}}}), ih1_rzr);
auto hr_r = prog.insert_instruction(
ins, make_op("slice", {{"axes", {1}}, {"starts", {hs}}, {"ends", {2 * hs}}}), ih1_rzr);
auto xw_hr_z = prog.insert_instruction(ins, op::add{}, xw_z, hr_z);
auto xw_hr_z = prog.insert_instruction(ins, make_op("add"), xw_z, hr_z);
auto zt = prog.insert_instruction(ins, actv_func1, xw_hr_z);
auto xw_hr_r = prog.insert_instruction(ins, op::add{}, xw_r, hr_r);
auto xw_hr_r = prog.insert_instruction(ins, make_op("add"), xw_r, hr_r);
auto rt = prog.insert_instruction(ins, actv_func1, xw_hr_r);
instruction_ref hr_h{};
if(linear_before_reset == 0)
{
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
auto rt_ht1 = prog.insert_instruction(ins, op::mul{}, rt, sih);
hr_h = prog.insert_instruction(ins, op::dot{}, rt_ht1, trh);
auto rt_ht1 = prog.insert_instruction(ins, make_op("mul"), rt, sih);
hr_h = prog.insert_instruction(ins, make_op("dot"), rt_ht1, trh);
if(bias != prog.end())
{
hr_h = prog.insert_instruction(ins, op::add{}, hr_h, brb_h);
hr_h = prog.insert_instruction(ins, make_op("add"), hr_h, brb_h);
}
}
else
{
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
auto ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, trh);
auto ht1_rh = prog.insert_instruction(ins, make_op("dot"), sih, trh);
if(bias != prog.end())
{
ht1_rh = prog.insert_instruction(ins, op::add{}, ht1_rh, brb_h);
ht1_rh = prog.insert_instruction(ins, make_op("add"), ht1_rh, brb_h);
}
hr_h = prog.insert_instruction(ins, op::mul{}, rt, ht1_rh);
hr_h = prog.insert_instruction(ins, make_op("mul"), rt, ht1_rh);
}
auto xw_hr_h = prog.insert_instruction(ins, op::add{}, xw_h, hr_h);
auto xw_hr_h = prog.insert_instruction(ins, make_op("add"), xw_h, hr_h);
auto ht = prog.insert_instruction(ins, actv_func2, xw_hr_h);
// equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
auto one_minus_zt = prog.insert_instruction(ins, op::sub{}, l1, zt);
auto one_minus_zt_ht = prog.insert_instruction(ins, op::mul{}, one_minus_zt, ht);
auto zt_ht1 = prog.insert_instruction(ins, op::mul{}, zt, sih);
sih = prog.insert_instruction(ins, op::add{}, one_minus_zt_ht, zt_ht1);
last_output = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, sih);
auto one_minus_zt = prog.insert_instruction(ins, make_op("sub"), l1, zt);
auto one_minus_zt_ht = prog.insert_instruction(ins, make_op("mul"), one_minus_zt, ht);
auto zt_ht1 = prog.insert_instruction(ins, make_op("mul"), zt, sih);
sih = prog.insert_instruction(ins, make_op("add"), one_minus_zt_ht, zt_ht1);
last_output = prog.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), sih);
if(i < seq_len - 1)
{
......@@ -629,14 +686,16 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
hidden_states =
(seq_index == 0)
? last_output
: prog.insert_instruction(ins, op::concat{0}, hidden_states, last_output);
: prog.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), hidden_states, last_output);
}
else
{
hidden_states =
(seq_index == seq_len - 1)
? last_output
: prog.insert_instruction(ins, op::concat{0}, last_output, hidden_states);
: prog.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), last_output, hidden_states);
}
}
}
......@@ -654,7 +713,7 @@ std::vector<operation> rewrite_rnn::gru_actv_funcs(instruction_ref ins) const
if(gru_op.direction == op::rnn_direction::bidirectional)
{
if(gru_op.actv_funcs.empty())
return {op::sigmoid{}, op::tanh{}, op::sigmoid{}, op::tanh{}};
return {make_op("sigmoid"), make_op("tanh"), make_op("sigmoid"), make_op("tanh")};
else if(gru_op.actv_funcs.size() == 1)
return {gru_op.actv_funcs.at(0),
gru_op.actv_funcs.at(0),
......@@ -676,7 +735,7 @@ std::vector<operation> rewrite_rnn::gru_actv_funcs(instruction_ref ins) const
else
{
if(gru_op.actv_funcs.empty())
return {op::sigmoid{}, op::tanh{}};
return {make_op("sigmoid"), make_op("tanh")};
else if(gru_op.actv_funcs.size() == 1)
return {gru_op.actv_funcs.at(0), gru_op.actv_funcs.at(0)};
else
......@@ -720,20 +779,26 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
{
// input weight matrix
// input weight matrix
auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);
auto w_forward = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[1]);
auto w_reverse = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[1]);
// hidden state weight matrix
auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);
auto r_forward = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]);
auto r_reverse = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]);
// process bias
instruction_ref bias_forward = prog.end();
instruction_ref bias_reverse = prog.end();
if(args.size() >= 4 && args[3]->name() != "undefined")
{
bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
bias_forward = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
bias_reverse = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[3]);
}
// process intial hidden state, it is the 6th argument
......@@ -741,8 +806,10 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
instruction_ref ih_reverse{};
if(args.size() >= 6 && args[5]->name() != "undefined")
{
ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
ih_forward = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
ih_reverse = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[5]);
}
else
{
......@@ -755,8 +822,10 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
instruction_ref ic_reverse{};
if(args.size() >= 7 && args[6]->name() != "undefined")
{
ic_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[6]);
ic_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[6]);
ic_forward = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[6]);
ic_reverse = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[6]);
}
else
{
......@@ -769,8 +838,10 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
instruction_ref pph_reverse = prog.end();
if(args.size() == 8 && args[7]->name() != "undefined")
{
pph_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[7]);
pph_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[7]);
pph_forward = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[7]);
pph_reverse = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[7]);
}
auto ret_forward = lstm_cell(true,
......@@ -790,8 +861,8 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
if(variable_seq_len)
{
args[0] =
prog.insert_instruction(ins, op::rnn_var_sl_shift_sequence{}, args[0], seq_lens);
args[0] = prog.insert_instruction(
ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
}
auto ret_reverse = lstm_cell(false,
prog,
......@@ -808,12 +879,14 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
actv_funcs.at(4),
actv_funcs.at(5));
auto concat_hs_output =
prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
auto concat_cell_output =
prog.insert_instruction(ins, op::concat{1}, ret_forward[3], ret_reverse[3]);
last_hs_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_hs_output);
last_cell_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_cell_output);
auto concat_hs_output = prog.insert_instruction(
ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
auto concat_cell_output = prog.insert_instruction(
ins, make_op("concat", {{"axis", 1}}), ret_forward[3], ret_reverse[3]);
last_hs_output =
prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_hs_output);
last_cell_output =
prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_cell_output);
// the following logic is to ensure the last instruction is a concat
if(ret_forward[0] == prog.end())
......@@ -822,21 +895,21 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
}
else
{
ret_forward[1] =
prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
ret_reverse[1] =
prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
ret_forward[1] = prog.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), ret_forward[0], ret_forward[1]);
ret_reverse[1] = prog.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), ret_reverse[1], ret_reverse[0]);
ret_forward[3] =
prog.insert_instruction(ins, op::concat{0}, ret_forward[2], ret_forward[3]);
ret_reverse[3] =
prog.insert_instruction(ins, op::concat{0}, ret_reverse[3], ret_reverse[2]);
cell_outputs =
prog.insert_instruction(ins, op::concat{1}, ret_forward[3], ret_reverse[3]);
ret_forward[3] = prog.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), ret_forward[2], ret_forward[3]);
ret_reverse[3] = prog.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), ret_reverse[3], ret_reverse[2]);
cell_outputs = prog.insert_instruction(
ins, make_op("concat", {{"axis", 1}}), ret_forward[3], ret_reverse[3]);
}
hidden_state =
prog.replace_instruction(ins, op::concat{1}, {ret_forward[1], ret_reverse[1]});
hidden_state = prog.replace_instruction(
ins, make_op("concat", {{"axis", 1}}), {ret_forward[1], ret_reverse[1]});
}
else
{
......@@ -883,8 +956,8 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
if(!is_forward and variable_seq_len)
{
args[0] =
prog.insert_instruction(ins, op::rnn_var_sl_shift_sequence{}, args[0], seq_lens);
args[0] = prog.insert_instruction(
ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
}
auto ret = lstm_cell(is_forward,
prog,
......@@ -894,24 +967,26 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
actv_funcs.at(1),
actv_funcs.at(2));
last_hs_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
last_cell_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[3]);
last_hs_output = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[1]);
last_cell_output =
prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[3]);
if(ret[0] == prog.end())
{
cell_outputs = ret[3];
hidden_state = prog.replace_instruction(ins, op::concat{0}, ret[1]);
hidden_state = prog.replace_instruction(ins, make_op("concat", {{"axis", 0}}), ret[1]);
}
else
{
auto concat_cell_arg0 = is_forward ? ret[2] : ret[3];
auto concat_cell_arg1 = is_forward ? ret[3] : ret[2];
cell_outputs =
prog.insert_instruction(ins, op::concat{0}, concat_cell_arg0, concat_cell_arg1);
cell_outputs = prog.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), concat_cell_arg0, concat_cell_arg1);
auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[0];
hidden_state = prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
hidden_state = prog.replace_instruction(
ins, make_op("concat", {{"axis", 0}}), concat_arg0, concat_arg1);
}
}
......@@ -957,18 +1032,18 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
std::vector<int64_t> perm{1, 0};
// w matrix, squeeze and transpose
auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w);
auto tsw = prog.insert_instruction(ins, op::transpose{perm}, sw);
auto sw = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
auto tsw = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), sw);
// r matrix, squeeze and transpose
auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r);
auto tsr = prog.insert_instruction(ins, op::transpose{perm}, sr);
auto sr = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
auto tsr = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), sr);
// initial hidden state
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
auto sih = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
// initial cell state
auto sic = prog.insert_instruction(ins, op::squeeze{{0}}, ic);
auto sic = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ic);
auto ic_lens = sic->get_shape().lens();
// bias
......@@ -976,13 +1051,19 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
if(bias != prog.end())
{
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto ub_wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {4 * hs}}, sbias);
auto ub_rb = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {8 * hs}}, sbias);
auto ub_wrb = prog.insert_instruction(ins, op::add{}, ub_wb, ub_rb);
auto sbias = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), bias);
auto ub_wb = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {4 * hs}}}), sbias);
auto ub_rb = prog.insert_instruction(
ins,
make_op("slice", {{"axes", {0}}, {"starts", {4 * hs}}, {"ends", {8 * hs}}}),
sbias);
auto ub_wrb = prog.insert_instruction(ins, make_op("add"), ub_wb, ub_rb);
wrb = prog.insert_instruction(
ins, op::broadcast{1, {bs, 4 * static_cast<size_t>(hs)}}, ub_wrb);
ins,
make_op("broadcast", {{"axis", 1}, {"dims", {bs, 4 * static_cast<size_t>(hs)}}}),
ub_wrb);
}
// peep hole
......@@ -991,73 +1072,90 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
instruction_ref pphf_brcst{};
if(pph != prog.end())
{
auto spph = prog.insert_instruction(ins, op::squeeze{{0}}, pph);
auto pphi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, spph);
pphi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, pphi);
auto spph = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), pph);
auto pphi = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {hs}}}), spph);
pphi_brcst = prog.insert_instruction(
ins, make_op("broadcast", {{"axis", 1}, {"dims", ic_lens}}), pphi);
auto ppho = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, spph);
ppho_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, ppho);
auto ppho = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {hs}}, {"ends", {2 * hs}}}), spph);
ppho_brcst = prog.insert_instruction(
ins, make_op("broadcast", {{"axis", 1}, {"dims", ic_lens}}), ppho);
auto pphf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, spph);
pphf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, pphf);
auto pphf = prog.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), spph);
pphf_brcst = prog.insert_instruction(
ins, make_op("broadcast", {{"axis", 1}, {"dims", ic_lens}}), pphf);
}
long seq_len = static_cast<long>(get_seq_len(prog, seq, seq_lens));
for(long i = 0; i < seq_len; ++i)
{
long seq_index = is_forward ? i : (seq_len - 1 - i);
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq);
auto cont_xt = prog.insert_instruction(ins, op::contiguous{}, xt);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, cont_xt);
auto xt_tsw = prog.insert_instruction(ins, op::dot{}, xt, tsw);
auto sih_tsr = prog.insert_instruction(ins, op::dot{}, sih, tsr);
auto xt_sih = prog.insert_instruction(ins, op::add{}, xt_tsw, sih_tsr);
auto xt = prog.insert_instruction(
ins,
make_op("slice", {{"axes", {0}}, {"starts", {seq_index}}, {"ends", {seq_index + 1}}}),
seq);
auto cont_xt = prog.insert_instruction(ins, make_op("contiguous"), xt);
xt = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), cont_xt);
auto xt_tsw = prog.insert_instruction(ins, make_op("dot"), xt, tsw);
auto sih_tsr = prog.insert_instruction(ins, make_op("dot"), sih, tsr);
auto xt_sih = prog.insert_instruction(ins, make_op("add"), xt_tsw, sih_tsr);
if(bias != prog.end())
{
xt_sih = prog.insert_instruction(ins, op::add{}, xt_sih, wrb);
xt_sih = prog.insert_instruction(ins, make_op("add"), xt_sih, wrb);
}
auto it_before_actv = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, xt_sih);
auto ot_before_actv = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, xt_sih);
auto ft_before_actv =
prog.insert_instruction(ins, op::slice{{1}, {2 * hs}, {3 * hs}}, xt_sih);
auto ct_before_actv =
prog.insert_instruction(ins, op::slice{{1}, {3 * hs}, {4 * hs}}, xt_sih);
auto it_before_actv = prog.insert_instruction(
ins, make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {hs}}}), xt_sih);
auto ot_before_actv = prog.insert_instruction(
ins, make_op("slice", {{"axes", {1}}, {"starts", {hs}}, {"ends", {2 * hs}}}), xt_sih);
auto ft_before_actv = prog.insert_instruction(
ins,
make_op("slice", {{"axes", {1}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}),
xt_sih);
auto ct_before_actv = prog.insert_instruction(
ins,
make_op("slice", {{"axes", {1}}, {"starts", {3 * hs}}, {"ends", {4 * hs}}}),
xt_sih);
if(pph != prog.end())
{
auto pphi_ct = prog.insert_instruction(ins, op::mul{}, pphi_brcst, sic);
it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, pphi_ct);
auto pphi_ct = prog.insert_instruction(ins, make_op("mul"), pphi_brcst, sic);
it_before_actv = prog.insert_instruction(ins, make_op("add"), it_before_actv, pphi_ct);
auto pphf_ct = prog.insert_instruction(ins, op::mul{}, pphf_brcst, sic);
ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, pphf_ct);
auto pphf_ct = prog.insert_instruction(ins, make_op("mul"), pphf_brcst, sic);
ft_before_actv = prog.insert_instruction(ins, make_op("add"), ft_before_actv, pphf_ct);
}
auto it = prog.insert_instruction(ins, actv_func1, it_before_actv);
auto ft = prog.insert_instruction(ins, actv_func1, ft_before_actv);
auto ct = prog.insert_instruction(ins, actv_func2, ct_before_actv);
// equation Ct = ft (.) Ct-1 + it (.) ct
auto ft_cell = prog.insert_instruction(ins, op::mul{}, ft, sic);
auto it_ct = prog.insert_instruction(ins, op::mul{}, it, ct);
auto cellt = prog.insert_instruction(ins, op::add{}, ft_cell, it_ct);
auto ft_cell = prog.insert_instruction(ins, make_op("mul"), ft, sic);
auto it_ct = prog.insert_instruction(ins, make_op("mul"), it, ct);
auto cellt = prog.insert_instruction(ins, make_op("add"), ft_cell, it_ct);
if(pph != prog.end())
{
auto ppho_cellt = prog.insert_instruction(ins, op::mul{}, ppho_brcst, cellt);
ot_before_actv = prog.insert_instruction(ins, op::add{}, ot_before_actv, ppho_cellt);
auto ppho_cellt = prog.insert_instruction(ins, make_op("mul"), ppho_brcst, cellt);
ot_before_actv =
prog.insert_instruction(ins, make_op("add"), ot_before_actv, ppho_cellt);
}
auto ot = prog.insert_instruction(ins, actv_func1, ot_before_actv);
// Ht = ot (.) h(Ct)
auto h_cellt = prog.insert_instruction(ins, actv_func3, cellt);
auto ht = prog.insert_instruction(ins, op::mul{}, ot, h_cellt);
auto ht = prog.insert_instruction(ins, make_op("mul"), ot, h_cellt);
sic = cellt;
sih = ht;
last_hs_output = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, ht);
last_cell_output = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, cellt);
last_hs_output = prog.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), ht);
last_cell_output =
prog.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), cellt);
if(i < seq_len - 1)
{
......@@ -1070,13 +1168,13 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
{
auto concat_hs_arg0 = is_forward ? hidden_states : last_hs_output;
auto concat_hs_arg1 = is_forward ? last_hs_output : hidden_states;
hidden_states =
prog.insert_instruction(ins, op::concat{0}, concat_hs_arg0, concat_hs_arg1);
hidden_states = prog.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), concat_hs_arg0, concat_hs_arg1);
auto concat_cell_arg0 = is_forward ? cell_outputs : last_cell_output;
auto concat_cell_arg1 = is_forward ? last_cell_output : cell_outputs;
cell_outputs =
prog.insert_instruction(ins, op::concat{0}, concat_cell_arg0, concat_cell_arg1);
cell_outputs = prog.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), concat_cell_arg0, concat_cell_arg1);
}
}
}
......@@ -1098,7 +1196,12 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
switch(num_actv_funcs)
{
case 0:
return {op::sigmoid{}, op::tanh{}, op::tanh{}, op::sigmoid{}, op::tanh{}, op::tanh{}};
return {make_op("sigmoid"),
make_op("tanh"),
make_op("tanh"),
make_op("sigmoid"),
make_op("tanh"),
make_op("tanh")};
case 1:
return {actv_funcs.at(0),
......@@ -1147,7 +1250,7 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
{
switch(num_actv_funcs)
{
case 0: return {op::sigmoid{}, op::tanh{}, op::tanh{}};
case 0: return {make_op("sigmoid"), make_op("tanh"), make_op("tanh")};
case 1: return {actv_funcs.at(0), actv_funcs.at(0), actv_funcs.at(0)};
......@@ -1215,7 +1318,11 @@ instruction_ref rewrite_rnn::replace_last_hs_output(module& prog,
if(variable_seq_len)
{
result_ins = prog.insert_instruction(
std::next(ins), op::rnn_var_sl_shift_output{"hidden_states", dirct}, ins, seq_lens);
std::next(ins),
make_op("rnn_var_sl_shift_output",
{{"output_name", "hidden_states"}, {"direction", dirct}}),
ins,
seq_lens);
prog.replace_instruction(ins, result_ins);
auto hs_outputs = find_all(result_ins->outputs(),
[&](auto i) { return i->name() == "rnn_last_hs_output"; });
......@@ -1223,8 +1330,10 @@ instruction_ref rewrite_rnn::replace_last_hs_output(module& prog,
for(auto& hs_out : hs_outputs)
{
auto inputs = hs_out->inputs();
prog.replace_instruction(
hs_out, op::rnn_var_sl_last_output{dirct}, inputs.front(), seq_lens);
prog.replace_instruction(hs_out,
make_op("rnn_var_sl_last_output", {{"direction", dirct}}),
inputs.front(),
seq_lens);
}
}
else
......@@ -1258,16 +1367,20 @@ void rewrite_rnn::replace_last_cell_output(module& prog,
{
if(!ins_outputs.empty())
{
cell_outputs =
prog.insert_instruction(std::next(ins),
op::rnn_var_sl_shift_output{"cell_outputs", dirct},
cell_outputs = prog.insert_instruction(
std::next(ins),
make_op("rnn_var_sl_shift_output",
{{"output_name", "cell_outputs"}, {"direction", dirct}}),
cell_outputs,
seq_lens);
}
for(auto co : ins_outputs)
{
prog.replace_instruction(co, op::rnn_var_sl_last_output{dirct}, cell_outputs, seq_lens);
prog.replace_instruction(co,
make_op("rnn_var_sl_last_output", {{"direction", dirct}}),
cell_outputs,
seq_lens);
}
}
// replace the rnn_last_cell_output with the last_cell_output. The while
......@@ -1300,7 +1413,8 @@ instruction_ref rewrite_rnn::pad_hidden_states(module& prog,
shape pad_s{s.type(), pad_lens};
std::vector<float> pad_data(pad_s.elements(), 0.0f);
auto pl = prog.add_literal(pad_s, pad_data.begin(), pad_data.end());
hs_padded = prog.insert_instruction(std::next(hs), op::concat{0}, hs, pl);
hs_padded =
prog.insert_instruction(std::next(hs), make_op("concat", {{"axis", 0}}), hs, pl);
prog.replace_instruction(hs, hs_padded);
}
......
#include <migraphx/schedule.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/par_for.hpp>
......@@ -12,6 +11,8 @@
#include <queue>
#include <thread>
#include <mutex>
#include <migraphx/make_op.hpp>
#include <set>
#include <deque>
#include <chrono>
......@@ -556,7 +557,7 @@ void schedule::apply(module& p) const
std::vector<instruction_ref> args;
args.push_back(ip.first);
args.insert(args.end(), ip.second.begin(), ip.second.end());
p.insert_instruction(std::next(ip.first), op::identity{}, args);
p.insert_instruction(std::next(ip.first), make_op("identity"), args);
}
}
......
......@@ -16,6 +16,10 @@
#include <migraphx/op/transpose.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/algorithm.hpp>
namespace migraphx {
......@@ -62,8 +66,10 @@ struct find_mul_conv
return;
auto new_a = p.insert_instruction(
ins, op::broadcast{0, w_ins->get_shape().lens()}, a_ins->inputs().front());
auto new_mul = p.insert_instruction(ins, op::mul{}, new_a, w_ins);
ins,
make_op("broadcast", {{"axis", 0}, {"dims", w_ins->get_shape().lens()}}),
a_ins->inputs().front());
auto new_mul = p.insert_instruction(ins, make_op("mul"), new_a, w_ins);
auto new_conv = p.insert_instruction(
ins, conv_ins->get_operator(), conv_ins->inputs().front(), new_mul);
p.replace_instruction(ins, new_conv);
......@@ -125,20 +131,27 @@ struct find_mul_slice_conv
auto slice_w_ins = p.insert_instruction(ins, w_slice_op, w_ins);
auto new_a = p.insert_instruction(
ins, op::broadcast{0, slice_w_ins->get_shape().lens()}, a_ins->inputs().front());
auto new_mul = p.insert_instruction(ins, op::mul{}, new_a, slice_w_ins);
ins,
make_op("broadcast", {{"axis", 0}, {"dims", slice_w_ins->get_shape().lens()}}),
a_ins->inputs().front());
auto new_mul = p.insert_instruction(ins, make_op("mul"), new_a, slice_w_ins);
std::vector<instruction_ref> sliced_weights;
if(slice_op.starts.front() != 0)
sliced_weights.push_back(
p.insert_instruction(ins, op::slice{{0}, {0}, slice_op.starts}, w_ins));
sliced_weights.push_back(p.insert_instruction(
ins,
make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", slice_op.starts}}),
w_ins));
sliced_weights.push_back(new_mul);
int64_t end_axis = w_ins->get_shape().lens().at(0);
if(slice_op.ends.front() != end_axis)
sliced_weights.push_back(
p.insert_instruction(ins, op::slice{{0}, {slice_op.ends}, {end_axis}}, w_ins));
sliced_weights.push_back(p.insert_instruction(
ins,
make_op("slice", {{"axes", {0}}, {"starts", slice_op.ends}, {"ends", {end_axis}}}),
w_ins));
auto new_weights = p.insert_instruction(ins, op::concat{0}, sliced_weights);
auto new_weights =
p.insert_instruction(ins, make_op("concat", {{"axis", 0}}), sliced_weights);
auto new_conv = p.insert_instruction(
ins, conv_ins->get_operator(), conv_ins->inputs().front(), new_weights);
......@@ -177,9 +190,9 @@ struct find_mul_add
auto x_ins = r.instructions["x"];
assert(x_ins != b_ins);
auto ax_ins = p.insert_instruction(ins, op::mul{}, a_ins, x_ins);
auto ab_ins = p.insert_instruction(ins, op::mul{}, a_ins, b_ins);
p.replace_instruction(ins, op::add{}, ax_ins, ab_ins);
auto ax_ins = p.insert_instruction(ins, make_op("mul"), a_ins, x_ins);
auto ab_ins = p.insert_instruction(ins, make_op("mul"), a_ins, b_ins);
p.replace_instruction(ins, make_op("add"), ax_ins, ab_ins);
}
};
......@@ -198,8 +211,8 @@ struct find_add_lit_broadcast
auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"];
auto sumab = p.insert_instruction(ins, op::add{}, a_ins, b_ins);
p.replace_instruction(ins, op::add{}, x_ins, sumab);
auto sumab = p.insert_instruction(ins, make_op("add"), a_ins, b_ins);
p.replace_instruction(ins, make_op("add"), x_ins, sumab);
}
};
......@@ -226,17 +239,17 @@ struct find_double_add_lit_broadcast
if(a_ins->inputs().at(0)->get_shape() != b_ins->inputs().at(0)->get_shape())
return;
auto op = a_ins->get_operator();
auto presum =
p.insert_instruction(ins, op::add{}, a_ins->inputs().at(0), b_ins->inputs().at(0));
auto presum = p.insert_instruction(
ins, make_op("add"), a_ins->inputs().at(0), b_ins->inputs().at(0));
sumab = p.insert_instruction(ins, op, presum);
}
else
{
sumab = p.insert_instruction(ins, op::add{}, a_ins, b_ins);
sumab = p.insert_instruction(ins, make_op("add"), a_ins, b_ins);
}
auto sumxy = p.insert_instruction(ins, op::add{}, x_ins, y_ins);
p.replace_instruction(ins, op::add{}, sumxy, sumab);
auto sumxy = p.insert_instruction(ins, make_op("add"), x_ins, y_ins);
p.replace_instruction(ins, make_op("add"), sumxy, sumab);
}
};
......@@ -327,7 +340,8 @@ struct find_concat_op
std::transform(start, last, std::back_inserter(inputs), [&](auto j) {
return j->inputs().at(i);
});
auto concat = p.insert_instruction(ins, op::concat{iaxis}, inputs);
auto concat =
p.insert_instruction(ins, make_op("concat", {{"axis", iaxis}}), inputs);
concats.push_back(concat);
}
auto y = p.insert_instruction(ins, op, concats);
......@@ -349,7 +363,7 @@ struct find_concat_op
if(args.size() == 1)
p.replace_instruction(ins, args.front());
else
p.replace_instruction(ins, op::concat{axis}, args);
p.replace_instruction(ins, make_op("concat", {{"axis", axis}}), args);
}
};
......@@ -482,7 +496,8 @@ struct find_splits
return;
auto concat_axis = slice_op.axes.front();
// TODO: Check if axises match
auto concat = p.insert_instruction(ins, op::concat{concat_axis}, data_args);
auto concat = p.insert_instruction(
ins, make_op("concat", {{"axis", concat_axis}}), data_args);
std::vector<instruction_ref> args;
args.resize(2);
......@@ -501,7 +516,8 @@ struct find_splits
{
if(not contains({"reshape", "squeeze", "unsqueeze"}, output->name()))
continue;
auto x = p.insert_instruction(output, op::contiguous{}, output->inputs());
auto x =
p.insert_instruction(output, make_op("contiguous"), output->inputs());
p.replace_instruction(output, output->get_operator(), x);
}
......@@ -648,7 +664,11 @@ struct find_add_convs
return;
new_op = a_op;
b_input = p.insert_instruction(
ins, op::as_shape{compute_stride_shape(b_input->get_shape(), n)}, b_input);
ins,
make_op(
"as_shape",
{{"shape", to_value(compute_stride_shape(b_input->get_shape(), n))}}),
b_input);
}
else if(b_op.stride < a_op.stride)
{
......@@ -657,7 +677,11 @@ struct find_add_convs
return;
new_op = b_op;
a_input = p.insert_instruction(
ins, op::as_shape{compute_stride_shape(a_input->get_shape(), n)}, a_input);
ins,
make_op(
"as_shape",
{{"shape", to_value(compute_stride_shape(a_input->get_shape(), n))}}),
a_input);
}
else
return;
......@@ -666,8 +690,10 @@ struct find_add_convs
return;
}
auto concat_input = p.insert_instruction(ins, op::concat{1}, a_input, b_input);
auto concat_weights = p.insert_instruction(ins, op::concat{1}, a_weights, b_weights);
auto concat_input =
p.insert_instruction(ins, make_op("concat", {{"axis", 1}}), a_input, b_input);
auto concat_weights =
p.insert_instruction(ins, make_op("concat", {{"axis", 1}}), a_weights, b_weights);
p.replace_instruction(ins, new_op, concat_input, concat_weights);
}
};
......@@ -739,13 +765,18 @@ struct find_conv_dot_horiz_fusion
for(auto arg : args)
p.move_instructions(arg, input);
// TODO: Check if axises match
auto concat = p.insert_instruction(input, op::concat{concat_axis}, args);
auto concat =
p.insert_instruction(input, make_op("concat", {{"axis", concat_axis}}), args);
auto fused = p.insert_instruction(std::next(input), op, input, concat);
int64_t offset = 0;
for(auto arg : range(start, last))
{
int64_t len = arg->get_shape().lens()[axis];
p.replace_instruction(arg, op::slice{{axis}, {offset}, {offset + len}}, fused);
p.replace_instruction(
arg,
make_op("slice",
{{"axes", {axis}}, {"starts", {offset}}, {"ends", {offset + len}}}),
fused);
offset += len;
}
};
......@@ -767,11 +798,11 @@ struct find_div_const
auto ins = r.result;
auto c_ins = r.instructions["c"];
auto recip = p.insert_instruction(std::next(c_ins), op::recip{}, c_ins);
auto recip = p.insert_instruction(std::next(c_ins), make_op("recip"), c_ins);
auto args = ins->inputs();
p.replace_instruction(ins, op::mul{}, args.front(), recip);
p.replace_instruction(ins, make_op("mul"), args.front(), recip);
}
};
......@@ -787,11 +818,11 @@ struct find_sub_const
auto ins = r.result;
auto c_ins = r.instructions["c"];
auto neg = p.insert_instruction(std::next(c_ins), op::neg{}, c_ins);
auto neg = p.insert_instruction(std::next(c_ins), make_op("neg"), c_ins);
auto args = ins->inputs();
p.replace_instruction(ins, op::add{}, args.front(), neg);
p.replace_instruction(ins, make_op("add"), args.front(), neg);
}
};
......@@ -808,7 +839,7 @@ struct find_rsqrt
auto ins = r.result;
auto x_ins = r.instructions["x"];
p.replace_instruction(ins, op::rsqrt{}, x_ins);
p.replace_instruction(ins, make_op("rsqrt"), x_ins);
}
};
......@@ -883,14 +914,19 @@ struct find_split_reshape
rsp_out_lens[rsp_axis] = std::accumulate(vec_dims.begin(), vec_dims.end(), std::int64_t{0});
// insert the reshape instruction
auto rsp_ins = p.insert_instruction(std::next(input), op::reshape{rsp_out_lens}, input);
auto rsp_ins = p.insert_instruction(
std::next(input), make_op("reshape", {{"dims", rsp_out_lens}}), input);
// replace the original reshape with slice
int64_t start = 0;
for(std::size_t i = 0; i < vec_rsp.size(); ++i)
{
p.replace_instruction(
vec_rsp[i], op::slice{{rsp_axis}, {start}, {start + vec_dims[i]}}, rsp_ins);
vec_rsp[i],
make_op(
"slice",
{{"axes", {rsp_axis}}, {"starts", {start}}, {"ends", {start + vec_dims[i]}}}),
rsp_ins);
start += vec_dims[i];
}
}
......@@ -930,7 +966,8 @@ struct find_split_transpose
}
// insert an transpose instruction
auto tr = p.insert_instruction(std::next(input), op::transpose{perm}, input);
auto tr =
p.insert_instruction(std::next(input), make_op("transpose", {{"dims", perm}}), input);
// compute the axis in the slice
auto axis = any_cast<op::slice>(slc->get_operator()).axes.front();
......@@ -944,7 +981,10 @@ struct find_split_transpose
auto starts = oper.starts;
auto ends = oper.ends;
auto tr_orig = in->outputs().front();
p.replace_instruction(tr_orig, op::slice{{axis_new}, starts, ends}, tr);
p.replace_instruction(
tr_orig,
make_op("slice", {{"axes", {axis_new}}, {"starts", starts}, {"ends", ends}}),
tr);
}
}
};
......
......@@ -11,6 +11,8 @@
#include <migraphx/permutation.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <unordered_set>
#include <migraphx/make_op.hpp>
#include <map>
namespace migraphx {
......@@ -149,7 +151,7 @@ struct find_transpose
}
else
{
p.replace_instruction(ins, op::transpose{{dims}}, t->inputs().front());
p.replace_instruction(ins, make_op("transpose", {{"dims", dims}}), t->inputs().front());
}
}
};
......@@ -257,10 +259,10 @@ struct find_concat_transpose
std::vector<instruction_ref> inputs;
std::transform(
ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto i) {
return p.insert_instruction(ins, op::transpose{permutation}, i);
return p.insert_instruction(ins, make_op("transpose", {{"dims", permutation}}), i);
});
auto concat = p.insert_instruction(ins, op, inputs);
auto t = p.insert_instruction(ins, op::transpose{ipermutation}, concat);
auto t = p.insert_instruction(ins, make_op("transpose", {{"dims", ipermutation}}), concat);
assert(ins->get_shape().lens() == t->get_shape().lens());
p.replace_instruction(ins, t);
}
......
......@@ -17,6 +17,8 @@
#include <migraphx/instruction.hpp>
#include <migraphx/config.hpp>
#include <migraphx/tf.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/pad_calc.hpp>
namespace migraphx {
......@@ -49,20 +51,20 @@ struct tf_parser
instruction_ref to_nhwc(instruction_ref ins) const
{
if(should_transpose(ins))
return mm->add_instruction(op::transpose{{0, 2, 3, 1}}, ins);
return mm->add_instruction(make_op("transpose", {{"dims", {0, 2, 3, 1}}}), ins);
return ins;
}
instruction_ref to_nchw(instruction_ref ins) const
{
if(should_transpose(ins))
return mm->add_instruction(op::transpose{{0, 3, 1, 2}}, ins);
return mm->add_instruction(make_op("transpose", {{"dims", {0, 3, 1, 2}}}), ins);
return ins;
}
instruction_ref to_kcxy(instruction_ref ins) const
{
return mm->add_instruction(op::transpose{{3, 2, 0, 1}}, ins);
return mm->add_instruction(make_op("transpose", {{"dims", {3, 2, 0, 1}}}), ins);
}
instruction_ref make_contiguous(instruction_ref ins) const
......@@ -70,7 +72,7 @@ struct tf_parser
if(ins->get_shape().standard())
return ins;
else
return mm->add_instruction(op::contiguous{}, ins);
return mm->add_instruction(make_op("contiguous"), ins);
}
std::vector<instruction_ref> to_nchw(const std::vector<instruction_ref>& args)
......@@ -176,20 +178,20 @@ struct tf_parser
tf_parser()
{
add_generic_op("All", op::identity{});
add_generic_op("Identity", op::identity{});
add_generic_op("LessEqual", op::identity{});
add_generic_op("Relu", op::relu{});
add_generic_op("Rsqrt", op::rsqrt{});
add_generic_op("Tanh", op::tanh{});
add_generic_op("StopGradient", op::identity{});
add_binary_op("Add", op::add{});
add_binary_op("AddV2", op::add{});
add_binary_op("Mul", op::mul{});
add_binary_op("Pow", op::pow{});
add_binary_op("SquaredDifference", op::sqdiff{});
add_binary_op("Sub", op::sub{});
add_generic_op("All", make_op("identity"));
add_generic_op("Identity", make_op("identity"));
add_generic_op("LessEqual", make_op("identity"));
add_generic_op("Relu", make_op("relu"));
add_generic_op("Rsqrt", make_op("rsqrt"));
add_generic_op("Tanh", make_op("tanh"));
add_generic_op("StopGradient", make_op("identity"));
add_binary_op("Add", make_op("add"));
add_binary_op("AddV2", make_op("add"));
add_binary_op("Mul", make_op("mul"));
add_binary_op("Pow", make_op("pow"));
add_binary_op("SquaredDifference", make_op("sqdiff"));
add_binary_op("Sub", make_op("sub"));
add_mem_op("ArgMax", &tf_parser::parse_arg_op<op::argmax>, false);
add_mem_op("ArgMin", &tf_parser::parse_arg_op<op::argmin>, false);
......@@ -310,8 +312,10 @@ struct tf_parser
output_lens.begin() + offset,
[](auto a, auto b) { return std::max(a, b); });
auto l0 = mm->add_instruction(op::multibroadcast{output_lens}, arg0);
auto l1 = mm->add_instruction(op::multibroadcast{output_lens}, arg1);
auto l0 = mm->add_instruction(make_op("multibroadcast", {{"output_lens", output_lens}}),
arg0);
auto l1 = mm->add_instruction(make_op("multibroadcast", {{"output_lens", output_lens}}),
arg1);
return to_nhwc(mm->add_instruction(x, to_nchw(l0), to_nchw(l1)));
}
else
......@@ -337,7 +341,7 @@ struct tf_parser
int64_t axis = 0;
axis = args[1]->eval().at<int64_t>();
auto ins = mm->add_instruction(Op{axis}, args.front());
return mm->add_instruction(op::squeeze{{axis}}, ins);
return mm->add_instruction(make_op("squeeze", {{"axes", {axis}}}), ins);
}
instruction_ref parse_batchnorm(const std::string&,
......@@ -359,8 +363,9 @@ struct tf_parser
parse_biasadd(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
{
uint64_t axis = 1; // assume output of previous layer is in NCHW (broadcast on channel)
auto l0 = mm->add_instruction(op::broadcast{axis, args[0]->get_shape().lens()}, args[1]);
return mm->add_instruction(op::add{}, args[0], l0);
auto l0 = mm->add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", args[0]->get_shape().lens()}}), args[1]);
return mm->add_instruction(make_op("add"), args[0], l0);
}
instruction_ref parse_cast(const std::string&,
......@@ -368,7 +373,7 @@ struct tf_parser
std::vector<instruction_ref> args) const
{
shape::type_t type = parse_type(attributes.at("DstT").type());
return mm->add_instruction(op::convert{type}, std::move(args));
return mm->add_instruction(make_op("convert", {{"target_type", type}}), std::move(args));
}
instruction_ref parse_concat(const std::string&,
......@@ -442,7 +447,7 @@ struct tf_parser
if(pads[0] != pads[2] || pads[1] != pads[3])
{
std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
l0 = mm->add_instruction(migraphx::op::pad{padding}, l0);
l0 = mm->add_instruction(migraphx::make_op("pad", {{"pads", padding}}), l0);
}
else
{
......@@ -528,7 +533,7 @@ struct tf_parser
if(pads[0] != pads[2] || pads[1] != pads[3])
{
std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
l0 = mm->add_instruction(migraphx::op::pad{padding}, l0);
l0 = mm->add_instruction(migraphx::make_op("pad", {{"pads", padding}}), l0);
}
else
{
......@@ -553,8 +558,8 @@ struct tf_parser
new_weights_shape[0] = out_channels;
new_weights_shape[1] = 1;
// Make sure weights are contiguous before doing reshape
auto new_weights =
mm->add_instruction(op::reshape{new_weights_shape}, make_contiguous(weights));
auto new_weights = mm->add_instruction(make_op("reshape", {{"dims", new_weights_shape}}),
make_contiguous(weights));
return mm->add_instruction(op, {l0, new_weights});
}
......@@ -576,7 +581,7 @@ struct tf_parser
{
new_dims.insert(new_dims.begin() + dim, 1);
}
return mm->add_instruction(op::reshape{new_dims}, args[0]);
return mm->add_instruction(make_op("reshape", {{"dims", new_dims}}), args[0]);
}
instruction_ref
......@@ -617,10 +622,12 @@ struct tf_parser
// swap the last two elements
std::iter_swap(perm.end() - 1, perm.end() - 2);
auto l1 = (transa) ? mm->add_instruction(op::transpose{perm}, args[0]) : args[0];
auto l2 = (transb) ? mm->add_instruction(op::transpose{perm}, args[1]) : args[1];
auto l1 = (transa) ? mm->add_instruction(make_op("transpose", {{"dims", perm}}), args[0])
: args[0];
auto l2 = (transb) ? mm->add_instruction(make_op("transpose", {{"dims", perm}}), args[1])
: args[1];
return mm->add_instruction(op::dot{}, l1, l2);
return mm->add_instruction(make_op("dot"), l1, l2);
}
instruction_ref parse_mean(const std::string&,
......@@ -632,12 +639,12 @@ struct tf_parser
if(keep_dims)
{
return mm->add_instruction(op::reduce_mean{axes}, args[0]);
return mm->add_instruction(make_op("reduce_mean", {{"axes", axes}}), args[0]);
}
else
{
auto ins = mm->add_instruction(op::reduce_mean{axes}, args[0]);
return mm->add_instruction(op::squeeze{axes}, ins);
auto ins = mm->add_instruction(make_op("reduce_mean", {{"axes", axes}}), args[0]);
return mm->add_instruction(make_op("squeeze", {{"axes", axes}}), ins);
}
}
......@@ -663,7 +670,7 @@ struct tf_parser
{
shape s{shape::float_type, {depth, depth}};
auto l0 = mm->add_literal({s, depth_input});
return mm->add_instruction(op::gather{0}, {l0, args[0]});
return mm->add_instruction(make_op("gather", {{"axis", 0}}), {l0, args[0]});
}
MIGRAPHX_THROW("MIGraphX does not support axis != -1");
}
......@@ -688,8 +695,10 @@ struct tf_parser
args.begin(),
args.end(),
std::back_inserter(unsqueezed_args),
[&](instruction_ref arg) { return mm->add_instruction(op::unsqueeze{{axis}}, arg); });
return to_nhwc(mm->add_instruction(op::concat{axis}, unsqueezed_args));
[&](instruction_ref arg) {
return mm->add_instruction(make_op("unsqueeze", {{"axes", {axis}}}), arg);
});
return to_nhwc(mm->add_instruction(make_op("concat", {{"axis", axis}}), unsqueezed_args));
}
instruction_ref
......@@ -765,7 +774,10 @@ struct tf_parser
{
std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
l0 = mm->add_instruction(
migraphx::op::pad{padding, std::numeric_limits<float>::lowest()}, l0);
migraphx::make_op(
"pad",
{{"pads", padding}, {"value", std::numeric_limits<float>::lowest()}}),
l0);
}
else
{
......@@ -784,9 +796,11 @@ struct tf_parser
auto min_val = mm->add_literal(0.0f);
auto max_val = mm->add_literal(6.0f);
min_val = mm->add_instruction(op::multibroadcast{input_lens}, min_val);
max_val = mm->add_instruction(op::multibroadcast{input_lens}, max_val);
return mm->add_instruction(op::clip{}, args.front(), min_val, max_val);
min_val =
mm->add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}), min_val);
max_val =
mm->add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}), max_val);
return mm->add_instruction(make_op("clip"), args.front(), min_val, max_val);
}
instruction_ref
......@@ -884,7 +898,8 @@ struct tf_parser
assert(num_outputs > 0);
if(num_outputs == 1)
return std::vector<instruction_ref>{mm->add_instruction(op::identity{}, input_arg)};
return std::vector<instruction_ref>{
mm->add_instruction(make_op("identity"), input_arg)};
auto lens = input_arg->get_shape().lens();
auto num_dims = lens.size();
......@@ -1012,7 +1027,7 @@ struct tf_parser
squeeze_axes.push_back(i);
}
return mm->add_instruction(op::squeeze{squeeze_axes}, l1);
return mm->add_instruction(make_op("squeeze", {{"axes", squeeze_axes}}), l1);
}
instruction_ref parse_transpose(const std::string&,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment