internvl.py 49.8 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_internvl_chat.py
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
9
from abc import ABC, abstractmethod
10
from collections.abc import Iterable, Mapping, Sequence
11
from typing import Any, Literal, Optional, TypedDict, TypeVar, Union
12

13
import numpy.typing as npt
14
15
16
17
import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
18
from transformers import BatchEncoding, PretrainedConfig, TensorType
19

20
from vllm.config import VllmConfig
21
22
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
23
24
from vllm.model_executor.models.intern_vit import (InternVisionModel,
                                                   InternVisionPatchModel)
25
from vllm.model_executor.sampling_metadata import SamplingMetadata
26
from vllm.multimodal import MULTIMODAL_REGISTRY
27
from vllm.multimodal.image import convert_image_mode
28
29
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
                                    MultiModalKwargs, NestedTensors)
30
31
32
33
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
                                   ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
                                        BaseProcessingInfo, PromptReplacement,
34
                                        PromptUpdate, PromptUpdateDetails)
35
from vllm.multimodal.profiling import BaseDummyInputsBuilder
36
from vllm.sequence import IntermediateTensors
37
from vllm.transformers_utils.tokenizer import AnyTokenizer
38

39
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
40
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
41
                    maybe_prefix, merge_multimodal_embeddings)
42
43
44
45
46
47
48
49
50
51
52

IMG_START = '<img>'
IMG_END = '</img>'
IMG_CONTEXT = '<IMG_CONTEXT>'

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)


class InternVLImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
53
    pixel_values_flat: torch.Tensor
54
    """
55
56
    Shape:
    `(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
57
    """
58
59
60
61

    num_patches: torch.Tensor
    """Shape: `(batch_size * num_images)`"""

62

63
64
class InternVLImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
65
    data: Union[torch.Tensor, list[torch.Tensor]]
66
67
68
    """ 
    A tensor of shape `(num_images, total_image_feature_size, hidden_size)`
    or a list of tensors of shape `(total_image_feature_size, hidden_size)`
69
70
71
72
73
74
75
76
77

    `hidden_size` must match the hidden size of language model backbone.
    """


InternVLImageInputs = Union[InternVLImagePixelInputs,
                            InternVLImageEmbeddingInputs]


78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
class InternVLVideoPixelInputs(TypedDict):
    type: Literal["pixel_values_videos"]
    pixel_values_flat: torch.Tensor
    """
    Shape:
    `(batch_size * num_video * num_frames, num_channels, height, width)`
    """

    num_patches: torch.Tensor
    """Shape: `(batch_size * num_images)`"""


class InternVLVideoEmbeddingInputs(TypedDict):
    type: Literal["video_embeds"]
    data: Union[torch.Tensor, list[torch.Tensor]]
    """ 
    A tensor of shape `(num_videos, total_video_feature_size, hidden_size)`
    or a list of tensors of shape `(total_video_feature_size, hidden_size)`

    `hidden_size` must match the hidden size of language model backbone.
    """


InternVLVideoInputs = Union[InternVLVideoPixelInputs,
                            InternVLVideoEmbeddingInputs]


105
106
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def build_transform(input_size: int):
107
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
108
    return T.Compose([
109
        T.Lambda(lambda img: convert_image_mode(img, 'RGB')),
110
111
112
113
114
115
116
        T.Resize((input_size, input_size),
                 interpolation=T.InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])


117
118
119
120
121
122
123
124
125
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def find_closest_aspect_ratio(
    aspect_ratio: float,
    target_ratios: list[tuple[int, int]],
    *,
    width: int,
    height: int,
    image_size: int,
) -> tuple[int, int]:
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio


141
142
143
144
145
146
147
def resolve_internvl_min_max_num(
    *,
    min_dynamic_patch: int,
    max_dynamic_patch: int,
    dynamic_image_size: bool,
    use_thumbnail: bool,
) -> tuple[int, int]:
148
    min_dynamic_patch = min_dynamic_patch if dynamic_image_size else 1
149
150
151
152
153
154
155
    max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1

    if use_thumbnail and max_dynamic_patch != 1:
        max_dynamic_patch += 1

    return min_dynamic_patch, max_dynamic_patch

156

157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
def get_internvl_target_ratios(
    min_num: int,
    max_num: int,
) -> list[tuple[int, int]]:
    target_ratios = {(i, j)
                     for n in range(min_num, max_num + 1)
                     for i in range(1, n + 1)
                     for j in range(1, n + 1) if min_num <= i * j <= max_num}
    return sorted(target_ratios, key=lambda x: x[0] * x[1])


def calculate_internvl_targets(
    *,
    orig_width: int,
    orig_height: int,
    target_ratios: list[tuple[int, int]],
    image_size: int,
    use_thumbnail: bool,
) -> tuple[int, int, int]:
    aspect_ratio = orig_width / orig_height
177
178

    # find the closest aspect ratio to the target
179
180
181
182
183
184
185
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio,
        target_ratios,
        width=orig_width,
        height=orig_height,
        image_size=image_size,
    )
186
187
188
189
190
191

    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

192
193
194
    # add thumbnail image if num_blocks != 1
    if use_thumbnail and blocks != 1:
        blocks += 1
195

196
    return blocks, target_width, target_height
197
198


199
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
200
201
202
203
204
205
206
def dynamic_preprocess_internvl(
    image: Image.Image,
    *,
    target_ratios: list[tuple[int, int]],
    image_size: int,
    use_thumbnail: bool,
) -> list[Image.Image]:
207
208
    orig_width, orig_height = image.size

209
    # calculate the number of blocks without thumbnail
210
211
212
213
214
215
216
217
    blocks, target_width, target_height = calculate_internvl_targets(
        orig_width=orig_width,
        orig_height=orig_height,
        target_ratios=target_ratios,
        image_size=image_size,
        use_thumbnail=False,
    )

218
219
220
221
222
223
224
225
226
227
228
    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = ((i % (target_width // image_size)) * image_size,
               (i // (target_width // image_size)) * image_size,
               ((i % (target_width // image_size)) + 1) * image_size,
               ((i // (target_width // image_size)) + 1) * image_size)
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
229

230
    assert len(processed_images) == blocks
231

232
233
234
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
235

236
237
238
239
    return processed_images


# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
240
241
242
243
244
245
246
247
248
249
def image_to_pixel_values_internvl(
    image: Image.Image,
    *,
    input_size: int,
    min_num: int,
    max_num: int,
    use_thumbnail: bool,
) -> torch.Tensor:
    target_ratios = get_internvl_target_ratios(min_num, max_num)

250
    transform = build_transform(input_size=input_size)
251
252
253
254
255
256
257
258
    images = dynamic_preprocess_internvl(
        image,
        target_ratios=target_ratios,
        image_size=input_size,
        use_thumbnail=use_thumbnail,
    )

    pixel_values = torch.stack([transform(image) for image in images])
259
260
261
    return pixel_values


262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def video_to_pixel_values_internvl(
    video: npt.NDArray,
    *,
    input_size: int,
    min_num: int,
    max_num: int,
    use_thumbnail: bool,
) -> torch.Tensor:
    target_ratios = get_internvl_target_ratios(min_num, max_num)

    transform = build_transform(input_size=input_size)
    frames_list = list[Image.Image]()
    for frame in video:
        pil_frame = dynamic_preprocess_internvl(
            Image.fromarray(frame, mode="RGB"),
            target_ratios=target_ratios,
            image_size=input_size,
            use_thumbnail=use_thumbnail,
        )
        assert len(pil_frame) == 1
        frames_list.extend(pil_frame)

    pixel_values = torch.stack([transform(image) for image in frames_list])
    return pixel_values


289
290
291
292
class BaseInternVLProcessor(ABC):
    """
    This model doesn't define its own HF processor,
    so we implement our own one here.
293

294
295
296
    The code to insert image tokens is based on:
    https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py#L252
    """
297

298
299
300
301
302
    def __init__(
        self,
        config: PretrainedConfig,
        tokenizer: AnyTokenizer,
        *,
303
        min_dynamic_patch: Optional[int] = None,
304
305
306
307
        max_dynamic_patch: Optional[int] = None,
        dynamic_image_size: Optional[bool] = None,
    ) -> None:
        super().__init__()
308

309
310
        self.config = config
        self.tokenizer = tokenizer
311

312
313
        image_size: int = config.vision_config.image_size
        patch_size: int = config.vision_config.patch_size
314

315
316
317
        if min_dynamic_patch is None:
            min_dynamic_patch = config.min_dynamic_patch
        assert isinstance(min_dynamic_patch, int)
318

319
320
321
        if max_dynamic_patch is None:
            max_dynamic_patch = config.max_dynamic_patch
        assert isinstance(max_dynamic_patch, int)
322

323
324
325
326
        if dynamic_image_size is None:
            dynamic_image_size = config.dynamic_image_size
        assert isinstance(dynamic_image_size, bool)

327
328
329
        self.num_image_token = int(
            (image_size // patch_size)**2 * (config.downsample_ratio**2))
        self.image_size = image_size
330
        self.min_dynamic_patch = min_dynamic_patch
331
332
333
334
335
336
337
338
339
340
        self.max_dynamic_patch = max_dynamic_patch
        self.dynamic_image_size = dynamic_image_size
        self.use_thumbnail: bool = config.use_thumbnail

    @property
    @abstractmethod
    def image_token_id(self) -> int:
        raise NotImplementedError

    @abstractmethod
341
    def get_image_repl(
342
343
344
        self,
        feature_size: int,
        num_patches: Optional[int],
345
    ) -> PromptUpdateDetails[str]:
346
        raise NotImplementedError
347

348
    def resolve_min_max_num(
349
        self,
350
        *,
351
        min_dynamic_patch: Optional[int] = None,
352
353
354
355
        max_dynamic_patch: Optional[int] = None,
        dynamic_image_size: Optional[bool] = None,
        use_thumbnail: Optional[bool] = None,
    ) -> tuple[int, int]:
356
357
        min_dynamic_patch = (self.min_dynamic_patch if min_dynamic_patch
                             is None else min_dynamic_patch)
358
359
360
361
362
363
364
365
366
367
368
369
370
        max_dynamic_patch = (self.max_dynamic_patch if max_dynamic_patch
                             is None else max_dynamic_patch)
        dynamic_image_size = (self.dynamic_image_size if dynamic_image_size
                              is None else dynamic_image_size)
        use_thumbnail = (self.use_thumbnail
                         if use_thumbnail is None else use_thumbnail)

        return resolve_internvl_min_max_num(
            min_dynamic_patch=min_dynamic_patch,
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
            use_thumbnail=use_thumbnail,
        )
371

372
373
374
    def resolve_target_ratios(
        self,
        *,
375
        min_dynamic_patch: Optional[int] = None,
376
377
378
379
380
        max_dynamic_patch: Optional[int] = None,
        dynamic_image_size: Optional[bool] = None,
        use_thumbnail: Optional[bool] = None,
    ) -> list[tuple[int, int]]:
        min_num, max_num = self.resolve_min_max_num(
381
            min_dynamic_patch=min_dynamic_patch,
382
383
384
385
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
            use_thumbnail=use_thumbnail,
        )
386

387
        return get_internvl_target_ratios(min_num, max_num)
388

389
    def get_num_image_tokens(
390
        self,
391
392
393
394
395
396
397
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        target_ratios = self.resolve_target_ratios(
            use_thumbnail=False,  # Applied in calculate_targets
        )
398

399
400
401
402
403
404
405
        num_patches, _, _ = calculate_internvl_targets(
            orig_width=image_width,
            orig_height=image_height,
            image_size=self.image_size,
            target_ratios=target_ratios,
            use_thumbnail=self.use_thumbnail,
        )
406

407
408
409
410
411
        return num_patches * self.num_image_token

    def _images_to_pixel_values_lst(
        self,
        images: list[Image.Image],
412
        min_dynamic_patch: Optional[int] = None,
413
414
415
416
        max_dynamic_patch: Optional[int] = None,
        dynamic_image_size: Optional[bool] = None,
    ) -> list[torch.Tensor]:
        min_num, max_num = self.resolve_min_max_num(
417
            min_dynamic_patch=min_dynamic_patch,
418
419
420
421
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
            use_thumbnail=False,  # Applied in image_to_pixel_values
        )
422

423
424
425
426
427
428
429
430
431
        return [
            image_to_pixel_values_internvl(
                image,
                input_size=self.image_size,
                min_num=min_num,
                max_num=max_num,
                use_thumbnail=self.use_thumbnail,
            ) for image in images
        ]
432

433
    def _preprocess_image(
434
        self,
435
436
        text: list[str],
        images: list[Image.Image],
437
        min_dynamic_patch: Optional[int] = None,
438
        max_dynamic_patch: Optional[int] = None,
439
        dynamic_image_size: Optional[bool] = None,
440
    ) -> tuple[list[str], dict[str, torch.Tensor]]:
441
442
        if len(images) == 0:
            image_inputs = {}
443
        else:
444
445
            pixel_values_lst = self._images_to_pixel_values_lst(
                images,
446
                min_dynamic_patch=min_dynamic_patch,
447
448
449
                max_dynamic_patch=max_dynamic_patch,
                dynamic_image_size=dynamic_image_size,
            )
450
451
452
453
454
            image_inputs: dict[str, NestedTensors] = {
                "pixel_values_flat":
                torch.cat(pixel_values_lst),
                "image_num_patches":
                torch.tensor([len(item) for item in pixel_values_lst]),
455
456
457
458
459
460
            }

            for pixel_values in pixel_values_lst:
                num_patches = pixel_values.shape[0]
                feature_size = num_patches * self.num_image_token

461
462
                image_repl = self.get_image_repl(feature_size, num_patches)
                text = [t.replace('<image>', image_repl.full, 1) for t in text]
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
        return text, image_inputs

    def _make_batch_input(self,
                          input_item: Optional[Union[Any, list[Any]]] = None):
        if input_item is None:
            input_item = []
        if not isinstance(input_item, list):
            input_item = [input_item]
        return input_item

    def __call__(
        self,
        text: Optional[Union[str, list[str]]] = None,
        images: Optional[Union[Image.Image, list[Image.Image]]] = None,
        min_dynamic_patch: Optional[int] = None,
        max_dynamic_patch: Optional[int] = None,
        dynamic_image_size: Optional[bool] = None,
        return_tensors: Optional[Union[str, TensorType]] = None,
    ) -> Mapping[str, NestedTensors]:
        text, images = [self._make_batch_input(x) for x in (text, images)]

        text, image_inputs = self._preprocess_image(
            text=text,
            images=images,
            min_dynamic_patch=min_dynamic_patch,
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
        )
491
492
493

        text_inputs = self.tokenizer(text)

494
495
496
497
        return {
            **BatchEncoding(text_inputs, tensor_type=return_tensors),
            **image_inputs,
        }
498
499


500
class InternVLProcessor(BaseInternVLProcessor):
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
    """
    HF Processor for InternVLChatModel with extended video processing logic.

    Code for video processing is adapted from video example:
    https://huggingface.co/OpenGVLab/InternVL3-1B#inference-with-transformers
    """

    def __init__(
        self,
        config: PretrainedConfig,
        tokenizer: AnyTokenizer,
        *,
        min_dynamic_patch: Optional[int] = None,
        max_dynamic_patch: Optional[int] = None,
        dynamic_image_size: Optional[bool] = None,
        video_token: Optional[str] = None,
    ) -> None:
        super().__init__(
            config=config,
            tokenizer=tokenizer,
            min_dynamic_patch=min_dynamic_patch,
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
        )
        # add extra video token for video processing
        self.video_token = video_token
527
528
529
530
531

    @property
    def image_token_id(self) -> int:
        return self.tokenizer.get_vocab()[IMG_CONTEXT]

532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
    @property
    def video_token_id(self) -> Optional[int]:
        if self.video_token is None:
            return None
        return self.tokenizer.get_vocab().get(self.video_token, None)

    @property
    def supports_video(self) -> bool:
        return self.video_token_id is not None

    def _videos_to_pixel_values_lst(
        self,
        videos: list[npt.NDArray],
        dynamic_image_size: Optional[bool] = None,
    ) -> list[torch.Tensor]:
        min_num, max_num = self.resolve_min_max_num(
            min_dynamic_patch=1,
            max_dynamic_patch=1,
            dynamic_image_size=dynamic_image_size,
            use_thumbnail=False,  # Applied in image_to_pixel_values
        )

        return [
            video_to_pixel_values_internvl(
                video,
                input_size=self.image_size,
                min_num=min_num,
                max_num=max_num,
                use_thumbnail=False,
            ) for video in videos
        ]

    def _preprocess_video(
        self,
        text: list[str],
        videos: list[npt.NDArray],
        dynamic_image_size: Optional[bool] = None,
    ):
        if len(videos) == 0 or not self.supports_video:
            video_inputs = {}
        else:
            pixel_values_lst_video = self._videos_to_pixel_values_lst(
                videos,
                dynamic_image_size=dynamic_image_size,
            )
            video_inputs: dict[str, NestedTensors] = {
                "pixel_values_flat_video":
                torch.cat(pixel_values_lst_video),
                "video_num_patches":
                torch.tensor([len(item) for item in pixel_values_lst_video]),
            }

            for pixel_values in pixel_values_lst_video:
                num_patches = pixel_values.shape[0]

                video_repl = self.get_video_repl(self.num_image_token,
                                                 num_patches, self.video_token)
                text = [t.replace('<video>', video_repl.full, 1) for t in text]
        return text, video_inputs

    def __call__(
        self,
        text: Optional[Union[str, list[str]]] = None,
        images: Optional[Union[Image.Image, list[Image.Image]]] = None,
        videos: Optional[Union[npt.NDArray, list[npt.NDArray]]] = None,
        min_dynamic_patch: Optional[int] = None,
        max_dynamic_patch: Optional[int] = None,
        dynamic_image_size: Optional[bool] = None,
        return_tensors: Optional[Union[str, TensorType]] = None,
    ) -> Mapping[str, NestedTensors]:
        text, images, videos = [
            self._make_batch_input(x) for x in (text, images, videos)
        ]

        text, image_inputs = self._preprocess_image(
            text=text,
            images=images,
            min_dynamic_patch=min_dynamic_patch,
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
        )

        text, video_inputs = self._preprocess_video(
            text=text,
            videos=videos,
            dynamic_image_size=dynamic_image_size,
        )

        text_inputs = self.tokenizer(text)

        return {
            **BatchEncoding(text_inputs, tensor_type=return_tensors),
            **image_inputs,
            **video_inputs,
        }

628
    def get_image_repl(
629
630
631
        self,
        feature_size: int,
        num_patches: Optional[int],
632
633
634
    ) -> PromptUpdateDetails[str]:
        repl_features = IMG_CONTEXT * feature_size
        repl_full = IMG_START + repl_features + IMG_END
635

636
        return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
637

638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
    def get_video_repl(
        self,
        feature_size: int,
        num_patches: Optional[int] = None,
        video_context_token: str = IMG_CONTEXT,
    ) -> PromptUpdateDetails[str]:
        repl_features = video_context_token * self.num_image_token
        repl_features_with_sep = IMG_START + repl_features + IMG_END
        # num_patches is equal to num_frames
        repl_full = ''.join([
            f'Frame{i+1}: {repl_features_with_sep}' for i in range(num_patches)
        ])

        return PromptUpdateDetails.select_text(repl_full, video_context_token)

653
654

class BaseInternVLProcessingInfo(BaseProcessingInfo):
655
    """Basic image-only ProcessingInfo for InternVL-style models."""
656
657
658

    @abstractmethod
    def get_hf_processor(
659
660
        self,
        *,
661
        min_dynamic_patch: Optional[int] = None,
662
        max_dynamic_patch: Optional[int] = None,
663
        dynamic_image_size: Optional[bool] = None,
664
        **kwargs: object,
665
666
667
668
669
670
671
672
    ) -> BaseInternVLProcessor:
        raise NotImplementedError

    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}

    def get_num_image_tokens(
        self,
673
        *,
674
675
676
677
678
679
680
681
682
683
684
        image_width: int,
        image_height: int,
        processor: Optional[BaseInternVLProcessor],
    ) -> int:
        if processor is None:
            processor = self.get_hf_processor()

        return processor.get_num_image_tokens(
            image_width=image_width,
            image_height=image_height,
        )
685

686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
    def get_image_size_with_most_features(self) -> ImageSize:
        processor = self.get_hf_processor()

        base_size = processor.image_size
        target_ratios = processor.resolve_target_ratios()

        largest_feature_size, largest_feature_pinpoint = 0, None
        for wr, hr in target_ratios:
            width, height = base_size * wr, base_size * hr

            feat_size = self.get_num_image_tokens(
                image_width=width,
                image_height=height,
                processor=processor,
            )
            if feat_size > largest_feature_size:
                largest_feature_size = feat_size
                largest_feature_pinpoint = ImageSize(width=width,
                                                     height=height)

        if largest_feature_size == 0 or largest_feature_pinpoint is None:
            raise ValueError("Cannot have a largest feature size of 0!")

        return largest_feature_pinpoint

711
712
713
714
715
716
717
718
719
720
    def get_max_image_tokens(self) -> int:
        processor = self.get_hf_processor()
        target_width, target_height = self.get_image_size_with_most_features()

        return self.get_num_image_tokens(
            image_width=target_width,
            image_height=target_height,
            processor=processor,
        )

721
722
723
724

_I = TypeVar("_I", bound=BaseInternVLProcessingInfo)


725
726
class BaseInternVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
    """Basic image-only DummyInputsBuilder for InternVL-style models."""
727

728
729
730
731
732
733
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        return "<image>" * num_images

    def get_dummy_mm_data(
734
735
736
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
737
    ) -> MultiModalDataDict:
738
739
740
741
        target_width, target_height = \
            self.info.get_image_size_with_most_features()
        num_images = mm_counts.get("image", 0)

742
        return {
743
744
745
746
747
748
749
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images)
        }


750
751
class BaseInternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
    """ Basic image-only MultiModalProcessor for InternVL-style models."""
752
753
754
755
756
757

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
758
    ) -> Mapping[str, NestedTensors]:
759
760
761
762
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
763
        )
764

765
766
        hf_processor = self.info.get_hf_processor(**mm_kwargs)
        image_token_id = hf_processor.image_token_id
767
768
769
770

        # Since there may be extra tokens in the feature placeholders,
        # we need to pass the image token ID to the model to select the
        # tokens to merge from the vision encoder outputs
771
        processed_outputs["image_token_id"] = torch.tensor(image_token_id)
772
773
774
775
776

        return processed_outputs

    def _get_mm_fields_config(
        self,
777
        hf_inputs: Mapping[str, NestedTensors],
778
779
780
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
781
        num_images = len(image_num_patches)
782
783
784
785
786
787

        return dict(
            pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
                "image", image_num_patches),
            image_num_patches=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
788
            image_token_id=MultiModalFieldConfig.shared("image", num_images),
789
790
        )

791
    def _get_prompt_updates(
792
793
794
795
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
796
    ) -> Sequence[PromptUpdate]:
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)

        if "image_num_patches" in out_mm_kwargs:
            image_num_patches = out_mm_kwargs["image_num_patches"]
            assert isinstance(image_num_patches, torch.Tensor)
            image_num_patches = image_num_patches.tolist()
        elif "image_embeds" in out_mm_kwargs:
            # TODO: Use image size information in dictionary embedding inputs
            # to compute num_patches (similar to Qwen2-VL)
            image_num_patches = [None] * len(out_mm_kwargs["image_embeds"])
        else:
            image_num_patches = []

        def get_replacement_internvl(item_idx: int):
            images = mm_items.get_items(
                "image", (ImageEmbeddingItems, ImageProcessorItems))

            if isinstance(images, ImageEmbeddingItems):
                feature_size = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
                feature_size = self.info.get_num_image_tokens(
                    image_width=image_size.width,
                    image_height=image_size.height,
                    processor=hf_processor,
                )

            num_patches = image_num_patches[item_idx]
            if num_patches is not None:
                assert isinstance(num_patches, int)

828
            return hf_processor.get_image_repl(feature_size, num_patches)
829

830
831
832
833
834
835
836
        return [
            PromptReplacement(
                modality="image",
                target="<image>",
                replacement=get_replacement_internvl,
            )
        ]
837
838


839
class InternVLProcessingInfo(BaseInternVLProcessingInfo):
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
    """InternVL ProcessingInfo extended for video processing"""

    @property
    def supports_video(self):
        return self.get_hf_processor().supports_video

    def get_supported_mm_limits(self):
        video_limit = {"video": None} if self.supports_video else {}
        return {**super().get_supported_mm_limits(), **video_limit}

    def get_video_token(self) -> Optional[str]:
        text_model_type = self.get_hf_config().get_text_config().model_type
        if text_model_type == "qwen2":
            return "<|video_pad|>"
        return None

    def get_num_frames_with_most_features(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> int:
        max_images = mm_counts.get("image", 0)
        max_videos = mm_counts.get("video", 0)

        processor = self.get_hf_processor()

        max_image_tokens = self.get_max_image_tokens() * max_images
        max_total_frames = (seq_len -
                            max_image_tokens) // processor.num_image_token
        max_frames_per_video = max_total_frames // max(max_videos, 1)

        return max(max_frames_per_video, 1)
872

873
874
875
    def get_hf_processor(
        self,
        *,
876
        min_dynamic_patch: Optional[int] = None,
877
878
        max_dynamic_patch: Optional[int] = None,
        dynamic_image_size: Optional[bool] = None,
879
        **kwargs: object,
880
    ) -> InternVLProcessor:
881
882
883
884
885
886
887
        if min_dynamic_patch is not None:
            kwargs["min_dynamic_patch"] = min_dynamic_patch
        if max_dynamic_patch is not None:
            kwargs["max_dynamic_patch"] = max_dynamic_patch
        if dynamic_image_size is not None:
            kwargs["dynamic_image_size"] = dynamic_image_size

888
889
        kwargs["video_token"] = self.get_video_token()

890
891
892
893
894
        return self.ctx.init_processor(
            InternVLProcessor,
            config=self.get_hf_config(),
            tokenizer=self.get_tokenizer(),
            **kwargs,
895
896
897
        )


898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
class InternVLDummyInputsBuilder(
        BaseInternVLDummyInputsBuilder[InternVLProcessingInfo]):
    """InternVL DummyInputsBuilder extended for video support"""

    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_videos = mm_counts.get("video", 0)

        return super().get_dummy_text(mm_counts) + "<video>" * num_videos

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> MultiModalDataDict:
        dummy_image = super().get_dummy_mm_data(seq_len=seq_len,
                                                mm_counts=mm_counts)
        if self.info.supports_video:
            config = self.info.get_hf_config()
            image_size: int = config.vision_config.image_size
            target_num_frames = \
                self.info.get_num_frames_with_most_features(seq_len, mm_counts)
            num_videos = mm_counts.get("video", 0)
            dummy_video = {
                "video":
                self._get_dummy_videos(width=image_size,
                                       height=image_size,
                                       num_frames=target_num_frames,
                                       num_videos=num_videos)
            }
        else:
            dummy_video = {}
        return {**dummy_image, **dummy_video}


class InternVLMultiModalProcessor(
        BaseInternVLMultiModalProcessor[InternVLProcessingInfo]):
    """InternVL MultiModalProcessor extended for video support"""

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, NestedTensors]:
        processed_outputs = super()._call_hf_processor(prompt, mm_data,
                                                       mm_kwargs)

        hf_processor = self.info.get_hf_processor(**mm_kwargs)
        if self.info.supports_video and (
                video_token_id := hf_processor.video_token_id) is not None:
            processed_outputs["video_token_id"] = torch.tensor(video_token_id)
        return processed_outputs

    def _get_mm_fields_config(
        self,
        hf_inputs: Mapping[str, NestedTensors],
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        image_fields = super()._get_mm_fields_config(hf_inputs,
                                                     hf_processor_mm_kwargs)
        if self.info.supports_video:
            video_num_patches = hf_inputs.get("video_num_patches",
                                              torch.empty(0))
            num_videos = len(video_num_patches)
            video_fields = dict(
                pixel_values_flat_video=MultiModalFieldConfig.flat_from_sizes(
                    "video", video_num_patches),
                video_num_patches=MultiModalFieldConfig.batched("video"),
                video_token_id=MultiModalFieldConfig.shared(
                    "video", num_videos),
            )
        else:
            video_fields = {}

        return image_fields | video_fields

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
    ) -> Sequence[PromptUpdate]:
        prompt_repl: list[PromptUpdate] = super()._get_prompt_updates(
            mm_items, hf_processor_mm_kwargs, out_mm_kwargs)

        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)

        if "video_num_patches" in out_mm_kwargs:
            video_num_patches = out_mm_kwargs["video_num_patches"]
            assert isinstance(video_num_patches, torch.Tensor)
            video_num_patches = video_num_patches.tolist()
        else:
            video_num_patches = []

        def get_video_replacement_internvl(item_idx: int):
            feature_size = hf_processor.num_image_token
            num_patches = video_num_patches[item_idx]
            if num_patches is not None:
                assert isinstance(num_patches, int)

            return hf_processor.get_video_repl(
                feature_size,
                num_patches,
                video_context_token=hf_processor.video_token)

        if self.info.supports_video:
            prompt_repl.append(
                PromptReplacement(
                    modality="video",
                    target="<video>",
                    replacement=get_video_replacement_internvl,
                ))
        return prompt_repl


1013
1014
1015
1016
@MULTIMODAL_REGISTRY.register_processor(
    InternVLMultiModalProcessor,
    info=InternVLProcessingInfo,
    dummy_inputs=InternVLDummyInputsBuilder)
1017
class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
1018

1019
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
1020
1021
        super().__init__()

1022
1023
1024
1025
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

1026
1027
        self.config = config
        self.multimodal_config = multimodal_config
1028
        self._patch_quant_config(config, quant_config)
1029
1030
1031
1032
1033
1034
1035
1036
1037

        image_size = config.force_image_size or config.vision_config.image_size
        patch_size = config.vision_config.patch_size
        self.patch_size = patch_size
        self.num_image_token = int(
            (image_size // patch_size)**2 * (config.downsample_ratio**2))
        self.downsample_ratio = config.downsample_ratio
        self.ps_version = config.ps_version

1038
1039
        self.llm_arch_name = config.text_config.architectures[0]
        self.is_mono = self.llm_arch_name == 'InternLM2VEForCausalLM'
1040
1041
1042
1043
        self.vision_model = self._init_vision_model(
            config,
            quant_config=quant_config,
            is_mono=self.is_mono,
1044
            prefix=maybe_prefix(prefix, "vision_model"),
1045
        )
1046

1047
        self.language_model = init_vllm_registered_model(
1048
            vllm_config=vllm_config,
1049
1050
1051
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
1052

1053
        self.mlp1 = self._init_mlp1(config)
1054
1055

        self.img_context_token_id = None
1056
1057
        self.video_context_token_id = None

1058
        self.visual_token_mask = None
1059
1060
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)
1061

1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
    def _patch_quant_config(self, config: PretrainedConfig,
                            quant_config: QuantizationConfig):
        # the awq models from OpenGVLab missing `modules_to_not_convert`
        # patch the quant_config to add `modules_to_not_convert` back
        if isinstance(quant_config, AWQConfig):
            text_config = config.text_config
            llm_quant_config = getattr(text_config, "quantization_config",
                                       None)
            if (not quant_config.modules_to_not_convert) and \
                (llm_quant_config is not None):
                quant_config.modules_to_not_convert.append("vision_model")

    def _init_vision_model(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig],
        *,
        is_mono: bool,
        prefix: str,
    ):
1082
        if not is_mono:
1083
            vision_feature_layer = config.select_layer
1084
1085
1086
1087
1088
            if vision_feature_layer < 0:
                num_hidden_layers = config.vision_config.num_hidden_layers \
                    + vision_feature_layer + 1
            else:
                num_hidden_layers = vision_feature_layer + 1
1089

1090
1091
            return InternVisionModel(
                config.vision_config,
1092
1093
1094
1095
                quant_config=quant_config,
                num_hidden_layers_override=num_hidden_layers,
                prefix=prefix,
            )
1096
1097
        else:
            return InternVisionPatchModel(config.vision_config)
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110

    def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
        vit_hidden_size = config.vision_config.hidden_size
        llm_hidden_size = config.text_config.hidden_size

        return nn.Sequential(
            nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2),
            nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2,
                      llm_hidden_size),
            nn.GELU(),
            nn.Linear(llm_hidden_size, llm_hidden_size),
        )

1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
    def pixel_shuffle(self, x, scale_factor=0.5):
        n, w, h, c = x.size()
        # N, W, H, C --> N, W, H * scale, C // scale
        x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
        # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
        x = x.permute(0, 2, 1, 3).contiguous()
        x = x.view(n, int(h * scale_factor), int(w * scale_factor),
                   int(c / (scale_factor * scale_factor)))
        if self.ps_version == 'v1':
            pass
        else:
            x = x.permute(0, 2, 1, 3).contiguous()
        return x

1125
    def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor:
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
        vit_embeds = self.vision_model(pixel_values=pixel_values)
        vit_embeds = vit_embeds[:, 1:, :]

        h = w = int(vit_embeds.shape[1]**0.5)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
        vit_embeds = self.pixel_shuffle(vit_embeds,
                                        scale_factor=self.downsample_ratio)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1,
                                        vit_embeds.shape[-1])
        vit_embeds = self.mlp1(vit_embeds)
        return vit_embeds

1138
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
1139
1140
1141
1142
1143
1144
1145
1146

        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape)

            if actual_dims != expected_dims:
1147
                expected_expr = str(expected_dims)
1148
                raise ValueError(
1149
1150
1151
                    "The expected shape of pixel values per image per batch "
                    f" per patch is {expected_expr}. "
                    f"You supplied {tuple(d.shape)}.")
1152
1153
1154
1155
1156
1157
1158

        for d in data:
            _validate_shape(d)

        return data

    def _parse_and_validate_image_input(
1159
            self, **kwargs: object) -> Optional[InternVLImageInputs]:
1160
1161
        pixel_values_flat = kwargs.pop("pixel_values_flat", None)
        image_num_patches = kwargs.pop("image_num_patches", None)
1162
        image_embeds = kwargs.pop("image_embeds", None)
1163

1164
        if pixel_values_flat is None and image_embeds is None:
1165
1166
            return None

1167
        if image_embeds is not None:
1168
            if not isinstance(image_embeds, (torch.Tensor, list)):
1169
1170
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
1171

1172
1173
            return InternVLImageEmbeddingInputs(
                type="image_embeds",
1174
                data=flatten_bn(image_embeds),
1175
1176
            )

1177
1178
1179
        image_token_id = kwargs["image_token_id"]
        assert isinstance(image_token_id, torch.Tensor)
        self.img_context_token_id = image_token_id.flatten().unique().item()
1180

1181
1182
        if pixel_values_flat is not None:
            if not isinstance(pixel_values_flat, (torch.Tensor, list)):
1183
                raise ValueError("Incorrect type of pixel values. "
1184
1185
                                 f"Got type: {type(pixel_values_flat)}")

1186
1187
            if not isinstance(image_num_patches, (torch.Tensor, list)):
                raise ValueError("Incorrect type of image_num_patches. "
1188
1189
1190
1191
                                 f"Got type: {type(image_num_patches)}")

            pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
            image_num_patches = flatten_bn(image_num_patches, concat=True)
1192

1193
1194
            return InternVLImagePixelInputs(
                type="pixel_values",
1195
1196
1197
1198
                pixel_values_flat=self._validate_pixel_values(
                    pixel_values_flat),
                num_patches=image_num_patches,
            )
1199
1200
1201

        raise AssertionError("This line should be unreachable.")

1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
    def _parse_and_validate_video_input(
            self, **kwargs: object) -> Optional[InternVLVideoPixelInputs]:
        pixel_values_flat_video = kwargs.pop("pixel_values_flat_video", None)
        video_num_patches = kwargs.pop("video_num_patches", None)
        video_embeds = kwargs.pop("image_embeds", None)

        if pixel_values_flat_video is None and video_embeds is None:
            return None

        if video_embeds is not None:
            if not isinstance(video_embeds, (torch.Tensor, list)):
                raise ValueError("Incorrect type of video embeddings. "
                                 f"Got type: {type(video_embeds)}")

            return InternVLImageEmbeddingInputs(
                type="video_embeds",
                data=flatten_bn(video_embeds),
            )

        video_token_id = kwargs["video_token_id"]
        assert isinstance(video_token_id, torch.Tensor)
        self.video_context_token_id = video_token_id.flatten().unique().item()

        if pixel_values_flat_video is not None:
            if not isinstance(pixel_values_flat_video, (torch.Tensor, list)):
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values_flat_video)}")

            if not isinstance(video_num_patches, (torch.Tensor, list)):
                raise ValueError("Incorrect type of image_num_patches. "
                                 f"Got type: {type(video_num_patches)}")

            pixel_values_flat_video = flatten_bn(pixel_values_flat_video,
                                                 concat=True)
            video_num_patches = flatten_bn(video_num_patches, concat=True)

            return InternVLVideoPixelInputs(
                type="pixel_values_videos",
                pixel_values_flat=self._validate_pixel_values(
                    pixel_values_flat_video),
                num_patches=video_num_patches,
            )

        raise AssertionError("This line should be unreachable.")

1247
1248
    def _process_image_input(
        self,
1249
1250
        image_input: Union[InternVLImageInputs, InternVLVideoPixelInputs],
    ) -> tuple[torch.Tensor, ...]:
1251
1252
1253
1254
        if image_input["type"] == "image_embeds":
            return image_input["data"]

        assert self.vision_model is not None
1255

1256
        image_embeds = self.extract_feature(image_input["pixel_values_flat"])
1257

1258
        num_patches = image_input["num_patches"]
1259
1260

        # Only one image in the current batch
1261
        if len(num_patches) == 1:
1262
1263
            return (image_embeds.view(-1,
                                      self.config.text_config.hidden_size), )
1264
1265
1266
1267
1268
1269
1270

        # NOTE: Image embeddings are split into separate tensors for each image
        # by the size of each embedding.
        feature_size = image_embeds.shape[1]
        image_embeds = image_embeds.view(-1,
                                         self.config.text_config.hidden_size)
        image_feature_sizes = [
1271
            num_patches * feature_size for num_patches in num_patches
1272
        ]
1273
        return image_embeds.split(image_feature_sizes)
1274

1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        modalities = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
            if input_key in ("pixel_values_flat",
                             "image_embeds") and "images" not in modalities:
                modalities["images"] = self._parse_and_validate_image_input(
                    **kwargs)
            if input_key in ("pixel_values_flat_video",
                             ) and "videos" not in modalities:
                modalities["videos"] = self._parse_and_validate_video_input(
                    **kwargs)

        return modalities

1292
    def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
1293
        if self.is_mono:
1294
            assert self.img_context_token_id is not None
1295
            self.visual_token_mask = (
1296
1297
                input_ids == self.img_context_token_id).reshape(-1, 1)
        else:
1298
            self.visual_token_mask = None
1299

1300
1301
1302
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

1303
    def get_multimodal_embeddings(
1304
            self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
1305
1306
1307

        modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not modalities:
1308
            return None
1309

1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
        # The result multimodal_embeddings is tuple of tensors, with each
        # tensor correspoending to a multimodal data item (image or video).
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in modalities:
            if modality == "images":
                image_input = modalities["images"]
                vision_embeddings = self._process_image_input(image_input)
                multimodal_embeddings += vision_embeddings
            if modality == "videos":
                video_input = modalities["videos"]
                video_embeddings = self._process_image_input(video_input)
                multimodal_embeddings += video_embeddings

        return multimodal_embeddings
1327
1328
1329
1330

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
1331
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
1332
1333
1334
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
        if multimodal_embeddings is not None:
1335
1336
1337
1338
1339
1340
            context_token_ids = [
                token_id for token_id in (self.img_context_token_id,
                                          self.video_context_token_id)
                if token_id is not None
            ]
            assert len(context_token_ids) >= 1
1341
            self._set_visual_token_mask(input_ids)
1342
            inputs_embeds = merge_multimodal_embeddings(
1343
1344
                input_ids,
                inputs_embeds,
1345
                multimodal_embeddings,
1346
                context_token_ids,
1347
            )
1348
1349
        return inputs_embeds

1350
1351
1352
1353
1354
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
1355
        inputs_embeds: Optional[torch.Tensor] = None,
1356
        **kwargs: object,
1357
    ) -> IntermediateTensors:
1358

1359
        if intermediate_tensors is not None:
1360
1361
            input_ids = None
            inputs_embeds = None
1362
1363
1364
1365
1366
1367
1368
1369

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
        elif inputs_embeds is None:
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
1370
1371
1372
1373
1374
1375
1376

        forward_kwargs = {
            "input_ids": input_ids,
            "positions": positions,
            "intermediate_tensors": intermediate_tensors,
            "inputs_embeds": inputs_embeds,
        }
1377

1378
        # Only required if the model is mono-architecture
1379
1380
1381
1382
        if self.visual_token_mask is not None:
            forward_kwargs.update(
                {"visual_token_mask": self.visual_token_mask})
            self.visual_token_mask = None
1383

1384
        hidden_states = self.language_model.model(**forward_kwargs)
1385
1386
        return hidden_states

1387
1388
1389
1390
1391
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
1392
1393
1394
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)

1395
1396
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
1397
1398
1399
1400
1401
1402
1403
1404
        # unused modules appear in OpenGVLab/InternVideo2_5_Chat_8B
        skip_prefixes = [
            "action_embed", "temporal_embed", "track_embed",
            "track_embed_decoder", "box_token", "cg_criterion", "cg_model",
            "loc_encoder", "loc_decoder", "sam", "temporal_token",
            "track_token"
        ]
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
1405
        return loader.load_weights(weights)