Commit e08422ae authored by laibao's avatar laibao
Browse files

[feat] 支持 mRoPE 的 fused RMSNorm+RoPE 路径,并修正 torch.compile 动态维度标注

实现了用于优化张量计算的 rms_mrope_fuse 和 rms_mrope_fuse_fake 方法
更新了 forward:在满足条件时走新的 M-RoPE 融合路径
增强了 Qwen3MoeModel 对动态参数维度的支持,以适配该功能
parent ca4598a4
......@@ -325,6 +325,67 @@ class Qwen3MoeAttention(nn.Module):
fake_impl=rms_rotary_embedding_fuse_fake,
)
def rms_mrope_fuse(
query: torch.Tensor,
key: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
head_size: int,
rotary_dim: int,
mrope_section_t: int,
mrope_section_h: int,
mrope_section_w: int,
mrope_interleaved: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_residual: Optional[torch.Tensor],
k_residual: Optional[torch.Tensor],
epsilon: float,
) -> None:
from lightop import op as lightop_ops
lightop_ops.fuse_rms_mrope_cuda(
query,
key,
cos,
sin,
[mrope_section_t, mrope_section_h, mrope_section_w],
head_size,
rotary_dim,
mrope_interleaved,
q_weight,
k_weight,
q_residual,
k_residual,
epsilon,
)
def rms_mrope_fuse_fake(
query: torch.Tensor,
key: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
head_size: int,
rotary_dim: int,
mrope_section_t: int,
mrope_section_h: int,
mrope_section_w: int,
mrope_interleaved: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_residual: Optional[torch.Tensor],
k_residual: Optional[torch.Tensor],
epsilon: float,
) -> None:
# Fake impl intentionally left as no-op for graph tracing modes.
pass
direct_register_custom_op(
op_name="rms_mrope_fuse",
op_func=rms_mrope_fuse,
mutates_args=["query", "key"],
fake_impl=rms_mrope_fuse_fake,
)
def forward(
self,
positions: torch.Tensor,
......@@ -333,7 +394,7 @@ class Qwen3MoeAttention(nn.Module):
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# Add qk-norm
if envs.VLLM_USE_FUSED_RMS_ROPE :
if envs.VLLM_USE_FUSED_RMS_ROPE and positions.ndim == 1:
# Fused RMSNorm + RoPE path through custom op.
cos_sin_cache = self.rotary_emb.cos_sin_cache
if (cos_sin_cache.device != q.device
......@@ -361,6 +422,44 @@ class Qwen3MoeAttention(nn.Module):
self.q_norm.variance_epsilon,
)
elif envs.VLLM_USE_FUSED_RMS_ROPE and positions.ndim == 2 and getattr(
self.rotary_emb, "mrope_section", None) is not None:
# Fused RMSNorm + M-RoPE path through custom op.
cos_sin_cache = self.rotary_emb.cos_sin_cache
if (cos_sin_cache.device != q.device
or cos_sin_cache.dtype != q.dtype):
cos_sin_cache = cos_sin_cache.to(q.device,
dtype=q.dtype,
non_blocking=True)
self.rotary_emb.cos_sin_cache = cos_sin_cache
cos_sin = cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
q = q.contiguous()
k = k.contiguous()
cos = cos.contiguous()
sin = sin.contiguous()
mrope_section = self.rotary_emb.mrope_section
assert mrope_section is not None and len(mrope_section) == 3
torch.ops.vllm.rms_mrope_fuse(
q,
k,
cos,
sin,
self.head_dim,
self.rotary_emb.rotary_dim,
mrope_section[0],
mrope_section[1],
mrope_section[2],
self.rotary_emb.mrope_interleaved,
self.q_norm.weight,
self.k_norm.weight,
None,
None,
self.q_norm.variance_epsilon,
)
else:
# Add qk-norm then RoPE (original path).
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
......@@ -462,7 +561,15 @@ class Qwen3MoeDecoderLayer(nn.Module):
return hidden_states, residual
@support_torch_compile
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
# positions is of shape (3, seq_len) if mrope is enabled,
# otherwise (seq_len, ).
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
})
class Qwen3MoeModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment