Unverified Commit 749f7925 authored by David Xia's avatar David Xia Committed by GitHub
Browse files

[Frontend] decrease import time of vllm.multimodal (#18031)


Co-authored-by: default avatarAaron Pham <Aaronpham0103@gmail.com>
parent 85686500
...@@ -10,40 +10,43 @@ from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar, ...@@ -10,40 +10,43 @@ from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar,
Union, cast, final) Union, cast, final)
import numpy as np import numpy as np
import torch
import torch.types
from PIL.Image import Image
from transformers import BatchFeature
from typing_extensions import NotRequired, TypeAlias from typing_extensions import NotRequired, TypeAlias
from vllm.jsontree import JSONTree, json_map_leaves from vllm.jsontree import JSONTree, json_map_leaves
from vllm.utils import full_groupby, is_list_of from vllm.utils import LazyLoader, full_groupby, is_list_of
if TYPE_CHECKING: if TYPE_CHECKING:
import torch
import torch.types
from PIL.Image import Image
from transformers.feature_extraction_utils import BatchFeature
from .hasher import MultiModalHashDict from .hasher import MultiModalHashDict
else:
torch = LazyLoader("torch", globals(), "torch")
_T = TypeVar("_T") _T = TypeVar("_T")
HfImageItem: TypeAlias = Union[Image, np.ndarray, torch.Tensor] HfImageItem: TypeAlias = Union["Image", np.ndarray, "torch.Tensor"]
""" """
A {class}`transformers.image_utils.ImageInput` representing a single image A {class}`transformers.image_utils.ImageInput` representing a single image
item, which can be passed to a HuggingFace `ImageProcessor`. item, which can be passed to a HuggingFace `ImageProcessor`.
""" """
HfVideoItem: TypeAlias = Union[list[Image], np.ndarray, torch.Tensor, HfVideoItem: TypeAlias = Union[list["Image"], np.ndarray, "torch.Tensor",
list[np.ndarray], list[torch.Tensor]] list[np.ndarray], list["torch.Tensor"]]
""" """
A {class}`transformers.image_utils.VideoInput` representing a single video A {class}`transformers.image_utils.VideoInput` representing a single video
item, which can be passed to a HuggingFace `VideoProcessor`. item, which can be passed to a HuggingFace `VideoProcessor`.
""" """
HfAudioItem: TypeAlias = Union[list[float], np.ndarray, torch.Tensor] HfAudioItem: TypeAlias = Union[list[float], np.ndarray, "torch.Tensor"]
""" """
Represents a single audio Represents a single audio
item, which can be passed to a HuggingFace `AudioProcessor`. item, which can be passed to a HuggingFace `AudioProcessor`.
""" """
ImageItem: TypeAlias = Union[HfImageItem, torch.Tensor] ImageItem: TypeAlias = Union[HfImageItem, "torch.Tensor"]
""" """
A {class}`transformers.image_utils.ImageInput` representing a single image A {class}`transformers.image_utils.ImageInput` representing a single image
item, which can be passed to a HuggingFace `ImageProcessor`. item, which can be passed to a HuggingFace `ImageProcessor`.
...@@ -53,7 +56,7 @@ which are treated as image embeddings; ...@@ -53,7 +56,7 @@ which are treated as image embeddings;
these are directly passed to the model without HF processing. these are directly passed to the model without HF processing.
""" """
VideoItem: TypeAlias = Union[HfVideoItem, torch.Tensor] VideoItem: TypeAlias = Union[HfVideoItem, "torch.Tensor"]
""" """
A {class}`transformers.image_utils.VideoInput` representing a single video A {class}`transformers.image_utils.VideoInput` representing a single video
item, which can be passed to a HuggingFace `VideoProcessor`. item, which can be passed to a HuggingFace `VideoProcessor`.
...@@ -64,7 +67,7 @@ these are directly passed to the model without HF processing. ...@@ -64,7 +67,7 @@ these are directly passed to the model without HF processing.
""" """
AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float], AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float],
torch.Tensor] "torch.Tensor"]
""" """
Represents a single audio Represents a single audio
item, which can be passed to a HuggingFace `AudioProcessor`. item, which can be passed to a HuggingFace `AudioProcessor`.
...@@ -132,7 +135,7 @@ class PlaceholderRange: ...@@ -132,7 +135,7 @@ class PlaceholderRange:
length: int length: int
"""The length of the placeholder.""" """The length of the placeholder."""
is_embed: Optional[torch.Tensor] = None is_embed: Optional["torch.Tensor"] = None
""" """
A boolean mask of shape `(length,)` indicating which positions A boolean mask of shape `(length,)` indicating which positions
between `offset` and `offset + length` to assign embeddings to. between `offset` and `offset + length` to assign embeddings to.
...@@ -158,8 +161,8 @@ class PlaceholderRange: ...@@ -158,8 +161,8 @@ class PlaceholderRange:
return nested_tensors_equal(self.is_embed, other.is_embed) return nested_tensors_equal(self.is_embed, other.is_embed)
NestedTensors = Union[list["NestedTensors"], list[torch.Tensor], torch.Tensor, NestedTensors: TypeAlias = Union[list["NestedTensors"], list["torch.Tensor"],
tuple[torch.Tensor, ...]] "torch.Tensor", tuple["torch.Tensor", ...]]
""" """
Uses a list instead of a tensor if the dimensions of each element do not match. Uses a list instead of a tensor if the dimensions of each element do not match.
""" """
...@@ -465,7 +468,7 @@ class MultiModalFieldConfig: ...@@ -465,7 +468,7 @@ class MultiModalFieldConfig:
@staticmethod @staticmethod
def flat_from_sizes(modality: str, def flat_from_sizes(modality: str,
size_per_item: torch.Tensor, size_per_item: "torch.Tensor",
dim: int = 0): dim: int = 0):
""" """
Defines a field where an element in the batch is obtained by Defines a field where an element in the batch is obtained by
...@@ -602,7 +605,7 @@ class MultiModalKwargs(UserDict[str, NestedTensors]): ...@@ -602,7 +605,7 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
@staticmethod @staticmethod
def from_hf_inputs( def from_hf_inputs(
hf_inputs: BatchFeature, hf_inputs: "BatchFeature",
config_by_key: Mapping[str, MultiModalFieldConfig], config_by_key: Mapping[str, MultiModalFieldConfig],
): ):
# NOTE: This skips fields in `hf_inputs` that are not in `config_by_key` # NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
...@@ -792,7 +795,7 @@ class MultiModalKwargs(UserDict[str, NestedTensors]): ...@@ -792,7 +795,7 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
return self._items_by_modality[modality] return self._items_by_modality[modality]
MultiModalPlaceholderDict = Mapping[str, Sequence[PlaceholderRange]] MultiModalPlaceholderDict: TypeAlias = Mapping[str, Sequence[PlaceholderRange]]
""" """
A dictionary containing placeholder ranges for each modality. A dictionary containing placeholder ranges for each modality.
""" """
...@@ -823,7 +826,7 @@ class MultiModalInputs(TypedDict): ...@@ -823,7 +826,7 @@ class MultiModalInputs(TypedDict):
mm_hashes: Optional["MultiModalHashDict"] mm_hashes: Optional["MultiModalHashDict"]
"""The hashes of the multi-modal data.""" """The hashes of the multi-modal data."""
mm_placeholders: MultiModalPlaceholderDict mm_placeholders: "MultiModalPlaceholderDict"
""" """
For each modality, information about the placeholder tokens in For each modality, information about the placeholder tokens in
`prompt_token_ids`. `prompt_token_ids`.
......
...@@ -8,11 +8,9 @@ from typing import (TYPE_CHECKING, Any, Generic, Literal, NamedTuple, Optional, ...@@ -8,11 +8,9 @@ from typing import (TYPE_CHECKING, Any, Generic, Literal, NamedTuple, Optional,
import numpy as np import numpy as np
import torch import torch
from PIL.Image import Image
from transformers import BatchFeature
from typing_extensions import TypeAlias, TypeGuard, assert_never from typing_extensions import TypeAlias, TypeGuard, assert_never
from vllm.utils import is_list_of from vllm.utils import LazyLoader, is_list_of
from .audio import AudioResampler from .audio import AudioResampler
from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem, from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem,
...@@ -22,6 +20,11 @@ from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem, ...@@ -22,6 +20,11 @@ from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem,
_T = TypeVar("_T") _T = TypeVar("_T")
_I = TypeVar("_I") _I = TypeVar("_I")
if TYPE_CHECKING:
import PIL.Image as PILImage
else:
PILImage = LazyLoader("PILImage", globals(), "PIL.Image")
class ModalityDataItems(ABC, Generic[_T, _I]): class ModalityDataItems(ABC, Generic[_T, _I]):
""" """
...@@ -131,6 +134,8 @@ class DictEmbeddingItems(ModalityDataItems[Mapping[str, torch.Tensor], ...@@ -131,6 +134,8 @@ class DictEmbeddingItems(ModalityDataItems[Mapping[str, torch.Tensor],
Mapping[str, MultiModalFieldConfig], Mapping[str, MultiModalFieldConfig],
], ],
) -> None: ) -> None:
from transformers.feature_extraction_utils import BatchFeature
super().__init__(data, modality) super().__init__(data, modality)
missing_required_data_keys = required_fields - data.keys() missing_required_data_keys = required_fields - data.keys()
...@@ -200,7 +205,7 @@ class ImageProcessorItems(ProcessorBatchItems[HfImageItem]): ...@@ -200,7 +205,7 @@ class ImageProcessorItems(ProcessorBatchItems[HfImageItem]):
def get_image_size(self, item_idx: int) -> ImageSize: def get_image_size(self, item_idx: int) -> ImageSize:
image = self.get(item_idx) image = self.get(item_idx)
if isinstance(image, Image): if isinstance(image, PILImage.Image):
return ImageSize(*image.size) return ImageSize(*image.size)
if isinstance(image, (np.ndarray, torch.Tensor)): if isinstance(image, (np.ndarray, torch.Tensor)):
_, h, w = image.shape _, h, w = image.shape
...@@ -226,7 +231,7 @@ class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]): ...@@ -226,7 +231,7 @@ class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]):
def get_frame_size(self, item_idx: int) -> ImageSize: def get_frame_size(self, item_idx: int) -> ImageSize:
image = self.get(item_idx)[0] # Assume that the video isn't empty image = self.get(item_idx)[0] # Assume that the video isn't empty
if isinstance(image, Image): if isinstance(image, PILImage.Image):
return ImageSize(*image.size) return ImageSize(*image.size)
if isinstance(image, (np.ndarray, torch.Tensor)): if isinstance(image, (np.ndarray, torch.Tensor)):
_, h, w = image.shape _, h, w = image.shape
...@@ -399,7 +404,7 @@ class MultiModalDataParser: ...@@ -399,7 +404,7 @@ class MultiModalDataParser:
if self._is_embeddings(data): if self._is_embeddings(data):
return ImageEmbeddingItems(data) return ImageEmbeddingItems(data)
if (isinstance(data, Image) if (isinstance(data, PILImage.Image)
or isinstance(data, or isinstance(data,
(np.ndarray, torch.Tensor)) and data.ndim == 3): (np.ndarray, torch.Tensor)) and data.ndim == 3):
data_items = [data] data_items = [data]
...@@ -420,7 +425,7 @@ class MultiModalDataParser: ...@@ -420,7 +425,7 @@ class MultiModalDataParser:
if self._is_embeddings(data): if self._is_embeddings(data):
return VideoEmbeddingItems(data) return VideoEmbeddingItems(data)
if (is_list_of(data, Image) if (is_list_of(data, PILImage.Image)
or isinstance(data, or isinstance(data,
(np.ndarray, torch.Tensor)) and data.ndim == 4): (np.ndarray, torch.Tensor)) and data.ndim == 4):
data_items = [data] data_items = [data]
......
...@@ -13,7 +13,6 @@ from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol, ...@@ -13,7 +13,6 @@ from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
TypeVar, Union, cast) TypeVar, Union, cast)
import torch import torch
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
from typing_extensions import assert_never from typing_extensions import assert_never
from vllm.inputs import InputProcessingContext from vllm.inputs import InputProcessingContext
...@@ -31,6 +30,10 @@ from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems, ...@@ -31,6 +30,10 @@ from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems,
MultiModalDataParser) MultiModalDataParser)
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers.configuration_utils import PretrainedConfig
from transformers.feature_extraction_utils import BatchFeature
from transformers.processing_utils import ProcessorMixin
from .profiling import BaseDummyInputsBuilder from .profiling import BaseDummyInputsBuilder
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -1047,10 +1050,10 @@ class BaseProcessingInfo: ...@@ -1047,10 +1050,10 @@ class BaseProcessingInfo:
def get_tokenizer(self) -> AnyTokenizer: def get_tokenizer(self) -> AnyTokenizer:
return self.ctx.tokenizer return self.ctx.tokenizer
def get_hf_config(self) -> PretrainedConfig: def get_hf_config(self) -> "PretrainedConfig":
return self.ctx.get_hf_config() return self.ctx.get_hf_config()
def get_hf_processor(self, **kwargs: object) -> ProcessorMixin: def get_hf_processor(self, **kwargs: object) -> "ProcessorMixin":
""" """
Subclasses can override this method to handle Subclasses can override this method to handle
specific kwargs from model config or user inputs. specific kwargs from model config or user inputs.
...@@ -1165,7 +1168,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1165,7 +1168,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
@abstractmethod @abstractmethod
def _get_mm_fields_config( def _get_mm_fields_config(
self, self,
hf_inputs: BatchFeature, hf_inputs: "BatchFeature",
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
"""Given the HF-processed data, output the metadata of each field.""" """Given the HF-processed data, output the metadata of each field."""
...@@ -1222,7 +1225,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1222,7 +1225,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
# This refers to the data to be passed to HF processor. # This refers to the data to be passed to HF processor.
mm_data: Mapping[str, object], mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
) -> BatchFeature: ) -> "BatchFeature":
""" """
Call the HF processor on the prompt text and Call the HF processor on the prompt text and
associated multi-modal data. associated multi-modal data.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment