Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
56819e16
Commit
56819e16
authored
May 29, 2025
by
dongcl
Browse files
Merge branch 'a2a_overlap', support 1f1b overlap
parent
1e8185f4
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
126 additions
and
76 deletions
+126
-76
dcu_megatron/adaptor/features_manager.py
dcu_megatron/adaptor/features_manager.py
+3
-4
dcu_megatron/core/extensions/transformer_engine.py
dcu_megatron/core/extensions/transformer_engine.py
+38
-10
dcu_megatron/core/models/gpt/fine_grained_schedule.py
dcu_megatron/core/models/gpt/fine_grained_schedule.py
+23
-14
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
+49
-29
dcu_megatron/core/pipeline_parallel/schedules.py
dcu_megatron/core/pipeline_parallel/schedules.py
+3
-3
dcu_megatron/core/transformer/transformer_config.py
dcu_megatron/core/transformer/transformer_config.py
+1
-2
dcu_megatron/core/transformer/transformer_layer.py
dcu_megatron/core/transformer/transformer_layer.py
+6
-6
dcu_megatron/training/initialize.py
dcu_megatron/training/initialize.py
+1
-1
dcu_megatron/training/training.py
dcu_megatron/training/training.py
+2
-1
pretrain_gpt.py
pretrain_gpt.py
+0
-6
No files found.
dcu_megatron/adaptor/features_manager.py
View file @
56819e16
...
@@ -39,10 +39,9 @@ def a2a_overlap_adaptation(patches_manager):
...
@@ -39,10 +39,9 @@ def a2a_overlap_adaptation(patches_manager):
create_dummy
=
True
)
create_dummy
=
True
)
# backward_dw
# backward_dw
if
is_te_min_version
(
"2.4.0.dev0"
):
patches_manager
.
register_patch
(
'megatron.core.extensions.transformer_engine._get_extra_te_kwargs'
,
patches_manager
.
register_patch
(
'megatron.core.extensions.transformer_engine._get_extra_te_kwargs'
,
_get_extra_te_kwargs_wrapper
,
_get_extra_te_kwargs_wrapper
,
apply_wrapper
=
True
)
apply_wrapper
=
True
)
patches_manager
.
register_patch
(
'megatron.core.extensions.transformer_engine.TELinear'
,
patches_manager
.
register_patch
(
'megatron.core.extensions.transformer_engine.TELinear'
,
TELinear
)
TELinear
)
patches_manager
.
register_patch
(
'megatron.core.extensions.transformer_engine.TELayerNormColumnParallelLinear'
,
patches_manager
.
register_patch
(
'megatron.core.extensions.transformer_engine.TELayerNormColumnParallelLinear'
,
...
...
dcu_megatron/core/extensions/transformer_engine.py
View file @
56819e16
import
os
import
os
import
copy
import
torch
import
torch
import
dataclasses
import
dataclasses
import
transformer_engine
as
te
import
transformer_engine
as
te
...
@@ -7,6 +8,7 @@ from functools import wraps
...
@@ -7,6 +8,7 @@ from functools import wraps
from
typing
import
Any
,
Optional
,
Callable
from
typing
import
Any
,
Optional
,
Callable
from
packaging.version
import
Version
as
PkgVersion
from
packaging.version
import
Version
as
PkgVersion
from
megatron.training
import
get_args
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.tensor_parallel
import
get_cuda_rng_tracker
from
megatron.core.tensor_parallel
import
get_cuda_rng_tracker
from
megatron.core.utils
import
get_te_version
,
is_te_min_version
from
megatron.core.utils
import
get_te_version
,
is_te_min_version
...
@@ -25,15 +27,17 @@ from megatron.core.parallel_state import (
...
@@ -25,15 +27,17 @@ from megatron.core.parallel_state import (
)
)
def
_get_extra_te_kwargs_wrapper
(
fn
):
def
_get_extra_te_kwargs_wrapper
(
_get_extra_te_kwargs_func
):
@
wraps
(
fn
)
@
wraps
(
_get_extra_te_kwargs_func
)
def
wrapper
(
config
:
TransformerConfig
):
def
wrapper
(
config
:
TransformerConfig
):
extra_transformer_engine_kwargs
=
fn
(
config
)
extra_transformer_engine_kwargs
=
_get_extra_te_kwargs_func
(
config
)
if
hasattr
(
config
,
"split_bw"
):
if
hasattr
(
config
,
"split_bw"
):
extra_transformer_engine_kwargs
[
"delay_wgrad_compute"
]
=
config
.
split_bw
extra_transformer_engine_kwargs
[
"delay_wgrad_compute"
]
=
config
.
split_bw
return
extra_transformer_engine_kwargs
return
extra_transformer_engine_kwargs
return
wrapper
if
is_te_min_version
(
"2.3.0.dev0"
):
return
wrapper
return
_get_extra_te_kwargs_func
class
TELinear
(
MegatronCoreTELinear
):
class
TELinear
(
MegatronCoreTELinear
):
...
@@ -66,8 +70,14 @@ class TELinear(MegatronCoreTELinear):
...
@@ -66,8 +70,14 @@ class TELinear(MegatronCoreTELinear):
tp_comm_buffer_name
:
Optional
[
str
]
=
None
,
tp_comm_buffer_name
:
Optional
[
str
]
=
None
,
is_expert
:
bool
=
False
,
is_expert
:
bool
=
False
,
):
):
self
.
split_bw
=
config
.
split_bw
if
hasattr
(
config
,
"split_bw"
)
else
False
args
=
get_args
()
assert
not
self
.
split_bw
,
"split_bw is currently not supported"
self
.
split_bw
=
args
.
split_bw
if
hasattr
(
args
,
"split_bw"
)
else
False
if
not
is_te_min_version
(
"2.3.0.dev0"
):
assert
not
self
.
split_bw
,
"split_bw is currently not supported"
if
self
.
split_bw
:
config
=
copy
.
copy
(
config
)
config
.
split_bw
=
True
super
().
__init__
(
super
().
__init__
(
input_size
,
input_size
,
...
@@ -86,6 +96,8 @@ class TELinear(MegatronCoreTELinear):
...
@@ -86,6 +96,8 @@ class TELinear(MegatronCoreTELinear):
if
not
self
.
split_bw
:
if
not
self
.
split_bw
:
return
return
return
super
(
MegatronCoreTELinear
,
self
).
backward_dw
()
class
TELayerNormColumnParallelLinear
(
MegatronCoreTELayerNormColumnParallelLinear
):
class
TELayerNormColumnParallelLinear
(
MegatronCoreTELayerNormColumnParallelLinear
):
"""
"""
...
@@ -107,8 +119,14 @@ class TELayerNormColumnParallelLinear(MegatronCoreTELayerNormColumnParallelLinea
...
@@ -107,8 +119,14 @@ class TELayerNormColumnParallelLinear(MegatronCoreTELayerNormColumnParallelLinea
skip_weight_param_allocation
:
bool
=
False
,
skip_weight_param_allocation
:
bool
=
False
,
tp_comm_buffer_name
:
Optional
[
str
]
=
None
,
tp_comm_buffer_name
:
Optional
[
str
]
=
None
,
):
):
self
.
split_bw
=
config
.
split_bw
if
hasattr
(
config
,
"split_bw"
)
else
False
args
=
get_args
()
assert
not
self
.
split_bw
,
"split_bw is currently not supported"
self
.
split_bw
=
args
.
split_bw
if
hasattr
(
args
,
"split_bw"
)
else
False
if
not
is_te_min_version
(
"2.3.0.dev0"
):
assert
not
self
.
split_bw
,
"split_bw is currently not supported"
if
self
.
split_bw
:
config
=
copy
.
copy
(
config
)
config
.
split_bw
=
True
super
().
__init__
(
super
().
__init__
(
input_size
,
input_size
,
...
@@ -127,6 +145,8 @@ class TELayerNormColumnParallelLinear(MegatronCoreTELayerNormColumnParallelLinea
...
@@ -127,6 +145,8 @@ class TELayerNormColumnParallelLinear(MegatronCoreTELayerNormColumnParallelLinea
if
not
self
.
split_bw
:
if
not
self
.
split_bw
:
return
return
return
super
(
MegatronCoreTELayerNormColumnParallelLinear
,
self
).
backward_dw
()
class
TEDotProductAttentionPatch
(
te
.
pytorch
.
DotProductAttention
):
class
TEDotProductAttentionPatch
(
te
.
pytorch
.
DotProductAttention
):
def
__init__
(
def
__init__
(
...
@@ -289,8 +309,14 @@ if is_te_min_version("1.9.0.dev0"):
...
@@ -289,8 +309,14 @@ if is_te_min_version("1.9.0.dev0"):
is_expert
:
bool
=
False
,
is_expert
:
bool
=
False
,
tp_comm_buffer_name
:
Optional
[
str
]
=
None
,
tp_comm_buffer_name
:
Optional
[
str
]
=
None
,
):
):
self
.
split_bw
=
config
.
split_bw
if
hasattr
(
config
,
"split_bw"
)
else
False
args
=
get_args
()
assert
not
self
.
split_bw
,
"split_bw is currently not supported"
self
.
split_bw
=
args
.
split_bw
if
hasattr
(
args
,
"split_bw"
)
else
False
if
not
is_te_min_version
(
"2.3.0.dev0"
):
assert
not
self
.
split_bw
,
"split_bw is currently not supported"
if
self
.
split_bw
:
config
=
copy
.
copy
(
config
)
config
.
split_bw
=
True
super
().
__init__
(
super
().
__init__
(
num_gemms
,
num_gemms
,
...
@@ -308,3 +334,5 @@ if is_te_min_version("1.9.0.dev0"):
...
@@ -308,3 +334,5 @@ if is_te_min_version("1.9.0.dev0"):
def
backward_dw
(
self
):
def
backward_dw
(
self
):
if
not
self
.
split_bw
:
if
not
self
.
split_bw
:
return
return
return
super
(
MegatronCoreTEGroupedLinear
,
self
).
backward_dw
()
dcu_megatron/core/models/gpt/fine_grained_schedule.py
View file @
56819e16
...
@@ -540,6 +540,8 @@ class ModelChunkSchedulePlan(AbstractSchedulePlan):
...
@@ -540,6 +540,8 @@ class ModelChunkSchedulePlan(AbstractSchedulePlan):
def
state
(
self
):
def
state
(
self
):
return
self
.
_model_chunk_state
return
self
.
_model_chunk_state
# F_DISPATCH_B_MLP_SYNC_EVENT = torch.cuda.Event()
F_DISPATCH_B_MLP_SYNC_EVENT
=
None
def
schedule_layer_1f1b
(
def
schedule_layer_1f1b
(
f_layer
,
f_layer
,
...
@@ -579,13 +581,17 @@ def schedule_layer_1f1b(
...
@@ -579,13 +581,17 @@ def schedule_layer_1f1b(
with
f_context
:
with
f_context
:
f_input
=
f_layer
.
attn
.
forward
(
f_input
)
f_input
=
f_layer
.
attn
.
forward
(
f_input
)
f_dispatch_b_mlp_sync_event
=
None
if
f_layer
is
not
None
and
b_layer
is
not
None
:
f_dispatch_b_mlp_sync_event
=
F_DISPATCH_B_MLP_SYNC_EVENT
if
f_layer
is
not
None
:
if
f_layer
is
not
None
:
with
f_context
:
with
f_context
:
f_input
=
f_layer
.
dispatch
.
forward
(
f_input
)
f_input
=
f_layer
.
dispatch
.
forward
(
f_input
,
stream_record_event
=
f_dispatch_b_mlp_sync_event
)
if
b_layer
is
not
None
:
if
b_layer
is
not
None
:
with
b_context
:
with
b_context
:
b_grad
=
b_layer
.
mlp
.
backward
(
b_grad
)
b_grad
=
b_layer
.
mlp
.
backward
(
b_grad
,
stream_wait_event
=
f_dispatch_b_mlp_sync_event
)
b_grad
=
b_layer
.
dispatch
.
backward
(
b_grad
)
b_grad
=
b_layer
.
dispatch
.
backward
(
b_grad
)
b_layer
.
mlp
.
dw
()
b_layer
.
mlp
.
dw
()
...
@@ -688,32 +694,28 @@ def schedule_chunk_1f1b(
...
@@ -688,32 +694,28 @@ def schedule_chunk_1f1b(
)
)
torch
.
cuda
.
nvtx
.
range_pop
()
torch
.
cuda
.
nvtx
.
range_pop
()
# tail forward
f_input
=
layer_pre_forward
()
del
layer_pre_forward
# tail backward
# tail backward
grad
=
layer_pre_backward
()
grad
=
layer_pre_backward
()
del
layer_pre_backward
del
layer_pre_backward
with
b_context
:
with
b_context
:
for
i
in
range
(
overlaped_layers
,
b_num_layers
):
for
i
in
range
(
overlaped_layers
,
b_num_layers
):
b_layer
=
b_schedule_plan
.
get_layer
(
b_num_layers
-
1
-
i
)
b_layer
=
b_schedule_plan
.
get_layer
(
b_num_layers
-
1
-
i
)
torch
.
cuda
.
nvtx
.
range_push
(
f
"layer_
{
b_num_layers
-
1
-
i
}
b"
)
torch
.
cuda
.
nvtx
.
range_push
(
f
"layer_
{
b_num_layers
-
1
-
i
}
b"
)
tmp
,
grad
,
_
=
schedule_layer_1f1b
(
None
,
b_layer
,
b_grad
=
grad
)
_
,
grad
,
_
=
schedule_layer_1f1b
(
None
,
b_layer
,
b_grad
=
grad
)
torch
.
cuda
.
nvtx
.
range_pop
()
torch
.
cuda
.
nvtx
.
range_pop
()
if
b_schedule_plan
is
not
None
:
b_schedule_plan
.
pre_process
.
backward
(
grad
)
# tail forward
f_input
=
layer_pre_forward
()
del
layer_pre_forward
with
f_context
:
with
f_context
:
for
i
in
range
(
overlaped_layers
,
f_num_layers
):
for
i
in
range
(
overlaped_layers
,
f_num_layers
):
f_layer
=
f_schedule_plan
.
get_layer
(
i
)
f_layer
=
f_schedule_plan
.
get_layer
(
i
)
torch
.
cuda
.
nvtx
.
range_push
(
f
"layer_
{
i
}
f"
)
torch
.
cuda
.
nvtx
.
range_push
(
f
"layer_
{
i
}
f"
)
f_input
,
tmp
,
_
=
schedule_layer_1f1b
(
f_layer
,
None
,
f_input
=
f_input
)
f_input
,
_
,
_
=
schedule_layer_1f1b
(
f_layer
,
None
,
f_input
=
f_input
)
torch
.
cuda
.
nvtx
.
range_pop
()
torch
.
cuda
.
nvtx
.
range_pop
()
if
f_schedule_plan
is
not
None
and
f_schedule_plan
.
post_process
is
not
None
:
f_input
=
f_schedule_plan
.
post_process
.
forward
(
f_input
)
# output pp send receive, overlapped with attn backward
# output pp send receive, overlapped with attn backward
if
f_schedule_plan
is
not
None
and
post_forward
is
not
None
:
if
f_schedule_plan
is
not
None
and
post_forward
is
not
None
:
with
f_context
:
with
f_context
:
...
@@ -730,6 +732,13 @@ def schedule_chunk_1f1b(
...
@@ -730,6 +732,13 @@ def schedule_chunk_1f1b(
layer_pre_backward_dw
()
layer_pre_backward_dw
()
del
layer_pre_backward_dw
del
layer_pre_backward_dw
with
f_context
:
if
f_schedule_plan
is
not
None
and
f_schedule_plan
.
post_process
is
not
None
:
f_input
=
f_schedule_plan
.
post_process
.
forward
(
f_input
)
with
b_context
:
if
b_schedule_plan
is
not
None
:
b_schedule_plan
.
pre_process
.
backward
(
grad
)
if
f_schedule_plan
:
if
f_schedule_plan
:
f_schedule_plan
.
wait_current_stream
()
f_schedule_plan
.
wait_current_stream
()
if
b_schedule_plan
:
if
b_schedule_plan
:
...
@@ -764,7 +773,7 @@ def build_model_chunk_schedule_plan(
...
@@ -764,7 +773,7 @@ def build_model_chunk_schedule_plan(
state
.
attention_mask
=
attention_mask
state
.
attention_mask
=
attention_mask
state
.
decoder_input
=
decoder_input
state
.
decoder_input
=
decoder_input
state
.
labels
=
labels
state
.
labels
=
labels
state
.
inference_context
=
inference_context
state
.
inference_context
=
inference_context
state
.
packed_seq_params
=
packed_seq_params
state
.
packed_seq_params
=
packed_seq_params
state
.
extra_block_kwargs
=
extra_block_kwargs
state
.
extra_block_kwargs
=
extra_block_kwargs
state
.
runtime_gather_output
=
runtime_gather_output
state
.
runtime_gather_output
=
runtime_gather_output
...
...
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
View file @
56819e16
import
contextlib
import
contextlib
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
Any
,
List
,
Tuple
,
Union
from
typing
import
List
,
Union
import
torch
import
torch
from
torch
import
Tensor
from
torch
import
Tensor
...
@@ -12,13 +12,10 @@ from megatron.core.distributed import DistributedDataParallel
...
@@ -12,13 +12,10 @@ from megatron.core.distributed import DistributedDataParallel
from
megatron.core.transformer.module
import
Float16Module
from
megatron.core.transformer.module
import
Float16Module
from
megatron.core.transformer.moe.router
import
MoEAuxLossAutoScaler
from
megatron.core.transformer.moe.router
import
MoEAuxLossAutoScaler
from
megatron.core.transformer.multi_token_prediction
import
MTPLossAutoScaler
from
megatron.core.utils
import
get_attr_wrapped_model
,
make_viewless_tensor
from
megatron.core.utils
import
get_attr_wrapped_model
,
make_viewless_tensor
# Types
Shape
=
Union
[
List
[
int
],
torch
.
Size
]
def
make_viewless
(
e
):
def
make_viewless
(
e
):
"""make_viewless util func"""
"""make_viewless util func"""
e
=
make_viewless_tensor
(
inp
=
e
,
requires_grad
=
e
.
requires_grad
,
keep_graph
=
True
)
e
=
make_viewless_tensor
(
inp
=
e
,
requires_grad
=
e
.
requires_grad
,
keep_graph
=
True
)
...
@@ -56,6 +53,11 @@ class ScheduleNode:
...
@@ -56,6 +53,11 @@ class ScheduleNode:
self
.
outputs
=
None
self
.
outputs
=
None
def
default_backward_func
(
self
,
outputs
,
output_grad
):
def
default_backward_func
(
self
,
outputs
,
output_grad
):
# Handle scalar output
if
output_grad
is
None
:
assert
outputs
.
numel
()
==
1
,
"implicit grad requires scalar output."
output_grad
=
torch
.
ones_like
(
outputs
,
memory_format
=
torch
.
preserve_format
)
Variable
.
_execution_engine
.
run_backward
(
Variable
.
_execution_engine
.
run_backward
(
tensors
=
outputs
,
tensors
=
outputs
,
grad_tensors
=
output_grad
,
grad_tensors
=
output_grad
,
...
@@ -67,17 +69,20 @@ class ScheduleNode:
...
@@ -67,17 +69,20 @@ class ScheduleNode:
)
)
return
output_grad
return
output_grad
def
forward
(
self
,
inputs
=
()):
def
forward
(
self
,
inputs
=
()
,
stream_wait_event
=
None
,
stream_record_event
=
None
):
"""schedule node forward"""
"""schedule node forward"""
if
not
isinstance
(
inputs
,
tuple
):
if
not
isinstance
(
inputs
,
tuple
):
inputs
=
(
inputs
,)
inputs
=
(
inputs
,)
return
self
.
_forward
(
*
inputs
)
return
self
.
_forward
(
*
inputs
,
stream_wait_event
=
stream_wait_event
,
stream_record_event
=
stream_record_event
)
def
_forward
(
self
,
*
inputs
):
def
_forward
(
self
,
*
inputs
,
stream_wait_event
=
None
,
stream_record_event
=
None
):
with
stream_acquire_context
(
self
.
stream
,
self
.
event
):
with
stream_acquire_context
(
self
.
stream
,
self
.
event
):
torch
.
cuda
.
nvtx
.
range_push
(
f
"
{
self
.
name
}
forward"
)
torch
.
cuda
.
nvtx
.
range_push
(
f
"
{
self
.
name
}
forward"
)
with
torch
.
cuda
.
stream
(
self
.
stream
):
with
torch
.
cuda
.
stream
(
self
.
stream
):
if
stream_wait_event
is
not
None
:
stream_wait_event
.
wait
(
self
.
stream
)
self
.
inputs
=
[
make_viewless
(
e
).
detach
()
if
e
is
not
None
else
None
for
e
in
inputs
]
self
.
inputs
=
[
make_viewless
(
e
).
detach
()
if
e
is
not
None
else
None
for
e
in
inputs
]
for
i
,
input
in
enumerate
(
self
.
inputs
):
for
i
,
input
in
enumerate
(
self
.
inputs
):
if
input
is
not
None
:
if
input
is
not
None
:
...
@@ -92,6 +97,10 @@ class ScheduleNode:
...
@@ -92,6 +97,10 @@ class ScheduleNode:
data
=
tuple
([
make_viewless
(
e
)
if
isinstance
(
e
,
Tensor
)
else
e
for
e
in
data
])
data
=
tuple
([
make_viewless
(
e
)
if
isinstance
(
e
,
Tensor
)
else
e
for
e
in
data
])
self
.
output
=
data
self
.
output
=
data
if
stream_record_event
is
not
None
:
stream_record_event
.
record
(
self
.
stream
)
torch
.
cuda
.
nvtx
.
range_pop
()
torch
.
cuda
.
nvtx
.
range_pop
()
if
self
.
free_inputs
:
if
self
.
free_inputs
:
...
@@ -105,16 +114,19 @@ class ScheduleNode:
...
@@ -105,16 +114,19 @@ class ScheduleNode:
"""get the forward output"""
"""get the forward output"""
return
self
.
output
return
self
.
output
def
backward
(
self
,
output_grad
):
def
backward
(
self
,
output_grad
,
stream_wait_event
=
None
,
stream_record_event
=
None
):
"""schedule node backward"""
"""schedule node backward"""
if
not
isinstance
(
output_grad
,
tuple
):
if
not
isinstance
(
output_grad
,
tuple
):
output_grad
=
(
output_grad
,)
output_grad
=
(
output_grad
,)
return
self
.
_backward
(
*
output_grad
)
return
self
.
_backward
(
*
output_grad
,
stream_wait_event
=
stream_wait_event
,
stream_record_event
=
stream_record_event
)
def
_backward
(
self
,
*
output_grad
):
def
_backward
(
self
,
*
output_grad
,
stream_wait_event
=
None
,
stream_record_event
=
None
):
with
stream_acquire_context
(
self
.
stream
,
self
.
event
):
with
stream_acquire_context
(
self
.
stream
,
self
.
event
):
torch
.
cuda
.
nvtx
.
range_push
(
f
"
{
self
.
name
}
backward"
)
torch
.
cuda
.
nvtx
.
range_push
(
f
"
{
self
.
name
}
backward"
)
with
torch
.
cuda
.
stream
(
self
.
stream
):
with
torch
.
cuda
.
stream
(
self
.
stream
):
if
stream_wait_event
is
not
None
:
stream_wait_event
.
wait
(
self
.
stream
)
outputs
=
self
.
output
outputs
=
self
.
output
if
not
isinstance
(
outputs
,
tuple
):
if
not
isinstance
(
outputs
,
tuple
):
outputs
=
(
outputs
,)
outputs
=
(
outputs
,)
...
@@ -125,6 +137,10 @@ class ScheduleNode:
...
@@ -125,6 +137,10 @@ class ScheduleNode:
output_grad
=
self
.
backward_func
(
outputs
,
output_grad
)
output_grad
=
self
.
backward_func
(
outputs
,
output_grad
)
else
:
else
:
output_grad
=
self
.
default_backward_func
(
outputs
,
output_grad
)
output_grad
=
self
.
default_backward_func
(
outputs
,
output_grad
)
if
stream_record_event
is
not
None
:
stream_record_event
.
record
(
self
.
stream
)
torch
.
cuda
.
nvtx
.
range_pop
()
torch
.
cuda
.
nvtx
.
range_pop
()
# output_grad maybe from another stream
# output_grad maybe from another stream
...
@@ -192,17 +208,6 @@ def schedule_chunk_1f1b(
...
@@ -192,17 +208,6 @@ def schedule_chunk_1f1b(
)
)
def
schedule_chunk_forward
(
schedule_plan
):
"""model level fine-grained forward schedule"""
f_input
=
schedule_chunk_1f1b
(
schedule_plan
,
None
,
None
)
return
f_input
def
schedule_chunk_backward
(
schedule_plan
,
grad
):
"""model level fine-grained backward schedule"""
tmp
=
schedule_chunk_1f1b
(
None
,
schedule_plan
,
grad
)
_COMP_STREAM
=
None
_COMP_STREAM
=
None
_COM_STREAM
=
None
_COM_STREAM
=
None
...
@@ -215,7 +220,7 @@ def set_streams(comp_stream=None, com_stream=None):
...
@@ -215,7 +220,7 @@ def set_streams(comp_stream=None, com_stream=None):
return
return
if
comp_stream
is
None
:
if
comp_stream
is
None
:
comp_stream
=
torch
.
cuda
.
Stream
(
device
=
"cuda"
)
comp_stream
=
torch
.
cuda
.
current_stream
(
)
if
com_stream
is
None
:
if
com_stream
is
None
:
com_stream
=
torch
.
cuda
.
Stream
(
device
=
"cuda"
)
com_stream
=
torch
.
cuda
.
Stream
(
device
=
"cuda"
)
...
@@ -342,7 +347,7 @@ def forward_backward_step(
...
@@ -342,7 +347,7 @@ def forward_backward_step(
Tensor or list[Tensor]: The output object(s) from the forward step.
Tensor or list[Tensor]: The output object(s) from the forward step.
Tensor: The number of tokens.
Tensor: The number of tokens.
"""
"""
from
.schedules
import
set_current_microbatch
from
megatron.core.pipeline_parallel
.schedules
import
set_current_microbatch
if
config
.
timers
is
not
None
:
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-compute'
,
log_level
=
2
).
start
()
config
.
timers
(
'forward-compute'
,
log_level
=
2
).
start
()
...
@@ -441,12 +446,13 @@ def forward_backward_step(
...
@@ -441,12 +446,13 @@ def forward_backward_step(
output_tensor
,
num_tokens
,
loss_reduced
=
outputs
output_tensor
,
num_tokens
,
loss_reduced
=
outputs
if
not
config
.
calculate_per_token_loss
:
if
not
config
.
calculate_per_token_loss
:
output_tensor
/=
num_tokens
output_tensor
/=
num_tokens
output_tensor
*=
parallel_state
.
get_context_parallel_world_size
()
output_tensor
/=
num_microbatches
output_tensor
/=
num_microbatches
else
:
else
:
# preserve legacy loss averaging behavior
# preserve legacy loss averaging behavior (ie, over the number of microbatches)
# (ie, over the number of microbatches)
assert
len
(
outputs
)
==
2
assert
len
(
outputs
)
==
2
output_tensor
,
loss_reduced
=
outputs
output_tensor
,
loss_reduced
=
outputs
output_tensor
*=
parallel_state
.
get_context_parallel_world_size
()
output_tensor
=
output_tensor
/
num_microbatches
output_tensor
=
output_tensor
/
num_microbatches
forward_data_store
.
append
(
loss_reduced
)
forward_data_store
.
append
(
loss_reduced
)
...
@@ -464,12 +470,11 @@ def forward_backward_step(
...
@@ -464,12 +470,11 @@ def forward_backward_step(
# Since we use a trick to do backward on the auxiliary loss, we need to set the scale
# Since we use a trick to do backward on the auxiliary loss, we need to set the scale
# explicitly.
# explicitly.
if
hasattr
(
config
,
'num_moe_experts'
)
and
config
.
num_moe_experts
is
not
None
:
if
hasattr
(
config
,
'num_moe_experts'
)
and
config
.
num_moe_experts
is
not
None
:
# Calculate the loss scale based on the grad_scale_func if available,
# Calculate the loss scale based on the grad_scale_func if available, else default to 1.
# else default to 1.
loss_scale
=
(
loss_scale
=
(
config
.
grad_scale_func
(
torch
.
ones
(
1
,
device
=
output_tensor
.
device
))
config
.
grad_scale_func
(
torch
.
ones
(
1
,
device
=
output_tensor
.
device
))
if
config
.
grad_scale_func
is
not
None
if
config
.
grad_scale_func
is
not
None
else
torch
.
tensor
(
1.0
)
else
torch
.
ones
(
1
,
device
=
output_tensor
.
device
)
)
)
# Set the loss scale
# Set the loss scale
if
config
.
calculate_per_token_loss
:
if
config
.
calculate_per_token_loss
:
...
@@ -477,8 +482,23 @@ def forward_backward_step(
...
@@ -477,8 +482,23 @@ def forward_backward_step(
else
:
else
:
MoEAuxLossAutoScaler
.
set_loss_scale
(
loss_scale
/
num_microbatches
)
MoEAuxLossAutoScaler
.
set_loss_scale
(
loss_scale
/
num_microbatches
)
# Set the loss scale for Multi-Token Prediction (MTP) loss.
if
hasattr
(
config
,
'mtp_num_layers'
)
and
config
.
mtp_num_layers
is
not
None
:
# Calculate the loss scale based on the grad_scale_func if available, else default to 1.
loss_scale
=
(
config
.
grad_scale_func
(
torch
.
ones
(
1
,
device
=
output_tensor
.
device
))
if
config
.
grad_scale_func
is
not
None
else
torch
.
ones
(
1
,
device
=
output_tensor
.
device
)
)
# Set the loss scale
if
config
.
calculate_per_token_loss
:
MTPLossAutoScaler
.
set_loss_scale
(
loss_scale
)
else
:
MTPLossAutoScaler
.
set_loss_scale
(
loss_scale
/
num_microbatches
)
if
not
unwrap_output_tensor
:
if
not
unwrap_output_tensor
:
output_tensor
,
num_tokens
=
[
output_tensor
],
num_tokens
output_tensor
,
num_tokens
=
[
output_tensor
],
num_tokens
# backward post process
# backward post process
input_tensor_grad
=
None
input_tensor_grad
=
None
if
b_model
is
not
None
:
if
b_model
is
not
None
:
...
...
dcu_megatron/core/pipeline_parallel/schedules.py
View file @
56819e16
import
contextlib
import
contextlib
from
typing
import
Callable
,
Iterator
,
List
,
Optional
,
Union
from
typing
import
Iterator
,
List
,
Union
import
torch
import
torch
...
@@ -7,10 +7,8 @@ from megatron.training import get_args
...
@@ -7,10 +7,8 @@ from megatron.training import get_args
from
megatron.core
import
parallel_state
from
megatron.core
import
parallel_state
from
megatron.core.enums
import
ModelType
from
megatron.core.enums
import
ModelType
from
megatron.core.pipeline_parallel
import
p2p_communication
from
megatron.core.pipeline_parallel
import
p2p_communication
from
megatron.core.pipeline_parallel.schedules
import
set_current_microbatch
from
megatron.core.transformer.cuda_graphs
import
create_cudagraphs
from
megatron.core.transformer.cuda_graphs
import
create_cudagraphs
from
megatron.core.utils
import
(
from
megatron.core.utils
import
(
get_attr_wrapped_model
,
get_model_config
,
get_model_config
,
get_model_type
,
get_model_type
,
get_model_xattn
,
get_model_xattn
,
...
@@ -448,6 +446,8 @@ def forward_backward_pipelining_with_interleaving(
...
@@ -448,6 +446,8 @@ def forward_backward_pipelining_with_interleaving(
"""Helper method to run backward step with model split into chunks
"""Helper method to run backward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
(run set_virtual_pipeline_model_parallel_rank() before calling
backward_step())."""
backward_step())."""
nonlocal
output_tensor_grads
model_chunk_id
=
get_model_chunk_id
(
virtual_microbatch_id
,
forward
=
False
)
model_chunk_id
=
get_model_chunk_id
(
virtual_microbatch_id
,
forward
=
False
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
...
...
dcu_megatron/core/transformer/transformer_config.py
View file @
56819e16
...
@@ -40,8 +40,7 @@ class ExtraTransformerConfig:
...
@@ -40,8 +40,7 @@ class ExtraTransformerConfig:
combined_1f1b_recipe
:
str
=
'ep_a2a'
combined_1f1b_recipe
:
str
=
'ep_a2a'
"""Recipe to use for combined 1F1B. Currently only 'ep_a2a' and 'golden' are supported."""
"""Recipe to use for combined 1F1B. Currently only 'ep_a2a' and 'golden' are supported."""
split_bw
:
bool
=
False
# split_bw: bool = False
"""If true, split dgrad and wgrad for better overlapping in combined 1F1B."""
@
dataclass
@
dataclass
...
...
dcu_megatron/core/transformer/transformer_layer.py
View file @
56819e16
from
functools
import
partial
from
typing
import
Any
,
Optional
from
typing
import
Any
,
Optional
import
torch
from
torch
import
Tensor
from
torch
import
Tensor
from
megatron.core
import
tensor_parallel
from
megatron.core
import
tensor_parallel
...
@@ -12,8 +10,7 @@ from megatron.core.utils import (
...
@@ -12,8 +10,7 @@ from megatron.core.utils import (
)
)
from
megatron.core.transformer.moe.moe_layer
import
MoELayer
from
megatron.core.transformer.moe.moe_layer
import
MoELayer
from
megatron.core.transformer.transformer_layer
import
TransformerLayer
as
MegatronCoreTransformerLayer
from
megatron.core.transformer.transformer_layer
import
TransformerLayer
as
MegatronCoreTransformerLayer
from
megatron.core.transformer.moe.token_dispatcher
import
MoEAlltoAllTokenDispatcher
from
dcu_megatron.core.transformer.utils
import
SubmoduleCallables
,
TransformerLayerSubmoduleCallables
class
TransformerLayer
(
MegatronCoreTransformerLayer
):
class
TransformerLayer
(
MegatronCoreTransformerLayer
):
...
@@ -34,7 +31,10 @@ class TransformerLayer(MegatronCoreTransformerLayer):
...
@@ -34,7 +31,10 @@ class TransformerLayer(MegatronCoreTransformerLayer):
inference_params
:
Optional
[
Any
]
=
None
,
inference_params
:
Optional
[
Any
]
=
None
,
):
):
if
not
isinstance
(
self
.
mlp
,
MoELayer
):
if
(
not
isinstance
(
self
.
mlp
,
MoELayer
)
or
not
isinstance
(
self
.
mlp
.
token_dispatcher
,
MoEAlltoAllTokenDispatcher
)
):
return
super
().
forward
(
return
super
().
forward
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
context
=
context
,
context
=
context
,
...
@@ -55,7 +55,7 @@ class TransformerLayer(MegatronCoreTransformerLayer):
...
@@ -55,7 +55,7 @@ class TransformerLayer(MegatronCoreTransformerLayer):
pre_mlp_layernorm_output
,
pre_mlp_layernorm_output
,
tokens_per_expert
,
tokens_per_expert
,
permutated_local_input_tokens
,
permutated_local_input_tokens
,
probs
,
_
,
)
=
self
.
_submodule_attention_router_compound_forward
(
)
=
self
.
_submodule_attention_router_compound_forward
(
hidden_states
,
hidden_states
,
attention_mask
,
attention_mask
,
...
...
dcu_megatron/training/initialize.py
View file @
56819e16
...
@@ -97,7 +97,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
...
@@ -97,7 +97,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
print
(
"> initializing torch distributed ..."
,
flush
=
True
)
print
(
"> initializing torch distributed ..."
,
flush
=
True
)
# Manually set the device ids.
# Manually set the device ids.
if
device_count
>
0
:
if
device_count
>
0
:
torch
.
cuda
.
set_device
(
args
.
local_rank
%
device_count
)
torch
.
cuda
.
set_device
(
args
.
local_rank
)
device_id
=
torch
.
device
(
f
'cuda:
{
args
.
local_rank
}
'
)
device_id
=
torch
.
device
(
f
'cuda:
{
args
.
local_rank
}
'
)
else
:
else
:
device_id
=
None
device_id
=
None
...
...
dcu_megatron/training/training.py
View file @
56819e16
...
@@ -195,7 +195,8 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
...
@@ -195,7 +195,8 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
active
=
args
.
profile_step_end
-
args
.
profile_step_start
,
active
=
args
.
profile_step_end
-
args
.
profile_step_start
,
repeat
=
1
),
repeat
=
1
),
on_trace_ready
=
trace_handler
,
on_trace_ready
=
trace_handler
,
record_shapes
=
True
,
record_shapes
=
True
,
with_stack
=
True
,
)
)
prof
.
start
()
prof
.
start
()
elif
args
.
profile
and
torch
.
distributed
.
get_rank
()
in
args
.
profile_ranks
and
args
.
use_hip_profiler
:
elif
args
.
profile
and
torch
.
distributed
.
get_rank
()
in
args
.
profile_ranks
and
args
.
use_hip_profiler
:
...
...
pretrain_gpt.py
View file @
56819e16
...
@@ -16,9 +16,6 @@ from megatron.core.enums import ModelType
...
@@ -16,9 +16,6 @@ from megatron.core.enums import ModelType
from
megatron.core.datasets.blended_megatron_dataset_builder
import
BlendedMegatronDatasetBuilder
from
megatron.core.datasets.blended_megatron_dataset_builder
import
BlendedMegatronDatasetBuilder
from
megatron.core.datasets.gpt_dataset
import
GPTDatasetConfig
from
megatron.core.datasets.gpt_dataset
import
GPTDatasetConfig
from
megatron.core.datasets.gpt_dataset
import
MockGPTDataset
,
GPTDataset
from
megatron.core.datasets.gpt_dataset
import
MockGPTDataset
,
GPTDataset
from
megatron.core.models.gpt.heterogeneous.heterogeneous_layer_specs
import
(
get_gpt_heterogeneous_layer_spec
,
)
from
megatron.core.rerun_state_machine
import
get_rerun_state_machine
from
megatron.core.rerun_state_machine
import
get_rerun_state_machine
import
megatron.legacy.model
import
megatron.legacy.model
from
megatron.core.models.gpt
import
GPTModel
from
megatron.core.models.gpt
import
GPTModel
...
@@ -38,7 +35,6 @@ from megatron.core.models.gpt.gpt_layer_specs import (
...
@@ -38,7 +35,6 @@ from megatron.core.models.gpt.gpt_layer_specs import (
get_gpt_layer_with_transformer_engine_spec
,
get_gpt_layer_with_transformer_engine_spec
,
get_gpt_mtp_block_spec
,
get_gpt_mtp_block_spec
,
)
)
from
megatron.core.transformer.transformer_block
import
TransformerBlockSubmodules
from
dcu_megatron
import
megatron_adaptor
from
dcu_megatron
import
megatron_adaptor
...
@@ -102,8 +98,6 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
...
@@ -102,8 +98,6 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
if
args
.
num_experts
:
if
args
.
num_experts
:
# Define the decoder block spec
# Define the decoder block spec
transformer_layer_spec
=
get_gpt_decoder_block_spec
(
config
,
use_transformer_engine
=
use_te
,
normalization
=
args
.
normalization
)
transformer_layer_spec
=
get_gpt_decoder_block_spec
(
config
,
use_transformer_engine
=
use_te
,
normalization
=
args
.
normalization
)
elif
args
.
heterogeneous_layers_config_path
is
not
None
:
transformer_layer_spec
=
get_gpt_heterogeneous_layer_spec
(
config
,
use_te
)
else
:
else
:
# Define the decoder layer spec
# Define the decoder layer spec
if
use_te
:
if
use_te
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment