internvl.py 33.5 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_internvl_chat.py
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
9
from abc import ABC, abstractmethod
10
from collections.abc import Iterable, Mapping, Sequence
11
from functools import cached_property
12
from typing import Literal, Optional, Set, Tuple, TypedDict, TypeVar, Union
13
14
15
16
17

import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
18
from transformers import BatchEncoding, PretrainedConfig, TensorType
19

20
from vllm.config import VllmConfig
21
22
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
Joe Runde's avatar
Joe Runde committed
23
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
24
25
from vllm.model_executor.models.intern_vit import (InternVisionModel,
                                                   InternVisionPatchModel)
26
from vllm.model_executor.sampling_metadata import SamplingMetadata
27
from vllm.multimodal import MULTIMODAL_REGISTRY
28
29
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
                                    MultiModalKwargs, NestedTensors)
30
31
32
33
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
                                   ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
                                        BaseProcessingInfo, PromptReplacement,
34
                                        PromptUpdate, PromptUpdateDetails)
35
from vllm.multimodal.profiling import BaseDummyInputsBuilder
36
from vllm.sequence import IntermediateTensors
37
from vllm.transformers_utils.tokenizer import AnyTokenizer
38

39
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
40
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
41
                    maybe_prefix, merge_multimodal_embeddings)
42
43
44
45
46
47
48
49
50
51
52

IMG_START = '<img>'
IMG_END = '</img>'
IMG_CONTEXT = '<IMG_CONTEXT>'

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)


class InternVLImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
53
    pixel_values_flat: torch.Tensor
54
    """
55
56
    Shape:
    `(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
57
    """
58
59
60
61

    num_patches: torch.Tensor
    """Shape: `(batch_size * num_images)`"""

62

63
64
class InternVLImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
65
    data: Union[torch.Tensor, list[torch.Tensor]]
66
67
68
    """ 
    A tensor of shape `(num_images, total_image_feature_size, hidden_size)`
    or a list of tensors of shape `(total_image_feature_size, hidden_size)`
69
70
71
72
73
74
75
76
77

    `hidden_size` must match the hidden size of language model backbone.
    """


InternVLImageInputs = Union[InternVLImagePixelInputs,
                            InternVLImageEmbeddingInputs]


78
79
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def build_transform(input_size: int):
80
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
81
    return T.Compose([
82
83
84
85
86
87
88
89
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size),
                 interpolation=T.InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])


90
91
92
93
94
95
96
97
98
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def find_closest_aspect_ratio(
    aspect_ratio: float,
    target_ratios: list[tuple[int, int]],
    *,
    width: int,
    height: int,
    image_size: int,
) -> tuple[int, int]:
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio


114
115
116
117
118
119
120
def resolve_internvl_min_max_num(
    *,
    min_dynamic_patch: int,
    max_dynamic_patch: int,
    dynamic_image_size: bool,
    use_thumbnail: bool,
) -> tuple[int, int]:
121
    min_dynamic_patch = min_dynamic_patch if dynamic_image_size else 1
122
123
124
125
126
127
128
    max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1

    if use_thumbnail and max_dynamic_patch != 1:
        max_dynamic_patch += 1

    return min_dynamic_patch, max_dynamic_patch

129

130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
def get_internvl_target_ratios(
    min_num: int,
    max_num: int,
) -> list[tuple[int, int]]:
    target_ratios = {(i, j)
                     for n in range(min_num, max_num + 1)
                     for i in range(1, n + 1)
                     for j in range(1, n + 1) if min_num <= i * j <= max_num}
    return sorted(target_ratios, key=lambda x: x[0] * x[1])


def calculate_internvl_targets(
    *,
    orig_width: int,
    orig_height: int,
    target_ratios: list[tuple[int, int]],
    image_size: int,
    use_thumbnail: bool,
) -> tuple[int, int, int]:
    aspect_ratio = orig_width / orig_height
150
151

    # find the closest aspect ratio to the target
152
153
154
155
156
157
158
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio,
        target_ratios,
        width=orig_width,
        height=orig_height,
        image_size=image_size,
    )
159
160
161
162
163
164

    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

165
166
167
    # add thumbnail image if num_blocks != 1
    if use_thumbnail and blocks != 1:
        blocks += 1
168

169
    return blocks, target_width, target_height
170
171


172
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
173
174
175
176
177
178
179
def dynamic_preprocess_internvl(
    image: Image.Image,
    *,
    target_ratios: list[tuple[int, int]],
    image_size: int,
    use_thumbnail: bool,
) -> list[Image.Image]:
180
181
    orig_width, orig_height = image.size

182
    # calculate the number of blocks without thumbnail
183
184
185
186
187
188
189
190
    blocks, target_width, target_height = calculate_internvl_targets(
        orig_width=orig_width,
        orig_height=orig_height,
        target_ratios=target_ratios,
        image_size=image_size,
        use_thumbnail=False,
    )

191
192
193
194
195
196
197
198
199
200
201
    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = ((i % (target_width // image_size)) * image_size,
               (i // (target_width // image_size)) * image_size,
               ((i % (target_width // image_size)) + 1) * image_size,
               ((i // (target_width // image_size)) + 1) * image_size)
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
202

203
    assert len(processed_images) == blocks
204

205
206
207
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
208

209
210
211
212
    return processed_images


# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
213
214
215
216
217
218
219
220
221
222
def image_to_pixel_values_internvl(
    image: Image.Image,
    *,
    input_size: int,
    min_num: int,
    max_num: int,
    use_thumbnail: bool,
) -> torch.Tensor:
    target_ratios = get_internvl_target_ratios(min_num, max_num)

223
    transform = build_transform(input_size=input_size)
224
225
226
227
228
229
230
231
    images = dynamic_preprocess_internvl(
        image,
        target_ratios=target_ratios,
        image_size=input_size,
        use_thumbnail=use_thumbnail,
    )

    pixel_values = torch.stack([transform(image) for image in images])
232
233
234
    return pixel_values


235
236
237
238
class BaseInternVLProcessor(ABC):
    """
    This model doesn't define its own HF processor,
    so we implement our own one here.
239

240
241
242
    The code to insert image tokens is based on:
    https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py#L252
    """
243

244
245
246
247
248
    def __init__(
        self,
        config: PretrainedConfig,
        tokenizer: AnyTokenizer,
        *,
249
        min_dynamic_patch: Optional[int] = None,
250
251
252
253
        max_dynamic_patch: Optional[int] = None,
        dynamic_image_size: Optional[bool] = None,
    ) -> None:
        super().__init__()
254

255
256
        self.config = config
        self.tokenizer = tokenizer
257

258
259
        image_size: int = config.vision_config.image_size
        patch_size: int = config.vision_config.patch_size
260

261
262
263
        if min_dynamic_patch is None:
            min_dynamic_patch = config.min_dynamic_patch
        assert isinstance(min_dynamic_patch, int)
264

265
266
267
        if max_dynamic_patch is None:
            max_dynamic_patch = config.max_dynamic_patch
        assert isinstance(max_dynamic_patch, int)
268

269
270
271
272
        if dynamic_image_size is None:
            dynamic_image_size = config.dynamic_image_size
        assert isinstance(dynamic_image_size, bool)

273
274
275
        self.num_image_token = int(
            (image_size // patch_size)**2 * (config.downsample_ratio**2))
        self.image_size = image_size
276
        self.min_dynamic_patch = min_dynamic_patch
277
278
279
280
281
282
283
284
285
286
        self.max_dynamic_patch = max_dynamic_patch
        self.dynamic_image_size = dynamic_image_size
        self.use_thumbnail: bool = config.use_thumbnail

    @property
    @abstractmethod
    def image_token_id(self) -> int:
        raise NotImplementedError

    @abstractmethod
287
    def get_image_repl(
288
289
290
        self,
        feature_size: int,
        num_patches: Optional[int],
291
    ) -> PromptUpdateDetails[str]:
292
        raise NotImplementedError
293

294
    def resolve_min_max_num(
295
        self,
296
        *,
297
        min_dynamic_patch: Optional[int] = None,
298
299
300
301
        max_dynamic_patch: Optional[int] = None,
        dynamic_image_size: Optional[bool] = None,
        use_thumbnail: Optional[bool] = None,
    ) -> tuple[int, int]:
302
303
        min_dynamic_patch = (self.min_dynamic_patch if min_dynamic_patch
                             is None else min_dynamic_patch)
304
305
306
307
308
309
310
311
312
313
314
315
316
        max_dynamic_patch = (self.max_dynamic_patch if max_dynamic_patch
                             is None else max_dynamic_patch)
        dynamic_image_size = (self.dynamic_image_size if dynamic_image_size
                              is None else dynamic_image_size)
        use_thumbnail = (self.use_thumbnail
                         if use_thumbnail is None else use_thumbnail)

        return resolve_internvl_min_max_num(
            min_dynamic_patch=min_dynamic_patch,
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
            use_thumbnail=use_thumbnail,
        )
317

318
319
320
    def resolve_target_ratios(
        self,
        *,
321
        min_dynamic_patch: Optional[int] = None,
322
323
324
325
326
        max_dynamic_patch: Optional[int] = None,
        dynamic_image_size: Optional[bool] = None,
        use_thumbnail: Optional[bool] = None,
    ) -> list[tuple[int, int]]:
        min_num, max_num = self.resolve_min_max_num(
327
            min_dynamic_patch=min_dynamic_patch,
328
329
330
331
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
            use_thumbnail=use_thumbnail,
        )
332

333
        return get_internvl_target_ratios(min_num, max_num)
334

335
    def get_num_image_tokens(
336
        self,
337
338
339
340
341
342
343
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        target_ratios = self.resolve_target_ratios(
            use_thumbnail=False,  # Applied in calculate_targets
        )
344

345
346
347
348
349
350
351
        num_patches, _, _ = calculate_internvl_targets(
            orig_width=image_width,
            orig_height=image_height,
            image_size=self.image_size,
            target_ratios=target_ratios,
            use_thumbnail=self.use_thumbnail,
        )
352

353
354
355
356
357
        return num_patches * self.num_image_token

    def _images_to_pixel_values_lst(
        self,
        images: list[Image.Image],
358
        min_dynamic_patch: Optional[int] = None,
359
360
361
362
        max_dynamic_patch: Optional[int] = None,
        dynamic_image_size: Optional[bool] = None,
    ) -> list[torch.Tensor]:
        min_num, max_num = self.resolve_min_max_num(
363
            min_dynamic_patch=min_dynamic_patch,
364
365
366
367
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
            use_thumbnail=False,  # Applied in image_to_pixel_values
        )
368

369
370
371
372
373
374
375
376
377
        return [
            image_to_pixel_values_internvl(
                image,
                input_size=self.image_size,
                min_num=min_num,
                max_num=max_num,
                use_thumbnail=self.use_thumbnail,
            ) for image in images
        ]
378

379
    def __call__(
380
        self,
381
382
        text: Optional[Union[str, list[str]]] = None,
        images: Optional[Union[Image.Image, list[Image.Image]]] = None,
383
        min_dynamic_patch: Optional[int] = None,
384
        max_dynamic_patch: Optional[int] = None,
385
        dynamic_image_size: Optional[bool] = None,
386
        return_tensors: Optional[Union[str, TensorType]] = None,
387
    ) -> Mapping[str, NestedTensors]:
388
389
390
391
392
393
394
395
396
397
398
        if text is None:
            text = []
        if not isinstance(text, list):
            text = [text]
        if images is None:
            images = []
        if not isinstance(images, list):
            images = [images]

        if len(images) == 0:
            image_inputs = {}
399
        else:
400
401
            pixel_values_lst = self._images_to_pixel_values_lst(
                images,
402
                min_dynamic_patch=min_dynamic_patch,
403
404
405
                max_dynamic_patch=max_dynamic_patch,
                dynamic_image_size=dynamic_image_size,
            )
406
407
408
409
410
            image_inputs: dict[str, NestedTensors] = {
                "pixel_values_flat":
                torch.cat(pixel_values_lst),
                "image_num_patches":
                torch.tensor([len(item) for item in pixel_values_lst]),
411
412
413
414
415
416
            }

            for pixel_values in pixel_values_lst:
                num_patches = pixel_values.shape[0]
                feature_size = num_patches * self.num_image_token

417
418
                image_repl = self.get_image_repl(feature_size, num_patches)
                text = [t.replace('<image>', image_repl.full, 1) for t in text]
419
420
421

        text_inputs = self.tokenizer(text)

422
423
424
425
        return {
            **BatchEncoding(text_inputs, tensor_type=return_tensors),
            **image_inputs,
        }
426
427


428
429
430
431
432
433
class InternVLProcessor(BaseInternVLProcessor):

    @property
    def image_token_id(self) -> int:
        return self.tokenizer.get_vocab()[IMG_CONTEXT]

434
    def get_image_repl(
435
436
437
        self,
        feature_size: int,
        num_patches: Optional[int],
438
439
440
    ) -> PromptUpdateDetails[str]:
        repl_features = IMG_CONTEXT * feature_size
        repl_full = IMG_START + repl_features + IMG_END
441

442
        return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
443
444
445
446
447
448


class BaseInternVLProcessingInfo(BaseProcessingInfo):

    @abstractmethod
    def get_hf_processor(
449
450
        self,
        *,
451
        min_dynamic_patch: Optional[int] = None,
452
        max_dynamic_patch: Optional[int] = None,
453
        dynamic_image_size: Optional[bool] = None,
454
        **kwargs: object,
455
456
457
458
459
460
461
462
    ) -> BaseInternVLProcessor:
        raise NotImplementedError

    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}

    def get_num_image_tokens(
        self,
463
        *,
464
465
466
467
468
469
470
471
472
473
474
        image_width: int,
        image_height: int,
        processor: Optional[BaseInternVLProcessor],
    ) -> int:
        if processor is None:
            processor = self.get_hf_processor()

        return processor.get_num_image_tokens(
            image_width=image_width,
            image_height=image_height,
        )
475

476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
    def get_image_size_with_most_features(self) -> ImageSize:
        processor = self.get_hf_processor()

        base_size = processor.image_size
        target_ratios = processor.resolve_target_ratios()

        largest_feature_size, largest_feature_pinpoint = 0, None
        for wr, hr in target_ratios:
            width, height = base_size * wr, base_size * hr

            feat_size = self.get_num_image_tokens(
                image_width=width,
                image_height=height,
                processor=processor,
            )
            if feat_size > largest_feature_size:
                largest_feature_size = feat_size
                largest_feature_pinpoint = ImageSize(width=width,
                                                     height=height)

        if largest_feature_size == 0 or largest_feature_pinpoint is None:
            raise ValueError("Cannot have a largest feature size of 0!")

        return largest_feature_pinpoint


_I = TypeVar("_I", bound=BaseInternVLProcessingInfo)


class InternVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]):

507
508
509
510
511
512
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        return "<image>" * num_images

    def get_dummy_mm_data(
513
514
515
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
516
    ) -> MultiModalDataDict:
517
518
519
520
        target_width, target_height = \
            self.info.get_image_size_with_most_features()
        num_images = mm_counts.get("image", 0)

521
        return {
522
523
524
525
526
527
528
529
530
531
532
533
534
535
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images)
        }


class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
536
    ) -> Mapping[str, NestedTensors]:
537
538
539
540
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
541
        )
542

543
544
        hf_processor = self.info.get_hf_processor(**mm_kwargs)
        image_token_id = hf_processor.image_token_id
545
546
547
548

        # Since there may be extra tokens in the feature placeholders,
        # we need to pass the image token ID to the model to select the
        # tokens to merge from the vision encoder outputs
549
        processed_outputs["image_token_id"] = torch.tensor(image_token_id)
550
551
552
553
554

        return processed_outputs

    def _get_mm_fields_config(
        self,
555
        hf_inputs: Mapping[str, NestedTensors],
556
557
558
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
559
        num_images = len(image_num_patches)
560
561
562
563
564
565

        return dict(
            pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
                "image", image_num_patches),
            image_num_patches=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
566
            image_token_id=MultiModalFieldConfig.shared("image", num_images),
567
568
        )

569
    def _get_prompt_updates(
570
571
572
573
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
574
    ) -> Sequence[PromptUpdate]:
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)

        if "image_num_patches" in out_mm_kwargs:
            image_num_patches = out_mm_kwargs["image_num_patches"]
            assert isinstance(image_num_patches, torch.Tensor)
            image_num_patches = image_num_patches.tolist()
        elif "image_embeds" in out_mm_kwargs:
            # TODO: Use image size information in dictionary embedding inputs
            # to compute num_patches (similar to Qwen2-VL)
            image_num_patches = [None] * len(out_mm_kwargs["image_embeds"])
        else:
            image_num_patches = []

        def get_replacement_internvl(item_idx: int):
            images = mm_items.get_items(
                "image", (ImageEmbeddingItems, ImageProcessorItems))

            if isinstance(images, ImageEmbeddingItems):
                feature_size = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
                feature_size = self.info.get_num_image_tokens(
                    image_width=image_size.width,
                    image_height=image_size.height,
                    processor=hf_processor,
                )

            num_patches = image_num_patches[item_idx]
            if num_patches is not None:
                assert isinstance(num_patches, int)

606
            return hf_processor.get_image_repl(feature_size, num_patches)
607

608
609
610
611
612
613
614
        return [
            PromptReplacement(
                modality="image",
                target="<image>",
                replacement=get_replacement_internvl,
            )
        ]
615
616


617
class InternVLProcessingInfo(BaseInternVLProcessingInfo):
618

619
620
621
    def get_hf_processor(
        self,
        *,
622
        min_dynamic_patch: Optional[int] = None,
623
624
        max_dynamic_patch: Optional[int] = None,
        dynamic_image_size: Optional[bool] = None,
625
        **kwargs: object,
626
    ) -> InternVLProcessor:
627
628
629
630
631
632
633
634
635
636
637
638
        if min_dynamic_patch is not None:
            kwargs["min_dynamic_patch"] = min_dynamic_patch
        if max_dynamic_patch is not None:
            kwargs["max_dynamic_patch"] = max_dynamic_patch
        if dynamic_image_size is not None:
            kwargs["dynamic_image_size"] = dynamic_image_size

        return self.ctx.init_processor(
            InternVLProcessor,
            config=self.get_hf_config(),
            tokenizer=self.get_tokenizer(),
            **kwargs,
639
640
641
642
643
644
645
        )


@MULTIMODAL_REGISTRY.register_processor(
    InternVLMultiModalProcessor,
    info=InternVLProcessingInfo,
    dummy_inputs=InternVLDummyInputsBuilder)
646
class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
647

648
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
649
650
        super().__init__()

651
652
653
654
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

655
656
        self.config = config
        self.multimodal_config = multimodal_config
657
        self._patch_quant_config(config, quant_config)
658
659
660
661
662
663
664
665
666

        image_size = config.force_image_size or config.vision_config.image_size
        patch_size = config.vision_config.patch_size
        self.patch_size = patch_size
        self.num_image_token = int(
            (image_size // patch_size)**2 * (config.downsample_ratio**2))
        self.downsample_ratio = config.downsample_ratio
        self.ps_version = config.ps_version

667
668
        self.llm_arch_name = config.text_config.architectures[0]
        self.is_mono = self.llm_arch_name == 'InternLM2VEForCausalLM'
669
670
671
672
        self.vision_model = self._init_vision_model(
            config,
            quant_config=quant_config,
            is_mono=self.is_mono,
673
            prefix=maybe_prefix(prefix, "vision_model"),
674
        )
675

676
        self.language_model = init_vllm_registered_model(
677
            vllm_config=vllm_config,
678
679
680
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
681

682
        self.mlp1 = self._init_mlp1(config)
683
684

        self.img_context_token_id = None
685
        self.visual_token_mask = None
686
687
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)
688

689
690
691
692
693
694
695
696
697
698
699
700
    def _patch_quant_config(self, config: PretrainedConfig,
                            quant_config: QuantizationConfig):
        # the awq models from OpenGVLab missing `modules_to_not_convert`
        # patch the quant_config to add `modules_to_not_convert` back
        if isinstance(quant_config, AWQConfig):
            text_config = config.text_config
            llm_quant_config = getattr(text_config, "quantization_config",
                                       None)
            if (not quant_config.modules_to_not_convert) and \
                (llm_quant_config is not None):
                quant_config.modules_to_not_convert.append("vision_model")

701
702
    @cached_property
    def sampler(self):
703
        if hasattr(self.language_model, "sampler"):
704
705
            return self.language_model.sampler

Joe Runde's avatar
Joe Runde committed
706
        return get_sampler()
707

708
709
710
711
712
713
714
715
    def _init_vision_model(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig],
        *,
        is_mono: bool,
        prefix: str,
    ):
716
        if not is_mono:
717
            vision_feature_layer = config.select_layer
718
719
720
721
722
            if vision_feature_layer < 0:
                num_hidden_layers = config.vision_config.num_hidden_layers \
                    + vision_feature_layer + 1
            else:
                num_hidden_layers = vision_feature_layer + 1
723

724
725
            return InternVisionModel(
                config.vision_config,
726
727
728
729
                quant_config=quant_config,
                num_hidden_layers_override=num_hidden_layers,
                prefix=prefix,
            )
730
731
        else:
            return InternVisionPatchModel(config.vision_config)
732
733
734
735
736
737
738
739
740
741
742
743
744

    def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
        vit_hidden_size = config.vision_config.hidden_size
        llm_hidden_size = config.text_config.hidden_size

        return nn.Sequential(
            nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2),
            nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2,
                      llm_hidden_size),
            nn.GELU(),
            nn.Linear(llm_hidden_size, llm_hidden_size),
        )

745
746
747
748
749
750
751
752
753
754
755
756
757
758
    def pixel_shuffle(self, x, scale_factor=0.5):
        n, w, h, c = x.size()
        # N, W, H, C --> N, W, H * scale, C // scale
        x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
        # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
        x = x.permute(0, 2, 1, 3).contiguous()
        x = x.view(n, int(h * scale_factor), int(w * scale_factor),
                   int(c / (scale_factor * scale_factor)))
        if self.ps_version == 'v1':
            pass
        else:
            x = x.permute(0, 2, 1, 3).contiguous()
        return x

759
    def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor:
760
761
762
763
764
765
766
767
768
769
770
771
        vit_embeds = self.vision_model(pixel_values=pixel_values)
        vit_embeds = vit_embeds[:, 1:, :]

        h = w = int(vit_embeds.shape[1]**0.5)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
        vit_embeds = self.pixel_shuffle(vit_embeds,
                                        scale_factor=self.downsample_ratio)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1,
                                        vit_embeds.shape[-1])
        vit_embeds = self.mlp1(vit_embeds)
        return vit_embeds

772
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
773
774
775
776
777
778
779
780

        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape)

            if actual_dims != expected_dims:
781
                expected_expr = str(expected_dims)
782
                raise ValueError(
783
784
785
                    "The expected shape of pixel values per image per batch "
                    f" per patch is {expected_expr}. "
                    f"You supplied {tuple(d.shape)}.")
786
787
788
789
790
791
792

        for d in data:
            _validate_shape(d)

        return data

    def _parse_and_validate_image_input(
793
            self, **kwargs: object) -> Optional[InternVLImageInputs]:
794
795
        pixel_values_flat = kwargs.pop("pixel_values_flat", None)
        image_num_patches = kwargs.pop("image_num_patches", None)
796
        image_embeds = kwargs.pop("image_embeds", None)
797

798
        if pixel_values_flat is None and image_embeds is None:
799
800
            return None

801
        if image_embeds is not None:
802
            if not isinstance(image_embeds, (torch.Tensor, list)):
803
804
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
805

806
807
            return InternVLImageEmbeddingInputs(
                type="image_embeds",
808
                data=flatten_bn(image_embeds),
809
810
            )

811
812
813
        image_token_id = kwargs["image_token_id"]
        assert isinstance(image_token_id, torch.Tensor)
        self.img_context_token_id = image_token_id.flatten().unique().item()
814

815
816
        if pixel_values_flat is not None:
            if not isinstance(pixel_values_flat, (torch.Tensor, list)):
817
                raise ValueError("Incorrect type of pixel values. "
818
819
                                 f"Got type: {type(pixel_values_flat)}")

820
821
            if not isinstance(image_num_patches, (torch.Tensor, list)):
                raise ValueError("Incorrect type of image_num_patches. "
822
823
824
825
                                 f"Got type: {type(image_num_patches)}")

            pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
            image_num_patches = flatten_bn(image_num_patches, concat=True)
826

827
828
            return InternVLImagePixelInputs(
                type="pixel_values",
829
830
831
832
                pixel_values_flat=self._validate_pixel_values(
                    pixel_values_flat),
                num_patches=image_num_patches,
            )
833
834
835
836
837
838

        raise AssertionError("This line should be unreachable.")

    def _process_image_input(
        self,
        image_input: InternVLImageInputs,
839
    ) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]]:
840
841
842
843
        if image_input["type"] == "image_embeds":
            return image_input["data"]

        assert self.vision_model is not None
844

845
        image_embeds = self.extract_feature(image_input["pixel_values_flat"])
846

847
        num_patches = image_input["num_patches"]
848
849

        # Only one image in the current batch
850
851
        if len(num_patches) == 1:
            return image_embeds.view(
852
                -1, self.config.text_config.hidden_size).unsqueeze(0)
853
854
855
856
857
858
859

        # NOTE: Image embeddings are split into separate tensors for each image
        # by the size of each embedding.
        feature_size = image_embeds.shape[1]
        image_embeds = image_embeds.view(-1,
                                         self.config.text_config.hidden_size)
        image_feature_sizes = [
860
            num_patches * feature_size for num_patches in num_patches
861
        ]
862
        return image_embeds.split(image_feature_sizes)
863

864
    def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
865
        if self.is_mono:
866
            self.visual_token_mask = (
867
868
                input_ids == self.img_context_token_id).reshape(-1, 1)
        else:
869
            self.visual_token_mask = None
870

871
872
873
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

874
    def get_multimodal_embeddings(
875
            self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
876
877
878
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
879

880
        return self._process_image_input(image_input)
881
882
883
884

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
885
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
886
887
888
889
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
        if multimodal_embeddings is not None:
            assert self.img_context_token_id is not None
890
            self._set_visual_token_mask(input_ids)
891
            inputs_embeds = merge_multimodal_embeddings(
892
893
                input_ids,
                inputs_embeds,
894
                multimodal_embeddings,
895
896
                self.img_context_token_id,
            )
897
898
        return inputs_embeds

899
900
901
902
903
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
904
        inputs_embeds: Optional[torch.Tensor] = None,
905
        **kwargs: object,
906
    ) -> Union[SamplerOutput, IntermediateTensors]:
907

908
        if intermediate_tensors is not None:
909
910
            input_ids = None
            inputs_embeds = None
911
912
913
914
915
916
917
918

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
        elif inputs_embeds is None:
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
919
920
921
922
923
924
925

        forward_kwargs = {
            "input_ids": input_ids,
            "positions": positions,
            "intermediate_tensors": intermediate_tensors,
            "inputs_embeds": inputs_embeds,
        }
926

927
        # Only required if the model is mono-architecture
928
929
930
931
        if self.visual_token_mask is not None:
            forward_kwargs.update(
                {"visual_token_mask": self.visual_token_mask})
            self.visual_token_mask = None
932

933
        hidden_states = self.language_model.model(**forward_kwargs)
934
935
        return hidden_states

936
937
938
939
940
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
941
942
943
944
945
946
947
948
949
950
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        return self.language_model.sample(logits, sampling_metadata)

951
952
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
953
954
955
956
957
958
959
960
        # unused modules appear in OpenGVLab/InternVideo2_5_Chat_8B
        skip_prefixes = [
            "action_embed", "temporal_embed", "track_embed",
            "track_embed_decoder", "box_token", "cg_criterion", "cg_model",
            "loc_encoder", "loc_decoder", "sam", "temporal_token",
            "track_token"
        ]
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
961
        return loader.load_weights(weights)