"platforms/vscode:/vscode.git/clone" did not exist on "a3b1009a212ac660c8a75b9de9ddae5f927e2311"
kimi_vl.py 14.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# ruff: noqa: E501
# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/modeling_kimi_vl.py
# Copyright 2025 The Moonshot AI Team, DeepSeek-AI, and HuggingFace Inc. team. All rights reserved.
#
# The code is based on llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py), but modified for KimiVL.
#
# Licensing Information:
# - Code derived from llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py) is licensed under the Apache License, Version 2.0.
# - Other parts of the code are licensed under the MIT License.
#
# Apache License, Version 2.0:
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# MIT License:
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import math
46
from collections.abc import Iterable, Mapping, Sequence
47
from dataclasses import dataclass
48
from typing import Annotated, Any, Literal
49
50
51

import torch
from torch import nn
52
from transformers import BatchFeature
53
54
55
from transformers.activations import GELUActivation

from vllm.config import VllmConfig
56
from vllm.config.multimodal import BaseDummyOptions
57
from vllm.model_executor.layers.linear import ReplicatedLinear
58
from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP
59
60
from vllm.model_executor.models.moonvit import MoonVitPretrainedModel
from vllm.multimodal import MULTIMODAL_REGISTRY
61
62
63
64
65
66
67
68
69
70
71
72
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
    NestedTensors,
)
from vllm.multimodal.parse import (
    ImageEmbeddingItems,
    ImageProcessorItems,
    MultiModalDataItems,
)
from vllm.multimodal.processing import (
73
    BaseDummyInputsBuilder,
74
75
76
77
78
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
79
80
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import KimiVLConfig, MoonViTConfig
81
from vllm.utils.tensor_schema import TensorSchema, TensorShape
82

83
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
84
from .vision import is_vit_use_data_parallel, run_dp_sharded_mrope_vision_model
85
86
87
88
89
90
91
92
93
94


# For dummy input only
@dataclass
class MaxImageTokenMeta:
    width: int = 1024
    height: int = 1024


class KimiVLMultiModalProjector(nn.Module):
95
    def __init__(
96
97
98
        self,
        config: KimiVLConfig,
        prefix: str = "",
99
    ):
100
        super().__init__()
101
        self.use_data_parallel = is_vit_use_data_parallel()
102

103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
        self.hidden_size = (
            config.vision_config.hidden_size
            * config.vision_config.merge_kernel_size[0]
            * config.vision_config.merge_kernel_size[1]
        )

        self.pre_norm = torch.nn.LayerNorm(config.vision_config.hidden_size, eps=1e-5)
        self.linear_1 = ReplicatedLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
            prefix=maybe_prefix(prefix, "linear_1"),
        )
        self.linear_2 = ReplicatedLinear(
            self.hidden_size,
            config.text_config.hidden_size,
            bias=True,
            prefix=maybe_prefix(prefix, "linear_2"),
        )
122
123
124
        self.act = GELUActivation()

    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
125
        hidden_states = self.pre_norm(image_features).view(-1, self.hidden_size)
126
        hidden_states, _ = self.linear_1(hidden_states)
127
        hidden_states = self.act(hidden_states)
128
        hidden_states, _ = self.linear_2(hidden_states)
129
130
131
        return hidden_states


132
class KimiVLImagePixelInputs(TensorSchema):
133
    """
134
135
136
137
138
    Dimensions:
        - nc: Number of channels
        - np: Number of patches
        - ps: Patch size
        - ni: Number of images
139
    """
140

141
    type: Literal["pixel_values"] = "pixel_values"
142

143
    pixel_values: Annotated[
144
        torch.Tensor | list[torch.Tensor],
145
146
147
148
        TensorShape("np", 3, "ps", "ps"),
    ]

    image_grid_hws: Annotated[torch.Tensor, TensorShape("ni", 2)]
149
150
151
152
153
154
155
156
157
158
159


# TODO: support embeds too
# We only support pixel input for kimi-vl now
KimiVLImageInputs = KimiVLImagePixelInputs


class KimiVLProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
        return self.ctx.get_hf_config(KimiVLConfig)

160
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
Cyrus Leung's avatar
Cyrus Leung committed
161
162
        return {"image": None}

163
164
165
166
167
168
169
170
171
172
173
174
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        hf_processor = self.get_hf_processor()
        patch_size = hf_processor.image_processor.patch_size
        kernel_size = hf_processor.image_processor.merge_kernel_size
        in_token_limit = hf_processor.image_processor.in_token_limit
        height = image_height
        width = image_width
175
176
        assert isinstance(height, int), f"height must be int, current height {height}"
        assert isinstance(width, int), f"width must be int, current width {width}"
177
178
179
        assert kernel_size is not None, "kernel_size must be specified"

        if (width // patch_size) * (height // patch_size) > in_token_limit:
180
181
182
            scale = math.sqrt(
                in_token_limit / ((width // patch_size) * (height // patch_size))
            )
183
184
185
186
187
            new_w, new_h = int(width * scale), int(height * scale)
            width, height = new_w, new_h

        kernel_height, kernel_width = kernel_size

188
189
190
191
192
193
        pad_height = (
            kernel_height * patch_size - height % (kernel_height * patch_size)
        ) % (kernel_height * patch_size)
        pad_width = (
            kernel_width * patch_size - width % (kernel_width * patch_size)
        ) % (kernel_width * patch_size)
194
195
196
197
198
199
200
201
202
203
204
205

        # Calculate new dimensions after padding and patching
        token_height = (height + pad_height) // (kernel_size[0] * patch_size)
        token_width = (width + pad_width) // (kernel_size[1] * patch_size)
        return int(token_height * token_width)

    @property
    def image_token_id(self) -> int:
        return self.get_hf_config().media_placeholder_token_id


class KimiVLDummyInputsBuilder(BaseDummyInputsBuilder[KimiVLProcessingInfo]):
Cyrus Leung's avatar
Cyrus Leung committed
206
207
208
209
210
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token = processor.image_token
211

Cyrus Leung's avatar
Cyrus Leung committed
212
        return image_token * num_images
213

Cyrus Leung's avatar
Cyrus Leung committed
214
    def get_dummy_mm_data(
215
216
217
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
218
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
219
        mm_processor_kwargs: Mapping[str, object] | None = None,
Cyrus Leung's avatar
Cyrus Leung committed
220
    ) -> MultiModalDataDict:
221
222
        num_images = mm_counts.get("image", 0)

223
224
        image_overrides = mm_options.get("image") if mm_options else None

Cyrus Leung's avatar
Cyrus Leung committed
225
        return {
226
227
228
229
230
231
            "image": self._get_dummy_images(
                width=MaxImageTokenMeta.width,
                height=MaxImageTokenMeta.height,
                num_images=num_images,
                overrides=image_overrides,
            )
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
        }


class KimiVLMultiModalProcessor(BaseMultiModalProcessor[KimiVLProcessingInfo]):
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        image_grid_hws = hf_inputs.get("image_grid_hws", torch.empty((0, 2)))
        image_grid_sizes = image_grid_hws.prod(-1)

        # pixel_values is merged as a single large tensor
        # image_grid_hws is shapes for each subtensor in pixel_values
        return dict(
            pixel_values=MultiModalFieldConfig.flat_from_sizes(
248
249
                "image", image_grid_sizes
            ),
250
251
252
253
254
255
256
            image_grid_hws=MultiModalFieldConfig.batched("image"),
        )

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, Any],
257
        out_mm_kwargs: MultiModalKwargsItems,
258
259
260
261
262
    ) -> Sequence[PromptUpdate]:
        image_token_id = self.info.image_token_id

        def get_replacement(item_idx: int):
            images = mm_items.get_items(
263
264
                "image", (ImageEmbeddingItems, ImageProcessorItems)
            )
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
                num_image_tokens = self.info.get_num_image_tokens(
                    image_width=image_size.width,
                    image_height=image_size.height,
                )

            return [image_token_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=get_replacement,
            ),
        ]


286
287
288
289
290
291
@MULTIMODAL_REGISTRY.register_processor(
    KimiVLMultiModalProcessor,
    info=KimiVLProcessingInfo,
    dummy_inputs=KimiVLDummyInputsBuilder,
)
class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
292
293
    supports_encoder_tp_data = True

294
    @classmethod
295
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
296
297
298
299
300
        if modality.startswith("image"):
            return "<|media_start|>image<|media_content|><|media_pad|><|media_end|>"

        raise ValueError("Only image modality is supported")

301
302
303
304
305
306
307
308
309
310
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
    ) -> None:
        super().__init__()
        model_config = vllm_config.model_config
        config: KimiVLConfig = model_config.hf_config
        quant_config = vllm_config.quant_config

311
312
313
        self.config = config
        self.quant_config = quant_config

314
        assert isinstance(config.vision_config, MoonViTConfig)
315
316
317
        self.use_data_parallel = (
            model_config.multimodal_config.mm_encoder_tp_mode == "data"
        )
318
319
        self.hidden_size = config.text_config.hidden_size

320
321
322
323
324
325
326
327
328
        with self._mark_tower_model(vllm_config, "image"):
            self.vision_tower = MoonVitPretrainedModel(
                config.vision_config,
                prefix=maybe_prefix(prefix, "vision_tower"),
            )
            self.multi_modal_projector = KimiVLMultiModalProjector(
                config=config,
                prefix=maybe_prefix(prefix, "multi_modal_projector"),
            )
329

330
331
332
333
334
335
        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=config.text_config,
                prefix=maybe_prefix(prefix, "language_model"),
                architectures=["DeepseekV2ForCausalLM"],
336
            )
337

338
        self.make_empty_intermediate_tensors = (
339
340
            self.language_model.make_empty_intermediate_tensors
        )
341

342
343
344
        self.media_placeholder: int = self.config.media_placeholder_token_id

    def _parse_and_validate_image_input(
345
        self, **kwargs: object
346
    ) -> KimiVLImageInputs | None:
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
        # image input type must be pixel values now
        pixel_values = kwargs.pop("pixel_values", None)
        image_grid_hws = kwargs.pop("image_grid_hws", None)

        if pixel_values is None:
            return None

        return KimiVLImagePixelInputs(
            type="pixel_values",
            pixel_values=pixel_values,
            image_grid_hws=image_grid_hws,
        )

    # perform vt on processored pixel_values
    @torch.inference_mode()
362
    def _process_image_pixels(self, inputs: KimiVLImagePixelInputs) -> torch.Tensor:
363
364
        pixel_values = inputs["pixel_values"]
        image_grid_hws = inputs["image_grid_hws"]
365
        if self.use_data_parallel:
366
367
368
369
370
371
            return run_dp_sharded_mrope_vision_model(
                self.vision_tower,
                pixel_values,
                image_grid_hws.tolist(),
                rope_type="rope_2d",
            )
372
373
        else:
            return self.vision_tower(pixel_values, image_grid_hws)
374

375
    def _process_image_input(self, image_input: KimiVLImageInputs) -> torch.Tensor:
376
377
        assert image_input["type"] == "pixel_values"
        image_features = self._process_image_pixels(image_input)
378
        assert isinstance(image_features, (list, tuple))
379
        lengths = [x.shape[0] for x in image_features]
380
        return self.multi_modal_projector(torch.cat(image_features)).split(lengths)
381

382
    def embed_multimodal(self, **kwargs: object) -> NestedTensors | None:
383
384
385
386
387
388
389
390
391
392
393
        # Validate the multimodal input keyword arguments
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None

        # Run multimodal inputs through encoder and projector
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def forward(
        self,
394
        input_ids: torch.Tensor | None,
395
        positions: torch.Tensor,
396
397
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
398
        **kwargs: object,
399
    ) -> IntermediateTensors:
400
401
402
403
404
405
406
407
408
409
410
411
        if intermediate_tensors is not None:
            inputs_embeds = None

        hidden_states = self.language_model(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )

        return hidden_states

412
    def compute_logits(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
413
        return self.language_model.compute_logits(hidden_states)
414

415
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
416
417
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)