llava_next.py 23.5 KB
Newer Older
1
import itertools
2
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
3
4
5

import torch
import torch.nn as nn
6
from PIL import Image
7
from transformers import CLIPVisionConfig, LlavaNextConfig, SiglipVisionConfig
8
9
10
11
12
from transformers.models.llava_next.modeling_llava_next import (
    get_anyres_image_grid_shape, unpad_image)
from typing_extensions import NotRequired

from vllm.attention import AttentionMetadata
13
from vllm.config import CacheConfig, MultiModalConfig
14
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
15
16
17
18
19
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
20
from vllm.multimodal import MULTIMODAL_REGISTRY
21
from vllm.sequence import IntermediateTensors, SamplerOutput
22

23
24
from .clip import (CLIPVisionModel, dummy_image_for_clip,
                   dummy_seq_data_for_clip, get_clip_image_feature_size,
25
                   get_clip_patch_grid_length, input_processor_for_clip)
26
from .interfaces import SupportsVision
27
from .llava import LlavaMultiModalProjector
28
29
30
31
32
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
                     dummy_seq_data_for_siglip, get_siglip_image_feature_size,
                     get_siglip_patch_grid_length, input_processor_for_siglip)
from .utils import (filter_weights, init_vllm_registered_model,
                    merge_vision_embeddings)
33
34
35
36
37
38
39
40

logger = init_logger(__name__)

_KEYS_TO_MODIFY_MAPPING = {
    "language_model.lm_head": "lm_head",
    "language_model.model": "language_model",
}

41
42
43
# Result in the max possible feature size (2x2 grid of 336x336px tiles)
MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448

44
45
46

class LlavaNextImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
47
    data: Union[torch.Tensor, List[torch.Tensor]]
48
49
50
    """
    Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`

51
52
    Note that `num_patches` may be different for each batch, in which case
    the data is passed as a list instead of a batched tensor.
53
    """
54
55

    image_sizes: NotRequired[torch.Tensor]
56
57
58
59
60
    """
    Shape: `(batch_size, 2)`

    This should be in `(height, width)` format.
    """
61
62


63
LlavaNextImageInputs = LlavaNextImagePixelInputs
64
65


66
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79
67
def _get_llava_next_num_unpadded_features(
68
69
    original_height: int,
    original_width: int,
70
71
72
73
74
75
76
    npatches: int,
    num_patch_height: int,
    num_patch_width: int,
) -> Tuple[int, int]:
    current_height = npatches * num_patch_height
    current_width = npatches * num_patch_width

77
78
79
    aspect_ratio = original_width / original_height
    current_aspect_ratio = current_width / current_height

80
    if aspect_ratio > current_aspect_ratio:
81
        new_height = (original_height * current_width) // original_width
82
83
        padding = (current_height - new_height) // 2
        current_height -= padding * 2
84
    else:
85
        new_width = (original_width * current_height) // original_height
86
87
        padding = (current_width - new_width) // 2
        current_width -= padding * 2
88
89
90
91
92
93

    unpadded_features = current_height * current_width
    newline_features = current_height
    return (unpadded_features, newline_features)


94
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106
95
def get_llava_next_image_feature_size(
96
97
98
99
100
101
102
103
104
105
106
107
    hf_config: LlavaNextConfig,
    *,
    input_height: int,
    input_width: int,
) -> int:
    vision_config = hf_config.vision_config

    if isinstance(vision_config, CLIPVisionConfig):
        num_patches = get_clip_patch_grid_length(
            image_size=vision_config.image_size,
            patch_size=vision_config.patch_size,
        )
108
109
110
111
112
        base_feature_size = get_clip_image_feature_size(vision_config)
    elif isinstance(vision_config, SiglipVisionConfig):
        num_patches = get_siglip_patch_grid_length(
            image_size=vision_config.image_size,
            patch_size=vision_config.patch_size,
113
        )
114
115
116
117
118
119
120
121
122
123
124
125
        base_feature_size = get_siglip_image_feature_size(vision_config)
    else:
        msg = f"Unsupported vision config: {type(vision_config)}"
        raise NotImplementedError(msg)

    strategy = hf_config.vision_feature_select_strategy
    if strategy == "default":
        base_feature_size -= 1
    elif strategy == "full":
        pass
    else:
        raise ValueError(f"Unexpected select feature strategy: {strategy}")
126

127
128
129
130
131
    num_patch_height, num_patch_width = get_anyres_image_grid_shape(
        image_size=(input_height, input_width),
        grid_pinpoints=hf_config.image_grid_pinpoints,
        patch_size=vision_config.image_size,
    )
132

133
134
135
136
137
138
    (
        unpadded_feature_size,
        newline_feature_size,
    ) = _get_llava_next_num_unpadded_features(input_height, input_width,
                                              num_patches, num_patch_height,
                                              num_patch_width)
139

140
    return unpadded_feature_size + newline_feature_size + base_feature_size
141
142


143
144
145
def get_max_llava_next_image_tokens(ctx: InputContext):
    return get_llava_next_image_feature_size(
        ctx.get_hf_config(LlavaNextConfig),
146
147
        input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
        input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
148
149
150
    )


151
152
153
154
def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
    hf_config = ctx.get_hf_config(LlavaNextConfig)
    vision_config = hf_config.vision_config

155
    image_feature_size = get_max_llava_next_image_tokens(ctx)
156
157
158
159
160
161
162
163
164

    if isinstance(vision_config, CLIPVisionConfig):
        seq_data = dummy_seq_data_for_clip(
            vision_config,
            seq_len,
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )

165
166
        mm_data = dummy_image_for_clip(
            vision_config,
167
168
            image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
            image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
169
        )
170

171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
        return seq_data, mm_data
    elif isinstance(vision_config, SiglipVisionConfig):
        seq_data = dummy_seq_data_for_siglip(
            vision_config,
            seq_len,
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )

        mm_data = dummy_image_for_siglip(
            vision_config,
            image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
            image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
        )

186
187
188
189
190
191
        return seq_data, mm_data

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


192
193
194
195
def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
    multi_modal_data = llm_inputs.get("multi_modal_data")
    if multi_modal_data is None or "image" not in multi_modal_data:
        return llm_inputs
196

197
198
199
    model_config = ctx.model_config
    hf_config = ctx.get_hf_config(LlavaNextConfig)
    vision_config = hf_config.vision_config
200

201
202
203
204
205
206
207
208
209
210
211
212
213
    image_data = multi_modal_data["image"]
    if isinstance(image_data, Image.Image):
        width, height = image_data.size

        image_feature_size = get_llava_next_image_feature_size(
            hf_config,
            input_height=height,
            input_width=width,
        )
    elif isinstance(image_data, torch.Tensor):
        raise NotImplementedError("Embeddings input is not supported yet")
    else:
        raise TypeError(f"Invalid image type: {type(image_data)}")
214

215
    vision_config = hf_config.vision_config
216

217
218
219
220
221
222
223
224
    if isinstance(vision_config, CLIPVisionConfig):
        return input_processor_for_clip(
            model_config,
            vision_config,
            llm_inputs,
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
    elif isinstance(vision_config, SiglipVisionConfig):
        return input_processor_for_siglip(
            model_config,
            vision_config,
            llm_inputs,
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


def _init_vision_tower(hf_config: LlavaNextConfig):
    vision_config = hf_config.vision_config

    # Initialize the vision tower only up to the required feature layer
    vision_feature_layer = hf_config.vision_feature_layer
    if vision_feature_layer < 0:
        num_hidden_layers = hf_config.vision_config.num_hidden_layers \
            + vision_feature_layer + 1
    else:
        num_hidden_layers = vision_feature_layer + 1

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
            num_hidden_layers_override=num_hidden_layers,
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
            num_hidden_layers_override=num_hidden_layers,
        )
259

260
261
    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)
262
263


264
@MULTIMODAL_REGISTRY.register_image_input_mapper()
265
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens)
266
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
267
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next)
268
269
class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):

270
271
    def __init__(self,
                 config: LlavaNextConfig,
272
                 multimodal_config: MultiModalConfig,
273
274
                 cache_config: Optional[CacheConfig] = None,
                 quant_config: Optional[QuantizationConfig] = None) -> None:
275
        super().__init__()
276
277

        self.config = config
278
        self.multimodal_config = multimodal_config
279

280
        # TODO: Optionally initializes this for supporting embeddings.
281
        self.vision_tower = _init_vision_tower(config)
282
283
284
285
286
        self.multi_modal_projector = LlavaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            text_hidden_size=config.text_config.hidden_size,
            projector_hidden_act=config.projector_hidden_act)

287
288
        self.language_model = init_vllm_registered_model(
            config.text_config, cache_config, quant_config)
289
290
291
292
293
294
295
296
297
298
299
300

        self.image_newline = nn.Parameter(
            torch.empty(config.text_config.hidden_size))

    def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
        if list(data.shape[1:]) != [2]:
            raise ValueError(
                f"The expected image sizes shape is batch dimension plus "
                f"{[2]}. You supplied {data.shape}.")

        return data

301
302
303
304
    def _validate_pixel_values(
        self, data: Union[torch.Tensor, List[torch.Tensor]]
    ) -> Union[torch.Tensor, List[torch.Tensor]]:

305
306
307
308
309
310
311
312
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape[1:])

            if actual_dims != expected_dims:
                expected_expr = ("num_patches", *map(str, expected_dims))
313
                raise ValueError(
314
315
                    "The expected shape of pixel values in each batch element "
                    f"is {expected_expr}. You supplied {tuple(d.shape)}.")
316

317
318
        for d in data:
            _validate_shape(d)
319
320
321

        return data

322
    def _parse_and_validate_image_input(
323
            self, **kwargs: object) -> Optional[LlavaNextImagePixelInputs]:
324
325
326
        pixel_values = kwargs.pop("pixel_values", None)
        image_sizes = kwargs.pop("image_sizes", None)

327
        if pixel_values is None:
328
            return None
329

330
        if not isinstance(pixel_values, (torch.Tensor, list)):
331
332
            raise ValueError("Incorrect type of pixel values. "
                             f"Got type: {type(pixel_values)}")
333

334
335
336
        if not isinstance(image_sizes, torch.Tensor):
            raise ValueError("Incorrect type of image sizes. "
                             f"Got type: {type(image_sizes)}")
337

338
339
        return LlavaNextImagePixelInputs(
            type="pixel_values",
340
            data=self._validate_pixel_values(pixel_values),
341
342
            image_sizes=self._validate_image_sizes(image_sizes),
        )
343

Cyrus Leung's avatar
Cyrus Leung committed
344
345
346
347
348
349
350
351
352
353
    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

354
355
356
357
358
    def _image_pixels_to_features(
        self,
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
Cyrus Leung's avatar
Cyrus Leung committed
359

360
361
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
362
        image_features = vision_tower(pixel_values)
Cyrus Leung's avatar
Cyrus Leung committed
363
364
365
366
367
368

        return self._select_image_features(
            image_features,
            strategy=self.config.vision_feature_select_strategy,
        )

369
    # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
    def _merge_image_patch_embeddings(self, image_size: torch.Tensor,
                                      patch_embeddings: torch.Tensor, *,
                                      strategy: str) -> torch.Tensor:
        if strategy == "flat":
            return patch_embeddings.flatten(0, 1)

        if strategy.startswith("spatial"):
            height = width = self.config.vision_config.image_size \
                // self.config.vision_config.patch_size

            base_patch_embeds = patch_embeddings[0]
            if height * width != base_patch_embeds.shape[0]:
                raise ValueError(
                    "The number of patches is not consistent with the "
                    "image size.")

            if patch_embeddings.shape[0] > 1:
                other_patch_embeds = patch_embeddings[1:]

389
390
391
                # Move to CPU to avoid floating-point errors
                orig_height, orig_width = image_size.tolist()

392
                # image_aspect_ratio == "anyres"
393
394
                num_patch_height, num_patch_width = get_anyres_image_grid_shape(
                    (orig_height, orig_width),
395
396
397
398
                    self.config.image_grid_pinpoints,
                    self.config.vision_config.image_size,
                )
                other_patch_embeds = other_patch_embeds \
399
                    .view(num_patch_height, num_patch_width, height, width, -1)
400
401
402
403
404
405

                if "unpad" in strategy:
                    other_patch_embeds = other_patch_embeds \
                        .permute(4, 0, 2, 1, 3).contiguous() \
                        .flatten(1, 2).flatten(2, 3)
                    other_patch_embeds = unpad_image(other_patch_embeds,
406
                                                     (orig_height, orig_width))
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
                    other_patch_embeds = torch.cat((
                        other_patch_embeds,
                        self.image_newline[:, None, None] \
                            .expand(*other_patch_embeds.shape[:-1], 1) \
                            .to(other_patch_embeds.device),
                    ), dim=-1)
                    other_patch_embeds = other_patch_embeds \
                        .flatten(1, 2).transpose(0, 1)
                else:
                    other_patch_embeds = other_patch_embeds \
                        .permute(0, 2, 1, 3, 4).contiguous() \
                        .flatten(0, 3)

                merged_patch_embeddings = torch.cat(
                    (base_patch_embeds, other_patch_embeds), dim=0)
            else:
                if "unpad" in strategy:
                    merged_patch_embeddings = torch.cat(
                        (base_patch_embeds,
                         self.image_newline[None] \
                            .to(base_patch_embeds.device)
                    ), dim=0)
                else:
                    merged_patch_embeddings = base_patch_embeds

            return merged_patch_embeddings

        raise ValueError(f"Unexpected patch merge strategy: {strategy}")

    def _process_image_pixels(
437
438
        self,
        inputs: LlavaNextImagePixelInputs,
439
    ) -> Union[torch.Tensor, List[torch.Tensor]]:
440
441
442
443
        assert self.vision_tower is not None

        pixel_values = inputs["data"]

444
445
446
447
448
449
450
        if isinstance(pixel_values, torch.Tensor):
            b, num_patches, c, h, w = pixel_values.shape
            stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
            stacked_image_features = self._image_pixels_to_features(
                self.vision_tower, stacked_pixel_values)
            stacked_patch_embeddings = self.multi_modal_projector(
                stacked_image_features)
451

452
453
454
455
456
            return stacked_patch_embeddings.view(
                b, num_patches, *stacked_patch_embeddings.shape[1:])

        num_patches_per_batch = [v.shape[0] for v in pixel_values]
        stacked_pixel_values = torch.cat(pixel_values)
457
458
459
        stacked_image_features = self._image_pixels_to_features(
            self.vision_tower, stacked_pixel_values)

460
461
462
463
        return [
            self.multi_modal_projector(image_features) for image_features in
            torch.split(stacked_image_features, num_patches_per_batch)
        ]
464
465

    def _process_image_input(
466
467
468
        self,
        image_input: LlavaNextImageInputs,
    ) -> Union[torch.Tensor, List[torch.Tensor]]:
469
        patch_embeddings = self._process_image_pixels(image_input)
470
471
472

        image_sizes = image_input.get("image_sizes")
        if image_sizes is None:
473
            batch_size = len(image_input["data"])
474
            vision_config = self.config.vision_config
475
476
            default_height = default_width = vision_config.image_size
            image_sizes = torch.as_tensor([[default_height, default_width]
477
478
                                           for _ in range(batch_size)])

479
        return [
480
            self._merge_image_patch_embeddings(image_sizes[i],
481
                                               patch_features_batch,
482
                                               strategy="spatial_unpad")
483
            for i, patch_features_batch in enumerate(patch_embeddings)
484
485
486
487
488
489
490
491
        ]

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
492
        intermediate_tensors: Optional[IntermediateTensors] = None,
493
494
        **kwargs: object,
    ) -> SamplerOutput:
Cyrus Leung's avatar
Cyrus Leung committed
495
        """Run forward pass for LlaVA-NeXT.
496
497
498

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
499

500
        Concretely, consider a text prompt:
501
502
503
504
505
        `"A chat between a curious human and an artificial intelligence
        assistant. The assistant gives helpful, detailed, and polite answers to
        the human's questions.
        USER: <image>\\nWhat is shown in this image? ASSISTANT:"`.

506
        Tokenizer outputs:
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
        `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
        29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
        6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
        29871, 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973, 319, 1799,
        9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
        before they are inputted to the model, so the input processor prepends 
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
        29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
        6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
        29871, 32000, ..., 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973,
        319, 1799, 9047, 13566, 29901]`.

        Unlike in LLaVA-1.5, the number of image tokens inputted to the language
        model depends on the original size of the input image. Including the
        original image token in the input, the required number of image tokens
        is given by :func:`get_llava_next_image_feature_size`.
526
527
528
529
530
531
532

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
533
            pixel_values: The pixels in each grid patch for each input image.
534
            image_sizes: The original `(height, width)` for each input image.
535
        
Cyrus Leung's avatar
Cyrus Leung committed
536
        See also:
537
            :class:`LlavaNextImageInputs`
538
539
540
541
542
        """
        image_input = self._parse_and_validate_image_input(**kwargs)

        if image_input is not None:
            vision_embeddings = self._process_image_input(image_input)
543
544
            inputs_embeds = self.language_model.model.get_input_embeddings(
                input_ids)
545
546
547

            inputs_embeds = merge_vision_embeddings(
                input_ids, inputs_embeds, vision_embeddings,
548
                self.config.image_token_index)
549
550
551
552
553

            input_ids = None
        else:
            inputs_embeds = None

554
555
556
557
558
559
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  kv_caches,
                                                  attn_metadata,
                                                  None,
                                                  inputs_embeds=inputs_embeds)
560
561
562
563
564

        return hidden_states

    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
565
566
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
567
568
569
570
571
572

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
573
        return self.language_model.sample(logits, sampling_metadata)
574
575

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
        # prepare weight iterators for components
        vit_weights, mlp_weights, newline_weights, llm_weights = itertools.tee(
            weights, 4)

        # load vision encoder
        vit_weights = filter_weights(vit_weights, "vision_tower")
        self.vision_tower.load_weights(vit_weights)

        # load mlp projector
        mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
        mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
        for name, loaded_weight in mlp_weights:
            param = mlp_params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)

        # load newline
        newline_weights = filter_weights(newline_weights, "image_newline")
        for name, loaded_weight in newline_weights:
            assert name == ""
            param = self.image_newline
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)

        # load llm backbone
        llm_weights = filter_weights(llm_weights, "language_model")
        self.language_model.load_weights(llm_weights)