Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
d46a984e
Commit
d46a984e
authored
Apr 16, 2025
by
dongcl
Browse files
add get_gpt_layer_with_flux_spec
parent
53627040
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
173 additions
and
11 deletions
+173
-11
dcu_megatron/adaptor/megatron_adaptor.py
dcu_megatron/adaptor/megatron_adaptor.py
+14
-11
dcu_megatron/core/models/gpt/gpt_layer_specs.py
dcu_megatron/core/models/gpt/gpt_layer_specs.py
+159
-0
No files found.
dcu_megatron/adaptor/megatron_adaptor.py
View file @
d46a984e
...
...
@@ -5,6 +5,8 @@ import types
import
argparse
import
torch
from
megatron.training
import
get_args
class
MegatronAdaptation
:
"""
...
...
@@ -168,12 +170,6 @@ class CoreAdaptation(MegatronAdaptationABC):
def
patch_tensor_parallel
(
self
):
from
..core.tensor_parallel.cross_entropy
import
VocabParallelCrossEntropy
from
..core.tensor_parallel
import
vocab_parallel_embedding_forward
,
vocab_parallel_embedding_init
from
..core.tensor_parallel
import
(
ColumnParallelLinearPatch
,
RowParallelLinearPatch
,
column_parallel_linear_init_wrapper
,
row_parallel_linear_init_wrapper
)
# VocabParallelEmbedding
MegatronAdaptation
.
register
(
'megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward'
,
...
...
@@ -195,13 +191,18 @@ class CoreAdaptation(MegatronAdaptationABC):
apply_wrapper
=
True
)
# flux
try
:
args
=
get_args
()
if
args
.
use_flux
:
import
flux
HAS_FLUX
=
True
except
ImportError
:
HAS_FLUX
=
False
if
HAS_FLUX
:
from
..core.tensor_parallel
import
(
ColumnParallelLinearPatch
,
RowParallelLinearPatch
,
column_parallel_linear_init_wrapper
,
row_parallel_linear_init_wrapper
)
from
..core.models.gpt.gpt_layer_specs
import
get_gpt_layer_with_flux_spec
MegatronAdaptation
.
register
(
"megatron.core.tensor_parallel.layers.ColumnParallelLinear.__init__"
,
column_parallel_linear_init_wrapper
,
apply_wrapper
=
True
)
...
...
@@ -212,6 +213,8 @@ class CoreAdaptation(MegatronAdaptationABC):
apply_wrapper
=
True
)
MegatronAdaptation
.
register
(
"megatron.core.tensor_parallel.layers.RowParallelLinear.forward"
,
RowParallelLinearPatch
.
forward
)
MegatronAdaptation
.
register
(
"megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_local_spec"
,
get_gpt_layer_with_flux_spec
)
def
patch_training
(
self
):
from
..training.tokenizer
import
build_tokenizer
...
...
dcu_megatron/core/models/gpt/gpt_layer_specs.py
0 → 100644
View file @
d46a984e
import
warnings
from
typing
import
Optional
from
megatron.core.fusions.fused_bias_dropout
import
get_bias_dropout_add
from
megatron.core.models.gpt.moe_module_specs
import
get_moe_module_spec
from
megatron.core.tensor_parallel.layers
import
ColumnParallelLinear
,
RowParallelLinear
from
megatron.core.transformer.attention
import
SelfAttention
,
SelfAttentionSubmodules
from
megatron.core.transformer.enums
import
AttnMaskType
from
megatron.core.transformer.identity_op
import
IdentityOp
from
megatron.core.transformer.mlp
import
MLP
,
MLPSubmodules
from
megatron.core.transformer.multi_latent_attention
import
(
MLASelfAttention
,
MLASelfAttentionSubmodules
,
)
from
megatron.core.transformer.spec_utils
import
ModuleSpec
from
megatron.core.transformer.transformer_layer
import
(
TransformerLayer
,
TransformerLayerSubmodules
,
)
from
megatron.core.utils
import
is_te_min_version
try
:
from
megatron.core.extensions.transformer_engine
import
(
TEDotProductAttention
,
TENorm
,
)
except
ImportError
:
warnings
.
warn
(
'transformer_engine is not installed.'
)
try
:
import
apex
# pylint: disable=unused-import
from
megatron.core.fusions.fused_layer_norm
import
FusedLayerNorm
except
ImportError
:
warnings
.
warn
(
'Apex is not installed.'
)
def
get_gpt_layer_with_flux_spec
(
num_experts
:
Optional
[
int
]
=
None
,
moe_grouped_gemm
:
Optional
[
bool
]
=
False
,
qk_layernorm
:
Optional
[
bool
]
=
False
,
multi_latent_attention
:
Optional
[
bool
]
=
False
,
fp8
:
Optional
[
str
]
=
None
,
# pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm
:
Optional
[
bool
]
=
False
,
)
->
ModuleSpec
:
"""Use this spec to use lower-level Transformer Engine modules (required for fp8 training).
Args:
num_experts (int, optional): Number of experts. Defaults to None.
moe_grouped_gemm (bool, optional): To use Grouped GEMM. Defaults to False.
qk_layernorm (bool, optional): To use layernorm for queries/keys. Defaults to False.
fp8 (str, optional): Deprecated. For temporary Nemo compatibility.
moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP.
Defaults to False.
Returns:
ModuleSpec: Module specification with TE modules
"""
if
fp8
is
not
None
:
warnings
.
warn
(
'The fp8 argument in "get_gpt_layer_with_transformer_engine_spec" has been deprecated'
' and will be removed soon. Please update your code accordingly.'
)
mlp
=
get_mlp_module_flux_spec
(
use_te
=
False
,
num_experts
=
num_experts
,
moe_grouped_gemm
=
moe_grouped_gemm
,
moe_use_legacy_grouped_gemm
=
moe_use_legacy_grouped_gemm
,
)
if
multi_latent_attention
:
return
ModuleSpec
(
module
=
TransformerLayer
,
submodules
=
TransformerLayerSubmodules
(
input_layernorm
=
TENorm
,
self_attention
=
ModuleSpec
(
module
=
MLASelfAttention
,
params
=
{
"attn_mask_type"
:
AttnMaskType
.
causal
},
submodules
=
MLASelfAttentionSubmodules
(
linear_q_proj
=
ColumnParallelLinear
,
linear_q_down_proj
=
ColumnParallelLinear
,
linear_q_up_proj
=
ColumnParallelLinear
,
linear_kv_down_proj
=
ColumnParallelLinear
,
linear_kv_up_proj
=
ColumnParallelLinear
,
core_attention
=
TEDotProductAttention
,
linear_proj
=
RowParallelLinear
,
q_layernorm
=
TENorm
if
qk_layernorm
else
IdentityOp
,
kv_layernorm
=
TENorm
if
qk_layernorm
else
IdentityOp
,
),
),
self_attn_bda
=
get_bias_dropout_add
,
pre_mlp_layernorm
=
TENorm
if
num_experts
else
IdentityOp
,
mlp
=
mlp
,
mlp_bda
=
get_bias_dropout_add
,
),
)
else
:
# TENorm significantly harms convergence when used
# for QKLayerNorm if TE Version < 1.9;
# we instead use the Apex implementation.
qk_norm
=
TENorm
if
is_te_min_version
(
"1.9.0"
)
else
FusedLayerNorm
return
ModuleSpec
(
module
=
TransformerLayer
,
submodules
=
TransformerLayerSubmodules
(
input_layernorm
=
TENorm
,
self_attention
=
ModuleSpec
(
module
=
SelfAttention
,
params
=
{
"attn_mask_type"
:
AttnMaskType
.
causal
},
submodules
=
SelfAttentionSubmodules
(
linear_qkv
=
ColumnParallelLinear
,
core_attention
=
TEDotProductAttention
,
linear_proj
=
RowParallelLinear
,
q_layernorm
=
qk_norm
if
qk_layernorm
else
IdentityOp
,
k_layernorm
=
qk_norm
if
qk_layernorm
else
IdentityOp
,
),
),
self_attn_bda
=
get_bias_dropout_add
,
pre_mlp_layernorm
=
TENorm
if
num_experts
else
IdentityOp
,
mlp
=
mlp
,
mlp_bda
=
get_bias_dropout_add
,
),
)
def
get_mlp_module_flux_spec
(
use_te
:
Optional
[
bool
]
=
True
,
num_experts
:
Optional
[
int
]
=
None
,
moe_grouped_gemm
:
Optional
[
bool
]
=
False
,
fp8
:
Optional
[
str
]
=
None
,
# pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm
:
Optional
[
bool
]
=
False
,
)
->
ModuleSpec
:
"""Helper function to get module spec for MLP/MoE"""
if
fp8
is
not
None
:
warnings
.
warn
(
'The fp8 argument in "_get_mlp_module_spec" has been deprecated'
' and will be removed soon. Please update your code accordingly.'
)
if
num_experts
is
None
:
# Dense MLP w/ or w/o TE modules.
return
ModuleSpec
(
module
=
MLP
,
submodules
=
MLPSubmodules
(
linear_fc1
=
ColumnParallelLinear
,
linear_fc2
=
RowParallelLinear
,
),
)
else
:
# Mixture of experts with modules in megatron core.
return
get_moe_module_spec
(
use_te
=
True
,
num_experts
=
num_experts
,
moe_grouped_gemm
=
moe_grouped_gemm
,
moe_use_legacy_grouped_gemm
=
moe_use_legacy_grouped_gemm
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment