Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
1f7b14ab
Commit
1f7b14ab
authored
Apr 26, 2025
by
sdwldchl
Browse files
rewrite mtp
parent
89d29a02
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
1429 additions
and
616 deletions
+1429
-616
dcu_megatron/adaptor/megatron_adaptor.py
dcu_megatron/adaptor/megatron_adaptor.py
+30
-32
dcu_megatron/adaptor/patch_utils.py
dcu_megatron/adaptor/patch_utils.py
+21
-3
dcu_megatron/core/distributed/finalize_model_grads.py
dcu_megatron/core/distributed/finalize_model_grads.py
+6
-1
dcu_megatron/core/models/common/language_module/language_module.py
...ron/core/models/common/language_module/language_module.py
+6
-5
dcu_megatron/core/models/gpt/gpt_layer_specs.py
dcu_megatron/core/models/gpt/gpt_layer_specs.py
+52
-3
dcu_megatron/core/models/gpt/gpt_model.py
dcu_megatron/core/models/gpt/gpt_model.py
+383
-350
dcu_megatron/core/pipeline_parallel/schedules.py
dcu_megatron/core/pipeline_parallel/schedules.py
+52
-0
dcu_megatron/core/tensor_parallel/__init__.py
dcu_megatron/core/tensor_parallel/__init__.py
+0
-2
dcu_megatron/core/tensor_parallel/layers.py
dcu_megatron/core/tensor_parallel/layers.py
+8
-96
dcu_megatron/core/transformer/multi_token_prediction.py
dcu_megatron/core/transformer/multi_token_prediction.py
+737
-0
dcu_megatron/core/transformer/transformer_block.py
dcu_megatron/core/transformer/transformer_block.py
+1
-1
dcu_megatron/core/transformer/transformer_config.py
dcu_megatron/core/transformer/transformer_config.py
+27
-13
dcu_megatron/training/arguments.py
dcu_megatron/training/arguments.py
+10
-8
dcu_megatron/training/utils.py
dcu_megatron/training/utils.py
+82
-88
pretrain_gpt.py
pretrain_gpt.py
+14
-14
No files found.
dcu_megatron/adaptor/megatron_adaptor.py
View file @
1f7b14ab
...
...
@@ -5,6 +5,8 @@ import types
import
argparse
import
torch
from
.adaptor_arguments
import
get_adaptor_args
class
MegatronAdaptation
:
"""
...
...
@@ -21,6 +23,15 @@ class MegatronAdaptation:
for
adaptation
in
[
CoreAdaptation
(),
LegacyAdaptation
()]:
adaptation
.
execute
()
MegatronAdaptation
.
apply
()
# from .patch_utils import MegatronPatchesManager
# args = get_adaptor_args()
# for feature in FEATURES_LIST:
# if (getattr(args, feature.feature_name, None) and feature.optimization_level > 0) or feature.optimization_level == 0:
# feature.register_patches(MegatronPatchesManager, args)
# MindSpeedPatchesManager.apply_patches()
# MegatronAdaptation.post_execute()
@
classmethod
...
...
@@ -87,38 +98,20 @@ class CoreAdaptation(MegatronAdaptationABC):
self
.
patch_miscellaneous
()
def
patch_core_distributed
(
self
):
#
M
tp share embedding
#
m
tp share embedding
from
..core.distributed.finalize_model_grads
import
_allreduce_word_embedding_grads
MegatronAdaptation
.
register
(
'megatron.core.distributed.finalize_model_grads._allreduce_word_embedding_grads'
,
_allreduce_word_embedding_grads
)
def
patch_core_models
(
self
):
from
..core.models.common.embeddings.language_model_embedding
import
(
language_model_embedding_forward
,
language_model_embedding_init_func
)
from
..core.models.gpt.gpt_model
import
(
gpt_model_forward
,
gpt_model_init
,
shared_embedding_or_output_weight
,
)
from
..core.models.common.language_module.language_module
import
(
setup_embeddings_and_output_layer
,
tie_embeddings_and_output_weights_state_dict
tie_embeddings_and_output_weights_state_dict
,
)
from
..core.models.gpt.gpt_model
import
GPTModel
from
..training.utils
import
get_batch_on_this_tp_rank
# Embedding
MegatronAdaptation
.
register
(
'megatron.core.models.common.embeddings.language_model_embedding.LanguageModelEmbedding.__init__'
,
language_model_embedding_init_func
)
MegatronAdaptation
.
register
(
'megatron.core.models.common.embeddings.language_model_embedding.LanguageModelEmbedding.forward'
,
language_model_embedding_forward
)
MegatronAdaptation
.
register
(
'megatron.training.utils.get_batch_on_this_tp_rank'
,
get_batch_on_this_tp_rank
)
# GPT Model
# LanguageModule
MegatronAdaptation
.
register
(
'megatron.core.models.common.language_module.language_module.LanguageModule.setup_embeddings_and_output_layer'
,
setup_embeddings_and_output_layer
)
...
...
@@ -126,17 +119,16 @@ class CoreAdaptation(MegatronAdaptationABC):
'megatron.core.models.common.language_module.language_module.LanguageModule.tie_embeddings_and_output_weights_state_dict'
,
tie_embeddings_and_output_weights_state_dict
)
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.shared_embedding_or_output_weight'
,
shared_embedding_or_output_weight
)
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.forward'
,
gpt_model_forward
)
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.__init__'
,
gpt_model_init
)
MegatronAdaptation
.
register
(
'megatron.training.utils.get_batch_on_this_tp_rank'
,
get_batch_on_this_tp_rank
)
# GPT Model
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel'
,
GPTModel
)
def
patch_core_transformers
(
self
):
from
..core
import
transformer_block_init_wrapper
from
..core.transformer.transformer_config
import
TransformerConfigPatch
,
MLATransformerConfigPatch
# Transformer block
# Transformer block
. If mtp_num_layers > 0, move final_layernorm outside
MegatronAdaptation
.
register
(
'megatron.core.transformer.transformer_block.TransformerBlock.__init__'
,
transformer_block_init_wrapper
)
...
...
@@ -174,13 +166,10 @@ class CoreAdaptation(MegatronAdaptationABC):
def
patch_tensor_parallel
(
self
):
from
..core.tensor_parallel.cross_entropy
import
VocabParallelCrossEntropy
from
..core.tensor_parallel
import
vocab_parallel_embedding_forward
,
vocab_parallel_embedding_init_wrapper
# VocabParallelEmbedding
MegatronAdaptation
.
register
(
'megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward'
,
vocab_parallel_embedding_forward
)
MegatronAdaptation
.
register
(
'megatron.core.tensor_parallel.layers.VocabParallelEmbedding.__init__'
,
vocab_parallel_embedding_init_wrapper
,
torch
.
compile
(
mode
=
'max-autotune-no-cudagraphs'
),
apply_wrapper
=
True
)
# VocabParallelCrossEntropy
...
...
@@ -211,6 +200,14 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation
.
register
(
"megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_with_transformer_engine_spec"
,
get_gpt_layer_with_flux_spec
)
def
patch_pipeline_parallel
(
self
):
from
..core.pipeline_parallel.schedules
import
forward_step_wrapper
# pipeline_parallel.schedules.forward_step
MegatronAdaptation
.
register
(
'megatron.core.pipeline_parallel.schedules.forward_step'
,
forward_step_wrapper
,
apply_wrapper
=
True
)
def
patch_training
(
self
):
from
..training.tokenizer
import
build_tokenizer
from
..training.initialize
import
_initialize_distributed
...
...
@@ -255,6 +252,7 @@ class LegacyAdaptation(MegatronAdaptationABC):
parallel_mlp_init_wrapper
,
apply_wrapper
=
True
)
# ParallelAttention
MegatronAdaptation
.
register
(
'megatron.legacy.model.transformer.ParallelAttention.__init__'
,
parallel_attention_init_wrapper
,
apply_wrapper
=
True
)
...
...
dcu_megatron/adaptor/patch_utils.py
View file @
1f7b14ab
...
...
@@ -148,11 +148,29 @@ class MegatronPatchesManager:
patches_info
=
{}
@
staticmethod
def
register_patch
(
orig_func_or_cls_name
,
new_func_or_cls
=
None
,
force_patch
=
False
,
create_dummy
=
False
):
def
register_patch
(
orig_func_or_cls_name
,
new_func_or_cls
=
None
,
force_patch
=
False
,
create_dummy
=
False
,
apply_wrapper
=
False
,
remove_origin_wrappers
=
False
):
if
orig_func_or_cls_name
not
in
MegatronPatchesManager
.
patches_info
:
MegatronPatchesManager
.
patches_info
[
orig_func_or_cls_name
]
=
Patch
(
orig_func_or_cls_name
,
new_func_or_cls
,
create_dummy
)
MegatronPatchesManager
.
patches_info
[
orig_func_or_cls_name
]
=
Patch
(
orig_func_or_cls_name
,
new_func_or_cls
,
create_dummy
,
apply_wrapper
=
apply_wrapper
,
remove_origin_wrappers
=
remove_origin_wrappers
)
else
:
MegatronPatchesManager
.
patches_info
.
get
(
orig_func_or_cls_name
).
set_patch_func
(
new_func_or_cls
,
force_patch
)
MegatronPatchesManager
.
patches_info
.
get
(
orig_func_or_cls_name
).
set_patch_func
(
new_func_or_cls
,
force_patch
,
apply_wrapper
=
apply_wrapper
,
remove_origin_wrappers
=
remove_origin_wrappers
)
@
staticmethod
def
apply_patches
():
...
...
dcu_megatron/core/distributed/finalize_model_grads.py
View file @
1f7b14ab
...
...
@@ -28,7 +28,12 @@ def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: Transf
model_module
=
model
[
0
]
model_module
=
get_attr_wrapped_model
(
model_module
,
'pre_process'
,
return_model_obj
=
True
)
if
model_module
.
share_embeddings_and_output_weights
or
getattr
(
config
,
'num_nextn_predict_layers'
,
0
):
# If share_embeddings_and_output_weights is True, we need to maintain duplicated
# embedding weights in post processing stage. If use Multi-Token Prediction (MTP),
# we also need to maintain duplicated embedding weights in mtp process stage.
# So we need to allreduce grads of embedding in the embedding group in these cases.
if
model_module
.
share_embeddings_and_output_weights
or
getattr
(
config
,
'mtp_num_layers'
,
0
):
weight
=
model_module
.
shared_embedding_or_output_weight
()
grad_attr
=
"main_grad"
if
hasattr
(
weight
,
"main_grad"
)
else
"grad"
orig_grad
=
getattr
(
weight
,
grad_attr
)
...
...
dcu_megatron/core/models/common/language_module/language_module.py
View file @
1f7b14ab
...
...
@@ -4,6 +4,7 @@ import torch
from
megatron.core
import
parallel_state
from
megatron.core.dist_checkpointing.mapping
import
ShardedStateDict
from
megatron.core.models.common.language_module.language_module
import
LanguageModule
from
megatron.core.utils
import
make_tp_sharded_tensor_for_checkpoint
...
...
@@ -27,7 +28,7 @@ def setup_embeddings_and_output_layer(self) -> None:
# So we need to copy embedding weights from pre processing stage as initial parameters
# in these cases.
if
not
self
.
share_embeddings_and_output_weights
and
not
getattr
(
self
.
config
,
'
num_nextn_predict
_layers'
,
0
self
.
config
,
'
mtp_num
_layers'
,
0
):
return
...
...
@@ -41,10 +42,10 @@ def setup_embeddings_and_output_layer(self) -> None:
if
parallel_state
.
is_pipeline_first_stage
()
and
self
.
pre_process
and
not
self
.
post_process
:
self
.
shared_embedding_or_output_weight
().
shared_embedding
=
True
if
self
.
post_process
and
not
self
.
pre_process
:
if
(
self
.
post_process
or
getattr
(
self
,
'mtp_process'
,
False
))
and
not
self
.
pre_process
:
assert
not
parallel_state
.
is_pipeline_first_stage
()
# set w
ord_embeddings weights to 0 here, then copy first
#
stage's weights
using all_reduce below.
# set w
eights of the duplicated embedding to 0 here,
#
then copy weights from pre processing stage
using all_reduce below.
weight
=
self
.
shared_embedding_or_output_weight
()
weight
.
data
.
fill_
(
0
)
weight
.
shared
=
True
...
...
@@ -114,7 +115,7 @@ def tie_embeddings_and_output_weights_state_dict(
# layer in mtp process stage. In this case, if share_embeddings_and_output_weights is True,
# the shared weights will be stored in embedding layer, and output layer will not have
# any weight.
if
self
.
post_process
and
getattr
(
self
,
'num_nextn_predict_layer
s'
,
False
):
if
getattr
(
self
,
'mtp_proces
s'
,
False
):
# No output layer
assert
output_layer_weight_key
not
in
sharded_state_dict
,
sharded_state_dict
.
keys
()
return
...
...
dcu_megatron/core/models/gpt/gpt_layer_specs.py
View file @
1f7b14ab
import
warnings
from
typing
import
Optional
from
typing
import
Optional
,
Union
from
megatron.core.fusions.fused_bias_dropout
import
get_bias_dropout_add
from
megatron.core.models.gpt.moe_module_specs
import
get_moe_module_spec
...
...
@@ -12,13 +12,13 @@ from megatron.core.transformer.multi_latent_attention import (
MLASelfAttentionSubmodules
,
)
from
megatron.core.transformer.spec_utils
import
ModuleSpec
from
megatron.core.transformer.transformer_block
import
TransformerBlockSubmodules
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core.transformer.transformer_layer
import
(
TransformerLayer
,
TransformerLayerSubmodules
,
)
from
dcu_megatron.core.tensor_parallel.layers
import
FluxColumnParallelLinear
,
FluxRowParallelLinear
from
megatron.core.utils
import
is_te_min_version
try
:
...
...
@@ -36,6 +36,55 @@ try:
except
ImportError
:
warnings
.
warn
(
'Apex is not installed.'
)
from
dcu_megatron.core.tensor_parallel.layers
import
(
FluxColumnParallelLinear
,
FluxRowParallelLinear
)
from
dcu_megatron.core.transformer.multi_token_prediction
import
(
MultiTokenPredictionBlockSubmodules
,
get_mtp_layer_offset
,
get_mtp_layer_spec
,
get_mtp_num_layers_to_build
,
)
def
get_gpt_mtp_block_spec
(
config
:
TransformerConfig
,
spec
:
Union
[
TransformerBlockSubmodules
,
ModuleSpec
],
use_transformer_engine
:
bool
,
)
->
MultiTokenPredictionBlockSubmodules
:
"""GPT Multi-Token Prediction (MTP) block spec."""
num_layers_to_build
=
get_mtp_num_layers_to_build
(
config
)
if
num_layers_to_build
==
0
:
return
None
if
isinstance
(
spec
,
TransformerBlockSubmodules
):
# get the spec for the last layer of decoder block
transformer_layer_spec
=
spec
.
layer_specs
[
-
1
]
elif
isinstance
(
spec
,
ModuleSpec
)
and
spec
.
module
==
TransformerLayer
:
transformer_layer_spec
=
spec
else
:
raise
ValueError
(
f
"Invalid spec:
{
spec
}
"
)
mtp_layer_spec
=
get_mtp_layer_spec
(
transformer_layer_spec
=
transformer_layer_spec
,
use_transformer_engine
=
use_transformer_engine
)
mtp_num_layers
=
config
.
mtp_num_layers
if
config
.
mtp_num_layers
else
0
mtp_layer_specs
=
[
mtp_layer_spec
]
*
mtp_num_layers
offset
=
get_mtp_layer_offset
(
config
)
# split the mtp layer specs to only include the layers that are built in this pipeline stage.
mtp_layer_specs
=
mtp_layer_specs
[
offset
:
offset
+
num_layers_to_build
]
if
len
(
mtp_layer_specs
)
>
0
:
assert
(
len
(
mtp_layer_specs
)
==
config
.
mtp_num_layers
),
+
f
"currently all of the mtp layers must stage in the same pipeline stage."
mtp_block_spec
=
MultiTokenPredictionBlockSubmodules
(
layer_specs
=
mtp_layer_specs
)
else
:
mtp_block_spec
=
None
return
mtp_block_spec
def
get_gpt_layer_with_flux_spec
(
num_experts
:
Optional
[
int
]
=
None
,
...
...
dcu_megatron/core/models/gpt/gpt_model.py
View file @
1f7b14ab
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import
os
import
logging
from
typing
import
Literal
,
Optional
from
functools
import
wraps
from
collections
import
OrderedDict
from
typing
import
Dict
,
Literal
,
Optional
import
torch
from
torch
import
Tensor
from
megatron.core
import
InferenceParams
,
parallel_state
,
tensor_parallel
from
megatron.core
import
InferenceParams
,
tensor_parallel
from
megatron.core.config_logger
import
has_config_logger_enabled
,
log_config_to_disk
from
megatron.core.models.gpt.gpt_model
import
GPTModel
from
megatron.core.models.common.language_module.language_module
import
LanguageModule
from
megatron.core.dist_checkpointing.mapping
import
ShardedStateDict
from
megatron.core.models.common.embeddings.language_model_embedding
import
LanguageModelEmbedding
from
megatron.core.models.common.embeddings.rotary_pos_embedding
import
RotaryEmbedding
from
megatron.core.models.common.language_module.language_module
import
LanguageModule
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.transformer.enums
import
ModelType
from
megatron.core.transformer.spec_utils
import
ModuleSpec
from
megatron.core.transformer.transformer_block
import
TransformerBlock
from
megatron.core.
extensions
.transformer_
engine
import
T
EColumnParallelLinear
from
megatron.core.
transformer
.transformer_
config
import
T
ransformerConfig
from
dcu_megatron.core.utils
import
tensor_slide
from
dcu_megatron.core.transformer.mtp.multi_token_predictor
import
MultiTokenPredictor
from
dcu_megatron.core.transformer.transformer_config
import
TransformerConfig
from
dcu_megatron.core.tensor_parallel
import
FluxColumnParallelLinear
from
dcu_megatron.core.transformer.multi_token_prediction
import
(
MultiTokenPredictionBlock
,
tie_output_layer_state_dict
,
tie_word_embeddings_state_dict
,
)
def
gpt_model_init
(
self
,
config
:
TransformerConfig
,
transformer_layer_spec
:
ModuleSpec
,
vocab_size
:
int
,
max_sequence_length
:
int
,
pre_process
:
bool
=
True
,
post_process
:
bool
=
True
,
fp16_lm_cross_entropy
:
bool
=
False
,
parallel_output
:
bool
=
True
,
share_embeddings_and_output_weights
:
bool
=
False
,
position_embedding_type
:
Literal
[
'learned_absolute'
,
'rope'
,
'none'
]
=
'learned_absolute'
,
rotary_percent
:
float
=
1.0
,
rotary_base
:
int
=
10000
,
rope_scaling
:
bool
=
False
,
rope_scaling_factor
:
float
=
8.0
,
scatter_embedding_sequence_parallel
:
bool
=
True
,
seq_len_interpolation_factor
:
Optional
[
float
]
=
None
,
)
->
None
:
super
(
GPTModel
,
self
).
__init__
(
config
=
config
)
if
has_config_logger_enabled
(
config
):
log_config_to_disk
(
config
,
locals
(),
prefix
=
type
(
self
).
__name__
)
self
.
transformer_layer_spec
:
ModuleSpec
=
transformer_layer_spec
self
.
vocab_size
=
vocab_size
self
.
max_sequence_length
=
max_sequence_length
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
fp16_lm_cross_entropy
=
fp16_lm_cross_entropy
self
.
parallel_output
=
parallel_output
self
.
share_embeddings_and_output_weights
=
share_embeddings_and_output_weights
self
.
position_embedding_type
=
position_embedding_type
# megatron core pipelining currently depends on model type
# TODO: remove this dependency ?
self
.
model_type
=
ModelType
.
encoder_or_decoder
# These 4 attributes are needed for TensorRT-LLM export.
self
.
max_position_embeddings
=
max_sequence_length
self
.
rotary_percent
=
rotary_percent
self
.
rotary_base
=
rotary_base
self
.
rotary_scaling
=
rope_scaling
self
.
num_nextn_predict_layers
=
self
.
config
.
num_nextn_predict_layers
if
self
.
pre_process
:
self
.
embedding
=
LanguageModelEmbedding
(
config
=
self
.
config
,
vocab_size
=
self
.
vocab_size
,
max_sequence_length
=
self
.
max_sequence_length
,
position_embedding_type
=
position_embedding_type
,
scatter_to_sequence_parallel
=
scatter_embedding_sequence_parallel
,
)
if
self
.
position_embedding_type
==
'rope'
and
not
self
.
config
.
multi_latent_attention
:
self
.
rotary_pos_emb
=
RotaryEmbedding
(
kv_channels
=
self
.
config
.
kv_channels
,
rotary_percent
=
rotary_percent
,
rotary_interleaved
=
self
.
config
.
rotary_interleaved
,
seq_len_interpolation_factor
=
seq_len_interpolation_factor
,
rotary_base
=
rotary_base
,
rope_scaling
=
rope_scaling
,
rope_scaling_factor
=
rope_scaling_factor
,
use_cpu_initialization
=
self
.
config
.
use_cpu_initialization
,
)
# Cache for RoPE tensors which do not change between iterations.
self
.
rotary_pos_emb_cache
=
{}
# Transformer.
self
.
decoder
=
TransformerBlock
(
config
=
self
.
config
,
spec
=
transformer_layer_spec
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
,
)
if
self
.
post_process
and
getattr
(
self
.
config
,
'num_nextn_predict_layers'
,
0
):
self
.
embedding
=
LanguageModelEmbedding
(
config
=
self
.
config
,
vocab_size
=
self
.
vocab_size
,
max_sequence_length
=
self
.
max_sequence_length
,
position_embedding_type
=
position_embedding_type
,
scatter_to_sequence_parallel
=
scatter_embedding_sequence_parallel
,
)
# Output
if
post_process
:
if
self
.
config
.
defer_embedding_wgrad_compute
:
# The embedding activation buffer preserves a reference to the input activations
# of the final embedding projection layer GEMM. It will hold the activations for
# all the micro-batches of a global batch for the last pipeline stage. Once we are
# done with all the back props for all the microbatches for the last pipeline stage,
# it will be in the pipeline flush stage. During this pipeline flush we use the
# input activations stored in embedding activation buffer and gradient outputs
# stored in gradient buffer to calculate the weight gradients for the embedding
# final linear layer.
self
.
embedding_activation_buffer
=
[]
self
.
grad_output_buffer
=
[]
else
:
self
.
embedding_activation_buffer
=
None
self
.
grad_output_buffer
=
None
if
int
(
os
.
getenv
(
"USE_FLUX_OVERLAP"
,
"0"
)):
column_parallel_linear_impl
=
FluxColumnParallelLinear
else
:
column_parallel_linear_impl
=
tensor_parallel
.
ColumnParallelLinear
self
.
output_layer
=
column_parallel_linear_impl
(
config
.
hidden_size
,
self
.
vocab_size
,
config
=
config
,
init_method
=
config
.
init_method
,
bias
=
False
,
skip_bias_add
=
False
,
gather_output
=
not
self
.
parallel_output
,
skip_weight_param_allocation
=
self
.
pre_process
and
self
.
share_embeddings_and_output_weights
,
embedding_activation_buffer
=
self
.
embedding_activation_buffer
,
grad_output_buffer
=
self
.
grad_output_buffer
,
)
# add mtp
if
self
.
num_nextn_predict_layers
:
assert
hasattr
(
self
.
config
,
"mtp_spec"
)
self
.
mtp_spec
=
self
.
config
.
mtp_spec
self
.
recompute_mtp_norm
=
self
.
config
.
recompute_mtp_norm
self
.
recompute_mtp_layer
=
self
.
config
.
recompute_mtp_layer
self
.
mtp_loss_scale
=
self
.
config
.
mtp_loss_scale
if
self
.
post_process
and
self
.
training
:
self
.
mtp_layers
=
torch
.
nn
.
ModuleList
(
[
MultiTokenPredictor
(
self
.
config
,
self
.
mtp_spec
.
submodules
,
vocab_size
=
self
.
vocab_size
,
max_sequence_length
=
self
.
max_sequence_length
,
layer_number
=
i
,
pre_process
=
self
.
pre_process
,
fp16_lm_cross_entropy
=
self
.
fp16_lm_cross_entropy
,
parallel_output
=
self
.
parallel_output
,
position_embedding_type
=
self
.
position_embedding_type
,
rotary_percent
=
self
.
rotary_percent
,
seq_len_interpolation_factor
=
seq_len_interpolation_factor
,
recompute_mtp_norm
=
self
.
recompute_mtp_norm
,
recompute_mtp_layer
=
self
.
recompute_mtp_layer
,
add_output_layer_bias
=
False
)
for
i
in
range
(
self
.
num_nextn_predict_layers
)
]
)
class
GPTModel
(
LanguageModule
):
"""GPT Transformer language model.
if
self
.
pre_process
or
self
.
post_process
:
self
.
setup_embeddings_and_output_layer
()
Args:
config (TransformerConfig):
Transformer config
transformer_layer_spec (ModuleSpec):
Specifies module to use for transformer layers
vocab_size (int):
Vocabulary size
max_sequence_length (int):
maximum size of sequence. This is used for positional embedding
pre_process (bool, optional):
Include embedding layer (used with pipeline parallelism). Defaults to True.
post_process (bool, optional):
Include an output layer (used with pipeline parallelism). Defaults to True.
fp16_lm_cross_entropy (bool, optional):
Defaults to False.
parallel_output (bool, optional):
Do not gather the outputs, keep them split across tensor
parallel ranks. Defaults to True.
share_embeddings_and_output_weights (bool, optional):
When True, input embeddings and output logit weights are shared. Defaults to False.
position_embedding_type (Literal[learned_absolute,rope], optional):
Position embedding type.. Defaults to 'learned_absolute'.
rotary_percent (float, optional):
Percent of rotary dimension to use for rotary position embeddings.
Ignored unless position_embedding_type is 'rope'. Defaults to 1.0.
rotary_base (int, optional):
Base period for rotary position embeddings. Ignored unless
position_embedding_type is 'rope'.
Defaults to 10000.
rope_scaling (bool, optional): Toggle RoPE scaling.
rope_scaling_factor (float): RoPE scaling factor. Default 8.
scatter_embedding_sequence_parallel (bool, optional):
Whether embeddings should be scattered across sequence parallel
region or not. Defaults to True.
seq_len_interpolation_factor (Optional[float], optional):
scale of linearly interpolating RoPE for longer sequences.
The value must be a float larger than 1.0. Defaults to None.
"""
if
has_config_logger_enabled
(
self
.
config
):
log_config_to_disk
(
self
.
config
,
self
.
state_dict
(),
prefix
=
f
'
{
type
(
self
).
__name__
}
_init_ckpt'
)
def
__init__
(
self
,
config
:
TransformerConfig
,
transformer_layer_spec
:
ModuleSpec
,
vocab_size
:
int
,
max_sequence_length
:
int
,
pre_process
:
bool
=
True
,
post_process
:
bool
=
True
,
fp16_lm_cross_entropy
:
bool
=
False
,
parallel_output
:
bool
=
True
,
share_embeddings_and_output_weights
:
bool
=
False
,
position_embedding_type
:
Literal
[
'learned_absolute'
,
'rope'
,
'none'
]
=
'learned_absolute'
,
rotary_percent
:
float
=
1.0
,
rotary_base
:
int
=
10000
,
rope_scaling
:
bool
=
False
,
rope_scaling_factor
:
float
=
8.0
,
scatter_embedding_sequence_parallel
:
bool
=
True
,
seq_len_interpolation_factor
:
Optional
[
float
]
=
None
,
mtp_block_spec
:
Optional
[
ModuleSpec
]
=
None
,
)
->
None
:
super
().
__init__
(
config
=
config
)
if
has_config_logger_enabled
(
config
):
log_config_to_disk
(
config
,
locals
(),
prefix
=
type
(
self
).
__name__
)
self
.
transformer_layer_spec
:
ModuleSpec
=
transformer_layer_spec
self
.
vocab_size
=
vocab_size
self
.
max_sequence_length
=
max_sequence_length
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
fp16_lm_cross_entropy
=
fp16_lm_cross_entropy
self
.
parallel_output
=
parallel_output
self
.
share_embeddings_and_output_weights
=
share_embeddings_and_output_weights
self
.
position_embedding_type
=
position_embedding_type
# megatron core pipelining currently depends on model type
# TODO: remove this dependency ?
self
.
model_type
=
ModelType
.
encoder_or_decoder
# These 4 attributes are needed for TensorRT-LLM export.
self
.
max_position_embeddings
=
max_sequence_length
self
.
rotary_percent
=
rotary_percent
self
.
rotary_base
=
rotary_base
self
.
rotary_scaling
=
rope_scaling
self
.
mtp_block_spec
=
mtp_block_spec
self
.
mtp_process
=
mtp_block_spec
is
not
None
if
self
.
pre_process
or
self
.
mtp_process
:
self
.
embedding
=
LanguageModelEmbedding
(
config
=
self
.
config
,
vocab_size
=
self
.
vocab_size
,
max_sequence_length
=
self
.
max_sequence_length
,
position_embedding_type
=
position_embedding_type
,
scatter_to_sequence_parallel
=
scatter_embedding_sequence_parallel
,
)
if
self
.
position_embedding_type
==
'rope'
and
not
self
.
config
.
multi_latent_attention
:
self
.
rotary_pos_emb
=
RotaryEmbedding
(
kv_channels
=
self
.
config
.
kv_channels
,
rotary_percent
=
rotary_percent
,
rotary_interleaved
=
self
.
config
.
rotary_interleaved
,
seq_len_interpolation_factor
=
seq_len_interpolation_factor
,
rotary_base
=
rotary_base
,
rope_scaling
=
rope_scaling
,
rope_scaling_factor
=
rope_scaling_factor
,
use_cpu_initialization
=
self
.
config
.
use_cpu_initialization
,
)
def
shared_embedding_or_output_weight
(
self
)
->
Tensor
:
"""Gets the emedding weight or output logit weights when share embedding and output weights set to True.
# Cache for RoPE tensors which do not change between iterations.
self
.
rotary_pos_emb_cache
=
{}
Returns:
Tensor: During pre processing it returns the input embeddings weight while during post processing it returns the final output layers weight
"""
if
self
.
pre_process
or
(
self
.
post_process
and
getattr
(
self
.
config
,
'num_nextn_predict_layers'
,
0
)):
return
self
.
embedding
.
word_embeddings
.
weight
elif
self
.
post_process
:
return
self
.
output_layer
.
weight
return
None
def
slice_inputs
(
self
,
input_ids
,
labels
,
position_ids
,
attention_mask
):
if
self
.
num_nextn_predict_layers
==
0
:
return
(
[
input_ids
],
[
labels
],
[
position_ids
],
[
attention_mask
],
# Transformer.
self
.
decoder
=
TransformerBlock
(
config
=
self
.
config
,
spec
=
transformer_layer_spec
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
,
)
return
(
tensor_slide
(
input_ids
,
self
.
num_nextn_predict_layers
),
tensor_slide
(
labels
,
self
.
num_nextn_predict_layers
),
generate_nextn_position_ids
(
position_ids
,
self
.
num_nextn_predict_layers
),
# not compatible with ppo attn_mask
tensor_slide
(
attention_mask
,
self
.
num_nextn_predict_layers
,
dims
=
[
-
2
,
-
1
]),
)
if
self
.
mtp_process
:
self
.
mtp
=
MultiTokenPredictionBlock
(
config
=
self
.
config
,
spec
=
self
.
mtp_block_spec
)
# Output
if
self
.
post_process
or
self
.
mtp_process
:
if
self
.
config
.
defer_embedding_wgrad_compute
:
# The embedding activation buffer preserves a reference to the input activations
# of the final embedding projection layer GEMM. It will hold the activations for
# all the micro-batches of a global batch for the last pipeline stage. Once we are
# done with all the back props for all the microbatches for the last pipeline stage,
# it will be in the pipeline flush stage. During this pipeline flush we use the
# input activations stored in embedding activation buffer and gradient outputs
# stored in gradient buffer to calculate the weight gradients for the embedding
# final linear layer.
self
.
embedding_activation_buffer
=
[]
self
.
grad_output_buffer
=
[]
else
:
self
.
embedding_activation_buffer
=
None
self
.
grad_output_buffer
=
None
if
int
(
os
.
getenv
(
"USE_FLUX_OVERLAP"
,
"0"
)):
parallel_linear_impl
=
FluxColumnParallelLinear
else
:
parallel_linear_impl
=
tensor_parallel
.
ColumnParallelLinear
self
.
output_layer
=
parallel_linear_impl
(
config
.
hidden_size
,
self
.
vocab_size
,
config
=
config
,
init_method
=
config
.
init_method
,
bias
=
False
,
skip_bias_add
=
False
,
gather_output
=
not
self
.
parallel_output
,
skip_weight_param_allocation
=
self
.
pre_process
and
self
.
share_embeddings_and_output_weights
,
embedding_activation_buffer
=
self
.
embedding_activation_buffer
,
grad_output_buffer
=
self
.
grad_output_buffer
,
)
def
generate_nextn_position_ids
(
tensor
,
slice_num
):
slides
=
tensor_slide
(
tensor
,
slice_num
)
if
slides
[
0
]
is
None
:
return
slides
if
self
.
pre_process
or
self
.
post_process
:
self
.
setup_embeddings_and_output_layer
()
for
idx
in
range
(
1
,
len
(
slides
)):
slides
[
idx
]
=
regenerate_position_ids
(
slides
[
idx
],
idx
)
return
slides
if
has_config_logger_enabled
(
self
.
config
):
log_config_to_disk
(
self
.
config
,
self
.
state_dict
(),
prefix
=
f
'
{
type
(
self
).
__name__
}
_init_ckpt'
)
def
set_input_tensor
(
self
,
input_tensor
:
Tensor
)
->
None
:
"""Sets input tensor to the model.
def
regenerate_position_ids
(
tensor
,
offset
):
if
tensor
is
None
:
return
None
See megatron.model.transformer.set_input_tensor()
tensor
=
tensor
.
clone
()
for
i
in
range
(
tensor
.
size
(
0
)):
row
=
tensor
[
i
]
zero_mask
=
(
row
==
0
)
# 两句拼接情形
if
zero_mask
.
any
():
first_zero_idx
=
torch
.
argmax
(
zero_mask
.
int
()).
item
()
tensor
[
i
,
:
first_zero_idx
]
=
torch
.
arange
(
first_zero_idx
)
else
:
tensor
[
i
]
=
tensor
[
i
]
-
offset
return
tensor
def
gpt_model_forward
(
self
,
input_ids
:
Tensor
,
position_ids
:
Tensor
,
attention_mask
:
Tensor
,
decoder_input
:
Tensor
=
None
,
labels
:
Tensor
=
None
,
inference_params
:
InferenceParams
=
None
,
packed_seq_params
:
PackedSeqParams
=
None
,
extra_block_kwargs
:
dict
=
None
,
runtime_gather_output
:
Optional
[
bool
]
=
None
,
)
->
Tensor
:
"""Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoeder and finally into the post
processing layer (optional).
It either returns the Loss values if labels are given or the final hidden units
Args:
input_tensor (Tensor): Sets the input tensor for the model.
"""
# This is usually handled in schedules.py but some inference code still
# gives us non-lists or None
if
not
isinstance
(
input_tensor
,
list
):
input_tensor
=
[
input_tensor
]
Args:
runtime_gather_output (bool): Gather output at runtime. Default None means
`parallel_output` arg in the constructor will be used.
"""
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
assert
len
(
input_tensor
)
==
1
,
'input_tensor should only be length 1 for gpt/bert'
self
.
decoder
.
set_input_tensor
(
input_tensor
[
0
])
# generate inputs for main and mtps
input_ids
,
labels
,
position_ids
,
attention_mask
=
slice_inputs
(
def
forward
(
self
,
input_ids
,
labels
,
position_ids
,
attention_mask
)
# Decoder embedding.
if
decoder_input
is
not
None
:
pass
elif
self
.
pre_process
:
decoder_input
=
self
.
embedding
(
input_ids
=
input_ids
[
0
],
position_ids
=
position_ids
[
0
])
else
:
# intermediate stage of pipeline
# decoder will get hidden_states from encoder.input_tensor
decoder_input
=
None
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb
=
None
rotary_pos_cos
=
None
rotary_pos_sin
=
None
if
self
.
position_embedding_type
==
'rope'
and
not
self
.
config
.
multi_latent_attention
:
if
not
self
.
training
and
self
.
config
.
flash_decode
and
inference_params
:
# Flash decoding uses precomputed cos and sin for RoPE
rotary_pos_cos
,
rotary_pos_sin
=
self
.
rotary_pos_emb_cache
.
setdefault
(
inference_params
.
max_sequence_length
,
self
.
rotary_pos_emb
.
get_cos_sin
(
inference_params
.
max_sequence_length
),
)
input_ids
:
Tensor
,
position_ids
:
Tensor
,
attention_mask
:
Tensor
,
decoder_input
:
Tensor
=
None
,
labels
:
Tensor
=
None
,
inference_params
:
InferenceParams
=
None
,
packed_seq_params
:
PackedSeqParams
=
None
,
extra_block_kwargs
:
dict
=
None
,
runtime_gather_output
:
Optional
[
bool
]
=
None
,
loss_mask
:
Optional
[
Tensor
]
=
None
,
)
->
Tensor
:
"""Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoeder and finally into the post
processing layer (optional).
It either returns the Loss values if labels are given or the final hidden units
Args:
runtime_gather_output (bool): Gather output at runtime. Default None means
`parallel_output` arg in the constructor will be used.
"""
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
# Decoder embedding.
if
decoder_input
is
not
None
:
pass
elif
self
.
pre_process
:
decoder_input
=
self
.
embedding
(
input_ids
=
input_ids
,
position_ids
=
position_ids
)
else
:
rotary_seq_len
=
self
.
rotary_pos_emb
.
get_rotary_seq_len
(
inference_params
,
self
.
decoder
,
decoder_input
,
self
.
config
,
packed_seq_params
)
rotary_pos_emb
=
self
.
rotary_pos_emb
(
rotary_seq_len
,
packed_seq
=
packed_seq_params
is
not
None
and
packed_seq_params
.
qkv_format
==
'thd'
,
# intermediate stage of pipeline
# decoder will get hidden_states from encoder.input_tensor
decoder_input
=
None
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb
=
None
rotary_pos_cos
=
None
rotary_pos_sin
=
None
if
self
.
position_embedding_type
==
'rope'
and
not
self
.
config
.
multi_latent_attention
:
if
not
self
.
training
and
self
.
config
.
flash_decode
and
inference_params
:
# Flash decoding uses precomputed cos and sin for RoPE
rotary_pos_cos
,
rotary_pos_sin
=
self
.
rotary_pos_emb_cache
.
setdefault
(
inference_params
.
max_sequence_length
,
self
.
rotary_pos_emb
.
get_cos_sin
(
inference_params
.
max_sequence_length
),
)
else
:
rotary_seq_len
=
self
.
rotary_pos_emb
.
get_rotary_seq_len
(
inference_params
,
self
.
decoder
,
decoder_input
,
self
.
config
,
packed_seq_params
)
rotary_pos_emb
=
self
.
rotary_pos_emb
(
rotary_seq_len
,
packed_seq
=
packed_seq_params
is
not
None
and
packed_seq_params
.
qkv_format
==
'thd'
,
)
if
(
(
self
.
config
.
enable_cuda_graph
or
self
.
config
.
flash_decode
)
and
rotary_pos_cos
is
not
None
and
inference_params
):
sequence_len_offset
=
torch
.
tensor
(
[
inference_params
.
sequence_len_offset
]
*
inference_params
.
current_batch_size
,
dtype
=
torch
.
int32
,
device
=
rotary_pos_cos
.
device
,
# Co-locate this with the rotary tensors
)
if
(
(
self
.
config
.
enable_cuda_graph
or
self
.
config
.
flash_decode
)
and
rotary_pos_cos
is
not
None
and
inference_params
):
sequence_len_offset
=
torch
.
tensor
(
[
inference_params
.
sequence_len_offset
]
*
inference_params
.
current_batch_size
,
dtype
=
torch
.
int32
,
device
=
rotary_pos_cos
.
device
,
# Co-locate this with the rotary tensors
else
:
sequence_len_offset
=
None
# Run decoder.
hidden_states
=
self
.
decoder
(
hidden_states
=
decoder_input
,
attention_mask
=
attention_mask
,
inference_params
=
inference_params
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_cos
=
rotary_pos_cos
,
rotary_pos_sin
=
rotary_pos_sin
,
packed_seq_params
=
packed_seq_params
,
sequence_len_offset
=
sequence_len_offset
,
**
(
extra_block_kwargs
or
{}),
)
else
:
sequence_len_offset
=
None
# Run decoder.
hidden_states
=
self
.
decoder
(
hidden_states
=
decoder_input
,
attention_mask
=
attention_mask
[
0
],
inference_params
=
inference_params
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_cos
=
rotary_pos_cos
,
rotary_pos_sin
=
rotary_pos_sin
,
packed_seq_params
=
packed_seq_params
,
sequence_len_offset
=
sequence_len_offset
,
**
(
extra_block_kwargs
or
{}),
)
if
not
self
.
post_process
:
return
hidden_states
# logits and loss
output_weight
=
None
if
self
.
share_embeddings_and_output_weights
:
output_weight
=
self
.
shared_embedding_or_output_weight
()
loss
=
0
# Multi token prediction module
if
self
.
num_nextn_predict_layers
and
self
.
training
:
mtp_hidden_states
=
hidden_states
for
i
in
range
(
self
.
num_nextn_predict_layers
):
mtp_hidden_states
,
mtp_loss
=
self
.
mtp_layers
[
i
](
mtp_hidden_states
,
# [s,b,h]
input_ids
[
i
+
1
],
position_ids
[
i
+
1
]
if
position_ids
[
0
]
is
not
None
else
None
,
attention_mask
[
i
+
1
]
if
attention_mask
[
0
]
is
not
None
else
None
,
labels
[
i
+
1
]
if
labels
[
0
]
is
not
None
else
None
,
inference_params
,
packed_seq_params
,
extra_block_kwargs
,
embedding_layer
=
self
.
embedding
,
# logits and loss
output_weight
=
None
if
self
.
share_embeddings_and_output_weights
:
output_weight
=
self
.
shared_embedding_or_output_weight
()
if
self
.
mtp_process
:
hidden_states
=
self
.
mtp
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
labels
=
labels
,
loss_mask
=
loss_mask
,
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
inference_params
=
inference_params
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_cos
=
rotary_pos_cos
,
rotary_pos_sin
=
rotary_pos_sin
,
packed_seq_params
=
packed_seq_params
,
sequence_len_offset
=
sequence_len_offset
,
embedding
=
self
.
embedding
,
output_layer
=
self
.
output_layer
,
output_weight
=
output_weight
,
runtime_gather_output
=
runtime_gather_output
,
compute_language_model_loss
=
self
.
compute_language_model_loss
,
**
(
extra_block_kwargs
or
{}),
)
loss
+=
self
.
mtp_loss_scale
/
self
.
num_nextn_predict_layers
*
mtp_loss
if
(
self
.
num_nextn_predict_layers
and
getattr
(
self
.
decoder
,
"main_final_layernorm"
,
None
)
is
not
None
):
# move block main model final norms here
hidden_states
=
self
.
decoder
.
main_final_layernorm
(
hidden_states
)
logits
,
_
=
self
.
output_layer
(
hidden_states
,
weight
=
output_weight
,
runtime_gather_output
=
runtime_gather_output
)
if
has_config_logger_enabled
(
self
.
config
):
payload
=
OrderedDict
(
{
'input_ids'
:
input_ids
[
0
],
'position_ids'
:
position_ids
[
0
],
'attention_mask'
:
attention_mask
[
0
],
'decoder_input'
:
decoder_input
,
'logits'
:
logits
,
}
if
(
self
.
mtp_process
is
not
None
and
getattr
(
self
.
decoder
,
"main_final_layernorm"
,
None
)
is
not
None
):
# move block main model final norms here
hidden_states
=
self
.
decoder
.
main_final_layernorm
(
hidden_states
)
if
not
self
.
post_process
:
return
hidden_states
logits
,
_
=
self
.
output_layer
(
hidden_states
,
weight
=
output_weight
,
runtime_gather_output
=
runtime_gather_output
)
log_config_to_disk
(
self
.
config
,
payload
,
prefix
=
'input_and_logits'
)
if
labels
[
0
]
is
None
:
# [s b h] => [b s h]
return
logits
.
transpose
(
0
,
1
).
contiguous
()
if
has_config_logger_enabled
(
self
.
config
):
payload
=
OrderedDict
(
{
'input_ids'
:
input_ids
,
'position_ids'
:
position_ids
,
'attention_mask'
:
attention_mask
,
'decoder_input'
:
decoder_input
,
'logits'
:
logits
,
}
)
log_config_to_disk
(
self
.
config
,
payload
,
prefix
=
'input_and_logits'
)
if
labels
is
None
:
# [s b h] => [b s h]
return
logits
.
transpose
(
0
,
1
).
contiguous
()
loss
=
self
.
compute_language_model_loss
(
labels
,
logits
)
return
loss
def
shared_embedding_or_output_weight
(
self
)
->
Tensor
:
"""Gets the embedding weight or output logit weights when share input embedding and
output weights set to True or when use Multi-Token Prediction (MTP) feature.
Returns:
Tensor: During pre processing or MTP process it returns the input embeddings weight.
Otherwise, during post processing it returns the final output layers weight.
"""
if
self
.
pre_process
or
self
.
mtp_process
:
# Multi-Token Prediction (MTP) need both embedding layer and output layer.
# So there will be both embedding layer and output layer in the mtp process stage.
# In this case, if share_embeddings_and_output_weights is True, the shared weights
# will be stored in embedding layer, and output layer will not have any weight.
assert
hasattr
(
self
,
'embedding'
),
f
"embedding is needed in this pipeline stage, but it is not initialized."
return
self
.
embedding
.
word_embeddings
.
weight
elif
self
.
post_process
:
return
self
.
output_layer
.
weight
return
None
loss
+=
self
.
compute_language_model_loss
(
labels
[
0
],
logits
)
def
sharded_state_dict
(
self
,
prefix
:
str
=
''
,
sharded_offsets
:
tuple
=
(),
metadata
:
Optional
[
Dict
]
=
None
)
->
ShardedStateDict
:
"""Sharded state dict implementation for GPTModel backward-compatibility
(removing extra state).
Args:
prefix (str): Module name prefix.
sharded_offsets (tuple): PP related offsets, expected to be empty at this module level.
metadata (Optional[Dict]): metadata controlling sharded state dict creation.
Returns:
ShardedStateDict: sharded state dict for the GPTModel
"""
sharded_state_dict
=
super
().
sharded_state_dict
(
prefix
,
sharded_offsets
,
metadata
)
output_layer_extra_state_key
=
f
'
{
prefix
}
output_layer._extra_state'
# Old GPT checkpoints only stored the output layer weight key. So we remove the
# _extra_state key but check that it doesn't contain any data anyway
output_extra_state
=
sharded_state_dict
.
pop
(
output_layer_extra_state_key
,
None
)
assert
not
(
output_extra_state
and
output_extra_state
.
data
),
f
'Expected output layer extra state to be empty, got:
{
output_extra_state
}
'
# Multi-Token Prediction (MTP) need both embedding layer and output layer in
# mtp process stage.
# If MTP is not placed in the pre processing stage, we need to maintain a copy of
# embedding layer in the mtp process stage and tie it to the embedding in the pre
# processing stage.
# Also, if MTP is not placed in the post processing stage, we need to maintain a copy
# of output layer in the mtp process stage and tie it to the output layer in the post
# processing stage.
if
self
.
mtp_process
and
not
self
.
pre_process
:
emb_weight_key
=
f
'
{
prefix
}
embedding.word_embeddings.weight'
emb_weight
=
self
.
embedding
.
word_embeddings
.
weight
tie_word_embeddings_state_dict
(
sharded_state_dict
,
emb_weight
,
emb_weight_key
)
if
self
.
mtp_process
and
not
self
.
post_process
:
# We only need to tie the output layer weight if share_embeddings_and_output_weights
# is False. Because if share_embeddings_and_output_weights is True, the shared weight
# will be stored in embedding layer, and output layer will not have any weight.
if
not
self
.
share_embeddings_and_output_weights
:
output_layer_weight_key
=
f
'
{
prefix
}
output_layer.weight'
output_layer_weight
=
self
.
output_layer
.
weight
tie_output_layer_state_dict
(
sharded_state_dict
,
output_layer_weight
,
output_layer_weight_key
)
return
loss
return
sharded_state_dict
dcu_megatron/core/pipeline_parallel/schedules.py
0 → 100644
View file @
1f7b14ab
import
torch
from
functools
import
wraps
from
dcu_megatron.core.transformer.multi_token_prediction
import
MTPLossAutoScaler
def
forward_step_wrapper
(
fn
):
@
wraps
(
fn
)
def
wrapper
(
forward_step_func
,
data_iterator
,
model
,
num_microbatches
,
input_tensor
,
forward_data_store
,
config
,
**
kwargs
,
):
output
,
num_tokens
=
fn
(
forward_step_func
,
data_iterator
,
model
,
num_microbatches
,
input_tensor
,
forward_data_store
,
config
,
**
kwargs
)
if
not
isinstance
(
input_tensor
,
list
):
# unwrap_output_tensor True
output_tensor
=
output
else
:
output_tensor
=
output
[
0
]
# Set the loss scale for Multi-Token Prediction (MTP) loss.
if
hasattr
(
config
,
'mtp_num_layers'
)
and
config
.
mtp_num_layers
is
not
None
:
# Calculate the loss scale based on the grad_scale_func if available, else default to 1.
loss_scale
=
(
config
.
grad_scale_func
(
torch
.
ones
(
1
,
device
=
output_tensor
.
device
))
if
config
.
grad_scale_func
is
not
None
else
torch
.
ones
(
1
,
device
=
output_tensor
.
device
)
)
# Set the loss scale
if
config
.
calculate_per_token_loss
:
MTPLossAutoScaler
.
set_loss_scale
(
loss_scale
)
else
:
MTPLossAutoScaler
.
set_loss_scale
(
loss_scale
/
num_microbatches
)
return
output
,
num_tokens
return
wrapper
\ No newline at end of file
dcu_megatron/core/tensor_parallel/__init__.py
View file @
1f7b14ab
from
.layers
import
(
FluxColumnParallelLinear
,
FluxRowParallelLinear
,
vocab_parallel_embedding_forward
,
vocab_parallel_embedding_init_wrapper
,
)
\ No newline at end of file
dcu_megatron/core/tensor_parallel/layers.py
View file @
1f7b14ab
import
os
import
copy
import
socket
import
warnings
from
functools
import
wraps
from
typing
import
Callable
,
List
,
Optional
if
int
(
os
.
getenv
(
"USE_FLUX_OVERLAP"
,
"0"
)):
try
:
import
flux
from
dcu_megatron.core.utils
import
is_flux_min_version
except
ImportError
:
raise
ImportError
(
"flux is NOT installed"
)
try
:
import
flux
except
ImportError
:
raise
ImportError
(
"flux is NOT installed"
)
import
torch
import
torch.nn.functional
as
F
from
torch.nn.parameter
import
Parameter
from
megatron.training
import
print_rank_0
from
megatron.core.model_parallel_config
import
ModelParallelConfig
from
megatron.core.parallel_state
import
(
get_global_memory_buffer
,
get_tensor_model_parallel_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
from
megatron.core.utils
import
(
is_torch_min_version
,
prepare_input_tensors_for_wgrad_compute
)
from
megatron.core.tensor_parallel.layers
import
(
_initialize_affine_weight_cpu
,
_initialize_affine_weight_gpu
,
VocabParallelEmbedding
,
)
from
megatron.core.utils
import
prepare_input_tensors_for_wgrad_compute
from
megatron.core.tensor_parallel.mappings
import
(
_reduce
,
copy_to_tensor_model_parallel_region
,
reduce_from_tensor_model_parallel_region
,
reduce_scatter_to_sequence_parallel_region
,
_reduce_scatter_along_first_dim
,
_gather_along_first_dim
,
)
from
megatron.core.tensor_parallel.utils
import
VocabUtility
from
megatron.core.tensor_parallel.mappings
import
_reduce
from
megatron.core.tensor_parallel
import
(
ColumnParallelLinear
,
RowParallelLinear
,
...
...
@@ -50,9 +30,9 @@ from megatron.core.tensor_parallel.layers import (
custom_fwd
,
custom_bwd
,
dist_all_gather_func
,
linear_with_frozen_weight
,
linear_with_grad_accumulation_and_async_allreduce
)
from
dcu_megatron.core.utils
import
is_flux_min_version
_grad_accum_fusion_available
=
True
try
:
...
...
@@ -61,74 +41,6 @@ except ImportError:
_grad_accum_fusion_available
=
False
def
vocab_parallel_embedding_init_wrapper
(
fn
):
@
wraps
(
fn
)
def
wrapper
(
self
,
*
args
,
skip_weight_param_allocation
:
bool
=
False
,
**
kwargs
):
if
(
skip_weight_param_allocation
and
"config"
in
kwargs
and
hasattr
(
kwargs
[
"config"
],
"perform_initialization"
)
):
config
=
copy
.
deepcopy
(
kwargs
[
"config"
])
config
.
perform_initialization
=
False
kwargs
[
"config"
]
=
config
fn
(
self
,
*
args
,
**
kwargs
)
if
skip_weight_param_allocation
:
self
.
weight
=
None
return
wrapper
@
torch
.
compile
(
mode
=
'max-autotune-no-cudagraphs'
)
def
vocab_parallel_embedding_forward
(
self
,
input_
,
weight
=
None
):
"""Forward.
Args:
input_ (torch.Tensor): Input tensor.
"""
if
weight
is
None
:
if
self
.
weight
is
None
:
raise
RuntimeError
(
"weight was not supplied to VocabParallelEmbedding forward pass "
"and skip_weight_param_allocation is True."
)
weight
=
self
.
weight
if
self
.
tensor_model_parallel_size
>
1
:
# Build the mask.
input_mask
=
(
input_
<
self
.
vocab_start_index
)
|
(
input_
>=
self
.
vocab_end_index
)
# Mask the input.
masked_input
=
input_
.
clone
()
-
self
.
vocab_start_index
masked_input
[
input_mask
]
=
0
else
:
masked_input
=
input_
# Get the embeddings.
if
self
.
deterministic_mode
:
output_parallel
=
weight
[
masked_input
]
else
:
# F.embedding currently has a non-deterministic backward function
output_parallel
=
F
.
embedding
(
masked_input
,
weight
)
# Mask the output embedding.
if
self
.
tensor_model_parallel_size
>
1
:
output_parallel
[
input_mask
,
:]
=
0.0
if
self
.
reduce_scatter_embeddings
:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
output_parallel
=
output_parallel
.
transpose
(
0
,
1
).
contiguous
()
output
=
reduce_scatter_to_sequence_parallel_region
(
output_parallel
)
else
:
# Reduce across all the model parallel GPUs.
output
=
reduce_from_tensor_model_parallel_region
(
output_parallel
)
return
output
def
get_tensor_model_parallel_node_size
(
group
=
None
):
""" 获取节点数
"""
...
...
dcu_megatron/core/transformer/multi_token_prediction.py
0 → 100755
View file @
1f7b14ab
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
from
contextlib
import
nullcontext
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
,
Union
import
torch
from
torch
import
Tensor
from
megatron.core
import
InferenceParams
,
mpu
,
parallel_state
,
tensor_parallel
from
megatron.core.dist_checkpointing.mapping
import
ShardedStateDict
from
megatron.core.dist_checkpointing.utils
import
replace_prefix_for_sharding
from
megatron.core.fusions.fused_layer_norm
import
FusedLayerNorm
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.tensor_parallel
import
(
all_gather_last_dim_from_tensor_parallel_region
,
scatter_to_sequence_parallel_region
,
)
from
megatron.core.tensor_parallel.layers
import
ColumnParallelLinear
from
megatron.core.transformer.enums
import
AttnMaskType
from
megatron.core.transformer.module
import
MegatronModule
from
megatron.core.transformer.spec_utils
import
ModuleSpec
,
build_module
from
megatron.core.transformer.transformer_block
import
TransformerBlockSubmodules
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core.utils
import
make_tp_sharded_tensor_for_checkpoint
,
make_viewless_tensor
SUPPORTED_ATTN_MASK
=
[
AttnMaskType
.
padding
,
AttnMaskType
.
causal
,
AttnMaskType
.
no_mask
,
AttnMaskType
.
padding_causal
,
]
try
:
from
megatron.core.extensions.transformer_engine
import
(
TEColumnParallelLinear
,
TEDelayedScaling
,
TENorm
,
)
HAVE_TE
=
True
except
ImportError
:
HAVE_TE
=
False
from
megatron.core.transformer.torch_norm
import
WrappedTorchNorm
try
:
import
apex
# pylint: disable=unused-import
from
megatron.core.fusions.fused_layer_norm
import
FusedLayerNorm
HAVE_APEX
=
True
LNImpl
=
FusedLayerNorm
except
ImportError
:
import
warnings
from
megatron.core.transformer.torch_norm
import
WrappedTorchNorm
warnings
.
warn
(
'Apex is not installed. Falling back to Torch Norm'
)
LNImpl
=
WrappedTorchNorm
def
tie_word_embeddings_state_dict
(
sharded_state_dict
:
ShardedStateDict
,
word_emb_weight
:
Tensor
,
word_emb_weight_key
:
str
)
->
None
:
"""tie the embedding of the mtp processing stage in a given sharded state dict.
Args:
sharded_state_dict (ShardedStateDict): state dict with the weight to tie.
word_emb_weight (Tensor): weight of the word embedding.
word_emb_weight_key (str): key of the word embedding in the sharded state dict.
Returns: None, acts in-place
"""
mtp_word_emb_replica_id
=
(
1
,
# copy of embedding in pre processing stage
0
,
parallel_state
.
get_data_parallel_rank
(
with_context_parallel
=
True
),
)
assert
word_emb_weight_key
in
sharded_state_dict
del
sharded_state_dict
[
word_emb_weight_key
]
sharded_state_dict
[
word_emb_weight_key
]
=
make_tp_sharded_tensor_for_checkpoint
(
tensor
=
word_emb_weight
,
key
=
word_emb_weight_key
,
replica_id
=
mtp_word_emb_replica_id
,
allow_shape_mismatch
=
True
,
)
def
tie_output_layer_state_dict
(
sharded_state_dict
:
ShardedStateDict
,
output_layer_weight
:
Tensor
,
output_layer_weight_key
:
str
)
->
None
:
"""tie the output layer of the mtp processing stage in a given sharded state dict.
Args:
sharded_state_dict (ShardedStateDict): state dict with the weight to tie.
output_layer_weight (Tensor): weight of the output layer.
output_layer_weight_key (str): key of the output layer in the sharded state dict.
Returns: None, acts in-place
"""
mtp_output_layer_replica_id
=
(
1
,
# copy of output layer in post processing stage
0
,
parallel_state
.
get_data_parallel_rank
(
with_context_parallel
=
True
),
)
assert
output_layer_weight_key
in
sharded_state_dict
del
sharded_state_dict
[
output_layer_weight_key
]
sharded_state_dict
[
output_layer_weight_key
]
=
make_tp_sharded_tensor_for_checkpoint
(
tensor
=
output_layer_weight
,
key
=
output_layer_weight_key
,
replica_id
=
mtp_output_layer_replica_id
,
allow_shape_mismatch
=
True
,
)
def
roll_tensor
(
tensor
,
shifts
=-
1
,
dims
=-
1
):
"""Roll the tensor input along the given dimension(s).
Inserted elements are set to be 0.0.
"""
rolled_tensor
=
torch
.
roll
(
tensor
,
shifts
=
shifts
,
dims
=
dims
)
rolled_tensor
.
select
(
dims
,
shifts
).
fill_
(
0
)
return
rolled_tensor
,
rolled_tensor
.
sum
()
class
MTPLossLoggingHelper
:
"""Helper class for logging MTP losses."""
tracker
=
{}
@
staticmethod
def
save_loss_to_tracker
(
loss
:
torch
.
Tensor
,
layer_number
:
int
,
num_layers
:
int
,
reduce_group
:
torch
.
distributed
.
ProcessGroup
=
None
,
avg_group
:
torch
.
distributed
.
ProcessGroup
=
None
,
):
"""Save the mtp loss for logging.
Args:
loss (torch.Tensor): The loss tensor.
layer_number (int): Layer index of the loss.
num_layers (int): The number of total layers.
reduce_group (torch.distributed.ProcessGroup): The group for reducing the loss.
mean_group (torch.distributed.ProcessGroup): The group for averaging the loss.
"""
# Skip mtp loss logging if layer_number is None.
if
layer_number
is
None
:
return
tracker
=
MTPLossLoggingHelper
.
tracker
if
"values"
not
in
tracker
:
tracker
[
"values"
]
=
torch
.
zeros
(
num_layers
,
device
=
loss
.
device
)
tracker
[
"values"
][
layer_number
]
+=
loss
.
detach
()
tracker
[
"reduce_group"
]
=
reduce_group
tracker
[
"avg_group"
]
=
avg_group
def
clean_loss_in_tracker
():
"""Clear the mtp losses."""
tracker
=
MTPLossLoggingHelper
.
tracker
tracker
[
"values"
].
zero_
()
tracker
[
"reduce_group"
]
=
None
tracker
[
"avg_group"
]
=
None
def
reduce_loss_in_tracker
():
"""Collect and reduce the mtp losses across ranks."""
tracker
=
MTPLossLoggingHelper
.
tracker
if
"values"
not
in
tracker
:
return
values
=
tracker
[
"values"
]
# Reduce mtp losses across ranks.
if
tracker
.
get
(
'reduce_group'
)
is
not
None
:
torch
.
distributed
.
all_reduce
(
values
,
group
=
tracker
.
get
(
'reduce_group'
))
if
tracker
.
get
(
'avg_group'
)
is
not
None
:
torch
.
distributed
.
all_reduce
(
values
,
group
=
tracker
[
'avg_group'
],
op
=
torch
.
distributed
.
ReduceOp
.
AVG
)
def
track_mtp_metrics
(
loss_scale
,
iteration
,
writer
,
wandb_writer
=
None
,
total_loss_dict
=
None
):
"""Track the Multi-Token Prediction (MTP) metrics for logging."""
MTPLossLoggingHelper
.
reduce_loss_in_tracker
()
tracker
=
MTPLossLoggingHelper
.
tracker
if
"values"
not
in
tracker
:
return
mtp_losses
=
tracker
[
"values"
]
*
loss_scale
mtp_num_layers
=
mtp_losses
.
shape
[
0
]
for
i
in
range
(
mtp_num_layers
):
name
=
f
"mtp_
{
i
+
1
}
loss"
loss
=
mtp_losses
[
i
]
if
total_loss_dict
is
not
None
:
total_loss_dict
[
name
]
=
loss
if
writer
is
not
None
:
writer
.
add_scalar
(
name
,
loss
,
iteration
)
if
wandb_writer
is
not
None
:
wandb_writer
.
log
({
f
"
{
name
}
"
:
loss
},
iteration
)
MTPLossLoggingHelper
.
clean_loss_in_tracker
()
@
dataclass
class
MultiTokenPredictionLayerSubmodules
:
"""
Dataclass for specifying the submodules of a MultiTokenPrediction module.
Args:
hnorm (Union[ModuleSpec, type]): Specification or instance of the
hidden states normalization to be applied.
enorm (Union[ModuleSpec, type]): Specification or instance of the
embedding normalization to be applied.
eh_proj (Union[ModuleSpec, type]): Specification or instance of the
linear projection to be applied.
transformer_layer (Union[ModuleSpec, type]): Specification
or instance of the transformer block to be applied.
"""
enorm
:
Union
[
ModuleSpec
,
type
]
=
None
hnorm
:
Union
[
ModuleSpec
,
type
]
=
None
eh_proj
:
Union
[
ModuleSpec
,
type
]
=
None
transformer_layer
:
Union
[
ModuleSpec
,
type
]
=
None
layer_norm
:
Union
[
ModuleSpec
,
type
]
=
None
def
get_mtp_layer_spec
(
transformer_layer_spec
:
ModuleSpec
,
use_transformer_engine
:
bool
)
->
ModuleSpec
:
"""Get the MTP layer spec.
Returns:
ModuleSpec: Module specification with TE modules
"""
if
use_transformer_engine
:
assert
HAVE_TE
,
"transformer_engine should be installed if use_transformer_engine is True"
layer_norm_impl
=
TENorm
column_parallel_linear_impl
=
TEColumnParallelLinear
else
:
layer_norm_impl
=
LNImpl
column_parallel_linear_impl
=
ColumnParallelLinear
mtp_layer_spec
=
ModuleSpec
(
module
=
MultiTokenPredictionLayer
,
submodules
=
MultiTokenPredictionLayerSubmodules
(
enorm
=
layer_norm_impl
,
hnorm
=
layer_norm_impl
,
eh_proj
=
column_parallel_linear_impl
,
transformer_layer
=
transformer_layer_spec
,
layer_norm
=
layer_norm_impl
,
),
)
return
mtp_layer_spec
def
get_mtp_layer_offset
(
config
:
TransformerConfig
)
->
int
:
"""Get the offset of the MTP layer."""
# Currently, we only support put all of MTP layers on the last pipeline stage.
return
0
def
get_mtp_num_layers_to_build
(
config
:
TransformerConfig
)
->
int
:
"""Get the number of MTP layers to build."""
# Currently, we only support put all of MTP layers on the last pipeline stage.
if
mpu
.
is_pipeline_last_stage
():
return
config
.
mtp_num_layers
if
config
.
mtp_num_layers
else
0
else
:
return
0
class
MTPLossAutoScaler
(
torch
.
autograd
.
Function
):
"""An AutoScaler that triggers the backward pass and scales the grad for mtp loss."""
main_loss_backward_scale
:
torch
.
Tensor
=
torch
.
tensor
(
1.0
)
@
staticmethod
def
forward
(
ctx
,
output
:
torch
.
Tensor
,
mtp_loss
:
torch
.
Tensor
):
"""Preserve the mtp by storing it in the context to avoid garbage collection.
Args:
output (torch.Tensor): The output tensor.
mtp_loss (torch.Tensor): The mtp loss tensor.
Returns:
torch.Tensor: The output tensor.
"""
ctx
.
save_for_backward
(
mtp_loss
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
:
torch
.
Tensor
):
"""Compute and scale the gradient for mtp loss..
Args:
grad_output (torch.Tensor): The gradient of the output.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The gradient of the output, scaled mtp loss
gradient.
"""
(
mtp_loss
,)
=
ctx
.
saved_tensors
mtp_loss_backward_scale
=
MTPLossAutoScaler
.
main_loss_backward_scale
scaled_mtp_loss_grad
=
torch
.
ones_like
(
mtp_loss
)
*
mtp_loss_backward_scale
return
grad_output
,
scaled_mtp_loss_grad
@
staticmethod
def
set_loss_scale
(
scale
:
torch
.
Tensor
):
"""set the scale of the mtp loss.
Args:
scale (torch.Tensor): The scale value to set. Please ensure that the scale passed in
matches the scale of the main_loss.
"""
MTPLossAutoScaler
.
main_loss_backward_scale
=
scale
class
MultiTokenPredictionLayer
(
MegatronModule
):
"""The implementation for Multi-Token Prediction (MTP) which extends
the prediction scope to multiple future tokens at each position.
This MTP implementation sequentially predict additional tokens and keep the complete
causal chain at each prediction depth, by using D sequential modules to predict
D additional tokens.
The k-th MTP module consists of a shared embedding layer, a projection matrix,
a Transformer block, and a shared output head.
For the i-th input token at the (k - 1)-th prediction depth, we first combine
the representation of the i-th token and the embedding of the (i + K)-th token with
the linear projection. The combined serves as the input of the Transformer block at
the k-th depth to produce the output representation.
for more information, please refer to DeepSeek-V3 Technical Report
https://github.com/deepseek-ai/DeepSeek-V3/blob/main/DeepSeek_V3.pdf
"""
def
__init__
(
self
,
config
:
TransformerConfig
,
submodules
:
MultiTokenPredictionLayerSubmodules
,
layer_number
:
int
=
1
,
):
super
().
__init__
(
config
=
config
)
self
.
sequence_parallel
=
config
.
sequence_parallel
self
.
submodules
=
submodules
self
.
layer_number
=
layer_number
self_attention_spec
=
self
.
submodules
.
transformer_layer
.
submodules
.
self_attention
attn_mask_type
=
self_attention_spec
.
params
.
get
(
'attn_mask_type'
,
''
)
assert
attn_mask_type
in
SUPPORTED_ATTN_MASK
,
(
f
"Multi-Token Prediction (MTP) is not jet supported with "
+
f
"
{
attn_mask_type
}
attention mask type."
+
f
"The supported attention mask types are
{
SUPPORTED_ATTN_MASK
}
."
)
self
.
enorm
=
build_module
(
self
.
submodules
.
enorm
,
config
=
self
.
config
,
hidden_size
=
self
.
config
.
hidden_size
,
eps
=
self
.
config
.
layernorm_epsilon
,
)
self
.
hnorm
=
build_module
(
self
.
submodules
.
hnorm
,
config
=
self
.
config
,
hidden_size
=
self
.
config
.
hidden_size
,
eps
=
self
.
config
.
layernorm_epsilon
,
)
# For the linear projection at the (k - 1)-th MTP layer, the input is the concatenation
# of the i-th tocken's hidden states and the (i + K)-th tocken's decoder input,
# so the input's shape is [s, b, 2*h].
# The output will be send to the following transformer layer,
# so the output's shape should be [s, b, h].
self
.
eh_proj
=
build_module
(
self
.
submodules
.
eh_proj
,
self
.
config
.
hidden_size
*
2
,
self
.
config
.
hidden_size
,
config
=
self
.
config
,
init_method
=
self
.
config
.
init_method
,
gather_output
=
False
,
bias
=
False
,
skip_bias_add
=
False
,
is_expert
=
False
,
)
self
.
transformer_layer
=
build_module
(
self
.
submodules
.
transformer_layer
,
config
=
self
.
config
)
self
.
final_layernorm
=
build_module
(
self
.
submodules
.
layer_norm
,
config
=
self
.
config
,
hidden_size
=
self
.
config
.
hidden_size
,
eps
=
self
.
config
.
layernorm_epsilon
,
)
def
forward
(
self
,
decoder_input
:
Tensor
,
hidden_states
:
Tensor
,
attention_mask
:
Tensor
,
context
:
Tensor
=
None
,
context_mask
:
Tensor
=
None
,
rotary_pos_emb
:
Tensor
=
None
,
rotary_pos_cos
:
Tensor
=
None
,
rotary_pos_sin
:
Tensor
=
None
,
attention_bias
:
Tensor
=
None
,
inference_params
:
InferenceParams
=
None
,
packed_seq_params
:
PackedSeqParams
=
None
,
sequence_len_offset
:
Tensor
=
None
,
):
"""
Perform the forward pass through the MTP layer.
Args:
hidden_states (Tensor): hidden states tensor of shape [s, b, h] where s is the
sequence length, b is the batch size, and h is the hidden size.
decoder_input (Tensor): Input tensor of shape [s, b, h] where s is the
sequence length, b is the batch size, and h is the hidden size.
At the (k - 1)-th MTP module, the i-th element of decoder input is
the embedding of (i + K)-th tocken.
attention_mask (Tensor): Boolean tensor of shape [1, 1, s, s] for masking
self-attention.
context (Tensor, optional): Context tensor for cross-attention.
context_mask (Tensor, optional): Mask for cross-attention context
rotary_pos_emb (Tensor, optional): Rotary positional embeddings.
attention_bias (Tensor): Bias tensor for Q * K.T of shape in shape broadcastable
to [b, num_head, sq, skv], e.g. [1, 1, sq, skv].
Used as an alternative to apply attention mask for TE cuDNN attention.
inference_params (InferenceParams, optional): Parameters for inference-time
optimizations.
packed_seq_params (PackedSeqParams, optional): Parameters for packed sequence
processing.
Returns:
Union[Tensor, Tuple[Tensor, Tensor]]: The output hidden states tensor of shape
[s, b, h], and optionally the updated context tensor if cross-attention is used.
"""
assert
context
is
None
,
f
"multi token prediction + cross attention is not yet supported."
assert
(
packed_seq_params
is
None
),
f
"multi token prediction + sequence packing is not yet supported."
hidden_states
=
make_viewless_tensor
(
inp
=
hidden_states
,
requires_grad
=
True
,
keep_graph
=
True
)
if
self
.
config
.
sequence_parallel
:
rng_context
=
tensor_parallel
.
get_cuda_rng_tracker
().
fork
()
else
:
rng_context
=
nullcontext
()
if
self
.
config
.
fp8
:
import
transformer_engine
# To keep out TE dependency when not training in fp8
if
self
.
config
.
fp8
==
"e4m3"
:
fp8_format
=
transformer_engine
.
common
.
recipe
.
Format
.
E4M3
elif
self
.
config
.
fp8
==
"hybrid"
:
fp8_format
=
transformer_engine
.
common
.
recipe
.
Format
.
HYBRID
else
:
raise
ValueError
(
"E4M3 and HYBRID are the only supported FP8 formats."
)
fp8_recipe
=
TEDelayedScaling
(
config
=
self
.
config
,
fp8_format
=
fp8_format
,
override_linear_precision
=
(
False
,
False
,
not
self
.
config
.
fp8_wgrad
),
)
fp8_group
=
None
if
parallel_state
.
model_parallel_is_initialized
():
fp8_group
=
parallel_state
.
get_amax_reduction_group
(
with_context_parallel
=
True
,
tp_only_amax_red
=
self
.
tp_only_amax_red
)
fp8_context
=
transformer_engine
.
pytorch
.
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
fp8_recipe
,
fp8_group
=
fp8_group
)
else
:
fp8_context
=
nullcontext
()
with
rng_context
,
fp8_context
:
decoder_input
=
self
.
enorm
(
decoder_input
)
decoder_input
=
make_viewless_tensor
(
inp
=
decoder_input
,
requires_grad
=
True
,
keep_graph
=
True
)
hidden_states
=
self
.
hnorm
(
hidden_states
)
hidden_states
=
make_viewless_tensor
(
inp
=
hidden_states
,
requires_grad
=
True
,
keep_graph
=
True
)
# At the (k - 1)-th MTP module, concatenates the i-th tocken's hidden_states
# and the (i + K)-th tocken's embedding, and combine them with linear projection.
hidden_states
=
torch
.
cat
((
decoder_input
,
hidden_states
),
-
1
)
hidden_states
,
_
=
self
.
eh_proj
(
hidden_states
)
# For tensor parallel, all gather after linear_fc.
hidden_states
=
all_gather_last_dim_from_tensor_parallel_region
(
hidden_states
)
# For sequence parallel, scatter after linear_fc and before transformer layer.
if
self
.
sequence_parallel
:
hidden_states
=
scatter_to_sequence_parallel_region
(
hidden_states
)
hidden_states
,
_
=
self
.
transformer_layer
(
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
context
=
context
,
context_mask
=
context_mask
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_cos
=
rotary_pos_cos
,
rotary_pos_sin
=
rotary_pos_sin
,
attention_bias
=
attention_bias
,
inference_params
=
inference_params
,
packed_seq_params
=
packed_seq_params
,
sequence_len_offset
=
sequence_len_offset
,
)
# Layer norm before shared head layer.
hidden_states
=
self
.
final_layernorm
(
hidden_states
)
# TENorm produces a "viewed" tensor. This will result in schedule.py's
# deallocate_output_tensor() throwing an error, so a viewless tensor is
# created to prevent this.
hidden_states
=
make_viewless_tensor
(
inp
=
hidden_states
,
requires_grad
=
True
,
keep_graph
=
True
)
return
hidden_states
def
sharded_state_dict
(
self
,
prefix
:
str
=
''
,
sharded_offsets
:
tuple
=
(),
metadata
:
Optional
[
dict
]
=
None
)
->
ShardedStateDict
:
"""
Generate a sharded state dictionary for the multi token prediction layer.
Args:
prefix (str, optional): Prefix to be added to all keys in the state dict.
sharded_offsets (tuple, optional): Tuple of sharding offsets.
metadata (Optional[dict], optional): Additional metadata for sharding.
Returns:
ShardedStateDict: A dictionary containing the sharded state of the multi
token prediction layer.
"""
sharded_state_dict
=
super
().
sharded_state_dict
(
prefix
,
sharded_offsets
,
metadata
)
return
sharded_state_dict
@
dataclass
class
MultiTokenPredictionBlockSubmodules
:
"""
Dataclass for specifying the submodules of a multi token prediction block.
This class defines the structure for configuring the layers, allowing for
flexible and customizable architecture designs.
Args:
layer_specs (List[ModuleSpec], optional): A list of module specifications for
the layers within the multi token prediction block. Each specification typically
defines a complete multi token prediction layer (e.g., shared embedding,
projection matrix, transformer block, shared output head).
"""
layer_specs
:
List
[
ModuleSpec
]
=
None
def
_get_mtp_block_submodules
(
config
:
TransformerConfig
,
spec
:
Union
[
MultiTokenPredictionBlockSubmodules
,
ModuleSpec
]
)
->
MultiTokenPredictionBlockSubmodules
:
"""
Retrieve or construct MultiTokenPredictionBlockSubmodules based on the provided specification.
Args:
config (TransformerConfig): Configuration object for the transformer model.
spec (Union[MultiTokenPredictionBlockSubmodules, ModuleSpec]): Specification for the
multi token prediction block submodules.
Can be either a MultiTokenPredictionBlockSubmodules instance or a ModuleSpec.
Returns:
MultiTokenPredictionBlockSubmodules: The submodules for the multi token prediction block.
"""
# Transformer block submodules.
if
isinstance
(
spec
,
MultiTokenPredictionBlockSubmodules
):
return
spec
elif
isinstance
(
spec
,
ModuleSpec
):
if
issubclass
(
spec
.
module
,
MultiTokenPredictionBlock
):
return
spec
.
submodules
else
:
raise
Exception
(
f
"specialize for
{
spec
.
module
.
__name__
}
."
)
else
:
raise
Exception
(
f
"specialize for
{
type
(
spec
).
__name__
}
."
)
class
MultiTokenPredictionBlock
(
MegatronModule
):
"""The implementation for Multi-Token Prediction (MTP) which extends
the prediction scope to multiple future tokens at each position.
This MTP implementation sequentially predict additional tokens and keep the complete
causal chain at each prediction depth, by using D sequential modules to predict
D additional tokens.
The k-th MTP module consists of a shared embedding layer, a projection matrix,
a Transformer block, and a shared output head.
For the i-th input token at the (k - 1)-th prediction depth, we first combine
the representation of the i-th token and the embedding of the (i + K)-th token with
the linear projection. The combined serves as the input of the Transformer block at
the k-th depth to produce the output representation.
for more information, please refer to DeepSeek-V3 Technical Report
https://github.com/deepseek-ai/DeepSeek-V3/blob/main/DeepSeek_V3.pdf
"""
def
__init__
(
self
,
config
:
TransformerConfig
,
spec
:
Union
[
TransformerBlockSubmodules
,
ModuleSpec
]
):
super
().
__init__
(
config
=
config
)
self
.
submodules
=
_get_mtp_block_submodules
(
config
,
spec
)
self
.
mtp_loss_scaling_factor
=
config
.
mtp_loss_scaling_factor
self
.
_build_layers
()
assert
len
(
self
.
layers
)
>
0
,
"MultiTokenPredictionBlock must have at least one layer."
def
_build_layers
(
self
):
def
build_layer
(
layer_spec
,
layer_number
):
return
build_module
(
layer_spec
,
config
=
self
.
config
,
layer_number
=
layer_number
)
self
.
layers
=
torch
.
nn
.
ModuleList
(
[
build_layer
(
layer_spec
,
i
+
1
)
for
i
,
layer_spec
in
enumerate
(
self
.
submodules
.
layer_specs
)
]
)
def
forward
(
self
,
input_ids
:
Tensor
,
position_ids
:
Tensor
,
hidden_states
:
Tensor
,
attention_mask
:
Tensor
,
labels
:
Tensor
=
None
,
context
:
Tensor
=
None
,
context_mask
:
Tensor
=
None
,
rotary_pos_emb
:
Tensor
=
None
,
rotary_pos_cos
:
Tensor
=
None
,
rotary_pos_sin
:
Tensor
=
None
,
attention_bias
:
Tensor
=
None
,
inference_params
:
InferenceParams
=
None
,
packed_seq_params
:
PackedSeqParams
=
None
,
sequence_len_offset
:
Tensor
=
None
,
extra_block_kwargs
:
dict
=
None
,
runtime_gather_output
:
Optional
[
bool
]
=
None
,
loss_mask
:
Optional
[
Tensor
]
=
None
,
embedding
=
None
,
output_layer
=
None
,
output_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
compute_language_model_loss
=
None
,
)
->
Tensor
:
"""
Perform the forward pass through all of the MTP modules.
Args:
hidden_states (Tensor): Hidden states for input token with the shape [s, b, h]
where s is the sequence length, b is the batch size, and h is the hidden size.
attention_mask (Tensor): Boolean tensor of shape [1, 1, s, s] for masking
self-attention.
Returns:
(Tensor): The mtp loss tensor of shape [b, s].
"""
assert
(
labels
is
not
None
),
f
"labels should not be None for calculating multi token prediction loss."
if
loss_mask
is
None
:
# if loss_mask is not provided, use all ones as loss_mask
loss_mask
=
torch
.
ones_like
(
labels
)
hidden_states_main_model
=
hidden_states
for
layer_number
in
range
(
len
(
self
.
layers
)):
# Calc logits for the current Multi-Token Prediction (MTP) layers.
input_ids
,
_
=
roll_tensor
(
input_ids
,
shifts
=-
1
,
dims
=-
1
)
# embedding
decoder_input
=
embedding
(
input_ids
=
input_ids
,
position_ids
=
position_ids
)
# norm, linear projection and transformer
hidden_states
=
self
.
layers
[
layer_number
](
decoder_input
=
decoder_input
,
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
inference_params
=
inference_params
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_cos
=
rotary_pos_cos
,
rotary_pos_sin
=
rotary_pos_sin
,
packed_seq_params
=
packed_seq_params
,
sequence_len_offset
=
sequence_len_offset
,
**
(
extra_block_kwargs
or
{}),
)
# output
mtp_logits
,
_
=
output_layer
(
hidden_states
,
weight
=
output_weight
,
runtime_gather_output
=
runtime_gather_output
)
# Calc loss for the current Multi-Token Prediction (MTP) layers.
labels
,
_
=
roll_tensor
(
labels
,
shifts
=-
1
,
dims
=-
1
)
loss_mask
,
num_tokens
=
roll_tensor
(
loss_mask
,
shifts
=-
1
,
dims
=-
1
)
mtp_loss
=
compute_language_model_loss
(
labels
,
mtp_logits
)
mtp_loss
=
loss_mask
*
mtp_loss
if
self
.
training
:
MTPLossLoggingHelper
.
save_loss_to_tracker
(
torch
.
sum
(
mtp_loss
)
/
num_tokens
,
layer_number
,
self
.
config
.
mtp_num_layers
,
avg_group
=
parallel_state
.
get_tensor_and_context_parallel_group
(),
)
mtp_loss_scale
=
self
.
mtp_loss_scaling_factor
/
self
.
config
.
mtp_num_layers
if
self
.
config
.
calculate_per_token_loss
:
hidden_states_main_model
=
MTPLossAutoScaler
.
apply
(
hidden_states_main_model
,
mtp_loss_scale
*
mtp_loss
)
else
:
hidden_states_main_model
=
MTPLossAutoScaler
.
apply
(
hidden_states_main_model
,
mtp_loss_scale
*
mtp_loss
/
num_tokens
)
return
hidden_states_main_model
def
sharded_state_dict
(
self
,
prefix
:
str
=
''
,
sharded_offsets
:
tuple
=
(),
metadata
:
Optional
[
dict
]
=
None
)
->
ShardedStateDict
:
"""
Generate a sharded state dictionary for the multi token prediction module.
Args:
prefix (str, optional): Prefix to be added to all keys in the state dict.
sharded_offsets (tuple, optional): Tuple of sharding offsets.
metadata (Optional[dict], optional): Additional metadata for sharding.
Returns:
ShardedStateDict: A dictionary containing the sharded state of the multi
token prediction module.
"""
sharded_state_dict
=
super
().
sharded_state_dict
(
prefix
,
sharded_offsets
,
metadata
)
layer_prefix
=
f
'
{
prefix
}
layers.'
for
layer
in
self
.
layers
:
offset
=
get_mtp_layer_offset
(
self
.
config
)
sharded_prefix
=
f
'
{
layer_prefix
}{
layer
.
layer_number
-
1
}
.'
state_dict_prefix
=
f
'
{
layer_prefix
}{
layer
.
layer_number
-
1
-
offset
}
.'
sharded_pp_offset
=
[]
layer_sharded_state_dict
=
layer
.
sharded_state_dict
(
state_dict_prefix
,
sharded_pp_offset
,
metadata
)
replace_prefix_for_sharding
(
layer_sharded_state_dict
,
state_dict_prefix
,
sharded_prefix
)
sharded_state_dict
.
update
(
layer_sharded_state_dict
)
return
sharded_state_dict
dcu_megatron/core/transformer/transformer_block.py
View file @
1f7b14ab
...
...
@@ -8,7 +8,7 @@ def transformer_block_init_wrapper(fn):
# mtp require seperate layernorms for main model and mtp modules, thus move finalnorm out of block
config
=
args
[
0
]
if
len
(
args
)
>
1
else
kwargs
[
'config'
]
if
getattr
(
config
,
"
num_nextn_predict
_layers"
,
0
)
>
0
:
if
getattr
(
config
,
"
mtp_num
_layers"
,
0
)
>
0
:
self
.
main_final_layernorm
=
self
.
final_layernorm
self
.
final_layernorm
=
None
...
...
dcu_megatron/core/transformer/transformer_config.py
View file @
1f7b14ab
from
typing
import
Optional
from
functools
import
wraps
from
dataclasses
import
dataclass
from
megatron.training
import
get_args
from
megatron.core.transformer.transformer_config
import
TransformerConfig
,
MLATransformerConfig
def
transformer_config_post_init_wrapper
(
fn
):
@
wraps
(
fn
)
def
wrapper
(
self
):
fn
(
self
)
args
=
get_args
()
"""Number of Multi-Token Prediction (MTP) Layers."""
self
.
mtp_num_layers
=
args
.
mtp_num_layers
"""Weighting factor of Multi-Token Prediction (MTP) loss."""
self
.
mtp_loss_scaling_factor
=
args
.
mtp_loss_scaling_factor
##################
# flux
##################
self
.
flux_transpose_weight
=
args
.
flux_transpose_weight
return
wrapper
@
dataclass
class
ExtraTransformerConfig
:
##################
# multi-token prediction
##################
num_nextn_predict_layers
:
int
=
0
"""The number of multi-token prediction layers"""
mtp_loss_scale
:
float
=
0.3
"""Multi-token prediction loss scale"""
recompute_mtp_norm
:
bool
=
False
"""Whether to recompute mtp normalization"""
recompute_mtp_layer
:
bool
=
False
"""Whether to recompute mtp layer"""
mtp_num_layers
:
Optional
[
int
]
=
None
"""Number of Multi-Token Prediction (MTP) Layers."""
share_mtp_embedding_and_output_weight
:
bool
=
Fals
e
"""
share embedding and output weight with mtp layer
."""
mtp_loss_scaling_factor
:
Optional
[
float
]
=
Non
e
"""
Weighting factor of Multi-Token Prediction (MTP) loss
."""
##################
# flux
...
...
dcu_megatron/training/arguments.py
View file @
1f7b14ab
...
...
@@ -170,14 +170,16 @@ def _add_extra_tokenizer_args(parser):
def
_add_mtp_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'multi token prediction'
)
group
.
add_argument
(
'--num-nextn-predict-layers'
,
type
=
int
,
default
=
0
,
help
=
'Multi-Token prediction layer num'
)
group
.
add_argument
(
'--mtp-loss-scale'
,
type
=
float
,
default
=
0.3
,
help
=
'Multi-Token prediction loss scale'
)
group
.
add_argument
(
'--recompute-mtp-norm'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Multi-Token prediction recompute norm'
)
group
.
add_argument
(
'--recompute-mtp-layer'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Multi-Token prediction recompute layer'
)
group
.
add_argument
(
'--share-mtp-embedding-and-output-weight'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Main model share embedding and output weight with mtp layer.'
)
group
.
add_argument
(
'--mtp-num-layers'
,
type
=
int
,
default
=
None
,
help
=
'Number of Multi-Token Prediction (MTP) Layers.'
'MTP extends the prediction scope to multiple future tokens at each position.'
'This MTP implementation sequentially predict additional tokens '
'by using D sequential modules to predict D additional tokens.'
)
group
.
add_argument
(
'--mtp-loss-scaling-factor'
,
type
=
float
,
default
=
0.3
,
help
=
'Scaling factor of Multi-Token Prediction (MTP) loss. '
'We compute the average of the MTP losses across all depths, '
'and multiply it the scaling factor to obtain the overall MTP loss, '
'which serves as an additional training objective.'
)
return
parser
...
...
dcu_megatron/training/utils.py
View file @
1f7b14ab
...
...
@@ -9,103 +9,97 @@ def get_batch_on_this_tp_rank(data_iterator):
args
=
get_args
()
def
_broadcast
(
item
):
if
item
is
not
None
:
torch
.
distributed
.
broadcast
(
item
,
mpu
.
get_tensor_model_parallel_src_rank
(),
group
=
mpu
.
get_tensor_model_parallel_group
())
if
item
is
not
None
:
torch
.
distributed
.
broadcast
(
item
,
mpu
.
get_tensor_model_parallel_src_rank
(),
group
=
mpu
.
get_tensor_model_parallel_group
())
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
if
data_iterator
is
not
None
:
data
=
next
(
data_iterator
)
else
:
data
=
None
batch
=
{
'tokens'
:
data
[
"tokens"
].
cuda
(
non_blocking
=
True
),
'labels'
:
data
[
"labels"
].
cuda
(
non_blocking
=
True
),
'loss_mask'
:
data
[
"loss_mask"
].
cuda
(
non_blocking
=
True
),
'attention_mask'
:
None
if
"attention_mask"
not
in
data
else
data
[
"attention_mask"
].
cuda
(
non_blocking
=
True
),
'position_ids'
:
data
[
"position_ids"
].
cuda
(
non_blocking
=
True
)
}
if
args
.
pipeline_model_parallel_size
==
1
:
_broadcast
(
batch
[
'tokens'
])
_broadcast
(
batch
[
'labels'
])
_broadcast
(
batch
[
'loss_mask'
])
_broadcast
(
batch
[
'attention_mask'
])
_broadcast
(
batch
[
'position_ids'
])
elif
mpu
.
is_pipeline_first_stage
():
_broadcast
(
batch
[
'tokens'
])
_broadcast
(
batch
[
'attention_mask'
])
_broadcast
(
batch
[
'position_ids'
])
elif
mpu
.
is_pipeline_last_stage
():
if
args
.
num_nextn_predict_layers
:
if
data_iterator
is
not
None
:
data
=
next
(
data_iterator
)
else
:
data
=
None
batch
=
{
'tokens'
:
data
[
"tokens"
].
cuda
(
non_blocking
=
True
),
'labels'
:
data
[
"labels"
].
cuda
(
non_blocking
=
True
),
'loss_mask'
:
data
[
"loss_mask"
].
cuda
(
non_blocking
=
True
),
'attention_mask'
:
None
if
"attention_mask"
not
in
data
else
data
[
"attention_mask"
].
cuda
(
non_blocking
=
True
),
'position_ids'
:
data
[
"position_ids"
].
cuda
(
non_blocking
=
True
)
}
if
args
.
pipeline_model_parallel_size
==
1
:
_broadcast
(
batch
[
'tokens'
])
_broadcast
(
batch
[
'labels'
])
_broadcast
(
batch
[
'loss_mask'
])
_broadcast
(
batch
[
'attention_mask'
])
_broadcast
(
batch
[
'position_ids'
])
elif
mpu
.
is_pipeline_first_stage
():
_broadcast
(
batch
[
'tokens'
])
_broadcast
(
batch
[
'attention_mask'
])
_broadcast
(
batch
[
'position_ids'
])
elif
mpu
.
is_pipeline_last_stage
():
# Multi-Token Prediction (MTP) layers need tokens and position_ids to calculate embedding.
# Currently the Multi-Token Prediction (MTP) layers is fixed on the last stage, so we need
# to broadcast tokens and position_ids to all of the tensor parallel ranks on the last stage.
if
args
.
mtp_num_layers
is
not
None
:
_broadcast
(
batch
[
'tokens'
])
_broadcast
(
batch
[
'labels'
])
_broadcast
(
batch
[
'loss_mask'
])
_broadcast
(
batch
[
'attention_mask'
])
if
args
.
reset_position_ids
or
args
.
num_nextn_predict_layers
:
_broadcast
(
batch
[
'position_ids'
])
_broadcast
(
batch
[
'labels'
])
_broadcast
(
batch
[
'loss_mask'
])
_broadcast
(
batch
[
'attention_mask'
])
else
:
tokens
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
+
args
.
num_nextn_predict_layers
),
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
labels
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
+
args
.
num_nextn_predict_layers
),
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
loss_mask
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
+
args
.
num_nextn_predict_layers
),
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
())
if
args
.
create_attention_mask_in_dataloader
:
attention_mask
=
torch
.
empty
(
(
args
.
micro_batch_size
,
1
,
args
.
seq_length
+
args
.
num_nextn_predict_layers
,
args
.
seq_length
+
args
.
num_nextn_predict_layers
),
dtype
=
torch
.
bool
,
device
=
torch
.
cuda
.
current_device
()
tokens
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
),
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
labels
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
),
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
loss_mask
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
),
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
())
if
args
.
create_attention_mask_in_dataloader
:
attention_mask
=
torch
.
empty
(
(
args
.
micro_batch_size
,
1
,
args
.
seq_length
,
args
.
seq_length
),
dtype
=
torch
.
bool
,
device
=
torch
.
cuda
.
current_device
()
)
else
:
attention_mask
=
None
position_ids
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
+
args
.
num_nextn_predict_layers
),
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
if
args
.
pipeline_model_parallel_size
==
1
:
_broadcast
(
tokens
)
_broadcast
(
labels
)
_broadcast
(
loss_mask
)
_broadcast
(
attention_mask
)
_broadcast
(
position_ids
)
elif
mpu
.
is_pipeline_first_stage
():
labels
=
None
loss_mask
=
None
_broadcast
(
tokens
)
_broadcast
(
attention_mask
)
_broadcast
(
position_ids
)
elif
mpu
.
is_pipeline_last_stage
():
if
args
.
num_nextn_predict_layers
:
else
:
attention_mask
=
None
position_ids
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
),
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
if
args
.
pipeline_model_parallel_size
==
1
:
_broadcast
(
tokens
)
_broadcast
(
labels
)
_broadcast
(
loss_mask
)
_broadcast
(
attention_mask
)
_broadcast
(
position_ids
)
elif
mpu
.
is_pipeline_first_stage
():
labels
=
None
loss_mask
=
None
_broadcast
(
tokens
)
_broadcast
(
attention_mask
)
_broadcast
(
position_ids
)
elif
mpu
.
is_pipeline_last_stage
():
# Multi-Token Prediction (MTP) layers need tokens and position_ids to calculate embedding.
# Currently the Multi-Token Prediction (MTP) layers is fixed on the last stage, so we need
# to broadcast tokens and position_ids to all of the tensor parallel ranks on the last stage.
if
args
.
mtp_num_layers
is
not
None
:
_broadcast
(
tokens
)
else
:
tokens
=
None
_broadcast
(
labels
)
_broadcast
(
loss_mask
)
_broadcast
(
attention_mask
)
if
args
.
reset_position_ids
or
args
.
num_nextn_predict_layers
:
_broadcast
(
position_ids
)
else
:
position_ids
=
None
batch
=
{
'tokens'
:
tokens
,
'labels'
:
labels
,
'loss_mask'
:
loss_mask
,
'attention_mask'
:
attention_mask
,
'position_ids'
:
position_ids
}
else
:
tokens
=
None
position_ids
=
None
_broadcast
(
labels
)
_broadcast
(
loss_mask
)
_broadcast
(
attention_mask
)
batch
=
{
'tokens'
:
tokens
,
'labels'
:
labels
,
'loss_mask'
:
loss_mask
,
'attention_mask'
:
attention_mask
,
'position_ids'
:
position_ids
}
return
batch
pretrain_gpt.py
View file @
1f7b14ab
...
...
@@ -39,9 +39,7 @@ from megatron.core.models.gpt.gpt_layer_specs import (
get_gpt_layer_with_transformer_engine_spec
,
)
from
megatron.core.transformer.transformer_block
import
TransformerBlockSubmodules
from
dcu_megatron.core.transformer.mtp.mtp_spec
import
get_mtp_spec
from
dcu_megatron.core.utils
import
tensor_slide
from
dcu_megatron.core.models.gpt.gpt_layer_specs
import
get_gpt_mtp_block_spec
from
dcu_megatron
import
megatron_adaptor
...
...
@@ -133,13 +131,12 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
raise
RuntimeError
(
"--fp8-param-gather requires `fp8_model_init` from TransformerEngine, but not found."
)
# Define the mtp layer spec
if
isinstance
(
transformer_layer_spec
,
TransformerBlockSubmodules
):
mtp_transformer_layer_spec
=
transformer_layer_spec
.
layer_specs
[
-
1
]
else
:
mtp_transformer_layer_spec
=
transformer_
layer_spec
mtp_block_spec
=
None
if
args
.
mtp_num_layers
is
not
None
:
from
dcu_megatron.core.models.gpt.gpt_layer_specs
import
get_gpt_mtp_block_spec
mtp_
block_spec
=
get_gpt_mtp_block_spec
(
config
,
transformer_layer_spec
,
use_
transformer_
engine
=
use_te
)
with
build_model_context
(
**
build_model_context_args
):
config
.
mtp_spec
=
get_mtp_spec
(
mtp_transformer_layer_spec
,
use_te
=
use_te
)
model
=
GPTModel
(
config
=
config
,
transformer_layer_spec
=
transformer_layer_spec
,
...
...
@@ -153,7 +150,8 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
position_embedding_type
=
args
.
position_embedding_type
,
rotary_percent
=
args
.
rotary_percent
,
rotary_base
=
args
.
rotary_base
,
rope_scaling
=
args
.
use_rope_scaling
rope_scaling
=
args
.
use_rope_scaling
,
mtp_block_spec
=
mtp_block_spec
,
)
# model = torch.compile(model,mode='max-autotune-no-cudagraphs')
print_rank_0
(
model
)
...
...
@@ -197,8 +195,6 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
args
=
get_args
()
losses
=
output_tensor
.
float
()
if
getattr
(
args
,
"num_nextn_predict_layers"
,
0
)
>
0
:
loss_mask
=
tensor_slide
(
loss_mask
,
args
.
num_nextn_predict_layers
,
return_first
=
True
)[
0
]
loss_mask
=
loss_mask
.
view
(
-
1
).
float
()
total_tokens
=
loss_mask
.
sum
()
loss
=
torch
.
cat
([
torch
.
sum
(
losses
.
view
(
-
1
)
*
loss_mask
).
view
(
1
),
total_tokens
.
view
(
1
)])
...
...
@@ -267,8 +263,12 @@ def forward_step(data_iterator, model: GPTModel):
timers
(
'batch-generator'
).
stop
()
with
stimer
:
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
,
labels
=
labels
)
if
args
.
use_legacy_models
:
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
,
labels
=
labels
)
else
:
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
,
labels
=
labels
,
loss_mask
=
loss_mask
)
return
output_tensor
,
partial
(
loss_func
,
loss_mask
)
...
...
@@ -289,7 +289,7 @@ def core_gpt_dataset_config_from_args(args):
return
GPTDatasetConfig
(
random_seed
=
args
.
seed
,
sequence_length
=
args
.
seq_length
+
getattr
(
args
,
"num_nextn_predict_layers"
,
0
)
,
sequence_length
=
args
.
seq_length
,
blend
=
blend
,
blend_per_split
=
blend_per_split
,
split
=
args
.
split
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment