internvl.py 35.2 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_internvl_chat.py
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
9
from abc import ABC, abstractmethod
10
from collections.abc import Iterable, Mapping, Sequence
11
from functools import cached_property
12
from typing import Literal, Optional, Set, Tuple, TypedDict, TypeVar, Union
13
14
15
16
17

import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
18
from transformers import BatchEncoding, PretrainedConfig, TensorType
19

20
from vllm.config import VllmConfig
21
22
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
Joe Runde's avatar
Joe Runde committed
23
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
24
25
from vllm.model_executor.models.intern_vit import (InternVisionModel,
                                                   InternVisionPatchModel)
26
from vllm.model_executor.sampling_metadata import SamplingMetadata
27
28
29
30
31
32
33
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
                                    NestedTensors)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
                                   ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
                                        BaseProcessingInfo, PromptReplacement,
34
                                        PromptUpdate, PromptUpdateDetails)
35
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
36
from vllm.sequence import IntermediateTensors
37
from vllm.transformers_utils.tokenizer import AnyTokenizer
38

39
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
40
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
41
                    maybe_prefix, merge_multimodal_embeddings)
42
from .vision import scatter_patch_features, select_patch_features
43
44
45
46
47
48
49
50
51
52
53

IMG_START = '<img>'
IMG_END = '</img>'
IMG_CONTEXT = '<IMG_CONTEXT>'

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)


class InternVLImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
54
    pixel_values_flat: torch.Tensor
55
    """
56
57
    Shape:
    `(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
58
    """
59
60
61
62

    num_patches: torch.Tensor
    """Shape: `(batch_size * num_images)`"""

63
64
65
66
67
68
69
70
    embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
    """
    A boolean mask indicating which image embeddings correspond
    to patch tokens.

    Shape: `(batch_size * num_images, num_embeds)`
    """

71

72
73
class InternVLImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
74
    data: Union[torch.Tensor, list[torch.Tensor]]
75
76
77
    """ 
    A tensor of shape `(num_images, total_image_feature_size, hidden_size)`
    or a list of tensors of shape `(total_image_feature_size, hidden_size)`
78
79
80
81
82
83
84
85
86

    `hidden_size` must match the hidden size of language model backbone.
    """


InternVLImageInputs = Union[InternVLImagePixelInputs,
                            InternVLImageEmbeddingInputs]


87
88
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def build_transform(input_size: int):
89
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
90
    return T.Compose([
91
92
93
94
95
96
97
98
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size),
                 interpolation=T.InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])


99
100
101
102
103
104
105
106
107
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def find_closest_aspect_ratio(
    aspect_ratio: float,
    target_ratios: list[tuple[int, int]],
    *,
    width: int,
    height: int,
    image_size: int,
) -> tuple[int, int]:
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio


123
124
125
126
127
128
129
def resolve_internvl_min_max_num(
    *,
    min_dynamic_patch: int,
    max_dynamic_patch: int,
    dynamic_image_size: bool,
    use_thumbnail: bool,
) -> tuple[int, int]:
130
    min_dynamic_patch = min_dynamic_patch if dynamic_image_size else 1
131
132
133
134
135
136
137
    max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1

    if use_thumbnail and max_dynamic_patch != 1:
        max_dynamic_patch += 1

    return min_dynamic_patch, max_dynamic_patch

138

139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
def get_internvl_target_ratios(
    min_num: int,
    max_num: int,
) -> list[tuple[int, int]]:
    target_ratios = {(i, j)
                     for n in range(min_num, max_num + 1)
                     for i in range(1, n + 1)
                     for j in range(1, n + 1) if min_num <= i * j <= max_num}
    return sorted(target_ratios, key=lambda x: x[0] * x[1])


def calculate_internvl_targets(
    *,
    orig_width: int,
    orig_height: int,
    target_ratios: list[tuple[int, int]],
    image_size: int,
    use_thumbnail: bool,
) -> tuple[int, int, int]:
    aspect_ratio = orig_width / orig_height
159
160

    # find the closest aspect ratio to the target
161
162
163
164
165
166
167
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio,
        target_ratios,
        width=orig_width,
        height=orig_height,
        image_size=image_size,
    )
168
169
170
171
172
173

    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

174
175
176
    # add thumbnail image if num_blocks != 1
    if use_thumbnail and blocks != 1:
        blocks += 1
177

178
    return blocks, target_width, target_height
179
180


181
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
182
183
184
185
186
187
188
def dynamic_preprocess_internvl(
    image: Image.Image,
    *,
    target_ratios: list[tuple[int, int]],
    image_size: int,
    use_thumbnail: bool,
) -> list[Image.Image]:
189
190
    orig_width, orig_height = image.size

191
    # calculate the number of blocks without thumbnail
192
193
194
195
196
197
198
199
    blocks, target_width, target_height = calculate_internvl_targets(
        orig_width=orig_width,
        orig_height=orig_height,
        target_ratios=target_ratios,
        image_size=image_size,
        use_thumbnail=False,
    )

200
201
202
203
204
205
206
207
208
209
210
    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = ((i % (target_width // image_size)) * image_size,
               (i // (target_width // image_size)) * image_size,
               ((i % (target_width // image_size)) + 1) * image_size,
               ((i // (target_width // image_size)) + 1) * image_size)
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
211

212
    assert len(processed_images) == blocks
213

214
215
216
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
217

218
219
220
221
    return processed_images


# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
222
223
224
225
226
227
228
229
230
231
def image_to_pixel_values_internvl(
    image: Image.Image,
    *,
    input_size: int,
    min_num: int,
    max_num: int,
    use_thumbnail: bool,
) -> torch.Tensor:
    target_ratios = get_internvl_target_ratios(min_num, max_num)

232
    transform = build_transform(input_size=input_size)
233
234
235
236
237
238
239
240
    images = dynamic_preprocess_internvl(
        image,
        target_ratios=target_ratios,
        image_size=input_size,
        use_thumbnail=use_thumbnail,
    )

    pixel_values = torch.stack([transform(image) for image in images])
241
242
243
    return pixel_values


244
245
246
247
class BaseInternVLProcessor(ABC):
    """
    This model doesn't define its own HF processor,
    so we implement our own one here.
248

249
250
251
    The code to insert image tokens is based on:
    https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py#L252
    """
252

253
254
255
256
257
    def __init__(
        self,
        config: PretrainedConfig,
        tokenizer: AnyTokenizer,
        *,
258
        min_dynamic_patch: Optional[int] = None,
259
260
261
262
        max_dynamic_patch: Optional[int] = None,
        dynamic_image_size: Optional[bool] = None,
    ) -> None:
        super().__init__()
263

264
265
        self.config = config
        self.tokenizer = tokenizer
266

267
268
        image_size: int = config.vision_config.image_size
        patch_size: int = config.vision_config.patch_size
269

270
271
272
        if min_dynamic_patch is None:
            min_dynamic_patch = config.min_dynamic_patch
        assert isinstance(min_dynamic_patch, int)
273

274
275
276
        if max_dynamic_patch is None:
            max_dynamic_patch = config.max_dynamic_patch
        assert isinstance(max_dynamic_patch, int)
277

278
279
280
281
        if dynamic_image_size is None:
            dynamic_image_size = config.dynamic_image_size
        assert isinstance(dynamic_image_size, bool)

282
283
284
        self.num_image_token = int(
            (image_size // patch_size)**2 * (config.downsample_ratio**2))
        self.image_size = image_size
285
        self.min_dynamic_patch = min_dynamic_patch
286
287
288
289
290
291
292
293
294
295
        self.max_dynamic_patch = max_dynamic_patch
        self.dynamic_image_size = dynamic_image_size
        self.use_thumbnail: bool = config.use_thumbnail

    @property
    @abstractmethod
    def image_token_id(self) -> int:
        raise NotImplementedError

    @abstractmethod
296
    def get_image_repl(
297
298
299
        self,
        feature_size: int,
        num_patches: Optional[int],
300
    ) -> PromptUpdateDetails[str]:
301
        raise NotImplementedError
302

303
    def resolve_min_max_num(
304
        self,
305
        *,
306
        min_dynamic_patch: Optional[int] = None,
307
308
309
310
        max_dynamic_patch: Optional[int] = None,
        dynamic_image_size: Optional[bool] = None,
        use_thumbnail: Optional[bool] = None,
    ) -> tuple[int, int]:
311
312
        min_dynamic_patch = (self.min_dynamic_patch if min_dynamic_patch
                             is None else min_dynamic_patch)
313
314
315
316
317
318
319
320
321
322
323
324
325
        max_dynamic_patch = (self.max_dynamic_patch if max_dynamic_patch
                             is None else max_dynamic_patch)
        dynamic_image_size = (self.dynamic_image_size if dynamic_image_size
                              is None else dynamic_image_size)
        use_thumbnail = (self.use_thumbnail
                         if use_thumbnail is None else use_thumbnail)

        return resolve_internvl_min_max_num(
            min_dynamic_patch=min_dynamic_patch,
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
            use_thumbnail=use_thumbnail,
        )
326

327
328
329
    def resolve_target_ratios(
        self,
        *,
330
        min_dynamic_patch: Optional[int] = None,
331
332
333
334
335
        max_dynamic_patch: Optional[int] = None,
        dynamic_image_size: Optional[bool] = None,
        use_thumbnail: Optional[bool] = None,
    ) -> list[tuple[int, int]]:
        min_num, max_num = self.resolve_min_max_num(
336
            min_dynamic_patch=min_dynamic_patch,
337
338
339
340
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
            use_thumbnail=use_thumbnail,
        )
341

342
        return get_internvl_target_ratios(min_num, max_num)
343

344
    def get_num_image_tokens(
345
        self,
346
347
348
349
350
351
352
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        target_ratios = self.resolve_target_ratios(
            use_thumbnail=False,  # Applied in calculate_targets
        )
353

354
355
356
357
358
359
360
        num_patches, _, _ = calculate_internvl_targets(
            orig_width=image_width,
            orig_height=image_height,
            image_size=self.image_size,
            target_ratios=target_ratios,
            use_thumbnail=self.use_thumbnail,
        )
361

362
363
364
365
366
        return num_patches * self.num_image_token

    def _images_to_pixel_values_lst(
        self,
        images: list[Image.Image],
367
        min_dynamic_patch: Optional[int] = None,
368
369
370
371
        max_dynamic_patch: Optional[int] = None,
        dynamic_image_size: Optional[bool] = None,
    ) -> list[torch.Tensor]:
        min_num, max_num = self.resolve_min_max_num(
372
            min_dynamic_patch=min_dynamic_patch,
373
374
375
376
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
            use_thumbnail=False,  # Applied in image_to_pixel_values
        )
377

378
379
380
381
382
383
384
385
386
        return [
            image_to_pixel_values_internvl(
                image,
                input_size=self.image_size,
                min_num=min_num,
                max_num=max_num,
                use_thumbnail=self.use_thumbnail,
            ) for image in images
        ]
387

388
    def __call__(
389
        self,
390
391
        text: Optional[Union[str, list[str]]] = None,
        images: Optional[Union[Image.Image, list[Image.Image]]] = None,
392
        min_dynamic_patch: Optional[int] = None,
393
        max_dynamic_patch: Optional[int] = None,
394
        dynamic_image_size: Optional[bool] = None,
395
        return_tensors: Optional[Union[str, TensorType]] = None,
396
    ) -> Mapping[str, NestedTensors]:
397
398
399
400
401
402
403
404
405
406
407
        if text is None:
            text = []
        if not isinstance(text, list):
            text = [text]
        if images is None:
            images = []
        if not isinstance(images, list):
            images = [images]

        if len(images) == 0:
            image_inputs = {}
408
        else:
409
410
            pixel_values_lst = self._images_to_pixel_values_lst(
                images,
411
                min_dynamic_patch=min_dynamic_patch,
412
413
414
                max_dynamic_patch=max_dynamic_patch,
                dynamic_image_size=dynamic_image_size,
            )
415
416
417
418
419
            image_inputs: dict[str, NestedTensors] = {
                "pixel_values_flat":
                torch.cat(pixel_values_lst),
                "image_num_patches":
                torch.tensor([len(item) for item in pixel_values_lst]),
420
421
            }

422
423
424
425
426
            tokenizer = self.tokenizer
            image_token_id = self.image_token_id

            embed_is_patch = list[torch.Tensor]()

427
428
429
430
            for pixel_values in pixel_values_lst:
                num_patches = pixel_values.shape[0]
                feature_size = num_patches * self.num_image_token

431
                image_repl = self.get_image_repl(feature_size, num_patches)
432
433
434
                feature_tokens = tokenizer.encode(image_repl.features,
                                                  add_special_tokens=False)

435
                text = [t.replace('<image>', image_repl.full, 1) for t in text]
436
437
438
439
                embed_is_patch.append(
                    torch.tensor(feature_tokens) == image_token_id)

            image_inputs["embed_is_patch"] = embed_is_patch
440
441
442

        text_inputs = self.tokenizer(text)

443
444
445
446
        return {
            **BatchEncoding(text_inputs, tensor_type=return_tensors),
            **image_inputs,
        }
447
448


449
450
451
452
453
454
class InternVLProcessor(BaseInternVLProcessor):

    @property
    def image_token_id(self) -> int:
        return self.tokenizer.get_vocab()[IMG_CONTEXT]

455
    def get_image_repl(
456
457
458
        self,
        feature_size: int,
        num_patches: Optional[int],
459
460
461
    ) -> PromptUpdateDetails[str]:
        repl_features = IMG_CONTEXT * feature_size
        repl_full = IMG_START + repl_features + IMG_END
462

463
        return PromptUpdateDetails(full=repl_full, features=repl_features)
464
465
466
467
468
469


class BaseInternVLProcessingInfo(BaseProcessingInfo):

    @abstractmethod
    def get_hf_processor(
470
471
        self,
        *,
472
        min_dynamic_patch: Optional[int] = None,
473
        max_dynamic_patch: Optional[int] = None,
474
        dynamic_image_size: Optional[bool] = None,
475
        **kwargs: object,
476
477
478
479
480
481
482
    ) -> BaseInternVLProcessor:
        raise NotImplementedError

    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}

    def get_mm_max_tokens_per_item(
483
484
485
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
486
487
488
489
490
    ) -> Mapping[str, int]:
        return {"image": self.get_max_image_tokens()}

    def get_num_image_tokens(
        self,
491
        *,
492
493
494
495
496
497
498
499
500
501
502
        image_width: int,
        image_height: int,
        processor: Optional[BaseInternVLProcessor],
    ) -> int:
        if processor is None:
            processor = self.get_hf_processor()

        return processor.get_num_image_tokens(
            image_width=image_width,
            image_height=image_height,
        )
503

504
505
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
506

507
508
509
510
        return self.get_num_image_tokens(
            image_width=target_width,
            image_height=target_height,
            processor=None,
511
        )
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562

    def get_image_size_with_most_features(self) -> ImageSize:
        processor = self.get_hf_processor()

        base_size = processor.image_size
        target_ratios = processor.resolve_target_ratios()

        largest_feature_size, largest_feature_pinpoint = 0, None
        for wr, hr in target_ratios:
            width, height = base_size * wr, base_size * hr

            feat_size = self.get_num_image_tokens(
                image_width=width,
                image_height=height,
                processor=processor,
            )
            if feat_size > largest_feature_size:
                largest_feature_size = feat_size
                largest_feature_pinpoint = ImageSize(width=width,
                                                     height=height)

        if largest_feature_size == 0 or largest_feature_pinpoint is None:
            raise ValueError("Cannot have a largest feature size of 0!")

        return largest_feature_pinpoint


_I = TypeVar("_I", bound=BaseInternVLProcessingInfo)


class InternVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]):

    def get_dummy_processor_inputs(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> ProcessorInputs:
        target_width, target_height = \
            self.info.get_image_size_with_most_features()
        num_images = mm_counts.get("image", 0)

        mm_data = {
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images)
        }

        return ProcessorInputs(
            prompt_text="<image>" * num_images,
            mm_data=mm_data,
563
564
        )

565
566
567
568
569
570
571
572

class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
573
    ) -> Mapping[str, NestedTensors]:
574
575
576
577
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
578
        )
579

580
581
        hf_processor = self.info.get_hf_processor(**mm_kwargs)
        image_token_id = hf_processor.image_token_id
582
583
584
585

        # Since there may be extra tokens in the feature placeholders,
        # we need to pass the image token ID to the model to select the
        # tokens to merge from the vision encoder outputs
586
        processed_outputs["image_token_id"] = torch.tensor(image_token_id)
587
588
589
590
591

        return processed_outputs

    def _get_mm_fields_config(
        self,
592
        hf_inputs: Mapping[str, NestedTensors],
593
594
595
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
596
        num_images = len(image_num_patches)
597
598
599
600
601

        return dict(
            pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
                "image", image_num_patches),
            image_num_patches=MultiModalFieldConfig.batched("image"),
602
            embed_is_patch=MultiModalFieldConfig.batched("image"),
603
            image_embeds=MultiModalFieldConfig.batched("image"),
604
            image_token_id=MultiModalFieldConfig.shared("image", num_images),
605
606
        )

607
    def _get_prompt_updates(
608
609
610
611
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
612
    ) -> Sequence[PromptUpdate]:
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)

        if "image_num_patches" in out_mm_kwargs:
            image_num_patches = out_mm_kwargs["image_num_patches"]
            assert isinstance(image_num_patches, torch.Tensor)
            image_num_patches = image_num_patches.tolist()
        elif "image_embeds" in out_mm_kwargs:
            # TODO: Use image size information in dictionary embedding inputs
            # to compute num_patches (similar to Qwen2-VL)
            image_num_patches = [None] * len(out_mm_kwargs["image_embeds"])
        else:
            image_num_patches = []

        def get_replacement_internvl(item_idx: int):
            images = mm_items.get_items(
                "image", (ImageEmbeddingItems, ImageProcessorItems))

            if isinstance(images, ImageEmbeddingItems):
                feature_size = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
                feature_size = self.info.get_num_image_tokens(
                    image_width=image_size.width,
                    image_height=image_size.height,
                    processor=hf_processor,
                )

            num_patches = image_num_patches[item_idx]
            if num_patches is not None:
                assert isinstance(num_patches, int)

644
            return hf_processor.get_image_repl(feature_size, num_patches)
645

646
647
648
649
650
651
652
        return [
            PromptReplacement(
                modality="image",
                target="<image>",
                replacement=get_replacement_internvl,
            )
        ]
653
654


655
class InternVLProcessingInfo(BaseInternVLProcessingInfo):
656

657
658
659
    def get_hf_processor(
        self,
        *,
660
        min_dynamic_patch: Optional[int] = None,
661
662
        max_dynamic_patch: Optional[int] = None,
        dynamic_image_size: Optional[bool] = None,
663
        **kwargs: object,
664
    ) -> InternVLProcessor:
665
666
667
668
669
670
671
672
673
674
675
676
        if min_dynamic_patch is not None:
            kwargs["min_dynamic_patch"] = min_dynamic_patch
        if max_dynamic_patch is not None:
            kwargs["max_dynamic_patch"] = max_dynamic_patch
        if dynamic_image_size is not None:
            kwargs["dynamic_image_size"] = dynamic_image_size

        return self.ctx.init_processor(
            InternVLProcessor,
            config=self.get_hf_config(),
            tokenizer=self.get_tokenizer(),
            **kwargs,
677
678
679
680
681
682
683
        )


@MULTIMODAL_REGISTRY.register_processor(
    InternVLMultiModalProcessor,
    info=InternVLProcessingInfo,
    dummy_inputs=InternVLDummyInputsBuilder)
684
class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
685

686
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
687
688
        super().__init__()

689
690
691
692
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

693
694
        self.config = config
        self.multimodal_config = multimodal_config
695
        self._patch_quant_config(config, quant_config)
696
697
698
699
700
701
702
703
704

        image_size = config.force_image_size or config.vision_config.image_size
        patch_size = config.vision_config.patch_size
        self.patch_size = patch_size
        self.num_image_token = int(
            (image_size // patch_size)**2 * (config.downsample_ratio**2))
        self.downsample_ratio = config.downsample_ratio
        self.ps_version = config.ps_version

705
706
        self.llm_arch_name = config.text_config.architectures[0]
        self.is_mono = self.llm_arch_name == 'InternLM2VEForCausalLM'
707
708
709
710
        self.vision_model = self._init_vision_model(
            config,
            quant_config=quant_config,
            is_mono=self.is_mono,
711
            prefix=maybe_prefix(prefix, "vision_model"),
712
        )
713

714
        self.language_model = init_vllm_registered_model(
715
            vllm_config=vllm_config,
716
717
718
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
719

720
        self.mlp1 = self._init_mlp1(config)
721
722

        self.img_context_token_id = None
723
        self.visual_token_mask = None
724
725
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)
726

727
728
729
730
731
732
733
734
735
736
737
738
    def _patch_quant_config(self, config: PretrainedConfig,
                            quant_config: QuantizationConfig):
        # the awq models from OpenGVLab missing `modules_to_not_convert`
        # patch the quant_config to add `modules_to_not_convert` back
        if isinstance(quant_config, AWQConfig):
            text_config = config.text_config
            llm_quant_config = getattr(text_config, "quantization_config",
                                       None)
            if (not quant_config.modules_to_not_convert) and \
                (llm_quant_config is not None):
                quant_config.modules_to_not_convert.append("vision_model")

739
740
    @cached_property
    def sampler(self):
741
        if hasattr(self.language_model, "sampler"):
742
743
            return self.language_model.sampler

Joe Runde's avatar
Joe Runde committed
744
        return get_sampler()
745

746
747
748
749
750
751
752
753
    def _init_vision_model(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig],
        *,
        is_mono: bool,
        prefix: str,
    ):
754
        if not is_mono:
755
            vision_feature_layer = config.select_layer
756
757
758
759
760
            if vision_feature_layer < 0:
                num_hidden_layers = config.vision_config.num_hidden_layers \
                    + vision_feature_layer + 1
            else:
                num_hidden_layers = vision_feature_layer + 1
761

762
763
            return InternVisionModel(
                config.vision_config,
764
765
766
767
                quant_config=quant_config,
                num_hidden_layers_override=num_hidden_layers,
                prefix=prefix,
            )
768
769
        else:
            return InternVisionPatchModel(config.vision_config)
770
771
772
773
774
775
776
777
778
779
780
781
782

    def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
        vit_hidden_size = config.vision_config.hidden_size
        llm_hidden_size = config.text_config.hidden_size

        return nn.Sequential(
            nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2),
            nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2,
                      llm_hidden_size),
            nn.GELU(),
            nn.Linear(llm_hidden_size, llm_hidden_size),
        )

783
784
785
786
787
788
789
790
791
792
793
794
795
796
    def pixel_shuffle(self, x, scale_factor=0.5):
        n, w, h, c = x.size()
        # N, W, H, C --> N, W, H * scale, C // scale
        x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
        # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
        x = x.permute(0, 2, 1, 3).contiguous()
        x = x.view(n, int(h * scale_factor), int(w * scale_factor),
                   int(c / (scale_factor * scale_factor)))
        if self.ps_version == 'v1':
            pass
        else:
            x = x.permute(0, 2, 1, 3).contiguous()
        return x

797
    def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor:
798
799
800
801
802
803
804
805
806
807
808
809
        vit_embeds = self.vision_model(pixel_values=pixel_values)
        vit_embeds = vit_embeds[:, 1:, :]

        h = w = int(vit_embeds.shape[1]**0.5)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
        vit_embeds = self.pixel_shuffle(vit_embeds,
                                        scale_factor=self.downsample_ratio)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1,
                                        vit_embeds.shape[-1])
        vit_embeds = self.mlp1(vit_embeds)
        return vit_embeds

810
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
811
812
813
814
815
816
817
818

        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape)

            if actual_dims != expected_dims:
819
                expected_expr = str(expected_dims)
820
                raise ValueError(
821
822
823
                    "The expected shape of pixel values per image per batch "
                    f" per patch is {expected_expr}. "
                    f"You supplied {tuple(d.shape)}.")
824
825
826
827
828
829
830

        for d in data:
            _validate_shape(d)

        return data

    def _parse_and_validate_image_input(
831
            self, **kwargs: object) -> Optional[InternVLImageInputs]:
832
833
        pixel_values_flat = kwargs.pop("pixel_values_flat", None)
        image_num_patches = kwargs.pop("image_num_patches", None)
834
        embed_is_patch = kwargs.pop("embed_is_patch", None)
835
        image_embeds = kwargs.pop("image_embeds", None)
836

837
        if pixel_values_flat is None and image_embeds is None:
838
839
            return None

840
        if image_embeds is not None:
841
            if not isinstance(image_embeds, (torch.Tensor, list)):
842
843
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
844

845
846
            return InternVLImageEmbeddingInputs(
                type="image_embeds",
847
                data=flatten_bn(image_embeds),
848
849
            )

850
851
852
        image_token_id = kwargs["image_token_id"]
        assert isinstance(image_token_id, torch.Tensor)
        self.img_context_token_id = image_token_id.flatten().unique().item()
853

854
855
        if pixel_values_flat is not None:
            if not isinstance(pixel_values_flat, (torch.Tensor, list)):
856
                raise ValueError("Incorrect type of pixel values. "
857
858
                                 f"Got type: {type(pixel_values_flat)}")

859
860
            if not isinstance(image_num_patches, (torch.Tensor, list)):
                raise ValueError("Incorrect type of image_num_patches. "
861
862
                                 f"Got type: {type(image_num_patches)}")

863
864
865
866
            if not isinstance(embed_is_patch, (torch.Tensor, list)):
                raise ValueError("Incorrect type of embed_is_patch. "
                                 f"Got type: {type(embed_is_patch)}")

867
868
            pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
            image_num_patches = flatten_bn(image_num_patches, concat=True)
869
            embed_is_patch = flatten_bn(embed_is_patch)
870

871
872
            return InternVLImagePixelInputs(
                type="pixel_values",
873
874
875
                pixel_values_flat=self._validate_pixel_values(
                    pixel_values_flat),
                num_patches=image_num_patches,
876
                embed_is_patch=embed_is_patch,
877
            )
878
879
880
881
882
883

        raise AssertionError("This line should be unreachable.")

    def _process_image_input(
        self,
        image_input: InternVLImageInputs,
884
    ) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]]:
885
886
887
888
        if image_input["type"] == "image_embeds":
            return image_input["data"]

        assert self.vision_model is not None
889

890
        image_embeds = self.extract_feature(image_input["pixel_values_flat"])
891

892
        num_patches = image_input["num_patches"]
893
894

        # Only one image in the current batch
895
896
        if len(num_patches) == 1:
            return image_embeds.view(
897
                -1, self.config.text_config.hidden_size).unsqueeze(0)
898
899
900
901
902
903
904

        # NOTE: Image embeddings are split into separate tensors for each image
        # by the size of each embedding.
        feature_size = image_embeds.shape[1]
        image_embeds = image_embeds.view(-1,
                                         self.config.text_config.hidden_size)
        image_feature_sizes = [
905
            num_patches * feature_size for num_patches in num_patches
906
        ]
907
        return image_embeds.split(image_feature_sizes)
908

909
    def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
910
        if self.is_mono:
911
            self.visual_token_mask = (
912
913
                input_ids == self.img_context_token_id).reshape(-1, 1)
        else:
914
            self.visual_token_mask = None
915

916
    def get_multimodal_embeddings(
917
            self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
918
919
920
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
921

922
923
924
925
926
927
928
929
930
        image_features = self._process_image_input(image_input)

        if image_input["type"] != "pixel_values":
            return image_features

        return scatter_patch_features(
            image_features,
            image_input["embed_is_patch"],
        )
931
932
933
934

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
935
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
936
937
938
939
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
        if multimodal_embeddings is not None:
            assert self.img_context_token_id is not None
940
            self._set_visual_token_mask(input_ids)
941
            inputs_embeds = merge_multimodal_embeddings(
942
943
                input_ids,
                inputs_embeds,
944
                select_patch_features(multimodal_embeddings),
945
946
                self.img_context_token_id,
            )
947
948
        return inputs_embeds

949
950
951
952
953
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
954
        inputs_embeds: Optional[torch.Tensor] = None,
955
        **kwargs: object,
956
    ) -> Union[SamplerOutput, IntermediateTensors]:
957

958
        if intermediate_tensors is not None:
959
960
            input_ids = None
            inputs_embeds = None
961
962
963
964
965
966
967
968

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
        elif inputs_embeds is None:
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
969
970
971
972
973
974
975

        forward_kwargs = {
            "input_ids": input_ids,
            "positions": positions,
            "intermediate_tensors": intermediate_tensors,
            "inputs_embeds": inputs_embeds,
        }
976

977
        # Only required if the model is mono-architecture
978
979
980
981
        if self.visual_token_mask is not None:
            forward_kwargs.update(
                {"visual_token_mask": self.visual_token_mask})
            self.visual_token_mask = None
982

983
        hidden_states = self.language_model.model(**forward_kwargs)
984
985
        return hidden_states

986
987
988
989
990
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
991
992
993
994
995
996
997
998
999
1000
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        return self.language_model.sample(logits, sampling_metadata)

1001
1002
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
1003
1004
1005
1006
1007
1008
1009
1010
        # unused modules appear in OpenGVLab/InternVideo2_5_Chat_8B
        skip_prefixes = [
            "action_embed", "temporal_embed", "track_embed",
            "track_embed_decoder", "box_token", "cg_criterion", "cg_model",
            "loc_encoder", "loc_decoder", "sam", "temporal_token",
            "track_token"
        ]
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
1011
        return loader.load_weights(weights)