Commit 3a4d36cf authored by charlie's avatar charlie
Browse files

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_model_test

parents 6bec381f e19f78ae
......@@ -56,7 +56,7 @@ const auto& get_nearest_op(const std::string& mode)
return static_cast<std::size_t>(std::ceil((val)));
}}};
if(!contains(nearest_ops, mode))
if(not contains(nearest_ops, mode))
{
MIGRAPHX_THROW("PARSE_RESIZE: nearest_mode " + mode + " not supported!");
}
......@@ -86,7 +86,7 @@ const auto& get_original_idx_op(const std::string& mode)
return (idx + 0.5) / scale;
}}};
if(!contains(idx_ops, mode))
if(not contains(idx_ops, mode))
{
MIGRAPHX_THROW("PARSE_RESIZE: coordinate_transformation_mode " + mode + " not supported!");
}
......
......@@ -31,7 +31,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
//! Parser for ReverseSequence ONNX operator.
// Parser for ReverseSequence ONNX operator.
/*!
Reverses the data along the time axis for the batches along the batch axis.
The sequence lengths can be given to reverse up to the given length for each batch, keeping the
......
......@@ -29,7 +29,7 @@ inline namespace MIGRAPHX_INLINE_NS {
void memory_coloring::apply(module& m) const
{
if(!enabled(MIGRAPHX_DISABLE_MEMORY_COLORING{}))
if(not enabled(MIGRAPHX_DISABLE_MEMORY_COLORING{}))
{
memory_coloring_impl opt(&m, allocation_op, verify);
opt.run();
......
......@@ -42,7 +42,7 @@ void memory_coloring_impl::run()
{
MIGRAPHX_DEBUG(dump_intervals());
// Coloring
while(!alloc_queue.empty())
while(not alloc_queue.empty())
{
interval_ptr interval = alloc_queue.top();
allocate(interval);
......@@ -72,7 +72,7 @@ bool memory_coloring_impl::allocate(interval_ptr interval)
if(conflict_table.find(vn) != conflict_table.end())
{
std::set<int>& vn_set = conflict_table[vn];
const std::set<int>& vn_set = conflict_table[vn];
for(const auto& iter : vn_set)
{
live_range* range = live_ranges[iter];
......@@ -96,7 +96,7 @@ bool memory_coloring_impl::allocate(interval_ptr interval)
}
std::size_t offset = 0;
while(!conflict_queue.empty())
while(not conflict_queue.empty())
{
live_range* range = conflict_queue.top();
std::size_t iter_offset = range->offset;
......@@ -149,7 +149,7 @@ void memory_coloring_impl::build()
{
def_interval = instr2_live[p_iter];
bool is_lit = is_literal(iter);
if(is_allocate(iter) || is_lit)
if(is_allocate(iter) or is_lit)
{
live_range& range = def_interval->segment;
def_interval->result = iter->get_shape();
......@@ -157,12 +157,12 @@ void memory_coloring_impl::build()
range.begin = cur_points;
def_interval->def_point = cur_points;
range.size = (iter->get_shape()).bytes();
if(!is_lit || unify_literals)
if(not is_lit or unify_literals)
alloc_queue.push(def_interval);
live_set.erase(range.vn);
}
}
else if(!is_param(iter) && !is_outline(iter) && !is_check_context(iter))
else if(not is_param(iter) && not is_outline(iter) && not is_check_context(iter))
{
is_dead = true;
}
......@@ -179,7 +179,7 @@ void memory_coloring_impl::build()
if(not p_mod->has_instruction(arg))
continue;
if(is_param(arg) || is_outline(arg))
if(is_param(arg) or is_outline(arg))
{
if(is_output_param(arg))
is_dead = false;
......@@ -235,7 +235,7 @@ void memory_coloring_impl::rewrite()
if(interval->get_begin() == invalid_offset)
continue;
if(!unify_literals && interval->is_literal)
if(not unify_literals && interval->is_literal)
continue;
std::size_t offset = 0;
......@@ -267,12 +267,12 @@ void memory_coloring_impl::verify()
{
for(int i = 0; i < num_of_lives; ++i)
{
live_interval& interval = live_intervals[i];
live_range& segment = interval.segment;
const live_interval& interval = live_intervals[i];
const live_range& segment = interval.segment;
if(segment.begin == invalid_offset)
{
// if(!interval.is_live_on_entry)
// if(not interval.is_live_on_entry)
// MIGRAPHX_THROW("interval is not live on entry");
continue;
}
......@@ -284,13 +284,13 @@ void memory_coloring_impl::verify()
int vn = segment.vn;
if(conflict_table.find(vn) != conflict_table.end())
{
std::set<int>& vn_set = conflict_table[vn];
const std::set<int>& vn_set = conflict_table[vn];
for(const auto& iter : vn_set)
{
live_range* range = live_ranges[iter];
if(range->offset == invalid_offset)
continue;
if(!is_disjoin(*range, segment))
if(not is_disjoin(*range, segment))
MIGRAPHX_THROW("range and segment is not disjoined");
}
}
......@@ -319,8 +319,8 @@ void memory_coloring_impl::dump_intervals()
{
std::cout << " segment:" << i;
std::cout << " =>";
std::set<int>& table = conflict_table[i];
for(auto& iter : table)
const std::set<int>& table = conflict_table[i];
for(const auto& iter : table)
{
std::cout << (iter) << ",";
}
......@@ -357,7 +357,7 @@ void live_interval::dump()
std::cout << "id:" << id;
segment.dump();
std::cout << " uses:";
for(auto& iter : use_points)
for(const auto& iter : use_points)
{
std::cout << " " << get_ins_enum(iter) << ",";
}
......
......@@ -125,11 +125,11 @@ struct memory_coloring_impl
static bool is_disjoin(const live_range& range1, const live_range& range2)
{
if((range1.size == 0) || (range2.size == 0))
if((range1.size == 0) or (range2.size == 0))
return false;
auto end1 = range1.offset + range1.size - 1;
auto end2 = range2.offset + range2.size - 1;
return ((end1 < range2.offset) || (end2 < range1.offset));
return ((end1 < range2.offset) or (end2 < range1.offset));
}
void verify();
#ifdef MIGRAPHX_DEBUG_OPT
......
......@@ -60,7 +60,7 @@ std::vector<std::size_t> calc_dyn_auto_pad(std::vector<std::size_t> tensor_lens,
{
std::vector<std::size_t> padding;
padding.resize(2 * k_lens.size());
for(size_t i = 0; i < padding.size() / 2; i++)
for(std::size_t i = 0; i < padding.size() / 2; i++)
{
std::ptrdiff_t input_dim = tensor_lens[i];
std::ptrdiff_t stride = strides[i];
......
......@@ -50,7 +50,7 @@ int exec(const std::string& cmd, const std::function<void(const char*)>& std_out
{
// TODO: Use execve instead of popen
std::unique_ptr<FILE, decltype(closer)> pipe(popen(cmd.c_str(), "r"), closer); // NOLINT
if(!pipe)
if(not pipe)
MIGRAPHX_THROW("popen() failed: " + cmd);
std::array<char, 128> buffer;
while(fgets(buffer.data(), buffer.size(), pipe.get()) != nullptr)
......
......@@ -37,6 +37,7 @@
#include <migraphx/output_iterator.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/marker.hpp>
#include <migraphx/supported_segments.hpp>
#include <iostream>
#include <sstream>
#include <algorithm>
......@@ -77,11 +78,11 @@ program& program::operator=(program p)
void program::assign(const program& p)
{
if(!impl)
if(not impl)
{
impl = std::make_unique<program_impl>();
}
else if(!impl->modules.empty())
else if(not impl->modules.empty())
{
impl->modules.clear();
}
......@@ -167,13 +168,37 @@ target_assignments program::get_target_assignments(const std::vector<target>& ta
target_assignments p;
const auto* mod = get_main_module();
for(auto it : iterator_for(*mod))
std::vector<std::pair<target, supported_segments>> target_subgraphs;
target_subgraphs.reserve(targets.size());
std::transform(targets.begin(),
targets.end(),
std::back_inserter(target_subgraphs),
[&](const auto& t) { return std::make_pair(t, t.find_supported(mod, m)); });
for(const auto ins : iterator_for(*mod))
{
auto t = std::max_element(
targets.begin(), targets.end(), [it, m](const target& lhs, const target& rhs) {
return lhs.is_supported(it, m) < rhs.is_supported(it, m);
});
p.add_assignment(it, t->name());
if(contains(p, ins))
{
continue;
}
for(const auto& [target, subgraph] : target_subgraphs)
{
// can't pass a structured binding into lambda in C++17 so create a variable for it
const auto& t = target;
for(const auto& segment : subgraph)
{
const auto& instructions = segment.instructions;
if(not contains(instructions, ins))
{
continue;
}
std::transform(instructions.begin(),
instructions.end(),
std::inserter(p, p.end()),
[&](auto instr) { return std::make_pair(instr, t.name()); });
}
}
}
return p;
}
......
......@@ -40,6 +40,7 @@
#include <migraphx/register_target.hpp>
#include <migraphx/json.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
#ifdef HAVE_GPU
#include <migraphx/gpu/hip.hpp>
......@@ -82,7 +83,7 @@ void visit_py(T x, F f)
{
f(x.template cast<bool>());
}
else if(py::isinstance<py::int_>(x))
else if(py::isinstance<py::int_>(x) or py::hasattr(x, "__index__"))
{
f(x.template cast<int>());
}
......@@ -263,12 +264,13 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::class_<migraphx::argument>(m, "argument", py::buffer_protocol())
.def_buffer([](migraphx::argument& x) -> py::buffer_info { return to_buffer_info(x); })
.def("__init__",
[](migraphx::argument& x, py::buffer b) {
py::buffer_info info = b.request();
new(&x) migraphx::argument(to_shape(info), info.ptr);
})
.def(py::init([](py::buffer b) {
py::buffer_info info = b.request();
return migraphx::argument(to_shape(info), info.ptr);
}))
.def("get_shape", &migraphx::argument::get_shape)
.def("data_ptr",
[](migraphx::argument& x) { return reinterpret_cast<std::uintptr_t>(x.data()); })
.def("tolist",
[](migraphx::argument& x) {
py::list l{x.get_shape().elements()};
......@@ -324,6 +326,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.def("get_parameter_names", &migraphx::program::get_parameter_names)
.def("get_parameter_shapes", &migraphx::program::get_parameter_shapes)
.def("get_output_shapes", &migraphx::program::get_output_shapes)
.def("is_compiled", &migraphx::program::is_compiled)
.def(
"compile",
[](migraphx::program& p, const migraphx::target& t, bool offload_copy, bool fast_math) {
......@@ -358,18 +361,35 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.def("__ne__", std::not_equal_to<migraphx::program>{})
.def("__repr__", [](const migraphx::program& p) { return migraphx::to_string(p); });
py::class_<migraphx::operation>(m, "op")
.def(py::init([](const std::string& name, py::kwargs kwargs) {
migraphx::value v = migraphx::value::object{};
if(kwargs)
{
v = migraphx::to_value(kwargs);
}
return migraphx::make_op(name, v);
}))
py::class_<migraphx::operation> op(m, "op");
op.def(py::init([](const std::string& name, py::kwargs kwargs) {
migraphx::value v = migraphx::value::object{};
if(kwargs)
{
v = migraphx::to_value(kwargs);
}
return migraphx::make_op(name, v);
}))
.def("name", &migraphx::operation::name);
py::enum_<migraphx::op::pooling_mode>(op, "pooling_mode")
.value("average", migraphx::op::pooling_mode::average)
.value("max", migraphx::op::pooling_mode::max)
.value("lpnorm", migraphx::op::pooling_mode::lpnorm);
py::enum_<migraphx::op::rnn_direction>(op, "rnn_direction")
.value("forward", migraphx::op::rnn_direction::forward)
.value("reverse", migraphx::op::rnn_direction::reverse)
.value("bidirectional", migraphx::op::rnn_direction::bidirectional);
m.def(
"argument_from_pointer",
[](const migraphx::shape shape, const int64_t address) {
return migraphx::argument(shape, reinterpret_cast<void*>(address));
},
py::arg("shape"),
py::arg("address"));
m.def(
"parse_tf",
[](const std::string& filename,
......
......@@ -70,7 +70,7 @@ void quantize_int8(program& prog,
{
std::set<std::string> op_names = {"convolution", "dot"};
std::set<std::string> input_ins_names(ins_names.begin(), ins_names.end());
if(!std::includes(
if(not std::includes(
op_names.begin(), op_names.end(), input_ins_names.begin(), input_ins_names.end()))
{
MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation");
......
......@@ -73,7 +73,7 @@ void insert_submod_allocations(instruction_ref ins, module& mod, const allocatio
name_shapes.insert(ps.begin(), ps.end());
}
for(auto& pn : name_shapes)
for(const auto& pn : name_shapes)
{
const auto& s = pn.second;
instruction_ref output{};
......
......@@ -21,44 +21,39 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_WHERE_HPP
#define MIGRAPHX_GUARD_RTGLIB_WHERE_HPP
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/where.hpp>
#include <migraphx/rewrite_gelu.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/match/gelu_erf.hpp>
#include <migraphx/common.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_where : ternary_device<hip_where, device::where>
struct find_gelu_erf
{
shape compute_shape(const std::vector<shape>& inputs) const
auto matcher() const { return match::gelu_erf(); }
void apply(module& m, const match::matcher_result& r) const
{
check_shapes{inputs, *this}.has(4).same_dims();
auto s1 = inputs.at(1);
auto s2 = inputs.at(2);
if(s1 == s2 and s1.packed())
{
return s1;
}
else if(s1.packed() != s2.packed())
{
return s1.packed() ? s1 : s2;
}
else if(s1.broadcasted() != s2.broadcasted())
{
return s1.broadcasted() ? s2.with_lens(s1.lens()) : s1.with_lens(s1.lens());
}
else
{
return {s1.type(), s1.lens()};
}
auto ins = r.result;
auto x = r.instructions["x"];
if(x->get_shape().type() != migraphx::shape::half_type)
return;
auto lit = m.add_literal(literal{shape{x->get_shape().type()}, {1.702f}});
auto mul = insert_common_op(m, ins, make_op("mul"), {x, lit});
auto sig = m.insert_instruction(ins, make_op("neg"), mul);
sig = m.insert_instruction(ins, make_op("exp"), sig);
auto one = m.add_literal(literal{shape{x->get_shape().type()}, {1.0f}});
sig = insert_common_op(m, ins, make_op("add"), {sig, one});
sig = m.insert_instruction(ins, make_op("div"), x, sig);
m.replace_instruction(ins, sig);
}
};
} // namespace gpu
void rewrite_gelu::apply(module& m) const { match::find_matches(m, find_gelu_erf{}); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -47,12 +47,12 @@ void rewrite_pooling::apply(module& m) const
if(not s.standard())
continue;
auto&& op = any_cast<op::pooling>(ins->get_operator());
if(!std::all_of(op.padding.begin(), op.padding.end(), [](auto i) { return i == 0; }))
if(not std::all_of(op.padding.begin(), op.padding.end(), [](auto i) { return i == 0; }))
continue;
if(!std::all_of(op.stride.begin(), op.stride.end(), [](auto i) { return i == 1; }))
if(not std::all_of(op.stride.begin(), op.stride.end(), [](auto i) { return i == 1; }))
continue;
auto lens = s.lens();
if(!std::equal(lens.begin() + 2, lens.end(), op.lengths.begin(), op.lengths.end()))
if(not std::equal(lens.begin() + 2, lens.end(), op.lengths.begin(), op.lengths.end()))
continue;
std::int64_t n = s.lens()[0];
std::int64_t c = s.lens()[1];
......
......@@ -214,7 +214,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
ih = m.add_literal(migraphx::literal{ih_shape, data});
}
if(!is_forward and variable_seq_len)
if(not is_forward and variable_seq_len)
{
args[0] =
m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
......@@ -520,7 +520,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
ih = m.add_literal(migraphx::literal{ih_shape, data});
}
if(!is_forward and variable_seq_len)
if(not is_forward and variable_seq_len)
{
args[0] =
m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
......@@ -977,7 +977,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
pph = args[7];
}
if(!is_forward and variable_seq_len)
if(not is_forward and variable_seq_len)
{
args[0] =
m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
......@@ -1294,11 +1294,11 @@ bool rewrite_rnn::is_variable_seq_lens(const module& m, instruction_ref seq_lens
std::vector<int64_t> vec_lens;
arg_lens.visit([&](auto l) { vec_lens.assign(l.begin(), l.end()); });
int64_t l = 0;
if(!vec_lens.empty())
if(not vec_lens.empty())
{
l = vec_lens[0];
}
if(!std::all_of(vec_lens.begin(), vec_lens.end(), [&](auto v) { return v == l; }))
if(not std::all_of(vec_lens.begin(), vec_lens.end(), [&](auto v) { return v == l; }))
{
is_var_lens = true;
}
......@@ -1318,7 +1318,7 @@ rewrite_rnn::get_seq_len(const module& m, instruction_ref input, instruction_ref
bool is_var_lens = is_variable_seq_lens(m, seq_lens);
auto input_shape = input->get_shape();
auto length = input_shape.lens()[0];
if(!is_var_lens and seq_lens != m.end())
if(not is_var_lens and seq_lens != m.end())
{
auto arg_len = seq_lens->eval();
std::vector<std::size_t> vec_lens;
......@@ -1387,7 +1387,7 @@ void rewrite_rnn::replace_last_cell_output(module& m,
if(variable_seq_len)
{
if(!ins_outputs.empty())
if(not ins_outputs.empty())
{
cell_outputs = m.insert_instruction(
std::next(ins),
......
......@@ -477,7 +477,7 @@ bool operator==(const shape::dynamic_dimension& x, const shape::dynamic_dimensio
bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y)
{
return !(x == y);
return not(x == y);
}
std::ostream& operator<<(std::ostream& os, const shape::dynamic_dimension& x)
{
......@@ -497,7 +497,7 @@ bool operator==(const shape& x, const shape& y)
x.strides() == y.strides() and x.sub_shapes() == y.sub_shapes());
}
bool operator!=(const shape& x, const shape& y) { return !(x == y); }
bool operator!=(const shape& x, const shape& y) { return not(x == y); }
std::ostream& operator<<(std::ostream& os, const shape& x)
{
......
......@@ -57,12 +57,14 @@ auto conv_const_weights()
auto reduction() { return match::name_contains("reduce"); }
// conv(x, w) * a => conv(x, a * w)
struct find_mul_conv
{
auto matcher() const
{
return match::name("mul")(match::either_arg(0, 1)(conv_const_weights().bind("conv"),
match::name("broadcast").bind("a")));
return match::name("mul")(
match::either_arg(0, 1)(conv_const_weights().bind("conv"),
match::name("broadcast", "multibroadcast").bind("a")));
}
void apply(module& m, const match::matcher_result& r) const
......@@ -72,14 +74,35 @@ struct find_mul_conv
auto a_ins = r.instructions["a"];
auto w_ins = r.instructions["w"];
auto broadcast_op = any_cast<op::broadcast>(a_ins->get_operator());
if(broadcast_op.axis != 1)
const auto& a_input_lens = a_ins->inputs().front()->get_shape().lens();
std::size_t num_not_one_dims = std::count_if(
a_input_lens.cbegin(), a_input_lens.cend(), [](auto dim) { return dim != 1; });
if(num_not_one_dims > 1)
return;
// check broadcasted along channels
const auto& a_lens = a_ins->get_shape().lens();
const auto& a_strides = a_ins->get_shape().strides();
auto is_broadcasted_axis = [](auto len, auto stride) { return len == 1 or stride == 0; };
if(a_strides.at(1) != 1)
return;
if(not is_broadcasted_axis(a_lens.front(), a_strides.front()))
return;
if(not std::equal(a_lens.begin() + 2,
a_lens.end(),
a_strides.begin() + 2,
a_strides.end(),
is_broadcasted_axis))
return;
auto sq = m.insert_instruction(ins, make_op("squeeze"), a_ins->inputs().front());
auto new_a = m.insert_instruction(
ins,
make_op("broadcast", {{"axis", 0}, {"out_lens", w_ins->get_shape().lens()}}),
a_ins->inputs().front());
ins, make_op("broadcast", {{"axis", 0}, {"out_lens", w_ins->get_shape().lens()}}), sq);
auto new_mul = m.insert_instruction(ins, make_op("mul"), new_a, w_ins);
auto new_conv = m.insert_instruction(
ins, conv_ins->get_operator(), conv_ins->inputs().front(), new_mul);
......@@ -208,6 +231,42 @@ struct find_mul_add
}
};
struct find_dot_add
{
auto matcher() const
{
return match::name("dot")(match::either_arg(0, 1)(
match::name("add")(
match::either_arg(0, 1)(match::any().bind("x"),
match::any_of(match::is_constant()).bind("b")),
match::none_of(match::args(match::is_constant(), match::is_constant())),
match::used_once()),
match::is_constant().bind("a")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"];
auto x_ins = r.instructions["x"];
assert(x_ins != b_ins);
const bool flipped = a_ins == ins->inputs().back();
auto insert_dot = [&](auto x, auto y) {
if(flipped)
return m.insert_instruction(ins, make_op("dot"), y, x);
else
return m.insert_instruction(ins, make_op("dot"), x, y);
};
auto ax_ins = insert_dot(a_ins, x_ins);
auto ab_ins = insert_dot(a_ins, b_ins);
m.replace_instruction(ins, make_op("add"), ax_ins, ab_ins);
}
};
struct find_add_lit_broadcast
{
auto matcher() const
......@@ -267,28 +326,26 @@ struct find_double_add_lit_broadcast
struct find_inner_broadcast
{
auto matcher() const
{
return pointwise(
match::nargs(2),
match::args(match::name("broadcast").bind("x"), match::name("broadcast").bind("y")));
}
auto matcher() const { return pointwise(match::all_of[match::inputs()](match::broadcast())); }
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
auto y_ins = r.instructions["y"];
auto xbroadcast = any_cast<op::broadcast>(x_ins->get_operator());
auto ybroadcast = any_cast<op::broadcast>(y_ins->get_operator());
if(xbroadcast.axis != ybroadcast.axis)
auto ins = r.result;
auto broadcasts = ins->inputs();
if(broadcasts.empty())
return;
std::vector<instruction_ref> inputs;
std::transform(broadcasts.begin(),
broadcasts.end(),
std::back_inserter(inputs),
[](auto i) { return i->inputs().front(); });
if(std::any_of(inputs.begin(), inputs.end(), [&](auto i) {
return i->get_shape() != inputs.front()->get_shape();
}))
return;
auto op = m.insert_instruction(
ins, ins->get_operator(), x_ins->inputs().front(), y_ins->inputs().front());
m.replace_instruction(ins, xbroadcast, op);
auto op = m.insert_instruction(ins, ins->get_operator(), inputs);
m.replace_instruction(ins, broadcasts.front()->get_operator(), op);
}
};
......@@ -378,6 +435,24 @@ struct find_concat_op
}
};
void move_instructions_back(module& m, instruction_ref pos, std::vector<instruction_ref> inss)
{
auto start = range(m.begin(), pos);
for(auto ins : iterator_for(start))
{
auto it = std::find(inss.begin(), inss.end(), ins);
if(it != inss.end())
inss.erase(it);
}
for(auto ins : inss)
{
if(not m.has_instruction(ins))
continue;
move_instructions_back(m, pos, ins->inputs());
m.move_instruction(ins, pos);
}
}
std::vector<instruction_ref> get_splits(instruction_ref ins)
{
std::vector<instruction_ref> result;
......@@ -416,8 +491,9 @@ struct find_splits
{
auto matcher() const
{
return match::any(match::any_of[match::outputs()](match::name("slice")(
match::any_of[match::outputs()](match::pointwise(), reduction()))));
return match::any(
match::any_of[match::outputs()](match::name("slice")(match::any_of[match::outputs()](
match::pointwise(match::any_of(match::nargs(1), match::nargs(2))), reduction()))));
}
static bool is_dependent(const module& m, instruction_ref ins1, instruction_ref ins2)
......@@ -552,8 +628,7 @@ struct find_splits
}))
return;
for(auto data : data_args)
m.move_instructions(data, ins);
move_instructions_back(m, ins, data_args);
auto slice_op = any_cast<op::slice>(splits.front()->get_operator());
assert(not slice_op.axes.empty());
......@@ -580,10 +655,9 @@ struct find_splits
auto outputs = i->outputs();
for(auto output : outputs)
{
if(not contains({"reshape", "squeeze", "unsqueeze"}, output->name()))
if(output->name() != "reshape")
continue;
auto x =
m.insert_instruction(output, make_op("contiguous"), output->inputs());
auto x = m.insert_instruction(output, make_op("contiguous"), i);
m.replace_instruction(output, output->get_operator(), x);
}
......@@ -753,7 +827,7 @@ MIGRAPHX_PRED_MATCHER(horiz_conv_dot, instruction_ref ins)
};
auto dots = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("dot"));
auto convs = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("convolution"));
return !(dots < 2 and convs < 2);
return not(dots < 2 and convs < 2);
}
struct find_conv_dot_horiz_fusion
......@@ -773,7 +847,7 @@ struct find_conv_dot_horiz_fusion
auto y = j->inputs()[1]->get_shape().lens();
if(x.size() != y.size())
return false;
// Check that non-axises match
// Check that non-axes match
int axis = 1;
if(i->name() == "dot")
{
......@@ -807,15 +881,23 @@ struct find_conv_dot_horiz_fusion
concat_axis = axis;
}
for(auto arg : args)
m.move_instructions(arg, input);
// TODO: Check if axises match
move_instructions_back(m, input, args);
// TODO: Check if axes match
auto concat =
m.insert_instruction(input, make_op("concat", {{"axis", concat_axis}}), args);
auto fused = m.insert_instruction(std::next(input), op, input, concat);
int64_t offset = 0;
for(auto arg : range(start, last))
{
auto outputs = arg->outputs();
for(auto output : outputs)
{
if(output->name() != "reshape")
continue;
auto x = m.insert_instruction(output, make_op("contiguous"), arg);
m.replace_instruction(output, output->get_operator(), x);
}
int64_t len = arg->get_shape().lens()[axis];
m.replace_instruction(
arg,
......@@ -926,7 +1008,7 @@ struct find_split_reshape
// all outputs are reshape and of the same shape
auto dims = any_cast<op::reshape>(rsp->get_operator()).dims;
if(!same_ops(vec_rsp))
if(not same_ops(vec_rsp))
{
return;
}
......@@ -942,23 +1024,42 @@ struct find_split_reshape
auto rsp_lens = rsp->get_shape().lens();
auto rsp_strides = rsp->get_shape().strides();
rsp_strides.insert(rsp_strides.begin(), rsp_strides[0] * rsp_lens[0]);
auto ait = std::find(rsp_strides.begin(), rsp_strides.end(), slc_dim_size);
auto ait = std::find(rsp_strides.begin(), rsp_strides.end(), slc_dim_size);
int rsp_axis = -1;
if(ait == rsp_strides.end())
{
return;
}
int rsp_axis = std::distance(rsp_strides.begin(), ait);
else if(ait == rsp_strides.end() - 1)
{
// edge case
// slice_dim == 1, in that case it could match with last stride of 1.
// it should accumulate lengths from last dim in that case. discount 1 to avoid going
// out of bounds.
assert(slc_dim_size == 1);
rsp_axis = std::distance(rsp_strides.begin(), ait) - 1;
}
else
{
rsp_axis = std::distance(rsp_strides.begin(), ait);
}
// calculate reshape output shape
std::vector<int64_t> vec_dims(vec_rsp.size());
std::transform(vec_rsp.begin(), vec_rsp.end(), vec_dims.begin(), [&](auto is) {
return is->get_shape().lens()[rsp_axis];
});
std::vector<int64_t> rsp_out_lens(rsp_lens.begin(), rsp_lens.end());
rsp_out_lens[rsp_axis] = std::accumulate(vec_dims.begin(), vec_dims.end(), std::int64_t{0});
// insert the reshape instruction
// insert the reshape instruction and add contiguous if needed
if(not input->get_shape().standard())
{
input = m.insert_instruction(std::next(input), make_op("contiguous"), input);
}
auto rsp_ins = m.insert_instruction(
std::next(input), make_op("reshape", {{"dims", rsp_out_lens}}), input);
......@@ -1005,7 +1106,7 @@ struct find_split_transpose
// all transpose are the same
auto perm = any_cast<op::transpose>(trans->get_operator()).dims;
if(!same_ops(vec_trans))
if(not same_ops(vec_trans))
{
return;
}
......@@ -1048,6 +1149,7 @@ void simplify_algebra::apply(module& m) const
find_mul_conv{},
find_mul_slice_conv{},
find_mul_add{},
find_dot_add{},
find_div_const{},
find_sub_const{},
find_rsqrt{},
......
......@@ -99,7 +99,7 @@ struct find_reshaper
std::vector<instruction_ref> reshapes{ins};
while(is_reshaper(reshapes.back()))
{
assert(!reshapes.back()->inputs().empty());
assert(not reshapes.back()->inputs().empty());
assert(m.has_instruction(reshapes.back()->inputs().front()));
auto input = reshapes.back()->inputs().front();
reshapes.push_back(input);
......@@ -151,8 +151,11 @@ struct find_transpose
{
auto matcher() const
{
return match::name("transpose")(match::none_of(
match::skip_output(match::name("contiguous"))(match::name("transpose"))));
auto output_not_transpose =
match::none_of(match::skip_output(match::name("contiguous"))(match::name("transpose")));
auto input_has_transpose =
match::args(match::skip(match::name("contiguous"))(match::name("transpose")));
return match::name("transpose")(output_not_transpose, input_has_transpose);
}
void apply(module& m, const match::matcher_result& mr) const
......@@ -268,6 +271,44 @@ struct find_nested_slice
}
};
struct find_concat_multibroadcasts
{
auto matcher() const
{
return match::name("concat")(match::all_of[match::inputs()](match::name("multibroadcast")));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto op = any_cast<op::concat>(ins->get_operator());
auto out_lens = ins->get_shape().lens();
auto inputs = ins->inputs();
auto in_strides = inputs.front()->get_shape().strides();
// Only apply when concat axis is not a broadcasted dimension
if(std::any_of(inputs.begin(), inputs.end(), [&](auto i) {
return i->get_shape().strides()[op.axis] == 0;
}))
{
return;
}
// Use inputs of multibroadcast ops as inputs to new concat op
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [](auto i) {
return i->inputs().front();
});
// Reduce axis by number of leading broadcasted dimensions
if(inputs.front()->get_shape().lens().size() < out_lens.size())
op.axis -= std::count(in_strides.begin(), in_strides.begin() + op.axis, 0);
auto concat = m.insert_instruction(ins, op, inputs);
m.replace_instruction(
ins, migraphx::make_op("multibroadcast", {{"out_lens", out_lens}}), concat);
}
};
struct find_concat_transpose
{
auto matcher() const
......@@ -285,7 +326,7 @@ struct find_concat_transpose
auto permutation = find_permutation(s);
// permutation should be the same for all inputs
if(!std::all_of(trans_inputs.begin(), trans_inputs.end(), [&](auto in) {
if(not std::all_of(trans_inputs.begin(), trans_inputs.end(), [&](auto in) {
return (find_permutation(in->get_shape()) == permutation);
}))
{
......@@ -664,9 +705,94 @@ struct find_slice_transpose
}
};
struct find_transpose_slice
{
auto matcher() const
{
return match::name("transpose")(match::all_of[match::outputs()](match::name("slice")));
}
static std::vector<int64_t> slice_distance(const op::slice& op)
{
assert(op.starts.size() == op.ends.size());
std::vector<int64_t> result(op.starts.size());
std::transform(
op.ends.begin(), op.ends.end(), op.starts.begin(), result.begin(), std::minus<>{});
return result;
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto slices = ins->outputs();
if(slices.empty())
return;
auto slice = any_cast<op::slice>(slices.front()->get_operator());
auto sdistance = slice_distance(slice);
// Check all distances and axes are the same
if(std::any_of(slices.begin(), slices.end(), [&](auto sins) {
auto s = any_cast<op::slice>(sins->get_operator());
return s.axes != slice.axes or slice_distance(s) != sdistance;
}))
return;
// Check distances are divisible by lens of corresponding axes
auto mod_by_distance = [&](const auto& v, auto f) {
return std::inner_product(v.begin(),
v.end(),
sdistance.begin(),
0,
std::plus<>{},
[&](auto x, auto d) -> uint64_t {
if(d == 0)
return 1;
return f(x) % d;
});
};
if(mod_by_distance(slice.axes, [&](auto x) { return ins->get_shape().lens()[x]; }) != 0 or
mod_by_distance(slice.starts, id{}) != 0 or mod_by_distance(slice.ends, id{}) != 0)
return;
// TODO: Handle multiple axes
if(sdistance.size() != 1)
return;
auto axis = slice.axes.front();
// Skip if axis would be packed
if(std::all_of(ins->get_shape().lens().begin(),
ins->get_shape().lens().begin() + axis,
[](auto x) { return x == 1; }))
return;
// Compute axis before transpose to use for unsqueeze
auto perm = ins->get_operator().to_value()["permutation"].to_vector<int64_t>();
auto preaxis = std::find(perm.begin(), perm.end(), axis) - perm.begin();
// Make unsqeeze
auto unsqueeze = m.insert_instruction(
ins, make_op("unsqueeze", {{"axes", {preaxis}}, {"steps", sdistance}}), ins->inputs());
// Make transpose
std::transform(perm.begin(), perm.end(), perm.begin(), [&](auto i) {
if(i > preaxis)
return i + 1;
return i;
});
perm.insert(perm.begin(), preaxis + 1);
auto transpose =
m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), unsqueeze);
// Slice and squeeze
for(auto s : slices)
{
auto op = any_cast<op::slice>(s->get_operator());
op.axes = {0};
op.starts = {op.starts.front() / sdistance.front()};
op.ends = {op.ends.front() / sdistance.front()};
auto slice_ins = m.insert_instruction(ins, op, transpose);
auto squeeze =
m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), slice_ins);
m.replace_instruction(s, squeeze);
}
}
};
void simplify_reshapes::apply(module& m) const
{
for(int i = 0; i < 2; i++)
for(int i = 0; i < 4; i++)
{
match::find_matches(m,
find_where_op{},
......@@ -676,9 +802,11 @@ void simplify_reshapes::apply(module& m) const
find_reshaper{},
find_transpose{},
find_concat_transpose{},
find_concat_multibroadcasts{},
find_nested_convert{},
find_nested_slice{},
find_nested_concat{},
find_transpose_slice{},
find_slice_transpose{},
find_transpose_contiguous_reshaper_unary{});
dead_code_elimination{}.apply(m);
......
......@@ -35,6 +35,7 @@ add_library(migraphx_cpu
dnnl.cpp
eltwise.cpp
erf.cpp
fmod.cpp
fuse_ops.cpp
gather.cpp
gemm.cpp
......@@ -42,6 +43,7 @@ add_library(migraphx_cpu
logsoftmax.cpp
lowering.cpp
lrn.cpp
mod.cpp
preallocate.cpp
pooling.cpp
reduction.cpp
......
......@@ -49,7 +49,7 @@ struct dnnl_binary : dnnl_op<dnnl_binary, dnnl::binary>
auto s0 = inputs.at(0);
auto s1 = inputs.at(1);
auto r = s0;
if(s0 != s1 or !s0.packed())
if(s0 != s1 or not s0.packed())
{
r = shape{s0.type(), s0.lens()};
}
......
......@@ -21,16 +21,16 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/target_assignments.hpp>
#include <migraphx/config.hpp>
#include <migraphx/cpu/pointwise.hpp>
#include <migraphx/op/fmod.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace cpu {
void target_assignments::add_assignment(instruction_ref ins, const std::string& target)
{
assignments.emplace(ins, target);
}
template struct cpu_binary<op::fmod>;
} // namespace cpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment