internvl.py 29.2 KB
Newer Older
1
2
3
4
5
6
# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_internvl_chat.py
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
7
import re
8
from functools import cached_property, partial
9
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
10
                    TypedDict, Union)
11
12
13
14
15
16
17
18

import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
from transformers import PretrainedConfig

from vllm.attention import AttentionMetadata
19
from vllm.config import VllmConfig
20
21
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
                         InputContext, token_inputs)
22
23
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
Joe Runde's avatar
Joe Runde committed
24
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
25
26
from vllm.model_executor.models.intern_vit import (InternVisionModel,
                                                   InternVisionPatchModel)
27
from vllm.model_executor.sampling_metadata import SamplingMetadata
28
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
29
from vllm.multimodal.inputs import NestedTensors, PlaceholderRange
30
from vllm.multimodal.utils import cached_get_tokenizer
31
from vllm.sequence import IntermediateTensors
32
from vllm.utils import is_list_of
33
34
35

from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
                   get_clip_num_patches)
36
from .interfaces import SupportsMultiModal, SupportsPP
37
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
38
                    maybe_prefix, merge_multimodal_embeddings)
39
40
41
42
43
44
45
46
47
48
49

IMG_START = '<img>'
IMG_END = '</img>'
IMG_CONTEXT = '<IMG_CONTEXT>'

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)


class InternVLImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
50
    data: torch.Tensor
51
    """
52
53
    Shape:
    `(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
54
    """
55
56
57
58
    patches_per_image: List[int]
    """
    List of number of total patches for each image in the batch.
    """
59
60


61
62
class InternVLImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
63
64
65
66
    data: NestedTensors
    """ 
    A tensor of shape `(num_images, total_image_feature_size, hidden_size)`
    or a list of tensors of shape `(total_image_feature_size, hidden_size)`
67
68
69
70
71
72
73
74
75

    `hidden_size` must match the hidden size of language model backbone.
    """


InternVLImageInputs = Union[InternVLImagePixelInputs,
                            InternVLImageEmbeddingInputs]


76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# copied from https://huggingface.co/OpenGVLab/InternVL2-1B
def build_transform(input_size):
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size),
                 interpolation=T.InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])
    return transform


# copied from https://huggingface.co/OpenGVLab/InternVL2-1B
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
                              image_size):
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio


107
def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int,
108
109
                         max_num: int, image_size: int,
                         use_thumbnail: bool) -> Tuple[int, int, int]:
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
    aspect_ratio = orig_width / orig_height

    # calculate the existing image aspect ratio
    target_ratios = set((i, j) for n in range(min_num, max_num + 1)
                        for i in range(1, n + 1) for j in range(1, n + 1)
                        if i * j <= max_num and i * j >= min_num)
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

    # find the closest aspect ratio to the target
    target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio,
                                                    target_ratios, orig_width,
                                                    orig_height, image_size)

    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
127
128
129
    # add thumbnail image if num_blocks > 1
    if use_thumbnail and blocks > 1:
        blocks += 1
130
131
132
    return blocks, target_width, target_height


133
134
135
136
137
138
139
140
141
def calculate_num_blocks_wrapper(
    hf_config: PretrainedConfig,
    max_dynamic_patch: Optional[int] = None,
    dynamic_image_size: Optional[bool] = None,
):
    if dynamic_image_size is None:
        dynamic_image_size = hf_config.dynamic_image_size

    max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
142
143
144
145
146
147
148
149
150
151
152
153
    if max_dynamic_patch is None:
        max_dynamic_patch = hf_config.max_dynamic_patch
    min_num = hf_config.min_dynamic_patch
    image_size = hf_config.vision_config.image_size
    use_thumbnail = hf_config.use_thumbnail
    return partial(calculate_num_blocks,
                   min_num=min_num,
                   max_num=max_dynamic_patch,
                   image_size=image_size,
                   use_thumbnail=use_thumbnail)


154
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
155
156
def dynamic_preprocess(image: Image.Image, min_num: int, max_num: int,
                       image_size: int,
157
                       use_thumbnail: bool) -> List[Image.Image]:
158
159
    orig_width, orig_height = image.size

160
    # calculate the number of blocks without thumbnail
161
    blocks, target_width, target_height = calculate_num_blocks(
162
163
164
165
166
167
        orig_width,
        orig_height,
        min_num,
        max_num,
        image_size,
        use_thumbnail=False)
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = ((i % (target_width // image_size)) * image_size,
               (i // (target_width // image_size)) * image_size,
               ((i % (target_width // image_size)) + 1) * image_size,
               ((i // (target_width // image_size)) + 1) * image_size)
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    assert len(processed_images) == blocks
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images


# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
187
188
def image_to_pixel_values(image: Image.Image, input_size: int, min_num: int,
                          max_num: int, use_thumbnail: bool) -> torch.Tensor:
189
190
    transform = build_transform(input_size=input_size)
    images = dynamic_preprocess(image,
191
192
                                min_num=min_num,
                                max_num=max_num,
193
                                image_size=input_size,
194
                                use_thumbnail=use_thumbnail)
195
196
197
198
199
    pixel_values = [transform(image) for image in images]
    pixel_values = torch.stack(pixel_values)
    return pixel_values


200
201
202
203
204
def image_to_pixel_values_wrapper(
    hf_config: PretrainedConfig,
    max_dynamic_patch: Optional[int] = None,
    dynamic_image_size: Optional[bool] = None,
):
205
206
    image_size = hf_config.vision_config.image_size
    min_num = hf_config.min_dynamic_patch
207
208
209
210
    if dynamic_image_size is None:
        dynamic_image_size = hf_config.dynamic_image_size

    max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
211
212
213
214
215
216
217
218
219
220
    if max_dynamic_patch is None:
        max_dynamic_patch = hf_config.max_dynamic_patch
    use_thumbnail = hf_config.use_thumbnail
    return partial(image_to_pixel_values,
                   input_size=image_size,
                   min_num=min_num,
                   max_num=max_dynamic_patch,
                   use_thumbnail=use_thumbnail)


221
def get_internvl_num_patches(hf_config: PretrainedConfig):
222
223
224
225
    vision_config = hf_config.vision_config
    downsample_ratio = hf_config.downsample_ratio
    image_size = vision_config.image_size
    patch_size = vision_config.patch_size
226
227
228
229
230
    return int(
        get_clip_num_patches(image_size=image_size, patch_size=patch_size) *
        (downsample_ratio**2))


231
232
233
234
235
236
def get_max_internvl_image_tokens(
    ctx: InputContext,
    *,
    max_dynamic_patch: Optional[int] = None,
    dynamic_image_size: Optional[bool] = None,
):
237
    hf_config = ctx.get_hf_config()
238
239
    if dynamic_image_size is None:
        dynamic_image_size = hf_config.dynamic_image_size
240

241
    max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
242
243
    if max_dynamic_patch is None:
        max_dynamic_patch = hf_config.max_dynamic_patch
244
    use_thumbnail = hf_config.use_thumbnail
245
    if use_thumbnail and max_dynamic_patch > 1:
246
247
        max_dynamic_patch += 1

248
    num_patches = get_internvl_num_patches(hf_config)
249
    return num_patches * max_dynamic_patch
250
251


252
253
254
255
256
257
def get_max_internvl_image_size(
    ctx: InputContext,
    *,
    max_dynamic_patch: Optional[int] = None,
    dynamic_image_size: Optional[bool] = None,
):
258
259
    hf_config = ctx.get_hf_config()
    image_size = hf_config.vision_config.image_size
260
261
    if dynamic_image_size is None:
        dynamic_image_size = hf_config.dynamic_image_size
262

263
    max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
264
265
266
267
268
269
270
271
272
273
    if max_dynamic_patch is None:
        max_dynamic_patch = hf_config.max_dynamic_patch
    use_thumbnail = hf_config.use_thumbnail
    if use_thumbnail and max_dynamic_patch > 1:
        max_dynamic_patch += 1
    width = image_size * max_dynamic_patch
    height = image_size
    return width, height


274
275
276
277
278
279
280
281
282
class InternVLInputPipeline:

    def __init__(
        self,
        img_start_token: str,
        img_end_token: str,
        img_context_token: str,
    ) -> None:
        super().__init__()
283

284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
        self.img_start_token = img_start_token
        self.img_end_token = img_end_token
        self.img_context_token = img_context_token

    def _create_image_prompt(self, feature_size: int, num_patches: int) -> str:
        return (self.img_start_token + self.img_context_token * feature_size +
                self.img_end_token)

    def _expand_image_prompt(
        self,
        prompt: str,
        feature_sizes: List[int],
        num_patches: int,
    ) -> str:
        image_idx = sorted(
            map(int, re.findall(r"Image-(\d+): <image>\n", prompt)))

        new_prompt = prompt
        for idx, feature_size in enumerate(feature_sizes, start=1):
            image_prompt = self._create_image_prompt(feature_size, num_patches)
            if not image_idx:
                image_prompt = f"Image-{idx}: {image_prompt}"

            new_prompt = new_prompt.replace('<image>', image_prompt, 1)

        return new_prompt

    def input_processor(
        self,
        ctx: InputContext,
314
        inputs: DecoderOnlyInputs,
315
316
        *,
        max_dynamic_patch: Optional[int] = None,
317
        dynamic_image_size: Optional[bool] = None,
318
319
    ) -> DecoderOnlyInputs:
        multi_modal_data = inputs.get("multi_modal_data")
320
        if multi_modal_data is None or "image" not in multi_modal_data:
321
            return inputs
322
323
324
325
326
327
328

        model_config = ctx.model_config
        hf_config = ctx.get_hf_config()

        image_data = multi_modal_data["image"]
        num_patches = get_internvl_num_patches(hf_config)
        num_blocks_calculator = calculate_num_blocks_wrapper(
329
            hf_config, max_dynamic_patch, dynamic_image_size)
330
331
        if isinstance(image_data, Image.Image):
            width, height = image_data.size
332
            num_blocks, _, _ = num_blocks_calculator(width, height)
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
            image_feature_sizes = [num_blocks * num_patches]
        elif is_list_of(image_data, Image.Image):
            image_feature_sizes = []
            for image in image_data:
                width, height = image.size
                num_blocks, _, _ = num_blocks_calculator(width, height)
                image_feature_sizes.append(num_blocks * num_patches)
        elif isinstance(image_data, torch.Tensor):
            num_images, image_feature_size, hidden_size = image_data.shape
            image_feature_sizes = [image_feature_size]
        else:
            raise TypeError(f"Invalid image type: {type(image_data)}")

        tokenizer = cached_get_tokenizer(
            model_config.tokenizer,
            trust_remote_code=model_config.trust_remote_code)

350
351
        prompt = inputs.get("prompt")
        prompt_token_ids = inputs["prompt_token_ids"]
352
353
354
355
356
357
        if prompt is None:
            prompt = tokenizer.decode(prompt_token_ids)

        new_prompt = self._expand_image_prompt(prompt, image_feature_sizes,
                                               num_patches)
        new_prompt_token_ids = tokenizer.encode(new_prompt)
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
        img_context_token_id = tokenizer.encode(self.img_context_token,
                                                add_special_tokens=False)
        assert len(img_context_token_id) == 1, \
            (f"Invalid image token '{self.img_context_token}': A valid image "
            f"token encodes to a single token ID, got {img_context_token_id}.")
        img_context_token_id = img_context_token_id[0]

        # Get precise tracking of placeholder positions
        token_idx = image_idx = 0
        placeholder_ranges = []
        while token_idx < len(new_prompt_token_ids):
            if new_prompt_token_ids[token_idx] == img_context_token_id:
                curr_image_featue_size = image_feature_sizes[image_idx]
                placeholder_ranges.append(
                    PlaceholderRange(offset=token_idx,
                                     length=curr_image_featue_size))
                image_idx += 1
                token_idx += curr_image_featue_size
            else:
                token_idx += 1
378

379
380
381
382
383
        return token_inputs(
            prompt=prompt,
            prompt_token_ids=new_prompt_token_ids,
            multi_modal_data=multi_modal_data,
            multi_modal_placeholders={"image": placeholder_ranges})
384
385
386
387
388
389
390

    def input_mapper(
        self,
        ctx: InputContext,
        data: object,
        *,
        max_dynamic_patch: Optional[int] = None,
391
        dynamic_image_size: Optional[bool] = None,
392
393
394
395
    ):
        hf_config = ctx.get_hf_config()

        image_pixel_values_mapper = image_to_pixel_values_wrapper(
396
            hf_config, max_dynamic_patch, dynamic_image_size)
397
398
399
400
401
402
403
        if isinstance(data, Image.Image):
            data = image_pixel_values_mapper(data)
            # Add an N dimension for number of images per prompt (currently 1).
            data = data.unsqueeze(0)
        elif is_list_of(data, Image.Image):
            # we can't stack here because images may have different num_patches
            data = [image_pixel_values_mapper(img) for img in data]
404
        else:
405
            return MultiModalKwargs({"image_embeds": data})
406
407
408
409
410
411
412
413
        model_config = ctx.model_config
        tokenizer = cached_get_tokenizer(
            model_config.tokenizer,
            trust_remote_code=model_config.trust_remote_code)
        image_token_id = tokenizer.encode(self.img_context_token,
                                          add_special_tokens=False,
                                          return_tensors="pt")[0]

414
        return MultiModalKwargs({
415
416
417
418
419
420
421
422
423
424
425
            "pixel_values": data,
            "image_token_id": image_token_id
        })

    def dummy_data(
        self,
        ctx: InputContext,
        seq_len: int,
        mm_counts: Mapping[str, int],
        *,
        max_dynamic_patch: Optional[int] = None,
426
        dynamic_image_size: Optional[bool] = None,
427
428
429
430
431
432
    ):
        num_images = mm_counts["image"]

        hf_config = ctx.get_hf_config()

        image_feature_size = get_max_internvl_image_tokens(
433
434
435
436
            ctx,
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
        )
437
438
439
440
441
        model_config = ctx.model_config
        tokenizer = cached_get_tokenizer(
            model_config.tokenizer,
            trust_remote_code=model_config.trust_remote_code)

442
        seq_data, ranges = dummy_seq_data_for_clip(
443
444
445
446
447
448
449
450
451
            hf_config.vision_config,
            seq_len,
            num_images,
            image_token_id=tokenizer.encode(self.img_context_token,
                                            add_special_tokens=False)[0],
            image_feature_size_override=image_feature_size,
        )

        max_image_width, max_image_height = get_max_internvl_image_size(
452
453
454
455
            ctx,
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
        )
456
457
458
459
460
461
462
463

        mm_data = dummy_image_for_clip(
            hf_config.vision_config,
            num_images,
            image_width_override=max_image_width,
            image_height_override=max_image_height,
        )

464
        return DummyData(seq_data, mm_data, ranges)
465
466
467
468
469
470


input_pipeline = InternVLInputPipeline(IMG_START, IMG_END, IMG_CONTEXT)


@MULTIMODAL_REGISTRY.register_image_input_mapper(input_pipeline.input_mapper)
471
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens)
472
473
@INPUT_REGISTRY.register_dummy_data(input_pipeline.dummy_data)
@INPUT_REGISTRY.register_input_processor(input_pipeline.input_processor)
474
class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
475

476
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
477
478
        super().__init__()

479
480
481
482
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

483
484
        self.config = config
        self.multimodal_config = multimodal_config
485
        self._patch_quant_config(config, quant_config)
486
487
488
489
490
491
492
493
494

        image_size = config.force_image_size or config.vision_config.image_size
        patch_size = config.vision_config.patch_size
        self.patch_size = patch_size
        self.num_image_token = int(
            (image_size // patch_size)**2 * (config.downsample_ratio**2))
        self.downsample_ratio = config.downsample_ratio
        self.ps_version = config.ps_version

495
496
        self.llm_arch_name = config.text_config.architectures[0]
        self.is_mono = self.llm_arch_name == 'InternLM2VEForCausalLM'
497
498
499
500
        self.vision_model = self._init_vision_model(
            config,
            quant_config=quant_config,
            is_mono=self.is_mono,
501
            prefix=maybe_prefix(prefix, "vision_model"),
502
        )
503

504
        self.language_model = init_vllm_registered_model(
505
            vllm_config=vllm_config,
506
507
508
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
509

510
        self.mlp1 = self._init_mlp1(config)
511
512

        self.img_context_token_id = None
513
        self.visual_token_mask = None
514
515
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)
516

517
518
519
520
521
522
523
524
525
526
527
528
    def _patch_quant_config(self, config: PretrainedConfig,
                            quant_config: QuantizationConfig):
        # the awq models from OpenGVLab missing `modules_to_not_convert`
        # patch the quant_config to add `modules_to_not_convert` back
        if isinstance(quant_config, AWQConfig):
            text_config = config.text_config
            llm_quant_config = getattr(text_config, "quantization_config",
                                       None)
            if (not quant_config.modules_to_not_convert) and \
                (llm_quant_config is not None):
                quant_config.modules_to_not_convert.append("vision_model")

529
530
    @cached_property
    def sampler(self):
531
        if hasattr(self.language_model, "sampler"):
532
533
            return self.language_model.sampler

Joe Runde's avatar
Joe Runde committed
534
        return get_sampler()
535

536
537
538
539
540
541
542
543
    def _init_vision_model(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig],
        *,
        is_mono: bool,
        prefix: str,
    ):
544
        if not is_mono:
545
            vision_feature_layer = config.select_layer
546
547
548
549
550
            if vision_feature_layer < 0:
                num_hidden_layers = config.vision_config.num_hidden_layers \
                    + vision_feature_layer + 1
            else:
                num_hidden_layers = vision_feature_layer + 1
551

552
553
            return InternVisionModel(
                config.vision_config,
554
555
556
557
                quant_config=quant_config,
                num_hidden_layers_override=num_hidden_layers,
                prefix=prefix,
            )
558
559
        else:
            return InternVisionPatchModel(config.vision_config)
560
561
562
563
564
565
566
567
568
569
570
571
572

    def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
        vit_hidden_size = config.vision_config.hidden_size
        llm_hidden_size = config.text_config.hidden_size

        return nn.Sequential(
            nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2),
            nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2,
                      llm_hidden_size),
            nn.GELU(),
            nn.Linear(llm_hidden_size, llm_hidden_size),
        )

573
574
575
576
577
578
579
580
581
582
583
584
585
586
    def pixel_shuffle(self, x, scale_factor=0.5):
        n, w, h, c = x.size()
        # N, W, H, C --> N, W, H * scale, C // scale
        x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
        # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
        x = x.permute(0, 2, 1, 3).contiguous()
        x = x.view(n, int(h * scale_factor), int(w * scale_factor),
                   int(c / (scale_factor * scale_factor)))
        if self.ps_version == 'v1':
            pass
        else:
            x = x.permute(0, 2, 1, 3).contiguous()
        return x

587
    def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor:
588
589
590
591
592
593
594
595
596
597
598
599
        vit_embeds = self.vision_model(pixel_values=pixel_values)
        vit_embeds = vit_embeds[:, 1:, :]

        h = w = int(vit_embeds.shape[1]**0.5)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
        vit_embeds = self.pixel_shuffle(vit_embeds,
                                        scale_factor=self.downsample_ratio)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1,
                                        vit_embeds.shape[-1])
        vit_embeds = self.mlp1(vit_embeds)
        return vit_embeds

600
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
601
602
603
604
605
606
607
608

        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape)

            if actual_dims != expected_dims:
609
                expected_expr = str(expected_dims)
610
                raise ValueError(
611
612
613
                    "The expected shape of pixel values per image per batch "
                    f" per patch is {expected_expr}. "
                    f"You supplied {tuple(d.shape)}.")
614
615
616
617
618
619
620

        for d in data:
            _validate_shape(d)

        return data

    def _parse_and_validate_image_input(
621
            self, **kwargs: object) -> Optional[InternVLImageInputs]:
622
623
        pixel_values = kwargs.pop("pixel_values", None)
        image_token_id = kwargs.pop("image_token_id", None)
624
        image_embeds = kwargs.pop("image_embeds", None)
625

626
        if pixel_values is None and image_embeds is None:
627
628
            return None

629
630
631
632
        if image_embeds is not None:
            if not isinstance(image_embeds, torch.Tensor):
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
633

634
635
            return InternVLImageEmbeddingInputs(
                type="image_embeds",
636
                data=flatten_bn(image_embeds),
637
638
            )

639
640
        self.img_context_token_id = image_token_id[0]

641
642
643
644
        if pixel_values is not None:
            if not isinstance(pixel_values, (torch.Tensor, list)):
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
645
646
647
648
649

            patches_per_image = []
            for request_pixel_values in pixel_values:
                for image_pixel_values in request_pixel_values:
                    patches_per_image.append(image_pixel_values.shape[0])
650
651
            # We need to flatten (B, N, P) to (B*N*P),
            # so we call flatten_bn twice.
652
653
            return InternVLImagePixelInputs(
                type="pixel_values",
654
                data=self._validate_pixel_values(
655
                    flatten_bn(flatten_bn(pixel_values), concat=True)),
656
                patches_per_image=patches_per_image)
657
658
659
660
661
662

        raise AssertionError("This line should be unreachable.")

    def _process_image_input(
        self,
        image_input: InternVLImageInputs,
663
    ) -> Tuple[torch.Tensor]:
664
665
666
667
        if image_input["type"] == "image_embeds":
            return image_input["data"]

        assert self.vision_model is not None
668

669
        image_embeds = self.extract_feature(image_input["data"])
670

671
        patches_per_image = image_input["patches_per_image"]
672
673

        # Only one image in the current batch
674
        if len(patches_per_image) == 1:
675
676
            image_embeds = image_embeds.view(
                -1, self.config.text_config.hidden_size).unsqueeze(0)
677
678
679
680
681
682
683
684
685
686
687
            return image_embeds

        # NOTE: Image embeddings are split into separate tensors for each image
        # by the size of each embedding.
        feature_size = image_embeds.shape[1]
        image_embeds = image_embeds.view(-1,
                                         self.config.text_config.hidden_size)
        image_feature_sizes = [
            num_patches * feature_size for num_patches in patches_per_image
        ]
        image_embeds = image_embeds.split(image_feature_sizes)
688
        return image_embeds
689

690
    def _set_visual_token_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
691
        if self.is_mono:
692
            self.visual_token_mask = (
693
694
                input_ids == self.img_context_token_id).reshape(-1, 1)
        else:
695
            self.visual_token_mask = None
696

697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
    def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: Optional[NestedTensors] = None,
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
        if multimodal_embeddings is not None:
            assert self.img_context_token_id is not None
712
            self._set_visual_token_mask(input_ids)
713
714
715
716
717
            inputs_embeds = merge_multimodal_embeddings(
                input_ids, inputs_embeds, multimodal_embeddings,
                self.img_context_token_id)
        return inputs_embeds

718
719
720
721
722
723
724
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
        intermediate_tensors: Optional[IntermediateTensors] = None,
725
        inputs_embeds: Optional[torch.Tensor] = None,
726
        **kwargs: object,
727
    ) -> Union[SamplerOutput, IntermediateTensors]:
728

729
        if intermediate_tensors is not None:
730
731
            input_ids = None
            inputs_embeds = None
732
733
734
735
736
737
738
739

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
        elif inputs_embeds is None:
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
740
741
742
743
744
745
746
747
748

        forward_kwargs = {
            "input_ids": input_ids,
            "positions": positions,
            "kv_caches": kv_caches,
            "attn_metadata": attn_metadata,
            "intermediate_tensors": intermediate_tensors,
            "inputs_embeds": inputs_embeds,
        }
749

750
        # Only required if the model is mono-architecture
751
752
753
754
        if self.visual_token_mask is not None:
            forward_kwargs.update(
                {"visual_token_mask": self.visual_token_mask})
            self.visual_token_mask = None
755

756
        hidden_states = self.language_model.model(**forward_kwargs)
757
758
        return hidden_states

759
760
761
762
763
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
764
765
766
767
768
769
770
771
772
773
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        return self.language_model.sample(logits, sampling_metadata)

774
775
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
776
        loader = AutoWeightsLoader(self)
777
        return loader.load_weights(weights)