Unverified Commit a7527728 authored by Benji Beck's avatar Benji Beck Committed by GitHub
Browse files

Migrate Paligemma inputs to TensorSchema (#23470)


Signed-off-by: default avatarBenji Beck <benjibeck@meta.com>
parent 9dc30b70
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, TypedDict, Union
from typing import Annotated, Literal, Optional, Union
import torch
from torch import nn
......@@ -21,6 +21,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel
......@@ -32,19 +33,27 @@ from .vision import get_vision_encoder_info
logger = init_logger(__name__)
class PaliGemmaImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
class PaliGemmaImagePixelInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- c: Number of channels (3)
- h: Height
- w: Width
"""
type: Literal["pixel_values"] = "pixel_values"
data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
class PaliGemmaImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
class PaliGemmaImageEmbeddingInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- ifs: Image feature size
- hs: Hidden size (must match language model backbone)
"""
type: Literal["image_embeds"] = "image_embeds"
data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
PaliGemmaImageInputs = Union[PaliGemmaImagePixelInputs,
......@@ -279,19 +288,6 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
actual_dims = tuple(data.shape[1:])
if actual_dims != expected_dims:
expected_expr = ("batch_size", *map(str, expected_dims))
raise ValueError(
f"The expected shape of pixel values is {expected_expr}. "
f"You supplied {tuple(data.shape)}.")
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[PaliGemmaImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
......@@ -301,22 +297,17 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
return None
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
pixel_values = flatten_bn(pixel_values, concat=True)
return PaliGemmaImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
)
h = w = self.config.vision_config.image_size
return PaliGemmaImagePixelInputs(type="pixel_values",
data=pixel_values,
resolve_bindings={
"h": h,
"w": w
})
if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
image_embeds = flatten_bn(image_embeds, concat=True)
return PaliGemmaImageEmbeddingInputs(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment