Commit dfb597c8 authored by wangmin6's avatar wangmin6
Browse files

Merge branch 'gy-015-qwen3py-fused_mrope' into 'v0.15.1-dev'

qwen3.py合入fused_morpe

See merge request dcutoolkit/deeplearing/vllm!516
parents fca0956a ef79626d
...@@ -3628,4 +3628,124 @@ direct_register_custom_op( ...@@ -3628,4 +3628,124 @@ direct_register_custom_op(
op_func=fused_add_rms_norm_opt, op_func=fused_add_rms_norm_opt,
mutates_args=[], mutates_args=[],
fake_impl=fused_add_rms_norm_opt_fake, fake_impl=fused_add_rms_norm_opt_fake,
)
"""
qwen3-vl-8b中LLM的修改 rms+mrope dim==1 2026/03/18
"""
def rms_rotary_embedding_fuse(
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor],
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox_style: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_bias: Optional[torch.Tensor],
k_bias: Optional[torch.Tensor],
epsilon: float,
) -> None:
from lightop import rms_rotary_embedding_fuse as fused_kernel
fused_kernel(
positions,
query,
key,
head_size,
cos_sin_cache,
is_neox_style,
q_weight,
k_weight,
q_bias,
k_bias,
epsilon,
)
def rms_rotary_embedding_fuse_fake(
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor],
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox_style: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_bias: Optional[torch.Tensor],
k_bias: Optional[torch.Tensor],
epsilon: float,
) -> None:
# Fake impl intentionally left as no-op for graph tracing modes.
pass
direct_register_custom_op(
op_name="rms_rotary_embedding_fuse",
op_func=rms_rotary_embedding_fuse,
mutates_args=["query", "key"],
fake_impl=rms_rotary_embedding_fuse_fake,
)
"""
qwen3-vl-8b中LLM模型的修改 rms+mrope dim==2 2026/03/18
"""
def rms_mrope_fuse_fake(
query: torch.Tensor,
key: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
head_size: int,
rotary_dim: int,
mrope_section_t: int,
mrope_section_h: int,
mrope_section_w: int,
mrope_interleaved: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
epsilon: float,
q_residual: torch.Tensor | None = None,
k_residual: torch.Tensor | None = None,
) -> None:
# Fake impl intentionally left as no-op for graph tracing modes.
pass
def rms_mrope_fuse(
query: torch.Tensor,
key: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
head_size: int,
rotary_dim: int,
mrope_section_t: int,
mrope_section_h: int,
mrope_section_w: int,
mrope_interleaved: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
epsilon: float,
q_residual: torch.Tensor | None = None,
k_residual: torch.Tensor | None = None,
) -> None:
from lightop import op as lightop_ops
lightop_ops.fuse_rms_mrope_cuda(
query,
key,
cos,
sin,
[mrope_section_t, mrope_section_h, mrope_section_w],
head_size,
rotary_dim,
mrope_interleaved,
q_weight,
k_weight,
q_residual,
k_residual,
epsilon,
)
direct_register_custom_op(
op_name="rms_mrope_fuse",
op_func=rms_mrope_fuse,
mutates_args=["query","key"],
fake_impl=rms_mrope_fuse_fake,
) )
\ No newline at end of file
...@@ -51,11 +51,9 @@ from .qwen2 import Qwen2Model ...@@ -51,11 +51,9 @@ from .qwen2 import Qwen2Model
from .utils import AutoWeightsLoader, PPMissingLayer, extract_layer_index, maybe_prefix from .utils import AutoWeightsLoader, PPMissingLayer, extract_layer_index, maybe_prefix
import vllm.envs as envs import vllm.envs as envs
from vllm.utils import direct_register_custom_op from vllm import _custom_ops as ops
logger = init_logger(__name__) logger = init_logger(__name__)
class Qwen3Attention(nn.Module): class Qwen3Attention(nn.Module):
def __init__( def __init__(
self, self,
...@@ -137,58 +135,6 @@ class Qwen3Attention(nn.Module): ...@@ -137,58 +135,6 @@ class Qwen3Attention(nn.Module):
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
def rms_rotary_embedding_fuse(
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor],
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox_style: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_bias: Optional[torch.Tensor],
k_bias: Optional[torch.Tensor],
epsilon: float,
) -> None:
from lightop import rms_rotary_embedding_fuse as fused_kernel
fused_kernel(
positions,
query,
key,
head_size,
cos_sin_cache,
is_neox_style,
q_weight,
k_weight,
q_bias,
k_bias,
epsilon,
)
def rms_rotary_embedding_fuse_fake(
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor],
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox_style: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_bias: Optional[torch.Tensor],
k_bias: Optional[torch.Tensor],
epsilon: float,
) -> None:
# Fake impl intentionally left as no-op for graph tracing modes.
pass
if not hasattr(torch.ops.vllm, "rms_rotary_embedding_fuse"):
direct_register_custom_op(
op_name="rms_rotary_embedding_fuse",
op_func=rms_rotary_embedding_fuse,
mutates_args=["query", "key"],
fake_impl=rms_rotary_embedding_fuse_fake,
)
def forward( def forward(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
...@@ -202,8 +148,8 @@ class Qwen3Attention(nn.Module): ...@@ -202,8 +148,8 @@ class Qwen3Attention(nn.Module):
if (cos_sin_cache.device != q.device if (cos_sin_cache.device != q.device
or cos_sin_cache.dtype != q.dtype): or cos_sin_cache.dtype != q.dtype):
cos_sin_cache = cos_sin_cache.to(q.device, cos_sin_cache = cos_sin_cache.to(q.device,
dtype=q.dtype, dtype=q.dtype,
non_blocking=True) non_blocking=True)
# Persist the converted cache so we don't re-copy/re-allocate # Persist the converted cache so we don't re-copy/re-allocate
# on every forward when the original buffer starts on CPU. # on every forward when the original buffer starts on CPU.
self.rotary_emb.cos_sin_cache = cos_sin_cache self.rotary_emb.cos_sin_cache = cos_sin_cache
...@@ -222,6 +168,40 @@ class Qwen3Attention(nn.Module): ...@@ -222,6 +168,40 @@ class Qwen3Attention(nn.Module):
None, None,
self.q_norm.variance_epsilon, self.q_norm.variance_epsilon,
) )
elif envs.VLLM_USE_FUSED_RMS_ROPE and positions.ndim == 2:
# Fused RMSNorm + M-RoPE path through custom op.
mrope_section = getattr(self.rotary_emb, "mrope_section", None)
assert len(mrope_section) == 3
cos_sin_cache = self.rotary_emb.cos_sin_cache
if (cos_sin_cache.device != q.device
or cos_sin_cache.dtype != q.dtype):
cos_sin_cache = cos_sin_cache.to(q.device,
dtype=q.dtype,
non_blocking=True)
self.rotary_emb.cos_sin_cache = cos_sin_cache
cos_sin = cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
q = q.contiguous()
k = k.contiguous()
cos = cos.contiguous()
sin = sin.contiguous()
torch.ops.vllm.rms_mrope_fuse(
q,
k,
cos,
sin,
self.head_dim,
self.rotary_emb.rotary_dim,
mrope_section[0],
mrope_section[1],
mrope_section[2],
self.rotary_emb.mrope_interleaved,
self.q_norm.weight,
self.k_norm.weight,
self.q_norm.variance_epsilon,
None,
None,
)
else: else:
# Add qk-norm # Add qk-norm
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
......
...@@ -96,7 +96,6 @@ from vllm import _custom_ops as ops ...@@ -96,7 +96,6 @@ from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf from vllm.model_executor.utils import pad_weight, gemm_bank_conf
from vllm.utils import W8a8GetCacheJSON from vllm.utils import W8a8GetCacheJSON
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -361,122 +360,6 @@ class Qwen3MoeAttention(nn.Module): ...@@ -361,122 +360,6 @@ class Qwen3MoeAttention(nn.Module):
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
def rms_rotary_embedding_fuse(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox_style: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
epsilon: float,
q_bias: torch.Tensor | None = None,
k_bias: torch.Tensor | None = None,
) -> None:
from lightop import rms_rotary_embedding_fuse as fused_kernel
fused_kernel(
positions,
query,
key,
head_size,
cos_sin_cache,
is_neox_style,
q_weight,
k_weight,
q_bias,
k_bias,
epsilon,
)
def rms_rotary_embedding_fuse_fake(
# q_out:torch.Tensor,
# k_out:torch.Tensor,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox_style: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
epsilon: float,
q_bias: torch.Tensor | None = None,
k_bias: torch.Tensor | None = None,
) -> None:
# Fake impl intentionally left as no-op for graph tracing modes.
pass
direct_register_custom_op(
op_name="rms_rotary_embedding_fuse",
op_func=rms_rotary_embedding_fuse,
mutates_args=["query", "key"],
fake_impl=rms_rotary_embedding_fuse_fake,
)
def rms_mrope_fuse(
query: torch.Tensor,
key: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
head_size: int,
rotary_dim: int,
mrope_section_t: int,
mrope_section_h: int,
mrope_section_w: int,
mrope_interleaved: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
epsilon: float,
q_residual: torch.Tensor | None = None,
k_residual: torch.Tensor | None = None,
) -> None:
from lightop import op as lightop_ops
lightop_ops.fuse_rms_mrope_cuda(
query,
key,
cos,
sin,
[mrope_section_t, mrope_section_h, mrope_section_w],
head_size,
rotary_dim,
mrope_interleaved,
q_weight,
k_weight,
q_residual,
k_residual,
epsilon,
)
def rms_mrope_fuse_fake(
query: torch.Tensor,
key: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
head_size: int,
rotary_dim: int,
mrope_section_t: int,
mrope_section_h: int,
mrope_section_w: int,
mrope_interleaved: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
epsilon: float,
q_residual: torch.Tensor | None = None,
k_residual: torch.Tensor | None = None,
) -> None:
# Fake impl intentionally left as no-op for graph tracing modes.
pass
direct_register_custom_op(
op_name="rms_mrope_fuse",
op_func=rms_mrope_fuse,
mutates_args=["query", "key"],
fake_impl=rms_mrope_fuse_fake,
)
def forward( def forward(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment