clip.py 17.2 KB
Newer Older
1
"""Minimal implementation of CLIPVisionModel intended to be only used
2
within a vision language model."""
3
from typing import Iterable, List, Optional, Tuple, Union
4

5
import numpy as np
6
7
import torch
import torch.nn as nn
8
from PIL import Image
9
from transformers import CLIPVisionConfig
10
from transformers.models.clip.modeling_clip import CLIPSdpaAttention
11

12
from vllm.config import ModelConfig
13
from vllm.distributed import divide, get_tensor_model_parallel_world_size
14
from vllm.inputs import LLMInputs
15
16
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
17
                                               QKVParallelLinear,
18
                                               RowParallelLinear)
19
from vllm.model_executor.layers.quantization import QuantizationConfig
20
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
21
22
from vllm.multimodal.utils import (cached_get_tokenizer,
                                   repeat_and_pad_placeholder_tokens)
23
from vllm.sequence import SequenceData
zhuwenwen's avatar
zhuwenwen committed
24
import vllm.envs as envs
25

26
try:
zhuwenwen's avatar
zhuwenwen committed
27
28
29
30
31
    if envs.VLLM_ATTENTION_BACKEND=="XFormers":
        from xformers import ops as xops
        USE_XFORMERS_OPS = True
    else:
        USE_XFORMERS_OPS = False
32
33
34
except ImportError:
    USE_XFORMERS_OPS = False

35

36
def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
37
    assert image_size % patch_size == 0
38
39
40
41
42
43
44
45
46
47
48
    return image_size // patch_size


def get_clip_num_patches(*, image_size: int, patch_size: int) -> int:
    grid_length = get_clip_patch_grid_length(image_size=image_size,
                                             patch_size=patch_size)
    return grid_length * grid_length


def get_clip_image_feature_size(hf_config: CLIPVisionConfig) -> int:
    return get_clip_num_patches(image_size=hf_config.image_size,
49
                                patch_size=hf_config.patch_size) + 1
50
51


52
53
54
55
def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int:
    return get_clip_image_feature_size(hf_config)


56
57
58
def dummy_seq_data_for_clip(
    hf_config: CLIPVisionConfig,
    seq_len: int,
59
    num_images: int,
60
61
62
63
64
65
66
67
68
    *,
    image_token_id: int,
    image_feature_size_override: Optional[int] = None,
):
    if image_feature_size_override is None:
        image_feature_size = get_clip_image_feature_size(hf_config)
    else:
        image_feature_size = image_feature_size_override

69
70
71
72
    return SequenceData.from_token_counts(
        (image_token_id, image_feature_size * num_images),
        (0, seq_len - image_feature_size * num_images),
    )
73
74


75
def dummy_image_for_clip(
76
    hf_config: CLIPVisionConfig,
77
    num_images: int,
78
79
80
81
82
83
84
85
86
87
88
    *,
    image_width_override: Optional[int] = None,
    image_height_override: Optional[int] = None,
):
    width = height = hf_config.image_size
    if image_width_override is not None:
        width = image_width_override
    if image_height_override is not None:
        height = image_height_override

    image = Image.new("RGB", (width, height), color=0)
89
    return {"image": image if num_images == 1 else [image] * num_images}
90
91


92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
def dummy_video_for_clip(
    hf_config: CLIPVisionConfig,
    num_frames: int,
    *,
    image_width_override: Optional[int] = None,
    image_height_override: Optional[int] = None,
):
    pil_frame = dummy_image_for_clip(
        hf_config,
        num_images=1,
        image_width_override=image_width_override,
        image_height_override=image_height_override)
    np_frame = np.array(pil_frame["image"])
    mm_data_per_video = np.repeat([np_frame], num_frames, axis=0)
    mm_data = {"video": mm_data_per_video}
    return mm_data


110
111
112
113
114
115
def input_processor_for_clip(
    model_config: ModelConfig,
    hf_config: CLIPVisionConfig,
    llm_inputs: LLMInputs,
    *,
    image_token_id: int,
116
    image_feature_size_override: Optional[Union[int, List[int]]] = None,
117
118
119
120
121
122
123
124
):
    multi_modal_data = llm_inputs.get("multi_modal_data")
    if multi_modal_data is None or "image" not in multi_modal_data:
        return llm_inputs

    tokenizer = cached_get_tokenizer(model_config.tokenizer)

    if image_feature_size_override is None:
125
126
127
128
        image_data = multi_modal_data["image"]
        if isinstance(image_data, Image.Image):
            image_feature_size = get_clip_image_feature_size(hf_config)
        elif isinstance(image_data, torch.Tensor):
129
            num_images, image_feature_size, hidden_size = image_data.shape
130
131
        else:
            raise TypeError(f"Invalid image type: {type(image_data)}")
132
133
134
    else:
        image_feature_size = image_feature_size_override

135
    new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
136
137
138
        tokenizer,
        llm_inputs.get("prompt"),
        llm_inputs["prompt_token_ids"],
139
        placeholder_token_id=image_token_id,
140
141
142
143
144
145
146
147
148
        repeat_count=image_feature_size,
    )

    # NOTE: Create a defensive copy of the original inputs
    return LLMInputs(prompt_token_ids=new_token_ids,
                     prompt=new_prompt,
                     multi_modal_data=multi_modal_data)


149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
class CLIPVisionEmbeddings(nn.Module):

    def __init__(self, config: CLIPVisionConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size

        self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))

        self.patch_embedding = nn.Conv2d(
            in_channels=config.num_channels,
            out_channels=self.embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            bias=False,
        )

169
170
        self.num_patches = get_clip_num_patches(image_size=self.image_size,
                                                patch_size=self.patch_size)
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
        self.num_positions = self.num_patches + 1
        self.position_embedding = nn.Embedding(self.num_positions,
                                               self.embed_dim)
        self.register_buffer("position_ids",
                             torch.arange(self.num_positions).expand((1, -1)),
                             persistent=False)

    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
        batch_size = pixel_values.shape[0]
        target_dtype = self.patch_embedding.weight.dtype
        patch_embeds = self.patch_embedding(pixel_values.to(
            dtype=target_dtype))  # shape = [*, width, grid, grid]
        patch_embeds = patch_embeds.flatten(2).transpose(1, 2)

        class_embeds = self.class_embedding.expand(batch_size, 1, -1)
        embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
        embeddings = embeddings + self.position_embedding(self.position_ids)

        return embeddings


192
class CLIPParallelAttention(nn.Module):
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        config: CLIPVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
    ):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                "embed_dim must be divisible by num_heads "
                f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
                f" {self.num_heads}).")
        self.scale = self.head_dim**-0.5
        self.dropout = config.attention_dropout

        self.qkv_proj = QKVParallelLinear(
            hidden_size=self.embed_dim,
            head_size=self.head_dim,
            total_num_heads=self.num_heads,
            quant_config=quant_config,
        )

        self.out_proj = RowParallelLinear(
            input_size=self.embed_dim,
            output_size=self.embed_dim,
            quant_config=quant_config,
        )

        self.tp_size = get_tensor_model_parallel_world_size()
        self.num_heads_per_partition = divide(self.num_heads, self.tp_size)

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads,
                           self.head_dim).transpose(1, 2).contiguous()

    def forward(
        self,
        hidden_states: torch.Tensor,
    ):
        """Input shape: Batch x Time x Channel"""
        bsz, tgt_len, _ = hidden_states.size()

        qkv_states, _ = self.qkv_proj(hidden_states)
        query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)

        query_states = query_states.view(bsz, tgt_len,
                                         self.num_heads_per_partition,
                                         self.head_dim)
        key_states = key_states.view(bsz, tgt_len,
                                     self.num_heads_per_partition,
                                     self.head_dim)
        value_states = value_states.view(bsz, tgt_len,
                                         self.num_heads_per_partition,
                                         self.head_dim)

        out = xops.memory_efficient_attention_forward(query_states,
                                                      key_states,
                                                      value_states,
                                                      p=self.dropout,
                                                      scale=self.scale)
        out = out.view(bsz, tgt_len, -1)
        attn_output, _ = self.out_proj(out)

261
        return attn_output, None
262
263


264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
class CLIPMLP(nn.Module):

    def __init__(self,
                 config: CLIPVisionConfig,
                 quant_config: Optional[QuantizationConfig] = None):
        super().__init__()
        self.config = config
        self.activation_fn = get_act_fn(config.hidden_act)
        self.fc1 = ColumnParallelLinear(config.hidden_size,
                                        config.intermediate_size,
                                        bias=True,
                                        quant_config=quant_config)
        self.fc2 = RowParallelLinear(config.intermediate_size,
                                     config.hidden_size,
                                     bias=True,
                                     quant_config=quant_config)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)

        return hidden_states


class CLIPEncoderLayer(nn.Module):

    def __init__(self,
                 config: CLIPVisionConfig,
                 quant_config: Optional[QuantizationConfig] = None):
        super().__init__()

296
297
298
299
300
301
302
        num_heads = config.num_attention_heads
        tp_size = get_tensor_model_parallel_world_size()
        if USE_XFORMERS_OPS and num_heads % tp_size == 0:
            self.self_attn = CLIPParallelAttention(config,
                                                   quant_config=quant_config)
        else:
            self.self_attn = CLIPSdpaAttention(config)
303
304
305
306
307
308
        self.layer_norm1 = nn.LayerNorm(config.hidden_size,
                                        eps=config.layer_norm_eps)
        self.mlp = CLIPMLP(config, quant_config=quant_config)
        self.layer_norm2 = nn.LayerNorm(config.hidden_size,
                                        eps=config.layer_norm_eps)

309
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
310
311
312
313

        residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
314
        hidden_states, _ = self.self_attn(hidden_states=hidden_states)
315
316
317
318
319
320
321
322
323
324
325
326
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states


class CLIPEncoder(nn.Module):
    """
327
    Transformer encoder consisting of `config.num_hidden_layers` self
328
329
330
331
332
333
334
335
    attention layers. Each layer is a [`CLIPEncoderLayer`].

    Args:
        config: CLIPConfig
    """

    def __init__(self,
                 config: CLIPVisionConfig,
336
337
                 quant_config: Optional[QuantizationConfig] = None,
                 num_hidden_layers_override: Optional[int] = None):
338
339
        super().__init__()
        self.config = config
340
341
342
343
344

        if num_hidden_layers_override is None:
            num_hidden_layers = config.num_hidden_layers
        else:
            num_hidden_layers = num_hidden_layers_override
345
346
        self.layers = nn.ModuleList([
            CLIPEncoderLayer(config=config, quant_config=quant_config)
347
            for _ in range(num_hidden_layers)
348
349
        ])

350
    def forward(self, inputs_embeds: torch.Tensor):
351
352

        hidden_states = inputs_embeds
353
        for encoder_layer in self.layers:
354
355
356
357
358
359
360
361
362
            hidden_states = encoder_layer(hidden_states)

        return hidden_states


class CLIPVisionTransformer(nn.Module):

    def __init__(self,
                 config: CLIPVisionConfig,
363
364
                 quant_config: Optional[QuantizationConfig] = None,
                 num_hidden_layers_override: Optional[int] = None):
365
366
367
368
369
370
371
372
373
        super().__init__()
        self.config = config
        embed_dim = config.hidden_size

        self.embeddings = CLIPVisionEmbeddings(config)

        # NOTE: This typo of "layrnorm" is not fixed on purpose to match
        # the original transformers code and name of the model weights.
        self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
374
375
376
377
        self.encoder = CLIPEncoder(
            config=config,
            quant_config=quant_config,
            num_hidden_layers_override=num_hidden_layers_override)
378

379
380
381
382
383
384
385
386
387
388
389
390
391
        if len(self.encoder.layers) > config.num_hidden_layers:
            raise ValueError(
                f"The original encoder only has {config.num_hidden_layers} "
                f"layers, but you requested {len(self.encoder.layers)} layers."
            )
        elif len(self.encoder.layers) == config.num_hidden_layers:
            self.post_layernorm = nn.LayerNorm(embed_dim,
                                               eps=config.layer_norm_eps)
        else:
            # post_layernorm is unused when we extract intermediate features
            # In this case, we can skip it to conserve memory
            self.post_layernorm = None

392
393
394
395
396
397
398
    def forward(
        self,
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:

        hidden_states = self.embeddings(pixel_values)
        hidden_states = self.pre_layrnorm(hidden_states)
399
        hidden_states = self.encoder(inputs_embeds=hidden_states)
400

401
402
403
404
        if self.post_layernorm is None:
            return hidden_states

        return self.post_layernorm(hidden_states)
405
406
407
408
409
410
411
412
413


class CLIPVisionModel(nn.Module):

    config_class = CLIPVisionConfig
    main_input_name = "pixel_values"

    def __init__(self,
                 config: CLIPVisionConfig,
414
415
                 quant_config: Optional[QuantizationConfig] = None,
                 num_hidden_layers_override: Optional[int] = None):
416
        super().__init__()
417

418
419
420
421
        tp_size = get_tensor_model_parallel_world_size()
        num_heads = config.num_attention_heads
        self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0

422
423
424
425
        self.vision_model = CLIPVisionTransformer(
            config=config,
            quant_config=quant_config,
            num_hidden_layers_override=num_hidden_layers_override)
426

427
428
    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
        return self.vision_model(pixel_values)
429
430
431
432

    @property
    def device(self):
        return next(self.parameters()).device
433

434
435
    # (TODO) Add prefix argument for filtering out weights to be loaded
    #        ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
436
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
437
438
439
440
441
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
442
        ] if self.shard_weight else []
443
444
445
446
447
        params_dict = dict(self.named_parameters())
        layer_count = len(self.vision_model.encoder.layers)

        for name, loaded_weight in weights:
            # post_layernorm is not needed in CLIPVisionModel
448
449
            if (name.startswith("vision_model.post_layernorm")
                    and self.vision_model.post_layernorm is None):
450
                continue
451

452
            # omit layers when num_hidden_layers_override is set
453
            if name.startswith("vision_model.encoder.layers"):
454
455
456
457
                layer_idx = int(name.split(".")[3])
                if layer_idx >= layer_count:
                    continue

458
459
460
461
462
463
464
465
466
467
468
469
470
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue

                param = params_dict[name.replace(weight_name, param_name)]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)