llava_next.py 25.9 KB
Newer Older
1
from functools import cached_property
2
3
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
                    TypedDict, Union)
4
5
6

import torch
import torch.nn as nn
7
from PIL import Image
8
from transformers import CLIPVisionConfig, LlavaNextConfig, SiglipVisionConfig
9
10
11
12
13
from transformers.models.llava_next.modeling_llava_next import (
    get_anyres_image_grid_shape, unpad_image)
from typing_extensions import NotRequired

from vllm.attention import AttentionMetadata
14
from vllm.config import CacheConfig, MultiModalConfig
15
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
16
from vllm.model_executor.layers.quantization import QuantizationConfig
17
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
18
19
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
20
from vllm.multimodal import MULTIMODAL_REGISTRY
21
from vllm.sequence import IntermediateTensors
22
from vllm.utils import is_list_of
23

24
25
from .clip import (CLIPVisionModel, dummy_image_for_clip,
                   dummy_seq_data_for_clip, get_clip_image_feature_size,
26
                   get_clip_patch_grid_length, input_processor_for_clip)
27
from .interfaces import SupportsMultiModal, SupportsPP
28
from .llava import LlavaMultiModalProjector
29
30
31
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
                     dummy_seq_data_for_siglip, get_siglip_image_feature_size,
                     get_siglip_patch_grid_length, input_processor_for_siglip)
32
33
from .utils import (flatten_bn, group_weights_with_prefix,
                    init_vllm_registered_model, merge_multimodal_embeddings)
34

35
36
37
# Result in the max possible feature size (2x2 grid of 336x336px tiles)
MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448

38
39
40

class LlavaNextImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
41
    data: Union[torch.Tensor, List[torch.Tensor]]
42
    """
43
44
    Shape:
    `(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
45

46
47
    Note that `num_patches` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
48
    """
49
50

    image_sizes: NotRequired[torch.Tensor]
51
    """
52
    Shape: `(batch_size * num_images, 2)`
53
54
55

    This should be in `(height, width)` format.
    """
56
57


58
59
60
class LlavaNextImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
61
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
62
63
64
65
66
67
68

    `hidden_size` must match the hidden size of language model backbone.
    """


LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
                             LlavaNextImageEmbeddingInputs]
69
70


71
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79
72
def _get_llava_next_num_unpadded_features(
73
74
    original_height: int,
    original_width: int,
75
76
77
78
79
80
81
    npatches: int,
    num_patch_height: int,
    num_patch_width: int,
) -> Tuple[int, int]:
    current_height = npatches * num_patch_height
    current_width = npatches * num_patch_width

82
    original_aspect_ratio = original_width / original_height
83
84
    current_aspect_ratio = current_width / current_height

85
86
87
    if original_aspect_ratio > current_aspect_ratio:
        scale_factor = current_width / original_width
        new_height = int(original_height * scale_factor)
88
        padding = (current_height - new_height) // 2
89
        current_height -= 2 * padding
90
    else:
91
92
        scale_factor = current_height / original_height
        new_width = int(original_width * scale_factor)
93
        padding = (current_width - new_width) // 2
94
        current_width -= 2 * padding
95
96
97
98
99
100

    unpadded_features = current_height * current_width
    newline_features = current_height
    return (unpadded_features, newline_features)


101
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106
102
def get_llava_next_image_feature_size(
103
104
105
106
107
108
109
110
111
112
113
114
    hf_config: LlavaNextConfig,
    *,
    input_height: int,
    input_width: int,
) -> int:
    vision_config = hf_config.vision_config

    if isinstance(vision_config, CLIPVisionConfig):
        num_patches = get_clip_patch_grid_length(
            image_size=vision_config.image_size,
            patch_size=vision_config.patch_size,
        )
115
116
117
118
119
        base_feature_size = get_clip_image_feature_size(vision_config)
    elif isinstance(vision_config, SiglipVisionConfig):
        num_patches = get_siglip_patch_grid_length(
            image_size=vision_config.image_size,
            patch_size=vision_config.patch_size,
120
        )
121
122
123
124
125
126
127
128
129
130
131
132
        base_feature_size = get_siglip_image_feature_size(vision_config)
    else:
        msg = f"Unsupported vision config: {type(vision_config)}"
        raise NotImplementedError(msg)

    strategy = hf_config.vision_feature_select_strategy
    if strategy == "default":
        base_feature_size -= 1
    elif strategy == "full":
        pass
    else:
        raise ValueError(f"Unexpected select feature strategy: {strategy}")
133

134
135
136
137
138
    num_patch_height, num_patch_width = get_anyres_image_grid_shape(
        image_size=(input_height, input_width),
        grid_pinpoints=hf_config.image_grid_pinpoints,
        patch_size=vision_config.image_size,
    )
139

140
141
142
143
144
145
    (
        unpadded_feature_size,
        newline_feature_size,
    ) = _get_llava_next_num_unpadded_features(input_height, input_width,
                                              num_patches, num_patch_height,
                                              num_patch_width)
146

147
    return unpadded_feature_size + newline_feature_size + base_feature_size
148
149


150
151
152
def get_max_llava_next_image_tokens(ctx: InputContext):
    return get_llava_next_image_feature_size(
        ctx.get_hf_config(LlavaNextConfig),
153
154
        input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
        input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
155
156
157
    )


158
159
def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
                              mm_counts: Mapping[str, int]):
160
161
    hf_config = ctx.get_hf_config(LlavaNextConfig)
    vision_config = hf_config.vision_config
162
    num_images = mm_counts["image"]
163

164
    image_feature_size = get_max_llava_next_image_tokens(ctx)
165
166
167
168
169

    if isinstance(vision_config, CLIPVisionConfig):
        seq_data = dummy_seq_data_for_clip(
            vision_config,
            seq_len,
170
            num_images,
171
172
173
174
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )

175
176
        mm_data = dummy_image_for_clip(
            vision_config,
177
            num_images,
178
179
            image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
            image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
180
        )
181

182
183
184
185
186
        return seq_data, mm_data
    elif isinstance(vision_config, SiglipVisionConfig):
        seq_data = dummy_seq_data_for_siglip(
            vision_config,
            seq_len,
187
            num_images,
188
189
190
191
192
193
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )

        mm_data = dummy_image_for_siglip(
            vision_config,
194
            num_images,
195
196
197
198
            image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
            image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
        )

199
200
201
202
203
204
        return seq_data, mm_data

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


205
206
207
208
def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
    multi_modal_data = llm_inputs.get("multi_modal_data")
    if multi_modal_data is None or "image" not in multi_modal_data:
        return llm_inputs
209

210
211
212
    model_config = ctx.model_config
    hf_config = ctx.get_hf_config(LlavaNextConfig)
    vision_config = hf_config.vision_config
213

214
215
216
217
218
219
220
221
222
    image_data = multi_modal_data["image"]
    if isinstance(image_data, Image.Image):
        width, height = image_data.size

        image_feature_size = get_llava_next_image_feature_size(
            hf_config,
            input_height=height,
            input_width=width,
        )
223
224
225
226
227
228
229
    elif is_list_of(image_data, Image.Image):
        image_feature_size = [
            get_llava_next_image_feature_size(hf_config,
                                              input_height=img.height,
                                              input_width=img.width)
            for img in image_data
        ]
230
    elif isinstance(image_data, torch.Tensor):
231
232
233
        num_images, image_feature_size, hidden_size = image_data.shape
    elif is_list_of(image_data, torch.Tensor):
        image_feature_size = [item.shape[1] for item in image_data]
234
235
    else:
        raise TypeError(f"Invalid image type: {type(image_data)}")
236

237
    vision_config = hf_config.vision_config
238

239
240
241
242
243
244
245
246
    if isinstance(vision_config, CLIPVisionConfig):
        return input_processor_for_clip(
            model_config,
            vision_config,
            llm_inputs,
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
    elif isinstance(vision_config, SiglipVisionConfig):
        return input_processor_for_siglip(
            model_config,
            vision_config,
            llm_inputs,
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


def _init_vision_tower(hf_config: LlavaNextConfig):
    vision_config = hf_config.vision_config

    # Initialize the vision tower only up to the required feature layer
    vision_feature_layer = hf_config.vision_feature_layer
    if vision_feature_layer < 0:
        num_hidden_layers = hf_config.vision_config.num_hidden_layers \
            + vision_feature_layer + 1
    else:
        num_hidden_layers = vision_feature_layer + 1

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
            num_hidden_layers_override=num_hidden_layers,
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
            num_hidden_layers_override=num_hidden_layers,
        )
281

282
283
    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)
284
285


286
@MULTIMODAL_REGISTRY.register_image_input_mapper()
287
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens)
288
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
289
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next)
290
291
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
                                        SupportsPP):
292

293
294
    def __init__(self,
                 config: LlavaNextConfig,
295
                 multimodal_config: MultiModalConfig,
296
297
                 cache_config: Optional[CacheConfig] = None,
                 quant_config: Optional[QuantizationConfig] = None) -> None:
298
        super().__init__()
299
300

        self.config = config
301
        self.multimodal_config = multimodal_config
302

303
        # TODO: Optionally initializes this for supporting embeddings.
304
        self.vision_tower = _init_vision_tower(config)
305
306
        self.image_newline = nn.Parameter(
            torch.empty(config.text_config.hidden_size))
307
308
309
310
311
        self.multi_modal_projector = LlavaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            text_hidden_size=config.text_config.hidden_size,
            projector_hidden_act=config.projector_hidden_act)

312
313
        self.language_model = init_vllm_registered_model(
            config.text_config, cache_config, quant_config)
314

315
316
317
318
319
320
321
322
323
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @cached_property
    def sampler(self):
        if hasattr(self.language_model, "sampler"):
            return self.language_model.sampler

        return Sampler()
324
325

    def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
326
327
328
329
330
331
332
333
334
335
336
337
338
        expected_dims = (2, )

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape)

            if actual_dims != expected_dims:
                expected_expr = str(expected_dims)
                raise ValueError(
                    f"The expected shape of image sizes per image per batch "
                    f"is {expected_expr}. You supplied {tuple(d.shape)}.")

        for d in data:
            _validate_shape(d)
339
340
341

        return data

342
343
344
345
    def _validate_pixel_values(
        self, data: Union[torch.Tensor, List[torch.Tensor]]
    ) -> Union[torch.Tensor, List[torch.Tensor]]:

346
347
348
349
350
351
352
353
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape[1:])

            if actual_dims != expected_dims:
                expected_expr = ("num_patches", *map(str, expected_dims))
354
                raise ValueError(
355
                    "The expected shape of pixel values per image per batch "
356
                    f"is {expected_expr}. You supplied {tuple(d.shape)}.")
357

358
359
        for d in data:
            _validate_shape(d)
360
361
362

        return data

363
    def _parse_and_validate_image_input(
364
            self, **kwargs: object) -> Optional[LlavaNextImageInputs]:
365
366
        pixel_values = kwargs.pop("pixel_values", None)
        image_sizes = kwargs.pop("image_sizes", None)
367
        image_embeds = kwargs.pop("image_embeds", None)
368

369
        if pixel_values is None and image_embeds is None:
370
            return None
371

372
373
374
375
        if pixel_values is not None:
            if not isinstance(pixel_values, (torch.Tensor, list)):
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
376

377
            if not isinstance(image_sizes, (torch.Tensor, list)):
378
379
                raise ValueError("Incorrect type of image sizes. "
                                 f"Got type: {type(image_sizes)}")
380

381
382
            return LlavaNextImagePixelInputs(
                type="pixel_values",
383
384
385
                data=self._validate_pixel_values(flatten_bn(pixel_values)),
                image_sizes=self._validate_image_sizes(
                    flatten_bn(image_sizes, concat=True)),
386
387
388
389
390
391
392
393
394
            )

        if image_embeds is not None:
            if not isinstance(image_embeds, torch.Tensor):
                raise ValueError("Incorrect type of image embeds. "
                                 f"Got type: {type(image_embeds)}")

            return LlavaNextImageEmbeddingInputs(
                type="image_embeds",
395
                data=flatten_bn(image_embeds),
396
397
398
            )

        raise AssertionError("This line should be unreachable.")
399

Cyrus Leung's avatar
Cyrus Leung committed
400
401
402
403
404
405
406
407
408
409
    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

410
411
412
413
414
    def _image_pixels_to_features(
        self,
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
Cyrus Leung's avatar
Cyrus Leung committed
415

416
417
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
418
        image_features = vision_tower(pixel_values)
Cyrus Leung's avatar
Cyrus Leung committed
419
420
421
422
423
424

        return self._select_image_features(
            image_features,
            strategy=self.config.vision_feature_select_strategy,
        )

425
    # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
    def _merge_image_patch_embeddings(self, image_size: torch.Tensor,
                                      patch_embeddings: torch.Tensor, *,
                                      strategy: str) -> torch.Tensor:
        if strategy == "flat":
            return patch_embeddings.flatten(0, 1)

        if strategy.startswith("spatial"):
            height = width = self.config.vision_config.image_size \
                // self.config.vision_config.patch_size

            base_patch_embeds = patch_embeddings[0]
            if height * width != base_patch_embeds.shape[0]:
                raise ValueError(
                    "The number of patches is not consistent with the "
                    "image size.")

            if patch_embeddings.shape[0] > 1:
                other_patch_embeds = patch_embeddings[1:]

445
446
447
                # Move to CPU to avoid floating-point errors
                orig_height, orig_width = image_size.tolist()

448
                # image_aspect_ratio == "anyres"
449
450
                num_patch_height, num_patch_width = get_anyres_image_grid_shape(
                    (orig_height, orig_width),
451
452
453
                    self.config.image_grid_pinpoints,
                    self.config.vision_config.image_size,
                )
454
455
456
457
                num_patches = num_patch_height * num_patch_width

                # Image patches might be padded for batch processing
                other_patch_embeds = other_patch_embeds[:num_patches] \
458
                    .view(num_patch_height, num_patch_width, height, width, -1)
459
460
461
462
463
464

                if "unpad" in strategy:
                    other_patch_embeds = other_patch_embeds \
                        .permute(4, 0, 2, 1, 3).contiguous() \
                        .flatten(1, 2).flatten(2, 3)
                    other_patch_embeds = unpad_image(other_patch_embeds,
465
                                                     (orig_height, orig_width))
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
                    other_patch_embeds = torch.cat((
                        other_patch_embeds,
                        self.image_newline[:, None, None] \
                            .expand(*other_patch_embeds.shape[:-1], 1) \
                            .to(other_patch_embeds.device),
                    ), dim=-1)
                    other_patch_embeds = other_patch_embeds \
                        .flatten(1, 2).transpose(0, 1)
                else:
                    other_patch_embeds = other_patch_embeds \
                        .permute(0, 2, 1, 3, 4).contiguous() \
                        .flatten(0, 3)

                merged_patch_embeddings = torch.cat(
                    (base_patch_embeds, other_patch_embeds), dim=0)
            else:
                if "unpad" in strategy:
                    merged_patch_embeddings = torch.cat(
                        (base_patch_embeds,
                         self.image_newline[None] \
                            .to(base_patch_embeds.device)
                    ), dim=0)
                else:
                    merged_patch_embeddings = base_patch_embeds

            return merged_patch_embeddings

        raise ValueError(f"Unexpected patch merge strategy: {strategy}")

    def _process_image_pixels(
496
497
        self,
        inputs: LlavaNextImagePixelInputs,
498
    ) -> Union[torch.Tensor, List[torch.Tensor]]:
499
500
501
502
        assert self.vision_tower is not None

        pixel_values = inputs["data"]

503
504
505
506
507
508
509
        if isinstance(pixel_values, torch.Tensor):
            b, num_patches, c, h, w = pixel_values.shape
            stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
            stacked_image_features = self._image_pixels_to_features(
                self.vision_tower, stacked_pixel_values)
            stacked_patch_embeddings = self.multi_modal_projector(
                stacked_image_features)
510

511
512
513
514
515
            return stacked_patch_embeddings.view(
                b, num_patches, *stacked_patch_embeddings.shape[1:])

        num_patches_per_batch = [v.shape[0] for v in pixel_values]
        stacked_pixel_values = torch.cat(pixel_values)
516
517
518
        stacked_image_features = self._image_pixels_to_features(
            self.vision_tower, stacked_pixel_values)

519
520
521
522
        return [
            self.multi_modal_projector(image_features) for image_features in
            torch.split(stacked_image_features, num_patches_per_batch)
        ]
523
524

    def _process_image_input(
525
526
527
        self,
        image_input: LlavaNextImageInputs,
    ) -> Union[torch.Tensor, List[torch.Tensor]]:
528
529
530
        if image_input["type"] == "image_embeds":
            return [image_input["data"]]

531
        patch_embeddings = self._process_image_pixels(image_input)
532
533
534

        image_sizes = image_input.get("image_sizes")
        if image_sizes is None:
535
            batch_size = len(image_input["data"])
536
            vision_config = self.config.vision_config
537
538
            default_height = default_width = vision_config.image_size
            image_sizes = torch.as_tensor([[default_height, default_width]
539
540
                                           for _ in range(batch_size)])

541
        return [
542
            self._merge_image_patch_embeddings(image_sizes[i],
543
                                               patch_features_batch,
544
                                               strategy="spatial_unpad")
545
            for i, patch_features_batch in enumerate(patch_embeddings)
546
547
548
549
550
551
552
553
        ]

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
554
        intermediate_tensors: Optional[IntermediateTensors] = None,
555
        **kwargs: object,
556
    ) -> Union[torch.Tensor, IntermediateTensors]:
Cyrus Leung's avatar
Cyrus Leung committed
557
        """Run forward pass for LlaVA-NeXT.
558
559
560

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
561

562
        Concretely, consider a text prompt:
563
564
565
566
567
        `"A chat between a curious human and an artificial intelligence
        assistant. The assistant gives helpful, detailed, and polite answers to
        the human's questions.
        USER: <image>\\nWhat is shown in this image? ASSISTANT:"`.

568
        Tokenizer outputs:
569
570
571
572
573
574
575
        `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
        29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
        6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
        29871, 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973, 319, 1799,
        9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
576
        before they are inputted to the model, so the input processor prepends
577
578
579
580
581
582
583
584
585
586
587
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
        29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
        6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
        29871, 32000, ..., 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973,
        319, 1799, 9047, 13566, 29901]`.

        Unlike in LLaVA-1.5, the number of image tokens inputted to the language
        model depends on the original size of the input image. Including the
        original image token in the input, the required number of image tokens
        is given by :func:`get_llava_next_image_feature_size`.
588
589
590
591
592
593
594

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
595
            pixel_values: The pixels in each grid patch for each input image.
596
            image_sizes: The original `(height, width)` for each input image.
597

Cyrus Leung's avatar
Cyrus Leung committed
598
        See also:
599
            :class:`LlavaNextImageInputs`
600
        """
601
602
603
604
605
        if intermediate_tensors is not None:
            input_ids = None
            inputs_embeds = None
        else:
            image_input = self._parse_and_validate_image_input(**kwargs)
606

607
608
609
610
            if image_input is not None:
                vision_embeddings = self._process_image_input(image_input)
                inputs_embeds = self.language_model.model.get_input_embeddings(
                    input_ids)
611

612
613
614
                inputs_embeds = merge_multimodal_embeddings(
                    input_ids, inputs_embeds, vision_embeddings,
                    self.config.image_token_index)
615

616
617
618
                input_ids = None
            else:
                inputs_embeds = None
619

620
621
622
623
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  kv_caches,
                                                  attn_metadata,
624
                                                  intermediate_tensors,
625
                                                  inputs_embeds=inputs_embeds)
626
627
628

        return hidden_states

629
630
631
632
633
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
634
635
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
636
637
638
639
640
641

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
642
        return self.language_model.sample(logits, sampling_metadata)
643
644

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
645
        # prepare weight iterators for components
646
        weights_group = group_weights_with_prefix(weights)
647
648

        # load vision encoder
649
        self.vision_tower.load_weights(weights_group["vision_tower"])
650
651
652

        # load mlp projector
        mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
653
        for name, loaded_weight in weights_group["multi_modal_projector"]:
654
655
656
657
658
659
            param = mlp_params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)

        # load newline
660
        for name, loaded_weight in weights_group["image_newline"]:
661
662
663
664
665
666
667
            assert name == ""
            param = self.image_newline
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)

        # load llm backbone
668
        self.language_model.load_weights(weights_group["language_model"])