internvl.py 33 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_internvl_chat.py
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
9
from abc import ABC, abstractmethod
10
from collections.abc import Iterable, Mapping, Sequence
11
from typing import Literal, Optional, TypedDict, TypeVar, Union
12
13
14
15
16

import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
17
from transformers import BatchEncoding, PretrainedConfig, TensorType
18

19
from vllm.config import VllmConfig
20
21
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
22
23
from vllm.model_executor.models.intern_vit import (InternVisionModel,
                                                   InternVisionPatchModel)
24
from vllm.model_executor.sampling_metadata import SamplingMetadata
25
from vllm.multimodal import MULTIMODAL_REGISTRY
26
from vllm.multimodal.image import convert_image_mode
27
28
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
                                    MultiModalKwargs, NestedTensors)
29
30
31
32
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
                                   ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
                                        BaseProcessingInfo, PromptReplacement,
33
                                        PromptUpdate, PromptUpdateDetails)
34
from vllm.multimodal.profiling import BaseDummyInputsBuilder
35
from vllm.sequence import IntermediateTensors
36
from vllm.transformers_utils.tokenizer import AnyTokenizer
37

38
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
39
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
40
                    maybe_prefix, merge_multimodal_embeddings)
41
42
43
44
45
46
47
48
49
50
51

IMG_START = '<img>'
IMG_END = '</img>'
IMG_CONTEXT = '<IMG_CONTEXT>'

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)


class InternVLImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
52
    pixel_values_flat: torch.Tensor
53
    """
54
55
    Shape:
    `(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
56
    """
57
58
59
60

    num_patches: torch.Tensor
    """Shape: `(batch_size * num_images)`"""

61

62
63
class InternVLImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
64
    data: Union[torch.Tensor, list[torch.Tensor]]
65
66
67
    """ 
    A tensor of shape `(num_images, total_image_feature_size, hidden_size)`
    or a list of tensors of shape `(total_image_feature_size, hidden_size)`
68
69
70
71
72
73
74
75
76

    `hidden_size` must match the hidden size of language model backbone.
    """


InternVLImageInputs = Union[InternVLImagePixelInputs,
                            InternVLImageEmbeddingInputs]


77
78
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def build_transform(input_size: int):
79
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
80
    return T.Compose([
81
        T.Lambda(lambda img: convert_image_mode(img, 'RGB')),
82
83
84
85
86
87
88
        T.Resize((input_size, input_size),
                 interpolation=T.InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])


89
90
91
92
93
94
95
96
97
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def find_closest_aspect_ratio(
    aspect_ratio: float,
    target_ratios: list[tuple[int, int]],
    *,
    width: int,
    height: int,
    image_size: int,
) -> tuple[int, int]:
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio


113
114
115
116
117
118
119
def resolve_internvl_min_max_num(
    *,
    min_dynamic_patch: int,
    max_dynamic_patch: int,
    dynamic_image_size: bool,
    use_thumbnail: bool,
) -> tuple[int, int]:
120
    min_dynamic_patch = min_dynamic_patch if dynamic_image_size else 1
121
122
123
124
125
126
127
    max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1

    if use_thumbnail and max_dynamic_patch != 1:
        max_dynamic_patch += 1

    return min_dynamic_patch, max_dynamic_patch

128

129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
def get_internvl_target_ratios(
    min_num: int,
    max_num: int,
) -> list[tuple[int, int]]:
    target_ratios = {(i, j)
                     for n in range(min_num, max_num + 1)
                     for i in range(1, n + 1)
                     for j in range(1, n + 1) if min_num <= i * j <= max_num}
    return sorted(target_ratios, key=lambda x: x[0] * x[1])


def calculate_internvl_targets(
    *,
    orig_width: int,
    orig_height: int,
    target_ratios: list[tuple[int, int]],
    image_size: int,
    use_thumbnail: bool,
) -> tuple[int, int, int]:
    aspect_ratio = orig_width / orig_height
149
150

    # find the closest aspect ratio to the target
151
152
153
154
155
156
157
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio,
        target_ratios,
        width=orig_width,
        height=orig_height,
        image_size=image_size,
    )
158
159
160
161
162
163

    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

164
165
166
    # add thumbnail image if num_blocks != 1
    if use_thumbnail and blocks != 1:
        blocks += 1
167

168
    return blocks, target_width, target_height
169
170


171
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
172
173
174
175
176
177
178
def dynamic_preprocess_internvl(
    image: Image.Image,
    *,
    target_ratios: list[tuple[int, int]],
    image_size: int,
    use_thumbnail: bool,
) -> list[Image.Image]:
179
180
    orig_width, orig_height = image.size

181
    # calculate the number of blocks without thumbnail
182
183
184
185
186
187
188
189
    blocks, target_width, target_height = calculate_internvl_targets(
        orig_width=orig_width,
        orig_height=orig_height,
        target_ratios=target_ratios,
        image_size=image_size,
        use_thumbnail=False,
    )

190
191
192
193
194
195
196
197
198
199
200
    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = ((i % (target_width // image_size)) * image_size,
               (i // (target_width // image_size)) * image_size,
               ((i % (target_width // image_size)) + 1) * image_size,
               ((i // (target_width // image_size)) + 1) * image_size)
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
201

202
    assert len(processed_images) == blocks
203

204
205
206
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
207

208
209
210
211
    return processed_images


# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
212
213
214
215
216
217
218
219
220
221
def image_to_pixel_values_internvl(
    image: Image.Image,
    *,
    input_size: int,
    min_num: int,
    max_num: int,
    use_thumbnail: bool,
) -> torch.Tensor:
    target_ratios = get_internvl_target_ratios(min_num, max_num)

222
    transform = build_transform(input_size=input_size)
223
224
225
226
227
228
229
230
    images = dynamic_preprocess_internvl(
        image,
        target_ratios=target_ratios,
        image_size=input_size,
        use_thumbnail=use_thumbnail,
    )

    pixel_values = torch.stack([transform(image) for image in images])
231
232
233
    return pixel_values


234
235
236
237
class BaseInternVLProcessor(ABC):
    """
    This model doesn't define its own HF processor,
    so we implement our own one here.
238

239
240
241
    The code to insert image tokens is based on:
    https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py#L252
    """
242

243
244
245
246
247
    def __init__(
        self,
        config: PretrainedConfig,
        tokenizer: AnyTokenizer,
        *,
248
        min_dynamic_patch: Optional[int] = None,
249
250
251
252
        max_dynamic_patch: Optional[int] = None,
        dynamic_image_size: Optional[bool] = None,
    ) -> None:
        super().__init__()
253

254
255
        self.config = config
        self.tokenizer = tokenizer
256

257
258
        image_size: int = config.vision_config.image_size
        patch_size: int = config.vision_config.patch_size
259

260
261
262
        if min_dynamic_patch is None:
            min_dynamic_patch = config.min_dynamic_patch
        assert isinstance(min_dynamic_patch, int)
263

264
265
266
        if max_dynamic_patch is None:
            max_dynamic_patch = config.max_dynamic_patch
        assert isinstance(max_dynamic_patch, int)
267

268
269
270
271
        if dynamic_image_size is None:
            dynamic_image_size = config.dynamic_image_size
        assert isinstance(dynamic_image_size, bool)

272
273
274
        self.num_image_token = int(
            (image_size // patch_size)**2 * (config.downsample_ratio**2))
        self.image_size = image_size
275
        self.min_dynamic_patch = min_dynamic_patch
276
277
278
279
280
281
282
283
284
285
        self.max_dynamic_patch = max_dynamic_patch
        self.dynamic_image_size = dynamic_image_size
        self.use_thumbnail: bool = config.use_thumbnail

    @property
    @abstractmethod
    def image_token_id(self) -> int:
        raise NotImplementedError

    @abstractmethod
286
    def get_image_repl(
287
288
289
        self,
        feature_size: int,
        num_patches: Optional[int],
290
    ) -> PromptUpdateDetails[str]:
291
        raise NotImplementedError
292

293
    def resolve_min_max_num(
294
        self,
295
        *,
296
        min_dynamic_patch: Optional[int] = None,
297
298
299
300
        max_dynamic_patch: Optional[int] = None,
        dynamic_image_size: Optional[bool] = None,
        use_thumbnail: Optional[bool] = None,
    ) -> tuple[int, int]:
301
302
        min_dynamic_patch = (self.min_dynamic_patch if min_dynamic_patch
                             is None else min_dynamic_patch)
303
304
305
306
307
308
309
310
311
312
313
314
315
        max_dynamic_patch = (self.max_dynamic_patch if max_dynamic_patch
                             is None else max_dynamic_patch)
        dynamic_image_size = (self.dynamic_image_size if dynamic_image_size
                              is None else dynamic_image_size)
        use_thumbnail = (self.use_thumbnail
                         if use_thumbnail is None else use_thumbnail)

        return resolve_internvl_min_max_num(
            min_dynamic_patch=min_dynamic_patch,
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
            use_thumbnail=use_thumbnail,
        )
316

317
318
319
    def resolve_target_ratios(
        self,
        *,
320
        min_dynamic_patch: Optional[int] = None,
321
322
323
324
325
        max_dynamic_patch: Optional[int] = None,
        dynamic_image_size: Optional[bool] = None,
        use_thumbnail: Optional[bool] = None,
    ) -> list[tuple[int, int]]:
        min_num, max_num = self.resolve_min_max_num(
326
            min_dynamic_patch=min_dynamic_patch,
327
328
329
330
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
            use_thumbnail=use_thumbnail,
        )
331

332
        return get_internvl_target_ratios(min_num, max_num)
333

334
    def get_num_image_tokens(
335
        self,
336
337
338
339
340
341
342
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        target_ratios = self.resolve_target_ratios(
            use_thumbnail=False,  # Applied in calculate_targets
        )
343

344
345
346
347
348
349
350
        num_patches, _, _ = calculate_internvl_targets(
            orig_width=image_width,
            orig_height=image_height,
            image_size=self.image_size,
            target_ratios=target_ratios,
            use_thumbnail=self.use_thumbnail,
        )
351

352
353
354
355
356
        return num_patches * self.num_image_token

    def _images_to_pixel_values_lst(
        self,
        images: list[Image.Image],
357
        min_dynamic_patch: Optional[int] = None,
358
359
360
361
        max_dynamic_patch: Optional[int] = None,
        dynamic_image_size: Optional[bool] = None,
    ) -> list[torch.Tensor]:
        min_num, max_num = self.resolve_min_max_num(
362
            min_dynamic_patch=min_dynamic_patch,
363
364
365
366
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
            use_thumbnail=False,  # Applied in image_to_pixel_values
        )
367

368
369
370
371
372
373
374
375
376
        return [
            image_to_pixel_values_internvl(
                image,
                input_size=self.image_size,
                min_num=min_num,
                max_num=max_num,
                use_thumbnail=self.use_thumbnail,
            ) for image in images
        ]
377

378
    def __call__(
379
        self,
380
381
        text: Optional[Union[str, list[str]]] = None,
        images: Optional[Union[Image.Image, list[Image.Image]]] = None,
382
        min_dynamic_patch: Optional[int] = None,
383
        max_dynamic_patch: Optional[int] = None,
384
        dynamic_image_size: Optional[bool] = None,
385
        return_tensors: Optional[Union[str, TensorType]] = None,
386
    ) -> Mapping[str, NestedTensors]:
387
388
389
390
391
392
393
394
395
396
397
        if text is None:
            text = []
        if not isinstance(text, list):
            text = [text]
        if images is None:
            images = []
        if not isinstance(images, list):
            images = [images]

        if len(images) == 0:
            image_inputs = {}
398
        else:
399
400
            pixel_values_lst = self._images_to_pixel_values_lst(
                images,
401
                min_dynamic_patch=min_dynamic_patch,
402
403
404
                max_dynamic_patch=max_dynamic_patch,
                dynamic_image_size=dynamic_image_size,
            )
405
406
407
408
409
            image_inputs: dict[str, NestedTensors] = {
                "pixel_values_flat":
                torch.cat(pixel_values_lst),
                "image_num_patches":
                torch.tensor([len(item) for item in pixel_values_lst]),
410
411
412
413
414
415
            }

            for pixel_values in pixel_values_lst:
                num_patches = pixel_values.shape[0]
                feature_size = num_patches * self.num_image_token

416
417
                image_repl = self.get_image_repl(feature_size, num_patches)
                text = [t.replace('<image>', image_repl.full, 1) for t in text]
418
419
420

        text_inputs = self.tokenizer(text)

421
422
423
424
        return {
            **BatchEncoding(text_inputs, tensor_type=return_tensors),
            **image_inputs,
        }
425
426


427
428
429
430
431
432
class InternVLProcessor(BaseInternVLProcessor):

    @property
    def image_token_id(self) -> int:
        return self.tokenizer.get_vocab()[IMG_CONTEXT]

433
    def get_image_repl(
434
435
436
        self,
        feature_size: int,
        num_patches: Optional[int],
437
438
439
    ) -> PromptUpdateDetails[str]:
        repl_features = IMG_CONTEXT * feature_size
        repl_full = IMG_START + repl_features + IMG_END
440

441
        return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
442
443
444
445
446
447


class BaseInternVLProcessingInfo(BaseProcessingInfo):

    @abstractmethod
    def get_hf_processor(
448
449
        self,
        *,
450
        min_dynamic_patch: Optional[int] = None,
451
        max_dynamic_patch: Optional[int] = None,
452
        dynamic_image_size: Optional[bool] = None,
453
        **kwargs: object,
454
455
456
457
458
459
460
461
    ) -> BaseInternVLProcessor:
        raise NotImplementedError

    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}

    def get_num_image_tokens(
        self,
462
        *,
463
464
465
466
467
468
469
470
471
472
473
        image_width: int,
        image_height: int,
        processor: Optional[BaseInternVLProcessor],
    ) -> int:
        if processor is None:
            processor = self.get_hf_processor()

        return processor.get_num_image_tokens(
            image_width=image_width,
            image_height=image_height,
        )
474

475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
    def get_image_size_with_most_features(self) -> ImageSize:
        processor = self.get_hf_processor()

        base_size = processor.image_size
        target_ratios = processor.resolve_target_ratios()

        largest_feature_size, largest_feature_pinpoint = 0, None
        for wr, hr in target_ratios:
            width, height = base_size * wr, base_size * hr

            feat_size = self.get_num_image_tokens(
                image_width=width,
                image_height=height,
                processor=processor,
            )
            if feat_size > largest_feature_size:
                largest_feature_size = feat_size
                largest_feature_pinpoint = ImageSize(width=width,
                                                     height=height)

        if largest_feature_size == 0 or largest_feature_pinpoint is None:
            raise ValueError("Cannot have a largest feature size of 0!")

        return largest_feature_pinpoint


_I = TypeVar("_I", bound=BaseInternVLProcessingInfo)


class InternVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]):

506
507
508
509
510
511
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        return "<image>" * num_images

    def get_dummy_mm_data(
512
513
514
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
515
    ) -> MultiModalDataDict:
516
517
518
519
        target_width, target_height = \
            self.info.get_image_size_with_most_features()
        num_images = mm_counts.get("image", 0)

520
        return {
521
522
523
524
525
526
527
528
529
530
531
532
533
534
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images)
        }


class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
535
    ) -> Mapping[str, NestedTensors]:
536
537
538
539
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
540
        )
541

542
543
        hf_processor = self.info.get_hf_processor(**mm_kwargs)
        image_token_id = hf_processor.image_token_id
544
545
546
547

        # Since there may be extra tokens in the feature placeholders,
        # we need to pass the image token ID to the model to select the
        # tokens to merge from the vision encoder outputs
548
        processed_outputs["image_token_id"] = torch.tensor(image_token_id)
549
550
551
552
553

        return processed_outputs

    def _get_mm_fields_config(
        self,
554
        hf_inputs: Mapping[str, NestedTensors],
555
556
557
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
558
        num_images = len(image_num_patches)
559
560
561
562
563
564

        return dict(
            pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
                "image", image_num_patches),
            image_num_patches=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
565
            image_token_id=MultiModalFieldConfig.shared("image", num_images),
566
567
        )

568
    def _get_prompt_updates(
569
570
571
572
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
573
    ) -> Sequence[PromptUpdate]:
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)

        if "image_num_patches" in out_mm_kwargs:
            image_num_patches = out_mm_kwargs["image_num_patches"]
            assert isinstance(image_num_patches, torch.Tensor)
            image_num_patches = image_num_patches.tolist()
        elif "image_embeds" in out_mm_kwargs:
            # TODO: Use image size information in dictionary embedding inputs
            # to compute num_patches (similar to Qwen2-VL)
            image_num_patches = [None] * len(out_mm_kwargs["image_embeds"])
        else:
            image_num_patches = []

        def get_replacement_internvl(item_idx: int):
            images = mm_items.get_items(
                "image", (ImageEmbeddingItems, ImageProcessorItems))

            if isinstance(images, ImageEmbeddingItems):
                feature_size = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
                feature_size = self.info.get_num_image_tokens(
                    image_width=image_size.width,
                    image_height=image_size.height,
                    processor=hf_processor,
                )

            num_patches = image_num_patches[item_idx]
            if num_patches is not None:
                assert isinstance(num_patches, int)

605
            return hf_processor.get_image_repl(feature_size, num_patches)
606

607
608
609
610
611
612
613
        return [
            PromptReplacement(
                modality="image",
                target="<image>",
                replacement=get_replacement_internvl,
            )
        ]
614
615


616
class InternVLProcessingInfo(BaseInternVLProcessingInfo):
617

618
619
620
    def get_hf_processor(
        self,
        *,
621
        min_dynamic_patch: Optional[int] = None,
622
623
        max_dynamic_patch: Optional[int] = None,
        dynamic_image_size: Optional[bool] = None,
624
        **kwargs: object,
625
    ) -> InternVLProcessor:
626
627
628
629
630
631
632
633
634
635
636
637
        if min_dynamic_patch is not None:
            kwargs["min_dynamic_patch"] = min_dynamic_patch
        if max_dynamic_patch is not None:
            kwargs["max_dynamic_patch"] = max_dynamic_patch
        if dynamic_image_size is not None:
            kwargs["dynamic_image_size"] = dynamic_image_size

        return self.ctx.init_processor(
            InternVLProcessor,
            config=self.get_hf_config(),
            tokenizer=self.get_tokenizer(),
            **kwargs,
638
639
640
641
642
643
644
        )


@MULTIMODAL_REGISTRY.register_processor(
    InternVLMultiModalProcessor,
    info=InternVLProcessingInfo,
    dummy_inputs=InternVLDummyInputsBuilder)
645
class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
646

647
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
648
649
        super().__init__()

650
651
652
653
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

654
655
        self.config = config
        self.multimodal_config = multimodal_config
656
        self._patch_quant_config(config, quant_config)
657
658
659
660
661
662
663
664
665

        image_size = config.force_image_size or config.vision_config.image_size
        patch_size = config.vision_config.patch_size
        self.patch_size = patch_size
        self.num_image_token = int(
            (image_size // patch_size)**2 * (config.downsample_ratio**2))
        self.downsample_ratio = config.downsample_ratio
        self.ps_version = config.ps_version

666
667
        self.llm_arch_name = config.text_config.architectures[0]
        self.is_mono = self.llm_arch_name == 'InternLM2VEForCausalLM'
668
669
670
671
        self.vision_model = self._init_vision_model(
            config,
            quant_config=quant_config,
            is_mono=self.is_mono,
672
            prefix=maybe_prefix(prefix, "vision_model"),
673
        )
674

675
        self.language_model = init_vllm_registered_model(
676
            vllm_config=vllm_config,
677
678
679
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
680

681
        self.mlp1 = self._init_mlp1(config)
682
683

        self.img_context_token_id = None
684
        self.visual_token_mask = None
685
686
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)
687

688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
    def _patch_quant_config(self, config: PretrainedConfig,
                            quant_config: QuantizationConfig):
        # the awq models from OpenGVLab missing `modules_to_not_convert`
        # patch the quant_config to add `modules_to_not_convert` back
        if isinstance(quant_config, AWQConfig):
            text_config = config.text_config
            llm_quant_config = getattr(text_config, "quantization_config",
                                       None)
            if (not quant_config.modules_to_not_convert) and \
                (llm_quant_config is not None):
                quant_config.modules_to_not_convert.append("vision_model")

    def _init_vision_model(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig],
        *,
        is_mono: bool,
        prefix: str,
    ):
708
        if not is_mono:
709
            vision_feature_layer = config.select_layer
710
711
712
713
714
            if vision_feature_layer < 0:
                num_hidden_layers = config.vision_config.num_hidden_layers \
                    + vision_feature_layer + 1
            else:
                num_hidden_layers = vision_feature_layer + 1
715

716
717
            return InternVisionModel(
                config.vision_config,
718
719
720
721
                quant_config=quant_config,
                num_hidden_layers_override=num_hidden_layers,
                prefix=prefix,
            )
722
723
        else:
            return InternVisionPatchModel(config.vision_config)
724
725
726
727
728
729
730
731
732
733
734
735
736

    def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
        vit_hidden_size = config.vision_config.hidden_size
        llm_hidden_size = config.text_config.hidden_size

        return nn.Sequential(
            nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2),
            nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2,
                      llm_hidden_size),
            nn.GELU(),
            nn.Linear(llm_hidden_size, llm_hidden_size),
        )

737
738
739
740
741
742
743
744
745
746
747
748
749
750
    def pixel_shuffle(self, x, scale_factor=0.5):
        n, w, h, c = x.size()
        # N, W, H, C --> N, W, H * scale, C // scale
        x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
        # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
        x = x.permute(0, 2, 1, 3).contiguous()
        x = x.view(n, int(h * scale_factor), int(w * scale_factor),
                   int(c / (scale_factor * scale_factor)))
        if self.ps_version == 'v1':
            pass
        else:
            x = x.permute(0, 2, 1, 3).contiguous()
        return x

751
    def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor:
752
753
754
755
756
757
758
759
760
761
762
763
        vit_embeds = self.vision_model(pixel_values=pixel_values)
        vit_embeds = vit_embeds[:, 1:, :]

        h = w = int(vit_embeds.shape[1]**0.5)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
        vit_embeds = self.pixel_shuffle(vit_embeds,
                                        scale_factor=self.downsample_ratio)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1,
                                        vit_embeds.shape[-1])
        vit_embeds = self.mlp1(vit_embeds)
        return vit_embeds

764
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
765
766
767
768
769
770
771
772

        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape)

            if actual_dims != expected_dims:
773
                expected_expr = str(expected_dims)
774
                raise ValueError(
775
776
777
                    "The expected shape of pixel values per image per batch "
                    f" per patch is {expected_expr}. "
                    f"You supplied {tuple(d.shape)}.")
778
779
780
781
782
783
784

        for d in data:
            _validate_shape(d)

        return data

    def _parse_and_validate_image_input(
785
            self, **kwargs: object) -> Optional[InternVLImageInputs]:
786
787
        pixel_values_flat = kwargs.pop("pixel_values_flat", None)
        image_num_patches = kwargs.pop("image_num_patches", None)
788
        image_embeds = kwargs.pop("image_embeds", None)
789

790
        if pixel_values_flat is None and image_embeds is None:
791
792
            return None

793
        if image_embeds is not None:
794
            if not isinstance(image_embeds, (torch.Tensor, list)):
795
796
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
797

798
799
            return InternVLImageEmbeddingInputs(
                type="image_embeds",
800
                data=flatten_bn(image_embeds),
801
802
            )

803
804
805
        image_token_id = kwargs["image_token_id"]
        assert isinstance(image_token_id, torch.Tensor)
        self.img_context_token_id = image_token_id.flatten().unique().item()
806

807
808
        if pixel_values_flat is not None:
            if not isinstance(pixel_values_flat, (torch.Tensor, list)):
809
                raise ValueError("Incorrect type of pixel values. "
810
811
                                 f"Got type: {type(pixel_values_flat)}")

812
813
            if not isinstance(image_num_patches, (torch.Tensor, list)):
                raise ValueError("Incorrect type of image_num_patches. "
814
815
816
817
                                 f"Got type: {type(image_num_patches)}")

            pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
            image_num_patches = flatten_bn(image_num_patches, concat=True)
818

819
820
            return InternVLImagePixelInputs(
                type="pixel_values",
821
822
823
824
                pixel_values_flat=self._validate_pixel_values(
                    pixel_values_flat),
                num_patches=image_num_patches,
            )
825
826
827
828
829
830

        raise AssertionError("This line should be unreachable.")

    def _process_image_input(
        self,
        image_input: InternVLImageInputs,
831
    ) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]]:
832
833
834
835
        if image_input["type"] == "image_embeds":
            return image_input["data"]

        assert self.vision_model is not None
836

837
        image_embeds = self.extract_feature(image_input["pixel_values_flat"])
838

839
        num_patches = image_input["num_patches"]
840
841

        # Only one image in the current batch
842
843
        if len(num_patches) == 1:
            return image_embeds.view(
844
                -1, self.config.text_config.hidden_size).unsqueeze(0)
845
846
847
848
849
850
851

        # NOTE: Image embeddings are split into separate tensors for each image
        # by the size of each embedding.
        feature_size = image_embeds.shape[1]
        image_embeds = image_embeds.view(-1,
                                         self.config.text_config.hidden_size)
        image_feature_sizes = [
852
            num_patches * feature_size for num_patches in num_patches
853
        ]
854
        return image_embeds.split(image_feature_sizes)
855

856
    def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
857
        if self.is_mono:
858
            self.visual_token_mask = (
859
860
                input_ids == self.img_context_token_id).reshape(-1, 1)
        else:
861
            self.visual_token_mask = None
862

863
864
865
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

866
    def get_multimodal_embeddings(
867
            self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
868
869
870
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
871

872
        return self._process_image_input(image_input)
873
874
875
876

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
877
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
878
879
880
881
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
        if multimodal_embeddings is not None:
            assert self.img_context_token_id is not None
882
            self._set_visual_token_mask(input_ids)
883
            inputs_embeds = merge_multimodal_embeddings(
884
885
                input_ids,
                inputs_embeds,
886
                multimodal_embeddings,
887
888
                self.img_context_token_id,
            )
889
890
        return inputs_embeds

891
892
893
894
895
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
896
        inputs_embeds: Optional[torch.Tensor] = None,
897
        **kwargs: object,
898
    ) -> IntermediateTensors:
899

900
        if intermediate_tensors is not None:
901
902
            input_ids = None
            inputs_embeds = None
903
904
905
906
907
908
909
910

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
        elif inputs_embeds is None:
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
911
912
913
914
915
916
917

        forward_kwargs = {
            "input_ids": input_ids,
            "positions": positions,
            "intermediate_tensors": intermediate_tensors,
            "inputs_embeds": inputs_embeds,
        }
918

919
        # Only required if the model is mono-architecture
920
921
922
923
        if self.visual_token_mask is not None:
            forward_kwargs.update(
                {"visual_token_mask": self.visual_token_mask})
            self.visual_token_mask = None
924

925
        hidden_states = self.language_model.model(**forward_kwargs)
926
927
        return hidden_states

928
929
930
931
932
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
933
934
935
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)

936
937
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
938
939
940
941
942
943
944
945
        # unused modules appear in OpenGVLab/InternVideo2_5_Chat_8B
        skip_prefixes = [
            "action_embed", "temporal_embed", "track_embed",
            "track_embed_decoder", "box_token", "cg_criterion", "cg_model",
            "loc_encoder", "loc_decoder", "sam", "temporal_token",
            "track_token"
        ]
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
946
        return loader.load_weights(weights)