Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b9489f51
Unverified
Commit
b9489f51
authored
Nov 18, 2025
by
Canlin Guo
Committed by
GitHub
Nov 18, 2025
Browse files
[Model][Perf] Use cos and sin cache in QwenVL (#28798)
Signed-off-by:
gcanlin
<
canlinguosdu@gmail.com
>
parent
285eaa42
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
218 additions
and
217 deletions
+218
-217
vllm/model_executor/layers/rotary_embedding/base.py
vllm/model_executor/layers/rotary_embedding/base.py
+5
-0
vllm/model_executor/models/glm4_1v.py
vllm/model_executor/models/glm4_1v.py
+38
-50
vllm/model_executor/models/qwen2_5_vl.py
vllm/model_executor/models/qwen2_5_vl.py
+65
-58
vllm/model_executor/models/qwen2_vl.py
vllm/model_executor/models/qwen2_vl.py
+46
-89
vllm/model_executor/models/qwen3_omni_moe_thinker.py
vllm/model_executor/models/qwen3_omni_moe_thinker.py
+30
-10
vllm/model_executor/models/qwen3_vl.py
vllm/model_executor/models/qwen3_vl.py
+34
-10
No files found.
vllm/model_executor/layers/rotary_embedding/base.py
View file @
b9489f51
...
@@ -83,6 +83,11 @@ class RotaryEmbeddingBase(CustomOp):
...
@@ -83,6 +83,11 @@ class RotaryEmbeddingBase(CustomOp):
):
):
self
.
cos_sin_cache
=
self
.
cos_sin_cache
.
to
(
query
.
device
,
dtype
=
query
.
dtype
)
self
.
cos_sin_cache
=
self
.
cos_sin_cache
.
to
(
query
.
device
,
dtype
=
query
.
dtype
)
def
get_cos_sin
(
self
,
seqlen
:
int
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
cos_sin
=
self
.
cos_sin_cache
[:
seqlen
]
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
return
cos
,
sin
class
RotaryEmbedding
(
RotaryEmbeddingBase
):
class
RotaryEmbedding
(
RotaryEmbeddingBase
):
def
__init__
(
def
__init__
(
...
...
vllm/model_executor/models/glm4_1v.py
View file @
b9489f51
...
@@ -65,6 +65,7 @@ from vllm.model_executor.layers.linear import (
...
@@ -65,6 +65,7 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear
,
RowParallelLinear
,
)
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
...
@@ -341,7 +342,8 @@ class Glm4vVisionAttention(nn.Module):
...
@@ -341,7 +342,8 @@ class Glm4vVisionAttention(nn.Module):
self
,
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
rotary_pos_emb
:
torch
.
Tensor
,
rotary_pos_emb_cos
:
torch
.
Tensor
,
rotary_pos_emb_sin
:
torch
.
Tensor
,
max_seqlen
:
int
|
None
=
None
,
# Only used for Flash Attention
max_seqlen
:
int
|
None
=
None
,
# Only used for Flash Attention
seqlens
:
list
[
int
]
|
None
=
None
,
# Only used for xFormers
seqlens
:
list
[
int
]
|
None
=
None
,
# Only used for xFormers
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -353,10 +355,12 @@ class Glm4vVisionAttention(nn.Module):
...
@@ -353,10 +355,12 @@ class Glm4vVisionAttention(nn.Module):
batch_size
=
q
.
shape
[
1
]
batch_size
=
q
.
shape
[
1
]
q
,
k
,
v
=
(
rearrange
(
x
,
"s b ... -> b s ..."
).
contiguous
()
for
x
in
(
q
,
k
,
v
))
q
,
k
,
v
=
(
rearrange
(
x
,
"s b ... -> b s ..."
).
contiguous
()
for
x
in
(
q
,
k
,
v
))
if
rotary_pos_emb
is
not
None
:
if
rotary_pos_emb
_cos
is
not
None
and
rotary_pos_emb_sin
is
not
None
:
# [2 * b, s, heads, head_dim]
# [2 * b, s, heads, head_dim]
qk_concat
=
torch
.
cat
([
q
,
k
],
dim
=
0
)
qk_concat
=
torch
.
cat
([
q
,
k
],
dim
=
0
)
qk_rotated
=
apply_rotary_pos_emb_vision
(
qk_concat
,
rotary_pos_emb
)
qk_rotated
=
apply_rotary_pos_emb_vision
(
qk_concat
,
rotary_pos_emb_cos
,
rotary_pos_emb_sin
)
q
,
k
=
torch
.
chunk
(
qk_rotated
,
2
,
dim
=
0
)
q
,
k
=
torch
.
chunk
(
qk_rotated
,
2
,
dim
=
0
)
if
self
.
is_flash_attn_backend
:
if
self
.
is_flash_attn_backend
:
...
@@ -454,14 +458,16 @@ class Glm4vVisionBlock(nn.Module):
...
@@ -454,14 +458,16 @@ class Glm4vVisionBlock(nn.Module):
self
,
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
rotary_pos_emb
:
torch
.
Tensor
,
rotary_pos_emb_cos
:
torch
.
Tensor
,
rotary_pos_emb_sin
:
torch
.
Tensor
,
max_seqlen
:
int
|
None
=
None
,
# Only used for Flash Attention
max_seqlen
:
int
|
None
=
None
,
# Only used for Flash Attention
seqlens
:
list
[
int
]
|
None
=
None
,
# Only used for xFormers
seqlens
:
list
[
int
]
|
None
=
None
,
# Only used for xFormers
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
x_attn
=
self
.
attn
(
x_attn
=
self
.
attn
(
self
.
norm1
(
x
),
self
.
norm1
(
x
),
cu_seqlens
=
cu_seqlens
,
cu_seqlens
=
cu_seqlens
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_emb_cos
=
rotary_pos_emb_cos
,
rotary_pos_emb_sin
=
rotary_pos_emb_sin
,
max_seqlen
=
max_seqlen
,
max_seqlen
=
max_seqlen
,
seqlens
=
seqlens
,
seqlens
=
seqlens
,
)
)
...
@@ -660,44 +666,6 @@ class Glm4vVisionEmbeddings(nn.Module):
...
@@ -660,44 +666,6 @@ class Glm4vVisionEmbeddings(nn.Module):
return
embeddings
return
embeddings
class
Glm4vVisionRotaryEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
theta
:
float
=
10000.0
)
->
None
:
super
().
__init__
()
self
.
dim
=
dim
self
.
theta
=
theta
inv_freq
=
1.0
/
(
theta
**
(
torch
.
arange
(
0
,
dim
,
2
,
dtype
=
torch
.
float
)
/
dim
))
self
.
register_buffer
(
"inv_freq"
,
inv_freq
,
persistent
=
False
)
self
.
_seq_len_cached
=
0
self
.
_freqs_cached
=
None
def
update_freqs_cache
(
self
,
seqlen
:
int
)
->
None
:
if
seqlen
>
self
.
_seq_len_cached
:
seqlen
*=
2
self
.
_seq_len_cached
=
seqlen
self
.
inv_freq
=
1.0
/
(
self
.
theta
**
(
torch
.
arange
(
0
,
self
.
dim
,
2
,
dtype
=
torch
.
float
,
device
=
self
.
inv_freq
.
device
,
)
/
self
.
dim
)
)
seq
=
torch
.
arange
(
seqlen
,
device
=
self
.
inv_freq
.
device
,
dtype
=
self
.
inv_freq
.
dtype
)
freqs
=
torch
.
outer
(
seq
,
self
.
inv_freq
)
self
.
_freqs_cached
=
freqs
def
forward
(
self
,
seqlen
:
int
)
->
torch
.
Tensor
:
self
.
update_freqs_cache
(
seqlen
)
return
self
.
_freqs_cached
[:
seqlen
]
class
Glm4vVisionTransformer
(
nn
.
Module
):
class
Glm4vVisionTransformer
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -731,7 +699,13 @@ class Glm4vVisionTransformer(nn.Module):
...
@@ -731,7 +699,13 @@ class Glm4vVisionTransformer(nn.Module):
norm_layer
=
partial
(
RMSNorm
,
eps
=
norm_eps
)
norm_layer
=
partial
(
RMSNorm
,
eps
=
norm_eps
)
head_dim
=
self
.
hidden_size
//
self
.
num_heads
head_dim
=
self
.
hidden_size
//
self
.
num_heads
self
.
rotary_pos_emb
=
Glm4vVisionRotaryEmbedding
(
head_dim
//
2
)
self
.
rotary_pos_emb
=
get_rope
(
head_size
=
head_dim
,
rotary_dim
=
head_dim
//
2
,
max_position
=
8192
,
base
=
10000.0
,
is_neox_style
=
True
,
)
self
.
blocks
=
nn
.
ModuleList
(
self
.
blocks
=
nn
.
ModuleList
(
[
[
Glm4vVisionBlock
(
Glm4vVisionBlock
(
...
@@ -789,7 +763,9 @@ class Glm4vVisionTransformer(nn.Module):
...
@@ -789,7 +763,9 @@ class Glm4vVisionTransformer(nn.Module):
def
device
(
self
)
->
torch
.
device
:
def
device
(
self
)
->
torch
.
device
:
return
self
.
patch_embed
.
proj
.
weight
.
device
return
self
.
patch_embed
.
proj
.
weight
.
device
def
rot_pos_emb
(
self
,
grid_thw
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
rot_pos_emb
(
self
,
grid_thw
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
pos_ids
=
[]
pos_ids
=
[]
for
t
,
h
,
w
in
grid_thw
:
for
t
,
h
,
w
in
grid_thw
:
hpos_ids
=
torch
.
arange
(
h
).
unsqueeze
(
1
).
expand
(
-
1
,
w
)
hpos_ids
=
torch
.
arange
(
h
).
unsqueeze
(
1
).
expand
(
-
1
,
w
)
...
@@ -817,9 +793,18 @@ class Glm4vVisionTransformer(nn.Module):
...
@@ -817,9 +793,18 @@ class Glm4vVisionTransformer(nn.Module):
pos_ids
.
append
(
torch
.
stack
([
hpos_ids
,
wpos_ids
],
dim
=-
1
).
repeat
(
t
,
1
))
pos_ids
.
append
(
torch
.
stack
([
hpos_ids
,
wpos_ids
],
dim
=-
1
).
repeat
(
t
,
1
))
pos_ids
=
torch
.
cat
(
pos_ids
,
dim
=
0
)
pos_ids
=
torch
.
cat
(
pos_ids
,
dim
=
0
)
max_grid_size
=
grid_thw
[:,
1
:].
max
()
max_grid_size
=
grid_thw
[:,
1
:].
max
()
rotary_pos_emb_full
=
self
.
rotary_pos_emb
(
max_grid_size
)
rotary_pos_emb
=
rotary_pos_emb_full
[
pos_ids
].
flatten
(
1
)
# Use pre-computed cos_sin_cache from RotaryEmbedding
return
rotary_pos_emb
,
pos_ids
cos
,
sin
=
self
.
rotary_pos_emb
.
get_cos_sin
(
max_grid_size
)
cos_h
=
cos
[
pos_ids
[:,
0
]]
# (num_tokens, rotary_dim // 2)
cos_w
=
cos
[
pos_ids
[:,
1
]]
sin_h
=
sin
[
pos_ids
[:,
0
]]
sin_w
=
sin
[
pos_ids
[:,
1
]]
cos_combined
=
torch
.
cat
([
cos_h
,
cos_w
],
dim
=-
1
)
sin_combined
=
torch
.
cat
([
sin_h
,
sin_w
],
dim
=-
1
)
return
cos_combined
,
sin_combined
,
pos_ids
def
compute_attn_mask_seqlen
(
def
compute_attn_mask_seqlen
(
self
,
self
,
...
@@ -848,7 +833,9 @@ class Glm4vVisionTransformer(nn.Module):
...
@@ -848,7 +833,9 @@ class Glm4vVisionTransformer(nn.Module):
x
=
self
.
post_conv_layernorm
(
x
)
x
=
self
.
post_conv_layernorm
(
x
)
# compute position embedding
# compute position embedding
rotary_pos_emb
,
image_type_ids
=
self
.
rot_pos_emb
(
grid_thw
)
rotary_pos_emb_cos
,
rotary_pos_emb_sin
,
image_type_ids
=
self
.
rot_pos_emb
(
grid_thw
)
# compute cu_seqlens
# compute cu_seqlens
cu_seqlens
=
torch
.
repeat_interleave
(
cu_seqlens
=
torch
.
repeat_interleave
(
grid_thw
[:,
1
]
*
grid_thw
[:,
2
],
grid_thw
[:,
0
]
grid_thw
[:,
1
]
*
grid_thw
[:,
2
],
grid_thw
[:,
0
]
...
@@ -867,7 +854,8 @@ class Glm4vVisionTransformer(nn.Module):
...
@@ -867,7 +854,8 @@ class Glm4vVisionTransformer(nn.Module):
x
=
blk
(
x
=
blk
(
x
,
x
,
cu_seqlens
=
cu_seqlens
,
cu_seqlens
=
cu_seqlens
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_emb_cos
=
rotary_pos_emb_cos
,
rotary_pos_emb_sin
=
rotary_pos_emb_sin
,
max_seqlen
=
max_seqlen
,
max_seqlen
=
max_seqlen
,
seqlens
=
seqlens
,
seqlens
=
seqlens
,
)
)
...
...
vllm/model_executor/models/qwen2_5_vl.py
View file @
b9489f51
...
@@ -64,6 +64,7 @@ from vllm.model_executor.layers.linear import (
...
@@ -64,6 +64,7 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear
,
RowParallelLinear
,
)
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.models.vision
import
should_torch_compile_mm_vit
from
vllm.model_executor.models.vision
import
should_torch_compile_mm_vit
...
@@ -363,7 +364,8 @@ class Qwen2_5_VisionAttention(nn.Module):
...
@@ -363,7 +364,8 @@ class Qwen2_5_VisionAttention(nn.Module):
self
,
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
rotary_pos_emb
:
torch
.
Tensor
,
rotary_pos_emb_cos
:
torch
.
Tensor
,
rotary_pos_emb_sin
:
torch
.
Tensor
,
max_seqlen
:
torch
.
Tensor
,
# Only used for Flash Attention
max_seqlen
:
torch
.
Tensor
,
# Only used for Flash Attention
seqlens
:
torch
.
Tensor
,
# Only used for xFormers
seqlens
:
torch
.
Tensor
,
# Only used for xFormers
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -378,13 +380,15 @@ class Qwen2_5_VisionAttention(nn.Module):
...
@@ -378,13 +380,15 @@ class Qwen2_5_VisionAttention(nn.Module):
head
=
self
.
num_attention_heads_per_partition
,
head
=
self
.
num_attention_heads_per_partition
,
)
)
if
rotary_pos_emb
is
not
None
:
if
rotary_pos_emb
_cos
is
not
None
and
rotary_pos_emb_sin
is
not
None
:
qk
,
v
=
qkv
[:,
:,
:
2
],
qkv
[:,
:,
2
]
qk
,
v
=
qkv
[:,
:,
:
2
],
qkv
[:,
:,
2
]
qk_reshaped
=
einops
.
rearrange
(
qk_reshaped
=
einops
.
rearrange
(
qk
,
"b s two head head_dim -> (two b) s head head_dim"
,
two
=
2
qk
,
"b s two head head_dim -> (two b) s head head_dim"
,
two
=
2
)
)
qk_rotated
=
apply_rotary_pos_emb_vision
(
qk_reshaped
,
rotary_pos_emb
)
qk_rotated
=
apply_rotary_pos_emb_vision
(
qk_reshaped
,
cos
=
rotary_pos_emb_cos
,
sin
=
rotary_pos_emb_sin
)
qk_rotated
=
qk_rotated
.
view
(
qk_rotated
=
qk_rotated
.
view
(
2
,
2
,
batch_size
,
batch_size
,
...
@@ -434,7 +438,8 @@ class Qwen2_5_VisionAttention(nn.Module):
...
@@ -434,7 +438,8 @@ class Qwen2_5_VisionAttention(nn.Module):
dynamic_arg_dims
=
{
dynamic_arg_dims
=
{
"x"
:
0
,
"x"
:
0
,
"cu_seqlens"
:
0
,
"cu_seqlens"
:
0
,
"rotary_pos_emb"
:
0
,
"rotary_pos_emb_cos"
:
0
,
"rotary_pos_emb_sin"
:
0
,
"seqlens"
:
0
,
"seqlens"
:
0
,
},
},
mark_unbacked_dims
=
{
"seqlens"
:
0
},
mark_unbacked_dims
=
{
"seqlens"
:
0
},
...
@@ -485,14 +490,16 @@ class Qwen2_5_VisionBlock(nn.Module):
...
@@ -485,14 +490,16 @@ class Qwen2_5_VisionBlock(nn.Module):
self
,
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
rotary_pos_emb
:
torch
.
Tensor
,
rotary_pos_emb_cos
:
torch
.
Tensor
,
rotary_pos_emb_sin
:
torch
.
Tensor
,
max_seqlen
:
torch
.
Tensor
,
# Only used for Flash Attention
max_seqlen
:
torch
.
Tensor
,
# Only used for Flash Attention
seqlens
:
torch
.
Tensor
,
# Only used for xFormers
seqlens
:
torch
.
Tensor
,
# Only used for xFormers
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
x_attn
=
self
.
attn
(
x_attn
=
self
.
attn
(
self
.
norm1
(
x
),
self
.
norm1
(
x
),
cu_seqlens
=
cu_seqlens
,
cu_seqlens
=
cu_seqlens
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_emb_cos
=
rotary_pos_emb_cos
,
rotary_pos_emb_sin
=
rotary_pos_emb_sin
,
max_seqlen
=
max_seqlen
,
max_seqlen
=
max_seqlen
,
seqlens
=
seqlens
,
seqlens
=
seqlens
,
)
)
...
@@ -588,42 +595,6 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
...
@@ -588,42 +595,6 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
return
out
return
out
class
Qwen2_5_VisionRotaryEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
theta
:
float
=
10000.0
)
->
None
:
super
().
__init__
()
self
.
dim
=
dim
self
.
theta
=
theta
inv_freq
=
1.0
/
(
theta
**
(
torch
.
arange
(
0
,
dim
,
2
,
dtype
=
torch
.
float
,
device
=
"cpu"
)
/
dim
)
)
self
.
register_buffer
(
"inv_freq"
,
inv_freq
,
persistent
=
False
)
self
.
_seq_len_cached
=
0
self
.
_freqs_cached
=
None
def
update_freqs_cache
(
self
,
seqlen
:
int
)
->
None
:
if
seqlen
>
self
.
_seq_len_cached
:
seqlen
*=
2
self
.
_seq_len_cached
=
seqlen
self
.
inv_freq
=
1.0
/
(
self
.
theta
**
(
torch
.
arange
(
0
,
self
.
dim
,
2
,
dtype
=
torch
.
float
,
device
=
self
.
inv_freq
.
device
)
/
self
.
dim
)
)
seq
=
torch
.
arange
(
seqlen
,
device
=
self
.
inv_freq
.
device
,
dtype
=
self
.
inv_freq
.
dtype
)
freqs
=
torch
.
outer
(
seq
,
self
.
inv_freq
)
self
.
_freqs_cached
=
freqs
def
forward
(
self
,
seqlen
:
int
)
->
torch
.
Tensor
:
self
.
update_freqs_cache
(
seqlen
)
return
self
.
_freqs_cached
[:
seqlen
]
class
Qwen2_5_VisionTransformer
(
nn
.
Module
):
class
Qwen2_5_VisionTransformer
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -666,7 +637,13 @@ class Qwen2_5_VisionTransformer(nn.Module):
...
@@ -666,7 +637,13 @@ class Qwen2_5_VisionTransformer(nn.Module):
norm_layer
=
partial
(
RMSNorm
,
eps
=
norm_eps
)
norm_layer
=
partial
(
RMSNorm
,
eps
=
norm_eps
)
head_dim
=
self
.
hidden_size
//
self
.
num_heads
head_dim
=
self
.
hidden_size
//
self
.
num_heads
self
.
rotary_pos_emb
=
Qwen2_5_VisionRotaryEmbedding
(
head_dim
//
2
)
self
.
rotary_pos_emb
=
get_rope
(
head_size
=
head_dim
,
rotary_dim
=
head_dim
//
2
,
max_position
=
8192
,
base
=
10000.0
,
is_neox_style
=
True
,
)
use_upstream_fa
=
False
use_upstream_fa
=
False
self
.
attn_backend
=
get_vit_attn_backend
(
self
.
attn_backend
=
get_vit_attn_backend
(
...
@@ -757,15 +734,30 @@ class Qwen2_5_VisionTransformer(nn.Module):
...
@@ -757,15 +734,30 @@ class Qwen2_5_VisionTransformer(nn.Module):
)
)
pos_ids
=
torch
.
stack
([
hpos_ids
,
wpos_ids
],
dim
=-
1
).
repeat
(
t
,
1
)
pos_ids
=
torch
.
stack
([
hpos_ids
,
wpos_ids
],
dim
=-
1
).
repeat
(
t
,
1
)
max_size
=
max
(
h
,
w
)
max_size
=
max
(
h
,
w
)
rotary_pos_emb_full
=
self
.
rotary_pos_emb
(
max_size
)
rotary_pos_emb
=
rotary_pos_emb_full
[
pos_ids
].
flatten
(
1
)
# Use pre-computed cos_sin_cache from RotaryEmbedding
rotary_pos_emb
=
rotary_pos_emb
.
reshape
(
cos
,
sin
=
self
.
rotary_pos_emb
.
get_cos_sin
(
max_size
)
rotary_pos_emb
.
shape
[
0
]
//
self
.
spatial_merge_unit
,
cos_h
=
cos
[
pos_ids
[:,
0
]]
# (num_tokens, rotary_dim // 2)
cos_w
=
cos
[
pos_ids
[:,
1
]]
sin_h
=
sin
[
pos_ids
[:,
0
]]
sin_w
=
sin
[
pos_ids
[:,
1
]]
cos_combined
=
torch
.
cat
([
cos_h
,
cos_w
],
dim
=-
1
)
sin_combined
=
torch
.
cat
([
sin_h
,
sin_w
],
dim
=-
1
)
cos_combined
=
cos_combined
.
reshape
(
cos_combined
.
shape
[
0
]
//
self
.
spatial_merge_unit
,
self
.
spatial_merge_unit
,
-
1
,
)
sin_combined
=
sin_combined
.
reshape
(
sin_combined
.
shape
[
0
]
//
self
.
spatial_merge_unit
,
self
.
spatial_merge_unit
,
self
.
spatial_merge_unit
,
-
1
,
-
1
,
)
)
return
rotary_pos_emb
return
cos_combined
,
sin_combined
def
get_window_index_thw
(
self
,
grid_t
,
grid_h
,
grid_w
):
def
get_window_index_thw
(
self
,
grid_t
,
grid_h
,
grid_w
):
vit_merger_window_size
=
(
vit_merger_window_size
=
(
...
@@ -807,14 +799,19 @@ class Qwen2_5_VisionTransformer(nn.Module):
...
@@ -807,14 +799,19 @@ class Qwen2_5_VisionTransformer(nn.Module):
@
lru_cache
(
maxsize
=
1024
)
# noqa: B019
@
lru_cache
(
maxsize
=
1024
)
# noqa: B019
def
get_rope_by_thw
(
self
,
t
,
h
,
w
):
def
get_rope_by_thw
(
self
,
t
,
h
,
w
):
window_index_thw
,
cu_seqlens_window_thw
=
self
.
get_window_index_thw
(
t
,
h
,
w
)
window_index_thw
,
cu_seqlens_window_thw
=
self
.
get_window_index_thw
(
t
,
h
,
w
)
rotary_pos_emb_thw
=
self
.
rotary_pos_emb_thw
(
t
,
h
,
w
)
cos_thw
,
sin_thw
=
self
.
rotary_pos_emb_thw
(
t
,
h
,
w
)
rotary_pos_emb_thw
=
rotary_pos_emb_thw
[
window_index_thw
,
:,
:]
rotary_pos_emb_thw
=
rotary_pos_emb_thw
.
flatten
(
start_dim
=
0
,
end_dim
=
1
)
cos_thw
=
cos_thw
[
window_index_thw
,
:,
:]
cos_thw
=
cos_thw
.
flatten
(
start_dim
=
0
,
end_dim
=
1
)
sin_thw
=
sin_thw
[
window_index_thw
,
:,
:]
sin_thw
=
sin_thw
.
flatten
(
start_dim
=
0
,
end_dim
=
1
)
cu_seqlens_thw
=
torch
.
repeat_interleave
(
cu_seqlens_thw
=
torch
.
repeat_interleave
(
torch
.
tensor
([
h
*
w
],
dtype
=
torch
.
int32
),
t
torch
.
tensor
([
h
*
w
],
dtype
=
torch
.
int32
),
t
)
)
return
(
return
(
rotary_pos_emb_thw
,
cos_thw
,
sin_thw
,
window_index_thw
,
window_index_thw
,
cu_seqlens_window_thw
,
cu_seqlens_window_thw
,
cu_seqlens_thw
,
cu_seqlens_thw
,
...
@@ -849,7 +846,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
...
@@ -849,7 +846,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# patchify
# patchify
seq_len
,
_
=
x
.
size
()
seq_len
,
_
=
x
.
size
()
rotary_pos_emb
=
[]
rotary_pos_emb_cos
=
[]
rotary_pos_emb_sin
=
[]
window_index
:
list
=
[]
window_index
:
list
=
[]
cu_window_seqlens
:
list
=
[
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
)]
cu_window_seqlens
:
list
=
[
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
)]
cu_seqlens
:
list
=
[]
cu_seqlens
:
list
=
[]
...
@@ -865,7 +863,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
...
@@ -865,7 +863,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
llm_w
=
w
//
self
.
spatial_merge_size
llm_w
=
w
//
self
.
spatial_merge_size
(
(
rotary_pos_emb_thw
,
cos_thw
,
sin_thw
,
window_index_thw
,
window_index_thw
,
cu_seqlens_window_thw
,
cu_seqlens_window_thw
,
cu_seqlens_thw
,
cu_seqlens_thw
,
...
@@ -878,11 +877,13 @@ class Qwen2_5_VisionTransformer(nn.Module):
...
@@ -878,11 +877,13 @@ class Qwen2_5_VisionTransformer(nn.Module):
cu_window_seqlens_last
=
cu_seqlens_window_thw
[
-
1
]
cu_window_seqlens_last
=
cu_seqlens_window_thw
[
-
1
]
cu_window_seqlens
.
append
(
cu_seqlens_window_thw
)
cu_window_seqlens
.
append
(
cu_seqlens_window_thw
)
rotary_pos_emb
.
append
(
rotary_pos_emb_thw
)
rotary_pos_emb_cos
.
append
(
cos_thw
)
rotary_pos_emb_sin
.
append
(
sin_thw
)
cu_seqlens
.
append
(
cu_seqlens_thw
)
cu_seqlens
.
append
(
cu_seqlens_thw
)
rotary_pos_emb
=
torch
.
cat
(
rotary_pos_emb
)
rotary_pos_emb_cos
=
torch
.
cat
(
rotary_pos_emb_cos
)
rotary_pos_emb_sin
=
torch
.
cat
(
rotary_pos_emb_sin
)
window_index
=
torch
.
cat
(
window_index
)
window_index
=
torch
.
cat
(
window_index
)
# compute reverse indices
# compute reverse indices
reverse_indices
=
self
.
invert_permutation
(
window_index
)
reverse_indices
=
self
.
invert_permutation
(
window_index
)
...
@@ -901,7 +902,12 @@ class Qwen2_5_VisionTransformer(nn.Module):
...
@@ -901,7 +902,12 @@ class Qwen2_5_VisionTransformer(nn.Module):
cu_seqlens
=
cu_seqlens
.
to
(
device
=
self
.
device
,
non_blocking
=
True
)
cu_seqlens
=
cu_seqlens
.
to
(
device
=
self
.
device
,
non_blocking
=
True
)
cu_window_seqlens
=
cu_window_seqlens
.
to
(
device
=
self
.
device
,
non_blocking
=
True
)
cu_window_seqlens
=
cu_window_seqlens
.
to
(
device
=
self
.
device
,
non_blocking
=
True
)
rotary_pos_emb
=
rotary_pos_emb
.
to
(
device
=
self
.
device
,
non_blocking
=
True
)
rotary_pos_emb_cos
=
rotary_pos_emb_cos
.
to
(
device
=
self
.
device
,
non_blocking
=
True
)
rotary_pos_emb_sin
=
rotary_pos_emb_sin
.
to
(
device
=
self
.
device
,
non_blocking
=
True
)
window_index
=
window_index
.
to
(
device
=
hidden_states
.
device
,
non_blocking
=
True
)
window_index
=
window_index
.
to
(
device
=
hidden_states
.
device
,
non_blocking
=
True
)
reverse_indices
=
reverse_indices
.
to
(
reverse_indices
=
reverse_indices
.
to
(
device
=
hidden_states
.
device
,
non_blocking
=
True
device
=
hidden_states
.
device
,
non_blocking
=
True
...
@@ -928,7 +934,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
...
@@ -928,7 +934,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
hidden_states
=
blk
(
hidden_states
=
blk
(
hidden_states
,
hidden_states
,
cu_seqlens
=
cu_seqlens_now
,
cu_seqlens
=
cu_seqlens_now
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_emb_cos
=
rotary_pos_emb_cos
,
rotary_pos_emb_sin
=
rotary_pos_emb_sin
,
max_seqlen
=
max_seqlen_now
,
max_seqlen
=
max_seqlen_now
,
seqlens
=
seqlens_now
,
seqlens
=
seqlens_now
,
)
)
...
...
vllm/model_executor/models/qwen2_vl.py
View file @
b9489f51
...
@@ -32,7 +32,7 @@ from typing import Annotated, Any, Literal, TypeAlias
...
@@ -32,7 +32,7 @@ from typing import Annotated, Any, Literal, TypeAlias
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
einops
import
rearrange
,
repeat
from
einops
import
rearrange
from
transformers
import
BatchFeature
from
transformers
import
BatchFeature
from
transformers.models.qwen2_vl
import
Qwen2VLImageProcessor
,
Qwen2VLProcessor
from
transformers.models.qwen2_vl
import
Qwen2VLImageProcessor
,
Qwen2VLProcessor
from
transformers.models.qwen2_vl.configuration_qwen2_vl
import
(
from
transformers.models.qwen2_vl.configuration_qwen2_vl
import
(
...
@@ -59,7 +59,9 @@ from vllm.model_executor.layers.linear import (
...
@@ -59,7 +59,9 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear
,
RowParallelLinear
,
)
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding.common
import
(
from
vllm.model_executor.layers.rotary_embedding.common
import
(
apply_rotary_emb_torch
,
dispatch_rotary_emb_function
,
dispatch_rotary_emb_function
,
)
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
...
@@ -275,47 +277,13 @@ class Qwen2VisionMLP(nn.Module):
...
@@ -275,47 +277,13 @@ class Qwen2VisionMLP(nn.Module):
return
x
return
x
def
rotate_half
(
x
:
torch
.
Tensor
,
interleaved
:
bool
=
False
)
->
torch
.
Tensor
:
def
apply_rotary_pos_emb_vision
(
if
not
interleaved
:
t
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
x1
,
x2
=
x
.
chunk
(
2
,
dim
=-
1
)
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
else
:
x1
,
x2
=
x
[...,
::
2
],
x
[...,
1
::
2
]
return
rearrange
(
torch
.
stack
((
-
x2
,
x1
),
dim
=-
1
),
"... d two -> ... (d two)"
,
two
=
2
)
def
apply_rotary_emb_torch
(
x
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
interleaved
:
bool
=
False
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
rotary_emb_function
=
dispatch_rotary_emb_function
(
x: (batch_size, seqlen, nheads, headdim)
default
=
partial
(
apply_rotary_emb_torch
,
is_neox_style
=
True
)
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
"""
ro_dim
=
cos
.
shape
[
-
1
]
*
2
assert
ro_dim
<=
x
.
shape
[
-
1
]
cos
=
repeat
(
cos
,
"... d -> ... 1 (2 d)"
if
not
interleaved
else
"... d -> ... 1 (d 2)"
)
sin
=
repeat
(
sin
,
"... d -> ... 1 (2 d)"
if
not
interleaved
else
"... d -> ... 1 (d 2)"
)
)
return
torch
.
cat
(
output
=
rotary_emb_function
(
t
,
cos
,
sin
).
type_as
(
t
)
[
x
[...,
:
ro_dim
]
*
cos
+
rotate_half
(
x
[...,
:
ro_dim
],
interleaved
)
*
sin
,
x
[...,
ro_dim
:],
],
dim
=-
1
,
)
def
apply_rotary_pos_emb_vision
(
t
:
torch
.
Tensor
,
freqs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
rotary_emb_function
=
dispatch_rotary_emb_function
(
default
=
apply_rotary_emb_torch
)
t_
=
t
.
float
()
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
output
=
rotary_emb_function
(
t_
,
cos
,
sin
).
type_as
(
t
)
return
output
return
output
...
@@ -412,7 +380,8 @@ class Qwen2VisionAttention(nn.Module):
...
@@ -412,7 +380,8 @@ class Qwen2VisionAttention(nn.Module):
self
,
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
rotary_pos_emb
:
torch
.
Tensor
,
rotary_pos_emb_cos
:
torch
.
Tensor
,
rotary_pos_emb_sin
:
torch
.
Tensor
,
max_seqlen
:
int
|
None
=
None
,
# Only used for Flash Attention
max_seqlen
:
int
|
None
=
None
,
# Only used for Flash Attention
seqlens
:
list
[
int
]
|
None
=
None
,
# Only used for xFormers
seqlens
:
list
[
int
]
|
None
=
None
,
# Only used for xFormers
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -424,10 +393,12 @@ class Qwen2VisionAttention(nn.Module):
...
@@ -424,10 +393,12 @@ class Qwen2VisionAttention(nn.Module):
batch_size
=
q
.
shape
[
1
]
batch_size
=
q
.
shape
[
1
]
q
,
k
,
v
=
(
rearrange
(
x
,
"s b ... -> b s ..."
)
for
x
in
(
q
,
k
,
v
))
q
,
k
,
v
=
(
rearrange
(
x
,
"s b ... -> b s ..."
)
for
x
in
(
q
,
k
,
v
))
if
rotary_pos_emb
is
not
None
:
# [2 * b, s, heads, head_dim]
# [2 * b, s, heads, head_dim]
qk_concat
=
torch
.
cat
([
q
,
k
],
dim
=
0
)
qk_concat
=
torch
.
cat
([
q
,
k
],
dim
=
0
)
qk_rotated
=
apply_rotary_pos_emb_vision
(
qk_concat
,
rotary_pos_emb
)
qk_rotated
=
apply_rotary_pos_emb_vision
(
qk_concat
,
rotary_pos_emb_cos
,
rotary_pos_emb_sin
)
q
,
k
=
torch
.
chunk
(
qk_rotated
,
2
,
dim
=
0
)
q
,
k
=
torch
.
chunk
(
qk_rotated
,
2
,
dim
=
0
)
if
self
.
is_flash_attn_backend
:
if
self
.
is_flash_attn_backend
:
...
@@ -534,14 +505,16 @@ class Qwen2VisionBlock(nn.Module):
...
@@ -534,14 +505,16 @@ class Qwen2VisionBlock(nn.Module):
self
,
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
rotary_pos_emb
:
torch
.
Tensor
,
rotary_pos_emb_cos
:
torch
.
Tensor
,
rotary_pos_emb_sin
:
torch
.
Tensor
,
max_seqlen
:
int
|
None
=
None
,
# Only used for Flash Attention
max_seqlen
:
int
|
None
=
None
,
# Only used for Flash Attention
seqlens
:
list
[
int
]
|
None
=
None
,
# Only used for xFormers
seqlens
:
list
[
int
]
|
None
=
None
,
# Only used for xFormers
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
x
=
x
+
self
.
attn
(
x
=
x
+
self
.
attn
(
self
.
norm1
(
x
),
self
.
norm1
(
x
),
cu_seqlens
=
cu_seqlens
,
cu_seqlens
=
cu_seqlens
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_emb_cos
=
rotary_pos_emb_cos
,
rotary_pos_emb_sin
=
rotary_pos_emb_sin
,
max_seqlen
=
max_seqlen
,
max_seqlen
=
max_seqlen
,
seqlens
=
seqlens
,
seqlens
=
seqlens
,
)
)
...
@@ -628,40 +601,6 @@ class Qwen2VisionPatchMerger(nn.Module):
...
@@ -628,40 +601,6 @@ class Qwen2VisionPatchMerger(nn.Module):
return
out
return
out
class
Qwen2VisionRotaryEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
theta
:
float
=
10000.0
)
->
None
:
super
().
__init__
()
self
.
dim
=
dim
self
.
theta
=
theta
inv_freq
=
1.0
/
(
theta
**
(
torch
.
arange
(
0
,
dim
,
2
,
dtype
=
torch
.
float
)
/
dim
))
self
.
register_buffer
(
"inv_freq"
,
inv_freq
,
persistent
=
False
)
self
.
_seq_len_cached
=
0
self
.
_freqs_cached
=
None
def
update_freqs_cache
(
self
,
seqlen
:
int
)
->
None
:
if
seqlen
>
self
.
_seq_len_cached
:
seqlen
*=
2
self
.
_seq_len_cached
=
seqlen
self
.
inv_freq
=
1.0
/
(
self
.
theta
**
(
torch
.
arange
(
0
,
self
.
dim
,
2
,
dtype
=
torch
.
float
,
device
=
self
.
inv_freq
.
device
)
/
self
.
dim
)
)
seq
=
torch
.
arange
(
seqlen
,
device
=
self
.
inv_freq
.
device
,
dtype
=
self
.
inv_freq
.
dtype
)
freqs
=
torch
.
outer
(
seq
,
self
.
inv_freq
)
self
.
_freqs_cached
=
freqs
def
forward
(
self
,
seqlen
:
int
)
->
torch
.
Tensor
:
self
.
update_freqs_cache
(
seqlen
)
return
self
.
_freqs_cached
[:
seqlen
]
class
Qwen2VisionTransformer
(
nn
.
Module
):
class
Qwen2VisionTransformer
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -700,7 +639,13 @@ class Qwen2VisionTransformer(nn.Module):
...
@@ -700,7 +639,13 @@ class Qwen2VisionTransformer(nn.Module):
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
norm_eps
)
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
norm_eps
)
head_dim
=
embed_dim
//
num_heads
head_dim
=
embed_dim
//
num_heads
self
.
rotary_pos_emb
=
Qwen2VisionRotaryEmbedding
(
head_dim
//
2
)
self
.
rotary_pos_emb
=
get_rope
(
head_size
=
head_dim
,
rotary_dim
=
head_dim
//
2
,
max_position
=
8192
,
base
=
10000.0
,
is_neox_style
=
True
,
)
self
.
blocks
=
nn
.
ModuleList
(
self
.
blocks
=
nn
.
ModuleList
(
[
[
...
@@ -744,7 +689,9 @@ class Qwen2VisionTransformer(nn.Module):
...
@@ -744,7 +689,9 @@ class Qwen2VisionTransformer(nn.Module):
def
device
(
self
)
->
torch
.
device
:
def
device
(
self
)
->
torch
.
device
:
return
self
.
patch_embed
.
proj
.
weight
.
device
return
self
.
patch_embed
.
proj
.
weight
.
device
def
rot_pos_emb
(
self
,
grid_thw
:
list
[
list
[
int
]])
->
torch
.
Tensor
:
def
rot_pos_emb
(
self
,
grid_thw
:
list
[
list
[
int
]]
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
pos_ids
=
[]
pos_ids
=
[]
max_grid_size
=
0
max_grid_size
=
0
for
t
,
h
,
w
in
grid_thw
:
for
t
,
h
,
w
in
grid_thw
:
...
@@ -773,9 +720,18 @@ class Qwen2VisionTransformer(nn.Module):
...
@@ -773,9 +720,18 @@ class Qwen2VisionTransformer(nn.Module):
pos_ids
.
append
(
torch
.
stack
([
hpos_ids
,
wpos_ids
],
dim
=-
1
).
repeat
(
t
,
1
))
pos_ids
.
append
(
torch
.
stack
([
hpos_ids
,
wpos_ids
],
dim
=-
1
).
repeat
(
t
,
1
))
max_grid_size
=
max
(
max_grid_size
,
h
,
w
)
max_grid_size
=
max
(
max_grid_size
,
h
,
w
)
pos_ids
=
torch
.
cat
(
pos_ids
,
dim
=
0
)
pos_ids
=
torch
.
cat
(
pos_ids
,
dim
=
0
)
rotary_pos_emb_full
=
self
.
rotary_pos_emb
(
max_grid_size
)
rotary_pos_emb
=
rotary_pos_emb_full
[
pos_ids
].
flatten
(
1
)
# Use pre-computed cos_sin_cache from RotaryEmbedding
return
rotary_pos_emb
cos
,
sin
=
self
.
rotary_pos_emb
.
get_cos_sin
(
max_grid_size
)
cos_h
=
cos
[
pos_ids
[:,
0
]]
# (num_tokens, rotary_dim // 2)
cos_w
=
cos
[
pos_ids
[:,
1
]]
sin_h
=
sin
[
pos_ids
[:,
0
]]
sin_w
=
sin
[
pos_ids
[:,
1
]]
cos_combined
=
torch
.
cat
([
cos_h
,
cos_w
],
dim
=-
1
)
sin_combined
=
torch
.
cat
([
sin_h
,
sin_w
],
dim
=-
1
)
return
cos_combined
,
sin_combined
def
compute_attn_mask_seqlen
(
def
compute_attn_mask_seqlen
(
self
,
cu_seqlens
:
torch
.
Tensor
self
,
cu_seqlens
:
torch
.
Tensor
...
@@ -806,7 +762,7 @@ class Qwen2VisionTransformer(nn.Module):
...
@@ -806,7 +762,7 @@ class Qwen2VisionTransformer(nn.Module):
grid_thw_list
=
grid_thw
.
tolist
()
grid_thw_list
=
grid_thw
.
tolist
()
# compute position embedding
# compute position embedding
rotary_pos_emb
=
self
.
rot_pos_emb
(
grid_thw_list
)
rotary_pos_emb
_cos
,
rotary_pos_emb_sin
=
self
.
rot_pos_emb
(
grid_thw_list
)
# compute cu_seqlens
# compute cu_seqlens
cu_seqlens
=
torch
.
repeat_interleave
(
cu_seqlens
=
torch
.
repeat_interleave
(
...
@@ -824,7 +780,8 @@ class Qwen2VisionTransformer(nn.Module):
...
@@ -824,7 +780,8 @@ class Qwen2VisionTransformer(nn.Module):
x
=
blk
(
x
=
blk
(
x
,
x
,
cu_seqlens
=
cu_seqlens
,
cu_seqlens
=
cu_seqlens
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_emb_cos
=
rotary_pos_emb_cos
,
rotary_pos_emb_sin
=
rotary_pos_emb_sin
,
max_seqlen
=
max_seqlen
,
max_seqlen
=
max_seqlen
,
seqlens
=
seqlens
,
seqlens
=
seqlens
,
)
)
...
...
vllm/model_executor/models/qwen3_omni_moe_thinker.py
View file @
b9489f51
...
@@ -60,6 +60,7 @@ from vllm.model_executor.layers.linear import (
...
@@ -60,6 +60,7 @@ from vllm.model_executor.layers.linear import (
)
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.qwen2_audio
import
Qwen2AudioProcessingInfo
from
vllm.model_executor.models.qwen2_audio
import
Qwen2AudioProcessingInfo
...
@@ -90,7 +91,6 @@ from .qwen2_5_omni_thinker import (
...
@@ -90,7 +91,6 @@ from .qwen2_5_omni_thinker import (
)
)
from
.qwen2_5_vl
import
(
from
.qwen2_5_vl
import
(
Qwen2_5_VisionAttention
,
Qwen2_5_VisionAttention
,
Qwen2_5_VisionRotaryEmbedding
,
Qwen2_5_VLProcessingInfo
,
Qwen2_5_VLProcessingInfo
,
)
)
from
.qwen3_moe
import
Qwen3MoeForCausalLM
,
Qwen3MoeModel
from
.qwen3_moe
import
Qwen3MoeForCausalLM
,
Qwen3MoeModel
...
@@ -221,14 +221,16 @@ class Qwen3_VisionBlock(nn.Module):
...
@@ -221,14 +221,16 @@ class Qwen3_VisionBlock(nn.Module):
self
,
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
rotary_pos_emb
:
torch
.
Tensor
,
rotary_pos_emb_cos
:
torch
.
Tensor
,
rotary_pos_emb_sin
:
torch
.
Tensor
,
max_seqlen
:
torch
.
Tensor
,
# Only used for Flash Attention
max_seqlen
:
torch
.
Tensor
,
# Only used for Flash Attention
seqlens
:
torch
.
Tensor
,
# Only used for xFormers
seqlens
:
torch
.
Tensor
,
# Only used for xFormers
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
x
=
x
+
self
.
attn
(
x
=
x
+
self
.
attn
(
self
.
norm1
(
x
),
self
.
norm1
(
x
),
cu_seqlens
=
cu_seqlens
,
cu_seqlens
=
cu_seqlens
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_emb_cos
=
rotary_pos_emb_cos
,
rotary_pos_emb_sin
=
rotary_pos_emb_sin
,
max_seqlen
=
max_seqlen
,
max_seqlen
=
max_seqlen
,
seqlens
=
seqlens
,
seqlens
=
seqlens
,
)
)
...
@@ -332,7 +334,13 @@ class Qwen3Omni_VisionTransformer(nn.Module):
...
@@ -332,7 +334,13 @@ class Qwen3Omni_VisionTransformer(nn.Module):
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
norm_eps
)
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
norm_eps
)
head_dim
=
self
.
hidden_size
//
self
.
num_heads
head_dim
=
self
.
hidden_size
//
self
.
num_heads
self
.
rotary_pos_emb
=
Qwen2_5_VisionRotaryEmbedding
(
head_dim
//
2
)
self
.
rotary_pos_emb
=
get_rope
(
head_size
=
head_dim
,
rotary_dim
=
head_dim
//
2
,
max_position
=
8192
,
base
=
10000.0
,
is_neox_style
=
True
,
)
self
.
blocks
=
nn
.
ModuleList
(
self
.
blocks
=
nn
.
ModuleList
(
[
[
...
@@ -416,9 +424,19 @@ class Qwen3Omni_VisionTransformer(nn.Module):
...
@@ -416,9 +424,19 @@ class Qwen3Omni_VisionTransformer(nn.Module):
pos_ids
.
append
(
torch
.
stack
([
hpos_ids
,
wpos_ids
],
dim
=-
1
).
repeat
(
t
,
1
))
pos_ids
.
append
(
torch
.
stack
([
hpos_ids
,
wpos_ids
],
dim
=-
1
).
repeat
(
t
,
1
))
pos_ids
=
torch
.
cat
(
pos_ids
,
dim
=
0
)
pos_ids
=
torch
.
cat
(
pos_ids
,
dim
=
0
)
max_grid_size
=
grid_thw
[:,
1
:].
max
()
max_grid_size
=
grid_thw
[:,
1
:].
max
()
rotary_pos_emb_full
=
self
.
rotary_pos_emb
(
max_grid_size
)
rotary_pos_emb
=
rotary_pos_emb_full
[
pos_ids
].
flatten
(
1
)
# Use pre-computed cos_sin_cache from RotaryEmbedding
return
rotary_pos_emb
cos
,
sin
=
self
.
rotary_pos_emb
.
get_cos_sin
(
max_grid_size
)
cos_h
=
cos
[
pos_ids
[:,
0
]]
# (num_tokens, rotary_dim // 2)
cos_w
=
cos
[
pos_ids
[:,
1
]]
sin_h
=
sin
[
pos_ids
[:,
0
]]
sin_w
=
sin
[
pos_ids
[:,
1
]]
cos_combined
=
torch
.
cat
([
cos_h
,
cos_w
],
dim
=-
1
)
sin_combined
=
torch
.
cat
([
sin_h
,
sin_w
],
dim
=-
1
)
return
cos_combined
,
sin_combined
def
fast_pos_embed_interpolate
(
self
,
grid_thw
:
list
[
list
[
int
]])
->
torch
.
Tensor
:
def
fast_pos_embed_interpolate
(
self
,
grid_thw
:
list
[
list
[
int
]])
->
torch
.
Tensor
:
num_grid_per_side
=
self
.
num_grid_per_side
num_grid_per_side
=
self
.
num_grid_per_side
...
@@ -508,7 +526,7 @@ class Qwen3Omni_VisionTransformer(nn.Module):
...
@@ -508,7 +526,7 @@ class Qwen3Omni_VisionTransformer(nn.Module):
if
self
.
apply_vit_abs_pos_embed
:
if
self
.
apply_vit_abs_pos_embed
:
pos_embeds
=
self
.
fast_pos_embed_interpolate
(
grid_thw
)
pos_embeds
=
self
.
fast_pos_embed_interpolate
(
grid_thw
)
hidden_states
=
hidden_states
+
pos_embeds
hidden_states
=
hidden_states
+
pos_embeds
rotary_pos_emb
=
self
.
rot_pos_emb
(
grid_thw
)
rotary_pos_emb
_cos
,
rotary_pos_emb_sin
=
self
.
rot_pos_emb
(
grid_thw
)
cu_seqlens
=
torch
.
repeat_interleave
(
cu_seqlens
=
torch
.
repeat_interleave
(
grid_thw
[:,
1
]
*
grid_thw
[:,
2
],
grid_thw
[:,
0
]
grid_thw
[:,
1
]
*
grid_thw
[:,
2
],
grid_thw
[:,
0
]
...
@@ -519,7 +537,8 @@ class Qwen3Omni_VisionTransformer(nn.Module):
...
@@ -519,7 +537,8 @@ class Qwen3Omni_VisionTransformer(nn.Module):
cu_seqlens
=
F
.
pad
(
cu_seqlens
,
(
1
,
0
),
value
=
0
)
cu_seqlens
=
F
.
pad
(
cu_seqlens
,
(
1
,
0
),
value
=
0
)
hidden_states
=
hidden_states
.
unsqueeze
(
1
)
hidden_states
=
hidden_states
.
unsqueeze
(
1
)
rotary_pos_emb
=
rotary_pos_emb
.
to
(
hidden_states
.
device
)
rotary_pos_emb_cos
=
rotary_pos_emb_cos
.
to
(
hidden_states
.
device
)
rotary_pos_emb_sin
=
rotary_pos_emb_sin
.
to
(
hidden_states
.
device
)
max_seqlen
,
seqlens
=
self
.
compute_attn_mask_seqlen
(
cu_seqlens
)
max_seqlen
,
seqlens
=
self
.
compute_attn_mask_seqlen
(
cu_seqlens
)
hidden_states_list
=
[]
hidden_states_list
=
[]
...
@@ -529,7 +548,8 @@ class Qwen3Omni_VisionTransformer(nn.Module):
...
@@ -529,7 +548,8 @@ class Qwen3Omni_VisionTransformer(nn.Module):
hidden_states
=
blk
(
hidden_states
=
blk
(
hidden_states
,
hidden_states
,
cu_seqlens
=
cu_seqlens
,
cu_seqlens
=
cu_seqlens
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_emb_cos
=
rotary_pos_emb_cos
,
rotary_pos_emb_sin
=
rotary_pos_emb_sin
,
max_seqlen
=
max_seqlen
,
max_seqlen
=
max_seqlen
,
seqlens
=
seqlens
,
seqlens
=
seqlens
,
)
)
...
...
vllm/model_executor/models/qwen3_vl.py
View file @
b9489f51
...
@@ -63,6 +63,7 @@ from vllm.model_executor.layers.linear import (
...
@@ -63,6 +63,7 @@ from vllm.model_executor.layers.linear import (
)
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
...
@@ -95,7 +96,6 @@ from .interfaces import (
...
@@ -95,7 +96,6 @@ from .interfaces import (
)
)
from
.qwen2_5_vl
import
(
from
.qwen2_5_vl
import
(
Qwen2_5_VisionAttention
,
Qwen2_5_VisionAttention
,
Qwen2_5_VisionRotaryEmbedding
,
Qwen2_5_VLImageEmbeddingInputs
,
Qwen2_5_VLImageEmbeddingInputs
,
Qwen2_5_VLImageInputs
,
Qwen2_5_VLImageInputs
,
Qwen2_5_VLImagePixelInputs
,
Qwen2_5_VLImagePixelInputs
,
...
@@ -232,14 +232,16 @@ class Qwen3_VisionBlock(nn.Module):
...
@@ -232,14 +232,16 @@ class Qwen3_VisionBlock(nn.Module):
self
,
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
rotary_pos_emb
:
torch
.
Tensor
,
rotary_pos_emb_cos
:
torch
.
Tensor
,
rotary_pos_emb_sin
:
torch
.
Tensor
,
max_seqlen
:
torch
.
Tensor
,
# Only used for Flash Attention
max_seqlen
:
torch
.
Tensor
,
# Only used for Flash Attention
seqlens
:
torch
.
Tensor
,
# Only used for xFormers
seqlens
:
torch
.
Tensor
,
# Only used for xFormers
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
x
=
x
+
self
.
attn
(
x
=
x
+
self
.
attn
(
self
.
norm1
(
x
),
self
.
norm1
(
x
),
cu_seqlens
=
cu_seqlens
,
cu_seqlens
=
cu_seqlens
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_emb_cos
=
rotary_pos_emb_cos
,
rotary_pos_emb_sin
=
rotary_pos_emb_sin
,
max_seqlen
=
max_seqlen
,
max_seqlen
=
max_seqlen
,
seqlens
=
seqlens
,
seqlens
=
seqlens
,
)
)
...
@@ -339,7 +341,13 @@ class Qwen3_VisionTransformer(nn.Module):
...
@@ -339,7 +341,13 @@ class Qwen3_VisionTransformer(nn.Module):
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
norm_eps
)
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
norm_eps
)
head_dim
=
self
.
hidden_size
//
self
.
num_heads
head_dim
=
self
.
hidden_size
//
self
.
num_heads
self
.
rotary_pos_emb
=
Qwen2_5_VisionRotaryEmbedding
(
head_dim
//
2
)
self
.
rotary_pos_emb
=
get_rope
(
head_size
=
head_dim
,
rotary_dim
=
head_dim
//
2
,
max_position
=
8192
,
base
=
10000.0
,
is_neox_style
=
True
,
)
self
.
merger
=
Qwen3_VisionPatchMerger
(
self
.
merger
=
Qwen3_VisionPatchMerger
(
d_model
=
vision_config
.
out_hidden_size
,
d_model
=
vision_config
.
out_hidden_size
,
...
@@ -452,9 +460,19 @@ class Qwen3_VisionTransformer(nn.Module):
...
@@ -452,9 +460,19 @@ class Qwen3_VisionTransformer(nn.Module):
for
t
,
h
,
w
in
grid_thw
for
t
,
h
,
w
in
grid_thw
]
]
pos_ids
=
torch
.
cat
(
pos_ids
,
dim
=
0
)
pos_ids
=
torch
.
cat
(
pos_ids
,
dim
=
0
)
rotary_pos_emb_full
=
self
.
rotary_pos_emb
(
max_grid_size
)
rotary_pos_emb
=
rotary_pos_emb_full
[
pos_ids
].
flatten
(
1
)
# Use pre-computed cos_sin_cache from RotaryEmbedding
return
rotary_pos_emb
cos
,
sin
=
self
.
rotary_pos_emb
.
get_cos_sin
(
max_grid_size
)
cos_h
=
cos
[
pos_ids
[:,
0
]]
# (num_tokens, rotary_dim // 2)
cos_w
=
cos
[
pos_ids
[:,
1
]]
sin_h
=
sin
[
pos_ids
[:,
0
]]
sin_w
=
sin
[
pos_ids
[:,
1
]]
cos_combined
=
torch
.
cat
([
cos_h
,
cos_w
],
dim
=-
1
)
sin_combined
=
torch
.
cat
([
sin_h
,
sin_w
],
dim
=-
1
)
return
cos_combined
,
sin_combined
def
fast_pos_embed_interpolate
(
self
,
grid_thw
:
list
[
list
[
int
]])
->
torch
.
Tensor
:
def
fast_pos_embed_interpolate
(
self
,
grid_thw
:
list
[
list
[
int
]])
->
torch
.
Tensor
:
num_grid_per_side
=
self
.
num_grid_per_side
num_grid_per_side
=
self
.
num_grid_per_side
...
@@ -547,8 +565,13 @@ class Qwen3_VisionTransformer(nn.Module):
...
@@ -547,8 +565,13 @@ class Qwen3_VisionTransformer(nn.Module):
pos_embeds
=
self
.
fast_pos_embed_interpolate
(
grid_thw_list
)
pos_embeds
=
self
.
fast_pos_embed_interpolate
(
grid_thw_list
)
hidden_states
=
hidden_states
+
pos_embeds
hidden_states
=
hidden_states
+
pos_embeds
rotary_pos_emb
=
self
.
rot_pos_emb
(
grid_thw_list
)
rotary_pos_emb_cos
,
rotary_pos_emb_sin
=
self
.
rot_pos_emb
(
grid_thw_list
)
rotary_pos_emb
=
rotary_pos_emb
.
to
(
hidden_states
.
device
,
non_blocking
=
True
)
rotary_pos_emb_cos
=
rotary_pos_emb_cos
.
to
(
hidden_states
.
device
,
non_blocking
=
True
)
rotary_pos_emb_sin
=
rotary_pos_emb_sin
.
to
(
hidden_states
.
device
,
non_blocking
=
True
)
cu_seqlens
=
torch
.
repeat_interleave
(
cu_seqlens
=
torch
.
repeat_interleave
(
grid_thw
[:,
1
]
*
grid_thw
[:,
2
],
grid_thw
[:,
0
]
grid_thw
[:,
1
]
*
grid_thw
[:,
2
],
grid_thw
[:,
0
]
...
@@ -564,7 +587,8 @@ class Qwen3_VisionTransformer(nn.Module):
...
@@ -564,7 +587,8 @@ class Qwen3_VisionTransformer(nn.Module):
hidden_states
=
blk
(
hidden_states
=
blk
(
hidden_states
,
hidden_states
,
cu_seqlens
=
cu_seqlens
,
cu_seqlens
=
cu_seqlens
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_emb_cos
=
rotary_pos_emb_cos
,
rotary_pos_emb_sin
=
rotary_pos_emb_sin
,
max_seqlen
=
max_seqlen
,
max_seqlen
=
max_seqlen
,
seqlens
=
seqlens
,
seqlens
=
seqlens
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment