clip.py 20.5 KB
Newer Older
1
"""Minimal implementation of CLIPVisionModel intended to be only used
2
within a vision language model."""
3
from typing import Iterable, List, Optional, Set, Tuple, Union
4

5
import numpy as np
6
7
import torch
import torch.nn as nn
8
import torch.nn.functional as F
9
from PIL import Image
10
11
from transformers import CLIPVisionConfig

12
from vllm.attention.selector import _Backend
13
from vllm.config import ModelConfig
14
from vllm.distributed import divide, get_tensor_model_parallel_world_size
15
from vllm.inputs import DecoderOnlyInputs, token_inputs
16
17
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
18
                                               QKVParallelLinear,
19
                                               RowParallelLinear)
20
from vllm.model_executor.layers.quantization import QuantizationConfig
21
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
22
from vllm.multimodal.utils import (cached_get_tokenizer,
23
                                   consecutive_placeholder_ranges,
24
25
                                   repeat_and_pad_placeholder_tokens,
                                   resolve_visual_encoder_outputs)
26
from vllm.sequence import SequenceData
27

28
from .utils import get_vit_attn_backend
29

30

31
def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
32
    assert image_size % patch_size == 0
33
34
35
36
37
38
39
40
41
42
43
    return image_size // patch_size


def get_clip_num_patches(*, image_size: int, patch_size: int) -> int:
    grid_length = get_clip_patch_grid_length(image_size=image_size,
                                             patch_size=patch_size)
    return grid_length * grid_length


def get_clip_image_feature_size(hf_config: CLIPVisionConfig) -> int:
    return get_clip_num_patches(image_size=hf_config.image_size,
44
                                patch_size=hf_config.patch_size) + 1
45
46


47
48
49
50
def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int:
    return get_clip_image_feature_size(hf_config)


51
52
53
54
55
56
57
def dummy_seq_data_for_clip(hf_config: CLIPVisionConfig,
                            seq_len: int,
                            num_images: int,
                            *,
                            image_token_id: int,
                            image_feature_size_override: Optional[int] = None,
                            mm_key: str = "image"):
58
59
60
61
62
    if image_feature_size_override is None:
        image_feature_size = get_clip_image_feature_size(hf_config)
    else:
        image_feature_size = image_feature_size_override

63
    return SequenceData.from_prompt_token_counts(
64
65
        (image_token_id, image_feature_size * num_images),
        (0, seq_len - image_feature_size * num_images),
66
67
68
69
70
    ), {
        mm_key:
        consecutive_placeholder_ranges(num_items=num_images,
                                       item_size=image_feature_size)
    }
71
72


73
def dummy_image_for_clip(
74
    hf_config: CLIPVisionConfig,
75
    num_images: int,
76
77
78
79
80
81
82
83
84
85
86
    *,
    image_width_override: Optional[int] = None,
    image_height_override: Optional[int] = None,
):
    width = height = hf_config.image_size
    if image_width_override is not None:
        width = image_width_override
    if image_height_override is not None:
        height = image_height_override

    image = Image.new("RGB", (width, height), color=0)
87
    return {"image": image if num_images == 1 else [image] * num_images}
88
89


90
91
92
def dummy_video_for_clip(
    hf_config: CLIPVisionConfig,
    num_frames: int,
93
    num_videos: int = 1,
94
95
96
97
98
99
100
101
102
103
104
    *,
    image_width_override: Optional[int] = None,
    image_height_override: Optional[int] = None,
):
    pil_frame = dummy_image_for_clip(
        hf_config,
        num_images=1,
        image_width_override=image_width_override,
        image_height_override=image_height_override)
    np_frame = np.array(pil_frame["image"])
    mm_data_per_video = np.repeat([np_frame], num_frames, axis=0)
105
106
    video_data = [mm_data_per_video] * num_videos
    mm_data = {"video": video_data}
107
108
109
    return mm_data


110
111
112
def input_processor_for_clip(
    model_config: ModelConfig,
    hf_config: CLIPVisionConfig,
113
    inputs: DecoderOnlyInputs,
114
115
    *,
    image_token_id: int,
116
    image_feature_size_override: Optional[Union[int, List[int]]] = None,
117
):
118
    multi_modal_data = inputs.get("multi_modal_data")
119
    if multi_modal_data is None or "image" not in multi_modal_data:
120
        return inputs
121

122
123
124
125
126
    if "multi_modal_placeholders" in inputs and "image" in inputs[
            "multi_modal_placeholders"]:
        # The inputs already have placeholders.
        return inputs

127
128
129
    tokenizer = cached_get_tokenizer(model_config.tokenizer)

    if image_feature_size_override is None:
130
131
132
133
        image_data = multi_modal_data["image"]
        if isinstance(image_data, Image.Image):
            image_feature_size = get_clip_image_feature_size(hf_config)
        elif isinstance(image_data, torch.Tensor):
134
            num_images, image_feature_size, hidden_size = image_data.shape
135
136
        else:
            raise TypeError(f"Invalid image type: {type(image_data)}")
137
138
139
    else:
        image_feature_size = image_feature_size_override

140
    new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
141
        tokenizer,
142
143
        inputs.get("prompt"),
        inputs["prompt_token_ids"],
144
        placeholder_token_id=image_token_id,
145
146
147
148
        repeat_count=image_feature_size,
    )

    # NOTE: Create a defensive copy of the original inputs
149
150
    return token_inputs(prompt_token_ids=new_token_ids,
                        prompt=new_prompt,
151
152
                        multi_modal_data=multi_modal_data,
                        multi_modal_placeholders={"image": ranges})
153
154


155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
class CLIPVisionEmbeddings(nn.Module):

    def __init__(self, config: CLIPVisionConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size

        self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))

        self.patch_embedding = nn.Conv2d(
            in_channels=config.num_channels,
            out_channels=self.embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            bias=False,
        )

175
176
        self.num_patches = get_clip_num_patches(image_size=self.image_size,
                                                patch_size=self.patch_size)
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
        self.num_positions = self.num_patches + 1
        self.position_embedding = nn.Embedding(self.num_positions,
                                               self.embed_dim)
        self.register_buffer("position_ids",
                             torch.arange(self.num_positions).expand((1, -1)),
                             persistent=False)

    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
        batch_size = pixel_values.shape[0]
        target_dtype = self.patch_embedding.weight.dtype
        patch_embeds = self.patch_embedding(pixel_values.to(
            dtype=target_dtype))  # shape = [*, width, grid, grid]
        patch_embeds = patch_embeds.flatten(2).transpose(1, 2)

        class_embeds = self.class_embedding.expand(batch_size, 1, -1)
        embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
        embeddings = embeddings + self.position_embedding(self.position_ids)

        return embeddings


198
class CLIPAttention(nn.Module):
199
200
201
202
203
204
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        config: CLIPVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
205
        prefix: str = "",
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
    ):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                "embed_dim must be divisible by num_heads "
                f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
                f" {self.num_heads}).")
        self.scale = self.head_dim**-0.5
        self.dropout = config.attention_dropout

        self.qkv_proj = QKVParallelLinear(
            hidden_size=self.embed_dim,
            head_size=self.head_dim,
            total_num_heads=self.num_heads,
            quant_config=quant_config,
225
            prefix=f"{prefix}.qkv_proj",
226
227
228
229
230
231
        )

        self.out_proj = RowParallelLinear(
            input_size=self.embed_dim,
            output_size=self.embed_dim,
            quant_config=quant_config,
232
            prefix=f"{prefix}.out_proj",
233
234
235
236
237
        )

        self.tp_size = get_tensor_model_parallel_world_size()
        self.num_heads_per_partition = divide(self.num_heads, self.tp_size)

238
239
240
241
242
243
        # Detect attention implementation.
        self.attn_backend = get_vit_attn_backend(support_fa=False)
        if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}:
            raise RuntimeError(
                f"CLIP does not support {self.attn_backend} backend now.")

244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads,
                           self.head_dim).transpose(1, 2).contiguous()

    def forward(
        self,
        hidden_states: torch.Tensor,
    ):
        """Input shape: Batch x Time x Channel"""
        bsz, tgt_len, _ = hidden_states.size()

        qkv_states, _ = self.qkv_proj(hidden_states)
        query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)

        query_states = query_states.view(bsz, tgt_len,
                                         self.num_heads_per_partition,
                                         self.head_dim)
        key_states = key_states.view(bsz, tgt_len,
                                     self.num_heads_per_partition,
                                     self.head_dim)
        value_states = value_states.view(bsz, tgt_len,
                                         self.num_heads_per_partition,
                                         self.head_dim)

268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
        if self.attn_backend == _Backend.XFORMERS:
            from xformers import ops as xops

            out = xops.memory_efficient_attention_forward(query_states,
                                                          key_states,
                                                          value_states,
                                                          p=self.dropout,
                                                          scale=self.scale)
        elif self.attn_backend == _Backend.TORCH_SDPA:
            query_states, key_states, value_states = (x.transpose(1, 2)
                                                      for x in (query_states,
                                                                key_states,
                                                                value_states))
            out = F.scaled_dot_product_attention(query_states,
                                                 key_states,
                                                 value_states,
                                                 dropout_p=self.dropout,
                                                 scale=self.scale)
            out = out.transpose(1, 2)

288
289
290
        out = out.view(bsz, tgt_len, -1)
        attn_output, _ = self.out_proj(out)

291
        return attn_output, None
292
293


294
295
class CLIPMLP(nn.Module):

296
297
298
299
300
301
    def __init__(
        self,
        config: CLIPVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
302
303
304
305
306
307
        super().__init__()
        self.config = config
        self.activation_fn = get_act_fn(config.hidden_act)
        self.fc1 = ColumnParallelLinear(config.hidden_size,
                                        config.intermediate_size,
                                        bias=True,
308
309
                                        quant_config=quant_config,
                                        prefix=f"{prefix}.fc1")
310
311
312
        self.fc2 = RowParallelLinear(config.intermediate_size,
                                     config.hidden_size,
                                     bias=True,
313
314
                                     quant_config=quant_config,
                                     prefix=f"{prefix}.fc2")
315
316
317
318
319
320
321
322
323
324
325

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)

        return hidden_states


class CLIPEncoderLayer(nn.Module):

326
327
328
329
330
331
    def __init__(
        self,
        config: CLIPVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
332
        super().__init__()
333
334
335
336
337
        self.self_attn = CLIPAttention(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
        )
338
339
        self.layer_norm1 = nn.LayerNorm(config.hidden_size,
                                        eps=config.layer_norm_eps)
340
341
342
        self.mlp = CLIPMLP(config,
                           quant_config=quant_config,
                           prefix=f"{prefix}.mlp")
343
344
345
        self.layer_norm2 = nn.LayerNorm(config.hidden_size,
                                        eps=config.layer_norm_eps)

346
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
347
348
349
350

        residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
351
        hidden_states, _ = self.self_attn(hidden_states=hidden_states)
352
353
354
355
356
357
358
359
360
361
362
363
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states


class CLIPEncoder(nn.Module):
    """
364
    Transformer encoder consisting of `config.num_hidden_layers` self
365
366
367
368
369
370
    attention layers. Each layer is a [`CLIPEncoderLayer`].

    Args:
        config: CLIPConfig
    """

371
372
373
374
375
376
377
    def __init__(
        self,
        config: CLIPVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        num_hidden_layers_override: Optional[int] = None,
        prefix: str = "",
    ) -> None:
378
        super().__init__()
379

380
        self.config = config
381
382
383
384
385

        if num_hidden_layers_override is None:
            num_hidden_layers = config.num_hidden_layers
        else:
            num_hidden_layers = num_hidden_layers_override
386
        self.layers = nn.ModuleList([
387
388
389
390
            CLIPEncoderLayer(config=config,
                             quant_config=quant_config,
                             prefix=f"{prefix}.layers.{layer_idx}")
            for layer_idx in range(num_hidden_layers)
391
392
        ])

393
394
395
396
    def forward(
        self, inputs_embeds: torch.Tensor, return_all_hidden_states: bool
    ) -> Union[torch.Tensor, list[torch.Tensor]]:
        hidden_states_pool = []
397
        hidden_states = inputs_embeds
398

399
        for encoder_layer in self.layers:
400
            hidden_states = encoder_layer(hidden_states)
401
402
403
404
405
406
            if return_all_hidden_states:
                hidden_states_pool.append(hidden_states)
        # If we have multiple feature sample layers, we return all hidden
        # states in order and grab the ones we need by index.
        if return_all_hidden_states:
            return hidden_states_pool
407
408
409
410
411
        return hidden_states


class CLIPVisionTransformer(nn.Module):

412
413
414
415
416
417
418
419
420
    def __init__(
        self,
        config: CLIPVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        *,
        num_hidden_layers_override: Optional[int] = None,
        require_post_norm: Optional[bool] = None,
        prefix: str = "",
    ) -> None:
421
        super().__init__()
422

423
424
425
426
427
428
429
430
        self.config = config
        embed_dim = config.hidden_size

        self.embeddings = CLIPVisionEmbeddings(config)

        # NOTE: This typo of "layrnorm" is not fixed on purpose to match
        # the original transformers code and name of the model weights.
        self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
431

432
433
434
        self.encoder = CLIPEncoder(
            config=config,
            quant_config=quant_config,
435
436
437
            num_hidden_layers_override=num_hidden_layers_override,
            prefix=f"{prefix}.encoder",
        )
438

439
        num_hidden_layers = config.num_hidden_layers
440
441
        if len(self.encoder.layers) > config.num_hidden_layers:
            raise ValueError(
442
                f"The original encoder only has {num_hidden_layers} "
443
444
                f"layers, but you requested {len(self.encoder.layers)} layers."
            )
445
446
447
448
449
450

        # If possible, skip post_layernorm to conserve memory
        if require_post_norm is None:
            require_post_norm = len(self.encoder.layers) == num_hidden_layers

        if require_post_norm:
451
452
453
454
455
            self.post_layernorm = nn.LayerNorm(embed_dim,
                                               eps=config.layer_norm_eps)
        else:
            self.post_layernorm = None

456
457
458
    def forward(
        self,
        pixel_values: torch.Tensor,
459
        feature_sample_layers: Optional[list[int]] = None,
460
461
462
463
464
    ) -> torch.Tensor:

        hidden_states = self.embeddings(pixel_values)
        hidden_states = self.pre_layrnorm(hidden_states)

465
466
467
468
469
470
471
472
473
474
475
476
        return_all_hidden_states = feature_sample_layers is not None

        # Produces either the last layer output or all of the hidden states,
        # depending on if we have feature_sample_layers or not
        encoder_outputs = self.encoder(
            inputs_embeds=hidden_states,
            return_all_hidden_states=return_all_hidden_states)

        # Handle post-norm (if applicable) and stacks feature layers if needed
        encoder_outputs = resolve_visual_encoder_outputs(
            encoder_outputs, feature_sample_layers, self.post_layernorm,
            self.config.num_hidden_layers)
477

478
        return encoder_outputs
479
480
481
482
483
484
485


class CLIPVisionModel(nn.Module):

    config_class = CLIPVisionConfig
    main_input_name = "pixel_values"

486
487
488
489
490
491
492
493
494
    def __init__(
        self,
        config: CLIPVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        *,
        num_hidden_layers_override: Optional[int] = None,
        require_post_norm: Optional[bool] = None,
        prefix: str = "",
    ) -> None:
495
        super().__init__()
496
497
498
        self.vision_model = CLIPVisionTransformer(
            config=config,
            quant_config=quant_config,
499
500
            num_hidden_layers_override=num_hidden_layers_override,
            require_post_norm=require_post_norm,
501
            prefix=f"{prefix}.vision_model")
502

503
504
505
506
507
508
    def forward(
        self,
        pixel_values: torch.Tensor,
        feature_sample_layers: Optional[list[int]] = None,
    ) -> torch.Tensor:
        return self.vision_model(pixel_values, feature_sample_layers)
509
510
511
512

    @property
    def device(self):
        return next(self.parameters()).device
513

514
515
    # (TODO) Add prefix argument for filtering out weights to be loaded
    #        ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
516
517
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
518
519
520
521
522
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
523
        ]
524
        params_dict = dict(self.named_parameters())
525
        loaded_params: Set[str] = set()
526
527
528
529
        layer_count = len(self.vision_model.encoder.layers)

        for name, loaded_weight in weights:
            # post_layernorm is not needed in CLIPVisionModel
530
531
            if (name.startswith("vision_model.post_layernorm")
                    and self.vision_model.post_layernorm is None):
532
                continue
533

534
            # omit layers when num_hidden_layers_override is set
535
            if name.startswith("vision_model.encoder.layers"):
536
537
538
539
                layer_idx = int(name.split(".")[3])
                if layer_idx >= layer_count:
                    continue

540
541
542
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
543
                name = name.replace(weight_name, param_name)
544

545
                param = params_dict[name]
546
547
548
549
550
551
552
553
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
554
555
            loaded_params.add(name)
        return loaded_params