Commit c7a8a31f authored by Paul's avatar Paul
Browse files

Fuse gemm multiplies by scalar

parent 1be4aed9
...@@ -553,11 +553,13 @@ struct find_gemm_pointwise ...@@ -553,11 +553,13 @@ struct find_gemm_pointwise
{ {
auto matcher() const auto matcher() const
{ {
return precompile_name("pointwise")( auto gemm_op = match::name("gpu::gemm")(match::nargs(3), match::used_once()).bind("gemm");
match::nargs(3), auto binary_op = match::all_of(match::nargs(3),
match::either_arg(0, 1)( match::either_arg(0, 1)(
match::any_of(match::standard_shape(), match::is_constant()).bind("c"), match::any_of(match::standard_shape(), match::is_constant()).bind("c"),
match::name("gpu::gemm")(match::nargs(3), match::used_once()).bind("gemm"))); gemm_op));
auto unary_op = match::all_of(match::nargs(2), match::arg(0)(gemm_op));
return precompile_name("pointwise")(match::any_of(binary_op, unary_op));
} }
// TODO: Move to matcher.hpp // TODO: Move to matcher.hpp
...@@ -589,15 +591,30 @@ struct find_gemm_pointwise ...@@ -589,15 +591,30 @@ struct find_gemm_pointwise
return match::name("@return")(match::args(match::any_of(add, mul_add, add_mul))); return match::name("@return")(match::args(match::any_of(add, mul_add, add_mul)));
} }
static auto match_mul(const std::string& input)
{
auto mul = match_mul_const(match_param(input), "alpha");
return match::name("@return")(match::args(mul));
}
static float get_float(instruction_ref ins) { return ins->get_literal().at<float>(); } static float get_float(instruction_ref ins) { return ins->get_literal().at<float>(); }
template <class Gemm> template <class Gemm>
static bool update_gemm(Gemm& gemm, module_ref pm, unsigned input) static bool update_gemm(Gemm& gemm, module_ref pm, unsigned input)
{ {
auto names = pm->get_parameter_names(); auto names = pm->get_parameter_names();
if(names.size() != 2)
return false;
std::sort(names.begin(), names.end()); std::sort(names.begin(), names.end());
if (names.size() == 1)
{
auto mr = match::match_instruction(
*pm, std::prev(pm->end()), match_mul(names[input]));
if(mr.result == pm->end())
return false;
gemm.alpha *= get_float(mr.instructions["alpha"]);
return true;
}
else if (names.size() == 2)
{
unsigned output = input == 0 ? 1 : 0; unsigned output = input == 0 ? 1 : 0;
auto mr = match::match_instruction( auto mr = match::match_instruction(
*pm, std::prev(pm->end()), match_add(names[input], names[output])); *pm, std::prev(pm->end()), match_add(names[input], names[output]));
...@@ -614,24 +631,35 @@ struct find_gemm_pointwise ...@@ -614,24 +631,35 @@ struct find_gemm_pointwise
} }
return true; return true;
} }
else
{
return false;
}
}
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto gemm_ins = r.instructions["gemm"]; auto gemm_ins = r.instructions["gemm"];
auto c_ins = r.instructions["c"];
auto gemm = any_cast<rocblas_gemm<op::dot>>(gemm_ins->get_operator()); auto gemm = any_cast<rocblas_gemm<op::dot>>(gemm_ins->get_operator());
// Already fused gemm // Already fused gemm
if(not float_equal(gemm.beta, 0)) if(not float_equal(gemm.beta, 0))
return; return;
if (ins->inputs().size() == 3)
gemm.beta = 1; gemm.beta = 1;
if(not update_gemm( if(not update_gemm(
gemm, ins->module_inputs().front(), ins->inputs().front() == gemm_ins ? 0 : 1)) gemm, ins->module_inputs().front(), ins->inputs().front() == gemm_ins ? 0 : 1))
return; return;
auto inputs = gemm_ins->inputs();
inputs.pop_back();
if (ins->inputs().size() == 3)
{
auto c_ins = r.instructions["c"];
// const-fold input if not standard shape since rocblas can't handle it // const-fold input if not standard shape since rocblas can't handle it
if(not c_ins->get_shape().standard()) if(not c_ins->get_shape().standard())
{ {
...@@ -639,11 +667,9 @@ struct find_gemm_pointwise ...@@ -639,11 +667,9 @@ struct find_gemm_pointwise
auto l = c.compute(c.compute_shape({c_ins->get_shape()}), {c_ins->eval()}); auto l = c.compute(c.compute_shape({c_ins->get_shape()}), {c_ins->eval()});
c_ins = m.add_literal(l.get_shape(), l.data()); c_ins = m.add_literal(l.get_shape(), l.data());
} }
auto inputs = gemm_ins->inputs();
inputs.pop_back();
inputs.push_back(c_ins); inputs.push_back(c_ins);
}
inputs.push_back(ins->inputs().back()); inputs.push_back(ins->inputs().back());
m.replace_instruction(ins, gemm, inputs); m.replace_instruction(ins, gemm, inputs);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment