''' QH=num_attention_heads KVH=1 QKD=kv_lora_rank + qk_rope_head_dim VD=kv_lora_rank deepseek v2: QH=16, QKD=512+64=576, VD=512 deepseek v3: QH=128, QKD=512+64=576, VD=512 python3 gen_mla_pa_tables.py --QH 16 --KVH 1 --QKD 576 --VD 512 --TP 1 python3 gen_mla_pa_tables.py --QH 16 --KVH 1 --QKD 576 --VD 512 --TP 2 python3 gen_mla_pa_tables.py --QH 16 --KVH 1 --QKD 576 --VD 512 --TP 4 python3 gen_mla_pa_tables.py --QH 16 --KVH 1 --QKD 576 --VD 512 --TP 8 python3 gen_mla_pa_tables.py --QH 128 --KVH 1 --QKD 576 --VD 512 --TP 32 ''' import argparse def parse_args(): parser = argparse.ArgumentParser(description="Generate mla pa tables (B, SEQ, QH, KVH, QKD, VD).") parser.add_argument('--SEQSTART', type=int, required=False, default=100, help='Value for SEQSTART') parser.add_argument('--SEQEND', type=int, required=False, default=8193, help='Value for SEQEND') parser.add_argument('--QH', type=int, required=True, help='Value for QH') parser.add_argument('--KVH', type=int, required=True, help='Value for KVH') parser.add_argument('--QKD', type=int, required=True, help='Value for QKD') parser.add_argument('--VD', type=int, required=True, help='Value for VD') parser.add_argument('--TP', type=int, required=True, help='Value for TP') return parser.parse_args() def generate_B(): # 1 4 8 16 ... 128,144 160 ... 256, 288 320 ... 512 B_values = [1, 4] + [i * 8 for i in range(1, 17)] # B_values = [1, 4] + [i * 8 for i in range(1, 17)] + [(i + 1) * 16 for i in range(8, 16)] + [(i + 1) * 32 for i in range(8, 16)] return B_values def generate_SEQ_bs1_32(seq_start, seq_end): SEQ_values_part1 = list(range(seq_start, 5000, 300)) SEQ_values_part2 = list(range(5000, seq_end, 500)) SEQ_values = [1] + SEQ_values_part1 + SEQ_values_part2 return SEQ_values def generate_SEQ_bs40_96(seq_start, seq_end): SEQ_values_part1 = list(range(seq_start, 5000, 200)) SEQ_values_part2 = list(range(5000, seq_end, 300)) SEQ_values = [1] + SEQ_values_part1 + SEQ_values_part2 return SEQ_values def generate_SEQ_other(seq_start, seq_end): SEQ_values_part1 = list(range(seq_start, 5000, 300)) SEQ_values_part2 = list(range(5000, seq_end, 500)) SEQ_values = [1] + SEQ_values_part1 + SEQ_values_part2 return SEQ_values def generate_tuples(seq_start, seq_end, QH, KVH, QKD, VD): B_values = generate_B() SEQ_values_b1_32 = generate_SEQ_bs1_32(seq_start, seq_end) SEQ_values_b40_96 = generate_SEQ_bs40_96(seq_start, seq_end) SEQ_values_other = generate_SEQ_other(seq_start, seq_end) pa_tables = [] for B in B_values: if B <= 32: for SEQ in SEQ_values_b1_32: pa_tables.append((B, SEQ, QH, KVH, QKD, VD)) elif B <= 96 and B > 32: for SEQ in SEQ_values_b40_96: pa_tables.append((B, SEQ, QH, KVH, QKD, VD)) else: for SEQ in SEQ_values_other: pa_tables.append((B, SEQ, QH, KVH, QKD, VD)) return pa_tables if __name__ == "__main__": args = parse_args() assert args.QH % args.TP == 0, "error QH, QH % TP must be 0" tuple_sizes = generate_tuples(args.SEQSTART, args.SEQEND, args.QH // args.TP, max(args.KVH // args.TP, 1), args.QKD, args.VD) for t in tuple_sizes: print(f"{t},")