"src/targets/vscode:/vscode.git/clone" did not exist on "82b60de915f91a939f0f4e11e39f6e0279ce2dce"
Commit 1dda8943 authored by Paul's avatar Paul
Browse files

Format

parent 297572f5
......@@ -958,38 +958,34 @@ struct find_gemm_pointwise
static auto match_param(const std::string& name)
{
return match::make_basic_pred_matcher([=](auto ins) {
if (ins->name() != "@param")
if(ins->name() != "@param")
return false;
auto p = any_cast<builtin::param>(ins->get_operator());
return p.parameter == name;
});
}
template<class M>
template <class M>
static auto match_mul_const(M m, const std::string& var)
{
return match::name("mul")(match::either_arg(0, 1)(match::name("@literal").bind(var), m)).bind(var+"_mul");
return match::name("mul")(match::either_arg(0, 1)(match::name("@literal").bind(var), m))
.bind(var + "_mul");
}
static auto match_add(const std::string& input, const std::string& output)
{
auto param = match::name("@param");
auto add = match::name("add")(match::args(param, param));
auto inner_mul = match::any_of(
match_mul_const(match_param(input), "alpha"),
match_mul_const(match_param(output), "beta")
);
auto inner_mul = match::any_of(match_mul_const(match_param(input), "alpha"),
match_mul_const(match_param(output), "beta"));
auto mul_add = match::name("add")(match::either_arg(0, 1)(inner_mul, param));
auto add_mul = match_mul_const(add, "gamma");
return match::name("@return")(match::args(match::any_of(add, mul_add, add_mul)));
}
static float get_float(instruction_ref ins)
{
return ins->get_literal().at<float>();
}
static float get_float(instruction_ref ins) { return ins->get_literal().at<float>(); }
template<class Gemm>
template <class Gemm>
static bool update_gemm(Gemm& gemm, module_ref pm, unsigned input)
{
auto names = pm->get_parameter_names();
......@@ -997,14 +993,15 @@ struct find_gemm_pointwise
return false;
std::sort(names.begin(), names.end());
unsigned output = input == 0 ? 1 : 0;
auto mr = match::match_instruction(*pm, std::prev(pm->end()), match_add(names[input], names[output]));
if (mr.result == pm->end())
auto mr = match::match_instruction(
*pm, std::prev(pm->end()), match_add(names[input], names[output]));
if(mr.result == pm->end())
return false;
if (contains(mr.instructions, "alpha_mul"))
if(contains(mr.instructions, "alpha_mul"))
gemm.alpha *= get_float(mr.instructions["alpha"]);
else if (contains(mr.instructions, "beta_mul"))
else if(contains(mr.instructions, "beta_mul"))
gemm.beta *= get_float(mr.instructions["beta"]);
else if (contains(mr.instructions, "gamma_mul"))
else if(contains(mr.instructions, "gamma_mul"))
{
gemm.alpha *= get_float(mr.instructions["gamma"]);
gemm.beta *= get_float(mr.instructions["gamma"]);
......@@ -1025,7 +1022,8 @@ struct find_gemm_pointwise
return;
gemm.beta = 1;
if (not update_gemm(gemm, ins->module_inputs().front(), ins->inputs().front() == gemm_ins ? 0 : 1))
if(not update_gemm(
gemm, ins->module_inputs().front(), ins->inputs().front() == gemm_ins ? 0 : 1))
return;
auto inputs = gemm_ins->inputs();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment