Commit 33189188 authored by Paul's avatar Paul
Browse files

Merge branch 'dot-add' into bert-opt

parents 5d6880e5 1dda8943
...@@ -185,6 +185,42 @@ struct find_mul_add ...@@ -185,6 +185,42 @@ struct find_mul_add
} }
}; };
struct find_dot_add
{
auto matcher() const
{
return match::name("dot")(match::either_arg(0, 1)(
match::name("add")(
match::either_arg(0, 1)(match::any().bind("x"),
match::any_of(match::is_constant()).bind("b")),
match::none_of(match::args(match::is_constant(), match::is_constant())),
match::used_once()),
match::is_constant().bind("a")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"];
auto x_ins = r.instructions["x"];
assert(x_ins != b_ins);
const bool flipped = a_ins == ins->inputs().back();
auto insert_dot = [&](auto x, auto y) {
if(flipped)
return m.insert_instruction(ins, make_op("dot"), y, x);
else
return m.insert_instruction(ins, make_op("dot"), x, y);
};
auto ax_ins = insert_dot(a_ins, x_ins);
auto ab_ins = insert_dot(a_ins, b_ins);
m.replace_instruction(ins, make_op("add"), ax_ins, ab_ins);
}
};
struct find_add_lit_broadcast struct find_add_lit_broadcast
{ {
auto matcher() const auto matcher() const
...@@ -246,26 +282,35 @@ struct find_inner_broadcast ...@@ -246,26 +282,35 @@ struct find_inner_broadcast
{ {
auto matcher() const auto matcher() const
{ {
return pointwise( return pointwise(match::all_of[match::inputs()](
match::nargs(2), match::broadcast_shape(), match::name("broadcast", "multibroadcast")));
match::args(match::name("broadcast").bind("x"), match::name("broadcast").bind("y")));
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto inputs = ins->inputs();
auto y_ins = r.instructions["y"]; if(inputs.empty())
return;
auto xbroadcast = any_cast<op::broadcast>(x_ins->get_operator()); std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto i) {
auto ybroadcast = any_cast<op::broadcast>(y_ins->get_operator()); if(contains({"broadcast", "multibroadcast"}, i->name()))
return i->inputs().front();
else
return i;
});
if(xbroadcast.axis != ybroadcast.axis) if(not std::all_of(inputs.begin(), inputs.end(), [&](auto& x) {
return x->get_shape() == inputs.front()->get_shape();
}))
return; return;
auto op = m.insert_instruction( auto op = m.insert_instruction(ins, ins->get_operator(), inputs);
ins, ins->get_operator(), x_ins->inputs().front(), y_ins->inputs().front()); auto bop = std::find_if(ins->inputs().begin(), ins->inputs().end(), [&](auto i) {
m.replace_instruction(ins, xbroadcast, op); return contains({"broadcast", "multibroadcast"}, i->name());
});
assert(bop != ins->inputs().end());
m.replace_instruction(ins, (*bop)->get_operator(), op);
} }
}; };
...@@ -1025,6 +1070,7 @@ void simplify_algebra::apply(module& m) const ...@@ -1025,6 +1070,7 @@ void simplify_algebra::apply(module& m) const
find_mul_conv{}, find_mul_conv{},
find_mul_slice_conv{}, find_mul_slice_conv{},
find_mul_add{}, find_mul_add{},
find_dot_add{},
find_div_const{}, find_div_const{},
find_sub_const{}, find_sub_const{},
find_rsqrt{}, find_rsqrt{},
......
...@@ -256,6 +256,11 @@ MIGRAPHX_REGISTER_OP(hip_layernorm) ...@@ -256,6 +256,11 @@ MIGRAPHX_REGISTER_OP(hip_layernorm)
struct hip_triadd_layernorm : ternary_device<hip_triadd_layernorm, &device::triadd_layernorm> struct hip_triadd_layernorm : ternary_device<hip_triadd_layernorm, &device::triadd_layernorm>
{ {
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(4).standard();
return inputs[0];
}
// Empty finalize to skip dimension reduction // Empty finalize to skip dimension reduction
void finalize(context&, const shape&, const std::vector<shape>&) {} void finalize(context&, const shape&, const std::vector<shape>&) {}
}; };
...@@ -943,13 +948,68 @@ struct find_gemm_pointwise ...@@ -943,13 +948,68 @@ struct find_gemm_pointwise
{ {
auto matcher() const auto matcher() const
{ {
return pointwise_name("add")( return precompile_name("pointwise")(
match::nargs(3), match::nargs(3),
match::all_of[match::inputs()](match::standard_shape()), match::all_of[match::inputs()](match::standard_shape()),
match::either_arg(0, 1)(match::used_once().bind("c"), match::either_arg(0, 1)(match::used_once().bind("c"),
match::name("gpu::gemm")(match::nargs(3)).bind("gemm"))); match::name("gpu::gemm")(match::nargs(3)).bind("gemm")));
} }
// TODO: Move to matcher.hpp
static auto match_param(const std::string& name)
{
return match::make_basic_pred_matcher([=](auto ins) {
if(ins->name() != "@param")
return false;
auto p = any_cast<builtin::param>(ins->get_operator());
return p.parameter == name;
});
}
template <class M>
static auto match_mul_const(M m, const std::string& var)
{
return match::name("mul")(match::either_arg(0, 1)(match::name("@literal").bind(var), m))
.bind(var + "_mul");
}
static auto match_add(const std::string& input, const std::string& output)
{
auto param = match::name("@param");
auto add = match::name("add")(match::args(param, param));
auto inner_mul = match::any_of(match_mul_const(match_param(input), "alpha"),
match_mul_const(match_param(output), "beta"));
auto mul_add = match::name("add")(match::either_arg(0, 1)(inner_mul, param));
auto add_mul = match_mul_const(add, "gamma");
return match::name("@return")(match::args(match::any_of(add, mul_add, add_mul)));
}
static float get_float(instruction_ref ins) { return ins->get_literal().at<float>(); }
template <class Gemm>
static bool update_gemm(Gemm& gemm, module_ref pm, unsigned input)
{
auto names = pm->get_parameter_names();
if(names.size() != 2)
return false;
std::sort(names.begin(), names.end());
unsigned output = input == 0 ? 1 : 0;
auto mr = match::match_instruction(
*pm, std::prev(pm->end()), match_add(names[input], names[output]));
if(mr.result == pm->end())
return false;
if(contains(mr.instructions, "alpha_mul"))
gemm.alpha *= get_float(mr.instructions["alpha"]);
else if(contains(mr.instructions, "beta_mul"))
gemm.beta *= get_float(mr.instructions["beta"]);
else if(contains(mr.instructions, "gamma_mul"))
{
gemm.alpha *= get_float(mr.instructions["gamma"]);
gemm.beta *= get_float(mr.instructions["gamma"]);
}
return true;
}
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
...@@ -961,6 +1021,11 @@ struct find_gemm_pointwise ...@@ -961,6 +1021,11 @@ struct find_gemm_pointwise
// Already fused gemm // Already fused gemm
if(not float_equal(gemm.beta, 0)) if(not float_equal(gemm.beta, 0))
return; return;
gemm.beta = 1;
if(not update_gemm(
gemm, ins->module_inputs().front(), ins->inputs().front() == gemm_ins ? 0 : 1))
return;
auto inputs = gemm_ins->inputs(); auto inputs = gemm_ins->inputs();
inputs.pop_back(); inputs.pop_back();
...@@ -968,7 +1033,6 @@ struct find_gemm_pointwise ...@@ -968,7 +1033,6 @@ struct find_gemm_pointwise
inputs.push_back(c_ins); inputs.push_back(c_ins);
inputs.push_back(ins->inputs().back()); inputs.push_back(ins->inputs().back());
gemm.beta = 1;
m.replace_instruction(ins, gemm, inputs); m.replace_instruction(ins, gemm, inputs);
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment