siglip.py 40.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""Implementation of SiglipVisionModel intended to be only used
within a vision language model."""

import math
7
8
9
from collections.abc import Iterable, Mapping
from functools import cached_property
from typing import Annotated, Literal
10
11
12

import torch
from torch import nn
13
14
15
16
17
18
19
from transformers import (
    BatchFeature,
    SiglipConfig,
    SiglipProcessor,
    SiglipTextConfig,
    SiglipVisionConfig,
)
20

21
from vllm.attention.layer import MultiHeadAttention
22
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
23
24
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
25
from vllm.distributed import divide, get_tensor_model_parallel_world_size
26
from vllm.model_executor.layers.activation import get_act_fn
27
from vllm.model_executor.layers.conv import Conv2dLayer
28
29
30
31
32
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
33
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
34
from vllm.model_executor.layers.quantization import QuantizationConfig
35
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
36
from vllm.model_executor.model_loader.weight_utils import (
37
38
39
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalInputs,
    MultiModalKwargsItems,
    MultiModalUUIDDict,
)
from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptIndexTargets,
    PromptReplacement,
    PromptUpdate,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
59

60
61
62
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsQuant
from .interfaces_base import default_pooling_type
from .utils import AutoWeightsLoader, maybe_prefix
63
64
65
from .vision import (
    VisionEncoderInfo,
    VisionFeatureSelectStrategy,
66
67
    VisionFeatureSelectStrategyStr,
    get_num_selected_vision_tokens,
68
69
    resolve_visual_encoder_outputs,
)
70

71

72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
class SiglipImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height of each image
        - w: Width of each image
    """

    type: Literal["pixel_values"]
    data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]


_POOLING_TYPE_TO_STRATEGY: dict[str, VisionFeatureSelectStrategyStr] = {
    "MEAN": "full",
    "ALL": "full",
    "CLS": "class",
}


def _get_vision_feature_select_strategy(
    pooling_type: str,
) -> VisionFeatureSelectStrategyStr:
    try:
        return _POOLING_TYPE_TO_STRATEGY[pooling_type]
    except KeyError:
        raise ValueError(
            f"No feature selection strategy is defined for "
            f"pooling_type: {pooling_type!r}"
        ) from None


class SiglipProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
        return self.ctx.get_hf_config(SiglipConfig)

    def get_vision_encoder_info(self):
        return SiglipEncoderInfo(self.get_hf_config())

    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(SiglipProcessor, **kwargs)

    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
        return {"image": 1}

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        vision_encoder_info = self.get_vision_encoder_info()

        pooler_config = self.ctx.model_config.pooler_config
        assert pooler_config is not None

        return get_num_selected_vision_tokens(
            vision_encoder_info.get_num_image_tokens(
                image_width=image_width,
                image_height=image_height,
            ),
            _get_vision_feature_select_strategy(pooler_config.pooling_type),
        )

    def get_image_size_with_most_features(self) -> ImageSize:
        vision_encoder_info = self.get_vision_encoder_info()
        width = height = vision_encoder_info.get_image_size()
        return ImageSize(width=width, height=height)

    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()

        return self.get_num_image_tokens(
            image_width=target_width, image_height=target_height
        )


class SiglipDummyInputsBuilder(BaseDummyInputsBuilder[SiglipProcessingInfo]):
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        return ""

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)

        target_width, target_height = self.info.get_image_size_with_most_features()

        image_overrides = mm_options.get("image") if mm_options else None

        return {
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
        }


class SiglipMultiModalProcessor(BaseMultiModalProcessor[SiglipProcessingInfo]):
    @cached_property
    def image_token_id(self) -> int:
        tokenizer = self.info.get_tokenizer()
179
180
181
182
183
        dummy_token_id = next(
            token_id
            for token_id in range(tokenizer.vocab_size)
            if token_id not in tokenizer.all_special_ids
        )
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260

        return dummy_token_id

    def apply(
        self,
        prompt: str | list[int],
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
        tokenization_kwargs: Mapping[str, object] | None = None,
        *,
        mm_uuids: MultiModalUUIDDict | None = None,
    ) -> MultiModalInputs:
        if prompt and mm_data:
            raise ValueError(
                "Siglip accepts text-only or image-only inputs, not both! "
                "Image-only inputs means passing an image with an empty text "
                "prompt."
            )

        if mm_data:
            # For multi-modal data, the prompt after processing should
            # only contain the image token
            tokenization_kwargs = {
                **(tokenization_kwargs or {}),
                "add_special_tokens": False,
            }

        return super().apply(
            prompt=prompt,
            mm_data=mm_data,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
            tokenization_kwargs=tokenization_kwargs,
            mm_uuids=mm_uuids,
        )

    def _hf_processor_applies_updates(
        self,
        prompt_text: str,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        tokenization_kwargs: Mapping[str, object],
    ) -> bool:
        return False

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(pixel_values=MultiModalFieldConfig.batched("image"))

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargsItems,
    ) -> list[PromptUpdate]:
        image_token_id = self.image_token_id

        def get_replacement(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)

            num_image_tokens = self.info.get_num_image_tokens(
                image_width=image_size.width, image_height=image_size.height
            )
            return [image_token_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=PromptIndexTargets.start(),
                replacement=get_replacement,
            ),
        ]


261
262
263
264
265
266
267
class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]):
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
268
        return self.get_patch_grid_length() ** 2
269

270
271
272
273
274
275
276
    def get_image_size(self) -> int:
        return self.vision_config.image_size

    def get_patch_size(self) -> int:
        return self.vision_config.patch_size

    def get_patch_grid_length(self) -> int:
277
278
        image_size, patch_size = self.get_image_size(), self.get_patch_size()
        return image_size // patch_size
279
280


281
282
283
284
285
286
287
288
289
# Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa
class SiglipVisionEmbeddings(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size

290
        self.patch_embedding = Conv2dLayer(
291
292
293
294
295
296
297
            in_channels=config.num_channels,
            out_channels=self.embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            padding="valid",
        )

298
        self.num_patches = (self.image_size // self.patch_size) ** 2
299
300
        self.num_positions = self.num_patches
        self.position_embedding = VocabParallelEmbedding(
301
302
            self.num_positions, self.embed_dim
        )
303
304
        self.register_buffer(
            "position_ids",
305
            torch.arange(self.num_positions, dtype=torch.int64).expand((1, -1)),
306
307
308
            persistent=False,
        )

309
310
311
    def interpolate_pos_encoding(
        self, embeddings: torch.Tensor, height: int, width: int
    ) -> torch.Tensor:
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
        """
        This method is an adapted method for SigLIP (due to SigLIP not having
        class embedding unlike other ViTs) that allows the model to interpolate
        the pre-trained position encodings such that it can be usable on higher
        resolution images.

        Source:
        https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
        """
        position_embeddings = self.position_embedding.weight.unsqueeze(0)
        num_patches = embeddings.shape[1]
        num_positions = position_embeddings.shape[1]
        if num_patches == num_positions and height == width:
            return position_embeddings

        dim = embeddings.shape[-1]
        height = height // self.patch_size
        width = width // self.patch_size
        # we add a small number to avoid floating point error
        # in the interpolation
        # see discussion at https://github.com/facebookresearch/dino/issues/8
        height, width = height + 0.1, width + 0.1

        patch_pos_embed = position_embeddings.reshape(
336
337
            1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
        )
338
339
340
341
342
343
344
345
346
347
        patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
        patch_pos_embed = nn.functional.interpolate(
            patch_pos_embed,
            scale_factor=(
                height / math.sqrt(num_positions),
                width / math.sqrt(num_positions),
            ),
            mode="bicubic",
            align_corners=False,
        )
348
349
350
351
352
353
354
355
        if (
            int(height) != patch_pos_embed.shape[-2]
            or int(width) != patch_pos_embed.shape[-1]
        ):
            raise ValueError(
                "Width or height does not match with "
                "the interpolated position embeddings"
            )
356
357
358
359

        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return patch_pos_embed

360
361
362
    def forward(
        self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False
    ) -> torch.Tensor:
363
364
        _, _, height, width = pixel_values.shape
        target_dtype = self.patch_embedding.weight.dtype
365
366
367
        patch_embeds = self.patch_embedding(
            pixel_values.to(dtype=target_dtype)
        )  # shape = [*, width, grid, grid]
368
369
370
        embeddings = patch_embeds.flatten(2).transpose(1, 2)

        if interpolate_pos_encoding:
371
            embeddings += self.interpolate_pos_encoding(embeddings, height, width)
372
        else:
373
            embeddings += self.position_embedding(self.position_ids)
374
375
376
        return embeddings


377
class SiglipAttention(nn.Module):
378
379
    def __init__(
        self,
380
        config: SiglipVisionConfig | SiglipTextConfig,
381
        quant_config: QuantizationConfig | None = None,
382
        *,
383
        prefix: str = "",
384
        attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention],
385
    ) -> None:
386
        super().__init__()
387

388
389
        self.config = config
        self.embed_dim = config.hidden_size
390
391
392
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
393
394
395
396
397
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got "
                "`embed_dim`: {self.embed_dim} and `num_heads`:"
                f" {self.num_heads})."
            )
398

399
400
401
402
403
        self.scale = self.head_dim**-0.5
        self.dropout = config.attention_dropout
        self.qkv_proj = QKVParallelLinear(
            hidden_size=self.embed_dim,
            head_size=self.head_dim,
404
            total_num_heads=self.num_heads,
405
            quant_config=quant_config,
406
            prefix=f"{prefix}.qkv_proj",
407
        )
408

409
410
411
412
        self.out_proj = RowParallelLinear(
            input_size=self.embed_dim,
            output_size=self.embed_dim,
            quant_config=quant_config,
413
            prefix=f"{prefix}.out_proj",
414
415
        )

416
417
        self.tp_size = get_tensor_model_parallel_world_size()
        self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
418

419
420
421
422
423
        self.attn = attn_cls(
            self.num_heads_per_partition,
            self.head_dim,
            self.scale,
            prefix=f"{prefix}.attn",
424
        )
425

426
427
428
    def forward(
        self,
        hidden_states: torch.Tensor,
429
    ) -> tuple[torch.Tensor, None]:
430
431
        """Input shape: Batch x Time x Channel"""
        qkv_states, _ = self.qkv_proj(hidden_states)
432
        query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
433
        out = self.attn(query_states, key_states, value_states)
434
        attn_output, _ = self.out_proj(out)
435

436
        return attn_output, None
437
438
439
440
441


class SiglipMLP(nn.Module):
    def __init__(
        self,
442
        config: SiglipVisionConfig | SiglipTextConfig,
443
        quant_config: QuantizationConfig | None = None,
444
445
        prefix: str = "",
    ) -> None:
446
        super().__init__()
447

448
449
        self.config = config
        self.activation_fn = get_act_fn(config.hidden_act)
450
        # Special handling for BNB and torchao quantization
451
        if quant_config and quant_config.get_name() in ["bitsandbytes", "torchao"]:
452
453
            quantizable = True
        else:
454
            # For other quantization, we require the hidden size to be a
455
            # multiple of 64
456
457
458
            quantizable = (
                config.hidden_size % 64 == 0 and config.intermediate_size % 64 == 0
            )
459
460
461
462
        self.fc1 = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
            quant_config=quant_config if quantizable else None,
463
            prefix=f"{prefix}.fc1",
464
465
466
467
468
        )
        self.fc2 = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
            quant_config=quant_config if quantizable else None,
469
            prefix=f"{prefix}.fc2",
470
471
472
473
474
475
476
477
478
479
480
481
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)
        return hidden_states


class SiglipEncoderLayer(nn.Module):
    def __init__(
        self,
482
        config: SiglipVisionConfig | SiglipTextConfig,
483
        quant_config: QuantizationConfig | None = None,
484
        *,
485
        prefix: str = "",
486
        attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention],
487
    ) -> None:
488
        super().__init__()
489

490
491
        self.embed_dim = config.hidden_size

492
493
494
495
        self.self_attn = SiglipAttention(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
496
            attn_cls=attn_cls,
497
        )
498
        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
499
500
501
        self.mlp = SiglipMLP(
            config,
            quant_config=quant_config,
502
            prefix=f"{prefix}.mlp",
503
        )
504
        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
505
506
507
508

    def forward(
        self,
        hidden_states: torch.Tensor,
509
    ) -> tuple[torch.Tensor, None]:
510
511
512
        residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
513
        hidden_states, _ = self.self_attn(hidden_states=hidden_states)
514
        hidden_states += residual
515
516
517
518

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
519
        hidden_states += residual
520
521
522
523
524
525
526

        return hidden_states, None


class SiglipEncoder(nn.Module):
    def __init__(
        self,
527
        config: SiglipVisionConfig | SiglipTextConfig,
528
529
        quant_config: QuantizationConfig | None = None,
        num_hidden_layers_override: int | None = None,
530
        *,
531
        prefix: str = "",
532
        attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention],
533
    ) -> None:
534
        super().__init__()
535

536
        self.config = config
537
538
539
540
541
542

        if num_hidden_layers_override is None:
            num_hidden_layers = config.num_hidden_layers
        else:
            num_hidden_layers = num_hidden_layers_override

543
544
545
546
547
548
        self.layers = nn.ModuleList(
            [
                SiglipEncoderLayer(
                    config,
                    quant_config=quant_config,
                    prefix=f"{prefix}.layers.{layer_idx}",
549
                    attn_cls=attn_cls,
550
551
552
553
                )
                for layer_idx in range(num_hidden_layers)
            ]
        )
554
555
556
557

    def forward(
        self,
        inputs_embeds: torch.Tensor,
558
        return_all_hidden_states: bool,
559
    ) -> torch.Tensor | list[torch.Tensor]:
560
        hidden_states_pool = [inputs_embeds]
561
        hidden_states = inputs_embeds
562

563
564
        for encoder_layer in self.layers:
            hidden_states, _ = encoder_layer(hidden_states)
565
566
567
568
569
570
            if return_all_hidden_states:
                hidden_states_pool.append(hidden_states)
        # If we have multiple feature sample layers, we return all hidden
        # states in order and grab the ones we need by index.
        if return_all_hidden_states:
            return hidden_states_pool
571
572
573
        return hidden_states


574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
class SiglipTextTransformer(nn.Module):
    def __init__(
        self,
        config: SiglipTextConfig,
        quant_config: QuantizationConfig | None = None,
        *,
        prefix: str = "",
    ) -> None:
        super().__init__()

        self.config = config
        embed_dim = config.hidden_size

        self.embeddings = SiglipTextEmbeddings(config)

        self.encoder = SiglipEncoder(
            config=config,
            quant_config=quant_config,
            prefix=f"{prefix}.encoder",
593
            attn_cls=EncoderOnlyAttention,
594
595
596
597
598
        )

        self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
        self.head = nn.Linear(embed_dim, config.projection_size)

599
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
        return self.embeddings.token_embedding(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor | None,
        position_ids: torch.Tensor,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor:
        hidden_states = self.embeddings(input_ids, position_ids, inputs_embeds)

        last_hidden_state = self.encoder(
            inputs_embeds=hidden_states, return_all_hidden_states=False
        )

        last_hidden_state = self.final_layer_norm(last_hidden_state)

        return last_hidden_state

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()

        for name, loaded_weight in weights:
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


645
646
647
648
649
650
class SiglipMultiheadAttentionPoolingHead(nn.Module):
    """Multihead Attention Pooling."""

    def __init__(
        self,
        config: SiglipVisionConfig,
651
        quant_config: QuantizationConfig | None = None,
652
653
        prefix: str = "",
    ) -> None:
654
655
656
657
658
        super().__init__()

        self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
        # TODO(ChristopherCho): Implement vLLM version of MultiheadAttention
        self.attention = torch.nn.MultiheadAttention(
659
660
661
662
663
664
            config.hidden_size, config.num_attention_heads, batch_first=True
        )
        self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.mlp = SiglipMLP(
            config=config, quant_config=quant_config, prefix=f"{prefix}.mlp"
        )
665
666

    def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
667
668
669
        batch_size = hidden_state.size(0)

        probe = self.probe.expand(batch_size, -1, -1)
670
671
672
673
674

        hidden_state = self.attention(probe, hidden_state, hidden_state)[0]

        residual = hidden_state
        hidden_state = self.layernorm(hidden_state)
675
676
        hidden_state = self.mlp(hidden_state)
        hidden_state += residual
677

678
679
680
        pooled = hidden_state[:, 0]

        return pooled.unsqueeze(1)
681
682
683
684
685
686


class SiglipVisionTransformer(nn.Module):
    def __init__(
        self,
        config: SiglipVisionConfig,
687
        quant_config: QuantizationConfig | None = None,
688
        *,
689
690
        num_hidden_layers_override: int | None = None,
        require_post_norm: bool | None = None,
691
692
        prefix: str = "",
    ) -> None:
693
        super().__init__()
694

695
696
697
698
        self.config = config
        embed_dim = config.hidden_size

        self.embeddings = SiglipVisionEmbeddings(config)
699

700
701
702
        self.encoder = SiglipEncoder(
            config,
            quant_config=quant_config,
703
            num_hidden_layers_override=num_hidden_layers_override,
704
            prefix=f"{prefix}.encoder",
705
            attn_cls=MultiHeadAttention,
706
        )
707

708
        num_hidden_layers = config.num_hidden_layers
709
710
        if len(self.encoder.layers) > config.num_hidden_layers:
            raise ValueError(
711
                f"The original encoder only has {num_hidden_layers} "
712
713
                f"layers, but you requested {len(self.encoder.layers)} layers."
            )
714
715
716
717
718
719

        # If possible, skip post_layernorm to conserve memory
        if require_post_norm is None:
            require_post_norm = len(self.encoder.layers) == num_hidden_layers

        if require_post_norm:
720
            self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
721
        else:
722
723
            self.post_layernorm = None

724
725
726
        self.use_head = (
            True if not hasattr(config, "vision_use_head") else config.vision_use_head
        )
727
728
        if self.use_head:
            self.head = SiglipMultiheadAttentionPoolingHead(
729
730
731
732
                config=config,
                quant_config=quant_config,
                prefix=f"{prefix}.head",
            )
733

734
735
736
737
738
739
740
741
    @property
    def dtype(self):
        return next(self.parameters()).dtype

    @property
    def device(self):
        return next(self.parameters()).device

742
743
744
    def forward(
        self,
        pixel_values: torch.Tensor,
745
746
        *,
        interpolate_pos_encoding: bool = False,
747
748
        select_layers: list[int] | None = None,
        feature_select_strategy: VisionFeatureSelectStrategy | None = None,
749
750
751
752
753
    ) -> torch.Tensor:
        hidden_states = self.embeddings(
            pixel_values,
            interpolate_pos_encoding=interpolate_pos_encoding,
        )
754
        # Produces either the last layer output or all of the hidden states,
755
        # depending on if we have select_layers or not
756
757
        encoder_outputs = self.encoder(
            inputs_embeds=hidden_states,
758
            return_all_hidden_states=select_layers is not None,
759
        )
760

761
762
763
764
765
766
767
        if self.post_layernorm is not None:
            encoder_outputs = self.post_layernorm(encoder_outputs)

        if self.use_head:
            encoder_outputs = self.head(encoder_outputs)

        # stacks feature layers if needed
768
        encoder_outputs = resolve_visual_encoder_outputs(
769
            encoder_outputs,
770
            None,
771
772
773
774
            select_layers=select_layers,
            max_possible_layers=self.config.num_hidden_layers,
            feature_select_strategy=feature_select_strategy,
        )
775

776
        return encoder_outputs
777

778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        layer_count = len(self.encoder.layers)

        for name, loaded_weight in weights:
            # post_layernorm is not needed in SiglipVisionTransformer
            if name.startswith("post_layernorm") and self.post_layernorm is None:
                continue

            # omit layers when num_hidden_layers_override is set
            if name.startswith("encoder.layers"):
                layer_idx = int(name.split(".")[2])
                if layer_idx >= layer_count:
                    continue

            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

815
816
817
818
819
820
821
822

class SiglipVisionModel(nn.Module):
    config_class = SiglipVisionConfig
    main_input_name = "pixel_values"

    def __init__(
        self,
        config: SiglipVisionConfig,
823
        quant_config: QuantizationConfig | None = None,
824
        *,
825
826
        num_hidden_layers_override: int | None = None,
        require_post_norm: bool | None = None,
827
828
        prefix: str = "",
    ) -> None:
829
        super().__init__()
830

831
        self.quant_config = quant_config
832
833
834
        self.vision_model = SiglipVisionTransformer(
            config,
            quant_config,
835
            num_hidden_layers_override=num_hidden_layers_override,
836
837
            require_post_norm=require_post_norm,
            prefix=f"{prefix}.vision_model",
838
839
840
841
842
        )

    def get_input_embeddings(self) -> nn.Module:
        return self.vision_model.embeddings.patch_embedding

843
844
    @property
    def dtype(self):
845
846
847
848
849
        return self.vision_model.dtype

    @property
    def device(self):
        return self.vision_model.device
850

851
852
853
854
    def forward(
        self,
        pixel_values: torch.Tensor,
        interpolate_pos_encoding: bool = False,
855
856
        select_layers: list[int] | None = None,
        feature_select_strategy: VisionFeatureSelectStrategy | None = None,
857
858
859
860
    ) -> torch.Tensor:
        return self.vision_model(
            pixel_values=pixel_values,
            interpolate_pos_encoding=interpolate_pos_encoding,
861
862
            select_layers=select_layers,
            feature_select_strategy=feature_select_strategy,
863
        )
864

865
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
866
867
868
869
870
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
871
        ]
872
        params_dict = dict(self.named_parameters())
873
        loaded_params: set[str] = set()
874
875
876
        layer_count = len(self.vision_model.encoder.layers)

        for name, loaded_weight in weights:
877
            # post_layernorm is optional in SiglipVisionModel
878
879
880
881
            if (
                name.startswith("vision_model.post_layernorm")
                and self.vision_model.post_layernorm is None
            ):
882
883
                continue

884
            # omit layers when num_hidden_layers_override is set
885
            if name.startswith("vision_model.encoder.layers"):
886
887
888
889
                layer_idx = int(name.split(".")[3])
                if layer_idx >= layer_count:
                    continue

890
            # Check if this is a scale parameter that needs remapping first
891
            if name.endswith((".k_scale", ".v_scale", ".q_scale", ".prob_scale")):
892
893
894
895
896
                # Try to remap the scale name first
                remapped_name = maybe_remap_kv_scale_name(name, params_dict)
                if remapped_name is not None and remapped_name in params_dict:
                    # Successfully remapped, use the remapped name
                    param = params_dict[remapped_name]
897
898
899
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
900
901
902
903
904
                    weight_loader(param, loaded_weight)
                    loaded_params.add(remapped_name)
                    continue
                # If remapping failed, continue with normal processing

905
            for param_name, weight_name, shard_id in stacked_params_mapping:
906
907
                if weight_name not in name:
                    continue
908
                name = name.replace(weight_name, param_name)
909

910
                param = params_dict[name]
911
912
913
914
915
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
916
917
918
                param = maybe_swap_ffn_param(
                    name, param, loaded_weight, params_dict, self.quant_config
                )
919
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
920
                weight_loader(param, loaded_weight)
921
922
            loaded_params.add(name)
        return loaded_params
923
924


925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
def maybe_swap_ffn_param(
    name: str,
    param: torch.Tensor,
    loaded_weight: torch.Tensor,
    params_dict: dict[str, torch.Tensor],
    quant_config: QuantizationConfig,
) -> torch.Tensor:
    if not (quant_config and quant_config.get_name() == "gguf") or ".fc" not in name:
        return param
    # Some GGUF models have fc1 and fc2 weights swapped
    tp_size = get_tensor_model_parallel_world_size()
    output_dim = getattr(param, "output_dim", 0)
    output_size = param.size(output_dim) * tp_size
    weight_out_size = loaded_weight.size(output_dim)
    if ".fc1." in name and output_size != weight_out_size:
        new_name = name.replace(".fc1.", ".fc2.")
        param = params_dict[new_name]
    elif ".fc2." in name and output_size != weight_out_size:
        new_name = name.replace(".fc2.", ".fc1.")
        param = params_dict[new_name]
    return param


948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
# Adapted from: https://github.com/huggingface/transformers/blob/v4.54.1/src/transformers/models/siglip/modeling_siglip.py#L200
class SiglipTextEmbeddings(nn.Module):
    def __init__(self, config: SiglipTextConfig):
        super().__init__()
        self.config = config

        self.token_embedding = VocabParallelEmbedding(
            config.vocab_size, config.hidden_size
        )

        self.position_embedding = VocabParallelEmbedding(
            config.max_position_embeddings, config.hidden_size
        )

        self.register_buffer(
            "position_ids",
            torch.arange(config.max_position_embeddings).expand((1, -1)),
            persistent=False,
        )

    def forward(
        self,
        input_ids: torch.Tensor | None,
        position_ids: torch.Tensor,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor:
        if inputs_embeds is None:
            inputs_embeds = self.token_embedding(input_ids)

        position_embeddings = self.position_embedding(position_ids)
        embeddings = inputs_embeds + position_embeddings
        return embeddings


# Assume EOS token corresponds to CLS token in text model
@default_pooling_type("CLS")
@MULTIMODAL_REGISTRY.register_processor(
    SiglipMultiModalProcessor,
    info=SiglipProcessingInfo,
    dummy_inputs=SiglipDummyInputsBuilder,
)
class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
    is_pooling_model = True

    packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
    merge_by_field_config = True

    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
        if modality.startswith("image"):
            return None

        raise ValueError("Only image modality is supported")

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config: SiglipConfig = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
        self.config = config
        self.multimodal_config = multimodal_config

        if hasattr(config, "num_labels"):
            config.num_labels = 0

        text_config = config.text_config
        vision_config = config.vision_config

        self.text_embed_dim = text_config.hidden_size
        self.vision_embed_dim = vision_config.hidden_size

        self.text_model = SiglipTextTransformer(
            text_config,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "text_model"),
        )
        self.vision_model = SiglipVisionTransformer(
            vision_config,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "vision_model"),
        )

        self.text_projection_size = text_config.projection_size

        pooler_config = vllm_config.model_config.pooler_config
        assert pooler_config is not None
        self.pooler_config = pooler_config

        self.pooler = DispatchPooler(
            {
                "token_embed": Pooler.for_token_embed(pooler_config),
                "embed": Pooler.for_embed(pooler_config),
            }
        )

        self._is_text_input = True

    def get_text_features(
        self,
        input_ids: torch.Tensor | None,
        position_ids: torch.Tensor,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor:
        last_hidden_state = self.text_model(
            input_ids=input_ids,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
        )
        text_features = self.text_model.head(last_hidden_state)
1058
1059
1060
1061
1062
1063
1064

        # SigLIP uses reversed position_ids;
        # flip sequences to move EOS token to first position
        text_features = self._flip_sequences_by_position_ids(
            text_features, position_ids
        )

1065
1066
        return text_features

1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
    def _flip_sequences_by_position_ids(
        self,
        features: torch.Tensor,
        position_ids: torch.Tensor,
    ) -> torch.Tensor:
        """Flip sequences so EOS token moves to first position for CLS pooling.

        SigLIP position_ids are reversed within each sequence. This method detects
        sequence boundaries and flips each sequence individually.
        """
        if len(features) == 1:
            return features

        # Detect sequence boundaries where position_ids decrease
        position_diffs = position_ids[1:] - position_ids[:-1]
        boundary_mask = position_diffs <= 0

        boundary_indices = torch.cat(
            [
                torch.tensor([0], device=features.device),
                torch.where(boundary_mask)[0] + 1,
                torch.tensor([len(features)], device=features.device),
            ]
        )

        # For each sequence [start, end), position i flips to: start + end - 1 - i
        lengths = boundary_indices[1:] - boundary_indices[:-1]
        starts = boundary_indices[:-1]
        ends = boundary_indices[1:]

        # Assign sequence ID to each element
        sequence_ids = torch.arange(
            len(lengths), device=features.device
        ).repeat_interleave(lengths)

        # Calculate flipped indices for all positions at once
        current_positions = torch.arange(len(features), device=features.device)
        flip_indices = starts[sequence_ids] + ends[sequence_ids] - 1 - current_positions

        return features[flip_indices]

1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
    def get_image_features(
        self,
        pixel_values: torch.Tensor,
        feature_select_strategy: VisionFeatureSelectStrategy | None = None,
    ) -> torch.Tensor:
        if feature_select_strategy is None:
            feature_select_strategy = _get_vision_feature_select_strategy(
                self.pooler_config.pooling_type
            )

        pooled_output = self.vision_model(
            pixel_values=pixel_values,
            select_layers=None,
            feature_select_strategy=feature_select_strategy,
        )

        return pooled_output

    def _parse_and_validate_image_input(
        self, **kwargs: object
    ) -> SiglipImagePixelInputs | None:
        pixel_values = kwargs.pop("pixel_values", None)
        if pixel_values is None:
            return None

        expected_h = expected_w = self.config.vision_config.image_size
        return SiglipImagePixelInputs(
            type="pixel_values",
            data=pixel_values,
            resolve_bindings={"h": expected_h, "w": expected_w},
        )

    def _process_image_inputs(self, inputs: SiglipImagePixelInputs) -> torch.Tensor:
        pixel_values = inputs["data"]

        return self.get_image_features(pixel_values)

    def get_language_model(self) -> torch.nn.Module:
        return self.text_model

1148
    def embed_input_ids(
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: MultiModalEmbeddings | None = None,
        *,
        is_multimodal: torch.Tensor | None = None,
        handle_oov_mm_token: bool = False,
    ) -> torch.Tensor:
        self._is_text_input = (
            multimodal_embeddings is None or len(multimodal_embeddings) == 0
        )

        if multimodal_embeddings is None or is_multimodal is None:
1161
            return super().embed_input_ids(input_ids)
1162

1163
        return super().embed_input_ids(
1164
1165
1166
1167
1168
1169
            input_ids,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
        )

1170
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return []

        vision_embeddings = self._process_image_inputs(image_input)
        return vision_embeddings

    def forward(
        self,
        input_ids: torch.Tensor | None,
        positions: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        **kwargs: object,
    ) -> torch.Tensor:
        if intermediate_tensors is not None:
            raise RuntimeError("PP is not supported for this model")

        # Multimodal inputs (image embeddings)
        if not self._is_text_input:
            return inputs_embeds

        return self.get_text_features(input_ids, positions, inputs_embeds)

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
        loader = AutoWeightsLoader(
            self,
            skip_substrs=[".position_ids"],
            ignore_unexpected_prefixes=["logit_scale.", "logit_bias."],
        )

        return loader.load_weights(weights)