Commit f967a24c authored by dongcl's avatar dongcl
Browse files

mla参数量计算

parent 7b78db5d
Pipeline #2491 passed with stage
...@@ -18,9 +18,10 @@ def compute_weight_and_optimizer_memory(args, verbose=False): ...@@ -18,9 +18,10 @@ def compute_weight_and_optimizer_memory(args, verbose=False):
# MoE. # MoE.
num_experts = 1 if args.num_experts is None else args.num_experts num_experts = 1 if args.num_experts is None else args.num_experts
gated_linear_multiplier = 3 / 2 if args.swiglu else 1 gated_linear_multiplier = 3 / 2 if args.swiglu else 1
num_parameters_in_transformer_layers = ( # MLA
if not args.multi_latent_attention:
num_parameters_in_transformer_block = (
2 2
* args.num_layers
* args.hidden_size * args.hidden_size
* args.hidden_size * args.hidden_size
* ( * (
...@@ -33,10 +34,43 @@ def compute_weight_and_optimizer_memory(args, verbose=False): ...@@ -33,10 +34,43 @@ def compute_weight_and_optimizer_memory(args, verbose=False):
+ ((args.ffn_hidden_size / args.hidden_size) * num_experts * gated_linear_multiplier) + ((args.ffn_hidden_size / args.hidden_size) * num_experts * gated_linear_multiplier)
# Transformer layernorms. # Transformer layernorms.
+ (2 / args.hidden_size) + (2 / args.hidden_size)
# Final layernorm.
+ (1 / (args.num_layers * args.hidden_size))
) )
) )
else:
q_head_dim = args.qk_head_dim + args.qk_pos_emb_head_dim
query_projection_size = args.v_head_dim * args.num_attention_heads
num_parameters_in_transformer_block = (
# Attention.
(
# q_down
args.hidden_size * args.q_lora_rank
# q_up
+ args.q_lora_rank * (args.num_attention_heads * q_head_dim)
# kv_down
+ args.hidden_size * (args.kv_lora_rank + args.qk_pos_emb_head_dim)
# kv_up
+ args.kv_lora_rank * (args.num_attention_heads * (args.qk_head_dim + args.v_head_dim))
# q_layernorm
+ 2 * args.q_lora_rank
# kv_layernorm
+ 2 * args.kv_lora_rank
# linear_proj
+ query_projection_size * args.hidden_size
)
# routed experts.
+ (2 * (args.ffn_hidden_size * args.hidden_size) * num_experts * gated_linear_multiplier)
# shared experts.
+ (2 * args.moe_shared_expert_intermediate_size * args.hidden_size)
# Transformer layernorms.
+ (4 * args.hidden_size)
)
num_parameters_in_transformer_layers = (
args.num_layers * num_parameters_in_transformer_block
# Final layernorm.
+ (2 * args.hidden_size)
)
embedding_size = args.hidden_size * args.padded_vocab_size embedding_size = args.hidden_size * args.padded_vocab_size
if args.untie_embeddings_and_output_weights: if args.untie_embeddings_and_output_weights:
num_parameters_in_embedding_layers = 2 * embedding_size num_parameters_in_embedding_layers = 2 * embedding_size
...@@ -45,22 +79,14 @@ def compute_weight_and_optimizer_memory(args, verbose=False): ...@@ -45,22 +79,14 @@ def compute_weight_and_optimizer_memory(args, verbose=False):
# mtp # mtp
num_parameters_in_mtp_layers = ( num_parameters_in_mtp_layers = (
2 args.num_nextn_predict_layers
* args.num_nextn_predict_layers
* args.hidden_size
* args.hidden_size
* ( * (
# Attention. # transformer block.
( num_parameters_in_transformer_block
(1 + (args.num_query_groups / args.num_attention_heads))
* query_projection_to_hidden_size_ratio
)
# MLP.
+ ((args.ffn_hidden_size / args.hidden_size) * num_experts * gated_linear_multiplier)
# layernorms. # layernorms.
+ (3 / args.hidden_size) + (6 * args.hidden_size)
# linear projection. # linear projection.
+ 1 + 2 * args.hidden_size * args.hidden_size
) )
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment