Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
8d5bae2a
Commit
8d5bae2a
authored
May 29, 2025
by
dongcl
Browse files
add dualpipev_chunks to support dualpipev
parent
e5f5eb4d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
1248 additions
and
0 deletions
+1248
-0
dcu_megatron/core/pipeline_parallel/dualpipev_schedules/dualpipev_chunks.py
...pipeline_parallel/dualpipev_schedules/dualpipev_chunks.py
+230
-0
dcu_megatron/core/pipeline_parallel/dualpipev_schedules/dualpipev_schedules.py
...eline_parallel/dualpipev_schedules/dualpipev_schedules.py
+1018
-0
No files found.
dcu_megatron/core/pipeline_parallel/dualpipev_schedules/dualpipev_chunks.py
0 → 100644
View file @
8d5bae2a
import
torch
from
functools
import
wraps
from
typing
import
List
,
Optional
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron.core.utils
import
get_model_config
from
megatron.core.transformer.module
import
Float16Module
from
megatron.core.distributed
import
DistributedDataParallelConfig
from
megatron.core.distributed
import
DistributedDataParallel
as
DDP
from
megatron.core.enums
import
ModelType
from
megatron.training.global_vars
import
get_args
,
get_timers
from
megatron.training.utils
import
unwrap_model
from
megatron.core.pipeline_parallel
import
get_forward_backward_func
from
megatron.core.transformer.module
import
fp32_to_float16
,
float16_to_fp32
from
megatron.core.num_microbatches_calculator
import
get_num_microbatches
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core
import
parallel_state
from
megatron.core.distributed.finalize_model_grads
import
_allreduce_layernorm_grads
from
dcu_megatron.core.pipeline_parallel.dualpipev.dualpipev_schedules
import
get_dualpipe_chunk
def
dualpipev_fp16forward
(
self
,
*
inputs
,
**
kwargs
):
dualpipe_first_stage
=
mpu
.
is_pipeline_first_stage
()
and
get_dualpipe_chunk
()
==
0
if
dualpipe_first_stage
:
inputs
=
fp32_to_float16
(
inputs
,
self
.
float16_convertor
)
outputs
=
self
.
module
(
*
inputs
,
**
kwargs
)
dualpipe_last_stage
=
mpu
.
is_pipeline_first_stage
()
and
get_dualpipe_chunk
()
==
1
if
dualpipe_last_stage
:
outputs
=
float16_to_fp32
(
outputs
)
return
outputs
def
get_model
(
model_provider_func
,
model_type
=
ModelType
.
encoder_or_decoder
,
wrap_with_ddp
=
True
):
"""Build the model."""
args
=
get_args
()
args
.
model_type
=
model_type
assert
model_type
!=
ModelType
.
encoder_and_decoder
,
\
"dualpipev schedule not supported for model with both encoder and decoder"
model
=
[]
args
.
dualpipev_first_chunk
=
True
first_model
=
model_provider_func
(
pre_process
=
mpu
.
is_pipeline_first_stage
(),
post_process
=
False
)
first_model
.
model_type
=
model_type
model
.
append
(
first_model
)
args
.
dualpipev_first_chunk
=
False
second_model
=
model_provider_func
(
pre_process
=
False
,
post_process
=
mpu
.
is_pipeline_first_stage
()
)
second_model
.
model_type
=
model_type
model
.
append
(
second_model
)
if
not
isinstance
(
model
,
list
):
model
=
[
model
]
# Set tensor model parallel attributes if not set.
# Only parameters that are already tensor model parallel have these
# attributes set for them. We should make sure the default attributes
# are set for all params so the optimizer can use them.
for
model_module
in
model
:
for
param
in
model_module
.
parameters
():
tensor_parallel
.
set_defaults_if_not_set_tensor_model_parallel_attributes
(
param
)
# Print number of parameters.
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
' > number of parameters on (tensor, pipeline) '
'model parallel rank ({}, {}): {}'
.
format
(
mpu
.
get_tensor_model_parallel_rank
(),
mpu
.
get_pipeline_model_parallel_rank
(),
sum
([
sum
([
p
.
nelement
()
for
p
in
model_module
.
parameters
()])
for
model_module
in
model
])),
flush
=
True
)
# GPU allocation.
for
model_module
in
model
:
model_module
.
cuda
(
torch
.
cuda
.
current_device
())
# Fp16 conversion.
if
args
.
fp16
or
args
.
bf16
:
model
=
[
Float16Module
(
model_module
,
args
)
for
model_module
in
model
]
if
wrap_with_ddp
:
config
=
get_model_config
(
model
[
0
])
ddp_config
=
DistributedDataParallelConfig
(
grad_reduce_in_fp32
=
args
.
accumulate_allreduce_grads_in_fp32
,
overlap_grad_reduce
=
args
.
overlap_grad_reduce
,
use_distributed_optimizer
=
args
.
use_distributed_optimizer
,
check_for_nan_in_grad
=
args
.
check_for_nan_in_loss_and_grad
,
bucket_size
=
args
.
ddp_bucket_size
,
average_in_collective
=
args
.
ddp_average_in_collective
)
model
=
[
DDP
(
config
,
ddp_config
,
model_chunk
,
# Turn off bucketing for model_chunk 2 onwards, since communication for these
# model chunks is overlapped with compute anyway.
disable_bucketing
=
(
model_chunk_idx
>
0
))
for
(
model_chunk_idx
,
model_chunk
)
in
enumerate
(
model
)]
# Broadcast params from data parallel src rank to other data parallel ranks.
if
args
.
data_parallel_random_init
:
for
model_module
in
model
:
model_module
.
broadcast_params
()
return
model
def
train_step
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
opt_param_scheduler
,
config
):
"""Single training step."""
args
=
get_args
()
timers
=
get_timers
()
rerun_state_machine
=
get_rerun_state_machine
()
while
rerun_state_machine
.
should_run_forward_backward
(
data_iterator
):
# Set grad to zero.
for
model_chunk
in
model
:
model_chunk
.
zero_grad_buffer
()
optimizer
.
zero_grad
()
# Forward pass.
forward_backward_func
=
get_forward_backward_func
()
losses_reduced
=
forward_backward_func
(
forward_step_func
=
forward_step_func
,
data_iterator
=
data_iterator
,
model
=
model
,
num_microbatches
=
get_num_microbatches
(),
seq_length
=
args
.
seq_length
,
micro_batch_size
=
args
.
micro_batch_size
,
decoder_seq_length
=
args
.
decoder_seq_length
,
forward_only
=
False
)
should_checkpoint
,
should_exit
,
exit_code
=
rerun_state_machine
.
should_checkpoint_and_exit
()
if
should_exit
:
return
{},
True
,
should_checkpoint
,
should_exit
,
exit_code
,
None
,
None
# Empty unused memory.
if
args
.
empty_unused_memory_level
>=
1
:
torch
.
cuda
.
empty_cache
()
# Vision gradients.
if
args
.
vision_pretraining
and
args
.
vision_pretraining_type
==
"dino"
:
unwrapped_model
=
unwrap_model
(
model
[
0
])
unwrapped_model
.
cancel_gradients_last_layer
(
args
.
curr_iteration
)
# Update parameters.
timers
(
'optimizer'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
update_successful
,
grad_norm
,
num_zeros_in_grad
=
optimizer
.
step
()
timers
(
'optimizer'
).
stop
()
# when freezing sub-models we may have a mixture of successful and unsucessful ranks,
# so we must gather across mp ranks
update_successful
=
logical_and_across_model_parallel_group
(
update_successful
)
# grad_norm and num_zeros_in_grad will be None on ranks without trainable params,
# so we must gather across mp ranks
grad_norm
=
reduce_max_stat_across_model_parallel_group
(
grad_norm
)
if
args
.
log_num_zeros_in_grad
:
num_zeros_in_grad
=
reduce_max_stat_across_model_parallel_group
(
num_zeros_in_grad
)
# Vision momentum.
if
args
.
vision_pretraining
and
args
.
vision_pretraining_type
==
"dino"
:
unwrapped_model
=
unwrap_model
(
model
[
0
])
unwrapped_model
.
update_momentum
(
args
.
curr_iteration
)
# Update learning rate.
if
update_successful
:
increment
=
get_num_microbatches
()
*
\
args
.
micro_batch_size
*
\
args
.
data_parallel_size
opt_param_scheduler
.
step
(
increment
=
increment
)
skipped_iter
=
0
else
:
skipped_iter
=
1
# Empty unused memory.
if
args
.
empty_unused_memory_level
>=
2
:
torch
.
cuda
.
empty_cache
()
if
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
# Average loss across microbatches.
loss_reduced
=
{}
for
key
in
losses_reduced
[
0
].
keys
():
numerator
=
0
denominator
=
0
for
x
in
losses_reduced
:
val
=
x
[
key
]
# there is one dict per microbatch. in new reporting, we average
# over the total number of tokens across the global batch.
if
isinstance
(
val
,
tuple
)
or
isinstance
(
val
,
list
):
numerator
+=
val
[
0
]
denominator
+=
val
[
1
]
else
:
# legacy behavior. we average over the number of microbatches,
# and so the denominator is 1.
numerator
+=
val
denominator
+=
1
loss_reduced
[
key
]
=
numerator
/
denominator
return
loss_reduced
,
skipped_iter
,
should_checkpoint
,
should_exit
,
exit_code
,
grad_norm
,
num_zeros_in_grad
return
{},
skipped_iter
,
should_checkpoint
,
should_exit
,
exit_code
,
grad_norm
,
num_zeros_in_grad
def
get_num_layers_to_build
(
config
:
TransformerConfig
)
->
int
:
num_layers_per_pipeline_rank
=
(
config
.
num_layers
//
parallel_state
.
get_pipeline_model_parallel_world_size
()
)
num_layers_to_build
=
num_layers_per_pipeline_rank
//
2
return
num_layers_to_build
def
_allreduce_embedding_grads_wrapper
(
fn
):
@
wraps
(
fn
)
def
wrapper
(
*
args
,
**
kwargs
):
if
get_args
().
schedules_method
==
'dualpipev'
:
# dualpipev no need to do embedding allreduce
# embedding and lm head are on save rank.
if
not
get_args
().
untie_embeddings_and_output_weights
:
raise
NotImplementedError
else
:
return
else
:
return
fn
(
*
args
,
**
kwargs
)
return
wrapper
dcu_megatron/core/pipeline_parallel/dualpipev.py
→
dcu_megatron/core/pipeline_parallel/dualpipev
_schedules/dualpipev_schedules
.py
View file @
8d5bae2a
...
...
@@ -21,8 +21,6 @@ from megatron.core import ModelParallelConfig
from
megatron.core.pipeline_parallel.p2p_communication
import
_communicate
from
megatron.core.pipeline_parallel.schedules
import
backward_step
,
set_current_microbatch
,
custom_backward
,
finish_embedding_wgrad_compute
from
megatron.core.models.gpt
import
GPTModel
from
mindspeed.core.pipeline_parallel.fb_overlap.gpt_model
import
gpt_model_backward
from
mindspeed.core.pipeline_parallel.fb_overlap.transformer_layer
import
P2PCommParams
from
mindspeed.core.pipeline_parallel.fb_overlap.modules.weight_grad_store
import
WeightGradStore
...
...
@@ -34,10 +32,10 @@ LOSS_BACKWARD_SCALE = torch.tensor(1.0)
_DUALPIPE_CHUNK
=
None
def
set_dualpipe_chunk
(
chunkid
):
def
set_dualpipe_chunk
(
chunk
_
id
):
"""set_dualpipe_chunk for fp16forward patch"""
global
_DUALPIPE_CHUNK
_DUALPIPE_CHUNK
=
chunkid
_DUALPIPE_CHUNK
=
chunk
_
id
def
get_dualpipe_chunk
():
...
...
@@ -48,7 +46,7 @@ def get_dualpipe_chunk():
raise
AssertionError
(
"_DUALPIPE_CHUNK is None"
)
def
is_dualpipev_last_st
g
ae
(
model_chunk_id
):
def
is_dualpipev_last_sta
g
e
(
model_chunk_id
):
return
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
True
)
and
model_chunk_id
==
1
...
...
@@ -59,11 +57,11 @@ def send_forward(output_tensor: torch.Tensor, tensor_shape, config: ModelParalle
"""
tensor_send_next
,
tensor_send_prev
=
None
,
None
if
model_chunk_id
==
0
:
if
parallel_state
.
is_pipeline_last_stage
():
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
return
None
tensor_send_next
=
output_tensor
else
:
if
parallel_state
.
is_pipeline_first_stage
():
if
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
return
None
tensor_send_prev
=
output_tensor
...
...
@@ -93,11 +91,11 @@ def send_backward(input_tensor_grad: torch.Tensor, tensor_shape, config: ModelPa
tensor_send_next
,
tensor_send_prev
=
None
,
None
if
model_chunk_id
==
0
:
if
parallel_state
.
is_pipeline_first_stage
():
if
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
return
None
tensor_send_prev
=
input_tensor_grad
else
:
if
parallel_state
.
is_pipeline_last_stage
():
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
return
None
tensor_send_next
=
input_tensor_grad
...
...
@@ -128,7 +126,10 @@ def recv_forward(tensor_shape: Shape, config: ModelParallelConfig, model_chunk_i
else
:
recv_next
=
True
if
(
parallel_state
.
is_pipeline_first_stage
()
and
recv_prev
)
or
(
parallel_state
.
is_pipeline_last_stage
()
and
recv_next
):
if
(
(
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
True
)
and
recv_prev
)
or
(
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
)
and
recv_next
)
):
fwd_wait_handles
=
None
return
None
,
fwd_wait_handles
else
:
...
...
@@ -163,7 +164,10 @@ def recv_backward(tensor_shape: Shape, config: ModelParallelConfig, model_chunk_
else
:
recv_prev
=
True
if
(
parallel_state
.
is_pipeline_first_stage
()
and
recv_prev
)
or
(
parallel_state
.
is_pipeline_last_stage
()
and
recv_next
):
if
(
(
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
True
)
and
recv_prev
)
or
(
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
)
and
recv_next
)
):
output_tensor_grad
=
None
bwd_wait_handles
=
None
return
output_tensor_grad
,
bwd_wait_handles
...
...
@@ -203,14 +207,14 @@ def send_forward_recv_forward(
recv_prev
,
recv_next
=
False
,
False
tensor_send_next
,
tensor_send_prev
=
None
,
None
if
model_chunk_id
==
0
:
if
not
parallel_state
.
is_pipeline_last_stage
():
if
not
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
tensor_send_next
=
output_tensor
if
not
parallel_state
.
is_pipeline_first_stage
():
if
not
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
recv_prev
=
True
if
model_chunk_id
==
1
:
if
not
parallel_state
.
is_pipeline_first_stage
():
if
not
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
tensor_send_prev
=
output_tensor
if
not
parallel_state
.
is_pipeline_last_stage
():
if
not
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
recv_next
=
True
if
config
.
timers
is
not
None
:
...
...
@@ -228,22 +232,23 @@ def send_forward_recv_forward(
config
.
timers
(
'forward-send-forward-recv'
).
stop
()
if
model_chunk_id
==
0
:
if
not
parallel_state
.
is_pipeline_first_stage
():
if
not
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
return
tensor_recv_prev
,
fwd_wait_handles
else
:
return
None
,
fwd_wait_handles
else
:
if
not
parallel_state
.
is_pipeline_last_stage
():
if
not
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
return
tensor_recv_next
,
fwd_wait_handles
else
:
return
None
,
fwd_wait_handles
# TODO (dongcl)
def
send_forward_recv_slave_forward
(
output_tensor
:
torch
.
Tensor
,
tensor_shape
:
Shape
,
config
:
ModelParallelConfig
,
model_chunk_id
,
fwd_
model_chunk_id
,
async_op
=
False
,
)
->
torch
.
Tensor
:
"""Batched recv from previous rank and send to next rank in pipeline.
...
...
@@ -251,13 +256,13 @@ def send_forward_recv_slave_forward(
"""
recv_prev
,
recv_next
=
False
,
False
tensor_send_next
,
tensor_send_prev
=
None
,
None
if
model_chunk_id
==
0
:
if
parallel_state
.
is_pipeline_last_stage
():
if
fwd_
model_chunk_id
==
0
:
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
return
None
,
None
tensor_send_next
=
output_tensor
recv_next
=
True
if
model_chunk_id
==
1
:
if
parallel_state
.
is_pipeline_first_stage
():
if
fwd_
model_chunk_id
==
1
:
if
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
return
None
,
None
tensor_send_prev
=
output_tensor
recv_prev
=
True
...
...
@@ -275,7 +280,49 @@ def send_forward_recv_slave_forward(
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-send-slave-forward-recv'
).
stop
()
if
model_chunk_id
==
0
:
if
fwd_model_chunk_id
==
0
:
return
tensor_recv_next
,
fwd_wait_handles
else
:
return
tensor_recv_prev
,
fwd_wait_handles
def
send_backward_recv_slave_backward
(
input_tensor_grad
:
torch
.
Tensor
,
tensor_shape
:
Shape
,
config
:
ModelParallelConfig
,
fwd_model_chunk_id
,
async_op
=
False
,
)
->
torch
.
Tensor
:
"""Batched recv from previous rank and send to next rank in pipeline.
See _communicate for argument details.
"""
recv_prev
,
recv_next
=
False
,
False
tensor_send_next
,
tensor_send_prev
=
None
,
None
if
fwd_model_chunk_id
==
0
:
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
return
None
,
None
tensor_send_next
=
input_tensor_grad
recv_next
=
True
if
fwd_model_chunk_id
==
1
:
if
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
return
None
,
None
tensor_send_prev
=
input_tensor_grad
recv_prev
=
True
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-send-slave-forward-recv'
,
log_level
=
2
).
start
()
tensor_recv_prev
,
tensor_recv_next
,
fwd_wait_handles
=
_communicate
(
tensor_send_next
=
tensor_send_next
,
tensor_send_prev
=
tensor_send_prev
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
wait_on_reqs
=
(
not
async_op
),
config
=
config
,
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-send-slave-forward-recv'
).
stop
()
if
fwd_model_chunk_id
==
0
:
return
tensor_recv_next
,
fwd_wait_handles
else
:
return
tensor_recv_prev
,
fwd_wait_handles
...
...
@@ -320,38 +367,6 @@ def generate_dualpipev_schedule(pp_size, num_microbatches):
return
schedule_all_stages
def
pretrain_gpt_forward_step_dualpipe
(
data_iterator
,
model
:
GPTModel
,
extra_block_kwargs
=
None
):
from
megatron.training
import
get_timers
from
functools
import
partial
from
pretrain_gpt
import
get_batch
,
loss_func
"""Forward training step.
Args:
data_iterator : Input data iterator
model (GPTModel): The GPT Model
"""
timers
=
get_timers
()
# Get the batch.
timers
(
'batch-generator'
,
log_level
=
2
).
start
()
tokens
,
labels
,
loss_mask
,
attention_mask
,
position_ids
=
get_batch
(
data_iterator
)
timers
(
'batch-generator'
).
stop
()
if
extra_block_kwargs
is
not
None
:
# excute forward backward overlaping
output_tensor
,
model_graph
,
pp_comm_output
=
\
model
(
tokens
,
position_ids
,
attention_mask
,
labels
=
labels
,
extra_block_kwargs
=
extra_block_kwargs
)
return
(
output_tensor
,
model_graph
,
pp_comm_output
),
partial
(
loss_func
,
loss_mask
)
else
:
output_tensor
,
model_graph
=
model
(
tokens
,
position_ids
,
attention_mask
,
labels
=
labels
)
return
(
output_tensor
,
model_graph
),
partial
(
loss_func
,
loss_mask
)
def
forward_step_no_model_graph
(
forward_step_func
,
model_chunk_id
,
...
...
@@ -395,18 +410,20 @@ def forward_step_no_model_graph(
)
num_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
int
)
if
is_dualpipev_last_st
gae
:
if
is_dualpipev_last_st
age
(
model_chunk_id
)
:
if
not
collect_non_loss_data
:
outputs
=
loss_func
(
output_tensor
)
if
len
(
outputs
)
==
3
:
output_tensor
,
num_tokens
,
loss_reduced
=
outputs
if
not
config
.
calculate_per_token_loss
:
output_tensor
/=
num_tokens
output_tensor
*=
parallel_state
.
get_context_parallel_world_size
()
output_tensor
/=
num_microbatches
else
:
# preserve legacy loss averaging behavior (ie, over the number of microbatches)
assert
len
(
outputs
)
==
2
output_tensor
,
loss_reduced
=
outputs
output_tensor
*=
parallel_state
.
get_context_parallel_world_size
()
output_tensor
/=
num_microbatches
forward_data_store
.
append
(
loss_reduced
)
else
:
...
...
@@ -417,251 +434,36 @@ def forward_step_no_model_graph(
config
.
timers
(
'forward-compute'
).
stop
()
# Set the loss scale for the auxiliary loss of the MoE layer.
# Since we use a trick to do backward on the auxiliary loss, we need to set the scale explicitly.
# Since we use a trick to do backward on the auxiliary loss, we need to set the scale
# explicitly.
if
hasattr
(
config
,
'num_moe_experts'
)
and
config
.
num_moe_experts
is
not
None
:
# Calculate the loss scale based on the grad_scale_func if available, else default to 1.
loss_scale
=
(
config
.
grad_scale_func
(
torch
.
ones
(
1
,
device
=
output_tensor
.
device
))
if
config
.
grad_scale_func
is
not
None
else
torch
.
tensor
(
1.0
)
else
torch
.
ones
(
1
,
device
=
output_tensor
.
device
)
)
# Set the loss scale
MoEAuxLossAutoScaler
.
set_loss_scale
(
loss_scale
/
num_microbatches
)
# If T5 model (or other model with encoder and decoder)
# and in decoder stack, then send encoder_hidden_state
# downstream as well.
model_type
=
get_model_type
(
model
)
if
(
parallel_state
.
is_pipeline_stage_after_split
()
and
model_type
==
ModelType
.
encoder_and_decoder
):
return
[
output_tensor
,
input_tensor
[
-
1
]],
num_tokens
if
unwrap_output_tensor
:
return
output_tensor
,
num_tokens
return
[
output_tensor
],
num_tokens
def
backward_step_with_model_graph
(
input_tensor
,
output_tensor
,
output_tensor_grad
,
model_type
,
config
,
model_graph
=
None
):
"""Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss
with respect to stage's output tensor.
Returns gradient of loss with respect to input tensor (None if first
stage)."""
# NOTE: This code currently can handle at most one skip connection. It
# needs to be modified slightly to support arbitrary numbers of skip
# connections.
if
config
.
timers
is
not
None
:
config
.
timers
(
'backward-compute'
,
log_level
=
2
).
start
()
# Retain the grad on the input_tensor.
unwrap_input_tensor_grad
=
False
if
not
isinstance
(
input_tensor
,
list
):
input_tensor
=
[
input_tensor
]
unwrap_input_tensor_grad
=
True
for
x
in
input_tensor
:
if
x
is
not
None
:
x
.
retain_grad
()
if
not
isinstance
(
output_tensor
,
list
):
output_tensor
=
[
output_tensor
]
if
not
isinstance
(
output_tensor_grad
,
list
):
output_tensor_grad
=
[
output_tensor_grad
]
# Backward pass.
if
output_tensor_grad
[
0
]
is
None
and
config
.
grad_scale_func
is
not
None
and
model_graph
is
None
:
output_tensor
[
0
]
=
config
.
grad_scale_func
(
output_tensor
[
0
])
if
config
.
deallocate_pipeline_outputs
:
if
model_graph
is
None
:
custom_backward
(
output_tensor
[
0
],
output_tensor_grad
[
0
])
else
:
layer_output_grad
=
gpt_model_backward
(
output_tensor_grad
[
0
],
model_graph
)
else
:
torch
.
autograd
.
backward
(
output_tensor
[
0
],
grad_tensors
=
output_tensor_grad
[
0
])
# Collect the grad of the input_tensor.
input_tensor_grad
=
[
None
]
if
input_tensor
is
not
None
:
input_tensor_grad
=
[]
if
model_graph
is
not
None
:
input_tensor_grad
.
append
(
layer_output_grad
)
if
config
.
calculate_per_token_loss
:
MoEAuxLossAutoScaler
.
set_loss_scale
(
loss_scale
)
else
:
for
x
in
input_tensor
:
if
x
is
None
:
input_tensor_grad
.
append
(
None
)
else
:
input_tensor_grad
.
append
(
x
.
grad
)
# Handle single skip connection if it exists (encoder_hidden_state in
# model with encoder and decoder).
if
(
parallel_state
.
get_pipeline_model_parallel_world_size
()
>
1
and
parallel_state
.
is_pipeline_stage_after_split
()
and
model_type
==
ModelType
.
encoder_and_decoder
):
if
output_tensor_grad
[
1
]
is
not
None
:
input_tensor_grad
[
-
1
].
add_
(
output_tensor_grad
[
1
])
if
unwrap_input_tensor_grad
:
input_tensor_grad
=
input_tensor_grad
[
0
]
MoEAuxLossAutoScaler
.
set_loss_scale
(
loss_scale
/
num_microbatches
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'backward-compute'
).
stop
()
return
input_tensor_grad
def
forward_step_with_model_graph
(
forward_step_func
,
model_chunk_id
,
data_iterator
,
model
,
num_microbatches
,
input_tensor
,
forward_data_store
,
config
,
collect_non_loss_data
=
False
,
checkpoint_activations_microbatch
=
None
,
is_first_microbatch
=
False
,
current_microbatch
=
None
,
extra_block_kwargs
=
None
,
):
"""Forward step for passed-in model.
If it is the first stage, the input tensor is obtained from the data_iterator.
Otherwise, the passed-in input_tensor is used.
Args:
forward_step_func (callable): The forward step function for the model that takes the
data iterator as the first argument, and model as the second.
This user's forward step is expected to output a tuple of two elements:
1. The output object from the forward step. This output object needs to be a
tensor or some kind of collection of tensors. The only hard requirement
for this object is that it needs to be acceptible as input into the second
function.
2. A function to reduce (optionally) the output from the forward step. This
could be a reduction over the loss from the model, it could be a function that
grabs the output from the model and reformats, it could be a function that just
passes through the model output. This function must have one of the following
patterns, and depending on the pattern different things happen internally.
a. A tuple of reduced loss and some other data. Note that in this case
the first argument is divided by the number of global microbatches,
assuming it is a loss, so that the loss is stable as a function of
the number of devices the step is split across.
b. A triple of reduced loss, number of tokens, and some other data. This
is similar to case (a), but the loss is further averaged across the
number of tokens in the batch. If the user is not already averaging
across the number of tokens, this pattern is useful to use.
c. Any arbitrary data the user wants (eg a dictionary of tensors, a list
of tensors, etc in the case of inference). To trigger case 3 you need
to specify `collect_non_loss_data=True` and you may also want to
specify `forward_only=True` in the call to the parent forward_backward
function.
data_iterator (iterator): The data iterator.
model (nn.Module): The model to perform the forward step on.
num_microbatches (int): The number of microbatches.
input_tensor (Tensor or list[Tensor]): The input tensor(s) for the forward step.
forward_data_store (list): The list to store the forward data. If you go down path 2.a or
2.b for the return of your forward reduction function then this will store only the
final dimension of the output, for example the metadata output by the loss function.
If you go down the path of 2.c then this will store the entire output of the forward
reduction function applied to the model output.
config (object): The configuration object.
collect_non_loss_data (bool, optional): Whether to collect non-loss data. Defaults to False.
This is the path to use if you want to collect arbitrary output from the model forward,
such as with inference use cases. Defaults to False.
checkpoint_activations_microbatch (int, optional): The microbatch to checkpoint activations.
Defaults to None.
is_first_microbatch (bool, optional): Whether it is the first microbatch. Defaults to False.
current_microbatch (int, optional): The current microbatch. Defaults to None.
Returns:
Tensor or list[Tensor]: The output object(s) from the forward step.
Tensor: The number of tokens.
"""
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-compute'
,
log_level
=
2
).
start
()
if
is_first_microbatch
and
hasattr
(
model
,
'set_is_first_microbatch'
):
model
.
set_is_first_microbatch
()
if
current_microbatch
is
not
None
:
set_current_microbatch
(
model
,
current_microbatch
)
unwrap_output_tensor
=
False
if
not
isinstance
(
input_tensor
,
list
):
input_tensor
=
[
input_tensor
]
unwrap_output_tensor
=
True
set_input_tensor
=
get_attr_wrapped_model
(
model
,
"set_input_tensor"
)
set_input_tensor
(
input_tensor
)
if
config
.
enable_autocast
:
context_manager
=
torch
.
autocast
(
"cuda"
,
dtype
=
config
.
autocast_dtype
)
else
:
context_manager
=
contextlib
.
nullcontext
()
with
context_manager
:
if
checkpoint_activations_microbatch
is
None
:
output_tensor
,
loss_func
=
pretrain_gpt_forward_step_dualpipe
(
data_iterator
,
model
,
extra_block_kwargs
)
else
:
output_tensor
,
loss_func
=
pretrain_gpt_forward_step_dualpipe
(
data_iterator
,
model
,
checkpoint_activations_microbatch
,
extra_block_kwargs
)
num_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
int
)
if
is_dualpipev_last_stgae
(
model_chunk_id
):
if
not
collect_non_loss_data
:
next_info
=
None
if
isinstance
(
output_tensor
,
tuple
):
# use pp overlaping,
if
len
(
output_tensor
)
==
2
:
output_tensor
,
model_graph
=
output_tensor
elif
len
(
output_tensor
)
==
3
:
output_tensor
,
model_graph
,
next_info
=
output_tensor
outputs
=
loss_func
(
output_tensor
)
if
len
(
outputs
)
==
3
:
output_tensor
,
num_tokens
,
loss_reduced
=
outputs
if
not
config
.
calculate_per_token_loss
:
output_tensor
/=
num_tokens
output_tensor
/=
num_microbatches
else
:
# preserve legacy loss averaging behavior (ie, over the number of microbatches)
assert
len
(
outputs
)
==
2
output_tensor
,
loss_reduced
=
outputs
output_tensor
/=
num_microbatches
forward_data_store
.
append
(
loss_reduced
)
output_tensor
=
(
output_tensor
,
model_graph
,
next_info
)
if
next_info
is
not
None
else
(
output_tensor
,
model_graph
)
else
:
data
=
loss_func
(
output_tensor
,
non_loss_data
=
True
)
forward_data_store
.
append
(
data
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-compute'
).
stop
()
# Set the loss scale for the auxiliary loss of the MoE layer.
# Since we use a trick to do backward on the auxiliary loss, we need to set the scale explicitly.
if
hasattr
(
config
,
'num_moe_experts'
)
and
config
.
num_moe_experts
is
not
None
:
# Set the loss scale for Multi-Token Prediction (MTP) loss.
if
hasattr
(
config
,
'mtp_num_layers'
)
and
config
.
mtp_num_layers
is
not
None
:
# Calculate the loss scale based on the grad_scale_func if available, else default to 1.
loss_scale
=
(
config
.
grad_scale_func
(
LOSS_BACKWARD_SCALE
)
config
.
grad_scale_func
(
torch
.
ones
(
1
,
device
=
output_tensor
.
device
)
)
if
config
.
grad_scale_func
is
not
None
else
torch
.
tensor
(
1.0
)
else
torch
.
ones
(
1
,
device
=
output_tensor
.
device
)
)
# Set the loss scale
MoEAuxLossAutoScaler
.
set_loss_scale
(
loss_scale
/
num_microbatches
)
if
config
.
calculate_per_token_loss
:
MTPLossAutoScaler
.
set_loss_scale
(
loss_scale
)
else
:
MTPLossAutoScaler
.
set_loss_scale
(
loss_scale
/
num_microbatches
)
# If T5 model (or other model with encoder and decoder)
# and in decoder stack, then send encoder_hidden_state
# If T5 model and in decoder stack, then send encoder_hidden_state
# downstream as well.
model_type
=
get_model_type
(
model
)
if
(
...
...
@@ -713,11 +515,11 @@ def forward_backward_pipelining_with_cutinhalf(
set_shared_embedding_from_dual_chunk
(
model
[
0
],
model
[
1
])
assert
(
isinstance
(
model
,
list
)
and
len
(
model
)
==
2
),
'Dualpipe Schedule
only support chunk model for two consecutive
chunks'
),
'Dualpipe Schedule
expects two model
chunks'
assert
(
isinstance
(
data_iterator
,
list
)
and
len
(
data_iterator
)
==
2
),
'Dualpipe Schedule
only support
two data_iterators'
),
'Dualpipe Schedule
expects
two data_iterators'
config
=
get_model_config
(
model
[
0
])
config
.
batch_p2p_comm
=
False
...
...
@@ -727,8 +529,7 @@ def forward_backward_pipelining_with_cutinhalf(
embedding_module
=
clear_embedding_activation_buffer
(
config
,
model
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-backward'
,
log_level
=
1
).
start
(
barrier
=
config
.
barrier_with_L1_time
)
config
.
timers
(
'forward-backward'
,
log_level
=
1
).
start
(
barrier
=
config
.
barrier_with_L1_time
)
# Disable async grad reductions
no_sync_func
=
config
.
no_sync_func
...
...
@@ -783,97 +584,30 @@ def forward_backward_pipelining_with_cutinhalf(
checkpoint_activations_microbatch
=
None
def
forward_step_helper
(
model_chunk_id
,
current_microbatch
,
checkpoint_activations_microbatch
,
is_first_microbatch
=
False
,
extra_block_kwargs
=
None
):
input_tensor
=
recv_forward
(
tensor_shape
,
config
,
master_chunk_id
)[
0
]
input_tensor
=
input_tensors
[
model_chunk_id
][
-
1
][
1
]
output_tensor
,
num_tokens
=
forward_step_with_model_graph
(
fwd_wait_handles_warmup
=
None
# Run warmup forward passes
for
i
in
range
(
schedule
[
'warmup'
][
rank
]):
output_tensor_warmup
,
num_tokens
=
forward_step_no_model_graph
(
forward_step_func
,
m
odel
_chunk_id
,
data_iterator
[
m
odel
_chunk_id
],
model
[
m
odel
_chunk_id
],
m
aster
_chunk_id
,
data_iterator
[
m
aster
_chunk_id
],
model
[
m
aster
_chunk_id
],
num_microbatches
,
input_tensor
,
forward_data_store
,
config
,
collect_non_loss_data
,
checkpoint_activations_microbatch
,
is_first_microbatch
,
current_microbatch
=
current_microbatch
,
extra_block_kwargs
=
extra_block_kwargs
is_first_microbatch
=
(
i
==
0
),
current_microbatch
=
master_cur_microbatch
)
if
isinstance
(
output_tensor
,
tuple
):
if
len
(
output_tensor
)
==
2
:
output_tensor_
,
model_graph
=
output_tensor
elif
len
(
output_tensor
)
==
3
:
output_tensor_
,
model_graph
,
pp_comm_output
=
output_tensor
if
is_dualpipev_last_stgae
(
model_chunk_id
):
logits_inputs
.
append
(
model_graph
.
layer_graphs
[
-
1
].
unperm2_graph
[
1
])
model_graphs
[
model_chunk_id
].
append
(
model_graph
)
else
:
output_tensor_
=
output_tensor
output_tensors
[
model_chunk_id
].
append
(
output_tensor_
)
if
extra_block_kwargs
is
not
None
:
input_tensors
[
1
-
model_chunk_id
].
pop
(
0
)
output_tensors
[
1
-
model_chunk_id
].
pop
(
0
)
nonlocal
total_num_tokens
total_num_tokens
+=
num_tokens
.
item
()
# if forward-only, no need to save tensors for a backward pass
if
forward_only
:
input_tensors
[
model_chunk_id
].
pop
()
output_tensors
[
model_chunk_id
].
pop
()
return
output_tensor
def
check_pipeline_stage
(
model_chunk_id
,
fwd_send_only
):
send_next
,
recv_next
,
send_prev
,
recv_prev
=
True
,
True
,
True
,
True
if
parallel_state
.
is_pipeline_first_stage
():
send_prev
,
recv_prev
=
False
,
False
if
parallel_state
.
is_pipeline_last_stage
():
send_next
,
recv_next
=
False
,
False
if
model_chunk_id
==
0
:
return
P2PCommParams
(
send_next
=
send_next
,
recv_next
=
not
fwd_send_only
and
recv_next
),
P2PCommParams
(
send_next
=
send_next
,
recv_next
=
recv_next
)
else
:
return
P2PCommParams
(
send_prev
=
send_prev
,
recv_prev
=
not
fwd_send_only
and
recv_prev
),
P2PCommParams
(
send_prev
=
send_prev
,
recv_prev
=
recv_prev
)
input_tensor
=
recv_forward
(
tensor_shape
,
config
,
master_chunk_id
)[
0
]
fwd_wait_handles_warmup
=
None
# Run warmup forward passes
for
i
in
range
(
schedule
[
'warmup'
][
rank
]):
if
args
.
moe_fb_overlap
:
input_tensors
[
master_chunk_id
].
append
(
(
master_cur_microbatch
,
input_tensor
))
output_tensor_warmup
,
_
=
forward_step_helper
(
master_chunk_id
,
master_cur_microbatch
,
checkpoint_activations_microbatch
,
is_first_microbatch
=
(
i
==
0
))
else
:
output_tensor_warmup
,
num_tokens
=
forward_step_no_model_graph
(
forward_step_func
,
master_chunk_id
,
data_iterator
[
master_chunk_id
],
model
[
master_chunk_id
],
num_microbatches
,
input_tensor
,
forward_data_store
,
config
,
collect_non_loss_data
,
checkpoint_activations_microbatch
,
is_first_microbatch
=
(
i
==
0
),
current_microbatch
=
master_cur_microbatch
)
total_num_tokens
+=
num_tokens
.
item
()
input_tensors
[
master_chunk_id
].
append
(
(
master_cur_microbatch
,
input_tensor
))
output_tensors
[
master_chunk_id
].
append
(
output_tensor_warmup
)
input_tensors
[
master_chunk_id
].
append
(
(
master_cur_microbatch
,
input_tensor
))
output_tensors
[
master_chunk_id
].
append
(
output_tensor_warmup
)
master_cur_microbatch
+=
1
...
...
@@ -899,45 +633,39 @@ def forward_backward_pipelining_with_cutinhalf(
req
.
wait
()
fwd_wait_handles
=
None
is_first_microbatch
=
parallel_state
.
is_pipeline_last_stage
()
and
(
i
==
0
)
is_first_microbatch
=
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
)
and
(
i
==
0
)
set_dualpipe_chunk
(
master_chunk_id
)
if
args
.
moe_fb_overlap
:
input_tensors
[
master_chunk_id
].
append
(
(
master_cur_microbatch
,
input_tensor
))
output_tensor
,
_
=
forward_step_helper
(
master_chunk_id
,
master_cur_microbatch
,
checkpoint_activations_microbatch
,
is_first_microbatch
=
is_first_microbatch
)
else
:
output_tensor
,
num_tokens
=
forward_step_no_model_graph
(
forward_step_func
,
master_chunk_id
,
data_iterator
[
master_chunk_id
],
model
[
master_chunk_id
],
num_microbatches
,
input_tensor
,
forward_data_store
,
config
,
collect_non_loss_data
,
checkpoint_activations_microbatch
,
is_first_microbatch
=
is_first_microbatch
,
current_microbatch
=
master_cur_microbatch
)
output_tensor
,
num_tokens
=
forward_step_no_model_graph
(
forward_step_func
,
master_chunk_id
,
data_iterator
[
master_chunk_id
],
model
[
master_chunk_id
],
num_microbatches
,
input_tensor
,
forward_data_store
,
config
,
collect_non_loss_data
,
checkpoint_activations_microbatch
,
is_first_microbatch
=
is_first_microbatch
,
current_microbatch
=
master_cur_microbatch
)
total_num_tokens
+=
num_tokens
.
item
()
input_tensors
[
master_chunk_id
].
append
(
(
master_cur_microbatch
,
input_tensor
))
output_tensors
[
master_chunk_id
].
append
(
output_tensor
)
total_num_tokens
+=
num_tokens
.
item
()
input_tensors
[
master_chunk_id
].
append
(
(
master_cur_microbatch
,
input_tensor
))
output_tensors
[
master_chunk_id
].
append
(
output_tensor
)
master_cur_microbatch
+=
1
if
not
parallel_state
.
is_pipeline_last_stage
()
and
fwd_wait_handles_send
is
not
None
:
if
not
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
)
and
fwd_wait_handles_send
is
not
None
:
for
req
in
fwd_wait_handles_send
:
req
.
wait
()
deallocate_output_tensor
(
output_tensor_send
,
config
.
deallocate_pipeline_outputs
)
fwd_wait_handles_send
=
None
if
parallel_state
.
is_pipeline_last_stage
():
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
input_tensor_slave_chunk
=
output_tensor
input_tensor
,
fwd_wait_handles
=
recv_forward
(
...
...
@@ -964,31 +692,24 @@ def forward_backward_pipelining_with_cutinhalf(
fwd_wait_handles_slave_chunk
=
None
set_dualpipe_chunk
(
slave_chunk_id
)
output_tensor_slave_chunk
,
num_tokens
=
forward_step_no_model_graph
(
forward_step_func
,
slave_chunk_id
,
data_iterator
[
slave_chunk_id
],
model
[
slave_chunk_id
],
num_microbatches
,
input_tensor_slave_chunk
,
forward_data_store
,
config
,
collect_non_loss_data
,
checkpoint_activations_microbatch
,
current_microbatch
=
slave_cur_microbatch
,
)
if
args
.
moe_fb_overlap
:
input_tensors
[
slave_chunk_id
].
append
(
(
slave_cur_microbatch
,
input_tensor_slave_chunk
))
output_tensor_slave_chunk
,
_
=
forward_step_helper
(
slave_chunk_id
,
slave_cur_microbatch
,
checkpoint_activations_microbatch
)
else
:
output_tensor_slave_chunk
,
num_tokens
=
forward_step_no_model_graph
(
forward_step_func
,
slave_chunk_id
,
data_iterator
[
slave_chunk_id
],
model
[
slave_chunk_id
],
num_microbatches
,
input_tensor_slave_chunk
,
forward_data_store
,
config
,
collect_non_loss_data
,
checkpoint_activations_microbatch
,
current_microbatch
=
slave_cur_microbatch
,
)
input_tensors
[
slave_chunk_id
].
append
(
(
slave_cur_microbatch
,
input_tensor_slave_chunk
))
total_num_tokens
+=
num_tokens
.
item
()
output_tensors
[
slave_chunk_id
].
append
(
output_tensor_slave_chunk
)
input_tensors
[
slave_chunk_id
].
append
(
(
slave_cur_microbatch
,
input_tensor_slave_chunk
))
total_num_tokens
+=
num_tokens
.
item
()
output_tensors
[
slave_chunk_id
].
append
(
output_tensor_slave_chunk
)
slave_cur_microbatch
+=
1
...
...
@@ -997,7 +718,7 @@ def forward_backward_pipelining_with_cutinhalf(
firstFB_no_overlp
=
False
firstFB_no_overlp_handle
=
None
# last rank not overlap first F&B
if
parallel_state
.
is_pipeline_last_stage
():
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
firstFB_no_overlp
=
True
output_tensor_grad_bwd
,
firstFB_no_overlp_handle
=
recv_backward
(
tensor_shape
,
config
,
slave_chunk_id
,
async_op
=
True
)
...
...
@@ -1008,7 +729,7 @@ def forward_backward_pipelining_with_cutinhalf(
fwd_wait_handles_slave_chunk
=
send_forward
(
output_tensor_slave_chunk
,
tensor_shape
,
config
,
slave_chunk_id
,
async_op
=
True
)
if
not
parallel_state
.
is_pipeline_last_stage
():
if
not
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
output_tensor_send
=
output_tensor
fwd_wait_handles_send
=
send_forward
(
output_tensor_send
,
tensor_shape
,
config
,
master_chunk_id
,
async_op
=
True
)
...
...
@@ -1024,32 +745,12 @@ def forward_backward_pipelining_with_cutinhalf(
WeightGradStore
.
start_decouple
()
if
args
.
moe_fb_overlap
:
if
is_dualpipev_last_stgae
(
slave_chunk_id
):
input_tensor_bwd
=
logits_inputs
.
pop
(
0
)
output_tensor_bwd
=
output_tensors
[
slave_chunk_id
][
0
]
model_graph
=
None
output_tensor_grad_bwd
=
backward_step_with_model_graph
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
model_type
,
config
,
model_graph
)
input_tensor_bwd
=
input_tensors
[
slave_chunk_id
].
pop
(
0
)[
1
]
output_tensor_bwd
=
output_tensors
[
slave_chunk_id
].
pop
(
0
)
model_graph
=
model_graphs
[
slave_chunk_id
].
pop
(
0
)
input_tensor_bwd
=
input_tensors
[
slave_chunk_id
].
pop
(
0
)[
1
]
output_tensor_bwd
=
output_tensors
[
slave_chunk_id
].
pop
(
0
)
input_tensor_grad
=
backward_step_with_model_graph
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
model_type
,
config
,
model_graph
)
else
:
input_tensor_bwd
=
input_tensors
[
slave_chunk_id
].
pop
(
0
)[
1
]
output_tensor_bwd
=
output_tensors
[
slave_chunk_id
].
pop
(
0
)
input_tensor_grad
=
backward_step
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
model_type
,
config
)
input_tensor_grad
=
backward_step
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
model_type
,
config
)
WeightGradStore
.
end_decouple
()
...
...
@@ -1084,31 +785,24 @@ def forward_backward_pipelining_with_cutinhalf(
# 1F: Forward pass
set_dualpipe_chunk
(
slave_chunk_id
)
output_tensor_slave_chunk
,
num_tokens
=
forward_step_no_model_graph
(
forward_step_func
,
slave_chunk_id
,
data_iterator
[
slave_chunk_id
],
model
[
slave_chunk_id
],
num_microbatches
,
input_tensor_slave_chunk
,
forward_data_store
,
config
,
collect_non_loss_data
,
checkpoint_activations_microbatch
,
current_microbatch
=
slave_cur_microbatch
)
if
args
.
moe_fb_overlap
:
input_tensors
[
slave_chunk_id
].
append
(
(
slave_cur_microbatch
,
input_tensor_slave_chunk
))
output_tensor_slave_chunk
,
_
=
forward_step_helper
(
slave_chunk_id
,
slave_cur_microbatch
,
checkpoint_activations_microbatch
)
else
:
output_tensor_slave_chunk
,
num_tokens
=
forward_step_no_model_graph
(
forward_step_func
,
slave_chunk_id
,
data_iterator
[
slave_chunk_id
],
model
[
slave_chunk_id
],
num_microbatches
,
input_tensor_slave_chunk
,
forward_data_store
,
config
,
collect_non_loss_data
,
checkpoint_activations_microbatch
,
current_microbatch
=
slave_cur_microbatch
)
input_tensors
[
slave_chunk_id
].
append
(
(
slave_cur_microbatch
,
input_tensor_slave_chunk
))
total_num_tokens
+=
num_tokens
.
item
()
output_tensors
[
slave_chunk_id
].
append
(
output_tensor_slave_chunk
)
input_tensors
[
slave_chunk_id
].
append
(
(
slave_cur_microbatch
,
input_tensor_slave_chunk
))
total_num_tokens
+=
num_tokens
.
item
()
output_tensors
[
slave_chunk_id
].
append
(
output_tensor_slave_chunk
)
slave_cur_microbatch
+=
1
...
...
@@ -1129,277 +823,110 @@ def forward_backward_pipelining_with_cutinhalf(
if
fwd_model_chunk_id
==
slave_chunk_id
and
slave_cur_microbatch
==
slave_microbatch_max
:
only_bwd
=
True
if
args
.
moe_fb_overlap
and
not
firstFB_no_overlp
:
if
not
only_bwd
:
if
fwd_wait_handles
is
not
None
:
for
req
in
fwd_wait_handles
:
req
.
wait
()
fwd_wait_handles
=
None
if
fwd_wait_handles_recv
is
not
None
:
for
req
in
fwd_wait_handles_recv
:
req
.
wait
()
fwd_wait_handles_recv
=
None
if
bwd_wait_handles
is
not
None
:
for
req
in
bwd_wait_handles
:
req
.
wait
()
bwd_wait_handles
=
None
if
not
parallel_state
.
is_pipeline_last_stage
()
or
fwd_model_chunk_id
==
master_chunk_id
:
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
fwd_microbatch
=
master_cur_microbatch
if
fwd_model_chunk_id
==
master_chunk_id
else
slave_cur_microbatch
set_dualpipe_chunk
(
fwd_model_chunk_id
)
fwd_send_only
=
False
if
fwd_model_chunk_id
==
slave_chunk_id
and
master_cur_microbatch
==
master_microbatch_max
:
fwd_send_only
=
True
extra_block_kwargs
=
{}
if
is_dualpipev_last_stgae
(
bwd_model_chunk_id
):
input_tensor_bwd
=
logits_inputs
.
pop
(
0
)
output_tensor_bwd
=
output_tensors
[
bwd_model_chunk_id
][
0
]
model_graph
=
None
input_tensor_grad
=
backward_step_with_model_graph
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
model_type
,
config
,
model_graph
)
extra_block_kwargs
.
setdefault
(
'bwd_model_grad'
,
input_tensor_grad
)
else
:
extra_block_kwargs
.
setdefault
(
'bwd_model_grad'
,
output_tensor_grad_bwd
)
fwd_pp_comm_params
,
bwd_pp_comm_params
=
check_pipeline_stage
(
fwd_model_chunk_id
,
fwd_send_only
)
fwd_pp_comm_params
.
config
,
bwd_pp_comm_params
.
config
=
config
,
config
fwd_pp_comm_params
.
tensor_shape
,
bwd_pp_comm_params
.
tensor_shape
=
tensor_shape
,
tensor_shape
if
not
only_bwd
:
fwd_microbatch
=
master_cur_microbatch
if
fwd_model_chunk_id
==
master_chunk_id
else
slave_cur_microbatch
set_dualpipe_chunk
(
fwd_model_chunk_id
)
extra_block_kwargs
.
setdefault
(
'bwd_model_graph'
,
model_graphs
[
bwd_model_chunk_id
].
pop
(
0
))
extra_block_kwargs
.
setdefault
(
'pp_comm_params'
,
fwd_pp_comm_params
)
extra_block_kwargs
.
setdefault
(
'bwd_pp_comm_params'
,
bwd_pp_comm_params
)
input_tensors
[
fwd_model_chunk_id
].
append
(
(
fwd_microbatch
,
input_tensor
))
output_tensor
,
num_tokens
=
forward_step_no_model_graph
(
forward_step_func
,
fwd_model_chunk_id
,
data_iterator
[
fwd_model_chunk_id
],
model
[
fwd_model_chunk_id
],
num_microbatches
,
input_tensor
,
forward_data_store
,
config
,
collect_non_loss_data
,
checkpoint_activations_microbatch
,
current_microbatch
=
fwd_microbatch
)
input_tensors
[
fwd_model_chunk_id
].
append
(
(
fwd_microbatch
,
input_tensor
))
total_num_tokens
+=
num_tokens
.
item
()
output_tensors
[
fwd_model_chunk_id
].
append
(
output_tensor
)
output_tensor
,
model_graph
,
pp_comm_output
=
forward_step_helper
(
fwd_model_chunk_id
,
fwd_microbatch
,
checkpoint_activations_microbatch
,
extra_block_kwargs
=
extra_block_kwargs
)
if
fwd_model_chunk_id
==
master_chunk_id
:
master_cur_microbatch
+=
1
fwd_send_only
=
False
else
:
slave_cur_microbatch
+=
1
fwd_send_only
=
(
master_cur_microbatch
==
master_microbatch_max
)
if
fwd_send_only
:
fwd_wait_handles
=
send_forward
(
output_tensor
,
tensor_shape
,
config
,
fwd_model_chunk_id
,
async_op
=
True
)
else
:
if
parallel_state
.
is_pipeline_last_stage
()
and
fwd_model_chunk_id
==
master_chunk_id
:
input_tensor
=
output_tensor
output_tensor_grad_bwd
=
pp_comm_output
.
input_tensor_grad
else
:
input_tensor
,
fwd_wait_handles
=
pp_comm_output
.
input_tensor
,
pp_comm_output
.
fwd_wait_handles
output_tensor
_grad_bwd
,
bwd_wait_handles
=
pp_comm_output
.
output_tensor_grad
,
pp_comm_output
.
bwd_wait_handles
input_tensor
,
fwd_wait_handles
=
send_forward_recv_slave_forward
(
output_tensor
,
tensor_shape
,
config
,
fwd_model_chunk_id
,
async_op
=
True
)
if
fwd_model_chunk_id
==
master_chunk_id
:
master_cur_microbatch
+=
1
else
:
slave_cur_microbatch
+=
1
if
fwd_wait_handles_slave_chunk
is
not
None
:
for
req
in
fwd_wait_handles_slave_chunk
:
# 同步上个阶段最后一个slave前向send
req
.
wait
()
deallocate_output_tensor
(
output_tensor_slave_chunk
,
config
.
deallocate_pipeline_outputs
)
fwd_wait_handles_slave_chunk
=
None
else
:
if
fwd_wait_handles
is
not
None
:
for
req
in
fwd_wait_handles
:
req
.
wait
()
fwd_wait_handles
=
None
if
bwd_wait_handles
is
not
None
:
for
req
in
bwd_wait_handles
:
req
.
wait
()
bwd_wait_handles
=
None
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
if
firstFB_no_overlp_handle
is
not
None
:
for
req
in
firstFB_no_overlp_handle
:
req
.
wait
()
firstFB_no_overlp_handle
=
None
if
bwd_model_chunk_id
==
slave_chunk_id
and
slave_cur_microbatch
<
slave_microbatch_max
:
input_tensor
,
fwd_wait_handles_recv
=
recv_forward
(
tensor_shape
,
config
,
slave_chunk_id
,
async_op
=
True
)
if
bwd_wait_handles
is
not
None
:
for
req
in
bwd_wait_handles
:
req
.
wait
()
bwd_wait_handles
=
None
if
is_dualpipev_last_stgae
(
bwd_model_chunk_id
):
input_tensor_bwd
=
logits_inputs
.
pop
(
0
)
output_tensor_bwd
=
output_tensors
[
bwd_model_chunk_id
]
[
0
]
model_graph
=
None
input_tensor_bwd
=
input_tensors
[
bwd_model_chunk_id
].
pop
(
0
)[
1
]
output_tensor_bwd
=
output_tensors
[
bwd_model_chunk_id
]
.
pop
(
0
)
out
put_tensor_grad
_bwd
=
backward_step
_with_model_graph
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
model_type
,
config
,
model_graph
)
in
put_tensor_grad
=
backward_step
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
model_type
,
config
)
input_tensor_bwd
=
input_tensors
[
bwd_model_chunk_id
].
pop
(
0
)[
1
]
output_tensor_bwd
=
output_tensors
[
bwd_model_chunk_id
].
pop
(
0
)
model_graph
=
model_graphs
[
bwd_model_chunk_id
].
pop
(
0
)
if
fwd_wait_handles
is
not
None
:
for
req
in
fwd_wait_handles
:
req
.
wait
()
fwd_wait_handles
=
None
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
input_tensor_grad
=
backward_step_with_model_graph
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
model_type
,
config
,
model_graph
)
if
parallel_state
.
is_pipeline_last_stage
()
and
fwd_model_chunk_id
==
master_chunk_id
:
output_tensor_grad_bwd
=
input_tensor_grad
else
:
# send_backward_recv_slave_backward
output_tensor_grad_bwd
,
bwd_wait_handles
=
send_backward_recv_slave_backward
(
input_tensor_grad
,
tensor_shape
,
config
,
fwd_model_chunk_id
,
async_op
=
True
)
if
parallel_state
.
is_pipeline_last_stage
()
and
fwd_model_chunk_id
==
master_chunk_id
:
output_tensor_grad_bwd
=
input_tensor_gra
d
else
:
# send_backward_recv_slave_backward
output_tensor_
grad_bwd
,
bwd_wait_handles
=
send_forward_recv_slave_forward
(
input_tensor_grad
,
tensor_shape
,
config
,
fwd_model_chunk_id
)
if
fwd_wait_handles_slave_chunk
is
not
None
:
for
req
in
fwd_wait_handles_slave_chunk
:
# 同步上个阶段最后一个slave前向sen
d
req
.
wait
()
deallocate_output_tensor
(
output_tensor_
slave_chunk
,
config
.
deallocate_pipeline_outputs
)
fwd_wait_handles_slave_chunk
=
None
# only run backward
else
:
firstFB_no_overlp
=
False
if
not
only_bwd
:
fwd_microbatch
=
master_cur_microbatch
if
fwd_model_chunk_id
==
master_chunk_id
else
slave_cur_microbatch
set_dualpipe_chunk
(
fwd_model_chunk_id
)
if
args
.
moe_fb_overlap
:
input_tensors
[
fwd_model_chunk_id
].
append
(
(
fwd_microbatch
,
input_tensor
))
output_tensor
,
_
=
forward_step_helper
(
fwd_model_chunk_id
,
fwd_microbatch
,
checkpoint_activations_microbatch
)
else
:
output_tensor
,
num_tokens
=
forward_step_no_model_graph
(
forward_step_func
,
fwd_model_chunk_id
,
data_iterator
[
fwd_model_chunk_id
],
model
[
fwd_model_chunk_id
],
num_microbatches
,
input_tensor
,
forward_data_store
,
config
,
collect_non_loss_data
,
checkpoint_activations_microbatch
,
current_microbatch
=
fwd_microbatch
)
input_tensors
[
fwd_model_chunk_id
].
append
(
(
fwd_microbatch
,
input_tensor
))
total_num_tokens
+=
num_tokens
.
item
()
output_tensors
[
fwd_model_chunk_id
].
append
(
output_tensor
)
if
fwd_model_chunk_id
==
master_chunk_id
:
master_cur_microbatch
+=
1
fwd_send_only
=
False
else
:
slave_cur_microbatch
+=
1
fwd_send_only
=
(
master_cur_microbatch
==
master_microbatch_max
)
if
fwd_send_only
:
fwd_wait_handles
=
send_forward
(
output_tensor
,
tensor_shape
,
config
,
fwd_model_chunk_id
,
async_op
=
True
)
else
:
if
parallel_state
.
is_pipeline_last_stage
()
and
fwd_model_chunk_id
==
master_chunk_id
:
input_tensor
=
output_tensor
else
:
input_tensor
,
fwd_wait_handles
=
send_forward_recv_slave_forward
(
output_tensor
,
tensor_shape
,
config
,
fwd_model_chunk_id
,
async_op
=
True
)
if
firstFB_no_overlp_handle
is
not
None
:
for
req
in
firstFB_no_overlp_handle
:
req
.
wait
()
firstFB_no_overlp_handle
=
None
if
bwd_wait_handles
is
not
None
:
for
req
in
bwd_wait_handles
:
req
.
wait
()
bwd_wait_handles
=
None
if
args
.
moe_fb_overlap
:
if
is_dualpipev_last_stgae
(
bwd_model_chunk_id
):
input_tensor_bwd
=
logits_inputs
.
pop
(
0
)
output_tensor_bwd
=
output_tensors
[
bwd_model_chunk_id
][
0
]
model_graph
=
None
output_tensor_grad_bwd
=
backward_step_with_model_graph
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
model_type
,
config
,
model_graph
)
input_tensor_bwd
=
input_tensors
[
bwd_model_chunk_id
].
pop
(
0
)[
1
]
output_tensor_bwd
=
output_tensors
[
bwd_model_chunk_id
].
pop
(
0
)
model_graph
=
model_graphs
[
bwd_model_chunk_id
].
pop
(
0
)
input_tensor_grad
=
backward_step_with_model_graph
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
model_type
,
config
,
model_graph
)
else
:
input_tensor_bwd
=
input_tensors
[
bwd_model_chunk_id
].
pop
(
0
)[
1
]
output_tensor_bwd
=
output_tensors
[
bwd_model_chunk_id
].
pop
(
0
)
input_tensor_grad
=
backward_step
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
model_type
,
config
)
if
fwd_wait_handles
is
not
None
:
for
req
in
fwd_wait_handles
:
req
.
wait
()
fwd_wait_handles
=
None
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
if
bwd_model_chunk_id
==
slave_chunk_id
and
slave_cur_microbatch
<
slave_microbatch_max
:
input_tensor
,
_
=
recv_forward
(
tensor_shape
,
config
,
slave_chunk_id
)
if
bwd_wait_handles
is
not
None
:
for
req
in
bwd_wait_handles
:
req
.
wait
()
bwd_wait_handles
=
None
input_tensor_bwd
=
input_tensors
[
bwd_model_chunk_id
].
pop
(
0
)[
1
]
output_tensor_bwd
=
output_tensors
[
bwd_model_chunk_id
].
pop
(
0
)
input_tensor_grad
=
backward_step
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
model_type
,
config
)
if
parallel_state
.
is_pipeline_last_stage
()
and
fwd_model_chunk_id
==
master_chunk_id
:
output_tensor_grad_bwd
=
input_tensor_grad
else
:
# send_backward_recv_slave_backward
output_tensor_grad_bwd
,
bwd_wait_handles
=
send_forward_recv_slave_forward
(
input_tensor_grad
,
tensor_shape
,
config
,
fwd_model_chunk_id
,
async_op
=
True
)
if
fwd_wait_handles_slave_chunk
is
not
None
:
for
req
in
fwd_wait_handles_slave_chunk
:
# 同步上个阶段最后一个slave前向send
req
.
wait
()
deallocate_output_tensor
(
output_tensor_slave_chunk
,
config
.
deallocate_pipeline_outputs
)
fwd_wait_handles_slave_chunk
=
None
# only run backward
if
parallel_state
.
is_pipeline_last_stage
()
and
fwd_model_chunk_id
==
master_chunk_id
:
output_tensor_grad_bwd
=
input_tensor_grad
else
:
if
bwd_model_chunk_id
==
slave_chunk_id
and
slave_cur_microbatch
<
slave_microbatch_max
:
input_tensor
,
_
=
recv_forward
(
tensor_shape
,
config
,
slave_chunk_id
)
if
bwd_wait_handles
is
not
None
:
for
req
in
bwd_wait_handles
:
req
.
wait
()
bwd_wait_handles
=
None
if
args
.
moe_fb_overlap
:
if
is_dualpipev_last_stgae
(
bwd_model_chunk_id
):
input_tensor_bwd
=
logits_inputs
.
pop
(
0
)
output_tensor_bwd
=
output_tensors
[
bwd_model_chunk_id
][
0
]
model_graph
=
None
output_tensor_grad_bwd
=
backward_step_with_model_graph
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
model_type
,
config
,
model_graph
)
input_tensor_bwd
=
input_tensors
[
bwd_model_chunk_id
].
pop
(
0
)[
1
]
output_tensor_bwd
=
output_tensors
[
bwd_model_chunk_id
].
pop
(
0
)
model_graph
=
model_graphs
[
bwd_model_chunk_id
].
pop
(
0
)
input_tensor_grad
=
backward_step_with_model_graph
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
model_type
,
config
,
model_graph
)
else
:
input_tensor_bwd
=
input_tensors
[
bwd_model_chunk_id
].
pop
(
0
)[
1
]
output_tensor_bwd
=
output_tensors
[
bwd_model_chunk_id
].
pop
(
0
)
input_tensor_grad
=
backward_step
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
model_type
,
config
)
if
parallel_state
.
is_pipeline_last_stage
()
and
fwd_model_chunk_id
==
master_chunk_id
:
output_tensor_grad_bwd
=
input_tensor_grad
else
:
# send_backward_recv_slave_backward
output_tensor_grad_bwd
,
bwd_wait_handles
=
send_forward_recv_slave_forward
(
input_tensor_grad
,
tensor_shape
,
config
,
fwd_model_chunk_id
)
# send_backward_recv_slave_backward
output_tensor_grad_bwd
,
bwd_wait_handles
=
send_forward_recv_slave_forward
(
input_tensor_grad
,
tensor_shape
,
config
,
fwd_model_chunk_id
)
# swap fwd & bwd chunks
fwd_model_chunk_id
,
bwd_model_chunk_id
=
bwd_model_chunk_id
,
fwd_model_chunk_id
...
...
@@ -1438,16 +965,9 @@ def forward_backward_pipelining_with_cutinhalf(
if
not
args
.
dualpipe_no_dw_detach
:
WeightGradStore
.
start_decouple
()
if
args
.
moe_fb_overlap
:
model_graph
=
model_graphs
[
bwd_model_chunk_id
].
pop
(
0
)
input_tensor_grad
=
backward_step_with_model_graph
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
model_type
,
config
,
model_graph
)
else
:
input_tensor_grad
=
backward_step
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
model_type
,
config
)
input_tensor_grad
=
backward_step
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
model_type
,
config
)
if
not
args
.
dualpipe_no_dw_detach
:
WeightGradStore
.
end_decouple
()
...
...
@@ -1465,7 +985,7 @@ def forward_backward_pipelining_with_cutinhalf(
output_tensor_grad_bwd
=
input_tensor_grad
else
:
# send_backward_recv_slave_backward
output_tensor_grad_bwd
,
bwd_wait_handles
=
send_
for
ward_recv_slave_
for
ward
(
input_tensor_grad
,
output_tensor_grad_bwd
,
bwd_wait_handles
=
send_
back
ward_recv_slave_
back
ward
(
input_tensor_grad
,
tensor_shape
,
config
,
1
-
bwd_model_chunk_id
)
WeightGradStore
.
flush_chunk_grad
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment