Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
6dcd0fb8
Commit
6dcd0fb8
authored
May 26, 2025
by
dongcl
Browse files
modify schedule_chunk_1f1b to fix the nccl communication bug
parent
124accba
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
49 additions
and
34 deletions
+49
-34
dcu_megatron/core/models/gpt/fine_grained_schedule.py
dcu_megatron/core/models/gpt/fine_grained_schedule.py
+22
-15
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
+21
-18
dcu_megatron/core/pipeline_parallel/schedules.py
dcu_megatron/core/pipeline_parallel/schedules.py
+2
-0
dcu_megatron/core/transformer/transformer_config.py
dcu_megatron/core/transformer/transformer_config.py
+2
-0
dcu_megatron/training/training.py
dcu_megatron/training/training.py
+2
-1
No files found.
dcu_megatron/core/models/gpt/fine_grained_schedule.py
View file @
6dcd0fb8
...
...
@@ -342,8 +342,6 @@ class MoeMlPNode(TransformerLayerNode):
assert
mlp_bias
is
None
# pre_mlp_layernorm_output used
# cur_stream = torch.cuda.current_stream()
# self.common_state.pre_mlp_layernorm_output.record_stream(cur_stream)
self
.
common_state
.
pre_mlp_layernorm_output
=
None
return
expert_output
,
shared_expert_output
...
...
@@ -542,6 +540,8 @@ class ModelChunkSchedulePlan(AbstractSchedulePlan):
def
state
(
self
):
return
self
.
_model_chunk_state
# F_DISPATCH_B_MLP_SYNC_EVENT = torch.cuda.Event()
F_DISPATCH_B_MLP_SYNC_EVENT
=
None
def
schedule_layer_1f1b
(
f_layer
,
...
...
@@ -581,13 +581,17 @@ def schedule_layer_1f1b(
with
f_context
:
f_input
=
f_layer
.
attn
.
forward
(
f_input
)
f_dispatch_b_mlp_sync_event
=
None
if
f_layer
is
not
None
and
b_layer
is
not
None
:
f_dispatch_b_mlp_sync_event
=
F_DISPATCH_B_MLP_SYNC_EVENT
if
f_layer
is
not
None
:
with
f_context
:
f_input
=
f_layer
.
dispatch
.
forward
(
f_input
)
f_input
=
f_layer
.
dispatch
.
forward
(
f_input
,
stream_record_event
=
f_dispatch_b_mlp_sync_event
)
if
b_layer
is
not
None
:
with
b_context
:
b_grad
=
b_layer
.
mlp
.
backward
(
b_grad
)
b_grad
=
b_layer
.
mlp
.
backward
(
b_grad
,
stream_wait_event
=
f_dispatch_b_mlp_sync_event
)
b_grad
=
b_layer
.
dispatch
.
backward
(
b_grad
)
b_layer
.
mlp
.
dw
()
...
...
@@ -690,32 +694,28 @@ def schedule_chunk_1f1b(
)
torch
.
cuda
.
nvtx
.
range_pop
()
# tail forward
f_input
=
layer_pre_forward
()
del
layer_pre_forward
# tail backward
grad
=
layer_pre_backward
()
del
layer_pre_backward
with
b_context
:
for
i
in
range
(
overlaped_layers
,
b_num_layers
):
b_layer
=
b_schedule_plan
.
get_layer
(
b_num_layers
-
1
-
i
)
torch
.
cuda
.
nvtx
.
range_push
(
f
"layer_
{
b_num_layers
-
1
-
i
}
b"
)
tmp
,
grad
,
_
=
schedule_layer_1f1b
(
None
,
b_layer
,
b_grad
=
grad
)
_
,
grad
,
_
=
schedule_layer_1f1b
(
None
,
b_layer
,
b_grad
=
grad
)
torch
.
cuda
.
nvtx
.
range_pop
()
if
b_schedule_plan
is
not
None
:
b_schedule_plan
.
pre_process
.
backward
(
grad
)
# tail forward
f_input
=
layer_pre_forward
()
del
layer_pre_forward
with
f_context
:
for
i
in
range
(
overlaped_layers
,
f_num_layers
):
f_layer
=
f_schedule_plan
.
get_layer
(
i
)
torch
.
cuda
.
nvtx
.
range_push
(
f
"layer_
{
i
}
f"
)
f_input
,
tmp
,
_
=
schedule_layer_1f1b
(
f_layer
,
None
,
f_input
=
f_input
)
f_input
,
_
,
_
=
schedule_layer_1f1b
(
f_layer
,
None
,
f_input
=
f_input
)
torch
.
cuda
.
nvtx
.
range_pop
()
if
f_schedule_plan
is
not
None
and
f_schedule_plan
.
post_process
is
not
None
:
f_input
=
f_schedule_plan
.
post_process
.
forward
(
f_input
)
# output pp send receive, overlapped with attn backward
if
f_schedule_plan
is
not
None
and
post_forward
is
not
None
:
with
f_context
:
...
...
@@ -732,6 +732,13 @@ def schedule_chunk_1f1b(
layer_pre_backward_dw
()
del
layer_pre_backward_dw
with
f_context
:
if
f_schedule_plan
is
not
None
and
f_schedule_plan
.
post_process
is
not
None
:
f_input
=
f_schedule_plan
.
post_process
.
forward
(
f_input
)
with
b_context
:
if
b_schedule_plan
is
not
None
:
b_schedule_plan
.
pre_process
.
backward
(
grad
)
if
f_schedule_plan
:
f_schedule_plan
.
wait_current_stream
()
if
b_schedule_plan
:
...
...
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
View file @
6dcd0fb8
...
...
@@ -73,17 +73,20 @@ class ScheduleNode:
)
return
output_grad
def
forward
(
self
,
inputs
=
()):
def
forward
(
self
,
inputs
=
()
,
stream_wait_event
=
None
,
stream_record_event
=
None
):
"""schedule node forward"""
if
not
isinstance
(
inputs
,
tuple
):
inputs
=
(
inputs
,)
return
self
.
_forward
(
*
inputs
)
return
self
.
_forward
(
*
inputs
,
stream_wait_event
=
stream_wait_event
,
stream_record_event
=
stream_record_event
)
def
_forward
(
self
,
*
inputs
):
def
_forward
(
self
,
*
inputs
,
stream_wait_event
=
None
,
stream_record_event
=
None
):
with
stream_acquire_context
(
self
.
stream
,
self
.
event
):
torch
.
cuda
.
nvtx
.
range_push
(
f
"
{
self
.
name
}
forward"
)
with
torch
.
cuda
.
stream
(
self
.
stream
):
if
stream_wait_event
is
not
None
:
stream_wait_event
.
wait
(
self
.
stream
)
self
.
inputs
=
[
make_viewless
(
e
).
detach
()
if
e
is
not
None
else
None
for
e
in
inputs
]
for
i
,
input
in
enumerate
(
self
.
inputs
):
if
input
is
not
None
:
...
...
@@ -98,6 +101,10 @@ class ScheduleNode:
data
=
tuple
([
make_viewless
(
e
)
if
isinstance
(
e
,
Tensor
)
else
e
for
e
in
data
])
self
.
output
=
data
if
stream_record_event
is
not
None
:
stream_record_event
.
record
(
self
.
stream
)
torch
.
cuda
.
nvtx
.
range_pop
()
if
self
.
free_inputs
:
...
...
@@ -111,16 +118,19 @@ class ScheduleNode:
"""get the forward output"""
return
self
.
output
def
backward
(
self
,
output_grad
):
def
backward
(
self
,
output_grad
,
stream_wait_event
=
None
,
stream_record_event
=
None
):
"""schedule node backward"""
if
not
isinstance
(
output_grad
,
tuple
):
output_grad
=
(
output_grad
,)
return
self
.
_backward
(
*
output_grad
)
return
self
.
_backward
(
*
output_grad
,
stream_wait_event
=
stream_wait_event
,
stream_record_event
=
stream_record_event
)
def
_backward
(
self
,
*
output_grad
):
def
_backward
(
self
,
*
output_grad
,
stream_wait_event
=
None
,
stream_record_event
=
None
):
with
stream_acquire_context
(
self
.
stream
,
self
.
event
):
torch
.
cuda
.
nvtx
.
range_push
(
f
"
{
self
.
name
}
backward"
)
with
torch
.
cuda
.
stream
(
self
.
stream
):
if
stream_wait_event
is
not
None
:
stream_wait_event
.
wait
(
self
.
stream
)
outputs
=
self
.
output
if
not
isinstance
(
outputs
,
tuple
):
outputs
=
(
outputs
,)
...
...
@@ -131,6 +141,10 @@ class ScheduleNode:
output_grad
=
self
.
backward_func
(
outputs
,
output_grad
)
else
:
output_grad
=
self
.
default_backward_func
(
outputs
,
output_grad
)
if
stream_record_event
is
not
None
:
stream_record_event
.
record
(
self
.
stream
)
torch
.
cuda
.
nvtx
.
range_pop
()
# output_grad maybe from another stream
...
...
@@ -198,17 +212,6 @@ def schedule_chunk_1f1b(
)
def
schedule_chunk_forward
(
schedule_plan
):
"""model level fine-grained forward schedule"""
f_input
=
schedule_chunk_1f1b
(
schedule_plan
,
None
,
None
)
return
f_input
def
schedule_chunk_backward
(
schedule_plan
,
grad
):
"""model level fine-grained backward schedule"""
tmp
=
schedule_chunk_1f1b
(
None
,
schedule_plan
,
grad
)
_COMP_STREAM
=
None
_COM_STREAM
=
None
...
...
@@ -221,7 +224,7 @@ def set_streams(comp_stream=None, com_stream=None):
return
if
comp_stream
is
None
:
comp_stream
=
torch
.
cuda
.
Stream
(
device
=
"cuda"
)
comp_stream
=
torch
.
cuda
.
current_stream
(
)
if
com_stream
is
None
:
com_stream
=
torch
.
cuda
.
Stream
(
device
=
"cuda"
)
...
...
dcu_megatron/core/pipeline_parallel/schedules.py
View file @
6dcd0fb8
...
...
@@ -448,6 +448,8 @@ def forward_backward_pipelining_with_interleaving(
"""Helper method to run backward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
backward_step())."""
nonlocal
output_tensor_grads
model_chunk_id
=
get_model_chunk_id
(
virtual_microbatch_id
,
forward
=
False
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
...
...
dcu_megatron/core/transformer/transformer_config.py
View file @
6dcd0fb8
...
...
@@ -40,6 +40,8 @@ class ExtraTransformerConfig:
combined_1f1b_recipe
:
str
=
'ep_a2a'
"""Recipe to use for combined 1F1B. Currently only 'ep_a2a' and 'golden' are supported."""
# split_bw: bool = False
@
dataclass
class
TransformerConfigPatch
(
TransformerConfig
,
ExtraTransformerConfig
):
...
...
dcu_megatron/training/training.py
View file @
6dcd0fb8
...
...
@@ -195,7 +195,8 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
active
=
args
.
profile_step_end
-
args
.
profile_step_start
,
repeat
=
1
),
on_trace_ready
=
trace_handler
,
record_shapes
=
True
,
record_shapes
=
True
,
with_stack
=
True
,
)
prof
.
start
()
elif
args
.
profile
and
torch
.
distributed
.
get_rank
()
in
args
.
profile_ranks
and
args
.
use_hip_profiler
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment