Commit 4a066e44 authored by Alan Turner's avatar Alan Turner
Browse files

Address fuse_ck review comments

parent 59a0b0ce
...@@ -203,58 +203,46 @@ struct find_ck_gemm ...@@ -203,58 +203,46 @@ struct find_ck_gemm
} }
}; };
static bool is_mul_module(const module& m) static auto is_mul_module(module& m)
{ {
std::vector<std::string> result; auto is_mul = match::arg(0)(match::name("mul")(match::all_of[match::inputs()](match::name("@param"))));
for(auto& ins : m) return match_instruction(m, std::prev(m.end()), is_mul).result != m.end();
{ }
if(starts_with(ins.name(), "@"))
continue; MIGRAPHX_PRED_MATCHER(is_pointwise_scale, instruction_ref ins)
if(contains({"multibroadcast", "contiguous"}, ins.name())) {
continue; if (ins->name() != "pointwise")
if(ins.name() == "pointwise") return false;
{ if (ins->module_inputs().size() != 1)
return is_mul_module(*ins.module_inputs().front());
}
else if(ins.name() == "mul")
{
return true;
}
}
return false; return false;
return is_mul_module(*ins->module_inputs().front());
} }
struct find_ck_gemm_softmax_gemm struct find_ck_gemm_softmax_gemm
{ {
auto matcher() const auto matcher() const
{ {
auto gemm1 = auto gemm1 = match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1"))); auto mul = match::name("pointwise")(match::either_arg(0, 1)(match::is_constant().bind("scale"), gemm1))(is_pointwise_scale());
auto mul = match::name("pointwise")(match::any_of[match::inputs()](gemm1)).bind("scale"); auto softmax = match::name("softmax")(match::arg(0)(mul)).bind("softmax");
auto softmax = match::name("softmax")(match::any_of[match::inputs()](mul)).bind("softmax");
return match::name("dot")(is_ck_gemm().bind("gemm2"))( return match::name("dot")(is_ck_gemm().bind("gemm2"))(match::arg(0)(softmax));
match::any_of[match::inputs()](softmax));
} }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{ {
std::cout << "Matched GSG" << std::endl;
auto ins = r.result; auto ins = r.result;
auto gemm2_ins = r.instructions["gemm2"]; auto gemm2_ins = r.instructions["gemm2"];
auto gemm1_ins = r.instructions["gemm1"]; auto gemm1_ins = r.instructions["gemm1"];
auto scale_ins = r.instructions["scale"]; auto scale_lit = r.instructions["scale"];
if(scale_ins->module_inputs().size() != 1 or
not is_mul_module(*scale_ins->module_inputs().front()))
return;
if(not ck_gemm_softmax_gemm::is_ck_supported_type(gemm1_ins->get_shape().type())) if(not ck_gemm_softmax_gemm::is_ck_supported_type(gemm1_ins->get_shape().type()))
return; return;
double scale = 1.0; double scale = 1.0;
for(auto& in : scale_ins->inputs()) scale_lit->get_literal().visit([&](const auto s) {
{ // CK only supports single-valued scale
if(in->can_eval())
{
in->get_literal().visit([&](const auto s) {
if(std::all_of(s.begin() + 1, s.end(), [&](auto v) { if(std::all_of(s.begin() + 1, s.end(), [&](auto v) {
return float_equal(v, s.front()); return float_equal(v, s.front());
})) }))
...@@ -262,8 +250,6 @@ struct find_ck_gemm_softmax_gemm ...@@ -262,8 +250,6 @@ struct find_ck_gemm_softmax_gemm
else else
return; return;
}); });
}
}
auto inputs = gemm1_ins->inputs(); // A, B auto inputs = gemm1_ins->inputs(); // A, B
inputs.push_back(gemm2_ins->inputs().back()); // B1 inputs.push_back(gemm2_ins->inputs().back()); // B1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment