llava_next.py 25.1 KB
Newer Older
1
import itertools
2
3
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
                    TypedDict, Union)
4
5
6

import torch
import torch.nn as nn
7
from PIL import Image
8
from transformers import CLIPVisionConfig, LlavaNextConfig, SiglipVisionConfig
9
10
11
12
13
from transformers.models.llava_next.modeling_llava_next import (
    get_anyres_image_grid_shape, unpad_image)
from typing_extensions import NotRequired

from vllm.attention import AttentionMetadata
14
from vllm.config import CacheConfig, MultiModalConfig
15
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
16
from vllm.logger import init_logger
17
from vllm.model_executor.layers.quantization import QuantizationConfig
18
19
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
20
from vllm.multimodal import MULTIMODAL_REGISTRY
21
from vllm.sequence import IntermediateTensors, SamplerOutput
22
from vllm.utils import is_list_of
23

24
25
from .clip import (CLIPVisionModel, dummy_image_for_clip,
                   dummy_seq_data_for_clip, get_clip_image_feature_size,
26
                   get_clip_patch_grid_length, input_processor_for_clip)
27
from .interfaces import SupportsMultiModal
28
from .llava import LlavaMultiModalProjector
29
30
31
32
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
                     dummy_seq_data_for_siglip, get_siglip_image_feature_size,
                     get_siglip_patch_grid_length, input_processor_for_siglip)
from .utils import (filter_weights, init_vllm_registered_model,
33
                    merge_multimodal_embeddings)
34
35
36
37
38
39
40
41

logger = init_logger(__name__)

_KEYS_TO_MODIFY_MAPPING = {
    "language_model.lm_head": "lm_head",
    "language_model.model": "language_model",
}

42
43
44
# Result in the max possible feature size (2x2 grid of 336x336px tiles)
MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448

45
46
47

class LlavaNextImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
48
    data: Union[torch.Tensor, List[torch.Tensor]]
49
50
51
    """
    Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`

52
53
    Note that `num_patches` may be different for each batch, in which case
    the data is passed as a list instead of a batched tensor.
54
    """
55
56

    image_sizes: NotRequired[torch.Tensor]
57
58
59
60
61
    """
    Shape: `(batch_size, 2)`

    This should be in `(height, width)` format.
    """
62
63


64
65
66
67
68
69
70
71
72
73
74
class LlavaNextImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
    """Shape: `(batch_size, image_feature_size, hidden_size)`

    `hidden_size` must match the hidden size of language model backbone.
    """


LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
                             LlavaNextImageEmbeddingInputs]
75
76


77
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79
78
def _get_llava_next_num_unpadded_features(
79
80
    original_height: int,
    original_width: int,
81
82
83
84
85
86
87
    npatches: int,
    num_patch_height: int,
    num_patch_width: int,
) -> Tuple[int, int]:
    current_height = npatches * num_patch_height
    current_width = npatches * num_patch_width

88
89
90
    aspect_ratio = original_width / original_height
    current_aspect_ratio = current_width / current_height

91
    if aspect_ratio > current_aspect_ratio:
92
        new_height = (original_height * current_width) // original_width
93
94
        padding = (current_height - new_height) // 2
        current_height -= padding * 2
95
    else:
96
        new_width = (original_width * current_height) // original_height
97
98
        padding = (current_width - new_width) // 2
        current_width -= padding * 2
99
100
101
102
103
104

    unpadded_features = current_height * current_width
    newline_features = current_height
    return (unpadded_features, newline_features)


105
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106
106
def get_llava_next_image_feature_size(
107
108
109
110
111
112
113
114
115
116
117
118
    hf_config: LlavaNextConfig,
    *,
    input_height: int,
    input_width: int,
) -> int:
    vision_config = hf_config.vision_config

    if isinstance(vision_config, CLIPVisionConfig):
        num_patches = get_clip_patch_grid_length(
            image_size=vision_config.image_size,
            patch_size=vision_config.patch_size,
        )
119
120
121
122
123
        base_feature_size = get_clip_image_feature_size(vision_config)
    elif isinstance(vision_config, SiglipVisionConfig):
        num_patches = get_siglip_patch_grid_length(
            image_size=vision_config.image_size,
            patch_size=vision_config.patch_size,
124
        )
125
126
127
128
129
130
131
132
133
134
135
136
        base_feature_size = get_siglip_image_feature_size(vision_config)
    else:
        msg = f"Unsupported vision config: {type(vision_config)}"
        raise NotImplementedError(msg)

    strategy = hf_config.vision_feature_select_strategy
    if strategy == "default":
        base_feature_size -= 1
    elif strategy == "full":
        pass
    else:
        raise ValueError(f"Unexpected select feature strategy: {strategy}")
137

138
139
140
141
142
    num_patch_height, num_patch_width = get_anyres_image_grid_shape(
        image_size=(input_height, input_width),
        grid_pinpoints=hf_config.image_grid_pinpoints,
        patch_size=vision_config.image_size,
    )
143

144
145
146
147
148
149
    (
        unpadded_feature_size,
        newline_feature_size,
    ) = _get_llava_next_num_unpadded_features(input_height, input_width,
                                              num_patches, num_patch_height,
                                              num_patch_width)
150

151
    return unpadded_feature_size + newline_feature_size + base_feature_size
152
153


154
155
156
def get_max_llava_next_image_tokens(ctx: InputContext):
    return get_llava_next_image_feature_size(
        ctx.get_hf_config(LlavaNextConfig),
157
158
        input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
        input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
159
160
161
    )


162
163
def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
                              mm_counts: Mapping[str, int]):
164
165
    hf_config = ctx.get_hf_config(LlavaNextConfig)
    vision_config = hf_config.vision_config
166
    num_images = mm_counts["image"]
167

168
    image_feature_size = get_max_llava_next_image_tokens(ctx)
169
170
171
172
173

    if isinstance(vision_config, CLIPVisionConfig):
        seq_data = dummy_seq_data_for_clip(
            vision_config,
            seq_len,
174
            num_images,
175
176
177
178
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )

179
180
        mm_data = dummy_image_for_clip(
            vision_config,
181
            num_images,
182
183
            image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
            image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
184
        )
185

186
187
188
189
190
        return seq_data, mm_data
    elif isinstance(vision_config, SiglipVisionConfig):
        seq_data = dummy_seq_data_for_siglip(
            vision_config,
            seq_len,
191
            num_images,
192
193
194
195
196
197
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )

        mm_data = dummy_image_for_siglip(
            vision_config,
198
            num_images,
199
200
201
202
            image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
            image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
        )

203
204
205
206
207
208
        return seq_data, mm_data

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


209
210
211
212
def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
    multi_modal_data = llm_inputs.get("multi_modal_data")
    if multi_modal_data is None or "image" not in multi_modal_data:
        return llm_inputs
213

214
215
216
    model_config = ctx.model_config
    hf_config = ctx.get_hf_config(LlavaNextConfig)
    vision_config = hf_config.vision_config
217

218
219
220
221
222
223
224
225
226
    image_data = multi_modal_data["image"]
    if isinstance(image_data, Image.Image):
        width, height = image_data.size

        image_feature_size = get_llava_next_image_feature_size(
            hf_config,
            input_height=height,
            input_width=width,
        )
227
228
229
230
231
232
233
    elif is_list_of(image_data, Image.Image):
        image_feature_size = [
            get_llava_next_image_feature_size(hf_config,
                                              input_height=img.height,
                                              input_width=img.width)
            for img in image_data
        ]
234
    elif isinstance(image_data, torch.Tensor):
235
        image_feature_size = image_data.shape[0]
236
237
    else:
        raise TypeError(f"Invalid image type: {type(image_data)}")
238

239
    vision_config = hf_config.vision_config
240

241
242
243
244
245
246
247
248
    if isinstance(vision_config, CLIPVisionConfig):
        return input_processor_for_clip(
            model_config,
            vision_config,
            llm_inputs,
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
    elif isinstance(vision_config, SiglipVisionConfig):
        return input_processor_for_siglip(
            model_config,
            vision_config,
            llm_inputs,
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


def _init_vision_tower(hf_config: LlavaNextConfig):
    vision_config = hf_config.vision_config

    # Initialize the vision tower only up to the required feature layer
    vision_feature_layer = hf_config.vision_feature_layer
    if vision_feature_layer < 0:
        num_hidden_layers = hf_config.vision_config.num_hidden_layers \
            + vision_feature_layer + 1
    else:
        num_hidden_layers = vision_feature_layer + 1

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
            num_hidden_layers_override=num_hidden_layers,
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
            num_hidden_layers_override=num_hidden_layers,
        )
283

284
285
    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)
286
287


288
@MULTIMODAL_REGISTRY.register_image_input_mapper()
289
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens)
290
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
291
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next)
292
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
293

294
295
    def __init__(self,
                 config: LlavaNextConfig,
296
                 multimodal_config: MultiModalConfig,
297
298
                 cache_config: Optional[CacheConfig] = None,
                 quant_config: Optional[QuantizationConfig] = None) -> None:
299
        super().__init__()
300
301

        self.config = config
302
        self.multimodal_config = multimodal_config
303

304
        # TODO: Optionally initializes this for supporting embeddings.
305
        self.vision_tower = _init_vision_tower(config)
306
307
308
309
310
        self.multi_modal_projector = LlavaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            text_hidden_size=config.text_config.hidden_size,
            projector_hidden_act=config.projector_hidden_act)

311
312
        self.language_model = init_vllm_registered_model(
            config.text_config, cache_config, quant_config)
313
314
315
316
317
318
319
320
321
322
323
324

        self.image_newline = nn.Parameter(
            torch.empty(config.text_config.hidden_size))

    def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
        if list(data.shape[1:]) != [2]:
            raise ValueError(
                f"The expected image sizes shape is batch dimension plus "
                f"{[2]}. You supplied {data.shape}.")

        return data

325
326
327
328
    def _validate_pixel_values(
        self, data: Union[torch.Tensor, List[torch.Tensor]]
    ) -> Union[torch.Tensor, List[torch.Tensor]]:

329
330
331
332
333
334
335
336
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape[1:])

            if actual_dims != expected_dims:
                expected_expr = ("num_patches", *map(str, expected_dims))
337
                raise ValueError(
338
339
                    "The expected shape of pixel values in each batch element "
                    f"is {expected_expr}. You supplied {tuple(d.shape)}.")
340

341
342
        for d in data:
            _validate_shape(d)
343
344
345

        return data

346
    def _parse_and_validate_image_input(
347
            self, **kwargs: object) -> Optional[LlavaNextImageInputs]:
348
349
        pixel_values = kwargs.pop("pixel_values", None)
        image_sizes = kwargs.pop("image_sizes", None)
350
        image_embeds = kwargs.pop("image_embeds", None)
351

352
        if pixel_values is None and image_embeds is None:
353
            return None
354

355
356
357
358
        if pixel_values is not None:
            if not isinstance(pixel_values, (torch.Tensor, list)):
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
359

360
361
362
            if not isinstance(image_sizes, torch.Tensor):
                raise ValueError("Incorrect type of image sizes. "
                                 f"Got type: {type(image_sizes)}")
363

364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
            return LlavaNextImagePixelInputs(
                type="pixel_values",
                data=self._validate_pixel_values(pixel_values),
                image_sizes=self._validate_image_sizes(image_sizes),
            )

        if image_embeds is not None:
            if not isinstance(image_embeds, torch.Tensor):
                raise ValueError("Incorrect type of image embeds. "
                                 f"Got type: {type(image_embeds)}")

            return LlavaNextImageEmbeddingInputs(
                type="image_embeds",
                data=image_embeds,
            )

        raise AssertionError("This line should be unreachable.")
381

Cyrus Leung's avatar
Cyrus Leung committed
382
383
384
385
386
387
388
389
390
391
    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

392
393
394
395
396
    def _image_pixels_to_features(
        self,
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
Cyrus Leung's avatar
Cyrus Leung committed
397

398
399
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
400
        image_features = vision_tower(pixel_values)
Cyrus Leung's avatar
Cyrus Leung committed
401
402
403
404
405
406

        return self._select_image_features(
            image_features,
            strategy=self.config.vision_feature_select_strategy,
        )

407
    # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
    def _merge_image_patch_embeddings(self, image_size: torch.Tensor,
                                      patch_embeddings: torch.Tensor, *,
                                      strategy: str) -> torch.Tensor:
        if strategy == "flat":
            return patch_embeddings.flatten(0, 1)

        if strategy.startswith("spatial"):
            height = width = self.config.vision_config.image_size \
                // self.config.vision_config.patch_size

            base_patch_embeds = patch_embeddings[0]
            if height * width != base_patch_embeds.shape[0]:
                raise ValueError(
                    "The number of patches is not consistent with the "
                    "image size.")

            if patch_embeddings.shape[0] > 1:
                other_patch_embeds = patch_embeddings[1:]

427
428
429
                # Move to CPU to avoid floating-point errors
                orig_height, orig_width = image_size.tolist()

430
                # image_aspect_ratio == "anyres"
431
432
                num_patch_height, num_patch_width = get_anyres_image_grid_shape(
                    (orig_height, orig_width),
433
434
435
                    self.config.image_grid_pinpoints,
                    self.config.vision_config.image_size,
                )
436
437
438
439
                num_patches = num_patch_height * num_patch_width

                # Image patches might be padded for batch processing
                other_patch_embeds = other_patch_embeds[:num_patches] \
440
                    .view(num_patch_height, num_patch_width, height, width, -1)
441
442
443
444
445
446

                if "unpad" in strategy:
                    other_patch_embeds = other_patch_embeds \
                        .permute(4, 0, 2, 1, 3).contiguous() \
                        .flatten(1, 2).flatten(2, 3)
                    other_patch_embeds = unpad_image(other_patch_embeds,
447
                                                     (orig_height, orig_width))
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
                    other_patch_embeds = torch.cat((
                        other_patch_embeds,
                        self.image_newline[:, None, None] \
                            .expand(*other_patch_embeds.shape[:-1], 1) \
                            .to(other_patch_embeds.device),
                    ), dim=-1)
                    other_patch_embeds = other_patch_embeds \
                        .flatten(1, 2).transpose(0, 1)
                else:
                    other_patch_embeds = other_patch_embeds \
                        .permute(0, 2, 1, 3, 4).contiguous() \
                        .flatten(0, 3)

                merged_patch_embeddings = torch.cat(
                    (base_patch_embeds, other_patch_embeds), dim=0)
            else:
                if "unpad" in strategy:
                    merged_patch_embeddings = torch.cat(
                        (base_patch_embeds,
                         self.image_newline[None] \
                            .to(base_patch_embeds.device)
                    ), dim=0)
                else:
                    merged_patch_embeddings = base_patch_embeds

            return merged_patch_embeddings

        raise ValueError(f"Unexpected patch merge strategy: {strategy}")

    def _process_image_pixels(
478
479
        self,
        inputs: LlavaNextImagePixelInputs,
480
    ) -> Union[torch.Tensor, List[torch.Tensor]]:
481
482
483
484
        assert self.vision_tower is not None

        pixel_values = inputs["data"]

485
486
487
488
489
490
491
        if isinstance(pixel_values, torch.Tensor):
            b, num_patches, c, h, w = pixel_values.shape
            stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
            stacked_image_features = self._image_pixels_to_features(
                self.vision_tower, stacked_pixel_values)
            stacked_patch_embeddings = self.multi_modal_projector(
                stacked_image_features)
492

493
494
495
496
497
            return stacked_patch_embeddings.view(
                b, num_patches, *stacked_patch_embeddings.shape[1:])

        num_patches_per_batch = [v.shape[0] for v in pixel_values]
        stacked_pixel_values = torch.cat(pixel_values)
498
499
500
        stacked_image_features = self._image_pixels_to_features(
            self.vision_tower, stacked_pixel_values)

501
502
503
504
        return [
            self.multi_modal_projector(image_features) for image_features in
            torch.split(stacked_image_features, num_patches_per_batch)
        ]
505
506

    def _process_image_input(
507
508
509
        self,
        image_input: LlavaNextImageInputs,
    ) -> Union[torch.Tensor, List[torch.Tensor]]:
510
511
512
        if image_input["type"] == "image_embeds":
            return [image_input["data"]]

513
        patch_embeddings = self._process_image_pixels(image_input)
514
515
516

        image_sizes = image_input.get("image_sizes")
        if image_sizes is None:
517
            batch_size = len(image_input["data"])
518
            vision_config = self.config.vision_config
519
520
            default_height = default_width = vision_config.image_size
            image_sizes = torch.as_tensor([[default_height, default_width]
521
522
                                           for _ in range(batch_size)])

523
        return [
524
            self._merge_image_patch_embeddings(image_sizes[i],
525
                                               patch_features_batch,
526
                                               strategy="spatial_unpad")
527
            for i, patch_features_batch in enumerate(patch_embeddings)
528
529
530
531
532
533
534
535
        ]

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
536
        intermediate_tensors: Optional[IntermediateTensors] = None,
537
538
        **kwargs: object,
    ) -> SamplerOutput:
Cyrus Leung's avatar
Cyrus Leung committed
539
        """Run forward pass for LlaVA-NeXT.
540
541
542

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
543

544
        Concretely, consider a text prompt:
545
546
547
548
549
        `"A chat between a curious human and an artificial intelligence
        assistant. The assistant gives helpful, detailed, and polite answers to
        the human's questions.
        USER: <image>\\nWhat is shown in this image? ASSISTANT:"`.

550
        Tokenizer outputs:
551
552
553
554
555
556
557
        `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
        29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
        6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
        29871, 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973, 319, 1799,
        9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
558
        before they are inputted to the model, so the input processor prepends
559
560
561
562
563
564
565
566
567
568
569
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
        29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
        6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
        29871, 32000, ..., 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973,
        319, 1799, 9047, 13566, 29901]`.

        Unlike in LLaVA-1.5, the number of image tokens inputted to the language
        model depends on the original size of the input image. Including the
        original image token in the input, the required number of image tokens
        is given by :func:`get_llava_next_image_feature_size`.
570
571
572
573
574
575
576

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
577
            pixel_values: The pixels in each grid patch for each input image.
578
            image_sizes: The original `(height, width)` for each input image.
579

Cyrus Leung's avatar
Cyrus Leung committed
580
        See also:
581
            :class:`LlavaNextImageInputs`
582
583
584
585
586
        """
        image_input = self._parse_and_validate_image_input(**kwargs)

        if image_input is not None:
            vision_embeddings = self._process_image_input(image_input)
587
588
            inputs_embeds = self.language_model.model.get_input_embeddings(
                input_ids)
589

590
            inputs_embeds = merge_multimodal_embeddings(
591
                input_ids, inputs_embeds, vision_embeddings,
592
                self.config.image_token_index)
593
594
595
596
597

            input_ids = None
        else:
            inputs_embeds = None

598
599
600
601
602
603
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  kv_caches,
                                                  attn_metadata,
                                                  None,
                                                  inputs_embeds=inputs_embeds)
604
605
606

        return hidden_states

607
608
609
610
611
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
612
613
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
614
615
616
617
618
619

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
620
        return self.language_model.sample(logits, sampling_metadata)
621
622

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
        # prepare weight iterators for components
        vit_weights, mlp_weights, newline_weights, llm_weights = itertools.tee(
            weights, 4)

        # load vision encoder
        vit_weights = filter_weights(vit_weights, "vision_tower")
        self.vision_tower.load_weights(vit_weights)

        # load mlp projector
        mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
        mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
        for name, loaded_weight in mlp_weights:
            param = mlp_params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)

        # load newline
        newline_weights = filter_weights(newline_weights, "image_newline")
        for name, loaded_weight in newline_weights:
            assert name == ""
            param = self.image_newline
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)

        # load llm backbone
        llm_weights = filter_weights(llm_weights, "language_model")
        self.language_model.load_weights(llm_weights)