Unverified Commit 1fa1d6a9 authored by Benji Beck's avatar Benji Beck Committed by GitHub
Browse files

Migrate OvisImagePatchInputs to TensorSchema (#22024)


Signed-off-by: default avatarBenji Beck <benjibeck@meta.com>
parent d59c9864
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
""" PyTorch Ovis model.""" """ PyTorch Ovis model."""
import math import math
from collections.abc import Iterable, Mapping from collections.abc import Iterable, Mapping
from typing import Literal, Optional, TypedDict, Union from typing import Annotated, Literal, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -49,6 +49,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, ...@@ -49,6 +49,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.processors.ovis import OvisProcessor from vllm.transformers_utils.processors.ovis import OvisProcessor
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import merge_multimodal_embeddings from .utils import merge_multimodal_embeddings
...@@ -201,25 +202,22 @@ class VisualTokenizer(torch.nn.Module): ...@@ -201,25 +202,22 @@ class VisualTokenizer(torch.nn.Module):
return tokens return tokens
class OvisImagePatchInputs(TypedDict): class OvisImagePatchInputs(TensorSchema):
type: Literal["image_patches"]
flat_data: torch.Tensor
"""
Shape:
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
"""
indicator_tokens: torch.Tensor
""" """
Shape: Dimensions:
`(batch_size * (num_patches + 1))` - batch_patches: Batch size * number of patches
""" - patch_size: patch_size_x * patch_size_y * num_channels
- patch_indicators: Batch size * (number of patches + 1)
patches_per_image: list[int] - patches_per_image: List of number of total patches for each image
""" in the batch.
List of number of total patches for each image in the batch.
This is used to restore the first two dimensions of `flat_data`.
""" """
type: Literal["image_patches"]
flat_data: Annotated[torch.Tensor,
TensorShape("batch_patches", "patch_size")]
indicator_tokens: Annotated[torch.Tensor, TensorShape("patch_indicators")]
patches_per_image: Annotated[list[int],
TensorShape("num_patches_per_image")]
# This is used to restore the first two dimensions of `flat_data`.
class VisualEmbedding(torch.nn.Embedding): class VisualEmbedding(torch.nn.Embedding):
...@@ -458,9 +456,12 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -458,9 +456,12 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP):
raise ValueError("Incorrect type of indicator_tokens. " raise ValueError("Incorrect type of indicator_tokens. "
f"Got type: {type(pixel_values)}") f"Got type: {type(pixel_values)}")
flat_data = flatten_bn(pixel_values, concat=True)
if flat_data.ndim >= 3:
flat_data = flat_data.flatten(start_dim=1)
return OvisImagePatchInputs( return OvisImagePatchInputs(
type="image_patches", type="image_patches",
flat_data=flatten_bn(flatten_bn(pixel_values), concat=True), flat_data=flat_data,
patches_per_image=[ patches_per_image=[
x.shape[0] for x in flatten_bn(pixel_values) x.shape[0] for x in flatten_bn(pixel_values)
], ],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment