chameleon.py 39.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Iterable, Mapping, Sequence
5
from functools import cached_property
6
from itertools import islice
7
from typing import Annotated, Any, Literal
8
9

import torch
10
import torch.nn as nn
11
import torch.nn.functional as F
12
13
14
15
16
17
from transformers import (
    BatchFeature,
    ChameleonConfig,
    ChameleonProcessor,
    ChameleonVQVAEConfig,
)
18

19
from vllm.config import CacheConfig, VllmConfig
20
from vllm.config.multimodal import BaseDummyOptions
21
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
22
from vllm.logger import init_logger
23
from vllm.model_executor.layers.activation import SiluAndMul
24
from vllm.model_executor.layers.attention import Attention
25
from vllm.model_executor.layers.conv import Conv2dLayer
26
from vllm.model_executor.layers.layernorm import RMSNorm
27
28
29
30
31
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
32
from vllm.model_executor.layers.logits_processor import LogitsProcessor
33
from vllm.model_executor.layers.quantization import QuantizationConfig
34
35
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
36
37
38
    ParallelLMHead,
    VocabParallelEmbedding,
)
39
from vllm.model_executor.model_loader.weight_utils import (
40
41
42
    default_weight_loader,
    row_parallel_weight_loader,
)
43
from vllm.model_executor.utils import set_weight_attrs
44
from vllm.multimodal import MULTIMODAL_REGISTRY
45
46
47
48
49
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
50
from vllm.multimodal.parse import MultiModalDataItems
51
from vllm.multimodal.processing import (
52
    BaseDummyInputsBuilder,
53
54
55
56
57
58
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
59
from vllm.sequence import IntermediateTensors
60
from vllm.utils.tensor_schema import TensorSchema, TensorShape
61

62
63
64
65
66
67
68
69
70
71
72
73
from .interfaces import (
    MultiModalEmbeddings,
    SupportsMultiModal,
    SupportsPP,
    SupportsQuant,
)
from .utils import (
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
74

75
76
logger = init_logger(__name__)

77

78
79
80
81
82
83
84
85
class ChameleonImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height of each image
        - w: Width of each image
    """
86

87
    type: Literal["pixel_values"]
88
    data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
89
90


91
92
class ChameleonProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
93
94
        return self.ctx.get_hf_config(ChameleonConfig)

95
96
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(ChameleonProcessor, **kwargs)
97

98
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
99
100
        return {"image": 1}

101
102
103
104
105
    def get_num_image_tokens(self) -> int:
        processor = self.get_hf_processor()
        return processor.image_seq_length


106
class ChameleonDummyInputsBuilder(BaseDummyInputsBuilder[ChameleonProcessingInfo]):
107
108
109
110
111
112
113
114
115
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token = processor.image_token

        return image_token * num_images

    def get_dummy_mm_data(
116
117
118
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
119
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
120
        mm_processor_kwargs: Mapping[str, object] | None = None,
121
    ) -> MultiModalDataDict:
122
        config = self.info.get_hf_config()
123
124
125
126

        width = height = config.vq_config.resolution
        num_images = mm_counts.get("image", 0)

127
128
        image_overrides = mm_options.get("image") if mm_options else None

129
        return {
130
131
132
133
134
135
            "image": self._get_dummy_images(
                width=width,
                height=height,
                num_images=num_images,
                overrides=image_overrides,
            )
136
137
138
        }


139
class ChameleonMultiModalProcessor(BaseMultiModalProcessor[ChameleonProcessingInfo]):
140
141
142
143
144
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
145
        tok_kwargs: Mapping[str, object],
146
147
148
149
150
151
152
153
154
155
    ) -> BatchFeature:
        if not mm_data:
            prompt_ids = self.info.get_tokenizer().encode(prompt)
            prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        return super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
156
            tok_kwargs=tok_kwargs,
157
158
159
160
161
162
163
164
        )

    def _apply_hf_processor_tokens_only(
        self,
        prompt_tokens: list[int],
    ) -> list[int]:
        # HF processor adds sep token for chat mode
        tokenizer = self.info.get_tokenizer()
165
166
167
        vocab = tokenizer.get_vocab()

        sep_token_id = vocab[tokenizer.sep_token]  # type: ignore
168
169
170

        return prompt_tokens + [sep_token_id]

171
172
173
174
175
176
177
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(pixel_values=MultiModalFieldConfig.batched("image"))

178
    def _get_prompt_updates(
179
180
181
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
182
        out_mm_kwargs: MultiModalKwargsItems,
183
    ) -> Sequence[PromptUpdate]:
184
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
185
186
187
188
189
190
191
192
193
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()

        image_start_id = vocab[processor.image_start_token]
        image_token_id = vocab[processor.image_token]
        image_end_id = vocab[processor.image_end_token]

        num_image_tokens = self.info.get_num_image_tokens()
        image_tokens = [image_token_id] * num_image_tokens
194
195
196
197

        return [
            PromptReplacement(
                modality="image",
198
                target=[image_token_id],
199
200
201
                replacement=PromptUpdateDetails.select_token_id(
                    [image_start_id] + image_tokens + [image_end_id],
                    embed_token_id=image_token_id,
202
                ),
203
204
205
            )
        ]

206
207
208
209

class ChameleonLayerNorm(nn.LayerNorm):
    def __init__(self, hidden_size, *args, **kwargs):
        super().__init__(hidden_size, *args, **kwargs)
210
        self.normalized_shape = (hidden_size[-1],)
211

212
213
        set_weight_attrs(self.weight, {"weight_loader": row_parallel_weight_loader})
        set_weight_attrs(self.bias, {"weight_loader": row_parallel_weight_loader})
214

215
    def forward(self, hidden_states):
216
217
218
        hidden_states = F.layer_norm(
            hidden_states, self.normalized_shape, None, None, eps=1e-5
        )
219
220
221
222
223
224
225
226
227
228
229
        hidden_states = hidden_states * self.weight + self.bias
        return hidden_states


# Copied from vllm.model_executor.models.llama.LlamaMLP -> ChameleonMLP
class ChameleonMLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
230
        quant_config: QuantizationConfig | None = None,
231
        bias: bool = False,
232
        prefix: str = "",
233
234
235
236
237
238
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            input_size=hidden_size,
            output_sizes=[intermediate_size] * 2,
            bias=bias,
239
            quant_config=quant_config,
240
            prefix=f"{prefix}.gate_up_proj",
241
242
243
244
245
246
        )
        self.down_proj = RowParallelLinear(
            input_size=intermediate_size,
            output_size=hidden_size,
            bias=bias,
            quant_config=quant_config,
247
            prefix=f"{prefix}.down_proj",
248
        )
249
        if hidden_act != "silu":
250
251
252
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


# Modified from vllm.model_executor.models.llama.LlamaAttention -> ChameleonAttention #noqa
class ChameleonAttention(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
269
        rope_parameters: dict[str, Any],
270
        max_position_embeddings: int = 4096,
271
        quant_config: QuantizationConfig | None = None,
272
        bias: bool = False,
273
        cache_config: CacheConfig | None = None,
274
        prefix: str = "",
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.max_position_embeddings = max_position_embeddings

        self.qkv_proj = QKVParallelLinear(
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
            bias=bias,
            quant_config=quant_config,
305
            prefix=f"{prefix}.qkv_proj",
306
307
308
309
310
311
        )
        self.o_proj = RowParallelLinear(
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
            bias=bias,
            quant_config=quant_config,
312
            prefix=f"{prefix}.o_proj",
313
314
315
316
317
318
        )
        self.q_norm = ChameleonLayerNorm((self.num_heads, self.head_dim))
        self.k_norm = ChameleonLayerNorm((self.num_kv_heads, self.head_dim))
        self.rotary_emb = get_rope(
            self.head_dim,
            max_position=max_position_embeddings,
319
            rope_parameters=rope_parameters,
320
321
        )

322
323
324
325
326
327
328
329
330
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
331

332
333
334
    def _apply_qk_norm(
        self, q: torch.Tensor, k: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
        # reshape for layernorm
        q = q.reshape(-1, self.num_heads, self.head_dim)
        k = k.reshape(-1, self.num_kv_heads, self.head_dim)
        q = self.q_norm(q)
        k = self.k_norm(k)
        q = q.view(*q.shape[:-2], -1)
        k = k.view(*k.shape[:-2], -1)
        return q, k

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self._apply_qk_norm(q, k)

        q, k = self.rotary_emb(positions, q, k)
354
        attn_output = self.attn(q, k, v)
355
356
357
358
359
360
361
362
        output, _ = self.o_proj(attn_output)
        return output


class ChameleonDecoderLayer(nn.Module):
    def __init__(
        self,
        config: ChameleonConfig,
363
364
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
365
        prefix: str = "",
366
367
368
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
369
        max_position_embeddings = getattr(config, "max_position_embeddings", 4096)
370
371
372
373

        self.self_attn = ChameleonAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
374
375
376
            num_kv_heads=getattr(
                config, "num_key_value_heads", config.num_attention_heads
            ),
377
            rope_parameters=config.rope_parameters,
378
379
380
381
            max_position_embeddings=max_position_embeddings,
            quant_config=quant_config,
            bias=False,
            cache_config=cache_config,
382
            prefix=f"{prefix}.self_attn",
383
384
385
386
387
388
389
        )
        self.mlp = ChameleonMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
            quant_config=quant_config,
            bias=getattr(config, "mlp_bias", False),
390
            prefix=f"{prefix}.mlp",
391
        )
392
393
394
395
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
396
397
398
399
400

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
401
402
        residual: torch.Tensor | None,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
403
404
405
406
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
407
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
408
409
410
411
412
413
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
414
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
415
416
417
418
419
420
421
422
423
        hidden_states = self.mlp(hidden_states)

        return hidden_states, residual


class ChameleonSwinDecoderLayer(nn.Module):
    def __init__(
        self,
        config: ChameleonConfig,
424
425
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
426
        prefix: str = "",
427
428
429
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
430
        max_position_embeddings = getattr(config, "max_position_embeddings", 4096)
431
432
433
434

        self.self_attn = ChameleonAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
435
436
437
            num_kv_heads=getattr(
                config, "num_key_value_heads", config.num_attention_heads
            ),
438
            rope_parameters=config.rope_parameters,
439
440
441
442
            max_position_embeddings=max_position_embeddings,
            quant_config=quant_config,
            bias=False,
            cache_config=cache_config,
443
            prefix=f"{prefix}.self_attn",
444
445
446
447
448
449
450
        )
        self.mlp = ChameleonMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
            quant_config=quant_config,
            bias=getattr(config, "mlp_bias", False),
451
            prefix=f"{prefix}.mlp",
452
        )
453
454
455
456
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
457
458
459
460
461

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
462
        residual: torch.Tensor | None,
463
    ) -> tuple[torch.Tensor, torch.Tensor]:
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
        residual = hidden_states
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = hidden_states + residual

        # Fully Connected
        residual = hidden_states
        hidden_states = self.mlp(hidden_states)
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states, residual


482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEVectorQuantizer #noqa
class ChameleonVQVAEVectorQuantizer(nn.Module):
    def __init__(self, config: ChameleonVQVAEConfig):
        super().__init__()
        self.num_embeddings = config.num_embeddings
        self.embedding_dim = config.embed_dim
        self.beta = getattr(config, "beta", 0.25)

        self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
        self.re_embed = self.num_embeddings

    def forward(self, hidden_state: torch.Tensor):
        hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
        hidden_state_flattened = hidden_state.view(-1, self.embedding_dim)

        # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
        distances = (
499
500
501
502
503
504
505
506
507
            torch.sum(hidden_state_flattened**2, dim=1, keepdim=True)
            + torch.sum(self.embedding.weight**2, dim=1)
            - 2
            * torch.einsum(
                "bd,dn->bn",
                hidden_state_flattened,
                self.embedding.weight.transpose(0, 1),
            )
        )
508
509
510

        min_encoding_indices = torch.argmin(distances, dim=1)
        hidden_state_quant = self.embedding(min_encoding_indices).view(
511
512
            hidden_state.shape
        )
513
514

        # compute loss for embedding
515
516
517
        loss = torch.mean(
            (hidden_state_quant.detach() - hidden_state) ** 2
        ) + self.beta * torch.mean((hidden_state_quant - hidden_state.detach()) ** 2)
518
519

        # preserve gradients
520
        hidden_state_quant = hidden_state + (hidden_state_quant - hidden_state).detach()
521
522

        # reshape back to match original input shape
523
        hidden_state_quant = hidden_state_quant.permute(0, 3, 1, 2).contiguous()
524
525
526
527
528
529
530
531

        return hidden_state_quant, loss, min_encoding_indices


# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderConvDownsample #noqa
class ChameleonVQVAEEncoderConvDownsample(nn.Module):
    def __init__(self, in_channels: int):
        super().__init__()
532
        self.conv = Conv2dLayer(
533
534
            in_channels, in_channels, kernel_size=3, stride=2, padding=0
        )
535
536
537

    def forward(self, hidden_states: torch.Tensor):
        # no asymmetric padding in torch conv, must do it ourselves
538
        hidden_states = F.pad(hidden_states, pad=(0, 1, 0, 1), mode="constant", value=0)
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
        hidden_states = self.conv(hidden_states)
        return hidden_states


# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderResnetBlock #noqa
class ChameleonVQVAEEncoderResnetBlock(nn.Module):
    def __init__(
        self,
        config: ChameleonVQVAEConfig,
        in_channels: int,
        out_channels=None,
        conv_shortcut=False,
    ):
        super().__init__()
        self.in_channels = in_channels
554
        self.out_channels = in_channels if out_channels is None else out_channels
555
556
        self.use_conv_shortcut = conv_shortcut

557
558
559
        self.norm1 = torch.nn.GroupNorm(
            num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
        )
560
        self.conv1 = Conv2dLayer(
561
562
563
564
565
            in_channels, out_channels, kernel_size=3, stride=1, padding=1
        )
        self.norm2 = torch.nn.GroupNorm(
            num_groups=32, num_channels=out_channels, eps=1e-6, affine=True
        )
566
        self.dropout = torch.nn.Dropout(config.dropout)
567
        self.conv2 = Conv2dLayer(
568
569
            out_channels, out_channels, kernel_size=3, stride=1, padding=1
        )
570
571
        if self.in_channels != self.out_channels:
            if self.use_conv_shortcut:
572
                self.conv_shortcut = Conv2dLayer(
573
574
                    in_channels, out_channels, kernel_size=3, stride=1, padding=1
                )
575
            else:
576
                self.nin_shortcut = Conv2dLayer(
577
578
                    in_channels, out_channels, kernel_size=1, stride=1, padding=0
                )
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605

    def forward(self, hidden_states: torch.Tensor):
        residual = hidden_states
        hidden_states = self.norm1(hidden_states)
        hidden_states *= torch.sigmoid(hidden_states)
        hidden_states = self.conv1(hidden_states)

        hidden_states = self.norm2(hidden_states)
        hidden_states *= torch.sigmoid(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.conv2(hidden_states)

        if self.in_channels != self.out_channels:
            if self.use_conv_shortcut:
                residual = self.conv_shortcut(residual)
            else:
                residual = self.nin_shortcut(residual)

        return residual + hidden_states


# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderAttnBlock #noqa
class ChameleonVQVAEEncoderAttnBlock(nn.Module):
    def __init__(self, in_channels: int):
        super().__init__()
        self.in_channels = in_channels

606
607
608
        self.norm = torch.nn.GroupNorm(
            num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
        )
609
        self.q = Conv2dLayer(
610
611
            in_channels, in_channels, kernel_size=1, stride=1, padding=0
        )
612
        self.k = Conv2dLayer(
613
614
            in_channels, in_channels, kernel_size=1, stride=1, padding=0
        )
615
        self.v = Conv2dLayer(
616
617
            in_channels, in_channels, kernel_size=1, stride=1, padding=0
        )
618
        self.proj_out = Conv2dLayer(
619
620
            in_channels, in_channels, kernel_size=1, stride=1, padding=0
        )
621
622
623
624
625
626
627
628
629
630

    def forward(self, hidden_states: torch.Tensor):
        residual = hidden_states
        hidden_states = self.norm(hidden_states)
        query_states = self.q(hidden_states)
        key_states = self.k(hidden_states)
        value_states = self.v(hidden_states)

        # compute attention
        batch_size, channels, height, width = query_states.shape
631
632
633
        query_states = query_states.reshape(
            batch_size, channels, height * width
        ).permute(0, 2, 1)
634
635
        key_states = key_states.reshape(batch_size, channels, height * width)
        attn_weights = torch.bmm(query_states, key_states)
636
        attn_weights = attn_weights * (int(channels) ** (-0.5))
637
638
639
        attn_weights = F.softmax(attn_weights, dim=2)

        # attend to values
640
        value_states = value_states.reshape(batch_size, channels, height * width)
641
        attn_weights = attn_weights.permute(0, 2, 1)
642
643
644
        attn_output = torch.bmm(value_states, attn_weights).reshape(
            batch_size, channels, height, width
        )
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663

        attn_output = self.proj_out(attn_output)
        return residual + attn_output


# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoder #noqa
class ChameleonVQVAEEncoder(nn.Module):
    def __init__(self, config: ChameleonVQVAEConfig):
        super().__init__()

        self.num_resolutions = len(config.channel_multiplier)
        self.num_res_blocks = config.num_res_blocks
        base_channels = config.base_channels
        resolution = config.resolution
        in_channels = config.in_channels
        double_latent = config.double_latent
        latent_channels = config.latent_channels
        channel_multiplier = config.channel_multiplier

664
        self.conv_in = Conv2dLayer(
665
666
            in_channels, base_channels, kernel_size=3, stride=1, padding=1
        )
667
668

        curr_res = resolution
669
        in_channel_multiplier = (1,) + tuple(channel_multiplier)
670
671
672
673
674
675
676
677
678
679
680
681
682
        self.in_channel_multiplier = in_channel_multiplier
        self.down = nn.ModuleList()
        for i_level in range(self.num_resolutions):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            block_in = base_channels * in_channel_multiplier[i_level]
            block_out = base_channels * channel_multiplier[i_level]
            for i_block in range(self.num_res_blocks):
                block.append(
                    ChameleonVQVAEEncoderResnetBlock(
                        config=config,
                        in_channels=block_in,
                        out_channels=block_out,
683
684
                    )
                )
685
                block_in = block_out
686
687
688
689
690
                if (
                    config.attn_resolutions is not None
                    and curr_res in config.attn_resolutions
                    and config.attn_type == "vanilla"
                ):
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
                    attn.append(ChameleonVQVAEEncoderAttnBlock(block_in))

            down = nn.Module()
            down.block = block
            down.attn = attn
            if i_level != self.num_resolutions - 1:
                down.downsample = ChameleonVQVAEEncoderConvDownsample(block_in)
                curr_res = curr_res // 2
            self.down.append(down)

        self.mid = nn.Module()
        self.mid.block_1 = ChameleonVQVAEEncoderResnetBlock(
            config=config,
            in_channels=block_in,
            out_channels=block_in,
        )
707
708
709
710
711
        self.mid.attn_1 = (
            ChameleonVQVAEEncoderAttnBlock(block_in)
            if config.attn_type == "vanilla"
            else nn.Identity()
        )
712
713
714
715
716
717
        self.mid.block_2 = ChameleonVQVAEEncoderResnetBlock(
            config=config,
            in_channels=block_in,
            out_channels=block_in,
        )

718
719
720
        self.norm_out = torch.nn.GroupNorm(
            num_groups=32, num_channels=block_in, eps=1e-6, affine=True
        )
721
        self.conv_out = Conv2dLayer(
722
723
724
725
726
727
728
729
            block_in,
            2 * latent_channels if double_latent else latent_channels,
            kernel_size=3,
            stride=1,
            padding=1,
        )

    def forward(self, pixel_values: torch.Tensor):
730
731
        pixel_values = pixel_values.to(self.conv_in.weight.dtype)

732
733
734
735
        # downsampling
        hidden_states = [self.conv_in(pixel_values)]
        for i_level in range(self.num_resolutions):
            for i_block in range(self.num_res_blocks):
736
                hidden_state = self.down[i_level].block[i_block](hidden_states[-1])
737
                if len(self.down[i_level].attn) > 0:
738
                    hidden_state = self.down[i_level].attn[i_block](hidden_state)
739
740
                hidden_states.append(hidden_state)
            if i_level != self.num_resolutions - 1:
741
                hidden_states.append(self.down[i_level].downsample(hidden_states[-1]))
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761

        # middle
        last_hidden_state = hidden_states[-1]
        last_hidden_state = self.mid.block_1(last_hidden_state)
        last_hidden_state = self.mid.attn_1(last_hidden_state)
        last_hidden_state = self.mid.block_2(last_hidden_state)

        # end
        last_hidden_state = self.norm_out(last_hidden_state)
        last_hidden_state *= torch.sigmoid(last_hidden_state)
        last_hidden_state = self.conv_out(last_hidden_state)
        return last_hidden_state


# Adapted from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAE #noqa
class ChameleonVQVAE(nn.Module):
    def __init__(self, config: ChameleonVQVAEConfig):
        super().__init__()
        self.encoder = ChameleonVQVAEEncoder(config)
        self.quantize = ChameleonVQVAEVectorQuantizer(config)
762
763
        self.quant_conv = Conv2dLayer(config.latent_channels, config.embed_dim, 1)
        self.post_quant_conv = Conv2dLayer(config.embed_dim, config.latent_channels, 1)
764
765
766
767
        self.eval()  # Chameleon's VQ model is frozen

    def encode(
        self, pixel_values: torch.Tensor
768
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
769
770
771
772
773
774
775
        hidden_states = self.encoder(pixel_values)
        hidden_states = self.quant_conv(hidden_states)
        quant, emb_loss, indices = self.quantize(hidden_states)
        return quant, emb_loss, indices


# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonImageVocabularyMapping #noqa
776
777
778
779
780
class ChameleonImageVocabularyMapping:
    """
    A class for mapping discrete image tokens from VQGAN to BPE tokens.
    """

781
    def __init__(self, vocab_map: dict[str, int]):
782
783
784
785
786
787
788
789
790
        self.vocab_map = vocab_map
        self.image_token_id = vocab_map.get("<image>")

    @cached_property
    def val2name(self):
        return {v: k for k, v in self.vocab_map.items()}

    @cached_property
    def image_tokens(self):
791
792
793
        return sorted(
            [val for name, val in self.vocab_map.items() if name.startswith("IMGIMG")]
        )
794
795
796
797
798
799
800

    @cached_property
    def bpe2img(self):
        img_tkn_chr_mapping = {chr(ord("A") + i): str(i) for i in range(10)}

        def remap(old_name: str) -> str:
            return "".join(
801
802
                img_tkn_chr_mapping.get(c, c) for c in old_name[len("IMGIMG") : -1]
            )
803

804
        return {tok: int(remap(self.val2name[tok])) for tok in self.image_tokens}
805
806
807
808
809
810
811
812

    @cached_property
    def img2bpe(self):
        return {v: k for k, v in self.bpe2img.items()}

    @cached_property
    def bpe2img_search_tensors(self):
        return torch.tensor(sorted(self.bpe2img.keys())), torch.tensor(
813
814
            sorted(self.bpe2img.values())
        )
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829

    @cached_property
    def img2bpe_mapping_tensor(self):
        mapping = torch.zeros(max(self.img2bpe.keys()) + 1, dtype=torch.int)
        for k, v in self.img2bpe.items():
            mapping[k] = v
        return mapping

    def convert_img2bpe(self, img_batch: torch.Tensor) -> torch.Tensor:
        device = img_batch.device
        img_tokens = self.img2bpe_mapping_tensor[img_batch.to("cpu")]
        return img_tokens.to(device)


class ChameleonModel(nn.Module):
830
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
831
        super().__init__()
832
833
834
835
836

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

837
838
839
840
841
842
        self.config = config
        self.vocab_size = config.vocab_size
        self.embed_tokens = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
        )
843
844
845
846
        self.vocabulary_mapping = ChameleonImageVocabularyMapping(config.vocabulary_map)
        decoder_layer = (
            ChameleonDecoderLayer
            if not self.config.swin_norm
847
            else ChameleonSwinDecoderLayer
848
        )
849
850
851

        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
852
853
854
855
856
857
            lambda prefix: decoder_layer(
                config=config,
                cache_config=cache_config,
                quant_config=quant_config,
                prefix=prefix,
            ),
858
859
860
            prefix=f"{prefix}.layers",
        )

861
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
862
        self.vqmodel = ChameleonVQVAE(config.vq_config)
863
864
865
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
866

867
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
868
869
        return self.embed_tokens(input_ids)

870
871
872
873
874
875
876
877
878
879
880
881
    def get_image_tokens(self, pixel_values: torch.Tensor) -> torch.Tensor:
        """
        Tokenizes images into discrete tokens with VQGAN module. Converts
        obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
        special tokens.
        """
        batch_size = pixel_values.shape[0]
        _, _, image_toks = self.vqmodel.encode(pixel_values)
        bpe_toks = self.vocabulary_mapping.convert_img2bpe(image_toks)
        bpe_toks = bpe_toks.view(batch_size, -1)
        return bpe_toks

882
883
    def forward(
        self,
884
        input_ids: torch.Tensor | None,
885
        positions: torch.Tensor,
886
887
888
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
889
890
891
892
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
893
                hidden_states = self.embed_input_ids(input_ids)
894
            residual = None
895
        else:
896
897
898
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
899
        for layer in islice(self.layers, self.start_layer, self.end_layer):
900
901
902
903
904
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )
905
        if not get_pp_group().is_last_rank:
906
907
908
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
909
910
911
912
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


913
914
915
@MULTIMODAL_REGISTRY.register_processor(
    ChameleonMultiModalProcessor,
    info=ChameleonProcessingInfo,
916
917
918
919
920
    dummy_inputs=ChameleonDummyInputsBuilder,
)
class ChameleonForConditionalGeneration(
    nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant
):
921
922
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
923
        "gate_up_proj": ["gate_proj", "up_proj"],
924
    }
925

926
    @classmethod
927
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
928
929
930
931
932
        if modality.startswith("image"):
            return "<image>"

        raise ValueError("Only image modality is supported")

933
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
934
        super().__init__()
935
936
        config = vllm_config.model_config.hf_config
        multimodal_config = vllm_config.model_config.multimodal_config
937
        self.config = config
938
        self.multimodal_config = multimodal_config
939
940
941
942
943
944
945
946
947
948
949
950
951
952

        with self._mark_composite_model(
            vllm_config,
            language_targets=(
                ChameleonDecoderLayer
                if not self.config.swin_norm
                else ChameleonSwinDecoderLayer
            ),
            tower_targets={"image": ChameleonVQVAE},
        ):
            self.model = ChameleonModel(
                vllm_config=vllm_config,
                prefix=maybe_prefix(prefix, "model"),
            )
953

954
        self.lm_head = ParallelLMHead(
955
            config.vocab_size,
956
            config.hidden_size,
957
            prefix=maybe_prefix(prefix, "lm_head"),
958
959
960
961
962
        )
        if config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight

        logit_scale = getattr(config, "logit_scale", 1.0)
963
        self.logits_processor = LogitsProcessor(config.vocab_size, scale=logit_scale)
964
        self.make_empty_intermediate_tensors = (
965
966
            self.model.make_empty_intermediate_tensors
        )
967

968
    def _parse_and_validate_image_input(
969
        self, **kwargs: object
970
    ) -> ChameleonImagePixelInputs | None:
971
972
973
974
975
        pixel_values = kwargs.pop("pixel_values", None)

        if pixel_values is None:
            return None

976
977
978
        vq_config: ChameleonVQVAEConfig = self.config.vq_config
        expected_h = expected_w = vq_config.resolution

979
980
981
982
983
        return ChameleonImagePixelInputs(
            type="pixel_values",
            data=pixel_values,
            resolve_bindings={"h": expected_h, "w": expected_w},
        )
984

985
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
986
987
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
988
            return []
989
        assert self.model.vqmodel is not None
990
        image_tokens = self.model.get_image_tokens(
991
            image_input["data"].to(self.config.dtype)
992
        )
993
        vision_embeddings = self.model.embed_input_ids(image_tokens)
994
995
        return vision_embeddings

996
997
    def forward(
        self,
998
        input_ids: torch.Tensor | None,
999
        positions: torch.Tensor,
1000
1001
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
1002
        **kwargs,
1003
    ) -> torch.Tensor | IntermediateTensors:
1004
        if intermediate_tensors is not None:
1005
1006
            inputs_embeds = None

1007
1008
1009
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
1010
1011
        return hidden_states

1012
1013
1014
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
1015
    ) -> torch.Tensor | None:
1016
        logits = self.logits_processor(self.lm_head, hidden_states)
1017
1018
1019

        # Disallow image tokens which does not include special
        # begin-image and end-image tokens
1020
1021
1022
        if logits is not None:
            image_tokens = self.model.vocabulary_mapping.image_tokens
            logits[:, image_tokens] = torch.finfo(logits.dtype).min
1023
1024
1025

        return logits

1026
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1027
1028
1029
1030
1031
1032
1033
1034
1035
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
1036
        loaded_params: set[str] = set()
1037
1038
1039
1040
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue

1041
            if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
1042
1043
1044
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
                continue
1045
1046
1047
1048
1049
1050
1051

            # With tie_word_embeddings, we can skip lm_head.weight
            # The weight might appear unnecessarily in the files if the model is
            # processed with quantization, LoRA, fine-tuning, etc.
            if self.config.tie_word_embeddings and "lm_head.weight" in name:
                continue

1052
1053
1054
1055
1056
1057
            use_default_weight_loading = False
            if "vqmodel" in name:
                if self.model.vqmodel is not None:
                    # We only do sharding for language model and
                    # not vqvae for now.
                    use_default_weight_loading = True
1058
            else:
1059
                for param_name, weight_name, shard_id in stacked_params_mapping:
1060
1061
1062
1063
1064
1065
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue
1066
1067
                    if is_pp_missing_parameter(name, self):
                        continue
1068
1069
1070
1071
1072
1073
1074
                    param = params_dict[name]
                    weight_loader = param.weight_loader
                    weight_loader(param, loaded_weight, shard_id)
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
1075
                        continue
1076
1077
1078
                    # Remapping the name of FP8 kv-scale.
                    if name.endswith("kv_scale"):
                        remapped_kv_scale_name = name.replace(
1079
1080
                            ".kv_scale", ".attn.kv_scale"
                        )
1081
                        if remapped_kv_scale_name not in params_dict:
1082
                            logger.warning_once(
1083
1084
1085
1086
                                "Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.",  # noqa: E501
                                name,
                                remapped_kv_scale_name,
                            )
1087
1088
1089
                            continue
                        else:
                            name = remapped_kv_scale_name
1090
1091
                    if is_pp_missing_parameter(name, self):
                        continue
1092
                    param = params_dict[name]
1093
1094
1095
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
1096
1097
                    weight_loader(param, loaded_weight)
            if use_default_weight_loading and name in params_dict:
1098
1099
                if is_pp_missing_parameter(name, self):
                    continue
1100
                param = params_dict[name]
1101
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
1102
                weight_loader(param, loaded_weight)
1103
1104
            loaded_params.add(name)
        return loaded_params