Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1409ef91
Unverified
Commit
1409ef91
authored
Jun 04, 2025
by
Lukas Geiger
Committed by
GitHub
Jun 03, 2025
Browse files
[Core] Cast multimodal input in hf processor (#18862)
Signed-off-by:
Lukas Geiger
<
lukas.geiger94@gmail.com
>
parent
4555143e
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
25 additions
and
25 deletions
+25
-25
vllm/inputs/registry.py
vllm/inputs/registry.py
+24
-2
vllm/multimodal/inputs.py
vllm/multimodal/inputs.py
+1
-7
vllm/spec_decode/draft_model_runner.py
vllm/spec_decode/draft_model_runner.py
+0
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+0
-2
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+0
-2
vllm/worker/cpu_enc_dec_model_runner.py
vllm/worker/cpu_enc_dec_model_runner.py
+0
-1
vllm/worker/cpu_model_runner.py
vllm/worker/cpu_model_runner.py
+0
-1
vllm/worker/cpu_pooling_model_runner.py
vllm/worker/cpu_pooling_model_runner.py
+0
-1
vllm/worker/enc_dec_model_runner.py
vllm/worker/enc_dec_model_runner.py
+0
-1
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+0
-1
vllm/worker/multi_step_neuron_model_runner.py
vllm/worker/multi_step_neuron_model_runner.py
+0
-1
vllm/worker/multi_step_neuronx_distributed_model_runner.py
vllm/worker/multi_step_neuronx_distributed_model_runner.py
+0
-1
vllm/worker/neuron_model_runner.py
vllm/worker/neuron_model_runner.py
+0
-2
vllm/worker/pooling_model_runner.py
vllm/worker/pooling_model_runner.py
+0
-1
vllm/worker/xpu_model_runner.py
vllm/worker/xpu_model_runner.py
+0
-1
No files found.
vllm/inputs/registry.py
View file @
1409ef91
...
...
@@ -4,9 +4,12 @@ from collections.abc import Mapping
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
NamedTuple
,
Optional
,
Union
import
torch
from
transformers
import
BatchFeature
,
PretrainedConfig
,
ProcessorMixin
from
typing_extensions
import
TypeVar
from
vllm.jsontree
import
JSONTree
,
json_map_leaves
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.processor
import
cached_processor_from_config
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.utils
import
resolve_mm_processor_kwargs
...
...
@@ -21,6 +24,8 @@ _T = TypeVar("_T")
_C
=
TypeVar
(
"_C"
,
bound
=
PretrainedConfig
,
default
=
PretrainedConfig
)
_P
=
TypeVar
(
"_P"
,
bound
=
ProcessorMixin
,
default
=
ProcessorMixin
)
logger
=
init_logger
(
__name__
)
@
dataclass
(
frozen
=
True
)
class
InputContext
:
...
...
@@ -134,7 +139,7 @@ class InputProcessingContext(InputContext):
hf_processor
:
ProcessorMixin
,
data
:
Mapping
[
str
,
object
],
kwargs
:
Mapping
[
str
,
object
]
=
{},
)
->
BatchFeature
:
)
->
Union
[
BatchFeature
,
JSONTree
]
:
"""
Call `hf_processor` on the prompt `data`
(text, image, audio...) with configurable options `kwargs`.
...
...
@@ -154,8 +159,25 @@ class InputProcessingContext(InputContext):
allow_var_kwargs
=
True
,
)
def
maybe_cast_dtype
(
x
):
# This mimics the behavior of transformers.BatchFeature
if
isinstance
(
x
,
torch
.
Tensor
)
and
x
.
is_floating_point
():
return
x
.
to
(
dtype
=
self
.
model_config
.
dtype
)
return
x
try
:
return
hf_processor
(
**
data
,
**
merged_kwargs
,
return_tensors
=
"pt"
)
output
=
hf_processor
(
**
data
,
**
merged_kwargs
,
return_tensors
=
"pt"
)
# this emulates output.to(dtype=self.model_config.dtype)
cast_output
=
json_map_leaves
(
maybe_cast_dtype
,
output
)
if
isinstance
(
output
,
BatchFeature
):
return
BatchFeature
(
cast_output
)
logger
.
warning_once
(
f
"
{
type
(
hf_processor
).
__name__
}
did not return `BatchFeature`. "
"Make sure to match the behaviour of `ProcessorMixin` when "
"implementing custom processors."
)
return
cast_output
except
Exception
as
exc
:
msg
=
(
f
"Failed to apply
{
type
(
hf_processor
).
__name__
}
"
f
"on data=
{
data
}
with kwargs=
{
merged_kwargs
}
"
)
...
...
vllm/multimodal/inputs.py
View file @
1409ef91
...
...
@@ -747,17 +747,11 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
batched_inputs
:
BatchedTensorInputs
,
*
,
device
:
torch
.
types
.
Device
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
)
->
BatchedTensorInputs
:
json_inputs
=
cast
(
JSONTree
[
torch
.
Tensor
],
batched_inputs
)
def
maybe_cast_dtype
(
x
:
torch
.
Tensor
):
# This mimics the behavior of transformers.BatchFeature
return
x
.
to
(
dtype
=
dtype
)
if
x
.
is_floating_point
()
else
x
json_mapped
=
json_map_leaves
(
# NOTE: Cast the dtype before sending it to device
lambda
x
:
maybe_cast_dtype
(
x
).
to
(
device
=
device
,
non_blocking
=
True
),
lambda
x
:
x
.
to
(
device
=
device
,
non_blocking
=
True
),
json_inputs
,
)
...
...
vllm/spec_decode/draft_model_runner.py
View file @
1409ef91
...
...
@@ -297,7 +297,6 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
intermediate_tensors
=
intermediate_tensors
,
**
MultiModalKwargs
.
as_kwargs
(
multi_modal_kwargs
,
dtype
=
self
.
model_runner
.
model_config
.
dtype
,
device
=
self
.
device
,
),
**
model_execute_kwargs
,
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
1409ef91
...
...
@@ -957,7 +957,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
batched_mm_inputs
=
MultiModalKwargs
.
batch
(
grouped_mm_inputs
)
batched_mm_inputs
=
MultiModalKwargs
.
as_kwargs
(
batched_mm_inputs
,
dtype
=
self
.
model_config
.
dtype
,
device
=
self
.
device
,
)
...
...
@@ -1951,7 +1950,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
[
dummy_mm_kwargs
]
*
max_num_mm_items
)
batched_dummy_mm_inputs
=
MultiModalKwargs
.
as_kwargs
(
batched_dummy_mm_inputs
,
dtype
=
self
.
model_config
.
dtype
,
device
=
self
.
device
,
)
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
1409ef91
...
...
@@ -718,7 +718,6 @@ class TPUModelRunner(LoRAModelRunnerMixin):
batched_mm_inputs
=
MultiModalKwargs
.
batch
(
grouped_mm_inputs
)
batched_mm_inputs
=
MultiModalKwargs
.
as_kwargs
(
batched_mm_inputs
,
dtype
=
self
.
model_config
.
dtype
,
device
=
self
.
device
,
)
...
...
@@ -1560,7 +1559,6 @@ class TPUModelRunner(LoRAModelRunnerMixin):
batch_size
)
return
MultiModalKwargs
.
as_kwargs
(
batched_dummy_mm_inputs
,
dtype
=
self
.
model_config
.
dtype
,
device
=
self
.
device
,
)
...
...
vllm/worker/cpu_enc_dec_model_runner.py
View file @
1409ef91
...
...
@@ -300,7 +300,6 @@ class CPUEncoderDecoderModelRunner(
model_input
.
encoder_input_positions
,
**
MultiModalKwargs
.
as_kwargs
(
model_input
.
multi_modal_kwargs
or
{},
dtype
=
self
.
model_config
.
dtype
,
device
=
self
.
device
,
),
"intermediate_tensors"
:
...
...
vllm/worker/cpu_model_runner.py
View file @
1409ef91
...
...
@@ -630,7 +630,6 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
if
model_input
.
multi_modal_kwargs
is
not
None
:
multimodal_kwargs
=
MultiModalKwargs
.
as_kwargs
(
model_input
.
multi_modal_kwargs
,
dtype
=
self
.
model_config
.
dtype
,
device
=
self
.
device
,
)
execute_model_kwargs
=
{}
...
...
vllm/worker/cpu_pooling_model_runner.py
View file @
1409ef91
...
...
@@ -53,7 +53,6 @@ class CPUPoolingModelRunner(
model_input
.
input_positions
,
**
MultiModalKwargs
.
as_kwargs
(
model_input
.
multi_modal_kwargs
or
{},
dtype
=
self
.
model_config
.
dtype
,
device
=
self
.
device
,
),
**
cross_enc_kwargs
,
...
...
vllm/worker/enc_dec_model_runner.py
View file @
1409ef91
...
...
@@ -205,7 +205,6 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
intermediate_tensors
=
intermediate_tensors
,
**
MultiModalKwargs
.
as_kwargs
(
multi_modal_kwargs
,
dtype
=
self
.
model_config
.
dtype
,
device
=
self
.
device
,
),
**
seqlen_agnostic_kwargs
,
...
...
vllm/worker/model_runner.py
View file @
1409ef91
...
...
@@ -1848,7 +1848,6 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
intermediate_tensors
=
intermediate_tensors
,
**
MultiModalKwargs
.
as_kwargs
(
multi_modal_kwargs
,
dtype
=
self
.
model_config
.
dtype
,
device
=
self
.
device
,
),
**
seqlen_agnostic_kwargs
,
...
...
vllm/worker/multi_step_neuron_model_runner.py
View file @
1409ef91
...
...
@@ -73,7 +73,6 @@ class MultiStepNeuronModelRunner(NeuronModelRunner):
input_block_ids
=
model_input
.
input_block_ids
,
**
MultiModalKwargs
.
as_kwargs
(
model_input
.
multi_modal_kwargs
or
{},
dtype
=
self
.
model_config
.
dtype
,
device
=
self
.
device
,
),
)
...
...
vllm/worker/multi_step_neuronx_distributed_model_runner.py
View file @
1409ef91
...
...
@@ -52,7 +52,6 @@ class MultiStepNeuronxDistributedModelRunner(NeuronxDistributedModelRunner):
sampling_params
=
sampling_params
,
**
MultiModalKwargs
.
as_kwargs
(
model_input
.
multi_modal_kwargs
or
{},
dtype
=
self
.
model_config
.
dtype
,
device
=
self
.
device
,
),
)
...
...
vllm/worker/neuron_model_runner.py
View file @
1409ef91
...
...
@@ -395,7 +395,6 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
adapter_ids
=
model_input
.
adapter_ids
,
**
MultiModalKwargs
.
as_kwargs
(
model_input
.
multi_modal_kwargs
or
{},
dtype
=
self
.
model_config
.
dtype
,
device
=
self
.
device
,
),
)
...
...
@@ -408,7 +407,6 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
input_block_ids
=
model_input
.
input_block_ids
,
**
MultiModalKwargs
.
as_kwargs
(
model_input
.
multi_modal_kwargs
or
{},
dtype
=
self
.
model_config
.
dtype
,
device
=
self
.
device
,
),
)
...
...
vllm/worker/pooling_model_runner.py
View file @
1409ef91
...
...
@@ -122,7 +122,6 @@ class PoolingModelRunner(
intermediate_tensors
=
intermediate_tensors
,
**
MultiModalKwargs
.
as_kwargs
(
multi_modal_kwargs
,
dtype
=
self
.
model_config
.
dtype
,
device
=
self
.
device
,
),
**
cross_enc_kwargs
,
...
...
vllm/worker/xpu_model_runner.py
View file @
1409ef91
...
...
@@ -565,7 +565,6 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
intermediate_tensors
=
intermediate_tensors
,
**
MultiModalKwargs
.
as_kwargs
(
model_input
.
multi_modal_kwargs
or
{},
dtype
=
self
.
model_config
.
dtype
,
device
=
self
.
device
,
),
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment