# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. import pytest import torch from megatron.core.models.mamba.mamba_layer_specs import mamba_stack_spec from megatron.core.ssm.mamba_layer import MambaLayer from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from megatron.core.transformer.transformer_config import TransformerConfig from tests.unit_tests.test_utilities import Utils class TestMambaLayer: def setup_method(self, method): Utils.initialize_model_parallel(1, 1) model_parallel_cuda_manual_seed(123) transformer_config = TransformerConfig( hidden_size=256, # The Mamba layer places several constraints on this # Need to specify num_attention_heads and num_layers or TransformerConfig # will generate errors. num_layers=1, num_attention_heads=1, use_cpu_initialization=True, ) modules = mamba_stack_spec.submodules.mamba_layer.submodules self.layer = MambaLayer(transformer_config, modules) def teardown_method(self, method): Utils.destroy_model_parallel() def test_gpu_forward(self): layer = self.layer layer.cuda() micro_batch_size = 2 sequence_length = 32 hidden_states = torch.ones((sequence_length, micro_batch_size, layer.config.hidden_size)) hidden_states = hidden_states.cuda() attention_mask = torch.ones( (micro_batch_size, 1, sequence_length, sequence_length), dtype=bool ) attention_mask = attention_mask.cuda() output = layer(hidden_states, attention_mask=attention_mask) assert output.shape[0] == sequence_length assert output.shape[1] == micro_batch_size assert output.shape[2] == layer.config.hidden_size assert output.dtype == torch.float32