chameleon.py 40.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Iterable, Mapping, Sequence
5
from functools import cached_property
6
from itertools import islice
7
from typing import Annotated, Any, Literal
8
9

import torch
10
import torch.nn as nn
11
import torch.nn.functional as F
12
13
14
15
16
17
from transformers import (
    BatchFeature,
    ChameleonConfig,
    ChameleonProcessor,
    ChameleonVQVAEConfig,
)
18

19
from vllm.attention import Attention
20
from vllm.config import CacheConfig, VllmConfig
21
from vllm.config.multimodal import BaseDummyOptions
22
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
23
from vllm.logger import init_logger
24
from vllm.model_executor.layers.activation import SiluAndMul
25
from vllm.model_executor.layers.conv import Conv2dLayer
26
from vllm.model_executor.layers.layernorm import RMSNorm
27
28
29
30
31
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
32
from vllm.model_executor.layers.logits_processor import LogitsProcessor
33
from vllm.model_executor.layers.quantization import QuantizationConfig
34
35
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
36
37
38
    ParallelLMHead,
    VocabParallelEmbedding,
)
39
from vllm.model_executor.model_loader.weight_utils import (
40
41
42
    default_weight_loader,
    row_parallel_weight_loader,
)
43
from vllm.model_executor.utils import set_weight_attrs
44
from vllm.multimodal import MULTIMODAL_REGISTRY
45
46
47
48
49
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
50
from vllm.multimodal.parse import MultiModalDataItems
51
52
53
54
55
56
57
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
58
from vllm.multimodal.profiling import BaseDummyInputsBuilder
59
from vllm.sequence import IntermediateTensors
60
from vllm.utils.tensor_schema import TensorSchema, TensorShape
61

62
63
64
65
66
67
68
69
70
71
72
73
from .interfaces import (
    MultiModalEmbeddings,
    SupportsMultiModal,
    SupportsPP,
    SupportsQuant,
)
from .utils import (
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
74

75
76
logger = init_logger(__name__)

77

78
79
80
81
82
83
84
85
class ChameleonImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height of each image
        - w: Width of each image
    """
86

87
    type: Literal["pixel_values"]
88
    data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
89
90


91
92
class ChameleonProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
93
94
        return self.ctx.get_hf_config(ChameleonConfig)

95
96
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(ChameleonProcessor, **kwargs)
97

98
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
99
100
        return {"image": 1}

101
102
103
104
105
    def get_num_image_tokens(self) -> int:
        processor = self.get_hf_processor()
        return processor.image_seq_length


106
class ChameleonDummyInputsBuilder(BaseDummyInputsBuilder[ChameleonProcessingInfo]):
107
108
109
110
111
112
113
114
115
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token = processor.image_token

        return image_token * num_images

    def get_dummy_mm_data(
116
117
118
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
119
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
120
    ) -> MultiModalDataDict:
121
        config = self.info.get_hf_config()
122
123
124
125

        width = height = config.vq_config.resolution
        num_images = mm_counts.get("image", 0)

126
127
        image_overrides = mm_options.get("image") if mm_options else None

128
        return {
129
130
131
132
133
134
            "image": self._get_dummy_images(
                width=width,
                height=height,
                num_images=num_images,
                overrides=image_overrides,
            )
135
136
137
        }


138
class ChameleonMultiModalProcessor(BaseMultiModalProcessor[ChameleonProcessingInfo]):
139
140
141
142
143
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
144
        tok_kwargs: Mapping[str, object],
145
146
147
148
149
150
151
152
153
154
    ) -> BatchFeature:
        if not mm_data:
            prompt_ids = self.info.get_tokenizer().encode(prompt)
            prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        return super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
155
            tok_kwargs=tok_kwargs,
156
157
158
159
160
161
162
163
        )

    def _apply_hf_processor_tokens_only(
        self,
        prompt_tokens: list[int],
    ) -> list[int]:
        # HF processor adds sep token for chat mode
        tokenizer = self.info.get_tokenizer()
164
165
166
        vocab = tokenizer.get_vocab()

        sep_token_id = vocab[tokenizer.sep_token]  # type: ignore
167
168
169

        return prompt_tokens + [sep_token_id]

170
171
172
173
174
175
176
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(pixel_values=MultiModalFieldConfig.batched("image"))

177
    def _get_prompt_updates(
178
179
180
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
181
        out_mm_kwargs: MultiModalKwargsItems,
182
    ) -> Sequence[PromptUpdate]:
183
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
184
185
186
187
188
189
190
191
192
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()

        image_start_id = vocab[processor.image_start_token]
        image_token_id = vocab[processor.image_token]
        image_end_id = vocab[processor.image_end_token]

        num_image_tokens = self.info.get_num_image_tokens()
        image_tokens = [image_token_id] * num_image_tokens
193
194
195
196

        return [
            PromptReplacement(
                modality="image",
197
                target=[image_token_id],
198
199
200
                replacement=PromptUpdateDetails.select_token_id(
                    [image_start_id] + image_tokens + [image_end_id],
                    embed_token_id=image_token_id,
201
                ),
202
203
204
            )
        ]

205
206
207
208

class ChameleonLayerNorm(nn.LayerNorm):
    def __init__(self, hidden_size, *args, **kwargs):
        super().__init__(hidden_size, *args, **kwargs)
209
        self.normalized_shape = (hidden_size[-1],)
210

211
212
        set_weight_attrs(self.weight, {"weight_loader": row_parallel_weight_loader})
        set_weight_attrs(self.bias, {"weight_loader": row_parallel_weight_loader})
213

214
    def forward(self, hidden_states):
215
216
217
        hidden_states = F.layer_norm(
            hidden_states, self.normalized_shape, None, None, eps=1e-5
        )
218
219
220
221
222
223
224
225
226
227
228
        hidden_states = hidden_states * self.weight + self.bias
        return hidden_states


# Copied from vllm.model_executor.models.llama.LlamaMLP -> ChameleonMLP
class ChameleonMLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
229
        quant_config: QuantizationConfig | None = None,
230
        bias: bool = False,
231
        prefix: str = "",
232
233
234
235
236
237
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            input_size=hidden_size,
            output_sizes=[intermediate_size] * 2,
            bias=bias,
238
            quant_config=quant_config,
239
            prefix=f"{prefix}.gate_up_proj",
240
241
242
243
244
245
        )
        self.down_proj = RowParallelLinear(
            input_size=intermediate_size,
            output_size=hidden_size,
            bias=bias,
            quant_config=quant_config,
246
            prefix=f"{prefix}.down_proj",
247
        )
248
        if hidden_act != "silu":
249
250
251
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


# Modified from vllm.model_executor.models.llama.LlamaAttention -> ChameleonAttention #noqa
class ChameleonAttention(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
269
        rope_scaling: dict[str, Any] | None = None,
270
        max_position_embeddings: int = 4096,
271
        quant_config: QuantizationConfig | None = None,
272
        bias: bool = False,
273
        cache_config: CacheConfig | None = None,
274
        prefix: str = "",
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        self.qkv_proj = QKVParallelLinear(
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
            bias=bias,
            quant_config=quant_config,
306
            prefix=f"{prefix}.qkv_proj",
307
308
309
310
311
312
        )
        self.o_proj = RowParallelLinear(
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
            bias=bias,
            quant_config=quant_config,
313
            prefix=f"{prefix}.o_proj",
314
315
316
317
318
319
320
321
322
323
324
        )
        self.q_norm = ChameleonLayerNorm((self.num_heads, self.head_dim))
        self.k_norm = ChameleonLayerNorm((self.num_kv_heads, self.head_dim))
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )

325
326
327
328
329
330
331
332
333
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
334

335
336
337
    def _apply_qk_norm(
        self, q: torch.Tensor, k: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
        # reshape for layernorm
        q = q.reshape(-1, self.num_heads, self.head_dim)
        k = k.reshape(-1, self.num_kv_heads, self.head_dim)
        q = self.q_norm(q)
        k = self.k_norm(k)
        q = q.view(*q.shape[:-2], -1)
        k = k.view(*k.shape[:-2], -1)
        return q, k

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self._apply_qk_norm(q, k)

        q, k = self.rotary_emb(positions, q, k)
357
        attn_output = self.attn(q, k, v)
358
359
360
361
362
363
364
365
        output, _ = self.o_proj(attn_output)
        return output


class ChameleonDecoderLayer(nn.Module):
    def __init__(
        self,
        config: ChameleonConfig,
366
367
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
368
        prefix: str = "",
369
370
371
372
373
374
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        if rope_scaling is not None and getattr(
375
376
            config, "original_max_position_embeddings", None
        ):
377
            rope_scaling["original_max_position_embeddings"] = (
378
379
380
                config.original_max_position_embeddings
            )
        max_position_embeddings = getattr(config, "max_position_embeddings", 4096)
381
382
383
384

        self.self_attn = ChameleonAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
385
386
387
            num_kv_heads=getattr(
                config, "num_key_value_heads", config.num_attention_heads
            ),
388
389
390
391
392
393
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
            quant_config=quant_config,
            bias=False,
            cache_config=cache_config,
394
            prefix=f"{prefix}.self_attn",
395
396
397
398
399
400
401
        )
        self.mlp = ChameleonMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
            quant_config=quant_config,
            bias=getattr(config, "mlp_bias", False),
402
            prefix=f"{prefix}.mlp",
403
        )
404
405
406
407
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
408
409
410
411
412

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
413
414
        residual: torch.Tensor | None,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
415
416
417
418
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
419
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
420
421
422
423
424
425
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
426
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
427
428
429
430
431
432
433
434
435
        hidden_states = self.mlp(hidden_states)

        return hidden_states, residual


class ChameleonSwinDecoderLayer(nn.Module):
    def __init__(
        self,
        config: ChameleonConfig,
436
437
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
438
        prefix: str = "",
439
440
441
442
443
444
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        if rope_scaling is not None and getattr(
445
446
            config, "original_max_position_embeddings", None
        ):
447
            rope_scaling["original_max_position_embeddings"] = (
448
449
450
                config.original_max_position_embeddings
            )
        max_position_embeddings = getattr(config, "max_position_embeddings", 4096)
451
452
453
454

        self.self_attn = ChameleonAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
455
456
457
            num_kv_heads=getattr(
                config, "num_key_value_heads", config.num_attention_heads
            ),
458
459
460
461
462
463
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
            quant_config=quant_config,
            bias=False,
            cache_config=cache_config,
464
            prefix=f"{prefix}.self_attn",
465
466
467
468
469
470
471
        )
        self.mlp = ChameleonMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
            quant_config=quant_config,
            bias=getattr(config, "mlp_bias", False),
472
            prefix=f"{prefix}.mlp",
473
        )
474
475
476
477
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
478
479
480
481
482

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
483
        residual: torch.Tensor | None,
484
    ) -> tuple[torch.Tensor, torch.Tensor]:
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
        residual = hidden_states
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = hidden_states + residual

        # Fully Connected
        residual = hidden_states
        hidden_states = self.mlp(hidden_states)
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states, residual


503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEVectorQuantizer #noqa
class ChameleonVQVAEVectorQuantizer(nn.Module):
    def __init__(self, config: ChameleonVQVAEConfig):
        super().__init__()
        self.num_embeddings = config.num_embeddings
        self.embedding_dim = config.embed_dim
        self.beta = getattr(config, "beta", 0.25)

        self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
        self.re_embed = self.num_embeddings

    def forward(self, hidden_state: torch.Tensor):
        hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
        hidden_state_flattened = hidden_state.view(-1, self.embedding_dim)

        # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
        distances = (
520
521
522
523
524
525
526
527
528
            torch.sum(hidden_state_flattened**2, dim=1, keepdim=True)
            + torch.sum(self.embedding.weight**2, dim=1)
            - 2
            * torch.einsum(
                "bd,dn->bn",
                hidden_state_flattened,
                self.embedding.weight.transpose(0, 1),
            )
        )
529
530
531

        min_encoding_indices = torch.argmin(distances, dim=1)
        hidden_state_quant = self.embedding(min_encoding_indices).view(
532
533
            hidden_state.shape
        )
534
535

        # compute loss for embedding
536
537
538
        loss = torch.mean(
            (hidden_state_quant.detach() - hidden_state) ** 2
        ) + self.beta * torch.mean((hidden_state_quant - hidden_state.detach()) ** 2)
539
540

        # preserve gradients
541
        hidden_state_quant = hidden_state + (hidden_state_quant - hidden_state).detach()
542
543

        # reshape back to match original input shape
544
        hidden_state_quant = hidden_state_quant.permute(0, 3, 1, 2).contiguous()
545
546
547
548
549
550
551
552

        return hidden_state_quant, loss, min_encoding_indices


# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderConvDownsample #noqa
class ChameleonVQVAEEncoderConvDownsample(nn.Module):
    def __init__(self, in_channels: int):
        super().__init__()
553
        self.conv = Conv2dLayer(
554
555
            in_channels, in_channels, kernel_size=3, stride=2, padding=0
        )
556
557
558

    def forward(self, hidden_states: torch.Tensor):
        # no asymmetric padding in torch conv, must do it ourselves
559
        hidden_states = F.pad(hidden_states, pad=(0, 1, 0, 1), mode="constant", value=0)
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
        hidden_states = self.conv(hidden_states)
        return hidden_states


# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderResnetBlock #noqa
class ChameleonVQVAEEncoderResnetBlock(nn.Module):
    def __init__(
        self,
        config: ChameleonVQVAEConfig,
        in_channels: int,
        out_channels=None,
        conv_shortcut=False,
    ):
        super().__init__()
        self.in_channels = in_channels
575
        self.out_channels = in_channels if out_channels is None else out_channels
576
577
        self.use_conv_shortcut = conv_shortcut

578
579
580
        self.norm1 = torch.nn.GroupNorm(
            num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
        )
581
        self.conv1 = Conv2dLayer(
582
583
584
585
586
            in_channels, out_channels, kernel_size=3, stride=1, padding=1
        )
        self.norm2 = torch.nn.GroupNorm(
            num_groups=32, num_channels=out_channels, eps=1e-6, affine=True
        )
587
        self.dropout = torch.nn.Dropout(config.dropout)
588
        self.conv2 = Conv2dLayer(
589
590
            out_channels, out_channels, kernel_size=3, stride=1, padding=1
        )
591
592
        if self.in_channels != self.out_channels:
            if self.use_conv_shortcut:
593
                self.conv_shortcut = Conv2dLayer(
594
595
                    in_channels, out_channels, kernel_size=3, stride=1, padding=1
                )
596
            else:
597
                self.nin_shortcut = Conv2dLayer(
598
599
                    in_channels, out_channels, kernel_size=1, stride=1, padding=0
                )
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626

    def forward(self, hidden_states: torch.Tensor):
        residual = hidden_states
        hidden_states = self.norm1(hidden_states)
        hidden_states *= torch.sigmoid(hidden_states)
        hidden_states = self.conv1(hidden_states)

        hidden_states = self.norm2(hidden_states)
        hidden_states *= torch.sigmoid(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.conv2(hidden_states)

        if self.in_channels != self.out_channels:
            if self.use_conv_shortcut:
                residual = self.conv_shortcut(residual)
            else:
                residual = self.nin_shortcut(residual)

        return residual + hidden_states


# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderAttnBlock #noqa
class ChameleonVQVAEEncoderAttnBlock(nn.Module):
    def __init__(self, in_channels: int):
        super().__init__()
        self.in_channels = in_channels

627
628
629
        self.norm = torch.nn.GroupNorm(
            num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
        )
630
        self.q = Conv2dLayer(
631
632
            in_channels, in_channels, kernel_size=1, stride=1, padding=0
        )
633
        self.k = Conv2dLayer(
634
635
            in_channels, in_channels, kernel_size=1, stride=1, padding=0
        )
636
        self.v = Conv2dLayer(
637
638
            in_channels, in_channels, kernel_size=1, stride=1, padding=0
        )
639
        self.proj_out = Conv2dLayer(
640
641
            in_channels, in_channels, kernel_size=1, stride=1, padding=0
        )
642
643
644
645
646
647
648
649
650
651

    def forward(self, hidden_states: torch.Tensor):
        residual = hidden_states
        hidden_states = self.norm(hidden_states)
        query_states = self.q(hidden_states)
        key_states = self.k(hidden_states)
        value_states = self.v(hidden_states)

        # compute attention
        batch_size, channels, height, width = query_states.shape
652
653
654
        query_states = query_states.reshape(
            batch_size, channels, height * width
        ).permute(0, 2, 1)
655
656
        key_states = key_states.reshape(batch_size, channels, height * width)
        attn_weights = torch.bmm(query_states, key_states)
657
        attn_weights = attn_weights * (int(channels) ** (-0.5))
658
659
660
        attn_weights = F.softmax(attn_weights, dim=2)

        # attend to values
661
        value_states = value_states.reshape(batch_size, channels, height * width)
662
        attn_weights = attn_weights.permute(0, 2, 1)
663
664
665
        attn_output = torch.bmm(value_states, attn_weights).reshape(
            batch_size, channels, height, width
        )
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684

        attn_output = self.proj_out(attn_output)
        return residual + attn_output


# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoder #noqa
class ChameleonVQVAEEncoder(nn.Module):
    def __init__(self, config: ChameleonVQVAEConfig):
        super().__init__()

        self.num_resolutions = len(config.channel_multiplier)
        self.num_res_blocks = config.num_res_blocks
        base_channels = config.base_channels
        resolution = config.resolution
        in_channels = config.in_channels
        double_latent = config.double_latent
        latent_channels = config.latent_channels
        channel_multiplier = config.channel_multiplier

685
        self.conv_in = Conv2dLayer(
686
687
            in_channels, base_channels, kernel_size=3, stride=1, padding=1
        )
688
689

        curr_res = resolution
690
        in_channel_multiplier = (1,) + tuple(channel_multiplier)
691
692
693
694
695
696
697
698
699
700
701
702
703
        self.in_channel_multiplier = in_channel_multiplier
        self.down = nn.ModuleList()
        for i_level in range(self.num_resolutions):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            block_in = base_channels * in_channel_multiplier[i_level]
            block_out = base_channels * channel_multiplier[i_level]
            for i_block in range(self.num_res_blocks):
                block.append(
                    ChameleonVQVAEEncoderResnetBlock(
                        config=config,
                        in_channels=block_in,
                        out_channels=block_out,
704
705
                    )
                )
706
                block_in = block_out
707
708
709
710
711
                if (
                    config.attn_resolutions is not None
                    and curr_res in config.attn_resolutions
                    and config.attn_type == "vanilla"
                ):
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
                    attn.append(ChameleonVQVAEEncoderAttnBlock(block_in))

            down = nn.Module()
            down.block = block
            down.attn = attn
            if i_level != self.num_resolutions - 1:
                down.downsample = ChameleonVQVAEEncoderConvDownsample(block_in)
                curr_res = curr_res // 2
            self.down.append(down)

        self.mid = nn.Module()
        self.mid.block_1 = ChameleonVQVAEEncoderResnetBlock(
            config=config,
            in_channels=block_in,
            out_channels=block_in,
        )
728
729
730
731
732
        self.mid.attn_1 = (
            ChameleonVQVAEEncoderAttnBlock(block_in)
            if config.attn_type == "vanilla"
            else nn.Identity()
        )
733
734
735
736
737
738
        self.mid.block_2 = ChameleonVQVAEEncoderResnetBlock(
            config=config,
            in_channels=block_in,
            out_channels=block_in,
        )

739
740
741
        self.norm_out = torch.nn.GroupNorm(
            num_groups=32, num_channels=block_in, eps=1e-6, affine=True
        )
742
        self.conv_out = Conv2dLayer(
743
744
745
746
747
748
749
750
            block_in,
            2 * latent_channels if double_latent else latent_channels,
            kernel_size=3,
            stride=1,
            padding=1,
        )

    def forward(self, pixel_values: torch.Tensor):
751
752
        pixel_values = pixel_values.to(self.conv_in.weight.dtype)

753
754
755
756
        # downsampling
        hidden_states = [self.conv_in(pixel_values)]
        for i_level in range(self.num_resolutions):
            for i_block in range(self.num_res_blocks):
757
                hidden_state = self.down[i_level].block[i_block](hidden_states[-1])
758
                if len(self.down[i_level].attn) > 0:
759
                    hidden_state = self.down[i_level].attn[i_block](hidden_state)
760
761
                hidden_states.append(hidden_state)
            if i_level != self.num_resolutions - 1:
762
                hidden_states.append(self.down[i_level].downsample(hidden_states[-1]))
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782

        # middle
        last_hidden_state = hidden_states[-1]
        last_hidden_state = self.mid.block_1(last_hidden_state)
        last_hidden_state = self.mid.attn_1(last_hidden_state)
        last_hidden_state = self.mid.block_2(last_hidden_state)

        # end
        last_hidden_state = self.norm_out(last_hidden_state)
        last_hidden_state *= torch.sigmoid(last_hidden_state)
        last_hidden_state = self.conv_out(last_hidden_state)
        return last_hidden_state


# Adapted from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAE #noqa
class ChameleonVQVAE(nn.Module):
    def __init__(self, config: ChameleonVQVAEConfig):
        super().__init__()
        self.encoder = ChameleonVQVAEEncoder(config)
        self.quantize = ChameleonVQVAEVectorQuantizer(config)
783
784
        self.quant_conv = Conv2dLayer(config.latent_channels, config.embed_dim, 1)
        self.post_quant_conv = Conv2dLayer(config.embed_dim, config.latent_channels, 1)
785
786
787
788
        self.eval()  # Chameleon's VQ model is frozen

    def encode(
        self, pixel_values: torch.Tensor
789
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
790
791
792
793
794
795
796
        hidden_states = self.encoder(pixel_values)
        hidden_states = self.quant_conv(hidden_states)
        quant, emb_loss, indices = self.quantize(hidden_states)
        return quant, emb_loss, indices


# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonImageVocabularyMapping #noqa
797
798
799
800
801
class ChameleonImageVocabularyMapping:
    """
    A class for mapping discrete image tokens from VQGAN to BPE tokens.
    """

802
    def __init__(self, vocab_map: dict[str, int]):
803
804
805
806
807
808
809
810
811
        self.vocab_map = vocab_map
        self.image_token_id = vocab_map.get("<image>")

    @cached_property
    def val2name(self):
        return {v: k for k, v in self.vocab_map.items()}

    @cached_property
    def image_tokens(self):
812
813
814
        return sorted(
            [val for name, val in self.vocab_map.items() if name.startswith("IMGIMG")]
        )
815
816
817
818
819
820
821

    @cached_property
    def bpe2img(self):
        img_tkn_chr_mapping = {chr(ord("A") + i): str(i) for i in range(10)}

        def remap(old_name: str) -> str:
            return "".join(
822
823
                img_tkn_chr_mapping.get(c, c) for c in old_name[len("IMGIMG") : -1]
            )
824

825
        return {tok: int(remap(self.val2name[tok])) for tok in self.image_tokens}
826
827
828
829
830
831
832
833

    @cached_property
    def img2bpe(self):
        return {v: k for k, v in self.bpe2img.items()}

    @cached_property
    def bpe2img_search_tensors(self):
        return torch.tensor(sorted(self.bpe2img.keys())), torch.tensor(
834
835
            sorted(self.bpe2img.values())
        )
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850

    @cached_property
    def img2bpe_mapping_tensor(self):
        mapping = torch.zeros(max(self.img2bpe.keys()) + 1, dtype=torch.int)
        for k, v in self.img2bpe.items():
            mapping[k] = v
        return mapping

    def convert_img2bpe(self, img_batch: torch.Tensor) -> torch.Tensor:
        device = img_batch.device
        img_tokens = self.img2bpe_mapping_tensor[img_batch.to("cpu")]
        return img_tokens.to(device)


class ChameleonModel(nn.Module):
851
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
852
        super().__init__()
853
854
855
856
857

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

858
859
860
861
862
863
        self.config = config
        self.vocab_size = config.vocab_size
        self.embed_tokens = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
        )
864
865
866
867
        self.vocabulary_mapping = ChameleonImageVocabularyMapping(config.vocabulary_map)
        decoder_layer = (
            ChameleonDecoderLayer
            if not self.config.swin_norm
868
            else ChameleonSwinDecoderLayer
869
        )
870
871
872

        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
873
874
875
876
877
878
            lambda prefix: decoder_layer(
                config=config,
                cache_config=cache_config,
                quant_config=quant_config,
                prefix=prefix,
            ),
879
880
881
            prefix=f"{prefix}.layers",
        )

882
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
883
        self.vqmodel = ChameleonVQVAE(config.vq_config)
884
885
886
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
887

888
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
889
890
        return self.embed_tokens(input_ids)

891
892
893
894
895
896
897
898
899
900
901
902
    def get_image_tokens(self, pixel_values: torch.Tensor) -> torch.Tensor:
        """
        Tokenizes images into discrete tokens with VQGAN module. Converts
        obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
        special tokens.
        """
        batch_size = pixel_values.shape[0]
        _, _, image_toks = self.vqmodel.encode(pixel_values)
        bpe_toks = self.vocabulary_mapping.convert_img2bpe(image_toks)
        bpe_toks = bpe_toks.view(batch_size, -1)
        return bpe_toks

903
904
    def forward(
        self,
905
        input_ids: torch.Tensor | None,
906
        positions: torch.Tensor,
907
908
909
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
910
911
912
913
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
914
                hidden_states = self.embed_input_ids(input_ids)
915
            residual = None
916
        else:
917
918
919
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
920
        for layer in islice(self.layers, self.start_layer, self.end_layer):
921
922
923
924
925
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )
926
        if not get_pp_group().is_last_rank:
927
928
929
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
930
931
932
933
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


934
935
936
@MULTIMODAL_REGISTRY.register_processor(
    ChameleonMultiModalProcessor,
    info=ChameleonProcessingInfo,
937
938
939
940
941
    dummy_inputs=ChameleonDummyInputsBuilder,
)
class ChameleonForConditionalGeneration(
    nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant
):
942
943
    merge_by_field_config = True

944
945
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
946
        "gate_up_proj": ["gate_proj", "up_proj"],
947
    }
948

949
    @classmethod
950
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
951
952
953
954
955
        if modality.startswith("image"):
            return "<image>"

        raise ValueError("Only image modality is supported")

956
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
957
        super().__init__()
958
959
        config = vllm_config.model_config.hf_config
        multimodal_config = vllm_config.model_config.multimodal_config
960
        self.config = config
961
        self.multimodal_config = multimodal_config
962
963
964
        self.model = ChameleonModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
965

966
        self.lm_head = ParallelLMHead(
967
            config.vocab_size,
968
            config.hidden_size,
969
            prefix=maybe_prefix(prefix, "lm_head"),
970
971
972
973
974
        )
        if config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight

        logit_scale = getattr(config, "logit_scale", 1.0)
975
        self.logits_processor = LogitsProcessor(config.vocab_size, scale=logit_scale)
976
        self.make_empty_intermediate_tensors = (
977
978
            self.model.make_empty_intermediate_tensors
        )
979

980
    def _parse_and_validate_image_input(
981
        self, **kwargs: object
982
    ) -> ChameleonImagePixelInputs | None:
983
984
985
986
987
        pixel_values = kwargs.pop("pixel_values", None)

        if pixel_values is None:
            return None

988
989
990
        vq_config: ChameleonVQVAEConfig = self.config.vq_config
        expected_h = expected_w = vq_config.resolution

991
992
993
994
995
        return ChameleonImagePixelInputs(
            type="pixel_values",
            data=pixel_values,
            resolve_bindings={"h": expected_h, "w": expected_w},
        )
996

997
998
999
    def get_language_model(self) -> torch.nn.Module:
        return self.model

1000
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
1001
1002
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
1003
            return []
1004
        assert self.model.vqmodel is not None
1005
        image_tokens = self.model.get_image_tokens(
1006
            image_input["data"].to(self.config.dtype)
1007
        )
1008
        vision_embeddings = self.model.embed_input_ids(image_tokens)
1009
1010
        return vision_embeddings

1011
1012
1013
1014
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1015
1016
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1017
        **kwargs,
1018
    ) -> torch.Tensor | IntermediateTensors:
1019
        if intermediate_tensors is not None:
1020
1021
            inputs_embeds = None

1022
1023
1024
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
1025
1026
        return hidden_states

1027
1028
1029
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
1030
    ) -> torch.Tensor | None:
1031
        logits = self.logits_processor(self.lm_head, hidden_states)
1032
1033
1034

        # Disallow image tokens which does not include special
        # begin-image and end-image tokens
1035
1036
1037
        if logits is not None:
            image_tokens = self.model.vocabulary_mapping.image_tokens
            logits[:, image_tokens] = torch.finfo(logits.dtype).min
1038
1039
1040

        return logits

1041
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1042
1043
1044
1045
1046
1047
1048
1049
1050
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
1051
        loaded_params: set[str] = set()
1052
1053
1054
1055
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue

1056
            if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
1057
1058
1059
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
                continue
1060
1061
1062
1063
1064
1065
1066

            # With tie_word_embeddings, we can skip lm_head.weight
            # The weight might appear unnecessarily in the files if the model is
            # processed with quantization, LoRA, fine-tuning, etc.
            if self.config.tie_word_embeddings and "lm_head.weight" in name:
                continue

1067
1068
1069
1070
1071
1072
            use_default_weight_loading = False
            if "vqmodel" in name:
                if self.model.vqmodel is not None:
                    # We only do sharding for language model and
                    # not vqvae for now.
                    use_default_weight_loading = True
1073
            else:
1074
                for param_name, weight_name, shard_id in stacked_params_mapping:
1075
1076
1077
1078
1079
1080
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue
1081
1082
                    if is_pp_missing_parameter(name, self):
                        continue
1083
1084
1085
1086
1087
1088
1089
                    param = params_dict[name]
                    weight_loader = param.weight_loader
                    weight_loader(param, loaded_weight, shard_id)
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
1090
                        continue
1091
1092
1093
                    # Remapping the name of FP8 kv-scale.
                    if name.endswith("kv_scale"):
                        remapped_kv_scale_name = name.replace(
1094
1095
                            ".kv_scale", ".attn.kv_scale"
                        )
1096
                        if remapped_kv_scale_name not in params_dict:
1097
                            logger.warning_once(
1098
1099
1100
1101
                                "Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.",  # noqa: E501
                                name,
                                remapped_kv_scale_name,
                            )
1102
1103
1104
                            continue
                        else:
                            name = remapped_kv_scale_name
1105
1106
                    if is_pp_missing_parameter(name, self):
                        continue
1107
                    param = params_dict[name]
1108
1109
1110
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
1111
1112
                    weight_loader(param, loaded_weight)
            if use_default_weight_loading and name in params_dict:
1113
1114
                if is_pp_missing_parameter(name, self):
                    continue
1115
                param = params_dict[name]
1116
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
1117
                weight_loader(param, loaded_weight)
1118
1119
            loaded_params.add(name)
        return loaded_params