Unverified Commit 83b824c8 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[VLM] Remove `BaseProcessingInfo.get_mm_max_tokens_per_item` (#16408)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 7678fcd5
......@@ -162,13 +162,6 @@ class BaseLlavaProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def get_num_image_tokens(
self,
*,
......@@ -186,14 +179,6 @@ class BaseLlavaProcessingInfo(BaseProcessingInfo):
width = height = vision_encoder_info.get_image_size()
return ImageSize(width=width, height=height)
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
)
_I = TypeVar("_I", bound=BaseLlavaProcessingInfo)
......
......@@ -106,16 +106,6 @@ class MllamaProcessingInfo(BaseProcessingInfo):
image_size = self.get_hf_config().vision_config.image_size
return calc_token_per_chunk(image_size)
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
vision_config = self.get_hf_config().vision_config
token_per_chunk = self.get_token_per_chunk_from_config()
mm_max_tokens = vision_config.max_num_tiles * token_per_chunk
return {"image": mm_max_tokens}
def get_num_tiles_per_image(self, image_height: int,
image_width: int) -> int:
vision_config = self.get_hf_config().vision_config
......
......@@ -498,17 +498,6 @@ class Mllama4ProcessingInfo(BaseProcessingInfo):
image_processor = self.get_hf_processor().image_processor
return image_processor.max_patches
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
vision_config = self.get_hf_config().vision_config
patch_per_chunk = self.get_patch_per_chunk(vision_config)
num_patches = self.get_max_num_tiles() + 1
return {"image": patch_per_chunk * num_patches}
def get_image_size_with_most_features(self) -> ImageSize:
vision_config = self.get_hf_config().vision_config
image_size = vision_config.image_size
......@@ -516,14 +505,6 @@ class Mllama4ProcessingInfo(BaseProcessingInfo):
return ImageSize(height=self.get_max_num_tiles() * image_size,
width=image_size)
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
)
class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
):
......
......@@ -1164,13 +1164,6 @@ class MolmoProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def get_num_image_tokens(
self,
*,
......@@ -1195,15 +1188,6 @@ class MolmoProcessingInfo(BaseProcessingInfo):
return extra + joint
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
processor=None,
)
def get_image_size_with_most_features(self) -> ImageSize:
processor = self.get_hf_processor()
......
......@@ -13,7 +13,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, MultiModalKwargs)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptIndexTargets,
PromptInsertion, PromptUpdate,
......@@ -72,16 +73,18 @@ class PaliGemmaProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
def get_mm_max_tokens_per_item(
def get_num_image_tokens(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_num_image_tokens()}
def get_num_image_tokens(self) -> int:
*,
image_width: int,
image_height: int,
) -> int:
vision_encoder_info = self.get_vision_encoder_info()
return vision_encoder_info.get_max_image_tokens()
return vision_encoder_info.get_num_image_tokens(
image_width=image_width,
image_height=image_height,
)
class PaliGemmaDummyInputsBuilder(
......@@ -148,12 +151,30 @@ class PaliGemmaMultiModalProcessor(
image_token_id = hf_config.image_token_index
tokenizer = self.info.get_tokenizer()
num_image_tokens = self.info.get_num_image_tokens()
image_tokens = [image_token_id] * num_image_tokens
bos_token_id = tokenizer.bos_token_id
assert isinstance(bos_token_id, int)
def get_insertion(item_idx: int):
images = mm_items.get_items(
"image", (ImageEmbeddingItems, ImageProcessorItems))
if isinstance(images, ImageEmbeddingItems):
num_image_tokens = images.get_feature_size(item_idx)
else:
image_size = images.get_image_size(item_idx)
num_image_tokens = self.info.get_num_image_tokens(
image_width=image_size.width,
image_height=image_size.height,
)
image_tokens = [image_token_id] * num_image_tokens
return PromptUpdateDetails.select_token_id(
image_tokens + [bos_token_id],
embed_token_id=image_token_id,
)
# Paligemma 1 and 2 have different tokenizer.add_bos_token
# Insert <image>*n + <bos> after <bos> for Paligemma 1
# Insert <image>*n + <bos> for Paligemma 2
......@@ -162,10 +183,7 @@ class PaliGemmaMultiModalProcessor(
modality="image",
target=PromptIndexTargets.prefix(
[bos_token_id] if tokenizer.add_bos_token else []),
insertion=PromptUpdateDetails.select_token_id(
image_tokens + [bos_token_id],
embed_token_id=image_token_id,
),
insertion=get_insertion,
)
]
......
......@@ -321,21 +321,6 @@ class Phi3VProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
target_width, target_height = self.get_image_size_with_most_features()
max_image_tokens = self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
processor=None,
)
return {"image": max_image_tokens}
def get_num_image_tokens(
self,
*,
......
......@@ -167,13 +167,6 @@ class PixtralProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def get_vision_config(
self,
processor: Optional[PixtralProcessorAdapter] = None,
......@@ -207,14 +200,6 @@ class PixtralProcessingInfo(BaseProcessingInfo):
return ImageSize(width=max_image_size, height=max_image_size)
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
)
class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
......@@ -938,14 +923,6 @@ class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]):
)
return ncols * nrows
def get_max_image_tokens(self) -> int:
image_size = self.get_image_size()
return self.get_num_image_tokens(
image_width=image_size,
image_height=image_size,
)
def get_image_size(self) -> int:
return self.vision_config.image_size
......
......@@ -45,9 +45,6 @@ class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {"image": 0}
class PrithviGeoSpatialMAEInputBuilder(
BaseDummyInputsBuilder[PrithviGeoSpatialMAEProcessingInfo]):
......
......@@ -109,17 +109,6 @@ class Qwen2AudioProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
hf_config = self.get_hf_config()
max_source_positions = hf_config.audio_config.max_source_positions
max_output_lengths = (max_source_positions - 2) // 2 + 1
return {"audio": max_output_lengths}
class Qwen2AudioDummyInputsBuilder(
BaseDummyInputsBuilder[Qwen2AudioProcessingInfo]):
......
......@@ -818,16 +818,6 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {
"image": self.get_max_image_tokens(),
"video": self.get_max_video_tokens(seq_len, mm_counts),
}
def _get_vision_info(
self,
*,
......
......@@ -530,13 +530,6 @@ class QwenVLProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_num_image_tokens()}
def get_num_image_tokens(self) -> int:
hf_config = self.get_hf_config()
vision_config = hf_config.visual
......
......@@ -33,9 +33,6 @@ class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]):
) -> int:
return self.get_patch_grid_length()**2
def get_max_image_tokens(self) -> int:
return self.get_patch_grid_length()**2
def get_image_size(self) -> int:
return self.vision_config.image_size
......
......@@ -459,13 +459,6 @@ class BaseSkyworkR1VProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def get_num_image_tokens(
self,
*,
......@@ -481,15 +474,6 @@ class BaseSkyworkR1VProcessingInfo(BaseProcessingInfo):
image_height=image_height,
)
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
processor=None,
)
def get_image_size_with_most_features(self) -> ImageSize:
processor = self.get_hf_processor()
......
......@@ -2,7 +2,6 @@
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model."""
import math
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import Any, Literal, Optional, Set, Tuple, TypedDict, Union
......@@ -107,17 +106,6 @@ class UltravoxProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
feature_extractor = self.get_feature_extractor()
max_audio_tokens = math.ceil(feature_extractor.chunk_length *
_AUDIO_TOKENS_PER_SECOND)
return {"audio": max_audio_tokens * _MAX_ENCODER_BATCH_SIZE}
class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo]
):
......
......@@ -33,10 +33,6 @@ class VisionEncoderInfo(ABC, Generic[_C]):
) -> int:
raise NotImplementedError
@abstractmethod
def get_max_image_tokens(self) -> int:
raise NotImplementedError
@abstractmethod
def get_image_size(self) -> int:
raise NotImplementedError
......
......@@ -538,16 +538,9 @@ class WhisperProcessingInfo(BaseProcessingInfo):
assert isinstance(feature_extractor, WhisperFeatureExtractor)
return feature_extractor
def get_max_audio_tokens(self) -> int:
def get_num_audio_tokens(self) -> int:
return self.get_hf_config().max_source_positions
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"audio": self.get_max_audio_tokens()}
class WhisperDummyInputsBuilder(BaseDummyInputsBuilder[WhisperProcessingInfo]):
......@@ -630,7 +623,7 @@ class WhisperMultiModalProcessor(
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]:
num_tokens = self.info.get_max_audio_tokens()
num_tokens = self.info.get_num_audio_tokens()
return [
PromptReplacement(
modality="audio",
......
......@@ -1034,21 +1034,6 @@ class BaseProcessingInfo:
"""
raise NotImplementedError
@abstractmethod
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
"""
Get the maximum possible number of tokens per data item
for each modality.
The dictionary returned by this method should have the same
keys as that returned by :meth:`get_supported_mm_limits`.
"""
raise NotImplementedError
_I = TypeVar("_I", bound=BaseProcessingInfo)
......
......@@ -68,7 +68,7 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
) -> ProcessorInputs:
"""
Build the input which, after processing, results in
:code:`self.info.get_mm_max_tokens_per_item()` placeholder tokens.
the maximum possible number of placeholder tokens.
"""
raise NotImplementedError
......@@ -152,8 +152,11 @@ class MultiModalProfiler(Generic[_I]):
def _get_dummy_mm_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_counts: Optional[Mapping[str, int]] = None,
) -> MultiModalInputs:
if mm_counts is None:
mm_counts = self.get_mm_limits()
factory = self.dummy_inputs
processor_inputs = factory.get_dummy_processor_inputs(
seq_len, mm_counts)
......@@ -164,53 +167,23 @@ class MultiModalProfiler(Generic[_I]):
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
)
def get_and_validate_mm_inputs(
def _get_mm_num_tokens(
self,
seq_len: int,
mm_counts: Optional[Mapping[str, int]] = None,
) -> tuple[MultiModalInputs, Mapping[str, int]]:
if mm_counts is None:
mm_counts = self.get_mm_limits()
info = self.processing_info
mm_max_tokens_per_item = info.get_mm_max_tokens_per_item(
seq_len, mm_counts)
if mm_counts.keys() - mm_max_tokens_per_item.keys():
raise AssertionError(
"The keys returned by `get_supported_mm_limits` "
f"({set(mm_counts.keys())}) should be a subset of those "
"returned by `get_mm_max_tokens_per_item` "
f"({set(mm_max_tokens_per_item.keys())})")
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
mm_inputs: MultiModalInputs,
) -> Mapping[str, int]:
placeholders_by_modality = mm_inputs["mm_placeholders"]
total_placeholders_by_modality = {
return {
modality: sum(item.get_num_embeds() for item in placeholders)
for modality, placeholders in placeholders_by_modality.items()
}
expected_placeholders_by_modality = {
modality: mm_max_tokens_per_item[modality] * mm_counts[modality]
for modality in placeholders_by_modality
}
if total_placeholders_by_modality != expected_placeholders_by_modality:
raise AssertionError(
f"The processed dummy data has a total of "
f"{total_placeholders_by_modality} placeholder tokens, which "
f"is not the expected {expected_placeholders_by_modality} "
"tokens.")
return mm_inputs, total_placeholders_by_modality
def get_encoder_dummy_data(
self,
seq_len: int,
mm_counts: Optional[Mapping[str, int]] = None,
) -> DummyEncoderData:
(
mm_inputs,
total_placeholders_by_modality,
) = self.get_and_validate_mm_inputs(seq_len, mm_counts)
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
mm_inputs = cast(MultiModalEncDecInputs, mm_inputs)
# For encoder-decoder models, use encoder prompt token ids instead of
......@@ -232,7 +205,7 @@ class MultiModalProfiler(Generic[_I]):
" is too short "
"to hold the multi-modal embeddings in the worst case "
f"({total_len} tokens in total, out of which "
f"{total_placeholders_by_modality} are reserved for "
f"{self._get_mm_num_tokens(mm_inputs)} are reserved for "
"multi-modal embeddings). This may cause certain "
"multi-modal inputs to fail during inference, even when "
"the input text is short. To avoid this, you should "
......@@ -246,10 +219,7 @@ class MultiModalProfiler(Generic[_I]):
seq_len: int,
mm_counts: Optional[Mapping[str, int]] = None,
) -> DummyDecoderData:
(
mm_inputs,
total_placeholders_by_modality,
) = self.get_and_validate_mm_inputs(seq_len, mm_counts)
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
prompt_token_ids = mm_inputs["prompt_token_ids"]
total_len = len(prompt_token_ids)
......@@ -263,7 +233,7 @@ class MultiModalProfiler(Generic[_I]):
"is too short "
"to hold the multi-modal embeddings in the worst case "
f"({total_len} tokens in total, out of which "
f"{total_placeholders_by_modality} are reserved for "
f"{self._get_mm_num_tokens(mm_inputs)} are reserved for "
"multi-modal embeddings). This may cause certain "
"multi-modal inputs to fail during inference, even when "
"the input text is short. To avoid this, you should "
......@@ -278,3 +248,12 @@ class MultiModalProfiler(Generic[_I]):
multi_modal_data=mm_inputs["mm_kwargs"],
multi_modal_placeholders=mm_inputs["mm_placeholders"],
)
def get_mm_max_tokens(
self,
seq_len: int,
mm_counts: Optional[Mapping[str, int]] = None,
) -> Mapping[str, int]:
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
return self._get_mm_num_tokens(mm_inputs)
......@@ -258,10 +258,16 @@ class MultiModalRegistry:
"""
if self.has_processor(model_config):
processor = self.create_processor(model_config, disable_cache=True)
profiler = MultiModalProfiler(processor)
seq_len = model_config.max_model_len
mm_limits = self.get_mm_limits_per_prompt(model_config)
return processor.info.get_mm_max_tokens_per_item(
seq_len, mm_limits)
return profiler.get_mm_max_tokens(
seq_len,
{modality: 1
for modality in mm_limits},
)
return {
key: plugin.get_max_multimodal_tokens(model_config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment