Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f39afa4a
Commit
f39afa4a
authored
Oct 27, 2025
by
cx
Browse files
添加qwen3-omni支持
parent
a9c37628
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
2032 additions
and
1115 deletions
+2032
-1115
vllm/model_executor/layers/rotary_embedding/mrope.py
vllm/model_executor/layers/rotary_embedding/mrope.py
+91
-1010
vllm/model_executor/models/interfaces.py
vllm/model_executor/models/interfaces.py
+74
-11
vllm/model_executor/models/qwen3_omni_moe_thinker.py
vllm/model_executor/models/qwen3_omni_moe_thinker.py
+1724
-0
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+1
-0
vllm/model_executor/models/utils.py
vllm/model_executor/models/utils.py
+63
-69
vllm/model_executor/models/vision.py
vllm/model_executor/models/vision.py
+36
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+31
-14
vllm/version.py
vllm/version.py
+12
-11
No files found.
vllm/model_executor/layers/rotary_embedding/mrope.py
View file @
f39afa4a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
itertools
from
typing
import
Optional
,
Union
import
numpy
as
np
import
torch
from
transformers
import
PretrainedConfig
from
vllm.triton_utils
import
tl
,
triton
...
...
@@ -62,10 +55,8 @@ def _triton_mrope_forward(
# Updated offsets for half head_dim
cos_offsets
=
tl
.
arange
(
0
,
pad_hd
//
2
)
if
is_interleaved
:
h_mask
=
(((
cos_offsets
%
3
)
==
1
)
&
(
cos_offsets
<=
3
*
mrope_section_h
))
w_mask
=
(((
cos_offsets
%
3
)
==
2
)
&
(
cos_offsets
<=
3
*
mrope_section_w
))
h_mask
=
((
cos_offsets
%
3
)
==
1
)
&
(
cos_offsets
<=
3
*
mrope_section_h
)
w_mask
=
((
cos_offsets
%
3
)
==
2
)
&
(
cos_offsets
<=
3
*
mrope_section_w
)
t_mask
=
~
(
h_mask
|
w_mask
)
else
:
t_end
=
mrope_section_t
...
...
@@ -89,21 +80,25 @@ def _triton_mrope_forward(
# program instance (i.e. for the current token) separately
# ####################################################################
# left half of the head
first_half_q_offsets
=
tl
.
arange
(
0
,
pad_n_qh
)[:,
None
]
*
hd
+
tl
.
arange
(
0
,
pad_hd
//
2
)[
None
,
:]
first_half_k_offsets
=
tl
.
arange
(
0
,
pad_n_kh
)[:,
None
]
*
hd
+
tl
.
arange
(
0
,
pad_hd
//
2
)[
None
,
:]
first_q_mask
=
(
tl
.
arange
(
0
,
pad_n_qh
)[:,
None
]
<
n_qh
)
&
(
tl
.
arange
(
0
,
pad_hd
//
2
)[
None
,
:]
<
rd
//
2
)
first_k_mask
=
(
tl
.
arange
(
0
,
pad_n_kh
)[:,
None
]
<
n_kh
)
&
(
tl
.
arange
(
0
,
pad_hd
//
2
)[
None
,
:]
<
rd
//
2
)
first_half_q_offsets
=
(
tl
.
arange
(
0
,
pad_n_qh
)[:,
None
]
*
hd
+
tl
.
arange
(
0
,
pad_hd
//
2
)[
None
,
:]
)
first_half_k_offsets
=
(
tl
.
arange
(
0
,
pad_n_kh
)[:,
None
]
*
hd
+
tl
.
arange
(
0
,
pad_hd
//
2
)[
None
,
:]
)
first_q_mask
=
(
tl
.
arange
(
0
,
pad_n_qh
)[:,
None
]
<
n_qh
)
&
(
tl
.
arange
(
0
,
pad_hd
//
2
)[
None
,
:]
<
rd
//
2
)
first_k_mask
=
(
tl
.
arange
(
0
,
pad_n_kh
)[:,
None
]
<
n_kh
)
&
(
tl
.
arange
(
0
,
pad_hd
//
2
)[
None
,
:]
<
rd
//
2
)
q_tile_1
=
tl
.
load
(
q_ptr
+
first_half_q_offsets
,
mask
=
first_q_mask
,
other
=
0
).
to
(
sin_row
.
dtype
)
k_tile_1
=
tl
.
load
(
k_ptr
+
first_half_k_offsets
,
mask
=
first_k_mask
,
other
=
0
).
to
(
sin_row
.
dtype
)
q_tile_1
=
tl
.
load
(
q_ptr
+
first_half_q_offsets
,
mask
=
first_q_mask
,
other
=
0
).
to
(
sin_row
.
dtype
)
k_tile_1
=
tl
.
load
(
k_ptr
+
first_half_k_offsets
,
mask
=
first_k_mask
,
other
=
0
).
to
(
sin_row
.
dtype
)
# right half of the head
second_half_q_offsets
=
first_half_q_offsets
+
(
rd
//
2
)
...
...
@@ -111,12 +106,12 @@ def _triton_mrope_forward(
second_q_mask
=
first_q_mask
second_k_mask
=
first_k_mask
q_tile_2
=
tl
.
load
(
q_ptr
+
second_half_q_offsets
,
mask
=
second_q_mask
,
other
=
0
).
to
(
sin_row
.
dtype
)
k_tile_2
=
tl
.
load
(
k_ptr
+
second_half_k_offsets
,
mask
=
second_k_mask
,
other
=
0
).
to
(
sin_row
.
dtype
)
q_tile_2
=
tl
.
load
(
q_ptr
+
second_half_q_offsets
,
mask
=
second_q_mask
,
other
=
0
).
to
(
sin_row
.
dtype
)
k_tile_2
=
tl
.
load
(
k_ptr
+
second_half_k_offsets
,
mask
=
second_k_mask
,
other
=
0
).
to
(
sin_row
.
dtype
)
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
# Since cos and sin are now half-size,
...
...
@@ -168,7 +163,7 @@ def triton_mrope(
cos
=
cos
.
contiguous
()
sin
=
sin
.
contiguous
()
_triton_mrope_forward
[(
n_row
,
)](
_triton_mrope_forward
[(
n_row
,)](
q
,
k
,
cos
,
...
...
@@ -189,15 +184,14 @@ def triton_mrope(
return
q
,
k
def
apply_interleaved_rope
(
x
:
torch
.
Tensor
,
mrope_section
:
list
[
int
])
->
torch
.
Tensor
:
def
apply_interleaved_rope
(
x
:
torch
.
Tensor
,
mrope_section
:
list
[
int
])
->
torch
.
Tensor
:
"""Apply interleaved MRoPE to 3D rotary embeddings.
Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
interleaved [THTHWHTHW...TT], preserving frequency continuity.
"""
x_t
=
x
[
0
].
clone
()
x_t
[...,
1
:
mrope_section
[
1
]
*
3
:
3
]
=
x
[
1
,
...,
1
:
mrope_section
[
1
]
*
3
:
3
]
x_t
[...,
2
:
mrope_section
[
2
]
*
3
:
3
]
=
x
[
2
,
...,
2
:
mrope_section
[
2
]
*
3
:
3
]
x_t
[...,
1
:
mrope_section
[
1
]
*
3
:
3
]
=
x
[
1
,
...,
1
:
mrope_section
[
1
]
*
3
:
3
]
x_t
[...,
2
:
mrope_section
[
2
]
*
3
:
3
]
=
x
[
2
,
...,
2
:
mrope_section
[
2
]
*
3
:
3
]
return
x_t
...
...
@@ -212,17 +206,16 @@ class MRotaryEmbedding(RotaryEmbedding):
base
:
float
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
mrope_section
:
Optional
[
list
[
int
]
]
=
None
,
mrope_section
:
list
[
int
]
|
None
=
None
,
mrope_interleaved
:
bool
=
False
,
# YaRN parameters.
*
,
scaling_factor
:
Optional
[
float
]
=
None
,
scaling_factor
:
float
|
None
=
None
,
extrapolation_factor
:
float
=
1
,
attn_factor
:
float
=
1
,
beta_fast
:
int
=
32
,
beta_slow
:
int
=
1
,
)
->
None
:
self
.
scaling_factor
=
scaling_factor
self
.
extrapolation_factor
=
extrapolation_factor
self
.
attn_factor
=
attn_factor
...
...
@@ -230,8 +223,7 @@ class MRotaryEmbedding(RotaryEmbedding):
self
.
beta_slow
=
beta_slow
if
self
.
scaling_factor
is
not
None
:
# Get n-d magnitude scaling corrected for interpolation
self
.
mscale
=
float
(
yarn_get_mscale
(
self
.
scaling_factor
)
*
attn_factor
)
self
.
mscale
=
float
(
yarn_get_mscale
(
self
.
scaling_factor
)
*
attn_factor
)
else
:
self
.
mscale
=
1.0
...
...
@@ -239,8 +231,14 @@ class MRotaryEmbedding(RotaryEmbedding):
# the input video. We enlarge max_position_embeddings to 4 times to get
# a larger the cos and sin cache.
self
.
cache_max_position_num
=
max_position_embeddings
*
4
super
().
__init__
(
head_size
,
rotary_dim
,
self
.
cache_max_position_num
,
base
,
is_neox_style
,
dtype
)
super
().
__init__
(
head_size
,
rotary_dim
,
self
.
cache_max_position_num
,
base
,
is_neox_style
,
dtype
,
)
self
.
mrope_section
=
mrope_section
self
.
mrope_interleaved
=
mrope_interleaved
...
...
@@ -261,9 +259,9 @@ class MRotaryEmbedding(RotaryEmbedding):
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]
]:
key
:
torch
.
Tensor
|
None
=
None
,
offsets
:
torch
.
Tensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
]:
"""PyTorch-native implementation equivalent to forward().
Args:
...
...
@@ -286,31 +284,27 @@ class MRotaryEmbedding(RotaryEmbedding):
cos
=
apply_interleaved_rope
(
cos
,
self
.
mrope_section
)
sin
=
apply_interleaved_rope
(
sin
,
self
.
mrope_section
)
else
:
cos
=
torch
.
cat
([
m
[
i
]
for
i
,
m
in
enumerate
(
cos
.
split
(
self
.
mrope_section
,
dim
=-
1
))
],
dim
=-
1
)
sin
=
torch
.
cat
([
m
[
i
]
for
i
,
m
in
enumerate
(
sin
.
split
(
self
.
mrope_section
,
dim
=-
1
))
],
dim
=-
1
)
cos
=
torch
.
cat
(
[
m
[
i
]
for
i
,
m
in
enumerate
(
cos
.
split
(
self
.
mrope_section
,
dim
=-
1
))],
dim
=-
1
,
)
sin
=
torch
.
cat
(
[
m
[
i
]
for
i
,
m
in
enumerate
(
sin
.
split
(
self
.
mrope_section
,
dim
=-
1
))],
dim
=-
1
,
)
query_shape
=
query
.
shape
query
=
query
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
query_pass
=
query
[...,
self
.
rotary_dim
:]
query_rot
=
apply_rotary_emb_dispatch
(
query_rot
,
cos
,
sin
,
self
.
is_neox_style
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
query_pass
=
query
[...,
self
.
rotary_dim
:]
query_rot
=
apply_rotary_emb_dispatch
(
query_rot
,
cos
,
sin
,
self
.
is_neox_style
)
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
).
reshape
(
query_shape
)
key_shape
=
key
.
shape
key
=
key
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key_pass
=
key
[...,
self
.
rotary_dim
:]
key_rot
=
apply_rotary_emb_dispatch
(
key_rot
,
cos
,
sin
,
self
.
is_neox_style
)
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key_pass
=
key
[...,
self
.
rotary_dim
:]
key_rot
=
apply_rotary_emb_dispatch
(
key_rot
,
cos
,
sin
,
self
.
is_neox_style
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
return
query
,
key
...
...
@@ -318,10 +312,9 @@ class MRotaryEmbedding(RotaryEmbedding):
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
key
:
torch
.
Tensor
|
None
=
None
,
offsets
:
torch
.
Tensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
]:
assert
positions
.
ndim
==
1
or
positions
.
ndim
==
2
assert
key
is
not
None
...
...
@@ -348,17 +341,15 @@ class MRotaryEmbedding(RotaryEmbedding):
return
q
.
reshape
(
query_shape
),
k
.
reshape
(
key_shape
)
query
=
query
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
query_pass
=
query
[...,
self
.
rotary_dim
:]
query_rot
=
apply_rotary_emb_dispatch
(
query_rot
,
cos
,
sin
,
self
.
is_neox_style
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
query_pass
=
query
[...,
self
.
rotary_dim
:]
query_rot
=
apply_rotary_emb_dispatch
(
query_rot
,
cos
,
sin
,
self
.
is_neox_style
)
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
).
reshape
(
query_shape
)
key
=
key
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key_pass
=
key
[...,
self
.
rotary_dim
:]
key_rot
=
apply_rotary_emb_dispatch
(
key_rot
,
cos
,
sin
,
self
.
is_neox_style
)
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key_pass
=
key
[...,
self
.
rotary_dim
:]
key_rot
=
apply_rotary_emb_dispatch
(
key_rot
,
cos
,
sin
,
self
.
is_neox_style
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
return
query
,
key
...
...
@@ -366,886 +357,20 @@ class MRotaryEmbedding(RotaryEmbedding):
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]
]:
key
:
torch
.
Tensor
|
None
=
None
,
offsets
:
torch
.
Tensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
]:
return
self
.
forward_native
(
positions
,
query
,
key
,
offsets
)
def
forward_cpu
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]
]:
key
:
torch
.
Tensor
|
None
=
None
,
offsets
:
torch
.
Tensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
]:
return
self
.
forward_native
(
positions
,
query
,
key
,
offsets
)
@
classmethod
def
get_input_positions
(
cls
,
input_tokens
:
list
[
int
],
hf_config
:
PretrainedConfig
,
image_grid_thw
:
Optional
[
Union
[
list
[
list
[
int
]],
torch
.
Tensor
]],
video_grid_thw
:
Optional
[
Union
[
list
[
list
[
int
]],
torch
.
Tensor
]],
second_per_grid_ts
:
Optional
[
list
[
float
]],
context_len
:
int
=
0
,
seq_len
:
Optional
[
int
]
=
None
,
audio_feature_lengths
:
Optional
[
torch
.
Tensor
]
=
None
,
use_audio_in_video
:
bool
=
False
,
)
->
tuple
[
list
[
list
[
int
]],
int
]:
"""Get mrope input positions and delta value."""
image_grid_thw
=
[]
if
image_grid_thw
is
None
else
image_grid_thw
video_grid_thw
=
[]
if
video_grid_thw
is
None
else
video_grid_thw
second_per_grid_ts
=
[]
if
second_per_grid_ts
is
None
else
\
second_per_grid_ts
llm_positions
,
mrope_position_delta
=
\
cls
.
get_input_positions_tensor
(
input_tokens
=
input_tokens
,
hf_config
=
hf_config
,
image_grid_thw
=
image_grid_thw
,
video_grid_thw
=
video_grid_thw
,
second_per_grid_ts
=
second_per_grid_ts
,
context_len
=
context_len
,
seq_len
=
seq_len
,
audio_feature_lengths
=
audio_feature_lengths
,
use_audio_in_video
=
use_audio_in_video
,
)
return
llm_positions
.
tolist
(),
mrope_position_delta
@
classmethod
def
get_input_positions_tensor
(
cls
,
input_tokens
:
list
[
int
],
hf_config
:
PretrainedConfig
,
image_grid_thw
:
Union
[
list
[
list
[
int
]],
torch
.
Tensor
],
video_grid_thw
:
Union
[
list
[
list
[
int
]],
torch
.
Tensor
],
second_per_grid_ts
:
list
[
float
],
context_len
:
int
=
0
,
seq_len
:
Optional
[
int
]
=
None
,
audio_feature_lengths
:
Optional
[
torch
.
Tensor
]
=
None
,
use_audio_in_video
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
int
]:
from
vllm.transformers_utils.config
import
thinker_uses_mrope
if
thinker_uses_mrope
(
hf_config
):
return
cls
.
_omni_get_input_positions_tensor
(
input_tokens
=
input_tokens
,
hf_config
=
hf_config
,
image_grid_thw
=
image_grid_thw
,
video_grid_thw
=
video_grid_thw
,
second_per_grid_ts
=
second_per_grid_ts
,
context_len
=
context_len
,
seq_len
=
seq_len
,
audio_feature_lengths
=
audio_feature_lengths
,
use_audio_in_video
=
use_audio_in_video
,
)
elif
hf_config
.
model_type
in
[
"glm4v"
,
"glm4v_moe"
]:
return
cls
.
_glm4v_get_input_positions_tensor
(
input_tokens
=
input_tokens
,
hf_config
=
hf_config
,
image_grid_thw
=
image_grid_thw
,
video_grid_thw
=
video_grid_thw
,
context_len
=
context_len
,
seq_len
=
seq_len
,
)
elif
hf_config
.
model_type
in
[
"qwen3_vl"
,
"qwen3_vl_moe"
]:
return
cls
.
_qwen3vl_get_input_positions_tensor
(
input_tokens
=
input_tokens
,
hf_config
=
hf_config
,
image_grid_thw
=
image_grid_thw
,
video_grid_thw
=
video_grid_thw
,
context_len
=
context_len
,
seq_len
=
seq_len
,
)
elif
hf_config
.
model_type
in
[
"ernie4_5_moe_vl"
,
"ernie4_5_vl"
]:
return
cls
.
_ernie_get_input_positions_tensor
(
input_tokens
=
input_tokens
,
hf_config
=
hf_config
,
image_grid_thw
=
image_grid_thw
,
video_grid_thw
=
video_grid_thw
,
context_len
=
context_len
,
seq_len
=
seq_len
,
)
elif
"KeyeVL1_5"
in
hf_config
.
model_type
:
return
cls
.
_keye_get_input_positions_tensor
(
input_tokens
=
input_tokens
,
hf_config
=
hf_config
,
image_grid_thw
=
image_grid_thw
,
video_grid_thw
=
video_grid_thw
,
context_len
=
context_len
,
seq_len
=
seq_len
,
)
else
:
return
cls
.
_vl_get_input_positions_tensor
(
input_tokens
=
input_tokens
,
hf_config
=
hf_config
,
image_grid_thw
=
image_grid_thw
,
video_grid_thw
=
video_grid_thw
,
second_per_grid_ts
=
second_per_grid_ts
,
context_len
=
context_len
,
seq_len
=
seq_len
,
)
@
classmethod
def
_glm4v_get_input_positions_tensor
(
cls
,
input_tokens
:
list
[
int
],
hf_config
:
PretrainedConfig
,
image_grid_thw
:
Union
[
list
[
list
[
int
]],
torch
.
Tensor
],
video_grid_thw
:
Union
[
list
[
list
[
int
]],
torch
.
Tensor
],
context_len
:
int
=
0
,
seq_len
:
Optional
[
int
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
int
]:
"""Get mrope input positions and delta value for GLM4V."""
image_token_id
=
hf_config
.
image_token_id
video_start_token_id
=
hf_config
.
video_start_token_id
video_end_token_id
=
hf_config
.
video_end_token_id
spatial_merge_size
=
hf_config
.
vision_config
.
spatial_merge_size
llm_pos_ids_list
:
list
=
[]
if
not
(
image_grid_thw
is
None
and
video_grid_thw
is
None
):
if
isinstance
(
image_grid_thw
,
torch
.
Tensor
):
image_grid_thw
=
image_grid_thw
.
tolist
()
input_token_type
:
list
[
str
]
=
[]
video_check_flg
=
False
for
token
in
input_tokens
:
if
token
==
video_start_token_id
:
video_check_flg
=
True
elif
token
==
video_end_token_id
:
video_check_flg
=
False
if
(
token
==
image_token_id
)
and
(
video_check_flg
is
False
):
input_token_type
.
append
(
"image"
)
elif
(
token
==
image_token_id
)
and
(
video_check_flg
is
True
):
input_token_type
.
append
(
"video"
)
else
:
input_token_type
.
append
(
"text"
)
input_type_group
:
list
[
tuple
[
str
,
int
,
int
]]
=
[]
for
key
,
group_iter
in
itertools
.
groupby
(
enumerate
(
input_token_type
),
lambda
x
:
x
[
1
]):
group_list
=
list
(
group_iter
)
start_index
=
group_list
[
0
][
0
]
end_index
=
group_list
[
-
1
][
0
]
+
1
input_type_group
.
append
((
key
,
start_index
,
end_index
))
video_frame_num
=
1
mm_data_idx
=
0
for
modality_type
,
start_idx
,
end_idx
in
input_type_group
:
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
len
(
llm_pos_ids_list
)
>
0
else
0
if
modality_type
==
"image"
:
t
,
h
,
w
=
(
image_grid_thw
[
mm_data_idx
][
0
],
image_grid_thw
[
mm_data_idx
][
1
],
image_grid_thw
[
mm_data_idx
][
2
],
)
llm_grid_t
,
llm_grid_h
,
llm_grid_w
=
\
t
,
h
//
spatial_merge_size
,
w
//
spatial_merge_size
t_index
=
torch
.
arange
(
llm_grid_t
).
view
(
-
1
,
1
).
expand
(
-
1
,
llm_grid_h
*
llm_grid_w
).
flatten
()
h_index
=
torch
.
arange
(
llm_grid_h
).
view
(
1
,
-
1
,
1
).
expand
(
llm_grid_t
,
-
1
,
llm_grid_w
).
flatten
()
w_index
=
torch
.
arange
(
llm_grid_w
).
view
(
1
,
1
,
-
1
).
expand
(
llm_grid_t
,
llm_grid_h
,
-
1
).
flatten
()
llm_pos_ids_list
.
append
(
torch
.
stack
([
t_index
,
h_index
,
w_index
])
+
st_idx
)
mm_data_idx
+=
1
elif
modality_type
==
"video"
:
t
,
h
,
w
=
(
video_frame_num
,
image_grid_thw
[
mm_data_idx
][
1
],
image_grid_thw
[
mm_data_idx
][
2
],
)
llm_grid_t
,
llm_grid_h
,
llm_grid_w
=
\
t
,
h
//
spatial_merge_size
,
w
//
spatial_merge_size
for
t_idx
in
range
(
llm_grid_t
):
t_index
=
torch
.
tensor
(
t_idx
).
view
(
-
1
,
1
).
expand
(
-
1
,
llm_grid_h
*
llm_grid_w
).
flatten
()
h_index
=
torch
.
arange
(
llm_grid_h
).
view
(
1
,
-
1
,
1
).
expand
(
1
,
-
1
,
llm_grid_w
).
flatten
()
w_index
=
torch
.
arange
(
llm_grid_w
).
view
(
1
,
1
,
-
1
).
expand
(
1
,
llm_grid_h
,
-
1
).
flatten
()
llm_pos_ids_list
.
append
(
torch
.
stack
([
t_index
,
h_index
,
w_index
])
+
st_idx
)
mm_data_idx
+=
1
video_frame_num
+=
1
else
:
text_len
=
end_idx
-
start_idx
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
video_frame_num
=
1
else
:
text_len
=
len
(
input_tokens
)
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
))
llm_positions
=
torch
.
cat
(
llm_pos_ids_list
,
dim
=
1
).
reshape
(
3
,
-
1
)
llm_positions
=
llm_positions
[:,
context_len
:
seq_len
]
mrope_position_delta
=
(
llm_positions
.
max
()
+
1
-
len
(
input_tokens
)).
item
()
return
llm_positions
,
mrope_position_delta
@
classmethod
def
_qwen3vl_get_input_positions_tensor
(
cls
,
input_tokens
:
list
[
int
],
hf_config
:
PretrainedConfig
,
image_grid_thw
:
Union
[
list
[
list
[
int
]],
torch
.
Tensor
],
video_grid_thw
:
Union
[
list
[
list
[
int
]],
torch
.
Tensor
],
context_len
:
int
=
0
,
seq_len
:
Optional
[
int
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
int
]:
"""Get mrope input positions and delta value."""
video_grid_thw
=
[[
1
,
h
,
w
]
for
t
,
h
,
w
in
video_grid_thw
for
_
in
range
(
t
)]
image_token_id
=
hf_config
.
image_token_id
video_token_id
=
hf_config
.
video_token_id
vision_start_token_id
=
hf_config
.
vision_start_token_id
spatial_merge_size
=
hf_config
.
vision_config
.
spatial_merge_size
input_tokens_tensor
=
torch
.
tensor
(
input_tokens
)
vision_start_indices
=
torch
.
argwhere
(
input_tokens_tensor
==
vision_start_token_id
).
squeeze
(
1
)
vision_tokens
=
input_tokens_tensor
[
vision_start_indices
+
1
]
image_nums
=
(
vision_tokens
==
image_token_id
).
sum
()
video_nums
=
(
vision_tokens
==
video_token_id
).
sum
()
llm_pos_ids_list
:
list
=
[]
st
=
0
remain_images
,
remain_videos
=
image_nums
,
video_nums
image_index
,
video_index
=
0
,
0
for
_
in
range
(
image_nums
+
video_nums
):
if
image_token_id
in
input_tokens
and
remain_images
>
0
:
ed_image
=
input_tokens
.
index
(
image_token_id
,
st
)
else
:
ed_image
=
len
(
input_tokens
)
+
1
if
video_token_id
in
input_tokens
and
remain_videos
>
0
:
ed_video
=
input_tokens
.
index
(
video_token_id
,
st
)
else
:
ed_video
=
len
(
input_tokens
)
+
1
if
ed_image
<
ed_video
:
t
,
h
,
w
=
(
image_grid_thw
[
image_index
][
0
],
image_grid_thw
[
image_index
][
1
],
image_grid_thw
[
image_index
][
2
],
)
image_index
+=
1
remain_images
-=
1
ed
=
ed_image
else
:
t
,
h
,
w
=
(
video_grid_thw
[
video_index
][
0
],
video_grid_thw
[
video_index
][
1
],
video_grid_thw
[
video_index
][
2
],
)
video_index
+=
1
remain_videos
-=
1
ed
=
ed_video
llm_grid_t
,
llm_grid_h
,
llm_grid_w
=
\
t
,
h
//
spatial_merge_size
,
w
//
spatial_merge_size
text_len
=
ed
-
st
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
len
(
llm_pos_ids_list
)
>
0
else
0
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
t_index
=
torch
.
arange
(
llm_grid_t
).
view
(
-
1
,
1
).
expand
(
-
1
,
llm_grid_h
*
llm_grid_w
).
flatten
()
h_index
=
torch
.
arange
(
llm_grid_h
).
view
(
1
,
-
1
,
1
).
expand
(
llm_grid_t
,
-
1
,
llm_grid_w
).
flatten
()
w_index
=
torch
.
arange
(
llm_grid_w
).
view
(
1
,
1
,
-
1
).
expand
(
llm_grid_t
,
llm_grid_h
,
-
1
).
flatten
()
llm_pos_ids_list
.
append
(
torch
.
stack
([
t_index
,
h_index
,
w_index
])
+
text_len
+
st_idx
)
st
=
ed
+
llm_grid_t
*
llm_grid_h
*
llm_grid_w
if
st
<
len
(
input_tokens
):
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
len
(
llm_pos_ids_list
)
>
0
else
0
text_len
=
len
(
input_tokens
)
-
st
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
llm_positions
=
torch
.
cat
(
llm_pos_ids_list
,
dim
=
1
).
reshape
(
3
,
-
1
)
mrope_position_delta
=
(
llm_positions
.
max
()
+
1
-
len
(
input_tokens
)).
item
()
llm_positions
=
llm_positions
[:,
context_len
:
seq_len
]
return
llm_positions
,
mrope_position_delta
@
classmethod
def
_ernie_get_input_positions_tensor
(
cls
,
input_tokens
:
list
[
int
],
hf_config
:
PretrainedConfig
,
image_grid_thw
:
Union
[
list
[
list
[
int
]],
torch
.
Tensor
],
video_grid_thw
:
Union
[
list
[
list
[
int
]],
torch
.
Tensor
],
context_len
:
int
=
0
,
seq_len
:
Optional
[
int
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
int
]:
"""Get mrope input positions and delta value for Ernie VL."""
image_token_id
=
hf_config
.
im_patch_id
video_start_token_id
=
hf_config
.
video_start_token_id
video_end_token_id
=
hf_config
.
video_end_token_id
spatial_conv_size
=
hf_config
.
spatial_conv_size
temporal_conv_size
=
hf_config
.
temporal_conv_size
llm_pos_ids_list
:
list
=
[]
if
not
(
image_grid_thw
is
None
and
video_grid_thw
is
None
):
if
isinstance
(
image_grid_thw
,
torch
.
Tensor
):
image_grid_thw
=
image_grid_thw
.
tolist
()
input_token_type
:
list
[
str
]
=
[]
video_check_flg
=
False
for
token
in
input_tokens
:
if
token
==
video_start_token_id
:
video_check_flg
=
True
elif
token
==
video_end_token_id
:
video_check_flg
=
False
if
(
token
==
image_token_id
)
and
(
video_check_flg
is
False
):
input_token_type
.
append
(
"image"
)
elif
(
token
==
image_token_id
)
and
(
video_check_flg
is
True
):
input_token_type
.
append
(
"video"
)
else
:
input_token_type
.
append
(
"text"
)
input_type_group
:
list
[
tuple
[
str
,
int
,
int
]]
=
[]
for
key
,
group_iter
in
itertools
.
groupby
(
enumerate
(
input_token_type
),
lambda
x
:
x
[
1
]):
group_list
=
list
(
group_iter
)
start_index
=
group_list
[
0
][
0
]
end_index
=
group_list
[
-
1
][
0
]
+
1
input_type_group
.
append
((
key
,
start_index
,
end_index
))
video_frame_num
=
1
mm_data_idx
=
0
for
modality_type
,
start_idx
,
end_idx
in
input_type_group
:
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
len
(
llm_pos_ids_list
)
>
0
else
0
if
modality_type
==
"image"
:
t
,
h
,
w
=
(
image_grid_thw
[
mm_data_idx
][
0
],
image_grid_thw
[
mm_data_idx
][
1
],
image_grid_thw
[
mm_data_idx
][
2
],
)
llm_grid_t
,
llm_grid_h
,
llm_grid_w
=
\
t
,
h
//
spatial_conv_size
,
w
//
spatial_conv_size
t_index
=
torch
.
arange
(
llm_grid_t
).
view
(
-
1
,
1
).
expand
(
-
1
,
llm_grid_h
*
llm_grid_w
).
flatten
()
h_index
=
torch
.
arange
(
llm_grid_h
).
view
(
1
,
-
1
,
1
).
expand
(
llm_grid_t
,
-
1
,
llm_grid_w
).
flatten
()
w_index
=
torch
.
arange
(
llm_grid_w
).
view
(
1
,
1
,
-
1
).
expand
(
llm_grid_t
,
llm_grid_h
,
-
1
).
flatten
()
llm_pos_ids_list
.
append
(
torch
.
stack
([
t_index
,
h_index
,
w_index
])
+
st_idx
)
mm_data_idx
+=
1
elif
modality_type
==
"video"
:
t
,
h
,
w
=
(
video_grid_thw
[
mm_data_idx
][
0
],
video_grid_thw
[
mm_data_idx
][
1
],
video_grid_thw
[
mm_data_idx
][
2
],
)
llm_grid_t
,
llm_grid_h
,
llm_grid_w
=
(
t
//
temporal_conv_size
,
h
//
spatial_conv_size
,
w
//
spatial_conv_size
)
for
t_idx
in
range
(
llm_grid_t
):
t_index
=
torch
.
tensor
(
t_idx
).
view
(
-
1
,
1
).
expand
(
-
1
,
llm_grid_h
*
llm_grid_w
).
flatten
()
h_index
=
torch
.
arange
(
llm_grid_h
).
view
(
1
,
-
1
,
1
).
expand
(
1
,
-
1
,
llm_grid_w
).
flatten
()
w_index
=
torch
.
arange
(
llm_grid_w
).
view
(
1
,
1
,
-
1
).
expand
(
1
,
llm_grid_h
,
-
1
).
flatten
()
llm_pos_ids_list
.
append
(
torch
.
stack
([
t_index
,
h_index
,
w_index
])
+
st_idx
)
mm_data_idx
+=
1
video_frame_num
+=
1
else
:
text_len
=
end_idx
-
start_idx
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
video_frame_num
=
1
else
:
text_len
=
len
(
input_tokens
)
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
))
llm_positions
=
torch
.
cat
(
llm_pos_ids_list
,
dim
=
1
).
reshape
(
3
,
-
1
)
llm_positions
=
llm_positions
[:,
context_len
:
seq_len
]
mrope_position_delta
=
(
llm_positions
.
max
()
+
1
-
len
(
input_tokens
)).
item
()
return
llm_positions
,
mrope_position_delta
@
classmethod
def
_keye_get_input_positions_tensor
(
cls
,
input_tokens
:
list
[
int
],
hf_config
:
PretrainedConfig
,
image_grid_thw
:
Union
[
list
[
list
[
int
]],
torch
.
Tensor
],
video_grid_thw
:
Union
[
list
[
list
[
int
]],
torch
.
Tensor
],
context_len
:
int
=
0
,
seq_len
:
Optional
[
int
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
int
]:
if
isinstance
(
video_grid_thw
,
list
)
and
len
(
video_grid_thw
)
>
0
:
video_grid_thw
=
video_grid_thw
[
0
]
"""Get mrope input positions and delta value (Keye series)."""
def
split_thw
(
grid_thw
:
Union
[
torch
.
Tensor
,
list
[
int
]])
->
list
[
list
[
int
]]:
"""
Split grid_thw along the t dimension.
Args:
grid_thw: shape [N, 3] tensor or nested list of [t, h, w].
Returns:
List of [1, h, w] rows, repeated t times for each original row.
"""
if
isinstance
(
grid_thw
,
list
):
grid_thw
=
torch
.
tensor
(
grid_thw
,
dtype
=
torch
.
long
)
if
grid_thw
.
numel
()
==
0
:
return
[]
t
,
hw
=
grid_thw
[:,
0
],
grid_thw
[:,
1
:]
ones
=
torch
.
ones_like
(
hw
[:,
:
1
])
# [N,1]
out
=
torch
.
cat
([
ones
,
hw
],
dim
=
1
).
repeat_interleave
(
t
,
dim
=
0
)
return
out
.
tolist
()
video_grid_thw
=
split_thw
(
video_grid_thw
)
image_token_id
=
hf_config
.
image_token_id
video_token_id
=
hf_config
.
video_token_id
spatial_merge_size
=
hf_config
.
vision_config
.
spatial_merge_size
image_nums
=
len
(
image_grid_thw
)
frame_nums
=
len
(
video_grid_thw
)
llm_pos_ids_list
:
list
=
[]
st
=
0
remain_images
,
remain_frames
=
image_nums
,
frame_nums
image_index
,
video_index
=
0
,
0
for
_
in
range
(
image_nums
+
frame_nums
):
if
remain_images
>
0
:
try
:
ed_image
=
input_tokens
.
index
(
image_token_id
,
st
)
except
ValueError
:
ed_image
=
len
(
input_tokens
)
+
1
else
:
ed_image
=
len
(
input_tokens
)
+
1
if
remain_frames
>
0
:
try
:
ed_video
=
input_tokens
.
index
(
video_token_id
,
st
)
except
ValueError
:
ed_video
=
len
(
input_tokens
)
+
1
else
:
ed_video
=
len
(
input_tokens
)
+
1
if
ed_image
<
ed_video
:
t
,
h
,
w
=
(
image_grid_thw
[
image_index
][
0
],
image_grid_thw
[
image_index
][
1
],
image_grid_thw
[
image_index
][
2
],
)
image_index
+=
1
remain_images
-=
1
ed
=
ed_image
else
:
t
,
h
,
w
=
(
video_grid_thw
[
video_index
][
0
],
video_grid_thw
[
video_index
][
1
],
video_grid_thw
[
video_index
][
2
],
)
video_index
+=
1
remain_frames
-=
1
ed
=
ed_video
llm_grid_t
,
llm_grid_h
,
llm_grid_w
=
\
t
,
h
//
spatial_merge_size
,
w
//
spatial_merge_size
text_len
=
ed
-
st
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
len
(
llm_pos_ids_list
)
>
0
else
0
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
t_index
=
(
torch
.
arange
(
llm_grid_t
).
view
(
-
1
,
1
).
expand
(
-
1
,
llm_grid_h
*
llm_grid_w
)).
long
().
flatten
()
h_index
=
torch
.
arange
(
llm_grid_h
).
view
(
1
,
-
1
,
1
).
expand
(
llm_grid_t
,
-
1
,
llm_grid_w
).
flatten
()
w_index
=
torch
.
arange
(
llm_grid_w
).
view
(
1
,
1
,
-
1
).
expand
(
llm_grid_t
,
llm_grid_h
,
-
1
).
flatten
()
llm_pos_ids_list
.
append
(
torch
.
stack
([
t_index
,
h_index
,
w_index
])
+
text_len
+
st_idx
)
st
=
ed
+
llm_grid_t
*
llm_grid_h
*
llm_grid_w
if
st
<
len
(
input_tokens
):
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
len
(
llm_pos_ids_list
)
>
0
else
0
text_len
=
len
(
input_tokens
)
-
st
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
llm_positions
=
torch
.
cat
(
llm_pos_ids_list
,
dim
=
1
).
reshape
(
3
,
-
1
)
mrope_position_delta
=
(
llm_positions
.
max
()
+
1
-
len
(
input_tokens
)).
item
()
llm_positions
=
llm_positions
[:,
context_len
:
seq_len
]
return
llm_positions
,
mrope_position_delta
@
classmethod
def
_vl_get_input_positions_tensor
(
cls
,
input_tokens
:
list
[
int
],
hf_config
:
PretrainedConfig
,
image_grid_thw
:
Union
[
list
[
list
[
int
]],
torch
.
Tensor
],
video_grid_thw
:
Union
[
list
[
list
[
int
]],
torch
.
Tensor
],
second_per_grid_ts
:
list
[
float
],
context_len
:
int
=
0
,
seq_len
:
Optional
[
int
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
int
]:
"""Get mrope input positions and delta value."""
image_token_id
=
hf_config
.
image_token_id
video_token_id
=
hf_config
.
video_token_id
vision_start_token_id
=
hf_config
.
vision_start_token_id
spatial_merge_size
=
hf_config
.
vision_config
.
spatial_merge_size
tokens_per_second
=
getattr
(
hf_config
.
vision_config
,
"tokens_per_second"
,
1.0
)
input_tokens_tensor
=
torch
.
tensor
(
input_tokens
)
vision_start_indices
=
torch
.
argwhere
(
input_tokens_tensor
==
vision_start_token_id
).
squeeze
(
1
)
vision_tokens
=
input_tokens_tensor
[
vision_start_indices
+
1
]
image_nums
=
(
vision_tokens
==
image_token_id
).
sum
()
video_nums
=
(
vision_tokens
==
video_token_id
).
sum
()
llm_pos_ids_list
:
list
=
[]
st
=
0
remain_images
,
remain_videos
=
image_nums
,
video_nums
image_index
,
video_index
=
0
,
0
for
_
in
range
(
image_nums
+
video_nums
):
video_second_per_grid_t
=
0.0
if
remain_images
>
0
:
try
:
ed_image
=
input_tokens
.
index
(
image_token_id
,
st
)
except
ValueError
:
ed_image
=
len
(
input_tokens
)
+
1
else
:
ed_image
=
len
(
input_tokens
)
+
1
if
remain_videos
>
0
:
try
:
ed_video
=
input_tokens
.
index
(
video_token_id
,
st
)
except
ValueError
:
ed_video
=
len
(
input_tokens
)
+
1
else
:
ed_video
=
len
(
input_tokens
)
+
1
if
ed_image
<
ed_video
:
t
,
h
,
w
=
(
image_grid_thw
[
image_index
][
0
],
image_grid_thw
[
image_index
][
1
],
image_grid_thw
[
image_index
][
2
],
)
image_index
+=
1
remain_images
-=
1
ed
=
ed_image
else
:
t
,
h
,
w
=
(
video_grid_thw
[
video_index
][
0
],
video_grid_thw
[
video_index
][
1
],
video_grid_thw
[
video_index
][
2
],
)
video_second_per_grid_t
=
1.0
if
second_per_grid_ts
:
video_second_per_grid_t
=
second_per_grid_ts
[
video_index
]
video_index
+=
1
remain_videos
-=
1
ed
=
ed_video
llm_grid_t
,
llm_grid_h
,
llm_grid_w
=
\
t
,
h
//
spatial_merge_size
,
w
//
spatial_merge_size
text_len
=
ed
-
st
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
len
(
llm_pos_ids_list
)
>
0
else
0
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
t_index
=
(
torch
.
arange
(
llm_grid_t
).
view
(
-
1
,
1
).
expand
(
-
1
,
llm_grid_h
*
llm_grid_w
)
*
video_second_per_grid_t
*
tokens_per_second
).
long
().
flatten
()
h_index
=
torch
.
arange
(
llm_grid_h
).
view
(
1
,
-
1
,
1
).
expand
(
llm_grid_t
,
-
1
,
llm_grid_w
).
flatten
()
w_index
=
torch
.
arange
(
llm_grid_w
).
view
(
1
,
1
,
-
1
).
expand
(
llm_grid_t
,
llm_grid_h
,
-
1
).
flatten
()
llm_pos_ids_list
.
append
(
torch
.
stack
([
t_index
,
h_index
,
w_index
])
+
text_len
+
st_idx
)
st
=
ed
+
llm_grid_t
*
llm_grid_h
*
llm_grid_w
if
st
<
len
(
input_tokens
):
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
len
(
llm_pos_ids_list
)
>
0
else
0
text_len
=
len
(
input_tokens
)
-
st
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
llm_positions
=
torch
.
cat
(
llm_pos_ids_list
,
dim
=
1
).
reshape
(
3
,
-
1
)
mrope_position_delta
=
(
llm_positions
.
max
()
+
1
-
len
(
input_tokens
)).
item
()
llm_positions
=
llm_positions
[:,
context_len
:
seq_len
]
return
llm_positions
,
mrope_position_delta
@
classmethod
def
_omni_get_input_positions_tensor
(
cls
,
input_tokens
:
list
[
int
],
hf_config
:
PretrainedConfig
,
image_grid_thw
:
Union
[
list
[
list
[
int
]],
torch
.
Tensor
],
video_grid_thw
:
Union
[
list
[
list
[
int
]],
torch
.
Tensor
],
second_per_grid_ts
:
Optional
[
list
[
float
]]
=
None
,
context_len
:
int
=
0
,
seq_len
:
Optional
[
int
]
=
None
,
audio_feature_lengths
:
Optional
[
torch
.
Tensor
]
=
None
,
use_audio_in_video
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
int
]:
"""Get mrope input positions and delta value (Qwen2.5-Omni version).
Differences from MRotaryEmbedding:
1. Add audio support (and related `audio_feature_lengths`).
2. Add `use_audio_in_video` option to read audio from video inputs.
In this case, audio and vision position ids will be split into
chunks and interleaved.
Example:
(V_i are vision position ids, A_i are audio position ids)
|V_1 ... V_n|A_1 ... A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|...
|vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |...
"""
# TODO(fyabc): refactor and share more code with
# _vl_get_input_positions_tensor.
thinker_config
=
hf_config
.
thinker_config
audio_token_id
=
thinker_config
.
audio_token_index
image_token_id
=
thinker_config
.
image_token_index
video_token_id
=
thinker_config
.
video_token_index
audio_start_token_id
=
thinker_config
.
audio_start_token_id
audio_end_token_id
=
thinker_config
.
audio_end_token_id
vision_start_token_id
=
thinker_config
.
vision_start_token_id
vision_end_token_id
=
thinker_config
.
vision_end_token_id
seconds_per_chunk
=
thinker_config
.
seconds_per_chunk
spatial_merge_size
=
thinker_config
.
vision_config
.
spatial_merge_size
tokens_per_second
=
getattr
(
thinker_config
.
vision_config
,
"tokens_per_second"
,
25
)
if
isinstance
(
image_grid_thw
,
list
):
image_grid_thw
=
torch
.
tensor
(
image_grid_thw
)
if
isinstance
(
video_grid_thw
,
list
):
video_grid_thw
=
torch
.
tensor
(
video_grid_thw
)
src_item
=
input_tokens
audio_seqlens
=
audio_feature_lengths
if
not
second_per_grid_ts
:
second_per_grid_ts
=
[
1
]
*
video_grid_thw
.
shape
[
0
]
audio_idx
=
0
video_idx
=
0
image_idx
=
0
new_src_item
:
list
[
int
]
=
[]
llm_pos_ids_list
:
list
[
torch
.
Tensor
]
=
[]
idx
=
0
while
idx
<
len
(
src_item
):
new_src_item_len
=
len
(
new_src_item
)
start_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
len
(
llm_pos_ids_list
)
>
0
else
0
if
src_item
[
idx
]
not
in
[
audio_token_id
,
video_token_id
,
image_token_id
]:
if
use_audio_in_video
and
idx
>
0
:
if
src_item
[
idx
]
==
vision_end_token_id
and
\
src_item
[
idx
-
1
]
==
audio_end_token_id
:
# processing the <|audio_eos|> before <|vision_eos|>
start_idx
-=
1
elif
src_item
[
idx
]
==
audio_start_token_id
and
\
src_item
[
idx
-
1
]
==
vision_start_token_id
:
# processing the <|audio_bos|> after <|vision_eos|>
start_idx
-=
1
new_src_item
.
append
(
src_item
[
idx
])
llm_pos_ids
=
torch
.
tensor
([
start_idx
],
dtype
=
torch
.
long
).
expand
(
3
,
-
1
)
llm_pos_ids_list
.
append
(
llm_pos_ids
)
elif
src_item
[
idx
]
==
audio_token_id
:
assert
audio_seqlens
is
not
None
audio_seqlen
=
audio_seqlens
[
audio_idx
]
place_num
=
(((
audio_seqlen
-
1
)
//
2
+
1
-
2
)
//
2
+
1
)
new_src_item
.
extend
([
audio_token_id
]
*
place_num
)
llm_pos_ids
=
torch
.
arange
(
place_num
).
expand
(
3
,
-
1
)
+
start_idx
llm_pos_ids_list
.
append
(
llm_pos_ids
)
audio_idx
+=
1
elif
src_item
[
idx
]
==
image_token_id
:
grid_t
=
image_grid_thw
[
image_idx
][
0
]
grid_hs
=
image_grid_thw
[:,
1
]
grid_ws
=
image_grid_thw
[:,
2
]
t_index
=
(
torch
.
arange
(
grid_t
)
*
1
*
tokens_per_second
).
long
()
llm_pos_ids
=
cls
.
_get_llm_pos_ids_for_vision
(
start_idx
,
image_idx
,
spatial_merge_size
,
t_index
,
grid_hs
,
grid_ws
)
llm_pos_ids_list
.
append
(
llm_pos_ids
)
vision_seqlen
=
image_grid_thw
[
image_idx
].
prod
()
//
(
spatial_merge_size
**
2
)
new_src_item
.
extend
([
image_token_id
]
*
vision_seqlen
)
image_idx
+=
1
elif
src_item
[
idx
]
==
video_token_id
and
not
use_audio_in_video
:
grid_t
=
video_grid_thw
[
video_idx
][
0
]
grid_hs
=
video_grid_thw
[:,
1
]
grid_ws
=
video_grid_thw
[:,
2
]
t_index
=
(
torch
.
arange
(
grid_t
)
*
second_per_grid_ts
[
video_idx
]
*
tokens_per_second
).
long
()
llm_pos_ids
=
cls
.
_get_llm_pos_ids_for_vision
(
start_idx
,
video_idx
,
spatial_merge_size
,
t_index
,
grid_hs
,
grid_ws
)
llm_pos_ids_list
.
append
(
llm_pos_ids
)
vision_seqlen
=
video_grid_thw
[
video_idx
].
prod
()
//
(
spatial_merge_size
**
2
)
new_src_item
.
extend
([
video_token_id
]
*
vision_seqlen
)
video_idx
+=
1
else
:
# read audio from video
assert
audio_seqlens
is
not
None
audio_seqlen
=
audio_seqlens
[
audio_idx
]
vision_seqlen
=
video_grid_thw
[
video_idx
].
prod
()
//
(
spatial_merge_size
**
2
)
grid_t
=
video_grid_thw
[
video_idx
][
0
]
grid_h
=
video_grid_thw
[
video_idx
][
1
]
grid_w
=
video_grid_thw
[
video_idx
][
2
]
grid_hs
=
video_grid_thw
[:,
1
]
grid_ws
=
video_grid_thw
[:,
2
]
t_ntoken_per_chunk
=
int
(
tokens_per_second
*
seconds_per_chunk
)
t_index
=
(
torch
.
arange
(
grid_t
)
*
second_per_grid_ts
[
video_idx
]
*
tokens_per_second
).
long
()
t_index_split_chunk
=
cls
.
_split_list_into_ranges
(
t_index
,
t_ntoken_per_chunk
)
place_num
=
(((
audio_seqlen
-
1
)
//
2
+
1
-
2
)
//
2
+
1
)
+
2
pure_audio_len
=
place_num
-
2
added_audio_len
=
0
audio_llm_pos_ids_list
:
list
[
torch
.
Tensor
]
=
[]
for
t_chunk
in
t_index_split_chunk
:
vision_ntoken_per_chunk
=
len
(
t_chunk
)
*
grid_h
*
grid_w
//
(
spatial_merge_size
**
2
)
new_src_item
.
extend
([
video_token_id
]
*
vision_ntoken_per_chunk
)
vision_llm_pos_ids_list
=
cls
.
_get_llm_pos_ids_for_vision
(
start_idx
,
video_idx
,
spatial_merge_size
,
t_chunk
,
grid_hs
,
grid_ws
).
split
(
1
,
dim
=
1
)
llm_pos_ids_list
.
extend
(
vision_llm_pos_ids_list
)
new_src_item
.
extend
(
min
(
t_ntoken_per_chunk
,
pure_audio_len
-
added_audio_len
)
*
[
audio_token_id
])
audio_start_idx
=
start_idx
if
len
(
audio_llm_pos_ids_list
)
==
0
else
audio_llm_pos_ids_list
[
-
1
][
0
].
item
()
+
1
if
min
(
t_ntoken_per_chunk
,
pure_audio_len
-
added_audio_len
)
>
0
:
audio_llm_pos_ids_list
=
(
torch
.
arange
(
min
(
t_ntoken_per_chunk
,
pure_audio_len
-
added_audio_len
)).
expand
(
3
,
-
1
)
+
audio_start_idx
).
split
(
1
,
dim
=
1
)
else
:
audio_llm_pos_ids_list
=
[]
added_audio_len
+=
min
(
t_ntoken_per_chunk
,
pure_audio_len
-
added_audio_len
)
llm_pos_ids_list
.
extend
(
audio_llm_pos_ids_list
)
if
added_audio_len
<
pure_audio_len
:
new_src_item
.
extend
(
(
pure_audio_len
-
added_audio_len
)
*
[
audio_token_id
])
audio_llm_pos_ids_list
=
(
torch
.
arange
(
pure_audio_len
-
added_audio_len
).
expand
(
3
,
-
1
)
+
llm_pos_ids_list
[
-
1
].
max
()
+
1
).
split
(
1
,
dim
=
1
)
llm_pos_ids_list
.
extend
(
audio_llm_pos_ids_list
)
audio_idx
+=
1
video_idx
+=
1
# move to the next token
idx
+=
len
(
new_src_item
)
-
new_src_item_len
llm_positions
=
torch
.
cat
(
llm_pos_ids_list
,
dim
=
1
)
mrope_position_delta
=
torch
.
cat
(
llm_pos_ids_list
,
dim
=
1
).
max
()
+
1
-
len
(
src_item
)
llm_positions
=
llm_positions
[:,
context_len
:
seq_len
]
return
llm_positions
,
mrope_position_delta
@
staticmethod
def
_get_llm_pos_ids_for_vision
(
start_idx
:
int
,
vision_idx
:
int
,
spatial_merge_size
:
int
,
t_index
:
list
[
int
],
grid_hs
:
torch
.
Tensor
,
grid_ws
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
llm_pos_ids_list
=
[]
llm_grid_h
=
grid_hs
[
vision_idx
]
//
spatial_merge_size
llm_grid_w
=
grid_ws
[
vision_idx
]
//
spatial_merge_size
h_index
=
(
torch
.
arange
(
llm_grid_h
).
view
(
1
,
-
1
,
1
).
expand
(
len
(
t_index
),
-
1
,
llm_grid_w
).
flatten
())
w_index
=
(
torch
.
arange
(
llm_grid_w
).
view
(
1
,
1
,
-
1
).
expand
(
len
(
t_index
),
llm_grid_h
,
-
1
).
flatten
())
t_index_tensor
=
torch
.
Tensor
(
t_index
).
to
(
llm_grid_h
.
device
).
view
(
-
1
,
1
).
expand
(
-
1
,
llm_grid_h
*
llm_grid_w
).
long
().
flatten
()
_llm_pos_ids
=
torch
.
stack
([
t_index_tensor
,
h_index
,
w_index
])
llm_pos_ids_list
.
append
(
_llm_pos_ids
+
start_idx
)
llm_pos_ids
=
torch
.
cat
(
llm_pos_ids_list
,
dim
=
1
)
return
llm_pos_ids
@
staticmethod
def
_split_list_into_ranges
(
lst
:
torch
.
Tensor
,
interval
:
int
)
->
list
[
list
[
int
]]:
ranges
:
list
[
list
[
int
]]
=
[[]
for
_
in
range
((
max
(
lst
)
//
interval
)
+
1
)]
for
num
in
lst
:
index
=
num
//
interval
ranges
[
index
].
append
(
num
)
return
ranges
@
staticmethod
def
get_next_input_positions
(
mrope_position_delta
:
int
,
...
...
@@ -1254,68 +379,24 @@ class MRotaryEmbedding(RotaryEmbedding):
)
->
list
[
list
[
int
]]:
return
[
list
(
range
(
context_len
+
mrope_position_delta
,
seq_len
+
mrope_position_delta
))
for
_
in
range
(
3
)
range
(
context_len
+
mrope_position_delta
,
seq_len
+
mrope_position_delta
)
)
for
_
in
range
(
3
)
]
@
staticmethod
def
get_next_input_positions_tensor
(
out
:
np
.
ndarray
,
out_offset
:
int
,
mrope_position_delta
:
int
,
context_len
:
int
,
num_new_tokens
:
int
):
values
=
np
.
arange
(
mrope_position_delta
+
context_len
,
mrope_position_delta
+
context_len
+
num_new_tokens
,
dtype
=
out
.
dtype
)
out
[:,
out_offset
:
out_offset
+
num_new_tokens
]
=
values
@
classmethod
def
omni_get_updates_use_audio_in_video
(
cls
,
thinker_config
:
PretrainedConfig
,
audio_len
:
int
,
video_grid_thw
:
Union
[
list
[
int
],
torch
.
Tensor
],
video_second_per_grid_t
:
float
,
)
->
list
[
int
]:
"""Get video prompt updates when `use_audio_in_video` is True.
In this case, audio and vision update ids will be split into
chunks and interleaved (details in `_omni_get_input_positions_tensor`).
<|video_bos|><|VIDEO|><|video_eos|> =>
<|video_bos|><|audio_bos|>(... chunks ...)<|audio_eos|><|video_eos|>
"""
audio_token_id
=
thinker_config
.
audio_token_index
video_token_id
=
thinker_config
.
video_token_index
audio_start_token_id
=
thinker_config
.
audio_start_token_id
audio_end_token_id
=
thinker_config
.
audio_end_token_id
seconds_per_chunk
=
thinker_config
.
seconds_per_chunk
spatial_merge_size
=
thinker_config
.
vision_config
.
spatial_merge_size
tokens_per_second
=
getattr
(
thinker_config
.
vision_config
,
"tokens_per_second"
,
25
)
grid_t
=
video_grid_thw
[
0
]
grid_h
=
video_grid_thw
[
1
]
grid_w
=
video_grid_thw
[
2
]
t_ntoken_per_chunk
=
int
(
tokens_per_second
*
seconds_per_chunk
)
t_index
=
(
torch
.
arange
(
grid_t
)
*
video_second_per_grid_t
*
tokens_per_second
).
long
()
t_index_split_chunk
=
cls
.
_split_list_into_ranges
(
t_index
,
t_ntoken_per_chunk
)
updates
=
[
audio_start_token_id
]
added_audio_len
=
0
for
t_chunk
in
t_index_split_chunk
:
vision_ntoken_per_chunk
=
len
(
t_chunk
)
*
grid_h
*
grid_w
//
(
spatial_merge_size
**
2
)
updates
.
extend
([
video_token_id
]
*
vision_ntoken_per_chunk
)
audio_chunk_size
=
min
(
t_ntoken_per_chunk
,
audio_len
-
added_audio_len
)
updates
.
extend
(
audio_chunk_size
*
[
audio_token_id
])
added_audio_len
+=
audio_chunk_size
if
added_audio_len
<
audio_len
:
updates
.
extend
((
audio_len
-
added_audio_len
)
*
[
audio_token_id
])
updates
.
extend
([
audio_end_token_id
])
return
updates
def
get_next_input_positions_tensor
(
out
:
np
.
ndarray
,
out_offset
:
int
,
mrope_position_delta
:
int
,
context_len
:
int
,
num_new_tokens
:
int
,
):
values
=
np
.
arange
(
mrope_position_delta
+
context_len
,
mrope_position_delta
+
context_len
+
num_new_tokens
,
dtype
=
out
.
dtype
,
)
out
[:,
out_offset
:
out_offset
+
num_new_tokens
]
=
values
\ No newline at end of file
vllm/model_executor/models/interfaces.py
View file @
f39afa4a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Iterable
,
Mapping
,
MutableSequence
from
collections.abc
import
Iterable
,
Mapping
,
MutableSequence
,
Callable
from
typing
import
(
TYPE_CHECKING
,
ClassVar
,
Literal
,
Optional
,
Protocol
,
Union
,
overload
,
runtime_checkable
)
...
...
@@ -20,7 +20,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig
)
from
vllm.utils
import
supports_kw
from
.interfaces_base
import
is_pooling_model
from
.interfaces_base
import
is_pooling_model
,
VllmModel
if
TYPE_CHECKING
:
from
vllm.config
import
VllmConfig
...
...
@@ -81,10 +81,9 @@ class SupportsMultiModal(Protocol):
"""
...
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
"""
Returns multimodal embeddings generated from multimodal kwargs
Returns multimodal embeddings generated from multimodal kwargs
to be merged with text embeddings.
Note:
...
...
@@ -94,11 +93,11 @@ class SupportsMultiModal(Protocol):
"""
...
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
def
get_language_model
(
self
)
->
VllmModel
:
"""
Returns the underlying language model used for text generation.
This is typically the `torch.nn.Module` instance responsible for
This is typically the `torch.nn.Module` instance responsible for
processing the merged multimodal embeddings and producing hidden states
Returns:
...
...
@@ -106,19 +105,83 @@ class SupportsMultiModal(Protocol):
"""
...
@
overload
def
get_input_embeddings
(
self
,
input_ids
:
Tensor
)
->
Tensor
:
...
@
overload
def
get_input_embeddings
(
self
,
input_ids
:
Tensor
,
multimodal_embeddings
:
MultiModalEmbeddings
,
*
,
is_multimodal
:
torch
.
Tensor
,
handle_oov_mm_token
:
bool
=
False
,
)
->
Tensor
:
...
def
_get_text_embeddings
(
self
,
input_ids
:
Tensor
,
get_input_embeddings
:
Callable
[[
Tensor
],
Tensor
],
*
,
is_multimodal
:
Optional
[
Tensor
],
handle_oov_mm_token
:
bool
,
)
->
Tensor
:
if
handle_oov_mm_token
and
is_multimodal
is
not
None
:
is_text
=
~
is_multimodal
text_embeds
=
get_input_embeddings
(
input_ids
[
is_text
])
return
torch
.
empty
(
(
input_ids
.
shape
[
0
],
text_embeds
.
shape
[
1
]),
dtype
=
text_embeds
.
dtype
,
device
=
text_embeds
.
device
,
).
masked_scatter_
(
is_text
.
unsqueeze_
(
-
1
),
text_embeds
)
return
get_input_embeddings
(
input_ids
)
def
get_input_embeddings
(
self
,
input_ids
:
Tensor
,
multimodal_embeddings
:
Optional
[
MultiModalEmbeddings
]
=
None
,
*
,
is_multimodal
:
Optional
[
Tensor
]
=
None
,
handle_oov_mm_token
:
bool
=
False
,
)
->
Tensor
:
"""
Returns the input embeddings merged from the text embeddings from
input_ids and the multimodal embeddings generated from multimodal
kwargs.
Apply token embeddings to `input_ids`.
If `multimodal_embeddings` is passed, scatter them into
`input_ids` according to the mask `is_multimodal`.
In case the multi-modal token IDs exceed the vocabulary size of
the language model, you can set `handle_oov_mm_token=False`
to avoid calling the language model's `get_input_embeddings` method
on those tokens. Note however that doing so increases memory usage
as an additional buffer is needed to hold the input embeddings.
"""
...
from
.utils
import
_merge_multimodal_embeddings
inputs_embeds
=
self
.
_get_text_embeddings
(
input_ids
,
self
.
get_language_model
().
get_input_embeddings
,
is_multimodal
=
is_multimodal
,
handle_oov_mm_token
=
handle_oov_mm_token
,
)
if
multimodal_embeddings
is
None
or
len
(
multimodal_embeddings
)
==
0
:
return
inputs_embeds
if
is_multimodal
is
None
:
raise
ValueError
(
"`get_input_embeddings` now requires `is_multimodal` arg, "
"please update your model runner according to "
"https://github.com/vllm-project/vllm/pull/16229."
)
return
_merge_multimodal_embeddings
(
inputs_embeds
=
inputs_embeds
,
multimodal_embeddings
=
multimodal_embeddings
,
is_multimodal
=
is_multimodal
,
)
@
runtime_checkable
class
SupportsMultiModalPruning
(
Protocol
):
"""The interface required for models that support returning both input
...
...
vllm/model_executor/models/qwen3_omni_moe_thinker.py
0 → 100644
View file @
f39afa4a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen3-Omni-Moe model (thinker part)."""
from
collections.abc
import
Callable
,
Iterable
,
Mapping
,
Sequence
from
functools
import
partial
from
typing
import
Any
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
packaging.version
import
Version
from
transformers
import
PretrainedConfig
from
transformers
import
__version__
as
TRANSFORMERS_VERSION
from
transformers.feature_extraction_utils
import
BatchFeature
from
transformers.models.qwen3_omni_moe.configuration_qwen3_omni_moe
import
(
Qwen3OmniMoeConfig
,
Qwen3OmniMoeThinkerConfig
,
)
from
transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe
import
(
Qwen3OmniMoeAudioEncoder
,
)
from
transformers.models.qwen3_omni_moe.processing_qwen3_omni_moe
import
(
Qwen3OmniMoeProcessor
,
)
from
transformers.models.whisper
import
WhisperFeatureExtractor
# from vllm.attention.backends.registry import _Backend
from
vllm.platforms
import
_Backend
,
current_platform
from
vllm.attention.layer
import
check_upstream_fa_availability
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
get_pp_group
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
_ACTIVATION_REGISTRY
from
vllm.model_executor.layers.linear
import
ColumnParallelLinear
,
RowParallelLinear
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.qwen2_audio
import
(
Qwen2AudioFeatureInputs
,
Qwen2AudioProcessingInfo
,
)
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
MultiModalKwargsItems
from
vllm.multimodal.parse
import
AudioProcessorItems
,
MultiModalDataItems
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
MultiModalPromptUpdates
,
PlaceholderFeaturesInfo
,
PromptReplacement
,
PromptUpdate
,
)
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsMRoPE
,
SupportsMultiModal
,
SupportsPP
,
)
from
.qwen2_5_omni_thinker
import
(
Qwen2_5OmniConditionalGenerationMixin
,
Qwen2_5OmniThinkerDummyInputsBuilder
,
Qwen2_5OmniThinkerMultiModalProcessor
,
Qwen2_5OmniThinkerProcessingInfo
,
)
from
.qwen2_5_vl
import
(
Qwen2_5_VisionAttention
,
Qwen2_5_VisionRotaryEmbedding
,
Qwen2_5_VLProcessingInfo
,
)
from
.qwen3_moe
import
Qwen3MoeForCausalLM
,
Qwen3MoeModel
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
_merge_multimodal_embeddings
,
maybe_prefix
,
)
from
.vision
import
get_llm_pos_ids_for_vision
,
get_vit_attn_backend
try
:
import
flash_attn
except
(
ImportError
,
ModuleNotFoundError
):
flash_attn
=
None
logger
=
init_logger
(
__name__
)
def
_get_feat_extract_output_lengths
(
input_lengths
:
torch
.
Tensor
):
input_lengths_leave
=
input_lengths
%
100
feat_lengths
=
(
input_lengths_leave
-
1
)
//
2
+
1
output_lengths
=
(
((
feat_lengths
-
1
)
//
2
+
1
-
1
)
//
2
+
1
+
(
input_lengths
//
100
)
*
13
)
return
feat_lengths
,
output_lengths
class
Qwen3_VisionPatchEmbed
(
nn
.
Module
):
def
__init__
(
self
,
patch_size
:
int
=
14
,
temporal_patch_size
:
int
=
2
,
in_channels
:
int
=
3
,
hidden_size
:
int
=
1152
,
)
->
None
:
super
().
__init__
()
self
.
patch_size
=
patch_size
self
.
temporal_patch_size
=
temporal_patch_size
self
.
hidden_size
=
hidden_size
kernel_size
=
(
temporal_patch_size
,
patch_size
,
patch_size
)
self
.
proj
=
nn
.
Conv3d
(
in_channels
,
hidden_size
,
kernel_size
=
kernel_size
,
stride
=
kernel_size
,
bias
=
True
,
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
L
,
C
=
x
.
shape
x
=
x
.
view
(
L
,
-
1
,
self
.
temporal_patch_size
,
self
.
patch_size
,
self
.
patch_size
)
x
=
self
.
proj
(
x
).
view
(
L
,
self
.
hidden_size
)
return
x
class
Qwen3_VisionMLP
(
nn
.
Module
):
def
__init__
(
self
,
in_features
:
int
,
hidden_features
:
int
,
bias
:
bool
=
False
,
act_fn
:
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
]
=
F
.
silu
,
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
linear_fc1
=
ColumnParallelLinear
(
in_features
,
hidden_features
,
bias
=
bias
,
quant_config
=
quant_config
,
return_bias
=
False
,
prefix
=
f
"
{
prefix
}
.linear_fc1"
,
)
self
.
linear_fc2
=
RowParallelLinear
(
hidden_features
,
in_features
,
bias
=
bias
,
quant_config
=
quant_config
,
return_bias
=
False
,
prefix
=
f
"
{
prefix
}
.linear_fc2"
,
)
self
.
act_fn
=
act_fn
def
forward
(
self
,
x
:
torch
.
Tensor
):
mlp_output
=
self
.
linear_fc2
(
self
.
act_fn
(
self
.
linear_fc1
(
x
)))
return
mlp_output
class
Qwen3_VisionBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
num_heads
:
int
,
mlp_hidden_dim
:
int
,
act_fn
:
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
]
=
F
.
silu
,
norm_layer
:
Callable
[[
int
],
nn
.
Module
]
|
None
=
None
,
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
if
norm_layer
is
None
:
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
)
self
.
norm1
=
norm_layer
(
dim
)
self
.
norm2
=
norm_layer
(
dim
)
self
.
attn
=
Qwen2_5_VisionAttention
(
embed_dim
=
dim
,
num_heads
=
num_heads
,
projection_size
=
dim
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
,
)
self
.
mlp
=
Qwen3_VisionMLP
(
dim
,
mlp_hidden_dim
,
act_fn
=
act_fn
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
rotary_pos_emb
:
torch
.
Tensor
,
max_seqlen
:
int
|
None
=
None
,
# Only used for Flash Attention
seqlens
:
list
[
int
]
|
None
=
None
,
# Only used for xFormers
)
->
torch
.
Tensor
:
x
=
x
+
self
.
attn
(
self
.
norm1
(
x
),
cu_seqlens
=
cu_seqlens
,
rotary_pos_emb
=
rotary_pos_emb
,
max_seqlen
=
max_seqlen
,
seqlens
=
seqlens
,
)
x
=
x
+
self
.
mlp
(
self
.
norm2
(
x
))
return
x
class
Qwen3_VisionPatchMerger
(
nn
.
Module
):
def
__init__
(
self
,
d_model
:
int
,
context_dim
:
int
,
norm_layer
:
Callable
[[
int
],
nn
.
Module
]
|
None
=
None
,
spatial_merge_size
:
int
=
2
,
use_postshuffle_norm
:
bool
=
False
,
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
context_dim
*
(
spatial_merge_size
**
2
)
self
.
use_postshuffle_norm
=
use_postshuffle_norm
if
self
.
use_postshuffle_norm
:
context_dim
=
self
.
hidden_size
if
norm_layer
is
None
:
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
)
self
.
use_postshuffle_norm
=
use_postshuffle_norm
self
.
ln_q
=
norm_layer
(
self
.
hidden_size
if
use_postshuffle_norm
else
context_dim
)
self
.
mlp
=
nn
.
ModuleList
(
[
ColumnParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp.0"
,
),
nn
.
GELU
(),
RowParallelLinear
(
self
.
hidden_size
,
d_model
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp.2"
,
),
]
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
use_postshuffle_norm
:
x
=
self
.
ln_q
(
x
.
view
(
-
1
,
self
.
hidden_size
))
else
:
x
=
self
.
ln_q
(
x
).
view
(
-
1
,
self
.
hidden_size
)
mlp_fc1
,
mlp_act
,
mlp_fc2
=
self
.
mlp
x_parallel
,
_
=
mlp_fc1
(
x
)
x_parallel
=
mlp_act
(
x_parallel
)
out
,
_
=
mlp_fc2
(
x_parallel
)
return
out
class
Qwen3Omni_VisionTransformer
(
nn
.
Module
):
def
__init__
(
self
,
vision_config
,
norm_eps
:
float
=
1e-6
,
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
vision_config
.
hidden_size
self
.
num_heads
=
vision_config
.
num_heads
self
.
image_size
=
vision_config
.
image_size
self
.
patch_size
=
vision_config
.
patch_size
self
.
spatial_merge_size
=
vision_config
.
spatial_merge_size
self
.
spatial_merge_unit
=
self
.
spatial_merge_size
**
2
self
.
temporal_patch_size
=
vision_config
.
temporal_patch_size
self
.
num_grid_per_side
=
self
.
image_size
//
self
.
patch_size
self
.
apply_vit_abs_pos_embed
=
vision_config
.
apply_vit_abs_pos_embed
self
.
deepstack_visual_indexes
=
vision_config
.
deepstack_visual_indexes
self
.
patch_embed
=
Qwen3_VisionPatchEmbed
(
patch_size
=
self
.
patch_size
,
temporal_patch_size
=
self
.
temporal_patch_size
,
in_channels
=
vision_config
.
in_channels
,
hidden_size
=
self
.
hidden_size
,
)
# vit pos embeding, TODO: spatial_patch_size vs patch_size
if
self
.
apply_vit_abs_pos_embed
:
self
.
pos_embed
=
nn
.
Embedding
(
self
.
num_grid_per_side
**
2
,
self
.
hidden_size
)
else
:
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
empty
([
1
,
self
.
num_grid_per_side
**
2
,
self
.
hidden_size
])
)
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
norm_eps
)
head_dim
=
self
.
hidden_size
//
self
.
num_heads
self
.
rotary_pos_emb
=
Qwen2_5_VisionRotaryEmbedding
(
head_dim
//
2
)
self
.
blocks
=
nn
.
ModuleList
(
[
Qwen3_VisionBlock
(
dim
=
self
.
hidden_size
,
num_heads
=
self
.
num_heads
,
mlp_hidden_dim
=
vision_config
.
intermediate_size
,
act_fn
=
_ACTIVATION_REGISTRY
[
vision_config
.
hidden_act
],
norm_layer
=
norm_layer
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.blocks.
{
layer_idx
}
"
,
)
for
layer_idx
in
range
(
vision_config
.
depth
)
]
)
self
.
merger
=
Qwen3_VisionPatchMerger
(
d_model
=
vision_config
.
out_hidden_size
,
context_dim
=
self
.
hidden_size
,
norm_layer
=
norm_layer
,
spatial_merge_size
=
self
.
spatial_merge_size
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.merger"
,
)
if
self
.
deepstack_visual_indexes
is
not
None
:
self
.
merger_list
=
nn
.
ModuleList
(
[
Qwen3_VisionPatchMerger
(
d_model
=
vision_config
.
out_hidden_size
,
context_dim
=
self
.
hidden_size
,
spatial_merge_size
=
self
.
spatial_merge_size
,
use_postshuffle_norm
=
True
,
norm_layer
=
norm_layer
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.merger_list.
{
layer_idx
}
"
,
)
for
layer_idx
in
range
(
len
(
self
.
deepstack_visual_indexes
))
]
)
self
.
attn_backend
=
get_vit_attn_backend
(
head_size
=
head_dim
,
dtype
=
torch
.
get_default_dtype
()
)
if
self
.
attn_backend
!=
_Backend
.
FLASH_ATTN
and
check_upstream_fa_availability
(
torch
.
get_default_dtype
()
):
self
.
attn_backend
=
_Backend
.
FLASH_ATTN
@
property
def
dtype
(
self
)
->
torch
.
dtype
:
return
self
.
patch_embed
.
proj
.
weight
.
dtype
@
property
def
device
(
self
)
->
torch
.
device
:
return
self
.
patch_embed
.
proj
.
weight
.
device
def
rot_pos_emb
(
self
,
grid_thw
):
pos_ids
=
[]
for
t
,
h
,
w
in
grid_thw
:
hpos_ids
=
torch
.
arange
(
h
).
unsqueeze
(
1
).
expand
(
-
1
,
w
)
hpos_ids
=
hpos_ids
.
reshape
(
h
//
self
.
spatial_merge_size
,
self
.
spatial_merge_size
,
w
//
self
.
spatial_merge_size
,
self
.
spatial_merge_size
,
)
hpos_ids
=
hpos_ids
.
permute
(
0
,
2
,
1
,
3
)
hpos_ids
=
hpos_ids
.
flatten
()
wpos_ids
=
torch
.
arange
(
w
).
unsqueeze
(
0
).
expand
(
h
,
-
1
)
wpos_ids
=
wpos_ids
.
reshape
(
h
//
self
.
spatial_merge_size
,
self
.
spatial_merge_size
,
w
//
self
.
spatial_merge_size
,
self
.
spatial_merge_size
,
)
wpos_ids
=
wpos_ids
.
permute
(
0
,
2
,
1
,
3
)
wpos_ids
=
wpos_ids
.
flatten
()
pos_ids
.
append
(
torch
.
stack
([
hpos_ids
,
wpos_ids
],
dim
=-
1
).
repeat
(
t
,
1
))
pos_ids
=
torch
.
cat
(
pos_ids
,
dim
=
0
)
max_grid_size
=
grid_thw
[:,
1
:].
max
()
rotary_pos_emb_full
=
self
.
rotary_pos_emb
(
max_grid_size
)
rotary_pos_emb
=
rotary_pos_emb_full
[
pos_ids
].
flatten
(
1
)
return
rotary_pos_emb
def
fast_pos_embed_interpolate
(
self
,
grid_thw
:
list
[
list
[
int
]])
->
torch
.
Tensor
:
num_grid_per_side
=
self
.
num_grid_per_side
m_size
=
self
.
spatial_merge_size
hidden_dim
=
self
.
pos_embed
.
embedding_dim
outputs
=
[]
for
t
,
h
,
w
in
grid_thw
:
h_idxs
=
torch
.
linspace
(
0
,
num_grid_per_side
-
1
,
h
,
dtype
=
torch
.
float32
,
device
=
self
.
device
)
w_idxs
=
torch
.
linspace
(
0
,
num_grid_per_side
-
1
,
w
,
dtype
=
torch
.
float32
,
device
=
self
.
device
)
h_floor
=
h_idxs
.
to
(
torch
.
long
)
w_floor
=
w_idxs
.
to
(
torch
.
long
)
h_ceil
=
torch
.
clamp
(
h_floor
+
1
,
max
=
num_grid_per_side
-
1
)
w_ceil
=
torch
.
clamp
(
w_floor
+
1
,
max
=
num_grid_per_side
-
1
)
dh
=
h_idxs
-
h_floor
dw
=
w_idxs
-
w_floor
# Create meshgrid view for all h, w vars
dh_grid
,
dw_grid
=
torch
.
meshgrid
(
dh
,
dw
,
indexing
=
"ij"
)
h_floor_grid
,
w_floor_grid
=
torch
.
meshgrid
(
h_floor
,
w_floor
,
indexing
=
"ij"
)
h_ceil_grid
,
w_ceil_grid
=
torch
.
meshgrid
(
h_ceil
,
w_ceil
,
indexing
=
"ij"
)
h_floor_grid_idx
=
h_floor_grid
*
num_grid_per_side
h_ceil_grid_idx
=
h_ceil_grid
*
num_grid_per_side
# original computation of weights
# w00 = (1 - dh_grid) * (1 - dw_grid)
# w01 = (1 - dh_grid) * dw_grid
# w10 = dh_grid * (1 - dw_grid)
# w11 = dh_grid * dw_grid
# we reuse w11 here to avoid duplicate
# dh_grid * dw_grid computation
w11
=
dh_grid
*
dw_grid
w10
=
dh_grid
-
w11
w01
=
dw_grid
-
w11
w00
=
1
-
dh_grid
-
dw_grid
+
w11
idx00
=
h_floor_grid_idx
+
w_floor_grid
idx01
=
h_floor_grid_idx
+
w_ceil_grid
idx10
=
h_ceil_grid_idx
+
w_floor_grid
idx11
=
h_ceil_grid_idx
+
w_ceil_grid
indices
=
torch
.
stack
([
idx00
,
idx01
,
idx10
,
idx11
],
dim
=
0
).
reshape
(
4
,
-
1
)
weights
=
torch
.
stack
([
w00
,
w01
,
w10
,
w11
],
dim
=
0
).
reshape
(
4
,
-
1
,
1
)
weights
=
weights
.
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
embeds
=
self
.
pos_embed
(
indices
)
weighted_embeds
=
embeds
*
weights
p0
,
p1
,
p2
,
p3
=
weighted_embeds
.
unbind
(
dim
=
0
)
combined
=
p0
+
p1
+
p2
+
p3
combined
=
combined
.
view
(
h
*
w
,
hidden_dim
)
repeated
=
combined
.
unsqueeze
(
0
).
expand
(
t
,
-
1
,
-
1
).
contiguous
()
repeated
=
repeated
.
view
(
t
,
h
//
m_size
,
m_size
,
w
//
m_size
,
m_size
,
hidden_dim
)
repeated
=
repeated
.
permute
(
0
,
1
,
3
,
2
,
4
,
5
).
reshape
(
-
1
,
hidden_dim
)
outputs
.
append
(
repeated
)
return
torch
.
cat
(
outputs
,
dim
=
0
)
def
compute_attn_mask_seqlen
(
self
,
cu_seqlens
:
torch
.
Tensor
,
)
->
tuple
[
int
|
None
,
list
[
int
]
|
None
]:
max_seqlen
,
seqlens
=
None
,
None
if
self
.
attn_backend
==
_Backend
.
FLASH_ATTN
:
max_seqlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
().
item
()
elif
self
.
attn_backend
==
_Backend
.
XFORMERS
:
seqlens
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
tolist
()
return
max_seqlen
,
seqlens
def
forward
(
self
,
x
:
torch
.
Tensor
,
grid_thw
:
list
[
list
[
int
]],
)
->
torch
.
Tensor
:
hidden_states
=
x
.
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
)
hidden_states
=
self
.
patch_embed
(
hidden_states
)
if
self
.
apply_vit_abs_pos_embed
:
pos_embeds
=
self
.
fast_pos_embed_interpolate
(
grid_thw
)
hidden_states
=
hidden_states
+
pos_embeds
rotary_pos_emb
=
self
.
rot_pos_emb
(
grid_thw
)
cu_seqlens
=
torch
.
repeat_interleave
(
grid_thw
[:,
1
]
*
grid_thw
[:,
2
],
grid_thw
[:,
0
]
).
cumsum
(
dim
=
0
,
dtype
=
grid_thw
.
dtype
if
torch
.
jit
.
is_tracing
()
else
torch
.
int32
,
)
cu_seqlens
=
F
.
pad
(
cu_seqlens
,
(
1
,
0
),
value
=
0
)
hidden_states
=
hidden_states
.
unsqueeze
(
1
)
rotary_pos_emb
=
rotary_pos_emb
.
to
(
hidden_states
.
device
)
max_seqlen
,
seqlens
=
self
.
compute_attn_mask_seqlen
(
cu_seqlens
)
hidden_states_list
=
[]
deepstack_visual_indexes
=
self
.
deepstack_visual_indexes
for
layer_num
,
blk
in
enumerate
(
self
.
blocks
):
hidden_states
=
blk
(
hidden_states
,
cu_seqlens
=
cu_seqlens
,
rotary_pos_emb
=
rotary_pos_emb
,
max_seqlen
=
max_seqlen
,
seqlens
=
seqlens
,
)
if
(
deepstack_visual_indexes
is
not
None
and
layer_num
in
deepstack_visual_indexes
):
hidden_states_list
.
append
(
hidden_states
)
hidden_states
=
self
.
merger
(
hidden_states
)
# processing deepstack
if
deepstack_visual_indexes
is
not
None
:
processed_hidden_states_list
=
[
hidden_states
]
for
idx
,
x
in
enumerate
(
hidden_states_list
):
x
=
self
.
merger_list
[
idx
](
x
)
processed_hidden_states_list
.
append
(
x
)
# we cat the original visual features and deepstack features
# along the feature dim
hidden_states
=
torch
.
cat
(
processed_hidden_states_list
,
dim
=
1
)
# [seq_len, hidden_size * (1 + depth_of_deepstack)]
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"attn.qkv."
,
"attn.q."
,
"q"
),
(
"attn.qkv."
,
"attn.k."
,
"k"
),
(
"attn.qkv."
,
"attn.v."
,
"v"
),
]
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
loaded_params
:
set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
@
support_torch_compile
(
dynamic_arg_dims
=
{
"input_ids"
:
0
,
"positions"
:
-
1
,
"intermediate_tensors"
:
0
,
"inputs_embeds"
:
0
,
"deepstack_input_embeds"
:
0
,
}
)
class
Qwen3MoeLLMModel
(
Qwen3MoeModel
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
self
.
deepstack_multiscale_layer_start
=
1
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
deepstack_input_embeds
:
IntermediateTensors
|
None
=
None
,
)
->
torch
.
Tensor
|
IntermediateTensors
:
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
residual
=
None
else
:
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
layer_idx
,
layer
in
enumerate
(
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]
):
layer_idx
=
layer_idx
+
self
.
start_layer
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
residual
,
)
if
deepstack_input_embeds
is
not
None
and
layer_idx
in
range
(
0
,
len
(
deepstack_input_embeds
)
):
hidden_states
=
(
hidden_states
+
deepstack_input_embeds
[
f
"deepstack_input_embeds_
{
layer_idx
}
"
]
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
(
{
"hidden_states"
:
hidden_states
,
"residual"
:
residual
}
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
class
Qwen3MoeLLMForCausalLM
(
Qwen3MoeForCausalLM
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
(
Qwen3MoeForCausalLM
,
self
).
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
Qwen3MoeLLMModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
)
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
class
Qwen3OmniMoeThinkerProcessingInfo
(
Qwen2AudioProcessingInfo
,
Qwen2_5_VLProcessingInfo
):
def
get_hf_config
(
self
):
return
self
.
ctx
.
get_hf_config
(
Qwen3OmniMoeConfig
).
thinker_config
def
get_hf_processor
(
self
,
**
kwargs
:
object
)
->
Qwen3OmniMoeProcessor
:
processor
=
self
.
ctx
.
get_hf_processor
(
Qwen3OmniMoeProcessor
,
use_fast
=
kwargs
.
pop
(
"use_fast"
,
True
),
**
kwargs
,
)
if
not
hasattr
(
processor
,
"audio_token"
):
processor
.
audio_token
=
"<|audio_pad|>"
if
not
hasattr
(
processor
,
"image_token"
):
processor
.
image_token
=
"<|image_pad|>"
if
not
hasattr
(
processor
,
"video_token"
):
processor
.
video_token
=
"<|video_pad|>"
return
processor
def
get_feature_extractor
(
self
,
**
kwargs
:
object
):
hf_processor
=
self
.
get_hf_processor
(
**
kwargs
)
feature_extractor
=
hf_processor
.
feature_extractor
# type: ignore
assert
isinstance
(
feature_extractor
,
WhisperFeatureExtractor
)
return
feature_extractor
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
int
|
None
]:
return
{
"audio"
:
None
,
"image"
:
None
,
"video"
:
None
}
Qwen3OmniMoeThinkerDummyInputsBuilder
=
Qwen2_5OmniThinkerDummyInputsBuilder
class
Qwen3OmniMoeThinkerMultiModalProcessor
(
Qwen2_5OmniThinkerMultiModalProcessor
,
):
def
_call_hf_processor
(
self
,
prompt
:
str
,
mm_data
:
Mapping
[
str
,
object
],
mm_kwargs
:
Mapping
[
str
,
object
],
tok_kwargs
:
Mapping
[
str
,
object
],
)
->
BatchFeature
:
mm_data
=
dict
(
mm_data
)
audios
=
mm_data
.
pop
(
"audios"
,
[])
def
pad_to_hop_length
(
x
:
np
.
ndarray
,
hop_length
:
int
)
->
np
.
ndarray
:
length
=
x
.
shape
[
-
1
]
if
length
%
hop_length
!=
0
:
pad_length
=
hop_length
-
(
length
%
hop_length
)
x
=
np
.
pad
(
x
,
(
0
,
pad_length
),
mode
=
"constant"
,
constant_values
=
0
)
return
x
# NOTE: WhisperFeatureExtractor cannot handle empty list of audios
feature_extractor
=
self
.
info
.
get_feature_extractor
()
hop_length
=
feature_extractor
.
hop_length
if
audios
:
# NOTE: Qwen3-Omni processor accept "audio"
# To make sure the cache works with padding=True, we pre-padded
# the audio to multiple of hop_length.
mm_data
[
"audio"
]
=
[
pad_to_hop_length
(
audio
,
hop_length
)
if
isinstance
(
audio
,
np
.
ndarray
)
else
(
pad_to_hop_length
(
audio
[
0
],
hop_length
),
audio
[
1
])
for
audio
in
audios
]
mm_kwargs
=
dict
(
**
mm_kwargs
,
)
# TODO(Isotr0py): Remove this patch after upstream fix PR
# released and Transformers version update:
# https://github.com/huggingface/transformers/pull/41473
if
(
Version
(
TRANSFORMERS_VERSION
)
<
Version
(
"4.58.0"
)
and
"truncation"
not
in
mm_kwargs
):
mm_kwargs
[
"truncation"
]
=
False
hf_inputs
=
super
().
_call_hf_processor
(
prompt
=
prompt
,
mm_data
=
mm_data
,
mm_kwargs
=
mm_kwargs
,
tok_kwargs
=
tok_kwargs
,
)
if
(
"audio_feature_lengths"
in
hf_inputs
and
"feature_attention_mask"
in
hf_inputs
and
(
audios
:
=
mm_data
.
get
(
"audio"
,
[]))
):
audio_num_frames
=
[]
for
_
,
audio
in
enumerate
(
audios
):
audio_length
=
len
(
audio
[
0
])
if
isinstance
(
audio
,
tuple
)
else
len
(
audio
)
num_frame
=
(
(
audio_length
//
hop_length
)
if
audio_length
%
hop_length
==
0
else
(
audio_length
//
hop_length
-
1
)
)
if
mm_kwargs
.
get
(
"truncation"
,
False
):
num_frame
=
min
(
num_frame
,
feature_extractor
.
n_samples
//
hop_length
)
audio_num_frames
.
append
(
num_frame
)
hf_inputs
[
"feature_attention_mask"
]
=
[
torch
.
ones
(
num_frame
)
for
num_frame
in
audio_num_frames
]
hf_inputs
[
"audio_feature_lengths"
]
=
torch
.
tensor
(
audio_num_frames
)
return
hf_inputs
def
_maybe_apply_prompt_updates
(
self
,
mm_items
:
MultiModalDataItems
,
prompt_ids
:
list
[
int
],
mm_kwargs
:
MultiModalKwargsItems
,
mm_prompt_updates
:
MultiModalPromptUpdates
,
is_update_applied
:
bool
,
)
->
tuple
[
list
[
int
],
str
,
Mapping
[
str
,
list
[
PlaceholderFeaturesInfo
]]]:
"""
Qwen3-Omni reimplements this function to handle `use_audio_in_video`.
"""
mm_item_counts
=
mm_items
.
get_all_counts
()
self
.
_validate_mm_kwargs
(
mm_kwargs
,
mm_item_counts
)
use_audio_in_video
=
False
if
"video"
in
mm_kwargs
:
for
item
in
mm_kwargs
[
"video"
]:
if
item
and
item
[
"use_audio_in_video"
].
data
:
use_audio_in_video
=
True
else
:
use_audio_in_video
=
False
if
use_audio_in_video
and
"video"
in
mm_item_counts
:
assert
"audio"
in
mm_item_counts
mm_item_counts
[
"audio"
]
-=
mm_item_counts
[
"video"
]
# Special case with `use_audio_in_video=True`
if
use_audio_in_video
:
if
is_update_applied
:
prompt_ids
=
self
.
_get_raw_input_ids
(
prompt_ids
,
use_audio_in_video
)
(
prompt_ids
,
prompt
,
mm_placeholders
,
)
=
self
.
_apply_prompt_updates
(
prompt_ids
,
mm_prompt_updates
,
)
self
.
_validate_mm_placeholders
(
mm_placeholders
,
mm_item_counts
)
# normal case with `use_audio_in_video=False`
elif
is_update_applied
:
mm_placeholders
=
self
.
_find_mm_placeholders
(
prompt_ids
,
mm_prompt_updates
,
)
self
.
_validate_mm_placeholders
(
mm_placeholders
,
mm_item_counts
,
)
else
:
prompt_ids
,
prompt
,
mm_placeholders
=
self
.
_apply_prompt_updates
(
prompt_ids
,
mm_prompt_updates
,
)
self
.
_validate_mm_placeholders
(
mm_placeholders
,
mm_item_counts
,
)
# print("3333333333333333", prompt_ids, prompt, mm_placeholders)
return
prompt_ids
,
prompt
,
mm_placeholders
def
get_updates_use_audio_in_video
(
self
,
thinker_config
:
PretrainedConfig
,
audio_len
:
int
,
video_grid_thw
:
list
[
int
]
|
torch
.
Tensor
,
video_second_per_grid_t
:
float
,
)
->
list
[
int
]:
shift
=
0
audio_token_id
=
thinker_config
.
audio_token_id
video_token_id
=
thinker_config
.
video_token_id
audio_start_token_id
=
thinker_config
.
audio_start_token_id
audio_end_token_id
=
thinker_config
.
audio_end_token_id
spatial_merge_size
=
thinker_config
.
vision_config
.
spatial_merge_size
position_id_per_seconds
=
thinker_config
.
position_id_per_seconds
audio_token_indices
=
np
.
arange
(
next
(
iter
([
audio_len
])))
curr_video_grid_thw
=
next
(
iter
([
video_grid_thw
]))
height
=
curr_video_grid_thw
[
1
]
//
spatial_merge_size
width
=
curr_video_grid_thw
[
2
]
//
spatial_merge_size
video_token_indices
=
np
.
arange
(
curr_video_grid_thw
[
0
]).
reshape
(
-
1
,
1
,
1
)
video_token_indices
=
np
.
broadcast_to
(
video_token_indices
,
(
video_token_indices
.
shape
[
0
],
height
,
width
)
).
reshape
(
-
1
)
video_token_indices
=
(
(
video_token_indices
+
shift
)
*
next
(
iter
([
video_second_per_grid_t
]))
*
position_id_per_seconds
)
video_data_index
,
audio_data_index
=
0
,
0
updates
=
[
audio_start_token_id
]
while
video_data_index
<
len
(
video_token_indices
)
and
audio_data_index
<
len
(
audio_token_indices
):
if
(
video_token_indices
[
video_data_index
]
<=
audio_token_indices
[
audio_data_index
]
):
updates
+=
[
video_token_id
]
video_data_index
+=
1
else
:
updates
+=
[
audio_token_id
]
audio_data_index
+=
1
if
video_data_index
<
len
(
video_token_indices
):
updates
+=
[
video_token_id
]
*
(
len
(
video_token_indices
)
-
video_data_index
)
if
audio_data_index
<
len
(
audio_token_indices
):
updates
+=
[
audio_token_id
]
*
(
len
(
audio_token_indices
)
-
audio_data_index
)
updates
+=
[
audio_end_token_id
]
return
updates
def
_get_prompt_updates
(
self
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
Any
],
out_mm_kwargs
:
MultiModalKwargsItems
,
)
->
Sequence
[
PromptUpdate
]:
processor
=
self
.
info
.
get_hf_processor
(
**
hf_processor_mm_kwargs
)
tokenizer
=
self
.
info
.
get_tokenizer
()
image_processor
=
self
.
info
.
get_image_processor
(
**
hf_processor_mm_kwargs
)
vocab
=
tokenizer
.
get_vocab
()
audio_token
=
processor
.
audio_token
image_token
=
processor
.
image_token
video_token
=
processor
.
video_token
audio_token_id
=
vocab
[
audio_token
]
image_token_id
=
vocab
[
image_token
]
video_token_id
=
vocab
[
video_token
]
out_mm_data
=
out_mm_kwargs
.
get_data
()
audio_feature_lengths
=
out_mm_data
.
get
(
"audio_feature_lengths"
)
feature_attention_mask
=
out_mm_data
.
get
(
"feature_attention_mask"
)
if
audio_feature_lengths
is
None
and
feature_attention_mask
is
None
:
audio_output_lengths
=
[]
elif
audio_feature_lengths
is
not
None
:
_
,
audio_output_lens
=
_get_feat_extract_output_lengths
(
audio_feature_lengths
)
audio_output_lengths
=
audio_output_lens
.
tolist
()
elif
feature_attention_mask
is
not
None
:
assert
isinstance
(
feature_attention_mask
,
torch
.
Tensor
)
_
,
audio_output_lens
=
_get_feat_extract_output_lengths
(
feature_attention_mask
.
sum
(
-
1
)
)
audio_output_lengths
=
audio_output_lens
.
tolist
()
# number of audios read from video.
audio_in_video_item_idx
=
0
audio_item_idx
=
0
def
get_replacement_qwen2_audio
(
item_idx
:
int
):
nonlocal
audio_item_idx
item_idx
+=
audio_in_video_item_idx
audio_item_idx
+=
1
num_features
=
audio_output_lengths
[
item_idx
]
if
num_features
==
0
:
audios
=
mm_items
.
get_items
(
"audio"
,
AudioProcessorItems
)
audio
=
audios
.
get
(
item_idx
)
raise
ValueError
(
f
"The audio
{
audio
}
(len=
{
len
(
audio
)
}
) is too short "
"to be represented inside the model"
)
return
[
audio_token_id
]
*
num_features
def
get_replacement_qwen2_vision
(
item_idx
:
int
,
modality
:
str
):
grid_thw
=
out_mm_data
[
f
"
{
modality
}
_grid_thw"
][
item_idx
]
assert
isinstance
(
grid_thw
,
torch
.
Tensor
)
merge_length
=
image_processor
.
merge_size
**
2
token_id
=
image_token_id
if
modality
==
"image"
else
video_token_id
return
[
token_id
]
*
(
int
(
grid_thw
.
prod
())
//
merge_length
)
use_audio_in_video
=
hf_processor_mm_kwargs
.
get
(
"use_audio_in_video"
,
False
)
thinker_config
=
self
.
info
.
get_hf_config
()
def
get_replacement_qwen2_use_audio_in_video
(
item_idx
:
int
):
nonlocal
audio_in_video_item_idx
audio_num_features
=
audio_output_lengths
[
audio_item_idx
+
item_idx
]
video_grid_thw
=
out_mm_data
[
"video_grid_thw"
][
item_idx
]
audio_in_video_item_idx
+=
1
second_per_grid_ts
=
hf_processor_mm_kwargs
.
get
(
"second_per_grid_ts"
,
None
)
if
second_per_grid_ts
:
video_second_per_grid_t
=
second_per_grid_ts
[
item_idx
]
else
:
video_second_per_grid_t
=
1.0
return
self
.
get_updates_use_audio_in_video
(
thinker_config
=
thinker_config
,
audio_len
=
audio_num_features
,
video_grid_thw
=
video_grid_thw
,
video_second_per_grid_t
=
video_second_per_grid_t
,
)
video_replacement_fn
=
(
get_replacement_qwen2_use_audio_in_video
if
use_audio_in_video
else
partial
(
get_replacement_qwen2_vision
,
modality
=
"video"
)
)
return
[
PromptReplacement
(
modality
=
"audio"
,
target
=
audio_token
,
replacement
=
get_replacement_qwen2_audio
,
),
PromptReplacement
(
modality
=
"image"
,
target
=
image_token
,
replacement
=
partial
(
get_replacement_qwen2_vision
,
modality
=
"image"
),
),
PromptReplacement
(
modality
=
"video"
,
target
=
video_token
,
replacement
=
video_replacement_fn
,
),
]
def
_validate_mm_placeholders
(
self
,
mm_placeholders
:
Mapping
[
str
,
list
[
PlaceholderFeaturesInfo
]],
mm_item_counts
:
Mapping
[
str
,
int
],
)
->
None
:
BaseMultiModalProcessor
[
Qwen2_5OmniThinkerProcessingInfo
].
_validate_mm_placeholders
(
self
,
mm_placeholders
,
mm_item_counts
)
def
_get_raw_input_ids
(
self
,
token_ids
:
list
[
int
],
use_audio_in_video
:
bool
=
False
,
)
->
list
[
int
]:
tokenizer
=
self
.
info
.
get_tokenizer
()
vision_bos_token
=
tokenizer
.
encode
(
tokenizer
.
vision_bos_token
)[
0
]
vision_eos_token
=
tokenizer
.
encode
(
tokenizer
.
vision_eos_token
)[
0
]
audio_bos_token
=
tokenizer
.
encode
(
tokenizer
.
audio_bos_token
)[
0
]
audio_eos_token
=
tokenizer
.
encode
(
tokenizer
.
audio_eos_token
)[
0
]
audio_token
=
tokenizer
.
encode
(
"<|audio_pad|>"
)[
0
]
image_token
=
tokenizer
.
encode
(
"<|image_pad|>"
)[
0
]
video_token
=
tokenizer
.
encode
(
"<|video_pad|>"
)[
0
]
result
=
token_ids
[:]
if
use_audio_in_video
:
while
True
:
start
=
None
for
i
in
range
(
len
(
result
)
-
1
):
if
result
[
i
:
i
+
2
]
==
[
vision_bos_token
,
audio_bos_token
]:
start
=
i
break
if
start
is
not
None
:
end
=
None
for
i
in
range
(
start
+
2
,
len
(
result
)
-
1
):
if
result
[
i
:
i
+
2
]
==
[
audio_eos_token
,
vision_eos_token
]:
end
=
i
break
if
end
is
not
None
:
result
=
(
result
[:
start
]
+
[
vision_bos_token
,
video_token
,
vision_eos_token
]
+
result
[
end
+
2
:]
)
else
:
break
for
mm_token
in
[
audio_token
,
image_token
,
video_token
]:
compressed
=
[]
for
x
in
result
:
if
x
!=
mm_token
or
(
not
compressed
or
compressed
[
-
1
]
!=
mm_token
):
compressed
.
append
(
x
)
result
=
compressed
return
result
class
Qwen3OmniMoeConditionalGenerationMixin
(
Qwen2_5OmniConditionalGenerationMixin
):
def
_validate_and_reshape_mm_tensor
(
self
,
mm_input
:
object
,
name
:
str
,
dim
:
int
=
0
)
->
torch
.
Tensor
:
if
not
isinstance
(
mm_input
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
f
"Incorrect type of
{
name
}
. Got type:
{
type
(
mm_input
)
}
"
)
if
name
==
"feature_attention_mask"
:
dim
=
-
1
if
isinstance
(
mm_input
,
torch
.
Tensor
):
return
torch
.
concat
(
list
(
mm_input
),
dim
=
dim
)
else
:
if
isinstance
(
mm_input
[
0
],
list
):
return
torch
.
concat
(
[
torch
.
concat
(
mm_input
[
i
],
dim
=
dim
)
for
i
in
range
(
len
(
mm_input
))],
dim
=
dim
,
)
else
:
return
torch
.
concat
(
mm_input
,
dim
=
dim
)
def
_process_audio_input
(
self
,
audio_input
:
Qwen2AudioFeatureInputs
,
audio_hashes
:
list
[
str
]
=
None
,
cached_audio_features
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
input_features
=
audio_input
[
"input_features"
]
audio_feature_lengths
=
audio_input
[
"audio_feature_lengths"
]
if
input_features
.
ndim
==
3
:
assert
input_features
.
shape
[
0
]
==
1
input_features
=
input_features
.
squeeze
(
0
)
if
not
isinstance
(
audio_feature_lengths
,
torch
.
Tensor
):
audio_feature_lengths
=
torch
.
cat
(
audio_feature_lengths
)
if
audio_feature_lengths
.
ndim
==
2
:
audio_feature_lengths
=
audio_feature_lengths
.
reshape
(
-
1
)
audio_feat_lengths
,
audio_output_lengths
=
_get_feat_extract_output_lengths
(
audio_feature_lengths
)
audio_outputs
=
self
.
audio_tower
(
input_features
.
to
(
self
.
audio_tower
.
dtype
),
feature_lens
=
audio_feature_lengths
,
aftercnn_lens
=
audio_feat_lengths
,
)
audio_features
=
audio_outputs
.
last_hidden_state
return
audio_features
.
split
(
audio_output_lengths
.
tolist
())
@
MULTIMODAL_REGISTRY
.
register_processor
(
Qwen3OmniMoeThinkerMultiModalProcessor
,
info
=
Qwen3OmniMoeThinkerProcessingInfo
,
dummy_inputs
=
Qwen3OmniMoeThinkerDummyInputsBuilder
,
)
class
Qwen3OmniMoeThinkerForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
,
SupportsMRoPE
,
Qwen3OmniMoeConditionalGenerationMixin
,
):
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
"thinker.lm_head."
:
"language_model.lm_head."
,
"thinker.model."
:
"language_model.model."
,
"thinker."
:
""
,
}
)
@
classmethod
def
get_placeholder_str
(
cls
,
modality
:
str
,
i
:
int
)
->
str
|
None
:
if
modality
.
startswith
(
"image"
):
return
"<|vision_start|><|image_pad|><|vision_end|>"
if
modality
.
startswith
(
"video"
):
return
"<|vision_start|><|video_pad|><|vision_end|>"
if
modality
.
startswith
(
"audio"
):
return
"<|audio_start|><|audio_pad|><|audio_end|>"
raise
ValueError
(
"Only image, video or audio modality is supported"
)
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
thinker_config
:
Qwen3OmniMoeThinkerConfig
=
(
vllm_config
.
model_config
.
hf_config
.
thinker_config
)
quant_config
=
vllm_config
.
quant_config
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
self
.
config
=
thinker_config
self
.
multimodal_config
=
multimodal_config
# force "use_flash_attention_2=True" to audio tower to align
# the results.
if
flash_attn
is
not
None
:
audio_config
=
thinker_config
.
audio_config
audio_config
.
_attn_implementation_autoset
=
True
audio_config
.
_attn_implementation
=
"flash_attention_2"
else
:
logger
.
warning
(
"flash_attn is not available, the model may not yield the "
"exactly same result as the transformers implementation "
"in the audio tower part."
)
self
.
audio_tower
=
Qwen3OmniMoeAudioEncoder
(
thinker_config
.
audio_config
)
self
.
visual
=
Qwen3Omni_VisionTransformer
(
vision_config
=
thinker_config
.
vision_config
,
norm_eps
=
getattr
(
thinker_config
.
text_config
,
"rms_norm_eps"
,
1e-6
),
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"visual"
),
)
self
.
quant_config
=
quant_config
self
.
language_model
=
Qwen3MoeLLMForCausalLM
(
vllm_config
=
vllm_config
.
with_hf_config
(
thinker_config
.
text_config
,
architectures
=
[
"Qwen3MoeForCausalLM"
]
),
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
self
.
use_deepstack
=
hasattr
(
thinker_config
.
vision_config
,
"deepstack_visual_indexes"
)
self
.
deepstack_num_level
=
(
len
(
thinker_config
.
vision_config
.
deepstack_visual_indexes
)
if
self
.
use_deepstack
else
0
)
# register buffer for deepstack
self
.
deepstack_input_embeds
=
(
[
torch
.
zeros
(
vllm_config
.
scheduler_config
.
max_num_batched_tokens
,
thinker_config
.
text_config
.
hidden_size
,
)
for
_
in
range
(
self
.
deepstack_num_level
)
]
if
self
.
use_deepstack
else
None
)
self
.
visual_dim
=
thinker_config
.
vision_config
.
out_hidden_size
self
.
multiscale_dim
=
self
.
visual_dim
*
self
.
deepstack_num_level
def
_get_deepstack_input_embeds
(
self
,
num_tokens
:
int
)
->
IntermediateTensors
:
# get deepstack_input_embeds from buffer, and clear the buffer
return
IntermediateTensors
(
{
f
"deepstack_input_embeds_
{
idx
}
"
:
self
.
deepstack_input_embeds
[
idx
][
:
num_tokens
]
for
idx
in
range
(
self
.
deepstack_num_level
)
}
)
def
_set_deepstack_input_embeds
(
self
,
deepstack_input_embeds
:
torch
.
Tensor
)
->
None
:
# set deepstack_input_embeds to buffer
num_tokens
=
deepstack_input_embeds
.
size
(
1
)
if
num_tokens
>
self
.
deepstack_input_embeds
[
0
].
size
(
0
):
self
.
deepstack_input_embeds
=
[
torch
.
zeros
(
num_tokens
,
self
.
config
.
text_config
.
hidden_size
,
device
=
self
.
deepstack_input_embeds
[
0
].
device
,
dtype
=
self
.
deepstack_input_embeds
[
0
].
dtype
,
)
for
_
in
range
(
self
.
deepstack_num_level
)
]
for
idx
in
range
(
self
.
deepstack_num_level
):
self
.
deepstack_input_embeds
[
idx
][:
num_tokens
].
copy_
(
deepstack_input_embeds
[
idx
]
)
def
_clear_deepstack_input_embeds
(
self
,
num_tokens
:
int
)
->
None
:
# clear deepstack_input_embeds in buffer
if
num_tokens
>
0
:
for
idx
in
range
(
self
.
deepstack_num_level
):
self
.
deepstack_input_embeds
[
idx
][:
num_tokens
].
zero_
()
def
_parse_and_validate_multimodal_inputs
(
self
,
**
kwargs
:
object
)
->
dict
:
mm_input_by_modality
=
{}
# Preserve the order of modalities if there are multiple of them
# from the order of kwargs.
for
input_key
in
kwargs
:
if
(
input_key
in
(
"pixel_values"
,
"image_embeds"
)
and
"image"
not
in
mm_input_by_modality
):
mm_input_by_modality
[
"image"
]
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
(
input_key
in
(
"pixel_values_videos"
,
"video_embeds"
)
and
"video"
not
in
mm_input_by_modality
):
mm_input_by_modality
[
"video"
]
=
self
.
_parse_and_validate_video_input
(
**
kwargs
)
if
(
input_key
in
(
"input_audio_features"
)
and
"audio"
not
in
mm_input_by_modality
):
mm_input_by_modality
[
"audio"
]
=
self
.
_parse_and_validate_audio_input
(
**
kwargs
)
return
mm_input_by_modality
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
MultiModalEmbeddings
|
None
:
mm_input_by_modality
=
self
.
_parse_and_validate_multimodal_inputs
(
**
kwargs
)
if
not
mm_input_by_modality
:
return
[]
# The result multimodal_embeddings is tuple of tensors, with each
# tensor correspoending to a multimodal data item (image or video).
multimodal_embeddings
:
tuple
[
torch
.
Tensor
,
...]
=
()
# NOTE: It is important to iterate over the keys in this dictionary
# to preserve the order of the modalities.
for
modality
in
mm_input_by_modality
:
multimodal_input
=
mm_input_by_modality
[
modality
]
if
modality
==
"image"
:
image_embeddings
=
self
.
_process_image_input
(
multimodal_input
)
multimodal_embeddings
+=
tuple
(
image_embeddings
)
if
modality
==
"video"
:
video_embeddings
=
self
.
_process_video_input
(
multimodal_input
)
multimodal_embeddings
+=
tuple
(
video_embeddings
)
if
modality
==
"audio"
:
audio_embeddings
=
self
.
_process_audio_input
(
multimodal_input
)
multimodal_embeddings
+=
tuple
(
audio_embeddings
)
return
multimodal_embeddings
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
,
multimodal_embeddings
:
MultiModalEmbeddings
|
None
=
None
,
*
,
is_multimodal
:
torch
.
Tensor
|
None
=
None
,
handle_oov_mm_token
:
bool
=
False
,
)
->
torch
.
Tensor
:
# print("11111111111111111", is_multimodal)
inputs_embeds
=
self
.
_get_text_embeddings
(
input_ids
,
self
.
language_model
.
get_input_embeddings
,
is_multimodal
=
is_multimodal
,
handle_oov_mm_token
=
handle_oov_mm_token
,
)
if
multimodal_embeddings
is
None
or
len
(
multimodal_embeddings
)
==
0
:
return
inputs_embeds
deepstack_input_embeds
=
None
# TODO (ywang96): support overlapping modalitiy embeddings so that
# `use_audio_in_video` will work on V1.
# split the feat dim to obtain multi-scale visual feature
has_vision_embeddings
=
[
embeddings
.
shape
[
-
1
]
!=
self
.
config
.
text_config
.
hidden_size
for
embeddings
in
multimodal_embeddings
]
if
self
.
visual
.
deepstack_visual_indexes
is
not
None
and
any
(
has_vision_embeddings
):
multiscale_len
=
len
(
self
.
visual
.
deepstack_visual_indexes
)
multimodal_embeddings_multiscale
=
[]
is_vision
=
torch
.
zeros_like
(
is_multimodal
)
mm_positions
=
torch
.
nonzero
(
is_multimodal
,
as_tuple
=
True
)[
0
]
mm_position_idx
=
0
for
index
,
embeddings
in
enumerate
(
multimodal_embeddings
):
num_tokens
=
embeddings
.
shape
[
0
]
current_positions
=
mm_positions
[
mm_position_idx
:
mm_position_idx
+
num_tokens
]
# Vision embeddings
if
embeddings
.
shape
[
-
1
]
!=
self
.
config
.
text_config
.
hidden_size
:
visual_dim
=
embeddings
.
shape
[
-
1
]
//
(
multiscale_len
+
1
)
multi_dim
=
visual_dim
*
multiscale_len
embeddings_main
,
embeddings_multiscale
=
torch
.
split
(
embeddings
,
[
visual_dim
,
multi_dim
],
dim
=-
1
)
multimodal_embeddings
[
index
]
=
embeddings_main
multimodal_embeddings_multiscale
.
append
(
embeddings_multiscale
)
is_vision
[
current_positions
]
=
True
# Audio embeddings
else
:
is_vision
[
current_positions
]
=
False
mm_position_idx
+=
num_tokens
deepstack_input_embeds
=
inputs_embeds
.
new_zeros
(
inputs_embeds
.
size
(
0
),
multiscale_len
*
inputs_embeds
.
size
(
1
)
)
deepstack_input_embeds
=
_merge_multimodal_embeddings
(
inputs_embeds
=
deepstack_input_embeds
,
multimodal_embeddings
=
multimodal_embeddings_multiscale
,
is_multimodal
=
is_vision
,
)
deepstack_input_embeds
=
(
deepstack_input_embeds
.
view
(
inputs_embeds
.
shape
[
0
],
multiscale_len
,
visual_dim
)
.
permute
(
1
,
0
,
2
)
.
contiguous
()
)
self
.
_set_deepstack_input_embeds
(
deepstack_input_embeds
)
# print("2222222222222", is_multimodal)
inputs_embeds
=
_merge_multimodal_embeddings
(
inputs_embeds
=
inputs_embeds
,
multimodal_embeddings
=
multimodal_embeddings
,
is_multimodal
=
is_multimodal
,
)
return
inputs_embeds
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
**
kwargs
:
object
,
)
->
torch
.
Tensor
|
IntermediateTensors
:
if
intermediate_tensors
is
not
None
:
inputs_embeds
=
None
if
(
self
.
use_deepstack
and
inputs_embeds
is
not
None
and
get_pp_group
().
is_first_rank
):
deepstack_input_embeds
=
self
.
_get_deepstack_input_embeds
(
inputs_embeds
.
size
(
0
)
)
else
:
deepstack_input_embeds
=
None
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
# args for deepstack
deepstack_input_embeds
=
deepstack_input_embeds
,
)
if
inputs_embeds
is
not
None
and
get_pp_group
().
is_first_rank
:
self
.
_clear_deepstack_input_embeds
(
inputs_embeds
.
size
(
0
))
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
None
:
return
self
.
language_model
.
compute_logits
(
hidden_states
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
loader
=
AutoWeightsLoader
(
self
,
skip_prefixes
=
[
"talker."
,
"code2wav."
],
)
loaded_weights
=
loader
.
load_weights
(
weights
,
mapper
=
self
.
hf_to_vllm_mapper
)
return
loaded_weights
@
classmethod
def
get_mrope_input_positions
(
self
,
input_tokens
:
list
[
int
],
hf_config
:
PretrainedConfig
,
image_grid_thw
:
list
[
list
[
int
]]
|
torch
.
Tensor
|
None
,
video_grid_thw
:
list
[
list
[
int
]]
|
torch
.
Tensor
|
None
,
second_per_grid_ts
:
list
[
float
]
|
None
=
None
,
context_len
:
int
=
0
,
seq_len
:
int
|
None
=
None
,
audio_feature_lengths
:
torch
.
Tensor
|
None
=
None
,
use_audio_in_video
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
int
]:
config
=
hf_config
.
thinker_config
if
isinstance
(
image_grid_thw
,
list
):
image_grid_thw
=
torch
.
tensor
(
image_grid_thw
)
if
isinstance
(
video_grid_thw
,
list
):
video_grid_thw
=
torch
.
tensor
(
video_grid_thw
)
input_ids
=
torch
.
tensor
(
input_tokens
)
if
input_ids
is
None
or
input_ids
.
ndim
!=
1
:
raise
ValueError
(
"_omni3_get_input_positions_tensor expects 1D input_ids"
)
seq_len
=
input_ids
.
shape
[
0
]
if
audio_feature_lengths
is
not
None
and
not
isinstance
(
audio_feature_lengths
,
torch
.
Tensor
):
audio_feature_lengths
=
torch
.
as_tensor
(
audio_feature_lengths
,
dtype
=
torch
.
long
)
if
second_per_grid_ts
is
None
:
if
video_grid_thw
is
not
None
and
video_grid_thw
.
numel
()
>
0
:
second_per_grids
=
torch
.
ones
(
video_grid_thw
.
shape
[
0
],
dtype
=
torch
.
float32
)
else
:
second_per_grids
=
torch
.
tensor
([],
dtype
=
torch
.
float32
)
else
:
second_per_grids
=
torch
.
tensor
(
second_per_grid_ts
,
dtype
=
torch
.
float32
)
spatial_merge_size
=
config
.
vision_config
.
spatial_merge_size
image_token_id
=
config
.
image_token_id
video_token_id
=
config
.
video_token_id
audio_token_id
=
config
.
audio_token_id
vision_start_token_id
=
config
.
vision_start_token_id
audio_start_token_id
=
config
.
audio_start_token_id
position_id_per_seconds
=
config
.
position_id_per_seconds
vision_start_indices
=
torch
.
argwhere
(
input_ids
==
vision_start_token_id
).
squeeze
(
1
)
if
vision_start_indices
.
numel
()
>
0
:
vision_tokens
=
input_ids
[
vision_start_indices
+
1
]
else
:
vision_tokens
=
input_ids
.
new_empty
((
0
,),
dtype
=
input_ids
.
dtype
)
audio_nums
=
torch
.
sum
(
input_ids
==
audio_start_token_id
)
image_nums
=
(
vision_tokens
==
image_token_id
).
sum
()
video_nums
=
(
(
vision_tokens
==
audio_start_token_id
).
sum
()
if
use_audio_in_video
else
(
vision_tokens
==
video_token_id
).
sum
()
)
llm_pos_ids_list
:
list
[
torch
.
Tensor
]
=
[]
st
=
0
image_idx
=
0
video_idx
=
0
audio_idx
=
0
remain_images
,
remain_videos
,
remain_audios
=
image_nums
,
video_nums
,
audio_nums
# noqa: E501
multimodal_nums
=
(
image_nums
+
audio_nums
if
use_audio_in_video
else
image_nums
+
video_nums
+
audio_nums
)
# noqa: E501
for
_
in
range
(
multimodal_nums
):
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
llm_pos_ids_list
else
0
if
(
image_token_id
in
input_tokens
or
video_token_id
in
input_tokens
)
and
(
remain_videos
>
0
or
remain_images
>
0
):
ed_vision_start
=
input_tokens
.
index
(
vision_start_token_id
,
st
)
else
:
ed_vision_start
=
len
(
input_tokens
)
+
1
if
audio_token_id
in
input_tokens
and
remain_audios
>
0
:
ed_audio_start
=
input_tokens
.
index
(
audio_start_token_id
,
st
)
else
:
ed_audio_start
=
len
(
input_tokens
)
+
1
min_ed
=
min
(
ed_vision_start
,
ed_audio_start
)
if
min_ed
==
ed_audio_start
:
text_len
=
min_ed
-
st
if
text_len
!=
0
:
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
llm_pos_ids_list
else
0
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
,
dtype
=
torch
.
long
)
.
view
(
1
,
-
1
)
.
expand
(
3
,
-
1
)
+
st_idx
)
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
llm_pos_ids_list
else
0
bos_len
=
1
llm_pos_ids_list
.
append
(
torch
.
arange
(
bos_len
,
dtype
=
torch
.
long
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
llm_pos_ids_list
else
0
_
,
audio_len
=
_get_feat_extract_output_lengths
(
audio_feature_lengths
[
audio_idx
]
)
llm_pos_ids
=
(
torch
.
arange
(
audio_len
,
dtype
=
torch
.
long
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
llm_pos_ids_list
.
append
(
llm_pos_ids
)
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
llm_pos_ids_list
else
0
eos_len
=
1
llm_pos_ids_list
.
append
(
torch
.
arange
(
eos_len
,
dtype
=
torch
.
long
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
st
+=
text_len
+
bos_len
+
audio_len
+
eos_len
audio_idx
+=
1
remain_audios
-=
1
elif
(
min_ed
==
ed_vision_start
and
input_ids
[
ed_vision_start
+
1
]
==
image_token_id
):
text_len
=
min_ed
-
st
if
text_len
!=
0
:
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
llm_pos_ids_list
else
0
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
,
dtype
=
torch
.
long
)
.
view
(
1
,
-
1
)
.
expand
(
3
,
-
1
)
+
st_idx
)
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
llm_pos_ids_list
else
0
bos_len
=
1
llm_pos_ids_list
.
append
(
torch
.
arange
(
bos_len
,
dtype
=
torch
.
long
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
llm_pos_ids_list
else
0
grid_t
=
image_grid_thw
[
image_idx
][
0
]
grid_hs
=
image_grid_thw
[:,
1
]
grid_ws
=
image_grid_thw
[:,
2
]
t_index
=
torch
.
arange
(
grid_t
)
*
position_id_per_seconds
llm_pos_ids
=
get_llm_pos_ids_for_vision
(
st_idx
,
image_idx
,
spatial_merge_size
,
t_index
,
grid_hs
,
grid_ws
)
image_len
=
image_grid_thw
[
image_idx
].
prod
()
//
(
spatial_merge_size
**
2
)
llm_pos_ids_list
.
append
(
llm_pos_ids
)
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
llm_pos_ids_list
else
0
eos_len
=
1
llm_pos_ids_list
.
append
(
torch
.
arange
(
eos_len
,
dtype
=
torch
.
long
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
st
+=
text_len
+
bos_len
+
image_len
+
eos_len
image_idx
+=
1
remain_images
-=
1
elif
(
min_ed
==
ed_vision_start
and
input_ids
[
ed_vision_start
+
1
]
==
video_token_id
and
not
use_audio_in_video
):
text_len
=
min_ed
-
st
if
text_len
!=
0
:
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
llm_pos_ids_list
else
0
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
,
dtype
=
torch
.
long
)
.
view
(
1
,
-
1
)
.
expand
(
3
,
-
1
)
+
st_idx
)
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
llm_pos_ids_list
else
0
bos_len
=
1
llm_pos_ids_list
.
append
(
torch
.
arange
(
bos_len
,
dtype
=
torch
.
long
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
llm_pos_ids_list
else
0
grid_t
=
video_grid_thw
[
video_idx
][
0
]
grid_hs
=
video_grid_thw
[:,
1
]
grid_ws
=
video_grid_thw
[:,
2
]
t_index
=
(
torch
.
arange
(
grid_t
)
*
float
(
second_per_grids
[
video_idx
].
item
())
*
position_id_per_seconds
)
llm_pos_ids
=
get_llm_pos_ids_for_vision
(
st_idx
,
video_idx
,
spatial_merge_size
,
t_index
,
grid_hs
,
grid_ws
)
video_len
=
video_grid_thw
[
video_idx
].
prod
()
//
(
spatial_merge_size
**
2
)
llm_pos_ids_list
.
append
(
llm_pos_ids
)
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
llm_pos_ids_list
else
0
eos_len
=
1
llm_pos_ids_list
.
append
(
torch
.
arange
(
eos_len
,
dtype
=
torch
.
long
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
st
+=
text_len
+
bos_len
+
video_len
+
eos_len
video_idx
+=
1
remain_videos
-=
1
elif
(
min_ed
==
ed_vision_start
and
ed_vision_start
+
1
==
ed_audio_start
and
use_audio_in_video
):
text_len
=
min_ed
-
st
if
text_len
!=
0
:
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
llm_pos_ids_list
else
0
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
,
dtype
=
torch
.
long
)
.
view
(
1
,
-
1
)
.
expand
(
3
,
-
1
)
+
st_idx
)
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
llm_pos_ids_list
else
0
bos_len
=
1
bos_block
=
(
torch
.
arange
(
bos_len
,
dtype
=
torch
.
long
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
llm_pos_ids_list
.
append
(
bos_block
)
llm_pos_ids_list
.
append
(
bos_block
)
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
llm_pos_ids_list
else
0
_
,
audio_len
=
_get_feat_extract_output_lengths
(
audio_feature_lengths
[
audio_idx
]
)
audio_llm_pos_ids
=
(
torch
.
arange
(
audio_len
,
dtype
=
torch
.
long
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
grid_t
=
video_grid_thw
[
video_idx
][
0
]
grid_hs
=
video_grid_thw
[:,
1
]
grid_ws
=
video_grid_thw
[:,
2
]
t_index
=
(
torch
.
arange
(
grid_t
)
*
float
(
second_per_grids
[
video_idx
].
item
())
*
position_id_per_seconds
)
video_llm_pos_ids
=
get_llm_pos_ids_for_vision
(
st_idx
,
video_idx
,
spatial_merge_size
,
t_index
,
grid_hs
,
grid_ws
)
video_data_index
,
audio_data_index
=
0
,
0
while
(
video_data_index
<
video_llm_pos_ids
.
shape
[
-
1
]
and
audio_data_index
<
audio_llm_pos_ids
.
shape
[
-
1
]
):
if
(
video_llm_pos_ids
[
0
][
video_data_index
]
<=
audio_llm_pos_ids
[
0
][
audio_data_index
]
):
llm_pos_ids_list
.
append
(
video_llm_pos_ids
[
:,
video_data_index
:
video_data_index
+
1
]
)
video_data_index
+=
1
else
:
llm_pos_ids_list
.
append
(
audio_llm_pos_ids
[
:,
audio_data_index
:
audio_data_index
+
1
]
)
audio_data_index
+=
1
if
video_data_index
<
video_llm_pos_ids
.
shape
[
-
1
]:
llm_pos_ids_list
.
append
(
video_llm_pos_ids
[
:,
video_data_index
:
video_llm_pos_ids
.
shape
[
-
1
]
]
)
if
audio_data_index
<
audio_llm_pos_ids
.
shape
[
-
1
]:
llm_pos_ids_list
.
append
(
audio_llm_pos_ids
[
:,
audio_data_index
:
audio_llm_pos_ids
.
shape
[
-
1
]
]
)
video_len
=
video_grid_thw
[
video_idx
].
prod
()
//
(
spatial_merge_size
**
2
)
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
llm_pos_ids_list
else
0
eos_len
=
1
eos_block
=
(
torch
.
arange
(
eos_len
,
dtype
=
torch
.
long
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
llm_pos_ids_list
.
append
(
eos_block
)
llm_pos_ids_list
.
append
(
eos_block
)
st
+=
text_len
+
bos_len
*
2
+
audio_len
+
video_len
+
eos_len
*
2
# noqa: E501
audio_idx
+=
1
video_idx
+=
1
remain_videos
-=
1
remain_audios
-=
1
if
st
<
len
(
input_tokens
):
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
llm_pos_ids_list
else
0
text_len
=
len
(
input_tokens
)
-
st
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
,
dtype
=
torch
.
long
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
llm_positions
=
torch
.
cat
(
llm_pos_ids_list
,
dim
=
1
).
reshape
(
3
,
-
1
)
if
llm_positions
.
shape
[
1
]
!=
seq_len
:
raise
RuntimeError
(
"Position ids length mismatch with input ids length"
)
mrope_position_delta
=
llm_positions
.
max
()
+
1
-
seq_len
return
llm_positions
,
mrope_position_delta
\ No newline at end of file
vllm/model_executor/models/registry.py
View file @
f39afa4a
...
...
@@ -269,6 +269,7 @@ _MULTIMODAL_MODELS = {
"Qwen2AudioForConditionalGeneration"
:
(
"qwen2_audio"
,
"Qwen2AudioForConditionalGeneration"
),
# noqa: E501
"Qwen2_5OmniModel"
:
(
"qwen2_5_omni_thinker"
,
"Qwen2_5OmniThinkerForConditionalGeneration"
),
# noqa: E501
"Qwen2_5OmniForConditionalGeneration"
:
(
"qwen2_5_omni_thinker"
,
"Qwen2_5OmniThinkerForConditionalGeneration"
),
# noqa: E501
"Qwen3OmniMoeForConditionalGeneration"
:
(
"qwen3_omni_moe_thinker"
,
"Qwen3OmniMoeThinkerForConditionalGeneration"
),
"Qwen3VLForConditionalGeneration"
:
(
"qwen3_vl"
,
"Qwen3VLForConditionalGeneration"
),
# noqa: E501
"Qwen3VLMoeForConditionalGeneration"
:
(
"qwen3_vl_moe"
,
"Qwen3VLMoeForConditionalGeneration"
),
# noqa: E501
"SkyworkR1VChatModel"
:
(
"skyworkr1v"
,
"SkyworkR1VChatModel"
),
...
...
vllm/model_executor/models/utils.py
View file @
f39afa4a
...
...
@@ -10,6 +10,7 @@ import torch
import
torch.nn
as
nn
from
torch.func
import
functional_call
from
transformers
import
PretrainedConfig
from
typing_extensions
import
deprecated
import
vllm.envs
as
envs
from
vllm.config
import
VllmConfig
...
...
@@ -391,92 +392,79 @@ def _embedding_count_expression(embeddings: NestedTensors) -> str:
return
" + "
.
join
(
_embedding_count_expression
(
inner
)
for
inner
in
embeddings
)
def
split_list_into_ranges
(
lst
:
torch
.
Tensor
,
interval
:
int
)
->
list
[
list
[
int
]]:
ranges
:
list
[
list
[
int
]]
=
[[]
for
_
in
range
((
max
(
lst
)
//
interval
)
+
1
)]
for
num
in
lst
:
index
=
num
//
interval
ranges
[
index
].
append
(
num
)
return
ranges
def
_merge_multimodal_embeddings
(
inputs_embeds
:
torch
.
Tensor
,
is_multimodal
:
torch
.
Tensor
,
multimodal_embeddings
:
NestedTensors
,
is_multimodal
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""
Merge
`
`multimodal_embeddings`
`
into
`
`inputs_embeds`
`
by overwriting the
positions in
`
`inputs_embeds`
`
corresponding to placeholder tokens in
`
`input_ids`
`
.
Merge `multimodal_embeddings` into `inputs_embeds` by overwriting the
positions in `inputs_embeds` corresponding to placeholder tokens in
`input_ids`.
Note:
This updates
`
`inputs_embeds`
`
in place.
This updates `inputs_embeds` in place.
"""
flattened
=
_flatten_embeddings
(
multimodal_embeddings
)
if
len
(
multimodal_embeddings
)
==
0
:
return
inputs_embeds
mm_embeds_flat
=
_flatten_embeddings
(
multimodal_embeddings
)
input_dtype
=
inputs_embeds
.
dtype
try
:
# This is equivalent to: inputs_embeds[is_multimodal] = flattened.
inputs_embeds
.
masked_scatter_
(
is_multimodal
.
unsqueeze
(
-
1
),
flattened
.
to
(
dtype
=
inputs_embeds
.
dtype
))
# For debugging
# inputs_embeds[is_multimodal] = mm_embeds_flat.to(dtype=input_dtype)
# NOTE: This can avoid D2H sync (#22105), but fails to
# raise an error if is_multimodal.sum() < len(mm_embeds_flat)
inputs_embeds
.
masked_scatter_
(
is_multimodal
.
unsqueeze
(
-
1
),
mm_embeds_flat
.
to
(
dtype
=
input_dtype
)
)
except
RuntimeError
as
e
:
num_actual_tokens
=
len
(
mm_embeds_flat
)
num_expected_tokens
=
is_multimodal
.
sum
().
item
()
assert
isinstance
(
num_expected_tokens
,
int
)
if
flattened
.
shape
[
0
]
!=
num_expected_tokens
:
if
num_actual_tokens
!=
num_expected_tokens
:
expr
=
_embedding_count_expression
(
multimodal_embeddings
)
raise
ValueError
(
f
"Attempted to assign
{
expr
}
=
{
flattened
.
shape
[
0
]
}
"
f
"Attempted to assign
{
expr
}
=
{
num_actual_tokens
}
"
f
"multimodal tokens to
{
num_expected_tokens
}
placeholders"
)
from
e
else
:
raise
ValueError
(
"Error during masked scatter operation"
)
from
e
return
inputs_embeds
def
embed_multimodal
(
input_ids
:
torch
.
Tensor
,
multimodal_token_id
:
int
,
get_text_embeds
:
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
],
multimodal_embeds
:
NestedTensors
,
)
->
torch
.
Tensor
:
"""
Embed token IDs and multimodal inputs and combine their embeddings.
``multimodal_token_id`` is used to determine whether a token ID should
be embedded using ``get_text_embeds`` or ``get_multimodal_embeds``.
Compared to ``merge_multimodal_embeddings`, this avoids running
``get_text_embeds`` on ``input_ids[input_ids == multimodal_token_id]``
which causes issues when the placeholder token ID exceeds the
vocabulary size of the language model.
"""
is_multimodal
=
input_ids
==
multimodal_token_id
is_text
=
~
is_multimodal
text_embeds
=
get_text_embeds
(
input_ids
[
is_text
])
merged_embeds
=
torch
.
empty
(
(
input_ids
.
shape
[
0
],
text_embeds
.
shape
[
1
]),
dtype
=
text_embeds
.
dtype
,
device
=
text_embeds
.
device
,
)
merged_embeds
[
is_text
]
=
text_embeds
raise
ValueError
(
"Error during masked scatter operation"
)
from
e
return
_merge_multimodal_embeddings
(
merged_embeds
,
is_multimodal
,
multimodal_embeds
,
)
return
inputs_embeds
@
deprecated
(
"`merge_multimodal_embeddings` has been replaced with "
"`SupportsMultiModal.get_input_embeddings` and will be "
"removed in v0.12."
)
def
merge_multimodal_embeddings
(
input_ids
:
torch
.
Tensor
,
inputs_embeds
:
torch
.
Tensor
,
multimodal_embeddings
:
NestedTensors
,
placeholder_token_id
:
Union
[
int
,
list
[
int
]
]
,
placeholder_token_id
:
int
|
list
[
int
],
)
->
torch
.
Tensor
:
"""
Merge
`
`multimodal_embeddings`
`
into
`
`inputs_embeds`
`
by overwriting the
positions in
`
`inputs_embeds`
`
corresponding to placeholder tokens in
`
`input_ids`
`
.
Merge `multimodal_embeddings` into `inputs_embeds` by overwriting the
positions in `inputs_embeds` corresponding to placeholder tokens in
`input_ids`.
`
`placeholder_token_id`
`
can be a list of token ids (e.g, token ids
`placeholder_token_id` can be a list of token ids (e.g, token ids
of img_start, img_break, and img_end tokens) when needed: This means
the order of these tokens in the
`
`input_ids`
`
MUST MATCH the order of
their embeddings in
`
`multimodal_embeddings`
`
since we need to
the order of these tokens in the `input_ids` MUST MATCH the order of
their embeddings in `multimodal_embeddings` since we need to
slice-merge instead of individually scattering.
For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where
...
...
@@ -491,26 +479,32 @@ def merge_multimodal_embeddings(
input_ids for a correct embedding merge.
Note:
This updates
`
`inputs_embeds`
`
in place.
This updates `inputs_embeds` in place.
"""
if
isinstance
(
placeholder_token_id
,
list
):
placeholder_token_id
=
torch
.
tensor
(
placeholder_token_id
,
pin_memory
=
is_pin_memory_available
()).
to
(
device
=
input_ids
.
device
,
non_blocking
=
True
)
return
_merge_multimodal_embeddings
(
inputs_embeds
,
torch
.
isin
(
input_ids
,
placeholder_token_id
),
multimodal_embeddings
,
)
is_multimodal
=
isin_list
(
input_ids
,
placeholder_token_id
)
else
:
is_multimodal
=
input_ids
==
placeholder_token_id
return
_merge_multimodal_embeddings
(
inputs_embeds
,
(
input_ids
==
placeholder_token_id
)
,
multimodal
_embeddings
,
multimodal_embeddings
=
multimodal_embeddings
,
is_
multimodal
=
is_multimodal
,
)
def
isin_list
(
elements
:
torch
.
Tensor
,
test_elements_list
:
list
[
int
],
)
->
torch
.
Tensor
:
test_elements
=
torch
.
tensor
(
test_elements_list
,
pin_memory
=
is_pin_memory_available
(),
).
to
(
device
=
elements
.
device
,
non_blocking
=
True
)
return
torch
.
isin
(
elements
,
test_elements
)
class
LayerFn
(
Protocol
):
def
__call__
(
self
,
prefix
:
str
)
->
torch
.
nn
.
Module
:
...
...
vllm/model_executor/models/vision.py
View file @
f39afa4a
...
...
@@ -402,3 +402,39 @@ def run_dp_sharded_mrope_vision_model(
assert
len
(
out_embeddings
)
==
len
(
original_order_embeddings
),
"Found unassigned embeddings"
return
out_embeddings
def
get_llm_pos_ids_for_vision
(
start_idx
:
int
,
vision_idx
:
int
,
spatial_merge_size
:
int
,
t_index
:
list
[
int
],
grid_hs
:
torch
.
Tensor
,
grid_ws
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
llm_pos_ids_list
=
[]
llm_grid_h
=
grid_hs
[
vision_idx
]
//
spatial_merge_size
llm_grid_w
=
grid_ws
[
vision_idx
]
//
spatial_merge_size
h_index
=
(
torch
.
arange
(
llm_grid_h
)
.
view
(
1
,
-
1
,
1
)
.
expand
(
len
(
t_index
),
-
1
,
llm_grid_w
)
.
flatten
()
)
w_index
=
(
torch
.
arange
(
llm_grid_w
)
.
view
(
1
,
1
,
-
1
)
.
expand
(
len
(
t_index
),
llm_grid_h
,
-
1
)
.
flatten
()
)
t_index_tensor
=
(
torch
.
Tensor
(
t_index
)
.
to
(
llm_grid_h
.
device
)
.
view
(
-
1
,
1
)
.
expand
(
-
1
,
llm_grid_h
*
llm_grid_w
)
.
long
()
.
flatten
()
)
_llm_pos_ids
=
torch
.
stack
([
t_index_tensor
,
h_index
,
w_index
])
llm_pos_ids_list
.
append
(
_llm_pos_ids
+
start_idx
)
llm_pos_ids
=
torch
.
cat
(
llm_pos_ids_list
,
dim
=
1
)
return
llm_pos_ids
vllm/v1/worker/gpu_model_runner.py
View file @
f39afa4a
...
...
@@ -368,6 +368,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dtype
=
torch
.
int32
)
self
.
num_accepted_tokens
=
self
.
_make_buffer
(
self
.
max_num_reqs
,
dtype
=
torch
.
int64
)
# Only relevant for multimodal models
if
self
.
supports_mm_inputs
:
self
.
is_mm_embed
=
self
.
_make_buffer
(
self
.
max_num_tokens
,
dtype
=
torch
.
bool
)
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if
self
.
uses_mrope
:
...
...
@@ -1612,17 +1615,23 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
,
scheduler_output
:
"SchedulerOutput"
,
shift_computed_tokens
:
int
=
0
,
)
->
list
[
torch
.
Tensor
]:
)
->
tuple
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
total_num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
mm_embeds
=
list
[
torch
.
Tensor
]()
is_mm_embed
=
self
.
is_mm_embed
.
cpu
is_mm_embed
[:
total_num_scheduled_tokens
]
=
False
req_start_idx
=
0
should_sync_mrope_positions
=
False
mm_embeds
:
list
[
torch
.
Tensor
]
=
[]
for
req_id
in
self
.
input_batch
.
req_ids
:
mm_embeds_req
:
list
[
torch
.
Tensor
]
=
[]
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
req_state
=
self
.
requests
[
req_id
]
num_computed_tokens
=
\
req_state
.
num_computed_tokens
+
shift_computed_tokens
num_computed_tokens
=
req_state
.
num_computed_tokens
+
shift_computed_tokens
for
mm_feature
in
req_state
.
mm_features
:
pos_info
=
mm_feature
.
mm_position
start_pos
=
pos_info
.
offset
...
...
@@ -1649,12 +1658,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
mm_hash
=
mm_feature
.
identifier
encoder_output
=
self
.
encoder_cache
.
get
(
mm_hash
,
None
)
assert
encoder_output
is
not
None
,
\
f
"Encoder cache miss for
{
mm_hash
}
."
assert
encoder_output
is
not
None
,
f
"Encoder cache miss for
{
mm_hash
}
."
if
(
is_embed
:
=
pos_info
.
is_embed
)
is
not
None
:
is_embed
=
is_embed
[
start_idx
:
end_idx
]
req_start_pos
=
req_start_idx
+
start_pos
-
num_computed_tokens
is_mm_embed
[
req_start_pos
+
start_idx
:
req_start_pos
+
end_idx
]
=
(
True
if
is_embed
is
None
else
is_embed
)
mm_embeds_item
=
gather_mm_placeholders
(
encoder_output
[
start_idx
:
end_idx
],
is_embed
=
is_embed
,
...
...
@@ -1662,6 +1675,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
mm_embeds_req
.
append
(
mm_embeds_item
)
if
self
.
is_multimodal_pruning_enabled
and
self
.
uses_mrope
:
assert
req_state
.
mrope_positions
is
not
None
should_sync_mrope_positions
=
True
mm_embeds_req
,
new_mrope_positions
,
new_delta
=
(
self
.
model
.
recompute_mrope_positions
(
...
...
@@ -1669,19 +1683,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
multimodal_embeddings
=
mm_embeds_req
,
mrope_positions
=
req_state
.
mrope_positions
,
num_computed_tokens
=
req_state
.
num_computed_tokens
,
)
)
assert
req_state
.
mrope_positions
is
not
None
)
)
req_state
.
mrope_positions
.
copy_
(
new_mrope_positions
)
req_state
.
mrope_position_delta
=
new_delta
mm_embeds
.
extend
(
mm_embeds_req
)
req_start_idx
+=
num_scheduled_tokens
is_mm_embed
=
self
.
is_mm_embed
.
copy_to_gpu
(
total_num_scheduled_tokens
)
if
should_sync_mrope_positions
:
self
.
_calc_mrope_positions
(
scheduler_output
)
self
.
mrope_positions
.
copy_to_gpu
(
scheduler_output
.
total_num_scheduled_tokens
)
self
.
mrope_positions
.
copy_to_gpu
(
total_num_scheduled_tokens
)
return
mm_embeds
return
mm_embeds
,
is_mm_embed
def
_extract_encoder_inputs
(
self
,
...
...
@@ -1975,7 +1991,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
and
not
self
.
model_config
.
is_encoder_decoder
):
# Run the multimodal encoder if any.
self
.
_execute_mm_encoder
(
scheduler_output
)
mm_embeds
=
self
.
_gather_mm_embeddings
(
scheduler_output
)
mm_embeds
,
is_mm_embed
=
self
.
_gather_mm_embeddings
(
scheduler_output
)
# NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids)
...
...
@@ -1983,6 +1999,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
inputs_embeds_scheduled
=
self
.
model
.
get_input_embeddings
(
input_ids
=
self
.
input_ids
.
gpu
[:
num_scheduled_tokens
],
multimodal_embeddings
=
mm_embeds
or
None
,
is_multimodal
=
is_mm_embed
,
)
# TODO(woosuk): Avoid the copy. Optimize.
...
...
vllm/version.py
View file @
f39afa4a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
try
:
from
._version
import
__version__
,
__version_tuple__
__version__
=
"0.11.0"
__version_tuple__
=
(
0
,
11
,
0
)
__hcu_version__
=
f
'0.11.0+das.opt1.alpha.6c015e7.dtk25041'
from
vllm.version
import
__version__
,
__version_tuple__
,
__hcu_version__
except
Exception
as
e
:
import
warnings
warnings
.
warn
(
f
"Failed to read commit hash:
\n
{
e
}
"
,
warnings
.
warn
(
f
"Failed to read commit hash:
\n
+ str(e)
"
,
RuntimeWarning
,
stacklevel
=
2
)
__version__
=
"dev"
__version_tuple__
=
(
0
,
0
,
__version__
)
def
_prev_minor_version_was
(
version_str
):
"""
Check whether a given version matches the previous minor version.
'''
Check whether a given version matches the previous minor version.
Return True if version_str matches the previous minor version.
...
...
@@ -23,19 +24,19 @@ def _prev_minor_version_was(version_str):
supplied version_str is '0.6'.
Used for --show-hidden-metrics-for-version.
"""
'''
# Match anything if this is a dev tree
if
__version_tuple__
[
0
:
2
]
==
(
0
,
0
):
return
True
# Note - this won't do the right thing when we release 1.0!
assert
__version_tuple__
[
0
]
==
0
#
assert __version_tuple__[0] == 0
assert
isinstance
(
__version_tuple__
[
1
],
int
)
return
version_str
==
f
"
{
__version_tuple__
[
0
]
}
.
{
__version_tuple__
[
1
]
-
1
}
"
def
_prev_minor_version
():
"""
For the purpose of testing, return a previous minor version number.
"""
'''
For the purpose of testing, return a previous minor version number.
'''
# In dev tree, this will return "0.-1", but that will work fine"
assert
isinstance
(
__version_tuple__
[
1
],
int
)
return
f
"
{
__version_tuple__
[
0
]
}
.
{
__version_tuple__
[
1
]
-
1
}
"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment