Unverified Commit 0079028a authored by kahmed10's avatar kahmed10 Committed by GitHub
Browse files

Add gelu optimization (#521)



* fix pad calc

* bert tf passes correctness

* formatting

* add test

* formatting

* remove comment

* add inline

* formatting

* fix order for literal

* formatting

* add test for gelu

* formatting

* added add_gelu fusion

* add files

* formatting

* remove layernorm opt

* revert reduce file

* add gelu_fn and tests

* formatting

* fix matcher, remove extra tests

* formatting

* fix matcher

* add used_once

* formatting

* start on new gelu

* formatting

* add matchers in fuse_ops

* formatting

* add dce to fix add_gelu

* add simplify_rsqrt and test

* formatting

* debugging value for matcher

* formatting

* add more to matchers

* formatting

* fix errors

* remove onnx gen

* add any_arg, change matchers to use either_arg

* formatting

* formatting

* add used_once

* formatting
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 294a0e66
...@@ -518,6 +518,11 @@ inline auto either_arg(std::size_t i, std::size_t j) ...@@ -518,6 +518,11 @@ inline auto either_arg(std::size_t i, std::size_t j)
}; };
} }
inline auto any_arg(std::size_t i, std::size_t j)
{
return [=](auto m) { return match::any_of(arg(i)(m), arg(j)(m)); };
}
template <class M> template <class M>
auto same_shape(M m) auto same_shape(M m)
{ {
...@@ -535,6 +540,21 @@ auto same_shape(Ms... ms) ...@@ -535,6 +540,21 @@ auto same_shape(Ms... ms)
return all_of(same_shape(ms)...); return all_of(same_shape(ms)...);
} }
template <class T>
inline auto has_value(T x, float tolerance = 1e-6)
{
return make_basic_pred_matcher([=](instruction_ref ins) {
if(ins->get_shape().elements() != 1)
return false;
auto l = ins->get_literal();
if(l.empty())
return false;
bool b = false;
l.visit([&](auto v) { b = v.front() - x < tolerance; });
return b;
});
}
} // namespace match } // namespace match
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -36,6 +36,7 @@ add_library(migraphx_device ...@@ -36,6 +36,7 @@ add_library(migraphx_device
device/exp.cpp device/exp.cpp
device/floor.cpp device/floor.cpp
device/gather.cpp device/gather.cpp
device/gelu.cpp
device/int8_gemm_pack.cpp device/int8_gemm_pack.cpp
device/log.cpp device/log.cpp
device/logsoftmax.cpp device/logsoftmax.cpp
......
#include <migraphx/gpu/device/gelu.hpp>
#include <migraphx/gpu/device/nary.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <cmath>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
// x * 0.5 * (1.0 + erf(x / sqrt(2.0)))
template <class T>
auto gelu_fn(T x) __device__
{
return x * 0.5 * (1 + ::erf(x * M_SQRT1_2));
}
// 0.5 * x * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * pow(x, 3))))
template <class T>
auto gelu_fn_new(T x) __device__
{
return 0.5 * x * (1 + tanh(sqrt(M_2_PI) * (x + 0.044715 * x * x * x)));
}
void gelu(hipStream_t stream, const argument& result, const argument& arg)
{
nary(stream, result, arg)([](auto x) __device__ { return gelu_fn(to_hip_type(x)); });
}
void gelu_new(hipStream_t stream, const argument& result, const argument& arg)
{
nary(stream, result, arg)([](auto x) __device__ { return gelu_fn_new(to_hip_type(x)); });
}
void add_gelu(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2)
{
nary(stream, result, arg1, arg2)([](auto x, auto y) __device__ {
auto sum = to_hip_type(x + y);
return gelu_fn(sum);
});
}
void add_gelu_new(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2)
{
nary(stream, result, arg1, arg2)([](auto x, auto y) __device__ {
auto sum = to_hip_type(x + y);
return gelu_fn(sum);
});
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/gpu/fuse_ops.hpp> #include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraphx/gpu/miopen.hpp> #include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/clip.hpp> #include <migraphx/gpu/clip.hpp>
#include <migraphx/gpu/convolution.hpp> #include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/oper.hpp> #include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/gelu.hpp>
#include <migraphx/gpu/device/mul_add.hpp> #include <migraphx/gpu/device/mul_add.hpp>
#include <migraphx/gpu/device/add_clip.hpp> #include <migraphx/gpu/device/add_clip.hpp>
#include <migraphx/gpu/device/add_relu.hpp> #include <migraphx/gpu/device/add_relu.hpp>
...@@ -14,12 +17,14 @@ ...@@ -14,12 +17,14 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/array.hpp> #include <migraphx/array.hpp>
#include <migraphx/op/clip.hpp> #include <migraphx/op/clip.hpp>
#include <cmath>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MIOPEN_FUSION) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MIOPEN_FUSION)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_FAST_GELU)
struct fusion struct fusion
{ {
...@@ -252,6 +257,22 @@ struct hip_add_tanh : binary_device<hip_add_tanh, &device::add_tanh> ...@@ -252,6 +257,22 @@ struct hip_add_tanh : binary_device<hip_add_tanh, &device::add_tanh>
{ {
}; };
struct hip_gelu : unary_device<hip_gelu, &device::gelu>
{
};
struct hip_add_gelu : binary_device<hip_add_gelu, &device::add_gelu>
{
};
struct hip_gelu_new : unary_device<hip_gelu_new, &device::gelu_new>
{
};
struct hip_add_gelu_new : binary_device<hip_add_gelu_new, &device::add_gelu_new>
{
};
struct hip_mul_add struct hip_mul_add
{ {
std::string name() const { return "hip::mul_add"; } std::string name() const { return "hip::mul_add"; }
...@@ -311,6 +332,121 @@ void move_standard_front(std::vector<instruction_ref>& args) ...@@ -311,6 +332,121 @@ void move_standard_front(std::vector<instruction_ref>& args)
std::swap(*it, args.front()); std::swap(*it, args.front());
} }
struct find_gelu
{
static auto erf_fn()
{
return match::name("gpu::erf")(
match::used_once(),
match::arg(0)(match::used_once(),
match::name("gpu::mul")(match::either_arg(0, 1)(
match::none_of(match::has_value(M_SQRT1_2)).bind("x"),
match::has_value(M_SQRT1_2)))));
}
auto matcher() const
{
return match::name("gpu::mul")(match::either_arg(0, 1)(
match::name("gpu::mul")(match::any_arg(0, 1)(match::args(match::has_value(0.5f)))),
match::name("gpu::add")(
match::used_once(),
match::either_arg(0, 1)(erf_fn(), match::args(match::has_value(1.0f))))));
}
void apply(program& p, match::matcher_result r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
auto args = ins->inputs();
p.replace_instruction(ins, hip_gelu{}, x_ins, args.back());
}
};
struct find_add_gelu
{
auto matcher() const
{
return match::name("gpu::gelu")(match::arg(0)(match::name("gpu::add").bind("add")));
}
void apply(program& p, match::matcher_result r) const
{
auto add_ins = r.instructions["add"];
auto ins = r.result;
auto args = add_ins->inputs();
move_standard_front(args);
move_broadcasted_back(args);
args.back() = ins->inputs().back();
p.replace_instruction(ins, hip_add_gelu{}, args);
}
};
struct find_gelu_new
{
static auto pow_fn()
{
return match::name("gpu::pow")(match::used_once(),
match::arg(1)(match::args(match::has_value(3.0f))));
}
static auto tanh_fn()
{
return match::name("gpu::tanh")(
match::used_once(),
match::arg(0)(match::name("gpu::mul")(match::either_arg(0, 1)(
match::args(match::has_value(sqrt(M_2_PI))),
match::name("gpu::add")(
match::any_arg(0, 1)(match::name("gpu::mul")(match::either_arg(0, 1)(
match::args(match::has_value(0.044715f)), pow_fn()))))))));
}
auto matcher() const
{
return match::name("gpu::mul")(
match::used_once(),
match::either_arg(0, 1)(
match::any().bind("x"),
match::name("gpu::add")(match::any_arg(0, 1)(match::name("gpu::mul")(
match::either_arg(0, 1)(match::args(match::has_value(0.5f)), tanh_fn()))))));
}
void apply(program& p, match::matcher_result r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
auto args = ins->inputs();
if(enabled(MIGRAPHX_DISABLE_FAST_GELU{}))
p.replace_instruction(ins, hip_gelu_new{}, x_ins, args.back());
else
p.replace_instruction(ins, hip_gelu{}, x_ins, args.back());
}
};
struct find_add_gelu_new
{
auto matcher() const
{
return match::name("gpu::gelu_new")(match::arg(0)(match::name("gpu::add").bind("add")));
}
void apply(program& p, match::matcher_result r) const
{
auto add_ins = r.instructions["add"];
auto ins = r.result;
auto args = add_ins->inputs();
move_standard_front(args);
move_broadcasted_back(args);
args.back() = ins->inputs().back();
p.replace_instruction(ins, hip_add_gelu_new{}, args);
}
};
struct find_add_clip struct find_add_clip
{ {
auto matcher() const auto matcher() const
...@@ -602,10 +738,14 @@ struct find_conv_bias_relu ...@@ -602,10 +738,14 @@ struct find_conv_bias_relu
void fuse_ops::apply(program& p) const void fuse_ops::apply(program& p) const
{ {
// clang-format off // clang-format off
match::find_matches(p, find_gelu{}, find_gelu_new{});
run_passes(p, {dead_code_elimination{}});
match::find_matches(p, find_triadd{}); match::find_matches(p, find_triadd{});
match::find_matches(p, match::find_matches(p,
find_conv_bias_relu{ctx}, find_conv_bias_relu{ctx},
find_conv_bias{ctx}, find_conv_bias{ctx},
find_add_gelu{},
find_add_gelu_new{},
find_mul_add{}, find_mul_add{},
find_mul_add_relu{}, find_mul_add_relu{},
find_add_unary{"gpu::relu", hip_add_relu{}, hip_triadd_relu{}}, find_add_unary{"gpu::relu", hip_add_relu{}, hip_triadd_relu{}},
......
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_GELU_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_GELU_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void gelu(hipStream_t stream, const argument& result, const argument& arg);
void gelu_new(hipStream_t stream, const argument& result, const argument& arg);
void add_gelu(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2);
void add_gelu_new(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <future> #include <future>
#include <thread> #include <thread>
#include <cmath>
#include <numeric> #include <numeric>
#include <test.hpp> #include <test.hpp>
...@@ -723,6 +724,52 @@ struct test_triadd_broadcast : verify_program<test_triadd_broadcast> ...@@ -723,6 +724,52 @@ struct test_triadd_broadcast : verify_program<test_triadd_broadcast>
} }
}; };
struct test_gelu : verify_program<test_gelu>
{
migraphx::program create_program() const
{
migraphx::program p;
std::vector<size_t> input_lens{1, 1, 5};
auto x = p.add_parameter("x", {migraphx::shape::float_type, input_lens});
auto half = p.add_literal(0.5f);
auto one = p.add_literal(1.0f);
auto sqrt2 = p.add_literal(static_cast<float>(M_SQRT2));
auto half_mbcast = p.add_instruction(migraphx::op::multibroadcast{input_lens}, half);
auto mul_half = p.add_instruction(migraphx::op::mul{}, x, half_mbcast);
auto sqrt2_mbcast = p.add_instruction(migraphx::op::multibroadcast{input_lens}, sqrt2);
auto div = p.add_instruction(migraphx::op::div{}, x, sqrt2_mbcast);
auto erf = p.add_instruction(migraphx::op::erf{}, div);
auto one_mbcast = p.add_instruction(migraphx::op::multibroadcast{input_lens}, one);
auto add_one = p.add_instruction(migraphx::op::add{}, erf, one_mbcast);
p.add_instruction(migraphx::op::mul{}, mul_half, add_one);
return p;
}
};
struct test_add_gelu : verify_program<test_add_gelu>
{
migraphx::program create_program() const
{
migraphx::program p;
std::vector<size_t> input_lens{1, 1, 5};
auto x = p.add_parameter("x", {migraphx::shape::float_type, input_lens});
auto y = p.add_parameter("y", {migraphx::shape::float_type, input_lens});
auto half = p.add_literal(0.5f);
auto one = p.add_literal(1.0f);
auto sqrt2 = p.add_literal(static_cast<float>(M_SQRT2));
auto add = p.add_instruction(migraphx::op::add{}, x, y);
auto half_mbcast = p.add_instruction(migraphx::op::multibroadcast{input_lens}, half);
auto mul_half = p.add_instruction(migraphx::op::mul{}, add, half_mbcast);
auto sqrt2_mbcast = p.add_instruction(migraphx::op::multibroadcast{input_lens}, sqrt2);
auto div = p.add_instruction(migraphx::op::div{}, add, sqrt2_mbcast);
auto erf = p.add_instruction(migraphx::op::erf{}, div);
auto one_mbcast = p.add_instruction(migraphx::op::multibroadcast{input_lens}, one);
auto add_one = p.add_instruction(migraphx::op::add{}, erf, one_mbcast);
p.add_instruction(migraphx::op::mul{}, mul_half, add_one);
return p;
}
};
struct test_sub : verify_program<test_sub> struct test_sub : verify_program<test_sub>
{ {
migraphx::program create_program() const migraphx::program create_program() const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment