Commit f38f6c1e authored by zhangqha's avatar zhangqha
Browse files

Merge branch 'v0.15.1-dev-wm' into 'v0.15.1-dev'

[perf]添加Module支持split qkv+rmsnorm+rope+kvcache融合算子,GLM4_MOE完成适配

See merge request dcutoolkit/deeplearing/vllm!465
parents 0786df31 f1a7696f
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer.""" """Attention layer."""
from typing import cast from typing import cast, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -195,6 +195,8 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -195,6 +195,8 @@ class Attention(nn.Module, AttentionLayerBase):
block_size = 64 if envs.VLLM_USE_FLASH_ATTN_PA and envs.VLLM_USE_FLASH_MLA else 16 block_size = 64 if envs.VLLM_USE_FLASH_ATTN_PA and envs.VLLM_USE_FLASH_MLA else 16
calculate_kv_scales = False calculate_kv_scales = False
self.block_size = block_size
# llm-compressor mdls need to set cache_dtype to "fp8" manually. # llm-compressor mdls need to set cache_dtype to "fp8" manually.
if getattr(quant_config, "kv_cache_scheme", None) is not None: if getattr(quant_config, "kv_cache_scheme", None) is not None:
kv_cache_dtype = "fp8" kv_cache_dtype = "fp8"
...@@ -494,6 +496,101 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -494,6 +496,101 @@ class Attention(nn.Module, AttentionLayerBase):
dtype=self.kv_cache_torch_dtype, dtype=self.kv_cache_torch_dtype,
) )
class FusedQkvSplitRmsNormRopeAttention(Attention):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int | None = None,
alibi_slopes: list[float] | None = None,
use_alibi_sqrt: bool | None = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
logits_soft_cap: float | None = None,
per_layer_sliding_window: int | None = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: str | None = None,
attn_backend: type[AttentionBackend] | None = None,
head_size_v: int | None = None,
**extra_impl_args,
) -> None:
super().__init__(num_heads, head_size, scale,
num_kv_heads, alibi_slopes,
use_alibi_sqrt, cache_config,
quant_config, logits_soft_cap,
per_layer_sliding_window,
prefix, attn_type,
kv_sharing_target_layer_name,
attn_backend,
head_size_v,
**extra_impl_args)
def forward(
self,
qkv: torch.Tensor,
positions: torch.Tensor,
cos_sin_cache: torch.Tensor,
weight_q_norm: torch.Tensor,
weight_k_norm: torch.Tensor,
epsilon: float,
# For some alternate attention backends like MLA the attention output
# shape does not match the query shape, so we optionally let the model
# definition specify the output tensor shape.
output_shape: torch.Size | None = None,
is_neox: bool = False,
) -> torch.Tensor:
"""
The KV cache is stored inside this class and is accessed via
`self.kv_cache`.
Attention metadata (`attn_metadata`) is set using a context manager in
the model runner's `execute_model` method. It is accessed via forward
context using
`vllm.forward_context.get_forward_context().attn_metadata`.
"""
output_dtype = qkv.dtype
num_tokens = qkv.shape[0]
if output_shape is None:
# Handle both 2D [num_tokens, hidden] and
# 3D [num_tokens, heads, head_dim] query
output_shape = torch.Size(
(num_tokens, self.num_heads * self.head_size_v)
)
output = torch.empty(output_shape, dtype=output_dtype, device=qkv.device)
output = output.view(-1, self.num_heads, self.head_size_v)
hidden_size = output_shape[-1]
q_size = self.num_heads * self.head_size
kv_size = self.num_kv_heads * self.head_size
query, key, value = torch.ops.vllm.fused_qkv_split_rmsnorm_rope_kv_store(qkv=qkv,
positions=positions,
layer_name=self.layer_name,
kv_cache_dtype=self.kv_cache_dtype,
cos_sin_cache=cos_sin_cache,
weight_q_norm=weight_q_norm,
weight_k_norm=weight_k_norm,
epsilon=epsilon,
head_size=self.head_size,
head_size_v=self.head_size_v,
q_size=q_size,
kv_size=kv_size,
block_size=self.block_size,
is_neox=is_neox)
kv_cache_dummy_dep = None
torch.ops.vllm.unified_attention_with_output(
query,
key,
value,
output,
self.layer_name,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
return output.view(-1, hidden_size)
class MLAAttention(nn.Module, AttentionLayerBase): class MLAAttention(nn.Module, AttentionLayerBase):
"""Multi-Head Latent Attention layer. """Multi-Head Latent Attention layer.
...@@ -995,3 +1092,107 @@ direct_register_custom_op( ...@@ -995,3 +1092,107 @@ direct_register_custom_op(
fake_impl=unified_mla_attention_with_output_fake, fake_impl=unified_mla_attention_with_output_fake,
dispatch_key=current_platform.dispatch_key, dispatch_key=current_platform.dispatch_key,
) )
def fused_qkv_split_rmsnorm_rope_kv_store_impl(
qkv: torch.Tensor,
positions: torch.Tensor,
layer_name: str,
kv_cache_dtype: str,
cos_sin_cache: torch.Tensor,
weight_q_norm: torch.Tensor,
weight_k_norm: torch.Tensor,
epsilon: float,
head_size: int,
head_size_v: int,
q_size: int,
kv_size: int,
block_size: int,
is_neox: bool = False)-> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
num_tokens = qkv.shape[0]
forward_context = get_forward_context()
slot_mapping = forward_context.slot_mapping
layer_slot_mapping = slot_mapping.get(layer_name)
assert isinstance(slot_mapping, dict), (
f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. "
)
attn_layer = forward_context.no_compile_layers[layer_name]
kv_cache = attn_layer.kv_cache[forward_context.virtual_engine]
if layer_slot_mapping is not None:
if current_platform.is_rocm():
key_cache, value_cache = kv_cache
else:
key_cache, value_cache = kv_cache.unbind(0)
if kv_cache_dtype.startswith("fp8"):
# queries are quantized in the attention layer
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
kv_cache_dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
kv_cache_dtype
)
key_cache = key_cache.view(kv_cache_dtype)
value_cache = value_cache.view(kv_cache_dtype)
else:
key_cache = torch.empty([0], device=qkv.device, dtype=qkv.dtype)
value_cache = torch.empty([0], device=qkv.device, dtype=qkv.dtype)
from lightop import split_qkv_rms_rotary_embedding_fuse_with_kv_store_quant
q, k, v = split_qkv_rms_rotary_embedding_fuse_with_kv_store_quant(positions,
qkv.contiguous(),
q_size,
kv_size,
cos_sin_cache,
head_dim=head_size,
page_size=block_size,
k_buffer=key_cache,
v_buffer=value_cache,
kv_cache_loc=layer_slot_mapping,
is_neox=is_neox,
weight_q=weight_q_norm,
weight_k=weight_k_norm,
output_dtype=qkv.dtype,
kv_cache_dtype=kv_cache_dtype,
epsilon=epsilon,
residual_q=None,
residual_k=None,
k_scale=None,
v_scale=None,
)
q = q.contiguous().view(num_tokens, q_size//head_size, head_size)
k = k.contiguous().view(num_tokens, kv_size//head_size_v, head_size_v)
v = v.contiguous().view(num_tokens, kv_size//head_size_v, head_size_v)
return q, k ,v
def fused_qkv_split_rmsnorm_rope_kv_store_fake(
qkv: torch.Tensor,
positions: torch.Tensor,
layer_name: str,
kv_cache_dtype: str,
cos_sin_cache: torch.Tensor,
weight_q_norm: torch.Tensor,
weight_k_norm: torch.Tensor,
epsilon: float,
head_size: int,
head_size_v: int,
q_size: int,
kv_size: int,
block_size: int,
is_neox: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
num_token = qkv.shape[0]
q = torch.empty((num_token, q_size//head_size, head_size), device=qkv.device, dtype=qkv.dtype)
k = torch.empty((num_token, kv_size//head_size_v, head_size_v), device=qkv.device, dtype=qkv.dtype)
v = torch.empty((num_token, kv_size//head_size_v, head_size_v), device=qkv.device, dtype=qkv.dtype)
return q, k, v
direct_register_custom_op(
op_name="fused_qkv_split_rmsnorm_rope_kv_store",
op_func=fused_qkv_split_rmsnorm_rope_kv_store_impl,
mutates_args=["qkv", "positions"],
fake_impl=fused_qkv_split_rmsnorm_rope_kv_store_fake,
tags=(torch.Tag.needs_fixed_stride_order,),
)
\ No newline at end of file
...@@ -956,6 +956,7 @@ class CompilationConfig: ...@@ -956,6 +956,7 @@ class CompilationConfig:
# https://github.com/vllm-project/vllm/issues/33267 # https://github.com/vllm-project/vllm/issues/33267
if not self.use_inductor_graph_partition: if not self.use_inductor_graph_partition:
self.splitting_ops.append("vllm::unified_kv_cache_update") self.splitting_ops.append("vllm::unified_kv_cache_update")
self.splitting_ops.append("vllm::fused_qkv_split_rmsnorm_rope_kv_store")
elif len(self.splitting_ops) == 0: elif len(self.splitting_ops) == 0:
if ( if (
......
...@@ -302,6 +302,7 @@ if TYPE_CHECKING: ...@@ -302,6 +302,7 @@ if TYPE_CHECKING:
VLLM_USE_MOE_W16A16_TRITON: bool = False VLLM_USE_MOE_W16A16_TRITON: bool = False
VLLM_V1_FAST_TOKEN_ID_COPY: bool = False VLLM_V1_FAST_TOKEN_ID_COPY: bool = False
VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER: bool = False VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER: bool = False
VLLM_V1_USE_FUSED_QKV_SPLIT_RMS_ROPE_KVSTORE: bool = False
def get_default_cache_root(): def get_default_cache_root():
...@@ -1897,6 +1898,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1897,6 +1898,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
).lower() ).lower()
in ("true", "1") in ("true", "1")
), ),
#If set to 1/True, enable fuse split qkv+rmsnorm+rope+kv update just like glm4.7 moe attention.
"VLLM_V1_USE_FUSED_QKV_SPLIT_RMS_ROPE_KVSTORE":
lambda: (os.environ.get("VLLM_V1_USE_FUSED_QKV_SPLIT_RMS_ROPE_KVSTORE", "False").lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -32,7 +32,8 @@ import torch ...@@ -32,7 +32,8 @@ import torch
from torch import nn from torch import nn
from transformers.models.glm4_moe import Glm4MoeConfig from transformers.models.glm4_moe import Glm4MoeConfig
from vllm.attention.layer import Attention from vllm import envs
from vllm.attention.layer import Attention, FusedQkvSplitRmsNormRopeAttention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import ( from vllm.distributed import (
...@@ -290,6 +291,8 @@ class Glm4MoeAttention(nn.Module): ...@@ -290,6 +291,8 @@ class Glm4MoeAttention(nn.Module):
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=config.rope_parameters, rope_parameters=config.rope_parameters,
) )
if not envs.VLLM_V1_USE_FUSED_QKV_SPLIT_RMS_ROPE_KVSTORE:
self.attn = Attention( self.attn = Attention(
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,
...@@ -299,6 +302,16 @@ class Glm4MoeAttention(nn.Module): ...@@ -299,6 +302,16 @@ class Glm4MoeAttention(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
) )
else:
self.attn = FusedQkvSplitRmsNormRopeAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
if self.use_qk_norm: if self.use_qk_norm:
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
...@@ -310,6 +323,8 @@ class Glm4MoeAttention(nn.Module): ...@@ -310,6 +323,8 @@ class Glm4MoeAttention(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
if not envs.VLLM_V1_USE_FUSED_QKV_SPLIT_RMS_ROPE_KVSTORE:
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.use_qk_norm: if self.use_qk_norm:
q = self.q_norm(q.reshape(-1, self.num_heads, self.head_dim)).reshape( q = self.q_norm(q.reshape(-1, self.num_heads, self.head_dim)).reshape(
...@@ -321,6 +336,23 @@ class Glm4MoeAttention(nn.Module): ...@@ -321,6 +336,23 @@ class Glm4MoeAttention(nn.Module):
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v) attn_output = self.attn(q, k, v)
else:
cos_sin_cache = self.rotary_emb.cos_sin_cache
if (cos_sin_cache.device != qkv.device
or cos_sin_cache.dtype != qkv.dtype):
cos_sin_cache = cos_sin_cache.to(qkv.device,
dtype=qkv.dtype,
non_blocking=True)
# Persist the converted cache so we don't re-copy/re-allocate
# on every forward when the original buffer starts on CPU.
self.rotary_emb.cos_sin_cache = cos_sin_cache
attn_output = self.attn(qkv,
positions,
cos_sin_cache,
self.q_norm.weight,
self.k_norm.weight,
self.q_norm.variance_epsilon,
is_neox=self.rotary_emb.is_neox_style)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment