Commit 87f44ead authored by Chao Liu's avatar Chao Liu
Browse files

clean

parent 4198de5a
...@@ -169,7 +169,7 @@ ...@@ -169,7 +169,7 @@
#define CK_WORKAROUND_SWDEV_325164 0 #define CK_WORKAROUND_SWDEV_325164 0
// workaround: compiler not emiting reciprocal instruction frm __frcp_rn() // workaround: compiler not emiting reciprocal instruction frm __frcp_rn()
#define CK_WORKAROUND_SWDEV_XXXXXX_FRCP_RN 1 #define CK_WORKAROUND_SWDEV_383542 1
// flag to enable (1) or disable (0) the debugging output in some kernels // flag to enable (1) or disable (0) the debugging output in some kernels
#define DEBUG_LOG 0 #define DEBUG_LOG 0
......
...@@ -280,47 +280,6 @@ struct AddHardswish ...@@ -280,47 +280,6 @@ struct AddHardswish
}; };
}; };
#if 0
// C = A * B
// E = FastGelu(C + D)
struct AddFastGelu
{
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__ __device__ static constexpr float GetFastGeLU(float x)
{
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u);
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
return x * cdf;
}
template <typename T>
static inline constexpr bool is_valid_param_type_v =
std::is_same_v<T, float> || std::is_same_v<T, half_t> || std::is_same_v<T, bhalf_t> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>;
template <typename E, typename C, typename D>
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const
{
static_assert(is_valid_param_type_v<E> && is_valid_param_type_v<C> &&
is_valid_param_type_v<D>);
const float y = GetFastGeLU(type_convert<float>(c) + type_convert<float>(d));
e = type_convert<E>(y);
}
template <typename D>
__host__ __device__ constexpr void operator()(float& e, const float& c, const D& d) const
{
static_assert(is_valid_param_type_v<D>);
e = GetFastGeLU(c + type_convert<float>(d));
}
};
#else
// E = FastGelu(C + D) // E = FastGelu(C + D)
struct AddFastGelu struct AddFastGelu
{ {
...@@ -345,7 +304,6 @@ struct AddFastGelu ...@@ -345,7 +304,6 @@ struct AddFastGelu
ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x); ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x);
} }
}; };
#endif
} // namespace element_wise } // namespace element_wise
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -195,45 +195,6 @@ struct AddMultiply ...@@ -195,45 +195,6 @@ struct AddMultiply
} }
}; };
#if 0
// C = A * B
// E = FastGelu(C + D0 + D1)
struct AddAddFastGelu
{
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__ __device__ static constexpr float GetFastGeLU(float x)
{
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u);
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
return x * cdf;
}
template <typename T>
static inline constexpr bool is_valid_param_type_v =
std::is_same_v<T, float> || std::is_same_v<T, half_t> || std::is_same_v<T, bhalf_t> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| std::is_same_v<T, ck::int4_t>
#endif
;
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const
{
static_assert(is_valid_param_type_v<E> && is_valid_param_type_v<C> &&
is_valid_param_type_v<D0> && is_valid_param_type_v<D1>);
const float y =
GetFastGeLU(type_convert<float>(c) + type_convert<float>(d0) + type_convert<float>(d1));
e = type_convert<E>(y);
}
};
#else
// E = FastGelu(C + D0 + D1) // E = FastGelu(C + D0 + D1)
struct AddAddFastGelu struct AddAddFastGelu
{ {
...@@ -261,7 +222,6 @@ struct AddAddFastGelu ...@@ -261,7 +222,6 @@ struct AddAddFastGelu
ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x); ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x);
} }
}; };
#endif
struct Normalize struct Normalize
{ {
......
...@@ -11,7 +11,7 @@ namespace ck { ...@@ -11,7 +11,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace element_wise { namespace element_wise {
#if CK_WORKAROUND_SWDEV_XXXXXX_FRCP_RN #if CK_WORKAROUND_SWDEV_383542
extern "C" __device__ float __ocml_native_recip_f32(float); extern "C" __device__ float __ocml_native_recip_f32(float);
#endif #endif
...@@ -204,40 +204,6 @@ struct Relu ...@@ -204,40 +204,6 @@ struct Relu
} }
}; };
#if 0
// Y = FastGelu(X)
struct FastGelu
{
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
__host__ __device__ static constexpr float GetFastGeLU(float x)
{
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u);
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
return x * cdf;
}
template <typename T>
static inline constexpr bool is_valid_param_type_v =
std::is_same_v<T, float> || std::is_same_v<T, half_t> || std::is_same_v<T, bhalf_t> ||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| std::is_same_v<T, ck::int4_t>
#endif
;
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const
{
static_assert(is_valid_param_type_v<Y> && is_valid_param_type_v<X>);
const float tmp_y = GetFastGeLU(type_convert<float>(x));
y = type_convert<Y>(tmp_y);
}
};
#else
// Fast GeLU // Fast GeLU
// https://paperswithcode.com/method/gelu // https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3))) // y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
...@@ -275,24 +241,16 @@ struct FastGelu ...@@ -275,24 +241,16 @@ struct FastGelu
template <> template <>
__device__ void operator()<float, float>(float& y, const float& x) const __device__ void operator()<float, float>(float& y, const float& x) const
{ {
#if 0
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = exp(-u);
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
y = x * cdf;
#else
const float u = 2.f * x * (0.035677f * x * x + 0.797885f); const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float emu = __expf(-u); const float emu = __expf(-u);
#if !CK_WORKAROUND_SWDEV_XXXXXX_FRCP_RN #if !CK_WORKAROUND_SWDEV_383542
const float cdf = 0.5f + 0.5f * (2.f * __frcp_rn(1.f + emu) - 1.f); const float cdf = 0.5f + 0.5f * (2.f * __frcp_rn(1.f + emu) - 1.f);
#else #else
const float cdf = 0.5f + 0.5f * (2.f * __ocml_native_recip_f32(1.f + emu) - 1.f); const float cdf = 0.5f + 0.5f * (2.f * __ocml_native_recip_f32(1.f + emu) - 1.f);
#endif #endif
y = x * cdf; y = x * cdf;
#endif
} }
// device code, use lower precision "__expf" and "rcp" // device code, use lower precision "__expf" and "rcp"
...@@ -306,7 +264,6 @@ struct FastGelu ...@@ -306,7 +264,6 @@ struct FastGelu
y = type_convert<half_t>(y_f); y = type_convert<half_t>(y_f);
} }
}; };
#endif
// https://paperswithcode.com/method/gelu // https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+erf(x/sqrt(2))) // y = 0.5*x*(1+erf(x/sqrt(2)))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment