Unverified Commit 05fae021 authored by Benji Beck's avatar Benji Beck Committed by GitHub
Browse files

Migrate KimiVLImagePixelInputs to TensorSchema (#21769)


Signed-off-by: default avatarBenji Beck <benjibeck@meta.com>
Co-authored-by: default avatarIsotr0py <2037008807@qq.com>
parent d1bf1b97
......@@ -46,7 +46,7 @@ import copy
import math
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass
from typing import Any, Literal, Optional, TypedDict, Union
from typing import Annotated, Any, Literal, Optional, Union
import torch
from torch import nn
......@@ -79,6 +79,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import KimiVLConfig, MoonViTConfig
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .utils import is_pp_missing_parameter, maybe_prefix
......@@ -118,15 +119,22 @@ class KimiVLMultiModalProjector(nn.Module):
return hidden_states
class KimiVLImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
pixel_values: Union[torch.Tensor, list[torch.Tensor]]
class KimiVLImagePixelInputs(TensorSchema):
"""
Shape:`(num_patches, num_channels, patch_size, patch_size)`
Dimensions:
- nc: Number of channels
- np: Number of patches
- ps: Patch size
- ni: Number of images
"""
type: Literal["pixel_values"] = "pixel_values"
image_grid_hws: torch.Tensor
"""Shape:`(num_images, 2)`"""
pixel_values: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("np", 3, "ps", "ps"),
]
image_grid_hws: Annotated[torch.Tensor, TensorShape("ni", 2)]
# TODO: support embeds too
......@@ -348,8 +356,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal):
pixel_values = pixel_values.reshape(-1, num_channels, patch_size,
patch_size)
pixel_values = pixel_values.to(self.vision_tower.dtype)
# image_grid_hws.shape = (N, 2)
assert image_grid_hws.ndim == 2, f"unexpected shape for image_grid_hws: {image_grid_hws.shape}"
return KimiVLImagePixelInputs(
type="pixel_values",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment