Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ef79626d
Commit
ef79626d
authored
Mar 18, 2026
by
guanyu1
Browse files
修改VLLM_USE_FUSED_RMS_ROPE的不同路径
parent
c1cd5334
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
58 additions
and
196 deletions
+58
-196
vllm/model_executor/models/qwen3.py
vllm/model_executor/models/qwen3.py
+58
-79
vllm/model_executor/models/qwen3_moe.py
vllm/model_executor/models/qwen3_moe.py
+0
-117
No files found.
vllm/model_executor/models/qwen3.py
View file @
ef79626d
...
...
@@ -54,7 +54,6 @@ import vllm.envs as envs
from
vllm
import
_custom_ops
as
ops
logger
=
init_logger
(
__name__
)
class
Qwen3Attention
(
nn
.
Module
):
def
__init__
(
self
,
...
...
@@ -143,87 +142,67 @@ class Qwen3Attention(nn.Module):
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
used_fused
=
False
if
envs
.
VLLM_USE_FUSED_RMS_ROPE
and
positions
.
ndim
==
1
:
if
hasattr
(
torch
.
ops
.
vllm
,
"rms_rotary_embedding_fuse"
):
# Fused RMSNorm + RoPE path through custom op.
cos_sin_cache
=
self
.
rotary_emb
.
cos_sin_cache
if
(
cos_sin_cache
.
device
!=
q
.
device
or
cos_sin_cache
.
dtype
!=
q
.
dtype
):
cos_sin_cache
=
cos_sin_cache
.
to
(
q
.
device
,
dtype
=
q
.
dtype
,
non_blocking
=
True
)
# Persist the converted cache so we don't re-copy/re-allocate
# on every forward when the original buffer starts on CPU.
self
.
rotary_emb
.
cos_sin_cache
=
cos_sin_cache
q
=
q
.
contiguous
()
k
=
k
.
contiguous
()
torch
.
ops
.
vllm
.
rms_rotary_embedding_fuse
(
positions
,
q
,
k
,
self
.
head_dim
,
cos_sin_cache
,
self
.
rotary_emb
.
is_neox_style
,
self
.
q_norm
.
weight
,
self
.
k_norm
.
weight
,
None
,
None
,
self
.
q_norm
.
variance_epsilon
,
)
used_fused
=
True
else
:
logger
.
warning_once
(
"VLLM_USE_FUSED_RMS_ROPE is enabled and positions.ndim == 1, "
"but the RoPE fused op is unavailable; falling back to the "
"default RMSNorm + RoPE path."
)
# Fused RMSNorm + RoPE path through custom op.
cos_sin_cache
=
self
.
rotary_emb
.
cos_sin_cache
if
(
cos_sin_cache
.
device
!=
q
.
device
or
cos_sin_cache
.
dtype
!=
q
.
dtype
):
cos_sin_cache
=
cos_sin_cache
.
to
(
q
.
device
,
dtype
=
q
.
dtype
,
non_blocking
=
True
)
# Persist the converted cache so we don't re-copy/re-allocate
# on every forward when the original buffer starts on CPU.
self
.
rotary_emb
.
cos_sin_cache
=
cos_sin_cache
q
=
q
.
contiguous
()
k
=
k
.
contiguous
()
torch
.
ops
.
vllm
.
rms_rotary_embedding_fuse
(
positions
,
q
,
k
,
self
.
head_dim
,
cos_sin_cache
,
self
.
rotary_emb
.
is_neox_style
,
self
.
q_norm
.
weight
,
self
.
k_norm
.
weight
,
None
,
None
,
self
.
q_norm
.
variance_epsilon
,
)
elif
envs
.
VLLM_USE_FUSED_RMS_ROPE
and
positions
.
ndim
==
2
:
# Fused RMSNorm + M-RoPE path through custom op.
mrope_section
=
getattr
(
self
.
rotary_emb
,
"mrope_section"
,
None
)
if
mrope_section
is
not
None
and
hasattr
(
torch
.
ops
.
vllm
,
"rms_mrope_fuse"
):
# Fused RMSNorm + M-RoPE path through custom op.
cos_sin_cache
=
self
.
rotary_emb
.
cos_sin_cache
if
(
cos_sin_cache
.
device
!=
q
.
device
or
cos_sin_cache
.
dtype
!=
q
.
dtype
):
cos_sin_cache
=
cos_sin_cache
.
to
(
q
.
device
,
dtype
=
q
.
dtype
,
non_blocking
=
True
)
self
.
rotary_emb
.
cos_sin_cache
=
cos_sin_cache
cos_sin
=
cos_sin_cache
[
positions
]
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
q
=
q
.
contiguous
()
k
=
k
.
contiguous
()
cos
=
cos
.
contiguous
()
sin
=
sin
.
contiguous
()
assert
len
(
mrope_section
)
==
3
torch
.
ops
.
vllm
.
rms_mrope_fuse
(
q
,
k
,
cos
,
sin
,
self
.
head_dim
,
self
.
rotary_emb
.
rotary_dim
,
mrope_section
[
0
],
mrope_section
[
1
],
mrope_section
[
2
],
self
.
rotary_emb
.
mrope_interleaved
,
self
.
q_norm
.
weight
,
self
.
k_norm
.
weight
,
self
.
q_norm
.
variance_epsilon
,
None
,
None
,
)
used_fused
=
True
else
:
logger
.
warning_once
(
"VLLM_USE_FUSED_RMS_ROPE is enabled and positions.ndim == 2, "
"but the M-RoPE fused op is unavailable; falling back to the "
"default RMSNorm + RoPE path."
)
if
not
used_fused
:
assert
len
(
mrope_section
)
==
3
cos_sin_cache
=
self
.
rotary_emb
.
cos_sin_cache
if
(
cos_sin_cache
.
device
!=
q
.
device
or
cos_sin_cache
.
dtype
!=
q
.
dtype
):
cos_sin_cache
=
cos_sin_cache
.
to
(
q
.
device
,
dtype
=
q
.
dtype
,
non_blocking
=
True
)
self
.
rotary_emb
.
cos_sin_cache
=
cos_sin_cache
cos_sin
=
cos_sin_cache
[
positions
]
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
q
=
q
.
contiguous
()
k
=
k
.
contiguous
()
cos
=
cos
.
contiguous
()
sin
=
sin
.
contiguous
()
torch
.
ops
.
vllm
.
rms_mrope_fuse
(
q
,
k
,
cos
,
sin
,
self
.
head_dim
,
self
.
rotary_emb
.
rotary_dim
,
mrope_section
[
0
],
mrope_section
[
1
],
mrope_section
[
2
],
self
.
rotary_emb
.
mrope_interleaved
,
self
.
q_norm
.
weight
,
self
.
k_norm
.
weight
,
self
.
q_norm
.
variance_epsilon
,
None
,
None
,
)
else
:
# Add qk-norm
q_by_head
=
q
.
view
(
*
q
.
shape
[:
-
1
],
q
.
shape
[
-
1
]
//
self
.
head_dim
,
self
.
head_dim
)
if
envs
.
VLLM_USE_APEX_RN
:
...
...
vllm/model_executor/models/qwen3_moe.py
View file @
ef79626d
...
...
@@ -96,7 +96,6 @@ from vllm import _custom_ops as ops
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
from
vllm.utils
import
W8a8GetCacheJSON
from
vllm.utils.torch_utils
import
direct_register_custom_op
logger
=
init_logger
(
__name__
)
...
...
@@ -361,122 +360,6 @@ class Qwen3MoeAttention(nn.Module):
self
.
q_norm
=
RMSNorm
(
self
.
head_dim
,
eps
=
rms_norm_eps
)
self
.
k_norm
=
RMSNorm
(
self
.
head_dim
,
eps
=
rms_norm_eps
)
def
rms_rotary_embedding_fuse
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
|
None
,
head_size
:
int
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox_style
:
bool
,
q_weight
:
torch
.
Tensor
,
k_weight
:
torch
.
Tensor
,
epsilon
:
float
,
q_bias
:
torch
.
Tensor
|
None
=
None
,
k_bias
:
torch
.
Tensor
|
None
=
None
,
)
->
None
:
from
lightop
import
rms_rotary_embedding_fuse
as
fused_kernel
fused_kernel
(
positions
,
query
,
key
,
head_size
,
cos_sin_cache
,
is_neox_style
,
q_weight
,
k_weight
,
q_bias
,
k_bias
,
epsilon
,
)
def
rms_rotary_embedding_fuse_fake
(
# q_out:torch.Tensor,
# k_out:torch.Tensor,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
|
None
,
head_size
:
int
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox_style
:
bool
,
q_weight
:
torch
.
Tensor
,
k_weight
:
torch
.
Tensor
,
epsilon
:
float
,
q_bias
:
torch
.
Tensor
|
None
=
None
,
k_bias
:
torch
.
Tensor
|
None
=
None
,
)
->
None
:
# Fake impl intentionally left as no-op for graph tracing modes.
pass
direct_register_custom_op
(
op_name
=
"rms_rotary_embedding_fuse"
,
op_func
=
rms_rotary_embedding_fuse
,
mutates_args
=
[
"query"
,
"key"
],
fake_impl
=
rms_rotary_embedding_fuse_fake
,
)
def
rms_mrope_fuse
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
head_size
:
int
,
rotary_dim
:
int
,
mrope_section_t
:
int
,
mrope_section_h
:
int
,
mrope_section_w
:
int
,
mrope_interleaved
:
bool
,
q_weight
:
torch
.
Tensor
,
k_weight
:
torch
.
Tensor
,
epsilon
:
float
,
q_residual
:
torch
.
Tensor
|
None
=
None
,
k_residual
:
torch
.
Tensor
|
None
=
None
,
)
->
None
:
from
lightop
import
op
as
lightop_ops
lightop_ops
.
fuse_rms_mrope_cuda
(
query
,
key
,
cos
,
sin
,
[
mrope_section_t
,
mrope_section_h
,
mrope_section_w
],
head_size
,
rotary_dim
,
mrope_interleaved
,
q_weight
,
k_weight
,
q_residual
,
k_residual
,
epsilon
,
)
def
rms_mrope_fuse_fake
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
head_size
:
int
,
rotary_dim
:
int
,
mrope_section_t
:
int
,
mrope_section_h
:
int
,
mrope_section_w
:
int
,
mrope_interleaved
:
bool
,
q_weight
:
torch
.
Tensor
,
k_weight
:
torch
.
Tensor
,
epsilon
:
float
,
q_residual
:
torch
.
Tensor
|
None
=
None
,
k_residual
:
torch
.
Tensor
|
None
=
None
,
)
->
None
:
# Fake impl intentionally left as no-op for graph tracing modes.
pass
direct_register_custom_op
(
op_name
=
"rms_mrope_fuse"
,
op_func
=
rms_mrope_fuse
,
mutates_args
=
[
"query"
,
"key"
],
fake_impl
=
rms_mrope_fuse_fake
,
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment