Unverified Commit c87eac18 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Move MM item count validation outside of processor (#33396)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent f45870b5
...@@ -921,7 +921,7 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid): ...@@ -921,7 +921,7 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
) )
processor = MULTIMODAL_REGISTRY.create_processor(model_config) processor = MULTIMODAL_REGISTRY.create_processor(model_config)
processor._supported_mm_limits = {"image": num_supported} processor.info.get_supported_mm_limits = lambda: {"image": num_supported}
exc_ctx = nullcontext() if is_valid else pytest.raises(ValueError, match="At most") exc_ctx = nullcontext() if is_valid else pytest.raises(ValueError, match="At most")
......
...@@ -528,7 +528,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -528,7 +528,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
else: else:
num_items = len(self._items_by_modality[original_modality]) + 1 num_items = len(self._items_by_modality[original_modality]) + 1
self.mm_processor.validate_num_items(input_modality, num_items) self.mm_processor.info.validate_num_items(input_modality, num_items)
# Track original modality for vision_chunk items # Track original modality for vision_chunk items
if use_vision_chunk: if use_vision_chunk:
......
...@@ -176,9 +176,7 @@ class LoRAModelManager: ...@@ -176,9 +176,7 @@ class LoRAModelManager:
) )
mm_budget = MultiModalBudget(vllm_config, mm_registry) mm_budget = MultiModalBudget(vllm_config, mm_registry)
limit_per_prompt: int = max( limit_per_prompt = max(self.mm_processor_info.allowed_mm_limits.values())
self.mm_processor_info.get_allowed_mm_limits().values()
)
num_encoder_tokens = self.model.get_num_mm_encoder_tokens( num_encoder_tokens = self.model.get_num_mm_encoder_tokens(
mm_budget.get_encoder_budget() mm_budget.get_encoder_budget()
) )
......
...@@ -7,11 +7,8 @@ from abc import abstractmethod ...@@ -7,11 +7,8 @@ from abc import abstractmethod
from collections.abc import Generator, Mapping from collections.abc import Generator, Mapping
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import ( from functools import cached_property
TYPE_CHECKING, from typing import TYPE_CHECKING, Any, overload
Any,
overload,
)
import torch import torch
from typing_extensions import TypeVar from typing_extensions import TypeVar
...@@ -615,13 +612,18 @@ class BaseProcessingInfo: ...@@ -615,13 +612,18 @@ class BaseProcessingInfo:
""" """
raise NotImplementedError raise NotImplementedError
def get_allowed_mm_limits(self) -> Mapping[str, int]: @cached_property
"""Return the maximum allowed number of items for each modality.""" def supported_mm_limits(self) -> Mapping[str, int | None]:
supported_mm_limits = self.get_supported_mm_limits() """The maximum supported number of items for each modality."""
return self.get_supported_mm_limits()
@cached_property
def allowed_mm_limits(self) -> Mapping[str, int]:
"""The maximum allowed number of items for each modality."""
mm_config = self.ctx.get_mm_config() mm_config = self.ctx.get_mm_config()
allowed_limits = dict[str, int]() allowed_limits = dict[str, int]()
for modality, supported_limit in supported_mm_limits.items(): for modality, supported_limit in self.supported_mm_limits.items():
user_limit = mm_config.get_limit_per_prompt(modality) user_limit = mm_config.get_limit_per_prompt(modality)
allowed_limits[modality] = ( allowed_limits[modality] = (
...@@ -632,6 +634,27 @@ class BaseProcessingInfo: ...@@ -632,6 +634,27 @@ class BaseProcessingInfo:
return allowed_limits return allowed_limits
def validate_num_items(self, modality: str, num_items: int) -> None:
"""
Raise `ValueError` if the number of input items for the given modality
is invalid.
"""
supported_limit = self.supported_mm_limits.get(modality, 0)
allowed_limit = self.allowed_mm_limits.get(modality, 0)
if supported_limit is None:
supported_limit = allowed_limit
limit = min(supported_limit, allowed_limit)
if num_items > limit:
msg = f"At most {limit} {modality}(s) may be provided in one prompt."
if num_items <= supported_limit:
msg += " Set `--limit-mm-per-prompt` to increase this limit."
raise ValueError(msg)
def get_mm_max_tokens_per_item( def get_mm_max_tokens_per_item(
self, self,
seq_len: int, seq_len: int,
......
...@@ -17,7 +17,7 @@ from typing import ( ...@@ -17,7 +17,7 @@ from typing import (
import regex as re import regex as re
import torch import torch
from typing_extensions import TypeVar, assert_never from typing_extensions import TypeVar, assert_never, deprecated
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
...@@ -1000,17 +1000,15 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1000,17 +1000,15 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
else: else:
self.data_parser = self.info.get_data_parser() self.data_parser = self.info.get_data_parser()
# Avoid unnecessary recomputation
self._supported_mm_limits = self.info.get_supported_mm_limits()
self._allowed_mm_limits = self.info.get_allowed_mm_limits()
@property @property
@deprecated("Will be removed in v0.17. Use `info.supported_mm_limits` instead.")
def supported_mm_limits(self): def supported_mm_limits(self):
return self._supported_mm_limits return self.info.supported_mm_limits
@property @property
@deprecated("Will be removed in v0.17. Use `info.allowed_mm_limits` instead.")
def allowed_mm_limits(self): def allowed_mm_limits(self):
return self._allowed_mm_limits return self.info.allowed_mm_limits
def __call__( def __call__(
self, self,
...@@ -1022,27 +1020,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1022,27 +1020,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
) -> MultiModalInputs: ) -> MultiModalInputs:
return self.apply(prompt, mm_data, hf_processor_mm_kwargs, mm_uuids=mm_uuids) return self.apply(prompt, mm_data, hf_processor_mm_kwargs, mm_uuids=mm_uuids)
def validate_num_items(
self,
modality: str,
num_items: int,
) -> None:
supported_limit = self.supported_mm_limits.get(modality, 0)
allowed_limit = self.allowed_mm_limits.get(modality, 0)
if supported_limit is None:
supported_limit = allowed_limit
limit = min(supported_limit, allowed_limit)
if num_items > limit:
msg = f"At most {limit} {modality}(s) may be provided in one prompt."
if num_items <= supported_limit:
msg += " Set `--limit-mm-per-prompt` to increase this limit."
raise ValueError(msg)
def _to_mm_items( def _to_mm_items(
self, self,
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
...@@ -1066,7 +1043,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1066,7 +1043,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
) )
for modality, items in mm_items.items(): for modality, items in mm_items.items():
self.validate_num_items(modality, len(items)) self.info.validate_num_items(modality, len(items))
return mm_items return mm_items
......
...@@ -168,7 +168,7 @@ class MultiModalRegistry: ...@@ -168,7 +168,7 @@ class MultiModalRegistry:
) )
if profiler_limits is None: if profiler_limits is None:
profiler_limits = processor.allowed_mm_limits profiler_limits = processor.info.allowed_mm_limits
mm_counts = { mm_counts = {
modality: 1 for modality, limit in profiler_limits.items() if limit > 0 modality: 1 for modality, limit in profiler_limits.items() if limit > 0
...@@ -200,7 +200,6 @@ class MultiModalRegistry: ...@@ -200,7 +200,6 @@ class MultiModalRegistry:
self, self,
model_config: "ModelConfig", model_config: "ModelConfig",
*, *,
cache: BaseMultiModalProcessorCache | None = None,
observability_config: ObservabilityConfig | None = None, observability_config: ObservabilityConfig | None = None,
) -> Mapping[str, int]: ) -> Mapping[str, int]:
""" """
...@@ -210,10 +209,8 @@ class MultiModalRegistry: ...@@ -210,10 +209,8 @@ class MultiModalRegistry:
if not model_config.is_multimodal_model: if not model_config.is_multimodal_model:
return {} return {}
processor = self.create_processor( info = self._create_processing_info(model_config, observability_config)
model_config, observability_config, cache=cache return info.allowed_mm_limits
)
return processor.allowed_mm_limits
def register_processor( def register_processor(
self, self,
...@@ -324,7 +321,7 @@ class MultiModalRegistry: ...@@ -324,7 +321,7 @@ class MultiModalRegistry:
model_config, observability_config, cache=cache model_config, observability_config, cache=cache
) )
if mm_counts is None: if mm_counts is None:
mm_counts = processor.allowed_mm_limits mm_counts = processor.info.allowed_mm_limits
processor_inputs = processor.dummy_inputs.get_dummy_processor_inputs( processor_inputs = processor.dummy_inputs.get_dummy_processor_inputs(
seq_len=seq_len, seq_len=seq_len,
......
...@@ -40,7 +40,7 @@ class MultiModalBudget: ...@@ -40,7 +40,7 @@ class MultiModalBudget:
self.max_model_len = model_config.max_model_len self.max_model_len = model_config.max_model_len
self.max_num_reqs = scheduler_config.max_num_seqs self.max_num_reqs = scheduler_config.max_num_seqs
self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config, cache=cache) self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config)
max_tokens_by_modality = mm_registry.get_max_tokens_per_item_by_modality( max_tokens_by_modality = mm_registry.get_max_tokens_per_item_by_modality(
model_config, model_config,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment