siglip2navit.py 26.9 KB
Newer Older
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Implementation of SiglipVisionModel intended to be only used
within a vision language model."""

6
7
from collections.abc import Iterable
from typing import Optional
8
9
10
11
12

import torch
from einops import rearrange, repeat
from torch import nn
from torch.nn import functional as F
13
from transformers import Siglip2VisionConfig
14
15
from transformers.configuration_utils import PretrainedConfig

16
17
18
19
20
21
22
23
from vllm.config import QuantizationConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               LinearBase, QKVParallelLinear,
                                               ReplicatedLinear,
                                               RowParallelLinear)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from vllm.platforms import _Backend

from .vision import get_vit_attn_backend


class VisionRotaryEmbedding(nn.Module):

    def __init__(self, dim: int, theta: float = 10000.0) -> None:
        super().__init__()
        inv_freq = 1.0 / (theta
                          **(torch.arange(0, dim, 2, dtype=torch.float) / dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

    def forward(self, seqlen: int) -> torch.Tensor:
        seq = torch.arange(seqlen,
                           device=self.inv_freq.device,
                           dtype=self.inv_freq.dtype)
        freqs = torch.outer(seq, self.inv_freq)
        return freqs


class Siglip2VisionEmbeddings(nn.Module):

    def __init__(self, config: PretrainedConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.patch_size = config.patch_size
        self.image_size = config.image_size
        self.num_patches = config.num_patches
        self.preserve_original_pe = config.preserve_original_pe
        self.hidden_stride = config.hidden_stride

        # siglip2 naflex
        if self.num_patches > 0:
59
60
            self.patch_embedding = ReplicatedLinear(
                input_size=config.num_channels * self.patch_size *
61
                self.patch_size,
62
63
                output_size=self.embed_dim,
                return_bias=False,
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
            )
            if self.preserve_original_pe:
                self.position_embedding_size = int(self.num_patches**0.5)
                self.position_embedding = nn.Embedding(self.num_patches,
                                                       self.embed_dim)

        else:
            self.patch_embedding = nn.Conv2d(
                in_channels=config.num_channels,
                out_channels=self.embed_dim,
                kernel_size=self.patch_size,
                stride=self.patch_size,
                padding="valid",
            )
            if self.preserve_original_pe:
                self.num_patches = (self.image_size // self.patch_size)**2
                self.position_embedding_size = (self.image_size //
                                                self.patch_size)
                self.position_embedding = nn.Embedding(self.num_patches,
                                                       self.embed_dim)

    def forward(self,
                pixel_values: torch.FloatTensor,
                grid_thws: Optional[torch.LongTensor] = None) -> torch.Tensor:
        """
        Args:
            pixel_values (`torch.FloatTensor`):
                Pixel values of shape (
                    num_patches,
                    num_channels * temporal_patch_size * patch_size * patch_size
                )
            grid_thws: (`torch.LongTensor`):
                grid shape (num_patches, 3)
        """

        # Apply patch embeddings to already patchified pixel values
        target_dtype = self.patch_embedding.weight.dtype
101
        if isinstance(self.patch_embedding, LinearBase):
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
            patch_embeds = self.patch_embedding(
                pixel_values.to(dtype=target_dtype))
        elif isinstance(self.patch_embedding, nn.Conv2d):
            pixel_values = pixel_values.view(
                -1, self.config.num_channels * self.config.temporal_patch_size,
                self.patch_size, self.patch_size)
            patch_embeds = self.patch_embedding(
                pixel_values.to(dtype=target_dtype))
            patch_embeds = patch_embeds.reshape(-1, self.embed_dim)

        if self.preserve_original_pe:
            assert grid_thws is not None
            pos_embed_new = torch.zeros_like(patch_embeds)
            positional_embeddings = self.position_embedding.weight.reshape(
                self.position_embedding_size, self.position_embedding_size,
                -1).unsqueeze(0).permute(0, 3, 1, 2)
            cnt = 0
            for t, h, w in grid_thws:
                volume = t * h * w
                pe = F.interpolate(positional_embeddings,
                                   size=(h, w),
                                   mode='bicubic',
                                   align_corners=False)
                pe = pe.permute(0, 2, 3, 1).reshape(1, h * w, -1)
                pe = pe[0].repeat(t, 1)
                pe = pe.reshape(t, h // self.hidden_stride, self.hidden_stride,
                                w // self.hidden_stride, self.hidden_stride,
                                -1)
                pe = pe.permute(0, 1, 3, 2, 4, 5).reshape(volume, -1)
                pos_embed_new[cnt:cnt + volume] = pe
                cnt += volume
            patch_embeds = patch_embeds + pos_embed_new

        return patch_embeds


# copy from flash_attn/layers/rotary.py
def rotate_half(x, interleaved=False):
    if not interleaved:
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)
    else:
        x1, x2 = x[..., ::2], x[..., 1::2]
        return rearrange(torch.stack((-x2, x1), dim=-1),
                         "... d two -> ... (d two)",
                         two=2)


def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
    """
    x: (batch_size, seqlen, nheads, headdim)
    cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
    """
    ro_dim = cos.shape[-1] * 2
    assert ro_dim <= x.shape[-1]
    cos = repeat(
        cos,
        "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
    sin = repeat(
        sin,
        "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
    return torch.cat(
        [
            x[..., :ro_dim] * cos +
            rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]
        ],
        dim=-1,
    )


def apply_rotary_pos_emb(
    q: torch.Tensor,
    k: torch.Tensor,
    cos: torch.Tensor,
    sin: torch.Tensor,
    is_flash_attn_backend: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
    cos = cos.chunk(2, dim=-1)[0].contiguous()
    sin = sin.chunk(2, dim=-1)[0].contiguous()
    if is_flash_attn_backend:
        from flash_attn.layers.rotary import apply_rotary_emb
        apply_rotary_emb_func = apply_rotary_emb
    else:
        apply_rotary_emb_func = apply_rotary_emb_torch
    q_embed = apply_rotary_emb_func(q.float(), cos.float(),
                                    sin.float()).type_as(q)
    k_embed = apply_rotary_emb_func(k.float(), cos.float(),
                                    sin.float()).type_as(k)
    return q_embed, k_embed


class Siglip2Attention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

196
197
198
199
200
201
202
    def __init__(
        self,
        config: Siglip2VisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        use_data_parallel: bool = False,
    ):
203
204
205
206
207
208
209
210
211
212
213
214
215
216
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads "
                f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
                f" {self.num_heads}).")
        self.scale = self.head_dim**-0.5
        self.dropout = config.attention_dropout
        self.is_causal = False

217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
        # TODO(Isotr0py): Enable data parallel after we support
        # disabling TP on parallel linear layer
        self.qkv_proj = QKVParallelLinear(
            hidden_size=self.embed_dim,
            head_size=self.head_dim,
            total_num_heads=self.num_heads,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )
        self.out_proj = RowParallelLinear(
            input_size=self.embed_dim,
            output_size=self.embed_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.out_proj",
        )
232

233
234
235
        self.tp_size = (1 if use_data_parallel else
                        get_tensor_model_parallel_world_size())
        self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
        self.use_rope = config.use_rope

        # Detect attention implementation.
        self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
        if self.attn_backend not in {
                _Backend.FLASH_ATTN, _Backend.TORCH_SDPA,
                _Backend.ROCM_AITER_FA
        }:
            self.attn_backend = _Backend.TORCH_SDPA
        self.is_flash_attn_backend = self.attn_backend in {
            _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
        }

    def forward(
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor,
        position_embeddings: Optional[tuple[torch.Tensor,
                                            torch.Tensor]] = None,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        """Input shape: Batch x Time x Channel"""

        seq_length, embed_dim = hidden_states.shape

260
261
        qkv_states, _ = self.qkv_proj(hidden_states)
        queries, keys, values = qkv_states.chunk(3, dim=-1)
262

263
264
265
266
267
268
        queries = queries.view(seq_length, self.num_heads_per_partition,
                               self.head_dim)
        keys = keys.view(seq_length, self.num_heads_per_partition,
                         self.head_dim)
        values = values.view(seq_length, self.num_heads_per_partition,
                             self.head_dim)
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309

        if self.use_rope:
            cos, sin = position_embeddings
            queries, keys = apply_rotary_pos_emb(queries.unsqueeze(0),
                                                 keys.unsqueeze(0), cos, sin,
                                                 self.is_flash_attn_backend)
            queries = queries.squeeze(0)
            keys = keys.squeeze(0)

        max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
        if self.is_flash_attn_backend:
            if self.attn_backend == _Backend.ROCM_AITER_FA:
                from aiter import flash_attn_varlen_func
            else:
                from flash_attn import flash_attn_varlen_func
            attn_output = flash_attn_varlen_func(
                queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen,
                max_seqlen).reshape(seq_length, -1)
        elif self.attn_backend == _Backend.TORCH_SDPA:
            # Execute attention entry by entry for speed & less VRAM.
            batch_size = cu_seqlens.shape[0] - 1
            outputs = []
            cu = cu_seqlens.tolist()
            for i in range(batch_size):
                start_idx = cu[i]
                end_idx = cu[i + 1]

                # Each sequence is processed independently.
                q_i = queries[start_idx:end_idx].unsqueeze(0)
                k_i = keys[start_idx:end_idx].unsqueeze(0)
                v_i = values[start_idx:end_idx].unsqueeze(0)

                # (1, seq_len, num_heads, head_dim) ->
                # (1, num_heads, seq_len, head_dim)
                q_i, k_i, v_i = [x.transpose(1, 2) for x in (q_i, k_i, v_i)]

                output_i = F.scaled_dot_product_attention(q_i,
                                                          k_i,
                                                          v_i,
                                                          dropout_p=0.0)
                # (1, num_heads, seq_len, head_dim) -> (seq_len, embed_dim)
310
311
                output_i = output_i.transpose(1, 2).reshape(
                    end_idx - start_idx, -1)
312
313
314
                outputs.append(output_i)

            attn_output = torch.cat(outputs, dim=0)
315
        attn_output, _ = self.out_proj(attn_output)
316
317
318
319
320
        return attn_output


class Siglip2MLP(nn.Module):

321
322
323
324
325
326
327
    def __init__(
        self,
        config: Siglip2VisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        use_data_parallel: bool = False,
    ):
328
329
        super().__init__()
        self.config = config
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
        self.activation_fn = get_act_fn(config.hidden_act)
        # TODO(Isotr0py): Enable data parallel after we support
        # disabling TP on parallel linear layer
        self.fc1 = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
        )
        self.fc2 = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
        )
345
346

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
347
        hidden_states, _ = self.fc1(hidden_states)
348
        hidden_states = self.activation_fn(hidden_states)
349
        hidden_states, _ = self.fc2(hidden_states)
350
351
352
353
354
        return hidden_states


class Siglip2EncoderLayer(nn.Module):

355
356
357
358
359
360
361
    def __init__(
        self,
        config: Siglip2VisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        use_data_parallel: bool = False,
    ):
362
363
364
365
        super().__init__()
        self.embed_dim = config.hidden_size
        self.layer_norm1 = nn.LayerNorm(self.embed_dim,
                                        eps=config.layer_norm_eps)
366
367
368
369
        self.self_attn = Siglip2Attention(config,
                                          quant_config=quant_config,
                                          prefix=f"{prefix}.self_attn",
                                          use_data_parallel=use_data_parallel)
370
371
        self.layer_norm2 = nn.LayerNorm(self.embed_dim,
                                        eps=config.layer_norm_eps)
372
373
374
375
        self.mlp = Siglip2MLP(config,
                              quant_config=quant_config,
                              prefix=f"{prefix}.mlp",
                              use_data_parallel=use_data_parallel)
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411

    def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
                position_embeddings: torch.Tensor) -> tuple[torch.FloatTensor]:
        """
        Args:
            hidden_states (`torch.FloatTensor`):
                Input to the layer of shape `(batch, seq_len, embed_dim)`.
            output_attentions (`bool`, *optional*, defaults to `False`):
                Whether or not to return the attentions tensors of all 
                attention layers. See `attentions` under
                returned tensors for more detail.
        """
        residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
        hidden_states = self.self_attn(hidden_states=hidden_states,
                                       cu_seqlens=cu_seqlens,
                                       position_embeddings=position_embeddings)
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states
        return hidden_states


class Siglip2Encoder(nn.Module):
    """
    Transformer encoder consisting of `config.num_hidden_layers` 
    self attention layers. Each layer is a [`Siglip2EncoderLayer`].

    Args:
        config: PretrainedConfig
    """

412
413
414
415
416
417
418
    def __init__(
        self,
        config: Siglip2VisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        use_data_parallel: bool = False,
    ):
419
420
421
        super().__init__()
        self.config = config
        self.layers = nn.ModuleList([
422
423
424
425
426
            Siglip2EncoderLayer(config,
                                quant_config=quant_config,
                                prefix=f"{prefix}.layers.{idx}",
                                use_data_parallel=use_data_parallel)
            for idx in range(config.num_hidden_layers)
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
        ])

        self.rotary_pos_emb = VisionRotaryEmbedding(
            config.hidden_size // config.num_attention_heads // 2)
        self.patch_size = config.patch_size
        self.hidden_stride = config.hidden_stride
        self.window_size = config.window_size
        self.spatial_merge_unit = config.hidden_stride * config.hidden_stride
        if config.fullatt_block_indexes is None:
            self.fullatt_block_indexes = None
        else:
            self.fullatt_block_indexes = [
                int(i) for i in config.fullatt_block_indexes.split('|')
            ]

    # copied from qwen2.5_vl
    def rot_pos_emb(self, grid_thw):
        pos_ids = []
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            hpos_ids = hpos_ids.reshape(
                h // self.hidden_stride,
                self.hidden_stride,
                w // self.hidden_stride,
                self.hidden_stride,
            )
            hpos_ids = hpos_ids.permute(0, 2, 1, 3)
            hpos_ids = hpos_ids.flatten()

            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
            wpos_ids = wpos_ids.reshape(
                h // self.hidden_stride,
                self.hidden_stride,
                w // self.hidden_stride,
                self.hidden_stride,
            )
            wpos_ids = wpos_ids.permute(0, 2, 1, 3)
            wpos_ids = wpos_ids.flatten()
            pos_ids.append(
                torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
        pos_ids = torch.cat(pos_ids, dim=0)
        max_grid_size = grid_thw[:, 1:].max()
        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
        return rotary_pos_emb

    def get_window_index(self, grid_thw):
        window_index: list = []
        cu_window_seqlens: list = [0]
        window_index_id = 0
        # patch (after merge) number in each window
        vit_merger_window_size = (self.window_size // self.hidden_stride //
                                  self.patch_size)

        for grid_t, grid_h, grid_w in grid_thw:
            llm_grid_h, llm_grid_w = (
                grid_h // self.hidden_stride,  # number of patch after merge
                grid_w // self.hidden_stride,
            )
            index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
                grid_t, llm_grid_h, llm_grid_w)
            pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
            pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
            num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
            num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
            index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
            index_padded = index_padded.reshape(
                grid_t,
                num_windows_h,
                vit_merger_window_size,
                num_windows_w,
                vit_merger_window_size,
            )
            index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
                grid_t,
                num_windows_h * num_windows_w,
                vit_merger_window_size,
                vit_merger_window_size,
            )
            seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
            index_padded = index_padded.reshape(-1)
            index_new = index_padded[index_padded != -100]
            window_index.append(index_new + window_index_id)
            cu_seqlens_tmp = seqlens.cumsum(
                0) * self.spatial_merge_unit + cu_window_seqlens[-1]
            cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
            window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
        window_index = torch.cat(window_index, dim=0)

        return window_index, cu_window_seqlens

    def forward(
        self,
520
        inputs_embeds: torch.Tensor,
521
        grid_thws: torch.Tensor,
522
    ) -> torch.Tensor:
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
        r"""
        Args:
            inputs_embeds (`torch.FloatTensor` of shape
                `(batch_size, sequence_length, hidden_size)`):
                Optionally, instead of passing `input_ids` you can choose to
                directly pass an embedded representation. This is useful if
                you want more control over how to convert `input_ids` indices
                into associated vectors than the model's internal embedding
                lookup matrix.
            grid_thws (`torch.LongTensor`):
                grid shape (num_patches, 3)
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See
                `hidden_states` under returned tensors for more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of
                a plain tuple.
        """
        rotary_pos_emb = self.rot_pos_emb(grid_thws)
        window_index, cu_window_seqlens = self.get_window_index(grid_thws)
        cu_window_seqlens = torch.tensor(
            cu_window_seqlens,
            device=inputs_embeds.device,
            dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32,
        )
        cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)

        seq_len, _ = inputs_embeds.size()
        inputs_embeds = inputs_embeds.reshape(
            seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
        inputs_embeds = inputs_embeds[window_index, :, :]
        inputs_embeds = inputs_embeds.reshape(seq_len, -1)
        rotary_pos_emb = rotary_pos_emb.reshape(
            seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
        rotary_pos_emb = rotary_pos_emb[window_index, :, :]
        rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
        emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
        position_embeddings = (emb.cos(), emb.sin())

        cu_seqlens = torch.repeat_interleave(
            grid_thws[:, 1] * grid_thws[:, 2], grid_thws[:, 0]
        ).cumsum(
            dim=0,
            # Select dtype based on the following factors:
            #  - FA2 requires that cu_seqlens_q must have dtype int32
            #  - torch.onnx.export requires that cu_seqlens_q must have
            #    same dtype as grid_thw
            # See https://github.com/huggingface/transformers/pull/34852
            # for more information
            dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32,
        )
        cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)

        reverse_indices = torch.argsort(window_index)

        hidden_states = inputs_embeds
        for index, block in enumerate(self.layers):
            if (not self.fullatt_block_indexes
                    or index in self.fullatt_block_indexes):
                cu_seqlens_tmp = cu_seqlens
            else:
                cu_seqlens_tmp = cu_window_seqlens
            hidden_states = block(hidden_states, cu_seqlens_tmp,
                                  position_embeddings)
587

588
589
590
591
        hidden_states = hidden_states.reshape(
            seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
        hidden_states = hidden_states[reverse_indices, :].reshape(seq_len, -1)

592
        return hidden_states
593
594
595
596


class Siglip2VisionTransformer(nn.Module):

597
598
599
600
601
602
603
    def __init__(
        self,
        config: Siglip2VisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        use_data_parallel: bool = False,
    ):
604
605
606
607
608
        super().__init__()
        self.config = config
        embed_dim = config.hidden_size

        self.embeddings = Siglip2VisionEmbeddings(config)
609
610
611
612
        self.encoder = Siglip2Encoder(config,
                                      quant_config=quant_config,
                                      prefix=f"{prefix}.encoder",
                                      use_data_parallel=use_data_parallel)
613
614
615
616
617
618
619
        self.post_layernorm = nn.LayerNorm(embed_dim,
                                           eps=config.layer_norm_eps)

    def forward(
        self,
        pixel_values: torch.FloatTensor,
        grid_thws: torch.LongTensor,
620
    ) -> torch.Tensor:
621
622
623
624
625
626
627
        r"""
        spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
            Tensor containing the spatial dimensions (height, width)
            of the input images.
        """
        hidden_states = self.embeddings(pixel_values, grid_thws)

628
        last_hidden_state = self.encoder(hidden_states, grid_thws)
629
630
631
632
633
634
635
        last_hidden_state = self.post_layernorm(last_hidden_state)

        return last_hidden_state


class Siglip2NavitModel(torch.nn.Module):

636
637
638
639
640
641
642
    def __init__(
        self,
        config: Siglip2VisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        use_data_parallel: bool = False,
    ):
643
644
        super().__init__()

645
646
647
648
649
        self.vision_model = Siglip2VisionTransformer(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.vision_model",
            use_data_parallel=use_data_parallel)
650
651
652
653
654

    def forward(
        self,
        pixel_values: torch.FloatTensor,
        grid_thws: torch.LongTensor,
655
    ) -> torch.Tensor:
656
657
658
659
        return self.vision_model(
            pixel_values=pixel_values,
            grid_thws=grid_thws,
        )
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()

        for name, loaded_weight in weights:
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params