Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
1f7b14ab
Commit
1f7b14ab
authored
Apr 26, 2025
by
sdwldchl
Browse files
rewrite mtp
parent
89d29a02
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
1429 additions
and
616 deletions
+1429
-616
dcu_megatron/adaptor/megatron_adaptor.py
dcu_megatron/adaptor/megatron_adaptor.py
+30
-32
dcu_megatron/adaptor/patch_utils.py
dcu_megatron/adaptor/patch_utils.py
+21
-3
dcu_megatron/core/distributed/finalize_model_grads.py
dcu_megatron/core/distributed/finalize_model_grads.py
+6
-1
dcu_megatron/core/models/common/language_module/language_module.py
...ron/core/models/common/language_module/language_module.py
+6
-5
dcu_megatron/core/models/gpt/gpt_layer_specs.py
dcu_megatron/core/models/gpt/gpt_layer_specs.py
+52
-3
dcu_megatron/core/models/gpt/gpt_model.py
dcu_megatron/core/models/gpt/gpt_model.py
+383
-350
dcu_megatron/core/pipeline_parallel/schedules.py
dcu_megatron/core/pipeline_parallel/schedules.py
+52
-0
dcu_megatron/core/tensor_parallel/__init__.py
dcu_megatron/core/tensor_parallel/__init__.py
+0
-2
dcu_megatron/core/tensor_parallel/layers.py
dcu_megatron/core/tensor_parallel/layers.py
+8
-96
dcu_megatron/core/transformer/multi_token_prediction.py
dcu_megatron/core/transformer/multi_token_prediction.py
+737
-0
dcu_megatron/core/transformer/transformer_block.py
dcu_megatron/core/transformer/transformer_block.py
+1
-1
dcu_megatron/core/transformer/transformer_config.py
dcu_megatron/core/transformer/transformer_config.py
+27
-13
dcu_megatron/training/arguments.py
dcu_megatron/training/arguments.py
+10
-8
dcu_megatron/training/utils.py
dcu_megatron/training/utils.py
+82
-88
pretrain_gpt.py
pretrain_gpt.py
+14
-14
No files found.
dcu_megatron/adaptor/megatron_adaptor.py
View file @
1f7b14ab
...
@@ -5,6 +5,8 @@ import types
...
@@ -5,6 +5,8 @@ import types
import
argparse
import
argparse
import
torch
import
torch
from
.adaptor_arguments
import
get_adaptor_args
class
MegatronAdaptation
:
class
MegatronAdaptation
:
"""
"""
...
@@ -21,6 +23,15 @@ class MegatronAdaptation:
...
@@ -21,6 +23,15 @@ class MegatronAdaptation:
for
adaptation
in
[
CoreAdaptation
(),
LegacyAdaptation
()]:
for
adaptation
in
[
CoreAdaptation
(),
LegacyAdaptation
()]:
adaptation
.
execute
()
adaptation
.
execute
()
MegatronAdaptation
.
apply
()
MegatronAdaptation
.
apply
()
# from .patch_utils import MegatronPatchesManager
# args = get_adaptor_args()
# for feature in FEATURES_LIST:
# if (getattr(args, feature.feature_name, None) and feature.optimization_level > 0) or feature.optimization_level == 0:
# feature.register_patches(MegatronPatchesManager, args)
# MindSpeedPatchesManager.apply_patches()
# MegatronAdaptation.post_execute()
# MegatronAdaptation.post_execute()
@
classmethod
@
classmethod
...
@@ -87,38 +98,20 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -87,38 +98,20 @@ class CoreAdaptation(MegatronAdaptationABC):
self
.
patch_miscellaneous
()
self
.
patch_miscellaneous
()
def
patch_core_distributed
(
self
):
def
patch_core_distributed
(
self
):
#
M
tp share embedding
#
m
tp share embedding
from
..core.distributed.finalize_model_grads
import
_allreduce_word_embedding_grads
from
..core.distributed.finalize_model_grads
import
_allreduce_word_embedding_grads
MegatronAdaptation
.
register
(
'megatron.core.distributed.finalize_model_grads._allreduce_word_embedding_grads'
,
MegatronAdaptation
.
register
(
'megatron.core.distributed.finalize_model_grads._allreduce_word_embedding_grads'
,
_allreduce_word_embedding_grads
)
_allreduce_word_embedding_grads
)
def
patch_core_models
(
self
):
def
patch_core_models
(
self
):
from
..core.models.common.embeddings.language_model_embedding
import
(
language_model_embedding_forward
,
language_model_embedding_init_func
)
from
..core.models.gpt.gpt_model
import
(
gpt_model_forward
,
gpt_model_init
,
shared_embedding_or_output_weight
,
)
from
..core.models.common.language_module.language_module
import
(
from
..core.models.common.language_module.language_module
import
(
setup_embeddings_and_output_layer
,
setup_embeddings_and_output_layer
,
tie_embeddings_and_output_weights_state_dict
tie_embeddings_and_output_weights_state_dict
,
)
)
from
..core.models.gpt.gpt_model
import
GPTModel
from
..training.utils
import
get_batch_on_this_tp_rank
from
..training.utils
import
get_batch_on_this_tp_rank
# Embedding
# LanguageModule
MegatronAdaptation
.
register
(
'megatron.core.models.common.embeddings.language_model_embedding.LanguageModelEmbedding.__init__'
,
language_model_embedding_init_func
)
MegatronAdaptation
.
register
(
'megatron.core.models.common.embeddings.language_model_embedding.LanguageModelEmbedding.forward'
,
language_model_embedding_forward
)
MegatronAdaptation
.
register
(
'megatron.training.utils.get_batch_on_this_tp_rank'
,
get_batch_on_this_tp_rank
)
# GPT Model
MegatronAdaptation
.
register
(
MegatronAdaptation
.
register
(
'megatron.core.models.common.language_module.language_module.LanguageModule.setup_embeddings_and_output_layer'
,
'megatron.core.models.common.language_module.language_module.LanguageModule.setup_embeddings_and_output_layer'
,
setup_embeddings_and_output_layer
)
setup_embeddings_and_output_layer
)
...
@@ -126,17 +119,16 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -126,17 +119,16 @@ class CoreAdaptation(MegatronAdaptationABC):
'megatron.core.models.common.language_module.language_module.LanguageModule.tie_embeddings_and_output_weights_state_dict'
,
'megatron.core.models.common.language_module.language_module.LanguageModule.tie_embeddings_and_output_weights_state_dict'
,
tie_embeddings_and_output_weights_state_dict
)
tie_embeddings_and_output_weights_state_dict
)
MegatronAdaptation
.
register
(
MegatronAdaptation
.
register
(
'megatron.training.utils.get_batch_on_this_tp_rank'
,
get_batch_on_this_tp_rank
)
'megatron.core.models.gpt.gpt_model.GPTModel.shared_embedding_or_output_weight'
,
shared_embedding_or_output_weight
)
# GPT Model
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.forward'
,
gpt_model_forward
)
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel'
,
GPTModel
)
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.__init__'
,
gpt_model_init
)
def
patch_core_transformers
(
self
):
def
patch_core_transformers
(
self
):
from
..core
import
transformer_block_init_wrapper
from
..core
import
transformer_block_init_wrapper
from
..core.transformer.transformer_config
import
TransformerConfigPatch
,
MLATransformerConfigPatch
from
..core.transformer.transformer_config
import
TransformerConfigPatch
,
MLATransformerConfigPatch
# Transformer block
# Transformer block
. If mtp_num_layers > 0, move final_layernorm outside
MegatronAdaptation
.
register
(
'megatron.core.transformer.transformer_block.TransformerBlock.__init__'
,
MegatronAdaptation
.
register
(
'megatron.core.transformer.transformer_block.TransformerBlock.__init__'
,
transformer_block_init_wrapper
)
transformer_block_init_wrapper
)
...
@@ -174,13 +166,10 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -174,13 +166,10 @@ class CoreAdaptation(MegatronAdaptationABC):
def
patch_tensor_parallel
(
self
):
def
patch_tensor_parallel
(
self
):
from
..core.tensor_parallel.cross_entropy
import
VocabParallelCrossEntropy
from
..core.tensor_parallel.cross_entropy
import
VocabParallelCrossEntropy
from
..core.tensor_parallel
import
vocab_parallel_embedding_forward
,
vocab_parallel_embedding_init_wrapper
# VocabParallelEmbedding
# VocabParallelEmbedding
MegatronAdaptation
.
register
(
'megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward'
,
MegatronAdaptation
.
register
(
'megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward'
,
vocab_parallel_embedding_forward
)
torch
.
compile
(
mode
=
'max-autotune-no-cudagraphs'
),
MegatronAdaptation
.
register
(
'megatron.core.tensor_parallel.layers.VocabParallelEmbedding.__init__'
,
vocab_parallel_embedding_init_wrapper
,
apply_wrapper
=
True
)
apply_wrapper
=
True
)
# VocabParallelCrossEntropy
# VocabParallelCrossEntropy
...
@@ -211,6 +200,14 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -211,6 +200,14 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation
.
register
(
"megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_with_transformer_engine_spec"
,
MegatronAdaptation
.
register
(
"megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_with_transformer_engine_spec"
,
get_gpt_layer_with_flux_spec
)
get_gpt_layer_with_flux_spec
)
def
patch_pipeline_parallel
(
self
):
from
..core.pipeline_parallel.schedules
import
forward_step_wrapper
# pipeline_parallel.schedules.forward_step
MegatronAdaptation
.
register
(
'megatron.core.pipeline_parallel.schedules.forward_step'
,
forward_step_wrapper
,
apply_wrapper
=
True
)
def
patch_training
(
self
):
def
patch_training
(
self
):
from
..training.tokenizer
import
build_tokenizer
from
..training.tokenizer
import
build_tokenizer
from
..training.initialize
import
_initialize_distributed
from
..training.initialize
import
_initialize_distributed
...
@@ -255,6 +252,7 @@ class LegacyAdaptation(MegatronAdaptationABC):
...
@@ -255,6 +252,7 @@ class LegacyAdaptation(MegatronAdaptationABC):
parallel_mlp_init_wrapper
,
parallel_mlp_init_wrapper
,
apply_wrapper
=
True
)
apply_wrapper
=
True
)
# ParallelAttention
MegatronAdaptation
.
register
(
'megatron.legacy.model.transformer.ParallelAttention.__init__'
,
MegatronAdaptation
.
register
(
'megatron.legacy.model.transformer.ParallelAttention.__init__'
,
parallel_attention_init_wrapper
,
parallel_attention_init_wrapper
,
apply_wrapper
=
True
)
apply_wrapper
=
True
)
...
...
dcu_megatron/adaptor/patch_utils.py
View file @
1f7b14ab
...
@@ -148,11 +148,29 @@ class MegatronPatchesManager:
...
@@ -148,11 +148,29 @@ class MegatronPatchesManager:
patches_info
=
{}
patches_info
=
{}
@
staticmethod
@
staticmethod
def
register_patch
(
orig_func_or_cls_name
,
new_func_or_cls
=
None
,
force_patch
=
False
,
create_dummy
=
False
):
def
register_patch
(
orig_func_or_cls_name
,
new_func_or_cls
=
None
,
force_patch
=
False
,
create_dummy
=
False
,
apply_wrapper
=
False
,
remove_origin_wrappers
=
False
):
if
orig_func_or_cls_name
not
in
MegatronPatchesManager
.
patches_info
:
if
orig_func_or_cls_name
not
in
MegatronPatchesManager
.
patches_info
:
MegatronPatchesManager
.
patches_info
[
orig_func_or_cls_name
]
=
Patch
(
orig_func_or_cls_name
,
new_func_or_cls
,
create_dummy
)
MegatronPatchesManager
.
patches_info
[
orig_func_or_cls_name
]
=
Patch
(
orig_func_or_cls_name
,
new_func_or_cls
,
create_dummy
,
apply_wrapper
=
apply_wrapper
,
remove_origin_wrappers
=
remove_origin_wrappers
)
else
:
else
:
MegatronPatchesManager
.
patches_info
.
get
(
orig_func_or_cls_name
).
set_patch_func
(
new_func_or_cls
,
force_patch
)
MegatronPatchesManager
.
patches_info
.
get
(
orig_func_or_cls_name
).
set_patch_func
(
new_func_or_cls
,
force_patch
,
apply_wrapper
=
apply_wrapper
,
remove_origin_wrappers
=
remove_origin_wrappers
)
@
staticmethod
@
staticmethod
def
apply_patches
():
def
apply_patches
():
...
...
dcu_megatron/core/distributed/finalize_model_grads.py
View file @
1f7b14ab
...
@@ -28,7 +28,12 @@ def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: Transf
...
@@ -28,7 +28,12 @@ def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: Transf
model_module
=
model
[
0
]
model_module
=
model
[
0
]
model_module
=
get_attr_wrapped_model
(
model_module
,
'pre_process'
,
return_model_obj
=
True
)
model_module
=
get_attr_wrapped_model
(
model_module
,
'pre_process'
,
return_model_obj
=
True
)
if
model_module
.
share_embeddings_and_output_weights
or
getattr
(
config
,
'num_nextn_predict_layers'
,
0
):
# If share_embeddings_and_output_weights is True, we need to maintain duplicated
# embedding weights in post processing stage. If use Multi-Token Prediction (MTP),
# we also need to maintain duplicated embedding weights in mtp process stage.
# So we need to allreduce grads of embedding in the embedding group in these cases.
if
model_module
.
share_embeddings_and_output_weights
or
getattr
(
config
,
'mtp_num_layers'
,
0
):
weight
=
model_module
.
shared_embedding_or_output_weight
()
weight
=
model_module
.
shared_embedding_or_output_weight
()
grad_attr
=
"main_grad"
if
hasattr
(
weight
,
"main_grad"
)
else
"grad"
grad_attr
=
"main_grad"
if
hasattr
(
weight
,
"main_grad"
)
else
"grad"
orig_grad
=
getattr
(
weight
,
grad_attr
)
orig_grad
=
getattr
(
weight
,
grad_attr
)
...
...
dcu_megatron/core/models/common/language_module/language_module.py
View file @
1f7b14ab
...
@@ -4,6 +4,7 @@ import torch
...
@@ -4,6 +4,7 @@ import torch
from
megatron.core
import
parallel_state
from
megatron.core
import
parallel_state
from
megatron.core.dist_checkpointing.mapping
import
ShardedStateDict
from
megatron.core.dist_checkpointing.mapping
import
ShardedStateDict
from
megatron.core.models.common.language_module.language_module
import
LanguageModule
from
megatron.core.utils
import
make_tp_sharded_tensor_for_checkpoint
from
megatron.core.utils
import
make_tp_sharded_tensor_for_checkpoint
...
@@ -27,7 +28,7 @@ def setup_embeddings_and_output_layer(self) -> None:
...
@@ -27,7 +28,7 @@ def setup_embeddings_and_output_layer(self) -> None:
# So we need to copy embedding weights from pre processing stage as initial parameters
# So we need to copy embedding weights from pre processing stage as initial parameters
# in these cases.
# in these cases.
if
not
self
.
share_embeddings_and_output_weights
and
not
getattr
(
if
not
self
.
share_embeddings_and_output_weights
and
not
getattr
(
self
.
config
,
'
num_nextn_predict
_layers'
,
0
self
.
config
,
'
mtp_num
_layers'
,
0
):
):
return
return
...
@@ -41,10 +42,10 @@ def setup_embeddings_and_output_layer(self) -> None:
...
@@ -41,10 +42,10 @@ def setup_embeddings_and_output_layer(self) -> None:
if
parallel_state
.
is_pipeline_first_stage
()
and
self
.
pre_process
and
not
self
.
post_process
:
if
parallel_state
.
is_pipeline_first_stage
()
and
self
.
pre_process
and
not
self
.
post_process
:
self
.
shared_embedding_or_output_weight
().
shared_embedding
=
True
self
.
shared_embedding_or_output_weight
().
shared_embedding
=
True
if
self
.
post_process
and
not
self
.
pre_process
:
if
(
self
.
post_process
or
getattr
(
self
,
'mtp_process'
,
False
))
and
not
self
.
pre_process
:
assert
not
parallel_state
.
is_pipeline_first_stage
()
assert
not
parallel_state
.
is_pipeline_first_stage
()
# set w
ord_embeddings weights to 0 here, then copy first
# set w
eights of the duplicated embedding to 0 here,
#
stage's weights
using all_reduce below.
#
then copy weights from pre processing stage
using all_reduce below.
weight
=
self
.
shared_embedding_or_output_weight
()
weight
=
self
.
shared_embedding_or_output_weight
()
weight
.
data
.
fill_
(
0
)
weight
.
data
.
fill_
(
0
)
weight
.
shared
=
True
weight
.
shared
=
True
...
@@ -114,7 +115,7 @@ def tie_embeddings_and_output_weights_state_dict(
...
@@ -114,7 +115,7 @@ def tie_embeddings_and_output_weights_state_dict(
# layer in mtp process stage. In this case, if share_embeddings_and_output_weights is True,
# layer in mtp process stage. In this case, if share_embeddings_and_output_weights is True,
# the shared weights will be stored in embedding layer, and output layer will not have
# the shared weights will be stored in embedding layer, and output layer will not have
# any weight.
# any weight.
if
self
.
post_process
and
getattr
(
self
,
'num_nextn_predict_layer
s'
,
False
):
if
getattr
(
self
,
'mtp_proces
s'
,
False
):
# No output layer
# No output layer
assert
output_layer_weight_key
not
in
sharded_state_dict
,
sharded_state_dict
.
keys
()
assert
output_layer_weight_key
not
in
sharded_state_dict
,
sharded_state_dict
.
keys
()
return
return
...
...
dcu_megatron/core/models/gpt/gpt_layer_specs.py
View file @
1f7b14ab
import
warnings
import
warnings
from
typing
import
Optional
from
typing
import
Optional
,
Union
from
megatron.core.fusions.fused_bias_dropout
import
get_bias_dropout_add
from
megatron.core.fusions.fused_bias_dropout
import
get_bias_dropout_add
from
megatron.core.models.gpt.moe_module_specs
import
get_moe_module_spec
from
megatron.core.models.gpt.moe_module_specs
import
get_moe_module_spec
...
@@ -12,13 +12,13 @@ from megatron.core.transformer.multi_latent_attention import (
...
@@ -12,13 +12,13 @@ from megatron.core.transformer.multi_latent_attention import (
MLASelfAttentionSubmodules
,
MLASelfAttentionSubmodules
,
)
)
from
megatron.core.transformer.spec_utils
import
ModuleSpec
from
megatron.core.transformer.spec_utils
import
ModuleSpec
from
megatron.core.transformer.transformer_block
import
TransformerBlockSubmodules
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core.transformer.transformer_layer
import
(
from
megatron.core.transformer.transformer_layer
import
(
TransformerLayer
,
TransformerLayer
,
TransformerLayerSubmodules
,
TransformerLayerSubmodules
,
)
)
from
dcu_megatron.core.tensor_parallel.layers
import
FluxColumnParallelLinear
,
FluxRowParallelLinear
from
megatron.core.utils
import
is_te_min_version
from
megatron.core.utils
import
is_te_min_version
try
:
try
:
...
@@ -36,6 +36,55 @@ try:
...
@@ -36,6 +36,55 @@ try:
except
ImportError
:
except
ImportError
:
warnings
.
warn
(
'Apex is not installed.'
)
warnings
.
warn
(
'Apex is not installed.'
)
from
dcu_megatron.core.tensor_parallel.layers
import
(
FluxColumnParallelLinear
,
FluxRowParallelLinear
)
from
dcu_megatron.core.transformer.multi_token_prediction
import
(
MultiTokenPredictionBlockSubmodules
,
get_mtp_layer_offset
,
get_mtp_layer_spec
,
get_mtp_num_layers_to_build
,
)
def
get_gpt_mtp_block_spec
(
config
:
TransformerConfig
,
spec
:
Union
[
TransformerBlockSubmodules
,
ModuleSpec
],
use_transformer_engine
:
bool
,
)
->
MultiTokenPredictionBlockSubmodules
:
"""GPT Multi-Token Prediction (MTP) block spec."""
num_layers_to_build
=
get_mtp_num_layers_to_build
(
config
)
if
num_layers_to_build
==
0
:
return
None
if
isinstance
(
spec
,
TransformerBlockSubmodules
):
# get the spec for the last layer of decoder block
transformer_layer_spec
=
spec
.
layer_specs
[
-
1
]
elif
isinstance
(
spec
,
ModuleSpec
)
and
spec
.
module
==
TransformerLayer
:
transformer_layer_spec
=
spec
else
:
raise
ValueError
(
f
"Invalid spec:
{
spec
}
"
)
mtp_layer_spec
=
get_mtp_layer_spec
(
transformer_layer_spec
=
transformer_layer_spec
,
use_transformer_engine
=
use_transformer_engine
)
mtp_num_layers
=
config
.
mtp_num_layers
if
config
.
mtp_num_layers
else
0
mtp_layer_specs
=
[
mtp_layer_spec
]
*
mtp_num_layers
offset
=
get_mtp_layer_offset
(
config
)
# split the mtp layer specs to only include the layers that are built in this pipeline stage.
mtp_layer_specs
=
mtp_layer_specs
[
offset
:
offset
+
num_layers_to_build
]
if
len
(
mtp_layer_specs
)
>
0
:
assert
(
len
(
mtp_layer_specs
)
==
config
.
mtp_num_layers
),
+
f
"currently all of the mtp layers must stage in the same pipeline stage."
mtp_block_spec
=
MultiTokenPredictionBlockSubmodules
(
layer_specs
=
mtp_layer_specs
)
else
:
mtp_block_spec
=
None
return
mtp_block_spec
def
get_gpt_layer_with_flux_spec
(
def
get_gpt_layer_with_flux_spec
(
num_experts
:
Optional
[
int
]
=
None
,
num_experts
:
Optional
[
int
]
=
None
,
...
...
dcu_megatron/core/models/gpt/gpt_model.py
View file @
1f7b14ab
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import
os
import
os
import
logging
from
typing
import
Literal
,
Optional
from
functools
import
wraps
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
typing
import
Dict
,
Literal
,
Optional
import
torch
import
torch
from
torch
import
Tensor
from
torch
import
Tensor
from
megatron.core
import
InferenceParams
,
parallel_state
,
tensor_parallel
from
megatron.core
import
InferenceParams
,
tensor_parallel
from
megatron.core.config_logger
import
has_config_logger_enabled
,
log_config_to_disk
from
megatron.core.config_logger
import
has_config_logger_enabled
,
log_config_to_disk
from
megatron.core.models.gpt.gpt_model
import
GPTModel
from
megatron.core.dist_checkpointing.mapping
import
ShardedStateDict
from
megatron.core.models.common.language_module.language_module
import
LanguageModule
from
megatron.core.models.common.embeddings.language_model_embedding
import
LanguageModelEmbedding
from
megatron.core.models.common.embeddings.language_model_embedding
import
LanguageModelEmbedding
from
megatron.core.models.common.embeddings.rotary_pos_embedding
import
RotaryEmbedding
from
megatron.core.models.common.embeddings.rotary_pos_embedding
import
RotaryEmbedding
from
megatron.core.models.common.language_module.language_module
import
LanguageModule
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.transformer.enums
import
ModelType
from
megatron.core.transformer.enums
import
ModelType
from
megatron.core.transformer.spec_utils
import
ModuleSpec
from
megatron.core.transformer.spec_utils
import
ModuleSpec
from
megatron.core.transformer.transformer_block
import
TransformerBlock
from
megatron.core.transformer.transformer_block
import
TransformerBlock
from
megatron.core.
extensions
.transformer_
engine
import
T
EColumnParallelLinear
from
megatron.core.
transformer
.transformer_
config
import
T
ransformerConfig
from
dcu_megatron.core.utils
import
tensor_slide
from
dcu_megatron.core.transformer.mtp.multi_token_predictor
import
MultiTokenPredictor
from
dcu_megatron.core.transformer.transformer_config
import
TransformerConfig
from
dcu_megatron.core.tensor_parallel
import
FluxColumnParallelLinear
from
dcu_megatron.core.tensor_parallel
import
FluxColumnParallelLinear
from
dcu_megatron.core.transformer.multi_token_prediction
import
(
MultiTokenPredictionBlock
,
tie_output_layer_state_dict
,
tie_word_embeddings_state_dict
,
)
def
gpt_model_init
(
class
GPTModel
(
LanguageModule
):
self
,
"""GPT Transformer language model.
config
:
TransformerConfig
,
transformer_layer_spec
:
ModuleSpec
,
vocab_size
:
int
,
max_sequence_length
:
int
,
pre_process
:
bool
=
True
,
post_process
:
bool
=
True
,
fp16_lm_cross_entropy
:
bool
=
False
,
parallel_output
:
bool
=
True
,
share_embeddings_and_output_weights
:
bool
=
False
,
position_embedding_type
:
Literal
[
'learned_absolute'
,
'rope'
,
'none'
]
=
'learned_absolute'
,
rotary_percent
:
float
=
1.0
,
rotary_base
:
int
=
10000
,
rope_scaling
:
bool
=
False
,
rope_scaling_factor
:
float
=
8.0
,
scatter_embedding_sequence_parallel
:
bool
=
True
,
seq_len_interpolation_factor
:
Optional
[
float
]
=
None
,
)
->
None
:
super
(
GPTModel
,
self
).
__init__
(
config
=
config
)
if
has_config_logger_enabled
(
config
):
log_config_to_disk
(
config
,
locals
(),
prefix
=
type
(
self
).
__name__
)
self
.
transformer_layer_spec
:
ModuleSpec
=
transformer_layer_spec
self
.
vocab_size
=
vocab_size
self
.
max_sequence_length
=
max_sequence_length
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
fp16_lm_cross_entropy
=
fp16_lm_cross_entropy
self
.
parallel_output
=
parallel_output
self
.
share_embeddings_and_output_weights
=
share_embeddings_and_output_weights
self
.
position_embedding_type
=
position_embedding_type
# megatron core pipelining currently depends on model type
# TODO: remove this dependency ?
self
.
model_type
=
ModelType
.
encoder_or_decoder
# These 4 attributes are needed for TensorRT-LLM export.
self
.
max_position_embeddings
=
max_sequence_length
self
.
rotary_percent
=
rotary_percent
self
.
rotary_base
=
rotary_base
self
.
rotary_scaling
=
rope_scaling
self
.
num_nextn_predict_layers
=
self
.
config
.
num_nextn_predict_layers
if
self
.
pre_process
:
self
.
embedding
=
LanguageModelEmbedding
(
config
=
self
.
config
,
vocab_size
=
self
.
vocab_size
,
max_sequence_length
=
self
.
max_sequence_length
,
position_embedding_type
=
position_embedding_type
,
scatter_to_sequence_parallel
=
scatter_embedding_sequence_parallel
,
)
if
self
.
position_embedding_type
==
'rope'
and
not
self
.
config
.
multi_latent_attention
:
self
.
rotary_pos_emb
=
RotaryEmbedding
(
kv_channels
=
self
.
config
.
kv_channels
,
rotary_percent
=
rotary_percent
,
rotary_interleaved
=
self
.
config
.
rotary_interleaved
,
seq_len_interpolation_factor
=
seq_len_interpolation_factor
,
rotary_base
=
rotary_base
,
rope_scaling
=
rope_scaling
,
rope_scaling_factor
=
rope_scaling_factor
,
use_cpu_initialization
=
self
.
config
.
use_cpu_initialization
,
)
# Cache for RoPE tensors which do not change between iterations.
self
.
rotary_pos_emb_cache
=
{}
# Transformer.
self
.
decoder
=
TransformerBlock
(
config
=
self
.
config
,
spec
=
transformer_layer_spec
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
,
)
if
self
.
post_process
and
getattr
(
self
.
config
,
'num_nextn_predict_layers'
,
0
):
self
.
embedding
=
LanguageModelEmbedding
(
config
=
self
.
config
,
vocab_size
=
self
.
vocab_size
,
max_sequence_length
=
self
.
max_sequence_length
,
position_embedding_type
=
position_embedding_type
,
scatter_to_sequence_parallel
=
scatter_embedding_sequence_parallel
,
)
# Output
if
post_process
:
if
self
.
config
.
defer_embedding_wgrad_compute
:
# The embedding activation buffer preserves a reference to the input activations
# of the final embedding projection layer GEMM. It will hold the activations for
# all the micro-batches of a global batch for the last pipeline stage. Once we are
# done with all the back props for all the microbatches for the last pipeline stage,
# it will be in the pipeline flush stage. During this pipeline flush we use the
# input activations stored in embedding activation buffer and gradient outputs
# stored in gradient buffer to calculate the weight gradients for the embedding
# final linear layer.
self
.
embedding_activation_buffer
=
[]
self
.
grad_output_buffer
=
[]
else
:
self
.
embedding_activation_buffer
=
None
self
.
grad_output_buffer
=
None
if
int
(
os
.
getenv
(
"USE_FLUX_OVERLAP"
,
"0"
)):
column_parallel_linear_impl
=
FluxColumnParallelLinear
else
:
column_parallel_linear_impl
=
tensor_parallel
.
ColumnParallelLinear
self
.
output_layer
=
column_parallel_linear_impl
(
config
.
hidden_size
,
self
.
vocab_size
,
config
=
config
,
init_method
=
config
.
init_method
,
bias
=
False
,
skip_bias_add
=
False
,
gather_output
=
not
self
.
parallel_output
,
skip_weight_param_allocation
=
self
.
pre_process
and
self
.
share_embeddings_and_output_weights
,
embedding_activation_buffer
=
self
.
embedding_activation_buffer
,
grad_output_buffer
=
self
.
grad_output_buffer
,
)
# add mtp
if
self
.
num_nextn_predict_layers
:
assert
hasattr
(
self
.
config
,
"mtp_spec"
)
self
.
mtp_spec
=
self
.
config
.
mtp_spec
self
.
recompute_mtp_norm
=
self
.
config
.
recompute_mtp_norm
self
.
recompute_mtp_layer
=
self
.
config
.
recompute_mtp_layer
self
.
mtp_loss_scale
=
self
.
config
.
mtp_loss_scale
if
self
.
post_process
and
self
.
training
:
self
.
mtp_layers
=
torch
.
nn
.
ModuleList
(
[
MultiTokenPredictor
(
self
.
config
,
self
.
mtp_spec
.
submodules
,
vocab_size
=
self
.
vocab_size
,
max_sequence_length
=
self
.
max_sequence_length
,
layer_number
=
i
,
pre_process
=
self
.
pre_process
,
fp16_lm_cross_entropy
=
self
.
fp16_lm_cross_entropy
,
parallel_output
=
self
.
parallel_output
,
position_embedding_type
=
self
.
position_embedding_type
,
rotary_percent
=
self
.
rotary_percent
,
seq_len_interpolation_factor
=
seq_len_interpolation_factor
,
recompute_mtp_norm
=
self
.
recompute_mtp_norm
,
recompute_mtp_layer
=
self
.
recompute_mtp_layer
,
add_output_layer_bias
=
False
)
for
i
in
range
(
self
.
num_nextn_predict_layers
)
]
)
if
self
.
pre_process
or
self
.
post_process
:
Args:
self
.
setup_embeddings_and_output_layer
()
config (TransformerConfig):
Transformer config
transformer_layer_spec (ModuleSpec):
Specifies module to use for transformer layers
vocab_size (int):
Vocabulary size
max_sequence_length (int):
maximum size of sequence. This is used for positional embedding
pre_process (bool, optional):
Include embedding layer (used with pipeline parallelism). Defaults to True.
post_process (bool, optional):
Include an output layer (used with pipeline parallelism). Defaults to True.
fp16_lm_cross_entropy (bool, optional):
Defaults to False.
parallel_output (bool, optional):
Do not gather the outputs, keep them split across tensor
parallel ranks. Defaults to True.
share_embeddings_and_output_weights (bool, optional):
When True, input embeddings and output logit weights are shared. Defaults to False.
position_embedding_type (Literal[learned_absolute,rope], optional):
Position embedding type.. Defaults to 'learned_absolute'.
rotary_percent (float, optional):
Percent of rotary dimension to use for rotary position embeddings.
Ignored unless position_embedding_type is 'rope'. Defaults to 1.0.
rotary_base (int, optional):
Base period for rotary position embeddings. Ignored unless
position_embedding_type is 'rope'.
Defaults to 10000.
rope_scaling (bool, optional): Toggle RoPE scaling.
rope_scaling_factor (float): RoPE scaling factor. Default 8.
scatter_embedding_sequence_parallel (bool, optional):
Whether embeddings should be scattered across sequence parallel
region or not. Defaults to True.
seq_len_interpolation_factor (Optional[float], optional):
scale of linearly interpolating RoPE for longer sequences.
The value must be a float larger than 1.0. Defaults to None.
"""
if
has_config_logger_enabled
(
self
.
config
):
def
__init__
(
log_config_to_disk
(
self
,
self
.
config
,
self
.
state_dict
(),
prefix
=
f
'
{
type
(
self
).
__name__
}
_init_ckpt'
config
:
TransformerConfig
,
)
transformer_layer_spec
:
ModuleSpec
,
vocab_size
:
int
,
max_sequence_length
:
int
,
pre_process
:
bool
=
True
,
post_process
:
bool
=
True
,
fp16_lm_cross_entropy
:
bool
=
False
,
parallel_output
:
bool
=
True
,
share_embeddings_and_output_weights
:
bool
=
False
,
position_embedding_type
:
Literal
[
'learned_absolute'
,
'rope'
,
'none'
]
=
'learned_absolute'
,
rotary_percent
:
float
=
1.0
,
rotary_base
:
int
=
10000
,
rope_scaling
:
bool
=
False
,
rope_scaling_factor
:
float
=
8.0
,
scatter_embedding_sequence_parallel
:
bool
=
True
,
seq_len_interpolation_factor
:
Optional
[
float
]
=
None
,
mtp_block_spec
:
Optional
[
ModuleSpec
]
=
None
,
)
->
None
:
super
().
__init__
(
config
=
config
)
if
has_config_logger_enabled
(
config
):
log_config_to_disk
(
config
,
locals
(),
prefix
=
type
(
self
).
__name__
)
self
.
transformer_layer_spec
:
ModuleSpec
=
transformer_layer_spec
self
.
vocab_size
=
vocab_size
self
.
max_sequence_length
=
max_sequence_length
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
fp16_lm_cross_entropy
=
fp16_lm_cross_entropy
self
.
parallel_output
=
parallel_output
self
.
share_embeddings_and_output_weights
=
share_embeddings_and_output_weights
self
.
position_embedding_type
=
position_embedding_type
# megatron core pipelining currently depends on model type
# TODO: remove this dependency ?
self
.
model_type
=
ModelType
.
encoder_or_decoder
# These 4 attributes are needed for TensorRT-LLM export.
self
.
max_position_embeddings
=
max_sequence_length
self
.
rotary_percent
=
rotary_percent
self
.
rotary_base
=
rotary_base
self
.
rotary_scaling
=
rope_scaling
self
.
mtp_block_spec
=
mtp_block_spec
self
.
mtp_process
=
mtp_block_spec
is
not
None
if
self
.
pre_process
or
self
.
mtp_process
:
self
.
embedding
=
LanguageModelEmbedding
(
config
=
self
.
config
,
vocab_size
=
self
.
vocab_size
,
max_sequence_length
=
self
.
max_sequence_length
,
position_embedding_type
=
position_embedding_type
,
scatter_to_sequence_parallel
=
scatter_embedding_sequence_parallel
,
)
if
self
.
position_embedding_type
==
'rope'
and
not
self
.
config
.
multi_latent_attention
:
self
.
rotary_pos_emb
=
RotaryEmbedding
(
kv_channels
=
self
.
config
.
kv_channels
,
rotary_percent
=
rotary_percent
,
rotary_interleaved
=
self
.
config
.
rotary_interleaved
,
seq_len_interpolation_factor
=
seq_len_interpolation_factor
,
rotary_base
=
rotary_base
,
rope_scaling
=
rope_scaling
,
rope_scaling_factor
=
rope_scaling_factor
,
use_cpu_initialization
=
self
.
config
.
use_cpu_initialization
,
)
def
shared_embedding_or_output_weight
(
self
)
->
Tensor
:
# Cache for RoPE tensors which do not change between iterations.
"""Gets the emedding weight or output logit weights when share embedding and output weights set to True.
self
.
rotary_pos_emb_cache
=
{}
Returns:
# Transformer.
Tensor: During pre processing it returns the input embeddings weight while during post processing it returns the final output layers weight
self
.
decoder
=
TransformerBlock
(
"""
config
=
self
.
config
,
if
self
.
pre_process
or
(
self
.
post_process
and
getattr
(
self
.
config
,
'num_nextn_predict_layers'
,
0
)):
spec
=
transformer_layer_spec
,
return
self
.
embedding
.
word_embeddings
.
weight
pre_process
=
self
.
pre_process
,
elif
self
.
post_process
:
post_process
=
self
.
post_process
,
return
self
.
output_layer
.
weight
return
None
def
slice_inputs
(
self
,
input_ids
,
labels
,
position_ids
,
attention_mask
):
if
self
.
num_nextn_predict_layers
==
0
:
return
(
[
input_ids
],
[
labels
],
[
position_ids
],
[
attention_mask
],
)
)
return
(
if
self
.
mtp_process
:
tensor_slide
(
input_ids
,
self
.
num_nextn_predict_layers
),
self
.
mtp
=
MultiTokenPredictionBlock
(
config
=
self
.
config
,
spec
=
self
.
mtp_block_spec
)
tensor_slide
(
labels
,
self
.
num_nextn_predict_layers
),
generate_nextn_position_ids
(
position_ids
,
self
.
num_nextn_predict_layers
),
# Output
# not compatible with ppo attn_mask
if
self
.
post_process
or
self
.
mtp_process
:
tensor_slide
(
attention_mask
,
self
.
num_nextn_predict_layers
,
dims
=
[
-
2
,
-
1
]),
)
if
self
.
config
.
defer_embedding_wgrad_compute
:
# The embedding activation buffer preserves a reference to the input activations
# of the final embedding projection layer GEMM. It will hold the activations for
# all the micro-batches of a global batch for the last pipeline stage. Once we are
# done with all the back props for all the microbatches for the last pipeline stage,
# it will be in the pipeline flush stage. During this pipeline flush we use the
# input activations stored in embedding activation buffer and gradient outputs
# stored in gradient buffer to calculate the weight gradients for the embedding
# final linear layer.
self
.
embedding_activation_buffer
=
[]
self
.
grad_output_buffer
=
[]
else
:
self
.
embedding_activation_buffer
=
None
self
.
grad_output_buffer
=
None
if
int
(
os
.
getenv
(
"USE_FLUX_OVERLAP"
,
"0"
)):
parallel_linear_impl
=
FluxColumnParallelLinear
else
:
parallel_linear_impl
=
tensor_parallel
.
ColumnParallelLinear
self
.
output_layer
=
parallel_linear_impl
(
config
.
hidden_size
,
self
.
vocab_size
,
config
=
config
,
init_method
=
config
.
init_method
,
bias
=
False
,
skip_bias_add
=
False
,
gather_output
=
not
self
.
parallel_output
,
skip_weight_param_allocation
=
self
.
pre_process
and
self
.
share_embeddings_and_output_weights
,
embedding_activation_buffer
=
self
.
embedding_activation_buffer
,
grad_output_buffer
=
self
.
grad_output_buffer
,
)
def
generate_nextn_position_ids
(
tensor
,
slice_num
):
if
self
.
pre_process
or
self
.
post_process
:
slides
=
tensor_slide
(
tensor
,
slice_num
)
self
.
setup_embeddings_and_output_layer
()
if
slides
[
0
]
is
None
:
return
slides
for
idx
in
range
(
1
,
len
(
slides
)):
if
has_config_logger_enabled
(
self
.
config
):
slides
[
idx
]
=
regenerate_position_ids
(
slides
[
idx
],
idx
)
log_config_to_disk
(
return
slides
self
.
config
,
self
.
state_dict
(),
prefix
=
f
'
{
type
(
self
).
__name__
}
_init_ckpt'
)
def
set_input_tensor
(
self
,
input_tensor
:
Tensor
)
->
None
:
"""Sets input tensor to the model.
def
regenerate_position_ids
(
tensor
,
offset
):
See megatron.model.transformer.set_input_tensor()
if
tensor
is
None
:
return
None
tensor
=
tensor
.
clone
()
Args:
for
i
in
range
(
tensor
.
size
(
0
)):
input_tensor (Tensor): Sets the input tensor for the model.
row
=
tensor
[
i
]
"""
zero_mask
=
(
row
==
0
)
# 两句拼接情形
# This is usually handled in schedules.py but some inference code still
if
zero_mask
.
any
():
# gives us non-lists or None
first_zero_idx
=
torch
.
argmax
(
zero_mask
.
int
()).
item
()
if
not
isinstance
(
input_tensor
,
list
):
tensor
[
i
,
:
first_zero_idx
]
=
torch
.
arange
(
first_zero_idx
)
input_tensor
=
[
input_tensor
]
else
:
tensor
[
i
]
=
tensor
[
i
]
-
offset
return
tensor
def
gpt_model_forward
(
self
,
input_ids
:
Tensor
,
position_ids
:
Tensor
,
attention_mask
:
Tensor
,
decoder_input
:
Tensor
=
None
,
labels
:
Tensor
=
None
,
inference_params
:
InferenceParams
=
None
,
packed_seq_params
:
PackedSeqParams
=
None
,
extra_block_kwargs
:
dict
=
None
,
runtime_gather_output
:
Optional
[
bool
]
=
None
,
)
->
Tensor
:
"""Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoeder and finally into the post
processing layer (optional).
It either returns the Loss values if labels are given or the final hidden units
Args:
assert
len
(
input_tensor
)
==
1
,
'input_tensor should only be length 1 for gpt/bert'
runtime_gather_output (bool): Gather output at runtime. Default None means
self
.
decoder
.
set_input_tensor
(
input_tensor
[
0
])
`parallel_output` arg in the constructor will be used.
"""
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
# generate inputs for main and mtps
def
forward
(
input_ids
,
labels
,
position_ids
,
attention_mask
=
slice_inputs
(
self
,
self
,
input_ids
,
input_ids
:
Tensor
,
labels
,
position_ids
:
Tensor
,
position_ids
,
attention_mask
:
Tensor
,
attention_mask
decoder_input
:
Tensor
=
None
,
)
labels
:
Tensor
=
None
,
inference_params
:
InferenceParams
=
None
,
# Decoder embedding.
packed_seq_params
:
PackedSeqParams
=
None
,
if
decoder_input
is
not
None
:
extra_block_kwargs
:
dict
=
None
,
pass
runtime_gather_output
:
Optional
[
bool
]
=
None
,
elif
self
.
pre_process
:
loss_mask
:
Optional
[
Tensor
]
=
None
,
decoder_input
=
self
.
embedding
(
input_ids
=
input_ids
[
0
],
position_ids
=
position_ids
[
0
])
)
->
Tensor
:
else
:
"""Forward function of the GPT Model This function passes the input tensors
# intermediate stage of pipeline
through the embedding layer, and then the decoeder and finally into the post
# decoder will get hidden_states from encoder.input_tensor
processing layer (optional).
decoder_input
=
None
It either returns the Loss values if labels are given or the final hidden units
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb
=
None
Args:
rotary_pos_cos
=
None
runtime_gather_output (bool): Gather output at runtime. Default None means
rotary_pos_sin
=
None
`parallel_output` arg in the constructor will be used.
if
self
.
position_embedding_type
==
'rope'
and
not
self
.
config
.
multi_latent_attention
:
"""
if
not
self
.
training
and
self
.
config
.
flash_decode
and
inference_params
:
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# Flash decoding uses precomputed cos and sin for RoPE
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
rotary_pos_cos
,
rotary_pos_sin
=
self
.
rotary_pos_emb_cache
.
setdefault
(
inference_params
.
max_sequence_length
,
# Decoder embedding.
self
.
rotary_pos_emb
.
get_cos_sin
(
inference_params
.
max_sequence_length
),
if
decoder_input
is
not
None
:
)
pass
elif
self
.
pre_process
:
decoder_input
=
self
.
embedding
(
input_ids
=
input_ids
,
position_ids
=
position_ids
)
else
:
else
:
rotary_seq_len
=
self
.
rotary_pos_emb
.
get_rotary_seq_len
(
# intermediate stage of pipeline
inference_params
,
self
.
decoder
,
decoder_input
,
self
.
config
,
packed_seq_params
# decoder will get hidden_states from encoder.input_tensor
)
decoder_input
=
None
rotary_pos_emb
=
self
.
rotary_pos_emb
(
rotary_seq_len
,
# Rotary positional embeddings (embedding is None for PP intermediate devices)
packed_seq
=
packed_seq_params
is
not
None
rotary_pos_emb
=
None
and
packed_seq_params
.
qkv_format
==
'thd'
,
rotary_pos_cos
=
None
rotary_pos_sin
=
None
if
self
.
position_embedding_type
==
'rope'
and
not
self
.
config
.
multi_latent_attention
:
if
not
self
.
training
and
self
.
config
.
flash_decode
and
inference_params
:
# Flash decoding uses precomputed cos and sin for RoPE
rotary_pos_cos
,
rotary_pos_sin
=
self
.
rotary_pos_emb_cache
.
setdefault
(
inference_params
.
max_sequence_length
,
self
.
rotary_pos_emb
.
get_cos_sin
(
inference_params
.
max_sequence_length
),
)
else
:
rotary_seq_len
=
self
.
rotary_pos_emb
.
get_rotary_seq_len
(
inference_params
,
self
.
decoder
,
decoder_input
,
self
.
config
,
packed_seq_params
)
rotary_pos_emb
=
self
.
rotary_pos_emb
(
rotary_seq_len
,
packed_seq
=
packed_seq_params
is
not
None
and
packed_seq_params
.
qkv_format
==
'thd'
,
)
if
(
(
self
.
config
.
enable_cuda_graph
or
self
.
config
.
flash_decode
)
and
rotary_pos_cos
is
not
None
and
inference_params
):
sequence_len_offset
=
torch
.
tensor
(
[
inference_params
.
sequence_len_offset
]
*
inference_params
.
current_batch_size
,
dtype
=
torch
.
int32
,
device
=
rotary_pos_cos
.
device
,
# Co-locate this with the rotary tensors
)
)
if
(
else
:
(
self
.
config
.
enable_cuda_graph
or
self
.
config
.
flash_decode
)
sequence_len_offset
=
None
and
rotary_pos_cos
is
not
None
and
inference_params
# Run decoder.
):
hidden_states
=
self
.
decoder
(
sequence_len_offset
=
torch
.
tensor
(
hidden_states
=
decoder_input
,
[
inference_params
.
sequence_len_offset
]
*
inference_params
.
current_batch_size
,
attention_mask
=
attention_mask
,
dtype
=
torch
.
int32
,
inference_params
=
inference_params
,
device
=
rotary_pos_cos
.
device
,
# Co-locate this with the rotary tensors
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_cos
=
rotary_pos_cos
,
rotary_pos_sin
=
rotary_pos_sin
,
packed_seq_params
=
packed_seq_params
,
sequence_len_offset
=
sequence_len_offset
,
**
(
extra_block_kwargs
or
{}),
)
)
else
:
sequence_len_offset
=
None
# logits and loss
output_weight
=
None
# Run decoder.
if
self
.
share_embeddings_and_output_weights
:
hidden_states
=
self
.
decoder
(
output_weight
=
self
.
shared_embedding_or_output_weight
()
hidden_states
=
decoder_input
,
attention_mask
=
attention_mask
[
0
],
if
self
.
mtp_process
:
inference_params
=
inference_params
,
hidden_states
=
self
.
mtp
(
rotary_pos_emb
=
rotary_pos_emb
,
input_ids
=
input_ids
,
rotary_pos_cos
=
rotary_pos_cos
,
position_ids
=
position_ids
,
rotary_pos_sin
=
rotary_pos_sin
,
labels
=
labels
,
packed_seq_params
=
packed_seq_params
,
loss_mask
=
loss_mask
,
sequence_len_offset
=
sequence_len_offset
,
hidden_states
=
hidden_states
,
**
(
extra_block_kwargs
or
{}),
attention_mask
=
attention_mask
,
)
inference_params
=
inference_params
,
rotary_pos_emb
=
rotary_pos_emb
,
if
not
self
.
post_process
:
rotary_pos_cos
=
rotary_pos_cos
,
return
hidden_states
rotary_pos_sin
=
rotary_pos_sin
,
packed_seq_params
=
packed_seq_params
,
# logits and loss
sequence_len_offset
=
sequence_len_offset
,
output_weight
=
None
embedding
=
self
.
embedding
,
if
self
.
share_embeddings_and_output_weights
:
output_weight
=
self
.
shared_embedding_or_output_weight
()
loss
=
0
# Multi token prediction module
if
self
.
num_nextn_predict_layers
and
self
.
training
:
mtp_hidden_states
=
hidden_states
for
i
in
range
(
self
.
num_nextn_predict_layers
):
mtp_hidden_states
,
mtp_loss
=
self
.
mtp_layers
[
i
](
mtp_hidden_states
,
# [s,b,h]
input_ids
[
i
+
1
],
position_ids
[
i
+
1
]
if
position_ids
[
0
]
is
not
None
else
None
,
attention_mask
[
i
+
1
]
if
attention_mask
[
0
]
is
not
None
else
None
,
labels
[
i
+
1
]
if
labels
[
0
]
is
not
None
else
None
,
inference_params
,
packed_seq_params
,
extra_block_kwargs
,
embedding_layer
=
self
.
embedding
,
output_layer
=
self
.
output_layer
,
output_layer
=
self
.
output_layer
,
output_weight
=
output_weight
,
output_weight
=
output_weight
,
runtime_gather_output
=
runtime_gather_output
,
compute_language_model_loss
=
self
.
compute_language_model_loss
,
**
(
extra_block_kwargs
or
{}),
)
)
loss
+=
self
.
mtp_loss_scale
/
self
.
num_nextn_predict_layers
*
mtp_loss
if
(
self
.
mtp_process
is
not
None
if
(
and
getattr
(
self
.
decoder
,
"main_final_layernorm"
,
None
)
is
not
None
self
.
num_nextn_predict_layers
):
and
getattr
(
self
.
decoder
,
"main_final_layernorm"
,
None
)
is
not
None
# move block main model final norms here
):
hidden_states
=
self
.
decoder
.
main_final_layernorm
(
hidden_states
)
# move block main model final norms here
hidden_states
=
self
.
decoder
.
main_final_layernorm
(
hidden_states
)
if
not
self
.
post_process
:
return
hidden_states
logits
,
_
=
self
.
output_layer
(
hidden_states
,
weight
=
output_weight
,
runtime_gather_output
=
runtime_gather_output
logits
,
_
=
self
.
output_layer
(
)
hidden_states
,
weight
=
output_weight
,
runtime_gather_output
=
runtime_gather_output
if
has_config_logger_enabled
(
self
.
config
):
payload
=
OrderedDict
(
{
'input_ids'
:
input_ids
[
0
],
'position_ids'
:
position_ids
[
0
],
'attention_mask'
:
attention_mask
[
0
],
'decoder_input'
:
decoder_input
,
'logits'
:
logits
,
}
)
)
log_config_to_disk
(
self
.
config
,
payload
,
prefix
=
'input_and_logits'
)
if
labels
[
0
]
is
None
:
if
has_config_logger_enabled
(
self
.
config
):
# [s b h] => [b s h]
payload
=
OrderedDict
(
return
logits
.
transpose
(
0
,
1
).
contiguous
()
{
'input_ids'
:
input_ids
,
'position_ids'
:
position_ids
,
'attention_mask'
:
attention_mask
,
'decoder_input'
:
decoder_input
,
'logits'
:
logits
,
}
)
log_config_to_disk
(
self
.
config
,
payload
,
prefix
=
'input_and_logits'
)
if
labels
is
None
:
# [s b h] => [b s h]
return
logits
.
transpose
(
0
,
1
).
contiguous
()
loss
=
self
.
compute_language_model_loss
(
labels
,
logits
)
return
loss
def
shared_embedding_or_output_weight
(
self
)
->
Tensor
:
"""Gets the embedding weight or output logit weights when share input embedding and
output weights set to True or when use Multi-Token Prediction (MTP) feature.
Returns:
Tensor: During pre processing or MTP process it returns the input embeddings weight.
Otherwise, during post processing it returns the final output layers weight.
"""
if
self
.
pre_process
or
self
.
mtp_process
:
# Multi-Token Prediction (MTP) need both embedding layer and output layer.
# So there will be both embedding layer and output layer in the mtp process stage.
# In this case, if share_embeddings_and_output_weights is True, the shared weights
# will be stored in embedding layer, and output layer will not have any weight.
assert
hasattr
(
self
,
'embedding'
),
f
"embedding is needed in this pipeline stage, but it is not initialized."
return
self
.
embedding
.
word_embeddings
.
weight
elif
self
.
post_process
:
return
self
.
output_layer
.
weight
return
None
loss
+=
self
.
compute_language_model_loss
(
labels
[
0
],
logits
)
def
sharded_state_dict
(
self
,
prefix
:
str
=
''
,
sharded_offsets
:
tuple
=
(),
metadata
:
Optional
[
Dict
]
=
None
)
->
ShardedStateDict
:
"""Sharded state dict implementation for GPTModel backward-compatibility
(removing extra state).
Args:
prefix (str): Module name prefix.
sharded_offsets (tuple): PP related offsets, expected to be empty at this module level.
metadata (Optional[Dict]): metadata controlling sharded state dict creation.
Returns:
ShardedStateDict: sharded state dict for the GPTModel
"""
sharded_state_dict
=
super
().
sharded_state_dict
(
prefix
,
sharded_offsets
,
metadata
)
output_layer_extra_state_key
=
f
'
{
prefix
}
output_layer._extra_state'
# Old GPT checkpoints only stored the output layer weight key. So we remove the
# _extra_state key but check that it doesn't contain any data anyway
output_extra_state
=
sharded_state_dict
.
pop
(
output_layer_extra_state_key
,
None
)
assert
not
(
output_extra_state
and
output_extra_state
.
data
),
f
'Expected output layer extra state to be empty, got:
{
output_extra_state
}
'
# Multi-Token Prediction (MTP) need both embedding layer and output layer in
# mtp process stage.
# If MTP is not placed in the pre processing stage, we need to maintain a copy of
# embedding layer in the mtp process stage and tie it to the embedding in the pre
# processing stage.
# Also, if MTP is not placed in the post processing stage, we need to maintain a copy
# of output layer in the mtp process stage and tie it to the output layer in the post
# processing stage.
if
self
.
mtp_process
and
not
self
.
pre_process
:
emb_weight_key
=
f
'
{
prefix
}
embedding.word_embeddings.weight'
emb_weight
=
self
.
embedding
.
word_embeddings
.
weight
tie_word_embeddings_state_dict
(
sharded_state_dict
,
emb_weight
,
emb_weight_key
)
if
self
.
mtp_process
and
not
self
.
post_process
:
# We only need to tie the output layer weight if share_embeddings_and_output_weights
# is False. Because if share_embeddings_and_output_weights is True, the shared weight
# will be stored in embedding layer, and output layer will not have any weight.
if
not
self
.
share_embeddings_and_output_weights
:
output_layer_weight_key
=
f
'
{
prefix
}
output_layer.weight'
output_layer_weight
=
self
.
output_layer
.
weight
tie_output_layer_state_dict
(
sharded_state_dict
,
output_layer_weight
,
output_layer_weight_key
)
return
loss
return
sharded_state_dict
dcu_megatron/core/pipeline_parallel/schedules.py
0 → 100644
View file @
1f7b14ab
import
torch
from
functools
import
wraps
from
dcu_megatron.core.transformer.multi_token_prediction
import
MTPLossAutoScaler
def
forward_step_wrapper
(
fn
):
@
wraps
(
fn
)
def
wrapper
(
forward_step_func
,
data_iterator
,
model
,
num_microbatches
,
input_tensor
,
forward_data_store
,
config
,
**
kwargs
,
):
output
,
num_tokens
=
fn
(
forward_step_func
,
data_iterator
,
model
,
num_microbatches
,
input_tensor
,
forward_data_store
,
config
,
**
kwargs
)
if
not
isinstance
(
input_tensor
,
list
):
# unwrap_output_tensor True
output_tensor
=
output
else
:
output_tensor
=
output
[
0
]
# Set the loss scale for Multi-Token Prediction (MTP) loss.
if
hasattr
(
config
,
'mtp_num_layers'
)
and
config
.
mtp_num_layers
is
not
None
:
# Calculate the loss scale based on the grad_scale_func if available, else default to 1.
loss_scale
=
(
config
.
grad_scale_func
(
torch
.
ones
(
1
,
device
=
output_tensor
.
device
))
if
config
.
grad_scale_func
is
not
None
else
torch
.
ones
(
1
,
device
=
output_tensor
.
device
)
)
# Set the loss scale
if
config
.
calculate_per_token_loss
:
MTPLossAutoScaler
.
set_loss_scale
(
loss_scale
)
else
:
MTPLossAutoScaler
.
set_loss_scale
(
loss_scale
/
num_microbatches
)
return
output
,
num_tokens
return
wrapper
\ No newline at end of file
dcu_megatron/core/tensor_parallel/__init__.py
View file @
1f7b14ab
from
.layers
import
(
from
.layers
import
(
FluxColumnParallelLinear
,
FluxColumnParallelLinear
,
FluxRowParallelLinear
,
FluxRowParallelLinear
,
vocab_parallel_embedding_forward
,
vocab_parallel_embedding_init_wrapper
,
)
)
\ No newline at end of file
dcu_megatron/core/tensor_parallel/layers.py
View file @
1f7b14ab
import
os
import
os
import
copy
import
socket
import
socket
import
warnings
import
warnings
from
functools
import
wraps
from
typing
import
Callable
,
List
,
Optional
from
typing
import
Callable
,
List
,
Optional
if
int
(
os
.
getenv
(
"USE_FLUX_OVERLAP"
,
"0"
)):
try
:
try
:
import
flux
import
flux
except
ImportError
:
from
dcu_megatron.core.utils
import
is_flux_min_version
raise
ImportError
(
"flux is NOT installed"
)
except
ImportError
:
raise
ImportError
(
"flux is NOT installed"
)
import
torch
import
torch
import
torch.nn.functional
as
F
from
torch.nn.parameter
import
Parameter
from
megatron.training
import
print_rank_0
from
megatron.core.model_parallel_config
import
ModelParallelConfig
from
megatron.core.model_parallel_config
import
ModelParallelConfig
from
megatron.core.parallel_state
import
(
from
megatron.core.parallel_state
import
(
get_global_memory_buffer
,
get_global_memory_buffer
,
get_tensor_model_parallel_group
,
get_tensor_model_parallel_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
)
)
from
megatron.core.utils
import
(
from
megatron.core.utils
import
prepare_input_tensors_for_wgrad_compute
is_torch_min_version
,
prepare_input_tensors_for_wgrad_compute
)
from
megatron.core.tensor_parallel.layers
import
(
_initialize_affine_weight_cpu
,
_initialize_affine_weight_gpu
,
VocabParallelEmbedding
,
)
from
megatron.core.tensor_parallel.mappings
import
(
from
megatron.core.tensor_parallel.mappings
import
(
_reduce
,
copy_to_tensor_model_parallel_region
,
copy_to_tensor_model_parallel_region
,
reduce_from_tensor_model_parallel_region
,
reduce_from_tensor_model_parallel_region
,
reduce_scatter_to_sequence_parallel_region
,
_reduce_scatter_along_first_dim
,
_gather_along_first_dim
,
)
)
from
megatron.core.tensor_parallel.utils
import
VocabUtility
from
megatron.core.tensor_parallel.mappings
import
_reduce
from
megatron.core.tensor_parallel
import
(
from
megatron.core.tensor_parallel
import
(
ColumnParallelLinear
,
ColumnParallelLinear
,
RowParallelLinear
,
RowParallelLinear
,
...
@@ -50,9 +30,9 @@ from megatron.core.tensor_parallel.layers import (
...
@@ -50,9 +30,9 @@ from megatron.core.tensor_parallel.layers import (
custom_fwd
,
custom_fwd
,
custom_bwd
,
custom_bwd
,
dist_all_gather_func
,
dist_all_gather_func
,
linear_with_frozen_weight
,
linear_with_grad_accumulation_and_async_allreduce
)
)
from
dcu_megatron.core.utils
import
is_flux_min_version
_grad_accum_fusion_available
=
True
_grad_accum_fusion_available
=
True
try
:
try
:
...
@@ -61,74 +41,6 @@ except ImportError:
...
@@ -61,74 +41,6 @@ except ImportError:
_grad_accum_fusion_available
=
False
_grad_accum_fusion_available
=
False
def
vocab_parallel_embedding_init_wrapper
(
fn
):
@
wraps
(
fn
)
def
wrapper
(
self
,
*
args
,
skip_weight_param_allocation
:
bool
=
False
,
**
kwargs
):
if
(
skip_weight_param_allocation
and
"config"
in
kwargs
and
hasattr
(
kwargs
[
"config"
],
"perform_initialization"
)
):
config
=
copy
.
deepcopy
(
kwargs
[
"config"
])
config
.
perform_initialization
=
False
kwargs
[
"config"
]
=
config
fn
(
self
,
*
args
,
**
kwargs
)
if
skip_weight_param_allocation
:
self
.
weight
=
None
return
wrapper
@
torch
.
compile
(
mode
=
'max-autotune-no-cudagraphs'
)
def
vocab_parallel_embedding_forward
(
self
,
input_
,
weight
=
None
):
"""Forward.
Args:
input_ (torch.Tensor): Input tensor.
"""
if
weight
is
None
:
if
self
.
weight
is
None
:
raise
RuntimeError
(
"weight was not supplied to VocabParallelEmbedding forward pass "
"and skip_weight_param_allocation is True."
)
weight
=
self
.
weight
if
self
.
tensor_model_parallel_size
>
1
:
# Build the mask.
input_mask
=
(
input_
<
self
.
vocab_start_index
)
|
(
input_
>=
self
.
vocab_end_index
)
# Mask the input.
masked_input
=
input_
.
clone
()
-
self
.
vocab_start_index
masked_input
[
input_mask
]
=
0
else
:
masked_input
=
input_
# Get the embeddings.
if
self
.
deterministic_mode
:
output_parallel
=
weight
[
masked_input
]
else
:
# F.embedding currently has a non-deterministic backward function
output_parallel
=
F
.
embedding
(
masked_input
,
weight
)
# Mask the output embedding.
if
self
.
tensor_model_parallel_size
>
1
:
output_parallel
[
input_mask
,
:]
=
0.0
if
self
.
reduce_scatter_embeddings
:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
output_parallel
=
output_parallel
.
transpose
(
0
,
1
).
contiguous
()
output
=
reduce_scatter_to_sequence_parallel_region
(
output_parallel
)
else
:
# Reduce across all the model parallel GPUs.
output
=
reduce_from_tensor_model_parallel_region
(
output_parallel
)
return
output
def
get_tensor_model_parallel_node_size
(
group
=
None
):
def
get_tensor_model_parallel_node_size
(
group
=
None
):
""" 获取节点数
""" 获取节点数
"""
"""
...
...
dcu_megatron/core/transformer/multi_token_prediction.py
0 → 100755
View file @
1f7b14ab
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
from
contextlib
import
nullcontext
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
,
Union
import
torch
from
torch
import
Tensor
from
megatron.core
import
InferenceParams
,
mpu
,
parallel_state
,
tensor_parallel
from
megatron.core.dist_checkpointing.mapping
import
ShardedStateDict
from
megatron.core.dist_checkpointing.utils
import
replace_prefix_for_sharding
from
megatron.core.fusions.fused_layer_norm
import
FusedLayerNorm
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.tensor_parallel
import
(
all_gather_last_dim_from_tensor_parallel_region
,
scatter_to_sequence_parallel_region
,
)
from
megatron.core.tensor_parallel.layers
import
ColumnParallelLinear
from
megatron.core.transformer.enums
import
AttnMaskType
from
megatron.core.transformer.module
import
MegatronModule
from
megatron.core.transformer.spec_utils
import
ModuleSpec
,
build_module
from
megatron.core.transformer.transformer_block
import
TransformerBlockSubmodules
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core.utils
import
make_tp_sharded_tensor_for_checkpoint
,
make_viewless_tensor
SUPPORTED_ATTN_MASK
=
[
AttnMaskType
.
padding
,
AttnMaskType
.
causal
,
AttnMaskType
.
no_mask
,
AttnMaskType
.
padding_causal
,
]
try
:
from
megatron.core.extensions.transformer_engine
import
(
TEColumnParallelLinear
,
TEDelayedScaling
,
TENorm
,
)
HAVE_TE
=
True
except
ImportError
:
HAVE_TE
=
False
from
megatron.core.transformer.torch_norm
import
WrappedTorchNorm
try
:
import
apex
# pylint: disable=unused-import
from
megatron.core.fusions.fused_layer_norm
import
FusedLayerNorm
HAVE_APEX
=
True
LNImpl
=
FusedLayerNorm
except
ImportError
:
import
warnings
from
megatron.core.transformer.torch_norm
import
WrappedTorchNorm
warnings
.
warn
(
'Apex is not installed. Falling back to Torch Norm'
)
LNImpl
=
WrappedTorchNorm
def
tie_word_embeddings_state_dict
(
sharded_state_dict
:
ShardedStateDict
,
word_emb_weight
:
Tensor
,
word_emb_weight_key
:
str
)
->
None
:
"""tie the embedding of the mtp processing stage in a given sharded state dict.
Args:
sharded_state_dict (ShardedStateDict): state dict with the weight to tie.
word_emb_weight (Tensor): weight of the word embedding.
word_emb_weight_key (str): key of the word embedding in the sharded state dict.
Returns: None, acts in-place
"""
mtp_word_emb_replica_id
=
(
1
,
# copy of embedding in pre processing stage
0
,
parallel_state
.
get_data_parallel_rank
(
with_context_parallel
=
True
),
)
assert
word_emb_weight_key
in
sharded_state_dict
del
sharded_state_dict
[
word_emb_weight_key
]
sharded_state_dict
[
word_emb_weight_key
]
=
make_tp_sharded_tensor_for_checkpoint
(
tensor
=
word_emb_weight
,
key
=
word_emb_weight_key
,
replica_id
=
mtp_word_emb_replica_id
,
allow_shape_mismatch
=
True
,
)
def
tie_output_layer_state_dict
(
sharded_state_dict
:
ShardedStateDict
,
output_layer_weight
:
Tensor
,
output_layer_weight_key
:
str
)
->
None
:
"""tie the output layer of the mtp processing stage in a given sharded state dict.
Args:
sharded_state_dict (ShardedStateDict): state dict with the weight to tie.
output_layer_weight (Tensor): weight of the output layer.
output_layer_weight_key (str): key of the output layer in the sharded state dict.
Returns: None, acts in-place
"""
mtp_output_layer_replica_id
=
(
1
,
# copy of output layer in post processing stage
0
,
parallel_state
.
get_data_parallel_rank
(
with_context_parallel
=
True
),
)
assert
output_layer_weight_key
in
sharded_state_dict
del
sharded_state_dict
[
output_layer_weight_key
]
sharded_state_dict
[
output_layer_weight_key
]
=
make_tp_sharded_tensor_for_checkpoint
(
tensor
=
output_layer_weight
,
key
=
output_layer_weight_key
,
replica_id
=
mtp_output_layer_replica_id
,
allow_shape_mismatch
=
True
,
)
def
roll_tensor
(
tensor
,
shifts
=-
1
,
dims
=-
1
):
"""Roll the tensor input along the given dimension(s).
Inserted elements are set to be 0.0.
"""
rolled_tensor
=
torch
.
roll
(
tensor
,
shifts
=
shifts
,
dims
=
dims
)
rolled_tensor
.
select
(
dims
,
shifts
).
fill_
(
0
)
return
rolled_tensor
,
rolled_tensor
.
sum
()
class
MTPLossLoggingHelper
:
"""Helper class for logging MTP losses."""
tracker
=
{}
@
staticmethod
def
save_loss_to_tracker
(
loss
:
torch
.
Tensor
,
layer_number
:
int
,
num_layers
:
int
,
reduce_group
:
torch
.
distributed
.
ProcessGroup
=
None
,
avg_group
:
torch
.
distributed
.
ProcessGroup
=
None
,
):
"""Save the mtp loss for logging.
Args:
loss (torch.Tensor): The loss tensor.
layer_number (int): Layer index of the loss.
num_layers (int): The number of total layers.
reduce_group (torch.distributed.ProcessGroup): The group for reducing the loss.
mean_group (torch.distributed.ProcessGroup): The group for averaging the loss.
"""
# Skip mtp loss logging if layer_number is None.
if
layer_number
is
None
:
return
tracker
=
MTPLossLoggingHelper
.
tracker
if
"values"
not
in
tracker
:
tracker
[
"values"
]
=
torch
.
zeros
(
num_layers
,
device
=
loss
.
device
)
tracker
[
"values"
][
layer_number
]
+=
loss
.
detach
()
tracker
[
"reduce_group"
]
=
reduce_group
tracker
[
"avg_group"
]
=
avg_group
def
clean_loss_in_tracker
():
"""Clear the mtp losses."""
tracker
=
MTPLossLoggingHelper
.
tracker
tracker
[
"values"
].
zero_
()
tracker
[
"reduce_group"
]
=
None
tracker
[
"avg_group"
]
=
None
def
reduce_loss_in_tracker
():
"""Collect and reduce the mtp losses across ranks."""
tracker
=
MTPLossLoggingHelper
.
tracker
if
"values"
not
in
tracker
:
return
values
=
tracker
[
"values"
]
# Reduce mtp losses across ranks.
if
tracker
.
get
(
'reduce_group'
)
is
not
None
:
torch
.
distributed
.
all_reduce
(
values
,
group
=
tracker
.
get
(
'reduce_group'
))
if
tracker
.
get
(
'avg_group'
)
is
not
None
:
torch
.
distributed
.
all_reduce
(
values
,
group
=
tracker
[
'avg_group'
],
op
=
torch
.
distributed
.
ReduceOp
.
AVG
)
def
track_mtp_metrics
(
loss_scale
,
iteration
,
writer
,
wandb_writer
=
None
,
total_loss_dict
=
None
):
"""Track the Multi-Token Prediction (MTP) metrics for logging."""
MTPLossLoggingHelper
.
reduce_loss_in_tracker
()
tracker
=
MTPLossLoggingHelper
.
tracker
if
"values"
not
in
tracker
:
return
mtp_losses
=
tracker
[
"values"
]
*
loss_scale
mtp_num_layers
=
mtp_losses
.
shape
[
0
]
for
i
in
range
(
mtp_num_layers
):
name
=
f
"mtp_
{
i
+
1
}
loss"
loss
=
mtp_losses
[
i
]
if
total_loss_dict
is
not
None
:
total_loss_dict
[
name
]
=
loss
if
writer
is
not
None
:
writer
.
add_scalar
(
name
,
loss
,
iteration
)
if
wandb_writer
is
not
None
:
wandb_writer
.
log
({
f
"
{
name
}
"
:
loss
},
iteration
)
MTPLossLoggingHelper
.
clean_loss_in_tracker
()
@
dataclass
class
MultiTokenPredictionLayerSubmodules
:
"""
Dataclass for specifying the submodules of a MultiTokenPrediction module.
Args:
hnorm (Union[ModuleSpec, type]): Specification or instance of the
hidden states normalization to be applied.
enorm (Union[ModuleSpec, type]): Specification or instance of the
embedding normalization to be applied.
eh_proj (Union[ModuleSpec, type]): Specification or instance of the
linear projection to be applied.
transformer_layer (Union[ModuleSpec, type]): Specification
or instance of the transformer block to be applied.
"""
enorm
:
Union
[
ModuleSpec
,
type
]
=
None
hnorm
:
Union
[
ModuleSpec
,
type
]
=
None
eh_proj
:
Union
[
ModuleSpec
,
type
]
=
None
transformer_layer
:
Union
[
ModuleSpec
,
type
]
=
None
layer_norm
:
Union
[
ModuleSpec
,
type
]
=
None
def
get_mtp_layer_spec
(
transformer_layer_spec
:
ModuleSpec
,
use_transformer_engine
:
bool
)
->
ModuleSpec
:
"""Get the MTP layer spec.
Returns:
ModuleSpec: Module specification with TE modules
"""
if
use_transformer_engine
:
assert
HAVE_TE
,
"transformer_engine should be installed if use_transformer_engine is True"
layer_norm_impl
=
TENorm
column_parallel_linear_impl
=
TEColumnParallelLinear
else
:
layer_norm_impl
=
LNImpl
column_parallel_linear_impl
=
ColumnParallelLinear
mtp_layer_spec
=
ModuleSpec
(
module
=
MultiTokenPredictionLayer
,
submodules
=
MultiTokenPredictionLayerSubmodules
(
enorm
=
layer_norm_impl
,
hnorm
=
layer_norm_impl
,
eh_proj
=
column_parallel_linear_impl
,
transformer_layer
=
transformer_layer_spec
,
layer_norm
=
layer_norm_impl
,
),
)
return
mtp_layer_spec
def
get_mtp_layer_offset
(
config
:
TransformerConfig
)
->
int
:
"""Get the offset of the MTP layer."""
# Currently, we only support put all of MTP layers on the last pipeline stage.
return
0
def
get_mtp_num_layers_to_build
(
config
:
TransformerConfig
)
->
int
:
"""Get the number of MTP layers to build."""
# Currently, we only support put all of MTP layers on the last pipeline stage.
if
mpu
.
is_pipeline_last_stage
():
return
config
.
mtp_num_layers
if
config
.
mtp_num_layers
else
0
else
:
return
0
class
MTPLossAutoScaler
(
torch
.
autograd
.
Function
):
"""An AutoScaler that triggers the backward pass and scales the grad for mtp loss."""
main_loss_backward_scale
:
torch
.
Tensor
=
torch
.
tensor
(
1.0
)
@
staticmethod
def
forward
(
ctx
,
output
:
torch
.
Tensor
,
mtp_loss
:
torch
.
Tensor
):
"""Preserve the mtp by storing it in the context to avoid garbage collection.
Args:
output (torch.Tensor): The output tensor.
mtp_loss (torch.Tensor): The mtp loss tensor.
Returns:
torch.Tensor: The output tensor.
"""
ctx
.
save_for_backward
(
mtp_loss
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
:
torch
.
Tensor
):
"""Compute and scale the gradient for mtp loss..
Args:
grad_output (torch.Tensor): The gradient of the output.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The gradient of the output, scaled mtp loss
gradient.
"""
(
mtp_loss
,)
=
ctx
.
saved_tensors
mtp_loss_backward_scale
=
MTPLossAutoScaler
.
main_loss_backward_scale
scaled_mtp_loss_grad
=
torch
.
ones_like
(
mtp_loss
)
*
mtp_loss_backward_scale
return
grad_output
,
scaled_mtp_loss_grad
@
staticmethod
def
set_loss_scale
(
scale
:
torch
.
Tensor
):
"""set the scale of the mtp loss.
Args:
scale (torch.Tensor): The scale value to set. Please ensure that the scale passed in
matches the scale of the main_loss.
"""
MTPLossAutoScaler
.
main_loss_backward_scale
=
scale
class
MultiTokenPredictionLayer
(
MegatronModule
):
"""The implementation for Multi-Token Prediction (MTP) which extends
the prediction scope to multiple future tokens at each position.
This MTP implementation sequentially predict additional tokens and keep the complete
causal chain at each prediction depth, by using D sequential modules to predict
D additional tokens.
The k-th MTP module consists of a shared embedding layer, a projection matrix,
a Transformer block, and a shared output head.
For the i-th input token at the (k - 1)-th prediction depth, we first combine
the representation of the i-th token and the embedding of the (i + K)-th token with
the linear projection. The combined serves as the input of the Transformer block at
the k-th depth to produce the output representation.
for more information, please refer to DeepSeek-V3 Technical Report
https://github.com/deepseek-ai/DeepSeek-V3/blob/main/DeepSeek_V3.pdf
"""
def
__init__
(
self
,
config
:
TransformerConfig
,
submodules
:
MultiTokenPredictionLayerSubmodules
,
layer_number
:
int
=
1
,
):
super
().
__init__
(
config
=
config
)
self
.
sequence_parallel
=
config
.
sequence_parallel
self
.
submodules
=
submodules
self
.
layer_number
=
layer_number
self_attention_spec
=
self
.
submodules
.
transformer_layer
.
submodules
.
self_attention
attn_mask_type
=
self_attention_spec
.
params
.
get
(
'attn_mask_type'
,
''
)
assert
attn_mask_type
in
SUPPORTED_ATTN_MASK
,
(
f
"Multi-Token Prediction (MTP) is not jet supported with "
+
f
"
{
attn_mask_type
}
attention mask type."
+
f
"The supported attention mask types are
{
SUPPORTED_ATTN_MASK
}
."
)
self
.
enorm
=
build_module
(
self
.
submodules
.
enorm
,
config
=
self
.
config
,
hidden_size
=
self
.
config
.
hidden_size
,
eps
=
self
.
config
.
layernorm_epsilon
,
)
self
.
hnorm
=
build_module
(
self
.
submodules
.
hnorm
,
config
=
self
.
config
,
hidden_size
=
self
.
config
.
hidden_size
,
eps
=
self
.
config
.
layernorm_epsilon
,
)
# For the linear projection at the (k - 1)-th MTP layer, the input is the concatenation
# of the i-th tocken's hidden states and the (i + K)-th tocken's decoder input,
# so the input's shape is [s, b, 2*h].
# The output will be send to the following transformer layer,
# so the output's shape should be [s, b, h].
self
.
eh_proj
=
build_module
(
self
.
submodules
.
eh_proj
,
self
.
config
.
hidden_size
*
2
,
self
.
config
.
hidden_size
,
config
=
self
.
config
,
init_method
=
self
.
config
.
init_method
,
gather_output
=
False
,
bias
=
False
,
skip_bias_add
=
False
,
is_expert
=
False
,
)
self
.
transformer_layer
=
build_module
(
self
.
submodules
.
transformer_layer
,
config
=
self
.
config
)
self
.
final_layernorm
=
build_module
(
self
.
submodules
.
layer_norm
,
config
=
self
.
config
,
hidden_size
=
self
.
config
.
hidden_size
,
eps
=
self
.
config
.
layernorm_epsilon
,
)
def
forward
(
self
,
decoder_input
:
Tensor
,
hidden_states
:
Tensor
,
attention_mask
:
Tensor
,
context
:
Tensor
=
None
,
context_mask
:
Tensor
=
None
,
rotary_pos_emb
:
Tensor
=
None
,
rotary_pos_cos
:
Tensor
=
None
,
rotary_pos_sin
:
Tensor
=
None
,
attention_bias
:
Tensor
=
None
,
inference_params
:
InferenceParams
=
None
,
packed_seq_params
:
PackedSeqParams
=
None
,
sequence_len_offset
:
Tensor
=
None
,
):
"""
Perform the forward pass through the MTP layer.
Args:
hidden_states (Tensor): hidden states tensor of shape [s, b, h] where s is the
sequence length, b is the batch size, and h is the hidden size.
decoder_input (Tensor): Input tensor of shape [s, b, h] where s is the
sequence length, b is the batch size, and h is the hidden size.
At the (k - 1)-th MTP module, the i-th element of decoder input is
the embedding of (i + K)-th tocken.
attention_mask (Tensor): Boolean tensor of shape [1, 1, s, s] for masking
self-attention.
context (Tensor, optional): Context tensor for cross-attention.
context_mask (Tensor, optional): Mask for cross-attention context
rotary_pos_emb (Tensor, optional): Rotary positional embeddings.
attention_bias (Tensor): Bias tensor for Q * K.T of shape in shape broadcastable
to [b, num_head, sq, skv], e.g. [1, 1, sq, skv].
Used as an alternative to apply attention mask for TE cuDNN attention.
inference_params (InferenceParams, optional): Parameters for inference-time
optimizations.
packed_seq_params (PackedSeqParams, optional): Parameters for packed sequence
processing.
Returns:
Union[Tensor, Tuple[Tensor, Tensor]]: The output hidden states tensor of shape
[s, b, h], and optionally the updated context tensor if cross-attention is used.
"""
assert
context
is
None
,
f
"multi token prediction + cross attention is not yet supported."
assert
(
packed_seq_params
is
None
),
f
"multi token prediction + sequence packing is not yet supported."
hidden_states
=
make_viewless_tensor
(
inp
=
hidden_states
,
requires_grad
=
True
,
keep_graph
=
True
)
if
self
.
config
.
sequence_parallel
:
rng_context
=
tensor_parallel
.
get_cuda_rng_tracker
().
fork
()
else
:
rng_context
=
nullcontext
()
if
self
.
config
.
fp8
:
import
transformer_engine
# To keep out TE dependency when not training in fp8
if
self
.
config
.
fp8
==
"e4m3"
:
fp8_format
=
transformer_engine
.
common
.
recipe
.
Format
.
E4M3
elif
self
.
config
.
fp8
==
"hybrid"
:
fp8_format
=
transformer_engine
.
common
.
recipe
.
Format
.
HYBRID
else
:
raise
ValueError
(
"E4M3 and HYBRID are the only supported FP8 formats."
)
fp8_recipe
=
TEDelayedScaling
(
config
=
self
.
config
,
fp8_format
=
fp8_format
,
override_linear_precision
=
(
False
,
False
,
not
self
.
config
.
fp8_wgrad
),
)
fp8_group
=
None
if
parallel_state
.
model_parallel_is_initialized
():
fp8_group
=
parallel_state
.
get_amax_reduction_group
(
with_context_parallel
=
True
,
tp_only_amax_red
=
self
.
tp_only_amax_red
)
fp8_context
=
transformer_engine
.
pytorch
.
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
fp8_recipe
,
fp8_group
=
fp8_group
)
else
:
fp8_context
=
nullcontext
()
with
rng_context
,
fp8_context
:
decoder_input
=
self
.
enorm
(
decoder_input
)
decoder_input
=
make_viewless_tensor
(
inp
=
decoder_input
,
requires_grad
=
True
,
keep_graph
=
True
)
hidden_states
=
self
.
hnorm
(
hidden_states
)
hidden_states
=
make_viewless_tensor
(
inp
=
hidden_states
,
requires_grad
=
True
,
keep_graph
=
True
)
# At the (k - 1)-th MTP module, concatenates the i-th tocken's hidden_states
# and the (i + K)-th tocken's embedding, and combine them with linear projection.
hidden_states
=
torch
.
cat
((
decoder_input
,
hidden_states
),
-
1
)
hidden_states
,
_
=
self
.
eh_proj
(
hidden_states
)
# For tensor parallel, all gather after linear_fc.
hidden_states
=
all_gather_last_dim_from_tensor_parallel_region
(
hidden_states
)
# For sequence parallel, scatter after linear_fc and before transformer layer.
if
self
.
sequence_parallel
:
hidden_states
=
scatter_to_sequence_parallel_region
(
hidden_states
)
hidden_states
,
_
=
self
.
transformer_layer
(
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
context
=
context
,
context_mask
=
context_mask
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_cos
=
rotary_pos_cos
,
rotary_pos_sin
=
rotary_pos_sin
,
attention_bias
=
attention_bias
,
inference_params
=
inference_params
,
packed_seq_params
=
packed_seq_params
,
sequence_len_offset
=
sequence_len_offset
,
)
# Layer norm before shared head layer.
hidden_states
=
self
.
final_layernorm
(
hidden_states
)
# TENorm produces a "viewed" tensor. This will result in schedule.py's
# deallocate_output_tensor() throwing an error, so a viewless tensor is
# created to prevent this.
hidden_states
=
make_viewless_tensor
(
inp
=
hidden_states
,
requires_grad
=
True
,
keep_graph
=
True
)
return
hidden_states
def
sharded_state_dict
(
self
,
prefix
:
str
=
''
,
sharded_offsets
:
tuple
=
(),
metadata
:
Optional
[
dict
]
=
None
)
->
ShardedStateDict
:
"""
Generate a sharded state dictionary for the multi token prediction layer.
Args:
prefix (str, optional): Prefix to be added to all keys in the state dict.
sharded_offsets (tuple, optional): Tuple of sharding offsets.
metadata (Optional[dict], optional): Additional metadata for sharding.
Returns:
ShardedStateDict: A dictionary containing the sharded state of the multi
token prediction layer.
"""
sharded_state_dict
=
super
().
sharded_state_dict
(
prefix
,
sharded_offsets
,
metadata
)
return
sharded_state_dict
@
dataclass
class
MultiTokenPredictionBlockSubmodules
:
"""
Dataclass for specifying the submodules of a multi token prediction block.
This class defines the structure for configuring the layers, allowing for
flexible and customizable architecture designs.
Args:
layer_specs (List[ModuleSpec], optional): A list of module specifications for
the layers within the multi token prediction block. Each specification typically
defines a complete multi token prediction layer (e.g., shared embedding,
projection matrix, transformer block, shared output head).
"""
layer_specs
:
List
[
ModuleSpec
]
=
None
def
_get_mtp_block_submodules
(
config
:
TransformerConfig
,
spec
:
Union
[
MultiTokenPredictionBlockSubmodules
,
ModuleSpec
]
)
->
MultiTokenPredictionBlockSubmodules
:
"""
Retrieve or construct MultiTokenPredictionBlockSubmodules based on the provided specification.
Args:
config (TransformerConfig): Configuration object for the transformer model.
spec (Union[MultiTokenPredictionBlockSubmodules, ModuleSpec]): Specification for the
multi token prediction block submodules.
Can be either a MultiTokenPredictionBlockSubmodules instance or a ModuleSpec.
Returns:
MultiTokenPredictionBlockSubmodules: The submodules for the multi token prediction block.
"""
# Transformer block submodules.
if
isinstance
(
spec
,
MultiTokenPredictionBlockSubmodules
):
return
spec
elif
isinstance
(
spec
,
ModuleSpec
):
if
issubclass
(
spec
.
module
,
MultiTokenPredictionBlock
):
return
spec
.
submodules
else
:
raise
Exception
(
f
"specialize for
{
spec
.
module
.
__name__
}
."
)
else
:
raise
Exception
(
f
"specialize for
{
type
(
spec
).
__name__
}
."
)
class
MultiTokenPredictionBlock
(
MegatronModule
):
"""The implementation for Multi-Token Prediction (MTP) which extends
the prediction scope to multiple future tokens at each position.
This MTP implementation sequentially predict additional tokens and keep the complete
causal chain at each prediction depth, by using D sequential modules to predict
D additional tokens.
The k-th MTP module consists of a shared embedding layer, a projection matrix,
a Transformer block, and a shared output head.
For the i-th input token at the (k - 1)-th prediction depth, we first combine
the representation of the i-th token and the embedding of the (i + K)-th token with
the linear projection. The combined serves as the input of the Transformer block at
the k-th depth to produce the output representation.
for more information, please refer to DeepSeek-V3 Technical Report
https://github.com/deepseek-ai/DeepSeek-V3/blob/main/DeepSeek_V3.pdf
"""
def
__init__
(
self
,
config
:
TransformerConfig
,
spec
:
Union
[
TransformerBlockSubmodules
,
ModuleSpec
]
):
super
().
__init__
(
config
=
config
)
self
.
submodules
=
_get_mtp_block_submodules
(
config
,
spec
)
self
.
mtp_loss_scaling_factor
=
config
.
mtp_loss_scaling_factor
self
.
_build_layers
()
assert
len
(
self
.
layers
)
>
0
,
"MultiTokenPredictionBlock must have at least one layer."
def
_build_layers
(
self
):
def
build_layer
(
layer_spec
,
layer_number
):
return
build_module
(
layer_spec
,
config
=
self
.
config
,
layer_number
=
layer_number
)
self
.
layers
=
torch
.
nn
.
ModuleList
(
[
build_layer
(
layer_spec
,
i
+
1
)
for
i
,
layer_spec
in
enumerate
(
self
.
submodules
.
layer_specs
)
]
)
def
forward
(
self
,
input_ids
:
Tensor
,
position_ids
:
Tensor
,
hidden_states
:
Tensor
,
attention_mask
:
Tensor
,
labels
:
Tensor
=
None
,
context
:
Tensor
=
None
,
context_mask
:
Tensor
=
None
,
rotary_pos_emb
:
Tensor
=
None
,
rotary_pos_cos
:
Tensor
=
None
,
rotary_pos_sin
:
Tensor
=
None
,
attention_bias
:
Tensor
=
None
,
inference_params
:
InferenceParams
=
None
,
packed_seq_params
:
PackedSeqParams
=
None
,
sequence_len_offset
:
Tensor
=
None
,
extra_block_kwargs
:
dict
=
None
,
runtime_gather_output
:
Optional
[
bool
]
=
None
,
loss_mask
:
Optional
[
Tensor
]
=
None
,
embedding
=
None
,
output_layer
=
None
,
output_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
compute_language_model_loss
=
None
,
)
->
Tensor
:
"""
Perform the forward pass through all of the MTP modules.
Args:
hidden_states (Tensor): Hidden states for input token with the shape [s, b, h]
where s is the sequence length, b is the batch size, and h is the hidden size.
attention_mask (Tensor): Boolean tensor of shape [1, 1, s, s] for masking
self-attention.
Returns:
(Tensor): The mtp loss tensor of shape [b, s].
"""
assert
(
labels
is
not
None
),
f
"labels should not be None for calculating multi token prediction loss."
if
loss_mask
is
None
:
# if loss_mask is not provided, use all ones as loss_mask
loss_mask
=
torch
.
ones_like
(
labels
)
hidden_states_main_model
=
hidden_states
for
layer_number
in
range
(
len
(
self
.
layers
)):
# Calc logits for the current Multi-Token Prediction (MTP) layers.
input_ids
,
_
=
roll_tensor
(
input_ids
,
shifts
=-
1
,
dims
=-
1
)
# embedding
decoder_input
=
embedding
(
input_ids
=
input_ids
,
position_ids
=
position_ids
)
# norm, linear projection and transformer
hidden_states
=
self
.
layers
[
layer_number
](
decoder_input
=
decoder_input
,
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
inference_params
=
inference_params
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_cos
=
rotary_pos_cos
,
rotary_pos_sin
=
rotary_pos_sin
,
packed_seq_params
=
packed_seq_params
,
sequence_len_offset
=
sequence_len_offset
,
**
(
extra_block_kwargs
or
{}),
)
# output
mtp_logits
,
_
=
output_layer
(
hidden_states
,
weight
=
output_weight
,
runtime_gather_output
=
runtime_gather_output
)
# Calc loss for the current Multi-Token Prediction (MTP) layers.
labels
,
_
=
roll_tensor
(
labels
,
shifts
=-
1
,
dims
=-
1
)
loss_mask
,
num_tokens
=
roll_tensor
(
loss_mask
,
shifts
=-
1
,
dims
=-
1
)
mtp_loss
=
compute_language_model_loss
(
labels
,
mtp_logits
)
mtp_loss
=
loss_mask
*
mtp_loss
if
self
.
training
:
MTPLossLoggingHelper
.
save_loss_to_tracker
(
torch
.
sum
(
mtp_loss
)
/
num_tokens
,
layer_number
,
self
.
config
.
mtp_num_layers
,
avg_group
=
parallel_state
.
get_tensor_and_context_parallel_group
(),
)
mtp_loss_scale
=
self
.
mtp_loss_scaling_factor
/
self
.
config
.
mtp_num_layers
if
self
.
config
.
calculate_per_token_loss
:
hidden_states_main_model
=
MTPLossAutoScaler
.
apply
(
hidden_states_main_model
,
mtp_loss_scale
*
mtp_loss
)
else
:
hidden_states_main_model
=
MTPLossAutoScaler
.
apply
(
hidden_states_main_model
,
mtp_loss_scale
*
mtp_loss
/
num_tokens
)
return
hidden_states_main_model
def
sharded_state_dict
(
self
,
prefix
:
str
=
''
,
sharded_offsets
:
tuple
=
(),
metadata
:
Optional
[
dict
]
=
None
)
->
ShardedStateDict
:
"""
Generate a sharded state dictionary for the multi token prediction module.
Args:
prefix (str, optional): Prefix to be added to all keys in the state dict.
sharded_offsets (tuple, optional): Tuple of sharding offsets.
metadata (Optional[dict], optional): Additional metadata for sharding.
Returns:
ShardedStateDict: A dictionary containing the sharded state of the multi
token prediction module.
"""
sharded_state_dict
=
super
().
sharded_state_dict
(
prefix
,
sharded_offsets
,
metadata
)
layer_prefix
=
f
'
{
prefix
}
layers.'
for
layer
in
self
.
layers
:
offset
=
get_mtp_layer_offset
(
self
.
config
)
sharded_prefix
=
f
'
{
layer_prefix
}{
layer
.
layer_number
-
1
}
.'
state_dict_prefix
=
f
'
{
layer_prefix
}{
layer
.
layer_number
-
1
-
offset
}
.'
sharded_pp_offset
=
[]
layer_sharded_state_dict
=
layer
.
sharded_state_dict
(
state_dict_prefix
,
sharded_pp_offset
,
metadata
)
replace_prefix_for_sharding
(
layer_sharded_state_dict
,
state_dict_prefix
,
sharded_prefix
)
sharded_state_dict
.
update
(
layer_sharded_state_dict
)
return
sharded_state_dict
dcu_megatron/core/transformer/transformer_block.py
View file @
1f7b14ab
...
@@ -8,7 +8,7 @@ def transformer_block_init_wrapper(fn):
...
@@ -8,7 +8,7 @@ def transformer_block_init_wrapper(fn):
# mtp require seperate layernorms for main model and mtp modules, thus move finalnorm out of block
# mtp require seperate layernorms for main model and mtp modules, thus move finalnorm out of block
config
=
args
[
0
]
if
len
(
args
)
>
1
else
kwargs
[
'config'
]
config
=
args
[
0
]
if
len
(
args
)
>
1
else
kwargs
[
'config'
]
if
getattr
(
config
,
"
num_nextn_predict
_layers"
,
0
)
>
0
:
if
getattr
(
config
,
"
mtp_num
_layers"
,
0
)
>
0
:
self
.
main_final_layernorm
=
self
.
final_layernorm
self
.
main_final_layernorm
=
self
.
final_layernorm
self
.
final_layernorm
=
None
self
.
final_layernorm
=
None
...
...
dcu_megatron/core/transformer/transformer_config.py
View file @
1f7b14ab
from
typing
import
Optional
from
functools
import
wraps
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
megatron.training
import
get_args
from
megatron.core.transformer.transformer_config
import
TransformerConfig
,
MLATransformerConfig
from
megatron.core.transformer.transformer_config
import
TransformerConfig
,
MLATransformerConfig
def
transformer_config_post_init_wrapper
(
fn
):
@
wraps
(
fn
)
def
wrapper
(
self
):
fn
(
self
)
args
=
get_args
()
"""Number of Multi-Token Prediction (MTP) Layers."""
self
.
mtp_num_layers
=
args
.
mtp_num_layers
"""Weighting factor of Multi-Token Prediction (MTP) loss."""
self
.
mtp_loss_scaling_factor
=
args
.
mtp_loss_scaling_factor
##################
# flux
##################
self
.
flux_transpose_weight
=
args
.
flux_transpose_weight
return
wrapper
@
dataclass
@
dataclass
class
ExtraTransformerConfig
:
class
ExtraTransformerConfig
:
##################
##################
# multi-token prediction
# multi-token prediction
##################
##################
num_nextn_predict_layers
:
int
=
0
mtp_num_layers
:
Optional
[
int
]
=
None
"""The number of multi-token prediction layers"""
"""Number of Multi-Token Prediction (MTP) Layers."""
mtp_loss_scale
:
float
=
0.3
"""Multi-token prediction loss scale"""
recompute_mtp_norm
:
bool
=
False
"""Whether to recompute mtp normalization"""
recompute_mtp_layer
:
bool
=
False
"""Whether to recompute mtp layer"""
share_mtp_embedding_and_output_weight
:
bool
=
Fals
e
mtp_loss_scaling_factor
:
Optional
[
float
]
=
Non
e
"""
share embedding and output weight with mtp layer
."""
"""
Weighting factor of Multi-Token Prediction (MTP) loss
."""
##################
##################
# flux
# flux
...
...
dcu_megatron/training/arguments.py
View file @
1f7b14ab
...
@@ -170,14 +170,16 @@ def _add_extra_tokenizer_args(parser):
...
@@ -170,14 +170,16 @@ def _add_extra_tokenizer_args(parser):
def
_add_mtp_args
(
parser
):
def
_add_mtp_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'multi token prediction'
)
group
=
parser
.
add_argument_group
(
title
=
'multi token prediction'
)
group
.
add_argument
(
'--num-nextn-predict-layers'
,
type
=
int
,
default
=
0
,
help
=
'Multi-Token prediction layer num'
)
group
.
add_argument
(
'--mtp-num-layers'
,
type
=
int
,
default
=
None
,
group
.
add_argument
(
'--mtp-loss-scale'
,
type
=
float
,
default
=
0.3
,
help
=
'Multi-Token prediction loss scale'
)
help
=
'Number of Multi-Token Prediction (MTP) Layers.'
group
.
add_argument
(
'--recompute-mtp-norm'
,
action
=
'store_true'
,
default
=
False
,
'MTP extends the prediction scope to multiple future tokens at each position.'
help
=
'Multi-Token prediction recompute norm'
)
'This MTP implementation sequentially predict additional tokens '
group
.
add_argument
(
'--recompute-mtp-layer'
,
action
=
'store_true'
,
default
=
False
,
'by using D sequential modules to predict D additional tokens.'
)
help
=
'Multi-Token prediction recompute layer'
)
group
.
add_argument
(
'--mtp-loss-scaling-factor'
,
type
=
float
,
default
=
0.3
,
group
.
add_argument
(
'--share-mtp-embedding-and-output-weight'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Scaling factor of Multi-Token Prediction (MTP) loss. '
help
=
'Main model share embedding and output weight with mtp layer.'
)
'We compute the average of the MTP losses across all depths, '
'and multiply it the scaling factor to obtain the overall MTP loss, '
'which serves as an additional training objective.'
)
return
parser
return
parser
...
...
dcu_megatron/training/utils.py
View file @
1f7b14ab
...
@@ -9,103 +9,97 @@ def get_batch_on_this_tp_rank(data_iterator):
...
@@ -9,103 +9,97 @@ def get_batch_on_this_tp_rank(data_iterator):
args
=
get_args
()
args
=
get_args
()
def
_broadcast
(
item
):
def
_broadcast
(
item
):
if
item
is
not
None
:
if
item
is
not
None
:
torch
.
distributed
.
broadcast
(
item
,
mpu
.
get_tensor_model_parallel_src_rank
(),
group
=
mpu
.
get_tensor_model_parallel_group
())
torch
.
distributed
.
broadcast
(
item
,
mpu
.
get_tensor_model_parallel_src_rank
(),
group
=
mpu
.
get_tensor_model_parallel_group
())
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
if
data_iterator
is
not
None
:
if
data_iterator
is
not
None
:
data
=
next
(
data_iterator
)
data
=
next
(
data_iterator
)
else
:
else
:
data
=
None
data
=
None
batch
=
{
batch
=
{
'tokens'
:
data
[
"tokens"
].
cuda
(
non_blocking
=
True
),
'tokens'
:
data
[
"tokens"
].
cuda
(
non_blocking
=
True
),
'labels'
:
data
[
"labels"
].
cuda
(
non_blocking
=
True
),
'labels'
:
data
[
"labels"
].
cuda
(
non_blocking
=
True
),
'loss_mask'
:
data
[
"loss_mask"
].
cuda
(
non_blocking
=
True
),
'loss_mask'
:
data
[
"loss_mask"
].
cuda
(
non_blocking
=
True
),
'attention_mask'
:
None
if
"attention_mask"
not
in
data
else
data
[
"attention_mask"
].
cuda
(
non_blocking
=
True
),
'attention_mask'
:
None
if
"attention_mask"
not
in
data
else
data
[
"attention_mask"
].
cuda
(
non_blocking
=
True
),
'position_ids'
:
data
[
"position_ids"
].
cuda
(
non_blocking
=
True
)
'position_ids'
:
data
[
"position_ids"
].
cuda
(
non_blocking
=
True
)
}
}
if
args
.
pipeline_model_parallel_size
==
1
:
if
args
.
pipeline_model_parallel_size
==
1
:
_broadcast
(
batch
[
'tokens'
])
_broadcast
(
batch
[
'tokens'
])
_broadcast
(
batch
[
'labels'
])
_broadcast
(
batch
[
'labels'
])
_broadcast
(
batch
[
'loss_mask'
])
_broadcast
(
batch
[
'loss_mask'
])
_broadcast
(
batch
[
'attention_mask'
])
_broadcast
(
batch
[
'attention_mask'
])
_broadcast
(
batch
[
'position_ids'
])
_broadcast
(
batch
[
'position_ids'
])
elif
mpu
.
is_pipeline_first_stage
():
elif
mpu
.
is_pipeline_first_stage
():
_broadcast
(
batch
[
'tokens'
])
_broadcast
(
batch
[
'tokens'
])
_broadcast
(
batch
[
'attention_mask'
])
_broadcast
(
batch
[
'attention_mask'
])
_broadcast
(
batch
[
'position_ids'
])
_broadcast
(
batch
[
'position_ids'
])
elif
mpu
.
is_pipeline_last_stage
():
elif
mpu
.
is_pipeline_last_stage
():
if
args
.
num_nextn_predict_layers
:
# Multi-Token Prediction (MTP) layers need tokens and position_ids to calculate embedding.
# Currently the Multi-Token Prediction (MTP) layers is fixed on the last stage, so we need
# to broadcast tokens and position_ids to all of the tensor parallel ranks on the last stage.
if
args
.
mtp_num_layers
is
not
None
:
_broadcast
(
batch
[
'tokens'
])
_broadcast
(
batch
[
'tokens'
])
_broadcast
(
batch
[
'labels'
])
_broadcast
(
batch
[
'loss_mask'
])
_broadcast
(
batch
[
'attention_mask'
])
if
args
.
reset_position_ids
or
args
.
num_nextn_predict_layers
:
_broadcast
(
batch
[
'position_ids'
])
_broadcast
(
batch
[
'position_ids'
])
_broadcast
(
batch
[
'labels'
])
_broadcast
(
batch
[
'loss_mask'
])
_broadcast
(
batch
[
'attention_mask'
])
else
:
else
:
tokens
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
+
args
.
num_nextn_predict_layers
),
dtype
=
torch
.
int64
,
tokens
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
),
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
device
=
torch
.
cuda
.
current_device
())
labels
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
),
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
labels
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
+
args
.
num_nextn_predict_layers
),
loss_mask
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
),
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
())
dtype
=
torch
.
int64
,
if
args
.
create_attention_mask_in_dataloader
:
device
=
torch
.
cuda
.
current_device
())
attention_mask
=
torch
.
empty
(
loss_mask
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
+
args
.
num_nextn_predict_layers
),
(
args
.
micro_batch_size
,
1
,
args
.
seq_length
,
args
.
seq_length
),
dtype
=
torch
.
bool
,
device
=
torch
.
cuda
.
current_device
()
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
())
if
args
.
create_attention_mask_in_dataloader
:
attention_mask
=
torch
.
empty
(
(
args
.
micro_batch_size
,
1
,
args
.
seq_length
+
args
.
num_nextn_predict_layers
,
args
.
seq_length
+
args
.
num_nextn_predict_layers
),
dtype
=
torch
.
bool
,
device
=
torch
.
cuda
.
current_device
()
)
)
else
:
else
:
attention_mask
=
None
attention_mask
=
None
position_ids
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
+
args
.
num_nextn_predict_layers
),
position_ids
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
),
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
if
args
.
pipeline_model_parallel_size
==
1
:
_broadcast
(
tokens
)
if
args
.
pipeline_model_parallel_size
==
1
:
_broadcast
(
labels
)
_broadcast
(
tokens
)
_broadcast
(
loss_mask
)
_broadcast
(
labels
)
_broadcast
(
attention_mask
)
_broadcast
(
loss_mask
)
_broadcast
(
position_ids
)
_broadcast
(
attention_mask
)
_broadcast
(
position_ids
)
elif
mpu
.
is_pipeline_first_stage
():
labels
=
None
elif
mpu
.
is_pipeline_first_stage
():
loss_mask
=
None
labels
=
None
loss_mask
=
None
_broadcast
(
tokens
)
_broadcast
(
attention_mask
)
_broadcast
(
tokens
)
_broadcast
(
position_ids
)
_broadcast
(
attention_mask
)
_broadcast
(
position_ids
)
elif
mpu
.
is_pipeline_last_stage
():
# Multi-Token Prediction (MTP) layers need tokens and position_ids to calculate embedding.
elif
mpu
.
is_pipeline_last_stage
():
# Currently the Multi-Token Prediction (MTP) layers is fixed on the last stage, so we need
if
args
.
num_nextn_predict_layers
:
# to broadcast tokens and position_ids to all of the tensor parallel ranks on the last stage.
if
args
.
mtp_num_layers
is
not
None
:
_broadcast
(
tokens
)
_broadcast
(
tokens
)
else
:
tokens
=
None
_broadcast
(
labels
)
_broadcast
(
loss_mask
)
_broadcast
(
attention_mask
)
if
args
.
reset_position_ids
or
args
.
num_nextn_predict_layers
:
_broadcast
(
position_ids
)
_broadcast
(
position_ids
)
else
:
else
:
position_ids
=
None
tokens
=
None
position_ids
=
None
batch
=
{
'tokens'
:
tokens
,
_broadcast
(
labels
)
'labels'
:
labels
,
_broadcast
(
loss_mask
)
'loss_mask'
:
loss_mask
,
_broadcast
(
attention_mask
)
'attention_mask'
:
attention_mask
,
'position_ids'
:
position_ids
batch
=
{
}
'tokens'
:
tokens
,
'labels'
:
labels
,
'loss_mask'
:
loss_mask
,
'attention_mask'
:
attention_mask
,
'position_ids'
:
position_ids
}
return
batch
return
batch
pretrain_gpt.py
View file @
1f7b14ab
...
@@ -39,9 +39,7 @@ from megatron.core.models.gpt.gpt_layer_specs import (
...
@@ -39,9 +39,7 @@ from megatron.core.models.gpt.gpt_layer_specs import (
get_gpt_layer_with_transformer_engine_spec
,
get_gpt_layer_with_transformer_engine_spec
,
)
)
from
megatron.core.transformer.transformer_block
import
TransformerBlockSubmodules
from
dcu_megatron.core.models.gpt.gpt_layer_specs
import
get_gpt_mtp_block_spec
from
dcu_megatron.core.transformer.mtp.mtp_spec
import
get_mtp_spec
from
dcu_megatron.core.utils
import
tensor_slide
from
dcu_megatron
import
megatron_adaptor
from
dcu_megatron
import
megatron_adaptor
...
@@ -133,13 +131,12 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
...
@@ -133,13 +131,12 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
raise
RuntimeError
(
"--fp8-param-gather requires `fp8_model_init` from TransformerEngine, but not found."
)
raise
RuntimeError
(
"--fp8-param-gather requires `fp8_model_init` from TransformerEngine, but not found."
)
# Define the mtp layer spec
# Define the mtp layer spec
if
isinstance
(
transformer_layer_spec
,
TransformerBlockSubmodules
):
mtp_block_spec
=
None
mtp_transformer_layer_spec
=
transformer_layer_spec
.
layer_specs
[
-
1
]
if
args
.
mtp_num_layers
is
not
None
:
else
:
from
dcu_megatron.core.models.gpt.gpt_layer_specs
import
get_gpt_mtp_block_spec
mtp_transformer_layer_spec
=
transformer_
layer_spec
mtp_
block_spec
=
get_gpt_mtp_block_spec
(
config
,
transformer_layer_spec
,
use_
transformer_
engine
=
use_te
)
with
build_model_context
(
**
build_model_context_args
):
with
build_model_context
(
**
build_model_context_args
):
config
.
mtp_spec
=
get_mtp_spec
(
mtp_transformer_layer_spec
,
use_te
=
use_te
)
model
=
GPTModel
(
model
=
GPTModel
(
config
=
config
,
config
=
config
,
transformer_layer_spec
=
transformer_layer_spec
,
transformer_layer_spec
=
transformer_layer_spec
,
...
@@ -153,7 +150,8 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
...
@@ -153,7 +150,8 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
position_embedding_type
=
args
.
position_embedding_type
,
position_embedding_type
=
args
.
position_embedding_type
,
rotary_percent
=
args
.
rotary_percent
,
rotary_percent
=
args
.
rotary_percent
,
rotary_base
=
args
.
rotary_base
,
rotary_base
=
args
.
rotary_base
,
rope_scaling
=
args
.
use_rope_scaling
rope_scaling
=
args
.
use_rope_scaling
,
mtp_block_spec
=
mtp_block_spec
,
)
)
# model = torch.compile(model,mode='max-autotune-no-cudagraphs')
# model = torch.compile(model,mode='max-autotune-no-cudagraphs')
print_rank_0
(
model
)
print_rank_0
(
model
)
...
@@ -197,8 +195,6 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
...
@@ -197,8 +195,6 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
args
=
get_args
()
args
=
get_args
()
losses
=
output_tensor
.
float
()
losses
=
output_tensor
.
float
()
if
getattr
(
args
,
"num_nextn_predict_layers"
,
0
)
>
0
:
loss_mask
=
tensor_slide
(
loss_mask
,
args
.
num_nextn_predict_layers
,
return_first
=
True
)[
0
]
loss_mask
=
loss_mask
.
view
(
-
1
).
float
()
loss_mask
=
loss_mask
.
view
(
-
1
).
float
()
total_tokens
=
loss_mask
.
sum
()
total_tokens
=
loss_mask
.
sum
()
loss
=
torch
.
cat
([
torch
.
sum
(
losses
.
view
(
-
1
)
*
loss_mask
).
view
(
1
),
total_tokens
.
view
(
1
)])
loss
=
torch
.
cat
([
torch
.
sum
(
losses
.
view
(
-
1
)
*
loss_mask
).
view
(
1
),
total_tokens
.
view
(
1
)])
...
@@ -267,8 +263,12 @@ def forward_step(data_iterator, model: GPTModel):
...
@@ -267,8 +263,12 @@ def forward_step(data_iterator, model: GPTModel):
timers
(
'batch-generator'
).
stop
()
timers
(
'batch-generator'
).
stop
()
with
stimer
:
with
stimer
:
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
,
if
args
.
use_legacy_models
:
labels
=
labels
)
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
,
labels
=
labels
)
else
:
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
,
labels
=
labels
,
loss_mask
=
loss_mask
)
return
output_tensor
,
partial
(
loss_func
,
loss_mask
)
return
output_tensor
,
partial
(
loss_func
,
loss_mask
)
...
@@ -289,7 +289,7 @@ def core_gpt_dataset_config_from_args(args):
...
@@ -289,7 +289,7 @@ def core_gpt_dataset_config_from_args(args):
return
GPTDatasetConfig
(
return
GPTDatasetConfig
(
random_seed
=
args
.
seed
,
random_seed
=
args
.
seed
,
sequence_length
=
args
.
seq_length
+
getattr
(
args
,
"num_nextn_predict_layers"
,
0
)
,
sequence_length
=
args
.
seq_length
,
blend
=
blend
,
blend
=
blend
,
blend_per_split
=
blend_per_split
,
blend_per_split
=
blend_per_split
,
split
=
args
.
split
,
split
=
args
.
split
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment