Unverified Commit 2a0596bc authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[VLM] Reorganize profiling/processing-related code (#11812)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent f1214117
......@@ -4,12 +4,13 @@ from collections import defaultdict
from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence
from dataclasses import dataclass, field
from functools import lru_cache
from typing import Any, NamedTuple, Optional, Protocol, TypeVar, Union
from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
TypeVar, Union)
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
from vllm import envs
from vllm.inputs import DummyData, InputProcessingContext
import vllm.envs as envs
from vllm.inputs import InputProcessingContext
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
encode_tokens)
......@@ -20,7 +21,9 @@ from .inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
MultiModalKwargsItem, PlaceholderRange)
from .parse import MultiModalDataItems, MultiModalDataParser
from .profiling import BaseProfilingInfo
if TYPE_CHECKING:
from .profiling import BaseDummyInputsBuilder
logger = init_logger(__name__)
......@@ -46,8 +49,8 @@ class PromptReplacement:
if it does not depend on the input.
"""
def bind(self, tokenizer: AnyTokenizer) -> "_BoundPromptReplacement":
return _BoundPromptReplacement(
def bind(self, tokenizer: AnyTokenizer) -> "BoundPromptReplacement":
return BoundPromptReplacement(
tokenizer=tokenizer,
modality=self.modality,
_target=self.target,
......@@ -128,7 +131,7 @@ class _BoundPromptSequence:
@dataclass
class _BoundPromptReplacement:
class BoundPromptReplacement:
tokenizer: AnyTokenizer = field(repr=False)
modality: str
......@@ -207,7 +210,7 @@ def iter_token_matches(
@dataclass(repr=False)
class _PromptReplacementMatch(ABC):
prompt_repl: _BoundPromptReplacement
prompt_repl: BoundPromptReplacement
@property
def modality(self) -> str:
......@@ -255,7 +258,7 @@ class _PromptReplacementTextMatch(_PromptReplacementMatch):
@dataclass
class _PlaceholderInfo:
class PlaceholderInfo:
modality: str
item_idx: int
start_idx: int
......@@ -274,7 +277,7 @@ class _PlaceholderInfo:
def find_token_matches(
prompt: list[int],
prompt_repls: Sequence[_BoundPromptReplacement],
prompt_repls: Sequence[BoundPromptReplacement],
) -> list[_PromptReplacementTokenMatch]:
"""Return each target of :code:`prompt_repls` found in :code:`prompt`."""
return [
......@@ -286,7 +289,7 @@ def find_token_matches(
def find_text_matches(
prompt: str,
prompt_repls: Sequence[_BoundPromptReplacement],
prompt_repls: Sequence[BoundPromptReplacement],
) -> list[_PromptReplacementTextMatch]:
"""Return each target of :code:`prompt_repls` found in :code:`prompt`."""
return [
......@@ -390,9 +393,9 @@ def replace_text_matches(
def _iter_modality_placeholders(
prompt: list[int],
modality: str,
modality_repls: Sequence[_BoundPromptReplacement],
modality_repls: Sequence[BoundPromptReplacement],
modal_item_count: int,
) -> Iterable[_PlaceholderInfo]:
) -> Iterable[PlaceholderInfo]:
if modal_item_count == 0:
return
......@@ -413,7 +416,7 @@ def _iter_modality_placeholders(
continue
if prompt[start_idx:end_idx] == repl_tokens:
yield _PlaceholderInfo(
yield PlaceholderInfo(
modality=modality,
item_idx=item_idx,
start_idx=start_idx,
......@@ -434,10 +437,10 @@ def _iter_modality_placeholders(
def _iter_placeholders(
mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]],
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
prompt: list[int],
mm_item_counts: Mapping[str, int],
) -> Iterable[_PlaceholderInfo]:
) -> Iterable[PlaceholderInfo]:
"""
For each modality, yield each set of placeholder tokens found in
:code:`prompt`.
......@@ -455,10 +458,10 @@ def _iter_placeholders(
def find_mm_placeholders(
mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]],
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
prompt: list[int],
mm_item_counts: Mapping[str, int],
) -> Mapping[str, list[_PlaceholderInfo]]:
) -> Mapping[str, list[PlaceholderInfo]]:
it = _iter_placeholders(mm_prompt_repls, prompt, mm_item_counts)
return dict(full_groupby_modality(it))
......@@ -524,29 +527,59 @@ class ProcessingCache:
self._cache.put(cache_key, output_kwargs)
class ProcessingMixin:
"""
Contains helper functions to perform processing.
class BaseProcessingInfo:
"""Base class containing information to perform processing."""
Not to be confused with :class:`transformers.ProcessorMixin`.
"""
ctx: InputProcessingContext
def __init__(self, ctx: InputProcessingContext) -> None:
super().__init__()
def _get_tokenizer(self) -> AnyTokenizer:
self.ctx = ctx
@property
def model_id(self) -> str:
return self.ctx.model_config.model
def get_tokenizer(self) -> AnyTokenizer:
return self.ctx.tokenizer
def _get_hf_config(self) -> PretrainedConfig:
def get_hf_config(self) -> PretrainedConfig:
return self.ctx.get_hf_config()
def _get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
"""
Subclasses can override this method to handle
specific kwargs from model config or user inputs.
"""
return self.ctx.get_hf_processor(**kwargs)
@abstractmethod
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
"""
Return the maximum supported number of items for each modality.
A value of `None` means unlimited number of items.
Omitting a modality from the returned dictionary means that
it is not supported at all.
"""
raise NotImplementedError
@abstractmethod
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
"""
Get the maximum possible number of tokens per data item
for each modality.
The dictionary returned by this method should have the same
keys as that returned by :meth:`get_supported_mm_limits`.
"""
raise NotImplementedError
_I = TypeVar("_I", bound=BaseProcessingInfo)
class BaseMultiModalProcessor(ProcessingMixin, ABC):
class BaseMultiModalProcessor(ABC, Generic[_I]):
"""
Abstract base class to process multi-modal inputs to be used in vLLM.
......@@ -554,18 +587,19 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
"""
def __init__(self,
ctx: InputProcessingContext,
info: _I,
dummy_inputs: "BaseDummyInputsBuilder[_I]",
*,
cache: Optional[ProcessingCache] = None,
enable_sanity_checks: bool = True) -> None:
super().__init__()
self.ctx = ctx
self.info = info
self.dummy_inputs = dummy_inputs
self.cache = cache
self.enable_sanity_checks = enable_sanity_checks
self.data_parser = self._get_data_parser()
self.profiling_info = self._get_profiling_info()
def __call__(
self,
......@@ -585,13 +619,6 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
"""
return MultiModalDataParser()
def _get_profiling_info(self) -> BaseProfilingInfo:
"""
Get the profiling information to find the worst-case memory usage of
the model.
"""
raise NotImplementedError
def _to_mm_items(
self,
mm_data: MultiModalDataDict,
......@@ -602,7 +629,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
"""
mm_items = self.data_parser.parse_mm_data(mm_data)
mm_limits = self.ctx.get_mm_config().limit_per_prompt
mm_limits = self.info.ctx.get_mm_config().limit_per_prompt
for modality, items in mm_items.items():
limit = mm_limits.get(modality, 1)
if len(items) > limit:
......@@ -646,19 +673,19 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
def _find_mm_placeholders(
self,
mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]],
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
new_token_ids: list[int],
mm_item_counts: Mapping[str, int],
) -> Mapping[str, list[_PlaceholderInfo]]:
) -> Mapping[str, list[PlaceholderInfo]]:
return find_mm_placeholders(mm_prompt_repls, new_token_ids,
mm_item_counts)
def _get_hf_mm_data(
self,
mm_items: MultiModalDataItems,
) -> tuple[dict[str, Any], dict[str, Any]]:
processor_data = dict[str, Any]()
passthrough_data = dict[str, Any]()
) -> tuple[Mapping[str, object], Mapping[str, object]]:
processor_data = dict[str, object]()
passthrough_data = dict[str, object]()
for items in mm_items.values():
processor_data.update(items.get_processor_data())
......@@ -678,8 +705,8 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
Call the HF processor on the prompt text and
associated multi-modal data.
"""
return self.ctx.call_hf_processor(
self._get_hf_processor(**mm_kwargs),
return self.info.ctx.call_hf_processor(
self.info.get_hf_processor(**mm_kwargs),
dict(text=prompt, **mm_data),
mm_kwargs,
)
......@@ -738,8 +765,8 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
# Some HF processors (e.g. Qwen2-VL) expect corresponding
# multi-modal tokens to be in the prompt text
dummy_inputs = self.profiling_info.get_dummy_processor_inputs(
self.ctx.model_config.max_model_len,
dummy_inputs = self.dummy_inputs.get_dummy_processor_inputs(
self.info.ctx.model_config.max_model_len,
mm_missing_counts,
)
......@@ -762,7 +789,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
caching the results and reusing cached results.
"""
cache = self.cache
model_id = self.ctx.model_config.model
model_id = self.info.model_id
_, passthrough_data = self._get_hf_mm_data(mm_data_items)
if cache is None or passthrough_data:
......@@ -838,8 +865,8 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
def _bind_and_group_repls(
self,
prompt_repls: list[PromptReplacement],
) -> dict[str, list[_BoundPromptReplacement]]:
tokenizer = self._get_tokenizer()
) -> dict[str, list[BoundPromptReplacement]]:
tokenizer = self.info.get_tokenizer()
it = (prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls)
return dict(full_groupby_modality(it))
......@@ -859,10 +886,10 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
def _apply_prompt_replacements(
self,
token_ids: list[int],
mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]],
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
mm_item_counts: Mapping[str, int],
) -> tuple[list[int], str, Mapping[str, list[_PlaceholderInfo]]]:
tokenizer = self._get_tokenizer()
) -> tuple[list[int], str, Mapping[str, list[PlaceholderInfo]]]:
tokenizer = self.info.get_tokenizer()
mm_token_matches = {
modality: find_token_matches(token_ids, prompt_repls)
......@@ -950,7 +977,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
def _validate_mm_placeholders(
self,
mm_placeholders: Mapping[str, list[_PlaceholderInfo]],
mm_placeholders: Mapping[str, list[PlaceholderInfo]],
mm_item_counts: Mapping[str, int],
*,
allow_missing: bool = False,
......@@ -1001,7 +1028,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
# instead of rehashing.
if envs.VLLM_USE_V1:
model_id = self.ctx.model_config.model
model_id = self.info.model_id
mm_hashes = {
modality: [
MultiModalHasher.hash_kwargs(model_id=model_id,
......@@ -1046,7 +1073,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
allow_missing=True,
)
mm_missing_repls = dict[str, list[_BoundPromptReplacement]]()
mm_missing_repls = dict[str, list[BoundPromptReplacement]]()
for modality, missing_repl_count in mm_missing_repl_counts.items():
if missing_repl_count == 0:
mm_missing_repls[modality] = []
......@@ -1059,7 +1086,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
# If HF processor already inserts placeholder tokens,
# there is no need for us to insert them
if all(len(repls) == 0 for repls in mm_missing_repls.items()):
tokenizer = self._get_tokenizer()
tokenizer = self.info.get_tokenizer()
prompt_text = decode_tokens(tokenizer, prompt_ids)
mm_placeholders = hf_mm_placeholders
else:
......@@ -1090,79 +1117,3 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
mm_hashes=mm_hashes,
mm_placeholders=mm_placeholder_ranges,
)
def _get_dummy_mm_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalInputsV2:
profiling = self.profiling_info
processor_inputs = profiling.get_dummy_processor_inputs(
seq_len, mm_counts)
return self.apply(
prompt_text=processor_inputs.prompt_text,
mm_data=processor_inputs.mm_data,
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
)
def get_dummy_data(self, seq_len: int) -> DummyData:
# Avoid circular import
from vllm.sequence import SequenceData
profiling = self.profiling_info
mm_counts = profiling.get_mm_limits()
mm_max_tokens_per_item = profiling.get_mm_max_tokens_per_item(seq_len)
if mm_counts.keys() != mm_max_tokens_per_item.keys():
raise AssertionError(
"The keys returned by `get_supported_mm_limits`"
f"({set(mm_counts.keys())}) should be the same as those "
"returned by `get_mm_max_tokens_per_item` "
f"({set(mm_max_tokens_per_item.keys())})")
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
prompt_token_ids = mm_inputs["prompt_token_ids"]
placeholders_by_modality = mm_inputs["mm_placeholders"]
total_placeholders_by_modality = {
modality: sum(item["length"] for item in placeholders)
for modality, placeholders in placeholders_by_modality.items()
}
expected_placeholders_by_modality = {
modality: mm_max_tokens_per_item[modality] * mm_counts[modality]
for modality in placeholders_by_modality
}
if total_placeholders_by_modality != expected_placeholders_by_modality:
raise AssertionError(
f"The processed dummy data has a total of "
f"{total_placeholders_by_modality} placeholder tokens, which "
f"is not the expected {expected_placeholders_by_modality} "
"tokens.")
total_len = len(prompt_token_ids)
# V0 does not support chunked prefill.
if total_len > seq_len and not envs.VLLM_USE_V1:
logger.warning(
"The context length (%d) of the model is too short "
"to hold the multi-modal embeddings in the worst case "
"(%d tokens in total, out of which %s are reserved for "
"multi-modal embeddings). This may cause certain multi-modal "
"inputs to fail during inference, even when the input text is "
"short. To avoid this, you should increase `max_model_len`, "
"reduce `max_num_seqs`, and/or reduce `mm_counts`.", seq_len,
total_len, total_placeholders_by_modality)
return DummyData(
seq_data=SequenceData.from_prompt_token_counts((0, seq_len)),
multi_modal_data=None,
multi_modal_placeholders=None,
)
prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids)))
return DummyData(
seq_data=SequenceData.from_seqs(prompt_token_ids),
multi_modal_data=mm_inputs["mm_kwargs"],
multi_modal_placeholders=placeholders_by_modality,
)
from abc import ABC, abstractmethod
from collections.abc import Mapping
from dataclasses import dataclass, field
from typing import Optional
from typing import Generic, TypeVar
import numpy as np
import numpy.typing as npt
from PIL import Image
from vllm.inputs import InputProcessingContext
import vllm.envs as envs
from vllm.inputs import DummyData
from vllm.logger import init_logger
from .inputs import MultiModalDataDict
from .inputs import MultiModalDataDict, MultiModalInputsV2
from .processing import BaseMultiModalProcessor, BaseProcessingInfo
logger = init_logger(__name__)
......@@ -23,39 +25,19 @@ class ProcessorInputs:
hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
class BaseProfilingInfo(ABC):
"""
Abstract base class that provides the information necessary to profile
multi-modal models.
"""
def __init__(self, ctx: InputProcessingContext) -> None:
super().__init__()
_I = TypeVar("_I", bound=BaseProcessingInfo)
self.ctx = ctx
@abstractmethod
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
class BaseDummyInputsBuilder(ABC, Generic[_I]):
"""
Return the maximum supported number of items for each modality.
A value of `None` means unlimited number of items.
Omitting a modality from the returned dictionary means that
it is not supported at all.
Abstract base class that constructs the dummy data to profile
multi-modal models.
"""
raise NotImplementedError
@abstractmethod
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
"""
Get the maximum possible number of tokens per data item
for each modality.
def __init__(self, info: _I) -> None:
super().__init__()
The dictionary returned by this method should have the same
keys as that returned by :meth:`get_supported_mm_limits`.
"""
raise NotImplementedError
self.info = info
@abstractmethod
def get_dummy_processor_inputs(
......@@ -64,8 +46,8 @@ class BaseProfilingInfo(ABC):
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
"""
Build the multi-modal portion of the input which, after processing,
results in `mm_max_tokens` in :meth:`get_mm_max_tokens_per_item`.
Build the input which, after processing, results in
`self.info.get_mm_max_tokens_per_item()` placeholder tokens.
"""
raise NotImplementedError
......@@ -99,11 +81,33 @@ class BaseProfilingInfo(ABC):
video = np.zeros((num_frames, width, height, 3))
return [video] * num_videos
def get_mm_limits(self) -> Mapping[str, int]:
mm_config = self.ctx.get_mm_config()
class MultiModalProfiler(Generic[_I]):
"""
Contains code for running memory profiling for multi-modal models.
"""
def __init__(
self,
processor: BaseMultiModalProcessor[_I],
) -> None:
super().__init__()
self.processor = processor
@property
def processing_info(self) -> BaseProcessingInfo:
return self.processor.info
@property
def dummy_inputs(self) -> BaseDummyInputsBuilder[_I]:
return self.processor.dummy_inputs
def _get_mm_limits(self) -> Mapping[str, int]:
mm_config = self.processing_info.ctx.get_mm_config()
mm_limit_per_prompt = mm_config.limit_per_prompt
supported_mm_limits = self.get_supported_mm_limits()
supported_mm_limits = self.processing_info.get_supported_mm_limits()
mm_limits = {
modality: mm_limit_per_prompt.get(modality, 1)
......@@ -119,3 +123,81 @@ class BaseProfilingInfo(ABC):
f"at most {supported_limit} {modality} items.")
return mm_limits
def _get_dummy_mm_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalInputsV2:
factory = self.dummy_inputs
processor_inputs = factory.get_dummy_processor_inputs(
seq_len, mm_counts)
return self.processor.apply(
prompt_text=processor_inputs.prompt_text,
mm_data=processor_inputs.mm_data,
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
)
def get_dummy_data(self, seq_len: int) -> DummyData:
# Avoid circular import
from vllm.sequence import SequenceData
mm_counts = self._get_mm_limits()
info = self.processing_info
mm_max_tokens_per_item = info.get_mm_max_tokens_per_item(seq_len)
if mm_counts.keys() != mm_max_tokens_per_item.keys():
raise AssertionError(
"The keys returned by `get_supported_mm_limits`"
f"({set(mm_counts.keys())}) should be the same as those "
"returned by `get_mm_max_tokens_per_item` "
f"({set(mm_max_tokens_per_item.keys())})")
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
prompt_token_ids = mm_inputs["prompt_token_ids"]
placeholders_by_modality = mm_inputs["mm_placeholders"]
total_placeholders_by_modality = {
modality: sum(item["length"] for item in placeholders)
for modality, placeholders in placeholders_by_modality.items()
}
expected_placeholders_by_modality = {
modality: mm_max_tokens_per_item[modality] * mm_counts[modality]
for modality in placeholders_by_modality
}
if total_placeholders_by_modality != expected_placeholders_by_modality:
raise AssertionError(
f"The processed dummy data has a total of "
f"{total_placeholders_by_modality} placeholder tokens, which "
f"is not the expected {expected_placeholders_by_modality} "
"tokens.")
total_len = len(prompt_token_ids)
# V0 does not support chunked prefill.
if total_len > seq_len and not envs.VLLM_USE_V1:
logger.warning(
"The context length (%d) of the model is too short "
"to hold the multi-modal embeddings in the worst case "
"(%d tokens in total, out of which %s are reserved for "
"multi-modal embeddings). This may cause certain multi-modal "
"inputs to fail during inference, even when the input text is "
"short. To avoid this, you should increase `max_model_len`, "
"reduce `max_num_seqs`, and/or reduce `mm_counts`.", seq_len,
total_len, total_placeholders_by_modality)
return DummyData(
seq_data=SequenceData.from_prompt_token_counts((0, seq_len)),
multi_modal_data=None,
multi_modal_placeholders=None,
)
prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids)))
return DummyData(
seq_data=SequenceData.from_seqs(prompt_token_ids),
multi_modal_data=mm_inputs["mm_kwargs"],
multi_modal_placeholders=placeholders_by_modality,
)
import functools
from collections import UserDict
from typing import (TYPE_CHECKING, Any, Dict, Mapping, Optional, Protocol,
Sequence, Type, TypeVar)
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, Generic, Mapping, Optional,
Protocol, Sequence, Type, TypeVar)
import torch.nn as nn
......@@ -14,7 +15,9 @@ from .audio import AudioPlugin
from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc
from .image import ImagePlugin
from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors
from .processing import BaseMultiModalProcessor, ProcessingCache
from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
ProcessingCache)
from .profiling import BaseDummyInputsBuilder
from .utils import cached_get_tokenizer
from .video import VideoPlugin
......@@ -27,20 +30,59 @@ logger = init_logger(__name__)
MM_CACHE_SIZE = 256
N = TypeVar("N", bound=Type[nn.Module])
_I = TypeVar("_I", bound=BaseProcessingInfo)
_I_co = TypeVar("_I_co", bound=BaseProcessingInfo, covariant=True)
class MultiModalProcessorFactory(Protocol):
class ProcessingInfoFactory(Protocol[_I_co]):
"""Constructs a :class:`MultiModalProcessor` instance from the context."""
def __call__(
self,
ctx: InputProcessingContext,
) -> _I_co:
...
class DummyInputsBuilderFactory(Protocol[_I]):
"""
Constructs a :class:`BaseDummyInputsBuilder` instance from the context.
"""
def __call__(self, info: _I) -> BaseDummyInputsBuilder[_I]:
...
class MultiModalProcessorFactory(Protocol[_I]):
"""Constructs a :class:`MultiModalProcessor` instance from the context."""
def __call__(
self,
info: _I,
dummy_inputs: BaseDummyInputsBuilder[_I],
*,
cache: Optional[ProcessingCache] = None,
) -> BaseMultiModalProcessor:
) -> BaseMultiModalProcessor[_I]:
...
@dataclass(frozen=True)
class _ProcessorFactories(Generic[_I]):
info: ProcessingInfoFactory[_I]
processor: MultiModalProcessorFactory[_I]
dummy_inputs: DummyInputsBuilderFactory[_I]
def build_processor(
self,
ctx: InputProcessingContext,
*,
cache: Optional[ProcessingCache] = None,
):
info = self.info(ctx)
dummy_inputs_builder = self.dummy_inputs(info)
return self.processor(info, dummy_inputs_builder, cache=cache)
class _MultiModalLimits(UserDict["ModelConfig", Dict[str, int]]):
"""
Wraps `_limits_by_model` for a more informative error message
......@@ -71,7 +113,7 @@ class MultiModalRegistry:
self._plugins = {p.get_data_key(): p for p in plugins}
self._processor_factories = ClassRegistry[nn.Module,
MultiModalProcessorFactory]()
_ProcessorFactories]()
# This is used for non-multimodal models
self._disabled_limits_per_plugin = {k: 0 for k in self._plugins}
......@@ -224,7 +266,7 @@ class MultiModalRegistry:
tokenizer = cached_get_tokenizer(model_config.tokenizer)
processor = self.create_processor(model_config, tokenizer)
seq_len = model_config.max_model_len
return processor.profiling_info.get_mm_max_tokens_per_item(seq_len)
return processor.info.get_mm_max_tokens_per_item(seq_len)
return {
key: plugin.get_max_multimodal_tokens(model_config)
......@@ -315,7 +357,10 @@ class MultiModalRegistry:
def register_processor(
self,
factory: MultiModalProcessorFactory,
processor: MultiModalProcessorFactory[_I],
*,
info: ProcessingInfoFactory[_I],
dummy_inputs: DummyInputsBuilderFactory[_I],
):
"""
Register a multi-modal processor to a model class. The processor
......@@ -336,7 +381,11 @@ class MultiModalRegistry:
"registered to %s. It is overwritten by the new one.",
model_cls, self)
self._processor_factories[model_cls] = factory
self._processor_factories[model_cls] = _ProcessorFactories(
info=info,
dummy_inputs=dummy_inputs,
processor=processor,
)
return model_cls
......@@ -359,15 +408,15 @@ class MultiModalRegistry:
self,
model_config: "ModelConfig",
tokenizer: AnyTokenizer,
) -> BaseMultiModalProcessor:
) -> BaseMultiModalProcessor[BaseProcessingInfo]:
"""
Create a multi-modal processor for a specific model and tokenizer.
"""
model_cls = self._get_model_cls(model_config)
processor_factory = self._processor_factories[model_cls]
factories = self._processor_factories[model_cls]
ctx = InputProcessingContext(model_config, tokenizer)
cache = (None if model_config.disable_mm_preprocessor_cache else
self._processing_cache)
return processor_factory(ctx, cache=cache)
return factories.build_processor(ctx, cache=cache)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment