kimi_vl.py 14.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# ruff: noqa: E501
# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/modeling_kimi_vl.py
# Copyright 2025 The Moonshot AI Team, DeepSeek-AI, and HuggingFace Inc. team. All rights reserved.
#
# The code is based on llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py), but modified for KimiVL.
#
# Licensing Information:
# - Code derived from llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py) is licensed under the Apache License, Version 2.0.
# - Other parts of the code are licensed under the MIT License.
#
# Apache License, Version 2.0:
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# MIT License:
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import math
46
from collections.abc import Iterable, Mapping, Sequence
47
from dataclasses import dataclass
48
from typing import Annotated, Any, Literal
49
50
51

import torch
from torch import nn
52
from transformers import BatchFeature
53
54
55
from transformers.activations import GELUActivation

from vllm.config import VllmConfig
56
from vllm.config.multimodal import BaseDummyOptions
57
from vllm.model_executor.layers.linear import ReplicatedLinear
58
from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP
59
60
from vllm.model_executor.models.moonvit import MoonVitPretrainedModel
from vllm.multimodal import MULTIMODAL_REGISTRY
61
62
63
64
65
66
67
68
69
70
71
72
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
    NestedTensors,
)
from vllm.multimodal.parse import (
    ImageEmbeddingItems,
    ImageProcessorItems,
    MultiModalDataItems,
)
from vllm.multimodal.processing import (
73
    BaseDummyInputsBuilder,
74
75
76
77
78
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
79
80
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import KimiVLConfig, MoonViTConfig
81
from vllm.utils.tensor_schema import TensorSchema, TensorShape
82

83
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
84
from .vision import run_dp_sharded_mrope_vision_model
85
86
87
88
89
90
91
92
93
94


# For dummy input only
@dataclass
class MaxImageTokenMeta:
    width: int = 1024
    height: int = 1024


class KimiVLMultiModalProjector(nn.Module):
95
96
97
    def __init__(
        self, config: KimiVLConfig, use_data_parallel: bool = False, prefix: str = ""
    ):
98
        super().__init__()
99
        self.use_data_parallel = use_data_parallel
100

101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
        self.hidden_size = (
            config.vision_config.hidden_size
            * config.vision_config.merge_kernel_size[0]
            * config.vision_config.merge_kernel_size[1]
        )

        self.pre_norm = torch.nn.LayerNorm(config.vision_config.hidden_size, eps=1e-5)
        self.linear_1 = ReplicatedLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
            prefix=maybe_prefix(prefix, "linear_1"),
        )
        self.linear_2 = ReplicatedLinear(
            self.hidden_size,
            config.text_config.hidden_size,
            bias=True,
            prefix=maybe_prefix(prefix, "linear_2"),
        )
120
121
122
        self.act = GELUActivation()

    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
123
        hidden_states = self.pre_norm(image_features).view(-1, self.hidden_size)
124
        hidden_states, _ = self.linear_1(hidden_states)
125
        hidden_states = self.act(hidden_states)
126
        hidden_states, _ = self.linear_2(hidden_states)
127
128
129
        return hidden_states


130
class KimiVLImagePixelInputs(TensorSchema):
131
    """
132
133
134
135
136
    Dimensions:
        - nc: Number of channels
        - np: Number of patches
        - ps: Patch size
        - ni: Number of images
137
    """
138

139
    type: Literal["pixel_values"] = "pixel_values"
140

141
    pixel_values: Annotated[
142
        torch.Tensor | list[torch.Tensor],
143
144
145
146
        TensorShape("np", 3, "ps", "ps"),
    ]

    image_grid_hws: Annotated[torch.Tensor, TensorShape("ni", 2)]
147
148
149
150
151
152
153
154
155
156
157


# TODO: support embeds too
# We only support pixel input for kimi-vl now
KimiVLImageInputs = KimiVLImagePixelInputs


class KimiVLProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
        return self.ctx.get_hf_config(KimiVLConfig)

158
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
Cyrus Leung's avatar
Cyrus Leung committed
159
160
        return {"image": None}

161
162
163
164
165
166
167
168
169
170
171
172
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        hf_processor = self.get_hf_processor()
        patch_size = hf_processor.image_processor.patch_size
        kernel_size = hf_processor.image_processor.merge_kernel_size
        in_token_limit = hf_processor.image_processor.in_token_limit
        height = image_height
        width = image_width
173
174
        assert isinstance(height, int), f"height must be int, current height {height}"
        assert isinstance(width, int), f"width must be int, current width {width}"
175
176
177
        assert kernel_size is not None, "kernel_size must be specified"

        if (width // patch_size) * (height // patch_size) > in_token_limit:
178
179
180
            scale = math.sqrt(
                in_token_limit / ((width // patch_size) * (height // patch_size))
            )
181
182
183
184
185
            new_w, new_h = int(width * scale), int(height * scale)
            width, height = new_w, new_h

        kernel_height, kernel_width = kernel_size

186
187
188
189
190
191
        pad_height = (
            kernel_height * patch_size - height % (kernel_height * patch_size)
        ) % (kernel_height * patch_size)
        pad_width = (
            kernel_width * patch_size - width % (kernel_width * patch_size)
        ) % (kernel_width * patch_size)
192
193
194
195
196
197
198
199
200
201
202
203

        # Calculate new dimensions after padding and patching
        token_height = (height + pad_height) // (kernel_size[0] * patch_size)
        token_width = (width + pad_width) // (kernel_size[1] * patch_size)
        return int(token_height * token_width)

    @property
    def image_token_id(self) -> int:
        return self.get_hf_config().media_placeholder_token_id


class KimiVLDummyInputsBuilder(BaseDummyInputsBuilder[KimiVLProcessingInfo]):
Cyrus Leung's avatar
Cyrus Leung committed
204
205
206
207
208
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token = processor.image_token
209

Cyrus Leung's avatar
Cyrus Leung committed
210
        return image_token * num_images
211

Cyrus Leung's avatar
Cyrus Leung committed
212
    def get_dummy_mm_data(
213
214
215
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
216
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
Cyrus Leung's avatar
Cyrus Leung committed
217
    ) -> MultiModalDataDict:
218
219
        num_images = mm_counts.get("image", 0)

220
221
        image_overrides = mm_options.get("image") if mm_options else None

Cyrus Leung's avatar
Cyrus Leung committed
222
        return {
223
224
225
226
227
228
            "image": self._get_dummy_images(
                width=MaxImageTokenMeta.width,
                height=MaxImageTokenMeta.height,
                num_images=num_images,
                overrides=image_overrides,
            )
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
        }


class KimiVLMultiModalProcessor(BaseMultiModalProcessor[KimiVLProcessingInfo]):
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        image_grid_hws = hf_inputs.get("image_grid_hws", torch.empty((0, 2)))
        image_grid_sizes = image_grid_hws.prod(-1)

        # pixel_values is merged as a single large tensor
        # image_grid_hws is shapes for each subtensor in pixel_values
        return dict(
            pixel_values=MultiModalFieldConfig.flat_from_sizes(
245
246
                "image", image_grid_sizes
            ),
247
248
249
250
251
252
253
            image_grid_hws=MultiModalFieldConfig.batched("image"),
        )

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, Any],
254
        out_mm_kwargs: MultiModalKwargsItems,
255
256
257
258
259
    ) -> Sequence[PromptUpdate]:
        image_token_id = self.info.image_token_id

        def get_replacement(item_idx: int):
            images = mm_items.get_items(
260
261
                "image", (ImageEmbeddingItems, ImageProcessorItems)
            )
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
                num_image_tokens = self.info.get_num_image_tokens(
                    image_width=image_size.width,
                    image_height=image_size.height,
                )

            return [image_token_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=get_replacement,
            ),
        ]


283
284
285
286
287
288
@MULTIMODAL_REGISTRY.register_processor(
    KimiVLMultiModalProcessor,
    info=KimiVLProcessingInfo,
    dummy_inputs=KimiVLDummyInputsBuilder,
)
class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
289
290
    supports_encoder_tp_data = True

291
    @classmethod
292
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
293
294
295
296
297
        if modality.startswith("image"):
            return "<|media_start|>image<|media_content|><|media_pad|><|media_end|>"

        raise ValueError("Only image modality is supported")

298
299
300
301
302
303
304
305
306
307
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
    ) -> None:
        super().__init__()
        model_config = vllm_config.model_config
        config: KimiVLConfig = model_config.hf_config
        quant_config = vllm_config.quant_config

308
309
310
        self.config = config
        self.quant_config = quant_config

311
        assert isinstance(config.vision_config, MoonViTConfig)
312
313
314
        self.use_data_parallel = (
            model_config.multimodal_config.mm_encoder_tp_mode == "data"
        )
315
316
        self.hidden_size = config.text_config.hidden_size

317
318
319
320
321
322
323
324
325
326
        with self._mark_tower_model(vllm_config, "image"):
            self.vision_tower = MoonVitPretrainedModel(
                config.vision_config,
                prefix=maybe_prefix(prefix, "vision_tower"),
            )
            self.multi_modal_projector = KimiVLMultiModalProjector(
                config=config,
                use_data_parallel=self.use_data_parallel,
                prefix=maybe_prefix(prefix, "multi_modal_projector"),
            )
327

328
329
330
331
332
333
        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=config.text_config,
                prefix=maybe_prefix(prefix, "language_model"),
                architectures=["DeepseekV2ForCausalLM"],
334
            )
335

336
        self.make_empty_intermediate_tensors = (
337
338
            self.language_model.make_empty_intermediate_tensors
        )
339

340
341
342
        self.media_placeholder: int = self.config.media_placeholder_token_id

    def _parse_and_validate_image_input(
343
        self, **kwargs: object
344
    ) -> KimiVLImageInputs | None:
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
        # image input type must be pixel values now
        pixel_values = kwargs.pop("pixel_values", None)
        image_grid_hws = kwargs.pop("image_grid_hws", None)

        if pixel_values is None:
            return None

        return KimiVLImagePixelInputs(
            type="pixel_values",
            pixel_values=pixel_values,
            image_grid_hws=image_grid_hws,
        )

    # perform vt on processored pixel_values
    @torch.inference_mode()
360
    def _process_image_pixels(self, inputs: KimiVLImagePixelInputs) -> torch.Tensor:
361
362
        pixel_values = inputs["pixel_values"]
        image_grid_hws = inputs["image_grid_hws"]
363
        if self.use_data_parallel:
364
365
366
367
368
369
            return run_dp_sharded_mrope_vision_model(
                self.vision_tower,
                pixel_values,
                image_grid_hws.tolist(),
                rope_type="rope_2d",
            )
370
371
        else:
            return self.vision_tower(pixel_values, image_grid_hws)
372

373
    def _process_image_input(self, image_input: KimiVLImageInputs) -> torch.Tensor:
374
375
        assert image_input["type"] == "pixel_values"
        image_features = self._process_image_pixels(image_input)
376
        assert isinstance(image_features, (list, tuple))
377
        lengths = [x.shape[0] for x in image_features]
378
        return self.multi_modal_projector(torch.cat(image_features)).split(lengths)
379

380
    def embed_multimodal(self, **kwargs: object) -> NestedTensors | None:
381
382
383
384
385
386
387
388
389
390
391
        # Validate the multimodal input keyword arguments
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None

        # Run multimodal inputs through encoder and projector
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def forward(
        self,
392
        input_ids: torch.Tensor | None,
393
        positions: torch.Tensor,
394
395
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
396
        **kwargs: object,
397
    ) -> IntermediateTensors:
398
399
400
401
402
403
404
405
406
407
408
409
        if intermediate_tensors is not None:
            inputs_embeds = None

        hidden_states = self.language_model(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )

        return hidden_states

410
    def compute_logits(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
411
        return self.language_model.compute_logits(hidden_states)
412

413
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
414
415
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)