llava_next.py 26.5 KB
Newer Older
1
from functools import cached_property
2
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
3
                    TypedDict, Union)
4
5
6

import torch
import torch.nn as nn
7
from PIL import Image
8
from transformers import CLIPVisionConfig, LlavaNextConfig, SiglipVisionConfig
9
10
11
12
13
from transformers.models.llava_next.modeling_llava_next import (
    get_anyres_image_grid_shape, unpad_image)
from typing_extensions import NotRequired

from vllm.attention import AttentionMetadata
14
from vllm.config import VllmConfig
15
16
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
                         InputContext)
Joe Runde's avatar
Joe Runde committed
17
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
18
from vllm.model_executor.sampling_metadata import SamplingMetadata
19
from vllm.multimodal import MULTIMODAL_REGISTRY
20
from vllm.multimodal.inputs import NestedTensors
21
from vllm.sequence import IntermediateTensors
22
from vllm.utils import is_list_of
23

24
25
from .clip import (CLIPVisionModel, dummy_image_for_clip,
                   dummy_seq_data_for_clip, get_clip_image_feature_size,
26
                   get_clip_patch_grid_length, input_processor_for_clip)
27
from .interfaces import SupportsMultiModal, SupportsPP
28
from .llava import LlavaMultiModalProjector, init_vision_tower_for_llava
29
30
31
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
                     dummy_seq_data_for_siglip, get_siglip_image_feature_size,
                     get_siglip_patch_grid_length, input_processor_for_siglip)
Cyrus Leung's avatar
Cyrus Leung committed
32
from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn,
33
                    init_vllm_registered_model, maybe_prefix)
34
35
36
37


class LlavaNextImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
38
    data: Union[torch.Tensor, List[torch.Tensor]]
39
    """
40
41
    Shape:
    `(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
42

43
44
    Note that `num_patches` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
45
    """
46
47

    image_sizes: NotRequired[torch.Tensor]
48
    """
49
    Shape: `(batch_size * num_images, 2)`
50
51
52

    This should be in `(height, width)` format.
    """
53
54


55
56
57
class LlavaNextImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
58
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
59
60
61
62
63
64
65

    `hidden_size` must match the hidden size of language model backbone.
    """


LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
                             LlavaNextImageEmbeddingInputs]
66
67


68
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79
69
def _get_llava_next_num_unpadded_features(
70
71
    original_height: int,
    original_width: int,
72
73
74
75
76
77
78
    npatches: int,
    num_patch_height: int,
    num_patch_width: int,
) -> Tuple[int, int]:
    current_height = npatches * num_patch_height
    current_width = npatches * num_patch_width

79
    original_aspect_ratio = original_width / original_height
80
81
    current_aspect_ratio = current_width / current_height

82
83
84
    if original_aspect_ratio > current_aspect_ratio:
        scale_factor = current_width / original_width
        new_height = int(original_height * scale_factor)
85
        padding = (current_height - new_height) // 2
86
        current_height -= 2 * padding
87
    else:
88
89
        scale_factor = current_height / original_height
        new_width = int(original_width * scale_factor)
90
        padding = (current_width - new_width) // 2
91
        current_width -= 2 * padding
92
93
94
95
96
97

    unpadded_features = current_height * current_width
    newline_features = current_height
    return (unpadded_features, newline_features)


98
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106
99
def get_llava_next_image_feature_size(
100
101
102
103
104
105
106
107
108
109
110
111
    hf_config: LlavaNextConfig,
    *,
    input_height: int,
    input_width: int,
) -> int:
    vision_config = hf_config.vision_config

    if isinstance(vision_config, CLIPVisionConfig):
        num_patches = get_clip_patch_grid_length(
            image_size=vision_config.image_size,
            patch_size=vision_config.patch_size,
        )
112
113
114
115
116
        base_feature_size = get_clip_image_feature_size(vision_config)
    elif isinstance(vision_config, SiglipVisionConfig):
        num_patches = get_siglip_patch_grid_length(
            image_size=vision_config.image_size,
            patch_size=vision_config.patch_size,
117
        )
118
119
120
121
122
123
124
125
126
127
128
129
        base_feature_size = get_siglip_image_feature_size(vision_config)
    else:
        msg = f"Unsupported vision config: {type(vision_config)}"
        raise NotImplementedError(msg)

    strategy = hf_config.vision_feature_select_strategy
    if strategy == "default":
        base_feature_size -= 1
    elif strategy == "full":
        pass
    else:
        raise ValueError(f"Unexpected select feature strategy: {strategy}")
130

131
132
133
134
135
    num_patch_height, num_patch_width = get_anyres_image_grid_shape(
        image_size=(input_height, input_width),
        grid_pinpoints=hf_config.image_grid_pinpoints,
        patch_size=vision_config.image_size,
    )
136

137
138
139
140
141
142
    (
        unpadded_feature_size,
        newline_feature_size,
    ) = _get_llava_next_num_unpadded_features(input_height, input_width,
                                              num_patches, num_patch_height,
                                              num_patch_width)
143

144
    return unpadded_feature_size + newline_feature_size + base_feature_size
145
146


147
def get_max_llava_next_image_tokens(ctx: InputContext):
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
    """Compute the max feature size for all possible image grid pinpoints."""
    return _get_pinpoint_with_largest_features(ctx)[0]


def _get_pinpoint_with_largest_features(
        ctx: InputContext) -> Tuple[int, Tuple[int, int]]:
    """Get the grid pinpoint with the largest features & its feature size."""
    hf_config = ctx.get_hf_config(LlavaNextConfig)
    largest_feature_size = 0
    largest_feature_pinpoint = None
    for (height, width) in hf_config.image_grid_pinpoints:
        feat_size = get_llava_next_image_feature_size(
            hf_config,
            input_height=height,
            input_width=width,
        )
        if feat_size > largest_feature_size:
            largest_feature_size = feat_size
            largest_feature_pinpoint = (height, width)
    if not largest_feature_size or largest_feature_pinpoint is None:
        raise ValueError("Cannot have a largest feature size of 0!")
    return largest_feature_size, largest_feature_pinpoint
170
171


172
173
def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
                              mm_counts: Mapping[str, int]):
174
175
    hf_config = ctx.get_hf_config(LlavaNextConfig)
    vision_config = hf_config.vision_config
176
    num_images = mm_counts["image"]
177

178
179
    image_feature_size, pinpoint = _get_pinpoint_with_largest_features(ctx)
    max_feat_height, max_feat_width = pinpoint
180
181

    if isinstance(vision_config, CLIPVisionConfig):
182
        seq_data, ranges = dummy_seq_data_for_clip(
183
184
            vision_config,
            seq_len,
185
            num_images,
186
187
188
189
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )

190
191
        mm_data = dummy_image_for_clip(
            vision_config,
192
            num_images,
193
194
            image_width_override=max_feat_width,
            image_height_override=max_feat_height,
195
        )
196

197
        return DummyData(seq_data, mm_data, ranges)
198
    elif isinstance(vision_config, SiglipVisionConfig):
199
        seq_data, ranges = dummy_seq_data_for_siglip(
200
201
            vision_config,
            seq_len,
202
            num_images,
203
204
205
206
207
208
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )

        mm_data = dummy_image_for_siglip(
            vision_config,
209
            num_images,
210
211
            image_width_override=max_feat_width,
            image_height_override=max_feat_height,
212
213
        )

214
        return DummyData(seq_data, mm_data, ranges)
215
216
217
218
219

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


220
221
222
def input_processor_for_llava_next(ctx: InputContext,
                                   inputs: DecoderOnlyInputs):
    multi_modal_data = inputs.get("multi_modal_data")
223
    if multi_modal_data is None or "image" not in multi_modal_data:
224
        return inputs
225

226
227
228
    model_config = ctx.model_config
    hf_config = ctx.get_hf_config(LlavaNextConfig)
    vision_config = hf_config.vision_config
229

230
231
232
233
234
235
236
237
238
    image_data = multi_modal_data["image"]
    if isinstance(image_data, Image.Image):
        width, height = image_data.size

        image_feature_size = get_llava_next_image_feature_size(
            hf_config,
            input_height=height,
            input_width=width,
        )
239
240
241
242
243
244
245
    elif is_list_of(image_data, Image.Image):
        image_feature_size = [
            get_llava_next_image_feature_size(hf_config,
                                              input_height=img.height,
                                              input_width=img.width)
            for img in image_data
        ]
246
    elif isinstance(image_data, torch.Tensor):
247
248
249
        num_images, image_feature_size, hidden_size = image_data.shape
    elif is_list_of(image_data, torch.Tensor):
        image_feature_size = [item.shape[1] for item in image_data]
250
251
    else:
        raise TypeError(f"Invalid image type: {type(image_data)}")
252

253
    vision_config = hf_config.vision_config
254

255
256
257
258
    if isinstance(vision_config, CLIPVisionConfig):
        return input_processor_for_clip(
            model_config,
            vision_config,
259
            inputs,
260
261
262
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )
263
264
265
266
    elif isinstance(vision_config, SiglipVisionConfig):
        return input_processor_for_siglip(
            model_config,
            vision_config,
267
            inputs,
268
269
270
271
272
273
274
275
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


276
@MULTIMODAL_REGISTRY.register_image_input_mapper()
277
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens)
278
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
279
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next)
280
281
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
                                        SupportsPP):
282

283
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
284
        super().__init__()
285
286
287
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
288

289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
        vision_feature_layer = config.vision_feature_layer
        # Determine the layer up to which we will initialize the vision tower
        if isinstance(vision_feature_layer, int):
            vision_hidden_size = config.vision_config.hidden_size
            self.feature_sample_layers = None
        # Used for multimodal granite models to control encoder outputs
        elif isinstance(vision_feature_layer, (list, tuple)):
            vision_hidden_size = config.vision_config.hidden_size * len(
                vision_feature_layer)
            self.feature_sample_layers = vision_feature_layer
        else:
            raise TypeError(
                f"vision_layer_feature type: {type(vision_feature_layer)}"
                " is not supported")

304
        self.config = config
305
        self.multimodal_config = multimodal_config
306

307
        # TODO: Optionally initializes this for supporting embeddings.
308
        self.vision_tower = init_vision_tower_for_llava(
309
310
311
            config,
            quant_config,
            require_post_norm=False,
312
            prefix=maybe_prefix(prefix, "vision_tower"))
313
314
        self.image_newline = nn.Parameter(
            torch.empty(config.text_config.hidden_size))
315
        self.multi_modal_projector = LlavaMultiModalProjector(
316
            vision_hidden_size=vision_hidden_size,
317
318
319
            text_hidden_size=config.text_config.hidden_size,
            projector_hidden_act=config.projector_hidden_act)

320
        self.language_model = init_vllm_registered_model(
321
            vllm_config=vllm_config,
322
323
324
325
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )

326
327
328
329
330
331
332
333
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @cached_property
    def sampler(self):
        if hasattr(self.language_model, "sampler"):
            return self.language_model.sampler

Joe Runde's avatar
Joe Runde committed
334
        return get_sampler()
335
336

    def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
337
338
339
340
341
342
343
344
345
346
347
348
349
        expected_dims = (2, )

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape)

            if actual_dims != expected_dims:
                expected_expr = str(expected_dims)
                raise ValueError(
                    f"The expected shape of image sizes per image per batch "
                    f"is {expected_expr}. You supplied {tuple(d.shape)}.")

        for d in data:
            _validate_shape(d)
350
351
352

        return data

353
354
355
356
    def _validate_pixel_values(
        self, data: Union[torch.Tensor, List[torch.Tensor]]
    ) -> Union[torch.Tensor, List[torch.Tensor]]:

357
358
359
360
361
362
363
364
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape[1:])

            if actual_dims != expected_dims:
                expected_expr = ("num_patches", *map(str, expected_dims))
365
                raise ValueError(
366
                    "The expected shape of pixel values per image per batch "
367
                    f"is {expected_expr}. You supplied {tuple(d.shape)}.")
368

369
370
        for d in data:
            _validate_shape(d)
371
372
373

        return data

374
    def _parse_and_validate_image_input(
375
            self, **kwargs: object) -> Optional[LlavaNextImageInputs]:
376
377
        pixel_values = kwargs.pop("pixel_values", None)
        image_sizes = kwargs.pop("image_sizes", None)
378
        image_embeds = kwargs.pop("image_embeds", None)
379

380
        if pixel_values is None and image_embeds is None:
381
            return None
382

383
384
385
386
        if pixel_values is not None:
            if not isinstance(pixel_values, (torch.Tensor, list)):
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
387

388
            if not isinstance(image_sizes, (torch.Tensor, list)):
389
390
                raise ValueError("Incorrect type of image sizes. "
                                 f"Got type: {type(image_sizes)}")
391

392
393
            return LlavaNextImagePixelInputs(
                type="pixel_values",
394
395
396
                data=self._validate_pixel_values(flatten_bn(pixel_values)),
                image_sizes=self._validate_image_sizes(
                    flatten_bn(image_sizes, concat=True)),
397
398
399
400
401
402
403
404
405
            )

        if image_embeds is not None:
            if not isinstance(image_embeds, torch.Tensor):
                raise ValueError("Incorrect type of image embeds. "
                                 f"Got type: {type(image_embeds)}")

            return LlavaNextImageEmbeddingInputs(
                type="image_embeds",
406
                data=flatten_bn(image_embeds),
407
408
409
            )

        raise AssertionError("This line should be unreachable.")
410

Cyrus Leung's avatar
Cyrus Leung committed
411
412
413
414
415
416
417
418
419
420
    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

421
422
423
424
425
    def _image_pixels_to_features(
        self,
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
Cyrus Leung's avatar
Cyrus Leung committed
426

427
428
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
429
430
        image_features = vision_tower(
            pixel_values, feature_sample_layers=self.feature_sample_layers)
Cyrus Leung's avatar
Cyrus Leung committed
431
432
433
434
435
436

        return self._select_image_features(
            image_features,
            strategy=self.config.vision_feature_select_strategy,
        )

437
    # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
    def _merge_image_patch_embeddings(self, image_size: torch.Tensor,
                                      patch_embeddings: torch.Tensor, *,
                                      strategy: str) -> torch.Tensor:
        if strategy == "flat":
            return patch_embeddings.flatten(0, 1)

        if strategy.startswith("spatial"):
            height = width = self.config.vision_config.image_size \
                // self.config.vision_config.patch_size

            base_patch_embeds = patch_embeddings[0]
            if height * width != base_patch_embeds.shape[0]:
                raise ValueError(
                    "The number of patches is not consistent with the "
                    "image size.")

            if patch_embeddings.shape[0] > 1:
                other_patch_embeds = patch_embeddings[1:]

457
458
459
                # Move to CPU to avoid floating-point errors
                orig_height, orig_width = image_size.tolist()

460
                # image_aspect_ratio == "anyres"
461
462
                num_patch_height, num_patch_width = get_anyres_image_grid_shape(
                    (orig_height, orig_width),
463
464
465
                    self.config.image_grid_pinpoints,
                    self.config.vision_config.image_size,
                )
466
467
468
469
                num_patches = num_patch_height * num_patch_width

                # Image patches might be padded for batch processing
                other_patch_embeds = other_patch_embeds[:num_patches] \
470
                    .view(num_patch_height, num_patch_width, height, width, -1)
471
472
473
474
475
476

                if "unpad" in strategy:
                    other_patch_embeds = other_patch_embeds \
                        .permute(4, 0, 2, 1, 3).contiguous() \
                        .flatten(1, 2).flatten(2, 3)
                    other_patch_embeds = unpad_image(other_patch_embeds,
477
                                                     (orig_height, orig_width))
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
                    other_patch_embeds = torch.cat((
                        other_patch_embeds,
                        self.image_newline[:, None, None] \
                            .expand(*other_patch_embeds.shape[:-1], 1) \
                            .to(other_patch_embeds.device),
                    ), dim=-1)
                    other_patch_embeds = other_patch_embeds \
                        .flatten(1, 2).transpose(0, 1)
                else:
                    other_patch_embeds = other_patch_embeds \
                        .permute(0, 2, 1, 3, 4).contiguous() \
                        .flatten(0, 3)

                merged_patch_embeddings = torch.cat(
                    (base_patch_embeds, other_patch_embeds), dim=0)
            else:
                if "unpad" in strategy:
                    merged_patch_embeddings = torch.cat(
                        (base_patch_embeds,
                         self.image_newline[None] \
                            .to(base_patch_embeds.device)
                    ), dim=0)
                else:
                    merged_patch_embeddings = base_patch_embeds

            return merged_patch_embeddings

        raise ValueError(f"Unexpected patch merge strategy: {strategy}")

    def _process_image_pixels(
508
509
        self,
        inputs: LlavaNextImagePixelInputs,
510
    ) -> Union[torch.Tensor, List[torch.Tensor]]:
511
512
513
514
        assert self.vision_tower is not None

        pixel_values = inputs["data"]

515
516
517
518
519
520
521
        if isinstance(pixel_values, torch.Tensor):
            b, num_patches, c, h, w = pixel_values.shape
            stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
            stacked_image_features = self._image_pixels_to_features(
                self.vision_tower, stacked_pixel_values)
            stacked_patch_embeddings = self.multi_modal_projector(
                stacked_image_features)
522

523
524
525
526
527
            return stacked_patch_embeddings.view(
                b, num_patches, *stacked_patch_embeddings.shape[1:])

        num_patches_per_batch = [v.shape[0] for v in pixel_values]
        stacked_pixel_values = torch.cat(pixel_values)
528
529
530
        stacked_image_features = self._image_pixels_to_features(
            self.vision_tower, stacked_pixel_values)

531
532
533
534
        return [
            self.multi_modal_projector(image_features) for image_features in
            torch.split(stacked_image_features, num_patches_per_batch)
        ]
535
536

    def _process_image_input(
537
538
539
        self,
        image_input: LlavaNextImageInputs,
    ) -> Union[torch.Tensor, List[torch.Tensor]]:
540
541
542
        if image_input["type"] == "image_embeds":
            return [image_input["data"]]

543
        patch_embeddings = self._process_image_pixels(image_input)
544
545
546

        image_sizes = image_input.get("image_sizes")
        if image_sizes is None:
547
            batch_size = len(image_input["data"])
548
            vision_config = self.config.vision_config
549
550
            default_height = default_width = vision_config.image_size
            image_sizes = torch.as_tensor([[default_height, default_width]
551
552
                                           for _ in range(batch_size)])

553
        return [
554
            self._merge_image_patch_embeddings(image_sizes[i],
555
                                               patch_features_batch,
556
                                               strategy="spatial_unpad")
557
            for i, patch_features_batch in enumerate(patch_embeddings)
558
559
        ]

560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
    def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: Optional[NestedTensors] = None,
    ) -> torch.Tensor:

        if multimodal_embeddings is None:
            return self.language_model.get_input_embeddings(input_ids)

        inputs_embeds = embed_multimodal(
            input_ids,
            self.config.image_token_index,
            self.language_model.model.get_input_embeddings,
            multimodal_embeddings,
        )
        return inputs_embeds

584
585
586
587
588
589
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
590
        intermediate_tensors: Optional[IntermediateTensors] = None,
591
        inputs_embeds: Optional[torch.Tensor] = None,
592
        **kwargs: object,
593
    ) -> Union[torch.Tensor, IntermediateTensors]:
Cyrus Leung's avatar
Cyrus Leung committed
594
        """Run forward pass for LlaVA-NeXT.
595
596
597

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
598

599
        Concretely, consider a text prompt:
600
601
602
603
604
        `"A chat between a curious human and an artificial intelligence
        assistant. The assistant gives helpful, detailed, and polite answers to
        the human's questions.
        USER: <image>\\nWhat is shown in this image? ASSISTANT:"`.

605
        Tokenizer outputs:
606
607
608
609
610
611
612
        `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
        29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
        6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
        29871, 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973, 319, 1799,
        9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
613
        before they are inputted to the model, so the input processor prepends
614
615
616
617
618
619
620
621
622
623
624
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
        29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
        6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
        29871, 32000, ..., 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973,
        319, 1799, 9047, 13566, 29901]`.

        Unlike in LLaVA-1.5, the number of image tokens inputted to the language
        model depends on the original size of the input image. Including the
        original image token in the input, the required number of image tokens
        is given by :func:`get_llava_next_image_feature_size`.
625
626
627
628
629
630
631

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
632
            pixel_values: The pixels in each grid patch for each input image.
633
            image_sizes: The original `(height, width)` for each input image.
634

Cyrus Leung's avatar
Cyrus Leung committed
635
        See also:
636
            :class:`LlavaNextImageInputs`
637
        """
638
639
        if intermediate_tensors is not None:
            inputs_embeds = None
640

641
642
643
644
645
646
647
        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
        elif inputs_embeds is None:
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
648

649
650
651
652
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  kv_caches,
                                                  attn_metadata,
653
                                                  intermediate_tensors,
654
                                                  inputs_embeds=inputs_embeds)
655
656
        return hidden_states

657
658
659
660
661
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
662
663
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
664
665
666
667
668
669

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
670
        return self.language_model.sample(logits, sampling_metadata)
671

672
673
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
674
        loader = AutoWeightsLoader(self)
675
        return loader.load_weights(weights)