clip.py 33.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Callable, Iterable, Mapping, Sequence
4
from functools import cached_property
5
from typing import Annotated, Literal
6
7
8

import torch
import torch.nn as nn
9
10
11
12
13
14
15
from transformers import (
    BatchFeature,
    CLIPConfig,
    CLIPProcessor,
    CLIPTextConfig,
    CLIPVisionConfig,
)
16

17
from vllm.config import VllmConfig
18
from vllm.config.multimodal import BaseDummyOptions
19
from vllm.distributed import divide, get_tensor_model_parallel_world_size
20
from vllm.model_executor.layers.activation import get_act_fn
21
from vllm.model_executor.layers.attention import Attention, MMEncoderAttention
22
from vllm.model_executor.layers.conv import Conv2dLayer
23
24
25
26
27
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
28
from vllm.model_executor.layers.pooler import DispatchPooler
29
from vllm.model_executor.layers.quantization import QuantizationConfig
30
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
31
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
32
from vllm.model_executor.models.interfaces import SupportsQuant
33
from vllm.multimodal import MULTIMODAL_REGISTRY
34
35
36
37
38
39
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalInputs,
    MultiModalKwargsItems,
)
40
41
42
43
44
from vllm.multimodal.parse import (
    ImageProcessorItems,
    ImageSize,
    MultiModalDataItems,
)
45
from vllm.multimodal.processing import (
46
    BaseDummyInputsBuilder,
47
48
    BaseMultiModalProcessor,
    BaseProcessingInfo,
49
    ProcessorInputs,
50
51
52
    PromptIndexTargets,
    PromptReplacement,
    PromptUpdate,
53
    TimingContext,
54
)
55
56
57
58
59
60
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape

from .interfaces import MultiModalEmbeddings, SupportsMultiModal
from .interfaces_base import default_pooling_type
from .utils import AutoWeightsLoader, maybe_prefix
61
62
63
64
65
from .vision import (
    VisionEncoderInfo,
    VisionFeatureSelectStrategy,
    VisionFeatureSelectStrategyStr,
    get_num_selected_vision_tokens,
66
    is_vit_use_data_parallel,
67
68
    resolve_visual_encoder_outputs,
)
69

70

71
72
73
74
75
76
77
78
class CLIPImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height of each image
        - w: Width of each image
    """
79

80
81
82
83
    type: Literal["pixel_values"]
    data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]


84
85
86
87
88
89
90
class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]):
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
91
        return self.get_patch_grid_length() ** 2 + 1
92

93
94
95
96
97
98
99
    def get_image_size(self) -> int:
        return self.vision_config.image_size

    def get_patch_size(self) -> int:
        return self.vision_config.patch_size

    def get_patch_grid_length(self) -> int:
100
101
102
        image_size, patch_size = self.get_image_size(), self.get_patch_size()
        assert image_size % patch_size == 0
        return image_size // patch_size
103
104


105
106
107
108
109
110
111
112
113
114
115
116
117
_POOLING_TYPE_TO_STRATEGY: dict[str, VisionFeatureSelectStrategyStr] = {
    "MEAN": "full",
    "ALL": "full",
    "CLS": "class",
    # This lets us use the same pooling type for both text and image
    "LAST": "class",
}


def _get_vision_feature_select_strategy(pooling_type: str):
    try:
        return _POOLING_TYPE_TO_STRATEGY[pooling_type]
    except KeyError:
118
119
120
121
        raise ValueError(
            f"No feature selection strategy is defined for "
            f"pooling_type: {pooling_type!r}"
        ) from None
122
123
124
125
126
127
128
129
130
131
132
133


class CLIPProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
        return self.ctx.get_hf_config(CLIPConfig)

    def get_vision_encoder_info(self):
        return CLIPEncoderInfo(self.get_hf_config())

    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(CLIPProcessor, **kwargs)

134
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
        return {"image": 1}

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        vision_encoder_info = self.get_vision_encoder_info()

        pooler_config = self.ctx.model_config.pooler_config
        assert pooler_config is not None

        return get_num_selected_vision_tokens(
            vision_encoder_info.get_num_image_tokens(
                image_width=image_width,
                image_height=image_height,
            ),
153
            _get_vision_feature_select_strategy(pooler_config.seq_pooling_type),
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
        )

    def get_image_size_with_most_features(self) -> ImageSize:
        vision_encoder_info = self.get_vision_encoder_info()
        width = height = vision_encoder_info.get_image_size()
        return ImageSize(width=width, height=height)

    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()

        return self.get_num_image_tokens(
            image_width=target_width,
            image_height=target_height,
        )


class CLIPDummyInputsBuilder(BaseDummyInputsBuilder[CLIPProcessingInfo]):
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        return ""

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
178
        mm_options: Mapping[str, BaseDummyOptions],
179
180
181
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)

182
        target_width, target_height = self.info.get_image_size_with_most_features()
183

184
        image_overrides = mm_options.get("image")
185
186

        return {
187
188
189
190
191
192
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
        }


class CLIPMultiModalProcessor(BaseMultiModalProcessor[CLIPProcessingInfo]):
    @cached_property
    def image_token_id(self) -> int:
        tokenizer = self.info.get_tokenizer()
        dummy_token_id = 0

        assert dummy_token_id not in tokenizer.all_special_ids

        return dummy_token_id

    def apply(
        self,
208
209
        inputs: ProcessorInputs,
        timing_ctx: TimingContext,
210
    ) -> MultiModalInputs:
211
212
213
        if inputs.mm_data_items:
            if isinstance(inputs.prompt, str):
                if len(inputs.prompt) > 0:
214
215
216
217
218
219
                    raise ValueError(
                        "CLIP accepts text-only or image-only inputs, not both! "
                        "You must pass an image with an empty text prompt."
                    )
            else:
                special_tokens = self.info.get_tokenizer().all_special_ids
220
221
                if all(tok in special_tokens for tok in inputs.prompt):
                    inputs.prompt = []
222
223
224
225
226
227
                else:
                    raise ValueError(
                        "CLIP accepts text-only or image-only inputs, not both! "
                        "You must pass an image with an empty token prompt."
                    )

228
229
            # For multi-modal data, the prompt after processing should
            # only contain the dummy image tokens
230
231
            inputs.tokenization_kwargs = {
                **inputs.tokenization_kwargs,
232
233
234
                "add_special_tokens": False,
            }

235
        return super().apply(inputs, timing_ctx)
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286

    def _hf_processor_applies_updates(
        self,
        prompt_text: str,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        tokenization_kwargs: Mapping[str, object],
    ) -> bool:
        return False

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(pixel_values=MultiModalFieldConfig.batched("image"))

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargsItems,
    ) -> Sequence[PromptUpdate]:
        image_token_id = self.image_token_id

        def get_replacement(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)

            num_image_tokens = self.info.get_num_image_tokens(
                image_width=image_size.width,
                image_height=image_size.height,
            )
            return [image_token_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=PromptIndexTargets.start(),
                replacement=get_replacement,
            ),
        ]


# Adapted from: https://github.com/huggingface/transformers/blob/v4.56.2/src/transformers/models/clip/modeling_clip.py
class CLIPTextEmbeddings(nn.Module):
    def __init__(self, config: CLIPTextConfig):
        super().__init__()

        embed_dim = config.hidden_size

287
        self.token_embedding = VocabParallelEmbedding(config.vocab_size, embed_dim)
288
        self.position_embedding = VocabParallelEmbedding(
289
290
            config.max_position_embeddings, embed_dim
        )
291
292
293

    def forward(
        self,
294
        input_ids: torch.Tensor | None,
295
        position_ids: torch.Tensor,
296
        inputs_embeds: torch.Tensor | None = None,
297
298
299
300
    ) -> torch.Tensor:
        if inputs_embeds is None:
            if input_ids is None:
                raise ValueError(
301
302
                    "Either `input_ids` or `input_embeds` must be provided"
                )
303
304
305
306
307
308
309
310
311

            inputs_embeds = self.token_embedding(input_ids)

        position_embeddings = self.position_embedding(position_ids)
        embeddings = inputs_embeds + position_embeddings

        return embeddings


312
313
314
315
316
317
318
class CLIPVisionEmbeddings(nn.Module):
    def __init__(self, config: CLIPVisionConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size
319
        assert self.image_size % self.patch_size == 0
320
321
322

        self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))

323
        self.patch_embedding = Conv2dLayer(
324
325
326
327
328
329
330
            in_channels=config.num_channels,
            out_channels=self.embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            bias=False,
        )

331
        self.num_patches = (self.image_size // self.patch_size) ** 2
332
        self.num_positions = self.num_patches + 1
333
334
335
336
337
338
        self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
        self.register_buffer(
            "position_ids",
            torch.arange(self.num_positions).expand((1, -1)),
            persistent=False,
        )
339
340
341
342

    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
        batch_size = pixel_values.shape[0]
        target_dtype = self.patch_embedding.weight.dtype
343
344
345
        patch_embeds = self.patch_embedding(
            pixel_values.to(dtype=target_dtype)
        )  # shape = [*, width, grid, grid]
346
347
348
349
350
351
352
353
354
        patch_embeds = patch_embeds.flatten(2).transpose(1, 2)

        class_embeds = self.class_embedding.expand(batch_size, 1, -1)
        embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
        embeddings = embeddings + self.position_embedding(self.position_ids)

        return embeddings


355
class CLIPAttention(nn.Module):
356
357
    def __init__(
        self,
358
359
        config: CLIPTextConfig | CLIPVisionConfig,
        quant_config: QuantizationConfig | None = None,
360
        *,
361
        prefix: str = "",
362
        attn_cls: type[Attention] | type[MMEncoderAttention],
363
    ) -> None:
364
        super().__init__()
365

366
367
368
369
370
371
        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
372
373
374
                f"embed_dim must be divisible by num_heads "
                f"(got `embed_dim`: {self.embed_dim} and "
                f"`num_heads`: {self.num_heads})."
375
            )
376
377
        self.scale = self.head_dim**-0.5

378
        use_data_parallel = is_vit_use_data_parallel()
379
380
381
382
383
        self.qkv_proj = QKVParallelLinear(
            hidden_size=self.embed_dim,
            head_size=self.head_dim,
            total_num_heads=self.num_heads,
            quant_config=quant_config,
384
            prefix=f"{prefix}.qkv_proj",
385
            disable_tp=use_data_parallel,
386
387
388
389
390
391
        )

        self.out_proj = RowParallelLinear(
            input_size=self.embed_dim,
            output_size=self.embed_dim,
            quant_config=quant_config,
392
            prefix=f"{prefix}.out_proj",
393
            disable_tp=use_data_parallel,
394
395
        )

396
397
398
        self.tp_size = (
            1 if use_data_parallel else get_tensor_model_parallel_world_size()
        )
399
400
        self.num_heads_per_partition = divide(self.num_heads, self.tp_size)

401
402
403
404
405
406
407
408
409
410
411
412
413
414
        if attn_cls == MMEncoderAttention:
            self.attn = attn_cls(
                self.num_heads_per_partition,
                self.head_dim,
                self.scale,
                prefix=f"{prefix}.attn",
            )
        else:
            self.attn = attn_cls(
                self.num_heads_per_partition,
                self.head_dim,
                self.scale,
                prefix=f"{prefix}.attn",
            )
415

416
417
418
419
420
421
422
423
    def forward(
        self,
        hidden_states: torch.Tensor,
    ):
        """Input shape: Batch x Time x Channel"""

        qkv_states, _ = self.qkv_proj(hidden_states)
        query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
424
        out = self.attn(query_states, key_states, value_states)
425
426
        attn_output, _ = self.out_proj(out)

427
        return attn_output, None
428
429


430
class CLIPMLP(nn.Module):
431
432
    def __init__(
        self,
433
434
        config: CLIPTextConfig | CLIPVisionConfig,
        quant_config: QuantizationConfig | None = None,
435
436
        prefix: str = "",
    ) -> None:
437
        super().__init__()
438

439
        self.config = config
440
        use_data_parallel = is_vit_use_data_parallel()
441
        self.activation_fn = get_act_fn(config.hidden_act)
442

443
444
445
446
447
448
        self.fc1 = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
449
            disable_tp=use_data_parallel,
450
451
452
453
454
455
456
        )
        self.fc2 = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
457
            disable_tp=use_data_parallel,
458
        )
459
460
461
462
463
464
465
466
467
468

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)

        return hidden_states


class CLIPEncoderLayer(nn.Module):
469
470
    def __init__(
        self,
471
472
        config: CLIPTextConfig | CLIPVisionConfig,
        quant_config: QuantizationConfig | None = None,
473
        *,
474
        prefix: str = "",
475
        attn_cls: type[Attention] | type[MMEncoderAttention],
476
    ) -> None:
477
        super().__init__()
478

479
480
481
482
        self.self_attn = CLIPAttention(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
483
            attn_cls=attn_cls,
484
        )
485
        self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
486
487
488
489
490
        self.mlp = CLIPMLP(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
        )
491
        self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
492

493
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
494
495
496
        residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
497
        hidden_states, _ = self.self_attn(hidden_states=hidden_states)
498
499
500
501
502
503
504
505
506
507
508
509
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states


class CLIPEncoder(nn.Module):
    """
510
    Transformer encoder consisting of `config.num_hidden_layers` self
511
512
513
514
515
516
    attention layers. Each layer is a [`CLIPEncoderLayer`].

    Args:
        config: CLIPConfig
    """

517
518
    def __init__(
        self,
519
520
521
        config: CLIPTextConfig | CLIPVisionConfig,
        quant_config: QuantizationConfig | None = None,
        num_hidden_layers_override: int | None = None,
522
        *,
523
        prefix: str = "",
524
        attn_cls: type[Attention] | type[MMEncoderAttention],
525
    ) -> None:
526
        super().__init__()
527

528
        self.config = config
529
530
531
532
533

        if num_hidden_layers_override is None:
            num_hidden_layers = config.num_hidden_layers
        else:
            num_hidden_layers = num_hidden_layers_override
534

535
536
537
538
539
540
541
542
543
544
545
        self.layers = nn.ModuleList(
            [
                CLIPEncoderLayer(
                    config=config,
                    quant_config=quant_config,
                    prefix=f"{prefix}.layers.{layer_idx}",
                    attn_cls=attn_cls,
                )
                for layer_idx in range(num_hidden_layers)
            ]
        )
546

547
    def forward(
548
549
550
        self,
        inputs_embeds: torch.Tensor,
        return_all_hidden_states: bool,
551
    ) -> torch.Tensor | list[torch.Tensor]:
552
        hidden_states_pool = [inputs_embeds]
553
        hidden_states = inputs_embeds
554

555
        for encoder_layer in self.layers:
556
            hidden_states = encoder_layer(hidden_states)
557
558
559
560
561
562
            if return_all_hidden_states:
                hidden_states_pool.append(hidden_states)
        # If we have multiple feature sample layers, we return all hidden
        # states in order and grab the ones we need by index.
        if return_all_hidden_states:
            return hidden_states_pool
563
564
565
        return hidden_states


566
567
568
569
class CLIPTextTransformer(nn.Module):
    def __init__(
        self,
        config: CLIPTextConfig,
570
        quant_config: QuantizationConfig | None = None,
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
        *,
        prefix: str = "",
    ) -> None:
        super().__init__()

        self.config = config
        embed_dim = config.hidden_size

        self.embeddings = CLIPTextEmbeddings(config)

        self.encoder = CLIPEncoder(
            config=config,
            quant_config=quant_config,
            prefix=f"{prefix}.encoder",
            attn_cls=Attention,
        )

        self.final_layer_norm = nn.LayerNorm(
            embed_dim,
            eps=config.layer_norm_eps,
        )

593
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
594
595
596
597
        return self.embeddings.token_embedding(input_ids)

    def forward(
        self,
598
        input_ids: torch.Tensor | None,
599
        position_ids: torch.Tensor,
600
        inputs_embeds: torch.Tensor | None = None,
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
    ) -> torch.Tensor:
        hidden_states = self.embeddings(
            input_ids=input_ids,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
        )

        last_hidden_state = self.encoder(
            inputs_embeds=hidden_states,
            return_all_hidden_states=False,
        )
        last_hidden_state = self.final_layer_norm(last_hidden_state)

        return last_hidden_state

616
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
617
618
619
620
621
622
623
624
625
626
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()

        for name, loaded_weight in weights:
627
            for param_name, weight_name, shard_id in stacked_params_mapping:
628
629
630
631
632
633
634
635
636
637
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
638
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
639
640
641
642
643
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


644
class CLIPVisionTransformer(nn.Module):
645
646
647
    def __init__(
        self,
        config: CLIPVisionConfig,
648
        quant_config: QuantizationConfig | None = None,
649
        *,
650
651
        num_hidden_layers_override: int | None = None,
        require_post_norm: bool | None = None,
652
653
        prefix: str = "",
    ) -> None:
654
        super().__init__()
655

656
657
658
659
660
661
662
663
        self.config = config
        embed_dim = config.hidden_size

        self.embeddings = CLIPVisionEmbeddings(config)

        # NOTE: This typo of "layrnorm" is not fixed on purpose to match
        # the original transformers code and name of the model weights.
        self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
664

665
666
667
        self.encoder = CLIPEncoder(
            config=config,
            quant_config=quant_config,
668
669
            num_hidden_layers_override=num_hidden_layers_override,
            prefix=f"{prefix}.encoder",
670
            attn_cls=MMEncoderAttention,
671
        )
672

673
        num_hidden_layers = config.num_hidden_layers
674
675
        if len(self.encoder.layers) > config.num_hidden_layers:
            raise ValueError(
676
                f"The original encoder only has {num_hidden_layers} "
677
678
                f"layers, but you requested {len(self.encoder.layers)} layers."
            )
679
680
681
682
683
684

        # If possible, skip post_layernorm to conserve memory
        if require_post_norm is None:
            require_post_norm = len(self.encoder.layers) == num_hidden_layers

        if require_post_norm:
685
            self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
686
687
688
        else:
            self.post_layernorm = None

689
690
691
692
693
694
695
696
    @property
    def dtype(self):
        return next(self.parameters()).dtype

    @property
    def device(self):
        return next(self.parameters()).device

697
698
699
    def forward(
        self,
        pixel_values: torch.Tensor,
700
        *,
701
702
        select_layers: list[int] | None = None,
        feature_select_strategy: VisionFeatureSelectStrategy | None = None,
703
704
705
706
    ) -> torch.Tensor:
        hidden_states = self.embeddings(pixel_values)
        hidden_states = self.pre_layrnorm(hidden_states)

707
        # Produces either the last layer output or all of the hidden states,
708
        # depending on if we have select_layers or not
709
710
        encoder_outputs = self.encoder(
            inputs_embeds=hidden_states,
711
712
            return_all_hidden_states=select_layers is not None,
        )
713
714
715

        # Handle post-norm (if applicable) and stacks feature layers if needed
        encoder_outputs = resolve_visual_encoder_outputs(
716
717
718
719
720
721
            encoder_outputs,
            self.post_layernorm,
            select_layers=select_layers,
            max_possible_layers=self.config.num_hidden_layers,
            feature_select_strategy=feature_select_strategy,
        )
722

723
        return encoder_outputs
724

725
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
726
727
728
729
730
731
732
733
734
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        layer_count = len(self.encoder.layers)
735

736
737
        for name, loaded_weight in weights:
            # post_layernorm is not needed in CLIPVisionModel
738
            if name.startswith("post_layernorm") and self.post_layernorm is None:
739
740
741
742
743
744
745
746
                continue

            # omit layers when num_hidden_layers_override is set
            if name.startswith("encoder.layers"):
                layer_idx = int(name.split(".")[2])
                if layer_idx >= layer_count:
                    continue

747
            for param_name, weight_name, shard_id in stacked_params_mapping:
748
749
750
751
752
753
754
755
756
757
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
758
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
759
760
761
762
763
764
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


class CLIPVisionModel(nn.Module):
765
766
767
    def __init__(
        self,
        config: CLIPVisionConfig,
768
        quant_config: QuantizationConfig | None = None,
769
        *,
770
771
        num_hidden_layers_override: int | None = None,
        require_post_norm: bool | None = None,
772
773
        prefix: str = "",
    ) -> None:
774
        super().__init__()
775

776
777
778
        self.vision_model = CLIPVisionTransformer(
            config=config,
            quant_config=quant_config,
779
780
            num_hidden_layers_override=num_hidden_layers_override,
            require_post_norm=require_post_norm,
781
782
            prefix=f"{prefix}.vision_model",
        )
783

784
785
786
    def forward(
        self,
        pixel_values: torch.Tensor,
787
788
        select_layers: list[int] | None = None,
        feature_select_strategy: VisionFeatureSelectStrategy | None = None,
789
    ) -> torch.Tensor:
790
791
792
793
794
        return self.vision_model(
            pixel_values,
            select_layers=select_layers,
            feature_select_strategy=feature_select_strategy,
        )
795

796
797
798
799
    @property
    def dtype(self):
        return self.vision_model.dtype

800
801
    @property
    def device(self):
802
        return self.vision_model.device
803
804


805
# Assume EOS token corresponds to LAST token in text model
806
@default_pooling_type(seq_pooling_type="LAST")
807
808
809
810
811
@MULTIMODAL_REGISTRY.register_processor(
    CLIPMultiModalProcessor,
    info=CLIPProcessingInfo,
    dummy_inputs=CLIPDummyInputsBuilder,
)
812
813
class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
    is_pooling_model = True
814

815
    packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
816

817
    @classmethod
818
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
        if modality.startswith("image"):
            return None

        raise ValueError("Only image modality is supported")

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config: CLIPConfig = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
        self.config = config
        self.multimodal_config = multimodal_config

        text_config = config.text_config
        vision_config = config.vision_config

        self.projection_dim = config.projection_dim
        self.text_embed_dim = text_config.hidden_size
        self.vision_embed_dim = vision_config.hidden_size

840
841
842
843
844
845
846
847
848
849
850
        with self._mark_language_model(vllm_config):
            self.text_model = CLIPTextTransformer(
                text_config,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "text_model"),
            )
            self.text_projection = nn.Linear(
                self.text_embed_dim,
                self.projection_dim,
                bias=False,
            )
851

852
853
854
855
856
857
858
859
860
861
862
        with self._mark_tower_model(vllm_config, "image"):
            self.vision_model = CLIPVisionTransformer(
                vision_config,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "vision_model"),
            )
            self.visual_projection = nn.Linear(
                self.vision_embed_dim,
                self.projection_dim,
                bias=False,
            )
863
864
865
866
867

        pooler_config = vllm_config.model_config.pooler_config
        assert pooler_config is not None
        self.pooler_config = pooler_config

868
        self.pooler = DispatchPooler.for_embedding(pooler_config)
869

870
        # Assumes that self.forward is called after self.embed_input_ids
871
872
873
874
        self._is_text_input = True

    def get_text_features(
        self,
875
        input_ids: torch.Tensor | None,
876
        position_ids: torch.Tensor,
877
        inputs_embeds: torch.Tensor | None = None,
878
879
880
881
882
883
884
885
886
887
888
889
890
891
    ) -> torch.Tensor:
        pooled_output = self.text_model(
            input_ids=input_ids,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
        )

        text_features = self.text_projection(pooled_output)

        return text_features

    def get_image_features(
        self,
        pixel_values: torch.Tensor,
892
        feature_select_strategy: VisionFeatureSelectStrategy | None = None,
893
894
895
    ) -> torch.Tensor:
        if feature_select_strategy is None:
            feature_select_strategy = _get_vision_feature_select_strategy(
896
                self.pooler_config.seq_pooling_type
897
            )
898
899
900
901
902
903
904
905
906
907
908
909

        pooled_output = self.vision_model(
            pixel_values=pixel_values,
            select_layers=None,
            feature_select_strategy=feature_select_strategy,
        )

        image_features = self.visual_projection(pooled_output)

        return image_features

    def _parse_and_validate_image_input(
910
        self, **kwargs: object
911
    ) -> CLIPImagePixelInputs | None:
912
913
914
915
916
        pixel_values = kwargs.pop("pixel_values", None)
        if pixel_values is None:
            return None

        expected_h = expected_w = self.config.vision_config.image_size
917
918
919
920
921
922
923
        return CLIPImagePixelInputs(
            type="pixel_values",
            data=pixel_values,
            resolve_bindings={"h": expected_h, "w": expected_w},
        )

    def _process_image_inputs(self, inputs: CLIPImagePixelInputs) -> torch.Tensor:
924
925
926
927
        pixel_values = inputs["data"]

        return self.get_image_features(pixel_values)

928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
    def _embed_text_input_ids(
        self,
        input_ids: torch.Tensor,
        embed_input_ids: Callable[[torch.Tensor], torch.Tensor],
        *,
        is_multimodal: torch.Tensor | None,
    ) -> torch.Tensor:
        inputs_embeds = super()._embed_text_input_ids(
            input_ids,
            embed_input_ids,
            is_multimodal=is_multimodal,
        )

        # NOTE: inputs_embeds in model runner has size text_config.projection_dim
        # (instead of text_config.hidden_size) to accommodate image embeddings
        inputs_embeds_size = self.projection_dim
        if inputs_embeds.shape[1] < inputs_embeds_size:
            inputs_embeds = torch.cat(
                [
                    inputs_embeds,
                    inputs_embeds.new_empty(
                        inputs_embeds.shape[0],
                        inputs_embeds_size - inputs_embeds.shape[1],
                    ),
                ],
                dim=1,
            )
        elif inputs_embeds.shape[1] > inputs_embeds_size:
            # No need to handle this case for now
            raise NotImplementedError

        return inputs_embeds

961
    def embed_input_ids(
962
963
        self,
        input_ids: torch.Tensor,
964
        multimodal_embeddings: MultiModalEmbeddings | None = None,
965
        *,
966
        is_multimodal: torch.Tensor | None = None,
967
    ) -> torch.Tensor:
968
969
970
        self._is_text_input = (
            multimodal_embeddings is None or len(multimodal_embeddings) == 0
        )
971
972
973

        # This is to satisfy the type checker for each overload
        if multimodal_embeddings is None or is_multimodal is None:
974
            return super().embed_input_ids(input_ids)
975

976
        return super().embed_input_ids(
977
978
979
980
981
            input_ids,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
        )

982
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
983
984
985
986
987
988
989
990
991
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return []

        vision_embeddings = self._process_image_inputs(image_input)
        return vision_embeddings

    def forward(
        self,
992
        input_ids: torch.Tensor | None,
993
        positions: torch.Tensor,
994
995
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
996
997
998
999
1000
1001
1002
1003
1004
        **kwargs: object,
    ) -> torch.Tensor:
        if intermediate_tensors is not None:
            raise RuntimeError("PP is not supported for this model")

        # Multimodal inputs
        if not self._is_text_input:
            return inputs_embeds

1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
        # NOTE: inputs_embeds in model runner has size text_config.projection_dim
        # (instead of text_config.hidden_size) to accommodate image embeddings
        hidden_size = self.text_embed_dim
        if inputs_embeds.shape[1] > hidden_size:
            inputs_embeds = inputs_embeds[:, :hidden_size]
        elif inputs_embeds.shape[1] < hidden_size:
            # No need to handle this case for now
            raise NotImplementedError

        return self.get_text_features(input_ids, positions, inputs_embeds)
1015
1016
1017
1018
1019
1020
1021
1022
1023

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
        loader = AutoWeightsLoader(
            self,
            skip_substrs=[".position_ids"],
            ignore_unexpected_prefixes=["logit_scale."],
        )

        return loader.load_weights(weights)