Commit e220b38b authored by laibao's avatar laibao
Browse files

[FEATURE] 为 Qwen3/Qwen3Moe 引入 QKV split + RMSNorm + RoPE 融合路径

 为 Qwen3 和 Qwen3Moe 增加可选的 fused QKV split + RMSNorm + RoPE 执行路径,
  减少中间张量拆分与重复计算开销,统一相关模型的优化开关控制逻辑。
parent 2dc182c0
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Literal, Optional from typing import TYPE_CHECKING, Literal, Optional, Tuple
import torch import torch
...@@ -3711,6 +3711,63 @@ direct_register_custom_op( ...@@ -3711,6 +3711,63 @@ direct_register_custom_op(
fake_impl=rms_rotary_embedding_fuse_fake, fake_impl=rms_rotary_embedding_fuse_fake,
) )
"""
qwen3 split_qkv+rn+rope
"""
def qkv_split_rms_rotary_embedding_fuse(
positions: torch.Tensor,
qkv: torch.Tensor,
q_size: int,
kv_size: int,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox_style: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
residual_q: Optional[torch.Tensor],
residual_k: Optional[torch.Tensor],
epsilon: float,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
from lightop import op
return op.qkv_split_rms_rotary_embedding_fuse(
positions,
qkv,
q_size,
kv_size,
head_size,
cos_sin_cache,
is_neox_style,
q_weight,
k_weight,
residual_q,
residual_k,
epsilon,
)
def qkv_split_rms_rotary_embedding_fuse_fake(
positions: torch.Tensor,
qkv: torch.Tensor,
q_size: int,
kv_size: int,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox_style: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
residual_q: Optional[torch.Tensor],
residual_k: Optional[torch.Tensor],
epsilon: float,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
q = qkv.narrow(-1, 0, q_size)
k = qkv.narrow(-1, q_size, kv_size)
v = qkv.narrow(-1, q_size + kv_size, kv_size)
return q, k, v
direct_register_custom_op(
op_name="qkv_split_rms_rotary_embedding_fuse",
op_func=qkv_split_rms_rotary_embedding_fuse,
mutates_args=["qkv"],
fake_impl=qkv_split_rms_rotary_embedding_fuse_fake,
)
""" """
qwen3-vl-8b中LLM模型的修改 rms+mrope dim==2 2026/03/18 qwen3-vl-8b中LLM模型的修改 rms+mrope dim==2 2026/03/18
""" """
......
...@@ -297,6 +297,7 @@ if TYPE_CHECKING: ...@@ -297,6 +297,7 @@ if TYPE_CHECKING:
VLLM_USE_OPT_RESHAPE_AND_CACHE: bool = False VLLM_USE_OPT_RESHAPE_AND_CACHE: bool = False
VLLM_USE_TOPK_RENORM: bool = False VLLM_USE_TOPK_RENORM: bool = False
VLLM_USE_FUSED_RMS_ROPE: bool = False VLLM_USE_FUSED_RMS_ROPE: bool = False
VLLM_USE_QKV_SPLIT_RMS_ROPE: bool = False
VLLM_USE_FUSED_FILL_RMS_CAT: bool = False VLLM_USE_FUSED_FILL_RMS_CAT: bool = False
VLLM_USE_CAT_MLA: bool = False VLLM_USE_CAT_MLA: bool = False
FP8_USE_MIXED_BATCH: bool = False FP8_USE_MIXED_BATCH: bool = False
...@@ -1889,6 +1890,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1889,6 +1890,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_FUSED_RMS_ROPE": "VLLM_USE_FUSED_RMS_ROPE":
lambda: (os.environ.get("VLLM_USE_FUSED_RMS_ROPE", "True").lower() in lambda: (os.environ.get("VLLM_USE_FUSED_RMS_ROPE", "True").lower() in
("true", "1")), ("true", "1")),
# vLLM will use split_qkv + fused RMS + RoPE kernel
"VLLM_USE_QKV_SPLIT_RMS_ROPE":
lambda: (os.environ.get("VLLM_USE_QKV_SPLIT_RMS_ROPE", "False").lower()
in ("true", "1")),
# vLLM will use lightop for dpsk mtp fill + rms*2 + cat # vLLM will use lightop for dpsk mtp fill + rms*2 + cat
"VLLM_USE_FUSED_FILL_RMS_CAT": "VLLM_USE_FUSED_FILL_RMS_CAT":
lambda: (os.environ.get("VLLM_USE_FUSED_FILL_RMS_CAT", "False").lower() in lambda: (os.environ.get("VLLM_USE_FUSED_FILL_RMS_CAT", "False").lower() in
......
...@@ -141,24 +141,21 @@ class Qwen3Attention(nn.Module): ...@@ -141,24 +141,21 @@ class Qwen3Attention(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) use_qkv_split_rms_rope = (envs.VLLM_USE_QKV_SPLIT_RMS_ROPE
if envs.VLLM_USE_FUSED_RMS_ROPE and positions.ndim == 1: and positions.ndim == 1)
# Fused RMSNorm + RoPE path through custom op. if use_qkv_split_rms_rope:
cos_sin_cache = self.rotary_emb.cos_sin_cache cos_sin_cache = self.rotary_emb.cos_sin_cache
if (cos_sin_cache.device != q.device if (cos_sin_cache.device != qkv.device
or cos_sin_cache.dtype != q.dtype): or cos_sin_cache.dtype != qkv.dtype):
cos_sin_cache = cos_sin_cache.to(q.device, cos_sin_cache = cos_sin_cache.to(qkv.device,
dtype=q.dtype, dtype=qkv.dtype,
non_blocking=True) non_blocking=True)
# Persist the converted cache so we don't re-copy/re-allocate
# on every forward when the original buffer starts on CPU.
self.rotary_emb.cos_sin_cache = cos_sin_cache self.rotary_emb.cos_sin_cache = cos_sin_cache
q = q.contiguous() q, k, v = torch.ops.vllm.qkv_split_rms_rotary_embedding_fuse(
k = k.contiguous()
torch.ops.vllm.rms_rotary_embedding_fuse(
positions, positions,
q, qkv,
k, self.q_size,
self.kv_size,
self.head_dim, self.head_dim,
cos_sin_cache, cos_sin_cache,
self.rotary_emb.is_neox_style, self.rotary_emb.is_neox_style,
...@@ -168,55 +165,86 @@ class Qwen3Attention(nn.Module): ...@@ -168,55 +165,86 @@ class Qwen3Attention(nn.Module):
None, None,
self.q_norm.variance_epsilon, self.q_norm.variance_epsilon,
) )
elif envs.VLLM_USE_FUSED_RMS_ROPE and positions.ndim == 2:
# Fused RMSNorm + M-RoPE path through custom op.
mrope_section = getattr(self.rotary_emb, "mrope_section", None)
assert len(mrope_section) == 3
cos_sin_cache = self.rotary_emb.cos_sin_cache
if (cos_sin_cache.device != q.device
or cos_sin_cache.dtype != q.dtype):
cos_sin_cache = cos_sin_cache.to(q.device,
dtype=q.dtype,
non_blocking=True)
self.rotary_emb.cos_sin_cache = cos_sin_cache
cos_sin = cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
q = q.contiguous()
k = k.contiguous()
cos = cos.contiguous()
sin = sin.contiguous()
torch.ops.vllm.rms_mrope_fuse(
q,
k,
cos,
sin,
self.head_dim,
self.rotary_emb.rotary_dim,
mrope_section[0],
mrope_section[1],
mrope_section[2],
self.rotary_emb.mrope_interleaved,
self.q_norm.weight,
self.k_norm.weight,
self.q_norm.variance_epsilon,
None,
None,
)
else: else:
# Add qk-norm q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) dim=-1)
if envs.VLLM_USE_APEX_RN: if (not use_qkv_split_rms_rope and envs.VLLM_USE_FUSED_RMS_ROPE
q_by_head = self.q_norm.forward_apex(q_by_head) and positions.ndim == 1):
else: # Fused RMSNorm + RoPE path through custom op.
q_by_head = self.q_norm.forward_cuda(q_by_head) cos_sin_cache = self.rotary_emb.cos_sin_cache
q = q_by_head.view(q.shape) if (cos_sin_cache.device != q.device
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) or cos_sin_cache.dtype != q.dtype):
if envs.VLLM_USE_APEX_RN: cos_sin_cache = cos_sin_cache.to(q.device,
k_by_head = self.k_norm.forward_apex(k_by_head) dtype=q.dtype,
non_blocking=True)
# Persist the converted cache so we don't re-copy/re-allocate
# on every forward when the original buffer starts on CPU.
self.rotary_emb.cos_sin_cache = cos_sin_cache
q = q.contiguous()
k = k.contiguous()
torch.ops.vllm.rms_rotary_embedding_fuse(
positions,
q,
k,
self.head_dim,
cos_sin_cache,
self.rotary_emb.is_neox_style,
self.q_norm.weight,
self.k_norm.weight,
None,
None,
self.q_norm.variance_epsilon,
)
elif (not use_qkv_split_rms_rope and envs.VLLM_USE_FUSED_RMS_ROPE
and positions.ndim == 2):
# Fused RMSNorm + M-RoPE path through custom op.
mrope_section = getattr(self.rotary_emb, "mrope_section", None)
assert len(mrope_section) == 3
cos_sin_cache = self.rotary_emb.cos_sin_cache
if (cos_sin_cache.device != q.device
or cos_sin_cache.dtype != q.dtype):
cos_sin_cache = cos_sin_cache.to(q.device,
dtype=q.dtype,
non_blocking=True)
self.rotary_emb.cos_sin_cache = cos_sin_cache
cos_sin = cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
q = q.contiguous()
k = k.contiguous()
cos = cos.contiguous()
sin = sin.contiguous()
torch.ops.vllm.rms_mrope_fuse(
q,
k,
cos,
sin,
self.head_dim,
self.rotary_emb.rotary_dim,
mrope_section[0],
mrope_section[1],
mrope_section[2],
self.rotary_emb.mrope_interleaved,
self.q_norm.weight,
self.k_norm.weight,
self.q_norm.variance_epsilon,
None,
None,
)
else: else:
k_by_head = self.k_norm.forward_cuda(k_by_head) # Add qk-norm
k = k_by_head.view(k.shape) q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
q, k = self.rotary_emb(positions, q, k) if envs.VLLM_USE_APEX_RN:
q_by_head = self.q_norm.forward_apex(q_by_head)
else:
q_by_head = self.q_norm.forward_cuda(q_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
if envs.VLLM_USE_APEX_RN:
k_by_head = self.k_norm.forward_apex(k_by_head)
else:
k_by_head = self.k_norm.forward_cuda(k_by_head)
k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
......
...@@ -366,26 +366,21 @@ class Qwen3MoeAttention(nn.Module): ...@@ -366,26 +366,21 @@ class Qwen3MoeAttention(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) use_qkv_split_rms_rope = (envs.VLLM_USE_QKV_SPLIT_RMS_ROPE
# Add qk-norm and positions.ndim == 1)
if envs.VLLM_USE_FUSED_RMS_ROPE and positions.ndim == 1: if use_qkv_split_rms_rope:
# Fused RMSNorm + RoPE path through custom op.
cos_sin_cache = self.rotary_emb.cos_sin_cache cos_sin_cache = self.rotary_emb.cos_sin_cache
if (cos_sin_cache.device != q.device if (cos_sin_cache.device != qkv.device
or cos_sin_cache.dtype != q.dtype): or cos_sin_cache.dtype != qkv.dtype):
cos_sin_cache = cos_sin_cache.to(q.device, cos_sin_cache = cos_sin_cache.to(qkv.device,
dtype=q.dtype, dtype=qkv.dtype,
non_blocking=True) non_blocking=True)
# Persist the converted cache so we don't re-copy/re-allocate
# on every forward when the original buffer starts on CPU.
self.rotary_emb.cos_sin_cache = cos_sin_cache self.rotary_emb.cos_sin_cache = cos_sin_cache
# # q, k 使用 continuous q, k, v = torch.ops.vllm.qkv_split_rms_rotary_embedding_fuse(
q = q.contiguous()
k = k.contiguous()
torch.ops.vllm.rms_rotary_embedding_fuse(
positions, positions,
q, qkv,
k, self.q_size,
self.kv_size,
self.head_dim, self.head_dim,
cos_sin_cache, cos_sin_cache,
self.rotary_emb.is_neox_style, self.rotary_emb.is_neox_style,
...@@ -395,58 +390,87 @@ class Qwen3MoeAttention(nn.Module): ...@@ -395,58 +390,87 @@ class Qwen3MoeAttention(nn.Module):
None, None,
self.q_norm.variance_epsilon, self.q_norm.variance_epsilon,
) )
elif envs.VLLM_USE_FUSED_RMS_ROPE and positions.ndim == 2 and getattr(
self.rotary_emb, "mrope_section", None) is not None:
# Fused RMSNorm + M-RoPE path through custom op.
cos_sin_cache = self.rotary_emb.cos_sin_cache
if (cos_sin_cache.device != q.device
or cos_sin_cache.dtype != q.dtype):
cos_sin_cache = cos_sin_cache.to(q.device,
dtype=q.dtype,
non_blocking=True)
self.rotary_emb.cos_sin_cache = cos_sin_cache
cos_sin = cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
q = q.contiguous()
k = k.contiguous()
cos = cos.contiguous()
sin = sin.contiguous()
mrope_section = self.rotary_emb.mrope_section
assert mrope_section is not None and len(mrope_section) == 3
torch.ops.vllm.rms_mrope_fuse(
q,
k,
cos,
sin,
self.head_dim,
self.rotary_emb.rotary_dim,
mrope_section[0],
mrope_section[1],
mrope_section[2],
self.rotary_emb.mrope_interleaved,
self.q_norm.weight,
self.k_norm.weight,
self.q_norm.variance_epsilon,
None,
None,
)
else: else:
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if envs.VLLM_USE_APEX_RN: if (not use_qkv_split_rms_rope and envs.VLLM_USE_FUSED_RMS_ROPE
q_by_head = self.q_norm.forward_apex(q_by_head) and positions.ndim == 1):
# Fused RMSNorm + RoPE path through custom op.
cos_sin_cache = self.rotary_emb.cos_sin_cache
if (cos_sin_cache.device != q.device
or cos_sin_cache.dtype != q.dtype):
cos_sin_cache = cos_sin_cache.to(q.device,
dtype=q.dtype,
non_blocking=True)
self.rotary_emb.cos_sin_cache = cos_sin_cache
q = q.contiguous()
k = k.contiguous()
torch.ops.vllm.rms_rotary_embedding_fuse(
positions,
q,
k,
self.head_dim,
cos_sin_cache,
self.rotary_emb.is_neox_style,
self.q_norm.weight,
self.k_norm.weight,
None,
None,
self.q_norm.variance_epsilon,
)
elif (not use_qkv_split_rms_rope and envs.VLLM_USE_FUSED_RMS_ROPE
and positions.ndim == 2 and getattr(
self.rotary_emb, "mrope_section", None) is not None):
# Fused RMSNorm + M-RoPE path through custom op.
cos_sin_cache = self.rotary_emb.cos_sin_cache
if (cos_sin_cache.device != q.device
or cos_sin_cache.dtype != q.dtype):
cos_sin_cache = cos_sin_cache.to(q.device,
dtype=q.dtype,
non_blocking=True)
self.rotary_emb.cos_sin_cache = cos_sin_cache
cos_sin = cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
q = q.contiguous()
k = k.contiguous()
cos = cos.contiguous()
sin = sin.contiguous()
mrope_section = self.rotary_emb.mrope_section
assert mrope_section is not None and len(mrope_section) == 3
torch.ops.vllm.rms_mrope_fuse(
q,
k,
cos,
sin,
self.head_dim,
self.rotary_emb.rotary_dim,
mrope_section[0],
mrope_section[1],
mrope_section[2],
self.rotary_emb.mrope_interleaved,
self.q_norm.weight,
self.k_norm.weight,
self.q_norm.variance_epsilon,
None,
None,
)
else: else:
q_by_head = self.q_norm.forward_cuda(q_by_head) q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
q = q_by_head.view(q.shape) if envs.VLLM_USE_APEX_RN:
q_by_head = self.q_norm.forward_apex(q_by_head)
else:
q_by_head = self.q_norm.forward_cuda(q_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
if envs.VLLM_USE_APEX_RN:
k_by_head = self.k_norm.forward_apex(k_by_head)
else:
k_by_head = self.k_norm.forward_cuda(k_by_head)
k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions, q, k)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
if envs.VLLM_USE_APEX_RN:
k_by_head = self.k_norm.forward_apex(k_by_head)
else:
k_by_head = self.k_norm.forward_cuda(k_by_head)
k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment