internvl.py 34.2 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_internvl_chat.py
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
9
from abc import ABC, abstractmethod
10
from collections.abc import Iterable, Mapping, Sequence
11
from functools import cached_property
12
13
from typing import (List, Literal, Optional, Set, Tuple, TypedDict, TypeVar,
                    Union)
14
15
16
17
18

import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
19
from transformers import BatchFeature, PretrainedConfig, TensorType
20

21
from vllm.config import VllmConfig
22
23
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
Joe Runde's avatar
Joe Runde committed
24
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
25
26
from vllm.model_executor.models.intern_vit import (InternVisionModel,
                                                   InternVisionPatchModel)
27
from vllm.model_executor.sampling_metadata import SamplingMetadata
28
29
30
31
32
33
34
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
                                    NestedTensors)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
                                   ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
                                        BaseProcessingInfo, PromptReplacement,
35
                                        PromptUpdate, PromptUpdateDetails)
36
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
37
from vllm.sequence import IntermediateTensors
38
from vllm.transformers_utils.tokenizer import AnyTokenizer
39

40
from .interfaces import SupportsMultiModal, SupportsPP
41
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
42
                    maybe_prefix, merge_multimodal_embeddings)
43
44
45
46
47
48
49
50
51
52
53

IMG_START = '<img>'
IMG_END = '</img>'
IMG_CONTEXT = '<IMG_CONTEXT>'

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)


class InternVLImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
54
    data: torch.Tensor
55
    """
56
57
    Shape:
    `(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
58
    """
59
60
61
62
    patches_per_image: List[int]
    """
    List of number of total patches for each image in the batch.
    """
63
64


65
66
class InternVLImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
67
68
69
70
    data: NestedTensors
    """ 
    A tensor of shape `(num_images, total_image_feature_size, hidden_size)`
    or a list of tensors of shape `(total_image_feature_size, hidden_size)`
71
72
73
74
75
76
77
78
79

    `hidden_size` must match the hidden size of language model backbone.
    """


InternVLImageInputs = Union[InternVLImagePixelInputs,
                            InternVLImageEmbeddingInputs]


80
81
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def build_transform(input_size: int):
82
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
83
    return T.Compose([
84
85
86
87
88
89
90
91
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size),
                 interpolation=T.InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])


92
93
94
95
96
97
98
99
100
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def find_closest_aspect_ratio(
    aspect_ratio: float,
    target_ratios: list[tuple[int, int]],
    *,
    width: int,
    height: int,
    image_size: int,
) -> tuple[int, int]:
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio


116
117
118
119
120
121
122
def resolve_internvl_min_max_num(
    *,
    min_dynamic_patch: int,
    max_dynamic_patch: int,
    dynamic_image_size: bool,
    use_thumbnail: bool,
) -> tuple[int, int]:
123
    min_dynamic_patch = min_dynamic_patch if dynamic_image_size else 1
124
125
126
127
128
129
130
    max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1

    if use_thumbnail and max_dynamic_patch != 1:
        max_dynamic_patch += 1

    return min_dynamic_patch, max_dynamic_patch

131

132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def get_internvl_target_ratios(
    min_num: int,
    max_num: int,
) -> list[tuple[int, int]]:
    target_ratios = {(i, j)
                     for n in range(min_num, max_num + 1)
                     for i in range(1, n + 1)
                     for j in range(1, n + 1) if min_num <= i * j <= max_num}
    return sorted(target_ratios, key=lambda x: x[0] * x[1])


def calculate_internvl_targets(
    *,
    orig_width: int,
    orig_height: int,
    target_ratios: list[tuple[int, int]],
    image_size: int,
    use_thumbnail: bool,
) -> tuple[int, int, int]:
    aspect_ratio = orig_width / orig_height
152
153

    # find the closest aspect ratio to the target
154
155
156
157
158
159
160
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio,
        target_ratios,
        width=orig_width,
        height=orig_height,
        image_size=image_size,
    )
161
162
163
164
165
166

    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

167
168
169
    # add thumbnail image if num_blocks != 1
    if use_thumbnail and blocks != 1:
        blocks += 1
170

171
    return blocks, target_width, target_height
172
173


174
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
175
176
177
178
179
180
181
def dynamic_preprocess_internvl(
    image: Image.Image,
    *,
    target_ratios: list[tuple[int, int]],
    image_size: int,
    use_thumbnail: bool,
) -> list[Image.Image]:
182
183
    orig_width, orig_height = image.size

184
    # calculate the number of blocks without thumbnail
185
186
187
188
189
190
191
192
    blocks, target_width, target_height = calculate_internvl_targets(
        orig_width=orig_width,
        orig_height=orig_height,
        target_ratios=target_ratios,
        image_size=image_size,
        use_thumbnail=False,
    )

193
194
195
196
197
198
199
200
201
202
203
    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = ((i % (target_width // image_size)) * image_size,
               (i // (target_width // image_size)) * image_size,
               ((i % (target_width // image_size)) + 1) * image_size,
               ((i // (target_width // image_size)) + 1) * image_size)
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
204

205
    assert len(processed_images) == blocks
206

207
208
209
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
210

211
212
213
214
    return processed_images


# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
215
216
217
218
219
220
221
222
223
224
def image_to_pixel_values_internvl(
    image: Image.Image,
    *,
    input_size: int,
    min_num: int,
    max_num: int,
    use_thumbnail: bool,
) -> torch.Tensor:
    target_ratios = get_internvl_target_ratios(min_num, max_num)

225
    transform = build_transform(input_size=input_size)
226
227
228
229
230
231
232
233
    images = dynamic_preprocess_internvl(
        image,
        target_ratios=target_ratios,
        image_size=input_size,
        use_thumbnail=use_thumbnail,
    )

    pixel_values = torch.stack([transform(image) for image in images])
234
235
236
    return pixel_values


237
238
239
240
class BaseInternVLProcessor(ABC):
    """
    This model doesn't define its own HF processor,
    so we implement our own one here.
241

242
243
244
    The code to insert image tokens is based on:
    https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py#L252
    """
245

246
247
248
249
250
    def __init__(
        self,
        config: PretrainedConfig,
        tokenizer: AnyTokenizer,
        *,
251
        min_dynamic_patch: Optional[int] = None,
252
253
254
255
        max_dynamic_patch: Optional[int] = None,
        dynamic_image_size: Optional[bool] = None,
    ) -> None:
        super().__init__()
256

257
258
        self.config = config
        self.tokenizer = tokenizer
259

260
261
        image_size: int = config.vision_config.image_size
        patch_size: int = config.vision_config.patch_size
262

263
264
265
        if min_dynamic_patch is None:
            min_dynamic_patch = config.min_dynamic_patch
        assert isinstance(min_dynamic_patch, int)
266

267
268
269
        if max_dynamic_patch is None:
            max_dynamic_patch = config.max_dynamic_patch
        assert isinstance(max_dynamic_patch, int)
270

271
272
273
274
        if dynamic_image_size is None:
            dynamic_image_size = config.dynamic_image_size
        assert isinstance(dynamic_image_size, bool)

275
276
277
        self.num_image_token = int(
            (image_size // patch_size)**2 * (config.downsample_ratio**2))
        self.image_size = image_size
278
        self.min_dynamic_patch = min_dynamic_patch
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
        self.max_dynamic_patch = max_dynamic_patch
        self.dynamic_image_size = dynamic_image_size
        self.use_thumbnail: bool = config.use_thumbnail

    @property
    @abstractmethod
    def image_token_id(self) -> int:
        raise NotImplementedError

    @abstractmethod
    def get_image_repl_features(
        self,
        feature_size: int,
        num_patches: Optional[int],
    ) -> str:
        raise NotImplementedError
295

296
297
298
299
300
301
302
    @abstractmethod
    def get_image_repl_full(
        self,
        feature_size: int,
        num_patches: Optional[int],
    ) -> str:
        raise NotImplementedError
303

304
    def resolve_min_max_num(
305
        self,
306
        *,
307
        min_dynamic_patch: Optional[int] = None,
308
309
310
311
        max_dynamic_patch: Optional[int] = None,
        dynamic_image_size: Optional[bool] = None,
        use_thumbnail: Optional[bool] = None,
    ) -> tuple[int, int]:
312
313
        min_dynamic_patch = (self.min_dynamic_patch if min_dynamic_patch
                             is None else min_dynamic_patch)
314
315
316
317
318
319
320
321
322
323
324
325
326
        max_dynamic_patch = (self.max_dynamic_patch if max_dynamic_patch
                             is None else max_dynamic_patch)
        dynamic_image_size = (self.dynamic_image_size if dynamic_image_size
                              is None else dynamic_image_size)
        use_thumbnail = (self.use_thumbnail
                         if use_thumbnail is None else use_thumbnail)

        return resolve_internvl_min_max_num(
            min_dynamic_patch=min_dynamic_patch,
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
            use_thumbnail=use_thumbnail,
        )
327

328
329
330
    def resolve_target_ratios(
        self,
        *,
331
        min_dynamic_patch: Optional[int] = None,
332
333
334
335
336
        max_dynamic_patch: Optional[int] = None,
        dynamic_image_size: Optional[bool] = None,
        use_thumbnail: Optional[bool] = None,
    ) -> list[tuple[int, int]]:
        min_num, max_num = self.resolve_min_max_num(
337
            min_dynamic_patch=min_dynamic_patch,
338
339
340
341
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
            use_thumbnail=use_thumbnail,
        )
342

343
        return get_internvl_target_ratios(min_num, max_num)
344

345
    def get_num_image_tokens(
346
        self,
347
348
349
350
351
352
353
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        target_ratios = self.resolve_target_ratios(
            use_thumbnail=False,  # Applied in calculate_targets
        )
354

355
356
357
358
359
360
361
        num_patches, _, _ = calculate_internvl_targets(
            orig_width=image_width,
            orig_height=image_height,
            image_size=self.image_size,
            target_ratios=target_ratios,
            use_thumbnail=self.use_thumbnail,
        )
362

363
364
365
366
367
        return num_patches * self.num_image_token

    def _images_to_pixel_values_lst(
        self,
        images: list[Image.Image],
368
        min_dynamic_patch: Optional[int] = None,
369
370
371
372
        max_dynamic_patch: Optional[int] = None,
        dynamic_image_size: Optional[bool] = None,
    ) -> list[torch.Tensor]:
        min_num, max_num = self.resolve_min_max_num(
373
            min_dynamic_patch=min_dynamic_patch,
374
375
376
377
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
            use_thumbnail=False,  # Applied in image_to_pixel_values
        )
378

379
380
381
382
383
384
385
386
387
        return [
            image_to_pixel_values_internvl(
                image,
                input_size=self.image_size,
                min_num=min_num,
                max_num=max_num,
                use_thumbnail=self.use_thumbnail,
            ) for image in images
        ]
388

389
    def __call__(
390
        self,
391
392
        text: Optional[Union[str, list[str]]] = None,
        images: Optional[Union[Image.Image, list[Image.Image]]] = None,
393
        min_dynamic_patch: Optional[int] = None,
394
        max_dynamic_patch: Optional[int] = None,
395
        dynamic_image_size: Optional[bool] = None,
396
397
398
399
400
401
402
403
404
405
406
407
408
        return_tensors: Optional[Union[str, TensorType]] = None,
    ) -> BatchFeature:
        if text is None:
            text = []
        if not isinstance(text, list):
            text = [text]
        if images is None:
            images = []
        if not isinstance(images, list):
            images = [images]

        if len(images) == 0:
            image_inputs = {}
409
        else:
410
411
            pixel_values_lst = self._images_to_pixel_values_lst(
                images,
412
                min_dynamic_patch=min_dynamic_patch,
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
                max_dynamic_patch=max_dynamic_patch,
                dynamic_image_size=dynamic_image_size,
            )
            image_inputs = {
                "pixel_values_flat": torch.cat(pixel_values_lst),
                "image_num_patches": list(map(len, pixel_values_lst)),
            }

            for pixel_values in pixel_values_lst:
                num_patches = pixel_values.shape[0]
                feature_size = num_patches * self.num_image_token

                image_repl = self.get_image_repl_full(feature_size,
                                                      num_patches)
                text = [t.replace('<image>', image_repl, 1) for t in text]

        text_inputs = self.tokenizer(text)

        return BatchFeature(
            {
                **text_inputs,
                **image_inputs,
            },
            tensor_type=return_tensors,
        )
438
439


440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
class InternVLProcessor(BaseInternVLProcessor):

    @property
    def image_token_id(self) -> int:
        return self.tokenizer.get_vocab()[IMG_CONTEXT]

    def get_image_repl_features(
        self,
        feature_size: int,
        num_patches: Optional[int],
    ) -> str:
        return IMG_CONTEXT * feature_size

    def get_image_repl_full(
        self,
        feature_size: int,
        num_patches: Optional[int],
    ) -> str:
        features = self.get_image_repl_features(feature_size, num_patches)
        return IMG_START + features + IMG_END


class BaseInternVLProcessingInfo(BaseProcessingInfo):

    @abstractmethod
    def get_hf_processor(
466
467
        self,
        *,
468
        min_dynamic_patch: Optional[int] = None,
469
        max_dynamic_patch: Optional[int] = None,
470
        dynamic_image_size: Optional[bool] = None,
471
        **kwargs: object,
472
473
474
475
476
477
478
    ) -> BaseInternVLProcessor:
        raise NotImplementedError

    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}

    def get_mm_max_tokens_per_item(
479
480
481
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
482
483
484
485
486
    ) -> Mapping[str, int]:
        return {"image": self.get_max_image_tokens()}

    def get_num_image_tokens(
        self,
487
        *,
488
489
490
491
492
493
494
495
496
497
498
        image_width: int,
        image_height: int,
        processor: Optional[BaseInternVLProcessor],
    ) -> int:
        if processor is None:
            processor = self.get_hf_processor()

        return processor.get_num_image_tokens(
            image_width=image_width,
            image_height=image_height,
        )
499

500
501
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
502

503
504
505
506
        return self.get_num_image_tokens(
            image_width=target_width,
            image_height=target_height,
            processor=None,
507
        )
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558

    def get_image_size_with_most_features(self) -> ImageSize:
        processor = self.get_hf_processor()

        base_size = processor.image_size
        target_ratios = processor.resolve_target_ratios()

        largest_feature_size, largest_feature_pinpoint = 0, None
        for wr, hr in target_ratios:
            width, height = base_size * wr, base_size * hr

            feat_size = self.get_num_image_tokens(
                image_width=width,
                image_height=height,
                processor=processor,
            )
            if feat_size > largest_feature_size:
                largest_feature_size = feat_size
                largest_feature_pinpoint = ImageSize(width=width,
                                                     height=height)

        if largest_feature_size == 0 or largest_feature_pinpoint is None:
            raise ValueError("Cannot have a largest feature size of 0!")

        return largest_feature_pinpoint


_I = TypeVar("_I", bound=BaseInternVLProcessingInfo)


class InternVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]):

    def get_dummy_processor_inputs(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> ProcessorInputs:
        target_width, target_height = \
            self.info.get_image_size_with_most_features()
        num_images = mm_counts.get("image", 0)

        mm_data = {
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images)
        }

        return ProcessorInputs(
            prompt_text="<image>" * num_images,
            mm_data=mm_data,
559
560
        )

561
562
563
564
565
566
567
568
569
570
571
572
573

class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
574
        )
575

576
577
578
579
580
581
582
        image_token_id = self.info.get_hf_processor(**mm_kwargs).image_token_id
        image_data = mm_data.get("images", [])
        assert isinstance(image_data, list)

        # Since there may be extra tokens in the feature placeholders,
        # we need to pass the image token ID to the model to select the
        # tokens to merge from the vision encoder outputs
583
        processed_outputs["image_token_id"] = torch.tensor(image_token_id)
584
585
586
587
588
589
590
591
592

        return processed_outputs

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
593
        num_images = len(image_num_patches)
594
595
596
597
598
599

        return dict(
            pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
                "image", image_num_patches),
            image_num_patches=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
600
            image_token_id=MultiModalFieldConfig.shared("image", num_images),
601
602
        )

603
    def _get_prompt_updates(
604
605
606
607
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
608
    ) -> Sequence[PromptUpdate]:
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)

        if "image_num_patches" in out_mm_kwargs:
            image_num_patches = out_mm_kwargs["image_num_patches"]
            assert isinstance(image_num_patches, torch.Tensor)
            image_num_patches = image_num_patches.tolist()
        elif "image_embeds" in out_mm_kwargs:
            # TODO: Use image size information in dictionary embedding inputs
            # to compute num_patches (similar to Qwen2-VL)
            image_num_patches = [None] * len(out_mm_kwargs["image_embeds"])
        else:
            image_num_patches = []

        def get_replacement_internvl(item_idx: int):
            images = mm_items.get_items(
                "image", (ImageEmbeddingItems, ImageProcessorItems))

            if isinstance(images, ImageEmbeddingItems):
                feature_size = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
                feature_size = self.info.get_num_image_tokens(
                    image_width=image_size.width,
                    image_height=image_size.height,
                    processor=hf_processor,
                )

            num_patches = image_num_patches[item_idx]
            if num_patches is not None:
                assert isinstance(num_patches, int)

640
            return PromptUpdateDetails(
641
642
643
644
645
                full=hf_processor.get_image_repl_full(feature_size,
                                                      num_patches),
                features=hf_processor.get_image_repl_features(
                    feature_size, num_patches),
            )
646

647
648
649
650
651
652
653
        return [
            PromptReplacement(
                modality="image",
                target="<image>",
                replacement=get_replacement_internvl,
            )
        ]
654
655


656
class InternVLProcessingInfo(BaseInternVLProcessingInfo):
657

658
659
660
    def get_hf_processor(
        self,
        *,
661
        min_dynamic_patch: Optional[int] = None,
662
663
        max_dynamic_patch: Optional[int] = None,
        dynamic_image_size: Optional[bool] = None,
664
        **kwargs: object,
665
    ) -> InternVLProcessor:
666
667
668
669
670
671
672
673
674
675
676
677
        if min_dynamic_patch is not None:
            kwargs["min_dynamic_patch"] = min_dynamic_patch
        if max_dynamic_patch is not None:
            kwargs["max_dynamic_patch"] = max_dynamic_patch
        if dynamic_image_size is not None:
            kwargs["dynamic_image_size"] = dynamic_image_size

        return self.ctx.init_processor(
            InternVLProcessor,
            config=self.get_hf_config(),
            tokenizer=self.get_tokenizer(),
            **kwargs,
678
679
680
681
682
683
684
        )


@MULTIMODAL_REGISTRY.register_processor(
    InternVLMultiModalProcessor,
    info=InternVLProcessingInfo,
    dummy_inputs=InternVLDummyInputsBuilder)
685
class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
686

687
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
688
689
        super().__init__()

690
691
692
693
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

694
695
        self.config = config
        self.multimodal_config = multimodal_config
696
        self._patch_quant_config(config, quant_config)
697
698
699
700
701
702
703
704
705

        image_size = config.force_image_size or config.vision_config.image_size
        patch_size = config.vision_config.patch_size
        self.patch_size = patch_size
        self.num_image_token = int(
            (image_size // patch_size)**2 * (config.downsample_ratio**2))
        self.downsample_ratio = config.downsample_ratio
        self.ps_version = config.ps_version

706
707
        self.llm_arch_name = config.text_config.architectures[0]
        self.is_mono = self.llm_arch_name == 'InternLM2VEForCausalLM'
708
709
710
711
        self.vision_model = self._init_vision_model(
            config,
            quant_config=quant_config,
            is_mono=self.is_mono,
712
            prefix=maybe_prefix(prefix, "vision_model"),
713
        )
714

715
        self.language_model = init_vllm_registered_model(
716
            vllm_config=vllm_config,
717
718
719
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
720

721
        self.mlp1 = self._init_mlp1(config)
722
723

        self.img_context_token_id = None
724
        self.visual_token_mask = None
725
726
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)
727

728
729
730
731
732
733
734
735
736
737
738
739
    def _patch_quant_config(self, config: PretrainedConfig,
                            quant_config: QuantizationConfig):
        # the awq models from OpenGVLab missing `modules_to_not_convert`
        # patch the quant_config to add `modules_to_not_convert` back
        if isinstance(quant_config, AWQConfig):
            text_config = config.text_config
            llm_quant_config = getattr(text_config, "quantization_config",
                                       None)
            if (not quant_config.modules_to_not_convert) and \
                (llm_quant_config is not None):
                quant_config.modules_to_not_convert.append("vision_model")

740
741
    @cached_property
    def sampler(self):
742
        if hasattr(self.language_model, "sampler"):
743
744
            return self.language_model.sampler

Joe Runde's avatar
Joe Runde committed
745
        return get_sampler()
746

747
748
749
750
751
752
753
754
    def _init_vision_model(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig],
        *,
        is_mono: bool,
        prefix: str,
    ):
755
        if not is_mono:
756
            vision_feature_layer = config.select_layer
757
758
759
760
761
            if vision_feature_layer < 0:
                num_hidden_layers = config.vision_config.num_hidden_layers \
                    + vision_feature_layer + 1
            else:
                num_hidden_layers = vision_feature_layer + 1
762

763
764
            return InternVisionModel(
                config.vision_config,
765
766
767
768
                quant_config=quant_config,
                num_hidden_layers_override=num_hidden_layers,
                prefix=prefix,
            )
769
770
        else:
            return InternVisionPatchModel(config.vision_config)
771
772
773
774
775
776
777
778
779
780
781
782
783

    def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
        vit_hidden_size = config.vision_config.hidden_size
        llm_hidden_size = config.text_config.hidden_size

        return nn.Sequential(
            nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2),
            nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2,
                      llm_hidden_size),
            nn.GELU(),
            nn.Linear(llm_hidden_size, llm_hidden_size),
        )

784
785
786
787
788
789
790
791
792
793
794
795
796
797
    def pixel_shuffle(self, x, scale_factor=0.5):
        n, w, h, c = x.size()
        # N, W, H, C --> N, W, H * scale, C // scale
        x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
        # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
        x = x.permute(0, 2, 1, 3).contiguous()
        x = x.view(n, int(h * scale_factor), int(w * scale_factor),
                   int(c / (scale_factor * scale_factor)))
        if self.ps_version == 'v1':
            pass
        else:
            x = x.permute(0, 2, 1, 3).contiguous()
        return x

798
    def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor:
799
800
801
802
803
804
805
806
807
808
809
810
        vit_embeds = self.vision_model(pixel_values=pixel_values)
        vit_embeds = vit_embeds[:, 1:, :]

        h = w = int(vit_embeds.shape[1]**0.5)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
        vit_embeds = self.pixel_shuffle(vit_embeds,
                                        scale_factor=self.downsample_ratio)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1,
                                        vit_embeds.shape[-1])
        vit_embeds = self.mlp1(vit_embeds)
        return vit_embeds

811
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
812
813
814
815
816
817
818
819

        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape)

            if actual_dims != expected_dims:
820
                expected_expr = str(expected_dims)
821
                raise ValueError(
822
823
824
                    "The expected shape of pixel values per image per batch "
                    f" per patch is {expected_expr}. "
                    f"You supplied {tuple(d.shape)}.")
825
826
827
828
829
830
831

        for d in data:
            _validate_shape(d)

        return data

    def _parse_and_validate_image_input(
832
            self, **kwargs: object) -> Optional[InternVLImageInputs]:
833
834
        pixel_values_flat = kwargs.pop("pixel_values_flat", None)
        image_num_patches = kwargs.pop("image_num_patches", None)
835
        image_embeds = kwargs.pop("image_embeds", None)
836

837
        if pixel_values_flat is None and image_embeds is None:
838
839
            return None

840
841
842
843
        if image_embeds is not None:
            if not isinstance(image_embeds, torch.Tensor):
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
844

845
846
            return InternVLImageEmbeddingInputs(
                type="image_embeds",
847
                data=flatten_bn(image_embeds),
848
849
            )

850
851
852
        image_token_id = kwargs["image_token_id"]
        assert isinstance(image_token_id, torch.Tensor)
        self.img_context_token_id = image_token_id.flatten().unique().item()
853

854
855
        if pixel_values_flat is not None:
            if not isinstance(pixel_values_flat, (torch.Tensor, list)):
856
                raise ValueError("Incorrect type of pixel values. "
857
858
859
860
                                 f"Got type: {type(pixel_values_flat)}")

            assert isinstance(image_num_patches, (torch.Tensor, list))

861
862
            return InternVLImagePixelInputs(
                type="pixel_values",
863
                data=self._validate_pixel_values(
864
865
866
                    flatten_bn(pixel_values_flat, concat=True)),
                patches_per_image=flatten_bn(image_num_patches,
                                             concat=True).tolist())
867
868
869
870
871
872

        raise AssertionError("This line should be unreachable.")

    def _process_image_input(
        self,
        image_input: InternVLImageInputs,
873
    ) -> tuple[torch.Tensor, ...]:
874
875
876
877
        if image_input["type"] == "image_embeds":
            return image_input["data"]

        assert self.vision_model is not None
878

879
        image_embeds = self.extract_feature(image_input["data"])
880

881
        patches_per_image = image_input["patches_per_image"]
882
883

        # Only one image in the current batch
884
        if len(patches_per_image) == 1:
885
886
            image_embeds = image_embeds.view(
                -1, self.config.text_config.hidden_size).unsqueeze(0)
887
888
889
890
891
892
893
894
895
896
897
            return image_embeds

        # NOTE: Image embeddings are split into separate tensors for each image
        # by the size of each embedding.
        feature_size = image_embeds.shape[1]
        image_embeds = image_embeds.view(-1,
                                         self.config.text_config.hidden_size)
        image_feature_sizes = [
            num_patches * feature_size for num_patches in patches_per_image
        ]
        image_embeds = image_embeds.split(image_feature_sizes)
898
        return image_embeds
899

900
    def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
901
        if self.is_mono:
902
            self.visual_token_mask = (
903
904
                input_ids == self.img_context_token_id).reshape(-1, 1)
        else:
905
            self.visual_token_mask = None
906

907
908
909
    def get_multimodal_embeddings(
        self, **kwargs
    ) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
910
911
912
913
914
915
916
917
918
919
920
921
922
923
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: Optional[NestedTensors] = None,
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
        if multimodal_embeddings is not None:
            assert self.img_context_token_id is not None
924
            self._set_visual_token_mask(input_ids)
925
926
927
928
929
            inputs_embeds = merge_multimodal_embeddings(
                input_ids, inputs_embeds, multimodal_embeddings,
                self.img_context_token_id)
        return inputs_embeds

930
931
932
933
934
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
935
        inputs_embeds: Optional[torch.Tensor] = None,
936
        **kwargs: object,
937
    ) -> Union[SamplerOutput, IntermediateTensors]:
938

939
        if intermediate_tensors is not None:
940
941
            input_ids = None
            inputs_embeds = None
942
943
944
945
946
947
948
949

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
        elif inputs_embeds is None:
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
950
951
952
953
954
955
956

        forward_kwargs = {
            "input_ids": input_ids,
            "positions": positions,
            "intermediate_tensors": intermediate_tensors,
            "inputs_embeds": inputs_embeds,
        }
957

958
        # Only required if the model is mono-architecture
959
960
961
962
        if self.visual_token_mask is not None:
            forward_kwargs.update(
                {"visual_token_mask": self.visual_token_mask})
            self.visual_token_mask = None
963

964
        hidden_states = self.language_model.model(**forward_kwargs)
965
966
        return hidden_states

967
968
969
970
971
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
972
973
974
975
976
977
978
979
980
981
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        return self.language_model.sample(logits, sampling_metadata)

982
983
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
984
985
986
987
988
989
990
991
        # unused modules appear in OpenGVLab/InternVideo2_5_Chat_8B
        skip_prefixes = [
            "action_embed", "temporal_embed", "track_embed",
            "track_embed_decoder", "box_token", "cg_criterion", "cg_model",
            "loc_encoder", "loc_decoder", "sam", "temporal_token",
            "track_token"
        ]
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
992
        return loader.load_weights(weights)