Unverified Commit 787cdb38 authored by Benji Beck's avatar Benji Beck Committed by GitHub
Browse files

Migrate DonutImagePixelInputs to TensorSchema (#23509)


Signed-off-by: default avatarBenji Beck <benjibeck@meta.com>
parent a5203d04
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, TypedDict, Union from typing import Annotated, Literal, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -29,6 +29,7 @@ from vllm.multimodal.processing import (BaseProcessingInfo, ...@@ -29,6 +29,7 @@ from vllm.multimodal.processing import (BaseProcessingInfo,
PromptIndexTargets, PromptInsertion, PromptIndexTargets, PromptInsertion,
PromptUpdate) PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.utils.tensor_schema import TensorSchema, TensorShape
class MBartDecoderWrapper(nn.Module): class MBartDecoderWrapper(nn.Module):
...@@ -132,10 +133,16 @@ class DonutLanguageForConditionalGeneration(nn.Module, SupportsV0Only): ...@@ -132,10 +133,16 @@ class DonutLanguageForConditionalGeneration(nn.Module, SupportsV0Only):
return loaded_params return loaded_params
class DonutImagePixelInputs(TypedDict): class DonutImagePixelInputs(TensorSchema):
"""
Dimensions:
- b: Batch size
- c: Number of channels (3)
- h: Height
- w: Width
"""
type: Literal["pixel_values"] type: Literal["pixel_values"]
data: torch.Tensor data: Annotated[torch.Tensor, TensorShape("b", 3, "h", "w")]
"""Shape: (batch_size, num_channel, height, width)"""
class DonutProcessingInfo(BaseProcessingInfo): class DonutProcessingInfo(BaseProcessingInfo):
...@@ -275,27 +282,6 @@ class DonutForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -275,27 +282,6 @@ class DonutForConditionalGeneration(nn.Module, SupportsMultiModal,
) )
self.pad_token_id = config.pad_token_id self.pad_token_id = config.pad_token_id
def _validate_pixel_values(
self, data: Union[torch.Tensor, list[torch.Tensor]]
) -> Union[torch.Tensor, list[torch.Tensor]]:
# size = self.processor_config["size"]
h, w = self.config.encoder.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape)
if actual_dims != expected_dims:
raise ValueError(
"The expected shape of pixel values per batch "
f"is {expected_dims}. You supplied {actual_dims}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_image_input(self, **kwargs: object): def _parse_and_validate_image_input(self, **kwargs: object):
pixel_values: Optional[Union[list[list[torch.Tensor]], pixel_values: Optional[Union[list[list[torch.Tensor]],
list[torch.Tensor], list[torch.Tensor],
...@@ -314,11 +300,14 @@ class DonutForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -314,11 +300,14 @@ class DonutForConditionalGeneration(nn.Module, SupportsMultiModal,
"Both pixel values and image embeds are provided.") "Both pixel values and image embeds are provided.")
if pixel_values is not None: if pixel_values is not None:
return DonutImagePixelInputs( h, w = self.config.encoder.image_size
type="pixel_values", return DonutImagePixelInputs(type="pixel_values",
data=self._validate_pixel_values( data=flatten_bn(pixel_values,
flatten_bn(pixel_values, concat=True)), concat=True),
) resolve_bindings={
"h": h,
"w": w,
})
if image_embeds is not None: if image_embeds is not None:
raise NotImplementedError raise NotImplementedError
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment