Unverified Commit d128d0d5 authored by Benji Beck's avatar Benji Beck Committed by GitHub
Browse files

Migrate KeyeImageInputs and KeyeVideoInputs to TensorSchema (#21686)


Signed-off-by: default avatarBenji Beck <benjibeck@meta.com>
parent a6c05028
......@@ -3,7 +3,7 @@
import math
from collections.abc import Iterable, Mapping, Sequence
from functools import partial
from typing import Any, Literal, Optional, TypedDict, Union
from typing import Annotated, Any, Literal, Optional, Union
import numpy as np
import torch
......@@ -46,6 +46,7 @@ from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import uses_mrope
from vllm.transformers_utils.processor import (
cached_image_processor_from_config)
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
......@@ -102,77 +103,62 @@ def smart_resize(
return h_bar, w_bar
class KeyeImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
pixel_values: torch.Tensor
"""Shape:
`(num_patches, num_channels * patch_size * patch_size)`
class KeyeImagePixelInputs(TensorSchema):
"""
image_grid_thw: torch.Tensor
"""Shape: `(num_images, 3)`
This should be in `(grid_t, grid_h, grid_w)` format.
Dimensions:
- np: Number of patches
- cps: Number of channels * patch_size * patch_size
- ni: Number of images
- g: Grid dimensions (3 for t, h, w)
"""
type: Literal["pixel_values"]
pixel_values: Annotated[torch.Tensor, TensorShape("np", "cps")]
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
class KeyeImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
image_embeds: torch.Tensor
"""Supported types:
- list[`torch.Tensor`]: A list of tensors holding all images' features.
Each tensor holds an image's features.
- `torch.Tensor`: A tensor holding all images' features
(concatenation of all images' feature tensors).
Tensor shape: `(num_image_features, hidden_size)`
- `num_image_features` varies based on
the number and resolution of the images.
- `hidden_size` must match the hidden size of language model backbone.
class KeyeImageEmbeddingInputs(TensorSchema):
"""
image_grid_thw: torch.Tensor
"""Shape: `(num_images, 3)`
This should be in `(grid_t, grid_h, grid_w)` format.
Dimensions:
- nf: Number of image features
- hs: Hidden size (must match the hidden size of language model
backbone)
- ni: Number of images
- g: Grid dimensions (3 for t, h, w)
"""
type: Literal["image_embeds"]
image_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")]
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
KeyeImageInputs = Union[KeyeImagePixelInputs, KeyeImageEmbeddingInputs]
class KeyeVideoPixelInputs(TypedDict):
type: Literal["pixel_values_videos"]
pixel_values_videos: torch.Tensor
"""Shape:
`(num_patches,
num_channels * temporal_patch_size * patch_size * patch_size)`
class KeyeVideoPixelInputs(TensorSchema):
"""
video_grid_thw: torch.Tensor
"""Shape: `(num_videos, 3)`
This should be in `(grid_t, grid_h, grid_w)` format.
Dimensions:
- np: Number of patches
- ctps: Number of channels * temporal_patch_size * patch_size *
patch_size
- nv: Number of videos
- g: Grid dimensions (3 for t, h, w)
"""
type: Literal["pixel_values_videos"]
pixel_values_videos: Annotated[torch.Tensor, TensorShape("np", "ctps")]
video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)]
class KeyeVideoEmbeddingInputs(TypedDict):
type: Literal["video_embeds"]
video_embeds: torch.Tensor
"""Supported types:
- list[`torch.Tensor`]: A list of tensors holding all videos' features.
Each tensor holds an video's features.
- `torch.Tensor`: A tensor holding all videos' features
(concatenation of all videos' feature tensors).
Tensor shape: `(num_image_features, hidden_size)`
- `num_image_features` varies based on
the number and resolution of the videos.
- `hidden_size` must match the hidden size of language model backbone.
class KeyeVideoEmbeddingInputs(TensorSchema):
"""
video_grid_thw: torch.Tensor
"""Shape: `(num_videos, 3)`
This should be in `(grid_t, grid_h, grid_w)` format.
Dimensions:
- nf: Number of video features
- hs: Hidden size (must match the hidden size of language model
backbone)
- nv: Number of videos
- g: Grid dimensions (3 for t, h, w)
"""
type: Literal["video_embeds"]
video_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")]
video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)]
KeyeVideoInputs = Union[KeyeVideoPixelInputs, KeyeVideoEmbeddingInputs]
......@@ -1420,10 +1406,6 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw")
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of image pixel values. "
f"Got type: {type(pixel_values)}")
return KeyeImagePixelInputs(
type="pixel_values",
pixel_values=pixel_values,
......@@ -1436,9 +1418,6 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw")
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return KeyeImageEmbeddingInputs(
type="image_embeds",
image_embeds=image_embeds,
......@@ -1474,9 +1453,6 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
video_grid_thw = self._validate_and_reshape_mm_tensor(
video_grid_thw, "video grid_thw")
if not isinstance(video_embeds, torch.Tensor):
raise ValueError("Incorrect type of video embeddings. "
f"Got type: {type(video_embeds)}")
return KeyeVideoEmbeddingInputs(
type="video_embeds",
video_embeds=video_embeds,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment