# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2024, Tri Dao, Albert Gu. # Some of this code was adopted from https://github.com/state-spaces/mamba/ # This source code is licensed under the Apache license found in the # LICENSE file in the root directory of this source tree. from dataclasses import dataclass from typing import Union import torch from torch import Tensor from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.transformer.transformer_config import TransformerConfig @dataclass class MambaLayerSubmodules: """ Configuration class for specifying the submodules of a Mamba layer. This class defines the structure and default implementations for various components of a Mamba layer, allowing for flexible customization of the layer's architecture. Args: norm (Union[ModuleSpec, type]): Specification for the input layer normalization. mixer (Union[ModuleSpec, type]): Specification for the along-sequence mixing mechanism. mamba_bda (Union[ModuleSpec, type]): Specification for the bias-dropout-add operation after the mixer. """ norm: Union[ModuleSpec, type] = IdentityOp mixer: Union[ModuleSpec, type] = IdentityOp mamba_bda: Union[ModuleSpec, type] = IdentityOp class MambaLayer(MegatronModule): """ A single Mamba layer. Mamba layer takes input with size [s, b, h] and returns an output of the same size. """ def __init__( self, config: TransformerConfig, submodules: MambaLayerSubmodules, mamba_ssm_ngroups=8, layer_number: int = 1, residual_in_fp32=False, ): """Initialize Mamba Layer.""" super().__init__(config) self.config = config self.layer_number = layer_number self.residual_in_fp32 = residual_in_fp32 self.hidden_dropout = config.hidden_dropout self.mixer = build_module( submodules.mixer, self.config, d_model=self.config.hidden_size, ngroups=mamba_ssm_ngroups, layer_number=layer_number, ) self.norm = build_module(submodules.norm, self.config, self.config.hidden_size) self.mamba_bda = build_module(submodules.mamba_bda) self.bias_dropout_add_exec_handler = torch.enable_grad def forward( self, hidden_states: Tensor, attention_mask: Tensor, # Not used in MambaLayer inference_params=None, rotary_pos_emb: Tensor = None, # Not used in MambaLayer ): """ Perform a forward pass through the Mamba layer. This method implements the core computation of a Mamba layer, including the convolution and the selective SSM/SSD. Args: hidden_states (Tensor): Input tensor of shape [s, b, h] where s is sequence length, b is batch size, and h is hidden size. attention_mask (Tensor): Mask tensor for self-attention. Not used by this layer. inference_params (object, optional): Parameters for inference-time optimizations. rotary_pos_emb (Tensor, optional): Rotary positional embeddings. Returns: output (Tensor): Transformed hidden states of shape [s, b, h]. """ residual = hidden_states if self.residual_in_fp32: residual = residual.to(torch.float32) hidden_states = hidden_states.to(dtype=self.config.params_dtype) hidden_states = self.norm(hidden_states) mixer_out_with_bias = self.mixer(hidden_states, inference_params=inference_params) with self.bias_dropout_add_exec_handler(): hidden_states = self.mamba_bda(self.training, self.config.bias_dropout_fusion)( mixer_out_with_bias, residual, self.hidden_dropout ) return hidden_states def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None): """Allocate the inference cache.""" return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype)