Unverified Commit 98d6682c authored by xwjiang2010's avatar xwjiang2010 Committed by GitHub
Browse files

[VLM] Remove `image_input_type` from VLM config (#5852)


Signed-off-by: default avatarXiaowei Jiang <xwjiang2010@gmail.com>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: default avatarRoger Wang <ywang@roblox.com>
parent 2c37540a
......@@ -4,7 +4,7 @@ from typing import (TYPE_CHECKING, List, Literal, Optional, Sequence,
from typing_extensions import NotRequired
if TYPE_CHECKING:
from vllm.multimodal import MultiModalData
from vllm.multimodal import MultiModalDataDict
class ParsedText(TypedDict):
......@@ -72,7 +72,7 @@ class TextPrompt(TypedDict):
prompt: str
"""The input text to be tokenized before passing to the model."""
multi_modal_data: NotRequired["MultiModalData"]
multi_modal_data: NotRequired["MultiModalDataDict"]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
......@@ -85,7 +85,7 @@ class TokensPrompt(TypedDict):
prompt_token_ids: List[int]
"""A list of token IDs to pass to the model."""
multi_modal_data: NotRequired["MultiModalData"]
multi_modal_data: NotRequired["MultiModalDataDict"]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
......@@ -103,7 +103,7 @@ class TextTokensPrompt(TypedDict):
prompt_token_ids: List[int]
"""The token IDs of the prompt."""
multi_modal_data: NotRequired["MultiModalData"]
multi_modal_data: NotRequired["MultiModalDataDict"]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
......@@ -128,7 +128,6 @@ class LLMInputs(TypedDict):
The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor.
"""
prompt_token_ids: List[int]
"""The token IDs of the prompt."""
......@@ -137,7 +136,7 @@ class LLMInputs(TypedDict):
The original prompt text corresponding to the token IDs, if available.
"""
multi_modal_data: NotRequired[Optional["MultiModalData"]]
multi_modal_data: NotRequired[Optional["MultiModalDataDict"]]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
......
......@@ -12,7 +12,7 @@ from .data import LLMInputs
if TYPE_CHECKING:
from vllm.config import ModelConfig, VisionLanguageConfig
from vllm.multimodal import MultiModalData
from vllm.multimodal import MultiModalDataDict
from vllm.sequence import SequenceData
logger = init_logger(__name__)
......@@ -66,7 +66,8 @@ class InputContext:
N = TypeVar("N", bound=Type[nn.Module])
DummyDataFactory = Callable[[InputContext, int],
Tuple["SequenceData", Optional["MultiModalData"]]]
Tuple["SequenceData",
Optional["MultiModalDataDict"]]]
"""
Create dummy data to be inputted into the model.
......@@ -94,7 +95,7 @@ class InputRegistry:
self,
ctx: InputContext,
seq_len: int,
) -> Tuple["SequenceData", Optional["MultiModalData"]]:
) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]:
"""
The default dummy data factory represents the longest possible text
that can be inputted to the model.
......
......@@ -84,9 +84,8 @@ def _get_model_initialization_kwargs(
if supports_vision(model_class):
if vlm_config is None:
raise ValueError("Provide `image_input_type` and other vision "
"related configurations through LLM entrypoint "
"or engine arguments.")
raise ValueError("Provide vision related configurations "
"through LLM entrypoint or engine arguments.")
extra_kwargs["vlm_config"] = vlm_config
......
......@@ -12,7 +12,6 @@ from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal.image import ImageFeatureData, ImagePixelData
from vllm.sequence import SequenceData
......@@ -49,7 +48,7 @@ def dummy_seq_data_for_clip(
return SequenceData(token_ids)
def dummy_pixel_data_for_clip(
def dummy_image_for_clip(
hf_config: CLIPVisionConfig,
*,
image_width_override: Optional[int] = None,
......@@ -62,22 +61,7 @@ def dummy_pixel_data_for_clip(
height = image_height_override
image = Image.new("RGB", (width, height), color=0)
return ImagePixelData(image)
def dummy_feature_data_for_clip(
hf_config: CLIPVisionConfig,
*,
image_feature_size_override: Optional[int] = None,
):
if image_feature_size_override is None:
image_feature_size = get_clip_image_feature_size(hf_config)
else:
image_feature_size = image_feature_size_override
values = torch.zeros((1, image_feature_size, hf_config.hidden_size),
dtype=torch.float16)
return ImageFeatureData(values)
return {"image": image}
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
......
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
import torch
import torch.nn as nn
......@@ -17,11 +17,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalData
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import SamplerOutput
from .clip import (dummy_feature_data_for_clip, dummy_pixel_data_for_clip,
dummy_seq_data_for_clip)
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from .interfaces import SupportsVision
_KEYS_TO_MODIFY_MAPPING = {
......@@ -76,17 +75,10 @@ class LlavaImagePixelInputs(TypedDict):
"""Shape: (batch_size, num_channels, height, width)"""
class LlavaImageFeatureInputs(TypedDict):
type: Literal["image_features"]
data: torch.Tensor
"""Shape: (batch_size, image_feature_size, hidden_size)"""
LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageFeatureInputs]
LlavaImageInputs = LlavaImagePixelInputs
def dummy_data_for_llava(ctx: InputContext, seq_len: int):
multimodal_config = ctx.get_multimodal_config()
hf_config = ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config
......@@ -97,22 +89,14 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int):
image_token_id=hf_config.image_token_index,
)
image_input_type = multimodal_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType
mm_data: MultiModalData
if image_input_type == ImageInputType.PIXEL_VALUES:
mm_data = dummy_pixel_data_for_clip(vision_config)
elif image_input_type == ImageInputType.IMAGE_FEATURES:
mm_data = dummy_feature_data_for_clip(vision_config)
mm_data = dummy_image_for_clip(vision_config)
return seq_data, mm_data
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
@MULTIMODAL_REGISTRY.register_image_feature_input_mapper()
@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper()
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
class LlavaForConditionalGeneration(nn.Module, SupportsVision):
......@@ -126,11 +110,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
self.config = config
self.vlm_config = vlm_config
if self.vlm_config.image_input_type == (
VisionLanguageConfig.ImageInputType.PIXEL_VALUES):
self.vision_tower = CLIPVisionModel(config.vision_config)
else:
self.vision_tower = None
# TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = CLIPVisionModel(config.vision_config)
self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
......@@ -165,44 +146,18 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_features = kwargs.pop("image_features", None)
expected_input_type = self.vlm_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType
if expected_input_type == ImageInputType.PIXEL_VALUES:
if image_features is not None:
raise ValueError(
"Expected pixel values but got image features")
if pixel_values is None:
return None
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_image_data(pixel_values),
)
if pixel_values is None:
return None
if expected_input_type == ImageInputType.IMAGE_FEATURES:
if pixel_values is not None:
raise ValueError(
"Expected image features but got pixel values")
if image_features is None:
return None
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if not isinstance(image_features, torch.Tensor):
raise ValueError("Incorrect type of image features. "
f"Got type: {type(image_features)}")
return LlavaImageFeatureInputs(
type="image_features",
data=self._validate_image_data(image_features),
)
return None
return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_image_data(pixel_values),
)
def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:
......@@ -237,12 +192,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
def _process_image_input(self,
image_input: LlavaImageInputs) -> torch.Tensor:
if image_input["type"] == "pixel_values":
assert self.vision_tower is not None
image_features = self._process_image_pixels(image_input)
else:
image_features = image_input["data"]
assert self.vision_tower is not None
image_features = self._process_image_pixels(image_input)
return self.multi_modal_projector(image_features)
def forward(
......@@ -273,25 +224,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
This way, the `positions` and `attn_metadata` are consistent
with the `input_ids`.
This model has two modes of image inputs:
`PIXEL_VALUES` and `IMAGE_FEATURES`.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
pixel_values: The pixels in each input image.
Expects a batch with shape `[1, 3, 336, 336]`.
(Only applicable to `PIXEL_VALUES` mode)
image_features: The image features for each input image outputted by
the vision tower before passing to the multi-modal projector.
Expects a batch with shape `[1, 576, 1024]`.
(Only applicable to `IMAGE_FEATURES` mode)
See also:
Each input maps to huggingface implementation, as follows:
- `pixel_values`: https://github.com/huggingface/transformers/blob/v4.41.1/src/transformers/models/llava/modeling_llava.py#L360
- `image_features`: https://github.com/huggingface/transformers/blob/v4.41.1/src/transformers/models/llava/modeling_llava.py#L437
"""
image_input = self._parse_and_validate_image_input(**kwargs)
......
from typing import (Dict, Iterable, List, Literal, Optional, Tuple, TypedDict,
Union)
from typing import Dict, Iterable, List, Literal, Optional, Tuple, TypedDict
import torch
import torch.nn as nn
from PIL import Image
from transformers import CLIPVisionConfig, LlavaNextConfig
from transformers.models.llava_next.modeling_llava_next import (
get_anyres_image_grid_shape, unpad_image)
......@@ -21,12 +21,11 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalData
from vllm.multimodal.image import ImagePixelData
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import SamplerOutput
from .clip import (dummy_feature_data_for_clip, dummy_pixel_data_for_clip,
dummy_seq_data_for_clip, get_clip_patch_grid_length)
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
get_clip_patch_grid_length)
from .interfaces import SupportsVision
from .llava import LlavaMultiModalProjector, merge_vision_embeddings
......@@ -47,17 +46,7 @@ class LlavaNextImagePixelInputs(TypedDict):
"""Shape: (batch_size, 2)"""
class LlavaNextImageFeatureInputs(TypedDict):
type: Literal["image_features"]
data: torch.Tensor
"""Shape: (batch_size, 1 + num_patches, image_feature_size, hidden_size)"""
image_sizes: NotRequired[torch.Tensor]
"""Shape: (batch_size, 2)"""
LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
LlavaNextImageFeatureInputs]
LlavaNextImageInputs = LlavaNextImagePixelInputs
def _get_llava_next_num_unpadded_features(
......@@ -138,20 +127,11 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
image_feature_size_override=image_feature_size,
)
image_input_type = multimodal_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType
mm_data: MultiModalData
if image_input_type == ImageInputType.PIXEL_VALUES:
mm_data = dummy_pixel_data_for_clip(
vision_config,
image_width_override=dummy_width,
image_height_override=dummy_height,
)
elif image_input_type == ImageInputType.IMAGE_FEATURES:
mm_data = dummy_feature_data_for_clip(
vision_config,
image_feature_size_override=image_feature_size,
)
mm_data = dummy_image_for_clip(
vision_config,
image_width_override=dummy_width,
image_height_override=dummy_height,
)
return seq_data, mm_data
......@@ -159,32 +139,26 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
raise NotImplementedError(msg)
def _pixel_mapper(ctx: InputContext,
data: ImagePixelData) -> Dict[str, torch.Tensor]:
image = data.image
def _pixel_mapper(ctx: InputContext, image: object) -> Dict[str, torch.Tensor]:
if isinstance(image, torch.Tensor):
pixel_values = image.to(ctx.model_config.dtype)
batch_size, _, _, h, w = pixel_values.shape
image_sizes = torch.tensor([(w, h) for _ in range(batch_size)])
if isinstance(image, Image.Image):
return {"pixel_values": pixel_values, "image_sizes": image_sizes}
# Temporary patch before dynamic number of image tokens is supported
_, _, h, w = ctx.get_multimodal_config().image_input_shape
if (w, h) != (image.width, image.height):
logger.warning(
"Dynamic image shape is currently not supported. "
"Resizing input image to (%d, %d).", w, h)
# Temporary patch before dynamic number of image tokens is supported
_, _, h, w = ctx.get_multimodal_config().image_input_shape
if (w, h) != (image.width, image.height):
logger.warning(
"Dynamic image shape is currently not supported. "
"Resizing input image to (%d, %d).", w, h)
image = image.resize((w, h))
data.image = image.resize((w, h))
return MULTIMODAL_REGISTRY._get_plugin("image") \
._default_input_mapper(ctx, image)
return MULTIMODAL_REGISTRY._get_plugin_for_data_type(ImagePixelData) \
._default_input_mapper(ctx, data)
raise TypeError(f"Invalid type for 'image': {type(image)}")
@MULTIMODAL_REGISTRY.register_image_feature_input_mapper()
@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper(_pixel_mapper)
@MULTIMODAL_REGISTRY.register_image_input_mapper(_pixel_mapper)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
......@@ -198,11 +172,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
self.config = config
self.vlm_config = vlm_config
if self.vlm_config.image_input_type == (
VisionLanguageConfig.ImageInputType.PIXEL_VALUES):
self.vision_tower = CLIPVisionModel(config=config.vision_config)
else:
raise TypeError("Image features are not supported by LLaVA-NeXT")
self.vision_tower = CLIPVisionModel(config=config.vision_config)
self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
......@@ -255,36 +225,23 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
self, **kwargs: object) -> Optional[LlavaNextImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_sizes = kwargs.pop("image_sizes", None)
image_features = kwargs.pop("image_features", None)
expected_input_type = self.vlm_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType
if expected_input_type == ImageInputType.PIXEL_VALUES:
if image_features is not None:
raise ValueError(
"Expected pixel values but got image features")
if pixel_values is None:
return None
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if pixel_values is None or image_sizes is None:
return None
if not isinstance(image_sizes, torch.Tensor):
raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}")
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
return LlavaNextImagePixelInputs(
type="pixel_values",
data=self._validate_image_pixels(pixel_values),
image_sizes=self._validate_image_sizes(image_sizes),
)
if not isinstance(image_sizes, torch.Tensor):
raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}")
assert expected_input_type != ImageInputType.IMAGE_FEATURES, (
"Failed to validate this at initialization time")
return None
return LlavaNextImagePixelInputs(
type="pixel_values",
data=self._validate_image_pixels(pixel_values),
image_sizes=self._validate_image_sizes(image_sizes),
)
def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:
......@@ -391,11 +348,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
def _process_image_input(
self, image_input: LlavaNextImageInputs) -> torch.Tensor:
if image_input["type"] == "pixel_values":
assert self.vision_tower is not None
image_features = self._process_image_pixels(image_input)
else:
image_features = image_input["data"]
assert self.vision_tower is not None
image_features = self._process_image_pixels(image_input)
patch_embeddings = self.multi_modal_projector(image_features)
......
......@@ -35,10 +35,9 @@ from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import ImagePixelData
from vllm.sequence import SamplerOutput
from .clip import dummy_pixel_data_for_clip, dummy_seq_data_for_clip
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from .interfaces import SupportsVision
logger = init_logger(__name__)
......@@ -286,7 +285,7 @@ def dummy_data_for_phi3v(ctx: InputContext, seq_len: int):
image_token_id=32044,
image_feature_size_override=image_feature_size,
)
mm_data = dummy_pixel_data_for_clip(
mm_data = dummy_image_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
image_width_override=dummy_width,
image_height_override=dummy_height,
......@@ -331,8 +330,7 @@ def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16):
def _image_processor(ctx: InputContext,
data: ImagePixelData) -> Dict[str, torch.Tensor]:
image = data.image
image: object) -> Dict[str, torch.Tensor]:
if isinstance(image, Image.Image):
# Temporary patch before dynamic number of image tokens is supported
......@@ -343,13 +341,14 @@ def _image_processor(ctx: InputContext,
"Dynamic image shape is currently not supported. "
"Resizing input image to (%d, %d).", w, h)
data.image = image.resize((w, h))
image = image.resize((w, h))
return MULTIMODAL_REGISTRY._get_plugin_for_data_type(ImagePixelData) \
._default_input_mapper(ctx, data)
return MULTIMODAL_REGISTRY._get_plugin("image") \
._default_input_mapper(ctx, image)
raise TypeError(f"Invalid type for 'image': {type(image)}")
@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper(_image_processor)
@MULTIMODAL_REGISTRY.register_image_input_mapper(_image_processor)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v)
class Phi3VForCausalLM(nn.Module, SupportsVision):
......@@ -375,14 +374,6 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
pixel_values = kwargs.pop("pixel_values", None)
image_sizes = kwargs.pop("image_sizes", None)
expected_input_type = self.vlm_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType
if expected_input_type != ImageInputType.PIXEL_VALUES:
raise ValueError(
f"Unexpected image input type: {expected_input_type}."
"Phi3v only support pixel_values input currently.")
if pixel_values is not None and image_sizes is not None:
return Phi3VImagePixelInputs(type="pixel_values",
data=pixel_values,
......
from .base import MultiModalData, MultiModalPlugin
from .base import MultiModalDataDict, MultiModalPlugin
from .registry import MultiModalRegistry
MULTIMODAL_REGISTRY = MultiModalRegistry()
......@@ -11,6 +11,8 @@ See also:
"""
__all__ = [
"MultiModalData", "MultiModalPlugin", "MULTIMODAL_REGISTRY",
"MultiModalRegistry"
"MultiModalPlugin",
"MULTIMODAL_REGISTRY",
"MultiModalRegistry",
"MultiModalDataDict",
]
from abc import ABC, abstractmethod
from typing import (TYPE_CHECKING, Callable, Dict, Generic, Optional, Type,
TypeVar)
from typing import (TYPE_CHECKING, Any, Callable, Dict, Optional, Type,
TypedDict, TypeVar, Union)
from vllm.config import ModelConfig
from vllm.inputs import InputContext
......@@ -8,38 +8,35 @@ from vllm.logger import init_logger
if TYPE_CHECKING:
import torch
from PIL import Image
from torch import nn
logger = init_logger(__name__)
N = TypeVar("N", bound=Type["nn.Module"])
class MultiModalData:
"""
Base class that contains multi-modal data.
To add a new modality, add a new file under ``multimodal`` directory.
In this new file, subclass :class:`~MultiModalData` and
:class:`~MultiModalPlugin`.
class MultiModalDataBuiltins(TypedDict, total=False):
image: "Image.Image"
Finally, register the new plugin to
:const:`vllm.multimodal.MULTIMODAL_REGISTRY`.
This enables models to call :meth:`MultiModalRegistry.map_input` for
the new modality.
"""
pass
MultiModalDataDict = Union[MultiModalDataBuiltins, Dict[str, Any]]
"""
A dictionary containing an item for each modality type to input.
D = TypeVar("D", bound=MultiModalData)
N = TypeVar("N", bound=Type["nn.Module"])
The data belonging to each modality is converted into keyword arguments
to the model by the corresponding mapper. By default, the mapper of
the corresponding plugin with the same modality key is applied.
"""
MultiModalInputMapper = Callable[[InputContext, D], Dict[str, "torch.Tensor"]]
MultiModalInputMapper = Callable[[InputContext, object], Dict[str,
"torch.Tensor"]]
"""Return a dictionary to be passed as keyword arguments to
:meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers
and processors in HuggingFace Transformers."""
class MultiModalPlugin(ABC, Generic[D]):
class MultiModalPlugin(ABC):
"""
Base class that defines data processing logic for a specific modality.
......@@ -52,19 +49,18 @@ class MultiModalPlugin(ABC, Generic[D]):
def __init__(self) -> None:
self._input_mappers: Dict[Type["nn.Module"],
MultiModalInputMapper[D]] = {}
MultiModalInputMapper] = {}
@abstractmethod
def get_data_type(self) -> Type[D]:
def get_data_key(self) -> str:
"""
Get the modality (subclass of :class:`~MultiModalData`) served by
this plugin.
Get the data key corresponding to the modality.
"""
raise NotImplementedError
@abstractmethod
def _default_input_mapper(self, ctx: InputContext,
data: D) -> Dict[str, "torch.Tensor"]:
data: object) -> Dict[str, "torch.Tensor"]:
"""Return a dictionary to be passed as keyword arguments to
:meth:`~torch.nn.Module.forward`. This is similar in concept to
tokenizers and processors in HuggingFace Transformers.
......@@ -73,11 +69,10 @@ class MultiModalPlugin(ABC, Generic[D]):
def register_input_mapper(
self,
mapper: Optional[MultiModalInputMapper[D]] = None,
mapper: Optional[MultiModalInputMapper] = None,
):
"""
Register an input mapper to a model class.
When the model receives input data that matches the modality served by
this plugin (see :meth:`get_data_type`), the provided function is
invoked to transform the data into a dictionary of model inputs.
......@@ -102,11 +97,13 @@ class MultiModalPlugin(ABC, Generic[D]):
return wrapper
def map_input(self, model_config: ModelConfig,
data: D) -> Dict[str, "torch.Tensor"]:
data: object) -> Dict[str, "torch.Tensor"]:
"""
Apply an input mapper to a :class:`~MultiModalData` instance passed
Apply an input mapper to a data passed
to the model, transforming the data into a dictionary of model inputs.
If the data is not something that the mapper expects, throws TypeError.
The model is identified by ``model_config``.
TODO: Add guide [ref: PR #5276]
......
from functools import lru_cache
from typing import Dict, Type, Union
from typing import Dict
import torch
from PIL import Image
......@@ -9,105 +9,36 @@ from vllm.inputs.registry import InputContext
from vllm.logger import init_logger
from vllm.transformers_utils.image_processor import get_image_processor
from .base import MultiModalData, MultiModalPlugin
from .base import MultiModalPlugin
logger = init_logger(__name__)
cached_get_image_processor = lru_cache(get_image_processor)
class ImagePixelData(MultiModalData):
"""
The pixel data of an image. Can be one of:
class ImagePlugin(MultiModalPlugin):
- :class:`PIL.Image.Image`: An image object. Requires that a HuggingFace
processor is available to the model.
- :class:`torch.Tensor`: The raw pixel data which is passed to the model
without additional pre-processing.
"""
def __init__(self, image: Union[Image.Image, torch.Tensor]) -> None:
if isinstance(image, Image.Image):
# So that this class can be created inside the Image context manager
image.load()
self.image = image
def __repr__(self) -> str:
image = self.image
if isinstance(image, Image.Image):
return f"{type(self).__name__}(image={image})"
return (f"{type(self).__name__}(image=torch.Tensor(shape="
f"{image.shape}, dtype={image.dtype}))")
class ImagePixelPlugin(MultiModalPlugin[ImagePixelData]):
def get_data_type(self) -> Type[ImagePixelData]:
return ImagePixelData
def get_data_key(self) -> str:
return "image"
def _get_hf_image_processor(self, model_config: ModelConfig):
vlm_config = model_config.multimodal_config
if vlm_config is None or vlm_config.image_processor is None:
return None
return cached_get_image_processor(
vlm_config.image_processor,
trust_remote_code=model_config.trust_remote_code,
revision=vlm_config.image_processor_revision,
)
model_config.model,
trust_remote_code=model_config.trust_remote_code)
def _default_input_mapper(self, ctx: InputContext,
data: ImagePixelData) -> Dict[str, torch.Tensor]:
data: object) -> Dict[str, torch.Tensor]:
model_config = ctx.model_config
image = data.image
if isinstance(image, Image.Image):
if isinstance(data, Image.Image):
image_processor = self._get_hf_image_processor(model_config)
if image_processor is None:
raise RuntimeError("No HuggingFace processor is available"
"to process the image object")
try:
return image_processor.preprocess(image, return_tensors="pt") \
return image_processor.preprocess(data, return_tensors="pt") \
.to(model_config.dtype).data
except Exception:
logger.error("Failed to process image (%s)", image)
logger.error("Failed to process image (%s)", data)
raise
elif isinstance(image, torch.Tensor):
pixel_values = image.to(model_config.dtype)
return {"pixel_values": pixel_values}
raise TypeError(f"Invalid image type: {type(image)}")
class ImageFeatureData(MultiModalData):
"""
The feature vector of an image, passed directly to the model.
This should be the output of the vision tower.
"""
def __init__(self, image_features: torch.Tensor) -> None:
self.image_features = image_features
def __repr__(self) -> str:
image_features = self.image_features
return (f"{type(self).__name__}(image_features=torch.Tensor(shape="
f"{image_features.shape}, dtype={image_features.dtype}))")
class ImageFeaturePlugin(MultiModalPlugin[ImageFeatureData]):
def get_data_type(self) -> Type[ImageFeatureData]:
return ImageFeatureData
def _default_input_mapper(
self, ctx: InputContext,
data: ImageFeatureData) -> Dict[str, torch.Tensor]:
model_config = ctx.model_config
image_features = data.image_features.to(model_config.dtype)
return {"image_features": image_features}
raise TypeError(f"Invalid type for 'image': {type(data)}")
import functools
from typing import Any, Optional, Sequence, Type, TypeVar
from typing import Optional, Sequence, Type, TypeVar
from torch import nn
from vllm.config import ModelConfig
from vllm.logger import init_logger
from .base import MultiModalData, MultiModalInputMapper, MultiModalPlugin
from .image import (ImageFeatureData, ImageFeaturePlugin, ImagePixelData,
ImagePixelPlugin)
from .base import MultiModalDataDict, MultiModalInputMapper, MultiModalPlugin
from .image import ImagePlugin
logger = init_logger(__name__)
D = TypeVar("D", bound=MultiModalData)
N = TypeVar("N", bound=Type[nn.Module])
......@@ -20,81 +18,91 @@ class MultiModalRegistry:
"""
A registry to dispatch data processing
according to its modality and the target model.
The registry handles both external and internal data input.
"""
DEFAULT_PLUGINS = (ImageFeaturePlugin(), ImagePixelPlugin())
DEFAULT_PLUGINS = (ImagePlugin(), )
def __init__(
self,
*,
plugins: Sequence[MultiModalPlugin[Any]] = DEFAULT_PLUGINS,
) -> None:
self._plugins_by_data_type = {p.get_data_type(): p for p in plugins}
self,
*,
plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None:
self._plugins = {p.get_data_key(): p for p in plugins}
def register_plugin(self, plugin: MultiModalPlugin[Any]) -> None:
data_type = plugin.get_data_type()
def register_plugin(self, plugin: MultiModalPlugin) -> None:
data_type_key = plugin.get_data_key()
if data_type in self._plugins_by_data_type:
if data_type_key in self._plugins:
logger.warning(
"A plugin is already registered for data type %s, "
"and will be overwritten by the new plugin %s.", data_type,
"and will be overwritten by the new plugin %s.", data_type_key,
plugin)
self._plugins_by_data_type[data_type] = plugin
self._plugins[data_type_key] = plugin
def _get_plugin_for_data_type(self, data_type: Type[MultiModalData]):
for typ in data_type.mro():
plugin = self._plugins_by_data_type.get(typ)
if plugin is not None:
return plugin
def _get_plugin(self, data_type_key: str):
plugin = self._plugins.get(data_type_key)
if plugin is not None:
return plugin
msg = f"Unknown multi-modal data type: {data_type}"
msg = f"Unknown multi-modal data type: {data_type_key}"
raise NotImplementedError(msg)
def register_input_mapper(
def register_image_input_mapper(
self,
data_type: Type[D],
mapper: Optional[MultiModalInputMapper[D]] = None,
mapper: Optional[MultiModalInputMapper] = None,
):
"""
Register an input mapper for a specific modality to a model class.
Register an input mapper for image data to a model class.
See :meth:`MultiModalPlugin.register_input_mapper` for more details.
"""
return self._get_plugin_for_data_type(data_type) \
.register_input_mapper(mapper)
return self.register_input_mapper("image", mapper)
def _process_input(self, key: str, value: object,
model_config: ModelConfig):
plugin = self._plugins.get(key)
if plugin:
return plugin.map_input(model_config, value)
msg = f"Unknown multi-modal data type: {key}"
raise NotImplementedError(msg)
def register_image_pixel_input_mapper(
def register_input_mapper(
self,
mapper: Optional[MultiModalInputMapper[ImagePixelData]] = None,
data_type: str,
mapper: Optional[MultiModalInputMapper] = None,
):
"""
Register an input mapper for image pixel data to a model class.
Register an input mapper for a specific modality to a model class.
See :meth:`MultiModalPlugin.register_input_mapper` for more details.
"""
return self.register_input_mapper(ImagePixelData, mapper)
def register_image_feature_input_mapper(
self,
mapper: Optional[MultiModalInputMapper[ImageFeatureData]] = None,
):
plugin = self._plugins.get(data_type)
if not plugin:
msg = f"Unknown multi-modal data type: {data_type}"
raise NotImplementedError(msg)
return plugin.register_input_mapper(mapper)
def register_image_input(self,
mapper: Optional[MultiModalInputMapper] = None):
"""
Register an input mapper for image feature data to a model class.
Register an input mapper for image pixel data to a model class.
See :meth:`MultiModalPlugin.register_input_mapper` for more details.
"""
return self.register_input_mapper(ImageFeatureData, mapper)
return self.register_input_mapper("image", mapper)
def map_input(self, model_config: ModelConfig, data: MultiModalData):
def map_input(self, model_config: ModelConfig, data: MultiModalDataDict):
"""
Apply an input mapper to a :class:`~MultiModalData` instance passed
to the model.
Apply an input mapper to the data passed to the model.
See :meth:`MultiModalPlugin.map_input` for more details.
"""
return self._get_plugin_for_data_type(type(data)) \
.map_input(model_config, data)
result_list = [
self._process_input(k, v, model_config) for k, v in data.items()
]
return {k: v for d in result_list for k, v in d.items()}
def create_input_mapper(self, model_config: ModelConfig):
"""
......
......@@ -8,7 +8,7 @@ from PIL import Image
from vllm.config import ModelConfig
from vllm.envs import VLLM_IMAGE_FETCH_TIMEOUT
from vllm.multimodal.image import ImagePixelData
from vllm.multimodal.base import MultiModalDataDict
class ImageFetchAiohttp:
......@@ -53,14 +53,10 @@ class ImageFetchAiohttp:
"Invalid 'image_url': A valid 'image_url' must start "
"with either 'data:image' or 'http'.")
image.load()
return image
async def async_get_and_parse_image(image_url: str) -> ImagePixelData:
with await ImageFetchAiohttp.fetch_image(image_url) as image:
return ImagePixelData(image)
def encode_image_base64(image: Image.Image, format: str = 'JPEG') -> str:
"""Encode a pillow image to base64 format."""
......@@ -91,3 +87,8 @@ def get_full_image_text_prompt(image_prompt: str, text_prompt: str,
raise ValueError(
f"Unsupported model type: {config.hf_config.model_type}")
return full_prompt
async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict:
image = await ImageFetchAiohttp.fetch_image(image_url)
return {"image": image}
......@@ -14,7 +14,7 @@ from vllm.sampling_params import SamplingParams
if TYPE_CHECKING:
from vllm.inputs import LLMInputs
from vllm.multimodal import MultiModalData
from vllm.multimodal import MultiModalDataDict
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
......@@ -280,8 +280,8 @@ class Sequence:
return self.inputs["prompt_token_ids"]
@property
def multi_modal_data(self) -> Optional["MultiModalData"]:
return self.inputs.get("multi_modal_data")
def multi_modal_data(self) -> "MultiModalDataDict":
return self.inputs.get("multi_modal_data") or {}
@property
def lora_int_id(self) -> int:
......@@ -457,7 +457,7 @@ class SequenceGroup:
return next(iter(self.seqs_dict.values())).prompt_token_ids
@property
def multi_modal_data(self) -> Optional["MultiModalData"]:
def multi_modal_data(self) -> Optional["MultiModalDataDict"]:
# All sequences in the group should have the same multi-modal data.
# We use the multi-modal data of an arbitrary sequence.
return next(iter(self.seqs_dict.values())).multi_modal_data
......@@ -639,7 +639,7 @@ class SequenceGroupMetadata:
lora_request: Optional[LoRARequest] = None,
computed_block_nums: Optional[List[int]] = None,
state: Optional[SequenceGroupState] = None,
multi_modal_data: Optional["MultiModalData"] = None,
multi_modal_data: Optional["MultiModalDataDict"] = None,
encoder_seq_data: Optional[SequenceData] = None,
cross_block_table: Optional[List[int]] = None,
) -> None:
......
from typing import Optional
from transformers import AutoImageProcessor
from transformers.image_processing_utils import BaseImageProcessor
......@@ -12,7 +10,6 @@ def get_image_processor(
processor_name: str,
*args,
trust_remote_code: bool = False,
revision: Optional[str] = None,
**kwargs,
) -> BaseImageProcessor:
"""Gets an image processor for the given model name via HuggingFace."""
......@@ -21,7 +18,6 @@ def get_image_processor(
processor_name,
*args,
trust_remote_code=trust_remote_code,
revision=revision,
**kwargs)
except ValueError as e:
# If the error pertains to the processor class not existing or not
......
......@@ -504,7 +504,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
is not None else 1))
mm_data = seq_group_metadata.multi_modal_data
if mm_data is not None:
if mm_data:
# Process multi-modal data
mm_kwargs = self.multi_modal_input_mapper(mm_data)
for k, v in mm_kwargs.items():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment