Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
08efd4ec
Commit
08efd4ec
authored
May 17, 2025
by
dongcl
Browse files
rewrite schedules based on megatron a73b4d2d4a993e9bea97fdebb841a393eb4ad5e7
parent
c964fcca
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
51 additions
and
23 deletions
+51
-23
dcu_megatron/core/pipeline_parallel/schedules.py
dcu_megatron/core/pipeline_parallel/schedules.py
+51
-23
No files found.
dcu_megatron/core/pipeline_parallel/schedules.py
View file @
08efd4ec
...
...
@@ -13,6 +13,8 @@ from megatron.core.utils import (
get_model_config
,
get_model_type
,
get_model_xattn
,
nvtx_range_pop
,
nvtx_range_push
,
)
from
megatron.core.pipeline_parallel.schedules
import
(
forward_step
,
...
...
@@ -430,7 +432,7 @@ def forward_backward_pipelining_with_interleaving(
)
# forward step
if
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
False
):
if
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
False
,
vp_stage
=
model_chunk_id
):
if
len
(
input_tensors
[
model_chunk_id
])
==
len
(
output_tensors
[
model_chunk_id
]):
input_tensors
[
model_chunk_id
].
append
(
None
)
...
...
@@ -458,6 +460,7 @@ def forward_backward_pipelining_with_interleaving(
is_first_microbatch_for_model_chunk
(
virtual_microbatch_id
),
),
current_microbatch
=
microbatch_id
,
vp_stage
=
model_chunk_id
,
)
output_tensors
[
model_chunk_id
].
append
(
output_tensor
)
...
...
@@ -477,7 +480,6 @@ def forward_backward_pipelining_with_interleaving(
"""Helper method to run backward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
backward_step())."""
nonlocal
output_tensor_grads
# TODO(dongcl)
model_chunk_id
=
get_model_chunk_id
(
virtual_microbatch_id
,
forward
=
False
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
...
...
@@ -489,7 +491,7 @@ def forward_backward_pipelining_with_interleaving(
synchronized_model_chunks
.
add
(
model_chunk_id
)
# pylint: disable=E0606
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
False
):
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
False
,
vp_stage
=
model_chunk_id
):
if
len
(
output_tensor_grads
[
model_chunk_id
])
==
0
:
output_tensor_grads
[
model_chunk_id
].
append
(
None
)
input_tensor
=
input_tensors
[
model_chunk_id
].
pop
(
0
)
...
...
@@ -728,9 +730,14 @@ def forward_backward_pipelining_with_interleaving(
input_tensor_grad
=
post_backward
(
input_tensor_grad
)
return
output_tensor
,
input_tensor_grad
is_vp_first_stage
=
partial
(
parallel_state
.
is_pipeline_first_stage
,
ignore_virtual
=
False
)
is_vp_last_stage
=
partial
(
parallel_state
.
is_pipeline_last_stage
,
ignore_virtual
=
False
)
# Run warmup forward passes.
nvtx_range_push
(
suffix
=
"warmup"
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
0
)
input_tensors
[
0
].
append
(
p2p_communication
.
recv_forward
(
tensor_shape
,
config
))
input_tensors
[
0
].
append
(
p2p_communication
.
recv_forward
(
tensor_shape
,
config
,
is_vp_first_stage
())
)
fwd_wait_handles
=
None
fwd_wait_recv_handles
=
None
...
...
@@ -760,7 +767,7 @@ def forward_backward_pipelining_with_interleaving(
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
cur_model_chunk_id
)
if
config
.
overlap_p2p_comm_warmup_flush
:
if
not
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
False
)
and
k
!=
0
:
if
not
is_vp_first_stage
(
vp_stage
=
cur_model_chunk_id
)
and
k
!=
0
:
assert
recv_prev_wait_handles
,
(
f
'pp rank
{
pipeline_parallel_rank
}
, iteration
{
k
}
,'
'should have registered recv handle'
...
...
@@ -807,7 +814,7 @@ def forward_backward_pipelining_with_interleaving(
)
# Don't send tensor downstream if on last stage.
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
False
):
if
is_vp_last_stage
(
vp_stage
=
cur_model_chunk_id
):
output_tensor
=
None
# Send and receive tensors as appropriate (send tensors computed
...
...
@@ -910,8 +917,10 @@ def forward_backward_pipelining_with_interleaving(
if
recv_next
:
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
bwd_recv_buffer
[
-
1
])
nvtx_range_pop
(
suffix
=
"warmup"
)
# Run 1F1B in steady state.
nvtx_range_push
(
suffix
=
"steady"
)
for
k
in
range
(
num_microbatches_remaining
):
# Forward pass.
forward_k
=
k
+
num_warmup_microbatches
...
...
@@ -928,14 +937,14 @@ def forward_backward_pipelining_with_interleaving(
else
:
checkpoint_activations_microbatch
=
None
cur_model_chunk_id
=
get_model_chunk_id
(
forward_k
,
forward
=
True
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
cur_model_chunk_id
)
if
config
.
overlap_p2p_comm
:
backward_k
=
k
# output send / receive sync
def
pp_pre_forward
():
if
not
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
False
):
nonlocal
recv_prev_wait_handles
cur_model_chunk_id
=
get_model_chunk_id
(
forward_k
,
forward
=
True
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
cur_model_chunk_id
)
if
not
is_vp_first_stage
(
vp_stage
=
cur_model_chunk_id
):
if
config
.
overlap_p2p_comm_warmup_flush
:
assert
recv_prev_wait_handles
,
(
f
'pp rank
{
pipeline_parallel_rank
}
, fwd iteration
{
forward_k
}
, '
...
...
@@ -956,8 +965,14 @@ def forward_backward_pipelining_with_interleaving(
nonlocal
fwd_recv_buffer
nonlocal
fwd_wait_handles
nonlocal
recv_prev_wait_handles
# Determine if current stage has anything to send in either direction,
# otherwise set tensor to None.
forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
,
forward
=
True
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
forward_model_chunk_id
)
# Last virtual stage no activation tensor to send.
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
False
):
if
is_vp_last_stage
(
vp_stage
=
forward_model_chunk_id
):
output_tensor
=
None
recv_prev
,
next_forward_model_chunk_id
=
recv_tensor_from_previous_stage
(
...
...
@@ -1002,10 +1017,14 @@ def forward_backward_pipelining_with_interleaving(
return
output_tensor
backward_k
=
k
# grad send receive sync
def
pp_pre_backward
():
nonlocal
recv_next_wait_handles
if
not
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
False
):
backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
,
forward
=
False
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
backward_model_chunk_id
)
if
not
is_vp_last_stage
(
vp_stage
=
backward_model_chunk_id
):
if
config
.
overlap_p2p_comm_warmup_flush
:
assert
recv_next_wait_handles
,
(
f
'pp rank
{
pipeline_parallel_rank
}
, bwd iteration
{
backward_k
}
, '
...
...
@@ -1023,8 +1042,13 @@ def forward_backward_pipelining_with_interleaving(
nonlocal
send_prev_wait_handle
nonlocal
bwd_wait_handles
nonlocal
recv_next_wait_handles
nonlocal
bwd_recv_buffer
backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
,
forward
=
False
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
backward_model_chunk_id
)
# First virtual stage no activation gradient tensor to send.
if
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
False
):
if
is_vp_first_stage
(
vp_stage
=
backward_model_chunk_id
):
input_tensor_grad
=
None
recv_next
,
next_backward_model_chunk_id
=
recv_tensor_from_previous_stage
(
...
...
@@ -1044,16 +1068,13 @@ def forward_backward_pipelining_with_interleaving(
send_prev_wait_handle
.
wait
()
if
bwd_wait_handles
is
not
None
:
send_prev_wait_handle
=
(
bwd_wait_handles
.
pop
(
"send_prev"
)
if
"send_prev"
in
bwd_wait_handles
else
None
bwd_wait_handles
.
pop
(
"send_prev"
)
if
"send_prev"
in
bwd_wait_handles
else
None
)
if
"recv_next"
in
bwd_wait_handles
:
recv_next_wait_handles
.
append
(
bwd_wait_handles
.
pop
(
"recv_next"
))
# Put input_tensor and output_tensor_grad in data structures in the
# right location.
if
recv_next
:
output_tensor_grads
[
next_backward_model_chunk_id
].
append
(
bwd_recv_buffer
[
backward_k
%
bwd_recv_buffer_size
]
...
...
@@ -1088,12 +1109,12 @@ def forward_backward_pipelining_with_interleaving(
# otherwise set tensor to None.
forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
,
forward
=
True
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
forward_model_chunk_id
)
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
False
):
if
is_vp_last_stage
(
vp_stage
=
forward_model_chunk_id
):
output_tensor
=
None
backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
,
forward
=
False
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
backward_model_chunk_id
)
if
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
False
):
if
is_vp_first_stage
(
vp_stage
=
backward_model_chunk_id
):
input_tensor_grad
=
None
recv_prev
,
next_forward_model_chunk_id
=
recv_tensor_from_previous_stage
(
...
...
@@ -1135,8 +1156,10 @@ def forward_backward_pipelining_with_interleaving(
print_rank_0
(
f
"rank first. 1F1B in steady state end"
)
print_rank_4
(
f
"rank last. 1F1B in steady state end"
)
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
nvtx_range_pop
(
suffix
=
"steady"
)
# Run cooldown backward passes (flush out pipeline).
nvtx_range_push
(
suffix
=
"cooldown"
)
if
not
forward_only
:
if
bwd_wait_handles
is
not
None
:
for
bwd_wait_handle
in
bwd_wait_handles
.
values
():
...
...
@@ -1144,12 +1167,14 @@ def forward_backward_pipelining_with_interleaving(
if
are_all_microbatches_in_warmup
:
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
p2p_communication
.
recv_backward
(
tensor_shape
,
config
=
config
)
p2p_communication
.
recv_backward
(
tensor_shape
,
config
=
config
,
is_last_stage
=
is_vp_last_stage
()
)
)
for
k
in
range
(
num_microbatches_remaining
,
total_num_microbatches
):
cur_model_chunk_id
=
get_model_chunk_id
(
k
,
forward
=
False
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
cur_model_chunk_id
)
if
not
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
False
)
and
k
!=
0
:
if
not
is_vp_last_stage
(
vp_stage
=
cur_model_chunk_id
)
and
k
!=
0
:
if
config
.
overlap_p2p_comm_warmup_flush
:
assert
recv_next_wait_handles
,
(
f
'pp rank
{
pipeline_parallel_rank
}
, backward iteration
{
k
}
, '
...
...
@@ -1189,7 +1214,7 @@ def forward_backward_pipelining_with_interleaving(
_
,
input_tensor_grad
=
forward_backward_helper_wrapper
(
b_virtual_microbatch_id
=
k
)
# First virtual stage no activation gradient tensor to send.
if
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
False
):
if
is_vp_first_stage
(
vp_stage
=
cur_model_chunk_id
):
input_tensor_grad
=
None
if
config
.
overlap_p2p_comm_warmup_flush
:
...
...
@@ -1246,7 +1271,9 @@ def forward_backward_pipelining_with_interleaving(
if
model_chunk_id
not
in
synchronized_model_chunks
:
config
.
grad_sync_func
[
model_chunk_id
](
model
[
model_chunk_id
].
parameters
())
synchronized_model_chunks
.
add
(
model_chunk_id
)
nvtx_range_pop
(
suffix
=
"cooldown"
)
nvtx_range_push
(
suffix
=
"misc"
)
assert
(
not
recv_prev_wait_handles
),
'recv_prev_wait_handles should be cleared at the end of a step'
...
...
@@ -1276,6 +1303,7 @@ def forward_backward_pipelining_with_interleaving(
if
hasattr
(
config
,
'enable_cuda_graph'
)
and
config
.
enable_cuda_graph
:
create_cudagraphs
()
nvtx_range_pop
(
suffix
=
"misc"
)
return
forward_data_store
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment