siglip.py 43 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Callable, Iterable, Mapping
5
from functools import cached_property, partial
6
from typing import Annotated, Literal
7
8
9

import torch
from torch import nn
10
11
12
13
14
15
16
from transformers import (
    BatchFeature,
    SiglipConfig,
    SiglipProcessor,
    SiglipTextConfig,
    SiglipVisionConfig,
)
17

18
from vllm.config import VllmConfig
19
from vllm.config.multimodal import BaseDummyOptions
20
from vllm.distributed import divide, get_tensor_model_parallel_world_size
21
from vllm.model_executor.layers.activation import get_act_fn
22
from vllm.model_executor.layers.attention import (
23
    EncoderOnlyAttention,
24
    MMEncoderAttention,
25
)
26
from vllm.model_executor.layers.conv import Conv2dLayer
27
28
29
30
31
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
32
from vllm.model_executor.layers.pooler import DispatchPooler
33
from vllm.model_executor.layers.quantization import QuantizationConfig
34
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
35
from vllm.model_executor.model_loader.weight_utils import (
36
37
38
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
39
40
41
42
43
44
45
46
47
48
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalInputs,
    MultiModalKwargsItems,
    MultiModalUUIDDict,
)
from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (
49
    BaseDummyInputsBuilder,
50
51
52
53
54
55
56
57
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptIndexTargets,
    PromptReplacement,
    PromptUpdate,
)
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
58

59
60
61
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsQuant
from .interfaces_base import default_pooling_type
from .utils import AutoWeightsLoader, maybe_prefix
62
63
64
from .vision import (
    VisionEncoderInfo,
    VisionFeatureSelectStrategy,
65
66
    VisionFeatureSelectStrategyStr,
    get_num_selected_vision_tokens,
67
    is_vit_use_data_parallel,
68
69
    resolve_visual_encoder_outputs,
)
70

71

72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
class SiglipImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height of each image
        - w: Width of each image
    """

    type: Literal["pixel_values"]
    data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]


_POOLING_TYPE_TO_STRATEGY: dict[str, VisionFeatureSelectStrategyStr] = {
    "MEAN": "full",
    "ALL": "full",
    "CLS": "class",
}


def _get_vision_feature_select_strategy(
    pooling_type: str,
) -> VisionFeatureSelectStrategyStr:
    try:
        return _POOLING_TYPE_TO_STRATEGY[pooling_type]
    except KeyError:
        raise ValueError(
            f"No feature selection strategy is defined for "
            f"pooling_type: {pooling_type!r}"
        ) from None


class SiglipProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
        return self.ctx.get_hf_config(SiglipConfig)

    def get_vision_encoder_info(self):
        return SiglipEncoderInfo(self.get_hf_config())

    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(SiglipProcessor, **kwargs)

    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
        return {"image": 1}

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        vision_encoder_info = self.get_vision_encoder_info()

        pooler_config = self.ctx.model_config.pooler_config
        assert pooler_config is not None

        return get_num_selected_vision_tokens(
            vision_encoder_info.get_num_image_tokens(
                image_width=image_width,
                image_height=image_height,
            ),
133
            _get_vision_feature_select_strategy(pooler_config.seq_pooling_type),
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
        )

    def get_image_size_with_most_features(self) -> ImageSize:
        vision_encoder_info = self.get_vision_encoder_info()
        width = height = vision_encoder_info.get_image_size()
        return ImageSize(width=width, height=height)

    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()

        return self.get_num_image_tokens(
            image_width=target_width, image_height=target_height
        )


class SiglipDummyInputsBuilder(BaseDummyInputsBuilder[SiglipProcessingInfo]):
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        return ""

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)

        target_width, target_height = self.info.get_image_size_with_most_features()

        image_overrides = mm_options.get("image") if mm_options else None

        return {
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
        }


class SiglipMultiModalProcessor(BaseMultiModalProcessor[SiglipProcessingInfo]):
    @cached_property
    def image_token_id(self) -> int:
        tokenizer = self.info.get_tokenizer()
179
180
181
182
183
        dummy_token_id = next(
            token_id
            for token_id in range(tokenizer.vocab_size)
            if token_id not in tokenizer.all_special_ids
        )
184
185
186
187
188
189

        return dummy_token_id

    def apply(
        self,
        prompt: str | list[int],
190
        mm_items: MultiModalDataItems,
191
192
193
194
195
        hf_processor_mm_kwargs: Mapping[str, object],
        tokenization_kwargs: Mapping[str, object] | None = None,
        *,
        mm_uuids: MultiModalUUIDDict | None = None,
    ) -> MultiModalInputs:
196
        if prompt and mm_items:
197
198
199
200
201
202
            raise ValueError(
                "Siglip accepts text-only or image-only inputs, not both! "
                "Image-only inputs means passing an image with an empty text "
                "prompt."
            )

203
        if mm_items:
204
205
206
207
208
209
210
211
212
            # For multi-modal data, the prompt after processing should
            # only contain the image token
            tokenization_kwargs = {
                **(tokenization_kwargs or {}),
                "add_special_tokens": False,
            }

        return super().apply(
            prompt=prompt,
213
            mm_items=mm_items,
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
            tokenization_kwargs=tokenization_kwargs,
            mm_uuids=mm_uuids,
        )

    def _hf_processor_applies_updates(
        self,
        prompt_text: str,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        tokenization_kwargs: Mapping[str, object],
    ) -> bool:
        return False

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(pixel_values=MultiModalFieldConfig.batched("image"))

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargsItems,
    ) -> list[PromptUpdate]:
        image_token_id = self.image_token_id

        def get_replacement(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)

            num_image_tokens = self.info.get_num_image_tokens(
                image_width=image_size.width, image_height=image_size.height
            )
            return [image_token_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=PromptIndexTargets.start(),
                replacement=get_replacement,
            ),
        ]


261
262
263
264
265
266
267
class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]):
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
268
        return self.get_patch_grid_length() ** 2
269

270
271
272
273
274
275
276
    def get_image_size(self) -> int:
        return self.vision_config.image_size

    def get_patch_size(self) -> int:
        return self.vision_config.patch_size

    def get_patch_grid_length(self) -> int:
277
278
        image_size, patch_size = self.get_image_size(), self.get_patch_size()
        return image_size // patch_size
279
280


281
# Adapted from https://github.com/huggingface/transformers/blob/v4.57.3/src/transformers/models/siglip/modeling_siglip.py#L216
282
283
284
285
286
287
288
289
class SiglipVisionEmbeddings(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size

290
        self.patch_embedding = Conv2dLayer(
291
292
293
294
295
296
297
            in_channels=config.num_channels,
            out_channels=self.embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            padding="valid",
        )

298
        self.num_patches = (self.image_size // self.patch_size) ** 2
299
        self.num_positions = self.num_patches
300
        self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
301
302
        self.register_buffer(
            "position_ids",
303
            torch.arange(self.num_positions, dtype=torch.int64).expand((1, -1)),
304
305
306
            persistent=False,
        )

307
308
309
    def interpolate_pos_encoding(
        self, embeddings: torch.Tensor, height: int, width: int
    ) -> torch.Tensor:
310
        num_patches = embeddings.shape[1]
311
        num_positions = self.position_embedding.weight.shape[1]
312
        if num_patches == num_positions and height == width:
313
314
315
            return self.position_embedding(self.position_ids)

        patch_pos_embed = self.position_embedding.weight.unsqueeze(0)
316
317

        dim = embeddings.shape[-1]
318
319
320
321
322
323
324

        new_height = height // self.patch_size
        new_width = width // self.patch_size

        sqrt_num_positions = int(num_positions**0.5)
        patch_pos_embed = patch_pos_embed.reshape(
            1, sqrt_num_positions, sqrt_num_positions, dim
325
        )
326
        patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
327

328
329
        patch_pos_embed = nn.functional.interpolate(
            patch_pos_embed,
330
            size=(new_height, new_width),
331
332
333
334
335
336
337
            mode="bicubic",
            align_corners=False,
        )

        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return patch_pos_embed

338
339
340
    def forward(
        self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False
    ) -> torch.Tensor:
341
342
        _, _, height, width = pixel_values.shape
        target_dtype = self.patch_embedding.weight.dtype
343
344
345
        patch_embeds = self.patch_embedding(
            pixel_values.to(dtype=target_dtype)
        )  # shape = [*, width, grid, grid]
346
347
348
        embeddings = patch_embeds.flatten(2).transpose(1, 2)

        if interpolate_pos_encoding:
349
            embeddings += self.interpolate_pos_encoding(embeddings, height, width)
350
        else:
351
            embeddings += self.position_embedding(self.position_ids)
352
353
354
        return embeddings


355
class SiglipAttention(nn.Module):
356
357
    def __init__(
        self,
358
        config: SiglipVisionConfig | SiglipTextConfig,
359
        quant_config: QuantizationConfig | None = None,
360
        *,
361
        prefix: str = "",
362
        attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention],
363
    ) -> None:
364
        super().__init__()
365

366
367
        self.config = config
        self.embed_dim = config.hidden_size
368
369
370
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
371
            raise ValueError(
372
373
374
                f"embed_dim must be divisible by num_heads "
                f"(got `embed_dim`: {self.embed_dim} and "
                f"`num_heads`: {self.num_heads})."
375
            )
376

377
        self.scale = self.head_dim**-0.5
378

379
        use_data_parallel = is_vit_use_data_parallel()
380
381
382
        self.qkv_proj = QKVParallelLinear(
            hidden_size=self.embed_dim,
            head_size=self.head_dim,
383
            total_num_heads=self.num_heads,
384
            quant_config=quant_config,
385
            prefix=f"{prefix}.qkv_proj",
386
            disable_tp=use_data_parallel,
387
        )
388

389
390
391
392
        self.out_proj = RowParallelLinear(
            input_size=self.embed_dim,
            output_size=self.embed_dim,
            quant_config=quant_config,
393
            prefix=f"{prefix}.out_proj",
394
            disable_tp=use_data_parallel,
395
396
        )

397
398
399
        self.tp_size = (
            1 if use_data_parallel else get_tensor_model_parallel_world_size()
        )
400
        self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
401

402
403
404
405
406
407
408
409
410
411
412
413
414
415
        if attn_cls == MMEncoderAttention:
            self.attn = attn_cls(
                self.num_heads_per_partition,
                self.head_dim,
                self.scale,
                prefix=f"{prefix}.attn",
            )
        else:
            self.attn = attn_cls(
                self.num_heads_per_partition,
                self.head_dim,
                self.scale,
                prefix=f"{prefix}.attn",
            )
416

417
418
419
    def forward(
        self,
        hidden_states: torch.Tensor,
420
    ) -> tuple[torch.Tensor, None]:
421
422
        """Input shape: Batch x Time x Channel"""
        qkv_states, _ = self.qkv_proj(hidden_states)
423
        query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
424
        out = self.attn(query_states, key_states, value_states)
425
        attn_output, _ = self.out_proj(out)
426

427
        return attn_output, None
428
429
430
431
432


class SiglipMLP(nn.Module):
    def __init__(
        self,
433
        config: SiglipVisionConfig | SiglipTextConfig,
434
        quant_config: QuantizationConfig | None = None,
435
436
        prefix: str = "",
    ) -> None:
437
        super().__init__()
438

439
        self.config = config
440
        use_data_parallel = is_vit_use_data_parallel()
441
        self.activation_fn = get_act_fn(config.hidden_act)
442

443
        # Special handling for BNB and torchao quantization
444
        if quant_config and quant_config.get_name() in ["bitsandbytes", "torchao"]:
445
446
            quantizable = True
        else:
447
            # For other quantization, we require the hidden size to be a
448
            # multiple of 64
449
450
451
            quantizable = (
                config.hidden_size % 64 == 0 and config.intermediate_size % 64 == 0
            )
452

453
454
455
456
        self.fc1 = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
            quant_config=quant_config if quantizable else None,
457
            prefix=f"{prefix}.fc1",
458
            disable_tp=use_data_parallel,
459
460
461
462
463
        )
        self.fc2 = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
            quant_config=quant_config if quantizable else None,
464
            prefix=f"{prefix}.fc2",
465
            disable_tp=use_data_parallel,
466
467
468
469
470
471
472
473
474
475
476
477
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)
        return hidden_states


class SiglipEncoderLayer(nn.Module):
    def __init__(
        self,
478
        config: SiglipVisionConfig | SiglipTextConfig,
479
        quant_config: QuantizationConfig | None = None,
480
        *,
481
        prefix: str = "",
482
        attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention],
483
    ) -> None:
484
        super().__init__()
485

486
487
        self.embed_dim = config.hidden_size

488
489
490
491
        self.self_attn = SiglipAttention(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
492
            attn_cls=attn_cls,
493
        )
494
        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
495
496
497
        self.mlp = SiglipMLP(
            config,
            quant_config=quant_config,
498
            prefix=f"{prefix}.mlp",
499
        )
500
        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
501
502
503
504

    def forward(
        self,
        hidden_states: torch.Tensor,
505
    ) -> tuple[torch.Tensor, None]:
506
507
508
        residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
509
        hidden_states, _ = self.self_attn(hidden_states=hidden_states)
510
        hidden_states += residual
511
512
513
514

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
515
        hidden_states += residual
516
517
518
519
520
521
522

        return hidden_states, None


class SiglipEncoder(nn.Module):
    def __init__(
        self,
523
        config: SiglipVisionConfig | SiglipTextConfig,
524
525
        quant_config: QuantizationConfig | None = None,
        num_hidden_layers_override: int | None = None,
526
        *,
527
        prefix: str = "",
528
        attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention],
529
    ) -> None:
530
        super().__init__()
531

532
        self.config = config
533
534
535
536
537
538

        if num_hidden_layers_override is None:
            num_hidden_layers = config.num_hidden_layers
        else:
            num_hidden_layers = num_hidden_layers_override

539
540
541
542
543
544
        self.layers = nn.ModuleList(
            [
                SiglipEncoderLayer(
                    config,
                    quant_config=quant_config,
                    prefix=f"{prefix}.layers.{layer_idx}",
545
                    attn_cls=attn_cls,
546
547
548
549
                )
                for layer_idx in range(num_hidden_layers)
            ]
        )
550
551
552
553

    def forward(
        self,
        inputs_embeds: torch.Tensor,
554
        return_all_hidden_states: bool,
555
    ) -> torch.Tensor | list[torch.Tensor]:
556
        hidden_states_pool = [inputs_embeds]
557
        hidden_states = inputs_embeds
558

559
560
        for encoder_layer in self.layers:
            hidden_states, _ = encoder_layer(hidden_states)
561
562
563
564
565
566
            if return_all_hidden_states:
                hidden_states_pool.append(hidden_states)
        # If we have multiple feature sample layers, we return all hidden
        # states in order and grab the ones we need by index.
        if return_all_hidden_states:
            return hidden_states_pool
567
568
569
        return hidden_states


570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
class SiglipTextTransformer(nn.Module):
    def __init__(
        self,
        config: SiglipTextConfig,
        quant_config: QuantizationConfig | None = None,
        *,
        prefix: str = "",
    ) -> None:
        super().__init__()

        self.config = config
        embed_dim = config.hidden_size

        self.embeddings = SiglipTextEmbeddings(config)

        self.encoder = SiglipEncoder(
            config=config,
            quant_config=quant_config,
            prefix=f"{prefix}.encoder",
589
            attn_cls=EncoderOnlyAttention,
590
591
592
593
594
        )

        self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
        self.head = nn.Linear(embed_dim, config.projection_size)

595
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
        return self.embeddings.token_embedding(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor | None,
        position_ids: torch.Tensor,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor:
        hidden_states = self.embeddings(input_ids, position_ids, inputs_embeds)

        last_hidden_state = self.encoder(
            inputs_embeds=hidden_states, return_all_hidden_states=False
        )

        last_hidden_state = self.final_layer_norm(last_hidden_state)

        return last_hidden_state

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()

        for name, loaded_weight in weights:
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


641
642
643
644
645
646
class SiglipMultiheadAttentionPoolingHead(nn.Module):
    """Multihead Attention Pooling."""

    def __init__(
        self,
        config: SiglipVisionConfig,
647
        quant_config: QuantizationConfig | None = None,
648
649
        prefix: str = "",
    ) -> None:
650
651
652
653
654
        super().__init__()

        self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
        # TODO(ChristopherCho): Implement vLLM version of MultiheadAttention
        self.attention = torch.nn.MultiheadAttention(
655
656
657
658
            config.hidden_size, config.num_attention_heads, batch_first=True
        )
        self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.mlp = SiglipMLP(
659
660
661
            config=config,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
662
        )
663
664

    def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
665
666
667
        batch_size = hidden_state.size(0)

        probe = self.probe.expand(batch_size, -1, -1)
668
669
670
671
672

        hidden_state = self.attention(probe, hidden_state, hidden_state)[0]

        residual = hidden_state
        hidden_state = self.layernorm(hidden_state)
673
674
        hidden_state = self.mlp(hidden_state)
        hidden_state += residual
675

676
677
678
        # Handled by resolve_visual_encoder_outputs
        # return hidden_state[:, 0]
        return hidden_state
679
680
681
682
683
684


class SiglipVisionTransformer(nn.Module):
    def __init__(
        self,
        config: SiglipVisionConfig,
685
        quant_config: QuantizationConfig | None = None,
686
        *,
687
688
        num_hidden_layers_override: int | None = None,
        require_post_norm: bool | None = None,
689
        prefix: str = "",
690
        use_head: bool | None = False,
691
    ) -> None:
692
        super().__init__()
693

694
695
696
697
        self.config = config
        embed_dim = config.hidden_size

        self.embeddings = SiglipVisionEmbeddings(config)
698

699
700
701
        self.encoder = SiglipEncoder(
            config,
            quant_config=quant_config,
702
            num_hidden_layers_override=num_hidden_layers_override,
703
            prefix=f"{prefix}.encoder",
704
            attn_cls=MMEncoderAttention,
705
        )
706

707
        num_hidden_layers = config.num_hidden_layers
708
709
        if len(self.encoder.layers) > config.num_hidden_layers:
            raise ValueError(
710
                f"The original encoder only has {num_hidden_layers} "
711
712
                f"layers, but you requested {len(self.encoder.layers)} layers."
            )
713
714
715
716
717
718

        # If possible, skip post_layernorm to conserve memory
        if require_post_norm is None:
            require_post_norm = len(self.encoder.layers) == num_hidden_layers

        if require_post_norm:
719
            self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
720
        else:
721
722
            self.post_layernorm = None

723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
        # Fall back to the config if a bool is not provided explicitly;
        # note that many config types, including SiglipVisionConfig,
        # do not have vision_use_head as a defined attribute.
        if isinstance(use_head, bool):
            self.use_head = use_head
        else:
            self.use_head = (
                True
                if not hasattr(config, "vision_use_head")
                else config.vision_use_head
            )

        # Only create and load the head weights if we actually need them
        self.head = (
            SiglipMultiheadAttentionPoolingHead(
738
739
740
741
                config=config,
                quant_config=quant_config,
                prefix=f"{prefix}.head",
            )
742
743
744
745
            if self.use_head
            else None
        )
        self.last_hs_proc = partial(self.maybe_layer_norm_and_apply_head)
746

747
748
749
750
751
752
753
754
    @property
    def dtype(self):
        return next(self.parameters()).dtype

    @property
    def device(self):
        return next(self.parameters()).device

755
756
757
    def forward(
        self,
        pixel_values: torch.Tensor,
758
759
        *,
        interpolate_pos_encoding: bool = False,
760
761
        select_layers: list[int] | None = None,
        feature_select_strategy: VisionFeatureSelectStrategy | None = None,
762
763
764
765
766
    ) -> torch.Tensor:
        hidden_states = self.embeddings(
            pixel_values,
            interpolate_pos_encoding=interpolate_pos_encoding,
        )
767
        # Produces either the last layer output or all of the hidden states,
768
        # depending on if we have select_layers or not
769
770
        encoder_outputs = self.encoder(
            inputs_embeds=hidden_states,
771
            return_all_hidden_states=select_layers is not None,
772
        )
773

774
775
776
777
778
        # In the case that we have multiple feature layers,
        # we stack and concatenate them into a tensor.
        # NOTE: post layer norm and the attention pooling head
        # are handled by last_hs_proc, which runs before applying
        # the vision feature selection strategy.
779
        encoder_outputs = resolve_visual_encoder_outputs(
780
            encoder_outputs,
781
            None,
782
783
            select_layers=select_layers,
            max_possible_layers=self.config.num_hidden_layers,
784
            last_hs_proc=self.last_hs_proc,
785
786
            feature_select_strategy=feature_select_strategy,
        )
787

788
        return encoder_outputs
789

790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
    def maybe_layer_norm_and_apply_head(
        self, encoder_outputs: torch.Tensor
    ) -> torch.Tensor:
        """Apply the post layer norm and head if they are enabled,
        given the last hidden states tensor.

        args:
            encoder_outputs: The last hidden states from the visual encoder.
        """
        if self.post_layernorm is not None:
            encoder_outputs = self.post_layernorm(encoder_outputs)
        if self.head is not None:
            encoder_outputs = self.head(encoder_outputs)
        return encoder_outputs

805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        layer_count = len(self.encoder.layers)

        for name, loaded_weight in weights:
            # post_layernorm is not needed in SiglipVisionTransformer
            if name.startswith("post_layernorm") and self.post_layernorm is None:
                continue

821
822
823
824
825
            # if the model configuration is not going to use
            # the pooling head for inference, don't load its weights
            if self.head is None and name.startswith("head"):
                continue

826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
            # omit layers when num_hidden_layers_override is set
            if name.startswith("encoder.layers"):
                layer_idx = int(name.split(".")[2])
                if layer_idx >= layer_count:
                    continue

            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

847
848
849
850
851

class SiglipVisionModel(nn.Module):
    def __init__(
        self,
        config: SiglipVisionConfig,
852
        quant_config: QuantizationConfig | None = None,
853
        *,
854
855
        num_hidden_layers_override: int | None = None,
        require_post_norm: bool | None = None,
856
        prefix: str = "",
857
        use_head: bool | None = False,
858
    ) -> None:
859
        super().__init__()
860

861
        self.quant_config = quant_config
862
863
        self.vision_model = SiglipVisionTransformer(
            config,
864
            quant_config=quant_config,
865
            num_hidden_layers_override=num_hidden_layers_override,
866
867
            require_post_norm=require_post_norm,
            prefix=f"{prefix}.vision_model",
868
            use_head=use_head,
869
870
871
872
873
        )

    def get_input_embeddings(self) -> nn.Module:
        return self.vision_model.embeddings.patch_embedding

874
875
    @property
    def dtype(self):
876
877
878
879
880
        return self.vision_model.dtype

    @property
    def device(self):
        return self.vision_model.device
881

882
883
884
885
    def forward(
        self,
        pixel_values: torch.Tensor,
        interpolate_pos_encoding: bool = False,
886
887
        select_layers: list[int] | None = None,
        feature_select_strategy: VisionFeatureSelectStrategy | None = None,
888
889
890
891
    ) -> torch.Tensor:
        return self.vision_model(
            pixel_values=pixel_values,
            interpolate_pos_encoding=interpolate_pos_encoding,
892
893
            select_layers=select_layers,
            feature_select_strategy=feature_select_strategy,
894
        )
895

896
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
897
898
899
900
901
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
902
        ]
903
        params_dict = dict(self.named_parameters())
904
        loaded_params: set[str] = set()
905
906
907
        layer_count = len(self.vision_model.encoder.layers)

        for name, loaded_weight in weights:
908
            # post_layernorm is optional in SiglipVisionModel
909
910
911
912
            if (
                name.startswith("vision_model.post_layernorm")
                and self.vision_model.post_layernorm is None
            ):
913
914
                continue

915
916
917
918
919
            # if the model configuration is not going to use
            # the pooling head for inference, don't load its weights
            if self.vision_model.head is None and name.startswith("vision_model.head"):
                continue

920
            # omit layers when num_hidden_layers_override is set
921
            if name.startswith("vision_model.encoder.layers"):
922
923
924
925
                layer_idx = int(name.split(".")[3])
                if layer_idx >= layer_count:
                    continue

926
            # Check if this is a scale parameter that needs remapping first
927
            if name.endswith((".k_scale", ".v_scale", ".q_scale", ".prob_scale")):
928
929
930
931
932
                # Try to remap the scale name first
                remapped_name = maybe_remap_kv_scale_name(name, params_dict)
                if remapped_name is not None and remapped_name in params_dict:
                    # Successfully remapped, use the remapped name
                    param = params_dict[remapped_name]
933
934
935
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
936
937
938
939
940
                    weight_loader(param, loaded_weight)
                    loaded_params.add(remapped_name)
                    continue
                # If remapping failed, continue with normal processing

941
            for param_name, weight_name, shard_id in stacked_params_mapping:
942
943
                if weight_name not in name:
                    continue
944
                name = name.replace(weight_name, param_name)
945

946
                param = params_dict[name]
947
948
949
950
951
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
952
953
954
                param = maybe_swap_ffn_param(
                    name, param, loaded_weight, params_dict, self.quant_config
                )
955
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
956
                weight_loader(param, loaded_weight)
957
958
            loaded_params.add(name)
        return loaded_params
959
960


961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
def maybe_swap_ffn_param(
    name: str,
    param: torch.Tensor,
    loaded_weight: torch.Tensor,
    params_dict: dict[str, torch.Tensor],
    quant_config: QuantizationConfig,
) -> torch.Tensor:
    if not (quant_config and quant_config.get_name() == "gguf") or ".fc" not in name:
        return param
    # Some GGUF models have fc1 and fc2 weights swapped
    tp_size = get_tensor_model_parallel_world_size()
    output_dim = getattr(param, "output_dim", 0)
    output_size = param.size(output_dim) * tp_size
    weight_out_size = loaded_weight.size(output_dim)
    if ".fc1." in name and output_size != weight_out_size:
        new_name = name.replace(".fc1.", ".fc2.")
        param = params_dict[new_name]
    elif ".fc2." in name and output_size != weight_out_size:
        new_name = name.replace(".fc2.", ".fc1.")
        param = params_dict[new_name]
    return param


984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
# Adapted from: https://github.com/huggingface/transformers/blob/v4.54.1/src/transformers/models/siglip/modeling_siglip.py#L200
class SiglipTextEmbeddings(nn.Module):
    def __init__(self, config: SiglipTextConfig):
        super().__init__()
        self.config = config

        self.token_embedding = VocabParallelEmbedding(
            config.vocab_size, config.hidden_size
        )

        self.position_embedding = VocabParallelEmbedding(
            config.max_position_embeddings, config.hidden_size
        )

        self.register_buffer(
            "position_ids",
            torch.arange(config.max_position_embeddings).expand((1, -1)),
            persistent=False,
        )

    def forward(
        self,
        input_ids: torch.Tensor | None,
        position_ids: torch.Tensor,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor:
        if inputs_embeds is None:
            inputs_embeds = self.token_embedding(input_ids)

        position_embeddings = self.position_embedding(position_ids)
        embeddings = inputs_embeds + position_embeddings
1015

1016
1017
1018
1019
        return embeddings


# Assume EOS token corresponds to CLS token in text model
1020
@default_pooling_type(seq_pooling_type="CLS")
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
@MULTIMODAL_REGISTRY.register_processor(
    SiglipMultiModalProcessor,
    info=SiglipProcessingInfo,
    dummy_inputs=SiglipDummyInputsBuilder,
)
class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
    is_pooling_model = True

    packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}

    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
        if modality.startswith("image"):
            return None

        raise ValueError("Only image modality is supported")

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config: SiglipConfig = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        self.config = config

        if hasattr(config, "num_labels"):
            config.num_labels = 0

        text_config = config.text_config
        vision_config = config.vision_config

        self.text_embed_dim = text_config.hidden_size
        self.vision_embed_dim = vision_config.hidden_size
1053
        self.text_projection_size = text_config.projection_size
1054

1055
1056
1057
1058
1059
1060
        with self._mark_language_model(vllm_config):
            self.text_model = SiglipTextTransformer(
                text_config,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "text_model"),
            )
1061

1062
1063
1064
1065
1066
        with self._mark_tower_model(vllm_config, "image"):
            self.vision_model = SiglipVisionTransformer(
                vision_config,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "vision_model"),
1067
                use_head=None,  # Allows potential pooling head
1068
            )
1069
1070
1071
1072
1073

        pooler_config = vllm_config.model_config.pooler_config
        assert pooler_config is not None
        self.pooler_config = pooler_config

1074
        self.pooler = DispatchPooler.for_embedding(pooler_config)
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089

        self._is_text_input = True

    def get_text_features(
        self,
        input_ids: torch.Tensor | None,
        position_ids: torch.Tensor,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor:
        last_hidden_state = self.text_model(
            input_ids=input_ids,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
        )
        text_features = self.text_model.head(last_hidden_state)
1090
1091
1092
1093
1094
1095
1096

        # SigLIP uses reversed position_ids;
        # flip sequences to move EOS token to first position
        text_features = self._flip_sequences_by_position_ids(
            text_features, position_ids
        )

1097
1098
        return text_features

1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
    def _flip_sequences_by_position_ids(
        self,
        features: torch.Tensor,
        position_ids: torch.Tensor,
    ) -> torch.Tensor:
        """Flip sequences so EOS token moves to first position for CLS pooling.

        SigLIP position_ids are reversed within each sequence. This method detects
        sequence boundaries and flips each sequence individually.
        """
        if len(features) == 1:
            return features

        # Detect sequence boundaries where position_ids decrease
        position_diffs = position_ids[1:] - position_ids[:-1]
        boundary_mask = position_diffs <= 0

        boundary_indices = torch.cat(
            [
                torch.tensor([0], device=features.device),
                torch.where(boundary_mask)[0] + 1,
                torch.tensor([len(features)], device=features.device),
            ]
        )

        # For each sequence [start, end), position i flips to: start + end - 1 - i
        lengths = boundary_indices[1:] - boundary_indices[:-1]
        starts = boundary_indices[:-1]
        ends = boundary_indices[1:]

        # Assign sequence ID to each element
        sequence_ids = torch.arange(
            len(lengths), device=features.device
        ).repeat_interleave(lengths)

        # Calculate flipped indices for all positions at once
        current_positions = torch.arange(len(features), device=features.device)
        flip_indices = starts[sequence_ids] + ends[sequence_ids] - 1 - current_positions

        return features[flip_indices]

1140
1141
1142
1143
1144
1145
1146
    def get_image_features(
        self,
        pixel_values: torch.Tensor,
        feature_select_strategy: VisionFeatureSelectStrategy | None = None,
    ) -> torch.Tensor:
        if feature_select_strategy is None:
            feature_select_strategy = _get_vision_feature_select_strategy(
1147
                self.pooler_config.seq_pooling_type
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
            )

        pooled_output = self.vision_model(
            pixel_values=pixel_values,
            select_layers=None,
            feature_select_strategy=feature_select_strategy,
        )

        return pooled_output

    def _parse_and_validate_image_input(
        self, **kwargs: object
    ) -> SiglipImagePixelInputs | None:
        pixel_values = kwargs.pop("pixel_values", None)
        if pixel_values is None:
            return None

        expected_h = expected_w = self.config.vision_config.image_size
        return SiglipImagePixelInputs(
            type="pixel_values",
            data=pixel_values,
            resolve_bindings={"h": expected_h, "w": expected_w},
        )

    def _process_image_inputs(self, inputs: SiglipImagePixelInputs) -> torch.Tensor:
        pixel_values = inputs["data"]

        return self.get_image_features(pixel_values)

1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
    def _embed_text_input_ids(
        self,
        input_ids: torch.Tensor,
        embed_input_ids: Callable[[torch.Tensor], torch.Tensor],
        *,
        is_multimodal: torch.Tensor | None,
        handle_oov_mm_token: bool,
    ) -> torch.Tensor:
        inputs_embeds = super()._embed_text_input_ids(
            input_ids,
            embed_input_ids,
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
        )

        # NOTE: inputs_embeds in model runner has size text_config.projection_size
        # (instead of text_config.hidden_size) to accommodate image embeddings
        inputs_embeds_size = self.text_projection_size
        if inputs_embeds.shape[1] < inputs_embeds_size:
            inputs_embeds = torch.cat(
                [
                    inputs_embeds,
                    inputs_embeds.new_empty(
                        inputs_embeds.shape[0],
                        inputs_embeds_size - inputs_embeds.shape[1],
                    ),
                ],
                dim=1,
            )
        elif inputs_embeds.shape[1] > inputs_embeds_size:
            # No need to handle this case for now
            raise NotImplementedError

        return inputs_embeds

1212
    def embed_input_ids(
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: MultiModalEmbeddings | None = None,
        *,
        is_multimodal: torch.Tensor | None = None,
        handle_oov_mm_token: bool = False,
    ) -> torch.Tensor:
        self._is_text_input = (
            multimodal_embeddings is None or len(multimodal_embeddings) == 0
        )

        if multimodal_embeddings is None or is_multimodal is None:
1225
            return super().embed_input_ids(input_ids)
1226

1227
        return super().embed_input_ids(
1228
1229
1230
1231
1232
1233
            input_ids,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
        )

1234
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return []

        vision_embeddings = self._process_image_inputs(image_input)
        return vision_embeddings

    def forward(
        self,
        input_ids: torch.Tensor | None,
        positions: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        **kwargs: object,
    ) -> torch.Tensor:
        if intermediate_tensors is not None:
            raise RuntimeError("PP is not supported for this model")

        # Multimodal inputs (image embeddings)
        if not self._is_text_input:
            return inputs_embeds

1257
1258
1259
1260
1261
1262
1263
1264
1265
        # NOTE: inputs_embeds in model runner has size text_config.projection_size
        # (instead of text_config.hidden_size) to accommodate image embeddings
        hidden_size = self.text_embed_dim
        if inputs_embeds.shape[1] > hidden_size:
            inputs_embeds = inputs_embeds[:, :hidden_size]
        elif inputs_embeds.shape[1] < hidden_size:
            # No need to handle this case for now
            raise NotImplementedError

1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
        return self.get_text_features(input_ids, positions, inputs_embeds)

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
        loader = AutoWeightsLoader(
            self,
            skip_substrs=[".position_ids"],
            ignore_unexpected_prefixes=["logit_scale.", "logit_bias."],
        )

        return loader.load_weights(weights)