kimi_vl.py 14.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# ruff: noqa: E501
# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/modeling_kimi_vl.py
# Copyright 2025 The Moonshot AI Team, DeepSeek-AI, and HuggingFace Inc. team. All rights reserved.
#
# The code is based on llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py), but modified for KimiVL.
#
# Licensing Information:
# - Code derived from llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py) is licensed under the Apache License, Version 2.0.
# - Other parts of the code are licensed under the MIT License.
#
# Apache License, Version 2.0:
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# MIT License:
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import math
46
from collections.abc import Iterable, Mapping, Sequence
47
from dataclasses import dataclass
48
from typing import Annotated, Any, Literal
49
50
51

import torch
from torch import nn
52
from transformers import BatchFeature
53
54
55
from transformers.activations import GELUActivation

from vllm.config import VllmConfig
56
from vllm.config.multimodal import BaseDummyOptions
57
from vllm.model_executor.layers.linear import ReplicatedLinear
58
from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP
59
60
from vllm.model_executor.models.moonvit import MoonVitPretrainedModel
from vllm.multimodal import MULTIMODAL_REGISTRY
61
62
63
64
65
66
67
68
69
70
71
72
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
    NestedTensors,
)
from vllm.multimodal.parse import (
    ImageEmbeddingItems,
    ImageProcessorItems,
    MultiModalDataItems,
)
from vllm.multimodal.processing import (
73
    BaseDummyInputsBuilder,
74
75
76
77
78
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
79
80
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import KimiVLConfig, MoonViTConfig
81
from vllm.utils.tensor_schema import TensorSchema, TensorShape
82

83
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
84
from .vision import is_vit_use_data_parallel, run_dp_sharded_mrope_vision_model
85
86
87
88
89
90
91
92
93
94


# For dummy input only
@dataclass
class MaxImageTokenMeta:
    width: int = 1024
    height: int = 1024


class KimiVLMultiModalProjector(nn.Module):
95
    def __init__(
96
97
98
        self,
        config: KimiVLConfig,
        prefix: str = "",
99
    ):
100
        super().__init__()
101
        self.use_data_parallel = is_vit_use_data_parallel()
102

103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
        self.hidden_size = (
            config.vision_config.hidden_size
            * config.vision_config.merge_kernel_size[0]
            * config.vision_config.merge_kernel_size[1]
        )

        self.pre_norm = torch.nn.LayerNorm(config.vision_config.hidden_size, eps=1e-5)
        self.linear_1 = ReplicatedLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
            prefix=maybe_prefix(prefix, "linear_1"),
        )
        self.linear_2 = ReplicatedLinear(
            self.hidden_size,
            config.text_config.hidden_size,
            bias=True,
            prefix=maybe_prefix(prefix, "linear_2"),
        )
122
123
124
        self.act = GELUActivation()

    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
125
        hidden_states = self.pre_norm(image_features).view(-1, self.hidden_size)
126
        hidden_states, _ = self.linear_1(hidden_states)
127
        hidden_states = self.act(hidden_states)
128
        hidden_states, _ = self.linear_2(hidden_states)
129
130
131
        return hidden_states


132
class KimiVLImagePixelInputs(TensorSchema):
133
    """
134
135
136
137
138
    Dimensions:
        - nc: Number of channels
        - np: Number of patches
        - ps: Patch size
        - ni: Number of images
139
    """
140

141
    type: Literal["pixel_values"] = "pixel_values"
142

143
    pixel_values: Annotated[
144
        torch.Tensor | list[torch.Tensor],
145
146
147
148
        TensorShape("np", 3, "ps", "ps"),
    ]

    image_grid_hws: Annotated[torch.Tensor, TensorShape("ni", 2)]
149
150
151
152
153
154
155
156
157
158
159


# TODO: support embeds too
# We only support pixel input for kimi-vl now
KimiVLImageInputs = KimiVLImagePixelInputs


class KimiVLProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
        return self.ctx.get_hf_config(KimiVLConfig)

160
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
Cyrus Leung's avatar
Cyrus Leung committed
161
162
        return {"image": None}

163
164
165
166
167
168
169
170
171
172
173
174
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        hf_processor = self.get_hf_processor()
        patch_size = hf_processor.image_processor.patch_size
        kernel_size = hf_processor.image_processor.merge_kernel_size
        in_token_limit = hf_processor.image_processor.in_token_limit
        height = image_height
        width = image_width
175
176
        assert isinstance(height, int), f"height must be int, current height {height}"
        assert isinstance(width, int), f"width must be int, current width {width}"
177
178
179
        assert kernel_size is not None, "kernel_size must be specified"

        if (width // patch_size) * (height // patch_size) > in_token_limit:
180
181
182
            scale = math.sqrt(
                in_token_limit / ((width // patch_size) * (height // patch_size))
            )
183
184
185
186
187
            new_w, new_h = int(width * scale), int(height * scale)
            width, height = new_w, new_h

        kernel_height, kernel_width = kernel_size

188
189
190
191
192
193
        pad_height = (
            kernel_height * patch_size - height % (kernel_height * patch_size)
        ) % (kernel_height * patch_size)
        pad_width = (
            kernel_width * patch_size - width % (kernel_width * patch_size)
        ) % (kernel_width * patch_size)
194
195
196
197
198
199
200
201
202
203
204
205

        # Calculate new dimensions after padding and patching
        token_height = (height + pad_height) // (kernel_size[0] * patch_size)
        token_width = (width + pad_width) // (kernel_size[1] * patch_size)
        return int(token_height * token_width)

    @property
    def image_token_id(self) -> int:
        return self.get_hf_config().media_placeholder_token_id


class KimiVLDummyInputsBuilder(BaseDummyInputsBuilder[KimiVLProcessingInfo]):
Cyrus Leung's avatar
Cyrus Leung committed
206
207
208
209
210
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token = processor.image_token
211

Cyrus Leung's avatar
Cyrus Leung committed
212
        return image_token * num_images
213

Cyrus Leung's avatar
Cyrus Leung committed
214
    def get_dummy_mm_data(
215
216
217
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
218
        mm_options: Mapping[str, BaseDummyOptions],
Cyrus Leung's avatar
Cyrus Leung committed
219
    ) -> MultiModalDataDict:
220
221
        num_images = mm_counts.get("image", 0)

222
        image_overrides = mm_options.get("image")
223

Cyrus Leung's avatar
Cyrus Leung committed
224
        return {
225
226
227
228
229
230
            "image": self._get_dummy_images(
                width=MaxImageTokenMeta.width,
                height=MaxImageTokenMeta.height,
                num_images=num_images,
                overrides=image_overrides,
            )
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
        }


class KimiVLMultiModalProcessor(BaseMultiModalProcessor[KimiVLProcessingInfo]):
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        image_grid_hws = hf_inputs.get("image_grid_hws", torch.empty((0, 2)))
        image_grid_sizes = image_grid_hws.prod(-1)

        # pixel_values is merged as a single large tensor
        # image_grid_hws is shapes for each subtensor in pixel_values
        return dict(
            pixel_values=MultiModalFieldConfig.flat_from_sizes(
247
248
                "image", image_grid_sizes
            ),
249
250
251
252
253
254
255
            image_grid_hws=MultiModalFieldConfig.batched("image"),
        )

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, Any],
256
        out_mm_kwargs: MultiModalKwargsItems,
257
258
259
260
261
    ) -> Sequence[PromptUpdate]:
        image_token_id = self.info.image_token_id

        def get_replacement(item_idx: int):
            images = mm_items.get_items(
262
263
                "image", (ImageEmbeddingItems, ImageProcessorItems)
            )
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
                num_image_tokens = self.info.get_num_image_tokens(
                    image_width=image_size.width,
                    image_height=image_size.height,
                )

            return [image_token_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=get_replacement,
            ),
        ]


285
286
287
288
289
290
@MULTIMODAL_REGISTRY.register_processor(
    KimiVLMultiModalProcessor,
    info=KimiVLProcessingInfo,
    dummy_inputs=KimiVLDummyInputsBuilder,
)
class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
291
292
    supports_encoder_tp_data = True

293
    @classmethod
294
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
295
296
297
298
299
        if modality.startswith("image"):
            return "<|media_start|>image<|media_content|><|media_pad|><|media_end|>"

        raise ValueError("Only image modality is supported")

300
301
302
303
304
305
306
307
308
309
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
    ) -> None:
        super().__init__()
        model_config = vllm_config.model_config
        config: KimiVLConfig = model_config.hf_config
        quant_config = vllm_config.quant_config

310
311
312
        self.config = config
        self.quant_config = quant_config

313
        assert isinstance(config.vision_config, MoonViTConfig)
314
315
316
        self.use_data_parallel = (
            model_config.multimodal_config.mm_encoder_tp_mode == "data"
        )
317
318
        self.hidden_size = config.text_config.hidden_size

319
320
321
322
323
324
325
326
327
        with self._mark_tower_model(vllm_config, "image"):
            self.vision_tower = MoonVitPretrainedModel(
                config.vision_config,
                prefix=maybe_prefix(prefix, "vision_tower"),
            )
            self.multi_modal_projector = KimiVLMultiModalProjector(
                config=config,
                prefix=maybe_prefix(prefix, "multi_modal_projector"),
            )
328

329
330
331
332
333
334
        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=config.text_config,
                prefix=maybe_prefix(prefix, "language_model"),
                architectures=["DeepseekV2ForCausalLM"],
335
            )
336

337
        self.make_empty_intermediate_tensors = (
338
339
            self.language_model.make_empty_intermediate_tensors
        )
340

341
342
343
        self.media_placeholder: int = self.config.media_placeholder_token_id

    def _parse_and_validate_image_input(
344
        self, **kwargs: object
345
    ) -> KimiVLImageInputs | None:
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
        # image input type must be pixel values now
        pixel_values = kwargs.pop("pixel_values", None)
        image_grid_hws = kwargs.pop("image_grid_hws", None)

        if pixel_values is None:
            return None

        return KimiVLImagePixelInputs(
            type="pixel_values",
            pixel_values=pixel_values,
            image_grid_hws=image_grid_hws,
        )

    # perform vt on processored pixel_values
    @torch.inference_mode()
361
    def _process_image_pixels(self, inputs: KimiVLImagePixelInputs) -> torch.Tensor:
362
363
        pixel_values = inputs["pixel_values"]
        image_grid_hws = inputs["image_grid_hws"]
364
        if self.use_data_parallel:
365
366
367
368
369
370
            return run_dp_sharded_mrope_vision_model(
                self.vision_tower,
                pixel_values,
                image_grid_hws.tolist(),
                rope_type="rope_2d",
            )
371
372
        else:
            return self.vision_tower(pixel_values, image_grid_hws)
373

374
    def _process_image_input(self, image_input: KimiVLImageInputs) -> torch.Tensor:
375
376
        assert image_input["type"] == "pixel_values"
        image_features = self._process_image_pixels(image_input)
377
        assert isinstance(image_features, (list, tuple))
378
        lengths = [x.shape[0] for x in image_features]
379
        return self.multi_modal_projector(torch.cat(image_features)).split(lengths)
380

381
    def embed_multimodal(self, **kwargs: object) -> NestedTensors | None:
382
383
384
385
386
387
388
389
390
391
392
        # Validate the multimodal input keyword arguments
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None

        # Run multimodal inputs through encoder and projector
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def forward(
        self,
zhuwenwen's avatar
zhuwenwen committed
393
        input_ids: torch.Tensor | None,
394
        positions: torch.Tensor,
395
396
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
397
        **kwargs: object,
398
    ) -> IntermediateTensors:
399
400
401
402
403
404
405
406
407
408
409
410
        if intermediate_tensors is not None:
            inputs_embeds = None

        hidden_states = self.language_model(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )

        return hidden_states

411
    def compute_logits(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
412
        return self.language_model.compute_logits(hidden_states)
413

414
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
415
        loader = AutoWeightsLoader(self)
zhuwenwen's avatar
zhuwenwen committed
416
        return loader.load_weights(weights)