Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
2b81ee55
Commit
2b81ee55
authored
Jun 03, 2025
by
dongcl
Browse files
dualpipev supports evaluation mode
parent
f3434cc7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
176 additions
and
146 deletions
+176
-146
dcu_megatron/core/pipeline_parallel/dualpipev/dualpipev_schedules.py
...n/core/pipeline_parallel/dualpipev/dualpipev_schedules.py
+176
-146
No files found.
dcu_megatron/core/pipeline_parallel/dualpipev/dualpipev_schedules.py
View file @
2b81ee55
...
...
@@ -597,22 +597,24 @@ def forward_backward_pipelining_with_cutinhalf(
config
,
collect_non_loss_data
,
checkpoint_activations_microbatch
,
is_first_microbatch
=
(
i
==
0
),
is_first_microbatch
=
check_first_val_step
(
first_val_step
,
forward_only
,
i
==
0
),
current_microbatch
=
master_cur_microbatch
)
total_num_tokens
+=
num_tokens
.
item
()
input_tensors
[
master_chunk_id
].
append
(
(
master_cur_microbatch
,
input_tensor
))
output_tensors
[
master_chunk_id
].
append
(
output_tensor_warmup
)
if
not
forward_only
:
input_tensors
[
master_chunk_id
].
append
(
(
master_cur_microbatch
,
input_tensor
))
output_tensors
[
master_chunk_id
].
append
(
output_tensor_warmup
)
master_cur_microbatch
+=
1
if
i
!=
schedule
[
'warmup'
][
rank
]
-
1
:
input_tensor
,
_
=
send_forward_recv_forward
(
output_tensor_warmup
,
tensor_shape
,
config
,
master_chunk_id
)
deallocate_output_tensor
(
output_tensor_warmup
,
config
.
deallocate_pipeline_outputs
)
if
not
forward_only
:
deallocate_output_tensor
(
output_tensor_warmup
,
config
.
deallocate_pipeline_outputs
)
else
:
input_tensor
,
_
=
recv_forward
(
tensor_shape
,
config
,
master_chunk_id
)
...
...
@@ -644,14 +646,15 @@ def forward_backward_pipelining_with_cutinhalf(
config
,
collect_non_loss_data
,
checkpoint_activations_microbatch
,
is_first_microbatch
=
is_first_microbatch
,
is_first_microbatch
=
check_first_val_step
(
first_val_step
,
forward_only
,
is_first_microbatch
)
,
current_microbatch
=
master_cur_microbatch
)
total_num_tokens
+=
num_tokens
.
item
()
input_tensors
[
master_chunk_id
].
append
(
(
master_cur_microbatch
,
input_tensor
))
output_tensors
[
master_chunk_id
].
append
(
output_tensor
)
if
not
forward_only
:
input_tensors
[
master_chunk_id
].
append
(
(
master_cur_microbatch
,
input_tensor
))
output_tensors
[
master_chunk_id
].
append
(
output_tensor
)
master_cur_microbatch
+=
1
...
...
@@ -659,13 +662,18 @@ def forward_backward_pipelining_with_cutinhalf(
for
req
,
req_handle
in
fwd_wait_handles_send
.
items
():
if
req_handle
is
not
None
:
req_handle
.
wait
()
deallocate_output_tensor
(
output_tensor_send
,
config
.
deallocate_pipeline_outputs
)
fwd_wait_handles_send
=
None
if
not
forward_only
:
deallocate_output_tensor
(
output_tensor_send
,
config
.
deallocate_pipeline_outputs
)
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
input_tensor_slave_chunk
=
output_tensor
.
detach
()
input_tensor_slave_chunk
.
requires_grad
=
True
if
not
forward_only
:
input_tensor_slave_chunk
=
output_tensor
.
detach
()
input_tensor_slave_chunk
.
requires_grad
=
True
else
:
input_tensor_slave_chunk
=
output_tensor
input_tensor
,
fwd_wait_handles
=
recv_forward
(
tensor_shape
,
config
,
master_chunk_id
,
async_op
=
True
)
...
...
@@ -680,18 +688,22 @@ def forward_backward_pipelining_with_cutinhalf(
for
req
,
req_handle
in
fwd_wait_handles_warmup
.
items
():
if
req_handle
is
not
None
:
req_handle
.
wait
()
deallocate_output_tensor
(
output_tensor_warmup
,
config
.
deallocate_pipeline_outputs
)
fwd_wait_handles_warmup
=
None
if
not
forward_only
:
deallocate_output_tensor
(
output_tensor_warmup
,
config
.
deallocate_pipeline_outputs
)
if
fwd_wait_handles_slave_chunk
is
not
None
:
for
req
,
req_handle
in
fwd_wait_handles_slave_chunk
.
items
():
if
req_handle
is
not
None
:
req_handle
.
wait
()
deallocate_output_tensor
(
output_tensor_slave_chunk
,
config
.
deallocate_pipeline_outputs
)
fwd_wait_handles_slave_chunk
=
None
if
not
forward_only
:
deallocate_output_tensor
(
output_tensor_slave_chunk
,
config
.
deallocate_pipeline_outputs
)
set_dualpipe_chunk
(
slave_chunk_id
)
output_tensor_slave_chunk
,
num_tokens
=
forward_step_no_model_graph
(
forward_step_func
,
...
...
@@ -707,15 +719,16 @@ def forward_backward_pipelining_with_cutinhalf(
current_microbatch
=
slave_cur_microbatch
,
)
input_tensors
[
slave_chunk_id
].
append
(
(
slave_cur_microbatch
,
input_tensor_slave_chunk
))
total_num_tokens
+=
num_tokens
.
item
()
output_tensors
[
slave_chunk_id
].
append
(
output_tensor_slave_chunk
)
if
not
forward_only
:
input_tensors
[
slave_chunk_id
].
append
(
(
slave_cur_microbatch
,
input_tensor_slave_chunk
))
output_tensors
[
slave_chunk_id
].
append
(
output_tensor_slave_chunk
)
slave_cur_microbatch
+=
1
if
i
==
schedule
[
'interleaved_forward'
][
rank
]
-
1
:
firstFB_no_overlp
=
False
firstFB_no_overlp_handle
=
None
# last rank not overlap first F&B
...
...
@@ -749,33 +762,37 @@ def forward_backward_pipelining_with_cutinhalf(
for
_
in
range
(
schedule
[
'1b1w1f'
][
rank
]):
# WeightGradStore.start_decouple()
if
not
forward_only
:
input_tensor_bwd
=
input_tensors
[
slave_chunk_id
].
pop
(
0
)[
1
]
output_tensor_bwd
=
output_tensors
[
slave_chunk_id
].
pop
(
0
)
input_tensor_bwd
=
input_tensors
[
slave_chunk_id
].
pop
(
0
)[
1
]
output_tensor_bwd
=
output_tensors
[
slave_chunk_id
].
pop
(
0
)
input_tensor_grad
=
backward_step
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
model_type
,
config
)
input_tensor_grad
=
backward_step
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
model_type
,
config
)
# WeightGradStore.end_decouple()
if
fwd_wait_handles_slave_chunk
is
not
None
:
for
req
in
fwd_wait_handles_slave_chunk
:
req
.
wait
()
deallocate_output_tensor
(
output_tensor_slave_chunk
,
config
.
deallocate_pipeline_outputs
)
fwd_wait_handles_slave_chunk
=
None
if
not
forward_only
:
deallocate_output_tensor
(
output_tensor_slave_chunk
,
config
.
deallocate_pipeline_outputs
)
if
fwd_wait_handles_send
is
not
None
:
for
req
,
req_handle
in
fwd_wait_handles_send
.
items
():
if
req_handle
is
not
None
:
req_handle
.
wait
()
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
fwd_wait_handles_send
=
None
if
not
forward_only
:
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
# If asynchronous, the memory will rise.
bwd_wait_handles
=
send_backward
(
input_tensor_grad
,
tensor_shape
,
config
,
slave_chunk_id
)
if
not
forward_only
:
# If asynchronous, the memory will rise.
bwd_wait_handles
=
send_backward
(
input_tensor_grad
,
tensor_shape
,
config
,
slave_chunk_id
)
# If asynchronous, the memory will rise.
input_tensor_slave_chunk
,
recv_forward_handle
=
recv_forward
(
...
...
@@ -806,15 +823,17 @@ def forward_backward_pipelining_with_cutinhalf(
current_microbatch
=
slave_cur_microbatch
)
input_tensors
[
slave_chunk_id
].
append
(
(
slave_cur_microbatch
,
input_tensor_slave_chunk
))
total_num_tokens
+=
num_tokens
.
item
()
output_tensors
[
slave_chunk_id
].
append
(
output_tensor_slave_chunk
)
if
not
forward_only
:
input_tensors
[
slave_chunk_id
].
append
(
(
slave_cur_microbatch
,
input_tensor_slave_chunk
))
output_tensors
[
slave_chunk_id
].
append
(
output_tensor_slave_chunk
)
slave_cur_microbatch
+=
1
output_tensor_grad_bwd
,
_
=
recv_backward
(
tensor_shape
,
config
,
slave_chunk_id
)
if
not
forward_only
:
output_tensor_grad_bwd
,
_
=
recv_backward
(
tensor_shape
,
config
,
slave_chunk_id
)
fwd_wait_handles_slave_chunk
=
send_forward
(
output_tensor_slave_chunk
,
tensor_shape
,
config
,
slave_chunk_id
,
async_op
=
True
)
...
...
@@ -847,36 +866,42 @@ def forward_backward_pipelining_with_cutinhalf(
checkpoint_activations_microbatch
,
current_microbatch
=
fwd_microbatch
)
input_tensors
[
fwd_model_chunk_id
].
append
(
(
fwd_microbatch
,
input_tensor
))
total_num_tokens
+=
num_tokens
.
item
()
output_tensors
[
fwd_model_chunk_id
].
append
(
output_tensor
)
if
not
forward_only
:
input_tensors
[
fwd_model_chunk_id
].
append
(
(
fwd_microbatch
,
input_tensor
))
output_tensors
[
fwd_model_chunk_id
].
append
(
output_tensor
)
if
fwd_model_chunk_id
==
master_chunk_id
:
master_cur_microbatch
+=
1
fwd_send_only
=
False
else
:
slave_cur_microbatch
+=
1
fwd_send_only
=
(
master_cur_microbatch
==
master_microbatch_max
)
fwd_send_only
=
(
master_cur_microbatch
==
master_microbatch_max
)
# 同步上个阶段最后一个slave前向send
if
fwd_wait_handles_slave_chunk
is
not
None
:
for
req
,
req_handle
in
fwd_wait_handles_slave_chunk
.
items
():
if
req_handle
is
not
None
:
req_handle
.
wait
()
deallocate_output_tensor
(
output_tensor_slave_chunk
,
config
.
deallocate_pipeline_outputs
)
fwd_wait_handles_slave_chunk
=
None
if
not
forward_only
:
deallocate_output_tensor
(
output_tensor_slave_chunk
,
config
.
deallocate_pipeline_outputs
)
if
fwd_send_only
:
fwd_wait_handles
=
send_forward
(
output_tensor
,
tensor_shape
,
config
,
fwd_model_chunk_id
,
async_op
=
True
)
else
:
if
parallel_state
.
is_pipeline_last_stage
()
and
fwd_model_chunk_id
==
master_chunk_id
:
input_tensor
=
output_tensor
.
detach
()
input_tensor
.
requires_grad
=
True
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
if
not
forward_only
:
input_tensor
=
output_tensor
.
detach
()
input_tensor
.
requires_grad
=
True
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
else
:
input_tensor
=
output_tensor
else
:
input_tensor
,
fwd_wait_handles
=
send_forward_recv_slave_forward
(
output_tensor
,
tensor_shape
,
config
,
fwd_model_chunk_id
,
async_op
=
True
)
...
...
@@ -887,123 +912,128 @@ def forward_backward_pipelining_with_cutinhalf(
req_handle
.
wait
()
firstFB_no_overlp_handle
=
None
if
bwd_wait_handles
is
not
None
:
for
req
,
req_handle
in
bwd_wait_handles
.
items
():
if
req_handle
is
not
None
:
req_handle
.
wait
()
bwd_wait_handles
=
None
if
not
forward_only
:
if
bwd_wait_handles
is
not
None
:
for
req
,
req_handle
in
bwd_wait_handles
.
items
():
if
req_handle
is
not
None
:
req_handle
.
wait
()
bwd_wait_handles
=
None
input_tensor_bwd
=
input_tensors
[
bwd_model_chunk_id
].
pop
(
0
)[
1
]
output_tensor_bwd
=
output_tensors
[
bwd_model_chunk_id
].
pop
(
0
)
input_tensor_bwd
=
input_tensors
[
bwd_model_chunk_id
].
pop
(
0
)[
1
]
output_tensor_bwd
=
output_tensors
[
bwd_model_chunk_id
].
pop
(
0
)
input_tensor_grad
=
backward_step
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
model_type
,
config
)
input_tensor_grad
=
backward_step
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
model_type
,
config
)
if
fwd_wait_handles
is
not
None
:
for
req
,
req_handle
in
fwd_wait_handles
.
items
():
if
req_handle
is
not
None
:
req_handle
.
wait
()
fwd_wait_handles
=
None
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
if
not
forward_only
:
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
if
parallel_state
.
is_pipeline_last_stage
()
and
fwd_model_chunk_id
==
master_chunk_id
:
output_tensor_grad_bwd
=
input_tensor_grad
else
:
# send_backward_recv_slave_backward
output_tensor_grad_bwd
,
bwd_wait_handles
=
send_backward_recv_slave_backward
(
input_tensor_grad
,
tensor_shape
,
config
,
fwd_model_chunk_id
,
async_op
=
True
)
if
not
forward_only
:
if
parallel_state
.
is_pipeline_last_stage
()
and
fwd_model_chunk_id
==
master_chunk_id
:
output_tensor_grad_bwd
=
input_tensor_grad
else
:
# send_backward_recv_slave_backward
output_tensor_grad_bwd
,
bwd_wait_handles
=
send_backward_recv_slave_backward
(
input_tensor_grad
,
tensor_shape
,
config
,
fwd_model_chunk_id
,
async_op
=
True
)
# only run backward
else
:
if
bwd_model_chunk_id
==
slave_chunk_id
and
slave_cur_microbatch
<
slave_microbatch_max
:
input_tensor
,
_
=
recv_forward
(
tensor_shape
,
config
,
slave_chunk_id
)
if
bwd_wait_handles
is
not
None
:
for
req
,
req_handle
in
bwd_wait_handles
.
items
():
if
req_handle
is
not
None
:
req_handle
.
wait
()
bwd_wait_handles
=
None
if
not
forward_only
:
if
bwd_wait_handles
is
not
None
:
for
req
,
req_handle
in
bwd_wait_handles
.
items
():
if
req_handle
is
not
None
:
req_handle
.
wait
()
bwd_wait_handles
=
None
input_tensor_bwd
=
input_tensors
[
bwd_model_chunk_id
].
pop
(
0
)[
1
]
output_tensor_bwd
=
output_tensors
[
bwd_model_chunk_id
].
pop
(
0
)
input_tensor_grad
=
backward_step
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
model_type
,
config
)
input_tensor_bwd
=
input_tensors
[
bwd_model_chunk_id
].
pop
(
0
)[
1
]
output_tensor_bwd
=
output_tensors
[
bwd_model_chunk_id
].
pop
(
0
)
input_tensor_grad
=
backward_step
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
model_type
,
config
)
if
parallel_state
.
is_pipeline_last_stage
()
and
fwd_model_chunk_id
==
master_chunk_id
:
output_tensor_grad_bwd
=
input_tensor_grad
else
:
# send_backward_recv_slave_backward
output_tensor_grad_bwd
,
bwd_wait_handles
=
send_forward_recv_slave_forward
(
input_tensor_grad
,
tensor_shape
,
config
,
fwd_model_chunk_id
)
if
parallel_state
.
is_pipeline_last_stage
()
and
fwd_model_chunk_id
==
master_chunk_id
:
output_tensor_grad_bwd
=
input_tensor_grad
else
:
# send_backward_recv_slave_backward
output_tensor_grad_bwd
,
bwd_wait_handles
=
send_backward_recv_slave_backward
(
input_tensor_grad
,
tensor_shape
,
config
,
fwd_model_chunk_id
)
# swap fwd & bwd chunks
fwd_model_chunk_id
,
bwd_model_chunk_id
=
bwd_model_chunk_id
,
fwd_model_chunk_id
# Run cooldown phases
merged_input_tensors
=
[]
merged_output_tensors
=
[]
while
len
(
input_tensors
[
0
])
>
0
or
len
(
input_tensors
[
1
])
>
0
:
if
len
(
input_tensors
[
bwd_model_chunk_id
])
>
0
:
merged_input_tensors
.
append
(
input_tensors
[
bwd_model_chunk_id
].
pop
(
0
))
merged_output_tensors
.
append
(
(
output_tensors
[
bwd_model_chunk_id
].
pop
(
0
),
bwd_model_chunk_id
))
if
len
(
input_tensors
[
1
-
bwd_model_chunk_id
])
>
0
:
merged_input_tensors
.
append
(
input_tensors
[
1
-
bwd_model_chunk_id
].
pop
(
0
))
merged_output_tensors
.
append
(
(
output_tensors
[
1
-
bwd_model_chunk_id
].
pop
(
0
),
1
-
bwd_model_chunk_id
))
bwd_wait_handles_recv
=
None
for
i
in
range
(
pp_size
):
if
not
forward_only
:
# Run cooldown phases
merged_input_tensors
=
[]
merged_output_tensors
=
[]
while
len
(
input_tensors
[
0
])
>
0
or
len
(
input_tensors
[
1
])
>
0
:
if
len
(
input_tensors
[
bwd_model_chunk_id
])
>
0
:
merged_input_tensors
.
append
(
input_tensors
[
bwd_model_chunk_id
].
pop
(
0
))
merged_output_tensors
.
append
(
(
output_tensors
[
bwd_model_chunk_id
].
pop
(
0
),
bwd_model_chunk_id
))
if
len
(
input_tensors
[
1
-
bwd_model_chunk_id
])
>
0
:
merged_input_tensors
.
append
(
input_tensors
[
1
-
bwd_model_chunk_id
].
pop
(
0
))
merged_output_tensors
.
append
(
(
output_tensors
[
1
-
bwd_model_chunk_id
].
pop
(
0
),
1
-
bwd_model_chunk_id
))
bwd_wait_handles_recv
=
None
for
i
in
range
(
pp_size
):
if
bwd_wait_handles
is
not
None
:
for
req
,
req_handle
in
bwd_wait_handles
.
items
():
if
req_handle
is
not
None
:
req_handle
.
wait
()
bwd_wait_handles
=
None
if
bwd_wait_handles_recv
is
not
None
:
for
req
,
req_handle
in
bwd_wait_handles_recv
.
items
():
if
req_handle
is
not
None
:
req_handle
.
wait
()
bwd_wait_handles_recv
=
None
input_tensor_bwd
=
merged_input_tensors
.
pop
(
0
)[
1
]
output_tensor_bwd
,
bwd_model_chunk_id
=
merged_output_tensors
.
pop
(
0
)
if
bwd_wait_handles
is
not
None
:
for
req
,
req_handle
in
bwd_wait_handles
.
items
():
if
req_handle
is
not
None
:
req_handle
.
wait
()
bwd_wait_handles
=
None
if
bwd_wait_handles_recv
is
not
None
:
for
req
,
req_handle
in
bwd_wait_handles_recv
.
items
():
if
req_handle
is
not
None
:
req_handle
.
wait
()
bwd_wait_handles_recv
=
None
# if not args.dualpipe_no_dw_detach:
#
WeightGradStore.start_decouple(
)
input_tensor_bwd
=
merged_input_tensors
.
pop
(
0
)[
1
]
output_tensor_bwd
,
bwd_model_chunk_id
=
merged_output_tensors
.
pop
(
0
)
input_tensor_grad
=
backward_step
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
model_type
,
config
)
# if not args.dualpipe_no_dw_detach:
# WeightGradStore.start_decouple()
# if not args.dualpipe_no_dw_detach:
# WeightGradStore.end_decouple()
input_tensor_grad
=
backward_step
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
model_type
,
config
)
if
i
==
pp_size
-
1
:
bwd_wait_handles
=
send_backward
(
input_tensor_grad
,
tensor_shape
,
config
,
bwd_model_chunk_id
,
async_op
=
True
)
el
if
i
>
=
schedule
[
'cooldown'
][
rank
][
0
]
-
1
:
bwd_wait_handles
=
send_backward
(
input_tensor_grad
,
tensor_shape
,
config
,
bwd_model_chunk_id
,
async_op
=
True
)
output_tensor_grad_bwd
,
bwd_wait_handles_recv
=
recv_backward
(
tensor_shape
,
config
,
bwd_model_chunk_id
,
async_op
=
True
)
else
:
if
parallel_state
.
is_pipeline_last_stage
()
and
(
1
-
bwd_model_chunk_id
)
==
master_chunk_id
:
output_tensor_grad_bwd
=
input_tensor_grad
# if not args.dualpipe_no_dw_detach
:
# WeightGradStore.end_decouple()
if
i
=
=
pp_size
-
1
:
bwd_wait_handles
=
send_backward
(
input_tensor_grad
,
tensor_shape
,
config
,
bwd_model_chunk_id
,
async_op
=
True
)
elif
i
>=
schedule
[
'cooldown'
][
rank
][
0
]
-
1
:
bwd_wait_handles
=
send_backward
(
input_tensor_grad
,
tensor_shape
,
config
,
bwd_model_chunk_id
,
async_op
=
True
)
output_tensor_grad_bwd
,
bwd_wait_handles_recv
=
recv_backward
(
tensor_shape
,
config
,
bwd_model_chunk_id
,
async_op
=
True
)
else
:
# send_backward_recv_slave_backward
output_tensor_grad_bwd
,
bwd_wait_handles
=
send_backward_recv_slave_backward
(
input_tensor_grad
,
tensor_shape
,
config
,
1
-
bwd_model_chunk_id
)
if
parallel_state
.
is_pipeline_last_stage
()
and
(
1
-
bwd_model_chunk_id
)
==
master_chunk_id
:
output_tensor_grad_bwd
=
input_tensor_grad
else
:
# send_backward_recv_slave_backward
output_tensor_grad_bwd
,
bwd_wait_handles
=
send_backward_recv_slave_backward
(
input_tensor_grad
,
tensor_shape
,
config
,
1
-
bwd_model_chunk_id
)
# WeightGradStore.flush_chunk_grad()
# if i >= schedule['cooldown'][rank][0] - 1:
...
...
@@ -1014,11 +1044,11 @@ def forward_backward_pipelining_with_cutinhalf(
# assert WeightGradStore.weight_grad_queue.empty()
if
bwd_wait_handles
is
not
None
:
for
req
,
req_handle
in
bwd_wait_handles
.
items
():
if
req_handle
is
not
None
:
req_handle
.
wait
()
bwd_wait_handles
=
None
if
bwd_wait_handles
is
not
None
:
for
req
,
req_handle
in
bwd_wait_handles
.
items
():
if
req_handle
is
not
None
:
req_handle
.
wait
()
bwd_wait_handles
=
None
if
config
.
finalize_model_grads_func
is
not
None
and
not
forward_only
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment