Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
736fbf4c
Unverified
Commit
736fbf4c
authored
Oct 04, 2025
by
Cyrus Leung
Committed by
GitHub
Oct 04, 2025
Browse files
[Misc] Require `merge_by_field_config` argument (#26214)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
44ea8513
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
45 deletions
+10
-45
vllm/multimodal/utils.py
vllm/multimodal/utils.py
+9
-33
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+1
-12
No files found.
vllm/multimodal/utils.py
View file @
736fbf4c
...
@@ -15,7 +15,6 @@ import numpy as np
...
@@ -15,7 +15,6 @@ import numpy as np
import
numpy.typing
as
npt
import
numpy.typing
as
npt
import
torch
import
torch
from
PIL
import
Image
,
UnidentifiedImageError
from
PIL
import
Image
,
UnidentifiedImageError
from
typing_extensions
import
deprecated
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.connections
import
HTTPConnection
,
global_http_connection
from
vllm.connections
import
HTTPConnection
,
global_http_connection
...
@@ -376,39 +375,12 @@ def argsort_mm_positions(
...
@@ -376,39 +375,12 @@ def argsort_mm_positions(
return
[(
modality
,
idx
)
for
modality
,
idx
,
_
in
sorted_flat_items
]
return
[(
modality
,
idx
)
for
modality
,
idx
,
_
in
sorted_flat_items
]
# Temporary back-compatibility for plugins that define model runner
@
deprecated
(
"`group_mm_inputs_by_modality` is superseded by "
"`group_mm_kwargs_by_modality` and will be removed in v0.13. "
"Please use `group_mm_kwargs_by_modality` instead."
)
def
group_mm_inputs_by_modality
(
mm_inputs
:
list
[
MultiModalKwargsItems
]
)
->
list
[
list
[
MultiModalKwargsItems
]]:
if
not
mm_inputs
:
return
[]
def
modality_group_func
(
mm_input
:
MultiModalKwargsItems
)
->
Union
[
str
,
int
]:
# If the input has multiple modalities, return an id as the unique key
# for the mm_input input.
if
len
(
mm_input
)
>
1
:
return
id
(
mm_input
)
elif
len
(
mm_input
)
==
1
:
return
next
(
iter
(
mm_input
.
keys
()))
raise
AssertionError
(
"This line should be unreachable."
)
return
[
list
(
group
)
for
_
,
group
in
groupby
(
mm_inputs
,
key
=
modality_group_func
)
]
def
group_mm_kwargs_by_modality
(
def
group_mm_kwargs_by_modality
(
mm_kwargs
:
list
[
MultiModalKwargsItem
],
mm_kwargs
:
list
[
MultiModalKwargsItem
],
*
,
*
,
device
:
torch
.
types
.
Device
=
None
,
device
:
torch
.
types
.
Device
=
None
,
pin_memory
:
bool
=
False
,
pin_memory
:
bool
=
False
,
merge_by_field_config
:
bool
=
Fals
e
,
merge_by_field_config
:
Optional
[
bool
]
=
Non
e
,
)
->
Iterable
[
tuple
[
str
,
int
,
BatchedTensorInputs
]]:
)
->
Iterable
[
tuple
[
str
,
int
,
BatchedTensorInputs
]]:
"""Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same
"""Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same
modality together into the same `MultiModalKwargs` instance.
modality together into the same `MultiModalKwargs` instance.
...
@@ -421,15 +393,19 @@ def group_mm_kwargs_by_modality(
...
@@ -421,15 +393,19 @@ def group_mm_kwargs_by_modality(
Yields:
Yields:
A tuple `(modality, num_items, grouped_kwargs)`.
A tuple `(modality, num_items, grouped_kwargs)`.
"""
"""
if
merge_by_field_config
is
None
:
raise
RuntimeError
(
"`group_mm_kwargs_by_modality` now requires "
"`merge_by_field_config` arg, please update your model runner "
"according to https://github.com/vllm-project/vllm/pull/25676."
)
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
MultiModalKwargsItems
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
MultiModalKwargsItems
for
modality
,
items
in
groupby
(
mm_kwargs
,
key
=
lambda
item
:
item
.
modality
):
for
modality
,
items
in
groupby
(
mm_kwargs
,
key
=
lambda
item
:
item
.
modality
):
items_lst
=
list
(
items
)
items_lst
=
list
(
items
)
# TODO: Enable `merge_by_field_config` for all models
# TODO: Deprecate `merge_by_field_config` once
# to avoid creating an extra batch dimension (except for fields
# we have migrated all in-tree models
# that are meant to be stacked anyway).
# We will also need to update each model to remove `flatten_bn`.
if
merge_by_field_config
:
if
merge_by_field_config
:
mm_kwargs_group
:
BatchedTensorInputs
=
dict
(
mm_kwargs_group
:
BatchedTensorInputs
=
dict
(
MultiModalKwargsItems
.
from_seq
(
items_lst
).
get_data
(
MultiModalKwargsItems
.
from_seq
(
items_lst
).
get_data
(
...
...
vllm/v1/worker/gpu_input_batch.py
View file @
736fbf4c
...
@@ -7,10 +7,9 @@ from typing import Optional, cast
...
@@ -7,10 +7,9 @@ from typing import Optional, cast
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
typing_extensions
import
deprecated
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.multimodal.inputs
import
MultiModalFeatureSpec
,
MultiModalKwargsItems
from
vllm.multimodal.inputs
import
MultiModalFeatureSpec
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.utils
import
length_from_prompt_token_ids_or_embeds
,
swap_dict_values
from
vllm.utils
import
length_from_prompt_token_ids_or_embeds
,
swap_dict_values
...
@@ -53,16 +52,6 @@ class CachedRequestState:
...
@@ -53,16 +52,6 @@ class CachedRequestState:
def
num_tokens
(
self
)
->
int
:
def
num_tokens
(
self
)
->
int
:
return
self
.
num_prompt_tokens
+
len
(
self
.
output_token_ids
)
return
self
.
num_prompt_tokens
+
len
(
self
.
output_token_ids
)
# Temporary back-compatibility for plugins that define model runner
@
property
@
deprecated
(
"`mm_inputs` is superseded by `mm_kwargs` and will be "
"removed in v0.13. Please use `mm_kwargs` instead."
)
def
mm_inputs
(
self
)
->
list
[
MultiModalKwargsItems
]:
return
[
MultiModalKwargsItems
.
from_seq
([
f
.
data
])
for
f
in
self
.
mm_features
if
f
.
data
is
not
None
]
def
get_token_id
(
self
,
idx
:
int
)
->
int
:
def
get_token_id
(
self
,
idx
:
int
)
->
int
:
if
idx
<
self
.
num_prompt_tokens
:
if
idx
<
self
.
num_prompt_tokens
:
if
self
.
prompt_token_ids
is
None
:
if
self
.
prompt_token_ids
is
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment