"git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "627a739fe35c312841a5c6ffdc1b22a17516a014"
Commit 980b8835 authored by Junhao Zhang's avatar Junhao Zhang
Browse files

compiling time RTN

parent 5fb5335e
...@@ -51,19 +51,6 @@ struct PassThrough ...@@ -51,19 +51,6 @@ struct PassThrough
y = x; y = x;
} }
#if FLASH_ATTENTION_INTERNAL_USE_RTN
template <>
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
{
y = bf16_convert_rtn<bhalf_t>(x);
}
template <>
__host__ __device__ void operator()<bhalf_t, half_t>(bhalf_t& y, const half_t& x) const
{
y = bf16_convert_rtn<bhalf_t>(x);
}
#else
template <> template <>
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const __host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
{ {
...@@ -75,7 +62,6 @@ struct PassThrough ...@@ -75,7 +62,6 @@ struct PassThrough
{ {
y = type_convert<bhalf_t>(x); y = type_convert<bhalf_t>(x);
} }
#endif
template <> template <>
__host__ __device__ void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const __host__ __device__ void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
...@@ -128,6 +114,16 @@ struct PassThrough ...@@ -128,6 +114,16 @@ struct PassThrough
} }
}; };
// struct PassThroughRTN : public PassThrough
// {
// __host__ __device__ void operator()(bhalf_t& y, const float& x) const
// {
// y = bf16_convert_rtn<bhalf_t>(x);
// }
// using PassThrough::operator();
// };
struct UnaryConvert struct UnaryConvert
{ {
template <typename Y, typename X> template <typename Y, typename X>
......
...@@ -31,6 +31,51 @@ inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t ...@@ -31,6 +31,51 @@ inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t
return u.fp32; return u.fp32;
} }
#if FLASH_ATTENTION_INTERNAL_USE_RTN
// Convert fp32 to bf16 with RTN if higher precision is needed
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float x)
{
union
{
float fp32;
uint32_t int32;
} u = {x};
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
bool flag0 = ~u.int32 & 0x7f800000;
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bfloat16's mantissa bits are all 0.
bool flag1 = !flag0 && (u.int32 & 0xffff);
u.int32 += flag0 ? 0x7fff + ((u.int32 >> 16) & 1) : 0; // Round to nearest, round to even
u.int32 |= flag1 ? 0x10000 : 0x0; // Preserve signaling NaN
return uint16_t(u.int32 >> 16);
}
#else
// convert fp32 to bfp16 // convert fp32 to bfp16
template <> template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float x) inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float x)
...@@ -43,6 +88,7 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float ...@@ -43,6 +88,7 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float
return uint16_t(u.int32 >> 16); return uint16_t(u.int32 >> 16);
} }
#endif
// convert bfp16 to fp16 via fp32 // convert bfp16 to fp16 via fp32
template <> template <>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment