internvl.py 24.9 KB
Newer Older
1
2
3
4
5
6
# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_internvl_chat.py
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
7
import re
8
9
10
from functools import cached_property, partial
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
                    TypedDict, Union)
11
12
13
14
15
16
17
18
19

import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
from transformers import PretrainedConfig

from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
20
21
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
                         token_inputs)
22
23
from vllm.model_executor.layers.quantization import (AWQConfig,
                                                     QuantizationConfig)
24
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
25
26
from vllm.model_executor.models.intern_vit import (InternVisionModel,
                                                   InternVisionPatchModel)
27
from vllm.model_executor.sampling_metadata import SamplingMetadata
28
from vllm.multimodal import MULTIMODAL_REGISTRY
29
from vllm.multimodal.base import MultiModalInputs
30
from vllm.multimodal.utils import cached_get_tokenizer
31
from vllm.sequence import IntermediateTensors
32
from vllm.utils import is_list_of
33
34
35

from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
                   get_clip_num_patches)
36
from .interfaces import SupportsMultiModal, SupportsPP
37
38
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
                    merge_multimodal_embeddings)
39
40
41
42
43
44
45
46
47
48
49

IMG_START = '<img>'
IMG_END = '</img>'
IMG_CONTEXT = '<IMG_CONTEXT>'

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)


class InternVLImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
50
    data: torch.Tensor
51
    """
52
53
    Shape:
    `(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
54
55
56
    """


57
58
class InternVLImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
59
60
    data: torch.Tensor
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
61
62
63
64
65
66
67
68
69

    `hidden_size` must match the hidden size of language model backbone.
    """


InternVLImageInputs = Union[InternVLImagePixelInputs,
                            InternVLImageEmbeddingInputs]


70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# copied from https://huggingface.co/OpenGVLab/InternVL2-1B
def build_transform(input_size):
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size),
                 interpolation=T.InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])
    return transform


# copied from https://huggingface.co/OpenGVLab/InternVL2-1B
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
                              image_size):
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio


101
def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int,
102
103
                         max_num: int, image_size: int,
                         use_thumbnail: bool) -> Tuple[int, int, int]:
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
    aspect_ratio = orig_width / orig_height

    # calculate the existing image aspect ratio
    target_ratios = set((i, j) for n in range(min_num, max_num + 1)
                        for i in range(1, n + 1) for j in range(1, n + 1)
                        if i * j <= max_num and i * j >= min_num)
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

    # find the closest aspect ratio to the target
    target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio,
                                                    target_ratios, orig_width,
                                                    orig_height, image_size)

    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
121
122
123
    # add thumbnail image if num_blocks > 1
    if use_thumbnail and blocks > 1:
        blocks += 1
124
125
126
    return blocks, target_width, target_height


127
def calculate_num_blocks_wrapper(hf_config: PretrainedConfig,
128
129
130
131
132
133
134
135
136
137
138
139
140
                                 max_dynamic_patch: Optional[int] = None):
    if max_dynamic_patch is None:
        max_dynamic_patch = hf_config.max_dynamic_patch
    min_num = hf_config.min_dynamic_patch
    image_size = hf_config.vision_config.image_size
    use_thumbnail = hf_config.use_thumbnail
    return partial(calculate_num_blocks,
                   min_num=min_num,
                   max_num=max_dynamic_patch,
                   image_size=image_size,
                   use_thumbnail=use_thumbnail)


141
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
142
143
def dynamic_preprocess(image: Image.Image, min_num: int, max_num: int,
                       image_size: int,
144
                       use_thumbnail: bool) -> List[Image.Image]:
145
146
    orig_width, orig_height = image.size

147
    # calculate the number of blocks without thumbnail
148
    blocks, target_width, target_height = calculate_num_blocks(
149
150
151
152
153
154
        orig_width,
        orig_height,
        min_num,
        max_num,
        image_size,
        use_thumbnail=False)
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = ((i % (target_width // image_size)) * image_size,
               (i // (target_width // image_size)) * image_size,
               ((i % (target_width // image_size)) + 1) * image_size,
               ((i // (target_width // image_size)) + 1) * image_size)
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    assert len(processed_images) == blocks
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images


# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
174
175
def image_to_pixel_values(image: Image.Image, input_size: int, min_num: int,
                          max_num: int, use_thumbnail: bool) -> torch.Tensor:
176
177
    transform = build_transform(input_size=input_size)
    images = dynamic_preprocess(image,
178
179
                                min_num=min_num,
                                max_num=max_num,
180
                                image_size=input_size,
181
                                use_thumbnail=use_thumbnail)
182
183
184
185
186
    pixel_values = [transform(image) for image in images]
    pixel_values = torch.stack(pixel_values)
    return pixel_values


187
def image_to_pixel_values_wrapper(hf_config: PretrainedConfig,
188
189
190
191
192
193
194
195
196
197
198
199
200
                                  max_dynamic_patch: Optional[int] = None):
    image_size = hf_config.vision_config.image_size
    min_num = hf_config.min_dynamic_patch
    if max_dynamic_patch is None:
        max_dynamic_patch = hf_config.max_dynamic_patch
    use_thumbnail = hf_config.use_thumbnail
    return partial(image_to_pixel_values,
                   input_size=image_size,
                   min_num=min_num,
                   max_num=max_dynamic_patch,
                   use_thumbnail=use_thumbnail)


201
def get_internvl_num_patches(hf_config: PretrainedConfig):
202
203
204
205
    vision_config = hf_config.vision_config
    downsample_ratio = hf_config.downsample_ratio
    image_size = vision_config.image_size
    patch_size = vision_config.patch_size
206
207
208
209
210
    return int(
        get_clip_num_patches(image_size=image_size, patch_size=patch_size) *
        (downsample_ratio**2))


211
212
213
def get_max_internvl_image_tokens(ctx: InputContext,
                                  *,
                                  max_dynamic_patch: Optional[int] = None):
214
    hf_config = ctx.get_hf_config()
215

216
217
    if max_dynamic_patch is None:
        max_dynamic_patch = hf_config.max_dynamic_patch
218
    use_thumbnail = hf_config.use_thumbnail
219
    if use_thumbnail and max_dynamic_patch > 1:
220
221
        max_dynamic_patch += 1

222
    num_patches = get_internvl_num_patches(hf_config)
223
    return num_patches * max_dynamic_patch
224
225


226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
def get_max_internvl_image_size(ctx: InputContext,
                                *,
                                max_dynamic_patch: Optional[int] = None):
    hf_config = ctx.get_hf_config()
    image_size = hf_config.vision_config.image_size

    if max_dynamic_patch is None:
        max_dynamic_patch = hf_config.max_dynamic_patch
    use_thumbnail = hf_config.use_thumbnail
    if use_thumbnail and max_dynamic_patch > 1:
        max_dynamic_patch += 1
    width = image_size * max_dynamic_patch
    height = image_size
    return width, height


242
243
244
245
246
247
248
249
250
class InternVLInputPipeline:

    def __init__(
        self,
        img_start_token: str,
        img_end_token: str,
        img_context_token: str,
    ) -> None:
        super().__init__()
251

252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
        self.img_start_token = img_start_token
        self.img_end_token = img_end_token
        self.img_context_token = img_context_token

    def _create_image_prompt(self, feature_size: int, num_patches: int) -> str:
        return (self.img_start_token + self.img_context_token * feature_size +
                self.img_end_token)

    def _expand_image_prompt(
        self,
        prompt: str,
        feature_sizes: List[int],
        num_patches: int,
    ) -> str:
        image_idx = sorted(
            map(int, re.findall(r"Image-(\d+): <image>\n", prompt)))

        new_prompt = prompt
        for idx, feature_size in enumerate(feature_sizes, start=1):
            image_prompt = self._create_image_prompt(feature_size, num_patches)
            if not image_idx:
                image_prompt = f"Image-{idx}: {image_prompt}"

            new_prompt = new_prompt.replace('<image>', image_prompt, 1)

        return new_prompt

    def input_processor(
        self,
        ctx: InputContext,
282
        inputs: DecoderOnlyInputs,
283
284
        *,
        max_dynamic_patch: Optional[int] = None,
285
286
    ) -> DecoderOnlyInputs:
        multi_modal_data = inputs.get("multi_modal_data")
287
        if multi_modal_data is None or "image" not in multi_modal_data:
288
            return inputs
289
290
291
292
293
294
295
296
297
298

        model_config = ctx.model_config
        hf_config = ctx.get_hf_config()

        image_data = multi_modal_data["image"]
        num_patches = get_internvl_num_patches(hf_config)
        num_blocks_calculator = calculate_num_blocks_wrapper(
            hf_config, max_dynamic_patch)
        if isinstance(image_data, Image.Image):
            width, height = image_data.size
299
            num_blocks, _, _ = num_blocks_calculator(width, height)
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
            image_feature_sizes = [num_blocks * num_patches]
        elif is_list_of(image_data, Image.Image):
            image_feature_sizes = []
            for image in image_data:
                width, height = image.size
                num_blocks, _, _ = num_blocks_calculator(width, height)
                image_feature_sizes.append(num_blocks * num_patches)
        elif isinstance(image_data, torch.Tensor):
            num_images, image_feature_size, hidden_size = image_data.shape
            image_feature_sizes = [image_feature_size]
        else:
            raise TypeError(f"Invalid image type: {type(image_data)}")

        tokenizer = cached_get_tokenizer(
            model_config.tokenizer,
            trust_remote_code=model_config.trust_remote_code)

317
318
        prompt = inputs.get("prompt")
        prompt_token_ids = inputs["prompt_token_ids"]
319
320
321
322
323
324
325
        if prompt is None:
            prompt = tokenizer.decode(prompt_token_ids)

        new_prompt = self._expand_image_prompt(prompt, image_feature_sizes,
                                               num_patches)
        new_prompt_token_ids = tokenizer.encode(new_prompt)

326
327
328
        return token_inputs(prompt=prompt,
                            prompt_token_ids=new_prompt_token_ids,
                            multi_modal_data=multi_modal_data)
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347

    def input_mapper(
        self,
        ctx: InputContext,
        data: object,
        *,
        max_dynamic_patch: Optional[int] = None,
    ):
        hf_config = ctx.get_hf_config()

        image_pixel_values_mapper = image_to_pixel_values_wrapper(
            hf_config, max_dynamic_patch)
        if isinstance(data, Image.Image):
            data = image_pixel_values_mapper(data)
            # Add an N dimension for number of images per prompt (currently 1).
            data = data.unsqueeze(0)
        elif is_list_of(data, Image.Image):
            # we can't stack here because images may have different num_patches
            data = [image_pixel_values_mapper(img) for img in data]
348
349
        else:
            return MultiModalInputs({"image_embeds": data})
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
        model_config = ctx.model_config
        tokenizer = cached_get_tokenizer(
            model_config.tokenizer,
            trust_remote_code=model_config.trust_remote_code)
        image_token_id = tokenizer.encode(self.img_context_token,
                                          add_special_tokens=False,
                                          return_tensors="pt")[0]

        return MultiModalInputs({
            "pixel_values": data,
            "image_token_id": image_token_id
        })

    def dummy_data(
        self,
        ctx: InputContext,
        seq_len: int,
        mm_counts: Mapping[str, int],
        *,
        max_dynamic_patch: Optional[int] = None,
    ):
        num_images = mm_counts["image"]

        hf_config = ctx.get_hf_config()

        image_feature_size = get_max_internvl_image_tokens(
            ctx, max_dynamic_patch=max_dynamic_patch)
        model_config = ctx.model_config
        tokenizer = cached_get_tokenizer(
            model_config.tokenizer,
            trust_remote_code=model_config.trust_remote_code)

        seq_data = dummy_seq_data_for_clip(
            hf_config.vision_config,
            seq_len,
            num_images,
            image_token_id=tokenizer.encode(self.img_context_token,
                                            add_special_tokens=False)[0],
            image_feature_size_override=image_feature_size,
        )

        max_image_width, max_image_height = get_max_internvl_image_size(
            ctx, max_dynamic_patch=max_dynamic_patch)

        mm_data = dummy_image_for_clip(
            hf_config.vision_config,
            num_images,
            image_width_override=max_image_width,
            image_height_override=max_image_height,
        )

        return seq_data, mm_data


input_pipeline = InternVLInputPipeline(IMG_START, IMG_END, IMG_CONTEXT)


@MULTIMODAL_REGISTRY.register_image_input_mapper(input_pipeline.input_mapper)
408
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens)
409
410
@INPUT_REGISTRY.register_dummy_data(input_pipeline.dummy_data)
@INPUT_REGISTRY.register_input_processor(input_pipeline.input_processor)
411
class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
412
413
414
415
416
417
418
419
420
421

    def __init__(self,
                 config: PretrainedConfig,
                 multimodal_config: MultiModalConfig,
                 cache_config: Optional[CacheConfig] = None,
                 quant_config: Optional[QuantizationConfig] = None) -> None:
        super().__init__()

        self.config = config
        self.multimodal_config = multimodal_config
422
        self._patch_quant_config(config, quant_config)
423
424
425
426
427
428
429
430
431

        image_size = config.force_image_size or config.vision_config.image_size
        patch_size = config.vision_config.patch_size
        self.patch_size = patch_size
        self.num_image_token = int(
            (image_size // patch_size)**2 * (config.downsample_ratio**2))
        self.downsample_ratio = config.downsample_ratio
        self.ps_version = config.ps_version

432
433
        self.llm_arch_name = config.text_config.architectures[0]
        self.is_mono = self.llm_arch_name == 'InternLM2VEForCausalLM'
434
435
436
437
438
439
        self.vision_model = self._init_vision_model(
            config,
            quant_config=quant_config,
            is_mono=self.is_mono,
            prefix="vision_model",
        )
440

441
        self.language_model = init_vllm_registered_model(
442
443
444
445
            config.text_config,
            cache_config,
            quant_config,
            prefix="language_model")
446

447
        self.mlp1 = self._init_mlp1(config)
448
449

        self.img_context_token_id = None
450
451
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)
452

453
454
455
456
457
458
459
460
461
462
463
464
    def _patch_quant_config(self, config: PretrainedConfig,
                            quant_config: QuantizationConfig):
        # the awq models from OpenGVLab missing `modules_to_not_convert`
        # patch the quant_config to add `modules_to_not_convert` back
        if isinstance(quant_config, AWQConfig):
            text_config = config.text_config
            llm_quant_config = getattr(text_config, "quantization_config",
                                       None)
            if (not quant_config.modules_to_not_convert) and \
                (llm_quant_config is not None):
                quant_config.modules_to_not_convert.append("vision_model")

465
466
    @cached_property
    def sampler(self):
467
        if hasattr(self.language_model, "sampler"):
468
469
470
            return self.language_model.sampler

        return Sampler()
471

472
473
474
475
476
477
478
479
    def _init_vision_model(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig],
        *,
        is_mono: bool,
        prefix: str,
    ):
480
        if not is_mono:
481
            vision_feature_layer = config.select_layer
482
483
484
485
486
            if vision_feature_layer < 0:
                num_hidden_layers = config.vision_config.num_hidden_layers \
                    + vision_feature_layer + 1
            else:
                num_hidden_layers = vision_feature_layer + 1
487

488
489
            return InternVisionModel(
                config.vision_config,
490
491
492
493
                quant_config=quant_config,
                num_hidden_layers_override=num_hidden_layers,
                prefix=prefix,
            )
494
495
        else:
            return InternVisionPatchModel(config.vision_config)
496
497
498
499
500
501
502
503
504
505
506
507
508

    def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
        vit_hidden_size = config.vision_config.hidden_size
        llm_hidden_size = config.text_config.hidden_size

        return nn.Sequential(
            nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2),
            nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2,
                      llm_hidden_size),
            nn.GELU(),
            nn.Linear(llm_hidden_size, llm_hidden_size),
        )

509
510
511
512
513
514
515
516
517
518
519
520
521
522
    def pixel_shuffle(self, x, scale_factor=0.5):
        n, w, h, c = x.size()
        # N, W, H, C --> N, W, H * scale, C // scale
        x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
        # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
        x = x.permute(0, 2, 1, 3).contiguous()
        x = x.view(n, int(h * scale_factor), int(w * scale_factor),
                   int(c / (scale_factor * scale_factor)))
        if self.ps_version == 'v1':
            pass
        else:
            x = x.permute(0, 2, 1, 3).contiguous()
        return x

523
    def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor:
524
525
526
527
528
529
530
531
532
533
534
535
        vit_embeds = self.vision_model(pixel_values=pixel_values)
        vit_embeds = vit_embeds[:, 1:, :]

        h = w = int(vit_embeds.shape[1]**0.5)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
        vit_embeds = self.pixel_shuffle(vit_embeds,
                                        scale_factor=self.downsample_ratio)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1,
                                        vit_embeds.shape[-1])
        vit_embeds = self.mlp1(vit_embeds)
        return vit_embeds

536
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
537
538
539
540
541
542
543
544

        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape)

            if actual_dims != expected_dims:
545
                expected_expr = str(expected_dims)
546
                raise ValueError(
547
548
549
                    "The expected shape of pixel values per image per batch "
                    f" per patch is {expected_expr}. "
                    f"You supplied {tuple(d.shape)}.")
550
551
552
553
554
555
556

        for d in data:
            _validate_shape(d)

        return data

    def _parse_and_validate_image_input(
557
            self, **kwargs: object) -> Optional[InternVLImageInputs]:
558
559
        pixel_values = kwargs.pop("pixel_values", None)
        image_token_id = kwargs.pop("image_token_id", None)
560
        image_embeds = kwargs.pop("image_embeds", None)
561

562
        if pixel_values is None and image_embeds is None:
563
564
            return None

565
566
567
568
        if image_embeds is not None:
            if not isinstance(image_embeds, torch.Tensor):
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
569

570
571
            return InternVLImageEmbeddingInputs(
                type="image_embeds",
572
                data=flatten_bn(image_embeds),
573
574
            )

575
576
        self.img_context_token_id = image_token_id[0]

577
578
579
580
        if pixel_values is not None:
            if not isinstance(pixel_values, (torch.Tensor, list)):
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
581
582
            # We need to flatten (B, N, P) to (B*N*P),
            # so we call flatten_bn twice.
583
584
            return InternVLImagePixelInputs(
                type="pixel_values",
585
                data=self._validate_pixel_values(
586
                    flatten_bn(flatten_bn(pixel_values), concat=True)),
587
588
589
590
591
592
593
594
595
596
597
598
599
            )

        raise AssertionError("This line should be unreachable.")

    def _process_image_input(
        self,
        image_input: InternVLImageInputs,
    ) -> torch.Tensor:
        if image_input["type"] == "image_embeds":
            return image_input["data"]

        assert self.vision_model is not None
        image_embeds = self.extract_feature(image_input["data"])
600

601
        return image_embeds
602

603
604
605
606
607
608
609
610
    def _get_visual_token_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
        if self.is_mono:
            visual_token_mask = (
                input_ids == self.img_context_token_id).reshape(-1, 1)
        else:
            visual_token_mask = None
        return visual_token_mask

611
612
613
614
615
616
617
618
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        **kwargs: object,
619
620
    ) -> Union[SamplerOutput, IntermediateTensors]:
        if intermediate_tensors is not None:
621
622
            input_ids = None
            inputs_embeds = None
623
            visual_token_mask = None
624
625
626
627
628
629
630
631
632
        else:
            image_input = self._parse_and_validate_image_input(**kwargs)
            if image_input is not None:
                inputs_embeds = self.language_model.model.get_input_embeddings(
                    input_ids)
                vision_embeddings = self._process_image_input(image_input)
                inputs_embeds = merge_multimodal_embeddings(
                    input_ids, inputs_embeds, vision_embeddings,
                    self.img_context_token_id)
633
                visual_token_mask = self._get_visual_token_mask(input_ids)
634
635
636
                input_ids = None
            else:
                inputs_embeds = None
637
638
639
640
641
642
643
644
645
646
647
648
649
650
                visual_token_mask = None

        forward_kwargs = {
            "input_ids": input_ids,
            "positions": positions,
            "kv_caches": kv_caches,
            "attn_metadata": attn_metadata,
            "intermediate_tensors": intermediate_tensors,
            "inputs_embeds": inputs_embeds,
        }
        if self.is_mono:
            forward_kwargs.update({"visual_token_mask": visual_token_mask})

        hidden_states = self.language_model.model(**forward_kwargs)
651
652
        return hidden_states

653
654
655
656
657
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
658
659
660
661
662
663
664
665
666
667
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        return self.language_model.sample(logits, sampling_metadata)

668
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
669
670
        loader = AutoWeightsLoader(self)
        loader.load_weights(weights)