Unverified Commit 99f80944 authored by Benji Beck's avatar Benji Beck Committed by GitHub
Browse files

Migrate tarsier inputs to TensorSchema (#23500)


Signed-off-by: default avatarBenji Beck <benjibeck@meta.com>
parent 170e8ea9
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar, from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar,
Union, cast) Union, cast)
import torch import torch
...@@ -34,6 +34,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, ...@@ -34,6 +34,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.jsontree import json_map_leaves from vllm.utils.jsontree import json_map_leaves
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .clip import CLIPVisionModel from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
...@@ -43,14 +44,28 @@ from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, ...@@ -43,14 +44,28 @@ from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
from .vision import VisionEncoderInfo, get_vision_encoder_info from .vision import VisionEncoderInfo, get_vision_encoder_info
class TarsierImagePixelInputs(TypedDict): class TarsierImagePixelInputs(TensorSchema):
type: Literal["pixel_values"] """
pixel_values: torch.Tensor Dimensions:
- bn: Batch size * number of images
- c: Number of channels (3)
- h: Height
- w: Width
"""
type: Literal["pixel_values"] = "pixel_values"
pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
class TarsierImageEmbeddingInputs(TypedDict): class TarsierImageEmbeddingInputs(TensorSchema):
type: Literal["image_embeds"] """
data: torch.Tensor Dimensions:
- bn: Batch size * number of images
- ifs: Image feature size
- hs: Hidden size (must match the hidden size of language model
backbone)
"""
type: Literal["image_embeds"] = "image_embeds"
data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
TarsierImageInputs = Union[TarsierImagePixelInputs, TarsierImageInputs = Union[TarsierImagePixelInputs,
...@@ -432,18 +447,6 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -432,18 +447,6 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal,
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors) self.language_model.make_empty_intermediate_tensors)
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w) # Assuming 3 channels
actual_dims = tuple(data.shape[1:])
if actual_dims != expected_dims:
expected_expr = ("batch_size", *map(str, expected_dims))
raise ValueError(
f"The expected shape of pixel values is {expected_expr}. "
f"You supplied {tuple(data.shape)}.")
return data
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[TarsierImageInputs]: self, **kwargs: object) -> Optional[TarsierImageInputs]:
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
...@@ -459,8 +462,7 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -459,8 +462,7 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal,
return TarsierImagePixelInputs( return TarsierImagePixelInputs(
type="pixel_values", type="pixel_values",
pixel_values=self._validate_pixel_values( pixel_values=flatten_bn(pixel_values, concat=True),
flatten_bn(pixel_values, concat=True)),
) )
if image_embeds is not None: if image_embeds is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment