siglip.py 19.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""Implementation of SiglipVisionModel intended to be only used
within a vision language model."""

import math
7
8
from collections.abc import Iterable
from typing import Optional, Union
9
10
11

import torch
from torch import nn
12
from transformers import SiglipVisionConfig
13

14
from vllm.attention.layer import MultiHeadAttention
15
from vllm.distributed import divide, get_tensor_model_parallel_world_size
16
from vllm.model_executor.layers.activation import get_act_fn
17
18
19
20
21
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
22
from vllm.model_executor.layers.quantization import QuantizationConfig
23
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
24
from vllm.model_executor.model_loader.weight_utils import (
25
26
27
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
28

29
30
31
32
33
from .vision import (
    VisionEncoderInfo,
    VisionFeatureSelectStrategy,
    resolve_visual_encoder_outputs,
)
34

35

36
37
38
39
40
41
42
class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]):
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
43
        return self.get_patch_grid_length() ** 2
44

45
46
47
48
49
50
51
    def get_image_size(self) -> int:
        return self.vision_config.image_size

    def get_patch_size(self) -> int:
        return self.vision_config.patch_size

    def get_patch_grid_length(self) -> int:
52
53
        image_size, patch_size = self.get_image_size(), self.get_patch_size()
        return image_size // patch_size
54
55


56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa
class SiglipVisionEmbeddings(nn.Module):
    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size

        self.patch_embedding = nn.Conv2d(
            in_channels=config.num_channels,
            out_channels=self.embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            padding="valid",
        )

73
        self.num_patches = (self.image_size // self.patch_size) ** 2
74
75
        self.num_positions = self.num_patches
        self.position_embedding = VocabParallelEmbedding(
76
77
            self.num_positions, self.embed_dim
        )
78
79
        self.register_buffer(
            "position_ids",
80
            torch.arange(self.num_positions, dtype=torch.int64).expand((1, -1)),
81
82
83
            persistent=False,
        )

84
85
86
    def interpolate_pos_encoding(
        self, embeddings: torch.Tensor, height: int, width: int
    ) -> torch.Tensor:
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
        """
        This method is an adapted method for SigLIP (due to SigLIP not having
        class embedding unlike other ViTs) that allows the model to interpolate
        the pre-trained position encodings such that it can be usable on higher
        resolution images.

        Source:
        https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
        """
        position_embeddings = self.position_embedding.weight.unsqueeze(0)
        num_patches = embeddings.shape[1]
        num_positions = position_embeddings.shape[1]
        if num_patches == num_positions and height == width:
            return position_embeddings

        dim = embeddings.shape[-1]
        height = height // self.patch_size
        width = width // self.patch_size
        # we add a small number to avoid floating point error
        # in the interpolation
        # see discussion at https://github.com/facebookresearch/dino/issues/8
        height, width = height + 0.1, width + 0.1

        patch_pos_embed = position_embeddings.reshape(
111
112
            1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
        )
113
114
115
116
117
118
119
120
121
122
        patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
        patch_pos_embed = nn.functional.interpolate(
            patch_pos_embed,
            scale_factor=(
                height / math.sqrt(num_positions),
                width / math.sqrt(num_positions),
            ),
            mode="bicubic",
            align_corners=False,
        )
123
124
125
126
127
128
129
130
        if (
            int(height) != patch_pos_embed.shape[-2]
            or int(width) != patch_pos_embed.shape[-1]
        ):
            raise ValueError(
                "Width or height does not match with "
                "the interpolated position embeddings"
            )
131
132
133
134

        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return patch_pos_embed

135
136
137
    def forward(
        self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False
    ) -> torch.Tensor:
138
139
        _, _, height, width = pixel_values.shape
        target_dtype = self.patch_embedding.weight.dtype
140
141
142
        patch_embeds = self.patch_embedding(
            pixel_values.to(dtype=target_dtype)
        )  # shape = [*, width, grid, grid]
143
144
145
        embeddings = patch_embeds.flatten(2).transpose(1, 2)

        if interpolate_pos_encoding:
146
            embeddings += self.interpolate_pos_encoding(embeddings, height, width)
147
        else:
148
            embeddings += self.position_embedding(self.position_ids)
149
150
151
        return embeddings


152
class SiglipAttention(nn.Module):
153
154
    def __init__(
        self,
155
        config: SiglipVisionConfig,
156
        quant_config: Optional[QuantizationConfig] = None,
157
158
        prefix: str = "",
    ) -> None:
159
        super().__init__()
160

161
162
        self.config = config
        self.embed_dim = config.hidden_size
163
164
165
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
166
167
168
169
170
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got "
                "`embed_dim`: {self.embed_dim} and `num_heads`:"
                f" {self.num_heads})."
            )
171

172
173
174
175
176
        self.scale = self.head_dim**-0.5
        self.dropout = config.attention_dropout
        self.qkv_proj = QKVParallelLinear(
            hidden_size=self.embed_dim,
            head_size=self.head_dim,
177
            total_num_heads=self.num_heads,
178
            quant_config=quant_config,
179
            prefix=f"{prefix}.qkv_proj",
180
        )
181

182
183
184
185
        self.out_proj = RowParallelLinear(
            input_size=self.embed_dim,
            output_size=self.embed_dim,
            quant_config=quant_config,
186
            prefix=f"{prefix}.out_proj",
187
188
        )

189
190
        self.tp_size = get_tensor_model_parallel_world_size()
        self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
191

192
193
194
        self.attn = MultiHeadAttention(
            self.num_heads_per_partition, self.head_dim, self.scale
        )
195

196
197
198
199
200
201
    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        """Input shape: Batch x Time x Channel"""
        qkv_states, _ = self.qkv_proj(hidden_states)
202
203
        query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)

204
        out = self.attn(query_states, key_states, value_states)
205
        attn_output, _ = self.out_proj(out)
206

207
        return attn_output, None
208
209
210
211
212


class SiglipMLP(nn.Module):
    def __init__(
        self,
213
        config: SiglipVisionConfig,
214
        quant_config: Optional[QuantizationConfig] = None,
215
216
        prefix: str = "",
    ) -> None:
217
        super().__init__()
218

219
220
        self.config = config
        self.activation_fn = get_act_fn(config.hidden_act)
221
        # Special handling for BNB and torchao quantization
222
        if quant_config and quant_config.get_name() in ["bitsandbytes", "torchao"]:
223
224
            quantizable = True
        else:
225
            # For other quantization, we require the hidden size to be a
226
            # multiple of 64
227
228
229
            quantizable = (
                config.hidden_size % 64 == 0 and config.intermediate_size % 64 == 0
            )
230
231
232
233
        self.fc1 = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
            quant_config=quant_config if quantizable else None,
234
            prefix=f"{prefix}.fc1",
235
236
237
238
239
        )
        self.fc2 = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
            quant_config=quant_config if quantizable else None,
240
            prefix=f"{prefix}.fc2",
241
242
243
244
245
246
247
248
249
250
251
252
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)
        return hidden_states


class SiglipEncoderLayer(nn.Module):
    def __init__(
        self,
253
        config: SiglipVisionConfig,
254
        quant_config: Optional[QuantizationConfig] = None,
255
256
        prefix: str = "",
    ) -> None:
257
        super().__init__()
258

259
260
        self.embed_dim = config.hidden_size

261
262
263
264
265
        self.self_attn = SiglipAttention(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
        )
266
        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
267
268
269
        self.mlp = SiglipMLP(
            config,
            quant_config=quant_config,
270
            prefix=f"{prefix}.mlp",
271
        )
272
        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
273
274
275
276

    def forward(
        self,
        hidden_states: torch.Tensor,
277
    ) -> tuple[torch.Tensor, None]:
278
279
280
        residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
281
        hidden_states, _ = self.self_attn(hidden_states=hidden_states)
282
        hidden_states += residual
283
284
285
286

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
287
        hidden_states += residual
288
289
290
291
292
293
294

        return hidden_states, None


class SiglipEncoder(nn.Module):
    def __init__(
        self,
295
        config: SiglipVisionConfig,
296
        quant_config: Optional[QuantizationConfig] = None,
297
        num_hidden_layers_override: Optional[int] = None,
298
299
        prefix: str = "",
    ) -> None:
300
        super().__init__()
301

302
        self.config = config
303
304
305
306
307
308

        if num_hidden_layers_override is None:
            num_hidden_layers = config.num_hidden_layers
        else:
            num_hidden_layers = num_hidden_layers_override

309
310
311
312
313
314
315
316
317
318
        self.layers = nn.ModuleList(
            [
                SiglipEncoderLayer(
                    config,
                    quant_config=quant_config,
                    prefix=f"{prefix}.layers.{layer_idx}",
                )
                for layer_idx in range(num_hidden_layers)
            ]
        )
319
320
321
322

    def forward(
        self,
        inputs_embeds: torch.Tensor,
323
324
        return_all_hidden_states: bool,
    ) -> Union[torch.Tensor, list[torch.Tensor]]:
325
        hidden_states_pool = [inputs_embeds]
326
        hidden_states = inputs_embeds
327

328
329
        for encoder_layer in self.layers:
            hidden_states, _ = encoder_layer(hidden_states)
330
331
332
333
334
335
            if return_all_hidden_states:
                hidden_states_pool.append(hidden_states)
        # If we have multiple feature sample layers, we return all hidden
        # states in order and grab the ones we need by index.
        if return_all_hidden_states:
            return hidden_states_pool
336
337
338
339
340
341
342
343
344
345
        return hidden_states


class SiglipMultiheadAttentionPoolingHead(nn.Module):
    """Multihead Attention Pooling."""

    def __init__(
        self,
        config: SiglipVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
346
347
        prefix: str = "",
    ) -> None:
348
349
350
351
352
        super().__init__()

        self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
        # TODO(ChristopherCho): Implement vLLM version of MultiheadAttention
        self.attention = torch.nn.MultiheadAttention(
353
354
355
356
357
358
            config.hidden_size, config.num_attention_heads, batch_first=True
        )
        self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.mlp = SiglipMLP(
            config=config, quant_config=quant_config, prefix=f"{prefix}.mlp"
        )
359
360
361
362
363
364
365
366
367

    def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
        batch_size = hidden_state.shape[0]
        probe = self.probe.repeat(batch_size, 1, 1)

        hidden_state = self.attention(probe, hidden_state, hidden_state)[0]

        residual = hidden_state
        hidden_state = self.layernorm(hidden_state)
368
369
        hidden_state = self.mlp(hidden_state)
        hidden_state += residual
370
371
372
373
374
375
376
377
378

        return hidden_state[:, 0]


class SiglipVisionTransformer(nn.Module):
    def __init__(
        self,
        config: SiglipVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
379
        *,
380
        num_hidden_layers_override: Optional[int] = None,
381
382
383
        require_post_norm: Optional[bool] = None,
        prefix: str = "",
    ) -> None:
384
        super().__init__()
385

386
387
388
389
        self.config = config
        embed_dim = config.hidden_size

        self.embeddings = SiglipVisionEmbeddings(config)
390

391
392
393
        self.encoder = SiglipEncoder(
            config,
            quant_config=quant_config,
394
            num_hidden_layers_override=num_hidden_layers_override,
395
            prefix=f"{prefix}.encoder",
396
        )
397

398
        num_hidden_layers = config.num_hidden_layers
399
400
        if len(self.encoder.layers) > config.num_hidden_layers:
            raise ValueError(
401
                f"The original encoder only has {num_hidden_layers} "
402
403
                f"layers, but you requested {len(self.encoder.layers)} layers."
            )
404
405
406
407
408
409

        # If possible, skip post_layernorm to conserve memory
        if require_post_norm is None:
            require_post_norm = len(self.encoder.layers) == num_hidden_layers

        if require_post_norm:
410
            self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
411
        else:
412
413
            self.post_layernorm = None

414
415
416
        self.use_head = (
            True if not hasattr(config, "vision_use_head") else config.vision_use_head
        )
417
418
        if self.use_head:
            self.head = SiglipMultiheadAttentionPoolingHead(
419
420
421
422
                config=config,
                quant_config=quant_config,
                prefix=f"{prefix}.head",
            )
423
424
425
426

    def forward(
        self,
        pixel_values: torch.Tensor,
427
428
429
430
        *,
        interpolate_pos_encoding: bool = False,
        select_layers: Optional[list[int]] = None,
        feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None,
431
432
433
434
435
436
    ) -> torch.Tensor:
        hidden_states = self.embeddings(
            pixel_values,
            interpolate_pos_encoding=interpolate_pos_encoding,
        )

437
        # Produces either the last layer output or all of the hidden states,
438
        # depending on if we have select_layers or not
439
440
        encoder_outputs = self.encoder(
            inputs_embeds=hidden_states,
441
            return_all_hidden_states=select_layers is not None,
442
        )
443

444
445
        # Handle post-norm (if applicable) and stacks feature layers if needed
        encoder_outputs = resolve_visual_encoder_outputs(
446
447
448
449
450
451
            encoder_outputs,
            self.post_layernorm,
            select_layers=select_layers,
            max_possible_layers=self.config.num_hidden_layers,
            feature_select_strategy=feature_select_strategy,
        )
452

453
        # TODO: add this back when pooled_output is used in inference.
454
        # if self.use_head:
455
        # pooled_output = self.head(encoder_outputs)
456

457
        return encoder_outputs
458
459
460
461
462
463
464
465
466
467


class SiglipVisionModel(nn.Module):
    config_class = SiglipVisionConfig
    main_input_name = "pixel_values"

    def __init__(
        self,
        config: SiglipVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
468
        *,
469
        num_hidden_layers_override: Optional[int] = None,
470
471
472
        require_post_norm: Optional[bool] = None,
        prefix: str = "",
    ) -> None:
473
        super().__init__()
474

475
476
477
        self.vision_model = SiglipVisionTransformer(
            config,
            quant_config,
478
            num_hidden_layers_override=num_hidden_layers_override,
479
480
            require_post_norm=require_post_norm,
            prefix=f"{prefix}.vision_model",
481
482
483
484
485
        )

    def get_input_embeddings(self) -> nn.Module:
        return self.vision_model.embeddings.patch_embedding

486
487
488
489
    @property
    def dtype(self):
        return self.get_input_embeddings().weight.dtype

490
491
492
493
    def forward(
        self,
        pixel_values: torch.Tensor,
        interpolate_pos_encoding: bool = False,
494
495
        select_layers: Optional[list[int]] = None,
        feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None,
496
497
498
499
    ) -> torch.Tensor:
        return self.vision_model(
            pixel_values=pixel_values,
            interpolate_pos_encoding=interpolate_pos_encoding,
500
501
            select_layers=select_layers,
            feature_select_strategy=feature_select_strategy,
502
        )
503

504
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
505
506
507
508
509
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
510
        ]
511
        params_dict = dict(self.named_parameters())
512
        loaded_params: set[str] = set()
513
514
515
        layer_count = len(self.vision_model.encoder.layers)

        for name, loaded_weight in weights:
516
            # post_layernorm is optional in SiglipVisionModel
517
518
519
520
            if (
                name.startswith("vision_model.post_layernorm")
                and self.vision_model.post_layernorm is None
            ):
521
522
                continue

523
            # omit layers when num_hidden_layers_override is set
524
            if name.startswith("vision_model.encoder.layers"):
525
526
527
528
                layer_idx = int(name.split(".")[3])
                if layer_idx >= layer_count:
                    continue

529
            # Check if this is a scale parameter that needs remapping first
530
            if name.endswith((".k_scale", ".v_scale", ".q_scale", ".prob_scale")):
531
532
533
534
535
                # Try to remap the scale name first
                remapped_name = maybe_remap_kv_scale_name(name, params_dict)
                if remapped_name is not None and remapped_name in params_dict:
                    # Successfully remapped, use the remapped name
                    param = params_dict[remapped_name]
536
537
538
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
539
540
541
542
543
                    weight_loader(param, loaded_weight)
                    loaded_params.add(remapped_name)
                    continue
                # If remapping failed, continue with normal processing

544
            for param_name, weight_name, shard_id in stacked_params_mapping:
545
546
                if weight_name not in name:
                    continue
547
                name = name.replace(weight_name, param_name)
548

549
                param = params_dict[name]
550
551
552
553
554
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
555
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
556
                weight_loader(param, loaded_weight)
557
558
            loaded_params.add(name)
        return loaded_params