Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0ea80c87
Unverified
Commit
0ea80c87
authored
Sep 26, 2025
by
Cyrus Leung
Committed by
GitHub
Sep 25, 2025
Browse files
[Model] Define `merge_by_field_config` MM interface (#25676)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
b8d9e4a3
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
44 additions
and
12 deletions
+44
-12
tests/models/multimodal/processing/test_tensor_schema.py
tests/models/multimodal/processing/test_tensor_schema.py
+18
-5
vllm/config/model.py
vllm/config/model.py
+2
-3
vllm/model_executor/models/interfaces.py
vllm/model_executor/models/interfaces.py
+6
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+11
-2
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+7
-2
No files found.
tests/models/multimodal/processing/test_tensor_schema.py
View file @
0ea80c87
...
@@ -19,6 +19,8 @@ from vllm.distributed import (cleanup_dist_env_and_memory,
...
@@ -19,6 +19,8 @@ from vllm.distributed import (cleanup_dist_env_and_memory,
init_distributed_environment
,
init_distributed_environment
,
initialize_model_parallel
)
initialize_model_parallel
)
from
vllm.model_executor.model_loader.utils
import
set_default_torch_dtype
from
vllm.model_executor.model_loader.utils
import
set_default_torch_dtype
from
vllm.model_executor.models.interfaces
import
(
SupportsMultiModal
,
supports_multimodal
)
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
BatchedTensorInputs
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
BatchedTensorInputs
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
InputProcessingContext
)
InputProcessingContext
)
...
@@ -88,6 +90,7 @@ def resize_mm_data(
...
@@ -88,6 +90,7 @@ def resize_mm_data(
def
create_batched_mm_kwargs
(
def
create_batched_mm_kwargs
(
model_cls
:
type
[
SupportsMultiModal
],
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
processor
:
BaseMultiModalProcessor
,
processor
:
BaseMultiModalProcessor
,
size_factors
:
tuple
[
float
,
...]
=
(
1.0
,
0.5
,
0.25
),
size_factors
:
tuple
[
float
,
...]
=
(
1.0
,
0.5
,
0.25
),
...
@@ -127,16 +130,22 @@ def create_batched_mm_kwargs(
...
@@ -127,16 +130,22 @@ def create_batched_mm_kwargs(
mm_data
=
resized_mm_data
,
mm_data
=
resized_mm_data
,
hf_processor_mm_kwargs
=
processor_inputs
.
hf_processor_mm_kwargs
,
hf_processor_mm_kwargs
=
processor_inputs
.
hf_processor_mm_kwargs
,
tokenization_kwargs
=
processor_inputs
.
tokenization_kwargs
,
tokenization_kwargs
=
processor_inputs
.
tokenization_kwargs
,
)[
"mm_kwargs"
]
)[
"mm_kwargs"
]
.
require_data
()
items
=
[
items
=
[
item
for
modality
in
supported_mm_limits
item
for
modality
in
supported_mm_limits
for
item
in
mm_kwargs
[
modality
]
for
item
in
mm_kwargs
[
modality
]
]
]
return
group_mm_kwargs_by_modality
(
items
)
return
group_mm_kwargs_by_modality
(
items
,
merge_by_field_config
=
model_cls
.
merge_by_field_config
,
)
@
contextmanager
@
contextmanager
def
initialize_dummy_model
(
model_cls
:
nn
.
Module
,
model_config
:
ModelConfig
):
def
initialize_dummy_model
(
model_cls
:
type
[
nn
.
Module
],
model_config
:
ModelConfig
,
):
temp_file
=
tempfile
.
mkstemp
()[
1
]
temp_file
=
tempfile
.
mkstemp
()[
1
]
init_distributed_environment
(
init_distributed_environment
(
world_size
=
1
,
world_size
=
1
,
...
@@ -198,8 +207,12 @@ def test_model_tensor_schema(model_arch: str, model_id: str):
...
@@ -198,8 +207,12 @@ def test_model_tensor_schema(model_arch: str, model_id: str):
hf_overrides
=
hf_overrides_fn
,
hf_overrides
=
hf_overrides_fn
,
skip_tokenizer_init
=
model_info
.
skip_tokenizer_init
,
skip_tokenizer_init
=
model_info
.
skip_tokenizer_init
,
enforce_eager
=
model_info
.
enforce_eager
,
enforce_eager
=
model_info
.
enforce_eager
,
dtype
=
model_info
.
dtype
)
dtype
=
model_info
.
dtype
,
)
model_cls
=
MULTIMODAL_REGISTRY
.
_get_model_cls
(
model_config
)
model_cls
=
MULTIMODAL_REGISTRY
.
_get_model_cls
(
model_config
)
assert
supports_multimodal
(
model_cls
)
factories
=
MULTIMODAL_REGISTRY
.
_processor_factories
[
model_cls
]
factories
=
MULTIMODAL_REGISTRY
.
_processor_factories
[
model_cls
]
inputs_parse_methods
=
[]
inputs_parse_methods
=
[]
...
@@ -228,7 +241,7 @@ def test_model_tensor_schema(model_arch: str, model_id: str):
...
@@ -228,7 +241,7 @@ def test_model_tensor_schema(model_arch: str, model_id: str):
with
initialize_dummy_model
(
model_cls
,
model_config
)
as
model
:
with
initialize_dummy_model
(
model_cls
,
model_config
)
as
model
:
for
modality
,
_
,
mm_kwargs
in
create_batched_mm_kwargs
(
for
modality
,
_
,
mm_kwargs
in
create_batched_mm_kwargs
(
model_config
,
processor
):
model_cls
,
model_config
,
processor
):
for
method_name
in
inputs_parse_methods
:
for
method_name
in
inputs_parse_methods
:
print
(
f
"Testing `
{
method_name
}
` with modality=
{
modality
}
"
print
(
f
"Testing `
{
method_name
}
` with modality=
{
modality
}
"
f
"and mm_kwargs
{
list
(
mm_kwargs
.
keys
())
}
"
)
f
"and mm_kwargs
{
list
(
mm_kwargs
.
keys
())
}
"
)
...
...
vllm/config/model.py
View file @
0ea80c87
...
@@ -63,13 +63,12 @@ ConvertType = Literal["none", "embed", "classify", "reward"]
...
@@ -63,13 +63,12 @@ ConvertType = Literal["none", "embed", "classify", "reward"]
ConvertOption
=
Literal
[
"auto"
,
ConvertType
]
ConvertOption
=
Literal
[
"auto"
,
ConvertType
]
TaskOption
=
Literal
[
"auto"
,
"generate"
,
"embedding"
,
"embed"
,
"classify"
,
TaskOption
=
Literal
[
"auto"
,
"generate"
,
"embedding"
,
"embed"
,
"classify"
,
"score"
,
"reward"
,
"transcription"
,
"draft"
]
"score"
,
"reward"
,
"transcription"
,
"draft"
]
_ResolvedTask
=
Literal
[
"generate"
,
"transcription"
,
"encode"
,
"embed"
,
"classify"
,
"reward"
,
"draft"
]
TokenizerMode
=
Literal
[
"auto"
,
"slow"
,
"mistral"
,
"custom"
]
TokenizerMode
=
Literal
[
"auto"
,
"slow"
,
"mistral"
,
"custom"
]
ModelDType
=
Literal
[
"auto"
,
"half"
,
"float16"
,
"bfloat16"
,
"float"
,
"float32"
]
ModelDType
=
Literal
[
"auto"
,
"half"
,
"float16"
,
"bfloat16"
,
"float"
,
"float32"
]
LogprobsMode
=
Literal
[
"raw_logits"
,
"raw_logprobs"
,
"processed_logits"
,
LogprobsMode
=
Literal
[
"raw_logits"
,
"raw_logprobs"
,
"processed_logits"
,
"processed_logprobs"
]
"processed_logprobs"
]
HfOverrides
=
Union
[
dict
[
str
,
Any
],
Callable
[[
type
],
type
]]
HfOverrides
=
Union
[
dict
[
str
,
Any
],
Callable
[[
PretrainedConfig
],
PretrainedConfig
]]
ModelImpl
=
Literal
[
"auto"
,
"vllm"
,
"transformers"
,
"terratorch"
]
ModelImpl
=
Literal
[
"auto"
,
"vllm"
,
"transformers"
,
"terratorch"
]
_RUNNER_TASKS
:
dict
[
RunnerType
,
list
[
TaskOption
]]
=
{
_RUNNER_TASKS
:
dict
[
RunnerType
,
list
[
TaskOption
]]
=
{
...
...
vllm/model_executor/models/interfaces.py
View file @
0ea80c87
...
@@ -64,6 +64,12 @@ class SupportsMultiModal(Protocol):
...
@@ -64,6 +64,12 @@ class SupportsMultiModal(Protocol):
`multimodal_config.mm_encoder_tp_mode="data"`.
`multimodal_config.mm_encoder_tp_mode="data"`.
"""
"""
merge_by_field_config
:
ClassVar
[
bool
]
=
False
"""
A flag that indicates which implementation of
`vllm.multimodal.utils.group_mm_kwargs_by_modality` to use.
"""
@
classmethod
@
classmethod
def
get_placeholder_str
(
cls
,
modality
:
str
,
i
:
int
)
->
Optional
[
str
]:
def
get_placeholder_str
(
cls
,
modality
:
str
,
i
:
int
)
->
Optional
[
str
]:
"""
"""
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
0ea80c87
...
@@ -40,7 +40,8 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
...
@@ -40,7 +40,8 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from
vllm.model_executor.layers.mamba.abstract
import
MambaBase
from
vllm.model_executor.layers.mamba.abstract
import
MambaBase
from
vllm.model_executor.layers.rotary_embedding
import
MRotaryEmbedding
from
vllm.model_executor.layers.rotary_embedding
import
MRotaryEmbedding
from
vllm.model_executor.model_loader
import
TensorizerLoader
,
get_model_loader
from
vllm.model_executor.model_loader
import
TensorizerLoader
,
get_model_loader
from
vllm.model_executor.models.interfaces
import
(
is_mixture_of_experts
,
from
vllm.model_executor.models.interfaces
import
(
SupportsMultiModal
,
is_mixture_of_experts
,
supports_eagle3
,
supports_eagle3
,
supports_mrope
,
supports_mrope
,
supports_transcription
)
supports_transcription
)
...
@@ -777,11 +778,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -777,11 +778,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
mm_kwargs
.
append
(
feature
.
data
)
mm_kwargs
.
append
(
feature
.
data
)
# Input all modalities at once
# Input all modalities at once
model
=
cast
(
SupportsMultiModal
,
self
.
model
)
mm_kwargs_combined
:
BatchedTensorInputs
=
{}
mm_kwargs_combined
:
BatchedTensorInputs
=
{}
for
_
,
_
,
mm_kwargs_group
in
group_mm_kwargs_by_modality
(
for
_
,
_
,
mm_kwargs_group
in
group_mm_kwargs_by_modality
(
mm_kwargs
,
mm_kwargs
,
device
=
self
.
device
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
pin_memory
=
self
.
pin_memory
,
merge_by_field_config
=
model
.
merge_by_field_config
,
):
):
mm_kwargs_combined
.
update
(
mm_kwargs_group
)
mm_kwargs_combined
.
update
(
mm_kwargs_group
)
...
@@ -1525,11 +1528,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -1525,11 +1528,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# in the same batch while still being able to benefit from batching
# in the same batch while still being able to benefit from batching
# multimodal inputs. The proper solution should be reordering the
# multimodal inputs. The proper solution should be reordering the
# encoder outputs.
# encoder outputs.
model
=
cast
(
SupportsMultiModal
,
self
.
model
)
encoder_outputs
=
[]
encoder_outputs
=
[]
for
_
,
num_items
,
mm_kwargs_group
in
group_mm_kwargs_by_modality
(
for
_
,
num_items
,
mm_kwargs_group
in
group_mm_kwargs_by_modality
(
mm_kwargs
,
mm_kwargs
,
device
=
self
.
device
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
pin_memory
=
self
.
pin_memory
,
merge_by_field_config
=
model
.
merge_by_field_config
,
):
):
# Run the encoder.
# Run the encoder.
# `curr_group_outputs` is either of the following:
# `curr_group_outputs` is either of the following:
...
@@ -1538,7 +1543,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -1538,7 +1543,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# 2. A list or tuple (length: num_items) of tensors, each of shape
# 2. A list or tuple (length: num_items) of tensors, each of shape
# (feature_size, hidden_size) in case the feature size is dynamic
# (feature_size, hidden_size) in case the feature size is dynamic
# depending on the input multimodal items.
# depending on the input multimodal items.
curr_group_outputs
=
self
.
model
.
get_multimodal_embeddings
(
curr_group_outputs
=
model
.
get_multimodal_embeddings
(
**
mm_kwargs_group
)
**
mm_kwargs_group
)
sanity_check_mm_encoder_outputs
(
sanity_check_mm_encoder_outputs
(
...
@@ -1623,11 +1628,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -1623,11 +1628,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return
{}
return
{}
# Group MM kwargs by modality and extract features
# Group MM kwargs by modality and extract features
model
=
cast
(
SupportsMultiModal
,
self
.
model
)
encoder_features
=
{}
encoder_features
=
{}
for
_
,
_
,
mm_kwargs_group
in
group_mm_kwargs_by_modality
(
for
_
,
_
,
mm_kwargs_group
in
group_mm_kwargs_by_modality
(
mm_kwargs
,
mm_kwargs
,
device
=
self
.
device
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
pin_memory
=
self
.
pin_memory
,
merge_by_field_config
=
model
.
merge_by_field_config
,
):
):
# Add the grouped features to encoder_features dict
# Add the grouped features to encoder_features dict
# This allows the model to receive them as kwargs (e.g.,
# This allows the model to receive them as kwargs (e.g.,
...
@@ -2839,11 +2846,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -2839,11 +2846,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dummy_mm_item
=
dummy_mm_data
[
modality
][
0
]
dummy_mm_item
=
dummy_mm_data
[
modality
][
0
]
dummy_mm_items
=
[
dummy_mm_item
]
*
max_items_per_batch
dummy_mm_items
=
[
dummy_mm_item
]
*
max_items_per_batch
model
=
cast
(
SupportsMultiModal
,
self
.
model
)
return
next
(
mm_kwargs_group
return
next
(
mm_kwargs_group
for
_
,
_
,
mm_kwargs_group
in
group_mm_kwargs_by_modality
(
for
_
,
_
,
mm_kwargs_group
in
group_mm_kwargs_by_modality
(
dummy_mm_items
,
dummy_mm_items
,
device
=
self
.
device
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
pin_memory
=
self
.
pin_memory
,
merge_by_field_config
=
model
.
merge_by_field_config
,
))
))
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
0ea80c87
...
@@ -30,7 +30,8 @@ from vllm.logger import init_logger
...
@@ -30,7 +30,8 @@ from vllm.logger import init_logger
from
vllm.lora.layers
import
BaseLayerWithLoRA
from
vllm.lora.layers
import
BaseLayerWithLoRA
from
vllm.model_executor.model_loader
import
get_model_loader
from
vllm.model_executor.model_loader
import
get_model_loader
from
vllm.model_executor.model_loader.tpu
import
TPUModelLoader
from
vllm.model_executor.model_loader.tpu
import
TPUModelLoader
from
vllm.model_executor.models.interfaces
import
supports_transcription
from
vllm.model_executor.models.interfaces
import
(
SupportsMultiModal
,
supports_transcription
)
from
vllm.model_executor.models.interfaces_base
import
(
from
vllm.model_executor.models.interfaces_base
import
(
is_pooling_model
,
is_text_generation_model
)
is_pooling_model
,
is_text_generation_model
)
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
...
@@ -834,11 +835,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -834,11 +835,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# in the same batch while still being able to benefit from batching
# in the same batch while still being able to benefit from batching
# multimodal inputs. The proper solution should be reordering the
# multimodal inputs. The proper solution should be reordering the
# encoder outputs.
# encoder outputs.
model
=
cast
(
SupportsMultiModal
,
self
.
model
)
encoder_outputs
=
[]
encoder_outputs
=
[]
for
_
,
num_items
,
mm_kwargs_group
in
group_mm_kwargs_by_modality
(
for
_
,
num_items
,
mm_kwargs_group
in
group_mm_kwargs_by_modality
(
mm_kwargs
,
mm_kwargs
,
device
=
self
.
device
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
pin_memory
=
self
.
pin_memory
,
merge_by_field_config
=
model
.
merge_by_field_config
,
):
):
# Run the encoder.
# Run the encoder.
# `curr_group_outputs` is either of the following:
# `curr_group_outputs` is either of the following:
...
@@ -848,7 +851,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -848,7 +851,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# (feature_size, hidden_size) in case the feature size is dynamic
# (feature_size, hidden_size) in case the feature size is dynamic
# depending on the input multimodal items.
# depending on the input multimodal items.
torch_xla
.
sync
(
wait
=
False
)
torch_xla
.
sync
(
wait
=
False
)
curr_group_outputs
=
self
.
model
.
get_multimodal_embeddings
(
curr_group_outputs
=
model
.
get_multimodal_embeddings
(
**
mm_kwargs_group
)
**
mm_kwargs_group
)
torch_xla
.
sync
(
wait
=
False
)
torch_xla
.
sync
(
wait
=
False
)
...
@@ -1805,11 +1808,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -1805,11 +1808,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dummy_mm_item
=
dummy_mm_data
[
modality
][
0
]
dummy_mm_item
=
dummy_mm_data
[
modality
][
0
]
dummy_mm_items
=
[
dummy_mm_item
]
*
max_items_per_batch
dummy_mm_items
=
[
dummy_mm_item
]
*
max_items_per_batch
model
=
cast
(
SupportsMultiModal
,
self
.
model
)
return
next
(
grouped_mm_kwargs
return
next
(
grouped_mm_kwargs
for
_
,
_
,
grouped_mm_kwargs
in
group_mm_kwargs_by_modality
(
for
_
,
_
,
grouped_mm_kwargs
in
group_mm_kwargs_by_modality
(
dummy_mm_items
,
dummy_mm_items
,
device
=
self
.
device
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
pin_memory
=
self
.
pin_memory
,
merge_by_field_config
=
model
.
merge_by_field_config
,
))
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment