siglip.py 18.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
"""Implementation of SiglipVisionModel intended to be only used
within a vision language model."""

import math
6
from typing import Iterable, Optional, Set, Tuple, Union
7
8
9

import torch
from torch import nn
10
from transformers import SiglipVisionConfig
11

12
from vllm.attention.layer import MultiHeadAttention
13
from vllm.distributed import divide, get_tensor_model_parallel_world_size
14
15
16
17
18
19
20
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
21
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
22

23
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
24

25

26
27
28
29
30
31
32
33
class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]):

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
34
        return self.get_patch_grid_length()**2
35

36
37
38
39
40
41
42
    def get_image_size(self) -> int:
        return self.vision_config.image_size

    def get_patch_size(self) -> int:
        return self.vision_config.patch_size

    def get_patch_grid_length(self) -> int:
43
44
        image_size, patch_size = self.get_image_size(), self.get_patch_size()
        return image_size // patch_size
45
46


47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa
class SiglipVisionEmbeddings(nn.Module):

    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size

        self.patch_embedding = nn.Conv2d(
            in_channels=config.num_channels,
            out_channels=self.embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            padding="valid",
        )

        self.num_patches = (self.image_size // self.patch_size)**2
        self.num_positions = self.num_patches
        self.position_embedding = VocabParallelEmbedding(
            self.num_positions, self.embed_dim)
        self.register_buffer(
            "position_ids",
            torch.arange(self.num_positions, dtype=torch.int64).expand(
                (1, -1)),
            persistent=False,
        )

    def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int,
                                 width: int) -> torch.Tensor:
        """
        This method is an adapted method for SigLIP (due to SigLIP not having
        class embedding unlike other ViTs) that allows the model to interpolate
        the pre-trained position encodings such that it can be usable on higher
        resolution images.

        Source:
        https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
        """
        position_embeddings = self.position_embedding.weight.unsqueeze(0)
        num_patches = embeddings.shape[1]
        num_positions = position_embeddings.shape[1]
        if num_patches == num_positions and height == width:
            return position_embeddings

        dim = embeddings.shape[-1]
        height = height // self.patch_size
        width = width // self.patch_size
        # we add a small number to avoid floating point error
        # in the interpolation
        # see discussion at https://github.com/facebookresearch/dino/issues/8
        height, width = height + 0.1, width + 0.1

        patch_pos_embed = position_embeddings.reshape(
            1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)),
            dim)
        patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
        patch_pos_embed = nn.functional.interpolate(
            patch_pos_embed,
            scale_factor=(
                height / math.sqrt(num_positions),
                width / math.sqrt(num_positions),
            ),
            mode="bicubic",
            align_corners=False,
        )
        if (int(height) != patch_pos_embed.shape[-2]
                or int(width) != patch_pos_embed.shape[-1]):
            raise ValueError("Width or height does not match with "
                             "the interpolated position embeddings")

        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return patch_pos_embed

    def forward(self,
                pixel_values: torch.Tensor,
                interpolate_pos_encoding: bool = False) -> torch.Tensor:
        _, _, height, width = pixel_values.shape
        target_dtype = self.patch_embedding.weight.dtype
        patch_embeds = self.patch_embedding(pixel_values.to(
            dtype=target_dtype))  # shape = [*, width, grid, grid]
        embeddings = patch_embeds.flatten(2).transpose(1, 2)

        if interpolate_pos_encoding:
            embeddings = embeddings + self.interpolate_pos_encoding(
                embeddings, height, width)
        else:
            embeddings = embeddings + self.position_embedding(
                self.position_ids)
        return embeddings


140
class SiglipAttention(nn.Module):
141
142
143

    def __init__(
        self,
144
        config: SiglipVisionConfig,
145
        quant_config: Optional[QuantizationConfig] = None,
146
147
        prefix: str = "",
    ) -> None:
148
        super().__init__()
149

150
151
        self.config = config
        self.embed_dim = config.hidden_size
152
153
154
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
155
156
157
            raise ValueError(f"embed_dim must be divisible by num_heads (got "
                             "`embed_dim`: {self.embed_dim} and `num_heads`:"
                             f" {self.num_heads}).")
158

159
160
161
162
163
        self.scale = self.head_dim**-0.5
        self.dropout = config.attention_dropout
        self.qkv_proj = QKVParallelLinear(
            hidden_size=self.embed_dim,
            head_size=self.head_dim,
164
            total_num_heads=self.num_heads,
165
            quant_config=quant_config,
166
            prefix=f"{prefix}.qkv_proj",
167
        )
168

169
170
171
172
        self.out_proj = RowParallelLinear(
            input_size=self.embed_dim,
            output_size=self.embed_dim,
            quant_config=quant_config,
173
            prefix=f"{prefix}.out_proj",
174
175
        )

176
177
        self.tp_size = get_tensor_model_parallel_world_size()
        self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
178

179
180
        self.attn = MultiHeadAttention(self.num_heads_per_partition,
                                       self.head_dim, self.scale)
181

182
183
184
185
186
187
    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        """Input shape: Batch x Time x Channel"""
        qkv_states, _ = self.qkv_proj(hidden_states)
188
189
        query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)

190
        out = self.attn(query_states, key_states, value_states)
191
        attn_output, _ = self.out_proj(out)
192

193
        return attn_output, None
194
195
196
197
198
199


class SiglipMLP(nn.Module):

    def __init__(
        self,
200
        config: SiglipVisionConfig,
201
        quant_config: Optional[QuantizationConfig] = None,
202
203
        prefix: str = "",
    ) -> None:
204
        super().__init__()
205

206
207
        self.config = config
        self.activation_fn = get_act_fn(config.hidden_act)
208
209
210
211
        # Special handling for BNB and torchao quantization
        if quant_config and quant_config.get_name() in [
                "bitsandbytes", "torchao"
        ]:
212
213
            quantizable = True
        else:
214
            # For other quantization, we require the hidden size to be a
215
            # multiple of 64
216
217
            quantizable = (config.hidden_size % 64 == 0
                           and config.intermediate_size % 64 == 0)
218
219
220
221
        self.fc1 = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
            quant_config=quant_config if quantizable else None,
222
            prefix=f"{prefix}.fc1",
223
224
225
226
227
        )
        self.fc2 = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
            quant_config=quant_config if quantizable else None,
228
            prefix=f"{prefix}.fc2",
229
230
231
232
233
234
235
236
237
238
239
240
241
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)
        return hidden_states


class SiglipEncoderLayer(nn.Module):

    def __init__(
        self,
242
        config: SiglipVisionConfig,
243
        quant_config: Optional[QuantizationConfig] = None,
244
245
        prefix: str = "",
    ) -> None:
246
        super().__init__()
247

248
249
        self.embed_dim = config.hidden_size

250
251
252
253
254
        self.self_attn = SiglipAttention(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
        )
255
256
257
258
259
        self.layer_norm1 = nn.LayerNorm(self.embed_dim,
                                        eps=config.layer_norm_eps)
        self.mlp = SiglipMLP(
            config,
            quant_config=quant_config,
260
            prefix=f"{prefix}.mlp",
261
262
263
264
265
266
267
        )
        self.layer_norm2 = nn.LayerNorm(self.embed_dim,
                                        eps=config.layer_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
268
    ) -> Tuple[torch.Tensor, None]:
269
270
271
        residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
272
        hidden_states, _ = self.self_attn(hidden_states=hidden_states)
273
274
275
276
277
278
279
280
281
282
283
284
285
286
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states, None


class SiglipEncoder(nn.Module):

    def __init__(
        self,
287
        config: SiglipVisionConfig,
288
        quant_config: Optional[QuantizationConfig] = None,
289
        num_hidden_layers_override: Optional[int] = None,
290
291
        prefix: str = "",
    ) -> None:
292
        super().__init__()
293

294
        self.config = config
295
296
297
298
299
300

        if num_hidden_layers_override is None:
            num_hidden_layers = config.num_hidden_layers
        else:
            num_hidden_layers = num_hidden_layers_override

301
        self.layers = nn.ModuleList([
302
303
304
305
            SiglipEncoderLayer(config,
                               quant_config=quant_config,
                               prefix=f"{prefix}.layers.{layer_idx}")
            for layer_idx in range(num_hidden_layers)
306
307
308
309
310
        ])

    def forward(
        self,
        inputs_embeds: torch.Tensor,
311
312
        return_all_hidden_states: bool,
    ) -> Union[torch.Tensor, list[torch.Tensor]]:
313
        hidden_states_pool = [inputs_embeds]
314
        hidden_states = inputs_embeds
315

316
317
        for encoder_layer in self.layers:
            hidden_states, _ = encoder_layer(hidden_states)
318
319
320
321
322
323
            if return_all_hidden_states:
                hidden_states_pool.append(hidden_states)
        # If we have multiple feature sample layers, we return all hidden
        # states in order and grab the ones we need by index.
        if return_all_hidden_states:
            return hidden_states_pool
324
325
326
327
328
329
330
331
332
333
        return hidden_states


class SiglipMultiheadAttentionPoolingHead(nn.Module):
    """Multihead Attention Pooling."""

    def __init__(
        self,
        config: SiglipVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
334
335
        prefix: str = "",
    ) -> None:
336
337
338
339
340
341
342
343
        super().__init__()

        self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
        # TODO(ChristopherCho): Implement vLLM version of MultiheadAttention
        self.attention = torch.nn.MultiheadAttention(
            config.hidden_size, config.num_attention_heads, batch_first=True)
        self.layernorm = nn.LayerNorm(config.hidden_size,
                                      eps=config.layer_norm_eps)
344
345
346
        self.mlp = SiglipMLP(config=config,
                             quant_config=quant_config,
                             prefix=f"{prefix}.mlp")
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366

    def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
        batch_size = hidden_state.shape[0]
        probe = self.probe.repeat(batch_size, 1, 1)

        hidden_state = self.attention(probe, hidden_state, hidden_state)[0]

        residual = hidden_state
        hidden_state = self.layernorm(hidden_state)
        hidden_state = residual + self.mlp(hidden_state)

        return hidden_state[:, 0]


class SiglipVisionTransformer(nn.Module):

    def __init__(
        self,
        config: SiglipVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
367
        *,
368
        num_hidden_layers_override: Optional[int] = None,
369
370
371
        require_post_norm: Optional[bool] = None,
        prefix: str = "",
    ) -> None:
372
        super().__init__()
373

374
375
376
377
        self.config = config
        embed_dim = config.hidden_size

        self.embeddings = SiglipVisionEmbeddings(config)
378

379
380
381
        self.encoder = SiglipEncoder(
            config,
            quant_config=quant_config,
382
            num_hidden_layers_override=num_hidden_layers_override,
383
            prefix=f"{prefix}.encoder",
384
        )
385

386
        num_hidden_layers = config.num_hidden_layers
387
388
        if len(self.encoder.layers) > config.num_hidden_layers:
            raise ValueError(
389
                f"The original encoder only has {num_hidden_layers} "
390
391
                f"layers, but you requested {len(self.encoder.layers)} layers."
            )
392
393
394
395
396
397

        # If possible, skip post_layernorm to conserve memory
        if require_post_norm is None:
            require_post_norm = len(self.encoder.layers) == num_hidden_layers

        if require_post_norm:
398
399
400
            self.post_layernorm = nn.LayerNorm(embed_dim,
                                               eps=config.layer_norm_eps)
        else:
401
402
            self.post_layernorm = None

403
404
405
406
        self.use_head = (True if not hasattr(config, "vision_use_head") else
                         config.vision_use_head)
        if self.use_head:
            self.head = SiglipMultiheadAttentionPoolingHead(
407
408
409
410
                config=config,
                quant_config=quant_config,
                prefix=f"{prefix}.head",
            )
411
412
413
414
415

    def forward(
        self,
        pixel_values: torch.Tensor,
        interpolate_pos_encoding: bool = True,
416
        feature_sample_layers: Optional[list[int]] = None,
417
    ) -> torch.Tensor:
418

419
420
421
422
423
        hidden_states = self.embeddings(
            pixel_values,
            interpolate_pos_encoding=interpolate_pos_encoding,
        )

424
425
426
427
428
429
430
431
        return_all_hidden_states = feature_sample_layers is not None

        # Produces either the last layer output or all of the hidden states,
        # depending on if we have feature_sample_layers or not
        encoder_outputs = self.encoder(
            inputs_embeds=hidden_states,
            return_all_hidden_states=return_all_hidden_states,
        )
432

433
434
435
436
        # Handle post-norm (if applicable) and stacks feature layers if needed
        encoder_outputs = resolve_visual_encoder_outputs(
            encoder_outputs, feature_sample_layers, self.post_layernorm,
            self.config.num_hidden_layers)
437

438
        # TODO: add this back when pooled_output is used in inference.
439
        # if self.use_head:
440
        # pooled_output = self.head(encoder_outputs)
441

442
        return encoder_outputs
443
444
445
446
447
448
449
450
451
452


class SiglipVisionModel(nn.Module):
    config_class = SiglipVisionConfig
    main_input_name = "pixel_values"

    def __init__(
        self,
        config: SiglipVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
453
        *,
454
        num_hidden_layers_override: Optional[int] = None,
455
456
457
        require_post_norm: Optional[bool] = None,
        prefix: str = "",
    ) -> None:
458
        super().__init__()
459

460
461
462
        self.vision_model = SiglipVisionTransformer(
            config,
            quant_config,
463
            num_hidden_layers_override=num_hidden_layers_override,
464
465
            require_post_norm=require_post_norm,
            prefix=f"{prefix}.vision_model",
466
467
468
469
470
471
472
473
474
        )

    def get_input_embeddings(self) -> nn.Module:
        return self.vision_model.embeddings.patch_embedding

    def forward(
        self,
        pixel_values: torch.Tensor,
        interpolate_pos_encoding: bool = False,
475
        feature_sample_layers: Optional[list[int]] = None,
476
477
478
479
    ) -> torch.Tensor:
        return self.vision_model(
            pixel_values=pixel_values,
            interpolate_pos_encoding=interpolate_pos_encoding,
480
            feature_sample_layers=feature_sample_layers,
481
        )
482

483
484
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
485
486
487
488
489
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
490
        ]
491
        params_dict = dict(self.named_parameters())
492
        loaded_params: Set[str] = set()
493
494
495
        layer_count = len(self.vision_model.encoder.layers)

        for name, loaded_weight in weights:
496
            # post_layernorm is optional in SiglipVisionModel
497
498
            if (name.startswith("vision_model.post_layernorm")
                    and self.vision_model.post_layernorm is None):
499
500
                continue

501
            # omit layers when num_hidden_layers_override is set
502
            if name.startswith("vision_model.encoder.layers"):
503
504
505
506
                layer_idx = int(name.split(".")[3])
                if layer_idx >= layer_count:
                    continue

507
508
509
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
510
                name = name.replace(weight_name, param_name)
511

512
                param = params_dict[name]
513
514
515
516
517
518
519
520
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
521
522
            loaded_params.add(name)
        return loaded_params