Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
4395c87a
Unverified
Commit
4395c87a
authored
Jul 17, 2025
by
Mick
Committed by
GitHub
Jul 16, 2025
Browse files
refactor: unify names of the feature field of MultimodalDataItem (#8075)
parent
c28ad199
Changes
33
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
46 additions
and
69 deletions
+46
-69
python/sglang/srt/managers/multimodal_processors/qwen_audio.py
...n/sglang/srt/managers/multimodal_processors/qwen_audio.py
+1
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+14
-27
python/sglang/srt/models/clip.py
python/sglang/srt/models/clip.py
+1
-1
python/sglang/srt/models/deepseek_janus_pro.py
python/sglang/srt/models/deepseek_janus_pro.py
+1
-1
python/sglang/srt/models/deepseek_vl2.py
python/sglang/srt/models/deepseek_vl2.py
+2
-2
python/sglang/srt/models/gemma3_mm.py
python/sglang/srt/models/gemma3_mm.py
+1
-1
python/sglang/srt/models/gemma3n_mm.py
python/sglang/srt/models/gemma3n_mm.py
+2
-4
python/sglang/srt/models/internvl.py
python/sglang/srt/models/internvl.py
+1
-1
python/sglang/srt/models/kimi_vl.py
python/sglang/srt/models/kimi_vl.py
+1
-1
python/sglang/srt/models/llava.py
python/sglang/srt/models/llava.py
+2
-2
python/sglang/srt/models/llavavid.py
python/sglang/srt/models/llavavid.py
+1
-1
python/sglang/srt/models/minicpmo.py
python/sglang/srt/models/minicpmo.py
+3
-7
python/sglang/srt/models/minicpmv.py
python/sglang/srt/models/minicpmv.py
+1
-1
python/sglang/srt/models/mistral.py
python/sglang/srt/models/mistral.py
+1
-1
python/sglang/srt/models/mllama.py
python/sglang/srt/models/mllama.py
+3
-5
python/sglang/srt/models/mllama4.py
python/sglang/srt/models/mllama4.py
+1
-1
python/sglang/srt/models/phi4mm.py
python/sglang/srt/models/phi4mm.py
+1
-3
python/sglang/srt/models/qwen2_5_vl.py
python/sglang/srt/models/qwen2_5_vl.py
+4
-4
python/sglang/srt/models/qwen2_audio.py
python/sglang/srt/models/qwen2_audio.py
+1
-1
python/sglang/srt/models/qwen2_vl.py
python/sglang/srt/models/qwen2_vl.py
+4
-4
No files found.
python/sglang/srt/managers/multimodal_processors/qwen_audio.py
View file @
4395c87a
...
...
@@ -78,7 +78,7 @@ class Qwen2AudioMultimodalProcessor(BaseMultimodalProcessor):
output_lengths
=
(
input_lengths
-
2
)
//
2
+
1
item
=
MultimodalDataItem
(
audio_
feature
s
=
res
[
"input_features"
],
feature
=
res
[
"input_features"
],
audio_feature_lens
=
output_lengths
,
audio_offsets
=
audio_offsets
,
modality
=
Modality
.
AUDIO
,
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
4395c87a
...
...
@@ -207,13 +207,12 @@ class MultimodalDataItem:
modality
:
Modality
hash
:
int
=
None
pad_value
:
int
=
None
image_sizes
:
Tuple
[
int
,
int
]
=
None
offsets
:
Optional
[
list
]
=
None
# the raw features returned by processor, e.g. pixel_values or audio_features
feature
:
Union
[
torch
.
Tensor
,
np
.
ndarray
]
=
None
image_sizes
:
Tuple
[
int
,
int
]
=
None
# the real data, pixel_values or audio_features
# data: Union[List[torch.Tensor], List[np.ndarray]]
pixel_values
:
Union
[
torch
.
Tensor
,
np
.
ndarray
,
"PIL.Image"
]
=
None
audio_features
:
Union
[
torch
.
Tensor
,
np
.
ndarray
]
=
None
audio_feature_lens
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
audio_offsets
:
Optional
[
List
[
Tuple
[
int
,
int
]]]
=
None
precomputed_features
:
Optional
[
Union
[
torch
.
Tensor
,
np
.
ndarray
]]
=
None
...
...
@@ -238,7 +237,6 @@ class MultimodalDataItem:
image_grid_hws
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
# For gemma3n
input_features
:
Optional
[
torch
.
Tensor
]
=
None
input_features_mask
:
Optional
[
torch
.
Tensor
]
=
None
@
staticmethod
...
...
@@ -254,18 +252,11 @@ class MultimodalDataItem:
from
sglang.srt.managers.mm_utils
import
hash_feature
if
self
.
hash
is
None
:
if
self
.
precomputed_features
is
not
None
:
self
.
hash
=
hash_feature
(
self
.
precomputed_features
)
elif
self
.
is_audio
():
if
self
.
audio_features
is
not
None
:
self
.
hash
=
hash_feature
(
self
.
audio_features
)
elif
self
.
input_features
is
not
None
:
self
.
hash
=
hash_feature
(
self
.
input_features
)
elif
self
.
is_video
():
self
.
hash
=
hash_feature
(
self
.
pixel_values_videos
)
if
self
.
feature
is
not
None
:
hashed_feature
=
self
.
feature
else
:
self
.
hash
=
hash_feature
(
self
.
p
ixel_valu
es
)
hash
ed
_feature
=
self
.
p
recomputed_featur
es
self
.
hash
=
hash_feature
(
hashed_feature
)
assert
self
.
hash
is
not
None
self
.
pad_value
=
self
.
hash
%
(
1
<<
30
)
...
...
@@ -275,8 +266,7 @@ class MultimodalDataItem:
def
is_audio
(
self
):
return
(
self
.
modality
==
Modality
.
AUDIO
)
and
(
self
.
precomputed_features
is
not
None
or
not
MultimodalDataItem
.
is_empty_list
(
self
.
audio_features
)
or
not
MultimodalDataItem
.
is_empty_list
(
self
.
input_features
)
or
not
MultimodalDataItem
.
is_empty_list
(
self
.
feature
)
)
def
is_image
(
self
):
...
...
@@ -284,13 +274,13 @@ class MultimodalDataItem:
self
.
is_modality
(
Modality
.
IMAGE
)
or
self
.
is_modality
(
Modality
.
MULTI_IMAGES
)
)
and
(
self
.
precomputed_features
is
not
None
or
not
MultimodalDataItem
.
is_empty_list
(
self
.
pixel_values
)
or
not
MultimodalDataItem
.
is_empty_list
(
self
.
feature
)
)
def
is_video
(
self
):
return
(
self
.
modality
==
Modality
.
VIDEO
)
and
(
self
.
precomputed_features
is
not
None
or
not
MultimodalDataItem
.
is_empty_list
(
self
.
pixel_values_videos
)
or
not
MultimodalDataItem
.
is_empty_list
(
self
.
feature
)
)
def
is_valid
(
self
)
->
bool
:
...
...
@@ -311,7 +301,7 @@ class MultimodalDataItem:
return
ret
def
merge
(
self
,
other
):
self
.
pixel_values
+=
other
.
pixel_values
self
.
feature
+=
other
.
feature
self
.
image_sizes
+=
other
.
image_sizes
self
.
image_offsets
+=
other
.
image_offsets
self
.
hash
=
hash
((
self
.
hash
,
other
.
hash
))
...
...
@@ -354,7 +344,6 @@ class MultimodalInputs:
assert
isinstance
(
ret
.
mm_items
,
list
)
ret
.
mm_items
=
[
item
for
item
in
ret
.
mm_items
if
item
.
is_valid
()]
for
item
in
ret
.
mm_items
:
item
.
set_pad_value
()
...
...
@@ -1278,11 +1267,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
if
mm_input
is
None
:
continue
for
mm_item
in
mm_input
.
mm_items
:
pixel_values
=
getattr
(
mm_item
,
"
pixel_values
"
,
None
)
pixel_values
=
getattr
(
mm_item
,
"
feature
"
,
None
)
if
isinstance
(
pixel_values
,
torch
.
Tensor
):
mm_item
.
pixel_values
=
pixel_values
.
to
(
self
.
device
,
non_blocking
=
True
)
mm_item
.
feature
=
pixel_values
.
to
(
self
.
device
,
non_blocking
=
True
)
self
.
multimodal_inputs
=
multimodal_inputs
self
.
token_type_ids
=
token_type_ids_tensor
self
.
seq_lens_sum
=
sum
(
seq_lens
)
...
...
python/sglang/srt/models/clip.py
View file @
4395c87a
...
...
@@ -463,7 +463,7 @@ class CLIPModel(nn.Module):
if
forward_batch
.
mm_inputs
is
not
None
:
mm_inputs
=
forward_batch
.
mm_inputs
pixel_values_list
=
[
item
.
pixel_values
item
.
feature
for
item
in
flatten_nested_list
(
[
mm_input
.
mm_items
for
mm_input
in
mm_inputs
if
mm_input
is
not
None
]
)
...
...
python/sglang/srt/models/deepseek_janus_pro.py
View file @
4395c87a
...
...
@@ -1960,7 +1960,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
self
.
logits_processor
=
LogitsProcessor
(
config
)
def
get_image_feature
(
self
,
items
:
List
[
MultimodalDataItem
])
->
torch
.
Tensor
:
pixel_values
=
torch
.
concat
([
item
.
pixel_values
for
item
in
items
],
dim
=
0
)
pixel_values
=
torch
.
concat
([
item
.
feature
for
item
in
items
],
dim
=
0
)
bs
,
n
=
pixel_values
.
shape
[
0
:
2
]
pixel_values
=
pixel_values
.
to
(
device
=
self
.
vision_model
.
device
,
dtype
=
self
.
vision_model
.
dtype
...
...
python/sglang/srt/models/deepseek_vl2.py
View file @
4395c87a
...
...
@@ -268,9 +268,9 @@ class DeepseekVL2ForCausalLM(nn.Module):
# TODO: can it be batched ?
images_in_this_batch
=
[]
for
item
in
items
:
assert
item
.
pixel_values
.
dim
()
==
4
assert
item
.
feature
.
dim
()
==
4
image_feature
=
self
.
vision
.
forward_features
(
item
.
pixel_values
.
type
(
next
(
self
.
vision
.
parameters
()).
dtype
).
to
(
item
.
feature
.
type
(
next
(
self
.
vision
.
parameters
()).
dtype
).
to
(
device
=
next
(
self
.
vision
.
parameters
()).
device
)
)
...
...
python/sglang/srt/models/gemma3_mm.py
View file @
4395c87a
...
...
@@ -283,7 +283,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
# Process images one by one to handle flatten_batch=True constraint in vision_tower
all_pixel_values
=
flatten_nested_list
([
item
.
pixel_values
for
item
in
items
])
all_pixel_values
=
flatten_nested_list
([
item
.
feature
for
item
in
items
])
vision_outputs_list
=
[]
for
pixel_values_batch
in
all_pixel_values
:
...
...
python/sglang/srt/models/gemma3n_mm.py
View file @
4395c87a
...
...
@@ -265,7 +265,7 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
# Process images one by one to handle flatten_batch=True constraint in vision_tower
all_pixel_values
=
flatten_nested_list
([
item
.
pixel_values
for
item
in
items
])
all_pixel_values
=
flatten_nested_list
([
item
.
feature
for
item
in
items
])
vision_outputs_list
=
[]
for
pixel_values_batch
in
all_pixel_values
:
...
...
@@ -316,9 +316,7 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
audio_features (`torch.Tensor`): Audio feature tensor of shape `(num_audios, audio_length, embed_dim)`).
"""
# Extract audio features and masks from items
all_input_features
=
flatten_nested_list
(
[
item
.
input_features
for
item
in
items
]
)
all_input_features
=
flatten_nested_list
([
item
.
feature
for
item
in
items
])
all_input_features_mask
=
flatten_nested_list
(
[
~
item
.
input_features_mask
for
item
in
items
]
)
# Note(Xinyuan): reverse the mask according to the HF implementation
...
...
python/sglang/srt/models/internvl.py
View file @
4395c87a
...
...
@@ -510,7 +510,7 @@ class InternVLChatModel(nn.Module):
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
pixel_values
=
torch
.
cat
([
item
.
pixel_values
for
item
in
items
])
pixel_values
=
torch
.
cat
([
item
.
feature
for
item
in
items
])
image_features
=
self
.
extract_feature
(
pixel_values
)
return
image_features
...
...
python/sglang/srt/models/kimi_vl.py
View file @
4395c87a
...
...
@@ -144,7 +144,7 @@ class KimiVLForConditionalGeneration(nn.Module):
def
get_image_feature
(
self
,
items
:
List
[
MultimodalDataItem
])
->
torch
.
Tensor
:
pixel_values
=
(
torch
.
cat
([
item
.
pixel_values
for
item
in
items
],
dim
=
0
)
torch
.
cat
([
item
.
feature
for
item
in
items
],
dim
=
0
)
.
type
(
self
.
vision_tower
.
dtype
)
.
to
(
self
.
vision_tower
.
device
)
)
...
...
python/sglang/srt/models/llava.py
View file @
4395c87a
...
...
@@ -186,7 +186,7 @@ class LlavaBaseForCausalLM(nn.Module):
bs
=
forward_batch
.
batch_size
pixel_values
=
flatten_nested_list
(
[
[
item
.
pixel_values
for
item
in
image_inputs
[
i
].
mm_items
]
[
item
.
feature
for
item
in
image_inputs
[
i
].
mm_items
]
for
i
in
range
(
bs
)
if
need_vision
[
i
]
]
...
...
@@ -753,7 +753,7 @@ class LlavaForConditionalGeneration(LlavaBaseForCausalLM):
features
=
[]
for
item
in
items
:
# in each item, we assume pixel_values is always batched
pixel_values
,
image_sizes
=
item
.
pixel_values
,
item
.
image_sizes
pixel_values
,
image_sizes
=
item
.
feature
,
item
.
image_sizes
image_outputs
=
self
.
vision_tower
(
pixel_values
,
image_sizes
,
output_hidden_states
=
True
)
...
...
python/sglang/srt/models/llavavid.py
View file @
4395c87a
...
...
@@ -135,7 +135,7 @@ class LlavaVidForCausalLM(nn.Module):
if
need_vision
.
any
():
pixel_values
=
flatten_nested_list
(
[
[
item
.
pixel_values
for
item
in
image_inputs
[
i
].
mm_items
]
[
item
.
feature
for
item
in
image_inputs
[
i
].
mm_items
]
for
i
in
range
(
bs
)
if
need_vision
[
i
]
]
...
...
python/sglang/srt/models/minicpmo.py
View file @
4395c87a
...
...
@@ -1552,9 +1552,7 @@ class MiniCPMO(MiniCPMBaseModel):
Returns:
List[List[torch.Tensor]]: audio embeddings
"""
wavforms
=
flatten_nested_list
(
[
item
.
audio_features
for
item
in
items
if
item
.
audio_features
]
)
wavforms
=
flatten_nested_list
([
item
.
feature
for
item
in
items
if
item
.
feature
])
# list, [[x1, x2], [y1], [z1]]
audio_feature_lens_raw
=
flatten_nested_list
(
[
item
.
audio_feature_lens
for
item
in
items
if
item
.
audio_feature_lens
]
...
...
@@ -1659,9 +1657,7 @@ class MiniCPMO(MiniCPMBaseModel):
List[List[torch.Tensor]]: audio embeddings
"""
# (bs, 80, frames) or [], multi audios need filled in advance
wavforms
=
flatten_nested_list
(
[
item
.
audio_features
for
item
in
items
if
item
.
audio_features
]
)
wavforms
=
flatten_nested_list
([
item
.
feature
for
item
in
items
if
item
.
feature
])
# list, [[x1, x2], [y1], [z1]]
audio_feature_lens_raw
=
flatten_nested_list
(
[
item
.
audio_feature_lens
for
item
in
items
if
item
.
audio_feature_lens
]
...
...
@@ -1778,7 +1774,7 @@ class MiniCPMO(MiniCPMBaseModel):
def
get_image_feature
(
self
,
items
:
List
[
MultimodalDataItem
])
->
torch
.
Tensor
:
# list of tensors
pixel_values
=
flatten_nested_list
([
item
.
pixel_values
for
item
in
items
])
pixel_values
=
flatten_nested_list
([
item
.
feature
for
item
in
items
])
tgt_sizes
=
torch
.
stack
(
flatten_nested_list
([
item
.
tgt_size
for
item
in
items
]),
dim
=
0
)
...
...
python/sglang/srt/models/minicpmv.py
View file @
4395c87a
...
...
@@ -724,7 +724,7 @@ class MiniCPMV2_6(MiniCPMBaseModel):
def
get_image_feature
(
self
,
items
:
List
[
MultimodalDataItem
])
->
torch
.
Tensor
:
# list of tensors
pixel_values
=
flatten_nested_list
([
item
.
pixel_values
for
item
in
items
])
pixel_values
=
flatten_nested_list
([
item
.
feature
for
item
in
items
])
tgt_sizes
=
torch
.
stack
(
flatten_nested_list
([
item
.
tgt_size
for
item
in
items
]),
dim
=
0
)
...
...
python/sglang/srt/models/mistral.py
View file @
4395c87a
...
...
@@ -56,7 +56,7 @@ class Mistral3ForConditionalGeneration:
features
=
[]
for
item
in
items
:
# in each item, we assume pixel_values is always batched
pixel_values
,
image_sizes
=
item
.
pixel_values
,
item
.
image_sizes
pixel_values
,
image_sizes
=
item
.
feature
,
item
.
image_sizes
image_outputs
=
self
.
vision_tower
(
pixel_values
,
image_sizes
,
output_hidden_states
=
True
)
...
...
python/sglang/srt/models/mllama.py
View file @
4395c87a
...
...
@@ -838,9 +838,7 @@ class MllamaForConditionalGeneration(nn.Module):
self
.
logits_processor
=
LogitsProcessor
(
config
.
text_config
)
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
mm_inputs
:
MultimodalInputs
):
pixel_values
=
torch
.
cat
(
[
item
.
pixel_values
for
item
in
mm_inputs
.
mm_items
],
dim
=
0
)
pixel_values
=
torch
.
cat
([
item
.
feature
for
item
in
mm_inputs
.
mm_items
],
dim
=
0
)
pad_values
=
[
item
.
pad_value
for
item
in
mm_inputs
.
mm_items
]
num_concurrent_media
,
num_tiles
=
pixel_values
.
shape
[
1
:
3
]
...
...
@@ -862,7 +860,7 @@ class MllamaForConditionalGeneration(nn.Module):
if
not
forward_batch
.
encoder_cached
[
i
]
and
mm_input
is
not
None
:
pixel_values
=
torch
.
cat
(
[
item
.
pixel_values
for
item
in
mm_input
.
mm_items
],
dim
=
0
[
item
.
feature
for
item
in
mm_input
.
mm_items
],
dim
=
0
)
max_num_images
=
max
(
max_num_images
,
pixel_values
.
shape
[
1
])
...
...
@@ -897,7 +895,7 @@ class MllamaForConditionalGeneration(nn.Module):
encoder_lens_need
.
append
(
forward_batch
.
encoder_lens
[
k
])
pixel_values
=
torch
.
cat
(
[
item
.
pixel_values
for
item
in
mm_input
.
mm_items
],
dim
=
0
[
item
.
feature
for
item
in
mm_input
.
mm_items
],
dim
=
0
)
for
j
in
range
(
pixel_values
.
shape
[
1
]):
img
=
pixel_values
[
0
,
j
]
...
...
python/sglang/srt/models/mllama4.py
View file @
4395c87a
...
...
@@ -147,7 +147,7 @@ class Llama4ForConditionalGeneration(nn.Module):
raise
ValueError
(
"Vision model not available for text-only checkpoint"
)
pixel_values
=
(
torch
.
concat
([
item
.
pixel_values
for
item
in
items
])
torch
.
concat
([
item
.
feature
for
item
in
items
])
.
to
(
next
(
self
.
vision_model
.
parameters
()).
device
)
.
type
(
next
(
self
.
vision_model
.
parameters
()).
dtype
)
)
...
...
python/sglang/srt/models/phi4mm.py
View file @
4395c87a
...
...
@@ -422,9 +422,7 @@ class Phi4MMForCausalLM(nn.Module):
def
get_image_feature
(
self
,
items
:
List
[
MultimodalDataItem
])
->
torch
.
Tensor
:
dtype
=
next
(
self
.
vision_encoder
.
parameters
()).
dtype
pixel_values
=
torch
.
cat
([
item
.
pixel_values
for
item
in
items
],
dim
=
0
).
type
(
dtype
)
pixel_values
=
torch
.
cat
([
item
.
feature
for
item
in
items
],
dim
=
0
).
type
(
dtype
)
image_attention_mask
=
torch
.
cat
([
item
.
image_emb_mask
for
item
in
items
],
dim
=
0
)
image_sizes
=
torch
.
cat
([
item
.
image_sizes
for
item
in
items
],
dim
=
0
)
image_embeds
=
self
.
vision_encoder
(
...
...
python/sglang/srt/models/qwen2_5_vl.py
View file @
4395c87a
...
...
@@ -497,7 +497,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
def
get_image_feature
(
self
,
items
:
List
[
MultimodalDataItem
])
->
torch
.
Tensor
:
# in qwen-vl, last dim is the same
pixel_values
=
torch
.
cat
([
item
.
pixel_values
for
item
in
items
],
dim
=
0
).
type
(
pixel_values
=
torch
.
cat
([
item
.
feature
for
item
in
items
],
dim
=
0
).
type
(
self
.
visual
.
dtype
)
image_grid_thw
=
torch
.
concat
([
item
.
image_grid_thw
for
item
in
items
],
dim
=
0
)
...
...
@@ -508,9 +508,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
def
get_video_feature
(
self
,
items
:
List
[
MultimodalDataItem
])
->
torch
.
Tensor
:
# in qwen-vl, last dim is the same
pixel_values
=
torch
.
cat
(
[
getattr
(
item
,
"pixel_values_videos"
)
for
item
in
items
],
dim
=
0
)
.
type
(
self
.
visual
.
dtype
)
pixel_values
=
torch
.
cat
(
[
item
.
feature
for
item
in
items
],
dim
=
0
).
type
(
self
.
visual
.
dtype
)
video_grid_thw
=
torch
.
concat
([
item
.
video_grid_thw
for
item
in
items
],
dim
=
0
)
assert
pixel_values
.
dim
()
==
2
,
pixel_values
.
dim
()
assert
video_grid_thw
.
dim
()
==
2
,
video_grid_thw
.
dim
()
...
...
python/sglang/srt/models/qwen2_audio.py
View file @
4395c87a
...
...
@@ -118,7 +118,7 @@ class Qwen2AudioForConditionalGeneration(nn.Module):
def
get_audio_feature
(
self
,
items
:
List
[
MultimodalDataItem
])
->
torch
.
Tensor
:
# Extract audio features from input items
input_features
=
torch
.
cat
([
item
.
audio_
feature
s
for
item
in
items
],
dim
=
0
).
type
(
input_features
=
torch
.
cat
([
item
.
feature
for
item
in
items
],
dim
=
0
).
type
(
self
.
audio_tower
.
dtype
)
...
...
python/sglang/srt/models/qwen2_vl.py
View file @
4395c87a
...
...
@@ -484,7 +484,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
def
get_image_feature
(
self
,
items
:
List
[
MultimodalDataItem
])
->
torch
.
Tensor
:
# in qwen-vl, last dim is the same
pixel_values
=
torch
.
cat
([
item
.
pixel_values
for
item
in
items
],
dim
=
0
).
type
(
pixel_values
=
torch
.
cat
([
item
.
feature
for
item
in
items
],
dim
=
0
).
type
(
self
.
visual
.
dtype
)
image_grid_thw
=
torch
.
concat
([
item
.
image_grid_thw
for
item
in
items
],
dim
=
0
)
...
...
@@ -495,9 +495,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
def
get_video_feature
(
self
,
items
:
List
[
MultimodalDataItem
])
->
torch
.
Tensor
:
# in qwen-vl, last dim is the same
pixel_values
=
torch
.
cat
(
[
item
.
pixel_values_videos
for
item
in
items
],
dim
=
0
)
.
type
(
self
.
visual
.
dtype
)
pixel_values
=
torch
.
cat
(
[
item
.
feature
for
item
in
items
],
dim
=
0
).
type
(
self
.
visual
.
dtype
)
video_grid_thw
=
torch
.
concat
([
item
.
video_grid_thw
for
item
in
items
],
dim
=
0
)
assert
pixel_values
.
dim
()
==
2
,
pixel_values
.
dim
()
assert
video_grid_thw
.
dim
()
==
2
,
video_grid_thw
.
dim
()
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment