Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3d4e7d34
Unverified
Commit
3d4e7d34
authored
Nov 19, 2025
by
Lukas Geiger
Committed by
GitHub
Nov 19, 2025
Browse files
[Model][QwenVL] Simplify cos/sin rotary embedding indexing (#28962)
Signed-off-by:
Lukas Geiger
<
lukas.geiger94@gmail.com
>
parent
6a25ea5f
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
11 additions
and
42 deletions
+11
-42
vllm/model_executor/models/glm4_1v.py
vllm/model_executor/models/glm4_1v.py
+2
-7
vllm/model_executor/models/qwen2_5_vl.py
vllm/model_executor/models/qwen2_5_vl.py
+2
-7
vllm/model_executor/models/qwen2_vl.py
vllm/model_executor/models/qwen2_vl.py
+2
-7
vllm/model_executor/models/qwen3_omni_moe_thinker.py
vllm/model_executor/models/qwen3_omni_moe_thinker.py
+2
-7
vllm/model_executor/models/qwen3_vl.py
vllm/model_executor/models/qwen3_vl.py
+3
-14
No files found.
vllm/model_executor/models/glm4_1v.py
View file @
3d4e7d34
...
@@ -797,13 +797,8 @@ class Glm4vVisionTransformer(nn.Module):
...
@@ -797,13 +797,8 @@ class Glm4vVisionTransformer(nn.Module):
# Use pre-computed cos_sin_cache from RotaryEmbedding
# Use pre-computed cos_sin_cache from RotaryEmbedding
cos
,
sin
=
self
.
rotary_pos_emb
.
get_cos_sin
(
max_grid_size
)
cos
,
sin
=
self
.
rotary_pos_emb
.
get_cos_sin
(
max_grid_size
)
cos_h
=
cos
[
pos_ids
[:,
0
]]
# (num_tokens, rotary_dim // 2)
cos_combined
=
cos
[
pos_ids
].
flatten
(
1
)
cos_w
=
cos
[
pos_ids
[:,
1
]]
sin_combined
=
sin
[
pos_ids
].
flatten
(
1
)
sin_h
=
sin
[
pos_ids
[:,
0
]]
sin_w
=
sin
[
pos_ids
[:,
1
]]
cos_combined
=
torch
.
cat
([
cos_h
,
cos_w
],
dim
=-
1
)
sin_combined
=
torch
.
cat
([
sin_h
,
sin_w
],
dim
=-
1
)
return
cos_combined
,
sin_combined
,
pos_ids
return
cos_combined
,
sin_combined
,
pos_ids
def
compute_attn_mask_seqlen
(
def
compute_attn_mask_seqlen
(
...
...
vllm/model_executor/models/qwen2_5_vl.py
View file @
3d4e7d34
...
@@ -738,13 +738,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
...
@@ -738,13 +738,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
# Use pre-computed cos_sin_cache from RotaryEmbedding
# Use pre-computed cos_sin_cache from RotaryEmbedding
cos
,
sin
=
self
.
rotary_pos_emb
.
get_cos_sin
(
max_size
)
cos
,
sin
=
self
.
rotary_pos_emb
.
get_cos_sin
(
max_size
)
cos_h
=
cos
[
pos_ids
[:,
0
]]
# (num_tokens, rotary_dim // 2)
cos_combined
=
cos
[
pos_ids
].
flatten
(
1
)
cos_w
=
cos
[
pos_ids
[:,
1
]]
sin_combined
=
sin
[
pos_ids
].
flatten
(
1
)
sin_h
=
sin
[
pos_ids
[:,
0
]]
sin_w
=
sin
[
pos_ids
[:,
1
]]
cos_combined
=
torch
.
cat
([
cos_h
,
cos_w
],
dim
=-
1
)
sin_combined
=
torch
.
cat
([
sin_h
,
sin_w
],
dim
=-
1
)
cos_combined
=
cos_combined
.
reshape
(
cos_combined
=
cos_combined
.
reshape
(
cos_combined
.
shape
[
0
]
//
self
.
spatial_merge_unit
,
cos_combined
.
shape
[
0
]
//
self
.
spatial_merge_unit
,
...
...
vllm/model_executor/models/qwen2_vl.py
View file @
3d4e7d34
...
@@ -724,13 +724,8 @@ class Qwen2VisionTransformer(nn.Module):
...
@@ -724,13 +724,8 @@ class Qwen2VisionTransformer(nn.Module):
# Use pre-computed cos_sin_cache from RotaryEmbedding
# Use pre-computed cos_sin_cache from RotaryEmbedding
cos
,
sin
=
self
.
rotary_pos_emb
.
get_cos_sin
(
max_grid_size
)
cos
,
sin
=
self
.
rotary_pos_emb
.
get_cos_sin
(
max_grid_size
)
cos_h
=
cos
[
pos_ids
[:,
0
]]
# (num_tokens, rotary_dim // 2)
cos_combined
=
cos
[
pos_ids
].
flatten
(
1
)
cos_w
=
cos
[
pos_ids
[:,
1
]]
sin_combined
=
sin
[
pos_ids
].
flatten
(
1
)
sin_h
=
sin
[
pos_ids
[:,
0
]]
sin_w
=
sin
[
pos_ids
[:,
1
]]
cos_combined
=
torch
.
cat
([
cos_h
,
cos_w
],
dim
=-
1
)
sin_combined
=
torch
.
cat
([
sin_h
,
sin_w
],
dim
=-
1
)
return
cos_combined
,
sin_combined
return
cos_combined
,
sin_combined
def
compute_attn_mask_seqlen
(
def
compute_attn_mask_seqlen
(
...
...
vllm/model_executor/models/qwen3_omni_moe_thinker.py
View file @
3d4e7d34
...
@@ -428,13 +428,8 @@ class Qwen3Omni_VisionTransformer(nn.Module):
...
@@ -428,13 +428,8 @@ class Qwen3Omni_VisionTransformer(nn.Module):
# Use pre-computed cos_sin_cache from RotaryEmbedding
# Use pre-computed cos_sin_cache from RotaryEmbedding
cos
,
sin
=
self
.
rotary_pos_emb
.
get_cos_sin
(
max_grid_size
)
cos
,
sin
=
self
.
rotary_pos_emb
.
get_cos_sin
(
max_grid_size
)
cos_h
=
cos
[
pos_ids
[:,
0
]]
# (num_tokens, rotary_dim // 2)
cos_combined
=
cos
[
pos_ids
].
flatten
(
1
)
cos_w
=
cos
[
pos_ids
[:,
1
]]
sin_combined
=
sin
[
pos_ids
].
flatten
(
1
)
sin_h
=
sin
[
pos_ids
[:,
0
]]
sin_w
=
sin
[
pos_ids
[:,
1
]]
cos_combined
=
torch
.
cat
([
cos_h
,
cos_w
],
dim
=-
1
)
sin_combined
=
torch
.
cat
([
sin_h
,
sin_w
],
dim
=-
1
)
return
cos_combined
,
sin_combined
return
cos_combined
,
sin_combined
...
...
vllm/model_executor/models/qwen3_vl.py
View file @
3d4e7d34
...
@@ -459,18 +459,13 @@ class Qwen3_VisionTransformer(nn.Module):
...
@@ -459,18 +459,13 @@ class Qwen3_VisionTransformer(nn.Module):
else
self
.
rot_pos_ids
(
h
,
w
,
self
.
spatial_merge_size
).
repeat
(
t
,
1
)
else
self
.
rot_pos_ids
(
h
,
w
,
self
.
spatial_merge_size
).
repeat
(
t
,
1
)
for
t
,
h
,
w
in
grid_thw
for
t
,
h
,
w
in
grid_thw
]
]
pos_ids
=
torch
.
cat
(
pos_ids
,
dim
=
0
)
pos_ids
=
torch
.
cat
(
pos_ids
,
dim
=
0
)
.
to
(
self
.
device
,
non_blocking
=
True
)
# Use pre-computed cos_sin_cache from RotaryEmbedding
# Use pre-computed cos_sin_cache from RotaryEmbedding
cos
,
sin
=
self
.
rotary_pos_emb
.
get_cos_sin
(
max_grid_size
)
cos
,
sin
=
self
.
rotary_pos_emb
.
get_cos_sin
(
max_grid_size
)
cos_h
=
cos
[
pos_ids
[:,
0
]]
# (num_tokens, rotary_dim // 2)
cos_combined
=
cos
[
pos_ids
].
flatten
(
1
)
cos_w
=
cos
[
pos_ids
[:,
1
]]
sin_combined
=
sin
[
pos_ids
].
flatten
(
1
)
sin_h
=
sin
[
pos_ids
[:,
0
]]
sin_w
=
sin
[
pos_ids
[:,
1
]]
cos_combined
=
torch
.
cat
([
cos_h
,
cos_w
],
dim
=-
1
)
sin_combined
=
torch
.
cat
([
sin_h
,
sin_w
],
dim
=-
1
)
return
cos_combined
,
sin_combined
return
cos_combined
,
sin_combined
...
@@ -566,12 +561,6 @@ class Qwen3_VisionTransformer(nn.Module):
...
@@ -566,12 +561,6 @@ class Qwen3_VisionTransformer(nn.Module):
pos_embeds
=
self
.
fast_pos_embed_interpolate
(
grid_thw_list
)
pos_embeds
=
self
.
fast_pos_embed_interpolate
(
grid_thw_list
)
hidden_states
=
hidden_states
+
pos_embeds
hidden_states
=
hidden_states
+
pos_embeds
rotary_pos_emb_cos
,
rotary_pos_emb_sin
=
self
.
rot_pos_emb
(
grid_thw_list
)
rotary_pos_emb_cos
,
rotary_pos_emb_sin
=
self
.
rot_pos_emb
(
grid_thw_list
)
rotary_pos_emb_cos
=
rotary_pos_emb_cos
.
to
(
hidden_states
.
device
,
non_blocking
=
True
)
rotary_pos_emb_sin
=
rotary_pos_emb_sin
.
to
(
hidden_states
.
device
,
non_blocking
=
True
)
cu_seqlens
=
torch
.
repeat_interleave
(
cu_seqlens
=
torch
.
repeat_interleave
(
grid_thw
[:,
1
]
*
grid_thw
[:,
2
],
grid_thw
[:,
0
]
grid_thw
[:,
1
]
*
grid_thw
[:,
2
],
grid_thw
[:,
0
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment