Commit d46a984e authored by dongcl's avatar dongcl
Browse files

add get_gpt_layer_with_flux_spec

parent 53627040
......@@ -5,6 +5,8 @@ import types
import argparse
import torch
from megatron.training import get_args
class MegatronAdaptation:
"""
......@@ -168,12 +170,6 @@ class CoreAdaptation(MegatronAdaptationABC):
def patch_tensor_parallel(self):
from ..core.tensor_parallel.cross_entropy import VocabParallelCrossEntropy
from ..core.tensor_parallel import vocab_parallel_embedding_forward, vocab_parallel_embedding_init
from ..core.tensor_parallel import (
ColumnParallelLinearPatch,
RowParallelLinearPatch,
column_parallel_linear_init_wrapper,
row_parallel_linear_init_wrapper
)
# VocabParallelEmbedding
MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward',
......@@ -195,13 +191,18 @@ class CoreAdaptation(MegatronAdaptationABC):
apply_wrapper=True)
# flux
try:
args = get_args()
if args.use_flux:
import flux
HAS_FLUX = True
except ImportError:
HAS_FLUX = False
if HAS_FLUX:
from ..core.tensor_parallel import (
ColumnParallelLinearPatch,
RowParallelLinearPatch,
column_parallel_linear_init_wrapper,
row_parallel_linear_init_wrapper
)
from ..core.models.gpt.gpt_layer_specs import get_gpt_layer_with_flux_spec
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.ColumnParallelLinear.__init__",
column_parallel_linear_init_wrapper,
apply_wrapper=True)
......@@ -212,6 +213,8 @@ class CoreAdaptation(MegatronAdaptationABC):
apply_wrapper=True)
MegatronAdaptation.register("megatron.core.tensor_parallel.layers.RowParallelLinear.forward",
RowParallelLinearPatch.forward)
MegatronAdaptation.register("megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_local_spec",
get_gpt_layer_with_flux_spec)
def patch_training(self):
from ..training.tokenizer import build_tokenizer
......
import warnings
from typing import Optional
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.multi_latent_attention import (
MLASelfAttention,
MLASelfAttentionSubmodules,
)
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_layer import (
TransformerLayer,
TransformerLayerSubmodules,
)
from megatron.core.utils import is_te_min_version
try:
from megatron.core.extensions.transformer_engine import (
TEDotProductAttention,
TENorm,
)
except ImportError:
warnings.warn('transformer_engine is not installed.')
try:
import apex # pylint: disable=unused-import
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
except ImportError:
warnings.warn('Apex is not installed.')
def get_gpt_layer_with_flux_spec(
num_experts: Optional[int] = None,
moe_grouped_gemm: Optional[bool] = False,
qk_layernorm: Optional[bool] = False,
multi_latent_attention: Optional[bool] = False,
fp8: Optional[str] = None, # pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm: Optional[bool] = False,
) -> ModuleSpec:
"""Use this spec to use lower-level Transformer Engine modules (required for fp8 training).
Args:
num_experts (int, optional): Number of experts. Defaults to None.
moe_grouped_gemm (bool, optional): To use Grouped GEMM. Defaults to False.
qk_layernorm (bool, optional): To use layernorm for queries/keys. Defaults to False.
fp8 (str, optional): Deprecated. For temporary Nemo compatibility.
moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP.
Defaults to False.
Returns:
ModuleSpec: Module specification with TE modules
"""
if fp8 is not None:
warnings.warn(
'The fp8 argument in "get_gpt_layer_with_transformer_engine_spec" has been deprecated'
' and will be removed soon. Please update your code accordingly.'
)
mlp = get_mlp_module_flux_spec(
use_te=False,
num_experts=num_experts,
moe_grouped_gemm=moe_grouped_gemm,
moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm,
)
if multi_latent_attention:
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=TENorm,
self_attention=ModuleSpec(
module=MLASelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=MLASelfAttentionSubmodules(
linear_q_proj=ColumnParallelLinear,
linear_q_down_proj=ColumnParallelLinear,
linear_q_up_proj=ColumnParallelLinear,
linear_kv_down_proj=ColumnParallelLinear,
linear_kv_up_proj=ColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=RowParallelLinear,
q_layernorm=TENorm if qk_layernorm else IdentityOp,
kv_layernorm=TENorm if qk_layernorm else IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=TENorm if num_experts else IdentityOp,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
),
)
else:
# TENorm significantly harms convergence when used
# for QKLayerNorm if TE Version < 1.9;
# we instead use the Apex implementation.
qk_norm = TENorm if is_te_min_version("1.9.0") else FusedLayerNorm
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=TENorm,
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=RowParallelLinear,
q_layernorm=qk_norm if qk_layernorm else IdentityOp,
k_layernorm=qk_norm if qk_layernorm else IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=TENorm if num_experts else IdentityOp,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
),
)
def get_mlp_module_flux_spec(
use_te: Optional[bool] = True,
num_experts: Optional[int] = None,
moe_grouped_gemm: Optional[bool] = False,
fp8: Optional[str] = None, # pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm: Optional[bool] = False,
) -> ModuleSpec:
"""Helper function to get module spec for MLP/MoE"""
if fp8 is not None:
warnings.warn(
'The fp8 argument in "_get_mlp_module_spec" has been deprecated'
' and will be removed soon. Please update your code accordingly.'
)
if num_experts is None:
# Dense MLP w/ or w/o TE modules.
return ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=ColumnParallelLinear,
linear_fc2=RowParallelLinear,
),
)
else:
# Mixture of experts with modules in megatron core.
return get_moe_module_spec(
use_te=True,
num_experts=num_experts,
moe_grouped_gemm=moe_grouped_gemm,
moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment