"benchmarks/benchmark_latency.py" did not exist on "0f40557af6141ced118b81f2a04e651a0c6c9dbd"
internvl.py 33.8 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_internvl_chat.py
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
9
from abc import ABC, abstractmethod
10
from collections.abc import Iterable, Mapping, Sequence
11
from functools import cached_property
12
from typing import Literal, Optional, Set, Tuple, TypedDict, TypeVar, Union
13
14
15
16
17

import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
18
from transformers import BatchEncoding, PretrainedConfig, TensorType
19

20
from vllm.config import VllmConfig
21
22
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
Joe Runde's avatar
Joe Runde committed
23
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
24
25
from vllm.model_executor.models.intern_vit import (InternVisionModel,
                                                   InternVisionPatchModel)
26
from vllm.model_executor.sampling_metadata import SamplingMetadata
27
28
29
30
31
32
33
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
                                    NestedTensors)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
                                   ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
                                        BaseProcessingInfo, PromptReplacement,
34
                                        PromptUpdate, PromptUpdateDetails)
35
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
36
from vllm.sequence import IntermediateTensors
37
from vllm.transformers_utils.tokenizer import AnyTokenizer
38

39
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
40
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
41
                    maybe_prefix, merge_multimodal_embeddings)
42
43
44
45
46
47
48
49
50
51
52

IMG_START = '<img>'
IMG_END = '</img>'
IMG_CONTEXT = '<IMG_CONTEXT>'

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)


class InternVLImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
53
    pixel_values_flat: torch.Tensor
54
    """
55
56
    Shape:
    `(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
57
    """
58
59
60
61

    num_patches: torch.Tensor
    """Shape: `(batch_size * num_images)`"""

62

63
64
class InternVLImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
65
    data: Union[torch.Tensor, list[torch.Tensor]]
66
67
68
    """ 
    A tensor of shape `(num_images, total_image_feature_size, hidden_size)`
    or a list of tensors of shape `(total_image_feature_size, hidden_size)`
69
70
71
72
73
74
75
76
77

    `hidden_size` must match the hidden size of language model backbone.
    """


InternVLImageInputs = Union[InternVLImagePixelInputs,
                            InternVLImageEmbeddingInputs]


78
79
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def build_transform(input_size: int):
80
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
81
    return T.Compose([
82
83
84
85
86
87
88
89
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size),
                 interpolation=T.InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])


90
91
92
93
94
95
96
97
98
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def find_closest_aspect_ratio(
    aspect_ratio: float,
    target_ratios: list[tuple[int, int]],
    *,
    width: int,
    height: int,
    image_size: int,
) -> tuple[int, int]:
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio


114
115
116
117
118
119
120
def resolve_internvl_min_max_num(
    *,
    min_dynamic_patch: int,
    max_dynamic_patch: int,
    dynamic_image_size: bool,
    use_thumbnail: bool,
) -> tuple[int, int]:
121
    min_dynamic_patch = min_dynamic_patch if dynamic_image_size else 1
122
123
124
125
126
127
128
    max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1

    if use_thumbnail and max_dynamic_patch != 1:
        max_dynamic_patch += 1

    return min_dynamic_patch, max_dynamic_patch

129

130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
def get_internvl_target_ratios(
    min_num: int,
    max_num: int,
) -> list[tuple[int, int]]:
    target_ratios = {(i, j)
                     for n in range(min_num, max_num + 1)
                     for i in range(1, n + 1)
                     for j in range(1, n + 1) if min_num <= i * j <= max_num}
    return sorted(target_ratios, key=lambda x: x[0] * x[1])


def calculate_internvl_targets(
    *,
    orig_width: int,
    orig_height: int,
    target_ratios: list[tuple[int, int]],
    image_size: int,
    use_thumbnail: bool,
) -> tuple[int, int, int]:
    aspect_ratio = orig_width / orig_height
150
151

    # find the closest aspect ratio to the target
152
153
154
155
156
157
158
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio,
        target_ratios,
        width=orig_width,
        height=orig_height,
        image_size=image_size,
    )
159
160
161
162
163
164

    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

165
166
167
    # add thumbnail image if num_blocks != 1
    if use_thumbnail and blocks != 1:
        blocks += 1
168

169
    return blocks, target_width, target_height
170
171


172
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
173
174
175
176
177
178
179
def dynamic_preprocess_internvl(
    image: Image.Image,
    *,
    target_ratios: list[tuple[int, int]],
    image_size: int,
    use_thumbnail: bool,
) -> list[Image.Image]:
180
181
    orig_width, orig_height = image.size

182
    # calculate the number of blocks without thumbnail
183
184
185
186
187
188
189
190
    blocks, target_width, target_height = calculate_internvl_targets(
        orig_width=orig_width,
        orig_height=orig_height,
        target_ratios=target_ratios,
        image_size=image_size,
        use_thumbnail=False,
    )

191
192
193
194
195
196
197
198
199
200
201
    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = ((i % (target_width // image_size)) * image_size,
               (i // (target_width // image_size)) * image_size,
               ((i % (target_width // image_size)) + 1) * image_size,
               ((i // (target_width // image_size)) + 1) * image_size)
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
202

203
    assert len(processed_images) == blocks
204

205
206
207
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
208

209
210
211
212
    return processed_images


# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
213
214
215
216
217
218
219
220
221
222
def image_to_pixel_values_internvl(
    image: Image.Image,
    *,
    input_size: int,
    min_num: int,
    max_num: int,
    use_thumbnail: bool,
) -> torch.Tensor:
    target_ratios = get_internvl_target_ratios(min_num, max_num)

223
    transform = build_transform(input_size=input_size)
224
225
226
227
228
229
230
231
    images = dynamic_preprocess_internvl(
        image,
        target_ratios=target_ratios,
        image_size=input_size,
        use_thumbnail=use_thumbnail,
    )

    pixel_values = torch.stack([transform(image) for image in images])
232
233
234
    return pixel_values


235
236
237
238
class BaseInternVLProcessor(ABC):
    """
    This model doesn't define its own HF processor,
    so we implement our own one here.
239

240
241
242
    The code to insert image tokens is based on:
    https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py#L252
    """
243

244
245
246
247
248
    def __init__(
        self,
        config: PretrainedConfig,
        tokenizer: AnyTokenizer,
        *,
249
        min_dynamic_patch: Optional[int] = None,
250
251
252
253
        max_dynamic_patch: Optional[int] = None,
        dynamic_image_size: Optional[bool] = None,
    ) -> None:
        super().__init__()
254

255
256
        self.config = config
        self.tokenizer = tokenizer
257

258
259
        image_size: int = config.vision_config.image_size
        patch_size: int = config.vision_config.patch_size
260

261
262
263
        if min_dynamic_patch is None:
            min_dynamic_patch = config.min_dynamic_patch
        assert isinstance(min_dynamic_patch, int)
264

265
266
267
        if max_dynamic_patch is None:
            max_dynamic_patch = config.max_dynamic_patch
        assert isinstance(max_dynamic_patch, int)
268

269
270
271
272
        if dynamic_image_size is None:
            dynamic_image_size = config.dynamic_image_size
        assert isinstance(dynamic_image_size, bool)

273
274
275
        self.num_image_token = int(
            (image_size // patch_size)**2 * (config.downsample_ratio**2))
        self.image_size = image_size
276
        self.min_dynamic_patch = min_dynamic_patch
277
278
279
280
281
282
283
284
285
286
        self.max_dynamic_patch = max_dynamic_patch
        self.dynamic_image_size = dynamic_image_size
        self.use_thumbnail: bool = config.use_thumbnail

    @property
    @abstractmethod
    def image_token_id(self) -> int:
        raise NotImplementedError

    @abstractmethod
287
    def get_image_repl(
288
289
290
        self,
        feature_size: int,
        num_patches: Optional[int],
291
    ) -> PromptUpdateDetails[str]:
292
        raise NotImplementedError
293

294
    def resolve_min_max_num(
295
        self,
296
        *,
297
        min_dynamic_patch: Optional[int] = None,
298
299
300
301
        max_dynamic_patch: Optional[int] = None,
        dynamic_image_size: Optional[bool] = None,
        use_thumbnail: Optional[bool] = None,
    ) -> tuple[int, int]:
302
303
        min_dynamic_patch = (self.min_dynamic_patch if min_dynamic_patch
                             is None else min_dynamic_patch)
304
305
306
307
308
309
310
311
312
313
314
315
316
        max_dynamic_patch = (self.max_dynamic_patch if max_dynamic_patch
                             is None else max_dynamic_patch)
        dynamic_image_size = (self.dynamic_image_size if dynamic_image_size
                              is None else dynamic_image_size)
        use_thumbnail = (self.use_thumbnail
                         if use_thumbnail is None else use_thumbnail)

        return resolve_internvl_min_max_num(
            min_dynamic_patch=min_dynamic_patch,
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
            use_thumbnail=use_thumbnail,
        )
317

318
319
320
    def resolve_target_ratios(
        self,
        *,
321
        min_dynamic_patch: Optional[int] = None,
322
323
324
325
326
        max_dynamic_patch: Optional[int] = None,
        dynamic_image_size: Optional[bool] = None,
        use_thumbnail: Optional[bool] = None,
    ) -> list[tuple[int, int]]:
        min_num, max_num = self.resolve_min_max_num(
327
            min_dynamic_patch=min_dynamic_patch,
328
329
330
331
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
            use_thumbnail=use_thumbnail,
        )
332

333
        return get_internvl_target_ratios(min_num, max_num)
334

335
    def get_num_image_tokens(
336
        self,
337
338
339
340
341
342
343
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        target_ratios = self.resolve_target_ratios(
            use_thumbnail=False,  # Applied in calculate_targets
        )
344

345
346
347
348
349
350
351
        num_patches, _, _ = calculate_internvl_targets(
            orig_width=image_width,
            orig_height=image_height,
            image_size=self.image_size,
            target_ratios=target_ratios,
            use_thumbnail=self.use_thumbnail,
        )
352

353
354
355
356
357
        return num_patches * self.num_image_token

    def _images_to_pixel_values_lst(
        self,
        images: list[Image.Image],
358
        min_dynamic_patch: Optional[int] = None,
359
360
361
362
        max_dynamic_patch: Optional[int] = None,
        dynamic_image_size: Optional[bool] = None,
    ) -> list[torch.Tensor]:
        min_num, max_num = self.resolve_min_max_num(
363
            min_dynamic_patch=min_dynamic_patch,
364
365
366
367
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
            use_thumbnail=False,  # Applied in image_to_pixel_values
        )
368

369
370
371
372
373
374
375
376
377
        return [
            image_to_pixel_values_internvl(
                image,
                input_size=self.image_size,
                min_num=min_num,
                max_num=max_num,
                use_thumbnail=self.use_thumbnail,
            ) for image in images
        ]
378

379
    def __call__(
380
        self,
381
382
        text: Optional[Union[str, list[str]]] = None,
        images: Optional[Union[Image.Image, list[Image.Image]]] = None,
383
        min_dynamic_patch: Optional[int] = None,
384
        max_dynamic_patch: Optional[int] = None,
385
        dynamic_image_size: Optional[bool] = None,
386
        return_tensors: Optional[Union[str, TensorType]] = None,
387
    ) -> Mapping[str, NestedTensors]:
388
389
390
391
392
393
394
395
396
397
398
        if text is None:
            text = []
        if not isinstance(text, list):
            text = [text]
        if images is None:
            images = []
        if not isinstance(images, list):
            images = [images]

        if len(images) == 0:
            image_inputs = {}
399
        else:
400
401
            pixel_values_lst = self._images_to_pixel_values_lst(
                images,
402
                min_dynamic_patch=min_dynamic_patch,
403
404
405
                max_dynamic_patch=max_dynamic_patch,
                dynamic_image_size=dynamic_image_size,
            )
406
407
408
409
410
            image_inputs: dict[str, NestedTensors] = {
                "pixel_values_flat":
                torch.cat(pixel_values_lst),
                "image_num_patches":
                torch.tensor([len(item) for item in pixel_values_lst]),
411
412
413
414
415
416
            }

            for pixel_values in pixel_values_lst:
                num_patches = pixel_values.shape[0]
                feature_size = num_patches * self.num_image_token

417
418
                image_repl = self.get_image_repl(feature_size, num_patches)
                text = [t.replace('<image>', image_repl.full, 1) for t in text]
419
420
421

        text_inputs = self.tokenizer(text)

422
423
424
425
        return {
            **BatchEncoding(text_inputs, tensor_type=return_tensors),
            **image_inputs,
        }
426
427


428
429
430
431
432
433
class InternVLProcessor(BaseInternVLProcessor):

    @property
    def image_token_id(self) -> int:
        return self.tokenizer.get_vocab()[IMG_CONTEXT]

434
    def get_image_repl(
435
436
437
        self,
        feature_size: int,
        num_patches: Optional[int],
438
439
440
    ) -> PromptUpdateDetails[str]:
        repl_features = IMG_CONTEXT * feature_size
        repl_full = IMG_START + repl_features + IMG_END
441

442
        return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
443
444
445
446
447
448


class BaseInternVLProcessingInfo(BaseProcessingInfo):

    @abstractmethod
    def get_hf_processor(
449
450
        self,
        *,
451
        min_dynamic_patch: Optional[int] = None,
452
        max_dynamic_patch: Optional[int] = None,
453
        dynamic_image_size: Optional[bool] = None,
454
        **kwargs: object,
455
456
457
458
459
460
461
    ) -> BaseInternVLProcessor:
        raise NotImplementedError

    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}

    def get_mm_max_tokens_per_item(
462
463
464
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
465
466
467
468
469
    ) -> Mapping[str, int]:
        return {"image": self.get_max_image_tokens()}

    def get_num_image_tokens(
        self,
470
        *,
471
472
473
474
475
476
477
478
479
480
481
        image_width: int,
        image_height: int,
        processor: Optional[BaseInternVLProcessor],
    ) -> int:
        if processor is None:
            processor = self.get_hf_processor()

        return processor.get_num_image_tokens(
            image_width=image_width,
            image_height=image_height,
        )
482

483
484
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
485

486
487
488
489
        return self.get_num_image_tokens(
            image_width=target_width,
            image_height=target_height,
            processor=None,
490
        )
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541

    def get_image_size_with_most_features(self) -> ImageSize:
        processor = self.get_hf_processor()

        base_size = processor.image_size
        target_ratios = processor.resolve_target_ratios()

        largest_feature_size, largest_feature_pinpoint = 0, None
        for wr, hr in target_ratios:
            width, height = base_size * wr, base_size * hr

            feat_size = self.get_num_image_tokens(
                image_width=width,
                image_height=height,
                processor=processor,
            )
            if feat_size > largest_feature_size:
                largest_feature_size = feat_size
                largest_feature_pinpoint = ImageSize(width=width,
                                                     height=height)

        if largest_feature_size == 0 or largest_feature_pinpoint is None:
            raise ValueError("Cannot have a largest feature size of 0!")

        return largest_feature_pinpoint


_I = TypeVar("_I", bound=BaseInternVLProcessingInfo)


class InternVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]):

    def get_dummy_processor_inputs(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> ProcessorInputs:
        target_width, target_height = \
            self.info.get_image_size_with_most_features()
        num_images = mm_counts.get("image", 0)

        mm_data = {
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images)
        }

        return ProcessorInputs(
            prompt_text="<image>" * num_images,
            mm_data=mm_data,
542
543
        )

544
545
546
547
548
549
550
551

class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
552
    ) -> Mapping[str, NestedTensors]:
553
554
555
556
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
557
        )
558

559
560
        hf_processor = self.info.get_hf_processor(**mm_kwargs)
        image_token_id = hf_processor.image_token_id
561
562
563
564

        # Since there may be extra tokens in the feature placeholders,
        # we need to pass the image token ID to the model to select the
        # tokens to merge from the vision encoder outputs
565
        processed_outputs["image_token_id"] = torch.tensor(image_token_id)
566
567
568
569
570

        return processed_outputs

    def _get_mm_fields_config(
        self,
571
        hf_inputs: Mapping[str, NestedTensors],
572
573
574
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
575
        num_images = len(image_num_patches)
576
577
578
579
580
581

        return dict(
            pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
                "image", image_num_patches),
            image_num_patches=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
582
            image_token_id=MultiModalFieldConfig.shared("image", num_images),
583
584
        )

585
    def _get_prompt_updates(
586
587
588
589
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
590
    ) -> Sequence[PromptUpdate]:
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)

        if "image_num_patches" in out_mm_kwargs:
            image_num_patches = out_mm_kwargs["image_num_patches"]
            assert isinstance(image_num_patches, torch.Tensor)
            image_num_patches = image_num_patches.tolist()
        elif "image_embeds" in out_mm_kwargs:
            # TODO: Use image size information in dictionary embedding inputs
            # to compute num_patches (similar to Qwen2-VL)
            image_num_patches = [None] * len(out_mm_kwargs["image_embeds"])
        else:
            image_num_patches = []

        def get_replacement_internvl(item_idx: int):
            images = mm_items.get_items(
                "image", (ImageEmbeddingItems, ImageProcessorItems))

            if isinstance(images, ImageEmbeddingItems):
                feature_size = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
                feature_size = self.info.get_num_image_tokens(
                    image_width=image_size.width,
                    image_height=image_size.height,
                    processor=hf_processor,
                )

            num_patches = image_num_patches[item_idx]
            if num_patches is not None:
                assert isinstance(num_patches, int)

622
            return hf_processor.get_image_repl(feature_size, num_patches)
623

624
625
626
627
628
629
630
        return [
            PromptReplacement(
                modality="image",
                target="<image>",
                replacement=get_replacement_internvl,
            )
        ]
631
632


633
class InternVLProcessingInfo(BaseInternVLProcessingInfo):
634

635
636
637
    def get_hf_processor(
        self,
        *,
638
        min_dynamic_patch: Optional[int] = None,
639
640
        max_dynamic_patch: Optional[int] = None,
        dynamic_image_size: Optional[bool] = None,
641
        **kwargs: object,
642
    ) -> InternVLProcessor:
643
644
645
646
647
648
649
650
651
652
653
654
        if min_dynamic_patch is not None:
            kwargs["min_dynamic_patch"] = min_dynamic_patch
        if max_dynamic_patch is not None:
            kwargs["max_dynamic_patch"] = max_dynamic_patch
        if dynamic_image_size is not None:
            kwargs["dynamic_image_size"] = dynamic_image_size

        return self.ctx.init_processor(
            InternVLProcessor,
            config=self.get_hf_config(),
            tokenizer=self.get_tokenizer(),
            **kwargs,
655
656
657
658
659
660
661
        )


@MULTIMODAL_REGISTRY.register_processor(
    InternVLMultiModalProcessor,
    info=InternVLProcessingInfo,
    dummy_inputs=InternVLDummyInputsBuilder)
662
class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
663

664
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
665
666
        super().__init__()

667
668
669
670
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

671
672
        self.config = config
        self.multimodal_config = multimodal_config
673
        self._patch_quant_config(config, quant_config)
674
675
676
677
678
679
680
681
682

        image_size = config.force_image_size or config.vision_config.image_size
        patch_size = config.vision_config.patch_size
        self.patch_size = patch_size
        self.num_image_token = int(
            (image_size // patch_size)**2 * (config.downsample_ratio**2))
        self.downsample_ratio = config.downsample_ratio
        self.ps_version = config.ps_version

683
684
        self.llm_arch_name = config.text_config.architectures[0]
        self.is_mono = self.llm_arch_name == 'InternLM2VEForCausalLM'
685
686
687
688
        self.vision_model = self._init_vision_model(
            config,
            quant_config=quant_config,
            is_mono=self.is_mono,
689
            prefix=maybe_prefix(prefix, "vision_model"),
690
        )
691

692
        self.language_model = init_vllm_registered_model(
693
            vllm_config=vllm_config,
694
695
696
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
697

698
        self.mlp1 = self._init_mlp1(config)
699
700

        self.img_context_token_id = None
701
        self.visual_token_mask = None
702
703
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)
704

705
706
707
708
709
710
711
712
713
714
715
716
    def _patch_quant_config(self, config: PretrainedConfig,
                            quant_config: QuantizationConfig):
        # the awq models from OpenGVLab missing `modules_to_not_convert`
        # patch the quant_config to add `modules_to_not_convert` back
        if isinstance(quant_config, AWQConfig):
            text_config = config.text_config
            llm_quant_config = getattr(text_config, "quantization_config",
                                       None)
            if (not quant_config.modules_to_not_convert) and \
                (llm_quant_config is not None):
                quant_config.modules_to_not_convert.append("vision_model")

717
718
    @cached_property
    def sampler(self):
719
        if hasattr(self.language_model, "sampler"):
720
721
            return self.language_model.sampler

Joe Runde's avatar
Joe Runde committed
722
        return get_sampler()
723

724
725
726
727
728
729
730
731
    def _init_vision_model(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig],
        *,
        is_mono: bool,
        prefix: str,
    ):
732
        if not is_mono:
733
            vision_feature_layer = config.select_layer
734
735
736
737
738
            if vision_feature_layer < 0:
                num_hidden_layers = config.vision_config.num_hidden_layers \
                    + vision_feature_layer + 1
            else:
                num_hidden_layers = vision_feature_layer + 1
739

740
741
            return InternVisionModel(
                config.vision_config,
742
743
744
745
                quant_config=quant_config,
                num_hidden_layers_override=num_hidden_layers,
                prefix=prefix,
            )
746
747
        else:
            return InternVisionPatchModel(config.vision_config)
748
749
750
751
752
753
754
755
756
757
758
759
760

    def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
        vit_hidden_size = config.vision_config.hidden_size
        llm_hidden_size = config.text_config.hidden_size

        return nn.Sequential(
            nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2),
            nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2,
                      llm_hidden_size),
            nn.GELU(),
            nn.Linear(llm_hidden_size, llm_hidden_size),
        )

761
762
763
764
765
766
767
768
769
770
771
772
773
774
    def pixel_shuffle(self, x, scale_factor=0.5):
        n, w, h, c = x.size()
        # N, W, H, C --> N, W, H * scale, C // scale
        x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
        # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
        x = x.permute(0, 2, 1, 3).contiguous()
        x = x.view(n, int(h * scale_factor), int(w * scale_factor),
                   int(c / (scale_factor * scale_factor)))
        if self.ps_version == 'v1':
            pass
        else:
            x = x.permute(0, 2, 1, 3).contiguous()
        return x

775
    def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor:
776
777
778
779
780
781
782
783
784
785
786
787
        vit_embeds = self.vision_model(pixel_values=pixel_values)
        vit_embeds = vit_embeds[:, 1:, :]

        h = w = int(vit_embeds.shape[1]**0.5)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
        vit_embeds = self.pixel_shuffle(vit_embeds,
                                        scale_factor=self.downsample_ratio)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1,
                                        vit_embeds.shape[-1])
        vit_embeds = self.mlp1(vit_embeds)
        return vit_embeds

788
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
789
790
791
792
793
794
795
796

        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape)

            if actual_dims != expected_dims:
797
                expected_expr = str(expected_dims)
798
                raise ValueError(
799
800
801
                    "The expected shape of pixel values per image per batch "
                    f" per patch is {expected_expr}. "
                    f"You supplied {tuple(d.shape)}.")
802
803
804
805
806
807
808

        for d in data:
            _validate_shape(d)

        return data

    def _parse_and_validate_image_input(
809
            self, **kwargs: object) -> Optional[InternVLImageInputs]:
810
811
        pixel_values_flat = kwargs.pop("pixel_values_flat", None)
        image_num_patches = kwargs.pop("image_num_patches", None)
812
        image_embeds = kwargs.pop("image_embeds", None)
813

814
        if pixel_values_flat is None and image_embeds is None:
815
816
            return None

817
        if image_embeds is not None:
818
            if not isinstance(image_embeds, (torch.Tensor, list)):
819
820
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
821

822
823
            return InternVLImageEmbeddingInputs(
                type="image_embeds",
824
                data=flatten_bn(image_embeds),
825
826
            )

827
828
829
        image_token_id = kwargs["image_token_id"]
        assert isinstance(image_token_id, torch.Tensor)
        self.img_context_token_id = image_token_id.flatten().unique().item()
830

831
832
        if pixel_values_flat is not None:
            if not isinstance(pixel_values_flat, (torch.Tensor, list)):
833
                raise ValueError("Incorrect type of pixel values. "
834
835
                                 f"Got type: {type(pixel_values_flat)}")

836
837
            if not isinstance(image_num_patches, (torch.Tensor, list)):
                raise ValueError("Incorrect type of image_num_patches. "
838
839
840
841
                                 f"Got type: {type(image_num_patches)}")

            pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
            image_num_patches = flatten_bn(image_num_patches, concat=True)
842

843
844
            return InternVLImagePixelInputs(
                type="pixel_values",
845
846
847
848
                pixel_values_flat=self._validate_pixel_values(
                    pixel_values_flat),
                num_patches=image_num_patches,
            )
849
850
851
852
853
854

        raise AssertionError("This line should be unreachable.")

    def _process_image_input(
        self,
        image_input: InternVLImageInputs,
855
    ) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]]:
856
857
858
859
        if image_input["type"] == "image_embeds":
            return image_input["data"]

        assert self.vision_model is not None
860

861
        image_embeds = self.extract_feature(image_input["pixel_values_flat"])
862

863
        num_patches = image_input["num_patches"]
864
865

        # Only one image in the current batch
866
867
        if len(num_patches) == 1:
            return image_embeds.view(
868
                -1, self.config.text_config.hidden_size).unsqueeze(0)
869
870
871
872
873
874
875

        # NOTE: Image embeddings are split into separate tensors for each image
        # by the size of each embedding.
        feature_size = image_embeds.shape[1]
        image_embeds = image_embeds.view(-1,
                                         self.config.text_config.hidden_size)
        image_feature_sizes = [
876
            num_patches * feature_size for num_patches in num_patches
877
        ]
878
        return image_embeds.split(image_feature_sizes)
879

880
    def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
881
        if self.is_mono:
882
            self.visual_token_mask = (
883
884
                input_ids == self.img_context_token_id).reshape(-1, 1)
        else:
885
            self.visual_token_mask = None
886

887
    def get_multimodal_embeddings(
888
            self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
889
890
891
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
892

893
        return self._process_image_input(image_input)
894
895
896
897

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
898
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
899
900
901
902
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
        if multimodal_embeddings is not None:
            assert self.img_context_token_id is not None
903
            self._set_visual_token_mask(input_ids)
904
            inputs_embeds = merge_multimodal_embeddings(
905
906
                input_ids,
                inputs_embeds,
907
                multimodal_embeddings,
908
909
                self.img_context_token_id,
            )
910
911
        return inputs_embeds

912
913
914
915
916
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
917
        inputs_embeds: Optional[torch.Tensor] = None,
918
        **kwargs: object,
919
    ) -> Union[SamplerOutput, IntermediateTensors]:
920

921
        if intermediate_tensors is not None:
922
923
            input_ids = None
            inputs_embeds = None
924
925
926
927
928
929
930
931

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
        elif inputs_embeds is None:
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
932
933
934
935
936
937
938

        forward_kwargs = {
            "input_ids": input_ids,
            "positions": positions,
            "intermediate_tensors": intermediate_tensors,
            "inputs_embeds": inputs_embeds,
        }
939

940
        # Only required if the model is mono-architecture
941
942
943
944
        if self.visual_token_mask is not None:
            forward_kwargs.update(
                {"visual_token_mask": self.visual_token_mask})
            self.visual_token_mask = None
945

946
        hidden_states = self.language_model.model(**forward_kwargs)
947
948
        return hidden_states

949
950
951
952
953
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
954
955
956
957
958
959
960
961
962
963
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        return self.language_model.sample(logits, sampling_metadata)

964
965
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
966
967
968
969
970
971
972
973
        # unused modules appear in OpenGVLab/InternVideo2_5_Chat_8B
        skip_prefixes = [
            "action_embed", "temporal_embed", "track_embed",
            "track_embed_decoder", "box_token", "cg_criterion", "cg_model",
            "loc_encoder", "loc_decoder", "sam", "temporal_token",
            "track_token"
        ]
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
974
        return loader.load_weights(weights)