Unverified Commit 05fae021 authored by Benji Beck's avatar Benji Beck Committed by GitHub
Browse files

Migrate KimiVLImagePixelInputs to TensorSchema (#21769)


Signed-off-by: default avatarBenji Beck <benjibeck@meta.com>
Co-authored-by: default avatarIsotr0py <2037008807@qq.com>
parent d1bf1b97
...@@ -46,7 +46,7 @@ import copy ...@@ -46,7 +46,7 @@ import copy
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Literal, Optional, TypedDict, Union from typing import Annotated, Any, Literal, Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -79,6 +79,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder ...@@ -79,6 +79,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import KimiVLConfig, MoonViTConfig from vllm.transformers_utils.configs import KimiVLConfig, MoonViTConfig
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .utils import is_pp_missing_parameter, maybe_prefix from .utils import is_pp_missing_parameter, maybe_prefix
...@@ -118,15 +119,22 @@ class KimiVLMultiModalProjector(nn.Module): ...@@ -118,15 +119,22 @@ class KimiVLMultiModalProjector(nn.Module):
return hidden_states return hidden_states
class KimiVLImagePixelInputs(TypedDict): class KimiVLImagePixelInputs(TensorSchema):
type: Literal["pixel_values"]
pixel_values: Union[torch.Tensor, list[torch.Tensor]]
""" """
Shape:`(num_patches, num_channels, patch_size, patch_size)` Dimensions:
- nc: Number of channels
- np: Number of patches
- ps: Patch size
- ni: Number of images
""" """
type: Literal["pixel_values"] = "pixel_values"
image_grid_hws: torch.Tensor pixel_values: Annotated[
"""Shape:`(num_images, 2)`""" Union[torch.Tensor, list[torch.Tensor]],
TensorShape("np", 3, "ps", "ps"),
]
image_grid_hws: Annotated[torch.Tensor, TensorShape("ni", 2)]
# TODO: support embeds too # TODO: support embeds too
...@@ -348,8 +356,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -348,8 +356,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal):
pixel_values = pixel_values.reshape(-1, num_channels, patch_size, pixel_values = pixel_values.reshape(-1, num_channels, patch_size,
patch_size) patch_size)
pixel_values = pixel_values.to(self.vision_tower.dtype) pixel_values = pixel_values.to(self.vision_tower.dtype)
# image_grid_hws.shape = (N, 2)
assert image_grid_hws.ndim == 2, f"unexpected shape for image_grid_hws: {image_grid_hws.shape}"
return KimiVLImagePixelInputs( return KimiVLImagePixelInputs(
type="pixel_values", type="pixel_values",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment