Unverified Commit 671f24be authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Prefuse layernorm for gpu (#1190)

Fuse layernorm and added triadd_layernorm fusion.  This is a prep performance booster
parent 4ec8209f
...@@ -9,26 +9,6 @@ ...@@ -9,26 +9,6 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <class Range, class Iterator>
std::ptrdiff_t bidistance(const Range& r, Iterator start, Iterator last)
{
auto start_forward = start;
auto start_backwards = start;
std::size_t n = 0;
while(start_forward != last and start_backwards != last)
{
n++;
if(start_forward != r.end())
start_forward++;
if(start_backwards != r.begin())
start_backwards--;
}
if(start_forward == last)
return n;
else
return -n;
}
void dead_code_elimination::apply(program& p) const { p.remove_unused_modules(); } void dead_code_elimination::apply(program& p) const { p.remove_unused_modules(); }
void dead_code_elimination::apply(module& m) const void dead_code_elimination::apply(module& m) const
...@@ -48,17 +28,21 @@ void dead_code_elimination::apply(module& m) const ...@@ -48,17 +28,21 @@ void dead_code_elimination::apply(module& m) const
if(i->get_shape().elements() == 0 and i->name().front() != '@' and if(i->get_shape().elements() == 0 and i->name().front() != '@' and
i->name() != "undefined" and i->name() != "identity") i->name() != "undefined" and i->name() != "identity")
continue; continue;
assert(bidistance(m, i, last) > 0); assert(std::distance(m.begin(), i) <= std::distance(m.begin(), last));
std::unordered_set<instruction_ref> visited;
fix([&](auto self, auto leaf) { fix([&](auto self, auto leaf) {
if(not m.has_instruction(leaf)) if(not m.has_instruction(leaf))
return; return;
if(leaf->outputs().empty()) if(leaf->outputs().empty())
{ {
// Dont visit inputs twice
if(not visited.insert(leaf).second)
return;
std::unordered_set<instruction_ref> args(leaf->inputs().begin(), std::unordered_set<instruction_ref> args(leaf->inputs().begin(),
leaf->inputs().end()); leaf->inputs().end());
leaf->clear_arguments(); leaf->clear_arguments();
assert(bidistance(m, last, leaf) < 0); assert(std::distance(m.begin(), leaf) < std::distance(m.begin(), last));
assert(leaf != ins); assert(leaf != ins);
if(leaf->name() != "@param") if(leaf->name() != "@param")
m.move_instruction(leaf, m.end()); m.move_instruction(leaf, m.end());
......
...@@ -9,7 +9,19 @@ namespace migraphx { ...@@ -9,7 +9,19 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
operation make_op(const std::string& name); operation make_op(const std::string& name);
operation make_op(const std::string& name, const value& v); operation make_op(const std::string& name,
const std::initializer_list<std::pair<std::string, value>>& v);
operation make_op_from_value(const std::string& name, const value& v);
// A template overload is added for migraphx::value so the initializer_list
// cannot be passed in directly. This is to enforce at compile-time that all
// initializer_list are key-value pairs, whereas migraphx::value allows other
// types of initializer_list such as for arrays.
template <class Value>
operation make_op(const std::string& name, const Value& v)
{
return make_op_from_value(name, v);
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -156,6 +156,19 @@ struct id_matcher ...@@ -156,6 +156,19 @@ struct id_matcher
} }
}; };
// Forward declare class and constructors
template <class M>
struct basic_matcher;
template <class M>
basic_matcher<M> make_basic_matcher(M m);
template <class F>
basic_matcher<function_matcher<F>> make_basic_fun_matcher(F f);
template <class P>
basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p);
/// The basic matcher provides the all_of composability of the matcher /// The basic matcher provides the all_of composability of the matcher
template <class M> template <class M>
struct basic_matcher struct basic_matcher
...@@ -167,7 +180,7 @@ struct basic_matcher ...@@ -167,7 +180,7 @@ struct basic_matcher
{ {
// Copy m because we cant capture `this` by value // Copy m because we cant capture `this` by value
auto mm = m; auto mm = m;
return make_bf_matcher([=](matcher_context& ctx, return make_basic_fun_matcher([=](matcher_context& ctx,
instruction_ref ins) -> optional<instruction_ref> { instruction_ref ins) -> optional<instruction_ref> {
auto result = mm.match(ctx, ins); auto result = mm.match(ctx, ins);
if(result) if(result)
...@@ -239,7 +252,39 @@ struct any_matcher : any_matcher_base ...@@ -239,7 +252,39 @@ struct any_matcher : any_matcher_base
struct matcher_result struct matcher_result
{ {
std::unordered_map<std::string, instruction_ref> instructions; struct instruction_container
{
instruction_container() = default;
instruction_container(std::unordered_map<std::string, instruction_ref> x)
: ins_map(std::move(x))
{
}
instruction_ref operator[](const std::string& name) const
{
auto it = ins_map.find(name);
if(it == ins_map.end())
MIGRAPHX_THROW("Accessing name that wasn't bound in matcher: " + name);
return it->second;
}
auto find(const std::string& name) const { return ins_map.find(name); }
auto begin() const { return ins_map.cbegin(); }
auto end() const { return ins_map.cend(); }
bool has_instructions_in(const module& mod) const
{
return std::all_of(ins_map.begin(), ins_map.end(), [&](auto&& p) {
return mod.has_instruction(p.second);
});
}
private:
std::unordered_map<std::string, instruction_ref> ins_map;
};
instruction_container instructions;
instruction_ref result; instruction_ref result;
}; };
...@@ -255,6 +300,7 @@ matcher_result match_instruction(module& mod, instruction_ref ins, M&& m) ...@@ -255,6 +300,7 @@ matcher_result match_instruction(module& mod, instruction_ref ins, M&& m)
{ {
result.result = ins; result.result = ins;
result.instructions = ctx.instructions; result.instructions = ctx.instructions;
assert(result.instructions.has_instructions_in(mod));
} }
else else
{ {
...@@ -533,6 +579,18 @@ auto skip_output(Ms... ms) ...@@ -533,6 +579,18 @@ auto skip_output(Ms... ms)
}); });
} }
inline auto var(std::string s)
{
return make_basic_fun_matcher(
[=, s = std::move(s)](const matcher_context& ctx,
instruction_ref) -> optional<instruction_ref> {
auto it = ctx.instructions.find(s);
if(it == ctx.instructions.end())
return nullopt;
return it->second;
});
}
inline auto name(std::string s) inline auto name(std::string s)
{ {
return make_basic_pred_matcher( return make_basic_pred_matcher(
......
...@@ -5,20 +5,41 @@ namespace migraphx { ...@@ -5,20 +5,41 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
operation make_op(const std::string& name) { return load_op(name); } operation make_op(const std::string& name) { return load_op(name); }
operation make_op(const std::string& name, const value& v)
template <class F>
operation make_op_generic(const std::string& name, F for_each)
{ {
if(not(v.is_object() or (v.empty() and v.is_array())))
MIGRAPHX_THROW("Value is not an object");
auto op = load_op(name); auto op = load_op(name);
// Merge values // Merge values
value w = op.to_value(); value w = op.to_value();
for(auto&& x : v) for_each([&](const auto& key, const auto& x) {
{ if(not w.contains(key))
w.at(x.get_key()) = x.without_key(); // NOLINTNEXTLINE(performance-inefficient-string-concatenation)
} MIGRAPHX_THROW("No key '" + key + "' in " + name);
w.at(key) = x;
});
op.from_value(w); op.from_value(w);
return op; return op;
} }
operation make_op(const std::string& name,
const std::initializer_list<std::pair<std::string, value>>& v)
{
return make_op_generic(name, [&](auto f) {
for(auto&& [key, x] : v)
f(key, x);
});
}
operation make_op_from_value(const std::string& name, const value& v)
{
if(not(v.is_object() or (v.empty() and v.is_array())))
MIGRAPHX_THROW("Value is not an object for make_op: " + name);
return make_op_generic(name, [&](auto f) {
for(auto&& x : v)
f(x.get_key(), x.without_key());
});
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -42,7 +42,7 @@ struct find_mul_conv ...@@ -42,7 +42,7 @@ struct find_mul_conv
match::name("broadcast").bind("a"))); match::name("broadcast").bind("a")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& p, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto conv_ins = r.instructions["conv"]; auto conv_ins = r.instructions["conv"];
...@@ -80,7 +80,7 @@ struct find_mul_slice_conv ...@@ -80,7 +80,7 @@ struct find_mul_slice_conv
match::name("broadcast")(match::is_constant()).bind("a"))); match::name("broadcast")(match::is_constant()).bind("a")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& p, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto slice_ins = r.instructions["slice"]; auto slice_ins = r.instructions["slice"];
...@@ -171,7 +171,7 @@ struct find_mul_add ...@@ -171,7 +171,7 @@ struct find_mul_add
match::is_constant().bind("a"))); match::is_constant().bind("a")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& p, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto a_ins = r.instructions["a"]; auto a_ins = r.instructions["a"];
...@@ -193,7 +193,7 @@ struct find_add_lit_broadcast ...@@ -193,7 +193,7 @@ struct find_add_lit_broadcast
match::either_arg(0, 1)(op_lit_broadcast("add", "a", "x"), lit_broadcast().bind("b"))); match::either_arg(0, 1)(op_lit_broadcast("add", "a", "x"), lit_broadcast().bind("b")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& p, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
...@@ -213,7 +213,7 @@ struct find_double_add_lit_broadcast ...@@ -213,7 +213,7 @@ struct find_double_add_lit_broadcast
match::args(op_lit_broadcast("add", "a", "x"), op_lit_broadcast("add", "b", "y"))); match::args(op_lit_broadcast("add", "a", "x"), op_lit_broadcast("add", "b", "y")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& p, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
...@@ -251,7 +251,7 @@ struct find_inner_broadcast ...@@ -251,7 +251,7 @@ struct find_inner_broadcast
match::args(match::name("broadcast").bind("x"), match::name("broadcast").bind("y"))); match::args(match::name("broadcast").bind("x"), match::name("broadcast").bind("y")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& p, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
...@@ -664,7 +664,7 @@ struct find_add_convs ...@@ -664,7 +664,7 @@ struct find_add_convs
return x.stride[0] / y.stride[0]; return x.stride[0] / y.stride[0];
} }
void apply(module& p, match::matcher_result r) const void apply(module& p, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto a_conv = r.instructions["a"]; auto a_conv = r.instructions["a"];
...@@ -815,7 +815,7 @@ struct find_div_const ...@@ -815,7 +815,7 @@ struct find_div_const
return match::name("div")(match::arg(1)(match::is_constant().bind("c"))); return match::name("div")(match::arg(1)(match::is_constant().bind("c")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& p, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto c_ins = r.instructions["c"]; auto c_ins = r.instructions["c"];
...@@ -835,7 +835,7 @@ struct find_sub_const ...@@ -835,7 +835,7 @@ struct find_sub_const
return match::name("sub")(match::arg(1)(match::is_constant().bind("c"))); return match::name("sub")(match::arg(1)(match::is_constant().bind("c")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& p, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto c_ins = r.instructions["c"]; auto c_ins = r.instructions["c"];
...@@ -856,7 +856,7 @@ struct find_rsqrt ...@@ -856,7 +856,7 @@ struct find_rsqrt
match::name("sqrt")(match::used_once(), match::args(match::any().bind("x"))))); match::name("sqrt")(match::used_once(), match::args(match::any().bind("x")))));
} }
void apply(module& p, match::matcher_result r) const void apply(module& p, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
...@@ -881,7 +881,7 @@ struct find_split_reshape ...@@ -881,7 +881,7 @@ struct find_split_reshape
.bind("reshape"); .bind("reshape");
} }
void apply(module& p, match::matcher_result r) const void apply(module& p, const match::matcher_result& r) const
{ {
auto slc = r.instructions["slice"]; auto slc = r.instructions["slice"];
auto rsp = r.instructions["reshape"]; auto rsp = r.instructions["reshape"];
...@@ -962,7 +962,7 @@ struct find_split_transpose ...@@ -962,7 +962,7 @@ struct find_split_transpose
.bind("trans"); .bind("trans");
} }
void apply(module& p, match::matcher_result r) const void apply(module& p, const match::matcher_result& r) const
{ {
auto slc = r.instructions["slice"]; auto slc = r.instructions["slice"];
auto trans = r.instructions["trans"]; auto trans = r.instructions["trans"];
......
...@@ -53,7 +53,7 @@ struct match_find_quantizable_ops ...@@ -53,7 +53,7 @@ struct match_find_quantizable_ops
match::arg(1)(dequantizelinear_op("x2", "scale2"))); match::arg(1)(dequantizelinear_op("x2", "scale2")));
} }
void apply(module& m, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto qop = r.result; auto qop = r.result;
auto q1 = r.instructions["x1"]; auto q1 = r.instructions["x1"];
......
...@@ -329,7 +329,7 @@ struct find_resize ...@@ -329,7 +329,7 @@ struct find_resize
match::args(match::name("reshape").bind("data"), match::is_constant().bind("ind"))); match::args(match::name("reshape").bind("data"), match::is_constant().bind("ind")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& p, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto ins_rsp = r.instructions["data"]; auto ins_rsp = r.instructions["data"];
...@@ -436,7 +436,7 @@ struct find_where_op ...@@ -436,7 +436,7 @@ struct find_where_op
match::is_constant().bind("ind"))); match::is_constant().bind("ind")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& p, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto concat = r.instructions["data"]; auto concat = r.instructions["data"];
...@@ -496,7 +496,7 @@ struct find_reshape_cont ...@@ -496,7 +496,7 @@ struct find_reshape_cont
match::any())); match::any()));
} }
void apply(module& p, match::matcher_result r) const void apply(module& p, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto ins_cont = r.instructions["cont"]; auto ins_cont = r.instructions["cont"];
...@@ -564,7 +564,7 @@ struct find_transpose_contiguous_reshaper_unary ...@@ -564,7 +564,7 @@ struct find_transpose_contiguous_reshaper_unary
match::args(match_transpose_contiguous_reshaper())); match::args(match_transpose_contiguous_reshaper()));
} }
void apply(module& p, match::matcher_result r) const void apply(module& p, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto reshaper_ins = r.instructions["reshaper_ins"]; auto reshaper_ins = r.instructions["reshaper_ins"];
......
...@@ -352,7 +352,7 @@ struct cpu_apply ...@@ -352,7 +352,7 @@ struct cpu_apply
std::transform(bind_inputs.begin(), std::transform(bind_inputs.begin(),
bind_inputs.end(), bind_inputs.end(),
std::back_inserter(inputs), std::back_inserter(inputs),
[&](const auto& s) { return r.instructions.at(s); }); [&](const auto& s) { return r.instructions[s]; });
inputs.push_back(this->insert_allocation(ins, ins->get_shape())); inputs.push_back(this->insert_allocation(ins, ins->get_shape()));
modl->replace_instruction(ins, op, inputs); modl->replace_instruction(ins, op, inputs);
}); });
......
...@@ -158,6 +158,7 @@ add_library(migraphx_gpu ...@@ -158,6 +158,7 @@ add_library(migraphx_gpu
nonzero.cpp nonzero.cpp
pack_args.cpp pack_args.cpp
pack_int8_args.cpp pack_int8_args.cpp
prefuse_ops.cpp
pad.cpp pad.cpp
pooling.cpp pooling.cpp
quant_convolution.cpp quant_convolution.cpp
......
...@@ -316,7 +316,7 @@ struct find_layernorm ...@@ -316,7 +316,7 @@ struct find_layernorm
{ {
auto matcher() const { return match::layernorm(&gpu_name); } auto matcher() const { return match::layernorm(&gpu_name); }
void apply(module& p, match::matcher_result r) const void apply(module& p, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
...@@ -355,7 +355,7 @@ struct find_gelu ...@@ -355,7 +355,7 @@ struct find_gelu
{ {
auto matcher() const { return match::gelu_erf(&gpu_name); } auto matcher() const { return match::gelu_erf(&gpu_name); }
void apply(module& p, match::matcher_result r) const void apply(module& p, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
...@@ -372,7 +372,7 @@ struct find_add_gelu ...@@ -372,7 +372,7 @@ struct find_add_gelu
return match::name("gpu::gelu")(match::arg(0)(match::name("gpu::add").bind("add"))); return match::name("gpu::gelu")(match::arg(0)(match::name("gpu::add").bind("add")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& p, const match::matcher_result& r) const
{ {
auto add_ins = r.instructions["add"]; auto add_ins = r.instructions["add"];
auto ins = r.result; auto ins = r.result;
...@@ -391,7 +391,7 @@ struct find_gelu_new ...@@ -391,7 +391,7 @@ struct find_gelu_new
auto matcher() const { return match::gelu_tanh(&gpu_name); } auto matcher() const { return match::gelu_tanh(&gpu_name); }
void apply(module& p, match::matcher_result r) const void apply(module& p, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
...@@ -411,7 +411,7 @@ struct find_add_gelu_new ...@@ -411,7 +411,7 @@ struct find_add_gelu_new
return match::name("gpu::gelu_new")(match::arg(0)(match::name("gpu::add").bind("add"))); return match::name("gpu::gelu_new")(match::arg(0)(match::name("gpu::add").bind("add")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& p, const match::matcher_result& r) const
{ {
auto add_ins = r.instructions["add"]; auto add_ins = r.instructions["add"];
auto ins = r.result; auto ins = r.result;
...@@ -435,7 +435,7 @@ struct find_add_clip ...@@ -435,7 +435,7 @@ struct find_add_clip
.bind("add"))); .bind("add")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& p, const match::matcher_result& r) const
{ {
auto add_ins = r.instructions["add"]; auto add_ins = r.instructions["add"];
auto ins = r.result; auto ins = r.result;
...@@ -470,7 +470,7 @@ struct find_add_unary ...@@ -470,7 +470,7 @@ struct find_add_unary
.bind("add"))); .bind("add")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& p, const match::matcher_result& r) const
{ {
auto add_ins = r.instructions["add"]; auto add_ins = r.instructions["add"];
auto ins = r.result; auto ins = r.result;
...@@ -498,7 +498,7 @@ struct find_triadd ...@@ -498,7 +498,7 @@ struct find_triadd
.bind("input"))); .bind("input")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& p, const match::matcher_result& r) const
{ {
auto add_ins = r.instructions["add"]; auto add_ins = r.instructions["add"];
auto input_ins = r.instructions["input"]; auto input_ins = r.instructions["input"];
...@@ -525,7 +525,7 @@ struct find_mul_add ...@@ -525,7 +525,7 @@ struct find_mul_add
match::name("gpu::mul")(match::used_once()).bind("mul"), match::any().bind("b"))); match::name("gpu::mul")(match::used_once()).bind("mul"), match::any().bind("b")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& p, const match::matcher_result& r) const
{ {
auto mul_ins = r.instructions["mul"]; auto mul_ins = r.instructions["mul"];
auto b_ins = r.instructions["b"]; auto b_ins = r.instructions["b"];
...@@ -550,7 +550,7 @@ struct find_mul_add_relu ...@@ -550,7 +550,7 @@ struct find_mul_add_relu
match::arg(0)(match::name("gpu::mul_add")(match::used_once()).bind("mul_add"))); match::arg(0)(match::name("gpu::mul_add")(match::used_once()).bind("mul_add")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& p, const match::matcher_result& r) const
{ {
auto mul_add_ins = r.instructions["mul_add"]; auto mul_add_ins = r.instructions["mul_add"];
auto ins = r.result; auto ins = r.result;
...@@ -783,7 +783,7 @@ auto conv_bias(Ms... ms) ...@@ -783,7 +783,7 @@ auto conv_bias(Ms... ms)
} }
template <class Op> template <class Op>
void apply_conv_bias(context& ctx, module& p, match::matcher_result r) void apply_conv_bias(context& ctx, module& p, const match::matcher_result& r)
{ {
auto conv_ins = r.instructions["conv"]; auto conv_ins = r.instructions["conv"];
auto bias_ins = r.instructions["bias"]; auto bias_ins = r.instructions["bias"];
...@@ -829,9 +829,9 @@ struct find_conv_bias ...@@ -829,9 +829,9 @@ struct find_conv_bias
match::output(match::name(std::unordered_set<std::string>{"gpu::relu"})))); match::output(match::name(std::unordered_set<std::string>{"gpu::relu"}))));
} }
void apply(module& p, match::matcher_result r) const void apply(module& p, const match::matcher_result& r) const
{ {
apply_conv_bias<miopen_conv_bias>(*ctx, p, std::move(r)); apply_conv_bias<miopen_conv_bias>(*ctx, p, r);
} }
}; };
...@@ -840,9 +840,9 @@ struct find_conv_bias_relu ...@@ -840,9 +840,9 @@ struct find_conv_bias_relu
context* ctx = nullptr; context* ctx = nullptr;
auto matcher() const { return match::name("gpu::relu")(match::arg(0)(conv_bias())); } auto matcher() const { return match::name("gpu::relu")(match::arg(0)(conv_bias())); }
void apply(module& p, match::matcher_result r) const void apply(module& p, const match::matcher_result& r) const
{ {
apply_conv_bias<miopen_conv_bias_relu>(*ctx, p, std::move(r)); apply_conv_bias<miopen_conv_bias_relu>(*ctx, p, r);
} }
}; };
...@@ -857,7 +857,7 @@ struct find_conv_pointwise ...@@ -857,7 +857,7 @@ struct find_conv_pointwise
fusable_conv(match::used_once()).bind("conv"))); fusable_conv(match::used_once()).bind("conv")));
} }
void apply(module& m, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto conv_ins = r.instructions["conv"]; auto conv_ins = r.instructions["conv"];
auto bias_ins = r.instructions["bias"]; auto bias_ins = r.instructions["bias"];
...@@ -896,7 +896,7 @@ struct find_gemm_add ...@@ -896,7 +896,7 @@ struct find_gemm_add
match::name("gpu::gemm")(match::nargs(3)).bind("gemm"))); match::name("gpu::gemm")(match::nargs(3)).bind("gemm")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& p, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto gemm_ins = r.instructions["gemm"]; auto gemm_ins = r.instructions["gemm"];
......
#ifndef MIGRAPHX_GUARD_GPU_PREFUSE_OPS_HPP
#define MIGRAPHX_GUARD_GPU_PREFUSE_OPS_HPP
#include <migraphx/config.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
namespace gpu {
struct prefuse_ops
{
std::string name() const { return "gpu::prefuse_ops"; }
void apply(module& m) const;
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_PREFUSE_OPS_HPP
#include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/match/layernorm.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace {
struct find_layernorm
{
auto matcher() const { return match::layernorm(); }
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
if(not x_ins->get_shape().standard())
x_ins = m.insert_instruction(ins, make_op("contiguous"), x_ins);
auto relements = x_ins->get_shape().lens().back();
if(relements > 1024 or (relements % 4 != 0 and relements > 256))
return;
auto a = m.insert_instruction(
ins, make_op("hip::allocate", {{"shape", to_value(x_ins->get_shape())}}));
m.replace_instruction(ins, make_op("gpu::layernorm"), x_ins, a);
}
};
struct find_triaddlayernorm
{
auto matcher() const
{
auto add1 =
match::name("add")(match::none_of(match::is_constant()),
match::args(match::any().bind("z1"), match::any().bind("z2")));
auto add2 = match::name("add")(match::either_arg(0, 1)(add1, match::any().bind("z3")));
return match::layernorm()(match::var("x")(add2));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = r.instructions["z1"];
auto y_ins = r.instructions["z2"];
auto z_ins = r.instructions["z3"];
for(auto* pins : {&x_ins, &y_ins, &z_ins})
{
if(not(*pins)->get_shape().standard())
*pins = m.insert_instruction(ins, make_op("contiguous"), *pins);
}
auto relements = x_ins->get_shape().lens().back();
if(relements > 1024 or (relements % 4 != 0 and relements > 256))
return;
auto a = m.insert_instruction(
ins, make_op("hip::allocate", {{"shape", to_value(x_ins->get_shape())}}));
m.replace_instruction(ins, make_op("gpu::triadd_layernorm"), x_ins, y_ins, z_ins, a);
}
};
} // namespace
void prefuse_ops::apply(module& m) const
{
match::find_matches(m, find_triaddlayernorm{}, find_layernorm{});
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/eliminate_workspace.hpp> #include <migraphx/gpu/eliminate_workspace.hpp>
#include <migraphx/gpu/fuse_ops.hpp> #include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/gpu/lowering.hpp> #include <migraphx/gpu/lowering.hpp>
#include <migraphx/gpu/mlir_conv.hpp> #include <migraphx/gpu/mlir_conv.hpp>
#include <migraphx/gpu/pack_int8_args.hpp> #include <migraphx/gpu/pack_int8_args.hpp>
...@@ -96,6 +97,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -96,6 +97,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
simplify_algebra{}, simplify_algebra{},
simplify_reshapes{}, simplify_reshapes{},
simplify_algebra{}, simplify_algebra{},
prefuse_ops{},
dead_code_elimination{},
auto_contiguous{}, auto_contiguous{},
simplify_reshapes{}, simplify_reshapes{},
propagate_constant{}, propagate_constant{},
......
...@@ -180,6 +180,40 @@ TEST_CASE(duplicate_args3) ...@@ -180,6 +180,40 @@ TEST_CASE(duplicate_args3)
EXPECT(result == migraphx::literal{0}); EXPECT(result == migraphx::literal{0});
} }
TEST_CASE(reused_twice)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<size_t> dims = {1, 2, 2};
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, dims});
auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, dims});
auto z = mm->add_parameter("z", migraphx::shape{migraphx::shape::float_type, dims});
auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y);
auto add2 = mm->add_instruction(migraphx::make_op("add"), add1, z);
auto epsilon = mm->add_literal(1e-12f);
auto exponent = mm->add_literal(2.0f);
auto mean = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), add2);
auto mean_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), mean);
auto sub = mm->add_instruction(migraphx::make_op("sub"), add2, mean_mbcast);
auto exponent_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), exponent);
auto pow = mm->add_instruction(migraphx::make_op("pow"), sub, exponent_mbcast);
auto var = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), pow);
auto epsilon_mbcast = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, dims.at(1), 1}}}), epsilon);
auto add_epsilon = mm->add_instruction(migraphx::make_op("add"), var, epsilon_mbcast);
mm->add_instruction(migraphx::make_op("sqrt"), add_epsilon);
mm->add_instruction(migraphx::make_op("add"), x, y);
auto count = std::distance(mm->begin(), mm->end());
run_pass(p);
p.debug_print();
EXPECT(std::distance(mm->begin(), mm->end()) != count);
EXPECT(std::distance(mm->begin(), mm->end()) == 4);
}
TEST_CASE(unused_module) TEST_CASE(unused_module)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -332,7 +332,7 @@ TEST_CASE(match_either_args_any1) ...@@ -332,7 +332,7 @@ TEST_CASE(match_either_args_any1)
match::name("sum")(match::either_arg(0, 1)(match::any().bind("x"), match::any().bind("y"))); match::name("sum")(match::either_arg(0, 1)(match::any().bind("x"), match::any().bind("y")));
auto r = find_match(mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum1}); EXPECT(bool{r.result == sum1});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")}); EXPECT(bool{r.instructions["x"] != r.instructions["y"]});
} }
TEST_CASE(match_either_args_any2) TEST_CASE(match_either_args_any2)
...@@ -347,7 +347,7 @@ TEST_CASE(match_either_args_any2) ...@@ -347,7 +347,7 @@ TEST_CASE(match_either_args_any2)
match::either_arg(0, 1)(match::any().bind("x"), match::name("@literal").bind("y"))); match::either_arg(0, 1)(match::any().bind("x"), match::name("@literal").bind("y")));
auto r = find_match(mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum1}); EXPECT(bool{r.result == sum1});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")}); EXPECT(bool{r.instructions["x"] != r.instructions["y"]});
} }
TEST_CASE(match_either_args_any3) TEST_CASE(match_either_args_any3)
...@@ -362,7 +362,7 @@ TEST_CASE(match_either_args_any3) ...@@ -362,7 +362,7 @@ TEST_CASE(match_either_args_any3)
match::either_arg(0, 1)(match::name("@literal").bind("x"), match::any().bind("y"))); match::either_arg(0, 1)(match::name("@literal").bind("x"), match::any().bind("y")));
auto r = find_match(mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum1}); EXPECT(bool{r.result == sum1});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")}); EXPECT(bool{r.instructions["x"] != r.instructions["y"]});
} }
TEST_CASE(match_either_args_any4) TEST_CASE(match_either_args_any4)
...@@ -377,7 +377,7 @@ TEST_CASE(match_either_args_any4) ...@@ -377,7 +377,7 @@ TEST_CASE(match_either_args_any4)
match::either_arg(0, 1)(match::name("sum").bind("x"), match::any().bind("y"))); match::either_arg(0, 1)(match::name("sum").bind("x"), match::any().bind("y")));
auto r = find_match(mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum2}); EXPECT(bool{r.result == sum2});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")}); EXPECT(bool{r.instructions["x"] != r.instructions["y"]});
} }
TEST_CASE(match_either_args_any5) TEST_CASE(match_either_args_any5)
...@@ -392,7 +392,7 @@ TEST_CASE(match_either_args_any5) ...@@ -392,7 +392,7 @@ TEST_CASE(match_either_args_any5)
match::either_arg(0, 1)(match::any().bind("x"), match::name("sum").bind("y"))); match::either_arg(0, 1)(match::any().bind("x"), match::name("sum").bind("y")));
auto r = find_match(mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.result == sum2}); EXPECT(bool{r.result == sum2});
EXPECT(bool{r.instructions.at("x") != r.instructions.at("y")}); EXPECT(bool{r.instructions["x"] != r.instructions["y"]});
} }
TEST_CASE(match_all_of1) TEST_CASE(match_all_of1)
...@@ -747,10 +747,10 @@ TEST_CASE(match_bind1) ...@@ -747,10 +747,10 @@ TEST_CASE(match_bind1)
match::standard_shape()) match::standard_shape())
.bind("pass"); .bind("pass");
auto r = find_match(mm, m); auto r = find_match(mm, m);
EXPECT(bool{r.instructions.at("one") == one}); EXPECT(bool{r.instructions["one"] == one});
EXPECT(bool{r.instructions.at("two") == two}); EXPECT(bool{r.instructions["two"] == two});
EXPECT(bool{r.instructions.at("sum") == sum}); EXPECT(bool{r.instructions["sum"] == sum});
EXPECT(bool{r.instructions.at("pass") == pass}); EXPECT(bool{r.instructions["pass"] == pass});
EXPECT(bool{r.result == pass}); EXPECT(bool{r.result == pass});
} }
...@@ -795,9 +795,9 @@ TEST_CASE(match_bind_modules2) ...@@ -795,9 +795,9 @@ TEST_CASE(match_bind_modules2)
match::standard_shape()) match::standard_shape())
.bind("pass"); .bind("pass");
auto r = find_match(*child, m); auto r = find_match(*child, m);
EXPECT(bool{r.instructions.at("two") == two}); EXPECT(bool{r.instructions["two"] == two});
EXPECT(bool{r.instructions.at("sum") == sum}); EXPECT(bool{r.instructions["sum"] == sum});
EXPECT(bool{r.instructions.at("pass") == pass}); EXPECT(bool{r.instructions["pass"] == pass});
EXPECT(bool{r.result == pass}); EXPECT(bool{r.result == pass});
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment