Unverified Commit a69693e3 authored by Benji Beck's avatar Benji Beck Committed by GitHub
Browse files

Migrate Qwen inputs to TensorSchema (#23473)


Signed-off-by: default avatarBenji Beck <benjibeck@meta.com>
parent 5da4f5d8
...@@ -11,7 +11,7 @@ import math ...@@ -11,7 +11,7 @@ import math
import unicodedata import unicodedata
from collections.abc import Collection, Mapping, Sequence, Set from collections.abc import Collection, Mapping, Sequence, Set
from functools import lru_cache, partial from functools import lru_cache, partial
from typing import Callable, Literal, Optional, TypedDict, Union from typing import Annotated, Callable, Literal, Optional, Union
import regex as re import regex as re
import torch import torch
...@@ -40,6 +40,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, ...@@ -40,6 +40,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptUpdate, PromptUpdateDetails) PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP) SupportsMultiModal, SupportsPP)
...@@ -47,26 +48,34 @@ from .qwen import QWenBaseModel, QWenModel ...@@ -47,26 +48,34 @@ from .qwen import QWenBaseModel, QWenModel
from .utils import flatten_bn, merge_multimodal_embeddings from .utils import flatten_bn, merge_multimodal_embeddings
class QwenImagePixelInputs(TypedDict): class QwenImagePixelInputs(TensorSchema):
type: Literal["pixel_values"]
data: torch.Tensor
""" """
Shape: `(batch_size * num_images, 3, image_size, image_size)` Dimensions:
- bn: Batch size * number of images
- c: Number of channels (3)
- h: Height
- w: Width
Note that image_size is the value in the vision config to which we resize Note that image_size is the value in the vision config to which we resize
the image to in the normalization transform. Currently multi-image support the image to in the normalization transform. Currently multi-image support
can only be leveraged by passing image embeddings directly. can only be leveraged by passing image embeddings directly.
""" """
type: Literal["pixel_values"] = "pixel_values"
data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
class QwenImageEmbeddingInputs(TypedDict): class QwenImageEmbeddingInputs(TensorSchema):
type: Literal["image_embeds"] """
data: torch.Tensor Dimensions:
"""Shape: `(batch_size * num_images, 256, hidden_size)` - bn: Batch size * number of images
- ifs: Image feature size (256)
- hs: Hidden size
`hidden_size` must match the hidden size of the language model backbone `hidden_size` must match the hidden size of the language model backbone
and is stored in the visual config of the model if we have one. and is stored in the visual config of the model if we have one.
""" """
type: Literal["image_embeds"] = "image_embeds"
data: Annotated[torch.Tensor, TensorShape("bn", 256, "hs")]
QwenImageInputs = Union[QwenImagePixelInputs, QwenImageEmbeddingInputs] QwenImageInputs = Union[QwenImagePixelInputs, QwenImageEmbeddingInputs]
...@@ -697,19 +706,6 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA, ...@@ -697,19 +706,6 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
self.transformer: QwenVLModel self.transformer: QwenVLModel
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.visual["image_size"]
expected_dims = (3, h, w)
actual_dims = tuple(data.shape[1:])
if actual_dims != expected_dims:
expected_expr = ("batch_size", *map(str, expected_dims))
raise ValueError(
f"The expected shape of pixel values is {expected_expr}. "
f"You supplied {tuple(data.shape)}.")
return data
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[QwenImageInputs]: self, **kwargs: object) -> Optional[QwenImageInputs]:
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
...@@ -720,10 +716,13 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA, ...@@ -720,10 +716,13 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
raise ValueError("Incorrect type of pixel values. " raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}") f"Got type: {type(pixel_values)}")
expected_h = expected_w = self.config.visual["image_size"]
resolve_bindings = {"h": expected_h, "w": expected_w}
return QwenImagePixelInputs( return QwenImagePixelInputs(
type="pixel_values", type="pixel_values",
data=self._validate_pixel_values( data=flatten_bn(pixel_values, concat=True),
flatten_bn(pixel_values, concat=True)), resolve_bindings=resolve_bindings,
) )
if image_embeds is not None: if image_embeds is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment