Commit e807ec39 authored by zhuwenwen's avatar zhuwenwen
Browse files

perf(qwen3): 融合 q/k RMSNorm + RoPE

set fp8_e4m3 only supported on nmz and support q&kvcache fp8
set VLLM_PCIE_USE_CUSTOM_ALLREDUCE=1
parent cf4be8ff
...@@ -303,7 +303,7 @@ set(VLLM_EXT_SRC ...@@ -303,7 +303,7 @@ set(VLLM_EXT_SRC
"csrc/cuda_view.cu" "csrc/cuda_view.cu"
# "csrc/quantization/gptq/q_gemm.cu" # "csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/w8a8/int8/scaled_quant.cu" "csrc/quantization/w8a8/int8/scaled_quant.cu"
# "csrc/quantization/w8a8/fp8/common.cu" "csrc/quantization/w8a8/fp8/common.cu"
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu" "csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
"csrc/quantization/gguf/gguf_kernel.cu" "csrc/quantization/gguf/gguf_kernel.cu"
# "csrc/quantization/activation_kernels.cu" # "csrc/quantization/activation_kernels.cu"
......
...@@ -357,9 +357,9 @@ void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, ...@@ -357,9 +357,9 @@ void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
// void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit); // void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
// void static_scaled_fp8_quant( void static_scaled_fp8_quant(
// torch::Tensor& out, torch::Tensor const& input, torch::Tensor const& scale, torch::Tensor& out, torch::Tensor const& input, torch::Tensor const& scale,
// std::optional<std::tuple<int64_t, int64_t>> group_shape = std::nullopt); std::optional<std::tuple<int64_t, int64_t>> group_shape = std::nullopt);
// void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input, // void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
// torch::Tensor& scale); // torch::Tensor& scale);
......
...@@ -27,7 +27,7 @@ static inline __device__ float fp8_to_float(uint8_t input) { ...@@ -27,7 +27,7 @@ static inline __device__ float fp8_to_float(uint8_t input) {
} }
// float -> fp8 // float -> fp8
static inline __device__ uint8_t float_to_fp8(float f) { static inline __device__ uint8_t float_to_fp8_e4m3(float f) {
constexpr uint32_t fp8_max = UINT32_C(1087) << 20; constexpr uint32_t fp8_max = UINT32_C(1087) << 20;
constexpr uint32_t denorm_mask = UINT32_C(141) << 23; constexpr uint32_t denorm_mask = UINT32_C(141) << 23;
uint32_t f_bits = c10::detail::fp32_to_bits(f); uint32_t f_bits = c10::detail::fp32_to_bits(f);
...@@ -53,6 +53,32 @@ static inline __device__ uint8_t float_to_fp8(float f) { ...@@ -53,6 +53,32 @@ static inline __device__ uint8_t float_to_fp8(float f) {
return result; return result;
} }
static inline __device__ uint8_t float_to_fp8_e5m2(float f) {
constexpr uint32_t fp32_inf = UINT32_C(255) << 23;
constexpr uint32_t fp8_max = UINT32_C(143) << 23;
constexpr uint32_t denorm_mask = UINT32_C(134) << 23;
uint32_t f_bits = c10::detail::fp32_to_bits(f);
uint8_t result = 0u;
const uint32_t sign = f_bits & UINT32_C(0x80000000);
f_bits ^= sign;
if (f_bits >= fp8_max) {
result = f_bits > fp32_inf ? UINT8_C(0x7F) : UINT8_C(0x7C);
} else {
if (f_bits < (UINT32_C(113) << 23)) {
f_bits = c10::detail::fp32_to_bits(c10::detail::fp32_from_bits(f_bits)
+ c10::detail::fp32_from_bits(denorm_mask));
result = static_cast<uint8_t>(f_bits - denorm_mask);
} else {
uint32_t mant_odd = (f_bits >> 21) & 1;
f_bits += ((uint32_t)(15 - 127) << 23) + 0xFFFFF;
f_bits += mant_odd;
result = static_cast<uint8_t>(f_bits >> 21);
}
}
result |= static_cast<uint8_t>(sign >> 24);
return result;
}
// template <typename Tout, typename Tin> // template <typename Tout, typename Tin>
// __inline__ __device__ Tout vec_conversion(const Tin& x) { // __inline__ __device__ Tout vec_conversion(const Tin& x) {
// return x; // return x;
...@@ -60,7 +86,7 @@ static inline __device__ uint8_t float_to_fp8(float f) { ...@@ -60,7 +86,7 @@ static inline __device__ uint8_t float_to_fp8(float f) {
template <typename Tout, typename Tin> template <typename Tout, typename Tin>
__inline__ __device__ Tout scaled_vec_conversion(const Tin& x, __inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
const float scale) { const float scale, Fp8KVCacheDataType kv_type) {
return x; return x;
} }
...@@ -344,7 +370,10 @@ using __nv_bfloat16 = __hip_bfloat16; ...@@ -344,7 +370,10 @@ using __nv_bfloat16 = __hip_bfloat16;
// fp8 -> __nv_bfloat16 // fp8 -> __nv_bfloat16
template <> template <>
__inline__ __device__ __nv_bfloat16 __inline__ __device__ __nv_bfloat16
scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) { scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale, Fp8KVCacheDataType kv_type) {
if (kv_type == vllm::Fp8KVCacheDataType::kFp8E5M2) {
assert(false);
}
return __float2bfloat16(fp8_to_float(a) * scale); return __float2bfloat16(fp8_to_float(a) * scale);
// fp8_type f8; // fp8_type f8;
...@@ -356,32 +385,32 @@ scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) { ...@@ -356,32 +385,32 @@ scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) {
template <> template <>
__inline__ __device__ __nv_bfloat162 __inline__ __device__ __nv_bfloat162
scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a, scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a,
float scale) { float scale, Fp8KVCacheDataType kv_type) {
__nv_bfloat162 res; __nv_bfloat162 res;
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale); res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale, kv_type);
res.y = res.y =
scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale); scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale, kv_type);
return res; return res;
} }
// fp8x4 -> bf16_4_t // fp8x4 -> bf16_4_t
template <> template <>
__inline__ __device__ bf16_4_t __inline__ __device__ bf16_4_t
scaled_vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a, float scale) { scaled_vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a, float scale, Fp8KVCacheDataType kv_type) {
bf16_4_t res; bf16_4_t res;
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale); res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale, kv_type);
res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
scale); scale, kv_type);
return res; return res;
} }
// fp8x8 -> bf16_8_t // fp8x8 -> bf16_8_t
template <> template <>
__inline__ __device__ bf16_8_t __inline__ __device__ bf16_8_t
scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, float scale) { scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, float scale, Fp8KVCacheDataType kv_type) {
bf16_4_t tmp1, tmp2; bf16_4_t tmp1, tmp2;
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale); tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale, kv_type);
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale); tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale, kv_type);
bf16_8_t res; bf16_8_t res;
res.x = tmp1.x; res.x = tmp1.x;
res.y = tmp1.y; res.y = tmp1.y;
...@@ -393,7 +422,10 @@ scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, float scale) { ...@@ -393,7 +422,10 @@ scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, float scale) {
// fp8 -> float // fp8 -> float
template <> template <>
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>( __inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
const uint8_t& a, float scale) { const uint8_t& a, float scale, Fp8KVCacheDataType kv_type) {
if (kv_type == vllm::Fp8KVCacheDataType::kFp8E5M2) {
assert(false);
}
return fp8_to_float(a) * scale; return fp8_to_float(a) * scale;
// fp8_type f8; // fp8_type f8;
// f8.__x = a; // f8.__x = a;
...@@ -403,10 +435,10 @@ __inline__ __device__ float scaled_vec_conversion<float, uint8_t>( ...@@ -403,10 +435,10 @@ __inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
// fp8x2 -> float2 // fp8x2 -> float2
template <> template <>
__inline__ __device__ float2 __inline__ __device__ float2
scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, float scale) { scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, float scale, Fp8KVCacheDataType kv_type) {
float2 f2r; float2 f2r;
f2r.x = scaled_vec_conversion<float, uint8_t>((uint8_t)a, scale); f2r.x = scaled_vec_conversion<float, uint8_t>((uint8_t)a, scale, kv_type);
f2r.y = scaled_vec_conversion<float, uint8_t>((uint8_t)(a >> 8U), scale); f2r.y = scaled_vec_conversion<float, uint8_t>((uint8_t)(a >> 8U), scale, kv_type);
return f2r; return f2r;
// [[maybe_unused]] // [[maybe_unused]]
// fp8x2_type f8x2; // fp8x2_type f8x2;
...@@ -417,28 +449,28 @@ scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, float scale) { ...@@ -417,28 +449,28 @@ scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, float scale) {
// fp8x4 -> float4 // fp8x4 -> float4
template <> template <>
__inline__ __device__ Float4_ __inline__ __device__ Float4_
scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale) { scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale, Fp8KVCacheDataType kv_type) {
Float4_ res; Float4_ res;
res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale); res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale, kv_type);
res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale); res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale, kv_type);
return res; return res;
} }
// fp8x4 -> float4 // fp8x4 -> float4
template <> template <>
__inline__ __device__ float4 __inline__ __device__ float4
scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, float scale) { scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, float scale, Fp8KVCacheDataType kv_type) {
Float4_ res = scaled_vec_conversion<Float4_, uint32_t>(a, scale); Float4_ res = scaled_vec_conversion<Float4_, uint32_t>(a, scale, kv_type);
return {res.x.x, res.x.y, res.y.x, res.y.y}; return {res.x.x, res.x.y, res.y.x, res.y.y};
} }
// fp8x8 -> float8 // fp8x8 -> float8
template <> template <>
__inline__ __device__ Float8_ __inline__ __device__ Float8_
scaled_vec_conversion<Float8_, uint2>(const uint2& a, float scale) { scaled_vec_conversion<Float8_, uint2>(const uint2& a, float scale, Fp8KVCacheDataType kv_type) {
Float4_ tmp1, tmp2; Float4_ tmp1, tmp2;
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale); tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale, kv_type);
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale); tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale, kv_type);
Float8_ res; Float8_ res;
res.x = tmp1.x; res.x = tmp1.x;
res.y = tmp1.y; res.y = tmp1.y;
...@@ -450,7 +482,10 @@ scaled_vec_conversion<Float8_, uint2>(const uint2& a, float scale) { ...@@ -450,7 +482,10 @@ scaled_vec_conversion<Float8_, uint2>(const uint2& a, float scale) {
// fp8 -> half // fp8 -> half
template <> template <>
__inline__ __device__ uint16_t __inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale) { scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale, Fp8KVCacheDataType kv_type) {
if (kv_type == vllm::Fp8KVCacheDataType::kFp8E5M2) {
assert(false);
}
float res = fp8_to_float(a) * scale; float res = fp8_to_float(a) * scale;
return float_to_half(res); return float_to_half(res);
// __half_raw res; // __half_raw res;
...@@ -461,13 +496,13 @@ scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale) { ...@@ -461,13 +496,13 @@ scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale) {
// fp8x2 -> half2 // fp8x2 -> half2
template <> template <>
__inline__ __device__ uint32_t __inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, float scale) { scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, float scale, Fp8KVCacheDataType kv_type) {
union { union {
uint16_t u16[2]; uint16_t u16[2];
uint32_t u32; uint32_t u32;
} res; } res;
res.u16[0] = scaled_vec_conversion<uint16_t, uint8_t>((uint8_t)a, scale); res.u16[0] = scaled_vec_conversion<uint16_t, uint8_t>((uint8_t)a, scale, kv_type);
res.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>((uint8_t)(a >> 8U), scale); res.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>((uint8_t)(a >> 8U), scale, kv_type);
return res.u32; return res.u32;
// [[maybe_unused]] __half2_raw h2r = // [[maybe_unused]] __half2_raw h2r =
// __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret); // __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
...@@ -484,35 +519,39 @@ scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, float scale) { ...@@ -484,35 +519,39 @@ scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, float scale) {
// fp8x4 -> half2x2 // fp8x4 -> half2x2
template <> template <>
__inline__ __device__ uint2 __inline__ __device__ uint2
scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, float scale) { scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, float scale, Fp8KVCacheDataType kv_type) {
union { union {
uint2 u32x2; uint2 u32x2;
uint32_t u32[2]; uint32_t u32[2];
} tmp; } tmp;
tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale); tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale, kv_type);
tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale); tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale, kv_type);
return tmp.u32x2; return tmp.u32x2;
} }
// fp8x8 -> half2x4 // fp8x8 -> half2x4
template <> template <>
__inline__ __device__ uint4 scaled_vec_conversion<uint4, uint2>(const uint2& a, __inline__ __device__ uint4 scaled_vec_conversion<uint4, uint2>(const uint2& a,
float scale) { float scale, Fp8KVCacheDataType kv_type) {
union { union {
uint4 u64x2; uint4 u64x2;
uint2 u64[2]; uint2 u64[2];
} tmp; } tmp;
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale); tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale, kv_type);
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale); tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale, kv_type);
return tmp.u64x2; return tmp.u64x2;
} }
// half -> fp8 // half -> fp8
template <> template <>
__inline__ __device__ uint8_t __inline__ __device__ uint8_t
scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, float scale) { scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, float scale, Fp8KVCacheDataType kv_type) {
float res_f = half_to_float(a) / scale; float res_f = half_to_float(a) / scale;
return float_to_fp8(res_f); if (kv_type == vllm::Fp8KVCacheDataType::kFp8E4M3) {
return float_to_fp8_e4m3(res_f);
} else {
return float_to_fp8_e5m2(res_f);
}
// __half_raw tmp; // __half_raw tmp;
// tmp.x = a; // tmp.x = a;
// tmp.data /= scale; // tmp.data /= scale;
...@@ -523,7 +562,7 @@ scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, float scale) { ...@@ -523,7 +562,7 @@ scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, float scale) {
// halfx2 -> fp8x2 // halfx2 -> fp8x2
template <> template <>
__inline__ __device__ uint16_t __inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale) { scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale, Fp8KVCacheDataType kv_type) {
union { union {
uint8_t ui8[2]; uint8_t ui8[2];
uint16_t ui16; uint16_t ui16;
...@@ -533,8 +572,8 @@ scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale) { ...@@ -533,8 +572,8 @@ scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale) {
half2 h2r; half2 h2r;
} tmp_a; } tmp_a;
tmp_a.ui32 = a; tmp_a.ui32 = a;
tmp.ui8[0] = scaled_vec_conversion<uint8_t, uint16_t>(tmp_a.h2r.data[0], scale); tmp.ui8[0] = scaled_vec_conversion<uint8_t, uint16_t>(tmp_a.h2r.data[0], scale, kv_type);
tmp.ui8[1] = scaled_vec_conversion<uint8_t, uint16_t>(tmp_a.h2r.data[1], scale); tmp.ui8[1] = scaled_vec_conversion<uint8_t, uint16_t>(tmp_a.h2r.data[1], scale, kv_type);
return tmp.ui16; return tmp.ui16;
// union { // union {
// uint32_t ui32; // uint32_t ui32;
...@@ -550,37 +589,41 @@ scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale) { ...@@ -550,37 +589,41 @@ scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale) {
// half2x2 -> fp8x4 // half2x2 -> fp8x4
template <> template <>
__inline__ __device__ uint32_t __inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, uint2>(const uint2& a, float scale) { scaled_vec_conversion<uint32_t, uint2>(const uint2& a, float scale, Fp8KVCacheDataType kv_type) {
union { union {
uint16_t ui16[2]; uint16_t ui16[2];
uint32_t ui32; uint32_t ui32;
} tmp; } tmp;
tmp.ui16[0] = scaled_vec_conversion<uint16_t, uint32_t>(a.x, scale); tmp.ui16[0] = scaled_vec_conversion<uint16_t, uint32_t>(a.x, scale, kv_type);
tmp.ui16[1] = scaled_vec_conversion<uint16_t, uint32_t>(a.y, scale); tmp.ui16[1] = scaled_vec_conversion<uint16_t, uint32_t>(a.y, scale, kv_type);
return tmp.ui32; return tmp.ui32;
} }
// half2x4 -> fp8x8 // half2x4 -> fp8x8
template <> template <>
__inline__ __device__ uint2 scaled_vec_conversion<uint2, uint4>(const uint4& a, __inline__ __device__ uint2 scaled_vec_conversion<uint2, uint4>(const uint4& a,
float scale) { float scale, Fp8KVCacheDataType kv_type) {
union { union {
uint2 ui2[2]; uint2 ui2[2];
uint4 ui4; uint4 ui4;
} tmp; } tmp;
tmp.ui4 = a; tmp.ui4 = a;
uint2 res; uint2 res;
res.x = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[0], scale); res.x = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[0], scale, kv_type);
res.y = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[1], scale); res.y = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[1], scale, kv_type);
return res; return res;
} }
// bf16 -> fp8 // bf16 -> fp8
template <> template <>
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>( __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
const __nv_bfloat16& a, float scale) { const __nv_bfloat16& a, float scale, Fp8KVCacheDataType kv_type) {
float res_f = (static_cast<float>(a)) / scale; float res_f = (static_cast<float>(a)) / scale;
return float_to_fp8(res_f); if (kv_type == vllm::Fp8KVCacheDataType::kFp8E4M3) {
return float_to_fp8_e4m3(res_f);
} else {
return float_to_fp8_e5m2(res_f);
}
// return __hip_cvt_float_to_fp8(__bfloat162float(a) / scale, // return __hip_cvt_float_to_fp8(__bfloat162float(a) / scale,
// fp8_type::__default_saturation, // fp8_type::__default_saturation,
// fp8_type::__default_interpret); // fp8_type::__default_interpret);
...@@ -589,44 +632,48 @@ __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>( ...@@ -589,44 +632,48 @@ __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
// bf16x2 -> fp8x2 // bf16x2 -> fp8x2
template <> template <>
__inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, __nv_bfloat162>( __inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, __nv_bfloat162>(
const __nv_bfloat162& a, float scale) { const __nv_bfloat162& a, float scale, Fp8KVCacheDataType kv_type) {
union { union {
uint8_t ui8[2]; uint8_t ui8[2];
uint16_t ui16; uint16_t ui16;
} tmp; } tmp;
tmp.ui8[0] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.x, scale); tmp.ui8[0] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.x, scale, kv_type);
tmp.ui8[1] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.y, scale); tmp.ui8[1] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.y, scale, kv_type);
return tmp.ui16; return tmp.ui16;
} }
// bf16x4 -> fp8x4 // bf16x4 -> fp8x4
template <> template <>
__inline__ __device__ uint32_t __inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, bf16_4_t>(const bf16_4_t& a, float scale) { scaled_vec_conversion<uint32_t, bf16_4_t>(const bf16_4_t& a, float scale, Fp8KVCacheDataType kv_type) {
union { union {
uint16_t ui16[2]; uint16_t ui16[2];
uint32_t ui32; uint32_t ui32;
} tmp; } tmp;
tmp.ui16[0] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.x, scale); tmp.ui16[0] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.x, scale, kv_type);
tmp.ui16[1] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.y, scale); tmp.ui16[1] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.y, scale, kv_type);
return tmp.ui32; return tmp.ui32;
} }
// bf16x8 -> fp8x8 // bf16x8 -> fp8x8
template <> template <>
__inline__ __device__ uint2 __inline__ __device__ uint2
scaled_vec_conversion<uint2, bf16_8_t>(const bf16_8_t& a, float scale) { scaled_vec_conversion<uint2, bf16_8_t>(const bf16_8_t& a, float scale, Fp8KVCacheDataType kv_type) {
uint2 res; uint2 res;
res.x = scaled_vec_conversion<uint32_t, bf16_4_t>({a.x, a.y}, scale); res.x = scaled_vec_conversion<uint32_t, bf16_4_t>({a.x, a.y}, scale, kv_type);
res.y = scaled_vec_conversion<uint32_t, bf16_4_t>({a.z, a.w}, scale); res.y = scaled_vec_conversion<uint32_t, bf16_4_t>({a.z, a.w}, scale, kv_type);
return res; return res;
} }
// float -> fp8 // float -> fp8
template <> template <>
__inline__ __device__ uint8_t __inline__ __device__ uint8_t
scaled_vec_conversion<uint8_t, float>(const float& a, float scale) { scaled_vec_conversion<uint8_t, float>(const float& a, float scale, Fp8KVCacheDataType kv_type) {
return float_to_fp8(a / scale); if (kv_type == vllm::Fp8KVCacheDataType::kFp8E4M3) {
return float_to_fp8_e4m3(a / scale);
} else {
return float_to_fp8_e5m2(a / scale);
}
// return __hip_cvt_float_to_fp8(a / scale, fp8_type::__default_saturation, // return __hip_cvt_float_to_fp8(a / scale, fp8_type::__default_saturation,
// fp8_type::__default_interpret); // fp8_type::__default_interpret);
} }
...@@ -634,13 +681,13 @@ scaled_vec_conversion<uint8_t, float>(const float& a, float scale) { ...@@ -634,13 +681,13 @@ scaled_vec_conversion<uint8_t, float>(const float& a, float scale) {
// floatx2 -> fp8x2 // floatx2 -> fp8x2
template <> template <>
__inline__ __device__ uint16_t __inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, float2>(const float2& a, float scale) { scaled_vec_conversion<uint16_t, float2>(const float2& a, float scale, Fp8KVCacheDataType kv_type) {
union { union {
uint8_t ui8[2]; uint8_t ui8[2];
uint16_t ui16; uint16_t ui16;
} tmp; } tmp;
tmp.ui8[0] = scaled_vec_conversion<uint8_t, float>(a.x, scale); tmp.ui8[0] = scaled_vec_conversion<uint8_t, float>(a.x, scale, kv_type);
tmp.ui8[1] = scaled_vec_conversion<uint8_t, float>(a.y, scale); tmp.ui8[1] = scaled_vec_conversion<uint8_t, float>(a.y, scale, kv_type);
return tmp.ui16; return tmp.ui16;
// return __hip_cvt_float2_to_fp8x2(a / scale, fp8_type::__default_saturation, // return __hip_cvt_float2_to_fp8x2(a / scale, fp8_type::__default_saturation,
// fp8_type::__default_interpret); // fp8_type::__default_interpret);
...@@ -649,13 +696,13 @@ scaled_vec_conversion<uint16_t, float2>(const float2& a, float scale) { ...@@ -649,13 +696,13 @@ scaled_vec_conversion<uint16_t, float2>(const float2& a, float scale) {
// floatx4 -> fp8x4 // floatx4 -> fp8x4
template <> template <>
__inline__ __device__ uint32_t __inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, float4>(const float4& a, float scale) { scaled_vec_conversion<uint32_t, float4>(const float4& a, float scale, Fp8KVCacheDataType kv_type) {
union { union {
uint16_t ui16[2]; uint16_t ui16[2];
uint32_t ui32; uint32_t ui32;
} tmp; } tmp;
tmp.ui16[0] = scaled_vec_conversion<uint16_t, float2>({a.x, a.y}, scale); tmp.ui16[0] = scaled_vec_conversion<uint16_t, float2>({a.x, a.y}, scale, kv_type);
tmp.ui16[1] = scaled_vec_conversion<uint16_t, float2>({a.z, a.w}, scale); tmp.ui16[1] = scaled_vec_conversion<uint16_t, float2>({a.z, a.w}, scale, kv_type);
return tmp.ui32; return tmp.ui32;
} }
// #endif // ENABLE_FP8 // #endif // ENABLE_FP8
...@@ -674,11 +721,11 @@ scaled_vec_conversion<uint32_t, float4>(const float4& a, float scale) { ...@@ -674,11 +721,11 @@ scaled_vec_conversion<uint32_t, float4>(const float4& a, float scale) {
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt> template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) { __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
// #ifdef ENABLE_FP8 // #ifdef ENABLE_FP8
// if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3 || kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
return scaled_vec_conversion<Tout, Tin>(x, scale); return scaled_vec_conversion<Tout, Tin>(x, scale, kv_dt);
// } }
// #endif // #endif
// assert(false); assert(false);
return {}; // Squash missing return statement warning return {}; // Squash missing return statement warning
} }
...@@ -719,6 +766,18 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) { ...@@ -719,6 +766,18 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
TORCH_CHECK(false, \ TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \ "Unsupported input type of kv cache: ", SRC_DTYPE); \
} \ } \
} \
else if (KV_DTYPE == "fp8_e5m2") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else { \
TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \ } else { \
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \ TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
} \ } \
......
...@@ -47,15 +47,19 @@ __device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val, ...@@ -47,15 +47,19 @@ __device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val,
x = val / scale; x = val / scale;
} }
float r = // float r =
fmaxf(-quant_type_max_v<fp8_type>, fminf(x, quant_type_max_v<fp8_type>)); // fmaxf(-quant_type_max_v<fp8_type>, fminf(x, quant_type_max_v<fp8_type>));
#ifndef USE_ROCM #ifndef USE_ROCM
// Use hardware cvt instruction for fp8 on nvidia // Use hardware cvt instruction for fp8 on nvidia
// Currently only support fp8_type = c10::Float8_e4m3fn // Currently only support fp8_type = c10::Float8_e4m3fn
return fp8::vec_conversion<fp8_type, float>(r); return fp8::vec_conversion<fp8_type, float>(r);
#else #else
fp8_type *test;
uint8_t test_uint8 = fp8::float_to_fp8_e4m3(x);
test = (fp8_type*)(&test_uint8);
return *test;
// Use hardware cvt instruction for fp8 on rocm // Use hardware cvt instruction for fp8 on rocm
return fp8::cvt_c10<fp8_type>(r); // return fp8::cvt_c10<fp8_type>(r);
#endif #endif
} }
......
...@@ -619,10 +619,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -619,10 +619,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Supports per-tensor, per-channel, per-token, and arbitrary 2D group // Supports per-tensor, per-channel, per-token, and arbitrary 2D group
// scaling. Optional group_m/group_n specify the group shape explicitly; // scaling. Optional group_m/group_n specify the group shape explicitly;
// required for 1D scales to disambiguate per-channel vs per-token. // required for 1D scales to disambiguate per-channel vs per-token.
// ops.def( ops.def(
// "static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale, " "static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale, "
// "(int, int)? group_shape=None) -> ()"); "(int, int)? group_shape=None) -> ()");
// ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant); ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
// Compute dynamic-per-tensor FP8 quantized tensor and scaling factor. // Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
// ops.def( // ops.def(
......
...@@ -1723,7 +1723,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1723,7 +1723,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# flag to control vllm to use optimized kernels # flag to control vllm to use optimized kernels
"VLLM_CUSTOM_CACHE": "VLLM_CUSTOM_CACHE":
lambda: bool(int(os.environ.get("VLLM_CUSTOM_CACHE", "0"))), lambda: bool(int(os.environ.get("VLLM_CUSTOM_CACHE", "1"))),
# flag to control vllm to use optimized kernels # flag to control vllm to use optimized kernels
"VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX": "VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX":
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
"""Inference-only Qwen3 model compatible with HuggingFace weights.""" """Inference-only Qwen3 model compatible with HuggingFace weights."""
from collections.abc import Iterable from collections.abc import Iterable
from typing import Any from typing import Any, Optional
import torch import torch
from torch import nn from torch import nn
...@@ -51,6 +51,7 @@ from .qwen2 import Qwen2Model ...@@ -51,6 +51,7 @@ from .qwen2 import Qwen2Model
from .utils import AutoWeightsLoader, PPMissingLayer, extract_layer_index, maybe_prefix from .utils import AutoWeightsLoader, PPMissingLayer, extract_layer_index, maybe_prefix
import vllm.envs as envs import vllm.envs as envs
from vllm.utils import direct_register_custom_op
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -136,6 +137,58 @@ class Qwen3Attention(nn.Module): ...@@ -136,6 +137,58 @@ class Qwen3Attention(nn.Module):
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
def rms_rotary_embedding_fuse(
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor],
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox_style: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_bias: Optional[torch.Tensor],
k_bias: Optional[torch.Tensor],
epsilon: float,
) -> None:
from lightop import rms_rotary_embedding_fuse as fused_kernel
fused_kernel(
positions,
query,
key,
head_size,
cos_sin_cache,
is_neox_style,
q_weight,
k_weight,
q_bias,
k_bias,
epsilon,
)
def rms_rotary_embedding_fuse_fake(
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor],
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox_style: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_bias: Optional[torch.Tensor],
k_bias: Optional[torch.Tensor],
epsilon: float,
) -> None:
# Fake impl intentionally left as no-op for graph tracing modes.
pass
if not hasattr(torch.ops.vllm, "rms_rotary_embedding_fuse"):
direct_register_custom_op(
op_name="rms_rotary_embedding_fuse",
op_func=rms_rotary_embedding_fuse,
mutates_args=["query", "key"],
fake_impl=rms_rotary_embedding_fuse_fake,
)
def forward( def forward(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
...@@ -143,20 +196,47 @@ class Qwen3Attention(nn.Module): ...@@ -143,20 +196,47 @@ class Qwen3Attention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# Add qk-norm if envs.VLLM_USE_FUSED_RMS_ROPE:
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) # Fused RMSNorm + RoPE path through custom op.
if envs.VLLM_USE_APEX_RN: cos_sin_cache = self.rotary_emb.cos_sin_cache
q_by_head = self.q_norm.forward_apex(q_by_head) if (cos_sin_cache.device != q.device
else: or cos_sin_cache.dtype != q.dtype):
q_by_head = self.q_norm.forward_cuda(q_by_head) cos_sin_cache = cos_sin_cache.to(q.device,
q = q_by_head.view(q.shape) dtype=q.dtype,
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) non_blocking=True)
if envs.VLLM_USE_APEX_RN: # Persist the converted cache so we don't re-copy/re-allocate
k_by_head = self.k_norm.forward_apex(k_by_head) # on every forward when the original buffer starts on CPU.
self.rotary_emb.cos_sin_cache = cos_sin_cache
q = q.contiguous()
k = k.contiguous()
torch.ops.vllm.rms_rotary_embedding_fuse(
positions,
q,
k,
self.head_dim,
cos_sin_cache,
self.rotary_emb.is_neox_style,
self.q_norm.weight,
self.k_norm.weight,
None,
None,
self.q_norm.variance_epsilon,
)
else: else:
k_by_head = self.k_norm.forward_cuda(k_by_head) # Add qk-norm
k = k_by_head.view(k.shape) q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
q, k = self.rotary_emb(positions, q, k) if envs.VLLM_USE_APEX_RN:
q_by_head = self.q_norm.forward_apex(q_by_head)
else:
q_by_head = self.q_norm.forward_cuda(q_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
if envs.VLLM_USE_APEX_RN:
k_by_head = self.k_norm.forward_apex(k_by_head)
else:
k_by_head = self.k_norm.forward_cuda(k_by_head)
k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
......
...@@ -189,7 +189,10 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -189,7 +189,10 @@ class FlashAttentionBackend(AttentionBackend):
@staticmethod @staticmethod
def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype: def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype:
if kv_cache_dtype in ("fp8", "fp8_e4m3"): if kv_cache_dtype in ("fp8", "fp8_e4m3"):
return torch.float8_e4m3fn if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
return torch.float8_e4m3fn
else:
raise ValueError(f"{kv_cache_dtype} only supported on nmz")
elif kv_cache_dtype in ("fp8_e5m2"): elif kv_cache_dtype in ("fp8_e5m2"):
return torch.float8_e5m2 return torch.float8_e5m2
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment