"docs/vscode:/vscode.git/clone" did not exist on "d88bff1b96c6f4c8abbd3d5ab4758bdc040f7b62"
Unverified Commit d8937de4 authored by Benji Beck's avatar Benji Beck Committed by GitHub
Browse files

Migrate Gemma3ImagePixelInputs to TensorSchema (#21676)


Signed-off-by: default avatarBenji Beck <benjibeck@meta.com>
parent e626d286
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Literal, Optional, TypedDict from typing import Annotated, Any, Literal, Optional
import torch import torch
from torch import nn from torch import nn
...@@ -31,6 +31,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, ...@@ -31,6 +31,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
# yapf: enable # yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP) SupportsMultiModal, SupportsPP)
...@@ -42,18 +43,21 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, ...@@ -42,18 +43,21 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
logger = init_logger(__name__) logger = init_logger(__name__)
class Gemma3ImagePixelInputs(TypedDict): class Gemma3ImagePixelInputs(TensorSchema):
type: Literal["pixel_values"]
pixel_values: torch.Tensor
""" """
Shape: `(num_patches_total, num_channels, height, width)` Dimensions:
- p: Number of patches total (over each image over each prompt in the
`num_patches_total` is the total number of patches batch)
over each image over each prompt in the batch. - c: Number of channels (3)
- h: Height of each patch
- w: Width of each patch
- bn: Batch size * number of images
""" """
type: Literal["pixel_values"] = "pixel_values"
pixel_values: Annotated[torch.Tensor, TensorShape("p", 3, "h", "w")]
num_patches: torch.Tensor num_patches: Annotated[torch.Tensor, TensorShape("bn")]
"""Shape: `(batch_size * num_images)`"""
Gemma3ImageInputs = Gemma3ImagePixelInputs Gemma3ImageInputs = Gemma3ImagePixelInputs
...@@ -523,15 +527,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -523,15 +527,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
def dtype(self): def dtype(self):
return next(self.parameters()).dtype return next(self.parameters()).dtype
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
image_size = self.config.vision_config.image_size
expected_dims = (3, image_size, image_size)
if data.shape[1:] != expected_dims:
raise ValueError(
"The expected shape of pixel values per image per batch is "
f"{expected_dims}. You supplied {tuple(data.shape)}.")
return data
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Gemma3ImageInputs]: self, **kwargs: object) -> Optional[Gemma3ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
...@@ -549,14 +544,15 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -549,14 +544,15 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
raise ValueError("Incorrect type of num_crops. " raise ValueError("Incorrect type of num_crops. "
f"Got type: {type(num_crops)}") f"Got type: {type(num_crops)}")
pixel_values = flatten_bn(pixel_values, concat=True) image_size = self.config.vision_config.image_size
num_crops = flatten_bn(num_crops, concat=True)
return Gemma3ImagePixelInputs( return Gemma3ImagePixelInputs(
type="pixel_values", pixel_values=flatten_bn(pixel_values, concat=True),
pixel_values=self._validate_pixel_values(pixel_values), num_patches=flatten_bn(num_crops, concat=True) + 1,
num_patches=num_crops + 1, resolve_bindings={
) "h": image_size,
"w": image_size
})
def _image_pixels_to_features( def _image_pixels_to_features(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment