Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c67abd61
Unverified
Commit
c67abd61
authored
Mar 29, 2025
by
Roger Wang
Committed by
GitHub
Mar 29, 2025
Browse files
[V1] Support interleaved modality items (#15605)
Signed-off-by:
Roger Wang
<
ywang@roblox.com
>
parent
6fa7cd3d
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
211 additions
and
121 deletions
+211
-121
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+1
-0
tests/conftest.py
tests/conftest.py
+19
-22
tests/models/decoder_only/vision_language/test_interleaved.py
...s/models/decoder_only/vision_language/test_interleaved.py
+77
-0
tests/multimodal/test_utils.py
tests/multimodal/test_utils.py
+64
-16
vllm/multimodal/utils.py
vllm/multimodal/utils.py
+30
-52
vllm/v1/engine/processor.py
vllm/v1/engine/processor.py
+20
-31
No files found.
.buildkite/test-pipeline.yaml
View file @
c67abd61
...
@@ -431,6 +431,7 @@ steps:
...
@@ -431,6 +431,7 @@ steps:
-
pytest -v -s models/encoder_decoder/audio_language -m core_model
-
pytest -v -s models/encoder_decoder/audio_language -m core_model
-
pytest -v -s models/encoder_decoder/language -m core_model
-
pytest -v -s models/encoder_decoder/language -m core_model
-
pytest -v -s models/encoder_decoder/vision_language -m core_model
-
pytest -v -s models/encoder_decoder/vision_language -m core_model
-
pytest -v -s models/decoder_only/vision_language/test_interleaved.py
-
label
:
Multi-Modal Models Test (Extended)
1
# 48m
-
label
:
Multi-Modal Models Test (Extended)
1
# 48m
optional
:
true
optional
:
true
...
...
tests/conftest.py
View file @
c67abd61
...
@@ -747,30 +747,27 @@ class VllmRunner:
...
@@ -747,30 +747,27 @@ class VllmRunner:
videos
:
Optional
[
PromptVideoInput
]
=
None
,
videos
:
Optional
[
PromptVideoInput
]
=
None
,
audios
:
Optional
[
PromptAudioInput
]
=
None
,
audios
:
Optional
[
PromptAudioInput
]
=
None
,
)
->
list
[
TextPrompt
]:
)
->
list
[
TextPrompt
]:
if
images
is
not
None
:
assert
len
(
prompts
)
==
len
(
images
)
if
videos
is
not
None
:
assert
len
(
prompts
)
==
len
(
videos
)
if
audios
is
not
None
:
if
any
(
x
is
not
None
and
len
(
x
)
!=
len
(
prompts
)
assert
len
(
prompts
)
==
len
(
audios
)
for
x
in
[
images
,
videos
,
audios
]):
raise
ValueError
(
"All non-None multimodal inputs must have the same length as "
"prompts"
)
inputs
=
[
TextPrompt
(
prompt
=
prompt
)
for
prompt
in
prompts
]
inputs
=
[]
if
images
is
not
None
:
for
i
,
prompt
in
enumerate
(
prompts
):
for
i
,
image
in
enumerate
(
images
):
multi_modal_data
=
{}
if
image
is
not
None
:
if
images
is
not
None
and
(
image
:
=
images
[
i
])
is
not
None
:
inputs
[
i
][
"multi_modal_data"
]
=
{
"image"
:
image
}
multi_modal_data
[
"image"
]
=
image
if
videos
is
not
None
and
(
video
:
=
videos
[
i
])
is
not
None
:
if
videos
is
not
None
:
multi_modal_data
[
"video"
]
=
video
for
i
,
video
in
enumerate
(
videos
):
if
audios
is
not
None
and
(
audio
:
=
audios
[
i
])
is
not
None
:
if
video
is
not
None
:
multi_modal_data
[
"audio"
]
=
audio
inputs
[
i
][
"multi_modal_data"
]
=
{
"video"
:
video
}
inputs
.
append
(
if
audios
is
not
None
:
TextPrompt
(
prompt
=
prompt
,
for
i
,
audio
in
enumerate
(
audios
):
multi_modal_data
=
multi_modal_data
if
audio
is
not
None
:
if
multi_modal_data
else
None
))
inputs
[
i
][
"multi_modal_data"
]
=
{
"audio"
:
audio
}
return
inputs
return
inputs
...
...
tests/models/decoder_only/vision_language/test_interleaved.py
0 → 100644
View file @
c67abd61
# SPDX-License-Identifier: Apache-2.0
import
pytest
from
vllm.assets.image
import
ImageAsset
from
vllm.assets.video
import
VideoAsset
models
=
[
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf"
]
def
base_prompt
(
modalities_str
:
str
)
->
str
:
return
f
"<|im_start|>user
{
modalities_str
}
\n
Describe what you see from these items.<|im_end|><|im_start|>assistant
\n
"
# noqa: E501
INTERLEAVED_PROMPT
=
base_prompt
(
"<image><video><image>
\n
"
)
NONINTERLEAVED_PROMPT
=
base_prompt
(
"<image><image><video>
\n
"
)
@
pytest
.
mark
.
parametrize
(
"model"
,
models
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float16"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
def
test_models
(
vllm_runner
,
model
,
dtype
:
str
,
max_tokens
:
int
)
->
None
:
"""
This is a simple test to check if interleaved and non-interleaved prompts
give the same result.
"""
image_cherry
=
ImageAsset
(
"cherry_blossom"
).
pil_image
.
convert
(
"RGB"
)
image_stop
=
ImageAsset
(
"stop_sign"
).
pil_image
.
convert
(
"RGB"
)
images
=
[
image_cherry
,
image_stop
]
video
=
VideoAsset
(
name
=
"sample_demo_1.mp4"
,
num_frames
=
16
).
np_ndarrays
inputs
=
[
(
[
INTERLEAVED_PROMPT
],
[
images
],
[
video
],
),
(
[
NONINTERLEAVED_PROMPT
],
[
images
],
[
video
],
),
]
with
vllm_runner
(
model
,
task
=
"generate"
,
dtype
=
dtype
,
limit_mm_per_prompt
=
{
"image"
:
2
},
max_model_len
=
32768
,
max_num_seqs
=
2
,
tensor_parallel_size
=
1
,
enforce_eager
=
True
)
as
vllm_model
:
vllm_outputs_per_case
=
[
vllm_model
.
generate_greedy
(
prompts
,
max_tokens
,
images
=
images
,
videos
=
videos
)
for
prompts
,
images
,
videos
in
inputs
]
all_results
=
[
output
[
0
][
1
]
for
output
in
vllm_outputs_per_case
]
outputs
=
[(
total_str
,
total_str
.
find
(
"assistant
\n
"
)
+
len
(
"assistant
\n
"
))
for
total_str
in
all_results
]
prompt_lengths
=
[
prompt_len
for
_
,
prompt_len
in
outputs
]
generated_strs
=
[
total_str
[
prompt_len
:]
for
total_str
,
prompt_len
in
outputs
]
interleaved_prompt_len
,
noninterleaved_prompt_len
=
prompt_lengths
interleaved_output_str
,
noninterleaved_output_str
=
generated_strs
# The two prompts are identical except for the order of modality tokens.
assert
interleaved_prompt_len
==
noninterleaved_prompt_len
# The two generated strings should be different because of the
# interleaved modality tokens.
assert
interleaved_output_str
!=
noninterleaved_output_str
tests/multimodal/test_utils.py
View file @
c67abd61
...
@@ -155,7 +155,7 @@ def test_merge_and_sort_multimodal_metadata():
...
@@ -155,7 +155,7 @@ def test_merge_and_sort_multimodal_metadata():
]
]
},
},
mm_hashes
=
{
"image"
:
[
"hash1"
,
"hash2"
]},
mm_hashes
=
{
"image"
:
[
"hash1"
,
"hash2"
]},
expected_modalities
=
[
"image"
],
expected_modalities
=
[
"image"
,
"image"
],
expected_ranges
=
[
expected_ranges
=
[
PlaceholderRange
(
offset
=
0
,
length
=
2
),
PlaceholderRange
(
offset
=
0
,
length
=
2
),
PlaceholderRange
(
offset
=
3
,
length
=
2
),
PlaceholderRange
(
offset
=
3
,
length
=
2
),
...
@@ -172,7 +172,7 @@ def test_merge_and_sort_multimodal_metadata():
...
@@ -172,7 +172,7 @@ def test_merge_and_sort_multimodal_metadata():
]
]
},
},
mm_hashes
=
None
,
mm_hashes
=
None
,
expected_modalities
=
[
"image"
],
expected_modalities
=
[
"image"
,
"image"
],
expected_ranges
=
[
expected_ranges
=
[
PlaceholderRange
(
offset
=
0
,
length
=
2
),
PlaceholderRange
(
offset
=
0
,
length
=
2
),
PlaceholderRange
(
offset
=
2
,
length
=
2
),
PlaceholderRange
(
offset
=
2
,
length
=
2
),
...
@@ -197,7 +197,7 @@ def test_merge_and_sort_multimodal_metadata():
...
@@ -197,7 +197,7 @@ def test_merge_and_sort_multimodal_metadata():
"image"
:
[
"image_hash1"
,
"image_hash2"
],
"image"
:
[
"image_hash1"
,
"image_hash2"
],
"audio"
:
[
"audio_hash1"
,
"audio_hash2"
],
"audio"
:
[
"audio_hash1"
,
"audio_hash2"
],
},
},
expected_modalities
=
[
"audio"
,
"image"
],
expected_modalities
=
[
"audio"
,
"audio"
,
"image"
,
"image"
],
expected_ranges
=
[
expected_ranges
=
[
PlaceholderRange
(
offset
=
0
,
length
=
2
),
PlaceholderRange
(
offset
=
0
,
length
=
2
),
PlaceholderRange
(
offset
=
2
,
length
=
3
),
PlaceholderRange
(
offset
=
2
,
length
=
3
),
...
@@ -223,7 +223,7 @@ def test_merge_and_sort_multimodal_metadata():
...
@@ -223,7 +223,7 @@ def test_merge_and_sort_multimodal_metadata():
]
]
},
},
mm_hashes
=
None
,
mm_hashes
=
None
,
expected_modalities
=
[
"audio"
,
"image"
],
expected_modalities
=
[
"audio"
,
"audio"
,
"image"
,
"image"
],
expected_ranges
=
[
expected_ranges
=
[
PlaceholderRange
(
offset
=
0
,
length
=
2
),
PlaceholderRange
(
offset
=
0
,
length
=
2
),
PlaceholderRange
(
offset
=
2
,
length
=
3
),
PlaceholderRange
(
offset
=
2
,
length
=
3
),
...
@@ -254,7 +254,9 @@ def test_merge_and_sort_multimodal_metadata():
...
@@ -254,7 +254,9 @@ def test_merge_and_sort_multimodal_metadata():
"audio"
:
[
"audio_hash1"
],
"audio"
:
[
"audio_hash1"
],
"video"
:
[
"video_hash1"
,
"video_hash2"
,
"video_hash3"
]
"video"
:
[
"video_hash1"
,
"video_hash2"
,
"video_hash3"
]
},
},
expected_modalities
=
[
"audio"
,
"video"
,
"image"
],
expected_modalities
=
[
"audio"
,
"video"
,
"video"
,
"video"
,
"image"
,
"image"
],
expected_ranges
=
[
expected_ranges
=
[
PlaceholderRange
(
offset
=
0
,
length
=
2
),
PlaceholderRange
(
offset
=
0
,
length
=
2
),
PlaceholderRange
(
offset
=
3
,
length
=
4
),
PlaceholderRange
(
offset
=
3
,
length
=
4
),
...
@@ -300,12 +302,19 @@ def test_merge_and_sort_multimodal_metadata_with_interleaving():
...
@@ -300,12 +302,19 @@ def test_merge_and_sort_multimodal_metadata_with_interleaving():
"image"
:
[
"image_hash1"
,
"image_hash2"
],
"image"
:
[
"image_hash1"
,
"image_hash2"
],
"audio"
:
[
"audio_hash1"
,
"audio_hash2"
],
"audio"
:
[
"audio_hash1"
,
"audio_hash2"
],
},
},
expected_modalities
=
[],
expected_modalities
=
[
"image"
,
"audio"
,
"image"
,
"audio"
],
expected_ranges
=
[],
expected_ranges
=
[
expected_hashes
=
None
,
PlaceholderRange
(
offset
=
0
,
length
=
4
),
PlaceholderRange
(
offset
=
5
,
length
=
2
),
PlaceholderRange
(
offset
=
8
,
length
=
2
),
PlaceholderRange
(
offset
=
11
,
length
=
4
),
],
expected_hashes
=
[
"image_hash1"
,
"audio_hash1"
,
"image_hash2"
,
"audio_hash2"
],
),
),
# <image> <image> <
video> <audi
o> <image>
# <image> <image> <
audio> <vide
o> <image>
TestCase
(
TestCase
(
mm_positions
=
{
mm_positions
=
{
"image"
:
[
"image"
:
[
...
@@ -321,15 +330,54 @@ def test_merge_and_sort_multimodal_metadata_with_interleaving():
...
@@ -321,15 +330,54 @@ def test_merge_and_sort_multimodal_metadata_with_interleaving():
]
]
},
},
mm_hashes
=
None
,
mm_hashes
=
None
,
expected_modalities
=
[],
expected_modalities
=
[
"image"
,
"image"
,
"audio"
,
"video"
,
"image"
],
expected_ranges
=
[],
expected_ranges
=
[
PlaceholderRange
(
offset
=
0
,
length
=
2
),
PlaceholderRange
(
offset
=
2
,
length
=
3
),
PlaceholderRange
(
offset
=
5
,
length
=
2
),
PlaceholderRange
(
offset
=
8
,
length
=
5
),
PlaceholderRange
(
offset
=
20
,
length
=
4
),
],
expected_hashes
=
None
,
expected_hashes
=
None
,
),
),
# <image> <audio> <video> <image> with hashes
TestCase
(
mm_positions
=
{
"image"
:
[
PlaceholderRange
(
offset
=
0
,
length
=
2
),
PlaceholderRange
(
offset
=
18
,
length
=
4
),
],
"audio"
:
[
PlaceholderRange
(
offset
=
6
,
length
=
2
),
],
"video"
:
[
PlaceholderRange
(
offset
=
10
,
length
=
5
),
]
},
mm_hashes
=
{
"image"
:
[
"image_hash1"
,
"image_hash2"
],
"audio"
:
[
"audio_hash1"
],
"video"
:
[
"video_hash1"
],
},
expected_modalities
=
[
"image"
,
"audio"
,
"video"
,
"image"
],
expected_ranges
=
[
PlaceholderRange
(
offset
=
0
,
length
=
2
),
PlaceholderRange
(
offset
=
6
,
length
=
2
),
PlaceholderRange
(
offset
=
10
,
length
=
5
),
PlaceholderRange
(
offset
=
18
,
length
=
4
),
],
expected_hashes
=
[
"image_hash1"
,
"audio_hash1"
,
"video_hash1"
,
"image_hash2"
],
),
]
]
for
case
in
test_cas
es
:
for
(
mm_positions
,
mm_hashes
,
expected_modalities
,
expected_rang
es
,
with
pytest
.
raises
(
ValueError
)
as
ex_info
:
expected_hashes
)
in
test_cases
:
merge_and_sort_multimodal_metadata
(
case
.
mm_positions
,
modalities
,
ranges
,
hashes
=
merge_and_sort_multimodal_metadata
(
case
.
mm_hashes
)
mm_positions
,
mm_hashes
)
assert
"Interleaved mixed-modality"
in
str
(
ex_info
.
value
)
assert
modalities
==
expected_modalities
assert
ranges
==
expected_ranges
assert
hashes
==
expected_hashes
vllm/multimodal/utils.py
View file @
c67abd61
...
@@ -303,14 +303,10 @@ def merge_and_sort_multimodal_metadata(
...
@@ -303,14 +303,10 @@ def merge_and_sort_multimodal_metadata(
Optionally if a MultiModalHashDict is given, same operation will be
Optionally if a MultiModalHashDict is given, same operation will be
applied to the object and the sorted list of hashes will be returned.
applied to the object and the sorted list of hashes will be returned.
Raises:
ValueError: If the input prompt has interleaved placeholders from
different modalities (e.g, "<image><audio><image> Describe the
content.")
Returns:
Returns:
list[str]: Sorted list of involved modalities.
list[str]: List of item modalities in order of their positions in
the input sequence.
list[PlaceholderRange]: Sorted list of all PlaceholdeRanges from
list[PlaceholderRange]: Sorted list of all PlaceholdeRanges from
mm_positions.
mm_positions.
Optional[list[str]]: Sorted list of all hashes from mm_hashes if
Optional[list[str]]: Sorted list of all hashes from mm_hashes if
...
@@ -324,50 +320,33 @@ def merge_and_sort_multimodal_metadata(
...
@@ -324,50 +320,33 @@ def merge_and_sort_multimodal_metadata(
# For single modality, placeholder ranges and hashes are already sorted
# For single modality, placeholder ranges and hashes are already sorted
# so we can return the list directly.
# so we can return the list directly.
if
len
(
modalities
)
==
1
:
if
len
(
modalities
)
==
1
:
if
mm_hashes
is
None
:
modality
=
modalities
[
0
]
return
modalities
,
list
(
mm_positions
[
modalities
[
0
]]),
None
placeholder_list
=
list
(
mm_positions
[
modality
])
else
:
return
modalities
,
list
(
mm_positions
[
modalities
[
0
]]),
list
(
return
[
modality
]
*
len
(
mm_hashes
[
modalities
[
0
]])
placeholder_list
),
placeholder_list
,
None
if
not
mm_hashes
else
mm_hashes
[
modality
]
placeholder_lists_with_modality
=
[(
modality
,
mm_positions
[
modality
])
for
modality
in
modalities
]
# Create a list of (modality, placeholder, hash) tuples for all placeholders
all_items
=
[]
if
mm_hashes
is
None
:
for
modality
in
modalities
:
sorted_placeholder_lists
=
sorted
(
placeholder_lists_with_modality
,
placeholder_list
=
list
(
mm_positions
[
modality
])
key
=
lambda
x
:
x
[
1
][
0
][
'offset'
])
hash_list
:
list
[
Optional
[
str
]]
=
list
(
sorted_hash_lists
=
None
mm_hashes
[
modality
])
if
mm_hashes
and
modality
in
mm_hashes
else
[
else
:
None
hashes_lists
=
[
]
*
len
(
placeholder_list
)
mm_hashes
[
modality
]
for
modality
in
modalities
if
modality
in
mm_hashes
for
placeholder
,
hash_value
in
zip
(
placeholder_list
,
hash_list
):
]
all_items
.
append
((
modality
,
placeholder
,
hash_value
))
sorted_pairs
=
sorted
(
zip
(
placeholder_lists_with_modality
,
hashes_lists
),
# Sort all items by offset
key
=
lambda
x
:
x
[
0
][
1
][
0
][
'offset'
])
all_items
.
sort
(
key
=
lambda
x
:
x
[
1
][
'offset'
])
sorted_placeholder_tuple
,
sorted_hash_tuple
=
zip
(
*
sorted_pairs
)
sorted_placeholder_lists
=
list
(
sorted_placeholder_tuple
)
# Split into separate lists
sorted_hash_lists
=
list
(
sorted_hash_tuple
)
sorted_modalities
=
[
item
[
0
]
for
item
in
all_items
]
merged_placeholders
=
[
item
[
1
]
for
item
in
all_items
]
sorted_modalities
=
[
modality
for
modality
,
_
in
sorted_placeholder_lists
]
merged_hashes
=
[
str
(
item
[
2
])
for
item
in
all_items
]
if
mm_hashes
is
not
None
else
None
# Flatten sorted list of lists to a single list and verify there is no
# interleaving of placeholders from different modalities.
merged_placeholders
:
list
[
PlaceholderRange
]
=
[]
for
modality
,
placeholder_list
in
sorted_placeholder_lists
:
if
merged_placeholders
and
placeholder_list
[
0
][
'offset'
]
<
merged_placeholders
[
-
1
][
'offset'
]:
raise
ValueError
(
"Interleaved mixed-modality inference is currently not "
"supported."
)
merged_placeholders
.
extend
(
placeholder_list
)
if
sorted_hash_lists
is
not
None
:
merged_hashes
=
[]
for
hash_list
in
sorted_hash_lists
:
merged_hashes
.
extend
(
hash_list
)
else
:
merged_hashes
=
None
return
sorted_modalities
,
merged_placeholders
,
merged_hashes
return
sorted_modalities
,
merged_placeholders
,
merged_hashes
...
@@ -383,8 +362,7 @@ def group_mm_inputs_by_modality(
...
@@ -383,8 +362,7 @@ def group_mm_inputs_by_modality(
Returns:
Returns:
list[list[MultiModalKwargs]]: List of list of MultiModalKwargs, each
list[list[MultiModalKwargs]]: List of list of MultiModalKwargs, each
inner list contains consecutive MultiModalKwargs with same modality, or
inner list contains consecutive MultiModalKwargs with same modality.
one with multimodal modalities.
"""
"""
if
not
mm_inputs
:
if
not
mm_inputs
:
return
[]
return
[]
...
...
vllm/v1/engine/processor.py
View file @
c67abd61
...
@@ -234,22 +234,11 @@ class Processor:
...
@@ -234,22 +234,11 @@ class Processor:
if
decoder_inputs
[
"type"
]
==
"multimodal"
:
if
decoder_inputs
[
"type"
]
==
"multimodal"
:
decoder_mm_inputs
=
decoder_inputs
[
"mm_kwargs"
]
decoder_mm_inputs
=
decoder_inputs
[
"mm_kwargs"
]
# The output of merged multi-modal processor (`decoder_mm_inputs`)
# contains the kwargs for all items from all modalities.
# This code separates them so that there is one set of kwargs
# per item per modality.
individual_mm_inputs
=
[
MultiModalKwargs
.
from_items
([
item
])
for
modality
in
decoder_mm_inputs
.
modalities
for
item
in
decoder_mm_inputs
.
get_items
(
modality
)
]
# Merge and flatten multimodal placeholders, hashes and inputs
# Merge and flatten multimodal placeholders, hashes and inputs
# from dictionaries to lists, and sort them by each item's position
# from dictionaries to lists, and sort them by each item's position
# in the input sequence.
# in the input sequence.
# NOTE: interleaved modalities are not supported.
(
(
sorted_modalities
,
sorted_
item_
modalities
,
sorted_mm_positions
,
sorted_mm_positions
,
sorted_mm_hashes
,
sorted_mm_hashes
,
)
=
merge_and_sort_multimodal_metadata
(
)
=
merge_and_sort_multimodal_metadata
(
...
@@ -257,26 +246,26 @@ class Processor:
...
@@ -257,26 +246,26 @@ class Processor:
decoder_inputs
[
"mm_hashes"
]
if
self
.
use_hash
else
None
,
decoder_inputs
[
"mm_hashes"
]
if
self
.
use_hash
else
None
,
)
)
# NOTE: Sort multimodal inputs/kwargs ONLY IF there are multiple
# The output of merged multi-modal processor (`decoder_mm_inputs`)
# modalities involved.
# is a single MultiModalKwargs for all items from all modalities.
if
len
(
sorted_modalities
)
>
1
:
# This code flattens kwargs for individual items in a list and
modality_order_dict
=
{
# sorts them by each item's position in the input sequence if there
modality
:
order
# are multiple modalities.
for
order
,
modality
in
enumerate
(
sorted_modalities
)
unique_modalities
=
set
(
sorted_item_modalities
)
}
if
len
(
unique_modalities
)
>
1
:
sorted_mm_inputs
=
[]
# Sanity check to make sure each multimodal input has only one
used_indices
=
{
modality
:
0
for
modality
in
unique_modalities
}
# modality key.
for
modality
in
sorted_item_modalities
:
for
mm_input
in
individual_mm_inputs
:
items
=
decoder_mm_inputs
.
get_items
(
modality
)
assert
len
(
mm_input
.
modalities
)
==
1
item
=
items
[
used_indices
[
modality
]]
sorted_mm_inputs
.
append
(
MultiModalKwargs
.
from_items
([
item
# Sort MultiModalKwargs to match sorted_mm_positions
]))
sorted_mm_inputs
=
sorted
(
used_indices
[
modality
]
+=
1
individual_mm_inputs
,
key
=
lambda
mm_input
:
modality_order_dict
[
list
(
mm_input
.
modalities
)[
0
]])
else
:
else
:
sorted_mm_inputs
=
individual_mm_inputs
sorted_mm_inputs
=
[
MultiModalKwargs
.
from_items
([
item
])
for
item
in
decoder_mm_inputs
.
get_items
(
sorted_item_modalities
[
0
])
]
return
EngineCoreRequest
(
return
EngineCoreRequest
(
request_id
=
request_id
,
request_id
=
request_id
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment