Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
2385a133
Commit
2385a133
authored
Jun 07, 2025
by
dongcl
Browse files
modify dualpipev, add foward_step_helper and backward_step_helper
parent
cb1230db
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
125 additions
and
112 deletions
+125
-112
dcu_megatron/core/pipeline_parallel/dualpipev/dualpipev_schedules.py
...n/core/pipeline_parallel/dualpipev/dualpipev_schedules.py
+125
-112
No files found.
dcu_megatron/core/pipeline_parallel/dualpipev/dualpipev_schedules.py
View file @
2385a133
...
@@ -354,9 +354,9 @@ def generate_dualpipev_schedule(pp_size, num_microbatches):
...
@@ -354,9 +354,9 @@ def generate_dualpipev_schedule(pp_size, num_microbatches):
num_1b1overlap_stages
[
i
]
=
(
pp_size
//
2
-
i
-
1
)
*
2
num_1b1overlap_stages
[
i
]
=
(
pp_size
//
2
-
i
-
1
)
*
2
num_interleaved_backward_stages
[
i
]
=
i
+
1
num_interleaved_backward_stages
[
i
]
=
(
i
+
1
)
*
2
num_cooldown_stages
[
i
]
=
[
i
+
1
,
pp_size
-
2
*
i
-
2
,
i
+
1
]
num_cooldown_stages
[
i
]
=
[
pp_size
//
2
-
i
-
1
,
pp_size
-
2
*
i
-
2
,
i
+
1
]
schedule_all_stages
=
{
schedule_all_stages
=
{
'warmup'
:
num_warmup_stages
,
'warmup'
:
num_warmup_stages
,
...
@@ -537,10 +537,33 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -537,10 +537,33 @@ def forward_backward_pipelining_with_cutinhalf(
# Disable async grad reductions
# Disable async grad reductions
no_sync_func
=
config
.
no_sync_func
no_sync_func
=
config
.
no_sync_func
if
isinstance
(
no_sync_func
,
list
):
def
multi_no_sync
():
stack
=
contextlib
.
ExitStack
()
for
model_chunk_no_sync_func
in
config
.
no_sync_func
:
stack
.
enter_context
(
model_chunk_no_sync_func
())
return
stack
no_sync_func
=
multi_no_sync
if
no_sync_func
is
None
:
if
no_sync_func
is
None
:
no_sync_func
=
contextlib
.
nullcontext
no_sync_func
=
contextlib
.
nullcontext
no_sync_context
=
None
no_sync_context
=
None
if
config
.
grad_sync_func
is
not
None
and
not
isinstance
(
config
.
grad_sync_func
,
list
):
config
.
grad_sync_func
=
[
config
.
grad_sync_func
for
_
in
model
]
if
config
.
param_sync_func
is
not
None
and
not
isinstance
(
config
.
param_sync_func
,
list
):
config
.
param_sync_func
=
[
config
.
param_sync_func
for
_
in
model
]
# Disable config.grad_sync_func and config.param_sync_func if only running forward passes.
# They will be re-enabled at the end of this function.
grad_sync_func
,
param_sync_func
=
None
,
None
if
forward_only
:
grad_sync_func
,
param_sync_func
=
config
.
grad_sync_func
,
config
.
param_sync_func
config
.
grad_sync_func
,
config
.
param_sync_func
=
None
,
None
def
disable_grad_sync
():
def
disable_grad_sync
():
"""Disable asynchronous grad reductions"""
"""Disable asynchronous grad reductions"""
nonlocal
no_sync_context
nonlocal
no_sync_context
...
@@ -565,7 +588,7 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -565,7 +588,7 @@ def forward_backward_pipelining_with_cutinhalf(
):
):
"""Helper method to run combined forward and backward step"""
"""Helper method to run combined forward and backward step"""
# forward prepare
# forward prepare
fwd_microbatch_id
=
master_cur
_microbatch
if
fwd_model_chunk_id
==
master_chunk_id
else
slave_cur_microbatch
fwd_microbatch_id
=
cur_fwd_chunk
_microbatch
[
fwd_model_chunk_id
]
f_context
=
contextlib
.
nullcontext
()
f_context
=
contextlib
.
nullcontext
()
set_dualpipe_chunk
(
fwd_model_chunk_id
)
set_dualpipe_chunk
(
fwd_model_chunk_id
)
...
@@ -653,11 +676,9 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -653,11 +676,9 @@ def forward_backward_pipelining_with_cutinhalf(
master_chunk_id
=
0
master_chunk_id
=
0
slave_chunk_id
=
1
slave_chunk_id
=
1
cur_fwd_chunk_microbatch
=
[
0
,
num_microbatches
]
master_cur_microbatch
=
0
cur_bwd_chunk_microbatch
=
[
0
,
num_microbatches
]
slave_cur_microbatch
=
num_microbatches
num_chunk_max_microbatch
=
[
num_microbatches
,
num_microbatches
*
2
]
master_microbatch_max
=
num_microbatches
slave_microbatch_max
=
num_microbatches
*
2
checkpoint_activations_microbatch
=
None
checkpoint_activations_microbatch
=
None
fwd_wait_handles_warmup
=
None
fwd_wait_handles_warmup
=
None
...
@@ -688,33 +709,31 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -688,33 +709,31 @@ def forward_backward_pipelining_with_cutinhalf(
return
output_tensor
return
output_tensor
def
backward_step_helper
(
input_tensor
,
output_tensor
,
output_tensor_grad
,
is_last_microbatch
=
False
):
def
backward_step_helper
(
input_tensor
,
output_tensor
,
output_tensor_grad
,
bwd_model_chunk_id
=
None
,
bwd_cur_microbatch
=
None
):
# # launch grad synchronization (default)
nonlocal
master_chunk_id
# if config.grad_sync_func is None and is_last_microbatch:
nonlocal
slave_chunk_id
# enable_grad_sync()
nonlocal
num_chunk_max_microbatch
# Enable async grad reduction in the last backward pass
# Note: If grad sync function is provided, only enable
# async grad reduction in first pipeline stage. Other
# pipeline stages do grad reduction during pipeline
# bubble.
if
(
bwd_cur_microbatch
is
not
None
and
bwd_cur_microbatch
==
num_chunk_max_microbatch
[
bwd_model_chunk_id
]
-
1
):
if
(
config
.
grad_sync_func
is
None
or
(
bwd_model_chunk_id
==
slave_chunk_id
and
parallel_state
.
is_pipeline_last_stage
())
or
(
bwd_model_chunk_id
==
master_chunk_id
and
parallel_state
.
is_pipeline_first_stage
())
):
enable_grad_sync
()
input_tensor_grad
=
backward_step
(
input_tensor_grad
=
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
,
model_type
,
config
input_tensor
,
output_tensor
,
output_tensor_grad
,
model_type
,
config
)
)
# # launch grad synchronization (custom grad sync)
# # Note: Asynchronous communication tends to slow down compute.
# # To reduce idling from mismatched microbatch times, we launch
# # asynchronous communication at the same time across the
# # pipeline-parallel group.
# if config.grad_sync_func is not None:
# grad_sync_virtual_microbatch_id = virtual_microbatch_id - pipeline_parallel_rank
# if grad_sync_virtual_microbatch_id >= 0 and is_last_microbatch_for_model_chunk(
# grad_sync_virtual_microbatch_id
# ):
# grad_sync_chunk_id = get_model_chunk_id(
# grad_sync_virtual_microbatch_id, forward=False
# )
# enable_grad_sync()
# config.grad_sync_func[grad_sync_chunk_id](model[grad_sync_chunk_id].parameters())
# synchronized_model_chunks.add(grad_sync_chunk_id)
# disable_grad_sync()
return
input_tensor_grad
return
input_tensor_grad
# Run warmup forward passes
# Run warmup forward passes
...
@@ -724,11 +743,10 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -724,11 +743,10 @@ def forward_backward_pipelining_with_cutinhalf(
output_tensor_warmup
=
forward_step_helper
(
output_tensor_warmup
=
forward_step_helper
(
input_tensor
,
input_tensor
,
master_chunk_id
,
master_chunk_id
,
master_cur_microbatch
,
cur_fwd_chunk_microbatch
[
master_chunk_id
]
,
is_first_microbatch
=
is_first_microbatch
is_first_microbatch
=
is_first_microbatch
)
)
cur_fwd_chunk_microbatch
[
master_chunk_id
]
+=
1
master_cur_microbatch
+=
1
if
i
!=
schedule
[
'warmup'
][
rank
]
-
1
:
if
i
!=
schedule
[
'warmup'
][
rank
]
-
1
:
input_tensor
,
_
=
send_forward_recv_forward
(
input_tensor
,
_
=
send_forward_recv_forward
(
...
@@ -758,10 +776,10 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -758,10 +776,10 @@ def forward_backward_pipelining_with_cutinhalf(
output_tensor
=
forward_step_helper
(
output_tensor
=
forward_step_helper
(
input_tensor
,
input_tensor
,
master_chunk_id
,
master_chunk_id
,
master_cur_microbatch
,
cur_fwd_chunk_microbatch
[
master_chunk_id
]
,
is_first_microbatch
=
is_first_microbatch
is_first_microbatch
=
is_first_microbatch
)
)
master_cur_microbatch
+=
1
cur_fwd_chunk_microbatch
[
master_chunk_id
]
+=
1
if
not
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
)
and
fwd_wait_handles_send
is
not
None
:
if
not
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
)
and
fwd_wait_handles_send
is
not
None
:
for
req
,
req_handle
in
fwd_wait_handles_send
.
items
():
for
req
,
req_handle
in
fwd_wait_handles_send
.
items
():
...
@@ -810,10 +828,10 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -810,10 +828,10 @@ def forward_backward_pipelining_with_cutinhalf(
output_tensor_slave_chunk
=
forward_step_helper
(
output_tensor_slave_chunk
=
forward_step_helper
(
input_tensor_slave
,
input_tensor_slave
,
slave_chunk_id
,
slave_chunk_id
,
slave_cur_microbatch
,
cur_fwd_chunk_microbatch
[
slave_chunk_id
]
,
is_first_microbatch
=
is_first_microbatch
is_first_microbatch
=
is_first_microbatch
)
)
slave_cur_microbatch
+=
1
cur_fwd_chunk_microbatch
[
slave_chunk_id
]
+=
1
if
not
forward_only
:
if
not
forward_only
:
if
i
==
schedule
[
'interleaved_forward'
][
rank
]
-
1
:
if
i
==
schedule
[
'interleaved_forward'
][
rank
]
-
1
:
...
@@ -849,10 +867,8 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -849,10 +867,8 @@ def forward_backward_pipelining_with_cutinhalf(
if
not
forward_only
:
if
not
forward_only
:
input_tensor_bwd
=
input_tensors
[
slave_chunk_id
].
pop
(
0
)[
1
]
input_tensor_bwd
=
input_tensors
[
slave_chunk_id
].
pop
(
0
)[
1
]
output_tensor_bwd
=
output_tensors
[
slave_chunk_id
].
pop
(
0
)
output_tensor_bwd
=
output_tensors
[
slave_chunk_id
].
pop
(
0
)
input_tensor_grad
=
backward_step_helper
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
)
input_tensor_grad
=
backward_step
(
cur_bwd_chunk_microbatch
[
slave_chunk_id
]
+=
1
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
model_type
,
config
)
if
fwd_wait_handles_slave_chunk
is
not
None
:
if
fwd_wait_handles_slave_chunk
is
not
None
:
for
req
in
fwd_wait_handles_slave_chunk
:
for
req
in
fwd_wait_handles_slave_chunk
:
...
@@ -890,10 +906,10 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -890,10 +906,10 @@ def forward_backward_pipelining_with_cutinhalf(
output_tensor
=
forward_step_helper
(
output_tensor
=
forward_step_helper
(
input_tensor_slave
,
input_tensor_slave
,
slave_chunk_id
,
slave_chunk_id
,
slave_cur_microbatch
,
cur_fwd_chunk_microbatch
[
slave_chunk_id
]
,
is_first_microbatch
=
False
is_first_microbatch
=
False
)
)
slave_cur_microbatch
+=
1
cur_fwd_chunk_microbatch
[
slave_chunk_id
]
+=
1
if
not
forward_only
:
if
not
forward_only
:
output_tensor_grad_bwd
,
_
=
recv_backward
(
output_tensor_grad_bwd
,
_
=
recv_backward
(
...
@@ -913,9 +929,7 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -913,9 +929,7 @@ def forward_backward_pipelining_with_cutinhalf(
num_overlap_steps
+=
schedule
[
'interleaved_backward'
][
rank
]
num_overlap_steps
+=
schedule
[
'interleaved_backward'
][
rank
]
for
step_id
in
range
(
num_overlap_steps
):
for
step_id
in
range
(
num_overlap_steps
):
only_bwd
=
False
only_bwd
=
False
if
fwd_model_chunk_id
==
master_chunk_id
and
master_cur_microbatch
==
master_microbatch_max
:
if
cur_fwd_chunk_microbatch
[
fwd_model_chunk_id
]
==
num_chunk_max_microbatch
[
fwd_model_chunk_id
]:
only_bwd
=
True
if
fwd_model_chunk_id
==
slave_chunk_id
and
slave_cur_microbatch
==
slave_microbatch_max
:
only_bwd
=
True
only_bwd
=
True
if
not
only_bwd
:
if
not
only_bwd
:
...
@@ -928,18 +942,16 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -928,18 +942,16 @@ def forward_backward_pipelining_with_cutinhalf(
fwd_wait_handles_recv
=
None
fwd_wait_handles_recv
=
None
def
pp_post_forward
(
output_tensor
):
def
pp_post_forward
(
output_tensor
):
nonlocal
master_cur
_microbatch
nonlocal
cur_fwd_chunk
_microbatch
nonlocal
slave_cur
_microbatch
nonlocal
num_chunk_max
_microbatch
nonlocal
fwd_wait_handles
nonlocal
fwd_wait_handles
nonlocal
fwd_wait_handles_slave_chunk
nonlocal
fwd_wait_handles_slave_chunk
nonlocal
firstFB_no_overlp_handle
nonlocal
firstFB_no_overlp_handle
if
fwd_model_chunk_id
==
master_chunk_id
:
if
fwd_model_chunk_id
==
master_chunk_id
:
master_cur_microbatch
+=
1
fwd_send_only
=
False
fwd_send_only
=
False
else
:
else
:
slave_cur_microbatch
+=
1
fwd_send_only
=
(
cur_fwd_chunk_microbatch
[
master_chunk_id
]
==
num_chunk_max_microbatch
[
master_chunk_id
])
fwd_send_only
=
(
master_cur_microbatch
==
master_microbatch_max
)
# 同步上个阶段最后一个slave前向send
# 同步上个阶段最后一个slave前向send
if
fwd_wait_handles_slave_chunk
is
not
None
:
if
fwd_wait_handles_slave_chunk
is
not
None
:
...
@@ -1016,13 +1028,13 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -1016,13 +1028,13 @@ def forward_backward_pipelining_with_cutinhalf(
# forward
# forward
pp_pre_forward
()
pp_pre_forward
()
fwd_microbatch
=
master_cur_microbatch
if
fwd_model_chunk_id
==
master_chunk_id
else
slave_cur_microbatch
output_tensor
=
forward_step_helper
(
output_tensor
=
forward_step_helper
(
input_tensor
,
input_tensor
,
fwd_model_chunk_id
,
fwd_model_chunk_id
,
fwd_microbatch
,
cur_fwd_chunk_microbatch
[
fwd_model_chunk_id
]
,
is_first_microbatch
=
False
is_first_microbatch
=
False
)
)
cur_fwd_chunk_microbatch
[
fwd_model_chunk_id
]
+=
1
input_tensor
=
pp_post_forward
(
output_tensor
)
input_tensor
=
pp_post_forward
(
output_tensor
)
# backward
# backward
...
@@ -1031,13 +1043,14 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -1031,13 +1043,14 @@ def forward_backward_pipelining_with_cutinhalf(
input_tensor_bwd
=
input_tensors
[
bwd_model_chunk_id
].
pop
(
0
)[
1
]
input_tensor_bwd
=
input_tensors
[
bwd_model_chunk_id
].
pop
(
0
)[
1
]
output_tensor_bwd
=
output_tensors
[
bwd_model_chunk_id
].
pop
(
0
)
output_tensor_bwd
=
output_tensors
[
bwd_model_chunk_id
].
pop
(
0
)
input_tensor_grad
=
backward_step_helper
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
)
input_tensor_grad
=
backward_step_helper
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
)
cur_bwd_chunk_microbatch
[
bwd_model_chunk_id
]
+=
1
else
:
else
:
input_tensor_grad
=
None
input_tensor_grad
=
None
output_tensor_grad_bwd
=
pp_post_backward
(
input_tensor_grad
)
output_tensor_grad_bwd
=
pp_post_backward
(
input_tensor_grad
)
# only run backward
# only run backward
else
:
else
:
if
bwd_model_chunk_id
==
slave_chunk_id
and
slave_cur
_microbatch
<
slave_
microbatch_max
:
if
bwd_model_chunk_id
==
slave_chunk_id
and
cur_fwd_chunk
_microbatch
[
slave_
chunk_id
]
<
num_chunk_max_microbatch
[
slave_chunk_id
]
:
input_tensor
,
fwd_wait_handles_recv
=
recv_forward
(
input_tensor
,
fwd_wait_handles_recv
=
recv_forward
(
tensor_shape
,
config
,
slave_chunk_id
,
async_op
=
True
)
tensor_shape
,
config
,
slave_chunk_id
,
async_op
=
True
)
if
not
forward_only
:
if
not
forward_only
:
...
@@ -1049,75 +1062,71 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -1049,75 +1062,71 @@ def forward_backward_pipelining_with_cutinhalf(
input_tensor_bwd
=
input_tensors
[
bwd_model_chunk_id
].
pop
(
0
)[
1
]
input_tensor_bwd
=
input_tensors
[
bwd_model_chunk_id
].
pop
(
0
)[
1
]
output_tensor_bwd
=
output_tensors
[
bwd_model_chunk_id
].
pop
(
0
)
output_tensor_bwd
=
output_tensors
[
bwd_model_chunk_id
].
pop
(
0
)
input_tensor_grad
=
backward_step_helper
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
)
input_tensor_grad
=
backward_step_helper
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
bwd_model_chunk_id
=
bwd_model_chunk_id
,
bwd_cur_microbatch
=
cur_bwd_chunk_microbatch
[
bwd_model_chunk_id
]
)
cur_bwd_chunk_microbatch
[
bwd_model_chunk_id
]
+=
1
if
parallel_state
.
is_pipeline_last_stage
()
and
fwd_model_chunk_id
==
master_chunk_id
:
if
parallel_state
.
is_pipeline_last_stage
()
and
fwd_model_chunk_id
==
master_chunk_id
:
output_tensor_grad_bwd
=
input_tensor_grad
output_tensor_grad_bwd
=
input_tensor_grad
else
:
else
:
# send_backward_recv_slave_backward
if
step_id
==
num_overlap_steps
-
1
:
output_tensor_grad_bwd
,
bwd_wait_handles
=
send_backward_recv_slave_backward
(
input_tensor_grad
,
bwd_wait_handles
=
send_backward
(
tensor_shape
,
config
,
fwd_model_chunk_id
)
input_tensor_grad
,
tensor_shape
,
config
,
bwd_model_chunk_id
,
)
else
:
# send_backward_recv_slave_backward
output_tensor_grad_bwd
,
bwd_wait_handles
=
send_backward_recv_slave_backward
(
input_tensor_grad
,
tensor_shape
,
config
,
fwd_model_chunk_id
)
# swap fwd & bwd chunks
# swap fwd & bwd chunks
fwd_model_chunk_id
,
bwd_model_chunk_id
=
bwd_model_chunk_id
,
fwd_model_chunk_id
fwd_model_chunk_id
,
bwd_model_chunk_id
=
bwd_model_chunk_id
,
fwd_model_chunk_id
if
not
forward_only
:
# Launch any remaining grad reductions.
# Run cooldown phases
if
config
.
grad_sync_func
is
not
None
:
merged_input_tensors
=
[]
enable_grad_sync
()
merged_output_tensors
=
[]
config
.
grad_sync_func
(
model
[
slave_chunk_id
].
parameters
())
while
len
(
input_tensors
[
0
])
>
0
or
len
(
input_tensors
[
1
])
>
0
:
disable_grad_sync
()
if
len
(
input_tensors
[
bwd_model_chunk_id
])
>
0
:
merged_input_tensors
.
append
(
input_tensors
[
bwd_model_chunk_id
].
pop
(
0
))
merged_output_tensors
.
append
(
(
output_tensors
[
bwd_model_chunk_id
].
pop
(
0
),
bwd_model_chunk_id
))
if
len
(
input_tensors
[
1
-
bwd_model_chunk_id
])
>
0
:
merged_input_tensors
.
append
(
input_tensors
[
1
-
bwd_model_chunk_id
].
pop
(
0
))
merged_output_tensors
.
append
(
(
output_tensors
[
1
-
bwd_model_chunk_id
].
pop
(
0
),
1
-
bwd_model_chunk_id
))
bwd_wait_handles_recv
=
None
for
i
in
range
(
pp_size
):
if
bwd_wait_handles
is
not
None
:
for
req
,
req_handle
in
bwd_wait_handles
.
items
():
if
req_handle
is
not
None
:
req_handle
.
wait
()
bwd_wait_handles
=
None
if
bwd_wait_handles_recv
is
not
None
:
for
req
,
req_handle
in
bwd_wait_handles_recv
.
items
():
if
req_handle
is
not
None
:
req_handle
.
wait
()
bwd_wait_handles_recv
=
None
input_tensor_bwd
=
merged_input_tensors
.
pop
(
0
)[
1
]
output_tensor_bwd
,
bwd_model_chunk_id
=
merged_output_tensors
.
pop
(
0
)
input_tensor_grad
=
backward_step_helper
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
)
# Run cooldown phases
if
not
forward_only
:
for
i
in
range
(
schedule
[
'cooldown'
][
rank
][
0
]):
output_tensor_grad_bwd
,
_
=
recv_backward
(
tensor_shape
,
config
,
master_chunk_id
)
input_tensor_bwd
=
input_tensors
[
master_chunk_id
].
pop
(
0
)[
1
]
output_tensor_bwd
=
output_tensors
[
master_chunk_id
].
pop
(
0
)
input_tensor_grad
=
backward_step_helper
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
bwd_model_chunk_id
=
master_chunk_id
,
bwd_cur_microbatch
=
cur_bwd_chunk_microbatch
[
master_chunk_id
]
)
cur_bwd_chunk_microbatch
[
master_chunk_id
]
+=
1
if
i
==
pp_size
-
1
:
_
=
send_backward
(
bwd_wait_handles
=
send_backward
(
input_tensor_grad
,
input_tensor_grad
,
tensor_shape
,
config
,
bwd_model_chunk_id
,
async_op
=
True
)
tensor_shape
,
elif
i
>=
schedule
[
'cooldown'
][
rank
][
0
]
-
1
:
config
,
bwd_wait_handles
=
send_backward
(
input_tensor_grad
,
master_chunk_id
,
tensor_shape
,
config
,
bwd_model_chunk_id
,
async_op
=
True
)
)
output_tensor_grad_bwd
,
bwd_wait_handles_recv
=
recv_backward
(
tensor_shape
,
config
,
bwd_model_chunk_id
,
async_op
=
True
)
else
:
if
parallel_state
.
is_pipeline_last_stage
()
and
(
1
-
bwd_model_chunk_id
)
==
master_chunk_id
:
output_tensor_grad_bwd
=
input_tensor_grad
else
:
# send_backward_recv_slave_backward
output_tensor_grad_bwd
,
bwd_wait_handles
=
send_backward_recv_slave_backward
(
input_tensor_grad
,
tensor_shape
,
config
,
1
-
bwd_model_chunk_id
)
if
bwd_wait_handles
is
not
None
:
# Launch any remaining grad reductions.
for
req
,
req_handle
in
bwd_wait_handles
.
items
():
if
config
.
grad_sync_func
is
not
None
:
if
req_handle
is
not
None
:
enable_grad_sync
()
req_handle
.
wait
()
config
.
grad_sync_func
(
model
[
master_chunk_id
].
parameters
())
bwd_wait_handles
=
None
if
config
.
finalize_model_grads_func
is
not
None
and
not
forward_only
:
if
config
.
finalize_model_grads_func
is
not
None
and
not
forward_only
:
...
@@ -1132,4 +1141,8 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -1132,4 +1141,8 @@ def forward_backward_pipelining_with_cutinhalf(
model
,
total_num_tokens
if
config
.
calculate_per_token_loss
else
None
model
,
total_num_tokens
if
config
.
calculate_per_token_loss
else
None
)
)
# Restore config.grad_sync_func and config.param_sync_func.
if
forward_only
:
config
.
grad_sync_func
,
config
.
param_sync_func
=
grad_sync_func
,
param_sync_func
return
forward_data_store
return
forward_data_store
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment