# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. import pytest import torch from megatron.core.models.mamba.mamba_layer_specs import mamba_stack_spec from megatron.core.ssm.mamba_mixer import MambaMixer from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from megatron.core.transformer.transformer_config import TransformerConfig from tests.unit_tests.test_utilities import Utils class TestMambaMixer: def setup_method(self, method): Utils.initialize_model_parallel(1, 1) model_parallel_cuda_manual_seed(123) transformer_config = TransformerConfig( hidden_size=256, # The Mamba layer places several constraints on this # Need to specify num_attention_heads and num_layers or TransformerConfig # will generate errors. num_layers=1, num_attention_heads=1, use_cpu_initialization=True, ) modules = mamba_stack_spec.submodules.mamba_layer.submodules.mixer.submodules self.mixer = MambaMixer(transformer_config, modules, transformer_config.hidden_size) self.mixer_no_mem_eff_path = MambaMixer( transformer_config, modules, transformer_config.hidden_size, use_mem_eff_path=False ) def teardown_method(self, method): Utils.destroy_model_parallel() @pytest.mark.parametrize("use_mem_eff_path", [True, False]) def test_gpu_forward(self, use_mem_eff_path): if use_mem_eff_path: mixer = self.mixer else: mixer = self.mixer_no_mem_eff_path mixer.cuda() micro_batch_size = 2 sequence_length = 32 hidden_states = torch.ones((sequence_length, micro_batch_size, mixer.config.hidden_size)) hidden_states = hidden_states.cuda() output, bias = mixer(hidden_states) assert output.shape[0] == sequence_length assert output.shape[1] == micro_batch_size assert output.shape[2] == mixer.config.hidden_size assert output.dtype == torch.float32