Commit 7f387483 authored by turneram's avatar turneram
Browse files

Merge remote-tracking branch 'origin/fuse-gemm-add-pointwise' into transformer_opts

parents f7fe5508 8354be18
...@@ -932,6 +932,53 @@ struct find_gemm_add ...@@ -932,6 +932,53 @@ struct find_gemm_add
} }
}; };
auto pointwise_name(const std::string& s)
{
return precompile_name("pointwise")(match::make_basic_pred_matcher([=](auto ins) {
module_ref pm = ins->module_inputs().front();
auto n = std::count_if(pm->begin(), pm->end(), [&](auto& i) { return i.name() == s; });
if(n != 1)
return false;
return std::all_of(pm->begin(), pm->end(), [&](auto& i) {
return starts_with(i.name(), "@") or i.name() == s;
});
}));
}
struct find_gemm_pointwise
{
auto matcher() const
{
return pointwise_name("add")(
match::nargs(3),
match::all_of[match::inputs()](match::standard_shape()),
match::either_arg(0, 1)(match::used_once().bind("c"),
match::name("gpu::gemm")(match::nargs(3)).bind("gemm")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto gemm_ins = r.instructions["gemm"];
auto c_ins = r.instructions["c"];
auto gemm = any_cast<rocblas_gemm<op::dot>>(gemm_ins->get_operator());
// Already fused gemm
if(not float_equal(gemm.beta, 0))
return;
auto inputs = gemm_ins->inputs();
inputs.pop_back();
inputs.push_back(c_ins);
inputs.push_back(gemm_ins->inputs().back());
gemm.beta = 1;
m.replace_instruction(ins, gemm, inputs);
}
};
struct find_commutative_broadcast struct find_commutative_broadcast
{ {
auto matcher() const auto matcher() const
...@@ -1003,8 +1050,12 @@ void fuse_ops::apply(module& m) const ...@@ -1003,8 +1050,12 @@ void fuse_ops::apply(module& m) const
find_add_unary{"gpu::tanh", hip_add_tanh{}, hip_triadd_tanh{}}, find_add_unary{"gpu::tanh", hip_add_tanh{}, hip_triadd_tanh{}},
find_add_clip{}); find_add_clip{});
run_passes(m, {dead_code_elimination{}}); run_passes(m, {dead_code_elimination{}});
match::find_matches(m, find_triadd_layernorm{}, find_gemm_add{}, find_commutative_broadcast{});
match::find_matches(m, find_contiguous{}); match::find_matches(m, find_contiguous{});
match::find_matches(m,
find_triadd_layernorm{},
find_gemm_add{},
find_gemm_pointwise{},
find_commutative_broadcast{});
} }
} // namespace gpu } // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment