Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
69add73b
Commit
69add73b
authored
Jun 12, 2025
by
dongcl
Browse files
dualpipev support moe a2a overlap
parent
62f16817
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
250 additions
and
191 deletions
+250
-191
.gitignore
.gitignore
+1
-0
dcu_megatron/adaptor/features_manager/pipeline_parallel/pipeline_feature.py
...or/features_manager/pipeline_parallel/pipeline_feature.py
+24
-26
dcu_megatron/core/models/gpt/fine_grained_schedule.py
dcu_megatron/core/models/gpt/fine_grained_schedule.py
+1
-2
dcu_megatron/core/parallel_state.py
dcu_megatron/core/parallel_state.py
+14
-0
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
+10
-1
dcu_megatron/core/pipeline_parallel/dualpipev/dualpipev_chunks.py
...tron/core/pipeline_parallel/dualpipev/dualpipev_chunks.py
+1
-1
dcu_megatron/core/pipeline_parallel/dualpipev/dualpipev_schedules.py
...n/core/pipeline_parallel/dualpipev/dualpipev_schedules.py
+199
-161
No files found.
.gitignore
View file @
69add73b
__pycache__
__pycache__
*.bak
*.bak
*.log
dcu_megatron/adaptor/features_manager/pipeline_parallel/pipeline_feature.py
View file @
69add73b
...
@@ -69,16 +69,12 @@ class PipelineFeature(AbstractFeature):
...
@@ -69,16 +69,12 @@ class PipelineFeature(AbstractFeature):
patch_manager
.
register_patch
(
patch_manager
.
register_patch
(
'megatron.training.training.evaluate'
,
evaluate
)
'megatron.training.training.evaluate'
,
evaluate
)
if
(
if
args
.
combined_1f1b
:
args
.
schedule_method
==
"interleaved_1f1b"
and
args
.
combined_1f1b
):
from
megatron.core.extensions.transformer_engine
import
TEColumnParallelLinear
,
TERowParallelLinear
from
megatron.core.extensions.transformer_engine
import
TEColumnParallelLinear
,
TERowParallelLinear
from
dcu_megatron.core.transformer.moe.token_dispatcher
import
MoEAlltoAllTokenDispatcher
from
dcu_megatron.core.transformer.moe.token_dispatcher
import
MoEAlltoAllTokenDispatcher
from
dcu_megatron.core.transformer.transformer_layer
import
TransformerLayer
from
dcu_megatron.core.transformer.transformer_layer
import
TransformerLayer
from
dcu_megatron.core.models.gpt.gpt_model
import
GPTModel
from
dcu_megatron.core.models.gpt.gpt_model
import
GPTModel
from
dcu_megatron.core.pipeline_parallel.schedules
import
get_pp_rank_microbatches
,
forward_backward_pipelining_with_interleaving
from
dcu_megatron.core.extensions.transformer_engine
import
(
from
dcu_megatron.core.extensions.transformer_engine
import
(
_get_extra_te_kwargs_wrapper
,
_get_extra_te_kwargs_wrapper
,
TELinear
,
TELinear
,
...
@@ -89,53 +85,55 @@ class PipelineFeature(AbstractFeature):
...
@@ -89,53 +85,55 @@ class PipelineFeature(AbstractFeature):
from
dcu_megatron.core.transformer.moe.experts
import
TEGroupedMLP
from
dcu_megatron.core.transformer.moe.experts
import
TEGroupedMLP
from
dcu_megatron.core.transformer.moe.moe_layer
import
MoELayer
from
dcu_megatron.core.transformer.moe.moe_layer
import
MoELayer
# num_warmup_microbatches + 1
patch_manager
.
register_patch
(
'megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher'
,
patches_manager
.
register_patch
(
'megatron.core.pipeline_parallel.schedules.get_pp_rank_microbatches'
,
get_pp_rank_microbatches
)
# a2a_overlap
patches_manager
.
register_patch
(
'megatron.core.pipeline_parallel.schedules.forward_backward_pipelining_with_interleaving'
,
forward_backward_pipelining_with_interleaving
)
patches_manager
.
register_patch
(
'megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher'
,
MoEAlltoAllTokenDispatcher
)
MoEAlltoAllTokenDispatcher
)
patch
es
_manager
.
register_patch
(
'megatron.core.transformer.transformer_layer.TransformerLayer'
,
patch_manager
.
register_patch
(
'megatron.core.transformer.transformer_layer.TransformerLayer'
,
TransformerLayer
)
TransformerLayer
)
patch
es
_manager
.
register_patch
(
'megatron.core.models.gpt.gpt_model.GPTModel.build_schedule_plan'
,
patch_manager
.
register_patch
(
'megatron.core.models.gpt.gpt_model.GPTModel.build_schedule_plan'
,
GPTModel
.
build_schedule_plan
,
GPTModel
.
build_schedule_plan
,
create_dummy
=
True
)
create_dummy
=
True
)
# backward_dw
# backward_dw
patch
es
_manager
.
register_patch
(
'megatron.core.extensions.transformer_engine._get_extra_te_kwargs'
,
patch_manager
.
register_patch
(
'megatron.core.extensions.transformer_engine._get_extra_te_kwargs'
,
_get_extra_te_kwargs_wrapper
,
_get_extra_te_kwargs_wrapper
,
apply_wrapper
=
True
)
apply_wrapper
=
True
)
patch
es
_manager
.
register_patch
(
'megatron.core.extensions.transformer_engine.TELinear'
,
patch_manager
.
register_patch
(
'megatron.core.extensions.transformer_engine.TELinear'
,
TELinear
)
TELinear
)
patch
es
_manager
.
register_patch
(
'megatron.core.extensions.transformer_engine.TELayerNormColumnParallelLinear'
,
patch_manager
.
register_patch
(
'megatron.core.extensions.transformer_engine.TELayerNormColumnParallelLinear'
,
TELayerNormColumnParallelLinear
)
TELayerNormColumnParallelLinear
)
TEColumnParallelLinear
.
__bases__
=
(
TELinear
,)
TEColumnParallelLinear
.
__bases__
=
(
TELinear
,)
TERowParallelLinear
.
__bases__
=
(
TELinear
,)
TERowParallelLinear
.
__bases__
=
(
TELinear
,)
if
is_te_min_version
(
"1.9.0.dev0"
):
if
is_te_min_version
(
"1.9.0.dev0"
):
from
megatron.core.extensions.transformer_engine
import
TEColumnParallelGroupedLinear
,
TERowParallelGroupedLinear
from
megatron.core.extensions.transformer_engine
import
TEColumnParallelGroupedLinear
,
TERowParallelGroupedLinear
from
.
.core.extensions.transformer_engine
import
TEGroupedLinear
from
dcu_megatron
.core.extensions.transformer_engine
import
TEGroupedLinear
patch
es
_manager
.
register_patch
(
'megatron.core.extensions.transformer_engine.TEGroupedLinear'
,
patch_manager
.
register_patch
(
'megatron.core.extensions.transformer_engine.TEGroupedLinear'
,
TEGroupedLinear
)
TEGroupedLinear
)
TEColumnParallelGroupedLinear
.
__bases__
=
(
TEGroupedLinear
,)
TEColumnParallelGroupedLinear
.
__bases__
=
(
TEGroupedLinear
,)
TERowParallelGroupedLinear
.
__bases__
=
(
TEGroupedLinear
,)
TERowParallelGroupedLinear
.
__bases__
=
(
TEGroupedLinear
,)
patch
es
_manager
.
register_patch
(
'megatron.core.transformer.multi_latent_attention.MLASelfAttention.backward_dw'
,
patch_manager
.
register_patch
(
'megatron.core.transformer.multi_latent_attention.MLASelfAttention.backward_dw'
,
MLASelfAttention
.
backward_dw
,
MLASelfAttention
.
backward_dw
,
create_dummy
=
True
)
create_dummy
=
True
)
patch
es
_manager
.
register_patch
(
'megatron.core.transformer.mlp.MLP.backward_dw'
,
patch_manager
.
register_patch
(
'megatron.core.transformer.mlp.MLP.backward_dw'
,
MLP
.
backward_dw
,
MLP
.
backward_dw
,
create_dummy
=
True
)
create_dummy
=
True
)
patch
es
_manager
.
register_patch
(
'megatron.core.transformer.moe.experts.TEGroupedMLP.backward_dw'
,
patch_manager
.
register_patch
(
'megatron.core.transformer.moe.experts.TEGroupedMLP.backward_dw'
,
TEGroupedMLP
.
backward_dw
,
TEGroupedMLP
.
backward_dw
,
create_dummy
=
True
)
create_dummy
=
True
)
patch
es
_manager
.
register_patch
(
'megatron.core.transformer.moe.moe_layer.MoELayer.backward_dw'
,
patch_manager
.
register_patch
(
'megatron.core.transformer.moe.moe_layer.MoELayer.backward_dw'
,
MoELayer
.
backward_dw
,
MoELayer
.
backward_dw
,
create_dummy
=
True
)
create_dummy
=
True
)
if
args
.
schedule_method
==
"interleaved_1f1b"
:
from
dcu_megatron.core.pipeline_parallel.schedules
import
get_pp_rank_microbatches
,
forward_backward_pipelining_with_interleaving
# num_warmup_microbatches + 1
patch_manager
.
register_patch
(
'megatron.core.pipeline_parallel.schedules.get_pp_rank_microbatches'
,
get_pp_rank_microbatches
)
# a2a_overlap
patch_manager
.
register_patch
(
'megatron.core.pipeline_parallel.schedules.forward_backward_pipelining_with_interleaving'
,
forward_backward_pipelining_with_interleaving
)
dcu_megatron/core/models/gpt/fine_grained_schedule.py
View file @
69add73b
...
@@ -6,7 +6,6 @@ from typing import Optional
...
@@ -6,7 +6,6 @@ from typing import Optional
import
torch
import
torch
from
torch
import
Tensor
from
torch
import
Tensor
from
megatron.core
import
parallel_state
from
megatron.core.config_logger
import
has_config_logger_enabled
,
log_config_to_disk
from
megatron.core.config_logger
import
has_config_logger_enabled
,
log_config_to_disk
from
megatron.core.inference.contexts
import
BaseInferenceContext
from
megatron.core.inference.contexts
import
BaseInferenceContext
...
@@ -720,7 +719,7 @@ def schedule_chunk_1f1b(
...
@@ -720,7 +719,7 @@ def schedule_chunk_1f1b(
if
f_schedule_plan
is
not
None
and
post_forward
is
not
None
:
if
f_schedule_plan
is
not
None
and
post_forward
is
not
None
:
with
f_context
:
with
f_context
:
f_schedule_plan
.
wait_current_stream
()
f_schedule_plan
.
wait_current_stream
()
post_forward
(
None
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
False
)
else
f_input
)
post_forward
(
f_input
)
# pp grad send / receive, overlapped with attn dw of cur micro-batch and forward attn of next micro-batch
# pp grad send / receive, overlapped with attn dw of cur micro-batch and forward attn of next micro-batch
if
b_schedule_plan
is
not
None
and
post_backward
is
not
None
:
if
b_schedule_plan
is
not
None
and
post_backward
is
not
None
:
...
...
dcu_megatron/core/parallel_state.py
0 → 100644
View file @
69add73b
_DUALPIPE_CHUNK
=
None
def
set_dualpipe_chunk
(
chunk_id
):
"""set_dualpipe_chunk for fp16forward patch"""
global
_DUALPIPE_CHUNK
_DUALPIPE_CHUNK
=
chunk_id
def
get_dualpipe_chunk
():
global
_DUALPIPE_CHUNK
if
_DUALPIPE_CHUNK
is
not
None
:
return
_DUALPIPE_CHUNK
else
:
raise
AssertionError
(
"_DUALPIPE_CHUNK is None"
)
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
View file @
69add73b
...
@@ -7,6 +7,7 @@ import torch
...
@@ -7,6 +7,7 @@ import torch
from
torch
import
Tensor
from
torch
import
Tensor
from
torch.autograd.variable
import
Variable
from
torch.autograd.variable
import
Variable
from
megatron.training
import
get_args
from
megatron.core
import
parallel_state
from
megatron.core
import
parallel_state
from
megatron.core.distributed
import
DistributedDataParallel
from
megatron.core.distributed
import
DistributedDataParallel
...
@@ -15,6 +16,8 @@ from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler
...
@@ -15,6 +16,8 @@ from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler
from
megatron.core.transformer.multi_token_prediction
import
MTPLossAutoScaler
from
megatron.core.transformer.multi_token_prediction
import
MTPLossAutoScaler
from
megatron.core.utils
import
get_attr_wrapped_model
,
make_viewless_tensor
from
megatron.core.utils
import
get_attr_wrapped_model
,
make_viewless_tensor
from
dcu_megatron.core.parallel_state
import
get_dualpipe_chunk
def
make_viewless
(
e
):
def
make_viewless
(
e
):
"""make_viewless util func"""
"""make_viewless util func"""
...
@@ -432,7 +435,13 @@ def forward_backward_step(
...
@@ -432,7 +435,13 @@ def forward_backward_step(
if
f_model
:
if
f_model
:
with
f_context
:
with
f_context
:
num_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
int
)
num_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
int
)
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
False
):
args
=
get_args
()
is_last_stage
=
False
if
args
.
schedule_method
==
"dualpipev"
:
is_last_stage
=
parallel_state
.
is_pipeline_first_stage
()
and
get_dualpipe_chunk
()
==
1
else
:
is_last_stage
=
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
False
)
if
is_last_stage
:
if
not
collect_non_loss_data
:
if
not
collect_non_loss_data
:
loss_node
=
ScheduleNode
(
loss_node
=
ScheduleNode
(
loss_func
,
loss_func
,
...
...
dcu_megatron/core/pipeline_parallel/dualpipev/dualpipev_chunks.py
View file @
69add73b
...
@@ -20,7 +20,7 @@ from megatron.training.utils import (
...
@@ -20,7 +20,7 @@ from megatron.training.utils import (
reduce_max_stat_across_model_parallel_group
reduce_max_stat_across_model_parallel_group
)
)
from
dcu_megatron.core.p
ipeline_parallel.dualpipev.dualpipev_schedules
import
get_dualpipe_chunk
from
dcu_megatron.core.p
arallel_state
import
get_dualpipe_chunk
def
dualpipev_fp16forward
(
self
,
*
inputs
,
**
kwargs
):
def
dualpipev_fp16forward
(
self
,
*
inputs
,
**
kwargs
):
...
...
dcu_megatron/core/pipeline_parallel/dualpipev/dualpipev_schedules.py
View file @
69add73b
...
@@ -26,7 +26,8 @@ from megatron.core.pipeline_parallel.schedules import (
...
@@ -26,7 +26,8 @@ from megatron.core.pipeline_parallel.schedules import (
finish_embedding_wgrad_compute
finish_embedding_wgrad_compute
)
)
from
dcu_megatron.training.utils
import
print_rank_message
from
dcu_megatron.core.pipeline_parallel.combined_1f1b
import
forward_backward_step
,
set_streams
,
wrap_forward_func
from
dcu_megatron.core.parallel_state
import
set_dualpipe_chunk
# from mindspeed.core.pipeline_parallel.fb_overlap.modules.weight_grad_store import WeightGradStore
# from mindspeed.core.pipeline_parallel.fb_overlap.modules.weight_grad_store import WeightGradStore
...
@@ -35,23 +36,6 @@ Shape = Union[List[int], torch.Size]
...
@@ -35,23 +36,6 @@ Shape = Union[List[int], torch.Size]
LOSS_BACKWARD_SCALE
=
torch
.
tensor
(
1.0
)
LOSS_BACKWARD_SCALE
=
torch
.
tensor
(
1.0
)
_DUALPIPE_CHUNK
=
None
def
set_dualpipe_chunk
(
chunk_id
):
"""set_dualpipe_chunk for fp16forward patch"""
global
_DUALPIPE_CHUNK
_DUALPIPE_CHUNK
=
chunk_id
def
get_dualpipe_chunk
():
global
_DUALPIPE_CHUNK
if
_DUALPIPE_CHUNK
is
not
None
:
return
_DUALPIPE_CHUNK
else
:
raise
AssertionError
(
"_DUALPIPE_CHUNK is None"
)
def
is_dualpipev_last_stage
(
model_chunk_id
):
def
is_dualpipev_last_stage
(
model_chunk_id
):
return
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
True
)
and
model_chunk_id
==
1
return
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
True
)
and
model_chunk_id
==
1
...
@@ -530,6 +514,13 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -530,6 +514,13 @@ def forward_backward_pipelining_with_cutinhalf(
config
=
get_model_config
(
model
[
0
])
config
=
get_model_config
(
model
[
0
])
config
.
batch_p2p_comm
=
False
config
.
batch_p2p_comm
=
False
if
(
not
forward_only
and
config
.
combined_1f1b
):
set_streams
()
forward_step_func
=
wrap_forward_func
(
config
,
forward_step_func
)
# Needed only when gradients are finalized in M-Core
# Needed only when gradients are finalized in M-Core
if
config
.
finalize_model_grads_func
is
not
None
and
not
forward_only
:
if
config
.
finalize_model_grads_func
is
not
None
and
not
forward_only
:
embedding_module
=
clear_embedding_activation_buffer
(
config
,
model
)
embedding_module
=
clear_embedding_activation_buffer
(
config
,
model
)
...
@@ -582,86 +573,6 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -582,86 +573,6 @@ def forward_backward_pipelining_with_cutinhalf(
disable_grad_sync
()
disable_grad_sync
()
def
combined_forward_backward_helper
(
fwd_model_chunk_id
,
bwd_model_chunk_id
,
fwd_input_tensor
=
None
,
bwd_output_tensor_grad
=
None
,
pre_forward
=
None
,
pre_backward
=
None
,
post_forward
=
None
,
post_backward
=
None
,
):
"""Helper method to run combined forward and backward step"""
# forward prepare
fwd_microbatch_id
=
cur_fwd_chunk_microbatch
[
fwd_model_chunk_id
]
f_context
=
contextlib
.
nullcontext
()
set_dualpipe_chunk
(
fwd_model_chunk_id
)
# backward prepare
b_context
=
contextlib
.
nullcontext
()
bwd_input_tensor
=
input_tensors
[
bwd_model_chunk_id
].
pop
(
0
)[
1
]
bwd_output_tensor
=
output_tensors
[
bwd_model_chunk_id
].
pop
(
0
)
output_tensor
,
num_tokens
,
input_tensor_grad
=
forward_backward_step
(
forward_step_func
,
data_iterator
[
fwd_model_chunk_id
]
if
fwd_model_chunk_id
is
not
None
else
None
,
model
[
fwd_model_chunk_id
]
if
fwd_model_chunk_id
is
not
None
else
None
,
num_microbatches
,
fwd_input_tensor
,
forward_data_store
,
model
[
bwd_model_chunk_id
]
if
bwd_model_chunk_id
is
not
None
else
None
,
bwd_input_tensor
,
bwd_output_tensor
,
bwd_output_tensor_grad
,
config
,
f_context
=
f_context
,
b_context
=
b_context
,
collect_non_loss_data
=
collect_non_loss_data
,
checkpoint_activations_microbatch
=
None
,
is_first_microbatch
=
False
,
current_microbatch
=
fwd_microbatch_id
,
)
# forward post process
if
fwd_model_chunk_id
is
not
None
:
with
f_context
:
nonlocal
total_num_tokens
total_num_tokens
+=
num_tokens
.
item
()
if
not
forward_only
:
input_tensors
[
fwd_model_chunk_id
].
append
((
fwd_microbatch_id
,
fwd_input_tensor
))
output_tensors
[
fwd_model_chunk_id
].
append
(
output_tensor
)
# backward post process
if
b_model_chunk_id
:
with
b_context
:
# launch grad synchronization (custom grad sync)
# Note: Asynchronous communication tends to slow down compute.
# To reduce idling from mismatched microbatch times, we launch
# asynchronous communication at the same time across the
# pipeline-parallel group.
if
config
.
grad_sync_func
is
not
None
:
grad_sync_virtual_microbatch_id
=
(
b_virtual_microbatch_id
-
pipeline_parallel_rank
)
if
grad_sync_virtual_microbatch_id
>=
0
and
is_last_microbatch_for_model_chunk
(
grad_sync_virtual_microbatch_id
):
grad_sync_chunk_id
=
get_model_chunk_id
(
grad_sync_virtual_microbatch_id
,
forward
=
False
)
enable_grad_sync
()
config
.
grad_sync_func
[
grad_sync_chunk_id
](
model
[
grad_sync_chunk_id
].
parameters
()
)
synchronized_model_chunks
.
add
(
grad_sync_chunk_id
)
disable_grad_sync
()
if
input_tensor
is
not
None
:
assert
input_tensor_grad
is
not
None
return
output_tensor
,
input_tensor_grad
# Compute number of steps for each stage
# Compute number of steps for each stage
pp_size
=
parallel_state
.
get_pipeline_model_parallel_world_size
()
pp_size
=
parallel_state
.
get_pipeline_model_parallel_world_size
()
rank
=
parallel_state
.
get_pipeline_model_parallel_rank
()
rank
=
parallel_state
.
get_pipeline_model_parallel_rank
()
...
@@ -686,8 +597,6 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -686,8 +597,6 @@ def forward_backward_pipelining_with_cutinhalf(
cur_bwd_chunk_microbatch
=
[
0
,
num_microbatches
]
cur_bwd_chunk_microbatch
=
[
0
,
num_microbatches
]
num_chunk_max_microbatch
=
[
num_microbatches
,
num_microbatches
*
2
]
num_chunk_max_microbatch
=
[
num_microbatches
,
num_microbatches
*
2
]
checkpoint_activations_microbatch
=
None
def
wait_comm_handles
(
comm_handles
):
def
wait_comm_handles
(
comm_handles
):
if
comm_handles
is
None
:
if
comm_handles
is
None
:
return
return
...
@@ -697,13 +606,19 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -697,13 +606,19 @@ def forward_backward_pipelining_with_cutinhalf(
req_handle
.
wait
()
req_handle
.
wait
()
comm_handles
=
None
comm_handles
=
None
def
forward_step_helper
(
model_chunk_id
,
cur_microbatch
,
is_first
_microbatch
=
False
):
def
forward_step_helper
(
model_chunk_id
,
cur_microbatch
,
checkpoint_activations
_microbatch
=
False
):
set_dualpipe_chunk
(
model_chunk_id
)
set_dualpipe_chunk
(
model_chunk_id
)
if
not
forward_only
:
if
not
forward_only
:
offset
=
cur_bwd_chunk_microbatch
[
model_chunk_id
]
offset
=
cur_bwd_chunk_microbatch
[
model_chunk_id
]
input_tensor
=
input_tensors
[
model_chunk_id
][
cur_microbatch
-
offset
]
input_tensor
=
input_tensors
[
model_chunk_id
][
cur_microbatch
-
offset
]
else
:
else
:
input_tensor
=
input_tensors
[
model_chunk_id
][
0
]
input_tensor
=
input_tensors
[
model_chunk_id
][
0
]
is_first_microbatch
=
check_first_val_step
(
first_val_step
,
forward_only
,
cur_fwd_chunk_microbatch
[
model_chunk_id
],
),
output_tensor
,
num_tokens
=
forward_step_no_model_graph
(
output_tensor
,
num_tokens
=
forward_step_no_model_graph
(
forward_step_func
,
forward_step_func
,
model_chunk_id
,
model_chunk_id
,
...
@@ -759,12 +674,154 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -759,12 +674,154 @@ def forward_backward_pipelining_with_cutinhalf(
return
input_tensor_grad
return
input_tensor_grad
def
combined_forward_backward_helper
(
fwd_model_chunk_id
=
None
,
bwd_model_chunk_id
=
None
,
pre_forward
=
None
,
pre_backward
=
None
,
post_forward
=
None
,
post_backward
=
None
,
):
"""Helper method to run combined forward and backward step"""
# forward prepare
f_context
=
contextlib
.
nullcontext
()
fwd_input_tensor
=
None
fwd_microbatch_id
=
None
if
fwd_model_chunk_id
is
not
None
:
fwd_microbatch_id
=
cur_fwd_chunk_microbatch
[
fwd_model_chunk_id
]
set_dualpipe_chunk
(
fwd_model_chunk_id
)
offset
=
cur_bwd_chunk_microbatch
[
fwd_model_chunk_id
]
fwd_input_tensor
=
input_tensors
[
fwd_model_chunk_id
][
fwd_microbatch_id
-
offset
]
# backward prepare
b_context
=
contextlib
.
nullcontext
()
bwd_input_tensor
=
None
bwd_output_tensor
=
None
bwd_output_tensor_grad
=
None
if
bwd_model_chunk_id
is
not
None
:
# Enable async grad reduction in the last backward pass
# Note: If grad sync function is provided, only enable
# async grad reduction in first pipeline stage. Other
# pipeline stages do grad reduction during pipeline
# bubble.
bwd_microbatch_id
=
cur_bwd_chunk_microbatch
[
bwd_model_chunk_id
]
if
(
bwd_microbatch_id
is
not
None
and
bwd_microbatch_id
==
num_chunk_max_microbatch
[
bwd_model_chunk_id
]
-
1
):
if
(
config
.
grad_sync_func
is
None
or
(
bwd_model_chunk_id
==
slave_chunk_id
and
parallel_state
.
is_pipeline_last_stage
())
or
(
bwd_model_chunk_id
==
master_chunk_id
and
parallel_state
.
is_pipeline_first_stage
())
):
enable_grad_sync
()
bwd_input_tensor
=
input_tensors
[
bwd_model_chunk_id
].
pop
(
0
)
bwd_output_tensor
=
output_tensors
[
bwd_model_chunk_id
].
pop
(
0
)
bwd_output_tensor_grad
=
output_tensor_grads
[
bwd_model_chunk_id
].
pop
(
0
)
output_tensor
,
num_tokens
,
input_tensor_grad
=
forward_backward_step
(
forward_step_func
,
data_iterator
[
fwd_model_chunk_id
]
if
fwd_model_chunk_id
is
not
None
else
None
,
model
[
fwd_model_chunk_id
]
if
fwd_model_chunk_id
is
not
None
else
None
,
num_microbatches
,
fwd_input_tensor
,
forward_data_store
,
model
[
bwd_model_chunk_id
]
if
bwd_model_chunk_id
is
not
None
else
None
,
bwd_input_tensor
,
bwd_output_tensor
,
bwd_output_tensor_grad
,
config
,
f_context
=
f_context
,
b_context
=
b_context
,
pre_forward
=
pre_forward
,
pre_backward
=
pre_backward
,
post_forward
=
post_forward
,
post_backward
=
post_backward
,
collect_non_loss_data
=
collect_non_loss_data
,
checkpoint_activations_microbatch
=
None
,
is_first_microbatch
=
False
,
current_microbatch
=
fwd_microbatch_id
,
)
# forward post process
if
fwd_model_chunk_id
is
not
None
:
cur_fwd_chunk_microbatch
[
fwd_model_chunk_id
]
+=
1
output_tensors
[
fwd_model_chunk_id
].
append
(
output_tensor
)
nonlocal
total_num_tokens
total_num_tokens
+=
num_tokens
.
item
()
if
forward_only
:
input_tensors
[
fwd_model_chunk_id
].
pop
(
0
)
output_tensors
[
fwd_model_chunk_id
].
pop
()
# backward post process
if
bwd_model_chunk_id
is
not
None
:
cur_bwd_chunk_microbatch
[
bwd_model_chunk_id
]
+=
1
return
output_tensor
,
input_tensor_grad
def
forward_backward_helper_wrapper
(
fwd_model_chunk_id
=
None
,
bwd_model_chunk_id
=
None
,
pre_forward
=
None
,
pre_backward
=
None
,
post_forward
=
None
,
post_backward
=
None
,
checkpoint_activations_microbatch
=
None
,
):
"""
wrap forward_helper、backward_helper、combined_forward_backward_helper in a unified way
"""
if
config
.
combined_1f1b
and
config
.
combined_1f1b_recipe
==
"ep_a2a"
and
not
forward_only
:
assert
(
checkpoint_activations_microbatch
is
None
),
"checkpoint_activations_microbatch not supported when combined_1f1b is true"
return
combined_forward_backward_helper
(
fwd_model_chunk_id
=
fwd_model_chunk_id
,
bwd_model_chunk_id
=
bwd_model_chunk_id
,
pre_forward
=
pre_forward
,
pre_backward
=
pre_backward
,
post_forward
=
post_forward
,
post_backward
=
post_backward
,
)
else
:
output_tensor
=
None
input_tensor_grad
=
None
if
fwd_model_chunk_id
is
not
None
:
# forward pass
if
pre_forward
is
not
None
:
pre_forward
()
output_tensor
=
forward_step_helper
(
fwd_model_chunk_id
,
cur_fwd_chunk_microbatch
[
fwd_model_chunk_id
],
checkpoint_activations_microbatch
)
cur_fwd_chunk_microbatch
[
fwd_model_chunk_id
]
+=
1
if
post_forward
is
not
None
:
output_tensor
=
post_forward
(
output_tensor
)
if
bwd_model_chunk_id
is
not
None
:
# Backward pass.
if
pre_backward
is
not
None
:
pre_backward
()
input_tensor_grad
=
backward_step_helper
(
bwd_model_chunk_id
,
cur_bwd_chunk_microbatch
[
bwd_model_chunk_id
])
cur_bwd_chunk_microbatch
[
bwd_model_chunk_id
]
+=
1
if
post_backward
is
not
None
:
input_tensor_grad
=
post_backward
(
input_tensor_grad
)
return
output_tensor
,
input_tensor_grad
output_tensor
=
None
output_tensor_master_send
=
None
output_tensor_master_send
=
None
output_tensor_slave_send
=
None
output_tensor_slave_send
=
None
fwd_wait_recv_handles
=
[
None
,
None
]
fwd_wait_recv_handles
=
[
None
,
None
]
fwd_wait_send_handles
=
[
None
,
None
]
fwd_wait_send_handles
=
[
None
,
None
]
bwd_wait_recv_handles
=
[
None
,
None
]
bwd_wait_recv_handles
=
[
None
,
None
]
bwd_wait_send_handles
=
[
None
,
None
]
bwd_wait_send_handles
=
[
None
,
None
]
checkpoint_activations_microbatch
=
None
# Run warmup forward passes
# Run warmup forward passes
input_tensor
,
_
=
recv_forward
(
tensor_shape
,
config
,
master_chunk_id
)
input_tensor
,
_
=
recv_forward
(
tensor_shape
,
config
,
master_chunk_id
)
...
@@ -776,13 +833,10 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -776,13 +833,10 @@ def forward_backward_pipelining_with_cutinhalf(
input_tensor
,
fwd_wait_recv_handles
[
master_chunk_id
]
=
recv_forward
(
tensor_shape
,
config
,
master_chunk_id
,
async_op
=
True
)
input_tensor
,
fwd_wait_recv_handles
[
master_chunk_id
]
=
recv_forward
(
tensor_shape
,
config
,
master_chunk_id
,
async_op
=
True
)
input_tensors
[
master_chunk_id
].
append
(
input_tensor
)
input_tensors
[
master_chunk_id
].
append
(
input_tensor
)
is_first_microbatch
=
check_first_val_step
(
first_val_step
,
forward_only
,
i
==
0
)
output_tensor
,
_
=
forward_backward_helper_wrapper
(
output_tensor
=
forward_step_helper
(
fwd_model_chunk_id
=
master_chunk_id
,
master_chunk_id
,
checkpoint_activations_microbatch
=
checkpoint_activations_microbatch
,
cur_fwd_chunk_microbatch
[
master_chunk_id
],
is_first_microbatch
=
is_first_microbatch
)
)
cur_fwd_chunk_microbatch
[
master_chunk_id
]
+=
1
if
fwd_wait_send_handles
[
master_chunk_id
]
is
not
None
:
if
fwd_wait_send_handles
[
master_chunk_id
]
is
not
None
:
for
req
,
req_handle
in
fwd_wait_send_handles
[
master_chunk_id
].
items
():
for
req
,
req_handle
in
fwd_wait_send_handles
[
master_chunk_id
].
items
():
...
@@ -804,14 +858,10 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -804,14 +858,10 @@ def forward_backward_pipelining_with_cutinhalf(
input_tensor_slave
,
fwd_wait_recv_handles
[
slave_chunk_id
]
=
recv_forward
(
tensor_shape
,
config
,
slave_chunk_id
,
async_op
=
True
)
input_tensor_slave
,
fwd_wait_recv_handles
[
slave_chunk_id
]
=
recv_forward
(
tensor_shape
,
config
,
slave_chunk_id
,
async_op
=
True
)
input_tensors
[
slave_chunk_id
].
append
(
input_tensor_slave
)
input_tensors
[
slave_chunk_id
].
append
(
input_tensor_slave
)
is_first_microbatch
=
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
)
and
(
i
==
0
)
output_tensor
,
_
=
forward_backward_helper_wrapper
(
is_first_microbatch
=
check_first_val_step
(
first_val_step
,
forward_only
,
is_first_microbatch
)
fwd_model_chunk_id
=
master_chunk_id
,
output_tensor_master
=
forward_step_helper
(
checkpoint_activations_microbatch
=
checkpoint_activations_microbatch
,
master_chunk_id
,
cur_fwd_chunk_microbatch
[
master_chunk_id
],
is_first_microbatch
=
is_first_microbatch
)
)
cur_fwd_chunk_microbatch
[
master_chunk_id
]
+=
1
if
not
parallel_state
.
is_pipeline_last_stage
():
if
not
parallel_state
.
is_pipeline_last_stage
():
wait_comm_handles
(
fwd_wait_send_handles
[
master_chunk_id
])
wait_comm_handles
(
fwd_wait_send_handles
[
master_chunk_id
])
...
@@ -819,20 +869,20 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -819,20 +869,20 @@ def forward_backward_pipelining_with_cutinhalf(
if
not
forward_only
:
if
not
forward_only
:
deallocate_output_tensor
(
output_tensor_master_send
,
config
.
deallocate_pipeline_outputs
)
deallocate_output_tensor
(
output_tensor_master_send
,
config
.
deallocate_pipeline_outputs
)
output_tensor_master_send
=
output_tensor
_master
output_tensor_master_send
=
output_tensor
fwd_wait_send_handles
[
master_chunk_id
]
=
send_forward
(
fwd_wait_send_handles
[
master_chunk_id
]
=
send_forward
(
output_tensor_master_send
,
tensor_shape
,
config
,
master_chunk_id
,
async_op
=
True
)
output_tensor_master_send
,
tensor_shape
,
config
,
master_chunk_id
,
async_op
=
True
)
# prepare input for slave chunk
# prepare input for slave chunk
if
parallel_state
.
is_pipeline_last_stage
():
if
parallel_state
.
is_pipeline_last_stage
():
if
not
forward_only
:
if
not
forward_only
:
input_tensor_slave
=
output_tensor
_master
.
detach
()
input_tensor_slave
=
output_tensor
.
detach
()
input_tensor_slave
.
requires_grad
=
True
input_tensor_slave
.
requires_grad
=
True
else
:
else
:
input_tensor_slave
=
output_tensor
_master
input_tensor_slave
=
output_tensor
input_tensors
[
slave_chunk_id
].
append
(
input_tensor_slave
)
input_tensors
[
slave_chunk_id
].
append
(
input_tensor_slave
)
if
not
forward_only
:
if
not
forward_only
:
deallocate_output_tensor
(
output_tensor
_master
,
config
.
deallocate_pipeline_outputs
)
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
else
:
else
:
wait_comm_handles
(
fwd_wait_recv_handles
[
slave_chunk_id
])
wait_comm_handles
(
fwd_wait_recv_handles
[
slave_chunk_id
])
...
@@ -841,19 +891,16 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -841,19 +891,16 @@ def forward_backward_pipelining_with_cutinhalf(
input_tensors
[
master_chunk_id
].
append
(
input_tensor
)
input_tensors
[
master_chunk_id
].
append
(
input_tensor
)
# slave forward
# slave forward
is_first_microbatch
=
check_first_val_step
(
first_val_step
,
forward_only
,
i
==
0
)
output_tensor
,
_
=
forward_backward_helper_wrapper
(
output_tensor_slave
=
forward_step_helper
(
fwd_model_chunk_id
=
slave_chunk_id
,
slave_chunk_id
,
checkpoint_activations_microbatch
=
checkpoint_activations_microbatch
,
cur_fwd_chunk_microbatch
[
slave_chunk_id
],
is_first_microbatch
=
is_first_microbatch
)
)
cur_fwd_chunk_microbatch
[
slave_chunk_id
]
+=
1
wait_comm_handles
(
fwd_wait_send_handles
[
slave_chunk_id
])
wait_comm_handles
(
fwd_wait_send_handles
[
slave_chunk_id
])
if
not
forward_only
:
if
not
forward_only
:
deallocate_output_tensor
(
output_tensor_slave_send
,
config
.
deallocate_pipeline_outputs
)
deallocate_output_tensor
(
output_tensor_slave_send
,
config
.
deallocate_pipeline_outputs
)
output_tensor_slave_send
=
output_tensor
_slave
output_tensor_slave_send
=
output_tensor
fwd_wait_send_handles
[
slave_chunk_id
]
=
send_forward
(
output_tensor_slave_send
,
tensor_shape
,
config
,
slave_chunk_id
,
async_op
=
True
)
fwd_wait_send_handles
[
slave_chunk_id
]
=
send_forward
(
output_tensor_slave_send
,
tensor_shape
,
config
,
slave_chunk_id
,
async_op
=
True
)
# check whether data transmission is completed.
# check whether data transmission is completed.
...
@@ -884,8 +931,7 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -884,8 +931,7 @@ def forward_backward_pipelining_with_cutinhalf(
input_tensors
[
slave_chunk_id
].
append
(
input_tensor_slave
)
input_tensors
[
slave_chunk_id
].
append
(
input_tensor_slave
)
if
not
forward_only
:
if
not
forward_only
:
input_tensor_grad
=
backward_step_helper
(
slave_chunk_id
)
_
,
input_tensor_grad
=
forward_backward_helper_wrapper
(
bwd_model_chunk_id
=
slave_chunk_id
)
cur_bwd_chunk_microbatch
[
slave_chunk_id
]
+=
1
# If asynchronous, the memory will rise.
# If asynchronous, the memory will rise.
bwd_wait_send_handles
[
slave_chunk_id
]
=
send_backward
(
input_tensor_grad
,
tensor_shape
,
config
,
slave_chunk_id
)
bwd_wait_send_handles
[
slave_chunk_id
]
=
send_backward
(
input_tensor_grad
,
tensor_shape
,
config
,
slave_chunk_id
)
...
@@ -905,19 +951,22 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -905,19 +951,22 @@ def forward_backward_pipelining_with_cutinhalf(
handle
.
wait
()
handle
.
wait
()
fwd_wait_recv_handles
[
slave_chunk_id
]
=
None
fwd_wait_recv_handles
[
slave_chunk_id
]
=
None
output_tensor_slave
=
forward_step_helper
(
output_tensor
,
_
=
forward_backward_helper_wrapper
(
slave_chunk_id
,
fwd_model_chunk_id
=
slave_chunk_id
,
cur_fwd_chunk_microbatch
[
slave_chunk_id
],
checkpoint_activations_microbatch
=
checkpoint_activations_microbatch
,
is_first_microbatch
=
False
)
)
cur_fwd_chunk_microbatch
[
slave_chunk_id
]
+=
1
# check whether backward data transmission is completed.
# check whether backward data transmission is completed.
wait_comm_handles
(
bwd_wait_send_handles
[
slave_chunk_id
])
wait_comm_handles
(
bwd_wait_send_handles
[
slave_chunk_id
])
output_tensor_slave_send
=
output_tensor
_slave
output_tensor_slave_send
=
output_tensor
fwd_wait_send_handles
[
slave_chunk_id
]
=
send_forward
(
output_tensor_slave_send
,
tensor_shape
,
config
,
slave_chunk_id
,
async_op
=
True
)
fwd_wait_send_handles
[
slave_chunk_id
]
=
send_forward
(
output_tensor_slave_send
,
tensor_shape
,
config
,
slave_chunk_id
,
async_op
=
True
)
# check whether forward data transmission is completed.
wait_comm_handles
(
fwd_wait_send_handles
[
slave_chunk_id
])
if
not
forward_only
:
deallocate_output_tensor
(
output_tensor_slave_send
,
config
.
deallocate_pipeline_outputs
)
# Run overlaping f&bw stages
# Run overlaping f&bw stages
fwd_wait_send_recv_handles
=
None
fwd_wait_send_recv_handles
=
None
bwd_wait_send_recv_handles
=
None
bwd_wait_send_recv_handles
=
None
...
@@ -938,6 +987,9 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -938,6 +987,9 @@ def forward_backward_pipelining_with_cutinhalf(
# wait input for current step
# wait input for current step
wait_comm_handles
(
fwd_wait_recv_handles
[
fwd_model_chunk_id
])
wait_comm_handles
(
fwd_wait_recv_handles
[
fwd_model_chunk_id
])
if
not
forward_only
:
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
def
pp_post_forward
(
output_tensor
):
def
pp_post_forward
(
output_tensor
):
nonlocal
cur_fwd_chunk_microbatch
nonlocal
cur_fwd_chunk_microbatch
nonlocal
num_chunk_max_microbatch
nonlocal
num_chunk_max_microbatch
...
@@ -973,6 +1025,7 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -973,6 +1025,7 @@ def forward_backward_pipelining_with_cutinhalf(
wait_comm_handles
(
bwd_wait_send_recv_handles
)
wait_comm_handles
(
bwd_wait_send_recv_handles
)
def
pp_post_backward
(
input_tensor_grad
):
def
pp_post_backward
(
input_tensor_grad
):
nonlocal
output_tensor_grads
nonlocal
fwd_wait_send_handles
nonlocal
fwd_wait_send_handles
nonlocal
fwd_wait_send_recv_handles
nonlocal
fwd_wait_send_recv_handles
nonlocal
bwd_wait_send_recv_handles
nonlocal
bwd_wait_send_recv_handles
...
@@ -981,9 +1034,6 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -981,9 +1034,6 @@ def forward_backward_pipelining_with_cutinhalf(
wait_comm_handles
(
fwd_wait_send_handles
[
fwd_model_chunk_id
])
wait_comm_handles
(
fwd_wait_send_handles
[
fwd_model_chunk_id
])
wait_comm_handles
(
fwd_wait_send_recv_handles
)
wait_comm_handles
(
fwd_wait_send_recv_handles
)
if
not
forward_only
:
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
if
not
forward_only
:
if
not
forward_only
:
if
parallel_state
.
is_pipeline_last_stage
()
and
fwd_model_chunk_id
==
master_chunk_id
:
if
parallel_state
.
is_pipeline_last_stage
()
and
fwd_model_chunk_id
==
master_chunk_id
:
output_tensor_grad
=
input_tensor_grad
output_tensor_grad
=
input_tensor_grad
...
@@ -999,31 +1049,21 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -999,31 +1049,21 @@ def forward_backward_pipelining_with_cutinhalf(
return
input_tensor_grad
return
input_tensor_grad
# forward
output_tensor
,
input_tensor_grad
=
forward_backward_helper_wrapper
(
pp_pre_forward
()
fwd_model_chunk_id
=
fwd_model_chunk_id
,
output_tensor
=
forward_step_helper
(
bwd_model_chunk_id
=
None
if
forward_only
else
bwd_model_chunk_id
,
fwd_model_chunk_id
,
pre_forward
=
pp_pre_forward
,
cur_fwd_chunk_microbatch
[
fwd_model_chunk_id
],
pre_backward
=
pp_pre_backward
,
is_first_microbatch
=
False
post_forward
=
pp_post_forward
,
post_backward
=
pp_post_backward
,
checkpoint_activations_microbatch
=
checkpoint_activations_microbatch
,
)
)
cur_fwd_chunk_microbatch
[
fwd_model_chunk_id
]
+=
1
output_tensor
=
pp_post_forward
(
output_tensor
)
# backward
pp_pre_backward
()
if
not
forward_only
:
try
:
input_tensor_grad
=
backward_step_helper
(
bwd_model_chunk_id
)
except
Exception
as
e
:
print
(
f
"step_id:
{
step_id
}
, rank:
{
torch
.
distributed
.
get_rank
()
}
, bwd_model_chunk_id:
{
bwd_model_chunk_id
}
"
,
flush
=
True
)
raise
Exception
(
f
"
{
e
}
"
)
cur_bwd_chunk_microbatch
[
bwd_model_chunk_id
]
+=
1
else
:
input_tensor_grad
=
None
_
=
pp_post_backward
(
input_tensor_grad
)
# only run backward
# only run backward
else
:
else
:
if
not
forward_only
:
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
if
bwd_model_chunk_id
==
slave_chunk_id
and
cur_fwd_chunk_microbatch
[
slave_chunk_id
]
<
num_chunk_max_microbatch
[
slave_chunk_id
]:
if
bwd_model_chunk_id
==
slave_chunk_id
and
cur_fwd_chunk_microbatch
[
slave_chunk_id
]
<
num_chunk_max_microbatch
[
slave_chunk_id
]:
input_tensor
,
fwd_wait_recv_handles
[
slave_chunk_id
]
=
recv_forward
(
tensor_shape
,
config
,
slave_chunk_id
,
async_op
=
True
)
input_tensor
,
fwd_wait_recv_handles
[
slave_chunk_id
]
=
recv_forward
(
tensor_shape
,
config
,
slave_chunk_id
,
async_op
=
True
)
input_tensors
[
slave_chunk_id
].
append
(
input_tensor
)
input_tensors
[
slave_chunk_id
].
append
(
input_tensor
)
...
@@ -1031,11 +1071,9 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -1031,11 +1071,9 @@ def forward_backward_pipelining_with_cutinhalf(
wait_comm_handles
(
bwd_wait_send_handles
[
1
-
bwd_model_chunk_id
])
wait_comm_handles
(
bwd_wait_send_handles
[
1
-
bwd_model_chunk_id
])
wait_comm_handles
(
bwd_wait_send_recv_handles
)
wait_comm_handles
(
bwd_wait_send_recv_handles
)
input_tensor_grad
=
backward_step_helper
(
_
,
input_tensor_grad
=
forward_backward_helper_wrapper
(
bwd_model_chunk_id
,
bwd_model_chunk_id
=
bwd_model_chunk_id
,
bwd_cur_microbatch
=
cur_bwd_chunk_microbatch
[
bwd_model_chunk_id
]
)
)
cur_bwd_chunk_microbatch
[
bwd_model_chunk_id
]
+=
1
if
parallel_state
.
is_pipeline_last_stage
()
and
fwd_model_chunk_id
==
master_chunk_id
:
if
parallel_state
.
is_pipeline_last_stage
()
and
fwd_model_chunk_id
==
master_chunk_id
:
output_tensor_grad
=
input_tensor_grad
output_tensor_grad
=
input_tensor_grad
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment