Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cea91a32
Unverified
Commit
cea91a32
authored
Sep 19, 2025
by
Isotr0py
Committed by
GitHub
Sep 19, 2025
Browse files
[Kernel][Performance] Add Triton kernel for Qwen3-VL interleaved MRoPE (#25055)
Signed-off-by:
Isotr0py
<
mozf@mail2.sysu.edu.cn
>
parent
a684c012
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
88 additions
and
46 deletions
+88
-46
tests/kernels/core/test_mrope.py
tests/kernels/core/test_mrope.py
+66
-32
vllm/model_executor/layers/rotary_embedding/mrope.py
vllm/model_executor/layers/rotary_embedding/mrope.py
+22
-14
No files found.
tests/kernels/core/test_mrope.py
View file @
cea91a32
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
NamedTuple
import
pytest
import
torch
from
packaging.version
import
Version
from
transformers
import
AutoConfig
from
transformers
import
__version__
as
TRANSFORMERS_VERSION
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.platforms
import
current_platform
...
...
@@ -15,6 +18,7 @@ def generate_test_data(num_tokens: int, num_q_heads: int, num_kv_heads: int,
head_size
:
int
,
max_position_embeddings
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
):
"""Generate test data for given configuration."""
current_platform
.
seed_everything
(
42
)
# Create 2D positions (3, num_tokens) for multimodal case
positions
=
torch
.
randint
(
0
,
max_position_embeddings
//
4
,
(
3
,
num_tokens
),
...
...
@@ -33,22 +37,37 @@ def generate_test_data(num_tokens: int, num_q_heads: int, num_kv_heads: int,
return
positions
,
query
,
key
def
unroll_model_tp_dict
(
model_tp_dict
):
return
[(
model_name
,
tp_size
)
for
model_name
,
tp_sizes
in
model_tp_dict
.
items
()
for
tp_size
in
tp_sizes
]
model_tp_dict
=
{
"Qwen/Qwen2-VL-7B-Instruct"
:
[
1
,
2
],
"Qwen/Qwen2-VL-72B-Instruct"
:
[
1
,
2
],
"Qwen/Qwen2.5-VL-72B-Instruct"
:
[
1
,
2
],
"zai-org/GLM-4.1V-9B-Thinking"
:
[
1
,
2
],
}
# https://github.com/pytorch/pytorch/blob/main/torch/testing/_comparison.py#L1317
dtype_atol_rtol_list
=
[
[
torch
.
bfloat16
,
1e-2
,
1.6e-2
],
class
MRoPETestInfo
(
NamedTuple
):
model_name
:
str
# https://github.com/pytorch/pytorch/blob/main/torch/testing/_comparison.py#L1317
atol
:
float
=
1e-2
rtol
:
float
=
1.6e-2
marks
:
list
[
pytest
.
MarkDecorator
]
=
[]
TRANSFORMERS_BASE_VERSION
=
Version
(
TRANSFORMERS_VERSION
).
base_version
MODELS_TO_TEST
=
[
MRoPETestInfo
(
model_name
=
"zai-org/GLM-4.1V-9B-Thinking"
),
MRoPETestInfo
(
model_name
=
"Qwen/Qwen2-VL-7B-Instruct"
),
MRoPETestInfo
(
model_name
=
"Qwen/Qwen2-VL-72B-Instruct"
),
MRoPETestInfo
(
model_name
=
"Qwen/Qwen2.5-VL-72B-Instruct"
),
MRoPETestInfo
(
model_name
=
"Qwen/Qwen3-VL-4B-Instruct"
,
marks
=
[
pytest
.
mark
.
skipif
(
Version
(
TRANSFORMERS_BASE_VERSION
)
<
Version
(
"4.57.0"
),
reason
=
"Qwen3-VL only available after Transformers v4.57"
,
)
]),
MRoPETestInfo
(
model_name
=
"Qwen/Qwen3-VL-30B-A3B-Instruct"
,
marks
=
[
pytest
.
mark
.
skipif
(
Version
(
TRANSFORMERS_BASE_VERSION
)
<
Version
(
"4.57.0"
),
reason
=
"Qwen3-VL only available after Transformers v4.57"
,
)
]),
]
num_tokens_list
=
[
11
,
8192
]
...
...
@@ -56,20 +75,29 @@ num_tokens_list = [11, 8192]
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda_alike
(),
reason
=
"Skipping CUDA/ROCm only tests."
)
@
pytest
.
mark
.
parametrize
(
"model_name, tp_size"
,
unroll_model_tp_dict
(
model_tp_dict
))
@
pytest
.
mark
.
parametrize
(
"dtype, atol, rtol"
,
dtype_atol_rtol_list
)
@
pytest
.
mark
.
parametrize
(
"model_info, model_name"
,
[
pytest
.
param
(
test_config
,
test_config
.
model_name
,
marks
=
test_config
.
marks
)
for
test_config
in
MODELS_TO_TEST
])
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
num_tokens_list
)
def
test_mrope
(
model_name
,
tp_size
,
dtype
,
atol
,
rtol
,
num_tokens
):
def
test_mrope
(
model_name
:
str
,
model_info
:
MRoPETestInfo
,
tp_size
:
int
,
dtype
:
torch
.
dtype
,
num_tokens
:
int
):
atol
=
model_info
.
atol
rtol
=
model_info
.
rtol
config
=
AutoConfig
.
from_pretrained
(
model_name
)
config
=
config
.
get_text_config
()
# get the model config
total_num_kv_heads
=
config
.
num_key_value_heads
total_num_heads
=
config
.
num_attention_heads
num_heads
=
total_num_heads
//
tp_size
num_kv_heads
=
max
(
1
,
total_num_kv_heads
//
tp_size
)
head_dim
=
config
.
hidden_size
//
total_num_heads
head_dim
=
(
config
.
head_dim
if
hasattr
(
config
,
"head_dim"
)
else
config
.
hidden_size
//
total_num_heads
)
is_neox_style
=
True
rope_theta
=
config
.
rope_theta
...
...
@@ -111,24 +139,30 @@ def test_mrope(model_name, tp_size, dtype, atol, rtol, num_tokens):
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda_alike
(),
reason
=
"Skipping CUDA/ROCm only tests."
)
@
pytest
.
mark
.
parametrize
(
"model_name, tp_size"
,
unroll_model_tp_dict
({
"Qwen/Qwen2-VL-7B-Instruct"
:
[
1
,
2
],
"zai-org/GLM-4.1V-9B-Thinking"
:
[
1
,
2
]
}))
@
pytest
.
mark
.
parametrize
(
"dtype, atol, rtol"
,
dtype_atol_rtol_list
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
4
])
def
test_mrope_torch_compile_tracing
(
model_name
,
tp_size
,
dtype
,
atol
,
rtol
,
num_tokens
):
@
pytest
.
mark
.
parametrize
(
"model_info, model_name"
,
[
pytest
.
param
(
test_config
,
test_config
.
model_name
,
marks
=
test_config
.
marks
)
for
test_config
in
MODELS_TO_TEST
])
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
num_tokens_list
)
def
test_mrope_torch_compile_tracing
(
model_name
:
str
,
model_info
:
MRoPETestInfo
,
tp_size
:
int
,
dtype
:
torch
.
dtype
,
num_tokens
:
int
):
atol
=
model_info
.
atol
rtol
=
model_info
.
rtol
config
=
AutoConfig
.
from_pretrained
(
model_name
)
config
=
config
.
get_text_config
()
# get the model config
total_num_kv_heads
=
config
.
num_key_value_heads
total_num_heads
=
config
.
num_attention_heads
num_heads
=
total_num_heads
//
tp_size
num_kv_heads
=
max
(
1
,
total_num_kv_heads
//
tp_size
)
head_dim
=
config
.
hidden_size
//
total_num_heads
head_dim
=
(
config
.
head_dim
if
hasattr
(
config
,
"head_dim"
)
else
config
.
hidden_size
//
total_num_heads
)
is_neox_style
=
True
rope_theta
=
config
.
rope_theta
max_position
=
config
.
max_position_embeddings
...
...
vllm/model_executor/layers/rotary_embedding/mrope.py
View file @
cea91a32
...
...
@@ -15,7 +15,7 @@ from .common import apply_rotary_emb_dispatch
@
triton
.
jit
def
_triton_
qwen2vl_
mrope_forward
(
def
_triton_mrope_forward
(
q_ptr
,
k_ptr
,
cos
,
...
...
@@ -30,12 +30,14 @@ def _triton_qwen2vl_mrope_forward(
pad_hd
:
tl
.
constexpr
,
mrope_section_t
:
tl
.
constexpr
,
mrope_section_h
:
tl
.
constexpr
,
mrope_section_w
:
tl
.
constexpr
,
is_interleaved
:
tl
.
constexpr
,
):
# Adapted from
# https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/qwen2vl_mrope.py
# This version supports flatten input tensors from vllm
# and supports cos and sin cache with shape (3, num_tokens, head_dim // 2)
# instead of (3, bsz, seq_len, head_dim)
# instead of (3, bsz, seq_len, head_dim)
, also supports interleaved rotary
pid
=
tl
.
program_id
(
0
)
# locate start address
q_ptr
=
q_ptr
+
pid
*
(
n_qh
*
hd
)
...
...
@@ -47,9 +49,6 @@ def _triton_qwen2vl_mrope_forward(
# ####################################################################
# Note: cos and sin now have shape (3, num_tokens, head_dim // 2)
t_end
=
mrope_section_t
h_end
=
t_end
+
mrope_section_h
# Updated stride calculation for half head_dim
half_rd
=
rd
//
2
t_cos
=
cos
+
pid
*
half_rd
...
...
@@ -61,9 +60,18 @@ def _triton_qwen2vl_mrope_forward(
# Updated offsets for half head_dim
cos_offsets
=
tl
.
arange
(
0
,
pad_hd
//
2
)
t_mask
=
cos_offsets
<
t_end
h_mask
=
(
t_end
<=
cos_offsets
)
&
(
cos_offsets
<
h_end
)
w_mask
=
(
h_end
<=
cos_offsets
)
&
(
cos_offsets
<
half_rd
)
if
is_interleaved
:
h_mask
=
(((
cos_offsets
%
3
)
==
1
)
&
(
cos_offsets
<=
3
*
mrope_section_h
))
w_mask
=
(((
cos_offsets
%
3
)
==
2
)
&
(
cos_offsets
<=
3
*
mrope_section_w
))
t_mask
=
~
(
h_mask
|
w_mask
)
else
:
t_end
=
mrope_section_t
h_end
=
t_end
+
mrope_section_h
t_mask
=
cos_offsets
<
mrope_section_t
h_mask
=
(
t_end
<=
cos_offsets
)
&
(
cos_offsets
<
h_end
)
w_mask
=
(
h_end
<=
cos_offsets
)
&
(
cos_offsets
<
half_rd
)
t_cos_row
=
tl
.
load
(
t_cos
+
cos_offsets
,
mask
=
t_mask
,
other
=
0
)
h_cos_row
=
tl
.
load
(
h_cos
+
cos_offsets
,
mask
=
h_mask
,
other
=
0
)
...
...
@@ -131,6 +139,7 @@ def triton_mrope(
mrope_section
:
list
[
int
],
head_size
:
int
,
rotary_dim
:
int
,
mrope_interleaved
:
bool
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Qwen2VL mrope kernel.
...
...
@@ -158,7 +167,7 @@ def triton_mrope(
cos
=
cos
.
contiguous
()
sin
=
sin
.
contiguous
()
_triton_
qwen2vl_
mrope_forward
[(
n_row
,
)](
_triton_mrope_forward
[(
n_row
,
)](
q
,
k
,
cos
,
...
...
@@ -173,6 +182,8 @@ def triton_mrope(
pad_hd
,
mrope_section
[
0
],
mrope_section
[
1
],
mrope_section
[
2
],
mrope_interleaved
,
)
return
q
,
k
...
...
@@ -201,7 +212,7 @@ class MRotaryEmbedding(RotaryEmbedding):
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
mrope_section
:
Optional
[
list
[
int
]]
=
None
,
mrope_interleaved
:
Optional
[
bool
]
=
False
,
mrope_interleaved
:
bool
=
False
,
)
->
None
:
# In Qwen2.5-VL, the maximum index value is related to the duration of
# the input video. We enlarge max_position_embeddings to 4 times to get
...
...
@@ -282,10 +293,6 @@ class MRotaryEmbedding(RotaryEmbedding):
assert
positions
.
ndim
==
1
or
positions
.
ndim
==
2
assert
key
is
not
None
if
self
.
mrope_interleaved
:
# TODO: add triton implementation to support mrope-interleaved
return
self
.
forward_native
(
positions
,
query
,
key
)
num_tokens
=
positions
.
shape
[
-
1
]
cos_sin
=
self
.
cos_sin_cache
[
positions
]
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
...
...
@@ -302,6 +309,7 @@ class MRotaryEmbedding(RotaryEmbedding):
self
.
mrope_section
,
self
.
head_size
,
self
.
rotary_dim
,
self
.
mrope_interleaved
,
)
return
q
.
reshape
(
query_shape
),
k
.
reshape
(
key_shape
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment