Commit 3cd3d1e6 authored by zhuwenwen's avatar zhuwenwen
Browse files

perf(qwen3): 融合 q/k RMSNorm + RoPE

新增 VLLM_USE_FUSED_RMS_ROPE 分支,走 fused 路径
注册 torch.ops.vllm.rms_rotary_embedding_fuse(direct_register_custom_op)
cos_sin_cache 自动转 device/dtype 并缓存,避免每次重复拷贝
parent 80a6b121
...@@ -50,6 +50,7 @@ from .qwen2 import Qwen2Model ...@@ -50,6 +50,7 @@ from .qwen2 import Qwen2Model
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
maybe_prefix) maybe_prefix)
import vllm.envs as envs import vllm.envs as envs
from vllm.utils import direct_register_custom_op
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -137,6 +138,58 @@ class Qwen3Attention(nn.Module): ...@@ -137,6 +138,58 @@ class Qwen3Attention(nn.Module):
) )
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
def rms_rotary_embedding_fuse(
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor],
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox_style: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_bias: Optional[torch.Tensor],
k_bias: Optional[torch.Tensor],
epsilon: float,
) -> None:
from lightop import rms_rotary_embedding_fuse as fused_kernel
fused_kernel(
positions,
query,
key,
head_size,
cos_sin_cache,
is_neox_style,
q_weight,
k_weight,
q_bias,
k_bias,
epsilon,
)
def rms_rotary_embedding_fuse_fake(
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor],
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox_style: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_bias: Optional[torch.Tensor],
k_bias: Optional[torch.Tensor],
epsilon: float,
) -> None:
# Fake impl intentionally left as no-op for graph tracing modes.
pass
if not hasattr(torch.ops.vllm, "rms_rotary_embedding_fuse"):
direct_register_custom_op(
op_name="rms_rotary_embedding_fuse",
op_func=rms_rotary_embedding_fuse,
mutates_args=["query", "key"],
fake_impl=rms_rotary_embedding_fuse_fake,
)
def forward( def forward(
self, self,
...@@ -145,22 +198,49 @@ class Qwen3Attention(nn.Module): ...@@ -145,22 +198,49 @@ class Qwen3Attention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# Add qk-norm if envs.VLLM_USE_FUSED_RMS_ROPE:
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, # Fused RMSNorm + RoPE path through custom op.
self.head_dim) cos_sin_cache = self.rotary_emb.cos_sin_cache
if envs.VLLM_USE_APEX_RN: if (cos_sin_cache.device != q.device
q_by_head = self.q_norm.forward_apex(q_by_head) or cos_sin_cache.dtype != q.dtype):
else: cos_sin_cache = cos_sin_cache.to(q.device,
q_by_head = self.q_norm.forward_cuda(q_by_head) dtype=q.dtype,
q = q_by_head.view(q.shape) non_blocking=True)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, # Persist the converted cache so we don't re-copy/re-allocate
self.head_dim) # on every forward when the original buffer starts on CPU.
if envs.VLLM_USE_APEX_RN: self.rotary_emb.cos_sin_cache = cos_sin_cache
k_by_head = self.k_norm.forward_apex(k_by_head) q = q.contiguous()
k = k.contiguous()
torch.ops.vllm.rms_rotary_embedding_fuse(
positions,
q,
k,
self.head_dim,
cos_sin_cache,
self.rotary_emb.is_neox_style,
self.q_norm.weight,
self.k_norm.weight,
None,
None,
self.q_norm.variance_epsilon,
)
else: else:
k_by_head = self.k_norm.forward_cuda(k_by_head) # Add qk-norm
k = k_by_head.view(k.shape) q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
q, k = self.rotary_emb(positions, q, k) self.head_dim)
if envs.VLLM_USE_APEX_RN:
q_by_head = self.q_norm.forward_apex(q_by_head)
else:
q_by_head = self.q_norm.forward_cuda(q_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
self.head_dim)
if envs.VLLM_USE_APEX_RN:
k_by_head = self.k_norm.forward_apex(k_by_head)
else:
k_by_head = self.k_norm.forward_cuda(k_by_head)
k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment