Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
a8a2bbea
Commit
a8a2bbea
authored
Apr 29, 2025
by
dongcl
Browse files
patch for megatron 4429e8ebe
parent
2ddbd4be
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
424 additions
and
1310 deletions
+424
-1310
dcu_megatron/adaptor/megatron_adaptor.py
dcu_megatron/adaptor/megatron_adaptor.py
+18
-38
dcu_megatron/adaptor/patch_utils.py
dcu_megatron/adaptor/patch_utils.py
+21
-3
dcu_megatron/core/distributed/finalize_model_grads.py
dcu_megatron/core/distributed/finalize_model_grads.py
+0
-49
dcu_megatron/core/extensions/transformer_engine.py
dcu_megatron/core/extensions/transformer_engine.py
+27
-4
dcu_megatron/core/models/common/embeddings/language_model_embedding.py
...core/models/common/embeddings/language_model_embedding.py
+0
-133
dcu_megatron/core/models/gpt/gpt_layer_specs.py
dcu_megatron/core/models/gpt/gpt_layer_specs.py
+14
-3
dcu_megatron/core/models/gpt/gpt_model.py
dcu_megatron/core/models/gpt/gpt_model.py
+45
-228
dcu_megatron/core/tensor_parallel/__init__.py
dcu_megatron/core/tensor_parallel/__init__.py
+0
-2
dcu_megatron/core/tensor_parallel/layers.py
dcu_megatron/core/tensor_parallel/layers.py
+8
-96
dcu_megatron/core/tensor_parallel/random.py
dcu_megatron/core/tensor_parallel/random.py
+0
-104
dcu_megatron/core/transformer/mtp/mtp_spec.py
dcu_megatron/core/transformer/mtp/mtp_spec.py
+0
-51
dcu_megatron/core/transformer/mtp/multi_token_predictor.py
dcu_megatron/core/transformer/mtp/multi_token_predictor.py
+0
-286
dcu_megatron/core/transformer/transformer_block.py
dcu_megatron/core/transformer/transformer_block.py
+1
-1
dcu_megatron/core/transformer/transformer_config.py
dcu_megatron/core/transformer/transformer_config.py
+20
-15
dcu_megatron/core/utils.py
dcu_megatron/core/utils.py
+0
-30
dcu_megatron/training/arguments.py
dcu_megatron/training/arguments.py
+10
-8
dcu_megatron/training/initialize.py
dcu_megatron/training/initialize.py
+73
-75
dcu_megatron/training/training.py
dcu_megatron/training/training.py
+172
-73
dcu_megatron/training/utils.py
dcu_megatron/training/utils.py
+0
-111
examples/deepseek_v3/run_deepseek_v3_1node.sh
examples/deepseek_v3/run_deepseek_v3_1node.sh
+15
-0
No files found.
dcu_megatron/adaptor/megatron_adaptor.py
View file @
a8a2bbea
...
...
@@ -21,7 +21,6 @@ class MegatronAdaptation:
for
adaptation
in
[
CoreAdaptation
(),
LegacyAdaptation
()]:
adaptation
.
execute
()
MegatronAdaptation
.
apply
()
# MegatronAdaptation.post_execute()
@
classmethod
def
register
(
cls
,
orig_func_name
,
new_func
=
None
,
force_patch
=
False
,
create_dummy
=
False
,
apply_wrapper
=
False
,
remove_origin_wrappers
=
False
):
...
...
@@ -87,47 +86,23 @@ class CoreAdaptation(MegatronAdaptationABC):
self
.
patch_miscellaneous
()
def
patch_core_distributed
(
self
):
# Mtp share embedding
from
..core.distributed.finalize_model_grads
import
_allreduce_word_embedding_grads
MegatronAdaptation
.
register
(
'megatron.core.distributed.finalize_model_grads._allreduce_word_embedding_grads'
,
_allreduce_word_embedding_grads
)
pass
def
patch_core_models
(
self
):
from
..core.models.common.embeddings.language_model_embedding
import
(
language_model_embedding_forward
,
language_model_embedding_init_func
)
from
..core.models.gpt.gpt_model
import
(
gpt_model_forward
,
gpt_model_init_wrapper
,
shared_embedding_or_mtp_embedding_weight
)
from
..training.utils
import
get_batch_on_this_tp_rank
# Embedding
MegatronAdaptation
.
register
(
'megatron.core.models.common.embeddings.language_model_embedding.LanguageModelEmbedding.__init__'
,
language_model_embedding_init_func
)
MegatronAdaptation
.
register
(
'megatron.core.models.common.embeddings.language_model_embedding.LanguageModelEmbedding.forward'
,
language_model_embedding_forward
)
MegatronAdaptation
.
register
(
'megatron.training.utils.get_batch_on_this_tp_rank'
,
get_batch_on_this_tp_rank
)
from
..core.models.gpt.gpt_model
import
gpt_model_init_wrapper
,
gpt_model_forward
# GPT Model
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.forward'
,
gpt_model_forward
)
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.__init__'
,
gpt_model_init_wrapper
,
apply_wrapper
=
True
)
from
megatron.core.models.gpt.gpt_model
import
GPTModel
setattr
(
GPTModel
,
'shared_embedding_or_mtp_embedding_weight'
,
shared_embedding_or_mtp_embedding_weight
)
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.forward'
,
gpt_model_forward
)
def
patch_core_transformers
(
self
):
from
..core
import
transformer_block_init_wrapper
from
..core.transformer.transformer_config
import
TransformerConfigPatch
,
MLATransformerConfigPatch
# Transformer block
# Transformer block
. If mtp_num_layers > 0, move final_layernorm outside
MegatronAdaptation
.
register
(
'megatron.core.transformer.transformer_block.TransformerBlock.__init__'
,
transformer_block_init_wrapper
)
...
...
@@ -141,9 +116,9 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation
.
register
(
'megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity'
,
torch
.
compile
(
options
=
{
"triton.cudagraphs"
:
True
,
"triton.cudagraph_trees"
:
False
}),
apply_wrapper
=
True
)
MegatronAdaptation
.
register
(
'megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func'
,
torch
.
compile
(
options
=
{
"triton.cudagraphs"
:
True
,
"triton.cudagraph_trees"
:
False
,
"triton.cudagraph_support_input_mutation"
:
True
}),
apply_wrapper
=
True
)
#
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func',
#
torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False, "triton.cudagraph_support_input_mutation":True}),
#
apply_wrapper=True)
MegatronAdaptation
.
register
(
'megatron.core.transformer.moe.moe_utils.permute'
,
torch
.
compile
(
mode
=
'max-autotune-no-cudagraphs'
),
apply_wrapper
=
True
)
...
...
@@ -157,6 +132,7 @@ class CoreAdaptation(MegatronAdaptationABC):
from
..core.extensions.transformer_engine
import
TEDotProductAttentionPatch
from
megatron.core.extensions.transformer_engine
import
TEGroupedLinear
# kv channels, te_min_version 1.10.0 -> 1.9.0
MegatronAdaptation
.
register
(
'megatron.core.extensions.transformer_engine.TEDotProductAttention.__init__'
,
TEDotProductAttentionPatch
.
__init__
)
...
...
@@ -165,13 +141,10 @@ class CoreAdaptation(MegatronAdaptationABC):
def
patch_tensor_parallel
(
self
):
from
..core.tensor_parallel.cross_entropy
import
VocabParallelCrossEntropy
from
..core.tensor_parallel
import
vocab_parallel_embedding_forward
,
vocab_parallel_embedding_init_wrapper
# VocabParallelEmbedding
MegatronAdaptation
.
register
(
'megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward'
,
vocab_parallel_embedding_forward
)
MegatronAdaptation
.
register
(
'megatron.core.tensor_parallel.layers.VocabParallelEmbedding.__init__'
,
vocab_parallel_embedding_init_wrapper
,
torch
.
compile
(
mode
=
'max-autotune-no-cudagraphs'
),
apply_wrapper
=
True
)
# VocabParallelCrossEntropy
...
...
@@ -202,6 +175,9 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation
.
register
(
"megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_with_transformer_engine_spec"
,
get_gpt_layer_with_flux_spec
)
def
patch_pipeline_parallel
(
self
):
pass
def
patch_training
(
self
):
from
..training.tokenizer
import
build_tokenizer
from
..training.initialize
import
_initialize_distributed
...
...
@@ -210,12 +186,14 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation
.
register
(
'megatron.training.tokenizer.tokenizer.build_tokenizer'
,
build_tokenizer
)
# specify init_method
MegatronAdaptation
.
register
(
'megatron.training.initialize._initialize_distributed'
,
_initialize_distributed
)
# remove fused_kernels
MegatronAdaptation
.
register
(
'megatron.training.initialize._compile_dependencies'
,
_compile_dependencies
)
#
traing.train
#
add trace_handler
MegatronAdaptation
.
register
(
'megatron.training.training.train'
,
train
)
...
...
@@ -245,6 +223,8 @@ class LegacyAdaptation(MegatronAdaptationABC):
MegatronAdaptation
.
register
(
'megatron.legacy.model.transformer.ParallelMLP.__init__'
,
parallel_mlp_init_wrapper
,
apply_wrapper
=
True
)
# ParallelAttention
MegatronAdaptation
.
register
(
'megatron.legacy.model.transformer.ParallelAttention.__init__'
,
parallel_attention_init_wrapper
,
apply_wrapper
=
True
)
...
...
dcu_megatron/adaptor/patch_utils.py
View file @
a8a2bbea
...
...
@@ -148,11 +148,29 @@ class MegatronPatchesManager:
patches_info
=
{}
@
staticmethod
def
register_patch
(
orig_func_or_cls_name
,
new_func_or_cls
=
None
,
force_patch
=
False
,
create_dummy
=
False
):
def
register_patch
(
orig_func_or_cls_name
,
new_func_or_cls
=
None
,
force_patch
=
False
,
create_dummy
=
False
,
apply_wrapper
=
False
,
remove_origin_wrappers
=
False
):
if
orig_func_or_cls_name
not
in
MegatronPatchesManager
.
patches_info
:
MegatronPatchesManager
.
patches_info
[
orig_func_or_cls_name
]
=
Patch
(
orig_func_or_cls_name
,
new_func_or_cls
,
create_dummy
)
MegatronPatchesManager
.
patches_info
[
orig_func_or_cls_name
]
=
Patch
(
orig_func_or_cls_name
,
new_func_or_cls
,
create_dummy
,
apply_wrapper
=
apply_wrapper
,
remove_origin_wrappers
=
remove_origin_wrappers
)
else
:
MegatronPatchesManager
.
patches_info
.
get
(
orig_func_or_cls_name
).
set_patch_func
(
new_func_or_cls
,
force_patch
)
MegatronPatchesManager
.
patches_info
.
get
(
orig_func_or_cls_name
).
set_patch_func
(
new_func_or_cls
,
force_patch
,
apply_wrapper
=
apply_wrapper
,
remove_origin_wrappers
=
remove_origin_wrappers
)
@
staticmethod
def
apply_patches
():
...
...
dcu_megatron/core/distributed/finalize_model_grads.py
deleted
100644 → 0
View file @
2ddbd4be
from
typing
import
List
import
torch
from
megatron.core
import
parallel_state
from
megatron.core.distributed.finalize_model_grads
import
_unshard_if_dtensor
,
_reshard_if_dtensor
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core.utils
import
get_attr_wrapped_model
def
_allreduce_word_embedding_grads
(
model
:
List
[
torch
.
nn
.
Module
],
config
:
TransformerConfig
):
"""
All-reduce word embedding grads.
Reduce grads across first and last stages to ensure that word_embeddings parameters stay in
sync.
"""
if
(
parallel_state
.
is_rank_in_embedding_group
(
ignore_virtual
=
True
)
and
torch
.
distributed
.
get_world_size
(
parallel_state
.
get_embedding_group
())
>
1
):
if
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
model_module
=
model
[
0
]
elif
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
model_module
=
model
[
-
1
]
else
:
# We do not support an interleaved schedule for models with encoders yet.
model_module
=
model
[
0
]
model_module
=
get_attr_wrapped_model
(
model_module
,
'pre_process'
,
return_model_obj
=
True
)
if
model_module
.
share_embeddings_and_output_weights
:
weight
=
model_module
.
shared_embedding_or_output_weight
()
grad_attr
=
"main_grad"
if
hasattr
(
weight
,
"main_grad"
)
else
"grad"
orig_grad
=
getattr
(
weight
,
grad_attr
)
grad
=
_unshard_if_dtensor
(
orig_grad
)
torch
.
distributed
.
all_reduce
(
grad
,
group
=
parallel_state
.
get_embedding_group
())
setattr
(
weight
,
grad_attr
,
_reshard_if_dtensor
(
grad
,
orig_grad
))
if
(
hasattr
(
model_module
,
"share_mtp_embedding_and_output_weight"
)
and
model_module
.
share_mtp_embedding_and_output_weight
and
config
.
num_nextn_predict_layers
>
0
):
weight
=
model_module
.
shared_embedding_or_mtp_embedding_weight
()
grad_attr
=
"main_grad"
if
hasattr
(
weight
,
"main_grad"
)
else
"grad"
orig_grad
=
getattr
(
weight
,
grad_attr
)
grad
=
_unshard_if_dtensor
(
orig_grad
)
torch
.
distributed
.
all_reduce
(
grad
,
group
=
parallel_state
.
get_embedding_group
())
setattr
(
weight
,
grad_attr
,
_reshard_if_dtensor
(
grad
,
orig_grad
))
dcu_megatron/core/extensions/transformer_engine.py
View file @
a8a2bbea
import
os
import
torch
import
dataclasses
import
transformer_engine
as
te
...
...
@@ -11,6 +12,7 @@ from megatron.core.utils import get_te_version, is_te_min_version
from
megatron.core.extensions.transformer_engine
import
TEDotProductAttention
from
megatron.core.transformer.enums
import
AttnMaskType
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core.process_groups_config
import
ModelCommProcessGroups
from
megatron.core.parallel_state
import
(
get_context_parallel_global_ranks
,
...
...
@@ -32,6 +34,7 @@ class TEDotProductAttentionPatch(te.pytorch.DotProductAttention):
k_channels
:
Optional
[
int
]
=
None
,
v_channels
:
Optional
[
int
]
=
None
,
cp_comm_type
:
str
=
"p2p"
,
model_comm_pgs
:
ModelCommProcessGroups
=
None
,
):
self
.
config
=
config
self
.
te_forward_mask_type
=
False
...
...
@@ -58,6 +61,26 @@ class TEDotProductAttentionPatch(te.pytorch.DotProductAttention):
f
"num_attention_heads (
{
self
.
config
.
num_attention_heads
}
))"
)
if
model_comm_pgs
is
None
:
# For backward compatibility, remove in v0.14 and raise error
# raise ValueError("TEDotProductAttention was called without ModelCommProcessGroups")
model_comm_pgs
=
ModelCommProcessGroups
(
tp
=
get_tensor_model_parallel_group
(
check_initialized
=
False
),
cp
=
get_context_parallel_group
(
check_initialized
=
False
),
hcp
=
get_hierarchical_context_parallel_groups
(
check_initialized
=
False
),
)
else
:
assert
hasattr
(
model_comm_pgs
,
'tp'
),
"TEDotProductAttention model_comm_pgs must have tp pg"
assert
hasattr
(
model_comm_pgs
,
'cp'
),
"TEDotProductAttention model_comm_pgs must have cp pg"
if
cp_comm_type
==
"a2a+p2p"
:
assert
hasattr
(
model_comm_pgs
,
'hcp'
),
"TEDotProductAttention model_comm_pgs must have hierarchical cp pg"
if
is_te_min_version
(
"0.10.0"
):
extra_kwargs
[
"attention_type"
]
=
attention_type
# older version don't need attention_type
...
...
@@ -73,9 +96,9 @@ class TEDotProductAttentionPatch(te.pytorch.DotProductAttention):
),
"Only Transformer-Engine version >= 1.0.0 supports context parallelism!"
if
getattr
(
TEDotProductAttention
,
"cp_stream"
)
is
None
:
TEDotProductAttention
.
cp_stream
=
torch
.
cuda
.
Stream
()
extra_kwargs
[
"cp_group"
]
=
get_context_parallel_group
(
check_initialized
=
False
)
extra_kwargs
[
"cp_global_ranks"
]
=
get_context_parallel_global
_ranks
(
check_initialized
=
False
extra_kwargs
[
"cp_group"
]
=
model_comm_pgs
.
cp
extra_kwargs
[
"cp_global_ranks"
]
=
torch
.
distributed
.
get_process_group
_ranks
(
model_comm_pgs
.
cp
)
extra_kwargs
[
"cp_stream"
]
=
TEDotProductAttention
.
cp_stream
if
is_te_min_version
(
"1.10.0"
):
...
...
@@ -149,7 +172,7 @@ class TEDotProductAttentionPatch(te.pytorch.DotProductAttention):
get_rng_state_tracker
=
(
get_cuda_rng_tracker
if
get_cuda_rng_tracker
().
is_initialized
()
else
None
),
tp_group
=
get_tensor_model_parallel_group
(
check_initialized
=
False
)
,
tp_group
=
model_comm_pgs
.
tp
,
layer_number
=
layer_number
,
**
extra_kwargs
,
)
dcu_megatron/core/models/common/embeddings/language_model_embedding.py
deleted
100644 → 0
View file @
2ddbd4be
from
typing
import
Literal
import
torch
from
torch
import
Tensor
from
megatron.core
import
tensor_parallel
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core.models.common.embeddings.language_model_embedding
import
LanguageModelEmbedding
def
language_model_embedding_init_func
(
self
,
config
:
TransformerConfig
,
vocab_size
:
int
,
max_sequence_length
:
int
,
position_embedding_type
:
Literal
[
'learned_absolute'
,
'rope'
,
'none'
]
=
'learned_absolute'
,
num_tokentypes
:
int
=
0
,
scatter_to_sequence_parallel
:
bool
=
True
,
skip_weight_param_allocation
:
bool
=
False
):
"""Patch language model embeddings init."""
super
(
LanguageModelEmbedding
,
self
).
__init__
(
config
=
config
)
self
.
config
:
TransformerConfig
=
config
self
.
vocab_size
:
int
=
vocab_size
self
.
max_sequence_length
:
int
=
max_sequence_length
self
.
add_position_embedding
:
bool
=
position_embedding_type
==
'learned_absolute'
self
.
num_tokentypes
=
num_tokentypes
self
.
scatter_to_sequence_parallel
=
scatter_to_sequence_parallel
self
.
reduce_scatter_embeddings
=
(
(
not
self
.
add_position_embedding
)
and
self
.
num_tokentypes
<=
0
and
self
.
config
.
sequence_parallel
and
self
.
scatter_to_sequence_parallel
)
# Word embeddings (parallel).
self
.
word_embeddings
=
tensor_parallel
.
VocabParallelEmbedding
(
num_embeddings
=
self
.
vocab_size
,
embedding_dim
=
self
.
config
.
hidden_size
,
init_method
=
self
.
config
.
init_method
,
reduce_scatter_embeddings
=
self
.
reduce_scatter_embeddings
,
config
=
self
.
config
,
skip_weight_param_allocation
=
skip_weight_param_allocation
)
# Position embedding (serial).
if
self
.
add_position_embedding
:
self
.
position_embeddings
=
torch
.
nn
.
Embedding
(
self
.
max_sequence_length
,
self
.
config
.
hidden_size
)
# Initialize the position embeddings.
if
self
.
config
.
perform_initialization
:
self
.
config
.
init_method
(
self
.
position_embeddings
.
weight
)
if
self
.
num_tokentypes
>
0
:
self
.
tokentype_embeddings
=
torch
.
nn
.
Embedding
(
self
.
num_tokentypes
,
self
.
config
.
hidden_size
)
# Initialize the token-type embeddings.
if
self
.
config
.
perform_initialization
:
self
.
config
.
init_method
(
self
.
tokentype_embeddings
.
weight
)
else
:
self
.
tokentype_embeddings
=
None
# Embeddings dropout
self
.
embedding_dropout
=
torch
.
nn
.
Dropout
(
self
.
config
.
hidden_dropout
)
def
language_model_embedding_forward
(
self
,
input_ids
:
Tensor
,
position_ids
:
Tensor
,
tokentype_ids
:
int
=
None
,
weight
:
Tensor
=
None
)
->
Tensor
:
"""Pacth forward pass of the embedding module.
Args:
input_ids (Tensor): The input tokens
position_ids (Tensor): The position id's used to calculate position embeddings
tokentype_ids (int): The token type ids. Used when args.bert_binary_head is
set to True. Defaults to None
weight (Tensor): embedding weight
Returns:
Tensor: The output embeddings
"""
if
weight
is
None
:
if
self
.
word_embeddings
.
weight
is
None
:
raise
RuntimeError
(
"weight was not supplied to VocabParallelEmbedding forward pass "
"and skip_weight_param_allocation is True."
)
weight
=
self
.
word_embeddings
.
weight
word_embeddings
=
self
.
word_embeddings
(
input_ids
,
weight
)
if
self
.
add_position_embedding
:
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
embeddings
=
word_embeddings
+
position_embeddings
else
:
embeddings
=
word_embeddings
if
not
self
.
reduce_scatter_embeddings
:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
embeddings
=
embeddings
.
transpose
(
0
,
1
).
contiguous
()
if
tokentype_ids
is
not
None
:
assert
self
.
tokentype_embeddings
is
not
None
# [b s h] -> [s b h] (So that it can be added with embeddings)
tokentype_embedding
=
self
.
tokentype_embeddings
(
tokentype_ids
).
permute
(
1
,
0
,
2
)
embeddings
=
embeddings
+
tokentype_embedding
else
:
assert
self
.
tokentype_embeddings
is
None
# If the input flag for fp32 residual connection is set, convert for float.
if
self
.
config
.
fp32_residual_connection
:
embeddings
=
embeddings
.
float
()
# Dropout.
if
self
.
config
.
sequence_parallel
:
if
not
self
.
reduce_scatter_embeddings
and
self
.
scatter_to_sequence_parallel
:
embeddings
=
tensor_parallel
.
scatter_to_sequence_parallel_region
(
embeddings
)
# `scatter_to_sequence_parallel_region` returns a view, which prevents
# the original tensor from being garbage collected. Clone to facilitate GC.
# Has a small runtime cost (~0.5%).
if
self
.
config
.
clone_scatter_output_in_embedding
and
self
.
scatter_to_sequence_parallel
:
embeddings
=
embeddings
.
clone
()
with
tensor_parallel
.
get_cuda_rng_tracker
().
fork
():
embeddings
=
self
.
embedding_dropout
(
embeddings
)
else
:
embeddings
=
self
.
embedding_dropout
(
embeddings
)
return
embeddings
dcu_megatron/core/models/gpt/gpt_layer_specs.py
View file @
a8a2bbea
import
warnings
from
typing
import
Optional
from
typing
import
Optional
,
Union
from
megatron.core.fusions.fused_bias_dropout
import
get_bias_dropout_add
from
megatron.core.models.gpt.moe_module_specs
import
get_moe_module_spec
...
...
@@ -12,13 +12,13 @@ from megatron.core.transformer.multi_latent_attention import (
MLASelfAttentionSubmodules
,
)
from
megatron.core.transformer.spec_utils
import
ModuleSpec
from
megatron.core.transformer.transformer_block
import
TransformerBlockSubmodules
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core.transformer.transformer_layer
import
(
TransformerLayer
,
TransformerLayerSubmodules
,
)
from
dcu_megatron.core.tensor_parallel.layers
import
FluxColumnParallelLinear
,
FluxRowParallelLinear
from
megatron.core.utils
import
is_te_min_version
try
:
...
...
@@ -36,6 +36,17 @@ try:
except
ImportError
:
warnings
.
warn
(
'Apex is not installed.'
)
from
dcu_megatron.core.tensor_parallel.layers
import
(
FluxColumnParallelLinear
,
FluxRowParallelLinear
)
from
dcu_megatron.core.transformer.multi_token_prediction
import
(
MultiTokenPredictionBlockSubmodules
,
get_mtp_layer_offset
,
get_mtp_layer_spec
,
get_mtp_num_layers_to_build
,
)
def
get_gpt_layer_with_flux_spec
(
num_experts
:
Optional
[
int
]
=
None
,
...
...
dcu_megatron/core/models/gpt/gpt_model.py
View file @
a8a2bbea
import
os
import
logging
from
typing
import
Literal
,
Optional
from
functools
import
wraps
from
collections
import
OrderedDict
from
typing
import
Optional
import
torch
from
torch
import
Tensor
from
megatron.core
import
InferenceParams
,
parallel_state
,
tensor_parallel
from
megatron.core
import
InferenceParams
,
tensor_parallel
from
megatron.core.config_logger
import
has_config_logger_enabled
,
log_config_to_disk
from
megatron.core.models.gpt.gpt_model
import
GPTModel
from
megatron.core.models.common.language_module.language_module
import
LanguageModule
from
megatron.core.models.common.embeddings.language_model_embedding
import
LanguageModelEmbedding
from
megatron.core.models.common.embeddings.rotary_pos_embedding
import
RotaryEmbedding
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.transformer.enums
import
ModelType
from
megatron.core.transformer.spec_utils
import
ModuleSpec
from
megatron.core.transformer.transformer_block
import
TransformerBlock
from
megatron.core.extensions.transformer_engine
import
TEColumnParallelLinear
from
dcu_megatron.core.utils
import
tensor_slide
from
dcu_megatron.core.transformer.mtp.multi_token_predictor
import
MultiTokenPredictor
from
dcu_megatron.core.transformer.transformer_config
import
TransformerConfig
from
dcu_megatron.core.tensor_parallel
import
FluxColumnParallelLinear
...
...
@@ -30,11 +18,13 @@ def gpt_model_init_wrapper(fn):
def
wrapper
(
self
,
*
args
,
**
kwargs
):
fn
(
self
,
*
args
,
**
kwargs
)
if
(
self
.
post_process
and
int
(
os
.
getenv
(
"USE_FLUX_OVERLAP"
,
"0"
))
):
self
.
output_layer
=
FluxColumnParallelLinear
(
# Output
if
self
.
post_process
or
self
.
mtp_process
:
if
int
(
os
.
getenv
(
"USE_FLUX_OVERLAP"
,
"0"
)):
parallel_linear_impl
=
FluxColumnParallelLinear
else
:
parallel_linear_impl
=
tensor_parallel
.
ColumnParallelLinear
self
.
output_layer
=
parallel_linear_impl
(
self
.
config
.
hidden_size
,
self
.
vocab_size
,
config
=
self
.
config
,
...
...
@@ -48,174 +38,12 @@ def gpt_model_init_wrapper(fn):
grad_output_buffer
=
self
.
grad_output_buffer
,
)
if
self
.
pre_process
or
self
.
post_process
:
self
.
setup_embeddings_and_output_layer
()
# add mtp
self
.
num_nextn_predict_layers
=
self
.
config
.
num_nextn_predict_layers
if
self
.
num_nextn_predict_layers
:
assert
hasattr
(
self
.
config
,
"mtp_spec"
)
self
.
mtp_spec
:
ModuleSpec
=
self
.
config
.
mtp_spec
self
.
share_mtp_embedding_and_output_weight
=
self
.
config
.
share_mtp_embedding_and_output_weight
self
.
recompute_mtp_norm
=
self
.
config
.
recompute_mtp_norm
self
.
recompute_mtp_layer
=
self
.
config
.
recompute_mtp_layer
self
.
mtp_loss_scale
=
self
.
config
.
mtp_loss_scale
if
self
.
post_process
and
self
.
training
:
self
.
mtp_layers
=
torch
.
nn
.
ModuleList
(
[
MultiTokenPredictor
(
self
.
config
,
self
.
mtp_spec
.
submodules
,
vocab_size
=
self
.
vocab_size
,
max_sequence_length
=
self
.
max_sequence_length
,
layer_number
=
i
,
pre_process
=
self
.
pre_process
,
fp16_lm_cross_entropy
=
self
.
fp16_lm_cross_entropy
,
parallel_output
=
self
.
parallel_output
,
position_embedding_type
=
self
.
position_embedding_type
,
rotary_percent
=
self
.
rotary_percent
,
seq_len_interpolation_factor
=
kwargs
.
get
(
"seq_len_interpolation_factor"
,
None
),
share_mtp_embedding_and_output_weight
=
self
.
share_mtp_embedding_and_output_weight
,
recompute_mtp_norm
=
self
.
recompute_mtp_norm
,
recompute_mtp_layer
=
self
.
recompute_mtp_layer
,
add_output_layer_bias
=
False
)
for
i
in
range
(
self
.
num_nextn_predict_layers
)
]
)
if
self
.
pre_process
or
self
.
post_process
:
setup_mtp_embeddings
(
self
)
return
wrapper
def
shared_embedding_or_mtp_embedding_weight
(
self
)
->
Tensor
:
"""Gets the embedding weight when share embedding and mtp embedding weights set to True.
Returns:
Tensor: During pre processing it returns the input embeddings weight while during post processing it returns
mtp embedding layers weight
"""
assert
self
.
num_nextn_predict_layers
>
0
if
self
.
pre_process
:
return
self
.
embedding
.
word_embeddings
.
weight
elif
self
.
post_process
:
return
self
.
mtp_layers
[
0
].
embedding
.
word_embeddings
.
weight
return
None
def
setup_mtp_embeddings
(
self
):
"""
Share embedding layer in mtp layer.
"""
if
self
.
pre_process
:
self
.
embedding
.
word_embeddings
.
weight
.
is_embedding_or_output_parameter
=
True
# Set `is_embedding_or_output_parameter` attribute.
for
i
in
range
(
self
.
num_nextn_predict_layers
):
if
self
.
post_process
and
self
.
mtp_layers
[
i
].
embedding
.
word_embeddings
.
weight
is
not
None
:
self
.
mtp_layers
[
i
].
embedding
.
word_embeddings
.
weight
.
is_embedding_or_output_parameter
=
True
if
not
self
.
share_mtp_embedding_and_output_weight
:
return
if
self
.
pre_process
and
self
.
post_process
:
# Zero out wgrad if sharing embeddings between two layers on same
# pipeline stage to make sure grad accumulation into main_grad is
# correct and does not include garbage values (e.g., from torch.empty).
self
.
shared_embedding_or_mtp_embedding_weight
().
zero_out_wgrad
=
True
return
if
self
.
pre_process
and
not
self
.
post_process
:
assert
parallel_state
.
is_pipeline_first_stage
()
self
.
shared_embedding_or_mtp_embedding_weight
().
shared_embedding
=
True
if
self
.
post_process
and
not
self
.
pre_process
:
assert
not
parallel_state
.
is_pipeline_first_stage
()
for
i
in
range
(
self
.
num_nextn_predict_layers
):
# set word_embeddings weights to 0 here, then copy first
# stage's weights using all_reduce below.
self
.
mtp_layers
[
i
].
embedding
.
word_embeddings
.
weight
.
data
.
fill_
(
0
)
self
.
mtp_layers
[
i
].
embedding
.
word_embeddings
.
weight
.
shared
=
True
self
.
mtp_layers
[
i
].
embedding
.
word_embeddings
.
weight
.
shared_embedding
=
True
# Parameters are shared between the word embeddings layers, and the
# heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different
# workers, so we do the following:
# 1. Create a second copy of word_embeddings on the last stage, with
# initial parameters of 0.0.
# 2. Do an all-reduce between the first and last stage to ensure that
# the two copies of word_embeddings start off with the same
# parameter values.
# 3. In the training loop, before an all-reduce between the grads of
# the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages.
# Ensure that first and last stages have the same initial parameter
# values.
if
torch
.
distributed
.
is_initialized
():
if
parallel_state
.
is_rank_in_embedding_group
():
weight
=
self
.
shared_embedding_or_mtp_embedding_weight
()
weight
.
data
=
weight
.
data
.
cuda
()
torch
.
distributed
.
all_reduce
(
weight
.
data
,
group
=
parallel_state
.
get_embedding_group
()
)
elif
not
getattr
(
LanguageModule
,
"embedding_warning_printed"
,
False
):
logging
.
getLogger
(
__name__
).
warning
(
"Distributed processes aren't initialized, so the output layer "
"is not initialized with weights from the word embeddings. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong."
)
LanguageModule
.
embedding_warning_printed
=
True
def
slice_inputs
(
self
,
input_ids
,
labels
,
position_ids
,
attention_mask
):
if
self
.
num_nextn_predict_layers
==
0
:
return
(
[
input_ids
],
[
labels
],
[
position_ids
],
[
attention_mask
],
)
return
(
tensor_slide
(
input_ids
,
self
.
num_nextn_predict_layers
),
tensor_slide
(
labels
,
self
.
num_nextn_predict_layers
),
generate_nextn_position_ids
(
position_ids
,
self
.
num_nextn_predict_layers
),
# not compatible with ppo attn_mask
tensor_slide
(
attention_mask
,
self
.
num_nextn_predict_layers
,
dims
=
[
-
2
,
-
1
]),
)
def
generate_nextn_position_ids
(
tensor
,
slice_num
):
slides
=
tensor_slide
(
tensor
,
slice_num
)
if
slides
[
0
]
is
None
:
return
slides
for
idx
in
range
(
1
,
len
(
slides
)):
slides
[
idx
]
=
regenerate_position_ids
(
slides
[
idx
],
idx
)
return
slides
def
regenerate_position_ids
(
tensor
,
offset
):
if
tensor
is
None
:
return
None
tensor
=
tensor
.
clone
()
for
i
in
range
(
tensor
.
size
(
0
)):
row
=
tensor
[
i
]
zero_mask
=
(
row
==
0
)
# 两句拼接情形
if
zero_mask
.
any
():
first_zero_idx
=
torch
.
argmax
(
zero_mask
.
int
()).
item
()
tensor
[
i
,
:
first_zero_idx
]
=
torch
.
arange
(
first_zero_idx
)
else
:
tensor
[
i
]
=
tensor
[
i
]
-
offset
return
tensor
def
gpt_model_forward
(
self
,
input_ids
:
Tensor
,
...
...
@@ -227,9 +55,10 @@ def gpt_model_forward(
packed_seq_params
:
PackedSeqParams
=
None
,
extra_block_kwargs
:
dict
=
None
,
runtime_gather_output
:
Optional
[
bool
]
=
None
,
loss_mask
:
Optional
[
Tensor
]
=
None
,
)
->
Tensor
:
"""Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the deco
e
der and finally into the post
through the embedding layer, and then the decoder and finally into the post
processing layer (optional).
It either returns the Loss values if labels are given or the final hidden units
...
...
@@ -241,20 +70,11 @@ def gpt_model_forward(
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
# generate inputs for main and mtps
input_ids
,
labels
,
position_ids
,
attention_mask
=
slice_inputs
(
self
,
input_ids
,
labels
,
position_ids
,
attention_mask
)
# Decoder embedding.
if
decoder_input
is
not
None
:
pass
elif
self
.
pre_process
:
decoder_input
=
self
.
embedding
(
input_ids
=
input_ids
[
0
]
,
position_ids
=
position_ids
[
0
]
)
decoder_input
=
self
.
embedding
(
input_ids
=
input_ids
,
position_ids
=
position_ids
)
else
:
# intermediate stage of pipeline
# decoder will get hidden_states from encoder.input_tensor
...
...
@@ -296,7 +116,7 @@ def gpt_model_forward(
# Run decoder.
hidden_states
=
self
.
decoder
(
hidden_states
=
decoder_input
,
attention_mask
=
attention_mask
[
0
]
,
attention_mask
=
attention_mask
,
inference_params
=
inference_params
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_cos
=
rotary_pos_cos
,
...
...
@@ -306,46 +126,43 @@ def gpt_model_forward(
**
(
extra_block_kwargs
or
{}),
)
if
not
self
.
post_process
:
return
hidden_states
# logits and loss
output_weight
=
None
if
self
.
share_embeddings_and_output_weights
:
output_weight
=
self
.
shared_embedding_or_output_weight
()
loss
=
0
# Multi token prediction module
if
self
.
num_nextn_predict_layers
and
self
.
training
:
if
not
self
.
share_embeddings_and_output_weights
and
self
.
share_mtp_embedding_and_output_weight
:
output_weight
=
self
.
output_layer
.
weight
output_weight
.
zero_out_wgrad
=
True
embedding_weight
=
self
.
shared_embedding_or_mtp_embedding_weight
()
if
self
.
share_mtp_embedding_and_output_weight
else
None
mtp_hidden_states
=
hidden_states
for
i
in
range
(
self
.
num_nextn_predict_layers
):
mtp_hidden_states
,
mtp_loss
=
self
.
mtp_layers
[
i
](
mtp_hidden_states
,
# [s,b,h]
input_ids
[
i
+
1
],
position_ids
[
i
+
1
]
if
position_ids
[
0
]
is
not
None
else
None
,
attention_mask
[
i
+
1
]
if
attention_mask
[
0
]
is
not
None
else
None
,
labels
[
i
+
1
]
if
labels
[
0
]
is
not
None
else
None
,
inference_params
,
packed_seq_params
,
extra_block_kwargs
,
embeding_weight
=
embedding_weight
,
output_weight
=
output_weight
,
)
if
self
.
mtp_process
:
hidden_states
=
self
.
mtp
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
labels
=
labels
,
loss_mask
=
loss_mask
,
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
inference_params
=
inference_params
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_cos
=
rotary_pos_cos
,
rotary_pos_sin
=
rotary_pos_sin
,
packed_seq_params
=
packed_seq_params
,
sequence_len_offset
=
sequence_len_offset
,
embedding
=
self
.
embedding
,
output_layer
=
self
.
output_layer
,
output_weight
=
output_weight
,
runtime_gather_output
=
runtime_gather_output
,
compute_language_model_loss
=
self
.
compute_language_model_loss
,
**
(
extra_block_kwargs
or
{}),
)
loss
+=
self
.
mtp_loss_scale
/
self
.
num_nextn_predict_layers
*
mtp_loss
if
(
self
.
num_nextn_predict_layers
self
.
mtp_process
is
not
None
and
getattr
(
self
.
decoder
,
"main_final_layernorm"
,
None
)
is
not
None
):
# move block main model final norms here
hidden_states
=
self
.
decoder
.
main_final_layernorm
(
hidden_states
)
if
not
self
.
post_process
:
return
hidden_states
logits
,
_
=
self
.
output_layer
(
hidden_states
,
weight
=
output_weight
,
runtime_gather_output
=
runtime_gather_output
)
...
...
@@ -353,19 +170,19 @@ def gpt_model_forward(
if
has_config_logger_enabled
(
self
.
config
):
payload
=
OrderedDict
(
{
'input_ids'
:
input_ids
[
0
]
,
'position_ids'
:
position_ids
[
0
]
,
'attention_mask'
:
attention_mask
[
0
]
,
'input_ids'
:
input_ids
,
'position_ids'
:
position_ids
,
'attention_mask'
:
attention_mask
,
'decoder_input'
:
decoder_input
,
'logits'
:
logits
,
}
)
log_config_to_disk
(
self
.
config
,
payload
,
prefix
=
'input_and_logits'
)
if
labels
[
0
]
is
None
:
if
labels
is
None
:
# [s b h] => [b s h]
return
logits
.
transpose
(
0
,
1
).
contiguous
()
loss
+
=
self
.
compute_language_model_loss
(
labels
[
0
]
,
logits
)
loss
=
self
.
compute_language_model_loss
(
labels
,
logits
)
return
loss
dcu_megatron/core/tensor_parallel/__init__.py
View file @
a8a2bbea
from
.layers
import
(
FluxColumnParallelLinear
,
FluxRowParallelLinear
,
vocab_parallel_embedding_forward
,
vocab_parallel_embedding_init_wrapper
,
)
\ No newline at end of file
dcu_megatron/core/tensor_parallel/layers.py
View file @
a8a2bbea
import
os
import
copy
import
socket
import
warnings
from
functools
import
wraps
from
typing
import
Callable
,
List
,
Optional
if
int
(
os
.
getenv
(
"USE_FLUX_OVERLAP"
,
"0"
)):
try
:
import
flux
from
dcu_megatron.core.utils
import
is_flux_min_version
except
ImportError
:
raise
ImportError
(
"flux is NOT installed"
)
try
:
import
flux
except
ImportError
:
raise
ImportError
(
"flux is NOT installed"
)
import
torch
import
torch.nn.functional
as
F
from
torch.nn.parameter
import
Parameter
from
megatron.training
import
print_rank_0
from
megatron.core.model_parallel_config
import
ModelParallelConfig
from
megatron.core.parallel_state
import
(
get_global_memory_buffer
,
get_tensor_model_parallel_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
from
megatron.core.utils
import
(
is_torch_min_version
,
prepare_input_tensors_for_wgrad_compute
)
from
megatron.core.tensor_parallel.layers
import
(
_initialize_affine_weight_cpu
,
_initialize_affine_weight_gpu
,
VocabParallelEmbedding
,
)
from
megatron.core.utils
import
prepare_input_tensors_for_wgrad_compute
from
megatron.core.tensor_parallel.mappings
import
(
_reduce
,
copy_to_tensor_model_parallel_region
,
reduce_from_tensor_model_parallel_region
,
reduce_scatter_to_sequence_parallel_region
,
_reduce_scatter_along_first_dim
,
_gather_along_first_dim
,
)
from
megatron.core.tensor_parallel.utils
import
VocabUtility
from
megatron.core.tensor_parallel.mappings
import
_reduce
from
megatron.core.tensor_parallel
import
(
ColumnParallelLinear
,
RowParallelLinear
,
...
...
@@ -50,9 +30,9 @@ from megatron.core.tensor_parallel.layers import (
custom_fwd
,
custom_bwd
,
dist_all_gather_func
,
linear_with_frozen_weight
,
linear_with_grad_accumulation_and_async_allreduce
)
from
dcu_megatron.core.utils
import
is_flux_min_version
_grad_accum_fusion_available
=
True
try
:
...
...
@@ -61,74 +41,6 @@ except ImportError:
_grad_accum_fusion_available
=
False
def
vocab_parallel_embedding_init_wrapper
(
fn
):
@
wraps
(
fn
)
def
wrapper
(
self
,
*
args
,
skip_weight_param_allocation
:
bool
=
False
,
**
kwargs
):
if
(
skip_weight_param_allocation
and
"config"
in
kwargs
and
hasattr
(
kwargs
[
"config"
],
"perform_initialization"
)
):
config
=
copy
.
deepcopy
(
kwargs
[
"config"
])
config
.
perform_initialization
=
False
kwargs
[
"config"
]
=
config
fn
(
self
,
*
args
,
**
kwargs
)
if
skip_weight_param_allocation
:
self
.
weight
=
None
return
wrapper
@
torch
.
compile
(
mode
=
'max-autotune-no-cudagraphs'
)
def
vocab_parallel_embedding_forward
(
self
,
input_
,
weight
=
None
):
"""Forward.
Args:
input_ (torch.Tensor): Input tensor.
"""
if
weight
is
None
:
if
self
.
weight
is
None
:
raise
RuntimeError
(
"weight was not supplied to VocabParallelEmbedding forward pass "
"and skip_weight_param_allocation is True."
)
weight
=
self
.
weight
if
self
.
tensor_model_parallel_size
>
1
:
# Build the mask.
input_mask
=
(
input_
<
self
.
vocab_start_index
)
|
(
input_
>=
self
.
vocab_end_index
)
# Mask the input.
masked_input
=
input_
.
clone
()
-
self
.
vocab_start_index
masked_input
[
input_mask
]
=
0
else
:
masked_input
=
input_
# Get the embeddings.
if
self
.
deterministic_mode
:
output_parallel
=
weight
[
masked_input
]
else
:
# F.embedding currently has a non-deterministic backward function
output_parallel
=
F
.
embedding
(
masked_input
,
weight
)
# Mask the output embedding.
if
self
.
tensor_model_parallel_size
>
1
:
output_parallel
[
input_mask
,
:]
=
0.0
if
self
.
reduce_scatter_embeddings
:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
output_parallel
=
output_parallel
.
transpose
(
0
,
1
).
contiguous
()
output
=
reduce_scatter_to_sequence_parallel_region
(
output_parallel
)
else
:
# Reduce across all the model parallel GPUs.
output
=
reduce_from_tensor_model_parallel_region
(
output_parallel
)
return
output
def
get_tensor_model_parallel_node_size
(
group
=
None
):
""" 获取节点数
"""
...
...
dcu_megatron/core/tensor_parallel/random.py
deleted
100644 → 0
View file @
2ddbd4be
import
torch
from
megatron.core.tensor_parallel.random
import
(
get_cuda_rng_tracker
,
_set_cuda_rng_state
)
class
CheckpointFunctionWithoutOutput
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
run_function
,
checkpoint
,
*
args
):
with
torch
.
no_grad
():
outputs
=
run_function
(
*
args
)
# Store everything
ctx
.
save_for_backward
(
*
detach_variable
(
args
))
checkpoint
.
ctx
=
ctx
return
outputs
@
staticmethod
def
backward
(
ctx
,
*
args
):
inputs
=
ctx
.
saved_tensors
outputs
=
ctx
.
outputs
torch
.
autograd
.
backward
(
outputs
,
args
)
ctx
.
outputs
=
None
grads
=
tuple
(
inp
.
grad
if
isinstance
(
inp
,
torch
.
Tensor
)
else
inp
for
inp
in
inputs
)
return
(
None
,
None
)
+
grads
class
CheckpointWithoutOutput
:
def
__init__
(
self
):
self
.
run_function
=
None
self
.
fwd_cpu_rng_state
=
None
self
.
fwd_cuda_rng_state
=
None
self
.
fwd_cuda_rng_state_tracker
=
None
self
.
outputs
=
None
def
checkpoint
(
self
,
run_function
,
distribute_saved_activations
,
*
args
):
self
.
run_function
=
run_function
if
distribute_saved_activations
:
raise
RuntimeError
(
"CheckpointFunctionWithoutOutput does not support "
"distribute_saved_activations"
)
#Copy the rng states.
self
.
fwd_cpu_rng_state
=
torch
.
get_rng_state
()
self
.
fwd_cuda_rng_state
=
torch
.
cuda
.
get_rng_state
()
self
.
fwd_cuda_rng_state_tracker
=
get_cuda_rng_tracker
().
get_states
()
outputs
=
CheckpointFunctionWithoutOutput
.
apply
(
run_function
,
self
,
*
args
)
self
.
outputs
=
outputs
if
isinstance
(
self
.
outputs
,
torch
.
Tensor
):
self
.
outputs
=
(
self
.
outputs
,)
return
outputs
def
discard_output
(
self
):
for
output
in
self
.
outputs
:
output
.
untyped_storage
().
resize_
(
0
)
def
recompute
(
self
,
_
):
if
not
torch
.
autograd
.
_is_checkpoint_valid
():
raise
RuntimeError
(
"Checkpointing is not compatible with .grad(), "
"please use .backward() if possible"
)
# Store the current states.
cur_cpu_rng_state
=
torch
.
get_rng_state
()
cur_cuda_rng_state
=
torch
.
cuda
.
get_rng_state
()
cur_cuda_rng_state_tracker
=
get_cuda_rng_tracker
().
get_states
()
# Set the states to what it used to be before the forward pass.
torch
.
set_rng_state
(
self
.
fwd_cpu_rng_state
)
_set_cuda_rng_state
(
self
.
fwd_cuda_rng_state
)
get_cuda_rng_tracker
().
set_states
(
self
.
fwd_cuda_rng_state_tracker
)
with
torch
.
enable_grad
():
outputs
=
self
.
run_function
(
*
self
.
ctx
.
saved_tensors
)
self
.
run_function
=
None
self
.
fwd_cpu_rng_state
=
None
self
.
fwd_cuda_rng_state
=
None
self
.
fwd_cuda_rng_state_tracker
=
None
# Set the states back to what it was at the start of this function.
torch
.
set_rng_state
(
cur_cpu_rng_state
)
_set_cuda_rng_state
(
cur_cuda_rng_state
)
get_cuda_rng_tracker
().
set_states
(
cur_cuda_rng_state_tracker
)
if
isinstance
(
outputs
,
torch
.
Tensor
):
outputs
=
(
outputs
,)
for
output
,
recomputation_output
in
zip
(
self
.
outputs
,
outputs
):
output_size
=
recomputation_output
.
untyped_storage
().
size
()
output
.
untyped_storage
().
resize_
(
output_size
)
with
torch
.
no_grad
():
output
.
untyped_storage
().
copy_
(
recomputation_output
.
untyped_storage
())
self
.
ctx
.
outputs
=
outputs
self
.
outputs
=
None
self
.
ctx
=
None
dcu_megatron/core/transformer/mtp/mtp_spec.py
deleted
100644 → 0
View file @
2ddbd4be
import
warnings
from
megatron.core.tensor_parallel
import
ColumnParallelLinear
from
megatron.core.transformer
import
ModuleSpec
from
.multi_token_predictor
import
(
MultiTokenPredicationSubmodules
,
MultiTokenPredictor
)
try
:
from
megatron.core.extensions.transformer_engine
import
(
TEColumnParallelLinear
,
TENorm
)
HAVE_TE
=
True
except
ImportError
:
HAVE_TE
=
False
try
:
import
apex
from
megatron.core.fusions.fused_layer_norm
import
FusedLayerNorm
LNImpl
=
FusedLayerNorm
except
ImportError
:
from
megatron.core.transformer.torch_norm
import
WrappedTorchNorm
warnings
.
warn
(
'Apex is not installed. Falling back to Torch Norm'
)
LNImpl
=
WrappedTorchNorm
def
get_mtp_spec
(
transformer_layer
,
use_te
=
False
):
"""
Multi Token Predication Layer Specification.
"""
use_te
=
use_te
&
HAVE_TE
mtp_spec
=
ModuleSpec
(
module
=
MultiTokenPredictor
,
submodules
=
MultiTokenPredicationSubmodules
(
embedding
=
None
,
enorm
=
TENorm
if
use_te
else
LNImpl
,
hnorm
=
TENorm
if
use_te
else
LNImpl
,
eh_proj
=
TEColumnParallelLinear
if
use_te
else
ColumnParallelLinear
,
transformer_layer
=
transformer_layer
,
final_layernorm
=
TENorm
if
use_te
else
LNImpl
,
output_layer
=
None
,
)
)
return
mtp_spec
dcu_megatron/core/transformer/mtp/multi_token_predictor.py
deleted
100644 → 0
View file @
2ddbd4be
import
os
import
logging
from
dataclasses
import
dataclass
from
typing
import
Union
,
Optional
,
Literal
import
torch
from
torch
import
Tensor
from
megatron.core
import
tensor_parallel
,
InferenceParams
from
megatron.core.models.common.embeddings.language_model_embedding
import
LanguageModelEmbedding
from
megatron.core.models.common.embeddings.rotary_pos_embedding
import
RotaryEmbedding
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.transformer.module
import
MegatronModule
from
megatron.core.extensions.transformer_engine
import
TEColumnParallelLinear
from
megatron.core.fusions.fused_cross_entropy
import
fused_vocab_parallel_cross_entropy
from
megatron.core.transformer
import
ModuleSpec
,
TransformerConfig
,
build_module
from
...tensor_parallel.random
import
CheckpointWithoutOutput
from
...tensor_parallel
import
FluxColumnParallelLinear
@
dataclass
class
MultiTokenPredicationSubmodules
:
embedding
:
Union
[
ModuleSpec
,
type
]
=
None
output_layer
:
Union
[
ModuleSpec
,
type
]
=
None
eh_proj
:
Union
[
ModuleSpec
,
type
]
=
None
enorm
:
Union
[
ModuleSpec
,
type
]
=
None
hnorm
:
Union
[
ModuleSpec
,
type
]
=
None
transformer_layer
:
Union
[
ModuleSpec
,
type
]
=
None
final_layernorm
:
Union
[
ModuleSpec
,
type
]
=
None
class
MultiTokenPredictor
(
MegatronModule
):
def
__init__
(
self
,
config
:
TransformerConfig
,
submodules
:
MultiTokenPredicationSubmodules
,
vocab_size
:
int
,
max_sequence_length
:
int
,
layer_number
:
int
=
1
,
hidden_dropout
:
float
=
None
,
pre_process
:
bool
=
True
,
fp16_lm_cross_entropy
:
bool
=
False
,
parallel_output
:
bool
=
True
,
position_embedding_type
:
Literal
[
'learned_absolute'
,
'rope'
,
'none'
]
=
'learned_absolute'
,
rotary_percent
:
float
=
1.0
,
rotary_base
:
int
=
10000
,
seq_len_interpolation_factor
:
Optional
[
float
]
=
None
,
share_mtp_embedding_and_output_weight
=
True
,
recompute_mtp_norm
=
False
,
recompute_mtp_layer
=
False
,
add_output_layer_bias
=
False
):
super
().
__init__
(
config
=
config
)
self
.
config
=
config
self
.
submodules
=
submodules
self
.
layer_number
=
layer_number
self
.
hidden_dropout
=
hidden_dropout
self
.
hidden_size
=
self
.
config
.
hidden_size
self
.
vocab_size
=
vocab_size
self
.
max_sequence_length
=
max_sequence_length
self
.
pre_process
=
pre_process
self
.
fp16_lm_cross_entropy
=
fp16_lm_cross_entropy
self
.
parallel_output
=
parallel_output
self
.
position_embedding_type
=
position_embedding_type
# share with main model
self
.
share_mtp_embedding_and_output_weight
=
share_mtp_embedding_and_output_weight
self
.
recompute_layer_norm
=
recompute_mtp_norm
self
.
recompute_mtp_layer
=
recompute_mtp_layer
self
.
add_output_layer_bias
=
add_output_layer_bias
self
.
embedding
=
LanguageModelEmbedding
(
config
=
self
.
config
,
vocab_size
=
self
.
vocab_size
,
max_sequence_length
=
self
.
max_sequence_length
,
position_embedding_type
=
self
.
position_embedding_type
,
skip_weight_param_allocation
=
self
.
pre_process
and
self
.
share_mtp_embedding_and_output_weight
)
if
self
.
position_embedding_type
==
'rope'
:
self
.
rotary_pos_emb
=
RotaryEmbedding
(
kv_channels
=
self
.
config
.
kv_channels
,
rotary_percent
=
rotary_percent
,
rotary_interleaved
=
self
.
config
.
rotary_interleaved
,
seq_len_interpolation_factor
=
seq_len_interpolation_factor
,
rotary_base
=
rotary_base
,
use_cpu_initialization
=
self
.
config
.
use_cpu_initialization
,
)
self
.
enorm
=
build_module
(
self
.
submodules
.
enorm
,
config
=
self
.
config
,
hidden_size
=
self
.
config
.
hidden_size
,
eps
=
self
.
config
.
layernorm_epsilon
,
)
self
.
hnorm
=
build_module
(
self
.
submodules
.
hnorm
,
config
=
self
.
config
,
hidden_size
=
self
.
config
.
hidden_size
,
eps
=
self
.
config
.
layernorm_epsilon
,
)
self
.
eh_proj
=
build_module
(
self
.
submodules
.
eh_proj
,
self
.
hidden_size
+
self
.
hidden_size
,
self
.
hidden_size
,
config
=
self
.
config
,
init_method
=
self
.
config
.
init_method
,
gather_output
=
False
,
bias
=
self
.
config
.
add_bias_linear
,
skip_bias_add
=
True
,
is_expert
=
False
,
tp_comm_buffer_name
=
'eh'
,
)
self
.
transformer_layer
=
build_module
(
self
.
submodules
.
transformer_layer
,
config
=
self
.
config
,
)
if
self
.
submodules
.
final_layernorm
:
self
.
final_layernorm
=
build_module
(
self
.
submodules
.
final_layernorm
,
config
=
self
.
config
,
hidden_size
=
self
.
config
.
hidden_size
,
eps
=
self
.
config
.
layernorm_epsilon
,
)
else
:
self
.
final_layernorm
=
None
if
self
.
config
.
defer_embedding_wgrad_compute
:
self
.
embedding_activation_buffer
=
[]
self
.
grad_output_buffer
=
[]
else
:
self
.
embedding_activation_buffer
=
None
self
.
grad_output_buffer
=
None
if
int
(
os
.
getenv
(
"USE_FLUX_OVERLAP"
,
"0"
)):
column_parallel_linear_impl
=
FluxColumnParallelLinear
else
:
column_parallel_linear_impl
=
tensor_parallel
.
ColumnParallelLinear
self
.
output_layer
=
column_parallel_linear_impl
(
self
.
config
.
hidden_size
,
self
.
vocab_size
,
config
=
self
.
config
,
init_method
=
self
.
config
.
init_method
,
bias
=
False
,
skip_bias_add
=
False
,
gather_output
=
not
self
.
parallel_output
,
skip_weight_param_allocation
=
self
.
share_mtp_embedding_and_output_weight
,
embedding_activation_buffer
=
self
.
embedding_activation_buffer
,
grad_output_buffer
=
self
.
grad_output_buffer
,
)
def
forward
(
self
,
hidden_input_ids
:
Tensor
,
embed_input_ids
:
Tensor
,
position_ids
:
Tensor
,
attention_mask
:
Tensor
,
labels
:
Tensor
=
None
,
inference_params
:
InferenceParams
=
None
,
packed_seq_params
:
PackedSeqParams
=
None
,
extra_block_kwargs
:
dict
=
None
,
embeding_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
output_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
):
"""Forward function of the MTP module"""
# Decoder embedding.
decoder_input
=
self
.
embedding
(
input_ids
=
embed_input_ids
,
position_ids
=
position_ids
,
weight
=
embeding_weight
,
)
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb
=
None
if
self
.
position_embedding_type
==
'rope'
and
not
self
.
config
.
multi_latent_attention
:
if
inference_params
is
not
None
:
rotary_seq_len
=
inference_params
.
max_sequence_length
else
:
rotary_seq_len
=
decoder_input
.
size
(
0
)
if
self
.
config
.
sequence_parallel
:
rotary_seq_len
*=
self
.
config
.
tensor_model_parallel_size
rotary_seq_len
*=
self
.
config
.
context_parallel_size
rotary_pos_emb
=
self
.
rotary_pos_emb
(
rotary_seq_len
)
if
self
.
recompute_layer_norm
:
self
.
enorm_ckpt
=
CheckpointWithoutOutput
()
enorm_output
=
self
.
enorm_ckpt
.
checkpoint
(
self
.
enorm
,
False
,
decoder_input
)
self
.
hnorm_ckpt
=
CheckpointWithoutOutput
()
hnorm_output
=
self
.
hnorm_ckpt
.
checkpoint
(
self
.
hnorm
,
False
,
hidden_input_ids
)
else
:
enorm_output
=
self
.
enorm
(
decoder_input
)
hnorm_output
=
self
.
hnorm
(
hidden_input_ids
)
# [s, b, h] -> [s, b, 2h]
hidden_states
=
torch
.
concat
(
[
hnorm_output
,
enorm_output
],
dim
=-
1
)
if
self
.
recompute_layer_norm
:
self
.
enorm_ckpt
.
discard_output
()
self
.
hnorm_ckpt
.
discard_output
()
hidden_states
.
register_hook
(
self
.
enorm_ckpt
.
recompute
)
hidden_states
.
register_hook
(
self
.
hnorm_ckpt
.
recompute
)
# hidden_states -> [s, b, h]
hidden_states
,
_
=
self
.
eh_proj
(
hidden_states
)
if
self
.
config
.
tensor_model_parallel_size
>
1
:
hidden_states
=
tensor_parallel
.
gather_from_tensor_model_parallel_region
(
hidden_states
)
if
self
.
config
.
sequence_parallel
:
hidden_states
=
tensor_parallel
.
scatter_to_sequence_parallel_region
(
hidden_states
)
if
self
.
recompute_mtp_layer
:
hidden_states
,
context
=
tensor_parallel
.
checkpoint
(
self
.
transformer_layer
,
self
.
config
.
distribute_saved_activations
,
hidden_states
,
attention_mask
,
None
,
None
,
rotary_pos_emb
,
inference_params
,
packed_seq_params
,
)
else
:
hidden_states
,
_
=
self
.
transformer_layer
(
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
rotary_pos_emb
=
rotary_pos_emb
,
inference_params
=
inference_params
,
packed_seq_params
=
packed_seq_params
,
**
(
extra_block_kwargs
or
{}),
)
# Final layer norm.
if
self
.
final_layernorm
is
not
None
:
if
self
.
recompute_layer_norm
:
self
.
finalnorm_ckpt
=
CheckpointWithoutOutput
()
finalnorm_output
=
self
.
finalnorm_ckpt
.
checkpoint
(
self
.
final_layernorm
,
False
,
hidden_states
)
else
:
finalnorm_output
=
self
.
final_layernorm
(
hidden_states
)
else
:
finalnorm_output
=
hidden_states
logits
,
_
=
self
.
output_layer
(
finalnorm_output
,
weight
=
output_weight
)
if
self
.
recompute_layer_norm
:
self
.
finalnorm_ckpt
.
discard_output
()
logits
.
register_hook
(
self
.
finalnorm_ckpt
.
recompute
)
if
labels
is
None
:
# [s b h] => [b s h]
return
logits
.
transpose
(
0
,
1
).
contiguous
()
loss
=
self
.
compute_language_model_loss
(
labels
,
logits
)
return
hidden_states
,
loss
def
compute_language_model_loss
(
self
,
labels
:
Tensor
,
logits
:
Tensor
)
->
Tensor
:
"""Computes the language model loss (Cross entropy across vocabulary)
Args:
labels (Tensor): The labels of dimension [batch size, seq length]
logits (Tensor): The final logits returned by the output layer of the transformer model
Returns:
Tensor: Loss tensor of dimensions [batch size, sequence_length]
"""
# [b s] => [s b]
labels
=
labels
.
transpose
(
0
,
1
).
contiguous
()
if
self
.
config
.
cross_entropy_loss_fusion
:
loss
=
fused_vocab_parallel_cross_entropy
(
logits
,
labels
)
else
:
loss
=
tensor_parallel
.
vocab_parallel_cross_entropy
(
logits
,
labels
)
# [s b] => [b, s]
loss
=
loss
.
transpose
(
0
,
1
).
contiguous
()
return
loss
\ No newline at end of file
dcu_megatron/core/transformer/transformer_block.py
View file @
a8a2bbea
...
...
@@ -8,7 +8,7 @@ def transformer_block_init_wrapper(fn):
# mtp require seperate layernorms for main model and mtp modules, thus move finalnorm out of block
config
=
args
[
0
]
if
len
(
args
)
>
1
else
kwargs
[
'config'
]
if
getattr
(
config
,
"
num_nextn_predict
_layers"
,
0
)
>
0
:
if
getattr
(
config
,
"
mtp_num
_layers"
,
0
)
>
0
:
self
.
main_final_layernorm
=
self
.
final_layernorm
self
.
final_layernorm
=
None
...
...
dcu_megatron/core/transformer/transformer_config.py
View file @
a8a2bbea
from
typing
import
Optional
from
functools
import
wraps
from
dataclasses
import
dataclass
from
megatron.training
import
get_args
from
megatron.core.transformer.transformer_config
import
TransformerConfig
,
MLATransformerConfig
@
dataclass
class
ExtraTransformerConfig
:
##################
# multi-token prediction
##################
num_nextn_predict_layers
:
int
=
0
"""The number of multi-token prediction layers"""
def
transformer_config_post_init_wrapper
(
fn
):
@
wraps
(
fn
)
def
wrapper
(
self
):
fn
(
self
)
args
=
get_args
()
mtp_loss_scale
:
float
=
0.3
"""Multi-token prediction loss scale"""
"""Number of Multi-Token Prediction (MTP) Layers."""
self
.
mtp_num_layers
=
args
.
mtp_num_layers
recompute_mtp_norm
:
bool
=
False
"""Whether to recompute mtp normalization"""
"""Weighting factor of Multi-Token Prediction (MTP) loss."""
self
.
mtp_loss_scaling_factor
=
args
.
mtp_loss_scaling_factor
recompute_mtp_layer
:
bool
=
False
"""Whether to recompute mtp layer"""
##################
# flux
##################
self
.
flux_transpose_weight
=
args
.
flux_transpose_weight
share_mtp_embedding_and_output_weight
:
bool
=
False
"""share embedding and output weight with mtp layer."""
return
wrapper
@
dataclass
class
ExtraTransformerConfig
:
##################
# flux
##################
...
...
dcu_megatron/core/utils.py
View file @
a8a2bbea
...
...
@@ -30,33 +30,3 @@ def is_flux_min_version(version, check_equality=True):
if
check_equality
:
return
get_flux_version
()
>=
PkgVersion
(
version
)
return
get_flux_version
()
>
PkgVersion
(
version
)
def
tensor_slide
(
tensor
:
Optional
[
torch
.
Tensor
],
num_slice
:
int
,
dims
:
Union
[
int
,
List
[
int
]]
=
-
1
,
step
:
int
=
1
,
return_first
=
False
,
)
->
List
[
Union
[
torch
.
Tensor
,
None
]]:
"""通用滑动窗口函数,支持任意维度"""
if
tensor
is
None
:
# return `List[None]` to avoid NoneType Error
return
[
None
]
*
(
num_slice
+
1
)
if
num_slice
==
0
:
return
[
tensor
]
window_size
=
tensor
.
shape
[
-
1
]
-
num_slice
dims
=
[
dims
]
if
isinstance
(
dims
,
int
)
else
sorted
(
dims
,
reverse
=
True
)
# 连续多维度滑动
slices
=
[]
for
i
in
range
(
0
,
tensor
.
size
(
dims
[
-
1
])
-
window_size
+
1
,
step
):
slice_obj
=
[
slice
(
None
)]
*
tensor
.
dim
()
for
dim
in
dims
:
slice_obj
[
dim
]
=
slice
(
i
,
i
+
window_size
)
slices
.
append
(
tensor
[
tuple
(
slice_obj
)])
if
return_first
:
return
slices
return
slices
dcu_megatron/training/arguments.py
View file @
a8a2bbea
...
...
@@ -170,14 +170,16 @@ def _add_extra_tokenizer_args(parser):
def
_add_mtp_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'multi token prediction'
)
group
.
add_argument
(
'--num-nextn-predict-layers'
,
type
=
int
,
default
=
0
,
help
=
'Multi-Token prediction layer num'
)
group
.
add_argument
(
'--mtp-loss-scale'
,
type
=
float
,
default
=
0.3
,
help
=
'Multi-Token prediction loss scale'
)
group
.
add_argument
(
'--recompute-mtp-norm'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Multi-Token prediction recompute norm'
)
group
.
add_argument
(
'--recompute-mtp-layer'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Multi-Token prediction recompute layer'
)
group
.
add_argument
(
'--share-mtp-embedding-and-output-weight'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Main model share embedding and output weight with mtp layer.'
)
group
.
add_argument
(
'--mtp-num-layers'
,
type
=
int
,
default
=
None
,
help
=
'Number of Multi-Token Prediction (MTP) Layers.'
'MTP extends the prediction scope to multiple future tokens at each position.'
'This MTP implementation sequentially predict additional tokens '
'by using D sequential modules to predict D additional tokens.'
)
group
.
add_argument
(
'--mtp-loss-scaling-factor'
,
type
=
float
,
default
=
0.3
,
help
=
'Scaling factor of Multi-Token Prediction (MTP) loss. '
'We compute the average of the MTP losses across all depths, '
'and multiply it the scaling factor to obtain the overall MTP loss, '
'which serves as an additional training objective.'
)
return
parser
...
...
dcu_megatron/training/initialize.py
View file @
a8a2bbea
...
...
@@ -7,6 +7,72 @@ from megatron.training import get_args
from
megatron.core
import
mpu
def
_compile_dependencies
():
args
=
get_args
()
# =========================
# Compile dataset C++ code.
# =========================
# TODO: move this to ninja
if
torch
.
distributed
.
get_rank
()
==
0
:
start_time
=
time
.
time
()
print
(
"> compiling dataset index builder ..."
)
from
megatron.core.datasets.utils
import
compile_helpers
compile_helpers
()
print
(
">>> done with dataset index builder. Compilation time: {:.3f} "
"seconds"
.
format
(
time
.
time
()
-
start_time
),
flush
=
True
,
)
# ==================
# Load fused kernels
# ==================
# Custom kernel constraints check.
seq_len
=
args
.
seq_length
attn_batch_size
=
(
args
.
num_attention_heads
/
args
.
tensor_model_parallel_size
)
*
args
.
micro_batch_size
# Constraints on sequence length and attn_batch_size to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint
=
(
seq_len
>
16
and
seq_len
<=
16384
and
seq_len
%
4
==
0
and
attn_batch_size
%
4
==
0
)
# Print a warning.
if
not
((
args
.
fp16
or
args
.
bf16
)
and
custom_kernel_constraint
and
args
.
masked_softmax_fusion
):
if
args
.
rank
==
0
:
print
(
"WARNING: constraints for invoking optimized"
" fused softmax kernel are not met. We default"
" back to unfused kernel invocations."
,
flush
=
True
,
)
# Always build on rank zero first.
if
torch
.
distributed
.
get_rank
()
==
0
:
start_time
=
time
.
time
()
print
(
"> compiling and loading fused kernels ..."
,
flush
=
True
)
#fused_kernels.load(args)
torch
.
distributed
.
barrier
()
else
:
torch
.
distributed
.
barrier
()
#fused_kernels.load(args)
# Simple barrier to make sure all ranks have passed the
# compilation phase successfully before moving on to the
# rest of the program. We think this might ensure that
# the lock is released.
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
">>> done with compiling and loading fused kernels. "
"Compilation time: {:.3f} seconds"
.
format
(
time
.
time
()
-
start_time
),
flush
=
True
,
)
def
_initialize_distributed
(
get_embedding_ranks
,
get_position_embedding_ranks
):
"""Initialize torch.distributed and core model parallel."""
args
=
get_args
()
...
...
@@ -16,8 +82,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
if
args
.
rank
==
0
:
print
(
"torch distributed is already initialized, "
"skipping initialization ..."
,
"torch distributed is already initialized, "
"skipping initialization ..."
,
flush
=
True
,
)
args
.
rank
=
torch
.
distributed
.
get_rank
()
...
...
@@ -34,6 +99,10 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
else
:
device_id
=
None
# Set to non-default stream for cudagraph capturing.
if
args
.
external_cuda_graph
:
torch
.
cuda
.
set_stream
(
torch
.
cuda
.
Stream
())
# Call the init process
init_process_group_kwargs
=
{
'backend'
:
args
.
distributed_backend
,
...
...
@@ -56,6 +125,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
args
.
pipeline_model_parallel_size
,
args
.
virtual_pipeline_model_parallel_size
,
args
.
pipeline_model_parallel_split_rank
,
pipeline_model_parallel_comm_backend
=
args
.
pipeline_model_parallel_comm_backend
,
context_parallel_size
=
args
.
context_parallel_size
,
hierarchical_context_parallel_sizes
=
args
.
hierarchical_context_parallel_sizes
,
expert_model_parallel_size
=
args
.
expert_model_parallel_size
,
...
...
@@ -68,6 +138,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
encoder_pipeline_model_parallel_size
=
args
.
encoder_pipeline_model_parallel_size
,
get_embedding_ranks
=
get_embedding_ranks
,
get_position_embedding_ranks
=
get_position_embedding_ranks
,
create_gloo_process_groups
=
args
.
enable_gloo_process_groups
,
)
if
args
.
rank
==
0
:
print
(
...
...
@@ -78,76 +149,3 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
f
"> initialized pipeline model parallel with size "
f
"
{
mpu
.
get_pipeline_model_parallel_world_size
()
}
"
)
def
_compile_dependencies
():
args
=
get_args
()
# =========================
# Compile dataset C++ code.
# =========================
# TODO: move this to ninja
if
torch
.
distributed
.
get_rank
()
==
0
:
start_time
=
time
.
time
()
print
(
"> compiling dataset index builder ..."
)
from
megatron.core.datasets.utils
import
compile_helpers
compile_helpers
()
print
(
">>> done with dataset index builder. Compilation time: {:.3f} "
"seconds"
.
format
(
time
.
time
()
-
start_time
),
flush
=
True
,
)
# ==================
# Load fused kernels
# ==================
# Custom kernel constraints check.
seq_len
=
args
.
seq_length
attn_batch_size
=
(
args
.
num_attention_heads
/
args
.
tensor_model_parallel_size
)
*
args
.
micro_batch_size
# Constraints on sequence length and attn_batch_size to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint
=
(
seq_len
>
16
and
seq_len
<=
16384
and
seq_len
%
4
==
0
and
attn_batch_size
%
4
==
0
)
# Print a warning.
if
not
(
(
args
.
fp16
or
args
.
bf16
)
and
custom_kernel_constraint
and
args
.
masked_softmax_fusion
):
if
args
.
rank
==
0
:
print
(
"WARNING: constraints for invoking optimized"
" fused softmax kernel are not met. We default"
" back to unfused kernel invocations."
,
flush
=
True
,
)
# Always build on rank zero first.
if
torch
.
distributed
.
get_rank
()
==
0
:
start_time
=
time
.
time
()
print
(
"> compiling and loading fused kernels ..."
,
flush
=
True
)
#fused_kernels.load(args)
torch
.
distributed
.
barrier
()
else
:
torch
.
distributed
.
barrier
()
#fused_kernels.load(args)
# Simple barrier to make sure all ranks have passed the
# compilation phase successfully before moving on to the
# rest of the program. We think this might ensure that
# the lock is released.
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
">>> done with compiling and loading fused kernels. "
"Compilation time: {:.3f} seconds"
.
format
(
time
.
time
()
-
start_time
),
flush
=
True
,
)
dcu_megatron/training/training.py
View file @
a8a2bbea
...
...
@@ -50,14 +50,34 @@ from megatron.training.training import (
stimer
=
StragglerDetector
()
def
train
(
forward_step_func
,
model
,
optimizer
,
opt_param_scheduler
,
train_data_iterator
,
valid_data_iterator
,
process_non_loss_data_func
,
config
,
checkpointing_context
,
non_loss_data_func
):
def
train
(
forward_step_func
,
model
,
optimizer
,
opt_param_scheduler
,
train_data_iterator
,
valid_data_iterator
,
process_non_loss_data_func
,
config
,
checkpointing_context
,
non_loss_data_func
,
):
"""Training function: run train_step desired number of times, run validation, checkpoint."""
args
=
get_args
()
timers
=
get_timers
()
one_logger
=
get_one_logger
()
if
args
.
run_workload_inspector_server
:
try
:
from
workload_inspector.utils.webserver
import
run_server
import
threading
threading
.
Thread
(
target
=
run_server
,
daemon
=
True
,
args
=
(
torch
.
distributed
.
get_rank
(),)
).
start
()
except
ModuleNotFoundError
:
print_rank_0
(
"workload inspector module not found."
)
# Write args to tensorboard
write_args_to_tensorboard
()
...
...
@@ -70,23 +90,35 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
# Iterations.
iteration
=
args
.
iteration
# Make sure rerun_state_machine has the right iteration loaded from checkpoint.
rerun_state_machine
=
get_rerun_state_machine
()
if
rerun_state_machine
.
current_iteration
!=
iteration
:
print_rank_0
(
f
"Setting rerun_state_machine.current_iteration to
{
iteration
}
..."
)
rerun_state_machine
.
current_iteration
=
iteration
# Track E2E metrics at the start of training.
one_logger_utils
.
on_train_start
(
iteration
=
iteration
,
consumed_train_samples
=
args
.
consumed_train_samples
,
train_samples
=
args
.
train_samples
,
seq_length
=
args
.
seq_length
,
train_iters
=
args
.
train_iters
,
save
=
args
.
save
,
async_save
=
args
.
async_save
,
log_throughput
=
args
.
log_throughput
,
num_floating_point_operations_so_far
=
args
.
num_floating_point_operations_so_far
)
one_logger_utils
.
on_train_start
(
iteration
=
iteration
,
consumed_train_samples
=
args
.
consumed_train_samples
,
train_samples
=
args
.
train_samples
,
seq_length
=
args
.
seq_length
,
train_iters
=
args
.
train_iters
,
save
=
args
.
save
,
async_save
=
args
.
async_save
,
log_throughput
=
args
.
log_throughput
,
num_floating_point_operations_so_far
=
args
.
num_floating_point_operations_so_far
,
)
num_floating_point_operations_so_far
=
args
.
num_floating_point_operations_so_far
# Setup some training config params.
config
.
grad_scale_func
=
optimizer
.
scale_loss
config
.
timers
=
timers
if
isinstance
(
model
[
0
],
DDP
)
and
args
.
overlap_grad_reduce
:
assert
config
.
no_sync_func
is
None
,
\
(
'When overlap_grad_reduce is True, config.no_sync_func must be None; '
'a custom no_sync_func is not supported when overlapping grad-reduce'
)
if
isinstance
(
model
[
0
],
(
custom_FSDP
,
DDP
))
and
args
.
overlap_grad_reduce
:
assert
config
.
no_sync_func
is
None
,
(
'When overlap_grad_reduce is True, config.no_sync_func must be None; '
'a custom no_sync_func is not supported when overlapping grad-reduce'
)
config
.
no_sync_func
=
[
model_chunk
.
no_sync
for
model_chunk
in
model
]
if
len
(
model
)
==
1
:
config
.
no_sync_func
=
config
.
no_sync_func
[
0
]
...
...
@@ -110,8 +142,9 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
if
args
.
manual_gc
:
# Disable the default garbage collector and perform the collection manually.
# This is to align the timing of garbage collection across ranks.
assert
args
.
manual_gc_interval
>=
0
,
\
'Manual garbage collection interval should be larger than or equal to 0'
assert
(
args
.
manual_gc_interval
>=
0
),
'Manual garbage collection interval should be larger than or equal to 0'
gc
.
disable
()
gc
.
collect
()
...
...
@@ -121,10 +154,13 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
world
=
torch
.
distributed
.
get_world_size
()
rank
=
torch
.
distributed
.
get_rank
()
mmcnt
=
args
.
straggler_minmax_count
stimer
.
configure
(
world
,
rank
,
mmcnt
=
mmcnt
,
enabled
=
not
args
.
disable_straggler_on_startup
,
port
=
args
.
straggler_ctrlr_port
)
stimer
.
configure
(
world
,
rank
,
mmcnt
=
mmcnt
,
enabled
=
not
args
.
disable_straggler_on_startup
,
port
=
args
.
straggler_ctrlr_port
,
)
num_floating_point_operations_since_last_log_event
=
0.0
num_microbatches
=
get_num_microbatches
()
...
...
@@ -132,10 +168,10 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
eval_iterations
=
0
def
get_e2e_base_metrics
():
"""Get base metrics values for one-logger to calculate E2E tracking metrics.
"""
num_floating_point_operations_since_current_train_start
=
\
"""Get base metrics values for one-logger to calculate E2E tracking metrics."""
num_floating_point_operations_since_current_train_start
=
(
num_floating_point_operations_so_far
-
args
.
num_floating_point_operations_so_far
)
return
{
'iteration'
:
iteration
,
'train_duration'
:
timers
(
'interval-time'
).
active_time
(),
...
...
@@ -145,7 +181,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
'num_floating_point_operations_so_far'
:
num_floating_point_operations_so_far
,
'consumed_train_samples'
:
args
.
consumed_train_samples
,
'world_size'
:
args
.
world_size
,
'seq_length'
:
args
.
seq_length
'seq_length'
:
args
.
seq_length
,
}
# Cache into one-logger for callback.
if
one_logger
:
...
...
@@ -153,7 +189,11 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
one_logger
.
store_set
(
'get_e2e_base_metrics'
,
get_e2e_base_metrics
)
prof
=
None
if
args
.
profile
and
torch
.
distributed
.
get_rank
()
in
args
.
profile_ranks
and
args
.
use_pytorch_profiler
:
if
(
args
.
profile
and
torch
.
distributed
.
get_rank
()
in
args
.
profile_ranks
and
args
.
use_pytorch_profiler
):
def
trace_handler
(
p
):
from
pathlib
import
Path
Path
(
f
"
{
args
.
profile_dir
}
"
).
mkdir
(
parents
=
True
,
exist_ok
=
True
)
...
...
@@ -178,9 +218,9 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
warmup
=
1
if
args
.
profile_step_start
>
0
else
0
,
active
=
args
.
profile_step_end
-
args
.
profile_step_start
,
repeat
=
1
),
on_trace_ready
=
trace_handler
,
record_shapes
=
True
,
#on_trace_ready=torch.profiler.tensorboard_trace_handler('./torch_prof_data'))
on_trace_ready
=
trace_handler
)
)
prof
.
start
()
elif
args
.
profile
and
torch
.
distributed
.
get_rank
()
in
args
.
profile_ranks
and
args
.
use_hip_profiler
:
import
ctypes
...
...
@@ -190,7 +230,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
# Disable forward pre-hook to start training to ensure that errors in checkpoint loading
# or random initialization don't propagate to all ranks in first all-gather (which is a
# no-op if things work correctly).
if
args
.
use_distributed_optimizer
and
args
.
overlap_param_gather
:
if
should_disable_forward_pre_hook
(
args
)
:
disable_forward_pre_hook
(
model
,
param_sync
=
False
)
# Also remove param_sync_func temporarily so that sync calls made in
# `forward_backward_func` are no-ops.
...
...
@@ -199,8 +239,9 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
pre_hook_enabled
=
False
# Also, check weight hash across DP replicas to be very pedantic.
if
args
.
check_weight_hash_across_dp_replicas_interval
is
not
None
:
assert
check_param_hashes_across_dp_replicas
(
model
,
cross_check
=
True
),
\
"Parameter hashes not matching across DP replicas"
assert
check_param_hashes_across_dp_replicas
(
model
,
cross_check
=
True
),
"Parameter hashes not matching across DP replicas"
torch
.
distributed
.
barrier
()
print_rank_0
(
f
">>> Weight hashes match after
{
iteration
}
iterations..."
)
...
...
@@ -226,33 +267,60 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
# to make sure training configuration is still valid.
update_num_microbatches
(
args
.
consumed_train_samples
,
consistency_check
=
False
,
verbose
=
True
)
if
get_num_microbatches
()
!=
num_microbatches
and
iteration
!=
0
:
assert
get_num_microbatches
()
>
num_microbatches
,
\
(
f
"Number of microbatches should be increasing due to batch size rampup; "
f
"instead going from
{
num_microbatches
}
to
{
get_num_microbatches
()
}
"
)
assert
get_num_microbatches
()
>
num_microbatches
,
(
f
"Number of microbatches should be increasing due to batch size rampup; "
f
"instead going from
{
num_microbatches
}
to
{
get_num_microbatches
()
}
"
)
if
args
.
save
is
not
None
:
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
opt_param_scheduler
,
num_floating_point_operations_so_far
,
checkpointing_context
,
train_data_iterator
=
train_data_iterator
)
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
opt_param_scheduler
,
num_floating_point_operations_so_far
,
checkpointing_context
,
train_data_iterator
=
train_data_iterator
,
)
num_microbatches
=
get_num_microbatches
()
update_num_microbatches
(
args
.
consumed_train_samples
,
consistency_check
=
True
,
verbose
=
True
)
# Completely skip iteration if needed.
if
iteration
in
args
.
iterations_to_skip
:
# Dummy train_step to fast forward train_data_iterator.
dummy_train_step
(
train_data_iterator
)
iteration
+=
1
batch_size
=
(
mpu
.
get_data_parallel_world_size
()
*
args
.
micro_batch_size
*
get_num_microbatches
()
)
args
.
consumed_train_samples
+=
batch_size
args
.
skipped_train_samples
+=
batch_size
continue
# Run training step.
args
.
curr_iteration
=
iteration
ft_integration
.
on_training_step_start
()
loss_dict
,
skipped_iter
,
should_checkpoint
,
should_exit
,
exit_code
,
grad_norm
,
num_zeros_in_grad
=
\
train_step
(
forward_step_func
,
train_data_iterator
,
model
,
optimizer
,
opt_param_scheduler
,
config
)
(
loss_dict
,
skipped_iter
,
should_checkpoint
,
should_exit
,
exit_code
,
grad_norm
,
num_zeros_in_grad
,
)
=
train_step
(
forward_step_func
,
train_data_iterator
,
model
,
optimizer
,
opt_param_scheduler
,
config
)
ft_integration
.
on_training_step_end
()
if
should_checkpoint
:
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
opt_param_scheduler
,
num_floating_point_operations_so_far
,
checkpointing_context
,
train_data_iterator
=
train_data_iterator
)
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
opt_param_scheduler
,
num_floating_point_operations_so_far
,
checkpointing_context
,
train_data_iterator
=
train_data_iterator
,
)
if
should_exit
:
break
...
...
@@ -269,18 +337,19 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
# Enable forward pre-hook after training step has successfully run. All subsequent
# forward passes will use the forward pre-hook / `param_sync_func` in
# `forward_backward_func`.
if
args
.
use_distributed_optimizer
and
args
.
overlap_param_gather
:
if
should_disable_forward_pre_hook
(
args
)
:
enable_forward_pre_hook
(
model
)
config
.
param_sync_func
=
param_sync_func
pre_hook_enabled
=
True
iteration
+=
1
batch_size
=
mpu
.
get_data_parallel_world_size
()
*
\
args
.
micro_batch_size
*
\
get_num_microbatches
(
)
batch_size
=
(
mpu
.
get_data_parallel_world_size
()
*
args
.
micro_batch_size
*
get_num_microbatches
()
)
args
.
consumed_train_samples
+=
batch_size
num_skipped_samples_in_batch
=
(
get_current_global_batch_size
()
-
get_current_running_global_batch_size
())
num_skipped_samples_in_batch
=
(
get_current_global_batch_size
()
-
get_current_running_global_batch_size
()
)
if
args
.
decrease_batch_size_if_needed
:
assert
num_skipped_samples_in_batch
>=
0
else
:
...
...
@@ -306,18 +375,24 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
decoupled_learning_rate
=
param_group
[
'lr'
]
else
:
learning_rate
=
param_group
[
'lr'
]
report_memory_flag
=
training_log
(
loss_dict
,
total_loss_dict
,
learning_rate
,
decoupled_learning_rate
,
iteration
,
loss_scale
,
report_memory_flag
,
skipped_iter
,
grad_norm
,
params_norm
,
num_zeros_in_grad
)
report_memory_flag
=
training_log
(
loss_dict
,
total_loss_dict
,
learning_rate
,
decoupled_learning_rate
,
iteration
,
loss_scale
,
report_memory_flag
,
skipped_iter
,
grad_norm
,
params_norm
,
num_zeros_in_grad
,
)
# Evaluation.
if
args
.
eval_interval
and
iteration
%
args
.
eval_interval
==
0
and
\
args
.
do_valid
:
if
args
.
eval_interval
and
iteration
%
args
.
eval_interval
==
0
and
args
.
do_valid
:
timers
(
'interval-time'
).
stop
()
if
args
.
use_distributed_optimizer
and
args
.
overlap_param_gather
:
if
should_disable_forward_pre_hook
(
args
)
:
disable_forward_pre_hook
(
model
)
pre_hook_enabled
=
False
if
args
.
manual_gc
and
args
.
manual_gc_eval
:
...
...
@@ -325,11 +400,18 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
gc
.
collect
()
prefix
=
f
'iteration
{
iteration
}
'
timers
(
'eval-time'
,
log_level
=
0
).
start
(
barrier
=
True
)
evaluate_and_print_results
(
prefix
,
forward_step_func
,
valid_data_iterator
,
model
,
iteration
,
process_non_loss_data_func
,
config
,
verbose
=
False
,
write_to_tensorboard
=
True
,
non_loss_data_func
=
non_loss_data_func
)
evaluate_and_print_results
(
prefix
,
forward_step_func
,
valid_data_iterator
,
model
,
iteration
,
process_non_loss_data_func
,
config
,
verbose
=
False
,
write_to_tensorboard
=
True
,
non_loss_data_func
=
non_loss_data_func
,
)
eval_duration
+=
timers
(
'eval-time'
).
elapsed
()
eval_iterations
+=
args
.
eval_iters
timers
(
'eval-time'
).
stop
()
...
...
@@ -338,20 +420,32 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
if
args
.
manual_gc
and
args
.
manual_gc_eval
:
# Collect only the objects created and used in evaluation.
gc
.
collect
(
generation
=
0
)
if
args
.
use_distributed_optimizer
and
args
.
overlap_param_gather
:
if
should_disable_forward_pre_hook
(
args
)
:
enable_forward_pre_hook
(
model
)
pre_hook_enabled
=
True
timers
(
'interval-time'
,
log_level
=
0
).
start
(
barrier
=
True
)
# Miscellaneous post-training-step functions (e.g., FT heartbeats, GC).
# Some of these only happen at specific iterations.
post_training_step_callbacks
(
model
,
optimizer
,
opt_param_scheduler
,
iteration
,
prof
,
num_floating_point_operations_since_last_log_event
)
post_training_step_callbacks
(
model
,
optimizer
,
opt_param_scheduler
,
iteration
,
prof
,
num_floating_point_operations_since_last_log_event
,
)
# Checkpoint and decide whether to exit.
should_exit
=
checkpoint_and_decide_exit
(
model
,
optimizer
,
opt_param_scheduler
,
iteration
,
num_floating_point_operations_so_far
,
checkpointing_context
,
train_data_iterator
)
should_exit
=
checkpoint_and_decide_exit
(
model
,
optimizer
,
opt_param_scheduler
,
iteration
,
num_floating_point_operations_so_far
,
checkpointing_context
,
train_data_iterator
,
)
if
should_exit
:
break
...
...
@@ -367,8 +461,12 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
disable_forward_pre_hook
(
model
)
ft_integration
.
on_checkpointing_start
()
maybe_finalize_async_save
(
blocking
=
True
)
# This will finalize all unfinalized async request and terminate
# a persistent async worker if persistent ckpt worker is enabled
maybe_finalize_async_save
(
blocking
=
True
,
terminate
=
True
)
ft_integration
.
on_checkpointing_end
(
is_async_finalization
=
True
)
if
args
.
enable_ft_package
and
ft_integration
.
get_rank_monitor_client
()
is
not
None
:
ft_integration
.
get_rank_monitor_client
().
shutdown_workload_monitoring
()
# If any exit conditions (signal handler, duration, iterations) have been reached, exit.
if
should_exit
:
...
...
@@ -376,6 +474,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
if
wandb_writer
:
wandb_writer
.
finish
()
ft_integration
.
shutdown
()
one_logger_utils
.
finish
()
sys
.
exit
(
exit_code
)
return
iteration
,
num_floating_point_operations_so_far
dcu_megatron/training/utils.py
deleted
100644 → 0
View file @
2ddbd4be
import
torch
from
megatron.core
import
mpu
from
megatron.training
import
get_args
def
get_batch_on_this_tp_rank
(
data_iterator
):
args
=
get_args
()
def
_broadcast
(
item
):
if
item
is
not
None
:
torch
.
distributed
.
broadcast
(
item
,
mpu
.
get_tensor_model_parallel_src_rank
(),
group
=
mpu
.
get_tensor_model_parallel_group
())
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
if
data_iterator
is
not
None
:
data
=
next
(
data_iterator
)
else
:
data
=
None
batch
=
{
'tokens'
:
data
[
"tokens"
].
cuda
(
non_blocking
=
True
),
'labels'
:
data
[
"labels"
].
cuda
(
non_blocking
=
True
),
'loss_mask'
:
data
[
"loss_mask"
].
cuda
(
non_blocking
=
True
),
'attention_mask'
:
None
if
"attention_mask"
not
in
data
else
data
[
"attention_mask"
].
cuda
(
non_blocking
=
True
),
'position_ids'
:
data
[
"position_ids"
].
cuda
(
non_blocking
=
True
)
}
if
args
.
pipeline_model_parallel_size
==
1
:
_broadcast
(
batch
[
'tokens'
])
_broadcast
(
batch
[
'labels'
])
_broadcast
(
batch
[
'loss_mask'
])
_broadcast
(
batch
[
'attention_mask'
])
_broadcast
(
batch
[
'position_ids'
])
elif
mpu
.
is_pipeline_first_stage
():
_broadcast
(
batch
[
'tokens'
])
_broadcast
(
batch
[
'attention_mask'
])
_broadcast
(
batch
[
'position_ids'
])
elif
mpu
.
is_pipeline_last_stage
():
if
args
.
num_nextn_predict_layers
:
_broadcast
(
batch
[
'tokens'
])
_broadcast
(
batch
[
'labels'
])
_broadcast
(
batch
[
'loss_mask'
])
_broadcast
(
batch
[
'attention_mask'
])
if
args
.
reset_position_ids
or
args
.
num_nextn_predict_layers
:
_broadcast
(
batch
[
'position_ids'
])
else
:
tokens
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
+
args
.
num_nextn_predict_layers
),
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
labels
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
+
args
.
num_nextn_predict_layers
),
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
loss_mask
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
+
args
.
num_nextn_predict_layers
),
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
())
if
args
.
create_attention_mask_in_dataloader
:
attention_mask
=
torch
.
empty
(
(
args
.
micro_batch_size
,
1
,
args
.
seq_length
+
args
.
num_nextn_predict_layers
,
args
.
seq_length
+
args
.
num_nextn_predict_layers
),
dtype
=
torch
.
bool
,
device
=
torch
.
cuda
.
current_device
()
)
else
:
attention_mask
=
None
position_ids
=
torch
.
empty
((
args
.
micro_batch_size
,
args
.
seq_length
+
args
.
num_nextn_predict_layers
),
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
if
args
.
pipeline_model_parallel_size
==
1
:
_broadcast
(
tokens
)
_broadcast
(
labels
)
_broadcast
(
loss_mask
)
_broadcast
(
attention_mask
)
_broadcast
(
position_ids
)
elif
mpu
.
is_pipeline_first_stage
():
labels
=
None
loss_mask
=
None
_broadcast
(
tokens
)
_broadcast
(
attention_mask
)
_broadcast
(
position_ids
)
elif
mpu
.
is_pipeline_last_stage
():
if
args
.
num_nextn_predict_layers
:
_broadcast
(
tokens
)
else
:
tokens
=
None
_broadcast
(
labels
)
_broadcast
(
loss_mask
)
_broadcast
(
attention_mask
)
if
args
.
reset_position_ids
or
args
.
num_nextn_predict_layers
:
_broadcast
(
position_ids
)
else
:
position_ids
=
None
batch
=
{
'tokens'
:
tokens
,
'labels'
:
labels
,
'loss_mask'
:
loss_mask
,
'attention_mask'
:
attention_mask
,
'position_ids'
:
position_ids
}
return
batch
examples/deepseek_v3/run_deepseek_v3_1node.sh
0 → 100644
View file @
a8a2bbea
for
para
in
$*
do
if
[[
$para
==
--profiling
*
]]
;
then
profiling
=
${
para
#*=
}
export
GPU_FLUSH_ON_EXECUTION
=
1
export
HIP_DIRECT_DISPATCH
=
0
fi
done
mpirun
-np
8
--allow-run-as-root
\
train_deepseek_v3_1node.sh localhost
--profiling
=
$profiling
>
output.log 2>&1
wait
rm
-rf
CKPT
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment