clip.py 18 KB
Newer Older
1
"""Minimal implementation of CLIPVisionModel intended to be only used
2
within a vision language model."""
3
from typing import Iterable, List, Optional, Tuple, Union
4

5
import numpy as np
6
7
import torch
import torch.nn as nn
8
from PIL import Image
9
from transformers import CLIPVisionConfig
10
from transformers.models.clip.modeling_clip import CLIPSdpaAttention
11

12
from vllm.config import ModelConfig
13
from vllm.distributed import divide, get_tensor_model_parallel_world_size
14
from vllm.inputs import DecoderOnlyInputs, token_inputs
15
16
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
17
                                               QKVParallelLinear,
18
                                               RowParallelLinear)
19
from vllm.model_executor.layers.quantization import QuantizationConfig
20
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
21
22
from vllm.multimodal.utils import (cached_get_tokenizer,
                                   repeat_and_pad_placeholder_tokens)
23
from vllm.sequence import SequenceData
24

25
26
27
28
29
30
try:
    from xformers import ops as xops
    USE_XFORMERS_OPS = True
except ImportError:
    USE_XFORMERS_OPS = False

31

32
def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
33
    assert image_size % patch_size == 0
34
35
36
37
38
39
40
41
42
43
44
    return image_size // patch_size


def get_clip_num_patches(*, image_size: int, patch_size: int) -> int:
    grid_length = get_clip_patch_grid_length(image_size=image_size,
                                             patch_size=patch_size)
    return grid_length * grid_length


def get_clip_image_feature_size(hf_config: CLIPVisionConfig) -> int:
    return get_clip_num_patches(image_size=hf_config.image_size,
45
                                patch_size=hf_config.patch_size) + 1
46
47


48
49
50
51
def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int:
    return get_clip_image_feature_size(hf_config)


52
53
54
def dummy_seq_data_for_clip(
    hf_config: CLIPVisionConfig,
    seq_len: int,
55
    num_images: int,
56
57
58
59
60
61
62
63
64
    *,
    image_token_id: int,
    image_feature_size_override: Optional[int] = None,
):
    if image_feature_size_override is None:
        image_feature_size = get_clip_image_feature_size(hf_config)
    else:
        image_feature_size = image_feature_size_override

65
    return SequenceData.from_prompt_token_counts(
66
67
68
        (image_token_id, image_feature_size * num_images),
        (0, seq_len - image_feature_size * num_images),
    )
69
70


71
def dummy_image_for_clip(
72
    hf_config: CLIPVisionConfig,
73
    num_images: int,
74
75
76
77
78
79
80
81
82
83
84
    *,
    image_width_override: Optional[int] = None,
    image_height_override: Optional[int] = None,
):
    width = height = hf_config.image_size
    if image_width_override is not None:
        width = image_width_override
    if image_height_override is not None:
        height = image_height_override

    image = Image.new("RGB", (width, height), color=0)
85
    return {"image": image if num_images == 1 else [image] * num_images}
86
87


88
89
90
def dummy_video_for_clip(
    hf_config: CLIPVisionConfig,
    num_frames: int,
91
    num_videos: int = 1,
92
93
94
95
96
97
98
99
100
101
102
    *,
    image_width_override: Optional[int] = None,
    image_height_override: Optional[int] = None,
):
    pil_frame = dummy_image_for_clip(
        hf_config,
        num_images=1,
        image_width_override=image_width_override,
        image_height_override=image_height_override)
    np_frame = np.array(pil_frame["image"])
    mm_data_per_video = np.repeat([np_frame], num_frames, axis=0)
103
104
    video_data = [mm_data_per_video] * num_videos
    mm_data = {"video": video_data}
105
106
107
    return mm_data


108
109
110
def input_processor_for_clip(
    model_config: ModelConfig,
    hf_config: CLIPVisionConfig,
111
    inputs: DecoderOnlyInputs,
112
113
    *,
    image_token_id: int,
114
    image_feature_size_override: Optional[Union[int, List[int]]] = None,
115
):
116
    multi_modal_data = inputs.get("multi_modal_data")
117
    if multi_modal_data is None or "image" not in multi_modal_data:
118
        return inputs
119
120
121
122

    tokenizer = cached_get_tokenizer(model_config.tokenizer)

    if image_feature_size_override is None:
123
124
125
126
        image_data = multi_modal_data["image"]
        if isinstance(image_data, Image.Image):
            image_feature_size = get_clip_image_feature_size(hf_config)
        elif isinstance(image_data, torch.Tensor):
127
            num_images, image_feature_size, hidden_size = image_data.shape
128
129
        else:
            raise TypeError(f"Invalid image type: {type(image_data)}")
130
131
132
    else:
        image_feature_size = image_feature_size_override

133
    new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
134
        tokenizer,
135
136
        inputs.get("prompt"),
        inputs["prompt_token_ids"],
137
        placeholder_token_id=image_token_id,
138
139
140
141
        repeat_count=image_feature_size,
    )

    # NOTE: Create a defensive copy of the original inputs
142
143
144
    return token_inputs(prompt_token_ids=new_token_ids,
                        prompt=new_prompt,
                        multi_modal_data=multi_modal_data)
145
146


147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
class CLIPVisionEmbeddings(nn.Module):

    def __init__(self, config: CLIPVisionConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size

        self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))

        self.patch_embedding = nn.Conv2d(
            in_channels=config.num_channels,
            out_channels=self.embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            bias=False,
        )

167
168
        self.num_patches = get_clip_num_patches(image_size=self.image_size,
                                                patch_size=self.patch_size)
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
        self.num_positions = self.num_patches + 1
        self.position_embedding = nn.Embedding(self.num_positions,
                                               self.embed_dim)
        self.register_buffer("position_ids",
                             torch.arange(self.num_positions).expand((1, -1)),
                             persistent=False)

    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
        batch_size = pixel_values.shape[0]
        target_dtype = self.patch_embedding.weight.dtype
        patch_embeds = self.patch_embedding(pixel_values.to(
            dtype=target_dtype))  # shape = [*, width, grid, grid]
        patch_embeds = patch_embeds.flatten(2).transpose(1, 2)

        class_embeds = self.class_embedding.expand(batch_size, 1, -1)
        embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
        embeddings = embeddings + self.position_embedding(self.position_ids)

        return embeddings


190
class CLIPParallelAttention(nn.Module):
191
192
193
194
195
196
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        config: CLIPVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
197
        prefix: str = "",
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
    ):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                "embed_dim must be divisible by num_heads "
                f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
                f" {self.num_heads}).")
        self.scale = self.head_dim**-0.5
        self.dropout = config.attention_dropout

        self.qkv_proj = QKVParallelLinear(
            hidden_size=self.embed_dim,
            head_size=self.head_dim,
            total_num_heads=self.num_heads,
            quant_config=quant_config,
217
            prefix=f"{prefix}.qkv_proj",
218
219
220
221
222
223
        )

        self.out_proj = RowParallelLinear(
            input_size=self.embed_dim,
            output_size=self.embed_dim,
            quant_config=quant_config,
224
            prefix=f"{prefix}.out_proj",
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
        )

        self.tp_size = get_tensor_model_parallel_world_size()
        self.num_heads_per_partition = divide(self.num_heads, self.tp_size)

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads,
                           self.head_dim).transpose(1, 2).contiguous()

    def forward(
        self,
        hidden_states: torch.Tensor,
    ):
        """Input shape: Batch x Time x Channel"""
        bsz, tgt_len, _ = hidden_states.size()

        qkv_states, _ = self.qkv_proj(hidden_states)
        query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)

        query_states = query_states.view(bsz, tgt_len,
                                         self.num_heads_per_partition,
                                         self.head_dim)
        key_states = key_states.view(bsz, tgt_len,
                                     self.num_heads_per_partition,
                                     self.head_dim)
        value_states = value_states.view(bsz, tgt_len,
                                         self.num_heads_per_partition,
                                         self.head_dim)

        out = xops.memory_efficient_attention_forward(query_states,
                                                      key_states,
                                                      value_states,
                                                      p=self.dropout,
                                                      scale=self.scale)
        out = out.view(bsz, tgt_len, -1)
        attn_output, _ = self.out_proj(out)

262
        return attn_output, None
263
264


265
266
class CLIPMLP(nn.Module):

267
268
269
270
271
272
    def __init__(
        self,
        config: CLIPVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
273
274
275
276
277
278
        super().__init__()
        self.config = config
        self.activation_fn = get_act_fn(config.hidden_act)
        self.fc1 = ColumnParallelLinear(config.hidden_size,
                                        config.intermediate_size,
                                        bias=True,
279
280
                                        quant_config=quant_config,
                                        prefix=f"{prefix}.fc1")
281
282
283
        self.fc2 = RowParallelLinear(config.intermediate_size,
                                     config.hidden_size,
                                     bias=True,
284
285
                                     quant_config=quant_config,
                                     prefix=f"{prefix}.fc2")
286
287
288
289
290
291
292
293
294
295
296

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)

        return hidden_states


class CLIPEncoderLayer(nn.Module):

297
298
299
300
301
302
    def __init__(
        self,
        config: CLIPVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
303
304
        super().__init__()

305
306
307
        num_heads = config.num_attention_heads
        tp_size = get_tensor_model_parallel_world_size()
        if USE_XFORMERS_OPS and num_heads % tp_size == 0:
308
309
310
311
312
            self.self_attn = CLIPParallelAttention(
                config,
                quant_config=quant_config,
                prefix=f"{prefix}.self_attn",
            )
313
314
        else:
            self.self_attn = CLIPSdpaAttention(config)
315
316
        self.layer_norm1 = nn.LayerNorm(config.hidden_size,
                                        eps=config.layer_norm_eps)
317
318
319
        self.mlp = CLIPMLP(config,
                           quant_config=quant_config,
                           prefix=f"{prefix}.mlp")
320
321
322
        self.layer_norm2 = nn.LayerNorm(config.hidden_size,
                                        eps=config.layer_norm_eps)

323
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
324
325
326
327

        residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
328
        hidden_states, _ = self.self_attn(hidden_states=hidden_states)
329
330
331
332
333
334
335
336
337
338
339
340
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states


class CLIPEncoder(nn.Module):
    """
341
    Transformer encoder consisting of `config.num_hidden_layers` self
342
343
344
345
346
347
    attention layers. Each layer is a [`CLIPEncoderLayer`].

    Args:
        config: CLIPConfig
    """

348
349
350
351
352
353
354
    def __init__(
        self,
        config: CLIPVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        num_hidden_layers_override: Optional[int] = None,
        prefix: str = "",
    ) -> None:
355
        super().__init__()
356

357
        self.config = config
358
359
360
361
362

        if num_hidden_layers_override is None:
            num_hidden_layers = config.num_hidden_layers
        else:
            num_hidden_layers = num_hidden_layers_override
363
        self.layers = nn.ModuleList([
364
365
366
367
            CLIPEncoderLayer(config=config,
                             quant_config=quant_config,
                             prefix=f"{prefix}.layers.{layer_idx}")
            for layer_idx in range(num_hidden_layers)
368
369
        ])

370
    def forward(self, inputs_embeds: torch.Tensor):
371
372

        hidden_states = inputs_embeds
373
        for encoder_layer in self.layers:
374
375
376
377
378
379
380
            hidden_states = encoder_layer(hidden_states)

        return hidden_states


class CLIPVisionTransformer(nn.Module):

381
382
383
384
385
386
387
388
389
    def __init__(
        self,
        config: CLIPVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        *,
        num_hidden_layers_override: Optional[int] = None,
        require_post_norm: Optional[bool] = None,
        prefix: str = "",
    ) -> None:
390
        super().__init__()
391

392
393
394
395
396
397
398
399
        self.config = config
        embed_dim = config.hidden_size

        self.embeddings = CLIPVisionEmbeddings(config)

        # NOTE: This typo of "layrnorm" is not fixed on purpose to match
        # the original transformers code and name of the model weights.
        self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
400
401
402
        self.encoder = CLIPEncoder(
            config=config,
            quant_config=quant_config,
403
404
405
            num_hidden_layers_override=num_hidden_layers_override,
            prefix=f"{prefix}.encoder",
        )
406

407
        num_hidden_layers = config.num_hidden_layers
408
409
        if len(self.encoder.layers) > config.num_hidden_layers:
            raise ValueError(
410
                f"The original encoder only has {num_hidden_layers} "
411
412
                f"layers, but you requested {len(self.encoder.layers)} layers."
            )
413
414
415
416
417
418

        # If possible, skip post_layernorm to conserve memory
        if require_post_norm is None:
            require_post_norm = len(self.encoder.layers) == num_hidden_layers

        if require_post_norm:
419
420
421
422
423
            self.post_layernorm = nn.LayerNorm(embed_dim,
                                               eps=config.layer_norm_eps)
        else:
            self.post_layernorm = None

424
425
426
427
428
429
430
    def forward(
        self,
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:

        hidden_states = self.embeddings(pixel_values)
        hidden_states = self.pre_layrnorm(hidden_states)
431
        hidden_states = self.encoder(inputs_embeds=hidden_states)
432

433
434
435
436
        if self.post_layernorm is None:
            return hidden_states

        return self.post_layernorm(hidden_states)
437
438
439
440
441
442
443


class CLIPVisionModel(nn.Module):

    config_class = CLIPVisionConfig
    main_input_name = "pixel_values"

444
445
446
447
448
449
450
451
452
    def __init__(
        self,
        config: CLIPVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        *,
        num_hidden_layers_override: Optional[int] = None,
        require_post_norm: Optional[bool] = None,
        prefix: str = "",
    ) -> None:
453
        super().__init__()
454

455
456
457
458
        tp_size = get_tensor_model_parallel_world_size()
        num_heads = config.num_attention_heads
        self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0

459
460
461
        self.vision_model = CLIPVisionTransformer(
            config=config,
            quant_config=quant_config,
462
463
464
465
            num_hidden_layers_override=num_hidden_layers_override,
            require_post_norm=require_post_norm,
            prefix=f"{prefix}.vision_model",
        )
466

467
468
    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
        return self.vision_model(pixel_values)
469
470
471
472

    @property
    def device(self):
        return next(self.parameters()).device
473

474
475
    # (TODO) Add prefix argument for filtering out weights to be loaded
    #        ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
476
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
477
478
479
480
481
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
482
        ] if self.shard_weight else []
483
484
485
486
487
        params_dict = dict(self.named_parameters())
        layer_count = len(self.vision_model.encoder.layers)

        for name, loaded_weight in weights:
            # post_layernorm is not needed in CLIPVisionModel
488
489
            if (name.startswith("vision_model.post_layernorm")
                    and self.vision_model.post_layernorm is None):
490
                continue
491

492
            # omit layers when num_hidden_layers_override is set
493
            if name.startswith("vision_model.encoder.layers"):
494
495
496
497
                layer_idx = int(name.split(".")[3])
                if layer_idx >= layer_count:
                    continue

498
499
500
501
502
503
504
505
506
507
508
509
510
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue

                param = params_dict[name.replace(weight_name, param_name)]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)