Commit 65f7a409 authored by dongcl's avatar dongcl
Browse files

dense/moe参数量计算

parent 557b84d8
Pipeline #2494 passed with stage
...@@ -18,7 +18,38 @@ def compute_weight_and_optimizer_memory(args, verbose=False): ...@@ -18,7 +18,38 @@ def compute_weight_and_optimizer_memory(args, verbose=False):
# MoE. # MoE.
num_experts = 1 if args.num_experts is None else args.num_experts num_experts = 1 if args.num_experts is None else args.num_experts
gated_linear_multiplier = 3 / 2 if args.swiglu else 1 gated_linear_multiplier = 3 / 2 if args.swiglu else 1
# MLA
# MLP
if isinstance(args.moe_layer_freq, int):
moe_layer_pattern = [
1 if (i % args.moe_layer_freq == 0) else 0 for i in range(args.num_layers)
]
elif isinstance(args.moe_layer_freq, list):
moe_layer_pattern = args.moe_layer_freq
assert len(moe_layer_pattern) == args.num_layers, (
f"Invalid length of moe_layer_pattern: {len(moe_layer_pattern)}, "
f"expected {args.num_layers}, "
f"current moe layer pattern: {args.moe_layer_freq}"
)
else:
raise ValueError(
f"Invalid moe_layer_freq: {type(args.moe_layer_freq)}, {args.moe_layer_freq}"
)
# MLP
num_parameters_in_mlps = 0
num_parameters_in_dense_mlp = 2 * args.ffn_hidden_size * args.hidden_size * num_experts * gated_linear_multiplier
num_parameters_in_moe_mlp = (
# routed experts.
+ (2 * (args.moe_ffn_hidden_size * args.hidden_size) * num_experts * gated_linear_multiplier)
# router
+ args.hidden_size * num_experts
# shared experts.
+ (2 * args.moe_shared_expert_intermediate_size * args.hidden_size)
)
for pattern in moe_layer_pattern:
num_parameters_in_mlps += num_parameters_in_dense_mlp if pattern == 0 else num_parameters_in_moe_mlp
if not args.multi_latent_attention: if not args.multi_latent_attention:
num_parameters_in_transformer_block = ( num_parameters_in_transformer_block = (
2 2
...@@ -30,8 +61,6 @@ def compute_weight_and_optimizer_memory(args, verbose=False): ...@@ -30,8 +61,6 @@ def compute_weight_and_optimizer_memory(args, verbose=False):
(1 + (args.num_query_groups / args.num_attention_heads)) (1 + (args.num_query_groups / args.num_attention_heads))
* query_projection_to_hidden_size_ratio * query_projection_to_hidden_size_ratio
) )
# MLP.
+ ((args.ffn_hidden_size / args.hidden_size) * num_experts * gated_linear_multiplier)
# Transformer layernorms. # Transformer layernorms.
+ (2 / args.hidden_size) + (2 / args.hidden_size)
) )
...@@ -57,18 +86,15 @@ def compute_weight_and_optimizer_memory(args, verbose=False): ...@@ -57,18 +86,15 @@ def compute_weight_and_optimizer_memory(args, verbose=False):
# linear_proj # linear_proj
+ query_projection_size * args.hidden_size + query_projection_size * args.hidden_size
) )
# routed experts.
+ (2 * (args.ffn_hidden_size * args.hidden_size) * num_experts * gated_linear_multiplier)
# router
+ args.hidden_size * num_experts
# shared experts.
+ (2 * args.moe_shared_expert_intermediate_size * args.hidden_size)
# Transformer layernorms. # Transformer layernorms.
+ (4 * args.hidden_size) + (4 * args.hidden_size)
) )
num_parameters_in_transformer_layers = ( num_parameters_in_transformer_layers = (
# attention + layernorm
args.num_layers * num_parameters_in_transformer_block args.num_layers * num_parameters_in_transformer_block
# mlp
+ num_parameters_in_mlps
# Final layernorm. # Final layernorm.
+ (2 * args.hidden_size) + (2 * args.hidden_size)
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment