"...git@developer.sourcefind.cn:cnjsdfcy/simbricks.git" did not exist on "207b0abdc178d968c4e43ddf364062759bba7f38"
Commit 83fe1e08 authored by Junhao's avatar Junhao
Browse files

fix RTN logic

parent ff46a782
...@@ -32,19 +32,7 @@ inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t ...@@ -32,19 +32,7 @@ inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t
} }
// convert fp32 to bfp16 // convert fp32 to bfp16
#ifndef FLASH_ATTENTION_INTERNAL_USE_RTN #if FLASH_ATTENTION_INTERNAL_USE_RTN
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float x)
{
union
{
float fp32;
uint32_t int32;
} u = {x};
return uint16_t(u.int32 >> 16);
}
#else
template <> template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float x) inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float x)
{ {
...@@ -87,6 +75,18 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float ...@@ -87,6 +75,18 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float
return uint16_t(u.int32 >> 16); return uint16_t(u.int32 >> 16);
} }
#else
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float x)
{
union
{
float fp32;
uint32_t int32;
} u = {x};
return uint16_t(u.int32 >> 16);
}
#endif #endif
// convert bfp16 to fp16 via fp32 // convert bfp16 to fp16 via fp32
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment