Unverified Commit 1fa1d6a9 authored by Benji Beck's avatar Benji Beck Committed by GitHub
Browse files

Migrate OvisImagePatchInputs to TensorSchema (#22024)


Signed-off-by: default avatarBenji Beck <benjibeck@meta.com>
parent d59c9864
......@@ -19,7 +19,7 @@
""" PyTorch Ovis model."""
import math
from collections.abc import Iterable, Mapping
from typing import Literal, Optional, TypedDict, Union
from typing import Annotated, Literal, Optional, Union
import torch
import torch.nn as nn
......@@ -49,6 +49,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.processors.ovis import OvisProcessor
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import merge_multimodal_embeddings
......@@ -201,25 +202,22 @@ class VisualTokenizer(torch.nn.Module):
return tokens
class OvisImagePatchInputs(TypedDict):
type: Literal["image_patches"]
flat_data: torch.Tensor
"""
Shape:
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
"""
indicator_tokens: torch.Tensor
class OvisImagePatchInputs(TensorSchema):
"""
Shape:
`(batch_size * (num_patches + 1))`
"""
patches_per_image: list[int]
"""
List of number of total patches for each image in the batch.
This is used to restore the first two dimensions of `flat_data`.
Dimensions:
- batch_patches: Batch size * number of patches
- patch_size: patch_size_x * patch_size_y * num_channels
- patch_indicators: Batch size * (number of patches + 1)
- patches_per_image: List of number of total patches for each image
in the batch.
"""
type: Literal["image_patches"]
flat_data: Annotated[torch.Tensor,
TensorShape("batch_patches", "patch_size")]
indicator_tokens: Annotated[torch.Tensor, TensorShape("patch_indicators")]
patches_per_image: Annotated[list[int],
TensorShape("num_patches_per_image")]
# This is used to restore the first two dimensions of `flat_data`.
class VisualEmbedding(torch.nn.Embedding):
......@@ -458,9 +456,12 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP):
raise ValueError("Incorrect type of indicator_tokens. "
f"Got type: {type(pixel_values)}")
flat_data = flatten_bn(pixel_values, concat=True)
if flat_data.ndim >= 3:
flat_data = flat_data.flatten(start_dim=1)
return OvisImagePatchInputs(
type="image_patches",
flat_data=flatten_bn(flatten_bn(pixel_values), concat=True),
flat_data=flat_data,
patches_per_image=[
x.shape[0] for x in flatten_bn(pixel_values)
],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment