# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. import torch from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules from megatron.core.transformer.dot_product_attention import DotProductAttention from megatron.core.transformer.enums import AttnMaskType from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.mlp import MLP, MLPSubmodules from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules try: from megatron.core.extensions.transformer_engine import ( TEColumnParallelLinear, TEDotProductAttention, TELayerNormColumnParallelLinear, TENorm, TERowParallelLinear, ) HAVE_TE = True except ImportError: HAVE_TE = False try: import apex from megatron.core.fusions.fused_layer_norm import FusedLayerNorm from megatron.core.transformer.torch_norm import WrappedTorchNorm HAVE_APEX = True LNImpl = FusedLayerNorm except ImportError: import warnings from megatron.core.transformer.torch_norm import WrappedTorchNorm warnings.warn(f'Apex is not installed. Falling back to Torch Norm') LNImpl = WrappedTorchNorm def get_layer_spec(is_vit, normalization) -> ModuleSpec: attn_mask_type = AttnMaskType.no_mask if is_vit else AttnMaskType.causal if normalization == "LayerNorm": norm = LNImpl elif normalization == "RMSNorm": if HAVE_TE: norm = TENorm else: version = torch.__version__.split('.') version_geq_2_4 = ( int(TORCH_VERSION[0]) > 2 or ( int(TORCH_VERSION[0]) == 2 and int(TORCH_VERSION[1]) >= 4 ) ) assert version_geq_2_4, "Torch version >= 2.4.0 is required for RMSNorm" if HAVE_APEX: warnings.warn(f'Apex does not support RMSNorm. Falling back to Torch Norm') norm = WrappedTorchNorm else: raise RuntimeError("unknown normalization", normalization) mlp = get_mlp_module_spec(use_te=False) # doesn't include norm. return ModuleSpec( module=TransformerLayer, submodules=TransformerLayerSubmodules( input_layernorm=norm, self_attention=ModuleSpec( module=SelfAttention, params={"attn_mask_type": attn_mask_type}, submodules=SelfAttentionSubmodules( linear_qkv=ColumnParallelLinear, core_attention=DotProductAttention, linear_proj=RowParallelLinear, q_layernorm=IdentityOp, k_layernorm=IdentityOp, ), ), self_attn_bda=get_bias_dropout_add, pre_mlp_layernorm=norm, mlp=mlp, mlp_bda=get_bias_dropout_add, ), ) def get_layer_spec_te(is_vit=False) -> ModuleSpec: attn_mask_type = AttnMaskType.no_mask if is_vit else AttnMaskType.causal mlp = get_norm_mlp_module_spec_te() return ModuleSpec( module=TransformerLayer, submodules=TransformerLayerSubmodules( self_attention=ModuleSpec( module=SelfAttention, params={"attn_mask_type": attn_mask_type}, submodules=SelfAttentionSubmodules( linear_qkv=TELayerNormColumnParallelLinear, core_attention=TEDotProductAttention, linear_proj=TERowParallelLinear, q_layernorm=IdentityOp, k_layernorm=IdentityOp, ), ), self_attn_bda=get_bias_dropout_add, pre_mlp_layernorm=IdentityOp, mlp=mlp, mlp_bda=get_bias_dropout_add, ), ) def get_mlp_module_spec(use_te: bool = True) -> ModuleSpec: # Dense MLP w/ or w/o TE modules. return ModuleSpec( module=MLP, submodules=MLPSubmodules( linear_fc1=TEColumnParallelLinear if use_te else ColumnParallelLinear, linear_fc2=TERowParallelLinear if use_te else RowParallelLinear, ), ) def get_norm_mlp_module_spec_te() -> ModuleSpec: return ModuleSpec( module=MLP, submodules=MLPSubmodules( linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear ), )