Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
9daa400a
Commit
9daa400a
authored
Mar 10, 2025
by
dongcl
Browse files
修改增加mtp后flops、参数量计算
parent
f00f0256
Pipeline
#2464
passed with stage
Changes
7
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
101 additions
and
25 deletions
+101
-25
megatron/core/models/common/embeddings/language_model_embedding.py
...core/models/common/embeddings/language_model_embedding.py
+2
-0
megatron/core/models/gpt/gpt_model.py
megatron/core/models/gpt/gpt_model.py
+6
-16
megatron/core/tensor_parallel/layers.py
megatron/core/tensor_parallel/layers.py
+2
-0
megatron/core/transformer/transformer_block.py
megatron/core/transformer/transformer_block.py
+5
-6
megatron/core/transformer/transformer_config.py
megatron/core/transformer/transformer_config.py
+18
-0
megatron/training/theoretical_memory_usage.py
megatron/training/theoretical_memory_usage.py
+33
-2
megatron/training/training.py
megatron/training/training.py
+35
-1
No files found.
megatron/core/models/common/embeddings/language_model_embedding.py
View file @
9daa400a
...
...
@@ -23,6 +23,8 @@ class LanguageModelEmbedding(MegatronModule):
num_tokentypes (int): Set to 0 without binary head, and 2 with a binary head. Defaults to 0.
scatter_to_sequence_parallel (bool): Set to False to disable scatter of embedding
across sequence parallel region. Defaults to True.
skip_weight_param_allocation: If True, weight parameter is not allocated and must be passed
as a keyword argument `weight` during the forward pass. Defaults to False.
"""
def
__init__
(
...
...
megatron/core/models/gpt/gpt_model.py
View file @
9daa400a
...
...
@@ -21,7 +21,6 @@ from megatron.core.transformer.spec_utils import build_module
from
megatron.core.transformer.transformer_block
import
TransformerBlock
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core.transformer.mtp.multi_token_predictor
import
MultiTokenPredictor
from
megatron.core.extensions.transformer_engine
import
TENorm
class
GPTModel
(
LanguageModule
):
...
...
@@ -138,8 +137,7 @@ class GPTModel(LanguageModule):
config
=
self
.
config
,
spec
=
transformer_layer_spec
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
,
num_nextn_predict_layers
=
num_nextn_predict_layers
post_process
=
self
.
post_process
)
# Output
...
...
@@ -204,17 +202,6 @@ class GPTModel(LanguageModule):
]
)
if
self
.
post_process
and
self
.
num_nextn_predict_layers
:
# move block main model final norms here
self
.
final_layernorm
=
build_module
(
TENorm
,
config
=
self
.
config
,
hidden_size
=
self
.
config
.
hidden_size
,
eps
=
self
.
config
.
layernorm_epsilon
,
)
else
:
self
.
final_layernorm
=
None
if
self
.
pre_process
or
self
.
post_process
:
self
.
setup_embeddings_and_output_layer
()
...
...
@@ -472,9 +459,12 @@ class GPTModel(LanguageModule):
loss
+=
self
.
mtp_loss_scale
/
self
.
num_nextn_predict_layers
*
mtp_loss
if
self
.
num_nextn_predict_layers
and
self
.
final_layernorm
is
not
None
:
if
(
self
.
num_nextn_predict_layers
and
getattr
(
self
.
decoder
,
final_layernorm
,
None
)
is
not
None
):
# move block main model final norms here
hidden_states
=
self
.
final_layernorm
(
hidden_states
)
hidden_states
=
self
.
decoder
.
final_layernorm
(
hidden_states
)
logits
,
_
=
self
.
output_layer
(
hidden_states
,
weight
=
output_weight
,
runtime_gather_output
=
runtime_gather_output
...
...
megatron/core/tensor_parallel/layers.py
View file @
9daa400a
...
...
@@ -181,6 +181,8 @@ class VocabParallelEmbedding(torch.nn.Module):
Keyword Args:
config: A megatron.core.ModelParallelConfig object
skip_weight_param_allocation: If True, weight parameter is not allocated and must be passed
as a keyword argument `weight` during the forward pass. Defaults to False.
"""
def
__init__
(
...
...
megatron/core/transformer/transformer_block.py
View file @
9daa400a
...
...
@@ -177,8 +177,7 @@ class TransformerBlock(MegatronModule):
spec
:
Union
[
TransformerBlockSubmodules
,
ModuleSpec
],
post_layer_norm
:
bool
=
True
,
pre_process
:
bool
=
True
,
post_process
:
bool
=
True
,
num_nextn_predict_layers
:
int
=
0
post_process
:
bool
=
True
):
super
().
__init__
(
config
=
config
)
...
...
@@ -225,6 +224,8 @@ class TransformerBlock(MegatronModule):
self
.
_build_layers
()
self
.
num_layers_per_pipeline_rank
=
len
(
self
.
layers
)
self
.
tp_only_amax_red
=
config
.
tp_only_amax_red
# mtp require seperate layernorms for main model and mtp modules, thus move finalnorm out of block
self
.
move_final_norm_out_of_block
=
getattr
(
config
,
num_nextn_predict_layers
,
0
)
>
0
def
_build_layers
(
self
):
# Transformer layers.
...
...
@@ -247,9 +248,7 @@ class TransformerBlock(MegatronModule):
# @TODO: add back standalone_embedding_stage (see issue #293)
# In pipeline parallelism, we want to add this LN only to the last stage of the pipeline
# self.post_process and self.post_layer_norm guide this behavior
# mtp require seperate layernorms for main model and mtp modules, thus move finalnorm out of block
move_final_norm_out_of_block
=
self
.
num_nextn_predict_layers
>
0
if
self
.
submodules
.
layer_norm
and
self
.
post_process
and
self
.
post_layer_norm
and
not
move_final_norm_out_of_block
:
if
self
.
submodules
.
layer_norm
and
self
.
post_process
and
self
.
post_layer_norm
:
self
.
final_layernorm
=
build_module
(
self
.
submodules
.
layer_norm
,
config
=
self
.
config
,
...
...
@@ -549,7 +548,7 @@ class TransformerBlock(MegatronModule):
hidden_states
=
self
.
group_prefetch_offload_commit_async
(
hidden_states
)
# Final layer norm.
if
self
.
final_layernorm
is
not
None
:
if
self
.
final_layernorm
is
not
None
and
not
self
.
move_final_norm_out_of_block
:
hidden_states
=
self
.
final_layernorm
(
hidden_states
)
# TENorm produces a "viewed" tensor. This will result in schedule.py's
# deallocate_output_tensor() throwing an error, so a viewless tensor is
...
...
megatron/core/transformer/transformer_config.py
View file @
9daa400a
...
...
@@ -367,6 +367,24 @@ class TransformerConfig(ModelParallelConfig):
moe_layer_recompute
:
bool
=
False
"""Memory optimization: checkpointing moe_layer to save actiavtion memory."""
##################
# multi-token prediction
##################
num_nextn_predict_layers
:
int
=
0
"""The number of multi-token prediction layers"""
mtp_loss_scale
:
float
=
0.3
"""Multi-token prediction loss scale"""
recompute_mtp_norm
:
bool
=
False
"""Whether to recompute mtp normalization"""
recompute_mtp_layer
:
bool
=
False
"""Whether to recompute mtp layer"""
share_mtp_embedding_and_output_weight
:
bool
=
False
"""share embedding and output weight with mtp layer."""
##################
# Context Parallel
##################
...
...
megatron/training/theoretical_memory_usage.py
View file @
9daa400a
...
...
@@ -42,7 +42,33 @@ def compute_weight_and_optimizer_memory(args, verbose=False):
num_parameters_in_embedding_layers
=
2
*
embedding_size
else
:
num_parameters_in_embedding_layers
=
embedding_size
num_total_parameters
=
num_parameters_in_transformer_layers
+
num_parameters_in_embedding_layers
# mtp
num_parameters_in_mtp_layers
=
(
2
*
args
.
num_nextn_predict_layers
*
args
.
hidden_size
*
args
.
hidden_size
*
(
# Attention.
(
(
1
+
(
args
.
num_query_groups
/
args
.
num_attention_heads
))
*
query_projection_to_hidden_size_ratio
)
# MLP.
+
((
args
.
ffn_hidden_size
/
args
.
hidden_size
)
*
num_experts
*
gated_linear_multiplier
)
# layernorms.
+
(
3
/
args
.
hidden_size
)
# linear projection.
+
1
)
)
# params of mtp embedding and mtp output layer
if
not
args
.
share_mtp_embedding_and_output_weight
:
num_parameters_in_mtp_layers
+=
2
*
args
.
num_nextn_predict_layers
*
args
.
hidden_size
*
args
.
padded_vocab_size
num_total_parameters
=
num_parameters_in_transformer_layers
+
num_parameters_in_embedding_layers
+
num_parameters_in_mtp_layers
if
verbose
:
print
(
f
"Number of parameters in transformer layers in billions: "
...
...
@@ -52,16 +78,21 @@ def compute_weight_and_optimizer_memory(args, verbose=False):
f
"Number of parameters in embedding layers in billions: "
f
"
{
num_parameters_in_embedding_layers
/
10
**
9
:.
2
f
}
"
)
print
(
f
"Number of parameters in mtp layers in billions: "
f
"
{
num_parameters_in_mtp_layers
/
10
**
9
:.
2
f
}
"
)
print
(
f
"Total number of parameters in billions:
{
num_total_parameters
/
10
**
9
:.
2
f
}
"
)
# Most loaded model shard has (1/pp_size transformer layers + 1 embedding layer) / tp_size.
num_parameters_on_most_loaded_model_shard
=
(
(
num_parameters_in_transformer_layers
/
args
.
pipeline_model_parallel_size
)
+
embedding_size
(
num_parameters_in_transformer_layers
/
args
.
pipeline_model_parallel_size
)
+
embedding_size
+
num_parameters_in_mtp_layers
)
/
args
.
tensor_model_parallel_size
if
args
.
untie_embeddings_and_output_weights
and
args
.
pipeline_model_parallel_size
==
1
:
num_parameters_on_most_loaded_model_shard
+=
(
embedding_size
/
args
.
tensor_model_parallel_size
)
if
verbose
:
print
(
f
"Number of parameters in most loaded shard in billions: "
...
...
megatron/training/training.py
View file @
9daa400a
...
...
@@ -138,7 +138,7 @@ def num_floating_point_operations(args, batch_size):
# - 2x: A GEMM of a m*n tensor with a n*k tensor requires 2mnk floating-point operations.
expansion_factor
=
3
*
2
*
2
return
(
num_backbone_flops
=
(
expansion_factor
*
batch_size
*
args
.
seq_length
...
...
@@ -167,6 +167,40 @@ def num_floating_point_operations(args, batch_size):
)
)
# mtp flops
num_mtp_flops
=
(
expansion_factor
*
batch_size
*
args
.
seq_length
*
args
.
num_nextn_predict_layers
*
args
.
hidden_size
*
args
.
hidden_size
*
(
# Attention.
(
(
1
+
(
args
.
num_query_groups
/
args
.
num_attention_heads
)
+
(
args
.
seq_length
/
args
.
hidden_size
)
)
*
query_projection_to_hidden_size_ratio
)
# MLP.
+
(
(
args
.
ffn_hidden_size
/
args
.
hidden_size
)
*
num_experts_routed_to
*
gated_linear_multiplier
)
# Shared Experts.
+
((
shared_expert_ffn_hidden_size
/
args
.
hidden_size
)
*
gated_linear_multiplier
)
# Logit.
+
(
args
.
padded_vocab_size
/
(
2
*
args
.
hidden_size
))
# mtp linear projection
+
1
)
)
return
num_backbone_flops
+
num_mtp_flops
def
get_start_time_from_progress_log
():
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment