siglip2navit.py 26.1 KB
Newer Older
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Implementation of SiglipVisionModel intended to be only used
within a vision language model."""

6
from collections.abc import Iterable
7
8
9
10
11

import torch
from einops import rearrange, repeat
from torch import nn
from torch.nn import functional as F
12
from transformers import Siglip2VisionConfig
13
14
from transformers.configuration_utils import PretrainedConfig

15
from vllm.attention.backends.registry import _Backend
16
from vllm.attention.layer import maybe_get_vit_flash_attn_backend
17
18
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
19
20
21
22
23
24
25
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    LinearBase,
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
26
from vllm.model_executor.layers.quantization import QuantizationConfig
27
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
28
29
30
31
32
33
34

from .vision import get_vit_attn_backend


class VisionRotaryEmbedding(nn.Module):
    def __init__(self, dim: int, theta: float = 10000.0) -> None:
        super().__init__()
35
        inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
36
37
38
        self.register_buffer("inv_freq", inv_freq, persistent=False)

    def forward(self, seqlen: int) -> torch.Tensor:
39
40
41
        seq = torch.arange(
            seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
        )
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
        freqs = torch.outer(seq, self.inv_freq)
        return freqs


class Siglip2VisionEmbeddings(nn.Module):
    def __init__(self, config: PretrainedConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.patch_size = config.patch_size
        self.image_size = config.image_size
        self.num_patches = config.num_patches
        self.preserve_original_pe = config.preserve_original_pe
        self.hidden_stride = config.hidden_stride

        # siglip2 naflex
        if self.num_patches > 0:
59
            self.patch_embedding = ReplicatedLinear(
60
                input_size=config.num_channels * self.patch_size * self.patch_size,
61
62
                output_size=self.embed_dim,
                return_bias=False,
63
64
65
            )
            if self.preserve_original_pe:
                self.position_embedding_size = int(self.num_patches**0.5)
66
                self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim)
67
68
69
70
71
72
73
74
75
76

        else:
            self.patch_embedding = nn.Conv2d(
                in_channels=config.num_channels,
                out_channels=self.embed_dim,
                kernel_size=self.patch_size,
                stride=self.patch_size,
                padding="valid",
            )
            if self.preserve_original_pe:
77
78
79
80
81
82
83
                self.num_patches = (self.image_size // self.patch_size) ** 2
                self.position_embedding_size = self.image_size // self.patch_size
                self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim)

    def forward(
        self,
        pixel_values: torch.FloatTensor,
84
        grid_thws: torch.LongTensor | None = None,
85
    ) -> torch.Tensor:
86
87
88
89
90
91
92
93
94
95
96
97
98
        """
        Args:
            pixel_values (`torch.FloatTensor`):
                Pixel values of shape (
                    num_patches,
                    num_channels * temporal_patch_size * patch_size * patch_size
                )
            grid_thws: (`torch.LongTensor`):
                grid shape (num_patches, 3)
        """

        # Apply patch embeddings to already patchified pixel values
        target_dtype = self.patch_embedding.weight.dtype
99
        if isinstance(self.patch_embedding, LinearBase):
100
            patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
101
102
        elif isinstance(self.patch_embedding, nn.Conv2d):
            pixel_values = pixel_values.view(
103
104
105
106
107
108
                -1,
                self.config.num_channels * self.config.temporal_patch_size,
                self.patch_size,
                self.patch_size,
            )
            patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
109
110
111
112
113
            patch_embeds = patch_embeds.reshape(-1, self.embed_dim)

        if self.preserve_original_pe:
            assert grid_thws is not None
            pos_embed_new = torch.zeros_like(patch_embeds)
114
115
116
117
118
119
120
            positional_embeddings = (
                self.position_embedding.weight.reshape(
                    self.position_embedding_size, self.position_embedding_size, -1
                )
                .unsqueeze(0)
                .permute(0, 3, 1, 2)
            )
121
122
123
            cnt = 0
            for t, h, w in grid_thws:
                volume = t * h * w
124
125
126
127
128
129
                pe = F.interpolate(
                    positional_embeddings,
                    size=(h, w),
                    mode="bicubic",
                    align_corners=False,
                )
130
131
                pe = pe.permute(0, 2, 3, 1).reshape(1, h * w, -1)
                pe = pe[0].repeat(t, 1)
132
133
134
135
136
137
138
139
                pe = pe.reshape(
                    t,
                    h // self.hidden_stride,
                    self.hidden_stride,
                    w // self.hidden_stride,
                    self.hidden_stride,
                    -1,
                )
140
                pe = pe.permute(0, 1, 3, 2, 4, 5).reshape(volume, -1)
141
                pos_embed_new[cnt : cnt + volume] = pe
142
143
144
145
146
147
148
149
150
151
152
153
154
                cnt += volume
            patch_embeds = patch_embeds + pos_embed_new

        return patch_embeds


# copy from flash_attn/layers/rotary.py
def rotate_half(x, interleaved=False):
    if not interleaved:
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)
    else:
        x1, x2 = x[..., ::2], x[..., 1::2]
155
156
157
        return rearrange(
            torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
        )
158
159
160
161
162
163
164
165
166
167


def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
    """
    x: (batch_size, seqlen, nheads, headdim)
    cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
    """
    ro_dim = cos.shape[-1] * 2
    assert ro_dim <= x.shape[-1]
    cos = repeat(
168
169
        cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
    )
170
    sin = repeat(
171
172
        sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
    )
173
174
    return torch.cat(
        [
175
176
            x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
            x[..., ro_dim:],
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
        ],
        dim=-1,
    )


def apply_rotary_pos_emb(
    q: torch.Tensor,
    k: torch.Tensor,
    cos: torch.Tensor,
    sin: torch.Tensor,
    is_flash_attn_backend: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
    cos = cos.chunk(2, dim=-1)[0].contiguous()
    sin = sin.chunk(2, dim=-1)[0].contiguous()
    if is_flash_attn_backend:
        from flash_attn.layers.rotary import apply_rotary_emb
193

194
195
196
        apply_rotary_emb_func = apply_rotary_emb
    else:
        apply_rotary_emb_func = apply_rotary_emb_torch
197
198
    q_embed = apply_rotary_emb_func(q.float(), cos.float(), sin.float()).type_as(q)
    k_embed = apply_rotary_emb_func(k.float(), cos.float(), sin.float()).type_as(k)
199
200
201
202
203
204
    return q_embed, k_embed


class Siglip2Attention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

205
206
207
    def __init__(
        self,
        config: Siglip2VisionConfig,
208
        quant_config: QuantizationConfig | None = None,
209
210
        prefix: str = "",
        use_data_parallel: bool = False,
211
        attn_backend_override: _Backend | None = None,
212
    ):
213
214
215
216
217
218
219
220
221
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads "
                f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
222
223
                f" {self.num_heads})."
            )
224
225
226
227
        self.scale = self.head_dim**-0.5
        self.dropout = config.attention_dropout
        self.is_causal = False

228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
        # TODO(Isotr0py): Enable data parallel after we support
        # disabling TP on parallel linear layer
        self.qkv_proj = QKVParallelLinear(
            hidden_size=self.embed_dim,
            head_size=self.head_dim,
            total_num_heads=self.num_heads,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )
        self.out_proj = RowParallelLinear(
            input_size=self.embed_dim,
            output_size=self.embed_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.out_proj",
        )
243

244
245
246
        self.tp_size = (
            1 if use_data_parallel else get_tensor_model_parallel_world_size()
        )
247
        self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
248
249
250
        self.use_rope = config.use_rope

        # Detect attention implementation.
251
        self.attn_backend = get_vit_attn_backend(
252
253
254
            head_size=self.head_dim,
            dtype=torch.get_default_dtype(),
            attn_backend_override=attn_backend_override,
255
        )
256
        self.use_upstream_fa = False
257

258
259
        self.attn_backend, self.flash_attn_varlen_func = (
            maybe_get_vit_flash_attn_backend(
260
261
                self.attn_backend,
                self.use_upstream_fa,
262
                attn_backend_override=attn_backend_override,
263
            )
264
        )
265

266
        if self.attn_backend not in {
267
268
269
            _Backend.FLASH_ATTN,
            _Backend.TORCH_SDPA,
            _Backend.ROCM_AITER_FA,
270
271
272
        }:
            self.attn_backend = _Backend.TORCH_SDPA
        self.is_flash_attn_backend = self.attn_backend in {
273
274
            _Backend.FLASH_ATTN,
            _Backend.ROCM_AITER_FA,
275
276
277
278
279
280
        }

    def forward(
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor,
281
282
        position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
283
284
285
286
        """Input shape: Batch x Time x Channel"""

        seq_length, embed_dim = hidden_states.shape

287
288
        qkv_states, _ = self.qkv_proj(hidden_states)
        queries, keys, values = qkv_states.chunk(3, dim=-1)
289

290
291
292
        queries = queries.view(seq_length, self.num_heads_per_partition, self.head_dim)
        keys = keys.view(seq_length, self.num_heads_per_partition, self.head_dim)
        values = values.view(seq_length, self.num_heads_per_partition, self.head_dim)
293
294
295

        if self.use_rope:
            cos, sin = position_embeddings
296
297
298
299
300
301
302
            queries, keys = apply_rotary_pos_emb(
                queries.unsqueeze(0),
                keys.unsqueeze(0),
                cos,
                sin,
                self.is_flash_attn_backend,
            )
303
304
305
306
307
            queries = queries.squeeze(0)
            keys = keys.squeeze(0)

        max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
        if self.is_flash_attn_backend:
308
            attn_output = self.flash_attn_varlen_func(
309
310
                queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen
            ).reshape(seq_length, -1)
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
        elif self.attn_backend == _Backend.TORCH_SDPA:
            # Execute attention entry by entry for speed & less VRAM.
            batch_size = cu_seqlens.shape[0] - 1
            outputs = []
            cu = cu_seqlens.tolist()
            for i in range(batch_size):
                start_idx = cu[i]
                end_idx = cu[i + 1]

                # Each sequence is processed independently.
                q_i = queries[start_idx:end_idx].unsqueeze(0)
                k_i = keys[start_idx:end_idx].unsqueeze(0)
                v_i = values[start_idx:end_idx].unsqueeze(0)

                # (1, seq_len, num_heads, head_dim) ->
                # (1, num_heads, seq_len, head_dim)
                q_i, k_i, v_i = [x.transpose(1, 2) for x in (q_i, k_i, v_i)]

329
                output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
330
                # (1, num_heads, seq_len, head_dim) -> (seq_len, embed_dim)
331
                output_i = output_i.transpose(1, 2).reshape(end_idx - start_idx, -1)
332
333
334
                outputs.append(output_i)

            attn_output = torch.cat(outputs, dim=0)
335
        attn_output, _ = self.out_proj(attn_output)
336
337
338
339
        return attn_output


class Siglip2MLP(nn.Module):
340
341
342
    def __init__(
        self,
        config: Siglip2VisionConfig,
343
        quant_config: QuantizationConfig | None = None,
344
345
346
        prefix: str = "",
        use_data_parallel: bool = False,
    ):
347
348
        super().__init__()
        self.config = config
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
        self.activation_fn = get_act_fn(config.hidden_act)
        # TODO(Isotr0py): Enable data parallel after we support
        # disabling TP on parallel linear layer
        self.fc1 = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
        )
        self.fc2 = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
        )
364
365

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
366
        hidden_states, _ = self.fc1(hidden_states)
367
        hidden_states = self.activation_fn(hidden_states)
368
        hidden_states, _ = self.fc2(hidden_states)
369
370
371
372
        return hidden_states


class Siglip2EncoderLayer(nn.Module):
373
374
375
    def __init__(
        self,
        config: Siglip2VisionConfig,
376
        quant_config: QuantizationConfig | None = None,
377
378
        prefix: str = "",
        use_data_parallel: bool = False,
379
        attn_backend_override: _Backend | None = None,
380
    ):
381
382
        super().__init__()
        self.embed_dim = config.hidden_size
383
384
385
386
387
388
        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
        self.self_attn = Siglip2Attention(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
            use_data_parallel=use_data_parallel,
389
            attn_backend_override=attn_backend_override,
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
        )
        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
        self.mlp = Siglip2MLP(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
            use_data_parallel=use_data_parallel,
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor,
        position_embeddings: torch.Tensor,
    ) -> tuple[torch.FloatTensor]:
405
406
        """
        Args:
407
408
409
            hidden_states: Input tensor of shape (batch, seq_len, embed_dim).
            cu_seqlens: Cumulative sequence lengths tensor.
            position_embeddings: Position embeddings tensor.
410
411
412
413
        """
        residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
414
415
416
417
418
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            cu_seqlens=cu_seqlens,
            position_embeddings=position_embeddings,
        )
419
420
421
422
423
424
425
426
427
428
429
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states
        return hidden_states


class Siglip2Encoder(nn.Module):
    """
430
    Transformer encoder consisting of `config.num_hidden_layers`
431
432
433
434
435
436
    self attention layers. Each layer is a [`Siglip2EncoderLayer`].

    Args:
        config: PretrainedConfig
    """

437
438
439
    def __init__(
        self,
        config: Siglip2VisionConfig,
440
        quant_config: QuantizationConfig | None = None,
441
442
        prefix: str = "",
        use_data_parallel: bool = False,
443
        attn_backend_override: _Backend | None = None,
444
    ):
445
446
        super().__init__()
        self.config = config
447
448
449
450
451
452
453
        self.layers = nn.ModuleList(
            [
                Siglip2EncoderLayer(
                    config,
                    quant_config=quant_config,
                    prefix=f"{prefix}.layers.{idx}",
                    use_data_parallel=use_data_parallel,
454
                    attn_backend_override=attn_backend_override,
455
456
457
458
                )
                for idx in range(config.num_hidden_layers)
            ]
        )
459
460

        self.rotary_pos_emb = VisionRotaryEmbedding(
461
462
            config.hidden_size // config.num_attention_heads // 2
        )
463
464
465
466
467
468
469
470
        self.patch_size = config.patch_size
        self.hidden_stride = config.hidden_stride
        self.window_size = config.window_size
        self.spatial_merge_unit = config.hidden_stride * config.hidden_stride
        if config.fullatt_block_indexes is None:
            self.fullatt_block_indexes = None
        else:
            self.fullatt_block_indexes = [
471
                int(i) for i in config.fullatt_block_indexes.split("|")
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
            ]

    # copied from qwen2.5_vl
    def rot_pos_emb(self, grid_thw):
        pos_ids = []
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            hpos_ids = hpos_ids.reshape(
                h // self.hidden_stride,
                self.hidden_stride,
                w // self.hidden_stride,
                self.hidden_stride,
            )
            hpos_ids = hpos_ids.permute(0, 2, 1, 3)
            hpos_ids = hpos_ids.flatten()

            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
            wpos_ids = wpos_ids.reshape(
                h // self.hidden_stride,
                self.hidden_stride,
                w // self.hidden_stride,
                self.hidden_stride,
            )
            wpos_ids = wpos_ids.permute(0, 2, 1, 3)
            wpos_ids = wpos_ids.flatten()
497
            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
498
499
500
501
502
503
504
505
506
507
508
        pos_ids = torch.cat(pos_ids, dim=0)
        max_grid_size = grid_thw[:, 1:].max()
        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
        return rotary_pos_emb

    def get_window_index(self, grid_thw):
        window_index: list = []
        cu_window_seqlens: list = [0]
        window_index_id = 0
        # patch (after merge) number in each window
509
510
511
        vit_merger_window_size = (
            self.window_size // self.hidden_stride // self.patch_size
        )
512
513
514
515
516
517
518

        for grid_t, grid_h, grid_w in grid_thw:
            llm_grid_h, llm_grid_w = (
                grid_h // self.hidden_stride,  # number of patch after merge
                grid_w // self.hidden_stride,
            )
            index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
519
520
                grid_t, llm_grid_h, llm_grid_w
            )
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
            pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
            pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
            num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
            num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
            index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
            index_padded = index_padded.reshape(
                grid_t,
                num_windows_h,
                vit_merger_window_size,
                num_windows_w,
                vit_merger_window_size,
            )
            index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
                grid_t,
                num_windows_h * num_windows_w,
                vit_merger_window_size,
                vit_merger_window_size,
            )
            seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
            index_padded = index_padded.reshape(-1)
            index_new = index_padded[index_padded != -100]
            window_index.append(index_new + window_index_id)
543
544
545
            cu_seqlens_tmp = (
                seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]
            )
546
547
548
549
550
551
552
553
            cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
            window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
        window_index = torch.cat(window_index, dim=0)

        return window_index, cu_window_seqlens

    def forward(
        self,
554
        inputs_embeds: torch.Tensor,
555
        grid_thws: torch.Tensor,
556
    ) -> torch.Tensor:
557
558
        r"""
        Args:
559
            inputs_embeds: Input tensor of shape
560
561
                (batch_size, sequence_length, hidden_size).
                Embedded representation of the input tokens.
562
            grid_thws: Grid tensor of shape (num_patches, 3)
563
                containing grid dimensions.
564
565
566
567
568
569
570
571
572
573
574
575
576
577
                Whether or not to return a [`~utils.ModelOutput`] instead of
                a plain tuple.
        """
        rotary_pos_emb = self.rot_pos_emb(grid_thws)
        window_index, cu_window_seqlens = self.get_window_index(grid_thws)
        cu_window_seqlens = torch.tensor(
            cu_window_seqlens,
            device=inputs_embeds.device,
            dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32,
        )
        cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)

        seq_len, _ = inputs_embeds.size()
        inputs_embeds = inputs_embeds.reshape(
578
579
            seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
        )
580
581
582
        inputs_embeds = inputs_embeds[window_index, :, :]
        inputs_embeds = inputs_embeds.reshape(seq_len, -1)
        rotary_pos_emb = rotary_pos_emb.reshape(
583
584
            seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
        )
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
        rotary_pos_emb = rotary_pos_emb[window_index, :, :]
        rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
        emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
        position_embeddings = (emb.cos(), emb.sin())

        cu_seqlens = torch.repeat_interleave(
            grid_thws[:, 1] * grid_thws[:, 2], grid_thws[:, 0]
        ).cumsum(
            dim=0,
            # Select dtype based on the following factors:
            #  - FA2 requires that cu_seqlens_q must have dtype int32
            #  - torch.onnx.export requires that cu_seqlens_q must have
            #    same dtype as grid_thw
            # See https://github.com/huggingface/transformers/pull/34852
            # for more information
            dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32,
        )
602
        cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
603
604
605
606
607

        reverse_indices = torch.argsort(window_index)

        hidden_states = inputs_embeds
        for index, block in enumerate(self.layers):
608
            if not self.fullatt_block_indexes or index in self.fullatt_block_indexes:
609
610
611
                cu_seqlens_tmp = cu_seqlens
            else:
                cu_seqlens_tmp = cu_window_seqlens
612
            hidden_states = block(hidden_states, cu_seqlens_tmp, position_embeddings)
613

614
        hidden_states = hidden_states.reshape(
615
616
            seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
        )
617
618
        hidden_states = hidden_states[reverse_indices, :].reshape(seq_len, -1)

619
        return hidden_states
620
621
622


class Siglip2VisionTransformer(nn.Module):
623
624
625
    def __init__(
        self,
        config: Siglip2VisionConfig,
626
        quant_config: QuantizationConfig | None = None,
627
628
        prefix: str = "",
        use_data_parallel: bool = False,
629
        attn_backend_override: _Backend | None = None,
630
    ):
631
632
633
634
635
        super().__init__()
        self.config = config
        embed_dim = config.hidden_size

        self.embeddings = Siglip2VisionEmbeddings(config)
636
637
638
639
640
        self.encoder = Siglip2Encoder(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.encoder",
            use_data_parallel=use_data_parallel,
641
            attn_backend_override=attn_backend_override,
642
643
        )
        self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
644
645
646
647
648

    def forward(
        self,
        pixel_values: torch.FloatTensor,
        grid_thws: torch.LongTensor,
649
    ) -> torch.Tensor:
650
651
652
653
654
655
656
        r"""
        spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
            Tensor containing the spatial dimensions (height, width)
            of the input images.
        """
        hidden_states = self.embeddings(pixel_values, grid_thws)

657
        last_hidden_state = self.encoder(hidden_states, grid_thws)
658
659
660
661
662
663
        last_hidden_state = self.post_layernorm(last_hidden_state)

        return last_hidden_state


class Siglip2NavitModel(torch.nn.Module):
664
665
666
    def __init__(
        self,
        config: Siglip2VisionConfig,
667
        quant_config: QuantizationConfig | None = None,
668
669
        prefix: str = "",
        use_data_parallel: bool = False,
670
        attn_backend_override: _Backend | None = None,
671
    ):
672
673
        super().__init__()

674
675
676
677
        self.vision_model = Siglip2VisionTransformer(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.vision_model",
678
            use_data_parallel=use_data_parallel,
679
            attn_backend_override=attn_backend_override,
680
        )
681
682
683
684
685

    def forward(
        self,
        pixel_values: torch.FloatTensor,
        grid_thws: torch.LongTensor,
686
    ) -> torch.Tensor:
687
688
689
690
        return self.vision_model(
            pixel_values=pixel_values,
            grid_thws=grid_thws,
        )
691

692
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
693
694
695
696
697
698
699
700
701
702
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()

        for name, loaded_weight in weights:
703
            for param_name, weight_name, shard_id in stacked_params_mapping:
704
705
706
707
708
709
710
711
712
713
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
714
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
715
716
717
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params