llava_next.py 25.8 KB
Newer Older
1
import itertools
2
3
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
                    TypedDict, Union)
4
5
6

import torch
import torch.nn as nn
7
from PIL import Image
8
from transformers import CLIPVisionConfig, LlavaNextConfig, SiglipVisionConfig
9
10
11
12
13
from transformers.models.llava_next.modeling_llava_next import (
    get_anyres_image_grid_shape, unpad_image)
from typing_extensions import NotRequired

from vllm.attention import AttentionMetadata
14
from vllm.config import CacheConfig, MultiModalConfig
15
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
16
from vllm.logger import init_logger
17
from vllm.model_executor.layers.quantization import QuantizationConfig
18
from vllm.model_executor.layers.sampler import SamplerOutput
19
20
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
21
from vllm.multimodal import MULTIMODAL_REGISTRY
22
from vllm.sequence import IntermediateTensors
23
from vllm.utils import is_list_of
24

25
26
from .clip import (CLIPVisionModel, dummy_image_for_clip,
                   dummy_seq_data_for_clip, get_clip_image_feature_size,
27
                   get_clip_patch_grid_length, input_processor_for_clip)
28
from .interfaces import SupportsMultiModal
29
from .llava import LlavaMultiModalProjector
30
31
32
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
                     dummy_seq_data_for_siglip, get_siglip_image_feature_size,
                     get_siglip_patch_grid_length, input_processor_for_siglip)
33
from .utils import (filter_weights, flatten_bn, init_vllm_registered_model,
34
                    merge_multimodal_embeddings)
35
36
37
38
39
40
41
42

logger = init_logger(__name__)

_KEYS_TO_MODIFY_MAPPING = {
    "language_model.lm_head": "lm_head",
    "language_model.model": "language_model",
}

43
44
45
# Result in the max possible feature size (2x2 grid of 336x336px tiles)
MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448

46
47
48

class LlavaNextImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
49
    data: Union[torch.Tensor, List[torch.Tensor]]
50
    """
51
52
    Shape:
    `(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
53

54
55
    Note that `num_patches` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
56
    """
57
58

    image_sizes: NotRequired[torch.Tensor]
59
    """
60
    Shape: `(batch_size * num_images, 2)`
61
62
63

    This should be in `(height, width)` format.
    """
64
65


66
67
68
class LlavaNextImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
69
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
70
71
72
73
74
75
76

    `hidden_size` must match the hidden size of language model backbone.
    """


LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
                             LlavaNextImageEmbeddingInputs]
77
78


79
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79
80
def _get_llava_next_num_unpadded_features(
81
82
    original_height: int,
    original_width: int,
83
84
85
86
87
88
89
    npatches: int,
    num_patch_height: int,
    num_patch_width: int,
) -> Tuple[int, int]:
    current_height = npatches * num_patch_height
    current_width = npatches * num_patch_width

90
    original_aspect_ratio = original_width / original_height
91
92
    current_aspect_ratio = current_width / current_height

93
94
95
    if original_aspect_ratio > current_aspect_ratio:
        scale_factor = current_width / original_width
        new_height = int(original_height * scale_factor)
96
        padding = (current_height - new_height) // 2
97
        current_height -= 2 * padding
98
    else:
99
100
        scale_factor = current_height / original_height
        new_width = int(original_width * scale_factor)
101
        padding = (current_width - new_width) // 2
102
        current_width -= 2 * padding
103
104
105
106
107
108

    unpadded_features = current_height * current_width
    newline_features = current_height
    return (unpadded_features, newline_features)


109
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106
110
def get_llava_next_image_feature_size(
111
112
113
114
115
116
117
118
119
120
121
122
    hf_config: LlavaNextConfig,
    *,
    input_height: int,
    input_width: int,
) -> int:
    vision_config = hf_config.vision_config

    if isinstance(vision_config, CLIPVisionConfig):
        num_patches = get_clip_patch_grid_length(
            image_size=vision_config.image_size,
            patch_size=vision_config.patch_size,
        )
123
124
125
126
127
        base_feature_size = get_clip_image_feature_size(vision_config)
    elif isinstance(vision_config, SiglipVisionConfig):
        num_patches = get_siglip_patch_grid_length(
            image_size=vision_config.image_size,
            patch_size=vision_config.patch_size,
128
        )
129
130
131
132
133
134
135
136
137
138
139
140
        base_feature_size = get_siglip_image_feature_size(vision_config)
    else:
        msg = f"Unsupported vision config: {type(vision_config)}"
        raise NotImplementedError(msg)

    strategy = hf_config.vision_feature_select_strategy
    if strategy == "default":
        base_feature_size -= 1
    elif strategy == "full":
        pass
    else:
        raise ValueError(f"Unexpected select feature strategy: {strategy}")
141

142
143
144
145
146
    num_patch_height, num_patch_width = get_anyres_image_grid_shape(
        image_size=(input_height, input_width),
        grid_pinpoints=hf_config.image_grid_pinpoints,
        patch_size=vision_config.image_size,
    )
147

148
149
150
151
152
153
    (
        unpadded_feature_size,
        newline_feature_size,
    ) = _get_llava_next_num_unpadded_features(input_height, input_width,
                                              num_patches, num_patch_height,
                                              num_patch_width)
154

155
    return unpadded_feature_size + newline_feature_size + base_feature_size
156
157


158
159
160
def get_max_llava_next_image_tokens(ctx: InputContext):
    return get_llava_next_image_feature_size(
        ctx.get_hf_config(LlavaNextConfig),
161
162
        input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
        input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
163
164
165
    )


166
167
def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
                              mm_counts: Mapping[str, int]):
168
169
    hf_config = ctx.get_hf_config(LlavaNextConfig)
    vision_config = hf_config.vision_config
170
    num_images = mm_counts["image"]
171

172
    image_feature_size = get_max_llava_next_image_tokens(ctx)
173
174
175
176
177

    if isinstance(vision_config, CLIPVisionConfig):
        seq_data = dummy_seq_data_for_clip(
            vision_config,
            seq_len,
178
            num_images,
179
180
181
182
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )

183
184
        mm_data = dummy_image_for_clip(
            vision_config,
185
            num_images,
186
187
            image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
            image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
188
        )
189

190
191
192
193
194
        return seq_data, mm_data
    elif isinstance(vision_config, SiglipVisionConfig):
        seq_data = dummy_seq_data_for_siglip(
            vision_config,
            seq_len,
195
            num_images,
196
197
198
199
200
201
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )

        mm_data = dummy_image_for_siglip(
            vision_config,
202
            num_images,
203
204
205
206
            image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
            image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
        )

207
208
209
210
211
212
        return seq_data, mm_data

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


213
214
215
216
def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
    multi_modal_data = llm_inputs.get("multi_modal_data")
    if multi_modal_data is None or "image" not in multi_modal_data:
        return llm_inputs
217

218
219
220
    model_config = ctx.model_config
    hf_config = ctx.get_hf_config(LlavaNextConfig)
    vision_config = hf_config.vision_config
221

222
223
224
225
226
227
228
229
230
    image_data = multi_modal_data["image"]
    if isinstance(image_data, Image.Image):
        width, height = image_data.size

        image_feature_size = get_llava_next_image_feature_size(
            hf_config,
            input_height=height,
            input_width=width,
        )
231
232
233
234
235
236
237
    elif is_list_of(image_data, Image.Image):
        image_feature_size = [
            get_llava_next_image_feature_size(hf_config,
                                              input_height=img.height,
                                              input_width=img.width)
            for img in image_data
        ]
238
    elif isinstance(image_data, torch.Tensor):
239
240
241
        num_images, image_feature_size, hidden_size = image_data.shape
    elif is_list_of(image_data, torch.Tensor):
        image_feature_size = [item.shape[1] for item in image_data]
242
243
    else:
        raise TypeError(f"Invalid image type: {type(image_data)}")
244

245
    vision_config = hf_config.vision_config
246

247
248
249
250
251
252
253
254
    if isinstance(vision_config, CLIPVisionConfig):
        return input_processor_for_clip(
            model_config,
            vision_config,
            llm_inputs,
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
    elif isinstance(vision_config, SiglipVisionConfig):
        return input_processor_for_siglip(
            model_config,
            vision_config,
            llm_inputs,
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


def _init_vision_tower(hf_config: LlavaNextConfig):
    vision_config = hf_config.vision_config

    # Initialize the vision tower only up to the required feature layer
    vision_feature_layer = hf_config.vision_feature_layer
    if vision_feature_layer < 0:
        num_hidden_layers = hf_config.vision_config.num_hidden_layers \
            + vision_feature_layer + 1
    else:
        num_hidden_layers = vision_feature_layer + 1

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
            num_hidden_layers_override=num_hidden_layers,
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
            num_hidden_layers_override=num_hidden_layers,
        )
289

290
291
    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)
292
293


294
@MULTIMODAL_REGISTRY.register_image_input_mapper()
295
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens)
296
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
297
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next)
298
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
299

300
301
    def __init__(self,
                 config: LlavaNextConfig,
302
                 multimodal_config: MultiModalConfig,
303
304
                 cache_config: Optional[CacheConfig] = None,
                 quant_config: Optional[QuantizationConfig] = None) -> None:
305
        super().__init__()
306
307

        self.config = config
308
        self.multimodal_config = multimodal_config
309

310
        # TODO: Optionally initializes this for supporting embeddings.
311
        self.vision_tower = _init_vision_tower(config)
312
313
314
315
316
        self.multi_modal_projector = LlavaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            text_hidden_size=config.text_config.hidden_size,
            projector_hidden_act=config.projector_hidden_act)

317
318
        self.language_model = init_vllm_registered_model(
            config.text_config, cache_config, quant_config)
319
320
321
322
323

        self.image_newline = nn.Parameter(
            torch.empty(config.text_config.hidden_size))

    def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
324
325
326
327
328
329
330
331
332
333
334
335
336
        expected_dims = (2, )

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape)

            if actual_dims != expected_dims:
                expected_expr = str(expected_dims)
                raise ValueError(
                    f"The expected shape of image sizes per image per batch "
                    f"is {expected_expr}. You supplied {tuple(d.shape)}.")

        for d in data:
            _validate_shape(d)
337
338
339

        return data

340
341
342
343
    def _validate_pixel_values(
        self, data: Union[torch.Tensor, List[torch.Tensor]]
    ) -> Union[torch.Tensor, List[torch.Tensor]]:

344
345
346
347
348
349
350
351
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape[1:])

            if actual_dims != expected_dims:
                expected_expr = ("num_patches", *map(str, expected_dims))
352
                raise ValueError(
353
                    "The expected shape of pixel values per image per batch "
354
                    f"is {expected_expr}. You supplied {tuple(d.shape)}.")
355

356
357
        for d in data:
            _validate_shape(d)
358
359
360

        return data

361
    def _parse_and_validate_image_input(
362
            self, **kwargs: object) -> Optional[LlavaNextImageInputs]:
363
364
        pixel_values = kwargs.pop("pixel_values", None)
        image_sizes = kwargs.pop("image_sizes", None)
365
        image_embeds = kwargs.pop("image_embeds", None)
366

367
        if pixel_values is None and image_embeds is None:
368
            return None
369

370
371
372
373
        if pixel_values is not None:
            if not isinstance(pixel_values, (torch.Tensor, list)):
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
374

375
            if not isinstance(image_sizes, (torch.Tensor, list)):
376
377
                raise ValueError("Incorrect type of image sizes. "
                                 f"Got type: {type(image_sizes)}")
378

379
380
            return LlavaNextImagePixelInputs(
                type="pixel_values",
381
382
383
                data=self._validate_pixel_values(flatten_bn(pixel_values)),
                image_sizes=self._validate_image_sizes(
                    flatten_bn(image_sizes, concat=True)),
384
385
386
387
388
389
390
391
392
            )

        if image_embeds is not None:
            if not isinstance(image_embeds, torch.Tensor):
                raise ValueError("Incorrect type of image embeds. "
                                 f"Got type: {type(image_embeds)}")

            return LlavaNextImageEmbeddingInputs(
                type="image_embeds",
393
                data=flatten_bn(image_embeds),
394
395
396
            )

        raise AssertionError("This line should be unreachable.")
397

Cyrus Leung's avatar
Cyrus Leung committed
398
399
400
401
402
403
404
405
406
407
    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

408
409
410
411
412
    def _image_pixels_to_features(
        self,
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
Cyrus Leung's avatar
Cyrus Leung committed
413

414
415
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
416
        image_features = vision_tower(pixel_values)
Cyrus Leung's avatar
Cyrus Leung committed
417
418
419
420
421
422

        return self._select_image_features(
            image_features,
            strategy=self.config.vision_feature_select_strategy,
        )

423
    # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
    def _merge_image_patch_embeddings(self, image_size: torch.Tensor,
                                      patch_embeddings: torch.Tensor, *,
                                      strategy: str) -> torch.Tensor:
        if strategy == "flat":
            return patch_embeddings.flatten(0, 1)

        if strategy.startswith("spatial"):
            height = width = self.config.vision_config.image_size \
                // self.config.vision_config.patch_size

            base_patch_embeds = patch_embeddings[0]
            if height * width != base_patch_embeds.shape[0]:
                raise ValueError(
                    "The number of patches is not consistent with the "
                    "image size.")

            if patch_embeddings.shape[0] > 1:
                other_patch_embeds = patch_embeddings[1:]

443
444
445
                # Move to CPU to avoid floating-point errors
                orig_height, orig_width = image_size.tolist()

446
                # image_aspect_ratio == "anyres"
447
448
                num_patch_height, num_patch_width = get_anyres_image_grid_shape(
                    (orig_height, orig_width),
449
450
451
                    self.config.image_grid_pinpoints,
                    self.config.vision_config.image_size,
                )
452
453
454
455
                num_patches = num_patch_height * num_patch_width

                # Image patches might be padded for batch processing
                other_patch_embeds = other_patch_embeds[:num_patches] \
456
                    .view(num_patch_height, num_patch_width, height, width, -1)
457
458
459
460
461
462

                if "unpad" in strategy:
                    other_patch_embeds = other_patch_embeds \
                        .permute(4, 0, 2, 1, 3).contiguous() \
                        .flatten(1, 2).flatten(2, 3)
                    other_patch_embeds = unpad_image(other_patch_embeds,
463
                                                     (orig_height, orig_width))
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
                    other_patch_embeds = torch.cat((
                        other_patch_embeds,
                        self.image_newline[:, None, None] \
                            .expand(*other_patch_embeds.shape[:-1], 1) \
                            .to(other_patch_embeds.device),
                    ), dim=-1)
                    other_patch_embeds = other_patch_embeds \
                        .flatten(1, 2).transpose(0, 1)
                else:
                    other_patch_embeds = other_patch_embeds \
                        .permute(0, 2, 1, 3, 4).contiguous() \
                        .flatten(0, 3)

                merged_patch_embeddings = torch.cat(
                    (base_patch_embeds, other_patch_embeds), dim=0)
            else:
                if "unpad" in strategy:
                    merged_patch_embeddings = torch.cat(
                        (base_patch_embeds,
                         self.image_newline[None] \
                            .to(base_patch_embeds.device)
                    ), dim=0)
                else:
                    merged_patch_embeddings = base_patch_embeds

            return merged_patch_embeddings

        raise ValueError(f"Unexpected patch merge strategy: {strategy}")

    def _process_image_pixels(
494
495
        self,
        inputs: LlavaNextImagePixelInputs,
496
    ) -> Union[torch.Tensor, List[torch.Tensor]]:
497
498
499
500
        assert self.vision_tower is not None

        pixel_values = inputs["data"]

501
502
503
504
505
506
507
        if isinstance(pixel_values, torch.Tensor):
            b, num_patches, c, h, w = pixel_values.shape
            stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
            stacked_image_features = self._image_pixels_to_features(
                self.vision_tower, stacked_pixel_values)
            stacked_patch_embeddings = self.multi_modal_projector(
                stacked_image_features)
508

509
510
511
512
513
            return stacked_patch_embeddings.view(
                b, num_patches, *stacked_patch_embeddings.shape[1:])

        num_patches_per_batch = [v.shape[0] for v in pixel_values]
        stacked_pixel_values = torch.cat(pixel_values)
514
515
516
        stacked_image_features = self._image_pixels_to_features(
            self.vision_tower, stacked_pixel_values)

517
518
519
520
        return [
            self.multi_modal_projector(image_features) for image_features in
            torch.split(stacked_image_features, num_patches_per_batch)
        ]
521
522

    def _process_image_input(
523
524
525
        self,
        image_input: LlavaNextImageInputs,
    ) -> Union[torch.Tensor, List[torch.Tensor]]:
526
527
528
        if image_input["type"] == "image_embeds":
            return [image_input["data"]]

529
        patch_embeddings = self._process_image_pixels(image_input)
530
531
532

        image_sizes = image_input.get("image_sizes")
        if image_sizes is None:
533
            batch_size = len(image_input["data"])
534
            vision_config = self.config.vision_config
535
536
            default_height = default_width = vision_config.image_size
            image_sizes = torch.as_tensor([[default_height, default_width]
537
538
                                           for _ in range(batch_size)])

539
        return [
540
            self._merge_image_patch_embeddings(image_sizes[i],
541
                                               patch_features_batch,
542
                                               strategy="spatial_unpad")
543
            for i, patch_features_batch in enumerate(patch_embeddings)
544
545
546
547
548
549
550
551
        ]

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
552
        intermediate_tensors: Optional[IntermediateTensors] = None,
553
554
        **kwargs: object,
    ) -> SamplerOutput:
Cyrus Leung's avatar
Cyrus Leung committed
555
        """Run forward pass for LlaVA-NeXT.
556
557
558

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
559

560
        Concretely, consider a text prompt:
561
562
563
564
565
        `"A chat between a curious human and an artificial intelligence
        assistant. The assistant gives helpful, detailed, and polite answers to
        the human's questions.
        USER: <image>\\nWhat is shown in this image? ASSISTANT:"`.

566
        Tokenizer outputs:
567
568
569
570
571
572
573
        `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
        29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
        6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
        29871, 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973, 319, 1799,
        9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
574
        before they are inputted to the model, so the input processor prepends
575
576
577
578
579
580
581
582
583
584
585
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
        29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
        6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
        29871, 32000, ..., 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973,
        319, 1799, 9047, 13566, 29901]`.

        Unlike in LLaVA-1.5, the number of image tokens inputted to the language
        model depends on the original size of the input image. Including the
        original image token in the input, the required number of image tokens
        is given by :func:`get_llava_next_image_feature_size`.
586
587
588
589
590
591
592

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
593
            pixel_values: The pixels in each grid patch for each input image.
594
            image_sizes: The original `(height, width)` for each input image.
595

Cyrus Leung's avatar
Cyrus Leung committed
596
        See also:
597
            :class:`LlavaNextImageInputs`
598
599
600
601
602
        """
        image_input = self._parse_and_validate_image_input(**kwargs)

        if image_input is not None:
            vision_embeddings = self._process_image_input(image_input)
603
604
            inputs_embeds = self.language_model.model.get_input_embeddings(
                input_ids)
605

606
            inputs_embeds = merge_multimodal_embeddings(
607
                input_ids, inputs_embeds, vision_embeddings,
608
                self.config.image_token_index)
609
610
611
612
613

            input_ids = None
        else:
            inputs_embeds = None

614
615
616
617
618
619
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  kv_caches,
                                                  attn_metadata,
                                                  None,
                                                  inputs_embeds=inputs_embeds)
620
621
622

        return hidden_states

623
624
625
626
627
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
628
629
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
630
631
632
633
634
635

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
636
        return self.language_model.sample(logits, sampling_metadata)
637
638

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
        # prepare weight iterators for components
        vit_weights, mlp_weights, newline_weights, llm_weights = itertools.tee(
            weights, 4)

        # load vision encoder
        vit_weights = filter_weights(vit_weights, "vision_tower")
        self.vision_tower.load_weights(vit_weights)

        # load mlp projector
        mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
        mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
        for name, loaded_weight in mlp_weights:
            param = mlp_params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)

        # load newline
        newline_weights = filter_weights(newline_weights, "image_newline")
        for name, loaded_weight in newline_weights:
            assert name == ""
            param = self.image_newline
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)

        # load llm backbone
        llm_weights = filter_weights(llm_weights, "language_model")
        self.language_model.load_weights(llm_weights)