#pragma once #ifndef USE_ROCM #include #endif #include #include #include #include "../../../attention/attention_dtypes.h" namespace vllm { #ifdef USE_ROCM namespace fp8 { // #ifdef ENABLE_FP8 // KV-CACHE int8 static inline __device__ float fp8_to_float(uint8_t input) { const uint32_t w = (uint32_t)input << 24; const uint32_t sign = w & UINT32_C(0x80000000); const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF); uint32_t renorm_shift = __clz(nonsign); renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0; uint32_t result = sign | ((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)); return c10::detail::fp32_from_bits(result); } // float -> fp8 static inline __device__ uint8_t float_to_fp8_e4m3(float f) { constexpr uint32_t fp8_max = UINT32_C(1087) << 20; constexpr uint32_t denorm_mask = UINT32_C(141) << 23; uint32_t f_bits = c10::detail::fp32_to_bits(f); uint8_t result = 0u; const uint32_t sign = f_bits & UINT32_C(0x80000000); f_bits ^= sign; if (f_bits >= fp8_max) { result = 0x7f; } else { if (f_bits < (UINT32_C(121) << 23)) { f_bits = c10::detail::fp32_to_bits(c10::detail::fp32_from_bits(f_bits) + c10::detail::fp32_from_bits(denorm_mask)); result = static_cast(f_bits - denorm_mask); } else { uint8_t mant_odd = (f_bits >> 20) & 1; f_bits += ((uint32_t)(7 - 127) << 23) + 0x7FFFF; f_bits += mant_odd; result = static_cast(f_bits >> 20); } } result |= static_cast(sign >> 24); return result; } static inline __device__ uint8_t float_to_fp8_e5m2(float f) { constexpr uint32_t fp32_inf = UINT32_C(255) << 23; constexpr uint32_t fp8_max = UINT32_C(143) << 23; constexpr uint32_t denorm_mask = UINT32_C(134) << 23; uint32_t f_bits = c10::detail::fp32_to_bits(f); uint8_t result = 0u; const uint32_t sign = f_bits & UINT32_C(0x80000000); f_bits ^= sign; if (f_bits >= fp8_max) { result = f_bits > fp32_inf ? UINT8_C(0x7F) : UINT8_C(0x7C); } else { if (f_bits < (UINT32_C(113) << 23)) { f_bits = c10::detail::fp32_to_bits(c10::detail::fp32_from_bits(f_bits) + c10::detail::fp32_from_bits(denorm_mask)); result = static_cast(f_bits - denorm_mask); } else { uint32_t mant_odd = (f_bits >> 21) & 1; f_bits += ((uint32_t)(15 - 127) << 23) + 0xFFFFF; f_bits += mant_odd; result = static_cast(f_bits >> 21); } } result |= static_cast(sign >> 24); return result; } // template // __inline__ __device__ Tout vec_conversion(const Tin& x) { // return x; // } template __inline__ __device__ Tout scaled_vec_conversion(const Tin& x, const float scale, Fp8KVCacheDataType kv_type) { return x; } // #if HIP_FP8_TYPE_OCP // using fp8_type = __hip_fp8_e4m3; // using fp8x2_type = __hip_fp8x2_e4m3; // #else // using fp8_type = __hip_fp8_e4m3_fnuz; // using fp8x2_type = __hip_fp8x2_e4m3_fnuz; // #endif // // fp8 -> half // template <> // __inline__ __device__ uint16_t // vec_conversion(const uint8_t& a) { // return __hip_cvt_fp8_to_halfraw(a, fp8_type::__default_interpret).x; // } // // fp8x2 -> half2 // template <> // __inline__ __device__ uint32_t // vec_conversion(const uint16_t& a) { // union { // __half2_raw h2r; // uint32_t ui32; // } tmp; // tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret); // return tmp.ui32; // } // // fp8x4 -> half2x2 // template <> // __inline__ __device__ uint2 vec_conversion(const uint32_t& a) { // union { // uint2 u32x2; // uint32_t u32[2]; // } tmp; // tmp.u32[0] = vec_conversion((uint16_t)a); // tmp.u32[1] = vec_conversion((uint16_t)(a >> 16U)); // return tmp.u32x2; // } // // fp8x8 -> half2x4 // template <> // __inline__ __device__ uint4 vec_conversion(const uint2& a) { // union { // uint4 u64x2; // uint2 u64[2]; // } tmp; // tmp.u64[0] = vec_conversion(a.x); // tmp.u64[1] = vec_conversion(a.y); // return tmp.u64x2; // } // using __nv_bfloat16 = __hip_bfloat16; // // fp8 -> __nv_bfloat16 // template <> // __inline__ __device__ __nv_bfloat16 // vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) { // fp8_type f8; // f8.__x = a; // return __float2bfloat16(static_cast(f8)); // } // using __nv_bfloat162 = __hip_bfloat162; // // fp8x2 -> __nv_bfloat162 // template <> // __inline__ __device__ __nv_bfloat162 // vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) { // __nv_bfloat162 res; // res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a); // res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U)); // return res; // } // // fp8x4 -> bf16_4_t // template <> // __inline__ __device__ bf16_4_t // vec_conversion(const uint32_t& a) { // bf16_4_t res; // res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a); // res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U)); // return res; // } // // fp8x8 -> bf16_8_t // template <> // __inline__ __device__ bf16_8_t vec_conversion(const uint2& a) { // bf16_4_t tmp1, tmp2; // tmp1 = vec_conversion(a.x); // tmp2 = vec_conversion(a.y); // bf16_8_t res; // res.x = tmp1.x; // res.y = tmp1.y; // res.z = tmp2.x; // res.w = tmp2.y; // return res; // } // // fp8 -> float // template <> // __inline__ __device__ float vec_conversion(const uint8_t& a) { // fp8_type f8; // f8.__x = a; // return static_cast(f8); // } // // fp8x2 -> float2 // template <> // __inline__ __device__ float2 // vec_conversion(const uint16_t& a) { // fp8x2_type f8x2; // f8x2.__x = a; // return static_cast(f8x2); // } // // fp8x4 -> float4 // template <> // __inline__ __device__ Float4_ // vec_conversion(const uint32_t& a) { // Float4_ res; // res.x = vec_conversion((uint16_t)a); // res.y = vec_conversion((uint16_t)(a >> 16U)); // return res; // } // // fp8x4 -> float4 // template <> // __inline__ __device__ float4 // vec_conversion(const uint32_t& a) { // Float4_ tmp = vec_conversion(a); // float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); // return res; // } // // fp8x8 -> float8 // template <> // __inline__ __device__ Float8_ vec_conversion(const uint2& a) { // Float4_ tmp1, tmp2; // tmp1 = vec_conversion(a.x); // tmp2 = vec_conversion(a.y); // Float8_ res; // res.x = tmp1.x; // res.y = tmp1.y; // res.z = tmp2.x; // res.w = tmp2.y; // return res; // } // // half -> fp8 // template <> // __inline__ __device__ uint8_t // vec_conversion(const uint16_t& a) { // __half_raw tmp; // tmp.x = a; // return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation, // fp8_type::__default_interpret); // } // template <> // __inline__ __device__ uint16_t // vec_conversion(const uint32_t& a) { // union { // uint32_t ui32; // __half2_raw h2r; // } tmp; // tmp.ui32 = a; // return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation, // fp8_type::__default_interpret); // } // // bf16 -> fp8 // template <> // __inline__ __device__ uint8_t // vec_conversion(const __nv_bfloat16& a) { // return __hip_cvt_float_to_fp8(__bfloat162float(a), // fp8_type::__default_saturation, // fp8_type::__default_interpret); // } // // float -> fp8 // template <> // __inline__ __device__ uint8_t vec_conversion(const float& a) { // return __hip_cvt_float_to_fp8(a, fp8_type::__default_saturation, // fp8_type::__default_interpret); // } // // float2 -> half2 // template <> // __inline__ __device__ uint32_t // vec_conversion(const float2& a) { // union { // half2 float16; // uint32_t uint32; // }; // float16 = __float22half2_rn(a); // return uint32; // } // // Float4 -> half2x2 // template <> // __inline__ __device__ uint2 vec_conversion(const Float4_& a) { // uint2 b; // float2 val; // val.x = a.x.x; // val.y = a.x.y; // b.x = vec_conversion(val); // val.x = a.y.x; // val.y = a.y.y; // b.y = vec_conversion(val); // return b; // } // // Float4 -> float4 // template <> // __inline__ __device__ float4 vec_conversion(const Float4_& a) { // float4 b; // b.x = a.x.x; // b.y = a.x.y; // b.z = a.y.x; // b.w = a.y.y; // return b; // } // // Float8 -> half2x4 // template <> // __inline__ __device__ uint4 vec_conversion(const Float8_& a) { // uint4 b; // b.x = vec_conversion(a.x); // b.y = vec_conversion(a.y); // b.z = vec_conversion(a.z); // b.w = vec_conversion(a.w); // return b; // } // // float2 -> bfloat162 // template <> // __inline__ __device__ __nv_bfloat162 // vec_conversion<__nv_bfloat162, float2>(const float2& a) { // __nv_bfloat162 b = __float22bfloat162_rn(a); // return b; // } // // Float4 -> bfloat162x2 // template <> // __inline__ __device__ bf16_4_t // vec_conversion(const Float4_& a) { // bf16_4_t b; // b.x = __float22bfloat162_rn(a.x); // b.y = __float22bfloat162_rn(a.y); // return b; // } // // Float8 -> bfloat162x4 // template <> // __inline__ __device__ bf16_8_t // vec_conversion(const Float8_& a) { // bf16_8_t b; // b.x = __float22bfloat162_rn(a.x); // b.y = __float22bfloat162_rn(a.y); // b.z = __float22bfloat162_rn(a.z); // b.w = __float22bfloat162_rn(a.w); // return b; // } /* Scaled and vectorized conversions, for data exchange between high and low precision domains Convention of the scale in API, e.g: FP8_data = Quantization( High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8 Dequant(FP8) * scale => HP */ using __nv_bfloat16 = __hip_bfloat16; // fp8 -> __nv_bfloat16 template <> __inline__ __device__ __nv_bfloat16 scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale, Fp8KVCacheDataType kv_type) { if (kv_type == vllm::Fp8KVCacheDataType::kFp8E5M2) { assert(false); } return __float2bfloat16(fp8_to_float(a) * scale); // fp8_type f8; // f8.__x = a; // return __float2bfloat16(static_cast(f8) * scale); } // fp8x2 -> __nv_bfloat162 template <> __inline__ __device__ __nv_bfloat162 scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a, float scale, Fp8KVCacheDataType kv_type) { __nv_bfloat162 res; res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale, kv_type); res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale, kv_type); return res; } // fp8x4 -> bf16_4_t template <> __inline__ __device__ bf16_4_t scaled_vec_conversion(const uint32_t& a, float scale, Fp8KVCacheDataType kv_type) { bf16_4_t res; res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale, kv_type); res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), scale, kv_type); return res; } // fp8x8 -> bf16_8_t template <> __inline__ __device__ bf16_8_t scaled_vec_conversion(const uint2& a, float scale, Fp8KVCacheDataType kv_type) { bf16_4_t tmp1, tmp2; tmp1 = scaled_vec_conversion(a.x, scale, kv_type); tmp2 = scaled_vec_conversion(a.y, scale, kv_type); bf16_8_t res; res.x = tmp1.x; res.y = tmp1.y; res.z = tmp2.x; res.w = tmp2.y; return res; } // fp8 -> float template <> __inline__ __device__ float scaled_vec_conversion( const uint8_t& a, float scale, Fp8KVCacheDataType kv_type) { if (kv_type == vllm::Fp8KVCacheDataType::kFp8E5M2) { assert(false); } return fp8_to_float(a) * scale; // fp8_type f8; // f8.__x = a; // return static_cast(f8) * scale; } // fp8x2 -> float2 template <> __inline__ __device__ float2 scaled_vec_conversion(const uint16_t& a, float scale, Fp8KVCacheDataType kv_type) { float2 f2r; f2r.x = scaled_vec_conversion((uint8_t)a, scale, kv_type); f2r.y = scaled_vec_conversion((uint8_t)(a >> 8U), scale, kv_type); return f2r; // [[maybe_unused]] // fp8x2_type f8x2; // f8x2.__x = a; // return static_cast(f8x2) * scale; } // fp8x4 -> float4 template <> __inline__ __device__ Float4_ scaled_vec_conversion(const uint32_t& a, const float scale, Fp8KVCacheDataType kv_type) { Float4_ res; res.x = scaled_vec_conversion((uint16_t)a, scale, kv_type); res.y = scaled_vec_conversion((uint16_t)(a >> 16U), scale, kv_type); return res; } // fp8x4 -> float4 template <> __inline__ __device__ float4 scaled_vec_conversion(const uint32_t& a, float scale, Fp8KVCacheDataType kv_type) { Float4_ res = scaled_vec_conversion(a, scale, kv_type); return {res.x.x, res.x.y, res.y.x, res.y.y}; } // fp8x8 -> float8 template <> __inline__ __device__ Float8_ scaled_vec_conversion(const uint2& a, float scale, Fp8KVCacheDataType kv_type) { Float4_ tmp1, tmp2; tmp1 = scaled_vec_conversion(a.x, scale, kv_type); tmp2 = scaled_vec_conversion(a.y, scale, kv_type); Float8_ res; res.x = tmp1.x; res.y = tmp1.y; res.z = tmp2.x; res.w = tmp2.y; return res; } // fp8 -> half template <> __inline__ __device__ uint16_t scaled_vec_conversion(const uint8_t& a, float scale, Fp8KVCacheDataType kv_type) { if (kv_type == vllm::Fp8KVCacheDataType::kFp8E5M2) { assert(false); } float res = fp8_to_float(a) * scale; return float_to_half(res); // __half_raw res; // res.data = scaled_vec_conversion(a, scale); // return res.x; } // fp8x2 -> half2 template <> __inline__ __device__ uint32_t scaled_vec_conversion(const uint16_t& a, float scale, Fp8KVCacheDataType kv_type) { union { uint16_t u16[2]; uint32_t u32; } res; res.u16[0] = scaled_vec_conversion((uint8_t)a, scale, kv_type); res.u16[1] = scaled_vec_conversion((uint8_t)(a >> 8U), scale, kv_type); return res.u32; // [[maybe_unused]] __half2_raw h2r = // __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret); // union { // __half2_raw h2r; // uint32_t ui32; // } tmp; // tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret); // tmp.h2r.x.data *= scale; // tmp.h2r.y.data *= scale; // return tmp.ui32; } // fp8x4 -> half2x2 template <> __inline__ __device__ uint2 scaled_vec_conversion(const uint32_t& a, float scale, Fp8KVCacheDataType kv_type) { union { uint2 u32x2; uint32_t u32[2]; } tmp; tmp.u32[0] = scaled_vec_conversion((uint16_t)a, scale, kv_type); tmp.u32[1] = scaled_vec_conversion((uint16_t)(a >> 16U), scale, kv_type); return tmp.u32x2; } // fp8x8 -> half2x4 template <> __inline__ __device__ uint4 scaled_vec_conversion(const uint2& a, float scale, Fp8KVCacheDataType kv_type) { union { uint4 u64x2; uint2 u64[2]; } tmp; tmp.u64[0] = scaled_vec_conversion(a.x, scale, kv_type); tmp.u64[1] = scaled_vec_conversion(a.y, scale, kv_type); return tmp.u64x2; } // half -> fp8 template <> __inline__ __device__ uint8_t scaled_vec_conversion(const uint16_t& a, float scale, Fp8KVCacheDataType kv_type) { float res_f = half_to_float(a) / scale; if (kv_type == vllm::Fp8KVCacheDataType::kFp8E4M3) { return float_to_fp8_e4m3(res_f); } else { return float_to_fp8_e5m2(res_f); } // __half_raw tmp; // tmp.x = a; // tmp.data /= scale; // return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation, // fp8_type::__default_interpret); } // halfx2 -> fp8x2 template <> __inline__ __device__ uint16_t scaled_vec_conversion(const uint32_t& a, float scale, Fp8KVCacheDataType kv_type) { union { uint8_t ui8[2]; uint16_t ui16; } tmp; union { uint32_t ui32; half2 h2r; } tmp_a; tmp_a.ui32 = a; tmp.ui8[0] = scaled_vec_conversion(tmp_a.h2r.data[0], scale, kv_type); tmp.ui8[1] = scaled_vec_conversion(tmp_a.h2r.data[1], scale, kv_type); return tmp.ui16; // union { // uint32_t ui32; // __half2_raw h2r; // } tmp; // tmp.ui32 = a; // tmp.h2r.x.data /= scale; // tmp.h2r.y.data /= scale; // return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation, // fp8_type::__default_interpret); } // half2x2 -> fp8x4 template <> __inline__ __device__ uint32_t scaled_vec_conversion(const uint2& a, float scale, Fp8KVCacheDataType kv_type) { union { uint16_t ui16[2]; uint32_t ui32; } tmp; tmp.ui16[0] = scaled_vec_conversion(a.x, scale, kv_type); tmp.ui16[1] = scaled_vec_conversion(a.y, scale, kv_type); return tmp.ui32; } // half2x4 -> fp8x8 template <> __inline__ __device__ uint2 scaled_vec_conversion(const uint4& a, float scale, Fp8KVCacheDataType kv_type) { union { uint2 ui2[2]; uint4 ui4; } tmp; tmp.ui4 = a; uint2 res; res.x = scaled_vec_conversion(tmp.ui2[0], scale, kv_type); res.y = scaled_vec_conversion(tmp.ui2[1], scale, kv_type); return res; } // bf16 -> fp8 template <> __inline__ __device__ uint8_t scaled_vec_conversion( const __nv_bfloat16& a, float scale, Fp8KVCacheDataType kv_type) { float res_f = (static_cast(a)) / scale; if (kv_type == vllm::Fp8KVCacheDataType::kFp8E4M3) { return float_to_fp8_e4m3(res_f); } else { return float_to_fp8_e5m2(res_f); } // return __hip_cvt_float_to_fp8(__bfloat162float(a) / scale, // fp8_type::__default_saturation, // fp8_type::__default_interpret); } // bf16x2 -> fp8x2 template <> __inline__ __device__ uint16_t scaled_vec_conversion( const __nv_bfloat162& a, float scale, Fp8KVCacheDataType kv_type) { union { uint8_t ui8[2]; uint16_t ui16; } tmp; tmp.ui8[0] = scaled_vec_conversion(a.x, scale, kv_type); tmp.ui8[1] = scaled_vec_conversion(a.y, scale, kv_type); return tmp.ui16; } // bf16x4 -> fp8x4 template <> __inline__ __device__ uint32_t scaled_vec_conversion(const bf16_4_t& a, float scale, Fp8KVCacheDataType kv_type) { union { uint16_t ui16[2]; uint32_t ui32; } tmp; tmp.ui16[0] = scaled_vec_conversion(a.x, scale, kv_type); tmp.ui16[1] = scaled_vec_conversion(a.y, scale, kv_type); return tmp.ui32; } // bf16x8 -> fp8x8 template <> __inline__ __device__ uint2 scaled_vec_conversion(const bf16_8_t& a, float scale, Fp8KVCacheDataType kv_type) { uint2 res; res.x = scaled_vec_conversion({a.x, a.y}, scale, kv_type); res.y = scaled_vec_conversion({a.z, a.w}, scale, kv_type); return res; } // float -> fp8 template <> __inline__ __device__ uint8_t scaled_vec_conversion(const float& a, float scale, Fp8KVCacheDataType kv_type) { if (kv_type == vllm::Fp8KVCacheDataType::kFp8E4M3) { return float_to_fp8_e4m3(a / scale); } else { return float_to_fp8_e5m2(a / scale); } // return __hip_cvt_float_to_fp8(a / scale, fp8_type::__default_saturation, // fp8_type::__default_interpret); } // floatx2 -> fp8x2 template <> __inline__ __device__ uint16_t scaled_vec_conversion(const float2& a, float scale, Fp8KVCacheDataType kv_type) { union { uint8_t ui8[2]; uint16_t ui16; } tmp; tmp.ui8[0] = scaled_vec_conversion(a.x, scale, kv_type); tmp.ui8[1] = scaled_vec_conversion(a.y, scale, kv_type); return tmp.ui16; // return __hip_cvt_float2_to_fp8x2(a / scale, fp8_type::__default_saturation, // fp8_type::__default_interpret); } // floatx4 -> fp8x4 template <> __inline__ __device__ uint32_t scaled_vec_conversion(const float4& a, float scale, Fp8KVCacheDataType kv_type) { union { uint16_t ui16[2]; uint32_t ui32; } tmp; tmp.ui16[0] = scaled_vec_conversion({a.x, a.y}, scale, kv_type); tmp.ui16[1] = scaled_vec_conversion({a.z, a.w}, scale, kv_type); return tmp.ui32; } // #endif // ENABLE_FP8 // template // __inline__ __device__ Tout convert(const Tin& x) { // #ifdef ENABLE_FP8 // if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { // return vec_conversion(x); // } // #endif // assert(false); // return {}; // Squash missing return statement warning // } template __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) { // #ifdef ENABLE_FP8 if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3 || kv_dt == Fp8KVCacheDataType::kFp8E5M2) { return scaled_vec_conversion(x, scale, kv_dt); } // #endif assert(false); return {}; // Squash missing return statement warning } // The following macro is used to dispatch the conversion function based on // the data type of the key and value cache. The FN is a macro that calls a // function with template. #define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \ if (KV_DTYPE == "auto") { \ if (SRC_DTYPE == at::ScalarType::Float) { \ FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \ } else if (SRC_DTYPE == at::ScalarType::Half) { \ FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \ } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \ } else { \ TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ } \ } else if (KV_DTYPE == "int8") { \ if (SRC_DTYPE == at::ScalarType::Float) { \ FN(float, uint8_t, vllm::Fp8KVCacheDataType::kInt8); \ } else if (SRC_DTYPE == at::ScalarType::Half) { \ FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kInt8); \ } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kInt8); \ } else { \ TORCH_CHECK(false,"Unsupported input type of kv cache: ", SRC_DTYPE); \ } \ } else { \ if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \ if (SRC_DTYPE == at::ScalarType::Float) { \ FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ } else if (SRC_DTYPE == at::ScalarType::Half) { \ FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ } else { \ TORCH_CHECK(false, \ "Unsupported input type of kv cache: ", SRC_DTYPE); \ } \ } \ else if (KV_DTYPE == "fp8_e5m2") { \ if (SRC_DTYPE == at::ScalarType::Float) { \ FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ } else if (SRC_DTYPE == at::ScalarType::Half) { \ FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ } else { \ TORCH_CHECK(false, \ "Unsupported input type of kv cache: ", SRC_DTYPE); \ } \ } else { \ TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \ } \ } } // namespace fp8 #endif // USE_ROCM } // namespace vllm