Commit 504a12b8 authored by zhuwenwen's avatar zhuwenwen
Browse files

add kvcache fp8

parent d074a953
...@@ -13,46 +13,50 @@ namespace vllm { ...@@ -13,46 +13,50 @@ namespace vllm {
#ifdef USE_ROCM #ifdef USE_ROCM
namespace fp8 { namespace fp8 {
#ifdef ENABLE_FP8 // #ifdef ENABLE_FP8
// Use hardware cvt instruction for fp8 on rocm // KV-CACHE int8
template <typename fp8_type> static inline __device__ float fp8_to_float(uint8_t input) {
__device__ __forceinline__ fp8_type cvt_c10(float const r) { const uint32_t w = (uint32_t)input << 24;
return {}; const uint32_t sign = w & UINT32_C(0x80000000);
const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF);
uint32_t renorm_shift = __clz(nonsign);
renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0;
uint32_t result = sign | ((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23));
return c10::detail::fp32_from_bits(result);
} }
// __hip_fp8_e4m3 only exists starting in ROCm 6.3. The macro // float -> fp8
// HIP_FP8_TYPE_OCP comes from the hip_fp8.h header and also makes static inline __device__ uint8_t float_to_fp8(float f) {
// its first appearance in ROCm 6.3. Since VLLM_DISPATCH_FP8_TYPES constexpr uint32_t fp8_max = UINT32_C(1087) << 20;
// on ROCm instantiates both OCP and FNUZ kernels, we need to replace constexpr uint32_t denorm_mask = UINT32_C(141) << 23;
// the new HW cvt with something reasonable that doesn't rely on the uint32_t f_bits = c10::detail::fp32_to_bits(f);
// ROCm 6.3 feature. This allows compiling on ROCm 6.2 or newer. uint8_t result = 0u;
template <> const uint32_t sign = f_bits & UINT32_C(0x80000000);
__device__ __forceinline__ c10::Float8_e4m3fn cvt_c10(float const r) { f_bits ^= sign;
#if HIP_FP8_TYPE_OCP if (f_bits >= fp8_max) {
return c10::Float8_e4m3fn( result = 0x7f;
__hip_cvt_float_to_fp8(r, __hip_fp8_e4m3::__default_saturation, } else {
__hip_fp8_e4m3::__default_interpret), if (f_bits < (UINT32_C(121) << 23)) {
c10::Float8_e4m3fn::from_bits()); f_bits =
#else c10::detail::fp32_to_bits(c10::detail::fp32_from_bits(f_bits) + c10::detail::fp32_from_bits(denorm_mask));
// Cast implemented by pytorch. Uses bit manipulation instead of HW cvt. result = static_cast<uint8_t>(f_bits - denorm_mask);
// HW cvt above is faster when it is available (ROCm 6.3 or newer). } else {
return static_cast<c10::Float8_e4m3fn>(r); uint8_t mant_odd = (f_bits >> 20) & 1;
#endif f_bits += ((uint32_t)(7 - 127) << 23) + 0x7FFFF;
} f_bits += mant_odd;
result = static_cast<uint8_t>(f_bits >> 20);
}
}
template <> result |= static_cast<uint8_t>(sign >> 24);
__device__ __forceinline__ c10::Float8_e4m3fnuz cvt_c10(float const r) { return result;
return c10::Float8_e4m3fnuz(
__hip_cvt_float_to_fp8(r, __hip_fp8_e4m3_fnuz::__default_saturation,
__hip_fp8_e4m3_fnuz::__default_interpret),
c10::Float8_e4m3fnuz::from_bits());
} }
template <typename Tout, typename Tin> // template <typename Tout, typename Tin>
__inline__ __device__ Tout vec_conversion(const Tin& x) { // __inline__ __device__ Tout vec_conversion(const Tin& x) {
return x; // return x;
} // }
template <typename Tout, typename Tin> template <typename Tout, typename Tin>
__inline__ __device__ Tout scaled_vec_conversion(const Tin& x, __inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
...@@ -60,271 +64,271 @@ __inline__ __device__ Tout scaled_vec_conversion(const Tin& x, ...@@ -60,271 +64,271 @@ __inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
return x; return x;
} }
#if HIP_FP8_TYPE_OCP // #if HIP_FP8_TYPE_OCP
using fp8_type = __hip_fp8_e4m3; // using fp8_type = __hip_fp8_e4m3;
using fp8x2_type = __hip_fp8x2_e4m3; // using fp8x2_type = __hip_fp8x2_e4m3;
#else // #else
using fp8_type = __hip_fp8_e4m3_fnuz; // using fp8_type = __hip_fp8_e4m3_fnuz;
using fp8x2_type = __hip_fp8x2_e4m3_fnuz; // using fp8x2_type = __hip_fp8x2_e4m3_fnuz;
#endif // #endif
// fp8 -> half // // fp8 -> half
template <> // template <>
__inline__ __device__ uint16_t // __inline__ __device__ uint16_t
vec_conversion<uint16_t, uint8_t>(const uint8_t& a) { // vec_conversion<uint16_t, uint8_t>(const uint8_t& a) {
return __hip_cvt_fp8_to_halfraw(a, fp8_type::__default_interpret).x; // return __hip_cvt_fp8_to_halfraw(a, fp8_type::__default_interpret).x;
} // }
// fp8x2 -> half2 // // fp8x2 -> half2
template <> // template <>
__inline__ __device__ uint32_t // __inline__ __device__ uint32_t
vec_conversion<uint32_t, uint16_t>(const uint16_t& a) { // vec_conversion<uint32_t, uint16_t>(const uint16_t& a) {
union { // union {
__half2_raw h2r; // __half2_raw h2r;
uint32_t ui32; // uint32_t ui32;
} tmp; // } tmp;
tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret); // tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
return tmp.ui32; // return tmp.ui32;
} // }
// fp8x4 -> half2x2 // // fp8x4 -> half2x2
template <> // template <>
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a) { // __inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a) {
union { // union {
uint2 u32x2; // uint2 u32x2;
uint32_t u32[2]; // uint32_t u32[2];
} tmp; // } tmp;
tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a); // tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U)); // tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
return tmp.u32x2; // return tmp.u32x2;
} // }
// fp8x8 -> half2x4 // // fp8x8 -> half2x4
template <> // template <>
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a) { // __inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a) {
union { // union {
uint4 u64x2; // uint4 u64x2;
uint2 u64[2]; // uint2 u64[2];
} tmp; // } tmp;
tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x); // tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y); // tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
return tmp.u64x2; // return tmp.u64x2;
} // }
using __nv_bfloat16 = __hip_bfloat16; // using __nv_bfloat16 = __hip_bfloat16;
// fp8 -> __nv_bfloat16 // // fp8 -> __nv_bfloat16
template <> // template <>
__inline__ __device__ __nv_bfloat16 // __inline__ __device__ __nv_bfloat16
vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) { // vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) {
fp8_type f8; // fp8_type f8;
f8.__x = a; // f8.__x = a;
return __float2bfloat16(static_cast<float>(f8)); // return __float2bfloat16(static_cast<float>(f8));
} // }
using __nv_bfloat162 = __hip_bfloat162; // using __nv_bfloat162 = __hip_bfloat162;
// fp8x2 -> __nv_bfloat162 // // fp8x2 -> __nv_bfloat162
template <> // template <>
__inline__ __device__ __nv_bfloat162 // __inline__ __device__ __nv_bfloat162
vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) { // vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) {
__nv_bfloat162 res; // __nv_bfloat162 res;
res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a); // res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U)); // res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
return res; // return res;
} // }
// fp8x4 -> bf16_4_t // // fp8x4 -> bf16_4_t
template <> // template <>
__inline__ __device__ bf16_4_t // __inline__ __device__ bf16_4_t
vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a) { // vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a) {
bf16_4_t res; // bf16_4_t res;
res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a); // res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U)); // res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
return res; // return res;
} // }
// fp8x8 -> bf16_8_t // // fp8x8 -> bf16_8_t
template <> // template <>
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a) { // __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a) {
bf16_4_t tmp1, tmp2; // bf16_4_t tmp1, tmp2;
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x); // tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y); // tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
bf16_8_t res; // bf16_8_t res;
res.x = tmp1.x; // res.x = tmp1.x;
res.y = tmp1.y; // res.y = tmp1.y;
res.z = tmp2.x; // res.z = tmp2.x;
res.w = tmp2.y; // res.w = tmp2.y;
return res; // return res;
} // }
// fp8 -> float // // fp8 -> float
template <> // template <>
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a) { // __inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a) {
fp8_type f8; // fp8_type f8;
f8.__x = a; // f8.__x = a;
return static_cast<float>(f8); // return static_cast<float>(f8);
} // }
// fp8x2 -> float2 // // fp8x2 -> float2
template <> // template <>
__inline__ __device__ float2 // __inline__ __device__ float2
vec_conversion<float2, uint16_t>(const uint16_t& a) { // vec_conversion<float2, uint16_t>(const uint16_t& a) {
fp8x2_type f8x2; // fp8x2_type f8x2;
f8x2.__x = a; // f8x2.__x = a;
return static_cast<float2>(f8x2); // return static_cast<float2>(f8x2);
} // }
// fp8x4 -> float4 // // fp8x4 -> float4
template <> // template <>
__inline__ __device__ Float4_ // __inline__ __device__ Float4_
vec_conversion<Float4_, uint32_t>(const uint32_t& a) { // vec_conversion<Float4_, uint32_t>(const uint32_t& a) {
Float4_ res; // Float4_ res;
res.x = vec_conversion<float2, uint16_t>((uint16_t)a); // res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U)); // res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
return res; // return res;
} // }
// fp8x4 -> float4 // // fp8x4 -> float4
template <> // template <>
__inline__ __device__ float4 // __inline__ __device__ float4
vec_conversion<float4, uint32_t>(const uint32_t& a) { // vec_conversion<float4, uint32_t>(const uint32_t& a) {
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a); // Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); // float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
return res; // return res;
} // }
// fp8x8 -> float8 // // fp8x8 -> float8
template <> // template <>
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a) { // __inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a) {
Float4_ tmp1, tmp2; // Float4_ tmp1, tmp2;
tmp1 = vec_conversion<Float4_, uint32_t>(a.x); // tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
tmp2 = vec_conversion<Float4_, uint32_t>(a.y); // tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
Float8_ res; // Float8_ res;
res.x = tmp1.x; // res.x = tmp1.x;
res.y = tmp1.y; // res.y = tmp1.y;
res.z = tmp2.x; // res.z = tmp2.x;
res.w = tmp2.y; // res.w = tmp2.y;
return res; // return res;
} // }
// half -> fp8 // // half -> fp8
template <> // template <>
__inline__ __device__ uint8_t // __inline__ __device__ uint8_t
vec_conversion<uint8_t, uint16_t>(const uint16_t& a) { // vec_conversion<uint8_t, uint16_t>(const uint16_t& a) {
__half_raw tmp; // __half_raw tmp;
tmp.x = a; // tmp.x = a;
return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation, // return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation,
fp8_type::__default_interpret); // fp8_type::__default_interpret);
} // }
template <> // template <>
__inline__ __device__ uint16_t // __inline__ __device__ uint16_t
vec_conversion<uint16_t, uint32_t>(const uint32_t& a) { // vec_conversion<uint16_t, uint32_t>(const uint32_t& a) {
union { // union {
uint32_t ui32; // uint32_t ui32;
__half2_raw h2r; // __half2_raw h2r;
} tmp; // } tmp;
tmp.ui32 = a; // tmp.ui32 = a;
return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation, // return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation,
fp8_type::__default_interpret); // fp8_type::__default_interpret);
} // }
// bf16 -> fp8 // // bf16 -> fp8
template <> // template <>
__inline__ __device__ uint8_t // __inline__ __device__ uint8_t
vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a) { // vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a) {
return __hip_cvt_float_to_fp8(__bfloat162float(a), // return __hip_cvt_float_to_fp8(__bfloat162float(a),
fp8_type::__default_saturation, // fp8_type::__default_saturation,
fp8_type::__default_interpret); // fp8_type::__default_interpret);
} // }
// float -> fp8 // // float -> fp8
template <> // template <>
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a) { // __inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a) {
return __hip_cvt_float_to_fp8(a, fp8_type::__default_saturation, // return __hip_cvt_float_to_fp8(a, fp8_type::__default_saturation,
fp8_type::__default_interpret); // fp8_type::__default_interpret);
} // }
// float2 -> half2 // // float2 -> half2
template <> // template <>
__inline__ __device__ uint32_t // __inline__ __device__ uint32_t
vec_conversion<uint32_t, float2>(const float2& a) { // vec_conversion<uint32_t, float2>(const float2& a) {
union { // union {
half2 float16; // half2 float16;
uint32_t uint32; // uint32_t uint32;
}; // };
float16 = __float22half2_rn(a); // float16 = __float22half2_rn(a);
return uint32; // return uint32;
} // }
// Float4 -> half2x2 // // Float4 -> half2x2
template <> // template <>
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a) { // __inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a) {
uint2 b; // uint2 b;
float2 val; // float2 val;
val.x = a.x.x; // val.x = a.x.x;
val.y = a.x.y; // val.y = a.x.y;
b.x = vec_conversion<uint32_t, float2>(val); // b.x = vec_conversion<uint32_t, float2>(val);
val.x = a.y.x; // val.x = a.y.x;
val.y = a.y.y; // val.y = a.y.y;
b.y = vec_conversion<uint32_t, float2>(val); // b.y = vec_conversion<uint32_t, float2>(val);
return b; // return b;
} // }
// Float4 -> float4 // // Float4 -> float4
template <> // template <>
__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a) { // __inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a) {
float4 b; // float4 b;
b.x = a.x.x; // b.x = a.x.x;
b.y = a.x.y; // b.y = a.x.y;
b.z = a.y.x; // b.z = a.y.x;
b.w = a.y.y; // b.w = a.y.y;
return b; // return b;
} // }
// Float8 -> half2x4 // // Float8 -> half2x4
template <> // template <>
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a) { // __inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a) {
uint4 b; // uint4 b;
b.x = vec_conversion<uint32_t, float2>(a.x); // b.x = vec_conversion<uint32_t, float2>(a.x);
b.y = vec_conversion<uint32_t, float2>(a.y); // b.y = vec_conversion<uint32_t, float2>(a.y);
b.z = vec_conversion<uint32_t, float2>(a.z); // b.z = vec_conversion<uint32_t, float2>(a.z);
b.w = vec_conversion<uint32_t, float2>(a.w); // b.w = vec_conversion<uint32_t, float2>(a.w);
return b; // return b;
} // }
// float2 -> bfloat162 // // float2 -> bfloat162
template <> // template <>
__inline__ __device__ __nv_bfloat162 // __inline__ __device__ __nv_bfloat162
vec_conversion<__nv_bfloat162, float2>(const float2& a) { // vec_conversion<__nv_bfloat162, float2>(const float2& a) {
__nv_bfloat162 b = __float22bfloat162_rn(a); // __nv_bfloat162 b = __float22bfloat162_rn(a);
return b; // return b;
} // }
// Float4 -> bfloat162x2 // // Float4 -> bfloat162x2
template <> // template <>
__inline__ __device__ bf16_4_t // __inline__ __device__ bf16_4_t
vec_conversion<bf16_4_t, Float4_>(const Float4_& a) { // vec_conversion<bf16_4_t, Float4_>(const Float4_& a) {
bf16_4_t b; // bf16_4_t b;
b.x = __float22bfloat162_rn(a.x); // b.x = __float22bfloat162_rn(a.x);
b.y = __float22bfloat162_rn(a.y); // b.y = __float22bfloat162_rn(a.y);
return b; // return b;
} // }
// Float8 -> bfloat162x4 // // Float8 -> bfloat162x4
template <> // template <>
__inline__ __device__ bf16_8_t // __inline__ __device__ bf16_8_t
vec_conversion<bf16_8_t, Float8_>(const Float8_& a) { // vec_conversion<bf16_8_t, Float8_>(const Float8_& a) {
bf16_8_t b; // bf16_8_t b;
b.x = __float22bfloat162_rn(a.x); // b.x = __float22bfloat162_rn(a.x);
b.y = __float22bfloat162_rn(a.y); // b.y = __float22bfloat162_rn(a.y);
b.z = __float22bfloat162_rn(a.z); // b.z = __float22bfloat162_rn(a.z);
b.w = __float22bfloat162_rn(a.w); // b.w = __float22bfloat162_rn(a.w);
return b; // return b;
} // }
/* Scaled and vectorized conversions, for data exchange between high and low /* Scaled and vectorized conversions, for data exchange between high and low
precision domains precision domains
...@@ -341,9 +345,11 @@ using __nv_bfloat16 = __hip_bfloat16; ...@@ -341,9 +345,11 @@ using __nv_bfloat16 = __hip_bfloat16;
template <> template <>
__inline__ __device__ __nv_bfloat16 __inline__ __device__ __nv_bfloat16
scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) { scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) {
fp8_type f8;
f8.__x = a; return __float2bfloat16(fp8_to_float(a) * scale);
return __float2bfloat16(static_cast<float>(f8) * scale); // fp8_type f8;
// f8.__x = a;
// return __float2bfloat16(static_cast<float>(f8) * scale);
} }
// fp8x2 -> __nv_bfloat162 // fp8x2 -> __nv_bfloat162
...@@ -388,18 +394,24 @@ scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, float scale) { ...@@ -388,18 +394,24 @@ scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, float scale) {
template <> template <>
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>( __inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
const uint8_t& a, float scale) { const uint8_t& a, float scale) {
fp8_type f8; return fp8_to_float(a) * scale;
f8.__x = a; // fp8_type f8;
return static_cast<float>(f8) * scale; // f8.__x = a;
// return static_cast<float>(f8) * scale;
} }
// fp8x2 -> float2 // fp8x2 -> float2
template <> template <>
__inline__ __device__ float2 __inline__ __device__ float2
scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, float scale) { scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, float scale) {
fp8x2_type f8x2; float2 f2r;
f8x2.__x = a; f2r.x = scaled_vec_conversion<float, uint8_t>((uint8_t)a, scale);
return static_cast<float2>(f8x2) * scale; f2r.y = scaled_vec_conversion<float, uint8_t>((uint8_t)(a >> 8U), scale);
return f2r;
// [[maybe_unused]]
// fp8x2_type f8x2;
// f8x2.__x = a;
// return static_cast<float2>(f8x2) * scale;
} }
// fp8x4 -> float4 // fp8x4 -> float4
...@@ -439,25 +451,34 @@ scaled_vec_conversion<Float8_, uint2>(const uint2& a, float scale) { ...@@ -439,25 +451,34 @@ scaled_vec_conversion<Float8_, uint2>(const uint2& a, float scale) {
template <> template <>
__inline__ __device__ uint16_t __inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale) { scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale) {
__half_raw res; float res = fp8_to_float(a) * scale;
res.data = scaled_vec_conversion<float, uint8_t>(a, scale); return float_to_half(res);
return res.x; // __half_raw res;
// res.data = scaled_vec_conversion<float, uint8_t>(a, scale);
// return res.x;
} }
// fp8x2 -> half2 // fp8x2 -> half2
template <> template <>
__inline__ __device__ uint32_t __inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, float scale) { scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, float scale) {
[[maybe_unused]] __half2_raw h2r =
__hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
union { union {
__half2_raw h2r; uint16_t u16[2];
uint32_t ui32; uint32_t u32;
} tmp; } res;
tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret); res.u16[0] = scaled_vec_conversion<uint16_t, uint8_t>((uint8_t)a, scale);
tmp.h2r.x.data *= scale; res.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>((uint8_t)(a >> 8U), scale);
tmp.h2r.y.data *= scale; return res.u32;
return tmp.ui32; // [[maybe_unused]] __half2_raw h2r =
// __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
// union {
// __half2_raw h2r;
// uint32_t ui32;
// } tmp;
// tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
// tmp.h2r.x.data *= scale;
// tmp.h2r.y.data *= scale;
// return tmp.ui32;
} }
// fp8x4 -> half2x2 // fp8x4 -> half2x2
...@@ -469,8 +490,7 @@ scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, float scale) { ...@@ -469,8 +490,7 @@ scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, float scale) {
uint32_t u32[2]; uint32_t u32[2];
} tmp; } tmp;
tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale); tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
tmp.u32[1] = tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
return tmp.u32x2; return tmp.u32x2;
} }
...@@ -491,11 +511,13 @@ __inline__ __device__ uint4 scaled_vec_conversion<uint4, uint2>(const uint2& a, ...@@ -491,11 +511,13 @@ __inline__ __device__ uint4 scaled_vec_conversion<uint4, uint2>(const uint2& a,
template <> template <>
__inline__ __device__ uint8_t __inline__ __device__ uint8_t
scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, float scale) { scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, float scale) {
__half_raw tmp; float res_f = half_to_float(a) / scale;
tmp.x = a; return float_to_fp8(res_f);
tmp.data /= scale; // __half_raw tmp;
return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation, // tmp.x = a;
fp8_type::__default_interpret); // tmp.data /= scale;
// return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation,
// fp8_type::__default_interpret);
} }
// halfx2 -> fp8x2 // halfx2 -> fp8x2
...@@ -503,14 +525,26 @@ template <> ...@@ -503,14 +525,26 @@ template <>
__inline__ __device__ uint16_t __inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale) { scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale) {
union { union {
uint32_t ui32; uint8_t ui8[2];
__half2_raw h2r; uint16_t ui16;
} tmp; } tmp;
tmp.ui32 = a; union {
tmp.h2r.x.data /= scale; uint32_t ui32;
tmp.h2r.y.data /= scale; half2 h2r;
return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation, } tmp_a;
fp8_type::__default_interpret); tmp_a.ui32 = a;
tmp.ui8[0] = scaled_vec_conversion<uint8_t, uint16_t>(tmp_a.h2r.data[0], scale);
tmp.ui8[1] = scaled_vec_conversion<uint8_t, uint16_t>(tmp_a.h2r.data[1], scale);
return tmp.ui16;
// union {
// uint32_t ui32;
// __half2_raw h2r;
// } tmp;
// tmp.ui32 = a;
// tmp.h2r.x.data /= scale;
// tmp.h2r.y.data /= scale;
// return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation,
// fp8_type::__default_interpret);
} }
// half2x2 -> fp8x4 // half2x2 -> fp8x4
...@@ -545,9 +579,11 @@ __inline__ __device__ uint2 scaled_vec_conversion<uint2, uint4>(const uint4& a, ...@@ -545,9 +579,11 @@ __inline__ __device__ uint2 scaled_vec_conversion<uint2, uint4>(const uint4& a,
template <> template <>
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>( __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
const __nv_bfloat16& a, float scale) { const __nv_bfloat16& a, float scale) {
return __hip_cvt_float_to_fp8(__bfloat162float(a) / scale, float res_f = (static_cast<float>(a)) / scale;
fp8_type::__default_saturation, return float_to_fp8(res_f);
fp8_type::__default_interpret); // return __hip_cvt_float_to_fp8(__bfloat162float(a) / scale,
// fp8_type::__default_saturation,
// fp8_type::__default_interpret);
} }
// bf16x2 -> fp8x2 // bf16x2 -> fp8x2
...@@ -590,16 +626,24 @@ scaled_vec_conversion<uint2, bf16_8_t>(const bf16_8_t& a, float scale) { ...@@ -590,16 +626,24 @@ scaled_vec_conversion<uint2, bf16_8_t>(const bf16_8_t& a, float scale) {
template <> template <>
__inline__ __device__ uint8_t __inline__ __device__ uint8_t
scaled_vec_conversion<uint8_t, float>(const float& a, float scale) { scaled_vec_conversion<uint8_t, float>(const float& a, float scale) {
return __hip_cvt_float_to_fp8(a / scale, fp8_type::__default_saturation, return float_to_fp8(a / scale);
fp8_type::__default_interpret); // return __hip_cvt_float_to_fp8(a / scale, fp8_type::__default_saturation,
// fp8_type::__default_interpret);
} }
// floatx2 -> fp8x2 // floatx2 -> fp8x2
template <> template <>
__inline__ __device__ uint16_t __inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, float2>(const float2& a, float scale) { scaled_vec_conversion<uint16_t, float2>(const float2& a, float scale) {
return __hip_cvt_float2_to_fp8x2(a / scale, fp8_type::__default_saturation, union {
fp8_type::__default_interpret); uint8_t ui8[2];
uint16_t ui16;
} tmp;
tmp.ui8[0] = scaled_vec_conversion<uint8_t, float>(a.x, scale);
tmp.ui8[1] = scaled_vec_conversion<uint8_t, float>(a.y, scale);
return tmp.ui16;
// return __hip_cvt_float2_to_fp8x2(a / scale, fp8_type::__default_saturation,
// fp8_type::__default_interpret);
} }
// floatx4 -> fp8x4 // floatx4 -> fp8x4
...@@ -614,27 +658,27 @@ scaled_vec_conversion<uint32_t, float4>(const float4& a, float scale) { ...@@ -614,27 +658,27 @@ scaled_vec_conversion<uint32_t, float4>(const float4& a, float scale) {
tmp.ui16[1] = scaled_vec_conversion<uint16_t, float2>({a.z, a.w}, scale); tmp.ui16[1] = scaled_vec_conversion<uint16_t, float2>({a.z, a.w}, scale);
return tmp.ui32; return tmp.ui32;
} }
#endif // ENABLE_FP8 // #endif // ENABLE_FP8
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt> // template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__inline__ __device__ Tout convert(const Tin& x) { // __inline__ __device__ Tout convert(const Tin& x) {
#ifdef ENABLE_FP8 // #ifdef ENABLE_FP8
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { // if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
return vec_conversion<Tout, Tin>(x); // return vec_conversion<Tout, Tin>(x);
} // }
#endif // #endif
assert(false); // assert(false);
return {}; // Squash missing return statement warning // return {}; // Squash missing return statement warning
} // }
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt> template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) { __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
#ifdef ENABLE_FP8 // #ifdef ENABLE_FP8
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { // if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
return scaled_vec_conversion<Tout, Tin>(x, scale); return scaled_vec_conversion<Tout, Tin>(x, scale);
} // }
#endif // #endif
assert(false); // assert(false);
return {}; // Squash missing return statement warning return {}; // Squash missing return statement warning
} }
...@@ -682,4 +726,4 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) { ...@@ -682,4 +726,4 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
} // namespace fp8 } // namespace fp8
#endif // USE_ROCM #endif // USE_ROCM
} // namespace vllm } // namespace vllm
\ No newline at end of file
#pragma once
#ifndef USE_ROCM
#include <hip/hip_fp8.h>
#endif
#include <hip/hip_fp16.h>
#include <hip/hip_bf16.h>
#include <hip/hip_bfloat16.h>
#include "../../../attention/attention_dtypes.h"
namespace vllm {
#ifdef USE_ROCM
namespace fp8 {
#ifdef ENABLE_FP8
// Use hardware cvt instruction for fp8 on rocm
template <typename fp8_type>
__device__ __forceinline__ fp8_type cvt_c10(float const r) {
return {};
}
// __hip_fp8_e4m3 only exists starting in ROCm 6.3. The macro
// HIP_FP8_TYPE_OCP comes from the hip_fp8.h header and also makes
// its first appearance in ROCm 6.3. Since VLLM_DISPATCH_FP8_TYPES
// on ROCm instantiates both OCP and FNUZ kernels, we need to replace
// the new HW cvt with something reasonable that doesn't rely on the
// ROCm 6.3 feature. This allows compiling on ROCm 6.2 or newer.
template <>
__device__ __forceinline__ c10::Float8_e4m3fn cvt_c10(float const r) {
#if HIP_FP8_TYPE_OCP
return c10::Float8_e4m3fn(
__hip_cvt_float_to_fp8(r, __hip_fp8_e4m3::__default_saturation,
__hip_fp8_e4m3::__default_interpret),
c10::Float8_e4m3fn::from_bits());
#else
// Cast implemented by pytorch. Uses bit manipulation instead of HW cvt.
// HW cvt above is faster when it is available (ROCm 6.3 or newer).
return static_cast<c10::Float8_e4m3fn>(r);
#endif
}
template <>
__device__ __forceinline__ c10::Float8_e4m3fnuz cvt_c10(float const r) {
return c10::Float8_e4m3fnuz(
__hip_cvt_float_to_fp8(r, __hip_fp8_e4m3_fnuz::__default_saturation,
__hip_fp8_e4m3_fnuz::__default_interpret),
c10::Float8_e4m3fnuz::from_bits());
}
template <typename Tout, typename Tin>
__inline__ __device__ Tout vec_conversion(const Tin& x) {
return x;
}
template <typename Tout, typename Tin>
__inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
const float scale) {
return x;
}
#if HIP_FP8_TYPE_OCP
using fp8_type = __hip_fp8_e4m3;
using fp8x2_type = __hip_fp8x2_e4m3;
#else
using fp8_type = __hip_fp8_e4m3_fnuz;
using fp8x2_type = __hip_fp8x2_e4m3_fnuz;
#endif
// fp8 -> half
template <>
__inline__ __device__ uint16_t
vec_conversion<uint16_t, uint8_t>(const uint8_t& a) {
return __hip_cvt_fp8_to_halfraw(a, fp8_type::__default_interpret).x;
}
// fp8x2 -> half2
template <>
__inline__ __device__ uint32_t
vec_conversion<uint32_t, uint16_t>(const uint16_t& a) {
union {
__half2_raw h2r;
uint32_t ui32;
} tmp;
tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
return tmp.ui32;
}
// fp8x4 -> half2x2
template <>
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a) {
union {
uint2 u32x2;
uint32_t u32[2];
} tmp;
tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
return tmp.u32x2;
}
// fp8x8 -> half2x4
template <>
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a) {
union {
uint4 u64x2;
uint2 u64[2];
} tmp;
tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
return tmp.u64x2;
}
using __nv_bfloat16 = __hip_bfloat16;
// fp8 -> __nv_bfloat16
template <>
__inline__ __device__ __nv_bfloat16
vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) {
fp8_type f8;
f8.__x = a;
return __float2bfloat16(static_cast<float>(f8));
}
using __nv_bfloat162 = __hip_bfloat162;
// fp8x2 -> __nv_bfloat162
template <>
__inline__ __device__ __nv_bfloat162
vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) {
__nv_bfloat162 res;
res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
return res;
}
// fp8x4 -> bf16_4_t
template <>
__inline__ __device__ bf16_4_t
vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a) {
bf16_4_t res;
res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
return res;
}
// fp8x8 -> bf16_8_t
template <>
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a) {
bf16_4_t tmp1, tmp2;
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
bf16_8_t res;
res.x = tmp1.x;
res.y = tmp1.y;
res.z = tmp2.x;
res.w = tmp2.y;
return res;
}
// fp8 -> float
template <>
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a) {
fp8_type f8;
f8.__x = a;
return static_cast<float>(f8);
}
// fp8x2 -> float2
template <>
__inline__ __device__ float2
vec_conversion<float2, uint16_t>(const uint16_t& a) {
fp8x2_type f8x2;
f8x2.__x = a;
return static_cast<float2>(f8x2);
}
// fp8x4 -> float4
template <>
__inline__ __device__ Float4_
vec_conversion<Float4_, uint32_t>(const uint32_t& a) {
Float4_ res;
res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
return res;
}
// fp8x4 -> float4
template <>
__inline__ __device__ float4
vec_conversion<float4, uint32_t>(const uint32_t& a) {
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
return res;
}
// fp8x8 -> float8
template <>
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a) {
Float4_ tmp1, tmp2;
tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
Float8_ res;
res.x = tmp1.x;
res.y = tmp1.y;
res.z = tmp2.x;
res.w = tmp2.y;
return res;
}
// half -> fp8
template <>
__inline__ __device__ uint8_t
vec_conversion<uint8_t, uint16_t>(const uint16_t& a) {
__half_raw tmp;
tmp.x = a;
return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation,
fp8_type::__default_interpret);
}
template <>
__inline__ __device__ uint16_t
vec_conversion<uint16_t, uint32_t>(const uint32_t& a) {
union {
uint32_t ui32;
__half2_raw h2r;
} tmp;
tmp.ui32 = a;
return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation,
fp8_type::__default_interpret);
}
// bf16 -> fp8
template <>
__inline__ __device__ uint8_t
vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a) {
return __hip_cvt_float_to_fp8(__bfloat162float(a),
fp8_type::__default_saturation,
fp8_type::__default_interpret);
}
// float -> fp8
template <>
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a) {
return __hip_cvt_float_to_fp8(a, fp8_type::__default_saturation,
fp8_type::__default_interpret);
}
// float2 -> half2
template <>
__inline__ __device__ uint32_t
vec_conversion<uint32_t, float2>(const float2& a) {
union {
half2 float16;
uint32_t uint32;
};
float16 = __float22half2_rn(a);
return uint32;
}
// Float4 -> half2x2
template <>
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a) {
uint2 b;
float2 val;
val.x = a.x.x;
val.y = a.x.y;
b.x = vec_conversion<uint32_t, float2>(val);
val.x = a.y.x;
val.y = a.y.y;
b.y = vec_conversion<uint32_t, float2>(val);
return b;
}
// Float4 -> float4
template <>
__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a) {
float4 b;
b.x = a.x.x;
b.y = a.x.y;
b.z = a.y.x;
b.w = a.y.y;
return b;
}
// Float8 -> half2x4
template <>
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a) {
uint4 b;
b.x = vec_conversion<uint32_t, float2>(a.x);
b.y = vec_conversion<uint32_t, float2>(a.y);
b.z = vec_conversion<uint32_t, float2>(a.z);
b.w = vec_conversion<uint32_t, float2>(a.w);
return b;
}
// float2 -> bfloat162
template <>
__inline__ __device__ __nv_bfloat162
vec_conversion<__nv_bfloat162, float2>(const float2& a) {
__nv_bfloat162 b = __float22bfloat162_rn(a);
return b;
}
// Float4 -> bfloat162x2
template <>
__inline__ __device__ bf16_4_t
vec_conversion<bf16_4_t, Float4_>(const Float4_& a) {
bf16_4_t b;
b.x = __float22bfloat162_rn(a.x);
b.y = __float22bfloat162_rn(a.y);
return b;
}
// Float8 -> bfloat162x4
template <>
__inline__ __device__ bf16_8_t
vec_conversion<bf16_8_t, Float8_>(const Float8_& a) {
bf16_8_t b;
b.x = __float22bfloat162_rn(a.x);
b.y = __float22bfloat162_rn(a.y);
b.z = __float22bfloat162_rn(a.z);
b.w = __float22bfloat162_rn(a.w);
return b;
}
/* Scaled and vectorized conversions, for data exchange between high and low
precision domains
Convention of the scale in API, e.g: FP8_data = Quantization(
High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8 Dequant(FP8) *
scale => HP
*/
using __nv_bfloat16 = __hip_bfloat16;
// fp8 -> __nv_bfloat16
template <>
__inline__ __device__ __nv_bfloat16
scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) {
fp8_type f8;
f8.__x = a;
return __float2bfloat16(static_cast<float>(f8) * scale);
}
// fp8x2 -> __nv_bfloat162
template <>
__inline__ __device__ __nv_bfloat162
scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a,
float scale) {
__nv_bfloat162 res;
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale);
res.y =
scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale);
return res;
}
// fp8x4 -> bf16_4_t
template <>
__inline__ __device__ bf16_4_t
scaled_vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a, float scale) {
bf16_4_t res;
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale);
res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
scale);
return res;
}
// fp8x8 -> bf16_8_t
template <>
__inline__ __device__ bf16_8_t
scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, float scale) {
bf16_4_t tmp1, tmp2;
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale);
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale);
bf16_8_t res;
res.x = tmp1.x;
res.y = tmp1.y;
res.z = tmp2.x;
res.w = tmp2.y;
return res;
}
// fp8 -> float
template <>
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
const uint8_t& a, float scale) {
fp8_type f8;
f8.__x = a;
return static_cast<float>(f8) * scale;
}
// fp8x2 -> float2
template <>
__inline__ __device__ float2
scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, float scale) {
fp8x2_type f8x2;
f8x2.__x = a;
return static_cast<float2>(f8x2) * scale;
}
// fp8x4 -> float4
template <>
__inline__ __device__ Float4_
scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale) {
Float4_ res;
res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale);
res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale);
return res;
}
// fp8x4 -> float4
template <>
__inline__ __device__ float4
scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, float scale) {
Float4_ res = scaled_vec_conversion<Float4_, uint32_t>(a, scale);
return {res.x.x, res.x.y, res.y.x, res.y.y};
}
// fp8x8 -> float8
template <>
__inline__ __device__ Float8_
scaled_vec_conversion<Float8_, uint2>(const uint2& a, float scale) {
Float4_ tmp1, tmp2;
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale);
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale);
Float8_ res;
res.x = tmp1.x;
res.y = tmp1.y;
res.z = tmp2.x;
res.w = tmp2.y;
return res;
}
// fp8 -> half
template <>
__inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale) {
__half_raw res;
res.data = scaled_vec_conversion<float, uint8_t>(a, scale);
return res.x;
}
// fp8x2 -> half2
template <>
__inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, float scale) {
[[maybe_unused]] __half2_raw h2r =
__hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
union {
__half2_raw h2r;
uint32_t ui32;
} tmp;
tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
tmp.h2r.x.data *= scale;
tmp.h2r.y.data *= scale;
return tmp.ui32;
}
// fp8x4 -> half2x2
template <>
__inline__ __device__ uint2
scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, float scale) {
union {
uint2 u32x2;
uint32_t u32[2];
} tmp;
tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
tmp.u32[1] =
scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
return tmp.u32x2;
}
// fp8x8 -> half2x4
template <>
__inline__ __device__ uint4 scaled_vec_conversion<uint4, uint2>(const uint2& a,
float scale) {
union {
uint4 u64x2;
uint2 u64[2];
} tmp;
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale);
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale);
return tmp.u64x2;
}
// half -> fp8
template <>
__inline__ __device__ uint8_t
scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, float scale) {
__half_raw tmp;
tmp.x = a;
tmp.data /= scale;
return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation,
fp8_type::__default_interpret);
}
// halfx2 -> fp8x2
template <>
__inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale) {
union {
uint32_t ui32;
__half2_raw h2r;
} tmp;
tmp.ui32 = a;
tmp.h2r.x.data /= scale;
tmp.h2r.y.data /= scale;
return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation,
fp8_type::__default_interpret);
}
// half2x2 -> fp8x4
template <>
__inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, uint2>(const uint2& a, float scale) {
union {
uint16_t ui16[2];
uint32_t ui32;
} tmp;
tmp.ui16[0] = scaled_vec_conversion<uint16_t, uint32_t>(a.x, scale);
tmp.ui16[1] = scaled_vec_conversion<uint16_t, uint32_t>(a.y, scale);
return tmp.ui32;
}
// half2x4 -> fp8x8
template <>
__inline__ __device__ uint2 scaled_vec_conversion<uint2, uint4>(const uint4& a,
float scale) {
union {
uint2 ui2[2];
uint4 ui4;
} tmp;
tmp.ui4 = a;
uint2 res;
res.x = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[0], scale);
res.y = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[1], scale);
return res;
}
// bf16 -> fp8
template <>
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
const __nv_bfloat16& a, float scale) {
return __hip_cvt_float_to_fp8(__bfloat162float(a) / scale,
fp8_type::__default_saturation,
fp8_type::__default_interpret);
}
// bf16x2 -> fp8x2
template <>
__inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, __nv_bfloat162>(
const __nv_bfloat162& a, float scale) {
union {
uint8_t ui8[2];
uint16_t ui16;
} tmp;
tmp.ui8[0] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.x, scale);
tmp.ui8[1] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.y, scale);
return tmp.ui16;
}
// bf16x4 -> fp8x4
template <>
__inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, bf16_4_t>(const bf16_4_t& a, float scale) {
union {
uint16_t ui16[2];
uint32_t ui32;
} tmp;
tmp.ui16[0] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.x, scale);
tmp.ui16[1] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.y, scale);
return tmp.ui32;
}
// bf16x8 -> fp8x8
template <>
__inline__ __device__ uint2
scaled_vec_conversion<uint2, bf16_8_t>(const bf16_8_t& a, float scale) {
uint2 res;
res.x = scaled_vec_conversion<uint32_t, bf16_4_t>({a.x, a.y}, scale);
res.y = scaled_vec_conversion<uint32_t, bf16_4_t>({a.z, a.w}, scale);
return res;
}
// float -> fp8
template <>
__inline__ __device__ uint8_t
scaled_vec_conversion<uint8_t, float>(const float& a, float scale) {
return __hip_cvt_float_to_fp8(a / scale, fp8_type::__default_saturation,
fp8_type::__default_interpret);
}
// floatx2 -> fp8x2
template <>
__inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, float2>(const float2& a, float scale) {
return __hip_cvt_float2_to_fp8x2(a / scale, fp8_type::__default_saturation,
fp8_type::__default_interpret);
}
// floatx4 -> fp8x4
template <>
__inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, float4>(const float4& a, float scale) {
union {
uint16_t ui16[2];
uint32_t ui32;
} tmp;
tmp.ui16[0] = scaled_vec_conversion<uint16_t, float2>({a.x, a.y}, scale);
tmp.ui16[1] = scaled_vec_conversion<uint16_t, float2>({a.z, a.w}, scale);
return tmp.ui32;
}
#endif // ENABLE_FP8
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__inline__ __device__ Tout convert(const Tin& x) {
#ifdef ENABLE_FP8
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
return vec_conversion<Tout, Tin>(x);
}
#endif
assert(false);
return {}; // Squash missing return statement warning
}
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
#ifdef ENABLE_FP8
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
return scaled_vec_conversion<Tout, Tin>(x, scale);
}
#endif
assert(false);
return {}; // Squash missing return statement warning
}
// The following macro is used to dispatch the conversion function based on
// the data type of the key and value cache. The FN is a macro that calls a
// function with template<typename scalar_t, typename cache_t,
// Fp8KVCacheDataType kv_dt>.
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
if (KV_DTYPE == "auto") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \
} else { \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else if (KV_DTYPE == "int8") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kInt8); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kInt8); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kInt8); \
} else { \
TORCH_CHECK(false,"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \
if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else { \
TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
} \
}
} // namespace fp8
#endif // USE_ROCM
} // namespace vllm
...@@ -172,7 +172,7 @@ STR_DTYPE_TO_TORCH_DTYPE = { ...@@ -172,7 +172,7 @@ STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.half, "half": torch.half,
"bfloat16": torch.bfloat16, "bfloat16": torch.bfloat16,
"float": torch.float, "float": torch.float,
# "fp8": torch.uint8, "fp8": torch.uint8,
# "fp8_e4m3": torch.uint8, # "fp8_e4m3": torch.uint8,
# "fp8_e5m2": torch.uint8, # "fp8_e5m2": torch.uint8,
"int8": torch.int8, "int8": torch.int8,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment