"tests/vscode:/vscode.git/clone" did not exist on "6a09612b2e0e09d037a220ea8115632b8084e008"
Unverified Commit 56d04089 authored by Benji Beck's avatar Benji Beck Committed by GitHub
Browse files

Migrate Interns1 inputs to TensorSchema (#23510)


Signed-off-by: default avatarBenji Beck <benjibeck@meta.com>
parent 7be0cb8e
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
# -------------------------------------------------------- # --------------------------------------------------------
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, TypedDict, Union from typing import Annotated, Literal, Optional, Union
import regex as re import regex as re
import torch import torch
...@@ -32,6 +32,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, ...@@ -32,6 +32,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptUpdate, PromptUpdateDetails) PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP) SupportsMultiModal, SupportsPP)
...@@ -62,51 +63,60 @@ class InternS1MultiModalProjector(nn.Module): ...@@ -62,51 +63,60 @@ class InternS1MultiModalProjector(nn.Module):
return hidden_states return hidden_states
class InternS1ImagePixelInputs(TypedDict): class InternS1ImagePixelInputs(TensorSchema):
type: Literal["pixel_values"]
pixel_values: torch.Tensor
""" """
Shape: Dimensions:
`(batch_size * num_images * (1 + num_patches), num_channels, height, width)` - bnp: Batch size * number of images * (1 + num_patches)
- c: Number of channels (3)
- h: Height
- w: Width
- bn: Batch size * number of images
""" """
type: Literal["pixel_values"] = "pixel_values"
pixel_values: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")]
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
class InternS1ImageEmbeddingInputs(TypedDict): class InternS1ImageEmbeddingInputs(TensorSchema):
type: Literal["image_embeds"]
data: Union[torch.Tensor, list[torch.Tensor]]
""" """
A tensor of shape `(num_images, total_image_feature_size, hidden_size)` Dimensions:
or a list of tensors of shape `(total_image_feature_size, hidden_size)` - ni: Number of images
- tifs: Total image feature size
`hidden_size` must match the hidden size of language model backbone. - hs: Hidden size (must match language model backbone)
""" """
type: Literal["image_embeds"] = "image_embeds"
data: Annotated[Union[torch.Tensor, list[torch.Tensor]],
TensorShape("ni", "tifs", "hs")]
InternS1ImageInputs = Union[InternS1ImagePixelInputs, InternS1ImageInputs = Union[InternS1ImagePixelInputs,
InternS1ImageEmbeddingInputs] InternS1ImageEmbeddingInputs]
class InternS1VideoPixelInputs(TypedDict): class InternS1VideoPixelInputs(TensorSchema):
type: Literal["pixel_values_videos"]
pixel_values: torch.Tensor
""" """
Shape: Dimensions:
`(batch_size * num_video * num_frames, num_channels, height, width)` - bnv: Batch size * number of videos * number of frames
- bn: Batch size * number of images
- c: Number of channels (3)
- h: Height
- w: Width
""" """
type: Literal["pixel_values_videos"] = "pixel_values_videos"
num_patches: torch.Tensor pixel_values: Annotated[torch.Tensor, TensorShape("bnv", 3, "h", "w")]
"""Shape: `(batch_size * num_images)`""" num_patches: Annotated[torch.Tensor, TensorShape("bn")]
class InternS1VideoEmbeddingInputs(TypedDict): class InternS1VideoEmbeddingInputs(TensorSchema):
type: Literal["video_embeds"]
data: Union[torch.Tensor, list[torch.Tensor]]
""" """
A tensor of shape `(num_videos, total_video_feature_size, hidden_size)` Dimensions:
or a list of tensors of shape `(total_video_feature_size, hidden_size)` - nv: Number of videos
- tvfs: Total video feature size
`hidden_size` must match the hidden size of language model backbone. - hs: Hidden size (must match language model backbone)
""" """
type: Literal["video_embeds"] = "video_embeds"
data: Annotated[Union[torch.Tensor, list[torch.Tensor]],
TensorShape("nv", "tvfs", "hs")]
InternS1VideoInputs = Union[InternS1VideoPixelInputs, InternS1VideoInputs = Union[InternS1VideoPixelInputs,
...@@ -572,26 +582,6 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -572,26 +582,6 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
vit_embeds = self.multi_modal_projector(vit_embeds) vit_embeds = self.multi_modal_projector(vit_embeds)
return vit_embeds return vit_embeds
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h, w = self.config.vision_config.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape)
if actual_dims != expected_dims:
expected_expr = str(expected_dims)
raise ValueError(
"The expected shape of pixel values per image per batch "
f" per patch is {expected_expr}. "
f"You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[InternS1ImageInputs]: self, **kwargs: object) -> Optional[InternS1ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
...@@ -627,10 +617,15 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -627,10 +617,15 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
pixel_values = flatten_bn(pixel_values, concat=True) pixel_values = flatten_bn(pixel_values, concat=True)
image_num_patches = flatten_bn(image_num_patches, concat=True) image_num_patches = flatten_bn(image_num_patches, concat=True)
h, w = self.config.vision_config.image_size
return InternS1ImagePixelInputs( return InternS1ImagePixelInputs(
type="pixel_values", type="pixel_values",
pixel_values=self._validate_pixel_values(pixel_values), pixel_values=pixel_values,
num_patches=image_num_patches, num_patches=image_num_patches,
resolve_bindings={
"h": h,
"w": w,
},
) )
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")
...@@ -671,11 +666,15 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -671,11 +666,15 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
concat=True) concat=True)
video_num_patches = flatten_bn(video_num_patches, concat=True) video_num_patches = flatten_bn(video_num_patches, concat=True)
h, w = self.config.vision_config.image_size
return InternS1VideoPixelInputs( return InternS1VideoPixelInputs(
type="pixel_values_videos", type="pixel_values_videos",
pixel_values=self._validate_pixel_values(
pixel_values_flat_video),
num_patches=video_num_patches, num_patches=video_num_patches,
pixel_values=pixel_values_flat_video,
resolve_bindings={
"h": h,
"w": w,
},
) )
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment