clip.py 14.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
"""Minimal implementation of CLIPVisionModel intended to be only used
3
within a vision language model."""
4
from typing import Iterable, Optional, Set, Tuple, Union
5
6
7
8
9

import torch
import torch.nn as nn
from transformers import CLIPVisionConfig

10
from vllm.attention.layer import MultiHeadAttention
11
from vllm.distributed import divide, get_tensor_model_parallel_world_size
12
13
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
14
                                               QKVParallelLinear,
15
                                               RowParallelLinear)
16
from vllm.model_executor.layers.quantization import QuantizationConfig
17
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
18

19
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
20

21

22
23
24
25
26
27
28
29
class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]):

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
30
        return self.get_patch_grid_length()**2 + 1
31
32

    def get_max_image_tokens(self) -> int:
33
        return self.get_patch_grid_length()**2 + 1
34

35
36
37
38
39
40
41
    def get_image_size(self) -> int:
        return self.vision_config.image_size

    def get_patch_size(self) -> int:
        return self.vision_config.patch_size

    def get_patch_grid_length(self) -> int:
42
43
44
        image_size, patch_size = self.get_image_size(), self.get_patch_size()
        assert image_size % patch_size == 0
        return image_size // patch_size
45
46


47
48
49
50
51
52
53
54
55
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
class CLIPVisionEmbeddings(nn.Module):

    def __init__(self, config: CLIPVisionConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size
56
        assert self.image_size % self.patch_size == 0
57
58
59
60
61
62
63
64
65
66
67

        self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))

        self.patch_embedding = nn.Conv2d(
            in_channels=config.num_channels,
            out_channels=self.embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            bias=False,
        )

68
        self.num_patches = (self.image_size // self.patch_size)**2
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
        self.num_positions = self.num_patches + 1
        self.position_embedding = nn.Embedding(self.num_positions,
                                               self.embed_dim)
        self.register_buffer("position_ids",
                             torch.arange(self.num_positions).expand((1, -1)),
                             persistent=False)

    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
        batch_size = pixel_values.shape[0]
        target_dtype = self.patch_embedding.weight.dtype
        patch_embeds = self.patch_embedding(pixel_values.to(
            dtype=target_dtype))  # shape = [*, width, grid, grid]
        patch_embeds = patch_embeds.flatten(2).transpose(1, 2)

        class_embeds = self.class_embedding.expand(batch_size, 1, -1)
        embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
        embeddings = embeddings + self.position_embedding(self.position_ids)

        return embeddings


90
class CLIPAttention(nn.Module):
91
92
93
94
95
96
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        config: CLIPVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
97
        prefix: str = "",
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
    ):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                "embed_dim must be divisible by num_heads "
                f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
                f" {self.num_heads}).")
        self.scale = self.head_dim**-0.5
        self.dropout = config.attention_dropout

        self.qkv_proj = QKVParallelLinear(
            hidden_size=self.embed_dim,
            head_size=self.head_dim,
            total_num_heads=self.num_heads,
            quant_config=quant_config,
117
            prefix=f"{prefix}.qkv_proj",
118
119
120
121
122
123
        )

        self.out_proj = RowParallelLinear(
            input_size=self.embed_dim,
            output_size=self.embed_dim,
            quant_config=quant_config,
124
            prefix=f"{prefix}.out_proj",
125
126
127
128
129
        )

        self.tp_size = get_tensor_model_parallel_world_size()
        self.num_heads_per_partition = divide(self.num_heads, self.tp_size)

130
131
        self.attn = MultiHeadAttention(self.num_heads_per_partition,
                                       self.head_dim, self.scale)
132

133
134
135
136
137
138
139
140
141
142
143
144
    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads,
                           self.head_dim).transpose(1, 2).contiguous()

    def forward(
        self,
        hidden_states: torch.Tensor,
    ):
        """Input shape: Batch x Time x Channel"""

        qkv_states, _ = self.qkv_proj(hidden_states)
        query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
145
        out = self.attn(query_states, key_states, value_states)
146
147
        attn_output, _ = self.out_proj(out)

148
        return attn_output, None
149
150


151
152
class CLIPMLP(nn.Module):

153
154
155
156
157
158
    def __init__(
        self,
        config: CLIPVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
159
160
161
162
163
164
        super().__init__()
        self.config = config
        self.activation_fn = get_act_fn(config.hidden_act)
        self.fc1 = ColumnParallelLinear(config.hidden_size,
                                        config.intermediate_size,
                                        bias=True,
165
166
                                        quant_config=quant_config,
                                        prefix=f"{prefix}.fc1")
167
168
169
        self.fc2 = RowParallelLinear(config.intermediate_size,
                                     config.hidden_size,
                                     bias=True,
170
171
                                     quant_config=quant_config,
                                     prefix=f"{prefix}.fc2")
172
173
174
175
176
177
178
179
180
181
182

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)

        return hidden_states


class CLIPEncoderLayer(nn.Module):

183
184
185
186
187
188
    def __init__(
        self,
        config: CLIPVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
189
        super().__init__()
190
191
192
193
194
        self.self_attn = CLIPAttention(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
        )
195
196
        self.layer_norm1 = nn.LayerNorm(config.hidden_size,
                                        eps=config.layer_norm_eps)
197
198
199
        self.mlp = CLIPMLP(config,
                           quant_config=quant_config,
                           prefix=f"{prefix}.mlp")
200
201
202
        self.layer_norm2 = nn.LayerNorm(config.hidden_size,
                                        eps=config.layer_norm_eps)

203
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
204
205
206
207

        residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
208
        hidden_states, _ = self.self_attn(hidden_states=hidden_states)
209
210
211
212
213
214
215
216
217
218
219
220
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states


class CLIPEncoder(nn.Module):
    """
221
    Transformer encoder consisting of `config.num_hidden_layers` self
222
223
224
225
226
227
    attention layers. Each layer is a [`CLIPEncoderLayer`].

    Args:
        config: CLIPConfig
    """

228
229
230
231
232
233
234
    def __init__(
        self,
        config: CLIPVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        num_hidden_layers_override: Optional[int] = None,
        prefix: str = "",
    ) -> None:
235
        super().__init__()
236

237
        self.config = config
238
239
240
241
242

        if num_hidden_layers_override is None:
            num_hidden_layers = config.num_hidden_layers
        else:
            num_hidden_layers = num_hidden_layers_override
243
        self.layers = nn.ModuleList([
244
245
246
247
            CLIPEncoderLayer(config=config,
                             quant_config=quant_config,
                             prefix=f"{prefix}.layers.{layer_idx}")
            for layer_idx in range(num_hidden_layers)
248
249
        ])

250
251
252
253
    def forward(
        self, inputs_embeds: torch.Tensor, return_all_hidden_states: bool
    ) -> Union[torch.Tensor, list[torch.Tensor]]:
        hidden_states_pool = []
254
        hidden_states = inputs_embeds
255

256
        for encoder_layer in self.layers:
257
            hidden_states = encoder_layer(hidden_states)
258
259
260
261
262
263
            if return_all_hidden_states:
                hidden_states_pool.append(hidden_states)
        # If we have multiple feature sample layers, we return all hidden
        # states in order and grab the ones we need by index.
        if return_all_hidden_states:
            return hidden_states_pool
264
265
266
267
268
        return hidden_states


class CLIPVisionTransformer(nn.Module):

269
270
271
272
273
274
275
276
277
    def __init__(
        self,
        config: CLIPVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        *,
        num_hidden_layers_override: Optional[int] = None,
        require_post_norm: Optional[bool] = None,
        prefix: str = "",
    ) -> None:
278
        super().__init__()
279

280
281
282
283
284
285
286
287
        self.config = config
        embed_dim = config.hidden_size

        self.embeddings = CLIPVisionEmbeddings(config)

        # NOTE: This typo of "layrnorm" is not fixed on purpose to match
        # the original transformers code and name of the model weights.
        self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
288

289
290
291
        self.encoder = CLIPEncoder(
            config=config,
            quant_config=quant_config,
292
293
294
            num_hidden_layers_override=num_hidden_layers_override,
            prefix=f"{prefix}.encoder",
        )
295

296
        num_hidden_layers = config.num_hidden_layers
297
298
        if len(self.encoder.layers) > config.num_hidden_layers:
            raise ValueError(
299
                f"The original encoder only has {num_hidden_layers} "
300
301
                f"layers, but you requested {len(self.encoder.layers)} layers."
            )
302
303
304
305
306
307

        # If possible, skip post_layernorm to conserve memory
        if require_post_norm is None:
            require_post_norm = len(self.encoder.layers) == num_hidden_layers

        if require_post_norm:
308
309
310
311
312
            self.post_layernorm = nn.LayerNorm(embed_dim,
                                               eps=config.layer_norm_eps)
        else:
            self.post_layernorm = None

313
314
315
    def forward(
        self,
        pixel_values: torch.Tensor,
316
        feature_sample_layers: Optional[list[int]] = None,
317
318
319
320
321
    ) -> torch.Tensor:

        hidden_states = self.embeddings(pixel_values)
        hidden_states = self.pre_layrnorm(hidden_states)

322
323
324
325
326
327
328
329
330
331
332
333
        return_all_hidden_states = feature_sample_layers is not None

        # Produces either the last layer output or all of the hidden states,
        # depending on if we have feature_sample_layers or not
        encoder_outputs = self.encoder(
            inputs_embeds=hidden_states,
            return_all_hidden_states=return_all_hidden_states)

        # Handle post-norm (if applicable) and stacks feature layers if needed
        encoder_outputs = resolve_visual_encoder_outputs(
            encoder_outputs, feature_sample_layers, self.post_layernorm,
            self.config.num_hidden_layers)
334

335
        return encoder_outputs
336
337
338
339
340
341
342


class CLIPVisionModel(nn.Module):

    config_class = CLIPVisionConfig
    main_input_name = "pixel_values"

343
344
345
346
347
348
349
350
351
    def __init__(
        self,
        config: CLIPVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        *,
        num_hidden_layers_override: Optional[int] = None,
        require_post_norm: Optional[bool] = None,
        prefix: str = "",
    ) -> None:
352
        super().__init__()
353
354
355
        self.vision_model = CLIPVisionTransformer(
            config=config,
            quant_config=quant_config,
356
357
            num_hidden_layers_override=num_hidden_layers_override,
            require_post_norm=require_post_norm,
358
            prefix=f"{prefix}.vision_model")
359

360
361
362
363
364
365
    def forward(
        self,
        pixel_values: torch.Tensor,
        feature_sample_layers: Optional[list[int]] = None,
    ) -> torch.Tensor:
        return self.vision_model(pixel_values, feature_sample_layers)
366
367
368
369

    @property
    def device(self):
        return next(self.parameters()).device
370

371
372
    # (TODO) Add prefix argument for filtering out weights to be loaded
    #        ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
373
374
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
375
376
377
378
379
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
380
        ]
381
        params_dict = dict(self.named_parameters())
382
        loaded_params: Set[str] = set()
383
384
385
386
        layer_count = len(self.vision_model.encoder.layers)

        for name, loaded_weight in weights:
            # post_layernorm is not needed in CLIPVisionModel
387
388
            if (name.startswith("vision_model.post_layernorm")
                    and self.vision_model.post_layernorm is None):
389
                continue
390

391
            # omit layers when num_hidden_layers_override is set
392
            if name.startswith("vision_model.encoder.layers"):
393
394
395
396
                layer_idx = int(name.split(".")[3])
                if layer_idx >= layer_count:
                    continue

397
398
399
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
400
                name = name.replace(weight_name, param_name)
401

402
                param = params_dict[name]
403
404
405
406
407
408
409
410
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
411
412
            loaded_params.add(name)
        return loaded_params