Unverified Commit 7675ba30 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Remove redundant `ClassRegistry` (#29681)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent 7c1ed458
...@@ -233,7 +233,7 @@ def _test_processing_correctness( ...@@ -233,7 +233,7 @@ def _test_processing_correctness(
) )
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls] factories = model_cls._processor_factory
ctx = InputProcessingContext( ctx = InputProcessingContext(
model_config, model_config,
tokenizer=cached_tokenizer_from_config(model_config), tokenizer=cached_tokenizer_from_config(model_config),
......
...@@ -193,7 +193,7 @@ def test_model_tensor_schema(model_id: str): ...@@ -193,7 +193,7 @@ def test_model_tensor_schema(model_id: str):
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
assert supports_multimodal(model_cls) assert supports_multimodal(model_cls)
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls] factories = model_cls._processor_factory
inputs_parse_methods = [] inputs_parse_methods = []
for attr_name in dir(model_cls): for attr_name in dir(model_cls):
......
...@@ -32,11 +32,13 @@ if TYPE_CHECKING: ...@@ -32,11 +32,13 @@ if TYPE_CHECKING:
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.models.utils import WeightsMapper from vllm.model_executor.models.utils import WeightsMapper
from vllm.multimodal.inputs import MultiModalFeatureSpec from vllm.multimodal.inputs import MultiModalFeatureSpec
from vllm.multimodal.registry import _ProcessorFactories
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
else: else:
VllmConfig = object VllmConfig = object
WeightsMapper = object WeightsMapper = object
MultiModalFeatureSpec = object MultiModalFeatureSpec = object
_ProcessorFactories = object
IntermediateTensors = object IntermediateTensors = object
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -87,6 +89,11 @@ class SupportsMultiModal(Protocol): ...@@ -87,6 +89,11 @@ class SupportsMultiModal(Protocol):
A set indicating CPU-only multimodal fields. A set indicating CPU-only multimodal fields.
""" """
_processor_factory: ClassVar[_ProcessorFactories]
"""
Set internally by `MultiModalRegistry.register_processor`.
"""
@classmethod @classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None: def get_placeholder_str(cls, modality: str, i: int) -> str | None:
""" """
......
...@@ -2,14 +2,11 @@ ...@@ -2,14 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Mapping from collections.abc import Mapping
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Generic, Protocol, TypeVar from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, cast
import torch.nn as nn
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, cached_tokenizer_from_config from vllm.transformers_utils.tokenizer import AnyTokenizer, cached_tokenizer_from_config
from vllm.utils.collection_utils import ClassRegistry
from .cache import BaseMultiModalProcessorCache from .cache import BaseMultiModalProcessorCache
from .processing import ( from .processing import (
...@@ -26,10 +23,11 @@ from .profiling import ( ...@@ -26,10 +23,11 @@ from .profiling import (
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.model_executor.models.interfaces import SupportsMultiModal
logger = init_logger(__name__) logger = init_logger(__name__)
N = TypeVar("N", bound=type[nn.Module]) N = TypeVar("N", bound=type["SupportsMultiModal"])
_I = TypeVar("_I", bound=BaseProcessingInfo) _I = TypeVar("_I", bound=BaseProcessingInfo)
_I_co = TypeVar("_I_co", bound=BaseProcessingInfo, covariant=True) _I_co = TypeVar("_I_co", bound=BaseProcessingInfo, covariant=True)
...@@ -95,9 +93,6 @@ class MultiModalRegistry: ...@@ -95,9 +93,6 @@ class MultiModalRegistry:
A registry that dispatches data processing according to the model. A registry that dispatches data processing according to the model.
""" """
def __init__(self) -> None:
self._processor_factories = ClassRegistry[nn.Module, _ProcessorFactories]()
def _extract_mm_options( def _extract_mm_options(
self, self,
model_config: "ModelConfig", model_config: "ModelConfig",
...@@ -207,7 +202,7 @@ class MultiModalRegistry: ...@@ -207,7 +202,7 @@ class MultiModalRegistry:
""" """
def wrapper(model_cls: N) -> N: def wrapper(model_cls: N) -> N:
if self._processor_factories.contains(model_cls, strict=True): if "_processor_factory" in model_cls.__dict__:
logger.warning( logger.warning(
"Model class %s already has a multi-modal processor " "Model class %s already has a multi-modal processor "
"registered to %s. It is overwritten by the new one.", "registered to %s. It is overwritten by the new one.",
...@@ -215,7 +210,7 @@ class MultiModalRegistry: ...@@ -215,7 +210,7 @@ class MultiModalRegistry:
self, self,
) )
self._processor_factories[model_cls] = _ProcessorFactories( model_cls._processor_factory = _ProcessorFactories(
info=info, info=info,
dummy_inputs=dummy_inputs, dummy_inputs=dummy_inputs,
processor=processor, processor=processor,
...@@ -225,12 +220,13 @@ class MultiModalRegistry: ...@@ -225,12 +220,13 @@ class MultiModalRegistry:
return wrapper return wrapper
def _get_model_cls(self, model_config: "ModelConfig"): def _get_model_cls(self, model_config: "ModelConfig") -> "SupportsMultiModal":
# Avoid circular import # Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture from vllm.model_executor.model_loader import get_model_architecture
model_cls, _ = get_model_architecture(model_config) model_cls, _ = get_model_architecture(model_config)
return model_cls assert hasattr(model_cls, "_processor_factory")
return cast("SupportsMultiModal", model_cls)
def _create_processing_ctx( def _create_processing_ctx(
self, self,
...@@ -248,7 +244,7 @@ class MultiModalRegistry: ...@@ -248,7 +244,7 @@ class MultiModalRegistry:
tokenizer: AnyTokenizer | None = None, tokenizer: AnyTokenizer | None = None,
) -> BaseProcessingInfo: ) -> BaseProcessingInfo:
model_cls = self._get_model_cls(model_config) model_cls = self._get_model_cls(model_config)
factories = self._processor_factories[model_cls] factories = model_cls._processor_factory
ctx = self._create_processing_ctx(model_config, tokenizer) ctx = self._create_processing_ctx(model_config, tokenizer)
return factories.info(ctx) return factories.info(ctx)
...@@ -266,7 +262,7 @@ class MultiModalRegistry: ...@@ -266,7 +262,7 @@ class MultiModalRegistry:
raise ValueError(f"{model_config.model} is not a multimodal model") raise ValueError(f"{model_config.model} is not a multimodal model")
model_cls = self._get_model_cls(model_config) model_cls = self._get_model_cls(model_config)
factories = self._processor_factories[model_cls] factories = model_cls._processor_factory
ctx = self._create_processing_ctx(model_config, tokenizer) ctx = self._create_processing_ctx(model_config, tokenizer)
......
...@@ -6,64 +6,37 @@ Contains helpers that are applied to collections. ...@@ -6,64 +6,37 @@ Contains helpers that are applied to collections.
This is similar in concept to the `collections` module. This is similar in concept to the `collections` module.
""" """
from collections import UserDict, defaultdict from collections import defaultdict
from collections.abc import Callable, Generator, Hashable, Iterable, Mapping from collections.abc import Callable, Generator, Hashable, Iterable, Mapping
from typing import Generic, Literal, TypeVar from typing import Generic, Literal, TypeVar
from typing_extensions import TypeIs, assert_never from typing_extensions import TypeIs, assert_never
T = TypeVar("T") T = TypeVar("T")
U = TypeVar("U")
_K = TypeVar("_K", bound=Hashable) _K = TypeVar("_K", bound=Hashable)
_V = TypeVar("_V") _V = TypeVar("_V")
class ClassRegistry(UserDict[type[T], _V]): class LazyDict(Mapping[str, _V], Generic[_V]):
"""
A registry that acts like a dictionary but searches for other classes
in the MRO if the original class is not found.
"""
def __getitem__(self, key: type[T]) -> _V:
for cls in key.mro():
if cls in self.data:
return self.data[cls]
raise KeyError(key)
def __contains__(self, key: object) -> bool:
return self.contains(key)
def contains(self, key: object, *, strict: bool = False) -> bool:
if not isinstance(key, type):
return False
if strict:
return key in self.data
return any(cls in self.data for cls in key.mro())
class LazyDict(Mapping[str, T], Generic[T]):
""" """
Evaluates dictionary items only when they are accessed. Evaluates dictionary items only when they are accessed.
Adapted from: https://stackoverflow.com/a/47212782/5082708 Adapted from: https://stackoverflow.com/a/47212782/5082708
""" """
def __init__(self, factory: dict[str, Callable[[], T]]): def __init__(self, factory: dict[str, Callable[[], _V]]):
self._factory = factory self._factory = factory
self._dict: dict[str, T] = {} self._dict: dict[str, _V] = {}
def __getitem__(self, key: str) -> T: def __getitem__(self, key: str) -> _V:
if key not in self._dict: if key not in self._dict:
if key not in self._factory: if key not in self._factory:
raise KeyError(key) raise KeyError(key)
self._dict[key] = self._factory[key]() self._dict[key] = self._factory[key]()
return self._dict[key] return self._dict[key]
def __setitem__(self, key: str, value: Callable[[], T]): def __setitem__(self, key: str, value: Callable[[], _V]):
self._factory[key] = value self._factory[key] = value
def __iter__(self): def __iter__(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment