Commit 4a066e44 authored by Alan Turner's avatar Alan Turner
Browse files

Address fuse_ck review comments

parent 59a0b0ce
......@@ -203,67 +203,53 @@ struct find_ck_gemm
}
};
static bool is_mul_module(const module& m)
static auto is_mul_module(module& m)
{
std::vector<std::string> result;
for(auto& ins : m)
{
if(starts_with(ins.name(), "@"))
continue;
if(contains({"multibroadcast", "contiguous"}, ins.name()))
continue;
if(ins.name() == "pointwise")
{
return is_mul_module(*ins.module_inputs().front());
}
else if(ins.name() == "mul")
{
return true;
}
}
return false;
auto is_mul = match::arg(0)(match::name("mul")(match::all_of[match::inputs()](match::name("@param"))));
return match_instruction(m, std::prev(m.end()), is_mul).result != m.end();
}
MIGRAPHX_PRED_MATCHER(is_pointwise_scale, instruction_ref ins)
{
if (ins->name() != "pointwise")
return false;
if (ins->module_inputs().size() != 1)
return false;
return is_mul_module(*ins->module_inputs().front());
}
struct find_ck_gemm_softmax_gemm
{
auto matcher() const
{
auto gemm1 =
match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
auto mul = match::name("pointwise")(match::any_of[match::inputs()](gemm1)).bind("scale");
auto softmax = match::name("softmax")(match::any_of[match::inputs()](mul)).bind("softmax");
return match::name("dot")(is_ck_gemm().bind("gemm2"))(
match::any_of[match::inputs()](softmax));
auto gemm1 = match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
auto mul = match::name("pointwise")(match::either_arg(0, 1)(match::is_constant().bind("scale"), gemm1))(is_pointwise_scale());
auto softmax = match::name("softmax")(match::arg(0)(mul)).bind("softmax");
return match::name("dot")(is_ck_gemm().bind("gemm2"))(match::arg(0)(softmax));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
std::cout << "Matched GSG" << std::endl;
auto ins = r.result;
auto gemm2_ins = r.instructions["gemm2"];
auto gemm1_ins = r.instructions["gemm1"];
auto scale_ins = r.instructions["scale"];
auto scale_lit = r.instructions["scale"];
if(scale_ins->module_inputs().size() != 1 or
not is_mul_module(*scale_ins->module_inputs().front()))
return;
if(not ck_gemm_softmax_gemm::is_ck_supported_type(gemm1_ins->get_shape().type()))
return;
double scale = 1.0;
for(auto& in : scale_ins->inputs())
{
if(in->can_eval())
{
in->get_literal().visit([&](const auto s) {
if(std::all_of(s.begin() + 1, s.end(), [&](auto v) {
return float_equal(v, s.front());
}))
scale = s.front();
else
return;
});
}
}
scale_lit->get_literal().visit([&](const auto s) {
// CK only supports single-valued scale
if(std::all_of(s.begin() + 1, s.end(), [&](auto v) {
return float_equal(v, s.front());
}))
scale = s.front();
else
return;
});
auto inputs = gemm1_ins->inputs(); // A, B
inputs.push_back(gemm2_ins->inputs().back()); // B1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment