Unverified Commit 9f283810 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Some perf improvements to bert (#627)



* Fuse gemm in fuse ops

* Formatting

* Add const ref

* Remove assert

* Skip already fused gemms

* Skip already fused gemm

* Formatting

* Use float_equal

* Avoid non-standard shapes for inputs

* Formatting
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent e2cbb01e
......@@ -8,6 +8,7 @@
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/add.hpp>
#include <migraphx/gpu/mul.hpp>
#include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/device/layernorm.hpp>
#include <migraphx/gpu/device/gelu.hpp>
#include <migraphx/gpu/device/mul_add.hpp>
......@@ -536,10 +537,9 @@ struct find_triadd
auto input_ins = r.instructions["input"];
auto ins = r.result;
auto args = add_ins->inputs();
assert(add_ins != input_ins);
auto is_broadcasted = [](auto arg) { return arg->get_shape().broadcasted(); };
if(std::count_if(args.begin(), args.end(), is_broadcasted) > 1)
if(std::count_if(args.begin(), args.end(), is_broadcasted) > 2)
return;
args.insert(args.begin(), input_ins);
move_standard_front(args);
......@@ -743,6 +743,68 @@ struct find_conv_bias_relu
}
};
struct find_gemm_add
{
auto matcher() const
{
return match::name("gpu::add")(
match::all_of[match::inputs()](match::standard_shape()),
match::either_arg(0, 1)(match::used_once().bind("c"),
match::name("gpu::gemm")(match::nargs(3)).bind("gemm")));
}
void apply(program& p, match::matcher_result r) const
{
auto ins = r.result;
auto gemm_ins = r.instructions["gemm"];
auto c_ins = r.instructions["c"];
auto gemm = any_cast<rocblas_gemm<op::dot>>(gemm_ins->get_operator());
// Already fused gemm
if(not float_equal(gemm.op.beta, 0))
return;
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto i) {
return not i->get_shape().standard();
}))
return;
auto inputs = gemm_ins->inputs();
inputs.pop_back();
auto copy_ins = c_ins;
// Insert copy
if(ins == p.end() or c_ins->outputs().size() > 1 or c_ins->inputs().empty())
{
copy_ins = p.insert_instruction(ins, hip_copy{}, c_ins, ins->inputs().back());
}
inputs.push_back(copy_ins);
inputs.push_back(copy_ins);
gemm.op.beta = 1;
p.replace_instruction(ins, gemm, inputs);
}
};
struct find_commutative_broadcast
{
auto matcher() const
{
return match::name("gpu::add", "gpu::mul")(match::arg(1)(match::broadcast_shape()));
}
void apply(program& p, const match::matcher_result& r) const
{
auto ins = r.result;
auto args = ins->inputs();
move_broadcasted_back(args);
p.replace_instruction(ins, ins->get_operator(), args);
}
};
void fuse_ops::apply(program& p) const
{
match::find_matches(p, find_gelu{}, find_gelu_new{});
......@@ -760,7 +822,7 @@ void fuse_ops::apply(program& p) const
find_add_unary{"gpu::sigmoid", hip_add_sigmoid{}, hip_triadd_sigmoid{}},
find_add_unary{"gpu::tanh", hip_add_tanh{}, hip_triadd_tanh{}},
find_add_clip{});
// clang-format on
match::find_matches(p, find_gemm_add{}, find_commutative_broadcast{});
}
} // namespace gpu
......
......@@ -64,8 +64,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
simplify_reshapes{},
propagate_constant{},
dead_code_elimination{},
remap{},
dead_code_elimination{},
lowering{&ctx, options.offload_copy},
eliminate_contiguous{},
dead_code_elimination{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment