Commit 1dda8943 authored by Paul's avatar Paul
Browse files

Format

parent 297572f5
...@@ -958,38 +958,34 @@ struct find_gemm_pointwise ...@@ -958,38 +958,34 @@ struct find_gemm_pointwise
static auto match_param(const std::string& name) static auto match_param(const std::string& name)
{ {
return match::make_basic_pred_matcher([=](auto ins) { return match::make_basic_pred_matcher([=](auto ins) {
if (ins->name() != "@param") if(ins->name() != "@param")
return false; return false;
auto p = any_cast<builtin::param>(ins->get_operator()); auto p = any_cast<builtin::param>(ins->get_operator());
return p.parameter == name; return p.parameter == name;
}); });
} }
template<class M> template <class M>
static auto match_mul_const(M m, const std::string& var) static auto match_mul_const(M m, const std::string& var)
{ {
return match::name("mul")(match::either_arg(0, 1)(match::name("@literal").bind(var), m)).bind(var+"_mul"); return match::name("mul")(match::either_arg(0, 1)(match::name("@literal").bind(var), m))
.bind(var + "_mul");
} }
static auto match_add(const std::string& input, const std::string& output) static auto match_add(const std::string& input, const std::string& output)
{ {
auto param = match::name("@param"); auto param = match::name("@param");
auto add = match::name("add")(match::args(param, param)); auto add = match::name("add")(match::args(param, param));
auto inner_mul = match::any_of( auto inner_mul = match::any_of(match_mul_const(match_param(input), "alpha"),
match_mul_const(match_param(input), "alpha"), match_mul_const(match_param(output), "beta"));
match_mul_const(match_param(output), "beta")
);
auto mul_add = match::name("add")(match::either_arg(0, 1)(inner_mul, param)); auto mul_add = match::name("add")(match::either_arg(0, 1)(inner_mul, param));
auto add_mul = match_mul_const(add, "gamma"); auto add_mul = match_mul_const(add, "gamma");
return match::name("@return")(match::args(match::any_of(add, mul_add, add_mul))); return match::name("@return")(match::args(match::any_of(add, mul_add, add_mul)));
} }
static float get_float(instruction_ref ins) static float get_float(instruction_ref ins) { return ins->get_literal().at<float>(); }
{
return ins->get_literal().at<float>();
}
template<class Gemm> template <class Gemm>
static bool update_gemm(Gemm& gemm, module_ref pm, unsigned input) static bool update_gemm(Gemm& gemm, module_ref pm, unsigned input)
{ {
auto names = pm->get_parameter_names(); auto names = pm->get_parameter_names();
...@@ -997,14 +993,15 @@ struct find_gemm_pointwise ...@@ -997,14 +993,15 @@ struct find_gemm_pointwise
return false; return false;
std::sort(names.begin(), names.end()); std::sort(names.begin(), names.end());
unsigned output = input == 0 ? 1 : 0; unsigned output = input == 0 ? 1 : 0;
auto mr = match::match_instruction(*pm, std::prev(pm->end()), match_add(names[input], names[output])); auto mr = match::match_instruction(
if (mr.result == pm->end()) *pm, std::prev(pm->end()), match_add(names[input], names[output]));
if(mr.result == pm->end())
return false; return false;
if (contains(mr.instructions, "alpha_mul")) if(contains(mr.instructions, "alpha_mul"))
gemm.alpha *= get_float(mr.instructions["alpha"]); gemm.alpha *= get_float(mr.instructions["alpha"]);
else if (contains(mr.instructions, "beta_mul")) else if(contains(mr.instructions, "beta_mul"))
gemm.beta *= get_float(mr.instructions["beta"]); gemm.beta *= get_float(mr.instructions["beta"]);
else if (contains(mr.instructions, "gamma_mul")) else if(contains(mr.instructions, "gamma_mul"))
{ {
gemm.alpha *= get_float(mr.instructions["gamma"]); gemm.alpha *= get_float(mr.instructions["gamma"]);
gemm.beta *= get_float(mr.instructions["gamma"]); gemm.beta *= get_float(mr.instructions["gamma"]);
...@@ -1025,7 +1022,8 @@ struct find_gemm_pointwise ...@@ -1025,7 +1022,8 @@ struct find_gemm_pointwise
return; return;
gemm.beta = 1; gemm.beta = 1;
if (not update_gemm(gemm, ins->module_inputs().front(), ins->inputs().front() == gemm_ins ? 0 : 1)) if(not update_gemm(
gemm, ins->module_inputs().front(), ins->inputs().front() == gemm_ins ? 0 : 1))
return; return;
auto inputs = gemm_ins->inputs(); auto inputs = gemm_ins->inputs();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment