Unverified Commit 0b8caf90 authored by Benji Beck's avatar Benji Beck Committed by GitHub
Browse files

Migrate DeepseekVL2ImageInputs to TensorSchema (#21658)


Signed-off-by: default avatarBenji Beck <benjibeck@meta.com>
parent ccf27cc4
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
"""Inference-only Deepseek-VL2 model compatible with HuggingFace weights.""" """Inference-only Deepseek-VL2 model compatible with HuggingFace weights."""
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, TypedDict, Union from typing import Annotated, Literal, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -36,6 +36,7 @@ from vllm.transformers_utils.processors.deepseek_vl2 import ( ...@@ -36,6 +36,7 @@ from vllm.transformers_utils.processors.deepseek_vl2 import (
DeepseekVLV2Processor) DeepseekVLV2Processor)
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from vllm.utils import is_list_of from vllm.utils import is_list_of
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
...@@ -46,25 +47,30 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, ...@@ -46,25 +47,30 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
_IMAGE_TOKEN = "<image>" _IMAGE_TOKEN = "<image>"
class DeepseekVL2ImagePixelInputs(TypedDict): class DeepseekVL2ImagePixelInputs(TensorSchema):
type: Literal["pixel_values"]
data: Union[torch.Tensor, list[torch.Tensor]]
"""
Shape: `(batch_size * num_images, num_channels, height, width)`
""" """
images_spatial_crop: torch.Tensor Dimensions:
- bn: Batch size * number of images
- c: Number of channels (3)
- h: Height of each image
- w: Width of each image
""" """
Shape: `(batch_size * num_images, 2)` type: Literal["pixel_values"]
""" data: Annotated[Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", 3, "h", "w")]
images_spatial_crop: Annotated[torch.Tensor, TensorShape("bn", 2)]
class DeepseekVL2VImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone. class DeepseekVL2VImageEmbeddingInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- f: Image feature size
- h: Hidden size (must match language model backbone)
""" """
type: Literal["image_embeds"]
data: Annotated[Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "f", "h")]
DeepseekVL2ImageInputs = Union[DeepseekVL2ImagePixelInputs, DeepseekVL2ImageInputs = Union[DeepseekVL2ImagePixelInputs,
...@@ -439,46 +445,6 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -439,46 +445,6 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
model = model.to(dtype=torch.get_default_dtype()) model = model.to(dtype=torch.get_default_dtype())
return model return model
def _validate_pixel_values(
self, data: Union[torch.Tensor, list[torch.Tensor]]
) -> Union[torch.Tensor, list[torch.Tensor]]:
h = w = self.vision_config.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape[1:])
if actual_dims != expected_dims:
expected_expr = ("num_patches", *map(str, expected_dims))
raise ValueError(
"The expected shape of pixel values per image per batch "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _validate_images_spatial_crop(
self, data: Union[torch.Tensor, list[torch.Tensor]]
) -> Union[torch.Tensor, list[torch.Tensor]]:
expected_dims = 2
def _validate_shape(d: torch.Tensor):
actual_dims = d.size(-1)
if actual_dims != expected_dims:
expected_expr = str(expected_dims)
raise ValueError(
f"The expected shape of image sizes per image per batch "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[DeepseekVL2ImageInputs]: self, **kwargs: object) -> Optional[DeepseekVL2ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
...@@ -489,25 +455,18 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -489,25 +455,18 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return None return None
if pixel_values is not None: if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)): expected_h = expected_w = self.vision_config.image_size
raise ValueError("Incorrect type of pixel values. " return DeepseekVL2ImagePixelInputs(type="pixel_values",
f"Got type: {type(pixel_values)}") data=flatten_bn(pixel_values),
images_spatial_crop=flatten_bn(
if not isinstance(images_spatial_crop, (torch.Tensor, list)): images_spatial_crop,
raise ValueError("Incorrect type of image sizes. " concat=True),
f"Got type: {type(images_spatial_crop)}") resolve_bindings={
"h": expected_h,
return DeepseekVL2ImagePixelInputs( "w": expected_w,
type="pixel_values", })
data=self._validate_pixel_values(flatten_bn(pixel_values)),
images_spatial_crop=self._validate_images_spatial_crop(
flatten_bn(images_spatial_crop, concat=True)))
if image_embeds is not None: if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return DeepseekVL2VImageEmbeddingInputs( return DeepseekVL2VImageEmbeddingInputs(
type="image_embeds", type="image_embeds",
data=flatten_bn(image_embeds), data=flatten_bn(image_embeds),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment