llava_next.py 24.4 KB
Newer Older
1
import itertools
2
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
3
4
5

import torch
import torch.nn as nn
6
from PIL import Image
7
from transformers import CLIPVisionConfig, LlavaNextConfig, SiglipVisionConfig
8
9
10
11
12
from transformers.models.llava_next.modeling_llava_next import (
    get_anyres_image_grid_shape, unpad_image)
from typing_extensions import NotRequired

from vllm.attention import AttentionMetadata
13
from vllm.config import CacheConfig, MultiModalConfig
14
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
15
16
17
18
19
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
20
from vllm.multimodal import MULTIMODAL_REGISTRY
21
from vllm.sequence import IntermediateTensors, SamplerOutput
22

23
24
from .clip import (CLIPVisionModel, dummy_image_for_clip,
                   dummy_seq_data_for_clip, get_clip_image_feature_size,
25
                   get_clip_patch_grid_length, input_processor_for_clip)
26
from .interfaces import SupportsMultiModal
27
from .llava import LlavaMultiModalProjector
28
29
30
31
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
                     dummy_seq_data_for_siglip, get_siglip_image_feature_size,
                     get_siglip_patch_grid_length, input_processor_for_siglip)
from .utils import (filter_weights, init_vllm_registered_model,
32
                    merge_multimodal_embeddings)
33
34
35
36
37
38
39
40

logger = init_logger(__name__)

_KEYS_TO_MODIFY_MAPPING = {
    "language_model.lm_head": "lm_head",
    "language_model.model": "language_model",
}

41
42
43
# Result in the max possible feature size (2x2 grid of 336x336px tiles)
MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448

44
45
46

class LlavaNextImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
47
    data: Union[torch.Tensor, List[torch.Tensor]]
48
49
50
    """
    Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`

51
52
    Note that `num_patches` may be different for each batch, in which case
    the data is passed as a list instead of a batched tensor.
53
    """
54
55

    image_sizes: NotRequired[torch.Tensor]
56
57
58
59
60
    """
    Shape: `(batch_size, 2)`

    This should be in `(height, width)` format.
    """
61
62


63
64
65
66
67
68
69
70
71
72
73
class LlavaNextImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
    """Shape: `(batch_size, image_feature_size, hidden_size)`

    `hidden_size` must match the hidden size of language model backbone.
    """


LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
                             LlavaNextImageEmbeddingInputs]
74
75


76
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79
77
def _get_llava_next_num_unpadded_features(
78
79
    original_height: int,
    original_width: int,
80
81
82
83
84
85
86
    npatches: int,
    num_patch_height: int,
    num_patch_width: int,
) -> Tuple[int, int]:
    current_height = npatches * num_patch_height
    current_width = npatches * num_patch_width

87
88
89
    aspect_ratio = original_width / original_height
    current_aspect_ratio = current_width / current_height

90
    if aspect_ratio > current_aspect_ratio:
91
        new_height = (original_height * current_width) // original_width
92
93
        padding = (current_height - new_height) // 2
        current_height -= padding * 2
94
    else:
95
        new_width = (original_width * current_height) // original_height
96
97
        padding = (current_width - new_width) // 2
        current_width -= padding * 2
98
99
100
101
102
103

    unpadded_features = current_height * current_width
    newline_features = current_height
    return (unpadded_features, newline_features)


104
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106
105
def get_llava_next_image_feature_size(
106
107
108
109
110
111
112
113
114
115
116
117
    hf_config: LlavaNextConfig,
    *,
    input_height: int,
    input_width: int,
) -> int:
    vision_config = hf_config.vision_config

    if isinstance(vision_config, CLIPVisionConfig):
        num_patches = get_clip_patch_grid_length(
            image_size=vision_config.image_size,
            patch_size=vision_config.patch_size,
        )
118
119
120
121
122
        base_feature_size = get_clip_image_feature_size(vision_config)
    elif isinstance(vision_config, SiglipVisionConfig):
        num_patches = get_siglip_patch_grid_length(
            image_size=vision_config.image_size,
            patch_size=vision_config.patch_size,
123
        )
124
125
126
127
128
129
130
131
132
133
134
135
        base_feature_size = get_siglip_image_feature_size(vision_config)
    else:
        msg = f"Unsupported vision config: {type(vision_config)}"
        raise NotImplementedError(msg)

    strategy = hf_config.vision_feature_select_strategy
    if strategy == "default":
        base_feature_size -= 1
    elif strategy == "full":
        pass
    else:
        raise ValueError(f"Unexpected select feature strategy: {strategy}")
136

137
138
139
140
141
    num_patch_height, num_patch_width = get_anyres_image_grid_shape(
        image_size=(input_height, input_width),
        grid_pinpoints=hf_config.image_grid_pinpoints,
        patch_size=vision_config.image_size,
    )
142

143
144
145
146
147
148
    (
        unpadded_feature_size,
        newline_feature_size,
    ) = _get_llava_next_num_unpadded_features(input_height, input_width,
                                              num_patches, num_patch_height,
                                              num_patch_width)
149

150
    return unpadded_feature_size + newline_feature_size + base_feature_size
151
152


153
154
155
def get_max_llava_next_image_tokens(ctx: InputContext):
    return get_llava_next_image_feature_size(
        ctx.get_hf_config(LlavaNextConfig),
156
157
        input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
        input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
158
159
160
    )


161
162
163
164
def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
    hf_config = ctx.get_hf_config(LlavaNextConfig)
    vision_config = hf_config.vision_config

165
    image_feature_size = get_max_llava_next_image_tokens(ctx)
166
167
168
169
170
171
172
173
174

    if isinstance(vision_config, CLIPVisionConfig):
        seq_data = dummy_seq_data_for_clip(
            vision_config,
            seq_len,
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )

175
176
        mm_data = dummy_image_for_clip(
            vision_config,
177
178
            image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
            image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
179
        )
180

181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
        return seq_data, mm_data
    elif isinstance(vision_config, SiglipVisionConfig):
        seq_data = dummy_seq_data_for_siglip(
            vision_config,
            seq_len,
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )

        mm_data = dummy_image_for_siglip(
            vision_config,
            image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
            image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
        )

196
197
198
199
200
201
        return seq_data, mm_data

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


202
203
204
205
def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
    multi_modal_data = llm_inputs.get("multi_modal_data")
    if multi_modal_data is None or "image" not in multi_modal_data:
        return llm_inputs
206

207
208
209
    model_config = ctx.model_config
    hf_config = ctx.get_hf_config(LlavaNextConfig)
    vision_config = hf_config.vision_config
210

211
212
213
214
215
216
217
218
219
220
    image_data = multi_modal_data["image"]
    if isinstance(image_data, Image.Image):
        width, height = image_data.size

        image_feature_size = get_llava_next_image_feature_size(
            hf_config,
            input_height=height,
            input_width=width,
        )
    elif isinstance(image_data, torch.Tensor):
221
        image_feature_size = image_data.shape[0]
222
223
    else:
        raise TypeError(f"Invalid image type: {type(image_data)}")
224

225
    vision_config = hf_config.vision_config
226

227
228
229
230
231
232
233
234
    if isinstance(vision_config, CLIPVisionConfig):
        return input_processor_for_clip(
            model_config,
            vision_config,
            llm_inputs,
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
    elif isinstance(vision_config, SiglipVisionConfig):
        return input_processor_for_siglip(
            model_config,
            vision_config,
            llm_inputs,
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


def _init_vision_tower(hf_config: LlavaNextConfig):
    vision_config = hf_config.vision_config

    # Initialize the vision tower only up to the required feature layer
    vision_feature_layer = hf_config.vision_feature_layer
    if vision_feature_layer < 0:
        num_hidden_layers = hf_config.vision_config.num_hidden_layers \
            + vision_feature_layer + 1
    else:
        num_hidden_layers = vision_feature_layer + 1

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
            num_hidden_layers_override=num_hidden_layers,
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
            num_hidden_layers_override=num_hidden_layers,
        )
269

270
271
    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)
272
273


274
@MULTIMODAL_REGISTRY.register_image_input_mapper()
275
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens)
276
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
277
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next)
278
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
279

280
281
    def __init__(self,
                 config: LlavaNextConfig,
282
                 multimodal_config: MultiModalConfig,
283
284
                 cache_config: Optional[CacheConfig] = None,
                 quant_config: Optional[QuantizationConfig] = None) -> None:
285
        super().__init__()
286
287

        self.config = config
288
        self.multimodal_config = multimodal_config
289

290
        # TODO: Optionally initializes this for supporting embeddings.
291
        self.vision_tower = _init_vision_tower(config)
292
293
294
295
296
        self.multi_modal_projector = LlavaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            text_hidden_size=config.text_config.hidden_size,
            projector_hidden_act=config.projector_hidden_act)

297
298
        self.language_model = init_vllm_registered_model(
            config.text_config, cache_config, quant_config)
299
300
301
302
303
304
305
306
307
308
309
310

        self.image_newline = nn.Parameter(
            torch.empty(config.text_config.hidden_size))

    def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
        if list(data.shape[1:]) != [2]:
            raise ValueError(
                f"The expected image sizes shape is batch dimension plus "
                f"{[2]}. You supplied {data.shape}.")

        return data

311
312
313
314
    def _validate_pixel_values(
        self, data: Union[torch.Tensor, List[torch.Tensor]]
    ) -> Union[torch.Tensor, List[torch.Tensor]]:

315
316
317
318
319
320
321
322
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape[1:])

            if actual_dims != expected_dims:
                expected_expr = ("num_patches", *map(str, expected_dims))
323
                raise ValueError(
324
325
                    "The expected shape of pixel values in each batch element "
                    f"is {expected_expr}. You supplied {tuple(d.shape)}.")
326

327
328
        for d in data:
            _validate_shape(d)
329
330
331

        return data

332
    def _parse_and_validate_image_input(
333
            self, **kwargs: object) -> Optional[LlavaNextImageInputs]:
334
335
        pixel_values = kwargs.pop("pixel_values", None)
        image_sizes = kwargs.pop("image_sizes", None)
336
        image_embeds = kwargs.pop("image_embeds", None)
337

338
        if pixel_values is None and image_embeds is None:
339
            return None
340

341
342
343
344
        if pixel_values is not None:
            if not isinstance(pixel_values, (torch.Tensor, list)):
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
345

346
347
348
            if not isinstance(image_sizes, torch.Tensor):
                raise ValueError("Incorrect type of image sizes. "
                                 f"Got type: {type(image_sizes)}")
349

350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
            return LlavaNextImagePixelInputs(
                type="pixel_values",
                data=self._validate_pixel_values(pixel_values),
                image_sizes=self._validate_image_sizes(image_sizes),
            )

        if image_embeds is not None:
            if not isinstance(image_embeds, torch.Tensor):
                raise ValueError("Incorrect type of image embeds. "
                                 f"Got type: {type(image_embeds)}")

            return LlavaNextImageEmbeddingInputs(
                type="image_embeds",
                data=image_embeds,
            )

        raise AssertionError("This line should be unreachable.")
367

Cyrus Leung's avatar
Cyrus Leung committed
368
369
370
371
372
373
374
375
376
377
    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

378
379
380
381
382
    def _image_pixels_to_features(
        self,
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
Cyrus Leung's avatar
Cyrus Leung committed
383

384
385
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
386
        image_features = vision_tower(pixel_values)
Cyrus Leung's avatar
Cyrus Leung committed
387
388
389
390
391
392

        return self._select_image_features(
            image_features,
            strategy=self.config.vision_feature_select_strategy,
        )

393
    # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
    def _merge_image_patch_embeddings(self, image_size: torch.Tensor,
                                      patch_embeddings: torch.Tensor, *,
                                      strategy: str) -> torch.Tensor:
        if strategy == "flat":
            return patch_embeddings.flatten(0, 1)

        if strategy.startswith("spatial"):
            height = width = self.config.vision_config.image_size \
                // self.config.vision_config.patch_size

            base_patch_embeds = patch_embeddings[0]
            if height * width != base_patch_embeds.shape[0]:
                raise ValueError(
                    "The number of patches is not consistent with the "
                    "image size.")

            if patch_embeddings.shape[0] > 1:
                other_patch_embeds = patch_embeddings[1:]

413
414
415
                # Move to CPU to avoid floating-point errors
                orig_height, orig_width = image_size.tolist()

416
                # image_aspect_ratio == "anyres"
417
418
                num_patch_height, num_patch_width = get_anyres_image_grid_shape(
                    (orig_height, orig_width),
419
420
421
422
                    self.config.image_grid_pinpoints,
                    self.config.vision_config.image_size,
                )
                other_patch_embeds = other_patch_embeds \
423
                    .view(num_patch_height, num_patch_width, height, width, -1)
424
425
426
427
428
429

                if "unpad" in strategy:
                    other_patch_embeds = other_patch_embeds \
                        .permute(4, 0, 2, 1, 3).contiguous() \
                        .flatten(1, 2).flatten(2, 3)
                    other_patch_embeds = unpad_image(other_patch_embeds,
430
                                                     (orig_height, orig_width))
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
                    other_patch_embeds = torch.cat((
                        other_patch_embeds,
                        self.image_newline[:, None, None] \
                            .expand(*other_patch_embeds.shape[:-1], 1) \
                            .to(other_patch_embeds.device),
                    ), dim=-1)
                    other_patch_embeds = other_patch_embeds \
                        .flatten(1, 2).transpose(0, 1)
                else:
                    other_patch_embeds = other_patch_embeds \
                        .permute(0, 2, 1, 3, 4).contiguous() \
                        .flatten(0, 3)

                merged_patch_embeddings = torch.cat(
                    (base_patch_embeds, other_patch_embeds), dim=0)
            else:
                if "unpad" in strategy:
                    merged_patch_embeddings = torch.cat(
                        (base_patch_embeds,
                         self.image_newline[None] \
                            .to(base_patch_embeds.device)
                    ), dim=0)
                else:
                    merged_patch_embeddings = base_patch_embeds

            return merged_patch_embeddings

        raise ValueError(f"Unexpected patch merge strategy: {strategy}")

    def _process_image_pixels(
461
462
        self,
        inputs: LlavaNextImagePixelInputs,
463
    ) -> Union[torch.Tensor, List[torch.Tensor]]:
464
465
466
467
        assert self.vision_tower is not None

        pixel_values = inputs["data"]

468
469
470
471
472
473
474
        if isinstance(pixel_values, torch.Tensor):
            b, num_patches, c, h, w = pixel_values.shape
            stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
            stacked_image_features = self._image_pixels_to_features(
                self.vision_tower, stacked_pixel_values)
            stacked_patch_embeddings = self.multi_modal_projector(
                stacked_image_features)
475

476
477
478
479
480
            return stacked_patch_embeddings.view(
                b, num_patches, *stacked_patch_embeddings.shape[1:])

        num_patches_per_batch = [v.shape[0] for v in pixel_values]
        stacked_pixel_values = torch.cat(pixel_values)
481
482
483
        stacked_image_features = self._image_pixels_to_features(
            self.vision_tower, stacked_pixel_values)

484
485
486
487
        return [
            self.multi_modal_projector(image_features) for image_features in
            torch.split(stacked_image_features, num_patches_per_batch)
        ]
488
489

    def _process_image_input(
490
491
492
        self,
        image_input: LlavaNextImageInputs,
    ) -> Union[torch.Tensor, List[torch.Tensor]]:
493
494
495
496

        if image_input["type"] == "image_embeds":
            return [image_input["data"]]

497
        patch_embeddings = self._process_image_pixels(image_input)
498
499
500

        image_sizes = image_input.get("image_sizes")
        if image_sizes is None:
501
            batch_size = len(image_input["data"])
502
            vision_config = self.config.vision_config
503
504
            default_height = default_width = vision_config.image_size
            image_sizes = torch.as_tensor([[default_height, default_width]
505
506
                                           for _ in range(batch_size)])

507
        return [
508
            self._merge_image_patch_embeddings(image_sizes[i],
509
                                               patch_features_batch,
510
                                               strategy="spatial_unpad")
511
            for i, patch_features_batch in enumerate(patch_embeddings)
512
513
514
515
516
517
518
519
        ]

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
520
        intermediate_tensors: Optional[IntermediateTensors] = None,
521
522
        **kwargs: object,
    ) -> SamplerOutput:
Cyrus Leung's avatar
Cyrus Leung committed
523
        """Run forward pass for LlaVA-NeXT.
524
525
526

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
527

528
        Concretely, consider a text prompt:
529
530
531
532
533
        `"A chat between a curious human and an artificial intelligence
        assistant. The assistant gives helpful, detailed, and polite answers to
        the human's questions.
        USER: <image>\\nWhat is shown in this image? ASSISTANT:"`.

534
        Tokenizer outputs:
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
        `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
        29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
        6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
        29871, 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973, 319, 1799,
        9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
        before they are inputted to the model, so the input processor prepends 
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
        29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
        6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
        29871, 32000, ..., 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973,
        319, 1799, 9047, 13566, 29901]`.

        Unlike in LLaVA-1.5, the number of image tokens inputted to the language
        model depends on the original size of the input image. Including the
        original image token in the input, the required number of image tokens
        is given by :func:`get_llava_next_image_feature_size`.
554
555
556
557
558
559
560

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
561
            pixel_values: The pixels in each grid patch for each input image.
562
            image_sizes: The original `(height, width)` for each input image.
563
        
Cyrus Leung's avatar
Cyrus Leung committed
564
        See also:
565
            :class:`LlavaNextImageInputs`
566
567
568
569
570
        """
        image_input = self._parse_and_validate_image_input(**kwargs)

        if image_input is not None:
            vision_embeddings = self._process_image_input(image_input)
571
572
            inputs_embeds = self.language_model.model.get_input_embeddings(
                input_ids)
573

574
            inputs_embeds = merge_multimodal_embeddings(
575
                input_ids, inputs_embeds, vision_embeddings,
576
                self.config.image_token_index)
577
578
579
580
581

            input_ids = None
        else:
            inputs_embeds = None

582
583
584
585
586
587
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  kv_caches,
                                                  attn_metadata,
                                                  None,
                                                  inputs_embeds=inputs_embeds)
588
589
590

        return hidden_states

591
592
593
594
595
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
596
597
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
598
599
600
601
602
603

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
604
        return self.language_model.sample(logits, sampling_metadata)
605
606

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
        # prepare weight iterators for components
        vit_weights, mlp_weights, newline_weights, llm_weights = itertools.tee(
            weights, 4)

        # load vision encoder
        vit_weights = filter_weights(vit_weights, "vision_tower")
        self.vision_tower.load_weights(vit_weights)

        # load mlp projector
        mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
        mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
        for name, loaded_weight in mlp_weights:
            param = mlp_params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)

        # load newline
        newline_weights = filter_weights(newline_weights, "image_newline")
        for name, loaded_weight in newline_weights:
            assert name == ""
            param = self.image_newline
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)

        # load llm backbone
        llm_weights = filter_weights(llm_weights, "language_model")
        self.language_model.load_weights(llm_weights)