siglip2navit.py 23.7 KB
Newer Older
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Implementation of SiglipVisionModel intended to be only used
within a vision language model."""

6
from collections.abc import Iterable
7
8
9
10

import torch
from torch import nn
from torch.nn import functional as F
11
from transformers import Siglip2VisionConfig
12
13
from transformers.configuration_utils import PretrainedConfig

14
15
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import MultiModalConfig
16
17
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
18
from vllm.model_executor.layers.conv import Conv2dLayer
19
20
21
22
23
24
25
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    LinearBase,
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
26
from vllm.model_executor.layers.quantization import QuantizationConfig
27
28
29
from vllm.model_executor.layers.rotary_embedding.common import (
    ApplyRotaryEmb,
)
30
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
31
from vllm.platforms import current_platform
32
33
34
35
36


class VisionRotaryEmbedding(nn.Module):
    def __init__(self, dim: int, theta: float = 10000.0) -> None:
        super().__init__()
37
        inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
38
39
40
        self.register_buffer("inv_freq", inv_freq, persistent=False)

    def forward(self, seqlen: int) -> torch.Tensor:
41
42
43
        seq = torch.arange(
            seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
        )
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
        freqs = torch.outer(seq, self.inv_freq)
        return freqs


class Siglip2VisionEmbeddings(nn.Module):
    def __init__(self, config: PretrainedConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.patch_size = config.patch_size
        self.image_size = config.image_size
        self.num_patches = config.num_patches
        self.preserve_original_pe = config.preserve_original_pe
        self.hidden_stride = config.hidden_stride

        # siglip2 naflex
        if self.num_patches > 0:
61
            self.patch_embedding = ReplicatedLinear(
62
                input_size=config.num_channels * self.patch_size * self.patch_size,
63
64
                output_size=self.embed_dim,
                return_bias=False,
65
66
67
            )
            if self.preserve_original_pe:
                self.position_embedding_size = int(self.num_patches**0.5)
68
                self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim)
69
70

        else:
71
            self.patch_embedding = Conv2dLayer(
72
73
74
75
76
77
78
                in_channels=config.num_channels,
                out_channels=self.embed_dim,
                kernel_size=self.patch_size,
                stride=self.patch_size,
                padding="valid",
            )
            if self.preserve_original_pe:
79
80
81
82
83
84
85
                self.num_patches = (self.image_size // self.patch_size) ** 2
                self.position_embedding_size = self.image_size // self.patch_size
                self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim)

    def forward(
        self,
        pixel_values: torch.FloatTensor,
86
        grid_thws: torch.LongTensor | None = None,
87
    ) -> torch.Tensor:
88
89
90
91
92
93
94
95
96
97
98
99
100
        """
        Args:
            pixel_values (`torch.FloatTensor`):
                Pixel values of shape (
                    num_patches,
                    num_channels * temporal_patch_size * patch_size * patch_size
                )
            grid_thws: (`torch.LongTensor`):
                grid shape (num_patches, 3)
        """

        # Apply patch embeddings to already patchified pixel values
        target_dtype = self.patch_embedding.weight.dtype
101
        if isinstance(self.patch_embedding, LinearBase):
102
            patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
103
        elif isinstance(self.patch_embedding, Conv2dLayer):
104
            pixel_values = pixel_values.view(
105
106
107
108
109
110
                -1,
                self.config.num_channels * self.config.temporal_patch_size,
                self.patch_size,
                self.patch_size,
            )
            patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
111
112
113
114
115
            patch_embeds = patch_embeds.reshape(-1, self.embed_dim)

        if self.preserve_original_pe:
            assert grid_thws is not None
            pos_embed_new = torch.zeros_like(patch_embeds)
116
117
118
119
120
121
122
            positional_embeddings = (
                self.position_embedding.weight.reshape(
                    self.position_embedding_size, self.position_embedding_size, -1
                )
                .unsqueeze(0)
                .permute(0, 3, 1, 2)
            )
123
124
125
            cnt = 0
            for t, h, w in grid_thws:
                volume = t * h * w
126
127
128
129
130
131
                pe = F.interpolate(
                    positional_embeddings,
                    size=(h, w),
                    mode="bicubic",
                    align_corners=False,
                )
132
133
                pe = pe.permute(0, 2, 3, 1).reshape(1, h * w, -1)
                pe = pe[0].repeat(t, 1)
134
135
136
137
138
139
140
141
                pe = pe.reshape(
                    t,
                    h // self.hidden_stride,
                    self.hidden_stride,
                    w // self.hidden_stride,
                    self.hidden_stride,
                    -1,
                )
142
                pe = pe.permute(0, 1, 3, 2, 4, 5).reshape(volume, -1)
143
                pos_embed_new[cnt : cnt + volume] = pe
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
                cnt += volume
            patch_embeds = patch_embeds + pos_embed_new

        return patch_embeds


def apply_rotary_pos_emb(
    q: torch.Tensor,
    k: torch.Tensor,
    cos: torch.Tensor,
    sin: torch.Tensor,
    is_flash_attn_backend: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
    cos = cos.chunk(2, dim=-1)[0].contiguous()
    sin = sin.chunk(2, dim=-1)[0].contiguous()
159

160
161
162
163
164
    apply_rotary_emb = ApplyRotaryEmb(
        enforce_enable=True,
        enable_fp32_compute=True,
    )

165
    if is_flash_attn_backend and current_platform.is_cuda():
166
        apply_rotary_emb_func = apply_rotary_emb.forward_cuda
167
168
    elif is_flash_attn_backend and current_platform.is_rocm():
        apply_rotary_emb_func = apply_rotary_emb.forward_hip
169
    else:
170
171
172
173
174
        apply_rotary_emb_func = apply_rotary_emb.forward_native

    q_embed = apply_rotary_emb_func(q, cos, sin)
    k_embed = apply_rotary_emb_func(k, cos, sin)

175
176
177
178
179
180
    return q_embed, k_embed


class Siglip2Attention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

181
182
183
    def __init__(
        self,
        config: Siglip2VisionConfig,
184
        quant_config: QuantizationConfig | None = None,
185
        multimodal_config: MultiModalConfig | None = None,
186
187
188
        prefix: str = "",
        use_data_parallel: bool = False,
    ):
189
190
191
192
193
194
195
196
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads "
197
198
                f"(got `embed_dim`: {self.embed_dim} and "
                f"`num_heads`: {self.num_heads})."
199
            )
200
201
202
        self.scale = self.head_dim**-0.5
        self.dropout = config.attention_dropout

203
204
205
206
207
        use_data_parallel = (
            multimodal_config.mm_encoder_tp_mode == "data"
            if multimodal_config
            else False
        )
208
209
210
211
212
213
        self.qkv_proj = QKVParallelLinear(
            hidden_size=self.embed_dim,
            head_size=self.head_dim,
            total_num_heads=self.num_heads,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
214
            disable_tp=use_data_parallel,
215
216
217
218
219
220
        )
        self.out_proj = RowParallelLinear(
            input_size=self.embed_dim,
            output_size=self.embed_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.out_proj",
221
            disable_tp=use_data_parallel,
222
        )
223

224
225
226
        self.tp_size = (
            1 if use_data_parallel else get_tensor_model_parallel_world_size()
        )
227
        self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
228
229
        self.use_rope = config.use_rope

230
231
        self.attn = MMEncoderAttention(
            num_heads=self.num_heads_per_partition,
232
            head_size=self.head_dim,
233
            scale=self.scale,
234
235
            prefix=f"{prefix}.attn",
            multimodal_config=multimodal_config,
236
        )
237

238
239
240
241
    def forward(
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor,
242
243
        position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
244
245
246
247
        """Input shape: Batch x Time x Channel"""

        seq_length, embed_dim = hidden_states.shape

248
249
        qkv_states, _ = self.qkv_proj(hidden_states)
        queries, keys, values = qkv_states.chunk(3, dim=-1)
250

251
252
253
        queries = queries.view(seq_length, self.num_heads_per_partition, self.head_dim)
        keys = keys.view(seq_length, self.num_heads_per_partition, self.head_dim)
        values = values.view(seq_length, self.num_heads_per_partition, self.head_dim)
254
255
256

        if self.use_rope:
            cos, sin = position_embeddings
257
258
259
260
261
            queries, keys = apply_rotary_pos_emb(
                queries.unsqueeze(0),
                keys.unsqueeze(0),
                cos,
                sin,
262
                self.attn.is_flash_attn_backend,
263
            )
264
265
266
            queries = queries.squeeze(0)
            keys = keys.squeeze(0)

267
268
269
270
271
272
273
274
275
276
277
278
        max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
        attn_output = self.attn(
            query=queries.unsqueeze(0),
            key=keys.unsqueeze(0),
            value=values.unsqueeze(0),
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seqlen,
        )
        attn_output = attn_output.reshape(
            seq_length, self.num_heads_per_partition * self.head_dim
        )

279
        attn_output, _ = self.out_proj(attn_output)
280
281
282
283
        return attn_output


class Siglip2MLP(nn.Module):
284
285
286
    def __init__(
        self,
        config: Siglip2VisionConfig,
287
        quant_config: QuantizationConfig | None = None,
288
        multimodal_config: MultiModalConfig | None = None,
289
290
        prefix: str = "",
    ):
291
292
        super().__init__()
        self.config = config
293
294
295
296
297
        use_data_parallel = (
            multimodal_config.mm_encoder_tp_mode == "data"
            if multimodal_config
            else False
        )
298
299
300
301
302
303
        self.activation_fn = get_act_fn(config.hidden_act)
        self.fc1 = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
304
            disable_tp=use_data_parallel,
305
306
307
308
309
310
        )
        self.fc2 = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
311
            disable_tp=use_data_parallel,
312
        )
313
314

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
315
        hidden_states, _ = self.fc1(hidden_states)
316
        hidden_states = self.activation_fn(hidden_states)
317
        hidden_states, _ = self.fc2(hidden_states)
318
319
320
321
        return hidden_states


class Siglip2EncoderLayer(nn.Module):
322
323
324
    def __init__(
        self,
        config: Siglip2VisionConfig,
325
        quant_config: QuantizationConfig | None = None,
326
        multimodal_config: MultiModalConfig | None = None,
327
328
        prefix: str = "",
    ):
329
330
        super().__init__()
        self.embed_dim = config.hidden_size
331
332
333
334
        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
        self.self_attn = Siglip2Attention(
            config,
            quant_config=quant_config,
335
            multimodal_config=multimodal_config,
336
337
338
339
340
341
            prefix=f"{prefix}.self_attn",
        )
        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
        self.mlp = Siglip2MLP(
            config,
            quant_config=quant_config,
342
            multimodal_config=multimodal_config,
343
344
345
346
347
348
349
350
351
            prefix=f"{prefix}.mlp",
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor,
        position_embeddings: torch.Tensor,
    ) -> tuple[torch.FloatTensor]:
352
353
        """
        Args:
354
355
356
            hidden_states: Input tensor of shape (batch, seq_len, embed_dim).
            cu_seqlens: Cumulative sequence lengths tensor.
            position_embeddings: Position embeddings tensor.
357
358
359
360
        """
        residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
361
362
363
364
365
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            cu_seqlens=cu_seqlens,
            position_embeddings=position_embeddings,
        )
366
367
368
369
370
371
372
373
374
375
376
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states
        return hidden_states


class Siglip2Encoder(nn.Module):
    """
377
    Transformer encoder consisting of `config.num_hidden_layers`
378
379
380
381
382
383
    self attention layers. Each layer is a [`Siglip2EncoderLayer`].

    Args:
        config: PretrainedConfig
    """

384
385
386
    def __init__(
        self,
        config: Siglip2VisionConfig,
387
        quant_config: QuantizationConfig | None = None,
388
        multimodal_config: MultiModalConfig | None = None,
389
390
        prefix: str = "",
    ):
391
392
        super().__init__()
        self.config = config
393
394
395
396
397
        self.layers = nn.ModuleList(
            [
                Siglip2EncoderLayer(
                    config,
                    quant_config=quant_config,
398
                    multimodal_config=multimodal_config,
399
400
401
402
403
                    prefix=f"{prefix}.layers.{idx}",
                )
                for idx in range(config.num_hidden_layers)
            ]
        )
404
405

        self.rotary_pos_emb = VisionRotaryEmbedding(
406
407
            config.hidden_size // config.num_attention_heads // 2
        )
408
409
410
411
412
413
414
415
        self.patch_size = config.patch_size
        self.hidden_stride = config.hidden_stride
        self.window_size = config.window_size
        self.spatial_merge_unit = config.hidden_stride * config.hidden_stride
        if config.fullatt_block_indexes is None:
            self.fullatt_block_indexes = None
        else:
            self.fullatt_block_indexes = [
416
                int(i) for i in config.fullatt_block_indexes.split("|")
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
            ]

    # copied from qwen2.5_vl
    def rot_pos_emb(self, grid_thw):
        pos_ids = []
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            hpos_ids = hpos_ids.reshape(
                h // self.hidden_stride,
                self.hidden_stride,
                w // self.hidden_stride,
                self.hidden_stride,
            )
            hpos_ids = hpos_ids.permute(0, 2, 1, 3)
            hpos_ids = hpos_ids.flatten()

            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
            wpos_ids = wpos_ids.reshape(
                h // self.hidden_stride,
                self.hidden_stride,
                w // self.hidden_stride,
                self.hidden_stride,
            )
            wpos_ids = wpos_ids.permute(0, 2, 1, 3)
            wpos_ids = wpos_ids.flatten()
442
            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
443
444
445
446
447
448
449
450
451
452
453
        pos_ids = torch.cat(pos_ids, dim=0)
        max_grid_size = grid_thw[:, 1:].max()
        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
        return rotary_pos_emb

    def get_window_index(self, grid_thw):
        window_index: list = []
        cu_window_seqlens: list = [0]
        window_index_id = 0
        # patch (after merge) number in each window
454
455
456
        vit_merger_window_size = (
            self.window_size // self.hidden_stride // self.patch_size
        )
457
458
459
460
461
462
463

        for grid_t, grid_h, grid_w in grid_thw:
            llm_grid_h, llm_grid_w = (
                grid_h // self.hidden_stride,  # number of patch after merge
                grid_w // self.hidden_stride,
            )
            index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
464
465
                grid_t, llm_grid_h, llm_grid_w
            )
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
            pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
            pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
            num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
            num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
            index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
            index_padded = index_padded.reshape(
                grid_t,
                num_windows_h,
                vit_merger_window_size,
                num_windows_w,
                vit_merger_window_size,
            )
            index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
                grid_t,
                num_windows_h * num_windows_w,
                vit_merger_window_size,
                vit_merger_window_size,
            )
            seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
            index_padded = index_padded.reshape(-1)
            index_new = index_padded[index_padded != -100]
            window_index.append(index_new + window_index_id)
488
489
490
            cu_seqlens_tmp = (
                seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]
            )
491
492
493
494
495
496
497
498
            cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
            window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
        window_index = torch.cat(window_index, dim=0)

        return window_index, cu_window_seqlens

    def forward(
        self,
499
        inputs_embeds: torch.Tensor,
500
        grid_thws: torch.Tensor,
501
    ) -> torch.Tensor:
502
503
        r"""
        Args:
504
            inputs_embeds: Input tensor of shape
505
506
                (batch_size, sequence_length, hidden_size).
                Embedded representation of the input tokens.
507
            grid_thws: Grid tensor of shape (num_patches, 3)
508
                containing grid dimensions.
509
510
511
512
513
514
515
516
517
518
519
520
521
522
                Whether or not to return a [`~utils.ModelOutput`] instead of
                a plain tuple.
        """
        rotary_pos_emb = self.rot_pos_emb(grid_thws)
        window_index, cu_window_seqlens = self.get_window_index(grid_thws)
        cu_window_seqlens = torch.tensor(
            cu_window_seqlens,
            device=inputs_embeds.device,
            dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32,
        )
        cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)

        seq_len, _ = inputs_embeds.size()
        inputs_embeds = inputs_embeds.reshape(
523
524
            seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
        )
525
526
527
        inputs_embeds = inputs_embeds[window_index, :, :]
        inputs_embeds = inputs_embeds.reshape(seq_len, -1)
        rotary_pos_emb = rotary_pos_emb.reshape(
528
529
            seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
        )
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
        rotary_pos_emb = rotary_pos_emb[window_index, :, :]
        rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
        emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
        position_embeddings = (emb.cos(), emb.sin())

        cu_seqlens = torch.repeat_interleave(
            grid_thws[:, 1] * grid_thws[:, 2], grid_thws[:, 0]
        ).cumsum(
            dim=0,
            # Select dtype based on the following factors:
            #  - FA2 requires that cu_seqlens_q must have dtype int32
            #  - torch.onnx.export requires that cu_seqlens_q must have
            #    same dtype as grid_thw
            # See https://github.com/huggingface/transformers/pull/34852
            # for more information
            dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32,
        )
547
        cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
548
549
550
551
552

        reverse_indices = torch.argsort(window_index)

        hidden_states = inputs_embeds
        for index, block in enumerate(self.layers):
553
            if not self.fullatt_block_indexes or index in self.fullatt_block_indexes:
554
555
556
                cu_seqlens_tmp = cu_seqlens
            else:
                cu_seqlens_tmp = cu_window_seqlens
557
            hidden_states = block(hidden_states, cu_seqlens_tmp, position_embeddings)
558

559
        hidden_states = hidden_states.reshape(
560
561
            seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
        )
562
563
        hidden_states = hidden_states[reverse_indices, :].reshape(seq_len, -1)

564
        return hidden_states
565
566
567


class Siglip2VisionTransformer(nn.Module):
568
569
570
    def __init__(
        self,
        config: Siglip2VisionConfig,
571
        quant_config: QuantizationConfig | None = None,
572
        multimodal_config: MultiModalConfig | None = None,
573
574
        prefix: str = "",
    ):
575
576
577
578
579
        super().__init__()
        self.config = config
        embed_dim = config.hidden_size

        self.embeddings = Siglip2VisionEmbeddings(config)
580
581
582
        self.encoder = Siglip2Encoder(
            config,
            quant_config=quant_config,
583
            multimodal_config=multimodal_config,
584
585
586
            prefix=f"{prefix}.encoder",
        )
        self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
587
588
589
590
591

    def forward(
        self,
        pixel_values: torch.FloatTensor,
        grid_thws: torch.LongTensor,
592
    ) -> torch.Tensor:
593
594
595
596
597
598
599
        r"""
        spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
            Tensor containing the spatial dimensions (height, width)
            of the input images.
        """
        hidden_states = self.embeddings(pixel_values, grid_thws)

600
        last_hidden_state = self.encoder(hidden_states, grid_thws)
601
602
603
604
605
606
        last_hidden_state = self.post_layernorm(last_hidden_state)

        return last_hidden_state


class Siglip2NavitModel(torch.nn.Module):
607
608
609
    def __init__(
        self,
        config: Siglip2VisionConfig,
610
        quant_config: QuantizationConfig | None = None,
611
        multimodal_config: MultiModalConfig | None = None,
612
613
        prefix: str = "",
    ):
614
615
        super().__init__()

616
617
618
        self.vision_model = Siglip2VisionTransformer(
            config,
            quant_config=quant_config,
619
            multimodal_config=multimodal_config,
620
            prefix=f"{prefix}.vision_model",
621
        )
622
623
624
625
626

    def forward(
        self,
        pixel_values: torch.FloatTensor,
        grid_thws: torch.LongTensor,
627
    ) -> torch.Tensor:
628
629
630
631
        return self.vision_model(
            pixel_values=pixel_values,
            grid_thws=grid_thws,
        )
632

633
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
634
635
636
637
638
639
640
641
642
643
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()

        for name, loaded_weight in weights:
644
            for param_name, weight_name, shard_id in stacked_params_mapping:
645
646
647
648
649
650
651
652
653
654
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
655
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
656
657
658
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params