siglip.py 18.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
"""Implementation of SiglipVisionModel intended to be only used
within a vision language model."""

import math
6
7
from collections.abc import Iterable
from typing import Optional, Union
8
9
10

import torch
from torch import nn
11
from transformers import SiglipVisionConfig
12

13
from vllm.attention.layer import MultiHeadAttention
14
from vllm.distributed import divide, get_tensor_model_parallel_world_size
15
16
17
18
19
20
21
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
22
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
23

24
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
25

26

27
28
29
30
31
32
33
34
class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]):

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
35
        return self.get_patch_grid_length()**2
36

37
38
39
40
41
42
43
    def get_image_size(self) -> int:
        return self.vision_config.image_size

    def get_patch_size(self) -> int:
        return self.vision_config.patch_size

    def get_patch_grid_length(self) -> int:
44
45
        image_size, patch_size = self.get_image_size(), self.get_patch_size()
        return image_size // patch_size
46
47


48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa
class SiglipVisionEmbeddings(nn.Module):

    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size

        self.patch_embedding = nn.Conv2d(
            in_channels=config.num_channels,
            out_channels=self.embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            padding="valid",
        )

        self.num_patches = (self.image_size // self.patch_size)**2
        self.num_positions = self.num_patches
        self.position_embedding = VocabParallelEmbedding(
            self.num_positions, self.embed_dim)
        self.register_buffer(
            "position_ids",
            torch.arange(self.num_positions, dtype=torch.int64).expand(
                (1, -1)),
            persistent=False,
        )

    def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int,
                                 width: int) -> torch.Tensor:
        """
        This method is an adapted method for SigLIP (due to SigLIP not having
        class embedding unlike other ViTs) that allows the model to interpolate
        the pre-trained position encodings such that it can be usable on higher
        resolution images.

        Source:
        https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
        """
        position_embeddings = self.position_embedding.weight.unsqueeze(0)
        num_patches = embeddings.shape[1]
        num_positions = position_embeddings.shape[1]
        if num_patches == num_positions and height == width:
            return position_embeddings

        dim = embeddings.shape[-1]
        height = height // self.patch_size
        width = width // self.patch_size
        # we add a small number to avoid floating point error
        # in the interpolation
        # see discussion at https://github.com/facebookresearch/dino/issues/8
        height, width = height + 0.1, width + 0.1

        patch_pos_embed = position_embeddings.reshape(
            1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)),
            dim)
        patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
        patch_pos_embed = nn.functional.interpolate(
            patch_pos_embed,
            scale_factor=(
                height / math.sqrt(num_positions),
                width / math.sqrt(num_positions),
            ),
            mode="bicubic",
            align_corners=False,
        )
        if (int(height) != patch_pos_embed.shape[-2]
                or int(width) != patch_pos_embed.shape[-1]):
            raise ValueError("Width or height does not match with "
                             "the interpolated position embeddings")

        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return patch_pos_embed

    def forward(self,
                pixel_values: torch.Tensor,
                interpolate_pos_encoding: bool = False) -> torch.Tensor:
        _, _, height, width = pixel_values.shape
        target_dtype = self.patch_embedding.weight.dtype
        patch_embeds = self.patch_embedding(pixel_values.to(
            dtype=target_dtype))  # shape = [*, width, grid, grid]
        embeddings = patch_embeds.flatten(2).transpose(1, 2)

        if interpolate_pos_encoding:
            embeddings = embeddings + self.interpolate_pos_encoding(
                embeddings, height, width)
        else:
            embeddings = embeddings + self.position_embedding(
                self.position_ids)
        return embeddings


141
class SiglipAttention(nn.Module):
142
143
144

    def __init__(
        self,
145
        config: SiglipVisionConfig,
146
        quant_config: Optional[QuantizationConfig] = None,
147
148
        prefix: str = "",
    ) -> None:
149
        super().__init__()
150

151
152
        self.config = config
        self.embed_dim = config.hidden_size
153
154
155
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
156
157
158
            raise ValueError(f"embed_dim must be divisible by num_heads (got "
                             "`embed_dim`: {self.embed_dim} and `num_heads`:"
                             f" {self.num_heads}).")
159

160
161
162
163
164
        self.scale = self.head_dim**-0.5
        self.dropout = config.attention_dropout
        self.qkv_proj = QKVParallelLinear(
            hidden_size=self.embed_dim,
            head_size=self.head_dim,
165
            total_num_heads=self.num_heads,
166
            quant_config=quant_config,
167
            prefix=f"{prefix}.qkv_proj",
168
        )
169

170
171
172
173
        self.out_proj = RowParallelLinear(
            input_size=self.embed_dim,
            output_size=self.embed_dim,
            quant_config=quant_config,
174
            prefix=f"{prefix}.out_proj",
175
176
        )

177
178
        self.tp_size = get_tensor_model_parallel_world_size()
        self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
179

180
181
        self.attn = MultiHeadAttention(self.num_heads_per_partition,
                                       self.head_dim, self.scale)
182

183
184
185
186
187
188
    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        """Input shape: Batch x Time x Channel"""
        qkv_states, _ = self.qkv_proj(hidden_states)
189
190
        query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)

191
        out = self.attn(query_states, key_states, value_states)
192
        attn_output, _ = self.out_proj(out)
193

194
        return attn_output, None
195
196
197
198
199
200


class SiglipMLP(nn.Module):

    def __init__(
        self,
201
        config: SiglipVisionConfig,
202
        quant_config: Optional[QuantizationConfig] = None,
203
204
        prefix: str = "",
    ) -> None:
205
        super().__init__()
206

207
208
        self.config = config
        self.activation_fn = get_act_fn(config.hidden_act)
209
210
211
212
        # Special handling for BNB and torchao quantization
        if quant_config and quant_config.get_name() in [
                "bitsandbytes", "torchao"
        ]:
213
214
            quantizable = True
        else:
215
            # For other quantization, we require the hidden size to be a
216
            # multiple of 64
217
218
            quantizable = (config.hidden_size % 64 == 0
                           and config.intermediate_size % 64 == 0)
219
220
221
222
        self.fc1 = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
            quant_config=quant_config if quantizable else None,
223
            prefix=f"{prefix}.fc1",
224
225
226
227
228
        )
        self.fc2 = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
            quant_config=quant_config if quantizable else None,
229
            prefix=f"{prefix}.fc2",
230
231
232
233
234
235
236
237
238
239
240
241
242
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)
        return hidden_states


class SiglipEncoderLayer(nn.Module):

    def __init__(
        self,
243
        config: SiglipVisionConfig,
244
        quant_config: Optional[QuantizationConfig] = None,
245
246
        prefix: str = "",
    ) -> None:
247
        super().__init__()
248

249
250
        self.embed_dim = config.hidden_size

251
252
253
254
255
        self.self_attn = SiglipAttention(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
        )
256
257
258
259
260
        self.layer_norm1 = nn.LayerNorm(self.embed_dim,
                                        eps=config.layer_norm_eps)
        self.mlp = SiglipMLP(
            config,
            quant_config=quant_config,
261
            prefix=f"{prefix}.mlp",
262
263
264
265
266
267
268
        )
        self.layer_norm2 = nn.LayerNorm(self.embed_dim,
                                        eps=config.layer_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
269
    ) -> tuple[torch.Tensor, None]:
270
271
272
        residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
273
        hidden_states, _ = self.self_attn(hidden_states=hidden_states)
274
275
276
277
278
279
280
281
282
283
284
285
286
287
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states, None


class SiglipEncoder(nn.Module):

    def __init__(
        self,
288
        config: SiglipVisionConfig,
289
        quant_config: Optional[QuantizationConfig] = None,
290
        num_hidden_layers_override: Optional[int] = None,
291
292
        prefix: str = "",
    ) -> None:
293
        super().__init__()
294

295
        self.config = config
296
297
298
299
300
301

        if num_hidden_layers_override is None:
            num_hidden_layers = config.num_hidden_layers
        else:
            num_hidden_layers = num_hidden_layers_override

302
        self.layers = nn.ModuleList([
303
304
305
306
            SiglipEncoderLayer(config,
                               quant_config=quant_config,
                               prefix=f"{prefix}.layers.{layer_idx}")
            for layer_idx in range(num_hidden_layers)
307
308
309
310
311
        ])

    def forward(
        self,
        inputs_embeds: torch.Tensor,
312
313
        return_all_hidden_states: bool,
    ) -> Union[torch.Tensor, list[torch.Tensor]]:
314
        hidden_states_pool = [inputs_embeds]
315
        hidden_states = inputs_embeds
316

317
318
        for encoder_layer in self.layers:
            hidden_states, _ = encoder_layer(hidden_states)
319
320
321
322
323
324
            if return_all_hidden_states:
                hidden_states_pool.append(hidden_states)
        # If we have multiple feature sample layers, we return all hidden
        # states in order and grab the ones we need by index.
        if return_all_hidden_states:
            return hidden_states_pool
325
326
327
328
329
330
331
332
333
334
        return hidden_states


class SiglipMultiheadAttentionPoolingHead(nn.Module):
    """Multihead Attention Pooling."""

    def __init__(
        self,
        config: SiglipVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
335
336
        prefix: str = "",
    ) -> None:
337
338
339
340
341
342
343
344
        super().__init__()

        self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
        # TODO(ChristopherCho): Implement vLLM version of MultiheadAttention
        self.attention = torch.nn.MultiheadAttention(
            config.hidden_size, config.num_attention_heads, batch_first=True)
        self.layernorm = nn.LayerNorm(config.hidden_size,
                                      eps=config.layer_norm_eps)
345
346
347
        self.mlp = SiglipMLP(config=config,
                             quant_config=quant_config,
                             prefix=f"{prefix}.mlp")
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367

    def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
        batch_size = hidden_state.shape[0]
        probe = self.probe.repeat(batch_size, 1, 1)

        hidden_state = self.attention(probe, hidden_state, hidden_state)[0]

        residual = hidden_state
        hidden_state = self.layernorm(hidden_state)
        hidden_state = residual + self.mlp(hidden_state)

        return hidden_state[:, 0]


class SiglipVisionTransformer(nn.Module):

    def __init__(
        self,
        config: SiglipVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
368
        *,
369
        num_hidden_layers_override: Optional[int] = None,
370
371
372
        require_post_norm: Optional[bool] = None,
        prefix: str = "",
    ) -> None:
373
        super().__init__()
374

375
376
377
378
        self.config = config
        embed_dim = config.hidden_size

        self.embeddings = SiglipVisionEmbeddings(config)
379

380
381
382
        self.encoder = SiglipEncoder(
            config,
            quant_config=quant_config,
383
            num_hidden_layers_override=num_hidden_layers_override,
384
            prefix=f"{prefix}.encoder",
385
        )
386

387
        num_hidden_layers = config.num_hidden_layers
388
389
        if len(self.encoder.layers) > config.num_hidden_layers:
            raise ValueError(
390
                f"The original encoder only has {num_hidden_layers} "
391
392
                f"layers, but you requested {len(self.encoder.layers)} layers."
            )
393
394
395
396
397
398

        # If possible, skip post_layernorm to conserve memory
        if require_post_norm is None:
            require_post_norm = len(self.encoder.layers) == num_hidden_layers

        if require_post_norm:
399
400
401
            self.post_layernorm = nn.LayerNorm(embed_dim,
                                               eps=config.layer_norm_eps)
        else:
402
403
            self.post_layernorm = None

404
405
406
407
        self.use_head = (True if not hasattr(config, "vision_use_head") else
                         config.vision_use_head)
        if self.use_head:
            self.head = SiglipMultiheadAttentionPoolingHead(
408
409
410
411
                config=config,
                quant_config=quant_config,
                prefix=f"{prefix}.head",
            )
412
413
414
415
416

    def forward(
        self,
        pixel_values: torch.Tensor,
        interpolate_pos_encoding: bool = True,
417
        feature_sample_layers: Optional[list[int]] = None,
418
    ) -> torch.Tensor:
419

420
421
422
423
424
        hidden_states = self.embeddings(
            pixel_values,
            interpolate_pos_encoding=interpolate_pos_encoding,
        )

425
426
427
428
429
430
431
432
        return_all_hidden_states = feature_sample_layers is not None

        # Produces either the last layer output or all of the hidden states,
        # depending on if we have feature_sample_layers or not
        encoder_outputs = self.encoder(
            inputs_embeds=hidden_states,
            return_all_hidden_states=return_all_hidden_states,
        )
433

434
435
436
437
        # Handle post-norm (if applicable) and stacks feature layers if needed
        encoder_outputs = resolve_visual_encoder_outputs(
            encoder_outputs, feature_sample_layers, self.post_layernorm,
            self.config.num_hidden_layers)
438

439
        # TODO: add this back when pooled_output is used in inference.
440
        # if self.use_head:
441
        # pooled_output = self.head(encoder_outputs)
442

443
        return encoder_outputs
444
445
446
447
448
449
450
451
452
453


class SiglipVisionModel(nn.Module):
    config_class = SiglipVisionConfig
    main_input_name = "pixel_values"

    def __init__(
        self,
        config: SiglipVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
454
        *,
455
        num_hidden_layers_override: Optional[int] = None,
456
457
458
        require_post_norm: Optional[bool] = None,
        prefix: str = "",
    ) -> None:
459
        super().__init__()
460

461
462
463
        self.vision_model = SiglipVisionTransformer(
            config,
            quant_config,
464
            num_hidden_layers_override=num_hidden_layers_override,
465
466
            require_post_norm=require_post_norm,
            prefix=f"{prefix}.vision_model",
467
468
469
470
471
472
473
474
475
        )

    def get_input_embeddings(self) -> nn.Module:
        return self.vision_model.embeddings.patch_embedding

    def forward(
        self,
        pixel_values: torch.Tensor,
        interpolate_pos_encoding: bool = False,
476
        feature_sample_layers: Optional[list[int]] = None,
477
478
479
480
    ) -> torch.Tensor:
        return self.vision_model(
            pixel_values=pixel_values,
            interpolate_pos_encoding=interpolate_pos_encoding,
481
            feature_sample_layers=feature_sample_layers,
482
        )
483

484
485
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
486
487
488
489
490
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
491
        ]
492
        params_dict = dict(self.named_parameters())
493
        loaded_params: set[str] = set()
494
495
496
        layer_count = len(self.vision_model.encoder.layers)

        for name, loaded_weight in weights:
497
            # post_layernorm is optional in SiglipVisionModel
498
499
            if (name.startswith("vision_model.post_layernorm")
                    and self.vision_model.post_layernorm is None):
500
501
                continue

502
            # omit layers when num_hidden_layers_override is set
503
            if name.startswith("vision_model.encoder.layers"):
504
505
506
507
                layer_idx = int(name.split(".")[3])
                if layer_idx >= layer_count:
                    continue

508
509
510
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
511
                name = name.replace(weight_name, param_name)
512

513
                param = params_dict[name]
514
515
516
517
518
519
520
521
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
522
523
            loaded_params.add(name)
        return loaded_params