zamba2.py 40.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
"""PyTorch Zamba2 model implementation for vLLM.

This module implements the Zamba2 architecture from 
https://arxiv.org/abs/2411.15242, which combines Mamba and Transformer 
architectures in a hybrid model optimized for efficient sequence modeling. The 
model alternates between state space model layers and attention-based layers.
"""
10
from collections.abc import Iterable
11
from itertools import cycle
12
from typing import Optional, Union
13
14
15
16
17

import torch
from torch import nn
from transformers import Zamba2Config

18
from vllm import envs
19
20
21
22
23
24
25
26
27
28
29
30
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               ReplicatedLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
31
32
from vllm.model_executor.layers.mamba.mamba2_metadata import (
    Mamba2Metadata, prepare_mamba2_metadata)
33
34
35
36
37
38
39
40
41
42
43
44
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
    MambaMixer2, extra_groups_for_head_shards)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
                                                    MambaCacheParams)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors

45
from .interfaces import HasInnerState, IsHybrid
46
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
47
48
49
50
51
52
53
54
55
56
57
58
59


class Zamba2LoRA(nn.Module):
    """LoRA layer for the Zamba2 model.
    
    Implements a LoRA layer that is used in shared attention and gated MLP
    blocks.
    """

    def __init__(
        self,
        input_dim: int,
        rank: int,
60
        output_dim: Union[int, list[int]],
61
        quant_config: Optional[QuantizationConfig] = None,
62
        prefix: str = "",
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
    ):
        """Initialize the attention layer.
        
        Args:
            input_dim: input dimension
            rank: LoRA rank
            output_dim: output dimension
            quant_config: Configuration for model quantization
        """
        super().__init__()

        self.A = ColumnParallelLinear(input_dim,
                                      rank,
                                      bias=False,
                                      quant_config=quant_config,
                                      gather_output=True)

        if isinstance(output_dim, list):
            B_class = MergedColumnParallelLinear
        else:
            B_class = ColumnParallelLinear
        self.B = B_class(rank,
                         output_dim,
                         bias=False,
                         quant_config=quant_config)

    def forward(
        self,
        hidden_states: torch.Tensor,
    ):
        lora_output, _ = self.A(hidden_states)
        lora_output, _ = self.B(lora_output)
        return lora_output


class Zamba2Attention(nn.Module):
    """Multi-head attention mechanism for the Zamba2 model.
    
    Implements attention with parallel computation, QKV projections, optional 
    adapters and rotary position embeddings. The attention is computed across
    distributed blocks for efficient processing.
    """

    def __init__(
        self,
        config: Zamba2Config,
        bare_block_idx: int,
        num_hybrid_layers: int,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
        """Initialize the attention layer.
        
        Args:
            config: The Zamba2 model configuration
            bare_block_idx: Index of the bare attention block
            num_hybrid_layers: Total number of hybrid layers
            cache_config: Configuration for key-value caching
            quant_config: Configuration for model quantization
            prefix: Optional prefix for parameter names
        """
        super().__init__()
        tp_size = get_tensor_model_parallel_world_size()
        self.config = config
        self.num_hybrid_layers = num_hybrid_layers
        self.rope_theta = config.rope_theta

        self.attention_hidden_size = config.attention_hidden_size
        self.total_num_attention_heads = config.num_attention_heads
        assert self.total_num_attention_heads % tp_size == 0
        self.num_attention_heads = config.num_attention_heads // tp_size
        self.attention_head_dim = config.attention_head_dim
        self.qkv_size = self.attention_hidden_size // tp_size
        self.scale = (self.attention_head_dim / 2)**-0.5

        if (self.attention_head_dim *
                self.total_num_attention_heads) != self.attention_hidden_size:
            raise ValueError(
                f"attention_hidden_size must be divisible by"
                f" num_attention_heads"
                f" (got `attention_hidden_size`: {self.attention_hidden_size}"
                f" and `num_heads`: {self.num_attention_heads}).")

        self.qkv_proj = QKVParallelLinear(
            self.attention_hidden_size,
            self.attention_head_dim,
            self.total_num_attention_heads,
            bias=False,
            quant_config=quant_config,
        )
        self.o_proj = RowParallelLinear(self.attention_hidden_size,
                                        config.hidden_size,
                                        bias=False,
                                        quant_config=quant_config)

        # Even though in Zamba2 weights are shared between attention layers, KV
        # cache is unique for every attention layer. Hence, we need to define
        # separate Attention objects, because in recent vLLM KV cache tensors
        # are tied to specific Attention objects.

        # Initialize attention blocks with proper indexing
        self.dpa_list = nn.ModuleList([])
        j = bare_block_idx * (self.num_hybrid_layers + config.num_mem_blocks -
                              1) // config.num_mem_blocks
        for block_idx in range(self.num_hybrid_layers):
            if block_idx % config.num_mem_blocks == bare_block_idx:
                dpa = Attention(
                    self.num_attention_heads,
                    self.attention_head_dim,
                    self.scale,
                    cache_config=cache_config,
                    prefix=f"{prefix}.attn.{j}",
                )
                j += 1
            else:
                dpa = nn.Identity()
            self.dpa_list.append(dpa)

        # Initialize adapter layers if enabled
        if config.use_shared_attention_adapter:
            self.linear_q_adapter_list = nn.ModuleList([])
            self.linear_k_adapter_list = nn.ModuleList([])
            self.linear_v_adapter_list = nn.ModuleList([])

            for block_idx in range(self.num_hybrid_layers):
                if block_idx % config.num_mem_blocks == bare_block_idx:
                    linear_q_adapter = Zamba2LoRA(
                        self.attention_hidden_size,
                        config.adapter_rank,
                        self.attention_hidden_size,
                        quant_config=quant_config,
                    )
                    linear_k_adapter = Zamba2LoRA(
                        self.attention_hidden_size,
                        config.adapter_rank,
                        self.attention_hidden_size,
                        quant_config=quant_config,
                    )
                    linear_v_adapter = Zamba2LoRA(
                        self.attention_hidden_size,
                        config.adapter_rank,
                        self.attention_hidden_size,
                        quant_config=quant_config,
                    )
                else:
                    linear_q_adapter = nn.Identity()
                    linear_k_adapter = nn.Identity()
                    linear_v_adapter = nn.Identity()

                self.linear_q_adapter_list.append(linear_q_adapter)
                self.linear_k_adapter_list.append(linear_k_adapter)
                self.linear_v_adapter_list.append(linear_v_adapter)

        if config.use_mem_rope:
            self.rotary_emb = get_rope(
                head_size=self.attention_head_dim,
                rotary_dim=self.attention_head_dim,
                max_position=config.max_position_embeddings,
                base=self.rope_theta,
                rope_scaling=None,
                is_neox_style=True,
            )

    def forward(
        self,
        hidden_states: torch.Tensor,
        block_idx: int,
        position_ids: torch.Tensor,
    ) -> torch.Tensor:
        """Forward pass through the attention layer.
        
        Args:
            hidden_states: Input tensor [batch_size, seq_len, hidden_size]
            position_ids: Position IDs for positional embeddings
            block_idx: Current shared transformer block index
            
        Returns:
            Output tensor [batch_size, seq_len, hidden_size]
        """
        qkv, _ = self.qkv_proj(hidden_states)
        query_states, key_states, value_states = qkv.split([self.qkv_size] * 3,
                                                           dim=-1)

        if self.config.use_shared_attention_adapter:
            # Apply adapter transformations to Q, K, V if enabled
            q_adapter = self.linear_q_adapter_list[block_idx]
            assert not isinstance(q_adapter, nn.Identity)
            q_lora_output = q_adapter(hidden_states)
            query_states = query_states + q_lora_output

            k_adapter = self.linear_k_adapter_list[block_idx]
            assert not isinstance(k_adapter, nn.Identity)
            k_lora_output = k_adapter(hidden_states)
            key_states = key_states + k_lora_output

            v_adapter = self.linear_v_adapter_list[block_idx]
            assert not isinstance(v_adapter, nn.Identity)
            v_lora_output = v_adapter(hidden_states)
            value_states = value_states + v_lora_output

        if self.config.use_mem_rope:
            query_states, key_states = self.rotary_emb(position_ids,
                                                       query_states,
                                                       key_states)

        y = self.dpa_list[block_idx](query_states, key_states, value_states)
        y, _ = self.o_proj(y)
        return y


class Zamba2MLP(nn.Module):
    """Feed-forward MLP layer for the Zamba2 model.
    
    Implements a gated feed-forward network that projects inputs to a larger 
    intermediate size, applies GELU activation with gating, then projects back 
    to the original size. Includes optional adapter layers for model adaptation.
    """

    def __init__(
        self,
        config: Zamba2Config,
        bare_block_idx: int,
286
        num_hybrid_layers: dict[int, int],
287
        quant_config: Optional[QuantizationConfig] = None,
288
        prefix: str = "",
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
    ) -> None:
        """Initialize the MLP layer.
        
        Args:
            config: The Zamba2 model configuration
            bare_block_idx: Index of the bare block in the model
            num_hybrid_layers: Total number of hybrid layers
            quant_config: Configuration for model quantization
        """
        super().__init__()
        self.config = config
        self.tp_size = get_tensor_model_parallel_world_size()
        self.num_hybrid_layers = num_hybrid_layers
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size

        # Main projection layers with gating
        self.gate_up_proj = MergedColumnParallelLinear(
            self.hidden_size,
            2 * [self.intermediate_size],  # 2x for gate and input projections
            bias=self.config.add_bias_linear,
            quant_config=quant_config)

        self.down_proj = RowParallelLinear(self.intermediate_size,
                                           self.hidden_size,
                                           bias=self.config.add_bias_linear,
                                           quant_config=quant_config)

        # Only allow GELU activations
        if config.hidden_act != "gelu":
            raise ValueError(f"Only GELU activation is supported "
                             f"(got `hidden_act`: {config.hidden_act})")
        self.act_fn = GeluAndMul()

        # Initialize adapter layers
        self.gate_up_proj_adapter_list = nn.ModuleList([])
        for block_idx in range(self.num_hybrid_layers):
            if block_idx % config.num_mem_blocks == bare_block_idx:
                gate_up_proj_adapter = Zamba2LoRA(
                    config.hidden_size,
                    config.adapter_rank,
                    2 * [self.intermediate_size],
                    quant_config,
                )
            else:
                gate_up_proj_adapter = nn.Identity()
            self.gate_up_proj_adapter_list.append(gate_up_proj_adapter)

    def forward(self, hidden_states: torch.Tensor,
                block_idx: int) -> torch.Tensor:
        """Forward pass through the MLP layer.
        
        Args:
            hidden_states: Input tensor [batch_size, seq_len, hidden_size]
            block_idx: Current shared transformer block index
            
        Returns:
            Output tensor [batch_size, seq_len, hidden_size] after applying
            gated feed-forward transformation
        """
        # Project input to intermediate size with gating
        gate_up_states, _ = self.gate_up_proj(hidden_states)

        # Apply adapter transformation if present
        adapter = self.gate_up_proj_adapter_list[block_idx]
        assert not isinstance(adapter, nn.Identity)
        lora_output = adapter(hidden_states)
        gate_up_states = gate_up_states + lora_output

        # Apply GELU activation with gating
        hidden_states = self.act_fn(gate_up_states)

        # Project back to hidden size
        output, _ = self.down_proj(hidden_states)
        return output


class Zamba2AttentionDecoderLayer(nn.Module):
    """Single decoder layer combining attention and feed-forward networks.
    
    This layer implements a standard transformer block with:
    - Input layer normalization
    - Multi-head self-attention
    - Pre-feed-forward layer normalization
    - Feed-forward network (MLP)
    """

    def __init__(
        self,
        config: Zamba2Config,
        bare_block_idx: int,
        num_hybrid_layers: int,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
        """Initialize the decoder layer.
        
        Args:
            config: The Zamba2 model configuration
            bare_block_idx: Index of the bare block
            num_hybrid_layers: Total number of hybrid layers
            cache_config: Configuration for key-value caching
            quant_config: Configuration for model quantization
            prefix: Optional prefix for parameter names
        """
        super().__init__()

        # Initialize attention sublayer
        self.self_attn = Zamba2Attention(
            config,
            bare_block_idx=bare_block_idx,
            num_hybrid_layers=num_hybrid_layers,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=prefix,
        )

        # Initialize feed-forward sublayer
        self.feed_forward = Zamba2MLP(
            config,
            bare_block_idx=bare_block_idx,
            num_hybrid_layers=num_hybrid_layers,
            quant_config=quant_config,
        )

        # Initialize layer normalizations
        # Input normalization operates on concatenated states
        self.input_layernorm = RMSNorm(2 * config.hidden_size,
                                       eps=config.rms_norm_eps)
        # Pre-FF normalization operates on attention output
        self.pre_ff_layernorm = RMSNorm(config.hidden_size,
                                        eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        original_hidden_states: torch.Tensor,
        block_idx: int,
        positions: torch.Tensor,
    ) -> torch.Tensor:
        """Forward pass through the decoder layer.
        
        Args:
            hidden_states: Input tensor from previous layer
            original_hidden_states: Original input tensor for residual 
                connection
            block_idx: Current shared transformer block index
            positions: IDs for positional embeddings
            
        Returns:
            Transformed hidden states after attention and feed-forward
        """

        # The argument original_hidden_states is concatenated with hidden_states
        # (which is the output of the previous (mamba) layer).
        # The concatenated tensor is then used as input of the pre-attention
        # RMSNorm (see fig. 2 in https://arxiv.org/pdf/2405.16712).
        hidden_states = torch.concatenate(
            [hidden_states, original_hidden_states], dim=-1)

        # Layer norm before attention
        hidden_states = self.input_layernorm(hidden_states)

        # Self attention
        hidden_states = self.self_attn(
            hidden_states,
            position_ids=positions,
            block_idx=block_idx,
        )

        # Layer norm before feed-forward
        hidden_states = self.pre_ff_layernorm(hidden_states)

        # Feed-forward network
        hidden_states = self.feed_forward(hidden_states, block_idx=block_idx)

        return hidden_states


class Zamba2MambaDecoderLayer(nn.Module):
    """Single Mamba decoder layer with normalization.
    
    This implements a  Mamba block. It includes input normalization 
    and can process sequences using either chunked or full 
    computation depending on configuration.
    """

477
478
479
480
    def __init__(self,
                 config: Zamba2Config,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = "") -> None:
481
482
483
484
485
486
487
488
489
490
        """Initialize the Mamba decoder layer.
        
        Args:
            config: The Zamba2 model configuration
            quant_config: Configuration for model quantization
        """
        super().__init__()

        # Initialize Mamba mixer with expanded intermediate size
        intermediate_size = config.mamba_expand * config.hidden_size
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
        self.mamba = MambaMixer2(hidden_size=config.hidden_size,
                                 ssm_state_size=config.mamba_d_state,
                                 conv_kernel_size=config.mamba_d_conv,
                                 intermediate_size=intermediate_size,
                                 use_conv_bias=config.use_conv_bias,
                                 use_bias=config.add_bias_linear,
                                 n_groups=config.mamba_ngroups,
                                 num_heads=config.n_mamba_heads,
                                 head_dim=intermediate_size //
                                 config.n_mamba_heads,
                                 rms_norm_eps=config.rms_norm_eps,
                                 activation="silu",
                                 quant_config=quant_config,
                                 prefix=f"{prefix}.mixer",
                                 chunk_size=config.chunk_size)
506
507
508
509
510
511
512
513
514

        # Input normalization
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        mamba_cache_params: MambaCacheParams,
515
        mamba2_metadata: Mamba2Metadata,
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
        transformer_hidden_states: Optional[torch.Tensor] = None,
        positions: Optional[torch.Tensor] = None,
        original_hidden_states: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Forward pass through the Mamba decoder layer.
        
        Args:
            hidden_states: Input tensor [batch_size, seq_len, hidden_size]
            mamba_cache_params: Parameters for Mamba's state caches 
                (one for conv, one for ssm)
            sequence_idx: Index tensor for identifying sequences in batch
                Required for proper chunked processing in prefill
            transformer_hidden_states: Optional output from transformer path
                Added to input if provided (used in hybrid architecture)
            positions: Optional position IDs (unused in Mamba)
            original_hidden_states: Optional original inputs (unused in Mamba)
            
        Returns:
            Transformed hidden states with residual connection applied
        """
        # Store input for residual connection
        residual = hidden_states

        # `transformer_hidden_states` is the output from shared
        # transformer + linear layer (see fig. 2 in
        # https://arxiv.org/pdf/2405.16712).
        # `transformer_hidden_states` is then added to the input to the mamba
        # layer below (as described in eq. (6) of
        # https://arxiv.org/pdf/2405.16712).
        if transformer_hidden_states is not None:
            hidden_states = hidden_states + transformer_hidden_states

        # Apply input normalization
        hidden_states = self.input_layernorm(hidden_states)

        # Process through Mamba mixer
        hidden_states = self.mamba(
            hidden_states,
            mamba_cache_params=mamba_cache_params,
555
            mamba2_metadata=mamba2_metadata,
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
        )

        # residual connection after mamba
        hidden_states = residual + hidden_states

        return hidden_states


class Zamba2HybridLayer(nn.Module):
    """Hybrid layer combining Transformer and Mamba architectures.
    
    This layer implements the hybrid architecture described in the Zamba paper,
    where a shared transformer pathway processes input in parallel with a Mamba
    pathway. The transformer output is projected and added to the Mamba input
    for enhanced representation learning.
    """

    def __init__(
        self,
        shared_transformer: Zamba2AttentionDecoderLayer,
        config: Zamba2Config,
        block_idx: int,
        quant_config: Optional[QuantizationConfig] = None,
579
        prefix: str = "",
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
    ) -> None:
        """Initialize the hybrid layer.
        
        Args:
            shared_transformer: Transformer decoder layer for attention pathway
            linear: Linear projection for transformer output before Mamba
            mamba: Mamba decoder layer for state space pathway
        """
        super().__init__()
        self.block_idx = block_idx
        self.shared_transformer = shared_transformer
        self.linear = ReplicatedLinear(config.hidden_size,
                                       config.hidden_size,
                                       bias=False,
                                       quant_config=quant_config)
        self.mamba_decoder = Zamba2MambaDecoderLayer(config,
596
597
                                                     quant_config=quant_config,
                                                     prefix=prefix)
598
599
600
601
602
603

    def forward(
        self,
        hidden_states: torch.Tensor,
        original_hidden_states: torch.Tensor,
        positions: torch.Tensor,
604
605
        mamba_cache_params: MambaCacheParams,
        mamba2_metadata: Mamba2Metadata,
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
    ) -> torch.Tensor:
        """Forward pass through the hybrid layer.
        
        Processes input through parallel transformer and Mamba paths:
        1. Transformer path processes input with attention
        2. Transformer output is projected to match hidden size
        3. Projected output is added to Mamba path input
        4. Final output combines both paths' representations
        
        Args:
            hidden_states: Input tensor [batch_size, seq_len, hidden_size]
            original_hidden_states: Original input for transformer residual 
                connection
            positions: Position IDs for positional embeddings
            mamba_cache_params: Parameters for Mamba's state caches 
                (one for conv, one for ssm)
            sequence_idx: Indices for identifying sequences in batch,
                required for proper chunked processing in prefill
            
        Returns:
            Output tensor combining transformer and Mamba representations
        """
        # Process through transformer pathway
        transformer_hidden_states = self.shared_transformer(
            hidden_states,
            original_hidden_states=original_hidden_states,
            block_idx=self.block_idx,
            positions=positions,
        )

        # Project transformer output
        transformer_hidden_states, _ = self.linear(transformer_hidden_states)

        # Process through Mamba pathway with transformer injection
        layer_outputs = self.mamba_decoder(
            hidden_states,
            transformer_hidden_states=transformer_hidden_states,
            mamba_cache_params=mamba_cache_params,
644
            mamba2_metadata=mamba2_metadata,
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
        )

        return layer_outputs


class Zamba2Model(nn.Module):
    """Core Zamba2 model combining transformer and Mamba architectures.
    
    The model processes input through a sequence of hybrid and Mamba-only 
    layers, using token embeddings and final layer normalization.
    """

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
        """Initialize the Zamba2 model.
        
        Args:
            vllm_config: Configuration object containing model, cache, 
                quantization and LoRA settings
            prefix: Optional prefix for parameter names in state dict
        """
        super().__init__()

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
        is_lora_enabled = bool(lora_config)
        assert not is_lora_enabled

        self.config = config
        lora_vocab = ((lora_config.lora_extra_vocab_size *
                       (lora_config.max_loras or 1)) if lora_config else 0)
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size

        # Initialize token embeddings
        self.embed_tokens = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
        )

        # Map hybrid layer indices to block indices
        layer2block_map = {
            layer_idx: block_idx
            for block_idx, layer_idx in enumerate(config.hybrid_layer_ids)
        }

        # Create cyclic iterator of transformer blocks
        blocks = cycle([
            Zamba2AttentionDecoderLayer(config,
                                        bare_block_idx=idx,
                                        num_hybrid_layers=len(layer2block_map),
                                        cache_config=cache_config,
                                        quant_config=quant_config,
                                        prefix=f"{prefix}")
            for idx in range(config.num_mem_blocks)
        ])

        # Initialize layers according to block type configuration
        layers = []
        for layer_idx, layer_type in enumerate(config.layers_block_type):
707
708
709
            # tdoublep: avoid layers getting same index
            # somewhat hacky but correct (I think)
            prefix = str(len(layer2block_map) + layer_idx)
710
711
712
713
            if layer_type == "hybrid":
                block = next(blocks)
                block_idx = layer2block_map[layer_idx]
                layers.append(
714
715
716
717
718
                    Zamba2HybridLayer(block,
                                      config,
                                      block_idx,
                                      quant_config,
                                      prefix=prefix))
719
720
            else:
                layers.append(
721
722
723
                    Zamba2MambaDecoderLayer(config,
                                            quant_config=quant_config,
                                            prefix=prefix))
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
        self.layers = nn.ModuleList(layers)

        # Final layer normalization
        self.final_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        """Convert input token IDs to embeddings.
        
        Args:
            input_ids: Tensor of input token IDs
            
        Returns:
            Embedded representation of the input tokens
        """
        return self.embed_tokens(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        mamba_cache_params: MambaCacheParams,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        """Forward pass through the model.
        
        Args:
            input_ids: Input token IDs
            positions: Position IDs for embeddings
            mamba_cache_params: Parameters for Mamba's state caches 
                (one for conv, one for ssm)
            inputs_embeds: Optional pre-computed input embeddings
            
        Returns:
            Either final hidden states or intermediate tensors for pipeline 
            parallelism
        """
        # Handle pipeline parallelism for first rank
        if inputs_embeds is None:
            inputs_embeds = self.get_input_embeddings(input_ids)
        hidden_states = inputs_embeds

        attn_metadata = get_forward_context().attn_metadata
767

768
769
770
771
772
773
774
775
        if not envs.VLLM_USE_V1:
            mamba2_metadata = prepare_mamba2_metadata(
                chunk_size=self.config.chunk_size,
                attn_metadata=attn_metadata,
            )
        else:
            # v1 get mamba2_metadata from forward_context
            mamba2_metadata = None
776
777
778
779

        # Process through layers
        original_hidden_states = torch.clone(hidden_states)
        for layer_idx, layer in enumerate(self.layers):
780
781
782
783
784
785
786

            layer_mamba_cache_params = None
            if (isinstance(layer, (Zamba2HybridLayer, Zamba2MambaDecoderLayer))
                    and mamba_cache_params):
                layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
                    layer_idx)

787
788
789
790
            layer_outputs = layer(
                hidden_states,
                original_hidden_states=original_hidden_states,
                positions=positions,
791
                mamba_cache_params=layer_mamba_cache_params,
792
                mamba2_metadata=mamba2_metadata,
793
794
795
796
797
798
            )
            hidden_states = layer_outputs

        hidden_states = self.final_layernorm(hidden_states)
        return hidden_states

799
800
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
801
802
803
804
805
806
807
808
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]

        params_dict = dict(self.named_parameters())
809
        loaded_params: set[str] = set()
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
        for chkpt_weight_name, loaded_weight in weights:
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in chkpt_weight_name:
                    continue
                chkpt_weight_name = chkpt_weight_name.replace(
                    weight_name, param_name)
                param = params_dict[chkpt_weight_name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                if chkpt_weight_name not in params_dict:
                    continue
                param = params_dict[chkpt_weight_name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(chkpt_weight_name)
        return loaded_params

830

831
class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
832
833
834
835
836
837
838
839
    """Zamba2 model with causal language modeling head.
    
    This class wraps the core Zamba2 model and adds:
    - A language modeling head for next token prediction
    - Mamba state caching functionality
    - Support for model parallelism and quantization
    - Sampling capabilities for text generation
    """
840
841
842
843
844
845
    # To ensure correct weight loading and mapping.
    hf_to_vllm_mapper = WeightsMapper(orig_to_new_substr={
        "A_log": "A",
        "0.weight": "A.weight",
        "1.weight": "B.weight",
    })
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
        """Initialize the Zamba2 model for causal language modeling.
        
        Args:
            vllm_config: Configuration containing model, cache, quantization,
                        LoRA and scheduler settings
            prefix: Optional prefix for parameter names
        
        Raises:
            AssertionError: If prefix caching is enabled (not supported by 
            Mamba)
        """
        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        lora_config = vllm_config.lora_config
        scheduler_config = vllm_config.scheduler_config
        assert not cache_config.enable_prefix_caching, \
            "Mamba does not support prefix caching"

        super().__init__()
        self.config = config
        self.vllm_config = vllm_config
        self.scheduler_config = scheduler_config
        self.model_config = vllm_config.model_config
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size

        # Initialize core model
        self.model = Zamba2Model(vllm_config=vllm_config,
                                 prefix=maybe_prefix(prefix, "model"))

        # Initialize language modeling head
        self.lm_head = ParallelLMHead(
            self.unpadded_vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE
            # We need bigger padding if using lora for kernel
            # compatibility
            if not lora_config else lora_config.lora_vocab_padding_size,
        )
        # Tie weights with input embeddings if using same dimensions
        self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)

        # Used to track and store by the Mamba cache between steps.
        self.mamba_cache: Optional[MambaCacheManager] = None

        # Initialize logits processing and sampling
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size)

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        """Convert input token IDs to embeddings.
        Args:
            input_ids: Tensor of input token IDs
        Returns:
            Embedded representation of the input tokens
        """
        return self.model.get_input_embeddings(input_ids)

    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
                inputs_embeds: Optional[torch.Tensor] = None,
                **kwargs) -> torch.Tensor:
        """Forward pass through the model.
        
        Args:
            input_ids: Input token IDs
            positions: Position IDs for embeddings
            inputs_embeds: Optional pre-computed input embeddings
            **kwargs: Additional arguments passed to cache manager
            
        Returns:
            Output hidden states
        """
        # Initialize Mamba cache if needed
925
926
927
928
929
930
931
932
933
934
        mamba_cache_params = None
        if not envs.VLLM_USE_V1:
            if self.mamba_cache is None:
                num_mamba_layers = self.config.num_hidden_layers
                self.mamba_cache = MambaCacheManager(
                    self.vllm_config, self.lm_head.weight.dtype,
                    num_mamba_layers, *self._get_mamba_cache_shape())

            # Get cache parameters for current run
            mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
935
936
937
938
939
940
941
942
943
944
945

        # Forward pass through model
        hidden_states = self.model(
            input_ids,
            positions,
            mamba_cache_params,
            inputs_embeds,
        )

        return hidden_states

946
    def copy_inputs_before_cuda_graphs(self, input_buffers: dict[str,
947
                                                                 torch.Tensor],
948
                                       **kwargs) -> dict[str, torch.Tensor]:
949
950
951
952
953
954
955
956
957
958
959
960
961
        """Copy inputs before CUDA graph capture.
        
        Args:
            input_buffers: Dictionary of input tensors
            **kwargs: Additional arguments passed to cache manager
            
        Returns:
            Updated input buffers
        """
        return self.mamba_cache.copy_inputs_before_cuda_graphs(
            input_buffers, **kwargs)

    def get_seqlen_agnostic_capture_inputs(
962
            self, batch_size: int) -> dict[str, torch.Tensor]:
963
964
965
966
967
968
969
970
971
972
        """Get inputs for sequence-length-agnostic graph capture.
        
        Args:
            batch_size: Size of batch to capture
        Returns:
            Dictionary of capture inputs
        """
        return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)

    def _get_mamba_cache_shape(
973
            self) -> tuple[tuple[int, int], tuple[int, int]]:
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
        """Calculate shapes for Mamba's convolutional and state caches.
        
        Returns:
            Tuple containing:
            - conv_state_shape: Shape for convolutional state cache
            - temporal_state_shape: Shape for state space model cache
        """
        world_size = get_tensor_model_parallel_world_size()

        intermediate_size = self.config.mamba_expand * self.config.hidden_size

        # Extend groups if needed to ensure all groups needed by a head
        # are sharded together

        # if n_groups is not divisible by world_size, need to extend the shards
        # to ensure all groups needed by a head is sharded along with it
        n_groups = (self.config.mamba_ngroups + extra_groups_for_head_shards(
            self.config.mamba_ngroups, world_size))

        # Calculate conv state shape (includes groups)
        # - heads and n_groups are TP-ed
        conv_dim = (intermediate_size +
                    2 * n_groups * self.config.mamba_d_state)
        conv_state_shape = (
            divide(conv_dim, world_size),
            self.config.mamba_d_conv - 1,
        )

        # Calculate temporal state shape (per-head states)
        # These are not TP-ed as they depend on A, dt_bias, D
        # - they are typically small
        #   e.g., (h_heads, d_head, d_state) = (128, 64, 128)
        temporal_state_shape = (
            divide(divide(intermediate_size, self.config.mamba_headdim),
                   world_size),
            self.config.mamba_headdim,
            self.config.mamba_d_state,
        )

        return conv_state_shape, temporal_state_shape

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
        """Compute logits for next token prediction.
        
        Args:
            hidden_states: Hidden states from model forward pass
            sampling_metadata: Metadata for sampling process
            
        Returns:
            Logits for next token prediction
        """
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)
        return logits

1033
1034
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
1035
1036
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)