zamba2.py 34.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
"""PyTorch Zamba2 model implementation for vLLM.

5
6
7
This module implements the Zamba2 architecture from
https://arxiv.org/abs/2411.15242, which combines Mamba and Transformer
architectures in a hybrid model optimized for efficient sequence modeling. The
8
9
model alternates between state space model layers and attention-based layers.
"""
10

11
from collections.abc import Iterable
12
from itertools import cycle
13
from typing import Any
14
15
16
17
18
19

import torch
from torch import nn
from transformers import Zamba2Config

from vllm.attention.layer import Attention
20
from vllm.compilation.decorators import support_torch_compile
21
from vllm.config import CacheConfig, ModelConfig, VllmConfig
22
from vllm.distributed import get_tensor_model_parallel_world_size
23
24
from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
25
26
27
28
29
30
31
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    MergedColumnParallelLinear,
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
32
from vllm.model_executor.layers.logits_processor import LogitsProcessor
33
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
34
from vllm.model_executor.layers.mamba.mamba_utils import (
35
36
37
    MambaStateDtypeCalculator,
    MambaStateShapeCalculator,
)
38
39
40
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
41
42
43
    ParallelLMHead,
    VocabParallelEmbedding,
)
44
45
46
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors

47
from .interfaces import HasInnerState, IsHybrid, SupportsMambaPrefixCaching
48
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
49
50
51
52


class Zamba2LoRA(nn.Module):
    """LoRA layer for the Zamba2 model.
53

54
55
56
57
58
59
60
61
    Implements a LoRA layer that is used in shared attention and gated MLP
    blocks.
    """

    def __init__(
        self,
        input_dim: int,
        rank: int,
62
63
        output_dim: int | list[int],
        quant_config: QuantizationConfig | None = None,
64
        prefix: str = "",
65
66
    ):
        """Initialize the attention layer.
67

68
69
70
71
72
73
74
75
        Args:
            input_dim: input dimension
            rank: LoRA rank
            output_dim: output dimension
            quant_config: Configuration for model quantization
        """
        super().__init__()

76
        self.A = ColumnParallelLinear(
77
78
79
80
81
82
            input_dim,
            rank,
            bias=False,
            quant_config=quant_config,
            gather_output=True,
            prefix=f"{prefix}.A",
83
        )
84
85
86
87
88

        if isinstance(output_dim, list):
            B_class = MergedColumnParallelLinear
        else:
            B_class = ColumnParallelLinear
89
        self.B = B_class(rank, output_dim, bias=False, quant_config=quant_config)
90
91
92
93
94
95
96
97
98
99
100
101

    def forward(
        self,
        hidden_states: torch.Tensor,
    ):
        lora_output, _ = self.A(hidden_states)
        lora_output, _ = self.B(lora_output)
        return lora_output


class Zamba2Attention(nn.Module):
    """Multi-head attention mechanism for the Zamba2 model.
102
103

    Implements attention with parallel computation, QKV projections, optional
104
105
106
107
108
109
110
111
112
    adapters and rotary position embeddings. The attention is computed across
    distributed blocks for efficient processing.
    """

    def __init__(
        self,
        config: Zamba2Config,
        bare_block_idx: int,
        num_hybrid_layers: int,
113
114
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
115
116
117
        prefix: str = "",
    ) -> None:
        """Initialize the attention layer.
118

119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
        Args:
            config: The Zamba2 model configuration
            bare_block_idx: Index of the bare attention block
            num_hybrid_layers: Total number of hybrid layers
            cache_config: Configuration for key-value caching
            quant_config: Configuration for model quantization
            prefix: Optional prefix for parameter names
        """
        super().__init__()
        tp_size = get_tensor_model_parallel_world_size()
        self.config = config
        self.num_hybrid_layers = num_hybrid_layers

        self.attention_hidden_size = config.attention_hidden_size
        self.total_num_attention_heads = config.num_attention_heads
        assert self.total_num_attention_heads % tp_size == 0
        self.num_attention_heads = config.num_attention_heads // tp_size
        self.attention_head_dim = config.attention_head_dim
        self.qkv_size = self.attention_hidden_size // tp_size
138
        self.scale = (self.attention_head_dim / 2) ** -0.5
139

140
141
142
        if (
            self.attention_head_dim * self.total_num_attention_heads
        ) != self.attention_hidden_size:
143
144
145
146
            raise ValueError(
                f"attention_hidden_size must be divisible by"
                f" num_attention_heads"
                f" (got `attention_hidden_size`: {self.attention_hidden_size}"
147
148
                f" and `num_heads`: {self.num_attention_heads})."
            )
149
150
151
152
153
154
155

        self.qkv_proj = QKVParallelLinear(
            self.attention_hidden_size,
            self.attention_head_dim,
            self.total_num_attention_heads,
            bias=False,
            quant_config=quant_config,
156
            prefix=f"{prefix}.qkv_proj",
157
        )
158
159
160
161
162
        self.o_proj = RowParallelLinear(
            self.attention_hidden_size,
            config.hidden_size,
            bias=False,
            quant_config=quant_config,
163
            prefix=f"{prefix}.o_proj",
164
        )
165
166
167
168
169
170
171
172

        # Even though in Zamba2 weights are shared between attention layers, KV
        # cache is unique for every attention layer. Hence, we need to define
        # separate Attention objects, because in recent vLLM KV cache tensors
        # are tied to specific Attention objects.

        # Initialize attention blocks with proper indexing
        self.dpa_list = nn.ModuleList([])
173
174
175
176
177
        j = (
            bare_block_idx
            * (self.num_hybrid_layers + config.num_mem_blocks - 1)
            // config.num_mem_blocks
        )
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
        for block_idx in range(self.num_hybrid_layers):
            if block_idx % config.num_mem_blocks == bare_block_idx:
                dpa = Attention(
                    self.num_attention_heads,
                    self.attention_head_dim,
                    self.scale,
                    cache_config=cache_config,
                    prefix=f"{prefix}.attn.{j}",
                )
                j += 1
            else:
                dpa = nn.Identity()
            self.dpa_list.append(dpa)

        # Initialize adapter layers if enabled
        if config.use_shared_attention_adapter:
            self.linear_q_adapter_list = nn.ModuleList([])
            self.linear_k_adapter_list = nn.ModuleList([])
            self.linear_v_adapter_list = nn.ModuleList([])

            for block_idx in range(self.num_hybrid_layers):
                if block_idx % config.num_mem_blocks == bare_block_idx:
                    linear_q_adapter = Zamba2LoRA(
                        self.attention_hidden_size,
                        config.adapter_rank,
                        self.attention_hidden_size,
                        quant_config=quant_config,
205
                        prefix=f"{prefix}.linear_q_adapter",
206
207
208
209
210
211
                    )
                    linear_k_adapter = Zamba2LoRA(
                        self.attention_hidden_size,
                        config.adapter_rank,
                        self.attention_hidden_size,
                        quant_config=quant_config,
212
                        prefix=f"{prefix}.linear_k_adapter",
213
214
215
216
217
218
                    )
                    linear_v_adapter = Zamba2LoRA(
                        self.attention_hidden_size,
                        config.adapter_rank,
                        self.attention_hidden_size,
                        quant_config=quant_config,
219
                        prefix=f"{prefix}.linear_v_adapter",
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
                    )
                else:
                    linear_q_adapter = nn.Identity()
                    linear_k_adapter = nn.Identity()
                    linear_v_adapter = nn.Identity()

                self.linear_q_adapter_list.append(linear_q_adapter)
                self.linear_k_adapter_list.append(linear_k_adapter)
                self.linear_v_adapter_list.append(linear_v_adapter)

        if config.use_mem_rope:
            self.rotary_emb = get_rope(
                head_size=self.attention_head_dim,
                rotary_dim=self.attention_head_dim,
                max_position=config.max_position_embeddings,
235
                rope_parameters=config.rope_parameters,
236
237
238
239
240
241
242
243
244
245
                is_neox_style=True,
            )

    def forward(
        self,
        hidden_states: torch.Tensor,
        block_idx: int,
        position_ids: torch.Tensor,
    ) -> torch.Tensor:
        """Forward pass through the attention layer.
246

247
248
249
250
        Args:
            hidden_states: Input tensor [batch_size, seq_len, hidden_size]
            position_ids: Position IDs for positional embeddings
            block_idx: Current shared transformer block index
251

252
253
254
255
        Returns:
            Output tensor [batch_size, seq_len, hidden_size]
        """
        qkv, _ = self.qkv_proj(hidden_states)
256
        query_states, key_states, value_states = qkv.split([self.qkv_size] * 3, dim=-1)
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275

        if self.config.use_shared_attention_adapter:
            # Apply adapter transformations to Q, K, V if enabled
            q_adapter = self.linear_q_adapter_list[block_idx]
            assert not isinstance(q_adapter, nn.Identity)
            q_lora_output = q_adapter(hidden_states)
            query_states = query_states + q_lora_output

            k_adapter = self.linear_k_adapter_list[block_idx]
            assert not isinstance(k_adapter, nn.Identity)
            k_lora_output = k_adapter(hidden_states)
            key_states = key_states + k_lora_output

            v_adapter = self.linear_v_adapter_list[block_idx]
            assert not isinstance(v_adapter, nn.Identity)
            v_lora_output = v_adapter(hidden_states)
            value_states = value_states + v_lora_output

        if self.config.use_mem_rope:
276
277
278
            query_states, key_states = self.rotary_emb(
                position_ids, query_states, key_states
            )
279
280
281
282
283
284
285
286

        y = self.dpa_list[block_idx](query_states, key_states, value_states)
        y, _ = self.o_proj(y)
        return y


class Zamba2MLP(nn.Module):
    """Feed-forward MLP layer for the Zamba2 model.
287
288
289

    Implements a gated feed-forward network that projects inputs to a larger
    intermediate size, applies GELU activation with gating, then projects back
290
291
292
293
294
295
296
    to the original size. Includes optional adapter layers for model adaptation.
    """

    def __init__(
        self,
        config: Zamba2Config,
        bare_block_idx: int,
297
        num_hybrid_layers: dict[int, int],
298
        quant_config: QuantizationConfig | None = None,
299
        prefix: str = "",
300
301
    ) -> None:
        """Initialize the MLP layer.
302

303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
        Args:
            config: The Zamba2 model configuration
            bare_block_idx: Index of the bare block in the model
            num_hybrid_layers: Total number of hybrid layers
            quant_config: Configuration for model quantization
        """
        super().__init__()
        self.config = config
        self.tp_size = get_tensor_model_parallel_world_size()
        self.num_hybrid_layers = num_hybrid_layers
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size

        # Main projection layers with gating
        self.gate_up_proj = MergedColumnParallelLinear(
            self.hidden_size,
            2 * [self.intermediate_size],  # 2x for gate and input projections
            bias=self.config.add_bias_linear,
321
            quant_config=quant_config,
322
            prefix=f"{prefix}.gate_up_proj",
323
        )
324

325
326
327
328
329
        self.down_proj = RowParallelLinear(
            self.intermediate_size,
            self.hidden_size,
            bias=self.config.add_bias_linear,
            quant_config=quant_config,
330
            prefix=f"{prefix}.down_proj",
331
        )
332
333
334

        # Only allow GELU activations
        if config.hidden_act != "gelu":
335
336
337
338
            raise ValueError(
                f"Only GELU activation is supported "
                f"(got `hidden_act`: {config.hidden_act})"
            )
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
        self.act_fn = GeluAndMul()

        # Initialize adapter layers
        self.gate_up_proj_adapter_list = nn.ModuleList([])
        for block_idx in range(self.num_hybrid_layers):
            if block_idx % config.num_mem_blocks == bare_block_idx:
                gate_up_proj_adapter = Zamba2LoRA(
                    config.hidden_size,
                    config.adapter_rank,
                    2 * [self.intermediate_size],
                    quant_config,
                )
            else:
                gate_up_proj_adapter = nn.Identity()
            self.gate_up_proj_adapter_list.append(gate_up_proj_adapter)

355
    def forward(self, hidden_states: torch.Tensor, block_idx: int) -> torch.Tensor:
356
        """Forward pass through the MLP layer.
357

358
359
360
        Args:
            hidden_states: Input tensor [batch_size, seq_len, hidden_size]
            block_idx: Current shared transformer block index
361

362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
        Returns:
            Output tensor [batch_size, seq_len, hidden_size] after applying
            gated feed-forward transformation
        """
        # Project input to intermediate size with gating
        gate_up_states, _ = self.gate_up_proj(hidden_states)

        # Apply adapter transformation if present
        adapter = self.gate_up_proj_adapter_list[block_idx]
        assert not isinstance(adapter, nn.Identity)
        lora_output = adapter(hidden_states)
        gate_up_states = gate_up_states + lora_output

        # Apply GELU activation with gating
        hidden_states = self.act_fn(gate_up_states)

        # Project back to hidden size
        output, _ = self.down_proj(hidden_states)
        return output


class Zamba2AttentionDecoderLayer(nn.Module):
    """Single decoder layer combining attention and feed-forward networks.
385

386
387
388
389
390
391
392
393
394
395
396
397
    This layer implements a standard transformer block with:
    - Input layer normalization
    - Multi-head self-attention
    - Pre-feed-forward layer normalization
    - Feed-forward network (MLP)
    """

    def __init__(
        self,
        config: Zamba2Config,
        bare_block_idx: int,
        num_hybrid_layers: int,
398
399
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
400
401
402
        prefix: str = "",
    ) -> None:
        """Initialize the decoder layer.
403

404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
        Args:
            config: The Zamba2 model configuration
            bare_block_idx: Index of the bare block
            num_hybrid_layers: Total number of hybrid layers
            cache_config: Configuration for key-value caching
            quant_config: Configuration for model quantization
            prefix: Optional prefix for parameter names
        """
        super().__init__()

        # Initialize attention sublayer
        self.self_attn = Zamba2Attention(
            config,
            bare_block_idx=bare_block_idx,
            num_hybrid_layers=num_hybrid_layers,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=prefix,
        )

        # Initialize feed-forward sublayer
        self.feed_forward = Zamba2MLP(
            config,
            bare_block_idx=bare_block_idx,
            num_hybrid_layers=num_hybrid_layers,
            quant_config=quant_config,
430
            prefix=f"{prefix}.feed_forward",
431
432
433
434
        )

        # Initialize layer normalizations
        # Input normalization operates on concatenated states
435
        self.input_layernorm = RMSNorm(2 * config.hidden_size, eps=config.rms_norm_eps)
436
        # Pre-FF normalization operates on attention output
437
        self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
438
439
440
441
442
443
444
445
446

    def forward(
        self,
        hidden_states: torch.Tensor,
        original_hidden_states: torch.Tensor,
        block_idx: int,
        positions: torch.Tensor,
    ) -> torch.Tensor:
        """Forward pass through the decoder layer.
447

448
449
        Args:
            hidden_states: Input tensor from previous layer
450
            original_hidden_states: Original input tensor for residual
451
452
453
                connection
            block_idx: Current shared transformer block index
            positions: IDs for positional embeddings
454

455
456
457
458
459
460
461
462
463
        Returns:
            Transformed hidden states after attention and feed-forward
        """

        # The argument original_hidden_states is concatenated with hidden_states
        # (which is the output of the previous (mamba) layer).
        # The concatenated tensor is then used as input of the pre-attention
        # RMSNorm (see fig. 2 in https://arxiv.org/pdf/2405.16712).
        hidden_states = torch.concatenate(
464
465
            [hidden_states, original_hidden_states], dim=-1
        )
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487

        # Layer norm before attention
        hidden_states = self.input_layernorm(hidden_states)

        # Self attention
        hidden_states = self.self_attn(
            hidden_states,
            position_ids=positions,
            block_idx=block_idx,
        )

        # Layer norm before feed-forward
        hidden_states = self.pre_ff_layernorm(hidden_states)

        # Feed-forward network
        hidden_states = self.feed_forward(hidden_states, block_idx=block_idx)

        return hidden_states


class Zamba2MambaDecoderLayer(nn.Module):
    """Single Mamba decoder layer with normalization.
488
489
490

    This implements a  Mamba block. It includes input normalization
    and can process sequences using either chunked or full
491
492
493
    computation depending on configuration.
    """

494
495
496
    def __init__(
        self,
        config: Zamba2Config,
497
498
499
        model_config: ModelConfig | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
500
501
        prefix: str = "",
    ) -> None:
502
        """Initialize the Mamba decoder layer.
503

504
505
506
507
508
509
510
511
        Args:
            config: The Zamba2 model configuration
            quant_config: Configuration for model quantization
        """
        super().__init__()

        # Initialize Mamba mixer with expanded intermediate size
        intermediate_size = config.mamba_expand * config.hidden_size
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
        self.mamba = MambaMixer2(
            hidden_size=config.hidden_size,
            ssm_state_size=config.mamba_d_state,
            conv_kernel_size=config.mamba_d_conv,
            intermediate_size=intermediate_size,
            use_conv_bias=config.use_conv_bias,
            use_bias=config.add_bias_linear,
            n_groups=config.mamba_ngroups,
            num_heads=config.n_mamba_heads,
            head_dim=intermediate_size // config.n_mamba_heads,
            rms_norm_eps=config.rms_norm_eps,
            activation="silu",
            model_config=model_config,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.mixer",
        )
529
530

        # Input normalization
531
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
532
533
534
535

    def forward(
        self,
        hidden_states: torch.Tensor,
536
537
538
        transformer_hidden_states: torch.Tensor | None = None,
        positions: torch.Tensor | None = None,
        original_hidden_states: torch.Tensor | None = None,
539
540
    ) -> torch.Tensor:
        """Forward pass through the Mamba decoder layer.
541

542
543
544
545
546
547
        Args:
            hidden_states: Input tensor [batch_size, seq_len, hidden_size]
            transformer_hidden_states: Optional output from transformer path
                Added to input if provided (used in hybrid architecture)
            positions: Optional position IDs (unused in Mamba)
            original_hidden_states: Optional original inputs (unused in Mamba)
548

549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
        Returns:
            Transformed hidden states with residual connection applied
        """
        # Store input for residual connection
        residual = hidden_states

        # `transformer_hidden_states` is the output from shared
        # transformer + linear layer (see fig. 2 in
        # https://arxiv.org/pdf/2405.16712).
        # `transformer_hidden_states` is then added to the input to the mamba
        # layer below (as described in eq. (6) of
        # https://arxiv.org/pdf/2405.16712).
        if transformer_hidden_states is not None:
            hidden_states = hidden_states + transformer_hidden_states

        # Apply input normalization
        hidden_states = self.input_layernorm(hidden_states)

        # Process through Mamba mixer
568
        output = self.mamba(hidden_states)
569
570

        # residual connection after mamba
571
        hidden_states = residual + output
572
573
574
575
576
577

        return hidden_states


class Zamba2HybridLayer(nn.Module):
    """Hybrid layer combining Transformer and Mamba architectures.
578

579
580
581
582
583
584
585
586
587
588
589
    This layer implements the hybrid architecture described in the Zamba paper,
    where a shared transformer pathway processes input in parallel with a Mamba
    pathway. The transformer output is projected and added to the Mamba input
    for enhanced representation learning.
    """

    def __init__(
        self,
        shared_transformer: Zamba2AttentionDecoderLayer,
        config: Zamba2Config,
        block_idx: int,
590
591
592
        model_config: ModelConfig | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
593
        prefix: str = "",
594
595
    ) -> None:
        """Initialize the hybrid layer.
596

597
598
599
600
601
602
        Args:
            shared_transformer: Transformer decoder layer for attention pathway
        """
        super().__init__()
        self.block_idx = block_idx
        self.shared_transformer = shared_transformer
603
604
605
606
607
        self.linear = ReplicatedLinear(
            config.hidden_size,
            config.hidden_size,
            bias=False,
            quant_config=quant_config,
608
            prefix=f"{prefix}.linear",
609
610
611
612
613
614
615
616
        )
        self.mamba_decoder = Zamba2MambaDecoderLayer(
            config,
            model_config=model_config,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=prefix,
        )
617
618
619
620
621
622
623
624

    def forward(
        self,
        hidden_states: torch.Tensor,
        original_hidden_states: torch.Tensor,
        positions: torch.Tensor,
    ) -> torch.Tensor:
        """Forward pass through the hybrid layer.
625

626
627
628
629
630
        Processes input through parallel transformer and Mamba paths:
        1. Transformer path processes input with attention
        2. Transformer output is projected to match hidden size
        3. Projected output is added to Mamba path input
        4. Final output combines both paths' representations
631

632
633
        Args:
            hidden_states: Input tensor [batch_size, seq_len, hidden_size]
634
            original_hidden_states: Original input for transformer residual
635
636
                connection
            positions: Position IDs for positional embeddings
637

638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
        Returns:
            Output tensor combining transformer and Mamba representations
        """
        # Process through transformer pathway
        transformer_hidden_states = self.shared_transformer(
            hidden_states,
            original_hidden_states=original_hidden_states,
            block_idx=self.block_idx,
            positions=positions,
        )

        # Project transformer output
        transformer_hidden_states, _ = self.linear(transformer_hidden_states)

        # Process through Mamba pathway with transformer injection
        layer_outputs = self.mamba_decoder(
            hidden_states,
            transformer_hidden_states=transformer_hidden_states,
        )

        return layer_outputs


661
@support_torch_compile
662
663
class Zamba2Model(nn.Module):
    """Core Zamba2 model combining transformer and Mamba architectures.
664
665

    The model processes input through a sequence of hybrid and Mamba-only
666
667
668
669
670
    layers, using token embeddings and final layer normalization.
    """

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
        """Initialize the Zamba2 model.
671

672
        Args:
673
            vllm_config: Configuration object containing model, cache,
674
675
676
677
678
679
                quantization and LoRA settings
            prefix: Optional prefix for parameter names in state dict
        """
        super().__init__()

        config = vllm_config.model_config.hf_config
680
        model_config = vllm_config.model_config
681
682
683
684
685
686
687
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
        is_lora_enabled = bool(lora_config)
        assert not is_lora_enabled

        self.config = config
688
689

        self.vocab_size = config.vocab_size
690
691
692
693
694
695
696
697
698
699
700
701
702
703

        # Initialize token embeddings
        self.embed_tokens = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
        )

        # Map hybrid layer indices to block indices
        layer2block_map = {
            layer_idx: block_idx
            for block_idx, layer_idx in enumerate(config.hybrid_layer_ids)
        }

        # Create cyclic iterator of transformer blocks
704
705
706
707
708
709
710
711
712
713
714
715
716
        blocks = cycle(
            [
                Zamba2AttentionDecoderLayer(
                    config,
                    bare_block_idx=idx,
                    num_hybrid_layers=len(layer2block_map),
                    cache_config=cache_config,
                    quant_config=quant_config,
                    prefix=f"{prefix}",
                )
                for idx in range(config.num_mem_blocks)
            ]
        )
717
718
719
720

        # Initialize layers according to block type configuration
        layers = []
        for layer_idx, layer_type in enumerate(config.layers_block_type):
721
722
723
            # tdoublep: avoid layers getting same index
            # somewhat hacky but correct (I think)
            prefix = str(len(layer2block_map) + layer_idx)
724
725
726
727
            if layer_type == "hybrid":
                block = next(blocks)
                block_idx = layer2block_map[layer_idx]
                layers.append(
728
729
730
731
732
733
734
735
736
737
                    Zamba2HybridLayer(
                        block,
                        config,
                        block_idx,
                        model_config=model_config,
                        cache_config=cache_config,
                        quant_config=quant_config,
                        prefix=prefix,
                    )
                )
738
739
            else:
                layers.append(
740
741
742
743
744
745
746
747
                    Zamba2MambaDecoderLayer(
                        config,
                        model_config=model_config,
                        cache_config=cache_config,
                        quant_config=quant_config,
                        prefix=prefix,
                    )
                )
748
749
750
        self.layers = nn.ModuleList(layers)

        # Final layer normalization
751
        self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
752

753
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
754
        """Convert input token IDs to embeddings.
755

756
757
        Args:
            input_ids: Tensor of input token IDs
758

759
760
761
762
763
764
765
766
767
        Returns:
            Embedded representation of the input tokens
        """
        return self.embed_tokens(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
768
769
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
770
        """Forward pass through the model.
771

772
773
774
775
        Args:
            input_ids: Input token IDs
            positions: Position IDs for embeddings
            inputs_embeds: Optional pre-computed input embeddings
776

777
        Returns:
778
            Either final hidden states or intermediate tensors for pipeline
779
780
781
782
            parallelism
        """
        # Handle pipeline parallelism for first rank
        if inputs_embeds is None:
783
            inputs_embeds = self.embed_input_ids(input_ids)
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
        hidden_states = inputs_embeds

        # Process through layers
        original_hidden_states = torch.clone(hidden_states)
        for layer_idx, layer in enumerate(self.layers):
            layer_outputs = layer(
                hidden_states,
                original_hidden_states=original_hidden_states,
                positions=positions,
            )
            hidden_states = layer_outputs

        hidden_states = self.final_layernorm(hidden_states)
        return hidden_states

799
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
800
801
802
803
804
805
806
807
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]

        params_dict = dict(self.named_parameters())
808
        loaded_params: set[str] = set()
809
810
811
812
        for chkpt_weight_name, loaded_weight in weights:
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in chkpt_weight_name:
                    continue
813
                chkpt_weight_name = chkpt_weight_name.replace(weight_name, param_name)
814
815
816
817
818
819
820
821
                param = params_dict[chkpt_weight_name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                if chkpt_weight_name not in params_dict:
                    continue
                param = params_dict[chkpt_weight_name]
822
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
823
824
825
826
                weight_loader(param, loaded_weight)
            loaded_params.add(chkpt_weight_name)
        return loaded_params

827

828
class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsMambaPrefixCaching):
829
    """Zamba2 model with causal language modeling head.
830

831
832
833
834
835
836
    This class wraps the core Zamba2 model and adds:
    - A language modeling head for next token prediction
    - Mamba state caching functionality
    - Support for model parallelism and quantization
    - Sampling capabilities for text generation
    """
837

838
    # To ensure correct weight loading and mapping.
839
840
841
842
843
844
845
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={
            "A_log": "A",
            "0.weight": "A.weight",
            "1.weight": "B.weight",
        }
    )
846

847
848
849
850
851
852
853
854
855
856
857
    @classmethod
    def get_mamba_state_dtype_from_config(
        cls,
        vllm_config: "VllmConfig",
    ) -> tuple[torch.dtype, torch.dtype]:
        return MambaStateDtypeCalculator.mamba2_state_dtype(
            vllm_config.model_config.dtype,
            vllm_config.cache_config.mamba_cache_dtype,
            vllm_config.cache_config.mamba_ssm_cache_dtype,
        )

858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
    @classmethod
    def get_mamba_state_shape_from_config(
        cls,
        vllm_config: "VllmConfig",
    ) -> tuple[tuple[int, int], tuple[int, int, int]]:
        """Calculate shapes for Mamba's convolutional and state caches.

        Args:
            vllm_config: vLLM config

        Returns:
            Tuple containing:
            - conv_state_shape: Shape for convolutional state cache
            - temporal_state_shape: Shape for state space model cache
        """

        parallel_config = vllm_config.parallel_config
        hf_config = vllm_config.model_config.hf_config
        intermediate_size = hf_config.mamba_expand * hf_config.hidden_size

878
        return MambaStateShapeCalculator.mamba2_state_shape(
879
880
881
882
883
884
885
886
887
            intermediate_size=intermediate_size,
            tp_world_size=parallel_config.tensor_parallel_size,
            n_groups=hf_config.mamba_ngroups,
            num_heads=hf_config.n_mamba_heads,
            head_dim=hf_config.mamba_headdim,
            state_size=hf_config.mamba_d_state,
            conv_kernel=hf_config.mamba_d_conv,
        )

888
889
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
        """Initialize the Zamba2 model for causal language modeling.
890

891
892
893
894
        Args:
            vllm_config: Configuration containing model, cache, quantization,
                        LoRA and scheduler settings
            prefix: Optional prefix for parameter names
895

896
        Raises:
897
            AssertionError: If prefix caching is enabled
898
                (not supported by Mamba)
899
900
        """
        config = vllm_config.model_config.hf_config
901

902
903
904
905
906
907
908
909
910
        scheduler_config = vllm_config.scheduler_config

        super().__init__()
        self.config = config
        self.vllm_config = vllm_config
        self.scheduler_config = scheduler_config
        self.model_config = vllm_config.model_config

        # Initialize core model
911
912
913
        self.model = Zamba2Model(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
914
915
916

        # Initialize language modeling head
        self.lm_head = ParallelLMHead(
917
            config.vocab_size,
918
            config.hidden_size,
919
            prefix=maybe_prefix(prefix, "lm_head"),
920
921
922
923
924
        )
        # Tie weights with input embeddings if using same dimensions
        self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)

        # Initialize logits processing and sampling
925
        self.logits_processor = LogitsProcessor(config.vocab_size)
926

927
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
928
929
930
931
932
933
        """Convert input token IDs to embeddings.
        Args:
            input_ids: Tensor of input token IDs
        Returns:
            Embedded representation of the input tokens
        """
934
        return self.model.embed_input_ids(input_ids)
935

936
937
938
939
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
940
        inputs_embeds: torch.Tensor | None = None,
941
942
        **kwargs: Any,
    ) -> torch.Tensor:
943
        """Forward pass through the model.
944

945
946
947
948
949
        Args:
            input_ids: Input token IDs
            positions: Position IDs for embeddings
            inputs_embeds: Optional pre-computed input embeddings
            **kwargs: Additional arguments passed to cache manager
950

951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
        Returns:
            Output hidden states
        """
        # Forward pass through model
        hidden_states = self.model(
            input_ids,
            positions,
            inputs_embeds,
        )

        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
966
    ) -> torch.Tensor | None:
967
        """Compute logits for next token prediction.
968

969
970
        Args:
            hidden_states: Hidden states from model forward pass
971

972
973
974
        Returns:
            Logits for next token prediction
        """
975
        logits = self.logits_processor(self.lm_head, hidden_states)
976
977
        return logits

978
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
979
980
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)