Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
cb1230db
Commit
cb1230db
authored
Jun 05, 2025
by
dongcl
Browse files
rewrite dualpipev_schedules; deduplicate the code
parent
a58a2da6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
259 additions
and
203 deletions
+259
-203
dcu_megatron/core/pipeline_parallel/dualpipev/dualpipev_schedules.py
...n/core/pipeline_parallel/dualpipev/dualpipev_schedules.py
+259
-203
No files found.
dcu_megatron/core/pipeline_parallel/dualpipev/dualpipev_schedules.py
View file @
cb1230db
...
...
@@ -557,6 +557,82 @@ def forward_backward_pipelining_with_cutinhalf(
disable_grad_sync
()
def
combined_forward_backward_helper
(
fwd_model_chunk_id
,
bwd_model_chunk_id
,
fwd_input_tensor
=
None
,
bwd_output_tensor_grad
=
None
):
"""Helper method to run combined forward and backward step"""
# forward prepare
fwd_microbatch_id
=
master_cur_microbatch
if
fwd_model_chunk_id
==
master_chunk_id
else
slave_cur_microbatch
f_context
=
contextlib
.
nullcontext
()
set_dualpipe_chunk
(
fwd_model_chunk_id
)
# backward prepare
b_context
=
contextlib
.
nullcontext
()
bwd_input_tensor
=
input_tensors
[
bwd_model_chunk_id
].
pop
(
0
)[
1
]
bwd_output_tensor
=
output_tensors
[
bwd_model_chunk_id
].
pop
(
0
)
output_tensor
,
num_tokens
,
input_tensor_grad
=
forward_backward_step
(
forward_step_func
,
data_iterator
[
fwd_model_chunk_id
]
if
fwd_model_chunk_id
is
not
None
else
None
,
model
[
fwd_model_chunk_id
]
if
fwd_model_chunk_id
is
not
None
else
None
,
num_microbatches
,
fwd_input_tensor
,
forward_data_store
,
model
[
bwd_model_chunk_id
]
if
bwd_model_chunk_id
is
not
None
else
None
,
bwd_input_tensor
,
bwd_output_tensor
,
bwd_output_tensor_grad
,
config
,
f_context
=
f_context
,
b_context
=
b_context
,
collect_non_loss_data
=
collect_non_loss_data
,
checkpoint_activations_microbatch
=
None
,
is_first_microbatch
=
False
,
current_microbatch
=
fwd_microbatch_id
,
)
# forward post process
if
fwd_model_chunk_id
is
not
None
:
with
f_context
:
nonlocal
total_num_tokens
total_num_tokens
+=
num_tokens
.
item
()
if
not
forward_only
:
input_tensors
[
fwd_model_chunk_id
].
append
((
fwd_microbatch_id
,
fwd_input_tensor
))
output_tensors
[
fwd_model_chunk_id
].
append
(
output_tensor
)
# backward post process
if
b_model_chunk_id
:
with
b_context
:
# launch grad synchronization (custom grad sync)
# Note: Asynchronous communication tends to slow down compute.
# To reduce idling from mismatched microbatch times, we launch
# asynchronous communication at the same time across the
# pipeline-parallel group.
if
config
.
grad_sync_func
is
not
None
:
grad_sync_virtual_microbatch_id
=
(
b_virtual_microbatch_id
-
pipeline_parallel_rank
)
if
grad_sync_virtual_microbatch_id
>=
0
and
is_last_microbatch_for_model_chunk
(
grad_sync_virtual_microbatch_id
):
grad_sync_chunk_id
=
get_model_chunk_id
(
grad_sync_virtual_microbatch_id
,
forward
=
False
)
enable_grad_sync
()
config
.
grad_sync_func
[
grad_sync_chunk_id
](
model
[
grad_sync_chunk_id
].
parameters
()
)
synchronized_model_chunks
.
add
(
grad_sync_chunk_id
)
disable_grad_sync
()
if
input_tensor
is
not
None
:
assert
input_tensor_grad
is
not
None
return
output_tensor
,
input_tensor_grad
# Compute number of steps for each stage
pp_size
=
parallel_state
.
get_pipeline_model_parallel_world_size
()
rank
=
parallel_state
.
get_pipeline_model_parallel_rank
()
...
...
@@ -583,35 +659,74 @@ def forward_backward_pipelining_with_cutinhalf(
master_microbatch_max
=
num_microbatches
slave_microbatch_max
=
num_microbatches
*
2
set_dualpipe_chunk
(
master_chunk_id
)
checkpoint_activations_microbatch
=
None
input_tensor
=
recv_forward
(
tensor_shape
,
config
,
master_chunk_id
,
step
=
0
)[
0
]
fwd_wait_handles_warmup
=
None
# Run warmup forward passes
for
i
in
range
(
schedule
[
'warmup'
][
rank
]):
output_tensor_warmup
,
num_tokens
=
forward_step_no_model_graph
(
def
forward_step_helper
(
input_tensor
,
model_chunk_id
,
cur_microbatch
,
is_first_microbatch
=
False
):
set_dualpipe_chunk
(
model_chunk_id
)
output_tensor
,
num_tokens
=
forward_step_no_model_graph
(
forward_step_func
,
m
aster
_chunk_id
,
data_iterator
[
m
aster
_chunk_id
],
model
[
m
aster
_chunk_id
],
m
odel
_chunk_id
,
data_iterator
[
m
odel
_chunk_id
],
model
[
m
odel
_chunk_id
],
num_microbatches
,
input_tensor
,
forward_data_store
,
config
,
collect_non_loss_data
,
checkpoint_activations_microbatch
,
is_first_microbatch
=
check_first_val_step
(
first_val_step
,
forward_only
,
i
==
0
)
,
current_microbatch
=
master_
cur_microbatch
is_first_microbatch
=
is_first_microbatch
,
current_microbatch
=
cur_microbatch
)
nonlocal
total_num_tokens
total_num_tokens
+=
num_tokens
.
item
()
if
not
forward_only
:
input_tensors
[
master_chunk_id
].
append
(
(
master_cur_microbatch
,
input_tensor
))
output_tensors
[
master_chunk_id
].
append
(
output_tensor_warmup
)
input_tensors
[
model_chunk_id
].
append
(
(
cur_microbatch
,
input_tensor
))
output_tensors
[
model_chunk_id
].
append
(
output_tensor
)
return
output_tensor
def
backward_step_helper
(
input_tensor
,
output_tensor
,
output_tensor_grad
,
is_last_microbatch
=
False
):
# # launch grad synchronization (default)
# if config.grad_sync_func is None and is_last_microbatch:
# enable_grad_sync()
input_tensor_grad
=
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
,
model_type
,
config
)
# # launch grad synchronization (custom grad sync)
# # Note: Asynchronous communication tends to slow down compute.
# # To reduce idling from mismatched microbatch times, we launch
# # asynchronous communication at the same time across the
# # pipeline-parallel group.
# if config.grad_sync_func is not None:
# grad_sync_virtual_microbatch_id = virtual_microbatch_id - pipeline_parallel_rank
# if grad_sync_virtual_microbatch_id >= 0 and is_last_microbatch_for_model_chunk(
# grad_sync_virtual_microbatch_id
# ):
# grad_sync_chunk_id = get_model_chunk_id(
# grad_sync_virtual_microbatch_id, forward=False
# )
# enable_grad_sync()
# config.grad_sync_func[grad_sync_chunk_id](model[grad_sync_chunk_id].parameters())
# synchronized_model_chunks.add(grad_sync_chunk_id)
# disable_grad_sync()
return
input_tensor_grad
# Run warmup forward passes
input_tensor
,
_
=
recv_forward
(
tensor_shape
,
config
,
master_chunk_id
)
for
i
in
range
(
schedule
[
'warmup'
][
rank
]):
is_first_microbatch
=
check_first_val_step
(
first_val_step
,
forward_only
,
i
==
0
)
output_tensor_warmup
=
forward_step_helper
(
input_tensor
,
master_chunk_id
,
master_cur_microbatch
,
is_first_microbatch
=
is_first_microbatch
)
master_cur_microbatch
+=
1
...
...
@@ -639,29 +754,13 @@ def forward_backward_pipelining_with_cutinhalf(
fwd_wait_handles
=
None
is_first_microbatch
=
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
)
and
(
i
==
0
)
set_dualpipe_chunk
(
master_chunk_id
)
output_tensor
,
num_tokens
=
forward_step_no_model_graph
(
forward_step_func
,
master_chunk_id
,
data_iterator
[
master_chunk_id
],
model
[
master_chunk_id
],
num_microbatches
,
is_first_microbatch
=
check_first_val_step
(
first_val_step
,
forward_only
,
is_first_microbatch
)
output_tensor
=
forward_step_helper
(
input_tensor
,
forward_data_store
,
config
,
collect_non_loss_data
,
checkpoint_activations_microbatch
,
is_first_microbatch
=
check_first_val_step
(
first_val_step
,
forward_only
,
is_first_microbatch
),
current_microbatch
=
master_cur_microbatch
master_chunk_id
,
master_cur_microbatch
,
is_first_microbatch
=
is_first_microbatch
)
total_num_tokens
+=
num_tokens
.
item
()
if
not
forward_only
:
input_tensors
[
master_chunk_id
].
append
(
(
master_cur_microbatch
,
input_tensor
))
output_tensors
[
master_chunk_id
].
append
(
output_tensor
)
master_cur_microbatch
+=
1
if
not
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
)
and
fwd_wait_handles_send
is
not
None
:
...
...
@@ -676,19 +775,16 @@ def forward_backward_pipelining_with_cutinhalf(
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
if
not
forward_only
:
input_tensor_slave
_chunk
=
output_tensor
.
detach
()
input_tensor_slave
_chunk
.
requires_grad
=
True
input_tensor_slave
=
output_tensor
.
detach
()
input_tensor_slave
.
requires_grad
=
True
else
:
input_tensor_slave_chunk
=
output_tensor
input_tensor
,
fwd_wait_handles
=
recv_forward
(
tensor_shape
,
config
,
master_chunk_id
,
async_op
=
True
)
input_tensor_slave
=
output_tensor
else
:
input_tensor_slave
_chunk
,
_
=
recv_forward
(
input_tensor_slave
,
_
=
recv_forward
(
tensor_shape
,
config
,
slave_chunk_id
)
input_tensor
,
fwd_wait_handles
=
recv_forward
(
tensor_shape
,
config
,
master_chunk_id
,
async_op
=
True
)
input_tensor
,
fwd_wait_handles
=
recv_forward
(
tensor_shape
,
config
,
master_chunk_id
,
async_op
=
True
)
if
fwd_wait_handles_warmup
is
not
None
:
for
req
,
req_handle
in
fwd_wait_handles_warmup
.
items
():
...
...
@@ -710,28 +806,13 @@ def forward_backward_pipelining_with_cutinhalf(
deallocate_output_tensor
(
output_tensor_slave_chunk
,
config
.
deallocate_pipeline_outputs
)
set_dualpipe_chunk
(
slave_chunk_id
)
output_tensor_slave_chunk
,
num_tokens
=
forward_step_
no_model_graph
(
forward_step_func
,
is_first_microbatch
=
check_first_val_step
(
first_val_step
,
forward_only
,
i
==
0
)
output_tensor_slave_chunk
=
forward_step_
helper
(
input_tensor_slave
,
slave_chunk_id
,
data_iterator
[
slave_chunk_id
],
model
[
slave_chunk_id
],
num_microbatches
,
input_tensor_slave_chunk
,
forward_data_store
,
config
,
collect_non_loss_data
,
checkpoint_activations_microbatch
,
current_microbatch
=
slave_cur_microbatch
,
slave_cur_microbatch
,
is_first_microbatch
=
is_first_microbatch
)
total_num_tokens
+=
num_tokens
.
item
()
if
not
forward_only
:
input_tensors
[
slave_chunk_id
].
append
(
(
slave_cur_microbatch
,
input_tensor_slave_chunk
))
output_tensors
[
slave_chunk_id
].
append
(
output_tensor_slave_chunk
)
slave_cur_microbatch
+=
1
if
not
forward_only
:
...
...
@@ -765,8 +846,6 @@ def forward_backward_pipelining_with_cutinhalf(
# Run 1b1w1f stages for slave chunk
bwd_wait_handles
=
None
for
_
in
range
(
schedule
[
'1b1w1f'
][
rank
]):
# WeightGradStore.start_decouple()
if
not
forward_only
:
input_tensor_bwd
=
input_tensors
[
slave_chunk_id
].
pop
(
0
)[
1
]
output_tensor_bwd
=
output_tensors
[
slave_chunk_id
].
pop
(
0
)
...
...
@@ -775,8 +854,6 @@ def forward_backward_pipelining_with_cutinhalf(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
model_type
,
config
)
# WeightGradStore.end_decouple()
if
fwd_wait_handles_slave_chunk
is
not
None
:
for
req
in
fwd_wait_handles_slave_chunk
:
req
.
wait
()
...
...
@@ -800,12 +877,9 @@ def forward_backward_pipelining_with_cutinhalf(
tensor_shape
,
config
,
slave_chunk_id
)
# If asynchronous, the memory will rise.
input_tensor_slave
_chunk
,
recv_forward_handle
=
recv_forward
(
input_tensor_slave
,
recv_forward_handle
=
recv_forward
(
tensor_shape
,
config
,
slave_chunk_id
)
# 1w: Weight Grad Compute
# WeightGradStore.pop()
if
recv_forward_handle
is
not
None
:
for
req
,
handle
in
recv_forward_handle
.
items
():
if
handle
is
not
None
:
...
...
@@ -813,27 +887,12 @@ def forward_backward_pipelining_with_cutinhalf(
recv_forward_handle
=
None
# 1F: Forward pass
set_dualpipe_chunk
(
slave_chunk_id
)
output_tensor_slave_chunk
,
num_tokens
=
forward_step_no_model_graph
(
forward_step_func
,
output_tensor
=
forward_step_helper
(
input_tensor_slave
,
slave_chunk_id
,
data_iterator
[
slave_chunk_id
],
model
[
slave_chunk_id
],
num_microbatches
,
input_tensor_slave_chunk
,
forward_data_store
,
config
,
collect_non_loss_data
,
checkpoint_activations_microbatch
,
current_microbatch
=
slave_cur_microbatch
slave_cur_microbatch
,
is_first_microbatch
=
False
)
total_num_tokens
+=
num_tokens
.
item
()
if
not
forward_only
:
input_tensors
[
slave_chunk_id
].
append
(
(
slave_cur_microbatch
,
input_tensor_slave_chunk
))
output_tensors
[
slave_chunk_id
].
append
(
output_tensor_slave_chunk
)
slave_cur_microbatch
+=
1
if
not
forward_only
:
...
...
@@ -844,6 +903,8 @@ def forward_backward_pipelining_with_cutinhalf(
tensor_shape
,
config
,
slave_chunk_id
,
async_op
=
True
)
# Run overlaping f&bw stages
fwd_wait_handles
=
None
bwd_wait_handles
=
None
fwd_wait_handles_recv
=
None
fwd_model_chunk_id
=
master_chunk_id
bwd_model_chunk_id
=
slave_chunk_id
...
...
@@ -858,105 +919,121 @@ def forward_backward_pipelining_with_cutinhalf(
only_bwd
=
True
if
not
only_bwd
:
fwd_microbatch
=
master_cur_microbatch
if
fwd_model_chunk_id
==
master_chunk_id
else
slave_cur_microbatch
set_dualpipe_chunk
(
fwd_model_chunk_id
)
def
pp_pre_forward
():
nonlocal
fwd_wait_handles_recv
if
fwd_wait_handles_recv
is
not
None
:
for
req
,
req_handle
in
fwd_wait_handles_recv
.
items
():
req_handle
.
wait
()
fwd_wait_handles_recv
=
None
output_tensor
,
num_tokens
=
forward_step_no_model_graph
(
forward_step_func
,
fwd_model_chunk_id
,
data_iterator
[
fwd_model_chunk_id
],
model
[
fwd_model_chunk_id
],
num_microbatches
,
input_tensor
,
forward_data_store
,
config
,
collect_non_loss_data
,
checkpoint_activations_microbatch
,
current_microbatch
=
fwd_microbatch
)
total_num_tokens
+=
num_tokens
.
item
()
if
not
forward_only
:
input_tensors
[
fwd_model_chunk_id
].
append
(
(
fwd_microbatch
,
input_tensor
))
output_tensors
[
fwd_model_chunk_id
].
append
(
output_tensor
)
if
fwd_model_chunk_id
==
master_chunk_id
:
master_cur_microbatch
+=
1
fwd_send_only
=
False
else
:
slave_cur_microbatch
+=
1
fwd_send_only
=
(
master_cur_microbatch
==
master_microbatch_max
)
# 同步上个阶段最后一个slave前向send
if
fwd_wait_handles_slave_chunk
is
not
None
:
for
req
,
req_handle
in
fwd_wait_handles_slave_chunk
.
items
():
if
req_handle
is
not
None
:
if
fwd_wait_handles_recv
is
not
None
:
for
req
,
req_handle
in
fwd_wait_handles_recv
.
items
():
req_handle
.
wait
()
fwd_wait_handles_slave_chunk
=
None
if
not
forward_only
:
deallocate_output_tensor
(
output_tensor_slave_chunk
,
config
.
deallocate_pipeline_outputs
)
fwd_wait_handles_recv
=
None
def
pp_post_forward
(
output_tensor
):
nonlocal
master_cur_microbatch
nonlocal
slave_cur_microbatch
nonlocal
fwd_wait_handles
nonlocal
fwd_wait_handles_slave_chunk
nonlocal
firstFB_no_overlp_handle
if
fwd_model_chunk_id
==
master_chunk_id
:
master_cur_microbatch
+=
1
fwd_send_only
=
False
else
:
slave_cur_microbatch
+=
1
fwd_send_only
=
(
master_cur_microbatch
==
master_microbatch_max
)
if
fwd_send_only
:
fwd_wait_handles
=
send_forward
(
output_tensor
,
tensor_shape
,
config
,
fwd_model_chunk_id
,
async_op
=
True
)
else
:
if
parallel_state
.
is_pipeline_last_stage
()
and
fwd_model_chunk_id
==
master_chunk_id
:
# 同步上个阶段最后一个slave前向send
if
fwd_wait_handles_slave_chunk
is
not
None
:
for
req
,
req_handle
in
fwd_wait_handles_slave_chunk
.
items
():
if
req_handle
is
not
None
:
req_handle
.
wait
()
fwd_wait_handles_slave_chunk
=
None
if
not
forward_only
:
input_tensor
=
output_tensor
.
detach
()
input_tensor
.
requires_grad
=
True
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
else
:
input_tensor
=
output_tensor
else
:
input_tensor
,
fwd_wait_handles
=
send_forward_recv_slave_forward
(
output_tensor
,
tensor_shape
,
config
,
fwd_model_chunk_id
,
async_op
=
True
)
deallocate_output_tensor
(
output_tensor_slave_chunk
,
config
.
deallocate_pipeline_outputs
)
if
not
forward_only
and
firstFB_no_overlp_handle
is
not
None
:
for
req
,
req_handle
in
firstFB_no_overlp_handle
.
items
():
if
req_handle
is
not
None
:
req_handle
.
wait
()
firstFB_no_overlp_handle
=
None
if
fwd_send_only
:
input_tensor
=
None
fwd_wait_handles
=
send_forward
(
output_tensor
,
tensor_shape
,
config
,
fwd_model_chunk_id
,
async_op
=
True
)
else
:
if
parallel_state
.
is_pipeline_last_stage
()
and
fwd_model_chunk_id
==
master_chunk_id
:
if
not
forward_only
:
input_tensor
=
output_tensor
.
detach
()
input_tensor
.
requires_grad
=
True
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
else
:
input_tensor
=
output_tensor
else
:
input_tensor
,
fwd_wait_handles
=
send_forward_recv_slave_forward
(
output_tensor
,
tensor_shape
,
config
,
fwd_model_chunk_id
,
async_op
=
True
)
if
not
forward_only
:
if
bwd_wait_handles
is
not
None
:
for
req
,
req_handle
in
bwd_wait_handles
.
items
():
if
not
forward_only
and
firstFB_no_overlp_handle
is
not
None
:
for
req
,
req_handle
in
firstFB_no_overlp_handle
.
items
():
if
req_handle
is
not
None
:
req_handle
.
wait
()
bwd_wait
_handle
s
=
None
firstFB_no_overlp
_handle
=
None
input_tensor_bwd
=
input_tensors
[
bwd_model_chunk_id
].
pop
(
0
)[
1
]
output_tensor_bwd
=
output_tensors
[
bwd_model_chunk_id
].
pop
(
0
)
return
input_tensor
input_tensor_grad
=
backward_step
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
model_type
,
config
)
def
pp_pre_backward
():
nonlocal
bwd_wait_handles
if
fwd_wait_handles
is
not
None
:
for
req
,
req_handle
in
fwd_wait_handles
.
items
():
if
req_handle
is
not
None
:
req_handle
.
wait
()
fwd_wait_handles
=
None
if
not
forward_only
:
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
if
bwd_wait_handles
is
not
None
:
for
_
,
req_handle
in
bwd_wait_handles
.
items
():
if
req_handle
is
not
None
:
req_handle
.
wait
()
bwd_wait_handles
=
None
def
pp_post_backward
(
input_tensor_grad
):
nonlocal
fwd_wait_handles
nonlocal
bwd_wait_handles
if
fwd_wait_handles
is
not
None
:
for
_
,
req_handle
in
fwd_wait_handles
.
items
():
if
req_handle
is
not
None
:
req_handle
.
wait
()
fwd_wait_handles
=
None
if
not
forward_only
:
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
if
not
forward_only
:
if
parallel_state
.
is_pipeline_last_stage
()
and
fwd_model_chunk_id
==
master_chunk_id
:
output_tensor_grad_bwd
=
input_tensor_grad
if
not
forward_only
:
if
parallel_state
.
is_pipeline_last_stage
()
and
fwd_model_chunk_id
==
master_chunk_id
:
output_tensor_grad
=
input_tensor_grad
else
:
output_tensor_grad
,
bwd_wait_handles
=
send_backward_recv_slave_backward
(
input_tensor_grad
,
tensor_shape
,
config
,
fwd_model_chunk_id
,
async_op
=
True
)
else
:
# send_backward_recv_slave_backward
output_tensor_grad_bwd
,
bwd_wait_handles
=
send_backward_recv_slave_backward
(
input_tensor_grad
,
tensor_shape
,
config
,
fwd_model_chunk_id
,
async_op
=
True
)
output_tensor_grad
=
None
return
output_tensor_grad
# forward
pp_pre_forward
()
fwd_microbatch
=
master_cur_microbatch
if
fwd_model_chunk_id
==
master_chunk_id
else
slave_cur_microbatch
output_tensor
=
forward_step_helper
(
input_tensor
,
fwd_model_chunk_id
,
fwd_microbatch
,
is_first_microbatch
=
False
)
input_tensor
=
pp_post_forward
(
output_tensor
)
# backward
pp_pre_backward
()
if
not
forward_only
:
input_tensor_bwd
=
input_tensors
[
bwd_model_chunk_id
].
pop
(
0
)[
1
]
output_tensor_bwd
=
output_tensors
[
bwd_model_chunk_id
].
pop
(
0
)
input_tensor_grad
=
backward_step_helper
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
)
else
:
input_tensor_grad
=
None
output_tensor_grad_bwd
=
pp_post_backward
(
input_tensor_grad
)
# only run backward
else
:
...
...
@@ -970,13 +1047,9 @@ def forward_backward_pipelining_with_cutinhalf(
req_handle
.
wait
()
bwd_wait_handles
=
None
input_tensor_bwd
=
input_tensors
[
bwd_model_chunk_id
].
pop
(
0
)[
1
]
output_tensor_bwd
=
output_tensors
[
bwd_model_chunk_id
].
pop
(
0
)
input_tensor_grad
=
backward_step
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
model_type
,
config
)
input_tensor_bwd
=
input_tensors
[
bwd_model_chunk_id
].
pop
(
0
)[
1
]
output_tensor_bwd
=
output_tensors
[
bwd_model_chunk_id
].
pop
(
0
)
input_tensor_grad
=
backward_step_helper
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
)
if
parallel_state
.
is_pipeline_last_stage
()
and
fwd_model_chunk_id
==
master_chunk_id
:
output_tensor_grad_bwd
=
input_tensor_grad
...
...
@@ -1022,15 +1095,7 @@ def forward_backward_pipelining_with_cutinhalf(
input_tensor_bwd
=
merged_input_tensors
.
pop
(
0
)[
1
]
output_tensor_bwd
,
bwd_model_chunk_id
=
merged_output_tensors
.
pop
(
0
)
# if not args.dualpipe_no_dw_detach:
# WeightGradStore.start_decouple()
input_tensor_grad
=
backward_step
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
model_type
,
config
)
# if not args.dualpipe_no_dw_detach:
# WeightGradStore.end_decouple()
input_tensor_grad
=
backward_step_helper
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
)
if
i
==
pp_size
-
1
:
bwd_wait_handles
=
send_backward
(
input_tensor_grad
,
...
...
@@ -1048,15 +1113,6 @@ def forward_backward_pipelining_with_cutinhalf(
output_tensor_grad_bwd
,
bwd_wait_handles
=
send_backward_recv_slave_backward
(
input_tensor_grad
,
tensor_shape
,
config
,
1
-
bwd_model_chunk_id
)
# WeightGradStore.flush_chunk_grad()
# if i >= schedule['cooldown'][rank][0] - 1:
# WeightGradStore.pop_single()
# for _ in range(schedule['cooldown'][rank][2] - 1):
# WeightGradStore.pop_single()
# assert WeightGradStore.weight_grad_queue.empty()
if
bwd_wait_handles
is
not
None
:
for
req
,
req_handle
in
bwd_wait_handles
.
items
():
if
req_handle
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment