Commit 100551f7 authored by Paul's avatar Paul
Browse files

Format

parent c7a8a31f
...@@ -554,10 +554,10 @@ struct find_gemm_pointwise ...@@ -554,10 +554,10 @@ struct find_gemm_pointwise
auto matcher() const auto matcher() const
{ {
auto gemm_op = match::name("gpu::gemm")(match::nargs(3), match::used_once()).bind("gemm"); auto gemm_op = match::name("gpu::gemm")(match::nargs(3), match::used_once()).bind("gemm");
auto binary_op = match::all_of(match::nargs(3), auto binary_op = match::all_of(
match::nargs(3),
match::either_arg(0, 1)( match::either_arg(0, 1)(
match::any_of(match::standard_shape(), match::is_constant()).bind("c"), match::any_of(match::standard_shape(), match::is_constant()).bind("c"), gemm_op));
gemm_op));
auto unary_op = match::all_of(match::nargs(2), match::arg(0)(gemm_op)); auto unary_op = match::all_of(match::nargs(2), match::arg(0)(gemm_op));
return precompile_name("pointwise")(match::any_of(binary_op, unary_op)); return precompile_name("pointwise")(match::any_of(binary_op, unary_op));
} }
...@@ -604,16 +604,15 @@ struct find_gemm_pointwise ...@@ -604,16 +604,15 @@ struct find_gemm_pointwise
{ {
auto names = pm->get_parameter_names(); auto names = pm->get_parameter_names();
std::sort(names.begin(), names.end()); std::sort(names.begin(), names.end());
if (names.size() == 1) if(names.size() == 1)
{ {
auto mr = match::match_instruction( auto mr = match::match_instruction(*pm, std::prev(pm->end()), match_mul(names[input]));
*pm, std::prev(pm->end()), match_mul(names[input]));
if(mr.result == pm->end()) if(mr.result == pm->end())
return false; return false;
gemm.alpha *= get_float(mr.instructions["alpha"]); gemm.alpha *= get_float(mr.instructions["alpha"]);
return true; return true;
} }
else if (names.size() == 2) else if(names.size() == 2)
{ {
unsigned output = input == 0 ? 1 : 0; unsigned output = input == 0 ? 1 : 0;
auto mr = match::match_instruction( auto mr = match::match_instruction(
...@@ -647,7 +646,7 @@ struct find_gemm_pointwise ...@@ -647,7 +646,7 @@ struct find_gemm_pointwise
// Already fused gemm // Already fused gemm
if(not float_equal(gemm.beta, 0)) if(not float_equal(gemm.beta, 0))
return; return;
if (ins->inputs().size() == 3) if(ins->inputs().size() == 3)
gemm.beta = 1; gemm.beta = 1;
if(not update_gemm( if(not update_gemm(
...@@ -657,7 +656,7 @@ struct find_gemm_pointwise ...@@ -657,7 +656,7 @@ struct find_gemm_pointwise
auto inputs = gemm_ins->inputs(); auto inputs = gemm_ins->inputs();
inputs.pop_back(); inputs.pop_back();
if (ins->inputs().size() == 3) if(ins->inputs().size() == 3)
{ {
auto c_ins = r.instructions["c"]; auto c_ins = r.instructions["c"];
// const-fold input if not standard shape since rocblas can't handle it // const-fold input if not standard shape since rocblas can't handle it
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment