Commit e220b38b authored by laibao's avatar laibao
Browse files

[FEATURE] 为 Qwen3/Qwen3Moe 引入 QKV split + RMSNorm + RoPE 融合路径

 为 Qwen3 和 Qwen3Moe 增加可选的 fused QKV split + RMSNorm + RoPE 执行路径,
  减少中间张量拆分与重复计算开销,统一相关模型的优化开关控制逻辑。
parent 2dc182c0
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Literal, Optional from typing import TYPE_CHECKING, Literal, Optional, Tuple
import torch import torch
...@@ -3711,6 +3711,63 @@ direct_register_custom_op( ...@@ -3711,6 +3711,63 @@ direct_register_custom_op(
fake_impl=rms_rotary_embedding_fuse_fake, fake_impl=rms_rotary_embedding_fuse_fake,
) )
"""
qwen3 split_qkv+rn+rope
"""
def qkv_split_rms_rotary_embedding_fuse(
positions: torch.Tensor,
qkv: torch.Tensor,
q_size: int,
kv_size: int,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox_style: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
residual_q: Optional[torch.Tensor],
residual_k: Optional[torch.Tensor],
epsilon: float,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
from lightop import op
return op.qkv_split_rms_rotary_embedding_fuse(
positions,
qkv,
q_size,
kv_size,
head_size,
cos_sin_cache,
is_neox_style,
q_weight,
k_weight,
residual_q,
residual_k,
epsilon,
)
def qkv_split_rms_rotary_embedding_fuse_fake(
positions: torch.Tensor,
qkv: torch.Tensor,
q_size: int,
kv_size: int,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox_style: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
residual_q: Optional[torch.Tensor],
residual_k: Optional[torch.Tensor],
epsilon: float,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
q = qkv.narrow(-1, 0, q_size)
k = qkv.narrow(-1, q_size, kv_size)
v = qkv.narrow(-1, q_size + kv_size, kv_size)
return q, k, v
direct_register_custom_op(
op_name="qkv_split_rms_rotary_embedding_fuse",
op_func=qkv_split_rms_rotary_embedding_fuse,
mutates_args=["qkv"],
fake_impl=qkv_split_rms_rotary_embedding_fuse_fake,
)
""" """
qwen3-vl-8b中LLM模型的修改 rms+mrope dim==2 2026/03/18 qwen3-vl-8b中LLM模型的修改 rms+mrope dim==2 2026/03/18
""" """
......
...@@ -297,6 +297,7 @@ if TYPE_CHECKING: ...@@ -297,6 +297,7 @@ if TYPE_CHECKING:
VLLM_USE_OPT_RESHAPE_AND_CACHE: bool = False VLLM_USE_OPT_RESHAPE_AND_CACHE: bool = False
VLLM_USE_TOPK_RENORM: bool = False VLLM_USE_TOPK_RENORM: bool = False
VLLM_USE_FUSED_RMS_ROPE: bool = False VLLM_USE_FUSED_RMS_ROPE: bool = False
VLLM_USE_QKV_SPLIT_RMS_ROPE: bool = False
VLLM_USE_FUSED_FILL_RMS_CAT: bool = False VLLM_USE_FUSED_FILL_RMS_CAT: bool = False
VLLM_USE_CAT_MLA: bool = False VLLM_USE_CAT_MLA: bool = False
FP8_USE_MIXED_BATCH: bool = False FP8_USE_MIXED_BATCH: bool = False
...@@ -1889,6 +1890,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1889,6 +1890,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_FUSED_RMS_ROPE": "VLLM_USE_FUSED_RMS_ROPE":
lambda: (os.environ.get("VLLM_USE_FUSED_RMS_ROPE", "True").lower() in lambda: (os.environ.get("VLLM_USE_FUSED_RMS_ROPE", "True").lower() in
("true", "1")), ("true", "1")),
# vLLM will use split_qkv + fused RMS + RoPE kernel
"VLLM_USE_QKV_SPLIT_RMS_ROPE":
lambda: (os.environ.get("VLLM_USE_QKV_SPLIT_RMS_ROPE", "False").lower()
in ("true", "1")),
# vLLM will use lightop for dpsk mtp fill + rms*2 + cat # vLLM will use lightop for dpsk mtp fill + rms*2 + cat
"VLLM_USE_FUSED_FILL_RMS_CAT": "VLLM_USE_FUSED_FILL_RMS_CAT":
lambda: (os.environ.get("VLLM_USE_FUSED_FILL_RMS_CAT", "False").lower() in lambda: (os.environ.get("VLLM_USE_FUSED_FILL_RMS_CAT", "False").lower() in
......
...@@ -141,8 +141,35 @@ class Qwen3Attention(nn.Module): ...@@ -141,8 +141,35 @@ class Qwen3Attention(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) use_qkv_split_rms_rope = (envs.VLLM_USE_QKV_SPLIT_RMS_ROPE
if envs.VLLM_USE_FUSED_RMS_ROPE and positions.ndim == 1: and positions.ndim == 1)
if use_qkv_split_rms_rope:
cos_sin_cache = self.rotary_emb.cos_sin_cache
if (cos_sin_cache.device != qkv.device
or cos_sin_cache.dtype != qkv.dtype):
cos_sin_cache = cos_sin_cache.to(qkv.device,
dtype=qkv.dtype,
non_blocking=True)
self.rotary_emb.cos_sin_cache = cos_sin_cache
q, k, v = torch.ops.vllm.qkv_split_rms_rotary_embedding_fuse(
positions,
qkv,
self.q_size,
self.kv_size,
self.head_dim,
cos_sin_cache,
self.rotary_emb.is_neox_style,
self.q_norm.weight,
self.k_norm.weight,
None,
None,
self.q_norm.variance_epsilon,
)
else:
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
if (not use_qkv_split_rms_rope and envs.VLLM_USE_FUSED_RMS_ROPE
and positions.ndim == 1):
# Fused RMSNorm + RoPE path through custom op. # Fused RMSNorm + RoPE path through custom op.
cos_sin_cache = self.rotary_emb.cos_sin_cache cos_sin_cache = self.rotary_emb.cos_sin_cache
if (cos_sin_cache.device != q.device if (cos_sin_cache.device != q.device
...@@ -168,7 +195,8 @@ class Qwen3Attention(nn.Module): ...@@ -168,7 +195,8 @@ class Qwen3Attention(nn.Module):
None, None,
self.q_norm.variance_epsilon, self.q_norm.variance_epsilon,
) )
elif envs.VLLM_USE_FUSED_RMS_ROPE and positions.ndim == 2: elif (not use_qkv_split_rms_rope and envs.VLLM_USE_FUSED_RMS_ROPE
and positions.ndim == 2):
# Fused RMSNorm + M-RoPE path through custom op. # Fused RMSNorm + M-RoPE path through custom op.
mrope_section = getattr(self.rotary_emb, "mrope_section", None) mrope_section = getattr(self.rotary_emb, "mrope_section", None)
assert len(mrope_section) == 3 assert len(mrope_section) == 3
......
...@@ -366,9 +366,34 @@ class Qwen3MoeAttention(nn.Module): ...@@ -366,9 +366,34 @@ class Qwen3MoeAttention(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
use_qkv_split_rms_rope = (envs.VLLM_USE_QKV_SPLIT_RMS_ROPE
and positions.ndim == 1)
if use_qkv_split_rms_rope:
cos_sin_cache = self.rotary_emb.cos_sin_cache
if (cos_sin_cache.device != qkv.device
or cos_sin_cache.dtype != qkv.dtype):
cos_sin_cache = cos_sin_cache.to(qkv.device,
dtype=qkv.dtype,
non_blocking=True)
self.rotary_emb.cos_sin_cache = cos_sin_cache
q, k, v = torch.ops.vllm.qkv_split_rms_rotary_embedding_fuse(
positions,
qkv,
self.q_size,
self.kv_size,
self.head_dim,
cos_sin_cache,
self.rotary_emb.is_neox_style,
self.q_norm.weight,
self.k_norm.weight,
None,
None,
self.q_norm.variance_epsilon,
)
else:
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# Add qk-norm if (not use_qkv_split_rms_rope and envs.VLLM_USE_FUSED_RMS_ROPE
if envs.VLLM_USE_FUSED_RMS_ROPE and positions.ndim == 1: and positions.ndim == 1):
# Fused RMSNorm + RoPE path through custom op. # Fused RMSNorm + RoPE path through custom op.
cos_sin_cache = self.rotary_emb.cos_sin_cache cos_sin_cache = self.rotary_emb.cos_sin_cache
if (cos_sin_cache.device != q.device if (cos_sin_cache.device != q.device
...@@ -376,10 +401,7 @@ class Qwen3MoeAttention(nn.Module): ...@@ -376,10 +401,7 @@ class Qwen3MoeAttention(nn.Module):
cos_sin_cache = cos_sin_cache.to(q.device, cos_sin_cache = cos_sin_cache.to(q.device,
dtype=q.dtype, dtype=q.dtype,
non_blocking=True) non_blocking=True)
# Persist the converted cache so we don't re-copy/re-allocate
# on every forward when the original buffer starts on CPU.
self.rotary_emb.cos_sin_cache = cos_sin_cache self.rotary_emb.cos_sin_cache = cos_sin_cache
# # q, k 使用 continuous
q = q.contiguous() q = q.contiguous()
k = k.contiguous() k = k.contiguous()
torch.ops.vllm.rms_rotary_embedding_fuse( torch.ops.vllm.rms_rotary_embedding_fuse(
...@@ -395,8 +417,9 @@ class Qwen3MoeAttention(nn.Module): ...@@ -395,8 +417,9 @@ class Qwen3MoeAttention(nn.Module):
None, None,
self.q_norm.variance_epsilon, self.q_norm.variance_epsilon,
) )
elif envs.VLLM_USE_FUSED_RMS_ROPE and positions.ndim == 2 and getattr( elif (not use_qkv_split_rms_rope and envs.VLLM_USE_FUSED_RMS_ROPE
self.rotary_emb, "mrope_section", None) is not None: and positions.ndim == 2 and getattr(
self.rotary_emb, "mrope_section", None) is not None):
# Fused RMSNorm + M-RoPE path through custom op. # Fused RMSNorm + M-RoPE path through custom op.
cos_sin_cache = self.rotary_emb.cos_sin_cache cos_sin_cache = self.rotary_emb.cos_sin_cache
if (cos_sin_cache.device != q.device if (cos_sin_cache.device != q.device
...@@ -447,6 +470,7 @@ class Qwen3MoeAttention(nn.Module): ...@@ -447,6 +470,7 @@ class Qwen3MoeAttention(nn.Module):
k_by_head = self.k_norm.forward_cuda(k_by_head) k_by_head = self.k_norm.forward_cuda(k_by_head)
k = k_by_head.view(k.shape) k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment