clip.py 34 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Callable, Iterable, Mapping, Sequence
4
from functools import cached_property
5
from typing import Annotated, Literal
6
7
8

import torch
import torch.nn as nn
9
10
11
12
13
14
15
from transformers import (
    BatchFeature,
    CLIPConfig,
    CLIPProcessor,
    CLIPTextConfig,
    CLIPVisionConfig,
)
16

17
18
from vllm.attention.layer import Attention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
19
from vllm.config import VllmConfig
20
from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig
21
from vllm.distributed import divide, get_tensor_model_parallel_world_size
22
from vllm.model_executor.layers.activation import get_act_fn
23
from vllm.model_executor.layers.conv import Conv2dLayer
24
25
26
27
28
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
29
from vllm.model_executor.layers.pooler import DispatchPooler
30
from vllm.model_executor.layers.quantization import QuantizationConfig
31
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
32
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
33
from vllm.model_executor.models.interfaces import SupportsQuant
34
from vllm.multimodal import MULTIMODAL_REGISTRY
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalInputs,
    MultiModalKwargsItems,
    MultiModalUUIDDict,
)
from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptIndexTargets,
    PromptReplacement,
    PromptUpdate,
)
50
51
52
53
54
55
56
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape

from .interfaces import MultiModalEmbeddings, SupportsMultiModal
from .interfaces_base import default_pooling_type
from .utils import AutoWeightsLoader, maybe_prefix
57
58
59
60
61
62
63
from .vision import (
    VisionEncoderInfo,
    VisionFeatureSelectStrategy,
    VisionFeatureSelectStrategyStr,
    get_num_selected_vision_tokens,
    resolve_visual_encoder_outputs,
)
64

65

66
67
68
69
70
71
72
73
class CLIPImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height of each image
        - w: Width of each image
    """
74

75
76
77
78
    type: Literal["pixel_values"]
    data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]


79
80
81
82
83
84
85
class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]):
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
86
        return self.get_patch_grid_length() ** 2 + 1
87

88
89
90
91
92
93
94
    def get_image_size(self) -> int:
        return self.vision_config.image_size

    def get_patch_size(self) -> int:
        return self.vision_config.patch_size

    def get_patch_grid_length(self) -> int:
95
96
97
        image_size, patch_size = self.get_image_size(), self.get_patch_size()
        assert image_size % patch_size == 0
        return image_size // patch_size
98
99


100
101
102
103
104
105
106
107
108
109
110
111
112
_POOLING_TYPE_TO_STRATEGY: dict[str, VisionFeatureSelectStrategyStr] = {
    "MEAN": "full",
    "ALL": "full",
    "CLS": "class",
    # This lets us use the same pooling type for both text and image
    "LAST": "class",
}


def _get_vision_feature_select_strategy(pooling_type: str):
    try:
        return _POOLING_TYPE_TO_STRATEGY[pooling_type]
    except KeyError:
113
114
115
116
        raise ValueError(
            f"No feature selection strategy is defined for "
            f"pooling_type: {pooling_type!r}"
        ) from None
117
118
119
120
121
122
123
124
125
126
127
128


class CLIPProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
        return self.ctx.get_hf_config(CLIPConfig)

    def get_vision_encoder_info(self):
        return CLIPEncoderInfo(self.get_hf_config())

    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(CLIPProcessor, **kwargs)

129
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
        return {"image": 1}

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        vision_encoder_info = self.get_vision_encoder_info()

        pooler_config = self.ctx.model_config.pooler_config
        assert pooler_config is not None

        return get_num_selected_vision_tokens(
            vision_encoder_info.get_num_image_tokens(
                image_width=image_width,
                image_height=image_height,
            ),
            _get_vision_feature_select_strategy(pooler_config.pooling_type),
        )

    def get_image_size_with_most_features(self) -> ImageSize:
        vision_encoder_info = self.get_vision_encoder_info()
        width = height = vision_encoder_info.get_image_size()
        return ImageSize(width=width, height=height)

    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()

        return self.get_num_image_tokens(
            image_width=target_width,
            image_height=target_height,
        )


class CLIPDummyInputsBuilder(BaseDummyInputsBuilder[CLIPProcessingInfo]):
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        return ""

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
173
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
174
175
176
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)

177
        target_width, target_height = self.info.get_image_size_with_most_features()
178
179
180
181

        image_overrides = mm_options.get("image") if mm_options else None

        return {
182
183
184
185
186
187
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
        }


class CLIPMultiModalProcessor(BaseMultiModalProcessor[CLIPProcessingInfo]):
    @cached_property
    def image_token_id(self) -> int:
        tokenizer = self.info.get_tokenizer()
        dummy_token_id = 0

        assert dummy_token_id not in tokenizer.all_special_ids

        return dummy_token_id

    def apply(
        self,
203
        prompt: str | list[int],
204
205
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
206
        tokenization_kwargs: Mapping[str, object] | None = None,
207
        *,
208
        mm_uuids: MultiModalUUIDDict | None = None,
209
210
211
212
213
    ) -> MultiModalInputs:
        if prompt and mm_data:
            raise ValueError(
                "CLIP accepts text-only or image-only inputs, not both! "
                "Image-only inputs means passing an image with an empty text "
214
215
                "prompt."
            )
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282

        if mm_data:
            # For multi-modal data, the prompt after processing should
            # only contain the dummy image tokens
            tokenization_kwargs = {
                **(tokenization_kwargs or {}),
                "add_special_tokens": False,
            }

        return super().apply(
            prompt=prompt,
            mm_data=mm_data,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
            tokenization_kwargs=tokenization_kwargs,
            mm_uuids=mm_uuids,
        )

    def _hf_processor_applies_updates(
        self,
        prompt_text: str,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        tokenization_kwargs: Mapping[str, object],
    ) -> bool:
        return False

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(pixel_values=MultiModalFieldConfig.batched("image"))

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargsItems,
    ) -> Sequence[PromptUpdate]:
        image_token_id = self.image_token_id

        def get_replacement(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)

            num_image_tokens = self.info.get_num_image_tokens(
                image_width=image_size.width,
                image_height=image_size.height,
            )
            return [image_token_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=PromptIndexTargets.start(),
                replacement=get_replacement,
            ),
        ]


# Adapted from: https://github.com/huggingface/transformers/blob/v4.56.2/src/transformers/models/clip/modeling_clip.py
class CLIPTextEmbeddings(nn.Module):
    def __init__(self, config: CLIPTextConfig):
        super().__init__()

        embed_dim = config.hidden_size

283
        self.token_embedding = VocabParallelEmbedding(config.vocab_size, embed_dim)
284
        self.position_embedding = VocabParallelEmbedding(
285
286
            config.max_position_embeddings, embed_dim
        )
287
288
289

    def forward(
        self,
290
        input_ids: torch.Tensor | None,
291
        position_ids: torch.Tensor,
292
        inputs_embeds: torch.Tensor | None = None,
293
294
295
296
    ) -> torch.Tensor:
        if inputs_embeds is None:
            if input_ids is None:
                raise ValueError(
297
298
                    "Either `input_ids` or `input_embeds` must be provided"
                )
299
300
301
302
303
304
305
306
307

            inputs_embeds = self.token_embedding(input_ids)

        position_embeddings = self.position_embedding(position_ids)
        embeddings = inputs_embeds + position_embeddings

        return embeddings


308
309
310
311
312
313
314
class CLIPVisionEmbeddings(nn.Module):
    def __init__(self, config: CLIPVisionConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size
315
        assert self.image_size % self.patch_size == 0
316
317
318

        self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))

319
        self.patch_embedding = Conv2dLayer(
320
321
322
323
324
325
326
            in_channels=config.num_channels,
            out_channels=self.embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            bias=False,
        )

327
        self.num_patches = (self.image_size // self.patch_size) ** 2
328
        self.num_positions = self.num_patches + 1
329
330
331
332
333
334
        self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
        self.register_buffer(
            "position_ids",
            torch.arange(self.num_positions).expand((1, -1)),
            persistent=False,
        )
335
336
337
338

    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
        batch_size = pixel_values.shape[0]
        target_dtype = self.patch_embedding.weight.dtype
339
340
341
        patch_embeds = self.patch_embedding(
            pixel_values.to(dtype=target_dtype)
        )  # shape = [*, width, grid, grid]
342
343
344
345
346
347
348
349
350
        patch_embeds = patch_embeds.flatten(2).transpose(1, 2)

        class_embeds = self.class_embedding.expand(batch_size, 1, -1)
        embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
        embeddings = embeddings + self.position_embedding(self.position_ids)

        return embeddings


351
class CLIPAttention(nn.Module):
352
353
    def __init__(
        self,
354
355
        config: CLIPTextConfig | CLIPVisionConfig,
        quant_config: QuantizationConfig | None = None,
356
        multimodal_config: MultiModalConfig | None = None,
357
        *,
358
        prefix: str = "",
359
        attn_cls: type[Attention] | type[MMEncoderAttention],
360
    ) -> None:
361
        super().__init__()
362

363
364
365
366
367
368
        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
369
370
371
                f"embed_dim must be divisible by num_heads "
                f"(got `embed_dim`: {self.embed_dim} and "
                f"`num_heads`: {self.num_heads})."
372
            )
373
374
        self.scale = self.head_dim**-0.5

375
376
377
378
379
        use_data_parallel = (
            multimodal_config.mm_encoder_tp_mode == "data"
            if multimodal_config
            else False
        )
380
381
382
383
384
        self.qkv_proj = QKVParallelLinear(
            hidden_size=self.embed_dim,
            head_size=self.head_dim,
            total_num_heads=self.num_heads,
            quant_config=quant_config,
385
            prefix=f"{prefix}.qkv_proj",
386
            disable_tp=use_data_parallel,
387
388
389
390
391
392
        )

        self.out_proj = RowParallelLinear(
            input_size=self.embed_dim,
            output_size=self.embed_dim,
            quant_config=quant_config,
393
            prefix=f"{prefix}.out_proj",
394
            disable_tp=use_data_parallel,
395
396
        )

397
398
399
        self.tp_size = (
            1 if use_data_parallel else get_tensor_model_parallel_world_size()
        )
400
401
        self.num_heads_per_partition = divide(self.num_heads, self.tp_size)

402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
        if attn_cls == MMEncoderAttention:
            self.attn = attn_cls(
                self.num_heads_per_partition,
                self.head_dim,
                self.scale,
                prefix=f"{prefix}.attn",
                multimodal_config=multimodal_config,
            )
        else:
            self.attn = attn_cls(
                self.num_heads_per_partition,
                self.head_dim,
                self.scale,
                prefix=f"{prefix}.attn",
            )
417

418
419
420
421
422
423
424
425
    def forward(
        self,
        hidden_states: torch.Tensor,
    ):
        """Input shape: Batch x Time x Channel"""

        qkv_states, _ = self.qkv_proj(hidden_states)
        query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
426
        out = self.attn(query_states, key_states, value_states)
427
428
        attn_output, _ = self.out_proj(out)

429
        return attn_output, None
430
431


432
class CLIPMLP(nn.Module):
433
434
    def __init__(
        self,
435
436
        config: CLIPTextConfig | CLIPVisionConfig,
        quant_config: QuantizationConfig | None = None,
437
        multimodal_config: MultiModalConfig | None = None,
438
439
        prefix: str = "",
    ) -> None:
440
        super().__init__()
441

442
        self.config = config
443
444
445
446
447
        use_data_parallel = (
            multimodal_config.mm_encoder_tp_mode == "data"
            if multimodal_config
            else False
        )
448
        self.activation_fn = get_act_fn(config.hidden_act)
449

450
451
452
453
454
455
        self.fc1 = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
456
            disable_tp=use_data_parallel,
457
458
459
460
461
462
463
        )
        self.fc2 = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
464
            disable_tp=use_data_parallel,
465
        )
466
467
468
469
470
471
472
473
474
475

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)

        return hidden_states


class CLIPEncoderLayer(nn.Module):
476
477
    def __init__(
        self,
478
479
        config: CLIPTextConfig | CLIPVisionConfig,
        quant_config: QuantizationConfig | None = None,
480
        multimodal_config: MultiModalConfig | None = None,
481
        *,
482
        prefix: str = "",
483
        attn_cls: type[Attention] | type[MMEncoderAttention],
484
    ) -> None:
485
        super().__init__()
486

487
488
489
        self.self_attn = CLIPAttention(
            config,
            quant_config=quant_config,
490
            multimodal_config=multimodal_config,
491
            prefix=f"{prefix}.self_attn",
492
            attn_cls=attn_cls,
493
        )
494
        self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
495
496
497
498
499
500
        self.mlp = CLIPMLP(
            config,
            quant_config=quant_config,
            multimodal_config=multimodal_config,
            prefix=f"{prefix}.mlp",
        )
501
        self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
502

503
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
504
505
506
        residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
507
        hidden_states, _ = self.self_attn(hidden_states=hidden_states)
508
509
510
511
512
513
514
515
516
517
518
519
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states


class CLIPEncoder(nn.Module):
    """
520
    Transformer encoder consisting of `config.num_hidden_layers` self
521
522
523
524
525
526
    attention layers. Each layer is a [`CLIPEncoderLayer`].

    Args:
        config: CLIPConfig
    """

527
528
    def __init__(
        self,
529
530
        config: CLIPTextConfig | CLIPVisionConfig,
        quant_config: QuantizationConfig | None = None,
531
        multimodal_config: MultiModalConfig | None = None,
532
        num_hidden_layers_override: int | None = None,
533
        *,
534
        prefix: str = "",
535
        attn_cls: type[Attention] | type[MMEncoderAttention],
536
    ) -> None:
537
        super().__init__()
538

539
        self.config = config
540
541
542
543
544

        if num_hidden_layers_override is None:
            num_hidden_layers = config.num_hidden_layers
        else:
            num_hidden_layers = num_hidden_layers_override
545

546
547
548
549
550
        self.layers = nn.ModuleList(
            [
                CLIPEncoderLayer(
                    config=config,
                    quant_config=quant_config,
551
                    multimodal_config=multimodal_config,
552
553
554
555
556
557
                    prefix=f"{prefix}.layers.{layer_idx}",
                    attn_cls=attn_cls,
                )
                for layer_idx in range(num_hidden_layers)
            ]
        )
558

559
    def forward(
560
561
562
        self,
        inputs_embeds: torch.Tensor,
        return_all_hidden_states: bool,
563
    ) -> torch.Tensor | list[torch.Tensor]:
564
        hidden_states_pool = [inputs_embeds]
565
        hidden_states = inputs_embeds
566

567
        for encoder_layer in self.layers:
568
            hidden_states = encoder_layer(hidden_states)
569
570
571
572
573
574
            if return_all_hidden_states:
                hidden_states_pool.append(hidden_states)
        # If we have multiple feature sample layers, we return all hidden
        # states in order and grab the ones we need by index.
        if return_all_hidden_states:
            return hidden_states_pool
575
576
577
        return hidden_states


578
579
580
581
class CLIPTextTransformer(nn.Module):
    def __init__(
        self,
        config: CLIPTextConfig,
582
        quant_config: QuantizationConfig | None = None,
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
        *,
        prefix: str = "",
    ) -> None:
        super().__init__()

        self.config = config
        embed_dim = config.hidden_size

        self.embeddings = CLIPTextEmbeddings(config)

        self.encoder = CLIPEncoder(
            config=config,
            quant_config=quant_config,
            prefix=f"{prefix}.encoder",
            attn_cls=Attention,
        )

        self.final_layer_norm = nn.LayerNorm(
            embed_dim,
            eps=config.layer_norm_eps,
        )

605
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
606
607
608
609
        return self.embeddings.token_embedding(input_ids)

    def forward(
        self,
610
        input_ids: torch.Tensor | None,
611
        position_ids: torch.Tensor,
612
        inputs_embeds: torch.Tensor | None = None,
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
    ) -> torch.Tensor:
        hidden_states = self.embeddings(
            input_ids=input_ids,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
        )

        last_hidden_state = self.encoder(
            inputs_embeds=hidden_states,
            return_all_hidden_states=False,
        )
        last_hidden_state = self.final_layer_norm(last_hidden_state)

        return last_hidden_state

628
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
629
630
631
632
633
634
635
636
637
638
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()

        for name, loaded_weight in weights:
639
            for param_name, weight_name, shard_id in stacked_params_mapping:
640
641
642
643
644
645
646
647
648
649
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
650
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
651
652
653
654
655
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


656
class CLIPVisionTransformer(nn.Module):
657
658
659
    def __init__(
        self,
        config: CLIPVisionConfig,
660
        quant_config: QuantizationConfig | None = None,
661
        multimodal_config: MultiModalConfig | None = None,
662
        *,
663
664
        num_hidden_layers_override: int | None = None,
        require_post_norm: bool | None = None,
665
666
        prefix: str = "",
    ) -> None:
667
        super().__init__()
668

669
670
671
672
673
674
675
676
        self.config = config
        embed_dim = config.hidden_size

        self.embeddings = CLIPVisionEmbeddings(config)

        # NOTE: This typo of "layrnorm" is not fixed on purpose to match
        # the original transformers code and name of the model weights.
        self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
677

678
679
680
        self.encoder = CLIPEncoder(
            config=config,
            quant_config=quant_config,
681
            multimodal_config=multimodal_config,
682
683
            num_hidden_layers_override=num_hidden_layers_override,
            prefix=f"{prefix}.encoder",
684
            attn_cls=MMEncoderAttention,
685
        )
686

687
        num_hidden_layers = config.num_hidden_layers
688
689
        if len(self.encoder.layers) > config.num_hidden_layers:
            raise ValueError(
690
                f"The original encoder only has {num_hidden_layers} "
691
692
                f"layers, but you requested {len(self.encoder.layers)} layers."
            )
693
694
695
696
697
698

        # If possible, skip post_layernorm to conserve memory
        if require_post_norm is None:
            require_post_norm = len(self.encoder.layers) == num_hidden_layers

        if require_post_norm:
699
            self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
700
701
702
        else:
            self.post_layernorm = None

703
704
705
706
707
708
709
710
    @property
    def dtype(self):
        return next(self.parameters()).dtype

    @property
    def device(self):
        return next(self.parameters()).device

711
712
713
    def forward(
        self,
        pixel_values: torch.Tensor,
714
        *,
715
716
        select_layers: list[int] | None = None,
        feature_select_strategy: VisionFeatureSelectStrategy | None = None,
717
718
719
720
    ) -> torch.Tensor:
        hidden_states = self.embeddings(pixel_values)
        hidden_states = self.pre_layrnorm(hidden_states)

721
        # Produces either the last layer output or all of the hidden states,
722
        # depending on if we have select_layers or not
723
724
        encoder_outputs = self.encoder(
            inputs_embeds=hidden_states,
725
726
            return_all_hidden_states=select_layers is not None,
        )
727
728
729

        # Handle post-norm (if applicable) and stacks feature layers if needed
        encoder_outputs = resolve_visual_encoder_outputs(
730
731
732
733
734
735
            encoder_outputs,
            self.post_layernorm,
            select_layers=select_layers,
            max_possible_layers=self.config.num_hidden_layers,
            feature_select_strategy=feature_select_strategy,
        )
736

737
        return encoder_outputs
738

739
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
740
741
742
743
744
745
746
747
748
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        layer_count = len(self.encoder.layers)
749

750
751
        for name, loaded_weight in weights:
            # post_layernorm is not needed in CLIPVisionModel
752
            if name.startswith("post_layernorm") and self.post_layernorm is None:
753
754
755
756
757
758
759
760
                continue

            # omit layers when num_hidden_layers_override is set
            if name.startswith("encoder.layers"):
                layer_idx = int(name.split(".")[2])
                if layer_idx >= layer_count:
                    continue

761
            for param_name, weight_name, shard_id in stacked_params_mapping:
762
763
764
765
766
767
768
769
770
771
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
772
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
773
774
775
776
777
778
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


class CLIPVisionModel(nn.Module):
779
780
781
    def __init__(
        self,
        config: CLIPVisionConfig,
782
        quant_config: QuantizationConfig | None = None,
783
        multimodal_config: MultiModalConfig | None = None,
784
        *,
785
786
        num_hidden_layers_override: int | None = None,
        require_post_norm: bool | None = None,
787
788
        prefix: str = "",
    ) -> None:
789
        super().__init__()
790

791
792
793
        self.vision_model = CLIPVisionTransformer(
            config=config,
            quant_config=quant_config,
794
            multimodal_config=multimodal_config,
795
796
            num_hidden_layers_override=num_hidden_layers_override,
            require_post_norm=require_post_norm,
797
798
            prefix=f"{prefix}.vision_model",
        )
799

800
801
802
    def forward(
        self,
        pixel_values: torch.Tensor,
803
804
        select_layers: list[int] | None = None,
        feature_select_strategy: VisionFeatureSelectStrategy | None = None,
805
    ) -> torch.Tensor:
806
807
808
809
810
        return self.vision_model(
            pixel_values,
            select_layers=select_layers,
            feature_select_strategy=feature_select_strategy,
        )
811

812
813
814
815
    @property
    def dtype(self):
        return self.vision_model.dtype

816
817
    @property
    def device(self):
818
        return self.vision_model.device
819
820


821
822
# Assume EOS token corresponds to LAST token in text model
@default_pooling_type("LAST")
823
824
825
826
827
@MULTIMODAL_REGISTRY.register_processor(
    CLIPMultiModalProcessor,
    info=CLIPProcessingInfo,
    dummy_inputs=CLIPDummyInputsBuilder,
)
828
829
class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
    is_pooling_model = True
830

831
    packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
832

833
    @classmethod
834
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
        if modality.startswith("image"):
            return None

        raise ValueError("Only image modality is supported")

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config: CLIPConfig = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
        self.config = config
        self.multimodal_config = multimodal_config

        text_config = config.text_config
        vision_config = config.vision_config

        self.projection_dim = config.projection_dim
        self.text_embed_dim = text_config.hidden_size
        self.vision_embed_dim = vision_config.hidden_size

        self.text_model = CLIPTextTransformer(
            text_config,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "text_model"),
        )
        self.vision_model = CLIPVisionTransformer(
            vision_config,
            quant_config=quant_config,
864
            multimodal_config=multimodal_config,
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
            prefix=maybe_prefix(prefix, "vision_model"),
        )

        self.visual_projection = nn.Linear(
            self.vision_embed_dim,
            self.projection_dim,
            bias=False,
        )
        self.text_projection = nn.Linear(
            self.text_embed_dim,
            self.projection_dim,
            bias=False,
        )

        pooler_config = vllm_config.model_config.pooler_config
        assert pooler_config is not None
        self.pooler_config = pooler_config

883
        self.pooler = DispatchPooler.for_embedding(pooler_config)
884

885
        # Assumes that self.forward is called after self.embed_input_ids
886
887
888
889
        self._is_text_input = True

    def get_text_features(
        self,
890
        input_ids: torch.Tensor | None,
891
        position_ids: torch.Tensor,
892
        inputs_embeds: torch.Tensor | None = None,
893
894
895
896
897
898
899
900
901
902
903
904
905
906
    ) -> torch.Tensor:
        pooled_output = self.text_model(
            input_ids=input_ids,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
        )

        text_features = self.text_projection(pooled_output)

        return text_features

    def get_image_features(
        self,
        pixel_values: torch.Tensor,
907
        feature_select_strategy: VisionFeatureSelectStrategy | None = None,
908
909
910
    ) -> torch.Tensor:
        if feature_select_strategy is None:
            feature_select_strategy = _get_vision_feature_select_strategy(
911
912
                self.pooler_config.pooling_type
            )
913
914
915
916
917
918
919
920
921
922
923
924

        pooled_output = self.vision_model(
            pixel_values=pixel_values,
            select_layers=None,
            feature_select_strategy=feature_select_strategy,
        )

        image_features = self.visual_projection(pooled_output)

        return image_features

    def _parse_and_validate_image_input(
925
        self, **kwargs: object
926
    ) -> CLIPImagePixelInputs | None:
927
928
929
930
931
        pixel_values = kwargs.pop("pixel_values", None)
        if pixel_values is None:
            return None

        expected_h = expected_w = self.config.vision_config.image_size
932
933
934
935
936
937
938
        return CLIPImagePixelInputs(
            type="pixel_values",
            data=pixel_values,
            resolve_bindings={"h": expected_h, "w": expected_w},
        )

    def _process_image_inputs(self, inputs: CLIPImagePixelInputs) -> torch.Tensor:
939
940
941
942
943
944
945
        pixel_values = inputs["data"]

        return self.get_image_features(pixel_values)

    def get_language_model(self) -> torch.nn.Module:
        return self.text_model

946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
    def _embed_text_input_ids(
        self,
        input_ids: torch.Tensor,
        embed_input_ids: Callable[[torch.Tensor], torch.Tensor],
        *,
        is_multimodal: torch.Tensor | None,
        handle_oov_mm_token: bool,
    ) -> torch.Tensor:
        inputs_embeds = super()._embed_text_input_ids(
            input_ids,
            embed_input_ids,
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
        )

        # NOTE: inputs_embeds in model runner has size text_config.projection_dim
        # (instead of text_config.hidden_size) to accommodate image embeddings
        inputs_embeds_size = self.projection_dim
        if inputs_embeds.shape[1] < inputs_embeds_size:
            inputs_embeds = torch.cat(
                [
                    inputs_embeds,
                    inputs_embeds.new_empty(
                        inputs_embeds.shape[0],
                        inputs_embeds_size - inputs_embeds.shape[1],
                    ),
                ],
                dim=1,
            )
        elif inputs_embeds.shape[1] > inputs_embeds_size:
            # No need to handle this case for now
            raise NotImplementedError

        return inputs_embeds

981
    def embed_input_ids(
982
983
        self,
        input_ids: torch.Tensor,
984
        multimodal_embeddings: MultiModalEmbeddings | None = None,
985
        *,
986
        is_multimodal: torch.Tensor | None = None,
987
988
        handle_oov_mm_token: bool = False,
    ) -> torch.Tensor:
989
990
991
        self._is_text_input = (
            multimodal_embeddings is None or len(multimodal_embeddings) == 0
        )
992
993
994

        # This is to satisfy the type checker for each overload
        if multimodal_embeddings is None or is_multimodal is None:
995
            return super().embed_input_ids(input_ids)
996

997
        return super().embed_input_ids(
998
999
1000
1001
1002
1003
            input_ids,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
        )

1004
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
1005
1006
1007
1008
1009
1010
1011
1012
1013
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return []

        vision_embeddings = self._process_image_inputs(image_input)
        return vision_embeddings

    def forward(
        self,
1014
        input_ids: torch.Tensor | None,
1015
        positions: torch.Tensor,
1016
1017
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1018
1019
1020
1021
1022
1023
1024
1025
1026
        **kwargs: object,
    ) -> torch.Tensor:
        if intermediate_tensors is not None:
            raise RuntimeError("PP is not supported for this model")

        # Multimodal inputs
        if not self._is_text_input:
            return inputs_embeds

1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
        # NOTE: inputs_embeds in model runner has size text_config.projection_dim
        # (instead of text_config.hidden_size) to accommodate image embeddings
        hidden_size = self.text_embed_dim
        if inputs_embeds.shape[1] > hidden_size:
            inputs_embeds = inputs_embeds[:, :hidden_size]
        elif inputs_embeds.shape[1] < hidden_size:
            # No need to handle this case for now
            raise NotImplementedError

        return self.get_text_features(input_ids, positions, inputs_embeds)
1037
1038
1039
1040
1041
1042
1043
1044
1045

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
        loader = AutoWeightsLoader(
            self,
            skip_substrs=[".position_ids"],
            ignore_unexpected_prefixes=["logit_scale."],
        )

        return loader.load_weights(weights)