siglip2navit.py 26.5 KB
Newer Older
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Implementation of SiglipVisionModel intended to be only used
within a vision language model."""

6
from collections.abc import Iterable
7
8
9
10
11

import torch
from einops import rearrange, repeat
from torch import nn
from torch.nn import functional as F
12
from transformers import Siglip2VisionConfig
13
14
from transformers.configuration_utils import PretrainedConfig

15
from vllm.attention.backends.registry import AttentionBackendEnum
16
from vllm.attention.layer import maybe_get_vit_flash_attn_backend
17
18
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
19
from vllm.model_executor.layers.conv import Conv2dLayer
20
21
22
23
24
25
26
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    LinearBase,
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
27
from vllm.model_executor.layers.quantization import QuantizationConfig
28
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
29
from vllm.platforms import current_platform
30
31
32
33
34
35
36

from .vision import get_vit_attn_backend


class VisionRotaryEmbedding(nn.Module):
    def __init__(self, dim: int, theta: float = 10000.0) -> None:
        super().__init__()
37
        inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
38
39
40
        self.register_buffer("inv_freq", inv_freq, persistent=False)

    def forward(self, seqlen: int) -> torch.Tensor:
41
42
43
        seq = torch.arange(
            seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
        )
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
        freqs = torch.outer(seq, self.inv_freq)
        return freqs


class Siglip2VisionEmbeddings(nn.Module):
    def __init__(self, config: PretrainedConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.patch_size = config.patch_size
        self.image_size = config.image_size
        self.num_patches = config.num_patches
        self.preserve_original_pe = config.preserve_original_pe
        self.hidden_stride = config.hidden_stride

        # siglip2 naflex
        if self.num_patches > 0:
61
            self.patch_embedding = ReplicatedLinear(
62
                input_size=config.num_channels * self.patch_size * self.patch_size,
63
64
                output_size=self.embed_dim,
                return_bias=False,
65
66
67
            )
            if self.preserve_original_pe:
                self.position_embedding_size = int(self.num_patches**0.5)
68
                self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim)
69
70

        else:
71
            self.patch_embedding = Conv2dLayer(
72
73
74
75
76
77
78
                in_channels=config.num_channels,
                out_channels=self.embed_dim,
                kernel_size=self.patch_size,
                stride=self.patch_size,
                padding="valid",
            )
            if self.preserve_original_pe:
79
80
81
82
83
84
85
                self.num_patches = (self.image_size // self.patch_size) ** 2
                self.position_embedding_size = self.image_size // self.patch_size
                self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim)

    def forward(
        self,
        pixel_values: torch.FloatTensor,
86
        grid_thws: torch.LongTensor | None = None,
87
    ) -> torch.Tensor:
88
89
90
91
92
93
94
95
96
97
98
99
100
        """
        Args:
            pixel_values (`torch.FloatTensor`):
                Pixel values of shape (
                    num_patches,
                    num_channels * temporal_patch_size * patch_size * patch_size
                )
            grid_thws: (`torch.LongTensor`):
                grid shape (num_patches, 3)
        """

        # Apply patch embeddings to already patchified pixel values
        target_dtype = self.patch_embedding.weight.dtype
101
        if isinstance(self.patch_embedding, LinearBase):
102
            patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
103
        elif isinstance(self.patch_embedding, Conv2dLayer):
104
            pixel_values = pixel_values.view(
105
106
107
108
109
110
                -1,
                self.config.num_channels * self.config.temporal_patch_size,
                self.patch_size,
                self.patch_size,
            )
            patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
111
112
113
114
115
            patch_embeds = patch_embeds.reshape(-1, self.embed_dim)

        if self.preserve_original_pe:
            assert grid_thws is not None
            pos_embed_new = torch.zeros_like(patch_embeds)
116
117
118
119
120
121
122
            positional_embeddings = (
                self.position_embedding.weight.reshape(
                    self.position_embedding_size, self.position_embedding_size, -1
                )
                .unsqueeze(0)
                .permute(0, 3, 1, 2)
            )
123
124
125
            cnt = 0
            for t, h, w in grid_thws:
                volume = t * h * w
126
127
128
129
130
131
                pe = F.interpolate(
                    positional_embeddings,
                    size=(h, w),
                    mode="bicubic",
                    align_corners=False,
                )
132
133
                pe = pe.permute(0, 2, 3, 1).reshape(1, h * w, -1)
                pe = pe[0].repeat(t, 1)
134
135
136
137
138
139
140
141
                pe = pe.reshape(
                    t,
                    h // self.hidden_stride,
                    self.hidden_stride,
                    w // self.hidden_stride,
                    self.hidden_stride,
                    -1,
                )
142
                pe = pe.permute(0, 1, 3, 2, 4, 5).reshape(volume, -1)
143
                pos_embed_new[cnt : cnt + volume] = pe
144
145
146
147
148
149
150
151
152
153
154
155
156
                cnt += volume
            patch_embeds = patch_embeds + pos_embed_new

        return patch_embeds


# copy from flash_attn/layers/rotary.py
def rotate_half(x, interleaved=False):
    if not interleaved:
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)
    else:
        x1, x2 = x[..., ::2], x[..., 1::2]
157
158
159
        return rearrange(
            torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
        )
160
161
162
163
164
165
166
167
168
169


def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
    """
    x: (batch_size, seqlen, nheads, headdim)
    cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
    """
    ro_dim = cos.shape[-1] * 2
    assert ro_dim <= x.shape[-1]
    cos = repeat(
170
171
        cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
    )
172
    sin = repeat(
173
174
        sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
    )
175
176
    return torch.cat(
        [
177
178
            x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
            x[..., ro_dim:],
179
180
181
182
183
184
185
186
187
188
189
190
191
192
        ],
        dim=-1,
    )


def apply_rotary_pos_emb(
    q: torch.Tensor,
    k: torch.Tensor,
    cos: torch.Tensor,
    sin: torch.Tensor,
    is_flash_attn_backend: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
    cos = cos.chunk(2, dim=-1)[0].contiguous()
    sin = sin.chunk(2, dim=-1)[0].contiguous()
193
    if is_flash_attn_backend and not current_platform.is_xpu():
194
        from flash_attn.layers.rotary import apply_rotary_emb
195

196
197
198
        apply_rotary_emb_func = apply_rotary_emb
    else:
        apply_rotary_emb_func = apply_rotary_emb_torch
199
200
    q_embed = apply_rotary_emb_func(q.float(), cos.float(), sin.float()).type_as(q)
    k_embed = apply_rotary_emb_func(k.float(), cos.float(), sin.float()).type_as(k)
201
202
203
204
205
206
    return q_embed, k_embed


class Siglip2Attention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

207
208
209
    def __init__(
        self,
        config: Siglip2VisionConfig,
210
        quant_config: QuantizationConfig | None = None,
211
212
        prefix: str = "",
        use_data_parallel: bool = False,
213
        attn_backend_override: AttentionBackendEnum | None = None,
214
    ):
215
216
217
218
219
220
221
222
223
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads "
                f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
224
225
                f" {self.num_heads})."
            )
226
227
228
229
        self.scale = self.head_dim**-0.5
        self.dropout = config.attention_dropout
        self.is_causal = False

230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
        # TODO(Isotr0py): Enable data parallel after we support
        # disabling TP on parallel linear layer
        self.qkv_proj = QKVParallelLinear(
            hidden_size=self.embed_dim,
            head_size=self.head_dim,
            total_num_heads=self.num_heads,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )
        self.out_proj = RowParallelLinear(
            input_size=self.embed_dim,
            output_size=self.embed_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.out_proj",
        )
245

246
247
248
        self.tp_size = (
            1 if use_data_parallel else get_tensor_model_parallel_world_size()
        )
249
        self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
250
251
252
        self.use_rope = config.use_rope

        # Detect attention implementation.
253
        self.attn_backend = get_vit_attn_backend(
254
255
256
            head_size=self.head_dim,
            dtype=torch.get_default_dtype(),
            attn_backend_override=attn_backend_override,
257
        )
258
        self.use_upstream_fa = False
259

260
261
        self.attn_backend, self.flash_attn_varlen_func = (
            maybe_get_vit_flash_attn_backend(
262
263
                self.attn_backend,
                self.use_upstream_fa,
264
                attn_backend_override=attn_backend_override,
265
            )
266
        )
267

268
        if self.attn_backend not in {
269
270
271
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.TORCH_SDPA,
            AttentionBackendEnum.ROCM_AITER_FA,
272
        }:
273
            self.attn_backend = AttentionBackendEnum.TORCH_SDPA
274
        self.is_flash_attn_backend = self.attn_backend in {
275
276
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.ROCM_AITER_FA,
277
278
279
280
281
282
        }

    def forward(
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor,
283
284
        position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
285
286
287
288
        """Input shape: Batch x Time x Channel"""

        seq_length, embed_dim = hidden_states.shape

289
290
        qkv_states, _ = self.qkv_proj(hidden_states)
        queries, keys, values = qkv_states.chunk(3, dim=-1)
291

292
293
294
        queries = queries.view(seq_length, self.num_heads_per_partition, self.head_dim)
        keys = keys.view(seq_length, self.num_heads_per_partition, self.head_dim)
        values = values.view(seq_length, self.num_heads_per_partition, self.head_dim)
295
296
297

        if self.use_rope:
            cos, sin = position_embeddings
298
299
300
301
302
303
304
            queries, keys = apply_rotary_pos_emb(
                queries.unsqueeze(0),
                keys.unsqueeze(0),
                cos,
                sin,
                self.is_flash_attn_backend,
            )
305
306
307
308
309
            queries = queries.squeeze(0)
            keys = keys.squeeze(0)

        max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
        if self.is_flash_attn_backend:
310
            attn_output = self.flash_attn_varlen_func(
311
312
313
314
315
316
317
                queries,
                keys,
                values,
                cu_seqlens_q=cu_seqlens,
                cu_seqlens_k=cu_seqlens,
                max_seqlen_q=max_seqlen,
                max_seqlen_k=max_seqlen,
318
            ).reshape(seq_length, -1)
319
        elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
            # Execute attention entry by entry for speed & less VRAM.
            batch_size = cu_seqlens.shape[0] - 1
            outputs = []
            cu = cu_seqlens.tolist()
            for i in range(batch_size):
                start_idx = cu[i]
                end_idx = cu[i + 1]

                # Each sequence is processed independently.
                q_i = queries[start_idx:end_idx].unsqueeze(0)
                k_i = keys[start_idx:end_idx].unsqueeze(0)
                v_i = values[start_idx:end_idx].unsqueeze(0)

                # (1, seq_len, num_heads, head_dim) ->
                # (1, num_heads, seq_len, head_dim)
                q_i, k_i, v_i = [x.transpose(1, 2) for x in (q_i, k_i, v_i)]

337
                output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
338
                # (1, num_heads, seq_len, head_dim) -> (seq_len, embed_dim)
339
                output_i = output_i.transpose(1, 2).reshape(end_idx - start_idx, -1)
340
341
342
                outputs.append(output_i)

            attn_output = torch.cat(outputs, dim=0)
343
        attn_output, _ = self.out_proj(attn_output)
344
345
346
347
        return attn_output


class Siglip2MLP(nn.Module):
348
349
350
    def __init__(
        self,
        config: Siglip2VisionConfig,
351
        quant_config: QuantizationConfig | None = None,
352
353
354
        prefix: str = "",
        use_data_parallel: bool = False,
    ):
355
356
        super().__init__()
        self.config = config
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
        self.activation_fn = get_act_fn(config.hidden_act)
        # TODO(Isotr0py): Enable data parallel after we support
        # disabling TP on parallel linear layer
        self.fc1 = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
        )
        self.fc2 = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
        )
372
373

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
374
        hidden_states, _ = self.fc1(hidden_states)
375
        hidden_states = self.activation_fn(hidden_states)
376
        hidden_states, _ = self.fc2(hidden_states)
377
378
379
380
        return hidden_states


class Siglip2EncoderLayer(nn.Module):
381
382
383
    def __init__(
        self,
        config: Siglip2VisionConfig,
384
        quant_config: QuantizationConfig | None = None,
385
386
        prefix: str = "",
        use_data_parallel: bool = False,
387
        attn_backend_override: AttentionBackendEnum | None = None,
388
    ):
389
390
        super().__init__()
        self.embed_dim = config.hidden_size
391
392
393
394
395
396
        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
        self.self_attn = Siglip2Attention(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
            use_data_parallel=use_data_parallel,
397
            attn_backend_override=attn_backend_override,
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
        )
        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
        self.mlp = Siglip2MLP(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
            use_data_parallel=use_data_parallel,
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor,
        position_embeddings: torch.Tensor,
    ) -> tuple[torch.FloatTensor]:
413
414
        """
        Args:
415
416
417
            hidden_states: Input tensor of shape (batch, seq_len, embed_dim).
            cu_seqlens: Cumulative sequence lengths tensor.
            position_embeddings: Position embeddings tensor.
418
419
420
421
        """
        residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
422
423
424
425
426
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            cu_seqlens=cu_seqlens,
            position_embeddings=position_embeddings,
        )
427
428
429
430
431
432
433
434
435
436
437
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states
        return hidden_states


class Siglip2Encoder(nn.Module):
    """
438
    Transformer encoder consisting of `config.num_hidden_layers`
439
440
441
442
443
444
    self attention layers. Each layer is a [`Siglip2EncoderLayer`].

    Args:
        config: PretrainedConfig
    """

445
446
447
    def __init__(
        self,
        config: Siglip2VisionConfig,
448
        quant_config: QuantizationConfig | None = None,
449
450
        prefix: str = "",
        use_data_parallel: bool = False,
451
        attn_backend_override: AttentionBackendEnum | None = None,
452
    ):
453
454
        super().__init__()
        self.config = config
455
456
457
458
459
460
461
        self.layers = nn.ModuleList(
            [
                Siglip2EncoderLayer(
                    config,
                    quant_config=quant_config,
                    prefix=f"{prefix}.layers.{idx}",
                    use_data_parallel=use_data_parallel,
462
                    attn_backend_override=attn_backend_override,
463
464
465
466
                )
                for idx in range(config.num_hidden_layers)
            ]
        )
467
468

        self.rotary_pos_emb = VisionRotaryEmbedding(
469
470
            config.hidden_size // config.num_attention_heads // 2
        )
471
472
473
474
475
476
477
478
        self.patch_size = config.patch_size
        self.hidden_stride = config.hidden_stride
        self.window_size = config.window_size
        self.spatial_merge_unit = config.hidden_stride * config.hidden_stride
        if config.fullatt_block_indexes is None:
            self.fullatt_block_indexes = None
        else:
            self.fullatt_block_indexes = [
479
                int(i) for i in config.fullatt_block_indexes.split("|")
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
            ]

    # copied from qwen2.5_vl
    def rot_pos_emb(self, grid_thw):
        pos_ids = []
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            hpos_ids = hpos_ids.reshape(
                h // self.hidden_stride,
                self.hidden_stride,
                w // self.hidden_stride,
                self.hidden_stride,
            )
            hpos_ids = hpos_ids.permute(0, 2, 1, 3)
            hpos_ids = hpos_ids.flatten()

            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
            wpos_ids = wpos_ids.reshape(
                h // self.hidden_stride,
                self.hidden_stride,
                w // self.hidden_stride,
                self.hidden_stride,
            )
            wpos_ids = wpos_ids.permute(0, 2, 1, 3)
            wpos_ids = wpos_ids.flatten()
505
            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
506
507
508
509
510
511
512
513
514
515
516
        pos_ids = torch.cat(pos_ids, dim=0)
        max_grid_size = grid_thw[:, 1:].max()
        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
        return rotary_pos_emb

    def get_window_index(self, grid_thw):
        window_index: list = []
        cu_window_seqlens: list = [0]
        window_index_id = 0
        # patch (after merge) number in each window
517
518
519
        vit_merger_window_size = (
            self.window_size // self.hidden_stride // self.patch_size
        )
520
521
522
523
524
525
526

        for grid_t, grid_h, grid_w in grid_thw:
            llm_grid_h, llm_grid_w = (
                grid_h // self.hidden_stride,  # number of patch after merge
                grid_w // self.hidden_stride,
            )
            index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
527
528
                grid_t, llm_grid_h, llm_grid_w
            )
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
            pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
            pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
            num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
            num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
            index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
            index_padded = index_padded.reshape(
                grid_t,
                num_windows_h,
                vit_merger_window_size,
                num_windows_w,
                vit_merger_window_size,
            )
            index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
                grid_t,
                num_windows_h * num_windows_w,
                vit_merger_window_size,
                vit_merger_window_size,
            )
            seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
            index_padded = index_padded.reshape(-1)
            index_new = index_padded[index_padded != -100]
            window_index.append(index_new + window_index_id)
551
552
553
            cu_seqlens_tmp = (
                seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]
            )
554
555
556
557
558
559
560
561
            cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
            window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
        window_index = torch.cat(window_index, dim=0)

        return window_index, cu_window_seqlens

    def forward(
        self,
562
        inputs_embeds: torch.Tensor,
563
        grid_thws: torch.Tensor,
564
    ) -> torch.Tensor:
565
566
        r"""
        Args:
567
            inputs_embeds: Input tensor of shape
568
569
                (batch_size, sequence_length, hidden_size).
                Embedded representation of the input tokens.
570
            grid_thws: Grid tensor of shape (num_patches, 3)
571
                containing grid dimensions.
572
573
574
575
576
577
578
579
580
581
582
583
584
585
                Whether or not to return a [`~utils.ModelOutput`] instead of
                a plain tuple.
        """
        rotary_pos_emb = self.rot_pos_emb(grid_thws)
        window_index, cu_window_seqlens = self.get_window_index(grid_thws)
        cu_window_seqlens = torch.tensor(
            cu_window_seqlens,
            device=inputs_embeds.device,
            dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32,
        )
        cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)

        seq_len, _ = inputs_embeds.size()
        inputs_embeds = inputs_embeds.reshape(
586
587
            seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
        )
588
589
590
        inputs_embeds = inputs_embeds[window_index, :, :]
        inputs_embeds = inputs_embeds.reshape(seq_len, -1)
        rotary_pos_emb = rotary_pos_emb.reshape(
591
592
            seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
        )
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
        rotary_pos_emb = rotary_pos_emb[window_index, :, :]
        rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
        emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
        position_embeddings = (emb.cos(), emb.sin())

        cu_seqlens = torch.repeat_interleave(
            grid_thws[:, 1] * grid_thws[:, 2], grid_thws[:, 0]
        ).cumsum(
            dim=0,
            # Select dtype based on the following factors:
            #  - FA2 requires that cu_seqlens_q must have dtype int32
            #  - torch.onnx.export requires that cu_seqlens_q must have
            #    same dtype as grid_thw
            # See https://github.com/huggingface/transformers/pull/34852
            # for more information
            dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32,
        )
610
        cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
611
612
613
614
615

        reverse_indices = torch.argsort(window_index)

        hidden_states = inputs_embeds
        for index, block in enumerate(self.layers):
616
            if not self.fullatt_block_indexes or index in self.fullatt_block_indexes:
617
618
619
                cu_seqlens_tmp = cu_seqlens
            else:
                cu_seqlens_tmp = cu_window_seqlens
620
            hidden_states = block(hidden_states, cu_seqlens_tmp, position_embeddings)
621

622
        hidden_states = hidden_states.reshape(
623
624
            seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
        )
625
626
        hidden_states = hidden_states[reverse_indices, :].reshape(seq_len, -1)

627
        return hidden_states
628
629
630


class Siglip2VisionTransformer(nn.Module):
631
632
633
    def __init__(
        self,
        config: Siglip2VisionConfig,
634
        quant_config: QuantizationConfig | None = None,
635
636
        prefix: str = "",
        use_data_parallel: bool = False,
637
        attn_backend_override: AttentionBackendEnum | None = None,
638
    ):
639
640
641
642
643
        super().__init__()
        self.config = config
        embed_dim = config.hidden_size

        self.embeddings = Siglip2VisionEmbeddings(config)
644
645
646
647
648
        self.encoder = Siglip2Encoder(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.encoder",
            use_data_parallel=use_data_parallel,
649
            attn_backend_override=attn_backend_override,
650
651
        )
        self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
652
653
654
655
656

    def forward(
        self,
        pixel_values: torch.FloatTensor,
        grid_thws: torch.LongTensor,
657
    ) -> torch.Tensor:
658
659
660
661
662
663
664
        r"""
        spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
            Tensor containing the spatial dimensions (height, width)
            of the input images.
        """
        hidden_states = self.embeddings(pixel_values, grid_thws)

665
        last_hidden_state = self.encoder(hidden_states, grid_thws)
666
667
668
669
670
671
        last_hidden_state = self.post_layernorm(last_hidden_state)

        return last_hidden_state


class Siglip2NavitModel(torch.nn.Module):
672
673
674
    def __init__(
        self,
        config: Siglip2VisionConfig,
675
        quant_config: QuantizationConfig | None = None,
676
677
        prefix: str = "",
        use_data_parallel: bool = False,
678
        attn_backend_override: AttentionBackendEnum | None = None,
679
    ):
680
681
        super().__init__()

682
683
684
685
        self.vision_model = Siglip2VisionTransformer(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.vision_model",
686
            use_data_parallel=use_data_parallel,
687
            attn_backend_override=attn_backend_override,
688
        )
689
690
691
692
693

    def forward(
        self,
        pixel_values: torch.FloatTensor,
        grid_thws: torch.LongTensor,
694
    ) -> torch.Tensor:
695
696
697
698
        return self.vision_model(
            pixel_values=pixel_values,
            grid_thws=grid_thws,
        )
699

700
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
701
702
703
704
705
706
707
708
709
710
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()

        for name, loaded_weight in weights:
711
            for param_name, weight_name, shard_id in stacked_params_mapping:
712
713
714
715
716
717
718
719
720
721
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
722
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
723
724
725
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params