"vscode:/vscode.git/clone" did not exist on "0fd3a3e860a56490a52b330ac49de4f7a8615e78"
Commit 9d2cdf25 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into develop

parents 66d50268 ed2c73ac
...@@ -29,7 +29,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -29,7 +29,7 @@ inline namespace MIGRAPHX_INLINE_NS {
void memory_coloring::apply(module& m) const void memory_coloring::apply(module& m) const
{ {
if(!enabled(MIGRAPHX_DISABLE_MEMORY_COLORING{})) if(not enabled(MIGRAPHX_DISABLE_MEMORY_COLORING{}))
{ {
memory_coloring_impl opt(&m, allocation_op, verify); memory_coloring_impl opt(&m, allocation_op, verify);
opt.run(); opt.run();
......
...@@ -42,7 +42,7 @@ void memory_coloring_impl::run() ...@@ -42,7 +42,7 @@ void memory_coloring_impl::run()
{ {
MIGRAPHX_DEBUG(dump_intervals()); MIGRAPHX_DEBUG(dump_intervals());
// Coloring // Coloring
while(!alloc_queue.empty()) while(not alloc_queue.empty())
{ {
interval_ptr interval = alloc_queue.top(); interval_ptr interval = alloc_queue.top();
allocate(interval); allocate(interval);
...@@ -96,7 +96,7 @@ bool memory_coloring_impl::allocate(interval_ptr interval) ...@@ -96,7 +96,7 @@ bool memory_coloring_impl::allocate(interval_ptr interval)
} }
std::size_t offset = 0; std::size_t offset = 0;
while(!conflict_queue.empty()) while(not conflict_queue.empty())
{ {
live_range* range = conflict_queue.top(); live_range* range = conflict_queue.top();
std::size_t iter_offset = range->offset; std::size_t iter_offset = range->offset;
...@@ -149,7 +149,7 @@ void memory_coloring_impl::build() ...@@ -149,7 +149,7 @@ void memory_coloring_impl::build()
{ {
def_interval = instr2_live[p_iter]; def_interval = instr2_live[p_iter];
bool is_lit = is_literal(iter); bool is_lit = is_literal(iter);
if(is_allocate(iter) || is_lit) if(is_allocate(iter) or is_lit)
{ {
live_range& range = def_interval->segment; live_range& range = def_interval->segment;
def_interval->result = iter->get_shape(); def_interval->result = iter->get_shape();
...@@ -157,12 +157,12 @@ void memory_coloring_impl::build() ...@@ -157,12 +157,12 @@ void memory_coloring_impl::build()
range.begin = cur_points; range.begin = cur_points;
def_interval->def_point = cur_points; def_interval->def_point = cur_points;
range.size = (iter->get_shape()).bytes(); range.size = (iter->get_shape()).bytes();
if(!is_lit || unify_literals) if(not is_lit or unify_literals)
alloc_queue.push(def_interval); alloc_queue.push(def_interval);
live_set.erase(range.vn); live_set.erase(range.vn);
} }
} }
else if(!is_param(iter) && !is_outline(iter) && !is_check_context(iter)) else if(not is_param(iter) && not is_outline(iter) && not is_check_context(iter))
{ {
is_dead = true; is_dead = true;
} }
...@@ -179,7 +179,7 @@ void memory_coloring_impl::build() ...@@ -179,7 +179,7 @@ void memory_coloring_impl::build()
if(not p_mod->has_instruction(arg)) if(not p_mod->has_instruction(arg))
continue; continue;
if(is_param(arg) || is_outline(arg)) if(is_param(arg) or is_outline(arg))
{ {
if(is_output_param(arg)) if(is_output_param(arg))
is_dead = false; is_dead = false;
...@@ -235,7 +235,7 @@ void memory_coloring_impl::rewrite() ...@@ -235,7 +235,7 @@ void memory_coloring_impl::rewrite()
if(interval->get_begin() == invalid_offset) if(interval->get_begin() == invalid_offset)
continue; continue;
if(!unify_literals && interval->is_literal) if(not unify_literals && interval->is_literal)
continue; continue;
std::size_t offset = 0; std::size_t offset = 0;
...@@ -272,7 +272,7 @@ void memory_coloring_impl::verify() ...@@ -272,7 +272,7 @@ void memory_coloring_impl::verify()
if(segment.begin == invalid_offset) if(segment.begin == invalid_offset)
{ {
// if(!interval.is_live_on_entry) // if(not interval.is_live_on_entry)
// MIGRAPHX_THROW("interval is not live on entry"); // MIGRAPHX_THROW("interval is not live on entry");
continue; continue;
} }
...@@ -290,7 +290,7 @@ void memory_coloring_impl::verify() ...@@ -290,7 +290,7 @@ void memory_coloring_impl::verify()
live_range* range = live_ranges[iter]; live_range* range = live_ranges[iter];
if(range->offset == invalid_offset) if(range->offset == invalid_offset)
continue; continue;
if(!is_disjoin(*range, segment)) if(not is_disjoin(*range, segment))
MIGRAPHX_THROW("range and segment is not disjoined"); MIGRAPHX_THROW("range and segment is not disjoined");
} }
} }
......
...@@ -125,11 +125,11 @@ struct memory_coloring_impl ...@@ -125,11 +125,11 @@ struct memory_coloring_impl
static bool is_disjoin(const live_range& range1, const live_range& range2) static bool is_disjoin(const live_range& range1, const live_range& range2)
{ {
if((range1.size == 0) || (range2.size == 0)) if((range1.size == 0) or (range2.size == 0))
return false; return false;
auto end1 = range1.offset + range1.size - 1; auto end1 = range1.offset + range1.size - 1;
auto end2 = range2.offset + range2.size - 1; auto end2 = range2.offset + range2.size - 1;
return ((end1 < range2.offset) || (end2 < range1.offset)); return ((end1 < range2.offset) or (end2 < range1.offset));
} }
void verify(); void verify();
#ifdef MIGRAPHX_DEBUG_OPT #ifdef MIGRAPHX_DEBUG_OPT
......
...@@ -60,7 +60,7 @@ std::vector<std::size_t> calc_dyn_auto_pad(std::vector<std::size_t> tensor_lens, ...@@ -60,7 +60,7 @@ std::vector<std::size_t> calc_dyn_auto_pad(std::vector<std::size_t> tensor_lens,
{ {
std::vector<std::size_t> padding; std::vector<std::size_t> padding;
padding.resize(2 * k_lens.size()); padding.resize(2 * k_lens.size());
for(size_t i = 0; i < padding.size() / 2; i++) for(std::size_t i = 0; i < padding.size() / 2; i++)
{ {
std::ptrdiff_t input_dim = tensor_lens[i]; std::ptrdiff_t input_dim = tensor_lens[i];
std::ptrdiff_t stride = strides[i]; std::ptrdiff_t stride = strides[i];
......
...@@ -50,7 +50,7 @@ int exec(const std::string& cmd, const std::function<void(const char*)>& std_out ...@@ -50,7 +50,7 @@ int exec(const std::string& cmd, const std::function<void(const char*)>& std_out
{ {
// TODO: Use execve instead of popen // TODO: Use execve instead of popen
std::unique_ptr<FILE, decltype(closer)> pipe(popen(cmd.c_str(), "r"), closer); // NOLINT std::unique_ptr<FILE, decltype(closer)> pipe(popen(cmd.c_str(), "r"), closer); // NOLINT
if(!pipe) if(not pipe)
MIGRAPHX_THROW("popen() failed: " + cmd); MIGRAPHX_THROW("popen() failed: " + cmd);
std::array<char, 128> buffer; std::array<char, 128> buffer;
while(fgets(buffer.data(), buffer.size(), pipe.get()) != nullptr) while(fgets(buffer.data(), buffer.size(), pipe.get()) != nullptr)
......
...@@ -78,11 +78,11 @@ program& program::operator=(program p) ...@@ -78,11 +78,11 @@ program& program::operator=(program p)
void program::assign(const program& p) void program::assign(const program& p)
{ {
if(!impl) if(not impl)
{ {
impl = std::make_unique<program_impl>(); impl = std::make_unique<program_impl>();
} }
else if(!impl->modules.empty()) else if(not impl->modules.empty())
{ {
impl->modules.clear(); impl->modules.clear();
} }
......
...@@ -83,7 +83,7 @@ void visit_py(T x, F f) ...@@ -83,7 +83,7 @@ void visit_py(T x, F f)
{ {
f(x.template cast<bool>()); f(x.template cast<bool>());
} }
else if(py::isinstance<py::int_>(x) || py::hasattr(x, "__index__")) else if(py::isinstance<py::int_>(x) or py::hasattr(x, "__index__"))
{ {
f(x.template cast<int>()); f(x.template cast<int>());
} }
......
...@@ -70,7 +70,7 @@ void quantize_int8(program& prog, ...@@ -70,7 +70,7 @@ void quantize_int8(program& prog,
{ {
std::set<std::string> op_names = {"convolution", "dot"}; std::set<std::string> op_names = {"convolution", "dot"};
std::set<std::string> input_ins_names(ins_names.begin(), ins_names.end()); std::set<std::string> input_ins_names(ins_names.begin(), ins_names.end());
if(!std::includes( if(not std::includes(
op_names.begin(), op_names.end(), input_ins_names.begin(), input_ins_names.end())) op_names.begin(), op_names.end(), input_ins_names.begin(), input_ins_names.end()))
{ {
MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation"); MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation");
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/rewrite_gelu.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/match/gelu_erf.hpp>
#include <migraphx/common.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct find_gelu_erf
{
auto matcher() const { return match::gelu_erf(); }
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x = r.instructions["x"];
if(x->get_shape().type() != migraphx::shape::half_type)
return;
auto lit = m.add_literal(literal{shape{x->get_shape().type()}, {1.702f}});
auto mul = insert_common_op(m, ins, make_op("mul"), {x, lit});
auto sig = m.insert_instruction(ins, make_op("neg"), mul);
sig = m.insert_instruction(ins, make_op("exp"), sig);
auto one = m.add_literal(literal{shape{x->get_shape().type()}, {1.0f}});
sig = insert_common_op(m, ins, make_op("add"), {sig, one});
sig = m.insert_instruction(ins, make_op("div"), x, sig);
m.replace_instruction(ins, sig);
}
};
void rewrite_gelu::apply(module& m) const { match::find_matches(m, find_gelu_erf{}); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -47,12 +47,12 @@ void rewrite_pooling::apply(module& m) const ...@@ -47,12 +47,12 @@ void rewrite_pooling::apply(module& m) const
if(not s.standard()) if(not s.standard())
continue; continue;
auto&& op = any_cast<op::pooling>(ins->get_operator()); auto&& op = any_cast<op::pooling>(ins->get_operator());
if(!std::all_of(op.padding.begin(), op.padding.end(), [](auto i) { return i == 0; })) if(not std::all_of(op.padding.begin(), op.padding.end(), [](auto i) { return i == 0; }))
continue; continue;
if(!std::all_of(op.stride.begin(), op.stride.end(), [](auto i) { return i == 1; })) if(not std::all_of(op.stride.begin(), op.stride.end(), [](auto i) { return i == 1; }))
continue; continue;
auto lens = s.lens(); auto lens = s.lens();
if(!std::equal(lens.begin() + 2, lens.end(), op.lengths.begin(), op.lengths.end())) if(not std::equal(lens.begin() + 2, lens.end(), op.lengths.begin(), op.lengths.end()))
continue; continue;
std::int64_t n = s.lens()[0]; std::int64_t n = s.lens()[0];
std::int64_t c = s.lens()[1]; std::int64_t c = s.lens()[1];
......
...@@ -214,7 +214,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const ...@@ -214,7 +214,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
ih = m.add_literal(migraphx::literal{ih_shape, data}); ih = m.add_literal(migraphx::literal{ih_shape, data});
} }
if(!is_forward and variable_seq_len) if(not is_forward and variable_seq_len)
{ {
args[0] = args[0] =
m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens); m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
...@@ -520,7 +520,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const ...@@ -520,7 +520,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
ih = m.add_literal(migraphx::literal{ih_shape, data}); ih = m.add_literal(migraphx::literal{ih_shape, data});
} }
if(!is_forward and variable_seq_len) if(not is_forward and variable_seq_len)
{ {
args[0] = args[0] =
m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens); m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
...@@ -977,7 +977,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const ...@@ -977,7 +977,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
pph = args[7]; pph = args[7];
} }
if(!is_forward and variable_seq_len) if(not is_forward and variable_seq_len)
{ {
args[0] = args[0] =
m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens); m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
...@@ -1294,11 +1294,11 @@ bool rewrite_rnn::is_variable_seq_lens(const module& m, instruction_ref seq_lens ...@@ -1294,11 +1294,11 @@ bool rewrite_rnn::is_variable_seq_lens(const module& m, instruction_ref seq_lens
std::vector<int64_t> vec_lens; std::vector<int64_t> vec_lens;
arg_lens.visit([&](auto l) { vec_lens.assign(l.begin(), l.end()); }); arg_lens.visit([&](auto l) { vec_lens.assign(l.begin(), l.end()); });
int64_t l = 0; int64_t l = 0;
if(!vec_lens.empty()) if(not vec_lens.empty())
{ {
l = vec_lens[0]; l = vec_lens[0];
} }
if(!std::all_of(vec_lens.begin(), vec_lens.end(), [&](auto v) { return v == l; })) if(not std::all_of(vec_lens.begin(), vec_lens.end(), [&](auto v) { return v == l; }))
{ {
is_var_lens = true; is_var_lens = true;
} }
...@@ -1318,7 +1318,7 @@ rewrite_rnn::get_seq_len(const module& m, instruction_ref input, instruction_ref ...@@ -1318,7 +1318,7 @@ rewrite_rnn::get_seq_len(const module& m, instruction_ref input, instruction_ref
bool is_var_lens = is_variable_seq_lens(m, seq_lens); bool is_var_lens = is_variable_seq_lens(m, seq_lens);
auto input_shape = input->get_shape(); auto input_shape = input->get_shape();
auto length = input_shape.lens()[0]; auto length = input_shape.lens()[0];
if(!is_var_lens and seq_lens != m.end()) if(not is_var_lens and seq_lens != m.end())
{ {
auto arg_len = seq_lens->eval(); auto arg_len = seq_lens->eval();
std::vector<std::size_t> vec_lens; std::vector<std::size_t> vec_lens;
...@@ -1387,7 +1387,7 @@ void rewrite_rnn::replace_last_cell_output(module& m, ...@@ -1387,7 +1387,7 @@ void rewrite_rnn::replace_last_cell_output(module& m,
if(variable_seq_len) if(variable_seq_len)
{ {
if(!ins_outputs.empty()) if(not ins_outputs.empty())
{ {
cell_outputs = m.insert_instruction( cell_outputs = m.insert_instruction(
std::next(ins), std::next(ins),
......
...@@ -477,7 +477,7 @@ bool operator==(const shape::dynamic_dimension& x, const shape::dynamic_dimensio ...@@ -477,7 +477,7 @@ bool operator==(const shape::dynamic_dimension& x, const shape::dynamic_dimensio
bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y) bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y)
{ {
return !(x == y); return not(x == y);
} }
std::ostream& operator<<(std::ostream& os, const shape::dynamic_dimension& x) std::ostream& operator<<(std::ostream& os, const shape::dynamic_dimension& x)
{ {
...@@ -497,7 +497,7 @@ bool operator==(const shape& x, const shape& y) ...@@ -497,7 +497,7 @@ bool operator==(const shape& x, const shape& y)
x.strides() == y.strides() and x.sub_shapes() == y.sub_shapes()); x.strides() == y.strides() and x.sub_shapes() == y.sub_shapes());
} }
bool operator!=(const shape& x, const shape& y) { return !(x == y); } bool operator!=(const shape& x, const shape& y) { return not(x == y); }
std::ostream& operator<<(std::ostream& os, const shape& x) std::ostream& operator<<(std::ostream& os, const shape& x)
{ {
......
...@@ -208,6 +208,42 @@ struct find_mul_add ...@@ -208,6 +208,42 @@ struct find_mul_add
} }
}; };
struct find_dot_add
{
auto matcher() const
{
return match::name("dot")(match::either_arg(0, 1)(
match::name("add")(
match::either_arg(0, 1)(match::any().bind("x"),
match::any_of(match::is_constant()).bind("b")),
match::none_of(match::args(match::is_constant(), match::is_constant())),
match::used_once()),
match::is_constant().bind("a")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"];
auto x_ins = r.instructions["x"];
assert(x_ins != b_ins);
const bool flipped = a_ins == ins->inputs().back();
auto insert_dot = [&](auto x, auto y) {
if(flipped)
return m.insert_instruction(ins, make_op("dot"), y, x);
else
return m.insert_instruction(ins, make_op("dot"), x, y);
};
auto ax_ins = insert_dot(a_ins, x_ins);
auto ab_ins = insert_dot(a_ins, b_ins);
m.replace_instruction(ins, make_op("add"), ax_ins, ab_ins);
}
};
struct find_add_lit_broadcast struct find_add_lit_broadcast
{ {
auto matcher() const auto matcher() const
...@@ -267,28 +303,26 @@ struct find_double_add_lit_broadcast ...@@ -267,28 +303,26 @@ struct find_double_add_lit_broadcast
struct find_inner_broadcast struct find_inner_broadcast
{ {
auto matcher() const auto matcher() const { return pointwise(match::all_of[match::inputs()](match::broadcast())); }
{
return pointwise(
match::nargs(2),
match::args(match::name("broadcast").bind("x"), match::name("broadcast").bind("y")));
}
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto broadcasts = ins->inputs();
auto y_ins = r.instructions["y"]; if(broadcasts.empty())
return;
auto xbroadcast = any_cast<op::broadcast>(x_ins->get_operator()); std::vector<instruction_ref> inputs;
auto ybroadcast = any_cast<op::broadcast>(y_ins->get_operator()); std::transform(broadcasts.begin(),
broadcasts.end(),
if(xbroadcast.axis != ybroadcast.axis) std::back_inserter(inputs),
[](auto i) { return i->inputs().front(); });
if(std::any_of(inputs.begin(), inputs.end(), [&](auto i) {
return i->get_shape() != inputs.front()->get_shape();
}))
return; return;
auto op = m.insert_instruction( auto op = m.insert_instruction(ins, ins->get_operator(), inputs);
ins, ins->get_operator(), x_ins->inputs().front(), y_ins->inputs().front()); m.replace_instruction(ins, broadcasts.front()->get_operator(), op);
m.replace_instruction(ins, xbroadcast, op);
} }
}; };
...@@ -416,8 +450,9 @@ struct find_splits ...@@ -416,8 +450,9 @@ struct find_splits
{ {
auto matcher() const auto matcher() const
{ {
return match::any(match::any_of[match::outputs()](match::name("slice")( return match::any(
match::any_of[match::outputs()](match::pointwise(), reduction())))); match::any_of[match::outputs()](match::name("slice")(match::any_of[match::outputs()](
match::pointwise(match::any_of(match::nargs(1), match::nargs(2))), reduction()))));
} }
static bool is_dependent(const module& m, instruction_ref ins1, instruction_ref ins2) static bool is_dependent(const module& m, instruction_ref ins1, instruction_ref ins2)
...@@ -580,10 +615,9 @@ struct find_splits ...@@ -580,10 +615,9 @@ struct find_splits
auto outputs = i->outputs(); auto outputs = i->outputs();
for(auto output : outputs) for(auto output : outputs)
{ {
if(not contains({"reshape", "squeeze", "unsqueeze"}, output->name())) if(output->name() != "reshape")
continue; continue;
auto x = auto x = m.insert_instruction(output, make_op("contiguous"), i);
m.insert_instruction(output, make_op("contiguous"), output->inputs());
m.replace_instruction(output, output->get_operator(), x); m.replace_instruction(output, output->get_operator(), x);
} }
...@@ -753,7 +787,7 @@ MIGRAPHX_PRED_MATCHER(horiz_conv_dot, instruction_ref ins) ...@@ -753,7 +787,7 @@ MIGRAPHX_PRED_MATCHER(horiz_conv_dot, instruction_ref ins)
}; };
auto dots = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("dot")); auto dots = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("dot"));
auto convs = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("convolution")); auto convs = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("convolution"));
return !(dots < 2 and convs < 2); return not(dots < 2 and convs < 2);
} }
struct find_conv_dot_horiz_fusion struct find_conv_dot_horiz_fusion
...@@ -773,7 +807,7 @@ struct find_conv_dot_horiz_fusion ...@@ -773,7 +807,7 @@ struct find_conv_dot_horiz_fusion
auto y = j->inputs()[1]->get_shape().lens(); auto y = j->inputs()[1]->get_shape().lens();
if(x.size() != y.size()) if(x.size() != y.size())
return false; return false;
// Check that non-axises match // Check that non-axes match
int axis = 1; int axis = 1;
if(i->name() == "dot") if(i->name() == "dot")
{ {
...@@ -809,13 +843,22 @@ struct find_conv_dot_horiz_fusion ...@@ -809,13 +843,22 @@ struct find_conv_dot_horiz_fusion
for(auto arg : args) for(auto arg : args)
m.move_instructions(arg, input); m.move_instructions(arg, input);
// TODO: Check if axises match // TODO: Check if axes match
auto concat = auto concat =
m.insert_instruction(input, make_op("concat", {{"axis", concat_axis}}), args); m.insert_instruction(input, make_op("concat", {{"axis", concat_axis}}), args);
auto fused = m.insert_instruction(std::next(input), op, input, concat); auto fused = m.insert_instruction(std::next(input), op, input, concat);
int64_t offset = 0; int64_t offset = 0;
for(auto arg : range(start, last)) for(auto arg : range(start, last))
{ {
auto outputs = arg->outputs();
for(auto output : outputs)
{
if(output->name() != "reshape")
continue;
auto x = m.insert_instruction(output, make_op("contiguous"), arg);
m.replace_instruction(output, output->get_operator(), x);
}
int64_t len = arg->get_shape().lens()[axis]; int64_t len = arg->get_shape().lens()[axis];
m.replace_instruction( m.replace_instruction(
arg, arg,
...@@ -926,7 +969,7 @@ struct find_split_reshape ...@@ -926,7 +969,7 @@ struct find_split_reshape
// all outputs are reshape and of the same shape // all outputs are reshape and of the same shape
auto dims = any_cast<op::reshape>(rsp->get_operator()).dims; auto dims = any_cast<op::reshape>(rsp->get_operator()).dims;
if(!same_ops(vec_rsp)) if(not same_ops(vec_rsp))
{ {
return; return;
} }
...@@ -958,7 +1001,11 @@ struct find_split_reshape ...@@ -958,7 +1001,11 @@ struct find_split_reshape
std::vector<int64_t> rsp_out_lens(rsp_lens.begin(), rsp_lens.end()); std::vector<int64_t> rsp_out_lens(rsp_lens.begin(), rsp_lens.end());
rsp_out_lens[rsp_axis] = std::accumulate(vec_dims.begin(), vec_dims.end(), std::int64_t{0}); rsp_out_lens[rsp_axis] = std::accumulate(vec_dims.begin(), vec_dims.end(), std::int64_t{0});
// insert the reshape instruction // insert the reshape instruction and add contiguous if needed
if(not input->get_shape().standard())
{
input = m.insert_instruction(std::next(input), make_op("contiguous"), input);
}
auto rsp_ins = m.insert_instruction( auto rsp_ins = m.insert_instruction(
std::next(input), make_op("reshape", {{"dims", rsp_out_lens}}), input); std::next(input), make_op("reshape", {{"dims", rsp_out_lens}}), input);
...@@ -1005,7 +1052,7 @@ struct find_split_transpose ...@@ -1005,7 +1052,7 @@ struct find_split_transpose
// all transpose are the same // all transpose are the same
auto perm = any_cast<op::transpose>(trans->get_operator()).dims; auto perm = any_cast<op::transpose>(trans->get_operator()).dims;
if(!same_ops(vec_trans)) if(not same_ops(vec_trans))
{ {
return; return;
} }
...@@ -1048,6 +1095,7 @@ void simplify_algebra::apply(module& m) const ...@@ -1048,6 +1095,7 @@ void simplify_algebra::apply(module& m) const
find_mul_conv{}, find_mul_conv{},
find_mul_slice_conv{}, find_mul_slice_conv{},
find_mul_add{}, find_mul_add{},
find_dot_add{},
find_div_const{}, find_div_const{},
find_sub_const{}, find_sub_const{},
find_rsqrt{}, find_rsqrt{},
......
...@@ -99,7 +99,7 @@ struct find_reshaper ...@@ -99,7 +99,7 @@ struct find_reshaper
std::vector<instruction_ref> reshapes{ins}; std::vector<instruction_ref> reshapes{ins};
while(is_reshaper(reshapes.back())) while(is_reshaper(reshapes.back()))
{ {
assert(!reshapes.back()->inputs().empty()); assert(not reshapes.back()->inputs().empty());
assert(m.has_instruction(reshapes.back()->inputs().front())); assert(m.has_instruction(reshapes.back()->inputs().front()));
auto input = reshapes.back()->inputs().front(); auto input = reshapes.back()->inputs().front();
reshapes.push_back(input); reshapes.push_back(input);
...@@ -288,7 +288,7 @@ struct find_concat_transpose ...@@ -288,7 +288,7 @@ struct find_concat_transpose
auto permutation = find_permutation(s); auto permutation = find_permutation(s);
// permutation should be the same for all inputs // permutation should be the same for all inputs
if(!std::all_of(trans_inputs.begin(), trans_inputs.end(), [&](auto in) { if(not std::all_of(trans_inputs.begin(), trans_inputs.end(), [&](auto in) {
return (find_permutation(in->get_shape()) == permutation); return (find_permutation(in->get_shape()) == permutation);
})) }))
{ {
......
...@@ -49,7 +49,7 @@ struct dnnl_binary : dnnl_op<dnnl_binary, dnnl::binary> ...@@ -49,7 +49,7 @@ struct dnnl_binary : dnnl_op<dnnl_binary, dnnl::binary>
auto s0 = inputs.at(0); auto s0 = inputs.at(0);
auto s1 = inputs.at(1); auto s1 = inputs.at(1);
auto r = s0; auto r = s0;
if(s0 != s1 or !s0.packed()) if(s0 != s1 or not s0.packed())
{ {
r = shape{s0.type(), s0.lens()}; r = shape{s0.type(), s0.lens()};
} }
......
...@@ -95,7 +95,7 @@ void subgraph::apply(module_pass_manager& mpm) const ...@@ -95,7 +95,7 @@ void subgraph::apply(module_pass_manager& mpm) const
for(auto it : iterator_for(mod)) for(auto it : iterator_for(mod))
{ {
// assuming we want all the params/literals as inputs to the FPGA submodule // assuming we want all the params/literals as inputs to the FPGA submodule
if(migraphx::starts_with(it->name(), "@param") || if(migraphx::starts_with(it->name(), "@param") or
migraphx::starts_with(it->name(), "@literal")) migraphx::starts_with(it->name(), "@literal"))
{ {
literal_inputs.push_back(it); literal_inputs.push_back(it);
......
...@@ -51,7 +51,8 @@ code_object_op::compute(context& ctx, const shape&, const std::vector<argument>& ...@@ -51,7 +51,8 @@ code_object_op::compute(context& ctx, const shape&, const std::vector<argument>&
std::vector<void*> kargs(args.size()); std::vector<void*> kargs(args.size());
std::transform( std::transform(
args.begin(), args.end(), kargs.begin(), [](const argument& a) { return a.data(); }); args.begin(), args.end(), kargs.begin(), [](const argument& a) { return a.data(); });
k.launch(ctx.get_stream().get(), global, local, std::move(kargs)); auto [start, stop] = ctx.get_perf_events();
k.launch(ctx.get_stream().get(), global, local, std::move(kargs), start, stop);
return args[get_output_arg(args.size())]; return args[get_output_arg(args.size())];
} }
void code_object_op::finalize(context&, const shape&, const std::vector<shape>&) void code_object_op::finalize(context&, const shape&, const std::vector<shape>&)
......
...@@ -61,12 +61,19 @@ vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs ...@@ -61,12 +61,19 @@ vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs
[&](const auto& input) -> std::size_t { [&](const auto& input) -> std::size_t {
auto stride = input.strides()[axis]; auto stride = input.strides()[axis];
auto len = input.lens()[axis]; auto len = input.lens()[axis];
if(stride != 0 and stride != 1) if(not contains({0, 1}, stride))
return 1; return 1;
if(len == 1 and input.elements() > sizes.front()) if(len == 1 and input.elements() > sizes.front())
return sizes.front(); return sizes.front();
auto it = std::find_if( auto it = std::find_if(sizes.begin(), sizes.end(), [&](auto vsize) {
sizes.begin(), sizes.end(), [&](auto i) { return (len % i) == 0; }); // The len is divisible by the size and all the strides are divisible by
// the size
return (len % vsize) == 0 and
std::all_of(
input.strides().begin(), input.strides().end(), [&](auto i) {
return contains({0, 1}, i) or i % vsize == 0;
});
});
if(it != sizes.end()) if(it != sizes.end())
return *it; return *it;
return 1; return 1;
......
...@@ -131,7 +131,7 @@ struct hip_array ...@@ -131,7 +131,7 @@ struct hip_array
friend MIGRAPHX_DEVICE_CONSTEXPR bool operator!=(const hip_array& x, const hip_array& y) friend MIGRAPHX_DEVICE_CONSTEXPR bool operator!=(const hip_array& x, const hip_array& y)
{ {
return !(x == y); return not(x == y);
} }
// This uses the product order rather than lexical order // This uses the product order rather than lexical order
friend MIGRAPHX_DEVICE_CONSTEXPR bool operator<(const hip_array& x, const hip_array& y) friend MIGRAPHX_DEVICE_CONSTEXPR bool operator<(const hip_array& x, const hip_array& y)
......
...@@ -117,12 +117,13 @@ template <class V, class F, class... Ts> ...@@ -117,12 +117,13 @@ template <class V, class F, class... Ts>
void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs) void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
{ {
std::initializer_list<migraphx::shape::type_t> types = {get_shape(xs).type()...}; std::initializer_list<migraphx::shape::type_t> types = {get_shape(xs).type()...};
if(!std::all_of( if(not std::all_of(
types.begin(), types.end(), [&](migraphx::shape::type_t t) { return t == s.type(); })) types.begin(), types.end(), [&](migraphx::shape::type_t t) { return t == s.type(); }))
MIGRAPHX_THROW("Types must be the same"); MIGRAPHX_THROW("Types must be the same");
std::initializer_list<index_int> ranks = { std::initializer_list<index_int> ranks = {
static_cast<index_int>(get_shape(xs).lens().size())...}; static_cast<index_int>(get_shape(xs).lens().size())...};
if(!std::all_of(ranks.begin(), ranks.end(), [&](index_int r) { return r == s.lens().size(); })) if(not std::all_of(
ranks.begin(), ranks.end(), [&](index_int r) { return r == s.lens().size(); }))
MIGRAPHX_THROW("Ranks must be the same"); MIGRAPHX_THROW("Ranks must be the same");
visit_tensor_size(s.lens().size(), [&](auto ndim) { visit_tensor_size(s.lens().size(), [&](auto ndim) {
s.visit_type(hip_visitor([&](auto as) { v(f(xs, ndim, as)...); })); s.visit_type(hip_visitor([&](auto as) { v(f(xs, ndim, as)...); }));
...@@ -134,7 +135,8 @@ void hip_visit_views_impl(const shape& s, F f, V&& v, Ts&&... xs) ...@@ -134,7 +135,8 @@ void hip_visit_views_impl(const shape& s, F f, V&& v, Ts&&... xs)
{ {
std::initializer_list<index_int> ranks = { std::initializer_list<index_int> ranks = {
static_cast<index_int>(get_shape(xs).lens().size())...}; static_cast<index_int>(get_shape(xs).lens().size())...};
if(!std::all_of(ranks.begin(), ranks.end(), [&](index_int r) { return r == s.lens().size(); })) if(not std::all_of(
ranks.begin(), ranks.end(), [&](index_int r) { return r == s.lens().size(); }))
MIGRAPHX_THROW("Ranks must be the same"); MIGRAPHX_THROW("Ranks must be the same");
visit_tensor_size(s.lens().size(), [&](auto ndim) { v(f(xs, ndim)...); }); visit_tensor_size(s.lens().size(), [&](auto ndim) { v(f(xs, ndim)...); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment