"tools/git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "b2c43ffd4ce8db4cf8c6516c89775239c28a5464"
Unverified Commit 8f455615 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Fast GeLU using built-in function (#587)



* clean up

* fast gelu using builtin function

* clean

* clean

* clean

* clean:

* clean

* fix compilation

* clean

* clean

---------
Co-authored-by: default avatarzjing14 <zhangjing14@gmail.com>
parent 209baee2
...@@ -62,7 +62,7 @@ struct ExecutionConfig final ...@@ -62,7 +62,7 @@ struct ExecutionConfig final
}; };
inline bool inline bool
parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig config) parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config)
{ {
if(argc == 1) if(argc == 1)
{ {
......
...@@ -7,10 +7,11 @@ using ADataType = BF16; ...@@ -7,10 +7,11 @@ using ADataType = BF16;
using BDataType = BF16; using BDataType = BF16;
using AccDataType = F32; using AccDataType = F32;
using CShuffleDataType = F32; using CShuffleDataType = F32;
using D0DataType = BF16; using CDataType = F32; // C matrix doesn't exsit in GPU memory, this is used for host verification
using D1DataType = BF16; using D0DataType = BF16;
using DsDataType = ck::Tuple<D0DataType, D1DataType>; using D1DataType = BF16;
using EDataType = BF16; using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = BF16;
using ALayout = Row; using ALayout = Row;
using BLayout = Col; using BLayout = Col;
...@@ -36,7 +37,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C ...@@ -36,7 +37,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType, BDataType,
AccDataType, CDataType,
AccDataType, AccDataType,
AElementOp, AElementOp,
BElementOp, BElementOp,
......
...@@ -7,10 +7,11 @@ using ADataType = F16; ...@@ -7,10 +7,11 @@ using ADataType = F16;
using BDataType = F16; using BDataType = F16;
using AccDataType = F32; using AccDataType = F32;
using CShuffleDataType = F32; using CShuffleDataType = F32;
using D0DataType = F16; using CDataType = F32; // C matrix doesn't exsit in GPU memory, this is used for host verification
using D1DataType = F16; using D0DataType = F16;
using DsDataType = ck::Tuple<D0DataType, D1DataType>; using D1DataType = F16;
using EDataType = F16; using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = F16;
using ALayout = Row; using ALayout = Row;
using BLayout = Col; using BLayout = Col;
...@@ -36,7 +37,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C ...@@ -36,7 +37,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType, BDataType,
AccDataType, CDataType,
AccDataType, AccDataType,
AElementOp, AElementOp,
BElementOp, BElementOp,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp" #include "common.hpp"
...@@ -7,10 +6,11 @@ using ADataType = F32; ...@@ -7,10 +6,11 @@ using ADataType = F32;
using BDataType = F32; using BDataType = F32;
using AccDataType = F32; using AccDataType = F32;
using CShuffleDataType = F32; using CShuffleDataType = F32;
using D0DataType = F32; using CDataType = F32; // C matrix doesn't exsit in GPU memory, this is used for host verification
using D1DataType = F32; using D0DataType = F32;
using DsDataType = ck::Tuple<D0DataType, D1DataType>; using D1DataType = F32;
using EDataType = F32; using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = F32;
using ALayout = Row; using ALayout = Row;
using BLayout = Col; using BLayout = Col;
...@@ -36,7 +36,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C ...@@ -36,7 +36,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType, BDataType,
AccDataType, CDataType,
AccDataType, AccDataType,
AElementOp, AElementOp,
BElementOp, BElementOp,
......
...@@ -11,10 +11,11 @@ using ADataType = I4; ...@@ -11,10 +11,11 @@ using ADataType = I4;
using BDataType = I4; using BDataType = I4;
using AccDataType = I32; using AccDataType = I32;
using CShuffleDataType = I32; using CShuffleDataType = I32;
using D0DataType = I4; using CDataType = I32; // C matrix doesn't exsit in GPU memory, this is used for host verification
using D1DataType = I4; using D0DataType = I4;
using DsDataType = ck::Tuple<D0DataType, D1DataType>; using D1DataType = I4;
using EDataType = I4; using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = I4;
using KernelADataType = I8; using KernelADataType = I8;
using KernelBDataType = I8; using KernelBDataType = I8;
...@@ -47,7 +48,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C ...@@ -47,7 +48,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType, BDataType,
AccDataType, CDataType,
AccDataType, AccDataType,
AElementOp, AElementOp,
BElementOp, BElementOp,
......
...@@ -7,10 +7,11 @@ using ADataType = I8; ...@@ -7,10 +7,11 @@ using ADataType = I8;
using BDataType = I8; using BDataType = I8;
using AccDataType = I32; using AccDataType = I32;
using CShuffleDataType = I32; using CShuffleDataType = I32;
using D0DataType = I8; using CDataType = I32; // C matrix doesn't exsit in GPU memory, this is used for host verification
using D1DataType = I8; using D0DataType = I8;
using DsDataType = ck::Tuple<D0DataType, D1DataType>; using D1DataType = I8;
using EDataType = I8; using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = I8;
using ALayout = Row; using ALayout = Row;
using BLayout = Col; using BLayout = Col;
...@@ -36,7 +37,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C ...@@ -36,7 +37,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType, BDataType,
AccDataType, CDataType,
AccDataType, AccDataType,
AElementOp, AElementOp,
BElementOp, BElementOp,
......
...@@ -124,7 +124,7 @@ bool run_gemm_add_add_fastgelu(const ProblemSize& problem_size, const ExecutionC ...@@ -124,7 +124,7 @@ bool run_gemm_add_add_fastgelu(const ProblemSize& problem_size, const ExecutionC
if(config.do_verification) if(config.do_verification)
{ {
Tensor<AccDataType> c_m_n({M, N}); Tensor<CDataType> c_m_n({M, N});
auto ref_gemm = ReferenceGemmInstance{}; auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_gemm.MakeInvoker();
......
...@@ -168,6 +168,9 @@ ...@@ -168,6 +168,9 @@
// tuning parameter // tuning parameter
#define CK_WORKAROUND_SWDEV_325164 0 #define CK_WORKAROUND_SWDEV_325164 0
// workaround: compiler not emiting reciprocal instruction frm __frcp_rn()
#define CK_WORKAROUND_SWDEV_383542 1
// flag to enable (1) or disable (0) the debugging output in some kernels // flag to enable (1) or disable (0) the debugging output in some kernels
#define DEBUG_LOG 0 #define DEBUG_LOG 0
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#pragma once #pragma once
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -280,43 +281,42 @@ struct AddHardswish ...@@ -280,43 +281,42 @@ struct AddHardswish
}; };
}; };
// C = A * B
// E = FastGelu(C + D) // E = FastGelu(C + D)
struct AddFastGelu struct AddFastGelu
{ {
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__ __device__ static constexpr float GetFastGeLU(float x)
{
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u);
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
return x * cdf;
}
template <typename T>
static inline constexpr bool is_valid_param_type_v =
std::is_same_v<T, float> || std::is_same_v<T, half_t> || std::is_same_v<T, bhalf_t> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>;
template <typename E, typename C, typename D> template <typename E, typename C, typename D>
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
template <>
__host__ __device__ constexpr void
operator()<float, float, float>(float& e, const float& c, const float& d) const
{ {
static_assert(is_valid_param_type_v<E> && is_valid_param_type_v<C> && const float x = c + d;
is_valid_param_type_v<D>);
FastGelu{}.template operator()<float, float>(e, x);
}
const float y = GetFastGeLU(type_convert<float>(c) + type_convert<float>(d)); template <>
__host__ __device__ constexpr void
operator()<half_t, half_t, half_t>(half_t& e, const half_t& c, const half_t& d) const
{
const half_t x = c + d;
e = type_convert<E>(y); ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x);
} }
template <typename D> template <>
__host__ __device__ constexpr void operator()(float& e, const float& c, const D& d) const __host__ __device__ constexpr void
operator()<half_t, float, half_t>(half_t& e, const float& c, const half_t& d) const
{ {
static_assert(is_valid_param_type_v<D>); const float x0_f = c + d;
float x1_f = 0;
ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
x0_f);
e = GetFastGeLU(c + type_convert<float>(d)); e = type_convert<half_t>(x1_f);
} }
}; };
......
...@@ -16,7 +16,7 @@ namespace element_wise { ...@@ -16,7 +16,7 @@ namespace element_wise {
// Need to ensure compiler will fail if there is no matching candidate, instead of compiler // Need to ensure compiler will fail if there is no matching candidate, instead of compiler
// siliently do implicit type conversion // siliently do implicit type conversion
// //
// Method 1: // Example:
// //
// struct ExampleElementwiseOp // struct ExampleElementwiseOp
// { // {
...@@ -30,19 +30,6 @@ namespace element_wise { ...@@ -30,19 +30,6 @@ namespace element_wise {
// { // {
// } // }
// }; // };
//
// Method 2:
//
// template <typename Y, typename X>
// struct ExampleElementwiseOp;
//
// template <>
// struct ExampleElementwiseOp<float, ck::bhalf_t>
// {
// __host__ __device__ void operator()(float& y, ck::bhalf_t& x) const
// {
// }
// };
struct AddReluAdd struct AddReluAdd
{ {
...@@ -208,41 +195,74 @@ struct AddMultiply ...@@ -208,41 +195,74 @@ struct AddMultiply
} }
}; };
// C = A * B
// E = FastGelu(C + D0 + D1) // E = FastGelu(C + D0 + D1)
struct AddAddFastGelu struct AddAddFastGelu
{ {
// Fast GeLU template <typename E, typename C, typename D0, typename D1>
// https://paperswithcode.com/method/gelu __host__ __device__ constexpr void
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3))) operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
__host__ __device__ static constexpr float GetFastGeLU(float x)
template <>
__host__ __device__ constexpr void operator()<float, float, float, float>(float& e,
const float& c,
const float& d0,
const float& d1) const
{ {
const float u = 2.f * x * (0.035677f * x * x + 0.797885f); const float x = c + d0 + d1;
const float emu = exp(-u);
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f); FastGelu{}.template operator()<float, float>(e, x);
return x * cdf;
} }
template <typename T> template <>
static inline constexpr bool is_valid_param_type_v = __host__ __device__ constexpr void operator()<half_t, half_t, half_t, half_t>(
std::is_same_v<T, float> || std::is_same_v<T, half_t> || std::is_same_v<T, bhalf_t> || half_t& e, const half_t& c, const half_t& d0, const half_t& d1) const
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> {
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 const half_t x = c + d0 + d1;
|| std::is_same_v<T, ck::int4_t>
#endif
;
template <typename E, typename C, typename D0, typename D1> ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x);
__host__ __device__ constexpr void }
operator()(E& e, const C& c, const D0& d0, const D1& d1) const
template <>
__host__ __device__ constexpr void operator()<half_t, float, half_t, half_t>(
half_t& e, const float& c, const half_t& d0, const half_t& d1) const
{ {
static_assert(is_valid_param_type_v<E> && is_valid_param_type_v<C> && const float x0_f = c + d0 + d1;
is_valid_param_type_v<D0> && is_valid_param_type_v<D1>);
float x1_f = 0;
ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
x0_f);
e = type_convert<half_t>(x1_f);
}
template <>
__host__ __device__ constexpr void operator()<bhalf_t, float, bhalf_t, bhalf_t>(
bhalf_t& e, const float& c, const bhalf_t& d0, const bhalf_t& d1) const
{
const float x0_f = c + type_convert<float>(d0) + type_convert<float>(d1);
float x1_f = 0;
ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
x0_f);
e = type_convert<bhalf_t>(x1_f);
}
template <>
__host__ __device__ constexpr void operator()<int8_t, int32_t, int8_t, int8_t>(
int8_t& e, const int32_t& c, const int8_t& d0, const int8_t& d1) const
{
const float x0_f =
type_convert<float>(c) + type_convert<float>(d0) + type_convert<float>(d1);
float x1_f = 0;
const float y = ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
GetFastGeLU(type_convert<float>(c) + type_convert<float>(d0) + type_convert<float>(d1)); x0_f);
e = type_convert<E>(y); e = type_convert<int8_t>(x1_f);
} }
}; };
......
...@@ -11,6 +11,10 @@ namespace ck { ...@@ -11,6 +11,10 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace element_wise { namespace element_wise {
#if CK_WORKAROUND_SWDEV_383542
extern "C" __device__ float __ocml_native_recip_f32(float);
#endif
struct PassThrough struct PassThrough
{ {
template <typename Y, typename X> template <typename Y, typename X>
...@@ -200,36 +204,83 @@ struct Relu ...@@ -200,36 +204,83 @@ struct Relu
} }
}; };
// Y = FastGelu(X) // Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
// host code use higher accuracy "exp" and "div"
// gpu code use lower accuracy "__expf" and "rcp" function
struct FastGelu struct FastGelu
{ {
// Fast GeLU template <typename Y, typename X>
// https://paperswithcode.com/method/gelu __host__ void operator()(Y& y, const X& x) const;
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__ __device__ static constexpr float GetFastGeLU(float x) template <typename Y, typename X>
__device__ void operator()(Y& y, const X& x) const;
template <>
__host__ void operator()<float, float>(float& y, const float& x) const
{ {
const float u = 2.f * x * (0.035677f * x * x + 0.797885f); const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u); const float emu = exp(-u);
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f); const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
return x * cdf;
y = x * cdf;
} }
template <typename T> // device code, use lower precision "__expf" and "rcp"
static inline constexpr bool is_valid_param_type_v = template <>
std::is_same_v<T, float> || std::is_same_v<T, half_t> || std::is_same_v<T, bhalf_t> || __device__ void operator()<float, float>(float& y, const float& x) const
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> {
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
|| std::is_same_v<T, ck::int4_t> const float emu = __expf(-u);
#if !CK_WORKAROUND_SWDEV_383542
const float cdf = 0.5f + 0.5f * (2.f * __frcp_rn(1.f + emu) - 1.f);
#else
const float cdf = 0.5f + 0.5f * (2.f * __ocml_native_recip_f32(1.f + emu) - 1.f);
#endif #endif
;
template <typename Y, typename X> y = x * cdf;
__host__ __device__ void operator()(Y& y, const X& x) const }
template <>
__host__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
{ {
static_assert(is_valid_param_type_v<Y> && is_valid_param_type_v<X>); float y_f;
this->operator()<float, float>(y_f, type_convert<float>(x));
y = type_convert<half_t>(y_f);
}
template <>
__device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
{
float y_f;
this->operator()<float, float>(y_f, type_convert<float>(x));
y = type_convert<half_t>(y_f);
}
template <>
__host__ void operator()<half_t, float>(half_t& y, const float& x) const
{
float y_f;
this->operator()<float, float>(y_f, x);
y = type_convert<half_t>(y_f);
}
template <>
__device__ void operator()<half_t, float>(half_t& y, const float& x) const
{
float y_f;
this->operator()<float, float>(y_f, x);
const float tmp_y = GetFastGeLU(type_convert<float>(x)); y = type_convert<half_t>(y_f);
y = type_convert<Y>(tmp_y);
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment