Commit aae03c3e authored by Paul's avatar Paul
Browse files

Replace non-fused gemms as well

parent de7bab05
......@@ -58,7 +58,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
return true;
}
struct find_ck_gemm
struct find_ck_gemm_pointwise
{
// Find a gemm followed by a pointwise operation.
auto matcher() const
......@@ -101,9 +101,24 @@ struct find_ck_gemm
}
};
struct find_ck_gemm
{
auto matcher() const { return match::name("dot")(is_ck_gemm().bind("gemm")); }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
mpm.get_module().replace_instruction(ins, ck_gemm{ins->get_operator()}, ins->inputs());
}
};
} // namespace
void fuse_ck::apply(module_pass_manager& mpm) const { match::find_matches(mpm, find_ck_gemm{}); }
void fuse_ck::apply(module_pass_manager& mpm) const
{
match::find_matches(mpm, find_ck_gemm_pointwise{});
match::find_matches(mpm, find_ck_gemm{});
}
} // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment