siglip.py 21.7 KB
Newer Older
1
2
3
4
"""Implementation of SiglipVisionModel intended to be only used
within a vision language model."""

import math
5
from typing import Iterable, List, Optional, Tuple, Union
6

7
import numpy as np
8
9
10
import torch
from PIL import Image
from torch import nn
11
from transformers import SiglipVisionConfig
12
from transformers.models.siglip.modeling_siglip import SiglipSdpaAttention
13
14

from vllm.config import ModelConfig
15
from vllm.distributed import divide, get_tensor_model_parallel_world_size
16
from vllm.inputs import DecoderOnlyInputs, token_inputs
17
18
19
20
21
22
23
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
24
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
25
26
from vllm.multimodal.utils import (cached_get_tokenizer,
                                   repeat_and_pad_placeholder_tokens)
27
from vllm.sequence import SequenceData
28

29
30
31
32
33
34
try:
    from xformers import ops as xops
    USE_XFORMERS_OPS = True
except ImportError:
    USE_XFORMERS_OPS = False

35
36

def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
37
38
    # Since interpolation is applied, the image size need not be divisible
    # assert image_size % patch_size == 0
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
    return image_size // patch_size


def get_siglip_num_patches(*, image_size: int, patch_size: int) -> int:
    grid_length = get_siglip_patch_grid_length(image_size=image_size,
                                               patch_size=patch_size)
    return grid_length * grid_length


def get_siglip_image_feature_size(hf_config: SiglipVisionConfig) -> int:
    return get_siglip_num_patches(image_size=hf_config.image_size,
                                  patch_size=hf_config.patch_size)


def get_max_siglip_image_tokens(hf_config: SiglipVisionConfig) -> int:
    return get_siglip_image_feature_size(hf_config)


def dummy_seq_data_for_siglip(
    hf_config: SiglipVisionConfig,
    seq_len: int,
60
    num_images: int,
61
62
63
64
65
66
67
68
69
    *,
    image_token_id: int,
    image_feature_size_override: Optional[int] = None,
):
    if image_feature_size_override is None:
        image_feature_size = get_siglip_image_feature_size(hf_config)
    else:
        image_feature_size = image_feature_size_override

70
    return SequenceData.from_prompt_token_counts(
71
72
73
        (image_token_id, image_feature_size * num_images),
        (0, seq_len - image_feature_size * num_images),
    )
74
75
76
77


def dummy_image_for_siglip(
    hf_config: SiglipVisionConfig,
78
    num_images: int,
79
80
81
82
83
84
85
86
87
88
89
    *,
    image_width_override: Optional[int] = None,
    image_height_override: Optional[int] = None,
):
    width = height = hf_config.image_size
    if image_width_override is not None:
        width = image_width_override
    if image_height_override is not None:
        height = image_height_override

    image = Image.new("RGB", (width, height), color=0)
90
    return {"image": image if num_images == 1 else [image] * num_images}
91
92


93
94
95
def dummy_video_for_siglip(
    hf_config: SiglipVisionConfig,
    num_frames: int,
96
    num_videos: int = 1,
97
98
99
100
101
102
103
104
105
106
107
    *,
    image_width_override: Optional[int] = None,
    image_height_override: Optional[int] = None,
):
    pil_frame = dummy_image_for_siglip(
        hf_config,
        num_images=1,
        image_width_override=image_width_override,
        image_height_override=image_height_override)
    np_frame = np.array(pil_frame["image"])
    mm_data_per_video = np.repeat([np_frame], num_frames, axis=0)
108
109
    video_data = [mm_data_per_video] * num_videos
    mm_data = {"video": video_data}
110
111
112
    return mm_data


113
114
115
def input_processor_for_siglip(
    model_config: ModelConfig,
    hf_config: SiglipVisionConfig,
116
    inputs: DecoderOnlyInputs,
117
118
    *,
    image_token_id: int,
119
    image_feature_size_override: Optional[Union[int, List[int]]] = None,
120
):
121
    multi_modal_data = inputs.get("multi_modal_data")
122
    if multi_modal_data is None or "image" not in multi_modal_data:
123
        return inputs
124
125
126
127

    tokenizer = cached_get_tokenizer(model_config.tokenizer)

    if image_feature_size_override is None:
128
129
130
131
        image_data = multi_modal_data["image"]
        if isinstance(image_data, Image.Image):
            image_feature_size = get_siglip_image_feature_size(hf_config)
        elif isinstance(image_data, torch.Tensor):
132
            num_images, image_feature_size, hidden_size = image_data.shape
133
134
        else:
            raise TypeError(f"Invalid image type: {type(image_data)}")
135
136
137
    else:
        image_feature_size = image_feature_size_override

138
    new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
139
        tokenizer,
140
141
        inputs.get("prompt"),
        inputs["prompt_token_ids"],
142
        placeholder_token_id=image_token_id,
143
144
145
146
        repeat_count=image_feature_size,
    )

    # NOTE: Create a defensive copy of the original inputs
147
    return token_inputs(
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
        prompt_token_ids=new_token_ids,
        prompt=new_prompt,
        multi_modal_data=multi_modal_data,
    )


# Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa
class SiglipVisionEmbeddings(nn.Module):

    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size

        self.patch_embedding = nn.Conv2d(
            in_channels=config.num_channels,
            out_channels=self.embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            padding="valid",
        )

        self.num_patches = (self.image_size // self.patch_size)**2
        self.num_positions = self.num_patches
        self.position_embedding = VocabParallelEmbedding(
            self.num_positions, self.embed_dim)
        self.register_buffer(
            "position_ids",
            torch.arange(self.num_positions, dtype=torch.int64).expand(
                (1, -1)),
            persistent=False,
        )

    def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int,
                                 width: int) -> torch.Tensor:
        """
        This method is an adapted method for SigLIP (due to SigLIP not having
        class embedding unlike other ViTs) that allows the model to interpolate
        the pre-trained position encodings such that it can be usable on higher
        resolution images.

        Source:
        https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
        """
        position_embeddings = self.position_embedding.weight.unsqueeze(0)
        num_patches = embeddings.shape[1]
        num_positions = position_embeddings.shape[1]
        if num_patches == num_positions and height == width:
            return position_embeddings

        dim = embeddings.shape[-1]
        height = height // self.patch_size
        width = width // self.patch_size
        # we add a small number to avoid floating point error
        # in the interpolation
        # see discussion at https://github.com/facebookresearch/dino/issues/8
        height, width = height + 0.1, width + 0.1

        patch_pos_embed = position_embeddings.reshape(
            1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)),
            dim)
        patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
        patch_pos_embed = nn.functional.interpolate(
            patch_pos_embed,
            scale_factor=(
                height / math.sqrt(num_positions),
                width / math.sqrt(num_positions),
            ),
            mode="bicubic",
            align_corners=False,
        )
        if (int(height) != patch_pos_embed.shape[-2]
                or int(width) != patch_pos_embed.shape[-1]):
            raise ValueError("Width or height does not match with "
                             "the interpolated position embeddings")

        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return patch_pos_embed

    def forward(self,
                pixel_values: torch.Tensor,
                interpolate_pos_encoding: bool = False) -> torch.Tensor:
        _, _, height, width = pixel_values.shape
        target_dtype = self.patch_embedding.weight.dtype
        patch_embeds = self.patch_embedding(pixel_values.to(
            dtype=target_dtype))  # shape = [*, width, grid, grid]
        embeddings = patch_embeds.flatten(2).transpose(1, 2)

        if interpolate_pos_encoding:
            embeddings = embeddings + self.interpolate_pos_encoding(
                embeddings, height, width)
        else:
            embeddings = embeddings + self.position_embedding(
                self.position_ids)
        return embeddings


247
class SiglipParallelAttention(nn.Module):
248
249
250

    def __init__(
        self,
251
        config: SiglipVisionConfig,
252
        quant_config: Optional[QuantizationConfig] = None,
253
254
        prefix: str = "",
    ) -> None:
255
        super().__init__()
256

257
258
        self.config = config
        self.embed_dim = config.hidden_size
259
260
261
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
262
263
264
            raise ValueError(f"embed_dim must be divisible by num_heads (got "
                             "`embed_dim`: {self.embed_dim} and `num_heads`:"
                             f" {self.num_heads}).")
265

266
267
268
269
270
        self.scale = self.head_dim**-0.5
        self.dropout = config.attention_dropout
        self.qkv_proj = QKVParallelLinear(
            hidden_size=self.embed_dim,
            head_size=self.head_dim,
271
            total_num_heads=self.num_heads,
272
            quant_config=quant_config,
273
            prefix=f"{prefix}.qkv_proj",
274
        )
275

276
277
278
279
        self.out_proj = RowParallelLinear(
            input_size=self.embed_dim,
            output_size=self.embed_dim,
            quant_config=quant_config,
280
            prefix=f"{prefix}.out_proj",
281
282
        )

283
284
        self.tp_size = get_tensor_model_parallel_world_size()
        self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
285
286
287
288
289
290
291
292
293

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        """Input shape: Batch x Time x Channel"""
        batch_size, q_len, _ = hidden_states.size()

        qkv_states, _ = self.qkv_proj(hidden_states)
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
        query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)

        query_states = query_states.view(batch_size, q_len,
                                         self.num_heads_per_partition,
                                         self.head_dim)
        key_states = key_states.view(batch_size, q_len,
                                     self.num_heads_per_partition,
                                     self.head_dim)
        value_states = value_states.view(batch_size, q_len,
                                         self.num_heads_per_partition,
                                         self.head_dim)

        out = xops.memory_efficient_attention_forward(query_states,
                                                      key_states,
                                                      value_states,
                                                      p=self.dropout,
                                                      scale=self.scale)
        out = out.view(batch_size, q_len, -1)
        attn_output, _ = self.out_proj(out)
313

314
        return attn_output, None
315
316
317
318
319
320


class SiglipMLP(nn.Module):

    def __init__(
        self,
321
        config: SiglipVisionConfig,
322
        quant_config: Optional[QuantizationConfig] = None,
323
324
        prefix: str = "",
    ) -> None:
325
        super().__init__()
326

327
328
329
330
331
332
333
334
335
336
        self.config = config
        self.activation_fn = get_act_fn(config.hidden_act)

        # For quantization, we require the hidden size to be a multiple of 64
        quantizable = (config.hidden_size % 64 == 0
                       and config.intermediate_size % 64 == 0)
        self.fc1 = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
            quant_config=quant_config if quantizable else None,
337
            prefix=f"{prefix}.fc1",
338
339
340
341
342
        )
        self.fc2 = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
            quant_config=quant_config if quantizable else None,
343
            prefix=f"{prefix}.fc2",
344
345
346
347
348
349
350
351
352
353
354
355
356
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)
        return hidden_states


class SiglipEncoderLayer(nn.Module):

    def __init__(
        self,
357
        config: SiglipVisionConfig,
358
        quant_config: Optional[QuantizationConfig] = None,
359
360
        prefix: str = "",
    ) -> None:
361
        super().__init__()
362

363
364
        self.embed_dim = config.hidden_size

365
366
367
        num_heads = config.num_attention_heads
        tp_size = get_tensor_model_parallel_world_size()
        if USE_XFORMERS_OPS and num_heads % tp_size == 0:
368
369
370
371
372
            self.self_attn = SiglipParallelAttention(
                config,
                quant_config=quant_config,
                prefix=f"{prefix}.self_attn",
            )
373
374
375
        else:
            self.self_attn = SiglipSdpaAttention(config)

376
377
378
379
380
        self.layer_norm1 = nn.LayerNorm(self.embed_dim,
                                        eps=config.layer_norm_eps)
        self.mlp = SiglipMLP(
            config,
            quant_config=quant_config,
381
            prefix=f"{prefix}.mlp",
382
383
384
385
386
387
388
        )
        self.layer_norm2 = nn.LayerNorm(self.embed_dim,
                                        eps=config.layer_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
389
    ) -> Tuple[torch.Tensor, None]:
390
391
392
        residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
393
        hidden_states, _ = self.self_attn(hidden_states=hidden_states)
394
395
396
397
398
399
400
401
402
403
404
405
406
407
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states, None


class SiglipEncoder(nn.Module):

    def __init__(
        self,
408
        config: SiglipVisionConfig,
409
        quant_config: Optional[QuantizationConfig] = None,
410
        num_hidden_layers_override: Optional[int] = None,
411
412
        prefix: str = "",
    ) -> None:
413
        super().__init__()
414

415
        self.config = config
416
417
418
419
420
421

        if num_hidden_layers_override is None:
            num_hidden_layers = config.num_hidden_layers
        else:
            num_hidden_layers = num_hidden_layers_override

422
        self.layers = nn.ModuleList([
423
424
425
426
            SiglipEncoderLayer(config,
                               quant_config=quant_config,
                               prefix=f"{prefix}.layers.{layer_idx}")
            for layer_idx in range(num_hidden_layers)
427
428
429
430
431
        ])

    def forward(
        self,
        inputs_embeds: torch.Tensor,
432
    ) -> torch.Tensor:
433
434
435
436
437
438
439
440
441
442
443
444
445
446
        hidden_states = inputs_embeds
        for encoder_layer in self.layers:
            hidden_states, _ = encoder_layer(hidden_states)

        return hidden_states


class SiglipMultiheadAttentionPoolingHead(nn.Module):
    """Multihead Attention Pooling."""

    def __init__(
        self,
        config: SiglipVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
447
448
        prefix: str = "",
    ) -> None:
449
450
451
452
453
454
455
456
        super().__init__()

        self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
        # TODO(ChristopherCho): Implement vLLM version of MultiheadAttention
        self.attention = torch.nn.MultiheadAttention(
            config.hidden_size, config.num_attention_heads, batch_first=True)
        self.layernorm = nn.LayerNorm(config.hidden_size,
                                      eps=config.layer_norm_eps)
457
458
459
        self.mlp = SiglipMLP(config=config,
                             quant_config=quant_config,
                             prefix=f"{prefix}.mlp")
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479

    def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
        batch_size = hidden_state.shape[0]
        probe = self.probe.repeat(batch_size, 1, 1)

        hidden_state = self.attention(probe, hidden_state, hidden_state)[0]

        residual = hidden_state
        hidden_state = self.layernorm(hidden_state)
        hidden_state = residual + self.mlp(hidden_state)

        return hidden_state[:, 0]


class SiglipVisionTransformer(nn.Module):

    def __init__(
        self,
        config: SiglipVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
480
        *,
481
        num_hidden_layers_override: Optional[int] = None,
482
483
484
        require_post_norm: Optional[bool] = None,
        prefix: str = "",
    ) -> None:
485
        super().__init__()
486

487
488
489
490
491
492
493
        self.config = config
        embed_dim = config.hidden_size

        self.embeddings = SiglipVisionEmbeddings(config)
        self.encoder = SiglipEncoder(
            config,
            quant_config=quant_config,
494
            num_hidden_layers_override=num_hidden_layers_override,
495
            prefix=f"{prefix}.encoder",
496
        )
497

498
        num_hidden_layers = config.num_hidden_layers
499
500
        if len(self.encoder.layers) > config.num_hidden_layers:
            raise ValueError(
501
                f"The original encoder only has {num_hidden_layers} "
502
503
                f"layers, but you requested {len(self.encoder.layers)} layers."
            )
504
505
506
507
508
509

        # If possible, skip post_layernorm to conserve memory
        if require_post_norm is None:
            require_post_norm = len(self.encoder.layers) == num_hidden_layers

        if require_post_norm:
510
511
512
            self.post_layernorm = nn.LayerNorm(embed_dim,
                                               eps=config.layer_norm_eps)
        else:
513
514
            self.post_layernorm = None

515
516
517
518
        self.use_head = (True if not hasattr(config, "vision_use_head") else
                         config.vision_use_head)
        if self.use_head:
            self.head = SiglipMultiheadAttentionPoolingHead(
519
520
521
522
                config=config,
                quant_config=quant_config,
                prefix=f"{prefix}.head",
            )
523
524
525
526
527
528
529
530
531
532
533
534
535

    def forward(
        self,
        pixel_values: torch.Tensor,
        interpolate_pos_encoding: bool = True,
    ) -> torch.Tensor:
        hidden_states = self.embeddings(
            pixel_values,
            interpolate_pos_encoding=interpolate_pos_encoding,
        )

        encoder_outputs = self.encoder(inputs_embeds=hidden_states)

536
537
538
        if self.post_layernorm is None:
            return encoder_outputs

539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
        last_hidden_state = self.post_layernorm(encoder_outputs)
        # TODO: add this back when pooled_output is used in inference
        # if self.use_head:
        # pooled_output = self.head(last_hidden_state)

        return last_hidden_state


class SiglipVisionModel(nn.Module):
    config_class = SiglipVisionConfig
    main_input_name = "pixel_values"

    def __init__(
        self,
        config: SiglipVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
555
        *,
556
        num_hidden_layers_override: Optional[int] = None,
557
558
559
        require_post_norm: Optional[bool] = None,
        prefix: str = "",
    ) -> None:
560
        super().__init__()
561

562
563
564
565
        num_heads = config.num_attention_heads
        tp_size = get_tensor_model_parallel_world_size()
        self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0

566
567
568
        self.vision_model = SiglipVisionTransformer(
            config,
            quant_config,
569
            num_hidden_layers_override=num_hidden_layers_override,
570
571
            require_post_norm=require_post_norm,
            prefix=f"{prefix}.vision_model",
572
573
574
575
576
577
578
579
580
581
582
583
584
585
        )

    def get_input_embeddings(self) -> nn.Module:
        return self.vision_model.embeddings.patch_embedding

    def forward(
        self,
        pixel_values: torch.Tensor,
        interpolate_pos_encoding: bool = False,
    ) -> torch.Tensor:
        return self.vision_model(
            pixel_values=pixel_values,
            interpolate_pos_encoding=interpolate_pos_encoding,
        )
586
587

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
588
589
590
591
592
593
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ] if self.shard_weight else []
594
595
596
597
        params_dict = dict(self.named_parameters())
        layer_count = len(self.vision_model.encoder.layers)

        for name, loaded_weight in weights:
598
            # post_layernorm is optional in SiglipVisionModel
599
600
            if (name.startswith("vision_model.post_layernorm")
                    and self.vision_model.post_layernorm is None):
601
602
                continue

603
            # omit layers when num_hidden_layers_override is set
604
            if name.startswith("vision_model.encoder.layers"):
605
606
607
608
                layer_idx = int(name.split(".")[3])
                if layer_idx >= layer_count:
                    continue

609
610
611
612
613
614
615
616
617
618
619
620
621
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue

                param = params_dict[name.replace(weight_name, param_name)]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)