Commit 9daa400a authored by dongcl's avatar dongcl
Browse files

修改增加mtp后flops、参数量计算

parent f00f0256
Pipeline #2464 passed with stage
......@@ -23,6 +23,8 @@ class LanguageModelEmbedding(MegatronModule):
num_tokentypes (int): Set to 0 without binary head, and 2 with a binary head. Defaults to 0.
scatter_to_sequence_parallel (bool): Set to False to disable scatter of embedding
across sequence parallel region. Defaults to True.
skip_weight_param_allocation: If True, weight parameter is not allocated and must be passed
as a keyword argument `weight` during the forward pass. Defaults to False.
"""
def __init__(
......
......@@ -21,7 +21,6 @@ from megatron.core.transformer.spec_utils import build_module
from megatron.core.transformer.transformer_block import TransformerBlock
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.mtp.multi_token_predictor import MultiTokenPredictor
from megatron.core.extensions.transformer_engine import TENorm
class GPTModel(LanguageModule):
......@@ -138,8 +137,7 @@ class GPTModel(LanguageModule):
config=self.config,
spec=transformer_layer_spec,
pre_process=self.pre_process,
post_process=self.post_process,
num_nextn_predict_layers=num_nextn_predict_layers
post_process=self.post_process
)
# Output
......@@ -204,17 +202,6 @@ class GPTModel(LanguageModule):
]
)
if self.post_process and self.num_nextn_predict_layers:
# move block main model final norms here
self.final_layernorm = build_module(
TENorm,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
else:
self.final_layernorm = None
if self.pre_process or self.post_process:
self.setup_embeddings_and_output_layer()
......@@ -472,9 +459,12 @@ class GPTModel(LanguageModule):
loss += self.mtp_loss_scale / self.num_nextn_predict_layers * mtp_loss
if self.num_nextn_predict_layers and self.final_layernorm is not None:
if (
self.num_nextn_predict_layers
and getattr(self.decoder, final_layernorm, None) is not None
):
# move block main model final norms here
hidden_states = self.final_layernorm(hidden_states)
hidden_states = self.decoder.final_layernorm(hidden_states)
logits, _ = self.output_layer(
hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output
......
......@@ -181,6 +181,8 @@ class VocabParallelEmbedding(torch.nn.Module):
Keyword Args:
config: A megatron.core.ModelParallelConfig object
skip_weight_param_allocation: If True, weight parameter is not allocated and must be passed
as a keyword argument `weight` during the forward pass. Defaults to False.
"""
def __init__(
......
......@@ -177,8 +177,7 @@ class TransformerBlock(MegatronModule):
spec: Union[TransformerBlockSubmodules, ModuleSpec],
post_layer_norm: bool = True,
pre_process: bool = True,
post_process: bool = True,
num_nextn_predict_layers: int = 0
post_process: bool = True
):
super().__init__(config=config)
......@@ -225,6 +224,8 @@ class TransformerBlock(MegatronModule):
self._build_layers()
self.num_layers_per_pipeline_rank = len(self.layers)
self.tp_only_amax_red = config.tp_only_amax_red
# mtp require seperate layernorms for main model and mtp modules, thus move finalnorm out of block
self.move_final_norm_out_of_block = getattr(config, num_nextn_predict_layers, 0) > 0
def _build_layers(self):
# Transformer layers.
......@@ -247,9 +248,7 @@ class TransformerBlock(MegatronModule):
# @TODO: add back standalone_embedding_stage (see issue #293)
# In pipeline parallelism, we want to add this LN only to the last stage of the pipeline
# self.post_process and self.post_layer_norm guide this behavior
# mtp require seperate layernorms for main model and mtp modules, thus move finalnorm out of block
move_final_norm_out_of_block = self.num_nextn_predict_layers > 0
if self.submodules.layer_norm and self.post_process and self.post_layer_norm and not move_final_norm_out_of_block:
if self.submodules.layer_norm and self.post_process and self.post_layer_norm:
self.final_layernorm = build_module(
self.submodules.layer_norm,
config=self.config,
......@@ -549,7 +548,7 @@ class TransformerBlock(MegatronModule):
hidden_states = self.group_prefetch_offload_commit_async(hidden_states)
# Final layer norm.
if self.final_layernorm is not None:
if self.final_layernorm is not None and not self.move_final_norm_out_of_block:
hidden_states = self.final_layernorm(hidden_states)
# TENorm produces a "viewed" tensor. This will result in schedule.py's
# deallocate_output_tensor() throwing an error, so a viewless tensor is
......
......@@ -367,6 +367,24 @@ class TransformerConfig(ModelParallelConfig):
moe_layer_recompute: bool = False
"""Memory optimization: checkpointing moe_layer to save actiavtion memory."""
##################
# multi-token prediction
##################
num_nextn_predict_layers: int = 0
"""The number of multi-token prediction layers"""
mtp_loss_scale: float = 0.3
"""Multi-token prediction loss scale"""
recompute_mtp_norm: bool = False
"""Whether to recompute mtp normalization"""
recompute_mtp_layer: bool = False
"""Whether to recompute mtp layer"""
share_mtp_embedding_and_output_weight: bool = False
"""share embedding and output weight with mtp layer."""
##################
# Context Parallel
##################
......
......@@ -42,7 +42,33 @@ def compute_weight_and_optimizer_memory(args, verbose=False):
num_parameters_in_embedding_layers = 2 * embedding_size
else:
num_parameters_in_embedding_layers = embedding_size
num_total_parameters = num_parameters_in_transformer_layers + num_parameters_in_embedding_layers
# mtp
num_parameters_in_mtp_layers = (
2
* args.num_nextn_predict_layers
* args.hidden_size
* args.hidden_size
* (
# Attention.
(
(1 + (args.num_query_groups / args.num_attention_heads))
* query_projection_to_hidden_size_ratio
)
# MLP.
+ ((args.ffn_hidden_size / args.hidden_size) * num_experts * gated_linear_multiplier)
# layernorms.
+ (3 / args.hidden_size)
# linear projection.
+ 1
)
)
# params of mtp embedding and mtp output layer
if not args.share_mtp_embedding_and_output_weight:
num_parameters_in_mtp_layers += 2 * args.num_nextn_predict_layers * args.hidden_size * args.padded_vocab_size
num_total_parameters = num_parameters_in_transformer_layers + num_parameters_in_embedding_layers + num_parameters_in_mtp_layers
if verbose:
print(
f"Number of parameters in transformer layers in billions: "
......@@ -52,16 +78,21 @@ def compute_weight_and_optimizer_memory(args, verbose=False):
f"Number of parameters in embedding layers in billions: "
f"{num_parameters_in_embedding_layers / 10**9:.2f}"
)
print(
f"Number of parameters in mtp layers in billions: "
f"{num_parameters_in_mtp_layers / 10**9:.2f}"
)
print(f"Total number of parameters in billions: {num_total_parameters / 10**9:.2f}")
# Most loaded model shard has (1/pp_size transformer layers + 1 embedding layer) / tp_size.
num_parameters_on_most_loaded_model_shard = (
(num_parameters_in_transformer_layers / args.pipeline_model_parallel_size) + embedding_size
(num_parameters_in_transformer_layers / args.pipeline_model_parallel_size) + embedding_size + num_parameters_in_mtp_layers
) / args.tensor_model_parallel_size
if args.untie_embeddings_and_output_weights and args.pipeline_model_parallel_size == 1:
num_parameters_on_most_loaded_model_shard += (
embedding_size / args.tensor_model_parallel_size
)
if verbose:
print(
f"Number of parameters in most loaded shard in billions: "
......
......@@ -138,7 +138,7 @@ def num_floating_point_operations(args, batch_size):
# - 2x: A GEMM of a m*n tensor with a n*k tensor requires 2mnk floating-point operations.
expansion_factor = 3 * 2 * 2
return (
num_backbone_flops = (
expansion_factor
* batch_size
* args.seq_length
......@@ -167,6 +167,40 @@ def num_floating_point_operations(args, batch_size):
)
)
# mtp flops
num_mtp_flops = (
expansion_factor
* batch_size
* args.seq_length
* args.num_nextn_predict_layers
* args.hidden_size
* args.hidden_size
* (
# Attention.
(
(
1
+ (args.num_query_groups / args.num_attention_heads)
+ (args.seq_length / args.hidden_size)
) * query_projection_to_hidden_size_ratio
)
# MLP.
+ (
(args.ffn_hidden_size / args.hidden_size)
* num_experts_routed_to
* gated_linear_multiplier
)
# Shared Experts.
+ ((shared_expert_ffn_hidden_size / args.hidden_size) * gated_linear_multiplier)
# Logit.
+ (args.padded_vocab_size / (2 * args.hidden_size))
# mtp linear projection
+ 1
)
)
return num_backbone_flops + num_mtp_flops
def get_start_time_from_progress_log():
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment