internvl.py 27 KB
Newer Older
1
2
3
4
5
6
# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_internvl_chat.py
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
7
import re
8
from functools import cached_property, partial
9
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
10
                    TypedDict, Union)
11
12
13
14
15
16
17
18

import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
from transformers import PretrainedConfig

from vllm.attention import AttentionMetadata
19
from vllm.config import VllmConfig
20
21
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
                         InputContext, token_inputs)
22
23
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
Joe Runde's avatar
Joe Runde committed
24
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
25
26
from vllm.model_executor.models.intern_vit import (InternVisionModel,
                                                   InternVisionPatchModel)
27
from vllm.model_executor.sampling_metadata import SamplingMetadata
28
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
29
from vllm.multimodal.inputs import NestedTensors
30
from vllm.multimodal.utils import cached_get_tokenizer
31
from vllm.sequence import IntermediateTensors
32
from vllm.utils import is_list_of
33
34
35

from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
                   get_clip_num_patches)
36
from .interfaces import SupportsMultiModal, SupportsPP
37
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
38
                    maybe_prefix, merge_multimodal_embeddings)
39
40
41
42
43
44
45
46
47
48
49

IMG_START = '<img>'
IMG_END = '</img>'
IMG_CONTEXT = '<IMG_CONTEXT>'

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)


class InternVLImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
50
    data: torch.Tensor
51
    """
52
53
    Shape:
    `(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
54
55
56
    """


57
58
class InternVLImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
59
60
    data: torch.Tensor
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
61
62
63
64
65
66
67
68
69

    `hidden_size` must match the hidden size of language model backbone.
    """


InternVLImageInputs = Union[InternVLImagePixelInputs,
                            InternVLImageEmbeddingInputs]


70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# copied from https://huggingface.co/OpenGVLab/InternVL2-1B
def build_transform(input_size):
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size),
                 interpolation=T.InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])
    return transform


# copied from https://huggingface.co/OpenGVLab/InternVL2-1B
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
                              image_size):
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio


101
def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int,
102
103
                         max_num: int, image_size: int,
                         use_thumbnail: bool) -> Tuple[int, int, int]:
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
    aspect_ratio = orig_width / orig_height

    # calculate the existing image aspect ratio
    target_ratios = set((i, j) for n in range(min_num, max_num + 1)
                        for i in range(1, n + 1) for j in range(1, n + 1)
                        if i * j <= max_num and i * j >= min_num)
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

    # find the closest aspect ratio to the target
    target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio,
                                                    target_ratios, orig_width,
                                                    orig_height, image_size)

    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
121
122
123
    # add thumbnail image if num_blocks > 1
    if use_thumbnail and blocks > 1:
        blocks += 1
124
125
126
    return blocks, target_width, target_height


127
128
129
130
131
132
133
134
135
def calculate_num_blocks_wrapper(
    hf_config: PretrainedConfig,
    max_dynamic_patch: Optional[int] = None,
    dynamic_image_size: Optional[bool] = None,
):
    if dynamic_image_size is None:
        dynamic_image_size = hf_config.dynamic_image_size

    max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
136
137
138
139
140
141
142
143
144
145
146
147
    if max_dynamic_patch is None:
        max_dynamic_patch = hf_config.max_dynamic_patch
    min_num = hf_config.min_dynamic_patch
    image_size = hf_config.vision_config.image_size
    use_thumbnail = hf_config.use_thumbnail
    return partial(calculate_num_blocks,
                   min_num=min_num,
                   max_num=max_dynamic_patch,
                   image_size=image_size,
                   use_thumbnail=use_thumbnail)


148
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
149
150
def dynamic_preprocess(image: Image.Image, min_num: int, max_num: int,
                       image_size: int,
151
                       use_thumbnail: bool) -> List[Image.Image]:
152
153
    orig_width, orig_height = image.size

154
    # calculate the number of blocks without thumbnail
155
    blocks, target_width, target_height = calculate_num_blocks(
156
157
158
159
160
161
        orig_width,
        orig_height,
        min_num,
        max_num,
        image_size,
        use_thumbnail=False)
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = ((i % (target_width // image_size)) * image_size,
               (i // (target_width // image_size)) * image_size,
               ((i % (target_width // image_size)) + 1) * image_size,
               ((i // (target_width // image_size)) + 1) * image_size)
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    assert len(processed_images) == blocks
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images


# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
181
182
def image_to_pixel_values(image: Image.Image, input_size: int, min_num: int,
                          max_num: int, use_thumbnail: bool) -> torch.Tensor:
183
184
    transform = build_transform(input_size=input_size)
    images = dynamic_preprocess(image,
185
186
                                min_num=min_num,
                                max_num=max_num,
187
                                image_size=input_size,
188
                                use_thumbnail=use_thumbnail)
189
190
191
192
193
    pixel_values = [transform(image) for image in images]
    pixel_values = torch.stack(pixel_values)
    return pixel_values


194
195
196
197
198
def image_to_pixel_values_wrapper(
    hf_config: PretrainedConfig,
    max_dynamic_patch: Optional[int] = None,
    dynamic_image_size: Optional[bool] = None,
):
199
200
    image_size = hf_config.vision_config.image_size
    min_num = hf_config.min_dynamic_patch
201
202
203
204
    if dynamic_image_size is None:
        dynamic_image_size = hf_config.dynamic_image_size

    max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
205
206
207
208
209
210
211
212
213
214
    if max_dynamic_patch is None:
        max_dynamic_patch = hf_config.max_dynamic_patch
    use_thumbnail = hf_config.use_thumbnail
    return partial(image_to_pixel_values,
                   input_size=image_size,
                   min_num=min_num,
                   max_num=max_dynamic_patch,
                   use_thumbnail=use_thumbnail)


215
def get_internvl_num_patches(hf_config: PretrainedConfig):
216
217
218
219
    vision_config = hf_config.vision_config
    downsample_ratio = hf_config.downsample_ratio
    image_size = vision_config.image_size
    patch_size = vision_config.patch_size
220
221
222
223
224
    return int(
        get_clip_num_patches(image_size=image_size, patch_size=patch_size) *
        (downsample_ratio**2))


225
226
227
228
229
230
def get_max_internvl_image_tokens(
    ctx: InputContext,
    *,
    max_dynamic_patch: Optional[int] = None,
    dynamic_image_size: Optional[bool] = None,
):
231
    hf_config = ctx.get_hf_config()
232
233
    if dynamic_image_size is None:
        dynamic_image_size = hf_config.dynamic_image_size
234

235
    max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
236
237
    if max_dynamic_patch is None:
        max_dynamic_patch = hf_config.max_dynamic_patch
238
    use_thumbnail = hf_config.use_thumbnail
239
    if use_thumbnail and max_dynamic_patch > 1:
240
241
        max_dynamic_patch += 1

242
    num_patches = get_internvl_num_patches(hf_config)
243
    return num_patches * max_dynamic_patch
244
245


246
247
248
249
250
251
def get_max_internvl_image_size(
    ctx: InputContext,
    *,
    max_dynamic_patch: Optional[int] = None,
    dynamic_image_size: Optional[bool] = None,
):
252
253
    hf_config = ctx.get_hf_config()
    image_size = hf_config.vision_config.image_size
254
255
    if dynamic_image_size is None:
        dynamic_image_size = hf_config.dynamic_image_size
256

257
    max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
258
259
260
261
262
263
264
265
266
267
    if max_dynamic_patch is None:
        max_dynamic_patch = hf_config.max_dynamic_patch
    use_thumbnail = hf_config.use_thumbnail
    if use_thumbnail and max_dynamic_patch > 1:
        max_dynamic_patch += 1
    width = image_size * max_dynamic_patch
    height = image_size
    return width, height


268
269
270
271
272
273
274
275
276
class InternVLInputPipeline:

    def __init__(
        self,
        img_start_token: str,
        img_end_token: str,
        img_context_token: str,
    ) -> None:
        super().__init__()
277

278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
        self.img_start_token = img_start_token
        self.img_end_token = img_end_token
        self.img_context_token = img_context_token

    def _create_image_prompt(self, feature_size: int, num_patches: int) -> str:
        return (self.img_start_token + self.img_context_token * feature_size +
                self.img_end_token)

    def _expand_image_prompt(
        self,
        prompt: str,
        feature_sizes: List[int],
        num_patches: int,
    ) -> str:
        image_idx = sorted(
            map(int, re.findall(r"Image-(\d+): <image>\n", prompt)))

        new_prompt = prompt
        for idx, feature_size in enumerate(feature_sizes, start=1):
            image_prompt = self._create_image_prompt(feature_size, num_patches)
            if not image_idx:
                image_prompt = f"Image-{idx}: {image_prompt}"

            new_prompt = new_prompt.replace('<image>', image_prompt, 1)

        return new_prompt

    def input_processor(
        self,
        ctx: InputContext,
308
        inputs: DecoderOnlyInputs,
309
310
        *,
        max_dynamic_patch: Optional[int] = None,
311
        dynamic_image_size: Optional[bool] = None,
312
313
    ) -> DecoderOnlyInputs:
        multi_modal_data = inputs.get("multi_modal_data")
314
        if multi_modal_data is None or "image" not in multi_modal_data:
315
            return inputs
316
317
318
319
320
321
322

        model_config = ctx.model_config
        hf_config = ctx.get_hf_config()

        image_data = multi_modal_data["image"]
        num_patches = get_internvl_num_patches(hf_config)
        num_blocks_calculator = calculate_num_blocks_wrapper(
323
            hf_config, max_dynamic_patch, dynamic_image_size)
324
325
        if isinstance(image_data, Image.Image):
            width, height = image_data.size
326
            num_blocks, _, _ = num_blocks_calculator(width, height)
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
            image_feature_sizes = [num_blocks * num_patches]
        elif is_list_of(image_data, Image.Image):
            image_feature_sizes = []
            for image in image_data:
                width, height = image.size
                num_blocks, _, _ = num_blocks_calculator(width, height)
                image_feature_sizes.append(num_blocks * num_patches)
        elif isinstance(image_data, torch.Tensor):
            num_images, image_feature_size, hidden_size = image_data.shape
            image_feature_sizes = [image_feature_size]
        else:
            raise TypeError(f"Invalid image type: {type(image_data)}")

        tokenizer = cached_get_tokenizer(
            model_config.tokenizer,
            trust_remote_code=model_config.trust_remote_code)

344
345
        prompt = inputs.get("prompt")
        prompt_token_ids = inputs["prompt_token_ids"]
346
347
348
349
350
351
352
        if prompt is None:
            prompt = tokenizer.decode(prompt_token_ids)

        new_prompt = self._expand_image_prompt(prompt, image_feature_sizes,
                                               num_patches)
        new_prompt_token_ids = tokenizer.encode(new_prompt)

353
354
355
        return token_inputs(prompt=prompt,
                            prompt_token_ids=new_prompt_token_ids,
                            multi_modal_data=multi_modal_data)
356
357
358
359
360
361
362

    def input_mapper(
        self,
        ctx: InputContext,
        data: object,
        *,
        max_dynamic_patch: Optional[int] = None,
363
        dynamic_image_size: Optional[bool] = None,
364
365
366
367
    ):
        hf_config = ctx.get_hf_config()

        image_pixel_values_mapper = image_to_pixel_values_wrapper(
368
            hf_config, max_dynamic_patch, dynamic_image_size)
369
370
371
372
373
374
375
        if isinstance(data, Image.Image):
            data = image_pixel_values_mapper(data)
            # Add an N dimension for number of images per prompt (currently 1).
            data = data.unsqueeze(0)
        elif is_list_of(data, Image.Image):
            # we can't stack here because images may have different num_patches
            data = [image_pixel_values_mapper(img) for img in data]
376
        else:
377
            return MultiModalKwargs({"image_embeds": data})
378
379
380
381
382
383
384
385
        model_config = ctx.model_config
        tokenizer = cached_get_tokenizer(
            model_config.tokenizer,
            trust_remote_code=model_config.trust_remote_code)
        image_token_id = tokenizer.encode(self.img_context_token,
                                          add_special_tokens=False,
                                          return_tensors="pt")[0]

386
        return MultiModalKwargs({
387
388
389
390
391
392
393
394
395
396
397
            "pixel_values": data,
            "image_token_id": image_token_id
        })

    def dummy_data(
        self,
        ctx: InputContext,
        seq_len: int,
        mm_counts: Mapping[str, int],
        *,
        max_dynamic_patch: Optional[int] = None,
398
        dynamic_image_size: Optional[bool] = None,
399
400
401
402
403
404
    ):
        num_images = mm_counts["image"]

        hf_config = ctx.get_hf_config()

        image_feature_size = get_max_internvl_image_tokens(
405
406
407
408
            ctx,
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
        )
409
410
411
412
413
        model_config = ctx.model_config
        tokenizer = cached_get_tokenizer(
            model_config.tokenizer,
            trust_remote_code=model_config.trust_remote_code)

414
        seq_data, ranges = dummy_seq_data_for_clip(
415
416
417
418
419
420
421
422
423
            hf_config.vision_config,
            seq_len,
            num_images,
            image_token_id=tokenizer.encode(self.img_context_token,
                                            add_special_tokens=False)[0],
            image_feature_size_override=image_feature_size,
        )

        max_image_width, max_image_height = get_max_internvl_image_size(
424
425
426
427
            ctx,
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
        )
428
429
430
431
432
433
434
435

        mm_data = dummy_image_for_clip(
            hf_config.vision_config,
            num_images,
            image_width_override=max_image_width,
            image_height_override=max_image_height,
        )

436
        return DummyData(seq_data, mm_data, ranges)
437
438
439
440
441
442


input_pipeline = InternVLInputPipeline(IMG_START, IMG_END, IMG_CONTEXT)


@MULTIMODAL_REGISTRY.register_image_input_mapper(input_pipeline.input_mapper)
443
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens)
444
445
@INPUT_REGISTRY.register_dummy_data(input_pipeline.dummy_data)
@INPUT_REGISTRY.register_input_processor(input_pipeline.input_processor)
446
class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
447

448
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
449
450
        super().__init__()

451
452
453
454
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

455
456
        self.config = config
        self.multimodal_config = multimodal_config
457
        self._patch_quant_config(config, quant_config)
458
459
460
461
462
463
464
465
466

        image_size = config.force_image_size or config.vision_config.image_size
        patch_size = config.vision_config.patch_size
        self.patch_size = patch_size
        self.num_image_token = int(
            (image_size // patch_size)**2 * (config.downsample_ratio**2))
        self.downsample_ratio = config.downsample_ratio
        self.ps_version = config.ps_version

467
468
        self.llm_arch_name = config.text_config.architectures[0]
        self.is_mono = self.llm_arch_name == 'InternLM2VEForCausalLM'
469
470
471
472
        self.vision_model = self._init_vision_model(
            config,
            quant_config=quant_config,
            is_mono=self.is_mono,
473
            prefix=maybe_prefix(prefix, "vision_model"),
474
        )
475

476
        self.language_model = init_vllm_registered_model(
477
            vllm_config=vllm_config,
478
479
480
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
481

482
        self.mlp1 = self._init_mlp1(config)
483
484

        self.img_context_token_id = None
485
        self.visual_token_mask = None
486
487
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)
488

489
490
491
492
493
494
495
496
497
498
499
500
    def _patch_quant_config(self, config: PretrainedConfig,
                            quant_config: QuantizationConfig):
        # the awq models from OpenGVLab missing `modules_to_not_convert`
        # patch the quant_config to add `modules_to_not_convert` back
        if isinstance(quant_config, AWQConfig):
            text_config = config.text_config
            llm_quant_config = getattr(text_config, "quantization_config",
                                       None)
            if (not quant_config.modules_to_not_convert) and \
                (llm_quant_config is not None):
                quant_config.modules_to_not_convert.append("vision_model")

501
502
    @cached_property
    def sampler(self):
503
        if hasattr(self.language_model, "sampler"):
504
505
            return self.language_model.sampler

Joe Runde's avatar
Joe Runde committed
506
        return get_sampler()
507

508
509
510
511
512
513
514
515
    def _init_vision_model(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig],
        *,
        is_mono: bool,
        prefix: str,
    ):
516
        if not is_mono:
517
            vision_feature_layer = config.select_layer
518
519
520
521
522
            if vision_feature_layer < 0:
                num_hidden_layers = config.vision_config.num_hidden_layers \
                    + vision_feature_layer + 1
            else:
                num_hidden_layers = vision_feature_layer + 1
523

524
525
            return InternVisionModel(
                config.vision_config,
526
527
528
529
                quant_config=quant_config,
                num_hidden_layers_override=num_hidden_layers,
                prefix=prefix,
            )
530
531
        else:
            return InternVisionPatchModel(config.vision_config)
532
533
534
535
536
537
538
539
540
541
542
543
544

    def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
        vit_hidden_size = config.vision_config.hidden_size
        llm_hidden_size = config.text_config.hidden_size

        return nn.Sequential(
            nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2),
            nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2,
                      llm_hidden_size),
            nn.GELU(),
            nn.Linear(llm_hidden_size, llm_hidden_size),
        )

545
546
547
548
549
550
551
552
553
554
555
556
557
558
    def pixel_shuffle(self, x, scale_factor=0.5):
        n, w, h, c = x.size()
        # N, W, H, C --> N, W, H * scale, C // scale
        x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
        # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
        x = x.permute(0, 2, 1, 3).contiguous()
        x = x.view(n, int(h * scale_factor), int(w * scale_factor),
                   int(c / (scale_factor * scale_factor)))
        if self.ps_version == 'v1':
            pass
        else:
            x = x.permute(0, 2, 1, 3).contiguous()
        return x

559
    def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor:
560
561
562
563
564
565
566
567
568
569
570
571
        vit_embeds = self.vision_model(pixel_values=pixel_values)
        vit_embeds = vit_embeds[:, 1:, :]

        h = w = int(vit_embeds.shape[1]**0.5)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
        vit_embeds = self.pixel_shuffle(vit_embeds,
                                        scale_factor=self.downsample_ratio)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1,
                                        vit_embeds.shape[-1])
        vit_embeds = self.mlp1(vit_embeds)
        return vit_embeds

572
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
573
574
575
576
577
578
579
580

        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape)

            if actual_dims != expected_dims:
581
                expected_expr = str(expected_dims)
582
                raise ValueError(
583
584
585
                    "The expected shape of pixel values per image per batch "
                    f" per patch is {expected_expr}. "
                    f"You supplied {tuple(d.shape)}.")
586
587
588
589
590
591
592

        for d in data:
            _validate_shape(d)

        return data

    def _parse_and_validate_image_input(
593
            self, **kwargs: object) -> Optional[InternVLImageInputs]:
594
595
        pixel_values = kwargs.pop("pixel_values", None)
        image_token_id = kwargs.pop("image_token_id", None)
596
        image_embeds = kwargs.pop("image_embeds", None)
597

598
        if pixel_values is None and image_embeds is None:
599
600
            return None

601
602
603
604
        if image_embeds is not None:
            if not isinstance(image_embeds, torch.Tensor):
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
605

606
607
            return InternVLImageEmbeddingInputs(
                type="image_embeds",
608
                data=flatten_bn(image_embeds),
609
610
            )

611
612
        self.img_context_token_id = image_token_id[0]

613
614
615
616
        if pixel_values is not None:
            if not isinstance(pixel_values, (torch.Tensor, list)):
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
617
618
            # We need to flatten (B, N, P) to (B*N*P),
            # so we call flatten_bn twice.
619
620
            return InternVLImagePixelInputs(
                type="pixel_values",
621
                data=self._validate_pixel_values(
622
                    flatten_bn(flatten_bn(pixel_values), concat=True)),
623
624
625
626
627
628
629
630
631
632
633
634
635
            )

        raise AssertionError("This line should be unreachable.")

    def _process_image_input(
        self,
        image_input: InternVLImageInputs,
    ) -> torch.Tensor:
        if image_input["type"] == "image_embeds":
            return image_input["data"]

        assert self.vision_model is not None
        image_embeds = self.extract_feature(image_input["data"])
636

637
        return image_embeds
638

639
    def _set_visual_token_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
640
        if self.is_mono:
641
            self.visual_token_mask = (
642
643
                input_ids == self.img_context_token_id).reshape(-1, 1)
        else:
644
            self.visual_token_mask = None
645

646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
    def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: Optional[NestedTensors] = None,
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
        if multimodal_embeddings is not None:
            assert self.img_context_token_id is not None
661
            self._set_visual_token_mask(input_ids)
662
663
664
665
666
            inputs_embeds = merge_multimodal_embeddings(
                input_ids, inputs_embeds, multimodal_embeddings,
                self.img_context_token_id)
        return inputs_embeds

667
668
669
670
671
672
673
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
        intermediate_tensors: Optional[IntermediateTensors] = None,
674
        inputs_embeds: Optional[torch.Tensor] = None,
675
        **kwargs: object,
676
    ) -> Union[SamplerOutput, IntermediateTensors]:
677

678
        if intermediate_tensors is not None:
679
680
            input_ids = None
            inputs_embeds = None
681
682
683
684
685
686
687
688

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
        elif inputs_embeds is None:
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
689
690
691
692
693
694
695
696
697

        forward_kwargs = {
            "input_ids": input_ids,
            "positions": positions,
            "kv_caches": kv_caches,
            "attn_metadata": attn_metadata,
            "intermediate_tensors": intermediate_tensors,
            "inputs_embeds": inputs_embeds,
        }
698

699
700
701
702
703
704
        if self.visual_token_mask is not None:
            # overwrite visual_token_mask and img_context_token_id back to None,
            # so that this doesn't need to depend on encoder output
            forward_kwargs.update(
                {"visual_token_mask": self.visual_token_mask})
            self.visual_token_mask = None
705
706
            self.img_context_token_id = None

707
        hidden_states = self.language_model.model(**forward_kwargs)
708
709
        return hidden_states

710
711
712
713
714
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
715
716
717
718
719
720
721
722
723
724
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        return self.language_model.sample(logits, sampling_metadata)

725
726
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
727
        loader = AutoWeightsLoader(self)
728
        return loader.load_weights(weights)