Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
124accba
Commit
124accba
authored
May 20, 2025
by
dongcl
Browse files
support split_bw when te_version > 2.3.0.dev0
parent
bfe0b4a9
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
69 additions
and
22 deletions
+69
-22
dcu_megatron/adaptor/features_manager.py
dcu_megatron/adaptor/features_manager.py
+3
-4
dcu_megatron/core/extensions/transformer_engine.py
dcu_megatron/core/extensions/transformer_engine.py
+38
-10
dcu_megatron/core/models/gpt/fine_grained_schedule.py
dcu_megatron/core/models/gpt/fine_grained_schedule.py
+2
-0
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
+26
-5
dcu_megatron/core/transformer/transformer_config.py
dcu_megatron/core/transformer/transformer_config.py
+0
-3
No files found.
dcu_megatron/adaptor/features_manager.py
View file @
124accba
...
...
@@ -39,10 +39,9 @@ def a2a_overlap_adaptation(patches_manager):
create_dummy
=
True
)
# backward_dw
if
is_te_min_version
(
"2.4.0.dev0"
):
patches_manager
.
register_patch
(
'megatron.core.extensions.transformer_engine._get_extra_te_kwargs'
,
_get_extra_te_kwargs_wrapper
,
apply_wrapper
=
True
)
patches_manager
.
register_patch
(
'megatron.core.extensions.transformer_engine._get_extra_te_kwargs'
,
_get_extra_te_kwargs_wrapper
,
apply_wrapper
=
True
)
patches_manager
.
register_patch
(
'megatron.core.extensions.transformer_engine.TELinear'
,
TELinear
)
patches_manager
.
register_patch
(
'megatron.core.extensions.transformer_engine.TELayerNormColumnParallelLinear'
,
...
...
dcu_megatron/core/extensions/transformer_engine.py
View file @
124accba
import
os
import
copy
import
torch
import
dataclasses
import
transformer_engine
as
te
...
...
@@ -7,6 +8,7 @@ from functools import wraps
from
typing
import
Any
,
Optional
,
Callable
from
packaging.version
import
Version
as
PkgVersion
from
megatron.training
import
get_args
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.tensor_parallel
import
get_cuda_rng_tracker
from
megatron.core.utils
import
get_te_version
,
is_te_min_version
...
...
@@ -25,15 +27,17 @@ from megatron.core.parallel_state import (
)
def
_get_extra_te_kwargs_wrapper
(
fn
):
@
wraps
(
fn
)
def
_get_extra_te_kwargs_wrapper
(
_get_extra_te_kwargs_func
):
@
wraps
(
_get_extra_te_kwargs_func
)
def
wrapper
(
config
:
TransformerConfig
):
extra_transformer_engine_kwargs
=
fn
(
config
)
extra_transformer_engine_kwargs
=
_get_extra_te_kwargs_func
(
config
)
if
hasattr
(
config
,
"split_bw"
):
extra_transformer_engine_kwargs
[
"delay_wgrad_compute"
]
=
config
.
split_bw
return
extra_transformer_engine_kwargs
return
wrapper
if
is_te_min_version
(
"2.3.0.dev0"
):
return
wrapper
return
_get_extra_te_kwargs_func
class
TELinear
(
MegatronCoreTELinear
):
...
...
@@ -66,8 +70,14 @@ class TELinear(MegatronCoreTELinear):
tp_comm_buffer_name
:
Optional
[
str
]
=
None
,
is_expert
:
bool
=
False
,
):
self
.
split_bw
=
config
.
split_bw
if
hasattr
(
config
,
"split_bw"
)
else
False
assert
not
self
.
split_bw
,
"split_bw is currently not supported"
args
=
get_args
()
self
.
split_bw
=
args
.
split_bw
if
hasattr
(
args
,
"split_bw"
)
else
False
if
not
is_te_min_version
(
"2.3.0.dev0"
):
assert
not
self
.
split_bw
,
"split_bw is currently not supported"
if
self
.
split_bw
:
config
=
copy
.
copy
(
config
)
config
.
split_bw
=
True
super
().
__init__
(
input_size
,
...
...
@@ -86,6 +96,8 @@ class TELinear(MegatronCoreTELinear):
if
not
self
.
split_bw
:
return
return
super
(
MegatronCoreTELinear
,
self
).
backward_dw
()
class
TELayerNormColumnParallelLinear
(
MegatronCoreTELayerNormColumnParallelLinear
):
"""
...
...
@@ -107,8 +119,14 @@ class TELayerNormColumnParallelLinear(MegatronCoreTELayerNormColumnParallelLinea
skip_weight_param_allocation
:
bool
=
False
,
tp_comm_buffer_name
:
Optional
[
str
]
=
None
,
):
self
.
split_bw
=
config
.
split_bw
if
hasattr
(
config
,
"split_bw"
)
else
False
assert
not
self
.
split_bw
,
"split_bw is currently not supported"
args
=
get_args
()
self
.
split_bw
=
args
.
split_bw
if
hasattr
(
args
,
"split_bw"
)
else
False
if
not
is_te_min_version
(
"2.3.0.dev0"
):
assert
not
self
.
split_bw
,
"split_bw is currently not supported"
if
self
.
split_bw
:
config
=
copy
.
copy
(
config
)
config
.
split_bw
=
True
super
().
__init__
(
input_size
,
...
...
@@ -127,6 +145,8 @@ class TELayerNormColumnParallelLinear(MegatronCoreTELayerNormColumnParallelLinea
if
not
self
.
split_bw
:
return
return
super
(
MegatronCoreTELayerNormColumnParallelLinear
,
self
).
backward_dw
()
class
TEDotProductAttentionPatch
(
te
.
pytorch
.
DotProductAttention
):
def
__init__
(
...
...
@@ -289,8 +309,14 @@ if is_te_min_version("1.9.0.dev0"):
is_expert
:
bool
=
False
,
tp_comm_buffer_name
:
Optional
[
str
]
=
None
,
):
self
.
split_bw
=
config
.
split_bw
if
hasattr
(
config
,
"split_bw"
)
else
False
assert
not
self
.
split_bw
,
"split_bw is currently not supported"
args
=
get_args
()
self
.
split_bw
=
args
.
split_bw
if
hasattr
(
args
,
"split_bw"
)
else
False
if
not
is_te_min_version
(
"2.3.0.dev0"
):
assert
not
self
.
split_bw
,
"split_bw is currently not supported"
if
self
.
split_bw
:
config
=
copy
.
copy
(
config
)
config
.
split_bw
=
True
super
().
__init__
(
num_gemms
,
...
...
@@ -308,3 +334,5 @@ if is_te_min_version("1.9.0.dev0"):
def
backward_dw
(
self
):
if
not
self
.
split_bw
:
return
return
super
(
MegatronCoreTEGroupedLinear
,
self
).
backward_dw
()
dcu_megatron/core/models/gpt/fine_grained_schedule.py
View file @
124accba
...
...
@@ -342,6 +342,8 @@ class MoeMlPNode(TransformerLayerNode):
assert
mlp_bias
is
None
# pre_mlp_layernorm_output used
# cur_stream = torch.cuda.current_stream()
# self.common_state.pre_mlp_layernorm_output.record_stream(cur_stream)
self
.
common_state
.
pre_mlp_layernorm_output
=
None
return
expert_output
,
shared_expert_output
...
...
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
View file @
124accba
...
...
@@ -12,6 +12,7 @@ from megatron.core.distributed import DistributedDataParallel
from
megatron.core.transformer.module
import
Float16Module
from
megatron.core.transformer.moe.router
import
MoEAuxLossAutoScaler
from
megatron.core.transformer.multi_token_prediction
import
MTPLossAutoScaler
from
megatron.core.utils
import
get_attr_wrapped_model
,
make_viewless_tensor
...
...
@@ -56,6 +57,11 @@ class ScheduleNode:
self
.
outputs
=
None
def
default_backward_func
(
self
,
outputs
,
output_grad
):
# Handle scalar output
if
output_grad
is
None
:
assert
outputs
.
numel
()
==
1
,
"implicit grad requires scalar output."
output_grad
=
torch
.
ones_like
(
outputs
,
memory_format
=
torch
.
preserve_format
)
Variable
.
_execution_engine
.
run_backward
(
tensors
=
outputs
,
grad_tensors
=
output_grad
,
...
...
@@ -441,12 +447,13 @@ def forward_backward_step(
output_tensor
,
num_tokens
,
loss_reduced
=
outputs
if
not
config
.
calculate_per_token_loss
:
output_tensor
/=
num_tokens
output_tensor
*=
parallel_state
.
get_context_parallel_world_size
()
output_tensor
/=
num_microbatches
else
:
# preserve legacy loss averaging behavior
# (ie, over the number of microbatches)
# preserve legacy loss averaging behavior (ie, over the number of microbatches)
assert
len
(
outputs
)
==
2
output_tensor
,
loss_reduced
=
outputs
output_tensor
*=
parallel_state
.
get_context_parallel_world_size
()
output_tensor
=
output_tensor
/
num_microbatches
forward_data_store
.
append
(
loss_reduced
)
...
...
@@ -464,12 +471,11 @@ def forward_backward_step(
# Since we use a trick to do backward on the auxiliary loss, we need to set the scale
# explicitly.
if
hasattr
(
config
,
'num_moe_experts'
)
and
config
.
num_moe_experts
is
not
None
:
# Calculate the loss scale based on the grad_scale_func if available,
# else default to 1.
# Calculate the loss scale based on the grad_scale_func if available, else default to 1.
loss_scale
=
(
config
.
grad_scale_func
(
torch
.
ones
(
1
,
device
=
output_tensor
.
device
))
if
config
.
grad_scale_func
is
not
None
else
torch
.
tensor
(
1.0
)
else
torch
.
ones
(
1
,
device
=
output_tensor
.
device
)
)
# Set the loss scale
if
config
.
calculate_per_token_loss
:
...
...
@@ -477,8 +483,23 @@ def forward_backward_step(
else
:
MoEAuxLossAutoScaler
.
set_loss_scale
(
loss_scale
/
num_microbatches
)
# Set the loss scale for Multi-Token Prediction (MTP) loss.
if
hasattr
(
config
,
'mtp_num_layers'
)
and
config
.
mtp_num_layers
is
not
None
:
# Calculate the loss scale based on the grad_scale_func if available, else default to 1.
loss_scale
=
(
config
.
grad_scale_func
(
torch
.
ones
(
1
,
device
=
output_tensor
.
device
))
if
config
.
grad_scale_func
is
not
None
else
torch
.
ones
(
1
,
device
=
output_tensor
.
device
)
)
# Set the loss scale
if
config
.
calculate_per_token_loss
:
MTPLossAutoScaler
.
set_loss_scale
(
loss_scale
)
else
:
MTPLossAutoScaler
.
set_loss_scale
(
loss_scale
/
num_microbatches
)
if
not
unwrap_output_tensor
:
output_tensor
,
num_tokens
=
[
output_tensor
],
num_tokens
# backward post process
input_tensor_grad
=
None
if
b_model
is
not
None
:
...
...
dcu_megatron/core/transformer/transformer_config.py
View file @
124accba
...
...
@@ -40,9 +40,6 @@ class ExtraTransformerConfig:
combined_1f1b_recipe
:
str
=
'ep_a2a'
"""Recipe to use for combined 1F1B. Currently only 'ep_a2a' and 'golden' are supported."""
split_bw
:
bool
=
False
"""If true, split dgrad and wgrad for better overlapping in combined 1F1B."""
@
dataclass
class
TransformerConfigPatch
(
TransformerConfig
,
ExtraTransformerConfig
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment