kimi_vl.py 21.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# ruff: noqa: E501
# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/modeling_kimi_vl.py
# Copyright 2025 The Moonshot AI Team, DeepSeek-AI, and HuggingFace Inc. team. All rights reserved.
#
# The code is based on llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py), but modified for KimiVL.
#
# Licensing Information:
# - Code derived from llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py) is licensed under the Apache License, Version 2.0.
# - Other parts of the code are licensed under the MIT License.
#
# Apache License, Version 2.0:
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# MIT License:
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import copy
import math
47
from collections.abc import Iterable, Mapping, Sequence
48
from dataclasses import dataclass
49
from typing import Annotated, Any, Literal
50
51
52

import torch
from torch import nn
53
from transformers import BatchFeature, DeepseekV2Config
54
55
56
from transformers.activations import GELUActivation

from vllm.config import VllmConfig
57
from vllm.config.multimodal import BaseDummyOptions
58
from vllm.distributed import get_pp_group
59
from vllm.model_executor.layers.fused_moe import FusedMoE
60
from vllm.model_executor.layers.linear import ReplicatedLinear
61
62
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
63
64
    ParallelLMHead,
)
65
from vllm.model_executor.model_loader.weight_utils import (
66
67
68
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
69
from vllm.model_executor.models.deepseek_v2 import DeepseekV2Model
70
from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP
71
72
from vllm.model_executor.models.moonvit import MoonVitPretrainedModel
from vllm.multimodal import MULTIMODAL_REGISTRY
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
    NestedTensors,
)
from vllm.multimodal.parse import (
    ImageEmbeddingItems,
    ImageProcessorItems,
    MultiModalDataItems,
)
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
Cyrus Leung's avatar
Cyrus Leung committed
90
from vllm.multimodal.profiling import BaseDummyInputsBuilder
91
92
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import KimiVLConfig, MoonViTConfig
93
from vllm.utils.tensor_schema import TensorSchema, TensorShape
94

95
from .utils import PPMissingLayer, is_pp_missing_parameter, maybe_prefix
96
from .vision import run_dp_sharded_mrope_vision_model
97
98
99
100
101
102
103
104
105
106


# For dummy input only
@dataclass
class MaxImageTokenMeta:
    width: int = 1024
    height: int = 1024


class KimiVLMultiModalProjector(nn.Module):
107
108
109
    def __init__(
        self, config: KimiVLConfig, use_data_parallel: bool = False, prefix: str = ""
    ):
110
        super().__init__()
111
        self.use_data_parallel = use_data_parallel
112

113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
        self.hidden_size = (
            config.vision_config.hidden_size
            * config.vision_config.merge_kernel_size[0]
            * config.vision_config.merge_kernel_size[1]
        )

        self.pre_norm = torch.nn.LayerNorm(config.vision_config.hidden_size, eps=1e-5)
        self.linear_1 = ReplicatedLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
            prefix=maybe_prefix(prefix, "linear_1"),
        )
        self.linear_2 = ReplicatedLinear(
            self.hidden_size,
            config.text_config.hidden_size,
            bias=True,
            prefix=maybe_prefix(prefix, "linear_2"),
        )
132
133
134
        self.act = GELUActivation()

    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
135
        hidden_states = self.pre_norm(image_features).view(-1, self.hidden_size)
136
        hidden_states, _ = self.linear_1(hidden_states)
137
        hidden_states = self.act(hidden_states)
138
        hidden_states, _ = self.linear_2(hidden_states)
139
140
141
        return hidden_states


142
class KimiVLImagePixelInputs(TensorSchema):
143
    """
144
145
146
147
148
    Dimensions:
        - nc: Number of channels
        - np: Number of patches
        - ps: Patch size
        - ni: Number of images
149
    """
150

151
    type: Literal["pixel_values"] = "pixel_values"
152

153
    pixel_values: Annotated[
154
        torch.Tensor | list[torch.Tensor],
155
156
157
158
        TensorShape("np", 3, "ps", "ps"),
    ]

    image_grid_hws: Annotated[torch.Tensor, TensorShape("ni", 2)]
159
160
161
162
163
164
165
166
167
168
169


# TODO: support embeds too
# We only support pixel input for kimi-vl now
KimiVLImageInputs = KimiVLImagePixelInputs


class KimiVLProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
        return self.ctx.get_hf_config(KimiVLConfig)

