Commit 0bd5fcd2 authored by wangmin6's avatar wangmin6
Browse files

Merge branch 'v0.15.1-dev-kvfp8-fuse' into 'v0.15.1-dev'

支持kvacache fp8_e4m3的RMS_ROPE_CONCAT

See merge request dcutoolkit/deeplearing/vllm!531
parents c3d75cdf 442abc67
...@@ -2189,20 +2189,13 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -2189,20 +2189,13 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
prefill_k_pe = k_pe[num_decode_tokens:] prefill_k_pe = k_pe[num_decode_tokens:]
use_fused_rms_rope_concat = (
envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
and (fused_rms_norm_rope_contiguous is not None)
and (q_ori is not None)
and (key_normed is not None)
and (positions is not None)
and (weight is not None)
and (cos_sin_cache is not None)
and (not fp8_attention)
and (not getattr(layer, "calculate_kv_scales", False))
)
kv_cache_dtype_str: str | None = None kv_cache_dtype_str: str | None = None
if use_fused_rms_rope_concat:
if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
decode_q = q[:num_decode_tokens]
prefill_q = q[num_decode_tokens:]
prefill_k_c_normed = k_c_normed[num_decode_tokens:]
else:
# q is q_pe (rope part) in this mode; q_ori is the full q tensor. # q is q_pe (rope part) in this mode; q_ori is the full q tensor.
q_ori = q_ori[:num_actual_toks, ...] q_ori = q_ori[:num_actual_toks, ...]
decode_q = q_ori[:num_decode_tokens] decode_q = q_ori[:num_decode_tokens]
...@@ -2218,34 +2211,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -2218,34 +2211,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
else: else:
kv_cache_dtype_str = self.kv_cache_dtype kv_cache_dtype_str = self.kv_cache_dtype
# Phase-1: only enable for fp16/bf16 caches (non-fp8).
if kv_cache_dtype_str not in ("fp16", "bf16"):
use_fused_rms_rope_concat = False
if (
envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
and not use_fused_rms_rope_concat
):
raise RuntimeError(
"VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT was requested, but the fused "
"path is not available for this configuration."
)
if not use_fused_rms_rope_concat:
decode_q = q[:num_decode_tokens]
prefill_q = q[num_decode_tokens:]
prefill_k_c_normed = k_c_normed[num_decode_tokens:]
if use_fused_rms_rope_concat and kv_cache.numel() == 0:
# This mode relies on the fused op to produce kv_c_normed and apply
# RoPE; without KV cache allocated we'd compute with uninitialized
# buffers.
raise RuntimeError(
"VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT requires a non-empty kv_cache."
)
# write the latent and rope to kv cache # write the latent and rope to kv cache
if kv_cache.numel() > 0: if kv_cache.numel() > 0:
if not use_fused_rms_rope_concat: if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
ops.concat_and_cache_mla( ops.concat_and_cache_mla(
k_c_normed, k_c_normed,
k_pe.squeeze(1), k_pe.squeeze(1),
...@@ -2283,7 +2251,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -2283,7 +2251,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
kv_cache = kv_cache.view(current_platform.fp8_dtype()) kv_cache = kv_cache.view(current_platform.fp8_dtype())
if has_prefill: if has_prefill:
if use_fused_rms_rope_concat: if envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
# key_normed is filled by fused op above. # key_normed is filled by fused op above.
prefill_k_c_normed = key_normed[num_decode_tokens:] prefill_k_c_normed = key_normed[num_decode_tokens:]
self._forward_prefill( self._forward_prefill(
......
...@@ -162,22 +162,16 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer): ...@@ -162,22 +162,16 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_cache_dtype = getattr(self.mla_attn, "kv_cache_dtype", "auto") kv_cache_dtype = getattr(self.mla_attn, "kv_cache_dtype", "auto")
calculate_kv_scales = getattr(self.mla_attn, "calculate_kv_scales", False) calculate_kv_scales = getattr(self.mla_attn, "calculate_kv_scales", False)
use_fused_rms_rope_concat = (
envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
and (self.rotary_emb is not None)
and (not self.is_sparse)
and (not calculate_kv_scales)
and (kv_cache_dtype in ("auto", "bfloat16"))
and (q.dtype in (torch.float16, torch.bfloat16))
)
if not use_fused_rms_rope_concat:
kv_c_normed = self.kv_a_layernorm(kv_c) kv_c_normed = self.kv_a_layernorm(kv_c)
q = q.view(-1, self.num_heads, self.qk_head_dim) q = q.view(-1, self.num_heads, self.qk_head_dim)
# Add head dim of 1 to k_pe # Add head dim of 1 to k_pe
k_pe = k_pe.unsqueeze(1) k_pe = k_pe.unsqueeze(1)
if not use_fused_rms_rope_concat and self.rotary_emb is not None: # if not use_fused_rms_rope_concat and self.rotary_emb is not None:
if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT and self.rotary_emb is not None:
q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb( q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb(
positions, q[..., self.qk_nope_head_dim:], k_pe positions, q[..., self.qk_nope_head_dim:], k_pe
) )
...@@ -189,7 +183,8 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer): ...@@ -189,7 +183,8 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
if llama_4_scaling is not None: if llama_4_scaling is not None:
q *= llama_4_scaling q *= llama_4_scaling
if not use_fused_rms_rope_concat: # if not use_fused_rms_rope_concat:
if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
attn_out = self.mla_attn( attn_out = self.mla_attn(
q, q,
kv_c_normed, kv_c_normed,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment