Commit 0e5a20b3 authored by xiabo's avatar xiabo
Browse files

支持kvacache fp8_e4m3/fp8_e5m2

支持kvacache fp8_e4m3/fp8_e5m2的RMS_ROPE_CONCAT
parent 06185134
......@@ -2189,20 +2189,13 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
prefill_k_pe = k_pe[num_decode_tokens:]
use_fused_rms_rope_concat = (
envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
and (fused_rms_norm_rope_contiguous is not None)
and (q_ori is not None)
and (key_normed is not None)
and (positions is not None)
and (weight is not None)
and (cos_sin_cache is not None)
and (not fp8_attention)
and (not getattr(layer, "calculate_kv_scales", False))
)
kv_cache_dtype_str: str | None = None
if use_fused_rms_rope_concat:
# if use_fused_rms_rope_concat:
if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
decode_q = q[:num_decode_tokens]
prefill_q = q[num_decode_tokens:]
prefill_k_c_normed = k_c_normed[num_decode_tokens:]
else:
# q is q_pe (rope part) in this mode; q_ori is the full q tensor.
q_ori = q_ori[:num_actual_toks, ...]
decode_q = q_ori[:num_decode_tokens]
......@@ -2218,34 +2211,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
else:
kv_cache_dtype_str = self.kv_cache_dtype
# Phase-1: only enable for fp16/bf16 caches (non-fp8).
if kv_cache_dtype_str not in ("fp16", "bf16"):
use_fused_rms_rope_concat = False
if (
envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
and not use_fused_rms_rope_concat
):
raise RuntimeError(
"VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT was requested, but the fused "
"path is not available for this configuration."
)
if not use_fused_rms_rope_concat:
decode_q = q[:num_decode_tokens]
prefill_q = q[num_decode_tokens:]
prefill_k_c_normed = k_c_normed[num_decode_tokens:]
if use_fused_rms_rope_concat and kv_cache.numel() == 0:
# This mode relies on the fused op to produce kv_c_normed and apply
# RoPE; without KV cache allocated we'd compute with uninitialized
# buffers.
raise RuntimeError(
"VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT requires a non-empty kv_cache."
)
# write the latent and rope to kv cache
if kv_cache.numel() > 0:
if not use_fused_rms_rope_concat:
if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
ops.concat_and_cache_mla(
k_c_normed,
k_pe.squeeze(1),
......@@ -2283,7 +2251,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
kv_cache = kv_cache.view(current_platform.fp8_dtype())
if has_prefill:
if use_fused_rms_rope_concat:
if envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
# key_normed is filled by fused op above.
prefill_k_c_normed = key_normed[num_decode_tokens:]
self._forward_prefill(
......
......@@ -162,22 +162,16 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_cache_dtype = getattr(self.mla_attn, "kv_cache_dtype", "auto")
calculate_kv_scales = getattr(self.mla_attn, "calculate_kv_scales", False)
use_fused_rms_rope_concat = (
envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
and (self.rotary_emb is not None)
and (not self.is_sparse)
and (not calculate_kv_scales)
and (kv_cache_dtype in ("auto", "bfloat16"))
and (q.dtype in (torch.float16, torch.bfloat16))
)
if not use_fused_rms_rope_concat:
if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
kv_c_normed = self.kv_a_layernorm(kv_c)
q = q.view(-1, self.num_heads, self.qk_head_dim)
# Add head dim of 1 to k_pe
k_pe = k_pe.unsqueeze(1)
if not use_fused_rms_rope_concat and self.rotary_emb is not None:
# if not use_fused_rms_rope_concat and self.rotary_emb is not None:
if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT and self.rotary_emb is not None:
q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb(
positions, q[..., self.qk_nope_head_dim:], k_pe
)
......@@ -189,7 +183,8 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
if llama_4_scaling is not None:
q *= llama_4_scaling
if not use_fused_rms_rope_concat:
# if not use_fused_rms_rope_concat:
if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
attn_out = self.mla_attn(
q,
kv_c_normed,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment