"vllm/vscode:/vscode.git/clone" did not exist on "475dcaa02ebd5485e18aa799c7e787537f3001c9"
clip.py 14.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Minimal implementation of CLIPVisionModel intended to be only used
4
within a vision language model."""
5
6
from collections.abc import Iterable
from typing import Optional, Union
7
8
9
10
11

import torch
import torch.nn as nn
from transformers import CLIPVisionConfig

12
from vllm.attention.layer import MultiHeadAttention
13
from vllm.distributed import divide, get_tensor_model_parallel_world_size
14
15
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
16
                                               QKVParallelLinear,
17
                                               RowParallelLinear)
18
from vllm.model_executor.layers.quantization import QuantizationConfig
19
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
20
from vllm.model_executor.models.interfaces import SupportsQuant
21

22
23
from .vision import (VisionEncoderInfo, VisionFeatureSelectStrategy,
                     resolve_visual_encoder_outputs)
24

25

26
27
28
29
30
31
32
33
class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]):

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
34
        return self.get_patch_grid_length()**2 + 1
35

36
37
38
39
40
41
42
    def get_image_size(self) -> int:
        return self.vision_config.image_size

    def get_patch_size(self) -> int:
        return self.vision_config.patch_size

    def get_patch_grid_length(self) -> int:
43
44
45
        image_size, patch_size = self.get_image_size(), self.get_patch_size()
        assert image_size % patch_size == 0
        return image_size // patch_size
46
47


48
49
50
51
52
53
54
55
56
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
class CLIPVisionEmbeddings(nn.Module):

    def __init__(self, config: CLIPVisionConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size
57
        assert self.image_size % self.patch_size == 0
58
59
60
61
62
63
64
65
66
67
68

        self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))

        self.patch_embedding = nn.Conv2d(
            in_channels=config.num_channels,
            out_channels=self.embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            bias=False,
        )

69
        self.num_patches = (self.image_size // self.patch_size)**2
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
        self.num_positions = self.num_patches + 1
        self.position_embedding = nn.Embedding(self.num_positions,
                                               self.embed_dim)
        self.register_buffer("position_ids",
                             torch.arange(self.num_positions).expand((1, -1)),
                             persistent=False)

    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
        batch_size = pixel_values.shape[0]
        target_dtype = self.patch_embedding.weight.dtype
        patch_embeds = self.patch_embedding(pixel_values.to(
            dtype=target_dtype))  # shape = [*, width, grid, grid]
        patch_embeds = patch_embeds.flatten(2).transpose(1, 2)

        class_embeds = self.class_embedding.expand(batch_size, 1, -1)
        embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
        embeddings = embeddings + self.position_embedding(self.position_ids)

        return embeddings


91
class CLIPAttention(nn.Module):
92
93
94
95
96
97
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        config: CLIPVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
98
        prefix: str = "",
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
    ):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                "embed_dim must be divisible by num_heads "
                f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
                f" {self.num_heads}).")
        self.scale = self.head_dim**-0.5

        self.qkv_proj = QKVParallelLinear(
            hidden_size=self.embed_dim,
            head_size=self.head_dim,
            total_num_heads=self.num_heads,
            quant_config=quant_config,
117
            prefix=f"{prefix}.qkv_proj",
118
119
120
121
122
123
        )

        self.out_proj = RowParallelLinear(
            input_size=self.embed_dim,
            output_size=self.embed_dim,
            quant_config=quant_config,
124
            prefix=f"{prefix}.out_proj",
125
126
127
128
129
        )

        self.tp_size = get_tensor_model_parallel_world_size()
        self.num_heads_per_partition = divide(self.num_heads, self.tp_size)

130
131
        self.attn = MultiHeadAttention(self.num_heads_per_partition,
                                       self.head_dim, self.scale)
132

133
134
135
136
137
138
139
140
    def forward(
        self,
        hidden_states: torch.Tensor,
    ):
        """Input shape: Batch x Time x Channel"""

        qkv_states, _ = self.qkv_proj(hidden_states)
        query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
141
        out = self.attn(query_states, key_states, value_states)
142
143
        attn_output, _ = self.out_proj(out)

144
        return attn_output, None
145
146


147
148
class CLIPMLP(nn.Module):

149
150
151
152
153
154
    def __init__(
        self,
        config: CLIPVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
155
156
157
158
159
160
        super().__init__()
        self.config = config
        self.activation_fn = get_act_fn(config.hidden_act)
        self.fc1 = ColumnParallelLinear(config.hidden_size,
                                        config.intermediate_size,
                                        bias=True,
161
162
                                        quant_config=quant_config,
                                        prefix=f"{prefix}.fc1")
163
164
165
        self.fc2 = RowParallelLinear(config.intermediate_size,
                                     config.hidden_size,
                                     bias=True,
166
167
                                     quant_config=quant_config,
                                     prefix=f"{prefix}.fc2")
168
169
170
171
172
173
174
175
176
177
178

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)

        return hidden_states


class CLIPEncoderLayer(nn.Module):

179
180
181
182
183
184
    def __init__(
        self,
        config: CLIPVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
185
        super().__init__()
186
187
188
189
190
        self.self_attn = CLIPAttention(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
        )
191
192
        self.layer_norm1 = nn.LayerNorm(config.hidden_size,
                                        eps=config.layer_norm_eps)
193
194
195
        self.mlp = CLIPMLP(config,
                           quant_config=quant_config,
                           prefix=f"{prefix}.mlp")
196
197
198
        self.layer_norm2 = nn.LayerNorm(config.hidden_size,
                                        eps=config.layer_norm_eps)

199
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
200
201
202
203

        residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
204
        hidden_states, _ = self.self_attn(hidden_states=hidden_states)
205
206
207
208
209
210
211
212
213
214
215
216
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states


class CLIPEncoder(nn.Module):
    """
217
    Transformer encoder consisting of `config.num_hidden_layers` self
218
219
220
221
222
223
    attention layers. Each layer is a [`CLIPEncoderLayer`].

    Args:
        config: CLIPConfig
    """

224
225
226
227
228
229
230
    def __init__(
        self,
        config: CLIPVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        num_hidden_layers_override: Optional[int] = None,
        prefix: str = "",
    ) -> None:
231
        super().__init__()
232

233
        self.config = config
234
235
236
237
238

        if num_hidden_layers_override is None:
            num_hidden_layers = config.num_hidden_layers
        else:
            num_hidden_layers = num_hidden_layers_override
239
        self.layers = nn.ModuleList([
240
241
242
243
            CLIPEncoderLayer(config=config,
                             quant_config=quant_config,
                             prefix=f"{prefix}.layers.{layer_idx}")
            for layer_idx in range(num_hidden_layers)
244
245
        ])

246
247
248
    def forward(
        self, inputs_embeds: torch.Tensor, return_all_hidden_states: bool
    ) -> Union[torch.Tensor, list[torch.Tensor]]:
249
        hidden_states_pool = [inputs_embeds]
250
        hidden_states = inputs_embeds
251

252
        for encoder_layer in self.layers:
253
            hidden_states = encoder_layer(hidden_states)
254
255
256
257
258
259
            if return_all_hidden_states:
                hidden_states_pool.append(hidden_states)
        # If we have multiple feature sample layers, we return all hidden
        # states in order and grab the ones we need by index.
        if return_all_hidden_states:
            return hidden_states_pool
260
261
262
263
264
        return hidden_states


class CLIPVisionTransformer(nn.Module):

265
266
267
268
269
270
271
272
273
    def __init__(
        self,
        config: CLIPVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        *,
        num_hidden_layers_override: Optional[int] = None,
        require_post_norm: Optional[bool] = None,
        prefix: str = "",
    ) -> None:
274
        super().__init__()
275

276
277
278
279
280
281
282
283
        self.config = config
        embed_dim = config.hidden_size

        self.embeddings = CLIPVisionEmbeddings(config)

        # NOTE: This typo of "layrnorm" is not fixed on purpose to match
        # the original transformers code and name of the model weights.
        self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
284

285
286
287
        self.encoder = CLIPEncoder(
            config=config,
            quant_config=quant_config,
288
289
290
            num_hidden_layers_override=num_hidden_layers_override,
            prefix=f"{prefix}.encoder",
        )
291

292
        num_hidden_layers = config.num_hidden_layers
293
294
        if len(self.encoder.layers) > config.num_hidden_layers:
            raise ValueError(
295
                f"The original encoder only has {num_hidden_layers} "
296
297
                f"layers, but you requested {len(self.encoder.layers)} layers."
            )
298
299
300
301
302
303

        # If possible, skip post_layernorm to conserve memory
        if require_post_norm is None:
            require_post_norm = len(self.encoder.layers) == num_hidden_layers

        if require_post_norm:
304
305
306
307
308
            self.post_layernorm = nn.LayerNorm(embed_dim,
                                               eps=config.layer_norm_eps)
        else:
            self.post_layernorm = None

309
310
311
    def forward(
        self,
        pixel_values: torch.Tensor,
312
313
314
        *,
        select_layers: Optional[list[int]] = None,
        feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None,
315
316
317
318
319
    ) -> torch.Tensor:

        hidden_states = self.embeddings(pixel_values)
        hidden_states = self.pre_layrnorm(hidden_states)

320
        # Produces either the last layer output or all of the hidden states,
321
        # depending on if we have select_layers or not
322
323
        encoder_outputs = self.encoder(
            inputs_embeds=hidden_states,
324
325
            return_all_hidden_states=select_layers is not None,
        )
326
327
328

        # Handle post-norm (if applicable) and stacks feature layers if needed
        encoder_outputs = resolve_visual_encoder_outputs(
329
330
331
332
333
334
            encoder_outputs,
            self.post_layernorm,
            select_layers=select_layers,
            max_possible_layers=self.config.num_hidden_layers,
            feature_select_strategy=feature_select_strategy,
        )
335

336
        return encoder_outputs
337
338


339
class CLIPVisionModel(nn.Module, SupportsQuant):
340
341
    config_class = CLIPVisionConfig
    main_input_name = "pixel_values"
342
    packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
343

344
345
346
347
348
349
350
351
352
    def __init__(
        self,
        config: CLIPVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        *,
        num_hidden_layers_override: Optional[int] = None,
        require_post_norm: Optional[bool] = None,
        prefix: str = "",
    ) -> None:
353
        super().__init__()
354
355
356
        self.vision_model = CLIPVisionTransformer(
            config=config,
            quant_config=quant_config,
357
358
            num_hidden_layers_override=num_hidden_layers_override,
            require_post_norm=require_post_norm,
359
            prefix=f"{prefix}.vision_model")
360

361
362
363
    def forward(
        self,
        pixel_values: torch.Tensor,
364
365
        select_layers: Optional[list[int]] = None,
        feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None,
366
    ) -> torch.Tensor:
367
368
369
370
371
        return self.vision_model(
            pixel_values,
            select_layers=select_layers,
            feature_select_strategy=feature_select_strategy,
        )
372
373
374
375

    @property
    def device(self):
        return next(self.parameters()).device
376

377
378
    # (TODO) Add prefix argument for filtering out weights to be loaded
    #        ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
379
380
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
381
382
383
384
385
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
386
        ]
387
        params_dict = dict(self.named_parameters())
388
        loaded_params: set[str] = set()
389
390
391
392
        layer_count = len(self.vision_model.encoder.layers)

        for name, loaded_weight in weights:
            # post_layernorm is not needed in CLIPVisionModel
393
394
            if (name.startswith("vision_model.post_layernorm")
                    and self.vision_model.post_layernorm is None):
395
                continue
396

397
            # omit layers when num_hidden_layers_override is set
398
            if name.startswith("vision_model.encoder.layers"):
399
400
401
402
                layer_idx = int(name.split(".")[3])
                if layer_idx >= layer_count:
                    continue

403
404
405
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
406
                name = name.replace(weight_name, param_name)
407

408
                param = params_dict[name]
409
410
411
412
413
414
415
416
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
417
418
            loaded_params.add(name)
        return loaded_params