siglip2navit.py 23.7 KB
Newer Older
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Implementation of SiglipVisionModel intended to be only used
within a vision language model."""

6
from collections.abc import Iterable
7
8
9
10

import torch
from torch import nn
from torch.nn import functional as F
11
from transformers import Siglip2VisionConfig
12
13
from transformers.configuration_utils import PretrainedConfig

14
15
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import MultiModalConfig
16
17
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
18
from vllm.model_executor.layers.conv import Conv2dLayer
19
20
21
22
23
24
25
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    LinearBase,
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
26
from vllm.model_executor.layers.quantization import QuantizationConfig
27
28
29
from vllm.model_executor.layers.rotary_embedding.common import (
    ApplyRotaryEmb,
)
30
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
31
from vllm.platforms import current_platform
32
33
34
35
36


class VisionRotaryEmbedding(nn.Module):
    def __init__(self, dim: int, theta: float = 10000.0) -> None:
        super().__init__()
37
        inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
38
39
40
        self.register_buffer("inv_freq", inv_freq, persistent=False)

    def forward(self, seqlen: int) -> torch.Tensor:
41
42
43
        seq = torch.arange(
            seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
        )
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
        freqs = torch.outer(seq, self.inv_freq)
        return freqs


class Siglip2VisionEmbeddings(nn.Module):
    def __init__(self, config: PretrainedConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.patch_size = config.patch_size
        self.image_size = config.image_size
        self.num_patches = config.num_patches
        self.preserve_original_pe = config.preserve_original_pe
        self.hidden_stride = config.hidden_stride

        # siglip2 naflex
        if self.num_patches > 0:
61
            self.patch_embedding = ReplicatedLinear(
62
                input_size=config.num_channels * self.patch_size * self.patch_size,
63
64
                output_size=self.embed_dim,
                return_bias=False,
65
66
67
            )
            if self.preserve_original_pe:
                self.position_embedding_size = int(self.num_patches**0.5)
68
                self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim)
69
70

        else:
71
            self.patch_embedding = Conv2dLayer(
72
73
74
75
76
77
78
                in_channels=config.num_channels,
                out_channels=self.embed_dim,
                kernel_size=self.patch_size,
                stride=self.patch_size,
                padding="valid",
            )
            if self.preserve_original_pe:
79
80
81
82
83
84
85
                self.num_patches = (self.image_size // self.patch_size) ** 2
                self.position_embedding_size = self.image_size // self.patch_size
                self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim)

    def forward(
        self,
        pixel_values: torch.FloatTensor,
86
        grid_thws: torch.LongTensor | None = None,
87
    ) -> torch.Tensor:
88
89
90
91
92
93
94
95
96
97
98
99
100
        """
        Args:
            pixel_values (`torch.FloatTensor`):
                Pixel values of shape (
                    num_patches,
                    num_channels * temporal_patch_size * patch_size * patch_size
                )
            grid_thws: (`torch.LongTensor`):
                grid shape (num_patches, 3)
        """

        # Apply patch embeddings to already patchified pixel values
        target_dtype = self.patch_embedding.weight.dtype
101
        if isinstance(self.patch_embedding, LinearBase):
102
            patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
103
        elif isinstance(self.patch_embedding, Conv2dLayer):
104
            pixel_values = pixel_values.view(
105
106
107
108
109
110
                -1,
                self.config.num_channels * self.config.temporal_patch_size,
                self.patch_size,
                self.patch_size,
            )
            patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
111
112
113
114
115
            patch_embeds = patch_embeds.reshape(-1, self.embed_dim)

        if self.preserve_original_pe:
            assert grid_thws is not None
            pos_embed_new = torch.zeros_like(patch_embeds)
116
117
118
119
120
121
122
            positional_embeddings = (
                self.position_embedding.weight.reshape(
                    self.position_embedding_size, self.position_embedding_size, -1
                )
                .unsqueeze(0)
                .permute(0, 3, 1, 2)
            )
123
124
125
            cnt = 0
            for t, h, w in grid_thws:
                volume = t * h * w
126
127
128
129
130
131
                pe = F.interpolate(
                    positional_embeddings,
                    size=(h, w),
                    mode="bicubic",
                    align_corners=False,
                )
132
133
                pe = pe.permute(0, 2, 3, 1).reshape(1, h * w, -1)
                pe = pe[0].repeat(t, 1)
134
135
136
137
138
139
140
141
                pe = pe.reshape(
                    t,
                    h // self.hidden_stride,
                    self.hidden_stride,
                    w // self.hidden_stride,
                    self.hidden_stride,
                    -1,
                )
142
                pe = pe.permute(0, 1, 3, 2, 4, 5).reshape(volume, -1)
143
                pos_embed_new[cnt : cnt + volume] = pe
144
145
146
147
148
149
150
151
152
153
154
                cnt += volume
            patch_embeds = patch_embeds + pos_embed_new

        return patch_embeds


def apply_rotary_pos_emb(
    q: torch.Tensor,
    k: torch.Tensor,
    cos: torch.Tensor,
    sin: torch.Tensor,
155
156
    is_flash_attn_backend: bool,
    apply_rotary_emb: ApplyRotaryEmb,
157
158
159
) -> tuple[torch.Tensor, torch.Tensor]:
    cos = cos.chunk(2, dim=-1)[0].contiguous()
    sin = sin.chunk(2, dim=-1)[0].contiguous()
160

161
    if is_flash_attn_backend and current_platform.is_cuda():
162
        apply_rotary_emb_func = apply_rotary_emb.forward_cuda
163
164
    elif is_flash_attn_backend and current_platform.is_rocm():
        apply_rotary_emb_func = apply_rotary_emb.forward_hip
165
    else:
166
167
168
169
170
        apply_rotary_emb_func = apply_rotary_emb.forward_native

    q_embed = apply_rotary_emb_func(q, cos, sin)
    k_embed = apply_rotary_emb_func(k, cos, sin)

171
172
173
174
175
176
    return q_embed, k_embed


class Siglip2Attention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

177
178
179
    def __init__(
        self,
        config: Siglip2VisionConfig,
180
        quant_config: QuantizationConfig | None = None,
181
        multimodal_config: MultiModalConfig | None = None,
182
183
184
        prefix: str = "",
        use_data_parallel: bool = False,
    ):
185
186
187
188
189
190
191
192
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads "
193
194
                f"(got `embed_dim`: {self.embed_dim} and "
                f"`num_heads`: {self.num_heads})."
195
            )
196
197
198
        self.scale = self.head_dim**-0.5
        self.dropout = config.attention_dropout

199
200
201
202
203
        use_data_parallel = (
            multimodal_config.mm_encoder_tp_mode == "data"
            if multimodal_config
            else False
        )
204
205
206
207
208
209
        self.qkv_proj = QKVParallelLinear(
            hidden_size=self.embed_dim,
            head_size=self.head_dim,
            total_num_heads=self.num_heads,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
210
            disable_tp=use_data_parallel,
211
212
213
214
215
216
        )
        self.out_proj = RowParallelLinear(
            input_size=self.embed_dim,
            output_size=self.embed_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.out_proj",
217
            disable_tp=use_data_parallel,
218
        )
219

220
221
222
        self.tp_size = (
            1 if use_data_parallel else get_tensor_model_parallel_world_size()
        )
223
        self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
224
225
        self.use_rope = config.use_rope

226
227
        self.attn = MMEncoderAttention(
            num_heads=self.num_heads_per_partition,
228
            head_size=self.head_dim,
229
            scale=self.scale,
230
231
            prefix=f"{prefix}.attn",
            multimodal_config=multimodal_config,
232
        )
233

234
235
236
237
238
        self.apply_rotary_emb = ApplyRotaryEmb(
            enforce_enable=True,
            enable_fp32_compute=True,
        )

239
240
241
242
    def forward(
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor,
243
244
        position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
245
246
247
248
        """Input shape: Batch x Time x Channel"""

        seq_length, embed_dim = hidden_states.shape

249
250
        qkv_states, _ = self.qkv_proj(hidden_states)
        queries, keys, values = qkv_states.chunk(3, dim=-1)
251

252
253
254
        queries = queries.view(seq_length, self.num_heads_per_partition, self.head_dim)
        keys = keys.view(seq_length, self.num_heads_per_partition, self.head_dim)
        values = values.view(seq_length, self.num_heads_per_partition, self.head_dim)
255
256
257

        if self.use_rope:
            cos, sin = position_embeddings
258
259
260
261
262
            queries, keys = apply_rotary_pos_emb(
                queries.unsqueeze(0),
                keys.unsqueeze(0),
                cos,
                sin,
263
                self.attn.is_flash_attn_backend,
264
                self.apply_rotary_emb,
265
            )
266
267
268
            queries = queries.squeeze(0)
            keys = keys.squeeze(0)

269
270
271
272
273
274
275
276
277
278
279
280
        max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
        attn_output = self.attn(
            query=queries.unsqueeze(0),
            key=keys.unsqueeze(0),
            value=values.unsqueeze(0),
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seqlen,
        )
        attn_output = attn_output.reshape(
            seq_length, self.num_heads_per_partition * self.head_dim
        )

281
        attn_output, _ = self.out_proj(attn_output)
282
283
284
285
        return attn_output


class Siglip2MLP(nn.Module):
286
287
288
    def __init__(
        self,
        config: Siglip2VisionConfig,
289
        quant_config: QuantizationConfig | None = None,
290
        multimodal_config: MultiModalConfig | None = None,
291
292
        prefix: str = "",
    ):
293
294
        super().__init__()
        self.config = config
295
296
297
298
299
        use_data_parallel = (
            multimodal_config.mm_encoder_tp_mode == "data"
            if multimodal_config
            else False
        )
300
301
302
303
304
305
        self.activation_fn = get_act_fn(config.hidden_act)
        self.fc1 = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
306
            disable_tp=use_data_parallel,
307
308
309
310
311
312
        )
        self.fc2 = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
313
            disable_tp=use_data_parallel,
314
        )
315
316

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
317
        hidden_states, _ = self.fc1(hidden_states)
318
        hidden_states = self.activation_fn(hidden_states)
319
        hidden_states, _ = self.fc2(hidden_states)
320
321
322
323
        return hidden_states


class Siglip2EncoderLayer(nn.Module):
324
325
326
    def __init__(
        self,
        config: Siglip2VisionConfig,
327
        quant_config: QuantizationConfig | None = None,
328
        multimodal_config: MultiModalConfig | None = None,
329
330
        prefix: str = "",
    ):
331
332
        super().__init__()
        self.embed_dim = config.hidden_size
333
334
335
336
        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
        self.self_attn = Siglip2Attention(
            config,
            quant_config=quant_config,
337
            multimodal_config=multimodal_config,
338
339
340
341
342
343
            prefix=f"{prefix}.self_attn",
        )
        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
        self.mlp = Siglip2MLP(
            config,
            quant_config=quant_config,
344
            multimodal_config=multimodal_config,
345
346
347
348
349
350
351
352
353
            prefix=f"{prefix}.mlp",
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor,
        position_embeddings: torch.Tensor,
    ) -> tuple[torch.FloatTensor]:
354
355
        """
        Args:
356
357
358
            hidden_states: Input tensor of shape (batch, seq_len, embed_dim).
            cu_seqlens: Cumulative sequence lengths tensor.
            position_embeddings: Position embeddings tensor.
359
360
361
362
        """
        residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
363
364
365
366
367
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            cu_seqlens=cu_seqlens,
            position_embeddings=position_embeddings,
        )
368
369
370
371
372
373
374
375
376
377
378
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states
        return hidden_states


class Siglip2Encoder(nn.Module):
    """
379
    Transformer encoder consisting of `config.num_hidden_layers`
380
381
382
383
384
385
    self attention layers. Each layer is a [`Siglip2EncoderLayer`].

    Args:
        config: PretrainedConfig
    """

386
387
388
    def __init__(
        self,
        config: Siglip2VisionConfig,
389
        quant_config: QuantizationConfig | None = None,
390
        multimodal_config: MultiModalConfig | None = None,
391
392
        prefix: str = "",
    ):
393
394
        super().__init__()
        self.config = config
395
396
397
398
399
        self.layers = nn.ModuleList(
            [
                Siglip2EncoderLayer(
                    config,
                    quant_config=quant_config,
400
                    multimodal_config=multimodal_config,
401
402
403
404
405
                    prefix=f"{prefix}.layers.{idx}",
                )
                for idx in range(config.num_hidden_layers)
            ]
        )
406
407

        self.rotary_pos_emb = VisionRotaryEmbedding(
408
409
            config.hidden_size // config.num_attention_heads // 2
        )
410
411
412
413
414
415
416
417
        self.patch_size = config.patch_size
        self.hidden_stride = config.hidden_stride
        self.window_size = config.window_size
        self.spatial_merge_unit = config.hidden_stride * config.hidden_stride
        if config.fullatt_block_indexes is None:
            self.fullatt_block_indexes = None
        else:
            self.fullatt_block_indexes = [
418
                int(i) for i in config.fullatt_block_indexes.split("|")
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
            ]

    # copied from qwen2.5_vl
    def rot_pos_emb(self, grid_thw):
        pos_ids = []
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            hpos_ids = hpos_ids.reshape(
                h // self.hidden_stride,
                self.hidden_stride,
                w // self.hidden_stride,
                self.hidden_stride,
            )
            hpos_ids = hpos_ids.permute(0, 2, 1, 3)
            hpos_ids = hpos_ids.flatten()

            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
            wpos_ids = wpos_ids.reshape(
                h // self.hidden_stride,
                self.hidden_stride,
                w // self.hidden_stride,
                self.hidden_stride,
            )
            wpos_ids = wpos_ids.permute(0, 2, 1, 3)
            wpos_ids = wpos_ids.flatten()
444
            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
445
446
447
448
449
450
451
452
453
454
455
        pos_ids = torch.cat(pos_ids, dim=0)
        max_grid_size = grid_thw[:, 1:].max()
        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
        return rotary_pos_emb

    def get_window_index(self, grid_thw):
        window_index: list = []
        cu_window_seqlens: list = [0]
        window_index_id = 0
        # patch (after merge) number in each window
456
457
458
        vit_merger_window_size = (
            self.window_size // self.hidden_stride // self.patch_size
        )
459
460
461
462
463
464
465

        for grid_t, grid_h, grid_w in grid_thw:
            llm_grid_h, llm_grid_w = (
                grid_h // self.hidden_stride,  # number of patch after merge
                grid_w // self.hidden_stride,
            )
            index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
466
467
                grid_t, llm_grid_h, llm_grid_w
            )
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
            pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
            pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
            num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
            num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
            index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
            index_padded = index_padded.reshape(
                grid_t,
                num_windows_h,
                vit_merger_window_size,
                num_windows_w,
                vit_merger_window_size,
            )
            index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
                grid_t,
                num_windows_h * num_windows_w,
                vit_merger_window_size,
                vit_merger_window_size,
            )
            seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
            index_padded = index_padded.reshape(-1)
            index_new = index_padded[index_padded != -100]
            window_index.append(index_new + window_index_id)
490
491
492
            cu_seqlens_tmp = (
                seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]
            )
493
494
495
496
497
498
499
500
            cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
            window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
        window_index = torch.cat(window_index, dim=0)

        return window_index, cu_window_seqlens

    def forward(
        self,
501
        inputs_embeds: torch.Tensor,
502
        grid_thws: torch.Tensor,
503
    ) -> torch.Tensor:
504
505
        r"""
        Args:
506
            inputs_embeds: Input tensor of shape
507
508
                (batch_size, sequence_length, hidden_size).
                Embedded representation of the input tokens.
509
            grid_thws: Grid tensor of shape (num_patches, 3)
510
                containing grid dimensions.
511
512
513
514
515
516
517
518
519
520
521
522
523
524
                Whether or not to return a [`~utils.ModelOutput`] instead of
                a plain tuple.
        """
        rotary_pos_emb = self.rot_pos_emb(grid_thws)
        window_index, cu_window_seqlens = self.get_window_index(grid_thws)
        cu_window_seqlens = torch.tensor(
            cu_window_seqlens,
            device=inputs_embeds.device,
            dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32,
        )
        cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)

        seq_len, _ = inputs_embeds.size()
        inputs_embeds = inputs_embeds.reshape(
525
526
            seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
        )
527
528
529
        inputs_embeds = inputs_embeds[window_index, :, :]
        inputs_embeds = inputs_embeds.reshape(seq_len, -1)
        rotary_pos_emb = rotary_pos_emb.reshape(
530
531
            seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
        )
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
        rotary_pos_emb = rotary_pos_emb[window_index, :, :]
        rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
        emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
        position_embeddings = (emb.cos(), emb.sin())

        cu_seqlens = torch.repeat_interleave(
            grid_thws[:, 1] * grid_thws[:, 2], grid_thws[:, 0]
        ).cumsum(
            dim=0,
            # Select dtype based on the following factors:
            #  - FA2 requires that cu_seqlens_q must have dtype int32
            #  - torch.onnx.export requires that cu_seqlens_q must have
            #    same dtype as grid_thw
            # See https://github.com/huggingface/transformers/pull/34852
            # for more information
            dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32,
        )
549
        cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
550
551
552
553
554

        reverse_indices = torch.argsort(window_index)

        hidden_states = inputs_embeds
        for index, block in enumerate(self.layers):
555
            if not self.fullatt_block_indexes or index in self.fullatt_block_indexes:
556
557
558
                cu_seqlens_tmp = cu_seqlens
            else:
                cu_seqlens_tmp = cu_window_seqlens
559
            hidden_states = block(hidden_states, cu_seqlens_tmp, position_embeddings)
560

561
        hidden_states = hidden_states.reshape(
562
563
            seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
        )
564
565
        hidden_states = hidden_states[reverse_indices, :].reshape(seq_len, -1)

566
        return hidden_states
567
568
569


class Siglip2VisionTransformer(nn.Module):
570
571
572
    def __init__(
        self,
        config: Siglip2VisionConfig,
573
        quant_config: QuantizationConfig | None = None,
574
        multimodal_config: MultiModalConfig | None = None,
575
576
        prefix: str = "",
    ):
577
578
579
580
581
        super().__init__()
        self.config = config
        embed_dim = config.hidden_size

        self.embeddings = Siglip2VisionEmbeddings(config)
582
583
584
        self.encoder = Siglip2Encoder(
            config,
            quant_config=quant_config,
585
            multimodal_config=multimodal_config,
586
587
588
            prefix=f"{prefix}.encoder",
        )
        self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
589
590
591
592
593

    def forward(
        self,
        pixel_values: torch.FloatTensor,
        grid_thws: torch.LongTensor,
594
    ) -> torch.Tensor:
595
596
597
598
599
600
601
        r"""
        spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
            Tensor containing the spatial dimensions (height, width)
            of the input images.
        """
        hidden_states = self.embeddings(pixel_values, grid_thws)

602
        last_hidden_state = self.encoder(hidden_states, grid_thws)
603
604
605
606
607
608
        last_hidden_state = self.post_layernorm(last_hidden_state)

        return last_hidden_state


class Siglip2NavitModel(torch.nn.Module):
609
610
611
    def __init__(
        self,
        config: Siglip2VisionConfig,
612
        quant_config: QuantizationConfig | None = None,
613
        multimodal_config: MultiModalConfig | None = None,
614
615
        prefix: str = "",
    ):
616
617
        super().__init__()

618
619
620
        self.vision_model = Siglip2VisionTransformer(
            config,
            quant_config=quant_config,
621
            multimodal_config=multimodal_config,
622
            prefix=f"{prefix}.vision_model",
623
        )
624
625
626
627
628

    def forward(
        self,
        pixel_values: torch.FloatTensor,
        grid_thws: torch.LongTensor,
629
    ) -> torch.Tensor:
630
631
632
633
        return self.vision_model(
            pixel_values=pixel_values,
            grid_thws=grid_thws,
        )
634

635
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
636
637
638
639
640
641
642
643
644
645
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()

        for name, loaded_weight in weights:
646
            for param_name, weight_name, shard_id in stacked_params_mapping:
647
648
649
650
651
652
653
654
655
656
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
657
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
658
659
660
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params