Commit 0dfb30d5 authored by zhuwenwen's avatar zhuwenwen
Browse files

support fa kvcache fp8, add VLLM_USE_QUERY_QUANT to not use q quant(todo)

[opt] 优化epsp代码, 零消耗添加epsp
update VLLM_USE_FUSED_RMS_ROPE=0 (default). for qwen3, VLLM_USE_FUSED_RMS_ROPE=1 (default)
feat(moe/marlin): Marlin W16A16 MoE 自动探测并预打包(去掉手动开关)
perf(qwen3): 融合 q/k RMSNorm + RoPE
fused_moe_fp8接入lmslim
parent 7624bd05
......@@ -27,7 +27,7 @@ static inline __device__ float fp8_to_float(uint8_t input) {
}
// float -> fp8
static inline __device__ uint8_t float_to_fp8(float f) {
static inline __device__ uint8_t float_to_fp8_e4m3(float f) {
constexpr uint32_t fp8_max = UINT32_C(1087) << 20;
constexpr uint32_t denorm_mask = UINT32_C(141) << 23;
uint32_t f_bits = c10::detail::fp32_to_bits(f);
......@@ -53,10 +53,35 @@ static inline __device__ uint8_t float_to_fp8(float f) {
return result;
}
static inline __device__ uint8_t float_to_fp8_e5m2(float f) {
constexpr uint32_t fp32_inf = UINT32_C(255) << 23;
constexpr uint32_t fp8_max = UINT32_C(143) << 23;
constexpr uint32_t denorm_mask = UINT32_C(134) << 23;
uint32_t f_bits = c10::detail::fp32_to_bits(f);
uint8_t result = 0u;
const uint32_t sign = f_bits & UINT32_C(0x80000000);
f_bits ^= sign;
if (f_bits >= fp8_max) {
result = f_bits > fp32_inf ? UINT8_C(0x7F) : UINT8_C(0x7C);
} else {
if (f_bits < (UINT32_C(113) << 23)) {
f_bits = c10::detail::fp32_to_bits(c10::detail::fp32_from_bits(f_bits)
+ c10::detail::fp32_from_bits(denorm_mask));
result = static_cast<uint8_t>(f_bits - denorm_mask);
} else {
uint32_t mant_odd = (f_bits >> 21) & 1;
f_bits += ((uint32_t)(15 - 127) << 23) + 0xFFFFF;
f_bits += mant_odd;
result = static_cast<uint8_t>(f_bits >> 21);
}
}
result |= static_cast<uint8_t>(sign >> 24);
return result;
}
template <typename Tout, typename Tin>
__inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
const float scale) {
const float scale, Fp8KVCacheDataType kv_type) {
return x;
}
......@@ -65,7 +90,10 @@ using __nv_bfloat16 = __hip_bfloat16;
// fp8 -> __nv_bfloat16
template <>
__inline__ __device__ __nv_bfloat16
scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) {
scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale, Fp8KVCacheDataType kv_type) {
if (kv_type == vllm::Fp8KVCacheDataType::kFp8E5M2) {
assert(false);
}
return __float2bfloat16(fp8_to_float(a) * scale);
}
......@@ -74,32 +102,32 @@ scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) {
template <>
__inline__ __device__ __nv_bfloat162
scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a,
float scale) {
float scale, Fp8KVCacheDataType kv_type) {
__nv_bfloat162 res;
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale);
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale, kv_type);
res.y =
scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale);
scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale, kv_type);
return res;
}
// fp8x4 -> bf16_4_t
template <>
__inline__ __device__ bf16_4_t
scaled_vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a, float scale) {
scaled_vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a, float scale, Fp8KVCacheDataType kv_type) {
bf16_4_t res;
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale);
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale, kv_type);
res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
scale);
scale, kv_type);
return res;
}
// fp8x8 -> bf16_8_t
template <>
__inline__ __device__ bf16_8_t
scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, float scale) {
scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, float scale, Fp8KVCacheDataType kv_type) {
bf16_4_t tmp1, tmp2;
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale);
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale);
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale, kv_type);
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale, kv_type);
bf16_8_t res;
res.x = tmp1.x;
res.y = tmp1.y;
......@@ -111,45 +139,48 @@ scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, float scale) {
// fp8 -> float
template <>
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
const uint8_t& a, float scale) {
const uint8_t& a, float scale, Fp8KVCacheDataType kv_type) {
if (kv_type == vllm::Fp8KVCacheDataType::kFp8E5M2) {
assert(false);
}
return fp8_to_float(a) * scale;
}
// fp8x2 -> float2
template <>
__inline__ __device__ float2
scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, float scale) {
scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, float scale, Fp8KVCacheDataType kv_type) {
float2 f2r;
f2r.x = scaled_vec_conversion<float, uint8_t>((uint8_t)a, scale);
f2r.y = scaled_vec_conversion<float, uint8_t>((uint8_t)(a >> 8U), scale);
f2r.x = scaled_vec_conversion<float, uint8_t>((uint8_t)a, scale, kv_type);
f2r.y = scaled_vec_conversion<float, uint8_t>((uint8_t)(a >> 8U), scale, kv_type);
return f2r;
}
// fp8x4 -> float4
template <>
__inline__ __device__ Float4_
scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale) {
scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale, Fp8KVCacheDataType kv_type) {
Float4_ res;
res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale);
res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale);
res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale, kv_type);
res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale, kv_type);
return res;
}
// fp8x4 -> float4
template <>
__inline__ __device__ float4
scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, float scale) {
Float4_ res = scaled_vec_conversion<Float4_, uint32_t>(a, scale);
scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, float scale, Fp8KVCacheDataType kv_type) {
Float4_ res = scaled_vec_conversion<Float4_, uint32_t>(a, scale, kv_type);
return {res.x.x, res.x.y, res.y.x, res.y.y};
}
// fp8x8 -> float8
template <>
__inline__ __device__ Float8_
scaled_vec_conversion<Float8_, uint2>(const uint2& a, float scale) {
scaled_vec_conversion<Float8_, uint2>(const uint2& a, float scale, Fp8KVCacheDataType kv_type) {
Float4_ tmp1, tmp2;
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale);
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale);
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale, kv_type);
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale, kv_type);
Float8_ res;
res.x = tmp1.x;
res.y = tmp1.y;
......@@ -161,7 +192,10 @@ scaled_vec_conversion<Float8_, uint2>(const uint2& a, float scale) {
// fp8 -> half
template <>
__inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale) {
scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale, Fp8KVCacheDataType kv_type) {
if (kv_type == vllm::Fp8KVCacheDataType::kFp8E5M2) {
assert(false);
}
float res = fp8_to_float(a) * scale;
return float_to_half(res);
}
......@@ -169,54 +203,58 @@ scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale) {
// fp8x2 -> half2
template <>
__inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, float scale) {
scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, float scale, Fp8KVCacheDataType kv_type) {
union {
uint16_t u16[2];
uint32_t u32;
} res;
res.u16[0] = scaled_vec_conversion<uint16_t, uint8_t>((uint8_t)a, scale);
res.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>((uint8_t)(a >> 8U), scale);
res.u16[0] = scaled_vec_conversion<uint16_t, uint8_t>((uint8_t)a, scale, kv_type);
res.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>((uint8_t)(a >> 8U), scale, kv_type);
return res.u32;
}
// fp8x4 -> half2x2
template <>
__inline__ __device__ uint2
scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, float scale) {
scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, float scale, Fp8KVCacheDataType kv_type) {
union {
uint2 u32x2;
uint32_t u32[2];
} tmp;
tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale, kv_type);
tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale, kv_type);
return tmp.u32x2;
}
// fp8x8 -> half2x4
template <>
__inline__ __device__ uint4 scaled_vec_conversion<uint4, uint2>(const uint2& a,
float scale) {
float scale, Fp8KVCacheDataType kv_type) {
union {
uint4 u64x2;
uint2 u64[2];
} tmp;
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale);
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale);
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale, kv_type);
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale, kv_type);
return tmp.u64x2;
}
// half -> fp8
template <>
__inline__ __device__ uint8_t
scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, float scale) {
scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, float scale, Fp8KVCacheDataType kv_type) {
float res_f = half_to_float(a) / scale;
return float_to_fp8(res_f);
if (kv_type == vllm::Fp8KVCacheDataType::kFp8E4M3) {
return float_to_fp8_e4m3(res_f);
} else {
return float_to_fp8_e5m2(res_f);
}
}
// halfx2 -> fp8x2
template <>
__inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale) {
scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale, Fp8KVCacheDataType kv_type) {
union {
uint8_t ui8[2];
uint16_t ui16;
......@@ -226,113 +264,121 @@ scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale) {
half2 h2r;
} tmp_a;
tmp_a.ui32 = a;
tmp.ui8[0] = scaled_vec_conversion<uint8_t, uint16_t>(tmp_a.h2r.data[0], scale);
tmp.ui8[1] = scaled_vec_conversion<uint8_t, uint16_t>(tmp_a.h2r.data[1], scale);
tmp.ui8[0] = scaled_vec_conversion<uint8_t, uint16_t>(tmp_a.h2r.data[0], scale, kv_type);
tmp.ui8[1] = scaled_vec_conversion<uint8_t, uint16_t>(tmp_a.h2r.data[1], scale, kv_type);
return tmp.ui16;
}
// half2x2 -> fp8x4
template <>
__inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, uint2>(const uint2& a, float scale) {
scaled_vec_conversion<uint32_t, uint2>(const uint2& a, float scale, Fp8KVCacheDataType kv_type) {
union {
uint16_t ui16[2];
uint32_t ui32;
} tmp;
tmp.ui16[0] = scaled_vec_conversion<uint16_t, uint32_t>(a.x, scale);
tmp.ui16[1] = scaled_vec_conversion<uint16_t, uint32_t>(a.y, scale);
tmp.ui16[0] = scaled_vec_conversion<uint16_t, uint32_t>(a.x, scale, kv_type);
tmp.ui16[1] = scaled_vec_conversion<uint16_t, uint32_t>(a.y, scale, kv_type);
return tmp.ui32;
}
// half2x4 -> fp8x8
template <>
__inline__ __device__ uint2 scaled_vec_conversion<uint2, uint4>(const uint4& a,
float scale) {
float scale, Fp8KVCacheDataType kv_type) {
union {
uint2 ui2[2];
uint4 ui4;
} tmp;
tmp.ui4 = a;
uint2 res;
res.x = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[0], scale);
res.y = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[1], scale);
res.x = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[0], scale, kv_type);
res.y = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[1], scale, kv_type);
return res;
}
// bf16 -> fp8
template <>
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
const __nv_bfloat16& a, float scale) {
const __nv_bfloat16& a, float scale, Fp8KVCacheDataType kv_type) {
float res_f = (static_cast<float>(a)) / scale;
return float_to_fp8(res_f);
if (kv_type == vllm::Fp8KVCacheDataType::kFp8E4M3) {
return float_to_fp8_e4m3(res_f);
} else {
return float_to_fp8_e5m2(res_f);
}
}
// bf16x2 -> fp8x2
template <>
__inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, __nv_bfloat162>(
const __nv_bfloat162& a, float scale) {
const __nv_bfloat162& a, float scale, Fp8KVCacheDataType kv_type) {
union {
uint8_t ui8[2];
uint16_t ui16;
} tmp;
tmp.ui8[0] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.x, scale);
tmp.ui8[1] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.y, scale);
tmp.ui8[0] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.x, scale, kv_type);
tmp.ui8[1] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.y, scale, kv_type);
return tmp.ui16;
}
// bf16x4 -> fp8x4
template <>
__inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, bf16_4_t>(const bf16_4_t& a, float scale) {
scaled_vec_conversion<uint32_t, bf16_4_t>(const bf16_4_t& a, float scale, Fp8KVCacheDataType kv_type) {
union {
uint16_t ui16[2];
uint32_t ui32;
} tmp;
tmp.ui16[0] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.x, scale);
tmp.ui16[1] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.y, scale);
tmp.ui16[0] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.x, scale, kv_type);
tmp.ui16[1] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.y, scale, kv_type);
return tmp.ui32;
}
// bf16x8 -> fp8x8
template <>
__inline__ __device__ uint2
scaled_vec_conversion<uint2, bf16_8_t>(const bf16_8_t& a, float scale) {
scaled_vec_conversion<uint2, bf16_8_t>(const bf16_8_t& a, float scale, Fp8KVCacheDataType kv_type) {
uint2 res;
res.x = scaled_vec_conversion<uint32_t, bf16_4_t>({a.x, a.y}, scale);
res.y = scaled_vec_conversion<uint32_t, bf16_4_t>({a.z, a.w}, scale);
res.x = scaled_vec_conversion<uint32_t, bf16_4_t>({a.x, a.y}, scale, kv_type);
res.y = scaled_vec_conversion<uint32_t, bf16_4_t>({a.z, a.w}, scale, kv_type);
return res;
}
// float -> fp8
template <>
__inline__ __device__ uint8_t
scaled_vec_conversion<uint8_t, float>(const float& a, float scale) {
return float_to_fp8(a / scale);
scaled_vec_conversion<uint8_t, float>(const float& a, float scale, Fp8KVCacheDataType kv_type) {
if (kv_type == vllm::Fp8KVCacheDataType::kFp8E4M3) {
return float_to_fp8_e4m3(a / scale);
} else {
return float_to_fp8_e5m2(a / scale);
}
}
// floatx2 -> fp8x2
template <>
__inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, float2>(const float2& a, float scale) {
scaled_vec_conversion<uint16_t, float2>(const float2& a, float scale, Fp8KVCacheDataType kv_type) {
union {
uint8_t ui8[2];
uint16_t ui16;
} tmp;
tmp.ui8[0] = scaled_vec_conversion<uint8_t, float>(a.x, scale);
tmp.ui8[1] = scaled_vec_conversion<uint8_t, float>(a.y, scale);
tmp.ui8[0] = scaled_vec_conversion<uint8_t, float>(a.x, scale, kv_type);
tmp.ui8[1] = scaled_vec_conversion<uint8_t, float>(a.y, scale, kv_type);
return tmp.ui16;
}
// floatx4 -> fp8x4
template <>
__inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, float4>(const float4& a, float scale) {
scaled_vec_conversion<uint32_t, float4>(const float4& a, float scale, Fp8KVCacheDataType kv_type) {
union {
uint16_t ui16[2];
uint32_t ui32;
} tmp;
tmp.ui16[0] = scaled_vec_conversion<uint16_t, float2>({a.x, a.y}, scale);
tmp.ui16[1] = scaled_vec_conversion<uint16_t, float2>({a.z, a.w}, scale);
tmp.ui16[0] = scaled_vec_conversion<uint16_t, float2>({a.x, a.y}, scale, kv_type);
tmp.ui16[1] = scaled_vec_conversion<uint16_t, float2>({a.z, a.w}, scale, kv_type);
return tmp.ui32;
}
......@@ -433,9 +479,8 @@ scaled_vec_conversion_from_e5m2<__nv_bfloat16>(const uint8_t& a, float scale) {
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
return scaled_vec_conversion<Tout, Tin>(x, scale);
}
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3 || kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
return scaled_vec_conversion<Tout, Tin>(x, scale, kv_dt);
else if constexpr(kv_dt == Fp8KVCacheDataType::kFp8E5M2 && sizeof(Tout)==1){
return scaled_vec_conversion_to_e5m2<Tin>(x, scale);
}
......
......@@ -5,6 +5,7 @@ from typing import Optional
from vllm import envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
import torch
logger = init_logger(__name__)
......@@ -68,6 +69,8 @@ def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]:
def flash_attn_supports_fp8() -> bool:
if current_platform.is_rocm():
return True
return get_flash_attn_version() == 3 and \
current_platform.get_device_capability().major == 9
......
......@@ -149,6 +149,7 @@ if TYPE_CHECKING:
VLLM_USE_TRITON_PREFIX_FLASH_ATTN: bool = False
VLLM_USE_TRITON_OPT_MLA: bool = False
VLLM_USE_FLASH_ATTN_FP8: bool = False
VLLM_USE_QUERY_QUANT: bool = False
VLLM_USE_FLASH_MLA: bool = False
VLLM_USE_FLASH_MLA_FP8: bool = False
VLLM_USE_OPT_OP: bool = False
......@@ -199,7 +200,6 @@ if TYPE_CHECKING:
VLLM_USE_V32_ENCODE: bool = False
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: bool = False
VLLM_USE_FUSED_RMS_ROPE: bool = False
VLLM_USE_MARLIN_W16A16_MOE:bool = False
VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER: bool = False
VLLM_USE_FUSED_FILL_RMS_CAT:bool = False
VLLM_MOE_ROUTER_CAPTURE: bool = False
......@@ -1074,6 +1074,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_FLASH_ATTN_FP8":
lambda: bool(int(os.getenv("VLLM_USE_FLASH_ATTN_FP8", "1"))),
# flag to control if vllm should use q quant
"VLLM_USE_QUERY_QUANT":
lambda: (os.environ.get("VLLM_USE_QUERY_QUANT", "False").lower() in
("true", "1")),
# If set, vLLM will use FLASH MLA attention optimizations.
"VLLM_USE_FLASH_MLA":
lambda: bool(int(os.getenv("VLLM_USE_FLASH_MLA", "1"))),
......@@ -1307,11 +1312,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# vLLM will use fused RMS + RoPE kernel
"VLLM_USE_FUSED_RMS_ROPE":
lambda: (os.environ.get("VLLM_USE_FUSED_RMS_ROPE", "True").lower() in
("true", "1")),
# vLLM will use Marlin W16A16 kernel for MoE experts
"VLLM_USE_MARLIN_W16A16_MOE":
lambda: (os.environ.get("VLLM_USE_MARLIN_W16A16_MOE", "False").lower() in
lambda: (os.environ.get("VLLM_USE_FUSED_RMS_ROPE", "False").lower() in
("true", "1")),
# vLLM will use lightop for dpsk mtp fill + rms*2 + cat
"VLLM_USE_FUSED_FILL_RMS_CAT":
......
......@@ -1696,93 +1696,88 @@ def fused_experts_impl(
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE)
# Optional fast path: use Marlin W16A16 fused MoE implementation when
# explicitly requested. When weights are pre-packed in the post-load hook,
# w1/w2 are already in Marlin layout and we can avoid first-run packing
# peaks during KV cache profiling.
if envs.VLLM_USE_MARLIN_W16A16_MOE and not use_nn_moe:
try:
from vllm.model_executor.layers.fused_moe.fuse_moe_w16a16_marlin import ( # noqa: E501
fused_experts_impl_w16a16_marlin)
except Exception:
fused_experts_impl_w16a16_marlin = None # type: ignore
if fused_experts_impl_w16a16_marlin is not None:
K = hidden_states.size(1)
def _is_marlin_w16a16_packed(w1: torch.Tensor,
w2: torch.Tensor) -> bool:
if w1.dim() != 3 or w2.dim() != 3:
return False
if w1.size(0) != w2.size(0):
return False
k_div16 = w1.size(1)
if k_div16 * 16 != K:
return False
if w1.size(2) % 16 != 0:
return False
twoN = w1.size(2) // 16
if twoN % 2 != 0:
return False
N = twoN // 2
if w2.size(2) != K * 16:
return False
if w2.size(1) * 16 != N:
return False
return True
if (getattr(w1, "marlin_w16a16_packed", False)
or getattr(w2, "marlin_w16a16_packed", False)
or _is_marlin_w16a16_packed(w1, w2)):
E = w1.size(0)
if global_num_experts == -1:
global_num_experts = E
twoN = w1.size(2) // 16
if envs.VLLM_USE_GLOBAL_CACHE13:
cache13 = get_moe_cache(top_k_num,
twoN,
K,
device=hidden_states.device,
dtype=hidden_states.dtype)
else:
cache13 = torch.empty(M * top_k_num * max(twoN, K),
device=hidden_states.device,
dtype=hidden_states.dtype)
return fused_experts_impl_w16a16_marlin(
hidden_states=hidden_states,
w1_marlin=w1,
w2_marlin=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
cache13=cache13,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
use_nn_moe=False,
routed_scaling_factor=routed_scaling_factor,
shared_output=shared_output,
# Optional fast path: use Marlin W16A16 fused MoE implementation when the
# expert weights are already packed in Marlin layout.
if not use_nn_moe:
K = hidden_states.size(1)
def _is_marlin_w16a16_packed(w1: torch.Tensor,
w2: torch.Tensor) -> bool:
if w1.dim() != 3 or w2.dim() != 3:
return False
if w1.size(0) != w2.size(0):
return False
k_div16 = w1.size(1)
if k_div16 * 16 != K:
return False
if w1.size(2) % 16 != 0:
return False
twoN = w1.size(2) // 16
if twoN % 2 != 0:
return False
N = twoN // 2
if w2.size(2) != K * 16:
return False
if w2.size(1) * 16 != N:
return False
return True
is_packed = (getattr(w1, "marlin_w16a16_packed", False)
or getattr(w2, "marlin_w16a16_packed", False)
or _is_marlin_w16a16_packed(w1, w2))
if is_packed:
try:
from vllm.model_executor.layers.fused_moe.fuse_moe_w16a16_marlin import ( # noqa: E501
fused_experts_impl_w16a16_marlin)
except Exception:
fused_experts_impl_w16a16_marlin = None # type: ignore
if fused_experts_impl_w16a16_marlin is None:
raise RuntimeError(
"Marlin W16A16 MoE weights are packed, but the Marlin kernel is unavailable. "
"Ensure lightop is installed and VLLM_USE_LIGHTOP=1."
)
if activation != "silu":
raise RuntimeError(
"Marlin W16A16 MoE only supports activation='silu'.")
if apply_router_weight_on_input:
raise RuntimeError(
"Marlin W16A16 MoE does not support apply_router_weight_on_input=True."
)
E = w1.size(0)
if global_num_experts == -1:
global_num_experts = E
# No fallback packing: require pre-packed weights when Marlin W16A16
# MoE is enabled. If weights are still in the original layout, fail
# fast to avoid packing-induced peak memory and unpredictable
# warmup/profiling behavior.
if (w1.dim() == 3 and w2.dim() == 3 and w1.size(0) == w2.size(0)
and w2.size(1) == K):
twoN = w1.size(1)
N = w2.size(2)
if (twoN == 2 * N and (K % 32 == 0) and (N % 16 == 0)
and (twoN % 32 == 0)):
raise RuntimeError(
"VLLM_USE_MARLIN_W16A16_MOE is enabled, but MoE weights "
"are not pre-packed in Marlin layout. Pre-pack weights "
"during the post-load hook or disable "
"VLLM_USE_MARLIN_W16A16_MOE."
)
twoN = w1.size(2) // 16
if envs.VLLM_USE_GLOBAL_CACHE13:
cache13 = get_moe_cache(top_k_num,
twoN,
K,
device=hidden_states.device,
dtype=hidden_states.dtype)
else:
cache13 = torch.empty(M * top_k_num * max(twoN, K),
device=hidden_states.device,
dtype=hidden_states.dtype)
return fused_experts_impl_w16a16_marlin(
hidden_states=hidden_states,
w1_marlin=w1,
w2_marlin=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
cache13=cache13,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
use_nn_moe=False,
routed_scaling_factor=routed_scaling_factor,
shared_output=shared_output,
)
# Non-Marlin paths need the original weight shapes.
if use_nn_moe:
......@@ -1806,7 +1801,7 @@ def fused_experts_impl(
device=hidden_states.device,
dtype=hidden_states.dtype)
if use_int8_w8a8 is True:
if use_int8_w8a8 or use_fp8_w8a8:
return fused_experts_impl_int8(hidden_states=hidden_states,
w1=w1,
w2=w2,
......@@ -1816,8 +1811,8 @@ def fused_experts_impl(
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=False,
use_int8_w8a8=True,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=False,
use_int4_w4a16=False,
per_channel_quant=per_channel_quant,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import os
import importlib
......@@ -76,6 +77,65 @@ else:
logger = init_logger(__name__)
_MARLIN_W16A16_MOE_PROBE_BATCH_SIZES: tuple[int, ...] = (1, 128)
@functools.lru_cache
def _is_marlin_w16a16_moe_supported(
E: int,
N: int,
K: int,
top_k: int,
dtype: torch.dtype,
) -> bool:
"""Return True if lightop reports Marlin W16A16 MoE is supported.
This is a best-effort probe used to decide whether we can safely pre-pack
weights into Marlin layout (which would otherwise prevent fallback).
"""
if not (current_platform.is_cuda_alike() and torch.cuda.is_available()):
return False
if dtype not in (torch.float16, torch.bfloat16):
return False
if K % 32 != 0 or N % 16 != 0:
return False
if E <= 0 or N <= 0 or K <= 0 or top_k <= 0:
return False
if not envs.VLLM_USE_LIGHTOP:
return False
try:
from lightop import get_moe_cuda_marlin_config_w16a16
props = torch.cuda.get_device_properties(torch.cuda.current_device())
arch_name = getattr(props, "gcnArchName", None)
if isinstance(arch_name, str) and arch_name:
arch_name = arch_name.split(":")[0]
else:
arch_name = getattr(props, "name", None)
if not isinstance(arch_name, str) or not arch_name:
return False
arch_cu = props.multi_processor_count
twoN = 2 * N
for bs in _MARLIN_W16A16_MOE_PROBE_BATCH_SIZES:
_, _, status = get_moe_cuda_marlin_config_w16a16(
E,
bs,
twoN,
K,
K,
N,
top_k,
arch_name,
arch_cu,
dtype,
)
if not status:
return False
return True
except Exception:
return False
# Global auxilary stream for running operations in background streams.
# We have single global auxilary stream to avoid an explosion of streams
# for every layer (and make profiling look sane).
......@@ -407,12 +467,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)
# If Marlin W16A16 MoE is enabled, pre-pack weights once during the
# If Marlin W16A16 MoE is supported, pre-pack weights once during the
# post-load hook and replace parameters with the packed layout.
#
# This avoids first-run packing peaks during KV cache profiling and
# keeps only one copy of weights resident on GPU in steady state.
if (envs.VLLM_USE_MARLIN_W16A16_MOE and current_platform.is_cuda_alike()
if (getattr(layer, "_marlin_w16a16_moe_enabled", False)
and current_platform.is_cuda_alike()
and not getattr(layer, "use_nn_moe", False)
and not getattr(layer, "_marlin_w16a16_moe_packed", False)):
w1 = layer.w13_weight
......@@ -421,12 +482,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
and w1.dtype in (torch.float16, torch.bfloat16)
and w2.dtype in (torch.float16, torch.bfloat16)):
try:
from vllm.model_executor.layers.fused_moe.fuse_moe_w16a16_marlin import ( # noqa: E501
use_lightop as _use_lightop)
if not _use_lightop:
raise RuntimeError(
"Marlin W16A16 MoE kernel is disabled")
if w1.dim() != 3 or w2.dim() != 3 or w1.size(0) != w2.size(
0):
raise RuntimeError("Unexpected MoE weight shapes")
......@@ -991,9 +1046,25 @@ class FusedMoE(torch.nn.Module):
if quant_config is None:
# Not considering quant for now, temporarily
self.use_nn_moe = int(os.environ.get('MOE_NN', 1)) == 1
moe_in_dtype = model_dtype
self._marlin_w16a16_moe_enabled = (
params_dtype == moe_in_dtype and self.activation == "silu"
and not self.apply_router_weight_on_input
and _is_marlin_w16a16_moe_supported(
E=self.local_num_experts,
N=self.intermediate_size_per_partition,
K=self.hidden_size,
top_k=self.top_k,
dtype=moe_in_dtype,
))
self.use_nn_moe = int(os.environ.get("MOE_NN", 1)) == 1
# Marlin W16A16 MoE requires the non-NN weight layout.
if self._marlin_w16a16_moe_enabled:
self.use_nn_moe = False
else:
self.use_nn_moe = False
self._marlin_w16a16_moe_enabled = False
moe_quant_params = {
"num_experts": self.local_num_experts,
......
......@@ -21,6 +21,10 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import cdiv, direct_register_custom_op, has_deep_gemm
try:
from lmslim.layers.gemm.fp8_utils import per_token_group_quant_fp8,w8a8_block_fp8_matmul
except Exception:
print("INFO: Please updata lmslim if you want to use fp8_utils.\n")
logger = init_logger(__name__)
......@@ -255,332 +259,6 @@ def block_quant_to_tensor_quant(
return x_q_tensor, scale
@triton.jit
def _per_token_group_quant_fp8(
# Pointers to inputs and output
y_ptr,
y_q_ptr,
y_s_ptr,
group_size,
# Num columns of y
y_num_columns,
y_row_stride,
# Avoid to divide zero
eps,
# Information for float8
fp8_min,
fp8_max,
# Meta-parameters
BLOCK: tl.constexpr,
):
"""A Triton-accelerated function to perform per-token-group
quantization on a tensor.
This function converts the tensor values into float8 values.
"""
groups_per_row = y_num_columns // group_size
# Map the program id to the row of X and Y it should compute.
g_id = tl.program_id(0)
row = g_id // groups_per_row
row_g_id = g_id % groups_per_row
# Ensure offset calculations use int64 to prevent overflow
y_ptr_offset = (row.to(tl.int64) * y_row_stride) + (row_g_id.to(tl.int64) *
group_size)
y_ptr += y_ptr_offset
y_q_ptr_offset = g_id.to(tl.int64) * group_size
y_q_ptr += y_q_ptr_offset
y_s_ptr += g_id
cols = tl.arange(0, BLOCK) # N <= BLOCK
mask = cols < group_size
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / fp8_max
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
tl.store(y_q_ptr + cols, y_q, mask=mask)
tl.store(y_s_ptr, y_s)
@triton.jit
def _per_token_group_quant_fp8_colmajor(
# Pointers to inputs and output
y_ptr,
y_q_ptr,
y_s_ptr,
group_size,
# Num columns of y
y_num_columns,
y_row_stride,
# Stride from one column to the next of y_s
y_s_col_stride,
# Avoid to divide zero
eps,
# Information for float8
fp8_min,
fp8_max,
# Meta-parameters
BLOCK: tl.constexpr,
):
"""A Triton-accelerated function to perform per-token-group
quantization on a tensor.
This function converts the tensor values into float8 values.
"""
groups_per_row = y_num_columns // group_size
# Map the program id to the row of X and Y it should compute.
g_id = tl.program_id(0)
row = g_id // groups_per_row
row_g_id = g_id % groups_per_row
# Ensure offset calculations use int64 to prevent overflow
y_ptr_offset = (row.to(tl.int64) * y_row_stride) + (row_g_id.to(tl.int64) *
group_size)
y_ptr += y_ptr_offset
y_q_ptr_offset = g_id.to(tl.int64) * group_size
y_q_ptr += y_q_ptr_offset
# Convert g_id the flattened block coordinate to 2D so we can index
# into the output y_scales matrix
blocks_per_row = y_num_columns // group_size
scale_col = g_id % blocks_per_row
scale_row = g_id // blocks_per_row
# Ensure offset calculation uses int64 for y_s_ptr
y_s_ptr_offset = (scale_col.to(tl.int64) * y_s_col_stride) + scale_row.to(
tl.int64)
y_s_ptr += y_s_ptr_offset
cols = tl.arange(0, BLOCK) # group_size <= BLOCK
mask = cols < group_size
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / fp8_max
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
tl.store(y_q_ptr + cols, y_q, mask=mask)
tl.store(y_s_ptr, y_s)
def per_token_group_quant_fp8(
x: torch.Tensor,
group_size: int,
eps: float = 1e-10,
dtype: Optional[torch.dtype] = None,
column_major_scales: bool = False,
out_q: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Args:
x: The input tensor with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
is supported for now.
column_major_scales: Outputs scales in column major.
out_q: Optional output tensor. If not provided, function will create.
Returns:
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization.
"""
dtype = current_platform.fp8_dtype() if dtype is None else dtype
assert (x.shape[-1] % group_size == 0), (
f"the last dimension of `x` {x.shape[-1]} must be divisible "
f"by `group_size` {group_size}")
assert x.stride(-1) == 1, "`x` groups must be contiguous"
finfo = torch.finfo(dtype)
fp8_min = finfo.min
fp8_max = finfo.max
assert out_q is None or out_q.shape == x.shape
x_q = out_q
if x_q is None:
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // group_size
N = group_size
if column_major_scales:
shape = (x.shape[-1] // group_size, ) + x.shape[:-1]
x_s = torch.empty(shape, device=x.device,
dtype=torch.float32).permute(-1, -2)
else:
shape = x.shape[:-1] + (x.shape[-1] // group_size, )
x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
num_stages = 1
if column_major_scales:
_per_token_group_quant_fp8_colmajor[(M, )](
x,
x_q,
x_s,
group_size,
x.shape[1],
x.stride(0),
x_s.stride(1),
eps,
fp8_min=fp8_min,
fp8_max=fp8_max,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
)
else:
_per_token_group_quant_fp8[(M, )](
x,
x_q,
x_s,
group_size,
x.shape[1],
x.stride(0),
eps,
fp8_min=fp8_min,
fp8_max=fp8_max,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
)
return x_q, x_s
@triton.jit
def _w8a8_block_fp8_matmul(
# Pointers to inputs and output
A,
B,
C,
As,
Bs,
# Shape for matmul
M,
N,
K,
# Block size for block-wise quantization
group_n,
group_k,
# Stride for inputs and output
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_As_m,
stride_As_k,
stride_Bs_k,
stride_Bs_n,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""Triton-accelerated function used to perform linear operations (dot
product) on input tensors `A` and `B` with block-wise quantization, and
store the result in output tensor `C`.
"""
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
As_ptrs = As + offs_am * stride_As_m
offs_bsn = offs_bn // group_n
Bs_ptrs = Bs + offs_bsn * stride_Bs_n
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs,
mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
other=0.0)
b = tl.load(b_ptrs,
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0)
k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
if C.dtype.element_ty == tl.bfloat16:
c = accumulator.to(tl.bfloat16)
elif C.dtype.element_ty == tl.float16:
c = accumulator.to(tl.float16)
else:
c = accumulator.to(tl.float32)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
@functools.lru_cache
def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int,
block_k: int) -> Optional[dict[int, Any]]:
"""
Return optimized configurations for the w8a8 block fp8 kernel.
The return value will be a dictionary that maps an irregular grid of
batch sizes to configurations of the w8a8 block fp8 kernel. To evaluate the
kernel on a given batch size bs, the closest batch size in the grid should
be picked and the associated configuration chosen to invoke the kernel.
"""
# First look up if an optimized configuration is available in the configs
# directory
device_name = current_platform.get_device_name().replace(" ", "_")
json_file_name = f"N={N},K={K},device_name={device_name},dtype=fp8_w8a8,block_shape=[{block_n},{block_k}].json" # noqa: E501
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.info(
"Using configuration from %s for W8A8 Block FP8 kernel.",
config_file_path,
)
# If a configuration has been found, return it
return {int(key): val for key, val in json.load(f).items()}
# If no optimized configuration is available, we will use the default
# configuration
logger.warning(
"Using default W8A8 Block FP8 kernel config. Performance might "
"be sub-optimal! Config file not found at %s",
config_file_path,
)
return None
def hipblaslt_w8a8_block_fp8_matmul(
A: torch.Tensor,
B: torch.Tensor,
......@@ -603,90 +281,3 @@ def hipblaslt_w8a8_block_fp8_matmul(
m, n, k, 'NN', output_dtype,
enum_block_size, None)
return d
def w8a8_block_fp8_matmul(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
"""This function performs matrix multiplication with block-wise
quantization.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
Args:
A: The input tensor, e.g., activation.
B: The input tensor, e.g., weight.
As: The per-token-group quantization scale for `A`.
Bs: The per-block quantization scale for `B`.
block_size: The block size for per-block quantization. It should
be 2-dim, e.g., [128, 128].
output_dytpe: The dtype of the returned tensor.
Returns:
torch.Tensor: The result of matmul.
"""
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert A.shape[-1] == B.shape[-1]
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
M = A.numel() // A.shape[-1]
assert B.ndim == 2 and Bs.ndim == 2
N, K = B.shape
assert triton.cdiv(N, block_n) == Bs.shape[0]
assert triton.cdiv(K, block_k) == Bs.shape[1]
C_shape = A.shape[:-1] + (N, )
C = A.new_empty(C_shape, dtype=output_dtype)
configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1])
if configs:
# Get the optimal config if there is one
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else:
# Default config
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_size[0]
# BLOCK_SIZE_K must be divisible by block_size[1]
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": block_size[0],
"BLOCK_SIZE_K": block_size[1],
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2,
}
def grid(META):
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) *
triton.cdiv(N, META["BLOCK_SIZE_N"]), )
_w8a8_block_fp8_matmul[grid](
A,
B,
C,
As,
Bs,
M,
N,
K,
block_n,
block_k,
A.stride(-2),
A.stride(-1),
B.stride(1),
B.stride(0),
C.stride(-2),
C.stride(-1),
As.stride(-2),
As.stride(-1),
Bs.stride(1),
Bs.stride(0),
**config,
)
return C
......@@ -286,6 +286,8 @@ def get_model_architecture(
os.environ['VLLM_USE_FUSE_SILU_AND_MUL'] = '1'
if not envs.is_set("VLLM_USE_OPT_RESHAPE_AND_CACHE"):
os.environ['VLLM_USE_OPT_RESHAPE_AND_CACHE'] = '1'
if not envs.is_set("VLLM_USE_FUSED_RMS_ROPE"):
os.environ['VLLM_USE_FUSED_RMS_ROPE'] = '1'
if architectures in [['DeepseekV32ForCausalLM']]:
if not envs.is_set("VLLM_USE_V32_ENCODE"):
......@@ -332,6 +334,8 @@ def get_model_architecture(
os.environ['VLLM_USE_FUSE_SILU_AND_MUL'] = '1'
if not envs.is_set("VLLM_USE_OPT_RESHAPE_AND_CACHE"):
os.environ['VLLM_USE_OPT_RESHAPE_AND_CACHE'] = '1'
if not envs.is_set("VLLM_USE_FUSED_RMS_ROPE"):
os.environ['VLLM_USE_FUSED_RMS_ROPE'] = '1'
if architectures in [['DeepseekV32ForCausalLM']]:
if not envs.is_set("VLLM_USE_V32_ENCODE"):
......
......@@ -1043,6 +1043,8 @@ class DeepseekV2DecoderLayer(nn.Module):
prefix=f"{prefix}.mlp",
)
self.enable_ep_sp = isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1
self.is_mtp_layer = False
if self.layer_idx == config.num_hidden_layers:
self.is_mtp_layer = True
......@@ -1205,24 +1207,20 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
if not self.is_mtp_layer:
if isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1 and \
self.layer_idx > self.config.first_k_dense_replace:
hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0)
if not self.is_mtp_layer and self.enable_ep_sp and \
self.layer_idx > self.config.first_k_dense_replace:
hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
if not self.is_mtp_layer:
if isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1:
if self.layer_idx == self.config.first_k_dense_replace:
residual = residual.tensor_split(self.tp_size)[self.tp_rank]
if not self.is_mtp_layer and self.enable_ep_sp:
if self.layer_idx == self.config.first_k_dense_replace:
residual = residual.tensor_split(self.tp_size)[self.tp_rank]
hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, dim=0)
hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, dim=0)
if self.enable_dp_attention:
if self.tp_rank == 0:
......@@ -1249,15 +1247,13 @@ class DeepseekV2DecoderLayer(nn.Module):
residual = hidden_states[self.dp_rank*new_bs: (self.dp_rank+1)*new_bs, :]
hidden_states = self.post_attention_layernorm(hidden_states)
if self.is_mtp_layer:
if isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1:
ori_bs = hidden_states.shape[0]
pad_size = (ori_bs + self.tp_size - 1) // self.tp_size * self.tp_size - ori_bs
if pad_size > 0:
hidden_states = torch.nn.functional.pad(hidden_states.contiguous(), [0, 0, 0, pad_size], value=0).contiguous()
new_bs = (ori_bs+pad_size) // self.tp_size
hidden_states = hidden_states[self.tp_rank*new_bs: (self.tp_rank+1)*new_bs, :].contiguous()
if self.is_mtp_layer and self.enable_ep_sp:
ori_bs = hidden_states.shape[0]
pad_size = (ori_bs + self.tp_size - 1) // self.tp_size * self.tp_size - ori_bs
if pad_size > 0:
hidden_states = torch.nn.functional.pad(hidden_states, [0, 0, 0, pad_size], value=0)
new_bs = (ori_bs+pad_size) // self.tp_size
hidden_states = hidden_states[self.tp_rank*new_bs: (self.tp_rank+1)*new_bs, :]
hidden_states = self.mlp(hidden_states)
......@@ -1265,11 +1261,9 @@ class DeepseekV2DecoderLayer(nn.Module):
if self.enable_dp_attention:
hidden_states = dp_reduce_scatter_tensor(hidden_states)
if self.is_mtp_layer:
if isinstance(self.mlp,
DeepseekV2MoE) and self.use_deepep and self.tp_size > 1:
hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0)
hidden_states = hidden_states[:ori_bs, :]
if self.is_mtp_layer and self.enable_ep_sp:
hidden_states = tensor_model_parallel_all_gather(hidden_states, dim=0)
hidden_states = hidden_states[:ori_bs, :]
if isinstance(self.mlp,
DeepseekV2MLP) and hidden_states.dtype == torch.float16:
......
......@@ -52,6 +52,7 @@ from .qwen2 import Qwen2MLP as Qwen3MLP
from .qwen2 import Qwen2Model
from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix
import vllm.envs as envs
from vllm.utils import direct_register_custom_op
logger = init_logger(__name__)
......@@ -129,6 +130,58 @@ class Qwen3Attention(nn.Module):
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
def rms_rotary_embedding_fuse(
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor],
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox_style: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_bias: Optional[torch.Tensor],
k_bias: Optional[torch.Tensor],
epsilon: float,
) -> None:
from lightop import rms_rotary_embedding_fuse as fused_kernel
fused_kernel(
positions,
query,
key,
head_size,
cos_sin_cache,
is_neox_style,
q_weight,
k_weight,
q_bias,
k_bias,
epsilon,
)
def rms_rotary_embedding_fuse_fake(
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor],
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox_style: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_bias: Optional[torch.Tensor],
k_bias: Optional[torch.Tensor],
epsilon: float,
) -> None:
# Fake impl intentionally left as no-op for graph tracing modes.
pass
if not hasattr(torch.ops.vllm, "rms_rotary_embedding_fuse"):
direct_register_custom_op(
op_name="rms_rotary_embedding_fuse",
op_func=rms_rotary_embedding_fuse,
mutates_args=["query", "key"],
fake_impl=rms_rotary_embedding_fuse_fake,
)
def forward(
self,
positions: torch.Tensor,
......@@ -136,22 +189,49 @@ class Qwen3Attention(nn.Module):
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# Add qk-norm
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
self.head_dim)
if envs.VLLM_USE_APEX_RN:
q_by_head = self.q_norm.forward_apex(q_by_head)
else:
q_by_head = self.q_norm.forward_cuda(q_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
self.head_dim)
if envs.VLLM_USE_APEX_RN:
k_by_head = self.k_norm.forward_apex(k_by_head)
if envs.VLLM_USE_FUSED_RMS_ROPE:
# Fused RMSNorm + RoPE path through custom op.
cos_sin_cache = self.rotary_emb.cos_sin_cache
if (cos_sin_cache.device != q.device
or cos_sin_cache.dtype != q.dtype):
cos_sin_cache = cos_sin_cache.to(q.device,
dtype=q.dtype,
non_blocking=True)
# Persist the converted cache so we don't re-copy/re-allocate
# on every forward when the original buffer starts on CPU.
self.rotary_emb.cos_sin_cache = cos_sin_cache
q = q.contiguous()
k = k.contiguous()
torch.ops.vllm.rms_rotary_embedding_fuse(
positions,
q,
k,
self.head_dim,
cos_sin_cache,
self.rotary_emb.is_neox_style,
self.q_norm.weight,
self.k_norm.weight,
None,
None,
self.q_norm.variance_epsilon,
)
else:
k_by_head = self.k_norm.forward_cuda(k_by_head)
k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions, q, k)
# Add qk-norm
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
self.head_dim)
if envs.VLLM_USE_APEX_RN:
q_by_head = self.q_norm.forward_apex(q_by_head)
else:
q_by_head = self.q_norm.forward_cuda(q_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
self.head_dim)
if envs.VLLM_USE_APEX_RN:
k_by_head = self.k_norm.forward_apex(k_by_head)
else:
k_by_head = self.k_norm.forward_cuda(k_by_head)
k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
......
......@@ -1959,6 +1959,7 @@ class W8a8GetCacheJSON:
arch_name = torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0]
arch_cu = torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count
self.cache_json_data = {}
device_name =arch_name+'_'+str(arch_cu)+'cu'
self.device_name=device_name
self.topk=1
......@@ -2060,19 +2061,27 @@ class W8a8GetCacheJSON:
return self.triton_json_dir+f"/linear_{n}_{k}_block[{block_n},{block_k}]_{self.device_name}.json"
def get_moeint8json_name(self,E,N1,N2,K,TOPK,
block_size:Optional[list]=None,use_int4_w4a8:Optional[bool]=False):
block_size:Optional[list]=None,use_int4_w4a8:Optional[bool]=False,use_int8_w8a8:Optional[bool]=False):
if use_int4_w4a8:
if block_size is not None:
return self.triton_json_dir+f"/MOE_W4A8INT8[{block_size[0]},{block_size[1]}]_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
else:
return self.triton_json_dir+f"/MOE_W4A8INT8_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
return self.triton_json_dir+f"/MOE_W4A8INT8_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
elif use_int8_w8a8:
if block_size is not None:
return self.triton_json_dir + f"/MOE_BLOCKINT8[{block_size[0]},{block_size[1]}]_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
else:
return self.triton_json_dir + f"/MOE_W8A8INT8_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
else:
if block_size is not None:
return self.triton_json_dir+f"/MOE_BLOCKINT8[{block_size[0]},{block_size[1]}]_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
return self.triton_json_dir + f"/MOE_BLOCKFP8[{block_size[0]},{block_size[1]}]_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
else:
return self.triton_json_dir+f"/MOE_W8A8INT8_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
return self.triton_json_dir + f"/MOE_W8A8FP8_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
def get_moeint8_triton_cache(self,file_path,E,N1,N2,K,TOPK):
if file_path in self.cache_json_data:
# 直接返回缓存数据,避免重复读取
return self.cache_json_data[file_path]
cache_json_file=file_path
if os.path.exists(file_path):
......@@ -2089,6 +2098,7 @@ class W8a8GetCacheJSON:
configs_key= f"{sub_key}_{key}"
configs_dict[configs_key]=sub_value
self.cache_json_data[file_path] = configs_dict
return configs_dict
# Adapted from: https://stackoverflow.com/a/47212782/5082708
......
......@@ -136,6 +136,17 @@ class FlashAttentionBackend(AttentionBackend):
raise ValueError(f"Unknown cache layout format {cache_layout}.")
return key_stride_order, value_stride_order
@staticmethod
def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype:
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
return torch.float8_e4m3fn
else:
raise ValueError(f"Unsupported FP8 dtype: {kv_cache_dtype}")
elif kv_cache_dtype in ("fp8_e5m2"):
return torch.float8_e5m2
else:
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
@dataclass
class FlashAttentionMetadata:
......@@ -589,14 +600,19 @@ class FlashAttentionImpl(AttentionImpl):
)
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(torch.float8_e4m3fn)
value_cache = value_cache.view(torch.float8_e4m3fn)
num_tokens, num_heads, head_size = query.shape
query, _ = ops.scaled_fp8_quant(
query.reshape(
(num_tokens, num_heads * head_size)).contiguous(),
layer._q_scale)
query = query.reshape((num_tokens, num_heads, head_size))
# key_cache = key_cache.view(torch.float8_e4m3fn)
# value_cache = value_cache.view(torch.float8_e4m3fn)
dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
self.kv_cache_dtype)
key_cache = key_cache.view(dtype)
value_cache = value_cache.view(dtype)
if envs.VLLM_USE_QUERY_QUANT:
num_tokens, num_heads, head_size = query.shape
query, _ = ops.scaled_fp8_quant(
query.reshape(
(num_tokens, num_heads * head_size)).contiguous(),
layer._q_scale)
query = query.reshape((num_tokens, num_heads, head_size))
# Compute attention and update output up to `num_actual_tokens`.
use_local_attn = \
......@@ -620,9 +636,10 @@ class FlashAttentionImpl(AttentionImpl):
block_table = attn_metadata.block_table
scheduler_metadata = attn_metadata.scheduler_metadata
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
# descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
if not current_platform.is_rocm():
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
......@@ -640,9 +657,12 @@ class FlashAttentionImpl(AttentionImpl):
softcap=self.logits_soft_cap,
scheduler_metadata=scheduler_metadata,
fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
# q_descale=layer._q_scale.expand(descale_shape),
# k_descale=layer._k_scale.expand(descale_shape),
# v_descale=layer._v_scale.expand(descale_shape),
q_descale=None,
k_descale=layer._k_scale,
v_descale=layer._v_scale,
num_splits=attn_metadata.max_num_splits,
)
else:
......@@ -729,6 +749,9 @@ class FlashAttentionImpl(AttentionImpl):
# q_descale=layer._q_scale,
# k_descale=layer._k_scale,
# v_descale=layer._v_scale,
q_descale=None,
k_descale=layer._k_scale,
v_descale=layer._v_scale,
)
return output
......@@ -879,12 +902,12 @@ def cascade_attention(
return_softmax_lse=True,
scheduler_metadata=prefix_scheduler_metadata,
# fa_version=fa_version,
# q_descale=q_descale.expand(descale_shape)
# if q_descale is not None else None,
# k_descale=k_descale.expand(descale_shape)
# if k_descale is not None else None,
# v_descale=v_descale.expand(descale_shape)
# if v_descale is not None else None,
q_descale=q_descale.expand(descale_shape)
if q_descale is not None else None,
k_descale=k_descale.expand(descale_shape)
if k_descale is not None else None,
v_descale=v_descale.expand(descale_shape)
if v_descale is not None else None,
is_prefix_cache=True,
)
......@@ -932,12 +955,12 @@ def cascade_attention(
return_softmax_lse=True,
scheduler_metadata=suffix_scheduler_metadata,
# fa_version=fa_version,
# q_descale=q_descale.expand(descale_shape)
# if q_descale is not None else None,
# k_descale=k_descale.expand(descale_shape)
# if k_descale is not None else None,
# v_descale=v_descale.expand(descale_shape)
# if v_descale is not None else None,
q_descale=q_descale.expand(descale_shape)
if q_descale is not None else None,
k_descale=k_descale.expand(descale_shape)
if k_descale is not None else None,
v_descale=v_descale.expand(descale_shape)
if v_descale is not None else None,
is_prefix_cache=True,
)
......
......@@ -465,7 +465,7 @@ class V1ZeroModelRunner(GPUModelRunner):
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
if self.ep_sp:
if self.ep_sp or self.enable_dp_attention:
num_input_tokens = round_up(num_scheduled_tokens, tp_size)
if (self.use_cuda_graph
and num_input_tokens <= self.cudagraph_batch_sizes[-1]):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment