Unverified Commit 2a0596bc authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[VLM] Reorganize profiling/processing-related code (#11812)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent f1214117
...@@ -4,12 +4,13 @@ from collections import defaultdict ...@@ -4,12 +4,13 @@ from collections import defaultdict
from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import lru_cache from functools import lru_cache
from typing import Any, NamedTuple, Optional, Protocol, TypeVar, Union from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
TypeVar, Union)
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
from vllm import envs import vllm.envs as envs
from vllm.inputs import DummyData, InputProcessingContext from vllm.inputs import InputProcessingContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens, from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
encode_tokens) encode_tokens)
...@@ -20,7 +21,9 @@ from .inputs import (MultiModalDataDict, MultiModalFieldConfig, ...@@ -20,7 +21,9 @@ from .inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs, MultiModalInputsV2, MultiModalKwargs,
MultiModalKwargsItem, PlaceholderRange) MultiModalKwargsItem, PlaceholderRange)
from .parse import MultiModalDataItems, MultiModalDataParser from .parse import MultiModalDataItems, MultiModalDataParser
from .profiling import BaseProfilingInfo
if TYPE_CHECKING:
from .profiling import BaseDummyInputsBuilder
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -46,8 +49,8 @@ class PromptReplacement: ...@@ -46,8 +49,8 @@ class PromptReplacement:
if it does not depend on the input. if it does not depend on the input.
""" """
def bind(self, tokenizer: AnyTokenizer) -> "_BoundPromptReplacement": def bind(self, tokenizer: AnyTokenizer) -> "BoundPromptReplacement":
return _BoundPromptReplacement( return BoundPromptReplacement(
tokenizer=tokenizer, tokenizer=tokenizer,
modality=self.modality, modality=self.modality,
_target=self.target, _target=self.target,
...@@ -128,7 +131,7 @@ class _BoundPromptSequence: ...@@ -128,7 +131,7 @@ class _BoundPromptSequence:
@dataclass @dataclass
class _BoundPromptReplacement: class BoundPromptReplacement:
tokenizer: AnyTokenizer = field(repr=False) tokenizer: AnyTokenizer = field(repr=False)
modality: str modality: str
...@@ -207,7 +210,7 @@ def iter_token_matches( ...@@ -207,7 +210,7 @@ def iter_token_matches(
@dataclass(repr=False) @dataclass(repr=False)
class _PromptReplacementMatch(ABC): class _PromptReplacementMatch(ABC):
prompt_repl: _BoundPromptReplacement prompt_repl: BoundPromptReplacement
@property @property
def modality(self) -> str: def modality(self) -> str:
...@@ -255,7 +258,7 @@ class _PromptReplacementTextMatch(_PromptReplacementMatch): ...@@ -255,7 +258,7 @@ class _PromptReplacementTextMatch(_PromptReplacementMatch):
@dataclass @dataclass
class _PlaceholderInfo: class PlaceholderInfo:
modality: str modality: str
item_idx: int item_idx: int
start_idx: int start_idx: int
...@@ -274,7 +277,7 @@ class _PlaceholderInfo: ...@@ -274,7 +277,7 @@ class _PlaceholderInfo:
def find_token_matches( def find_token_matches(
prompt: list[int], prompt: list[int],
prompt_repls: Sequence[_BoundPromptReplacement], prompt_repls: Sequence[BoundPromptReplacement],
) -> list[_PromptReplacementTokenMatch]: ) -> list[_PromptReplacementTokenMatch]:
"""Return each target of :code:`prompt_repls` found in :code:`prompt`.""" """Return each target of :code:`prompt_repls` found in :code:`prompt`."""
return [ return [
...@@ -286,7 +289,7 @@ def find_token_matches( ...@@ -286,7 +289,7 @@ def find_token_matches(
def find_text_matches( def find_text_matches(
prompt: str, prompt: str,
prompt_repls: Sequence[_BoundPromptReplacement], prompt_repls: Sequence[BoundPromptReplacement],
) -> list[_PromptReplacementTextMatch]: ) -> list[_PromptReplacementTextMatch]:
"""Return each target of :code:`prompt_repls` found in :code:`prompt`.""" """Return each target of :code:`prompt_repls` found in :code:`prompt`."""
return [ return [
...@@ -390,9 +393,9 @@ def replace_text_matches( ...@@ -390,9 +393,9 @@ def replace_text_matches(
def _iter_modality_placeholders( def _iter_modality_placeholders(
prompt: list[int], prompt: list[int],
modality: str, modality: str,
modality_repls: Sequence[_BoundPromptReplacement], modality_repls: Sequence[BoundPromptReplacement],
modal_item_count: int, modal_item_count: int,
) -> Iterable[_PlaceholderInfo]: ) -> Iterable[PlaceholderInfo]:
if modal_item_count == 0: if modal_item_count == 0:
return return
...@@ -413,7 +416,7 @@ def _iter_modality_placeholders( ...@@ -413,7 +416,7 @@ def _iter_modality_placeholders(
continue continue
if prompt[start_idx:end_idx] == repl_tokens: if prompt[start_idx:end_idx] == repl_tokens:
yield _PlaceholderInfo( yield PlaceholderInfo(
modality=modality, modality=modality,
item_idx=item_idx, item_idx=item_idx,
start_idx=start_idx, start_idx=start_idx,
...@@ -434,10 +437,10 @@ def _iter_modality_placeholders( ...@@ -434,10 +437,10 @@ def _iter_modality_placeholders(
def _iter_placeholders( def _iter_placeholders(
mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]], mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
prompt: list[int], prompt: list[int],
mm_item_counts: Mapping[str, int], mm_item_counts: Mapping[str, int],
) -> Iterable[_PlaceholderInfo]: ) -> Iterable[PlaceholderInfo]:
""" """
For each modality, yield each set of placeholder tokens found in For each modality, yield each set of placeholder tokens found in
:code:`prompt`. :code:`prompt`.
...@@ -455,10 +458,10 @@ def _iter_placeholders( ...@@ -455,10 +458,10 @@ def _iter_placeholders(
def find_mm_placeholders( def find_mm_placeholders(
mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]], mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
prompt: list[int], prompt: list[int],
mm_item_counts: Mapping[str, int], mm_item_counts: Mapping[str, int],
) -> Mapping[str, list[_PlaceholderInfo]]: ) -> Mapping[str, list[PlaceholderInfo]]:
it = _iter_placeholders(mm_prompt_repls, prompt, mm_item_counts) it = _iter_placeholders(mm_prompt_repls, prompt, mm_item_counts)
return dict(full_groupby_modality(it)) return dict(full_groupby_modality(it))
...@@ -524,29 +527,59 @@ class ProcessingCache: ...@@ -524,29 +527,59 @@ class ProcessingCache:
self._cache.put(cache_key, output_kwargs) self._cache.put(cache_key, output_kwargs)
class ProcessingMixin: class BaseProcessingInfo:
""" """Base class containing information to perform processing."""
Contains helper functions to perform processing.
Not to be confused with :class:`transformers.ProcessorMixin`. def __init__(self, ctx: InputProcessingContext) -> None:
""" super().__init__()
ctx: InputProcessingContext
def _get_tokenizer(self) -> AnyTokenizer: self.ctx = ctx
@property
def model_id(self) -> str:
return self.ctx.model_config.model
def get_tokenizer(self) -> AnyTokenizer:
return self.ctx.tokenizer return self.ctx.tokenizer
def _get_hf_config(self) -> PretrainedConfig: def get_hf_config(self) -> PretrainedConfig:
return self.ctx.get_hf_config() return self.ctx.get_hf_config()
def _get_hf_processor(self, **kwargs: object) -> ProcessorMixin: def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
""" """
Subclasses can override this method to handle Subclasses can override this method to handle
specific kwargs from model config or user inputs. specific kwargs from model config or user inputs.
""" """
return self.ctx.get_hf_processor(**kwargs) return self.ctx.get_hf_processor(**kwargs)
@abstractmethod
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
"""
Return the maximum supported number of items for each modality.
A value of `None` means unlimited number of items.
Omitting a modality from the returned dictionary means that
it is not supported at all.
"""
raise NotImplementedError
@abstractmethod
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
"""
Get the maximum possible number of tokens per data item
for each modality.
The dictionary returned by this method should have the same
keys as that returned by :meth:`get_supported_mm_limits`.
"""
raise NotImplementedError
_I = TypeVar("_I", bound=BaseProcessingInfo)
class BaseMultiModalProcessor(ProcessingMixin, ABC):
class BaseMultiModalProcessor(ABC, Generic[_I]):
""" """
Abstract base class to process multi-modal inputs to be used in vLLM. Abstract base class to process multi-modal inputs to be used in vLLM.
...@@ -554,18 +587,19 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC): ...@@ -554,18 +587,19 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
""" """
def __init__(self, def __init__(self,
ctx: InputProcessingContext, info: _I,
dummy_inputs: "BaseDummyInputsBuilder[_I]",
*, *,
cache: Optional[ProcessingCache] = None, cache: Optional[ProcessingCache] = None,
enable_sanity_checks: bool = True) -> None: enable_sanity_checks: bool = True) -> None:
super().__init__() super().__init__()
self.ctx = ctx self.info = info
self.dummy_inputs = dummy_inputs
self.cache = cache self.cache = cache
self.enable_sanity_checks = enable_sanity_checks self.enable_sanity_checks = enable_sanity_checks
self.data_parser = self._get_data_parser() self.data_parser = self._get_data_parser()
self.profiling_info = self._get_profiling_info()
def __call__( def __call__(
self, self,
...@@ -585,13 +619,6 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC): ...@@ -585,13 +619,6 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
""" """
return MultiModalDataParser() return MultiModalDataParser()
def _get_profiling_info(self) -> BaseProfilingInfo:
"""
Get the profiling information to find the worst-case memory usage of
the model.
"""
raise NotImplementedError
def _to_mm_items( def _to_mm_items(
self, self,
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
...@@ -602,7 +629,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC): ...@@ -602,7 +629,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
""" """
mm_items = self.data_parser.parse_mm_data(mm_data) mm_items = self.data_parser.parse_mm_data(mm_data)
mm_limits = self.ctx.get_mm_config().limit_per_prompt mm_limits = self.info.ctx.get_mm_config().limit_per_prompt
for modality, items in mm_items.items(): for modality, items in mm_items.items():
limit = mm_limits.get(modality, 1) limit = mm_limits.get(modality, 1)
if len(items) > limit: if len(items) > limit:
...@@ -646,19 +673,19 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC): ...@@ -646,19 +673,19 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
def _find_mm_placeholders( def _find_mm_placeholders(
self, self,
mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]], mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
new_token_ids: list[int], new_token_ids: list[int],
mm_item_counts: Mapping[str, int], mm_item_counts: Mapping[str, int],
) -> Mapping[str, list[_PlaceholderInfo]]: ) -> Mapping[str, list[PlaceholderInfo]]:
return find_mm_placeholders(mm_prompt_repls, new_token_ids, return find_mm_placeholders(mm_prompt_repls, new_token_ids,
mm_item_counts) mm_item_counts)
def _get_hf_mm_data( def _get_hf_mm_data(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
) -> tuple[dict[str, Any], dict[str, Any]]: ) -> tuple[Mapping[str, object], Mapping[str, object]]:
processor_data = dict[str, Any]() processor_data = dict[str, object]()
passthrough_data = dict[str, Any]() passthrough_data = dict[str, object]()
for items in mm_items.values(): for items in mm_items.values():
processor_data.update(items.get_processor_data()) processor_data.update(items.get_processor_data())
...@@ -678,8 +705,8 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC): ...@@ -678,8 +705,8 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
Call the HF processor on the prompt text and Call the HF processor on the prompt text and
associated multi-modal data. associated multi-modal data.
""" """
return self.ctx.call_hf_processor( return self.info.ctx.call_hf_processor(
self._get_hf_processor(**mm_kwargs), self.info.get_hf_processor(**mm_kwargs),
dict(text=prompt, **mm_data), dict(text=prompt, **mm_data),
mm_kwargs, mm_kwargs,
) )
...@@ -738,8 +765,8 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC): ...@@ -738,8 +765,8 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
# Some HF processors (e.g. Qwen2-VL) expect corresponding # Some HF processors (e.g. Qwen2-VL) expect corresponding
# multi-modal tokens to be in the prompt text # multi-modal tokens to be in the prompt text
dummy_inputs = self.profiling_info.get_dummy_processor_inputs( dummy_inputs = self.dummy_inputs.get_dummy_processor_inputs(
self.ctx.model_config.max_model_len, self.info.ctx.model_config.max_model_len,
mm_missing_counts, mm_missing_counts,
) )
...@@ -762,7 +789,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC): ...@@ -762,7 +789,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
caching the results and reusing cached results. caching the results and reusing cached results.
""" """
cache = self.cache cache = self.cache
model_id = self.ctx.model_config.model model_id = self.info.model_id
_, passthrough_data = self._get_hf_mm_data(mm_data_items) _, passthrough_data = self._get_hf_mm_data(mm_data_items)
if cache is None or passthrough_data: if cache is None or passthrough_data:
...@@ -838,8 +865,8 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC): ...@@ -838,8 +865,8 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
def _bind_and_group_repls( def _bind_and_group_repls(
self, self,
prompt_repls: list[PromptReplacement], prompt_repls: list[PromptReplacement],
) -> dict[str, list[_BoundPromptReplacement]]: ) -> dict[str, list[BoundPromptReplacement]]:
tokenizer = self._get_tokenizer() tokenizer = self.info.get_tokenizer()
it = (prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls) it = (prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls)
return dict(full_groupby_modality(it)) return dict(full_groupby_modality(it))
...@@ -859,10 +886,10 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC): ...@@ -859,10 +886,10 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
def _apply_prompt_replacements( def _apply_prompt_replacements(
self, self,
token_ids: list[int], token_ids: list[int],
mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]], mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
mm_item_counts: Mapping[str, int], mm_item_counts: Mapping[str, int],
) -> tuple[list[int], str, Mapping[str, list[_PlaceholderInfo]]]: ) -> tuple[list[int], str, Mapping[str, list[PlaceholderInfo]]]:
tokenizer = self._get_tokenizer() tokenizer = self.info.get_tokenizer()
mm_token_matches = { mm_token_matches = {
modality: find_token_matches(token_ids, prompt_repls) modality: find_token_matches(token_ids, prompt_repls)
...@@ -950,7 +977,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC): ...@@ -950,7 +977,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
def _validate_mm_placeholders( def _validate_mm_placeholders(
self, self,
mm_placeholders: Mapping[str, list[_PlaceholderInfo]], mm_placeholders: Mapping[str, list[PlaceholderInfo]],
mm_item_counts: Mapping[str, int], mm_item_counts: Mapping[str, int],
*, *,
allow_missing: bool = False, allow_missing: bool = False,
...@@ -1001,7 +1028,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC): ...@@ -1001,7 +1028,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
# instead of rehashing. # instead of rehashing.
if envs.VLLM_USE_V1: if envs.VLLM_USE_V1:
model_id = self.ctx.model_config.model model_id = self.info.model_id
mm_hashes = { mm_hashes = {
modality: [ modality: [
MultiModalHasher.hash_kwargs(model_id=model_id, MultiModalHasher.hash_kwargs(model_id=model_id,
...@@ -1046,7 +1073,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC): ...@@ -1046,7 +1073,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
allow_missing=True, allow_missing=True,
) )
mm_missing_repls = dict[str, list[_BoundPromptReplacement]]() mm_missing_repls = dict[str, list[BoundPromptReplacement]]()
for modality, missing_repl_count in mm_missing_repl_counts.items(): for modality, missing_repl_count in mm_missing_repl_counts.items():
if missing_repl_count == 0: if missing_repl_count == 0:
mm_missing_repls[modality] = [] mm_missing_repls[modality] = []
...@@ -1059,7 +1086,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC): ...@@ -1059,7 +1086,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
# If HF processor already inserts placeholder tokens, # If HF processor already inserts placeholder tokens,
# there is no need for us to insert them # there is no need for us to insert them
if all(len(repls) == 0 for repls in mm_missing_repls.items()): if all(len(repls) == 0 for repls in mm_missing_repls.items()):
tokenizer = self._get_tokenizer() tokenizer = self.info.get_tokenizer()
prompt_text = decode_tokens(tokenizer, prompt_ids) prompt_text = decode_tokens(tokenizer, prompt_ids)
mm_placeholders = hf_mm_placeholders mm_placeholders = hf_mm_placeholders
else: else:
...@@ -1090,79 +1117,3 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC): ...@@ -1090,79 +1117,3 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
mm_hashes=mm_hashes, mm_hashes=mm_hashes,
mm_placeholders=mm_placeholder_ranges, mm_placeholders=mm_placeholder_ranges,
) )
def _get_dummy_mm_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalInputsV2:
profiling = self.profiling_info
processor_inputs = profiling.get_dummy_processor_inputs(
seq_len, mm_counts)
return self.apply(
prompt_text=processor_inputs.prompt_text,
mm_data=processor_inputs.mm_data,
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
)
def get_dummy_data(self, seq_len: int) -> DummyData:
# Avoid circular import
from vllm.sequence import SequenceData
profiling = self.profiling_info
mm_counts = profiling.get_mm_limits()
mm_max_tokens_per_item = profiling.get_mm_max_tokens_per_item(seq_len)
if mm_counts.keys() != mm_max_tokens_per_item.keys():
raise AssertionError(
"The keys returned by `get_supported_mm_limits`"
f"({set(mm_counts.keys())}) should be the same as those "
"returned by `get_mm_max_tokens_per_item` "
f"({set(mm_max_tokens_per_item.keys())})")
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
prompt_token_ids = mm_inputs["prompt_token_ids"]
placeholders_by_modality = mm_inputs["mm_placeholders"]
total_placeholders_by_modality = {
modality: sum(item["length"] for item in placeholders)
for modality, placeholders in placeholders_by_modality.items()
}
expected_placeholders_by_modality = {
modality: mm_max_tokens_per_item[modality] * mm_counts[modality]
for modality in placeholders_by_modality
}
if total_placeholders_by_modality != expected_placeholders_by_modality:
raise AssertionError(
f"The processed dummy data has a total of "
f"{total_placeholders_by_modality} placeholder tokens, which "
f"is not the expected {expected_placeholders_by_modality} "
"tokens.")
total_len = len(prompt_token_ids)
# V0 does not support chunked prefill.
if total_len > seq_len and not envs.VLLM_USE_V1:
logger.warning(
"The context length (%d) of the model is too short "
"to hold the multi-modal embeddings in the worst case "
"(%d tokens in total, out of which %s are reserved for "
"multi-modal embeddings). This may cause certain multi-modal "
"inputs to fail during inference, even when the input text is "
"short. To avoid this, you should increase `max_model_len`, "
"reduce `max_num_seqs`, and/or reduce `mm_counts`.", seq_len,
total_len, total_placeholders_by_modality)
return DummyData(
seq_data=SequenceData.from_prompt_token_counts((0, seq_len)),
multi_modal_data=None,
multi_modal_placeholders=None,
)
prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids)))
return DummyData(
seq_data=SequenceData.from_seqs(prompt_token_ids),
multi_modal_data=mm_inputs["mm_kwargs"],
multi_modal_placeholders=placeholders_by_modality,
)
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Mapping from collections.abc import Mapping
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional from typing import Generic, TypeVar
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
from PIL import Image from PIL import Image
from vllm.inputs import InputProcessingContext import vllm.envs as envs
from vllm.inputs import DummyData
from vllm.logger import init_logger from vllm.logger import init_logger
from .inputs import MultiModalDataDict from .inputs import MultiModalDataDict, MultiModalInputsV2
from .processing import BaseMultiModalProcessor, BaseProcessingInfo
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -23,39 +25,19 @@ class ProcessorInputs: ...@@ -23,39 +25,19 @@ class ProcessorInputs:
hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict) hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
class BaseProfilingInfo(ABC): _I = TypeVar("_I", bound=BaseProcessingInfo)
class BaseDummyInputsBuilder(ABC, Generic[_I]):
""" """
Abstract base class that provides the information necessary to profile Abstract base class that constructs the dummy data to profile
multi-modal models. multi-modal models.
""" """
def __init__(self, ctx: InputProcessingContext) -> None: def __init__(self, info: _I) -> None:
super().__init__() super().__init__()
self.ctx = ctx self.info = info
@abstractmethod
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
"""
Return the maximum supported number of items for each modality.
A value of `None` means unlimited number of items.
Omitting a modality from the returned dictionary means that
it is not supported at all.
"""
raise NotImplementedError
@abstractmethod
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
"""
Get the maximum possible number of tokens per data item
for each modality.
The dictionary returned by this method should have the same
keys as that returned by :meth:`get_supported_mm_limits`.
"""
raise NotImplementedError
@abstractmethod @abstractmethod
def get_dummy_processor_inputs( def get_dummy_processor_inputs(
...@@ -64,8 +46,8 @@ class BaseProfilingInfo(ABC): ...@@ -64,8 +46,8 @@ class BaseProfilingInfo(ABC):
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> ProcessorInputs:
""" """
Build the multi-modal portion of the input which, after processing, Build the input which, after processing, results in
results in `mm_max_tokens` in :meth:`get_mm_max_tokens_per_item`. `self.info.get_mm_max_tokens_per_item()` placeholder tokens.
""" """
raise NotImplementedError raise NotImplementedError
...@@ -99,11 +81,33 @@ class BaseProfilingInfo(ABC): ...@@ -99,11 +81,33 @@ class BaseProfilingInfo(ABC):
video = np.zeros((num_frames, width, height, 3)) video = np.zeros((num_frames, width, height, 3))
return [video] * num_videos return [video] * num_videos
def get_mm_limits(self) -> Mapping[str, int]:
mm_config = self.ctx.get_mm_config() class MultiModalProfiler(Generic[_I]):
"""
Contains code for running memory profiling for multi-modal models.
"""
def __init__(
self,
processor: BaseMultiModalProcessor[_I],
) -> None:
super().__init__()
self.processor = processor
@property
def processing_info(self) -> BaseProcessingInfo:
return self.processor.info
@property
def dummy_inputs(self) -> BaseDummyInputsBuilder[_I]:
return self.processor.dummy_inputs
def _get_mm_limits(self) -> Mapping[str, int]:
mm_config = self.processing_info.ctx.get_mm_config()
mm_limit_per_prompt = mm_config.limit_per_prompt mm_limit_per_prompt = mm_config.limit_per_prompt
supported_mm_limits = self.get_supported_mm_limits() supported_mm_limits = self.processing_info.get_supported_mm_limits()
mm_limits = { mm_limits = {
modality: mm_limit_per_prompt.get(modality, 1) modality: mm_limit_per_prompt.get(modality, 1)
...@@ -119,3 +123,81 @@ class BaseProfilingInfo(ABC): ...@@ -119,3 +123,81 @@ class BaseProfilingInfo(ABC):
f"at most {supported_limit} {modality} items.") f"at most {supported_limit} {modality} items.")
return mm_limits return mm_limits
def _get_dummy_mm_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalInputsV2:
factory = self.dummy_inputs
processor_inputs = factory.get_dummy_processor_inputs(
seq_len, mm_counts)
return self.processor.apply(
prompt_text=processor_inputs.prompt_text,
mm_data=processor_inputs.mm_data,
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
)
def get_dummy_data(self, seq_len: int) -> DummyData:
# Avoid circular import
from vllm.sequence import SequenceData
mm_counts = self._get_mm_limits()
info = self.processing_info
mm_max_tokens_per_item = info.get_mm_max_tokens_per_item(seq_len)
if mm_counts.keys() != mm_max_tokens_per_item.keys():
raise AssertionError(
"The keys returned by `get_supported_mm_limits`"
f"({set(mm_counts.keys())}) should be the same as those "
"returned by `get_mm_max_tokens_per_item` "
f"({set(mm_max_tokens_per_item.keys())})")
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
prompt_token_ids = mm_inputs["prompt_token_ids"]
placeholders_by_modality = mm_inputs["mm_placeholders"]
total_placeholders_by_modality = {
modality: sum(item["length"] for item in placeholders)
for modality, placeholders in placeholders_by_modality.items()
}
expected_placeholders_by_modality = {
modality: mm_max_tokens_per_item[modality] * mm_counts[modality]
for modality in placeholders_by_modality
}
if total_placeholders_by_modality != expected_placeholders_by_modality:
raise AssertionError(
f"The processed dummy data has a total of "
f"{total_placeholders_by_modality} placeholder tokens, which "
f"is not the expected {expected_placeholders_by_modality} "
"tokens.")
total_len = len(prompt_token_ids)
# V0 does not support chunked prefill.
if total_len > seq_len and not envs.VLLM_USE_V1:
logger.warning(
"The context length (%d) of the model is too short "
"to hold the multi-modal embeddings in the worst case "
"(%d tokens in total, out of which %s are reserved for "
"multi-modal embeddings). This may cause certain multi-modal "
"inputs to fail during inference, even when the input text is "
"short. To avoid this, you should increase `max_model_len`, "
"reduce `max_num_seqs`, and/or reduce `mm_counts`.", seq_len,
total_len, total_placeholders_by_modality)
return DummyData(
seq_data=SequenceData.from_prompt_token_counts((0, seq_len)),
multi_modal_data=None,
multi_modal_placeholders=None,
)
prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids)))
return DummyData(
seq_data=SequenceData.from_seqs(prompt_token_ids),
multi_modal_data=mm_inputs["mm_kwargs"],
multi_modal_placeholders=placeholders_by_modality,
)
import functools import functools
from collections import UserDict from collections import UserDict
from typing import (TYPE_CHECKING, Any, Dict, Mapping, Optional, Protocol, from dataclasses import dataclass
Sequence, Type, TypeVar) from typing import (TYPE_CHECKING, Any, Dict, Generic, Mapping, Optional,
Protocol, Sequence, Type, TypeVar)
import torch.nn as nn import torch.nn as nn
...@@ -14,7 +15,9 @@ from .audio import AudioPlugin ...@@ -14,7 +15,9 @@ from .audio import AudioPlugin
from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc
from .image import ImagePlugin from .image import ImagePlugin
from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors
from .processing import BaseMultiModalProcessor, ProcessingCache from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
ProcessingCache)
from .profiling import BaseDummyInputsBuilder
from .utils import cached_get_tokenizer from .utils import cached_get_tokenizer
from .video import VideoPlugin from .video import VideoPlugin
...@@ -27,20 +30,59 @@ logger = init_logger(__name__) ...@@ -27,20 +30,59 @@ logger = init_logger(__name__)
MM_CACHE_SIZE = 256 MM_CACHE_SIZE = 256
N = TypeVar("N", bound=Type[nn.Module]) N = TypeVar("N", bound=Type[nn.Module])
_I = TypeVar("_I", bound=BaseProcessingInfo)
_I_co = TypeVar("_I_co", bound=BaseProcessingInfo, covariant=True)
class MultiModalProcessorFactory(Protocol): class ProcessingInfoFactory(Protocol[_I_co]):
"""Constructs a :class:`MultiModalProcessor` instance from the context.""" """Constructs a :class:`MultiModalProcessor` instance from the context."""
def __call__( def __call__(
self, self,
ctx: InputProcessingContext, ctx: InputProcessingContext,
) -> _I_co:
...
class DummyInputsBuilderFactory(Protocol[_I]):
"""
Constructs a :class:`BaseDummyInputsBuilder` instance from the context.
"""
def __call__(self, info: _I) -> BaseDummyInputsBuilder[_I]:
...
class MultiModalProcessorFactory(Protocol[_I]):
"""Constructs a :class:`MultiModalProcessor` instance from the context."""
def __call__(
self,
info: _I,
dummy_inputs: BaseDummyInputsBuilder[_I],
*, *,
cache: Optional[ProcessingCache] = None, cache: Optional[ProcessingCache] = None,
) -> BaseMultiModalProcessor: ) -> BaseMultiModalProcessor[_I]:
... ...
@dataclass(frozen=True)
class _ProcessorFactories(Generic[_I]):
info: ProcessingInfoFactory[_I]
processor: MultiModalProcessorFactory[_I]
dummy_inputs: DummyInputsBuilderFactory[_I]
def build_processor(
self,
ctx: InputProcessingContext,
*,
cache: Optional[ProcessingCache] = None,
):
info = self.info(ctx)
dummy_inputs_builder = self.dummy_inputs(info)
return self.processor(info, dummy_inputs_builder, cache=cache)
class _MultiModalLimits(UserDict["ModelConfig", Dict[str, int]]): class _MultiModalLimits(UserDict["ModelConfig", Dict[str, int]]):
""" """
Wraps `_limits_by_model` for a more informative error message Wraps `_limits_by_model` for a more informative error message
...@@ -71,7 +113,7 @@ class MultiModalRegistry: ...@@ -71,7 +113,7 @@ class MultiModalRegistry:
self._plugins = {p.get_data_key(): p for p in plugins} self._plugins = {p.get_data_key(): p for p in plugins}
self._processor_factories = ClassRegistry[nn.Module, self._processor_factories = ClassRegistry[nn.Module,
MultiModalProcessorFactory]() _ProcessorFactories]()
# This is used for non-multimodal models # This is used for non-multimodal models
self._disabled_limits_per_plugin = {k: 0 for k in self._plugins} self._disabled_limits_per_plugin = {k: 0 for k in self._plugins}
...@@ -224,7 +266,7 @@ class MultiModalRegistry: ...@@ -224,7 +266,7 @@ class MultiModalRegistry:
tokenizer = cached_get_tokenizer(model_config.tokenizer) tokenizer = cached_get_tokenizer(model_config.tokenizer)
processor = self.create_processor(model_config, tokenizer) processor = self.create_processor(model_config, tokenizer)
seq_len = model_config.max_model_len seq_len = model_config.max_model_len
return processor.profiling_info.get_mm_max_tokens_per_item(seq_len) return processor.info.get_mm_max_tokens_per_item(seq_len)
return { return {
key: plugin.get_max_multimodal_tokens(model_config) key: plugin.get_max_multimodal_tokens(model_config)
...@@ -315,7 +357,10 @@ class MultiModalRegistry: ...@@ -315,7 +357,10 @@ class MultiModalRegistry:
def register_processor( def register_processor(
self, self,
factory: MultiModalProcessorFactory, processor: MultiModalProcessorFactory[_I],
*,
info: ProcessingInfoFactory[_I],
dummy_inputs: DummyInputsBuilderFactory[_I],
): ):
""" """
Register a multi-modal processor to a model class. The processor Register a multi-modal processor to a model class. The processor
...@@ -336,7 +381,11 @@ class MultiModalRegistry: ...@@ -336,7 +381,11 @@ class MultiModalRegistry:
"registered to %s. It is overwritten by the new one.", "registered to %s. It is overwritten by the new one.",
model_cls, self) model_cls, self)
self._processor_factories[model_cls] = factory self._processor_factories[model_cls] = _ProcessorFactories(
info=info,
dummy_inputs=dummy_inputs,
processor=processor,
)
return model_cls return model_cls
...@@ -359,15 +408,15 @@ class MultiModalRegistry: ...@@ -359,15 +408,15 @@ class MultiModalRegistry:
self, self,
model_config: "ModelConfig", model_config: "ModelConfig",
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
) -> BaseMultiModalProcessor: ) -> BaseMultiModalProcessor[BaseProcessingInfo]:
""" """
Create a multi-modal processor for a specific model and tokenizer. Create a multi-modal processor for a specific model and tokenizer.
""" """
model_cls = self._get_model_cls(model_config) model_cls = self._get_model_cls(model_config)
processor_factory = self._processor_factories[model_cls] factories = self._processor_factories[model_cls]
ctx = InputProcessingContext(model_config, tokenizer) ctx = InputProcessingContext(model_config, tokenizer)
cache = (None if model_config.disable_mm_preprocessor_cache else cache = (None if model_config.disable_mm_preprocessor_cache else
self._processing_cache) self._processing_cache)
return processor_factory(ctx, cache=cache) return factories.build_processor(ctx, cache=cache)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment