Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
2b81ee55
Commit
2b81ee55
authored
Jun 03, 2025
by
dongcl
Browse files
dualpipev supports evaluation mode
parent
f3434cc7
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
176 additions
and
146 deletions
+176
-146
dcu_megatron/core/pipeline_parallel/dualpipev/dualpipev_schedules.py
...n/core/pipeline_parallel/dualpipev/dualpipev_schedules.py
+176
-146
No files found.
dcu_megatron/core/pipeline_parallel/dualpipev/dualpipev_schedules.py
View file @
2b81ee55
...
@@ -597,11 +597,12 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -597,11 +597,12 @@ def forward_backward_pipelining_with_cutinhalf(
config
,
config
,
collect_non_loss_data
,
collect_non_loss_data
,
checkpoint_activations_microbatch
,
checkpoint_activations_microbatch
,
is_first_microbatch
=
(
i
==
0
),
is_first_microbatch
=
check_first_val_step
(
first_val_step
,
forward_only
,
i
==
0
),
current_microbatch
=
master_cur_microbatch
current_microbatch
=
master_cur_microbatch
)
)
total_num_tokens
+=
num_tokens
.
item
()
total_num_tokens
+=
num_tokens
.
item
()
if
not
forward_only
:
input_tensors
[
master_chunk_id
].
append
(
input_tensors
[
master_chunk_id
].
append
(
(
master_cur_microbatch
,
input_tensor
))
(
master_cur_microbatch
,
input_tensor
))
output_tensors
[
master_chunk_id
].
append
(
output_tensor_warmup
)
output_tensors
[
master_chunk_id
].
append
(
output_tensor_warmup
)
...
@@ -611,6 +612,7 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -611,6 +612,7 @@ def forward_backward_pipelining_with_cutinhalf(
if
i
!=
schedule
[
'warmup'
][
rank
]
-
1
:
if
i
!=
schedule
[
'warmup'
][
rank
]
-
1
:
input_tensor
,
_
=
send_forward_recv_forward
(
input_tensor
,
_
=
send_forward_recv_forward
(
output_tensor_warmup
,
tensor_shape
,
config
,
master_chunk_id
)
output_tensor_warmup
,
tensor_shape
,
config
,
master_chunk_id
)
if
not
forward_only
:
deallocate_output_tensor
(
deallocate_output_tensor
(
output_tensor_warmup
,
config
.
deallocate_pipeline_outputs
)
output_tensor_warmup
,
config
.
deallocate_pipeline_outputs
)
else
:
else
:
...
@@ -644,11 +646,12 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -644,11 +646,12 @@ def forward_backward_pipelining_with_cutinhalf(
config
,
config
,
collect_non_loss_data
,
collect_non_loss_data
,
checkpoint_activations_microbatch
,
checkpoint_activations_microbatch
,
is_first_microbatch
=
is_first_microbatch
,
is_first_microbatch
=
check_first_val_step
(
first_val_step
,
forward_only
,
is_first_microbatch
)
,
current_microbatch
=
master_cur_microbatch
current_microbatch
=
master_cur_microbatch
)
)
total_num_tokens
+=
num_tokens
.
item
()
total_num_tokens
+=
num_tokens
.
item
()
if
not
forward_only
:
input_tensors
[
master_chunk_id
].
append
(
input_tensors
[
master_chunk_id
].
append
(
(
master_cur_microbatch
,
input_tensor
))
(
master_cur_microbatch
,
input_tensor
))
output_tensors
[
master_chunk_id
].
append
(
output_tensor
)
output_tensors
[
master_chunk_id
].
append
(
output_tensor
)
...
@@ -659,13 +662,18 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -659,13 +662,18 @@ def forward_backward_pipelining_with_cutinhalf(
for
req
,
req_handle
in
fwd_wait_handles_send
.
items
():
for
req
,
req_handle
in
fwd_wait_handles_send
.
items
():
if
req_handle
is
not
None
:
if
req_handle
is
not
None
:
req_handle
.
wait
()
req_handle
.
wait
()
fwd_wait_handles_send
=
None
if
not
forward_only
:
deallocate_output_tensor
(
deallocate_output_tensor
(
output_tensor_send
,
config
.
deallocate_pipeline_outputs
)
output_tensor_send
,
config
.
deallocate_pipeline_outputs
)
fwd_wait_handles_send
=
None
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
if
not
forward_only
:
input_tensor_slave_chunk
=
output_tensor
.
detach
()
input_tensor_slave_chunk
=
output_tensor
.
detach
()
input_tensor_slave_chunk
.
requires_grad
=
True
input_tensor_slave_chunk
.
requires_grad
=
True
else
:
input_tensor_slave_chunk
=
output_tensor
input_tensor
,
fwd_wait_handles
=
recv_forward
(
input_tensor
,
fwd_wait_handles
=
recv_forward
(
tensor_shape
,
config
,
master_chunk_id
,
async_op
=
True
)
tensor_shape
,
config
,
master_chunk_id
,
async_op
=
True
)
...
@@ -680,17 +688,21 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -680,17 +688,21 @@ def forward_backward_pipelining_with_cutinhalf(
for
req
,
req_handle
in
fwd_wait_handles_warmup
.
items
():
for
req
,
req_handle
in
fwd_wait_handles_warmup
.
items
():
if
req_handle
is
not
None
:
if
req_handle
is
not
None
:
req_handle
.
wait
()
req_handle
.
wait
()
fwd_wait_handles_warmup
=
None
if
not
forward_only
:
deallocate_output_tensor
(
deallocate_output_tensor
(
output_tensor_warmup
,
config
.
deallocate_pipeline_outputs
)
output_tensor_warmup
,
config
.
deallocate_pipeline_outputs
)
fwd_wait_handles_warmup
=
None
if
fwd_wait_handles_slave_chunk
is
not
None
:
if
fwd_wait_handles_slave_chunk
is
not
None
:
for
req
,
req_handle
in
fwd_wait_handles_slave_chunk
.
items
():
for
req
,
req_handle
in
fwd_wait_handles_slave_chunk
.
items
():
if
req_handle
is
not
None
:
if
req_handle
is
not
None
:
req_handle
.
wait
()
req_handle
.
wait
()
fwd_wait_handles_slave_chunk
=
None
if
not
forward_only
:
deallocate_output_tensor
(
deallocate_output_tensor
(
output_tensor_slave_chunk
,
config
.
deallocate_pipeline_outputs
)
output_tensor_slave_chunk
,
config
.
deallocate_pipeline_outputs
)
fwd_wait_handles_slave_chunk
=
None
set_dualpipe_chunk
(
slave_chunk_id
)
set_dualpipe_chunk
(
slave_chunk_id
)
output_tensor_slave_chunk
,
num_tokens
=
forward_step_no_model_graph
(
output_tensor_slave_chunk
,
num_tokens
=
forward_step_no_model_graph
(
...
@@ -707,15 +719,16 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -707,15 +719,16 @@ def forward_backward_pipelining_with_cutinhalf(
current_microbatch
=
slave_cur_microbatch
,
current_microbatch
=
slave_cur_microbatch
,
)
)
total_num_tokens
+=
num_tokens
.
item
()
if
not
forward_only
:
input_tensors
[
slave_chunk_id
].
append
(
input_tensors
[
slave_chunk_id
].
append
(
(
slave_cur_microbatch
,
input_tensor_slave_chunk
))
(
slave_cur_microbatch
,
input_tensor_slave_chunk
))
total_num_tokens
+=
num_tokens
.
item
()
output_tensors
[
slave_chunk_id
].
append
(
output_tensor_slave_chunk
)
output_tensors
[
slave_chunk_id
].
append
(
output_tensor_slave_chunk
)
slave_cur_microbatch
+=
1
slave_cur_microbatch
+=
1
if
i
==
schedule
[
'interleaved_forward'
][
rank
]
-
1
:
if
i
==
schedule
[
'interleaved_forward'
][
rank
]
-
1
:
firstFB_no_overlp
=
False
firstFB_no_overlp
=
False
firstFB_no_overlp_handle
=
None
firstFB_no_overlp_handle
=
None
# last rank not overlap first F&B
# last rank not overlap first F&B
...
@@ -749,7 +762,7 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -749,7 +762,7 @@ def forward_backward_pipelining_with_cutinhalf(
for
_
in
range
(
schedule
[
'1b1w1f'
][
rank
]):
for
_
in
range
(
schedule
[
'1b1w1f'
][
rank
]):
# WeightGradStore.start_decouple()
# WeightGradStore.start_decouple()
if
not
forward_only
:
input_tensor_bwd
=
input_tensors
[
slave_chunk_id
].
pop
(
0
)[
1
]
input_tensor_bwd
=
input_tensors
[
slave_chunk_id
].
pop
(
0
)[
1
]
output_tensor_bwd
=
output_tensors
[
slave_chunk_id
].
pop
(
0
)
output_tensor_bwd
=
output_tensors
[
slave_chunk_id
].
pop
(
0
)
...
@@ -762,17 +775,21 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -762,17 +775,21 @@ def forward_backward_pipelining_with_cutinhalf(
if
fwd_wait_handles_slave_chunk
is
not
None
:
if
fwd_wait_handles_slave_chunk
is
not
None
:
for
req
in
fwd_wait_handles_slave_chunk
:
for
req
in
fwd_wait_handles_slave_chunk
:
req
.
wait
()
req
.
wait
()
fwd_wait_handles_slave_chunk
=
None
if
not
forward_only
:
deallocate_output_tensor
(
deallocate_output_tensor
(
output_tensor_slave_chunk
,
config
.
deallocate_pipeline_outputs
)
output_tensor_slave_chunk
,
config
.
deallocate_pipeline_outputs
)
fwd_wait_handles_slave_chunk
=
None
if
fwd_wait_handles_send
is
not
None
:
if
fwd_wait_handles_send
is
not
None
:
for
req
,
req_handle
in
fwd_wait_handles_send
.
items
():
for
req
,
req_handle
in
fwd_wait_handles_send
.
items
():
if
req_handle
is
not
None
:
if
req_handle
is
not
None
:
req_handle
.
wait
()
req_handle
.
wait
()
fwd_wait_handles_send
=
None
if
not
forward_only
:
deallocate_output_tensor
(
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
output_tensor
,
config
.
deallocate_pipeline_outputs
)
fwd_wait_handles_send
=
None
if
not
forward_only
:
# If asynchronous, the memory will rise.
# If asynchronous, the memory will rise.
bwd_wait_handles
=
send_backward
(
input_tensor_grad
,
bwd_wait_handles
=
send_backward
(
input_tensor_grad
,
tensor_shape
,
config
,
slave_chunk_id
)
tensor_shape
,
config
,
slave_chunk_id
)
...
@@ -806,13 +823,15 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -806,13 +823,15 @@ def forward_backward_pipelining_with_cutinhalf(
current_microbatch
=
slave_cur_microbatch
current_microbatch
=
slave_cur_microbatch
)
)
total_num_tokens
+=
num_tokens
.
item
()
if
not
forward_only
:
input_tensors
[
slave_chunk_id
].
append
(
input_tensors
[
slave_chunk_id
].
append
(
(
slave_cur_microbatch
,
input_tensor_slave_chunk
))
(
slave_cur_microbatch
,
input_tensor_slave_chunk
))
total_num_tokens
+=
num_tokens
.
item
()
output_tensors
[
slave_chunk_id
].
append
(
output_tensor_slave_chunk
)
output_tensors
[
slave_chunk_id
].
append
(
output_tensor_slave_chunk
)
slave_cur_microbatch
+=
1
slave_cur_microbatch
+=
1
if
not
forward_only
:
output_tensor_grad_bwd
,
_
=
recv_backward
(
output_tensor_grad_bwd
,
_
=
recv_backward
(
tensor_shape
,
config
,
slave_chunk_id
)
tensor_shape
,
config
,
slave_chunk_id
)
...
@@ -847,9 +866,12 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -847,9 +866,12 @@ def forward_backward_pipelining_with_cutinhalf(
checkpoint_activations_microbatch
,
checkpoint_activations_microbatch
,
current_microbatch
=
fwd_microbatch
current_microbatch
=
fwd_microbatch
)
)
total_num_tokens
+=
num_tokens
.
item
()
if
not
forward_only
:
input_tensors
[
fwd_model_chunk_id
].
append
(
input_tensors
[
fwd_model_chunk_id
].
append
(
(
fwd_microbatch
,
input_tensor
))
(
fwd_microbatch
,
input_tensor
))
total_num_tokens
+=
num_tokens
.
item
()
output_tensors
[
fwd_model_chunk_id
].
append
(
output_tensor
)
output_tensors
[
fwd_model_chunk_id
].
append
(
output_tensor
)
if
fwd_model_chunk_id
==
master_chunk_id
:
if
fwd_model_chunk_id
==
master_chunk_id
:
...
@@ -857,26 +879,29 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -857,26 +879,29 @@ def forward_backward_pipelining_with_cutinhalf(
fwd_send_only
=
False
fwd_send_only
=
False
else
:
else
:
slave_cur_microbatch
+=
1
slave_cur_microbatch
+=
1
fwd_send_only
=
(
master_cur_microbatch
==
fwd_send_only
=
(
master_cur_microbatch
==
master_microbatch_max
)
master_microbatch_max
)
# 同步上个阶段最后一个slave前向send
# 同步上个阶段最后一个slave前向send
if
fwd_wait_handles_slave_chunk
is
not
None
:
if
fwd_wait_handles_slave_chunk
is
not
None
:
for
req
,
req_handle
in
fwd_wait_handles_slave_chunk
.
items
():
for
req
,
req_handle
in
fwd_wait_handles_slave_chunk
.
items
():
if
req_handle
is
not
None
:
if
req_handle
is
not
None
:
req_handle
.
wait
()
req_handle
.
wait
()
fwd_wait_handles_slave_chunk
=
None
if
not
forward_only
:
deallocate_output_tensor
(
deallocate_output_tensor
(
output_tensor_slave_chunk
,
config
.
deallocate_pipeline_outputs
)
output_tensor_slave_chunk
,
config
.
deallocate_pipeline_outputs
)
fwd_wait_handles_slave_chunk
=
None
if
fwd_send_only
:
if
fwd_send_only
:
fwd_wait_handles
=
send_forward
(
fwd_wait_handles
=
send_forward
(
output_tensor
,
tensor_shape
,
config
,
fwd_model_chunk_id
,
async_op
=
True
)
output_tensor
,
tensor_shape
,
config
,
fwd_model_chunk_id
,
async_op
=
True
)
else
:
else
:
if
parallel_state
.
is_pipeline_last_stage
()
and
fwd_model_chunk_id
==
master_chunk_id
:
if
parallel_state
.
is_pipeline_last_stage
()
and
fwd_model_chunk_id
==
master_chunk_id
:
if
not
forward_only
:
input_tensor
=
output_tensor
.
detach
()
input_tensor
=
output_tensor
.
detach
()
input_tensor
.
requires_grad
=
True
input_tensor
.
requires_grad
=
True
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
else
:
input_tensor
=
output_tensor
else
:
else
:
input_tensor
,
fwd_wait_handles
=
send_forward_recv_slave_forward
(
input_tensor
,
fwd_wait_handles
=
send_forward_recv_slave_forward
(
output_tensor
,
tensor_shape
,
config
,
fwd_model_chunk_id
,
async_op
=
True
)
output_tensor
,
tensor_shape
,
config
,
fwd_model_chunk_id
,
async_op
=
True
)
...
@@ -887,6 +912,7 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -887,6 +912,7 @@ def forward_backward_pipelining_with_cutinhalf(
req_handle
.
wait
()
req_handle
.
wait
()
firstFB_no_overlp_handle
=
None
firstFB_no_overlp_handle
=
None
if
not
forward_only
:
if
bwd_wait_handles
is
not
None
:
if
bwd_wait_handles
is
not
None
:
for
req
,
req_handle
in
bwd_wait_handles
.
items
():
for
req
,
req_handle
in
bwd_wait_handles
.
items
():
if
req_handle
is
not
None
:
if
req_handle
is
not
None
:
...
@@ -907,9 +933,11 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -907,9 +933,11 @@ def forward_backward_pipelining_with_cutinhalf(
if
req_handle
is
not
None
:
if
req_handle
is
not
None
:
req_handle
.
wait
()
req_handle
.
wait
()
fwd_wait_handles
=
None
fwd_wait_handles
=
None
if
not
forward_only
:
deallocate_output_tensor
(
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
output_tensor
,
config
.
deallocate_pipeline_outputs
)
if
not
forward_only
:
if
parallel_state
.
is_pipeline_last_stage
()
and
fwd_model_chunk_id
==
master_chunk_id
:
if
parallel_state
.
is_pipeline_last_stage
()
and
fwd_model_chunk_id
==
master_chunk_id
:
output_tensor_grad_bwd
=
input_tensor_grad
output_tensor_grad_bwd
=
input_tensor_grad
else
:
else
:
...
@@ -922,6 +950,7 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -922,6 +950,7 @@ def forward_backward_pipelining_with_cutinhalf(
if
bwd_model_chunk_id
==
slave_chunk_id
and
slave_cur_microbatch
<
slave_microbatch_max
:
if
bwd_model_chunk_id
==
slave_chunk_id
and
slave_cur_microbatch
<
slave_microbatch_max
:
input_tensor
,
_
=
recv_forward
(
input_tensor
,
_
=
recv_forward
(
tensor_shape
,
config
,
slave_chunk_id
)
tensor_shape
,
config
,
slave_chunk_id
)
if
not
forward_only
:
if
bwd_wait_handles
is
not
None
:
if
bwd_wait_handles
is
not
None
:
for
req
,
req_handle
in
bwd_wait_handles
.
items
():
for
req
,
req_handle
in
bwd_wait_handles
.
items
():
if
req_handle
is
not
None
:
if
req_handle
is
not
None
:
...
@@ -940,12 +969,13 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -940,12 +969,13 @@ def forward_backward_pipelining_with_cutinhalf(
output_tensor_grad_bwd
=
input_tensor_grad
output_tensor_grad_bwd
=
input_tensor_grad
else
:
else
:
# send_backward_recv_slave_backward
# send_backward_recv_slave_backward
output_tensor_grad_bwd
,
bwd_wait_handles
=
send_
for
ward_recv_slave_
for
ward
(
input_tensor_grad
,
output_tensor_grad_bwd
,
bwd_wait_handles
=
send_
back
ward_recv_slave_
back
ward
(
input_tensor_grad
,
tensor_shape
,
config
,
fwd_model_chunk_id
)
tensor_shape
,
config
,
fwd_model_chunk_id
)
# swap fwd & bwd chunks
# swap fwd & bwd chunks
fwd_model_chunk_id
,
bwd_model_chunk_id
=
bwd_model_chunk_id
,
fwd_model_chunk_id
fwd_model_chunk_id
,
bwd_model_chunk_id
=
bwd_model_chunk_id
,
fwd_model_chunk_id
if
not
forward_only
:
# Run cooldown phases
# Run cooldown phases
merged_input_tensors
=
[]
merged_input_tensors
=
[]
merged_output_tensors
=
[]
merged_output_tensors
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment