Commit 04d429f6 authored by guanyu1's avatar guanyu1
Browse files

qwen3.py合入fused_morpe

parent 7676d0c9
......@@ -3628,4 +3628,124 @@ direct_register_custom_op(
op_func=fused_add_rms_norm_opt,
mutates_args=[],
fake_impl=fused_add_rms_norm_opt_fake,
)
"""
qwen3-vl-8b中LLM的修改 rms+mrope dim==1 2026/03/18
"""
def rms_rotary_embedding_fuse(
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor],
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox_style: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_bias: Optional[torch.Tensor],
k_bias: Optional[torch.Tensor],
epsilon: float,
) -> None:
from lightop import rms_rotary_embedding_fuse as fused_kernel
fused_kernel(
positions,
query,
key,
head_size,
cos_sin_cache,
is_neox_style,
q_weight,
k_weight,
q_bias,
k_bias,
epsilon,
)
def rms_rotary_embedding_fuse_fake(
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor],
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox_style: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_bias: Optional[torch.Tensor],
k_bias: Optional[torch.Tensor],
epsilon: float,
) -> None:
# Fake impl intentionally left as no-op for graph tracing modes.
pass
if not hasattr(torch.ops.vllm, "rms_rotary_embedding_fuse"):
direct_register_custom_op(
op_name="rms_rotary_embedding_fuse",
op_func=rms_rotary_embedding_fuse,
mutates_args=["query", "key"],
fake_impl=rms_rotary_embedding_fuse_fake,
)
"""
qwen3-vl-8b中LLM模型的修改 rms+mrope dim==2 2026/03/18
"""
def rms_mrope_fuse_fake(
query: torch.Tensor,
key: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
head_size: int,
rotary_dim: int,
mrope_section_t: int,
mrope_section_h: int,
mrope_section_w: int,
mrope_interleaved: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
epsilon: float,
q_residual: torch.Tensor | None = None,
k_residual: torch.Tensor | None = None,
) -> None:
# Fake impl intentionally left as no-op for graph tracing modes.
pass
def rms_mrope_fuse(
query: torch.Tensor,
key: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
head_size: int,
rotary_dim: int,
mrope_section_t: int,
mrope_section_h: int,
mrope_section_w: int,
mrope_interleaved: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
epsilon: float,
q_residual: torch.Tensor | None = None,
k_residual: torch.Tensor | None = None,
) -> None:
from lightop import op as lightop_ops
lightop_ops.fuse_rms_mrope_cuda(
query,
key,
cos,
sin,
[mrope_section_t, mrope_section_h, mrope_section_w],
head_size,
rotary_dim,
mrope_interleaved,
q_weight,
k_weight,
q_residual,
k_residual,
epsilon,
)
direct_register_custom_op(
op_name="rms_mrope_fuse",
op_func=rms_mrope_fuse,
mutates_args=["query","key"],
fake_impl=rms_mrope_fuse_fake,
)
\ No newline at end of file
......@@ -51,8 +51,7 @@ from .qwen2 import Qwen2Model
from .utils import AutoWeightsLoader, PPMissingLayer, extract_layer_index, maybe_prefix
import vllm.envs as envs
from vllm.utils import direct_register_custom_op
from vllm import _custom_ops as ops
logger = init_logger(__name__)
......@@ -137,58 +136,6 @@ class Qwen3Attention(nn.Module):
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
def rms_rotary_embedding_fuse(
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor],
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox_style: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_bias: Optional[torch.Tensor],
k_bias: Optional[torch.Tensor],
epsilon: float,
) -> None:
from lightop import rms_rotary_embedding_fuse as fused_kernel
fused_kernel(
positions,
query,
key,
head_size,
cos_sin_cache,
is_neox_style,
q_weight,
k_weight,
q_bias,
k_bias,
epsilon,
)
def rms_rotary_embedding_fuse_fake(
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor],
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox_style: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_bias: Optional[torch.Tensor],
k_bias: Optional[torch.Tensor],
epsilon: float,
) -> None:
# Fake impl intentionally left as no-op for graph tracing modes.
pass
if not hasattr(torch.ops.vllm, "rms_rotary_embedding_fuse"):
direct_register_custom_op(
op_name="rms_rotary_embedding_fuse",
op_func=rms_rotary_embedding_fuse,
mutates_args=["query", "key"],
fake_impl=rms_rotary_embedding_fuse_fake,
)
def forward(
self,
positions: torch.Tensor,
......@@ -196,33 +143,87 @@ class Qwen3Attention(nn.Module):
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
used_fused = False
if envs.VLLM_USE_FUSED_RMS_ROPE and positions.ndim == 1:
# Fused RMSNorm + RoPE path through custom op.
cos_sin_cache = self.rotary_emb.cos_sin_cache
if (cos_sin_cache.device != q.device
or cos_sin_cache.dtype != q.dtype):
cos_sin_cache = cos_sin_cache.to(q.device,
dtype=q.dtype,
non_blocking=True)
# Persist the converted cache so we don't re-copy/re-allocate
# on every forward when the original buffer starts on CPU.
self.rotary_emb.cos_sin_cache = cos_sin_cache
q = q.contiguous()
k = k.contiguous()
torch.ops.vllm.rms_rotary_embedding_fuse(
positions,
q,
k,
self.head_dim,
cos_sin_cache,
self.rotary_emb.is_neox_style,
self.q_norm.weight,
self.k_norm.weight,
None,
None,
self.q_norm.variance_epsilon,
)
else:
if hasattr(torch.ops.vllm, "rms_rotary_embedding_fuse"):
# Fused RMSNorm + RoPE path through custom op.
cos_sin_cache = self.rotary_emb.cos_sin_cache
if (cos_sin_cache.device != q.device
or cos_sin_cache.dtype != q.dtype):
cos_sin_cache = cos_sin_cache.to(q.device,
dtype=q.dtype,
non_blocking=True)
# Persist the converted cache so we don't re-copy/re-allocate
# on every forward when the original buffer starts on CPU.
self.rotary_emb.cos_sin_cache = cos_sin_cache
q = q.contiguous()
k = k.contiguous()
torch.ops.vllm.rms_rotary_embedding_fuse(
positions,
q,
k,
self.head_dim,
cos_sin_cache,
self.rotary_emb.is_neox_style,
self.q_norm.weight,
self.k_norm.weight,
None,
None,
self.q_norm.variance_epsilon,
)
used_fused = True
else:
logger.warning_once(
"VLLM_USE_FUSED_RMS_ROPE is enabled and positions.ndim == 1, "
"but the RoPE fused op is unavailable; falling back to the "
"default RMSNorm + RoPE path."
)
elif envs.VLLM_USE_FUSED_RMS_ROPE and positions.ndim == 2:
mrope_section = getattr(self.rotary_emb, "mrope_section", None)
if mrope_section is not None and hasattr(torch.ops.vllm,
"rms_mrope_fuse"):
# Fused RMSNorm + M-RoPE path through custom op.
cos_sin_cache = self.rotary_emb.cos_sin_cache
if (cos_sin_cache.device != q.device
or cos_sin_cache.dtype != q.dtype):
cos_sin_cache = cos_sin_cache.to(q.device,
dtype=q.dtype,
non_blocking=True)
self.rotary_emb.cos_sin_cache = cos_sin_cache
cos_sin = cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
q = q.contiguous()
k = k.contiguous()
cos = cos.contiguous()
sin = sin.contiguous()
assert len(mrope_section) == 3
torch.ops.vllm.rms_mrope_fuse(
q,
k,
cos,
sin,
self.head_dim,
self.rotary_emb.rotary_dim,
mrope_section[0],
mrope_section[1],
mrope_section[2],
self.rotary_emb.mrope_interleaved,
self.q_norm.weight,
self.k_norm.weight,
self.q_norm.variance_epsilon,
None,
None,
)
used_fused = True
else:
logger.warning_once(
"VLLM_USE_FUSED_RMS_ROPE is enabled and positions.ndim == 2, "
"but the M-RoPE fused op is unavailable; falling back to the "
"default RMSNorm + RoPE path."
)
if not used_fused:
# Add qk-norm
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
if envs.VLLM_USE_APEX_RN:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment