Unverified Commit 3ea57a56 authored by Benji Beck's avatar Benji Beck Committed by GitHub
Browse files

Migrate Idefics3ImagePixelInputs and Idefics3ImageEmbeddingInputs to … (#21683)


Signed-off-by: default avatarBenji Beck <benjibeck@meta.com>
parent 75856bc2
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, TypedDict, Union from typing import Annotated, Literal, Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -45,6 +45,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, ...@@ -45,6 +45,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
# yapf: enable # yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
# yapf: disable # yapf: disable
from .idefics2_vision_model import ( from .idefics2_vision_model import (
...@@ -56,26 +57,30 @@ from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, ...@@ -56,26 +57,30 @@ from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings) merge_multimodal_embeddings)
class Idefics3ImagePixelInputs(TypedDict): class Idefics3ImagePixelInputs(TensorSchema):
type: Literal["pixel_values"]
pixel_values: torch.Tensor
""" """
Shape: `(batch_size * num_images * num_patches, Dimensions:
num_channels, height, width)` - bn: Batch size * number of images
- bnp: Batch size * number of images * number of patches
- c: Number of channels (3)
- h: Height
- w: Width
""" """
type: Literal["pixel_values"]
pixel_values: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")]
pixel_attention_mask: torch.Tensor pixel_attention_mask: torch.Tensor
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
num_patches: torch.Tensor
"""Shape: `(batch_size * num_images)`"""
class Idefics3ImageEmbeddingInputs(TensorSchema):
class Idefics3ImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
""" """
Shape: `(batch_size * num_images, image_feature_size, hidden_size)` Dimensions:
`hidden_size` must match the hidden size of language model backbone. - bn: Batch size * number of images
- f: Image feature size
- h: Hidden size (must match the hidden size of language model backbone)
""" """
type: Literal["image_embeds"]
data: Annotated[torch.Tensor, TensorShape("bn", "f", "h")]
ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs] ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs]
...@@ -614,25 +619,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -614,25 +619,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
self.lm_head.weight = self.model.text_model.wte.weight self.lm_head.weight = self.model.text_model.wte.weight
self.logits_processor = LogitsProcessor(config.text_config.vocab_size) self.logits_processor = LogitsProcessor(config.text_config.vocab_size)
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape)
if actual_dims != expected_dims:
expected_expr = str(expected_dims)
raise ValueError(
"The expected shape of pixel values per image per batch "
f" per patch is {expected_expr}. "
f"You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[ImageInputs]: self, **kwargs: object) -> Optional[ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
...@@ -666,16 +652,17 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -666,16 +652,17 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
raise ValueError("Incorrect type of num_patches. " raise ValueError("Incorrect type of num_patches. "
f"Got type: {type(num_patches)}") f"Got type: {type(num_patches)}")
pixel_values = flatten_bn(pixel_values, concat=True) expected_h = expected_w = self.config.vision_config.image_size
pixel_attention_mask = flatten_bn(pixel_attention_mask,
concat=True)
num_patches = flatten_bn(num_patches, concat=True)
return Idefics3ImagePixelInputs( return Idefics3ImagePixelInputs(
type="pixel_values", type="pixel_values",
pixel_values=self._validate_pixel_values(pixel_values), pixel_values=flatten_bn(pixel_values, concat=True),
pixel_attention_mask=pixel_attention_mask, pixel_attention_mask=flatten_bn(pixel_attention_mask,
num_patches=num_patches, concat=True),
num_patches=flatten_bn(num_patches, concat=True),
resolve_bindings={
"h": expected_h,
"w": expected_w
},
) )
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment