Unverified Commit c87eac18 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Move MM item count validation outside of processor (#33396)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent f45870b5
......@@ -921,7 +921,7 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
)
processor = MULTIMODAL_REGISTRY.create_processor(model_config)
processor._supported_mm_limits = {"image": num_supported}
processor.info.get_supported_mm_limits = lambda: {"image": num_supported}
exc_ctx = nullcontext() if is_valid else pytest.raises(ValueError, match="At most")
......
......@@ -528,7 +528,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
else:
num_items = len(self._items_by_modality[original_modality]) + 1
self.mm_processor.validate_num_items(input_modality, num_items)
self.mm_processor.info.validate_num_items(input_modality, num_items)
# Track original modality for vision_chunk items
if use_vision_chunk:
......
......@@ -176,9 +176,7 @@ class LoRAModelManager:
)
mm_budget = MultiModalBudget(vllm_config, mm_registry)
limit_per_prompt: int = max(
self.mm_processor_info.get_allowed_mm_limits().values()
)
limit_per_prompt = max(self.mm_processor_info.allowed_mm_limits.values())
num_encoder_tokens = self.model.get_num_mm_encoder_tokens(
mm_budget.get_encoder_budget()
)
......
......@@ -7,11 +7,8 @@ from abc import abstractmethod
from collections.abc import Generator, Mapping
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import (
TYPE_CHECKING,
Any,
overload,
)
from functools import cached_property
from typing import TYPE_CHECKING, Any, overload
import torch
from typing_extensions import TypeVar
......@@ -615,13 +612,18 @@ class BaseProcessingInfo:
"""
raise NotImplementedError
def get_allowed_mm_limits(self) -> Mapping[str, int]:
"""Return the maximum allowed number of items for each modality."""
supported_mm_limits = self.get_supported_mm_limits()
@cached_property
def supported_mm_limits(self) -> Mapping[str, int | None]:
"""The maximum supported number of items for each modality."""
return self.get_supported_mm_limits()
@cached_property
def allowed_mm_limits(self) -> Mapping[str, int]:
"""The maximum allowed number of items for each modality."""
mm_config = self.ctx.get_mm_config()
allowed_limits = dict[str, int]()
for modality, supported_limit in supported_mm_limits.items():
for modality, supported_limit in self.supported_mm_limits.items():
user_limit = mm_config.get_limit_per_prompt(modality)
allowed_limits[modality] = (
......@@ -632,6 +634,27 @@ class BaseProcessingInfo:
return allowed_limits
def validate_num_items(self, modality: str, num_items: int) -> None:
"""
Raise `ValueError` if the number of input items for the given modality
is invalid.
"""
supported_limit = self.supported_mm_limits.get(modality, 0)
allowed_limit = self.allowed_mm_limits.get(modality, 0)
if supported_limit is None:
supported_limit = allowed_limit
limit = min(supported_limit, allowed_limit)
if num_items > limit:
msg = f"At most {limit} {modality}(s) may be provided in one prompt."
if num_items <= supported_limit:
msg += " Set `--limit-mm-per-prompt` to increase this limit."
raise ValueError(msg)
def get_mm_max_tokens_per_item(
self,
seq_len: int,
......
......@@ -17,7 +17,7 @@ from typing import (
import regex as re
import torch
from typing_extensions import TypeVar, assert_never
from typing_extensions import TypeVar, assert_never, deprecated
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
......@@ -1000,17 +1000,15 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
else:
self.data_parser = self.info.get_data_parser()
# Avoid unnecessary recomputation
self._supported_mm_limits = self.info.get_supported_mm_limits()
self._allowed_mm_limits = self.info.get_allowed_mm_limits()
@property
@deprecated("Will be removed in v0.17. Use `info.supported_mm_limits` instead.")
def supported_mm_limits(self):
return self._supported_mm_limits
return self.info.supported_mm_limits
@property
@deprecated("Will be removed in v0.17. Use `info.allowed_mm_limits` instead.")
def allowed_mm_limits(self):
return self._allowed_mm_limits
return self.info.allowed_mm_limits
def __call__(
self,
......@@ -1022,27 +1020,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
) -> MultiModalInputs:
return self.apply(prompt, mm_data, hf_processor_mm_kwargs, mm_uuids=mm_uuids)
def validate_num_items(
self,
modality: str,
num_items: int,
) -> None:
supported_limit = self.supported_mm_limits.get(modality, 0)
allowed_limit = self.allowed_mm_limits.get(modality, 0)
if supported_limit is None:
supported_limit = allowed_limit
limit = min(supported_limit, allowed_limit)
if num_items > limit:
msg = f"At most {limit} {modality}(s) may be provided in one prompt."
if num_items <= supported_limit:
msg += " Set `--limit-mm-per-prompt` to increase this limit."
raise ValueError(msg)
def _to_mm_items(
self,
mm_data: MultiModalDataDict,
......@@ -1066,7 +1043,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
)
for modality, items in mm_items.items():
self.validate_num_items(modality, len(items))
self.info.validate_num_items(modality, len(items))
return mm_items
......
......@@ -168,7 +168,7 @@ class MultiModalRegistry:
)
if profiler_limits is None:
profiler_limits = processor.allowed_mm_limits
profiler_limits = processor.info.allowed_mm_limits
mm_counts = {
modality: 1 for modality, limit in profiler_limits.items() if limit > 0
......@@ -200,7 +200,6 @@ class MultiModalRegistry:
self,
model_config: "ModelConfig",
*,
cache: BaseMultiModalProcessorCache | None = None,
observability_config: ObservabilityConfig | None = None,
) -> Mapping[str, int]:
"""
......@@ -210,10 +209,8 @@ class MultiModalRegistry:
if not model_config.is_multimodal_model:
return {}
processor = self.create_processor(
model_config, observability_config, cache=cache
)
return processor.allowed_mm_limits
info = self._create_processing_info(model_config, observability_config)
return info.allowed_mm_limits
def register_processor(
self,
......@@ -324,7 +321,7 @@ class MultiModalRegistry:
model_config, observability_config, cache=cache
)
if mm_counts is None:
mm_counts = processor.allowed_mm_limits
mm_counts = processor.info.allowed_mm_limits
processor_inputs = processor.dummy_inputs.get_dummy_processor_inputs(
seq_len=seq_len,
......
......@@ -40,7 +40,7 @@ class MultiModalBudget:
self.max_model_len = model_config.max_model_len
self.max_num_reqs = scheduler_config.max_num_seqs
self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config, cache=cache)
self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config)
max_tokens_by_modality = mm_registry.get_max_tokens_per_item_by_modality(
model_config,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment