clip.py 33.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Callable, Iterable, Mapping, Sequence
4
from functools import cached_property
5
from typing import Annotated, Literal
6
7
8

import torch
import torch.nn as nn
9
10
11
12
13
14
15
from transformers import (
    BatchFeature,
    CLIPConfig,
    CLIPProcessor,
    CLIPTextConfig,
    CLIPVisionConfig,
)
16

17
from vllm.config import VllmConfig
18
from vllm.config.multimodal import BaseDummyOptions
19
from vllm.distributed import divide, get_tensor_model_parallel_world_size
20
from vllm.model_executor.layers.activation import get_act_fn
21
from vllm.model_executor.layers.attention import Attention, MMEncoderAttention
22
from vllm.model_executor.layers.conv import Conv2dLayer
23
24
25
26
27
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
28
from vllm.model_executor.layers.pooler import DispatchPooler
29
from vllm.model_executor.layers.quantization import QuantizationConfig
30
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
31
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
32
from vllm.model_executor.models.interfaces import SupportsQuant
33
from vllm.multimodal import MULTIMODAL_REGISTRY
34
35
36
37
38
39
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalInputs,
    MultiModalKwargsItems,
)
40
41
42
43
44
45
from vllm.multimodal.parse import (
    ImageProcessorItems,
    ImageSize,
    MultiModalDataItems,
    MultiModalUUIDItems,
)
46
from vllm.multimodal.processing import (
47
    BaseDummyInputsBuilder,
48
49
50
51
52
53
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptIndexTargets,
    PromptReplacement,
    PromptUpdate,
)
54
55
56
57
58
59
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape

from .interfaces import MultiModalEmbeddings, SupportsMultiModal
from .interfaces_base import default_pooling_type
from .utils import AutoWeightsLoader, maybe_prefix
60
61
62
63
64
from .vision import (
    VisionEncoderInfo,
    VisionFeatureSelectStrategy,
    VisionFeatureSelectStrategyStr,
    get_num_selected_vision_tokens,
65
    is_vit_use_data_parallel,
66
67
    resolve_visual_encoder_outputs,
)
68

69

70
71
72
73
74
75
76
77
class CLIPImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height of each image
        - w: Width of each image
    """
78

79
80
81
82
    type: Literal["pixel_values"]
    data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]


83
84
85
86
87
88
89
class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]):
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
90
        return self.get_patch_grid_length() ** 2 + 1
91

92
93
94
95
96
97
98
    def get_image_size(self) -> int:
        return self.vision_config.image_size

    def get_patch_size(self) -> int:
        return self.vision_config.patch_size

    def get_patch_grid_length(self) -> int:
99
100
101
        image_size, patch_size = self.get_image_size(), self.get_patch_size()
        assert image_size % patch_size == 0
        return image_size // patch_size
102
103


104
105
106
107
108
109
110
111
112
113
114
115
116
_POOLING_TYPE_TO_STRATEGY: dict[str, VisionFeatureSelectStrategyStr] = {
    "MEAN": "full",
    "ALL": "full",
    "CLS": "class",
    # This lets us use the same pooling type for both text and image
    "LAST": "class",
}


def _get_vision_feature_select_strategy(pooling_type: str):
    try:
        return _POOLING_TYPE_TO_STRATEGY[pooling_type]
    except KeyError:
117
118
119
120
        raise ValueError(
            f"No feature selection strategy is defined for "
            f"pooling_type: {pooling_type!r}"
        ) from None
121
122
123
124
125
126
127
128
129
130
131
132


class CLIPProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
        return self.ctx.get_hf_config(CLIPConfig)

    def get_vision_encoder_info(self):
        return CLIPEncoderInfo(self.get_hf_config())

    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(CLIPProcessor, **kwargs)

133
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
        return {"image": 1}

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        vision_encoder_info = self.get_vision_encoder_info()

        pooler_config = self.ctx.model_config.pooler_config
        assert pooler_config is not None

        return get_num_selected_vision_tokens(
            vision_encoder_info.get_num_image_tokens(
                image_width=image_width,
                image_height=image_height,
            ),
152
            _get_vision_feature_select_strategy(pooler_config.seq_pooling_type),
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
        )

    def get_image_size_with_most_features(self) -> ImageSize:
        vision_encoder_info = self.get_vision_encoder_info()
        width = height = vision_encoder_info.get_image_size()
        return ImageSize(width=width, height=height)

    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()

        return self.get_num_image_tokens(
            image_width=target_width,
            image_height=target_height,
        )


class CLIPDummyInputsBuilder(BaseDummyInputsBuilder[CLIPProcessingInfo]):
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        return ""

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
177
        mm_options: Mapping[str, BaseDummyOptions],
178
179
180
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)

181
        target_width, target_height = self.info.get_image_size_with_most_features()
182

183
        image_overrides = mm_options.get("image")
184
185

        return {
186
187
188
189
190
191
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
        }


class CLIPMultiModalProcessor(BaseMultiModalProcessor[CLIPProcessingInfo]):
    @cached_property
    def image_token_id(self) -> int:
        tokenizer = self.info.get_tokenizer()
        dummy_token_id = 0

        assert dummy_token_id not in tokenizer.all_special_ids

        return dummy_token_id

    def apply(
        self,
207
        prompt: str | list[int],
208
        mm_items: MultiModalDataItems,
209
210
        mm_uuid_items: MultiModalUUIDItems | None = None,
        hf_processor_mm_kwargs: Mapping[str, object] | None = None,
211
        tokenization_kwargs: Mapping[str, object] | None = None,
212
    ) -> MultiModalInputs:
213
        if mm_items:
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
            if isinstance(prompt, str):
                if len(prompt) > 0:
                    raise ValueError(
                        "CLIP accepts text-only or image-only inputs, not both! "
                        "You must pass an image with an empty text prompt."
                    )
            else:
                special_tokens = self.info.get_tokenizer().all_special_ids
                if all(tok in special_tokens for tok in prompt):
                    prompt = []
                else:
                    raise ValueError(
                        "CLIP accepts text-only or image-only inputs, not both! "
                        "You must pass an image with an empty token prompt."
                    )

230
231
232
233
234
235
236
237
238
            # For multi-modal data, the prompt after processing should
            # only contain the dummy image tokens
            tokenization_kwargs = {
                **(tokenization_kwargs or {}),
                "add_special_tokens": False,
            }

        return super().apply(
            prompt=prompt,
239
            mm_items=mm_items,
240
            mm_uuid_items=mm_uuid_items,
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
            tokenization_kwargs=tokenization_kwargs,
        )

    def _hf_processor_applies_updates(
        self,
        prompt_text: str,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        tokenization_kwargs: Mapping[str, object],
    ) -> bool:
        return False

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(pixel_values=MultiModalFieldConfig.batched("image"))

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargsItems,
    ) -> Sequence[PromptUpdate]:
        image_token_id = self.image_token_id

        def get_replacement(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)

            num_image_tokens = self.info.get_num_image_tokens(
                image_width=image_size.width,
                image_height=image_size.height,
            )
            return [image_token_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=PromptIndexTargets.start(),
                replacement=get_replacement,
            ),
        ]


# Adapted from: https://github.com/huggingface/transformers/blob/v4.56.2/src/transformers/models/clip/modeling_clip.py
class CLIPTextEmbeddings(nn.Module):
    def __init__(self, config: CLIPTextConfig):
        super().__init__()

        embed_dim = config.hidden_size

295
        self.token_embedding = VocabParallelEmbedding(config.vocab_size, embed_dim)
296
        self.position_embedding = VocabParallelEmbedding(
297
298
            config.max_position_embeddings, embed_dim
        )
299
300
301

    def forward(
        self,
302
        input_ids: torch.Tensor | None,
303
        position_ids: torch.Tensor,
304
        inputs_embeds: torch.Tensor | None = None,
305
306
307
308
    ) -> torch.Tensor:
        if inputs_embeds is None:
            if input_ids is None:
                raise ValueError(
309
310
                    "Either `input_ids` or `input_embeds` must be provided"
                )
311
312
313
314
315
316
317
318
319

            inputs_embeds = self.token_embedding(input_ids)

        position_embeddings = self.position_embedding(position_ids)
        embeddings = inputs_embeds + position_embeddings

        return embeddings


320
321
322
323
324
325
326
class CLIPVisionEmbeddings(nn.Module):
    def __init__(self, config: CLIPVisionConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size
327
        assert self.image_size % self.patch_size == 0
328
329
330

        self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))

331
        self.patch_embedding = Conv2dLayer(
332
333
334
335
336
337
338
            in_channels=config.num_channels,
            out_channels=self.embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            bias=False,
        )

339
        self.num_patches = (self.image_size // self.patch_size) ** 2
340
        self.num_positions = self.num_patches + 1
341
342
343
344
345
346
        self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
        self.register_buffer(
            "position_ids",
            torch.arange(self.num_positions).expand((1, -1)),
            persistent=False,
        )
347
348
349
350

    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
        batch_size = pixel_values.shape[0]
        target_dtype = self.patch_embedding.weight.dtype
351
352
353
        patch_embeds = self.patch_embedding(
            pixel_values.to(dtype=target_dtype)
        )  # shape = [*, width, grid, grid]
354
355
356
357
358
359
360
361
362
        patch_embeds = patch_embeds.flatten(2).transpose(1, 2)

        class_embeds = self.class_embedding.expand(batch_size, 1, -1)
        embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
        embeddings = embeddings + self.position_embedding(self.position_ids)

        return embeddings


363
class CLIPAttention(nn.Module):
364
365
    def __init__(
        self,
366
367
        config: CLIPTextConfig | CLIPVisionConfig,
        quant_config: QuantizationConfig | None = None,
368
        *,
369
        prefix: str = "",
370
        attn_cls: type[Attention] | type[MMEncoderAttention],
371
    ) -> None:
372
        super().__init__()
373

374
375
376
377
378
379
        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
380
381
382
                f"embed_dim must be divisible by num_heads "
                f"(got `embed_dim`: {self.embed_dim} and "
                f"`num_heads`: {self.num_heads})."
383
            )
384
385
        self.scale = self.head_dim**-0.5

386
        use_data_parallel = is_vit_use_data_parallel()
387
388
389
390
391
        self.qkv_proj = QKVParallelLinear(
            hidden_size=self.embed_dim,
            head_size=self.head_dim,
            total_num_heads=self.num_heads,
            quant_config=quant_config,
392
            prefix=f"{prefix}.qkv_proj",
393
            disable_tp=use_data_parallel,
394
395
396
397
398
399
        )

        self.out_proj = RowParallelLinear(
            input_size=self.embed_dim,
            output_size=self.embed_dim,
            quant_config=quant_config,
400
            prefix=f"{prefix}.out_proj",
401
            disable_tp=use_data_parallel,
402
403
        )

404
405
406
        self.tp_size = (
            1 if use_data_parallel else get_tensor_model_parallel_world_size()
        )
407
408
        self.num_heads_per_partition = divide(self.num_heads, self.tp_size)

409
410
411
412
413
414
415
416
417
418
419
420
421
422
        if attn_cls == MMEncoderAttention:
            self.attn = attn_cls(
                self.num_heads_per_partition,
                self.head_dim,
                self.scale,
                prefix=f"{prefix}.attn",
            )
        else:
            self.attn = attn_cls(
                self.num_heads_per_partition,
                self.head_dim,
                self.scale,
                prefix=f"{prefix}.attn",
            )
423

424
425
426
427
428
429
430
431
    def forward(
        self,
        hidden_states: torch.Tensor,
    ):
        """Input shape: Batch x Time x Channel"""

        qkv_states, _ = self.qkv_proj(hidden_states)
        query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
432
        out = self.attn(query_states, key_states, value_states)
433
434
        attn_output, _ = self.out_proj(out)

435
        return attn_output, None
436
437


438
class CLIPMLP(nn.Module):
439
440
    def __init__(
        self,
441
442
        config: CLIPTextConfig | CLIPVisionConfig,
        quant_config: QuantizationConfig | None = None,
443
444
        prefix: str = "",
    ) -> None:
445
        super().__init__()
446

447
        self.config = config
448
        use_data_parallel = is_vit_use_data_parallel()
449
        self.activation_fn = get_act_fn(config.hidden_act)
450

451
452
453
454
455
456
        self.fc1 = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
457
            disable_tp=use_data_parallel,
458
459
460
461
462
463
464
        )
        self.fc2 = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
465
            disable_tp=use_data_parallel,
466
        )
467
468
469
470
471
472
473
474
475
476

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)

        return hidden_states


class CLIPEncoderLayer(nn.Module):
477
478
    def __init__(
        self,
479
480
        config: CLIPTextConfig | CLIPVisionConfig,
        quant_config: QuantizationConfig | None = None,
481
        *,
482
        prefix: str = "",
483
        attn_cls: type[Attention] | type[MMEncoderAttention],
484
    ) -> None:
485
        super().__init__()
486

487
488
489
490
        self.self_attn = CLIPAttention(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
491
            attn_cls=attn_cls,
492
        )
493
        self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
494
495
496
497
498
        self.mlp = CLIPMLP(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
        )
499
        self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
500

501
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
502
503
504
        residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
505
        hidden_states, _ = self.self_attn(hidden_states=hidden_states)
506
507
508
509
510
511
512
513
514
515
516
517
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states


class CLIPEncoder(nn.Module):
    """
518
    Transformer encoder consisting of `config.num_hidden_layers` self
519
520
521
522
523
524
    attention layers. Each layer is a [`CLIPEncoderLayer`].

    Args:
        config: CLIPConfig
    """

525
526
    def __init__(
        self,
527
528
529
        config: CLIPTextConfig | CLIPVisionConfig,
        quant_config: QuantizationConfig | None = None,
        num_hidden_layers_override: int | None = None,
530
        *,
531
        prefix: str = "",
532
        attn_cls: type[Attention] | type[MMEncoderAttention],
533
    ) -> None:
534
        super().__init__()
535

536
        self.config = config
537
538
539
540
541

        if num_hidden_layers_override is None:
            num_hidden_layers = config.num_hidden_layers
        else:
            num_hidden_layers = num_hidden_layers_override
542

543
544
545
546
547
548
549
550
551
552
553
        self.layers = nn.ModuleList(
            [
                CLIPEncoderLayer(
                    config=config,
                    quant_config=quant_config,
                    prefix=f"{prefix}.layers.{layer_idx}",
                    attn_cls=attn_cls,
                )
                for layer_idx in range(num_hidden_layers)
            ]
        )
554

555
    def forward(
556
557
558
        self,
        inputs_embeds: torch.Tensor,
        return_all_hidden_states: bool,
559
    ) -> torch.Tensor | list[torch.Tensor]:
560
        hidden_states_pool = [inputs_embeds]
561
        hidden_states = inputs_embeds
562

563
        for encoder_layer in self.layers:
564
            hidden_states = encoder_layer(hidden_states)
565
566
567
568
569
570
            if return_all_hidden_states:
                hidden_states_pool.append(hidden_states)
        # If we have multiple feature sample layers, we return all hidden
        # states in order and grab the ones we need by index.
        if return_all_hidden_states:
            return hidden_states_pool
571
572
573
        return hidden_states


574
575
576
577
class CLIPTextTransformer(nn.Module):
    def __init__(
        self,
        config: CLIPTextConfig,
578
        quant_config: QuantizationConfig | None = None,
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
        *,
        prefix: str = "",
    ) -> None:
        super().__init__()

        self.config = config
        embed_dim = config.hidden_size

        self.embeddings = CLIPTextEmbeddings(config)

        self.encoder = CLIPEncoder(
            config=config,
            quant_config=quant_config,
            prefix=f"{prefix}.encoder",
            attn_cls=Attention,
        )

        self.final_layer_norm = nn.LayerNorm(
            embed_dim,
            eps=config.layer_norm_eps,
        )

601
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
602
603
604
605
        return self.embeddings.token_embedding(input_ids)

    def forward(
        self,
606
        input_ids: torch.Tensor | None,
607
        position_ids: torch.Tensor,
608
        inputs_embeds: torch.Tensor | None = None,
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
    ) -> torch.Tensor:
        hidden_states = self.embeddings(
            input_ids=input_ids,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
        )

        last_hidden_state = self.encoder(
            inputs_embeds=hidden_states,
            return_all_hidden_states=False,
        )
        last_hidden_state = self.final_layer_norm(last_hidden_state)

        return last_hidden_state

624
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
625
626
627
628
629
630
631
632
633
634
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()

        for name, loaded_weight in weights:
635
            for param_name, weight_name, shard_id in stacked_params_mapping:
636
637
638
639
640
641
642
643
644
645
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
646
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
647
648
649
650
651
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


652
class CLIPVisionTransformer(nn.Module):
653
654
655
    def __init__(
        self,
        config: CLIPVisionConfig,
656
        quant_config: QuantizationConfig | None = None,
657
        *,
658
659
        num_hidden_layers_override: int | None = None,
        require_post_norm: bool | None = None,
660
661
        prefix: str = "",
    ) -> None:
662
        super().__init__()
663

664
665
666
667
668
669
670
671
        self.config = config
        embed_dim = config.hidden_size

        self.embeddings = CLIPVisionEmbeddings(config)

        # NOTE: This typo of "layrnorm" is not fixed on purpose to match
        # the original transformers code and name of the model weights.
        self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
672

673
674
675
        self.encoder = CLIPEncoder(
            config=config,
            quant_config=quant_config,
676
677
            num_hidden_layers_override=num_hidden_layers_override,
            prefix=f"{prefix}.encoder",
678
            attn_cls=MMEncoderAttention,
679
        )
680

681
        num_hidden_layers = config.num_hidden_layers
682
683
        if len(self.encoder.layers) > config.num_hidden_layers:
            raise ValueError(
684
                f"The original encoder only has {num_hidden_layers} "
685
686
                f"layers, but you requested {len(self.encoder.layers)} layers."
            )
687
688
689
690
691
692

        # If possible, skip post_layernorm to conserve memory
        if require_post_norm is None:
            require_post_norm = len(self.encoder.layers) == num_hidden_layers

        if require_post_norm:
693
            self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
694
695
696
        else:
            self.post_layernorm = None

697
698
699
700
701
702
703
704
    @property
    def dtype(self):
        return next(self.parameters()).dtype

    @property
    def device(self):
        return next(self.parameters()).device

705
706
707
    def forward(
        self,
        pixel_values: torch.Tensor,
708
        *,
709
710
        select_layers: list[int] | None = None,
        feature_select_strategy: VisionFeatureSelectStrategy | None = None,
711
712
713
714
    ) -> torch.Tensor:
        hidden_states = self.embeddings(pixel_values)
        hidden_states = self.pre_layrnorm(hidden_states)

715
        # Produces either the last layer output or all of the hidden states,
716
        # depending on if we have select_layers or not
717
718
        encoder_outputs = self.encoder(
            inputs_embeds=hidden_states,
719
720
            return_all_hidden_states=select_layers is not None,
        )
721
722
723

        # Handle post-norm (if applicable) and stacks feature layers if needed
        encoder_outputs = resolve_visual_encoder_outputs(
724
725
726
727
728
729
            encoder_outputs,
            self.post_layernorm,
            select_layers=select_layers,
            max_possible_layers=self.config.num_hidden_layers,
            feature_select_strategy=feature_select_strategy,
        )
730

731
        return encoder_outputs
732

733
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
734
735
736
737
738
739
740
741
742
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        layer_count = len(self.encoder.layers)
743

744
745
        for name, loaded_weight in weights:
            # post_layernorm is not needed in CLIPVisionModel
746
            if name.startswith("post_layernorm") and self.post_layernorm is None:
747
748
749
750
751
752
753
754
                continue

            # omit layers when num_hidden_layers_override is set
            if name.startswith("encoder.layers"):
                layer_idx = int(name.split(".")[2])
                if layer_idx >= layer_count:
                    continue

755
            for param_name, weight_name, shard_id in stacked_params_mapping:
756
757
758
759
760
761
762
763
764
765
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
766
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
767
768
769
770
771
772
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


class CLIPVisionModel(nn.Module):
773
774
775
    def __init__(
        self,
        config: CLIPVisionConfig,
776
        quant_config: QuantizationConfig | None = None,
777
        *,
778
779
        num_hidden_layers_override: int | None = None,
        require_post_norm: bool | None = None,
780
781
        prefix: str = "",
    ) -> None:
782
        super().__init__()
783

784
785
786
        self.vision_model = CLIPVisionTransformer(
            config=config,
            quant_config=quant_config,
787
788
            num_hidden_layers_override=num_hidden_layers_override,
            require_post_norm=require_post_norm,
789
790
            prefix=f"{prefix}.vision_model",
        )
791

792
793
794
    def forward(
        self,
        pixel_values: torch.Tensor,
795
796
        select_layers: list[int] | None = None,
        feature_select_strategy: VisionFeatureSelectStrategy | None = None,
797
    ) -> torch.Tensor:
798
799
800
801
802
        return self.vision_model(
            pixel_values,
            select_layers=select_layers,
            feature_select_strategy=feature_select_strategy,
        )
803

804
805
806
807
    @property
    def dtype(self):
        return self.vision_model.dtype

808
809
    @property
    def device(self):
810
        return self.vision_model.device
811
812


813
# Assume EOS token corresponds to LAST token in text model
814
@default_pooling_type(seq_pooling_type="LAST")
815
816
817
818
819
@MULTIMODAL_REGISTRY.register_processor(
    CLIPMultiModalProcessor,
    info=CLIPProcessingInfo,
    dummy_inputs=CLIPDummyInputsBuilder,
)
820
821
class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
    is_pooling_model = True
822

823
    packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
824

825
    @classmethod
826
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
        if modality.startswith("image"):
            return None

        raise ValueError("Only image modality is supported")

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config: CLIPConfig = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
        self.config = config
        self.multimodal_config = multimodal_config

        text_config = config.text_config
        vision_config = config.vision_config

        self.projection_dim = config.projection_dim
        self.text_embed_dim = text_config.hidden_size
        self.vision_embed_dim = vision_config.hidden_size

848
849
850
851
852
853
854
855
856
857
858
        with self._mark_language_model(vllm_config):
            self.text_model = CLIPTextTransformer(
                text_config,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "text_model"),
            )
            self.text_projection = nn.Linear(
                self.text_embed_dim,
                self.projection_dim,
                bias=False,
            )
859

860
861
862
863
864
865
866
867
868
869
870
        with self._mark_tower_model(vllm_config, "image"):
            self.vision_model = CLIPVisionTransformer(
                vision_config,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "vision_model"),
            )
            self.visual_projection = nn.Linear(
                self.vision_embed_dim,
                self.projection_dim,
                bias=False,
            )
871
872
873
874
875

        pooler_config = vllm_config.model_config.pooler_config
        assert pooler_config is not None
        self.pooler_config = pooler_config

876
        self.pooler = DispatchPooler.for_embedding(pooler_config)
877

878
        # Assumes that self.forward is called after self.embed_input_ids
879
880
881
882
        self._is_text_input = True

    def get_text_features(
        self,
883
        input_ids: torch.Tensor | None,
884
        position_ids: torch.Tensor,
885
        inputs_embeds: torch.Tensor | None = None,
886
887
888
889
890
891
892
893
894
895
896
897
898
899
    ) -> torch.Tensor:
        pooled_output = self.text_model(
            input_ids=input_ids,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
        )

        text_features = self.text_projection(pooled_output)

        return text_features

    def get_image_features(
        self,
        pixel_values: torch.Tensor,
900
        feature_select_strategy: VisionFeatureSelectStrategy | None = None,
901
902
903
    ) -> torch.Tensor:
        if feature_select_strategy is None:
            feature_select_strategy = _get_vision_feature_select_strategy(
904
                self.pooler_config.seq_pooling_type
905
            )
906
907
908
909
910
911
912
913
914
915
916
917

        pooled_output = self.vision_model(
            pixel_values=pixel_values,
            select_layers=None,
            feature_select_strategy=feature_select_strategy,
        )

        image_features = self.visual_projection(pooled_output)

        return image_features

    def _parse_and_validate_image_input(
918
        self, **kwargs: object
919
    ) -> CLIPImagePixelInputs | None:
920
921
922
923
924
        pixel_values = kwargs.pop("pixel_values", None)
        if pixel_values is None:
            return None

        expected_h = expected_w = self.config.vision_config.image_size
925
926
927
928
929
930
931
        return CLIPImagePixelInputs(
            type="pixel_values",
            data=pixel_values,
            resolve_bindings={"h": expected_h, "w": expected_w},
        )

    def _process_image_inputs(self, inputs: CLIPImagePixelInputs) -> torch.Tensor:
932
933
934
935
        pixel_values = inputs["data"]

        return self.get_image_features(pixel_values)

936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
    def _embed_text_input_ids(
        self,
        input_ids: torch.Tensor,
        embed_input_ids: Callable[[torch.Tensor], torch.Tensor],
        *,
        is_multimodal: torch.Tensor | None,
        handle_oov_mm_token: bool,
    ) -> torch.Tensor:
        inputs_embeds = super()._embed_text_input_ids(
            input_ids,
            embed_input_ids,
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
        )

        # NOTE: inputs_embeds in model runner has size text_config.projection_dim
        # (instead of text_config.hidden_size) to accommodate image embeddings
        inputs_embeds_size = self.projection_dim
        if inputs_embeds.shape[1] < inputs_embeds_size:
            inputs_embeds = torch.cat(
                [
                    inputs_embeds,
                    inputs_embeds.new_empty(
                        inputs_embeds.shape[0],
                        inputs_embeds_size - inputs_embeds.shape[1],
                    ),
                ],
                dim=1,
            )
        elif inputs_embeds.shape[1] > inputs_embeds_size:
            # No need to handle this case for now
            raise NotImplementedError

        return inputs_embeds

971
    def embed_input_ids(
972
973
        self,
        input_ids: torch.Tensor,
974
        multimodal_embeddings: MultiModalEmbeddings | None = None,
975
        *,
976
        is_multimodal: torch.Tensor | None = None,
977
978
        handle_oov_mm_token: bool = False,
    ) -> torch.Tensor:
979
980
981
        self._is_text_input = (
            multimodal_embeddings is None or len(multimodal_embeddings) == 0
        )
982
983
984

        # This is to satisfy the type checker for each overload
        if multimodal_embeddings is None or is_multimodal is None:
985
            return super().embed_input_ids(input_ids)
986

987
        return super().embed_input_ids(
988
989
990
991
992
993
            input_ids,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
        )

994
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
995
996
997
998
999
1000
1001
1002
1003
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return []

        vision_embeddings = self._process_image_inputs(image_input)
        return vision_embeddings

    def forward(
        self,
1004
        input_ids: torch.Tensor | None,
1005
        positions: torch.Tensor,
1006
1007
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1008
1009
1010
1011
1012
1013
1014
1015
1016
        **kwargs: object,
    ) -> torch.Tensor:
        if intermediate_tensors is not None:
            raise RuntimeError("PP is not supported for this model")

        # Multimodal inputs
        if not self._is_text_input:
            return inputs_embeds

1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
        # NOTE: inputs_embeds in model runner has size text_config.projection_dim
        # (instead of text_config.hidden_size) to accommodate image embeddings
        hidden_size = self.text_embed_dim
        if inputs_embeds.shape[1] > hidden_size:
            inputs_embeds = inputs_embeds[:, :hidden_size]
        elif inputs_embeds.shape[1] < hidden_size:
            # No need to handle this case for now
            raise NotImplementedError

        return self.get_text_features(input_ids, positions, inputs_embeds)
1027
1028
1029
1030
1031
1032
1033
1034
1035

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
        loader = AutoWeightsLoader(
            self,
            skip_substrs=[".position_ids"],
            ignore_unexpected_prefixes=["logit_scale."],
        )

        return loader.load_weights(weights)