Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b76753f0
Unverified
Commit
b76753f0
authored
Aug 11, 2025
by
Isotr0py
Committed by
GitHub
Aug 10, 2025
Browse files
[Bugfix][Kernel] Support partial rotary embedding for MRoPE triton kernel (#22593)
Signed-off-by:
Isotr0py
<
mozf@mail2.sysu.edu.cn
>
parent
b81fe83b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
18 deletions
+30
-18
tests/kernels/core/test_mrope.py
tests/kernels/core/test_mrope.py
+14
-6
vllm/model_executor/layers/rotary_embedding/mrope.py
vllm/model_executor/layers/rotary_embedding/mrope.py
+16
-12
No files found.
tests/kernels/test_mrope.py
→
tests/kernels/
core/
test_mrope.py
View file @
b76753f0
...
...
@@ -42,12 +42,13 @@ def unroll_model_tp_dict(model_tp_dict):
model_tp_dict
=
{
"Qwen/Qwen2-VL-7B-Instruct"
:
[
1
,
2
],
"Qwen/Qwen2-VL-72B-Instruct"
:
[
1
,
2
],
"Qwen/Qwen2.5-VL-72B-Instruct"
:
[
1
,
2
]
"Qwen/Qwen2.5-VL-72B-Instruct"
:
[
1
,
2
],
"zai-org/GLM-4.1V-9B-Thinking"
:
[
1
,
2
],
}
# https://github.com/pytorch/pytorch/blob/main/torch/testing/_comparison.py#L1317
dtype_atol_rtol_list
=
[
[
torch
.
bfloat16
,
1e-
5
,
1.6e-2
],
[
torch
.
bfloat16
,
1e-
2
,
1.6e-2
],
]
num_tokens_list
=
[
11
,
8192
]
...
...
@@ -73,10 +74,12 @@ def test_mrope(model_name, tp_size, dtype, atol, rtol, num_tokens):
rope_theta
=
config
.
rope_theta
max_position
=
config
.
max_position_embeddings
partial_rotary_factor
=
getattr
(
config
,
"partial_rotary_factor"
,
1.0
)
rotary_dim
=
int
(
head_dim
*
partial_rotary_factor
)
mrope_helper_class
=
get_rope
(
head_size
=
head_dim
,
rotary_dim
=
head
_dim
,
rotary_dim
=
rotary
_dim
,
max_position
=
max_position
,
base
=
rope_theta
,
is_neox_style
=
is_neox_style
,
...
...
@@ -110,7 +113,10 @@ def test_mrope(model_name, tp_size, dtype, atol, rtol, num_tokens):
reason
=
"Skipping CUDA/ROCm only tests."
)
@
pytest
.
mark
.
parametrize
(
"model_name, tp_size"
,
unroll_model_tp_dict
({
"Qwen/Qwen2-VL-7B-Instruct"
:
[
1
,
2
]}))
unroll_model_tp_dict
({
"Qwen/Qwen2-VL-7B-Instruct"
:
[
1
,
2
],
"zai-org/GLM-4.1V-9B-Thinking"
:
[
1
,
2
]
}))
@
pytest
.
mark
.
parametrize
(
"dtype, atol, rtol"
,
dtype_atol_rtol_list
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
4
])
def
test_mrope_torch_compile_tracing
(
model_name
,
tp_size
,
dtype
,
atol
,
rtol
,
...
...
@@ -126,10 +132,12 @@ def test_mrope_torch_compile_tracing(model_name, tp_size, dtype, atol, rtol,
is_neox_style
=
True
rope_theta
=
config
.
rope_theta
max_position
=
config
.
max_position_embeddings
partial_rotary_factor
=
getattr
(
config
,
"partial_rotary_factor"
,
1.0
)
rotary_dim
=
int
(
head_dim
*
partial_rotary_factor
)
mrope_helper_class
=
get_rope
(
head_size
=
head_dim
,
rotary_dim
=
head
_dim
,
rotary_dim
=
rotary
_dim
,
max_position
=
max_position
,
base
=
rope_theta
,
is_neox_style
=
is_neox_style
,
...
...
@@ -145,7 +153,7 @@ def test_mrope_torch_compile_tracing(model_name, tp_size, dtype, atol, rtol,
# Create a wrapper that makes the in-place function appear functional
def
functional_forward_cuda
(
pos
,
q
,
k
):
"""Wrapper that converts in-place operation to functional style
CUDA Graph does not support in-place operations.
This wrapper creates working copies of the
input tensors and modifies them.
...
...
vllm/model_executor/layers/rotary_embedding/mrope.py
View file @
b76753f0
...
...
@@ -25,6 +25,7 @@ def _triton_qwen2vl_mrope_forward(
n_qh
:
tl
.
constexpr
,
n_kh
:
tl
.
constexpr
,
hd
:
tl
.
constexpr
,
rd
:
tl
.
constexpr
,
pad_n_qh
:
tl
.
constexpr
,
pad_n_kh
:
tl
.
constexpr
,
pad_hd
:
tl
.
constexpr
,
...
...
@@ -51,19 +52,19 @@ def _triton_qwen2vl_mrope_forward(
h_end
=
t_end
+
mrope_section_h
# Updated stride calculation for half head_dim
half_
h
d
=
h
d
//
2
t_cos
=
cos
+
pid
*
half_
h
d
h_cos
=
t_cos
+
num_tokens
*
half_
h
d
w_cos
=
h_cos
+
num_tokens
*
half_
h
d
t_sin
=
sin
+
pid
*
half_
h
d
h_sin
=
t_sin
+
num_tokens
*
half_
h
d
w_sin
=
h_sin
+
num_tokens
*
half_
h
d
half_
r
d
=
r
d
//
2
t_cos
=
cos
+
pid
*
half_
r
d
h_cos
=
t_cos
+
num_tokens
*
half_
r
d
w_cos
=
h_cos
+
num_tokens
*
half_
r
d
t_sin
=
sin
+
pid
*
half_
r
d
h_sin
=
t_sin
+
num_tokens
*
half_
r
d
w_sin
=
h_sin
+
num_tokens
*
half_
r
d
# Updated offsets for half head_dim
cos_offsets
=
tl
.
arange
(
0
,
pad_hd
//
2
)
t_mask
=
cos_offsets
<
t_end
h_mask
=
(
t_end
<=
cos_offsets
)
&
(
cos_offsets
<
h_end
)
w_mask
=
(
h_end
<=
cos_offsets
)
&
(
cos_offsets
<
half_
h
d
)
w_mask
=
(
h_end
<=
cos_offsets
)
&
(
cos_offsets
<
half_
r
d
)
t_cos_row
=
tl
.
load
(
t_cos
+
cos_offsets
,
mask
=
t_mask
,
other
=
0
)
h_cos_row
=
tl
.
load
(
h_cos
+
cos_offsets
,
mask
=
h_mask
,
other
=
0
)
...
...
@@ -85,9 +86,9 @@ def _triton_qwen2vl_mrope_forward(
first_half_k_offsets
=
tl
.
arange
(
0
,
pad_n_kh
)[:,
None
]
*
hd
+
tl
.
arange
(
0
,
pad_hd
//
2
)[
None
,
:]
first_q_mask
=
(
tl
.
arange
(
0
,
pad_n_qh
)[:,
None
]
<
n_qh
)
&
(
tl
.
arange
(
0
,
pad_hd
//
2
)[
None
,
:]
<
h
d
//
2
)
0
,
pad_hd
//
2
)[
None
,
:]
<
r
d
//
2
)
first_k_mask
=
(
tl
.
arange
(
0
,
pad_n_kh
)[:,
None
]
<
n_kh
)
&
(
tl
.
arange
(
0
,
pad_hd
//
2
)[
None
,
:]
<
h
d
//
2
)
0
,
pad_hd
//
2
)[
None
,
:]
<
r
d
//
2
)
q_tile_1
=
tl
.
load
(
q_ptr
+
first_half_q_offsets
,
mask
=
first_q_mask
,
...
...
@@ -97,8 +98,8 @@ def _triton_qwen2vl_mrope_forward(
other
=
0
).
to
(
sin_row
.
dtype
)
# right half of the head
second_half_q_offsets
=
first_half_q_offsets
+
(
h
d
//
2
)
second_half_k_offsets
=
first_half_k_offsets
+
(
h
d
//
2
)
second_half_q_offsets
=
first_half_q_offsets
+
(
r
d
//
2
)
second_half_k_offsets
=
first_half_k_offsets
+
(
r
d
//
2
)
second_q_mask
=
first_q_mask
second_k_mask
=
first_k_mask
...
...
@@ -130,6 +131,7 @@ def triton_mrope(
sin
:
torch
.
Tensor
,
mrope_section
:
list
[
int
],
head_size
:
int
,
rotary_dim
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Qwen2VL mrope kernel.
...
...
@@ -166,6 +168,7 @@ def triton_mrope(
n_q_head
,
n_kv_head
,
head_size
,
rotary_dim
,
pad_n_q_head
,
pad_n_kv_head
,
pad_hd
,
...
...
@@ -300,6 +303,7 @@ class MRotaryEmbedding(RotaryEmbedding):
sin
,
self
.
mrope_section
,
self
.
head_size
,
self
.
rotary_dim
,
)
return
q
.
reshape
(
query_shape
),
k
.
reshape
(
key_shape
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment