clip.py 16.5 KB
Newer Older
1
"""Minimal implementation of CLIPVisionModel intended to be only used
2
within a vision language model."""
3
from typing import Iterable, List, Optional, Tuple, Union
4
5
6

import torch
import torch.nn as nn
7
from PIL import Image
8
from transformers import CLIPVisionConfig
9
from transformers.models.clip.modeling_clip import CLIPSdpaAttention
10

11
from vllm.config import ModelConfig
12
from vllm.distributed import divide, get_tensor_model_parallel_world_size
13
from vllm.inputs import LLMInputs
14
15
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
16
                                               QKVParallelLinear,
17
                                               RowParallelLinear)
18
from vllm.model_executor.layers.quantization import QuantizationConfig
19
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
20
21
from vllm.multimodal.utils import (cached_get_tokenizer,
                                   repeat_and_pad_placeholder_tokens)
22
from vllm.sequence import SequenceData
23

24
25
26
27
28
29
try:
    from xformers import ops as xops
    USE_XFORMERS_OPS = True
except ImportError:
    USE_XFORMERS_OPS = False

30

31
def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
32
    assert image_size % patch_size == 0
33
34
35
36
37
38
39
40
41
42
43
    return image_size // patch_size


def get_clip_num_patches(*, image_size: int, patch_size: int) -> int:
    grid_length = get_clip_patch_grid_length(image_size=image_size,
                                             patch_size=patch_size)
    return grid_length * grid_length


def get_clip_image_feature_size(hf_config: CLIPVisionConfig) -> int:
    return get_clip_num_patches(image_size=hf_config.image_size,
44
                                patch_size=hf_config.patch_size) + 1
45
46


47
48
49
50
def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int:
    return get_clip_image_feature_size(hf_config)


51
52
53
def dummy_seq_data_for_clip(
    hf_config: CLIPVisionConfig,
    seq_len: int,
54
    num_images: int,
55
56
57
58
59
60
61
62
63
    *,
    image_token_id: int,
    image_feature_size_override: Optional[int] = None,
):
    if image_feature_size_override is None:
        image_feature_size = get_clip_image_feature_size(hf_config)
    else:
        image_feature_size = image_feature_size_override

64
65
66
67
    return SequenceData.from_token_counts(
        (image_token_id, image_feature_size * num_images),
        (0, seq_len - image_feature_size * num_images),
    )
68
69


70
def dummy_image_for_clip(
71
    hf_config: CLIPVisionConfig,
72
    num_images: int,
73
74
75
76
77
78
79
80
81
82
83
    *,
    image_width_override: Optional[int] = None,
    image_height_override: Optional[int] = None,
):
    width = height = hf_config.image_size
    if image_width_override is not None:
        width = image_width_override
    if image_height_override is not None:
        height = image_height_override

    image = Image.new("RGB", (width, height), color=0)
84
    return {"image": image if num_images == 1 else [image] * num_images}
85
86


87
88
89
90
91
92
def input_processor_for_clip(
    model_config: ModelConfig,
    hf_config: CLIPVisionConfig,
    llm_inputs: LLMInputs,
    *,
    image_token_id: int,
93
    image_feature_size_override: Optional[Union[int, List[int]]] = None,
94
95
96
97
98
99
100
101
):
    multi_modal_data = llm_inputs.get("multi_modal_data")
    if multi_modal_data is None or "image" not in multi_modal_data:
        return llm_inputs

    tokenizer = cached_get_tokenizer(model_config.tokenizer)

    if image_feature_size_override is None:
102
103
104
105
        image_data = multi_modal_data["image"]
        if isinstance(image_data, Image.Image):
            image_feature_size = get_clip_image_feature_size(hf_config)
        elif isinstance(image_data, torch.Tensor):
106
            num_images, image_feature_size, hidden_size = image_data.shape
107
108
        else:
            raise TypeError(f"Invalid image type: {type(image_data)}")
109
110
111
    else:
        image_feature_size = image_feature_size_override

112
    new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
113
114
115
        tokenizer,
        llm_inputs.get("prompt"),
        llm_inputs["prompt_token_ids"],
116
        placeholder_token_id=image_token_id,
117
118
119
120
121
122
123
124
125
        repeat_count=image_feature_size,
    )

    # NOTE: Create a defensive copy of the original inputs
    return LLMInputs(prompt_token_ids=new_token_ids,
                     prompt=new_prompt,
                     multi_modal_data=multi_modal_data)


126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
class CLIPVisionEmbeddings(nn.Module):

    def __init__(self, config: CLIPVisionConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size

        self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))

        self.patch_embedding = nn.Conv2d(
            in_channels=config.num_channels,
            out_channels=self.embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            bias=False,
        )

146
147
        self.num_patches = get_clip_num_patches(image_size=self.image_size,
                                                patch_size=self.patch_size)
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
        self.num_positions = self.num_patches + 1
        self.position_embedding = nn.Embedding(self.num_positions,
                                               self.embed_dim)
        self.register_buffer("position_ids",
                             torch.arange(self.num_positions).expand((1, -1)),
                             persistent=False)

    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
        batch_size = pixel_values.shape[0]
        target_dtype = self.patch_embedding.weight.dtype
        patch_embeds = self.patch_embedding(pixel_values.to(
            dtype=target_dtype))  # shape = [*, width, grid, grid]
        patch_embeds = patch_embeds.flatten(2).transpose(1, 2)

        class_embeds = self.class_embedding.expand(batch_size, 1, -1)
        embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
        embeddings = embeddings + self.position_embedding(self.position_ids)

        return embeddings


169
class CLIPParallelAttention(nn.Module):
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        config: CLIPVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
    ):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                "embed_dim must be divisible by num_heads "
                f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
                f" {self.num_heads}).")
        self.scale = self.head_dim**-0.5
        self.dropout = config.attention_dropout

        self.qkv_proj = QKVParallelLinear(
            hidden_size=self.embed_dim,
            head_size=self.head_dim,
            total_num_heads=self.num_heads,
            quant_config=quant_config,
        )

        self.out_proj = RowParallelLinear(
            input_size=self.embed_dim,
            output_size=self.embed_dim,
            quant_config=quant_config,
        )

        self.tp_size = get_tensor_model_parallel_world_size()
        self.num_heads_per_partition = divide(self.num_heads, self.tp_size)

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads,
                           self.head_dim).transpose(1, 2).contiguous()

    def forward(
        self,
        hidden_states: torch.Tensor,
    ):
        """Input shape: Batch x Time x Channel"""
        bsz, tgt_len, _ = hidden_states.size()

        qkv_states, _ = self.qkv_proj(hidden_states)
        query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)

        query_states = query_states.view(bsz, tgt_len,
                                         self.num_heads_per_partition,
                                         self.head_dim)
        key_states = key_states.view(bsz, tgt_len,
                                     self.num_heads_per_partition,
                                     self.head_dim)
        value_states = value_states.view(bsz, tgt_len,
                                         self.num_heads_per_partition,
                                         self.head_dim)

        out = xops.memory_efficient_attention_forward(query_states,
                                                      key_states,
                                                      value_states,
                                                      p=self.dropout,
                                                      scale=self.scale)
        out = out.view(bsz, tgt_len, -1)
        attn_output, _ = self.out_proj(out)

238
        return attn_output, None
239
240


241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
class CLIPMLP(nn.Module):

    def __init__(self,
                 config: CLIPVisionConfig,
                 quant_config: Optional[QuantizationConfig] = None):
        super().__init__()
        self.config = config
        self.activation_fn = get_act_fn(config.hidden_act)
        self.fc1 = ColumnParallelLinear(config.hidden_size,
                                        config.intermediate_size,
                                        bias=True,
                                        quant_config=quant_config)
        self.fc2 = RowParallelLinear(config.intermediate_size,
                                     config.hidden_size,
                                     bias=True,
                                     quant_config=quant_config)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)

        return hidden_states


class CLIPEncoderLayer(nn.Module):

    def __init__(self,
                 config: CLIPVisionConfig,
                 quant_config: Optional[QuantizationConfig] = None):
        super().__init__()

273
274
275
276
277
278
279
        num_heads = config.num_attention_heads
        tp_size = get_tensor_model_parallel_world_size()
        if USE_XFORMERS_OPS and num_heads % tp_size == 0:
            self.self_attn = CLIPParallelAttention(config,
                                                   quant_config=quant_config)
        else:
            self.self_attn = CLIPSdpaAttention(config)
280
281
282
283
284
285
        self.layer_norm1 = nn.LayerNorm(config.hidden_size,
                                        eps=config.layer_norm_eps)
        self.mlp = CLIPMLP(config, quant_config=quant_config)
        self.layer_norm2 = nn.LayerNorm(config.hidden_size,
                                        eps=config.layer_norm_eps)

286
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
287
288
289
290

        residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
291
        hidden_states, _ = self.self_attn(hidden_states=hidden_states)
292
293
294
295
296
297
298
299
300
301
302
303
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states


class CLIPEncoder(nn.Module):
    """
304
    Transformer encoder consisting of `config.num_hidden_layers` self
305
306
307
308
309
310
311
312
    attention layers. Each layer is a [`CLIPEncoderLayer`].

    Args:
        config: CLIPConfig
    """

    def __init__(self,
                 config: CLIPVisionConfig,
313
314
                 quant_config: Optional[QuantizationConfig] = None,
                 num_hidden_layers_override: Optional[int] = None):
315
316
        super().__init__()
        self.config = config
317
318
319
320
321

        if num_hidden_layers_override is None:
            num_hidden_layers = config.num_hidden_layers
        else:
            num_hidden_layers = num_hidden_layers_override
322
323
        self.layers = nn.ModuleList([
            CLIPEncoderLayer(config=config, quant_config=quant_config)
324
            for _ in range(num_hidden_layers)
325
326
        ])

327
    def forward(self, inputs_embeds: torch.Tensor):
328
329

        hidden_states = inputs_embeds
330
        for encoder_layer in self.layers:
331
332
333
334
335
336
337
338
339
            hidden_states = encoder_layer(hidden_states)

        return hidden_states


class CLIPVisionTransformer(nn.Module):

    def __init__(self,
                 config: CLIPVisionConfig,
340
341
                 quant_config: Optional[QuantizationConfig] = None,
                 num_hidden_layers_override: Optional[int] = None):
342
343
344
345
346
347
348
349
350
        super().__init__()
        self.config = config
        embed_dim = config.hidden_size

        self.embeddings = CLIPVisionEmbeddings(config)

        # NOTE: This typo of "layrnorm" is not fixed on purpose to match
        # the original transformers code and name of the model weights.
        self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
351
352
353
354
        self.encoder = CLIPEncoder(
            config=config,
            quant_config=quant_config,
            num_hidden_layers_override=num_hidden_layers_override)
355

356
357
358
359
360
361
362
363
364
365
366
367
368
        if len(self.encoder.layers) > config.num_hidden_layers:
            raise ValueError(
                f"The original encoder only has {config.num_hidden_layers} "
                f"layers, but you requested {len(self.encoder.layers)} layers."
            )
        elif len(self.encoder.layers) == config.num_hidden_layers:
            self.post_layernorm = nn.LayerNorm(embed_dim,
                                               eps=config.layer_norm_eps)
        else:
            # post_layernorm is unused when we extract intermediate features
            # In this case, we can skip it to conserve memory
            self.post_layernorm = None

369
370
371
372
373
374
375
    def forward(
        self,
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:

        hidden_states = self.embeddings(pixel_values)
        hidden_states = self.pre_layrnorm(hidden_states)
376
        hidden_states = self.encoder(inputs_embeds=hidden_states)
377

378
379
380
381
        if self.post_layernorm is None:
            return hidden_states

        return self.post_layernorm(hidden_states)
382
383
384
385
386
387
388
389
390


class CLIPVisionModel(nn.Module):

    config_class = CLIPVisionConfig
    main_input_name = "pixel_values"

    def __init__(self,
                 config: CLIPVisionConfig,
391
392
                 quant_config: Optional[QuantizationConfig] = None,
                 num_hidden_layers_override: Optional[int] = None):
393
        super().__init__()
394

395
396
397
398
        tp_size = get_tensor_model_parallel_world_size()
        num_heads = config.num_attention_heads
        self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0

399
400
401
402
        self.vision_model = CLIPVisionTransformer(
            config=config,
            quant_config=quant_config,
            num_hidden_layers_override=num_hidden_layers_override)
403

404
405
    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
        return self.vision_model(pixel_values)
406
407
408
409

    @property
    def device(self):
        return next(self.parameters()).device
410

411
412
    # (TODO) Add prefix argument for filtering out weights to be loaded
    #        ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
413
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
414
415
416
417
418
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
419
        ] if self.shard_weight else []
420
421
422
423
424
        params_dict = dict(self.named_parameters())
        layer_count = len(self.vision_model.encoder.layers)

        for name, loaded_weight in weights:
            # post_layernorm is not needed in CLIPVisionModel
425
426
            if (name.startswith("vision_model.post_layernorm")
                    and self.vision_model.post_layernorm is None):
427
                continue
428

429
            # omit layers when num_hidden_layers_override is set
430
            if name.startswith("vision_model.encoder.layers"):
431
432
433
434
                layer_idx = int(name.split(".")[3])
                if layer_idx >= layer_count:
                    continue

435
436
437
438
439
440
441
442
443
444
445
446
447
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue

                param = params_dict[name.replace(weight_name, param_name)]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)