Unverified Commit 83b824c8 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[VLM] Remove `BaseProcessingInfo.get_mm_max_tokens_per_item` (#16408)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 7678fcd5
...@@ -162,13 +162,6 @@ class BaseLlavaProcessingInfo(BaseProcessingInfo): ...@@ -162,13 +162,6 @@ class BaseLlavaProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def get_num_image_tokens( def get_num_image_tokens(
self, self,
*, *,
...@@ -186,14 +179,6 @@ class BaseLlavaProcessingInfo(BaseProcessingInfo): ...@@ -186,14 +179,6 @@ class BaseLlavaProcessingInfo(BaseProcessingInfo):
width = height = vision_encoder_info.get_image_size() width = height = vision_encoder_info.get_image_size()
return ImageSize(width=width, height=height) return ImageSize(width=width, height=height)
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
)
_I = TypeVar("_I", bound=BaseLlavaProcessingInfo) _I = TypeVar("_I", bound=BaseLlavaProcessingInfo)
......
...@@ -106,16 +106,6 @@ class MllamaProcessingInfo(BaseProcessingInfo): ...@@ -106,16 +106,6 @@ class MllamaProcessingInfo(BaseProcessingInfo):
image_size = self.get_hf_config().vision_config.image_size image_size = self.get_hf_config().vision_config.image_size
return calc_token_per_chunk(image_size) return calc_token_per_chunk(image_size)
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
vision_config = self.get_hf_config().vision_config
token_per_chunk = self.get_token_per_chunk_from_config()
mm_max_tokens = vision_config.max_num_tiles * token_per_chunk
return {"image": mm_max_tokens}
def get_num_tiles_per_image(self, image_height: int, def get_num_tiles_per_image(self, image_height: int,
image_width: int) -> int: image_width: int) -> int:
vision_config = self.get_hf_config().vision_config vision_config = self.get_hf_config().vision_config
......
...@@ -498,17 +498,6 @@ class Mllama4ProcessingInfo(BaseProcessingInfo): ...@@ -498,17 +498,6 @@ class Mllama4ProcessingInfo(BaseProcessingInfo):
image_processor = self.get_hf_processor().image_processor image_processor = self.get_hf_processor().image_processor
return image_processor.max_patches return image_processor.max_patches
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
vision_config = self.get_hf_config().vision_config
patch_per_chunk = self.get_patch_per_chunk(vision_config)
num_patches = self.get_max_num_tiles() + 1
return {"image": patch_per_chunk * num_patches}
def get_image_size_with_most_features(self) -> ImageSize: def get_image_size_with_most_features(self) -> ImageSize:
vision_config = self.get_hf_config().vision_config vision_config = self.get_hf_config().vision_config
image_size = vision_config.image_size image_size = vision_config.image_size
...@@ -516,14 +505,6 @@ class Mllama4ProcessingInfo(BaseProcessingInfo): ...@@ -516,14 +505,6 @@ class Mllama4ProcessingInfo(BaseProcessingInfo):
return ImageSize(height=self.get_max_num_tiles() * image_size, return ImageSize(height=self.get_max_num_tiles() * image_size,
width=image_size) width=image_size)
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
)
class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo] class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
): ):
......
...@@ -1164,13 +1164,6 @@ class MolmoProcessingInfo(BaseProcessingInfo): ...@@ -1164,13 +1164,6 @@ class MolmoProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def get_num_image_tokens( def get_num_image_tokens(
self, self,
*, *,
...@@ -1195,15 +1188,6 @@ class MolmoProcessingInfo(BaseProcessingInfo): ...@@ -1195,15 +1188,6 @@ class MolmoProcessingInfo(BaseProcessingInfo):
return extra + joint return extra + joint
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
processor=None,
)
def get_image_size_with_most_features(self) -> ImageSize: def get_image_size_with_most_features(self) -> ImageSize:
processor = self.get_hf_processor() processor = self.get_hf_processor()
......
...@@ -13,7 +13,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -13,7 +13,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, MultiModalKwargs) MultiModalInputs, MultiModalKwargs)
from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptIndexTargets, BaseProcessingInfo, PromptIndexTargets,
PromptInsertion, PromptUpdate, PromptInsertion, PromptUpdate,
...@@ -72,16 +73,18 @@ class PaliGemmaProcessingInfo(BaseProcessingInfo): ...@@ -72,16 +73,18 @@ class PaliGemmaProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1} return {"image": 1}
def get_mm_max_tokens_per_item( def get_num_image_tokens(
self, self,
seq_len: int, *,
mm_counts: Mapping[str, int], image_width: int,
) -> Mapping[str, int]: image_height: int,
return {"image": self.get_num_image_tokens()} ) -> int:
def get_num_image_tokens(self) -> int:
vision_encoder_info = self.get_vision_encoder_info() vision_encoder_info = self.get_vision_encoder_info()
return vision_encoder_info.get_max_image_tokens()
return vision_encoder_info.get_num_image_tokens(
image_width=image_width,
image_height=image_height,
)
class PaliGemmaDummyInputsBuilder( class PaliGemmaDummyInputsBuilder(
...@@ -148,12 +151,30 @@ class PaliGemmaMultiModalProcessor( ...@@ -148,12 +151,30 @@ class PaliGemmaMultiModalProcessor(
image_token_id = hf_config.image_token_index image_token_id = hf_config.image_token_index
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
num_image_tokens = self.info.get_num_image_tokens()
image_tokens = [image_token_id] * num_image_tokens
bos_token_id = tokenizer.bos_token_id bos_token_id = tokenizer.bos_token_id
assert isinstance(bos_token_id, int) assert isinstance(bos_token_id, int)
def get_insertion(item_idx: int):
images = mm_items.get_items(
"image", (ImageEmbeddingItems, ImageProcessorItems))
if isinstance(images, ImageEmbeddingItems):
num_image_tokens = images.get_feature_size(item_idx)
else:
image_size = images.get_image_size(item_idx)
num_image_tokens = self.info.get_num_image_tokens(
image_width=image_size.width,
image_height=image_size.height,
)
image_tokens = [image_token_id] * num_image_tokens
return PromptUpdateDetails.select_token_id(
image_tokens + [bos_token_id],
embed_token_id=image_token_id,
)
# Paligemma 1 and 2 have different tokenizer.add_bos_token # Paligemma 1 and 2 have different tokenizer.add_bos_token
# Insert <image>*n + <bos> after <bos> for Paligemma 1 # Insert <image>*n + <bos> after <bos> for Paligemma 1
# Insert <image>*n + <bos> for Paligemma 2 # Insert <image>*n + <bos> for Paligemma 2
...@@ -162,10 +183,7 @@ class PaliGemmaMultiModalProcessor( ...@@ -162,10 +183,7 @@ class PaliGemmaMultiModalProcessor(
modality="image", modality="image",
target=PromptIndexTargets.prefix( target=PromptIndexTargets.prefix(
[bos_token_id] if tokenizer.add_bos_token else []), [bos_token_id] if tokenizer.add_bos_token else []),
insertion=PromptUpdateDetails.select_token_id( insertion=get_insertion,
image_tokens + [bos_token_id],
embed_token_id=image_token_id,
),
) )
] ]
......
...@@ -321,21 +321,6 @@ class Phi3VProcessingInfo(BaseProcessingInfo): ...@@ -321,21 +321,6 @@ class Phi3VProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
target_width, target_height = self.get_image_size_with_most_features()
max_image_tokens = self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
processor=None,
)
return {"image": max_image_tokens}
def get_num_image_tokens( def get_num_image_tokens(
self, self,
*, *,
......
...@@ -167,13 +167,6 @@ class PixtralProcessingInfo(BaseProcessingInfo): ...@@ -167,13 +167,6 @@ class PixtralProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def get_vision_config( def get_vision_config(
self, self,
processor: Optional[PixtralProcessorAdapter] = None, processor: Optional[PixtralProcessorAdapter] = None,
...@@ -207,14 +200,6 @@ class PixtralProcessingInfo(BaseProcessingInfo): ...@@ -207,14 +200,6 @@ class PixtralProcessingInfo(BaseProcessingInfo):
return ImageSize(width=max_image_size, height=max_image_size) return ImageSize(width=max_image_size, height=max_image_size)
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
)
class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]): class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
...@@ -938,14 +923,6 @@ class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]): ...@@ -938,14 +923,6 @@ class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]):
) )
return ncols * nrows return ncols * nrows
def get_max_image_tokens(self) -> int:
image_size = self.get_image_size()
return self.get_num_image_tokens(
image_width=image_size,
image_height=image_size,
)
def get_image_size(self) -> int: def get_image_size(self) -> int:
return self.vision_config.image_size return self.vision_config.image_size
......
...@@ -45,9 +45,6 @@ class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo): ...@@ -45,9 +45,6 @@ class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {"image": 0}
class PrithviGeoSpatialMAEInputBuilder( class PrithviGeoSpatialMAEInputBuilder(
BaseDummyInputsBuilder[PrithviGeoSpatialMAEProcessingInfo]): BaseDummyInputsBuilder[PrithviGeoSpatialMAEProcessingInfo]):
......
...@@ -109,17 +109,6 @@ class Qwen2AudioProcessingInfo(BaseProcessingInfo): ...@@ -109,17 +109,6 @@ class Qwen2AudioProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": None} return {"audio": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
hf_config = self.get_hf_config()
max_source_positions = hf_config.audio_config.max_source_positions
max_output_lengths = (max_source_positions - 2) // 2 + 1
return {"audio": max_output_lengths}
class Qwen2AudioDummyInputsBuilder( class Qwen2AudioDummyInputsBuilder(
BaseDummyInputsBuilder[Qwen2AudioProcessingInfo]): BaseDummyInputsBuilder[Qwen2AudioProcessingInfo]):
......
...@@ -818,16 +818,6 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo): ...@@ -818,16 +818,6 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None} return {"image": None, "video": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {
"image": self.get_max_image_tokens(),
"video": self.get_max_video_tokens(seq_len, mm_counts),
}
def _get_vision_info( def _get_vision_info(
self, self,
*, *,
......
...@@ -530,13 +530,6 @@ class QwenVLProcessingInfo(BaseProcessingInfo): ...@@ -530,13 +530,6 @@ class QwenVLProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_num_image_tokens()}
def get_num_image_tokens(self) -> int: def get_num_image_tokens(self) -> int:
hf_config = self.get_hf_config() hf_config = self.get_hf_config()
vision_config = hf_config.visual vision_config = hf_config.visual
......
...@@ -33,9 +33,6 @@ class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]): ...@@ -33,9 +33,6 @@ class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]):
) -> int: ) -> int:
return self.get_patch_grid_length()**2 return self.get_patch_grid_length()**2
def get_max_image_tokens(self) -> int:
return self.get_patch_grid_length()**2
def get_image_size(self) -> int: def get_image_size(self) -> int:
return self.vision_config.image_size return self.vision_config.image_size
......
...@@ -459,13 +459,6 @@ class BaseSkyworkR1VProcessingInfo(BaseProcessingInfo): ...@@ -459,13 +459,6 @@ class BaseSkyworkR1VProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def get_num_image_tokens( def get_num_image_tokens(
self, self,
*, *,
...@@ -481,15 +474,6 @@ class BaseSkyworkR1VProcessingInfo(BaseProcessingInfo): ...@@ -481,15 +474,6 @@ class BaseSkyworkR1VProcessingInfo(BaseProcessingInfo):
image_height=image_height, image_height=image_height,
) )
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
processor=None,
)
def get_image_size_with_most_features(self) -> ImageSize: def get_image_size_with_most_features(self) -> ImageSize:
processor = self.get_hf_processor() processor = self.get_hf_processor()
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py # Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model.""" """PyTorch Ultravox model."""
import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property from functools import cached_property
from typing import Any, Literal, Optional, Set, Tuple, TypedDict, Union from typing import Any, Literal, Optional, Set, Tuple, TypedDict, Union
...@@ -107,17 +106,6 @@ class UltravoxProcessingInfo(BaseProcessingInfo): ...@@ -107,17 +106,6 @@ class UltravoxProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": None} return {"audio": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
feature_extractor = self.get_feature_extractor()
max_audio_tokens = math.ceil(feature_extractor.chunk_length *
_AUDIO_TOKENS_PER_SECOND)
return {"audio": max_audio_tokens * _MAX_ENCODER_BATCH_SIZE}
class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo] class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo]
): ):
......
...@@ -33,10 +33,6 @@ class VisionEncoderInfo(ABC, Generic[_C]): ...@@ -33,10 +33,6 @@ class VisionEncoderInfo(ABC, Generic[_C]):
) -> int: ) -> int:
raise NotImplementedError raise NotImplementedError
@abstractmethod
def get_max_image_tokens(self) -> int:
raise NotImplementedError
@abstractmethod @abstractmethod
def get_image_size(self) -> int: def get_image_size(self) -> int:
raise NotImplementedError raise NotImplementedError
......
...@@ -538,16 +538,9 @@ class WhisperProcessingInfo(BaseProcessingInfo): ...@@ -538,16 +538,9 @@ class WhisperProcessingInfo(BaseProcessingInfo):
assert isinstance(feature_extractor, WhisperFeatureExtractor) assert isinstance(feature_extractor, WhisperFeatureExtractor)
return feature_extractor return feature_extractor
def get_max_audio_tokens(self) -> int: def get_num_audio_tokens(self) -> int:
return self.get_hf_config().max_source_positions return self.get_hf_config().max_source_positions
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"audio": self.get_max_audio_tokens()}
class WhisperDummyInputsBuilder(BaseDummyInputsBuilder[WhisperProcessingInfo]): class WhisperDummyInputsBuilder(BaseDummyInputsBuilder[WhisperProcessingInfo]):
...@@ -630,7 +623,7 @@ class WhisperMultiModalProcessor( ...@@ -630,7 +623,7 @@ class WhisperMultiModalProcessor(
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]: ) -> Sequence[PromptUpdate]:
num_tokens = self.info.get_max_audio_tokens() num_tokens = self.info.get_num_audio_tokens()
return [ return [
PromptReplacement( PromptReplacement(
modality="audio", modality="audio",
......
...@@ -1034,21 +1034,6 @@ class BaseProcessingInfo: ...@@ -1034,21 +1034,6 @@ class BaseProcessingInfo:
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
"""
Get the maximum possible number of tokens per data item
for each modality.
The dictionary returned by this method should have the same
keys as that returned by :meth:`get_supported_mm_limits`.
"""
raise NotImplementedError
_I = TypeVar("_I", bound=BaseProcessingInfo) _I = TypeVar("_I", bound=BaseProcessingInfo)
......
...@@ -68,7 +68,7 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]): ...@@ -68,7 +68,7 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
) -> ProcessorInputs: ) -> ProcessorInputs:
""" """
Build the input which, after processing, results in Build the input which, after processing, results in
:code:`self.info.get_mm_max_tokens_per_item()` placeholder tokens. the maximum possible number of placeholder tokens.
""" """
raise NotImplementedError raise NotImplementedError
...@@ -152,8 +152,11 @@ class MultiModalProfiler(Generic[_I]): ...@@ -152,8 +152,11 @@ class MultiModalProfiler(Generic[_I]):
def _get_dummy_mm_inputs( def _get_dummy_mm_inputs(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Optional[Mapping[str, int]] = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
if mm_counts is None:
mm_counts = self.get_mm_limits()
factory = self.dummy_inputs factory = self.dummy_inputs
processor_inputs = factory.get_dummy_processor_inputs( processor_inputs = factory.get_dummy_processor_inputs(
seq_len, mm_counts) seq_len, mm_counts)
...@@ -164,53 +167,23 @@ class MultiModalProfiler(Generic[_I]): ...@@ -164,53 +167,23 @@ class MultiModalProfiler(Generic[_I]):
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
) )
def get_and_validate_mm_inputs( def _get_mm_num_tokens(
self, self,
seq_len: int, mm_inputs: MultiModalInputs,
mm_counts: Optional[Mapping[str, int]] = None, ) -> Mapping[str, int]:
) -> tuple[MultiModalInputs, Mapping[str, int]]:
if mm_counts is None:
mm_counts = self.get_mm_limits()
info = self.processing_info
mm_max_tokens_per_item = info.get_mm_max_tokens_per_item(
seq_len, mm_counts)
if mm_counts.keys() - mm_max_tokens_per_item.keys():
raise AssertionError(
"The keys returned by `get_supported_mm_limits` "
f"({set(mm_counts.keys())}) should be a subset of those "
"returned by `get_mm_max_tokens_per_item` "
f"({set(mm_max_tokens_per_item.keys())})")
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
placeholders_by_modality = mm_inputs["mm_placeholders"] placeholders_by_modality = mm_inputs["mm_placeholders"]
total_placeholders_by_modality = { return {
modality: sum(item.get_num_embeds() for item in placeholders) modality: sum(item.get_num_embeds() for item in placeholders)
for modality, placeholders in placeholders_by_modality.items() for modality, placeholders in placeholders_by_modality.items()
} }
expected_placeholders_by_modality = {
modality: mm_max_tokens_per_item[modality] * mm_counts[modality]
for modality in placeholders_by_modality
}
if total_placeholders_by_modality != expected_placeholders_by_modality:
raise AssertionError(
f"The processed dummy data has a total of "
f"{total_placeholders_by_modality} placeholder tokens, which "
f"is not the expected {expected_placeholders_by_modality} "
"tokens.")
return mm_inputs, total_placeholders_by_modality
def get_encoder_dummy_data( def get_encoder_dummy_data(
self, self,
seq_len: int, seq_len: int,
mm_counts: Optional[Mapping[str, int]] = None, mm_counts: Optional[Mapping[str, int]] = None,
) -> DummyEncoderData: ) -> DummyEncoderData:
( mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
mm_inputs,
total_placeholders_by_modality,
) = self.get_and_validate_mm_inputs(seq_len, mm_counts)
mm_inputs = cast(MultiModalEncDecInputs, mm_inputs) mm_inputs = cast(MultiModalEncDecInputs, mm_inputs)
# For encoder-decoder models, use encoder prompt token ids instead of # For encoder-decoder models, use encoder prompt token ids instead of
...@@ -232,7 +205,7 @@ class MultiModalProfiler(Generic[_I]): ...@@ -232,7 +205,7 @@ class MultiModalProfiler(Generic[_I]):
" is too short " " is too short "
"to hold the multi-modal embeddings in the worst case " "to hold the multi-modal embeddings in the worst case "
f"({total_len} tokens in total, out of which " f"({total_len} tokens in total, out of which "
f"{total_placeholders_by_modality} are reserved for " f"{self._get_mm_num_tokens(mm_inputs)} are reserved for "
"multi-modal embeddings). This may cause certain " "multi-modal embeddings). This may cause certain "
"multi-modal inputs to fail during inference, even when " "multi-modal inputs to fail during inference, even when "
"the input text is short. To avoid this, you should " "the input text is short. To avoid this, you should "
...@@ -246,10 +219,7 @@ class MultiModalProfiler(Generic[_I]): ...@@ -246,10 +219,7 @@ class MultiModalProfiler(Generic[_I]):
seq_len: int, seq_len: int,
mm_counts: Optional[Mapping[str, int]] = None, mm_counts: Optional[Mapping[str, int]] = None,
) -> DummyDecoderData: ) -> DummyDecoderData:
( mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
mm_inputs,
total_placeholders_by_modality,
) = self.get_and_validate_mm_inputs(seq_len, mm_counts)
prompt_token_ids = mm_inputs["prompt_token_ids"] prompt_token_ids = mm_inputs["prompt_token_ids"]
total_len = len(prompt_token_ids) total_len = len(prompt_token_ids)
...@@ -263,7 +233,7 @@ class MultiModalProfiler(Generic[_I]): ...@@ -263,7 +233,7 @@ class MultiModalProfiler(Generic[_I]):
"is too short " "is too short "
"to hold the multi-modal embeddings in the worst case " "to hold the multi-modal embeddings in the worst case "
f"({total_len} tokens in total, out of which " f"({total_len} tokens in total, out of which "
f"{total_placeholders_by_modality} are reserved for " f"{self._get_mm_num_tokens(mm_inputs)} are reserved for "
"multi-modal embeddings). This may cause certain " "multi-modal embeddings). This may cause certain "
"multi-modal inputs to fail during inference, even when " "multi-modal inputs to fail during inference, even when "
"the input text is short. To avoid this, you should " "the input text is short. To avoid this, you should "
...@@ -278,3 +248,12 @@ class MultiModalProfiler(Generic[_I]): ...@@ -278,3 +248,12 @@ class MultiModalProfiler(Generic[_I]):
multi_modal_data=mm_inputs["mm_kwargs"], multi_modal_data=mm_inputs["mm_kwargs"],
multi_modal_placeholders=mm_inputs["mm_placeholders"], multi_modal_placeholders=mm_inputs["mm_placeholders"],
) )
def get_mm_max_tokens(
self,
seq_len: int,
mm_counts: Optional[Mapping[str, int]] = None,
) -> Mapping[str, int]:
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
return self._get_mm_num_tokens(mm_inputs)
...@@ -258,10 +258,16 @@ class MultiModalRegistry: ...@@ -258,10 +258,16 @@ class MultiModalRegistry:
""" """
if self.has_processor(model_config): if self.has_processor(model_config):
processor = self.create_processor(model_config, disable_cache=True) processor = self.create_processor(model_config, disable_cache=True)
profiler = MultiModalProfiler(processor)
seq_len = model_config.max_model_len seq_len = model_config.max_model_len
mm_limits = self.get_mm_limits_per_prompt(model_config) mm_limits = self.get_mm_limits_per_prompt(model_config)
return processor.info.get_mm_max_tokens_per_item(
seq_len, mm_limits) return profiler.get_mm_max_tokens(
seq_len,
{modality: 1
for modality in mm_limits},
)
return { return {
key: plugin.get_max_multimodal_tokens(model_config) key: plugin.get_max_multimodal_tokens(model_config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment