Commit f7838bc8 authored by turneram's avatar turneram
Browse files

Merge remote-tracking branch 'origin/develop' into ck-elementwise

parents fea58a7b d78bcdfb
......@@ -31,7 +31,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
//! Parser for ReverseSequence ONNX operator.
// Parser for ReverseSequence ONNX operator.
/*!
Reverses the data along the time axis for the batches along the batch axis.
The sequence lengths can be given to reverse up to the given length for each batch, keeping the
......
......@@ -29,7 +29,7 @@ inline namespace MIGRAPHX_INLINE_NS {
void memory_coloring::apply(module& m) const
{
if(!enabled(MIGRAPHX_DISABLE_MEMORY_COLORING{}))
if(not enabled(MIGRAPHX_DISABLE_MEMORY_COLORING{}))
{
memory_coloring_impl opt(&m, allocation_op, verify);
opt.run();
......
......@@ -42,7 +42,7 @@ void memory_coloring_impl::run()
{
MIGRAPHX_DEBUG(dump_intervals());
// Coloring
while(!alloc_queue.empty())
while(not alloc_queue.empty())
{
interval_ptr interval = alloc_queue.top();
allocate(interval);
......@@ -96,7 +96,7 @@ bool memory_coloring_impl::allocate(interval_ptr interval)
}
std::size_t offset = 0;
while(!conflict_queue.empty())
while(not conflict_queue.empty())
{
live_range* range = conflict_queue.top();
std::size_t iter_offset = range->offset;
......@@ -149,7 +149,7 @@ void memory_coloring_impl::build()
{
def_interval = instr2_live[p_iter];
bool is_lit = is_literal(iter);
if(is_allocate(iter) || is_lit)
if(is_allocate(iter) or is_lit)
{
live_range& range = def_interval->segment;
def_interval->result = iter->get_shape();
......@@ -157,12 +157,12 @@ void memory_coloring_impl::build()
range.begin = cur_points;
def_interval->def_point = cur_points;
range.size = (iter->get_shape()).bytes();
if(!is_lit || unify_literals)
if(not is_lit or unify_literals)
alloc_queue.push(def_interval);
live_set.erase(range.vn);
}
}
else if(!is_param(iter) && !is_outline(iter) && !is_check_context(iter))
else if(not is_param(iter) && not is_outline(iter) && not is_check_context(iter))
{
is_dead = true;
}
......@@ -179,7 +179,7 @@ void memory_coloring_impl::build()
if(not p_mod->has_instruction(arg))
continue;
if(is_param(arg) || is_outline(arg))
if(is_param(arg) or is_outline(arg))
{
if(is_output_param(arg))
is_dead = false;
......@@ -235,7 +235,7 @@ void memory_coloring_impl::rewrite()
if(interval->get_begin() == invalid_offset)
continue;
if(!unify_literals && interval->is_literal)
if(not unify_literals && interval->is_literal)
continue;
std::size_t offset = 0;
......@@ -272,7 +272,7 @@ void memory_coloring_impl::verify()
if(segment.begin == invalid_offset)
{
// if(!interval.is_live_on_entry)
// if(not interval.is_live_on_entry)
// MIGRAPHX_THROW("interval is not live on entry");
continue;
}
......@@ -290,7 +290,7 @@ void memory_coloring_impl::verify()
live_range* range = live_ranges[iter];
if(range->offset == invalid_offset)
continue;
if(!is_disjoin(*range, segment))
if(not is_disjoin(*range, segment))
MIGRAPHX_THROW("range and segment is not disjoined");
}
}
......
......@@ -125,11 +125,11 @@ struct memory_coloring_impl
static bool is_disjoin(const live_range& range1, const live_range& range2)
{
if((range1.size == 0) || (range2.size == 0))
if((range1.size == 0) or (range2.size == 0))
return false;
auto end1 = range1.offset + range1.size - 1;
auto end2 = range2.offset + range2.size - 1;
return ((end1 < range2.offset) || (end2 < range1.offset));
return ((end1 < range2.offset) or (end2 < range1.offset));
}
void verify();
#ifdef MIGRAPHX_DEBUG_OPT
......
......@@ -50,7 +50,7 @@ int exec(const std::string& cmd, const std::function<void(const char*)>& std_out
{
// TODO: Use execve instead of popen
std::unique_ptr<FILE, decltype(closer)> pipe(popen(cmd.c_str(), "r"), closer); // NOLINT
if(!pipe)
if(not pipe)
MIGRAPHX_THROW("popen() failed: " + cmd);
std::array<char, 128> buffer;
while(fgets(buffer.data(), buffer.size(), pipe.get()) != nullptr)
......
......@@ -78,11 +78,11 @@ program& program::operator=(program p)
void program::assign(const program& p)
{
if(!impl)
if(not impl)
{
impl = std::make_unique<program_impl>();
}
else if(!impl->modules.empty())
else if(not impl->modules.empty())
{
impl->modules.clear();
}
......
......@@ -83,7 +83,7 @@ void visit_py(T x, F f)
{
f(x.template cast<bool>());
}
else if(py::isinstance<py::int_>(x) || py::hasattr(x, "__index__"))
else if(py::isinstance<py::int_>(x) or py::hasattr(x, "__index__"))
{
f(x.template cast<int>());
}
......
......@@ -70,7 +70,7 @@ void quantize_int8(program& prog,
{
std::set<std::string> op_names = {"convolution", "dot"};
std::set<std::string> input_ins_names(ins_names.begin(), ins_names.end());
if(!std::includes(
if(not std::includes(
op_names.begin(), op_names.end(), input_ins_names.begin(), input_ins_names.end()))
{
MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation");
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/rewrite_gelu.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/match/gelu_erf.hpp>
#include <migraphx/common.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct find_gelu_erf
{
auto matcher() const { return match::gelu_erf(); }
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x = r.instructions["x"];
if(x->get_shape().type() != migraphx::shape::half_type)
return;
auto lit = m.add_literal(literal{shape{x->get_shape().type()}, {1.702f}});
auto mul = insert_common_op(m, ins, make_op("mul"), {x, lit});
auto sig = m.insert_instruction(ins, make_op("neg"), mul);
sig = m.insert_instruction(ins, make_op("exp"), sig);
auto one = m.add_literal(literal{shape{x->get_shape().type()}, {1.0f}});
sig = insert_common_op(m, ins, make_op("add"), {sig, one});
sig = m.insert_instruction(ins, make_op("div"), x, sig);
m.replace_instruction(ins, sig);
}
};
void rewrite_gelu::apply(module& m) const { match::find_matches(m, find_gelu_erf{}); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -47,12 +47,12 @@ void rewrite_pooling::apply(module& m) const
if(not s.standard())
continue;
auto&& op = any_cast<op::pooling>(ins->get_operator());
if(!std::all_of(op.padding.begin(), op.padding.end(), [](auto i) { return i == 0; }))
if(not std::all_of(op.padding.begin(), op.padding.end(), [](auto i) { return i == 0; }))
continue;
if(!std::all_of(op.stride.begin(), op.stride.end(), [](auto i) { return i == 1; }))
if(not std::all_of(op.stride.begin(), op.stride.end(), [](auto i) { return i == 1; }))
continue;
auto lens = s.lens();
if(!std::equal(lens.begin() + 2, lens.end(), op.lengths.begin(), op.lengths.end()))
if(not std::equal(lens.begin() + 2, lens.end(), op.lengths.begin(), op.lengths.end()))
continue;
std::int64_t n = s.lens()[0];
std::int64_t c = s.lens()[1];
......
......@@ -214,7 +214,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
ih = m.add_literal(migraphx::literal{ih_shape, data});
}
if(!is_forward and variable_seq_len)
if(not is_forward and variable_seq_len)
{
args[0] =
m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
......@@ -520,7 +520,7 @@ void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
ih = m.add_literal(migraphx::literal{ih_shape, data});
}
if(!is_forward and variable_seq_len)
if(not is_forward and variable_seq_len)
{
args[0] =
m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
......@@ -977,7 +977,7 @@ void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
pph = args[7];
}
if(!is_forward and variable_seq_len)
if(not is_forward and variable_seq_len)
{
args[0] =
m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
......@@ -1294,11 +1294,11 @@ bool rewrite_rnn::is_variable_seq_lens(const module& m, instruction_ref seq_lens
std::vector<int64_t> vec_lens;
arg_lens.visit([&](auto l) { vec_lens.assign(l.begin(), l.end()); });
int64_t l = 0;
if(!vec_lens.empty())
if(not vec_lens.empty())
{
l = vec_lens[0];
}
if(!std::all_of(vec_lens.begin(), vec_lens.end(), [&](auto v) { return v == l; }))
if(not std::all_of(vec_lens.begin(), vec_lens.end(), [&](auto v) { return v == l; }))
{
is_var_lens = true;
}
......@@ -1318,7 +1318,7 @@ rewrite_rnn::get_seq_len(const module& m, instruction_ref input, instruction_ref
bool is_var_lens = is_variable_seq_lens(m, seq_lens);
auto input_shape = input->get_shape();
auto length = input_shape.lens()[0];
if(!is_var_lens and seq_lens != m.end())
if(not is_var_lens and seq_lens != m.end())
{
auto arg_len = seq_lens->eval();
std::vector<std::size_t> vec_lens;
......@@ -1387,7 +1387,7 @@ void rewrite_rnn::replace_last_cell_output(module& m,
if(variable_seq_len)
{
if(!ins_outputs.empty())
if(not ins_outputs.empty())
{
cell_outputs = m.insert_instruction(
std::next(ins),
......
......@@ -477,7 +477,7 @@ bool operator==(const shape::dynamic_dimension& x, const shape::dynamic_dimensio
bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y)
{
return !(x == y);
return not(x == y);
}
std::ostream& operator<<(std::ostream& os, const shape::dynamic_dimension& x)
{
......@@ -497,7 +497,7 @@ bool operator==(const shape& x, const shape& y)
x.strides() == y.strides() and x.sub_shapes() == y.sub_shapes());
}
bool operator!=(const shape& x, const shape& y) { return !(x == y); }
bool operator!=(const shape& x, const shape& y) { return not(x == y); }
std::ostream& operator<<(std::ostream& os, const shape& x)
{
......
......@@ -208,6 +208,42 @@ struct find_mul_add
}
};
struct find_dot_add
{
auto matcher() const
{
return match::name("dot")(match::either_arg(0, 1)(
match::name("add")(
match::either_arg(0, 1)(match::any().bind("x"),
match::any_of(match::is_constant()).bind("b")),
match::none_of(match::args(match::is_constant(), match::is_constant())),
match::used_once()),
match::is_constant().bind("a")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"];
auto x_ins = r.instructions["x"];
assert(x_ins != b_ins);
const bool flipped = a_ins == ins->inputs().back();
auto insert_dot = [&](auto x, auto y) {
if(flipped)
return m.insert_instruction(ins, make_op("dot"), y, x);
else
return m.insert_instruction(ins, make_op("dot"), x, y);
};
auto ax_ins = insert_dot(a_ins, x_ins);
auto ab_ins = insert_dot(a_ins, b_ins);
m.replace_instruction(ins, make_op("add"), ax_ins, ab_ins);
}
};
struct find_add_lit_broadcast
{
auto matcher() const
......@@ -267,28 +303,26 @@ struct find_double_add_lit_broadcast
struct find_inner_broadcast
{
auto matcher() const
{
return pointwise(
match::nargs(2),
match::args(match::name("broadcast").bind("x"), match::name("broadcast").bind("y")));
}
auto matcher() const { return pointwise(match::all_of[match::inputs()](match::broadcast())); }
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
auto y_ins = r.instructions["y"];
auto xbroadcast = any_cast<op::broadcast>(x_ins->get_operator());
auto ybroadcast = any_cast<op::broadcast>(y_ins->get_operator());
if(xbroadcast.axis != ybroadcast.axis)
auto ins = r.result;
auto broadcasts = ins->inputs();
if(broadcasts.empty())
return;
std::vector<instruction_ref> inputs;
std::transform(broadcasts.begin(),
broadcasts.end(),
std::back_inserter(inputs),
[](auto i) { return i->inputs().front(); });
if(std::any_of(inputs.begin(), inputs.end(), [&](auto i) {
return i->get_shape() != inputs.front()->get_shape();
}))
return;
auto op = m.insert_instruction(
ins, ins->get_operator(), x_ins->inputs().front(), y_ins->inputs().front());
m.replace_instruction(ins, xbroadcast, op);
auto op = m.insert_instruction(ins, ins->get_operator(), inputs);
m.replace_instruction(ins, broadcasts.front()->get_operator(), op);
}
};
......@@ -416,8 +450,9 @@ struct find_splits
{
auto matcher() const
{
return match::any(match::any_of[match::outputs()](match::name("slice")(
match::any_of[match::outputs()](match::pointwise(), reduction()))));
return match::any(
match::any_of[match::outputs()](match::name("slice")(match::any_of[match::outputs()](
match::pointwise(match::any_of(match::nargs(1), match::nargs(2))), reduction()))));
}
static bool is_dependent(const module& m, instruction_ref ins1, instruction_ref ins2)
......@@ -580,10 +615,9 @@ struct find_splits
auto outputs = i->outputs();
for(auto output : outputs)
{
if(not contains({"reshape", "squeeze", "unsqueeze"}, output->name()))
if(output->name() != "reshape")
continue;
auto x =
m.insert_instruction(output, make_op("contiguous"), output->inputs());
auto x = m.insert_instruction(output, make_op("contiguous"), i);
m.replace_instruction(output, output->get_operator(), x);
}
......@@ -753,7 +787,7 @@ MIGRAPHX_PRED_MATCHER(horiz_conv_dot, instruction_ref ins)
};
auto dots = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("dot"));
auto convs = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("convolution"));
return !(dots < 2 and convs < 2);
return not(dots < 2 and convs < 2);
}
struct find_conv_dot_horiz_fusion
......@@ -773,7 +807,7 @@ struct find_conv_dot_horiz_fusion
auto y = j->inputs()[1]->get_shape().lens();
if(x.size() != y.size())
return false;
// Check that non-axises match
// Check that non-axes match
int axis = 1;
if(i->name() == "dot")
{
......@@ -809,13 +843,22 @@ struct find_conv_dot_horiz_fusion
for(auto arg : args)
m.move_instructions(arg, input);
// TODO: Check if axises match
// TODO: Check if axes match
auto concat =
m.insert_instruction(input, make_op("concat", {{"axis", concat_axis}}), args);
auto fused = m.insert_instruction(std::next(input), op, input, concat);
int64_t offset = 0;
for(auto arg : range(start, last))
{
auto outputs = arg->outputs();
for(auto output : outputs)
{
if(output->name() != "reshape")
continue;
auto x = m.insert_instruction(output, make_op("contiguous"), arg);
m.replace_instruction(output, output->get_operator(), x);
}
int64_t len = arg->get_shape().lens()[axis];
m.replace_instruction(
arg,
......@@ -926,7 +969,7 @@ struct find_split_reshape
// all outputs are reshape and of the same shape
auto dims = any_cast<op::reshape>(rsp->get_operator()).dims;
if(!same_ops(vec_rsp))
if(not same_ops(vec_rsp))
{
return;
}
......@@ -958,7 +1001,11 @@ struct find_split_reshape
std::vector<int64_t> rsp_out_lens(rsp_lens.begin(), rsp_lens.end());
rsp_out_lens[rsp_axis] = std::accumulate(vec_dims.begin(), vec_dims.end(), std::int64_t{0});
// insert the reshape instruction
// insert the reshape instruction and add contiguous if needed
if(not input->get_shape().standard())
{
input = m.insert_instruction(std::next(input), make_op("contiguous"), input);
}
auto rsp_ins = m.insert_instruction(
std::next(input), make_op("reshape", {{"dims", rsp_out_lens}}), input);
......@@ -1005,7 +1052,7 @@ struct find_split_transpose
// all transpose are the same
auto perm = any_cast<op::transpose>(trans->get_operator()).dims;
if(!same_ops(vec_trans))
if(not same_ops(vec_trans))
{
return;
}
......@@ -1048,6 +1095,7 @@ void simplify_algebra::apply(module& m) const
find_mul_conv{},
find_mul_slice_conv{},
find_mul_add{},
find_dot_add{},
find_div_const{},
find_sub_const{},
find_rsqrt{},
......
......@@ -99,7 +99,7 @@ struct find_reshaper
std::vector<instruction_ref> reshapes{ins};
while(is_reshaper(reshapes.back()))
{
assert(!reshapes.back()->inputs().empty());
assert(not reshapes.back()->inputs().empty());
assert(m.has_instruction(reshapes.back()->inputs().front()));
auto input = reshapes.back()->inputs().front();
reshapes.push_back(input);
......@@ -288,7 +288,7 @@ struct find_concat_transpose
auto permutation = find_permutation(s);
// permutation should be the same for all inputs
if(!std::all_of(trans_inputs.begin(), trans_inputs.end(), [&](auto in) {
if(not std::all_of(trans_inputs.begin(), trans_inputs.end(), [&](auto in) {
return (find_permutation(in->get_shape()) == permutation);
}))
{
......
......@@ -49,7 +49,7 @@ struct dnnl_binary : dnnl_op<dnnl_binary, dnnl::binary>
auto s0 = inputs.at(0);
auto s1 = inputs.at(1);
auto r = s0;
if(s0 != s1 or !s0.packed())
if(s0 != s1 or not s0.packed())
{
r = shape{s0.type(), s0.lens()};
}
......
......@@ -95,7 +95,7 @@ void subgraph::apply(module_pass_manager& mpm) const
for(auto it : iterator_for(mod))
{
// assuming we want all the params/literals as inputs to the FPGA submodule
if(migraphx::starts_with(it->name(), "@param") ||
if(migraphx::starts_with(it->name(), "@param") or
migraphx::starts_with(it->name(), "@literal"))
{
literal_inputs.push_back(it);
......
......@@ -51,7 +51,8 @@ code_object_op::compute(context& ctx, const shape&, const std::vector<argument>&
std::vector<void*> kargs(args.size());
std::transform(
args.begin(), args.end(), kargs.begin(), [](const argument& a) { return a.data(); });
k.launch(ctx.get_stream().get(), global, local, std::move(kargs));
auto [start, stop] = ctx.get_perf_events();
k.launch(ctx.get_stream().get(), global, local, std::move(kargs), start, stop);
return args[get_output_arg(args.size())];
}
void code_object_op::finalize(context&, const shape&, const std::vector<shape>&)
......
......@@ -61,12 +61,19 @@ vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs
[&](const auto& input) -> std::size_t {
auto stride = input.strides()[axis];
auto len = input.lens()[axis];
if(stride != 0 and stride != 1)
if(not contains({0, 1}, stride))
return 1;
if(len == 1 and input.elements() > sizes.front())
return sizes.front();
auto it = std::find_if(
sizes.begin(), sizes.end(), [&](auto i) { return (len % i) == 0; });
auto it = std::find_if(sizes.begin(), sizes.end(), [&](auto vsize) {
// The len is divisible by the size and all the strides are divisible by
// the size
return (len % vsize) == 0 and
std::all_of(
input.strides().begin(), input.strides().end(), [&](auto i) {
return contains({0, 1}, i) or i % vsize == 0;
});
});
if(it != sizes.end())
return *it;
return 1;
......
......@@ -131,7 +131,7 @@ struct hip_array
friend MIGRAPHX_DEVICE_CONSTEXPR bool operator!=(const hip_array& x, const hip_array& y)
{
return !(x == y);
return not(x == y);
}
// This uses the product order rather than lexical order
friend MIGRAPHX_DEVICE_CONSTEXPR bool operator<(const hip_array& x, const hip_array& y)
......
......@@ -117,12 +117,13 @@ template <class V, class F, class... Ts>
void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
{
std::initializer_list<migraphx::shape::type_t> types = {get_shape(xs).type()...};
if(!std::all_of(
if(not std::all_of(
types.begin(), types.end(), [&](migraphx::shape::type_t t) { return t == s.type(); }))
MIGRAPHX_THROW("Types must be the same");
std::initializer_list<index_int> ranks = {
static_cast<index_int>(get_shape(xs).lens().size())...};
if(!std::all_of(ranks.begin(), ranks.end(), [&](index_int r) { return r == s.lens().size(); }))
if(not std::all_of(
ranks.begin(), ranks.end(), [&](index_int r) { return r == s.lens().size(); }))
MIGRAPHX_THROW("Ranks must be the same");
visit_tensor_size(s.lens().size(), [&](auto ndim) {
s.visit_type(hip_visitor([&](auto as) { v(f(xs, ndim, as)...); }));
......@@ -134,7 +135,8 @@ void hip_visit_views_impl(const shape& s, F f, V&& v, Ts&&... xs)
{
std::initializer_list<index_int> ranks = {
static_cast<index_int>(get_shape(xs).lens().size())...};
if(!std::all_of(ranks.begin(), ranks.end(), [&](index_int r) { return r == s.lens().size(); }))
if(not std::all_of(
ranks.begin(), ranks.end(), [&](index_int r) { return r == s.lens().size(); }))
MIGRAPHX_THROW("Ranks must be the same");
visit_tensor_size(s.lens().size(), [&](auto ndim) { v(f(xs, ndim)...); });
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment