"vllm/vscode:/vscode.git/clone" did not exist on "ccd0d1d9067a0bf24330a87044dca272e4a5228c"
internvl.py 33 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_internvl_chat.py
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
9
from abc import ABC, abstractmethod
10
from collections.abc import Iterable, Mapping, Sequence
11
from typing import Literal, Optional, Set, Tuple, TypedDict, TypeVar, Union
12
13
14
15
16

import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
17
from transformers import BatchEncoding, PretrainedConfig, TensorType
18

19
from vllm.config import VllmConfig
20
21
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
22
23
from vllm.model_executor.models.intern_vit import (InternVisionModel,
                                                   InternVisionPatchModel)
24
from vllm.model_executor.sampling_metadata import SamplingMetadata
25
from vllm.multimodal import MULTIMODAL_REGISTRY
26
27
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
                                    MultiModalKwargs, NestedTensors)
28
29
30
31
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
                                   ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
                                        BaseProcessingInfo, PromptReplacement,
32
                                        PromptUpdate, PromptUpdateDetails)
33
from vllm.multimodal.profiling import BaseDummyInputsBuilder
34
from vllm.sequence import IntermediateTensors
35
from vllm.transformers_utils.tokenizer import AnyTokenizer
36

37
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
38
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
39
                    maybe_prefix, merge_multimodal_embeddings)
40
41
42
43
44
45
46
47
48
49
50

IMG_START = '<img>'
IMG_END = '</img>'
IMG_CONTEXT = '<IMG_CONTEXT>'

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)


class InternVLImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
51
    pixel_values_flat: torch.Tensor
52
    """
53
54
    Shape:
    `(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
55
    """
56
57
58
59

    num_patches: torch.Tensor
    """Shape: `(batch_size * num_images)`"""

60

61
62
class InternVLImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
63
    data: Union[torch.Tensor, list[torch.Tensor]]
64
65
66
    """ 
    A tensor of shape `(num_images, total_image_feature_size, hidden_size)`
    or a list of tensors of shape `(total_image_feature_size, hidden_size)`
67
68
69
70
71
72
73
74
75

    `hidden_size` must match the hidden size of language model backbone.
    """


InternVLImageInputs = Union[InternVLImagePixelInputs,
                            InternVLImageEmbeddingInputs]


76
77
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def build_transform(input_size: int):
78
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
79
    return T.Compose([
80
81
82
83
84
85
86
87
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size),
                 interpolation=T.InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])


88
89
90
91
92
93
94
95
96
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def find_closest_aspect_ratio(
    aspect_ratio: float,
    target_ratios: list[tuple[int, int]],
    *,
    width: int,
    height: int,
    image_size: int,
) -> tuple[int, int]:
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio


112
113
114
115
116
117
118
def resolve_internvl_min_max_num(
    *,
    min_dynamic_patch: int,
    max_dynamic_patch: int,
    dynamic_image_size: bool,
    use_thumbnail: bool,
) -> tuple[int, int]:
119
    min_dynamic_patch = min_dynamic_patch if dynamic_image_size else 1
120
121
122
123
124
125
126
    max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1

    if use_thumbnail and max_dynamic_patch != 1:
        max_dynamic_patch += 1

    return min_dynamic_patch, max_dynamic_patch

127

128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
def get_internvl_target_ratios(
    min_num: int,
    max_num: int,
) -> list[tuple[int, int]]:
    target_ratios = {(i, j)
                     for n in range(min_num, max_num + 1)
                     for i in range(1, n + 1)
                     for j in range(1, n + 1) if min_num <= i * j <= max_num}
    return sorted(target_ratios, key=lambda x: x[0] * x[1])


def calculate_internvl_targets(
    *,
    orig_width: int,
    orig_height: int,
    target_ratios: list[tuple[int, int]],
    image_size: int,
    use_thumbnail: bool,
) -> tuple[int, int, int]:
    aspect_ratio = orig_width / orig_height
148
149

    # find the closest aspect ratio to the target
150
151
152
153
154
155
156
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio,
        target_ratios,
        width=orig_width,
        height=orig_height,
        image_size=image_size,
    )
157
158
159
160
161
162

    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

163
164
165
    # add thumbnail image if num_blocks != 1
    if use_thumbnail and blocks != 1:
        blocks += 1
166

167
    return blocks, target_width, target_height
168
169


170
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
171
172
173
174
175
176
177
def dynamic_preprocess_internvl(
    image: Image.Image,
    *,
    target_ratios: list[tuple[int, int]],
    image_size: int,
    use_thumbnail: bool,
) -> list[Image.Image]:
178
179
    orig_width, orig_height = image.size

180
    # calculate the number of blocks without thumbnail
181
182
183
184
185
186
187
188
    blocks, target_width, target_height = calculate_internvl_targets(
        orig_width=orig_width,
        orig_height=orig_height,
        target_ratios=target_ratios,
        image_size=image_size,
        use_thumbnail=False,
    )

189
190
191
192
193
194
195
196
197
198
199
    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = ((i % (target_width // image_size)) * image_size,
               (i // (target_width // image_size)) * image_size,
               ((i % (target_width // image_size)) + 1) * image_size,
               ((i // (target_width // image_size)) + 1) * image_size)
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
200

201
    assert len(processed_images) == blocks
202

203
204
205
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
206

207
208
209
210
    return processed_images


# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
211
212
213
214
215
216
217
218
219
220
def image_to_pixel_values_internvl(
    image: Image.Image,
    *,
    input_size: int,
    min_num: int,
    max_num: int,
    use_thumbnail: bool,
) -> torch.Tensor:
    target_ratios = get_internvl_target_ratios(min_num, max_num)

221
    transform = build_transform(input_size=input_size)
222
223
224
225
226
227
228
229
    images = dynamic_preprocess_internvl(
        image,
        target_ratios=target_ratios,
        image_size=input_size,
        use_thumbnail=use_thumbnail,
    )

    pixel_values = torch.stack([transform(image) for image in images])
230
231
232
    return pixel_values


233
234
235
236
class BaseInternVLProcessor(ABC):
    """
    This model doesn't define its own HF processor,
    so we implement our own one here.
237

238
239
240
    The code to insert image tokens is based on:
    https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py#L252
    """
241

242
243
244
245
246
    def __init__(
        self,
        config: PretrainedConfig,
        tokenizer: AnyTokenizer,
        *,
247
        min_dynamic_patch: Optional[int] = None,
248
249
250
251
        max_dynamic_patch: Optional[int] = None,
        dynamic_image_size: Optional[bool] = None,
    ) -> None:
        super().__init__()
252

253
254
        self.config = config
        self.tokenizer = tokenizer
255

256
257
        image_size: int = config.vision_config.image_size
        patch_size: int = config.vision_config.patch_size
258

259
260
261
        if min_dynamic_patch is None:
            min_dynamic_patch = config.min_dynamic_patch
        assert isinstance(min_dynamic_patch, int)
262

263
264
265
        if max_dynamic_patch is None:
            max_dynamic_patch = config.max_dynamic_patch
        assert isinstance(max_dynamic_patch, int)
266

267
268
269
270
        if dynamic_image_size is None:
            dynamic_image_size = config.dynamic_image_size
        assert isinstance(dynamic_image_size, bool)

271
272
273
        self.num_image_token = int(
            (image_size // patch_size)**2 * (config.downsample_ratio**2))
        self.image_size = image_size
274
        self.min_dynamic_patch = min_dynamic_patch
275
276
277
278
279
280
281
282
283
284
        self.max_dynamic_patch = max_dynamic_patch
        self.dynamic_image_size = dynamic_image_size
        self.use_thumbnail: bool = config.use_thumbnail

    @property
    @abstractmethod
    def image_token_id(self) -> int:
        raise NotImplementedError

    @abstractmethod
285
    def get_image_repl(
286
287
288
        self,
        feature_size: int,
        num_patches: Optional[int],
289
    ) -> PromptUpdateDetails[str]:
290
        raise NotImplementedError
291

292
    def resolve_min_max_num(
293
        self,
294
        *,
295
        min_dynamic_patch: Optional[int] = None,
296
297
298
299
        max_dynamic_patch: Optional[int] = None,
        dynamic_image_size: Optional[bool] = None,
        use_thumbnail: Optional[bool] = None,
    ) -> tuple[int, int]:
300
301
        min_dynamic_patch = (self.min_dynamic_patch if min_dynamic_patch
                             is None else min_dynamic_patch)
302
303
304
305
306
307
308
309
310
311
312
313
314
        max_dynamic_patch = (self.max_dynamic_patch if max_dynamic_patch
                             is None else max_dynamic_patch)
        dynamic_image_size = (self.dynamic_image_size if dynamic_image_size
                              is None else dynamic_image_size)
        use_thumbnail = (self.use_thumbnail
                         if use_thumbnail is None else use_thumbnail)

        return resolve_internvl_min_max_num(
            min_dynamic_patch=min_dynamic_patch,
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
            use_thumbnail=use_thumbnail,
        )
315

316
317
318
    def resolve_target_ratios(
        self,
        *,
319
        min_dynamic_patch: Optional[int] = None,
320
321
322
323
324
        max_dynamic_patch: Optional[int] = None,
        dynamic_image_size: Optional[bool] = None,
        use_thumbnail: Optional[bool] = None,
    ) -> list[tuple[int, int]]:
        min_num, max_num = self.resolve_min_max_num(
325
            min_dynamic_patch=min_dynamic_patch,
326
327
328
329
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
            use_thumbnail=use_thumbnail,
        )
330

331
        return get_internvl_target_ratios(min_num, max_num)
332

333
    def get_num_image_tokens(
334
        self,
335
336
337
338
339
340
341
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        target_ratios = self.resolve_target_ratios(
            use_thumbnail=False,  # Applied in calculate_targets
        )
342

343
344
345
346
347
348
349
        num_patches, _, _ = calculate_internvl_targets(
            orig_width=image_width,
            orig_height=image_height,
            image_size=self.image_size,
            target_ratios=target_ratios,
            use_thumbnail=self.use_thumbnail,
        )
350

351
352
353
354
355
        return num_patches * self.num_image_token

    def _images_to_pixel_values_lst(
        self,
        images: list[Image.Image],
356
        min_dynamic_patch: Optional[int] = None,
357
358
359
360
        max_dynamic_patch: Optional[int] = None,
        dynamic_image_size: Optional[bool] = None,
    ) -> list[torch.Tensor]:
        min_num, max_num = self.resolve_min_max_num(
361
            min_dynamic_patch=min_dynamic_patch,
362
363
364
365
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
            use_thumbnail=False,  # Applied in image_to_pixel_values
        )
366

367
368
369
370
371
372
373
374
375
        return [
            image_to_pixel_values_internvl(
                image,
                input_size=self.image_size,
                min_num=min_num,
                max_num=max_num,
                use_thumbnail=self.use_thumbnail,
            ) for image in images
        ]
376

377
    def __call__(
378
        self,
379
380
        text: Optional[Union[str, list[str]]] = None,
        images: Optional[Union[Image.Image, list[Image.Image]]] = None,
381
        min_dynamic_patch: Optional[int] = None,
382
        max_dynamic_patch: Optional[int] = None,
383
        dynamic_image_size: Optional[bool] = None,
384
        return_tensors: Optional[Union[str, TensorType]] = None,
385
    ) -> Mapping[str, NestedTensors]:
386
387
388
389
390
391
392
393
394
395
396
        if text is None:
            text = []
        if not isinstance(text, list):
            text = [text]
        if images is None:
            images = []
        if not isinstance(images, list):
            images = [images]

        if len(images) == 0:
            image_inputs = {}
397
        else:
398
399
            pixel_values_lst = self._images_to_pixel_values_lst(
                images,
400
                min_dynamic_patch=min_dynamic_patch,
401
402
403
                max_dynamic_patch=max_dynamic_patch,
                dynamic_image_size=dynamic_image_size,
            )
404
405
406
407
408
            image_inputs: dict[str, NestedTensors] = {
                "pixel_values_flat":
                torch.cat(pixel_values_lst),
                "image_num_patches":
                torch.tensor([len(item) for item in pixel_values_lst]),
409
410
411
412
413
414
            }

            for pixel_values in pixel_values_lst:
                num_patches = pixel_values.shape[0]
                feature_size = num_patches * self.num_image_token

415
416
                image_repl = self.get_image_repl(feature_size, num_patches)
                text = [t.replace('<image>', image_repl.full, 1) for t in text]
417
418
419

        text_inputs = self.tokenizer(text)

420
421
422
423
        return {
            **BatchEncoding(text_inputs, tensor_type=return_tensors),
            **image_inputs,
        }
424
425


426
427
428
429
430
431
class InternVLProcessor(BaseInternVLProcessor):

    @property
    def image_token_id(self) -> int:
        return self.tokenizer.get_vocab()[IMG_CONTEXT]

432
    def get_image_repl(
433
434
435
        self,
        feature_size: int,
        num_patches: Optional[int],
436
437
438
    ) -> PromptUpdateDetails[str]:
        repl_features = IMG_CONTEXT * feature_size
        repl_full = IMG_START + repl_features + IMG_END
439

440
        return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
441
442
443
444
445
446


class BaseInternVLProcessingInfo(BaseProcessingInfo):

    @abstractmethod
    def get_hf_processor(
447
448
        self,
        *,
449
        min_dynamic_patch: Optional[int] = None,
450
        max_dynamic_patch: Optional[int] = None,
451
        dynamic_image_size: Optional[bool] = None,
452
        **kwargs: object,
453
454
455
456
457
458
459
460
    ) -> BaseInternVLProcessor:
        raise NotImplementedError

    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}

    def get_num_image_tokens(
        self,
461
        *,
462
463
464
465
466
467
468
469
470
471
472
        image_width: int,
        image_height: int,
        processor: Optional[BaseInternVLProcessor],
    ) -> int:
        if processor is None:
            processor = self.get_hf_processor()

        return processor.get_num_image_tokens(
            image_width=image_width,
            image_height=image_height,
        )
473

474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
    def get_image_size_with_most_features(self) -> ImageSize:
        processor = self.get_hf_processor()

        base_size = processor.image_size
        target_ratios = processor.resolve_target_ratios()

        largest_feature_size, largest_feature_pinpoint = 0, None
        for wr, hr in target_ratios:
            width, height = base_size * wr, base_size * hr

            feat_size = self.get_num_image_tokens(
                image_width=width,
                image_height=height,
                processor=processor,
            )
            if feat_size > largest_feature_size:
                largest_feature_size = feat_size
                largest_feature_pinpoint = ImageSize(width=width,
                                                     height=height)

        if largest_feature_size == 0 or largest_feature_pinpoint is None:
            raise ValueError("Cannot have a largest feature size of 0!")

        return largest_feature_pinpoint


_I = TypeVar("_I", bound=BaseInternVLProcessingInfo)


class InternVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]):

505
506
507
508
509
510
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        return "<image>" * num_images

    def get_dummy_mm_data(
511
512
513
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
514
    ) -> MultiModalDataDict:
515
516
517
518
        target_width, target_height = \
            self.info.get_image_size_with_most_features()
        num_images = mm_counts.get("image", 0)

519
        return {
520
521
522
523
524
525
526
527
528
529
530
531
532
533
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images)
        }


class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
534
    ) -> Mapping[str, NestedTensors]:
535
536
537
538
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
539
        )
540

541
542
        hf_processor = self.info.get_hf_processor(**mm_kwargs)
        image_token_id = hf_processor.image_token_id
543
544
545
546

        # Since there may be extra tokens in the feature placeholders,
        # we need to pass the image token ID to the model to select the
        # tokens to merge from the vision encoder outputs
547
        processed_outputs["image_token_id"] = torch.tensor(image_token_id)
548
549
550
551
552

        return processed_outputs

    def _get_mm_fields_config(
        self,
553
        hf_inputs: Mapping[str, NestedTensors],
554
555
556
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
557
        num_images = len(image_num_patches)
558
559
560
561
562
563

        return dict(
            pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
                "image", image_num_patches),
            image_num_patches=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
564
            image_token_id=MultiModalFieldConfig.shared("image", num_images),
565
566
        )

567
    def _get_prompt_updates(
568
569
570
571
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
572
    ) -> Sequence[PromptUpdate]:
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)

        if "image_num_patches" in out_mm_kwargs:
            image_num_patches = out_mm_kwargs["image_num_patches"]
            assert isinstance(image_num_patches, torch.Tensor)
            image_num_patches = image_num_patches.tolist()
        elif "image_embeds" in out_mm_kwargs:
            # TODO: Use image size information in dictionary embedding inputs
            # to compute num_patches (similar to Qwen2-VL)
            image_num_patches = [None] * len(out_mm_kwargs["image_embeds"])
        else:
            image_num_patches = []

        def get_replacement_internvl(item_idx: int):
            images = mm_items.get_items(
                "image", (ImageEmbeddingItems, ImageProcessorItems))

            if isinstance(images, ImageEmbeddingItems):
                feature_size = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
                feature_size = self.info.get_num_image_tokens(
                    image_width=image_size.width,
                    image_height=image_size.height,
                    processor=hf_processor,
                )

            num_patches = image_num_patches[item_idx]
            if num_patches is not None:
                assert isinstance(num_patches, int)

604
            return hf_processor.get_image_repl(feature_size, num_patches)
605

606
607
608
609
610
611
612
        return [
            PromptReplacement(
                modality="image",
                target="<image>",
                replacement=get_replacement_internvl,
            )
        ]
613
614


615
class InternVLProcessingInfo(BaseInternVLProcessingInfo):
616

617
618
619
    def get_hf_processor(
        self,
        *,
620
        min_dynamic_patch: Optional[int] = None,
621
622
        max_dynamic_patch: Optional[int] = None,
        dynamic_image_size: Optional[bool] = None,
623
        **kwargs: object,
624
    ) -> InternVLProcessor:
625
626
627
628
629
630
631
632
633
634
635
636
        if min_dynamic_patch is not None:
            kwargs["min_dynamic_patch"] = min_dynamic_patch
        if max_dynamic_patch is not None:
            kwargs["max_dynamic_patch"] = max_dynamic_patch
        if dynamic_image_size is not None:
            kwargs["dynamic_image_size"] = dynamic_image_size

        return self.ctx.init_processor(
            InternVLProcessor,
            config=self.get_hf_config(),
            tokenizer=self.get_tokenizer(),
            **kwargs,
637
638
639
640
641
642
643
        )


@MULTIMODAL_REGISTRY.register_processor(
    InternVLMultiModalProcessor,
    info=InternVLProcessingInfo,
    dummy_inputs=InternVLDummyInputsBuilder)
644
class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
645

646
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
647
648
        super().__init__()

649
650
651
652
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

653
654
        self.config = config
        self.multimodal_config = multimodal_config
655
        self._patch_quant_config(config, quant_config)
656
657
658
659
660
661
662
663
664

        image_size = config.force_image_size or config.vision_config.image_size
        patch_size = config.vision_config.patch_size
        self.patch_size = patch_size
        self.num_image_token = int(
            (image_size // patch_size)**2 * (config.downsample_ratio**2))
        self.downsample_ratio = config.downsample_ratio
        self.ps_version = config.ps_version

665
666
        self.llm_arch_name = config.text_config.architectures[0]
        self.is_mono = self.llm_arch_name == 'InternLM2VEForCausalLM'
667
668
669
670
        self.vision_model = self._init_vision_model(
            config,
            quant_config=quant_config,
            is_mono=self.is_mono,
671
            prefix=maybe_prefix(prefix, "vision_model"),
672
        )
673

674
        self.language_model = init_vllm_registered_model(
675
            vllm_config=vllm_config,
676
677
678
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
679

680
        self.mlp1 = self._init_mlp1(config)
681
682

        self.img_context_token_id = None
683
        self.visual_token_mask = None
684
685
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)
686

687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
    def _patch_quant_config(self, config: PretrainedConfig,
                            quant_config: QuantizationConfig):
        # the awq models from OpenGVLab missing `modules_to_not_convert`
        # patch the quant_config to add `modules_to_not_convert` back
        if isinstance(quant_config, AWQConfig):
            text_config = config.text_config
            llm_quant_config = getattr(text_config, "quantization_config",
                                       None)
            if (not quant_config.modules_to_not_convert) and \
                (llm_quant_config is not None):
                quant_config.modules_to_not_convert.append("vision_model")

    def _init_vision_model(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig],
        *,
        is_mono: bool,
        prefix: str,
    ):
707
        if not is_mono:
708
            vision_feature_layer = config.select_layer
709
710
711
712
713
            if vision_feature_layer < 0:
                num_hidden_layers = config.vision_config.num_hidden_layers \
                    + vision_feature_layer + 1
            else:
                num_hidden_layers = vision_feature_layer + 1
714

715
716
            return InternVisionModel(
                config.vision_config,
717
718
719
720
                quant_config=quant_config,
                num_hidden_layers_override=num_hidden_layers,
                prefix=prefix,
            )
721
722
        else:
            return InternVisionPatchModel(config.vision_config)
723
724
725
726
727
728
729
730
731
732
733
734
735

    def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
        vit_hidden_size = config.vision_config.hidden_size
        llm_hidden_size = config.text_config.hidden_size

        return nn.Sequential(
            nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2),
            nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2,
                      llm_hidden_size),
            nn.GELU(),
            nn.Linear(llm_hidden_size, llm_hidden_size),
        )

736
737
738
739
740
741
742
743
744
745
746
747
748
749
    def pixel_shuffle(self, x, scale_factor=0.5):
        n, w, h, c = x.size()
        # N, W, H, C --> N, W, H * scale, C // scale
        x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
        # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
        x = x.permute(0, 2, 1, 3).contiguous()
        x = x.view(n, int(h * scale_factor), int(w * scale_factor),
                   int(c / (scale_factor * scale_factor)))
        if self.ps_version == 'v1':
            pass
        else:
            x = x.permute(0, 2, 1, 3).contiguous()
        return x

750
    def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor:
751
752
753
754
755
756
757
758
759
760
761
762
        vit_embeds = self.vision_model(pixel_values=pixel_values)
        vit_embeds = vit_embeds[:, 1:, :]

        h = w = int(vit_embeds.shape[1]**0.5)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
        vit_embeds = self.pixel_shuffle(vit_embeds,
                                        scale_factor=self.downsample_ratio)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1,
                                        vit_embeds.shape[-1])
        vit_embeds = self.mlp1(vit_embeds)
        return vit_embeds

763
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
764
765
766
767
768
769
770
771

        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape)

            if actual_dims != expected_dims:
772
                expected_expr = str(expected_dims)
773
                raise ValueError(
774
775
776
                    "The expected shape of pixel values per image per batch "
                    f" per patch is {expected_expr}. "
                    f"You supplied {tuple(d.shape)}.")
777
778
779
780
781
782
783

        for d in data:
            _validate_shape(d)

        return data

    def _parse_and_validate_image_input(
784
            self, **kwargs: object) -> Optional[InternVLImageInputs]:
785
786
        pixel_values_flat = kwargs.pop("pixel_values_flat", None)
        image_num_patches = kwargs.pop("image_num_patches", None)
787
        image_embeds = kwargs.pop("image_embeds", None)
788

789
        if pixel_values_flat is None and image_embeds is None:
790
791
            return None

792
        if image_embeds is not None:
793
            if not isinstance(image_embeds, (torch.Tensor, list)):
794
795
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
796

797
798
            return InternVLImageEmbeddingInputs(
                type="image_embeds",
799
                data=flatten_bn(image_embeds),
800
801
            )

802
803
804
        image_token_id = kwargs["image_token_id"]
        assert isinstance(image_token_id, torch.Tensor)
        self.img_context_token_id = image_token_id.flatten().unique().item()
805

806
807
        if pixel_values_flat is not None:
            if not isinstance(pixel_values_flat, (torch.Tensor, list)):
808
                raise ValueError("Incorrect type of pixel values. "
809
810
                                 f"Got type: {type(pixel_values_flat)}")

811
812
            if not isinstance(image_num_patches, (torch.Tensor, list)):
                raise ValueError("Incorrect type of image_num_patches. "
813
814
815
816
                                 f"Got type: {type(image_num_patches)}")

            pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
            image_num_patches = flatten_bn(image_num_patches, concat=True)
817

818
819
            return InternVLImagePixelInputs(
                type="pixel_values",
820
821
822
823
                pixel_values_flat=self._validate_pixel_values(
                    pixel_values_flat),
                num_patches=image_num_patches,
            )
824
825
826
827
828
829

        raise AssertionError("This line should be unreachable.")

    def _process_image_input(
        self,
        image_input: InternVLImageInputs,
830
    ) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]]:
831
832
833
834
        if image_input["type"] == "image_embeds":
            return image_input["data"]

        assert self.vision_model is not None
835

836
        image_embeds = self.extract_feature(image_input["pixel_values_flat"])
837

838
        num_patches = image_input["num_patches"]
839
840

        # Only one image in the current batch
841
842
        if len(num_patches) == 1:
            return image_embeds.view(
843
                -1, self.config.text_config.hidden_size).unsqueeze(0)
844
845
846
847
848
849
850

        # NOTE: Image embeddings are split into separate tensors for each image
        # by the size of each embedding.
        feature_size = image_embeds.shape[1]
        image_embeds = image_embeds.view(-1,
                                         self.config.text_config.hidden_size)
        image_feature_sizes = [
851
            num_patches * feature_size for num_patches in num_patches
852
        ]
853
        return image_embeds.split(image_feature_sizes)
854

855
    def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
856
        if self.is_mono:
857
            self.visual_token_mask = (
858
859
                input_ids == self.img_context_token_id).reshape(-1, 1)
        else:
860
            self.visual_token_mask = None
861

862
863
864
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

865
    def get_multimodal_embeddings(
866
            self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
867
868
869
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
870

871
        return self._process_image_input(image_input)
872
873
874
875

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
876
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
877
878
879
880
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
        if multimodal_embeddings is not None:
            assert self.img_context_token_id is not None
881
            self._set_visual_token_mask(input_ids)
882
            inputs_embeds = merge_multimodal_embeddings(
883
884
                input_ids,
                inputs_embeds,
885
                multimodal_embeddings,
886
887
                self.img_context_token_id,
            )
888
889
        return inputs_embeds

890
891
892
893
894
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
895
        inputs_embeds: Optional[torch.Tensor] = None,
896
        **kwargs: object,
897
    ) -> IntermediateTensors:
898

899
        if intermediate_tensors is not None:
900
901
            input_ids = None
            inputs_embeds = None
902
903
904
905
906
907
908
909

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
        elif inputs_embeds is None:
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
910
911
912
913
914
915
916

        forward_kwargs = {
            "input_ids": input_ids,
            "positions": positions,
            "intermediate_tensors": intermediate_tensors,
            "inputs_embeds": inputs_embeds,
        }
917

918
        # Only required if the model is mono-architecture
919
920
921
922
        if self.visual_token_mask is not None:
            forward_kwargs.update(
                {"visual_token_mask": self.visual_token_mask})
            self.visual_token_mask = None
923

924
        hidden_states = self.language_model.model(**forward_kwargs)
925
926
        return hidden_states

927
928
929
930
931
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
932
933
934
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)

935
936
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
937
938
939
940
941
942
943
944
        # unused modules appear in OpenGVLab/InternVideo2_5_Chat_8B
        skip_prefixes = [
            "action_embed", "temporal_embed", "track_embed",
            "track_embed_decoder", "box_token", "cg_criterion", "cg_model",
            "loc_encoder", "loc_decoder", "sam", "temporal_token",
            "track_token"
        ]
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
945
        return loader.load_weights(weights)