internvl.py 48.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_internvl_chat.py
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
10
from abc import ABC, abstractmethod
11
from collections.abc import Iterable, Mapping, Sequence
12
from typing import Annotated, Any, Literal, TypeAlias, TypeVar
13

14
import numpy.typing as npt
15
16
17
18
import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
19
from transformers import BatchFeature, PretrainedConfig, TensorType
20

21
from vllm.config import VllmConfig
22
from vllm.config.multimodal import BaseDummyOptions
23
24
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
25
26
27
28
from vllm.model_executor.models.intern_vit import (
    InternVisionModel,
    InternVisionPatchModel,
)
29
from vllm.model_executor.models.module_mapping import MultiModelKeys
30
from vllm.multimodal import MULTIMODAL_REGISTRY
31
from vllm.multimodal.image import convert_image_mode
32
33
34
35
36
37
38
39
40
41
42
43
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
from vllm.multimodal.parse import (
    ImageEmbeddingItems,
    ImageProcessorItems,
    ImageSize,
    MultiModalDataItems,
)
from vllm.multimodal.processing import (
44
    BaseDummyInputsBuilder,
45
46
47
48
49
50
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
51
from vllm.sequence import IntermediateTensors
52
from vllm.tokenizers import TokenizerLike
53
from vllm.utils.tensor_schema import TensorSchema, TensorShape
54

55
56
57
58
59
60
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMultiModal,
    SupportsPP,
)
61
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
62

63
64
65
IMG_START = "<img>"
IMG_END = "</img>"
IMG_CONTEXT = "<IMG_CONTEXT>"
66
67
68
69
70

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)


71
class InternVLImagePixelInputs(TensorSchema):
72
    """
73
74
75
76
77
78
    Dimensions:
        - bn: Batch size * number of images
        - bnp: Batch size * number of images * (1 + num_patches)
        - c: Number of channels (3)
        - h: Height of each image patch
        - w: Width of each image patch
79
    """
80

81
82
83
    type: Literal["pixel_values"]
    pixel_values_flat: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")]
    num_patches: Annotated[torch.Tensor, TensorShape("bn")]
84

85

86
87
88
89
90
91
class InternVLImageEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - n: Number of images
        - f: Total image feature size
        - h: Hidden size (must match the hidden size of language model backbone)
92
    """
93

94
    type: Literal["image_embeds"]
95
    data: Annotated[torch.Tensor | list[torch.Tensor], TensorShape("n", "f", "h")]
96
97


98
InternVLImageInputs: TypeAlias = InternVLImagePixelInputs | InternVLImageEmbeddingInputs
99
100


101
class InternVLVideoPixelInputs(TensorSchema):
102
    """
103
104
105
106
107
108
    Dimensions:
        - bvf: Batch size * number of videos * num_frames
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height of each video frame
        - w: Width of each video frame
109
    """
110

111
112
113
    type: Literal["pixel_values_videos"]
    pixel_values_flat: Annotated[torch.Tensor, TensorShape("bvf", 3, "h", "w")]
    num_patches: Annotated[torch.Tensor, TensorShape("bn")]
114
115


116
117
118
119
120
121
class InternVLVideoEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - n: Number of videos
        - f: Total video feature size
        - h: Hidden size (must match the hidden size of language model backbone)
122
    """
123

124
    type: Literal["video_embeds"]
125
    data: Annotated[torch.Tensor | list[torch.Tensor], TensorShape("n", "f", "h")]
126
127


128
InternVLVideoInputs: TypeAlias = InternVLVideoPixelInputs | InternVLVideoEmbeddingInputs
129
130


131
132
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def build_transform(input_size: int):
133
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
134
135
136
137
138
139
140
141
142
143
    transform = T.Compose(
        [
            T.Lambda(lambda img: convert_image_mode(img, "RGB")),
            T.Resize(
                (input_size, input_size), interpolation=T.InterpolationMode.BICUBIC
            ),
            T.ToTensor(),
            T.Normalize(mean=MEAN, std=STD),
        ]
    )
144
    return transform
145
146


147
148
149
150
151
152
153
154
155
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def find_closest_aspect_ratio(
    aspect_ratio: float,
    target_ratios: list[tuple[int, int]],
    *,
    width: int,
    height: int,
    image_size: int,
) -> tuple[int, int]:
156
    best_ratio_diff = float("inf")
157
158
159
160
161
162
163
164
165
166
167
168
169
170
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio


171
172
173
174
175
176
177
def resolve_internvl_min_max_num(
    *,
    min_dynamic_patch: int,
    max_dynamic_patch: int,
    dynamic_image_size: bool,
    use_thumbnail: bool,
) -> tuple[int, int]:
178
    min_dynamic_patch = min_dynamic_patch if dynamic_image_size else 1
179
180
181
182
183
184
185
    max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1

    if use_thumbnail and max_dynamic_patch != 1:
        max_dynamic_patch += 1

    return min_dynamic_patch, max_dynamic_patch

186

187
188
189
190
def get_internvl_target_ratios(
    min_num: int,
    max_num: int,
) -> list[tuple[int, int]]:
191
192
193
194
195
196
197
    target_ratios = {
        (i, j)
        for n in range(min_num, max_num + 1)
        for i in range(1, n + 1)
        for j in range(1, n + 1)
        if min_num <= i * j <= max_num
    }
198
199
200
201
202
203
204
205
206
207
208
209
    return sorted(target_ratios, key=lambda x: x[0] * x[1])


def calculate_internvl_targets(
    *,
    orig_width: int,
    orig_height: int,
    target_ratios: list[tuple[int, int]],
    image_size: int,
    use_thumbnail: bool,
) -> tuple[int, int, int]:
    aspect_ratio = orig_width / orig_height
210
211

    # find the closest aspect ratio to the target
212
213
214
215
216
217
218
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio,
        target_ratios,
        width=orig_width,
        height=orig_height,
        image_size=image_size,
    )
219
220
221
222
223
224

    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

225
226
227
    # add thumbnail image if num_blocks != 1
    if use_thumbnail and blocks != 1:
        blocks += 1
228

229
    return blocks, target_width, target_height
230
231


232
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
233
234
235
236
237
238
239
def dynamic_preprocess_internvl(
    image: Image.Image,
    *,
    target_ratios: list[tuple[int, int]],
    image_size: int,
    use_thumbnail: bool,
) -> list[Image.Image]:
240
241
    orig_width, orig_height = image.size

242
    # calculate the number of blocks without thumbnail
243
244
245
246
247
248
249
250
    blocks, target_width, target_height = calculate_internvl_targets(
        orig_width=orig_width,
        orig_height=orig_height,
        target_ratios=target_ratios,
        image_size=image_size,
        use_thumbnail=False,
    )

251
252
253
254
    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
255
256
257
258
259
260
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size,
        )
261
262
263
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
264

265
    assert len(processed_images) == blocks
266

267
268
269
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
270

271
272
273
274
    return processed_images


# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
275
276
277
278
279
280
281
282
283
284
def image_to_pixel_values_internvl(
    image: Image.Image,
    *,
    input_size: int,
    min_num: int,
    max_num: int,
    use_thumbnail: bool,
) -> torch.Tensor:
    target_ratios = get_internvl_target_ratios(min_num, max_num)

285
    transform = build_transform(input_size=input_size)
286
287
288
289
290
291
292
293
    images = dynamic_preprocess_internvl(
        image,
        target_ratios=target_ratios,
        image_size=input_size,
        use_thumbnail=use_thumbnail,
    )

    pixel_values = torch.stack([transform(image) for image in images])
294
295
296
    return pixel_values


297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def video_to_pixel_values_internvl(
    video: npt.NDArray,
    *,
    input_size: int,
    min_num: int,
    max_num: int,
    use_thumbnail: bool,
) -> torch.Tensor:
    target_ratios = get_internvl_target_ratios(min_num, max_num)

    transform = build_transform(input_size=input_size)
    frames_list = list[Image.Image]()
    for frame in video:
        pil_frame = dynamic_preprocess_internvl(
            Image.fromarray(frame, mode="RGB"),
            target_ratios=target_ratios,
            image_size=input_size,
            use_thumbnail=use_thumbnail,
        )
        assert len(pil_frame) == 1
        frames_list.extend(pil_frame)

    pixel_values = torch.stack([transform(image) for image in frames_list])
    return pixel_values


324
325
326
327
class BaseInternVLProcessor(ABC):
    """
    This model doesn't define its own HF processor,
    so we implement our own one here.
328

329
330
331
    The code to insert image tokens is based on:
    https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py#L252
    """
332

333
334
335
    def __init__(
        self,
        config: PretrainedConfig,
336
        tokenizer: TokenizerLike,
337
        *,
338
339
340
        min_dynamic_patch: int | None = None,
        max_dynamic_patch: int | None = None,
        dynamic_image_size: bool | None = None,
341
342
    ) -> None:
        super().__init__()
343

344
345
        self.config = config
        self.tokenizer = tokenizer
346

347
348
        image_size: int = config.vision_config.image_size
        patch_size: int = config.vision_config.patch_size
349

350
351
352
        if min_dynamic_patch is None:
            min_dynamic_patch = config.min_dynamic_patch
        assert isinstance(min_dynamic_patch, int)
353

354
355
356
        if max_dynamic_patch is None:
            max_dynamic_patch = config.max_dynamic_patch
        assert isinstance(max_dynamic_patch, int)
357

358
359
360
361
        if dynamic_image_size is None:
            dynamic_image_size = config.dynamic_image_size
        assert isinstance(dynamic_image_size, bool)

362
        self.num_image_token = int(
363
364
            (image_size // patch_size) ** 2 * (config.downsample_ratio**2)
        )
365
        self.image_size = image_size
366
        self.min_dynamic_patch = min_dynamic_patch
367
368
369
370
371
372
373
374
375
376
        self.max_dynamic_patch = max_dynamic_patch
        self.dynamic_image_size = dynamic_image_size
        self.use_thumbnail: bool = config.use_thumbnail

    @property
    @abstractmethod
    def image_token_id(self) -> int:
        raise NotImplementedError

    @abstractmethod
377
    def get_image_repl(
378
379
        self,
        feature_size: int,
380
        num_patches: int | None,
381
    ) -> PromptUpdateDetails[str]:
382
        raise NotImplementedError
383

384
    def resolve_min_max_num(
385
        self,
386
        *,
387
388
389
390
        min_dynamic_patch: int | None = None,
        max_dynamic_patch: int | None = None,
        dynamic_image_size: bool | None = None,
        use_thumbnail: bool | None = None,
391
    ) -> tuple[int, int]:
392
393
394
395
396
397
398
399
400
401
402
403
        min_dynamic_patch = (
            self.min_dynamic_patch if min_dynamic_patch is None else min_dynamic_patch
        )
        max_dynamic_patch = (
            self.max_dynamic_patch if max_dynamic_patch is None else max_dynamic_patch
        )
        dynamic_image_size = (
            self.dynamic_image_size
            if dynamic_image_size is None
            else dynamic_image_size
        )
        use_thumbnail = self.use_thumbnail if use_thumbnail is None else use_thumbnail
404
405
406
407
408
409
410

        return resolve_internvl_min_max_num(
            min_dynamic_patch=min_dynamic_patch,
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
            use_thumbnail=use_thumbnail,
        )
411

412
413
414
    def resolve_target_ratios(
        self,
        *,
415
416
417
418
        min_dynamic_patch: int | None = None,
        max_dynamic_patch: int | None = None,
        dynamic_image_size: bool | None = None,
        use_thumbnail: bool | None = None,
419
420
    ) -> list[tuple[int, int]]:
        min_num, max_num = self.resolve_min_max_num(
421
            min_dynamic_patch=min_dynamic_patch,
422
423
424
425
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
            use_thumbnail=use_thumbnail,
        )
426

427
        return get_internvl_target_ratios(min_num, max_num)
428

429
    def get_num_image_tokens(
430
        self,
431
432
433
434
435
436
437
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        target_ratios = self.resolve_target_ratios(
            use_thumbnail=False,  # Applied in calculate_targets
        )
438

439
440
441
442
443
444
445
        num_patches, _, _ = calculate_internvl_targets(
            orig_width=image_width,
            orig_height=image_height,
            image_size=self.image_size,
            target_ratios=target_ratios,
            use_thumbnail=self.use_thumbnail,
        )
446

447
448
449
450
451
        return num_patches * self.num_image_token

    def _images_to_pixel_values_lst(
        self,
        images: list[Image.Image],
452
453
454
        min_dynamic_patch: int | None = None,
        max_dynamic_patch: int | None = None,
        dynamic_image_size: bool | None = None,
455
456
    ) -> list[torch.Tensor]:
        min_num, max_num = self.resolve_min_max_num(
457
            min_dynamic_patch=min_dynamic_patch,
458
459
460
461
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
            use_thumbnail=False,  # Applied in image_to_pixel_values
        )
462

463
464
465
466
467
468
469
        return [
            image_to_pixel_values_internvl(
                image,
                input_size=self.image_size,
                min_num=min_num,
                max_num=max_num,
                use_thumbnail=self.use_thumbnail,
470
471
            )
            for image in images
472
        ]
473

474
    def _preprocess_image(
475
        self,
476
477
        text: list[str],
        images: list[Image.Image],
478
479
480
        min_dynamic_patch: int | None = None,
        max_dynamic_patch: int | None = None,
        dynamic_image_size: bool | None = None,
481
    ) -> tuple[list[str], dict[str, torch.Tensor]]:
482
483
        if len(images) == 0:
            image_inputs = {}
484
        else:
485
486
            pixel_values_lst = self._images_to_pixel_values_lst(
                images,
487
                min_dynamic_patch=min_dynamic_patch,
488
489
490
                max_dynamic_patch=max_dynamic_patch,
                dynamic_image_size=dynamic_image_size,
            )
491
            image_inputs = {
492
493
494
495
                "pixel_values_flat": torch.cat(pixel_values_lst),
                "image_num_patches": torch.tensor(
                    [len(item) for item in pixel_values_lst]
                ),
496
497
498
499
500
501
            }

            for pixel_values in pixel_values_lst:
                num_patches = pixel_values.shape[0]
                feature_size = num_patches * self.num_image_token

502
                image_repl = self.get_image_repl(feature_size, num_patches)
503
                text = [t.replace("<image>", image_repl.full, 1) for t in text]
504
505
        return text, image_inputs

506
    def _make_batch_input(self, input_item: Any | list[Any] | None = None):
507
508
509
510
511
512
513
514
        if input_item is None:
            input_item = []
        if not isinstance(input_item, list):
            input_item = [input_item]
        return input_item

    def __call__(
        self,
515
516
517
518
519
520
        text: str | list[str] | None = None,
        images: Image.Image | list[Image.Image] | None = None,
        min_dynamic_patch: int | None = None,
        max_dynamic_patch: int | None = None,
        dynamic_image_size: bool | None = None,
        return_tensors: str | TensorType | None = None,
521
    ) -> BatchFeature:
522
523
524
525
526
527
528
529
530
        text, images = [self._make_batch_input(x) for x in (text, images)]

        text, image_inputs = self._preprocess_image(
            text=text,
            images=images,
            min_dynamic_patch=min_dynamic_patch,
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
        )
531
532
533

        text_inputs = self.tokenizer(text)

534
535
536
        combined_outputs = {**text_inputs, **image_inputs}

        return BatchFeature(combined_outputs, tensor_type=return_tensors)
537
538


539
class InternVLProcessor(BaseInternVLProcessor):
540
541
542
543
544
545
546
547
548
549
    """
    HF Processor for InternVLChatModel with extended video processing logic.

    Code for video processing is adapted from video example:
    https://huggingface.co/OpenGVLab/InternVL3-1B#inference-with-transformers
    """

    def __init__(
        self,
        config: PretrainedConfig,
550
        tokenizer: TokenizerLike,
551
        *,
552
553
554
555
        min_dynamic_patch: int | None = None,
        max_dynamic_patch: int | None = None,
        dynamic_image_size: bool | None = None,
        video_token: str | None = None,
556
557
558
559
560
561
562
563
564
565
    ) -> None:
        super().__init__(
            config=config,
            tokenizer=tokenizer,
            min_dynamic_patch=min_dynamic_patch,
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
        )
        # add extra video token for video processing
        self.video_token = video_token
566
567
568
569
570

    @property
    def image_token_id(self) -> int:
        return self.tokenizer.get_vocab()[IMG_CONTEXT]

571
    @property
572
    def video_token_id(self) -> int | None:
573
574
575
576
577
578
579
580
581
582
583
        if self.video_token is None:
            return None
        return self.tokenizer.get_vocab().get(self.video_token, None)

    @property
    def supports_video(self) -> bool:
        return self.video_token_id is not None

    def _videos_to_pixel_values_lst(
        self,
        videos: list[npt.NDArray],
584
        dynamic_image_size: bool | None = None,
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
    ) -> list[torch.Tensor]:
        min_num, max_num = self.resolve_min_max_num(
            min_dynamic_patch=1,
            max_dynamic_patch=1,
            dynamic_image_size=dynamic_image_size,
            use_thumbnail=False,  # Applied in image_to_pixel_values
        )

        return [
            video_to_pixel_values_internvl(
                video,
                input_size=self.image_size,
                min_num=min_num,
                max_num=max_num,
                use_thumbnail=False,
600
601
            )
            for video in videos
602
603
604
605
606
607
        ]

    def _preprocess_video(
        self,
        text: list[str],
        videos: list[npt.NDArray],
608
        dynamic_image_size: bool | None = None,
609
610
611
612
613
614
615
616
    ):
        if len(videos) == 0 or not self.supports_video:
            video_inputs = {}
        else:
            pixel_values_lst_video = self._videos_to_pixel_values_lst(
                videos,
                dynamic_image_size=dynamic_image_size,
            )
617
            video_inputs = {
618
619
620
621
                "pixel_values_flat_video": torch.cat(pixel_values_lst_video),
                "video_num_patches": torch.tensor(
                    [len(item) for item in pixel_values_lst_video]
                ),
622
623
624
625
626
            }

            for pixel_values in pixel_values_lst_video:
                num_patches = pixel_values.shape[0]

627
628
629
630
                video_repl = self.get_video_repl(
                    self.num_image_token, num_patches, self.video_token
                )
                text = [t.replace("<video>", video_repl.full, 1) for t in text]
631
632
633
634
        return text, video_inputs

    def __call__(
        self,
635
636
637
638
639
640
641
        text: str | list[str] | None = None,
        images: Image.Image | list[Image.Image] | None = None,
        videos: npt.NDArray | list[npt.NDArray] | None = None,
        min_dynamic_patch: int | None = None,
        max_dynamic_patch: int | None = None,
        dynamic_image_size: bool | None = None,
        return_tensors: str | TensorType | None = None,
642
    ) -> BatchFeature:
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
        text, images, videos = [
            self._make_batch_input(x) for x in (text, images, videos)
        ]

        text, image_inputs = self._preprocess_image(
            text=text,
            images=images,
            min_dynamic_patch=min_dynamic_patch,
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
        )

        text, video_inputs = self._preprocess_video(
            text=text,
            videos=videos,
            dynamic_image_size=dynamic_image_size,
        )

        text_inputs = self.tokenizer(text)

663
664
665
        combined_outputs = {**text_inputs, **image_inputs, **video_inputs}

        return BatchFeature(combined_outputs, tensor_type=return_tensors)
666

667
    def get_image_repl(
668
669
        self,
        feature_size: int,
670
        num_patches: int | None,
671
672
673
    ) -> PromptUpdateDetails[str]:
        repl_features = IMG_CONTEXT * feature_size
        repl_full = IMG_START + repl_features + IMG_END
674

675
        return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
676

677
678
679
    def get_video_repl(
        self,
        feature_size: int,
680
        num_patches: int | None = None,
681
682
683
684
685
        video_context_token: str = IMG_CONTEXT,
    ) -> PromptUpdateDetails[str]:
        repl_features = video_context_token * self.num_image_token
        repl_features_with_sep = IMG_START + repl_features + IMG_END
        # num_patches is equal to num_frames
686
687
688
        repl_full = "".join(
            [f"Frame{i + 1}: {repl_features_with_sep}" for i in range(num_patches)]
        )
689
690
691

        return PromptUpdateDetails.select_text(repl_full, video_context_token)

692
693

class BaseInternVLProcessingInfo(BaseProcessingInfo):
694
    """Basic image-only ProcessingInfo for InternVL-style models."""
695
696

    @abstractmethod
697
    def get_hf_processor(self, **kwargs: object) -> BaseInternVLProcessor:
698
699
        raise NotImplementedError

700
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
701
702
703
704
        return {"image": None}

    def get_num_image_tokens(
        self,
705
        *,
706
707
        image_width: int,
        image_height: int,
708
        processor: BaseInternVLProcessor,
709
710
711
712
713
    ) -> int:
        return processor.get_num_image_tokens(
            image_width=image_width,
            image_height=image_height,
        )
714

715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
    def get_image_size_with_most_features(self) -> ImageSize:
        processor = self.get_hf_processor()

        base_size = processor.image_size
        target_ratios = processor.resolve_target_ratios()

        largest_feature_size, largest_feature_pinpoint = 0, None
        for wr, hr in target_ratios:
            width, height = base_size * wr, base_size * hr

            feat_size = self.get_num_image_tokens(
                image_width=width,
                image_height=height,
                processor=processor,
            )
            if feat_size > largest_feature_size:
                largest_feature_size = feat_size
732
                largest_feature_pinpoint = ImageSize(width=width, height=height)
733
734
735
736
737
738

        if largest_feature_size == 0 or largest_feature_pinpoint is None:
            raise ValueError("Cannot have a largest feature size of 0!")

        return largest_feature_pinpoint

739
740
741
742
743
744
745
746
747
748
    def get_max_image_tokens(self) -> int:
        processor = self.get_hf_processor()
        target_width, target_height = self.get_image_size_with_most_features()

        return self.get_num_image_tokens(
            image_width=target_width,
            image_height=target_height,
            processor=processor,
        )

749
750
751
752

_I = TypeVar("_I", bound=BaseInternVLProcessingInfo)


753
754
class BaseInternVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
    """Basic image-only DummyInputsBuilder for InternVL-style models."""
755

756
757
758
759
760
761
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        return "<image>" * num_images

    def get_dummy_mm_data(
762
763
764
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
765
        mm_options: Mapping[str, BaseDummyOptions],
766
    ) -> MultiModalDataDict:
767
        target_width, target_height = self.info.get_image_size_with_most_features()
768
769
        num_images = mm_counts.get("image", 0)

770
        image_overrides = mm_options.get("image")
771

772
        return {
773
774
775
776
777
778
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
779
780
781
        }


782
class BaseInternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
783
    """Basic image-only MultiModalProcessor for InternVL-style models."""
784
785
786
787
788
789

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
790
        tok_kwargs: Mapping[str, object],
791
    ) -> BatchFeature:
792
793
794
795
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
796
            tok_kwargs=tok_kwargs,
797
        )
798

799
800
        hf_processor = self.info.get_hf_processor(**mm_kwargs)
        image_token_id = hf_processor.image_token_id
801
802
803
804

        # Since there may be extra tokens in the feature placeholders,
        # we need to pass the image token ID to the model to select the
        # tokens to merge from the vision encoder outputs
805
        processed_outputs["image_token_id"] = torch.tensor(image_token_id)
806
807
808
809
810

        return processed_outputs

    def _get_mm_fields_config(
        self,
811
        hf_inputs: BatchFeature,
812
813
814
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
815
        num_images = len(image_num_patches)
816
817
818

        return dict(
            pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
819
820
                "image", image_num_patches
            ),
821
822
            image_num_patches=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
823
            image_token_id=MultiModalFieldConfig.shared("image", num_images),
824
825
        )

826
    def _get_prompt_updates(
827
828
829
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
830
        out_mm_kwargs: MultiModalKwargsItems,
831
    ) -> Sequence[PromptUpdate]:
832
833
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)

834
835
836
        out_mm_data = out_mm_kwargs.get_data()
        if "image_num_patches" in out_mm_data:
            image_num_patches = out_mm_data["image_num_patches"]
837
838
            assert isinstance(image_num_patches, torch.Tensor)
            image_num_patches = image_num_patches.tolist()
839
        elif "image_embeds" in out_mm_data:
840
841
            # TODO: Use image size information in dictionary embedding inputs
            # to compute num_patches (similar to Qwen2-VL)
842
            image_num_patches = [None] * len(out_mm_data["image_embeds"])
843
844
845
846
847
        else:
            image_num_patches = []

        def get_replacement_internvl(item_idx: int):
            images = mm_items.get_items(
848
849
                "image", (ImageEmbeddingItems, ImageProcessorItems)
            )
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864

            if isinstance(images, ImageEmbeddingItems):
                feature_size = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
                feature_size = self.info.get_num_image_tokens(
                    image_width=image_size.width,
                    image_height=image_size.height,
                    processor=hf_processor,
                )

            num_patches = image_num_patches[item_idx]
            if num_patches is not None:
                assert isinstance(num_patches, int)

865
            return hf_processor.get_image_repl(feature_size, num_patches)
866

867
868
869
870
871
872
873
        return [
            PromptReplacement(
                modality="image",
                target="<image>",
                replacement=get_replacement_internvl,
            )
        ]
874
875


876
class InternVLProcessingInfo(BaseInternVLProcessingInfo):
877
878
879
880
881
882
883
884
885
886
    """InternVL ProcessingInfo extended for video processing"""

    @property
    def supports_video(self):
        return self.get_hf_processor().supports_video

    def get_supported_mm_limits(self):
        video_limit = {"video": None} if self.supports_video else {}
        return {**super().get_supported_mm_limits(), **video_limit}

887
    def get_video_token(self) -> str | None:
888
        text_model_type = self.get_hf_config().get_text_config().model_type
889
890
891
892
893
894
895
        video_token_map = {
            "qwen2": "<|video_pad|>",
            "qwen3": "<|video_pad|>",
            "qwen3_moe": "<|video_pad|>",
            "gpt_oss": "<|reserved_200000|>",
        }
        return video_token_map.get(text_model_type)
896
897
898
899
900
901
902
903
904
905
906
907

    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
        max_images = mm_counts.get("image", 0)
        max_videos = mm_counts.get("video", 0)

        processor = self.get_hf_processor()

        max_image_tokens = self.get_max_image_tokens() * max_images
908
        max_total_frames = (seq_len - max_image_tokens) // processor.num_image_token
909
910
911
        max_frames_per_video = max_total_frames // max(max_videos, 1)

        return max(max_frames_per_video, 1)
912

913
    def get_hf_processor(self, **kwargs: object) -> InternVLProcessor:
914
915
916
917
        return self.ctx.init_processor(
            InternVLProcessor,
            config=self.get_hf_config(),
            tokenizer=self.get_tokenizer(),
918
            video_token=self.get_video_token(),
919
            **kwargs,
920
921
922
        )


923
class InternVLDummyInputsBuilder(
924
925
    BaseInternVLDummyInputsBuilder[InternVLProcessingInfo]
):
926
927
928
929
930
931
932
933
934
935
936
    """InternVL DummyInputsBuilder extended for video support"""

    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_videos = mm_counts.get("video", 0)

        return super().get_dummy_text(mm_counts) + "<video>" * num_videos

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
937
        mm_options: Mapping[str, BaseDummyOptions],
938
    ) -> MultiModalDataDict:
939
        dummy_image = super().get_dummy_mm_data(seq_len, mm_counts, mm_options)
940
941
942
        if self.info.supports_video:
            config = self.info.get_hf_config()
            image_size: int = config.vision_config.image_size
943
944
945
            target_num_frames = self.info.get_num_frames_with_most_features(
                seq_len, mm_counts
            )
946
            num_videos = mm_counts.get("video", 0)
947
            video_overrides = mm_options.get("video")
948
            dummy_video = {
949
950
951
952
953
954
955
                "video": self._get_dummy_videos(
                    width=image_size,
                    height=image_size,
                    num_frames=target_num_frames,
                    num_videos=num_videos,
                    overrides=video_overrides,
                )
956
957
958
959
960
961
962
            }
        else:
            dummy_video = {}
        return {**dummy_image, **dummy_video}


class InternVLMultiModalProcessor(
963
964
    BaseInternVLMultiModalProcessor[InternVLProcessingInfo]
):
965
966
967
968
969
970
971
    """InternVL MultiModalProcessor extended for video support"""

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
972
        tok_kwargs: Mapping[str, object],
973
    ) -> BatchFeature:
974
975
976
        processed_outputs = super()._call_hf_processor(
            prompt, mm_data, mm_kwargs, tok_kwargs
        )
977
978

        hf_processor = self.info.get_hf_processor(**mm_kwargs)
979
980
981
982
        if (
            self.info.supports_video
            and (video_token_id := hf_processor.video_token_id) is not None
        ):
983
984
985
986
987
            processed_outputs["video_token_id"] = torch.tensor(video_token_id)
        return processed_outputs

    def _get_mm_fields_config(
        self,
988
        hf_inputs: BatchFeature,
989
990
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
991
        image_fields = super()._get_mm_fields_config(hf_inputs, hf_processor_mm_kwargs)
992
        if self.info.supports_video:
993
            video_num_patches = hf_inputs.get("video_num_patches", torch.empty(0))
994
995
996
            num_videos = len(video_num_patches)
            video_fields = dict(
                pixel_values_flat_video=MultiModalFieldConfig.flat_from_sizes(
997
998
                    "video", video_num_patches
                ),
999
                video_num_patches=MultiModalFieldConfig.batched("video"),
1000
                video_token_id=MultiModalFieldConfig.shared("video", num_videos),
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
            )
        else:
            video_fields = {}

        return image_fields | video_fields

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
1011
        out_mm_kwargs: MultiModalKwargsItems,
1012
    ) -> Sequence[PromptUpdate]:
1013
1014
1015
1016
1017
        prompt_repl = super()._get_prompt_updates(
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
            out_mm_kwargs=out_mm_kwargs,
        )
1018
1019
1020

        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)

1021
1022
1023
        out_mm_data = out_mm_kwargs.get_data()
        if "video_num_patches" in out_mm_data:
            video_num_patches = out_mm_data["video_num_patches"]
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
            assert isinstance(video_num_patches, torch.Tensor)
            video_num_patches = video_num_patches.tolist()
        else:
            video_num_patches = []

        def get_video_replacement_internvl(item_idx: int):
            feature_size = hf_processor.num_image_token
            num_patches = video_num_patches[item_idx]
            if num_patches is not None:
                assert isinstance(num_patches, int)

            return hf_processor.get_video_repl(
1036
1037
                feature_size, num_patches, video_context_token=hf_processor.video_token
            )
1038
1039

        if self.info.supports_video:
1040
1041
            prompt_repl = [
                *prompt_repl,
1042
1043
1044
1045
                PromptReplacement(
                    modality="video",
                    target="<video>",
                    replacement=get_video_replacement_internvl,
1046
                ),
1047
1048
            ]

1049
1050
1051
        return prompt_repl


1052
1053
1054
@MULTIMODAL_REGISTRY.register_processor(
    InternVLMultiModalProcessor,
    info=InternVLProcessingInfo,
1055
1056
1057
    dummy_inputs=InternVLDummyInputsBuilder,
)
class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
1058
1059
    supports_encoder_tp_data = True

1060
    @classmethod
1061
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
1062
1063
1064
1065
1066
1067
1068
        if modality.startswith("image"):
            return "<image>"
        if modality.startswith("video"):
            return "<video>"

        raise ValueError("Only image or video modality is supported")

1069
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
1070
1071
        super().__init__()

1072
1073
1074
1075
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

1076
1077
        self.config = config
        self.multimodal_config = multimodal_config
1078
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
1079
        self._patch_quant_config(config, quant_config)
1080
1081
1082
1083

        image_size = config.force_image_size or config.vision_config.image_size
        patch_size = config.vision_config.patch_size
        self.patch_size = patch_size
1084
1085
        self.patch_tokens = (image_size // patch_size) ** 2
        self.num_image_token = int(self.patch_tokens * (config.downsample_ratio**2))
1086
1087
1088
        self.downsample_ratio = config.downsample_ratio
        self.ps_version = config.ps_version

1089
1090
        llm_arch_name = config.text_config.architectures[0]
        self.is_mono = llm_arch_name == "InternLM2VEForCausalLM"
1091

1092
1093
1094
1095
1096
1097
1098
1099
        with self._mark_tower_model(vllm_config, {"image", "video"}):
            self.vision_model = self._init_vision_model(
                config,
                quant_config=quant_config,
                is_mono=self.is_mono,
                prefix=maybe_prefix(prefix, "vision_model"),
            )
            self.mlp1 = self._init_mlp1(config)
1100

1101
1102
1103
1104
1105
1106
        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=config.text_config,
                prefix=maybe_prefix(prefix, "language_model"),
            )
1107
1108

        self.img_context_token_id = None
1109
1110
        self.video_context_token_id = None

1111
        self.visual_token_mask = None
1112
        self.make_empty_intermediate_tensors = (
1113
1114
            self.language_model.make_empty_intermediate_tensors
        )
1115

1116
1117
1118
    def _patch_quant_config(
        self, config: PretrainedConfig, quant_config: QuantizationConfig
    ):
1119
1120
1121
1122
        # the awq models from OpenGVLab missing `modules_to_not_convert`
        # patch the quant_config to add `modules_to_not_convert` back
        if isinstance(quant_config, AWQConfig):
            text_config = config.text_config
1123
1124
1125
1126
            llm_quant_config = getattr(text_config, "quantization_config", None)
            if (not quant_config.modules_to_not_convert) and (
                llm_quant_config is not None
            ):
1127
1128
1129
1130
1131
                quant_config.modules_to_not_convert.append("vision_model")

    def _init_vision_model(
        self,
        config: PretrainedConfig,
1132
        quant_config: QuantizationConfig | None,
1133
1134
1135
1136
        *,
        is_mono: bool,
        prefix: str,
    ):
1137
        if not is_mono:
1138
            vision_feature_layer = config.select_layer
1139
            if vision_feature_layer < 0:
1140
1141
1142
                num_hidden_layers = (
                    config.vision_config.num_hidden_layers + vision_feature_layer + 1
                )
1143
1144
            else:
                num_hidden_layers = vision_feature_layer + 1
1145

1146
1147
            return InternVisionModel(
                config.vision_config,
1148
1149
1150
                quant_config=quant_config,
                num_hidden_layers_override=num_hidden_layers,
                prefix=prefix,
1151
            )
1152
1153
        else:
            return InternVisionPatchModel(config.vision_config)
1154

1155
    def _init_mlp1(self, config: PretrainedConfig) -> nn.Module:
1156
1157
1158
1159
        vit_hidden_size = config.vision_config.hidden_size
        llm_hidden_size = config.text_config.hidden_size

        return nn.Sequential(
1160
1161
1162
1163
            nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
            nn.Linear(
                vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size
            ),
1164
1165
1166
1167
            nn.GELU(),
            nn.Linear(llm_hidden_size, llm_hidden_size),
        )

1168
1169
1170
1171
1172
1173
    def pixel_shuffle(self, x, scale_factor=0.5):
        n, w, h, c = x.size()
        # N, W, H, C --> N, W, H * scale, C // scale
        x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
        # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
        x = x.permute(0, 2, 1, 3).contiguous()
1174
1175
1176
1177
1178
1179
1180
        x = x.view(
            n,
            int(h * scale_factor),
            int(w * scale_factor),
            int(c / (scale_factor * scale_factor)),
        )
        if self.ps_version == "v1":
1181
1182
1183
1184
1185
            pass
        else:
            x = x.permute(0, 2, 1, 3).contiguous()
        return x

1186
    def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor:
1187
1188
1189
        vit_embeds = self.vision_model(pixel_values=pixel_values)
        vit_embeds = vit_embeds[:, 1:, :]

1190
        h = w = int(vit_embeds.shape[1] ** 0.5)
1191
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
1192
1193
        vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
1194
1195
1196
1197
        vit_embeds = self.mlp1(vit_embeds)
        return vit_embeds

    def _parse_and_validate_image_input(
1198
        self, **kwargs: object
1199
    ) -> InternVLImageInputs | None:
1200
1201
        pixel_values_flat = kwargs.pop("pixel_values_flat", None)
        image_num_patches = kwargs.pop("image_num_patches", None)
1202
        image_embeds = kwargs.pop("image_embeds", None)
1203

1204
        if pixel_values_flat is None and image_embeds is None:
1205
1206
            return None

1207
1208
1209
        if image_embeds is not None:
            return InternVLImageEmbeddingInputs(
                type="image_embeds",
1210
                data=image_embeds,
1211
1212
            )

1213
        image_token_id = kwargs["image_token_id"]
1214
1215
1216
1217
1218
        if isinstance(image_token_id, torch.Tensor):
            image_token_id = image_token_id.flatten().unique().item()

        assert isinstance(image_token_id, int)
        self.img_context_token_id = image_token_id
1219

1220
        if pixel_values_flat is not None:
1221
1222
            expected_h = expected_w = self.config.vision_config.image_size
            resolve_bindings = {"h": expected_h, "w": expected_w}
1223

1224
1225
            return InternVLImagePixelInputs(
                type="pixel_values",
1226
                pixel_values_flat=pixel_values_flat,
1227
                num_patches=image_num_patches,
1228
                resolve_bindings=resolve_bindings,
1229
            )
1230
1231
1232

        raise AssertionError("This line should be unreachable.")

1233
    def _parse_and_validate_video_input(
1234
        self, **kwargs: object
1235
    ) -> InternVLVideoPixelInputs | None:
1236
1237
1238
1239
1240
1241
1242
1243
        pixel_values_flat_video = kwargs.pop("pixel_values_flat_video", None)
        video_num_patches = kwargs.pop("video_num_patches", None)
        video_embeds = kwargs.pop("image_embeds", None)

        if pixel_values_flat_video is None and video_embeds is None:
            return None

        if video_embeds is not None:
1244
            return InternVLVideoEmbeddingInputs(
1245
                type="video_embeds",
1246
                data=video_embeds,
1247
1248
1249
            )

        video_token_id = kwargs["video_token_id"]
1250
1251
1252
1253
1254
        if isinstance(video_token_id, torch.Tensor):
            video_token_id = video_token_id.flatten().unique().item()

        assert isinstance(video_token_id, int)
        self.video_context_token_id = video_token_id
1255
1256

        if pixel_values_flat_video is not None:
1257
1258
            expected_h = expected_w = self.config.vision_config.image_size
            resolve_bindings = {"h": expected_h, "w": expected_w}
1259
1260
1261

            return InternVLVideoPixelInputs(
                type="pixel_values_videos",
1262
                pixel_values_flat=pixel_values_flat_video,
1263
                num_patches=video_num_patches,
1264
                resolve_bindings=resolve_bindings,
1265
1266
1267
1268
            )

        raise AssertionError("This line should be unreachable.")

1269
    def _process_vision_input(
1270
        self,
1271
        image_input: InternVLImageInputs | InternVLVideoInputs,
1272
    ) -> tuple[torch.Tensor, ...]:
1273
1274
1275
1276
        if (
            image_input["type"] == "image_embeds"
            or image_input["type"] == "video_embeds"
        ):
1277
1278
            return image_input["data"]

1279
        image_embeds = self.extract_feature(image_input["pixel_values_flat"])
1280

1281
        num_patches = image_input["num_patches"]
1282
1283

        # Only one image in the current batch
1284
        if len(num_patches) == 1:
1285
            return (image_embeds.view(-1, self.config.text_config.hidden_size),)
1286
1287
1288
1289

        # NOTE: Image embeddings are split into separate tensors for each image
        # by the size of each embedding.
        feature_size = image_embeds.shape[1]
1290
        image_embeds = image_embeds.view(-1, self.config.text_config.hidden_size)
1291
        image_feature_sizes = [
1292
            num_patches * feature_size for num_patches in num_patches
1293
        ]
1294
        return image_embeds.split(image_feature_sizes)
1295

1296
1297
1298
1299
1300
1301
    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        modalities = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
1302
1303
1304
1305
1306
1307
1308
            if (
                input_key in ("pixel_values_flat", "image_embeds")
                and "images" not in modalities
            ):
                modalities["images"] = self._parse_and_validate_image_input(**kwargs)
            if input_key in ("pixel_values_flat_video",) and "videos" not in modalities:
                modalities["videos"] = self._parse_and_validate_video_input(**kwargs)
1309
1310
1311

        return modalities

1312
    def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
1313
        if self.is_mono:
1314
            assert self.img_context_token_id is not None
1315
1316
1317
            self.visual_token_mask = (input_ids == self.img_context_token_id).reshape(
                -1, 1
            )
1318
        else:
1319
            self.visual_token_mask = None
1320

1321
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
1322
1323
        modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not modalities:
1324
            return []
1325

1326
1327
1328
1329
1330
1331
1332
1333
1334
        # The result multimodal_embeddings is tuple of tensors, with each
        # tensor correspoending to a multimodal data item (image or video).
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in modalities:
            if modality == "images":
                image_input = modalities["images"]
1335
1336
                image_embeddings = self._process_vision_input(image_input)
                multimodal_embeddings += tuple(image_embeddings)
1337
1338
            if modality == "videos":
                video_input = modalities["videos"]
1339
                video_embeddings = self._process_vision_input(video_input)
1340
                multimodal_embeddings += tuple(video_embeddings)
1341
1342

        return multimodal_embeddings
1343

1344
    def embed_input_ids(
1345
1346
        self,
        input_ids: torch.Tensor,
1347
        multimodal_embeddings: MultiModalEmbeddings | None = None,
1348
        *,
1349
        is_multimodal: torch.Tensor | None = None,
1350
        handle_oov_mm_token: bool = False,
1351
    ) -> torch.Tensor:
1352
        if multimodal_embeddings is not None and len(multimodal_embeddings) > 0:
1353
            self._set_visual_token_mask(input_ids)
1354
1355
1356

        # This is to satisfy the type checker for each overload
        if multimodal_embeddings is None or is_multimodal is None:
1357
            return super().embed_input_ids(input_ids)
1358

1359
        return super().embed_input_ids(
1360
1361
1362
1363
1364
            input_ids,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
        )
1365

1366
1367
    def forward(
        self,
1368
        input_ids: torch.Tensor | None,
1369
        positions: torch.Tensor,
1370
1371
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1372
        **kwargs: object,
1373
    ) -> IntermediateTensors:
1374
        if intermediate_tensors is not None:
1375
            inputs_embeds = None
1376

1377
1378
1379
1380
1381
1382
        forward_kwargs = {
            "input_ids": input_ids,
            "positions": positions,
            "intermediate_tensors": intermediate_tensors,
            "inputs_embeds": inputs_embeds,
        }
1383

1384
        # Only required if the model is mono-architecture
1385
        if self.visual_token_mask is not None:
1386
            forward_kwargs.update({"visual_token_mask": self.visual_token_mask})
1387
            self.visual_token_mask = None
1388

1389
        hidden_states = self.language_model.model(**forward_kwargs)
1390
1391
        return hidden_states

1392
1393
1394
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
1395
    ) -> torch.Tensor | None:
1396
        return self.language_model.compute_logits(hidden_states)
1397

1398
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1399
1400
        # unused modules appear in OpenGVLab/InternVideo2_5_Chat_8B
        skip_prefixes = [
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
            "action_embed",
            "temporal_embed",
            "track_embed",
            "track_embed_decoder",
            "box_token",
            "cg_criterion",
            "cg_model",
            "loc_encoder",
            "loc_decoder",
            "sam",
            "temporal_token",
            "track_token",
1413
1414
        ]
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
1415
        return loader.load_weights(weights)
1416
1417
1418
1419
1420
1421
1422
1423

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector="mlp1",
1424
1425
            tower_model="vision_model",
        )
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439

    def get_num_mm_encoder_tokens(self, num_image_tokens: int) -> int:
        if num_image_tokens <= 0 or self.num_image_token <= 0:
            return 0

        num_patches = num_image_tokens // self.num_image_token
        return num_patches * (self.patch_tokens + 1)

    def get_num_mm_connector_tokens(self, num_vision_tokens: int) -> int:
        if num_vision_tokens <= 0 or self.num_image_token <= 0:
            return 0

        num_patches = num_vision_tokens // (self.patch_tokens + 1)
        return num_patches * self.num_image_token