Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
eb28e806
Unverified
Commit
eb28e806
authored
Jan 13, 2026
by
Cyrus Leung
Committed by
GitHub
Jan 13, 2026
Browse files
[Refactor] Remove `get_encoder_dummy_data` (#32241)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
542a4059
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
21 additions
and
82 deletions
+21
-82
vllm/model_executor/models/nemotron_parse.py
vllm/model_executor/models/nemotron_parse.py
+4
-4
vllm/model_executor/models/whisper.py
vllm/model_executor/models/whisper.py
+4
-4
vllm/multimodal/processing.py
vllm/multimodal/processing.py
+4
-4
vllm/multimodal/profiling.py
vllm/multimodal/profiling.py
+1
-25
vllm/multimodal/registry.py
vllm/multimodal/registry.py
+0
-38
vllm/v1/engine/input_processor.py
vllm/v1/engine/input_processor.py
+8
-7
No files found.
vllm/model_executor/models/nemotron_parse.py
View file @
eb28e806
...
@@ -605,6 +605,10 @@ class NemotronParseProcessingInfo(BaseProcessingInfo):
...
@@ -605,6 +605,10 @@ class NemotronParseProcessingInfo(BaseProcessingInfo):
**
kwargs
,
**
kwargs
,
)
)
@
property
def
skip_prompt_length_check
(
self
)
->
bool
:
return
True
# Because the encoder prompt is padded
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
int
|
None
]:
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
int
|
None
]:
return
{
"image"
:
1
}
return
{
"image"
:
1
}
...
@@ -657,10 +661,6 @@ class NemotronParseMultiModalProcessor(
...
@@ -657,10 +661,6 @@ class NemotronParseMultiModalProcessor(
)
->
str
|
list
[
int
]:
)
->
str
|
list
[
int
]:
return
[
0
]
return
[
0
]
@
property
def
pad_dummy_encoder_prompt
(
self
)
->
bool
:
return
True
def
_call_hf_processor
(
def
_call_hf_processor
(
self
,
self
,
prompt
:
str
,
prompt
:
str
,
...
...
vllm/model_executor/models/whisper.py
View file @
eb28e806
...
@@ -681,6 +681,10 @@ class WhisperProcessingInfo(BaseProcessingInfo):
...
@@ -681,6 +681,10 @@ class WhisperProcessingInfo(BaseProcessingInfo):
def
get_hf_config
(
self
)
->
WhisperConfig
:
def
get_hf_config
(
self
)
->
WhisperConfig
:
return
self
.
ctx
.
get_hf_config
(
WhisperConfig
)
return
self
.
ctx
.
get_hf_config
(
WhisperConfig
)
@
property
def
skip_prompt_length_check
(
self
)
->
bool
:
return
True
# Because the encoder prompt is padded
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
int
|
None
]:
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
int
|
None
]:
return
{
"audio"
:
1
}
return
{
"audio"
:
1
}
...
@@ -733,10 +737,6 @@ class WhisperMultiModalProcessor(EncDecMultiModalProcessor[WhisperProcessingInfo
...
@@ -733,10 +737,6 @@ class WhisperMultiModalProcessor(EncDecMultiModalProcessor[WhisperProcessingInfo
target_channels
=
self
.
info
.
get_target_channels
(),
target_channels
=
self
.
info
.
get_target_channels
(),
)
)
@
property
def
pad_dummy_encoder_prompt
(
self
)
->
bool
:
return
True
def
create_encoder_prompt
(
def
create_encoder_prompt
(
self
,
self
,
prompt
:
str
|
list
[
int
],
prompt
:
str
|
list
[
int
],
...
...
vllm/multimodal/processing.py
View file @
eb28e806
...
@@ -1396,6 +1396,10 @@ class BaseProcessingInfo:
...
@@ -1396,6 +1396,10 @@ class BaseProcessingInfo:
"""
"""
return
self
.
ctx
.
get_hf_processor
(
**
kwargs
)
return
self
.
ctx
.
get_hf_processor
(
**
kwargs
)
@
property
def
skip_prompt_length_check
(
self
)
->
bool
:
return
False
@
abstractmethod
@
abstractmethod
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
int
|
None
]:
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
int
|
None
]:
"""
"""
...
@@ -2403,10 +2407,6 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
...
@@ -2403,10 +2407,6 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
"""
"""
raise
NotImplementedError
raise
NotImplementedError
@
property
def
pad_dummy_encoder_prompt
(
self
)
->
bool
:
return
False
def
create_decoder_prompt
(
def
create_decoder_prompt
(
self
,
self
,
prompt
:
str
|
list
[
int
],
prompt
:
str
|
list
[
int
],
...
...
vllm/multimodal/profiling.py
View file @
eb28e806
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Mapping
from
collections.abc
import
Mapping
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
typing
import
Generic
,
NamedTuple
,
TypeVar
,
cast
from
typing
import
Generic
,
NamedTuple
,
TypeVar
import
numpy
as
np
import
numpy
as
np
import
numpy.typing
as
npt
import
numpy.typing
as
npt
...
@@ -19,7 +19,6 @@ from vllm.logger import init_logger
...
@@ -19,7 +19,6 @@ from vllm.logger import init_logger
from
.inputs
import
(
from
.inputs
import
(
MultiModalDataDict
,
MultiModalDataDict
,
MultiModalEncDecInputs
,
MultiModalInputs
,
MultiModalInputs
,
MultiModalKwargsItems
,
MultiModalKwargsItems
,
MultiModalPlaceholderDict
,
MultiModalPlaceholderDict
,
...
@@ -27,7 +26,6 @@ from .inputs import (
...
@@ -27,7 +26,6 @@ from .inputs import (
from
.processing
import
(
from
.processing
import
(
BaseMultiModalProcessor
,
BaseMultiModalProcessor
,
BaseProcessingInfo
,
BaseProcessingInfo
,
EncDecMultiModalProcessor
,
)
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -282,28 +280,6 @@ class MultiModalProfiler(Generic[_I]):
...
@@ -282,28 +280,6 @@ class MultiModalProfiler(Generic[_I]):
for
modality
,
placeholders
in
placeholders_by_modality
.
items
()
for
modality
,
placeholders
in
placeholders_by_modality
.
items
()
}
}
def
get_encoder_dummy_data
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
]
|
None
=
None
,
mm_options
:
Mapping
[
str
,
BaseDummyOptions
]
|
None
=
None
,
)
->
DummyEncoderData
:
mm_inputs
=
self
.
_get_dummy_mm_inputs
(
seq_len
,
mm_counts
,
mm_options
)
mm_inputs
=
cast
(
MultiModalEncDecInputs
,
mm_inputs
)
# For encoder-decoder models, use encoder prompt token ids instead of
# decoder prompt to construct dummy seq_data for encoder profiling.
encoder_prompt_token_ids
=
mm_inputs
[
"encoder_prompt_token_ids"
]
total_len
=
len
(
encoder_prompt_token_ids
)
processor
=
cast
(
EncDecMultiModalProcessor
,
self
.
processor
)
if
processor
.
pad_dummy_encoder_prompt
:
num_tokens_to_pad
=
max
(
total_len
,
seq_len
)
-
total_len
encoder_prompt_token_ids
.
extend
([
0
]
*
num_tokens_to_pad
)
return
DummyEncoderData
(
encoder_prompt_token_ids
)
def
get_decoder_dummy_data
(
def
get_decoder_dummy_data
(
self
,
self
,
seq_len
:
int
,
seq_len
:
int
,
...
...
vllm/multimodal/registry.py
View file @
eb28e806
...
@@ -18,7 +18,6 @@ from .processing import (
...
@@ -18,7 +18,6 @@ from .processing import (
from
.profiling
import
(
from
.profiling
import
(
BaseDummyInputsBuilder
,
BaseDummyInputsBuilder
,
DummyDecoderData
,
DummyDecoderData
,
DummyEncoderData
,
MultiModalProfiler
,
MultiModalProfiler
,
)
)
...
@@ -317,43 +316,6 @@ class MultiModalRegistry:
...
@@ -317,43 +316,6 @@ class MultiModalRegistry:
return
dummy_data
return
dummy_data
def
get_encoder_dummy_data
(
self
,
model_config
:
"ModelConfig"
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
]
|
None
=
None
,
*
,
cache
:
BaseMultiModalProcessorCache
|
None
=
None
,
observability_config
:
ObservabilityConfig
|
None
=
None
,
)
->
DummyEncoderData
:
"""
Create dummy data for profiling the memory usage of a model.
The model is identified by `model_config`.
"""
processor
=
self
.
create_processor
(
model_config
,
observability_config
,
cache
=
cache
)
profiler
:
MultiModalProfiler
=
MultiModalProfiler
(
processor
)
# Extract configurable options from multimodal config.
# Only include modalities that use advanced option types so legacy
# count-only behavior remains unchanged.
mm_options
=
self
.
_extract_mm_options
(
model_config
)
dummy_data
=
profiler
.
get_encoder_dummy_data
(
seq_len
,
mm_counts
,
mm_options
)
# Having more tokens is over-conservative but otherwise fine
token_ids
=
dummy_data
.
prompt_token_ids
if
len
(
token_ids
)
<
seq_len
:
logger
.
warning_once
(
"Expected at least %d dummy encoder tokens for profiling, but found %d tokens instead."
,
# noqa: E501
seq_len
,
len
(
token_ids
),
)
return
dummy_data
def
get_encdec_max_encoder_len
(
self
,
model_config
:
"ModelConfig"
)
->
int
:
def
get_encdec_max_encoder_len
(
self
,
model_config
:
"ModelConfig"
)
->
int
:
"""
"""
Get the maximum length of the encoder input for encoder-decoder models.
Get the maximum length of the encoder input for encoder-decoder models.
...
...
vllm/v1/engine/input_processor.py
View file @
eb28e806
...
@@ -17,7 +17,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
...
@@ -17,7 +17,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from
vllm.multimodal.cache
import
processor_cache_from_config
from
vllm.multimodal.cache
import
processor_cache_from_config
from
vllm.multimodal.inputs
import
MultiModalFeatureSpec
,
MultiModalUUIDDict
from
vllm.multimodal.inputs
import
MultiModalFeatureSpec
,
MultiModalUUIDDict
from
vllm.multimodal.parse
import
MultiModalDataParser
from
vllm.multimodal.parse
import
MultiModalDataParser
from
vllm.multimodal.processing
import
EncDecMultiModalProcessor
,
set_request_id
from
vllm.multimodal.processing
import
set_request_id
from
vllm.multimodal.utils
import
argsort_mm_positions
from
vllm.multimodal.utils
import
argsort_mm_positions
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
_SAMPLING_EPS
,
SamplingParams
from
vllm.sampling_params
import
_SAMPLING_EPS
,
SamplingParams
...
@@ -655,17 +655,18 @@ class InputProcessor:
...
@@ -655,17 +655,18 @@ class InputProcessor:
max_prompt_len
=
self
.
model_config
.
max_model_len
max_prompt_len
=
self
.
model_config
.
max_model_len
if
prompt_len
>
max_prompt_len
:
if
prompt_len
>
max_prompt_len
:
if
prompt_type
==
"encoder"
and
model_config
.
is_multimodal_model
:
if
model_config
.
is_multimodal_model
:
mm_registry
=
self
.
input_preprocessor
.
mm_registry
mm_registry
=
self
.
input_preprocessor
.
mm_registry
mm_processor
=
mm_registry
.
create_processor
(
model_cls
=
mm_registry
.
_get_model_cls
(
model_config
)
factories
=
model_cls
.
_processor_factory
ctx
=
mm_registry
.
_create_processing_ctx
(
model_config
,
model_config
,
self
.
vllm_config
.
observability_config
,
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
)
)
assert
isinstance
(
mm_processor
,
EncDecMultiModalProcessor
)
mm_info
=
factories
.
info
(
ctx
)
if
mm_
processor
.
pad_dummy_encoder_prompt
:
if
mm_
info
.
skip_prompt_length_check
:
return
# Skip encoder length check for Whisper
return
if
model_config
.
is_multimodal_model
:
if
model_config
.
is_multimodal_model
:
suggestion
=
(
suggestion
=
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment