Commit f1a7696f authored by 王敏's avatar 王敏
Browse files

[perf]添加Module支持split qkv+rmsnorm+rope+kvcache融合算子,GLM4_MOE完成适配

parent 0786df31
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer."""
from typing import cast
from typing import cast, Optional
import torch
import torch.nn as nn
......@@ -195,6 +195,8 @@ class Attention(nn.Module, AttentionLayerBase):
block_size = 64 if envs.VLLM_USE_FLASH_ATTN_PA and envs.VLLM_USE_FLASH_MLA else 16
calculate_kv_scales = False
self.block_size = block_size
# llm-compressor mdls need to set cache_dtype to "fp8" manually.
if getattr(quant_config, "kv_cache_scheme", None) is not None:
kv_cache_dtype = "fp8"
......@@ -494,6 +496,101 @@ class Attention(nn.Module, AttentionLayerBase):
dtype=self.kv_cache_torch_dtype,
)
class FusedQkvSplitRmsNormRopeAttention(Attention):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int | None = None,
alibi_slopes: list[float] | None = None,
use_alibi_sqrt: bool | None = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
logits_soft_cap: float | None = None,
per_layer_sliding_window: int | None = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: str | None = None,
attn_backend: type[AttentionBackend] | None = None,
head_size_v: int | None = None,
**extra_impl_args,
) -> None:
super().__init__(num_heads, head_size, scale,
num_kv_heads, alibi_slopes,
use_alibi_sqrt, cache_config,
quant_config, logits_soft_cap,
per_layer_sliding_window,
prefix, attn_type,
kv_sharing_target_layer_name,
attn_backend,
head_size_v,
**extra_impl_args)
def forward(
self,
qkv: torch.Tensor,
positions: torch.Tensor,
cos_sin_cache: torch.Tensor,
weight_q_norm: torch.Tensor,
weight_k_norm: torch.Tensor,
epsilon: float,
# For some alternate attention backends like MLA the attention output
# shape does not match the query shape, so we optionally let the model
# definition specify the output tensor shape.
output_shape: torch.Size | None = None,
is_neox: bool = False,
) -> torch.Tensor:
"""
The KV cache is stored inside this class and is accessed via
`self.kv_cache`.
Attention metadata (`attn_metadata`) is set using a context manager in
the model runner's `execute_model` method. It is accessed via forward
context using
`vllm.forward_context.get_forward_context().attn_metadata`.
"""
output_dtype = qkv.dtype
num_tokens = qkv.shape[0]
if output_shape is None:
# Handle both 2D [num_tokens, hidden] and
# 3D [num_tokens, heads, head_dim] query
output_shape = torch.Size(
(num_tokens, self.num_heads * self.head_size_v)
)
output = torch.empty(output_shape, dtype=output_dtype, device=qkv.device)
output = output.view(-1, self.num_heads, self.head_size_v)
hidden_size = output_shape[-1]
q_size = self.num_heads * self.head_size
kv_size = self.num_kv_heads * self.head_size
query, key, value = torch.ops.vllm.fused_qkv_split_rmsnorm_rope_kv_store(qkv=qkv,
positions=positions,
layer_name=self.layer_name,
kv_cache_dtype=self.kv_cache_dtype,
cos_sin_cache=cos_sin_cache,
weight_q_norm=weight_q_norm,
weight_k_norm=weight_k_norm,
epsilon=epsilon,
head_size=self.head_size,
head_size_v=self.head_size_v,
q_size=q_size,
kv_size=kv_size,
block_size=self.block_size,
is_neox=is_neox)
kv_cache_dummy_dep = None
torch.ops.vllm.unified_attention_with_output(
query,
key,
value,
output,
self.layer_name,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
return output.view(-1, hidden_size)
class MLAAttention(nn.Module, AttentionLayerBase):
"""Multi-Head Latent Attention layer.
......@@ -995,3 +1092,107 @@ direct_register_custom_op(
fake_impl=unified_mla_attention_with_output_fake,
dispatch_key=current_platform.dispatch_key,
)
def fused_qkv_split_rmsnorm_rope_kv_store_impl(
qkv: torch.Tensor,
positions: torch.Tensor,
layer_name: str,
kv_cache_dtype: str,
cos_sin_cache: torch.Tensor,
weight_q_norm: torch.Tensor,
weight_k_norm: torch.Tensor,
epsilon: float,
head_size: int,
head_size_v: int,
q_size: int,
kv_size: int,
block_size: int,
is_neox: bool = False)-> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
num_tokens = qkv.shape[0]
forward_context = get_forward_context()
slot_mapping = forward_context.slot_mapping
layer_slot_mapping = slot_mapping.get(layer_name)
assert isinstance(slot_mapping, dict), (
f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. "
)
attn_layer = forward_context.no_compile_layers[layer_name]
kv_cache = attn_layer.kv_cache[forward_context.virtual_engine]
if layer_slot_mapping is not None:
if current_platform.is_rocm():
key_cache, value_cache = kv_cache
else:
key_cache, value_cache = kv_cache.unbind(0)
if kv_cache_dtype.startswith("fp8"):
# queries are quantized in the attention layer
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
kv_cache_dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
kv_cache_dtype
)
key_cache = key_cache.view(kv_cache_dtype)
value_cache = value_cache.view(kv_cache_dtype)
else:
key_cache = torch.empty([0], device=qkv.device, dtype=qkv.dtype)
value_cache = torch.empty([0], device=qkv.device, dtype=qkv.dtype)
from lightop import split_qkv_rms_rotary_embedding_fuse_with_kv_store_quant
q, k, v = split_qkv_rms_rotary_embedding_fuse_with_kv_store_quant(positions,
qkv.contiguous(),
q_size,
kv_size,
cos_sin_cache,
head_dim=head_size,
page_size=block_size,
k_buffer=key_cache,
v_buffer=value_cache,
kv_cache_loc=layer_slot_mapping,
is_neox=is_neox,
weight_q=weight_q_norm,
weight_k=weight_k_norm,
output_dtype=qkv.dtype,
kv_cache_dtype=kv_cache_dtype,
epsilon=epsilon,
residual_q=None,
residual_k=None,
k_scale=None,
v_scale=None,
)
q = q.contiguous().view(num_tokens, q_size//head_size, head_size)
k = k.contiguous().view(num_tokens, kv_size//head_size_v, head_size_v)
v = v.contiguous().view(num_tokens, kv_size//head_size_v, head_size_v)
return q, k ,v
def fused_qkv_split_rmsnorm_rope_kv_store_fake(
qkv: torch.Tensor,
positions: torch.Tensor,
layer_name: str,
kv_cache_dtype: str,
cos_sin_cache: torch.Tensor,
weight_q_norm: torch.Tensor,
weight_k_norm: torch.Tensor,
epsilon: float,
head_size: int,
head_size_v: int,
q_size: int,
kv_size: int,
block_size: int,
is_neox: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
num_token = qkv.shape[0]
q = torch.empty((num_token, q_size//head_size, head_size), device=qkv.device, dtype=qkv.dtype)
k = torch.empty((num_token, kv_size//head_size_v, head_size_v), device=qkv.device, dtype=qkv.dtype)
v = torch.empty((num_token, kv_size//head_size_v, head_size_v), device=qkv.device, dtype=qkv.dtype)
return q, k, v
direct_register_custom_op(
op_name="fused_qkv_split_rmsnorm_rope_kv_store",
op_func=fused_qkv_split_rmsnorm_rope_kv_store_impl,
mutates_args=["qkv", "positions"],
fake_impl=fused_qkv_split_rmsnorm_rope_kv_store_fake,
tags=(torch.Tag.needs_fixed_stride_order,),
)
\ No newline at end of file
......@@ -956,6 +956,7 @@ class CompilationConfig:
# https://github.com/vllm-project/vllm/issues/33267
if not self.use_inductor_graph_partition:
self.splitting_ops.append("vllm::unified_kv_cache_update")
self.splitting_ops.append("vllm::fused_qkv_split_rmsnorm_rope_kv_store")
elif len(self.splitting_ops) == 0:
if (
......
......@@ -302,6 +302,7 @@ if TYPE_CHECKING:
VLLM_USE_MOE_W16A16_TRITON: bool = False
VLLM_V1_FAST_TOKEN_ID_COPY: bool = False
VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER: bool = False
VLLM_V1_USE_FUSED_QKV_SPLIT_RMS_ROPE_KVSTORE: bool = False
def get_default_cache_root():
......@@ -1897,6 +1898,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
).lower()
in ("true", "1")
),
#If set to 1/True, enable fuse split qkv+rmsnorm+rope+kv update just like glm4.7 moe attention.
"VLLM_V1_USE_FUSED_QKV_SPLIT_RMS_ROPE_KVSTORE":
lambda: (os.environ.get("VLLM_V1_USE_FUSED_QKV_SPLIT_RMS_ROPE_KVSTORE", "False").lower() in
("true", "1")),
}
# --8<-- [end:env-vars-definition]
......
......@@ -32,7 +32,8 @@ import torch
from torch import nn
from transformers.models.glm4_moe import Glm4MoeConfig
from vllm.attention.layer import Attention
from vllm import envs
from vllm.attention.layer import Attention, FusedQkvSplitRmsNormRopeAttention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import (
......@@ -290,15 +291,27 @@ class Glm4MoeAttention(nn.Module):
max_position=max_position_embeddings,
rope_parameters=config.rope_parameters,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
if not envs.VLLM_V1_USE_FUSED_QKV_SPLIT_RMS_ROPE_KVSTORE:
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
else:
self.attn = FusedQkvSplitRmsNormRopeAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
if self.use_qk_norm:
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
......@@ -310,17 +323,36 @@ class Glm4MoeAttention(nn.Module):
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.use_qk_norm:
q = self.q_norm(q.reshape(-1, self.num_heads, self.head_dim)).reshape(
q.shape
)
k = self.k_norm(k.reshape(-1, self.num_kv_heads, self.head_dim)).reshape(
k.shape
)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
if not envs.VLLM_V1_USE_FUSED_QKV_SPLIT_RMS_ROPE_KVSTORE:
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.use_qk_norm:
q = self.q_norm(q.reshape(-1, self.num_heads, self.head_dim)).reshape(
q.shape
)
k = self.k_norm(k.reshape(-1, self.num_kv_heads, self.head_dim)).reshape(
k.shape
)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
else:
cos_sin_cache = self.rotary_emb.cos_sin_cache
if (cos_sin_cache.device != qkv.device
or cos_sin_cache.dtype != qkv.dtype):
cos_sin_cache = cos_sin_cache.to(qkv.device,
dtype=qkv.dtype,
non_blocking=True)
# Persist the converted cache so we don't re-copy/re-allocate
# on every forward when the original buffer starts on CPU.
self.rotary_emb.cos_sin_cache = cos_sin_cache
attn_output = self.attn(qkv,
positions,
cos_sin_cache,
self.q_norm.weight,
self.k_norm.weight,
self.q_norm.variance_epsilon,
is_neox=self.rotary_emb.is_neox_style)
output, _ = self.o_proj(attn_output)
return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment