Unverified Commit 5cbe8d15 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Core] Registry for processing model inputs (#5214)


Co-authored-by: default avatarywang96 <ywang@roblox.com>
parent 0d0e3a42
from typing import Dict, Tuple, Type, Union from functools import lru_cache
from typing import Dict, Type, Union
import torch import torch
from PIL import Image from PIL import Image
from vllm.config import ModelConfig, VisionLanguageConfig from vllm.config import ModelConfig
from vllm.inputs.registry import InputContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import SequenceData from vllm.transformers_utils.image_processor import get_image_processor
from vllm.transformers_utils.image_processor import cached_get_image_processor
from .base import MultiModalData, MultiModalPlugin from .base import MultiModalData, MultiModalPlugin
logger = init_logger(__name__) logger = init_logger(__name__)
cached_get_image_processor = lru_cache(get_image_processor)
def _get_dummy_seq_data(seq_len: int,
vlm_config: VisionLanguageConfig) -> SequenceData:
# NOTE: We assume that <image> token is repeated `image_feature_size` times
# and then concatenated with the text prompt
# TODO: Enable other ways of inserting the image into the prompt
token_ids = [vlm_config.image_token_id] * vlm_config.image_feature_size
token_ids += [0] * (seq_len - vlm_config.image_feature_size)
return SequenceData(token_ids)
def _get_dummy_values(vlm_config: VisionLanguageConfig) -> torch.Tensor:
if vlm_config.image_processor is None:
values_dtype = torch.float16
else:
values_dtype = torch.uint8
return torch.zeros(vlm_config.image_input_shape, dtype=values_dtype)
def get_dummy_image_data(
seq_len: int,
model_config: ModelConfig,
vlm_config: VisionLanguageConfig,
) -> Tuple[SequenceData, MultiModalData]:
"""Standard dummy data factory for image data (to be used in
:meth:`vlm.multimodal.MultiModalRegistry.register_dummy_data`)."""
seq_data = _get_dummy_seq_data(seq_len, vlm_config)
values = _get_dummy_values(vlm_config)
config_input_type = vlm_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType
fake_mm_data: MultiModalData
if config_input_type == ImageInputType.PIXEL_VALUES:
fake_mm_data = ImagePixelData(values)
elif config_input_type == ImageInputType.IMAGE_FEATURES:
fake_mm_data = ImageFeatureData(values)
else:
raise NotImplementedError
return seq_data, fake_mm_data
class ImagePixelData(MultiModalData): class ImagePixelData(MultiModalData):
""" """
The pixel data of an image. Can be one of: The pixel data of an image. Can be one of:
- :class:``PIL.Image``: An image object. Requires that a HuggingFace - :class:`PIL.Image.Image`: An image object. Requires that a HuggingFace
processor is available to the model. processor is available to the model.
- :class:``torch.Tensor``: The raw pixel data which is passed to the model - :class:`torch.Tensor`: The raw pixel data which is passed to the model
without additional pre-processing. without additional pre-processing.
""" """
...@@ -89,8 +47,8 @@ class ImagePixelPlugin(MultiModalPlugin[ImagePixelData]): ...@@ -89,8 +47,8 @@ class ImagePixelPlugin(MultiModalPlugin[ImagePixelData]):
def get_data_type(self) -> Type[ImagePixelData]: def get_data_type(self) -> Type[ImagePixelData]:
return ImagePixelData return ImagePixelData
def _get_hf_image_processor(self, model_config: ModelConfig, def _get_hf_image_processor(self, model_config: ModelConfig):
vlm_config: VisionLanguageConfig): vlm_config = model_config.multimodal_config
if vlm_config is None or vlm_config.image_processor is None: if vlm_config is None or vlm_config.image_processor is None:
return None return None
...@@ -100,14 +58,13 @@ class ImagePixelPlugin(MultiModalPlugin[ImagePixelData]): ...@@ -100,14 +58,13 @@ class ImagePixelPlugin(MultiModalPlugin[ImagePixelData]):
revision=vlm_config.image_processor_revision, revision=vlm_config.image_processor_revision,
) )
def _default_input_processor( def _default_input_mapper(self, ctx: InputContext,
self, data: ImagePixelData, model_config: ModelConfig, data: ImagePixelData) -> Dict[str, torch.Tensor]:
vlm_config: VisionLanguageConfig) -> Dict[str, torch.Tensor]: model_config = ctx.model_config
image = data.image image = data.image
if isinstance(image, Image.Image): if isinstance(image, Image.Image):
image_processor = self._get_hf_image_processor( image_processor = self._get_hf_image_processor(model_config)
model_config, vlm_config)
if image_processor is None: if image_processor is None:
raise RuntimeError("No HuggingFace processor is available" raise RuntimeError("No HuggingFace processor is available"
"to process the image object") "to process the image object")
...@@ -147,9 +104,10 @@ class ImageFeaturePlugin(MultiModalPlugin[ImageFeatureData]): ...@@ -147,9 +104,10 @@ class ImageFeaturePlugin(MultiModalPlugin[ImageFeatureData]):
def get_data_type(self) -> Type[ImageFeatureData]: def get_data_type(self) -> Type[ImageFeatureData]:
return ImageFeatureData return ImageFeatureData
def _default_input_processor( def _default_input_mapper(
self, data: ImageFeatureData, model_config: ModelConfig, self, ctx: InputContext,
vlm_config: VisionLanguageConfig) -> Dict[str, torch.Tensor]: data: ImageFeatureData) -> Dict[str, torch.Tensor]:
model_config = ctx.model_config
image_features = data.image_features.to(model_config.dtype) image_features = data.image_features.to(model_config.dtype)
return {"image_features": image_features} return {"image_features": image_features}
import functools import functools
from typing import (TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence, from typing import Any, Optional, Sequence, Type, TypeVar
Tuple, Type, TypeVar)
from vllm.config import ModelConfig, VisionLanguageConfig from torch import nn
from vllm.config import ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from .base import MultiModalData, MultiModalPlugin from .base import MultiModalData, MultiModalInputMapper, MultiModalPlugin
from .image import (ImageFeatureData, ImageFeaturePlugin, ImagePixelData, from .image import (ImageFeatureData, ImageFeaturePlugin, ImagePixelData,
ImagePixelPlugin) ImagePixelPlugin)
if TYPE_CHECKING:
import torch
from torch import nn
from vllm.sequence import SequenceData
logger = init_logger(__name__) logger = init_logger(__name__)
D = TypeVar("D", bound=MultiModalData) D = TypeVar("D", bound=MultiModalData)
N = TypeVar("N", bound=Type["nn.Module"]) N = TypeVar("N", bound=Type[nn.Module])
MultiModalInputProcessor = Callable[[D, ModelConfig, VisionLanguageConfig],
Dict[str, "torch.Tensor"]]
MultiModalDummyFactory = Callable[[int, ModelConfig, VisionLanguageConfig],
Tuple["SequenceData", MultiModalData]]
class MultiModalRegistry: class MultiModalRegistry:
""" """
This registry is used by model runners to dispatch data processing A registry to dispatch data processing
according to its modality and the target model. according to its modality and the target model.
""" """
DEFAULT_PLUGINS = (ImageFeaturePlugin(), ImagePixelPlugin()) DEFAULT_PLUGINS = (ImageFeaturePlugin(), ImagePixelPlugin())
def __init__(self, def __init__(
*, self,
plugins: Sequence[MultiModalPlugin[Any]] = DEFAULT_PLUGINS *,
) -> None: plugins: Sequence[MultiModalPlugin[Any]] = DEFAULT_PLUGINS,
) -> None:
self._plugins_by_data_type = {p.get_data_type(): p for p in plugins} self._plugins_by_data_type = {p.get_data_type(): p for p in plugins}
self._dummy_factories_by_model_type: Dict[Type["nn.Module"],
MultiModalDummyFactory] = {}
def register_plugin(self, plugin: MultiModalPlugin[Any]) -> None: def register_plugin(self, plugin: MultiModalPlugin[Any]) -> None:
data_type = plugin.get_data_type() data_type = plugin.get_data_type()
...@@ -62,95 +51,53 @@ class MultiModalRegistry: ...@@ -62,95 +51,53 @@ class MultiModalRegistry:
msg = f"Unknown multi-modal data type: {data_type}" msg = f"Unknown multi-modal data type: {data_type}"
raise NotImplementedError(msg) raise NotImplementedError(msg)
def register_dummy_data(self, factory: MultiModalDummyFactory): def register_input_mapper(
""" self,
Register a dummy data factory to a model class. data_type: Type[D],
mapper: Optional[MultiModalInputMapper[D]] = None,
During memory profiling, the provided function is invoked to create ):
dummy data to be inputted into the model. The modality and shape of
the dummy data should be an upper bound of what the model would receive
at inference time.
"""
def wrapper(model_cls: N) -> N:
if model_cls in self._dummy_factories_by_model_type:
logger.warning(
"Model class %s already has dummy data "
"registered to %s. It is overwritten by the new one.",
model_cls, self)
self._dummy_factories_by_model_type[model_cls] = factory
return model_cls
return wrapper
def dummy_data_for_profiling(self, seq_len: int, model_config: ModelConfig,
vlm_config: VisionLanguageConfig):
"""Create dummy data for memory profiling."""
model_cls = MultiModalPlugin.get_model_cls(model_config)
dummy_factory = self._dummy_factories_by_model_type.get(model_cls)
if dummy_factory is None:
msg = f"No dummy data defined for model class: {model_cls}"
raise NotImplementedError(msg)
return dummy_factory(seq_len, model_config, vlm_config)
def register_input(
self,
data_type: Type[D],
processor: Optional[MultiModalInputProcessor[D]] = None):
""" """
Register an input processor for a specific modality to a model class. Register an input mapper for a specific modality to a model class.
See :meth:`MultiModalPlugin.register_input_processor` for more details. See :meth:`MultiModalPlugin.register_input_mapper` for more details.
""" """
return self._get_plugin_for_data_type(data_type) \ return self._get_plugin_for_data_type(data_type) \
.register_input_processor(processor) .register_input_mapper(mapper)
def register_image_pixel_input( def register_image_pixel_input_mapper(
self, self,
processor: Optional[ mapper: Optional[MultiModalInputMapper[ImagePixelData]] = None,
MultiModalInputProcessor[ImagePixelData]] = None): ):
""" """
Register an input processor for image pixel data to a model class. Register an input mapper for image pixel data to a model class.
See :meth:`MultiModalPlugin.register_input_processor` for more details. See :meth:`MultiModalPlugin.register_input_mapper` for more details.
""" """
return self.register_input(ImagePixelData, processor) return self.register_input_mapper(ImagePixelData, mapper)
def register_image_feature_input( def register_image_feature_input_mapper(
self, self,
processor: Optional[ mapper: Optional[MultiModalInputMapper[ImageFeatureData]] = None,
MultiModalInputProcessor[ImageFeatureData]] = None): ):
""" """
Register an input processor for image feature data to a model class. Register an input mapper for image feature data to a model class.
See :meth:`MultiModalPlugin.register_input_processor` for more details. See :meth:`MultiModalPlugin.register_input_mapper` for more details.
""" """
return self.register_input(ImageFeatureData, processor) return self.register_input_mapper(ImageFeatureData, mapper)
def process_input(self, data: MultiModalData, model_config: ModelConfig, def map_input(self, model_config: ModelConfig, data: MultiModalData):
vlm_config: VisionLanguageConfig):
""" """
Apply an input processor to a :class:`~MultiModalData` instance passed Apply an input mapper to a :class:`~MultiModalData` instance passed
to the model. to the model.
See :meth:`MultiModalPlugin.process_input` for more details. See :meth:`MultiModalPlugin.map_input` for more details.
""" """
return self._get_plugin_for_data_type(type(data)) \ return self._get_plugin_for_data_type(type(data)) \
.process_input(data, model_config, vlm_config) .map_input(model_config, data)
def create_input_processor(self, model_config: ModelConfig, def create_input_mapper(self, model_config: ModelConfig):
vlm_config: VisionLanguageConfig):
""" """
Create an input processor (see :meth:`process_input`) for a Create an input mapper (see :meth:`map_input`) for a specific model.
specific model.
""" """
return functools.partial(self.process_input, return functools.partial(self.map_input, model_config)
model_config=model_config,
vlm_config=vlm_config)
MULTIMODAL_REGISTRY = MultiModalRegistry()
"""The global :class:`~MultiModalRegistry` which is used by model runners."""
...@@ -8,12 +8,12 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union ...@@ -8,12 +8,12 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch import torch
from vllm.inputs import LLMInputs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.inputs import LLMInputs
from vllm.multimodal import MultiModalData from vllm.multimodal import MultiModalData
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
...@@ -221,7 +221,7 @@ class Sequence: ...@@ -221,7 +221,7 @@ class Sequence:
def __init__( def __init__(
self, self,
seq_id: int, seq_id: int,
inputs: LLMInputs, inputs: "LLMInputs",
block_size: int, block_size: int,
eos_token_id: Optional[int] = None, eos_token_id: Optional[int] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
......
from functools import lru_cache
from typing import Optional from typing import Optional
from transformers import AutoImageProcessor from transformers import AutoImageProcessor
...@@ -40,6 +39,3 @@ def get_image_processor( ...@@ -40,6 +39,3 @@ def get_image_processor(
raise e raise e
return processor return processor
cached_get_image_processor = lru_cache(get_image_processor)
...@@ -110,15 +110,9 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]): ...@@ -110,15 +110,9 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
self.block_size, self.block_size,
) )
# Create processor for multi-modal data # Multi-modal data support
if self.vision_language_config is not None: self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
self.multi_modal_input_processor = MULTIMODAL_REGISTRY \ .create_input_mapper(self.model_config)
.create_input_processor(
self.model_config,
self.vision_language_config,
)
else:
self.multi_modal_input_processor = None
# Lazy initialization. # Lazy initialization.
self.model: nn.Module # Set after init_Model self.model: nn.Module # Set after init_Model
...@@ -168,13 +162,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]): ...@@ -168,13 +162,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
mm_data = seq_group_metadata.multi_modal_data mm_data = seq_group_metadata.multi_modal_data
if mm_data is not None: if mm_data is not None:
# Process multi-modal data mm_kwargs = self.multi_modal_input_mapper(mm_data)
if self.multi_modal_input_processor is None:
raise ValueError(
"Multi-modal inputs are only supported by "
"vision language models.")
mm_kwargs = self.multi_modal_input_processor(mm_data)
for k, v in mm_kwargs.items(): for k, v in mm_kwargs.items():
multi_modal_kwargs_list[k].append(v) multi_modal_kwargs_list[k].append(v)
......
...@@ -15,6 +15,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ...@@ -15,6 +15,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig, ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig) VisionLanguageConfig)
from vllm.distributed.parallel_state import graph_capture from vllm.distributed.parallel_state import graph_capture
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
...@@ -25,7 +26,7 @@ from vllm.model_executor.model_loader.tensorizer import TensorizerConfig ...@@ -25,7 +26,7 @@ from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models.interfaces import supports_lora from vllm.model_executor.models.interfaces import supports_lora
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip,
is_pin_memory_available, make_tensor_with_pad) is_pin_memory_available, make_tensor_with_pad)
from vllm.worker.model_runner_base import ( from vllm.worker.model_runner_base import (
...@@ -191,15 +192,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -191,15 +192,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.block_size, self.block_size,
) if num_attn_heads else None ) if num_attn_heads else None
# Create processor for multi-modal data # Multi-modal data support
if self.vision_language_config is not None: self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
self.multi_modal_input_processor = MULTIMODAL_REGISTRY \ .create_input_mapper(self.model_config)
.create_input_processor(
self.model_config,
self.vision_language_config,
)
else:
self.multi_modal_input_processor = None
# Lazy initialization # Lazy initialization
self.model: nn.Module # Set after load_model self.model: nn.Module # Set after load_model
...@@ -506,12 +501,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -506,12 +501,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
mm_data = seq_group_metadata.multi_modal_data mm_data = seq_group_metadata.multi_modal_data
if mm_data is not None: if mm_data is not None:
# Process multi-modal data # Process multi-modal data
if self.multi_modal_input_processor is None: mm_kwargs = self.multi_modal_input_mapper(mm_data)
raise ValueError(
"Multi-modal inputs are only supported by "
"vision language models.")
mm_kwargs = self.multi_modal_input_processor(mm_data)
for k, v in mm_kwargs.items(): for k, v in mm_kwargs.items():
multi_modal_kwargs_list[k].append(v) multi_modal_kwargs_list[k].append(v)
...@@ -764,12 +754,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -764,12 +754,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
seq_len = (max_num_batched_tokens // max_num_seqs + seq_len = (max_num_batched_tokens // max_num_seqs +
(group_id < max_num_batched_tokens % max_num_seqs)) (group_id < max_num_batched_tokens % max_num_seqs))
if vlm_config is None: seq_data, dummy_multi_modal_data = INPUT_REGISTRY \
seq_data = SequenceData([0] * seq_len) .dummy_data_for_profiling(model_config, seq_len)
dummy_multi_modal_data = None assert len(seq_data.prompt_token_ids) == seq_len
else:
seq_data, dummy_multi_modal_data = MULTIMODAL_REGISTRY \
.dummy_data_for_profiling(seq_len, model_config, vlm_config)
seq = SequenceGroupMetadata( seq = SequenceGroupMetadata(
request_id=str(group_id), request_id=str(group_id),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment