Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
1f7b14ab
Commit
1f7b14ab
authored
Apr 26, 2025
by
sdwldchl
Browse files
rewrite mtp
parent
89d29a02
Changes
15
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
1429 additions
and
616 deletions
+1429
-616
dcu_megatron/adaptor/megatron_adaptor.py
dcu_megatron/adaptor/megatron_adaptor.py
+30
-32
dcu_megatron/adaptor/patch_utils.py
dcu_megatron/adaptor/patch_utils.py
+21
-3
dcu_megatron/core/distributed/finalize_model_grads.py
dcu_megatron/core/distributed/finalize_model_grads.py
+6
-1
dcu_megatron/core/models/common/language_module/language_module.py
...ron/core/models/common/language_module/language_module.py
+6
-5
dcu_megatron/core/models/gpt/gpt_layer_specs.py
dcu_megatron/core/models/gpt/gpt_layer_specs.py
+52
-3
dcu_megatron/core/models/gpt/gpt_model.py
dcu_megatron/core/models/gpt/gpt_model.py
+383
-350
dcu_megatron/core/pipeline_parallel/schedules.py
dcu_megatron/core/pipeline_parallel/schedules.py
+52
-0
dcu_megatron/core/tensor_parallel/__init__.py
dcu_megatron/core/tensor_parallel/__init__.py
+0
-2
dcu_megatron/core/tensor_parallel/layers.py
dcu_megatron/core/tensor_parallel/layers.py
+8
-96
dcu_megatron/core/transformer/multi_token_prediction.py
dcu_megatron/core/transformer/multi_token_prediction.py
+737
-0
dcu_megatron/core/transformer/transformer_block.py
dcu_megatron/core/transformer/transformer_block.py
+1
-1
dcu_megatron/core/transformer/transformer_config.py
dcu_megatron/core/transformer/transformer_config.py
+27
-13
dcu_megatron/training/arguments.py
dcu_megatron/training/arguments.py
+10
-8
dcu_megatron/training/utils.py
dcu_megatron/training/utils.py
+82
-88
pretrain_gpt.py
pretrain_gpt.py
+14
-14
No files found.
dcu_megatron/adaptor/megatron_adaptor.py
View file @
1f7b14ab
...
@@ -5,6 +5,8 @@ import types
...
@@ -5,6 +5,8 @@ import types
import
argparse
import
argparse
import
torch
import
torch
from
.adaptor_arguments
import
get_adaptor_args
class
MegatronAdaptation
:
class
MegatronAdaptation
:
"""
"""
...
@@ -21,6 +23,15 @@ class MegatronAdaptation:
...
@@ -21,6 +23,15 @@ class MegatronAdaptation:
for
adaptation
in
[
CoreAdaptation
(),
LegacyAdaptation
()]:
for
adaptation
in
[
CoreAdaptation
(),
LegacyAdaptation
()]:
adaptation
.
execute
()
adaptation
.
execute
()
MegatronAdaptation
.
apply
()
MegatronAdaptation
.
apply
()
# from .patch_utils import MegatronPatchesManager
# args = get_adaptor_args()
# for feature in FEATURES_LIST:
# if (getattr(args, feature.feature_name, None) and feature.optimization_level > 0) or feature.optimization_level == 0:
# feature.register_patches(MegatronPatchesManager, args)
# MindSpeedPatchesManager.apply_patches()
# MegatronAdaptation.post_execute()
# MegatronAdaptation.post_execute()
@
classmethod
@
classmethod
...
@@ -87,38 +98,20 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -87,38 +98,20 @@ class CoreAdaptation(MegatronAdaptationABC):
self
.
patch_miscellaneous
()
self
.
patch_miscellaneous
()
def
patch_core_distributed
(
self
):
def
patch_core_distributed
(
self
):
#
M
tp share embedding
#
m
tp share embedding
from
..core.distributed.finalize_model_grads
import
_allreduce_word_embedding_grads
from
..core.distributed.finalize_model_grads
import
_allreduce_word_embedding_grads
MegatronAdaptation
.
register
(
'megatron.core.distributed.finalize_model_grads._allreduce_word_embedding_grads'
,
MegatronAdaptation
.
register
(
'megatron.core.distributed.finalize_model_grads._allreduce_word_embedding_grads'
,
_allreduce_word_embedding_grads
)
_allreduce_word_embedding_grads
)
def
patch_core_models
(
self
):
def
patch_core_models
(
self
):
from
..core.models.common.embeddings.language_model_embedding
import
(
language_model_embedding_forward
,
language_model_embedding_init_func
)
from
..core.models.gpt.gpt_model
import
(
gpt_model_forward
,
gpt_model_init
,
shared_embedding_or_output_weight
,
)
from
..core.models.common.language_module.language_module
import
(
from
..core.models.common.language_module.language_module
import
(
setup_embeddings_and_output_layer
,
setup_embeddings_and_output_layer
,
tie_embeddings_and_output_weights_state_dict
tie_embeddings_and_output_weights_state_dict
,
)
)
from
..core.models.gpt.gpt_model
import
GPTModel
from
..training.utils
import
get_batch_on_this_tp_rank
from
..training.utils
import
get_batch_on_this_tp_rank
# Embedding
# LanguageModule
MegatronAdaptation
.
register
(
'megatron.core.models.common.embeddings.language_model_embedding.LanguageModelEmbedding.__init__'
,
language_model_embedding_init_func
)
MegatronAdaptation
.
register
(
'megatron.core.models.common.embeddings.language_model_embedding.LanguageModelEmbedding.forward'
,
language_model_embedding_forward
)
MegatronAdaptation
.
register
(
'megatron.training.utils.get_batch_on_this_tp_rank'
,
get_batch_on_this_tp_rank
)
# GPT Model
MegatronAdaptation
.
register
(
MegatronAdaptation
.
register
(
'megatron.core.models.common.language_module.language_module.LanguageModule.setup_embeddings_and_output_layer'
,
'megatron.core.models.common.language_module.language_module.LanguageModule.setup_embeddings_and_output_layer'
,
setup_embeddings_and_output_layer
)
setup_embeddings_and_output_layer
)
...
@@ -126,17 +119,16 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -126,17 +119,16 @@ class CoreAdaptation(MegatronAdaptationABC):
'megatron.core.models.common.language_module.language_module.LanguageModule.tie_embeddings_and_output_weights_state_dict'
,
'megatron.core.models.common.language_module.language_module.LanguageModule.tie_embeddings_and_output_weights_state_dict'
,
tie_embeddings_and_output_weights_state_dict
)
tie_embeddings_and_output_weights_state_dict
)
MegatronAdaptation
.
register
(
MegatronAdaptation
.
register
(
'megatron.training.utils.get_batch_on_this_tp_rank'
,
get_batch_on_this_tp_rank
)
'megatron.core.models.gpt.gpt_model.GPTModel.shared_embedding_or_output_weight'
,
shared_embedding_or_output_weight
)
# GPT Model
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.forward'
,
gpt_model_forward
)
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel'
,
GPTModel
)
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.__init__'
,
gpt_model_init
)
def
patch_core_transformers
(
self
):
def
patch_core_transformers
(
self
):
from
..core
import
transformer_block_init_wrapper
from
..core
import
transformer_block_init_wrapper
from
..core.transformer.transformer_config
import
TransformerConfigPatch
,
MLATransformerConfigPatch
from
..core.transformer.transformer_config
import
TransformerConfigPatch
,
MLATransformerConfigPatch
# Transformer block
# Transformer block
. If mtp_num_layers > 0, move final_layernorm outside
MegatronAdaptation
.
register
(
'megatron.core.transformer.transformer_block.TransformerBlock.__init__'
,
MegatronAdaptation
.
register
(
'megatron.core.transformer.transformer_block.TransformerBlock.__init__'
,
transformer_block_init_wrapper
)
transformer_block_init_wrapper
)
...
@@ -174,13 +166,10 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -174,13 +166,10 @@ class CoreAdaptation(MegatronAdaptationABC):
def
patch_tensor_parallel
(
self
):
def
patch_tensor_parallel
(
self
):
from
..core.tensor_parallel.cross_entropy
import
VocabParallelCrossEntropy
from
..core.tensor_parallel.cross_entropy
import
VocabParallelCrossEntropy
from
..core.tensor_parallel
import
vocab_parallel_embedding_forward
,
vocab_parallel_embedding_init_wrapper
# VocabParallelEmbedding
# VocabParallelEmbedding
MegatronAdaptation
.
register
(
'megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward'
,
MegatronAdaptation
.
register
(
'megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward'
,
vocab_parallel_embedding_forward
)
torch
.
compile
(
mode
=
'max-autotune-no-cudagraphs'
),
MegatronAdaptation
.
register
(
'megatron.core.tensor_parallel.layers.VocabParallelEmbedding.__init__'
,
vocab_parallel_embedding_init_wrapper
,
apply_wrapper
=
True
)
apply_wrapper
=
True
)
# VocabParallelCrossEntropy
# VocabParallelCrossEntropy
...
@@ -211,6 +200,14 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -211,6 +200,14 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation
.
register
(
"megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_with_transformer_engine_spec"
,
MegatronAdaptation
.
register
(
"megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_with_transformer_engine_spec"
,
get_gpt_layer_with_flux_spec
)
get_gpt_layer_with_flux_spec
)
def
patch_pipeline_parallel
(
self
):
from
..core.pipeline_parallel.schedules
import
forward_step_wrapper
# pipeline_parallel.schedules.forward_step
MegatronAdaptation
.
register
(
'megatron.core.pipeline_parallel.schedules.forward_step'
,
forward_step_wrapper
,
apply_wrapper
=
True
)
def
patch_training
(
self
):
def
patch_training
(
self
):
from
..training.tokenizer
import
build_tokenizer
from
..training.tokenizer
import
build_tokenizer
from
..training.initialize
import
_initialize_distributed
from
..training.initialize
import
_initialize_distributed
...
@@ -255,6 +252,7 @@ class LegacyAdaptation(MegatronAdaptationABC):
...
@@ -255,6 +252,7 @@ class LegacyAdaptation(MegatronAdaptationABC):
parallel_mlp_init_wrapper
,
parallel_mlp_init_wrapper
,
apply_wrapper
=
True
)
apply_wrapper
=
True
)
# ParallelAttention
MegatronAdaptation
.
register
(
'megatron.legacy.model.transformer.ParallelAttention.__init__'
,
MegatronAdaptation
.
register
(
'megatron.legacy.model.transformer.ParallelAttention.__init__'
,
parallel_attention_init_wrapper
,
parallel_attention_init_wrapper
,
apply_wrapper
=
True
)
apply_wrapper
=
True
)
...
...
dcu_megatron/adaptor/patch_utils.py
View file @
1f7b14ab
...
@@ -148,11 +148,29 @@ class MegatronPatchesManager:
...
@@ -148,11 +148,29 @@ class MegatronPatchesManager:
patches_info
=
{}
patches_info
=
{}
@
staticmethod
@
staticmethod
def
register_patch
(
orig_func_or_cls_name
,
new_func_or_cls
=
None
,
force_patch
=
False
,
create_dummy
=
False
):
def
register_patch
(
orig_func_or_cls_name
,
new_func_or_cls
=
None
,
force_patch
=
False
,
create_dummy
=
False
,
apply_wrapper
=
False
,
remove_origin_wrappers
=
False
):
if
orig_func_or_cls_name
not
in
MegatronPatchesManager
.
patches_info
:
if
orig_func_or_cls_name
not
in
MegatronPatchesManager
.
patches_info
:
MegatronPatchesManager
.
patches_info
[
orig_func_or_cls_name
]
=
Patch
(
orig_func_or_cls_name
,
new_func_or_cls
,
create_dummy
)
MegatronPatchesManager
.
patches_info
[
orig_func_or_cls_name
]
=
Patch
(
orig_func_or_cls_name
,
new_func_or_cls
,
create_dummy
,
apply_wrapper
=
apply_wrapper
,
remove_origin_wrappers
=
remove_origin_wrappers
)
else
:
else
:
MegatronPatchesManager
.
patches_info
.
get
(
orig_func_or_cls_name
).
set_patch_func
(
new_func_or_cls
,
force_patch
)
MegatronPatchesManager
.
patches_info
.
get
(
orig_func_or_cls_name
).
set_patch_func
(
new_func_or_cls
,
force_patch
,
apply_wrapper
=
apply_wrapper
,
remove_origin_wrappers
=
remove_origin_wrappers
)
@
staticmethod
@
staticmethod
def
apply_patches
():
def
apply_patches
():
...
...
dcu_megatron/core/distributed/finalize_model_grads.py
View file @
1f7b14ab
...
@@ -28,7 +28,12 @@ def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: Transf
...
@@ -28,7 +28,12 @@ def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: Transf
model_module
=
model
[
0
]
model_module
=
model
[
0
]
model_module
=
get_attr_wrapped_model
(
model_module
,
'pre_process'
,
return_model_obj
=
True
)
model_module
=
get_attr_wrapped_model
(
model_module
,
'pre_process'
,
return_model_obj
=
True
)
if
model_module
.
share_embeddings_and_output_weights
or
getattr
(
config
,
'num_nextn_predict_layers'
,
0
):
# If share_embeddings_and_output_weights is True, we need to maintain duplicated
# embedding weights in post processing stage. If use Multi-Token Prediction (MTP),
# we also need to maintain duplicated embedding weights in mtp process stage.
# So we need to allreduce grads of embedding in the embedding group in these cases.
if
model_module
.
share_embeddings_and_output_weights
or
getattr
(
config
,
'mtp_num_layers'
,
0
):
weight
=
model_module
.
shared_embedding_or_output_weight
()
weight
=
model_module
.
shared_embedding_or_output_weight
()
grad_attr
=
"main_grad"
if
hasattr
(
weight
,
"main_grad"
)
else
"grad"
grad_attr
=
"main_grad"
if
hasattr
(
weight
,
"main_grad"
)
else
"grad"
orig_grad
=
getattr
(
weight
,
grad_attr
)
orig_grad
=
getattr
(
weight
,
grad_attr
)
...
...
dcu_megatron/core/models/common/language_module/language_module.py
View file @
1f7b14ab
...
@@ -4,6 +4,7 @@ import torch
...
@@ -4,6 +4,7 @@ import torch
from
megatron.core
import
parallel_state
from
megatron.core
import
parallel_state
from
megatron.core.dist_checkpointing.mapping
import
ShardedStateDict
from
megatron.core.dist_checkpointing.mapping
import
ShardedStateDict
from
megatron.core.models.common.language_module.language_module
import
LanguageModule
from
megatron.core.utils
import
make_tp_sharded_tensor_for_checkpoint
from
megatron.core.utils
import
make_tp_sharded_tensor_for_checkpoint
...
@@ -27,7 +28,7 @@ def setup_embeddings_and_output_layer(self) -> None:
...
@@ -27,7 +28,7 @@ def setup_embeddings_and_output_layer(self) -> None:
# So we need to copy embedding weights from pre processing stage as initial parameters
# So we need to copy embedding weights from pre processing stage as initial parameters
# in these cases.
# in these cases.
if
not
self
.
share_embeddings_and_output_weights
and
not
getattr
(
if
not
self
.
share_embeddings_and_output_weights
and
not
getattr
(
self
.
config
,
'
num_nextn_predict
_layers'
,
0
self
.
config
,
'
mtp_num
_layers'
,
0
):
):
return
return
...
@@ -41,10 +42,10 @@ def setup_embeddings_and_output_layer(self) -> None:
...
@@ -41,10 +42,10 @@ def setup_embeddings_and_output_layer(self) -> None:
if
parallel_state
.
is_pipeline_first_stage
()
and
self
.
pre_process
and
not
self
.
post_process
:
if
parallel_state
.
is_pipeline_first_stage
()
and
self
.
pre_process
and
not
self
.
post_process
:
self
.
shared_embedding_or_output_weight
().
shared_embedding
=
True
self
.
shared_embedding_or_output_weight
().
shared_embedding
=
True
if
self
.
post_process
and
not
self
.
pre_process
:
if
(
self
.
post_process
or
getattr
(
self
,
'mtp_process'
,
False
))
and
not
self
.
pre_process
:
assert
not
parallel_state
.
is_pipeline_first_stage
()
assert
not
parallel_state
.
is_pipeline_first_stage
()
# set w
ord_embeddings weights to 0 here, then copy first
# set w
eights of the duplicated embedding to 0 here,
#
stage's weights
using all_reduce below.
#
then copy weights from pre processing stage
using all_reduce below.
weight
=
self
.
shared_embedding_or_output_weight
()
weight
=
self
.
shared_embedding_or_output_weight
()
weight
.
data
.
fill_
(
0
)
weight
.
data
.
fill_
(
0
)
weight
.
shared
=
True
weight
.
shared
=
True
...
@@ -114,7 +115,7 @@ def tie_embeddings_and_output_weights_state_dict(
...
@@ -114,7 +115,7 @@ def tie_embeddings_and_output_weights_state_dict(
# layer in mtp process stage. In this case, if share_embeddings_and_output_weights is True,
# layer in mtp process stage. In this case, if share_embeddings_and_output_weights is True,
# the shared weights will be stored in embedding layer, and output layer will not have
# the shared weights will be stored in embedding layer, and output layer will not have
# any weight.
# any weight.
if
self
.
post_process
and
getattr
(
self
,
'num_nextn_predict_layer
s'
,
False
):
if
getattr
(
self
,
'mtp_proces
s'
,
False
):
# No output layer
# No output layer
assert
output_layer_weight_key
not
in
sharded_state_dict
,
sharded_state_dict
.
keys
()
assert
output_layer_weight_key
not
in
sharded_state_dict
,
sharded_state_dict
.
keys
()
return
return
...
...
dcu_megatron/core/models/gpt/gpt_layer_specs.py
View file @
1f7b14ab
import
warnings
import
warnings
from
typing
import
Optional
from
typing
import
Optional
,
Union
from
megatron.core.fusions.fused_bias_dropout
import
get_bias_dropout_add
from
megatron.core.fusions.fused_bias_dropout
import
get_bias_dropout_add
from
megatron.core.models.gpt.moe_module_specs
import
get_moe_module_spec
from
megatron.core.models.gpt.moe_module_specs
import
get_moe_module_spec
...
@@ -12,13 +12,13 @@ from megatron.core.transformer.multi_latent_attention import (
...
@@ -12,13 +12,13 @@ from megatron.core.transformer.multi_latent_attention import (
MLASelfAttentionSubmodules
,
MLASelfAttentionSubmodules
,
)
)
from
megatron.core.transformer.spec_utils
import
ModuleSpec
from
megatron.core.transformer.spec_utils
import
ModuleSpec
from
megatron.core.transformer.transformer_block
import
TransformerBlockSubmodules
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core.transformer.transformer_layer
import
(
from
megatron.core.transformer.transformer_layer
import
(
TransformerLayer
,
TransformerLayer
,
TransformerLayerSubmodules
,
TransformerLayerSubmodules
,
)
)
from
dcu_megatron.core.tensor_parallel.layers
import
FluxColumnParallelLinear
,
FluxRowParallelLinear
from
megatron.core.utils
import
is_te_min_version
from
megatron.core.utils
import
is_te_min_version
try
:
try
:
...
@@ -36,6 +36,55 @@ try:
...
@@ -36,6 +36,55 @@ try:
except
ImportError
:
except
ImportError
:
warnings
.
warn
(
'Apex is not installed.'
)
warnings
.
warn
(
'Apex is not installed.'
)
from
dcu_megatron.core.tensor_parallel.layers
import
(
FluxColumnParallelLinear
,
FluxRowParallelLinear
)
from
dcu_megatron.core.transformer.multi_token_prediction
import
(
MultiTokenPredictionBlockSubmodules
,
get_mtp_layer_offset
,
get_mtp_layer_spec
,
get_mtp_num_layers_to_build
,
)
def
get_gpt_mtp_block_spec
(
config
:
TransformerConfig
,
spec
:
Union
[
TransformerBlockSubmodules
,
ModuleSpec
],
use_transformer_engine
:
bool
,
)
->
MultiTokenPredictionBlockSubmodules
:
"""GPT Multi-Token Prediction (MTP) block spec."""
num_layers_to_build
=
get_mtp_num_layers_to_build
(
config
)
if
num_layers_to_build
==
0
:
return
None
if
isinstance
(
spec
,
TransformerBlockSubmodules
):
# get the spec for the last layer of decoder block
transformer_layer_spec
=
spec
.
layer_specs
[
-
1
]
elif
isinstance
(
spec
,
ModuleSpec
)
and
spec
.
module
==
TransformerLayer
:
transformer_layer_spec
=
spec
else
:
raise
ValueError
(
f
"Invalid spec:
{
spec
}
"
)
mtp_layer_spec
=
get_mtp_layer_spec
(
transformer_layer_spec
=
transformer_layer_spec
,
use_transformer_engine
=
use_transformer_engine
)
mtp_num_layers
=
config
.
mtp_num_layers
if
config
.
mtp_num_layers
else
0
mtp_layer_specs
=
[
mtp_layer_spec
]
*
mtp_num_layers
offset
=
get_mtp_layer_offset
(
config
)
# split the mtp layer specs to only include the layers that are built in this pipeline stage.
mtp_layer_specs
=
mtp_layer_specs
[
offset
:
offset
+
num_layers_to_build
]
if
len
(
mtp_layer_specs
)
>
0
:
assert
(
len
(
mtp_layer_specs
)
==
config
.
mtp_num_layers
),
+
f
"currently all of the mtp layers must stage in the same pipeline stage."
mtp_block_spec
=
MultiTokenPredictionBlockSubmodules
(
layer_specs
=
mtp_layer_specs
)
else
:
mtp_block_spec
=
None
return
mtp_block_spec
def
get_gpt_layer_with_flux_spec
(
def
get_gpt_layer_with_flux_spec
(
num_experts
:
Optional
[
int
]
=
None
,
num_experts
:
Optional
[
int
]
=
None
,
...
...
dcu_megatron/core/models/gpt/gpt_model.py
View file @
1f7b14ab
This diff is collapsed.
Click to expand it.
dcu_megatron/core/pipeline_parallel/schedules.py
0 → 100644
View file @
1f7b14ab
import
torch
from
functools
import
wraps
from
dcu_megatron.core.transformer.multi_token_prediction
import
MTPLossAutoScaler
def
forward_step_wrapper
(
fn
):
@
wraps
(
fn
)
def
wrapper
(
forward_step_func
,
data_iterator
,
model
,
num_microbatches
,
input_tensor
,
forward_data_store
,
config
,
**
kwargs
,
):
output
,
num_tokens
=
fn
(
forward_step_func
,
data_iterator
,
model
,
num_microbatches
,
input_tensor
,
forward_data_store
,
config
,
**
kwargs
)
if
not
isinstance
(
input_tensor
,
list
):
# unwrap_output_tensor True
output_tensor
=
output
else
:
output_tensor
=
output
[
0
]
# Set the loss scale for Multi-Token Prediction (MTP) loss.
if
hasattr
(
config
,
'mtp_num_layers'
)
and
config
.
mtp_num_layers
is
not
None
:
# Calculate the loss scale based on the grad_scale_func if available, else default to 1.
loss_scale
=
(
config
.
grad_scale_func
(
torch
.
ones
(
1
,
device
=
output_tensor
.
device
))
if
config
.
grad_scale_func
is
not
None
else
torch
.
ones
(
1
,
device
=
output_tensor
.
device
)
)
# Set the loss scale
if
config
.
calculate_per_token_loss
:
MTPLossAutoScaler
.
set_loss_scale
(
loss_scale
)
else
:
MTPLossAutoScaler
.
set_loss_scale
(
loss_scale
/
num_microbatches
)
return
output
,
num_tokens
return
wrapper
\ No newline at end of file
dcu_megatron/core/tensor_parallel/__init__.py
View file @
1f7b14ab
from
.layers
import
(
from
.layers
import
(
FluxColumnParallelLinear
,
FluxColumnParallelLinear
,
FluxRowParallelLinear
,
FluxRowParallelLinear
,
vocab_parallel_embedding_forward
,
vocab_parallel_embedding_init_wrapper
,
)
)
\ No newline at end of file
dcu_megatron/core/tensor_parallel/layers.py
View file @
1f7b14ab
import
os
import
os
import
copy
import
socket
import
socket
import
warnings
import
warnings
from
functools
import
wraps
from
typing
import
Callable
,
List
,
Optional
from
typing
import
Callable
,
List
,
Optional
if
int
(
os
.
getenv
(
"USE_FLUX_OVERLAP"
,
"0"
)):
try
:
try
:
import
flux
import
flux
except
ImportError
:
from
dcu_megatron.core.utils
import
is_flux_min_version
raise
ImportError
(
"flux is NOT installed"
)
except
ImportError
:
raise
ImportError
(
"flux is NOT installed"
)
import
torch
import
torch
import
torch.nn.functional
as
F
from
torch.nn.parameter
import
Parameter
from
megatron.training
import
print_rank_0
from
megatron.core.model_parallel_config
import
ModelParallelConfig
from
megatron.core.model_parallel_config
import
ModelParallelConfig
from
megatron.core.parallel_state
import
(
from
megatron.core.parallel_state
import
(
get_global_memory_buffer
,
get_global_memory_buffer
,
get_tensor_model_parallel_group
,
get_tensor_model_parallel_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
)
)
from
megatron.core.utils
import
(
from
megatron.core.utils
import
prepare_input_tensors_for_wgrad_compute
is_torch_min_version
,
prepare_input_tensors_for_wgrad_compute
)
from
megatron.core.tensor_parallel.layers
import
(
_initialize_affine_weight_cpu
,
_initialize_affine_weight_gpu
,
VocabParallelEmbedding
,
)
from
megatron.core.tensor_parallel.mappings
import
(
from
megatron.core.tensor_parallel.mappings
import
(
_reduce
,
copy_to_tensor_model_parallel_region
,
copy_to_tensor_model_parallel_region
,
reduce_from_tensor_model_parallel_region
,
reduce_from_tensor_model_parallel_region
,
reduce_scatter_to_sequence_parallel_region
,
_reduce_scatter_along_first_dim
,
_gather_along_first_dim
,
)
)
from
megatron.core.tensor_parallel.utils
import
VocabUtility
from
megatron.core.tensor_parallel.mappings
import
_reduce
from
megatron.core.tensor_parallel
import
(
from
megatron.core.tensor_parallel
import
(
ColumnParallelLinear
,
ColumnParallelLinear
,
RowParallelLinear
,
RowParallelLinear
,
...
@@ -50,9 +30,9 @@ from megatron.core.tensor_parallel.layers import (
...
@@ -50,9 +30,9 @@ from megatron.core.tensor_parallel.layers import (
custom_fwd
,
custom_fwd
,
custom_bwd
,
custom_bwd
,
dist_all_gather_func
,
dist_all_gather_func
,
linear_with_frozen_weight
,
linear_with_grad_accumulation_and_async_allreduce
)
)
from
dcu_megatron.core.utils
import
is_flux_min_version
_grad_accum_fusion_available
=
True
_grad_accum_fusion_available
=
True
try
:
try
:
...
@@ -61,74 +41,6 @@ except ImportError:
...
@@ -61,74 +41,6 @@ except ImportError:
_grad_accum_fusion_available
=
False
_grad_accum_fusion_available
=
False
def
vocab_parallel_embedding_init_wrapper
(
fn
):
@
wraps
(
fn
)
def
wrapper
(
self
,
*
args
,
skip_weight_param_allocation
:
bool
=
False
,
**
kwargs
):
if
(
skip_weight_param_allocation
and
"config"
in
kwargs
and
hasattr
(
kwargs
[
"config"
],
"perform_initialization"
)
):
config
=
copy
.
deepcopy
(
kwargs
[
"config"
])
config
.
perform_initialization
=
False
kwargs
[
"config"
]
=
config
fn
(
self
,
*
args
,
**
kwargs
)
if
skip_weight_param_allocation
:
self
.
weight
=
None
return
wrapper
@
torch
.
compile
(
mode
=
'max-autotune-no-cudagraphs'
)
def
vocab_parallel_embedding_forward
(
self
,
input_
,
weight
=
None
):
"""Forward.
Args:
input_ (torch.Tensor): Input tensor.
"""
if
weight
is
None
:
if
self
.
weight
is
None
:
raise
RuntimeError
(
"weight was not supplied to VocabParallelEmbedding forward pass "
"and skip_weight_param_allocation is True."
)
weight
=
self
.
weight
if
self
.
tensor_model_parallel_size
>
1
:
# Build the mask.
input_mask
=
(
input_
<
self
.
vocab_start_index
)
|
(
input_
>=
self
.
vocab_end_index
)
# Mask the input.
masked_input
=
input_
.
clone
()
-
self
.
vocab_start_index
masked_input
[
input_mask
]
=
0
else
:
masked_input
=
input_
# Get the embeddings.
if
self
.
deterministic_mode
:
output_parallel
=
weight
[
masked_input
]
else
:
# F.embedding currently has a non-deterministic backward function
output_parallel
=
F
.
embedding
(
masked_input
,
weight
)
# Mask the output embedding.
if
self
.
tensor_model_parallel_size
>
1
:
output_parallel
[
input_mask
,
:]
=
0.0
if
self
.
reduce_scatter_embeddings
:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
output_parallel
=
output_parallel
.
transpose
(
0
,
1
).
contiguous
()
output
=
reduce_scatter_to_sequence_parallel_region
(
output_parallel
)
else
:
# Reduce across all the model parallel GPUs.
output
=
reduce_from_tensor_model_parallel_region
(
output_parallel
)
return
output
def
get_tensor_model_parallel_node_size
(
group
=
None
):
def
get_tensor_model_parallel_node_size
(
group
=
None
):
""" 获取节点数
""" 获取节点数
"""
"""
...
...
dcu_megatron/core/transformer/multi_token_prediction.py
0 → 100755
View file @
1f7b14ab
This diff is collapsed.
Click to expand it.
dcu_megatron/core/transformer/transformer_block.py
View file @
1f7b14ab
...
@@ -8,7 +8,7 @@ def transformer_block_init_wrapper(fn):
...
@@ -8,7 +8,7 @@ def transformer_block_init_wrapper(fn):
# mtp require seperate layernorms for main model and mtp modules, thus move finalnorm out of block
# mtp require seperate layernorms for main model and mtp modules, thus move finalnorm out of block
config
=
args
[
0
]
if
len
(
args
)
>
1
else
kwargs
[
'config'
]
config
=
args
[
0
]
if
len
(
args
)
>
1
else
kwargs
[
'config'
]
if
getattr
(
config
,
"
num_nextn_predict
_layers"
,
0
)
>
0
:
if
getattr
(
config
,
"
mtp_num
_layers"
,
0
)
>
0
:
self
.
main_final_layernorm
=
self
.
final_layernorm
self
.
main_final_layernorm
=
self
.
final_layernorm
self
.
final_layernorm
=
None
self
.
final_layernorm
=
None
...
...
dcu_megatron/core/transformer/transformer_config.py
View file @
1f7b14ab
from
typing
import
Optional
from
functools
import
wraps
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
megatron.training
import
get_args
from
megatron.core.transformer.transformer_config
import
TransformerConfig
,
MLATransformerConfig
from
megatron.core.transformer.transformer_config
import
TransformerConfig
,
MLATransformerConfig
def
transformer_config_post_init_wrapper
(
fn
):
@
wraps
(
fn
)
def
wrapper
(
self
):
fn
(
self
)
args
=
get_args
()
"""Number of Multi-Token Prediction (MTP) Layers."""
self
.
mtp_num_layers
=
args
.
mtp_num_layers
"""Weighting factor of Multi-Token Prediction (MTP) loss."""
self
.
mtp_loss_scaling_factor
=
args
.
mtp_loss_scaling_factor
##################
# flux
##################
self
.
flux_transpose_weight
=
args
.
flux_transpose_weight
return
wrapper
@
dataclass
@
dataclass
class
ExtraTransformerConfig
:
class
ExtraTransformerConfig
:
##################
##################
# multi-token prediction
# multi-token prediction
##################
##################
num_nextn_predict_layers
:
int
=
0
mtp_num_layers
:
Optional
[
int
]
=
None
"""The number of multi-token prediction layers"""
"""Number of Multi-Token Prediction (MTP) Layers."""
mtp_loss_scale
:
float
=
0.3
"""Multi-token prediction loss scale"""
recompute_mtp_norm
:
bool
=
False
"""Whether to recompute mtp normalization"""
recompute_mtp_layer
:
bool
=
False
"""Whether to recompute mtp layer"""
share_mtp_embedding_and_output_weight
:
bool
=
Fals
e
mtp_loss_scaling_factor
:
Optional
[
float
]
=
Non
e
"""
share embedding and output weight with mtp layer
."""
"""
Weighting factor of Multi-Token Prediction (MTP) loss
."""
##################
##################
# flux
# flux
...
...
dcu_megatron/training/arguments.py
View file @
1f7b14ab
...
@@ -170,14 +170,16 @@ def _add_extra_tokenizer_args(parser):
...
@@ -170,14 +170,16 @@ def _add_extra_tokenizer_args(parser):
def
_add_mtp_args
(
parser
):
def
_add_mtp_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'multi token prediction'
)
group
=
parser
.
add_argument_group
(
title
=
'multi token prediction'
)
group
.
add_argument
(
'--num-nextn-predict-layers'
,
type
=
int
,
default
=
0
,
help
=
'Multi-Token prediction layer num'
)
group
.
add_argument
(
'--mtp-num-layers'
,
type
=
int
,
default
=
None
,
group
.
add_argument
(
'--mtp-loss-scale'
,
type
=
float
,
default
=
0.3
,
help
=
'Multi-Token prediction loss scale'
)
help
=
'Number of Multi-Token Prediction (MTP) Layers.'
group
.
add_argument
(
'--recompute-mtp-norm'
,
action
=
'store_true'
,
default
=
False
,
'MTP extends the prediction scope to multiple future tokens at each position.'
help
=
'Multi-Token prediction recompute norm'
)
'This MTP implementation sequentially predict additional tokens '
group
.
add_argument
(
'--recompute-mtp-layer'
,
action
=
'store_true'
,
default
=
False
,
'by using D sequential modules to predict D additional tokens.'
)
help
=
'Multi-Token prediction recompute layer'
)
group
.
add_argument
(
'--mtp-loss-scaling-factor'
,
type
=
float
,
default
=
0.3
,
group
.
add_argument
(
'--share-mtp-embedding-and-output-weight'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Scaling factor of Multi-Token Prediction (MTP) loss. '
help
=
'Main model share embedding and output weight with mtp layer.'
)
'We compute the average of the MTP losses across all depths, '
'and multiply it the scaling factor to obtain the overall MTP loss, '
'which serves as an additional training objective.'
)
return
parser
return
parser
...
...
dcu_megatron/training/utils.py
View file @
1f7b14ab
...
@@ -9,103 +9,97 @@ def get_batch_on_this_tp_rank(data_iterator):
...
@@ -9,103 +9,97 @@ def get_batch_on_this_tp_rank(data_iterator):
args
=
get_args
()
args
=
get_args
()
def
_broadcast
(
item
):
def
_broadcast
(
item
):
if
item
is
not
None
:
if
item
is
not
None
:
torch
.
distributed
.
broadcast
(
item
,
mpu
.
get_tensor_model_parallel_src_rank
(),
group
=
mpu
.
get_tensor_model_parallel_group
())
torch
.
distributed
.
broadcast
(
item
,
mpu
.
get_tensor_model_parallel_src_rank
(),
group
=
mpu
.
get_tensor_model_parallel_group
())
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
if
data_iterator
is
not
None
:
if
data_iterator
is
not
None
:
data
=
next
(
data_iterator
)
data
=
next
(
data_iterator
)
else
:
else
:
data
=
None
data
=
None
batch
=
{
batch
=
{
'tokens'
:
data
[
"tokens"
].
cuda
(
non_blocking
=
True
),
'tokens'
:
data
[
"tokens"
].
cuda
(
non_blocking
=
True
),
'labels'
:
data
[
"labels"
].
cuda
(
non_blocking
=
True
),
'labels'
:
data
[
"labels"
].
cuda
(
non_blocking
=
True
),
'loss_mask'
:
data
[
"loss_mask"
].
cuda
(
non_blocking
=
True
),
'loss_mask'
:
data
[
"loss_mask"
].
cuda
(
non_blocking
=
True
),
'attention_mask'
:
None
if
"attention_mask"
not
in
data
else
data
[
"attention_mask"
].
cuda
(
non_blocking
=
True
),
'attention_mask'
:
None
if
"attention_mask"
not
in
data
else
data
[
"attention_mask"
].
cuda
(
non_blocking
=
True
),
'position_ids'
:
data
[
"position_ids"
].
cuda
(
non_blocking
=
True
)
'position_ids'
:
data
[
"position_ids"
].
cuda
(
non_blocking
=
True
)
}
}
if
args
.
pipeline_model_parallel_size
==
1
:
if
args
.
pipeline_model_parallel_size
==
1
:
_broadcast
(
batch
[
'tokens'
])
_broadcast
(
batch
[
'tokens'
])
_broadcast
(
batch
[
'labels'
])
_broadcast
(
batch
[
'labels'
])
_broadcast
(
batch
[
'loss_mask'
])
_broadcast
(
batch
[
'loss_mask'
])
_broadcast
(
batch
[
'attention_mask'
])
_broadcast
(
batch
[
'attention_mask'
])
_broadcast
(
batch
[
'position_ids'
])
_broadcast
(
batch
[
'position_ids'
])
elif
mpu
.
is_pipeline_first_stage
():
elif
mpu
.
is_pipeline_first_stage
():
_broadcast
(
batch
[
'tokens'
])
_broadcast
(
batch
[
'tokens'
])
_broadcast
(
batch
[
'attention_mask'
])
_broadcast
(
batch
[
'attention_mask'
])
_broadcast
(
batch
[
'position_ids'
])
_broadcast
(
batch
[
'position_ids'
])
elif
mpu
.
is_pipeline_last_stage
():
elif
mpu
.
is_pipeline_last_stage
():
if
args
.
num_nextn_predict_layers
:
# Multi-Token Prediction (MTP) layers need tokens and position_ids to calculate embedding.
# Currently the Multi-Token Prediction (MTP) layers is fixed on the last stage, so we need
# to broadcast tokens and position_ids to all of the tensor parallel ranks on the last stage.
if
args
.
mtp_num_layers
is
not
None
:
_broadcast
(
batch
[
'tokens'
])
_broadcast
(
batch
[
'tokens'
])
_broadcast
(
batch
[
'labels'
])
_broadcast
(
batch
[
'loss_mask'
])
_broadcast
(
batch
[
'attention_mask'
])
if
args
.
reset_position_ids
or
args
.
num_nextn_predict_layers
:
_broadcast
(
batch
[
'position_ids'
])
_broadcast
(
batch
[
'position_ids'
])
_broadcast
(
batch
[
'labels'
])
_broadcast
(
batch
[
'loss_mask'
])
_broadcast
(
batch
[
'attention_mask'
])
else
:
else
:
tokens
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
+
args
.
num_nextn_predict_layers
),
dtype
=
torch
.
int64
,
tokens
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
),
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
device
=
torch
.
cuda
.
current_device
())
labels
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
),
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
labels
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
+
args
.
num_nextn_predict_layers
),
loss_mask
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
),
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
())
dtype
=
torch
.
int64
,
if
args
.
create_attention_mask_in_dataloader
:
device
=
torch
.
cuda
.
current_device
())
attention_mask
=
torch
.
empty
(
loss_mask
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
+
args
.
num_nextn_predict_layers
),
(
args
.
micro_batch_size
,
1
,
args
.
seq_length
,
args
.
seq_length
),
dtype
=
torch
.
bool
,
device
=
torch
.
cuda
.
current_device
()
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
())
if
args
.
create_attention_mask_in_dataloader
:
attention_mask
=
torch
.
empty
(
(
args
.
micro_batch_size
,
1
,
args
.
seq_length
+
args
.
num_nextn_predict_layers
,
args
.
seq_length
+
args
.
num_nextn_predict_layers
),
dtype
=
torch
.
bool
,
device
=
torch
.
cuda
.
current_device
()
)
)
else
:
else
:
attention_mask
=
None
attention_mask
=
None
position_ids
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
+
args
.
num_nextn_predict_layers
),
position_ids
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
),
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
if
args
.
pipeline_model_parallel_size
==
1
:
_broadcast
(
tokens
)
if
args
.
pipeline_model_parallel_size
==
1
:
_broadcast
(
labels
)
_broadcast
(
tokens
)
_broadcast
(
loss_mask
)
_broadcast
(
labels
)
_broadcast
(
attention_mask
)
_broadcast
(
loss_mask
)
_broadcast
(
position_ids
)
_broadcast
(
attention_mask
)
_broadcast
(
position_ids
)
elif
mpu
.
is_pipeline_first_stage
():
labels
=
None
elif
mpu
.
is_pipeline_first_stage
():
loss_mask
=
None
labels
=
None
loss_mask
=
None
_broadcast
(
tokens
)
_broadcast
(
attention_mask
)
_broadcast
(
tokens
)
_broadcast
(
position_ids
)
_broadcast
(
attention_mask
)
_broadcast
(
position_ids
)
elif
mpu
.
is_pipeline_last_stage
():
# Multi-Token Prediction (MTP) layers need tokens and position_ids to calculate embedding.
elif
mpu
.
is_pipeline_last_stage
():
# Currently the Multi-Token Prediction (MTP) layers is fixed on the last stage, so we need
if
args
.
num_nextn_predict_layers
:
# to broadcast tokens and position_ids to all of the tensor parallel ranks on the last stage.
if
args
.
mtp_num_layers
is
not
None
:
_broadcast
(
tokens
)
_broadcast
(
tokens
)
else
:
tokens
=
None
_broadcast
(
labels
)
_broadcast
(
loss_mask
)
_broadcast
(
attention_mask
)
if
args
.
reset_position_ids
or
args
.
num_nextn_predict_layers
:
_broadcast
(
position_ids
)
_broadcast
(
position_ids
)
else
:
else
:
position_ids
=
None
tokens
=
None
position_ids
=
None
batch
=
{
'tokens'
:
tokens
,
_broadcast
(
labels
)
'labels'
:
labels
,
_broadcast
(
loss_mask
)
'loss_mask'
:
loss_mask
,
_broadcast
(
attention_mask
)
'attention_mask'
:
attention_mask
,
'position_ids'
:
position_ids
batch
=
{
}
'tokens'
:
tokens
,
'labels'
:
labels
,
'loss_mask'
:
loss_mask
,
'attention_mask'
:
attention_mask
,
'position_ids'
:
position_ids
}
return
batch
return
batch
pretrain_gpt.py
View file @
1f7b14ab
...
@@ -39,9 +39,7 @@ from megatron.core.models.gpt.gpt_layer_specs import (
...
@@ -39,9 +39,7 @@ from megatron.core.models.gpt.gpt_layer_specs import (
get_gpt_layer_with_transformer_engine_spec
,
get_gpt_layer_with_transformer_engine_spec
,
)
)
from
megatron.core.transformer.transformer_block
import
TransformerBlockSubmodules
from
dcu_megatron.core.models.gpt.gpt_layer_specs
import
get_gpt_mtp_block_spec
from
dcu_megatron.core.transformer.mtp.mtp_spec
import
get_mtp_spec
from
dcu_megatron.core.utils
import
tensor_slide
from
dcu_megatron
import
megatron_adaptor
from
dcu_megatron
import
megatron_adaptor
...
@@ -133,13 +131,12 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
...
@@ -133,13 +131,12 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
raise
RuntimeError
(
"--fp8-param-gather requires `fp8_model_init` from TransformerEngine, but not found."
)
raise
RuntimeError
(
"--fp8-param-gather requires `fp8_model_init` from TransformerEngine, but not found."
)
# Define the mtp layer spec
# Define the mtp layer spec
if
isinstance
(
transformer_layer_spec
,
TransformerBlockSubmodules
):
mtp_block_spec
=
None
mtp_transformer_layer_spec
=
transformer_layer_spec
.
layer_specs
[
-
1
]
if
args
.
mtp_num_layers
is
not
None
:
else
:
from
dcu_megatron.core.models.gpt.gpt_layer_specs
import
get_gpt_mtp_block_spec
mtp_transformer_layer_spec
=
transformer_
layer_spec
mtp_
block_spec
=
get_gpt_mtp_block_spec
(
config
,
transformer_layer_spec
,
use_
transformer_
engine
=
use_te
)
with
build_model_context
(
**
build_model_context_args
):
with
build_model_context
(
**
build_model_context_args
):
config
.
mtp_spec
=
get_mtp_spec
(
mtp_transformer_layer_spec
,
use_te
=
use_te
)
model
=
GPTModel
(
model
=
GPTModel
(
config
=
config
,
config
=
config
,
transformer_layer_spec
=
transformer_layer_spec
,
transformer_layer_spec
=
transformer_layer_spec
,
...
@@ -153,7 +150,8 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
...
@@ -153,7 +150,8 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
position_embedding_type
=
args
.
position_embedding_type
,
position_embedding_type
=
args
.
position_embedding_type
,
rotary_percent
=
args
.
rotary_percent
,
rotary_percent
=
args
.
rotary_percent
,
rotary_base
=
args
.
rotary_base
,
rotary_base
=
args
.
rotary_base
,
rope_scaling
=
args
.
use_rope_scaling
rope_scaling
=
args
.
use_rope_scaling
,
mtp_block_spec
=
mtp_block_spec
,
)
)
# model = torch.compile(model,mode='max-autotune-no-cudagraphs')
# model = torch.compile(model,mode='max-autotune-no-cudagraphs')
print_rank_0
(
model
)
print_rank_0
(
model
)
...
@@ -197,8 +195,6 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
...
@@ -197,8 +195,6 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
args
=
get_args
()
args
=
get_args
()
losses
=
output_tensor
.
float
()
losses
=
output_tensor
.
float
()
if
getattr
(
args
,
"num_nextn_predict_layers"
,
0
)
>
0
:
loss_mask
=
tensor_slide
(
loss_mask
,
args
.
num_nextn_predict_layers
,
return_first
=
True
)[
0
]
loss_mask
=
loss_mask
.
view
(
-
1
).
float
()
loss_mask
=
loss_mask
.
view
(
-
1
).
float
()
total_tokens
=
loss_mask
.
sum
()
total_tokens
=
loss_mask
.
sum
()
loss
=
torch
.
cat
([
torch
.
sum
(
losses
.
view
(
-
1
)
*
loss_mask
).
view
(
1
),
total_tokens
.
view
(
1
)])
loss
=
torch
.
cat
([
torch
.
sum
(
losses
.
view
(
-
1
)
*
loss_mask
).
view
(
1
),
total_tokens
.
view
(
1
)])
...
@@ -267,8 +263,12 @@ def forward_step(data_iterator, model: GPTModel):
...
@@ -267,8 +263,12 @@ def forward_step(data_iterator, model: GPTModel):
timers
(
'batch-generator'
).
stop
()
timers
(
'batch-generator'
).
stop
()
with
stimer
:
with
stimer
:
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
,
if
args
.
use_legacy_models
:
labels
=
labels
)
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
,
labels
=
labels
)
else
:
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
,
labels
=
labels
,
loss_mask
=
loss_mask
)
return
output_tensor
,
partial
(
loss_func
,
loss_mask
)
return
output_tensor
,
partial
(
loss_func
,
loss_mask
)
...
@@ -289,7 +289,7 @@ def core_gpt_dataset_config_from_args(args):
...
@@ -289,7 +289,7 @@ def core_gpt_dataset_config_from_args(args):
return
GPTDatasetConfig
(
return
GPTDatasetConfig
(
random_seed
=
args
.
seed
,
random_seed
=
args
.
seed
,
sequence_length
=
args
.
seq_length
+
getattr
(
args
,
"num_nextn_predict_layers"
,
0
)
,
sequence_length
=
args
.
seq_length
,
blend
=
blend
,
blend
=
blend
,
blend_per_split
=
blend_per_split
,
blend_per_split
=
blend_per_split
,
split
=
args
.
split
,
split
=
args
.
split
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment