Commit 297572f5 authored by Paul's avatar Paul
Browse files

Improve gemm fusion

parent 4ac2919f
......@@ -947,13 +947,71 @@ struct find_gemm_pointwise
{
auto matcher() const
{
return pointwise_name("add")(
return precompile_name("pointwise")(
match::nargs(3),
match::all_of[match::inputs()](match::standard_shape()),
match::either_arg(0, 1)(match::used_once().bind("c"),
match::name("gpu::gemm")(match::nargs(3)).bind("gemm")));
}
// TODO: Move to matcher.hpp
static auto match_param(const std::string& name)
{
return match::make_basic_pred_matcher([=](auto ins) {
if (ins->name() != "@param")
return false;
auto p = any_cast<builtin::param>(ins->get_operator());
return p.parameter == name;
});
}
template<class M>
static auto match_mul_const(M m, const std::string& var)
{
return match::name("mul")(match::either_arg(0, 1)(match::name("@literal").bind(var), m)).bind(var+"_mul");
}
static auto match_add(const std::string& input, const std::string& output)
{
auto param = match::name("@param");
auto add = match::name("add")(match::args(param, param));
auto inner_mul = match::any_of(
match_mul_const(match_param(input), "alpha"),
match_mul_const(match_param(output), "beta")
);
auto mul_add = match::name("add")(match::either_arg(0, 1)(inner_mul, param));
auto add_mul = match_mul_const(add, "gamma");
return match::name("@return")(match::args(match::any_of(add, mul_add, add_mul)));
}
static float get_float(instruction_ref ins)
{
return ins->get_literal().at<float>();
}
template<class Gemm>
static bool update_gemm(Gemm& gemm, module_ref pm, unsigned input)
{
auto names = pm->get_parameter_names();
if(names.size() != 2)
return false;
std::sort(names.begin(), names.end());
unsigned output = input == 0 ? 1 : 0;
auto mr = match::match_instruction(*pm, std::prev(pm->end()), match_add(names[input], names[output]));
if (mr.result == pm->end())
return false;
if (contains(mr.instructions, "alpha_mul"))
gemm.alpha *= get_float(mr.instructions["alpha"]);
else if (contains(mr.instructions, "beta_mul"))
gemm.beta *= get_float(mr.instructions["beta"]);
else if (contains(mr.instructions, "gamma_mul"))
{
gemm.alpha *= get_float(mr.instructions["gamma"]);
gemm.beta *= get_float(mr.instructions["gamma"]);
}
return true;
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
......@@ -965,6 +1023,10 @@ struct find_gemm_pointwise
// Already fused gemm
if(not float_equal(gemm.beta, 0))
return;
gemm.beta = 1;
if (not update_gemm(gemm, ins->module_inputs().front(), ins->inputs().front() == gemm_ins ? 0 : 1))
return;
auto inputs = gemm_ins->inputs();
inputs.pop_back();
......@@ -972,7 +1034,6 @@ struct find_gemm_pointwise
inputs.push_back(c_ins);
inputs.push_back(ins->inputs().back());
gemm.beta = 1;
m.replace_instruction(ins, gemm, inputs);
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment