Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cea91a32
Unverified
Commit
cea91a32
authored
Sep 19, 2025
by
Isotr0py
Committed by
GitHub
Sep 19, 2025
Browse files
[Kernel][Performance] Add Triton kernel for Qwen3-VL interleaved MRoPE (#25055)
Signed-off-by:
Isotr0py
<
mozf@mail2.sysu.edu.cn
>
parent
a684c012
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
88 additions
and
46 deletions
+88
-46
tests/kernels/core/test_mrope.py
tests/kernels/core/test_mrope.py
+66
-32
vllm/model_executor/layers/rotary_embedding/mrope.py
vllm/model_executor/layers/rotary_embedding/mrope.py
+22
-14
No files found.
tests/kernels/core/test_mrope.py
View file @
cea91a32
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
NamedTuple
import
pytest
import
pytest
import
torch
import
torch
from
packaging.version
import
Version
from
transformers
import
AutoConfig
from
transformers
import
AutoConfig
from
transformers
import
__version__
as
TRANSFORMERS_VERSION
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
...
@@ -15,6 +18,7 @@ def generate_test_data(num_tokens: int, num_q_heads: int, num_kv_heads: int,
...
@@ -15,6 +18,7 @@ def generate_test_data(num_tokens: int, num_q_heads: int, num_kv_heads: int,
head_size
:
int
,
max_position_embeddings
:
int
,
head_size
:
int
,
max_position_embeddings
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
):
dtype
:
torch
.
dtype
,
device
:
torch
.
device
):
"""Generate test data for given configuration."""
"""Generate test data for given configuration."""
current_platform
.
seed_everything
(
42
)
# Create 2D positions (3, num_tokens) for multimodal case
# Create 2D positions (3, num_tokens) for multimodal case
positions
=
torch
.
randint
(
0
,
positions
=
torch
.
randint
(
0
,
max_position_embeddings
//
4
,
(
3
,
num_tokens
),
max_position_embeddings
//
4
,
(
3
,
num_tokens
),
...
@@ -33,22 +37,37 @@ def generate_test_data(num_tokens: int, num_q_heads: int, num_kv_heads: int,
...
@@ -33,22 +37,37 @@ def generate_test_data(num_tokens: int, num_q_heads: int, num_kv_heads: int,
return
positions
,
query
,
key
return
positions
,
query
,
key
def
unroll_model_tp_dict
(
model_tp_dict
):
class
MRoPETestInfo
(
NamedTuple
):
return
[(
model_name
,
tp_size
)
model_name
:
str
for
model_name
,
tp_sizes
in
model_tp_dict
.
items
()
# https://github.com/pytorch/pytorch/blob/main/torch/testing/_comparison.py#L1317
for
tp_size
in
tp_sizes
]
atol
:
float
=
1e-2
rtol
:
float
=
1.6e-2
marks
:
list
[
pytest
.
MarkDecorator
]
=
[]
model_tp_dict
=
{
TRANSFORMERS_BASE_VERSION
=
Version
(
TRANSFORMERS_VERSION
).
base_version
"Qwen/Qwen2-VL-7B-Instruct"
:
[
1
,
2
],
"Qwen/Qwen2-VL-72B-Instruct"
:
[
1
,
2
],
"Qwen/Qwen2.5-VL-72B-Instruct"
:
[
1
,
2
],
"zai-org/GLM-4.1V-9B-Thinking"
:
[
1
,
2
],
}
# https://github.com/pytorch/pytorch/blob/main/torch/testing/_comparison.py#L1317
MODELS_TO_TEST
=
[
dtype_atol_rtol_list
=
[
MRoPETestInfo
(
model_name
=
"zai-org/GLM-4.1V-9B-Thinking"
),
[
torch
.
bfloat16
,
1e-2
,
1.6e-2
],
MRoPETestInfo
(
model_name
=
"Qwen/Qwen2-VL-7B-Instruct"
),
MRoPETestInfo
(
model_name
=
"Qwen/Qwen2-VL-72B-Instruct"
),
MRoPETestInfo
(
model_name
=
"Qwen/Qwen2.5-VL-72B-Instruct"
),
MRoPETestInfo
(
model_name
=
"Qwen/Qwen3-VL-4B-Instruct"
,
marks
=
[
pytest
.
mark
.
skipif
(
Version
(
TRANSFORMERS_BASE_VERSION
)
<
Version
(
"4.57.0"
),
reason
=
"Qwen3-VL only available after Transformers v4.57"
,
)
]),
MRoPETestInfo
(
model_name
=
"Qwen/Qwen3-VL-30B-A3B-Instruct"
,
marks
=
[
pytest
.
mark
.
skipif
(
Version
(
TRANSFORMERS_BASE_VERSION
)
<
Version
(
"4.57.0"
),
reason
=
"Qwen3-VL only available after Transformers v4.57"
,
)
]),
]
]
num_tokens_list
=
[
11
,
8192
]
num_tokens_list
=
[
11
,
8192
]
...
@@ -56,20 +75,29 @@ num_tokens_list = [11, 8192]
...
@@ -56,20 +75,29 @@ num_tokens_list = [11, 8192]
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda_alike
(),
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda_alike
(),
reason
=
"Skipping CUDA/ROCm only tests."
)
reason
=
"Skipping CUDA/ROCm only tests."
)
@
pytest
.
mark
.
parametrize
(
"model_name, tp_size"
,
@
pytest
.
mark
.
parametrize
(
"model_info, model_name"
,
[
unroll_model_tp_dict
(
model_tp_dict
))
pytest
.
param
(
test_config
,
test_config
.
model_name
,
marks
=
test_config
.
marks
)
@
pytest
.
mark
.
parametrize
(
"dtype, atol, rtol"
,
dtype_atol_rtol_list
)
for
test_config
in
MODELS_TO_TEST
])
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
num_tokens_list
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
num_tokens_list
)
def
test_mrope
(
model_name
,
tp_size
,
dtype
,
atol
,
rtol
,
num_tokens
):
def
test_mrope
(
model_name
:
str
,
model_info
:
MRoPETestInfo
,
tp_size
:
int
,
dtype
:
torch
.
dtype
,
num_tokens
:
int
):
atol
=
model_info
.
atol
rtol
=
model_info
.
rtol
config
=
AutoConfig
.
from_pretrained
(
model_name
)
config
=
AutoConfig
.
from_pretrained
(
model_name
)
config
=
config
.
get_text_config
()
# get the model config
# get the model config
total_num_kv_heads
=
config
.
num_key_value_heads
total_num_kv_heads
=
config
.
num_key_value_heads
total_num_heads
=
config
.
num_attention_heads
total_num_heads
=
config
.
num_attention_heads
num_heads
=
total_num_heads
//
tp_size
num_heads
=
total_num_heads
//
tp_size
num_kv_heads
=
max
(
1
,
total_num_kv_heads
//
tp_size
)
num_kv_heads
=
max
(
1
,
total_num_kv_heads
//
tp_size
)
head_dim
=
config
.
hidden_size
//
total_num_heads
head_dim
=
(
config
.
head_dim
if
hasattr
(
config
,
"head_dim"
)
else
config
.
hidden_size
//
total_num_heads
)
is_neox_style
=
True
is_neox_style
=
True
rope_theta
=
config
.
rope_theta
rope_theta
=
config
.
rope_theta
...
@@ -111,24 +139,30 @@ def test_mrope(model_name, tp_size, dtype, atol, rtol, num_tokens):
...
@@ -111,24 +139,30 @@ def test_mrope(model_name, tp_size, dtype, atol, rtol, num_tokens):
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda_alike
(),
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda_alike
(),
reason
=
"Skipping CUDA/ROCm only tests."
)
reason
=
"Skipping CUDA/ROCm only tests."
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"model_info, model_name"
,
[
"model_name, tp_size"
,
pytest
.
param
(
test_config
,
test_config
.
model_name
,
marks
=
test_config
.
marks
)
unroll_model_tp_dict
({
for
test_config
in
MODELS_TO_TEST
"Qwen/Qwen2-VL-7B-Instruct"
:
[
1
,
2
],
])
"zai-org/GLM-4.1V-9B-Thinking"
:
[
1
,
2
]
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
1
,
2
])
}))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"dtype, atol, rtol"
,
dtype_atol_rtol_list
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
num_tokens_list
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
4
])
def
test_mrope_torch_compile_tracing
(
model_name
:
str
,
def
test_mrope_torch_compile_tracing
(
model_name
,
tp_size
,
dtype
,
atol
,
rtol
,
model_info
:
MRoPETestInfo
,
tp_size
:
int
,
num_tokens
):
dtype
:
torch
.
dtype
,
num_tokens
:
int
):
atol
=
model_info
.
atol
rtol
=
model_info
.
rtol
config
=
AutoConfig
.
from_pretrained
(
model_name
)
config
=
AutoConfig
.
from_pretrained
(
model_name
)
config
=
config
.
get_text_config
()
# get the model config
# get the model config
total_num_kv_heads
=
config
.
num_key_value_heads
total_num_kv_heads
=
config
.
num_key_value_heads
total_num_heads
=
config
.
num_attention_heads
total_num_heads
=
config
.
num_attention_heads
num_heads
=
total_num_heads
//
tp_size
num_heads
=
total_num_heads
//
tp_size
num_kv_heads
=
max
(
1
,
total_num_kv_heads
//
tp_size
)
num_kv_heads
=
max
(
1
,
total_num_kv_heads
//
tp_size
)
head_dim
=
config
.
hidden_size
//
total_num_heads
head_dim
=
(
config
.
head_dim
if
hasattr
(
config
,
"head_dim"
)
else
config
.
hidden_size
//
total_num_heads
)
is_neox_style
=
True
is_neox_style
=
True
rope_theta
=
config
.
rope_theta
rope_theta
=
config
.
rope_theta
max_position
=
config
.
max_position_embeddings
max_position
=
config
.
max_position_embeddings
...
...
vllm/model_executor/layers/rotary_embedding/mrope.py
View file @
cea91a32
...
@@ -15,7 +15,7 @@ from .common import apply_rotary_emb_dispatch
...
@@ -15,7 +15,7 @@ from .common import apply_rotary_emb_dispatch
@
triton
.
jit
@
triton
.
jit
def
_triton_
qwen2vl_
mrope_forward
(
def
_triton_mrope_forward
(
q_ptr
,
q_ptr
,
k_ptr
,
k_ptr
,
cos
,
cos
,
...
@@ -30,12 +30,14 @@ def _triton_qwen2vl_mrope_forward(
...
@@ -30,12 +30,14 @@ def _triton_qwen2vl_mrope_forward(
pad_hd
:
tl
.
constexpr
,
pad_hd
:
tl
.
constexpr
,
mrope_section_t
:
tl
.
constexpr
,
mrope_section_t
:
tl
.
constexpr
,
mrope_section_h
:
tl
.
constexpr
,
mrope_section_h
:
tl
.
constexpr
,
mrope_section_w
:
tl
.
constexpr
,
is_interleaved
:
tl
.
constexpr
,
):
):
# Adapted from
# Adapted from
# https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/qwen2vl_mrope.py
# https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/qwen2vl_mrope.py
# This version supports flatten input tensors from vllm
# This version supports flatten input tensors from vllm
# and supports cos and sin cache with shape (3, num_tokens, head_dim // 2)
# and supports cos and sin cache with shape (3, num_tokens, head_dim // 2)
# instead of (3, bsz, seq_len, head_dim)
# instead of (3, bsz, seq_len, head_dim)
, also supports interleaved rotary
pid
=
tl
.
program_id
(
0
)
pid
=
tl
.
program_id
(
0
)
# locate start address
# locate start address
q_ptr
=
q_ptr
+
pid
*
(
n_qh
*
hd
)
q_ptr
=
q_ptr
+
pid
*
(
n_qh
*
hd
)
...
@@ -47,9 +49,6 @@ def _triton_qwen2vl_mrope_forward(
...
@@ -47,9 +49,6 @@ def _triton_qwen2vl_mrope_forward(
# ####################################################################
# ####################################################################
# Note: cos and sin now have shape (3, num_tokens, head_dim // 2)
# Note: cos and sin now have shape (3, num_tokens, head_dim // 2)
t_end
=
mrope_section_t
h_end
=
t_end
+
mrope_section_h
# Updated stride calculation for half head_dim
# Updated stride calculation for half head_dim
half_rd
=
rd
//
2
half_rd
=
rd
//
2
t_cos
=
cos
+
pid
*
half_rd
t_cos
=
cos
+
pid
*
half_rd
...
@@ -61,7 +60,16 @@ def _triton_qwen2vl_mrope_forward(
...
@@ -61,7 +60,16 @@ def _triton_qwen2vl_mrope_forward(
# Updated offsets for half head_dim
# Updated offsets for half head_dim
cos_offsets
=
tl
.
arange
(
0
,
pad_hd
//
2
)
cos_offsets
=
tl
.
arange
(
0
,
pad_hd
//
2
)
t_mask
=
cos_offsets
<
t_end
if
is_interleaved
:
h_mask
=
(((
cos_offsets
%
3
)
==
1
)
&
(
cos_offsets
<=
3
*
mrope_section_h
))
w_mask
=
(((
cos_offsets
%
3
)
==
2
)
&
(
cos_offsets
<=
3
*
mrope_section_w
))
t_mask
=
~
(
h_mask
|
w_mask
)
else
:
t_end
=
mrope_section_t
h_end
=
t_end
+
mrope_section_h
t_mask
=
cos_offsets
<
mrope_section_t
h_mask
=
(
t_end
<=
cos_offsets
)
&
(
cos_offsets
<
h_end
)
h_mask
=
(
t_end
<=
cos_offsets
)
&
(
cos_offsets
<
h_end
)
w_mask
=
(
h_end
<=
cos_offsets
)
&
(
cos_offsets
<
half_rd
)
w_mask
=
(
h_end
<=
cos_offsets
)
&
(
cos_offsets
<
half_rd
)
...
@@ -131,6 +139,7 @@ def triton_mrope(
...
@@ -131,6 +139,7 @@ def triton_mrope(
mrope_section
:
list
[
int
],
mrope_section
:
list
[
int
],
head_size
:
int
,
head_size
:
int
,
rotary_dim
:
int
,
rotary_dim
:
int
,
mrope_interleaved
:
bool
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Qwen2VL mrope kernel.
"""Qwen2VL mrope kernel.
...
@@ -158,7 +167,7 @@ def triton_mrope(
...
@@ -158,7 +167,7 @@ def triton_mrope(
cos
=
cos
.
contiguous
()
cos
=
cos
.
contiguous
()
sin
=
sin
.
contiguous
()
sin
=
sin
.
contiguous
()
_triton_
qwen2vl_
mrope_forward
[(
n_row
,
)](
_triton_mrope_forward
[(
n_row
,
)](
q
,
q
,
k
,
k
,
cos
,
cos
,
...
@@ -173,6 +182,8 @@ def triton_mrope(
...
@@ -173,6 +182,8 @@ def triton_mrope(
pad_hd
,
pad_hd
,
mrope_section
[
0
],
mrope_section
[
0
],
mrope_section
[
1
],
mrope_section
[
1
],
mrope_section
[
2
],
mrope_interleaved
,
)
)
return
q
,
k
return
q
,
k
...
@@ -201,7 +212,7 @@ class MRotaryEmbedding(RotaryEmbedding):
...
@@ -201,7 +212,7 @@ class MRotaryEmbedding(RotaryEmbedding):
is_neox_style
:
bool
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
mrope_section
:
Optional
[
list
[
int
]]
=
None
,
mrope_section
:
Optional
[
list
[
int
]]
=
None
,
mrope_interleaved
:
Optional
[
bool
]
=
False
,
mrope_interleaved
:
bool
=
False
,
)
->
None
:
)
->
None
:
# In Qwen2.5-VL, the maximum index value is related to the duration of
# In Qwen2.5-VL, the maximum index value is related to the duration of
# the input video. We enlarge max_position_embeddings to 4 times to get
# the input video. We enlarge max_position_embeddings to 4 times to get
...
@@ -282,10 +293,6 @@ class MRotaryEmbedding(RotaryEmbedding):
...
@@ -282,10 +293,6 @@ class MRotaryEmbedding(RotaryEmbedding):
assert
positions
.
ndim
==
1
or
positions
.
ndim
==
2
assert
positions
.
ndim
==
1
or
positions
.
ndim
==
2
assert
key
is
not
None
assert
key
is
not
None
if
self
.
mrope_interleaved
:
# TODO: add triton implementation to support mrope-interleaved
return
self
.
forward_native
(
positions
,
query
,
key
)
num_tokens
=
positions
.
shape
[
-
1
]
num_tokens
=
positions
.
shape
[
-
1
]
cos_sin
=
self
.
cos_sin_cache
[
positions
]
cos_sin
=
self
.
cos_sin_cache
[
positions
]
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
...
@@ -302,6 +309,7 @@ class MRotaryEmbedding(RotaryEmbedding):
...
@@ -302,6 +309,7 @@ class MRotaryEmbedding(RotaryEmbedding):
self
.
mrope_section
,
self
.
mrope_section
,
self
.
head_size
,
self
.
head_size
,
self
.
rotary_dim
,
self
.
rotary_dim
,
self
.
mrope_interleaved
,
)
)
return
q
.
reshape
(
query_shape
),
k
.
reshape
(
key_shape
)
return
q
.
reshape
(
query_shape
),
k
.
reshape
(
key_shape
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment