Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
04d429f6
Commit
04d429f6
authored
Mar 18, 2026
by
guanyu1
Browse files
qwen3.py合入fused_morpe
parent
7676d0c9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
201 additions
and
80 deletions
+201
-80
vllm/_custom_ops.py
vllm/_custom_ops.py
+120
-0
vllm/model_executor/models/qwen3.py
vllm/model_executor/models/qwen3.py
+81
-80
No files found.
vllm/_custom_ops.py
View file @
04d429f6
...
...
@@ -3628,4 +3628,124 @@ direct_register_custom_op(
op_func
=
fused_add_rms_norm_opt
,
mutates_args
=
[],
fake_impl
=
fused_add_rms_norm_opt_fake
,
)
"""
qwen3-vl-8b中LLM的修改 rms+mrope dim==1 2026/03/18
"""
def
rms_rotary_embedding_fuse
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
],
head_size
:
int
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox_style
:
bool
,
q_weight
:
torch
.
Tensor
,
k_weight
:
torch
.
Tensor
,
q_bias
:
Optional
[
torch
.
Tensor
],
k_bias
:
Optional
[
torch
.
Tensor
],
epsilon
:
float
,
)
->
None
:
from
lightop
import
rms_rotary_embedding_fuse
as
fused_kernel
fused_kernel
(
positions
,
query
,
key
,
head_size
,
cos_sin_cache
,
is_neox_style
,
q_weight
,
k_weight
,
q_bias
,
k_bias
,
epsilon
,
)
def
rms_rotary_embedding_fuse_fake
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
],
head_size
:
int
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox_style
:
bool
,
q_weight
:
torch
.
Tensor
,
k_weight
:
torch
.
Tensor
,
q_bias
:
Optional
[
torch
.
Tensor
],
k_bias
:
Optional
[
torch
.
Tensor
],
epsilon
:
float
,
)
->
None
:
# Fake impl intentionally left as no-op for graph tracing modes.
pass
if
not
hasattr
(
torch
.
ops
.
vllm
,
"rms_rotary_embedding_fuse"
):
direct_register_custom_op
(
op_name
=
"rms_rotary_embedding_fuse"
,
op_func
=
rms_rotary_embedding_fuse
,
mutates_args
=
[
"query"
,
"key"
],
fake_impl
=
rms_rotary_embedding_fuse_fake
,
)
"""
qwen3-vl-8b中LLM模型的修改 rms+mrope dim==2 2026/03/18
"""
def
rms_mrope_fuse_fake
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
head_size
:
int
,
rotary_dim
:
int
,
mrope_section_t
:
int
,
mrope_section_h
:
int
,
mrope_section_w
:
int
,
mrope_interleaved
:
bool
,
q_weight
:
torch
.
Tensor
,
k_weight
:
torch
.
Tensor
,
epsilon
:
float
,
q_residual
:
torch
.
Tensor
|
None
=
None
,
k_residual
:
torch
.
Tensor
|
None
=
None
,
)
->
None
:
# Fake impl intentionally left as no-op for graph tracing modes.
pass
def
rms_mrope_fuse
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
head_size
:
int
,
rotary_dim
:
int
,
mrope_section_t
:
int
,
mrope_section_h
:
int
,
mrope_section_w
:
int
,
mrope_interleaved
:
bool
,
q_weight
:
torch
.
Tensor
,
k_weight
:
torch
.
Tensor
,
epsilon
:
float
,
q_residual
:
torch
.
Tensor
|
None
=
None
,
k_residual
:
torch
.
Tensor
|
None
=
None
,
)
->
None
:
from
lightop
import
op
as
lightop_ops
lightop_ops
.
fuse_rms_mrope_cuda
(
query
,
key
,
cos
,
sin
,
[
mrope_section_t
,
mrope_section_h
,
mrope_section_w
],
head_size
,
rotary_dim
,
mrope_interleaved
,
q_weight
,
k_weight
,
q_residual
,
k_residual
,
epsilon
,
)
direct_register_custom_op
(
op_name
=
"rms_mrope_fuse"
,
op_func
=
rms_mrope_fuse
,
mutates_args
=
[
"query"
,
"key"
],
fake_impl
=
rms_mrope_fuse_fake
,
)
\ No newline at end of file
vllm/model_executor/models/qwen3.py
View file @
04d429f6
...
...
@@ -51,8 +51,7 @@ from .qwen2 import Qwen2Model
from
.utils
import
AutoWeightsLoader
,
PPMissingLayer
,
extract_layer_index
,
maybe_prefix
import
vllm.envs
as
envs
from
vllm.utils
import
direct_register_custom_op
from
vllm
import
_custom_ops
as
ops
logger
=
init_logger
(
__name__
)
...
...
@@ -137,58 +136,6 @@ class Qwen3Attention(nn.Module):
self
.
q_norm
=
RMSNorm
(
self
.
head_dim
,
eps
=
rms_norm_eps
)
self
.
k_norm
=
RMSNorm
(
self
.
head_dim
,
eps
=
rms_norm_eps
)
def
rms_rotary_embedding_fuse
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
],
head_size
:
int
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox_style
:
bool
,
q_weight
:
torch
.
Tensor
,
k_weight
:
torch
.
Tensor
,
q_bias
:
Optional
[
torch
.
Tensor
],
k_bias
:
Optional
[
torch
.
Tensor
],
epsilon
:
float
,
)
->
None
:
from
lightop
import
rms_rotary_embedding_fuse
as
fused_kernel
fused_kernel
(
positions
,
query
,
key
,
head_size
,
cos_sin_cache
,
is_neox_style
,
q_weight
,
k_weight
,
q_bias
,
k_bias
,
epsilon
,
)
def
rms_rotary_embedding_fuse_fake
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
],
head_size
:
int
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox_style
:
bool
,
q_weight
:
torch
.
Tensor
,
k_weight
:
torch
.
Tensor
,
q_bias
:
Optional
[
torch
.
Tensor
],
k_bias
:
Optional
[
torch
.
Tensor
],
epsilon
:
float
,
)
->
None
:
# Fake impl intentionally left as no-op for graph tracing modes.
pass
if
not
hasattr
(
torch
.
ops
.
vllm
,
"rms_rotary_embedding_fuse"
):
direct_register_custom_op
(
op_name
=
"rms_rotary_embedding_fuse"
,
op_func
=
rms_rotary_embedding_fuse
,
mutates_args
=
[
"query"
,
"key"
],
fake_impl
=
rms_rotary_embedding_fuse_fake
,
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
...
...
@@ -196,33 +143,87 @@ class Qwen3Attention(nn.Module):
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
used_fused
=
False
if
envs
.
VLLM_USE_FUSED_RMS_ROPE
and
positions
.
ndim
==
1
:
# Fused RMSNorm + RoPE path through custom op.
cos_sin_cache
=
self
.
rotary_emb
.
cos_sin_cache
if
(
cos_sin_cache
.
device
!=
q
.
device
or
cos_sin_cache
.
dtype
!=
q
.
dtype
):
cos_sin_cache
=
cos_sin_cache
.
to
(
q
.
device
,
dtype
=
q
.
dtype
,
non_blocking
=
True
)
# Persist the converted cache so we don't re-copy/re-allocate
# on every forward when the original buffer starts on CPU.
self
.
rotary_emb
.
cos_sin_cache
=
cos_sin_cache
q
=
q
.
contiguous
()
k
=
k
.
contiguous
()
torch
.
ops
.
vllm
.
rms_rotary_embedding_fuse
(
positions
,
q
,
k
,
self
.
head_dim
,
cos_sin_cache
,
self
.
rotary_emb
.
is_neox_style
,
self
.
q_norm
.
weight
,
self
.
k_norm
.
weight
,
None
,
None
,
self
.
q_norm
.
variance_epsilon
,
)
else
:
if
hasattr
(
torch
.
ops
.
vllm
,
"rms_rotary_embedding_fuse"
):
# Fused RMSNorm + RoPE path through custom op.
cos_sin_cache
=
self
.
rotary_emb
.
cos_sin_cache
if
(
cos_sin_cache
.
device
!=
q
.
device
or
cos_sin_cache
.
dtype
!=
q
.
dtype
):
cos_sin_cache
=
cos_sin_cache
.
to
(
q
.
device
,
dtype
=
q
.
dtype
,
non_blocking
=
True
)
# Persist the converted cache so we don't re-copy/re-allocate
# on every forward when the original buffer starts on CPU.
self
.
rotary_emb
.
cos_sin_cache
=
cos_sin_cache
q
=
q
.
contiguous
()
k
=
k
.
contiguous
()
torch
.
ops
.
vllm
.
rms_rotary_embedding_fuse
(
positions
,
q
,
k
,
self
.
head_dim
,
cos_sin_cache
,
self
.
rotary_emb
.
is_neox_style
,
self
.
q_norm
.
weight
,
self
.
k_norm
.
weight
,
None
,
None
,
self
.
q_norm
.
variance_epsilon
,
)
used_fused
=
True
else
:
logger
.
warning_once
(
"VLLM_USE_FUSED_RMS_ROPE is enabled and positions.ndim == 1, "
"but the RoPE fused op is unavailable; falling back to the "
"default RMSNorm + RoPE path."
)
elif
envs
.
VLLM_USE_FUSED_RMS_ROPE
and
positions
.
ndim
==
2
:
mrope_section
=
getattr
(
self
.
rotary_emb
,
"mrope_section"
,
None
)
if
mrope_section
is
not
None
and
hasattr
(
torch
.
ops
.
vllm
,
"rms_mrope_fuse"
):
# Fused RMSNorm + M-RoPE path through custom op.
cos_sin_cache
=
self
.
rotary_emb
.
cos_sin_cache
if
(
cos_sin_cache
.
device
!=
q
.
device
or
cos_sin_cache
.
dtype
!=
q
.
dtype
):
cos_sin_cache
=
cos_sin_cache
.
to
(
q
.
device
,
dtype
=
q
.
dtype
,
non_blocking
=
True
)
self
.
rotary_emb
.
cos_sin_cache
=
cos_sin_cache
cos_sin
=
cos_sin_cache
[
positions
]
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
q
=
q
.
contiguous
()
k
=
k
.
contiguous
()
cos
=
cos
.
contiguous
()
sin
=
sin
.
contiguous
()
assert
len
(
mrope_section
)
==
3
torch
.
ops
.
vllm
.
rms_mrope_fuse
(
q
,
k
,
cos
,
sin
,
self
.
head_dim
,
self
.
rotary_emb
.
rotary_dim
,
mrope_section
[
0
],
mrope_section
[
1
],
mrope_section
[
2
],
self
.
rotary_emb
.
mrope_interleaved
,
self
.
q_norm
.
weight
,
self
.
k_norm
.
weight
,
self
.
q_norm
.
variance_epsilon
,
None
,
None
,
)
used_fused
=
True
else
:
logger
.
warning_once
(
"VLLM_USE_FUSED_RMS_ROPE is enabled and positions.ndim == 2, "
"but the M-RoPE fused op is unavailable; falling back to the "
"default RMSNorm + RoPE path."
)
if
not
used_fused
:
# Add qk-norm
q_by_head
=
q
.
view
(
*
q
.
shape
[:
-
1
],
q
.
shape
[
-
1
]
//
self
.
head_dim
,
self
.
head_dim
)
if
envs
.
VLLM_USE_APEX_RN
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment