Unverified Commit 8d21fdc9 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Refactor to use make_op almost everywhere (#696)

* Load op when serializing

* Formatting

* Add missing clip field

* Use make_op almost everywhere

* Formatting

* More make ops for rnns

* Get rid of spaces

* Formatting

* Remove operators headers

* Formatting

* Remove unused op headers

* Increase line threshold
parent b5633c27
......@@ -18,7 +18,7 @@ CheckOptions:
- key: readability-function-size.BranchThreshold
value: '15'
- key: readability-function-size.LineThreshold
value: '300'
value: '350'
- key: readability-function-size.NestingThreshold
value: '5'
- key: readability-function-size.ParameterThreshold
......
......@@ -26,6 +26,7 @@ add_library(migraphx
load_save.cpp
make_op.cpp
msgpack.cpp
operation.cpp
program.cpp
quantization.cpp
reduce_dims.cpp
......
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/contiguous.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp>
namespace migraphx {
......@@ -14,7 +15,7 @@ void auto_contiguous::apply(module& p) const
shape s = ins->get_shape();
if(not s.standard() and s.elements() != 0)
{
auto c = p.insert_instruction(std::next(ins), op::contiguous{}, ins);
auto c = p.insert_instruction(std::next(ins), make_op("contiguous"), ins);
p.replace_instruction(ins, c);
}
}
......
......@@ -9,6 +9,8 @@
#include <migraphx/op/dot.hpp>
#include <migraphx/op/multibroadcast.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/add.hpp>
namespace migraphx {
......@@ -26,17 +28,19 @@ struct find_dot_add
not contains({shape::float_type, shape::half_type, shape::double_type},
ins->get_shape().type()))
return;
auto dot_ins =
p.insert_instruction(ins, op::dot{dot.alpha, 0}, ins->inputs()[0], ins->inputs()[1]);
auto c_ins = ins->inputs()[2];
auto dot_ins = p.insert_instruction(ins,
make_op("dot", {{"alpha", dot.alpha}, {"beta", 0}}),
ins->inputs()[0],
ins->inputs()[1]);
auto c_ins = ins->inputs()[2];
if(not float_equal(dot.beta, 1))
{
auto beta = p.add_literal(literal{shape{ins->get_shape().type()}, {dot.beta}});
auto beta_broadcast =
p.insert_instruction(ins, op::multibroadcast{ins->get_shape().lens()}, beta);
c_ins = p.insert_instruction(ins, op::mul{}, c_ins, beta_broadcast);
auto beta_broadcast = p.insert_instruction(
ins, make_op("multibroadcast", {{"output_lens", ins->get_shape().lens()}}), beta);
c_ins = p.insert_instruction(ins, make_op("mul"), c_ins, beta_broadcast);
}
p.replace_instruction(ins, op::add{}, dot_ins, c_ins);
p.replace_instruction(ins, make_op("add"), dot_ins, c_ins);
}
};
......
#include <migraphx/eliminate_allocation.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/load.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/pass_config.hpp>
namespace migraphx {
......@@ -33,7 +36,8 @@ void eliminate_allocation::apply(module& p) const
auto ins = pp.first;
auto s = ins->get_shape();
auto offset = pp.second;
p.replace_instruction(ins, op::load{s, offset}, mem);
p.replace_instruction(
ins, make_op("load", {{"shape", to_value(s)}, {"offset", offset}}), mem);
}
}
}
......
......@@ -6,6 +6,8 @@
#include <migraphx/op/identity.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/dfor.hpp>
namespace migraphx {
......@@ -77,7 +79,7 @@ void eliminate_concat::apply(module& p) const
}
std::vector<instruction_ref> args = {super};
std::copy(ins->inputs().begin(), ins->inputs().end() - 1, std::back_inserter(args));
p.replace_instruction(ins, migraphx::op::identity{}, args);
p.replace_instruction(ins, migraphx::make_op("identity"), args);
}
}
}
......
......@@ -33,6 +33,7 @@ struct lstm
return pack(f(self.hidden_size, "hidden_size"),
f(self.actv_funcs, "actv_func"),
f(self.direction, "direction"),
f(self.clip, "clip"),
f(self.input_forget, "input_forget"));
}
......
......@@ -846,6 +846,9 @@ bool has_finalize(const T& x)
return detail::has_finalize_op(x);
}
void migraphx_to_value(value& v, const operation& op);
void migraphx_from_value(const value& v, operation& op);
#endif
} // namespace MIGRAPHX_INLINE_NS
......
This diff is collapsed.
#include <migraphx/operation.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void migraphx_to_value(value& v, const operation& op)
{
v["name"] = op.name();
v["operator"] = op.to_value();
}
void migraphx_from_value(const value& v, operation& op)
{
op = make_op(v.at("name").to<std::string>(), v.at("operator"));
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/op/load.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include "memory_coloring_impl.hpp"
namespace migraphx {
......@@ -208,7 +211,9 @@ void memory_coloring_impl::rewrite()
if(is_allocate(ins))
{
p_program->replace_instruction(
ins, op::load{ins->get_shape(), offset}, scratch_param);
ins,
make_op("load", {{"shape", to_value(ins->get_shape())}, {"offset", offset}}),
scratch_param);
}
}
}
......
......@@ -15,6 +15,8 @@
#include <algorithm>
#include <set>
#include <utility>
#include <migraphx/make_op.hpp>
#include <unordered_set>
namespace migraphx {
......@@ -212,7 +214,7 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
if(ins == std::prev(this->end()))
{
return replace_instruction(ins, op::identity{}, rep);
return replace_instruction(ins, make_op("identity"), rep);
}
// TODO: Should it be an error if the output is empty?
......
......@@ -19,6 +19,10 @@
#include <utility>
#include <set>
#include <iomanip>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <fstream>
#include <algorithm>
......@@ -58,12 +62,14 @@ instruction_ref insert_quant_ins(module& modl,
auto float_ins = scaled_ins;
if(scaled_ins->get_shape().type() != shape::float_type)
{
float_ins =
modl.insert_instruction(insert_loc, op::convert{shape::float_type}, scaled_ins);
float_ins = modl.insert_instruction(
insert_loc,
make_op("convert", {{"target_type", to_value(shape::float_type)}}),
scaled_ins);
}
std::vector<float> vec_scale(scaled_ins->get_shape().elements(), scale);
auto l_scale = modl.add_literal(literal(float_ins->get_shape(), vec_scale));
scaled_ins = modl.insert_instruction(insert_loc, op::mul{}, l_scale, float_ins);
scaled_ins = modl.insert_instruction(insert_loc, make_op("mul"), l_scale, float_ins);
}
auto shifted_ins = scaled_ins;
......@@ -73,26 +79,32 @@ instruction_ref insert_quant_ins(module& modl,
if(shifted_ins->get_shape().type() != shape::float_type)
{
float_ins = modl.insert_instruction(
insert_loc, op::convert{shape::float_type}, shifted_ins);
insert_loc,
make_op("convert", {{"target_type", to_value(shape::float_type)}}),
shifted_ins);
}
std::vector<float> vec_shift(shifted_ins->get_shape().elements(), shift);
auto l_shift = modl.add_literal(literal(float_ins->get_shape(), vec_shift));
shifted_ins = modl.insert_instruction(insert_loc, op::add{}, l_shift, float_ins);
shifted_ins = modl.insert_instruction(insert_loc, make_op("add"), l_shift, float_ins);
}
auto rounded_ins = modl.insert_instruction(insert_loc, op::round{}, shifted_ins);
auto rounded_ins = modl.insert_instruction(insert_loc, make_op("round"), shifted_ins);
auto rounded_lens = rounded_ins->get_shape().lens();
auto max_clip = modl.add_literal(127.0f);
auto min_clip = modl.add_literal(-128.0f);
max_clip = modl.insert_instruction(insert_loc, op::multibroadcast{rounded_lens}, max_clip);
min_clip = modl.insert_instruction(insert_loc, op::multibroadcast{rounded_lens}, min_clip);
max_clip = modl.insert_instruction(
insert_loc, make_op("multibroadcast", {{"output_lens", rounded_lens}}), max_clip);
min_clip = modl.insert_instruction(
insert_loc, make_op("multibroadcast", {{"output_lens", rounded_lens}}), min_clip);
auto clipped_ins =
modl.insert_instruction(insert_loc, op::clip{}, rounded_ins, min_clip, max_clip);
quant_ins = modl.insert_instruction(insert_loc, op::convert{type}, clipped_ins);
modl.insert_instruction(insert_loc, make_op("clip"), rounded_ins, min_clip, max_clip);
quant_ins = modl.insert_instruction(
insert_loc, make_op("convert", {{"target_type", type}}), clipped_ins);
}
else
{
quant_ins = modl.insert_instruction(insert_loc, op::convert{type}, ins);
quant_ins =
modl.insert_instruction(insert_loc, make_op("convert", {{"target_type", type}}), ins);
}
map_ins[ins] = quant_ins;
......@@ -161,9 +173,9 @@ void quantize_fp16(program& prog, const std::vector<std::string>& ins_names)
if(ins_shape.type() != orig_type)
{
// check the dead code case to avoid assert
bool output_empty = ins->outputs().empty();
auto ins_orig_type =
mm->insert_instruction(std::next(ins), op::convert{orig_type}, ins);
bool output_empty = ins->outputs().empty();
auto ins_orig_type = mm->insert_instruction(
std::next(ins), make_op("convert", {{"target_type", orig_type}}), ins);
if(!output_empty)
{
mm->replace_instruction(ins, ins_orig_type);
......@@ -197,13 +209,18 @@ static void ins_quantize_int8(module& modl,
if(shape::int32_type == orig_type)
{
modl.replace_instruction(
ins, op::quant_dot{quant_alpha, quant_beta}, converted_inputs);
ins,
make_op("quant_dot", {{"alpha", quant_alpha}, {"beta", quant_beta}}),
converted_inputs);
}
else
{
auto quant_dot = modl.insert_instruction(
ins, op::quant_dot{quant_alpha, quant_beta}, converted_inputs);
modl.replace_instruction(ins, op::convert{orig_type}, quant_dot);
ins,
make_op("quant_dot", {{"alpha", quant_alpha}, {"beta", quant_beta}}),
converted_inputs);
modl.replace_instruction(
ins, make_op("convert", {{"target_type", to_value(orig_type)}}), quant_dot);
}
}
// either alpha or beta cannot be quantized because of too big
......@@ -214,8 +231,10 @@ static void ins_quantize_int8(module& modl,
{
converted_inputs.pop_back();
}
auto q_dot = modl.insert_instruction(ins, op::quant_dot{1, 0}, converted_inputs);
auto f_dot = modl.insert_instruction(ins, op::convert{shape::float_type}, q_dot);
auto q_dot = modl.insert_instruction(
ins, make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), converted_inputs);
auto f_dot = modl.insert_instruction(
ins, make_op("convert", {{"target_type", to_value(shape::float_type)}}), q_dot);
auto c_shape = q_dot->get_shape();
std::vector<float> vec_alpha(c_shape.elements(), new_alpha);
auto l_alpha =
......@@ -223,42 +242,46 @@ static void ins_quantize_int8(module& modl,
if(inputs.size() == 3 and dot_op.beta != 0.0f)
{
auto alpha_ab = modl.insert_instruction(ins, op::mul{}, l_alpha, f_dot);
auto alpha_ab = modl.insert_instruction(ins, make_op("mul"), l_alpha, f_dot);
std::vector<float> vec_beta(c_shape.elements(), dot_op.beta);
auto l_beta =
modl.add_literal(literal({shape::float_type, c_shape.lens()}, vec_beta));
instruction_ref beta_c{};
if(orig_type != shape::float_type)
{
auto fp32_c =
modl.insert_instruction(ins, op::convert{shape::float_type}, inputs.back());
beta_c = modl.insert_instruction(ins, op::mul{}, l_beta, fp32_c);
auto fp32_c = modl.insert_instruction(
ins,
make_op("convert", {{"target_type", to_value(shape::float_type)}}),
inputs.back());
beta_c = modl.insert_instruction(ins, make_op("mul"), l_beta, fp32_c);
}
else
{
beta_c = modl.insert_instruction(ins, op::mul{}, l_beta, inputs.back());
beta_c = modl.insert_instruction(ins, make_op("mul"), l_beta, inputs.back());
}
if(orig_type == shape::float_type)
{
modl.replace_instruction(ins, op::add{}, alpha_ab, beta_c);
modl.replace_instruction(ins, make_op("add"), alpha_ab, beta_c);
}
else
{
auto f_res = modl.insert_instruction(ins, op::add{}, alpha_ab, beta_c);
modl.replace_instruction(ins, op::convert{orig_type}, f_res);
auto f_res = modl.insert_instruction(ins, make_op("add"), alpha_ab, beta_c);
modl.replace_instruction(
ins, make_op("convert", {{"target_type", to_value(orig_type)}}), f_res);
}
}
else
{
if(orig_type == shape::float_type)
{
modl.replace_instruction(ins, op::mul{}, l_alpha, f_dot);
modl.replace_instruction(ins, make_op("mul"), l_alpha, f_dot);
}
else
{
auto alpha_ab = modl.insert_instruction(ins, op::mul{}, l_alpha, f_dot);
modl.replace_instruction(ins, op::convert{orig_type}, alpha_ab);
auto alpha_ab = modl.insert_instruction(ins, make_op("mul"), l_alpha, f_dot);
modl.replace_instruction(
ins, make_op("convert", {{"target_type", to_value(orig_type)}}), alpha_ab);
}
}
}
......@@ -285,23 +308,27 @@ static void ins_quantize_int8(module& modl,
{
auto l_factor = modl.add_literal(
literal(quant_conv->get_shape(), vec_factor.begin(), vec_factor.end()));
modl.replace_instruction(ins, op::mul{}, quant_conv, l_factor);
modl.replace_instruction(ins, make_op("mul"), quant_conv, l_factor);
}
// convert quant_conv output to float type, multiply the factor and
// conver back to original type
else
{
auto float_conv =
modl.insert_instruction(ins, op::convert{shape::float_type}, quant_conv);
auto float_conv = modl.insert_instruction(
ins,
make_op("convert", {{"target_type", to_value(shape::float_type)}}),
quant_conv);
auto l_factor = modl.add_literal(literal(float_conv->get_shape(), vec_factor));
if(orig_type == shape::float_type)
{
modl.replace_instruction(ins, op::mul{}, l_factor, float_conv);
modl.replace_instruction(ins, make_op("mul"), l_factor, float_conv);
}
else
{
auto adjusted_conv = modl.insert_instruction(ins, op::mul{}, l_factor, float_conv);
modl.replace_instruction(ins, op::convert{orig_type}, adjusted_conv);
auto adjusted_conv =
modl.insert_instruction(ins, make_op("mul"), l_factor, float_conv);
modl.replace_instruction(
ins, make_op("convert", {{"target_type", to_value(orig_type)}}), adjusted_conv);
}
}
}
......
......@@ -7,6 +7,8 @@
#include <migraphx/op/mul.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/dfor.hpp>
namespace migraphx {
......@@ -46,10 +48,10 @@ void rewrite_batchnorm::apply(module& p) const
auto broadcast = op::broadcast{1, ins->get_shape().lens()};
auto a_ins = p.add_literal({a.get_shape(), a.data()});
auto a_broadcast = p.insert_instruction(ins, broadcast, a_ins);
auto mul = p.insert_instruction(ins, op::mul{}, ins->inputs().front(), a_broadcast);
auto b_ins = p.add_literal({b.get_shape(), b.data()});
auto mul = p.insert_instruction(ins, make_op("mul"), ins->inputs().front(), a_broadcast);
auto b_ins = p.add_literal({b.get_shape(), b.data()});
auto b_broadcast = p.insert_instruction(ins, broadcast, b_ins);
auto add = p.insert_instruction(ins, op::add{}, mul, b_broadcast);
auto add = p.insert_instruction(ins, make_op("add"), mul, b_broadcast);
p.replace_instruction(ins, add);
}
}
......
......@@ -5,6 +5,8 @@
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/reduce_mean.hpp>
#include <migraphx/op/reduce_max.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
namespace migraphx {
......@@ -31,25 +33,26 @@ void rewrite_pooling::apply(module& prog) const
continue;
std::int64_t n = s.lens()[0];
std::int64_t c = s.lens()[1];
auto reshape =
prog.insert_instruction(ins, op::reshape{{n * c, -1}}, ins->inputs().front());
auto reshape = prog.insert_instruction(
ins, make_op("reshape", {{"dims", {n * c, -1}}}), ins->inputs().front());
instruction_ref pooling{};
// average pooling
if(op.mode == "average")
{
pooling = prog.insert_instruction(ins, op::reduce_mean{{1}}, reshape);
pooling =
prog.insert_instruction(ins, make_op("reduce_mean", {{"axes", {1}}}), reshape);
}
// max pooling
else
{
pooling = prog.insert_instruction(ins, op::reduce_max{{1}}, reshape);
pooling = prog.insert_instruction(ins, make_op("reduce_max", {{"axes", {1}}}), reshape);
}
std::vector<int64_t> rsp_lens(lens.size(), 1);
rsp_lens[0] = n;
rsp_lens[1] = c;
prog.replace_instruction(ins, op::reshape{rsp_lens}, pooling);
prog.replace_instruction(ins, make_op("reshape", {{"dims", rsp_lens}}), pooling);
}
}
......
This diff is collapsed.
#include <migraphx/schedule.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/par_for.hpp>
......@@ -12,6 +11,8 @@
#include <queue>
#include <thread>
#include <mutex>
#include <migraphx/make_op.hpp>
#include <set>
#include <deque>
#include <chrono>
......@@ -556,7 +557,7 @@ void schedule::apply(module& p) const
std::vector<instruction_ref> args;
args.push_back(ip.first);
args.insert(args.end(), ip.second.begin(), ip.second.end());
p.insert_instruction(std::next(ip.first), op::identity{}, args);
p.insert_instruction(std::next(ip.first), make_op("identity"), args);
}
}
......
......@@ -16,6 +16,10 @@
#include <migraphx/op/transpose.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/algorithm.hpp>
namespace migraphx {
......@@ -62,8 +66,10 @@ struct find_mul_conv
return;
auto new_a = p.insert_instruction(
ins, op::broadcast{0, w_ins->get_shape().lens()}, a_ins->inputs().front());
auto new_mul = p.insert_instruction(ins, op::mul{}, new_a, w_ins);
ins,
make_op("broadcast", {{"axis", 0}, {"dims", w_ins->get_shape().lens()}}),
a_ins->inputs().front());
auto new_mul = p.insert_instruction(ins, make_op("mul"), new_a, w_ins);
auto new_conv = p.insert_instruction(
ins, conv_ins->get_operator(), conv_ins->inputs().front(), new_mul);
p.replace_instruction(ins, new_conv);
......@@ -125,20 +131,27 @@ struct find_mul_slice_conv
auto slice_w_ins = p.insert_instruction(ins, w_slice_op, w_ins);
auto new_a = p.insert_instruction(
ins, op::broadcast{0, slice_w_ins->get_shape().lens()}, a_ins->inputs().front());
auto new_mul = p.insert_instruction(ins, op::mul{}, new_a, slice_w_ins);
ins,
make_op("broadcast", {{"axis", 0}, {"dims", slice_w_ins->get_shape().lens()}}),
a_ins->inputs().front());
auto new_mul = p.insert_instruction(ins, make_op("mul"), new_a, slice_w_ins);
std::vector<instruction_ref> sliced_weights;
if(slice_op.starts.front() != 0)
sliced_weights.push_back(
p.insert_instruction(ins, op::slice{{0}, {0}, slice_op.starts}, w_ins));
sliced_weights.push_back(p.insert_instruction(
ins,
make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", slice_op.starts}}),
w_ins));
sliced_weights.push_back(new_mul);
int64_t end_axis = w_ins->get_shape().lens().at(0);
if(slice_op.ends.front() != end_axis)
sliced_weights.push_back(
p.insert_instruction(ins, op::slice{{0}, {slice_op.ends}, {end_axis}}, w_ins));
sliced_weights.push_back(p.insert_instruction(
ins,
make_op("slice", {{"axes", {0}}, {"starts", slice_op.ends}, {"ends", {end_axis}}}),
w_ins));
auto new_weights = p.insert_instruction(ins, op::concat{0}, sliced_weights);
auto new_weights =
p.insert_instruction(ins, make_op("concat", {{"axis", 0}}), sliced_weights);
auto new_conv = p.insert_instruction(
ins, conv_ins->get_operator(), conv_ins->inputs().front(), new_weights);
......@@ -177,9 +190,9 @@ struct find_mul_add
auto x_ins = r.instructions["x"];
assert(x_ins != b_ins);
auto ax_ins = p.insert_instruction(ins, op::mul{}, a_ins, x_ins);
auto ab_ins = p.insert_instruction(ins, op::mul{}, a_ins, b_ins);
p.replace_instruction(ins, op::add{}, ax_ins, ab_ins);
auto ax_ins = p.insert_instruction(ins, make_op("mul"), a_ins, x_ins);
auto ab_ins = p.insert_instruction(ins, make_op("mul"), a_ins, b_ins);
p.replace_instruction(ins, make_op("add"), ax_ins, ab_ins);
}
};
......@@ -198,8 +211,8 @@ struct find_add_lit_broadcast
auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"];
auto sumab = p.insert_instruction(ins, op::add{}, a_ins, b_ins);
p.replace_instruction(ins, op::add{}, x_ins, sumab);
auto sumab = p.insert_instruction(ins, make_op("add"), a_ins, b_ins);
p.replace_instruction(ins, make_op("add"), x_ins, sumab);
}
};
......@@ -225,18 +238,18 @@ struct find_double_add_lit_broadcast
{
if(a_ins->inputs().at(0)->get_shape() != b_ins->inputs().at(0)->get_shape())
return;
auto op = a_ins->get_operator();
auto presum =
p.insert_instruction(ins, op::add{}, a_ins->inputs().at(0), b_ins->inputs().at(0));
auto op = a_ins->get_operator();
auto presum = p.insert_instruction(
ins, make_op("add"), a_ins->inputs().at(0), b_ins->inputs().at(0));
sumab = p.insert_instruction(ins, op, presum);
}
else
{
sumab = p.insert_instruction(ins, op::add{}, a_ins, b_ins);
sumab = p.insert_instruction(ins, make_op("add"), a_ins, b_ins);
}
auto sumxy = p.insert_instruction(ins, op::add{}, x_ins, y_ins);
p.replace_instruction(ins, op::add{}, sumxy, sumab);
auto sumxy = p.insert_instruction(ins, make_op("add"), x_ins, y_ins);
p.replace_instruction(ins, make_op("add"), sumxy, sumab);
}
};
......@@ -327,7 +340,8 @@ struct find_concat_op
std::transform(start, last, std::back_inserter(inputs), [&](auto j) {
return j->inputs().at(i);
});
auto concat = p.insert_instruction(ins, op::concat{iaxis}, inputs);
auto concat =
p.insert_instruction(ins, make_op("concat", {{"axis", iaxis}}), inputs);
concats.push_back(concat);
}
auto y = p.insert_instruction(ins, op, concats);
......@@ -349,7 +363,7 @@ struct find_concat_op
if(args.size() == 1)
p.replace_instruction(ins, args.front());
else
p.replace_instruction(ins, op::concat{axis}, args);
p.replace_instruction(ins, make_op("concat", {{"axis", axis}}), args);
}
};
......@@ -482,7 +496,8 @@ struct find_splits
return;
auto concat_axis = slice_op.axes.front();
// TODO: Check if axises match
auto concat = p.insert_instruction(ins, op::concat{concat_axis}, data_args);
auto concat = p.insert_instruction(
ins, make_op("concat", {{"axis", concat_axis}}), data_args);
std::vector<instruction_ref> args;
args.resize(2);
......@@ -501,7 +516,8 @@ struct find_splits
{
if(not contains({"reshape", "squeeze", "unsqueeze"}, output->name()))
continue;
auto x = p.insert_instruction(output, op::contiguous{}, output->inputs());
auto x =
p.insert_instruction(output, make_op("contiguous"), output->inputs());
p.replace_instruction(output, output->get_operator(), x);
}
......@@ -648,7 +664,11 @@ struct find_add_convs
return;
new_op = a_op;
b_input = p.insert_instruction(
ins, op::as_shape{compute_stride_shape(b_input->get_shape(), n)}, b_input);
ins,
make_op(
"as_shape",
{{"shape", to_value(compute_stride_shape(b_input->get_shape(), n))}}),
b_input);
}
else if(b_op.stride < a_op.stride)
{
......@@ -657,7 +677,11 @@ struct find_add_convs
return;
new_op = b_op;
a_input = p.insert_instruction(
ins, op::as_shape{compute_stride_shape(a_input->get_shape(), n)}, a_input);
ins,
make_op(
"as_shape",
{{"shape", to_value(compute_stride_shape(a_input->get_shape(), n))}}),
a_input);
}
else
return;
......@@ -666,8 +690,10 @@ struct find_add_convs
return;
}
auto concat_input = p.insert_instruction(ins, op::concat{1}, a_input, b_input);
auto concat_weights = p.insert_instruction(ins, op::concat{1}, a_weights, b_weights);
auto concat_input =
p.insert_instruction(ins, make_op("concat", {{"axis", 1}}), a_input, b_input);
auto concat_weights =
p.insert_instruction(ins, make_op("concat", {{"axis", 1}}), a_weights, b_weights);
p.replace_instruction(ins, new_op, concat_input, concat_weights);
}
};
......@@ -739,13 +765,18 @@ struct find_conv_dot_horiz_fusion
for(auto arg : args)
p.move_instructions(arg, input);
// TODO: Check if axises match
auto concat = p.insert_instruction(input, op::concat{concat_axis}, args);
auto concat =
p.insert_instruction(input, make_op("concat", {{"axis", concat_axis}}), args);
auto fused = p.insert_instruction(std::next(input), op, input, concat);
int64_t offset = 0;
for(auto arg : range(start, last))
{
int64_t len = arg->get_shape().lens()[axis];
p.replace_instruction(arg, op::slice{{axis}, {offset}, {offset + len}}, fused);
p.replace_instruction(
arg,
make_op("slice",
{{"axes", {axis}}, {"starts", {offset}}, {"ends", {offset + len}}}),
fused);
offset += len;
}
};
......@@ -767,11 +798,11 @@ struct find_div_const
auto ins = r.result;
auto c_ins = r.instructions["c"];
auto recip = p.insert_instruction(std::next(c_ins), op::recip{}, c_ins);
auto recip = p.insert_instruction(std::next(c_ins), make_op("recip"), c_ins);
auto args = ins->inputs();
p.replace_instruction(ins, op::mul{}, args.front(), recip);
p.replace_instruction(ins, make_op("mul"), args.front(), recip);
}
};
......@@ -787,11 +818,11 @@ struct find_sub_const
auto ins = r.result;
auto c_ins = r.instructions["c"];
auto neg = p.insert_instruction(std::next(c_ins), op::neg{}, c_ins);
auto neg = p.insert_instruction(std::next(c_ins), make_op("neg"), c_ins);
auto args = ins->inputs();
p.replace_instruction(ins, op::add{}, args.front(), neg);
p.replace_instruction(ins, make_op("add"), args.front(), neg);
}
};
......@@ -808,7 +839,7 @@ struct find_rsqrt
auto ins = r.result;
auto x_ins = r.instructions["x"];
p.replace_instruction(ins, op::rsqrt{}, x_ins);
p.replace_instruction(ins, make_op("rsqrt"), x_ins);
}
};
......@@ -883,14 +914,19 @@ struct find_split_reshape
rsp_out_lens[rsp_axis] = std::accumulate(vec_dims.begin(), vec_dims.end(), std::int64_t{0});
// insert the reshape instruction
auto rsp_ins = p.insert_instruction(std::next(input), op::reshape{rsp_out_lens}, input);
auto rsp_ins = p.insert_instruction(
std::next(input), make_op("reshape", {{"dims", rsp_out_lens}}), input);
// replace the original reshape with slice
int64_t start = 0;
for(std::size_t i = 0; i < vec_rsp.size(); ++i)
{
p.replace_instruction(
vec_rsp[i], op::slice{{rsp_axis}, {start}, {start + vec_dims[i]}}, rsp_ins);
vec_rsp[i],
make_op(
"slice",
{{"axes", {rsp_axis}}, {"starts", {start}}, {"ends", {start + vec_dims[i]}}}),
rsp_ins);
start += vec_dims[i];
}
}
......@@ -930,7 +966,8 @@ struct find_split_transpose
}
// insert an transpose instruction
auto tr = p.insert_instruction(std::next(input), op::transpose{perm}, input);
auto tr =
p.insert_instruction(std::next(input), make_op("transpose", {{"dims", perm}}), input);
// compute the axis in the slice
auto axis = any_cast<op::slice>(slc->get_operator()).axes.front();
......@@ -944,7 +981,10 @@ struct find_split_transpose
auto starts = oper.starts;
auto ends = oper.ends;
auto tr_orig = in->outputs().front();
p.replace_instruction(tr_orig, op::slice{{axis_new}, starts, ends}, tr);
p.replace_instruction(
tr_orig,
make_op("slice", {{"axes", {axis_new}}, {"starts", starts}, {"ends", ends}}),
tr);
}
}
};
......
......@@ -11,6 +11,8 @@
#include <migraphx/permutation.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <unordered_set>
#include <migraphx/make_op.hpp>
#include <map>
namespace migraphx {
......@@ -149,7 +151,7 @@ struct find_transpose
}
else
{
p.replace_instruction(ins, op::transpose{{dims}}, t->inputs().front());
p.replace_instruction(ins, make_op("transpose", {{"dims", dims}}), t->inputs().front());
}
}
};
......@@ -257,10 +259,10 @@ struct find_concat_transpose
std::vector<instruction_ref> inputs;
std::transform(
ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto i) {
return p.insert_instruction(ins, op::transpose{permutation}, i);
return p.insert_instruction(ins, make_op("transpose", {{"dims", permutation}}), i);
});
auto concat = p.insert_instruction(ins, op, inputs);
auto t = p.insert_instruction(ins, op::transpose{ipermutation}, concat);
auto t = p.insert_instruction(ins, make_op("transpose", {{"dims", ipermutation}}), concat);
assert(ins->get_shape().lens() == t->get_shape().lens());
p.replace_instruction(ins, t);
}
......
......@@ -17,6 +17,8 @@
#include <migraphx/instruction.hpp>
#include <migraphx/config.hpp>
#include <migraphx/tf.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/pad_calc.hpp>
namespace migraphx {
......@@ -49,20 +51,20 @@ struct tf_parser
instruction_ref to_nhwc(instruction_ref ins) const
{
if(should_transpose(ins))
return mm->add_instruction(op::transpose{{0, 2, 3, 1}}, ins);
return mm->add_instruction(make_op("transpose", {{"dims", {0, 2, 3, 1}}}), ins);
return ins;
}
instruction_ref to_nchw(instruction_ref ins) const
{
if(should_transpose(ins))
return mm->add_instruction(op::transpose{{0, 3, 1, 2}}, ins);
return mm->add_instruction(make_op("transpose", {{"dims", {0, 3, 1, 2}}}), ins);
return ins;
}
instruction_ref to_kcxy(instruction_ref ins) const
{
return mm->add_instruction(op::transpose{{3, 2, 0, 1}}, ins);
return mm->add_instruction(make_op("transpose", {{"dims", {3, 2, 0, 1}}}), ins);
}
instruction_ref make_contiguous(instruction_ref ins) const
......@@ -70,7 +72,7 @@ struct tf_parser
if(ins->get_shape().standard())
return ins;
else
return mm->add_instruction(op::contiguous{}, ins);
return mm->add_instruction(make_op("contiguous"), ins);
}
std::vector<instruction_ref> to_nchw(const std::vector<instruction_ref>& args)
......@@ -176,20 +178,20 @@ struct tf_parser
tf_parser()
{
add_generic_op("All", op::identity{});
add_generic_op("Identity", op::identity{});
add_generic_op("LessEqual", op::identity{});
add_generic_op("Relu", op::relu{});
add_generic_op("Rsqrt", op::rsqrt{});
add_generic_op("Tanh", op::tanh{});
add_generic_op("StopGradient", op::identity{});
add_binary_op("Add", op::add{});
add_binary_op("AddV2", op::add{});
add_binary_op("Mul", op::mul{});
add_binary_op("Pow", op::pow{});
add_binary_op("SquaredDifference", op::sqdiff{});
add_binary_op("Sub", op::sub{});
add_generic_op("All", make_op("identity"));
add_generic_op("Identity", make_op("identity"));
add_generic_op("LessEqual", make_op("identity"));
add_generic_op("Relu", make_op("relu"));
add_generic_op("Rsqrt", make_op("rsqrt"));
add_generic_op("Tanh", make_op("tanh"));
add_generic_op("StopGradient", make_op("identity"));
add_binary_op("Add", make_op("add"));
add_binary_op("AddV2", make_op("add"));
add_binary_op("Mul", make_op("mul"));
add_binary_op("Pow", make_op("pow"));
add_binary_op("SquaredDifference", make_op("sqdiff"));
add_binary_op("Sub", make_op("sub"));
add_mem_op("ArgMax", &tf_parser::parse_arg_op<op::argmax>, false);
add_mem_op("ArgMin", &tf_parser::parse_arg_op<op::argmin>, false);
......@@ -310,8 +312,10 @@ struct tf_parser
output_lens.begin() + offset,
[](auto a, auto b) { return std::max(a, b); });
auto l0 = mm->add_instruction(op::multibroadcast{output_lens}, arg0);
auto l1 = mm->add_instruction(op::multibroadcast{output_lens}, arg1);
auto l0 = mm->add_instruction(make_op("multibroadcast", {{"output_lens", output_lens}}),
arg0);
auto l1 = mm->add_instruction(make_op("multibroadcast", {{"output_lens", output_lens}}),
arg1);
return to_nhwc(mm->add_instruction(x, to_nchw(l0), to_nchw(l1)));
}
else
......@@ -337,7 +341,7 @@ struct tf_parser
int64_t axis = 0;
axis = args[1]->eval().at<int64_t>();
auto ins = mm->add_instruction(Op{axis}, args.front());
return mm->add_instruction(op::squeeze{{axis}}, ins);
return mm->add_instruction(make_op("squeeze", {{"axes", {axis}}}), ins);
}
instruction_ref parse_batchnorm(const std::string&,
......@@ -359,8 +363,9 @@ struct tf_parser
parse_biasadd(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
{
uint64_t axis = 1; // assume output of previous layer is in NCHW (broadcast on channel)
auto l0 = mm->add_instruction(op::broadcast{axis, args[0]->get_shape().lens()}, args[1]);
return mm->add_instruction(op::add{}, args[0], l0);
auto l0 = mm->add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", args[0]->get_shape().lens()}}), args[1]);
return mm->add_instruction(make_op("add"), args[0], l0);
}
instruction_ref parse_cast(const std::string&,
......@@ -368,7 +373,7 @@ struct tf_parser
std::vector<instruction_ref> args) const
{
shape::type_t type = parse_type(attributes.at("DstT").type());
return mm->add_instruction(op::convert{type}, std::move(args));
return mm->add_instruction(make_op("convert", {{"target_type", type}}), std::move(args));
}
instruction_ref parse_concat(const std::string&,
......@@ -442,7 +447,7 @@ struct tf_parser
if(pads[0] != pads[2] || pads[1] != pads[3])
{
std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
l0 = mm->add_instruction(migraphx::op::pad{padding}, l0);
l0 = mm->add_instruction(migraphx::make_op("pad", {{"pads", padding}}), l0);
}
else
{
......@@ -528,7 +533,7 @@ struct tf_parser
if(pads[0] != pads[2] || pads[1] != pads[3])
{
std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
l0 = mm->add_instruction(migraphx::op::pad{padding}, l0);
l0 = mm->add_instruction(migraphx::make_op("pad", {{"pads", padding}}), l0);
}
else
{
......@@ -553,8 +558,8 @@ struct tf_parser
new_weights_shape[0] = out_channels;
new_weights_shape[1] = 1;
// Make sure weights are contiguous before doing reshape
auto new_weights =
mm->add_instruction(op::reshape{new_weights_shape}, make_contiguous(weights));
auto new_weights = mm->add_instruction(make_op("reshape", {{"dims", new_weights_shape}}),
make_contiguous(weights));
return mm->add_instruction(op, {l0, new_weights});
}
......@@ -576,7 +581,7 @@ struct tf_parser
{
new_dims.insert(new_dims.begin() + dim, 1);
}
return mm->add_instruction(op::reshape{new_dims}, args[0]);
return mm->add_instruction(make_op("reshape", {{"dims", new_dims}}), args[0]);
}
instruction_ref
......@@ -617,10 +622,12 @@ struct tf_parser
// swap the last two elements
std::iter_swap(perm.end() - 1, perm.end() - 2);
auto l1 = (transa) ? mm->add_instruction(op::transpose{perm}, args[0]) : args[0];
auto l2 = (transb) ? mm->add_instruction(op::transpose{perm}, args[1]) : args[1];
auto l1 = (transa) ? mm->add_instruction(make_op("transpose", {{"dims", perm}}), args[0])
: args[0];
auto l2 = (transb) ? mm->add_instruction(make_op("transpose", {{"dims", perm}}), args[1])
: args[1];
return mm->add_instruction(op::dot{}, l1, l2);
return mm->add_instruction(make_op("dot"), l1, l2);
}
instruction_ref parse_mean(const std::string&,
......@@ -632,12 +639,12 @@ struct tf_parser
if(keep_dims)
{
return mm->add_instruction(op::reduce_mean{axes}, args[0]);
return mm->add_instruction(make_op("reduce_mean", {{"axes", axes}}), args[0]);
}
else
{
auto ins = mm->add_instruction(op::reduce_mean{axes}, args[0]);
return mm->add_instruction(op::squeeze{axes}, ins);
auto ins = mm->add_instruction(make_op("reduce_mean", {{"axes", axes}}), args[0]);
return mm->add_instruction(make_op("squeeze", {{"axes", axes}}), ins);
}
}
......@@ -663,7 +670,7 @@ struct tf_parser
{
shape s{shape::float_type, {depth, depth}};
auto l0 = mm->add_literal({s, depth_input});
return mm->add_instruction(op::gather{0}, {l0, args[0]});
return mm->add_instruction(make_op("gather", {{"axis", 0}}), {l0, args[0]});
}
MIGRAPHX_THROW("MIGraphX does not support axis != -1");
}
......@@ -688,8 +695,10 @@ struct tf_parser
args.begin(),
args.end(),
std::back_inserter(unsqueezed_args),
[&](instruction_ref arg) { return mm->add_instruction(op::unsqueeze{{axis}}, arg); });
return to_nhwc(mm->add_instruction(op::concat{axis}, unsqueezed_args));
[&](instruction_ref arg) {
return mm->add_instruction(make_op("unsqueeze", {{"axes", {axis}}}), arg);
});
return to_nhwc(mm->add_instruction(make_op("concat", {{"axis", axis}}), unsqueezed_args));
}
instruction_ref
......@@ -765,7 +774,10 @@ struct tf_parser
{
std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
l0 = mm->add_instruction(
migraphx::op::pad{padding, std::numeric_limits<float>::lowest()}, l0);
migraphx::make_op(
"pad",
{{"pads", padding}, {"value", std::numeric_limits<float>::lowest()}}),
l0);
}
else
{
......@@ -784,9 +796,11 @@ struct tf_parser
auto min_val = mm->add_literal(0.0f);
auto max_val = mm->add_literal(6.0f);
min_val = mm->add_instruction(op::multibroadcast{input_lens}, min_val);
max_val = mm->add_instruction(op::multibroadcast{input_lens}, max_val);
return mm->add_instruction(op::clip{}, args.front(), min_val, max_val);
min_val =
mm->add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}), min_val);
max_val =
mm->add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}), max_val);
return mm->add_instruction(make_op("clip"), args.front(), min_val, max_val);
}
instruction_ref
......@@ -884,7 +898,8 @@ struct tf_parser
assert(num_outputs > 0);
if(num_outputs == 1)
return std::vector<instruction_ref>{mm->add_instruction(op::identity{}, input_arg)};
return std::vector<instruction_ref>{
mm->add_instruction(make_op("identity"), input_arg)};
auto lens = input_arg->get_shape().lens();
auto num_dims = lens.size();
......@@ -1012,7 +1027,7 @@ struct tf_parser
squeeze_axes.push_back(i);
}
return mm->add_instruction(op::squeeze{squeeze_axes}, l1);
return mm->add_instruction(make_op("squeeze", {{"axes", squeeze_axes}}), l1);
}
instruction_ref parse_transpose(const std::string&,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment