Commit aae03c3e authored by Paul's avatar Paul
Browse files

Replace non-fused gemms as well

parent de7bab05
...@@ -58,7 +58,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) ...@@ -58,7 +58,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
return true; return true;
} }
struct find_ck_gemm struct find_ck_gemm_pointwise
{ {
// Find a gemm followed by a pointwise operation. // Find a gemm followed by a pointwise operation.
auto matcher() const auto matcher() const
...@@ -101,9 +101,24 @@ struct find_ck_gemm ...@@ -101,9 +101,24 @@ struct find_ck_gemm
} }
}; };
struct find_ck_gemm
{
auto matcher() const { return match::name("dot")(is_ck_gemm().bind("gemm")); }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
mpm.get_module().replace_instruction(ins, ck_gemm{ins->get_operator()}, ins->inputs());
}
};
} // namespace } // namespace
void fuse_ck::apply(module_pass_manager& mpm) const { match::find_matches(mpm, find_ck_gemm{}); } void fuse_ck::apply(module_pass_manager& mpm) const
{
match::find_matches(mpm, find_ck_gemm_pointwise{});
match::find_matches(mpm, find_ck_gemm{});
}
} // namespace gpu } // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment