Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b12c902b
Commit
b12c902b
authored
Oct 29, 2025
by
zhuwenwen
Browse files
Merge remote-tracking branch 'cx/v0.11.0-dev' into v0.11.0-dev-omni
parents
c16e075a
f39afa4a
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
303 additions
and
1141 deletions
+303
-1141
vllm/model_executor/layers/rotary_embedding/mrope.py
vllm/model_executor/layers/rotary_embedding/mrope.py
+91
-1009
vllm/model_executor/models/interfaces.py
vllm/model_executor/models/interfaces.py
+74
-11
vllm/model_executor/models/qwen3_omni_moe_thinker.py
vllm/model_executor/models/qwen3_omni_moe_thinker.py
+32
-27
vllm/model_executor/models/utils.py
vllm/model_executor/models/utils.py
+63
-69
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+31
-14
vllm/version.py
vllm/version.py
+12
-11
No files found.
vllm/model_executor/layers/rotary_embedding/mrope.py
View file @
b12c902b
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
itertools
from
typing
import
Optional
,
Union
import
numpy
as
np
import
torch
from
transformers
import
PretrainedConfig
from
vllm.triton_utils
import
tl
,
triton
...
...
@@ -62,10 +55,8 @@ def _triton_mrope_forward(
# Updated offsets for half head_dim
cos_offsets
=
tl
.
arange
(
0
,
pad_hd
//
2
)
if
is_interleaved
:
h_mask
=
(((
cos_offsets
%
3
)
==
1
)
&
(
cos_offsets
<=
3
*
mrope_section_h
))
w_mask
=
(((
cos_offsets
%
3
)
==
2
)
&
(
cos_offsets
<=
3
*
mrope_section_w
))
h_mask
=
((
cos_offsets
%
3
)
==
1
)
&
(
cos_offsets
<=
3
*
mrope_section_h
)
w_mask
=
((
cos_offsets
%
3
)
==
2
)
&
(
cos_offsets
<=
3
*
mrope_section_w
)
t_mask
=
~
(
h_mask
|
w_mask
)
else
:
t_end
=
mrope_section_t
...
...
@@ -89,21 +80,25 @@ def _triton_mrope_forward(
# program instance (i.e. for the current token) separately
# ####################################################################
# left half of the head
first_half_q_offsets
=
tl
.
arange
(
0
,
pad_n_qh
)[:,
None
]
*
hd
+
tl
.
arange
(
0
,
pad_hd
//
2
)[
None
,
:]
first_half_k_offsets
=
tl
.
arange
(
0
,
pad_n_kh
)[:,
None
]
*
hd
+
tl
.
arange
(
0
,
pad_hd
//
2
)[
None
,
:]
first_q_mask
=
(
tl
.
arange
(
0
,
pad_n_qh
)[:,
None
]
<
n_qh
)
&
(
tl
.
arange
(
0
,
pad_hd
//
2
)[
None
,
:]
<
rd
//
2
)
first_k_mask
=
(
tl
.
arange
(
0
,
pad_n_kh
)[:,
None
]
<
n_kh
)
&
(
tl
.
arange
(
0
,
pad_hd
//
2
)[
None
,
:]
<
rd
//
2
)
first_half_q_offsets
=
(
tl
.
arange
(
0
,
pad_n_qh
)[:,
None
]
*
hd
+
tl
.
arange
(
0
,
pad_hd
//
2
)[
None
,
:]
)
first_half_k_offsets
=
(
tl
.
arange
(
0
,
pad_n_kh
)[:,
None
]
*
hd
+
tl
.
arange
(
0
,
pad_hd
//
2
)[
None
,
:]
)
first_q_mask
=
(
tl
.
arange
(
0
,
pad_n_qh
)[:,
None
]
<
n_qh
)
&
(
tl
.
arange
(
0
,
pad_hd
//
2
)[
None
,
:]
<
rd
//
2
)
first_k_mask
=
(
tl
.
arange
(
0
,
pad_n_kh
)[:,
None
]
<
n_kh
)
&
(
tl
.
arange
(
0
,
pad_hd
//
2
)[
None
,
:]
<
rd
//
2
)
q_tile_1
=
tl
.
load
(
q_ptr
+
first_half_q_offsets
,
mask
=
first_q_mask
,
other
=
0
).
to
(
sin_row
.
dtype
)
k_tile_1
=
tl
.
load
(
k_ptr
+
first_half_k_offsets
,
mask
=
first_k_mask
,
other
=
0
).
to
(
sin_row
.
dtype
)
q_tile_1
=
tl
.
load
(
q_ptr
+
first_half_q_offsets
,
mask
=
first_q_mask
,
other
=
0
).
to
(
sin_row
.
dtype
)
k_tile_1
=
tl
.
load
(
k_ptr
+
first_half_k_offsets
,
mask
=
first_k_mask
,
other
=
0
).
to
(
sin_row
.
dtype
)
# right half of the head
second_half_q_offsets
=
first_half_q_offsets
+
(
rd
//
2
)
...
...
@@ -111,12 +106,12 @@ def _triton_mrope_forward(
second_q_mask
=
first_q_mask
second_k_mask
=
first_k_mask
q_tile_2
=
tl
.
load
(
q_ptr
+
second_half_q_offsets
,
mask
=
second_q_mask
,
other
=
0
).
to
(
sin_row
.
dtype
)
k_tile_2
=
tl
.
load
(
k_ptr
+
second_half_k_offsets
,
mask
=
second_k_mask
,
other
=
0
).
to
(
sin_row
.
dtype
)
q_tile_2
=
tl
.
load
(
q_ptr
+
second_half_q_offsets
,
mask
=
second_q_mask
,
other
=
0
).
to
(
sin_row
.
dtype
)
k_tile_2
=
tl
.
load
(
k_ptr
+
second_half_k_offsets
,
mask
=
second_k_mask
,
other
=
0
).
to
(
sin_row
.
dtype
)
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
# Since cos and sin are now half-size,
...
...
@@ -168,7 +163,7 @@ def triton_mrope(
cos
=
cos
.
contiguous
()
sin
=
sin
.
contiguous
()
_triton_mrope_forward
[(
n_row
,
)](
_triton_mrope_forward
[(
n_row
,)](
q
,
k
,
cos
,
...
...
@@ -189,15 +184,14 @@ def triton_mrope(
return
q
,
k
def
apply_interleaved_rope
(
x
:
torch
.
Tensor
,
mrope_section
:
list
[
int
])
->
torch
.
Tensor
:
def
apply_interleaved_rope
(
x
:
torch
.
Tensor
,
mrope_section
:
list
[
int
])
->
torch
.
Tensor
:
"""Apply interleaved MRoPE to 3D rotary embeddings.
Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
interleaved [THTHWHTHW...TT], preserving frequency continuity.
"""
x_t
=
x
[
0
].
clone
()
x_t
[...,
1
:
mrope_section
[
1
]
*
3
:
3
]
=
x
[
1
,
...,
1
:
mrope_section
[
1
]
*
3
:
3
]
x_t
[...,
2
:
mrope_section
[
2
]
*
3
:
3
]
=
x
[
2
,
...,
2
:
mrope_section
[
2
]
*
3
:
3
]
x_t
[...,
1
:
mrope_section
[
1
]
*
3
:
3
]
=
x
[
1
,
...,
1
:
mrope_section
[
1
]
*
3
:
3
]
x_t
[...,
2
:
mrope_section
[
2
]
*
3
:
3
]
=
x
[
2
,
...,
2
:
mrope_section
[
2
]
*
3
:
3
]
return
x_t
...
...
@@ -212,17 +206,16 @@ class MRotaryEmbedding(RotaryEmbedding):
base
:
float
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
mrope_section
:
Optional
[
list
[
int
]
]
=
None
,
mrope_section
:
list
[
int
]
|
None
=
None
,
mrope_interleaved
:
bool
=
False
,
# YaRN parameters.
*
,
scaling_factor
:
Optional
[
float
]
=
None
,
scaling_factor
:
float
|
None
=
None
,
extrapolation_factor
:
float
=
1
,
attn_factor
:
float
=
1
,
beta_fast
:
int
=
32
,
beta_slow
:
int
=
1
,
)
->
None
:
self
.
scaling_factor
=
scaling_factor
self
.
extrapolation_factor
=
extrapolation_factor
self
.
attn_factor
=
attn_factor
...
...
@@ -230,8 +223,7 @@ class MRotaryEmbedding(RotaryEmbedding):
self
.
beta_slow
=
beta_slow
if
self
.
scaling_factor
is
not
None
:
# Get n-d magnitude scaling corrected for interpolation
self
.
mscale
=
float
(
yarn_get_mscale
(
self
.
scaling_factor
)
*
attn_factor
)
self
.
mscale
=
float
(
yarn_get_mscale
(
self
.
scaling_factor
)
*
attn_factor
)
else
:
self
.
mscale
=
1.0
...
...
@@ -239,8 +231,14 @@ class MRotaryEmbedding(RotaryEmbedding):
# the input video. We enlarge max_position_embeddings to 4 times to get
# a larger the cos and sin cache.
self
.
cache_max_position_num
=
max_position_embeddings
*
4
super
().
__init__
(
head_size
,
rotary_dim
,
self
.
cache_max_position_num
,
base
,
is_neox_style
,
dtype
)
super
().
__init__
(
head_size
,
rotary_dim
,
self
.
cache_max_position_num
,
base
,
is_neox_style
,
dtype
,
)
self
.
mrope_section
=
mrope_section
self
.
mrope_interleaved
=
mrope_interleaved
...
...
@@ -261,9 +259,9 @@ class MRotaryEmbedding(RotaryEmbedding):
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]
]:
key
:
torch
.
Tensor
|
None
=
None
,
offsets
:
torch
.
Tensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
]:
"""PyTorch-native implementation equivalent to forward().
Args:
...
...
@@ -286,31 +284,27 @@ class MRotaryEmbedding(RotaryEmbedding):
cos
=
apply_interleaved_rope
(
cos
,
self
.
mrope_section
)
sin
=
apply_interleaved_rope
(
sin
,
self
.
mrope_section
)
else
:
cos
=
torch
.
cat
([
m
[
i
]
for
i
,
m
in
enumerate
(
cos
.
split
(
self
.
mrope_section
,
dim
=-
1
))
],
dim
=-
1
)
sin
=
torch
.
cat
([
m
[
i
]
for
i
,
m
in
enumerate
(
sin
.
split
(
self
.
mrope_section
,
dim
=-
1
))
],
dim
=-
1
)
cos
=
torch
.
cat
(
[
m
[
i
]
for
i
,
m
in
enumerate
(
cos
.
split
(
self
.
mrope_section
,
dim
=-
1
))],
dim
=-
1
,
)
sin
=
torch
.
cat
(
[
m
[
i
]
for
i
,
m
in
enumerate
(
sin
.
split
(
self
.
mrope_section
,
dim
=-
1
))],
dim
=-
1
,
)
query_shape
=
query
.
shape
query
=
query
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
query_pass
=
query
[...,
self
.
rotary_dim
:]
query_rot
=
apply_rotary_emb_dispatch
(
query_rot
,
cos
,
sin
,
self
.
is_neox_style
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
query_pass
=
query
[...,
self
.
rotary_dim
:]
query_rot
=
apply_rotary_emb_dispatch
(
query_rot
,
cos
,
sin
,
self
.
is_neox_style
)
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
).
reshape
(
query_shape
)
key_shape
=
key
.
shape
key
=
key
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key_pass
=
key
[...,
self
.
rotary_dim
:]
key_rot
=
apply_rotary_emb_dispatch
(
key_rot
,
cos
,
sin
,
self
.
is_neox_style
)
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key_pass
=
key
[...,
self
.
rotary_dim
:]
key_rot
=
apply_rotary_emb_dispatch
(
key_rot
,
cos
,
sin
,
self
.
is_neox_style
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
return
query
,
key
...
...
@@ -318,10 +312,9 @@ class MRotaryEmbedding(RotaryEmbedding):
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
key
:
torch
.
Tensor
|
None
=
None
,
offsets
:
torch
.
Tensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
]:
assert
positions
.
ndim
==
1
or
positions
.
ndim
==
2
assert
key
is
not
None
...
...
@@ -348,17 +341,15 @@ class MRotaryEmbedding(RotaryEmbedding):
return
q
.
reshape
(
query_shape
),
k
.
reshape
(
key_shape
)
query
=
query
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
query_pass
=
query
[...,
self
.
rotary_dim
:]
query_rot
=
apply_rotary_emb_dispatch
(
query_rot
,
cos
,
sin
,
self
.
is_neox_style
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
query_pass
=
query
[...,
self
.
rotary_dim
:]
query_rot
=
apply_rotary_emb_dispatch
(
query_rot
,
cos
,
sin
,
self
.
is_neox_style
)
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
).
reshape
(
query_shape
)
key
=
key
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key_pass
=
key
[...,
self
.
rotary_dim
:]
key_rot
=
apply_rotary_emb_dispatch
(
key_rot
,
cos
,
sin
,
self
.
is_neox_style
)
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key_pass
=
key
[...,
self
.
rotary_dim
:]
key_rot
=
apply_rotary_emb_dispatch
(
key_rot
,
cos
,
sin
,
self
.
is_neox_style
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
return
query
,
key
...
...
@@ -366,885 +357,20 @@ class MRotaryEmbedding(RotaryEmbedding):
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]
]:
key
:
torch
.
Tensor
|
None
=
None
,
offsets
:
torch
.
Tensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
]:
return
self
.
forward_native
(
positions
,
query
,
key
,
offsets
)
def
forward_cpu
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]
]:
key
:
torch
.
Tensor
|
None
=
None
,
offsets
:
torch
.
Tensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
]:
return
self
.
forward_native
(
positions
,
query
,
key
,
offsets
)
@
classmethod
def
get_input_positions
(
cls
,
input_tokens
:
list
[
int
],
hf_config
:
PretrainedConfig
,
image_grid_thw
:
Optional
[
Union
[
list
[
list
[
int
]],
torch
.
Tensor
]],
video_grid_thw
:
Optional
[
Union
[
list
[
list
[
int
]],
torch
.
Tensor
]],
second_per_grid_ts
:
Optional
[
list
[
float
]],
context_len
:
int
=
0
,
seq_len
:
Optional
[
int
]
=
None
,
audio_feature_lengths
:
Optional
[
torch
.
Tensor
]
=
None
,
use_audio_in_video
:
bool
=
False
,
)
->
tuple
[
list
[
list
[
int
]],
int
]:
"""Get mrope input positions and delta value."""
image_grid_thw
=
[]
if
image_grid_thw
is
None
else
image_grid_thw
video_grid_thw
=
[]
if
video_grid_thw
is
None
else
video_grid_thw
second_per_grid_ts
=
[]
if
second_per_grid_ts
is
None
else
\
second_per_grid_ts
llm_positions
,
mrope_position_delta
=
\
cls
.
get_input_positions_tensor
(
input_tokens
=
input_tokens
,
hf_config
=
hf_config
,
image_grid_thw
=
image_grid_thw
,
video_grid_thw
=
video_grid_thw
,
second_per_grid_ts
=
second_per_grid_ts
,
context_len
=
context_len
,
seq_len
=
seq_len
,
audio_feature_lengths
=
audio_feature_lengths
,
use_audio_in_video
=
use_audio_in_video
,
)
return
llm_positions
.
tolist
(),
mrope_position_delta
@
classmethod
def
get_input_positions_tensor
(
cls
,
input_tokens
:
list
[
int
],
hf_config
:
PretrainedConfig
,
image_grid_thw
:
Union
[
list
[
list
[
int
]],
torch
.
Tensor
],
video_grid_thw
:
Union
[
list
[
list
[
int
]],
torch
.
Tensor
],
second_per_grid_ts
:
list
[
float
],
context_len
:
int
=
0
,
seq_len
:
Optional
[
int
]
=
None
,
audio_feature_lengths
:
Optional
[
torch
.
Tensor
]
=
None
,
use_audio_in_video
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
int
]:
from
vllm.transformers_utils.config
import
thinker_uses_mrope
if
thinker_uses_mrope
(
hf_config
)
and
hf_config
.
model_type
==
"qwen2_5_omni"
:
return
cls
.
_omni_get_input_positions_tensor
(
input_tokens
=
input_tokens
,
hf_config
=
hf_config
,
image_grid_thw
=
image_grid_thw
,
video_grid_thw
=
video_grid_thw
,
second_per_grid_ts
=
second_per_grid_ts
,
context_len
=
context_len
,
seq_len
=
seq_len
,
audio_feature_lengths
=
audio_feature_lengths
,
use_audio_in_video
=
use_audio_in_video
,
)
elif
hf_config
.
model_type
in
[
"glm4v"
,
"glm4v_moe"
]:
return
cls
.
_glm4v_get_input_positions_tensor
(
input_tokens
=
input_tokens
,
hf_config
=
hf_config
,
image_grid_thw
=
image_grid_thw
,
video_grid_thw
=
video_grid_thw
,
context_len
=
context_len
,
seq_len
=
seq_len
,
)
elif
hf_config
.
model_type
in
[
"qwen3_vl"
,
"qwen3_vl_moe"
]:
return
cls
.
_qwen3vl_get_input_positions_tensor
(
input_tokens
=
input_tokens
,
hf_config
=
hf_config
,
image_grid_thw
=
image_grid_thw
,
video_grid_thw
=
video_grid_thw
,
context_len
=
context_len
,
seq_len
=
seq_len
,
)
elif
hf_config
.
model_type
in
[
"ernie4_5_moe_vl"
,
"ernie4_5_vl"
]:
return
cls
.
_ernie_get_input_positions_tensor
(
input_tokens
=
input_tokens
,
hf_config
=
hf_config
,
image_grid_thw
=
image_grid_thw
,
video_grid_thw
=
video_grid_thw
,
context_len
=
context_len
,
seq_len
=
seq_len
,
)
elif
"KeyeVL1_5"
in
hf_config
.
model_type
:
return
cls
.
_keye_get_input_positions_tensor
(
input_tokens
=
input_tokens
,
hf_config
=
hf_config
,
image_grid_thw
=
image_grid_thw
,
video_grid_thw
=
video_grid_thw
,
context_len
=
context_len
,
seq_len
=
seq_len
,
)
else
:
return
cls
.
_vl_get_input_positions_tensor
(
input_tokens
=
input_tokens
,
hf_config
=
hf_config
,
image_grid_thw
=
image_grid_thw
,
video_grid_thw
=
video_grid_thw
,
second_per_grid_ts
=
second_per_grid_ts
,
context_len
=
context_len
,
seq_len
=
seq_len
,
)
@
classmethod
def
_glm4v_get_input_positions_tensor
(
cls
,
input_tokens
:
list
[
int
],
hf_config
:
PretrainedConfig
,
image_grid_thw
:
Union
[
list
[
list
[
int
]],
torch
.
Tensor
],
video_grid_thw
:
Union
[
list
[
list
[
int
]],
torch
.
Tensor
],
context_len
:
int
=
0
,
seq_len
:
Optional
[
int
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
int
]:
"""Get mrope input positions and delta value for GLM4V."""
image_token_id
=
hf_config
.
image_token_id
video_start_token_id
=
hf_config
.
video_start_token_id
video_end_token_id
=
hf_config
.
video_end_token_id
spatial_merge_size
=
hf_config
.
vision_config
.
spatial_merge_size
llm_pos_ids_list
:
list
=
[]
if
not
(
image_grid_thw
is
None
and
video_grid_thw
is
None
):
if
isinstance
(
image_grid_thw
,
torch
.
Tensor
):
image_grid_thw
=
image_grid_thw
.
tolist
()
input_token_type
:
list
[
str
]
=
[]
video_check_flg
=
False
for
token
in
input_tokens
:
if
token
==
video_start_token_id
:
video_check_flg
=
True
elif
token
==
video_end_token_id
:
video_check_flg
=
False
if
(
token
==
image_token_id
)
and
(
video_check_flg
is
False
):
input_token_type
.
append
(
"image"
)
elif
(
token
==
image_token_id
)
and
(
video_check_flg
is
True
):
input_token_type
.
append
(
"video"
)
else
:
input_token_type
.
append
(
"text"
)
input_type_group
:
list
[
tuple
[
str
,
int
,
int
]]
=
[]
for
key
,
group_iter
in
itertools
.
groupby
(
enumerate
(
input_token_type
),
lambda
x
:
x
[
1
]):
group_list
=
list
(
group_iter
)
start_index
=
group_list
[
0
][
0
]
end_index
=
group_list
[
-
1
][
0
]
+
1
input_type_group
.
append
((
key
,
start_index
,
end_index
))
video_frame_num
=
1
mm_data_idx
=
0
for
modality_type
,
start_idx
,
end_idx
in
input_type_group
:
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
len
(
llm_pos_ids_list
)
>
0
else
0
if
modality_type
==
"image"
:
t
,
h
,
w
=
(
image_grid_thw
[
mm_data_idx
][
0
],
image_grid_thw
[
mm_data_idx
][
1
],
image_grid_thw
[
mm_data_idx
][
2
],
)
llm_grid_t
,
llm_grid_h
,
llm_grid_w
=
\
t
,
h
//
spatial_merge_size
,
w
//
spatial_merge_size
t_index
=
torch
.
arange
(
llm_grid_t
).
view
(
-
1
,
1
).
expand
(
-
1
,
llm_grid_h
*
llm_grid_w
).
flatten
()
h_index
=
torch
.
arange
(
llm_grid_h
).
view
(
1
,
-
1
,
1
).
expand
(
llm_grid_t
,
-
1
,
llm_grid_w
).
flatten
()
w_index
=
torch
.
arange
(
llm_grid_w
).
view
(
1
,
1
,
-
1
).
expand
(
llm_grid_t
,
llm_grid_h
,
-
1
).
flatten
()
llm_pos_ids_list
.
append
(
torch
.
stack
([
t_index
,
h_index
,
w_index
])
+
st_idx
)
mm_data_idx
+=
1
elif
modality_type
==
"video"
:
t
,
h
,
w
=
(
video_frame_num
,
image_grid_thw
[
mm_data_idx
][
1
],
image_grid_thw
[
mm_data_idx
][
2
],
)
llm_grid_t
,
llm_grid_h
,
llm_grid_w
=
\
t
,
h
//
spatial_merge_size
,
w
//
spatial_merge_size
for
t_idx
in
range
(
llm_grid_t
):
t_index
=
torch
.
tensor
(
t_idx
).
view
(
-
1
,
1
).
expand
(
-
1
,
llm_grid_h
*
llm_grid_w
).
flatten
()
h_index
=
torch
.
arange
(
llm_grid_h
).
view
(
1
,
-
1
,
1
).
expand
(
1
,
-
1
,
llm_grid_w
).
flatten
()
w_index
=
torch
.
arange
(
llm_grid_w
).
view
(
1
,
1
,
-
1
).
expand
(
1
,
llm_grid_h
,
-
1
).
flatten
()
llm_pos_ids_list
.
append
(
torch
.
stack
([
t_index
,
h_index
,
w_index
])
+
st_idx
)
mm_data_idx
+=
1
video_frame_num
+=
1
else
:
text_len
=
end_idx
-
start_idx
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
video_frame_num
=
1
else
:
text_len
=
len
(
input_tokens
)
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
))
llm_positions
=
torch
.
cat
(
llm_pos_ids_list
,
dim
=
1
).
reshape
(
3
,
-
1
)
llm_positions
=
llm_positions
[:,
context_len
:
seq_len
]
mrope_position_delta
=
(
llm_positions
.
max
()
+
1
-
len
(
input_tokens
)).
item
()
return
llm_positions
,
mrope_position_delta
@
classmethod
def
_qwen3vl_get_input_positions_tensor
(
cls
,
input_tokens
:
list
[
int
],
hf_config
:
PretrainedConfig
,
image_grid_thw
:
Union
[
list
[
list
[
int
]],
torch
.
Tensor
],
video_grid_thw
:
Union
[
list
[
list
[
int
]],
torch
.
Tensor
],
context_len
:
int
=
0
,
seq_len
:
Optional
[
int
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
int
]:
"""Get mrope input positions and delta value."""
video_grid_thw
=
[[
1
,
h
,
w
]
for
t
,
h
,
w
in
video_grid_thw
for
_
in
range
(
t
)]
image_token_id
=
hf_config
.
image_token_id
video_token_id
=
hf_config
.
video_token_id
vision_start_token_id
=
hf_config
.
vision_start_token_id
spatial_merge_size
=
hf_config
.
vision_config
.
spatial_merge_size
input_tokens_tensor
=
torch
.
tensor
(
input_tokens
)
vision_start_indices
=
torch
.
argwhere
(
input_tokens_tensor
==
vision_start_token_id
).
squeeze
(
1
)
vision_tokens
=
input_tokens_tensor
[
vision_start_indices
+
1
]
image_nums
=
(
vision_tokens
==
image_token_id
).
sum
()
video_nums
=
(
vision_tokens
==
video_token_id
).
sum
()
llm_pos_ids_list
:
list
=
[]
st
=
0
remain_images
,
remain_videos
=
image_nums
,
video_nums
image_index
,
video_index
=
0
,
0
for
_
in
range
(
image_nums
+
video_nums
):
if
image_token_id
in
input_tokens
and
remain_images
>
0
:
ed_image
=
input_tokens
.
index
(
image_token_id
,
st
)
else
:
ed_image
=
len
(
input_tokens
)
+
1
if
video_token_id
in
input_tokens
and
remain_videos
>
0
:
ed_video
=
input_tokens
.
index
(
video_token_id
,
st
)
else
:
ed_video
=
len
(
input_tokens
)
+
1
if
ed_image
<
ed_video
:
t
,
h
,
w
=
(
image_grid_thw
[
image_index
][
0
],
image_grid_thw
[
image_index
][
1
],
image_grid_thw
[
image_index
][
2
],
)
image_index
+=
1
remain_images
-=
1
ed
=
ed_image
else
:
t
,
h
,
w
=
(
video_grid_thw
[
video_index
][
0
],
video_grid_thw
[
video_index
][
1
],
video_grid_thw
[
video_index
][
2
],
)
video_index
+=
1
remain_videos
-=
1
ed
=
ed_video
llm_grid_t
,
llm_grid_h
,
llm_grid_w
=
\
t
,
h
//
spatial_merge_size
,
w
//
spatial_merge_size
text_len
=
ed
-
st
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
len
(
llm_pos_ids_list
)
>
0
else
0
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
t_index
=
torch
.
arange
(
llm_grid_t
).
view
(
-
1
,
1
).
expand
(
-
1
,
llm_grid_h
*
llm_grid_w
).
flatten
()
h_index
=
torch
.
arange
(
llm_grid_h
).
view
(
1
,
-
1
,
1
).
expand
(
llm_grid_t
,
-
1
,
llm_grid_w
).
flatten
()
w_index
=
torch
.
arange
(
llm_grid_w
).
view
(
1
,
1
,
-
1
).
expand
(
llm_grid_t
,
llm_grid_h
,
-
1
).
flatten
()
llm_pos_ids_list
.
append
(
torch
.
stack
([
t_index
,
h_index
,
w_index
])
+
text_len
+
st_idx
)
st
=
ed
+
llm_grid_t
*
llm_grid_h
*
llm_grid_w
if
st
<
len
(
input_tokens
):
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
len
(
llm_pos_ids_list
)
>
0
else
0
text_len
=
len
(
input_tokens
)
-
st
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
llm_positions
=
torch
.
cat
(
llm_pos_ids_list
,
dim
=
1
).
reshape
(
3
,
-
1
)
mrope_position_delta
=
(
llm_positions
.
max
()
+
1
-
len
(
input_tokens
)).
item
()
llm_positions
=
llm_positions
[:,
context_len
:
seq_len
]
return
llm_positions
,
mrope_position_delta
@
classmethod
def
_ernie_get_input_positions_tensor
(
cls
,
input_tokens
:
list
[
int
],
hf_config
:
PretrainedConfig
,
image_grid_thw
:
Union
[
list
[
list
[
int
]],
torch
.
Tensor
],
video_grid_thw
:
Union
[
list
[
list
[
int
]],
torch
.
Tensor
],
context_len
:
int
=
0
,
seq_len
:
Optional
[
int
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
int
]:
"""Get mrope input positions and delta value for Ernie VL."""
image_token_id
=
hf_config
.
im_patch_id
video_start_token_id
=
hf_config
.
video_start_token_id
video_end_token_id
=
hf_config
.
video_end_token_id
spatial_conv_size
=
hf_config
.
spatial_conv_size
temporal_conv_size
=
hf_config
.
temporal_conv_size
llm_pos_ids_list
:
list
=
[]
if
not
(
image_grid_thw
is
None
and
video_grid_thw
is
None
):
if
isinstance
(
image_grid_thw
,
torch
.
Tensor
):
image_grid_thw
=
image_grid_thw
.
tolist
()
input_token_type
:
list
[
str
]
=
[]
video_check_flg
=
False
for
token
in
input_tokens
:
if
token
==
video_start_token_id
:
video_check_flg
=
True
elif
token
==
video_end_token_id
:
video_check_flg
=
False
if
(
token
==
image_token_id
)
and
(
video_check_flg
is
False
):
input_token_type
.
append
(
"image"
)
elif
(
token
==
image_token_id
)
and
(
video_check_flg
is
True
):
input_token_type
.
append
(
"video"
)
else
:
input_token_type
.
append
(
"text"
)
input_type_group
:
list
[
tuple
[
str
,
int
,
int
]]
=
[]
for
key
,
group_iter
in
itertools
.
groupby
(
enumerate
(
input_token_type
),
lambda
x
:
x
[
1
]):
group_list
=
list
(
group_iter
)
start_index
=
group_list
[
0
][
0
]
end_index
=
group_list
[
-
1
][
0
]
+
1
input_type_group
.
append
((
key
,
start_index
,
end_index
))
video_frame_num
=
1
mm_data_idx
=
0
for
modality_type
,
start_idx
,
end_idx
in
input_type_group
:
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
len
(
llm_pos_ids_list
)
>
0
else
0
if
modality_type
==
"image"
:
t
,
h
,
w
=
(
image_grid_thw
[
mm_data_idx
][
0
],
image_grid_thw
[
mm_data_idx
][
1
],
image_grid_thw
[
mm_data_idx
][
2
],
)
llm_grid_t
,
llm_grid_h
,
llm_grid_w
=
\
t
,
h
//
spatial_conv_size
,
w
//
spatial_conv_size
t_index
=
torch
.
arange
(
llm_grid_t
).
view
(
-
1
,
1
).
expand
(
-
1
,
llm_grid_h
*
llm_grid_w
).
flatten
()
h_index
=
torch
.
arange
(
llm_grid_h
).
view
(
1
,
-
1
,
1
).
expand
(
llm_grid_t
,
-
1
,
llm_grid_w
).
flatten
()
w_index
=
torch
.
arange
(
llm_grid_w
).
view
(
1
,
1
,
-
1
).
expand
(
llm_grid_t
,
llm_grid_h
,
-
1
).
flatten
()
llm_pos_ids_list
.
append
(
torch
.
stack
([
t_index
,
h_index
,
w_index
])
+
st_idx
)
mm_data_idx
+=
1
elif
modality_type
==
"video"
:
t
,
h
,
w
=
(
video_grid_thw
[
mm_data_idx
][
0
],
video_grid_thw
[
mm_data_idx
][
1
],
video_grid_thw
[
mm_data_idx
][
2
],
)
llm_grid_t
,
llm_grid_h
,
llm_grid_w
=
(
t
//
temporal_conv_size
,
h
//
spatial_conv_size
,
w
//
spatial_conv_size
)
for
t_idx
in
range
(
llm_grid_t
):
t_index
=
torch
.
tensor
(
t_idx
).
view
(
-
1
,
1
).
expand
(
-
1
,
llm_grid_h
*
llm_grid_w
).
flatten
()
h_index
=
torch
.
arange
(
llm_grid_h
).
view
(
1
,
-
1
,
1
).
expand
(
1
,
-
1
,
llm_grid_w
).
flatten
()
w_index
=
torch
.
arange
(
llm_grid_w
).
view
(
1
,
1
,
-
1
).
expand
(
1
,
llm_grid_h
,
-
1
).
flatten
()
llm_pos_ids_list
.
append
(
torch
.
stack
([
t_index
,
h_index
,
w_index
])
+
st_idx
)
mm_data_idx
+=
1
video_frame_num
+=
1
else
:
text_len
=
end_idx
-
start_idx
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
video_frame_num
=
1
else
:
text_len
=
len
(
input_tokens
)
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
))
llm_positions
=
torch
.
cat
(
llm_pos_ids_list
,
dim
=
1
).
reshape
(
3
,
-
1
)
llm_positions
=
llm_positions
[:,
context_len
:
seq_len
]
mrope_position_delta
=
(
llm_positions
.
max
()
+
1
-
len
(
input_tokens
)).
item
()
return
llm_positions
,
mrope_position_delta
@
classmethod
def
_keye_get_input_positions_tensor
(
cls
,
input_tokens
:
list
[
int
],
hf_config
:
PretrainedConfig
,
image_grid_thw
:
Union
[
list
[
list
[
int
]],
torch
.
Tensor
],
video_grid_thw
:
Union
[
list
[
list
[
int
]],
torch
.
Tensor
],
context_len
:
int
=
0
,
seq_len
:
Optional
[
int
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
int
]:
if
isinstance
(
video_grid_thw
,
list
)
and
len
(
video_grid_thw
)
>
0
:
video_grid_thw
=
video_grid_thw
[
0
]
"""Get mrope input positions and delta value (Keye series)."""
def
split_thw
(
grid_thw
:
Union
[
torch
.
Tensor
,
list
[
int
]])
->
list
[
list
[
int
]]:
"""
Split grid_thw along the t dimension.
Args:
grid_thw: shape [N, 3] tensor or nested list of [t, h, w].
Returns:
List of [1, h, w] rows, repeated t times for each original row.
"""
if
isinstance
(
grid_thw
,
list
):
grid_thw
=
torch
.
tensor
(
grid_thw
,
dtype
=
torch
.
long
)
if
grid_thw
.
numel
()
==
0
:
return
[]
t
,
hw
=
grid_thw
[:,
0
],
grid_thw
[:,
1
:]
ones
=
torch
.
ones_like
(
hw
[:,
:
1
])
# [N,1]
out
=
torch
.
cat
([
ones
,
hw
],
dim
=
1
).
repeat_interleave
(
t
,
dim
=
0
)
return
out
.
tolist
()
video_grid_thw
=
split_thw
(
video_grid_thw
)
image_token_id
=
hf_config
.
image_token_id
video_token_id
=
hf_config
.
video_token_id
spatial_merge_size
=
hf_config
.
vision_config
.
spatial_merge_size
image_nums
=
len
(
image_grid_thw
)
frame_nums
=
len
(
video_grid_thw
)
llm_pos_ids_list
:
list
=
[]
st
=
0
remain_images
,
remain_frames
=
image_nums
,
frame_nums
image_index
,
video_index
=
0
,
0
for
_
in
range
(
image_nums
+
frame_nums
):
if
remain_images
>
0
:
try
:
ed_image
=
input_tokens
.
index
(
image_token_id
,
st
)
except
ValueError
:
ed_image
=
len
(
input_tokens
)
+
1
else
:
ed_image
=
len
(
input_tokens
)
+
1
if
remain_frames
>
0
:
try
:
ed_video
=
input_tokens
.
index
(
video_token_id
,
st
)
except
ValueError
:
ed_video
=
len
(
input_tokens
)
+
1
else
:
ed_video
=
len
(
input_tokens
)
+
1
if
ed_image
<
ed_video
:
t
,
h
,
w
=
(
image_grid_thw
[
image_index
][
0
],
image_grid_thw
[
image_index
][
1
],
image_grid_thw
[
image_index
][
2
],
)
image_index
+=
1
remain_images
-=
1
ed
=
ed_image
else
:
t
,
h
,
w
=
(
video_grid_thw
[
video_index
][
0
],
video_grid_thw
[
video_index
][
1
],
video_grid_thw
[
video_index
][
2
],
)
video_index
+=
1
remain_frames
-=
1
ed
=
ed_video
llm_grid_t
,
llm_grid_h
,
llm_grid_w
=
\
t
,
h
//
spatial_merge_size
,
w
//
spatial_merge_size
text_len
=
ed
-
st
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
len
(
llm_pos_ids_list
)
>
0
else
0
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
t_index
=
(
torch
.
arange
(
llm_grid_t
).
view
(
-
1
,
1
).
expand
(
-
1
,
llm_grid_h
*
llm_grid_w
)).
flatten
()
h_index
=
torch
.
arange
(
llm_grid_h
).
view
(
1
,
-
1
,
1
).
expand
(
llm_grid_t
,
-
1
,
llm_grid_w
).
flatten
()
w_index
=
torch
.
arange
(
llm_grid_w
).
view
(
1
,
1
,
-
1
).
expand
(
llm_grid_t
,
llm_grid_h
,
-
1
).
flatten
()
llm_pos_ids_list
.
append
(
torch
.
stack
([
t_index
,
h_index
,
w_index
])
+
text_len
+
st_idx
)
st
=
ed
+
llm_grid_t
*
llm_grid_h
*
llm_grid_w
if
st
<
len
(
input_tokens
):
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
len
(
llm_pos_ids_list
)
>
0
else
0
text_len
=
len
(
input_tokens
)
-
st
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
llm_positions
=
torch
.
cat
(
llm_pos_ids_list
,
dim
=
1
).
reshape
(
3
,
-
1
)
mrope_position_delta
=
(
llm_positions
.
max
()
+
1
-
len
(
input_tokens
)).
item
()
llm_positions
=
llm_positions
[:,
context_len
:
seq_len
]
return
llm_positions
,
mrope_position_delta
@
classmethod
def
_vl_get_input_positions_tensor
(
cls
,
input_tokens
:
list
[
int
],
hf_config
:
PretrainedConfig
,
image_grid_thw
:
Union
[
list
[
list
[
int
]],
torch
.
Tensor
],
video_grid_thw
:
Union
[
list
[
list
[
int
]],
torch
.
Tensor
],
second_per_grid_ts
:
list
[
float
],
context_len
:
int
=
0
,
seq_len
:
Optional
[
int
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
int
]:
"""Get mrope input positions and delta value."""
image_token_id
=
hf_config
.
image_token_id
video_token_id
=
hf_config
.
video_token_id
vision_start_token_id
=
hf_config
.
vision_start_token_id
spatial_merge_size
=
hf_config
.
vision_config
.
spatial_merge_size
tokens_per_second
=
getattr
(
hf_config
.
vision_config
,
"tokens_per_second"
,
1.0
)
input_tokens_tensor
=
torch
.
tensor
(
input_tokens
)
vision_start_indices
=
torch
.
argwhere
(
input_tokens_tensor
==
vision_start_token_id
).
squeeze
(
1
)
vision_tokens
=
input_tokens_tensor
[
vision_start_indices
+
1
]
image_nums
=
(
vision_tokens
==
image_token_id
).
sum
()
video_nums
=
(
vision_tokens
==
video_token_id
).
sum
()
llm_pos_ids_list
:
list
=
[]
st
=
0
remain_images
,
remain_videos
=
image_nums
,
video_nums
image_index
,
video_index
=
0
,
0
for
_
in
range
(
image_nums
+
video_nums
):
video_second_per_grid_t
=
0.0
if
remain_images
>
0
:
try
:
ed_image
=
input_tokens
.
index
(
image_token_id
,
st
)
except
ValueError
:
ed_image
=
len
(
input_tokens
)
+
1
else
:
ed_image
=
len
(
input_tokens
)
+
1
if
remain_videos
>
0
:
try
:
ed_video
=
input_tokens
.
index
(
video_token_id
,
st
)
except
ValueError
:
ed_video
=
len
(
input_tokens
)
+
1
else
:
ed_video
=
len
(
input_tokens
)
+
1
if
ed_image
<
ed_video
:
t
,
h
,
w
=
(
image_grid_thw
[
image_index
][
0
],
image_grid_thw
[
image_index
][
1
],
image_grid_thw
[
image_index
][
2
],
)
image_index
+=
1
remain_images
-=
1
ed
=
ed_image
else
:
t
,
h
,
w
=
(
video_grid_thw
[
video_index
][
0
],
video_grid_thw
[
video_index
][
1
],
video_grid_thw
[
video_index
][
2
],
)
video_second_per_grid_t
=
1.0
if
second_per_grid_ts
:
video_second_per_grid_t
=
second_per_grid_ts
[
video_index
]
video_index
+=
1
remain_videos
-=
1
ed
=
ed_video
llm_grid_t
,
llm_grid_h
,
llm_grid_w
=
\
t
,
h
//
spatial_merge_size
,
w
//
spatial_merge_size
text_len
=
ed
-
st
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
len
(
llm_pos_ids_list
)
>
0
else
0
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
t_index
=
(
torch
.
arange
(
llm_grid_t
).
view
(
-
1
,
1
).
expand
(
-
1
,
llm_grid_h
*
llm_grid_w
)
*
video_second_per_grid_t
*
tokens_per_second
).
flatten
()
h_index
=
torch
.
arange
(
llm_grid_h
).
view
(
1
,
-
1
,
1
).
expand
(
llm_grid_t
,
-
1
,
llm_grid_w
).
flatten
()
w_index
=
torch
.
arange
(
llm_grid_w
).
view
(
1
,
1
,
-
1
).
expand
(
llm_grid_t
,
llm_grid_h
,
-
1
).
flatten
()
llm_pos_ids_list
.
append
(
torch
.
stack
([
t_index
,
h_index
,
w_index
])
+
text_len
+
st_idx
)
st
=
ed
+
llm_grid_t
*
llm_grid_h
*
llm_grid_w
if
st
<
len
(
input_tokens
):
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
len
(
llm_pos_ids_list
)
>
0
else
0
text_len
=
len
(
input_tokens
)
-
st
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
llm_positions
=
torch
.
cat
(
llm_pos_ids_list
,
dim
=
1
).
reshape
(
3
,
-
1
)
mrope_position_delta
=
(
llm_positions
.
max
()
+
1
-
len
(
input_tokens
)).
item
()
llm_positions
=
llm_positions
[:,
context_len
:
seq_len
]
return
llm_positions
,
mrope_position_delta
@
classmethod
def
_omni_get_input_positions_tensor
(
cls
,
input_tokens
:
list
[
int
],
hf_config
:
PretrainedConfig
,
image_grid_thw
:
Union
[
list
[
list
[
int
]],
torch
.
Tensor
],
video_grid_thw
:
Union
[
list
[
list
[
int
]],
torch
.
Tensor
],
second_per_grid_ts
:
Optional
[
list
[
float
]]
=
None
,
context_len
:
int
=
0
,
seq_len
:
Optional
[
int
]
=
None
,
audio_feature_lengths
:
Optional
[
torch
.
Tensor
]
=
None
,
use_audio_in_video
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
int
]:
"""Get mrope input positions and delta value (Qwen2.5-Omni version).
Differences from MRotaryEmbedding:
1. Add audio support (and related `audio_feature_lengths`).
2. Add `use_audio_in_video` option to read audio from video inputs.
In this case, audio and vision position ids will be split into
chunks and interleaved.
Example:
(V_i are vision position ids, A_i are audio position ids)
|V_1 ... V_n|A_1 ... A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|...
|vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |...
"""
# TODO(fyabc): refactor and share more code with
# _vl_get_input_positions_tensor.
thinker_config
=
hf_config
.
thinker_config
if
isinstance
(
image_grid_thw
,
list
):
image_grid_thw
=
torch
.
tensor
(
image_grid_thw
)
if
isinstance
(
video_grid_thw
,
list
):
video_grid_thw
=
torch
.
tensor
(
video_grid_thw
)
audio_token_id
=
thinker_config
.
audio_token_index
image_token_id
=
thinker_config
.
image_token_index
video_token_id
=
thinker_config
.
video_token_index
audio_start_token_id
=
thinker_config
.
audio_start_token_id
audio_end_token_id
=
thinker_config
.
audio_end_token_id
vision_start_token_id
=
thinker_config
.
vision_start_token_id
vision_end_token_id
=
thinker_config
.
vision_end_token_id
seconds_per_chunk
=
thinker_config
.
seconds_per_chunk
spatial_merge_size
=
thinker_config
.
vision_config
.
spatial_merge_size
tokens_per_second
=
getattr
(
thinker_config
.
vision_config
,
"tokens_per_second"
,
25
)
src_item
=
input_tokens
audio_seqlens
=
audio_feature_lengths
if
not
second_per_grid_ts
:
second_per_grid_ts
=
[
1
]
*
video_grid_thw
.
shape
[
0
]
audio_idx
=
0
video_idx
=
0
image_idx
=
0
new_src_item
:
list
[
int
]
=
[]
llm_pos_ids_list
:
list
[
torch
.
Tensor
]
=
[]
idx
=
0
while
idx
<
len
(
src_item
):
new_src_item_len
=
len
(
new_src_item
)
start_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
len
(
llm_pos_ids_list
)
>
0
else
0
if
src_item
[
idx
]
not
in
[
audio_token_id
,
video_token_id
,
image_token_id
]:
if
use_audio_in_video
and
idx
>
0
:
if
src_item
[
idx
]
==
vision_end_token_id
and
\
src_item
[
idx
-
1
]
==
audio_end_token_id
:
# processing the <|audio_eos|> before <|vision_eos|>
start_idx
-=
1
elif
src_item
[
idx
]
==
audio_start_token_id
and
\
src_item
[
idx
-
1
]
==
vision_start_token_id
:
# processing the <|audio_bos|> after <|vision_eos|>
start_idx
-=
1
new_src_item
.
append
(
src_item
[
idx
])
llm_pos_ids
=
torch
.
tensor
([
start_idx
],
dtype
=
torch
.
long
).
expand
(
3
,
-
1
)
llm_pos_ids_list
.
append
(
llm_pos_ids
)
elif
src_item
[
idx
]
==
audio_token_id
:
assert
audio_seqlens
is
not
None
audio_seqlen
=
audio_seqlens
[
audio_idx
]
place_num
=
(((
audio_seqlen
-
1
)
//
2
+
1
-
2
)
//
2
+
1
)
new_src_item
.
extend
([
audio_token_id
]
*
place_num
)
llm_pos_ids
=
torch
.
arange
(
place_num
).
expand
(
3
,
-
1
)
+
start_idx
llm_pos_ids_list
.
append
(
llm_pos_ids
)
audio_idx
+=
1
elif
src_item
[
idx
]
==
image_token_id
:
grid_t
=
image_grid_thw
[
image_idx
][
0
]
grid_hs
=
image_grid_thw
[:,
1
]
grid_ws
=
image_grid_thw
[:,
2
]
t_index
=
torch
.
arange
(
grid_t
)
*
1
*
tokens_per_second
llm_pos_ids
=
cls
.
_get_llm_pos_ids_for_vision
(
start_idx
,
image_idx
,
spatial_merge_size
,
t_index
,
grid_hs
,
grid_ws
)
llm_pos_ids_list
.
append
(
llm_pos_ids
)
vision_seqlen
=
image_grid_thw
[
image_idx
].
prod
()
//
(
spatial_merge_size
**
2
)
new_src_item
.
extend
([
image_token_id
]
*
vision_seqlen
)
image_idx
+=
1
elif
src_item
[
idx
]
==
video_token_id
and
not
use_audio_in_video
:
grid_t
=
video_grid_thw
[
video_idx
][
0
]
grid_hs
=
video_grid_thw
[:,
1
]
grid_ws
=
video_grid_thw
[:,
2
]
t_index
=
(
torch
.
arange
(
grid_t
)
*
second_per_grid_ts
[
video_idx
]
*
tokens_per_second
)
llm_pos_ids
=
cls
.
_get_llm_pos_ids_for_vision
(
start_idx
,
video_idx
,
spatial_merge_size
,
t_index
,
grid_hs
,
grid_ws
)
llm_pos_ids_list
.
append
(
llm_pos_ids
)
vision_seqlen
=
video_grid_thw
[
video_idx
].
prod
()
//
(
spatial_merge_size
**
2
)
new_src_item
.
extend
([
video_token_id
]
*
vision_seqlen
)
video_idx
+=
1
else
:
# read audio from video
assert
audio_seqlens
is
not
None
audio_seqlen
=
audio_seqlens
[
audio_idx
]
vision_seqlen
=
video_grid_thw
[
video_idx
].
prod
()
//
(
spatial_merge_size
**
2
)
grid_t
=
video_grid_thw
[
video_idx
][
0
]
grid_h
=
video_grid_thw
[
video_idx
][
1
]
grid_w
=
video_grid_thw
[
video_idx
][
2
]
grid_hs
=
video_grid_thw
[:,
1
]
grid_ws
=
video_grid_thw
[:,
2
]
t_ntoken_per_chunk
=
int
(
tokens_per_second
*
seconds_per_chunk
)
t_index
=
(
torch
.
arange
(
grid_t
)
*
second_per_grid_ts
[
video_idx
]
*
tokens_per_second
)
t_index_split_chunk
=
cls
.
_split_list_into_ranges
(
t_index
,
t_ntoken_per_chunk
)
place_num
=
(((
audio_seqlen
-
1
)
//
2
+
1
-
2
)
//
2
+
1
)
+
2
pure_audio_len
=
place_num
-
2
added_audio_len
=
0
audio_llm_pos_ids_list
:
list
[
torch
.
Tensor
]
=
[]
for
t_chunk
in
t_index_split_chunk
:
vision_ntoken_per_chunk
=
len
(
t_chunk
)
*
grid_h
*
grid_w
//
(
spatial_merge_size
**
2
)
new_src_item
.
extend
([
video_token_id
]
*
vision_ntoken_per_chunk
)
vision_llm_pos_ids_list
=
cls
.
_get_llm_pos_ids_for_vision
(
start_idx
,
video_idx
,
spatial_merge_size
,
t_chunk
,
grid_hs
,
grid_ws
).
split
(
1
,
dim
=
1
)
llm_pos_ids_list
.
extend
(
vision_llm_pos_ids_list
)
new_src_item
.
extend
(
min
(
t_ntoken_per_chunk
,
pure_audio_len
-
added_audio_len
)
*
[
audio_token_id
])
audio_start_idx
=
start_idx
if
len
(
audio_llm_pos_ids_list
)
==
0
else
audio_llm_pos_ids_list
[
-
1
][
0
].
item
()
+
1
if
min
(
t_ntoken_per_chunk
,
pure_audio_len
-
added_audio_len
)
>
0
:
audio_llm_pos_ids_list
=
(
torch
.
arange
(
min
(
t_ntoken_per_chunk
,
pure_audio_len
-
added_audio_len
)).
expand
(
3
,
-
1
)
+
audio_start_idx
).
split
(
1
,
dim
=
1
)
else
:
audio_llm_pos_ids_list
=
[]
added_audio_len
+=
min
(
t_ntoken_per_chunk
,
pure_audio_len
-
added_audio_len
)
llm_pos_ids_list
.
extend
(
audio_llm_pos_ids_list
)
if
added_audio_len
<
pure_audio_len
:
new_src_item
.
extend
(
(
pure_audio_len
-
added_audio_len
)
*
[
audio_token_id
])
audio_llm_pos_ids_list
=
(
torch
.
arange
(
pure_audio_len
-
added_audio_len
).
expand
(
3
,
-
1
)
+
llm_pos_ids_list
[
-
1
].
max
()
+
1
).
split
(
1
,
dim
=
1
)
llm_pos_ids_list
.
extend
(
audio_llm_pos_ids_list
)
audio_idx
+=
1
video_idx
+=
1
# move to the next token
idx
+=
len
(
new_src_item
)
-
new_src_item_len
llm_positions
=
torch
.
cat
(
llm_pos_ids_list
,
dim
=
1
)
mrope_position_delta
=
torch
.
cat
(
llm_pos_ids_list
,
dim
=
1
).
max
()
+
1
-
len
(
src_item
)
llm_positions
=
llm_positions
[:,
context_len
:
seq_len
]
return
llm_positions
,
mrope_position_delta
@
staticmethod
def
_get_llm_pos_ids_for_vision
(
start_idx
:
int
,
vision_idx
:
int
,
spatial_merge_size
:
int
,
t_index
:
list
[
int
],
grid_hs
:
torch
.
Tensor
,
grid_ws
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
llm_pos_ids_list
=
[]
llm_grid_h
=
grid_hs
[
vision_idx
]
//
spatial_merge_size
llm_grid_w
=
grid_ws
[
vision_idx
]
//
spatial_merge_size
h_index
=
(
torch
.
arange
(
llm_grid_h
).
view
(
1
,
-
1
,
1
).
expand
(
len
(
t_index
),
-
1
,
llm_grid_w
).
flatten
())
w_index
=
(
torch
.
arange
(
llm_grid_w
).
view
(
1
,
1
,
-
1
).
expand
(
len
(
t_index
),
llm_grid_h
,
-
1
).
flatten
())
t_index_tensor
=
torch
.
Tensor
(
t_index
).
to
(
llm_grid_h
.
device
).
view
(
-
1
,
1
).
expand
(
-
1
,
llm_grid_h
*
llm_grid_w
).
long
().
flatten
()
_llm_pos_ids
=
torch
.
stack
([
t_index_tensor
,
h_index
,
w_index
])
llm_pos_ids_list
.
append
(
_llm_pos_ids
+
start_idx
)
llm_pos_ids
=
torch
.
cat
(
llm_pos_ids_list
,
dim
=
1
)
return
llm_pos_ids
@
staticmethod
def
_split_list_into_ranges
(
lst
:
torch
.
Tensor
,
interval
:
int
)
->
list
[
list
[
int
]]:
ranges
:
list
[
list
[
int
]]
=
[[]
for
_
in
range
((
max
(
lst
)
//
interval
)
+
1
)]
for
num
in
lst
:
index
=
num
//
interval
ranges
[
index
].
append
(
num
)
return
ranges
@
staticmethod
def
get_next_input_positions
(
...
...
@@ -1254,68 +380,24 @@ class MRotaryEmbedding(RotaryEmbedding):
)
->
list
[
list
[
int
]]:
return
[
list
(
range
(
context_len
+
mrope_position_delta
,
seq_len
+
mrope_position_delta
))
for
_
in
range
(
3
)
range
(
context_len
+
mrope_position_delta
,
seq_len
+
mrope_position_delta
)
)
for
_
in
range
(
3
)
]
@
staticmethod
def
get_next_input_positions_tensor
(
out
:
np
.
ndarray
,
out_offset
:
int
,
mrope_position_delta
:
int
,
context_len
:
int
,
num_new_tokens
:
int
):
values
=
np
.
arange
(
mrope_position_delta
+
context_len
,
mrope_position_delta
+
context_len
+
num_new_tokens
,
dtype
=
out
.
dtype
)
out
[:,
out_offset
:
out_offset
+
num_new_tokens
]
=
values
@
classmethod
def
omni_get_updates_use_audio_in_video
(
cls
,
thinker_config
:
PretrainedConfig
,
audio_len
:
int
,
video_grid_thw
:
Union
[
list
[
int
],
torch
.
Tensor
],
video_second_per_grid_t
:
float
,
)
->
list
[
int
]:
"""Get video prompt updates when `use_audio_in_video` is True.
In this case, audio and vision update ids will be split into
chunks and interleaved (details in `_omni_get_input_positions_tensor`).
<|video_bos|><|VIDEO|><|video_eos|> =>
<|video_bos|><|audio_bos|>(... chunks ...)<|audio_eos|><|video_eos|>
"""
audio_token_id
=
thinker_config
.
audio_token_index
video_token_id
=
thinker_config
.
video_token_index
audio_start_token_id
=
thinker_config
.
audio_start_token_id
audio_end_token_id
=
thinker_config
.
audio_end_token_id
seconds_per_chunk
=
thinker_config
.
seconds_per_chunk
spatial_merge_size
=
thinker_config
.
vision_config
.
spatial_merge_size
tokens_per_second
=
getattr
(
thinker_config
.
vision_config
,
"tokens_per_second"
,
25
)
grid_t
=
video_grid_thw
[
0
]
grid_h
=
video_grid_thw
[
1
]
grid_w
=
video_grid_thw
[
2
]
t_ntoken_per_chunk
=
int
(
tokens_per_second
*
seconds_per_chunk
)
t_index
=
(
torch
.
arange
(
grid_t
)
*
video_second_per_grid_t
*
tokens_per_second
)
t_index_split_chunk
=
cls
.
_split_list_into_ranges
(
t_index
,
t_ntoken_per_chunk
)
updates
=
[
audio_start_token_id
]
added_audio_len
=
0
for
t_chunk
in
t_index_split_chunk
:
vision_ntoken_per_chunk
=
len
(
t_chunk
)
*
grid_h
*
grid_w
//
(
spatial_merge_size
**
2
)
updates
.
extend
([
video_token_id
]
*
vision_ntoken_per_chunk
)
audio_chunk_size
=
min
(
t_ntoken_per_chunk
,
audio_len
-
added_audio_len
)
updates
.
extend
(
audio_chunk_size
*
[
audio_token_id
])
added_audio_len
+=
audio_chunk_size
if
added_audio_len
<
audio_len
:
updates
.
extend
((
audio_len
-
added_audio_len
)
*
[
audio_token_id
])
updates
.
extend
([
audio_end_token_id
])
return
updates
def
get_next_input_positions_tensor
(
out
:
np
.
ndarray
,
out_offset
:
int
,
mrope_position_delta
:
int
,
context_len
:
int
,
num_new_tokens
:
int
,
):
values
=
np
.
arange
(
mrope_position_delta
+
context_len
,
mrope_position_delta
+
context_len
+
num_new_tokens
,
dtype
=
out
.
dtype
,
)
out
[:,
out_offset
:
out_offset
+
num_new_tokens
]
=
values
vllm/model_executor/models/interfaces.py
View file @
b12c902b
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Iterable
,
Mapping
,
MutableSequence
from
collections.abc
import
Iterable
,
Mapping
,
MutableSequence
,
Callable
from
typing
import
(
TYPE_CHECKING
,
ClassVar
,
Literal
,
Optional
,
Protocol
,
Union
,
overload
,
runtime_checkable
)
...
...
@@ -20,7 +20,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig
)
from
vllm.utils
import
supports_kw
from
.interfaces_base
import
is_pooling_model
from
.interfaces_base
import
is_pooling_model
,
VllmModel
if
TYPE_CHECKING
:
from
vllm.config
import
VllmConfig
...
...
@@ -81,10 +81,9 @@ class SupportsMultiModal(Protocol):
"""
...
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
"""
Returns multimodal embeddings generated from multimodal kwargs
Returns multimodal embeddings generated from multimodal kwargs
to be merged with text embeddings.
Note:
...
...
@@ -94,11 +93,11 @@ class SupportsMultiModal(Protocol):
"""
...
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
def
get_language_model
(
self
)
->
VllmModel
:
"""
Returns the underlying language model used for text generation.
This is typically the `torch.nn.Module` instance responsible for
This is typically the `torch.nn.Module` instance responsible for
processing the merged multimodal embeddings and producing hidden states
Returns:
...
...
@@ -106,19 +105,83 @@ class SupportsMultiModal(Protocol):
"""
...
@
overload
def
get_input_embeddings
(
self
,
input_ids
:
Tensor
)
->
Tensor
:
...
@
overload
def
get_input_embeddings
(
self
,
input_ids
:
Tensor
,
multimodal_embeddings
:
MultiModalEmbeddings
,
*
,
is_multimodal
:
torch
.
Tensor
,
handle_oov_mm_token
:
bool
=
False
,
)
->
Tensor
:
...
def
_get_text_embeddings
(
self
,
input_ids
:
Tensor
,
get_input_embeddings
:
Callable
[[
Tensor
],
Tensor
],
*
,
is_multimodal
:
Optional
[
Tensor
],
handle_oov_mm_token
:
bool
,
)
->
Tensor
:
if
handle_oov_mm_token
and
is_multimodal
is
not
None
:
is_text
=
~
is_multimodal
text_embeds
=
get_input_embeddings
(
input_ids
[
is_text
])
return
torch
.
empty
(
(
input_ids
.
shape
[
0
],
text_embeds
.
shape
[
1
]),
dtype
=
text_embeds
.
dtype
,
device
=
text_embeds
.
device
,
).
masked_scatter_
(
is_text
.
unsqueeze_
(
-
1
),
text_embeds
)
return
get_input_embeddings
(
input_ids
)
def
get_input_embeddings
(
self
,
input_ids
:
Tensor
,
multimodal_embeddings
:
Optional
[
MultiModalEmbeddings
]
=
None
,
*
,
is_multimodal
:
Optional
[
Tensor
]
=
None
,
handle_oov_mm_token
:
bool
=
False
,
)
->
Tensor
:
"""
Returns the input embeddings merged from the text embeddings from
input_ids and the multimodal embeddings generated from multimodal
kwargs.
Apply token embeddings to `input_ids`.
If `multimodal_embeddings` is passed, scatter them into
`input_ids` according to the mask `is_multimodal`.
In case the multi-modal token IDs exceed the vocabulary size of
the language model, you can set `handle_oov_mm_token=False`
to avoid calling the language model's `get_input_embeddings` method
on those tokens. Note however that doing so increases memory usage
as an additional buffer is needed to hold the input embeddings.
"""
...
from
.utils
import
_merge_multimodal_embeddings
inputs_embeds
=
self
.
_get_text_embeddings
(
input_ids
,
self
.
get_language_model
().
get_input_embeddings
,
is_multimodal
=
is_multimodal
,
handle_oov_mm_token
=
handle_oov_mm_token
,
)
if
multimodal_embeddings
is
None
or
len
(
multimodal_embeddings
)
==
0
:
return
inputs_embeds
if
is_multimodal
is
None
:
raise
ValueError
(
"`get_input_embeddings` now requires `is_multimodal` arg, "
"please update your model runner according to "
"https://github.com/vllm-project/vllm/pull/16229."
)
return
_merge_multimodal_embeddings
(
inputs_embeds
=
inputs_embeds
,
multimodal_embeddings
=
multimodal_embeddings
,
is_multimodal
=
is_multimodal
,
)
@
runtime_checkable
class
SupportsMultiModalPruning
(
Protocol
):
"""The interface required for models that support returning both input
...
...
vllm/model_executor/models/qwen3_omni_moe_thinker.py
View file @
b12c902b
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 The Qwen team.
# Copyright 2023 The vLLM team.
...
...
@@ -22,6 +22,7 @@
# limitations under the License.
"""Inference-only Qwen3-Omni-Moe model (thinker part)."""
import
os
import
math
from
collections.abc
import
Callable
,
Iterable
,
Mapping
,
Sequence
...
...
@@ -48,7 +49,9 @@ from transformers.models.qwen3_omni_moe.processing_qwen3_omni_moe import (
)
from
transformers.models.whisper
import
WhisperFeatureExtractor
from
vllm.attention.backends.registry
import
_Backend
# from vllm.attention.backends.registry import _Backend
from
vllm.platforms
import
_Backend
,
current_platform
from
vllm.attention.layer
import
check_upstream_fa_availability
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
VllmConfig
...
...
@@ -106,6 +109,7 @@ from .utils import (
_merge_multimodal_embeddings
,
maybe_prefix
,
)
from
.vision
import
(
conv3d_to_linear_weight
,
get_llm_pos_ids_for_vision
,
...
...
@@ -143,18 +147,28 @@ class Qwen3_VisionPatchEmbed(nn.Module):
self
.
hidden_size
=
hidden_size
kernel_size
=
(
temporal_patch_size
,
patch_size
,
patch_size
)
self
.
proj
=
ReplicatedLinear
(
in_channels
*
math
.
prod
(
kernel_size
),
# self.proj = ReplicatedLinear(
# in_channels * math.prod(kernel_size),
# hidden_size,
# bias=True,
# return_bias=False,
# )
self
.
proj
=
nn
.
Conv3d
(
in_channels
,
hidden_size
,
kernel_size
=
kernel_size
,
stride
=
kernel_size
,
bias
=
True
,
return_bias
=
False
,
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
L
,
C
=
x
.
shape
if
os
.
environ
.
get
(
'PYTORCH_MIOPEN_SUGGEST_NDHWC'
)
==
'1'
:
x
=
x
.
to
(
memory_format
=
torch
.
channels_last_3d
)
x
=
self
.
proj
(
x
)
x
=
x
.
view
(
L
,
-
1
,
self
.
temporal_patch_size
,
self
.
patch_size
,
self
.
patch_size
)
# if os.environ.get('PYTORCH_MIOPEN_SUGGEST_NDHWC') == '1':
# x = x.to(memory_format=torch.channels_last_3d)
x
=
self
.
proj
(
x
).
view
(
L
,
self
.
hidden_size
)
return
x
...
...
@@ -308,7 +322,6 @@ class Qwen3Omni_VisionTransformer(nn.Module):
norm_eps
:
float
=
1e-6
,
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
attn_backend_override
:
_Backend
|
None
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
vision_config
.
hidden_size
...
...
@@ -380,9 +393,7 @@ class Qwen3Omni_VisionTransformer(nn.Module):
)
self
.
attn_backend
=
get_vit_attn_backend
(
head_size
=
head_dim
,
dtype
=
torch
.
get_default_dtype
(),
attn_backend_override
=
attn_backend_override
,
head_size
=
head_dim
,
dtype
=
torch
.
get_default_dtype
()
)
if
self
.
attn_backend
!=
_Backend
.
FLASH_ATTN
and
check_upstream_fa_availability
(
torch
.
get_default_dtype
()
...
...
@@ -571,8 +582,8 @@ class Qwen3Omni_VisionTransformer(nn.Module):
loaded_params
:
set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
name
.
endswith
(
"patch_embed.proj.weight"
):
loaded_weight
=
conv3d_to_linear_weight
(
loaded_weight
)
#
if name.endswith("patch_embed.proj.weight"):
#
loaded_weight = conv3d_to_linear_weight(loaded_weight)
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
...
...
@@ -811,7 +822,8 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
if
is_update_applied
:
prompt_ids
=
self
.
_get_raw_input_ids
(
prompt_ids
,
use_audio_in_video
)
(
prompt_ids
,
prompt_ids
,
prompt
,
mm_placeholders
,
)
=
self
.
_apply_prompt_updates
(
prompt_ids
,
...
...
@@ -829,7 +841,7 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
mm_item_counts
,
)
else
:
prompt_ids
,
mm_placeholders
=
self
.
_apply_prompt_updates
(
prompt_ids
,
prompt
,
mm_placeholders
=
self
.
_apply_prompt_updates
(
prompt_ids
,
mm_prompt_updates
,
)
...
...
@@ -837,8 +849,7 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
mm_placeholders
,
mm_item_counts
,
)
return
prompt_ids
,
mm_placeholders
return
prompt_ids
,
prompt
,
mm_placeholders
def
get_updates_use_audio_in_video
(
self
,
...
...
@@ -1160,18 +1171,11 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
)
self
.
audio_tower
=
Qwen3OmniMoeAudioEncoder
(
thinker_config
.
audio_config
)
attn_backend_override
=
(
multimodal_config
.
mm_encoder_attn_backend
if
multimodal_config
is
not
None
else
None
)
self
.
visual
=
Qwen3Omni_VisionTransformer
(
vision_config
=
thinker_config
.
vision_config
,
norm_eps
=
getattr
(
thinker_config
.
text_config
,
"rms_norm_eps"
,
1e-6
),
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"visual"
),
attn_backend_override
=
attn_backend_override
,
)
self
.
quant_config
=
quant_config
...
...
@@ -1375,7 +1379,6 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
.
contiguous
()
)
self
.
_set_deepstack_input_embeds
(
deepstack_input_embeds
)
inputs_embeds
=
_merge_multimodal_embeddings
(
inputs_embeds
=
inputs_embeds
,
multimodal_embeddings
=
multimodal_embeddings
,
...
...
@@ -1434,7 +1437,9 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
loaded_weights
=
loader
.
load_weights
(
weights
,
mapper
=
self
.
hf_to_vllm_mapper
)
return
loaded_weights
@
classmethod
def
get_mrope_input_positions
(
self
,
input_tokens
:
list
[
int
],
...
...
vllm/model_executor/models/utils.py
View file @
b12c902b
...
...
@@ -10,6 +10,7 @@ import torch
import
torch.nn
as
nn
from
torch.func
import
functional_call
from
transformers
import
PretrainedConfig
from
typing_extensions
import
deprecated
import
vllm.envs
as
envs
from
vllm.config
import
VllmConfig
...
...
@@ -391,92 +392,79 @@ def _embedding_count_expression(embeddings: NestedTensors) -> str:
return
" + "
.
join
(
_embedding_count_expression
(
inner
)
for
inner
in
embeddings
)
def
split_list_into_ranges
(
lst
:
torch
.
Tensor
,
interval
:
int
)
->
list
[
list
[
int
]]:
ranges
:
list
[
list
[
int
]]
=
[[]
for
_
in
range
((
max
(
lst
)
//
interval
)
+
1
)]
for
num
in
lst
:
index
=
num
//
interval
ranges
[
index
].
append
(
num
)
return
ranges
def
_merge_multimodal_embeddings
(
inputs_embeds
:
torch
.
Tensor
,
is_multimodal
:
torch
.
Tensor
,
multimodal_embeddings
:
NestedTensors
,
is_multimodal
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""
Merge
`
`multimodal_embeddings`
`
into
`
`inputs_embeds`
`
by overwriting the
positions in
`
`inputs_embeds`
`
corresponding to placeholder tokens in
`
`input_ids`
`
.
Merge `multimodal_embeddings` into `inputs_embeds` by overwriting the
positions in `inputs_embeds` corresponding to placeholder tokens in
`input_ids`.
Note:
This updates
`
`inputs_embeds`
`
in place.
This updates `inputs_embeds` in place.
"""
flattened
=
_flatten_embeddings
(
multimodal_embeddings
)
if
len
(
multimodal_embeddings
)
==
0
:
return
inputs_embeds
mm_embeds_flat
=
_flatten_embeddings
(
multimodal_embeddings
)
input_dtype
=
inputs_embeds
.
dtype
try
:
# This is equivalent to: inputs_embeds[is_multimodal] = flattened.
inputs_embeds
.
masked_scatter_
(
is_multimodal
.
unsqueeze
(
-
1
),
flattened
.
to
(
dtype
=
inputs_embeds
.
dtype
))
# For debugging
# inputs_embeds[is_multimodal] = mm_embeds_flat.to(dtype=input_dtype)
# NOTE: This can avoid D2H sync (#22105), but fails to
# raise an error if is_multimodal.sum() < len(mm_embeds_flat)
inputs_embeds
.
masked_scatter_
(
is_multimodal
.
unsqueeze
(
-
1
),
mm_embeds_flat
.
to
(
dtype
=
input_dtype
)
)
except
RuntimeError
as
e
:
num_actual_tokens
=
len
(
mm_embeds_flat
)
num_expected_tokens
=
is_multimodal
.
sum
().
item
()
assert
isinstance
(
num_expected_tokens
,
int
)
if
flattened
.
shape
[
0
]
!=
num_expected_tokens
:
if
num_actual_tokens
!=
num_expected_tokens
:
expr
=
_embedding_count_expression
(
multimodal_embeddings
)
raise
ValueError
(
f
"Attempted to assign
{
expr
}
=
{
flattened
.
shape
[
0
]
}
"
f
"Attempted to assign
{
expr
}
=
{
num_actual_tokens
}
"
f
"multimodal tokens to
{
num_expected_tokens
}
placeholders"
)
from
e
else
:
raise
ValueError
(
"Error during masked scatter operation"
)
from
e
return
inputs_embeds
def
embed_multimodal
(
input_ids
:
torch
.
Tensor
,
multimodal_token_id
:
int
,
get_text_embeds
:
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
],
multimodal_embeds
:
NestedTensors
,
)
->
torch
.
Tensor
:
"""
Embed token IDs and multimodal inputs and combine their embeddings.
``multimodal_token_id`` is used to determine whether a token ID should
be embedded using ``get_text_embeds`` or ``get_multimodal_embeds``.
Compared to ``merge_multimodal_embeddings`, this avoids running
``get_text_embeds`` on ``input_ids[input_ids == multimodal_token_id]``
which causes issues when the placeholder token ID exceeds the
vocabulary size of the language model.
"""
is_multimodal
=
input_ids
==
multimodal_token_id
is_text
=
~
is_multimodal
text_embeds
=
get_text_embeds
(
input_ids
[
is_text
])
merged_embeds
=
torch
.
empty
(
(
input_ids
.
shape
[
0
],
text_embeds
.
shape
[
1
]),
dtype
=
text_embeds
.
dtype
,
device
=
text_embeds
.
device
,
)
merged_embeds
[
is_text
]
=
text_embeds
raise
ValueError
(
"Error during masked scatter operation"
)
from
e
return
_merge_multimodal_embeddings
(
merged_embeds
,
is_multimodal
,
multimodal_embeds
,
)
return
inputs_embeds
@
deprecated
(
"`merge_multimodal_embeddings` has been replaced with "
"`SupportsMultiModal.get_input_embeddings` and will be "
"removed in v0.12."
)
def
merge_multimodal_embeddings
(
input_ids
:
torch
.
Tensor
,
inputs_embeds
:
torch
.
Tensor
,
multimodal_embeddings
:
NestedTensors
,
placeholder_token_id
:
Union
[
int
,
list
[
int
]
]
,
placeholder_token_id
:
int
|
list
[
int
],
)
->
torch
.
Tensor
:
"""
Merge
`
`multimodal_embeddings`
`
into
`
`inputs_embeds`
`
by overwriting the
positions in
`
`inputs_embeds`
`
corresponding to placeholder tokens in
`
`input_ids`
`
.
Merge `multimodal_embeddings` into `inputs_embeds` by overwriting the
positions in `inputs_embeds` corresponding to placeholder tokens in
`input_ids`.
`
`placeholder_token_id`
`
can be a list of token ids (e.g, token ids
`placeholder_token_id` can be a list of token ids (e.g, token ids
of img_start, img_break, and img_end tokens) when needed: This means
the order of these tokens in the
`
`input_ids`
`
MUST MATCH the order of
their embeddings in
`
`multimodal_embeddings`
`
since we need to
the order of these tokens in the `input_ids` MUST MATCH the order of
their embeddings in `multimodal_embeddings` since we need to
slice-merge instead of individually scattering.
For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where
...
...
@@ -491,26 +479,32 @@ def merge_multimodal_embeddings(
input_ids for a correct embedding merge.
Note:
This updates
`
`inputs_embeds`
`
in place.
This updates `inputs_embeds` in place.
"""
if
isinstance
(
placeholder_token_id
,
list
):
placeholder_token_id
=
torch
.
tensor
(
placeholder_token_id
,
pin_memory
=
is_pin_memory_available
()).
to
(
device
=
input_ids
.
device
,
non_blocking
=
True
)
return
_merge_multimodal_embeddings
(
inputs_embeds
,
torch
.
isin
(
input_ids
,
placeholder_token_id
),
multimodal_embeddings
,
)
is_multimodal
=
isin_list
(
input_ids
,
placeholder_token_id
)
else
:
is_multimodal
=
input_ids
==
placeholder_token_id
return
_merge_multimodal_embeddings
(
inputs_embeds
,
(
input_ids
==
placeholder_token_id
)
,
multimodal
_embeddings
,
multimodal_embeddings
=
multimodal_embeddings
,
is_
multimodal
=
is_multimodal
,
)
def
isin_list
(
elements
:
torch
.
Tensor
,
test_elements_list
:
list
[
int
],
)
->
torch
.
Tensor
:
test_elements
=
torch
.
tensor
(
test_elements_list
,
pin_memory
=
is_pin_memory_available
(),
).
to
(
device
=
elements
.
device
,
non_blocking
=
True
)
return
torch
.
isin
(
elements
,
test_elements
)
class
LayerFn
(
Protocol
):
def
__call__
(
self
,
prefix
:
str
)
->
torch
.
nn
.
Module
:
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
b12c902b
...
...
@@ -368,6 +368,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dtype
=
torch
.
int32
)
self
.
num_accepted_tokens
=
self
.
_make_buffer
(
self
.
max_num_reqs
,
dtype
=
torch
.
int64
)
# Only relevant for multimodal models
if
self
.
supports_mm_inputs
:
self
.
is_mm_embed
=
self
.
_make_buffer
(
self
.
max_num_tokens
,
dtype
=
torch
.
bool
)
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if
self
.
uses_mrope
:
...
...
@@ -1612,17 +1615,23 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
,
scheduler_output
:
"SchedulerOutput"
,
shift_computed_tokens
:
int
=
0
,
)
->
list
[
torch
.
Tensor
]:
)
->
tuple
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
total_num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
mm_embeds
=
list
[
torch
.
Tensor
]()
is_mm_embed
=
self
.
is_mm_embed
.
cpu
is_mm_embed
[:
total_num_scheduled_tokens
]
=
False
req_start_idx
=
0
should_sync_mrope_positions
=
False
mm_embeds
:
list
[
torch
.
Tensor
]
=
[]
for
req_id
in
self
.
input_batch
.
req_ids
:
mm_embeds_req
:
list
[
torch
.
Tensor
]
=
[]
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
req_state
=
self
.
requests
[
req_id
]
num_computed_tokens
=
\
req_state
.
num_computed_tokens
+
shift_computed_tokens
num_computed_tokens
=
req_state
.
num_computed_tokens
+
shift_computed_tokens
for
mm_feature
in
req_state
.
mm_features
:
pos_info
=
mm_feature
.
mm_position
start_pos
=
pos_info
.
offset
...
...
@@ -1649,12 +1658,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
mm_hash
=
mm_feature
.
identifier
encoder_output
=
self
.
encoder_cache
.
get
(
mm_hash
,
None
)
assert
encoder_output
is
not
None
,
\
f
"Encoder cache miss for
{
mm_hash
}
."
assert
encoder_output
is
not
None
,
f
"Encoder cache miss for
{
mm_hash
}
."
if
(
is_embed
:
=
pos_info
.
is_embed
)
is
not
None
:
is_embed
=
is_embed
[
start_idx
:
end_idx
]
req_start_pos
=
req_start_idx
+
start_pos
-
num_computed_tokens
is_mm_embed
[
req_start_pos
+
start_idx
:
req_start_pos
+
end_idx
]
=
(
True
if
is_embed
is
None
else
is_embed
)
mm_embeds_item
=
gather_mm_placeholders
(
encoder_output
[
start_idx
:
end_idx
],
is_embed
=
is_embed
,
...
...
@@ -1662,6 +1675,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
mm_embeds_req
.
append
(
mm_embeds_item
)
if
self
.
is_multimodal_pruning_enabled
and
self
.
uses_mrope
:
assert
req_state
.
mrope_positions
is
not
None
should_sync_mrope_positions
=
True
mm_embeds_req
,
new_mrope_positions
,
new_delta
=
(
self
.
model
.
recompute_mrope_positions
(
...
...
@@ -1669,19 +1683,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
multimodal_embeddings
=
mm_embeds_req
,
mrope_positions
=
req_state
.
mrope_positions
,
num_computed_tokens
=
req_state
.
num_computed_tokens
,
)
)
assert
req_state
.
mrope_positions
is
not
None
)
)
req_state
.
mrope_positions
.
copy_
(
new_mrope_positions
)
req_state
.
mrope_position_delta
=
new_delta
mm_embeds
.
extend
(
mm_embeds_req
)
req_start_idx
+=
num_scheduled_tokens
is_mm_embed
=
self
.
is_mm_embed
.
copy_to_gpu
(
total_num_scheduled_tokens
)
if
should_sync_mrope_positions
:
self
.
_calc_mrope_positions
(
scheduler_output
)
self
.
mrope_positions
.
copy_to_gpu
(
scheduler_output
.
total_num_scheduled_tokens
)
self
.
mrope_positions
.
copy_to_gpu
(
total_num_scheduled_tokens
)
return
mm_embeds
return
mm_embeds
,
is_mm_embed
def
_extract_encoder_inputs
(
self
,
...
...
@@ -1975,7 +1991,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
and
not
self
.
model_config
.
is_encoder_decoder
):
# Run the multimodal encoder if any.
self
.
_execute_mm_encoder
(
scheduler_output
)
mm_embeds
=
self
.
_gather_mm_embeddings
(
scheduler_output
)
mm_embeds
,
is_mm_embed
=
self
.
_gather_mm_embeddings
(
scheduler_output
)
# NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids)
...
...
@@ -1983,6 +1999,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
inputs_embeds_scheduled
=
self
.
model
.
get_input_embeddings
(
input_ids
=
self
.
input_ids
.
gpu
[:
num_scheduled_tokens
],
multimodal_embeddings
=
mm_embeds
or
None
,
is_multimodal
=
is_mm_embed
,
)
# TODO(woosuk): Avoid the copy. Optimize.
...
...
vllm/version.py
View file @
b12c902b
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
try
:
from
._version
import
__version__
,
__version_tuple__
__version__
=
"0.11.0"
__version_tuple__
=
(
0
,
11
,
0
)
__hcu_version__
=
f
'0.11.0+das.opt1.alpha.6c015e7.dtk25041'
from
vllm.version
import
__version__
,
__version_tuple__
,
__hcu_version__
except
Exception
as
e
:
import
warnings
warnings
.
warn
(
f
"Failed to read commit hash:
\n
{
e
}
"
,
warnings
.
warn
(
f
"Failed to read commit hash:
\n
+ str(e)
"
,
RuntimeWarning
,
stacklevel
=
2
)
__version__
=
"dev"
__version_tuple__
=
(
0
,
0
,
__version__
)
def
_prev_minor_version_was
(
version_str
):
"""
Check whether a given version matches the previous minor version.
'''
Check whether a given version matches the previous minor version.
Return True if version_str matches the previous minor version.
...
...
@@ -23,19 +24,19 @@ def _prev_minor_version_was(version_str):
supplied version_str is '0.6'.
Used for --show-hidden-metrics-for-version.
"""
'''
# Match anything if this is a dev tree
if
__version_tuple__
[
0
:
2
]
==
(
0
,
0
):
return
True
# Note - this won't do the right thing when we release 1.0!
assert
__version_tuple__
[
0
]
==
0
#
assert __version_tuple__[0] == 0
assert
isinstance
(
__version_tuple__
[
1
],
int
)
return
version_str
==
f
"
{
__version_tuple__
[
0
]
}
.
{
__version_tuple__
[
1
]
-
1
}
"
def
_prev_minor_version
():
"""
For the purpose of testing, return a previous minor version number.
"""
'''
For the purpose of testing, return a previous minor version number.
'''
# In dev tree, this will return "0.-1", but that will work fine"
assert
isinstance
(
__version_tuple__
[
1
],
int
)
return
f
"
{
__version_tuple__
[
0
]
}
.
{
__version_tuple__
[
1
]
-
1
}
"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment