Commit 100551f7 authored by Paul's avatar Paul
Browse files

Format

parent c7a8a31f
...@@ -553,11 +553,11 @@ struct find_gemm_pointwise ...@@ -553,11 +553,11 @@ struct find_gemm_pointwise
{ {
auto matcher() const auto matcher() const
{ {
auto gemm_op = match::name("gpu::gemm")(match::nargs(3), match::used_once()).bind("gemm"); auto gemm_op = match::name("gpu::gemm")(match::nargs(3), match::used_once()).bind("gemm");
auto binary_op = match::all_of(match::nargs(3), auto binary_op = match::all_of(
match::nargs(3),
match::either_arg(0, 1)( match::either_arg(0, 1)(
match::any_of(match::standard_shape(), match::is_constant()).bind("c"), match::any_of(match::standard_shape(), match::is_constant()).bind("c"), gemm_op));
gemm_op));
auto unary_op = match::all_of(match::nargs(2), match::arg(0)(gemm_op)); auto unary_op = match::all_of(match::nargs(2), match::arg(0)(gemm_op));
return precompile_name("pointwise")(match::any_of(binary_op, unary_op)); return precompile_name("pointwise")(match::any_of(binary_op, unary_op));
} }
...@@ -604,16 +604,15 @@ struct find_gemm_pointwise ...@@ -604,16 +604,15 @@ struct find_gemm_pointwise
{ {
auto names = pm->get_parameter_names(); auto names = pm->get_parameter_names();
std::sort(names.begin(), names.end()); std::sort(names.begin(), names.end());
if (names.size() == 1) if(names.size() == 1)
{ {
auto mr = match::match_instruction( auto mr = match::match_instruction(*pm, std::prev(pm->end()), match_mul(names[input]));
*pm, std::prev(pm->end()), match_mul(names[input]));
if(mr.result == pm->end()) if(mr.result == pm->end())
return false; return false;
gemm.alpha *= get_float(mr.instructions["alpha"]); gemm.alpha *= get_float(mr.instructions["alpha"]);
return true; return true;
} }
else if (names.size() == 2) else if(names.size() == 2)
{ {
unsigned output = input == 0 ? 1 : 0; unsigned output = input == 0 ? 1 : 0;
auto mr = match::match_instruction( auto mr = match::match_instruction(
...@@ -647,9 +646,9 @@ struct find_gemm_pointwise ...@@ -647,9 +646,9 @@ struct find_gemm_pointwise
// Already fused gemm // Already fused gemm
if(not float_equal(gemm.beta, 0)) if(not float_equal(gemm.beta, 0))
return; return;
if (ins->inputs().size() == 3) if(ins->inputs().size() == 3)
gemm.beta = 1; gemm.beta = 1;
if(not update_gemm( if(not update_gemm(
gemm, ins->module_inputs().front(), ins->inputs().front() == gemm_ins ? 0 : 1)) gemm, ins->module_inputs().front(), ins->inputs().front() == gemm_ins ? 0 : 1))
return; return;
...@@ -657,9 +656,9 @@ struct find_gemm_pointwise ...@@ -657,9 +656,9 @@ struct find_gemm_pointwise
auto inputs = gemm_ins->inputs(); auto inputs = gemm_ins->inputs();
inputs.pop_back(); inputs.pop_back();
if (ins->inputs().size() == 3) if(ins->inputs().size() == 3)
{ {
auto c_ins = r.instructions["c"]; auto c_ins = r.instructions["c"];
// const-fold input if not standard shape since rocblas can't handle it // const-fold input if not standard shape since rocblas can't handle it
if(not c_ins->get_shape().standard()) if(not c_ins->get_shape().standard())
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment