"...resnet50_tensorflow.git" did not exist on "4ad903b49f8e30ad7a5fdb3190f6bce724b405e1"
Unverified Commit 9f283810 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Some perf improvements to bert (#627)



* Fuse gemm in fuse ops

* Formatting

* Add const ref

* Remove assert

* Skip already fused gemms

* Skip already fused gemm

* Formatting

* Use float_equal

* Avoid non-standard shapes for inputs

* Formatting
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent e2cbb01e
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <migraphx/gpu/oper.hpp> #include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/add.hpp> #include <migraphx/gpu/add.hpp>
#include <migraphx/gpu/mul.hpp> #include <migraphx/gpu/mul.hpp>
#include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/device/layernorm.hpp> #include <migraphx/gpu/device/layernorm.hpp>
#include <migraphx/gpu/device/gelu.hpp> #include <migraphx/gpu/device/gelu.hpp>
#include <migraphx/gpu/device/mul_add.hpp> #include <migraphx/gpu/device/mul_add.hpp>
...@@ -536,10 +537,9 @@ struct find_triadd ...@@ -536,10 +537,9 @@ struct find_triadd
auto input_ins = r.instructions["input"]; auto input_ins = r.instructions["input"];
auto ins = r.result; auto ins = r.result;
auto args = add_ins->inputs(); auto args = add_ins->inputs();
assert(add_ins != input_ins);
auto is_broadcasted = [](auto arg) { return arg->get_shape().broadcasted(); }; auto is_broadcasted = [](auto arg) { return arg->get_shape().broadcasted(); };
if(std::count_if(args.begin(), args.end(), is_broadcasted) > 1) if(std::count_if(args.begin(), args.end(), is_broadcasted) > 2)
return; return;
args.insert(args.begin(), input_ins); args.insert(args.begin(), input_ins);
move_standard_front(args); move_standard_front(args);
...@@ -743,6 +743,68 @@ struct find_conv_bias_relu ...@@ -743,6 +743,68 @@ struct find_conv_bias_relu
} }
}; };
struct find_gemm_add
{
auto matcher() const
{
return match::name("gpu::add")(
match::all_of[match::inputs()](match::standard_shape()),
match::either_arg(0, 1)(match::used_once().bind("c"),
match::name("gpu::gemm")(match::nargs(3)).bind("gemm")));
}
void apply(program& p, match::matcher_result r) const
{
auto ins = r.result;
auto gemm_ins = r.instructions["gemm"];
auto c_ins = r.instructions["c"];
auto gemm = any_cast<rocblas_gemm<op::dot>>(gemm_ins->get_operator());
// Already fused gemm
if(not float_equal(gemm.op.beta, 0))
return;
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto i) {
return not i->get_shape().standard();
}))
return;
auto inputs = gemm_ins->inputs();
inputs.pop_back();
auto copy_ins = c_ins;
// Insert copy
if(ins == p.end() or c_ins->outputs().size() > 1 or c_ins->inputs().empty())
{
copy_ins = p.insert_instruction(ins, hip_copy{}, c_ins, ins->inputs().back());
}
inputs.push_back(copy_ins);
inputs.push_back(copy_ins);
gemm.op.beta = 1;
p.replace_instruction(ins, gemm, inputs);
}
};
struct find_commutative_broadcast
{
auto matcher() const
{
return match::name("gpu::add", "gpu::mul")(match::arg(1)(match::broadcast_shape()));
}
void apply(program& p, const match::matcher_result& r) const
{
auto ins = r.result;
auto args = ins->inputs();
move_broadcasted_back(args);
p.replace_instruction(ins, ins->get_operator(), args);
}
};
void fuse_ops::apply(program& p) const void fuse_ops::apply(program& p) const
{ {
match::find_matches(p, find_gelu{}, find_gelu_new{}); match::find_matches(p, find_gelu{}, find_gelu_new{});
...@@ -760,7 +822,7 @@ void fuse_ops::apply(program& p) const ...@@ -760,7 +822,7 @@ void fuse_ops::apply(program& p) const
find_add_unary{"gpu::sigmoid", hip_add_sigmoid{}, hip_triadd_sigmoid{}}, find_add_unary{"gpu::sigmoid", hip_add_sigmoid{}, hip_triadd_sigmoid{}},
find_add_unary{"gpu::tanh", hip_add_tanh{}, hip_triadd_tanh{}}, find_add_unary{"gpu::tanh", hip_add_tanh{}, hip_triadd_tanh{}},
find_add_clip{}); find_add_clip{});
// clang-format on match::find_matches(p, find_gemm_add{}, find_commutative_broadcast{});
} }
} // namespace gpu } // namespace gpu
......
...@@ -64,8 +64,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -64,8 +64,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
simplify_reshapes{}, simplify_reshapes{},
propagate_constant{}, propagate_constant{},
dead_code_elimination{}, dead_code_elimination{},
remap{},
dead_code_elimination{},
lowering{&ctx, options.offload_copy}, lowering{&ctx, options.offload_copy},
eliminate_contiguous{}, eliminate_contiguous{},
dead_code_elimination{}, dead_code_elimination{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment