Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
a58a2da6
Commit
a58a2da6
authored
Jun 04, 2025
by
dongcl
Browse files
if forward_only is true, recv_backward should not be called
parent
2b81ee55
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
168 additions
and
22 deletions
+168
-22
dcu_megatron/core/pipeline_parallel/dualpipev/dualpipev_schedules.py
...n/core/pipeline_parallel/dualpipev/dualpipev_schedules.py
+33
-20
dcu_megatron/training/training.py
dcu_megatron/training/training.py
+126
-2
dcu_megatron/training/utils.py
dcu_megatron/training/utils.py
+9
-0
No files found.
dcu_megatron/core/pipeline_parallel/dualpipev/dualpipev_schedules.py
View file @
a58a2da6
...
@@ -19,7 +19,12 @@ from megatron.core.utils import (
...
@@ -19,7 +19,12 @@ from megatron.core.utils import (
from
megatron.core.pipeline_parallel.schedules
import
clear_embedding_activation_buffer
,
deallocate_output_tensor
from
megatron.core.pipeline_parallel.schedules
import
clear_embedding_activation_buffer
,
deallocate_output_tensor
from
megatron.core
import
ModelParallelConfig
from
megatron.core
import
ModelParallelConfig
from
megatron.core.pipeline_parallel.p2p_communication
import
_communicate
from
megatron.core.pipeline_parallel.p2p_communication
import
_communicate
from
megatron.core.pipeline_parallel.schedules
import
backward_step
,
set_current_microbatch
,
finish_embedding_wgrad_compute
from
megatron.core.pipeline_parallel.schedules
import
(
backward_step
,
set_current_microbatch
,
check_first_val_step
,
finish_embedding_wgrad_compute
)
# from mindspeed.core.pipeline_parallel.fb_overlap.modules.weight_grad_store import WeightGradStore
# from mindspeed.core.pipeline_parallel.fb_overlap.modules.weight_grad_store import WeightGradStore
...
@@ -114,7 +119,7 @@ def send_backward(input_tensor_grad: torch.Tensor, tensor_shape, config: ModelPa
...
@@ -114,7 +119,7 @@ def send_backward(input_tensor_grad: torch.Tensor, tensor_shape, config: ModelPa
return
reqs
return
reqs
def
recv_forward
(
tensor_shape
:
Shape
,
config
:
ModelParallelConfig
,
model_chunk_id
,
async_op
=
False
,
step
=-
1
)
->
torch
.
Tensor
:
def
recv_forward
(
tensor_shape
:
Shape
,
config
:
ModelParallelConfig
,
model_chunk_id
,
async_op
=
False
)
->
torch
.
Tensor
:
""" Receive tensor from previous rank in pipeline (forward receive).
""" Receive tensor from previous rank in pipeline (forward receive).
See _communicate for argument details.
See _communicate for argument details.
...
@@ -565,9 +570,10 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -565,9 +570,10 @@ def forward_backward_pipelining_with_cutinhalf(
tensor_shape
[
0
]
=
tensor_shape
[
0
]
//
parallel_state
.
get_tensor_model_parallel_world_size
()
tensor_shape
[
0
]
=
tensor_shape
[
0
]
//
parallel_state
.
get_tensor_model_parallel_world_size
()
total_num_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
int
).
cuda
()
total_num_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
int
).
cuda
()
input_tensors
=
[[],
[]]
output_tensors
=
[[],
[]]
forward_data_store
=
[]
forward_data_store
=
[]
if
not
forward_only
:
input_tensors
=
[[],
[]]
output_tensors
=
[[],
[]]
master_chunk_id
=
0
master_chunk_id
=
0
slave_chunk_id
=
1
slave_chunk_id
=
1
...
@@ -728,17 +734,16 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -728,17 +734,16 @@ def forward_backward_pipelining_with_cutinhalf(
slave_cur_microbatch
+=
1
slave_cur_microbatch
+=
1
if
i
==
schedule
[
'interleaved_forward'
][
rank
]
-
1
:
if
not
forward_only
:
firstFB_no_overlp
=
False
if
i
==
schedule
[
'interleaved_forward'
][
rank
]
-
1
:
firstFB_no_overlp_handle
=
None
firstFB_no_overlp_handle
=
None
# last rank not overlap first F&B
# last rank not overlap first F&B
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
firstFB_no_overlp
=
True
output_tensor_grad_bwd
,
firstFB_no_overlp_handle
=
recv_backward
(
output_tensor_grad_bwd
,
firstFB_no_overlp_handle
=
recv_backward
(
tensor_shape
,
config
,
slave_chunk_id
,
async_op
=
True
)
tensor_shape
,
config
,
slave_chunk_id
,
async_op
=
True
)
else
:
else
:
output_tensor_grad_bwd
,
_
=
recv_backward
(
output_tensor_grad_bwd
,
_
=
recv_backward
(
tensor_shape
,
config
,
slave_chunk_id
)
tensor_shape
,
config
,
slave_chunk_id
)
fwd_wait_handles_slave_chunk
=
send_forward
(
output_tensor_slave_chunk
,
fwd_wait_handles_slave_chunk
=
send_forward
(
output_tensor_slave_chunk
,
tensor_shape
,
config
,
slave_chunk_id
,
async_op
=
True
)
tensor_shape
,
config
,
slave_chunk_id
,
async_op
=
True
)
...
@@ -838,11 +843,14 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -838,11 +843,14 @@ def forward_backward_pipelining_with_cutinhalf(
fwd_wait_handles_slave_chunk
=
send_forward
(
output_tensor_slave_chunk
,
fwd_wait_handles_slave_chunk
=
send_forward
(
output_tensor_slave_chunk
,
tensor_shape
,
config
,
slave_chunk_id
,
async_op
=
True
)
tensor_shape
,
config
,
slave_chunk_id
,
async_op
=
True
)
fwd_wait_handles_recv
=
None
# Run overlaping f&bw stages
# Run overlaping f&bw stages
fwd_wait_handles_recv
=
None
fwd_model_chunk_id
=
master_chunk_id
fwd_model_chunk_id
=
master_chunk_id
bwd_model_chunk_id
=
slave_chunk_id
bwd_model_chunk_id
=
slave_chunk_id
for
step_id
in
range
(
schedule
[
'overlap'
][
rank
]
+
schedule
[
'1b1overlap'
][
rank
]
+
schedule
[
'interleaved_backward'
][
rank
]):
num_overlap_steps
=
schedule
[
'overlap'
][
rank
]
+
schedule
[
'1b1overlap'
][
rank
]
if
not
forward_only
:
num_overlap_steps
+=
schedule
[
'interleaved_backward'
][
rank
]
for
step_id
in
range
(
num_overlap_steps
):
only_bwd
=
False
only_bwd
=
False
if
fwd_model_chunk_id
==
master_chunk_id
and
master_cur_microbatch
==
master_microbatch_max
:
if
fwd_model_chunk_id
==
master_chunk_id
and
master_cur_microbatch
==
master_microbatch_max
:
only_bwd
=
True
only_bwd
=
True
...
@@ -853,6 +861,11 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -853,6 +861,11 @@ def forward_backward_pipelining_with_cutinhalf(
fwd_microbatch
=
master_cur_microbatch
if
fwd_model_chunk_id
==
master_chunk_id
else
slave_cur_microbatch
fwd_microbatch
=
master_cur_microbatch
if
fwd_model_chunk_id
==
master_chunk_id
else
slave_cur_microbatch
set_dualpipe_chunk
(
fwd_model_chunk_id
)
set_dualpipe_chunk
(
fwd_model_chunk_id
)
if
fwd_wait_handles_recv
is
not
None
:
for
req
,
req_handle
in
fwd_wait_handles_recv
.
items
():
req_handle
.
wait
()
fwd_wait_handles_recv
=
None
output_tensor
,
num_tokens
=
forward_step_no_model_graph
(
output_tensor
,
num_tokens
=
forward_step_no_model_graph
(
forward_step_func
,
forward_step_func
,
fwd_model_chunk_id
,
fwd_model_chunk_id
,
...
@@ -906,7 +919,7 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -906,7 +919,7 @@ def forward_backward_pipelining_with_cutinhalf(
input_tensor
,
fwd_wait_handles
=
send_forward_recv_slave_forward
(
input_tensor
,
fwd_wait_handles
=
send_forward_recv_slave_forward
(
output_tensor
,
tensor_shape
,
config
,
fwd_model_chunk_id
,
async_op
=
True
)
output_tensor
,
tensor_shape
,
config
,
fwd_model_chunk_id
,
async_op
=
True
)
if
firstFB_no_overlp_handle
is
not
None
:
if
not
forward_only
and
firstFB_no_overlp_handle
is
not
None
:
for
req
,
req_handle
in
firstFB_no_overlp_handle
.
items
():
for
req
,
req_handle
in
firstFB_no_overlp_handle
.
items
():
if
req_handle
is
not
None
:
if
req_handle
is
not
None
:
req_handle
.
wait
()
req_handle
.
wait
()
...
@@ -948,8 +961,8 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -948,8 +961,8 @@ def forward_backward_pipelining_with_cutinhalf(
# only run backward
# only run backward
else
:
else
:
if
bwd_model_chunk_id
==
slave_chunk_id
and
slave_cur_microbatch
<
slave_microbatch_max
:
if
bwd_model_chunk_id
==
slave_chunk_id
and
slave_cur_microbatch
<
slave_microbatch_max
:
input_tensor
,
_
=
recv_forward
(
input_tensor
,
fwd_wait_handles_recv
=
recv_forward
(
tensor_shape
,
config
,
slave_chunk_id
)
tensor_shape
,
config
,
slave_chunk_id
,
async_op
=
True
)
if
not
forward_only
:
if
not
forward_only
:
if
bwd_wait_handles
is
not
None
:
if
bwd_wait_handles
is
not
None
:
for
req
,
req_handle
in
bwd_wait_handles
.
items
():
for
req
,
req_handle
in
bwd_wait_handles
.
items
():
...
...
dcu_megatron/training/training.py
View file @
a58a2da6
...
@@ -20,12 +20,16 @@ from megatron.core.num_microbatches_calculator import (
...
@@ -20,12 +20,16 @@ from megatron.core.num_microbatches_calculator import (
get_current_global_batch_size
,
get_current_global_batch_size
,
get_current_running_global_batch_size
,
get_current_running_global_batch_size
,
get_num_microbatches
,
get_num_microbatches
,
update_num_microbatches
)
update_num_microbatches
,
)
from
megatron.training.async_utils
import
maybe_finalize_async_save
from
megatron.training.async_utils
import
maybe_finalize_async_save
from
megatron.training.utils
import
(
from
megatron.training.utils
import
(
calc_params_l2_norm
,
calc_params_l2_norm
,
print_rank_0
,
print_rank_0
,
logical_and_across_model_parallel_group
,
reduce_max_stat_across_model_parallel_group
,
unwrap_model
,
)
)
from
megatron.training.global_vars
import
(
from
megatron.training.global_vars
import
(
get_args
,
get_args
,
...
@@ -41,7 +45,6 @@ from megatron.training.training import (
...
@@ -41,7 +45,6 @@ from megatron.training.training import (
print_datetime
,
print_datetime
,
should_disable_forward_pre_hook
,
should_disable_forward_pre_hook
,
disable_forward_pre_hook
,
disable_forward_pre_hook
,
train_step
,
save_checkpoint_and_time
,
save_checkpoint_and_time
,
enable_forward_pre_hook
,
enable_forward_pre_hook
,
num_floating_point_operations
,
num_floating_point_operations
,
...
@@ -49,7 +52,12 @@ from megatron.training.training import (
...
@@ -49,7 +52,12 @@ from megatron.training.training import (
evaluate_and_print_results
,
evaluate_and_print_results
,
post_training_step_callbacks
,
post_training_step_callbacks
,
checkpoint_and_decide_exit
,
checkpoint_and_decide_exit
,
cuda_graph_capture
,
cuda_graph_set_manual_hooks
,
dummy_train_step
,
)
)
from
megatron.core.pipeline_parallel
import
get_forward_backward_func
stimer
=
StragglerDetector
()
stimer
=
StragglerDetector
()
...
@@ -77,6 +85,122 @@ def build_train_valid_test_data_iterators_wrapper(build_train_valid_test_data_it
...
@@ -77,6 +85,122 @@ def build_train_valid_test_data_iterators_wrapper(build_train_valid_test_data_it
return
wrapper
return
wrapper
def
train_step
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
opt_param_scheduler
,
config
):
"""Single training step."""
args
=
get_args
()
timers
=
get_timers
()
# CUDA Graph capturing only executes once, when it's the first training iteration.
if
args
.
curr_iteration
==
args
.
iteration
and
args
.
external_cuda_graph
:
cuda_graph_capture
(
model
,
config
,
args
)
# Set grad to zero.
for
model_chunk
in
model
:
model_chunk
.
zero_grad_buffer
()
optimizer
.
zero_grad
()
# Collect garbage and empty unused memory.
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
rerun_state_machine
=
get_rerun_state_machine
()
while
rerun_state_machine
.
should_run_forward_backward
(
data_iterator
):
# Set grad to zero.
for
model_chunk
in
model
:
model_chunk
.
zero_grad_buffer
()
optimizer
.
zero_grad
()
# Forward pass.
forward_backward_func
=
get_forward_backward_func
()
losses_reduced
=
forward_backward_func
(
forward_step_func
=
forward_step_func
,
data_iterator
=
data_iterator
,
model
=
model
,
num_microbatches
=
get_num_microbatches
(),
seq_length
=
args
.
seq_length
,
micro_batch_size
=
args
.
micro_batch_size
,
decoder_seq_length
=
args
.
decoder_seq_length
,
forward_only
=
False
)
should_checkpoint
,
should_exit
,
exit_code
=
rerun_state_machine
.
should_checkpoint_and_exit
()
if
should_exit
:
return
{},
True
,
should_checkpoint
,
should_exit
,
exit_code
,
None
,
None
# Empty unused memory.
if
args
.
empty_unused_memory_level
>=
1
:
torch
.
cuda
.
empty_cache
()
# Vision gradients.
if
args
.
vision_pretraining
and
args
.
vision_pretraining_type
==
"dino"
:
unwrapped_model
=
unwrap_model
(
model
[
0
])
unwrapped_model
.
cancel_gradients_last_layer
(
args
.
curr_iteration
)
# Update parameters.
timers
(
'optimizer'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
update_successful
,
grad_norm
,
num_zeros_in_grad
=
optimizer
.
step
()
timers
(
'optimizer'
).
stop
()
# when freezing sub-models we may have a mixture of successful and unsucessful ranks,
# so we must gather across mp ranks
update_successful
=
logical_and_across_model_parallel_group
(
update_successful
)
# grad_norm and num_zeros_in_grad will be None on ranks without trainable params,
# so we must gather across mp ranks
grad_norm
=
reduce_max_stat_across_model_parallel_group
(
grad_norm
)
if
args
.
log_num_zeros_in_grad
:
num_zeros_in_grad
=
reduce_max_stat_across_model_parallel_group
(
num_zeros_in_grad
)
# Vision momentum.
if
args
.
vision_pretraining
and
args
.
vision_pretraining_type
==
"dino"
:
unwrapped_model
=
unwrap_model
(
model
[
0
])
unwrapped_model
.
update_momentum
(
args
.
curr_iteration
)
# Update learning rate.
if
update_successful
:
increment
=
get_num_microbatches
()
*
\
args
.
micro_batch_size
*
\
args
.
data_parallel_size
opt_param_scheduler
.
step
(
increment
=
increment
)
skipped_iter
=
0
else
:
skipped_iter
=
1
# Empty unused memory.
if
args
.
empty_unused_memory_level
>=
2
:
torch
.
cuda
.
empty_cache
()
# Set the manual hooks when CUDA Graphs are enabled.
if
args
.
curr_iteration
==
args
.
iteration
and
args
.
external_cuda_graph
:
if
args
.
use_distributed_optimizer
and
args
.
overlap_param_gather
:
cuda_graph_set_manual_hooks
(
model
)
if
args
.
schedule_method
==
'dualpipev'
:
is_last_stage
=
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
)
else
:
is_last_stage
=
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
)
if
is_last_stage
:
# Average loss across microbatches.
loss_reduced
=
{}
for
key
in
losses_reduced
[
0
].
keys
():
numerator
=
0
denominator
=
0
for
x
in
losses_reduced
:
val
=
x
[
key
]
# there is one dict per microbatch. in new reporting, we average
# over the total number of tokens across the global batch.
if
isinstance
(
val
,
tuple
)
or
isinstance
(
val
,
list
):
numerator
+=
val
[
0
]
denominator
+=
val
[
1
]
else
:
# legacy behavior. we average over the number of microbatches,
# and so the denominator is 1.
numerator
+=
val
denominator
+=
1
loss_reduced
[
key
]
=
numerator
/
denominator
return
loss_reduced
,
skipped_iter
,
should_checkpoint
,
should_exit
,
exit_code
,
grad_norm
,
num_zeros_in_grad
return
{},
skipped_iter
,
should_checkpoint
,
should_exit
,
exit_code
,
grad_norm
,
num_zeros_in_grad
def
train
(
forward_step_func
,
model
,
optimizer
,
opt_param_scheduler
,
def
train
(
forward_step_func
,
model
,
optimizer
,
opt_param_scheduler
,
train_data_iterator
,
valid_data_iterator
,
train_data_iterator
,
valid_data_iterator
,
process_non_loss_data_func
,
config
,
checkpointing_context
,
non_loss_data_func
):
process_non_loss_data_func
,
config
,
checkpointing_context
,
non_loss_data_func
):
...
...
dcu_megatron/training/utils.py
View file @
a58a2da6
...
@@ -4,6 +4,15 @@ from megatron.training import get_args
...
@@ -4,6 +4,15 @@ from megatron.training import get_args
from
megatron.core
import
mpu
from
megatron.core
import
mpu
def
print_rank_message
(
message
,
rank_id
=
0
):
"""If distributed is initialized, print only on rank specified by rank_id."""
if
torch
.
distributed
.
is_initialized
():
if
torch
.
distributed
.
get_rank
()
==
rank_id
:
print
(
f
"[rank
{
rank_id
}
]
{
message
}
"
,
flush
=
True
)
else
:
print
(
f
"[rank
{
rank_id
}
]
{
message
}
"
,
flush
=
True
)
def
get_batch_on_this_tp_rank
(
data_iterator
):
def
get_batch_on_this_tp_rank
(
data_iterator
):
args
=
get_args
()
args
=
get_args
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment