"tests/vscode:/vscode.git/clone" did not exist on "6c04638214d413dfafa2ab1bf3f16069878e60f9"
internvl.py 32.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_internvl_chat.py
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
10
from abc import abstractmethod
11
from collections.abc import Iterable, Mapping, Sequence
12
from functools import cached_property
13
from typing import Annotated, Literal, TypeAlias, TypeVar
14
15
16

import torch
import torch.nn as nn
17
from transformers import BatchFeature, PretrainedConfig
18

19
from vllm.config import VllmConfig
20
from vllm.config.multimodal import BaseDummyOptions
21
from vllm.inputs import MultiModalDataDict
22
23
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
24
25
26
27
from vllm.model_executor.models.intern_vit import (
    InternVisionModel,
    InternVisionPatchModel,
)
28
from vllm.model_executor.models.module_mapping import MultiModelKeys
29
from vllm.multimodal import MULTIMODAL_REGISTRY
30
from vllm.multimodal.inputs import (
31
    BatchedTensorInputs,
32
33
34
35
36
37
38
39
40
41
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
from vllm.multimodal.parse import (
    ImageEmbeddingItems,
    ImageProcessorItems,
    ImageSize,
    MultiModalDataItems,
)
from vllm.multimodal.processing import (
42
    BaseDummyInputsBuilder,
43
44
45
46
47
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
48
from vllm.sequence import IntermediateTensors
49
from vllm.transformers_utils.processors.internvl import (
50
    InternVLImageProcessor,
51
    InternVLProcessor,
52
    InternVLVideoProcessor,
53
)
54
from vllm.utils.tensor_schema import TensorSchema, TensorShape
55

56
57
58
59
60
61
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMultiModal,
    SupportsPP,
)
62
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
63
64


65
class InternVLImagePixelInputs(TensorSchema):
66
    """
67
68
69
70
71
72
    Dimensions:
        - bn: Batch size * number of images
        - bnp: Batch size * number of images * (1 + num_patches)
        - c: Number of channels (3)
        - h: Height of each image patch
        - w: Width of each image patch
73
    """
74

75
76
77
    type: Literal["pixel_values"]
    pixel_values_flat: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")]
    num_patches: Annotated[torch.Tensor, TensorShape("bn")]
78

79

80
81
82
83
84
85
class InternVLImageEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - n: Number of images
        - f: Total image feature size
        - h: Hidden size (must match the hidden size of language model backbone)
86
    """
87

88
    type: Literal["image_embeds"]
89
    data: Annotated[torch.Tensor | list[torch.Tensor], TensorShape("n", "f", "h")]
90
91


92
InternVLImageInputs: TypeAlias = InternVLImagePixelInputs | InternVLImageEmbeddingInputs
93
94


95
class InternVLVideoPixelInputs(TensorSchema):
96
    """
97
98
99
100
101
102
    Dimensions:
        - bvf: Batch size * number of videos * num_frames
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height of each video frame
        - w: Width of each video frame
103
    """
104

105
106
107
    type: Literal["pixel_values_videos"]
    pixel_values_flat: Annotated[torch.Tensor, TensorShape("bvf", 3, "h", "w")]
    num_patches: Annotated[torch.Tensor, TensorShape("bn")]
108
109


110
111
112
113
114
115
class InternVLVideoEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - n: Number of videos
        - f: Total video feature size
        - h: Hidden size (must match the hidden size of language model backbone)
116
    """
117

118
    type: Literal["video_embeds"]
119
    data: Annotated[torch.Tensor | list[torch.Tensor], TensorShape("n", "f", "h")]
120
121


122
InternVLVideoInputs: TypeAlias = InternVLVideoPixelInputs | InternVLVideoEmbeddingInputs
123
124


125
class BaseInternVLProcessingInfo(BaseProcessingInfo):
126
    """Basic image-only ProcessingInfo for InternVL-style models."""
127
128

    @abstractmethod
129
    def get_hf_processor(self, **kwargs: object) -> InternVLProcessor:
130
131
        raise NotImplementedError

132
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
133
134
135
136
        return {"image": None}

    def get_num_image_tokens(
        self,
137
        *,
138
139
        image_width: int,
        image_height: int,
140
        processor: InternVLProcessor,
141
142
143
144
145
    ) -> int:
        return processor.get_num_image_tokens(
            image_width=image_width,
            image_height=image_height,
        )
146

147
148
    def get_image_size_with_most_features(self) -> ImageSize:
        processor = self.get_hf_processor()
149
        image_processor = processor.image_processor
150

151
        base_size = image_processor.image_size
152
153
154
155
156
157
158
159
160
161
162
163
164
        target_ratios = processor.resolve_target_ratios()

        largest_feature_size, largest_feature_pinpoint = 0, None
        for wr, hr in target_ratios:
            width, height = base_size * wr, base_size * hr

            feat_size = self.get_num_image_tokens(
                image_width=width,
                image_height=height,
                processor=processor,
            )
            if feat_size > largest_feature_size:
                largest_feature_size = feat_size
165
                largest_feature_pinpoint = ImageSize(width=width, height=height)
166
167
168
169
170
171

        if largest_feature_size == 0 or largest_feature_pinpoint is None:
            raise ValueError("Cannot have a largest feature size of 0!")

        return largest_feature_pinpoint

172
173
174
175
176
177
178
179
180
181
    def get_max_image_tokens(self) -> int:
        processor = self.get_hf_processor()
        target_width, target_height = self.get_image_size_with_most_features()

        return self.get_num_image_tokens(
            image_width=target_width,
            image_height=target_height,
            processor=processor,
        )

182
183
184
185

_I = TypeVar("_I", bound=BaseInternVLProcessingInfo)


186
187
class BaseInternVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
    """Basic image-only DummyInputsBuilder for InternVL-style models."""
188

189
190
191
192
193
194
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        return "<image>" * num_images

    def get_dummy_mm_data(
195
196
197
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
198
        mm_options: Mapping[str, BaseDummyOptions],
199
    ) -> MultiModalDataDict:
200
        target_width, target_height = self.info.get_image_size_with_most_features()
201
202
        num_images = mm_counts.get("image", 0)

203
        image_overrides = mm_options.get("image")
204

205
        return {
206
207
208
209
210
211
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
212
213
214
        }


215
class BaseInternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
216
    """Basic image-only MultiModalProcessor for InternVL-style models."""
217
218
219
220
221
222

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
223
        tok_kwargs: Mapping[str, object],
224
    ) -> BatchFeature:
225
226
227
228
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
229
            tok_kwargs=tok_kwargs,
230
        )
231

232
        hf_processor = self.info.get_hf_processor(**mm_kwargs)
233
        image_token_id = hf_processor.ctx_image_token_id
234
235
236
237

        # Since there may be extra tokens in the feature placeholders,
        # we need to pass the image token ID to the model to select the
        # tokens to merge from the vision encoder outputs
238
        processed_outputs["image_token_id"] = torch.tensor(image_token_id)
239
240
241

        return processed_outputs

242
    def _get_image_fields_config(self, hf_inputs: BatchFeature):
243
        image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
244
        num_images = len(image_num_patches)
245
246
247

        return dict(
            pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
248
249
                "image", image_num_patches
            ),
250
251
            image_num_patches=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
252
            image_token_id=MultiModalFieldConfig.shared("image", num_images),
253
254
        )

255
    def _get_mm_fields_config(
256
        self,
257
        hf_inputs: BatchFeature,
258
        hf_processor_mm_kwargs: Mapping[str, object],
259
260
    ) -> Mapping[str, MultiModalFieldConfig]:
        return self._get_image_fields_config(hf_inputs)
261

262
263
264
265
266
267
    def _get_prompt_repl_image(
        self,
        mm_items: MultiModalDataItems,
        hf_processor: InternVLProcessor,
        out_mm_data: BatchedTensorInputs,
    ):
268
269
        if "image_num_patches" in out_mm_data:
            image_num_patches = out_mm_data["image_num_patches"]
270
271
            assert isinstance(image_num_patches, torch.Tensor)
            image_num_patches = image_num_patches.tolist()
272
        elif "image_embeds" in out_mm_data:
273
274
            # TODO: Use image size information in dictionary embedding inputs
            # to compute num_patches (similar to Qwen2-VL)
275
            image_num_patches = [None] * len(out_mm_data["image_embeds"])
276
277
278
279
280
        else:
            image_num_patches = []

        def get_replacement_internvl(item_idx: int):
            images = mm_items.get_items(
281
282
                "image", (ImageEmbeddingItems, ImageProcessorItems)
            )
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297

            if isinstance(images, ImageEmbeddingItems):
                feature_size = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
                feature_size = self.info.get_num_image_tokens(
                    image_width=image_size.width,
                    image_height=image_size.height,
                    processor=hf_processor,
                )

            num_patches = image_num_patches[item_idx]
            if num_patches is not None:
                assert isinstance(num_patches, int)

298
            return hf_processor.get_image_repl(num_patches, num_features=feature_size)
299

300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
        return PromptReplacement(
            modality="image",
            target="<image>",
            replacement=get_replacement_internvl,
        )

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargsItems,
    ) -> Sequence[PromptUpdate]:
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
        out_mm_data = out_mm_kwargs.get_data()

315
        return [
316
            self._get_prompt_repl_image(mm_items, hf_processor, out_mm_data),
317
        ]
318
319


320
class InternVLProcessingInfo(BaseInternVLProcessingInfo):
321
322
    """InternVL ProcessingInfo extended for video processing"""

323
324
325
    def get_image_processor(self, **kwargs):
        config = self.get_hf_config()
        vision_config = config.vision_config
326

327
328
329
330
331
332
333
334
        kwargs = self.ctx.get_merged_mm_kwargs(kwargs)
        kwargs.setdefault("image_size", vision_config.image_size)
        kwargs.setdefault("min_dynamic_patch", config.min_dynamic_patch)
        kwargs.setdefault("max_dynamic_patch", config.max_dynamic_patch)
        kwargs.setdefault("dynamic_image_size", config.dynamic_image_size)
        kwargs.setdefault("use_thumbnail", config.use_thumbnail)

        return InternVLImageProcessor(**kwargs)
335

336
337
338
339
340
341
342
343
344
345
346
    def get_video_processor(self, **kwargs):
        config = self.get_hf_config()
        vision_config = config.vision_config

        kwargs = self.ctx.get_merged_mm_kwargs(kwargs)
        kwargs.setdefault("image_size", vision_config.image_size)

        return InternVLVideoProcessor(**kwargs)

    @cached_property
    def ctx_video_token(self):
347
        text_model_type = self.get_hf_config().get_text_config().model_type
348
        ctx_video_token_map = {
349
350
351
352
353
            "qwen2": "<|video_pad|>",
            "qwen3": "<|video_pad|>",
            "qwen3_moe": "<|video_pad|>",
            "gpt_oss": "<|reserved_200000|>",
        }
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389

        if text_model_type not in ctx_video_token_map:
            return None

        ctx_video_token = ctx_video_token_map[text_model_type]
        if ctx_video_token not in self.get_tokenizer().get_vocab():
            return None

        return ctx_video_token

    def get_hf_processor(self, **kwargs: object) -> InternVLProcessor:
        config = self.get_hf_config()
        vision_config = config.vision_config

        image_processor = self.get_image_processor(**kwargs)
        image_size = image_processor.image_size
        patch_size = vision_config.patch_size
        downsample_ratio = config.downsample_ratio
        image_seq_length = int((image_size // patch_size) ** 2 * (downsample_ratio**2))

        ctx_video_token = self.ctx_video_token
        video_processor = (
            self.get_video_processor(**kwargs) if ctx_video_token else None
        )

        return InternVLProcessor(
            tokenizer=self.get_tokenizer(),
            image_processor=image_processor,
            video_processor=video_processor,
            image_seq_length=image_seq_length,
            ctx_video_token=ctx_video_token,
        )

    def get_supported_mm_limits(self):
        video_limit = {"video": None} if self.ctx_video_token else {}
        return {**super().get_supported_mm_limits(), **video_limit}
390
391
392
393
394
395
396
397
398
399

    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
        max_images = mm_counts.get("image", 0)
        max_videos = mm_counts.get("video", 0)

        processor = self.get_hf_processor()
400
        num_image_token = processor.image_seq_length
401
402

        max_image_tokens = self.get_max_image_tokens() * max_images
403
        max_total_frames = (seq_len - max_image_tokens) // num_image_token
404
405
406
        max_frames_per_video = max_total_frames // max(max_videos, 1)

        return max(max_frames_per_video, 1)
407

408

409
class InternVLDummyInputsBuilder(
410
411
    BaseInternVLDummyInputsBuilder[InternVLProcessingInfo]
):
412
413
414
415
416
417
418
419
420
421
422
    """InternVL DummyInputsBuilder extended for video support"""

    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_videos = mm_counts.get("video", 0)

        return super().get_dummy_text(mm_counts) + "<video>" * num_videos

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
423
        mm_options: Mapping[str, BaseDummyOptions],
424
    ) -> MultiModalDataDict:
425
        dummy_image = super().get_dummy_mm_data(seq_len, mm_counts, mm_options)
426
        if self.info.ctx_video_token:
427
428
            config = self.info.get_hf_config()
            image_size: int = config.vision_config.image_size
429
430
431
            target_num_frames = self.info.get_num_frames_with_most_features(
                seq_len, mm_counts
            )
432
            num_videos = mm_counts.get("video", 0)
433
            video_overrides = mm_options.get("video")
434
            dummy_video = {
435
436
437
438
439
440
441
                "video": self._get_dummy_videos(
                    width=image_size,
                    height=image_size,
                    num_frames=target_num_frames,
                    num_videos=num_videos,
                    overrides=video_overrides,
                )
442
443
444
445
446
447
448
            }
        else:
            dummy_video = {}
        return {**dummy_image, **dummy_video}


class InternVLMultiModalProcessor(
449
450
    BaseInternVLMultiModalProcessor[InternVLProcessingInfo]
):
451
452
453
454
455
456
457
    """InternVL MultiModalProcessor extended for video support"""

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
458
        tok_kwargs: Mapping[str, object],
459
    ) -> BatchFeature:
460
461
462
        processed_outputs = super()._call_hf_processor(
            prompt, mm_data, mm_kwargs, tok_kwargs
        )
463
464

        hf_processor = self.info.get_hf_processor(**mm_kwargs)
465
        if (video_token_id := hf_processor.ctx_video_token_id) is not None:
466
            processed_outputs["video_token_id"] = torch.tensor(video_token_id)
467

468
469
        return processed_outputs

470
471
472
473
474
475
476
477
478
479
480
481
    def _get_video_fields_config(self, hf_inputs: BatchFeature):
        video_num_patches = hf_inputs.get("video_num_patches", torch.empty(0))
        num_videos = len(video_num_patches)

        return dict(
            pixel_values_flat_video=MultiModalFieldConfig.flat_from_sizes(
                "video", video_num_patches
            ),
            video_num_patches=MultiModalFieldConfig.batched("video"),
            video_token_id=MultiModalFieldConfig.shared("video", num_videos),
        )

482
483
    def _get_mm_fields_config(
        self,
484
        hf_inputs: BatchFeature,
485
486
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
487
        fields = self._get_image_fields_config(hf_inputs)
488
        if self.info.ctx_video_token:
489
            fields |= self._get_video_fields_config(hf_inputs)
490

491
        return fields
492

493
    def _get_prompt_repl_video(
494
495
        self,
        mm_items: MultiModalDataItems,
496
497
498
        hf_processor: InternVLProcessor,
        out_mm_data: BatchedTensorInputs,
    ):
499
500
        if "video_num_patches" in out_mm_data:
            video_num_patches = out_mm_data["video_num_patches"]
501
502
503
504
505
506
507
508
509
510
            assert isinstance(video_num_patches, torch.Tensor)
            video_num_patches = video_num_patches.tolist()
        else:
            video_num_patches = []

        def get_video_replacement_internvl(item_idx: int):
            num_patches = video_num_patches[item_idx]
            if num_patches is not None:
                assert isinstance(num_patches, int)

511
            return hf_processor.get_video_repl(num_patches)
512

513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
        return PromptReplacement(
            modality="video",
            target="<video>",
            replacement=get_video_replacement_internvl,
        )

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargsItems,
    ) -> Sequence[PromptUpdate]:
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
        out_mm_data = out_mm_kwargs.get_data()

        prompt_repls = [
            self._get_prompt_repl_image(mm_items, hf_processor, out_mm_data),
530
        ]
531
532
533
534
535
536
        if self.info.ctx_video_token is not None:
            prompt_repls.append(
                self._get_prompt_repl_video(mm_items, hf_processor, out_mm_data)
            )

        return prompt_repls
537
538


539
540
541
@MULTIMODAL_REGISTRY.register_processor(
    InternVLMultiModalProcessor,
    info=InternVLProcessingInfo,
542
543
544
    dummy_inputs=InternVLDummyInputsBuilder,
)
class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
545
546
    supports_encoder_tp_data = True

547
    @classmethod
548
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
549
550
551
552
553
554
555
        if modality.startswith("image"):
            return "<image>"
        if modality.startswith("video"):
            return "<video>"

        raise ValueError("Only image or video modality is supported")

556
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
557
558
        super().__init__()

559
560
561
562
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

563
564
        self.config = config
        self.multimodal_config = multimodal_config
565
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
566
        self._patch_quant_config(config, quant_config)
567
568
569
570

        image_size = config.force_image_size or config.vision_config.image_size
        patch_size = config.vision_config.patch_size
        self.patch_size = patch_size
571
572
        self.patch_tokens = (image_size // patch_size) ** 2
        self.num_image_token = int(self.patch_tokens * (config.downsample_ratio**2))
573
574
575
        self.downsample_ratio = config.downsample_ratio
        self.ps_version = config.ps_version

576
577
        llm_arch_name = config.text_config.architectures[0]
        self.is_mono = llm_arch_name == "InternLM2VEForCausalLM"
578

579
580
581
582
583
584
585
586
        with self._mark_tower_model(vllm_config, {"image", "video"}):
            self.vision_model = self._init_vision_model(
                config,
                quant_config=quant_config,
                is_mono=self.is_mono,
                prefix=maybe_prefix(prefix, "vision_model"),
            )
            self.mlp1 = self._init_mlp1(config)
587

588
589
590
591
592
593
        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=config.text_config,
                prefix=maybe_prefix(prefix, "language_model"),
            )
594
595

        self.img_context_token_id = None
596
597
        self.video_context_token_id = None

598
        self.visual_token_mask = None
599
        self.make_empty_intermediate_tensors = (
600
601
            self.language_model.make_empty_intermediate_tensors
        )
602

603
604
605
    def _patch_quant_config(
        self, config: PretrainedConfig, quant_config: QuantizationConfig
    ):
606
607
608
609
        # the awq models from OpenGVLab missing `modules_to_not_convert`
        # patch the quant_config to add `modules_to_not_convert` back
        if isinstance(quant_config, AWQConfig):
            text_config = config.text_config
610
611
612
613
            llm_quant_config = getattr(text_config, "quantization_config", None)
            if (not quant_config.modules_to_not_convert) and (
                llm_quant_config is not None
            ):
614
615
616
617
618
                quant_config.modules_to_not_convert.append("vision_model")

    def _init_vision_model(
        self,
        config: PretrainedConfig,
619
        quant_config: QuantizationConfig | None,
620
621
622
623
        *,
        is_mono: bool,
        prefix: str,
    ):
624
        if not is_mono:
625
            vision_feature_layer = config.select_layer
626
            if vision_feature_layer < 0:
627
628
629
                num_hidden_layers = (
                    config.vision_config.num_hidden_layers + vision_feature_layer + 1
                )
630
631
            else:
                num_hidden_layers = vision_feature_layer + 1
632

633
634
            return InternVisionModel(
                config.vision_config,
635
636
637
                quant_config=quant_config,
                num_hidden_layers_override=num_hidden_layers,
                prefix=prefix,
638
            )
639
640
        else:
            return InternVisionPatchModel(config.vision_config)
641

642
    def _init_mlp1(self, config: PretrainedConfig) -> nn.Module:
643
644
645
646
        vit_hidden_size = config.vision_config.hidden_size
        llm_hidden_size = config.text_config.hidden_size

        return nn.Sequential(
647
648
649
650
            nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
            nn.Linear(
                vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size
            ),
651
652
653
654
            nn.GELU(),
            nn.Linear(llm_hidden_size, llm_hidden_size),
        )

655
656
657
658
659
660
    def pixel_shuffle(self, x, scale_factor=0.5):
        n, w, h, c = x.size()
        # N, W, H, C --> N, W, H * scale, C // scale
        x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
        # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
        x = x.permute(0, 2, 1, 3).contiguous()
661
662
663
664
665
666
667
        x = x.view(
            n,
            int(h * scale_factor),
            int(w * scale_factor),
            int(c / (scale_factor * scale_factor)),
        )
        if self.ps_version == "v1":
