Commit 5ec8f913 authored by Ted Themistokleous's avatar Ted Themistokleous Committed by Ted Themistokleous
Browse files

Merge branch 'develop' into simplify_1_mul_div_ops

parents 32d69e8e d78bcdfb
...@@ -37,6 +37,7 @@ ...@@ -37,6 +37,7 @@
#include <migraphx/output_iterator.hpp> #include <migraphx/output_iterator.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/marker.hpp> #include <migraphx/marker.hpp>
#include <migraphx/supported_segments.hpp>
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include <algorithm> #include <algorithm>
...@@ -77,11 +78,11 @@ program& program::operator=(program p) ...@@ -77,11 +78,11 @@ program& program::operator=(program p)
void program::assign(const program& p) void program::assign(const program& p)
{ {
if(!impl) if(not impl)
{ {
impl = std::make_unique<program_impl>(); impl = std::make_unique<program_impl>();
} }
else if(!impl->modules.empty()) else if(not impl->modules.empty())
{ {
impl->modules.clear(); impl->modules.clear();
} }
...@@ -167,13 +168,37 @@ target_assignments program::get_target_assignments(const std::vector<target>& ta ...@@ -167,13 +168,37 @@ target_assignments program::get_target_assignments(const std::vector<target>& ta
target_assignments p; target_assignments p;
const auto* mod = get_main_module(); const auto* mod = get_main_module();
for(auto it : iterator_for(*mod)) std::vector<std::pair<target, supported_segments>> target_subgraphs;
target_subgraphs.reserve(targets.size());
std::transform(targets.begin(),
targets.end(),
std::back_inserter(target_subgraphs),
[&](const auto& t) { return std::make_pair(t, t.find_supported(mod, m)); });
for(const auto ins : iterator_for(*mod))
{ {
auto t = std::max_element( if(contains(p, ins))
targets.begin(), targets.end(), [it, m](const target& lhs, const target& rhs) { {
return lhs.is_supported(it, m) < rhs.is_supported(it, m); continue;
}); }
p.add_assignment(it, t->name());
for(const auto& [target, subgraph] : target_subgraphs)
{
// can't pass a structured binding into lambda in C++17 so create a variable for it
const auto& t = target;
for(const auto& segment : subgraph)
{
const auto& instructions = segment.instructions;
if(not contains(instructions, ins))
{
continue;
}
std::transform(instructions.begin(),
instructions.end(),
std::inserter(p, p.end()),
[&](auto instr) { return std::make_pair(instr, t.name()); });
}
}
} }
return p; return p;
} }
......
...@@ -40,6 +40,7 @@ ...@@ -40,6 +40,7 @@
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/json.hpp> #include <migraphx/json.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
#ifdef HAVE_GPU #ifdef HAVE_GPU
#include <migraphx/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
...@@ -82,7 +83,7 @@ void visit_py(T x, F f) ...@@ -82,7 +83,7 @@ void visit_py(T x, F f)
{ {
f(x.template cast<bool>()); f(x.template cast<bool>());
} }
else if(py::isinstance<py::int_>(x)) else if(py::isinstance<py::int_>(x) or py::hasattr(x, "__index__"))
{ {
f(x.template cast<int>()); f(x.template cast<int>());
} }
...@@ -324,6 +325,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -324,6 +325,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.def("get_parameter_names", &migraphx::program::get_parameter_names) .def("get_parameter_names", &migraphx::program::get_parameter_names)
.def("get_parameter_shapes", &migraphx::program::get_parameter_shapes) .def("get_parameter_shapes", &migraphx::program::get_parameter_shapes)
.def("get_output_shapes", &migraphx::program::get_output_shapes) .def("get_output_shapes", &migraphx::program::get_output_shapes)
.def("is_compiled", &migraphx::program::is_compiled)
.def( .def(
"compile", "compile",
[](migraphx::program& p, const migraphx::target& t, bool offload_copy, bool fast_math) { [](migraphx::program& p, const migraphx::target& t, bool offload_copy, bool fast_math) {
...@@ -358,18 +360,35 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -358,18 +360,35 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.def("__ne__", std::not_equal_to<migraphx::program>{}) .def("__ne__", std::not_equal_to<migraphx::program>{})
.def("__repr__", [](const migraphx::program& p) { return migraphx::to_string(p); }); .def("__repr__", [](const migraphx::program& p) { return migraphx::to_string(p); });
py::class_<migraphx::operation>(m, "op") py::class_<migraphx::operation> op(m, "op");
.def(py::init([](const std::string& name, py::kwargs kwargs) { op.def(py::init([](const std::string& name, py::kwargs kwargs) {
migraphx::value v = migraphx::value::object{}; migraphx::value v = migraphx::value::object{};
if(kwargs) if(kwargs)
{ {
v = migraphx::to_value(kwargs); v = migraphx::to_value(kwargs);
} }
return migraphx::make_op(name, v); return migraphx::make_op(name, v);
})) }))
.def("name", &migraphx::operation::name); .def("name", &migraphx::operation::name);
py::enum_<migraphx::op::pooling_mode>(op, "pooling_mode")
.value("average", migraphx::op::pooling_mode::average)
.value("max", migraphx::op::pooling_mode::max)
.value("lpnorm", migraphx::op::pooling_mode::lpnorm);
py::enum_<migraphx::op::rnn_direction>(op, "rnn_direction")
.value("forward", migraphx::op::rnn_direction::forward)
.value("reverse", migraphx::op::rnn_direction::reverse)
.value("bidirectional", migraphx::op::rnn_direction::bidirectional);
m.def(
"argument_from_pointer",
[](const migraphx::shape shape, const int64_t address) {
return migraphx::argument(shape, reinterpret_cast<void*>(address));
},
py::arg("shape"),
py::arg("address"));
m.def( m.def(
"parse_tf", "parse_tf",
[](const std::string& filename, [](const std::string& filename,
......
...@@ -70,7 +70,7 @@ void quantize_int8(program& prog, ...@@ -70,7 +70,7 @@ void quantize_int8(program& prog,
{ {
std::set<std::string> op_names = {"convolution", "dot"}; std::set<std::string> op_names = {"convolution", "dot"};
std::set<std::string> input_ins_names(ins_names.begin(), ins_names.end()); std::set<std::string> input_ins_names(ins_names.begin(), ins_names.end());
if(!std::includes( if(not std::includes(
op_names.begin(), op_names.end(), input_ins_names.begin(), input_ins_names.end())) op_names.begin(), op_names.end(), input_ins_names.begin(), input_ins_names.end()))
{ {
MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation"); MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation");
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/rewrite_gelu.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/match/gelu_erf.hpp>
#include <migraphx/common.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct find_gelu_erf
{
auto matcher() const { return match::gelu_erf(); }
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x = r.instructions["x"];
if(x->get_shape().type() != migraphx::shape::half_type)
return;
auto lit = m.add_literal(literal{shape{x->get_shape().type()}, {1.702f}});
auto mul = insert_common_op(m, ins, make_op("mul"), {x, lit});
auto sig = m.insert_instruction(ins, make_op("neg"), mul);
sig = m.insert_instruction(ins, make_op("exp"), sig);
auto one = m.add_literal(literal{shape{x->get_shape().type()}, {1.0f}});
sig = insert_common_op(m, ins, make_op("add"), {sig, one});
sig = m.insert_instruction(ins, make_op("div"), x, sig);
m.replace_instruction(ins, sig);
}
};
void rewrite_gelu::apply(module& m) const { match::find_matches(m, find_gelu_erf{}); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -47,12 +47,12 @@ void rewrite_pooling::apply(module& m) const ...@@ -47,12 +47,12 @@ void rewrite_pooling::apply(module& m) const
if(not s.standard()) if(not s.standard())
continue; continue;
auto&& op = any_cast<op::pooling>(ins->get_operator()); auto&& op = any_cast<op::pooling>(ins->get_operator());
if(!std::all_of(op.padding.begin(), op.padding.end(), [](auto i) { return i == 0; })) if(not std::all_of(op.padding.begin(), op.padding.end(), [](auto i) { return i == 0; }))
continue; continue;
if(!std::all_of(op.stride.begin(), op.stride.end(), [](auto i) { return i == 1; })) if(not std::all_of(op.stride.begin(), op.stride.end(), [](auto i) { return i == 1; }))
continue; continue;
auto lens = s.lens(); auto lens = s.lens();
if(!std::equal(lens.begin() + 2, lens.end(), op.lengths.begin(), op.lengths.end())) if(not std::equal(lens.begin() + 2, lens.end(), op.lengths.begin(), op.lengths.end()))
continue; continue;
std::int64_t n = s.lens()[0]; std::int64_t n = s.lens()[0];
std::int64_t c = s.lens()[1]; std::int64_t c = s.lens()[1];
......
...@@ -214,7 +214,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const ...@@ -214,7 +214,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
ih = m.add_literal(migraphx::literal{ih_shape, data}); ih = m.add_literal(migraphx::literal{ih_shape, data});
} }
if(!is_forward and variable_seq_len) if(not is_forward and variable_seq_len)
{ {
args[0] = args[0] =
m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens); m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
...@@ -520,7 +520,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const ...@@ -520,7 +520,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
ih = m.add_literal(migraphx::literal{ih_shape, data}); ih = m.add_literal(migraphx::literal{ih_shape, data});
} }
if(!is_forward and variable_seq_len) if(not is_forward and variable_seq_len)
{ {
args[0] = args[0] =
m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens); m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
...@@ -977,7 +977,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const ...@@ -977,7 +977,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
pph = args[7]; pph = args[7];
} }
if(!is_forward and variable_seq_len) if(not is_forward and variable_seq_len)
{ {
args[0] = args[0] =
m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens); m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
...@@ -1294,11 +1294,11 @@ bool rewrite_rnn::is_variable_seq_lens(const module& m, instruction_ref seq_lens ...@@ -1294,11 +1294,11 @@ bool rewrite_rnn::is_variable_seq_lens(const module& m, instruction_ref seq_lens
std::vector<int64_t> vec_lens; std::vector<int64_t> vec_lens;
arg_lens.visit([&](auto l) { vec_lens.assign(l.begin(), l.end()); }); arg_lens.visit([&](auto l) { vec_lens.assign(l.begin(), l.end()); });
int64_t l = 0; int64_t l = 0;
if(!vec_lens.empty()) if(not vec_lens.empty())
{ {
l = vec_lens[0]; l = vec_lens[0];
} }
if(!std::all_of(vec_lens.begin(), vec_lens.end(), [&](auto v) { return v == l; })) if(not std::all_of(vec_lens.begin(), vec_lens.end(), [&](auto v) { return v == l; }))
{ {
is_var_lens = true; is_var_lens = true;
} }
...@@ -1318,7 +1318,7 @@ rewrite_rnn::get_seq_len(const module& m, instruction_ref input, instruction_ref ...@@ -1318,7 +1318,7 @@ rewrite_rnn::get_seq_len(const module& m, instruction_ref input, instruction_ref
bool is_var_lens = is_variable_seq_lens(m, seq_lens); bool is_var_lens = is_variable_seq_lens(m, seq_lens);
auto input_shape = input->get_shape(); auto input_shape = input->get_shape();
auto length = input_shape.lens()[0]; auto length = input_shape.lens()[0];
if(!is_var_lens and seq_lens != m.end()) if(not is_var_lens and seq_lens != m.end())
{ {
auto arg_len = seq_lens->eval(); auto arg_len = seq_lens->eval();
std::vector<std::size_t> vec_lens; std::vector<std::size_t> vec_lens;
...@@ -1387,7 +1387,7 @@ void rewrite_rnn::replace_last_cell_output(module& m, ...@@ -1387,7 +1387,7 @@ void rewrite_rnn::replace_last_cell_output(module& m,
if(variable_seq_len) if(variable_seq_len)
{ {
if(!ins_outputs.empty()) if(not ins_outputs.empty())
{ {
cell_outputs = m.insert_instruction( cell_outputs = m.insert_instruction(
std::next(ins), std::next(ins),
......
...@@ -477,7 +477,7 @@ bool operator==(const shape::dynamic_dimension& x, const shape::dynamic_dimensio ...@@ -477,7 +477,7 @@ bool operator==(const shape::dynamic_dimension& x, const shape::dynamic_dimensio
bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y) bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y)
{ {
return !(x == y); return not(x == y);
} }
std::ostream& operator<<(std::ostream& os, const shape::dynamic_dimension& x) std::ostream& operator<<(std::ostream& os, const shape::dynamic_dimension& x)
{ {
...@@ -497,7 +497,7 @@ bool operator==(const shape& x, const shape& y) ...@@ -497,7 +497,7 @@ bool operator==(const shape& x, const shape& y)
x.strides() == y.strides() and x.sub_shapes() == y.sub_shapes()); x.strides() == y.strides() and x.sub_shapes() == y.sub_shapes());
} }
bool operator!=(const shape& x, const shape& y) { return !(x == y); } bool operator!=(const shape& x, const shape& y) { return not(x == y); }
std::ostream& operator<<(std::ostream& os, const shape& x) std::ostream& operator<<(std::ostream& os, const shape& x)
{ {
......
...@@ -208,6 +208,42 @@ struct find_mul_add ...@@ -208,6 +208,42 @@ struct find_mul_add
} }
}; };
struct find_dot_add
{
auto matcher() const
{
return match::name("dot")(match::either_arg(0, 1)(
match::name("add")(
match::either_arg(0, 1)(match::any().bind("x"),
match::any_of(match::is_constant()).bind("b")),
match::none_of(match::args(match::is_constant(), match::is_constant())),
match::used_once()),
match::is_constant().bind("a")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"];
auto x_ins = r.instructions["x"];
assert(x_ins != b_ins);
const bool flipped = a_ins == ins->inputs().back();
auto insert_dot = [&](auto x, auto y) {
if(flipped)
return m.insert_instruction(ins, make_op("dot"), y, x);
else
return m.insert_instruction(ins, make_op("dot"), x, y);
};
auto ax_ins = insert_dot(a_ins, x_ins);
auto ab_ins = insert_dot(a_ins, b_ins);
m.replace_instruction(ins, make_op("add"), ax_ins, ab_ins);
}
};
struct find_add_lit_broadcast struct find_add_lit_broadcast
{ {
auto matcher() const auto matcher() const
...@@ -267,28 +303,26 @@ struct find_double_add_lit_broadcast ...@@ -267,28 +303,26 @@ struct find_double_add_lit_broadcast
struct find_inner_broadcast struct find_inner_broadcast
{ {
auto matcher() const auto matcher() const { return pointwise(match::all_of[match::inputs()](match::broadcast())); }
{
return pointwise(
match::nargs(2),
match::args(match::name("broadcast").bind("x"), match::name("broadcast").bind("y")));
}
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto broadcasts = ins->inputs();
auto y_ins = r.instructions["y"]; if(broadcasts.empty())
return;
auto xbroadcast = any_cast<op::broadcast>(x_ins->get_operator()); std::vector<instruction_ref> inputs;
auto ybroadcast = any_cast<op::broadcast>(y_ins->get_operator()); std::transform(broadcasts.begin(),
broadcasts.end(),
if(xbroadcast.axis != ybroadcast.axis) std::back_inserter(inputs),
[](auto i) { return i->inputs().front(); });
if(std::any_of(inputs.begin(), inputs.end(), [&](auto i) {
return i->get_shape() != inputs.front()->get_shape();
}))
return; return;
auto op = m.insert_instruction( auto op = m.insert_instruction(ins, ins->get_operator(), inputs);
ins, ins->get_operator(), x_ins->inputs().front(), y_ins->inputs().front()); m.replace_instruction(ins, broadcasts.front()->get_operator(), op);
m.replace_instruction(ins, xbroadcast, op);
} }
}; };
...@@ -416,8 +450,9 @@ struct find_splits ...@@ -416,8 +450,9 @@ struct find_splits
{ {
auto matcher() const auto matcher() const
{ {
return match::any(match::any_of[match::outputs()](match::name("slice")( return match::any(
match::any_of[match::outputs()](match::pointwise(), reduction())))); match::any_of[match::outputs()](match::name("slice")(match::any_of[match::outputs()](
match::pointwise(match::any_of(match::nargs(1), match::nargs(2))), reduction()))));
} }
static bool is_dependent(const module& m, instruction_ref ins1, instruction_ref ins2) static bool is_dependent(const module& m, instruction_ref ins1, instruction_ref ins2)
...@@ -580,10 +615,9 @@ struct find_splits ...@@ -580,10 +615,9 @@ struct find_splits
auto outputs = i->outputs(); auto outputs = i->outputs();
for(auto output : outputs) for(auto output : outputs)
{ {
if(not contains({"reshape", "squeeze", "unsqueeze"}, output->name())) if(output->name() != "reshape")
continue; continue;
auto x = auto x = m.insert_instruction(output, make_op("contiguous"), i);
m.insert_instruction(output, make_op("contiguous"), output->inputs());
m.replace_instruction(output, output->get_operator(), x); m.replace_instruction(output, output->get_operator(), x);
} }
...@@ -753,7 +787,7 @@ MIGRAPHX_PRED_MATCHER(horiz_conv_dot, instruction_ref ins) ...@@ -753,7 +787,7 @@ MIGRAPHX_PRED_MATCHER(horiz_conv_dot, instruction_ref ins)
}; };
auto dots = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("dot")); auto dots = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("dot"));
auto convs = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("convolution")); auto convs = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("convolution"));
return !(dots < 2 and convs < 2); return not(dots < 2 and convs < 2);
} }
struct find_conv_dot_horiz_fusion struct find_conv_dot_horiz_fusion
...@@ -773,7 +807,7 @@ struct find_conv_dot_horiz_fusion ...@@ -773,7 +807,7 @@ struct find_conv_dot_horiz_fusion
auto y = j->inputs()[1]->get_shape().lens(); auto y = j->inputs()[1]->get_shape().lens();
if(x.size() != y.size()) if(x.size() != y.size())
return false; return false;
// Check that non-axises match // Check that non-axes match
int axis = 1; int axis = 1;
if(i->name() == "dot") if(i->name() == "dot")
{ {
...@@ -809,13 +843,22 @@ struct find_conv_dot_horiz_fusion ...@@ -809,13 +843,22 @@ struct find_conv_dot_horiz_fusion
for(auto arg : args) for(auto arg : args)
m.move_instructions(arg, input); m.move_instructions(arg, input);
// TODO: Check if axises match // TODO: Check if axes match
auto concat = auto concat =
m.insert_instruction(input, make_op("concat", {{"axis", concat_axis}}), args); m.insert_instruction(input, make_op("concat", {{"axis", concat_axis}}), args);
auto fused = m.insert_instruction(std::next(input), op, input, concat); auto fused = m.insert_instruction(std::next(input), op, input, concat);
int64_t offset = 0; int64_t offset = 0;
for(auto arg : range(start, last)) for(auto arg : range(start, last))
{ {
auto outputs = arg->outputs();
for(auto output : outputs)
{
if(output->name() != "reshape")
continue;
auto x = m.insert_instruction(output, make_op("contiguous"), arg);
m.replace_instruction(output, output->get_operator(), x);
}
int64_t len = arg->get_shape().lens()[axis]; int64_t len = arg->get_shape().lens()[axis];
m.replace_instruction( m.replace_instruction(
arg, arg,
...@@ -993,7 +1036,7 @@ struct find_split_reshape ...@@ -993,7 +1036,7 @@ struct find_split_reshape
// all outputs are reshape and of the same shape // all outputs are reshape and of the same shape
auto dims = any_cast<op::reshape>(rsp->get_operator()).dims; auto dims = any_cast<op::reshape>(rsp->get_operator()).dims;
if(!same_ops(vec_rsp)) if(not same_ops(vec_rsp))
{ {
return; return;
} }
...@@ -1025,7 +1068,11 @@ struct find_split_reshape ...@@ -1025,7 +1068,11 @@ struct find_split_reshape
std::vector<int64_t> rsp_out_lens(rsp_lens.begin(), rsp_lens.end()); std::vector<int64_t> rsp_out_lens(rsp_lens.begin(), rsp_lens.end());
rsp_out_lens[rsp_axis] = std::accumulate(vec_dims.begin(), vec_dims.end(), std::int64_t{0}); rsp_out_lens[rsp_axis] = std::accumulate(vec_dims.begin(), vec_dims.end(), std::int64_t{0});
// insert the reshape instruction // insert the reshape instruction and add contiguous if needed
if(not input->get_shape().standard())
{
input = m.insert_instruction(std::next(input), make_op("contiguous"), input);
}
auto rsp_ins = m.insert_instruction( auto rsp_ins = m.insert_instruction(
std::next(input), make_op("reshape", {{"dims", rsp_out_lens}}), input); std::next(input), make_op("reshape", {{"dims", rsp_out_lens}}), input);
...@@ -1072,7 +1119,7 @@ struct find_split_transpose ...@@ -1072,7 +1119,7 @@ struct find_split_transpose
// all transpose are the same // all transpose are the same
auto perm = any_cast<op::transpose>(trans->get_operator()).dims; auto perm = any_cast<op::transpose>(trans->get_operator()).dims;
if(!same_ops(vec_trans)) if(not same_ops(vec_trans))
{ {
return; return;
} }
...@@ -1118,6 +1165,7 @@ void simplify_algebra::apply(module& m) const ...@@ -1118,6 +1165,7 @@ void simplify_algebra::apply(module& m) const
find_unit_ops{}, find_unit_ops{},
find_neg_unit_ops{}, find_neg_unit_ops{},
find_zero_ops{}, find_zero_ops{},
find_dot_add{},
find_div_const{}, find_div_const{},
find_sub_const{}, find_sub_const{},
find_rsqrt{}, find_rsqrt{},
......
...@@ -99,7 +99,7 @@ struct find_reshaper ...@@ -99,7 +99,7 @@ struct find_reshaper
std::vector<instruction_ref> reshapes{ins}; std::vector<instruction_ref> reshapes{ins};
while(is_reshaper(reshapes.back())) while(is_reshaper(reshapes.back()))
{ {
assert(!reshapes.back()->inputs().empty()); assert(not reshapes.back()->inputs().empty());
assert(m.has_instruction(reshapes.back()->inputs().front())); assert(m.has_instruction(reshapes.back()->inputs().front()));
auto input = reshapes.back()->inputs().front(); auto input = reshapes.back()->inputs().front();
reshapes.push_back(input); reshapes.push_back(input);
...@@ -151,8 +151,11 @@ struct find_transpose ...@@ -151,8 +151,11 @@ struct find_transpose
{ {
auto matcher() const auto matcher() const
{ {
return match::name("transpose")(match::none_of( auto output_not_transpose =
match::skip_output(match::name("contiguous"))(match::name("transpose")))); match::none_of(match::skip_output(match::name("contiguous"))(match::name("transpose")));
auto input_has_transpose =
match::args(match::skip(match::name("contiguous"))(match::name("transpose")));
return match::name("transpose")(output_not_transpose, input_has_transpose);
} }
void apply(module& m, const match::matcher_result& mr) const void apply(module& m, const match::matcher_result& mr) const
...@@ -285,7 +288,7 @@ struct find_concat_transpose ...@@ -285,7 +288,7 @@ struct find_concat_transpose
auto permutation = find_permutation(s); auto permutation = find_permutation(s);
// permutation should be the same for all inputs // permutation should be the same for all inputs
if(!std::all_of(trans_inputs.begin(), trans_inputs.end(), [&](auto in) { if(not std::all_of(trans_inputs.begin(), trans_inputs.end(), [&](auto in) {
return (find_permutation(in->get_shape()) == permutation); return (find_permutation(in->get_shape()) == permutation);
})) }))
{ {
...@@ -664,9 +667,94 @@ struct find_slice_transpose ...@@ -664,9 +667,94 @@ struct find_slice_transpose
} }
}; };
struct find_transpose_slice
{
auto matcher() const
{
return match::name("transpose")(match::all_of[match::outputs()](match::name("slice")));
}
static std::vector<int64_t> slice_distance(const op::slice& op)
{
assert(op.starts.size() == op.ends.size());
std::vector<int64_t> result(op.starts.size());
std::transform(
op.ends.begin(), op.ends.end(), op.starts.begin(), result.begin(), std::minus<>{});
return result;
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto slices = ins->outputs();
if(slices.empty())
return;
auto slice = any_cast<op::slice>(slices.front()->get_operator());
auto sdistance = slice_distance(slice);
// Check all distances and axes are the same
if(std::any_of(slices.begin(), slices.end(), [&](auto sins) {
auto s = any_cast<op::slice>(sins->get_operator());
return s.axes != slice.axes or slice_distance(s) != sdistance;
}))
return;
// Check distances are divisible by lens of corresponding axes
auto mod_by_distance = [&](const auto& v, auto f) {
return std::inner_product(v.begin(),
v.end(),
sdistance.begin(),
0,
std::plus<>{},
[&](auto x, auto d) -> uint64_t {
if(d == 0)
return 1;
return f(x) % d;
});
};
if(mod_by_distance(slice.axes, [&](auto x) { return ins->get_shape().lens()[x]; }) != 0 or
mod_by_distance(slice.starts, id{}) != 0 or mod_by_distance(slice.ends, id{}) != 0)
return;
// TODO: Handle multiple axes
if(sdistance.size() != 1)
return;
auto axis = slice.axes.front();
// Skip if axis would be packed
if(std::all_of(ins->get_shape().lens().begin(),
ins->get_shape().lens().begin() + axis,
[](auto x) { return x == 1; }))
return;
// Compute axis before transpose to use for unsqueeze
auto perm = ins->get_operator().to_value()["permutation"].to_vector<int64_t>();
auto preaxis = std::find(perm.begin(), perm.end(), axis) - perm.begin();
// Make unsqeeze
auto unsqueeze = m.insert_instruction(
ins, make_op("unsqueeze", {{"axes", {preaxis}}, {"steps", sdistance}}), ins->inputs());
// Make transpose
std::transform(perm.begin(), perm.end(), perm.begin(), [&](auto i) {
if(i > preaxis)
return i + 1;
return i;
});
perm.insert(perm.begin(), preaxis + 1);
auto transpose =
m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), unsqueeze);
// Slice and squeeze
for(auto s : slices)
{
auto op = any_cast<op::slice>(s->get_operator());
op.axes = {0};
op.starts = {op.starts.front() / sdistance.front()};
op.ends = {op.ends.front() / sdistance.front()};
auto slice_ins = m.insert_instruction(ins, op, transpose);
auto squeeze =
m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), slice_ins);
m.replace_instruction(s, squeeze);
}
}
};
void simplify_reshapes::apply(module& m) const void simplify_reshapes::apply(module& m) const
{ {
for(int i = 0; i < 2; i++) for(int i = 0; i < 4; i++)
{ {
match::find_matches(m, match::find_matches(m,
find_where_op{}, find_where_op{},
...@@ -679,6 +767,7 @@ void simplify_reshapes::apply(module& m) const ...@@ -679,6 +767,7 @@ void simplify_reshapes::apply(module& m) const
find_nested_convert{}, find_nested_convert{},
find_nested_slice{}, find_nested_slice{},
find_nested_concat{}, find_nested_concat{},
find_transpose_slice{},
find_slice_transpose{}, find_slice_transpose{},
find_transpose_contiguous_reshaper_unary{}); find_transpose_contiguous_reshaper_unary{});
dead_code_elimination{}.apply(m); dead_code_elimination{}.apply(m);
......
...@@ -49,7 +49,7 @@ struct dnnl_binary : dnnl_op<dnnl_binary, dnnl::binary> ...@@ -49,7 +49,7 @@ struct dnnl_binary : dnnl_op<dnnl_binary, dnnl::binary>
auto s0 = inputs.at(0); auto s0 = inputs.at(0);
auto s1 = inputs.at(1); auto s1 = inputs.at(1);
auto r = s0; auto r = s0;
if(s0 != s1 or !s0.packed()) if(s0 != s1 or not s0.packed())
{ {
r = shape{s0.type(), s0.lens()}; r = shape{s0.type(), s0.lens()};
} }
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <migraphx/compile_options.hpp> #include <migraphx/compile_options.hpp>
#include <migraphx/fpga/context.hpp> #include <migraphx/fpga/context.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/supported_segments.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -41,7 +42,7 @@ struct target ...@@ -41,7 +42,7 @@ struct target
std::string name() const; std::string name() const;
std::vector<pass> get_passes(migraphx::context& ctx, const compile_options&) const; std::vector<pass> get_passes(migraphx::context& ctx, const compile_options&) const;
migraphx::context get_context() const { return context{}; } migraphx::context get_context() const { return context{}; }
float is_supported(instruction_ref ins, support_metric m); supported_segments find_supported(const_module_ref mod, support_metric m) const;
argument copy_to(const argument& arg) const { return arg; } argument copy_to(const argument& arg) const { return arg; }
argument copy_from(const argument& arg) const { return arg; } argument copy_from(const argument& arg) const { return arg; }
......
...@@ -95,7 +95,7 @@ void subgraph::apply(module_pass_manager& mpm) const ...@@ -95,7 +95,7 @@ void subgraph::apply(module_pass_manager& mpm) const
for(auto it : iterator_for(mod)) for(auto it : iterator_for(mod))
{ {
// assuming we want all the params/literals as inputs to the FPGA submodule // assuming we want all the params/literals as inputs to the FPGA submodule
if(migraphx::starts_with(it->name(), "@param") || if(migraphx::starts_with(it->name(), "@param") or
migraphx::starts_with(it->name(), "@literal")) migraphx::starts_with(it->name(), "@literal"))
{ {
literal_inputs.push_back(it); literal_inputs.push_back(it);
......
...@@ -34,6 +34,7 @@ ...@@ -34,6 +34,7 @@
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/normalize_ops.hpp> #include <migraphx/normalize_ops.hpp>
#include <migraphx/iterator_for.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -62,12 +63,17 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -62,12 +63,17 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
argument target::allocate(const shape& s) const { return fill_argument(s, 0); } argument target::allocate(const shape& s) const { return fill_argument(s, 0); }
float is_supported(instruction_ref ins, support_metric m) supported_segments target::find_supported(const_module_ref mod, support_metric m) const
{ {
// for now, not using the ins and metric to return a value
(void)ins;
(void)m; (void)m;
return 1.0;
supported_segment instrs;
for(const auto ins : iterator_for(*mod))
{
instrs.instructions.insert(ins);
}
instrs.metric = 1; // arbitrary value
return {instrs};
} }
MIGRAPHX_REGISTER_TARGET(target); MIGRAPHX_REGISTER_TARGET(target);
......
...@@ -51,7 +51,8 @@ code_object_op::compute(context& ctx, const shape&, const std::vector<argument>& ...@@ -51,7 +51,8 @@ code_object_op::compute(context& ctx, const shape&, const std::vector<argument>&
std::vector<void*> kargs(args.size()); std::vector<void*> kargs(args.size());
std::transform( std::transform(
args.begin(), args.end(), kargs.begin(), [](const argument& a) { return a.data(); }); args.begin(), args.end(), kargs.begin(), [](const argument& a) { return a.data(); });
k.launch(ctx.get_stream().get(), global, local, std::move(kargs)); auto [start, stop] = ctx.get_perf_events();
k.launch(ctx.get_stream().get(), global, local, std::move(kargs), start, stop);
return args[get_output_arg(args.size())]; return args[get_output_arg(args.size())];
} }
void code_object_op::finalize(context&, const shape&, const std::vector<shape>&) void code_object_op::finalize(context&, const shape&, const std::vector<shape>&)
......
...@@ -25,6 +25,13 @@ ...@@ -25,6 +25,13 @@
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/permutation.hpp> #include <migraphx/permutation.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/module.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -54,12 +61,19 @@ vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs ...@@ -54,12 +61,19 @@ vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs
[&](const auto& input) -> std::size_t { [&](const auto& input) -> std::size_t {
auto stride = input.strides()[axis]; auto stride = input.strides()[axis];
auto len = input.lens()[axis]; auto len = input.lens()[axis];
if(stride != 0 and stride != 1) if(not contains({0, 1}, stride))
return 1; return 1;
if(len == 1 and input.elements() > sizes.front()) if(len == 1 and input.elements() > sizes.front())
return sizes.front(); return sizes.front();
auto it = std::find_if( auto it = std::find_if(sizes.begin(), sizes.end(), [&](auto vsize) {
sizes.begin(), sizes.end(), [&](auto i) { return (len % i) == 0; }); // The len is divisible by the size and all the strides are divisible by
// the size
return (len % vsize) == 0 and
std::all_of(
input.strides().begin(), input.strides().end(), [&](auto i) {
return contains({0, 1}, i) or i % vsize == 0;
});
});
if(it != sizes.end()) if(it != sizes.end())
return *it; return *it;
return 1; return 1;
...@@ -75,25 +89,25 @@ std::string vectorize::str() const ...@@ -75,25 +89,25 @@ std::string vectorize::str() const
preload preload::broadcasts(std::size_t axis, const std::vector<shape>& inputs) preload preload::broadcasts(std::size_t axis, const std::vector<shape>& inputs)
{ {
const std::size_t max_lds_bytes = 4096; const std::size_t max_lds_bytes = 4096;
std::vector<bool> result; std::vector<bool> result(inputs.size());
std::transform(inputs.begin(), std::vector<std::size_t> preloaded;
inputs.end(), auto idxs = range(inputs.size());
std::back_inserter(result), std::copy_if(idxs.begin(), idxs.end(), std::back_inserter(preloaded), [&](auto i) {
[&](const shape& input) { return input.strides()[axis] == 0; }); return inputs[i].strides()[axis] == 0;
auto bytes = std::inner_product(inputs.begin(), });
inputs.end(), std::sort(preloaded.begin(), preloaded.end(), by(std::less<>{}, [&](auto i) {
result.begin(), return inputs[i].bytes();
std::size_t{0}, }));
std::plus<>{},
[](const shape& s, bool b) -> std::size_t { std::size_t bytes = 0;
if(b) for(auto i : preloaded)
return s.bytes(); {
return 0; auto input = inputs[i];
}); bytes += input.bytes();
if(bytes < max_lds_bytes) if(bytes > max_lds_bytes)
return {result}; break;
// TODO: Try to partially preload items result[i] = true;
std::fill(result.begin(), result.end(), false); }
return {result}; return {result};
} }
...@@ -125,6 +139,45 @@ std::string make_transformer_args(std::vector<std::string> transformers) ...@@ -125,6 +139,45 @@ std::string make_transformer_args(std::vector<std::string> transformers)
return join_strings(std::move(transformers), ", "); return join_strings(std::move(transformers), ", ");
} }
std::string generate_pointwise(const module& pm, const std::string& name)
{
module m = pm;
run_passes(m, {eliminate_common_subexpression{}, dead_code_elimination{}});
cpp_generator g;
g.fmap([](const std::string& fname) { return "migraphx::" + fname; });
g.add_point_op("where", "${function:where}(${0}, ${1}, ${2})");
g.add_point_op("prelu", "${function:where}(${0} < 0, ${0} * ${1}, ${0})");
g.add_point_op("sign", "${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))");
g.add_point_op("equal", "migraphx::abs(${0} == ${1})");
g.add_point_op("less", "migraphx::abs(${0} < ${1})");
g.add_point_op("greater", "migraphx::abs(${0} > ${1})");
g.add_point_op("not", "migraphx::abs(not ${0})");
// Add explict conversions
g.fresult(
[](const shape& s) { return "migraphx::convert<" + shape::cpp_type(s.type()) + ">"; });
g.create_function(
g.generate_module(m).set_attributes({"__device__"}).set_generic_types(m).set_name(name));
return g.str();
}
static std::vector<std::string> get_op_names(const module& m)
{
std::vector<std::string> result;
for(auto& ins : m)
{
if(starts_with(ins.name(), "@"))
continue;
result.push_back(ins.name());
}
return result;
}
std::string generate_name_from_ops(const module& m)
{
auto op_names = get_op_names(m);
return join_strings(op_names, "_");
}
} // namespace gen } // namespace gen
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -131,7 +131,7 @@ struct hip_array ...@@ -131,7 +131,7 @@ struct hip_array
friend MIGRAPHX_DEVICE_CONSTEXPR bool operator!=(const hip_array& x, const hip_array& y) friend MIGRAPHX_DEVICE_CONSTEXPR bool operator!=(const hip_array& x, const hip_array& y)
{ {
return !(x == y); return not(x == y);
} }
// This uses the product order rather than lexical order // This uses the product order rather than lexical order
friend MIGRAPHX_DEVICE_CONSTEXPR bool operator<(const hip_array& x, const hip_array& y) friend MIGRAPHX_DEVICE_CONSTEXPR bool operator<(const hip_array& x, const hip_array& y)
......
...@@ -117,12 +117,13 @@ template <class V, class F, class... Ts> ...@@ -117,12 +117,13 @@ template <class V, class F, class... Ts>
void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs) void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
{ {
std::initializer_list<migraphx::shape::type_t> types = {get_shape(xs).type()...}; std::initializer_list<migraphx::shape::type_t> types = {get_shape(xs).type()...};
if(!std::all_of( if(not std::all_of(
types.begin(), types.end(), [&](migraphx::shape::type_t t) { return t == s.type(); })) types.begin(), types.end(), [&](migraphx::shape::type_t t) { return t == s.type(); }))
MIGRAPHX_THROW("Types must be the same"); MIGRAPHX_THROW("Types must be the same");
std::initializer_list<index_int> ranks = { std::initializer_list<index_int> ranks = {
static_cast<index_int>(get_shape(xs).lens().size())...}; static_cast<index_int>(get_shape(xs).lens().size())...};
if(!std::all_of(ranks.begin(), ranks.end(), [&](index_int r) { return r == s.lens().size(); })) if(not std::all_of(
ranks.begin(), ranks.end(), [&](index_int r) { return r == s.lens().size(); }))
MIGRAPHX_THROW("Ranks must be the same"); MIGRAPHX_THROW("Ranks must be the same");
visit_tensor_size(s.lens().size(), [&](auto ndim) { visit_tensor_size(s.lens().size(), [&](auto ndim) {
s.visit_type(hip_visitor([&](auto as) { v(f(xs, ndim, as)...); })); s.visit_type(hip_visitor([&](auto as) { v(f(xs, ndim, as)...); }));
...@@ -134,7 +135,8 @@ void hip_visit_views_impl(const shape& s, F f, V&& v, Ts&&... xs) ...@@ -134,7 +135,8 @@ void hip_visit_views_impl(const shape& s, F f, V&& v, Ts&&... xs)
{ {
std::initializer_list<index_int> ranks = { std::initializer_list<index_int> ranks = {
static_cast<index_int>(get_shape(xs).lens().size())...}; static_cast<index_int>(get_shape(xs).lens().size())...};
if(!std::all_of(ranks.begin(), ranks.end(), [&](index_int r) { return r == s.lens().size(); })) if(not std::all_of(
ranks.begin(), ranks.end(), [&](index_int r) { return r == s.lens().size(); }))
MIGRAPHX_THROW("Ranks must be the same"); MIGRAPHX_THROW("Ranks must be the same");
visit_tensor_size(s.lens().size(), [&](auto ndim) { v(f(xs, ndim)...); }); visit_tensor_size(s.lens().size(), [&](auto ndim) { v(f(xs, ndim)...); });
} }
......
...@@ -47,7 +47,7 @@ constexpr Iterator upper_bound(Iterator first, Iterator last, const T& value) ...@@ -47,7 +47,7 @@ constexpr Iterator upper_bound(Iterator first, Iterator last, const T& value)
it = first; it = first;
step = count / 2; step = count / 2;
std::advance(it, step); std::advance(it, step);
if(!(value < *it)) if(not(value < *it))
{ {
first = ++it; first = ++it;
count -= step + 1; count -= step + 1;
......
...@@ -38,8 +38,11 @@ struct compile_op : action<compile_op> ...@@ -38,8 +38,11 @@ struct compile_op : action<compile_op>
context ctx; context ctx;
auto inputs = p.parse_shapes(v.at("inputs")); auto inputs = p.parse_shapes(v.at("inputs"));
auto op = gpu::compile_op(v.at("name").to<std::string>(), ctx, inputs, v); auto op = gpu::compile_op(v.at("name").to<std::string>(), ctx, inputs, v);
double t = time_op(ctx, op, inputs, p.get(v, "iterations", 100)); auto [host_time, device_time] = time_op(ctx, op, inputs, p.get(v, "iterations", 100));
std::cout << op << ": " << t << "ms" << std::endl; std::cout << op << ": " << host_time << "ms";
if(device_time > 0)
std::cout << ", " << device_time << "ms";
std::cout << std::endl;
} }
}; };
......
...@@ -33,7 +33,8 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -33,7 +33,8 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace driver { namespace driver {
double time_op(context& ctx, operation op, const std::vector<shape>& inputs, int n = 100); std::pair<double, double>
time_op(context& ictx, operation op, const std::vector<shape>& inputs, int n = 100);
} // namespace driver } // namespace driver
} // namespace gpu } // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment