Commit 7a81bc31 authored by zhuwenwen's avatar zhuwenwen
Browse files

update fp8 native implementation

parent 98a011e9
...@@ -13,7 +13,41 @@ namespace vllm { ...@@ -13,7 +13,41 @@ namespace vllm {
#ifdef USE_ROCM #ifdef USE_ROCM
namespace fp8 { namespace fp8 {
// #ifdef ENABLE_FP8 #ifdef ENABLE_FP8
// Use hardware cvt instruction for fp8 on rocm
template <typename fp8_type>
__device__ __forceinline__ fp8_type cvt_c10(float const r) {
return {};
}
// __hip_fp8_e4m3 only exists starting in ROCm 6.3. The macro
// HIP_FP8_TYPE_OCP comes from the hip_fp8.h header and also makes
// its first appearance in ROCm 6.3. Since VLLM_DISPATCH_FP8_TYPES
// on ROCm instantiates both OCP and FNUZ kernels, we need to replace
// the new HW cvt with something reasonable that doesn't rely on the
// ROCm 6.3 feature. This allows compiling on ROCm 6.2 or newer.
template <>
__device__ __forceinline__ c10::Float8_e4m3fn cvt_c10(float const r) {
#if HIP_FP8_TYPE_OCP
return c10::Float8_e4m3fn(
__hip_cvt_float_to_fp8(r, __hip_fp8_e4m3::__default_saturation,
__hip_fp8_e4m3::__default_interpret),
c10::Float8_e4m3fn::from_bits());
#else
// Cast implemented by pytorch. Uses bit manipulation instead of HW cvt.
// HW cvt above is faster when it is available (ROCm 6.3 or newer).
return static_cast<c10::Float8_e4m3fn>(r);
#endif
}
template <>
__device__ __forceinline__ c10::Float8_e4m3fnuz cvt_c10(float const r) {
return c10::Float8_e4m3fnuz(
__hip_cvt_float_to_fp8(r, __hip_fp8_e4m3_fnuz::__default_saturation,
__hip_fp8_e4m3_fnuz::__default_interpret),
c10::Float8_e4m3fnuz::from_bits());
}
// KV-CACHE int8 // KV-CACHE int8
static inline __device__ float fp8_to_float(uint8_t input) { static inline __device__ float fp8_to_float(uint8_t input) {
...@@ -53,10 +87,11 @@ static inline __device__ uint8_t float_to_fp8(float f) { ...@@ -53,10 +87,11 @@ static inline __device__ uint8_t float_to_fp8(float f) {
return result; return result;
} }
// template <typename Tout, typename Tin>
// __inline__ __device__ Tout vec_conversion(const Tin& x) { template <typename Tout, typename Tin>
// return x; __inline__ __device__ Tout vec_conversion(const Tin& x) {
// } return x;
}
template <typename Tout, typename Tin> template <typename Tout, typename Tin>
__inline__ __device__ Tout scaled_vec_conversion(const Tin& x, __inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
...@@ -64,271 +99,271 @@ __inline__ __device__ Tout scaled_vec_conversion(const Tin& x, ...@@ -64,271 +99,271 @@ __inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
return x; return x;
} }
// #if HIP_FP8_TYPE_OCP #if HIP_FP8_TYPE_OCP
// using fp8_type = __hip_fp8_e4m3; using fp8_type = __hip_fp8_e4m3;
// using fp8x2_type = __hip_fp8x2_e4m3; using fp8x2_type = __hip_fp8x2_e4m3;
// #else #else
// using fp8_type = __hip_fp8_e4m3_fnuz; using fp8_type = __hip_fp8_e4m3_fnuz;
// using fp8x2_type = __hip_fp8x2_e4m3_fnuz; using fp8x2_type = __hip_fp8x2_e4m3_fnuz;
// #endif #endif
// // fp8 -> half // fp8 -> half
// template <> template <>
// __inline__ __device__ uint16_t __inline__ __device__ uint16_t
// vec_conversion<uint16_t, uint8_t>(const uint8_t& a) { vec_conversion<uint16_t, uint8_t>(const uint8_t& a) {
// return __hip_cvt_fp8_to_halfraw(a, fp8_type::__default_interpret).x; return __hip_cvt_fp8_to_halfraw(a, fp8_type::__default_interpret).x;
// } }
// // fp8x2 -> half2 // fp8x2 -> half2
// template <> template <>
// __inline__ __device__ uint32_t __inline__ __device__ uint32_t
// vec_conversion<uint32_t, uint16_t>(const uint16_t& a) { vec_conversion<uint32_t, uint16_t>(const uint16_t& a) {
// union { union {
// __half2_raw h2r; __half2_raw h2r;
// uint32_t ui32; uint32_t ui32;
// } tmp; } tmp;
// tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret); tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
// return tmp.ui32; return tmp.ui32;
// } }
// // fp8x4 -> half2x2 // fp8x4 -> half2x2
// template <> template <>
// __inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a) { __inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a) {
// union { union {
// uint2 u32x2; uint2 u32x2;
// uint32_t u32[2]; uint32_t u32[2];
// } tmp; } tmp;
// tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a); tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
// tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U)); tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
// return tmp.u32x2; return tmp.u32x2;
// } }
// // fp8x8 -> half2x4 // fp8x8 -> half2x4
// template <> template <>
// __inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a) { __inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a) {
// union { union {
// uint4 u64x2; uint4 u64x2;
// uint2 u64[2]; uint2 u64[2];
// } tmp; } tmp;
// tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x); tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
// tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y); tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
// return tmp.u64x2; return tmp.u64x2;
// } }
// using __nv_bfloat16 = __hip_bfloat16; using __nv_bfloat16 = __hip_bfloat16;
// // fp8 -> __nv_bfloat16 // fp8 -> __nv_bfloat16
// template <> template <>
// __inline__ __device__ __nv_bfloat16 __inline__ __device__ __nv_bfloat16
// vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) { vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) {
// fp8_type f8; fp8_type f8;
// f8.__x = a; f8.__x = a;
// return __float2bfloat16(static_cast<float>(f8)); return __float2bfloat16(static_cast<float>(f8));
// } }
// using __nv_bfloat162 = __hip_bfloat162; using __nv_bfloat162 = __hip_bfloat162;
// // fp8x2 -> __nv_bfloat162 // fp8x2 -> __nv_bfloat162
// template <> template <>
// __inline__ __device__ __nv_bfloat162 __inline__ __device__ __nv_bfloat162
// vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) { vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) {
// __nv_bfloat162 res; __nv_bfloat162 res;
// res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a); res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
// res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U)); res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
// return res; return res;
// } }
// // fp8x4 -> bf16_4_t // fp8x4 -> bf16_4_t
// template <> template <>
// __inline__ __device__ bf16_4_t __inline__ __device__ bf16_4_t
// vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a) { vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a) {
// bf16_4_t res; bf16_4_t res;
// res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a); res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
// res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U)); res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
// return res; return res;
// } }
// // fp8x8 -> bf16_8_t // fp8x8 -> bf16_8_t
// template <> template <>
// __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a) { __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a) {
// bf16_4_t tmp1, tmp2; bf16_4_t tmp1, tmp2;
// tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x); tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
// tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y); tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
// bf16_8_t res; bf16_8_t res;
// res.x = tmp1.x; res.x = tmp1.x;
// res.y = tmp1.y; res.y = tmp1.y;
// res.z = tmp2.x; res.z = tmp2.x;
// res.w = tmp2.y; res.w = tmp2.y;
// return res; return res;
// } }
// // fp8 -> float // fp8 -> float
// template <> template <>
// __inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a) { __inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a) {
// fp8_type f8; fp8_type f8;
// f8.__x = a; f8.__x = a;
// return static_cast<float>(f8); return static_cast<float>(f8);
// } }
// // fp8x2 -> float2 // fp8x2 -> float2
// template <> template <>
// __inline__ __device__ float2 __inline__ __device__ float2
// vec_conversion<float2, uint16_t>(const uint16_t& a) { vec_conversion<float2, uint16_t>(const uint16_t& a) {
// fp8x2_type f8x2; fp8x2_type f8x2;
// f8x2.__x = a; f8x2.__x = a;
// return static_cast<float2>(f8x2); return static_cast<float2>(f8x2);
// } }
// // fp8x4 -> float4 // fp8x4 -> float4
// template <> template <>
// __inline__ __device__ Float4_ __inline__ __device__ Float4_
// vec_conversion<Float4_, uint32_t>(const uint32_t& a) { vec_conversion<Float4_, uint32_t>(const uint32_t& a) {
// Float4_ res; Float4_ res;
// res.x = vec_conversion<float2, uint16_t>((uint16_t)a); res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
// res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U)); res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
// return res; return res;
// } }
// // fp8x4 -> float4 // fp8x4 -> float4
// template <> template <>
// __inline__ __device__ float4 __inline__ __device__ float4
// vec_conversion<float4, uint32_t>(const uint32_t& a) { vec_conversion<float4, uint32_t>(const uint32_t& a) {
// Float4_ tmp = vec_conversion<Float4_, uint32_t>(a); Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
// float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
// return res; return res;
// } }
// // fp8x8 -> float8 // fp8x8 -> float8
// template <> template <>
// __inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a) { __inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a) {
// Float4_ tmp1, tmp2; Float4_ tmp1, tmp2;
// tmp1 = vec_conversion<Float4_, uint32_t>(a.x); tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
// tmp2 = vec_conversion<Float4_, uint32_t>(a.y); tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
// Float8_ res; Float8_ res;
// res.x = tmp1.x; res.x = tmp1.x;
// res.y = tmp1.y; res.y = tmp1.y;
// res.z = tmp2.x; res.z = tmp2.x;
// res.w = tmp2.y; res.w = tmp2.y;
// return res; return res;
// } }
// // half -> fp8 // half -> fp8
// template <> template <>
// __inline__ __device__ uint8_t __inline__ __device__ uint8_t
// vec_conversion<uint8_t, uint16_t>(const uint16_t& a) { vec_conversion<uint8_t, uint16_t>(const uint16_t& a) {
// __half_raw tmp; __half_raw tmp;
// tmp.x = a; tmp.x = a;
// return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation, return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation,
// fp8_type::__default_interpret); fp8_type::__default_interpret);
// } }
// template <> template <>
// __inline__ __device__ uint16_t __inline__ __device__ uint16_t
// vec_conversion<uint16_t, uint32_t>(const uint32_t& a) { vec_conversion<uint16_t, uint32_t>(const uint32_t& a) {
// union { union {
// uint32_t ui32; uint32_t ui32;
// __half2_raw h2r; __half2_raw h2r;
// } tmp; } tmp;
// tmp.ui32 = a; tmp.ui32 = a;
// return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation, return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation,
// fp8_type::__default_interpret); fp8_type::__default_interpret);
// } }
// // bf16 -> fp8 // bf16 -> fp8
// template <> template <>
// __inline__ __device__ uint8_t __inline__ __device__ uint8_t
// vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a) { vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a) {
// return __hip_cvt_float_to_fp8(__bfloat162float(a), return __hip_cvt_float_to_fp8(__bfloat162float(a),
// fp8_type::__default_saturation, fp8_type::__default_saturation,
// fp8_type::__default_interpret); fp8_type::__default_interpret);
// } }
// // float -> fp8 // float -> fp8
// template <> template <>
// __inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a) { __inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a) {
// return __hip_cvt_float_to_fp8(a, fp8_type::__default_saturation, return __hip_cvt_float_to_fp8(a, fp8_type::__default_saturation,
// fp8_type::__default_interpret); fp8_type::__default_interpret);
// } }
// // float2 -> half2 // float2 -> half2
// template <> template <>
// __inline__ __device__ uint32_t __inline__ __device__ uint32_t
// vec_conversion<uint32_t, float2>(const float2& a) { vec_conversion<uint32_t, float2>(const float2& a) {
// union { union {
// half2 float16; half2 float16;
// uint32_t uint32; uint32_t uint32;
// }; };
// float16 = __float22half2_rn(a); float16 = __float22half2_rn(a);
// return uint32; return uint32;
// } }
// // Float4 -> half2x2 // Float4 -> half2x2
// template <> template <>
// __inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a) { __inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a) {
// uint2 b; uint2 b;
// float2 val; float2 val;
// val.x = a.x.x; val.x = a.x.x;
// val.y = a.x.y; val.y = a.x.y;
// b.x = vec_conversion<uint32_t, float2>(val); b.x = vec_conversion<uint32_t, float2>(val);
// val.x = a.y.x; val.x = a.y.x;
// val.y = a.y.y; val.y = a.y.y;
// b.y = vec_conversion<uint32_t, float2>(val); b.y = vec_conversion<uint32_t, float2>(val);
// return b; return b;
// } }
// // Float4 -> float4 // Float4 -> float4
// template <> template <>
// __inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a) { __inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a) {
// float4 b; float4 b;
// b.x = a.x.x; b.x = a.x.x;
// b.y = a.x.y; b.y = a.x.y;
// b.z = a.y.x; b.z = a.y.x;
// b.w = a.y.y; b.w = a.y.y;
// return b; return b;
// } }
// // Float8 -> half2x4 // Float8 -> half2x4
// template <> template <>
// __inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a) { __inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a) {
// uint4 b; uint4 b;
// b.x = vec_conversion<uint32_t, float2>(a.x); b.x = vec_conversion<uint32_t, float2>(a.x);
// b.y = vec_conversion<uint32_t, float2>(a.y); b.y = vec_conversion<uint32_t, float2>(a.y);
// b.z = vec_conversion<uint32_t, float2>(a.z); b.z = vec_conversion<uint32_t, float2>(a.z);
// b.w = vec_conversion<uint32_t, float2>(a.w); b.w = vec_conversion<uint32_t, float2>(a.w);
// return b; return b;
// } }
// // float2 -> bfloat162 // float2 -> bfloat162
// template <> template <>
// __inline__ __device__ __nv_bfloat162 __inline__ __device__ __nv_bfloat162
// vec_conversion<__nv_bfloat162, float2>(const float2& a) { vec_conversion<__nv_bfloat162, float2>(const float2& a) {
// __nv_bfloat162 b = __float22bfloat162_rn(a); __nv_bfloat162 b = __float22bfloat162_rn(a);
// return b; return b;
// } }
// // Float4 -> bfloat162x2 // Float4 -> bfloat162x2
// template <> template <>
// __inline__ __device__ bf16_4_t __inline__ __device__ bf16_4_t
// vec_conversion<bf16_4_t, Float4_>(const Float4_& a) { vec_conversion<bf16_4_t, Float4_>(const Float4_& a) {
// bf16_4_t b; bf16_4_t b;
// b.x = __float22bfloat162_rn(a.x); b.x = __float22bfloat162_rn(a.x);
// b.y = __float22bfloat162_rn(a.y); b.y = __float22bfloat162_rn(a.y);
// return b; return b;
// } }
// // Float8 -> bfloat162x4 // Float8 -> bfloat162x4
// template <> template <>
// __inline__ __device__ bf16_8_t __inline__ __device__ bf16_8_t
// vec_conversion<bf16_8_t, Float8_>(const Float8_& a) { vec_conversion<bf16_8_t, Float8_>(const Float8_& a) {
// bf16_8_t b; bf16_8_t b;
// b.x = __float22bfloat162_rn(a.x); b.x = __float22bfloat162_rn(a.x);
// b.y = __float22bfloat162_rn(a.y); b.y = __float22bfloat162_rn(a.y);
// b.z = __float22bfloat162_rn(a.z); b.z = __float22bfloat162_rn(a.z);
// b.w = __float22bfloat162_rn(a.w); b.w = __float22bfloat162_rn(a.w);
// return b; return b;
// } }
/* Scaled and vectorized conversions, for data exchange between high and low /* Scaled and vectorized conversions, for data exchange between high and low
precision domains precision domains
...@@ -345,11 +380,10 @@ using __nv_bfloat16 = __hip_bfloat16; ...@@ -345,11 +380,10 @@ using __nv_bfloat16 = __hip_bfloat16;
template <> template <>
__inline__ __device__ __nv_bfloat16 __inline__ __device__ __nv_bfloat16
scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) { scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) {
fp8_type f8;
return __float2bfloat16(fp8_to_float(a) * scale); f8.__x = a;
// fp8_type f8; return __float2bfloat16(static_cast<float>(f8) * scale);
// f8.__x = a; // return __float2bfloat16(fp8_to_float(a) * scale);
// return __float2bfloat16(static_cast<float>(f8) * scale);
} }
// fp8x2 -> __nv_bfloat162 // fp8x2 -> __nv_bfloat162
...@@ -394,24 +428,24 @@ scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, float scale) { ...@@ -394,24 +428,24 @@ scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, float scale) {
template <> template <>
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>( __inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
const uint8_t& a, float scale) { const uint8_t& a, float scale) {
return fp8_to_float(a) * scale; fp8_type f8;
// fp8_type f8; f8.__x = a;
// f8.__x = a; return static_cast<float>(f8) * scale;
// return static_cast<float>(f8) * scale; // return fp8_to_float(a) * scale;
} }
// fp8x2 -> float2 // fp8x2 -> float2
template <> template <>
__inline__ __device__ float2 __inline__ __device__ float2
scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, float scale) { scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, float scale) {
float2 f2r;
f2r.x = scaled_vec_conversion<float, uint8_t>((uint8_t)a, scale);
f2r.y = scaled_vec_conversion<float, uint8_t>((uint8_t)(a >> 8U), scale);
return f2r;
// [[maybe_unused]] // [[maybe_unused]]
// fp8x2_type f8x2; fp8x2_type f8x2;
// f8x2.__x = a; f8x2.__x = a;
// return static_cast<float2>(f8x2) * scale; return static_cast<float2>(f8x2) * scale;
// float2 f2r;
// f2r.x = scaled_vec_conversion<float, uint8_t>((uint8_t)a, scale);
// f2r.y = scaled_vec_conversion<float, uint8_t>((uint8_t)(a >> 8U), scale);
// return f2r;
} }
// fp8x4 -> float4 // fp8x4 -> float4
...@@ -451,34 +485,35 @@ scaled_vec_conversion<Float8_, uint2>(const uint2& a, float scale) { ...@@ -451,34 +485,35 @@ scaled_vec_conversion<Float8_, uint2>(const uint2& a, float scale) {
template <> template <>
__inline__ __device__ uint16_t __inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale) { scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale) {
float res = fp8_to_float(a) * scale; __half_raw res;
return float_to_half(res); res.data = scaled_vec_conversion<float, uint8_t>(a, scale);
// __half_raw res; return res.x;
// res.data = scaled_vec_conversion<float, uint8_t>(a, scale); // float res = fp8_to_float(a) * scale;
// return res.x; // return float_to_half(res);
} }
// fp8x2 -> half2 // fp8x2 -> half2
template <> template <>
__inline__ __device__ uint32_t __inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, float scale) { scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, float scale) {
union {
uint16_t u16[2];
uint32_t u32;
} res;
res.u16[0] = scaled_vec_conversion<uint16_t, uint8_t>((uint8_t)a, scale);
res.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>((uint8_t)(a >> 8U), scale);
return res.u32;
// [[maybe_unused]] __half2_raw h2r = // [[maybe_unused]] __half2_raw h2r =
// __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret); // __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
union {
__half2_raw h2r;
uint32_t ui32;
} tmp;
tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
tmp.h2r.x.data *= scale;
tmp.h2r.y.data *= scale;
return tmp.ui32;
// union { // union {
// __half2_raw h2r; // uint16_t u16[2];
// uint32_t ui32; // uint32_t u32;
// } tmp; // } res;
// tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret); // res.u16[0] = scaled_vec_conversion<uint16_t, uint8_t>((uint8_t)a, scale);
// tmp.h2r.x.data *= scale; // res.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>((uint8_t)(a >> 8U), scale);
// tmp.h2r.y.data *= scale; // return res.u32;
// return tmp.ui32;
} }
// fp8x4 -> half2x2 // fp8x4 -> half2x2
...@@ -490,7 +525,8 @@ scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, float scale) { ...@@ -490,7 +525,8 @@ scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, float scale) {
uint32_t u32[2]; uint32_t u32[2];
} tmp; } tmp;
tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale); tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale); tmp.u32[1] =
scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
return tmp.u32x2; return tmp.u32x2;
} }
...@@ -511,40 +547,40 @@ __inline__ __device__ uint4 scaled_vec_conversion<uint4, uint2>(const uint2& a, ...@@ -511,40 +547,40 @@ __inline__ __device__ uint4 scaled_vec_conversion<uint4, uint2>(const uint2& a,
template <> template <>
__inline__ __device__ uint8_t __inline__ __device__ uint8_t
scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, float scale) { scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, float scale) {
float res_f = half_to_float(a) / scale; __half_raw tmp;
return float_to_fp8(res_f); tmp.x = a;
// __half_raw tmp; tmp.data /= scale;
// tmp.x = a; return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation,
// tmp.data /= scale; fp8_type::__default_interpret);
// return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation, // float res_f = half_to_float(a) / scale;
// fp8_type::__default_interpret); // return float_to_fp8(res_f);
} }
// halfx2 -> fp8x2 // halfx2 -> fp8x2
template <> template <>
__inline__ __device__ uint16_t __inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale) { scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale) {
union {
uint8_t ui8[2];
uint16_t ui16;
} tmp;
union { union {
uint32_t ui32; uint32_t ui32;
half2 h2r; __half2_raw h2r;
} tmp_a; } tmp;
tmp_a.ui32 = a; tmp.ui32 = a;
tmp.ui8[0] = scaled_vec_conversion<uint8_t, uint16_t>(tmp_a.h2r.data[0], scale); tmp.h2r.x.data /= scale;
tmp.ui8[1] = scaled_vec_conversion<uint8_t, uint16_t>(tmp_a.h2r.data[1], scale); tmp.h2r.y.data /= scale;
return tmp.ui16; return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation,
fp8_type::__default_interpret);
// union { // union {
// uint32_t ui32; // uint8_t ui8[2];
// __half2_raw h2r; // uint16_t ui16;
// } tmp; // } tmp;
// tmp.ui32 = a; // union {
// tmp.h2r.x.data /= scale; // uint32_t ui32;
// tmp.h2r.y.data /= scale; // half2 h2r;
// return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation, // } tmp_a;
// fp8_type::__default_interpret); // tmp_a.ui32 = a;
// tmp.ui8[0] = scaled_vec_conversion<uint8_t, uint16_t>(tmp_a.h2r.data[0], scale);
// tmp.ui8[1] = scaled_vec_conversion<uint8_t, uint16_t>(tmp_a.h2r.data[1], scale);
// return tmp.ui16;
} }
// half2x2 -> fp8x4 // half2x2 -> fp8x4
...@@ -579,11 +615,11 @@ __inline__ __device__ uint2 scaled_vec_conversion<uint2, uint4>(const uint4& a, ...@@ -579,11 +615,11 @@ __inline__ __device__ uint2 scaled_vec_conversion<uint2, uint4>(const uint4& a,
template <> template <>
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>( __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
const __nv_bfloat16& a, float scale) { const __nv_bfloat16& a, float scale) {
float res_f = (static_cast<float>(a)) / scale; return __hip_cvt_float_to_fp8(__bfloat162float(a) / scale,
return float_to_fp8(res_f); fp8_type::__default_saturation,
// return __hip_cvt_float_to_fp8(__bfloat162float(a) / scale, fp8_type::__default_interpret);
// fp8_type::__default_saturation, // float res_f = (static_cast<float>(a)) / scale;
// fp8_type::__default_interpret); // return float_to_fp8(res_f);
} }
// bf16x2 -> fp8x2 // bf16x2 -> fp8x2
...@@ -626,24 +662,24 @@ scaled_vec_conversion<uint2, bf16_8_t>(const bf16_8_t& a, float scale) { ...@@ -626,24 +662,24 @@ scaled_vec_conversion<uint2, bf16_8_t>(const bf16_8_t& a, float scale) {
template <> template <>
__inline__ __device__ uint8_t __inline__ __device__ uint8_t
scaled_vec_conversion<uint8_t, float>(const float& a, float scale) { scaled_vec_conversion<uint8_t, float>(const float& a, float scale) {
return float_to_fp8(a / scale); return __hip_cvt_float_to_fp8(a / scale, fp8_type::__default_saturation,
// return __hip_cvt_float_to_fp8(a / scale, fp8_type::__default_saturation, fp8_type::__default_interpret);
// fp8_type::__default_interpret); // return float_to_fp8(a / scale);
} }
// floatx2 -> fp8x2 // floatx2 -> fp8x2
template <> template <>
__inline__ __device__ uint16_t __inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, float2>(const float2& a, float scale) { scaled_vec_conversion<uint16_t, float2>(const float2& a, float scale) {
union { return __hip_cvt_float2_to_fp8x2(a / scale, fp8_type::__default_saturation,
uint8_t ui8[2]; fp8_type::__default_interpret);
uint16_t ui16; // union {
} tmp; // uint8_t ui8[2];
tmp.ui8[0] = scaled_vec_conversion<uint8_t, float>(a.x, scale); // uint16_t ui16;
tmp.ui8[1] = scaled_vec_conversion<uint8_t, float>(a.y, scale); // } tmp;
return tmp.ui16; // tmp.ui8[0] = scaled_vec_conversion<uint8_t, float>(a.x, scale);
// return __hip_cvt_float2_to_fp8x2(a / scale, fp8_type::__default_saturation, // tmp.ui8[1] = scaled_vec_conversion<uint8_t, float>(a.y, scale);
// fp8_type::__default_interpret); // return tmp.ui16;
} }
// floatx4 -> fp8x4 // floatx4 -> fp8x4
...@@ -658,27 +694,27 @@ scaled_vec_conversion<uint32_t, float4>(const float4& a, float scale) { ...@@ -658,27 +694,27 @@ scaled_vec_conversion<uint32_t, float4>(const float4& a, float scale) {
tmp.ui16[1] = scaled_vec_conversion<uint16_t, float2>({a.z, a.w}, scale); tmp.ui16[1] = scaled_vec_conversion<uint16_t, float2>({a.z, a.w}, scale);
return tmp.ui32; return tmp.ui32;
} }
// #endif // ENABLE_FP8 #endif // ENABLE_FP8
// template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt> template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
// __inline__ __device__ Tout convert(const Tin& x) { __inline__ __device__ Tout convert(const Tin& x) {
// #ifdef ENABLE_FP8 #ifdef ENABLE_FP8
// if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
// return vec_conversion<Tout, Tin>(x); return vec_conversion<Tout, Tin>(x);
// } }
// #endif #endif
// assert(false); assert(false);
// return {}; // Squash missing return statement warning return {}; // Squash missing return statement warning
// } }
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt> template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) { __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
// #ifdef ENABLE_FP8 #ifdef ENABLE_FP8
// if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
return scaled_vec_conversion<Tout, Tin>(x, scale); return scaled_vec_conversion<Tout, Tin>(x, scale);
// } }
// #endif #endif
// assert(false); assert(false);
return {}; // Squash missing return statement warning return {}; // Squash missing return statement warning
} }
......
...@@ -132,8 +132,8 @@ STR_DTYPE_TO_TORCH_DTYPE = { ...@@ -132,8 +132,8 @@ STR_DTYPE_TO_TORCH_DTYPE = {
"bfloat16": torch.bfloat16, "bfloat16": torch.bfloat16,
"float": torch.float, "float": torch.float,
"fp8": torch.uint8, "fp8": torch.uint8,
# "fp8_e4m3": torch.uint8, "fp8_e4m3": torch.uint8,
# "fp8_e5m2": torch.uint8, "fp8_e5m2": torch.uint8,
"int8": torch.int8, "int8": torch.int8,
"fp8_inc": torch.float8_e4m3fn, "fp8_inc": torch.float8_e4m3fn,
"fp8_ds_mla": torch.uint8, "fp8_ds_mla": torch.uint8,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment