Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
30d08911
Unverified
Commit
30d08911
authored
Sep 21, 2025
by
Roger Wang
Committed by
GitHub
Sep 21, 2025
Browse files
[MM][Perf] Minor Optimization on Qwen3-VL `fast_pos_embed_interpolate` (#25337)
Signed-off-by:
Roger Wang
<
hey@rogerw.io
>
parent
cf56cf78
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
60 additions
and
75 deletions
+60
-75
vllm/model_executor/models/qwen3_vl.py
vllm/model_executor/models/qwen3_vl.py
+60
-75
No files found.
vllm/model_executor/models/qwen3_vl.py
View file @
30d08911
...
@@ -270,6 +270,7 @@ class Qwen3_VisionTransformer(nn.Module):
...
@@ -270,6 +270,7 @@ class Qwen3_VisionTransformer(nn.Module):
self
.
temporal_patch_size
=
vision_config
.
temporal_patch_size
self
.
temporal_patch_size
=
vision_config
.
temporal_patch_size
self
.
deepstack_visual_indexes
=
vision_config
.
deepstack_visual_indexes
self
.
deepstack_visual_indexes
=
vision_config
.
deepstack_visual_indexes
self
.
use_data_parallel
=
use_data_parallel
self
.
use_data_parallel
=
use_data_parallel
self
.
num_grid_per_side
=
int
(
self
.
num_position_embeddings
**
0.5
)
# NOTE: This is used for creating empty tensor for all_gather for
# NOTE: This is used for creating empty tensor for all_gather for
# DP ViT. Here out_hidden_size is enlarged due to deepstack
# DP ViT. Here out_hidden_size is enlarged due to deepstack
...
@@ -377,82 +378,68 @@ class Qwen3_VisionTransformer(nn.Module):
...
@@ -377,82 +378,68 @@ class Qwen3_VisionTransformer(nn.Module):
rotary_pos_emb
=
rotary_pos_emb_full
[
pos_ids
].
flatten
(
1
)
rotary_pos_emb
=
rotary_pos_emb_full
[
pos_ids
].
flatten
(
1
)
return
rotary_pos_emb
return
rotary_pos_emb
def
fast_pos_embed_interpolate
(
self
,
grid_thw
):
def
fast_pos_embed_interpolate
(
self
,
num_grid_per_side
=
int
(
self
.
num_position_embeddings
**
0.5
)
grid_thw
:
list
[
list
[
int
]])
->
torch
.
Tensor
:
idx_list
=
[[]
for
_
in
range
(
4
)]
num_grid_per_side
=
self
.
num_grid_per_side
weight_list
=
[[]
for
_
in
range
(
4
)]
m_size
=
self
.
spatial_merge_size
hidden_dim
=
self
.
pos_embed
.
embedding_dim
outputs
=
[]
for
t
,
h
,
w
in
grid_thw
:
for
t
,
h
,
w
in
grid_thw
:
h_idxs
=
torch
.
linspace
(
0
,
h_idxs
=
torch
.
linspace
(
0
,
num_grid_per_side
-
1
,
num_grid_per_side
-
1
,
h
,
h
,
dtype
=
torch
.
float32
)
dtype
=
torch
.
float32
,
device
=
self
.
device
)
w_idxs
=
torch
.
linspace
(
0
,
w_idxs
=
torch
.
linspace
(
0
,
num_grid_per_side
-
1
,
num_grid_per_side
-
1
,
w
,
w
,
dtype
=
torch
.
float32
)
dtype
=
torch
.
float32
,
device
=
self
.
device
)
h_idxs_floor
=
h_idxs
.
to
(
torch
.
long
)
w_idxs_floor
=
w_idxs
.
to
(
torch
.
long
)
h_floor
=
h_idxs
.
to
(
torch
.
long
)
h_idxs_ceil
=
torch
.
clamp
(
h_idxs
.
to
(
torch
.
long
)
+
1
,
w_floor
=
w_idxs
.
to
(
torch
.
long
)
max
=
num_grid_per_side
-
1
)
h_ceil
=
torch
.
clamp
(
h_floor
+
1
,
max
=
num_grid_per_side
-
1
)
w_idxs_ceil
=
torch
.
clamp
(
w_idxs
.
to
(
torch
.
long
)
+
1
,
w_ceil
=
torch
.
clamp
(
w_floor
+
1
,
max
=
num_grid_per_side
-
1
)
max
=
num_grid_per_side
-
1
)
dh
=
h_idxs
-
h_floor
dh
=
h_idxs
-
h_idxs_floor
dw
=
w_idxs
-
w_floor
dw
=
w_idxs
-
w_idxs_floor
w00
=
((
1
-
dh
)[:,
None
]
*
(
1
-
dw
)[
None
,
:]).
reshape
(
-
1
)
idx_list
[
0
].
extend
(((
h_idxs_floor
*
num_grid_per_side
)[
None
].
T
+
w01
=
((
1
-
dh
)[:,
None
]
*
dw
[
None
,
:]).
reshape
(
-
1
)
w_idxs_floor
[
None
]).
flatten
().
tolist
()
*
t
)
w10
=
(
dh
[:,
None
]
*
(
1
-
dw
)[
None
,
:]).
reshape
(
-
1
)
idx_list
[
1
].
extend
(((
h_idxs_floor
*
num_grid_per_side
)[
None
].
T
+
w11
=
(
dh
[:,
None
]
*
dw
[
None
,
:]).
reshape
(
-
1
)
w_idxs_ceil
[
None
]).
flatten
().
tolist
()
*
t
)
idx_list
[
2
].
extend
(((
h_idxs_ceil
*
num_grid_per_side
)[
None
].
T
+
idx00
=
(
h_floor
[:,
None
]
*
num_grid_per_side
+
w_idxs_floor
[
None
]).
flatten
().
tolist
()
*
t
)
w_floor
[
None
,
:]).
reshape
(
-
1
)
idx_list
[
3
].
extend
(((
h_idxs_ceil
*
num_grid_per_side
)[
None
].
T
+
idx01
=
(
h_floor
[:,
None
]
*
num_grid_per_side
+
w_idxs_ceil
[
None
]).
flatten
().
tolist
()
*
t
)
w_ceil
[
None
,
:]).
reshape
(
-
1
)
idx10
=
(
h_ceil
[:,
None
]
*
num_grid_per_side
+
weight_list
[
0
].
extend
(
w_floor
[
None
,
:]).
reshape
(
-
1
)
((
1
-
dh
)[
None
].
T
*
(
1
-
dw
)[
None
]).
flatten
().
tolist
()
*
t
)
idx11
=
(
h_ceil
[:,
None
]
*
num_grid_per_side
+
weight_list
[
1
].
extend
(
w_ceil
[
None
,
:]).
reshape
(
-
1
)
((
1
-
dh
)[
None
].
T
*
dw
[
None
]).
flatten
().
tolist
()
*
t
)
weight_list
[
2
].
extend
(
indices
=
torch
.
stack
([
idx00
,
idx01
,
idx10
,
idx11
],
dim
=
0
)
(
dh
[
None
].
T
*
(
1
-
dw
)[
None
]).
flatten
().
tolist
()
*
t
)
weights
=
torch
.
stack
([
w00
,
w01
,
w10
,
w11
],
weight_list
[
3
].
extend
(
dim
=
0
).
to
(
dtype
=
self
.
dtype
,
(
dh
[
None
].
T
*
dw
[
None
]).
flatten
().
tolist
()
*
t
)
device
=
self
.
device
)
weights
=
weights
.
unsqueeze
(
-
1
)
device
=
self
.
pos_embed
.
weight
.
device
dtype
=
self
.
pos_embed
.
weight
.
dtype
embeds
=
self
.
pos_embed
(
indices
)
weighted_embeds
=
embeds
*
weights
p0
=
self
.
pos_embed
(
p0
,
p1
,
p2
,
p3
=
weighted_embeds
.
unbind
(
dim
=
0
)
torch
.
tensor
(
combined
=
p0
+
p1
+
p2
+
p3
idx_list
[
0
],
dtype
=
torch
.
long
,
device
=
device
))
*
torch
.
tensor
(
weight_list
[
0
],
dtype
=
dtype
,
device
=
device
)[:,
None
]
combined
=
combined
.
view
(
h
*
w
,
hidden_dim
)
p1
=
self
.
pos_embed
(
repeated
=
combined
.
unsqueeze
(
0
).
expand
(
t
,
-
1
,
-
1
).
contiguous
()
torch
.
tensor
(
repeated
=
repeated
.
view
(
t
,
h
//
m_size
,
m_size
,
w
//
m_size
,
idx_list
[
1
],
dtype
=
torch
.
long
,
device
=
device
))
*
torch
.
tensor
(
m_size
,
hidden_dim
)
weight_list
[
1
],
dtype
=
dtype
,
device
=
device
)[:,
None
]
repeated
=
repeated
.
permute
(
0
,
1
,
3
,
2
,
4
,
p2
=
self
.
pos_embed
(
5
).
reshape
(
-
1
,
hidden_dim
)
torch
.
tensor
(
outputs
.
append
(
repeated
)
idx_list
[
2
],
dtype
=
torch
.
long
,
device
=
device
))
*
torch
.
tensor
(
weight_list
[
2
],
dtype
=
dtype
,
device
=
device
)[:,
None
]
return
torch
.
cat
(
outputs
,
dim
=
0
)
p3
=
self
.
pos_embed
(
torch
.
tensor
(
idx_list
[
3
],
dtype
=
torch
.
long
,
device
=
device
))
*
torch
.
tensor
(
weight_list
[
3
],
dtype
=
dtype
,
device
=
device
)[:,
None
]
patch_pos_embeds
=
p0
+
p1
+
p2
+
p3
patch_pos_embeds
=
patch_pos_embeds
.
split
(
[
t
*
h
*
w
for
t
,
h
,
w
in
grid_thw
])
patch_pos_embeds_permute
=
[]
m_size
=
self
.
spatial_merge_size
for
pos_embed
,
(
t
,
h
,
w
)
in
zip
(
patch_pos_embeds
,
grid_thw
):
pos_embed
=
pos_embed
.
view
(
t
,
h
//
m_size
,
m_size
,
w
//
m_size
,
m_size
,
-
1
).
permute
(
0
,
1
,
3
,
2
,
4
,
5
).
flatten
(
0
,
4
)
patch_pos_embeds_permute
.
append
(
pos_embed
)
patch_pos_embeds
=
torch
.
cat
(
patch_pos_embeds_permute
)
return
patch_pos_embeds
def
compute_attn_mask_seqlen
(
def
compute_attn_mask_seqlen
(
self
,
self
,
...
@@ -477,12 +464,9 @@ class Qwen3_VisionTransformer(nn.Module):
...
@@ -477,12 +464,9 @@ class Qwen3_VisionTransformer(nn.Module):
hidden_states
=
hidden_states
+
pos_embeds
hidden_states
=
hidden_states
+
pos_embeds
rotary_pos_emb
=
self
.
rot_pos_emb
(
grid_thw
)
rotary_pos_emb
=
self
.
rot_pos_emb
(
grid_thw
)
if
isinstance
(
grid_thw
,
list
):
grid_thw_tensor
=
torch
.
tensor
(
grid_thw
,
grid_thw_tensor
=
torch
.
tensor
(
grid_thw
,
device
=
hidden_states
.
device
,
device
=
self
.
device
,
dtype
=
torch
.
int32
)
dtype
=
torch
.
int32
)
else
:
grid_thw_tensor
=
grid_thw
cu_seqlens
=
torch
.
repeat_interleave
(
cu_seqlens
=
torch
.
repeat_interleave
(
grid_thw_tensor
[:,
1
]
*
grid_thw_tensor
[:,
2
],
grid_thw_tensor
[:,
1
]
*
grid_thw_tensor
[:,
2
],
...
@@ -1224,7 +1208,8 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -1224,7 +1208,8 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
grid_thw_list
,
grid_thw_list
,
rope_type
=
"rope_3d"
)
rope_type
=
"rope_3d"
)
else
:
else
:
image_embeds
=
self
.
visual
(
pixel_values
,
grid_thw
=
grid_thw
)
image_embeds
=
self
.
visual
(
pixel_values
,
grid_thw
=
grid_thw_list
)
# Split concatenated embeddings for each image item.
# Split concatenated embeddings for each image item.
# Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
# Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment