Commit 791addfb authored by Alan Turner's avatar Alan Turner
Browse files

Formatting

parent 4a066e44
...@@ -205,15 +205,16 @@ struct find_ck_gemm ...@@ -205,15 +205,16 @@ struct find_ck_gemm
static auto is_mul_module(module& m) static auto is_mul_module(module& m)
{ {
auto is_mul = match::arg(0)(match::name("mul")(match::all_of[match::inputs()](match::name("@param")))); auto is_mul =
match::arg(0)(match::name("mul")(match::all_of[match::inputs()](match::name("@param"))));
return match_instruction(m, std::prev(m.end()), is_mul).result != m.end(); return match_instruction(m, std::prev(m.end()), is_mul).result != m.end();
} }
MIGRAPHX_PRED_MATCHER(is_pointwise_scale, instruction_ref ins) MIGRAPHX_PRED_MATCHER(is_pointwise_scale, instruction_ref ins)
{ {
if (ins->name() != "pointwise") if(ins->name() != "pointwise")
return false; return false;
if (ins->module_inputs().size() != 1) if(ins->module_inputs().size() != 1)
return false; return false;
return is_mul_module(*ins->module_inputs().front()); return is_mul_module(*ins->module_inputs().front());
} }
...@@ -222,8 +223,10 @@ struct find_ck_gemm_softmax_gemm ...@@ -222,8 +223,10 @@ struct find_ck_gemm_softmax_gemm
{ {
auto matcher() const auto matcher() const
{ {
auto gemm1 = match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1"))); auto gemm1 =
auto mul = match::name("pointwise")(match::either_arg(0, 1)(match::is_constant().bind("scale"), gemm1))(is_pointwise_scale()); match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
auto mul = match::name("pointwise")(match::either_arg(0, 1)(
match::is_constant().bind("scale"), gemm1))(is_pointwise_scale());
auto softmax = match::name("softmax")(match::arg(0)(mul)).bind("softmax"); auto softmax = match::name("softmax")(match::arg(0)(mul)).bind("softmax");
return match::name("dot")(is_ck_gemm().bind("gemm2"))(match::arg(0)(softmax)); return match::name("dot")(is_ck_gemm().bind("gemm2"))(match::arg(0)(softmax));
...@@ -243,9 +246,8 @@ struct find_ck_gemm_softmax_gemm ...@@ -243,9 +246,8 @@ struct find_ck_gemm_softmax_gemm
double scale = 1.0; double scale = 1.0;
scale_lit->get_literal().visit([&](const auto s) { scale_lit->get_literal().visit([&](const auto s) {
// CK only supports single-valued scale // CK only supports single-valued scale
if(std::all_of(s.begin() + 1, s.end(), [&](auto v) { if(std::all_of(
return float_equal(v, s.front()); s.begin() + 1, s.end(), [&](auto v) { return float_equal(v, s.front()); }))
}))
scale = s.front(); scale = s.front();
else else
return; return;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment