internvl.py 30.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_internvl_chat.py
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
10
from abc import abstractmethod
11
from collections.abc import Iterable, Mapping, Sequence
12
from typing import Annotated, Literal, TypeAlias, TypeVar
13
14
15

import torch
import torch.nn as nn
16
from transformers import BatchFeature, PretrainedConfig
17

18
from vllm.config import VllmConfig
19
from vllm.config.multimodal import BaseDummyOptions
20
21
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
22
23
24
25
from vllm.model_executor.models.intern_vit import (
    InternVisionModel,
    InternVisionPatchModel,
)
26
from vllm.model_executor.models.module_mapping import MultiModelKeys
27
from vllm.multimodal import MULTIMODAL_REGISTRY
28
29
30
31
32
33
34
35
36
37
38
39
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
from vllm.multimodal.parse import (
    ImageEmbeddingItems,
    ImageProcessorItems,
    ImageSize,
    MultiModalDataItems,
)
from vllm.multimodal.processing import (
40
    BaseDummyInputsBuilder,
41
42
43
44
45
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
46
from vllm.sequence import IntermediateTensors
47
48
49
50
from vllm.transformers_utils.processors.internvl import (
    BaseInternVLProcessor,
    InternVLProcessor,
)
51
from vllm.utils.tensor_schema import TensorSchema, TensorShape
52

53
54
55
56
57
58
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMultiModal,
    SupportsPP,
)
59
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
60
61


62
class InternVLImagePixelInputs(TensorSchema):
63
    """
64
65
66
67
68
69
    Dimensions:
        - bn: Batch size * number of images
        - bnp: Batch size * number of images * (1 + num_patches)
        - c: Number of channels (3)
        - h: Height of each image patch
        - w: Width of each image patch
70
    """
71

72
73
74
    type: Literal["pixel_values"]
    pixel_values_flat: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")]
    num_patches: Annotated[torch.Tensor, TensorShape("bn")]
75

76

77
78
79
80
81
82
class InternVLImageEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - n: Number of images
        - f: Total image feature size
        - h: Hidden size (must match the hidden size of language model backbone)
83
    """
84

85
    type: Literal["image_embeds"]
86
    data: Annotated[torch.Tensor | list[torch.Tensor], TensorShape("n", "f", "h")]
87
88


89
InternVLImageInputs: TypeAlias = InternVLImagePixelInputs | InternVLImageEmbeddingInputs
90
91


92
class InternVLVideoPixelInputs(TensorSchema):
93
    """
94
95
96
97
98
99
    Dimensions:
        - bvf: Batch size * number of videos * num_frames
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height of each video frame
        - w: Width of each video frame
100
    """
101

102
103
104
    type: Literal["pixel_values_videos"]
    pixel_values_flat: Annotated[torch.Tensor, TensorShape("bvf", 3, "h", "w")]
    num_patches: Annotated[torch.Tensor, TensorShape("bn")]
105
106


107
108
109
110
111
112
class InternVLVideoEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - n: Number of videos
        - f: Total video feature size
        - h: Hidden size (must match the hidden size of language model backbone)
113
    """
114

115
    type: Literal["video_embeds"]
116
    data: Annotated[torch.Tensor | list[torch.Tensor], TensorShape("n", "f", "h")]
117
118


119
InternVLVideoInputs: TypeAlias = InternVLVideoPixelInputs | InternVLVideoEmbeddingInputs
120
121


122
class BaseInternVLProcessingInfo(BaseProcessingInfo):
123
    """Basic image-only ProcessingInfo for InternVL-style models."""
124
125

    @abstractmethod
126
    def get_hf_processor(self, **kwargs: object) -> BaseInternVLProcessor:
127
128
        raise NotImplementedError

129
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
130
131
132
133
        return {"image": None}

    def get_num_image_tokens(
        self,
134
        *,
135
136
        image_width: int,
        image_height: int,
137
        processor: BaseInternVLProcessor,
138
139
140
141
142
    ) -> int:
        return processor.get_num_image_tokens(
            image_width=image_width,
            image_height=image_height,
        )
143

144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
    def get_image_size_with_most_features(self) -> ImageSize:
        processor = self.get_hf_processor()

        base_size = processor.image_size
        target_ratios = processor.resolve_target_ratios()

        largest_feature_size, largest_feature_pinpoint = 0, None
        for wr, hr in target_ratios:
            width, height = base_size * wr, base_size * hr

            feat_size = self.get_num_image_tokens(
                image_width=width,
                image_height=height,
                processor=processor,
            )
            if feat_size > largest_feature_size:
                largest_feature_size = feat_size
161
                largest_feature_pinpoint = ImageSize(width=width, height=height)
162
163
164
165
166
167

        if largest_feature_size == 0 or largest_feature_pinpoint is None:
            raise ValueError("Cannot have a largest feature size of 0!")

        return largest_feature_pinpoint

168
169
170
171
172
173
174
175
176
177
    def get_max_image_tokens(self) -> int:
        processor = self.get_hf_processor()
        target_width, target_height = self.get_image_size_with_most_features()

        return self.get_num_image_tokens(
            image_width=target_width,
            image_height=target_height,
            processor=processor,
        )

178
179
180
181

_I = TypeVar("_I", bound=BaseInternVLProcessingInfo)


182
183
class BaseInternVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
    """Basic image-only DummyInputsBuilder for InternVL-style models."""
184

185
186
187
188
189
190
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        return "<image>" * num_images

    def get_dummy_mm_data(
191
192
193
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
194
        mm_options: Mapping[str, BaseDummyOptions],
195
    ) -> MultiModalDataDict:
196
        target_width, target_height = self.info.get_image_size_with_most_features()
197
198
        num_images = mm_counts.get("image", 0)

199
        image_overrides = mm_options.get("image")
200

201
        return {
202
203
204
205
206
207
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
208
209
210
        }


211
class BaseInternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
212
    """Basic image-only MultiModalProcessor for InternVL-style models."""
213
214
215
216
217
218

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
219
        tok_kwargs: Mapping[str, object],
220
    ) -> BatchFeature:
221
222
223
224
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
225
            tok_kwargs=tok_kwargs,
226
        )
227

228
229
        hf_processor = self.info.get_hf_processor(**mm_kwargs)
        image_token_id = hf_processor.image_token_id
230
231
232
233

        # Since there may be extra tokens in the feature placeholders,
        # we need to pass the image token ID to the model to select the
        # tokens to merge from the vision encoder outputs
234
        processed_outputs["image_token_id"] = torch.tensor(image_token_id)
235
236
237
238
239

        return processed_outputs

    def _get_mm_fields_config(
        self,
240
        hf_inputs: BatchFeature,
241
242
243
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
244
        num_images = len(image_num_patches)
245
246
247

        return dict(
            pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
248
249
                "image", image_num_patches
            ),
250
251
            image_num_patches=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
252
            image_token_id=MultiModalFieldConfig.shared("image", num_images),
253
254
        )

255
    def _get_prompt_updates(
256
257
258
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
259
        out_mm_kwargs: MultiModalKwargsItems,
260
    ) -> Sequence[PromptUpdate]:
261
262
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)

263
264
265
        out_mm_data = out_mm_kwargs.get_data()
        if "image_num_patches" in out_mm_data:
            image_num_patches = out_mm_data["image_num_patches"]
266
267
            assert isinstance(image_num_patches, torch.Tensor)
            image_num_patches = image_num_patches.tolist()
268
        elif "image_embeds" in out_mm_data:
269
270
            # TODO: Use image size information in dictionary embedding inputs
            # to compute num_patches (similar to Qwen2-VL)
271
            image_num_patches = [None] * len(out_mm_data["image_embeds"])
272
273
274
275
276
        else:
            image_num_patches = []

        def get_replacement_internvl(item_idx: int):
            images = mm_items.get_items(
277
278
                "image", (ImageEmbeddingItems, ImageProcessorItems)
            )
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293

            if isinstance(images, ImageEmbeddingItems):
                feature_size = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
                feature_size = self.info.get_num_image_tokens(
                    image_width=image_size.width,
                    image_height=image_size.height,
                    processor=hf_processor,
                )

            num_patches = image_num_patches[item_idx]
            if num_patches is not None:
                assert isinstance(num_patches, int)

294
            return hf_processor.get_image_repl(feature_size, num_patches)
295

296
297
298
299
300
301
302
        return [
            PromptReplacement(
                modality="image",
                target="<image>",
                replacement=get_replacement_internvl,
            )
        ]
303
304


305
class InternVLProcessingInfo(BaseInternVLProcessingInfo):
306
307
308
309
310
311
312
313
314
315
    """InternVL ProcessingInfo extended for video processing"""

    @property
    def supports_video(self):
        return self.get_hf_processor().supports_video

    def get_supported_mm_limits(self):
        video_limit = {"video": None} if self.supports_video else {}
        return {**super().get_supported_mm_limits(), **video_limit}

316
    def get_video_token(self) -> str | None:
317
        text_model_type = self.get_hf_config().get_text_config().model_type
318
319
320
321
322
323
324
        video_token_map = {
            "qwen2": "<|video_pad|>",
            "qwen3": "<|video_pad|>",
            "qwen3_moe": "<|video_pad|>",
            "gpt_oss": "<|reserved_200000|>",
        }
        return video_token_map.get(text_model_type)
325
326
327
328
329
330
331
332
333
334
335
336

    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
        max_images = mm_counts.get("image", 0)
        max_videos = mm_counts.get("video", 0)

        processor = self.get_hf_processor()

        max_image_tokens = self.get_max_image_tokens() * max_images
337
        max_total_frames = (seq_len - max_image_tokens) // processor.num_image_token
338
339
340
        max_frames_per_video = max_total_frames // max(max_videos, 1)

        return max(max_frames_per_video, 1)
341

342
    def get_hf_processor(self, **kwargs: object) -> InternVLProcessor:
343
344
345
346
        return self.ctx.init_processor(
            InternVLProcessor,
            config=self.get_hf_config(),
            tokenizer=self.get_tokenizer(),
347
            video_token=self.get_video_token(),
348
            **kwargs,
349
350
351
        )


352
class InternVLDummyInputsBuilder(
353
354
    BaseInternVLDummyInputsBuilder[InternVLProcessingInfo]
):
355
356
357
358
359
360
361
362
363
364
365
    """InternVL DummyInputsBuilder extended for video support"""

    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_videos = mm_counts.get("video", 0)

        return super().get_dummy_text(mm_counts) + "<video>" * num_videos

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
366
        mm_options: Mapping[str, BaseDummyOptions],
367
    ) -> MultiModalDataDict:
368
        dummy_image = super().get_dummy_mm_data(seq_len, mm_counts, mm_options)
369
370
371
        if self.info.supports_video:
            config = self.info.get_hf_config()
            image_size: int = config.vision_config.image_size
372
373
374
            target_num_frames = self.info.get_num_frames_with_most_features(
                seq_len, mm_counts
            )
375
            num_videos = mm_counts.get("video", 0)
376
            video_overrides = mm_options.get("video")
377
            dummy_video = {
378
379
380
381
382
383
384
                "video": self._get_dummy_videos(
                    width=image_size,
                    height=image_size,
                    num_frames=target_num_frames,
                    num_videos=num_videos,
                    overrides=video_overrides,
                )
385
386
387
388
389
390
391
            }
        else:
            dummy_video = {}
        return {**dummy_image, **dummy_video}


class InternVLMultiModalProcessor(
392
393
    BaseInternVLMultiModalProcessor[InternVLProcessingInfo]
):
394
395
396
397
398
399
400
    """InternVL MultiModalProcessor extended for video support"""

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
401
        tok_kwargs: Mapping[str, object],
