siglip.py 19.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""Implementation of SiglipVisionModel intended to be only used
within a vision language model."""

import math
7
from collections.abc import Iterable
8
9
10

import torch
from torch import nn
11
from transformers import SiglipVisionConfig
12

13
from vllm.attention.layer import MultiHeadAttention
14
from vllm.distributed import divide, get_tensor_model_parallel_world_size
15
from vllm.model_executor.layers.activation import get_act_fn
16
17
18
19
20
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
21
from vllm.model_executor.layers.quantization import QuantizationConfig
22
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
23
from vllm.model_executor.model_loader.weight_utils import (
24
25
26
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
27

28
29
30
31
32
from .vision import (
    VisionEncoderInfo,
    VisionFeatureSelectStrategy,
    resolve_visual_encoder_outputs,
)
33

34

35
36
37
38
39
40
41
class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]):
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
42
        return self.get_patch_grid_length() ** 2
43

44
45
46
47
48
49
50
    def get_image_size(self) -> int:
        return self.vision_config.image_size

    def get_patch_size(self) -> int:
        return self.vision_config.patch_size

    def get_patch_grid_length(self) -> int:
51
52
        image_size, patch_size = self.get_image_size(), self.get_patch_size()
        return image_size // patch_size
53
54


55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa
class SiglipVisionEmbeddings(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size

        self.patch_embedding = nn.Conv2d(
            in_channels=config.num_channels,
            out_channels=self.embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            padding="valid",
        )

72
        self.num_patches = (self.image_size // self.patch_size) ** 2
73
74
        self.num_positions = self.num_patches
        self.position_embedding = VocabParallelEmbedding(
75
76
            self.num_positions, self.embed_dim
        )
77
78
        self.register_buffer(
            "position_ids",
79
            torch.arange(self.num_positions, dtype=torch.int64).expand((1, -1)),
80
81
82
            persistent=False,
        )

83
84
85
    def interpolate_pos_encoding(
        self, embeddings: torch.Tensor, height: int, width: int
    ) -> torch.Tensor:
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
        """
        This method is an adapted method for SigLIP (due to SigLIP not having
        class embedding unlike other ViTs) that allows the model to interpolate
        the pre-trained position encodings such that it can be usable on higher
        resolution images.

        Source:
        https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
        """
        position_embeddings = self.position_embedding.weight.unsqueeze(0)
        num_patches = embeddings.shape[1]
        num_positions = position_embeddings.shape[1]
        if num_patches == num_positions and height == width:
            return position_embeddings

        dim = embeddings.shape[-1]
        height = height // self.patch_size
        width = width // self.patch_size
        # we add a small number to avoid floating point error
        # in the interpolation
        # see discussion at https://github.com/facebookresearch/dino/issues/8
        height, width = height + 0.1, width + 0.1

        patch_pos_embed = position_embeddings.reshape(
110
111
            1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
        )
112
113
114
115
116
117
118
119
120
121
        patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
        patch_pos_embed = nn.functional.interpolate(
            patch_pos_embed,
            scale_factor=(
                height / math.sqrt(num_positions),
                width / math.sqrt(num_positions),
            ),
            mode="bicubic",
            align_corners=False,
        )
122
123
124
125
126
127
128
129
        if (
            int(height) != patch_pos_embed.shape[-2]
            or int(width) != patch_pos_embed.shape[-1]
        ):
            raise ValueError(
                "Width or height does not match with "
                "the interpolated position embeddings"
            )
130
131
132
133

        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return patch_pos_embed

134
135
136
    def forward(
        self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False
    ) -> torch.Tensor:
137
138
        _, _, height, width = pixel_values.shape
        target_dtype = self.patch_embedding.weight.dtype
139
140
141
        patch_embeds = self.patch_embedding(
            pixel_values.to(dtype=target_dtype)
        )  # shape = [*, width, grid, grid]
142
143
144
        embeddings = patch_embeds.flatten(2).transpose(1, 2)

        if interpolate_pos_encoding:
145
            embeddings += self.interpolate_pos_encoding(embeddings, height, width)
146
        else:
147
            embeddings += self.position_embedding(self.position_ids)
148
149
150
        return embeddings


151
class SiglipAttention(nn.Module):
152
153
    def __init__(
        self,
154
        config: SiglipVisionConfig,
155
        quant_config: QuantizationConfig | None = None,
156
157
        prefix: str = "",
    ) -> None:
158
        super().__init__()
159

160
161
        self.config = config
        self.embed_dim = config.hidden_size
162
163
164
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
165
166
167
168
169
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got "
                "`embed_dim`: {self.embed_dim} and `num_heads`:"
                f" {self.num_heads})."
            )
170

171
172
173
174
175
        self.scale = self.head_dim**-0.5
        self.dropout = config.attention_dropout
        self.qkv_proj = QKVParallelLinear(
            hidden_size=self.embed_dim,
            head_size=self.head_dim,
176
            total_num_heads=self.num_heads,
177
            quant_config=quant_config,
178
            prefix=f"{prefix}.qkv_proj",
179
        )
180

181
182
183
184
        self.out_proj = RowParallelLinear(
            input_size=self.embed_dim,
            output_size=self.embed_dim,
            quant_config=quant_config,
185
            prefix=f"{prefix}.out_proj",
186
187
        )

188
189
        self.tp_size = get_tensor_model_parallel_world_size()
        self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
190

191
192
193
        self.attn = MultiHeadAttention(
            self.num_heads_per_partition, self.head_dim, self.scale
        )
194

195
196
197
198
199
200
    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        """Input shape: Batch x Time x Channel"""
        qkv_states, _ = self.qkv_proj(hidden_states)
201
202
        query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)

203
        out = self.attn(query_states, key_states, value_states)
204
        attn_output, _ = self.out_proj(out)
205

206
        return attn_output, None
207
208
209
210
211


class SiglipMLP(nn.Module):
    def __init__(
        self,
212
        config: SiglipVisionConfig,
213
        quant_config: QuantizationConfig | None = None,
214
215
        prefix: str = "",
    ) -> None:
216
        super().__init__()
217

218
219
        self.config = config
        self.activation_fn = get_act_fn(config.hidden_act)
220
        # Special handling for BNB and torchao quantization
221
        if quant_config and quant_config.get_name() in ["bitsandbytes", "torchao"]:
222
223
            quantizable = True
        else:
224
            # For other quantization, we require the hidden size to be a
225
            # multiple of 64
226
227
228
            quantizable = (
                config.hidden_size % 64 == 0 and config.intermediate_size % 64 == 0
            )
229
230
231
232
        self.fc1 = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
            quant_config=quant_config if quantizable else None,
233
            prefix=f"{prefix}.fc1",
234
235
236
237
238
        )
        self.fc2 = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
            quant_config=quant_config if quantizable else None,
239
            prefix=f"{prefix}.fc2",
240
241
242
243
244
245
246
247
248
249
250
251
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)
        return hidden_states


class SiglipEncoderLayer(nn.Module):
    def __init__(
        self,
252
        config: SiglipVisionConfig,
253
        quant_config: QuantizationConfig | None = None,
254
255
        prefix: str = "",
    ) -> None:
256
        super().__init__()
257

258
259
        self.embed_dim = config.hidden_size

260
261
262
263
264
        self.self_attn = SiglipAttention(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
        )
265
        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
266
267
268
        self.mlp = SiglipMLP(
            config,
            quant_config=quant_config,
269
            prefix=f"{prefix}.mlp",
270
        )
271
        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
272
273
274
275

    def forward(
        self,
        hidden_states: torch.Tensor,
276
    ) -> tuple[torch.Tensor, None]:
277
278
279
        residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
280
        hidden_states, _ = self.self_attn(hidden_states=hidden_states)
281
        hidden_states += residual
282
283
284
285

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
286
        hidden_states += residual
287
288
289
290
291
292
293

        return hidden_states, None


class SiglipEncoder(nn.Module):
    def __init__(
        self,
294
        config: SiglipVisionConfig,
295
296
        quant_config: QuantizationConfig | None = None,
        num_hidden_layers_override: int | None = None,
297
298
        prefix: str = "",
    ) -> None:
299
        super().__init__()
300

301
        self.config = config
302
303
304
305
306
307

        if num_hidden_layers_override is None:
            num_hidden_layers = config.num_hidden_layers
        else:
            num_hidden_layers = num_hidden_layers_override

308
309
310
311
312
313
314
315
316
317
        self.layers = nn.ModuleList(
            [
                SiglipEncoderLayer(
                    config,
                    quant_config=quant_config,
                    prefix=f"{prefix}.layers.{layer_idx}",
                )
                for layer_idx in range(num_hidden_layers)
            ]
        )
318
319
320
321

    def forward(
        self,
        inputs_embeds: torch.Tensor,
322
        return_all_hidden_states: bool,
323
    ) -> torch.Tensor | list[torch.Tensor]:
324
        hidden_states_pool = [inputs_embeds]
325
        hidden_states = inputs_embeds
326

327
328
        for encoder_layer in self.layers:
            hidden_states, _ = encoder_layer(hidden_states)
329
330
331
332
333
334
            if return_all_hidden_states:
                hidden_states_pool.append(hidden_states)
        # If we have multiple feature sample layers, we return all hidden
        # states in order and grab the ones we need by index.
        if return_all_hidden_states:
            return hidden_states_pool
335
336
337
338
339
340
341
342
343
        return hidden_states


class SiglipMultiheadAttentionPoolingHead(nn.Module):
    """Multihead Attention Pooling."""

    def __init__(
        self,
        config: SiglipVisionConfig,
344
        quant_config: QuantizationConfig | None = None,
345
346
        prefix: str = "",
    ) -> None:
347
348
349
350
351
        super().__init__()

        self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
        # TODO(ChristopherCho): Implement vLLM version of MultiheadAttention
        self.attention = torch.nn.MultiheadAttention(
352
353
354
355
356
357
            config.hidden_size, config.num_attention_heads, batch_first=True
        )
        self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.mlp = SiglipMLP(
            config=config, quant_config=quant_config, prefix=f"{prefix}.mlp"
        )
358
359
360
361
362
363
364
365
366

    def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
        batch_size = hidden_state.shape[0]
        probe = self.probe.repeat(batch_size, 1, 1)

        hidden_state = self.attention(probe, hidden_state, hidden_state)[0]

        residual = hidden_state
        hidden_state = self.layernorm(hidden_state)
367
368
        hidden_state = self.mlp(hidden_state)
        hidden_state += residual
369
370
371
372
373
374
375
376

        return hidden_state[:, 0]


class SiglipVisionTransformer(nn.Module):
    def __init__(
        self,
        config: SiglipVisionConfig,
377
        quant_config: QuantizationConfig | None = None,
378
        *,
379
380
        num_hidden_layers_override: int | None = None,
        require_post_norm: bool | None = None,
381
382
        prefix: str = "",
    ) -> None:
383
        super().__init__()
384

385
386
387
388
        self.config = config
        embed_dim = config.hidden_size

        self.embeddings = SiglipVisionEmbeddings(config)
389

390
391
392
        self.encoder = SiglipEncoder(
            config,
            quant_config=quant_config,
393
            num_hidden_layers_override=num_hidden_layers_override,
394
            prefix=f"{prefix}.encoder",
395
        )
396

397
        num_hidden_layers = config.num_hidden_layers
398
399
        if len(self.encoder.layers) > config.num_hidden_layers:
            raise ValueError(
400
                f"The original encoder only has {num_hidden_layers} "
401
402
                f"layers, but you requested {len(self.encoder.layers)} layers."
            )
403
404
405
406
407
408

        # If possible, skip post_layernorm to conserve memory
        if require_post_norm is None:
            require_post_norm = len(self.encoder.layers) == num_hidden_layers

        if require_post_norm:
409
            self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
410
        else:
411
412
            self.post_layernorm = None

413
414
415
        self.use_head = (
            True if not hasattr(config, "vision_use_head") else config.vision_use_head
        )
416
417
        if self.use_head:
            self.head = SiglipMultiheadAttentionPoolingHead(
418
419
420
421
                config=config,
                quant_config=quant_config,
                prefix=f"{prefix}.head",
            )
422
423
424
425

    def forward(
        self,
        pixel_values: torch.Tensor,
426
427
        *,
        interpolate_pos_encoding: bool = False,
428
429
        select_layers: list[int] | None = None,
        feature_select_strategy: VisionFeatureSelectStrategy | None = None,
430
431
432
433
434
435
    ) -> torch.Tensor:
        hidden_states = self.embeddings(
            pixel_values,
            interpolate_pos_encoding=interpolate_pos_encoding,
        )

436
        # Produces either the last layer output or all of the hidden states,
437
        # depending on if we have select_layers or not
438
439
        encoder_outputs = self.encoder(
            inputs_embeds=hidden_states,
440
            return_all_hidden_states=select_layers is not None,
441
        )
442

443
444
        # Handle post-norm (if applicable) and stacks feature layers if needed
        encoder_outputs = resolve_visual_encoder_outputs(
445
446
447
448
449
450
            encoder_outputs,
            self.post_layernorm,
            select_layers=select_layers,
            max_possible_layers=self.config.num_hidden_layers,
            feature_select_strategy=feature_select_strategy,
        )
451

452
        # TODO: add this back when pooled_output is used in inference.
453
        # if self.use_head:
454
        # pooled_output = self.head(encoder_outputs)
455

456
        return encoder_outputs
457
458
459
460
461
462
463
464
465


class SiglipVisionModel(nn.Module):
    config_class = SiglipVisionConfig
    main_input_name = "pixel_values"

    def __init__(
        self,
        config: SiglipVisionConfig,
466
        quant_config: QuantizationConfig | None = None,
467
        *,
468
469
        num_hidden_layers_override: int | None = None,
        require_post_norm: bool | None = None,
470
471
        prefix: str = "",
    ) -> None:
472
        super().__init__()
473

474
475
476
        self.vision_model = SiglipVisionTransformer(
            config,
            quant_config,
477
            num_hidden_layers_override=num_hidden_layers_override,
478
479
            require_post_norm=require_post_norm,
            prefix=f"{prefix}.vision_model",
480
481
482
483
484
        )

    def get_input_embeddings(self) -> nn.Module:
        return self.vision_model.embeddings.patch_embedding

485
486
487
488
    @property
    def dtype(self):
        return self.get_input_embeddings().weight.dtype

489
490
491
492
    def forward(
        self,
        pixel_values: torch.Tensor,
        interpolate_pos_encoding: bool = False,
493
494
        select_layers: list[int] | None = None,
        feature_select_strategy: VisionFeatureSelectStrategy | None = None,
495
496
497
498
    ) -> torch.Tensor:
        return self.vision_model(
            pixel_values=pixel_values,
            interpolate_pos_encoding=interpolate_pos_encoding,
499
500
            select_layers=select_layers,
            feature_select_strategy=feature_select_strategy,
501
        )
502

503
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
504
505
506
507
508
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
509
        ]
510
        params_dict = dict(self.named_parameters())
511
        loaded_params: set[str] = set()
512
513
514
        layer_count = len(self.vision_model.encoder.layers)

        for name, loaded_weight in weights:
515
            # post_layernorm is optional in SiglipVisionModel
516
517
518
519
            if (
                name.startswith("vision_model.post_layernorm")
                and self.vision_model.post_layernorm is None
            ):
520
521
                continue

522
            # omit layers when num_hidden_layers_override is set
523
            if name.startswith("vision_model.encoder.layers"):
524
525
526
527
                layer_idx = int(name.split(".")[3])
                if layer_idx >= layer_count:
                    continue

528
            # Check if this is a scale parameter that needs remapping first
529
            if name.endswith((".k_scale", ".v_scale", ".q_scale", ".prob_scale")):
530
531
532
533
534
                # Try to remap the scale name first
                remapped_name = maybe_remap_kv_scale_name(name, params_dict)
                if remapped_name is not None and remapped_name in params_dict:
                    # Successfully remapped, use the remapped name
                    param = params_dict[remapped_name]
535
536
537
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
538
539
540
541
542
                    weight_loader(param, loaded_weight)
                    loaded_params.add(remapped_name)
                    continue
                # If remapping failed, continue with normal processing

543
            for param_name, weight_name, shard_id in stacked_params_mapping:
544
545
                if weight_name not in name:
                    continue
546
                name = name.replace(weight_name, param_name)
547

548
                param = params_dict[name]
549
550
551
552
553
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
554
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
555
                weight_loader(param, loaded_weight)
556
557
            loaded_params.add(name)
        return loaded_params