Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
7c9dc3ec
Commit
7c9dc3ec
authored
May 07, 2025
by
dongcl
Browse files
forward_backward_pipelining_without_interleaving supports a2a_overlap
parent
649bfbdb
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
433 additions
and
1167 deletions
+433
-1167
dcu_megatron/adaptor/megatron_adaptor.py
dcu_megatron/adaptor/megatron_adaptor.py
+26
-11
dcu_megatron/core/models/gpt/gpt_layer_specs.py
dcu_megatron/core/models/gpt/gpt_layer_specs.py
+10
-8
dcu_megatron/core/models/gpt/gpt_model.py
dcu_megatron/core/models/gpt/gpt_model.py
+9
-10
dcu_megatron/core/pipeline_parallel/schedules.py
dcu_megatron/core/pipeline_parallel/schedules.py
+362
-1133
dcu_megatron/core/tensor_parallel/__init__.py
dcu_megatron/core/tensor_parallel/__init__.py
+0
-4
dcu_megatron/core/tensor_parallel/layers.py
dcu_megatron/core/tensor_parallel/layers.py
+5
-1
dcu_megatron/core/transformer/transformer_config.py
dcu_megatron/core/transformer/transformer_config.py
+7
-0
dcu_megatron/training/arguments.py
dcu_megatron/training/arguments.py
+14
-0
No files found.
dcu_megatron/adaptor/megatron_adaptor.py
View file @
7c9dc3ec
...
...
@@ -5,6 +5,8 @@ import types
import
argparse
import
torch
from
megatron.core.utils
import
is_te_min_version
class
MegatronAdaptation
:
"""
...
...
@@ -89,14 +91,14 @@ class CoreAdaptation(MegatronAdaptationABC):
pass
def
patch_core_models
(
self
):
from
..core.models.gpt.gpt_model
import
gpt_model_init_wrapper
,
gpt_model_forward
from
..core.models.gpt.gpt_model
import
gpt_model_init_wrapper
,
GPTModel
# GPT Model
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.__init__'
,
gpt_model_init_wrapper
,
apply_wrapper
=
True
)
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel
.forward
'
,
gpt_model_forward
)
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel'
,
GPTModel
)
def
patch_core_transformers
(
self
):
from
..core
import
transformer_block_init_wrapper
...
...
@@ -116,9 +118,9 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation
.
register
(
'megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity'
,
torch
.
compile
(
options
=
{
"triton.cudagraphs"
:
True
,
"triton.cudagraph_trees"
:
False
}),
apply_wrapper
=
True
)
#
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func',
#
torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False, "triton.cudagraph_support_input_mutation":True}),
#
apply_wrapper=True)
MegatronAdaptation
.
register
(
'megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func'
,
torch
.
compile
(
options
=
{
"triton.cudagraphs"
:
True
,
"triton.cudagraph_trees"
:
False
,
"triton.cudagraph_support_input_mutation"
:
True
}),
apply_wrapper
=
True
)
MegatronAdaptation
.
register
(
'megatron.core.transformer.moe.moe_utils.permute'
,
torch
.
compile
(
mode
=
'max-autotune-no-cudagraphs'
),
apply_wrapper
=
True
)
...
...
@@ -132,12 +134,25 @@ class CoreAdaptation(MegatronAdaptationABC):
from
..core.extensions.transformer_engine
import
TEDotProductAttentionPatch
from
megatron.core.extensions.transformer_engine
import
TEGroupedLinear
# kv channels, te_min_version 1.10.0 -> 1.9.0
MegatronAdaptation
.
register
(
'megatron.core.extensions.transformer_engine.TEDotProductAttention.__init__'
,
TEDotProductAttentionPatch
.
__init__
)
if
not
is_te_min_version
(
"1.10.0"
):
# kv channels, te_min_version 1.10.0 -> 1.9.0
MegatronAdaptation
.
register
(
'megatron.core.extensions.transformer_engine.TEDotProductAttention.__init__'
,
TEDotProductAttentionPatch
.
__init__
)
if
int
(
os
.
getenv
(
"GROUPED_GEMM_BatchLinear"
,
'0'
)):
TEGroupedLinear
.
__bases__
=
(
te
.
pytorch
.
BatchLinear
,)
TEGroupedLinear
.
__bases__
=
(
te
.
pytorch
.
BatchedLinear
if
is_te_min_version
(
"2.3.0.dev0"
)
else
te
.
pytorch
.
BatchLinear
,)
def
patch_pipeline_parallel
(
self
):
from
..core.pipeline_parallel.schedules
import
get_pp_rank_microbatches
,
forward_backward_pipelining_with_interleaving
# num_warmup_microbatches + 1
MegatronAdaptation
.
register
(
'megatron.core.pipeline_parallel.schedules.get_pp_rank_microbatches'
,
get_pp_rank_microbatches
)
# a2a_overlap
MegatronAdaptation
.
register
(
'megatron.core.pipeline_parallel.schedules.forward_backward_pipelining_with_interleaving'
,
forward_backward_pipelining_with_interleaving
)
def
patch_tensor_parallel
(
self
):
from
..core.tensor_parallel.cross_entropy
import
VocabParallelCrossEntropy
...
...
@@ -162,7 +177,7 @@ class CoreAdaptation(MegatronAdaptationABC):
# flux
if
int
(
os
.
getenv
(
"USE_FLUX_OVERLAP"
,
"0"
)):
from
..core.tensor_parallel
import
(
from
..core.tensor_parallel
.layers
import
(
FluxColumnParallelLinear
,
FluxRowParallelLinear
)
...
...
dcu_megatron/core/models/gpt/gpt_layer_specs.py
View file @
7c9dc3ec
...
...
@@ -12,6 +12,7 @@ from megatron.core.transformer.multi_latent_attention import (
MLASelfAttentionSubmodules
,
)
from
megatron.core.transformer.spec_utils
import
ModuleSpec
from
megatron.core.transformer.torch_norm
import
L2Norm
from
megatron.core.transformer.transformer_block
import
TransformerBlockSubmodules
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core.transformer.transformer_layer
import
(
...
...
@@ -40,12 +41,6 @@ from dcu_megatron.core.tensor_parallel.layers import (
FluxColumnParallelLinear
,
FluxRowParallelLinear
)
from
dcu_megatron.core.transformer.multi_token_prediction
import
(
MultiTokenPredictionBlockSubmodules
,
get_mtp_layer_offset
,
get_mtp_layer_spec
,
get_mtp_num_layers_to_build
,
)
def
get_gpt_layer_with_flux_spec
(
...
...
@@ -55,6 +50,7 @@ def get_gpt_layer_with_flux_spec(
multi_latent_attention
:
Optional
[
bool
]
=
False
,
fp8
:
Optional
[
str
]
=
None
,
# pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm
:
Optional
[
bool
]
=
False
,
qk_l2_norm
:
Optional
[
bool
]
=
False
,
)
->
ModuleSpec
:
"""Use this spec to use flux modules (required for fp8 training).
...
...
@@ -66,6 +62,7 @@ def get_gpt_layer_with_flux_spec(
fp8 (str, optional): Deprecated. For temporary Nemo compatibility.
moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP.
Defaults to False.
qk_l2_norm (bool, optional): To use l2 norm for queries/keys. Defaults to False.
Returns:
ModuleSpec: Module specification with flux modules
...
...
@@ -84,6 +81,7 @@ def get_gpt_layer_with_flux_spec(
)
if
multi_latent_attention
:
assert
qk_l2_norm
is
False
,
"qk_l2_norm is not supported with MLA."
return
ModuleSpec
(
module
=
TransformerLayer
,
submodules
=
TransformerLayerSubmodules
(
...
...
@@ -127,8 +125,12 @@ def get_gpt_layer_with_flux_spec(
linear_qkv
=
FluxColumnParallelLinear
,
core_attention
=
TEDotProductAttention
,
linear_proj
=
FluxRowParallelLinear
,
q_layernorm
=
qk_norm
if
qk_layernorm
else
IdentityOp
,
k_layernorm
=
qk_norm
if
qk_layernorm
else
IdentityOp
,
q_layernorm
=
(
L2Norm
if
qk_l2_norm
else
(
qk_norm
if
qk_layernorm
else
IdentityOp
)
),
k_layernorm
=
(
L2Norm
if
qk_l2_norm
else
(
qk_norm
if
qk_layernorm
else
IdentityOp
)
),
),
),
self_attn_bda
=
get_bias_dropout_add
,
...
...
dcu_megatron/core/models/gpt/gpt_model.py
View file @
7c9dc3ec
...
...
@@ -13,8 +13,6 @@ from megatron.core.inference.contexts import BaseInferenceContext
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.models.gpt
import
GPTModel
as
MegatronCoreGPTModel
from
dcu_megatron.core.tensor_parallel
import
FluxColumnParallelLinear
def
gpt_model_init_wrapper
(
fn
):
@
wraps
(
fn
)
...
...
@@ -22,12 +20,13 @@ def gpt_model_init_wrapper(fn):
fn
(
self
,
*
args
,
**
kwargs
)
# Output
if
self
.
post_process
or
self
.
mtp_process
:
if
int
(
os
.
getenv
(
"USE_FLUX_OVERLAP"
,
"0"
)):
parallel_linear_impl
=
FluxColumnParallelLinear
else
:
parallel_linear_impl
=
tensor_parallel
.
ColumnParallelLinear
self
.
output_layer
=
parallel_linear_impl
(
if
(
(
self
.
post_process
or
self
.
mtp_process
)
and
int
(
os
.
getenv
(
"USE_FLUX_OVERLAP"
,
"0"
))
):
from
dcu_megatron.core.tensor_parallel.layers
import
FluxColumnParallelLinear
self
.
output_layer
=
FluxColumnParallelLinear
(
self
.
config
.
hidden_size
,
self
.
vocab_size
,
config
=
self
.
config
,
...
...
@@ -41,8 +40,8 @@ def gpt_model_init_wrapper(fn):
grad_output_buffer
=
self
.
grad_output_buffer
,
)
if
self
.
pre_process
or
self
.
post_process
:
self
.
setup_embeddings_and_output_layer
()
if
self
.
pre_process
or
self
.
post_process
:
self
.
setup_embeddings_and_output_layer
()
return
wrapper
...
...
dcu_megatron/core/pipeline_parallel/schedules.py
View file @
7c9dc3ec
This diff is collapsed.
Click to expand it.
dcu_megatron/core/tensor_parallel/__init__.py
deleted
100644 → 0
View file @
649bfbdb
from
.layers
import
(
FluxColumnParallelLinear
,
FluxRowParallelLinear
,
)
\ No newline at end of file
dcu_megatron/core/tensor_parallel/layers.py
View file @
7c9dc3ec
...
...
@@ -740,6 +740,7 @@ class FluxColumnParallelLinear(ColumnParallelLinear):
is_expert
:
bool
=
False
,
tp_comm_buffer_name
:
str
=
None
,
# Not used
disable_grad_reduce
:
bool
=
False
,
tp_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
):
super
(
FluxColumnParallelLinear
,
self
).
__init__
(
input_size
=
input_size
,
...
...
@@ -757,6 +758,7 @@ class FluxColumnParallelLinear(ColumnParallelLinear):
is_expert
=
is_expert
,
tp_comm_buffer_name
=
tp_comm_buffer_name
,
disable_grad_reduce
=
disable_grad_reduce
,
tp_group
=
tp_group
,
)
# flux params
...
...
@@ -961,6 +963,7 @@ class FluxRowParallelLinear(RowParallelLinear):
keep_master_weight_for_test
:
bool
=
False
,
is_expert
:
bool
=
False
,
tp_comm_buffer_name
:
str
=
None
,
# Not used
tp_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
):
super
(
FluxRowParallelLinear
,
self
).
__init__
(
...
...
@@ -974,7 +977,8 @@ class FluxRowParallelLinear(RowParallelLinear):
stride
=
stride
,
keep_master_weight_for_test
=
keep_master_weight_for_test
,
is_expert
=
is_expert
,
tp_comm_buffer_name
=
tp_comm_buffer_name
tp_comm_buffer_name
=
tp_comm_buffer_name
,
tp_group
=
tp_group
,
)
# flux params
...
...
dcu_megatron/core/transformer/transformer_config.py
View file @
7c9dc3ec
...
...
@@ -23,6 +23,7 @@ def transformer_config_post_init_wrapper(fn):
##################
self
.
flux_transpose_weight
=
args
.
flux_transpose_weight
return
wrapper
...
...
@@ -33,6 +34,12 @@ class ExtraTransformerConfig:
##################
flux_transpose_weight
:
bool
=
False
combined_1f1b
:
bool
=
False
"""If true, use combined 1F1B for communication hiding."""
combined_1f1b_recipe
:
str
=
'ep_a2a'
"""Recipe to use for combined 1F1B. Currently only 'ep_a2a' and 'golden' are supported."""
@
dataclass
class
TransformerConfigPatch
(
TransformerConfig
,
ExtraTransformerConfig
):
...
...
dcu_megatron/training/arguments.py
View file @
7c9dc3ec
...
...
@@ -26,6 +26,8 @@ def add_megatron_arguments_patch(parser: argparse.ArgumentParser):
parser
=
_add_extra_training_args
(
parser
)
parser
=
_add_extra_distributed_args
(
parser
)
parser
=
_add_extra_tokenizer_args
(
parser
)
parser
=
_add_extra_moe_args
(
parser
)
parser
=
_add_flux_args
(
parser
)
return
parser
...
...
@@ -128,6 +130,18 @@ def _add_extra_tokenizer_args(parser):
return
parser
def
_add_extra_moe_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
"extra moe args"
)
group
.
add_argument
(
'--combined-1f1b'
,
action
=
'store_true'
,
help
=
'Batch-level overlapping in 1f1b stage.'
)
group
.
add_argument
(
'--combined-1f1b-recipe'
,
type
=
str
,
choices
=
[
'ep_a2a'
,
'golden'
],
default
=
'golden'
,
help
=
'Options are "ep_a2a" and "golden".'
)
return
parser
def
_add_flux_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'flux args'
)
group
.
add_argument
(
'--flux-transpose-weight'
,
action
=
'store_true'
,
default
=
False
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment