"vllm/vscode:/vscode.git/clone" did not exist on "afdabfbef4d3e7c404f75047f285da79c783cdb0"
llava_next.py 25.3 KB
Newer Older
1
from functools import cached_property
2
3
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
                    TypedDict, Union)
4
5
6

import torch
import torch.nn as nn
7
from PIL import Image
8
from transformers import CLIPVisionConfig, LlavaNextConfig, SiglipVisionConfig
9
10
11
12
13
from transformers.models.llava_next.modeling_llava_next import (
    get_anyres_image_grid_shape, unpad_image)
from typing_extensions import NotRequired

from vllm.attention import AttentionMetadata
14
from vllm.config import CacheConfig, MultiModalConfig
15
from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext
Cyrus Leung's avatar
Cyrus Leung committed
16
from vllm.model_executor.layers.pooler import Pooler, PoolingType
17
from vllm.model_executor.layers.quantization import QuantizationConfig
18
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
Cyrus Leung's avatar
Cyrus Leung committed
19
from vllm.model_executor.pooling_metadata import PoolingMetadata
20
from vllm.model_executor.sampling_metadata import SamplingMetadata
21
from vllm.multimodal import MULTIMODAL_REGISTRY
Cyrus Leung's avatar
Cyrus Leung committed
22
from vllm.sequence import IntermediateTensors, PoolerOutput
23
from vllm.utils import is_list_of
24

25
26
from .clip import (CLIPVisionModel, dummy_image_for_clip,
                   dummy_seq_data_for_clip, get_clip_image_feature_size,
27
                   get_clip_patch_grid_length, input_processor_for_clip)
28
from .interfaces import SupportsMultiModal, SupportsPP
29
from .llava import LlavaMultiModalProjector, init_vision_tower_for_llava
30
31
32
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
                     dummy_seq_data_for_siglip, get_siglip_image_feature_size,
                     get_siglip_patch_grid_length, input_processor_for_siglip)
Cyrus Leung's avatar
Cyrus Leung committed
33
34
from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn,
                    init_vllm_registered_model)
35
36
37
38


class LlavaNextImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
39
    data: Union[torch.Tensor, List[torch.Tensor]]
40
    """
41
42
    Shape:
    `(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
43

44
45
    Note that `num_patches` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
46
    """
47
48

    image_sizes: NotRequired[torch.Tensor]
49
    """
50
    Shape: `(batch_size * num_images, 2)`
51
52
53

    This should be in `(height, width)` format.
    """
54
55


56
57
58
class LlavaNextImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
59
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
60
61
62
63
64
65
66

    `hidden_size` must match the hidden size of language model backbone.
    """


LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
                             LlavaNextImageEmbeddingInputs]
67
68


69
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79
70
def _get_llava_next_num_unpadded_features(
71
72
    original_height: int,
    original_width: int,
73
74
75
76
77
78
79
    npatches: int,
    num_patch_height: int,
    num_patch_width: int,
) -> Tuple[int, int]:
    current_height = npatches * num_patch_height
    current_width = npatches * num_patch_width

80
    original_aspect_ratio = original_width / original_height
81
82
    current_aspect_ratio = current_width / current_height

83
84
85
    if original_aspect_ratio > current_aspect_ratio:
        scale_factor = current_width / original_width
        new_height = int(original_height * scale_factor)
86
        padding = (current_height - new_height) // 2
87
        current_height -= 2 * padding
88
    else:
89
90
        scale_factor = current_height / original_height
        new_width = int(original_width * scale_factor)
91
        padding = (current_width - new_width) // 2
92
        current_width -= 2 * padding
93
94
95
96
97
98

    unpadded_features = current_height * current_width
    newline_features = current_height
    return (unpadded_features, newline_features)


99
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106
100
def get_llava_next_image_feature_size(
101
102
103
104
105
106
107
108
109
110
111
112
    hf_config: LlavaNextConfig,
    *,
    input_height: int,
    input_width: int,
) -> int:
    vision_config = hf_config.vision_config

    if isinstance(vision_config, CLIPVisionConfig):
        num_patches = get_clip_patch_grid_length(
            image_size=vision_config.image_size,
            patch_size=vision_config.patch_size,
        )
113
114
115
116
117
        base_feature_size = get_clip_image_feature_size(vision_config)
    elif isinstance(vision_config, SiglipVisionConfig):
        num_patches = get_siglip_patch_grid_length(
            image_size=vision_config.image_size,
            patch_size=vision_config.patch_size,
118
        )
119
120
121
122
123
124
125
126
127
128
129
130
        base_feature_size = get_siglip_image_feature_size(vision_config)
    else:
        msg = f"Unsupported vision config: {type(vision_config)}"
        raise NotImplementedError(msg)

    strategy = hf_config.vision_feature_select_strategy
    if strategy == "default":
        base_feature_size -= 1
    elif strategy == "full":
        pass
    else:
        raise ValueError(f"Unexpected select feature strategy: {strategy}")
131

132
133
134
135
136
    num_patch_height, num_patch_width = get_anyres_image_grid_shape(
        image_size=(input_height, input_width),
        grid_pinpoints=hf_config.image_grid_pinpoints,
        patch_size=vision_config.image_size,
    )
137

138
139
140
141
142
143
    (
        unpadded_feature_size,
        newline_feature_size,
    ) = _get_llava_next_num_unpadded_features(input_height, input_width,
                                              num_patches, num_patch_height,
                                              num_patch_width)
144

145
    return unpadded_feature_size + newline_feature_size + base_feature_size
146
147


148
def get_max_llava_next_image_tokens(ctx: InputContext):
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
    """Compute the max feature size for all possible image grid pinpoints."""
    return _get_pinpoint_with_largest_features(ctx)[0]


def _get_pinpoint_with_largest_features(
        ctx: InputContext) -> Tuple[int, Tuple[int, int]]:
    """Get the grid pinpoint with the largest features & its feature size."""
    hf_config = ctx.get_hf_config(LlavaNextConfig)
    largest_feature_size = 0
    largest_feature_pinpoint = None
    for (height, width) in hf_config.image_grid_pinpoints:
        feat_size = get_llava_next_image_feature_size(
            hf_config,
            input_height=height,
            input_width=width,
        )
        if feat_size > largest_feature_size:
            largest_feature_size = feat_size
            largest_feature_pinpoint = (height, width)
    if not largest_feature_size or largest_feature_pinpoint is None:
        raise ValueError("Cannot have a largest feature size of 0!")
    return largest_feature_size, largest_feature_pinpoint
171
172


173
174
def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
                              mm_counts: Mapping[str, int]):
175
176
    hf_config = ctx.get_hf_config(LlavaNextConfig)
    vision_config = hf_config.vision_config
177
    num_images = mm_counts["image"]
178

179
180
    image_feature_size, pinpoint = _get_pinpoint_with_largest_features(ctx)
    max_feat_height, max_feat_width = pinpoint
181
182
183
184
185

    if isinstance(vision_config, CLIPVisionConfig):
        seq_data = dummy_seq_data_for_clip(
            vision_config,
            seq_len,
186
            num_images,
187
188
189
190
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )

191
192
        mm_data = dummy_image_for_clip(
            vision_config,
193
            num_images,
194
195
            image_width_override=max_feat_width,
            image_height_override=max_feat_height,
196
        )
197

198
199
200
201
202
        return seq_data, mm_data
    elif isinstance(vision_config, SiglipVisionConfig):
        seq_data = dummy_seq_data_for_siglip(
            vision_config,
            seq_len,
203
            num_images,
204
205
206
207
208
209
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )

        mm_data = dummy_image_for_siglip(
            vision_config,
210
            num_images,
211
212
            image_width_override=max_feat_width,
            image_height_override=max_feat_height,
213
214
        )

215
216
217
218
219
220
        return seq_data, mm_data

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


221
222
223
def input_processor_for_llava_next(ctx: InputContext,
                                   inputs: DecoderOnlyInputs):
    multi_modal_data = inputs.get("multi_modal_data")
224
    if multi_modal_data is None or "image" not in multi_modal_data:
225
        return inputs
226

227
228
229
    model_config = ctx.model_config
    hf_config = ctx.get_hf_config(LlavaNextConfig)
    vision_config = hf_config.vision_config
230

231
232
233
234
235
236
237
238
239
    image_data = multi_modal_data["image"]
    if isinstance(image_data, Image.Image):
        width, height = image_data.size

        image_feature_size = get_llava_next_image_feature_size(
            hf_config,
            input_height=height,
            input_width=width,
        )
240
241
242
243
244
245
246
    elif is_list_of(image_data, Image.Image):
        image_feature_size = [
            get_llava_next_image_feature_size(hf_config,
                                              input_height=img.height,
                                              input_width=img.width)
            for img in image_data
        ]
247
    elif isinstance(image_data, torch.Tensor):
248
249
250
        num_images, image_feature_size, hidden_size = image_data.shape
    elif is_list_of(image_data, torch.Tensor):
        image_feature_size = [item.shape[1] for item in image_data]
251
252
    else:
        raise TypeError(f"Invalid image type: {type(image_data)}")
253

254
    vision_config = hf_config.vision_config
255

256
257
258
259
    if isinstance(vision_config, CLIPVisionConfig):
        return input_processor_for_clip(
            model_config,
            vision_config,
260
            inputs,
261
262
263
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )
264
265
266
267
    elif isinstance(vision_config, SiglipVisionConfig):
        return input_processor_for_siglip(
            model_config,
            vision_config,
268
            inputs,
269
270
271
272
273
274
275
276
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


277
@MULTIMODAL_REGISTRY.register_image_input_mapper()
278
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens)
279
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
280
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next)
281
282
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
                                        SupportsPP):
283

284
285
    def __init__(self,
                 config: LlavaNextConfig,
286
                 multimodal_config: MultiModalConfig,
287
288
                 cache_config: Optional[CacheConfig] = None,
                 quant_config: Optional[QuantizationConfig] = None) -> None:
289
        super().__init__()
290
291

        self.config = config
292
        self.multimodal_config = multimodal_config
293

294
        # TODO: Optionally initializes this for supporting embeddings.
295
296
        self.vision_tower = init_vision_tower_for_llava(
            config, quant_config, require_post_norm=False)
297
298
        self.image_newline = nn.Parameter(
            torch.empty(config.text_config.hidden_size))
299
300
301
302
303
        self.multi_modal_projector = LlavaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            text_hidden_size=config.text_config.hidden_size,
            projector_hidden_act=config.projector_hidden_act)

304
305
        self.language_model = init_vllm_registered_model(
            config.text_config, cache_config, quant_config)
306

Cyrus Leung's avatar
Cyrus Leung committed
307
308
309
310
        # The same model class supports both language generation and embedding
        # because the architecture name is the same
        self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)

311
312
313
314
315
316
317
318
319
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @cached_property
    def sampler(self):
        if hasattr(self.language_model, "sampler"):
            return self.language_model.sampler

        return Sampler()
320
321

    def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
322
323
324
325
326
327
328
329
330
331
332
333
334
        expected_dims = (2, )

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape)

            if actual_dims != expected_dims:
                expected_expr = str(expected_dims)
                raise ValueError(
                    f"The expected shape of image sizes per image per batch "
                    f"is {expected_expr}. You supplied {tuple(d.shape)}.")

        for d in data:
            _validate_shape(d)
335
336
337

        return data

338
339
340
341
    def _validate_pixel_values(
        self, data: Union[torch.Tensor, List[torch.Tensor]]
    ) -> Union[torch.Tensor, List[torch.Tensor]]:

342
343
344
345
346
347
348
349
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape[1:])

            if actual_dims != expected_dims:
                expected_expr = ("num_patches", *map(str, expected_dims))
350
                raise ValueError(
351
                    "The expected shape of pixel values per image per batch "
352
                    f"is {expected_expr}. You supplied {tuple(d.shape)}.")
353

354
355
        for d in data:
            _validate_shape(d)
356
357
358

        return data

359
    def _parse_and_validate_image_input(
360
            self, **kwargs: object) -> Optional[LlavaNextImageInputs]:
361
362
        pixel_values = kwargs.pop("pixel_values", None)
        image_sizes = kwargs.pop("image_sizes", None)
363
        image_embeds = kwargs.pop("image_embeds", None)
364

365
        if pixel_values is None and image_embeds is None:
366
            return None
367

368
369
370
371
        if pixel_values is not None:
            if not isinstance(pixel_values, (torch.Tensor, list)):
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
372

373
            if not isinstance(image_sizes, (torch.Tensor, list)):
374
375
                raise ValueError("Incorrect type of image sizes. "
                                 f"Got type: {type(image_sizes)}")
376

377
378
            return LlavaNextImagePixelInputs(
                type="pixel_values",
379
380
381
                data=self._validate_pixel_values(flatten_bn(pixel_values)),
                image_sizes=self._validate_image_sizes(
                    flatten_bn(image_sizes, concat=True)),
382
383
384
385
386
387
388
389
390
            )

        if image_embeds is not None:
            if not isinstance(image_embeds, torch.Tensor):
                raise ValueError("Incorrect type of image embeds. "
                                 f"Got type: {type(image_embeds)}")

            return LlavaNextImageEmbeddingInputs(
                type="image_embeds",
391
                data=flatten_bn(image_embeds),
392
393
394
            )

        raise AssertionError("This line should be unreachable.")
395

Cyrus Leung's avatar
Cyrus Leung committed
396
397
398
399
400
401
402
403
404
405
    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

406
407
408
409
410
    def _image_pixels_to_features(
        self,
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
Cyrus Leung's avatar
Cyrus Leung committed
411

412
413
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
414
        image_features = vision_tower(pixel_values)
Cyrus Leung's avatar
Cyrus Leung committed
415
416
417
418
419
420

        return self._select_image_features(
            image_features,
            strategy=self.config.vision_feature_select_strategy,
        )

421
    # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
    def _merge_image_patch_embeddings(self, image_size: torch.Tensor,
                                      patch_embeddings: torch.Tensor, *,
                                      strategy: str) -> torch.Tensor:
        if strategy == "flat":
            return patch_embeddings.flatten(0, 1)

        if strategy.startswith("spatial"):
            height = width = self.config.vision_config.image_size \
                // self.config.vision_config.patch_size

            base_patch_embeds = patch_embeddings[0]
            if height * width != base_patch_embeds.shape[0]:
                raise ValueError(
                    "The number of patches is not consistent with the "
                    "image size.")

            if patch_embeddings.shape[0] > 1:
                other_patch_embeds = patch_embeddings[1:]

441
442
443
                # Move to CPU to avoid floating-point errors
                orig_height, orig_width = image_size.tolist()

444
                # image_aspect_ratio == "anyres"
445
446
                num_patch_height, num_patch_width = get_anyres_image_grid_shape(
                    (orig_height, orig_width),
447
448
449
                    self.config.image_grid_pinpoints,
                    self.config.vision_config.image_size,
                )
450
451
452
453
                num_patches = num_patch_height * num_patch_width

                # Image patches might be padded for batch processing
                other_patch_embeds = other_patch_embeds[:num_patches] \
454
                    .view(num_patch_height, num_patch_width, height, width, -1)
455
456
457
458
459
460

                if "unpad" in strategy:
                    other_patch_embeds = other_patch_embeds \
                        .permute(4, 0, 2, 1, 3).contiguous() \
                        .flatten(1, 2).flatten(2, 3)
                    other_patch_embeds = unpad_image(other_patch_embeds,
461
                                                     (orig_height, orig_width))
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
                    other_patch_embeds = torch.cat((
                        other_patch_embeds,
                        self.image_newline[:, None, None] \
                            .expand(*other_patch_embeds.shape[:-1], 1) \
                            .to(other_patch_embeds.device),
                    ), dim=-1)
                    other_patch_embeds = other_patch_embeds \
                        .flatten(1, 2).transpose(0, 1)
                else:
                    other_patch_embeds = other_patch_embeds \
                        .permute(0, 2, 1, 3, 4).contiguous() \
                        .flatten(0, 3)

                merged_patch_embeddings = torch.cat(
                    (base_patch_embeds, other_patch_embeds), dim=0)
            else:
                if "unpad" in strategy:
                    merged_patch_embeddings = torch.cat(
                        (base_patch_embeds,
                         self.image_newline[None] \
                            .to(base_patch_embeds.device)
                    ), dim=0)
                else:
                    merged_patch_embeddings = base_patch_embeds

            return merged_patch_embeddings

        raise ValueError(f"Unexpected patch merge strategy: {strategy}")

    def _process_image_pixels(
492
493
        self,
        inputs: LlavaNextImagePixelInputs,
494
    ) -> Union[torch.Tensor, List[torch.Tensor]]:
495
496
497
498
        assert self.vision_tower is not None

        pixel_values = inputs["data"]

499
500
501
502
503
504
505
        if isinstance(pixel_values, torch.Tensor):
            b, num_patches, c, h, w = pixel_values.shape
            stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
            stacked_image_features = self._image_pixels_to_features(
                self.vision_tower, stacked_pixel_values)
            stacked_patch_embeddings = self.multi_modal_projector(
                stacked_image_features)
506

507
508
509
510
511
            return stacked_patch_embeddings.view(
                b, num_patches, *stacked_patch_embeddings.shape[1:])

        num_patches_per_batch = [v.shape[0] for v in pixel_values]
        stacked_pixel_values = torch.cat(pixel_values)
512
513
514
        stacked_image_features = self._image_pixels_to_features(
            self.vision_tower, stacked_pixel_values)

515
516
517
518
        return [
            self.multi_modal_projector(image_features) for image_features in
            torch.split(stacked_image_features, num_patches_per_batch)
        ]
519
520

    def _process_image_input(
521
522
523
        self,
        image_input: LlavaNextImageInputs,
    ) -> Union[torch.Tensor, List[torch.Tensor]]:
524
525
526
        if image_input["type"] == "image_embeds":
            return [image_input["data"]]

527
        patch_embeddings = self._process_image_pixels(image_input)
528
529
530

        image_sizes = image_input.get("image_sizes")
        if image_sizes is None:
531
            batch_size = len(image_input["data"])
532
            vision_config = self.config.vision_config
533
534
            default_height = default_width = vision_config.image_size
            image_sizes = torch.as_tensor([[default_height, default_width]
535
536
                                           for _ in range(batch_size)])

537
        return [
538
            self._merge_image_patch_embeddings(image_sizes[i],
539
                                               patch_features_batch,
540
                                               strategy="spatial_unpad")
541
            for i, patch_features_batch in enumerate(patch_embeddings)
542
543
544
545
546
547
548
549
        ]

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
550
        intermediate_tensors: Optional[IntermediateTensors] = None,
551
        **kwargs: object,
552
    ) -> Union[torch.Tensor, IntermediateTensors]:
Cyrus Leung's avatar
Cyrus Leung committed
553
        """Run forward pass for LlaVA-NeXT.
554
555
556

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
557

558
        Concretely, consider a text prompt:
559
560
561
562
563
        `"A chat between a curious human and an artificial intelligence
        assistant. The assistant gives helpful, detailed, and polite answers to
        the human's questions.
        USER: <image>\\nWhat is shown in this image? ASSISTANT:"`.

564
        Tokenizer outputs:
565
566
567
568
569
570
571
        `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
        29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
        6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
        29871, 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973, 319, 1799,
        9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
572
        before they are inputted to the model, so the input processor prepends
573
574
575
576
577
578
579
580
581
582
583
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
        29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
        6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
        29871, 32000, ..., 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973,
        319, 1799, 9047, 13566, 29901]`.

        Unlike in LLaVA-1.5, the number of image tokens inputted to the language
        model depends on the original size of the input image. Including the
        original image token in the input, the required number of image tokens
        is given by :func:`get_llava_next_image_feature_size`.
584
585
586
587
588
589
590

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
591
            pixel_values: The pixels in each grid patch for each input image.
592
            image_sizes: The original `(height, width)` for each input image.
593

Cyrus Leung's avatar
Cyrus Leung committed
594
        See also:
595
            :class:`LlavaNextImageInputs`
596
        """
597
598
599
600
601
        if intermediate_tensors is not None:
            input_ids = None
            inputs_embeds = None
        else:
            image_input = self._parse_and_validate_image_input(**kwargs)
602

603
            if image_input is not None:
Cyrus Leung's avatar
Cyrus Leung committed
604
605
606
607
608
609
                inputs_embeds = embed_multimodal(
                    input_ids,
                    self.config.image_token_index,
                    self.language_model.model.get_input_embeddings,
                    lambda _: self._process_image_input(image_input),
                )
610
611
612
                input_ids = None
            else:
                inputs_embeds = None
613

614
615
616
617
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  kv_caches,
                                                  attn_metadata,
618
                                                  intermediate_tensors,
619
                                                  inputs_embeds=inputs_embeds)
620
621
622

        return hidden_states

623
624
625
626
627
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
628
629
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
630
631
632
633
634
635

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
636
        return self.language_model.sample(logits, sampling_metadata)
637

Cyrus Leung's avatar
Cyrus Leung committed
638
639
640
641
642
643
644
    def pooler(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> Optional[PoolerOutput]:
        return self._pooler(hidden_states, pooling_metadata)

645
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
646
647
        loader = AutoWeightsLoader(self)
        loader.load_weights(weights)