siglip2navit.py 25.5 KB
Newer Older
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Implementation of SiglipVisionModel intended to be only used
within a vision language model."""

6
7
from collections.abc import Iterable
from typing import Optional
8
9
10
11
12

import torch
from einops import rearrange, repeat
from torch import nn
from torch.nn import functional as F
13
from transformers import Siglip2VisionConfig
14
15
from transformers.configuration_utils import PretrainedConfig

16
from vllm.attention.backends.registry import _Backend
17
from vllm.attention.layer import maybe_get_vit_flash_attn_backend
18
19
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
20
21
22
23
24
25
26
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    LinearBase,
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
27
from vllm.model_executor.layers.quantization import QuantizationConfig
28
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
29
30
31
32
33
34
35

from .vision import get_vit_attn_backend


class VisionRotaryEmbedding(nn.Module):
    def __init__(self, dim: int, theta: float = 10000.0) -> None:
        super().__init__()
36
        inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
37
38
39
        self.register_buffer("inv_freq", inv_freq, persistent=False)

    def forward(self, seqlen: int) -> torch.Tensor:
40
41
42
        seq = torch.arange(
            seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
        )
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
        freqs = torch.outer(seq, self.inv_freq)
        return freqs


class Siglip2VisionEmbeddings(nn.Module):
    def __init__(self, config: PretrainedConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.patch_size = config.patch_size
        self.image_size = config.image_size
        self.num_patches = config.num_patches
        self.preserve_original_pe = config.preserve_original_pe
        self.hidden_stride = config.hidden_stride

        # siglip2 naflex
        if self.num_patches > 0:
60
            self.patch_embedding = ReplicatedLinear(
61
                input_size=config.num_channels * self.patch_size * self.patch_size,
62
63
                output_size=self.embed_dim,
                return_bias=False,
64
65
66
            )
            if self.preserve_original_pe:
                self.position_embedding_size = int(self.num_patches**0.5)
67
                self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim)
68
69
70
71
72
73
74
75
76
77

        else:
            self.patch_embedding = nn.Conv2d(
                in_channels=config.num_channels,
                out_channels=self.embed_dim,
                kernel_size=self.patch_size,
                stride=self.patch_size,
                padding="valid",
            )
            if self.preserve_original_pe:
78
79
80
81
82
83
84
85
86
                self.num_patches = (self.image_size // self.patch_size) ** 2
                self.position_embedding_size = self.image_size // self.patch_size
                self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim)

    def forward(
        self,
        pixel_values: torch.FloatTensor,
        grid_thws: Optional[torch.LongTensor] = None,
    ) -> torch.Tensor:
87
88
89
90
91
92
93
94
95
96
97
98
99
        """
        Args:
            pixel_values (`torch.FloatTensor`):
                Pixel values of shape (
                    num_patches,
                    num_channels * temporal_patch_size * patch_size * patch_size
                )
            grid_thws: (`torch.LongTensor`):
                grid shape (num_patches, 3)
        """

        # Apply patch embeddings to already patchified pixel values
        target_dtype = self.patch_embedding.weight.dtype
100
        if isinstance(self.patch_embedding, LinearBase):
101
            patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
102
103
        elif isinstance(self.patch_embedding, nn.Conv2d):
            pixel_values = pixel_values.view(
104
105
106
107
108
109
                -1,
                self.config.num_channels * self.config.temporal_patch_size,
                self.patch_size,
                self.patch_size,
            )
            patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
110
111
112
113
114
            patch_embeds = patch_embeds.reshape(-1, self.embed_dim)

        if self.preserve_original_pe:
            assert grid_thws is not None
            pos_embed_new = torch.zeros_like(patch_embeds)
115
116
117
118
119
120
121
            positional_embeddings = (
                self.position_embedding.weight.reshape(
                    self.position_embedding_size, self.position_embedding_size, -1
                )
                .unsqueeze(0)
                .permute(0, 3, 1, 2)
            )
122
123
124
            cnt = 0
            for t, h, w in grid_thws:
                volume = t * h * w
125
126
127
128
129
130
                pe = F.interpolate(
                    positional_embeddings,
                    size=(h, w),
                    mode="bicubic",
                    align_corners=False,
                )
131
132
                pe = pe.permute(0, 2, 3, 1).reshape(1, h * w, -1)
                pe = pe[0].repeat(t, 1)
133
134
135
136
137
138
139
140
                pe = pe.reshape(
                    t,
                    h // self.hidden_stride,
                    self.hidden_stride,
                    w // self.hidden_stride,
                    self.hidden_stride,
                    -1,
                )
141
                pe = pe.permute(0, 1, 3, 2, 4, 5).reshape(volume, -1)
142
                pos_embed_new[cnt : cnt + volume] = pe
143
144
145
146
147
148
149
150
151
152
153
154
155
                cnt += volume
            patch_embeds = patch_embeds + pos_embed_new

        return patch_embeds


# copy from flash_attn/layers/rotary.py
def rotate_half(x, interleaved=False):
    if not interleaved:
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)
    else:
        x1, x2 = x[..., ::2], x[..., 1::2]
156
157
158
        return rearrange(
            torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
        )
159
160
161
162
163
164
165
166
167
168


def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
    """
    x: (batch_size, seqlen, nheads, headdim)
    cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
    """
    ro_dim = cos.shape[-1] * 2
    assert ro_dim <= x.shape[-1]
    cos = repeat(
169
170
        cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
    )
171
    sin = repeat(
172
173
        sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
    )
174
175
    return torch.cat(
        [
176
177
            x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
            x[..., ro_dim:],
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
        ],
        dim=-1,
    )


def apply_rotary_pos_emb(
    q: torch.Tensor,
    k: torch.Tensor,
    cos: torch.Tensor,
    sin: torch.Tensor,
    is_flash_attn_backend: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
    cos = cos.chunk(2, dim=-1)[0].contiguous()
    sin = sin.chunk(2, dim=-1)[0].contiguous()
    if is_flash_attn_backend:
        from flash_attn.layers.rotary import apply_rotary_emb
194

195
196
197
        apply_rotary_emb_func = apply_rotary_emb
    else:
        apply_rotary_emb_func = apply_rotary_emb_torch
198
199
    q_embed = apply_rotary_emb_func(q.float(), cos.float(), sin.float()).type_as(q)
    k_embed = apply_rotary_emb_func(k.float(), cos.float(), sin.float()).type_as(k)
200
201
202
203
204
205
    return q_embed, k_embed


class Siglip2Attention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

206
207
208
209
210
211
212
    def __init__(
        self,
        config: Siglip2VisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        use_data_parallel: bool = False,
    ):
213
214
215
216
217
218
219
220
221
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads "
                f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
222
223
                f" {self.num_heads})."
            )
224
225
226
227
        self.scale = self.head_dim**-0.5
        self.dropout = config.attention_dropout
        self.is_causal = False

228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
        # TODO(Isotr0py): Enable data parallel after we support
        # disabling TP on parallel linear layer
        self.qkv_proj = QKVParallelLinear(
            hidden_size=self.embed_dim,
            head_size=self.head_dim,
            total_num_heads=self.num_heads,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )
        self.out_proj = RowParallelLinear(
            input_size=self.embed_dim,
            output_size=self.embed_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.out_proj",
        )
243

244
245
246
        self.tp_size = (
            1 if use_data_parallel else get_tensor_model_parallel_world_size()
        )
247
        self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
248
249
250
        self.use_rope = config.use_rope

        # Detect attention implementation.
251
        self.attn_backend = get_vit_attn_backend(
252
253
            head_size=self.head_dim, dtype=torch.get_default_dtype()
        )
254
        self.use_upstream_fa = False
255

256
257
        self.attn_backend, self.flash_attn_varlen_func = (
            maybe_get_vit_flash_attn_backend(
258
259
260
                self.attn_backend,
                self.use_upstream_fa,
            )
261
        )
262

263
        if self.attn_backend not in {
264
265
266
            _Backend.FLASH_ATTN,
            _Backend.TORCH_SDPA,
            _Backend.ROCM_AITER_FA,
267
268
269
        }:
            self.attn_backend = _Backend.TORCH_SDPA
        self.is_flash_attn_backend = self.attn_backend in {
270
271
            _Backend.FLASH_ATTN,
            _Backend.ROCM_AITER_FA,
272
273
274
275
276
277
        }

    def forward(
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor,
278
        position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
279
280
281
282
283
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        """Input shape: Batch x Time x Channel"""

        seq_length, embed_dim = hidden_states.shape

284
285
        qkv_states, _ = self.qkv_proj(hidden_states)
        queries, keys, values = qkv_states.chunk(3, dim=-1)
286

287
288
289
        queries = queries.view(seq_length, self.num_heads_per_partition, self.head_dim)
        keys = keys.view(seq_length, self.num_heads_per_partition, self.head_dim)
        values = values.view(seq_length, self.num_heads_per_partition, self.head_dim)
290
291
292

        if self.use_rope:
            cos, sin = position_embeddings
293
294
295
296
297
298
299
            queries, keys = apply_rotary_pos_emb(
                queries.unsqueeze(0),
                keys.unsqueeze(0),
                cos,
                sin,
                self.is_flash_attn_backend,
            )
300
301
302
303
304
            queries = queries.squeeze(0)
            keys = keys.squeeze(0)

        max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
        if self.is_flash_attn_backend:
305
            attn_output = self.flash_attn_varlen_func(
306
307
                queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen
            ).reshape(seq_length, -1)
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
        elif self.attn_backend == _Backend.TORCH_SDPA:
            # Execute attention entry by entry for speed & less VRAM.
            batch_size = cu_seqlens.shape[0] - 1
            outputs = []
            cu = cu_seqlens.tolist()
            for i in range(batch_size):
                start_idx = cu[i]
                end_idx = cu[i + 1]

                # Each sequence is processed independently.
                q_i = queries[start_idx:end_idx].unsqueeze(0)
                k_i = keys[start_idx:end_idx].unsqueeze(0)
                v_i = values[start_idx:end_idx].unsqueeze(0)

                # (1, seq_len, num_heads, head_dim) ->
                # (1, num_heads, seq_len, head_dim)
                q_i, k_i, v_i = [x.transpose(1, 2) for x in (q_i, k_i, v_i)]

326
                output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
327
                # (1, num_heads, seq_len, head_dim) -> (seq_len, embed_dim)
328
                output_i = output_i.transpose(1, 2).reshape(end_idx - start_idx, -1)
329
330
331
                outputs.append(output_i)

            attn_output = torch.cat(outputs, dim=0)
332
        attn_output, _ = self.out_proj(attn_output)
333
334
335
336
        return attn_output


class Siglip2MLP(nn.Module):
337
338
339
340
341
342
343
    def __init__(
        self,
        config: Siglip2VisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        use_data_parallel: bool = False,
    ):
344
345
        super().__init__()
        self.config = config
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
        self.activation_fn = get_act_fn(config.hidden_act)
        # TODO(Isotr0py): Enable data parallel after we support
        # disabling TP on parallel linear layer
        self.fc1 = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
        )
        self.fc2 = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
        )
361
362

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
363
        hidden_states, _ = self.fc1(hidden_states)
364
        hidden_states = self.activation_fn(hidden_states)
365
        hidden_states, _ = self.fc2(hidden_states)
366
367
368
369
        return hidden_states


class Siglip2EncoderLayer(nn.Module):
370
371
372
373
374
375
376
    def __init__(
        self,
        config: Siglip2VisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        use_data_parallel: bool = False,
    ):
377
378
        super().__init__()
        self.embed_dim = config.hidden_size
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
        self.self_attn = Siglip2Attention(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
            use_data_parallel=use_data_parallel,
        )
        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
        self.mlp = Siglip2MLP(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
            use_data_parallel=use_data_parallel,
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor,
        position_embeddings: torch.Tensor,
    ) -> tuple[torch.FloatTensor]:
400
401
        """
        Args:
402
403
404
            hidden_states: Input tensor of shape (batch, seq_len, embed_dim).
            cu_seqlens: Cumulative sequence lengths tensor.
            position_embeddings: Position embeddings tensor.
405
406
407
408
        """
        residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
409
410
411
412
413
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            cu_seqlens=cu_seqlens,
            position_embeddings=position_embeddings,
        )
414
415
416
417
418
419
420
421
422
423
424
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states
        return hidden_states


class Siglip2Encoder(nn.Module):
    """
425
    Transformer encoder consisting of `config.num_hidden_layers`
426
427
428
429
430
431
    self attention layers. Each layer is a [`Siglip2EncoderLayer`].

    Args:
        config: PretrainedConfig
    """

432
433
434
435
436
437
438
    def __init__(
        self,
        config: Siglip2VisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        use_data_parallel: bool = False,
    ):
439
440
        super().__init__()
        self.config = config
441
442
443
444
445
446
447
448
449
450
451
        self.layers = nn.ModuleList(
            [
                Siglip2EncoderLayer(
                    config,
                    quant_config=quant_config,
                    prefix=f"{prefix}.layers.{idx}",
                    use_data_parallel=use_data_parallel,
                )
                for idx in range(config.num_hidden_layers)
            ]
        )
452
453

        self.rotary_pos_emb = VisionRotaryEmbedding(
454
455
            config.hidden_size // config.num_attention_heads // 2
        )
456
457
458
459
460
461
462
463
        self.patch_size = config.patch_size
        self.hidden_stride = config.hidden_stride
        self.window_size = config.window_size
        self.spatial_merge_unit = config.hidden_stride * config.hidden_stride
        if config.fullatt_block_indexes is None:
            self.fullatt_block_indexes = None
        else:
            self.fullatt_block_indexes = [
464
                int(i) for i in config.fullatt_block_indexes.split("|")
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
            ]

    # copied from qwen2.5_vl
    def rot_pos_emb(self, grid_thw):
        pos_ids = []
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            hpos_ids = hpos_ids.reshape(
                h // self.hidden_stride,
                self.hidden_stride,
                w // self.hidden_stride,
                self.hidden_stride,
            )
            hpos_ids = hpos_ids.permute(0, 2, 1, 3)
            hpos_ids = hpos_ids.flatten()

            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
            wpos_ids = wpos_ids.reshape(
                h // self.hidden_stride,
                self.hidden_stride,
                w // self.hidden_stride,
                self.hidden_stride,
            )
            wpos_ids = wpos_ids.permute(0, 2, 1, 3)
            wpos_ids = wpos_ids.flatten()
490
            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
491
492
493
494
495
496
497
498
499
500
501
        pos_ids = torch.cat(pos_ids, dim=0)
        max_grid_size = grid_thw[:, 1:].max()
        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
        return rotary_pos_emb

    def get_window_index(self, grid_thw):
        window_index: list = []
        cu_window_seqlens: list = [0]
        window_index_id = 0
        # patch (after merge) number in each window
502
503
504
        vit_merger_window_size = (
            self.window_size // self.hidden_stride // self.patch_size
        )
505
506
507
508
509
510
511

        for grid_t, grid_h, grid_w in grid_thw:
            llm_grid_h, llm_grid_w = (
                grid_h // self.hidden_stride,  # number of patch after merge
                grid_w // self.hidden_stride,
            )
            index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
512
513
                grid_t, llm_grid_h, llm_grid_w
            )
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
            pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
            pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
            num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
            num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
            index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
            index_padded = index_padded.reshape(
                grid_t,
                num_windows_h,
                vit_merger_window_size,
                num_windows_w,
                vit_merger_window_size,
            )
            index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
                grid_t,
                num_windows_h * num_windows_w,
                vit_merger_window_size,
                vit_merger_window_size,
            )
            seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
            index_padded = index_padded.reshape(-1)
            index_new = index_padded[index_padded != -100]
            window_index.append(index_new + window_index_id)
536
537
538
            cu_seqlens_tmp = (
                seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]
            )
539
540
541
542
543
544
545
546
            cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
            window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
        window_index = torch.cat(window_index, dim=0)

        return window_index, cu_window_seqlens

    def forward(
        self,
547
        inputs_embeds: torch.Tensor,
548
        grid_thws: torch.Tensor,
549
    ) -> torch.Tensor:
550
551
        r"""
        Args:
552
            inputs_embeds: Input tensor of shape
553
554
                (batch_size, sequence_length, hidden_size).
                Embedded representation of the input tokens.
555
            grid_thws: Grid tensor of shape (num_patches, 3)
556
                containing grid dimensions.
557
558
559
560
561
562
563
564
565
566
567
568
569
570
                Whether or not to return a [`~utils.ModelOutput`] instead of
                a plain tuple.
        """
        rotary_pos_emb = self.rot_pos_emb(grid_thws)
        window_index, cu_window_seqlens = self.get_window_index(grid_thws)
        cu_window_seqlens = torch.tensor(
            cu_window_seqlens,
            device=inputs_embeds.device,
            dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32,
        )
        cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)

        seq_len, _ = inputs_embeds.size()
        inputs_embeds = inputs_embeds.reshape(
571
572
            seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
        )
573
574
575
        inputs_embeds = inputs_embeds[window_index, :, :]
        inputs_embeds = inputs_embeds.reshape(seq_len, -1)
        rotary_pos_emb = rotary_pos_emb.reshape(
576
577
            seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
        )
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
        rotary_pos_emb = rotary_pos_emb[window_index, :, :]
        rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
        emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
        position_embeddings = (emb.cos(), emb.sin())

        cu_seqlens = torch.repeat_interleave(
            grid_thws[:, 1] * grid_thws[:, 2], grid_thws[:, 0]
        ).cumsum(
            dim=0,
            # Select dtype based on the following factors:
            #  - FA2 requires that cu_seqlens_q must have dtype int32
            #  - torch.onnx.export requires that cu_seqlens_q must have
            #    same dtype as grid_thw
            # See https://github.com/huggingface/transformers/pull/34852
            # for more information
            dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32,
        )
        cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)

        reverse_indices = torch.argsort(window_index)

        hidden_states = inputs_embeds
        for index, block in enumerate(self.layers):
601
            if not self.fullatt_block_indexes or index in self.fullatt_block_indexes:
602
603
604
                cu_seqlens_tmp = cu_seqlens
            else:
                cu_seqlens_tmp = cu_window_seqlens
605
            hidden_states = block(hidden_states, cu_seqlens_tmp, position_embeddings)
606

607
        hidden_states = hidden_states.reshape(
608
609
            seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
        )
610
611
        hidden_states = hidden_states[reverse_indices, :].reshape(seq_len, -1)

612
        return hidden_states
613
614
615


class Siglip2VisionTransformer(nn.Module):
616
617
618
619
620
621
622
    def __init__(
        self,
        config: Siglip2VisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        use_data_parallel: bool = False,
    ):
623
624
625
626
627
        super().__init__()
        self.config = config
        embed_dim = config.hidden_size

        self.embeddings = Siglip2VisionEmbeddings(config)
628
629
630
631
632
633
634
        self.encoder = Siglip2Encoder(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.encoder",
            use_data_parallel=use_data_parallel,
        )
        self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
635
636
637
638
639

    def forward(
        self,
        pixel_values: torch.FloatTensor,
        grid_thws: torch.LongTensor,
640
    ) -> torch.Tensor:
641
642
643
644
645
646
647
        r"""
        spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
            Tensor containing the spatial dimensions (height, width)
            of the input images.
        """
        hidden_states = self.embeddings(pixel_values, grid_thws)

648
        last_hidden_state = self.encoder(hidden_states, grid_thws)
649
650
651
652
653
654
        last_hidden_state = self.post_layernorm(last_hidden_state)

        return last_hidden_state


class Siglip2NavitModel(torch.nn.Module):
655
656
657
658
659
660
661
    def __init__(
        self,
        config: Siglip2VisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        use_data_parallel: bool = False,
    ):
662
663
        super().__init__()

664
665
666
667
        self.vision_model = Siglip2VisionTransformer(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.vision_model",
668
669
            use_data_parallel=use_data_parallel,
        )
670
671
672
673
674

    def forward(
        self,
        pixel_values: torch.FloatTensor,
        grid_thws: torch.LongTensor,
675
    ) -> torch.Tensor:
676
677
678
679
        return self.vision_model(
            pixel_values=pixel_values,
            grid_thws=grid_thws,
        )
680

681
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
682
683
684
685
686
687
688
689
690
691
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()

        for name, loaded_weight in weights:
692
            for param_name, weight_name, shard_id in stacked_params_mapping:
693
694
695
696
697
698
699
700
701
702
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
703
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
704
705
706
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params