668
669
670
671
672
            pass
        else:
            x = x.permute(0, 2, 1, 3).contiguous()
        return x

673
    def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor:
674
675
676
        vit_embeds = self.vision_model(pixel_values=pixel_values)
        vit_embeds = vit_embeds[:, 1:, :]

677
        h = w = int(vit_embeds.shape[1] ** 0.5)
678
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
679
680
        vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
681
682
683
684
        vit_embeds = self.mlp1(vit_embeds)
        return vit_embeds

    def _parse_and_validate_image_input(
685
        self, **kwargs: object
686
    ) -> InternVLImageInputs | None:
687
688
        pixel_values_flat = kwargs.pop("pixel_values_flat", None)
        image_num_patches = kwargs.pop("image_num_patches", None)
689
        image_embeds = kwargs.pop("image_embeds", None)
690

691
        if pixel_values_flat is None and image_embeds is None:
692
693
            return None

694
695
696
        if image_embeds is not None:
            return InternVLImageEmbeddingInputs(
                type="image_embeds",
697
                data=image_embeds,
698
699
            )

700
        image_token_id = kwargs["image_token_id"]
701
702
703
704
705
        if isinstance(image_token_id, torch.Tensor):
            image_token_id = image_token_id.flatten().unique().item()

        assert isinstance(image_token_id, int)
        self.img_context_token_id = image_token_id
706

707
        if pixel_values_flat is not None:
708
709
            expected_h = expected_w = self.config.vision_config.image_size
            resolve_bindings = {"h": expected_h, "w": expected_w}
710

711
712
            return InternVLImagePixelInputs(
                type="pixel_values",
713
                pixel_values_flat=pixel_values_flat,
714
                num_patches=image_num_patches,
715
                resolve_bindings=resolve_bindings,
716
            )
717
718
719

        raise AssertionError("This line should be unreachable.")

720
    def _parse_and_validate_video_input(
721
        self, **kwargs: object
722
    ) -> InternVLVideoPixelInputs | None:
723
724
725
726
727
728
729
730
        pixel_values_flat_video = kwargs.pop("pixel_values_flat_video", None)
        video_num_patches = kwargs.pop("video_num_patches", None)
        video_embeds = kwargs.pop("image_embeds", None)

        if pixel_values_flat_video is None and video_embeds is None:
            return None

        if video_embeds is not None:
731
            return InternVLVideoEmbeddingInputs(
732
                type="video_embeds",
733
                data=video_embeds,
734
735
736
            )

        video_token_id = kwargs["video_token_id"]
737
738
739
740
741
        if isinstance(video_token_id, torch.Tensor):
            video_token_id = video_token_id.flatten().unique().item()

        assert isinstance(video_token_id, int)
        self.video_context_token_id = video_token_id
742
743

        if pixel_values_flat_video is not None:
744
745
            expected_h = expected_w = self.config.vision_config.image_size
            resolve_bindings = {"h": expected_h, "w": expected_w}
746
747
748

            return InternVLVideoPixelInputs(
                type="pixel_values_videos",
749
                pixel_values_flat=pixel_values_flat_video,
750
                num_patches=video_num_patches,
751
                resolve_bindings=resolve_bindings,
752
753
754
755
            )

        raise AssertionError("This line should be unreachable.")

756
    def _process_vision_input(
757
        self,
758
        image_input: InternVLImageInputs | InternVLVideoInputs,
759
    ) -> tuple[torch.Tensor, ...]:
760
761
762
763
        if (
            image_input["type"] == "image_embeds"
            or image_input["type"] == "video_embeds"
        ):
764
765
            return image_input["data"]

766
        image_embeds = self.extract_feature(image_input["pixel_values_flat"])
767

768
        num_patches = image_input["num_patches"]
769
770

        # Only one image in the current batch
771
        if len(num_patches) == 1:
772
            return (image_embeds.view(-1, self.config.text_config.hidden_size),)
773
774
775
776

        # NOTE: Image embeddings are split into separate tensors for each image
        # by the size of each embedding.
        feature_size = image_embeds.shape[1]
777
        image_embeds = image_embeds.view(-1, self.config.text_config.hidden_size)
778
        image_feature_sizes = [
779
            num_patches * feature_size for num_patches in num_patches
780
        ]
781
        return image_embeds.split(image_feature_sizes)
782

783
784
785
786
787
788
    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        modalities = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
789
790
791
792
793
794
795
            if (
                input_key in ("pixel_values_flat", "image_embeds")
                and "images" not in modalities
            ):
                modalities["images"] = self._parse_and_validate_image_input(**kwargs)
            if input_key in ("pixel_values_flat_video",) and "videos" not in modalities:
                modalities["videos"] = self._parse_and_validate_video_input(**kwargs)
796
797
798

        return modalities

799
    def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
800
        if self.is_mono:
801
            assert self.img_context_token_id is not None
802
803
804
            self.visual_token_mask = (input_ids == self.img_context_token_id).reshape(
                -1, 1
            )
805
        else:
806
            self.visual_token_mask = None
807

808
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
809
810
        modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not modalities:
811
            return []
812

813
814
815
816
817
818
819
820
821
        # The result multimodal_embeddings is tuple of tensors, with each
        # tensor correspoending to a multimodal data item (image or video).
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in modalities:
            if modality == "images":
                image_input = modalities["images"]
822
823
                image_embeddings = self._process_vision_input(image_input)
                multimodal_embeddings += tuple(image_embeddings)
824
825
            if modality == "videos":
                video_input = modalities["videos"]
826
                video_embeddings = self._process_vision_input(video_input)
827
                multimodal_embeddings += tuple(video_embeddings)
828
829

        return multimodal_embeddings
830

831
    def embed_input_ids(
832
833
        self,
        input_ids: torch.Tensor,
834
        multimodal_embeddings: MultiModalEmbeddings | None = None,
835
        *,
836
        is_multimodal: torch.Tensor | None = None,
837
    ) -> torch.Tensor:
838
        if multimodal_embeddings is not None and len(multimodal_embeddings) > 0:
839
            self._set_visual_token_mask(input_ids)
840
841
842

        # This is to satisfy the type checker for each overload
        if multimodal_embeddings is None or is_multimodal is None:
843
            return super().embed_input_ids(input_ids)
844

845
        return super().embed_input_ids(
846
847
848
849
            input_ids,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
        )
850

851
852
    def forward(
        self,
853
        input_ids: torch.Tensor | None,
854
        positions: torch.Tensor,
855
856
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
857
        **kwargs: object,
858
    ) -> IntermediateTensors:
859
        if intermediate_tensors is not None:
860
            inputs_embeds = None
861

862
863
864
865
866
867
        forward_kwargs = {
            "input_ids": input_ids,
            "positions": positions,
            "intermediate_tensors": intermediate_tensors,
            "inputs_embeds": inputs_embeds,
        }
868

869
        # Only required if the model is mono-architecture
870
        if self.visual_token_mask is not None:
871
            forward_kwargs.update({"visual_token_mask": self.visual_token_mask})
872
            self.visual_token_mask = None
873

874
        hidden_states = self.language_model.model(**forward_kwargs)
875
876
        return hidden_states

877
878
879
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
880
    ) -> torch.Tensor | None:
881
        return self.language_model.compute_logits(hidden_states)
882

883
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
884
885
        # unused modules appear in OpenGVLab/InternVideo2_5_Chat_8B
        skip_prefixes = [
886
887
888
889
890
891
892
893
894
895
896
897
            "action_embed",
            "temporal_embed",
            "track_embed",
            "track_embed_decoder",
            "box_token",
            "cg_criterion",
            "cg_model",
            "loc_encoder",
            "loc_decoder",
            "sam",
            "temporal_token",
            "track_token",
898
899
        ]
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
900
        return loader.load_weights(weights)
901
902
903
904
905
906
907
908

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector="mlp1",
909
910
            tower_model="vision_model",
        )
911
912
913
914
915
916
917
918
919
920
921
922
923
924

    def get_num_mm_encoder_tokens(self, num_image_tokens: int) -> int:
        if num_image_tokens <= 0 or self.num_image_token <= 0:
            return 0

        num_patches = num_image_tokens // self.num_image_token
        return num_patches * (self.patch_tokens + 1)

    def get_num_mm_connector_tokens(self, num_vision_tokens: int) -> int:
        if num_vision_tokens <= 0 or self.num_image_token <= 0:
            return 0

        num_patches = num_vision_tokens // (self.patch_tokens + 1)
        return num_patches * self.num_image_token