170
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
Cyrus Leung's avatar
Cyrus Leung committed
171
172
        return {"image": None}

173
174
175
176
177
178
179
180
181
182
183
184
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        hf_processor = self.get_hf_processor()
        patch_size = hf_processor.image_processor.patch_size
        kernel_size = hf_processor.image_processor.merge_kernel_size
        in_token_limit = hf_processor.image_processor.in_token_limit
        height = image_height
        width = image_width
185
186
        assert isinstance(height, int), f"height must be int, current height {height}"
        assert isinstance(width, int), f"width must be int, current width {width}"
187
188
189
        assert kernel_size is not None, "kernel_size must be specified"

        if (width // patch_size) * (height // patch_size) > in_token_limit:
190
191
192
            scale = math.sqrt(
                in_token_limit / ((width // patch_size) * (height // patch_size))
            )
193
194
195
196
197
            new_w, new_h = int(width * scale), int(height * scale)
            width, height = new_w, new_h

        kernel_height, kernel_width = kernel_size

198
199
200
201
202
203
        pad_height = (
            kernel_height * patch_size - height % (kernel_height * patch_size)
        ) % (kernel_height * patch_size)
        pad_width = (
            kernel_width * patch_size - width % (kernel_width * patch_size)
        ) % (kernel_width * patch_size)
204
205
206
207
208
209
210
211
212
213
214
215

        # Calculate new dimensions after padding and patching
        token_height = (height + pad_height) // (kernel_size[0] * patch_size)
        token_width = (width + pad_width) // (kernel_size[1] * patch_size)
        return int(token_height * token_width)

    @property
    def image_token_id(self) -> int:
        return self.get_hf_config().media_placeholder_token_id


class KimiVLDummyInputsBuilder(BaseDummyInputsBuilder[KimiVLProcessingInfo]):
Cyrus Leung's avatar
Cyrus Leung committed
216
217
218
219
220
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token = processor.image_token
221

Cyrus Leung's avatar
Cyrus Leung committed
222
        return image_token * num_images
223

Cyrus Leung's avatar
Cyrus Leung committed
224
    def get_dummy_mm_data(
225
226
227
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
228
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
Cyrus Leung's avatar
Cyrus Leung committed
229
    ) -> MultiModalDataDict:
230
231
        num_images = mm_counts.get("image", 0)

232
233
        image_overrides = mm_options.get("image") if mm_options else None

Cyrus Leung's avatar
Cyrus Leung committed
234
        return {
235
236
237
238
239
240
            "image": self._get_dummy_images(
                width=MaxImageTokenMeta.width,
                height=MaxImageTokenMeta.height,
                num_images=num_images,
                overrides=image_overrides,
            )
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
        }


class KimiVLMultiModalProcessor(BaseMultiModalProcessor[KimiVLProcessingInfo]):
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        image_grid_hws = hf_inputs.get("image_grid_hws", torch.empty((0, 2)))
        image_grid_sizes = image_grid_hws.prod(-1)

        # pixel_values is merged as a single large tensor
        # image_grid_hws is shapes for each subtensor in pixel_values
        return dict(
            pixel_values=MultiModalFieldConfig.flat_from_sizes(
257
258
                "image", image_grid_sizes
            ),
259
260
261
262
263
264
265
            image_grid_hws=MultiModalFieldConfig.batched("image"),
        )

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, Any],
266
        out_mm_kwargs: MultiModalKwargsItems,
267
268
269
270
271
    ) -> Sequence[PromptUpdate]:
        image_token_id = self.info.image_token_id

        def get_replacement(item_idx: int):
            images = mm_items.get_items(
272
273
                "image", (ImageEmbeddingItems, ImageProcessorItems)
            )
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
                num_image_tokens = self.info.get_num_image_tokens(
                    image_width=image_size.width,
                    image_height=image_size.height,
                )

            return [image_token_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=get_replacement,
            ),
        ]


295
296
297
298
299
300
@MULTIMODAL_REGISTRY.register_processor(
    KimiVLMultiModalProcessor,
    info=KimiVLProcessingInfo,
    dummy_inputs=KimiVLDummyInputsBuilder,
)
class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
301
302
    supports_encoder_tp_data = True

303
    @classmethod
304
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
305
306
307
308
309
        if modality.startswith("image"):
            return "<|media_start|>image<|media_content|><|media_pad|><|media_end|>"

        raise ValueError("Only image modality is supported")

310
311
312
313
314
315
316
317
318
319
320
321
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
    ) -> None:
        super().__init__()
        model_config = vllm_config.model_config
        config: KimiVLConfig = model_config.hf_config
        self.config = config
        quant_config = vllm_config.quant_config

        assert isinstance(config.vision_config, MoonViTConfig)
322
323
324
        self.use_data_parallel = (
            model_config.multimodal_config.mm_encoder_tp_mode == "data"
        )
325
        self.hidden_size = config.text_config.hidden_size
326
327
328
329
330
        self.vision_tower = MoonVitPretrainedModel(
            config.vision_config,
            self.use_data_parallel,
            prefix=maybe_prefix(prefix, "vision_tower"),
        )
331
332
333
334

        self.multi_modal_projector = KimiVLMultiModalProjector(
            config=config,
            use_data_parallel=self.use_data_parallel,
335
336
            prefix=maybe_prefix(prefix, "multi_modal_projector"),
        )
337
338
339

        self.quant_config = quant_config
        sub_vllm_config = copy.deepcopy(vllm_config)
340
341
342
        sub_vllm_config.model_config.hf_config = (
            sub_vllm_config.model_config.hf_config.text_config
        )
343
344
345
346
        self.language_model = DeepseekV2Model(
            vllm_config=sub_vllm_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
347
348
        if get_pp_group().is_last_rank:
            self.lm_head = ParallelLMHead(
349
                config.vocab_size,
350
                config.text_config.hidden_size,
351
                prefix=maybe_prefix(prefix, "lm_head"),
352
353
354
355
            )
        else:
            self.lm_head = PPMissingLayer()
        self.make_empty_intermediate_tensors = (
356
357
            self.language_model.make_empty_intermediate_tensors
        )
358
        logit_scale = getattr(config, "logit_scale", 1.0)
359
        self.logits_processor = LogitsProcessor(config.vocab_size, scale=logit_scale)
360
361
362
        self.media_placeholder: int = self.config.media_placeholder_token_id

    def _parse_and_validate_image_input(
363
        self, **kwargs: object
364
    ) -> KimiVLImageInputs | None:
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
        # image input type must be pixel values now
        pixel_values = kwargs.pop("pixel_values", None)
        image_grid_hws = kwargs.pop("image_grid_hws", None)

        if pixel_values is None:
            return None

        return KimiVLImagePixelInputs(
            type="pixel_values",
            pixel_values=pixel_values,
            image_grid_hws=image_grid_hws,
        )

    # perform vt on processored pixel_values
    @torch.inference_mode()
380
    def _process_image_pixels(self, inputs: KimiVLImagePixelInputs) -> torch.Tensor:
381
382
383
384
        assert self.vision_tower is not None

        pixel_values = inputs["pixel_values"]
        image_grid_hws = inputs["image_grid_hws"]
385
        if self.use_data_parallel:
386
387
388
389
390
391
            return run_dp_sharded_mrope_vision_model(
                self.vision_tower,
                pixel_values,
                image_grid_hws.tolist(),
                rope_type="rope_2d",
            )
392
393
        else:
            return self.vision_tower(pixel_values, image_grid_hws)
394

395
    def _process_image_input(self, image_input: KimiVLImageInputs) -> torch.Tensor:
396
397
        assert image_input["type"] == "pixel_values"
        image_features = self._process_image_pixels(image_input)
398
        assert isinstance(image_features, (list, tuple))
399
        lengths = [x.shape[0] for x in image_features]
400
        return self.multi_modal_projector(torch.cat(image_features)).split(lengths)
401

402
403
404
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

405
    def embed_multimodal(self, **kwargs: object) -> NestedTensors | None:
406
407
408
409
410
411
412
413
414
415
416
417
418
        # Validate the multimodal input keyword arguments
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None

        # Run multimodal inputs through encoder and projector
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
419
420
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
421
        **kwargs: object,
422
    ) -> IntermediateTensors:
423
424
425
426
427
428
429
430
431
432
433
434
        if intermediate_tensors is not None:
            inputs_embeds = None

        hidden_states = self.language_model(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )

        return hidden_states

435
    def compute_logits(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
436
        logits = self.logits_processor(self.lm_head, hidden_states, **kwargs)
437
438
        return logits

439
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
440
441
442
443
444
445
446
447
448
449
450
        config = self.config.text_config
        _KEYS_TO_MODIFY_MAPPING = {
            "language_model.lm_head": "lm_head",
            "language_model.model": "language_model",
        }
        # only doing this for language model part for now.
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
451
452
453
454
455
        use_mha = (
            config.model_type == "deepseek"
            or config.qk_nope_head_dim + config.qk_rope_head_dim == 0
        )
        if use_mha:
456
457
458
459
460
461
462
463
464
465
466
467
            stacked_params_mapping += [
                (".qkv_proj", ".q_proj", "q"),
                (".qkv_proj", ".k_proj", "k"),
                (".qkv_proj", ".v_proj", "v"),
            ]
        if getattr(config, "n_routed_experts", None):
            # Params for weights, fp8 weight scales, fp8 activation scales
            # (param_name, weight_name, expert_id, shard_id)
            expert_params_mapping = FusedMoE.make_expert_params_mapping(
                ckpt_gate_proj_name="gate_proj",
                ckpt_down_proj_name="down_proj",
                ckpt_up_proj_name="up_proj",
468
469
                num_experts=config.n_routed_experts,
            )
470
471
472
473
        else:
            expert_params_mapping = []

        params_dict = dict(self.named_parameters())
474

475
476
477
478
479
480
481
482
483
484
        for args in weights:
            name, loaded_weight = args[:2]
            kwargs = args[2] if len(args) > 2 else {}
            if "rotary_emb.inv_freq" in name:
                continue

            spec_layer = get_spec_layer_idx_from_weight_name(config, name)
            if spec_layer is not None:
                continue  # skip spec decode layers for main model

485
            if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
486
487
488
489
490
491
492
493
494
495
496
497
498
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
                continue
            for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
                if key_to_modify in name:
                    name = name.replace(key_to_modify, new_key)
            use_default_weight_loading = False
            if "vision" in name:
                if self.vision_tower is not None:
                    # We only do sharding for language model and
                    # not vision model for now.
                    use_default_weight_loading = True
            else:
499
                for param_name, weight_name, shard_id in stacked_params_mapping:
500
501
502
503
504
505
506
507
                    if weight_name not in name:
                        continue
                    # We have mlp.experts[0].gate_proj in the checkpoint.
                    # Since we handle the experts below in expert_params_mapping,
                    # we need to skip here BEFORE we update the name, otherwise
                    # name will be updated to mlp.experts[0].gate_up_proj, which
                    # will then be updated below in expert_params_mapping
                    # for mlp.experts[0].gate_gate_up_proj, which breaks load.
508
                    if ("mlp.experts." in name) and name not in params_dict:
509
510
511
512
513
514
515
516
517
518
519
520
521
522
                        continue
                    name = name.replace(weight_name, param_name)
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue

                    if is_pp_missing_parameter(name, self):
                        continue

                    param = params_dict[name]
                    weight_loader = param.weight_loader
                    weight_loader(param, loaded_weight, shard_id, **kwargs)
                    break
                else:
523
524
525
526
527
528
                    for idx, (
                        param_name,
                        weight_name,
                        expert_id,
                        shard_id,
                    ) in enumerate(expert_params_mapping):
529
530
531
532
533
534
535
536
537
                        if weight_name not in name:
                            continue
                        name = name.replace(weight_name, param_name)

                        if is_pp_missing_parameter(name, self):
                            continue

                        param = params_dict[name]
                        weight_loader = param.weight_loader
538
539
540
541
542
543
544
545
                        weight_loader(
                            param,
                            loaded_weight,
                            name,
                            expert_id=expert_id,
                            shard_id=shard_id,
                            **kwargs,
                        )
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
                        break
                    else:
                        use_default_weight_loading = True
            if use_default_weight_loading:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
562
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
563
564
565
                weight_loader(param, loaded_weight, **kwargs)


566
567
def get_spec_layer_idx_from_weight_name(
    config: DeepseekV2Config, weight_name: str
568
) -> int | None:
569
570
571
    if hasattr(config, "num_nextn_predict_layers") and (
        config.num_nextn_predict_layers > 0
    ):
572
573
        layer_idx = config.num_hidden_layers
        for i in range(config.num_nextn_predict_layers):
574
            if weight_name.startswith(f"model.layers.{layer_idx + i}."):
575
576
                return layer_idx + i
    return None