chameleon.py 40.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Iterable, Mapping, Sequence
5
from functools import cached_property
6
from itertools import islice
7
from typing import Annotated, Any, Literal
8
9

import torch
10
import torch.nn as nn
11
import torch.nn.functional as F
12
13
14
15
16
17
from transformers import (
    BatchFeature,
    ChameleonConfig,
    ChameleonProcessor,
    ChameleonVQVAEConfig,
)
18

19
from vllm.attention import Attention
20
from vllm.config import CacheConfig, VllmConfig
21
from vllm.config.multimodal import BaseDummyOptions
22
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
23
from vllm.logger import init_logger
24
25
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
26
27
28
29
30
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
31
from vllm.model_executor.layers.logits_processor import LogitsProcessor
32
from vllm.model_executor.layers.quantization import QuantizationConfig
33
34
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
35
36
37
    ParallelLMHead,
    VocabParallelEmbedding,
)
38
from vllm.model_executor.model_loader.weight_utils import (
39
40
41
    default_weight_loader,
    row_parallel_weight_loader,
)
42
from vllm.model_executor.utils import set_weight_attrs
43
from vllm.multimodal import MULTIMODAL_REGISTRY
44
45
46
47
48
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
49
from vllm.multimodal.parse import MultiModalDataItems
50
51
52
53
54
55
56
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
57
from vllm.multimodal.profiling import BaseDummyInputsBuilder
58
from vllm.sequence import IntermediateTensors
59
from vllm.utils.tensor_schema import TensorSchema, TensorShape
60

61
62
63
64
65
66
67
68
69
70
71
72
from .interfaces import (
    MultiModalEmbeddings,
    SupportsMultiModal,
    SupportsPP,
    SupportsQuant,
)
from .utils import (
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
73

74
75
logger = init_logger(__name__)

76

77
78
79
80
81
82
83
84
class ChameleonImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height of each image
        - w: Width of each image
    """
85

86
    type: Literal["pixel_values"]
87
    data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
88
89


90
91
class ChameleonProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
92
93
        return self.ctx.get_hf_config(ChameleonConfig)

94
95
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(ChameleonProcessor, **kwargs)
96

97
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
98
99
        return {"image": 1}

100
101
102
103
104
    def get_num_image_tokens(self) -> int:
        processor = self.get_hf_processor()
        return processor.image_seq_length


105
class ChameleonDummyInputsBuilder(BaseDummyInputsBuilder[ChameleonProcessingInfo]):
106
107
108
109
110
111
112
113
114
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token = processor.image_token

        return image_token * num_images

    def get_dummy_mm_data(
115
116
117
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
118
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
119
    ) -> MultiModalDataDict:
120
        config = self.info.get_hf_config()
121
122
123
124

        width = height = config.vq_config.resolution
        num_images = mm_counts.get("image", 0)

125
126
        image_overrides = mm_options.get("image") if mm_options else None

127
        return {
128
129
130
131
132
133
            "image": self._get_dummy_images(
                width=width,
                height=height,
                num_images=num_images,
                overrides=image_overrides,
            )
134
135
136
        }


137
class ChameleonMultiModalProcessor(BaseMultiModalProcessor[ChameleonProcessingInfo]):
138
139
140
141
142
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
143
        tok_kwargs: Mapping[str, object],
144
145
146
147
148
149
150
151
152
153
    ) -> BatchFeature:
        if not mm_data:
            prompt_ids = self.info.get_tokenizer().encode(prompt)
            prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        return super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
154
            tok_kwargs=tok_kwargs,
155
156
157
158
159
160
161
162
        )

    def _apply_hf_processor_tokens_only(
        self,
        prompt_tokens: list[int],
    ) -> list[int]:
        # HF processor adds sep token for chat mode
        tokenizer = self.info.get_tokenizer()
163
164
165
        vocab = tokenizer.get_vocab()

        sep_token_id = vocab[tokenizer.sep_token]  # type: ignore
166
167
168

        return prompt_tokens + [sep_token_id]

169
170
171
172
173
174
175
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(pixel_values=MultiModalFieldConfig.batched("image"))

176
    def _get_prompt_updates(
177
178
179
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
180
        out_mm_kwargs: MultiModalKwargsItems,
181
    ) -> Sequence[PromptUpdate]:
182
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
183
184
185
186
187
188
189
190
191
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()

        image_start_id = vocab[processor.image_start_token]
        image_token_id = vocab[processor.image_token]
        image_end_id = vocab[processor.image_end_token]

        num_image_tokens = self.info.get_num_image_tokens()
        image_tokens = [image_token_id] * num_image_tokens
192
193
194
195

        return [
            PromptReplacement(
                modality="image",
196
                target=[image_token_id],
197
198
199
                replacement=PromptUpdateDetails.select_token_id(
                    [image_start_id] + image_tokens + [image_end_id],
                    embed_token_id=image_token_id,
200
                ),
201
202
203
            )
        ]

204
205
206
207

class ChameleonLayerNorm(nn.LayerNorm):
    def __init__(self, hidden_size, *args, **kwargs):
        super().__init__(hidden_size, *args, **kwargs)
208
        self.normalized_shape = (hidden_size[-1],)
209

210
211
        set_weight_attrs(self.weight, {"weight_loader": row_parallel_weight_loader})
        set_weight_attrs(self.bias, {"weight_loader": row_parallel_weight_loader})
212

213
    def forward(self, hidden_states):
214
215
216
        hidden_states = F.layer_norm(
            hidden_states, self.normalized_shape, None, None, eps=1e-5
        )
217
218
219
220
221
222
223
224
225
226
227
        hidden_states = hidden_states * self.weight + self.bias
        return hidden_states


# Copied from vllm.model_executor.models.llama.LlamaMLP -> ChameleonMLP
class ChameleonMLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
228
        quant_config: QuantizationConfig | None = None,
229
        bias: bool = False,
230
        prefix: str = "",
231
232
233
234
235
236
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            input_size=hidden_size,
            output_sizes=[intermediate_size] * 2,
            bias=bias,
237
            quant_config=quant_config,
238
            prefix=f"{prefix}.gate_up_proj",
239
240
241
242
243
244
        )
        self.down_proj = RowParallelLinear(
            input_size=intermediate_size,
            output_size=hidden_size,
            bias=bias,
            quant_config=quant_config,
245
            prefix=f"{prefix}.down_proj",
246
        )
247
        if hidden_act != "silu":
248
249
250
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


# Modified from vllm.model_executor.models.llama.LlamaAttention -> ChameleonAttention #noqa
class ChameleonAttention(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
268
        rope_scaling: dict[str, Any] | None = None,
269
        max_position_embeddings: int = 4096,
270
        quant_config: QuantizationConfig | None = None,
271
        bias: bool = False,
272
        cache_config: CacheConfig | None = None,
273
        prefix: str = "",
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        self.qkv_proj = QKVParallelLinear(
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
            bias=bias,
            quant_config=quant_config,
305
            prefix=f"{prefix}.qkv_proj",
306
307
308
309
310
311
        )
        self.o_proj = RowParallelLinear(
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
            bias=bias,
            quant_config=quant_config,
312
            prefix=f"{prefix}.o_proj",
313
314
315
316
317
318
319
320
321
322
323
        )
        self.q_norm = ChameleonLayerNorm((self.num_heads, self.head_dim))
        self.k_norm = ChameleonLayerNorm((self.num_kv_heads, self.head_dim))
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )

324
325
326
327
328
329
330
331
332
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
333

334
335
336
    def _apply_qk_norm(
        self, q: torch.Tensor, k: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
        # reshape for layernorm
        q = q.reshape(-1, self.num_heads, self.head_dim)
        k = k.reshape(-1, self.num_kv_heads, self.head_dim)
        q = self.q_norm(q)
        k = self.k_norm(k)
        q = q.view(*q.shape[:-2], -1)
        k = k.view(*k.shape[:-2], -1)
        return q, k

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self._apply_qk_norm(q, k)

        q, k = self.rotary_emb(positions, q, k)
356
        attn_output = self.attn(q, k, v)
357
358
359
360
361
362
363
364
        output, _ = self.o_proj(attn_output)
        return output


class ChameleonDecoderLayer(nn.Module):
    def __init__(
        self,
        config: ChameleonConfig,
365
366
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
367
        prefix: str = "",
368
369
370
371
372
373
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        if rope_scaling is not None and getattr(
374
375
            config, "original_max_position_embeddings", None
        ):
376
            rope_scaling["original_max_position_embeddings"] = (
377
378
379
                config.original_max_position_embeddings
            )
        max_position_embeddings = getattr(config, "max_position_embeddings", 4096)
380
381
382
383

        self.self_attn = ChameleonAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
384
385
386
            num_kv_heads=getattr(
                config, "num_key_value_heads", config.num_attention_heads
            ),
387
388
389
390
391
392
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
            quant_config=quant_config,
            bias=False,
            cache_config=cache_config,
393
            prefix=f"{prefix}.self_attn",
394
395
396
397
398
399
400
        )
        self.mlp = ChameleonMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
            quant_config=quant_config,
            bias=getattr(config, "mlp_bias", False),
401
            prefix=f"{prefix}.mlp",
402
        )
403
404
405
406
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
407
408
409
410
411

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
412
413
        residual: torch.Tensor | None,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
414
415
416
417
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
418
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
419
420
421
422
423
424
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
425
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
426
427
428
429
430
431
432
433
434
        hidden_states = self.mlp(hidden_states)

        return hidden_states, residual


class ChameleonSwinDecoderLayer(nn.Module):
    def __init__(
        self,
        config: ChameleonConfig,
435
436
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
437
        prefix: str = "",
438
439
440
441
442
443
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        if rope_scaling is not None and getattr(
444
445
            config, "original_max_position_embeddings", None
        ):
446
            rope_scaling["original_max_position_embeddings"] = (
447
448
449
                config.original_max_position_embeddings
            )
        max_position_embeddings = getattr(config, "max_position_embeddings", 4096)
450
451
452
453

        self.self_attn = ChameleonAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
454
455
456
            num_kv_heads=getattr(
                config, "num_key_value_heads", config.num_attention_heads
            ),
457
458
459
460
461
462
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
            quant_config=quant_config,
            bias=False,
            cache_config=cache_config,
463
            prefix=f"{prefix}.self_attn",
464
465
466
467
468
469
470
        )
        self.mlp = ChameleonMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
            quant_config=quant_config,
            bias=getattr(config, "mlp_bias", False),
471
            prefix=f"{prefix}.mlp",
472
        )
473
474
475
476
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
477
478
479
480
481

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
482
        residual: torch.Tensor | None,
483
    ) -> tuple[torch.Tensor, torch.Tensor]:
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
        residual = hidden_states
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = hidden_states + residual

        # Fully Connected
        residual = hidden_states
        hidden_states = self.mlp(hidden_states)
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states, residual


502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEVectorQuantizer #noqa
class ChameleonVQVAEVectorQuantizer(nn.Module):
    def __init__(self, config: ChameleonVQVAEConfig):
        super().__init__()
        self.num_embeddings = config.num_embeddings
        self.embedding_dim = config.embed_dim
        self.beta = getattr(config, "beta", 0.25)

        self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
        self.re_embed = self.num_embeddings

    def forward(self, hidden_state: torch.Tensor):
        hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
        hidden_state_flattened = hidden_state.view(-1, self.embedding_dim)

        # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
        distances = (
519
520
521
522
523
524
525
526
527
            torch.sum(hidden_state_flattened**2, dim=1, keepdim=True)
            + torch.sum(self.embedding.weight**2, dim=1)
            - 2
            * torch.einsum(
                "bd,dn->bn",
                hidden_state_flattened,
                self.embedding.weight.transpose(0, 1),
            )
        )
528
529
530

        min_encoding_indices = torch.argmin(distances, dim=1)
        hidden_state_quant = self.embedding(min_encoding_indices).view(
531
532
            hidden_state.shape
        )
533
534

        # compute loss for embedding
535
536
537
        loss = torch.mean(
            (hidden_state_quant.detach() - hidden_state) ** 2
        ) + self.beta * torch.mean((hidden_state_quant - hidden_state.detach()) ** 2)
538
539

        # preserve gradients
540
        hidden_state_quant = hidden_state + (hidden_state_quant - hidden_state).detach()
541
542

        # reshape back to match original input shape
543
        hidden_state_quant = hidden_state_quant.permute(0, 3, 1, 2).contiguous()
544
545
546
547
548
549
550
551

        return hidden_state_quant, loss, min_encoding_indices


# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderConvDownsample #noqa
class ChameleonVQVAEEncoderConvDownsample(nn.Module):
    def __init__(self, in_channels: int):
        super().__init__()
552
553
554
        self.conv = nn.Conv2d(
            in_channels, in_channels, kernel_size=3, stride=2, padding=0
        )
555
556
557

    def forward(self, hidden_states: torch.Tensor):
        # no asymmetric padding in torch conv, must do it ourselves
558
        hidden_states = F.pad(hidden_states, pad=(0, 1, 0, 1), mode="constant", value=0)
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
        hidden_states = self.conv(hidden_states)
        return hidden_states


# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderResnetBlock #noqa
class ChameleonVQVAEEncoderResnetBlock(nn.Module):
    def __init__(
        self,
        config: ChameleonVQVAEConfig,
        in_channels: int,
        out_channels=None,
        conv_shortcut=False,
    ):
        super().__init__()
        self.in_channels = in_channels
574
        self.out_channels = in_channels if out_channels is None else out_channels
575
576
        self.use_conv_shortcut = conv_shortcut

577
578
579
580
581
582
583
584
585
        self.norm1 = torch.nn.GroupNorm(
            num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
        )
        self.conv1 = torch.nn.Conv2d(
            in_channels, out_channels, kernel_size=3, stride=1, padding=1
        )
        self.norm2 = torch.nn.GroupNorm(
            num_groups=32, num_channels=out_channels, eps=1e-6, affine=True
        )
586
        self.dropout = torch.nn.Dropout(config.dropout)
587
588
589
        self.conv2 = torch.nn.Conv2d(
            out_channels, out_channels, kernel_size=3, stride=1, padding=1
        )
590
591
        if self.in_channels != self.out_channels:
            if self.use_conv_shortcut:
592
593
594
                self.conv_shortcut = torch.nn.Conv2d(
                    in_channels, out_channels, kernel_size=3, stride=1, padding=1
                )
595
            else:
596
597
598
                self.nin_shortcut = torch.nn.Conv2d(
                    in_channels, out_channels, kernel_size=1, stride=1, padding=0
                )
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625

    def forward(self, hidden_states: torch.Tensor):
        residual = hidden_states
        hidden_states = self.norm1(hidden_states)
        hidden_states *= torch.sigmoid(hidden_states)
        hidden_states = self.conv1(hidden_states)

        hidden_states = self.norm2(hidden_states)
        hidden_states *= torch.sigmoid(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.conv2(hidden_states)

        if self.in_channels != self.out_channels:
            if self.use_conv_shortcut:
                residual = self.conv_shortcut(residual)
            else:
                residual = self.nin_shortcut(residual)

        return residual + hidden_states


# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderAttnBlock #noqa
class ChameleonVQVAEEncoderAttnBlock(nn.Module):
    def __init__(self, in_channels: int):
        super().__init__()
        self.in_channels = in_channels

626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
        self.norm = torch.nn.GroupNorm(
            num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
        )
        self.q = torch.nn.Conv2d(
            in_channels, in_channels, kernel_size=1, stride=1, padding=0
        )
        self.k = torch.nn.Conv2d(
            in_channels, in_channels, kernel_size=1, stride=1, padding=0
        )
        self.v = torch.nn.Conv2d(
            in_channels, in_channels, kernel_size=1, stride=1, padding=0
        )
        self.proj_out = torch.nn.Conv2d(
            in_channels, in_channels, kernel_size=1, stride=1, padding=0
        )
641
642
643
644
645
646
647
648
649
650

    def forward(self, hidden_states: torch.Tensor):
        residual = hidden_states
        hidden_states = self.norm(hidden_states)
        query_states = self.q(hidden_states)
        key_states = self.k(hidden_states)
        value_states = self.v(hidden_states)

        # compute attention
        batch_size, channels, height, width = query_states.shape
651
652
653
        query_states = query_states.reshape(
            batch_size, channels, height * width
        ).permute(0, 2, 1)
654
655
        key_states = key_states.reshape(batch_size, channels, height * width)
        attn_weights = torch.bmm(query_states, key_states)
656
        attn_weights = attn_weights * (int(channels) ** (-0.5))
657
658
659
        attn_weights = F.softmax(attn_weights, dim=2)

        # attend to values
660
        value_states = value_states.reshape(batch_size, channels, height * width)
661
        attn_weights = attn_weights.permute(0, 2, 1)
662
663
664
        attn_output = torch.bmm(value_states, attn_weights).reshape(
            batch_size, channels, height, width
        )
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683

        attn_output = self.proj_out(attn_output)
        return residual + attn_output


# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoder #noqa
class ChameleonVQVAEEncoder(nn.Module):
    def __init__(self, config: ChameleonVQVAEConfig):
        super().__init__()

        self.num_resolutions = len(config.channel_multiplier)
        self.num_res_blocks = config.num_res_blocks
        base_channels = config.base_channels
        resolution = config.resolution
        in_channels = config.in_channels
        double_latent = config.double_latent
        latent_channels = config.latent_channels
        channel_multiplier = config.channel_multiplier

684
685
686
        self.conv_in = torch.nn.Conv2d(
            in_channels, base_channels, kernel_size=3, stride=1, padding=1
        )
687
688

        curr_res = resolution
689
        in_channel_multiplier = (1,) + tuple(channel_multiplier)
690
691
692
693
694
695
696
697
698
699
700
701
702
        self.in_channel_multiplier = in_channel_multiplier
        self.down = nn.ModuleList()
        for i_level in range(self.num_resolutions):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            block_in = base_channels * in_channel_multiplier[i_level]
            block_out = base_channels * channel_multiplier[i_level]
            for i_block in range(self.num_res_blocks):
                block.append(
                    ChameleonVQVAEEncoderResnetBlock(
                        config=config,
                        in_channels=block_in,
                        out_channels=block_out,
703
704
                    )
                )
705
                block_in = block_out
706
707
708
709
710
                if (
                    config.attn_resolutions is not None
                    and curr_res in config.attn_resolutions
                    and config.attn_type == "vanilla"
                ):
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
                    attn.append(ChameleonVQVAEEncoderAttnBlock(block_in))

            down = nn.Module()
            down.block = block
            down.attn = attn
            if i_level != self.num_resolutions - 1:
                down.downsample = ChameleonVQVAEEncoderConvDownsample(block_in)
                curr_res = curr_res // 2
            self.down.append(down)

        self.mid = nn.Module()
        self.mid.block_1 = ChameleonVQVAEEncoderResnetBlock(
            config=config,
            in_channels=block_in,
            out_channels=block_in,
        )
727
728
729
730
731
        self.mid.attn_1 = (
            ChameleonVQVAEEncoderAttnBlock(block_in)
            if config.attn_type == "vanilla"
            else nn.Identity()
        )
732
733
734
735
736
737
        self.mid.block_2 = ChameleonVQVAEEncoderResnetBlock(
            config=config,
            in_channels=block_in,
            out_channels=block_in,
        )

738
739
740
        self.norm_out = torch.nn.GroupNorm(
            num_groups=32, num_channels=block_in, eps=1e-6, affine=True
        )
741
742
743
744
745
746
747
748
749
        self.conv_out = torch.nn.Conv2d(
            block_in,
            2 * latent_channels if double_latent else latent_channels,
            kernel_size=3,
            stride=1,
            padding=1,
        )

    def forward(self, pixel_values: torch.Tensor):
750
751
        pixel_values = pixel_values.to(self.conv_in.weight.dtype)

752
753
754
755
        # downsampling
        hidden_states = [self.conv_in(pixel_values)]
        for i_level in range(self.num_resolutions):
            for i_block in range(self.num_res_blocks):
756
                hidden_state = self.down[i_level].block[i_block](hidden_states[-1])
757
                if len(self.down[i_level].attn) > 0:
758
                    hidden_state = self.down[i_level].attn[i_block](hidden_state)
759
760
                hidden_states.append(hidden_state)
            if i_level != self.num_resolutions - 1:
761
                hidden_states.append(self.down[i_level].downsample(hidden_states[-1]))
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781

        # middle
        last_hidden_state = hidden_states[-1]
        last_hidden_state = self.mid.block_1(last_hidden_state)
        last_hidden_state = self.mid.attn_1(last_hidden_state)
        last_hidden_state = self.mid.block_2(last_hidden_state)

        # end
        last_hidden_state = self.norm_out(last_hidden_state)
        last_hidden_state *= torch.sigmoid(last_hidden_state)
        last_hidden_state = self.conv_out(last_hidden_state)
        return last_hidden_state


# Adapted from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAE #noqa
class ChameleonVQVAE(nn.Module):
    def __init__(self, config: ChameleonVQVAEConfig):
        super().__init__()
        self.encoder = ChameleonVQVAEEncoder(config)
        self.quantize = ChameleonVQVAEVectorQuantizer(config)
782
783
784
785
        self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1)
        self.post_quant_conv = torch.nn.Conv2d(
            config.embed_dim, config.latent_channels, 1
        )
786
787
788
789
        self.eval()  # Chameleon's VQ model is frozen

    def encode(
        self, pixel_values: torch.Tensor
790
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
791
792
793
794
795
796
797
        hidden_states = self.encoder(pixel_values)
        hidden_states = self.quant_conv(hidden_states)
        quant, emb_loss, indices = self.quantize(hidden_states)
        return quant, emb_loss, indices


# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonImageVocabularyMapping #noqa
798
799
800
801
802
class ChameleonImageVocabularyMapping:
    """
    A class for mapping discrete image tokens from VQGAN to BPE tokens.
    """

803
    def __init__(self, vocab_map: dict[str, int]):
804
805
806
807
808
809
810
811
812
        self.vocab_map = vocab_map
        self.image_token_id = vocab_map.get("<image>")

    @cached_property
    def val2name(self):
        return {v: k for k, v in self.vocab_map.items()}

    @cached_property
    def image_tokens(self):
813
814
815
        return sorted(
            [val for name, val in self.vocab_map.items() if name.startswith("IMGIMG")]
        )
816
817
818
819
820
821
822

    @cached_property
    def bpe2img(self):
        img_tkn_chr_mapping = {chr(ord("A") + i): str(i) for i in range(10)}

        def remap(old_name: str) -> str:
            return "".join(
823
824
                img_tkn_chr_mapping.get(c, c) for c in old_name[len("IMGIMG") : -1]
            )
825

826
        return {tok: int(remap(self.val2name[tok])) for tok in self.image_tokens}
827
828
829
830
831
832
833
834

    @cached_property
    def img2bpe(self):
        return {v: k for k, v in self.bpe2img.items()}

    @cached_property
    def bpe2img_search_tensors(self):
        return torch.tensor(sorted(self.bpe2img.keys())), torch.tensor(
835
836
            sorted(self.bpe2img.values())
        )
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851

    @cached_property
    def img2bpe_mapping_tensor(self):
        mapping = torch.zeros(max(self.img2bpe.keys()) + 1, dtype=torch.int)
        for k, v in self.img2bpe.items():
            mapping[k] = v
        return mapping

    def convert_img2bpe(self, img_batch: torch.Tensor) -> torch.Tensor:
        device = img_batch.device
        img_tokens = self.img2bpe_mapping_tensor[img_batch.to("cpu")]
        return img_tokens.to(device)


class ChameleonModel(nn.Module):
852
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
853
        super().__init__()
854
855
856
857
858

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

859
860
861
862
863
864
        self.config = config
        self.vocab_size = config.vocab_size
        self.embed_tokens = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
        )
865
866
867
868
        self.vocabulary_mapping = ChameleonImageVocabularyMapping(config.vocabulary_map)
        decoder_layer = (
            ChameleonDecoderLayer
            if not self.config.swin_norm
869
            else ChameleonSwinDecoderLayer
870
        )
871
872
873

        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
874
875
876
877
878
879
            lambda prefix: decoder_layer(
                config=config,
                cache_config=cache_config,
                quant_config=quant_config,
                prefix=prefix,
            ),
880
881
882
            prefix=f"{prefix}.layers",
        )

883
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
884
        self.vqmodel = ChameleonVQVAE(config.vq_config)
885
886
887
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
888
889
890
891

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

892
893
894
895
896
897
898
899
900
901
902
903
    def get_image_tokens(self, pixel_values: torch.Tensor) -> torch.Tensor:
        """
        Tokenizes images into discrete tokens with VQGAN module. Converts
        obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
        special tokens.
        """
        batch_size = pixel_values.shape[0]
        _, _, image_toks = self.vqmodel.encode(pixel_values)
        bpe_toks = self.vocabulary_mapping.convert_img2bpe(image_toks)
        bpe_toks = bpe_toks.view(batch_size, -1)
        return bpe_toks

904
905
    def forward(
        self,
906
        input_ids: torch.Tensor | None,
907
        positions: torch.Tensor,
908
909
910
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
911
912
913
914
915
916
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
            residual = None
917
        else:
918
919
920
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
921
        for layer in islice(self.layers, self.start_layer, self.end_layer):
922
923
924
925
926
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )
927
        if not get_pp_group().is_last_rank:
928
929
930
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
931
932
933
934
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


935
936
937
@MULTIMODAL_REGISTRY.register_processor(
    ChameleonMultiModalProcessor,
    info=ChameleonProcessingInfo,
938
939
940
941
942
    dummy_inputs=ChameleonDummyInputsBuilder,
)
class ChameleonForConditionalGeneration(
    nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant
):
943
944
    merge_by_field_config = True

945
946
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
947
        "gate_up_proj": ["gate_proj", "up_proj"],
948
    }
949

950
    @classmethod
951
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
952
953
954
955
956
        if modality.startswith("image"):
            return "<image>"

        raise ValueError("Only image modality is supported")

957
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
958
        super().__init__()
959
960
        config = vllm_config.model_config.hf_config
        multimodal_config = vllm_config.model_config.multimodal_config
961
        self.config = config
962
        self.multimodal_config = multimodal_config
963
964
965
        self.model = ChameleonModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
966
967
968
969
        self.unpadded_vocab_size = config.vocab_size
        self.lm_head = ParallelLMHead(
            self.unpadded_vocab_size,
            config.hidden_size,
970
            prefix=maybe_prefix(prefix, "lm_head"),
971
972
973
974
975
        )
        if config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight

        logit_scale = getattr(config, "logit_scale", 1.0)
976
977
978
        self.logits_processor = LogitsProcessor(
            self.unpadded_vocab_size, config.vocab_size, logit_scale
        )
979
        self.make_empty_intermediate_tensors = (
980
981
            self.model.make_empty_intermediate_tensors
        )
982

983
    def _parse_and_validate_image_input(
984
        self, **kwargs: object
985
    ) -> ChameleonImagePixelInputs | None:
986
987
988
989
990
        pixel_values = kwargs.pop("pixel_values", None)

        if pixel_values is None:
            return None

991
992
993
        vq_config: ChameleonVQVAEConfig = self.config.vq_config
        expected_h = expected_w = vq_config.resolution

994
995
996
997
998
        return ChameleonImagePixelInputs(
            type="pixel_values",
            data=pixel_values,
            resolve_bindings={"h": expected_h, "w": expected_w},
        )
999

1000
1001
1002
    def get_language_model(self) -> torch.nn.Module:
        return self.model

1003
    def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
1004
1005
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
1006
            return []
1007
        assert self.model.vqmodel is not None
1008
        image_tokens = self.model.get_image_tokens(
1009
            image_input["data"].to(self.config.dtype)
1010
        )
1011
1012
1013
        vision_embeddings = self.model.get_input_embeddings(image_tokens)
        return vision_embeddings

1014
1015
1016
1017
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1018
1019
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1020
        **kwargs,
1021
    ) -> torch.Tensor | IntermediateTensors:
1022
        if intermediate_tensors is not None:
1023
1024
            inputs_embeds = None

1025
1026
1027
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
1028
1029
        return hidden_states

1030
1031
1032
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
1033
    ) -> torch.Tensor | None:
1034
        logits = self.logits_processor(self.lm_head, hidden_states)
1035
1036
1037

        # Disallow image tokens which does not include special
        # begin-image and end-image tokens
1038
1039
1040
        if logits is not None:
            image_tokens = self.model.vocabulary_mapping.image_tokens
            logits[:, image_tokens] = torch.finfo(logits.dtype).min
1041
1042
1043

        return logits

1044
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1045
1046
1047
1048
1049
1050
1051
1052
1053
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
1054
        loaded_params: set[str] = set()
1055
1056
1057
1058
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue

1059
            if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
1060
1061
1062
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
                continue
1063
1064
1065
1066
1067
1068
1069

            # With tie_word_embeddings, we can skip lm_head.weight
            # The weight might appear unnecessarily in the files if the model is
            # processed with quantization, LoRA, fine-tuning, etc.
            if self.config.tie_word_embeddings and "lm_head.weight" in name:
                continue

1070
1071
1072
1073
1074
1075
            use_default_weight_loading = False
            if "vqmodel" in name:
                if self.model.vqmodel is not None:
                    # We only do sharding for language model and
                    # not vqvae for now.
                    use_default_weight_loading = True
1076
            else:
1077
                for param_name, weight_name, shard_id in stacked_params_mapping:
1078
1079
1080
1081
1082
1083
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue
1084
1085
                    if is_pp_missing_parameter(name, self):
                        continue
1086
1087
1088
1089
1090
1091
1092
                    param = params_dict[name]
                    weight_loader = param.weight_loader
                    weight_loader(param, loaded_weight, shard_id)
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
1093
                        continue
1094
1095
1096
                    # Remapping the name of FP8 kv-scale.
                    if name.endswith("kv_scale"):
                        remapped_kv_scale_name = name.replace(
1097
1098
                            ".kv_scale", ".attn.kv_scale"
                        )
1099
                        if remapped_kv_scale_name not in params_dict:
1100
                            logger.warning_once(
1101
1102
1103
1104
                                "Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.",  # noqa: E501
                                name,
                                remapped_kv_scale_name,
                            )
1105
1106
1107
                            continue
                        else:
                            name = remapped_kv_scale_name
1108
1109
                    if is_pp_missing_parameter(name, self):
                        continue
1110
                    param = params_dict[name]
1111
1112
1113
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
1114
1115
                    weight_loader(param, loaded_weight)
            if use_default_weight_loading and name in params_dict:
1116
1117
                if is_pp_missing_parameter(name, self):
                    continue
1118
                param = params_dict[name]
1119
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
1120
                weight_loader(param, loaded_weight)
1121
1122
            loaded_params.add(name)
        return loaded_params