Commit e220b38b authored by laibao's avatar laibao
Browse files

[FEATURE] 为 Qwen3/Qwen3Moe 引入 QKV split + RMSNorm + RoPE 融合路径

 为 Qwen3 和 Qwen3Moe 增加可选的 fused QKV split + RMSNorm + RoPE 执行路径,
  减少中间张量拆分与重复计算开销,统一相关模型的优化开关控制逻辑。
parent 2dc182c0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Literal, Optional
from typing import TYPE_CHECKING, Literal, Optional, Tuple
import torch
......@@ -3711,6 +3711,63 @@ direct_register_custom_op(
fake_impl=rms_rotary_embedding_fuse_fake,
)
"""
qwen3 split_qkv+rn+rope
"""
def qkv_split_rms_rotary_embedding_fuse(
positions: torch.Tensor,
qkv: torch.Tensor,
q_size: int,
kv_size: int,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox_style: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
residual_q: Optional[torch.Tensor],
residual_k: Optional[torch.Tensor],
epsilon: float,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
from lightop import op
return op.qkv_split_rms_rotary_embedding_fuse(
positions,
qkv,
q_size,
kv_size,
head_size,
cos_sin_cache,
is_neox_style,
q_weight,
k_weight,
residual_q,
residual_k,
epsilon,
)
def qkv_split_rms_rotary_embedding_fuse_fake(
positions: torch.Tensor,
qkv: torch.Tensor,
q_size: int,
kv_size: int,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox_style: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
residual_q: Optional[torch.Tensor],
residual_k: Optional[torch.Tensor],
epsilon: float,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
q = qkv.narrow(-1, 0, q_size)
k = qkv.narrow(-1, q_size, kv_size)
v = qkv.narrow(-1, q_size + kv_size, kv_size)
return q, k, v
direct_register_custom_op(
op_name="qkv_split_rms_rotary_embedding_fuse",
op_func=qkv_split_rms_rotary_embedding_fuse,
mutates_args=["qkv"],
fake_impl=qkv_split_rms_rotary_embedding_fuse_fake,
)
"""
qwen3-vl-8b中LLM模型的修改 rms+mrope dim==2 2026/03/18
"""
......
......@@ -297,6 +297,7 @@ if TYPE_CHECKING:
VLLM_USE_OPT_RESHAPE_AND_CACHE: bool = False
VLLM_USE_TOPK_RENORM: bool = False
VLLM_USE_FUSED_RMS_ROPE: bool = False
VLLM_USE_QKV_SPLIT_RMS_ROPE: bool = False
VLLM_USE_FUSED_FILL_RMS_CAT: bool = False
VLLM_USE_CAT_MLA: bool = False
FP8_USE_MIXED_BATCH: bool = False
......@@ -1889,6 +1890,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_FUSED_RMS_ROPE":
lambda: (os.environ.get("VLLM_USE_FUSED_RMS_ROPE", "True").lower() in
("true", "1")),
# vLLM will use split_qkv + fused RMS + RoPE kernel
"VLLM_USE_QKV_SPLIT_RMS_ROPE":
lambda: (os.environ.get("VLLM_USE_QKV_SPLIT_RMS_ROPE", "False").lower()
in ("true", "1")),
# vLLM will use lightop for dpsk mtp fill + rms*2 + cat
"VLLM_USE_FUSED_FILL_RMS_CAT":
lambda: (os.environ.get("VLLM_USE_FUSED_FILL_RMS_CAT", "False").lower() in
......
......@@ -141,8 +141,35 @@ class Qwen3Attention(nn.Module):
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if envs.VLLM_USE_FUSED_RMS_ROPE and positions.ndim == 1:
use_qkv_split_rms_rope = (envs.VLLM_USE_QKV_SPLIT_RMS_ROPE
and positions.ndim == 1)
if use_qkv_split_rms_rope:
cos_sin_cache = self.rotary_emb.cos_sin_cache
if (cos_sin_cache.device != qkv.device
or cos_sin_cache.dtype != qkv.dtype):
cos_sin_cache = cos_sin_cache.to(qkv.device,
dtype=qkv.dtype,
non_blocking=True)
self.rotary_emb.cos_sin_cache = cos_sin_cache
q, k, v = torch.ops.vllm.qkv_split_rms_rotary_embedding_fuse(
positions,
qkv,
self.q_size,
self.kv_size,
self.head_dim,
cos_sin_cache,
self.rotary_emb.is_neox_style,
self.q_norm.weight,
self.k_norm.weight,
None,
None,
self.q_norm.variance_epsilon,
)
else:
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
if (not use_qkv_split_rms_rope and envs.VLLM_USE_FUSED_RMS_ROPE
and positions.ndim == 1):
# Fused RMSNorm + RoPE path through custom op.
cos_sin_cache = self.rotary_emb.cos_sin_cache
if (cos_sin_cache.device != q.device
......@@ -168,7 +195,8 @@ class Qwen3Attention(nn.Module):
None,
self.q_norm.variance_epsilon,
)
elif envs.VLLM_USE_FUSED_RMS_ROPE and positions.ndim == 2:
elif (not use_qkv_split_rms_rope and envs.VLLM_USE_FUSED_RMS_ROPE
and positions.ndim == 2):
# Fused RMSNorm + M-RoPE path through custom op.
mrope_section = getattr(self.rotary_emb, "mrope_section", None)
assert len(mrope_section) == 3
......
......@@ -366,9 +366,34 @@ class Qwen3MoeAttention(nn.Module):
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
use_qkv_split_rms_rope = (envs.VLLM_USE_QKV_SPLIT_RMS_ROPE
and positions.ndim == 1)
if use_qkv_split_rms_rope:
cos_sin_cache = self.rotary_emb.cos_sin_cache
if (cos_sin_cache.device != qkv.device
or cos_sin_cache.dtype != qkv.dtype):
cos_sin_cache = cos_sin_cache.to(qkv.device,
dtype=qkv.dtype,
non_blocking=True)
self.rotary_emb.cos_sin_cache = cos_sin_cache
q, k, v = torch.ops.vllm.qkv_split_rms_rotary_embedding_fuse(
positions,
qkv,
self.q_size,
self.kv_size,
self.head_dim,
cos_sin_cache,
self.rotary_emb.is_neox_style,
self.q_norm.weight,
self.k_norm.weight,
None,
None,
self.q_norm.variance_epsilon,
)
else:
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# Add qk-norm
if envs.VLLM_USE_FUSED_RMS_ROPE and positions.ndim == 1:
if (not use_qkv_split_rms_rope and envs.VLLM_USE_FUSED_RMS_ROPE
and positions.ndim == 1):
# Fused RMSNorm + RoPE path through custom op.
cos_sin_cache = self.rotary_emb.cos_sin_cache
if (cos_sin_cache.device != q.device
......@@ -376,10 +401,7 @@ class Qwen3MoeAttention(nn.Module):
cos_sin_cache = cos_sin_cache.to(q.device,
dtype=q.dtype,
non_blocking=True)
# Persist the converted cache so we don't re-copy/re-allocate
# on every forward when the original buffer starts on CPU.
self.rotary_emb.cos_sin_cache = cos_sin_cache
# # q, k 使用 continuous
q = q.contiguous()
k = k.contiguous()
torch.ops.vllm.rms_rotary_embedding_fuse(
......@@ -395,8 +417,9 @@ class Qwen3MoeAttention(nn.Module):
None,
self.q_norm.variance_epsilon,
)
elif envs.VLLM_USE_FUSED_RMS_ROPE and positions.ndim == 2 and getattr(
self.rotary_emb, "mrope_section", None) is not None:
elif (not use_qkv_split_rms_rope and envs.VLLM_USE_FUSED_RMS_ROPE
and positions.ndim == 2 and getattr(
self.rotary_emb, "mrope_section", None) is not None):
# Fused RMSNorm + M-RoPE path through custom op.
cos_sin_cache = self.rotary_emb.cos_sin_cache
if (cos_sin_cache.device != q.device
......@@ -447,6 +470,7 @@ class Qwen3MoeAttention(nn.Module):
k_by_head = self.k_norm.forward_cuda(k_by_head)
k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment