clip.py 32.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import Annotated, Literal, Optional, Union
6
7
8

import torch
import torch.nn as nn
9
10
from transformers import (BatchFeature, CLIPConfig, CLIPProcessor,
                          CLIPTextConfig, CLIPVisionConfig)
11

12
from vllm.attention import Attention
13
from vllm.attention.layer import MultiHeadAttention
14
15
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
16
from vllm.distributed import divide, get_tensor_model_parallel_world_size
17
18
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
19
                                               QKVParallelLinear,
20
                                               RowParallelLinear)
21
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
22
from vllm.model_executor.layers.quantization import QuantizationConfig
23
24
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
25
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
26
from vllm.model_executor.models.interfaces import SupportsQuant
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
                                    MultiModalInputs, MultiModalKwargsItems,
                                    MultiModalUUIDDict)
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
                                   MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
                                        BaseProcessingInfo, PromptIndexTargets,
                                        PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape

from .interfaces import MultiModalEmbeddings, SupportsMultiModal
from .interfaces_base import default_pooling_type
from .utils import AutoWeightsLoader, maybe_prefix
43
from .vision import (VisionEncoderInfo, VisionFeatureSelectStrategy,
44
45
                     VisionFeatureSelectStrategyStr,
                     get_num_selected_vision_tokens,
46
                     resolve_visual_encoder_outputs)
47

48

49
50
51
52
53
54
55
56
57
58
59
60
class CLIPImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height of each image
        - w: Width of each image
    """
    type: Literal["pixel_values"]
    data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]


61
62
63
64
65
66
67
68
class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]):

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
69
        return self.get_patch_grid_length()**2 + 1
70

71
72
73
74
75
76
77
    def get_image_size(self) -> int:
        return self.vision_config.image_size

    def get_patch_size(self) -> int:
        return self.vision_config.patch_size

    def get_patch_grid_length(self) -> int:
78
79
80
        image_size, patch_size = self.get_image_size(), self.get_patch_size()
        assert image_size % patch_size == 0
        return image_size // patch_size
81
82


83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
_POOLING_TYPE_TO_STRATEGY: dict[str, VisionFeatureSelectStrategyStr] = {
    "MEAN": "full",
    "ALL": "full",
    "CLS": "class",
    # This lets us use the same pooling type for both text and image
    "LAST": "class",
}


def _get_vision_feature_select_strategy(pooling_type: str):
    try:
        return _POOLING_TYPE_TO_STRATEGY[pooling_type]
    except KeyError:
        raise ValueError(f"No feature selection strategy is defined for "
                         f"pooling_type: {pooling_type!r}") from None


class CLIPProcessingInfo(BaseProcessingInfo):

    def get_hf_config(self):
        return self.ctx.get_hf_config(CLIPConfig)

    def get_vision_encoder_info(self):
        return CLIPEncoderInfo(self.get_hf_config())

    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(CLIPProcessor, **kwargs)

    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": 1}

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        vision_encoder_info = self.get_vision_encoder_info()

        pooler_config = self.ctx.model_config.pooler_config
        assert pooler_config is not None

        return get_num_selected_vision_tokens(
            vision_encoder_info.get_num_image_tokens(
                image_width=image_width,
                image_height=image_height,
            ),
            _get_vision_feature_select_strategy(pooler_config.pooling_type),
        )

    def get_image_size_with_most_features(self) -> ImageSize:
        vision_encoder_info = self.get_vision_encoder_info()
        width = height = vision_encoder_info.get_image_size()
        return ImageSize(width=width, height=height)

    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()

        return self.get_num_image_tokens(
            image_width=target_width,
            image_height=target_height,
        )


class CLIPDummyInputsBuilder(BaseDummyInputsBuilder[CLIPProcessingInfo]):

    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        return ""

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
        mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)

        target_width, target_height = \
            self.info.get_image_size_with_most_features()

        image_overrides = mm_options.get("image") if mm_options else None

        return {
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images,
                                   overrides=image_overrides)
        }


class CLIPMultiModalProcessor(BaseMultiModalProcessor[CLIPProcessingInfo]):

    @cached_property
    def image_token_id(self) -> int:
        tokenizer = self.info.get_tokenizer()
        dummy_token_id = 0

        assert dummy_token_id not in tokenizer.all_special_ids

        return dummy_token_id

    def apply(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
        tokenization_kwargs: Optional[Mapping[str, object]] = None,
        *,
        mm_uuids: Optional[MultiModalUUIDDict] = None,
    ) -> MultiModalInputs:
        if prompt and mm_data:
            raise ValueError(
                "CLIP accepts text-only or image-only inputs, not both! "
                "Image-only inputs means passing an image with an empty text "
                "prompt.")

        if mm_data:
            # For multi-modal data, the prompt after processing should
            # only contain the dummy image tokens
            tokenization_kwargs = {
                **(tokenization_kwargs or {}),
                "add_special_tokens": False,
            }

        return super().apply(
            prompt=prompt,
            mm_data=mm_data,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
            tokenization_kwargs=tokenization_kwargs,
            mm_uuids=mm_uuids,
        )

    def _hf_processor_applies_updates(
        self,
        prompt_text: str,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        tokenization_kwargs: Mapping[str, object],
    ) -> bool:
        return False

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(pixel_values=MultiModalFieldConfig.batched("image"))

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargsItems,
    ) -> Sequence[PromptUpdate]:
        image_token_id = self.image_token_id

        def get_replacement(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)

            num_image_tokens = self.info.get_num_image_tokens(
                image_width=image_size.width,
                image_height=image_size.height,
            )
            return [image_token_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=PromptIndexTargets.start(),
                replacement=get_replacement,
            ),
        ]


# Adapted from: https://github.com/huggingface/transformers/blob/v4.56.2/src/transformers/models/clip/modeling_clip.py
class CLIPTextEmbeddings(nn.Module):

    def __init__(self, config: CLIPTextConfig):
        super().__init__()

        embed_dim = config.hidden_size

        self.token_embedding = VocabParallelEmbedding(config.vocab_size,
                                                      embed_dim)
        self.position_embedding = VocabParallelEmbedding(
            config.max_position_embeddings, embed_dim)

    def forward(
        self,
        input_ids: Optional[torch.Tensor],
        position_ids: torch.Tensor,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        if inputs_embeds is None:
            if input_ids is None:
                raise ValueError(
                    "Either `input_ids` or `input_embeds` must be provided")

            inputs_embeds = self.token_embedding(input_ids)

        position_embeddings = self.position_embedding(position_ids)
        embeddings = inputs_embeds + position_embeddings

        return embeddings


291
292
293
294
295
296
297
298
class CLIPVisionEmbeddings(nn.Module):

    def __init__(self, config: CLIPVisionConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size
299
        assert self.image_size % self.patch_size == 0
300
301
302
303
304
305
306
307
308
309
310

        self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))

        self.patch_embedding = nn.Conv2d(
            in_channels=config.num_channels,
            out_channels=self.embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            bias=False,
        )

311
        self.num_patches = (self.image_size // self.patch_size)**2
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
        self.num_positions = self.num_patches + 1
        self.position_embedding = nn.Embedding(self.num_positions,
                                               self.embed_dim)
        self.register_buffer("position_ids",
                             torch.arange(self.num_positions).expand((1, -1)),
                             persistent=False)

    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
        batch_size = pixel_values.shape[0]
        target_dtype = self.patch_embedding.weight.dtype
        patch_embeds = self.patch_embedding(pixel_values.to(
            dtype=target_dtype))  # shape = [*, width, grid, grid]
        patch_embeds = patch_embeds.flatten(2).transpose(1, 2)

        class_embeds = self.class_embedding.expand(batch_size, 1, -1)
        embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
        embeddings = embeddings + self.position_embedding(self.position_ids)

        return embeddings


333
class CLIPAttention(nn.Module):
334
335
336

    def __init__(
        self,
337
        config: Union[CLIPTextConfig, CLIPVisionConfig],
338
        quant_config: Optional[QuantizationConfig] = None,
339
        *,
340
        prefix: str = "",
341
342
        attn_cls: Union[type[Attention], type[MultiHeadAttention]],
    ) -> None:
343
        super().__init__()
344

345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                "embed_dim must be divisible by num_heads "
                f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
                f" {self.num_heads}).")
        self.scale = self.head_dim**-0.5

        self.qkv_proj = QKVParallelLinear(
            hidden_size=self.embed_dim,
            head_size=self.head_dim,
            total_num_heads=self.num_heads,
            quant_config=quant_config,
361
            prefix=f"{prefix}.qkv_proj",
362
363
364
365
366
367
        )

        self.out_proj = RowParallelLinear(
            input_size=self.embed_dim,
            output_size=self.embed_dim,
            quant_config=quant_config,
368
            prefix=f"{prefix}.out_proj",
369
370
371
372
373
        )

        self.tp_size = get_tensor_model_parallel_world_size()
        self.num_heads_per_partition = divide(self.num_heads, self.tp_size)

374
375
376
377
378
379
        self.attn = attn_cls(
            self.num_heads_per_partition,
            self.head_dim,
            self.scale,
            prefix=f"{prefix}.attn",
        )
380

381
382
383
384
385
386
387
388
    def forward(
        self,
        hidden_states: torch.Tensor,
    ):
        """Input shape: Batch x Time x Channel"""

        qkv_states, _ = self.qkv_proj(hidden_states)
        query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
389
        out = self.attn(query_states, key_states, value_states)
390
391
        attn_output, _ = self.out_proj(out)

392
        return attn_output, None
393
394


395
396
class CLIPMLP(nn.Module):

397
398
    def __init__(
        self,
399
        config: Union[CLIPTextConfig, CLIPVisionConfig],
400
401
402
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
403
404
405
406
407
408
        super().__init__()
        self.config = config
        self.activation_fn = get_act_fn(config.hidden_act)
        self.fc1 = ColumnParallelLinear(config.hidden_size,
                                        config.intermediate_size,
                                        bias=True,
409
410
                                        quant_config=quant_config,
                                        prefix=f"{prefix}.fc1")
411
412
413
        self.fc2 = RowParallelLinear(config.intermediate_size,
                                     config.hidden_size,
                                     bias=True,
414
415
                                     quant_config=quant_config,
                                     prefix=f"{prefix}.fc2")
416
417
418
419
420
421
422
423
424
425
426

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)

        return hidden_states


class CLIPEncoderLayer(nn.Module):

427
428
    def __init__(
        self,
429
        config: Union[CLIPTextConfig, CLIPVisionConfig],
430
        quant_config: Optional[QuantizationConfig] = None,
431
        *,
432
        prefix: str = "",
433
        attn_cls: Union[type[Attention], type[MultiHeadAttention]],
434
    ) -> None:
435
        super().__init__()
436
437
438
439
        self.self_attn = CLIPAttention(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
440
            attn_cls=attn_cls,
441
        )
442
443
        self.layer_norm1 = nn.LayerNorm(config.hidden_size,
                                        eps=config.layer_norm_eps)
444
445
446
        self.mlp = CLIPMLP(config,
                           quant_config=quant_config,
                           prefix=f"{prefix}.mlp")
447
448
449
        self.layer_norm2 = nn.LayerNorm(config.hidden_size,
                                        eps=config.layer_norm_eps)

450
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
451
452
453
454

        residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
455
        hidden_states, _ = self.self_attn(hidden_states=hidden_states)
456
457
458
459
460
461
462
463
464
465
466
467
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states


class CLIPEncoder(nn.Module):
    """
468
    Transformer encoder consisting of `config.num_hidden_layers` self
469
470
471
472
473
474
    attention layers. Each layer is a [`CLIPEncoderLayer`].

    Args:
        config: CLIPConfig
    """

475
476
    def __init__(
        self,
477
        config: Union[CLIPTextConfig, CLIPVisionConfig],
478
479
        quant_config: Optional[QuantizationConfig] = None,
        num_hidden_layers_override: Optional[int] = None,
480
        *,
481
        prefix: str = "",
482
        attn_cls: Union[type[Attention], type[MultiHeadAttention]],
483
    ) -> None:
484
        super().__init__()
485

486
        self.config = config
487
488
489
490
491

        if num_hidden_layers_override is None:
            num_hidden_layers = config.num_hidden_layers
        else:
            num_hidden_layers = num_hidden_layers_override
492
        self.layers = nn.ModuleList([
493
494
            CLIPEncoderLayer(config=config,
                             quant_config=quant_config,
495
496
                             prefix=f"{prefix}.layers.{layer_idx}",
                             attn_cls=attn_cls)
497
            for layer_idx in range(num_hidden_layers)
498
499
        ])

500
    def forward(
501
502
503
        self,
        inputs_embeds: torch.Tensor,
        return_all_hidden_states: bool,
504
    ) -> Union[torch.Tensor, list[torch.Tensor]]:
505
        hidden_states_pool = [inputs_embeds]
506
        hidden_states = inputs_embeds
507

508
        for encoder_layer in self.layers:
509
            hidden_states = encoder_layer(hidden_states)
510
511
512
513
514
515
            if return_all_hidden_states:
                hidden_states_pool.append(hidden_states)
        # If we have multiple feature sample layers, we return all hidden
        # states in order and grab the ones we need by index.
        if return_all_hidden_states:
            return hidden_states_pool
516
517
518
        return hidden_states


519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
class CLIPTextTransformer(nn.Module):

    def __init__(
        self,
        config: CLIPTextConfig,
        quant_config: Optional[QuantizationConfig] = None,
        *,
        prefix: str = "",
    ) -> None:
        super().__init__()

        self.config = config
        embed_dim = config.hidden_size

        self.embeddings = CLIPTextEmbeddings(config)

        self.encoder = CLIPEncoder(
            config=config,
            quant_config=quant_config,
            prefix=f"{prefix}.encoder",
            attn_cls=Attention,
        )

        self.final_layer_norm = nn.LayerNorm(
            embed_dim,
            eps=config.layer_norm_eps,
        )

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embeddings.token_embedding(input_ids)

    def forward(
        self,
        input_ids: Optional[torch.Tensor],
        position_ids: torch.Tensor,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        hidden_states = self.embeddings(
            input_ids=input_ids,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
        )

        last_hidden_state = self.encoder(
            inputs_embeds=hidden_states,
            return_all_hidden_states=False,
        )
        last_hidden_state = self.final_layer_norm(last_hidden_state)

        return last_hidden_state

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()

        for name, loaded_weight in weights:
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


600
601
class CLIPVisionTransformer(nn.Module):

602
603
604
605
606
607
608
609
610
    def __init__(
        self,
        config: CLIPVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        *,
        num_hidden_layers_override: Optional[int] = None,
        require_post_norm: Optional[bool] = None,
        prefix: str = "",
    ) -> None:
611
        super().__init__()
612

613
614
615
616
617
618
619
620
        self.config = config
        embed_dim = config.hidden_size

        self.embeddings = CLIPVisionEmbeddings(config)

        # NOTE: This typo of "layrnorm" is not fixed on purpose to match
        # the original transformers code and name of the model weights.
        self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
621

622
623
624
        self.encoder = CLIPEncoder(
            config=config,
            quant_config=quant_config,
625
626
            num_hidden_layers_override=num_hidden_layers_override,
            prefix=f"{prefix}.encoder",
627
            attn_cls=MultiHeadAttention,
628
        )
629

630
        num_hidden_layers = config.num_hidden_layers
631
632
        if len(self.encoder.layers) > config.num_hidden_layers:
            raise ValueError(
633
                f"The original encoder only has {num_hidden_layers} "
634
635
                f"layers, but you requested {len(self.encoder.layers)} layers."
            )
636
637
638
639
640
641

        # If possible, skip post_layernorm to conserve memory
        if require_post_norm is None:
            require_post_norm = len(self.encoder.layers) == num_hidden_layers

        if require_post_norm:
642
643
644
645
646
            self.post_layernorm = nn.LayerNorm(embed_dim,
                                               eps=config.layer_norm_eps)
        else:
            self.post_layernorm = None

647
648
649
650
651
652
653
654
    @property
    def dtype(self):
        return next(self.parameters()).dtype

    @property
    def device(self):
        return next(self.parameters()).device

655
656
657
    def forward(
        self,
        pixel_values: torch.Tensor,
658
659
660
        *,
        select_layers: Optional[list[int]] = None,
        feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None,
661
662
663
664
665
    ) -> torch.Tensor:

        hidden_states = self.embeddings(pixel_values)
        hidden_states = self.pre_layrnorm(hidden_states)

666
        # Produces either the last layer output or all of the hidden states,
667
        # depending on if we have select_layers or not
668
669
        encoder_outputs = self.encoder(
            inputs_embeds=hidden_states,
670
671
            return_all_hidden_states=select_layers is not None,
        )
672
673
674

        # Handle post-norm (if applicable) and stacks feature layers if needed
        encoder_outputs = resolve_visual_encoder_outputs(
675
676
677
678
679
680
            encoder_outputs,
            self.post_layernorm,
            select_layers=select_layers,
            max_possible_layers=self.config.num_hidden_layers,
            feature_select_strategy=feature_select_strategy,
        )
681

682
        return encoder_outputs
683

684
685
686
687
688
689
690
691
692
693
694
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        layer_count = len(self.encoder.layers)
695

696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
        for name, loaded_weight in weights:
            # post_layernorm is not needed in CLIPVisionModel
            if (name.startswith("post_layernorm")
                    and self.post_layernorm is None):
                continue

            # omit layers when num_hidden_layers_override is set
            if name.startswith("encoder.layers"):
                layer_idx = int(name.split(".")[2])
                if layer_idx >= layer_count:
                    continue

            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


class CLIPVisionModel(nn.Module):
727

728
729
730
731
732
733
734
735
736
    def __init__(
        self,
        config: CLIPVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        *,
        num_hidden_layers_override: Optional[int] = None,
        require_post_norm: Optional[bool] = None,
        prefix: str = "",
    ) -> None:
737
        super().__init__()
738

739
740
741
        self.vision_model = CLIPVisionTransformer(
            config=config,
            quant_config=quant_config,
742
743
            num_hidden_layers_override=num_hidden_layers_override,
            require_post_norm=require_post_norm,
744
745
            prefix=f"{prefix}.vision_model",
        )
746

747
748
749
    def forward(
        self,
        pixel_values: torch.Tensor,
750
751
        select_layers: Optional[list[int]] = None,
        feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None,
752
    ) -> torch.Tensor:
753
754
755
756
757
        return self.vision_model(
            pixel_values,
            select_layers=select_layers,
            feature_select_strategy=feature_select_strategy,
        )
758

759
760
761
762
    @property
    def dtype(self):
        return self.vision_model.dtype

763
764
    @property
    def device(self):
765
        return self.vision_model.device
766
767


768
769
770
771
772
773
# Assume EOS token corresponds to LAST token in text model
@default_pooling_type("LAST")
@MULTIMODAL_REGISTRY.register_processor(CLIPMultiModalProcessor,
                                        info=CLIPProcessingInfo,
                                        dummy_inputs=CLIPDummyInputsBuilder)
class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
774

775
    is_pooling_model = True
776

777
778
    packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
    merge_by_field_config = True
779

780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return None

        raise ValueError("Only image modality is supported")

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config: CLIPConfig = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
        self.config = config
        self.multimodal_config = multimodal_config

        text_config = config.text_config
        vision_config = config.vision_config

        self.projection_dim = config.projection_dim
        self.text_embed_dim = text_config.hidden_size
        self.vision_embed_dim = vision_config.hidden_size

        self.text_model = CLIPTextTransformer(
            text_config,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "text_model"),
        )
        self.vision_model = CLIPVisionTransformer(
            vision_config,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "vision_model"),
        )

        self.visual_projection = nn.Linear(
            self.vision_embed_dim,
            self.projection_dim,
            bias=False,
        )
        self.text_projection = nn.Linear(
            self.text_embed_dim,
            self.projection_dim,
            bias=False,
        )

        pooler_config = vllm_config.model_config.pooler_config
        assert pooler_config is not None
        self.pooler_config = pooler_config

        self.pooler = DispatchPooler({
            "encode": Pooler.for_encode(pooler_config),
            "embed": Pooler.for_embed(pooler_config),
        })

        # Assumes that self.forward is called after self.get_input_embeddings
        self._is_text_input = True

    def get_text_features(
        self,
        input_ids: Optional[torch.Tensor],
        position_ids: torch.Tensor,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        pooled_output = self.text_model(
            input_ids=input_ids,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
        )

        text_features = self.text_projection(pooled_output)

        return text_features

    def get_image_features(
        self,
        pixel_values: torch.Tensor,
        feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None,
    ) -> torch.Tensor:
        if feature_select_strategy is None:
            feature_select_strategy = _get_vision_feature_select_strategy(
                self.pooler_config.pooling_type)

        pooled_output = self.vision_model(
            pixel_values=pixel_values,
            select_layers=None,
            feature_select_strategy=feature_select_strategy,
        )

        image_features = self.visual_projection(pooled_output)

        return image_features

    def _parse_and_validate_image_input(
            self, **kwargs: object) -> Optional[CLIPImagePixelInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
        if pixel_values is None:
            return None

        expected_h = expected_w = self.config.vision_config.image_size
        return CLIPImagePixelInputs(type="pixel_values",
                                    data=pixel_values,
                                    resolve_bindings={
                                        "h": expected_h,
                                        "w": expected_w
                                    })

    def _process_image_inputs(self,
                              inputs: CLIPImagePixelInputs) -> torch.Tensor:
        pixel_values = inputs["data"]

        return self.get_image_features(pixel_values)

    def get_language_model(self) -> torch.nn.Module:
        return self.text_model

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
        *,
        is_multimodal: Optional[torch.Tensor] = None,
        handle_oov_mm_token: bool = False,
    ) -> torch.Tensor:
        self._is_text_input = (multimodal_embeddings is None
                               or len(multimodal_embeddings) == 0)

        # This is to satisfy the type checker for each overload
        if multimodal_embeddings is None or is_multimodal is None:
            return super().get_input_embeddings(input_ids)

        return super().get_input_embeddings(
            input_ids,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
        )

    def get_multimodal_embeddings(self,
                                  **kwargs: object) -> MultiModalEmbeddings:
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return []

        vision_embeddings = self._process_image_inputs(image_input)
        return vision_embeddings

    def forward(
        self,
        input_ids: Optional[torch.Tensor],
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs: object,
    ) -> torch.Tensor:
        if intermediate_tensors is not None:
            raise RuntimeError("PP is not supported for this model")

        # Multimodal inputs
        if not self._is_text_input:
            return inputs_embeds

        # Text inputs
        return self.get_text_features(input_ids=input_ids,
                                      position_ids=positions,
                                      inputs_embeds=inputs_embeds)

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
        loader = AutoWeightsLoader(
            self,
            skip_substrs=[".position_ids"],
            ignore_unexpected_prefixes=["logit_scale."],
        )

        return loader.load_weights(weights)