"tests/entrypoints/openai/test_vision.py" did not exist on "bc34937d68e9715d8416457539fb528301cf6269"
llava_next.py 25.6 KB
Newer Older
1
from functools import cached_property
2
3
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
                    TypedDict, Union)
4
5
6

import torch
import torch.nn as nn
7
from PIL import Image
8
from transformers import CLIPVisionConfig, LlavaNextConfig, SiglipVisionConfig
9
10
11
12
13
from transformers.models.llava_next.modeling_llava_next import (
    get_anyres_image_grid_shape, unpad_image)
from typing_extensions import NotRequired

from vllm.attention import AttentionMetadata
14
from vllm.config import CacheConfig, MultiModalConfig, PoolerConfig
15
16
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
                         InputContext)
Cyrus Leung's avatar
Cyrus Leung committed
17
from vllm.model_executor.layers.pooler import Pooler, PoolingType
18
from vllm.model_executor.layers.quantization import QuantizationConfig
19
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
Cyrus Leung's avatar
Cyrus Leung committed
20
from vllm.model_executor.pooling_metadata import PoolingMetadata
21
from vllm.model_executor.sampling_metadata import SamplingMetadata
22
from vllm.multimodal import MULTIMODAL_REGISTRY
Cyrus Leung's avatar
Cyrus Leung committed
23
from vllm.sequence import IntermediateTensors, PoolerOutput
24
from vllm.utils import is_list_of
25

26
27
from .clip import (CLIPVisionModel, dummy_image_for_clip,
                   dummy_seq_data_for_clip, get_clip_image_feature_size,
28
                   get_clip_patch_grid_length, input_processor_for_clip)
29
from .interfaces import SupportsMultiModal, SupportsPP
30
from .llava import LlavaMultiModalProjector, init_vision_tower_for_llava
31
32
33
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
                     dummy_seq_data_for_siglip, get_siglip_image_feature_size,
                     get_siglip_patch_grid_length, input_processor_for_siglip)
Cyrus Leung's avatar
Cyrus Leung committed
34
35
from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn,
                    init_vllm_registered_model)
36
37
38
39


class LlavaNextImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
40
    data: Union[torch.Tensor, List[torch.Tensor]]
41
    """
42
43
    Shape:
    `(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
44

45
46
    Note that `num_patches` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
47
    """
48
49

    image_sizes: NotRequired[torch.Tensor]
50
    """
51
    Shape: `(batch_size * num_images, 2)`
52
53
54

    This should be in `(height, width)` format.
    """
55
56


57
58
59
class LlavaNextImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
60
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
61
62
63
64
65
66
67

    `hidden_size` must match the hidden size of language model backbone.
    """


LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
                             LlavaNextImageEmbeddingInputs]
68
69


70
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79
71
def _get_llava_next_num_unpadded_features(
72
73
    original_height: int,
    original_width: int,
74
75
76
77
78
79
80
    npatches: int,
    num_patch_height: int,
    num_patch_width: int,
) -> Tuple[int, int]:
    current_height = npatches * num_patch_height
    current_width = npatches * num_patch_width

81
    original_aspect_ratio = original_width / original_height
82
83
    current_aspect_ratio = current_width / current_height

84
85
86
    if original_aspect_ratio > current_aspect_ratio:
        scale_factor = current_width / original_width
        new_height = int(original_height * scale_factor)
87
        padding = (current_height - new_height) // 2
88
        current_height -= 2 * padding
89
    else:
90
91
        scale_factor = current_height / original_height
        new_width = int(original_width * scale_factor)
92
        padding = (current_width - new_width) // 2
93
        current_width -= 2 * padding
94
95
96
97
98
99

    unpadded_features = current_height * current_width
    newline_features = current_height
    return (unpadded_features, newline_features)


100
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106
101
def get_llava_next_image_feature_size(
102
103
104
105
106
107
108
109
110
111
112
113
    hf_config: LlavaNextConfig,
    *,
    input_height: int,
    input_width: int,
) -> int:
    vision_config = hf_config.vision_config

    if isinstance(vision_config, CLIPVisionConfig):
        num_patches = get_clip_patch_grid_length(
            image_size=vision_config.image_size,
            patch_size=vision_config.patch_size,
        )
114
115
116
117
118
        base_feature_size = get_clip_image_feature_size(vision_config)
    elif isinstance(vision_config, SiglipVisionConfig):
        num_patches = get_siglip_patch_grid_length(
            image_size=vision_config.image_size,
            patch_size=vision_config.patch_size,
119
        )
120
121
122
123
124
125
126
127
128
129
130
131
        base_feature_size = get_siglip_image_feature_size(vision_config)
    else:
        msg = f"Unsupported vision config: {type(vision_config)}"
        raise NotImplementedError(msg)

    strategy = hf_config.vision_feature_select_strategy
    if strategy == "default":
        base_feature_size -= 1
    elif strategy == "full":
        pass
    else:
        raise ValueError(f"Unexpected select feature strategy: {strategy}")
132

133
134
135
136
137
    num_patch_height, num_patch_width = get_anyres_image_grid_shape(
        image_size=(input_height, input_width),
        grid_pinpoints=hf_config.image_grid_pinpoints,
        patch_size=vision_config.image_size,
    )
138

139
140
141
142
143
144
    (
        unpadded_feature_size,
        newline_feature_size,
    ) = _get_llava_next_num_unpadded_features(input_height, input_width,
                                              num_patches, num_patch_height,
                                              num_patch_width)
145

146
    return unpadded_feature_size + newline_feature_size + base_feature_size
147
148


149
def get_max_llava_next_image_tokens(ctx: InputContext):
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
    """Compute the max feature size for all possible image grid pinpoints."""
    return _get_pinpoint_with_largest_features(ctx)[0]


def _get_pinpoint_with_largest_features(
        ctx: InputContext) -> Tuple[int, Tuple[int, int]]:
    """Get the grid pinpoint with the largest features & its feature size."""
    hf_config = ctx.get_hf_config(LlavaNextConfig)
    largest_feature_size = 0
    largest_feature_pinpoint = None
    for (height, width) in hf_config.image_grid_pinpoints:
        feat_size = get_llava_next_image_feature_size(
            hf_config,
            input_height=height,
            input_width=width,
        )
        if feat_size > largest_feature_size:
            largest_feature_size = feat_size
            largest_feature_pinpoint = (height, width)
    if not largest_feature_size or largest_feature_pinpoint is None:
        raise ValueError("Cannot have a largest feature size of 0!")
    return largest_feature_size, largest_feature_pinpoint
172
173


174
175
def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
                              mm_counts: Mapping[str, int]):
176
177
    hf_config = ctx.get_hf_config(LlavaNextConfig)
    vision_config = hf_config.vision_config
178
    num_images = mm_counts["image"]
179

180
181
    image_feature_size, pinpoint = _get_pinpoint_with_largest_features(ctx)
    max_feat_height, max_feat_width = pinpoint
182
183

    if isinstance(vision_config, CLIPVisionConfig):
184
        seq_data, ranges = dummy_seq_data_for_clip(
185
186
            vision_config,
            seq_len,
187
            num_images,
188
189
190
191
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )

192
193
        mm_data = dummy_image_for_clip(
            vision_config,
194
            num_images,
195
196
            image_width_override=max_feat_width,
            image_height_override=max_feat_height,
197
        )
198

199
        return DummyData(seq_data, mm_data, ranges)
200
    elif isinstance(vision_config, SiglipVisionConfig):
201
        seq_data, ranges = dummy_seq_data_for_siglip(
202
203
            vision_config,
            seq_len,
204
            num_images,
205
206
207
208
209
210
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )

        mm_data = dummy_image_for_siglip(
            vision_config,
211
            num_images,
212
213
            image_width_override=max_feat_width,
            image_height_override=max_feat_height,
214
215
        )

216
        return DummyData(seq_data, mm_data, ranges)
217
218
219
220
221

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


222
223
224
def input_processor_for_llava_next(ctx: InputContext,
                                   inputs: DecoderOnlyInputs):
    multi_modal_data = inputs.get("multi_modal_data")
225
    if multi_modal_data is None or "image" not in multi_modal_data:
226
        return inputs
227

228
229
230
    model_config = ctx.model_config
    hf_config = ctx.get_hf_config(LlavaNextConfig)
    vision_config = hf_config.vision_config
231

232
233
234
235
236
237
238
239
240
    image_data = multi_modal_data["image"]
    if isinstance(image_data, Image.Image):
        width, height = image_data.size

        image_feature_size = get_llava_next_image_feature_size(
            hf_config,
            input_height=height,
            input_width=width,
        )
241
242
243
244
245
246
247
    elif is_list_of(image_data, Image.Image):
        image_feature_size = [
            get_llava_next_image_feature_size(hf_config,
                                              input_height=img.height,
                                              input_width=img.width)
            for img in image_data
        ]
248
    elif isinstance(image_data, torch.Tensor):
249
250
251
        num_images, image_feature_size, hidden_size = image_data.shape
    elif is_list_of(image_data, torch.Tensor):
        image_feature_size = [item.shape[1] for item in image_data]
252
253
    else:
        raise TypeError(f"Invalid image type: {type(image_data)}")
254

255
    vision_config = hf_config.vision_config
256

257
258
259
260
    if isinstance(vision_config, CLIPVisionConfig):
        return input_processor_for_clip(
            model_config,
            vision_config,
261
            inputs,
262
263
264
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )
265
266
267
268
    elif isinstance(vision_config, SiglipVisionConfig):
        return input_processor_for_siglip(
            model_config,
            vision_config,
269
            inputs,
270
271
272
273
274
275
276
277
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


278
@MULTIMODAL_REGISTRY.register_image_input_mapper()
279
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens)
280
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
281
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next)
282
283
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
                                        SupportsPP):
284

285
286
    def __init__(self,
                 config: LlavaNextConfig,
287
                 multimodal_config: MultiModalConfig,
288
                 cache_config: Optional[CacheConfig] = None,
289
290
                 quant_config: Optional[QuantizationConfig] = None,
                 pooler_config: Optional[PoolerConfig] = None) -> None:
291
        super().__init__()
292
293

        self.config = config
294
        self.multimodal_config = multimodal_config
295

296
        # TODO: Optionally initializes this for supporting embeddings.
297
        self.vision_tower = init_vision_tower_for_llava(
298
299
300
301
            config,
            quant_config,
            require_post_norm=False,
            prefix="vision_tower")
302
303
        self.image_newline = nn.Parameter(
            torch.empty(config.text_config.hidden_size))
304
305
306
307
308
        self.multi_modal_projector = LlavaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            text_hidden_size=config.text_config.hidden_size,
            projector_hidden_act=config.projector_hidden_act)

309
        self.language_model = init_vllm_registered_model(
310
311
312
313
            config.text_config,
            cache_config,
            quant_config,
            prefix="language_model")
314

Cyrus Leung's avatar
Cyrus Leung committed
315
316
        # The same model class supports both language generation and embedding
        # because the architecture name is the same
317
318
319
320
321
        self._pooler = Pooler.from_config_with_defaults(
            pooler_config,
            pooling_type=PoolingType.LAST,
            normalize=True,
            softmax=False)
322
323
324
325
326
327
328
329
330
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @cached_property
    def sampler(self):
        if hasattr(self.language_model, "sampler"):
            return self.language_model.sampler

        return Sampler()
331
332

    def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
333
334
335
336
337
338
339
340
341
342
343
344
345
        expected_dims = (2, )

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape)

            if actual_dims != expected_dims:
                expected_expr = str(expected_dims)
                raise ValueError(
                    f"The expected shape of image sizes per image per batch "
                    f"is {expected_expr}. You supplied {tuple(d.shape)}.")

        for d in data:
            _validate_shape(d)
346
347
348

        return data

349
350
351
352
    def _validate_pixel_values(
        self, data: Union[torch.Tensor, List[torch.Tensor]]
    ) -> Union[torch.Tensor, List[torch.Tensor]]:

353
354
355
356
357
358
359
360
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape[1:])

            if actual_dims != expected_dims:
                expected_expr = ("num_patches", *map(str, expected_dims))
361
                raise ValueError(
362
                    "The expected shape of pixel values per image per batch "
363
                    f"is {expected_expr}. You supplied {tuple(d.shape)}.")
364

365
366
        for d in data:
            _validate_shape(d)
367
368
369

        return data

370
    def _parse_and_validate_image_input(
371
            self, **kwargs: object) -> Optional[LlavaNextImageInputs]:
372
373
        pixel_values = kwargs.pop("pixel_values", None)
        image_sizes = kwargs.pop("image_sizes", None)
374
        image_embeds = kwargs.pop("image_embeds", None)
375

376
        if pixel_values is None and image_embeds is None:
377
            return None
378

379
380
381
382
        if pixel_values is not None:
            if not isinstance(pixel_values, (torch.Tensor, list)):
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
383

384
            if not isinstance(image_sizes, (torch.Tensor, list)):
385
386
                raise ValueError("Incorrect type of image sizes. "
                                 f"Got type: {type(image_sizes)}")
387

388
389
            return LlavaNextImagePixelInputs(
                type="pixel_values",
390
391
392
                data=self._validate_pixel_values(flatten_bn(pixel_values)),
                image_sizes=self._validate_image_sizes(
                    flatten_bn(image_sizes, concat=True)),
393
394
395
396
397
398
399
400
401
            )

        if image_embeds is not None:
            if not isinstance(image_embeds, torch.Tensor):
                raise ValueError("Incorrect type of image embeds. "
                                 f"Got type: {type(image_embeds)}")

            return LlavaNextImageEmbeddingInputs(
                type="image_embeds",
402
                data=flatten_bn(image_embeds),
403
404
405
            )

        raise AssertionError("This line should be unreachable.")
406

Cyrus Leung's avatar
Cyrus Leung committed
407
408
409
410
411
412
413
414
415
416
    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

417
418
419
420
421
    def _image_pixels_to_features(
        self,
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
Cyrus Leung's avatar
Cyrus Leung committed
422

423
424
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
425
        image_features = vision_tower(pixel_values)
Cyrus Leung's avatar
Cyrus Leung committed
426
427
428
429
430
431

        return self._select_image_features(
            image_features,
            strategy=self.config.vision_feature_select_strategy,
        )

432
    # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
    def _merge_image_patch_embeddings(self, image_size: torch.Tensor,
                                      patch_embeddings: torch.Tensor, *,
                                      strategy: str) -> torch.Tensor:
        if strategy == "flat":
            return patch_embeddings.flatten(0, 1)

        if strategy.startswith("spatial"):
            height = width = self.config.vision_config.image_size \
                // self.config.vision_config.patch_size

            base_patch_embeds = patch_embeddings[0]
            if height * width != base_patch_embeds.shape[0]:
                raise ValueError(
                    "The number of patches is not consistent with the "
                    "image size.")

            if patch_embeddings.shape[0] > 1:
                other_patch_embeds = patch_embeddings[1:]

452
453
454
                # Move to CPU to avoid floating-point errors
                orig_height, orig_width = image_size.tolist()

455
                # image_aspect_ratio == "anyres"
456
457
                num_patch_height, num_patch_width = get_anyres_image_grid_shape(
                    (orig_height, orig_width),
458
459
460
                    self.config.image_grid_pinpoints,
                    self.config.vision_config.image_size,
                )
461
462
463
464
                num_patches = num_patch_height * num_patch_width

                # Image patches might be padded for batch processing
                other_patch_embeds = other_patch_embeds[:num_patches] \
465
                    .view(num_patch_height, num_patch_width, height, width, -1)
466
467
468
469
470
471

                if "unpad" in strategy:
                    other_patch_embeds = other_patch_embeds \
                        .permute(4, 0, 2, 1, 3).contiguous() \
                        .flatten(1, 2).flatten(2, 3)
                    other_patch_embeds = unpad_image(other_patch_embeds,
472
                                                     (orig_height, orig_width))
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
                    other_patch_embeds = torch.cat((
                        other_patch_embeds,
                        self.image_newline[:, None, None] \
                            .expand(*other_patch_embeds.shape[:-1], 1) \
                            .to(other_patch_embeds.device),
                    ), dim=-1)
                    other_patch_embeds = other_patch_embeds \
                        .flatten(1, 2).transpose(0, 1)
                else:
                    other_patch_embeds = other_patch_embeds \
                        .permute(0, 2, 1, 3, 4).contiguous() \
                        .flatten(0, 3)

                merged_patch_embeddings = torch.cat(
                    (base_patch_embeds, other_patch_embeds), dim=0)
            else:
                if "unpad" in strategy:
                    merged_patch_embeddings = torch.cat(
                        (base_patch_embeds,
                         self.image_newline[None] \
                            .to(base_patch_embeds.device)
                    ), dim=0)
                else:
                    merged_patch_embeddings = base_patch_embeds

            return merged_patch_embeddings

        raise ValueError(f"Unexpected patch merge strategy: {strategy}")

    def _process_image_pixels(
503
504
        self,
        inputs: LlavaNextImagePixelInputs,
505
    ) -> Union[torch.Tensor, List[torch.Tensor]]:
506
507
508
509
        assert self.vision_tower is not None

        pixel_values = inputs["data"]

510
511
512
513
514
515
516
        if isinstance(pixel_values, torch.Tensor):
            b, num_patches, c, h, w = pixel_values.shape
            stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
            stacked_image_features = self._image_pixels_to_features(
                self.vision_tower, stacked_pixel_values)
            stacked_patch_embeddings = self.multi_modal_projector(
                stacked_image_features)
517

518
519
520
521
522
            return stacked_patch_embeddings.view(
                b, num_patches, *stacked_patch_embeddings.shape[1:])

        num_patches_per_batch = [v.shape[0] for v in pixel_values]
        stacked_pixel_values = torch.cat(pixel_values)
523
524
525
        stacked_image_features = self._image_pixels_to_features(
            self.vision_tower, stacked_pixel_values)

526
527
528
529
        return [
            self.multi_modal_projector(image_features) for image_features in
            torch.split(stacked_image_features, num_patches_per_batch)
        ]
530
531

    def _process_image_input(
532
533
534
        self,
        image_input: LlavaNextImageInputs,
    ) -> Union[torch.Tensor, List[torch.Tensor]]:
535
536
537
        if image_input["type"] == "image_embeds":
            return [image_input["data"]]

538
        patch_embeddings = self._process_image_pixels(image_input)
539
540
541

        image_sizes = image_input.get("image_sizes")
        if image_sizes is None:
542
            batch_size = len(image_input["data"])
543
            vision_config = self.config.vision_config
544
545
            default_height = default_width = vision_config.image_size
            image_sizes = torch.as_tensor([[default_height, default_width]
546
547
                                           for _ in range(batch_size)])

548
        return [
549
            self._merge_image_patch_embeddings(image_sizes[i],
550
                                               patch_features_batch,
551
                                               strategy="spatial_unpad")
552
            for i, patch_features_batch in enumerate(patch_embeddings)
553
554
555
556
557
558
559
560
        ]

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
561
        intermediate_tensors: Optional[IntermediateTensors] = None,
562
        **kwargs: object,
563
    ) -> Union[torch.Tensor, IntermediateTensors]:
Cyrus Leung's avatar
Cyrus Leung committed
564
        """Run forward pass for LlaVA-NeXT.
565
566
567

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
568

569
        Concretely, consider a text prompt:
570
571
572
573
574
        `"A chat between a curious human and an artificial intelligence
        assistant. The assistant gives helpful, detailed, and polite answers to
        the human's questions.
        USER: <image>\\nWhat is shown in this image? ASSISTANT:"`.

575
        Tokenizer outputs:
576
577
578
579
580
581
582
        `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
        29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
        6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
        29871, 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973, 319, 1799,
        9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
583
        before they are inputted to the model, so the input processor prepends
584
585
586
587
588
589
590
591
592
593
594
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
        29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
        6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
        29871, 32000, ..., 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973,
        319, 1799, 9047, 13566, 29901]`.

        Unlike in LLaVA-1.5, the number of image tokens inputted to the language
        model depends on the original size of the input image. Including the
        original image token in the input, the required number of image tokens
        is given by :func:`get_llava_next_image_feature_size`.
595
596
597
598
599
600
601

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
602
            pixel_values: The pixels in each grid patch for each input image.
603
            image_sizes: The original `(height, width)` for each input image.
604

Cyrus Leung's avatar
Cyrus Leung committed
605
        See also:
606
            :class:`LlavaNextImageInputs`
607
        """
608
609
610
611
612
        if intermediate_tensors is not None:
            input_ids = None
            inputs_embeds = None
        else:
            image_input = self._parse_and_validate_image_input(**kwargs)
613

614
            if image_input is not None:
Cyrus Leung's avatar
Cyrus Leung committed
615
616
617
618
619
620
                inputs_embeds = embed_multimodal(
                    input_ids,
                    self.config.image_token_index,
                    self.language_model.model.get_input_embeddings,
                    lambda _: self._process_image_input(image_input),
                )
621
622
623
                input_ids = None
            else:
                inputs_embeds = None
624

625
626
627
628
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  kv_caches,
                                                  attn_metadata,
629
                                                  intermediate_tensors,
630
                                                  inputs_embeds=inputs_embeds)
631
632
633

        return hidden_states

634
635
636
637
638
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
639
640
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
641
642
643
644
645
646

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
647
        return self.language_model.sample(logits, sampling_metadata)
648

Cyrus Leung's avatar
Cyrus Leung committed
649
650
651
652
653
654
655
    def pooler(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> Optional[PoolerOutput]:
        return self._pooler(hidden_states, pooling_metadata)

656
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
657
658
        loader = AutoWeightsLoader(self)
        loader.load_weights(weights)