internvl.py 31.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_internvl_chat.py
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
10
from abc import abstractmethod
11
from collections.abc import Iterable, Mapping, Sequence
12
from functools import cached_property
13
from typing import Annotated, Literal, TypeAlias, TypeVar
14
15
16

import torch
import torch.nn as nn
17
from transformers import BatchFeature, PretrainedConfig
18

19
from vllm.config import VllmConfig
20
from vllm.config.multimodal import BaseDummyOptions
21
22
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
23
24
25
26
from vllm.model_executor.models.intern_vit import (
    InternVisionModel,
    InternVisionPatchModel,
)
27
from vllm.model_executor.models.module_mapping import MultiModelKeys
28
from vllm.multimodal import MULTIMODAL_REGISTRY
29
30
31
32
33
34
35
36
37
38
39
40
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
from vllm.multimodal.parse import (
    ImageEmbeddingItems,
    ImageProcessorItems,
    ImageSize,
    MultiModalDataItems,
)
from vllm.multimodal.processing import (
41
    BaseDummyInputsBuilder,
42
43
44
45
46
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
47
from vllm.sequence import IntermediateTensors
48
from vllm.transformers_utils.processors.internvl import (
49
    InternVLImageProcessor,
50
    InternVLProcessor,
51
    InternVLVideoProcessor,
52
)
53
from vllm.utils.tensor_schema import TensorSchema, TensorShape
54

55
56
57
58
59
60
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMultiModal,
    SupportsPP,
)
61
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
62
63


64
class InternVLImagePixelInputs(TensorSchema):
65
    """
66
67
68
69
70
71
    Dimensions:
        - bn: Batch size * number of images
        - bnp: Batch size * number of images * (1 + num_patches)
        - c: Number of channels (3)
        - h: Height of each image patch
        - w: Width of each image patch
72
    """
73

74
75
76
    type: Literal["pixel_values"]
    pixel_values_flat: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")]
    num_patches: Annotated[torch.Tensor, TensorShape("bn")]
77

78

79
80
81
82
83
84
class InternVLImageEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - n: Number of images
        - f: Total image feature size
        - h: Hidden size (must match the hidden size of language model backbone)
85
    """
86

87
    type: Literal["image_embeds"]
88
    data: Annotated[torch.Tensor | list[torch.Tensor], TensorShape("n", "f", "h")]
89
90


91
InternVLImageInputs: TypeAlias = InternVLImagePixelInputs | InternVLImageEmbeddingInputs
92
93


94
class InternVLVideoPixelInputs(TensorSchema):
95
    """
96
97
98
99
100
101
    Dimensions:
        - bvf: Batch size * number of videos * num_frames
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height of each video frame
        - w: Width of each video frame
102
    """
103

104
105
106
    type: Literal["pixel_values_videos"]
    pixel_values_flat: Annotated[torch.Tensor, TensorShape("bvf", 3, "h", "w")]
    num_patches: Annotated[torch.Tensor, TensorShape("bn")]
107
108


109
110
111
112
113
114
class InternVLVideoEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - n: Number of videos
        - f: Total video feature size
        - h: Hidden size (must match the hidden size of language model backbone)
115
    """
116

117
    type: Literal["video_embeds"]
118
    data: Annotated[torch.Tensor | list[torch.Tensor], TensorShape("n", "f", "h")]
119
120


121
InternVLVideoInputs: TypeAlias = InternVLVideoPixelInputs | InternVLVideoEmbeddingInputs
122
123


124
class BaseInternVLProcessingInfo(BaseProcessingInfo):
125
    """Basic image-only ProcessingInfo for InternVL-style models."""
126
127

    @abstractmethod
128
    def get_hf_processor(self, **kwargs: object) -> InternVLProcessor:
129
130
        raise NotImplementedError

131
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
132
133
134
135
        return {"image": None}

    def get_num_image_tokens(
        self,
136
        *,
137
138
        image_width: int,
        image_height: int,
139
        processor: InternVLProcessor,
140
141
142
143
144
    ) -> int:
        return processor.get_num_image_tokens(
            image_width=image_width,
            image_height=image_height,
        )
145

146
147
    def get_image_size_with_most_features(self) -> ImageSize:
        processor = self.get_hf_processor()
148
        image_processor = processor.image_processor
149

150
        base_size = image_processor.image_size
151
152
153
154
155
156
157
158
159
160
161
162
163
        target_ratios = processor.resolve_target_ratios()

        largest_feature_size, largest_feature_pinpoint = 0, None
        for wr, hr in target_ratios:
            width, height = base_size * wr, base_size * hr

            feat_size = self.get_num_image_tokens(
                image_width=width,
                image_height=height,
                processor=processor,
            )
            if feat_size > largest_feature_size:
                largest_feature_size = feat_size
164
                largest_feature_pinpoint = ImageSize(width=width, height=height)
165
166
167
168
169
170

        if largest_feature_size == 0 or largest_feature_pinpoint is None:
            raise ValueError("Cannot have a largest feature size of 0!")

        return largest_feature_pinpoint

171
172
173
174
175
176
177
178
179
180
    def get_max_image_tokens(self) -> int:
        processor = self.get_hf_processor()
        target_width, target_height = self.get_image_size_with_most_features()

        return self.get_num_image_tokens(
            image_width=target_width,
            image_height=target_height,
            processor=processor,
        )

181
182
183
184

_I = TypeVar("_I", bound=BaseInternVLProcessingInfo)


185
186
class BaseInternVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
    """Basic image-only DummyInputsBuilder for InternVL-style models."""
187

188
189
190
191
192
193
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        return "<image>" * num_images

    def get_dummy_mm_data(
194
195
196
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
197
        mm_options: Mapping[str, BaseDummyOptions],
198
    ) -> MultiModalDataDict:
199
        target_width, target_height = self.info.get_image_size_with_most_features()
200
201
        num_images = mm_counts.get("image", 0)

202
        image_overrides = mm_options.get("image")
203

204
        return {
205
206
207
208
209
210
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
211
212
213
        }


214
class BaseInternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
215
    """Basic image-only MultiModalProcessor for InternVL-style models."""
216
217
218
219
220
221

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
222
        tok_kwargs: Mapping[str, object],
223
    ) -> BatchFeature:
224
225
226
227
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
228
            tok_kwargs=tok_kwargs,
229
        )
230

231
        hf_processor = self.info.get_hf_processor(**mm_kwargs)
232
        image_token_id = hf_processor.ctx_image_token_id
233
234
235
236

        # Since there may be extra tokens in the feature placeholders,
        # we need to pass the image token ID to the model to select the
        # tokens to merge from the vision encoder outputs
237
        processed_outputs["image_token_id"] = torch.tensor(image_token_id)
238
239
240
241
242

        return processed_outputs

    def _get_mm_fields_config(
        self,
243
        hf_inputs: BatchFeature,
244
245
246
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
247
        num_images = len(image_num_patches)
248
249
250

        return dict(
            pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
251
252
                "image", image_num_patches
            ),
253
254
            image_num_patches=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
255
            image_token_id=MultiModalFieldConfig.shared("image", num_images),
256
257
        )

258
    def _get_prompt_updates(
259
260
261
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
262
        out_mm_kwargs: MultiModalKwargsItems,
263
    ) -> Sequence[PromptUpdate]:
264
265
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)

266
267
268
        out_mm_data = out_mm_kwargs.get_data()
        if "image_num_patches" in out_mm_data:
            image_num_patches = out_mm_data["image_num_patches"]
269
270
            assert isinstance(image_num_patches, torch.Tensor)
            image_num_patches = image_num_patches.tolist()
271
        elif "image_embeds" in out_mm_data:
272
273
            # TODO: Use image size information in dictionary embedding inputs
            # to compute num_patches (similar to Qwen2-VL)
274
            image_num_patches = [None] * len(out_mm_data["image_embeds"])
275
276
277
278
279
        else:
            image_num_patches = []

        def get_replacement_internvl(item_idx: int):
            images = mm_items.get_items(
280
281
                "image", (ImageEmbeddingItems, ImageProcessorItems)
            )
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296

            if isinstance(images, ImageEmbeddingItems):
                feature_size = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
                feature_size = self.info.get_num_image_tokens(
                    image_width=image_size.width,
                    image_height=image_size.height,
                    processor=hf_processor,
                )

            num_patches = image_num_patches[item_idx]
            if num_patches is not None:
                assert isinstance(num_patches, int)

297
            return hf_processor.get_image_repl(num_patches, num_features=feature_size)
298

299
300
301
302
303
304
305
        return [
            PromptReplacement(
                modality="image",
                target="<image>",
                replacement=get_replacement_internvl,
            )
        ]
306
307


308
class InternVLProcessingInfo(BaseInternVLProcessingInfo):
309
310
    """InternVL ProcessingInfo extended for video processing"""

311
312
313
    def get_image_processor(self, **kwargs):
        config = self.get_hf_config()
        vision_config = config.vision_config
314

315
316
317
318
319
320
321
322
        kwargs = self.ctx.get_merged_mm_kwargs(kwargs)
        kwargs.setdefault("image_size", vision_config.image_size)
        kwargs.setdefault("min_dynamic_patch", config.min_dynamic_patch)
        kwargs.setdefault("max_dynamic_patch", config.max_dynamic_patch)
        kwargs.setdefault("dynamic_image_size", config.dynamic_image_size)
        kwargs.setdefault("use_thumbnail", config.use_thumbnail)

        return InternVLImageProcessor(**kwargs)
323

324
325
326
327
328
329
330
331
332
333
334
    def get_video_processor(self, **kwargs):
        config = self.get_hf_config()
        vision_config = config.vision_config

        kwargs = self.ctx.get_merged_mm_kwargs(kwargs)
        kwargs.setdefault("image_size", vision_config.image_size)

        return InternVLVideoProcessor(**kwargs)

    @cached_property
    def ctx_video_token(self):
335
        text_model_type = self.get_hf_config().get_text_config().model_type
336
        ctx_video_token_map = {
337
338
339
340
341
            "qwen2": "<|video_pad|>",
            "qwen3": "<|video_pad|>",
            "qwen3_moe": "<|video_pad|>",
            "gpt_oss": "<|reserved_200000|>",
        }
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377

        if text_model_type not in ctx_video_token_map:
            return None

        ctx_video_token = ctx_video_token_map[text_model_type]
        if ctx_video_token not in self.get_tokenizer().get_vocab():
            return None

        return ctx_video_token

    def get_hf_processor(self, **kwargs: object) -> InternVLProcessor:
        config = self.get_hf_config()
        vision_config = config.vision_config

        image_processor = self.get_image_processor(**kwargs)
        image_size = image_processor.image_size
        patch_size = vision_config.patch_size
        downsample_ratio = config.downsample_ratio
        image_seq_length = int((image_size // patch_size) ** 2 * (downsample_ratio**2))

        ctx_video_token = self.ctx_video_token
        video_processor = (
            self.get_video_processor(**kwargs) if ctx_video_token else None
        )

        return InternVLProcessor(
            tokenizer=self.get_tokenizer(),
            image_processor=image_processor,
            video_processor=video_processor,
            image_seq_length=image_seq_length,
            ctx_video_token=ctx_video_token,
        )

    def get_supported_mm_limits(self):
        video_limit = {"video": None} if self.ctx_video_token else {}
        return {**super().get_supported_mm_limits(), **video_limit}
378
379
380
381
382
383
384
385
386
387

    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
        max_images = mm_counts.get("image", 0)
        max_videos = mm_counts.get("video", 0)

        processor = self.get_hf_processor()
388
        num_image_token = processor.image_seq_length
389
390

        max_image_tokens = self.get_max_image_tokens() * max_images
391
        max_total_frames = (seq_len - max_image_tokens) // num_image_token
392
393
394
        max_frames_per_video = max_total_frames // max(max_videos, 1)

        return max(max_frames_per_video, 1)
395

396

397
class InternVLDummyInputsBuilder(
398
399
    BaseInternVLDummyInputsBuilder[InternVLProcessingInfo]
):
400
401
402
403
404
405
406
407
408
409
410
    """InternVL DummyInputsBuilder extended for video support"""

    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_videos = mm_counts.get("video", 0)

        return super().get_dummy_text(mm_counts) + "<video>" * num_videos

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
411
        mm_options: Mapping[str, BaseDummyOptions],
412
    ) -> MultiModalDataDict:
413
        dummy_image = super().get_dummy_mm_data(seq_len, mm_counts, mm_options)
414
        if self.info.ctx_video_token:
415
416
            config = self.info.get_hf_config()
            image_size: int = config.vision_config.image_size
417
418
419
            target_num_frames = self.info.get_num_frames_with_most_features(
                seq_len, mm_counts
            )
420
            num_videos = mm_counts.get("video", 0)
421
            video_overrides = mm_options.get("video")
422
            dummy_video = {
423
424
425
426
427
428
429
                "video": self._get_dummy_videos(
                    width=image_size,
                    height=image_size,
                    num_frames=target_num_frames,
                    num_videos=num_videos,
                    overrides=video_overrides,
                )
430
431
432
433
434
435
436
            }
        else:
            dummy_video = {}
        return {**dummy_image, **dummy_video}


class InternVLMultiModalProcessor(
437
438
    BaseInternVLMultiModalProcessor[InternVLProcessingInfo]
):
439
440
441
442
443
444
445
    """InternVL MultiModalProcessor extended for video support"""

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
446
        tok_kwargs: Mapping[str, object],
447
    ) -> BatchFeature:
448
449
450
        processed_outputs = super()._call_hf_processor(
            prompt, mm_data, mm_kwargs, tok_kwargs
        )
451
452

        hf_processor = self.info.get_hf_processor(**mm_kwargs)
453
        if (video_token_id := hf_processor.ctx_video_token_id) is not None:
454
            processed_outputs["video_token_id"] = torch.tensor(video_token_id)
455

456
457
458
459
        return processed_outputs

    def _get_mm_fields_config(
        self,
460
        hf_inputs: BatchFeature,
461
462
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
463
        image_fields = super()._get_mm_fields_config(hf_inputs, hf_processor_mm_kwargs)
464
        if self.info.ctx_video_token:
465
            video_num_patches = hf_inputs.get("video_num_patches", torch.empty(0))
466
467
468
            num_videos = len(video_num_patches)
            video_fields = dict(
                pixel_values_flat_video=MultiModalFieldConfig.flat_from_sizes(
469
470
                    "video", video_num_patches
                ),
471
                video_num_patches=MultiModalFieldConfig.batched("video"),
472
                video_token_id=MultiModalFieldConfig.shared("video", num_videos),
473
474
475
476
477
478
479
480
481
482
            )
        else:
            video_fields = {}

        return image_fields | video_fields

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
483
        out_mm_kwargs: MultiModalKwargsItems,
484
    ) -> Sequence[PromptUpdate]:
485
486
487
488
489
        prompt_repl = super()._get_prompt_updates(
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
            out_mm_kwargs=out_mm_kwargs,
        )
490
491
        if self.info.ctx_video_token is None:
            return prompt_repl
492
493
494

        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)

495
496
497
        out_mm_data = out_mm_kwargs.get_data()
        if "video_num_patches" in out_mm_data:
            video_num_patches = out_mm_data["video_num_patches"]
498
499
500
501
502
503
504
505
506
507
            assert isinstance(video_num_patches, torch.Tensor)
            video_num_patches = video_num_patches.tolist()
        else:
            video_num_patches = []

        def get_video_replacement_internvl(item_idx: int):
            num_patches = video_num_patches[item_idx]
            if num_patches is not None:
                assert isinstance(num_patches, int)

508
            return hf_processor.get_video_repl(num_patches)
509

510
511
512
513
514
515
516
517
        return [
            *prompt_repl,
            PromptReplacement(
                modality="video",
                target="<video>",
                replacement=get_video_replacement_internvl,
            ),
        ]
518
519


520
521
522
@MULTIMODAL_REGISTRY.register_processor(
    InternVLMultiModalProcessor,
    info=InternVLProcessingInfo,
523
524
525
    dummy_inputs=InternVLDummyInputsBuilder,
)
class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
526
527
    supports_encoder_tp_data = True

528
    @classmethod
529
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
530
531
532
533
534
535
536
        if modality.startswith("image"):
            return "<image>"
        if modality.startswith("video"):
            return "<video>"

        raise ValueError("Only image or video modality is supported")

537
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
538
539
        super().__init__()

540
541
542
543
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

544
545
        self.config = config
        self.multimodal_config = multimodal_config
546
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
547
        self._patch_quant_config(config, quant_config)
548
549
550
551

        image_size = config.force_image_size or config.vision_config.image_size
        patch_size = config.vision_config.patch_size
        self.patch_size = patch_size
552
553
        self.patch_tokens = (image_size // patch_size) ** 2
        self.num_image_token = int(self.patch_tokens * (config.downsample_ratio**2))
554
555
556
        self.downsample_ratio = config.downsample_ratio
        self.ps_version = config.ps_version

557
558
        llm_arch_name = config.text_config.architectures[0]
        self.is_mono = llm_arch_name == "InternLM2VEForCausalLM"
559

560
561
562
563
564
565
566
567
        with self._mark_tower_model(vllm_config, {"image", "video"}):
            self.vision_model = self._init_vision_model(
                config,
                quant_config=quant_config,
                is_mono=self.is_mono,
                prefix=maybe_prefix(prefix, "vision_model"),
            )
            self.mlp1 = self._init_mlp1(config)
568

569
570
571
572
573
574
        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=config.text_config,
                prefix=maybe_prefix(prefix, "language_model"),
            )
575
576

        self.img_context_token_id = None
577
578
        self.video_context_token_id = None

579
        self.visual_token_mask = None
580
        self.make_empty_intermediate_tensors = (
581
582
            self.language_model.make_empty_intermediate_tensors
        )
583

584
585
586
    def _patch_quant_config(
        self, config: PretrainedConfig, quant_config: QuantizationConfig
    ):
587
588
589
590
        # the awq models from OpenGVLab missing `modules_to_not_convert`
        # patch the quant_config to add `modules_to_not_convert` back
        if isinstance(quant_config, AWQConfig):
            text_config = config.text_config
591
592
593
594
            llm_quant_config = getattr(text_config, "quantization_config", None)
            if (not quant_config.modules_to_not_convert) and (
                llm_quant_config is not None
            ):
595
596
597
598
599
                quant_config.modules_to_not_convert.append("vision_model")

    def _init_vision_model(
        self,
        config: PretrainedConfig,
600
        quant_config: QuantizationConfig | None,
601
602
603
604
        *,
        is_mono: bool,
        prefix: str,
    ):
605
        if not is_mono:
606
            vision_feature_layer = config.select_layer
607
            if vision_feature_layer < 0:
608
609
610
                num_hidden_layers = (
                    config.vision_config.num_hidden_layers + vision_feature_layer + 1
                )
611
612
            else:
                num_hidden_layers = vision_feature_layer + 1
613

614
615
            return InternVisionModel(
                config.vision_config,
616
617
618
                quant_config=quant_config,
                num_hidden_layers_override=num_hidden_layers,
                prefix=prefix,
619
            )
620
621
        else:
            return InternVisionPatchModel(config.vision_config)
622

623
    def _init_mlp1(self, config: PretrainedConfig) -> nn.Module:
624
625
626
627
        vit_hidden_size = config.vision_config.hidden_size
        llm_hidden_size = config.text_config.hidden_size

        return nn.Sequential(
628
629
630
631
            nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
            nn.Linear(
                vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size
            ),
632
633
634
635
            nn.GELU(),
            nn.Linear(llm_hidden_size, llm_hidden_size),
        )

636
637
638
639
640
641
    def pixel_shuffle(self, x, scale_factor=0.5):
        n, w, h, c = x.size()
        # N, W, H, C --> N, W, H * scale, C // scale
        x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
        # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
        x = x.permute(0, 2, 1, 3).contiguous()
642
643
644
645
646
647
648
        x = x.view(
            n,
            int(h * scale_factor),
            int(w * scale_factor),
            int(c / (scale_factor * scale_factor)),
        )
        if self.ps_version == "v1":
649
650
651
652
653
            pass
        else:
            x = x.permute(0, 2, 1, 3).contiguous()
        return x

654
    def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor:
655
656
657
        vit_embeds = self.vision_model(pixel_values=pixel_values)
        vit_embeds = vit_embeds[:, 1:, :]

658
        h = w = int(vit_embeds.shape[1] ** 0.5)
659
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
660
661
        vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
662
663
664
665
        vit_embeds = self.mlp1(vit_embeds)
        return vit_embeds

    def _parse_and_validate_image_input(
666
        self, **kwargs: object
667
    ) -> InternVLImageInputs | None:
668
669
        pixel_values_flat = kwargs.pop("pixel_values_flat", None)
        image_num_patches = kwargs.pop("image_num_patches", None)
670
        image_embeds = kwargs.pop("image_embeds", None)
671

672
        if pixel_values_flat is None and image_embeds is None:
673
674
            return None

675
676
677
        if image_embeds is not None:
            return InternVLImageEmbeddingInputs(
                type="image_embeds",
678
                data=image_embeds,
679
680
            )

681
        image_token_id = kwargs["image_token_id"]
682
683
684
685
686
        if isinstance(image_token_id, torch.Tensor):
            image_token_id = image_token_id.flatten().unique().item()

        assert isinstance(image_token_id, int)
        self.img_context_token_id = image_token_id
687

688
        if pixel_values_flat is not None:
689
690
            expected_h = expected_w = self.config.vision_config.image_size
            resolve_bindings = {"h": expected_h, "w": expected_w}
691

692
693
            return InternVLImagePixelInputs(
                type="pixel_values",
694
                pixel_values_flat=pixel_values_flat,
695
                num_patches=image_num_patches,
696
                resolve_bindings=resolve_bindings,
697
            )
698
699
700

        raise AssertionError("This line should be unreachable.")

701
    def _parse_and_validate_video_input(
702
        self, **kwargs: object
703
    ) -> InternVLVideoPixelInputs | None:
704
705
706
707
708
709
710
711
        pixel_values_flat_video = kwargs.pop("pixel_values_flat_video", None)
        video_num_patches = kwargs.pop("video_num_patches", None)
        video_embeds = kwargs.pop("image_embeds", None)

        if pixel_values_flat_video is None and video_embeds is None:
            return None

        if video_embeds is not None:
712
            return InternVLVideoEmbeddingInputs(
713
                type="video_embeds",
714
                data=video_embeds,
715
716
717
            )

        video_token_id = kwargs["video_token_id"]
718
719
720
721
722
        if isinstance(video_token_id, torch.Tensor):
            video_token_id = video_token_id.flatten().unique().item()

        assert isinstance(video_token_id, int)
        self.video_context_token_id = video_token_id
723
724

        if pixel_values_flat_video is not None:
725
726
            expected_h = expected_w = self.config.vision_config.image_size
            resolve_bindings = {"h": expected_h, "w": expected_w}
727
728
729

            return InternVLVideoPixelInputs(
                type="pixel_values_videos",
730
                pixel_values_flat=pixel_values_flat_video,
731
                num_patches=video_num_patches,
732
                resolve_bindings=resolve_bindings,
733
734
735
736
            )

        raise AssertionError("This line should be unreachable.")

737
    def _process_vision_input(
738
        self,
739
        image_input: InternVLImageInputs | InternVLVideoInputs,
740
    ) -> tuple[torch.Tensor, ...]:
741
742
743
744
        if (
            image_input["type"] == "image_embeds"
            or image_input["type"] == "video_embeds"
        ):
745
746
            return image_input["data"]

747
        image_embeds = self.extract_feature(image_input["pixel_values_flat"])
748

749
        num_patches = image_input["num_patches"]
750
751

        # Only one image in the current batch
752
        if len(num_patches) == 1:
753
            return (image_embeds.view(-1, self.config.text_config.hidden_size),)
754
755
756
757

        # NOTE: Image embeddings are split into separate tensors for each image
        # by the size of each embedding.
        feature_size = image_embeds.shape[1]
758
        image_embeds = image_embeds.view(-1, self.config.text_config.hidden_size)
759
        image_feature_sizes = [
760
            num_patches * feature_size for num_patches in num_patches
761
        ]
762
        return image_embeds.split(image_feature_sizes)
763

764
765
766
767
768
769
    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        modalities = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
770
771
772
773
774
775
776
            if (
                input_key in ("pixel_values_flat", "image_embeds")
                and "images" not in modalities
            ):
                modalities["images"] = self._parse_and_validate_image_input(**kwargs)
            if input_key in ("pixel_values_flat_video",) and "videos" not in modalities:
                modalities["videos"] = self._parse_and_validate_video_input(**kwargs)
777
778
779

        return modalities

780
    def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
781
        if self.is_mono:
782
            assert self.img_context_token_id is not None
783
784
785
            self.visual_token_mask = (input_ids == self.img_context_token_id).reshape(
                -1, 1
            )
786
        else:
787
            self.visual_token_mask = None
788

789
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
790
791
        modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not modalities:
792
            return []
793

794
795
796
797
798
799
800
801
802
        # The result multimodal_embeddings is tuple of tensors, with each
        # tensor correspoending to a multimodal data item (image or video).
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in modalities:
            if modality == "images":
                image_input = modalities["images"]
803
804
                image_embeddings = self._process_vision_input(image_input)
                multimodal_embeddings += tuple(image_embeddings)
805
806
            if modality == "videos":
                video_input = modalities["videos"]
807
                video_embeddings = self._process_vision_input(video_input)
808
                multimodal_embeddings += tuple(video_embeddings)
809
810

        return multimodal_embeddings
811

812
    def embed_input_ids(
813
814
        self,
        input_ids: torch.Tensor,
815
        multimodal_embeddings: MultiModalEmbeddings | None = None,
816
        *,
817
        is_multimodal: torch.Tensor | None = None,
818
    ) -> torch.Tensor:
819
        if multimodal_embeddings is not None and len(multimodal_embeddings) > 0:
820
            self._set_visual_token_mask(input_ids)
821
822
823

        # This is to satisfy the type checker for each overload
        if multimodal_embeddings is None or is_multimodal is None:
824
            return super().embed_input_ids(input_ids)
825

826
        return super().embed_input_ids(
827
828
829
830
            input_ids,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
        )
831

832
833
    def forward(
        self,
834
        input_ids: torch.Tensor | None,
835
        positions: torch.Tensor,
836
837
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
838
        **kwargs: object,
839
    ) -> IntermediateTensors:
840
        if intermediate_tensors is not None:
841
            inputs_embeds = None
842

843
844
845
846
847
848
        forward_kwargs = {
            "input_ids": input_ids,
            "positions": positions,
            "intermediate_tensors": intermediate_tensors,
            "inputs_embeds": inputs_embeds,
        }
849

850
        # Only required if the model is mono-architecture
851
        if self.visual_token_mask is not None:
852
            forward_kwargs.update({"visual_token_mask": self.visual_token_mask})
853
            self.visual_token_mask = None
854

855
        hidden_states = self.language_model.model(**forward_kwargs)
856
857
        return hidden_states

858
859
860
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
861
    ) -> torch.Tensor | None:
862
        return self.language_model.compute_logits(hidden_states)
863

864
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
865
866
        # unused modules appear in OpenGVLab/InternVideo2_5_Chat_8B
        skip_prefixes = [
867
868
869
870
871
872
873
874
875
876
877
878
            "action_embed",
            "temporal_embed",
            "track_embed",
            "track_embed_decoder",
            "box_token",
            "cg_criterion",
            "cg_model",
            "loc_encoder",
            "loc_decoder",
            "sam",
            "temporal_token",
            "track_token",
879
880
        ]
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
881
        return loader.load_weights(weights)
882
883
884
885
886
887
888
889

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector="mlp1",
890
891
            tower_model="vision_model",
        )
892
893
894
895
896
897
898
899
900
901
902
903
904
905

    def get_num_mm_encoder_tokens(self, num_image_tokens: int) -> int:
        if num_image_tokens <= 0 or self.num_image_token <= 0:
            return 0

        num_patches = num_image_tokens // self.num_image_token
        return num_patches * (self.patch_tokens + 1)

    def get_num_mm_connector_tokens(self, num_vision_tokens: int) -> int:
        if num_vision_tokens <= 0 or self.num_image_token <= 0:
            return 0

        num_patches = num_vision_tokens // (self.patch_tokens + 1)
        return num_patches * self.num_image_token