402
    ) -> BatchFeature:
403
404
405
        processed_outputs = super()._call_hf_processor(
            prompt, mm_data, mm_kwargs, tok_kwargs
        )
406
407

        hf_processor = self.info.get_hf_processor(**mm_kwargs)
408
409
410
411
        if (
            self.info.supports_video
            and (video_token_id := hf_processor.video_token_id) is not None
        ):
412
413
414
415
416
            processed_outputs["video_token_id"] = torch.tensor(video_token_id)
        return processed_outputs

    def _get_mm_fields_config(
        self,
417
        hf_inputs: BatchFeature,
418
419
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
420
        image_fields = super()._get_mm_fields_config(hf_inputs, hf_processor_mm_kwargs)
421
        if self.info.supports_video:
422
            video_num_patches = hf_inputs.get("video_num_patches", torch.empty(0))
423
424
425
            num_videos = len(video_num_patches)
            video_fields = dict(
                pixel_values_flat_video=MultiModalFieldConfig.flat_from_sizes(
426
427
                    "video", video_num_patches
                ),
428
                video_num_patches=MultiModalFieldConfig.batched("video"),
429
                video_token_id=MultiModalFieldConfig.shared("video", num_videos),
430
431
432
433
434
435
436
437
438
439
            )
        else:
            video_fields = {}

        return image_fields | video_fields

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
440
        out_mm_kwargs: MultiModalKwargsItems,
441
    ) -> Sequence[PromptUpdate]:
442
443
444
445
446
        prompt_repl = super()._get_prompt_updates(
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
            out_mm_kwargs=out_mm_kwargs,
        )
447
448
449

        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)

450
451
452
        out_mm_data = out_mm_kwargs.get_data()
        if "video_num_patches" in out_mm_data:
            video_num_patches = out_mm_data["video_num_patches"]
453
454
455
456
457
458
459
460
461
462
463
464
            assert isinstance(video_num_patches, torch.Tensor)
            video_num_patches = video_num_patches.tolist()
        else:
            video_num_patches = []

        def get_video_replacement_internvl(item_idx: int):
            feature_size = hf_processor.num_image_token
            num_patches = video_num_patches[item_idx]
            if num_patches is not None:
                assert isinstance(num_patches, int)

            return hf_processor.get_video_repl(
465
466
                feature_size, num_patches, video_context_token=hf_processor.video_token
            )
467
468

        if self.info.supports_video:
469
470
            prompt_repl = [
                *prompt_repl,
471
472
473
474
                PromptReplacement(
                    modality="video",
                    target="<video>",
                    replacement=get_video_replacement_internvl,
475
                ),
476
477
            ]

478
479
480
        return prompt_repl


481
482
483
@MULTIMODAL_REGISTRY.register_processor(
    InternVLMultiModalProcessor,
    info=InternVLProcessingInfo,
484
485
486
    dummy_inputs=InternVLDummyInputsBuilder,
)
class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
487
488
    supports_encoder_tp_data = True

489
    @classmethod
490
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
491
492
493
494
495
496
497
        if modality.startswith("image"):
            return "<image>"
        if modality.startswith("video"):
            return "<video>"

        raise ValueError("Only image or video modality is supported")

498
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
499
500
        super().__init__()

501
502
503
504
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

505
506
        self.config = config
        self.multimodal_config = multimodal_config
507
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
508
        self._patch_quant_config(config, quant_config)
509
510
511
512

        image_size = config.force_image_size or config.vision_config.image_size
        patch_size = config.vision_config.patch_size
        self.patch_size = patch_size
513
514
        self.patch_tokens = (image_size // patch_size) ** 2
        self.num_image_token = int(self.patch_tokens * (config.downsample_ratio**2))
515
516
517
        self.downsample_ratio = config.downsample_ratio
        self.ps_version = config.ps_version

518
519
        llm_arch_name = config.text_config.architectures[0]
        self.is_mono = llm_arch_name == "InternLM2VEForCausalLM"
520

521
522
523
524
525
526
527
528
        with self._mark_tower_model(vllm_config, {"image", "video"}):
            self.vision_model = self._init_vision_model(
                config,
                quant_config=quant_config,
                is_mono=self.is_mono,
                prefix=maybe_prefix(prefix, "vision_model"),
            )
            self.mlp1 = self._init_mlp1(config)
529

530
531
532
533
534
535
        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=config.text_config,
                prefix=maybe_prefix(prefix, "language_model"),
            )
536
537

        self.img_context_token_id = None
538
539
        self.video_context_token_id = None

540
        self.visual_token_mask = None
541
        self.make_empty_intermediate_tensors = (
542
543
            self.language_model.make_empty_intermediate_tensors
        )
544

545
546
547
    def _patch_quant_config(
        self, config: PretrainedConfig, quant_config: QuantizationConfig
    ):
548
549
550
551
        # the awq models from OpenGVLab missing `modules_to_not_convert`
        # patch the quant_config to add `modules_to_not_convert` back
        if isinstance(quant_config, AWQConfig):
            text_config = config.text_config
552
553
554
555
            llm_quant_config = getattr(text_config, "quantization_config", None)
            if (not quant_config.modules_to_not_convert) and (
                llm_quant_config is not None
            ):
556
557
558
559
560
                quant_config.modules_to_not_convert.append("vision_model")

    def _init_vision_model(
        self,
        config: PretrainedConfig,
561
        quant_config: QuantizationConfig | None,
562
563
564
565
        *,
        is_mono: bool,
        prefix: str,
    ):
566
        if not is_mono:
567
            vision_feature_layer = config.select_layer
568
            if vision_feature_layer < 0:
569
570
571
                num_hidden_layers = (
                    config.vision_config.num_hidden_layers + vision_feature_layer + 1
                )
572
573
            else:
                num_hidden_layers = vision_feature_layer + 1
574

575
576
            return InternVisionModel(
                config.vision_config,
577
578
579
                quant_config=quant_config,
                num_hidden_layers_override=num_hidden_layers,
                prefix=prefix,
580
            )
581
582
        else:
            return InternVisionPatchModel(config.vision_config)
583

584
    def _init_mlp1(self, config: PretrainedConfig) -> nn.Module:
585
586
587
588
        vit_hidden_size = config.vision_config.hidden_size
        llm_hidden_size = config.text_config.hidden_size

        return nn.Sequential(
589
590
591
592
            nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
            nn.Linear(
                vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size
            ),
593
594
595
596
            nn.GELU(),
            nn.Linear(llm_hidden_size, llm_hidden_size),
        )

597
598
599
600
601
602
    def pixel_shuffle(self, x, scale_factor=0.5):
        n, w, h, c = x.size()
        # N, W, H, C --> N, W, H * scale, C // scale
        x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
        # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
        x = x.permute(0, 2, 1, 3).contiguous()
603
604
605
606
607
608
609
        x = x.view(
            n,
            int(h * scale_factor),
            int(w * scale_factor),
            int(c / (scale_factor * scale_factor)),
        )
        if self.ps_version == "v1":
610
611
612
613
614
            pass
        else:
            x = x.permute(0, 2, 1, 3).contiguous()
        return x

615
    def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor:
616
617
618
        vit_embeds = self.vision_model(pixel_values=pixel_values)
        vit_embeds = vit_embeds[:, 1:, :]

619
        h = w = int(vit_embeds.shape[1] ** 0.5)
620
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
621
622
        vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
623
624
625
626
        vit_embeds = self.mlp1(vit_embeds)
        return vit_embeds

    def _parse_and_validate_image_input(
627
        self, **kwargs: object
628
    ) -> InternVLImageInputs | None:
629
630
        pixel_values_flat = kwargs.pop("pixel_values_flat", None)
        image_num_patches = kwargs.pop("image_num_patches", None)
631
        image_embeds = kwargs.pop("image_embeds", None)
632

633
        if pixel_values_flat is None and image_embeds is None:
634
635
            return None

636
637
638
        if image_embeds is not None:
            return InternVLImageEmbeddingInputs(
                type="image_embeds",
639
                data=image_embeds,
640
641
            )

642
        image_token_id = kwargs["image_token_id"]
643
644
645
646
647
        if isinstance(image_token_id, torch.Tensor):
            image_token_id = image_token_id.flatten().unique().item()

        assert isinstance(image_token_id, int)
        self.img_context_token_id = image_token_id
648

649
        if pixel_values_flat is not None:
650
651
            expected_h = expected_w = self.config.vision_config.image_size
            resolve_bindings = {"h": expected_h, "w": expected_w}
652

653
654
            return InternVLImagePixelInputs(
                type="pixel_values",
655
                pixel_values_flat=pixel_values_flat,
656
                num_patches=image_num_patches,
657
                resolve_bindings=resolve_bindings,
658
            )
659
660
661

        raise AssertionError("This line should be unreachable.")

662
    def _parse_and_validate_video_input(
663
        self, **kwargs: object
664
    ) -> InternVLVideoPixelInputs | None:
665
666
667
668
669
670
671
672
        pixel_values_flat_video = kwargs.pop("pixel_values_flat_video", None)
        video_num_patches = kwargs.pop("video_num_patches", None)
        video_embeds = kwargs.pop("image_embeds", None)

        if pixel_values_flat_video is None and video_embeds is None:
            return None

        if video_embeds is not None:
673
            return InternVLVideoEmbeddingInputs(
674
                type="video_embeds",
675
                data=video_embeds,
676
677
678
            )

        video_token_id = kwargs["video_token_id"]
679
680
681
682
683
        if isinstance(video_token_id, torch.Tensor):
            video_token_id = video_token_id.flatten().unique().item()

        assert isinstance(video_token_id, int)
        self.video_context_token_id = video_token_id
684
685

        if pixel_values_flat_video is not None:
686
687
            expected_h = expected_w = self.config.vision_config.image_size
            resolve_bindings = {"h": expected_h, "w": expected_w}
688
689
690

            return InternVLVideoPixelInputs(
                type="pixel_values_videos",
691
                pixel_values_flat=pixel_values_flat_video,
692
                num_patches=video_num_patches,
693
                resolve_bindings=resolve_bindings,
694
695
696
697
            )

        raise AssertionError("This line should be unreachable.")

698
    def _process_vision_input(
699
        self,
700
        image_input: InternVLImageInputs | InternVLVideoInputs,
701
    ) -> tuple[torch.Tensor, ...]:
702
703
704
705
        if (
            image_input["type"] == "image_embeds"
            or image_input["type"] == "video_embeds"
        ):
706
707
            return image_input["data"]

708
        image_embeds = self.extract_feature(image_input["pixel_values_flat"])
709

710
        num_patches = image_input["num_patches"]
711
712

        # Only one image in the current batch
713
        if len(num_patches) == 1:
714
            return (image_embeds.view(-1, self.config.text_config.hidden_size),)
715
716
717
718

        # NOTE: Image embeddings are split into separate tensors for each image
        # by the size of each embedding.
        feature_size = image_embeds.shape[1]
719
        image_embeds = image_embeds.view(-1, self.config.text_config.hidden_size)
720
        image_feature_sizes = [
721
            num_patches * feature_size for num_patches in num_patches
722
        ]
723
        return image_embeds.split(image_feature_sizes)
724

725
726
727
728
729
730
    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        modalities = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
731
732
733
734
735
736
737
            if (
                input_key in ("pixel_values_flat", "image_embeds")
                and "images" not in modalities
            ):
                modalities["images"] = self._parse_and_validate_image_input(**kwargs)
            if input_key in ("pixel_values_flat_video",) and "videos" not in modalities:
                modalities["videos"] = self._parse_and_validate_video_input(**kwargs)
738
739
740

        return modalities

741
    def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
742
        if self.is_mono:
743
            assert self.img_context_token_id is not None
744
745
746
            self.visual_token_mask = (input_ids == self.img_context_token_id).reshape(
                -1, 1
            )
747
        else:
748
            self.visual_token_mask = None
749

750
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
751
752
        modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not modalities:
753
            return []
754

755
756
757
758
759
760
761
762
763
        # The result multimodal_embeddings is tuple of tensors, with each
        # tensor correspoending to a multimodal data item (image or video).
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in modalities:
            if modality == "images":
                image_input = modalities["images"]
764
765
                image_embeddings = self._process_vision_input(image_input)
                multimodal_embeddings += tuple(image_embeddings)
766
767
            if modality == "videos":
                video_input = modalities["videos"]
768
                video_embeddings = self._process_vision_input(video_input)
769
                multimodal_embeddings += tuple(video_embeddings)
770
771

        return multimodal_embeddings
772

773
    def embed_input_ids(
774
775
        self,
        input_ids: torch.Tensor,
776
        multimodal_embeddings: MultiModalEmbeddings | None = None,
777
        *,
778
        is_multimodal: torch.Tensor | None = None,
779
    ) -> torch.Tensor:
780
        if multimodal_embeddings is not None and len(multimodal_embeddings) > 0:
781
            self._set_visual_token_mask(input_ids)
782
783
784

        # This is to satisfy the type checker for each overload
        if multimodal_embeddings is None or is_multimodal is None:
785
            return super().embed_input_ids(input_ids)
786

787
        return super().embed_input_ids(
788
789
790
791
            input_ids,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
        )
792

793
794
    def forward(
        self,
795
        input_ids: torch.Tensor | None,
796
        positions: torch.Tensor,
797
798
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
799
        **kwargs: object,
800
    ) -> IntermediateTensors:
801
        if intermediate_tensors is not None:
802
            inputs_embeds = None
803

804
805
806
807
808
809
        forward_kwargs = {
            "input_ids": input_ids,
            "positions": positions,
            "intermediate_tensors": intermediate_tensors,
            "inputs_embeds": inputs_embeds,
        }
810

811
        # Only required if the model is mono-architecture
812
        if self.visual_token_mask is not None:
813
            forward_kwargs.update({"visual_token_mask": self.visual_token_mask})
814
            self.visual_token_mask = None
815

816
        hidden_states = self.language_model.model(**forward_kwargs)
817
818
        return hidden_states

819
820
821
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
822
    ) -> torch.Tensor | None:
823
        return self.language_model.compute_logits(hidden_states)
824

825
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
826
827
        # unused modules appear in OpenGVLab/InternVideo2_5_Chat_8B
        skip_prefixes = [
828
829
830
831
832
833
834
835
836
837
838
839
            "action_embed",
            "temporal_embed",
            "track_embed",
            "track_embed_decoder",
            "box_token",
            "cg_criterion",
            "cg_model",
            "loc_encoder",
            "loc_decoder",
            "sam",
            "temporal_token",
            "track_token",
840
841
        ]
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
842
        return loader.load_weights(weights)
843
844
845
846
847
848
849
850

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector="mlp1",
851
852
            tower_model="vision_model",
        )
853
854
855
856
857
858
859
860
861
862
863
864
865
866

    def get_num_mm_encoder_tokens(self, num_image_tokens: int) -> int:
        if num_image_tokens <= 0 or self.num_image_token <= 0:
            return 0

        num_patches = num_image_tokens // self.num_image_token
        return num_patches * (self.patch_tokens + 1)

    def get_num_mm_connector_tokens(self, num_vision_tokens: int) -> int:
        if num_vision_tokens <= 0 or self.num_image_token <= 0:
            return 0

        num_patches = num_vision_tokens // (self.patch_tokens + 1)
        return num_patches * self.num_image_token