Unverified Commit 46785034 authored by Benji Beck's avatar Benji Beck Committed by GitHub
Browse files

Migrate MiniCPMVImageInputs to TensorSchema (#21939)


Signed-off-by: default avatarBenji Beck <benjibeck@meta.com>
parent 93d06524
...@@ -27,7 +27,7 @@ import math ...@@ -27,7 +27,7 @@ import math
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from functools import partial from functools import partial
from typing import Any, Callable, Literal, Optional, TypedDict, Union from typing import Annotated, Any, Callable, Literal, Optional, Union
import numpy as np import numpy as np
import torch import torch
...@@ -63,6 +63,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder ...@@ -63,6 +63,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import flatten_2d_lists from vllm.utils import flatten_2d_lists
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .idefics2_vision_model import Idefics2VisionTransformer from .idefics2_vision_model import Idefics2VisionTransformer
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
...@@ -74,36 +75,47 @@ from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, ...@@ -74,36 +75,47 @@ from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
_MAX_FRAMES_PER_VIDEO = 16 _MAX_FRAMES_PER_VIDEO = 16
class MiniCPMVImagePixelInputs(TypedDict): class MiniCPMVImagePixelInputs(TensorSchema):
type: Literal["pixel_values"]
pixel_values: list[torch.Tensor]
""" """
Shape: `(batch_size * num_images * num_slices, num_channels, height, width)` Dimensions:
- bns: Batch size * number of images * number of slices
Note that the image size may vary, so we pass it as a list - bn: Batch size * number of images
instead of a batched tensor. - c: Number of channels
- h: Height
- w: Width
""" """
tgt_sizes: torch.Tensor type: Literal["pixel_values"] = "pixel_values"
# Note that the image size may vary, so we pass it as a list instead of a
# batched tensor.
pixel_values: Annotated[
list[torch.Tensor],
TensorShape("bns", "c", "h", "w"),
]
tgt_sizes: Annotated[
torch.Tensor,
TensorShape("bns", 2), # This should be in `(height, width)` format.
]
num_slices: Annotated[
torch.Tensor,
TensorShape("bn"),
]
class MiniCPMVImageEmbeddingInputs(TensorSchema):
""" """
Shape: `(batch_size * num_images * num_slices, 2)` Dimensions:
- bn: Batch size * number of images
This should be in `(height, width)` format. - ns: Number of slices
- hs: Hidden size (must match language model backbone)
""" """
num_slices: torch.Tensor
"""Shape: `(batch_size * num_images)`"""
class MiniCPMVImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"] type: Literal["image_embeds"]
image_embeds: Union[torch.Tensor, list[torch.Tensor]] image_embeds: Annotated[
""" Union[torch.Tensor, list[torch.Tensor]],
Shape: `(batch_size * num_images, num_slices, hidden_size)` TensorShape("bn", "ns", "hs"),
]
`hidden_size` must match the hidden size of language model backbone.
instead of a batched tensor.
"""
MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs, MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs,
...@@ -832,11 +844,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -832,11 +844,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
pixel_values_flat = flatten_bn(flatten_2d_lists(pixel_values)) pixel_values_flat = flatten_bn(flatten_2d_lists(pixel_values))
tgt_sizes_flat = flatten_bn(flatten_2d_lists(tgt_sizes), concat=True) tgt_sizes_flat = flatten_bn(flatten_2d_lists(tgt_sizes), concat=True)
if len(pixel_values_flat) != len(tgt_sizes_flat):
raise ValueError("Inconsistent flattened lengths, found: "
f"{len(pixel_values_flat)} vs. "
f"{len(tgt_sizes_flat)}")
return MiniCPMVImagePixelInputs( return MiniCPMVImagePixelInputs(
type="pixel_values", type="pixel_values",
pixel_values=pixel_values_flat, pixel_values=pixel_values_flat,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment