chameleon.py 45.1 KB
Newer Older
1
from functools import cached_property
2
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
3
                    Tuple, TypedDict, Union)
4
5

import torch
6
import torch.nn as nn
7
import torch.nn.functional as F
8
9
from transformers import (BatchFeature, ChameleonConfig, ChameleonProcessor,
                          ChameleonVQVAEConfig)
10
11

from vllm.attention import Attention, AttentionMetadata
12
from vllm.config import CacheConfig, VllmConfig
13
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
14
from vllm.logger import init_logger
15
16
17
18
19
20
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
21
from vllm.model_executor.layers.quantization import QuantizationConfig
22
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
23
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
24
25
from vllm.model_executor.layers.vocab_parallel_embedding import (
    ParallelLMHead, VocabParallelEmbedding)
26
27
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader, row_parallel_weight_loader)
28
from vllm.model_executor.sampling_metadata import SamplingMetadata
29
from vllm.model_executor.utils import set_weight_attrs
30
from vllm.multimodal import MULTIMODAL_REGISTRY
31
32
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
                                    NestedTensors)
33
from vllm.multimodal.parse import MultiModalDataItems
34
from vllm.multimodal.processing import (BaseMultiModalProcessor,
35
36
                                        BaseProcessingInfo, PromptReplacement,
                                        PromptReplacementDetails)
37
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
38
from vllm.sequence import IntermediateTensors
39

40
41
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (is_pp_missing_parameter,
42
                    make_empty_intermediate_tensors_factory, make_layers,
43
                    maybe_prefix, merge_multimodal_embeddings)
44

45
46
logger = init_logger(__name__)

47
48
49
50

class ChameleonImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
    data: torch.Tensor
51
    """Shape: `(batch_size * num_images, num_channels, height, width)`"""
52
53


54
class ChameleonProcessingInfo(BaseProcessingInfo):
55

56
    def get_hf_config(self):
57
58
        return self.ctx.get_hf_config(ChameleonConfig)

59
    def get_hf_processor(self):
60
        return self.ctx.get_hf_processor(ChameleonProcessor)
61

62
63
64
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": 1}

65
    def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
66
67
68
69
70
71
72
73
74
        return {"image": self.get_num_image_tokens()}

    def get_num_image_tokens(self) -> int:
        processor = self.get_hf_processor()
        return processor.image_seq_length


class ChameleonDummyInputsBuilder(
        BaseDummyInputsBuilder[ChameleonProcessingInfo]):
75

76
77
78
79
80
    def get_dummy_processor_inputs(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> ProcessorInputs:
81
        config = self.info.get_hf_config()
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98

        width = height = config.vq_config.resolution
        num_images = mm_counts.get("image", 0)

        mm_data = {
            "image":
            self._get_dummy_images(width=width,
                                   height=height,
                                   num_images=num_images)
        }

        return ProcessorInputs(
            prompt_text="<image>" * num_images,
            mm_data=mm_data,
        )


99
100
class ChameleonMultiModalProcessor(
        BaseMultiModalProcessor[ChameleonProcessingInfo]):
101

102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        if not mm_data:
            prompt_ids = self.info.get_tokenizer().encode(prompt)
            prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        return super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
        )

    def _apply_hf_processor_tokens_only(
        self,
        prompt_tokens: list[int],
    ) -> list[int]:
        # HF processor adds sep token for chat mode
        tokenizer = self.info.get_tokenizer()
125
126
127
        vocab = tokenizer.get_vocab()

        sep_token_id = vocab[tokenizer.sep_token]  # type: ignore
128
129
130

        return prompt_tokens + [sep_token_id]

131
132
133
134
135
136
137
138
139
140
141
142
143
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(pixel_values=MultiModalFieldConfig.batched("image"))

    def _get_prompt_replacements(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
    ) -> list[PromptReplacement]:
144
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
145
146
147
148
149
150
151
152
153
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()

        image_start_id = vocab[processor.image_start_token]
        image_token_id = vocab[processor.image_token]
        image_end_id = vocab[processor.image_end_token]

        num_image_tokens = self.info.get_num_image_tokens()
        image_tokens = [image_token_id] * num_image_tokens
154
155
156
157

        return [
            PromptReplacement(
                modality="image",
158
                target=[image_token_id],
159
                replacement=PromptReplacementDetails(
160
                    full=([image_start_id] + image_tokens + [image_end_id]),
161
162
                    features=image_tokens,
                ),
163
164
165
            )
        ]

166
167
168
169
170
171
172

class ChameleonLayerNorm(nn.LayerNorm):

    def __init__(self, hidden_size, *args, **kwargs):
        super().__init__(hidden_size, *args, **kwargs)
        self.normalized_shape = (hidden_size[-1], )

173
174
175
176
177
        set_weight_attrs(self.weight,
                         {"weight_loader": row_parallel_weight_loader})
        set_weight_attrs(self.bias,
                         {"weight_loader": row_parallel_weight_loader})

178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
    def forward(self, hidden_states):
        hidden_states = F.layer_norm(hidden_states,
                                     self.normalized_shape,
                                     None,
                                     None,
                                     eps=1e-5)
        hidden_states = hidden_states * self.weight + self.bias
        return hidden_states


# Copied from vllm.model_executor.models.llama.LlamaMLP -> ChameleonMLP
class ChameleonMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
        quant_config: Optional[QuantizationConfig] = None,
        bias: bool = False,
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            input_size=hidden_size,
            output_sizes=[intermediate_size] * 2,
            bias=bias,
            quant_config=quant_config)
        self.down_proj = RowParallelLinear(input_size=intermediate_size,
                                           output_size=hidden_size,
                                           bias=bias,
                                           quant_config=quant_config)
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


# Modified from vllm.model_executor.models.llama.LlamaAttention -> ChameleonAttention #noqa
class ChameleonAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 4096,
        quant_config: Optional[QuantizationConfig] = None,
        bias: bool = False,
        cache_config: Optional[CacheConfig] = None,
235
        prefix: str = "",
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        self.qkv_proj = QKVParallelLinear(
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
            bias=bias,
            quant_config=quant_config,
        )
        self.o_proj = RowParallelLinear(
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
            bias=bias,
            quant_config=quant_config,
        )
        self.q_norm = ChameleonLayerNorm((self.num_heads, self.head_dim))
        self.k_norm = ChameleonLayerNorm((self.num_kv_heads, self.head_dim))
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )

        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
                              cache_config=cache_config,
289
290
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326

    def _apply_qk_norm(self, q: torch.Tensor,
                       k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # reshape for layernorm
        q = q.reshape(-1, self.num_heads, self.head_dim)
        k = k.reshape(-1, self.num_kv_heads, self.head_dim)
        q = self.q_norm(q)
        k = self.k_norm(k)
        q = q.view(*q.shape[:-2], -1)
        k = k.view(*k.shape[:-2], -1)
        return q, k

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self._apply_qk_norm(q, k)

        q, k = self.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
        output, _ = self.o_proj(attn_output)
        return output


class ChameleonDecoderLayer(nn.Module):

    def __init__(
        self,
        config: ChameleonConfig,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
327
        prefix: str = "",
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        if rope_scaling is not None and getattr(
                config, "original_max_position_embeddings", None):
            rope_scaling["original_max_position_embeddings"] = (
                config.original_max_position_embeddings)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          4096)

        self.self_attn = ChameleonAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=getattr(config, "num_key_value_heads",
                                 config.num_attention_heads),
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
            quant_config=quant_config,
            bias=False,
            cache_config=cache_config,
351
            prefix=f"{prefix}.self_attn",
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
        )
        self.mlp = ChameleonMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
            quant_config=quant_config,
            bias=getattr(config, "mlp_bias", False),
        )
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:

        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            attn_metadata=attn_metadata,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.mlp(hidden_states)

        return hidden_states, residual


class ChameleonSwinDecoderLayer(nn.Module):

    def __init__(
        self,
        config: ChameleonConfig,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
402
        prefix: str = "",
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        if rope_scaling is not None and getattr(
                config, "original_max_position_embeddings", None):
            rope_scaling["original_max_position_embeddings"] = (
                config.original_max_position_embeddings)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          4096)

        self.self_attn = ChameleonAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=getattr(config, "num_key_value_heads",
                                 config.num_attention_heads),
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
            quant_config=quant_config,
            bias=False,
            cache_config=cache_config,
426
            prefix=f"{prefix}.self_attn",
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
        )
        self.mlp = ChameleonMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
            quant_config=quant_config,
            bias=getattr(config, "mlp_bias", False),
        )
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:

        residual = hidden_states
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            attn_metadata=attn_metadata,
        )

        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = hidden_states + residual

        # Fully Connected
        residual = hidden_states
        hidden_states = self.mlp(hidden_states)
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states, residual


469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEVectorQuantizer #noqa
class ChameleonVQVAEVectorQuantizer(nn.Module):

    def __init__(self, config: ChameleonVQVAEConfig):
        super().__init__()
        self.num_embeddings = config.num_embeddings
        self.embedding_dim = config.embed_dim
        self.beta = getattr(config, "beta", 0.25)

        self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
        self.re_embed = self.num_embeddings

    def forward(self, hidden_state: torch.Tensor):
        hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
        hidden_state_flattened = hidden_state.view(-1, self.embedding_dim)

        # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
        distances = (
            torch.sum(hidden_state_flattened**2, dim=1, keepdim=True) +
            torch.sum(self.embedding.weight**2, dim=1) -
            2 * torch.einsum("bd,dn->bn", hidden_state_flattened,
                             self.embedding.weight.transpose(0, 1)))

        min_encoding_indices = torch.argmin(distances, dim=1)
        hidden_state_quant = self.embedding(min_encoding_indices).view(
            hidden_state.shape)

        # compute loss for embedding
        loss = torch.mean((hidden_state_quant.detach() - hidden_state)**
                          2) + self.beta * torch.mean(
                              (hidden_state_quant - hidden_state.detach())**2)

        # preserve gradients
        hidden_state_quant = hidden_state + (hidden_state_quant -
                                             hidden_state).detach()

        # reshape back to match original input shape
        hidden_state_quant = hidden_state_quant.permute(0, 3, 1,
                                                        2).contiguous()

        return hidden_state_quant, loss, min_encoding_indices


# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderConvDownsample #noqa
class ChameleonVQVAEEncoderConvDownsample(nn.Module):

    def __init__(self, in_channels: int):
        super().__init__()
        self.conv = nn.Conv2d(in_channels,
                              in_channels,
                              kernel_size=3,
                              stride=2,
                              padding=0)

    def forward(self, hidden_states: torch.Tensor):
        # no asymmetric padding in torch conv, must do it ourselves
        hidden_states = F.pad(hidden_states,
                              pad=(0, 1, 0, 1),
                              mode="constant",
                              value=0)
        hidden_states = self.conv(hidden_states)
        return hidden_states


# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderResnetBlock #noqa
class ChameleonVQVAEEncoderResnetBlock(nn.Module):

    def __init__(
        self,
        config: ChameleonVQVAEConfig,
        in_channels: int,
        out_channels=None,
        conv_shortcut=False,
    ):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = in_channels if out_channels is None \
            else out_channels
        self.use_conv_shortcut = conv_shortcut

        self.norm1 = torch.nn.GroupNorm(num_groups=32,
                                        num_channels=in_channels,
                                        eps=1e-6,
                                        affine=True)
        self.conv1 = torch.nn.Conv2d(in_channels,
                                     out_channels,
                                     kernel_size=3,
                                     stride=1,
                                     padding=1)
        self.norm2 = torch.nn.GroupNorm(num_groups=32,
                                        num_channels=out_channels,
                                        eps=1e-6,
                                        affine=True)
        self.dropout = torch.nn.Dropout(config.dropout)
        self.conv2 = torch.nn.Conv2d(out_channels,
                                     out_channels,
                                     kernel_size=3,
                                     stride=1,
                                     padding=1)
        if self.in_channels != self.out_channels:
            if self.use_conv_shortcut:
                self.conv_shortcut = torch.nn.Conv2d(in_channels,
                                                     out_channels,
                                                     kernel_size=3,
                                                     stride=1,
                                                     padding=1)
            else:
                self.nin_shortcut = torch.nn.Conv2d(in_channels,
                                                    out_channels,
                                                    kernel_size=1,
                                                    stride=1,
                                                    padding=0)

    def forward(self, hidden_states: torch.Tensor):
        residual = hidden_states
        hidden_states = self.norm1(hidden_states)
        hidden_states *= torch.sigmoid(hidden_states)
        hidden_states = self.conv1(hidden_states)

        hidden_states = self.norm2(hidden_states)
        hidden_states *= torch.sigmoid(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.conv2(hidden_states)

        if self.in_channels != self.out_channels:
            if self.use_conv_shortcut:
                residual = self.conv_shortcut(residual)
            else:
                residual = self.nin_shortcut(residual)

        return residual + hidden_states


# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderAttnBlock #noqa
class ChameleonVQVAEEncoderAttnBlock(nn.Module):

    def __init__(self, in_channels: int):
        super().__init__()
        self.in_channels = in_channels

        self.norm = torch.nn.GroupNorm(num_groups=32,
                                       num_channels=in_channels,
                                       eps=1e-6,
                                       affine=True)
        self.q = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.k = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.v = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.proj_out = torch.nn.Conv2d(in_channels,
                                        in_channels,
                                        kernel_size=1,
                                        stride=1,
                                        padding=0)

    def forward(self, hidden_states: torch.Tensor):
        residual = hidden_states
        hidden_states = self.norm(hidden_states)
        query_states = self.q(hidden_states)
        key_states = self.k(hidden_states)
        value_states = self.v(hidden_states)

        # compute attention
        batch_size, channels, height, width = query_states.shape
        query_states = query_states.reshape(batch_size, channels,
                                            height * width).permute(0, 2, 1)
        key_states = key_states.reshape(batch_size, channels, height * width)
        attn_weights = torch.bmm(query_states, key_states)
        attn_weights = attn_weights * (int(channels)**(-0.5))
        attn_weights = F.softmax(attn_weights, dim=2)

        # attend to values
        value_states = value_states.reshape(batch_size, channels,
                                            height * width)
        attn_weights = attn_weights.permute(0, 2, 1)
        attn_output = torch.bmm(value_states,
                                attn_weights).reshape(batch_size, channels,
                                                      height, width)

        attn_output = self.proj_out(attn_output)
        return residual + attn_output


# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoder #noqa
class ChameleonVQVAEEncoder(nn.Module):

    def __init__(self, config: ChameleonVQVAEConfig):
        super().__init__()

        self.num_resolutions = len(config.channel_multiplier)
        self.num_res_blocks = config.num_res_blocks
        base_channels = config.base_channels
        resolution = config.resolution
        in_channels = config.in_channels
        double_latent = config.double_latent
        latent_channels = config.latent_channels
        channel_multiplier = config.channel_multiplier

        self.conv_in = torch.nn.Conv2d(in_channels,
                                       base_channels,
                                       kernel_size=3,
                                       stride=1,
                                       padding=1)

        curr_res = resolution
        in_channel_multiplier = (1, ) + tuple(channel_multiplier)
        self.in_channel_multiplier = in_channel_multiplier
        self.down = nn.ModuleList()
        for i_level in range(self.num_resolutions):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            block_in = base_channels * in_channel_multiplier[i_level]
            block_out = base_channels * channel_multiplier[i_level]
            for i_block in range(self.num_res_blocks):
                block.append(
                    ChameleonVQVAEEncoderResnetBlock(
                        config=config,
                        in_channels=block_in,
                        out_channels=block_out,
                    ))
                block_in = block_out
                if (config.attn_resolutions is not None
                        and curr_res in config.attn_resolutions
                        and config.attn_type == "vanilla"):
                    attn.append(ChameleonVQVAEEncoderAttnBlock(block_in))

            down = nn.Module()
            down.block = block
            down.attn = attn
            if i_level != self.num_resolutions - 1:
                down.downsample = ChameleonVQVAEEncoderConvDownsample(block_in)
                curr_res = curr_res // 2
            self.down.append(down)

        self.mid = nn.Module()
        self.mid.block_1 = ChameleonVQVAEEncoderResnetBlock(
            config=config,
            in_channels=block_in,
            out_channels=block_in,
        )
        self.mid.attn_1 = ChameleonVQVAEEncoderAttnBlock(
            block_in) if config.attn_type == "vanilla" else nn.Identity()
        self.mid.block_2 = ChameleonVQVAEEncoderResnetBlock(
            config=config,
            in_channels=block_in,
            out_channels=block_in,
        )

        self.norm_out = torch.nn.GroupNorm(num_groups=32,
                                           num_channels=block_in,
                                           eps=1e-6,
                                           affine=True)
        self.conv_out = torch.nn.Conv2d(
            block_in,
            2 * latent_channels if double_latent else latent_channels,
            kernel_size=3,
            stride=1,
            padding=1,
        )

    def forward(self, pixel_values: torch.Tensor):
740
741
        pixel_values = pixel_values.to(self.conv_in.weight.dtype)

742
743
744
745
746
        # downsampling
        hidden_states = [self.conv_in(pixel_values)]
        for i_level in range(self.num_resolutions):
            for i_block in range(self.num_res_blocks):
                hidden_state = self.down[i_level].block[i_block](
747
                    hidden_states[-1])
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
                if len(self.down[i_level].attn) > 0:
                    hidden_state = self.down[i_level].attn[i_block](
                        hidden_state)
                hidden_states.append(hidden_state)
            if i_level != self.num_resolutions - 1:
                hidden_states.append(self.down[i_level].downsample(
                    hidden_states[-1]))

        # middle
        last_hidden_state = hidden_states[-1]
        last_hidden_state = self.mid.block_1(last_hidden_state)
        last_hidden_state = self.mid.attn_1(last_hidden_state)
        last_hidden_state = self.mid.block_2(last_hidden_state)

        # end
        last_hidden_state = self.norm_out(last_hidden_state)
        last_hidden_state *= torch.sigmoid(last_hidden_state)
        last_hidden_state = self.conv_out(last_hidden_state)
        return last_hidden_state


# Adapted from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAE #noqa
class ChameleonVQVAE(nn.Module):

    def __init__(self, config: ChameleonVQVAEConfig):
        super().__init__()
        self.encoder = ChameleonVQVAEEncoder(config)
        self.quantize = ChameleonVQVAEVectorQuantizer(config)
        self.quant_conv = torch.nn.Conv2d(config.latent_channels,
                                          config.embed_dim, 1)
        self.post_quant_conv = torch.nn.Conv2d(config.embed_dim,
                                               config.latent_channels, 1)
        self.eval()  # Chameleon's VQ model is frozen

    def encode(
        self, pixel_values: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        hidden_states = self.encoder(pixel_values)
        hidden_states = self.quant_conv(hidden_states)
        quant, emb_loss, indices = self.quantize(hidden_states)
        return quant, emb_loss, indices


# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonImageVocabularyMapping #noqa
792
793
794
795
796
class ChameleonImageVocabularyMapping:
    """
    A class for mapping discrete image tokens from VQGAN to BPE tokens.
    """

797
    def __init__(self, vocab_map: Dict[str, int]):
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
        self.vocab_map = vocab_map
        self.image_token_id = vocab_map.get("<image>")

    @cached_property
    def val2name(self):
        return {v: k for k, v in self.vocab_map.items()}

    @cached_property
    def image_tokens(self):
        return sorted([
            val for name, val in self.vocab_map.items()
            if name.startswith("IMGIMG")
        ])

    @cached_property
    def bpe2img(self):
        img_tkn_chr_mapping = {chr(ord("A") + i): str(i) for i in range(10)}

        def remap(old_name: str) -> str:
            return "".join(
                img_tkn_chr_mapping.get(c, c)
                for c in old_name[len("IMGIMG"):-1])

        return {
            tok: int(remap(self.val2name[tok]))
            for tok in self.image_tokens
        }

    @cached_property
    def img2bpe(self):
        return {v: k for k, v in self.bpe2img.items()}

    @cached_property
    def bpe2img_search_tensors(self):
        return torch.tensor(sorted(self.bpe2img.keys())), torch.tensor(
            sorted(self.bpe2img.values()))

    @cached_property
    def img2bpe_mapping_tensor(self):
        mapping = torch.zeros(max(self.img2bpe.keys()) + 1, dtype=torch.int)
        for k, v in self.img2bpe.items():
            mapping[k] = v
        return mapping

    def convert_img2bpe(self, img_batch: torch.Tensor) -> torch.Tensor:
        device = img_batch.device
        img_tokens = self.img2bpe_mapping_tensor[img_batch.to("cpu")]
        return img_tokens.to(device)


class ChameleonModel(nn.Module):

850
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
851
        super().__init__()
852
853
854
855
856

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

857
858
859
860
861
862
863
864
865
866
867
        self.config = config
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
        self.embed_tokens = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
        )
        self.vocabulary_mapping = ChameleonImageVocabularyMapping(
            config.vocabulary_map)
        decoder_layer = ChameleonDecoderLayer if not self.config.swin_norm \
            else ChameleonSwinDecoderLayer
868
869
870
871
872

        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: decoder_layer(config=config,
                                         cache_config=cache_config,
873
874
                                         quant_config=quant_config,
                                         prefix=prefix),
875
876
877
            prefix=f"{prefix}.layers",
        )

878
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
879
        self.vqmodel = ChameleonVQVAE(config.vq_config)
880
881
882
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
883
884
885
886

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

887
888
889
890
891
892
893
894
895
896
897
898
    def get_image_tokens(self, pixel_values: torch.Tensor) -> torch.Tensor:
        """
        Tokenizes images into discrete tokens with VQGAN module. Converts
        obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
        special tokens.
        """
        batch_size = pixel_values.shape[0]
        _, _, image_toks = self.vqmodel.encode(pixel_values)
        bpe_toks = self.vocabulary_mapping.convert_img2bpe(image_toks)
        bpe_toks = bpe_toks.view(batch_size, -1)
        return bpe_toks

899
900
901
902
903
904
    def forward(
        self,
        input_ids: Optional[torch.Tensor],
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
905
        intermediate_tensors: Optional[IntermediateTensors],
906
        inputs_embeds: Optional[torch.Tensor] = None,
907
908
909
910
911
912
913
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
            residual = None
914
        else:
915
916
917
918
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
        for i in range(self.start_layer, self.end_layer):
919
920
921
922
            layer = self.layers[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
923
                kv_caches[i - self.start_layer],
924
925
926
                attn_metadata,
                residual,
            )
927
928
929
930
931
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
932
933
934
935
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


936
937
938
939
@MULTIMODAL_REGISTRY.register_processor(
    ChameleonMultiModalProcessor,
    info=ChameleonProcessingInfo,
    dummy_inputs=ChameleonDummyInputsBuilder)
940
941
class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
                                        SupportsPP):
942

943
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
944
        super().__init__()
945
946
        config = vllm_config.model_config.hf_config
        multimodal_config = vllm_config.model_config.multimodal_config
947
        self.config = config
948
        self.multimodal_config = multimodal_config
949
950
        self.model = ChameleonModel(vllm_config=vllm_config,
                                    prefix=maybe_prefix(prefix, "model"))
951
952
953
954
955
956
957
958
959
960
961
        self.unpadded_vocab_size = config.vocab_size
        self.lm_head = ParallelLMHead(
            self.unpadded_vocab_size,
            config.hidden_size,
        )
        if config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight

        logit_scale = getattr(config, "logit_scale", 1.0)
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size, logit_scale)
Joe Runde's avatar
Joe Runde committed
962
        self.sampler = get_sampler()
963
964
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
965

966
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
967
968
        vq_config: ChameleonVQVAEConfig = self.config.vq_config
        expected_dims = (3, vq_config.resolution, vq_config.resolution)
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
        actual_dims = tuple(data.shape[1:])

        if actual_dims != expected_dims:
            expected_expr = ("batch_size", *map(str, expected_dims))
            raise ValueError(
                f"The expected shape of pixel values is {expected_expr}. "
                f"You supplied {tuple(data.shape)}.")

        return data

    def _parse_and_validate_image_input(
            self, **kwargs: object) -> Optional[ChameleonImagePixelInputs]:
        pixel_values = kwargs.pop("pixel_values", None)

        if pixel_values is None:
            return None

        if not isinstance(pixel_values, torch.Tensor):
            raise ValueError("Incorrect type of pixel values. "
                             f"Got type: {type(pixel_values)}")

990
991
992
        # Remove the N dimension until multiple images are supported.
        pixel_values = pixel_values.squeeze(1)

993
994
995
996
997
        return ChameleonImagePixelInputs(
            type="pixel_values",
            data=self._validate_pixel_values(pixel_values),
        )

998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
    def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
        assert self.model.vqmodel is not None
        image_tokens = self.model.get_image_tokens(image_input["data"].to(
            self.config.torch_dtype))
        vision_embeddings = self.model.get_input_embeddings(image_tokens)
        return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: Optional[NestedTensors] = None,
    ) -> torch.Tensor:

        inputs_embeds = self.model.get_input_embeddings(input_ids)
        if multimodal_embeddings is not None:
            inputs_embeds = merge_multimodal_embeddings(
                input_ids, inputs_embeds, multimodal_embeddings,
                self.model.vocabulary_mapping.image_token_id)
        return inputs_embeds

1021
1022
1023
1024
1025
1026
1027
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
        intermediate_tensors: Optional[IntermediateTensors] = None,
1028
        inputs_embeds: Optional[torch.Tensor] = None,
1029
        **kwargs,
1030
    ) -> Union[torch.Tensor, IntermediateTensors]:
1031

1032
        if intermediate_tensors is not None:
1033
1034
1035
1036
1037
1038
1039
1040
            inputs_embeds = None

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
        elif inputs_embeds is None:
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
1041
            input_ids = None
1042
1043
1044
1045
1046
1047
1048

        hidden_states = self.model(input_ids,
                                   positions,
                                   kv_caches,
                                   attn_metadata,
                                   intermediate_tensors,
                                   inputs_embeds=inputs_embeds)
1049
1050
        return hidden_states

1051
1052
1053
1054
1055
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
1056
1057
1058
1059
1060
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)

        # Disallow image tokens which does not include special
        # begin-image and end-image tokens
1061
1062
1063
        if logits is not None:
            image_tokens = self.model.vocabulary_mapping.image_tokens
            logits[:, image_tokens] = torch.finfo(logits.dtype).min
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074

        return logits

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

1075
1076
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
1077
1078
1079
1080
1081
1082
1083
1084
1085
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
1086
        loaded_params: Set[str] = set()
1087
1088
1089
1090
1091
1092
1093
1094
1095
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue

            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
                continue
1096
1097
1098
1099
1100
1101
1102

            # With tie_word_embeddings, we can skip lm_head.weight
            # The weight might appear unnecessarily in the files if the model is
            # processed with quantization, LoRA, fine-tuning, etc.
            if self.config.tie_word_embeddings and "lm_head.weight" in name:
                continue

1103
1104
1105
1106
1107
1108
            use_default_weight_loading = False
            if "vqmodel" in name:
                if self.model.vqmodel is not None:
                    # We only do sharding for language model and
                    # not vqvae for now.
                    use_default_weight_loading = True
1109
            else:
1110
1111
1112
1113
1114
1115
1116
1117
                for (param_name, weight_name,
                     shard_id) in stacked_params_mapping:
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue
1118
1119
                    if is_pp_missing_parameter(name, self):
                        continue
1120
1121
1122
1123
1124
1125
1126
                    param = params_dict[name]
                    weight_loader = param.weight_loader
                    weight_loader(param, loaded_weight, shard_id)
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
1127
                        continue
1128
1129
1130
1131
1132
                    # Remapping the name of FP8 kv-scale.
                    if name.endswith("kv_scale"):
                        remapped_kv_scale_name = name.replace(
                            ".kv_scale", ".attn.kv_scale")
                        if remapped_kv_scale_name not in params_dict:
1133
                            logger.warning_once(
1134
1135
1136
1137
1138
1139
1140
                                "Found kv scale in the checkpoint (e.g. "
                                f"{name}), but not found the expected name in "
                                f"the model (e.g. {remapped_kv_scale_name}). "
                                "kv-scale is not loaded.")
                            continue
                        else:
                            name = remapped_kv_scale_name
1141
1142
                    if is_pp_missing_parameter(name, self):
                        continue
1143
1144
1145
1146
1147
                    param = params_dict[name]
                    weight_loader = getattr(param, "weight_loader",
                                            default_weight_loader)
                    weight_loader(param, loaded_weight)
            if use_default_weight_loading and name in params_dict:
1148
1149
                if is_pp_missing_parameter(name, self):
                    continue
1150
1151
1152
1153
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
1154
1155
            loaded_params.add(name)
        return loaded_params