gen_mla_pa_tables.py 3.26 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
'''
QH=num_attention_heads KVH=1 QKD=kv_lora_rank + qk_rope_head_dim VD=kv_lora_rank
deepseek v2: QH=16, QKD=512+64=576, VD=512
deepseek v3: QH=128, QKD=512+64=576, VD=512

python3 gen_mla_pa_tables.py --QH 16 --KVH 1 --QKD 576 --VD 512 --TP 1
python3 gen_mla_pa_tables.py --QH 16 --KVH 1 --QKD 576 --VD 512 --TP 2
python3 gen_mla_pa_tables.py --QH 16 --KVH 1 --QKD 576 --VD 512 --TP 4
python3 gen_mla_pa_tables.py --QH 16 --KVH 1 --QKD 576 --VD 512 --TP 8
python3 gen_mla_pa_tables.py --QH 128 --KVH 1 --QKD 576 --VD 512 --TP 32
'''
import argparse


def parse_args():
    parser = argparse.ArgumentParser(description="Generate mla pa tables (B, SEQ, QH, KVH, QKD, VD).")
    parser.add_argument('--SEQSTART', type=int, required=False, default=100, help='Value for SEQSTART')
    parser.add_argument('--SEQEND', type=int, required=False, default=8193, help='Value for SEQEND')
    parser.add_argument('--QH', type=int, required=True, help='Value for QH')
    parser.add_argument('--KVH', type=int, required=True, help='Value for KVH')
    parser.add_argument('--QKD', type=int, required=True, help='Value for QKD')
    parser.add_argument('--VD', type=int, required=True, help='Value for VD')
    parser.add_argument('--TP', type=int, required=True, help='Value for TP')
    return parser.parse_args()


def generate_B():
    # 1 4 8 16 ... 128,144 160 ... 256, 288 320 ... 512
    B_values = [1, 4] + [i * 8 for i in range(1, 17)]
    # B_values = [1, 4] + [i * 8 for i in range(1, 17)] + [(i + 1) * 16 for i in range(8, 16)] + [(i + 1) * 32 for i in range(8, 16)]
    return B_values


def generate_SEQ_bs1_32(seq_start, seq_end):
    SEQ_values_part1 = list(range(seq_start, 5000, 300)) 
    SEQ_values_part2 = list(range(5000, seq_end, 500))  
    SEQ_values = [1] + SEQ_values_part1 + SEQ_values_part2
    return SEQ_values

def generate_SEQ_bs40_96(seq_start, seq_end):
    SEQ_values_part1 = list(range(seq_start, 5000, 200)) 
    SEQ_values_part2 = list(range(5000, seq_end, 300))  
    SEQ_values = [1] + SEQ_values_part1 + SEQ_values_part2
    return SEQ_values

def generate_SEQ_other(seq_start, seq_end):
    SEQ_values_part1 = list(range(seq_start, 5000, 300)) 
    SEQ_values_part2 = list(range(5000, seq_end, 500))  
    SEQ_values = [1] + SEQ_values_part1 + SEQ_values_part2
    return SEQ_values


def generate_tuples(seq_start, seq_end, QH, KVH, QKD, VD):
    B_values = generate_B()
    SEQ_values_b1_32 = generate_SEQ_bs1_32(seq_start, seq_end)
    SEQ_values_b40_96 = generate_SEQ_bs40_96(seq_start, seq_end)
    SEQ_values_other = generate_SEQ_other(seq_start, seq_end)

    pa_tables = []
    for B in B_values:
        if B <= 32:
            for SEQ in SEQ_values_b1_32:
                pa_tables.append((B, SEQ, QH, KVH, QKD, VD))
        elif B <= 96 and B > 32:
            for SEQ in SEQ_values_b40_96:
                pa_tables.append((B, SEQ, QH, KVH, QKD, VD))
        else:
            for SEQ in SEQ_values_other:
                pa_tables.append((B, SEQ, QH, KVH, QKD, VD))
    
    return pa_tables


if __name__ == "__main__":
    args = parse_args()
    assert args.QH % args.TP == 0, "error QH, QH % TP must be 0"
    tuple_sizes = generate_tuples(args.SEQSTART, args.SEQEND, args.QH // args.TP, max(args.KVH // args.TP, 1), args.QKD, args.VD)

    for t in tuple_sizes:
        print(f"{t},")