Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
770fa304
Commit
770fa304
authored
Apr 25, 2025
by
dongcl
Browse files
修改mtp
parent
8096abd4
Changes
44
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2866 additions
and
502 deletions
+2866
-502
dcu_megatron/adaptor/adaptor_arguments.py
dcu_megatron/adaptor/adaptor_arguments.py
+23
-0
dcu_megatron/adaptor/feature_manager/__init__.py
dcu_megatron/adaptor/feature_manager/__init__.py
+7
-0
dcu_megatron/adaptor/feature_manager/base_feature.py
dcu_megatron/adaptor/feature_manager/base_feature.py
+40
-0
dcu_megatron/adaptor/feature_manager/mtp_feature.py
dcu_megatron/adaptor/feature_manager/mtp_feature.py
+51
-0
dcu_megatron/adaptor/feature_manager/pipeline_parallel/dualpipev_feature.py
...or/feature_manager/pipeline_parallel/dualpipev_feature.py
+52
-0
dcu_megatron/adaptor/megatron_adaptor.py
dcu_megatron/adaptor/megatron_adaptor.py
+34
-26
dcu_megatron/adaptor/patch_utils.py
dcu_megatron/adaptor/patch_utils.py
+21
-3
dcu_megatron/core/distributed/finalize_model_grads.py
dcu_megatron/core/distributed/finalize_model_grads.py
+6
-13
dcu_megatron/core/models/common/embeddings/language_model_embedding.py
...core/models/common/embeddings/language_model_embedding.py
+0
-133
dcu_megatron/core/models/common/language_module/language_module.py
...ron/core/models/common/language_module/language_module.py
+137
-0
dcu_megatron/core/models/gpt/gpt_layer_specs.py
dcu_megatron/core/models/gpt/gpt_layer_specs.py
+51
-2
dcu_megatron/core/models/gpt/gpt_model.py
dcu_megatron/core/models/gpt/gpt_model.py
+378
-325
dcu_megatron/core/pipeline_parallel/dualpipev/__init__.py
dcu_megatron/core/pipeline_parallel/dualpipev/__init__.py
+0
-0
dcu_megatron/core/pipeline_parallel/dualpipev/dualpipev_chunks.py
...tron/core/pipeline_parallel/dualpipev/dualpipev_chunks.py
+236
-0
dcu_megatron/core/pipeline_parallel/dualpipev/dualpipev_schedules.py
...n/core/pipeline_parallel/dualpipev/dualpipev_schedules.py
+1498
-0
dcu_megatron/core/pipeline_parallel/fb_overlap/__init__.py
dcu_megatron/core/pipeline_parallel/fb_overlap/__init__.py
+5
-0
dcu_megatron/core/pipeline_parallel/fb_overlap/adaptor.py
dcu_megatron/core/pipeline_parallel/fb_overlap/adaptor.py
+33
-0
dcu_megatron/core/pipeline_parallel/fb_overlap/gpt_model.py
dcu_megatron/core/pipeline_parallel/fb_overlap/gpt_model.py
+209
-0
dcu_megatron/core/pipeline_parallel/fb_overlap/modules/__init__.py
...ron/core/pipeline_parallel/fb_overlap/modules/__init__.py
+0
-0
dcu_megatron/core/pipeline_parallel/fb_overlap/modules/attention.py
...on/core/pipeline_parallel/fb_overlap/modules/attention.py
+85
-0
No files found.
dcu_megatron/adaptor/adaptor_arguments.py
0 → 100644
View file @
770fa304
import
argparse
from
.feature_manager
import
FEATURES_LIST
_ARGS
=
None
def
process_args
(
parser
):
parser
.
conflict_handler
=
'resolve'
for
feature
in
FEATURES_LIST
:
feature
.
register_args
(
parser
)
return
parser
def
get_adaptor_args
():
global
_ARGS
if
_ARGS
is
None
:
parser
=
argparse
.
ArgumentParser
(
description
=
'Adaptor Arguments'
,
allow_abbrev
=
False
)
_ARGS
,
_
=
process_args
(
parser
).
parse_known_args
()
return
_ARGS
dcu_megatron/adaptor/feature_manager/__init__.py
0 → 100644
View file @
770fa304
from
.pipeline_parallel.dualpipev_feature
import
DualpipeVFeature
FEATURES_LIST
=
[
# Pipeline Parallel features
DualpipeVFeature
()
]
dcu_megatron/adaptor/feature_manager/base_feature.py
0 → 100644
View file @
770fa304
# modified from mindspeed
import
argparse
class
BaseFeature
:
def
__init__
(
self
,
feature_name
:
str
,
optimization_level
:
int
=
2
):
self
.
feature_name
=
feature_name
.
strip
().
replace
(
'-'
,
'_'
)
self
.
optimization_level
=
optimization_level
self
.
default_patches
=
self
.
optimization_level
==
0
def
register_args
(
self
,
parser
):
pass
def
pre_validate_args
(
self
,
args
):
pass
def
validate_args
(
self
,
args
):
pass
def
post_validate_args
(
self
,
args
):
pass
def
register_patches
(
self
,
patch_manager
,
args
):
...
def
incompatible_check
(
self
,
global_args
,
check_args
):
if
getattr
(
global_args
,
self
.
feature_name
,
None
)
and
getattr
(
global_args
,
check_args
,
None
):
raise
AssertionError
(
'{} and {} are incompatible.'
.
format
(
self
.
feature_name
,
check_args
))
def
dependency_check
(
self
,
global_args
,
check_args
):
if
getattr
(
global_args
,
self
.
feature_name
,
None
)
and
not
getattr
(
global_args
,
check_args
,
None
):
raise
AssertionError
(
'{} requires {}.'
.
format
(
self
.
feature_name
,
check_args
))
@
staticmethod
def
add_parser_argument_choices_value
(
parser
,
argument_name
,
new_choice
):
for
action
in
parser
.
_actions
:
exist_arg
=
isinstance
(
action
,
argparse
.
Action
)
and
argument_name
in
action
.
option_strings
if
exist_arg
and
action
.
choices
is
not
None
and
new_choice
not
in
action
.
choices
:
action
.
choices
.
append
(
new_choice
)
dcu_megatron/adaptor/feature_manager/mtp_feature.py
0 → 100644
View file @
770fa304
from
argparse
import
ArgumentParser
from
..base_feature
import
BaseFeature
class
MTPFeature
(
BaseFeature
):
def
__init__
(
self
):
super
().
__init__
(
'schedules-method'
)
def
register_args
(
self
,
parser
:
ArgumentParser
):
group
=
parser
.
add_argument_group
(
title
=
self
.
feature_name
)
group
.
add_argument
(
'--schedules-method'
,
type
=
str
,
default
=
None
,
choices
=
[
'dualpipev'
])
def
register_patches
(
self
,
patch_manager
,
args
):
from
...core.distributed.finalize_model_grads
import
_allreduce_word_embedding_grads
from
...core.models.common.language_module.language_module
import
(
setup_embeddings_and_output_layer
,
tie_embeddings_and_output_weights_state_dict
,
)
from
...core.models.gpt.gpt_model
import
GPTModel
from
...training.utils
import
get_batch_on_this_tp_rank
from
...core.pipeline_parallel.schedules
import
forward_step_wrapper
from
...core
import
transformer_block_init_wrapper
MegatronAdaptation
.
register
(
'megatron.core.distributed.finalize_model_grads._allreduce_word_embedding_grads'
,
_allreduce_word_embedding_grads
)
# LanguageModule
MegatronAdaptation
.
register
(
'megatron.core.models.common.language_module.language_module.LanguageModule.setup_embeddings_and_output_layer'
,
setup_embeddings_and_output_layer
)
MegatronAdaptation
.
register
(
'megatron.core.models.common.language_module.language_module.LanguageModule.tie_embeddings_and_output_weights_state_dict'
,
tie_embeddings_and_output_weights_state_dict
)
MegatronAdaptation
.
register
(
'megatron.training.utils.get_batch_on_this_tp_rank'
,
get_batch_on_this_tp_rank
)
# GPT Model
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel'
,
GPTModel
)
# Transformer block
MegatronAdaptation
.
register
(
'megatron.core.transformer.transformer_block.TransformerBlock.__init__'
,
transformer_block_init_wrapper
)
# pipeline_parallel.schedules.forward_step
MegatronAdaptation
.
register
(
'megatron.core.pipeline_parallel.schedules.forward_step'
,
forward_step_wrapper
,
apply_wrapper
=
True
)
dcu_megatron/adaptor/feature_manager/pipeline_parallel/dualpipev_feature.py
0 → 100644
View file @
770fa304
# Modified from mindspeed.
from
argparse
import
ArgumentParser
from
..base_feature
import
BaseFeature
class
DualpipeVFeature
(
BaseFeature
):
def
__init__
(
self
):
super
().
__init__
(
'schedules-method'
)
def
register_args
(
self
,
parser
:
ArgumentParser
):
group
=
parser
.
add_argument_group
(
title
=
self
.
feature_name
)
group
.
add_argument
(
'--schedules-method'
,
type
=
str
,
default
=
None
,
choices
=
[
'dualpipev'
])
def
validate_args
(
self
,
args
):
if
args
.
schedules_method
==
"dualpipev"
:
if
args
.
num_layers_per_virtual_pipeline_stage
is
not
None
:
raise
AssertionError
(
"The dualpipev and virtual_pipeline are incompatible."
)
if
args
.
num_layers
<
args
.
pipeline_model_parallel_size
*
2
:
raise
AssertionError
(
'number of layers must be at least 2*pipeline_model_parallel_size in dualpipe'
)
num_micro_batch
=
args
.
global_batch_size
//
args
.
micro_batch_size
//
args
.
data_parallel_size
if
num_micro_batch
<
args
.
pipeline_model_parallel_size
*
2
-
1
:
raise
AssertionError
(
"num_micro_batch should more than pipeline_model_parallel_size * 2 - 1"
)
def
register_patches
(
self
,
patch_manager
,
args
):
from
megatron.training.utils
import
print_rank_0
from
dcu_megatron.core.pipeline_parallel.dualpipev.dualpipev_schedules
import
forward_backward_pipelining_with_cutinhalf
from
dcu_megatron.core.pipeline_parallel.dualpipev.dualpipev_chunks
import
(
get_model
,
dualpipev_fp16forward
,
get_num_layers_to_build
,
train_step
,
_allreduce_embedding_grads_wrapper
)
if
args
.
schedules_method
==
"dualpipev"
:
patch_manager
.
register_patch
(
'megatron.training.training.get_model'
,
get_model
)
patch_manager
.
register_patch
(
'megatron.training.training.train_step'
,
train_step
)
patch_manager
.
register_patch
(
'megatron.core.pipeline_parallel.schedules.forward_backward_pipelining_without_interleaving'
,
forward_backward_pipelining_with_cutinhalf
)
patch_manager
.
register_patch
(
'megatron.legacy.model.module.Float16Module.forward'
,
dualpipev_fp16forward
)
patch_manager
.
register_patch
(
'megatron.core.transformer.transformer_block.get_num_layers_to_build'
,
get_num_layers_to_build
)
patch_manager
.
register_patch
(
'megatron.training.utils.print_rank_last'
,
print_rank_0
)
patch_manager
.
register_patch
(
'megatron.core.distributed.finalize_model_grads._allreduce_embedding_grads'
,
_allreduce_embedding_grads_wrapper
)
dcu_megatron/adaptor/megatron_adaptor.py
View file @
770fa304
...
...
@@ -5,6 +5,8 @@ import types
import
argparse
import
torch
from
.adaptor_arguments
import
get_adaptor_args
class
MegatronAdaptation
:
"""
...
...
@@ -21,6 +23,15 @@ class MegatronAdaptation:
for
adaptation
in
[
CoreAdaptation
(),
LegacyAdaptation
()]:
adaptation
.
execute
()
MegatronAdaptation
.
apply
()
from
.patch_utils
import
MegatronPatchesManager
args
=
get_adaptor_args
()
for
feature
in
FEATURES_LIST
:
if
(
getattr
(
args
,
feature
.
feature_name
,
None
)
and
feature
.
optimization_level
>
0
)
or
feature
.
optimization_level
==
0
:
feature
.
register_patches
(
MegatronPatchesManager
,
args
)
MindSpeedPatchesManager
.
apply_patches
()
# MegatronAdaptation.post_execute()
@
classmethod
...
...
@@ -87,47 +98,37 @@ class CoreAdaptation(MegatronAdaptationABC):
self
.
patch_miscellaneous
()
def
patch_core_distributed
(
self
):
#
M
tp share embedding
#
m
tp share embedding
from
..core.distributed.finalize_model_grads
import
_allreduce_word_embedding_grads
MegatronAdaptation
.
register
(
'megatron.core.distributed.finalize_model_grads._allreduce_word_embedding_grads'
,
_allreduce_word_embedding_grads
)
def
patch_core_models
(
self
):
from
..core.models.common.embeddings.language_model_embedding
import
(
language_model_embedding_forward
,
language_model_embedding_init_func
)
from
..core.models.gpt.gpt_model
import
(
gpt_model_forward
,
gpt_model_init_wrapper
,
shared_embedding_or_mtp_embedding_weight
from
..core.models.common.language_module.language_module
import
(
setup_embeddings_and_output_layer
,
tie_embeddings_and_output_weights_state_dict
,
)
from
..core.models.gpt.gpt_model
import
GPTModel
from
..training.utils
import
get_batch_on_this_tp_rank
#
Embedding
#
LanguageModule
MegatronAdaptation
.
register
(
'megatron.core.models.common.
embeddings.language_model_embedding.LanguageModelEmbedding.__init__
'
,
language_model_embedding_init_func
)
'megatron.core.models.common.
language_module.language_module.LanguageModule.setup_embeddings_and_output_layer
'
,
setup_embeddings_and_output_layer
)
MegatronAdaptation
.
register
(
'megatron.core.models.common.
embeddings.language_model_embedding.LanguageModelEmbedding.forward
'
,
language_model_embedding_forward
)
'megatron.core.models.common.
language_module.language_module.LanguageModule.tie_embeddings_and_output_weights_state_dict
'
,
tie_embeddings_and_output_weights_state_dict
)
MegatronAdaptation
.
register
(
'megatron.training.utils.get_batch_on_this_tp_rank'
,
get_batch_on_this_tp_rank
)
# GPT Model
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.forward'
,
gpt_model_forward
)
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.__init__'
,
gpt_model_init_wrapper
,
apply_wrapper
=
True
)
from
megatron.core.models.gpt.gpt_model
import
GPTModel
setattr
(
GPTModel
,
'shared_embedding_or_mtp_embedding_weight'
,
shared_embedding_or_mtp_embedding_weight
)
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel'
,
GPTModel
)
def
patch_core_transformers
(
self
):
from
..core
import
transformer_block_init_wrapper
from
..core.transformer.transformer_config
import
TransformerConfigPatch
,
MLATransformerConfigPatch
# Transformer block
# Transformer block
. If mtp_num_layers > 0, move final_layernorm outside
MegatronAdaptation
.
register
(
'megatron.core.transformer.transformer_block.TransformerBlock.__init__'
,
transformer_block_init_wrapper
)
...
...
@@ -165,13 +166,11 @@ class CoreAdaptation(MegatronAdaptationABC):
def
patch_tensor_parallel
(
self
):
from
..core.tensor_parallel.cross_entropy
import
VocabParallelCrossEntropy
from
..core.tensor_parallel
import
vocab_parallel_embedding_forward
,
vocab_parallel_embedding_init
# VocabParallelEmbedding
MegatronAdaptation
.
register
(
'megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward'
,
vocab_parallel_embedding_forward
)
MegatronAdaptation
.
register
(
'megatron.core.tensor_parallel.layers.VocabParallelEmbedding.__init__'
,
vocab_parallel_embedding_init
)
torch
.
compile
(
mode
=
'max-autotune-no-cudagraphs'
),
apply_wrapper
=
True
)
# VocabParallelCrossEntropy
MegatronAdaptation
.
register
(
'megatron.core.tensor_parallel.cross_entropy.VocabParallelCrossEntropy.calculate_predicted_logits'
,
...
...
@@ -201,6 +200,14 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation
.
register
(
"megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_with_transformer_engine_spec"
,
get_gpt_layer_with_flux_spec
)
def
patch_pipeline_parallel
(
self
):
from
..core.pipeline_parallel.schedules
import
forward_step_wrapper
# pipeline_parallel.schedules.forward_step
MegatronAdaptation
.
register
(
'megatron.core.pipeline_parallel.schedules.forward_step'
,
forward_step_wrapper
,
apply_wrapper
=
True
)
def
patch_training
(
self
):
from
..training.tokenizer
import
build_tokenizer
from
..training.initialize
import
_initialize_distributed
...
...
@@ -245,6 +252,7 @@ class LegacyAdaptation(MegatronAdaptationABC):
parallel_mlp_init_wrapper
,
apply_wrapper
=
True
)
# ParallelAttention
MegatronAdaptation
.
register
(
'megatron.legacy.model.transformer.ParallelAttention.__init__'
,
parallel_attention_init_wrapper
,
apply_wrapper
=
True
)
...
...
dcu_megatron/adaptor/patch_utils.py
View file @
770fa304
...
...
@@ -148,11 +148,29 @@ class MegatronPatchesManager:
patches_info
=
{}
@
staticmethod
def
register_patch
(
orig_func_or_cls_name
,
new_func_or_cls
=
None
,
force_patch
=
False
,
create_dummy
=
False
):
def
register_patch
(
orig_func_or_cls_name
,
new_func_or_cls
=
None
,
force_patch
=
False
,
create_dummy
=
False
,
apply_wrapper
=
False
,
remove_origin_wrappers
=
False
):
if
orig_func_or_cls_name
not
in
MegatronPatchesManager
.
patches_info
:
MegatronPatchesManager
.
patches_info
[
orig_func_or_cls_name
]
=
Patch
(
orig_func_or_cls_name
,
new_func_or_cls
,
create_dummy
)
MegatronPatchesManager
.
patches_info
[
orig_func_or_cls_name
]
=
Patch
(
orig_func_or_cls_name
,
new_func_or_cls
,
create_dummy
,
apply_wrapper
=
apply_wrapper
,
remove_origin_wrappers
=
remove_origin_wrappers
)
else
:
MegatronPatchesManager
.
patches_info
.
get
(
orig_func_or_cls_name
).
set_patch_func
(
new_func_or_cls
,
force_patch
)
MegatronPatchesManager
.
patches_info
.
get
(
orig_func_or_cls_name
).
set_patch_func
(
new_func_or_cls
,
force_patch
,
apply_wrapper
=
apply_wrapper
,
remove_origin_wrappers
=
remove_origin_wrappers
)
@
staticmethod
def
apply_patches
():
...
...
dcu_megatron/core/distributed/finalize_model_grads.py
View file @
770fa304
...
...
@@ -28,20 +28,13 @@ def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: Transf
model_module
=
model
[
0
]
model_module
=
get_attr_wrapped_model
(
model_module
,
'pre_process'
,
return_model_obj
=
True
)
if
model_module
.
share_embeddings_and_output_weights
:
weight
=
model_module
.
shared_embedding_or_output_weight
()
grad_attr
=
"main_grad"
if
hasattr
(
weight
,
"main_grad"
)
else
"grad"
orig_grad
=
getattr
(
weight
,
grad_attr
)
grad
=
_unshard_if_dtensor
(
orig_grad
)
torch
.
distributed
.
all_reduce
(
grad
,
group
=
parallel_state
.
get_embedding_group
())
setattr
(
weight
,
grad_attr
,
_reshard_if_dtensor
(
grad
,
orig_grad
))
if
(
hasattr
(
model_module
,
"share_mtp_embedding_and_output_weight"
)
and
model_module
.
share_mtp_embedding_and_output_weight
and
config
.
num_nextn_predict_layers
>
0
):
weight
=
model_module
.
shared_embedding_or_
mtp_embedding
_weight
()
# If share_embeddings_and_output_weights is True, we need to maintain duplicated
# embedding weights in post processing stage. If use Multi-Token Prediction (MTP),
# we also need to maintain duplicated embedding weights in mtp process stage.
# So we need to allreduce grads of embedding in the embedding group in these cases.
if
model_module
.
share_embeddings_and_output_weights
or
getattr
(
config
,
'mtp_num_layers'
,
0
):
weight
=
model_module
.
shared_embedding_or_
output
_weight
()
grad_attr
=
"main_grad"
if
hasattr
(
weight
,
"main_grad"
)
else
"grad"
orig_grad
=
getattr
(
weight
,
grad_attr
)
grad
=
_unshard_if_dtensor
(
orig_grad
)
...
...
dcu_megatron/core/models/common/embeddings/language_model_embedding.py
deleted
100644 → 0
View file @
8096abd4
from
typing
import
Literal
import
torch
from
torch
import
Tensor
from
megatron.core
import
tensor_parallel
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core.models.common.embeddings.language_model_embedding
import
LanguageModelEmbedding
def
language_model_embedding_init_func
(
self
,
config
:
TransformerConfig
,
vocab_size
:
int
,
max_sequence_length
:
int
,
position_embedding_type
:
Literal
[
'learned_absolute'
,
'rope'
,
'none'
]
=
'learned_absolute'
,
num_tokentypes
:
int
=
0
,
scatter_to_sequence_parallel
:
bool
=
True
,
skip_weight_param_allocation
:
bool
=
False
):
"""Patch language model embeddings init."""
super
(
LanguageModelEmbedding
,
self
).
__init__
(
config
=
config
)
self
.
config
:
TransformerConfig
=
config
self
.
vocab_size
:
int
=
vocab_size
self
.
max_sequence_length
:
int
=
max_sequence_length
self
.
add_position_embedding
:
bool
=
position_embedding_type
==
'learned_absolute'
self
.
num_tokentypes
=
num_tokentypes
self
.
scatter_to_sequence_parallel
=
scatter_to_sequence_parallel
self
.
reduce_scatter_embeddings
=
(
(
not
self
.
add_position_embedding
)
and
self
.
num_tokentypes
<=
0
and
self
.
config
.
sequence_parallel
and
self
.
scatter_to_sequence_parallel
)
# Word embeddings (parallel).
self
.
word_embeddings
=
tensor_parallel
.
VocabParallelEmbedding
(
num_embeddings
=
self
.
vocab_size
,
embedding_dim
=
self
.
config
.
hidden_size
,
init_method
=
self
.
config
.
init_method
,
reduce_scatter_embeddings
=
self
.
reduce_scatter_embeddings
,
config
=
self
.
config
,
skip_weight_param_allocation
=
skip_weight_param_allocation
)
# Position embedding (serial).
if
self
.
add_position_embedding
:
self
.
position_embeddings
=
torch
.
nn
.
Embedding
(
self
.
max_sequence_length
,
self
.
config
.
hidden_size
)
# Initialize the position embeddings.
if
self
.
config
.
perform_initialization
:
self
.
config
.
init_method
(
self
.
position_embeddings
.
weight
)
if
self
.
num_tokentypes
>
0
:
self
.
tokentype_embeddings
=
torch
.
nn
.
Embedding
(
self
.
num_tokentypes
,
self
.
config
.
hidden_size
)
# Initialize the token-type embeddings.
if
self
.
config
.
perform_initialization
:
self
.
config
.
init_method
(
self
.
tokentype_embeddings
.
weight
)
else
:
self
.
tokentype_embeddings
=
None
# Embeddings dropout
self
.
embedding_dropout
=
torch
.
nn
.
Dropout
(
self
.
config
.
hidden_dropout
)
def
language_model_embedding_forward
(
self
,
input_ids
:
Tensor
,
position_ids
:
Tensor
,
tokentype_ids
:
int
=
None
,
weight
:
Tensor
=
None
)
->
Tensor
:
"""Pacth forward pass of the embedding module.
Args:
input_ids (Tensor): The input tokens
position_ids (Tensor): The position id's used to calculate position embeddings
tokentype_ids (int): The token type ids. Used when args.bert_binary_head is
set to True. Defaults to None
weight (Tensor): embedding weight
Returns:
Tensor: The output embeddings
"""
if
weight
is
None
:
if
self
.
word_embeddings
.
weight
is
None
:
raise
RuntimeError
(
"weight was not supplied to VocabParallelEmbedding forward pass "
"and skip_weight_param_allocation is True."
)
weight
=
self
.
word_embeddings
.
weight
word_embeddings
=
self
.
word_embeddings
(
input_ids
,
weight
)
if
self
.
add_position_embedding
:
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
embeddings
=
word_embeddings
+
position_embeddings
else
:
embeddings
=
word_embeddings
if
not
self
.
reduce_scatter_embeddings
:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
embeddings
=
embeddings
.
transpose
(
0
,
1
).
contiguous
()
if
tokentype_ids
is
not
None
:
assert
self
.
tokentype_embeddings
is
not
None
# [b s h] -> [s b h] (So that it can be added with embeddings)
tokentype_embedding
=
self
.
tokentype_embeddings
(
tokentype_ids
).
permute
(
1
,
0
,
2
)
embeddings
=
embeddings
+
tokentype_embedding
else
:
assert
self
.
tokentype_embeddings
is
None
# If the input flag for fp32 residual connection is set, convert for float.
if
self
.
config
.
fp32_residual_connection
:
embeddings
=
embeddings
.
float
()
# Dropout.
if
self
.
config
.
sequence_parallel
:
if
not
self
.
reduce_scatter_embeddings
and
self
.
scatter_to_sequence_parallel
:
embeddings
=
tensor_parallel
.
scatter_to_sequence_parallel_region
(
embeddings
)
# `scatter_to_sequence_parallel_region` returns a view, which prevents
# the original tensor from being garbage collected. Clone to facilitate GC.
# Has a small runtime cost (~0.5%).
if
self
.
config
.
clone_scatter_output_in_embedding
and
self
.
scatter_to_sequence_parallel
:
embeddings
=
embeddings
.
clone
()
with
tensor_parallel
.
get_cuda_rng_tracker
().
fork
():
embeddings
=
self
.
embedding_dropout
(
embeddings
)
else
:
embeddings
=
self
.
embedding_dropout
(
embeddings
)
return
embeddings
dcu_megatron/core/models/common/language_module/language_module.py
0 → 100644
View file @
770fa304
import
logging
import
torch
from
megatron.core
import
parallel_state
from
megatron.core.dist_checkpointing.mapping
import
ShardedStateDict
from
megatron.core.models.common.language_module.language_module
import
LanguageModule
from
megatron.core.utils
import
make_tp_sharded_tensor_for_checkpoint
def
setup_embeddings_and_output_layer
(
self
)
->
None
:
"""Sets up embedding layer in first stage and output layer in last stage.
This function initalizes word embeddings in the final stage when we are
using pipeline parallelism and sharing word embeddings, and sets up param
attributes on the embedding and output layers.
"""
# Set `is_embedding_or_output_parameter` attribute.
if
self
.
pre_process
:
self
.
embedding
.
word_embeddings
.
weight
.
is_embedding_or_output_parameter
=
True
if
self
.
post_process
and
self
.
output_layer
.
weight
is
not
None
:
self
.
output_layer
.
weight
.
is_embedding_or_output_parameter
=
True
# If share_embeddings_and_output_weights is True, we need to maintain duplicated
# embedding weights in post processing stage. If use Multi-Token Prediction (MTP),
# we also need to maintain duplicated embedding weights in mtp process stage.
# So we need to copy embedding weights from pre processing stage as initial parameters
# in these cases.
if
not
self
.
share_embeddings_and_output_weights
and
not
getattr
(
self
.
config
,
'mtp_num_layers'
,
0
):
return
if
parallel_state
.
get_pipeline_model_parallel_world_size
()
==
1
:
# Zero out wgrad if sharing embeddings between two layers on same
# pipeline stage to make sure grad accumulation into main_grad is
# correct and does not include garbage values (e.g., from torch.empty).
self
.
shared_embedding_or_output_weight
().
zero_out_wgrad
=
True
return
if
parallel_state
.
is_pipeline_first_stage
()
and
self
.
pre_process
and
not
self
.
post_process
:
self
.
shared_embedding_or_output_weight
().
shared_embedding
=
True
if
(
self
.
post_process
or
getattr
(
self
,
'mtp_process'
,
False
))
and
not
self
.
pre_process
:
assert
not
parallel_state
.
is_pipeline_first_stage
()
# set weights of the duplicated embedding to 0 here,
# then copy weights from pre processing stage using all_reduce below.
weight
=
self
.
shared_embedding_or_output_weight
()
weight
.
data
.
fill_
(
0
)
weight
.
shared
=
True
weight
.
shared_embedding
=
True
# Parameters are shared between the word embeddings layers, and the
# heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different
# workers, so we do the following:
# 1. Create a second copy of word_embeddings on the last stage, with
# initial parameters of 0.0.
# 2. Do an all-reduce between the first and last stage to ensure that
# the two copies of word_embeddings start off with the same
# parameter values.
# 3. In the training loop, before an all-reduce between the grads of
# the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages.
# Ensure that first and last stages have the same initial parameter
# values.
if
torch
.
distributed
.
is_initialized
():
if
parallel_state
.
is_rank_in_embedding_group
():
weight
=
self
.
shared_embedding_or_output_weight
()
weight
.
data
=
weight
.
data
.
cuda
()
torch
.
distributed
.
all_reduce
(
weight
.
data
,
group
=
parallel_state
.
get_embedding_group
()
)
elif
not
getattr
(
LanguageModule
,
"embedding_warning_printed"
,
False
):
logging
.
getLogger
(
__name__
).
warning
(
"Distributed processes aren't initialized, so the output layer "
"is not initialized with weights from the word embeddings. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong."
)
LanguageModule
.
embedding_warning_printed
=
True
def
tie_embeddings_and_output_weights_state_dict
(
self
,
sharded_state_dict
:
ShardedStateDict
,
output_layer_weight_key
:
str
,
first_stage_word_emb_key
:
str
,
)
->
None
:
"""Ties the embedding and output weights in a given sharded state dict.
Args:
sharded_state_dict (ShardedStateDict): state dict with the weight to tie
output_layer_weight_key (str): key of the output layer weight in the state dict.
This entry will be replaced with a tied version
first_stage_word_emb_key (str): this must be the same as the
ShardedTensor.key of the first stage word embeddings.
Returns: None, acts in-place
"""
if
not
self
.
post_process
:
# No output layer
assert
output_layer_weight_key
not
in
sharded_state_dict
,
sharded_state_dict
.
keys
()
return
if
self
.
pre_process
:
# Output layer is equivalent to the embedding already
return
# If use Multi-Token Prediction (MTP), we need maintain both embedding layer and output
# layer in mtp process stage. In this case, if share_embeddings_and_output_weights is True,
# the shared weights will be stored in embedding layer, and output layer will not have
# any weight.
if
getattr
(
self
,
'mtp_process'
,
False
):
# No output layer
assert
output_layer_weight_key
not
in
sharded_state_dict
,
sharded_state_dict
.
keys
()
return
# Replace the default output layer with a one sharing the weights with the embedding
del
sharded_state_dict
[
output_layer_weight_key
]
tensor
=
self
.
shared_embedding_or_output_weight
()
last_stage_word_emb_replica_id
=
(
1
,
# copy of first stage embedding
0
,
parallel_state
.
get_data_parallel_rank
(
with_context_parallel
=
True
),
)
sharded_state_dict
[
output_layer_weight_key
]
=
make_tp_sharded_tensor_for_checkpoint
(
tensor
=
tensor
,
key
=
first_stage_word_emb_key
,
replica_id
=
last_stage_word_emb_replica_id
,
allow_shape_mismatch
=
True
,
)
dcu_megatron/core/models/gpt/gpt_layer_specs.py
View file @
770fa304
...
...
@@ -12,13 +12,13 @@ from megatron.core.transformer.multi_latent_attention import (
MLASelfAttentionSubmodules
,
)
from
megatron.core.transformer.spec_utils
import
ModuleSpec
from
megatron.core.transformer.transformer_block
import
TransformerBlockSubmodules
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core.transformer.transformer_layer
import
(
TransformerLayer
,
TransformerLayerSubmodules
,
)
from
dcu_megatron.core.tensor_parallel.layers
import
FluxColumnParallelLinear
,
FluxRowParallelLinear
from
megatron.core.utils
import
is_te_min_version
try
:
...
...
@@ -36,6 +36,55 @@ try:
except
ImportError
:
warnings
.
warn
(
'Apex is not installed.'
)
from
dcu_megatron.core.tensor_parallel.layers
import
(
FluxColumnParallelLinear
,
FluxRowParallelLinear
)
from
dcu_megatron.core.transformer.multi_token_prediction
import
(
MultiTokenPredictionBlockSubmodules
,
get_mtp_layer_offset
,
get_mtp_layer_spec
,
get_mtp_num_layers_to_build
,
)
def
get_gpt_mtp_block_spec
(
config
:
TransformerConfig
,
spec
:
Union
[
TransformerBlockSubmodules
,
ModuleSpec
],
use_transformer_engine
:
bool
,
)
->
MultiTokenPredictionBlockSubmodules
:
"""GPT Multi-Token Prediction (MTP) block spec."""
num_layers_to_build
=
get_mtp_num_layers_to_build
(
config
)
if
num_layers_to_build
==
0
:
return
None
if
isinstance
(
spec
,
TransformerBlockSubmodules
):
# get the spec for the last layer of decoder block
transformer_layer_spec
=
spec
.
layer_specs
[
-
1
]
elif
isinstance
(
spec
,
ModuleSpec
)
and
spec
.
module
==
TransformerLayer
:
transformer_layer_spec
=
spec
else
:
raise
ValueError
(
f
"Invalid spec:
{
spec
}
"
)
mtp_layer_spec
=
get_mtp_layer_spec
(
transformer_layer_spec
=
transformer_layer_spec
,
use_transformer_engine
=
use_transformer_engine
)
mtp_num_layers
=
config
.
mtp_num_layers
if
config
.
mtp_num_layers
else
0
mtp_layer_specs
=
[
mtp_layer_spec
]
*
mtp_num_layers
offset
=
get_mtp_layer_offset
(
config
)
# split the mtp layer specs to only include the layers that are built in this pipeline stage.
mtp_layer_specs
=
mtp_layer_specs
[
offset
:
offset
+
num_layers_to_build
]
if
len
(
mtp_layer_specs
)
>
0
:
assert
(
len
(
mtp_layer_specs
)
==
config
.
mtp_num_layers
),
+
f
"currently all of the mtp layers must stage in the same pipeline stage."
mtp_block_spec
=
MultiTokenPredictionBlockSubmodules
(
layer_specs
=
mtp_layer_specs
)
else
:
mtp_block_spec
=
None
return
mtp_block_spec
def
get_gpt_layer_with_flux_spec
(
num_experts
:
Optional
[
int
]
=
None
,
...
...
dcu_megatron/core/models/gpt/gpt_model.py
View file @
770fa304
This diff is collapsed.
Click to expand it.
dcu_megatron/core/pipeline_parallel/dualpipev/__init__.py
0 → 100644
View file @
770fa304
dcu_megatron/core/pipeline_parallel/dualpipev/dualpipev_chunks.py
0 → 100644
View file @
770fa304
# Modified from mindspeed.
import
torch
from
functools
import
wraps
from
typing
import
List
,
Optional
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron.core.utils
import
get_model_config
from
megatron.legacy.model
import
Float16Module
from
megatron.core.distributed
import
DistributedDataParallelConfig
from
megatron.core.distributed
import
DistributedDataParallel
as
DDP
from
megatron.core.enums
import
ModelType
from
megatron.training.global_vars
import
get_args
,
get_timers
from
megatron.training.utils
import
unwrap_model
from
megatron.core.pipeline_parallel
import
get_forward_backward_func
from
megatron.legacy.model.module
import
fp32_to_float16
,
float16_to_fp32
from
megatron.core.num_microbatches_calculator
import
get_num_microbatches
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core
import
parallel_state
from
megatron.core.distributed.finalize_model_grads
import
_allreduce_layernorm_grads
from
.dualpipev_schedules
import
get_dualpipe_chunk
def
dualpipev_fp16forward
(
self
,
*
inputs
,
**
kwargs
):
is_pipeline_first_stage
=
mpu
.
is_pipeline_first_stage
()
and
get_dualpipe_chunk
()
==
0
if
is_pipeline_first_stage
:
inputs
=
fp32_to_float16
(
inputs
,
self
.
float16_convertor
)
outputs
=
self
.
module
(
*
inputs
,
**
kwargs
)
is_pipeline_last_stage
=
mpu
.
is_pipeline_first_stage
()
and
get_dualpipe_chunk
()
==
1
if
is_pipeline_last_stage
:
outputs
=
float16_to_fp32
(
outputs
)
return
outputs
def
get_model
(
model_provider_func
,
model_type
=
ModelType
.
encoder_or_decoder
,
wrap_with_ddp
=
True
):
"""Build the model."""
args
=
get_args
()
args
.
model_type
=
model_type
assert
model_type
!=
ModelType
.
encoder_and_decoder
,
\
"Interleaved schedule not supported for model with both encoder and decoder"
model
=
[]
pre_process
,
post_process
=
False
,
False
if
mpu
.
is_pipeline_first_stage
():
pre_process
=
True
args
.
dualpipev_first_chunk
=
True
first_model
=
model_provider_func
(
pre_process
=
pre_process
,
post_process
=
post_process
)
first_model
.
model_type
=
model_type
model
.
append
(
first_model
)
args
.
dualpipev_first_chunk
=
False
second_model
=
model_provider_func
(
pre_process
=
post_process
,
post_process
=
pre_process
)
second_model
.
model_type
=
model_type
model
.
append
(
second_model
)
if
not
isinstance
(
model
,
list
):
model
=
[
model
]
# Set tensor model parallel attributes if not set.
# Only parameters that are already tensor model parallel have these
# attributes set for them. We should make sure the default attributes
# are set for all params so the optimizer can use them.
for
model_module
in
model
:
for
param
in
model_module
.
parameters
():
tensor_parallel
.
set_defaults_if_not_set_tensor_model_parallel_attributes
(
param
)
# Print number of parameters.
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
' > number of parameters on (tensor, pipeline) '
'model parallel rank ({}, {}): {}'
.
format
(
mpu
.
get_tensor_model_parallel_rank
(),
mpu
.
get_pipeline_model_parallel_rank
(),
sum
([
sum
([
p
.
nelement
()
for
p
in
model_module
.
parameters
()])
for
model_module
in
model
])),
flush
=
True
)
# GPU allocation.
for
model_module
in
model
:
model_module
.
cuda
(
torch
.
cuda
.
current_device
())
# Fp16 conversion.
if
args
.
fp16
or
args
.
bf16
:
model
=
[
Float16Module
(
model_module
,
args
)
for
model_module
in
model
]
if
wrap_with_ddp
:
config
=
get_model_config
(
model
[
0
])
ddp_config
=
DistributedDataParallelConfig
(
grad_reduce_in_fp32
=
args
.
accumulate_allreduce_grads_in_fp32
,
overlap_grad_reduce
=
args
.
overlap_grad_reduce
,
use_distributed_optimizer
=
args
.
use_distributed_optimizer
,
check_for_nan_in_grad
=
args
.
check_for_nan_in_loss_and_grad
,
bucket_size
=
args
.
ddp_bucket_size
,
average_in_collective
=
args
.
ddp_average_in_collective
)
model
=
[
DDP
(
config
,
ddp_config
,
model_chunk
,
# Turn off bucketing for model_chunk 2 onwards, since communication for these
# model chunks is overlapped with compute anyway.
disable_bucketing
=
(
model_chunk_idx
>
0
))
for
(
model_chunk_idx
,
model_chunk
)
in
enumerate
(
model
)]
# Broadcast params from data parallel src rank to other data parallel ranks.
if
args
.
data_parallel_random_init
:
for
model_module
in
model
:
model_module
.
broadcast_params
()
return
model
def
train_step
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
opt_param_scheduler
,
config
):
"""Single training step."""
args
=
get_args
()
timers
=
get_timers
()
rerun_state_machine
=
get_rerun_state_machine
()
while
rerun_state_machine
.
should_run_forward_backward
(
data_iterator
):
# Set grad to zero.
for
model_chunk
in
model
:
model_chunk
.
zero_grad_buffer
()
optimizer
.
zero_grad
()
# Forward pass.
forward_backward_func
=
get_forward_backward_func
()
losses_reduced
=
forward_backward_func
(
forward_step_func
=
forward_step_func
,
data_iterator
=
data_iterator
,
model
=
model
,
num_microbatches
=
get_num_microbatches
(),
seq_length
=
args
.
seq_length
,
micro_batch_size
=
args
.
micro_batch_size
,
decoder_seq_length
=
args
.
decoder_seq_length
,
forward_only
=
False
)
should_checkpoint
,
should_exit
,
exit_code
=
rerun_state_machine
.
should_checkpoint_and_exit
()
if
should_exit
:
return
{},
True
,
should_checkpoint
,
should_exit
,
exit_code
,
None
,
None
# Empty unused memory.
if
args
.
empty_unused_memory_level
>=
1
:
torch
.
cuda
.
empty_cache
()
# Vision gradients.
if
getattr
(
args
,
'vision_pretraining'
,
False
)
and
args
.
vision_pretraining_type
==
"dino"
:
unwrapped_model
=
unwrap_model
(
model
[
0
])
unwrapped_model
.
cancel_gradients_last_layer
(
args
.
curr_iteration
)
# Update parameters.
timers
(
'optimizer'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
update_successful
,
grad_norm
,
num_zeros_in_grad
=
optimizer
.
step
()
timers
(
'optimizer'
).
stop
()
# when freezing sub-models we may have a mixture of successful and unsucessful ranks,
# so we must gather across mp ranks
update_successful
=
logical_and_across_model_parallel_group
(
update_successful
)
# grad_norm and num_zeros_in_grad will be None on ranks without trainable params,
# so we must gather across mp ranks
grad_norm
=
reduce_max_stat_across_model_parallel_group
(
grad_norm
)
if
args
.
log_num_zeros_in_grad
:
num_zeros_in_grad
=
reduce_max_stat_across_model_parallel_group
(
num_zeros_in_grad
)
# Vision momentum.
if
getattr
(
args
,
'vision_pretraining'
,
False
)
and
args
.
vision_pretraining_type
==
"dino"
:
unwrapped_model
=
unwrap_model
(
model
[
0
])
unwrapped_model
.
update_momentum
(
args
.
curr_iteration
)
# Update learning rate.
if
update_successful
:
increment
=
get_num_microbatches
()
*
\
args
.
micro_batch_size
*
\
args
.
data_parallel_size
opt_param_scheduler
.
step
(
increment
=
increment
)
skipped_iter
=
0
else
:
skipped_iter
=
1
# Empty unused memory.
if
args
.
empty_unused_memory_level
>=
2
:
torch
.
cuda
.
empty_cache
()
dualpipev_last_stage
=
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
)
if
dualpipev_last_stage
:
# Average loss across microbatches.
loss_reduced
=
{}
for
key
in
losses_reduced
[
0
].
keys
():
numerator
=
0
denominator
=
0
for
x
in
losses_reduced
:
val
=
x
[
key
]
# there is one dict per microbatch. in new reporting, we average
# over the total number of tokens across the global batch.
if
isinstance
(
val
,
tuple
)
or
isinstance
(
val
,
list
):
numerator
+=
val
[
0
]
denominator
+=
val
[
1
]
else
:
# legacy behavior. we average over the number of microbatches,
# and so the denominator is 1.
numerator
+=
val
denominator
+=
1
loss_reduced
[
key
]
=
numerator
/
denominator
return
loss_reduced
,
skipped_iter
,
should_checkpoint
,
should_exit
,
exit_code
,
grad_norm
,
num_zeros_in_grad
return
{},
skipped_iter
,
should_checkpoint
,
should_exit
,
exit_code
,
grad_norm
,
num_zeros_in_grad
def
get_num_layers_to_build
(
config
:
TransformerConfig
)
->
int
:
num_layers_per_pipeline_rank
=
(
config
.
num_layers
//
parallel_state
.
get_pipeline_model_parallel_world_size
()
)
num_layers_to_build
=
num_layers_per_pipeline_rank
//
2
return
num_layers_to_build
def
_allreduce_embedding_grads_wrapper
(
fn
):
@
wraps
(
fn
)
def
wrapper
(
*
args
,
**
kwargs
):
if
get_args
().
schedules_method
==
'dualpipev'
:
# dualpipev no need to do embedding allreduce
# embedding and lm head are on save rank.
if
not
get_args
().
untie_embeddings_and_output_weights
:
raise
NotImplementedError
else
:
return
else
:
return
fn
(
*
args
,
**
kwargs
)
return
wrapper
dcu_megatron/core/pipeline_parallel/dualpipev/dualpipev_schedules.py
0 → 100644
View file @
770fa304
This diff is collapsed.
Click to expand it.
dcu_megatron/core/pipeline_parallel/fb_overlap/__init__.py
0 → 100644
View file @
770fa304
from
.modules.layers
import
linear_backward_wgrad_detach
,
ColumnParallelLinear
,
RowParallelLinear
from
.modules.experts
import
group_mlp_forward_detach
from
.transformer_layer
import
transformer_layer_forward_backward_overlaping
from
.gpt_model
import
gpt_model_forward_backward_overlaping
from
.vpp_schedules
import
forward_backward_pipelining_with_interleaving
\ No newline at end of file
dcu_megatron/core/pipeline_parallel/fb_overlap/adaptor.py
0 → 100644
View file @
770fa304
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
import
torch
def
_make_param_hook
(
self
,
param
:
torch
.
nn
.
Parameter
,
param_to_buffer
,
):
"""
Creates the all-reduce / reduce-scatter hook for backprop.
"""
def
param_hook
(
*
unused
):
if
param
.
requires_grad
and
not
getattr
(
param
,
'skip_grad_accum'
,
False
):
if
self
.
ddp_config
.
overlap_grad_reduce
:
assert
(
param
.
grad
is
not
None
),
'param.grad being None is not safe when overlap_grad_reduce is True'
if
param
.
grad
is
not
None
and
(
not
param
.
grad_added_to_main_grad
or
getattr
(
param
,
'zero_out_wgrad'
,
False
)
):
param
.
main_grad
.
add_
(
param
.
grad
.
data
)
param
.
grad
=
None
# Maybe should called after weightgradstore.pop()
if
self
.
ddp_config
.
overlap_grad_reduce
:
param_to_buffer
[
param
].
register_grad_ready
(
param
)
if
getattr
(
param
,
'skip_grad_accum'
,
False
):
param
.
skip_grad_accum
=
False
return
param_hook
\ No newline at end of file
dcu_megatron/core/pipeline_parallel/fb_overlap/gpt_model.py
0 → 100644
View file @
770fa304
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
import
logging
from
typing
import
Dict
,
Literal
,
Optional
,
Tuple
,
Union
,
List
import
torch
from
torch
import
Tensor
from
megatron.core
import
InferenceParams
,
parallel_state
,
tensor_parallel
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
.transformer_block
import
(
transformer_block_backward
,
transformer_block_forward_backward_overlaping
,
transformer_block_forward
)
from
.modules.utils
import
(
LayerGraph
,
detach_tensor
,
run_graph_backward
)
class
ModelGraph
:
def
__init__
(
self
,
layer_graphs
:
List
[
LayerGraph
],
block_output
,
preprocess_graph
:
Tensor
=
None
,
preprocess_detached_output
:
Tensor
=
None
,
):
self
.
preprocess_graph
=
(
preprocess_graph
,
preprocess_detached_output
)
self
.
layer_graphs
=
layer_graphs
self
.
block_output
=
block_output
def
gpt_model_forward
(
self
,
input_ids
:
Tensor
,
position_ids
:
Tensor
,
attention_mask
:
Tensor
,
decoder_input
:
Tensor
=
None
,
labels
:
Tensor
=
None
,
inference_params
:
InferenceParams
=
None
,
packed_seq_params
:
PackedSeqParams
=
None
,
extra_block_kwargs
:
dict
=
None
,
)
->
Tensor
:
"""Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoeder and finally into the post
processing layer (optional).
It either returns the Loss values if labels are given or the final hidden units
"""
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
# Decoder embedding.
if
decoder_input
is
not
None
:
preprocess_graph
=
None
elif
self
.
pre_process
:
decoder_input
=
self
.
embedding
(
input_ids
=
input_ids
,
position_ids
=
position_ids
)
preprocess_graph
=
decoder_input
else
:
# intermediate stage of pipeline
# decoder will get hidden_states from encoder.input_tensor
decoder_input
=
None
preprocess_graph
=
None
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb
=
None
if
self
.
position_embedding_type
==
'rope'
:
rotary_seq_len
=
self
.
rotary_pos_emb
.
get_rotary_seq_len
(
inference_params
,
self
.
decoder
,
decoder_input
,
self
.
config
)
rotary_pos_emb
=
self
.
rotary_pos_emb
(
rotary_seq_len
)
detached_block_input
=
detach_tensor
(
decoder_input
)
# Run decoder.
hidden_states
,
layer_graphs
=
transformer_block_forward
(
self
.
decoder
,
hidden_states
=
detached_block_input
,
attention_mask
=
attention_mask
,
inference_params
=
inference_params
,
rotary_pos_emb
=
rotary_pos_emb
,
packed_seq_params
=
packed_seq_params
,
**
(
extra_block_kwargs
or
{}),
)
if
not
self
.
post_process
:
return
hidden_states
,
ModelGraph
(
layer_graphs
,
hidden_states
,
preprocess_graph
,
detached_block_input
)
# logits and loss
output_weight
=
None
if
self
.
share_embeddings_and_output_weights
:
output_weight
=
self
.
shared_embedding_or_output_weight
()
logits
,
_
=
self
.
output_layer
(
hidden_states
,
weight
=
output_weight
)
if
labels
is
None
:
# [s b h] => [b s h]
logits
=
logits
.
transpose
(
0
,
1
).
contiguous
()
graph
=
ModelGraph
(
layer_graphs
,
hidden_states
,
preprocess_graph
,
detached_block_input
)
return
logits
,
graph
loss
=
self
.
compute_language_model_loss
(
labels
,
logits
)
graph
=
ModelGraph
(
layer_graphs
,
hidden_states
,
preprocess_graph
,
detached_block_input
)
return
loss
,
graph
def
gpt_model_backward
(
model_grad
,
model_graph
:
ModelGraph
,
):
block_input_grad
=
transformer_block_backward
(
model_grad
,
model_graph
.
layer_graphs
)
if
model_graph
.
preprocess_graph
[
0
]
is
not
None
:
run_graph_backward
(
model_graph
.
preprocess_graph
,
block_input_grad
,
keep_graph
=
True
,
keep_grad
=
True
)
return
None
else
:
return
block_input_grad
def
gpt_model_forward_backward_overlaping
(
fwd_model
,
input_ids
:
Tensor
,
position_ids
:
Tensor
,
attention_mask
:
Tensor
,
decoder_input
:
Tensor
=
None
,
labels
:
Tensor
=
None
,
inference_params
:
InferenceParams
=
None
,
packed_seq_params
:
PackedSeqParams
=
None
,
extra_block_kwargs
:
dict
=
None
,
):
if
extra_block_kwargs
is
None
or
extra_block_kwargs
[
'bwd_model_graph'
]
is
None
:
return
gpt_model_forward
(
fwd_model
,
input_ids
,
position_ids
,
attention_mask
,
decoder_input
,
labels
,
inference_params
,
packed_seq_params
,
extra_block_kwargs
)
bwd_model_grad
,
bwd_model_graph
=
extra_block_kwargs
[
'bwd_model_grad'
],
extra_block_kwargs
[
'bwd_model_graph'
]
# Fwd Model Decoder embedding.
if
decoder_input
is
not
None
:
preprocess_graph
=
None
elif
fwd_model
.
pre_process
:
decoder_input
=
fwd_model
.
embedding
(
input_ids
=
input_ids
,
position_ids
=
position_ids
)
preprocess_graph
=
decoder_input
else
:
# intermediate stage of pipeline
# decoder will get hidden_states from encoder.input_tensor
decoder_input
=
None
preprocess_graph
=
None
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb
=
None
if
fwd_model
.
position_embedding_type
==
'rope'
:
rotary_seq_len
=
fwd_model
.
rotary_pos_emb
.
get_rotary_seq_len
(
inference_params
,
fwd_model
.
decoder
,
decoder_input
,
fwd_model
.
config
)
rotary_pos_emb
=
fwd_model
.
rotary_pos_emb
(
rotary_seq_len
)
detached_block_input
=
detach_tensor
(
decoder_input
)
# Run transformer block fwd & bwd overlaping
(
hidden_states
,
layer_graphs
),
block_input_grad
,
pp_comm_output
\
=
transformer_block_forward_backward_overlaping
(
fwd_model
.
decoder
,
detached_block_input
,
attention_mask
,
bwd_model_grad
,
bwd_model_graph
.
layer_graphs
,
rotary_pos_emb
=
rotary_pos_emb
,
inference_params
=
inference_params
,
packed_seq_params
=
packed_seq_params
,
pp_comm_params
=
extra_block_kwargs
[
'pp_comm_params'
],
bwd_pp_comm_params
=
extra_block_kwargs
[
'bwd_pp_comm_params'
]
)
if
bwd_model_graph
.
preprocess_graph
[
0
]
is
not
None
:
run_graph_backward
(
bwd_model_graph
.
preprocess_graph
,
block_input_grad
,
keep_grad
=
True
,
keep_graph
=
True
)
if
not
fwd_model
.
post_process
:
return
hidden_states
,
ModelGraph
(
layer_graphs
,
hidden_states
,
preprocess_graph
,
detached_block_input
),
pp_comm_output
# logits and loss
output_weight
=
None
if
fwd_model
.
share_embeddings_and_output_weights
:
output_weight
=
fwd_model
.
shared_embedding_or_output_weight
()
logits
,
_
=
fwd_model
.
output_layer
(
hidden_states
,
weight
=
output_weight
)
if
labels
is
None
:
# [s b h] => [b s h]
logits
=
logits
.
transpose
(
0
,
1
).
contiguous
()
graph
=
ModelGraph
(
layer_graphs
,
hidden_states
,
preprocess_graph
,
detached_block_input
)
return
logits
,
graph
,
pp_comm_output
loss
=
fwd_model
.
compute_language_model_loss
(
labels
,
logits
)
graph
=
ModelGraph
(
layer_graphs
,
hidden_states
,
preprocess_graph
,
detached_block_input
)
return
loss
,
graph
,
pp_comm_output
dcu_megatron/core/pipeline_parallel/fb_overlap/modules/__init__.py
0 → 100644
View file @
770fa304
dcu_megatron/core/pipeline_parallel/fb_overlap/modules/attention.py
0 → 100644
View file @
770fa304
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
import
torch
from
megatron.training
import
get_args
from
mindspeed.core.transformer.moe.comm_utils
import
async_all_to_all
from
mindspeed.core.tensor_parallel.random
import
CheckpointWithoutOutput
AsyncAll2All_INPUT
=
[]
AsyncAll2All_OUTPUT
=
[]
def
set_async_alltoall_inputs
(
*
args
):
AsyncAll2All_INPUT
.
append
(
args
)
def
get_async_alltoall_outputs
():
return
AsyncAll2All_OUTPUT
.
pop
(
0
)
def
launch_async_all2all
():
global
AsyncAll2All_INPUT
global
AsyncAll2All_OUTPUT
if
len
(
AsyncAll2All_INPUT
)
>
0
:
input_
,
input_splits
,
output_splits
,
group
=
AsyncAll2All_INPUT
.
pop
(
0
)
_
,
output
,
a2a_handle
=
async_all_to_all
(
input_
,
input_splits
,
output_splits
,
group
)
AsyncAll2All_OUTPUT
.
append
((
output
,
a2a_handle
))
def
launch_async_all2all_hook
(
_
):
launch_async_all2all
()
def
attention_forward
(
self
,
hidden_states
,
residual
,
attention_mask
=
None
,
inference_params
=
None
,
rotary_pos_emb
=
None
,
packed_seq_params
=
None
,
recompute_norm
=
False
):
# Optional Input Layer norm
def
pre_norm
(
hidden_states
):
args
=
get_args
()
input_layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
if
getattr
(
args
,
'input_layernorm_in_fp32'
,
False
):
input_layernorm_output
=
input_layernorm_output
.
float
()
return
input_layernorm_output
if
recompute_norm
:
self
.
norm_ckpt1
=
CheckpointWithoutOutput
()
input_layernorm_output
=
self
.
norm_ckpt1
.
checkpoint
(
pre_norm
,
False
,
hidden_states
)
else
:
input_layernorm_output
=
pre_norm
(
hidden_states
)
# Self attention.
attention_output_with_bias
=
self
.
self_attention
(
input_layernorm_output
,
attention_mask
=
attention_mask
,
inference_params
=
inference_params
,
rotary_pos_emb
=
rotary_pos_emb
,
packed_seq_params
=
packed_seq_params
,
)
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
with
self
.
bias_dropout_add_exec_handler
():
hidden_states
=
self
.
self_attn_bda
(
self
.
training
,
self
.
config
.
bias_dropout_fusion
)(
attention_output_with_bias
,
residual
,
self
.
hidden_dropout
)
if
recompute_norm
:
self
.
norm_ckpt1
.
discard_output
()
hidden_states
.
register_hook
(
self
.
norm_ckpt1
.
recompute
)
return
hidden_states
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment