Unverified Commit ccf27cc4 authored by Benji Beck's avatar Benji Beck Committed by GitHub
Browse files

Migrate Blip2ImagePixelInputs and Blip2ImageEmbeddingInputs to TensorSchema (#21656)


Signed-off-by: default avatarBenji Beck <benjibeck@meta.com>
parent c6573698
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, TypedDict, Union from typing import Annotated, Literal, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -22,6 +22,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, ...@@ -22,6 +22,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptInsertion, PromptUpdate) PromptInsertion, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .blip import BlipVisionModel from .blip import BlipVisionModel
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP, from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
...@@ -34,19 +35,27 @@ from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, ...@@ -34,19 +35,27 @@ from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
_IMAGE_TOKEN_ID = 50265 _IMAGE_TOKEN_ID = 50265
class Blip2ImagePixelInputs(TypedDict): class Blip2ImagePixelInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- c: Number of channels (3)
- h: Height of each image
- w: Width of each image
"""
type: Literal["pixel_values"] type: Literal["pixel_values"]
data: torch.Tensor data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
class Blip2ImageEmbeddingInputs(TypedDict): class Blip2ImageEmbeddingInputs(TensorSchema):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
""" """
Dimensions:
- bn: Batch size * number of images
- f: Image feature size
- h: Hidden size (must match the hidden size of language model backbone)
"""
type: Literal["image_embeds"]
data: Annotated[torch.Tensor, TensorShape("bn", "f", "h")]
Blip2ImageInputs = Union[Blip2ImagePixelInputs, Blip2ImageEmbeddingInputs] Blip2ImageInputs = Union[Blip2ImagePixelInputs, Blip2ImageEmbeddingInputs]
...@@ -551,21 +560,8 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -551,21 +560,8 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors) self.language_model.make_empty_intermediate_tensors)
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: def _create_image_input(self,
h = w = self.config.vision_config.image_size **kwargs: object) -> Optional[Blip2ImageInputs]:
expected_dims = (3, h, w)
actual_dims = tuple(data.shape[1:])
if actual_dims != expected_dims:
expected_expr = ("batch_size", *map(str, expected_dims))
raise ValueError(
f"The expected shape of pixel values is {expected_expr}. "
f"You supplied {tuple(data.shape)}.")
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Blip2ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None) image_embeds = kwargs.pop("image_embeds", None)
...@@ -573,27 +569,19 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -573,27 +569,19 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
return None return None
if pixel_values is not None: if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)): expected_h = expected_w = self.config.vision_config.image_size
raise ValueError("Incorrect type of pixel values. " return Blip2ImagePixelInputs(type="pixel_values",
f"Got type: {type(pixel_values)}") data=flatten_bn(pixel_values,
concat=True),
pixel_values = flatten_bn(pixel_values, concat=True) resolve_bindings={
"h": expected_h,
return Blip2ImagePixelInputs( "w": expected_w
type="pixel_values", })
data=self._validate_pixel_values(pixel_values),
)
if image_embeds is not None: if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
image_embeds = flatten_bn(image_embeds, concat=True)
return Blip2ImageEmbeddingInputs( return Blip2ImageEmbeddingInputs(
type="image_embeds", type="image_embeds",
data=image_embeds, data=flatten_bn(image_embeds, concat=True),
) )
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment