clip.py 33.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Callable, Iterable, Mapping, Sequence
4
from functools import cached_property
5
from typing import Annotated, Literal
6
7
8

import torch
import torch.nn as nn
9
10
11
12
13
14
15
from transformers import (
    BatchFeature,
    CLIPConfig,
    CLIPProcessor,
    CLIPTextConfig,
    CLIPVisionConfig,
)
16

17
from vllm.config import VllmConfig
18
from vllm.config.multimodal import BaseDummyOptions
19
from vllm.distributed import divide, get_tensor_model_parallel_world_size
20
from vllm.model_executor.layers.activation import get_act_fn
21
from vllm.model_executor.layers.attention import Attention, MMEncoderAttention
22
from vllm.model_executor.layers.conv import Conv2dLayer
23
24
25
26
27
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
28
from vllm.model_executor.layers.pooler import DispatchPooler
29
from vllm.model_executor.layers.quantization import QuantizationConfig
30
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
31
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
32
from vllm.model_executor.models.interfaces import SupportsQuant
33
from vllm.multimodal import MULTIMODAL_REGISTRY
34
35
36
37
38
39
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalInputs,
    MultiModalKwargsItems,
)
40
41
42
43
44
45
from vllm.multimodal.parse import (
    ImageProcessorItems,
    ImageSize,
    MultiModalDataItems,
    MultiModalUUIDItems,
)
46
from vllm.multimodal.processing import (
47
    BaseDummyInputsBuilder,
48
49
50
51
52
53
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptIndexTargets,
    PromptReplacement,
    PromptUpdate,
)
54
55
56
57
58
59
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape

from .interfaces import MultiModalEmbeddings, SupportsMultiModal
from .interfaces_base import default_pooling_type
from .utils import AutoWeightsLoader, maybe_prefix
60
61
62
63
64
from .vision import (
    VisionEncoderInfo,
    VisionFeatureSelectStrategy,
    VisionFeatureSelectStrategyStr,
    get_num_selected_vision_tokens,
65
    is_vit_use_data_parallel,
66
67
    resolve_visual_encoder_outputs,
)
68

69

70
71
72
73
74
75
76
77
class CLIPImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height of each image
        - w: Width of each image
    """
78

79
80
81
82
    type: Literal["pixel_values"]
    data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]


83
84
85
86
87
88
89
class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]):
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
90
        return self.get_patch_grid_length() ** 2 + 1
91

92
93
94
95
96
97
98
    def get_image_size(self) -> int:
        return self.vision_config.image_size

    def get_patch_size(self) -> int:
        return self.vision_config.patch_size

    def get_patch_grid_length(self) -> int:
99
100
101
        image_size, patch_size = self.get_image_size(), self.get_patch_size()
        assert image_size % patch_size == 0
        return image_size // patch_size
102
103


104
105
106
107
108
109
110
111
112
113
114
115
116
_POOLING_TYPE_TO_STRATEGY: dict[str, VisionFeatureSelectStrategyStr] = {
    "MEAN": "full",
    "ALL": "full",
    "CLS": "class",
    # This lets us use the same pooling type for both text and image
    "LAST": "class",
}


def _get_vision_feature_select_strategy(pooling_type: str):
    try:
        return _POOLING_TYPE_TO_STRATEGY[pooling_type]
    except KeyError:
117
118
119
120
        raise ValueError(
            f"No feature selection strategy is defined for "
            f"pooling_type: {pooling_type!r}"
        ) from None
121
122
123
124
125
126
127
128
129
130
131
132


class CLIPProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
        return self.ctx.get_hf_config(CLIPConfig)

    def get_vision_encoder_info(self):
        return CLIPEncoderInfo(self.get_hf_config())

    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(CLIPProcessor, **kwargs)

133
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
        return {"image": 1}

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        vision_encoder_info = self.get_vision_encoder_info()

        pooler_config = self.ctx.model_config.pooler_config
        assert pooler_config is not None

        return get_num_selected_vision_tokens(
            vision_encoder_info.get_num_image_tokens(
                image_width=image_width,
                image_height=image_height,
            ),
152
            _get_vision_feature_select_strategy(pooler_config.seq_pooling_type),
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
        )

    def get_image_size_with_most_features(self) -> ImageSize:
        vision_encoder_info = self.get_vision_encoder_info()
        width = height = vision_encoder_info.get_image_size()
        return ImageSize(width=width, height=height)

    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()

        return self.get_num_image_tokens(
            image_width=target_width,
            image_height=target_height,
        )


class CLIPDummyInputsBuilder(BaseDummyInputsBuilder[CLIPProcessingInfo]):
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        return ""

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
177
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
178
        mm_processor_kwargs: Mapping[str, object] | None = None,
179
180
181
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)

182
        target_width, target_height = self.info.get_image_size_with_most_features()
183
184
185
186

        image_overrides = mm_options.get("image") if mm_options else None

        return {
187
188
189
190
191
192
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
        }


class CLIPMultiModalProcessor(BaseMultiModalProcessor[CLIPProcessingInfo]):
    @cached_property
    def image_token_id(self) -> int:
        tokenizer = self.info.get_tokenizer()
        dummy_token_id = 0

        assert dummy_token_id not in tokenizer.all_special_ids

        return dummy_token_id

    def apply(
        self,
208
        prompt: str | list[int],
209
        mm_items: MultiModalDataItems,
210
211
        mm_uuid_items: MultiModalUUIDItems | None = None,
        hf_processor_mm_kwargs: Mapping[str, object] | None = None,
212
        tokenization_kwargs: Mapping[str, object] | None = None,
213
    ) -> MultiModalInputs:
214
        if mm_items:
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
            if isinstance(prompt, str):
                if len(prompt) > 0:
                    raise ValueError(
                        "CLIP accepts text-only or image-only inputs, not both! "
                        "You must pass an image with an empty text prompt."
                    )
            else:
                special_tokens = self.info.get_tokenizer().all_special_ids
                if all(tok in special_tokens for tok in prompt):
                    prompt = []
                else:
                    raise ValueError(
                        "CLIP accepts text-only or image-only inputs, not both! "
                        "You must pass an image with an empty token prompt."
                    )

231
232
233
234
235
236
237
238
239
            # For multi-modal data, the prompt after processing should
            # only contain the dummy image tokens
            tokenization_kwargs = {
                **(tokenization_kwargs or {}),
                "add_special_tokens": False,
            }

        return super().apply(
            prompt=prompt,
240
            mm_items=mm_items,
241
            mm_uuid_items=mm_uuid_items,
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
            tokenization_kwargs=tokenization_kwargs,
        )

    def _hf_processor_applies_updates(
        self,
        prompt_text: str,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        tokenization_kwargs: Mapping[str, object],
    ) -> bool:
        return False

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(pixel_values=MultiModalFieldConfig.batched("image"))

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargsItems,
    ) -> Sequence[PromptUpdate]:
        image_token_id = self.image_token_id

        def get_replacement(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)

            num_image_tokens = self.info.get_num_image_tokens(
                image_width=image_size.width,
                image_height=image_size.height,
            )
            return [image_token_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=PromptIndexTargets.start(),
                replacement=get_replacement,
            ),
        ]


# Adapted from: https://github.com/huggingface/transformers/blob/v4.56.2/src/transformers/models/clip/modeling_clip.py
class CLIPTextEmbeddings(nn.Module):
    def __init__(self, config: CLIPTextConfig):
        super().__init__()

        embed_dim = config.hidden_size

296
        self.token_embedding = VocabParallelEmbedding(config.vocab_size, embed_dim)
297
        self.position_embedding = VocabParallelEmbedding(
298
299
            config.max_position_embeddings, embed_dim
        )
300
301
302

    def forward(
        self,
303
        input_ids: torch.Tensor | None,
304
        position_ids: torch.Tensor,
305
        inputs_embeds: torch.Tensor | None = None,
306
307
308
309
    ) -> torch.Tensor:
        if inputs_embeds is None:
            if input_ids is None:
                raise ValueError(
310
311
                    "Either `input_ids` or `input_embeds` must be provided"
                )
312
313
314
315
316
317
318
319
320

            inputs_embeds = self.token_embedding(input_ids)

        position_embeddings = self.position_embedding(position_ids)
        embeddings = inputs_embeds + position_embeddings

        return embeddings


321
322
323
324
325
326
327
class CLIPVisionEmbeddings(nn.Module):
    def __init__(self, config: CLIPVisionConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size
328
        assert self.image_size % self.patch_size == 0
329
330
331

        self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))

332
        self.patch_embedding = Conv2dLayer(
333
334
335
336
337
338
339
            in_channels=config.num_channels,
            out_channels=self.embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            bias=False,
        )

340
        self.num_patches = (self.image_size // self.patch_size) ** 2
341
        self.num_positions = self.num_patches + 1
342
343
344
345
346
347
        self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
        self.register_buffer(
            "position_ids",
            torch.arange(self.num_positions).expand((1, -1)),
            persistent=False,
        )
348
349
350
351

    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
        batch_size = pixel_values.shape[0]
        target_dtype = self.patch_embedding.weight.dtype
352
353
354
        patch_embeds = self.patch_embedding(
            pixel_values.to(dtype=target_dtype)
        )  # shape = [*, width, grid, grid]
355
356
357
358
359
360
361
362
363
        patch_embeds = patch_embeds.flatten(2).transpose(1, 2)

        class_embeds = self.class_embedding.expand(batch_size, 1, -1)
        embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
        embeddings = embeddings + self.position_embedding(self.position_ids)

        return embeddings


364
class CLIPAttention(nn.Module):
365
366
    def __init__(
        self,
367
368
        config: CLIPTextConfig | CLIPVisionConfig,
        quant_config: QuantizationConfig | None = None,
369
        *,
370
        prefix: str = "",
371
        attn_cls: type[Attention] | type[MMEncoderAttention],
372
    ) -> None:
373
        super().__init__()
374

375
376
377
378
379
380
        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
381
382
383
                f"embed_dim must be divisible by num_heads "
                f"(got `embed_dim`: {self.embed_dim} and "
                f"`num_heads`: {self.num_heads})."
384
            )
385
386
        self.scale = self.head_dim**-0.5

387
        use_data_parallel = is_vit_use_data_parallel()
388
389
390
391
392
        self.qkv_proj = QKVParallelLinear(
            hidden_size=self.embed_dim,
            head_size=self.head_dim,
            total_num_heads=self.num_heads,
            quant_config=quant_config,
393
            prefix=f"{prefix}.qkv_proj",
394
            disable_tp=use_data_parallel,
395
396
397
398
399
400
        )

        self.out_proj = RowParallelLinear(
            input_size=self.embed_dim,
            output_size=self.embed_dim,
            quant_config=quant_config,
401
            prefix=f"{prefix}.out_proj",
402
            disable_tp=use_data_parallel,
403
404
        )

405
406
407
        self.tp_size = (
            1 if use_data_parallel else get_tensor_model_parallel_world_size()
        )
408
409
        self.num_heads_per_partition = divide(self.num_heads, self.tp_size)

410
411
412
413
414
415
416
417
418
419
420
421
422
423
        if attn_cls == MMEncoderAttention:
            self.attn = attn_cls(
                self.num_heads_per_partition,
                self.head_dim,
                self.scale,
                prefix=f"{prefix}.attn",
            )
        else:
            self.attn = attn_cls(
                self.num_heads_per_partition,
                self.head_dim,
                self.scale,
                prefix=f"{prefix}.attn",
            )
424

425
426
427
428
429
430
431
432
    def forward(
        self,
        hidden_states: torch.Tensor,
    ):
        """Input shape: Batch x Time x Channel"""

        qkv_states, _ = self.qkv_proj(hidden_states)
        query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
433
        out = self.attn(query_states, key_states, value_states)
434
435
        attn_output, _ = self.out_proj(out)

436
        return attn_output, None
437
438


439
class CLIPMLP(nn.Module):
440
441
    def __init__(
        self,
442
443
        config: CLIPTextConfig | CLIPVisionConfig,
        quant_config: QuantizationConfig | None = None,
444
445
        prefix: str = "",
    ) -> None:
446
        super().__init__()
447

448
        self.config = config
449
        use_data_parallel = is_vit_use_data_parallel()
450
        self.activation_fn = get_act_fn(config.hidden_act)
451

452
453
454
455
456
457
        self.fc1 = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
458
            disable_tp=use_data_parallel,
459
460
461
462
463
464
465
        )
        self.fc2 = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
466
            disable_tp=use_data_parallel,
467
        )
468
469
470
471
472
473
474
475
476
477

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)

        return hidden_states


class CLIPEncoderLayer(nn.Module):
478
479
    def __init__(
        self,
480
481
        config: CLIPTextConfig | CLIPVisionConfig,
        quant_config: QuantizationConfig | None = None,
482
        *,
483
        prefix: str = "",
484
        attn_cls: type[Attention] | type[MMEncoderAttention],
485
    ) -> None:
486
        super().__init__()
487

488
489
490
491
        self.self_attn = CLIPAttention(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
492
            attn_cls=attn_cls,
493
        )
494
        self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
495
496
497
498
499
        self.mlp = CLIPMLP(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
        )
500
        self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
501

502
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
503
504
505
        residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
506
        hidden_states, _ = self.self_attn(hidden_states=hidden_states)
507
508
509
510
511
512
513
514
515
516
517
518
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states


class CLIPEncoder(nn.Module):
    """
519
    Transformer encoder consisting of `config.num_hidden_layers` self
520
521
522
523
524
525
    attention layers. Each layer is a [`CLIPEncoderLayer`].

    Args:
        config: CLIPConfig
    """

526
527
    def __init__(
        self,
528
529
530
        config: CLIPTextConfig | CLIPVisionConfig,
        quant_config: QuantizationConfig | None = None,
        num_hidden_layers_override: int | None = None,
531
        *,
532
        prefix: str = "",
533
        attn_cls: type[Attention] | type[MMEncoderAttention],
534
    ) -> None:
535
        super().__init__()
536

537
        self.config = config
538
539
540
541
542

        if num_hidden_layers_override is None:
            num_hidden_layers = config.num_hidden_layers
        else:
            num_hidden_layers = num_hidden_layers_override
543

544
545
546
547
548
549
550
551
552
553
554
        self.layers = nn.ModuleList(
            [
                CLIPEncoderLayer(
                    config=config,
                    quant_config=quant_config,
                    prefix=f"{prefix}.layers.{layer_idx}",
                    attn_cls=attn_cls,
                )
                for layer_idx in range(num_hidden_layers)
            ]
        )
555

556
    def forward(
557
558
559
        self,
        inputs_embeds: torch.Tensor,
        return_all_hidden_states: bool,
560
    ) -> torch.Tensor | list[torch.Tensor]:
561
        hidden_states_pool = [inputs_embeds]
562
        hidden_states = inputs_embeds
563

564
        for encoder_layer in self.layers:
565
            hidden_states = encoder_layer(hidden_states)
566
567
568
569
570
571
            if return_all_hidden_states:
                hidden_states_pool.append(hidden_states)
        # If we have multiple feature sample layers, we return all hidden
        # states in order and grab the ones we need by index.
        if return_all_hidden_states:
            return hidden_states_pool
572
573
574
        return hidden_states


575
576
577
578
class CLIPTextTransformer(nn.Module):
    def __init__(
        self,
        config: CLIPTextConfig,
579
        quant_config: QuantizationConfig | None = None,
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
        *,
        prefix: str = "",
    ) -> None:
        super().__init__()

        self.config = config
        embed_dim = config.hidden_size

        self.embeddings = CLIPTextEmbeddings(config)

        self.encoder = CLIPEncoder(
            config=config,
            quant_config=quant_config,
            prefix=f"{prefix}.encoder",
            attn_cls=Attention,
        )

        self.final_layer_norm = nn.LayerNorm(
            embed_dim,
            eps=config.layer_norm_eps,
        )

602
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
603
604
605
606
        return self.embeddings.token_embedding(input_ids)

    def forward(
        self,
607
        input_ids: torch.Tensor | None,
608
        position_ids: torch.Tensor,
609
        inputs_embeds: torch.Tensor | None = None,
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
    ) -> torch.Tensor:
        hidden_states = self.embeddings(
            input_ids=input_ids,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
        )

        last_hidden_state = self.encoder(
            inputs_embeds=hidden_states,
            return_all_hidden_states=False,
        )
        last_hidden_state = self.final_layer_norm(last_hidden_state)

        return last_hidden_state

625
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
626
627
628
629
630
631
632
633
634
635
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()

        for name, loaded_weight in weights:
636
            for param_name, weight_name, shard_id in stacked_params_mapping:
637
638
639
640
641
642
643
644
645
646
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
647
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
648
649
650
651
652
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


653
class CLIPVisionTransformer(nn.Module):
654
655
656
    def __init__(
        self,
        config: CLIPVisionConfig,
657
        quant_config: QuantizationConfig | None = None,
658
        *,
659
660
        num_hidden_layers_override: int | None = None,
        require_post_norm: bool | None = None,
661
662
        prefix: str = "",
    ) -> None:
663
        super().__init__()
664

665
666
667
668
669
670
671
672
        self.config = config
        embed_dim = config.hidden_size

        self.embeddings = CLIPVisionEmbeddings(config)

        # NOTE: This typo of "layrnorm" is not fixed on purpose to match
        # the original transformers code and name of the model weights.
        self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
673

674
675
676
        self.encoder = CLIPEncoder(
            config=config,
            quant_config=quant_config,
677
678
            num_hidden_layers_override=num_hidden_layers_override,
            prefix=f"{prefix}.encoder",
679
            attn_cls=MMEncoderAttention,
680
        )
681

682
        num_hidden_layers = config.num_hidden_layers
683
684
        if len(self.encoder.layers) > config.num_hidden_layers:
            raise ValueError(
685
                f"The original encoder only has {num_hidden_layers} "
686
687
                f"layers, but you requested {len(self.encoder.layers)} layers."
            )
688
689
690
691
692
693

        # If possible, skip post_layernorm to conserve memory
        if require_post_norm is None:
            require_post_norm = len(self.encoder.layers) == num_hidden_layers

        if require_post_norm:
694
            self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
695
696
697
        else:
            self.post_layernorm = None

698
699
700
701
702
703
704
705
    @property
    def dtype(self):
        return next(self.parameters()).dtype

    @property
    def device(self):
        return next(self.parameters()).device

706
707
708
    def forward(
        self,
        pixel_values: torch.Tensor,
709
        *,
710
711
        select_layers: list[int] | None = None,
        feature_select_strategy: VisionFeatureSelectStrategy | None = None,
712
713
714
715
    ) -> torch.Tensor:
        hidden_states = self.embeddings(pixel_values)
        hidden_states = self.pre_layrnorm(hidden_states)

716
        # Produces either the last layer output or all of the hidden states,
717
        # depending on if we have select_layers or not
718
719
        encoder_outputs = self.encoder(
            inputs_embeds=hidden_states,
720
721
            return_all_hidden_states=select_layers is not None,
        )
722
723
724

        # Handle post-norm (if applicable) and stacks feature layers if needed
        encoder_outputs = resolve_visual_encoder_outputs(
725
726
727
728
729
730
            encoder_outputs,
            self.post_layernorm,
            select_layers=select_layers,
            max_possible_layers=self.config.num_hidden_layers,
            feature_select_strategy=feature_select_strategy,
        )
731

732
        return encoder_outputs
733

734
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
735
736
737
738
739
740
741
742
743
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        layer_count = len(self.encoder.layers)
744

745
746
        for name, loaded_weight in weights:
            # post_layernorm is not needed in CLIPVisionModel
747
            if name.startswith("post_layernorm") and self.post_layernorm is None:
748
749
750
751
752
753
754
755
                continue

            # omit layers when num_hidden_layers_override is set
            if name.startswith("encoder.layers"):
                layer_idx = int(name.split(".")[2])
                if layer_idx >= layer_count:
                    continue

756
            for param_name, weight_name, shard_id in stacked_params_mapping:
757
758
759
760
761
762
763
764
765
766
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
767
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
768
769
770
771
772
773
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


class CLIPVisionModel(nn.Module):
774
775
776
    def __init__(
        self,
        config: CLIPVisionConfig,
777
        quant_config: QuantizationConfig | None = None,
778
        *,
779
780
        num_hidden_layers_override: int | None = None,
        require_post_norm: bool | None = None,
781
782
        prefix: str = "",
    ) -> None:
783
        super().__init__()
784

785
786
787
        self.vision_model = CLIPVisionTransformer(
            config=config,
            quant_config=quant_config,
788
789
            num_hidden_layers_override=num_hidden_layers_override,
            require_post_norm=require_post_norm,
790
791
            prefix=f"{prefix}.vision_model",
        )
792

793
794
795
    def forward(
        self,
        pixel_values: torch.Tensor,
796
797
        select_layers: list[int] | None = None,
        feature_select_strategy: VisionFeatureSelectStrategy | None = None,
798
    ) -> torch.Tensor:
799
800
801
802
803
        return self.vision_model(
            pixel_values,
            select_layers=select_layers,
            feature_select_strategy=feature_select_strategy,
        )
804

805
806
807
808
    @property
    def dtype(self):
        return self.vision_model.dtype

809
810
    @property
    def device(self):
811
        return self.vision_model.device
812
813


814
# Assume EOS token corresponds to LAST token in text model
815
@default_pooling_type(seq_pooling_type="LAST")
816
817
818
819
820
@MULTIMODAL_REGISTRY.register_processor(
    CLIPMultiModalProcessor,
    info=CLIPProcessingInfo,
    dummy_inputs=CLIPDummyInputsBuilder,
)
821
822
class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
    is_pooling_model = True
823

824
    packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
825

826
    @classmethod
827
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
        if modality.startswith("image"):
            return None

        raise ValueError("Only image modality is supported")

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config: CLIPConfig = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
        self.config = config
        self.multimodal_config = multimodal_config

        text_config = config.text_config
        vision_config = config.vision_config

        self.projection_dim = config.projection_dim
        self.text_embed_dim = text_config.hidden_size
        self.vision_embed_dim = vision_config.hidden_size

849
850
851
852
853
854
855
856
857
858
859
        with self._mark_language_model(vllm_config):
            self.text_model = CLIPTextTransformer(
                text_config,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "text_model"),
            )
            self.text_projection = nn.Linear(
                self.text_embed_dim,
                self.projection_dim,
                bias=False,
            )
860

861
862
863
864
865
866
867
868
869
870
871
        with self._mark_tower_model(vllm_config, "image"):
            self.vision_model = CLIPVisionTransformer(
                vision_config,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "vision_model"),
            )
            self.visual_projection = nn.Linear(
                self.vision_embed_dim,
                self.projection_dim,
                bias=False,
            )
872
873
874
875
876

        pooler_config = vllm_config.model_config.pooler_config
        assert pooler_config is not None
        self.pooler_config = pooler_config

877
        self.pooler = DispatchPooler.for_embedding(pooler_config)
878

879
        # Assumes that self.forward is called after self.embed_input_ids
880
881
882
883
        self._is_text_input = True

    def get_text_features(
        self,
884
        input_ids: torch.Tensor | None,
885
        position_ids: torch.Tensor,
886
        inputs_embeds: torch.Tensor | None = None,
887
888
889
890
891
892
893
894
895
896
897
898
899
900
    ) -> torch.Tensor:
        pooled_output = self.text_model(
            input_ids=input_ids,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
        )

        text_features = self.text_projection(pooled_output)

        return text_features

    def get_image_features(
        self,
        pixel_values: torch.Tensor,
901
        feature_select_strategy: VisionFeatureSelectStrategy | None = None,
902
903
904
    ) -> torch.Tensor:
        if feature_select_strategy is None:
            feature_select_strategy = _get_vision_feature_select_strategy(
905
                self.pooler_config.seq_pooling_type
906
            )
907
908
909
910
911
912
913
914
915
916
917
918

        pooled_output = self.vision_model(
            pixel_values=pixel_values,
            select_layers=None,
            feature_select_strategy=feature_select_strategy,
        )

        image_features = self.visual_projection(pooled_output)

        return image_features

    def _parse_and_validate_image_input(
919
        self, **kwargs: object
920
    ) -> CLIPImagePixelInputs | None:
921
922
923
924
925
        pixel_values = kwargs.pop("pixel_values", None)
        if pixel_values is None:
            return None

        expected_h = expected_w = self.config.vision_config.image_size
926
927
928
929
930
931
932
        return CLIPImagePixelInputs(
            type="pixel_values",
            data=pixel_values,
            resolve_bindings={"h": expected_h, "w": expected_w},
        )

    def _process_image_inputs(self, inputs: CLIPImagePixelInputs) -> torch.Tensor:
933
934
935
936
        pixel_values = inputs["data"]

        return self.get_image_features(pixel_values)

937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
    def _embed_text_input_ids(
        self,
        input_ids: torch.Tensor,
        embed_input_ids: Callable[[torch.Tensor], torch.Tensor],
        *,
        is_multimodal: torch.Tensor | None,
        handle_oov_mm_token: bool,
    ) -> torch.Tensor:
        inputs_embeds = super()._embed_text_input_ids(
            input_ids,
            embed_input_ids,
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
        )

        # NOTE: inputs_embeds in model runner has size text_config.projection_dim
        # (instead of text_config.hidden_size) to accommodate image embeddings
        inputs_embeds_size = self.projection_dim
        if inputs_embeds.shape[1] < inputs_embeds_size:
            inputs_embeds = torch.cat(
                [
                    inputs_embeds,
                    inputs_embeds.new_empty(
                        inputs_embeds.shape[0],
                        inputs_embeds_size - inputs_embeds.shape[1],
                    ),
                ],
                dim=1,
            )
        elif inputs_embeds.shape[1] > inputs_embeds_size:
            # No need to handle this case for now
            raise NotImplementedError

        return inputs_embeds

972
    def embed_input_ids(
973
974
        self,
        input_ids: torch.Tensor,
975
        multimodal_embeddings: MultiModalEmbeddings | None = None,
976
        *,
977
        is_multimodal: torch.Tensor | None = None,
978
979
        handle_oov_mm_token: bool = False,
    ) -> torch.Tensor:
980
981
982
        self._is_text_input = (
            multimodal_embeddings is None or len(multimodal_embeddings) == 0
        )
983
984
985

        # This is to satisfy the type checker for each overload
        if multimodal_embeddings is None or is_multimodal is None:
986
            return super().embed_input_ids(input_ids)
987

988
        return super().embed_input_ids(
989
990
991
992
993
994
            input_ids,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
        )

995
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
996
997
998
999
1000
1001
1002
1003
1004
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return []

        vision_embeddings = self._process_image_inputs(image_input)
        return vision_embeddings

    def forward(
        self,
1005
        input_ids: torch.Tensor | None,
1006
        positions: torch.Tensor,
1007
1008
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1009
1010
1011
1012
1013
1014
1015
1016
1017
        **kwargs: object,
    ) -> torch.Tensor:
        if intermediate_tensors is not None:
            raise RuntimeError("PP is not supported for this model")

        # Multimodal inputs
        if not self._is_text_input:
            return inputs_embeds

1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
        # NOTE: inputs_embeds in model runner has size text_config.projection_dim
        # (instead of text_config.hidden_size) to accommodate image embeddings
        hidden_size = self.text_embed_dim
        if inputs_embeds.shape[1] > hidden_size:
            inputs_embeds = inputs_embeds[:, :hidden_size]
        elif inputs_embeds.shape[1] < hidden_size:
            # No need to handle this case for now
            raise NotImplementedError

        return self.get_text_features(input_ids, positions, inputs_embeds)
1028
1029
1030
1031
1032
1033
1034
1035
1036

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
        loader = AutoWeightsLoader(
            self,
            skip_substrs=[".position_ids"],
            ignore_unexpected_prefixes=["logit_scale."],
        )

        return loader.load_weights(weights)