Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
770fa304
Commit
770fa304
authored
Apr 25, 2025
by
dongcl
Browse files
修改mtp
parent
8096abd4
Changes
44
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2866 additions
and
502 deletions
+2866
-502
dcu_megatron/adaptor/adaptor_arguments.py
dcu_megatron/adaptor/adaptor_arguments.py
+23
-0
dcu_megatron/adaptor/feature_manager/__init__.py
dcu_megatron/adaptor/feature_manager/__init__.py
+7
-0
dcu_megatron/adaptor/feature_manager/base_feature.py
dcu_megatron/adaptor/feature_manager/base_feature.py
+40
-0
dcu_megatron/adaptor/feature_manager/mtp_feature.py
dcu_megatron/adaptor/feature_manager/mtp_feature.py
+51
-0
dcu_megatron/adaptor/feature_manager/pipeline_parallel/dualpipev_feature.py
...or/feature_manager/pipeline_parallel/dualpipev_feature.py
+52
-0
dcu_megatron/adaptor/megatron_adaptor.py
dcu_megatron/adaptor/megatron_adaptor.py
+34
-26
dcu_megatron/adaptor/patch_utils.py
dcu_megatron/adaptor/patch_utils.py
+21
-3
dcu_megatron/core/distributed/finalize_model_grads.py
dcu_megatron/core/distributed/finalize_model_grads.py
+6
-13
dcu_megatron/core/models/common/embeddings/language_model_embedding.py
...core/models/common/embeddings/language_model_embedding.py
+0
-133
dcu_megatron/core/models/common/language_module/language_module.py
...ron/core/models/common/language_module/language_module.py
+137
-0
dcu_megatron/core/models/gpt/gpt_layer_specs.py
dcu_megatron/core/models/gpt/gpt_layer_specs.py
+51
-2
dcu_megatron/core/models/gpt/gpt_model.py
dcu_megatron/core/models/gpt/gpt_model.py
+378
-325
dcu_megatron/core/pipeline_parallel/dualpipev/__init__.py
dcu_megatron/core/pipeline_parallel/dualpipev/__init__.py
+0
-0
dcu_megatron/core/pipeline_parallel/dualpipev/dualpipev_chunks.py
...tron/core/pipeline_parallel/dualpipev/dualpipev_chunks.py
+236
-0
dcu_megatron/core/pipeline_parallel/dualpipev/dualpipev_schedules.py
...n/core/pipeline_parallel/dualpipev/dualpipev_schedules.py
+1498
-0
dcu_megatron/core/pipeline_parallel/fb_overlap/__init__.py
dcu_megatron/core/pipeline_parallel/fb_overlap/__init__.py
+5
-0
dcu_megatron/core/pipeline_parallel/fb_overlap/adaptor.py
dcu_megatron/core/pipeline_parallel/fb_overlap/adaptor.py
+33
-0
dcu_megatron/core/pipeline_parallel/fb_overlap/gpt_model.py
dcu_megatron/core/pipeline_parallel/fb_overlap/gpt_model.py
+209
-0
dcu_megatron/core/pipeline_parallel/fb_overlap/modules/__init__.py
...ron/core/pipeline_parallel/fb_overlap/modules/__init__.py
+0
-0
dcu_megatron/core/pipeline_parallel/fb_overlap/modules/attention.py
...on/core/pipeline_parallel/fb_overlap/modules/attention.py
+85
-0
No files found.
dcu_megatron/adaptor/adaptor_arguments.py
0 → 100644
View file @
770fa304
import
argparse
from
.feature_manager
import
FEATURES_LIST
_ARGS
=
None
def
process_args
(
parser
):
parser
.
conflict_handler
=
'resolve'
for
feature
in
FEATURES_LIST
:
feature
.
register_args
(
parser
)
return
parser
def
get_adaptor_args
():
global
_ARGS
if
_ARGS
is
None
:
parser
=
argparse
.
ArgumentParser
(
description
=
'Adaptor Arguments'
,
allow_abbrev
=
False
)
_ARGS
,
_
=
process_args
(
parser
).
parse_known_args
()
return
_ARGS
dcu_megatron/adaptor/feature_manager/__init__.py
0 → 100644
View file @
770fa304
from
.pipeline_parallel.dualpipev_feature
import
DualpipeVFeature
FEATURES_LIST
=
[
# Pipeline Parallel features
DualpipeVFeature
()
]
dcu_megatron/adaptor/feature_manager/base_feature.py
0 → 100644
View file @
770fa304
# modified from mindspeed
import
argparse
class
BaseFeature
:
def
__init__
(
self
,
feature_name
:
str
,
optimization_level
:
int
=
2
):
self
.
feature_name
=
feature_name
.
strip
().
replace
(
'-'
,
'_'
)
self
.
optimization_level
=
optimization_level
self
.
default_patches
=
self
.
optimization_level
==
0
def
register_args
(
self
,
parser
):
pass
def
pre_validate_args
(
self
,
args
):
pass
def
validate_args
(
self
,
args
):
pass
def
post_validate_args
(
self
,
args
):
pass
def
register_patches
(
self
,
patch_manager
,
args
):
...
def
incompatible_check
(
self
,
global_args
,
check_args
):
if
getattr
(
global_args
,
self
.
feature_name
,
None
)
and
getattr
(
global_args
,
check_args
,
None
):
raise
AssertionError
(
'{} and {} are incompatible.'
.
format
(
self
.
feature_name
,
check_args
))
def
dependency_check
(
self
,
global_args
,
check_args
):
if
getattr
(
global_args
,
self
.
feature_name
,
None
)
and
not
getattr
(
global_args
,
check_args
,
None
):
raise
AssertionError
(
'{} requires {}.'
.
format
(
self
.
feature_name
,
check_args
))
@
staticmethod
def
add_parser_argument_choices_value
(
parser
,
argument_name
,
new_choice
):
for
action
in
parser
.
_actions
:
exist_arg
=
isinstance
(
action
,
argparse
.
Action
)
and
argument_name
in
action
.
option_strings
if
exist_arg
and
action
.
choices
is
not
None
and
new_choice
not
in
action
.
choices
:
action
.
choices
.
append
(
new_choice
)
dcu_megatron/adaptor/feature_manager/mtp_feature.py
0 → 100644
View file @
770fa304
from
argparse
import
ArgumentParser
from
..base_feature
import
BaseFeature
class
MTPFeature
(
BaseFeature
):
def
__init__
(
self
):
super
().
__init__
(
'schedules-method'
)
def
register_args
(
self
,
parser
:
ArgumentParser
):
group
=
parser
.
add_argument_group
(
title
=
self
.
feature_name
)
group
.
add_argument
(
'--schedules-method'
,
type
=
str
,
default
=
None
,
choices
=
[
'dualpipev'
])
def
register_patches
(
self
,
patch_manager
,
args
):
from
...core.distributed.finalize_model_grads
import
_allreduce_word_embedding_grads
from
...core.models.common.language_module.language_module
import
(
setup_embeddings_and_output_layer
,
tie_embeddings_and_output_weights_state_dict
,
)
from
...core.models.gpt.gpt_model
import
GPTModel
from
...training.utils
import
get_batch_on_this_tp_rank
from
...core.pipeline_parallel.schedules
import
forward_step_wrapper
from
...core
import
transformer_block_init_wrapper
MegatronAdaptation
.
register
(
'megatron.core.distributed.finalize_model_grads._allreduce_word_embedding_grads'
,
_allreduce_word_embedding_grads
)
# LanguageModule
MegatronAdaptation
.
register
(
'megatron.core.models.common.language_module.language_module.LanguageModule.setup_embeddings_and_output_layer'
,
setup_embeddings_and_output_layer
)
MegatronAdaptation
.
register
(
'megatron.core.models.common.language_module.language_module.LanguageModule.tie_embeddings_and_output_weights_state_dict'
,
tie_embeddings_and_output_weights_state_dict
)
MegatronAdaptation
.
register
(
'megatron.training.utils.get_batch_on_this_tp_rank'
,
get_batch_on_this_tp_rank
)
# GPT Model
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel'
,
GPTModel
)
# Transformer block
MegatronAdaptation
.
register
(
'megatron.core.transformer.transformer_block.TransformerBlock.__init__'
,
transformer_block_init_wrapper
)
# pipeline_parallel.schedules.forward_step
MegatronAdaptation
.
register
(
'megatron.core.pipeline_parallel.schedules.forward_step'
,
forward_step_wrapper
,
apply_wrapper
=
True
)
dcu_megatron/adaptor/feature_manager/pipeline_parallel/dualpipev_feature.py
0 → 100644
View file @
770fa304
# Modified from mindspeed.
from
argparse
import
ArgumentParser
from
..base_feature
import
BaseFeature
class
DualpipeVFeature
(
BaseFeature
):
def
__init__
(
self
):
super
().
__init__
(
'schedules-method'
)
def
register_args
(
self
,
parser
:
ArgumentParser
):
group
=
parser
.
add_argument_group
(
title
=
self
.
feature_name
)
group
.
add_argument
(
'--schedules-method'
,
type
=
str
,
default
=
None
,
choices
=
[
'dualpipev'
])
def
validate_args
(
self
,
args
):
if
args
.
schedules_method
==
"dualpipev"
:
if
args
.
num_layers_per_virtual_pipeline_stage
is
not
None
:
raise
AssertionError
(
"The dualpipev and virtual_pipeline are incompatible."
)
if
args
.
num_layers
<
args
.
pipeline_model_parallel_size
*
2
:
raise
AssertionError
(
'number of layers must be at least 2*pipeline_model_parallel_size in dualpipe'
)
num_micro_batch
=
args
.
global_batch_size
//
args
.
micro_batch_size
//
args
.
data_parallel_size
if
num_micro_batch
<
args
.
pipeline_model_parallel_size
*
2
-
1
:
raise
AssertionError
(
"num_micro_batch should more than pipeline_model_parallel_size * 2 - 1"
)
def
register_patches
(
self
,
patch_manager
,
args
):
from
megatron.training.utils
import
print_rank_0
from
dcu_megatron.core.pipeline_parallel.dualpipev.dualpipev_schedules
import
forward_backward_pipelining_with_cutinhalf
from
dcu_megatron.core.pipeline_parallel.dualpipev.dualpipev_chunks
import
(
get_model
,
dualpipev_fp16forward
,
get_num_layers_to_build
,
train_step
,
_allreduce_embedding_grads_wrapper
)
if
args
.
schedules_method
==
"dualpipev"
:
patch_manager
.
register_patch
(
'megatron.training.training.get_model'
,
get_model
)
patch_manager
.
register_patch
(
'megatron.training.training.train_step'
,
train_step
)
patch_manager
.
register_patch
(
'megatron.core.pipeline_parallel.schedules.forward_backward_pipelining_without_interleaving'
,
forward_backward_pipelining_with_cutinhalf
)
patch_manager
.
register_patch
(
'megatron.legacy.model.module.Float16Module.forward'
,
dualpipev_fp16forward
)
patch_manager
.
register_patch
(
'megatron.core.transformer.transformer_block.get_num_layers_to_build'
,
get_num_layers_to_build
)
patch_manager
.
register_patch
(
'megatron.training.utils.print_rank_last'
,
print_rank_0
)
patch_manager
.
register_patch
(
'megatron.core.distributed.finalize_model_grads._allreduce_embedding_grads'
,
_allreduce_embedding_grads_wrapper
)
dcu_megatron/adaptor/megatron_adaptor.py
View file @
770fa304
...
...
@@ -5,6 +5,8 @@ import types
import
argparse
import
torch
from
.adaptor_arguments
import
get_adaptor_args
class
MegatronAdaptation
:
"""
...
...
@@ -21,6 +23,15 @@ class MegatronAdaptation:
for
adaptation
in
[
CoreAdaptation
(),
LegacyAdaptation
()]:
adaptation
.
execute
()
MegatronAdaptation
.
apply
()
from
.patch_utils
import
MegatronPatchesManager
args
=
get_adaptor_args
()
for
feature
in
FEATURES_LIST
:
if
(
getattr
(
args
,
feature
.
feature_name
,
None
)
and
feature
.
optimization_level
>
0
)
or
feature
.
optimization_level
==
0
:
feature
.
register_patches
(
MegatronPatchesManager
,
args
)
MindSpeedPatchesManager
.
apply_patches
()
# MegatronAdaptation.post_execute()
@
classmethod
...
...
@@ -87,47 +98,37 @@ class CoreAdaptation(MegatronAdaptationABC):
self
.
patch_miscellaneous
()
def
patch_core_distributed
(
self
):
#
M
tp share embedding
#
m
tp share embedding
from
..core.distributed.finalize_model_grads
import
_allreduce_word_embedding_grads
MegatronAdaptation
.
register
(
'megatron.core.distributed.finalize_model_grads._allreduce_word_embedding_grads'
,
_allreduce_word_embedding_grads
)
def
patch_core_models
(
self
):
from
..core.models.common.embeddings.language_model_embedding
import
(
language_model_embedding_forward
,
language_model_embedding_init_func
)
from
..core.models.gpt.gpt_model
import
(
gpt_model_forward
,
gpt_model_init_wrapper
,
shared_embedding_or_mtp_embedding_weight
from
..core.models.common.language_module.language_module
import
(
setup_embeddings_and_output_layer
,
tie_embeddings_and_output_weights_state_dict
,
)
from
..core.models.gpt.gpt_model
import
GPTModel
from
..training.utils
import
get_batch_on_this_tp_rank
#
Embedding
#
LanguageModule
MegatronAdaptation
.
register
(
'megatron.core.models.common.
embeddings.language_model_embedding.LanguageModelEmbedding.__init__
'
,
language_model_embedding_init_func
)
'megatron.core.models.common.
language_module.language_module.LanguageModule.setup_embeddings_and_output_layer
'
,
setup_embeddings_and_output_layer
)
MegatronAdaptation
.
register
(
'megatron.core.models.common.
embeddings.language_model_embedding.LanguageModelEmbedding.forward
'
,
language_model_embedding_forward
)
'megatron.core.models.common.
language_module.language_module.LanguageModule.tie_embeddings_and_output_weights_state_dict
'
,
tie_embeddings_and_output_weights_state_dict
)
MegatronAdaptation
.
register
(
'megatron.training.utils.get_batch_on_this_tp_rank'
,
get_batch_on_this_tp_rank
)
# GPT Model
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.forward'
,
gpt_model_forward
)
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.__init__'
,
gpt_model_init_wrapper
,
apply_wrapper
=
True
)
from
megatron.core.models.gpt.gpt_model
import
GPTModel
setattr
(
GPTModel
,
'shared_embedding_or_mtp_embedding_weight'
,
shared_embedding_or_mtp_embedding_weight
)
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel'
,
GPTModel
)
def
patch_core_transformers
(
self
):
from
..core
import
transformer_block_init_wrapper
from
..core.transformer.transformer_config
import
TransformerConfigPatch
,
MLATransformerConfigPatch
# Transformer block
# Transformer block
. If mtp_num_layers > 0, move final_layernorm outside
MegatronAdaptation
.
register
(
'megatron.core.transformer.transformer_block.TransformerBlock.__init__'
,
transformer_block_init_wrapper
)
...
...
@@ -165,13 +166,11 @@ class CoreAdaptation(MegatronAdaptationABC):
def
patch_tensor_parallel
(
self
):
from
..core.tensor_parallel.cross_entropy
import
VocabParallelCrossEntropy
from
..core.tensor_parallel
import
vocab_parallel_embedding_forward
,
vocab_parallel_embedding_init
# VocabParallelEmbedding
MegatronAdaptation
.
register
(
'megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward'
,
vocab_parallel_embedding_forward
)
MegatronAdaptation
.
register
(
'megatron.core.tensor_parallel.layers.VocabParallelEmbedding.__init__'
,
vocab_parallel_embedding_init
)
torch
.
compile
(
mode
=
'max-autotune-no-cudagraphs'
),
apply_wrapper
=
True
)
# VocabParallelCrossEntropy
MegatronAdaptation
.
register
(
'megatron.core.tensor_parallel.cross_entropy.VocabParallelCrossEntropy.calculate_predicted_logits'
,
...
...
@@ -201,6 +200,14 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation
.
register
(
"megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_with_transformer_engine_spec"
,
get_gpt_layer_with_flux_spec
)
def
patch_pipeline_parallel
(
self
):
from
..core.pipeline_parallel.schedules
import
forward_step_wrapper
# pipeline_parallel.schedules.forward_step
MegatronAdaptation
.
register
(
'megatron.core.pipeline_parallel.schedules.forward_step'
,
forward_step_wrapper
,
apply_wrapper
=
True
)
def
patch_training
(
self
):
from
..training.tokenizer
import
build_tokenizer
from
..training.initialize
import
_initialize_distributed
...
...
@@ -245,6 +252,7 @@ class LegacyAdaptation(MegatronAdaptationABC):
parallel_mlp_init_wrapper
,
apply_wrapper
=
True
)
# ParallelAttention
MegatronAdaptation
.
register
(
'megatron.legacy.model.transformer.ParallelAttention.__init__'
,
parallel_attention_init_wrapper
,
apply_wrapper
=
True
)
...
...
dcu_megatron/adaptor/patch_utils.py
View file @
770fa304
...
...
@@ -148,11 +148,29 @@ class MegatronPatchesManager:
patches_info
=
{}
@
staticmethod
def
register_patch
(
orig_func_or_cls_name
,
new_func_or_cls
=
None
,
force_patch
=
False
,
create_dummy
=
False
):
def
register_patch
(
orig_func_or_cls_name
,
new_func_or_cls
=
None
,
force_patch
=
False
,
create_dummy
=
False
,
apply_wrapper
=
False
,
remove_origin_wrappers
=
False
):
if
orig_func_or_cls_name
not
in
MegatronPatchesManager
.
patches_info
:
MegatronPatchesManager
.
patches_info
[
orig_func_or_cls_name
]
=
Patch
(
orig_func_or_cls_name
,
new_func_or_cls
,
create_dummy
)
MegatronPatchesManager
.
patches_info
[
orig_func_or_cls_name
]
=
Patch
(
orig_func_or_cls_name
,
new_func_or_cls
,
create_dummy
,
apply_wrapper
=
apply_wrapper
,
remove_origin_wrappers
=
remove_origin_wrappers
)
else
:
MegatronPatchesManager
.
patches_info
.
get
(
orig_func_or_cls_name
).
set_patch_func
(
new_func_or_cls
,
force_patch
)
MegatronPatchesManager
.
patches_info
.
get
(
orig_func_or_cls_name
).
set_patch_func
(
new_func_or_cls
,
force_patch
,
apply_wrapper
=
apply_wrapper
,
remove_origin_wrappers
=
remove_origin_wrappers
)
@
staticmethod
def
apply_patches
():
...
...
dcu_megatron/core/distributed/finalize_model_grads.py
View file @
770fa304
...
...
@@ -28,20 +28,13 @@ def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: Transf
model_module
=
model
[
0
]
model_module
=
get_attr_wrapped_model
(
model_module
,
'pre_process'
,
return_model_obj
=
True
)
if
model_module
.
share_embeddings_and_output_weights
:
weight
=
model_module
.
shared_embedding_or_output_weight
()
grad_attr
=
"main_grad"
if
hasattr
(
weight
,
"main_grad"
)
else
"grad"
orig_grad
=
getattr
(
weight
,
grad_attr
)
grad
=
_unshard_if_dtensor
(
orig_grad
)
torch
.
distributed
.
all_reduce
(
grad
,
group
=
parallel_state
.
get_embedding_group
())
setattr
(
weight
,
grad_attr
,
_reshard_if_dtensor
(
grad
,
orig_grad
))
if
(
hasattr
(
model_module
,
"share_mtp_embedding_and_output_weight"
)
and
model_module
.
share_mtp_embedding_and_output_weight
and
config
.
num_nextn_predict_layers
>
0
):
weight
=
model_module
.
shared_embedding_or_
mtp_embedding
_weight
()
# If share_embeddings_and_output_weights is True, we need to maintain duplicated
# embedding weights in post processing stage. If use Multi-Token Prediction (MTP),
# we also need to maintain duplicated embedding weights in mtp process stage.
# So we need to allreduce grads of embedding in the embedding group in these cases.
if
model_module
.
share_embeddings_and_output_weights
or
getattr
(
config
,
'mtp_num_layers'
,
0
):
weight
=
model_module
.
shared_embedding_or_
output
_weight
()
grad_attr
=
"main_grad"
if
hasattr
(
weight
,
"main_grad"
)
else
"grad"
orig_grad
=
getattr
(
weight
,
grad_attr
)
grad
=
_unshard_if_dtensor
(
orig_grad
)
...
...
dcu_megatron/core/models/common/embeddings/language_model_embedding.py
deleted
100644 → 0
View file @
8096abd4
from
typing
import
Literal
import
torch
from
torch
import
Tensor
from
megatron.core
import
tensor_parallel
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core.models.common.embeddings.language_model_embedding
import
LanguageModelEmbedding
def
language_model_embedding_init_func
(
self
,
config
:
TransformerConfig
,
vocab_size
:
int
,
max_sequence_length
:
int
,
position_embedding_type
:
Literal
[
'learned_absolute'
,
'rope'
,
'none'
]
=
'learned_absolute'
,
num_tokentypes
:
int
=
0
,
scatter_to_sequence_parallel
:
bool
=
True
,
skip_weight_param_allocation
:
bool
=
False
):
"""Patch language model embeddings init."""
super
(
LanguageModelEmbedding
,
self
).
__init__
(
config
=
config
)
self
.
config
:
TransformerConfig
=
config
self
.
vocab_size
:
int
=
vocab_size
self
.
max_sequence_length
:
int
=
max_sequence_length
self
.
add_position_embedding
:
bool
=
position_embedding_type
==
'learned_absolute'
self
.
num_tokentypes
=
num_tokentypes
self
.
scatter_to_sequence_parallel
=
scatter_to_sequence_parallel
self
.
reduce_scatter_embeddings
=
(
(
not
self
.
add_position_embedding
)
and
self
.
num_tokentypes
<=
0
and
self
.
config
.
sequence_parallel
and
self
.
scatter_to_sequence_parallel
)
# Word embeddings (parallel).
self
.
word_embeddings
=
tensor_parallel
.
VocabParallelEmbedding
(
num_embeddings
=
self
.
vocab_size
,
embedding_dim
=
self
.
config
.
hidden_size
,
init_method
=
self
.
config
.
init_method
,
reduce_scatter_embeddings
=
self
.
reduce_scatter_embeddings
,
config
=
self
.
config
,
skip_weight_param_allocation
=
skip_weight_param_allocation
)
# Position embedding (serial).
if
self
.
add_position_embedding
:
self
.
position_embeddings
=
torch
.
nn
.
Embedding
(
self
.
max_sequence_length
,
self
.
config
.
hidden_size
)
# Initialize the position embeddings.
if
self
.
config
.
perform_initialization
:
self
.
config
.
init_method
(
self
.
position_embeddings
.
weight
)
if
self
.
num_tokentypes
>
0
:
self
.
tokentype_embeddings
=
torch
.
nn
.
Embedding
(
self
.
num_tokentypes
,
self
.
config
.
hidden_size
)
# Initialize the token-type embeddings.
if
self
.
config
.
perform_initialization
:
self
.
config
.
init_method
(
self
.
tokentype_embeddings
.
weight
)
else
:
self
.
tokentype_embeddings
=
None
# Embeddings dropout
self
.
embedding_dropout
=
torch
.
nn
.
Dropout
(
self
.
config
.
hidden_dropout
)
def
language_model_embedding_forward
(
self
,
input_ids
:
Tensor
,
position_ids
:
Tensor
,
tokentype_ids
:
int
=
None
,
weight
:
Tensor
=
None
)
->
Tensor
:
"""Pacth forward pass of the embedding module.
Args:
input_ids (Tensor): The input tokens
position_ids (Tensor): The position id's used to calculate position embeddings
tokentype_ids (int): The token type ids. Used when args.bert_binary_head is
set to True. Defaults to None
weight (Tensor): embedding weight
Returns:
Tensor: The output embeddings
"""
if
weight
is
None
:
if
self
.
word_embeddings
.
weight
is
None
:
raise
RuntimeError
(
"weight was not supplied to VocabParallelEmbedding forward pass "
"and skip_weight_param_allocation is True."
)
weight
=
self
.
word_embeddings
.
weight
word_embeddings
=
self
.
word_embeddings
(
input_ids
,
weight
)
if
self
.
add_position_embedding
:
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
embeddings
=
word_embeddings
+
position_embeddings
else
:
embeddings
=
word_embeddings
if
not
self
.
reduce_scatter_embeddings
:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
embeddings
=
embeddings
.
transpose
(
0
,
1
).
contiguous
()
if
tokentype_ids
is
not
None
:
assert
self
.
tokentype_embeddings
is
not
None
# [b s h] -> [s b h] (So that it can be added with embeddings)
tokentype_embedding
=
self
.
tokentype_embeddings
(
tokentype_ids
).
permute
(
1
,
0
,
2
)
embeddings
=
embeddings
+
tokentype_embedding
else
:
assert
self
.
tokentype_embeddings
is
None
# If the input flag for fp32 residual connection is set, convert for float.
if
self
.
config
.
fp32_residual_connection
:
embeddings
=
embeddings
.
float
()
# Dropout.
if
self
.
config
.
sequence_parallel
:
if
not
self
.
reduce_scatter_embeddings
and
self
.
scatter_to_sequence_parallel
:
embeddings
=
tensor_parallel
.
scatter_to_sequence_parallel_region
(
embeddings
)
# `scatter_to_sequence_parallel_region` returns a view, which prevents
# the original tensor from being garbage collected. Clone to facilitate GC.
# Has a small runtime cost (~0.5%).
if
self
.
config
.
clone_scatter_output_in_embedding
and
self
.
scatter_to_sequence_parallel
:
embeddings
=
embeddings
.
clone
()
with
tensor_parallel
.
get_cuda_rng_tracker
().
fork
():
embeddings
=
self
.
embedding_dropout
(
embeddings
)
else
:
embeddings
=
self
.
embedding_dropout
(
embeddings
)
return
embeddings
dcu_megatron/core/models/common/language_module/language_module.py
0 → 100644
View file @
770fa304
import
logging
import
torch
from
megatron.core
import
parallel_state
from
megatron.core.dist_checkpointing.mapping
import
ShardedStateDict
from
megatron.core.models.common.language_module.language_module
import
LanguageModule
from
megatron.core.utils
import
make_tp_sharded_tensor_for_checkpoint
def
setup_embeddings_and_output_layer
(
self
)
->
None
:
"""Sets up embedding layer in first stage and output layer in last stage.
This function initalizes word embeddings in the final stage when we are
using pipeline parallelism and sharing word embeddings, and sets up param
attributes on the embedding and output layers.
"""
# Set `is_embedding_or_output_parameter` attribute.
if
self
.
pre_process
:
self
.
embedding
.
word_embeddings
.
weight
.
is_embedding_or_output_parameter
=
True
if
self
.
post_process
and
self
.
output_layer
.
weight
is
not
None
:
self
.
output_layer
.
weight
.
is_embedding_or_output_parameter
=
True
# If share_embeddings_and_output_weights is True, we need to maintain duplicated
# embedding weights in post processing stage. If use Multi-Token Prediction (MTP),
# we also need to maintain duplicated embedding weights in mtp process stage.
# So we need to copy embedding weights from pre processing stage as initial parameters
# in these cases.
if
not
self
.
share_embeddings_and_output_weights
and
not
getattr
(
self
.
config
,
'mtp_num_layers'
,
0
):
return
if
parallel_state
.
get_pipeline_model_parallel_world_size
()
==
1
:
# Zero out wgrad if sharing embeddings between two layers on same
# pipeline stage to make sure grad accumulation into main_grad is
# correct and does not include garbage values (e.g., from torch.empty).
self
.
shared_embedding_or_output_weight
().
zero_out_wgrad
=
True
return
if
parallel_state
.
is_pipeline_first_stage
()
and
self
.
pre_process
and
not
self
.
post_process
:
self
.
shared_embedding_or_output_weight
().
shared_embedding
=
True
if
(
self
.
post_process
or
getattr
(
self
,
'mtp_process'
,
False
))
and
not
self
.
pre_process
:
assert
not
parallel_state
.
is_pipeline_first_stage
()
# set weights of the duplicated embedding to 0 here,
# then copy weights from pre processing stage using all_reduce below.
weight
=
self
.
shared_embedding_or_output_weight
()
weight
.
data
.
fill_
(
0
)
weight
.
shared
=
True
weight
.
shared_embedding
=
True
# Parameters are shared between the word embeddings layers, and the
# heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different
# workers, so we do the following:
# 1. Create a second copy of word_embeddings on the last stage, with
# initial parameters of 0.0.
# 2. Do an all-reduce between the first and last stage to ensure that
# the two copies of word_embeddings start off with the same
# parameter values.
# 3. In the training loop, before an all-reduce between the grads of
# the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages.
# Ensure that first and last stages have the same initial parameter
# values.
if
torch
.
distributed
.
is_initialized
():
if
parallel_state
.
is_rank_in_embedding_group
():
weight
=
self
.
shared_embedding_or_output_weight
()
weight
.
data
=
weight
.
data
.
cuda
()
torch
.
distributed
.
all_reduce
(
weight
.
data
,
group
=
parallel_state
.
get_embedding_group
()
)
elif
not
getattr
(
LanguageModule
,
"embedding_warning_printed"
,
False
):
logging
.
getLogger
(
__name__
).
warning
(
"Distributed processes aren't initialized, so the output layer "
"is not initialized with weights from the word embeddings. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong."
)
LanguageModule
.
embedding_warning_printed
=
True
def
tie_embeddings_and_output_weights_state_dict
(
self
,
sharded_state_dict
:
ShardedStateDict
,
output_layer_weight_key
:
str
,
first_stage_word_emb_key
:
str
,
)
->
None
:
"""Ties the embedding and output weights in a given sharded state dict.
Args:
sharded_state_dict (ShardedStateDict): state dict with the weight to tie
output_layer_weight_key (str): key of the output layer weight in the state dict.
This entry will be replaced with a tied version
first_stage_word_emb_key (str): this must be the same as the
ShardedTensor.key of the first stage word embeddings.
Returns: None, acts in-place
"""
if
not
self
.
post_process
:
# No output layer
assert
output_layer_weight_key
not
in
sharded_state_dict
,
sharded_state_dict
.
keys
()
return
if
self
.
pre_process
:
# Output layer is equivalent to the embedding already
return
# If use Multi-Token Prediction (MTP), we need maintain both embedding layer and output
# layer in mtp process stage. In this case, if share_embeddings_and_output_weights is True,
# the shared weights will be stored in embedding layer, and output layer will not have
# any weight.
if
getattr
(
self
,
'mtp_process'
,
False
):
# No output layer
assert
output_layer_weight_key
not
in
sharded_state_dict
,
sharded_state_dict
.
keys
()
return
# Replace the default output layer with a one sharing the weights with the embedding
del
sharded_state_dict
[
output_layer_weight_key
]
tensor
=
self
.
shared_embedding_or_output_weight
()
last_stage_word_emb_replica_id
=
(
1
,
# copy of first stage embedding
0
,
parallel_state
.
get_data_parallel_rank
(
with_context_parallel
=
True
),
)
sharded_state_dict
[
output_layer_weight_key
]
=
make_tp_sharded_tensor_for_checkpoint
(
tensor
=
tensor
,
key
=
first_stage_word_emb_key
,
replica_id
=
last_stage_word_emb_replica_id
,
allow_shape_mismatch
=
True
,
)
dcu_megatron/core/models/gpt/gpt_layer_specs.py
View file @
770fa304
...
...
@@ -12,13 +12,13 @@ from megatron.core.transformer.multi_latent_attention import (
MLASelfAttentionSubmodules
,
)
from
megatron.core.transformer.spec_utils
import
ModuleSpec
from
megatron.core.transformer.transformer_block
import
TransformerBlockSubmodules
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core.transformer.transformer_layer
import
(
TransformerLayer
,
TransformerLayerSubmodules
,
)
from
dcu_megatron.core.tensor_parallel.layers
import
FluxColumnParallelLinear
,
FluxRowParallelLinear
from
megatron.core.utils
import
is_te_min_version
try
:
...
...
@@ -36,6 +36,55 @@ try:
except
ImportError
:
warnings
.
warn
(
'Apex is not installed.'
)
from
dcu_megatron.core.tensor_parallel.layers
import
(
FluxColumnParallelLinear
,
FluxRowParallelLinear
)
from
dcu_megatron.core.transformer.multi_token_prediction
import
(
MultiTokenPredictionBlockSubmodules
,
get_mtp_layer_offset
,
get_mtp_layer_spec
,
get_mtp_num_layers_to_build
,
)
def
get_gpt_mtp_block_spec
(
config
:
TransformerConfig
,
spec
:
Union
[
TransformerBlockSubmodules
,
ModuleSpec
],
use_transformer_engine
:
bool
,
)
->
MultiTokenPredictionBlockSubmodules
:
"""GPT Multi-Token Prediction (MTP) block spec."""
num_layers_to_build
=
get_mtp_num_layers_to_build
(
config
)
if
num_layers_to_build
==
0
:
return
None
if
isinstance
(
spec
,
TransformerBlockSubmodules
):
# get the spec for the last layer of decoder block
transformer_layer_spec
=
spec
.
layer_specs
[
-
1
]
elif
isinstance
(
spec
,
ModuleSpec
)
and
spec
.
module
==
TransformerLayer
:
transformer_layer_spec
=
spec
else
:
raise
ValueError
(
f
"Invalid spec:
{
spec
}
"
)
mtp_layer_spec
=
get_mtp_layer_spec
(
transformer_layer_spec
=
transformer_layer_spec
,
use_transformer_engine
=
use_transformer_engine
)
mtp_num_layers
=
config
.
mtp_num_layers
if
config
.
mtp_num_layers
else
0
mtp_layer_specs
=
[
mtp_layer_spec
]
*
mtp_num_layers
offset
=
get_mtp_layer_offset
(
config
)
# split the mtp layer specs to only include the layers that are built in this pipeline stage.
mtp_layer_specs
=
mtp_layer_specs
[
offset
:
offset
+
num_layers_to_build
]
if
len
(
mtp_layer_specs
)
>
0
:
assert
(
len
(
mtp_layer_specs
)
==
config
.
mtp_num_layers
),
+
f
"currently all of the mtp layers must stage in the same pipeline stage."
mtp_block_spec
=
MultiTokenPredictionBlockSubmodules
(
layer_specs
=
mtp_layer_specs
)
else
:
mtp_block_spec
=
None
return
mtp_block_spec
def
get_gpt_layer_with_flux_spec
(
num_experts
:
Optional
[
int
]
=
None
,
...
...
dcu_megatron/core/models/gpt/gpt_model.py
View file @
770fa304
import
os
import
logging
from
typing
import
Literal
,
Optional
from
functools
import
wraps
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from
collections
import
OrderedDict
from
typing
import
Dict
,
Literal
,
Optional
import
torch
from
torch
import
Tensor
from
megatron.core
import
InferenceParams
,
parallel_state
,
tensor_parallel
from
megatron.core
import
InferenceParams
,
tensor_parallel
from
megatron.core.config_logger
import
has_config_logger_enabled
,
log_config_to_disk
from
megatron.core.models.gpt.gpt_model
import
GPTModel
from
megatron.core.models.common.language_module.language_module
import
LanguageModule
from
megatron.core.dist_checkpointing.mapping
import
ShardedStateDict
from
megatron.core.models.common.embeddings.language_model_embedding
import
LanguageModelEmbedding
from
megatron.core.models.common.embeddings.rotary_pos_embedding
import
RotaryEmbedding
from
megatron.core.models.common.language_module.language_module
import
LanguageModule
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.transformer.enums
import
ModelType
from
megatron.core.transformer.spec_utils
import
ModuleSpec
from
megatron.core.transformer.transformer_block
import
TransformerBlock
from
megatron.core.
extensions
.transformer_
engine
import
T
EColumnParallelLinear
from
megatron.core.
transformer
.transformer_
config
import
T
ransformerConfig
from
dcu_megatron.core.utils
import
tensor_slide
from
dcu_megatron.core.transformer.mtp.multi_token_predictor
import
MultiTokenPredictor
from
dcu_megatron.core.transformer.transformer_config
import
TransformerConfig
from
dcu_megatron.core.tensor_parallel
import
FluxColumnParallelLinear
from
dcu_megatron.core.transformer.multi_token_prediction
import
(
MultiTokenPredictionBlock
,
tie_output_layer_state_dict
,
tie_word_embeddings_state_dict
,
)
def
gpt_model_init_wrapper
(
fn
):
@
wraps
(
fn
)
def
wrapper
(
self
,
*
args
,
**
kwargs
):
fn
(
self
,
*
args
,
**
kwargs
)
class
GPTModel
(
LanguageModule
):
"""GPT Transformer language model.
if
(
self
.
post_process
and
int
(
os
.
getenv
(
"USE_FLUX_OVERLAP"
,
"0"
))
):
self
.
output_layer
=
FluxColumnParallelLinear
(
self
.
config
.
hidden_size
,
self
.
vocab_size
,
Args:
config (TransformerConfig):
Transformer config
transformer_layer_spec (ModuleSpec):
Specifies module to use for transformer layers
vocab_size (int):
Vocabulary size
max_sequence_length (int):
maximum size of sequence. This is used for positional embedding
pre_process (bool, optional):
Include embedding layer (used with pipeline parallelism). Defaults to True.
post_process (bool, optional):
Include an output layer (used with pipeline parallelism). Defaults to True.
fp16_lm_cross_entropy (bool, optional):
Defaults to False.
parallel_output (bool, optional):
Do not gather the outputs, keep them split across tensor
parallel ranks. Defaults to True.
share_embeddings_and_output_weights (bool, optional):
When True, input embeddings and output logit weights are shared. Defaults to False.
position_embedding_type (Literal[learned_absolute,rope], optional):
Position embedding type.. Defaults to 'learned_absolute'.
rotary_percent (float, optional):
Percent of rotary dimension to use for rotary position embeddings.
Ignored unless position_embedding_type is 'rope'. Defaults to 1.0.
rotary_base (int, optional):
Base period for rotary position embeddings. Ignored unless
position_embedding_type is 'rope'.
Defaults to 10000.
rope_scaling (bool, optional): Toggle RoPE scaling.
rope_scaling_factor (float): RoPE scaling factor. Default 8.
scatter_embedding_sequence_parallel (bool, optional):
Whether embeddings should be scattered across sequence parallel
region or not. Defaults to True.
seq_len_interpolation_factor (Optional[float], optional):
scale of linearly interpolating RoPE for longer sequences.
The value must be a float larger than 1.0. Defaults to None.
"""
def
__init__
(
self
,
config
:
TransformerConfig
,
transformer_layer_spec
:
ModuleSpec
,
vocab_size
:
int
,
max_sequence_length
:
int
,
pre_process
:
bool
=
True
,
post_process
:
bool
=
True
,
fp16_lm_cross_entropy
:
bool
=
False
,
parallel_output
:
bool
=
True
,
share_embeddings_and_output_weights
:
bool
=
False
,
position_embedding_type
:
Literal
[
'learned_absolute'
,
'rope'
,
'none'
]
=
'learned_absolute'
,
rotary_percent
:
float
=
1.0
,
rotary_base
:
int
=
10000
,
rope_scaling
:
bool
=
False
,
rope_scaling_factor
:
float
=
8.0
,
scatter_embedding_sequence_parallel
:
bool
=
True
,
seq_len_interpolation_factor
:
Optional
[
float
]
=
None
,
mtp_block_spec
:
Optional
[
ModuleSpec
]
=
None
,
)
->
None
:
super
().
__init__
(
config
=
config
)
if
has_config_logger_enabled
(
config
):
log_config_to_disk
(
config
,
locals
(),
prefix
=
type
(
self
).
__name__
)
self
.
transformer_layer_spec
:
ModuleSpec
=
transformer_layer_spec
self
.
vocab_size
=
vocab_size
self
.
max_sequence_length
=
max_sequence_length
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
fp16_lm_cross_entropy
=
fp16_lm_cross_entropy
self
.
parallel_output
=
parallel_output
self
.
share_embeddings_and_output_weights
=
share_embeddings_and_output_weights
self
.
position_embedding_type
=
position_embedding_type
# megatron core pipelining currently depends on model type
# TODO: remove this dependency ?
self
.
model_type
=
ModelType
.
encoder_or_decoder
# These 4 attributes are needed for TensorRT-LLM export.
self
.
max_position_embeddings
=
max_sequence_length
self
.
rotary_percent
=
rotary_percent
self
.
rotary_base
=
rotary_base
self
.
rotary_scaling
=
rope_scaling
self
.
mtp_block_spec
=
mtp_block_spec
self
.
mtp_process
=
mtp_block_spec
is
not
None
if
self
.
pre_process
or
self
.
mtp_process
:
self
.
embedding
=
LanguageModelEmbedding
(
config
=
self
.
config
,
init_method
=
self
.
config
.
init_method
,
vocab_size
=
self
.
vocab_size
,
max_sequence_length
=
self
.
max_sequence_length
,
position_embedding_type
=
position_embedding_type
,
scatter_to_sequence_parallel
=
scatter_embedding_sequence_parallel
,
)
if
self
.
position_embedding_type
==
'rope'
and
not
self
.
config
.
multi_latent_attention
:
self
.
rotary_pos_emb
=
RotaryEmbedding
(
kv_channels
=
self
.
config
.
kv_channels
,
rotary_percent
=
rotary_percent
,
rotary_interleaved
=
self
.
config
.
rotary_interleaved
,
seq_len_interpolation_factor
=
seq_len_interpolation_factor
,
rotary_base
=
rotary_base
,
rope_scaling
=
rope_scaling
,
rope_scaling_factor
=
rope_scaling_factor
,
use_cpu_initialization
=
self
.
config
.
use_cpu_initialization
,
)
# Cache for RoPE tensors which do not change between iterations.
self
.
rotary_pos_emb_cache
=
{}
# Transformer.
self
.
decoder
=
TransformerBlock
(
config
=
self
.
config
,
spec
=
transformer_layer_spec
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
,
)
if
self
.
mtp_process
:
self
.
mtp
=
MultiTokenPredictionBlock
(
config
=
self
.
config
,
spec
=
self
.
mtp_block_spec
)
# Output
if
self
.
post_process
or
self
.
mtp_process
:
if
self
.
config
.
defer_embedding_wgrad_compute
:
# The embedding activation buffer preserves a reference to the input activations
# of the final embedding projection layer GEMM. It will hold the activations for
# all the micro-batches of a global batch for the last pipeline stage. Once we are
# done with all the back props for all the microbatches for the last pipeline stage,
# it will be in the pipeline flush stage. During this pipeline flush we use the
# input activations stored in embedding activation buffer and gradient outputs
# stored in gradient buffer to calculate the weight gradients for the embedding
# final linear layer.
self
.
embedding_activation_buffer
=
[]
self
.
grad_output_buffer
=
[]
else
:
self
.
embedding_activation_buffer
=
None
self
.
grad_output_buffer
=
None
if
int
(
os
.
getenv
(
"USE_FLUX_OVERLAP"
,
"0"
)):
parallel_linear_impl
=
FluxColumnParallelLinear
else
:
parallel_linear_impl
=
tensor_parallel
.
ColumnParallelLinear
self
.
output_layer
=
parallel_linear_impl
(
config
.
hidden_size
,
self
.
vocab_size
,
config
=
config
,
init_method
=
config
.
init_method
,
bias
=
False
,
skip_bias_add
=
False
,
gather_output
=
not
self
.
parallel_output
,
...
...
@@ -48,324 +186,239 @@ def gpt_model_init_wrapper(fn):
grad_output_buffer
=
self
.
grad_output_buffer
,
)
if
self
.
pre_process
or
self
.
post_process
:
self
.
setup_embeddings_and_output_layer
()
# add mtp
self
.
num_nextn_predict_layers
=
self
.
config
.
num_nextn_predict_layers
if
self
.
num_nextn_predict_layers
:
assert
hasattr
(
self
.
config
,
"mtp_spec"
)
self
.
mtp_spec
:
ModuleSpec
=
self
.
config
.
mtp_spec
self
.
share_mtp_embedding_and_output_weight
=
self
.
config
.
share_mtp_embedding_and_output_weight
self
.
recompute_mtp_norm
=
self
.
config
.
recompute_mtp_norm
self
.
recompute_mtp_layer
=
self
.
config
.
recompute_mtp_layer
self
.
mtp_loss_scale
=
self
.
config
.
mtp_loss_scale
if
self
.
post_process
and
self
.
training
:
self
.
mtp_layers
=
torch
.
nn
.
ModuleList
(
[
MultiTokenPredictor
(
self
.
config
,
self
.
mtp_spec
.
submodules
,
vocab_size
=
self
.
vocab_size
,
max_sequence_length
=
self
.
max_sequence_length
,
layer_number
=
i
,
pre_process
=
self
.
pre_process
,
fp16_lm_cross_entropy
=
self
.
fp16_lm_cross_entropy
,
parallel_output
=
self
.
parallel_output
,
position_embedding_type
=
self
.
position_embedding_type
,
rotary_percent
=
self
.
rotary_percent
,
seq_len_interpolation_factor
=
seq_len_interpolation_factor
,
share_mtp_embedding_and_output_weight
=
self
.
share_mtp_embedding_and_output_weight
,
recompute_mtp_norm
=
self
.
recompute_mtp_norm
,
recompute_mtp_layer
=
self
.
recompute_mtp_layer
,
add_output_layer_bias
=
False
)
for
i
in
range
(
self
.
num_nextn_predict_layers
)
]
)
if
self
.
pre_process
or
self
.
post_process
:
setup_mtp_embeddings
(
self
)
return
wrapper
def
shared_embedding_or_mtp_embedding_weight
(
self
)
->
Tensor
:
"""Gets the embedding weight when share embedding and mtp embedding weights set to True.
Returns:
Tensor: During pre processing it returns the input embeddings weight while during post processing it returns
mtp embedding layers weight
"""
assert
self
.
num_nextn_predict_layers
>
0
if
self
.
pre_process
:
return
self
.
embedding
.
word_embeddings
.
weight
elif
self
.
post_process
:
return
self
.
mtp_layers
[
0
].
embedding
.
word_embeddings
.
weight
return
None
def
setup_mtp_embeddings
(
self
):
"""
Share embedding layer in mtp layer.
"""
if
self
.
pre_process
:
self
.
embedding
.
word_embeddings
.
weight
.
is_embedding_or_output_parameter
=
True
# Set `is_embedding_or_output_parameter` attribute.
for
i
in
range
(
self
.
num_nextn_predict_layers
):
if
self
.
post_process
and
self
.
mtp_layers
[
i
].
embedding
.
word_embeddings
.
weight
is
not
None
:
self
.
mtp_layers
[
i
].
embedding
.
word_embeddings
.
weight
.
is_embedding_or_output_parameter
=
True
if
not
self
.
share_mtp_embedding_and_output_weight
:
return
if
self
.
pre_process
and
self
.
post_process
:
# Zero out wgrad if sharing embeddings between two layers on same
# pipeline stage to make sure grad accumulation into main_grad is
# correct and does not include garbage values (e.g., from torch.empty).
self
.
shared_embedding_or_mtp_embedding_weight
().
zero_out_wgrad
=
True
return
if
self
.
pre_process
and
not
self
.
post_process
:
assert
parallel_state
.
is_pipeline_first_stage
()
self
.
shared_embedding_or_mtp_embedding_weight
().
shared_embedding
=
True
if
self
.
post_process
and
not
self
.
pre_process
:
assert
not
parallel_state
.
is_pipeline_first_stage
()
for
i
in
range
(
self
.
num_nextn_predict_layers
):
# set word_embeddings weights to 0 here, then copy first
# stage's weights using all_reduce below.
self
.
mtp_layers
[
i
].
embedding
.
word_embeddings
.
weight
.
data
.
fill_
(
0
)
self
.
mtp_layers
[
i
].
embedding
.
word_embeddings
.
weight
.
shared
=
True
self
.
mtp_layers
[
i
].
embedding
.
word_embeddings
.
weight
.
shared_embedding
=
True
# Parameters are shared between the word embeddings layers, and the
# heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different
# workers, so we do the following:
# 1. Create a second copy of word_embeddings on the last stage, with
# initial parameters of 0.0.
# 2. Do an all-reduce between the first and last stage to ensure that
# the two copies of word_embeddings start off with the same
# parameter values.
# 3. In the training loop, before an all-reduce between the grads of
# the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages.
# Ensure that first and last stages have the same initial parameter
# values.
if
torch
.
distributed
.
is_initialized
():
if
parallel_state
.
is_rank_in_embedding_group
():
weight
=
self
.
shared_embedding_or_mtp_embedding_weight
()
weight
.
data
=
weight
.
data
.
cuda
()
torch
.
distributed
.
all_reduce
(
weight
.
data
,
group
=
parallel_state
.
get_embedding_group
()
if
has_config_logger_enabled
(
self
.
config
):
log_config_to_disk
(
self
.
config
,
self
.
state_dict
(),
prefix
=
f
'
{
type
(
self
).
__name__
}
_init_ckpt'
)
elif
not
getattr
(
LanguageModule
,
"embedding_warning_printed"
,
False
):
logging
.
getLogger
(
__name__
).
warning
(
"Distributed processes aren't initialized, so the output layer "
"is not initialized with weights from the word embeddings. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong."
)
LanguageModule
.
embedding_warning_printed
=
True
def
slice_inputs
(
self
,
input_ids
,
labels
,
position_ids
,
attention_mask
):
if
self
.
num_nextn_predict_layers
==
0
:
return
(
[
input_ids
],
[
labels
],
[
position_ids
],
[
attention_mask
],
)
def
set_input_tensor
(
self
,
input_tensor
:
Tensor
)
->
None
:
"""Sets input tensor to the model.
return
(
tensor_slide
(
input_ids
,
self
.
num_nextn_predict_layers
),
tensor_slide
(
labels
,
self
.
num_nextn_predict_layers
),
generate_nextn_position_ids
(
position_ids
,
self
.
num_nextn_predict_layers
),
# not compatible with ppo attn_mask
tensor_slide
(
attention_mask
,
self
.
num_nextn_predict_layers
,
dims
=
[
-
2
,
-
1
]),
)
See megatron.model.transformer.set_input_tensor()
Args:
input_tensor (Tensor): Sets the input tensor for the model.
"""
# This is usually handled in schedules.py but some inference code still
# gives us non-lists or None
if
not
isinstance
(
input_tensor
,
list
):
input_tensor
=
[
input_tensor
]
def
generate_nextn_position_ids
(
tensor
,
slice_num
):
slides
=
tensor_slide
(
tensor
,
slice_num
)
if
slides
[
0
]
is
None
:
return
slides
for
idx
in
range
(
1
,
len
(
slides
)):
slides
[
idx
]
=
regenerate_position_ids
(
slides
[
idx
],
idx
)
return
slides
def
regenerate_position_ids
(
tensor
,
offset
):
if
tensor
is
None
:
return
None
assert
len
(
input_tensor
)
==
1
,
'input_tensor should only be length 1 for gpt/bert'
self
.
decoder
.
set_input_tensor
(
input_tensor
[
0
])
tensor
=
tensor
.
clone
()
for
i
in
range
(
tensor
.
size
(
0
)):
row
=
tensor
[
i
]
zero_mask
=
(
row
==
0
)
# 两句拼接情形
if
zero_mask
.
any
():
first_zero_idx
=
torch
.
argmax
(
zero_mask
.
int
()).
item
()
tensor
[
i
,
:
first_zero_idx
]
=
torch
.
arange
(
first_zero_idx
)
else
:
tensor
[
i
]
=
tensor
[
i
]
-
offset
return
tensor
def
gpt_model_forward
(
self
,
input_ids
:
Tensor
,
position_ids
:
Tensor
,
attention_mask
:
Tensor
,
decoder_input
:
Tensor
=
None
,
labels
:
Tensor
=
None
,
inference_params
:
InferenceParams
=
None
,
packed_seq_params
:
PackedSeqParams
=
None
,
extra_block_kwargs
:
dict
=
None
,
runtime_gather_output
:
Optional
[
bool
]
=
None
,
)
->
Tensor
:
"""Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoeder and finally into the post
processing layer (optional).
It either returns the Loss values if labels are given or the final hidden units
Args:
runtime_gather_output (bool): Gather output at runtime. Default None means
`parallel_output` arg in the constructor will be used.
"""
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
# generate inputs for main and mtps
input_ids
,
labels
,
position_ids
,
attention_mask
=
slice_inputs
(
def
forward
(
self
,
input_ids
,
labels
,
position_ids
,
attention_mask
)
# Decoder embedding.
if
decoder_input
is
not
None
:
pass
elif
self
.
pre_process
:
decoder_input
=
self
.
embedding
(
input_ids
=
input_ids
[
0
],
position_ids
=
position_ids
[
0
])
else
:
# intermediate stage of pipeline
# decoder will get hidden_states from encoder.input_tensor
decoder_input
=
None
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb
=
None
rotary_pos_cos
=
None
rotary_pos_sin
=
None
if
self
.
position_embedding_type
==
'rope'
and
not
self
.
config
.
multi_latent_attention
:
if
not
self
.
training
and
self
.
config
.
flash_decode
and
inference_params
:
# Flash decoding uses precomputed cos and sin for RoPE
rotary_pos_cos
,
rotary_pos_sin
=
self
.
rotary_pos_emb_cache
.
setdefault
(
inference_params
.
max_sequence_length
,
self
.
rotary_pos_emb
.
get_cos_sin
(
inference_params
.
max_sequence_length
),
)
input_ids
:
Tensor
,
position_ids
:
Tensor
,
attention_mask
:
Tensor
,
decoder_input
:
Tensor
=
None
,
labels
:
Tensor
=
None
,
inference_params
:
InferenceParams
=
None
,
packed_seq_params
:
PackedSeqParams
=
None
,
extra_block_kwargs
:
dict
=
None
,
runtime_gather_output
:
Optional
[
bool
]
=
None
,
loss_mask
:
Optional
[
Tensor
]
=
None
,
)
->
Tensor
:
"""Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoeder and finally into the post
processing layer (optional).
It either returns the Loss values if labels are given or the final hidden units
Args:
runtime_gather_output (bool): Gather output at runtime. Default None means
`parallel_output` arg in the constructor will be used.
"""
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
# Decoder embedding.
if
decoder_input
is
not
None
:
pass
elif
self
.
pre_process
:
decoder_input
=
self
.
embedding
(
input_ids
=
input_ids
,
position_ids
=
position_ids
)
else
:
rotary_seq_len
=
self
.
rotary_pos_emb
.
get_rotary_seq_len
(
inference_params
,
self
.
decoder
,
decoder_input
,
self
.
config
,
packed_seq_params
)
rotary_pos_emb
=
self
.
rotary_pos_emb
(
rotary_seq_len
,
packed_seq
=
packed_seq_params
is
not
None
and
packed_seq_params
.
qkv_format
==
'thd'
,
# intermediate stage of pipeline
# decoder will get hidden_states from encoder.input_tensor
decoder_input
=
None
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb
=
None
rotary_pos_cos
=
None
rotary_pos_sin
=
None
if
self
.
position_embedding_type
==
'rope'
and
not
self
.
config
.
multi_latent_attention
:
if
not
self
.
training
and
self
.
config
.
flash_decode
and
inference_params
:
# Flash decoding uses precomputed cos and sin for RoPE
rotary_pos_cos
,
rotary_pos_sin
=
self
.
rotary_pos_emb_cache
.
setdefault
(
inference_params
.
max_sequence_length
,
self
.
rotary_pos_emb
.
get_cos_sin
(
inference_params
.
max_sequence_length
),
)
else
:
rotary_seq_len
=
self
.
rotary_pos_emb
.
get_rotary_seq_len
(
inference_params
,
self
.
decoder
,
decoder_input
,
self
.
config
,
packed_seq_params
)
rotary_pos_emb
=
self
.
rotary_pos_emb
(
rotary_seq_len
,
packed_seq
=
packed_seq_params
is
not
None
and
packed_seq_params
.
qkv_format
==
'thd'
,
)
if
(
(
self
.
config
.
enable_cuda_graph
or
self
.
config
.
flash_decode
)
and
rotary_pos_cos
is
not
None
and
inference_params
):
sequence_len_offset
=
torch
.
tensor
(
[
inference_params
.
sequence_len_offset
]
*
inference_params
.
current_batch_size
,
dtype
=
torch
.
int32
,
device
=
rotary_pos_cos
.
device
,
# Co-locate this with the rotary tensors
)
if
(
(
self
.
config
.
enable_cuda_graph
or
self
.
config
.
flash_decode
)
and
rotary_pos_cos
is
not
None
and
inference_params
):
sequence_len_offset
=
torch
.
tensor
(
[
inference_params
.
sequence_len_offset
]
*
inference_params
.
current_batch_size
,
dtype
=
torch
.
int32
,
device
=
rotary_pos_cos
.
device
,
# Co-locate this with the rotary tensors
else
:
sequence_len_offset
=
None
# Run decoder.
hidden_states
=
self
.
decoder
(
hidden_states
=
decoder_input
,
attention_mask
=
attention_mask
,
inference_params
=
inference_params
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_cos
=
rotary_pos_cos
,
rotary_pos_sin
=
rotary_pos_sin
,
packed_seq_params
=
packed_seq_params
,
sequence_len_offset
=
sequence_len_offset
,
**
(
extra_block_kwargs
or
{}),
)
else
:
sequence_len_offset
=
None
# Run decoder.
hidden_states
=
self
.
decoder
(
hidden_states
=
decoder_input
,
attention_mask
=
attention_mask
[
0
],
inference_params
=
inference_params
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_cos
=
rotary_pos_cos
,
rotary_pos_sin
=
rotary_pos_sin
,
packed_seq_params
=
packed_seq_params
,
sequence_len_offset
=
sequence_len_offset
,
**
(
extra_block_kwargs
or
{}),
)
if
not
self
.
post_process
:
return
hidden_states
# logits and loss
output_weight
=
None
if
self
.
share_embeddings_and_output_weights
:
output_weight
=
self
.
shared_embedding_or_output_weight
()
loss
=
0
# Multi token prediction module
if
self
.
num_nextn_predict_layers
and
self
.
training
:
if
not
self
.
share_embeddings_and_output_weights
and
self
.
share_mtp_embedding_and_output_weight
:
output_weight
=
self
.
output_layer
.
weight
output_weight
.
zero_out_wgrad
=
True
embedding_weight
=
self
.
shared_embedding_or_mtp_embedding_weight
()
if
self
.
share_mtp_embedding_and_output_weight
else
None
mtp_hidden_states
=
hidden_states
for
i
in
range
(
self
.
num_nextn_predict_layers
):
mtp_hidden_states
,
mtp_loss
=
self
.
mtp_layers
[
i
](
mtp_hidden_states
,
# [s,b,h]
input_ids
[
i
+
1
],
position_ids
[
i
+
1
]
if
position_ids
[
0
]
is
not
None
else
None
,
attention_mask
[
i
+
1
]
if
attention_mask
[
0
]
is
not
None
else
None
,
labels
[
i
+
1
]
if
labels
[
0
]
is
not
None
else
None
,
inference_params
,
packed_seq_params
,
extra_block_kwargs
,
embeding_weight
=
embedding_weight
,
# logits and loss
output_weight
=
None
if
self
.
share_embeddings_and_output_weights
:
output_weight
=
self
.
shared_embedding_or_output_weight
()
if
self
.
mtp_process
:
hidden_states
=
self
.
mtp
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
labels
=
labels
,
loss_mask
=
loss_mask
,
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
inference_params
=
inference_params
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_cos
=
rotary_pos_cos
,
rotary_pos_sin
=
rotary_pos_sin
,
packed_seq_params
=
packed_seq_params
,
sequence_len_offset
=
sequence_len_offset
,
embedding
=
self
.
embedding
,
output_layer
=
self
.
output_layer
,
output_weight
=
output_weight
,
runtime_gather_output
=
runtime_gather_output
,
compute_language_model_loss
=
self
.
compute_language_model_loss
,
**
(
extra_block_kwargs
or
{}),
)
loss
+=
self
.
mtp_loss_scale
/
self
.
num_nextn_predict_layers
*
mtp_loss
if
(
self
.
num_nextn_predict_layers
and
getattr
(
self
.
decoder
,
"main_final_layernorm"
,
None
)
is
not
None
):
# move block main model final norms here
hidden_states
=
self
.
decoder
.
main_final_layernorm
(
hidden_states
)
logits
,
_
=
self
.
output_layer
(
hidden_states
,
weight
=
output_weight
,
runtime_gather_output
=
runtime_gather_output
)
if
has_config_logger_enabled
(
self
.
config
):
payload
=
OrderedDict
(
{
'input_ids'
:
input_ids
[
0
],
'position_ids'
:
position_ids
[
0
],
'attention_mask'
:
attention_mask
[
0
],
'decoder_input'
:
decoder_input
,
'logits'
:
logits
,
}
if
(
self
.
num_nextn_predict_layers
and
getattr
(
self
.
decoder
,
"main_final_layernorm"
,
None
)
is
not
None
):
# move block main model final norms here
hidden_states
=
self
.
decoder
.
main_final_layernorm
(
hidden_states
)
if
not
self
.
post_process
:
return
hidden_states
logits
,
_
=
self
.
output_layer
(
hidden_states
,
weight
=
output_weight
,
runtime_gather_output
=
runtime_gather_output
)
log_config_to_disk
(
self
.
config
,
payload
,
prefix
=
'input_and_logits'
)
if
labels
[
0
]
is
None
:
# [s b h] => [b s h]
return
logits
.
transpose
(
0
,
1
).
contiguous
()
if
has_config_logger_enabled
(
self
.
config
):
payload
=
OrderedDict
(
{
'input_ids'
:
input_ids
,
'position_ids'
:
position_ids
,
'attention_mask'
:
attention_mask
,
'decoder_input'
:
decoder_input
,
'logits'
:
logits
,
}
)
log_config_to_disk
(
self
.
config
,
payload
,
prefix
=
'input_and_logits'
)
if
labels
is
None
:
# [s b h] => [b s h]
return
logits
.
transpose
(
0
,
1
).
contiguous
()
loss
=
self
.
compute_language_model_loss
(
labels
,
logits
)
return
loss
def
shared_embedding_or_output_weight
(
self
)
->
Tensor
:
"""Gets the embedding weight or output logit weights when share input embedding and
output weights set to True or when use Multi-Token Prediction (MTP) feature.
Returns:
Tensor: During pre processing or MTP process it returns the input embeddings weight.
Otherwise, during post processing it returns the final output layers weight.
"""
if
self
.
pre_process
or
self
.
mtp_process
:
# Multi-Token Prediction (MTP) need both embedding layer and output layer.
# So there will be both embedding layer and output layer in the mtp process stage.
# In this case, if share_embeddings_and_output_weights is True, the shared weights
# will be stored in embedding layer, and output layer will not have any weight.
assert
hasattr
(
self
,
'embedding'
),
f
"embedding is needed in this pipeline stage, but it is not initialized."
return
self
.
embedding
.
word_embeddings
.
weight
elif
self
.
post_process
:
return
self
.
output_layer
.
weight
return
None
loss
+=
self
.
compute_language_model_loss
(
labels
[
0
],
logits
)
def
sharded_state_dict
(
self
,
prefix
:
str
=
''
,
sharded_offsets
:
tuple
=
(),
metadata
:
Optional
[
Dict
]
=
None
)
->
ShardedStateDict
:
"""Sharded state dict implementation for GPTModel backward-compatibility
(removing extra state).
Args:
prefix (str): Module name prefix.
sharded_offsets (tuple): PP related offsets, expected to be empty at this module level.
metadata (Optional[Dict]): metadata controlling sharded state dict creation.
Returns:
ShardedStateDict: sharded state dict for the GPTModel
"""
sharded_state_dict
=
super
().
sharded_state_dict
(
prefix
,
sharded_offsets
,
metadata
)
output_layer_extra_state_key
=
f
'
{
prefix
}
output_layer._extra_state'
# Old GPT checkpoints only stored the output layer weight key. So we remove the
# _extra_state key but check that it doesn't contain any data anyway
output_extra_state
=
sharded_state_dict
.
pop
(
output_layer_extra_state_key
,
None
)
assert
not
(
output_extra_state
and
output_extra_state
.
data
),
f
'Expected output layer extra state to be empty, got:
{
output_extra_state
}
'
# Multi-Token Prediction (MTP) need both embedding layer and output layer in
# mtp process stage.
# If MTP is not placed in the pre processing stage, we need to maintain a copy of
# embedding layer in the mtp process stage and tie it to the embedding in the pre
# processing stage.
# Also, if MTP is not placed in the post processing stage, we need to maintain a copy
# of output layer in the mtp process stage and tie it to the output layer in the post
# processing stage.
if
self
.
mtp_process
and
not
self
.
pre_process
:
emb_weight_key
=
f
'
{
prefix
}
embedding.word_embeddings.weight'
emb_weight
=
self
.
embedding
.
word_embeddings
.
weight
tie_word_embeddings_state_dict
(
sharded_state_dict
,
emb_weight
,
emb_weight_key
)
if
self
.
mtp_process
and
not
self
.
post_process
:
# We only need to tie the output layer weight if share_embeddings_and_output_weights
# is False. Because if share_embeddings_and_output_weights is True, the shared weight
# will be stored in embedding layer, and output layer will not have any weight.
if
not
self
.
share_embeddings_and_output_weights
:
output_layer_weight_key
=
f
'
{
prefix
}
output_layer.weight'
output_layer_weight
=
self
.
output_layer
.
weight
tie_output_layer_state_dict
(
sharded_state_dict
,
output_layer_weight
,
output_layer_weight_key
)
return
loss
return
sharded_state_dict
dcu_megatron/core/pipeline_parallel/dualpipev/__init__.py
0 → 100644
View file @
770fa304
dcu_megatron/core/pipeline_parallel/dualpipev/dualpipev_chunks.py
0 → 100644
View file @
770fa304
# Modified from mindspeed.
import
torch
from
functools
import
wraps
from
typing
import
List
,
Optional
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron.core.utils
import
get_model_config
from
megatron.legacy.model
import
Float16Module
from
megatron.core.distributed
import
DistributedDataParallelConfig
from
megatron.core.distributed
import
DistributedDataParallel
as
DDP
from
megatron.core.enums
import
ModelType
from
megatron.training.global_vars
import
get_args
,
get_timers
from
megatron.training.utils
import
unwrap_model
from
megatron.core.pipeline_parallel
import
get_forward_backward_func
from
megatron.legacy.model.module
import
fp32_to_float16
,
float16_to_fp32
from
megatron.core.num_microbatches_calculator
import
get_num_microbatches
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core
import
parallel_state
from
megatron.core.distributed.finalize_model_grads
import
_allreduce_layernorm_grads
from
.dualpipev_schedules
import
get_dualpipe_chunk
def
dualpipev_fp16forward
(
self
,
*
inputs
,
**
kwargs
):
is_pipeline_first_stage
=
mpu
.
is_pipeline_first_stage
()
and
get_dualpipe_chunk
()
==
0
if
is_pipeline_first_stage
:
inputs
=
fp32_to_float16
(
inputs
,
self
.
float16_convertor
)
outputs
=
self
.
module
(
*
inputs
,
**
kwargs
)
is_pipeline_last_stage
=
mpu
.
is_pipeline_first_stage
()
and
get_dualpipe_chunk
()
==
1
if
is_pipeline_last_stage
:
outputs
=
float16_to_fp32
(
outputs
)
return
outputs
def
get_model
(
model_provider_func
,
model_type
=
ModelType
.
encoder_or_decoder
,
wrap_with_ddp
=
True
):
"""Build the model."""
args
=
get_args
()
args
.
model_type
=
model_type
assert
model_type
!=
ModelType
.
encoder_and_decoder
,
\
"Interleaved schedule not supported for model with both encoder and decoder"
model
=
[]
pre_process
,
post_process
=
False
,
False
if
mpu
.
is_pipeline_first_stage
():
pre_process
=
True
args
.
dualpipev_first_chunk
=
True
first_model
=
model_provider_func
(
pre_process
=
pre_process
,
post_process
=
post_process
)
first_model
.
model_type
=
model_type
model
.
append
(
first_model
)
args
.
dualpipev_first_chunk
=
False
second_model
=
model_provider_func
(
pre_process
=
post_process
,
post_process
=
pre_process
)
second_model
.
model_type
=
model_type
model
.
append
(
second_model
)
if
not
isinstance
(
model
,
list
):
model
=
[
model
]
# Set tensor model parallel attributes if not set.
# Only parameters that are already tensor model parallel have these
# attributes set for them. We should make sure the default attributes
# are set for all params so the optimizer can use them.
for
model_module
in
model
:
for
param
in
model_module
.
parameters
():
tensor_parallel
.
set_defaults_if_not_set_tensor_model_parallel_attributes
(
param
)
# Print number of parameters.
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
' > number of parameters on (tensor, pipeline) '
'model parallel rank ({}, {}): {}'
.
format
(
mpu
.
get_tensor_model_parallel_rank
(),
mpu
.
get_pipeline_model_parallel_rank
(),
sum
([
sum
([
p
.
nelement
()
for
p
in
model_module
.
parameters
()])
for
model_module
in
model
])),
flush
=
True
)
# GPU allocation.
for
model_module
in
model
:
model_module
.
cuda
(
torch
.
cuda
.
current_device
())
# Fp16 conversion.
if
args
.
fp16
or
args
.
bf16
:
model
=
[
Float16Module
(
model_module
,
args
)
for
model_module
in
model
]
if
wrap_with_ddp
:
config
=
get_model_config
(
model
[
0
])
ddp_config
=
DistributedDataParallelConfig
(
grad_reduce_in_fp32
=
args
.
accumulate_allreduce_grads_in_fp32
,
overlap_grad_reduce
=
args
.
overlap_grad_reduce
,
use_distributed_optimizer
=
args
.
use_distributed_optimizer
,
check_for_nan_in_grad
=
args
.
check_for_nan_in_loss_and_grad
,
bucket_size
=
args
.
ddp_bucket_size
,
average_in_collective
=
args
.
ddp_average_in_collective
)
model
=
[
DDP
(
config
,
ddp_config
,
model_chunk
,
# Turn off bucketing for model_chunk 2 onwards, since communication for these
# model chunks is overlapped with compute anyway.
disable_bucketing
=
(
model_chunk_idx
>
0
))
for
(
model_chunk_idx
,
model_chunk
)
in
enumerate
(
model
)]
# Broadcast params from data parallel src rank to other data parallel ranks.
if
args
.
data_parallel_random_init
:
for
model_module
in
model
:
model_module
.
broadcast_params
()
return
model
def
train_step
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
opt_param_scheduler
,
config
):
"""Single training step."""
args
=
get_args
()
timers
=
get_timers
()
rerun_state_machine
=
get_rerun_state_machine
()
while
rerun_state_machine
.
should_run_forward_backward
(
data_iterator
):
# Set grad to zero.
for
model_chunk
in
model
:
model_chunk
.
zero_grad_buffer
()
optimizer
.
zero_grad
()
# Forward pass.
forward_backward_func
=
get_forward_backward_func
()
losses_reduced
=
forward_backward_func
(
forward_step_func
=
forward_step_func
,
data_iterator
=
data_iterator
,
model
=
model
,
num_microbatches
=
get_num_microbatches
(),
seq_length
=
args
.
seq_length
,
micro_batch_size
=
args
.
micro_batch_size
,
decoder_seq_length
=
args
.
decoder_seq_length
,
forward_only
=
False
)
should_checkpoint
,
should_exit
,
exit_code
=
rerun_state_machine
.
should_checkpoint_and_exit
()
if
should_exit
:
return
{},
True
,
should_checkpoint
,
should_exit
,
exit_code
,
None
,
None
# Empty unused memory.
if
args
.
empty_unused_memory_level
>=
1
:
torch
.
cuda
.
empty_cache
()
# Vision gradients.
if
getattr
(
args
,
'vision_pretraining'
,
False
)
and
args
.
vision_pretraining_type
==
"dino"
:
unwrapped_model
=
unwrap_model
(
model
[
0
])
unwrapped_model
.
cancel_gradients_last_layer
(
args
.
curr_iteration
)
# Update parameters.
timers
(
'optimizer'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
update_successful
,
grad_norm
,
num_zeros_in_grad
=
optimizer
.
step
()
timers
(
'optimizer'
).
stop
()
# when freezing sub-models we may have a mixture of successful and unsucessful ranks,
# so we must gather across mp ranks
update_successful
=
logical_and_across_model_parallel_group
(
update_successful
)
# grad_norm and num_zeros_in_grad will be None on ranks without trainable params,
# so we must gather across mp ranks
grad_norm
=
reduce_max_stat_across_model_parallel_group
(
grad_norm
)
if
args
.
log_num_zeros_in_grad
:
num_zeros_in_grad
=
reduce_max_stat_across_model_parallel_group
(
num_zeros_in_grad
)
# Vision momentum.
if
getattr
(
args
,
'vision_pretraining'
,
False
)
and
args
.
vision_pretraining_type
==
"dino"
:
unwrapped_model
=
unwrap_model
(
model
[
0
])
unwrapped_model
.
update_momentum
(
args
.
curr_iteration
)
# Update learning rate.
if
update_successful
:
increment
=
get_num_microbatches
()
*
\
args
.
micro_batch_size
*
\
args
.
data_parallel_size
opt_param_scheduler
.
step
(
increment
=
increment
)
skipped_iter
=
0
else
:
skipped_iter
=
1
# Empty unused memory.
if
args
.
empty_unused_memory_level
>=
2
:
torch
.
cuda
.
empty_cache
()
dualpipev_last_stage
=
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
)
if
dualpipev_last_stage
:
# Average loss across microbatches.
loss_reduced
=
{}
for
key
in
losses_reduced
[
0
].
keys
():
numerator
=
0
denominator
=
0
for
x
in
losses_reduced
:
val
=
x
[
key
]
# there is one dict per microbatch. in new reporting, we average
# over the total number of tokens across the global batch.
if
isinstance
(
val
,
tuple
)
or
isinstance
(
val
,
list
):
numerator
+=
val
[
0
]
denominator
+=
val
[
1
]
else
:
# legacy behavior. we average over the number of microbatches,
# and so the denominator is 1.
numerator
+=
val
denominator
+=
1
loss_reduced
[
key
]
=
numerator
/
denominator
return
loss_reduced
,
skipped_iter
,
should_checkpoint
,
should_exit
,
exit_code
,
grad_norm
,
num_zeros_in_grad
return
{},
skipped_iter
,
should_checkpoint
,
should_exit
,
exit_code
,
grad_norm
,
num_zeros_in_grad
def
get_num_layers_to_build
(
config
:
TransformerConfig
)
->
int
:
num_layers_per_pipeline_rank
=
(
config
.
num_layers
//
parallel_state
.
get_pipeline_model_parallel_world_size
()
)
num_layers_to_build
=
num_layers_per_pipeline_rank
//
2
return
num_layers_to_build
def
_allreduce_embedding_grads_wrapper
(
fn
):
@
wraps
(
fn
)
def
wrapper
(
*
args
,
**
kwargs
):
if
get_args
().
schedules_method
==
'dualpipev'
:
# dualpipev no need to do embedding allreduce
# embedding and lm head are on save rank.
if
not
get_args
().
untie_embeddings_and_output_weights
:
raise
NotImplementedError
else
:
return
else
:
return
fn
(
*
args
,
**
kwargs
)
return
wrapper
dcu_megatron/core/pipeline_parallel/dualpipev/dualpipev_schedules.py
0 → 100644
View file @
770fa304
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
import
contextlib
from
functools
import
wraps
from
typing
import
Iterator
,
List
,
Union
import
torch
from
megatron.core
import
parallel_state
from
megatron.core.enums
import
ModelType
from
megatron.training
import
get_args
from
megatron.core.transformer.moe.router
import
MoEAuxLossAutoScaler
from
megatron.core.utils
import
(
get_attr_wrapped_model
,
get_model_config
,
get_model_type
,
)
from
megatron.core.pipeline_parallel.schedules
import
clear_embedding_activation_buffer
,
deallocate_output_tensor
from
megatron.core
import
ModelParallelConfig
from
megatron.core.pipeline_parallel.p2p_communication
import
_communicate
from
megatron.core.pipeline_parallel.schedules
import
backward_step
,
set_current_microbatch
,
custom_backward
,
finish_embedding_wgrad_compute
from
megatron.core.models.gpt
import
GPTModel
from
mindspeed.core.pipeline_parallel.fb_overlap.gpt_model
import
gpt_model_backward
from
mindspeed.core.pipeline_parallel.fb_overlap.transformer_layer
import
P2PCommParams
from
mindspeed.core.pipeline_parallel.fb_overlap.modules.weight_grad_store
import
WeightGradStore
# Types
Shape
=
Union
[
List
[
int
],
torch
.
Size
]
LOSS_BACKWARD_SCALE
=
torch
.
tensor
(
1.0
)
_DUALPIPE_CHUNK
=
None
def
set_dualpipe_chunk
(
chunkid
):
"""set_dualpipe_chunk for fp16forward patch"""
global
_DUALPIPE_CHUNK
_DUALPIPE_CHUNK
=
chunkid
def
get_dualpipe_chunk
():
global
_DUALPIPE_CHUNK
if
_DUALPIPE_CHUNK
is
not
None
:
return
_DUALPIPE_CHUNK
else
:
raise
AssertionError
(
"_DUALPIPE_CHUNK is None"
)
def
is_dualpipev_last_stgae
(
model_chunk_id
):
return
parallel_state
.
is_pipeline_first_stage
()
and
model_chunk_id
==
1
def
send_forward
(
output_tensor
:
torch
.
Tensor
,
tensor_shape
,
config
:
ModelParallelConfig
,
model_chunk_id
,
async_op
=
False
)
->
None
:
"""Send tensor to next rank in pipeline (forward send).
See _communicate for argument details.
"""
tensor_send_next
,
tensor_send_prev
=
None
,
None
if
model_chunk_id
==
0
:
if
parallel_state
.
is_pipeline_last_stage
():
return
None
tensor_send_next
=
output_tensor
else
:
if
parallel_state
.
is_pipeline_first_stage
():
return
None
tensor_send_prev
=
output_tensor
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-send'
,
log_level
=
2
).
start
()
_
,
_
,
fwd_wait_handles
=
_communicate
(
tensor_send_next
=
tensor_send_next
,
tensor_send_prev
=
tensor_send_prev
,
recv_prev
=
False
,
recv_next
=
False
,
tensor_shape
=
tensor_shape
,
config
=
config
,
wait_on_reqs
=
(
not
async_op
)
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-send'
).
stop
()
return
fwd_wait_handles
def
send_backward
(
input_tensor_grad
:
torch
.
Tensor
,
tensor_shape
,
config
:
ModelParallelConfig
,
model_chunk_id
,
async_op
=
False
)
->
None
:
"""Send tensor to next rank in pipeline (forward send).
See _communicate for argument details.
"""
tensor_send_next
,
tensor_send_prev
=
None
,
None
if
model_chunk_id
==
0
:
if
parallel_state
.
is_pipeline_first_stage
():
return
None
tensor_send_prev
=
input_tensor_grad
else
:
if
parallel_state
.
is_pipeline_last_stage
():
return
None
tensor_send_next
=
input_tensor_grad
if
config
.
timers
is
not
None
:
config
.
timers
(
'backward-send'
,
log_level
=
2
).
start
()
_
,
_
,
reqs
=
_communicate
(
tensor_send_next
=
tensor_send_next
,
tensor_send_prev
=
tensor_send_prev
,
recv_prev
=
False
,
recv_next
=
False
,
tensor_shape
=
tensor_shape
,
config
=
config
,
wait_on_reqs
=
(
not
async_op
)
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'backward-send'
).
stop
()
return
reqs
def
recv_forward
(
tensor_shape
:
Shape
,
config
:
ModelParallelConfig
,
model_chunk_id
,
async_op
=
False
)
->
torch
.
Tensor
:
""" Receive tensor from previous rank in pipeline (forward receive).
See _communicate for argument details.
"""
recv_prev
,
recv_next
=
False
,
False
if
model_chunk_id
==
0
:
recv_prev
=
True
else
:
recv_next
=
True
if
(
parallel_state
.
is_pipeline_first_stage
()
and
recv_prev
)
or
(
parallel_state
.
is_pipeline_last_stage
()
and
recv_next
):
fwd_wait_handles
=
None
return
None
,
fwd_wait_handles
else
:
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-recv'
,
log_level
=
2
).
start
()
tensor_recv_prev
,
tensor_recv_next
,
fwd_wait_handles
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
config
=
config
,
wait_on_reqs
=
(
not
async_op
),
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-recv'
).
stop
()
if
recv_prev
:
return
tensor_recv_prev
,
fwd_wait_handles
else
:
return
tensor_recv_next
,
fwd_wait_handles
def
recv_backward
(
tensor_shape
:
Shape
,
config
:
ModelParallelConfig
,
model_chunk_id
,
async_op
=
False
)
->
torch
.
Tensor
:
"""Receive tensor from next rank in pipeline (backward receive).
See _communicate for argument details.
"""
recv_prev
,
recv_next
=
False
,
False
if
model_chunk_id
==
0
:
recv_next
=
True
else
:
recv_prev
=
True
if
(
parallel_state
.
is_pipeline_first_stage
()
and
recv_prev
)
or
(
parallel_state
.
is_pipeline_last_stage
()
and
recv_next
):
output_tensor_grad
=
None
bwd_wait_handles
=
None
return
output_tensor_grad
,
bwd_wait_handles
else
:
if
config
.
timers
is
not
None
:
config
.
timers
(
'backward-recv'
,
log_level
=
2
).
start
()
tensor_recv_prev
,
tensor_recv_next
,
bwd_wait_handles
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
config
=
config
,
wait_on_reqs
=
(
not
async_op
)
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'backward-recv'
).
stop
()
if
recv_prev
:
return
tensor_recv_prev
,
bwd_wait_handles
else
:
return
tensor_recv_next
,
bwd_wait_handles
def
send_forward_recv_forward
(
output_tensor
:
torch
.
Tensor
,
tensor_shape
:
Shape
,
config
:
ModelParallelConfig
,
model_chunk_id
,
async_op
=
False
)
->
torch
.
Tensor
:
"""Batched recv from previous rank and send to next rank in pipeline.
See _communicate for argument details.
"""
recv_prev
,
recv_next
=
False
,
False
tensor_send_next
,
tensor_send_prev
=
None
,
None
if
model_chunk_id
==
0
:
if
not
parallel_state
.
is_pipeline_last_stage
():
tensor_send_next
=
output_tensor
if
not
parallel_state
.
is_pipeline_first_stage
():
recv_prev
=
True
if
model_chunk_id
==
1
:
if
not
parallel_state
.
is_pipeline_first_stage
():
tensor_send_prev
=
output_tensor
if
not
parallel_state
.
is_pipeline_last_stage
():
recv_next
=
True
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-send-forward-recv'
,
log_level
=
2
).
start
()
tensor_recv_prev
,
tensor_recv_next
,
fwd_wait_handles
=
_communicate
(
tensor_send_next
=
tensor_send_next
,
tensor_send_prev
=
tensor_send_prev
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
wait_on_reqs
=
(
not
async_op
),
config
=
config
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-send-forward-recv'
).
stop
()
if
model_chunk_id
==
0
:
if
not
parallel_state
.
is_pipeline_first_stage
():
return
tensor_recv_prev
,
fwd_wait_handles
else
:
return
None
,
fwd_wait_handles
else
:
if
not
parallel_state
.
is_pipeline_last_stage
():
return
tensor_recv_next
,
fwd_wait_handles
else
:
return
None
,
fwd_wait_handles
def
send_forward_recv_slave_forward
(
output_tensor
:
torch
.
Tensor
,
tensor_shape
:
Shape
,
config
:
ModelParallelConfig
,
model_chunk_id
,
async_op
=
False
,
)
->
torch
.
Tensor
:
"""Batched recv from previous rank and send to next rank in pipeline.
See _communicate for argument details.
"""
recv_prev
,
recv_next
=
False
,
False
tensor_send_next
,
tensor_send_prev
=
None
,
None
if
model_chunk_id
==
0
:
if
parallel_state
.
is_pipeline_last_stage
():
return
None
,
None
tensor_send_next
=
output_tensor
recv_next
=
True
if
model_chunk_id
==
1
:
if
parallel_state
.
is_pipeline_first_stage
():
return
None
,
None
tensor_send_prev
=
output_tensor
recv_prev
=
True
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-send-slave-forward-recv'
,
log_level
=
2
).
start
()
tensor_recv_prev
,
tensor_recv_next
,
fwd_wait_handles
=
_communicate
(
tensor_send_next
=
tensor_send_next
,
tensor_send_prev
=
tensor_send_prev
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
wait_on_reqs
=
(
not
async_op
),
config
=
config
,
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-send-slave-forward-recv'
).
stop
()
if
model_chunk_id
==
0
:
return
tensor_recv_next
,
fwd_wait_handles
else
:
return
tensor_recv_prev
,
fwd_wait_handles
def
generate_dualpipev_schedule
(
pp_size
,
num_microbatches
):
num_microbatches
=
num_microbatches
*
2
num_warmup_stages
=
[
0
]
*
pp_size
num_interleaved_forward_stages
=
[
0
]
*
pp_size
num_1b1w1f_stages
=
[
0
]
*
pp_size
num_overlap_stages
=
[
0
]
*
pp_size
num_1b1overlap_stages
=
[
0
]
*
pp_size
num_interleaved_backward_stages
=
[
0
]
*
pp_size
num_cooldown_stages
=
[
0
]
*
pp_size
pp_size
*=
2
for
i
in
range
(
pp_size
//
2
):
num_warmup_stages
[
i
]
=
pp_size
-
2
-
i
*
2
num_interleaved_forward_stages
[
i
]
=
i
+
1
# 每个单位是一组1f1f
num_1b1w1f_stages
[
i
]
=
pp_size
//
2
-
i
-
1
num_overlap_stages
[
i
]
=
num_microbatches
-
pp_size
*
2
+
i
*
2
+
2
num_1b1overlap_stages
[
i
]
=
(
pp_size
//
2
-
i
-
1
)
*
2
num_interleaved_backward_stages
[
i
]
=
i
+
1
num_cooldown_stages
[
i
]
=
[
i
+
1
,
pp_size
-
2
*
i
-
2
,
i
+
1
]
schedule_all_stages
=
{
'warmup'
:
num_warmup_stages
,
'interleaved_forward'
:
num_interleaved_forward_stages
,
'1b1w1f'
:
num_1b1w1f_stages
,
'overlap'
:
num_overlap_stages
,
'1b1overlap'
:
num_1b1overlap_stages
,
'interleaved_backward'
:
num_interleaved_backward_stages
,
'cooldown'
:
num_cooldown_stages
}
return
schedule_all_stages
def
pretrain_gpt_forward_step_dualpipe
(
data_iterator
,
model
:
GPTModel
,
extra_block_kwargs
=
None
):
from
megatron.training
import
get_timers
from
functools
import
partial
from
pretrain_gpt
import
get_batch
,
loss_func
"""Forward training step.
Args:
data_iterator : Input data iterator
model (GPTModel): The GPT Model
"""
timers
=
get_timers
()
# Get the batch.
timers
(
'batch-generator'
,
log_level
=
2
).
start
()
tokens
,
labels
,
loss_mask
,
attention_mask
,
position_ids
=
get_batch
(
data_iterator
)
timers
(
'batch-generator'
).
stop
()
if
extra_block_kwargs
is
not
None
:
# excute forward backward overlaping
output_tensor
,
model_graph
,
pp_comm_output
=
\
model
(
tokens
,
position_ids
,
attention_mask
,
labels
=
labels
,
extra_block_kwargs
=
extra_block_kwargs
)
return
(
output_tensor
,
model_graph
,
pp_comm_output
),
partial
(
loss_func
,
loss_mask
)
else
:
output_tensor
,
model_graph
=
model
(
tokens
,
position_ids
,
attention_mask
,
labels
=
labels
)
return
(
output_tensor
,
model_graph
),
partial
(
loss_func
,
loss_mask
)
def
forward_step_no_model_graph
(
forward_step_func
,
model_chunk_id
,
data_iterator
,
model
,
num_microbatches
,
input_tensor
,
forward_data_store
,
config
,
collect_non_loss_data
=
False
,
checkpoint_activations_microbatch
=
None
,
is_first_microbatch
=
False
,
current_microbatch
=
None
,
):
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-compute'
,
log_level
=
2
).
start
()
if
is_first_microbatch
and
hasattr
(
model
,
'set_is_first_microbatch'
):
model
.
set_is_first_microbatch
()
if
current_microbatch
is
not
None
:
set_current_microbatch
(
model
,
current_microbatch
)
unwrap_output_tensor
=
False
if
not
isinstance
(
input_tensor
,
list
):
input_tensor
=
[
input_tensor
]
unwrap_output_tensor
=
True
set_input_tensor
=
get_attr_wrapped_model
(
model
,
"set_input_tensor"
)
set_input_tensor
(
input_tensor
)
if
config
.
enable_autocast
:
context_manager
=
torch
.
autocast
(
"cuda"
,
dtype
=
config
.
autocast_dtype
)
else
:
context_manager
=
contextlib
.
nullcontext
()
with
context_manager
:
if
checkpoint_activations_microbatch
is
None
:
output_tensor
,
loss_func
=
forward_step_func
(
data_iterator
,
model
)
else
:
output_tensor
,
loss_func
=
forward_step_func
(
data_iterator
,
model
,
checkpoint_activations_microbatch
)
num_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
int
)
if
is_dualpipev_last_stgae
:
if
not
collect_non_loss_data
:
outputs
=
loss_func
(
output_tensor
)
if
len
(
outputs
)
==
3
:
output_tensor
,
num_tokens
,
loss_reduced
=
outputs
if
not
config
.
calculate_per_token_loss
:
output_tensor
/=
num_tokens
output_tensor
/=
num_microbatches
else
:
# preserve legacy loss averaging behavior (ie, over the number of microbatches)
assert
len
(
outputs
)
==
2
output_tensor
,
loss_reduced
=
outputs
output_tensor
/=
num_microbatches
forward_data_store
.
append
(
loss_reduced
)
else
:
data
=
loss_func
(
output_tensor
,
non_loss_data
=
True
)
forward_data_store
.
append
(
data
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-compute'
).
stop
()
# Set the loss scale for the auxiliary loss of the MoE layer.
# Since we use a trick to do backward on the auxiliary loss, we need to set the scale explicitly.
if
hasattr
(
config
,
'num_moe_experts'
)
and
config
.
num_moe_experts
is
not
None
:
# Calculate the loss scale based on the grad_scale_func if available, else default to 1.
loss_scale
=
(
config
.
grad_scale_func
(
torch
.
ones
(
1
,
device
=
output_tensor
.
device
))
if
config
.
grad_scale_func
is
not
None
else
torch
.
tensor
(
1.0
)
)
# Set the loss scale
MoEAuxLossAutoScaler
.
set_loss_scale
(
loss_scale
/
num_microbatches
)
# If T5 model (or other model with encoder and decoder)
# and in decoder stack, then send encoder_hidden_state
# downstream as well.
model_type
=
get_model_type
(
model
)
if
(
parallel_state
.
is_pipeline_stage_after_split
()
and
model_type
==
ModelType
.
encoder_and_decoder
):
return
[
output_tensor
,
input_tensor
[
-
1
]],
num_tokens
if
unwrap_output_tensor
:
return
output_tensor
,
num_tokens
return
[
output_tensor
],
num_tokens
def
backward_step_with_model_graph
(
input_tensor
,
output_tensor
,
output_tensor_grad
,
model_type
,
config
,
model_graph
=
None
):
"""Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss
with respect to stage's output tensor.
Returns gradient of loss with respect to input tensor (None if first
stage)."""
# NOTE: This code currently can handle at most one skip connection. It
# needs to be modified slightly to support arbitrary numbers of skip
# connections.
if
config
.
timers
is
not
None
:
config
.
timers
(
'backward-compute'
,
log_level
=
2
).
start
()
# Retain the grad on the input_tensor.
unwrap_input_tensor_grad
=
False
if
not
isinstance
(
input_tensor
,
list
):
input_tensor
=
[
input_tensor
]
unwrap_input_tensor_grad
=
True
for
x
in
input_tensor
:
if
x
is
not
None
:
x
.
retain_grad
()
if
not
isinstance
(
output_tensor
,
list
):
output_tensor
=
[
output_tensor
]
if
not
isinstance
(
output_tensor_grad
,
list
):
output_tensor_grad
=
[
output_tensor_grad
]
# Backward pass.
if
output_tensor_grad
[
0
]
is
None
and
config
.
grad_scale_func
is
not
None
and
model_graph
is
None
:
output_tensor
[
0
]
=
config
.
grad_scale_func
(
output_tensor
[
0
])
if
config
.
deallocate_pipeline_outputs
:
if
model_graph
is
None
:
custom_backward
(
output_tensor
[
0
],
output_tensor_grad
[
0
])
else
:
layer_output_grad
=
gpt_model_backward
(
output_tensor_grad
[
0
],
model_graph
)
else
:
torch
.
autograd
.
backward
(
output_tensor
[
0
],
grad_tensors
=
output_tensor_grad
[
0
])
# Collect the grad of the input_tensor.
input_tensor_grad
=
[
None
]
if
input_tensor
is
not
None
:
input_tensor_grad
=
[]
if
model_graph
is
not
None
:
input_tensor_grad
.
append
(
layer_output_grad
)
else
:
for
x
in
input_tensor
:
if
x
is
None
:
input_tensor_grad
.
append
(
None
)
else
:
input_tensor_grad
.
append
(
x
.
grad
)
# Handle single skip connection if it exists (encoder_hidden_state in
# model with encoder and decoder).
if
(
parallel_state
.
get_pipeline_model_parallel_world_size
()
>
1
and
parallel_state
.
is_pipeline_stage_after_split
()
and
model_type
==
ModelType
.
encoder_and_decoder
):
if
output_tensor_grad
[
1
]
is
not
None
:
input_tensor_grad
[
-
1
].
add_
(
output_tensor_grad
[
1
])
if
unwrap_input_tensor_grad
:
input_tensor_grad
=
input_tensor_grad
[
0
]
if
config
.
timers
is
not
None
:
config
.
timers
(
'backward-compute'
).
stop
()
return
input_tensor_grad
def
forward_step_with_model_graph
(
forward_step_func
,
model_chunk_id
,
data_iterator
,
model
,
num_microbatches
,
input_tensor
,
forward_data_store
,
config
,
collect_non_loss_data
=
False
,
checkpoint_activations_microbatch
=
None
,
is_first_microbatch
=
False
,
current_microbatch
=
None
,
extra_block_kwargs
=
None
,
):
"""Forward step for passed-in model.
If it is the first stage, the input tensor is obtained from the data_iterator.
Otherwise, the passed-in input_tensor is used.
Args:
forward_step_func (callable): The forward step function for the model that takes the
data iterator as the first argument, and model as the second.
This user's forward step is expected to output a tuple of two elements:
1. The output object from the forward step. This output object needs to be a
tensor or some kind of collection of tensors. The only hard requirement
for this object is that it needs to be acceptible as input into the second
function.
2. A function to reduce (optionally) the output from the forward step. This
could be a reduction over the loss from the model, it could be a function that
grabs the output from the model and reformats, it could be a function that just
passes through the model output. This function must have one of the following
patterns, and depending on the pattern different things happen internally.
a. A tuple of reduced loss and some other data. Note that in this case
the first argument is divided by the number of global microbatches,
assuming it is a loss, so that the loss is stable as a function of
the number of devices the step is split across.
b. A triple of reduced loss, number of tokens, and some other data. This
is similar to case (a), but the loss is further averaged across the
number of tokens in the batch. If the user is not already averaging
across the number of tokens, this pattern is useful to use.
c. Any arbitrary data the user wants (eg a dictionary of tensors, a list
of tensors, etc in the case of inference). To trigger case 3 you need
to specify `collect_non_loss_data=True` and you may also want to
specify `forward_only=True` in the call to the parent forward_backward
function.
data_iterator (iterator): The data iterator.
model (nn.Module): The model to perform the forward step on.
num_microbatches (int): The number of microbatches.
input_tensor (Tensor or list[Tensor]): The input tensor(s) for the forward step.
forward_data_store (list): The list to store the forward data. If you go down path 2.a or
2.b for the return of your forward reduction function then this will store only the
final dimension of the output, for example the metadata output by the loss function.
If you go down the path of 2.c then this will store the entire output of the forward
reduction function applied to the model output.
config (object): The configuration object.
collect_non_loss_data (bool, optional): Whether to collect non-loss data. Defaults to False.
This is the path to use if you want to collect arbitrary output from the model forward,
such as with inference use cases. Defaults to False.
checkpoint_activations_microbatch (int, optional): The microbatch to checkpoint activations.
Defaults to None.
is_first_microbatch (bool, optional): Whether it is the first microbatch. Defaults to False.
current_microbatch (int, optional): The current microbatch. Defaults to None.
Returns:
Tensor or list[Tensor]: The output object(s) from the forward step.
Tensor: The number of tokens.
"""
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-compute'
,
log_level
=
2
).
start
()
if
is_first_microbatch
and
hasattr
(
model
,
'set_is_first_microbatch'
):
model
.
set_is_first_microbatch
()
if
current_microbatch
is
not
None
:
set_current_microbatch
(
model
,
current_microbatch
)
unwrap_output_tensor
=
False
if
not
isinstance
(
input_tensor
,
list
):
input_tensor
=
[
input_tensor
]
unwrap_output_tensor
=
True
set_input_tensor
=
get_attr_wrapped_model
(
model
,
"set_input_tensor"
)
set_input_tensor
(
input_tensor
)
if
config
.
enable_autocast
:
context_manager
=
torch
.
autocast
(
"cuda"
,
dtype
=
config
.
autocast_dtype
)
else
:
context_manager
=
contextlib
.
nullcontext
()
with
context_manager
:
if
checkpoint_activations_microbatch
is
None
:
output_tensor
,
loss_func
=
pretrain_gpt_forward_step_dualpipe
(
data_iterator
,
model
,
extra_block_kwargs
)
else
:
output_tensor
,
loss_func
=
pretrain_gpt_forward_step_dualpipe
(
data_iterator
,
model
,
checkpoint_activations_microbatch
,
extra_block_kwargs
)
num_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
int
)
if
is_dualpipev_last_stgae
(
model_chunk_id
):
if
not
collect_non_loss_data
:
next_info
=
None
if
isinstance
(
output_tensor
,
tuple
):
# use pp overlaping,
if
len
(
output_tensor
)
==
2
:
output_tensor
,
model_graph
=
output_tensor
elif
len
(
output_tensor
)
==
3
:
output_tensor
,
model_graph
,
next_info
=
output_tensor
outputs
=
loss_func
(
output_tensor
)
if
len
(
outputs
)
==
3
:
output_tensor
,
num_tokens
,
loss_reduced
=
outputs
if
not
config
.
calculate_per_token_loss
:
output_tensor
/=
num_tokens
output_tensor
/=
num_microbatches
else
:
# preserve legacy loss averaging behavior (ie, over the number of microbatches)
assert
len
(
outputs
)
==
2
output_tensor
,
loss_reduced
=
outputs
output_tensor
/=
num_microbatches
forward_data_store
.
append
(
loss_reduced
)
output_tensor
=
(
output_tensor
,
model_graph
,
next_info
)
if
next_info
is
not
None
else
(
output_tensor
,
model_graph
)
else
:
data
=
loss_func
(
output_tensor
,
non_loss_data
=
True
)
forward_data_store
.
append
(
data
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-compute'
).
stop
()
# Set the loss scale for the auxiliary loss of the MoE layer.
# Since we use a trick to do backward on the auxiliary loss, we need to set the scale explicitly.
if
hasattr
(
config
,
'num_moe_experts'
)
and
config
.
num_moe_experts
is
not
None
:
# Calculate the loss scale based on the grad_scale_func if available, else default to 1.
loss_scale
=
(
config
.
grad_scale_func
(
LOSS_BACKWARD_SCALE
)
if
config
.
grad_scale_func
is
not
None
else
torch
.
tensor
(
1.0
)
)
# Set the loss scale
MoEAuxLossAutoScaler
.
set_loss_scale
(
loss_scale
/
num_microbatches
)
# If T5 model (or other model with encoder and decoder)
# and in decoder stack, then send encoder_hidden_state
# downstream as well.
model_type
=
get_model_type
(
model
)
if
(
parallel_state
.
is_pipeline_stage_after_split
()
and
model_type
==
ModelType
.
encoder_and_decoder
):
return
[
output_tensor
,
input_tensor
[
-
1
]],
num_tokens
if
unwrap_output_tensor
:
return
output_tensor
,
num_tokens
return
[
output_tensor
],
num_tokens
shared_embedding
=
None
def
get_shared_embedding_from_dual_chunk
():
assert
shared_embedding
is
not
None
return
shared_embedding
def
set_shared_embedding_from_dual_chunk
(
model1
,
model2
):
global
shared_embedding
if
shared_embedding
is
not
None
:
return
if
model1
.
module
.
module
.
pre_process
:
shared_embedding
=
model1
.
module
.
module
.
embedding
.
word_embeddings
.
weight
elif
model2
.
module
.
module
.
pre_process
:
shared_embedding
=
model2
.
module
.
module
.
embedding
.
word_embeddings
.
weight
def
forward_backward_pipelining_with_cutinhalf
(
*
,
forward_step_func
,
data_iterator
:
Union
[
Iterator
,
List
[
Iterator
]],
model
:
Union
[
torch
.
nn
.
Module
,
List
[
torch
.
nn
.
Module
]],
num_microbatches
:
int
,
seq_length
:
int
,
micro_batch_size
:
int
,
decoder_seq_length
:
int
=
None
,
forward_only
:
bool
=
False
,
collect_non_loss_data
:
bool
=
False
,
first_val_step
:
bool
=
None
,
):
args
=
get_args
()
args
.
moe_fb_overlap
=
True
args
.
dualpipe_no_dw_detach
=
True
set_shared_embedding_from_dual_chunk
(
model
[
0
],
model
[
1
])
assert
(
isinstance
(
model
,
list
)
and
len
(
model
)
==
2
),
'Dualpipe Schedule only support chunk model for two consecutive chunks'
assert
(
isinstance
(
data_iterator
,
list
)
and
len
(
data_iterator
)
==
2
),
'Dualpipe Schedule only support two data_iterators'
config
=
get_model_config
(
model
[
0
])
config
.
batch_p2p_comm
=
False
# Needed only when gradients are finalized in M-Core
if
config
.
finalize_model_grads_func
is
not
None
and
not
forward_only
:
embedding_module
=
clear_embedding_activation_buffer
(
config
,
model
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-backward'
,
log_level
=
1
).
start
(
barrier
=
config
.
barrier_with_L1_time
)
# Disable async grad reductions
no_sync_func
=
config
.
no_sync_func
if
no_sync_func
is
None
:
no_sync_func
=
contextlib
.
nullcontext
no_sync_context
=
None
def
disable_grad_sync
():
"""Disable asynchronous grad reductions"""
nonlocal
no_sync_context
if
no_sync_context
is
None
:
no_sync_context
=
no_sync_func
()
no_sync_context
.
__enter__
()
def
enable_grad_sync
():
"""Enable asynchronous grad reductions"""
nonlocal
no_sync_context
if
no_sync_context
is
not
None
:
no_sync_context
.
__exit__
(
None
,
None
,
None
)
no_sync_context
=
None
disable_grad_sync
()
# Compute number of steps for each stage
pp_size
=
parallel_state
.
get_pipeline_model_parallel_world_size
()
rank
=
parallel_state
.
get_pipeline_model_parallel_rank
()
schedule
=
generate_dualpipev_schedule
(
pp_size
,
num_microbatches
)
model_type
=
get_model_type
(
model
[
0
])
tensor_shape
=
[
seq_length
,
micro_batch_size
,
config
.
hidden_size
]
tensor_shape
[
0
]
=
tensor_shape
[
0
]
//
parallel_state
.
get_context_parallel_world_size
()
if
config
.
sequence_parallel
:
tensor_shape
[
0
]
=
tensor_shape
[
0
]
//
parallel_state
.
get_tensor_model_parallel_world_size
()
total_num_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
int
).
cuda
()
input_tensors
=
[[],
[]]
output_tensors
=
[[],
[]]
model_graphs
=
[[],
[]]
logits_inputs
=
[]
forward_data_store
=
[]
master_chunk_id
=
0
slave_chunk_id
=
1
master_cur_microbatch
=
0
slave_cur_microbatch
=
num_microbatches
master_microbatch_max
=
num_microbatches
slave_microbatch_max
=
num_microbatches
*
2
set_dualpipe_chunk
(
master_chunk_id
)
checkpoint_activations_microbatch
=
None
def
forward_step_helper
(
model_chunk_id
,
current_microbatch
,
checkpoint_activations_microbatch
,
is_first_microbatch
=
False
,
extra_block_kwargs
=
None
):
input_tensor
=
input_tensors
[
model_chunk_id
][
-
1
][
1
]
output_tensor
,
num_tokens
=
forward_step_with_model_graph
(
forward_step_func
,
model_chunk_id
,
data_iterator
[
model_chunk_id
],
model
[
model_chunk_id
],
num_microbatches
,
input_tensor
,
forward_data_store
,
config
,
collect_non_loss_data
,
checkpoint_activations_microbatch
,
is_first_microbatch
,
current_microbatch
=
current_microbatch
,
extra_block_kwargs
=
extra_block_kwargs
)
if
isinstance
(
output_tensor
,
tuple
):
if
len
(
output_tensor
)
==
2
:
output_tensor_
,
model_graph
=
output_tensor
elif
len
(
output_tensor
)
==
3
:
output_tensor_
,
model_graph
,
pp_comm_output
=
output_tensor
if
is_dualpipev_last_stgae
(
model_chunk_id
):
logits_inputs
.
append
(
model_graph
.
layer_graphs
[
-
1
].
unperm2_graph
[
1
])
model_graphs
[
model_chunk_id
].
append
(
model_graph
)
else
:
output_tensor_
=
output_tensor
output_tensors
[
model_chunk_id
].
append
(
output_tensor_
)
if
extra_block_kwargs
is
not
None
:
input_tensors
[
1
-
model_chunk_id
].
pop
(
0
)
output_tensors
[
1
-
model_chunk_id
].
pop
(
0
)
nonlocal
total_num_tokens
total_num_tokens
+=
num_tokens
.
item
()
# if forward-only, no need to save tensors for a backward pass
if
forward_only
:
input_tensors
[
model_chunk_id
].
pop
()
output_tensors
[
model_chunk_id
].
pop
()
return
output_tensor
def
check_pipeline_stage
(
model_chunk_id
,
fwd_send_only
):
send_next
,
recv_next
,
send_prev
,
recv_prev
=
True
,
True
,
True
,
True
if
parallel_state
.
is_pipeline_first_stage
():
send_prev
,
recv_prev
=
False
,
False
if
parallel_state
.
is_pipeline_last_stage
():
send_next
,
recv_next
=
False
,
False
if
model_chunk_id
==
0
:
return
P2PCommParams
(
send_next
=
send_next
,
recv_next
=
not
fwd_send_only
and
recv_next
),
P2PCommParams
(
send_next
=
send_next
,
recv_next
=
recv_next
)
else
:
return
P2PCommParams
(
send_prev
=
send_prev
,
recv_prev
=
not
fwd_send_only
and
recv_prev
),
P2PCommParams
(
send_prev
=
send_prev
,
recv_prev
=
recv_prev
)
input_tensor
=
recv_forward
(
tensor_shape
,
config
,
master_chunk_id
)[
0
]
fwd_wait_handles_warmup
=
None
# Run warmup forward passes
for
i
in
range
(
schedule
[
'warmup'
][
rank
]):
if
args
.
moe_fb_overlap
:
input_tensors
[
master_chunk_id
].
append
(
(
master_cur_microbatch
,
input_tensor
))
output_tensor_warmup
,
_
=
forward_step_helper
(
master_chunk_id
,
master_cur_microbatch
,
checkpoint_activations_microbatch
,
is_first_microbatch
=
(
i
==
0
))
else
:
output_tensor_warmup
,
num_tokens
=
forward_step_no_model_graph
(
forward_step_func
,
master_chunk_id
,
data_iterator
[
master_chunk_id
],
model
[
master_chunk_id
],
num_microbatches
,
input_tensor
,
forward_data_store
,
config
,
collect_non_loss_data
,
checkpoint_activations_microbatch
,
is_first_microbatch
=
(
i
==
0
),
current_microbatch
=
master_cur_microbatch
)
total_num_tokens
+=
num_tokens
.
item
()
input_tensors
[
master_chunk_id
].
append
(
(
master_cur_microbatch
,
input_tensor
))
output_tensors
[
master_chunk_id
].
append
(
output_tensor_warmup
)
master_cur_microbatch
+=
1
if
i
!=
schedule
[
'warmup'
][
rank
]
-
1
:
input_tensor
,
_
=
send_forward_recv_forward
(
output_tensor_warmup
,
tensor_shape
,
config
,
master_chunk_id
)
deallocate_output_tensor
(
output_tensor_warmup
,
config
.
deallocate_pipeline_outputs
)
else
:
input_tensor
,
_
=
recv_forward
(
tensor_shape
,
config
,
master_chunk_id
)
fwd_wait_handles_warmup
=
send_forward
(
output_tensor_warmup
,
tensor_shape
,
config
,
master_chunk_id
,
async_op
=
True
)
# Run interleaved forward passes for two model chunk
fwd_wait_handles
=
None
fwd_wait_handles_slave_chunk
=
None
fwd_wait_handles_send
=
None
for
i
in
range
(
schedule
[
'interleaved_forward'
][
rank
]):
if
fwd_wait_handles
is
not
None
:
for
req
in
fwd_wait_handles
:
req
.
wait
()
fwd_wait_handles
=
None
is_first_microbatch
=
parallel_state
.
is_pipeline_last_stage
()
and
(
i
==
0
)
set_dualpipe_chunk
(
master_chunk_id
)
if
args
.
moe_fb_overlap
:
input_tensors
[
master_chunk_id
].
append
(
(
master_cur_microbatch
,
input_tensor
))
output_tensor
,
_
=
forward_step_helper
(
master_chunk_id
,
master_cur_microbatch
,
checkpoint_activations_microbatch
,
is_first_microbatch
=
is_first_microbatch
)
else
:
output_tensor
,
num_tokens
=
forward_step_no_model_graph
(
forward_step_func
,
master_chunk_id
,
data_iterator
[
master_chunk_id
],
model
[
master_chunk_id
],
num_microbatches
,
input_tensor
,
forward_data_store
,
config
,
collect_non_loss_data
,
checkpoint_activations_microbatch
,
is_first_microbatch
=
is_first_microbatch
,
current_microbatch
=
master_cur_microbatch
)
total_num_tokens
+=
num_tokens
.
item
()
input_tensors
[
master_chunk_id
].
append
(
(
master_cur_microbatch
,
input_tensor
))
output_tensors
[
master_chunk_id
].
append
(
output_tensor
)
master_cur_microbatch
+=
1
if
not
parallel_state
.
is_pipeline_last_stage
()
and
fwd_wait_handles_send
is
not
None
:
for
req
in
fwd_wait_handles_send
:
req
.
wait
()
deallocate_output_tensor
(
output_tensor_send
,
config
.
deallocate_pipeline_outputs
)
fwd_wait_handles_send
=
None
if
parallel_state
.
is_pipeline_last_stage
():
input_tensor_slave_chunk
=
output_tensor
input_tensor
,
fwd_wait_handles
=
recv_forward
(
tensor_shape
,
config
,
master_chunk_id
,
async_op
=
True
)
else
:
input_tensor_slave_chunk
,
_
=
recv_forward
(
tensor_shape
,
config
,
slave_chunk_id
)
input_tensor
,
fwd_wait_handles
=
recv_forward
(
tensor_shape
,
config
,
master_chunk_id
,
async_op
=
True
)
if
fwd_wait_handles_warmup
is
not
None
:
for
req
in
fwd_wait_handles_warmup
:
req
.
wait
()
deallocate_output_tensor
(
output_tensor_warmup
,
config
.
deallocate_pipeline_outputs
)
fwd_wait_handles_warmup
=
None
if
fwd_wait_handles_slave_chunk
is
not
None
:
for
req
in
fwd_wait_handles_slave_chunk
:
req
.
wait
()
deallocate_output_tensor
(
output_tensor_slave_chunk
,
config
.
deallocate_pipeline_outputs
)
fwd_wait_handles_slave_chunk
=
None
set_dualpipe_chunk
(
slave_chunk_id
)
if
args
.
moe_fb_overlap
:
input_tensors
[
slave_chunk_id
].
append
(
(
slave_cur_microbatch
,
input_tensor_slave_chunk
))
output_tensor_slave_chunk
,
_
=
forward_step_helper
(
slave_chunk_id
,
slave_cur_microbatch
,
checkpoint_activations_microbatch
)
else
:
output_tensor_slave_chunk
,
num_tokens
=
forward_step_no_model_graph
(
forward_step_func
,
slave_chunk_id
,
data_iterator
[
slave_chunk_id
],
model
[
slave_chunk_id
],
num_microbatches
,
input_tensor_slave_chunk
,
forward_data_store
,
config
,
collect_non_loss_data
,
checkpoint_activations_microbatch
,
current_microbatch
=
slave_cur_microbatch
,
)
input_tensors
[
slave_chunk_id
].
append
(
(
slave_cur_microbatch
,
input_tensor_slave_chunk
))
total_num_tokens
+=
num_tokens
.
item
()
output_tensors
[
slave_chunk_id
].
append
(
output_tensor_slave_chunk
)
slave_cur_microbatch
+=
1
if
i
==
schedule
[
'interleaved_forward'
][
rank
]
-
1
:
firstFB_no_overlp
=
False
firstFB_no_overlp_handle
=
None
# last rank not overlap first F&B
if
parallel_state
.
is_pipeline_last_stage
():
firstFB_no_overlp
=
True
output_tensor_grad_bwd
,
firstFB_no_overlp_handle
=
recv_backward
(
tensor_shape
,
config
,
slave_chunk_id
,
async_op
=
True
)
else
:
output_tensor_grad_bwd
,
_
=
recv_backward
(
tensor_shape
,
config
,
slave_chunk_id
)
fwd_wait_handles_slave_chunk
=
send_forward
(
output_tensor_slave_chunk
,
tensor_shape
,
config
,
slave_chunk_id
,
async_op
=
True
)
if
not
parallel_state
.
is_pipeline_last_stage
():
output_tensor_send
=
output_tensor
fwd_wait_handles_send
=
send_forward
(
output_tensor_send
,
tensor_shape
,
config
,
master_chunk_id
,
async_op
=
True
)
if
fwd_wait_handles
is
not
None
:
for
req
in
fwd_wait_handles
:
req
.
wait
()
fwd_wait_handles
=
None
# Run 1b1w1f stages for slave chunk
bwd_wait_handles
=
None
for
_
in
range
(
schedule
[
'1b1w1f'
][
rank
]):
WeightGradStore
.
start_decouple
()
if
args
.
moe_fb_overlap
:
if
is_dualpipev_last_stgae
(
slave_chunk_id
):
input_tensor_bwd
=
logits_inputs
.
pop
(
0
)
output_tensor_bwd
=
output_tensors
[
slave_chunk_id
][
0
]
model_graph
=
None
output_tensor_grad_bwd
=
backward_step_with_model_graph
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
model_type
,
config
,
model_graph
)
input_tensor_bwd
=
input_tensors
[
slave_chunk_id
].
pop
(
0
)[
1
]
output_tensor_bwd
=
output_tensors
[
slave_chunk_id
].
pop
(
0
)
model_graph
=
model_graphs
[
slave_chunk_id
].
pop
(
0
)
input_tensor_grad
=
backward_step_with_model_graph
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
model_type
,
config
,
model_graph
)
else
:
input_tensor_bwd
=
input_tensors
[
slave_chunk_id
].
pop
(
0
)[
1
]
output_tensor_bwd
=
output_tensors
[
slave_chunk_id
].
pop
(
0
)
input_tensor_grad
=
backward_step
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
model_type
,
config
)
WeightGradStore
.
end_decouple
()
# If asynchronous, the memory will rise.
bwd_wait_handles
=
send_backward
(
input_tensor_grad
,
tensor_shape
,
config
,
slave_chunk_id
)
if
fwd_wait_handles_slave_chunk
is
not
None
:
for
req
in
fwd_wait_handles_slave_chunk
:
req
.
wait
()
deallocate_output_tensor
(
output_tensor_slave_chunk
,
config
.
deallocate_pipeline_outputs
)
fwd_wait_handles_slave_chunk
=
None
if
fwd_wait_handles_send
is
not
None
:
for
req
in
fwd_wait_handles_send
:
req
.
wait
()
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
fwd_wait_handles_send
=
None
# If asynchronous, the memory will rise.
input_tensor_slave_chunk
,
recv_forward_handle
=
recv_forward
(
tensor_shape
,
config
,
slave_chunk_id
)
# 1w: Weight Grad Compute
WeightGradStore
.
pop
()
if
recv_forward_handle
is
not
None
:
for
req
in
recv_forward_handle
:
req
.
wait
()
recv_forward_handle
=
None
# 1F: Forward pass
set_dualpipe_chunk
(
slave_chunk_id
)
if
args
.
moe_fb_overlap
:
input_tensors
[
slave_chunk_id
].
append
(
(
slave_cur_microbatch
,
input_tensor_slave_chunk
))
output_tensor_slave_chunk
,
_
=
forward_step_helper
(
slave_chunk_id
,
slave_cur_microbatch
,
checkpoint_activations_microbatch
)
else
:
output_tensor_slave_chunk
,
num_tokens
=
forward_step_no_model_graph
(
forward_step_func
,
slave_chunk_id
,
data_iterator
[
slave_chunk_id
],
model
[
slave_chunk_id
],
num_microbatches
,
input_tensor_slave_chunk
,
forward_data_store
,
config
,
collect_non_loss_data
,
checkpoint_activations_microbatch
,
current_microbatch
=
slave_cur_microbatch
)
input_tensors
[
slave_chunk_id
].
append
(
(
slave_cur_microbatch
,
input_tensor_slave_chunk
))
total_num_tokens
+=
num_tokens
.
item
()
output_tensors
[
slave_chunk_id
].
append
(
output_tensor_slave_chunk
)
slave_cur_microbatch
+=
1
output_tensor_grad_bwd
,
_
=
recv_backward
(
tensor_shape
,
config
,
slave_chunk_id
)
fwd_wait_handles_slave_chunk
=
send_forward
(
output_tensor_slave_chunk
,
tensor_shape
,
config
,
slave_chunk_id
,
async_op
=
True
)
fwd_wait_handles_recv
=
None
# Run overlaping f&bw stages
fwd_model_chunk_id
=
master_chunk_id
bwd_model_chunk_id
=
slave_chunk_id
for
_
in
range
(
schedule
[
'overlap'
][
rank
]
+
schedule
[
'1b1overlap'
][
rank
]
+
schedule
[
'interleaved_backward'
][
rank
]):
only_bwd
=
False
if
fwd_model_chunk_id
==
master_chunk_id
and
master_cur_microbatch
==
master_microbatch_max
:
only_bwd
=
True
if
fwd_model_chunk_id
==
slave_chunk_id
and
slave_cur_microbatch
==
slave_microbatch_max
:
only_bwd
=
True
if
args
.
moe_fb_overlap
and
not
firstFB_no_overlp
:
if
not
only_bwd
:
if
fwd_wait_handles
is
not
None
:
for
req
in
fwd_wait_handles
:
req
.
wait
()
fwd_wait_handles
=
None
if
fwd_wait_handles_recv
is
not
None
:
for
req
in
fwd_wait_handles_recv
:
req
.
wait
()
fwd_wait_handles_recv
=
None
if
bwd_wait_handles
is
not
None
:
for
req
in
bwd_wait_handles
:
req
.
wait
()
bwd_wait_handles
=
None
if
not
parallel_state
.
is_pipeline_last_stage
()
or
fwd_model_chunk_id
==
master_chunk_id
:
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
fwd_microbatch
=
master_cur_microbatch
if
fwd_model_chunk_id
==
master_chunk_id
else
slave_cur_microbatch
set_dualpipe_chunk
(
fwd_model_chunk_id
)
fwd_send_only
=
False
if
fwd_model_chunk_id
==
slave_chunk_id
and
master_cur_microbatch
==
master_microbatch_max
:
fwd_send_only
=
True
extra_block_kwargs
=
{}
if
is_dualpipev_last_stgae
(
bwd_model_chunk_id
):
input_tensor_bwd
=
logits_inputs
.
pop
(
0
)
output_tensor_bwd
=
output_tensors
[
bwd_model_chunk_id
][
0
]
model_graph
=
None
input_tensor_grad
=
backward_step_with_model_graph
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
model_type
,
config
,
model_graph
)
extra_block_kwargs
.
setdefault
(
'bwd_model_grad'
,
input_tensor_grad
)
else
:
extra_block_kwargs
.
setdefault
(
'bwd_model_grad'
,
output_tensor_grad_bwd
)
fwd_pp_comm_params
,
bwd_pp_comm_params
=
check_pipeline_stage
(
fwd_model_chunk_id
,
fwd_send_only
)
fwd_pp_comm_params
.
config
,
bwd_pp_comm_params
.
config
=
config
,
config
fwd_pp_comm_params
.
tensor_shape
,
bwd_pp_comm_params
.
tensor_shape
=
tensor_shape
,
tensor_shape
extra_block_kwargs
.
setdefault
(
'bwd_model_graph'
,
model_graphs
[
bwd_model_chunk_id
].
pop
(
0
))
extra_block_kwargs
.
setdefault
(
'pp_comm_params'
,
fwd_pp_comm_params
)
extra_block_kwargs
.
setdefault
(
'bwd_pp_comm_params'
,
bwd_pp_comm_params
)
input_tensors
[
fwd_model_chunk_id
].
append
(
(
fwd_microbatch
,
input_tensor
))
output_tensor
,
model_graph
,
pp_comm_output
=
forward_step_helper
(
fwd_model_chunk_id
,
fwd_microbatch
,
checkpoint_activations_microbatch
,
extra_block_kwargs
=
extra_block_kwargs
)
if
parallel_state
.
is_pipeline_last_stage
()
and
fwd_model_chunk_id
==
master_chunk_id
:
input_tensor
=
output_tensor
output_tensor_grad_bwd
=
pp_comm_output
.
input_tensor_grad
else
:
input_tensor
,
fwd_wait_handles
=
pp_comm_output
.
input_tensor
,
pp_comm_output
.
fwd_wait_handles
output_tensor_grad_bwd
,
bwd_wait_handles
=
pp_comm_output
.
output_tensor_grad
,
pp_comm_output
.
bwd_wait_handles
if
fwd_model_chunk_id
==
master_chunk_id
:
master_cur_microbatch
+=
1
else
:
slave_cur_microbatch
+=
1
if
fwd_wait_handles_slave_chunk
is
not
None
:
for
req
in
fwd_wait_handles_slave_chunk
:
# 同步上个阶段最后一个slave前向send
req
.
wait
()
deallocate_output_tensor
(
output_tensor_slave_chunk
,
config
.
deallocate_pipeline_outputs
)
fwd_wait_handles_slave_chunk
=
None
else
:
if
fwd_wait_handles
is
not
None
:
for
req
in
fwd_wait_handles
:
req
.
wait
()
fwd_wait_handles
=
None
if
bwd_wait_handles
is
not
None
:
for
req
in
bwd_wait_handles
:
req
.
wait
()
bwd_wait_handles
=
None
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
if
bwd_model_chunk_id
==
slave_chunk_id
and
slave_cur_microbatch
<
slave_microbatch_max
:
input_tensor
,
fwd_wait_handles_recv
=
recv_forward
(
tensor_shape
,
config
,
slave_chunk_id
,
async_op
=
True
)
if
is_dualpipev_last_stgae
(
bwd_model_chunk_id
):
input_tensor_bwd
=
logits_inputs
.
pop
(
0
)
output_tensor_bwd
=
output_tensors
[
bwd_model_chunk_id
][
0
]
model_graph
=
None
output_tensor_grad_bwd
=
backward_step_with_model_graph
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
model_type
,
config
,
model_graph
)
input_tensor_bwd
=
input_tensors
[
bwd_model_chunk_id
].
pop
(
0
)[
1
]
output_tensor_bwd
=
output_tensors
[
bwd_model_chunk_id
].
pop
(
0
)
model_graph
=
model_graphs
[
bwd_model_chunk_id
].
pop
(
0
)
input_tensor_grad
=
backward_step_with_model_graph
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
model_type
,
config
,
model_graph
)
if
parallel_state
.
is_pipeline_last_stage
()
and
fwd_model_chunk_id
==
master_chunk_id
:
output_tensor_grad_bwd
=
input_tensor_grad
else
:
# send_backward_recv_slave_backward
output_tensor_grad_bwd
,
bwd_wait_handles
=
send_forward_recv_slave_forward
(
input_tensor_grad
,
tensor_shape
,
config
,
fwd_model_chunk_id
)
else
:
firstFB_no_overlp
=
False
if
not
only_bwd
:
fwd_microbatch
=
master_cur_microbatch
if
fwd_model_chunk_id
==
master_chunk_id
else
slave_cur_microbatch
set_dualpipe_chunk
(
fwd_model_chunk_id
)
if
args
.
moe_fb_overlap
:
input_tensors
[
fwd_model_chunk_id
].
append
(
(
fwd_microbatch
,
input_tensor
))
output_tensor
,
_
=
forward_step_helper
(
fwd_model_chunk_id
,
fwd_microbatch
,
checkpoint_activations_microbatch
)
else
:
output_tensor
,
num_tokens
=
forward_step_no_model_graph
(
forward_step_func
,
fwd_model_chunk_id
,
data_iterator
[
fwd_model_chunk_id
],
model
[
fwd_model_chunk_id
],
num_microbatches
,
input_tensor
,
forward_data_store
,
config
,
collect_non_loss_data
,
checkpoint_activations_microbatch
,
current_microbatch
=
fwd_microbatch
)
input_tensors
[
fwd_model_chunk_id
].
append
(
(
fwd_microbatch
,
input_tensor
))
total_num_tokens
+=
num_tokens
.
item
()
output_tensors
[
fwd_model_chunk_id
].
append
(
output_tensor
)
if
fwd_model_chunk_id
==
master_chunk_id
:
master_cur_microbatch
+=
1
fwd_send_only
=
False
else
:
slave_cur_microbatch
+=
1
fwd_send_only
=
(
master_cur_microbatch
==
master_microbatch_max
)
if
fwd_send_only
:
fwd_wait_handles
=
send_forward
(
output_tensor
,
tensor_shape
,
config
,
fwd_model_chunk_id
,
async_op
=
True
)
else
:
if
parallel_state
.
is_pipeline_last_stage
()
and
fwd_model_chunk_id
==
master_chunk_id
:
input_tensor
=
output_tensor
else
:
input_tensor
,
fwd_wait_handles
=
send_forward_recv_slave_forward
(
output_tensor
,
tensor_shape
,
config
,
fwd_model_chunk_id
,
async_op
=
True
)
if
firstFB_no_overlp_handle
is
not
None
:
for
req
in
firstFB_no_overlp_handle
:
req
.
wait
()
firstFB_no_overlp_handle
=
None
if
bwd_wait_handles
is
not
None
:
for
req
in
bwd_wait_handles
:
req
.
wait
()
bwd_wait_handles
=
None
if
args
.
moe_fb_overlap
:
if
is_dualpipev_last_stgae
(
bwd_model_chunk_id
):
input_tensor_bwd
=
logits_inputs
.
pop
(
0
)
output_tensor_bwd
=
output_tensors
[
bwd_model_chunk_id
][
0
]
model_graph
=
None
output_tensor_grad_bwd
=
backward_step_with_model_graph
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
model_type
,
config
,
model_graph
)
input_tensor_bwd
=
input_tensors
[
bwd_model_chunk_id
].
pop
(
0
)[
1
]
output_tensor_bwd
=
output_tensors
[
bwd_model_chunk_id
].
pop
(
0
)
model_graph
=
model_graphs
[
bwd_model_chunk_id
].
pop
(
0
)
input_tensor_grad
=
backward_step_with_model_graph
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
model_type
,
config
,
model_graph
)
else
:
input_tensor_bwd
=
input_tensors
[
bwd_model_chunk_id
].
pop
(
0
)[
1
]
output_tensor_bwd
=
output_tensors
[
bwd_model_chunk_id
].
pop
(
0
)
input_tensor_grad
=
backward_step
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
model_type
,
config
)
if
fwd_wait_handles
is
not
None
:
for
req
in
fwd_wait_handles
:
req
.
wait
()
fwd_wait_handles
=
None
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
if
parallel_state
.
is_pipeline_last_stage
()
and
fwd_model_chunk_id
==
master_chunk_id
:
output_tensor_grad_bwd
=
input_tensor_grad
else
:
# send_backward_recv_slave_backward
output_tensor_grad_bwd
,
bwd_wait_handles
=
send_forward_recv_slave_forward
(
input_tensor_grad
,
tensor_shape
,
config
,
fwd_model_chunk_id
,
async_op
=
True
)
if
fwd_wait_handles_slave_chunk
is
not
None
:
for
req
in
fwd_wait_handles_slave_chunk
:
# 同步上个阶段最后一个slave前向send
req
.
wait
()
deallocate_output_tensor
(
output_tensor_slave_chunk
,
config
.
deallocate_pipeline_outputs
)
fwd_wait_handles_slave_chunk
=
None
# only run backward
else
:
if
bwd_model_chunk_id
==
slave_chunk_id
and
slave_cur_microbatch
<
slave_microbatch_max
:
input_tensor
,
_
=
recv_forward
(
tensor_shape
,
config
,
slave_chunk_id
)
if
bwd_wait_handles
is
not
None
:
for
req
in
bwd_wait_handles
:
req
.
wait
()
bwd_wait_handles
=
None
if
args
.
moe_fb_overlap
:
if
is_dualpipev_last_stgae
(
bwd_model_chunk_id
):
input_tensor_bwd
=
logits_inputs
.
pop
(
0
)
output_tensor_bwd
=
output_tensors
[
bwd_model_chunk_id
][
0
]
model_graph
=
None
output_tensor_grad_bwd
=
backward_step_with_model_graph
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
model_type
,
config
,
model_graph
)
input_tensor_bwd
=
input_tensors
[
bwd_model_chunk_id
].
pop
(
0
)[
1
]
output_tensor_bwd
=
output_tensors
[
bwd_model_chunk_id
].
pop
(
0
)
model_graph
=
model_graphs
[
bwd_model_chunk_id
].
pop
(
0
)
input_tensor_grad
=
backward_step_with_model_graph
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
model_type
,
config
,
model_graph
)
else
:
input_tensor_bwd
=
input_tensors
[
bwd_model_chunk_id
].
pop
(
0
)[
1
]
output_tensor_bwd
=
output_tensors
[
bwd_model_chunk_id
].
pop
(
0
)
input_tensor_grad
=
backward_step
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
model_type
,
config
)
if
parallel_state
.
is_pipeline_last_stage
()
and
fwd_model_chunk_id
==
master_chunk_id
:
output_tensor_grad_bwd
=
input_tensor_grad
else
:
# send_backward_recv_slave_backward
output_tensor_grad_bwd
,
bwd_wait_handles
=
send_forward_recv_slave_forward
(
input_tensor_grad
,
tensor_shape
,
config
,
fwd_model_chunk_id
)
# swap fwd & bwd chunks
fwd_model_chunk_id
,
bwd_model_chunk_id
=
bwd_model_chunk_id
,
fwd_model_chunk_id
# Run cooldown phases
merged_input_tensors
=
[]
merged_output_tensors
=
[]
while
len
(
input_tensors
[
0
])
>
0
or
len
(
input_tensors
[
1
])
>
0
:
if
len
(
input_tensors
[
bwd_model_chunk_id
])
>
0
:
merged_input_tensors
.
append
(
input_tensors
[
bwd_model_chunk_id
].
pop
(
0
))
merged_output_tensors
.
append
(
(
output_tensors
[
bwd_model_chunk_id
].
pop
(
0
),
bwd_model_chunk_id
))
if
len
(
input_tensors
[
1
-
bwd_model_chunk_id
])
>
0
:
merged_input_tensors
.
append
(
input_tensors
[
1
-
bwd_model_chunk_id
].
pop
(
0
))
merged_output_tensors
.
append
(
(
output_tensors
[
1
-
bwd_model_chunk_id
].
pop
(
0
),
1
-
bwd_model_chunk_id
))
bwd_wait_handles_recv
=
None
for
i
in
range
(
pp_size
):
if
bwd_wait_handles
is
not
None
:
for
req
in
bwd_wait_handles
:
req
.
wait
()
bwd_wait_handles
=
None
if
bwd_wait_handles_recv
is
not
None
:
for
req
in
bwd_wait_handles_recv
:
req
.
wait
()
bwd_wait_handles_recv
=
None
input_tensor_bwd
=
merged_input_tensors
.
pop
(
0
)[
1
]
output_tensor_bwd
,
bwd_model_chunk_id
=
merged_output_tensors
.
pop
(
0
)
if
not
args
.
dualpipe_no_dw_detach
:
WeightGradStore
.
start_decouple
()
if
args
.
moe_fb_overlap
:
model_graph
=
model_graphs
[
bwd_model_chunk_id
].
pop
(
0
)
input_tensor_grad
=
backward_step_with_model_graph
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
model_type
,
config
,
model_graph
)
else
:
input_tensor_grad
=
backward_step
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
model_type
,
config
)
if
not
args
.
dualpipe_no_dw_detach
:
WeightGradStore
.
end_decouple
()
if
i
==
pp_size
-
1
:
bwd_wait_handles
=
send_backward
(
input_tensor_grad
,
tensor_shape
,
config
,
bwd_model_chunk_id
,
async_op
=
True
)
elif
i
>=
schedule
[
'cooldown'
][
rank
][
0
]
-
1
:
bwd_wait_handles
=
send_backward
(
input_tensor_grad
,
tensor_shape
,
config
,
bwd_model_chunk_id
,
async_op
=
True
)
output_tensor_grad_bwd
,
bwd_wait_handles_recv
=
recv_backward
(
tensor_shape
,
config
,
bwd_model_chunk_id
,
async_op
=
True
)
else
:
if
parallel_state
.
is_pipeline_last_stage
()
and
(
1
-
bwd_model_chunk_id
)
==
master_chunk_id
:
output_tensor_grad_bwd
=
input_tensor_grad
else
:
# send_backward_recv_slave_backward
output_tensor_grad_bwd
,
bwd_wait_handles
=
send_forward_recv_slave_forward
(
input_tensor_grad
,
tensor_shape
,
config
,
1
-
bwd_model_chunk_id
)
WeightGradStore
.
flush_chunk_grad
()
if
i
>=
schedule
[
'cooldown'
][
rank
][
0
]
-
1
:
WeightGradStore
.
pop_single
()
for
_
in
range
(
schedule
[
'cooldown'
][
rank
][
2
]
-
1
):
WeightGradStore
.
pop_single
()
assert
WeightGradStore
.
weight_grad_queue
.
empty
()
if
bwd_wait_handles
is
not
None
:
for
req
in
bwd_wait_handles
:
req
.
wait
()
bwd_wait_handles
=
None
if
config
.
finalize_model_grads_func
is
not
None
and
not
forward_only
:
# If defer_embedding_wgrad_compute is enabled we need to do the
# weight gradient GEMM's here.
finish_embedding_wgrad_compute
(
config
,
embedding_module
)
# Finalize model grads (perform full grad all-reduce / reduce-scatter for
# data parallelism, layernorm all-reduce for sequence parallelism, and
# embedding all-reduce for pipeline parallelism).
config
.
finalize_model_grads_func
(
model
,
total_num_tokens
if
config
.
calculate_per_token_loss
else
None
)
return
forward_data_store
dcu_megatron/core/pipeline_parallel/fb_overlap/__init__.py
0 → 100644
View file @
770fa304
from
.modules.layers
import
linear_backward_wgrad_detach
,
ColumnParallelLinear
,
RowParallelLinear
from
.modules.experts
import
group_mlp_forward_detach
from
.transformer_layer
import
transformer_layer_forward_backward_overlaping
from
.gpt_model
import
gpt_model_forward_backward_overlaping
from
.vpp_schedules
import
forward_backward_pipelining_with_interleaving
\ No newline at end of file
dcu_megatron/core/pipeline_parallel/fb_overlap/adaptor.py
0 → 100644
View file @
770fa304
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
import
torch
def
_make_param_hook
(
self
,
param
:
torch
.
nn
.
Parameter
,
param_to_buffer
,
):
"""
Creates the all-reduce / reduce-scatter hook for backprop.
"""
def
param_hook
(
*
unused
):
if
param
.
requires_grad
and
not
getattr
(
param
,
'skip_grad_accum'
,
False
):
if
self
.
ddp_config
.
overlap_grad_reduce
:
assert
(
param
.
grad
is
not
None
),
'param.grad being None is not safe when overlap_grad_reduce is True'
if
param
.
grad
is
not
None
and
(
not
param
.
grad_added_to_main_grad
or
getattr
(
param
,
'zero_out_wgrad'
,
False
)
):
param
.
main_grad
.
add_
(
param
.
grad
.
data
)
param
.
grad
=
None
# Maybe should called after weightgradstore.pop()
if
self
.
ddp_config
.
overlap_grad_reduce
:
param_to_buffer
[
param
].
register_grad_ready
(
param
)
if
getattr
(
param
,
'skip_grad_accum'
,
False
):
param
.
skip_grad_accum
=
False
return
param_hook
\ No newline at end of file
dcu_megatron/core/pipeline_parallel/fb_overlap/gpt_model.py
0 → 100644
View file @
770fa304
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
import
logging
from
typing
import
Dict
,
Literal
,
Optional
,
Tuple
,
Union
,
List
import
torch
from
torch
import
Tensor
from
megatron.core
import
InferenceParams
,
parallel_state
,
tensor_parallel
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
.transformer_block
import
(
transformer_block_backward
,
transformer_block_forward_backward_overlaping
,
transformer_block_forward
)
from
.modules.utils
import
(
LayerGraph
,
detach_tensor
,
run_graph_backward
)
class
ModelGraph
:
def
__init__
(
self
,
layer_graphs
:
List
[
LayerGraph
],
block_output
,
preprocess_graph
:
Tensor
=
None
,
preprocess_detached_output
:
Tensor
=
None
,
):
self
.
preprocess_graph
=
(
preprocess_graph
,
preprocess_detached_output
)
self
.
layer_graphs
=
layer_graphs
self
.
block_output
=
block_output
def
gpt_model_forward
(
self
,
input_ids
:
Tensor
,
position_ids
:
Tensor
,
attention_mask
:
Tensor
,
decoder_input
:
Tensor
=
None
,
labels
:
Tensor
=
None
,
inference_params
:
InferenceParams
=
None
,
packed_seq_params
:
PackedSeqParams
=
None
,
extra_block_kwargs
:
dict
=
None
,
)
->
Tensor
:
"""Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoeder and finally into the post
processing layer (optional).
It either returns the Loss values if labels are given or the final hidden units
"""
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
# Decoder embedding.
if
decoder_input
is
not
None
:
preprocess_graph
=
None
elif
self
.
pre_process
:
decoder_input
=
self
.
embedding
(
input_ids
=
input_ids
,
position_ids
=
position_ids
)
preprocess_graph
=
decoder_input
else
:
# intermediate stage of pipeline
# decoder will get hidden_states from encoder.input_tensor
decoder_input
=
None
preprocess_graph
=
None
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb
=
None
if
self
.
position_embedding_type
==
'rope'
:
rotary_seq_len
=
self
.
rotary_pos_emb
.
get_rotary_seq_len
(
inference_params
,
self
.
decoder
,
decoder_input
,
self
.
config
)
rotary_pos_emb
=
self
.
rotary_pos_emb
(
rotary_seq_len
)
detached_block_input
=
detach_tensor
(
decoder_input
)
# Run decoder.
hidden_states
,
layer_graphs
=
transformer_block_forward
(
self
.
decoder
,
hidden_states
=
detached_block_input
,
attention_mask
=
attention_mask
,
inference_params
=
inference_params
,
rotary_pos_emb
=
rotary_pos_emb
,
packed_seq_params
=
packed_seq_params
,
**
(
extra_block_kwargs
or
{}),
)
if
not
self
.
post_process
:
return
hidden_states
,
ModelGraph
(
layer_graphs
,
hidden_states
,
preprocess_graph
,
detached_block_input
)
# logits and loss
output_weight
=
None
if
self
.
share_embeddings_and_output_weights
:
output_weight
=
self
.
shared_embedding_or_output_weight
()
logits
,
_
=
self
.
output_layer
(
hidden_states
,
weight
=
output_weight
)
if
labels
is
None
:
# [s b h] => [b s h]
logits
=
logits
.
transpose
(
0
,
1
).
contiguous
()
graph
=
ModelGraph
(
layer_graphs
,
hidden_states
,
preprocess_graph
,
detached_block_input
)
return
logits
,
graph
loss
=
self
.
compute_language_model_loss
(
labels
,
logits
)
graph
=
ModelGraph
(
layer_graphs
,
hidden_states
,
preprocess_graph
,
detached_block_input
)
return
loss
,
graph
def
gpt_model_backward
(
model_grad
,
model_graph
:
ModelGraph
,
):
block_input_grad
=
transformer_block_backward
(
model_grad
,
model_graph
.
layer_graphs
)
if
model_graph
.
preprocess_graph
[
0
]
is
not
None
:
run_graph_backward
(
model_graph
.
preprocess_graph
,
block_input_grad
,
keep_graph
=
True
,
keep_grad
=
True
)
return
None
else
:
return
block_input_grad
def
gpt_model_forward_backward_overlaping
(
fwd_model
,
input_ids
:
Tensor
,
position_ids
:
Tensor
,
attention_mask
:
Tensor
,
decoder_input
:
Tensor
=
None
,
labels
:
Tensor
=
None
,
inference_params
:
InferenceParams
=
None
,
packed_seq_params
:
PackedSeqParams
=
None
,
extra_block_kwargs
:
dict
=
None
,
):
if
extra_block_kwargs
is
None
or
extra_block_kwargs
[
'bwd_model_graph'
]
is
None
:
return
gpt_model_forward
(
fwd_model
,
input_ids
,
position_ids
,
attention_mask
,
decoder_input
,
labels
,
inference_params
,
packed_seq_params
,
extra_block_kwargs
)
bwd_model_grad
,
bwd_model_graph
=
extra_block_kwargs
[
'bwd_model_grad'
],
extra_block_kwargs
[
'bwd_model_graph'
]
# Fwd Model Decoder embedding.
if
decoder_input
is
not
None
:
preprocess_graph
=
None
elif
fwd_model
.
pre_process
:
decoder_input
=
fwd_model
.
embedding
(
input_ids
=
input_ids
,
position_ids
=
position_ids
)
preprocess_graph
=
decoder_input
else
:
# intermediate stage of pipeline
# decoder will get hidden_states from encoder.input_tensor
decoder_input
=
None
preprocess_graph
=
None
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb
=
None
if
fwd_model
.
position_embedding_type
==
'rope'
:
rotary_seq_len
=
fwd_model
.
rotary_pos_emb
.
get_rotary_seq_len
(
inference_params
,
fwd_model
.
decoder
,
decoder_input
,
fwd_model
.
config
)
rotary_pos_emb
=
fwd_model
.
rotary_pos_emb
(
rotary_seq_len
)
detached_block_input
=
detach_tensor
(
decoder_input
)
# Run transformer block fwd & bwd overlaping
(
hidden_states
,
layer_graphs
),
block_input_grad
,
pp_comm_output
\
=
transformer_block_forward_backward_overlaping
(
fwd_model
.
decoder
,
detached_block_input
,
attention_mask
,
bwd_model_grad
,
bwd_model_graph
.
layer_graphs
,
rotary_pos_emb
=
rotary_pos_emb
,
inference_params
=
inference_params
,
packed_seq_params
=
packed_seq_params
,
pp_comm_params
=
extra_block_kwargs
[
'pp_comm_params'
],
bwd_pp_comm_params
=
extra_block_kwargs
[
'bwd_pp_comm_params'
]
)
if
bwd_model_graph
.
preprocess_graph
[
0
]
is
not
None
:
run_graph_backward
(
bwd_model_graph
.
preprocess_graph
,
block_input_grad
,
keep_grad
=
True
,
keep_graph
=
True
)
if
not
fwd_model
.
post_process
:
return
hidden_states
,
ModelGraph
(
layer_graphs
,
hidden_states
,
preprocess_graph
,
detached_block_input
),
pp_comm_output
# logits and loss
output_weight
=
None
if
fwd_model
.
share_embeddings_and_output_weights
:
output_weight
=
fwd_model
.
shared_embedding_or_output_weight
()
logits
,
_
=
fwd_model
.
output_layer
(
hidden_states
,
weight
=
output_weight
)
if
labels
is
None
:
# [s b h] => [b s h]
logits
=
logits
.
transpose
(
0
,
1
).
contiguous
()
graph
=
ModelGraph
(
layer_graphs
,
hidden_states
,
preprocess_graph
,
detached_block_input
)
return
logits
,
graph
,
pp_comm_output
loss
=
fwd_model
.
compute_language_model_loss
(
labels
,
logits
)
graph
=
ModelGraph
(
layer_graphs
,
hidden_states
,
preprocess_graph
,
detached_block_input
)
return
loss
,
graph
,
pp_comm_output
dcu_megatron/core/pipeline_parallel/fb_overlap/modules/__init__.py
0 → 100644
View file @
770fa304
dcu_megatron/core/pipeline_parallel/fb_overlap/modules/attention.py
0 → 100644
View file @
770fa304
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
import
torch
from
megatron.training
import
get_args
from
mindspeed.core.transformer.moe.comm_utils
import
async_all_to_all
from
mindspeed.core.tensor_parallel.random
import
CheckpointWithoutOutput
AsyncAll2All_INPUT
=
[]
AsyncAll2All_OUTPUT
=
[]
def
set_async_alltoall_inputs
(
*
args
):
AsyncAll2All_INPUT
.
append
(
args
)
def
get_async_alltoall_outputs
():
return
AsyncAll2All_OUTPUT
.
pop
(
0
)
def
launch_async_all2all
():
global
AsyncAll2All_INPUT
global
AsyncAll2All_OUTPUT
if
len
(
AsyncAll2All_INPUT
)
>
0
:
input_
,
input_splits
,
output_splits
,
group
=
AsyncAll2All_INPUT
.
pop
(
0
)
_
,
output
,
a2a_handle
=
async_all_to_all
(
input_
,
input_splits
,
output_splits
,
group
)
AsyncAll2All_OUTPUT
.
append
((
output
,
a2a_handle
))
def
launch_async_all2all_hook
(
_
):
launch_async_all2all
()
def
attention_forward
(
self
,
hidden_states
,
residual
,
attention_mask
=
None
,
inference_params
=
None
,
rotary_pos_emb
=
None
,
packed_seq_params
=
None
,
recompute_norm
=
False
):
# Optional Input Layer norm
def
pre_norm
(
hidden_states
):
args
=
get_args
()
input_layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
if
getattr
(
args
,
'input_layernorm_in_fp32'
,
False
):
input_layernorm_output
=
input_layernorm_output
.
float
()
return
input_layernorm_output
if
recompute_norm
:
self
.
norm_ckpt1
=
CheckpointWithoutOutput
()
input_layernorm_output
=
self
.
norm_ckpt1
.
checkpoint
(
pre_norm
,
False
,
hidden_states
)
else
:
input_layernorm_output
=
pre_norm
(
hidden_states
)
# Self attention.
attention_output_with_bias
=
self
.
self_attention
(
input_layernorm_output
,
attention_mask
=
attention_mask
,
inference_params
=
inference_params
,
rotary_pos_emb
=
rotary_pos_emb
,
packed_seq_params
=
packed_seq_params
,
)
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
with
self
.
bias_dropout_add_exec_handler
():
hidden_states
=
self
.
self_attn_bda
(
self
.
training
,
self
.
config
.
bias_dropout_fusion
)(
attention_output_with_bias
,
residual
,
self
.
hidden_dropout
)
if
recompute_norm
:
self
.
norm_ckpt1
.
discard_output
()
hidden_states
.
register_hook
(
self
.
norm_ckpt1
.
recompute
)
return
hidden_states
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment