Commit a82a91db authored by wangmin6's avatar wangmin6
Browse files

Merge branch 'v0.15.1-dev_lightop_rms_rope_concat' into 'v0.15.1-dev'

feat(deepseek-mla): 接入 VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT 融合链路

See merge request dcutoolkit/deeplearing/vllm!486
parents 16f88a8a 9aabf7e7
...@@ -734,10 +734,31 @@ class MLAAttention(nn.Module, AttentionLayerBase): ...@@ -734,10 +734,31 @@ class MLAAttention(nn.Module, AttentionLayerBase):
kv_c_normed: torch.Tensor, kv_c_normed: torch.Tensor,
k_pe: torch.Tensor, k_pe: torch.Tensor,
output_shape: torch.Size | None = None, output_shape: torch.Size | None = None,
q_ori: torch.Tensor | None = None,
key_normed: torch.Tensor | None = None,
positions: torch.Tensor | None = None,
weight: torch.Tensor | None = None,
cos_sin_cache: torch.Tensor | None = None,
epsilon: float | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
if self.calculate_kv_scales: # NOTE: fused path computes kv_c_normed inside the attention impl.
if self.calculate_kv_scales and q_ori is None:
torch.ops.vllm.maybe_calc_kv_scales(q, kv_c_normed, k_pe, self.layer_name) torch.ops.vllm.maybe_calc_kv_scales(q, kv_c_normed, k_pe, self.layer_name)
extra_kwargs: dict[str, object] = {}
if q_ori is not None:
extra_kwargs["q_ori"] = q_ori
if key_normed is not None:
extra_kwargs["key_normed"] = key_normed
if positions is not None:
extra_kwargs["positions"] = positions
if weight is not None:
extra_kwargs["weight"] = weight
if cos_sin_cache is not None:
extra_kwargs["cos_sin_cache"] = cos_sin_cache
if epsilon is not None:
extra_kwargs["epsilon"] = epsilon
if self.use_direct_call: if self.use_direct_call:
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata attn_metadata = forward_context.attn_metadata
...@@ -764,13 +785,26 @@ class MLAAttention(nn.Module, AttentionLayerBase): ...@@ -764,13 +785,26 @@ class MLAAttention(nn.Module, AttentionLayerBase):
else: else:
if self.attn_backend.accept_output_buffer: if self.attn_backend.accept_output_buffer:
output = torch.empty(output_shape, dtype=q.dtype, device=q.device) output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
torch.ops.vllm.unified_mla_attention_with_output( if not extra_kwargs:
q, torch.ops.vllm.unified_mla_attention_with_output(
kv_c_normed, q, kv_c_normed, k_pe, output, self.layer_name
k_pe, )
output, else:
self.layer_name, torch.ops.vllm.unified_mla_attention_with_output(
) q,
kv_c_normed,
k_pe,
output,
self.layer_name,
None,
None,
q_ori,
key_normed,
positions,
weight,
cos_sin_cache,
epsilon,
)
return output return output
else: else:
return torch.ops.vllm.unified_mla_attention( return torch.ops.vllm.unified_mla_attention(
...@@ -1074,8 +1108,29 @@ def unified_mla_attention_with_output( ...@@ -1074,8 +1108,29 @@ def unified_mla_attention_with_output(
layer_name: str, layer_name: str,
output_scale: torch.Tensor | None = None, output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None,
q_ori: torch.Tensor | None = None,
key_normed: torch.Tensor | None = None,
positions: torch.Tensor | None = None,
weight: torch.Tensor | None = None,
cos_sin_cache: torch.Tensor | None = None,
epsilon: float | None = None,
) -> None: ) -> None:
attn_metadata, self, kv_cache = get_attention_context(layer_name) attn_metadata, self, kv_cache = get_attention_context(layer_name)
extra_kwargs: dict[str, object] = {}
if q_ori is not None:
extra_kwargs["q_ori"] = q_ori
if key_normed is not None:
extra_kwargs["key_normed"] = key_normed
if positions is not None:
extra_kwargs["positions"] = positions
if weight is not None:
extra_kwargs["weight"] = weight
if cos_sin_cache is not None:
extra_kwargs["cos_sin_cache"] = cos_sin_cache
if epsilon is not None:
extra_kwargs["epsilon"] = epsilon
self.impl.forward( self.impl.forward(
self, self,
q, q,
...@@ -1086,6 +1141,7 @@ def unified_mla_attention_with_output( ...@@ -1086,6 +1141,7 @@ def unified_mla_attention_with_output(
output=output, output=output,
output_scale=output_scale, output_scale=output_scale,
output_block_scale=output_block_scale, output_block_scale=output_block_scale,
**extra_kwargs,
) )
...@@ -1097,6 +1153,12 @@ def unified_mla_attention_with_output_fake( ...@@ -1097,6 +1153,12 @@ def unified_mla_attention_with_output_fake(
layer_name: str, layer_name: str,
output_scale: torch.Tensor | None = None, output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None,
q_ori: torch.Tensor | None = None,
key_normed: torch.Tensor | None = None,
positions: torch.Tensor | None = None,
weight: torch.Tensor | None = None,
cos_sin_cache: torch.Tensor | None = None,
epsilon: float | None = None,
) -> None: ) -> None:
return return
......
...@@ -307,6 +307,7 @@ if TYPE_CHECKING: ...@@ -307,6 +307,7 @@ if TYPE_CHECKING:
VLLM_V1_USE_FUSED_QKV_SPLIT_RMS_ROPE_KVSTORE: bool = False VLLM_V1_USE_FUSED_QKV_SPLIT_RMS_ROPE_KVSTORE: bool = False
VLLM_USE_FUSED_DTBMM: bool = False # DOUBLE TRANS BMM FP8 VLLM_USE_FUSED_DTBMM: bool = False # DOUBLE TRANS BMM FP8
VLLM_USE_LIGHTOP_FILL_MOE_ALIGN: bool = False VLLM_USE_LIGHTOP_FILL_MOE_ALIGN: bool = False
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: bool = False
VLLM_USE_CUDA_GRAPH_SIZES: bool = False VLLM_USE_CUDA_GRAPH_SIZES: bool = False
...@@ -1920,6 +1921,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1920,6 +1921,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_V1_USE_FUSED_QKV_SPLIT_RMS_ROPE_KVSTORE": "VLLM_V1_USE_FUSED_QKV_SPLIT_RMS_ROPE_KVSTORE":
lambda: (os.environ.get("VLLM_V1_USE_FUSED_QKV_SPLIT_RMS_ROPE_KVSTORE", "False").lower() in lambda: (os.environ.get("VLLM_V1_USE_FUSED_QKV_SPLIT_RMS_ROPE_KVSTORE", "False").lower() in
("true", "1")), ("true", "1")),
# DeepSeek MLA: fused rmsnorm + contiguous + rope + concat_and_cache_mla
"VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT",
"False").lower() in ("true", "1")),
# DOUBLE TRANSPOSE BMM FP8 format use in NMZ DeepSeek models # DOUBLE TRANSPOSE BMM FP8 format use in NMZ DeepSeek models
"VLLM_USE_FUSED_DTBMM": "VLLM_USE_FUSED_DTBMM":
lambda: (os.environ.get("VLLM_USE_FUSED_DTBMM", "False").lower() in lambda: (os.environ.get("VLLM_USE_FUSED_DTBMM", "False").lower() in
......
...@@ -237,6 +237,11 @@ from vllm.v1.attention.ops.common import cp_lse_ag_out_rs ...@@ -237,6 +237,11 @@ from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
try:
from lightop import fused_rms_norm_rope_contiguous # type: ignore
except Exception:
fused_rms_norm_rope_contiguous = None # type: ignore[assignment]
class QueryLenSupport(Enum): class QueryLenSupport(Enum):
"""Defines the level of query length support for an attention backend's """Defines the level of query length support for an attention backend's
...@@ -2119,6 +2124,12 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -2119,6 +2124,12 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
output: torch.Tensor | None = None, output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None, output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None,
q_ori: torch.Tensor | None = None,
key_normed: torch.Tensor | None = None,
positions: torch.Tensor | None = None,
weight: torch.Tensor | None = None,
cos_sin_cache: torch.Tensor | None = None,
epsilon: float | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
...@@ -2170,27 +2181,105 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -2170,27 +2181,105 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
has_prefill = attn_metadata.num_prefills > 0 has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens num_decode_tokens = attn_metadata.num_decode_tokens
decode_q = q[:num_decode_tokens]
prefill_q = q[num_decode_tokens:]
prefill_k_pe = k_pe[num_decode_tokens:] prefill_k_pe = k_pe[num_decode_tokens:]
prefill_k_c_normed = k_c_normed[num_decode_tokens:]
use_fused_rms_rope_concat = (
envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
and (fused_rms_norm_rope_contiguous is not None)
and (q_ori is not None)
and (key_normed is not None)
and (positions is not None)
and (weight is not None)
and (cos_sin_cache is not None)
and (not fp8_attention)
and (not getattr(layer, "calculate_kv_scales", False))
)
kv_cache_dtype_str: str | None = None
if use_fused_rms_rope_concat:
# q is q_pe (rope part) in this mode; q_ori is the full q tensor.
q_ori = q_ori[:num_actual_toks, ...]
decode_q = q_ori[:num_decode_tokens]
prefill_q = q_ori[num_decode_tokens:]
if self.kv_cache_dtype == "auto":
if q.dtype == torch.float16:
kv_cache_dtype_str = "fp16"
elif q.dtype == torch.bfloat16:
kv_cache_dtype_str = "bf16"
elif self.kv_cache_dtype == "bfloat16":
kv_cache_dtype_str = "bf16"
else:
kv_cache_dtype_str = self.kv_cache_dtype
# Phase-1: only enable for fp16/bf16 caches (non-fp8).
if kv_cache_dtype_str not in ("fp16", "bf16"):
use_fused_rms_rope_concat = False
if (
envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
and not use_fused_rms_rope_concat
):
raise RuntimeError(
"VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT was requested, but the fused "
"path is not available for this configuration."
)
if not use_fused_rms_rope_concat:
decode_q = q[:num_decode_tokens]
prefill_q = q[num_decode_tokens:]
prefill_k_c_normed = k_c_normed[num_decode_tokens:]
if use_fused_rms_rope_concat and kv_cache.numel() == 0:
# This mode relies on the fused op to produce kv_c_normed and apply
# RoPE; without KV cache allocated we'd compute with uninitialized
# buffers.
raise RuntimeError(
"VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT requires a non-empty kv_cache."
)
# write the latent and rope to kv cache # write the latent and rope to kv cache
if kv_cache.numel() > 0: if kv_cache.numel() > 0:
ops.concat_and_cache_mla( if not use_fused_rms_rope_concat:
k_c_normed, ops.concat_and_cache_mla(
k_pe.squeeze(1), k_c_normed,
kv_cache, k_pe.squeeze(1),
attn_metadata.slot_mapping.flatten(), kv_cache,
kv_cache_dtype=self.kv_cache_dtype, attn_metadata.slot_mapping.flatten(),
scale=layer._k_scale, kv_cache_dtype=self.kv_cache_dtype,
) scale=layer._k_scale,
)
else:
assert kv_cache_dtype_str is not None
key_normed = key_normed[:num_actual_toks, ...]
positions = positions[:num_actual_toks, ...]
logger.info_once(
"Using lightop fused rmsnorm+rope+kv-cache update for MLA",
scope="local",
)
fused_rms_norm_rope_contiguous(
positions,
q,
k_pe.squeeze(1),
k_c_normed, # kv_c (not normed)
key_normed,
weight,
cos_sin_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache,
kv_cache_dtype_str,
1.0,
False,
epsilon if epsilon is not None else 1e-6,
)
if fp8_attention and get_gcn_arch_name() == "gfx938": if fp8_attention and get_gcn_arch_name() == "gfx938":
kv_cache = kv_cache.view(current_platform.fp8_dtype()) kv_cache = kv_cache.view(current_platform.fp8_dtype())
if has_prefill: if has_prefill:
if use_fused_rms_rope_concat:
# key_normed is filled by fused op above.
prefill_k_c_normed = key_normed[num_decode_tokens:]
self._forward_prefill( self._forward_prefill(
prefill_q, prefill_q,
prefill_k_c_normed, prefill_k_c_normed,
......
...@@ -6,9 +6,9 @@ import torch ...@@ -6,9 +6,9 @@ import torch
from vllm.attention.layer import MLAAttention from vllm.attention.layer import MLAAttention
from vllm.config import CacheConfig from vllm.config import CacheConfig
import vllm.envs as envs
from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm import envs
@dataclass @dataclass
...@@ -160,15 +160,26 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer): ...@@ -160,15 +160,26 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
q = self.q_proj(hidden_states)[0] q = self.q_proj(hidden_states)[0]
kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c) kv_cache_dtype = getattr(self.mla_attn, "kv_cache_dtype", "auto")
calculate_kv_scales = getattr(self.mla_attn, "calculate_kv_scales", False)
use_fused_rms_rope_concat = (
envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
and (self.rotary_emb is not None)
and (not self.is_sparse)
and (not calculate_kv_scales)
and (kv_cache_dtype in ("auto", "bfloat16"))
and (q.dtype in (torch.float16, torch.bfloat16))
)
if not use_fused_rms_rope_concat:
kv_c_normed = self.kv_a_layernorm(kv_c)
q = q.view(-1, self.num_heads, self.qk_head_dim) q = q.view(-1, self.num_heads, self.qk_head_dim)
# Add head dim of 1 to k_pe # Add head dim of 1 to k_pe
k_pe = k_pe.unsqueeze(1) k_pe = k_pe.unsqueeze(1)
if self.rotary_emb is not None: if not use_fused_rms_rope_concat and self.rotary_emb is not None:
q[..., self.qk_nope_head_dim :], k_pe = self.rotary_emb( q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb(
positions, q[..., self.qk_nope_head_dim :], k_pe positions, q[..., self.qk_nope_head_dim:], k_pe
) )
if self.indexer and self.is_sparse: if self.indexer and self.is_sparse:
...@@ -178,12 +189,54 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer): ...@@ -178,12 +189,54 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
if llama_4_scaling is not None: if llama_4_scaling is not None:
q *= llama_4_scaling q *= llama_4_scaling
if not use_fused_rms_rope_concat:
attn_out = self.mla_attn( attn_out = self.mla_attn(
q, q,
kv_c_normed, kv_c_normed,
k_pe, k_pe,
output_shape=(hidden_states.shape[0], self.num_heads * self.v_head_dim), output_shape=(hidden_states.shape[0],
) self.num_heads * self.v_head_dim),
)
else:
# Lightop fused path:
# - kv_c is passed as "unnormed" and written to kv_cache by the backend.
# - key_normed is an output buffer filled by the fused op and then
# used for the prefill path.
# Keep kv_c/k_pe as views into the original kv_lora buffer so they
# share the same row stride. The lightop fused op requires
# `kv_c.stride(0) == k_pe.stride(0)`, which is not true if we make
# kv_c individually contiguous.
key_normed = torch.empty_like(kv_c,
memory_format=torch.contiguous_format)
weight = getattr(self.kv_a_layernorm, "weight", None)
epsilon = getattr(self.kv_a_layernorm, "variance_epsilon", 1e-6)
if weight is None:
raise RuntimeError(
"VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT requires kv_a_layernorm "
"to have a 'weight' parameter."
)
# Keep cos_sin_cache on the same device/dtype as q.
if hasattr(self.rotary_emb, "_match_cos_sin_cache_dtype"):
# type: ignore[attr-defined]
self.rotary_emb._match_cos_sin_cache_dtype(q)
cos_sin_cache = getattr(self.rotary_emb, "cos_sin_cache", None)
if cos_sin_cache is None:
raise RuntimeError(
"VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT requires rotary_emb to "
"expose 'cos_sin_cache'."
)
attn_out = self.mla_attn(
q[..., self.qk_nope_head_dim:],
kv_c,
k_pe,
output_shape=(hidden_states.shape[0],
self.num_heads * self.v_head_dim),
q_ori=q,
key_normed=key_normed,
positions=positions,
weight=weight,
cos_sin_cache=cos_sin_cache,
epsilon=epsilon,
)
return self.o_proj(attn_out)[0] return self.o_proj(attn_out)[0]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment