Commit f967a24c authored by dongcl's avatar dongcl
Browse files

mla参数量计算

parent 7b78db5d
Pipeline #2491 passed with stage
......@@ -18,9 +18,10 @@ def compute_weight_and_optimizer_memory(args, verbose=False):
# MoE.
num_experts = 1 if args.num_experts is None else args.num_experts
gated_linear_multiplier = 3 / 2 if args.swiglu else 1
num_parameters_in_transformer_layers = (
# MLA
if not args.multi_latent_attention:
num_parameters_in_transformer_block = (
2
* args.num_layers
* args.hidden_size
* args.hidden_size
* (
......@@ -33,10 +34,43 @@ def compute_weight_and_optimizer_memory(args, verbose=False):
+ ((args.ffn_hidden_size / args.hidden_size) * num_experts * gated_linear_multiplier)
# Transformer layernorms.
+ (2 / args.hidden_size)
# Final layernorm.
+ (1 / (args.num_layers * args.hidden_size))
)
)
else:
q_head_dim = args.qk_head_dim + args.qk_pos_emb_head_dim
query_projection_size = args.v_head_dim * args.num_attention_heads
num_parameters_in_transformer_block = (
# Attention.
(
# q_down
args.hidden_size * args.q_lora_rank
# q_up
+ args.q_lora_rank * (args.num_attention_heads * q_head_dim)
# kv_down
+ args.hidden_size * (args.kv_lora_rank + args.qk_pos_emb_head_dim)
# kv_up
+ args.kv_lora_rank * (args.num_attention_heads * (args.qk_head_dim + args.v_head_dim))
# q_layernorm
+ 2 * args.q_lora_rank
# kv_layernorm
+ 2 * args.kv_lora_rank
# linear_proj
+ query_projection_size * args.hidden_size
)
# routed experts.
+ (2 * (args.ffn_hidden_size * args.hidden_size) * num_experts * gated_linear_multiplier)
# shared experts.
+ (2 * args.moe_shared_expert_intermediate_size * args.hidden_size)
# Transformer layernorms.
+ (4 * args.hidden_size)
)
num_parameters_in_transformer_layers = (
args.num_layers * num_parameters_in_transformer_block
# Final layernorm.
+ (2 * args.hidden_size)
)
embedding_size = args.hidden_size * args.padded_vocab_size
if args.untie_embeddings_and_output_weights:
num_parameters_in_embedding_layers = 2 * embedding_size
......@@ -45,22 +79,14 @@ def compute_weight_and_optimizer_memory(args, verbose=False):
# mtp
num_parameters_in_mtp_layers = (
2
* args.num_nextn_predict_layers
* args.hidden_size
* args.hidden_size
args.num_nextn_predict_layers
* (
# Attention.
(
(1 + (args.num_query_groups / args.num_attention_heads))
* query_projection_to_hidden_size_ratio
)
# MLP.
+ ((args.ffn_hidden_size / args.hidden_size) * num_experts * gated_linear_multiplier)
# transformer block.
num_parameters_in_transformer_block
# layernorms.
+ (3 / args.hidden_size)
+ (6 * args.hidden_size)
# linear projection.
+ 1
+ 2 * args.hidden_size * args.hidden_size
)
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment