"test/git@developer.sourcefind.cn:change/sglang.git" did not exist on "d8476818efc88188d0aa0a8a176024a0b82e7a1d"
Unverified Commit ee768148 authored by Qianfeng's avatar Qianfeng Committed by GitHub
Browse files

Replace the using of __expf by __ocml_exp_f32 to work-around the test_softmax_rank4 failure (#1394)

parent 9cac2827
...@@ -431,7 +431,7 @@ struct Relu ...@@ -431,7 +431,7 @@ struct Relu
// https://paperswithcode.com/method/gelu // https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3))) // y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
// host code use higher accuracy "exp" and "div" // host code use higher accuracy "exp" and "div"
// gpu code use lower accuracy "__expf" and "rcp" function // gpu code use lower accuracy "_ocml_exp_f32" and "rcp" function
struct FastGelu struct FastGelu
{ {
template <typename Y, typename X> template <typename Y, typename X>
...@@ -451,7 +451,7 @@ struct FastGelu ...@@ -451,7 +451,7 @@ struct FastGelu
y = x / (1.f + emu); y = x / (1.f + emu);
} }
// device code, use lower precision "__expf" and "rcp" // device code, use lower precision "__ocml_exp_f32" and "rcp"
template <> template <>
__device__ void operator()<float, float>(float& y, const float& x) const __device__ void operator()<float, float>(float& y, const float& x) const
{ {
...@@ -459,7 +459,7 @@ struct FastGelu ...@@ -459,7 +459,7 @@ struct FastGelu
const float c1 = -2.0 * 0.035677f; const float c1 = -2.0 * 0.035677f;
const float c2 = -2.0 * 0.797885f; const float c2 = -2.0 * 0.797885f;
const float u = x * (c1 * x * x + c2); const float u = x * (c1 * x * x + c2);
const float emu = __expf(u); const float emu = __ocml_exp_f32(u);
y = x * ck::math::rcp(1.f + emu); y = x * ck::math::rcp(1.f + emu);
} }
......
...@@ -839,7 +839,7 @@ inline __device__ T rcp(T x) ...@@ -839,7 +839,7 @@ inline __device__ T rcp(T x)
template <typename T> template <typename T>
inline __device__ T exp(T x) inline __device__ T exp(T x)
{ {
return ck::type_convert<T>(__expf(ck::type_convert<float>(x))); return ck::type_convert<T>(__ocml_exp_f32(ck::type_convert<float>(x)));
}; };
template <> template <>
...@@ -851,7 +851,7 @@ inline __device__ half_t exp<half_t>(half_t x) ...@@ -851,7 +851,7 @@ inline __device__ half_t exp<half_t>(half_t x)
template <> template <>
inline __device__ float exp<float>(float x) inline __device__ float exp<float>(float x)
{ {
return __expf(x); return __ocml_exp_f32(x);
}; };
template <> template <>
......
...@@ -331,7 +331,10 @@ bfloat16_t sqrt(bfloat16_t x) ...@@ -331,7 +331,10 @@ bfloat16_t sqrt(bfloat16_t x)
}; };
CK_TILE_DEVICE CK_TILE_DEVICE
bfloat16_t exp(bfloat16_t x) { return static_cast<bfloat16_t>(__expf(static_cast<float>(x))); }; bfloat16_t exp(bfloat16_t x)
{
return static_cast<bfloat16_t>(__ocml_exp_f32(static_cast<float>(x)));
};
CK_TILE_DEVICE CK_TILE_DEVICE
bfloat16_t exp2(bfloat16_t x) { return static_cast<bfloat16_t>(exp2f(static_cast<float>(x))); }; bfloat16_t exp2(bfloat16_t x) { return static_cast<bfloat16_t>(exp2f(static_cast<float>(x))); };
......
...@@ -835,7 +835,7 @@ CK_TILE_DEVICE ...@@ -835,7 +835,7 @@ CK_TILE_DEVICE
fp8_t sqrt(fp8_t x) { return static_cast<fp8_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x))); }; fp8_t sqrt(fp8_t x) { return static_cast<fp8_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x))); };
CK_TILE_DEVICE CK_TILE_DEVICE
fp8_t exp(fp8_t x) { return static_cast<fp8_t>(__expf(static_cast<float>(x))); }; fp8_t exp(fp8_t x) { return static_cast<fp8_t>(__ocml_exp_f32(static_cast<float>(x))); };
CK_TILE_DEVICE CK_TILE_DEVICE
fp8_t exp2(fp8_t x) { return static_cast<fp8_t>(exp2f(static_cast<float>(x))); }; fp8_t exp2(fp8_t x) { return static_cast<fp8_t>(exp2f(static_cast<float>(x))); };
...@@ -860,7 +860,7 @@ CK_TILE_DEVICE ...@@ -860,7 +860,7 @@ CK_TILE_DEVICE
bf8_t sqrt(bf8_t x) { return static_cast<bf8_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x))); }; bf8_t sqrt(bf8_t x) { return static_cast<bf8_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x))); };
CK_TILE_DEVICE CK_TILE_DEVICE
bf8_t exp(bf8_t x) { return static_cast<bf8_t>(__expf(static_cast<float>(x))); }; bf8_t exp(bf8_t x) { return static_cast<bf8_t>(__ocml_exp_f32(static_cast<float>(x))); };
CK_TILE_DEVICE CK_TILE_DEVICE
bf8_t exp2(bf8_t x) { return static_cast<bf8_t>(exp2f(static_cast<float>(x))); }; bf8_t exp2(bf8_t x) { return static_cast<bf8_t>(exp2f(static_cast<float>(x))); };
......
...@@ -374,7 +374,7 @@ half_t sqrt(half_t x) ...@@ -374,7 +374,7 @@ half_t sqrt(half_t x)
}; };
CK_TILE_DEVICE CK_TILE_DEVICE
half_t exp(half_t x) { return static_cast<half_t>(__expf(static_cast<float>(x))); }; half_t exp(half_t x) { return static_cast<half_t>(__ocml_exp_f32(static_cast<float>(x))); };
CK_TILE_DEVICE CK_TILE_DEVICE
half_t exp2(half_t x) { return static_cast<half_t>(exp2f(static_cast<float>(x))); }; half_t exp2(half_t x) { return static_cast<half_t>(exp2f(static_cast<float>(x))); };
......
...@@ -519,7 +519,7 @@ CK_TILE_DEVICE ...@@ -519,7 +519,7 @@ CK_TILE_DEVICE
double sqrt(double x) { return __builtin_amdgcn_sqrt(x); }; double sqrt(double x) { return __builtin_amdgcn_sqrt(x); };
CK_TILE_DEVICE CK_TILE_DEVICE
float exp(float x) { return __expf(x); }; float exp(float x) { return __ocml_exp_f32(x); };
CK_TILE_HOST CK_TILE_HOST
float exp(float x) { return std::expf(x); } float exp(float x) { return std::expf(x); }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment