Unverified Commit eb28e806 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Remove `get_encoder_dummy_data` (#32241)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 542a4059
...@@ -605,6 +605,10 @@ class NemotronParseProcessingInfo(BaseProcessingInfo): ...@@ -605,6 +605,10 @@ class NemotronParseProcessingInfo(BaseProcessingInfo):
**kwargs, **kwargs,
) )
@property
def skip_prompt_length_check(self) -> bool:
return True # Because the encoder prompt is padded
def get_supported_mm_limits(self) -> Mapping[str, int | None]: def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": 1} return {"image": 1}
...@@ -657,10 +661,6 @@ class NemotronParseMultiModalProcessor( ...@@ -657,10 +661,6 @@ class NemotronParseMultiModalProcessor(
) -> str | list[int]: ) -> str | list[int]:
return [0] return [0]
@property
def pad_dummy_encoder_prompt(self) -> bool:
return True
def _call_hf_processor( def _call_hf_processor(
self, self,
prompt: str, prompt: str,
......
...@@ -681,6 +681,10 @@ class WhisperProcessingInfo(BaseProcessingInfo): ...@@ -681,6 +681,10 @@ class WhisperProcessingInfo(BaseProcessingInfo):
def get_hf_config(self) -> WhisperConfig: def get_hf_config(self) -> WhisperConfig:
return self.ctx.get_hf_config(WhisperConfig) return self.ctx.get_hf_config(WhisperConfig)
@property
def skip_prompt_length_check(self) -> bool:
return True # Because the encoder prompt is padded
def get_supported_mm_limits(self) -> Mapping[str, int | None]: def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"audio": 1} return {"audio": 1}
...@@ -733,10 +737,6 @@ class WhisperMultiModalProcessor(EncDecMultiModalProcessor[WhisperProcessingInfo ...@@ -733,10 +737,6 @@ class WhisperMultiModalProcessor(EncDecMultiModalProcessor[WhisperProcessingInfo
target_channels=self.info.get_target_channels(), target_channels=self.info.get_target_channels(),
) )
@property
def pad_dummy_encoder_prompt(self) -> bool:
return True
def create_encoder_prompt( def create_encoder_prompt(
self, self,
prompt: str | list[int], prompt: str | list[int],
......
...@@ -1396,6 +1396,10 @@ class BaseProcessingInfo: ...@@ -1396,6 +1396,10 @@ class BaseProcessingInfo:
""" """
return self.ctx.get_hf_processor(**kwargs) return self.ctx.get_hf_processor(**kwargs)
@property
def skip_prompt_length_check(self) -> bool:
return False
@abstractmethod @abstractmethod
def get_supported_mm_limits(self) -> Mapping[str, int | None]: def get_supported_mm_limits(self) -> Mapping[str, int | None]:
""" """
...@@ -2403,10 +2407,6 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -2403,10 +2407,6 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
""" """
raise NotImplementedError raise NotImplementedError
@property
def pad_dummy_encoder_prompt(self) -> bool:
return False
def create_decoder_prompt( def create_decoder_prompt(
self, self,
prompt: str | list[int], prompt: str | list[int],
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Mapping from collections.abc import Mapping
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Generic, NamedTuple, TypeVar, cast from typing import Generic, NamedTuple, TypeVar
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
...@@ -19,7 +19,6 @@ from vllm.logger import init_logger ...@@ -19,7 +19,6 @@ from vllm.logger import init_logger
from .inputs import ( from .inputs import (
MultiModalDataDict, MultiModalDataDict,
MultiModalEncDecInputs,
MultiModalInputs, MultiModalInputs,
MultiModalKwargsItems, MultiModalKwargsItems,
MultiModalPlaceholderDict, MultiModalPlaceholderDict,
...@@ -27,7 +26,6 @@ from .inputs import ( ...@@ -27,7 +26,6 @@ from .inputs import (
from .processing import ( from .processing import (
BaseMultiModalProcessor, BaseMultiModalProcessor,
BaseProcessingInfo, BaseProcessingInfo,
EncDecMultiModalProcessor,
) )
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -282,28 +280,6 @@ class MultiModalProfiler(Generic[_I]): ...@@ -282,28 +280,6 @@ class MultiModalProfiler(Generic[_I]):
for modality, placeholders in placeholders_by_modality.items() for modality, placeholders in placeholders_by_modality.items()
} }
def get_encoder_dummy_data(
self,
seq_len: int,
mm_counts: Mapping[str, int] | None = None,
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> DummyEncoderData:
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts, mm_options)
mm_inputs = cast(MultiModalEncDecInputs, mm_inputs)
# For encoder-decoder models, use encoder prompt token ids instead of
# decoder prompt to construct dummy seq_data for encoder profiling.
encoder_prompt_token_ids = mm_inputs["encoder_prompt_token_ids"]
total_len = len(encoder_prompt_token_ids)
processor = cast(EncDecMultiModalProcessor, self.processor)
if processor.pad_dummy_encoder_prompt:
num_tokens_to_pad = max(total_len, seq_len) - total_len
encoder_prompt_token_ids.extend([0] * num_tokens_to_pad)
return DummyEncoderData(encoder_prompt_token_ids)
def get_decoder_dummy_data( def get_decoder_dummy_data(
self, self,
seq_len: int, seq_len: int,
......
...@@ -18,7 +18,6 @@ from .processing import ( ...@@ -18,7 +18,6 @@ from .processing import (
from .profiling import ( from .profiling import (
BaseDummyInputsBuilder, BaseDummyInputsBuilder,
DummyDecoderData, DummyDecoderData,
DummyEncoderData,
MultiModalProfiler, MultiModalProfiler,
) )
...@@ -317,43 +316,6 @@ class MultiModalRegistry: ...@@ -317,43 +316,6 @@ class MultiModalRegistry:
return dummy_data return dummy_data
def get_encoder_dummy_data(
self,
model_config: "ModelConfig",
seq_len: int,
mm_counts: Mapping[str, int] | None = None,
*,
cache: BaseMultiModalProcessorCache | None = None,
observability_config: ObservabilityConfig | None = None,
) -> DummyEncoderData:
"""
Create dummy data for profiling the memory usage of a model.
The model is identified by `model_config`.
"""
processor = self.create_processor(
model_config, observability_config, cache=cache
)
profiler: MultiModalProfiler = MultiModalProfiler(processor)
# Extract configurable options from multimodal config.
# Only include modalities that use advanced option types so legacy
# count-only behavior remains unchanged.
mm_options = self._extract_mm_options(model_config)
dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts, mm_options)
# Having more tokens is over-conservative but otherwise fine
token_ids = dummy_data.prompt_token_ids
if len(token_ids) < seq_len:
logger.warning_once(
"Expected at least %d dummy encoder tokens for profiling, but found %d tokens instead.", # noqa: E501
seq_len,
len(token_ids),
)
return dummy_data
def get_encdec_max_encoder_len(self, model_config: "ModelConfig") -> int: def get_encdec_max_encoder_len(self, model_config: "ModelConfig") -> int:
""" """
Get the maximum length of the encoder input for encoder-decoder models. Get the maximum length of the encoder input for encoder-decoder models.
......
...@@ -17,7 +17,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry ...@@ -17,7 +17,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import processor_cache_from_config from vllm.multimodal.cache import processor_cache_from_config
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict
from vllm.multimodal.parse import MultiModalDataParser from vllm.multimodal.parse import MultiModalDataParser
from vllm.multimodal.processing import EncDecMultiModalProcessor, set_request_id from vllm.multimodal.processing import set_request_id
from vllm.multimodal.utils import argsort_mm_positions from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import _SAMPLING_EPS, SamplingParams from vllm.sampling_params import _SAMPLING_EPS, SamplingParams
...@@ -655,17 +655,18 @@ class InputProcessor: ...@@ -655,17 +655,18 @@ class InputProcessor:
max_prompt_len = self.model_config.max_model_len max_prompt_len = self.model_config.max_model_len
if prompt_len > max_prompt_len: if prompt_len > max_prompt_len:
if prompt_type == "encoder" and model_config.is_multimodal_model: if model_config.is_multimodal_model:
mm_registry = self.input_preprocessor.mm_registry mm_registry = self.input_preprocessor.mm_registry
mm_processor = mm_registry.create_processor( model_cls = mm_registry._get_model_cls(model_config)
factories = model_cls._processor_factory
ctx = mm_registry._create_processing_ctx(
model_config, model_config,
self.vllm_config.observability_config,
tokenizer=tokenizer, tokenizer=tokenizer,
) )
assert isinstance(mm_processor, EncDecMultiModalProcessor) mm_info = factories.info(ctx)
if mm_processor.pad_dummy_encoder_prompt: if mm_info.skip_prompt_length_check:
return # Skip encoder length check for Whisper return
if model_config.is_multimodal_model: if model_config.is_multimodal_model:
suggestion = ( suggestion = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment