Unverified Commit c4477f55 authored by Benji Beck's avatar Benji Beck Committed by GitHub
Browse files

Migrate Mistral3ImagePixelInputs to TensorSchema (#21945)


Signed-off-by: default avatarBenji Beck <benjibeck@meta.com>
Co-authored-by: default avatarCyrus Leung <tlleungac@connect.ust.hk>
parent dfd23820
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar, from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar,
Union) Union)
import torch import torch
...@@ -32,6 +32,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, ...@@ -32,6 +32,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptUpdateDetails) PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP) SupportsMultiModal, SupportsPP)
...@@ -42,16 +43,24 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, ...@@ -42,16 +43,24 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
from .vision import get_vision_encoder_info from .vision import get_vision_encoder_info
class Mistral3ImagePixelInputs(TypedDict): class Mistral3ImagePixelInputs(TensorSchema):
type: Literal["pixel_values_pixtral"]
pixel_values: Union[torch.Tensor, list[torch.Tensor]]
""" """
Shape: `(batch_size * num_images, num_channels, height, width)` Dimensions:
- bn: Batch size * number of images
Note that `height` or `width` may be different per batch and image, - c: Number of channels (3)
in which case the data is passed as a list instead of a batched tensor. - h: Height of each image
- w: Width of each image
""" """
type: Literal["pixel_values_pixtral"] = "pixel_values_pixtral"
# Note that `height` or `width` may be different per batch and image,
# in which case the data is passed as a list instead of a batched tensor.
pixel_values: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", 3, "h", "w", dynamic_dims={"h", "w"}),
]
class Mistral3PatchMerger(nn.Module): class Mistral3PatchMerger(nn.Module):
""" """
...@@ -456,19 +465,6 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA, ...@@ -456,19 +465,6 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors) self.language_model.make_empty_intermediate_tensors)
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
actual_dims = tuple(data.shape[1:])
if actual_dims != expected_dims:
expected_expr = ("batch_size", *map(str, expected_dims))
raise ValueError(
f"The expected shape of pixel values is {expected_expr}. "
f"You supplied {tuple(data.shape)}.")
return data
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Mistral3ImagePixelInputs]: self, **kwargs: object) -> Optional[Mistral3ImagePixelInputs]:
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment