Commit e6fd8fda authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.7.2-pa' of http://10.6.10.68/dcutoolkit/deeplearing/vllm into v0.7.2-pa

parents fbde1e5a 9c7191c3
'''
QH=num_attention_heads KVH=1 QKD=kv_lora_rank + qk_rope_head_dim VD=kv_lora_rank
deepseek v2: QH=16, QKD=512+64=576, VD=512
deepseek v3: QH=128, QKD=512+64=576, VD=512
python3 gen_mla_pa_tables.py --QH 16 --KVH 1 --QKD 576 --VD 512 --TP 1
python3 gen_mla_pa_tables.py --QH 16 --KVH 1 --QKD 576 --VD 512 --TP 2
python3 gen_mla_pa_tables.py --QH 16 --KVH 1 --QKD 576 --VD 512 --TP 4
python3 gen_mla_pa_tables.py --QH 16 --KVH 1 --QKD 576 --VD 512 --TP 8
python3 gen_mla_pa_tables.py --QH 128 --KVH 1 --QKD 576 --VD 512 --TP 32
'''
import argparse
def parse_args():
parser = argparse.ArgumentParser(description="Generate mla pa tables (B, SEQ, QH, KVH, QKD, VD).")
parser.add_argument('--SEQSTART', type=int, required=False, default=100, help='Value for SEQSTART')
parser.add_argument('--SEQEND', type=int, required=False, default=8193, help='Value for SEQEND')
parser.add_argument('--QH', type=int, required=True, help='Value for QH')
parser.add_argument('--KVH', type=int, required=True, help='Value for KVH')
parser.add_argument('--QKD', type=int, required=True, help='Value for QKD')
parser.add_argument('--VD', type=int, required=True, help='Value for VD')
parser.add_argument('--TP', type=int, required=True, help='Value for TP')
return parser.parse_args()
def generate_B():
# 1 4 8 16 ... 128,144 160 ... 256, 288 320 ... 512
B_values = [1, 4] + [i * 8 for i in range(1, 17)]
# B_values = [1, 4] + [i * 8 for i in range(1, 17)] + [(i + 1) * 16 for i in range(8, 16)] + [(i + 1) * 32 for i in range(8, 16)]
return B_values
def generate_SEQ_bs1_32(seq_start, seq_end):
SEQ_values_part1 = list(range(seq_start, 5000, 300))
SEQ_values_part2 = list(range(5000, seq_end, 500))
SEQ_values = [1] + SEQ_values_part1 + SEQ_values_part2
return SEQ_values
def generate_SEQ_bs40_96(seq_start, seq_end):
SEQ_values_part1 = list(range(seq_start, 5000, 200))
SEQ_values_part2 = list(range(5000, seq_end, 300))
SEQ_values = [1] + SEQ_values_part1 + SEQ_values_part2
return SEQ_values
def generate_SEQ_other(seq_start, seq_end):
SEQ_values_part1 = list(range(seq_start, 5000, 300))
SEQ_values_part2 = list(range(5000, seq_end, 500))
SEQ_values = [1] + SEQ_values_part1 + SEQ_values_part2
return SEQ_values
def generate_tuples(seq_start, seq_end, QH, KVH, QKD, VD):
B_values = generate_B()
SEQ_values_b1_32 = generate_SEQ_bs1_32(seq_start, seq_end)
SEQ_values_b40_96 = generate_SEQ_bs40_96(seq_start, seq_end)
SEQ_values_other = generate_SEQ_other(seq_start, seq_end)
pa_tables = []
for B in B_values:
if B <= 32:
for SEQ in SEQ_values_b1_32:
pa_tables.append((B, SEQ, QH, KVH, QKD, VD))
elif B <= 96 and B > 32:
for SEQ in SEQ_values_b40_96:
pa_tables.append((B, SEQ, QH, KVH, QKD, VD))
else:
for SEQ in SEQ_values_other:
pa_tables.append((B, SEQ, QH, KVH, QKD, VD))
return pa_tables
if __name__ == "__main__":
args = parse_args()
assert args.QH % args.TP == 0, "error QH, QH % TP must be 0"
tuple_sizes = generate_tuples(args.SEQSTART, args.SEQEND, args.QH // args.TP, max(args.KVH // args.TP, 1), args.QKD, args.VD)
for t in tuple_sizes:
print(f"{t},")
'''
bs1-128, seqlen:1-8k
'''
deepseek_v2_16b_tp1_configs=[
(1, 1, 16, 1, 576, 512),
(1, 100, 16, 1, 576, 512),
(1, 400, 16, 1, 576, 512),
(1, 700, 16, 1, 576, 512),
(1, 1000, 16, 1, 576, 512),
(1, 1300, 16, 1, 576, 512),
(1, 1600, 16, 1, 576, 512),
(1, 1900, 16, 1, 576, 512),
(1, 2200, 16, 1, 576, 512),
(1, 2500, 16, 1, 576, 512),
(1, 2800, 16, 1, 576, 512),
(1, 3100, 16, 1, 576, 512),
(1, 3400, 16, 1, 576, 512),
(1, 3700, 16, 1, 576, 512),
(1, 4000, 16, 1, 576, 512),
(1, 4300, 16, 1, 576, 512),
(1, 4600, 16, 1, 576, 512),
(1, 4900, 16, 1, 576, 512),
(1, 5000, 16, 1, 576, 512),
(1, 5500, 16, 1, 576, 512),
(1, 6000, 16, 1, 576, 512),
(1, 6500, 16, 1, 576, 512),
(1, 7000, 16, 1, 576, 512),
(1, 7500, 16, 1, 576, 512),
(1, 8000, 16, 1, 576, 512),
(4, 1, 16, 1, 576, 512),
(4, 100, 16, 1, 576, 512),
(4, 400, 16, 1, 576, 512),
(4, 700, 16, 1, 576, 512),
(4, 1000, 16, 1, 576, 512),
(4, 1300, 16, 1, 576, 512),
(4, 1600, 16, 1, 576, 512),
(4, 1900, 16, 1, 576, 512),
(4, 2200, 16, 1, 576, 512),
(4, 2500, 16, 1, 576, 512),
(4, 2800, 16, 1, 576, 512),
(4, 3100, 16, 1, 576, 512),
(4, 3400, 16, 1, 576, 512),
(4, 3700, 16, 1, 576, 512),
(4, 4000, 16, 1, 576, 512),
(4, 4300, 16, 1, 576, 512),
(4, 4600, 16, 1, 576, 512),
(4, 4900, 16, 1, 576, 512),
(4, 5000, 16, 1, 576, 512),
(4, 5500, 16, 1, 576, 512),
(4, 6000, 16, 1, 576, 512),
(4, 6500, 16, 1, 576, 512),
(4, 7000, 16, 1, 576, 512),
(4, 7500, 16, 1, 576, 512),
(4, 8000, 16, 1, 576, 512),
(8, 1, 16, 1, 576, 512),
(8, 100, 16, 1, 576, 512),
(8, 400, 16, 1, 576, 512),
(8, 700, 16, 1, 576, 512),
(8, 1000, 16, 1, 576, 512),
(8, 1300, 16, 1, 576, 512),
(8, 1600, 16, 1, 576, 512),
(8, 1900, 16, 1, 576, 512),
(8, 2200, 16, 1, 576, 512),
(8, 2500, 16, 1, 576, 512),
(8, 2800, 16, 1, 576, 512),
(8, 3100, 16, 1, 576, 512),
(8, 3400, 16, 1, 576, 512),
(8, 3700, 16, 1, 576, 512),
(8, 4000, 16, 1, 576, 512),
(8, 4300, 16, 1, 576, 512),
(8, 4600, 16, 1, 576, 512),
(8, 4900, 16, 1, 576, 512),
(8, 5000, 16, 1, 576, 512),
(8, 5500, 16, 1, 576, 512),
(8, 6000, 16, 1, 576, 512),
(8, 6500, 16, 1, 576, 512),
(8, 7000, 16, 1, 576, 512),
(8, 7500, 16, 1, 576, 512),
(8, 8000, 16, 1, 576, 512),
(16, 1, 16, 1, 576, 512),
(16, 100, 16, 1, 576, 512),
(16, 400, 16, 1, 576, 512),
(16, 700, 16, 1, 576, 512),
(16, 1000, 16, 1, 576, 512),
(16, 1300, 16, 1, 576, 512),
(16, 1600, 16, 1, 576, 512),
(16, 1900, 16, 1, 576, 512),
(16, 2200, 16, 1, 576, 512),
(16, 2500, 16, 1, 576, 512),
(16, 2800, 16, 1, 576, 512),
(16, 3100, 16, 1, 576, 512),
(16, 3400, 16, 1, 576, 512),
(16, 3700, 16, 1, 576, 512),
(16, 4000, 16, 1, 576, 512),
(16, 4300, 16, 1, 576, 512),
(16, 4600, 16, 1, 576, 512),
(16, 4900, 16, 1, 576, 512),
(16, 5000, 16, 1, 576, 512),
(16, 5500, 16, 1, 576, 512),
(16, 6000, 16, 1, 576, 512),
(16, 6500, 16, 1, 576, 512),
(16, 7000, 16, 1, 576, 512),
(16, 7500, 16, 1, 576, 512),
(16, 8000, 16, 1, 576, 512),
(24, 1, 16, 1, 576, 512),
(24, 100, 16, 1, 576, 512),
(24, 400, 16, 1, 576, 512),
(24, 700, 16, 1, 576, 512),
(24, 1000, 16, 1, 576, 512),
(24, 1300, 16, 1, 576, 512),
(24, 1600, 16, 1, 576, 512),
(24, 1900, 16, 1, 576, 512),
(24, 2200, 16, 1, 576, 512),
(24, 2500, 16, 1, 576, 512),
(24, 2800, 16, 1, 576, 512),
(24, 3100, 16, 1, 576, 512),
(24, 3400, 16, 1, 576, 512),
(24, 3700, 16, 1, 576, 512),
(24, 4000, 16, 1, 576, 512),
(24, 4300, 16, 1, 576, 512),
(24, 4600, 16, 1, 576, 512),
(24, 4900, 16, 1, 576, 512),
(24, 5000, 16, 1, 576, 512),
(24, 5500, 16, 1, 576, 512),
(24, 6000, 16, 1, 576, 512),
(24, 6500, 16, 1, 576, 512),
(24, 7000, 16, 1, 576, 512),
(24, 7500, 16, 1, 576, 512),
(24, 8000, 16, 1, 576, 512),
(32, 1, 16, 1, 576, 512),
(32, 100, 16, 1, 576, 512),
(32, 400, 16, 1, 576, 512),
(32, 700, 16, 1, 576, 512),
(32, 1000, 16, 1, 576, 512),
(32, 1300, 16, 1, 576, 512),
(32, 1600, 16, 1, 576, 512),
(32, 1900, 16, 1, 576, 512),
(32, 2200, 16, 1, 576, 512),
(32, 2500, 16, 1, 576, 512),
(32, 2800, 16, 1, 576, 512),
(32, 3100, 16, 1, 576, 512),
(32, 3400, 16, 1, 576, 512),
(32, 3700, 16, 1, 576, 512),
(32, 4000, 16, 1, 576, 512),
(32, 4300, 16, 1, 576, 512),
(32, 4600, 16, 1, 576, 512),
(32, 4900, 16, 1, 576, 512),
(32, 5000, 16, 1, 576, 512),
(32, 5500, 16, 1, 576, 512),
(32, 6000, 16, 1, 576, 512),
(32, 6500, 16, 1, 576, 512),
(32, 7000, 16, 1, 576, 512),
(32, 7500, 16, 1, 576, 512),
(32, 8000, 16, 1, 576, 512),
(40, 1, 16, 1, 576, 512),
(40, 100, 16, 1, 576, 512),
(40, 300, 16, 1, 576, 512),
(40, 500, 16, 1, 576, 512),
(40, 700, 16, 1, 576, 512),
(40, 900, 16, 1, 576, 512),
(40, 1100, 16, 1, 576, 512),
(40, 1300, 16, 1, 576, 512),
(40, 1500, 16, 1, 576, 512),
(40, 1700, 16, 1, 576, 512),
(40, 1900, 16, 1, 576, 512),
(40, 2100, 16, 1, 576, 512),
(40, 2300, 16, 1, 576, 512),
(40, 2500, 16, 1, 576, 512),
(40, 2700, 16, 1, 576, 512),
(40, 2900, 16, 1, 576, 512),
(40, 3100, 16, 1, 576, 512),
(40, 3300, 16, 1, 576, 512),
(40, 3500, 16, 1, 576, 512),
(40, 3700, 16, 1, 576, 512),
(40, 3900, 16, 1, 576, 512),
(40, 4100, 16, 1, 576, 512),
(40, 4300, 16, 1, 576, 512),
(40, 4500, 16, 1, 576, 512),
(40, 4700, 16, 1, 576, 512),
(40, 4900, 16, 1, 576, 512),
(40, 5000, 16, 1, 576, 512),
(40, 5300, 16, 1, 576, 512),
(40, 5600, 16, 1, 576, 512),
(40, 5900, 16, 1, 576, 512),
(40, 6200, 16, 1, 576, 512),
(40, 6500, 16, 1, 576, 512),
(40, 6800, 16, 1, 576, 512),
(40, 7100, 16, 1, 576, 512),
(40, 7400, 16, 1, 576, 512),
(40, 7700, 16, 1, 576, 512),
(40, 8000, 16, 1, 576, 512),
(48, 1, 16, 1, 576, 512),
(48, 100, 16, 1, 576, 512),
(48, 300, 16, 1, 576, 512),
(48, 500, 16, 1, 576, 512),
(48, 700, 16, 1, 576, 512),
(48, 900, 16, 1, 576, 512),
(48, 1100, 16, 1, 576, 512),
(48, 1300, 16, 1, 576, 512),
(48, 1500, 16, 1, 576, 512),
(48, 1700, 16, 1, 576, 512),
(48, 1900, 16, 1, 576, 512),
(48, 2100, 16, 1, 576, 512),
(48, 2300, 16, 1, 576, 512),
(48, 2500, 16, 1, 576, 512),
(48, 2700, 16, 1, 576, 512),
(48, 2900, 16, 1, 576, 512),
(48, 3100, 16, 1, 576, 512),
(48, 3300, 16, 1, 576, 512),
(48, 3500, 16, 1, 576, 512),
(48, 3700, 16, 1, 576, 512),
(48, 3900, 16, 1, 576, 512),
(48, 4100, 16, 1, 576, 512),
(48, 4300, 16, 1, 576, 512),
(48, 4500, 16, 1, 576, 512),
(48, 4700, 16, 1, 576, 512),
(48, 4900, 16, 1, 576, 512),
(48, 5000, 16, 1, 576, 512),
(48, 5300, 16, 1, 576, 512),
(48, 5600, 16, 1, 576, 512),
(48, 5900, 16, 1, 576, 512),
(48, 6200, 16, 1, 576, 512),
(48, 6500, 16, 1, 576, 512),
(48, 6800, 16, 1, 576, 512),
(48, 7100, 16, 1, 576, 512),
(48, 7400, 16, 1, 576, 512),
(48, 7700, 16, 1, 576, 512),
(48, 8000, 16, 1, 576, 512),
(56, 1, 16, 1, 576, 512),
(56, 100, 16, 1, 576, 512),
(56, 300, 16, 1, 576, 512),
(56, 500, 16, 1, 576, 512),
(56, 700, 16, 1, 576, 512),
(56, 900, 16, 1, 576, 512),
(56, 1100, 16, 1, 576, 512),
(56, 1300, 16, 1, 576, 512),
(56, 1500, 16, 1, 576, 512),
(56, 1700, 16, 1, 576, 512),
(56, 1900, 16, 1, 576, 512),
(56, 2100, 16, 1, 576, 512),
(56, 2300, 16, 1, 576, 512),
(56, 2500, 16, 1, 576, 512),
(56, 2700, 16, 1, 576, 512),
(56, 2900, 16, 1, 576, 512),
(56, 3100, 16, 1, 576, 512),
(56, 3300, 16, 1, 576, 512),
(56, 3500, 16, 1, 576, 512),
(56, 3700, 16, 1, 576, 512),
(56, 3900, 16, 1, 576, 512),
(56, 4100, 16, 1, 576, 512),
(56, 4300, 16, 1, 576, 512),
(56, 4500, 16, 1, 576, 512),
(56, 4700, 16, 1, 576, 512),
(56, 4900, 16, 1, 576, 512),
(56, 5000, 16, 1, 576, 512),
(56, 5300, 16, 1, 576, 512),
(56, 5600, 16, 1, 576, 512),
(56, 5900, 16, 1, 576, 512),
(56, 6200, 16, 1, 576, 512),
(56, 6500, 16, 1, 576, 512),
(56, 6800, 16, 1, 576, 512),
(56, 7100, 16, 1, 576, 512),
(56, 7400, 16, 1, 576, 512),
(56, 7700, 16, 1, 576, 512),
(56, 8000, 16, 1, 576, 512),
(64, 1, 16, 1, 576, 512),
(64, 100, 16, 1, 576, 512),
(64, 300, 16, 1, 576, 512),
(64, 500, 16, 1, 576, 512),
(64, 700, 16, 1, 576, 512),
(64, 900, 16, 1, 576, 512),
(64, 1100, 16, 1, 576, 512),
(64, 1300, 16, 1, 576, 512),
(64, 1500, 16, 1, 576, 512),
(64, 1700, 16, 1, 576, 512),
(64, 1900, 16, 1, 576, 512),
(64, 2100, 16, 1, 576, 512),
(64, 2300, 16, 1, 576, 512),
(64, 2500, 16, 1, 576, 512),
(64, 2700, 16, 1, 576, 512),
(64, 2900, 16, 1, 576, 512),
(64, 3100, 16, 1, 576, 512),
(64, 3300, 16, 1, 576, 512),
(64, 3500, 16, 1, 576, 512),
(64, 3700, 16, 1, 576, 512),
(64, 3900, 16, 1, 576, 512),
(64, 4100, 16, 1, 576, 512),
(64, 4300, 16, 1, 576, 512),
(64, 4500, 16, 1, 576, 512),
(64, 4700, 16, 1, 576, 512),
(64, 4900, 16, 1, 576, 512),
(64, 5000, 16, 1, 576, 512),
(64, 5300, 16, 1, 576, 512),
(64, 5600, 16, 1, 576, 512),
(64, 5900, 16, 1, 576, 512),
(64, 6200, 16, 1, 576, 512),
(64, 6500, 16, 1, 576, 512),
(64, 6800, 16, 1, 576, 512),
(64, 7100, 16, 1, 576, 512),
(64, 7400, 16, 1, 576, 512),
(64, 7700, 16, 1, 576, 512),
(64, 8000, 16, 1, 576, 512),
(72, 1, 16, 1, 576, 512),
(72, 100, 16, 1, 576, 512),
(72, 300, 16, 1, 576, 512),
(72, 500, 16, 1, 576, 512),
(72, 700, 16, 1, 576, 512),
(72, 900, 16, 1, 576, 512),
(72, 1100, 16, 1, 576, 512),
(72, 1300, 16, 1, 576, 512),
(72, 1500, 16, 1, 576, 512),
(72, 1700, 16, 1, 576, 512),
(72, 1900, 16, 1, 576, 512),
(72, 2100, 16, 1, 576, 512),
(72, 2300, 16, 1, 576, 512),
(72, 2500, 16, 1, 576, 512),
(72, 2700, 16, 1, 576, 512),
(72, 2900, 16, 1, 576, 512),
(72, 3100, 16, 1, 576, 512),
(72, 3300, 16, 1, 576, 512),
(72, 3500, 16, 1, 576, 512),
(72, 3700, 16, 1, 576, 512),
(72, 3900, 16, 1, 576, 512),
(72, 4100, 16, 1, 576, 512),
(72, 4300, 16, 1, 576, 512),
(72, 4500, 16, 1, 576, 512),
(72, 4700, 16, 1, 576, 512),
(72, 4900, 16, 1, 576, 512),
(72, 5000, 16, 1, 576, 512),
(72, 5300, 16, 1, 576, 512),
(72, 5600, 16, 1, 576, 512),
(72, 5900, 16, 1, 576, 512),
(72, 6200, 16, 1, 576, 512),
(72, 6500, 16, 1, 576, 512),
(72, 6800, 16, 1, 576, 512),
(72, 7100, 16, 1, 576, 512),
(72, 7400, 16, 1, 576, 512),
(72, 7700, 16, 1, 576, 512),
(72, 8000, 16, 1, 576, 512),
(80, 1, 16, 1, 576, 512),
(80, 100, 16, 1, 576, 512),
(80, 300, 16, 1, 576, 512),
(80, 500, 16, 1, 576, 512),
(80, 700, 16, 1, 576, 512),
(80, 900, 16, 1, 576, 512),
(80, 1100, 16, 1, 576, 512),
(80, 1300, 16, 1, 576, 512),
(80, 1500, 16, 1, 576, 512),
(80, 1700, 16, 1, 576, 512),
(80, 1900, 16, 1, 576, 512),
(80, 2100, 16, 1, 576, 512),
(80, 2300, 16, 1, 576, 512),
(80, 2500, 16, 1, 576, 512),
(80, 2700, 16, 1, 576, 512),
(80, 2900, 16, 1, 576, 512),
(80, 3100, 16, 1, 576, 512),
(80, 3300, 16, 1, 576, 512),
(80, 3500, 16, 1, 576, 512),
(80, 3700, 16, 1, 576, 512),
(80, 3900, 16, 1, 576, 512),
(80, 4100, 16, 1, 576, 512),
(80, 4300, 16, 1, 576, 512),
(80, 4500, 16, 1, 576, 512),
(80, 4700, 16, 1, 576, 512),
(80, 4900, 16, 1, 576, 512),
(80, 5000, 16, 1, 576, 512),
(80, 5300, 16, 1, 576, 512),
(80, 5600, 16, 1, 576, 512),
(80, 5900, 16, 1, 576, 512),
(80, 6200, 16, 1, 576, 512),
(80, 6500, 16, 1, 576, 512),
(80, 6800, 16, 1, 576, 512),
(80, 7100, 16, 1, 576, 512),
(80, 7400, 16, 1, 576, 512),
(80, 7700, 16, 1, 576, 512),
(80, 8000, 16, 1, 576, 512),
(88, 1, 16, 1, 576, 512),
(88, 100, 16, 1, 576, 512),
(88, 300, 16, 1, 576, 512),
(88, 500, 16, 1, 576, 512),
(88, 700, 16, 1, 576, 512),
(88, 900, 16, 1, 576, 512),
(88, 1100, 16, 1, 576, 512),
(88, 1300, 16, 1, 576, 512),
(88, 1500, 16, 1, 576, 512),
(88, 1700, 16, 1, 576, 512),
(88, 1900, 16, 1, 576, 512),
(88, 2100, 16, 1, 576, 512),
(88, 2300, 16, 1, 576, 512),
(88, 2500, 16, 1, 576, 512),
(88, 2700, 16, 1, 576, 512),
(88, 2900, 16, 1, 576, 512),
(88, 3100, 16, 1, 576, 512),
(88, 3300, 16, 1, 576, 512),
(88, 3500, 16, 1, 576, 512),
(88, 3700, 16, 1, 576, 512),
(88, 3900, 16, 1, 576, 512),
(88, 4100, 16, 1, 576, 512),
(88, 4300, 16, 1, 576, 512),
(88, 4500, 16, 1, 576, 512),
(88, 4700, 16, 1, 576, 512),
(88, 4900, 16, 1, 576, 512),
(88, 5000, 16, 1, 576, 512),
(88, 5300, 16, 1, 576, 512),
(88, 5600, 16, 1, 576, 512),
(88, 5900, 16, 1, 576, 512),
(88, 6200, 16, 1, 576, 512),
(88, 6500, 16, 1, 576, 512),
(88, 6800, 16, 1, 576, 512),
(88, 7100, 16, 1, 576, 512),
(88, 7400, 16, 1, 576, 512),
(88, 7700, 16, 1, 576, 512),
(88, 8000, 16, 1, 576, 512),
(96, 1, 16, 1, 576, 512),
(96, 100, 16, 1, 576, 512),
(96, 300, 16, 1, 576, 512),
(96, 500, 16, 1, 576, 512),
(96, 700, 16, 1, 576, 512),
(96, 900, 16, 1, 576, 512),
(96, 1100, 16, 1, 576, 512),
(96, 1300, 16, 1, 576, 512),
(96, 1500, 16, 1, 576, 512),
(96, 1700, 16, 1, 576, 512),
(96, 1900, 16, 1, 576, 512),
(96, 2100, 16, 1, 576, 512),
(96, 2300, 16, 1, 576, 512),
(96, 2500, 16, 1, 576, 512),
(96, 2700, 16, 1, 576, 512),
(96, 2900, 16, 1, 576, 512),
(96, 3100, 16, 1, 576, 512),
(96, 3300, 16, 1, 576, 512),
(96, 3500, 16, 1, 576, 512),
(96, 3700, 16, 1, 576, 512),
(96, 3900, 16, 1, 576, 512),
(96, 4100, 16, 1, 576, 512),
(96, 4300, 16, 1, 576, 512),
(96, 4500, 16, 1, 576, 512),
(96, 4700, 16, 1, 576, 512),
(96, 4900, 16, 1, 576, 512),
(96, 5000, 16, 1, 576, 512),
(96, 5300, 16, 1, 576, 512),
(96, 5600, 16, 1, 576, 512),
(96, 5900, 16, 1, 576, 512),
(96, 6200, 16, 1, 576, 512),
(96, 6500, 16, 1, 576, 512),
(96, 6800, 16, 1, 576, 512),
(96, 7100, 16, 1, 576, 512),
(96, 7400, 16, 1, 576, 512),
(96, 7700, 16, 1, 576, 512),
(96, 8000, 16, 1, 576, 512),
(104, 1, 16, 1, 576, 512),
(104, 100, 16, 1, 576, 512),
(104, 400, 16, 1, 576, 512),
(104, 700, 16, 1, 576, 512),
(104, 1000, 16, 1, 576, 512),
(104, 1300, 16, 1, 576, 512),
(104, 1600, 16, 1, 576, 512),
(104, 1900, 16, 1, 576, 512),
(104, 2200, 16, 1, 576, 512),
(104, 2500, 16, 1, 576, 512),
(104, 2800, 16, 1, 576, 512),
(104, 3100, 16, 1, 576, 512),
(104, 3400, 16, 1, 576, 512),
(104, 3700, 16, 1, 576, 512),
(104, 4000, 16, 1, 576, 512),
(104, 4300, 16, 1, 576, 512),
(104, 4600, 16, 1, 576, 512),
(104, 4900, 16, 1, 576, 512),
(104, 5000, 16, 1, 576, 512),
(104, 5500, 16, 1, 576, 512),
(104, 6000, 16, 1, 576, 512),
(104, 6500, 16, 1, 576, 512),
(104, 7000, 16, 1, 576, 512),
(104, 7500, 16, 1, 576, 512),
(104, 8000, 16, 1, 576, 512),
(112, 1, 16, 1, 576, 512),
(112, 100, 16, 1, 576, 512),
(112, 400, 16, 1, 576, 512),
(112, 700, 16, 1, 576, 512),
(112, 1000, 16, 1, 576, 512),
(112, 1300, 16, 1, 576, 512),
(112, 1600, 16, 1, 576, 512),
(112, 1900, 16, 1, 576, 512),
(112, 2200, 16, 1, 576, 512),
(112, 2500, 16, 1, 576, 512),
(112, 2800, 16, 1, 576, 512),
(112, 3100, 16, 1, 576, 512),
(112, 3400, 16, 1, 576, 512),
(112, 3700, 16, 1, 576, 512),
(112, 4000, 16, 1, 576, 512),
(112, 4300, 16, 1, 576, 512),
(112, 4600, 16, 1, 576, 512),
(112, 4900, 16, 1, 576, 512),
(112, 5000, 16, 1, 576, 512),
(112, 5500, 16, 1, 576, 512),
(112, 6000, 16, 1, 576, 512),
(112, 6500, 16, 1, 576, 512),
(112, 7000, 16, 1, 576, 512),
(112, 7500, 16, 1, 576, 512),
(112, 8000, 16, 1, 576, 512),
(120, 1, 16, 1, 576, 512),
(120, 100, 16, 1, 576, 512),
(120, 400, 16, 1, 576, 512),
(120, 700, 16, 1, 576, 512),
(120, 1000, 16, 1, 576, 512),
(120, 1300, 16, 1, 576, 512),
(120, 1600, 16, 1, 576, 512),
(120, 1900, 16, 1, 576, 512),
(120, 2200, 16, 1, 576, 512),
(120, 2500, 16, 1, 576, 512),
(120, 2800, 16, 1, 576, 512),
(120, 3100, 16, 1, 576, 512),
(120, 3400, 16, 1, 576, 512),
(120, 3700, 16, 1, 576, 512),
(120, 4000, 16, 1, 576, 512),
(120, 4300, 16, 1, 576, 512),
(120, 4600, 16, 1, 576, 512),
(120, 4900, 16, 1, 576, 512),
(120, 5000, 16, 1, 576, 512),
(120, 5500, 16, 1, 576, 512),
(120, 6000, 16, 1, 576, 512),
(120, 6500, 16, 1, 576, 512),
(120, 7000, 16, 1, 576, 512),
(120, 7500, 16, 1, 576, 512),
(120, 8000, 16, 1, 576, 512),
(128, 1, 16, 1, 576, 512),
(128, 100, 16, 1, 576, 512),
(128, 400, 16, 1, 576, 512),
(128, 700, 16, 1, 576, 512),
(128, 1000, 16, 1, 576, 512),
(128, 1300, 16, 1, 576, 512),
(128, 1600, 16, 1, 576, 512),
(128, 1900, 16, 1, 576, 512),
(128, 2200, 16, 1, 576, 512),
(128, 2500, 16, 1, 576, 512),
(128, 2800, 16, 1, 576, 512),
(128, 3100, 16, 1, 576, 512),
(128, 3400, 16, 1, 576, 512),
(128, 3700, 16, 1, 576, 512),
(128, 4000, 16, 1, 576, 512),
(128, 4300, 16, 1, 576, 512),
(128, 4600, 16, 1, 576, 512),
(128, 4900, 16, 1, 576, 512),
(128, 5000, 16, 1, 576, 512),
(128, 5500, 16, 1, 576, 512),
(128, 6000, 16, 1, 576, 512),
(128, 6500, 16, 1, 576, 512),
(128, 7000, 16, 1, 576, 512),
(128, 7500, 16, 1, 576, 512),
(128, 8000, 16, 1, 576, 512),
]
deepseek_v2_16b_tp2_configs=[
(1, 1, 8, 1, 576, 512),
(1, 100, 8, 1, 576, 512),
(1, 400, 8, 1, 576, 512),
(1, 700, 8, 1, 576, 512),
(1, 1000, 8, 1, 576, 512),
(1, 1300, 8, 1, 576, 512),
(1, 1600, 8, 1, 576, 512),
(1, 1900, 8, 1, 576, 512),
(1, 2200, 8, 1, 576, 512),
(1, 2500, 8, 1, 576, 512),
(1, 2800, 8, 1, 576, 512),
(1, 3100, 8, 1, 576, 512),
(1, 3400, 8, 1, 576, 512),
(1, 3700, 8, 1, 576, 512),
(1, 4000, 8, 1, 576, 512),
(1, 4300, 8, 1, 576, 512),
(1, 4600, 8, 1, 576, 512),
(1, 4900, 8, 1, 576, 512),
(1, 5000, 8, 1, 576, 512),
(1, 5500, 8, 1, 576, 512),
(1, 6000, 8, 1, 576, 512),
(1, 6500, 8, 1, 576, 512),
(1, 7000, 8, 1, 576, 512),
(1, 7500, 8, 1, 576, 512),
(1, 8000, 8, 1, 576, 512),
(4, 1, 8, 1, 576, 512),
(4, 100, 8, 1, 576, 512),
(4, 400, 8, 1, 576, 512),
(4, 700, 8, 1, 576, 512),
(4, 1000, 8, 1, 576, 512),
(4, 1300, 8, 1, 576, 512),
(4, 1600, 8, 1, 576, 512),
(4, 1900, 8, 1, 576, 512),
(4, 2200, 8, 1, 576, 512),
(4, 2500, 8, 1, 576, 512),
(4, 2800, 8, 1, 576, 512),
(4, 3100, 8, 1, 576, 512),
(4, 3400, 8, 1, 576, 512),
(4, 3700, 8, 1, 576, 512),
(4, 4000, 8, 1, 576, 512),
(4, 4300, 8, 1, 576, 512),
(4, 4600, 8, 1, 576, 512),
(4, 4900, 8, 1, 576, 512),
(4, 5000, 8, 1, 576, 512),
(4, 5500, 8, 1, 576, 512),
(4, 6000, 8, 1, 576, 512),
(4, 6500, 8, 1, 576, 512),
(4, 7000, 8, 1, 576, 512),
(4, 7500, 8, 1, 576, 512),
(4, 8000, 8, 1, 576, 512),
(8, 1, 8, 1, 576, 512),
(8, 100, 8, 1, 576, 512),
(8, 400, 8, 1, 576, 512),
(8, 700, 8, 1, 576, 512),
(8, 1000, 8, 1, 576, 512),
(8, 1300, 8, 1, 576, 512),
(8, 1600, 8, 1, 576, 512),
(8, 1900, 8, 1, 576, 512),
(8, 2200, 8, 1, 576, 512),
(8, 2500, 8, 1, 576, 512),
(8, 2800, 8, 1, 576, 512),
(8, 3100, 8, 1, 576, 512),
(8, 3400, 8, 1, 576, 512),
(8, 3700, 8, 1, 576, 512),
(8, 4000, 8, 1, 576, 512),
(8, 4300, 8, 1, 576, 512),
(8, 4600, 8, 1, 576, 512),
(8, 4900, 8, 1, 576, 512),
(8, 5000, 8, 1, 576, 512),
(8, 5500, 8, 1, 576, 512),
(8, 6000, 8, 1, 576, 512),
(8, 6500, 8, 1, 576, 512),
(8, 7000, 8, 1, 576, 512),
(8, 7500, 8, 1, 576, 512),
(8, 8000, 8, 1, 576, 512),
(16, 1, 8, 1, 576, 512),
(16, 100, 8, 1, 576, 512),
(16, 400, 8, 1, 576, 512),
(16, 700, 8, 1, 576, 512),
(16, 1000, 8, 1, 576, 512),
(16, 1300, 8, 1, 576, 512),
(16, 1600, 8, 1, 576, 512),
(16, 1900, 8, 1, 576, 512),
(16, 2200, 8, 1, 576, 512),
(16, 2500, 8, 1, 576, 512),
(16, 2800, 8, 1, 576, 512),
(16, 3100, 8, 1, 576, 512),
(16, 3400, 8, 1, 576, 512),
(16, 3700, 8, 1, 576, 512),
(16, 4000, 8, 1, 576, 512),
(16, 4300, 8, 1, 576, 512),
(16, 4600, 8, 1, 576, 512),
(16, 4900, 8, 1, 576, 512),
(16, 5000, 8, 1, 576, 512),
(16, 5500, 8, 1, 576, 512),
(16, 6000, 8, 1, 576, 512),
(16, 6500, 8, 1, 576, 512),
(16, 7000, 8, 1, 576, 512),
(16, 7500, 8, 1, 576, 512),
(16, 8000, 8, 1, 576, 512),
(24, 1, 8, 1, 576, 512),
(24, 100, 8, 1, 576, 512),
(24, 400, 8, 1, 576, 512),
(24, 700, 8, 1, 576, 512),
(24, 1000, 8, 1, 576, 512),
(24, 1300, 8, 1, 576, 512),
(24, 1600, 8, 1, 576, 512),
(24, 1900, 8, 1, 576, 512),
(24, 2200, 8, 1, 576, 512),
(24, 2500, 8, 1, 576, 512),
(24, 2800, 8, 1, 576, 512),
(24, 3100, 8, 1, 576, 512),
(24, 3400, 8, 1, 576, 512),
(24, 3700, 8, 1, 576, 512),
(24, 4000, 8, 1, 576, 512),
(24, 4300, 8, 1, 576, 512),
(24, 4600, 8, 1, 576, 512),
(24, 4900, 8, 1, 576, 512),
(24, 5000, 8, 1, 576, 512),
(24, 5500, 8, 1, 576, 512),
(24, 6000, 8, 1, 576, 512),
(24, 6500, 8, 1, 576, 512),
(24, 7000, 8, 1, 576, 512),
(24, 7500, 8, 1, 576, 512),
(24, 8000, 8, 1, 576, 512),
(32, 1, 8, 1, 576, 512),
(32, 100, 8, 1, 576, 512),
(32, 400, 8, 1, 576, 512),
(32, 700, 8, 1, 576, 512),
(32, 1000, 8, 1, 576, 512),
(32, 1300, 8, 1, 576, 512),
(32, 1600, 8, 1, 576, 512),
(32, 1900, 8, 1, 576, 512),
(32, 2200, 8, 1, 576, 512),
(32, 2500, 8, 1, 576, 512),
(32, 2800, 8, 1, 576, 512),
(32, 3100, 8, 1, 576, 512),
(32, 3400, 8, 1, 576, 512),
(32, 3700, 8, 1, 576, 512),
(32, 4000, 8, 1, 576, 512),
(32, 4300, 8, 1, 576, 512),
(32, 4600, 8, 1, 576, 512),
(32, 4900, 8, 1, 576, 512),
(32, 5000, 8, 1, 576, 512),
(32, 5500, 8, 1, 576, 512),
(32, 6000, 8, 1, 576, 512),
(32, 6500, 8, 1, 576, 512),
(32, 7000, 8, 1, 576, 512),
(32, 7500, 8, 1, 576, 512),
(32, 8000, 8, 1, 576, 512),
(40, 1, 8, 1, 576, 512),
(40, 100, 8, 1, 576, 512),
(40, 300, 8, 1, 576, 512),
(40, 500, 8, 1, 576, 512),
(40, 700, 8, 1, 576, 512),
(40, 900, 8, 1, 576, 512),
(40, 1100, 8, 1, 576, 512),
(40, 1300, 8, 1, 576, 512),
(40, 1500, 8, 1, 576, 512),
(40, 1700, 8, 1, 576, 512),
(40, 1900, 8, 1, 576, 512),
(40, 2100, 8, 1, 576, 512),
(40, 2300, 8, 1, 576, 512),
(40, 2500, 8, 1, 576, 512),
(40, 2700, 8, 1, 576, 512),
(40, 2900, 8, 1, 576, 512),
(40, 3100, 8, 1, 576, 512),
(40, 3300, 8, 1, 576, 512),
(40, 3500, 8, 1, 576, 512),
(40, 3700, 8, 1, 576, 512),
(40, 3900, 8, 1, 576, 512),
(40, 4100, 8, 1, 576, 512),
(40, 4300, 8, 1, 576, 512),
(40, 4500, 8, 1, 576, 512),
(40, 4700, 8, 1, 576, 512),
(40, 4900, 8, 1, 576, 512),
(40, 5000, 8, 1, 576, 512),
(40, 5300, 8, 1, 576, 512),
(40, 5600, 8, 1, 576, 512),
(40, 5900, 8, 1, 576, 512),
(40, 6200, 8, 1, 576, 512),
(40, 6500, 8, 1, 576, 512),
(40, 6800, 8, 1, 576, 512),
(40, 7100, 8, 1, 576, 512),
(40, 7400, 8, 1, 576, 512),
(40, 7700, 8, 1, 576, 512),
(40, 8000, 8, 1, 576, 512),
(48, 1, 8, 1, 576, 512),
(48, 100, 8, 1, 576, 512),
(48, 300, 8, 1, 576, 512),
(48, 500, 8, 1, 576, 512),
(48, 700, 8, 1, 576, 512),
(48, 900, 8, 1, 576, 512),
(48, 1100, 8, 1, 576, 512),
(48, 1300, 8, 1, 576, 512),
(48, 1500, 8, 1, 576, 512),
(48, 1700, 8, 1, 576, 512),
(48, 1900, 8, 1, 576, 512),
(48, 2100, 8, 1, 576, 512),
(48, 2300, 8, 1, 576, 512),
(48, 2500, 8, 1, 576, 512),
(48, 2700, 8, 1, 576, 512),
(48, 2900, 8, 1, 576, 512),
(48, 3100, 8, 1, 576, 512),
(48, 3300, 8, 1, 576, 512),
(48, 3500, 8, 1, 576, 512),
(48, 3700, 8, 1, 576, 512),
(48, 3900, 8, 1, 576, 512),
(48, 4100, 8, 1, 576, 512),
(48, 4300, 8, 1, 576, 512),
(48, 4500, 8, 1, 576, 512),
(48, 4700, 8, 1, 576, 512),
(48, 4900, 8, 1, 576, 512),
(48, 5000, 8, 1, 576, 512),
(48, 5300, 8, 1, 576, 512),
(48, 5600, 8, 1, 576, 512),
(48, 5900, 8, 1, 576, 512),
(48, 6200, 8, 1, 576, 512),
(48, 6500, 8, 1, 576, 512),
(48, 6800, 8, 1, 576, 512),
(48, 7100, 8, 1, 576, 512),
(48, 7400, 8, 1, 576, 512),
(48, 7700, 8, 1, 576, 512),
(48, 8000, 8, 1, 576, 512),
(56, 1, 8, 1, 576, 512),
(56, 100, 8, 1, 576, 512),
(56, 300, 8, 1, 576, 512),
(56, 500, 8, 1, 576, 512),
(56, 700, 8, 1, 576, 512),
(56, 900, 8, 1, 576, 512),
(56, 1100, 8, 1, 576, 512),
(56, 1300, 8, 1, 576, 512),
(56, 1500, 8, 1, 576, 512),
(56, 1700, 8, 1, 576, 512),
(56, 1900, 8, 1, 576, 512),
(56, 2100, 8, 1, 576, 512),
(56, 2300, 8, 1, 576, 512),
(56, 2500, 8, 1, 576, 512),
(56, 2700, 8, 1, 576, 512),
(56, 2900, 8, 1, 576, 512),
(56, 3100, 8, 1, 576, 512),
(56, 3300, 8, 1, 576, 512),
(56, 3500, 8, 1, 576, 512),
(56, 3700, 8, 1, 576, 512),
(56, 3900, 8, 1, 576, 512),
(56, 4100, 8, 1, 576, 512),
(56, 4300, 8, 1, 576, 512),
(56, 4500, 8, 1, 576, 512),
(56, 4700, 8, 1, 576, 512),
(56, 4900, 8, 1, 576, 512),
(56, 5000, 8, 1, 576, 512),
(56, 5300, 8, 1, 576, 512),
(56, 5600, 8, 1, 576, 512),
(56, 5900, 8, 1, 576, 512),
(56, 6200, 8, 1, 576, 512),
(56, 6500, 8, 1, 576, 512),
(56, 6800, 8, 1, 576, 512),
(56, 7100, 8, 1, 576, 512),
(56, 7400, 8, 1, 576, 512),
(56, 7700, 8, 1, 576, 512),
(56, 8000, 8, 1, 576, 512),
(64, 1, 8, 1, 576, 512),
(64, 100, 8, 1, 576, 512),
(64, 300, 8, 1, 576, 512),
(64, 500, 8, 1, 576, 512),
(64, 700, 8, 1, 576, 512),
(64, 900, 8, 1, 576, 512),
(64, 1100, 8, 1, 576, 512),
(64, 1300, 8, 1, 576, 512),
(64, 1500, 8, 1, 576, 512),
(64, 1700, 8, 1, 576, 512),
(64, 1900, 8, 1, 576, 512),
(64, 2100, 8, 1, 576, 512),
(64, 2300, 8, 1, 576, 512),
(64, 2500, 8, 1, 576, 512),
(64, 2700, 8, 1, 576, 512),
(64, 2900, 8, 1, 576, 512),
(64, 3100, 8, 1, 576, 512),
(64, 3300, 8, 1, 576, 512),
(64, 3500, 8, 1, 576, 512),
(64, 3700, 8, 1, 576, 512),
(64, 3900, 8, 1, 576, 512),
(64, 4100, 8, 1, 576, 512),
(64, 4300, 8, 1, 576, 512),
(64, 4500, 8, 1, 576, 512),
(64, 4700, 8, 1, 576, 512),
(64, 4900, 8, 1, 576, 512),
(64, 5000, 8, 1, 576, 512),
(64, 5300, 8, 1, 576, 512),
(64, 5600, 8, 1, 576, 512),
(64, 5900, 8, 1, 576, 512),
(64, 6200, 8, 1, 576, 512),
(64, 6500, 8, 1, 576, 512),
(64, 6800, 8, 1, 576, 512),
(64, 7100, 8, 1, 576, 512),
(64, 7400, 8, 1, 576, 512),
(64, 7700, 8, 1, 576, 512),
(64, 8000, 8, 1, 576, 512),
(72, 1, 8, 1, 576, 512),
(72, 100, 8, 1, 576, 512),
(72, 300, 8, 1, 576, 512),
(72, 500, 8, 1, 576, 512),
(72, 700, 8, 1, 576, 512),
(72, 900, 8, 1, 576, 512),
(72, 1100, 8, 1, 576, 512),
(72, 1300, 8, 1, 576, 512),
(72, 1500, 8, 1, 576, 512),
(72, 1700, 8, 1, 576, 512),
(72, 1900, 8, 1, 576, 512),
(72, 2100, 8, 1, 576, 512),
(72, 2300, 8, 1, 576, 512),
(72, 2500, 8, 1, 576, 512),
(72, 2700, 8, 1, 576, 512),
(72, 2900, 8, 1, 576, 512),
(72, 3100, 8, 1, 576, 512),
(72, 3300, 8, 1, 576, 512),
(72, 3500, 8, 1, 576, 512),
(72, 3700, 8, 1, 576, 512),
(72, 3900, 8, 1, 576, 512),
(72, 4100, 8, 1, 576, 512),
(72, 4300, 8, 1, 576, 512),
(72, 4500, 8, 1, 576, 512),
(72, 4700, 8, 1, 576, 512),
(72, 4900, 8, 1, 576, 512),
(72, 5000, 8, 1, 576, 512),
(72, 5300, 8, 1, 576, 512),
(72, 5600, 8, 1, 576, 512),
(72, 5900, 8, 1, 576, 512),
(72, 6200, 8, 1, 576, 512),
(72, 6500, 8, 1, 576, 512),
(72, 6800, 8, 1, 576, 512),
(72, 7100, 8, 1, 576, 512),
(72, 7400, 8, 1, 576, 512),
(72, 7700, 8, 1, 576, 512),
(72, 8000, 8, 1, 576, 512),
(80, 1, 8, 1, 576, 512),
(80, 100, 8, 1, 576, 512),
(80, 300, 8, 1, 576, 512),
(80, 500, 8, 1, 576, 512),
(80, 700, 8, 1, 576, 512),
(80, 900, 8, 1, 576, 512),
(80, 1100, 8, 1, 576, 512),
(80, 1300, 8, 1, 576, 512),
(80, 1500, 8, 1, 576, 512),
(80, 1700, 8, 1, 576, 512),
(80, 1900, 8, 1, 576, 512),
(80, 2100, 8, 1, 576, 512),
(80, 2300, 8, 1, 576, 512),
(80, 2500, 8, 1, 576, 512),
(80, 2700, 8, 1, 576, 512),
(80, 2900, 8, 1, 576, 512),
(80, 3100, 8, 1, 576, 512),
(80, 3300, 8, 1, 576, 512),
(80, 3500, 8, 1, 576, 512),
(80, 3700, 8, 1, 576, 512),
(80, 3900, 8, 1, 576, 512),
(80, 4100, 8, 1, 576, 512),
(80, 4300, 8, 1, 576, 512),
(80, 4500, 8, 1, 576, 512),
(80, 4700, 8, 1, 576, 512),
(80, 4900, 8, 1, 576, 512),
(80, 5000, 8, 1, 576, 512),
(80, 5300, 8, 1, 576, 512),
(80, 5600, 8, 1, 576, 512),
(80, 5900, 8, 1, 576, 512),
(80, 6200, 8, 1, 576, 512),
(80, 6500, 8, 1, 576, 512),
(80, 6800, 8, 1, 576, 512),
(80, 7100, 8, 1, 576, 512),
(80, 7400, 8, 1, 576, 512),
(80, 7700, 8, 1, 576, 512),
(80, 8000, 8, 1, 576, 512),
(88, 1, 8, 1, 576, 512),
(88, 100, 8, 1, 576, 512),
(88, 300, 8, 1, 576, 512),
(88, 500, 8, 1, 576, 512),
(88, 700, 8, 1, 576, 512),
(88, 900, 8, 1, 576, 512),
(88, 1100, 8, 1, 576, 512),
(88, 1300, 8, 1, 576, 512),
(88, 1500, 8, 1, 576, 512),
(88, 1700, 8, 1, 576, 512),
(88, 1900, 8, 1, 576, 512),
(88, 2100, 8, 1, 576, 512),
(88, 2300, 8, 1, 576, 512),
(88, 2500, 8, 1, 576, 512),
(88, 2700, 8, 1, 576, 512),
(88, 2900, 8, 1, 576, 512),
(88, 3100, 8, 1, 576, 512),
(88, 3300, 8, 1, 576, 512),
(88, 3500, 8, 1, 576, 512),
(88, 3700, 8, 1, 576, 512),
(88, 3900, 8, 1, 576, 512),
(88, 4100, 8, 1, 576, 512),
(88, 4300, 8, 1, 576, 512),
(88, 4500, 8, 1, 576, 512),
(88, 4700, 8, 1, 576, 512),
(88, 4900, 8, 1, 576, 512),
(88, 5000, 8, 1, 576, 512),
(88, 5300, 8, 1, 576, 512),
(88, 5600, 8, 1, 576, 512),
(88, 5900, 8, 1, 576, 512),
(88, 6200, 8, 1, 576, 512),
(88, 6500, 8, 1, 576, 512),
(88, 6800, 8, 1, 576, 512),
(88, 7100, 8, 1, 576, 512),
(88, 7400, 8, 1, 576, 512),
(88, 7700, 8, 1, 576, 512),
(88, 8000, 8, 1, 576, 512),
(96, 1, 8, 1, 576, 512),
(96, 100, 8, 1, 576, 512),
(96, 300, 8, 1, 576, 512),
(96, 500, 8, 1, 576, 512),
(96, 700, 8, 1, 576, 512),
(96, 900, 8, 1, 576, 512),
(96, 1100, 8, 1, 576, 512),
(96, 1300, 8, 1, 576, 512),
(96, 1500, 8, 1, 576, 512),
(96, 1700, 8, 1, 576, 512),
(96, 1900, 8, 1, 576, 512),
(96, 2100, 8, 1, 576, 512),
(96, 2300, 8, 1, 576, 512),
(96, 2500, 8, 1, 576, 512),
(96, 2700, 8, 1, 576, 512),
(96, 2900, 8, 1, 576, 512),
(96, 3100, 8, 1, 576, 512),
(96, 3300, 8, 1, 576, 512),
(96, 3500, 8, 1, 576, 512),
(96, 3700, 8, 1, 576, 512),
(96, 3900, 8, 1, 576, 512),
(96, 4100, 8, 1, 576, 512),
(96, 4300, 8, 1, 576, 512),
(96, 4500, 8, 1, 576, 512),
(96, 4700, 8, 1, 576, 512),
(96, 4900, 8, 1, 576, 512),
(96, 5000, 8, 1, 576, 512),
(96, 5300, 8, 1, 576, 512),
(96, 5600, 8, 1, 576, 512),
(96, 5900, 8, 1, 576, 512),
(96, 6200, 8, 1, 576, 512),
(96, 6500, 8, 1, 576, 512),
(96, 6800, 8, 1, 576, 512),
(96, 7100, 8, 1, 576, 512),
(96, 7400, 8, 1, 576, 512),
(96, 7700, 8, 1, 576, 512),
(96, 8000, 8, 1, 576, 512),
(104, 1, 8, 1, 576, 512),
(104, 100, 8, 1, 576, 512),
(104, 400, 8, 1, 576, 512),
(104, 700, 8, 1, 576, 512),
(104, 1000, 8, 1, 576, 512),
(104, 1300, 8, 1, 576, 512),
(104, 1600, 8, 1, 576, 512),
(104, 1900, 8, 1, 576, 512),
(104, 2200, 8, 1, 576, 512),
(104, 2500, 8, 1, 576, 512),
(104, 2800, 8, 1, 576, 512),
(104, 3100, 8, 1, 576, 512),
(104, 3400, 8, 1, 576, 512),
(104, 3700, 8, 1, 576, 512),
(104, 4000, 8, 1, 576, 512),
(104, 4300, 8, 1, 576, 512),
(104, 4600, 8, 1, 576, 512),
(104, 4900, 8, 1, 576, 512),
(104, 5000, 8, 1, 576, 512),
(104, 5500, 8, 1, 576, 512),
(104, 6000, 8, 1, 576, 512),
(104, 6500, 8, 1, 576, 512),
(104, 7000, 8, 1, 576, 512),
(104, 7500, 8, 1, 576, 512),
(104, 8000, 8, 1, 576, 512),
(112, 1, 8, 1, 576, 512),
(112, 100, 8, 1, 576, 512),
(112, 400, 8, 1, 576, 512),
(112, 700, 8, 1, 576, 512),
(112, 1000, 8, 1, 576, 512),
(112, 1300, 8, 1, 576, 512),
(112, 1600, 8, 1, 576, 512),
(112, 1900, 8, 1, 576, 512),
(112, 2200, 8, 1, 576, 512),
(112, 2500, 8, 1, 576, 512),
(112, 2800, 8, 1, 576, 512),
(112, 3100, 8, 1, 576, 512),
(112, 3400, 8, 1, 576, 512),
(112, 3700, 8, 1, 576, 512),
(112, 4000, 8, 1, 576, 512),
(112, 4300, 8, 1, 576, 512),
(112, 4600, 8, 1, 576, 512),
(112, 4900, 8, 1, 576, 512),
(112, 5000, 8, 1, 576, 512),
(112, 5500, 8, 1, 576, 512),
(112, 6000, 8, 1, 576, 512),
(112, 6500, 8, 1, 576, 512),
(112, 7000, 8, 1, 576, 512),
(112, 7500, 8, 1, 576, 512),
(112, 8000, 8, 1, 576, 512),
(120, 1, 8, 1, 576, 512),
(120, 100, 8, 1, 576, 512),
(120, 400, 8, 1, 576, 512),
(120, 700, 8, 1, 576, 512),
(120, 1000, 8, 1, 576, 512),
(120, 1300, 8, 1, 576, 512),
(120, 1600, 8, 1, 576, 512),
(120, 1900, 8, 1, 576, 512),
(120, 2200, 8, 1, 576, 512),
(120, 2500, 8, 1, 576, 512),
(120, 2800, 8, 1, 576, 512),
(120, 3100, 8, 1, 576, 512),
(120, 3400, 8, 1, 576, 512),
(120, 3700, 8, 1, 576, 512),
(120, 4000, 8, 1, 576, 512),
(120, 4300, 8, 1, 576, 512),
(120, 4600, 8, 1, 576, 512),
(120, 4900, 8, 1, 576, 512),
(120, 5000, 8, 1, 576, 512),
(120, 5500, 8, 1, 576, 512),
(120, 6000, 8, 1, 576, 512),
(120, 6500, 8, 1, 576, 512),
(120, 7000, 8, 1, 576, 512),
(120, 7500, 8, 1, 576, 512),
(120, 8000, 8, 1, 576, 512),
(128, 1, 8, 1, 576, 512),
(128, 100, 8, 1, 576, 512),
(128, 400, 8, 1, 576, 512),
(128, 700, 8, 1, 576, 512),
(128, 1000, 8, 1, 576, 512),
(128, 1300, 8, 1, 576, 512),
(128, 1600, 8, 1, 576, 512),
(128, 1900, 8, 1, 576, 512),
(128, 2200, 8, 1, 576, 512),
(128, 2500, 8, 1, 576, 512),
(128, 2800, 8, 1, 576, 512),
(128, 3100, 8, 1, 576, 512),
(128, 3400, 8, 1, 576, 512),
(128, 3700, 8, 1, 576, 512),
(128, 4000, 8, 1, 576, 512),
(128, 4300, 8, 1, 576, 512),
(128, 4600, 8, 1, 576, 512),
(128, 4900, 8, 1, 576, 512),
(128, 5000, 8, 1, 576, 512),
(128, 5500, 8, 1, 576, 512),
(128, 6000, 8, 1, 576, 512),
(128, 6500, 8, 1, 576, 512),
(128, 7000, 8, 1, 576, 512),
(128, 7500, 8, 1, 576, 512),
(128, 8000, 8, 1, 576, 512),
]
deepseek_v2_16b_tp4_configs=[
(1, 1, 4, 1, 576, 512),
(1, 100, 4, 1, 576, 512),
(1, 400, 4, 1, 576, 512),
(1, 700, 4, 1, 576, 512),
(1, 1000, 4, 1, 576, 512),
(1, 1300, 4, 1, 576, 512),
(1, 1600, 4, 1, 576, 512),
(1, 1900, 4, 1, 576, 512),
(1, 2200, 4, 1, 576, 512),
(1, 2500, 4, 1, 576, 512),
(1, 2800, 4, 1, 576, 512),
(1, 3100, 4, 1, 576, 512),
(1, 3400, 4, 1, 576, 512),
(1, 3700, 4, 1, 576, 512),
(1, 4000, 4, 1, 576, 512),
(1, 4300, 4, 1, 576, 512),
(1, 4600, 4, 1, 576, 512),
(1, 4900, 4, 1, 576, 512),
(1, 5000, 4, 1, 576, 512),
(1, 5500, 4, 1, 576, 512),
(1, 6000, 4, 1, 576, 512),
(1, 6500, 4, 1, 576, 512),
(1, 7000, 4, 1, 576, 512),
(1, 7500, 4, 1, 576, 512),
(1, 8000, 4, 1, 576, 512),
(4, 1, 4, 1, 576, 512),
(4, 100, 4, 1, 576, 512),
(4, 400, 4, 1, 576, 512),
(4, 700, 4, 1, 576, 512),
(4, 1000, 4, 1, 576, 512),
(4, 1300, 4, 1, 576, 512),
(4, 1600, 4, 1, 576, 512),
(4, 1900, 4, 1, 576, 512),
(4, 2200, 4, 1, 576, 512),
(4, 2500, 4, 1, 576, 512),
(4, 2800, 4, 1, 576, 512),
(4, 3100, 4, 1, 576, 512),
(4, 3400, 4, 1, 576, 512),
(4, 3700, 4, 1, 576, 512),
(4, 4000, 4, 1, 576, 512),
(4, 4300, 4, 1, 576, 512),
(4, 4600, 4, 1, 576, 512),
(4, 4900, 4, 1, 576, 512),
(4, 5000, 4, 1, 576, 512),
(4, 5500, 4, 1, 576, 512),
(4, 6000, 4, 1, 576, 512),
(4, 6500, 4, 1, 576, 512),
(4, 7000, 4, 1, 576, 512),
(4, 7500, 4, 1, 576, 512),
(4, 8000, 4, 1, 576, 512),
(8, 1, 4, 1, 576, 512),
(8, 100, 4, 1, 576, 512),
(8, 400, 4, 1, 576, 512),
(8, 700, 4, 1, 576, 512),
(8, 1000, 4, 1, 576, 512),
(8, 1300, 4, 1, 576, 512),
(8, 1600, 4, 1, 576, 512),
(8, 1900, 4, 1, 576, 512),
(8, 2200, 4, 1, 576, 512),
(8, 2500, 4, 1, 576, 512),
(8, 2800, 4, 1, 576, 512),
(8, 3100, 4, 1, 576, 512),
(8, 3400, 4, 1, 576, 512),
(8, 3700, 4, 1, 576, 512),
(8, 4000, 4, 1, 576, 512),
(8, 4300, 4, 1, 576, 512),
(8, 4600, 4, 1, 576, 512),
(8, 4900, 4, 1, 576, 512),
(8, 5000, 4, 1, 576, 512),
(8, 5500, 4, 1, 576, 512),
(8, 6000, 4, 1, 576, 512),
(8, 6500, 4, 1, 576, 512),
(8, 7000, 4, 1, 576, 512),
(8, 7500, 4, 1, 576, 512),
(8, 8000, 4, 1, 576, 512),
(16, 1, 4, 1, 576, 512),
(16, 100, 4, 1, 576, 512),
(16, 400, 4, 1, 576, 512),
(16, 700, 4, 1, 576, 512),
(16, 1000, 4, 1, 576, 512),
(16, 1300, 4, 1, 576, 512),
(16, 1600, 4, 1, 576, 512),
(16, 1900, 4, 1, 576, 512),
(16, 2200, 4, 1, 576, 512),
(16, 2500, 4, 1, 576, 512),
(16, 2800, 4, 1, 576, 512),
(16, 3100, 4, 1, 576, 512),
(16, 3400, 4, 1, 576, 512),
(16, 3700, 4, 1, 576, 512),
(16, 4000, 4, 1, 576, 512),
(16, 4300, 4, 1, 576, 512),
(16, 4600, 4, 1, 576, 512),
(16, 4900, 4, 1, 576, 512),
(16, 5000, 4, 1, 576, 512),
(16, 5500, 4, 1, 576, 512),
(16, 6000, 4, 1, 576, 512),
(16, 6500, 4, 1, 576, 512),
(16, 7000, 4, 1, 576, 512),
(16, 7500, 4, 1, 576, 512),
(16, 8000, 4, 1, 576, 512),
(24, 1, 4, 1, 576, 512),
(24, 100, 4, 1, 576, 512),
(24, 400, 4, 1, 576, 512),
(24, 700, 4, 1, 576, 512),
(24, 1000, 4, 1, 576, 512),
(24, 1300, 4, 1, 576, 512),
(24, 1600, 4, 1, 576, 512),
(24, 1900, 4, 1, 576, 512),
(24, 2200, 4, 1, 576, 512),
(24, 2500, 4, 1, 576, 512),
(24, 2800, 4, 1, 576, 512),
(24, 3100, 4, 1, 576, 512),
(24, 3400, 4, 1, 576, 512),
(24, 3700, 4, 1, 576, 512),
(24, 4000, 4, 1, 576, 512),
(24, 4300, 4, 1, 576, 512),
(24, 4600, 4, 1, 576, 512),
(24, 4900, 4, 1, 576, 512),
(24, 5000, 4, 1, 576, 512),
(24, 5500, 4, 1, 576, 512),
(24, 6000, 4, 1, 576, 512),
(24, 6500, 4, 1, 576, 512),
(24, 7000, 4, 1, 576, 512),
(24, 7500, 4, 1, 576, 512),
(24, 8000, 4, 1, 576, 512),
(32, 1, 4, 1, 576, 512),
(32, 100, 4, 1, 576, 512),
(32, 400, 4, 1, 576, 512),
(32, 700, 4, 1, 576, 512),
(32, 1000, 4, 1, 576, 512),
(32, 1300, 4, 1, 576, 512),
(32, 1600, 4, 1, 576, 512),
(32, 1900, 4, 1, 576, 512),
(32, 2200, 4, 1, 576, 512),
(32, 2500, 4, 1, 576, 512),
(32, 2800, 4, 1, 576, 512),
(32, 3100, 4, 1, 576, 512),
(32, 3400, 4, 1, 576, 512),
(32, 3700, 4, 1, 576, 512),
(32, 4000, 4, 1, 576, 512),
(32, 4300, 4, 1, 576, 512),
(32, 4600, 4, 1, 576, 512),
(32, 4900, 4, 1, 576, 512),
(32, 5000, 4, 1, 576, 512),
(32, 5500, 4, 1, 576, 512),
(32, 6000, 4, 1, 576, 512),
(32, 6500, 4, 1, 576, 512),
(32, 7000, 4, 1, 576, 512),
(32, 7500, 4, 1, 576, 512),
(32, 8000, 4, 1, 576, 512),
(40, 1, 4, 1, 576, 512),
(40, 100, 4, 1, 576, 512),
(40, 300, 4, 1, 576, 512),
(40, 500, 4, 1, 576, 512),
(40, 700, 4, 1, 576, 512),
(40, 900, 4, 1, 576, 512),
(40, 1100, 4, 1, 576, 512),
(40, 1300, 4, 1, 576, 512),
(40, 1500, 4, 1, 576, 512),
(40, 1700, 4, 1, 576, 512),
(40, 1900, 4, 1, 576, 512),
(40, 2100, 4, 1, 576, 512),
(40, 2300, 4, 1, 576, 512),
(40, 2500, 4, 1, 576, 512),
(40, 2700, 4, 1, 576, 512),
(40, 2900, 4, 1, 576, 512),
(40, 3100, 4, 1, 576, 512),
(40, 3300, 4, 1, 576, 512),
(40, 3500, 4, 1, 576, 512),
(40, 3700, 4, 1, 576, 512),
(40, 3900, 4, 1, 576, 512),
(40, 4100, 4, 1, 576, 512),
(40, 4300, 4, 1, 576, 512),
(40, 4500, 4, 1, 576, 512),
(40, 4700, 4, 1, 576, 512),
(40, 4900, 4, 1, 576, 512),
(40, 5000, 4, 1, 576, 512),
(40, 5300, 4, 1, 576, 512),
(40, 5600, 4, 1, 576, 512),
(40, 5900, 4, 1, 576, 512),
(40, 6200, 4, 1, 576, 512),
(40, 6500, 4, 1, 576, 512),
(40, 6800, 4, 1, 576, 512),
(40, 7100, 4, 1, 576, 512),
(40, 7400, 4, 1, 576, 512),
(40, 7700, 4, 1, 576, 512),
(40, 8000, 4, 1, 576, 512),
(48, 1, 4, 1, 576, 512),
(48, 100, 4, 1, 576, 512),
(48, 300, 4, 1, 576, 512),
(48, 500, 4, 1, 576, 512),
(48, 700, 4, 1, 576, 512),
(48, 900, 4, 1, 576, 512),
(48, 1100, 4, 1, 576, 512),
(48, 1300, 4, 1, 576, 512),
(48, 1500, 4, 1, 576, 512),
(48, 1700, 4, 1, 576, 512),
(48, 1900, 4, 1, 576, 512),
(48, 2100, 4, 1, 576, 512),
(48, 2300, 4, 1, 576, 512),
(48, 2500, 4, 1, 576, 512),
(48, 2700, 4, 1, 576, 512),
(48, 2900, 4, 1, 576, 512),
(48, 3100, 4, 1, 576, 512),
(48, 3300, 4, 1, 576, 512),
(48, 3500, 4, 1, 576, 512),
(48, 3700, 4, 1, 576, 512),
(48, 3900, 4, 1, 576, 512),
(48, 4100, 4, 1, 576, 512),
(48, 4300, 4, 1, 576, 512),
(48, 4500, 4, 1, 576, 512),
(48, 4700, 4, 1, 576, 512),
(48, 4900, 4, 1, 576, 512),
(48, 5000, 4, 1, 576, 512),
(48, 5300, 4, 1, 576, 512),
(48, 5600, 4, 1, 576, 512),
(48, 5900, 4, 1, 576, 512),
(48, 6200, 4, 1, 576, 512),
(48, 6500, 4, 1, 576, 512),
(48, 6800, 4, 1, 576, 512),
(48, 7100, 4, 1, 576, 512),
(48, 7400, 4, 1, 576, 512),
(48, 7700, 4, 1, 576, 512),
(48, 8000, 4, 1, 576, 512),
(56, 1, 4, 1, 576, 512),
(56, 100, 4, 1, 576, 512),
(56, 300, 4, 1, 576, 512),
(56, 500, 4, 1, 576, 512),
(56, 700, 4, 1, 576, 512),
(56, 900, 4, 1, 576, 512),
(56, 1100, 4, 1, 576, 512),
(56, 1300, 4, 1, 576, 512),
(56, 1500, 4, 1, 576, 512),
(56, 1700, 4, 1, 576, 512),
(56, 1900, 4, 1, 576, 512),
(56, 2100, 4, 1, 576, 512),
(56, 2300, 4, 1, 576, 512),
(56, 2500, 4, 1, 576, 512),
(56, 2700, 4, 1, 576, 512),
(56, 2900, 4, 1, 576, 512),
(56, 3100, 4, 1, 576, 512),
(56, 3300, 4, 1, 576, 512),
(56, 3500, 4, 1, 576, 512),
(56, 3700, 4, 1, 576, 512),
(56, 3900, 4, 1, 576, 512),
(56, 4100, 4, 1, 576, 512),
(56, 4300, 4, 1, 576, 512),
(56, 4500, 4, 1, 576, 512),
(56, 4700, 4, 1, 576, 512),
(56, 4900, 4, 1, 576, 512),
(56, 5000, 4, 1, 576, 512),
(56, 5300, 4, 1, 576, 512),
(56, 5600, 4, 1, 576, 512),
(56, 5900, 4, 1, 576, 512),
(56, 6200, 4, 1, 576, 512),
(56, 6500, 4, 1, 576, 512),
(56, 6800, 4, 1, 576, 512),
(56, 7100, 4, 1, 576, 512),
(56, 7400, 4, 1, 576, 512),
(56, 7700, 4, 1, 576, 512),
(56, 8000, 4, 1, 576, 512),
(64, 1, 4, 1, 576, 512),
(64, 100, 4, 1, 576, 512),
(64, 300, 4, 1, 576, 512),
(64, 500, 4, 1, 576, 512),
(64, 700, 4, 1, 576, 512),
(64, 900, 4, 1, 576, 512),
(64, 1100, 4, 1, 576, 512),
(64, 1300, 4, 1, 576, 512),
(64, 1500, 4, 1, 576, 512),
(64, 1700, 4, 1, 576, 512),
(64, 1900, 4, 1, 576, 512),
(64, 2100, 4, 1, 576, 512),
(64, 2300, 4, 1, 576, 512),
(64, 2500, 4, 1, 576, 512),
(64, 2700, 4, 1, 576, 512),
(64, 2900, 4, 1, 576, 512),
(64, 3100, 4, 1, 576, 512),
(64, 3300, 4, 1, 576, 512),
(64, 3500, 4, 1, 576, 512),
(64, 3700, 4, 1, 576, 512),
(64, 3900, 4, 1, 576, 512),
(64, 4100, 4, 1, 576, 512),
(64, 4300, 4, 1, 576, 512),
(64, 4500, 4, 1, 576, 512),
(64, 4700, 4, 1, 576, 512),
(64, 4900, 4, 1, 576, 512),
(64, 5000, 4, 1, 576, 512),
(64, 5300, 4, 1, 576, 512),
(64, 5600, 4, 1, 576, 512),
(64, 5900, 4, 1, 576, 512),
(64, 6200, 4, 1, 576, 512),
(64, 6500, 4, 1, 576, 512),
(64, 6800, 4, 1, 576, 512),
(64, 7100, 4, 1, 576, 512),
(64, 7400, 4, 1, 576, 512),
(64, 7700, 4, 1, 576, 512),
(64, 8000, 4, 1, 576, 512),
(72, 1, 4, 1, 576, 512),
(72, 100, 4, 1, 576, 512),
(72, 300, 4, 1, 576, 512),
(72, 500, 4, 1, 576, 512),
(72, 700, 4, 1, 576, 512),
(72, 900, 4, 1, 576, 512),
(72, 1100, 4, 1, 576, 512),
(72, 1300, 4, 1, 576, 512),
(72, 1500, 4, 1, 576, 512),
(72, 1700, 4, 1, 576, 512),
(72, 1900, 4, 1, 576, 512),
(72, 2100, 4, 1, 576, 512),
(72, 2300, 4, 1, 576, 512),
(72, 2500, 4, 1, 576, 512),
(72, 2700, 4, 1, 576, 512),
(72, 2900, 4, 1, 576, 512),
(72, 3100, 4, 1, 576, 512),
(72, 3300, 4, 1, 576, 512),
(72, 3500, 4, 1, 576, 512),
(72, 3700, 4, 1, 576, 512),
(72, 3900, 4, 1, 576, 512),
(72, 4100, 4, 1, 576, 512),
(72, 4300, 4, 1, 576, 512),
(72, 4500, 4, 1, 576, 512),
(72, 4700, 4, 1, 576, 512),
(72, 4900, 4, 1, 576, 512),
(72, 5000, 4, 1, 576, 512),
(72, 5300, 4, 1, 576, 512),
(72, 5600, 4, 1, 576, 512),
(72, 5900, 4, 1, 576, 512),
(72, 6200, 4, 1, 576, 512),
(72, 6500, 4, 1, 576, 512),
(72, 6800, 4, 1, 576, 512),
(72, 7100, 4, 1, 576, 512),
(72, 7400, 4, 1, 576, 512),
(72, 7700, 4, 1, 576, 512),
(72, 8000, 4, 1, 576, 512),
(80, 1, 4, 1, 576, 512),
(80, 100, 4, 1, 576, 512),
(80, 300, 4, 1, 576, 512),
(80, 500, 4, 1, 576, 512),
(80, 700, 4, 1, 576, 512),
(80, 900, 4, 1, 576, 512),
(80, 1100, 4, 1, 576, 512),
(80, 1300, 4, 1, 576, 512),
(80, 1500, 4, 1, 576, 512),
(80, 1700, 4, 1, 576, 512),
(80, 1900, 4, 1, 576, 512),
(80, 2100, 4, 1, 576, 512),
(80, 2300, 4, 1, 576, 512),
(80, 2500, 4, 1, 576, 512),
(80, 2700, 4, 1, 576, 512),
(80, 2900, 4, 1, 576, 512),
(80, 3100, 4, 1, 576, 512),
(80, 3300, 4, 1, 576, 512),
(80, 3500, 4, 1, 576, 512),
(80, 3700, 4, 1, 576, 512),
(80, 3900, 4, 1, 576, 512),
(80, 4100, 4, 1, 576, 512),
(80, 4300, 4, 1, 576, 512),
(80, 4500, 4, 1, 576, 512),
(80, 4700, 4, 1, 576, 512),
(80, 4900, 4, 1, 576, 512),
(80, 5000, 4, 1, 576, 512),
(80, 5300, 4, 1, 576, 512),
(80, 5600, 4, 1, 576, 512),
(80, 5900, 4, 1, 576, 512),
(80, 6200, 4, 1, 576, 512),
(80, 6500, 4, 1, 576, 512),
(80, 6800, 4, 1, 576, 512),
(80, 7100, 4, 1, 576, 512),
(80, 7400, 4, 1, 576, 512),
(80, 7700, 4, 1, 576, 512),
(80, 8000, 4, 1, 576, 512),
(88, 1, 4, 1, 576, 512),
(88, 100, 4, 1, 576, 512),
(88, 300, 4, 1, 576, 512),
(88, 500, 4, 1, 576, 512),
(88, 700, 4, 1, 576, 512),
(88, 900, 4, 1, 576, 512),
(88, 1100, 4, 1, 576, 512),
(88, 1300, 4, 1, 576, 512),
(88, 1500, 4, 1, 576, 512),
(88, 1700, 4, 1, 576, 512),
(88, 1900, 4, 1, 576, 512),
(88, 2100, 4, 1, 576, 512),
(88, 2300, 4, 1, 576, 512),
(88, 2500, 4, 1, 576, 512),
(88, 2700, 4, 1, 576, 512),
(88, 2900, 4, 1, 576, 512),
(88, 3100, 4, 1, 576, 512),
(88, 3300, 4, 1, 576, 512),
(88, 3500, 4, 1, 576, 512),
(88, 3700, 4, 1, 576, 512),
(88, 3900, 4, 1, 576, 512),
(88, 4100, 4, 1, 576, 512),
(88, 4300, 4, 1, 576, 512),
(88, 4500, 4, 1, 576, 512),
(88, 4700, 4, 1, 576, 512),
(88, 4900, 4, 1, 576, 512),
(88, 5000, 4, 1, 576, 512),
(88, 5300, 4, 1, 576, 512),
(88, 5600, 4, 1, 576, 512),
(88, 5900, 4, 1, 576, 512),
(88, 6200, 4, 1, 576, 512),
(88, 6500, 4, 1, 576, 512),
(88, 6800, 4, 1, 576, 512),
(88, 7100, 4, 1, 576, 512),
(88, 7400, 4, 1, 576, 512),
(88, 7700, 4, 1, 576, 512),
(88, 8000, 4, 1, 576, 512),
(96, 1, 4, 1, 576, 512),
(96, 100, 4, 1, 576, 512),
(96, 300, 4, 1, 576, 512),
(96, 500, 4, 1, 576, 512),
(96, 700, 4, 1, 576, 512),
(96, 900, 4, 1, 576, 512),
(96, 1100, 4, 1, 576, 512),
(96, 1300, 4, 1, 576, 512),
(96, 1500, 4, 1, 576, 512),
(96, 1700, 4, 1, 576, 512),
(96, 1900, 4, 1, 576, 512),
(96, 2100, 4, 1, 576, 512),
(96, 2300, 4, 1, 576, 512),
(96, 2500, 4, 1, 576, 512),
(96, 2700, 4, 1, 576, 512),
(96, 2900, 4, 1, 576, 512),
(96, 3100, 4, 1, 576, 512),
(96, 3300, 4, 1, 576, 512),
(96, 3500, 4, 1, 576, 512),
(96, 3700, 4, 1, 576, 512),
(96, 3900, 4, 1, 576, 512),
(96, 4100, 4, 1, 576, 512),
(96, 4300, 4, 1, 576, 512),
(96, 4500, 4, 1, 576, 512),
(96, 4700, 4, 1, 576, 512),
(96, 4900, 4, 1, 576, 512),
(96, 5000, 4, 1, 576, 512),
(96, 5300, 4, 1, 576, 512),
(96, 5600, 4, 1, 576, 512),
(96, 5900, 4, 1, 576, 512),
(96, 6200, 4, 1, 576, 512),
(96, 6500, 4, 1, 576, 512),
(96, 6800, 4, 1, 576, 512),
(96, 7100, 4, 1, 576, 512),
(96, 7400, 4, 1, 576, 512),
(96, 7700, 4, 1, 576, 512),
(96, 8000, 4, 1, 576, 512),
(104, 1, 4, 1, 576, 512),
(104, 100, 4, 1, 576, 512),
(104, 400, 4, 1, 576, 512),
(104, 700, 4, 1, 576, 512),
(104, 1000, 4, 1, 576, 512),
(104, 1300, 4, 1, 576, 512),
(104, 1600, 4, 1, 576, 512),
(104, 1900, 4, 1, 576, 512),
(104, 2200, 4, 1, 576, 512),
(104, 2500, 4, 1, 576, 512),
(104, 2800, 4, 1, 576, 512),
(104, 3100, 4, 1, 576, 512),
(104, 3400, 4, 1, 576, 512),
(104, 3700, 4, 1, 576, 512),
(104, 4000, 4, 1, 576, 512),
(104, 4300, 4, 1, 576, 512),
(104, 4600, 4, 1, 576, 512),
(104, 4900, 4, 1, 576, 512),
(104, 5000, 4, 1, 576, 512),
(104, 5500, 4, 1, 576, 512),
(104, 6000, 4, 1, 576, 512),
(104, 6500, 4, 1, 576, 512),
(104, 7000, 4, 1, 576, 512),
(104, 7500, 4, 1, 576, 512),
(104, 8000, 4, 1, 576, 512),
(112, 1, 4, 1, 576, 512),
(112, 100, 4, 1, 576, 512),
(112, 400, 4, 1, 576, 512),
(112, 700, 4, 1, 576, 512),
(112, 1000, 4, 1, 576, 512),
(112, 1300, 4, 1, 576, 512),
(112, 1600, 4, 1, 576, 512),
(112, 1900, 4, 1, 576, 512),
(112, 2200, 4, 1, 576, 512),
(112, 2500, 4, 1, 576, 512),
(112, 2800, 4, 1, 576, 512),
(112, 3100, 4, 1, 576, 512),
(112, 3400, 4, 1, 576, 512),
(112, 3700, 4, 1, 576, 512),
(112, 4000, 4, 1, 576, 512),
(112, 4300, 4, 1, 576, 512),
(112, 4600, 4, 1, 576, 512),
(112, 4900, 4, 1, 576, 512),
(112, 5000, 4, 1, 576, 512),
(112, 5500, 4, 1, 576, 512),
(112, 6000, 4, 1, 576, 512),
(112, 6500, 4, 1, 576, 512),
(112, 7000, 4, 1, 576, 512),
(112, 7500, 4, 1, 576, 512),
(112, 8000, 4, 1, 576, 512),
(120, 1, 4, 1, 576, 512),
(120, 100, 4, 1, 576, 512),
(120, 400, 4, 1, 576, 512),
(120, 700, 4, 1, 576, 512),
(120, 1000, 4, 1, 576, 512),
(120, 1300, 4, 1, 576, 512),
(120, 1600, 4, 1, 576, 512),
(120, 1900, 4, 1, 576, 512),
(120, 2200, 4, 1, 576, 512),
(120, 2500, 4, 1, 576, 512),
(120, 2800, 4, 1, 576, 512),
(120, 3100, 4, 1, 576, 512),
(120, 3400, 4, 1, 576, 512),
(120, 3700, 4, 1, 576, 512),
(120, 4000, 4, 1, 576, 512),
(120, 4300, 4, 1, 576, 512),
(120, 4600, 4, 1, 576, 512),
(120, 4900, 4, 1, 576, 512),
(120, 5000, 4, 1, 576, 512),
(120, 5500, 4, 1, 576, 512),
(120, 6000, 4, 1, 576, 512),
(120, 6500, 4, 1, 576, 512),
(120, 7000, 4, 1, 576, 512),
(120, 7500, 4, 1, 576, 512),
(120, 8000, 4, 1, 576, 512),
(128, 1, 4, 1, 576, 512),
(128, 100, 4, 1, 576, 512),
(128, 400, 4, 1, 576, 512),
(128, 700, 4, 1, 576, 512),
(128, 1000, 4, 1, 576, 512),
(128, 1300, 4, 1, 576, 512),
(128, 1600, 4, 1, 576, 512),
(128, 1900, 4, 1, 576, 512),
(128, 2200, 4, 1, 576, 512),
(128, 2500, 4, 1, 576, 512),
(128, 2800, 4, 1, 576, 512),
(128, 3100, 4, 1, 576, 512),
(128, 3400, 4, 1, 576, 512),
(128, 3700, 4, 1, 576, 512),
(128, 4000, 4, 1, 576, 512),
(128, 4300, 4, 1, 576, 512),
(128, 4600, 4, 1, 576, 512),
(128, 4900, 4, 1, 576, 512),
(128, 5000, 4, 1, 576, 512),
(128, 5500, 4, 1, 576, 512),
(128, 6000, 4, 1, 576, 512),
(128, 6500, 4, 1, 576, 512),
(128, 7000, 4, 1, 576, 512),
(128, 7500, 4, 1, 576, 512),
(128, 8000, 4, 1, 576, 512),
]
deepseek_v2_16b_tp8_configs=[
(1, 1, 2, 1, 576, 512),
(1, 100, 2, 1, 576, 512),
(1, 400, 2, 1, 576, 512),
(1, 700, 2, 1, 576, 512),
(1, 1000, 2, 1, 576, 512),
(1, 1300, 2, 1, 576, 512),
(1, 1600, 2, 1, 576, 512),
(1, 1900, 2, 1, 576, 512),
(1, 2200, 2, 1, 576, 512),
(1, 2500, 2, 1, 576, 512),
(1, 2800, 2, 1, 576, 512),
(1, 3100, 2, 1, 576, 512),
(1, 3400, 2, 1, 576, 512),
(1, 3700, 2, 1, 576, 512),
(1, 4000, 2, 1, 576, 512),
(1, 4300, 2, 1, 576, 512),
(1, 4600, 2, 1, 576, 512),
(1, 4900, 2, 1, 576, 512),
(1, 5000, 2, 1, 576, 512),
(1, 5500, 2, 1, 576, 512),
(1, 6000, 2, 1, 576, 512),
(1, 6500, 2, 1, 576, 512),
(1, 7000, 2, 1, 576, 512),
(1, 7500, 2, 1, 576, 512),
(1, 8000, 2, 1, 576, 512),
(4, 1, 2, 1, 576, 512),
(4, 100, 2, 1, 576, 512),
(4, 400, 2, 1, 576, 512),
(4, 700, 2, 1, 576, 512),
(4, 1000, 2, 1, 576, 512),
(4, 1300, 2, 1, 576, 512),
(4, 1600, 2, 1, 576, 512),
(4, 1900, 2, 1, 576, 512),
(4, 2200, 2, 1, 576, 512),
(4, 2500, 2, 1, 576, 512),
(4, 2800, 2, 1, 576, 512),
(4, 3100, 2, 1, 576, 512),
(4, 3400, 2, 1, 576, 512),
(4, 3700, 2, 1, 576, 512),
(4, 4000, 2, 1, 576, 512),
(4, 4300, 2, 1, 576, 512),
(4, 4600, 2, 1, 576, 512),
(4, 4900, 2, 1, 576, 512),
(4, 5000, 2, 1, 576, 512),
(4, 5500, 2, 1, 576, 512),
(4, 6000, 2, 1, 576, 512),
(4, 6500, 2, 1, 576, 512),
(4, 7000, 2, 1, 576, 512),
(4, 7500, 2, 1, 576, 512),
(4, 8000, 2, 1, 576, 512),
(8, 1, 2, 1, 576, 512),
(8, 100, 2, 1, 576, 512),
(8, 400, 2, 1, 576, 512),
(8, 700, 2, 1, 576, 512),
(8, 1000, 2, 1, 576, 512),
(8, 1300, 2, 1, 576, 512),
(8, 1600, 2, 1, 576, 512),
(8, 1900, 2, 1, 576, 512),
(8, 2200, 2, 1, 576, 512),
(8, 2500, 2, 1, 576, 512),
(8, 2800, 2, 1, 576, 512),
(8, 3100, 2, 1, 576, 512),
(8, 3400, 2, 1, 576, 512),
(8, 3700, 2, 1, 576, 512),
(8, 4000, 2, 1, 576, 512),
(8, 4300, 2, 1, 576, 512),
(8, 4600, 2, 1, 576, 512),
(8, 4900, 2, 1, 576, 512),
(8, 5000, 2, 1, 576, 512),
(8, 5500, 2, 1, 576, 512),
(8, 6000, 2, 1, 576, 512),
(8, 6500, 2, 1, 576, 512),
(8, 7000, 2, 1, 576, 512),
(8, 7500, 2, 1, 576, 512),
(8, 8000, 2, 1, 576, 512),
(16, 1, 2, 1, 576, 512),
(16, 100, 2, 1, 576, 512),
(16, 400, 2, 1, 576, 512),
(16, 700, 2, 1, 576, 512),
(16, 1000, 2, 1, 576, 512),
(16, 1300, 2, 1, 576, 512),
(16, 1600, 2, 1, 576, 512),
(16, 1900, 2, 1, 576, 512),
(16, 2200, 2, 1, 576, 512),
(16, 2500, 2, 1, 576, 512),
(16, 2800, 2, 1, 576, 512),
(16, 3100, 2, 1, 576, 512),
(16, 3400, 2, 1, 576, 512),
(16, 3700, 2, 1, 576, 512),
(16, 4000, 2, 1, 576, 512),
(16, 4300, 2, 1, 576, 512),
(16, 4600, 2, 1, 576, 512),
(16, 4900, 2, 1, 576, 512),
(16, 5000, 2, 1, 576, 512),
(16, 5500, 2, 1, 576, 512),
(16, 6000, 2, 1, 576, 512),
(16, 6500, 2, 1, 576, 512),
(16, 7000, 2, 1, 576, 512),
(16, 7500, 2, 1, 576, 512),
(16, 8000, 2, 1, 576, 512),
(24, 1, 2, 1, 576, 512),
(24, 100, 2, 1, 576, 512),
(24, 400, 2, 1, 576, 512),
(24, 700, 2, 1, 576, 512),
(24, 1000, 2, 1, 576, 512),
(24, 1300, 2, 1, 576, 512),
(24, 1600, 2, 1, 576, 512),
(24, 1900, 2, 1, 576, 512),
(24, 2200, 2, 1, 576, 512),
(24, 2500, 2, 1, 576, 512),
(24, 2800, 2, 1, 576, 512),
(24, 3100, 2, 1, 576, 512),
(24, 3400, 2, 1, 576, 512),
(24, 3700, 2, 1, 576, 512),
(24, 4000, 2, 1, 576, 512),
(24, 4300, 2, 1, 576, 512),
(24, 4600, 2, 1, 576, 512),
(24, 4900, 2, 1, 576, 512),
(24, 5000, 2, 1, 576, 512),
(24, 5500, 2, 1, 576, 512),
(24, 6000, 2, 1, 576, 512),
(24, 6500, 2, 1, 576, 512),
(24, 7000, 2, 1, 576, 512),
(24, 7500, 2, 1, 576, 512),
(24, 8000, 2, 1, 576, 512),
(32, 1, 2, 1, 576, 512),
(32, 100, 2, 1, 576, 512),
(32, 400, 2, 1, 576, 512),
(32, 700, 2, 1, 576, 512),
(32, 1000, 2, 1, 576, 512),
(32, 1300, 2, 1, 576, 512),
(32, 1600, 2, 1, 576, 512),
(32, 1900, 2, 1, 576, 512),
(32, 2200, 2, 1, 576, 512),
(32, 2500, 2, 1, 576, 512),
(32, 2800, 2, 1, 576, 512),
(32, 3100, 2, 1, 576, 512),
(32, 3400, 2, 1, 576, 512),
(32, 3700, 2, 1, 576, 512),
(32, 4000, 2, 1, 576, 512),
(32, 4300, 2, 1, 576, 512),
(32, 4600, 2, 1, 576, 512),
(32, 4900, 2, 1, 576, 512),
(32, 5000, 2, 1, 576, 512),
(32, 5500, 2, 1, 576, 512),
(32, 6000, 2, 1, 576, 512),
(32, 6500, 2, 1, 576, 512),
(32, 7000, 2, 1, 576, 512),
(32, 7500, 2, 1, 576, 512),
(32, 8000, 2, 1, 576, 512),
(40, 1, 2, 1, 576, 512),
(40, 100, 2, 1, 576, 512),
(40, 300, 2, 1, 576, 512),
(40, 500, 2, 1, 576, 512),
(40, 700, 2, 1, 576, 512),
(40, 900, 2, 1, 576, 512),
(40, 1100, 2, 1, 576, 512),
(40, 1300, 2, 1, 576, 512),
(40, 1500, 2, 1, 576, 512),
(40, 1700, 2, 1, 576, 512),
(40, 1900, 2, 1, 576, 512),
(40, 2100, 2, 1, 576, 512),
(40, 2300, 2, 1, 576, 512),
(40, 2500, 2, 1, 576, 512),
(40, 2700, 2, 1, 576, 512),
(40, 2900, 2, 1, 576, 512),
(40, 3100, 2, 1, 576, 512),
(40, 3300, 2, 1, 576, 512),
(40, 3500, 2, 1, 576, 512),
(40, 3700, 2, 1, 576, 512),
(40, 3900, 2, 1, 576, 512),
(40, 4100, 2, 1, 576, 512),
(40, 4300, 2, 1, 576, 512),
(40, 4500, 2, 1, 576, 512),
(40, 4700, 2, 1, 576, 512),
(40, 4900, 2, 1, 576, 512),
(40, 5000, 2, 1, 576, 512),
(40, 5300, 2, 1, 576, 512),
(40, 5600, 2, 1, 576, 512),
(40, 5900, 2, 1, 576, 512),
(40, 6200, 2, 1, 576, 512),
(40, 6500, 2, 1, 576, 512),
(40, 6800, 2, 1, 576, 512),
(40, 7100, 2, 1, 576, 512),
(40, 7400, 2, 1, 576, 512),
(40, 7700, 2, 1, 576, 512),
(40, 8000, 2, 1, 576, 512),
(48, 1, 2, 1, 576, 512),
(48, 100, 2, 1, 576, 512),
(48, 300, 2, 1, 576, 512),
(48, 500, 2, 1, 576, 512),
(48, 700, 2, 1, 576, 512),
(48, 900, 2, 1, 576, 512),
(48, 1100, 2, 1, 576, 512),
(48, 1300, 2, 1, 576, 512),
(48, 1500, 2, 1, 576, 512),
(48, 1700, 2, 1, 576, 512),
(48, 1900, 2, 1, 576, 512),
(48, 2100, 2, 1, 576, 512),
(48, 2300, 2, 1, 576, 512),
(48, 2500, 2, 1, 576, 512),
(48, 2700, 2, 1, 576, 512),
(48, 2900, 2, 1, 576, 512),
(48, 3100, 2, 1, 576, 512),
(48, 3300, 2, 1, 576, 512),
(48, 3500, 2, 1, 576, 512),
(48, 3700, 2, 1, 576, 512),
(48, 3900, 2, 1, 576, 512),
(48, 4100, 2, 1, 576, 512),
(48, 4300, 2, 1, 576, 512),
(48, 4500, 2, 1, 576, 512),
(48, 4700, 2, 1, 576, 512),
(48, 4900, 2, 1, 576, 512),
(48, 5000, 2, 1, 576, 512),
(48, 5300, 2, 1, 576, 512),
(48, 5600, 2, 1, 576, 512),
(48, 5900, 2, 1, 576, 512),
(48, 6200, 2, 1, 576, 512),
(48, 6500, 2, 1, 576, 512),
(48, 6800, 2, 1, 576, 512),
(48, 7100, 2, 1, 576, 512),
(48, 7400, 2, 1, 576, 512),
(48, 7700, 2, 1, 576, 512),
(48, 8000, 2, 1, 576, 512),
(56, 1, 2, 1, 576, 512),
(56, 100, 2, 1, 576, 512),
(56, 300, 2, 1, 576, 512),
(56, 500, 2, 1, 576, 512),
(56, 700, 2, 1, 576, 512),
(56, 900, 2, 1, 576, 512),
(56, 1100, 2, 1, 576, 512),
(56, 1300, 2, 1, 576, 512),
(56, 1500, 2, 1, 576, 512),
(56, 1700, 2, 1, 576, 512),
(56, 1900, 2, 1, 576, 512),
(56, 2100, 2, 1, 576, 512),
(56, 2300, 2, 1, 576, 512),
(56, 2500, 2, 1, 576, 512),
(56, 2700, 2, 1, 576, 512),
(56, 2900, 2, 1, 576, 512),
(56, 3100, 2, 1, 576, 512),
(56, 3300, 2, 1, 576, 512),
(56, 3500, 2, 1, 576, 512),
(56, 3700, 2, 1, 576, 512),
(56, 3900, 2, 1, 576, 512),
(56, 4100, 2, 1, 576, 512),
(56, 4300, 2, 1, 576, 512),
(56, 4500, 2, 1, 576, 512),
(56, 4700, 2, 1, 576, 512),
(56, 4900, 2, 1, 576, 512),
(56, 5000, 2, 1, 576, 512),
(56, 5300, 2, 1, 576, 512),
(56, 5600, 2, 1, 576, 512),
(56, 5900, 2, 1, 576, 512),
(56, 6200, 2, 1, 576, 512),
(56, 6500, 2, 1, 576, 512),
(56, 6800, 2, 1, 576, 512),
(56, 7100, 2, 1, 576, 512),
(56, 7400, 2, 1, 576, 512),
(56, 7700, 2, 1, 576, 512),
(56, 8000, 2, 1, 576, 512),
(64, 1, 2, 1, 576, 512),
(64, 100, 2, 1, 576, 512),
(64, 300, 2, 1, 576, 512),
(64, 500, 2, 1, 576, 512),
(64, 700, 2, 1, 576, 512),
(64, 900, 2, 1, 576, 512),
(64, 1100, 2, 1, 576, 512),
(64, 1300, 2, 1, 576, 512),
(64, 1500, 2, 1, 576, 512),
(64, 1700, 2, 1, 576, 512),
(64, 1900, 2, 1, 576, 512),
(64, 2100, 2, 1, 576, 512),
(64, 2300, 2, 1, 576, 512),
(64, 2500, 2, 1, 576, 512),
(64, 2700, 2, 1, 576, 512),
(64, 2900, 2, 1, 576, 512),
(64, 3100, 2, 1, 576, 512),
(64, 3300, 2, 1, 576, 512),
(64, 3500, 2, 1, 576, 512),
(64, 3700, 2, 1, 576, 512),
(64, 3900, 2, 1, 576, 512),
(64, 4100, 2, 1, 576, 512),
(64, 4300, 2, 1, 576, 512),
(64, 4500, 2, 1, 576, 512),
(64, 4700, 2, 1, 576, 512),
(64, 4900, 2, 1, 576, 512),
(64, 5000, 2, 1, 576, 512),
(64, 5300, 2, 1, 576, 512),
(64, 5600, 2, 1, 576, 512),
(64, 5900, 2, 1, 576, 512),
(64, 6200, 2, 1, 576, 512),
(64, 6500, 2, 1, 576, 512),
(64, 6800, 2, 1, 576, 512),
(64, 7100, 2, 1, 576, 512),
(64, 7400, 2, 1, 576, 512),
(64, 7700, 2, 1, 576, 512),
(64, 8000, 2, 1, 576, 512),
(72, 1, 2, 1, 576, 512),
(72, 100, 2, 1, 576, 512),
(72, 300, 2, 1, 576, 512),
(72, 500, 2, 1, 576, 512),
(72, 700, 2, 1, 576, 512),
(72, 900, 2, 1, 576, 512),
(72, 1100, 2, 1, 576, 512),
(72, 1300, 2, 1, 576, 512),
(72, 1500, 2, 1, 576, 512),
(72, 1700, 2, 1, 576, 512),
(72, 1900, 2, 1, 576, 512),
(72, 2100, 2, 1, 576, 512),
(72, 2300, 2, 1, 576, 512),
(72, 2500, 2, 1, 576, 512),
(72, 2700, 2, 1, 576, 512),
(72, 2900, 2, 1, 576, 512),
(72, 3100, 2, 1, 576, 512),
(72, 3300, 2, 1, 576, 512),
(72, 3500, 2, 1, 576, 512),
(72, 3700, 2, 1, 576, 512),
(72, 3900, 2, 1, 576, 512),
(72, 4100, 2, 1, 576, 512),
(72, 4300, 2, 1, 576, 512),
(72, 4500, 2, 1, 576, 512),
(72, 4700, 2, 1, 576, 512),
(72, 4900, 2, 1, 576, 512),
(72, 5000, 2, 1, 576, 512),
(72, 5300, 2, 1, 576, 512),
(72, 5600, 2, 1, 576, 512),
(72, 5900, 2, 1, 576, 512),
(72, 6200, 2, 1, 576, 512),
(72, 6500, 2, 1, 576, 512),
(72, 6800, 2, 1, 576, 512),
(72, 7100, 2, 1, 576, 512),
(72, 7400, 2, 1, 576, 512),
(72, 7700, 2, 1, 576, 512),
(72, 8000, 2, 1, 576, 512),
(80, 1, 2, 1, 576, 512),
(80, 100, 2, 1, 576, 512),
(80, 300, 2, 1, 576, 512),
(80, 500, 2, 1, 576, 512),
(80, 700, 2, 1, 576, 512),
(80, 900, 2, 1, 576, 512),
(80, 1100, 2, 1, 576, 512),
(80, 1300, 2, 1, 576, 512),
(80, 1500, 2, 1, 576, 512),
(80, 1700, 2, 1, 576, 512),
(80, 1900, 2, 1, 576, 512),
(80, 2100, 2, 1, 576, 512),
(80, 2300, 2, 1, 576, 512),
(80, 2500, 2, 1, 576, 512),
(80, 2700, 2, 1, 576, 512),
(80, 2900, 2, 1, 576, 512),
(80, 3100, 2, 1, 576, 512),
(80, 3300, 2, 1, 576, 512),
(80, 3500, 2, 1, 576, 512),
(80, 3700, 2, 1, 576, 512),
(80, 3900, 2, 1, 576, 512),
(80, 4100, 2, 1, 576, 512),
(80, 4300, 2, 1, 576, 512),
(80, 4500, 2, 1, 576, 512),
(80, 4700, 2, 1, 576, 512),
(80, 4900, 2, 1, 576, 512),
(80, 5000, 2, 1, 576, 512),
(80, 5300, 2, 1, 576, 512),
(80, 5600, 2, 1, 576, 512),
(80, 5900, 2, 1, 576, 512),
(80, 6200, 2, 1, 576, 512),
(80, 6500, 2, 1, 576, 512),
(80, 6800, 2, 1, 576, 512),
(80, 7100, 2, 1, 576, 512),
(80, 7400, 2, 1, 576, 512),
(80, 7700, 2, 1, 576, 512),
(80, 8000, 2, 1, 576, 512),
(88, 1, 2, 1, 576, 512),
(88, 100, 2, 1, 576, 512),
(88, 300, 2, 1, 576, 512),
(88, 500, 2, 1, 576, 512),
(88, 700, 2, 1, 576, 512),
(88, 900, 2, 1, 576, 512),
(88, 1100, 2, 1, 576, 512),
(88, 1300, 2, 1, 576, 512),
(88, 1500, 2, 1, 576, 512),
(88, 1700, 2, 1, 576, 512),
(88, 1900, 2, 1, 576, 512),
(88, 2100, 2, 1, 576, 512),
(88, 2300, 2, 1, 576, 512),
(88, 2500, 2, 1, 576, 512),
(88, 2700, 2, 1, 576, 512),
(88, 2900, 2, 1, 576, 512),
(88, 3100, 2, 1, 576, 512),
(88, 3300, 2, 1, 576, 512),
(88, 3500, 2, 1, 576, 512),
(88, 3700, 2, 1, 576, 512),
(88, 3900, 2, 1, 576, 512),
(88, 4100, 2, 1, 576, 512),
(88, 4300, 2, 1, 576, 512),
(88, 4500, 2, 1, 576, 512),
(88, 4700, 2, 1, 576, 512),
(88, 4900, 2, 1, 576, 512),
(88, 5000, 2, 1, 576, 512),
(88, 5300, 2, 1, 576, 512),
(88, 5600, 2, 1, 576, 512),
(88, 5900, 2, 1, 576, 512),
(88, 6200, 2, 1, 576, 512),
(88, 6500, 2, 1, 576, 512),
(88, 6800, 2, 1, 576, 512),
(88, 7100, 2, 1, 576, 512),
(88, 7400, 2, 1, 576, 512),
(88, 7700, 2, 1, 576, 512),
(88, 8000, 2, 1, 576, 512),
(96, 1, 2, 1, 576, 512),
(96, 100, 2, 1, 576, 512),
(96, 300, 2, 1, 576, 512),
(96, 500, 2, 1, 576, 512),
(96, 700, 2, 1, 576, 512),
(96, 900, 2, 1, 576, 512),
(96, 1100, 2, 1, 576, 512),
(96, 1300, 2, 1, 576, 512),
(96, 1500, 2, 1, 576, 512),
(96, 1700, 2, 1, 576, 512),
(96, 1900, 2, 1, 576, 512),
(96, 2100, 2, 1, 576, 512),
(96, 2300, 2, 1, 576, 512),
(96, 2500, 2, 1, 576, 512),
(96, 2700, 2, 1, 576, 512),
(96, 2900, 2, 1, 576, 512),
(96, 3100, 2, 1, 576, 512),
(96, 3300, 2, 1, 576, 512),
(96, 3500, 2, 1, 576, 512),
(96, 3700, 2, 1, 576, 512),
(96, 3900, 2, 1, 576, 512),
(96, 4100, 2, 1, 576, 512),
(96, 4300, 2, 1, 576, 512),
(96, 4500, 2, 1, 576, 512),
(96, 4700, 2, 1, 576, 512),
(96, 4900, 2, 1, 576, 512),
(96, 5000, 2, 1, 576, 512),
(96, 5300, 2, 1, 576, 512),
(96, 5600, 2, 1, 576, 512),
(96, 5900, 2, 1, 576, 512),
(96, 6200, 2, 1, 576, 512),
(96, 6500, 2, 1, 576, 512),
(96, 6800, 2, 1, 576, 512),
(96, 7100, 2, 1, 576, 512),
(96, 7400, 2, 1, 576, 512),
(96, 7700, 2, 1, 576, 512),
(96, 8000, 2, 1, 576, 512),
(104, 1, 2, 1, 576, 512),
(104, 100, 2, 1, 576, 512),
(104, 400, 2, 1, 576, 512),
(104, 700, 2, 1, 576, 512),
(104, 1000, 2, 1, 576, 512),
(104, 1300, 2, 1, 576, 512),
(104, 1600, 2, 1, 576, 512),
(104, 1900, 2, 1, 576, 512),
(104, 2200, 2, 1, 576, 512),
(104, 2500, 2, 1, 576, 512),
(104, 2800, 2, 1, 576, 512),
(104, 3100, 2, 1, 576, 512),
(104, 3400, 2, 1, 576, 512),
(104, 3700, 2, 1, 576, 512),
(104, 4000, 2, 1, 576, 512),
(104, 4300, 2, 1, 576, 512),
(104, 4600, 2, 1, 576, 512),
(104, 4900, 2, 1, 576, 512),
(104, 5000, 2, 1, 576, 512),
(104, 5500, 2, 1, 576, 512),
(104, 6000, 2, 1, 576, 512),
(104, 6500, 2, 1, 576, 512),
(104, 7000, 2, 1, 576, 512),
(104, 7500, 2, 1, 576, 512),
(104, 8000, 2, 1, 576, 512),
(112, 1, 2, 1, 576, 512),
(112, 100, 2, 1, 576, 512),
(112, 400, 2, 1, 576, 512),
(112, 700, 2, 1, 576, 512),
(112, 1000, 2, 1, 576, 512),
(112, 1300, 2, 1, 576, 512),
(112, 1600, 2, 1, 576, 512),
(112, 1900, 2, 1, 576, 512),
(112, 2200, 2, 1, 576, 512),
(112, 2500, 2, 1, 576, 512),
(112, 2800, 2, 1, 576, 512),
(112, 3100, 2, 1, 576, 512),
(112, 3400, 2, 1, 576, 512),
(112, 3700, 2, 1, 576, 512),
(112, 4000, 2, 1, 576, 512),
(112, 4300, 2, 1, 576, 512),
(112, 4600, 2, 1, 576, 512),
(112, 4900, 2, 1, 576, 512),
(112, 5000, 2, 1, 576, 512),
(112, 5500, 2, 1, 576, 512),
(112, 6000, 2, 1, 576, 512),
(112, 6500, 2, 1, 576, 512),
(112, 7000, 2, 1, 576, 512),
(112, 7500, 2, 1, 576, 512),
(112, 8000, 2, 1, 576, 512),
(120, 1, 2, 1, 576, 512),
(120, 100, 2, 1, 576, 512),
(120, 400, 2, 1, 576, 512),
(120, 700, 2, 1, 576, 512),
(120, 1000, 2, 1, 576, 512),
(120, 1300, 2, 1, 576, 512),
(120, 1600, 2, 1, 576, 512),
(120, 1900, 2, 1, 576, 512),
(120, 2200, 2, 1, 576, 512),
(120, 2500, 2, 1, 576, 512),
(120, 2800, 2, 1, 576, 512),
(120, 3100, 2, 1, 576, 512),
(120, 3400, 2, 1, 576, 512),
(120, 3700, 2, 1, 576, 512),
(120, 4000, 2, 1, 576, 512),
(120, 4300, 2, 1, 576, 512),
(120, 4600, 2, 1, 576, 512),
(120, 4900, 2, 1, 576, 512),
(120, 5000, 2, 1, 576, 512),
(120, 5500, 2, 1, 576, 512),
(120, 6000, 2, 1, 576, 512),
(120, 6500, 2, 1, 576, 512),
(120, 7000, 2, 1, 576, 512),
(120, 7500, 2, 1, 576, 512),
(120, 8000, 2, 1, 576, 512),
(128, 1, 2, 1, 576, 512),
(128, 100, 2, 1, 576, 512),
(128, 400, 2, 1, 576, 512),
(128, 700, 2, 1, 576, 512),
(128, 1000, 2, 1, 576, 512),
(128, 1300, 2, 1, 576, 512),
(128, 1600, 2, 1, 576, 512),
(128, 1900, 2, 1, 576, 512),
(128, 2200, 2, 1, 576, 512),
(128, 2500, 2, 1, 576, 512),
(128, 2800, 2, 1, 576, 512),
(128, 3100, 2, 1, 576, 512),
(128, 3400, 2, 1, 576, 512),
(128, 3700, 2, 1, 576, 512),
(128, 4000, 2, 1, 576, 512),
(128, 4300, 2, 1, 576, 512),
(128, 4600, 2, 1, 576, 512),
(128, 4900, 2, 1, 576, 512),
(128, 5000, 2, 1, 576, 512),
(128, 5500, 2, 1, 576, 512),
(128, 6000, 2, 1, 576, 512),
(128, 6500, 2, 1, 576, 512),
(128, 7000, 2, 1, 576, 512),
(128, 7500, 2, 1, 576, 512),
(128, 8000, 2, 1, 576, 512),
]
deepseek_v2_671b_tp32_configs=[
(1, 1, 4, 1, 576, 512),
(1, 100, 4, 1, 576, 512),
(1, 400, 4, 1, 576, 512),
(1, 700, 4, 1, 576, 512),
(1, 1000, 4, 1, 576, 512),
(1, 1300, 4, 1, 576, 512),
(1, 1600, 4, 1, 576, 512),
(1, 1900, 4, 1, 576, 512),
(1, 2200, 4, 1, 576, 512),
(1, 2500, 4, 1, 576, 512),
(1, 2800, 4, 1, 576, 512),
(1, 3100, 4, 1, 576, 512),
(1, 3400, 4, 1, 576, 512),
(1, 3700, 4, 1, 576, 512),
(1, 4000, 4, 1, 576, 512),
(1, 4300, 4, 1, 576, 512),
(1, 4600, 4, 1, 576, 512),
(1, 4900, 4, 1, 576, 512),
(1, 5000, 4, 1, 576, 512),
(1, 5500, 4, 1, 576, 512),
(1, 6000, 4, 1, 576, 512),
(1, 6500, 4, 1, 576, 512),
(1, 7000, 4, 1, 576, 512),
(1, 7500, 4, 1, 576, 512),
(1, 8000, 4, 1, 576, 512),
(4, 1, 4, 1, 576, 512),
(4, 100, 4, 1, 576, 512),
(4, 400, 4, 1, 576, 512),
(4, 700, 4, 1, 576, 512),
(4, 1000, 4, 1, 576, 512),
(4, 1300, 4, 1, 576, 512),
(4, 1600, 4, 1, 576, 512),
(4, 1900, 4, 1, 576, 512),
(4, 2200, 4, 1, 576, 512),
(4, 2500, 4, 1, 576, 512),
(4, 2800, 4, 1, 576, 512),
(4, 3100, 4, 1, 576, 512),
(4, 3400, 4, 1, 576, 512),
(4, 3700, 4, 1, 576, 512),
(4, 4000, 4, 1, 576, 512),
(4, 4300, 4, 1, 576, 512),
(4, 4600, 4, 1, 576, 512),
(4, 4900, 4, 1, 576, 512),
(4, 5000, 4, 1, 576, 512),
(4, 5500, 4, 1, 576, 512),
(4, 6000, 4, 1, 576, 512),
(4, 6500, 4, 1, 576, 512),
(4, 7000, 4, 1, 576, 512),
(4, 7500, 4, 1, 576, 512),
(4, 8000, 4, 1, 576, 512),
(8, 1, 4, 1, 576, 512),
(8, 100, 4, 1, 576, 512),
(8, 400, 4, 1, 576, 512),
(8, 700, 4, 1, 576, 512),
(8, 1000, 4, 1, 576, 512),
(8, 1300, 4, 1, 576, 512),
(8, 1600, 4, 1, 576, 512),
(8, 1900, 4, 1, 576, 512),
(8, 2200, 4, 1, 576, 512),
(8, 2500, 4, 1, 576, 512),
(8, 2800, 4, 1, 576, 512),
(8, 3100, 4, 1, 576, 512),
(8, 3400, 4, 1, 576, 512),
(8, 3700, 4, 1, 576, 512),
(8, 4000, 4, 1, 576, 512),
(8, 4300, 4, 1, 576, 512),
(8, 4600, 4, 1, 576, 512),
(8, 4900, 4, 1, 576, 512),
(8, 5000, 4, 1, 576, 512),
(8, 5500, 4, 1, 576, 512),
(8, 6000, 4, 1, 576, 512),
(8, 6500, 4, 1, 576, 512),
(8, 7000, 4, 1, 576, 512),
(8, 7500, 4, 1, 576, 512),
(8, 8000, 4, 1, 576, 512),
(16, 1, 4, 1, 576, 512),
(16, 100, 4, 1, 576, 512),
(16, 400, 4, 1, 576, 512),
(16, 700, 4, 1, 576, 512),
(16, 1000, 4, 1, 576, 512),
(16, 1300, 4, 1, 576, 512),
(16, 1600, 4, 1, 576, 512),
(16, 1900, 4, 1, 576, 512),
(16, 2200, 4, 1, 576, 512),
(16, 2500, 4, 1, 576, 512),
(16, 2800, 4, 1, 576, 512),
(16, 3100, 4, 1, 576, 512),
(16, 3400, 4, 1, 576, 512),
(16, 3700, 4, 1, 576, 512),
(16, 4000, 4, 1, 576, 512),
(16, 4300, 4, 1, 576, 512),
(16, 4600, 4, 1, 576, 512),
(16, 4900, 4, 1, 576, 512),
(16, 5000, 4, 1, 576, 512),
(16, 5500, 4, 1, 576, 512),
(16, 6000, 4, 1, 576, 512),
(16, 6500, 4, 1, 576, 512),
(16, 7000, 4, 1, 576, 512),
(16, 7500, 4, 1, 576, 512),
(16, 8000, 4, 1, 576, 512),
(24, 1, 4, 1, 576, 512),
(24, 100, 4, 1, 576, 512),
(24, 400, 4, 1, 576, 512),
(24, 700, 4, 1, 576, 512),
(24, 1000, 4, 1, 576, 512),
(24, 1300, 4, 1, 576, 512),
(24, 1600, 4, 1, 576, 512),
(24, 1900, 4, 1, 576, 512),
(24, 2200, 4, 1, 576, 512),
(24, 2500, 4, 1, 576, 512),
(24, 2800, 4, 1, 576, 512),
(24, 3100, 4, 1, 576, 512),
(24, 3400, 4, 1, 576, 512),
(24, 3700, 4, 1, 576, 512),
(24, 4000, 4, 1, 576, 512),
(24, 4300, 4, 1, 576, 512),
(24, 4600, 4, 1, 576, 512),
(24, 4900, 4, 1, 576, 512),
(24, 5000, 4, 1, 576, 512),
(24, 5500, 4, 1, 576, 512),
(24, 6000, 4, 1, 576, 512),
(24, 6500, 4, 1, 576, 512),
(24, 7000, 4, 1, 576, 512),
(24, 7500, 4, 1, 576, 512),
(24, 8000, 4, 1, 576, 512),
(32, 1, 4, 1, 576, 512),
(32, 100, 4, 1, 576, 512),
(32, 400, 4, 1, 576, 512),
(32, 700, 4, 1, 576, 512),
(32, 1000, 4, 1, 576, 512),
(32, 1300, 4, 1, 576, 512),
(32, 1600, 4, 1, 576, 512),
(32, 1900, 4, 1, 576, 512),
(32, 2200, 4, 1, 576, 512),
(32, 2500, 4, 1, 576, 512),
(32, 2800, 4, 1, 576, 512),
(32, 3100, 4, 1, 576, 512),
(32, 3400, 4, 1, 576, 512),
(32, 3700, 4, 1, 576, 512),
(32, 4000, 4, 1, 576, 512),
(32, 4300, 4, 1, 576, 512),
(32, 4600, 4, 1, 576, 512),
(32, 4900, 4, 1, 576, 512),
(32, 5000, 4, 1, 576, 512),
(32, 5500, 4, 1, 576, 512),
(32, 6000, 4, 1, 576, 512),
(32, 6500, 4, 1, 576, 512),
(32, 7000, 4, 1, 576, 512),
(32, 7500, 4, 1, 576, 512),
(32, 8000, 4, 1, 576, 512),
(40, 1, 4, 1, 576, 512),
(40, 100, 4, 1, 576, 512),
(40, 300, 4, 1, 576, 512),
(40, 500, 4, 1, 576, 512),
(40, 700, 4, 1, 576, 512),
(40, 900, 4, 1, 576, 512),
(40, 1100, 4, 1, 576, 512),
(40, 1300, 4, 1, 576, 512),
(40, 1500, 4, 1, 576, 512),
(40, 1700, 4, 1, 576, 512),
(40, 1900, 4, 1, 576, 512),
(40, 2100, 4, 1, 576, 512),
(40, 2300, 4, 1, 576, 512),
(40, 2500, 4, 1, 576, 512),
(40, 2700, 4, 1, 576, 512),
(40, 2900, 4, 1, 576, 512),
(40, 3100, 4, 1, 576, 512),
(40, 3300, 4, 1, 576, 512),
(40, 3500, 4, 1, 576, 512),
(40, 3700, 4, 1, 576, 512),
(40, 3900, 4, 1, 576, 512),
(40, 4100, 4, 1, 576, 512),
(40, 4300, 4, 1, 576, 512),
(40, 4500, 4, 1, 576, 512),
(40, 4700, 4, 1, 576, 512),
(40, 4900, 4, 1, 576, 512),
(40, 5000, 4, 1, 576, 512),
(40, 5300, 4, 1, 576, 512),
(40, 5600, 4, 1, 576, 512),
(40, 5900, 4, 1, 576, 512),
(40, 6200, 4, 1, 576, 512),
(40, 6500, 4, 1, 576, 512),
(40, 6800, 4, 1, 576, 512),
(40, 7100, 4, 1, 576, 512),
(40, 7400, 4, 1, 576, 512),
(40, 7700, 4, 1, 576, 512),
(40, 8000, 4, 1, 576, 512),
(48, 1, 4, 1, 576, 512),
(48, 100, 4, 1, 576, 512),
(48, 300, 4, 1, 576, 512),
(48, 500, 4, 1, 576, 512),
(48, 700, 4, 1, 576, 512),
(48, 900, 4, 1, 576, 512),
(48, 1100, 4, 1, 576, 512),
(48, 1300, 4, 1, 576, 512),
(48, 1500, 4, 1, 576, 512),
(48, 1700, 4, 1, 576, 512),
(48, 1900, 4, 1, 576, 512),
(48, 2100, 4, 1, 576, 512),
(48, 2300, 4, 1, 576, 512),
(48, 2500, 4, 1, 576, 512),
(48, 2700, 4, 1, 576, 512),
(48, 2900, 4, 1, 576, 512),
(48, 3100, 4, 1, 576, 512),
(48, 3300, 4, 1, 576, 512),
(48, 3500, 4, 1, 576, 512),
(48, 3700, 4, 1, 576, 512),
(48, 3900, 4, 1, 576, 512),
(48, 4100, 4, 1, 576, 512),
(48, 4300, 4, 1, 576, 512),
(48, 4500, 4, 1, 576, 512),
(48, 4700, 4, 1, 576, 512),
(48, 4900, 4, 1, 576, 512),
(48, 5000, 4, 1, 576, 512),
(48, 5300, 4, 1, 576, 512),
(48, 5600, 4, 1, 576, 512),
(48, 5900, 4, 1, 576, 512),
(48, 6200, 4, 1, 576, 512),
(48, 6500, 4, 1, 576, 512),
(48, 6800, 4, 1, 576, 512),
(48, 7100, 4, 1, 576, 512),
(48, 7400, 4, 1, 576, 512),
(48, 7700, 4, 1, 576, 512),
(48, 8000, 4, 1, 576, 512),
(56, 1, 4, 1, 576, 512),
(56, 100, 4, 1, 576, 512),
(56, 300, 4, 1, 576, 512),
(56, 500, 4, 1, 576, 512),
(56, 700, 4, 1, 576, 512),
(56, 900, 4, 1, 576, 512),
(56, 1100, 4, 1, 576, 512),
(56, 1300, 4, 1, 576, 512),
(56, 1500, 4, 1, 576, 512),
(56, 1700, 4, 1, 576, 512),
(56, 1900, 4, 1, 576, 512),
(56, 2100, 4, 1, 576, 512),
(56, 2300, 4, 1, 576, 512),
(56, 2500, 4, 1, 576, 512),
(56, 2700, 4, 1, 576, 512),
(56, 2900, 4, 1, 576, 512),
(56, 3100, 4, 1, 576, 512),
(56, 3300, 4, 1, 576, 512),
(56, 3500, 4, 1, 576, 512),
(56, 3700, 4, 1, 576, 512),
(56, 3900, 4, 1, 576, 512),
(56, 4100, 4, 1, 576, 512),
(56, 4300, 4, 1, 576, 512),
(56, 4500, 4, 1, 576, 512),
(56, 4700, 4, 1, 576, 512),
(56, 4900, 4, 1, 576, 512),
(56, 5000, 4, 1, 576, 512),
(56, 5300, 4, 1, 576, 512),
(56, 5600, 4, 1, 576, 512),
(56, 5900, 4, 1, 576, 512),
(56, 6200, 4, 1, 576, 512),
(56, 6500, 4, 1, 576, 512),
(56, 6800, 4, 1, 576, 512),
(56, 7100, 4, 1, 576, 512),
(56, 7400, 4, 1, 576, 512),
(56, 7700, 4, 1, 576, 512),
(56, 8000, 4, 1, 576, 512),
(64, 1, 4, 1, 576, 512),
(64, 100, 4, 1, 576, 512),
(64, 300, 4, 1, 576, 512),
(64, 500, 4, 1, 576, 512),
(64, 700, 4, 1, 576, 512),
(64, 900, 4, 1, 576, 512),
(64, 1100, 4, 1, 576, 512),
(64, 1300, 4, 1, 576, 512),
(64, 1500, 4, 1, 576, 512),
(64, 1700, 4, 1, 576, 512),
(64, 1900, 4, 1, 576, 512),
(64, 2100, 4, 1, 576, 512),
(64, 2300, 4, 1, 576, 512),
(64, 2500, 4, 1, 576, 512),
(64, 2700, 4, 1, 576, 512),
(64, 2900, 4, 1, 576, 512),
(64, 3100, 4, 1, 576, 512),
(64, 3300, 4, 1, 576, 512),
(64, 3500, 4, 1, 576, 512),
(64, 3700, 4, 1, 576, 512),
(64, 3900, 4, 1, 576, 512),
(64, 4100, 4, 1, 576, 512),
(64, 4300, 4, 1, 576, 512),
(64, 4500, 4, 1, 576, 512),
(64, 4700, 4, 1, 576, 512),
(64, 4900, 4, 1, 576, 512),
(64, 5000, 4, 1, 576, 512),
(64, 5300, 4, 1, 576, 512),
(64, 5600, 4, 1, 576, 512),
(64, 5900, 4, 1, 576, 512),
(64, 6200, 4, 1, 576, 512),
(64, 6500, 4, 1, 576, 512),
(64, 6800, 4, 1, 576, 512),
(64, 7100, 4, 1, 576, 512),
(64, 7400, 4, 1, 576, 512),
(64, 7700, 4, 1, 576, 512),
(64, 8000, 4, 1, 576, 512),
(72, 1, 4, 1, 576, 512),
(72, 100, 4, 1, 576, 512),
(72, 300, 4, 1, 576, 512),
(72, 500, 4, 1, 576, 512),
(72, 700, 4, 1, 576, 512),
(72, 900, 4, 1, 576, 512),
(72, 1100, 4, 1, 576, 512),
(72, 1300, 4, 1, 576, 512),
(72, 1500, 4, 1, 576, 512),
(72, 1700, 4, 1, 576, 512),
(72, 1900, 4, 1, 576, 512),
(72, 2100, 4, 1, 576, 512),
(72, 2300, 4, 1, 576, 512),
(72, 2500, 4, 1, 576, 512),
(72, 2700, 4, 1, 576, 512),
(72, 2900, 4, 1, 576, 512),
(72, 3100, 4, 1, 576, 512),
(72, 3300, 4, 1, 576, 512),
(72, 3500, 4, 1, 576, 512),
(72, 3700, 4, 1, 576, 512),
(72, 3900, 4, 1, 576, 512),
(72, 4100, 4, 1, 576, 512),
(72, 4300, 4, 1, 576, 512),
(72, 4500, 4, 1, 576, 512),
(72, 4700, 4, 1, 576, 512),
(72, 4900, 4, 1, 576, 512),
(72, 5000, 4, 1, 576, 512),
(72, 5300, 4, 1, 576, 512),
(72, 5600, 4, 1, 576, 512),
(72, 5900, 4, 1, 576, 512),
(72, 6200, 4, 1, 576, 512),
(72, 6500, 4, 1, 576, 512),
(72, 6800, 4, 1, 576, 512),
(72, 7100, 4, 1, 576, 512),
(72, 7400, 4, 1, 576, 512),
(72, 7700, 4, 1, 576, 512),
(72, 8000, 4, 1, 576, 512),
(80, 1, 4, 1, 576, 512),
(80, 100, 4, 1, 576, 512),
(80, 300, 4, 1, 576, 512),
(80, 500, 4, 1, 576, 512),
(80, 700, 4, 1, 576, 512),
(80, 900, 4, 1, 576, 512),
(80, 1100, 4, 1, 576, 512),
(80, 1300, 4, 1, 576, 512),
(80, 1500, 4, 1, 576, 512),
(80, 1700, 4, 1, 576, 512),
(80, 1900, 4, 1, 576, 512),
(80, 2100, 4, 1, 576, 512),
(80, 2300, 4, 1, 576, 512),
(80, 2500, 4, 1, 576, 512),
(80, 2700, 4, 1, 576, 512),
(80, 2900, 4, 1, 576, 512),
(80, 3100, 4, 1, 576, 512),
(80, 3300, 4, 1, 576, 512),
(80, 3500, 4, 1, 576, 512),
(80, 3700, 4, 1, 576, 512),
(80, 3900, 4, 1, 576, 512),
(80, 4100, 4, 1, 576, 512),
(80, 4300, 4, 1, 576, 512),
(80, 4500, 4, 1, 576, 512),
(80, 4700, 4, 1, 576, 512),
(80, 4900, 4, 1, 576, 512),
(80, 5000, 4, 1, 576, 512),
(80, 5300, 4, 1, 576, 512),
(80, 5600, 4, 1, 576, 512),
(80, 5900, 4, 1, 576, 512),
(80, 6200, 4, 1, 576, 512),
(80, 6500, 4, 1, 576, 512),
(80, 6800, 4, 1, 576, 512),
(80, 7100, 4, 1, 576, 512),
(80, 7400, 4, 1, 576, 512),
(80, 7700, 4, 1, 576, 512),
(80, 8000, 4, 1, 576, 512),
(88, 1, 4, 1, 576, 512),
(88, 100, 4, 1, 576, 512),
(88, 300, 4, 1, 576, 512),
(88, 500, 4, 1, 576, 512),
(88, 700, 4, 1, 576, 512),
(88, 900, 4, 1, 576, 512),
(88, 1100, 4, 1, 576, 512),
(88, 1300, 4, 1, 576, 512),
(88, 1500, 4, 1, 576, 512),
(88, 1700, 4, 1, 576, 512),
(88, 1900, 4, 1, 576, 512),
(88, 2100, 4, 1, 576, 512),
(88, 2300, 4, 1, 576, 512),
(88, 2500, 4, 1, 576, 512),
(88, 2700, 4, 1, 576, 512),
(88, 2900, 4, 1, 576, 512),
(88, 3100, 4, 1, 576, 512),
(88, 3300, 4, 1, 576, 512),
(88, 3500, 4, 1, 576, 512),
(88, 3700, 4, 1, 576, 512),
(88, 3900, 4, 1, 576, 512),
(88, 4100, 4, 1, 576, 512),
(88, 4300, 4, 1, 576, 512),
(88, 4500, 4, 1, 576, 512),
(88, 4700, 4, 1, 576, 512),
(88, 4900, 4, 1, 576, 512),
(88, 5000, 4, 1, 576, 512),
(88, 5300, 4, 1, 576, 512),
(88, 5600, 4, 1, 576, 512),
(88, 5900, 4, 1, 576, 512),
(88, 6200, 4, 1, 576, 512),
(88, 6500, 4, 1, 576, 512),
(88, 6800, 4, 1, 576, 512),
(88, 7100, 4, 1, 576, 512),
(88, 7400, 4, 1, 576, 512),
(88, 7700, 4, 1, 576, 512),
(88, 8000, 4, 1, 576, 512),
(96, 1, 4, 1, 576, 512),
(96, 100, 4, 1, 576, 512),
(96, 300, 4, 1, 576, 512),
(96, 500, 4, 1, 576, 512),
(96, 700, 4, 1, 576, 512),
(96, 900, 4, 1, 576, 512),
(96, 1100, 4, 1, 576, 512),
(96, 1300, 4, 1, 576, 512),
(96, 1500, 4, 1, 576, 512),
(96, 1700, 4, 1, 576, 512),
(96, 1900, 4, 1, 576, 512),
(96, 2100, 4, 1, 576, 512),
(96, 2300, 4, 1, 576, 512),
(96, 2500, 4, 1, 576, 512),
(96, 2700, 4, 1, 576, 512),
(96, 2900, 4, 1, 576, 512),
(96, 3100, 4, 1, 576, 512),
(96, 3300, 4, 1, 576, 512),
(96, 3500, 4, 1, 576, 512),
(96, 3700, 4, 1, 576, 512),
(96, 3900, 4, 1, 576, 512),
(96, 4100, 4, 1, 576, 512),
(96, 4300, 4, 1, 576, 512),
(96, 4500, 4, 1, 576, 512),
(96, 4700, 4, 1, 576, 512),
(96, 4900, 4, 1, 576, 512),
(96, 5000, 4, 1, 576, 512),
(96, 5300, 4, 1, 576, 512),
(96, 5600, 4, 1, 576, 512),
(96, 5900, 4, 1, 576, 512),
(96, 6200, 4, 1, 576, 512),
(96, 6500, 4, 1, 576, 512),
(96, 6800, 4, 1, 576, 512),
(96, 7100, 4, 1, 576, 512),
(96, 7400, 4, 1, 576, 512),
(96, 7700, 4, 1, 576, 512),
(96, 8000, 4, 1, 576, 512),
(104, 1, 4, 1, 576, 512),
(104, 100, 4, 1, 576, 512),
(104, 400, 4, 1, 576, 512),
(104, 700, 4, 1, 576, 512),
(104, 1000, 4, 1, 576, 512),
(104, 1300, 4, 1, 576, 512),
(104, 1600, 4, 1, 576, 512),
(104, 1900, 4, 1, 576, 512),
(104, 2200, 4, 1, 576, 512),
(104, 2500, 4, 1, 576, 512),
(104, 2800, 4, 1, 576, 512),
(104, 3100, 4, 1, 576, 512),
(104, 3400, 4, 1, 576, 512),
(104, 3700, 4, 1, 576, 512),
(104, 4000, 4, 1, 576, 512),
(104, 4300, 4, 1, 576, 512),
(104, 4600, 4, 1, 576, 512),
(104, 4900, 4, 1, 576, 512),
(104, 5000, 4, 1, 576, 512),
(104, 5500, 4, 1, 576, 512),
(104, 6000, 4, 1, 576, 512),
(104, 6500, 4, 1, 576, 512),
(104, 7000, 4, 1, 576, 512),
(104, 7500, 4, 1, 576, 512),
(104, 8000, 4, 1, 576, 512),
(112, 1, 4, 1, 576, 512),
(112, 100, 4, 1, 576, 512),
(112, 400, 4, 1, 576, 512),
(112, 700, 4, 1, 576, 512),
(112, 1000, 4, 1, 576, 512),
(112, 1300, 4, 1, 576, 512),
(112, 1600, 4, 1, 576, 512),
(112, 1900, 4, 1, 576, 512),
(112, 2200, 4, 1, 576, 512),
(112, 2500, 4, 1, 576, 512),
(112, 2800, 4, 1, 576, 512),
(112, 3100, 4, 1, 576, 512),
(112, 3400, 4, 1, 576, 512),
(112, 3700, 4, 1, 576, 512),
(112, 4000, 4, 1, 576, 512),
(112, 4300, 4, 1, 576, 512),
(112, 4600, 4, 1, 576, 512),
(112, 4900, 4, 1, 576, 512),
(112, 5000, 4, 1, 576, 512),
(112, 5500, 4, 1, 576, 512),
(112, 6000, 4, 1, 576, 512),
(112, 6500, 4, 1, 576, 512),
(112, 7000, 4, 1, 576, 512),
(112, 7500, 4, 1, 576, 512),
(112, 8000, 4, 1, 576, 512),
(120, 1, 4, 1, 576, 512),
(120, 100, 4, 1, 576, 512),
(120, 400, 4, 1, 576, 512),
(120, 700, 4, 1, 576, 512),
(120, 1000, 4, 1, 576, 512),
(120, 1300, 4, 1, 576, 512),
(120, 1600, 4, 1, 576, 512),
(120, 1900, 4, 1, 576, 512),
(120, 2200, 4, 1, 576, 512),
(120, 2500, 4, 1, 576, 512),
(120, 2800, 4, 1, 576, 512),
(120, 3100, 4, 1, 576, 512),
(120, 3400, 4, 1, 576, 512),
(120, 3700, 4, 1, 576, 512),
(120, 4000, 4, 1, 576, 512),
(120, 4300, 4, 1, 576, 512),
(120, 4600, 4, 1, 576, 512),
(120, 4900, 4, 1, 576, 512),
(120, 5000, 4, 1, 576, 512),
(120, 5500, 4, 1, 576, 512),
(120, 6000, 4, 1, 576, 512),
(120, 6500, 4, 1, 576, 512),
(120, 7000, 4, 1, 576, 512),
(120, 7500, 4, 1, 576, 512),
(120, 8000, 4, 1, 576, 512),
(128, 1, 4, 1, 576, 512),
(128, 100, 4, 1, 576, 512),
(128, 400, 4, 1, 576, 512),
(128, 700, 4, 1, 576, 512),
(128, 1000, 4, 1, 576, 512),
(128, 1300, 4, 1, 576, 512),
(128, 1600, 4, 1, 576, 512),
(128, 1900, 4, 1, 576, 512),
(128, 2200, 4, 1, 576, 512),
(128, 2500, 4, 1, 576, 512),
(128, 2800, 4, 1, 576, 512),
(128, 3100, 4, 1, 576, 512),
(128, 3400, 4, 1, 576, 512),
(128, 3700, 4, 1, 576, 512),
(128, 4000, 4, 1, 576, 512),
(128, 4300, 4, 1, 576, 512),
(128, 4600, 4, 1, 576, 512),
(128, 4900, 4, 1, 576, 512),
(128, 5000, 4, 1, 576, 512),
(128, 5500, 4, 1, 576, 512),
(128, 6000, 4, 1, 576, 512),
(128, 6500, 4, 1, 576, 512),
(128, 7000, 4, 1, 576, 512),
(128, 7500, 4, 1, 576, 512),
(128, 8000, 4, 1, 576, 512),
]
\ No newline at end of file
import functools
import json
import torch
import os
from enum import Enum
from typing import Any, Dict, Optional, Tuple
import bisect
from vllm.logger import init_logger
logger = init_logger(__name__)
class KERNLE_KINDS(Enum):
v1 = 0
v1_tc = 1
v1_tc_2 = 2
v1_2stages = 3
v1_2stages_tc = 4
v2 = 5
v2_tc = 6
TOTAL_KIND = 7
class BestConfig():
def __init__(self):
self.batch_size = 0
self.seq_len = 0
self.kernel_kind = KERNLE_KINDS.TOTAL_KIND
self.BLOCK_N = 0
self.BLOCK_SEQ = 0
self.SPLIT_K = 0
self.num_stages = 0
self.num_warps = 0
self.NUM_KV_SPLITS = 0
self.BLOCK_N_2 = 0
self.num_stages_2 = 0
self.num_warps_2 = 0
self.best_us = 0
self.decode_fwd_stage1 = None
self.decode_fwd_stage2 = None
def get_mla_config_file_name(QH: int, KVH: int, QKD: int, VD: int, cache_dtype: Optional[str]) -> str:
if cache_dtype == "default":
return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_default.json"
device_name = torch.cuda.get_device_name().replace(" ", "_")
if "K100_AI" in device_name:
return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_{cache_dtype}_K100AI.json"
elif "BW" in device_name:
return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_{cache_dtype}_BW.json"
else:
raise ValueError(f"Unsurpport device name: {device_name}")
def get_config_file_name(QH: int, KVH: int, D: int, cache_dtype: Optional[str]) -> str:
if cache_dtype == "default":
return f"QH={QH}_KVH={KVH}_D={D}_default.json"
device_name = torch.cuda.get_device_name().replace(" ", "_")
if "K100_AI" in device_name:
return f"QH={QH}_KVH={KVH}_D={D}_{cache_dtype}_K100AI.json"
elif "BW" in device_name:
return f"QH={QH}_KVH={KVH}_D={D}_{cache_dtype}_BW.json"
else:
raise ValueError(f"Unsurpport device name: {device_name}")
def get_attention_mla_configs_json(QH: int, KVH: int, QKD: int, VD: int, cache_dtype: Optional[str]) -> Optional[Dict[Any, Any]]:
# First look up if an optimized configuration is available in the configs
# directory
json_file_name = get_mla_config_file_name(QH, KVH, QKD, VD, cache_dtype)
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.info("Using decode attention configuration from %s for attention layer.", config_file_path)
# If a configuration has been found, return it
return json.load(f)
else:
logger.warning("Can not find best decode attention configuration %s for attention layer, it may not have the best performance to use default json. Please tune one. ", config_file_path)
json_file_name = get_mla_config_file_name(16, 1, 576, 512, "default")
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.warning("Using default decode attention configuration from %s for attention layer. It may not have the best performance to use default json. ", config_file_path)
# If a configuration has been found, return it
return json.load(f)
else:
raise ValueError("Please surpport default config can match 16 1 576 512")
# If no optimized configuration is available, we will use the default
# configuration
return None
def get_config_map(attention_configs):
ret_map = {}
for bs in attention_configs.keys():
int_bs = int(bs)
seq_map = {}
seq_configs = attention_configs[bs]
ret_map[int_bs] = seq_map
for seq_len in seq_configs.keys():
int_seq_len = int(seq_len)
kind_config = seq_configs[seq_len]
configs = BestConfig()
configs.batch_size = int_bs
configs.seq_len = int_seq_len
configs.best_us = kind_config['best_us']
seq_map[int_seq_len] = configs
if kind_config['kernel_kind'] == 'v1':
best_config = kind_config['best_config']
configs.kernel_kind = KERNLE_KINDS.v1
configs.BLOCK_N = best_config['BLOCK_N']
configs.num_stages = best_config['num_stages']
configs.num_warps = best_config['num_warps']
elif kind_config['kernel_kind'] == 'v1_tc':
best_config = kind_config['best_config']
configs.kernel_kind = KERNLE_KINDS.v1_tc
configs.BLOCK_N = best_config['BLOCK_N']
configs.num_stages = best_config['num_stages']
configs.num_warps = best_config['num_warps']
elif kind_config['kernel_kind'] == 'v1_tc_2':
best_config = kind_config['best_config']
configs.kernel_kind = KERNLE_KINDS.v1_tc_2
configs.BLOCK_N = best_config['BLOCK_N']
configs.num_stages = best_config['num_stages']
configs.num_warps = best_config['num_warps']
elif kind_config['kernel_kind'] == 'v1_2stages':
best_config = kind_config['best_config']
stage1 = best_config['stage1']
stage2 = best_config['stage2']
configs.kernel_kind = KERNLE_KINDS.v1_2stages
# configs.SPLIT_K = stage1['SPLIT_K']
configs.BLOCK_N = stage1['BLOCK_N']
configs.num_stages = stage1['num_stages']
configs.num_warps = stage1['num_warps']
configs.BLOCK_N_2 = stage2['BLOCK_N']
configs.num_stages_2 = stage2['num_stages']
configs.num_warps_2 = stage2['num_warps']
elif kind_config['kernel_kind'] == 'v1_2stages_tc':
best_config = kind_config['best_config']
stage1 = best_config['stage1']
stage2 = best_config['stage2']
configs.kernel_kind = KERNLE_KINDS.v1_2stages_tc
# configs.SPLIT_K = stage1['SPLIT_K']
configs.BLOCK_N = stage1['BLOCK_N']
configs.num_stages = stage1['num_stages']
configs.num_warps = stage1['num_warps']
configs.BLOCK_N_2 = stage2['BLOCK_N']
configs.num_stages_2 = stage2['num_stages']
configs.num_warps_2 = stage2['num_warps']
elif kind_config['kernel_kind'] == 'v2':
best_config = kind_config['best_config']
stage1 = best_config['stage1']
stage2 = best_config['stage2']
configs.kernel_kind = KERNLE_KINDS.v2
if 'BLOCK_SEQ' in stage1:
configs.BLOCK_SEQ = stage1['BLOCK_SEQ']
else:
configs.NUM_KV_SPLITS = stage1['NUM_KV_SPLITS']
configs.BLOCK_N = stage1['BLOCK_N']
configs.num_stages = stage1['num_stages']
configs.num_warps = stage1['num_warps']
configs.num_stages_2 = stage2['num_stages']
configs.num_warps_2 = stage2['num_warps']
elif kind_config['kernel_kind'] == 'v2_tc':
best_config = kind_config['best_config']
stage1 = best_config['stage1']
stage2 = best_config['stage2']
configs.kernel_kind = KERNLE_KINDS.v2_tc
if 'BLOCK_SEQ' in stage1:
configs.BLOCK_SEQ = stage1['BLOCK_SEQ']
else:
configs.NUM_KV_SPLITS = stage1['NUM_KV_SPLITS']
configs.BLOCK_N = stage1['BLOCK_N']
configs.num_stages = stage1['num_stages']
configs.num_warps = stage1['num_warps']
configs.num_stages_2 = stage2['num_stages']
configs.num_warps_2 = stage2['num_warps']
return ret_map
@functools.lru_cache
def get_attention_mla_configs(QH: int, KVH: int, QKD: int, VD: int, cache_dtype: Optional[str]) -> Optional[Dict[Any, Any]]:
attention_configs = get_attention_mla_configs_json(QH, KVH, QKD, VD, cache_dtype)
return get_config_map(attention_configs)
def get_closest_key(dic_keys, target_key):
keys = list(dic_keys)
idx = bisect.bisect_left(keys, target_key)
if idx == 0:
return keys[0]
if idx == len(keys):
return keys[-1]
left_key = keys[idx - 1]
right_key = keys[idx]
if target_key - left_key <= right_key - target_key:
return left_key
else:
return right_key
def get_nearest_config(bs_key, mean_kv_seqlen_key, config):
closest_bs_key = get_closest_key(config.keys(), bs_key)
closest_mean_kv_seqlen_key = get_closest_key(config[closest_bs_key].keys(), mean_kv_seqlen_key)
return config[closest_bs_key][closest_mean_kv_seqlen_key]
def get_config(bs_key, mean_kv_seqlen_key, config):
if bs_key in config and mean_kv_seqlen_key in config[bs_key]:
return config[bs_key][mean_kv_seqlen_key]
else:
raise ValueError(f"No matching configuration found for bs key: {bs_key} and mean kv seq key: {mean_kv_seqlen_key} when init decode attention db")
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os
import functools
import json
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
...@@ -10,6 +7,7 @@ from itertools import accumulate ...@@ -10,6 +7,7 @@ from itertools import accumulate
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
from vllm.multimodal import MultiModalPlaceholderMap from vllm.multimodal import MultiModalPlaceholderMap
from .triton_config import get_nearest_config, get_attention_mla_configs, get_config
try: try:
from flashinfer import BatchDecodeMlaWithPagedKVCacheWrapper from flashinfer import BatchDecodeMlaWithPagedKVCacheWrapper
...@@ -39,65 +37,6 @@ if TYPE_CHECKING: ...@@ -39,65 +37,6 @@ if TYPE_CHECKING:
ModelInputForGPUWithSamplingMetadata) ModelInputForGPUWithSamplingMetadata)
def get_config(bs_key, mean_kv_seqlen_key, config):
# 转换参数为字符串以匹配字典的键
bs_key_str = str(bs_key)
mean_kv_seqlen_key_str = str(mean_kv_seqlen_key)
# 检查字典中是否存在对应的配置
if bs_key_str in config and mean_kv_seqlen_key_str in config[bs_key_str]:
return config[bs_key_str][mean_kv_seqlen_key_str]
else:
raise ValueError(f"No matching configuration found for bs key: {bs_key} and mean kv seq key: {mean_kv_seqlen_key} when init decode attention db")
def get_mla_config_file_name(QH: int, KVH: int, QKD: int, VD: int, cache_dtype: Optional[str]) -> str:
if cache_dtype == "default":
return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_default.json"
device_name = torch.cuda.get_device_name().replace(" ", "_")
if "K100_AI" in device_name:
return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_{cache_dtype}_K100AI.json"
elif "BW" in device_name:
return f"QH={QH}_KVH={KVH}_QKD={QKD}_VD={VD}_{cache_dtype}_BW.json"
else:
raise ValueError(f"Unsurpport device name: {device_name}")
@functools.lru_cache
def get_attention_mla_configs(QH: int, KVH: int, QKD: int, VD: int, cache_dtype: Optional[str]) -> Optional[Dict[Any, Any]]:
# First look up if an optimized configuration is available in the configs
# directory
json_file_name = get_mla_config_file_name(QH, KVH, QKD, VD, cache_dtype)
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.info("Using decode attention configuration from %s for attention layer.", config_file_path)
# If a configuration has been found, return it
return json.load(f)
else:
logger.warning("Can not find best decode attention configuration %s for attention layer, it may not have the best performance to use default json. Please tune one. ", config_file_path)
json_file_name = get_mla_config_file_name(16, 1, 576, 512, "default")
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.warning("Using default decode attention configuration from %s for attention layer. It may not have the best performance to use default json. ", config_file_path)
# If a configuration has been found, return it
return json.load(f)
else:
raise ValueError("Please surpport default config can match 16 1 576 512")
# If no optimized configuration is available, we will use the default
# configuration
return None
class TritonMLABackend(AttentionBackend): class TritonMLABackend(AttentionBackend):
@staticmethod @staticmethod
......
...@@ -36,6 +36,7 @@ import triton.language as tl ...@@ -36,6 +36,7 @@ import triton.language as tl
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm import envs from vllm import envs
from ..backends.triton_config import KERNLE_KINDS
is_hip_ = current_platform.is_rocm() is_hip_ = current_platform.is_rocm()
os.environ["TRITON_HIP_USE_NEW_STREAM_PIPELINE"] = f"0" os.environ["TRITON_HIP_USE_NEW_STREAM_PIPELINE"] = f"0"
...@@ -897,8 +898,8 @@ def _decode_v1_stage1_use_tc( ...@@ -897,8 +898,8 @@ def _decode_v1_stage1_use_tc(
sm_scale, sm_scale,
page_size, page_size,
num_kv_splits, num_kv_splits,
logit_cap,
best_config, best_config,
logit_cap,
): ):
Lk = k_buffer.shape[-1] Lk = k_buffer.shape[-1]
...@@ -916,45 +917,64 @@ def _decode_v1_stage1_use_tc( ...@@ -916,45 +917,64 @@ def _decode_v1_stage1_use_tc(
batch, head_num = q.shape[0], q.shape[1] batch, head_num = q.shape[0], q.shape[1]
kv_group_num = q.shape[1] // k_buffer.shape[-2] kv_group_num = q.shape[1] // k_buffer.shape[-2]
BLOCK_N = best_config['BLOCK_N'] BLOCK_N = best_config.BLOCK_N
SPLIT_K = num_kv_splits # best_config['SPLIT_K'] ? SPLIT_K = num_kv_splits # best_config.SPLIT_K
num_stages = best_config['num_stages'] num_stages = best_config.num_stages
num_warps = best_config['num_warps'] num_warps = best_config.num_warps
BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num))) BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num)))
grid = lambda META: ( grid = lambda META: (
batch, batch,
triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
SPLIT_K, SPLIT_K,
) )
_decode_v1_kernel_stage1_use_tc[grid]( if best_config.decode_fwd_stage1 is None:
q, best_config.decode_fwd_stage1 = _decode_v1_kernel_stage1_use_tc[grid](
k_buffer, q,
sm_scale, k_buffer,
Req_to_tokens, sm_scale,
#B_req_idx, Req_to_tokens,
B_Start_Loc, #B_req_idx,
B_Seqlen, B_Start_Loc,
att_out, B_Seqlen,
Req_to_tokens.stride(0), att_out,
q.stride(0), Req_to_tokens.stride(0),
q.stride(1), q.stride(0),
k_buffer.stride(-3), q.stride(1),
k_buffer.stride(-2), k_buffer.stride(-3),
att_out.stride(0), k_buffer.stride(-2),
kv_group_num=kv_group_num, att_out.stride(0),
q_head_num=head_num, kv_group_num=kv_group_num,
BLOCK_DMODEL=BLOCK_DMODEL, q_head_num=head_num,
BLOCK_DPE=BLOCK_DPE, BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_N=BLOCK_N, BLOCK_DPE=BLOCK_DPE,
BLOCK_H=BLOCK_H, BLOCK_N=BLOCK_N,
SPLIT_K=SPLIT_K, BLOCK_H=BLOCK_H,
PAGE_SIZE=page_size, SPLIT_K=SPLIT_K,
logit_cap=logit_cap, PAGE_SIZE=page_size,
num_warps=num_warps, logit_cap=logit_cap,
num_stages=num_stages, num_warps=num_warps,
Lk=Lk, num_stages=num_stages,
kpack=2, Lk=Lk,
) kpack=2,
)
else:
best_config.decode_fwd_stage1 = _decode_v1_kernel_stage1_use_tc[grid](
q,
k_buffer,
sm_scale,
Req_to_tokens,
#B_req_idx,
B_Start_Loc,
B_Seqlen,
att_out,
Req_to_tokens.stride(0),
q.stride(0),
q.stride(1),
k_buffer.stride(-3),
k_buffer.stride(-2),
att_out.stride(0),
)
# return _decode_v1_kernel_stage1_use_tc.best_config # return _decode_v1_kernel_stage1_use_tc.best_config
...@@ -966,44 +986,62 @@ def _decode_v1_stage2_use_tc( ...@@ -966,44 +986,62 @@ def _decode_v1_stage2_use_tc(
#b_req_idx, #b_req_idx,
b_start_loc, b_start_loc,
b_seq_len, b_seq_len,
page_size,
best_config, best_config,
page_size,
): ):
batch, head_num = b_seq_len.shape[0], logits.shape[0] batch, head_num = b_seq_len.shape[0], logits.shape[0]
kv_group_num = logits.shape[0] // v_buffer.shape[-2] kv_group_num = logits.shape[0] // v_buffer.shape[-2]
BLOCK_N = best_config['BLOCK_N'] BLOCK_N = best_config.BLOCK_N_2
num_stages = best_config['num_stages'] num_stages = best_config.num_stages_2
num_warps = best_config['num_warps'] num_warps = best_config.num_warps_2
BLOCK_H = max(16, triton.next_power_of_2(kv_group_num)) BLOCK_H = max(16, triton.next_power_of_2(kv_group_num))
grid = (batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), 1) grid = (batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), 1)
Lv = v_buffer.shape[-1] Lv = v_buffer.shape[-1]
BLOCK_DMODEL = triton.next_power_of_2(Lv) BLOCK_DMODEL = triton.next_power_of_2(Lv)
_decode_v1_kernel_stage2_use_tc[grid]( if best_config.decode_fwd_stage2 is None:
logits, best_config.decode_fwd_stage2 = _decode_v1_kernel_stage2_use_tc[grid](
v_buffer, logits,
o, v_buffer,
req_to_tokens, o,
#b_req_idx, req_to_tokens,
b_start_loc, #b_req_idx,
b_seq_len, b_start_loc,
logits.stride(0), b_seq_len,
v_buffer.stride(-3), logits.stride(0),
v_buffer.stride(-2), v_buffer.stride(-3),
o.stride(0), v_buffer.stride(-2),
o.stride(1), o.stride(0),
req_to_tokens.stride(0), o.stride(1),
kv_group_num=kv_group_num, req_to_tokens.stride(0),
q_head_num=head_num, kv_group_num=kv_group_num,
BLOCK_DMODEL=BLOCK_DMODEL, q_head_num=head_num,
BLOCK_N=BLOCK_N, BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_H=BLOCK_H, BLOCK_N=BLOCK_N,
PAGE_SIZE=page_size, BLOCK_H=BLOCK_H,
Lv=Lv, PAGE_SIZE=page_size,
num_warps=num_warps, Lv=Lv,
num_stages=num_stages, num_warps=num_warps,
) num_stages=num_stages,
)
else:
best_config.decode_fwd_stage2 = _decode_v1_kernel_stage2_use_tc[grid](
logits,
v_buffer,
o,
req_to_tokens,
#b_req_idx,
b_start_loc,
b_seq_len,
logits.stride(0),
v_buffer.stride(-3),
v_buffer.stride(-2),
o.stride(0),
o.stride(1),
req_to_tokens.stride(0),
)
# return _decode_v1_kernel_stage2_use_tc.best_config # return _decode_v1_kernel_stage2_use_tc.best_config
...@@ -1059,8 +1097,8 @@ def decode_attention_v1( ...@@ -1059,8 +1097,8 @@ def decode_attention_v1(
sm_scale, sm_scale,
page_size, page_size,
num_kv_splits, num_kv_splits,
best_config,
logit_cap, logit_cap,
best_config['stage1'],
) )
_decode_v1_stage2_use_tc( _decode_v1_stage2_use_tc(
attn_logits, attn_logits,
...@@ -1070,12 +1108,11 @@ def decode_attention_v1( ...@@ -1070,12 +1108,11 @@ def decode_attention_v1(
#b_req_idx, #b_req_idx,
b_start_loc, b_start_loc,
b_seq_len, b_seq_len,
best_config,
page_size, page_size,
best_config['stage2'],
) )
# @triton.autotune( # @triton.autotune(
# configs=[ # configs=[
# triton.Config({"BLOCK_N": 16}, num_warps=2, num_stages=1), # triton.Config({"BLOCK_N": 16}, num_warps=2, num_stages=1),
...@@ -1274,9 +1311,9 @@ def _decode_v2_stage1_use_tc( ...@@ -1274,9 +1311,9 @@ def _decode_v2_stage1_use_tc(
logit_cap, logit_cap,
): ):
BLOCK = best_config['BLOCK_N'] BLOCK = best_config.BLOCK_N
num_stages = best_config['num_stages'] num_stages = best_config.num_stages
num_warps = best_config['num_warps'] num_warps = best_config.num_warps
Lk = k_buffer.shape[-1] Lk = k_buffer.shape[-1]
Lv = v_buffer.shape[-1] Lv = v_buffer.shape[-1]
...@@ -1304,43 +1341,66 @@ def _decode_v2_stage1_use_tc( ...@@ -1304,43 +1341,66 @@ def _decode_v2_stage1_use_tc(
NUM_KV_SPLITS, NUM_KV_SPLITS,
) )
_decode_v2_kernel_stage1_use_tc[grid]( if best_config.decode_fwd_stage1 is None:
q, best_config.decode_fwd_stage1 =_decode_v2_kernel_stage1_use_tc[grid](
k_buffer, q,
v_buffer, k_buffer,
sm_scale, v_buffer,
Req_to_tokens, sm_scale,
# B_req_idx, Req_to_tokens,
B_Seqlen, # B_req_idx,
att_out, B_Seqlen,
Req_to_tokens.stride(0), att_out,
q.stride(0), Req_to_tokens.stride(0),
q.stride(1), q.stride(0),
k_buffer.stride(-3), q.stride(1),
k_buffer.stride(-2), k_buffer.stride(-3),
v_buffer.stride(-3), k_buffer.stride(-2),
v_buffer.stride(-2), v_buffer.stride(-3),
att_out.stride(0), v_buffer.stride(-2),
att_out.stride(1), att_out.stride(0),
att_out.stride(2), att_out.stride(1),
kv_group_num=kv_group_num, att_out.stride(2),
q_head_num=head_num, kv_group_num=kv_group_num,
BLOCK_DMODEL=BLOCK_DMODEL, q_head_num=head_num,
BLOCK_DPE=BLOCK_DPE, BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_DV=BLOCK_DV, BLOCK_DPE=BLOCK_DPE,
BLOCK_N=BLOCK, BLOCK_DV=BLOCK_DV,
BLOCK_H=BLOCK_H, BLOCK_N=BLOCK,
NUM_KV_SPLITS=NUM_KV_SPLITS, BLOCK_H=BLOCK_H,
PAGE_SIZE=page_size, NUM_KV_SPLITS=NUM_KV_SPLITS,
logit_cap=logit_cap, PAGE_SIZE=page_size,
num_warps=num_warps, logit_cap=logit_cap,
num_stages=num_stages, num_warps=num_warps,
Lk=Lk, num_stages=num_stages,
Lv=Lv, Lk=Lk,
kpack=2, Lv=Lv,
) kpack=2,
)
else:
best_config.decode_fwd_stage1[grid](
q,
k_buffer,
v_buffer,
sm_scale,
Req_to_tokens,
# B_req_idx,
B_Seqlen,
att_out,
Req_to_tokens.stride(0),
q.stride(0),
q.stride(1),
k_buffer.stride(-3),
k_buffer.stride(-2),
v_buffer.stride(-3),
v_buffer.stride(-2),
att_out.stride(0),
att_out.stride(1),
att_out.stride(2),
)
# return _decode_v2_kernel_stage1_use_tc.best_config # return _decode_v2_kernel_stage1_use_tc.best_config
# @triton.autotune( # @triton.autotune(
# configs=[ # configs=[
# triton.Config({}, num_warps=1, num_stages=1), # triton.Config({}, num_warps=1, num_stages=1),
...@@ -1416,8 +1476,8 @@ def _decode_v2_stage2_use_tc( ...@@ -1416,8 +1476,8 @@ def _decode_v2_stage2_use_tc(
num_kv_splits, num_kv_splits,
best_config, best_config,
): ):
num_stages = best_config['num_stages'] num_stages = best_config.num_stages_2
num_warps = best_config['num_warps'] num_warps = best_config.num_warps_2
batch, head_num = q.shape[0], q.shape[1] batch, head_num = q.shape[0], q.shape[1]
Lv = v_buffer.shape[-1] Lv = v_buffer.shape[-1]
...@@ -1425,22 +1485,34 @@ def _decode_v2_stage2_use_tc( ...@@ -1425,22 +1485,34 @@ def _decode_v2_stage2_use_tc(
NUM_KV_SPLITS = num_kv_splits NUM_KV_SPLITS = num_kv_splits
grid = (batch, head_num) grid = (batch, head_num, 1)
_decode_v2_kernel_stage2[grid]( if best_config.decode_fwd_stage2 is None:
logits, best_config.decode_fwd_stage2 = _decode_v2_kernel_stage2[grid](
o, logits,
b_seq_len, o,
logits.stride(0), b_seq_len,
logits.stride(1), logits.stride(0),
logits.stride(2), logits.stride(1),
o.stride(0), logits.stride(2),
o.stride(1), o.stride(0),
NUM_KV_SPLITS=NUM_KV_SPLITS, o.stride(1),
BLOCK_DV=BLOCK_DV, NUM_KV_SPLITS=NUM_KV_SPLITS,
Lv=Lv, BLOCK_DV=BLOCK_DV,
num_warps=num_warps, Lv=Lv,
num_stages=num_stages, num_warps=num_warps,
) num_stages=num_stages,
)
else:
best_config.decode_fwd_stage2[grid](
logits,
o,
b_seq_len,
logits.stride(0),
logits.stride(1),
logits.stride(2),
o.stride(0),
o.stride(1),
)
# return _decode_v2_kernel_stage2.best_config # return _decode_v2_kernel_stage2.best_config
...@@ -1485,11 +1557,11 @@ def decode_attention_v2( ...@@ -1485,11 +1557,11 @@ def decode_attention_v2(
b_seq_len, b_seq_len,
num_kv_splits, num_kv_splits,
sm_scale, sm_scale,
best_config,
page_size, page_size,
logit_cap, logit_cap,
best_config['stage1'],
) )
_decode_v2_stage2_use_tc(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits, best_config['stage2']) _decode_v2_stage2_use_tc(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits, best_config)
def decode_attention_fwd( def decode_attention_fwd(
...@@ -1560,7 +1632,7 @@ def decode_attention_fwd( ...@@ -1560,7 +1632,7 @@ def decode_attention_fwd(
logit_cap, logit_cap,
)''' )'''
if best_config['kernel_kind'] == 'v1_2stages_tc': if best_config.kernel_kind == KERNLE_KINDS.v1_2stages_tc:
attn_logits_v1 = torch.empty( attn_logits_v1 = torch.empty(
(q.shape[1],k_buffer.shape[0]*page_size), (q.shape[1],k_buffer.shape[0]*page_size),
dtype=torch.float16, dtype=torch.float16,
...@@ -1576,11 +1648,11 @@ def decode_attention_fwd( ...@@ -1576,11 +1648,11 @@ def decode_attention_fwd(
attn_logits_v1, attn_logits_v1,
num_kv_splits, num_kv_splits,
sm_scale, sm_scale,
best_config=best_config['best_config'], best_config=best_config,
page_size=page_size, page_size=page_size,
logit_cap=logit_cap, logit_cap=logit_cap,
) )
elif best_config['kernel_kind'] == 'v2_tc': elif best_config.kernel_kind == KERNLE_KINDS.v2_tc:
decode_attention_v2( decode_attention_v2(
q, q,
k_buffer, k_buffer,
...@@ -1591,7 +1663,7 @@ def decode_attention_fwd( ...@@ -1591,7 +1663,7 @@ def decode_attention_fwd(
attn_logits, attn_logits,
num_kv_splits, num_kv_splits,
sm_scale, sm_scale,
best_config=best_config['best_config'], best_config=best_config,
page_size=page_size, page_size=page_size,
logit_cap=logit_cap, logit_cap=logit_cap,
) )
......
...@@ -14,6 +14,9 @@ from vllm import _custom_ops as ops ...@@ -14,6 +14,9 @@ from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8) per_token_group_quant_fp8)
from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_group_quant_int8)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
...@@ -265,6 +268,7 @@ def fused_moe_kernel( ...@@ -265,6 +268,7 @@ def fused_moe_kernel(
top_k: tl.constexpr, top_k: tl.constexpr,
compute_type: tl.constexpr, compute_type: tl.constexpr,
use_fp8_w8a8: tl.constexpr, use_fp8_w8a8: tl.constexpr,
use_int8_w8a8: tl.constexpr,
use_int8_w8a16: tl.constexpr): use_int8_w8a16: tl.constexpr):
""" """
Implements the fused computation for a Mixture of Experts (MOE) using Implements the fused computation for a Mixture of Experts (MOE) using
...@@ -346,7 +350,7 @@ def fused_moe_kernel( ...@@ -346,7 +350,7 @@ def fused_moe_kernel(
None, :] * stride_bsn None, :] * stride_bsn
b_scale = tl.load(b_scale_ptrs) b_scale = tl.load(b_scale_ptrs)
if use_fp8_w8a8: if use_fp8_w8a8 or use_int8_w8a8:
if group_k > 0 and group_n > 0: if group_k > 0 and group_n > 0:
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
offs_bsn = offs_bn // group_n offs_bsn = offs_bn // group_n
...@@ -376,7 +380,7 @@ def fused_moe_kernel( ...@@ -376,7 +380,7 @@ def fused_moe_kernel(
# We accumulate along the K dimension. # We accumulate along the K dimension.
if use_int8_w8a16: if use_int8_w8a16:
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
elif use_fp8_w8a8: elif use_fp8_w8a8 or use_int8_w8a8:
if group_k > 0 and group_n > 0: if group_k > 0 and group_n > 0:
k_start = k * BLOCK_SIZE_K k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k offs_ks = k_start // group_k
...@@ -402,7 +406,7 @@ def fused_moe_kernel( ...@@ -402,7 +406,7 @@ def fused_moe_kernel(
accumulator = accumulator * moe_weight[:, None] accumulator = accumulator * moe_weight[:, None]
if use_int8_w8a16: if use_int8_w8a16:
accumulator = (accumulator * b_scale).to(compute_type) accumulator = (accumulator * b_scale).to(compute_type)
elif use_fp8_w8a8: elif use_fp8_w8a8 or use_int8_w8a8:
if group_k > 0 and group_n > 0: if group_k > 0 and group_n > 0:
accumulator = accumulator.to(compute_type) accumulator = accumulator.to(compute_type)
else: else:
...@@ -709,6 +713,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -709,6 +713,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
config: Dict[str, Any], config: Dict[str, Any],
compute_type: tl.dtype, compute_type: tl.dtype,
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
use_int4_w4a16: bool, use_int4_w4a16: bool,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
...@@ -727,6 +732,19 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -727,6 +732,19 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1] assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2] assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1] assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
elif use_int8_w8a8:
assert B_scale is not None
if block_shape is None:
A, A_scale = ops.scaled_int8_quant(A, A_scale)
else:
assert len(block_shape) == 2
block_n, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_int8(A, block_k)
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
elif use_int8_w8a16 or use_int4_w4a16: elif use_int8_w8a16 or use_int4_w4a16:
assert B_scale is not None assert B_scale is not None
assert block_shape is None or block_shape[0] == 0 assert block_shape is None or block_shape[0] == 0
...@@ -826,6 +844,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -826,6 +844,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
top_k=top_k, top_k=top_k,
compute_type=compute_type, compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
**config, **config,
) )
...@@ -1060,9 +1079,12 @@ def grouped_topk(hidden_states: torch.Tensor, ...@@ -1060,9 +1079,12 @@ def grouped_topk(hidden_states: torch.Tensor,
def get_config_dtype_str(dtype: torch.dtype, def get_config_dtype_str(dtype: torch.dtype,
use_int4_w4a16: Optional[bool] = False, use_int4_w4a16: Optional[bool] = False,
use_int8_w8a16: Optional[bool] = False, use_int8_w8a16: Optional[bool] = False,
use_fp8_w8a8: Optional[bool] = False): use_fp8_w8a8: Optional[bool] = False,
use_int8_w8a8: Optional[bool] = False):
if use_fp8_w8a8: if use_fp8_w8a8:
return "fp8_w8a8" return "fp8_w8a8"
elif use_int8_w8a8:
return "int8_w8a8"
elif use_int8_w8a16: elif use_int8_w8a16:
return "int8_w8a16" return "int8_w8a16"
elif use_int4_w4a16: elif use_int4_w4a16:
...@@ -1080,6 +1102,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor, ...@@ -1080,6 +1102,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
...@@ -1094,7 +1117,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor, ...@@ -1094,7 +1117,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
start_expert: Optional[int] = -1, start_expert: Optional[int] = -1,
end_expert: Optional[int] = -1) -> None: end_expert: Optional[int] = -1) -> None:
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, w1_scale, use_fp8_w8a8,use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, w1_scale,
w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape,
use_nn_moe, moe_ep_size=moe_ep_size, use_nn_moe, moe_ep_size=moe_ep_size,
start_expert=start_expert, end_expert=end_expert) start_expert=start_expert, end_expert=end_expert)
...@@ -1107,6 +1130,7 @@ def inplace_fused_experts_fake( ...@@ -1107,6 +1130,7 @@ def inplace_fused_experts_fake(
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
...@@ -1138,6 +1162,7 @@ def outplace_fused_experts( ...@@ -1138,6 +1162,7 @@ def outplace_fused_experts(
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
...@@ -1152,7 +1177,7 @@ def outplace_fused_experts( ...@@ -1152,7 +1177,7 @@ def outplace_fused_experts(
start_expert: Optional[int] = -1, start_expert: Optional[int] = -1,
end_expert: Optional[int] = -1) -> torch.Tensor: end_expert: Optional[int] = -1) -> torch.Tensor:
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
False, use_fp8_w8a8, use_int8_w8a16, False, use_fp8_w8a8,use_int8_w8a8,use_int8_w8a16,
use_int4_w4a16, w1_scale, w2_scale, w1_zp, w2_zp, use_int4_w4a16, w1_scale, w2_scale, w1_zp, w2_zp,
a1_scale, a2_scale, block_shape, a1_scale, a2_scale, block_shape,
use_nn_moe, moe_ep_size=moe_ep_size, use_nn_moe, moe_ep_size=moe_ep_size,
...@@ -1166,6 +1191,7 @@ def outplace_fused_experts_fake( ...@@ -1166,6 +1191,7 @@ def outplace_fused_experts_fake(
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
...@@ -1197,6 +1223,7 @@ def fused_experts(hidden_states: torch.Tensor, ...@@ -1197,6 +1223,7 @@ def fused_experts(hidden_states: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
inplace: bool = False, inplace: bool = False,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
...@@ -1213,7 +1240,7 @@ def fused_experts(hidden_states: torch.Tensor, ...@@ -1213,7 +1240,7 @@ def fused_experts(hidden_states: torch.Tensor,
if inplace: if inplace:
torch.ops.vllm.inplace_fused_experts(hidden_states, w1, w2, torch.ops.vllm.inplace_fused_experts(hidden_states, w1, w2,
topk_weights, topk_ids, topk_weights, topk_ids,
use_fp8_w8a8, use_int8_w8a16, use_fp8_w8a8,use_int8_w8a8,use_int8_w8a16,
use_int4_w4a16, w1_scale, use_int4_w4a16, w1_scale,
w2_scale, w1_zp, w2_zp, a1_scale, w2_scale, w1_zp, w2_zp, a1_scale,
a2_scale, block_shape, a2_scale, block_shape,
...@@ -1224,7 +1251,7 @@ def fused_experts(hidden_states: torch.Tensor, ...@@ -1224,7 +1251,7 @@ def fused_experts(hidden_states: torch.Tensor,
return hidden_states return hidden_states
else: else:
return torch.ops.vllm.outplace_fused_experts( return torch.ops.vllm.outplace_fused_experts(
hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8, hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8,use_int8_w8a8,
use_int8_w8a16, use_int4_w4a16, w1_scale, w2_scale, w1_zp, w2_zp, use_int8_w8a16, use_int4_w4a16, w1_scale, w2_scale, w1_zp, w2_zp,
a1_scale, a2_scale, block_shape, a1_scale, a2_scale, block_shape,
use_nn_moe, moe_ep_size=moe_ep_size, use_nn_moe, moe_ep_size=moe_ep_size,
...@@ -1239,6 +1266,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1239,6 +1266,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
inplace: bool = False, inplace: bool = False,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
...@@ -1279,6 +1307,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1279,6 +1307,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE) M = min(num_tokens, CHUNK_SIZE)
config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8, config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
dtype=hidden_states.dtype) dtype=hidden_states.dtype)
...@@ -1369,6 +1398,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1369,6 +1398,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
config, config,
compute_type=compute_type, compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
block_shape=block_shape, block_shape=block_shape,
...@@ -1393,6 +1423,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1393,6 +1423,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
config, config,
compute_type=compute_type, compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
block_shape=block_shape, block_shape=block_shape,
...@@ -1416,6 +1447,7 @@ def fused_moe( ...@@ -1416,6 +1447,7 @@ def fused_moe(
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
...@@ -1492,6 +1524,7 @@ def fused_moe( ...@@ -1492,6 +1524,7 @@ def fused_moe(
topk_ids, topk_ids,
inplace=inplace, inplace=inplace,
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
w1_scale=w1_scale, w1_scale=w1_scale,
......
...@@ -363,6 +363,9 @@ class FusedMoE(torch.nn.Module): ...@@ -363,6 +363,9 @@ class FusedMoE(torch.nn.Module):
if (self.quant_method.__class__.__name__ == if (self.quant_method.__class__.__name__ ==
"CompressedTensorsWNA16MoEMethod"): "CompressedTensorsWNA16MoEMethod"):
moe_quant_params["intermediate_size_full"] = intermediate_size moe_quant_params["intermediate_size_full"] = intermediate_size
if (self.quant_method.__class__.__name__ in ("BlockInt8MoEMethod")):
moe_quant_params["intermediate_size"] = self.intermediate_size_per_partition
self.quant_method.create_weights(layer=self, **moe_quant_params) self.quant_method.create_weights(layer=self, **moe_quant_params)
......
...@@ -37,7 +37,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [ ...@@ -37,7 +37,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
"MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod", "MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod",
"TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod", "TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod",
"ModelOptFp8LinearMethod", "IPEXAWQLinearMethod", "IPEXGPTQLinearMethod", "ModelOptFp8LinearMethod", "IPEXAWQLinearMethod", "IPEXGPTQLinearMethod",
"HQQMarlinMethod", "QuarkLinearMethod" "HQQMarlinMethod", "QuarkLinearMethod", "BlockInt8LinearMethod",
] ]
...@@ -664,9 +664,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -664,9 +664,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
if isinstance(param, BlockQuantScaleParameter): if isinstance(param, BlockQuantScaleParameter):
from vllm.model_executor.layers.quantization.fp8 import ( from vllm.model_executor.layers.quantization.fp8 import (
Fp8LinearMethod, Fp8MoEMethod) Fp8LinearMethod, Fp8MoEMethod)
from vllm.model_executor.layers.quantization.blockwise_int8 import (
BlockInt8LinearMethod, BlockInt8MoEMethod)
assert self.quant_method is not None assert self.quant_method is not None
assert isinstance(self.quant_method, assert isinstance(self.quant_method,
(Fp8LinearMethod, Fp8MoEMethod)) (Fp8LinearMethod, Fp8MoEMethod, BlockInt8LinearMethod, BlockInt8MoEMethod))
weight_block_size = self.quant_method.quant_config.weight_block_size weight_block_size = self.quant_method.quant_config.weight_block_size
assert weight_block_size is not None assert weight_block_size is not None
block_n, _ = weight_block_size[0], weight_block_size[1] block_n, _ = weight_block_size[0], weight_block_size[1]
......
...@@ -29,7 +29,8 @@ QUANTIZATION_METHODS: List[str] = [ ...@@ -29,7 +29,8 @@ QUANTIZATION_METHODS: List[str] = [
"neuron_quant", "neuron_quant",
"ipex", "ipex",
"quark", "quark",
"moe_wna16" "moe_wna16",
"blockwise_int8"
] ]
# The customized quantization methods which will be added to this dict. # The customized quantization methods which will be added to this dict.
...@@ -101,6 +102,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: ...@@ -101,6 +102,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
from .neuron_quant import NeuronQuantConfig from .neuron_quant import NeuronQuantConfig
from .qqq import QQQConfig from .qqq import QQQConfig
from .tpu_int8 import Int8TpuConfig from .tpu_int8 import Int8TpuConfig
from .blockwise_int8 import BlockInt8Config
method_to_config: Dict[str, Type[QuantizationConfig]] = { method_to_config: Dict[str, Type[QuantizationConfig]] = {
"aqlm": AQLMConfig, "aqlm": AQLMConfig,
...@@ -127,6 +129,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: ...@@ -127,6 +129,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
"ipex": IPEXConfig, "ipex": IPEXConfig,
"quark": QuarkConfig, "quark": QuarkConfig,
"moe_wna16": MoeWNA16Config, "moe_wna16": MoeWNA16Config,
"blockwise_int8": BlockInt8Config,
} }
# Update the `method_to_config` with customized quantization methods. # Update the `method_to_config` with customized quantization methods.
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG) method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)
......
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/sgl-project/sglang/pull/3730
import logging
from typing import Any, Callable, Dict, List, Optional
import torch
from torch.nn import Module
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped)
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.int8_utils import (
apply_w8a8_block_int8_linear)
from vllm.model_executor.utils import set_weight_attrs
ACTIVATION_SCHEMES = ["static", "dynamic"]
logger = logging.getLogger(__name__)
class BlockInt8Config(QuantizationConfig):
"""Config class for INT8."""
def __init__(
self,
is_checkpoint_int8_serialized: bool = False,
activation_scheme: str = "dynamic",
ignored_layers: Optional[List[str]] = None,
weight_block_size: Optional[List[int]] = None,
) -> None:
self.is_checkpoint_int8_serialized = is_checkpoint_int8_serialized
if is_checkpoint_int8_serialized:
logger.warning(
"Detected int8 checkpoint. Please note that the "
"format is experimental and subject to change."
)
if activation_scheme not in ACTIVATION_SCHEMES:
raise ValueError("Unsupported activation scheme"
f" {activation_scheme}")
self.activation_scheme = activation_scheme
self.ignored_layers = ignored_layers or []
if weight_block_size is not None:
if not is_checkpoint_int8_serialized:
raise ValueError(
f"The block-wise quantization only supports "
"int8-serialized checkpoint for now."
)
if len(weight_block_size) != 2:
raise ValueError(
f"The quantization block size of weight must have 2 "
"dimensions, but got {len(weight_block_size)} dimensions."
)
if activation_scheme != "dynamic":
raise ValueError(
f"The block-wise quantization only supports dynamic "
"activation scheme for now, but got "
"{activation_scheme} activation scheme."
)
self.weight_block_size = weight_block_size
@classmethod
def get_name(cls) -> str:
return "blockwise_int8"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.bfloat16, torch.half]
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
return []
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "BlockInt8Config":
quant_method = cls.get_from_keys(config, ["quant_method"])
is_checkpoint_int8_serialized = "int8" in quant_method
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
weight_block_size = cls.get_from_keys_or(config,
["weight_block_size"], None)
return cls(
is_checkpoint_int8_serialized=is_checkpoint_int8_serialized,
activation_scheme=activation_scheme,
ignored_layers=ignored_layers,
weight_block_size=weight_block_size,
)
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.ignored_layers):
return UnquantizedLinearMethod()
return BlockInt8LinearMethod(self)
elif isinstance(layer, FusedMoE):
return BlockInt8MoEMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class BlockInt8LinearMethod(LinearMethodBase):
"""Linear method for INT8.
Supports loading INT8 checkpoints with static weight scale and
dynamic activation scale.
Limitations:
Only support block-wise int8 quantization and int8 checkpoint
Args:
quant_config: The quantization config.
"""
def __init__(self, quant_config: BlockInt8Config):
self.quant_config = quant_config
assert self.quant_config.weight_block_size is not None
assert self.quant_config.is_checkpoint_int8_serialized
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: Optional[List[int]],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
# assert output_partition_sizes is not None, (
# "output_partition_sizes must be provided for quantization")
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
tp_size = get_tensor_model_parallel_world_size()
block_n, block_k = (
self.quant_config.weight_block_size[0],
self.quant_config.weight_block_size[1],
)
# Required by row parallel
if tp_size > 1 and input_size // input_size_per_partition == tp_size:
if input_size_per_partition % block_k != 0:
raise ValueError(
f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible by "
f"weight quantization block_k = {block_k}."
)
# Required by collum parallel or enabling merged weights
if (tp_size > 1 and output_size // output_size_per_partition == tp_size) or len(
output_partition_sizes
) > 1:
for output_partition_size in output_partition_sizes:
if output_partition_size % block_n != 0:
raise ValueError(
f"Weight output_partition_size = "
f"{output_partition_size} is not divisible by "
f"weight quantization block_n = {block_n}."
)
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype
# WEIGHT
weight_dtype = (
torch.int8
if self.quant_config.is_checkpoint_int8_serialized
else params_dtype
)
weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition, input_size_per_partition, dtype=weight_dtype
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
# WEIGHT SCALE
scale = BlockQuantScaleParameter(
data=torch.empty(
(output_size_per_partition + block_n - 1) // block_n,
(input_size_per_partition + block_k - 1) // block_k,
dtype=torch.float32,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale_inv", scale)
# INPUT ACTIVATION SCALE
assert self.quant_config.activation_scheme == "dynamic"
layer.register_parameter("input_scale", None)
def process_weights_after_loading(self, layer: Module) -> None:
# Block quant doesn't need to process weights after loading
# Use torch Parameter to avoid cuda graph capturing issue
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
layer.weight_scale_inv = torch.nn.Parameter(
layer.weight_scale_inv.data, requires_grad=False
)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return apply_w8a8_block_int8_linear(
input=x,
weight=layer.weight,
block_size=self.quant_config.weight_block_size,
weight_scale=layer.weight_scale_inv,
input_scale=None,
bias=bias,
)
class BlockInt8MoEMethod:
"""MoE method for INT8.
Supports loading INT8 checkpoints with static weight scale and
dynamic activation scale.
Limitations:
Only support block-wise int8 quantization and int8 checkpoint
Args:
quant_config: The quantization config.
"""
def __new__(cls, *args, **kwargs):
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
if not hasattr(cls, "_initialized"):
original_init = cls.__init__
new_cls = type(
cls.__name__,
(FusedMoEMethodBase,),
{
"__init__": original_init,
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
},
)
obj = super(new_cls, new_cls).__new__(new_cls)
obj.__init__(*args, **kwargs)
return obj
return super().__new__(cls)
def __init__(self, quant_config):
self.quant_config = quant_config
assert self.quant_config.weight_block_size is not None
assert self.quant_config.is_checkpoint_int8_serialized
def create_weights(
self,
layer: Module,
num_experts: int,
hidden_size: int,
intermediate_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
from vllm.model_executor.layers.fused_moe import FusedMoeWeightScaleSupported
if self.quant_config.is_checkpoint_int8_serialized:
params_dtype = torch.int8
tp_size = get_tensor_model_parallel_world_size()
block_n, block_k = (
self.quant_config.weight_block_size[0],
self.quant_config.weight_block_size[1],
)
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
# Required by collum parallel or enabling merged weights
if intermediate_size % block_n != 0:
raise ValueError(
f"The output_size of gate's and up's weight = "
f"{intermediate_size} is not divisible by "
f"weight quantization block_n = {block_n}."
)
if tp_size > 1:
# Required by row parallel
if intermediate_size % block_k != 0:
raise ValueError(
f"The input_size of down's weight = "
f"{intermediate_size} is not divisible by "
f"weight quantization block_k = {block_k}."
)
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts, hidden_size, intermediate_size, dtype=params_dtype
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES
w13_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
2 * ((intermediate_size + block_n - 1) // block_n),
(hidden_size + block_k - 1) // block_k,
dtype=torch.float32,
),
requires_grad=False,
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
(hidden_size + block_n - 1) // block_n,
(intermediate_size + block_k - 1) // block_k,
dtype=torch.float32,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# INPUT_SCALES
assert self.quant_config.activation_scheme == "dynamic"
layer.w13_input_scale = None
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: Module) -> None:
# Block quant doesn't need to process weights after loading
return
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
use_nn_moe: Optional[bool] = False,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
moe_ep_size: Optional[int] = 1,
start_expert: Optional[int] = -1,
end_expert: Optional[int] = -1
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
#print("===========fused_experts========================")
# Expert selection
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias
)
# Expert fusion with INT8 quantization
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
use_int8_w8a8=True,
w1_scale=(layer.w13_weight_scale_inv),
w2_scale=(layer.w2_weight_scale_inv),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
use_nn_moe=use_nn_moe,
moe_ep_size=moe_ep_size,
start_expert=start_expert,
end_expert=end_expert
)
# SPDX-License-Identifier: Apache-2.0
import functools
import json
import logging
import os
from typing import Any, Dict, List, Optional, Tuple
import torch
import triton
import triton.language as tl
# from sglang.srt.utils import get_device_name
from vllm.platforms import current_platform
logger = logging.getLogger(__name__)
@triton.jit
def _per_token_quant_int8(
x_ptr,
xq_ptr,
scale_ptr,
stride_x,
stride_xq,
N,
BLOCK: tl.constexpr,
):
row_id = tl.program_id(0)
cols = tl.arange(0, BLOCK)
mask = cols < N
x = tl.load(x_ptr + row_id * stride_x + cols,
mask=mask, other=0.0).to(tl.float32)
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
scale_x = absmax / 127
x_q = x * (127 / absmax)
x_q = tl.extra.cuda.libdevice.round(x_q).to(tl.int8)
tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
tl.store(scale_ptr + row_id, scale_x)
def per_token_quant_int8(x):
M = x.numel() // x.shape[-1]
N = x.shape[-1]
x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
scales = torch.empty(x.shape[:-1] + (1,),
device=x.device, dtype=torch.float32)
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
assert x.is_contiguous()
_per_token_quant_int8[(M,)](
x,
x_q,
scales,
stride_x=x.stride(-2),
stride_xq=x_q.stride(-2),
N=N,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return x_q, scales
@triton.jit
def _per_token_group_quant_int8(
# Pointers to inputs and output
y_ptr,
y_q_ptr,
y_s_ptr,
# Stride of input
y_stride,
# Collums of input
N,
# Avoid to divide zero
eps,
# Information for int8
int8_min,
int8_max,
# Meta-parameters
BLOCK: tl.constexpr,
):
"""A Triton-accelerated function to perform
per-token-group quantization on a tensor.
This function converts the tensor values into int8 values.
"""
# Map the program id to the row of X and Y it should compute.
g_id = tl.program_id(0)
y_ptr += g_id * y_stride
y_q_ptr += g_id * y_stride
y_s_ptr += g_id
cols = tl.arange(0, BLOCK) # N <= BLOCK
mask = cols < N
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / int8_max
y_q = tl.clamp(y / y_s, int8_min, int8_max).to(y_q_ptr.dtype.element_ty)
tl.store(y_q_ptr + cols, y_q, mask=mask)
tl.store(y_s_ptr, y_s)
def per_token_group_quant_int8(
x: torch.Tensor,
group_size: int,
eps: float = 1e-10,
dtype: torch.dtype = torch.int8,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed int8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Args:
x: The input tenosr with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.int8`
is supported for now.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization.
"""
assert (
x.shape[-1] % group_size == 0
), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous"
iinfo = torch.iinfo(dtype)
int8_max = iinfo.max
int8_min = iinfo.min
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // group_size
N = group_size
x_s = torch.empty(
x.shape[:-1] + (x.shape[-1] // group_size,),
device=x.device,
dtype=torch.float32,
)
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
num_stages = 1
_per_token_group_quant_int8[(M,)](
x,
x_q,
x_s,
group_size,
N,
eps,
int8_min=int8_min,
int8_max=int8_max,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
)
return x_q, x_s
@triton.jit
def _w8a8_block_int8_matmul(
# Pointers to inputs and output
A,
B,
C,
As,
Bs,
# Shape for matmul
M,
N,
K,
# Block size for block-wise quantization
group_n,
group_k,
# Stride for inputs and output
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_As_m,
stride_As_k,
stride_Bs_k,
stride_Bs_n,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""Triton-accelerated function used to perform linear operations (dot
product) on input tensors `A` and `B` with block-wise quantization,
and store the result in output tensor `C`.
"""
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
As_ptrs = As + offs_am * stride_As_m
offs_bsn = offs_bn // group_n
Bs_ptrs = Bs + offs_bsn * stride_Bs_n
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs,
mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
other=0.0)
b = tl.load(b_ptrs,
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0)
k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
accumulator += tl.dot(a, b).to(tl.float32) * a_s[:, None] * b_s[None, :]
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
if C.dtype.element_ty == tl.bfloat16:
c = accumulator.to(tl.bfloat16)
elif C.dtype.element_ty == tl.float16:
c = accumulator.to(tl.float16)
else:
c = accumulator.to(tl.float32)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
@functools.lru_cache
def get_w8a8_block_int8_configs(
N: int, K: int, block_n: int, block_k: int
) -> Optional[Dict[int, Any]]:
"""
Return optimized configurations for the w8a8 block fp8 kernel.
The return value will be a dictionary that maps an irregular grid of
batch sizes to configurations of the w8a8 block fp8 kernel. To evaluate the
kernel on a given batch size bs, the closest batch size in the grid should
be picked and the associated configuration chosen to invoke the kernel.
"""
# First look up if an optimized configuration is available in the configs
# directory
device_name = current_platform.get_device_name().replace(" ", "_")
json_file_name = f"N={N},K={K},device_name={device_name},dtype=int8_w8a8,block_shape=[{block_n}, {block_k}].json" # noqa: E501
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.info(
"Using configuration from %s for W8A8 Block INT8 kernel.",
config_file_path,
)
# If a configuration has been found, return it
return {int(key): val for key, val in json.load(f).items()}
# If no optimized configuration is available, we will use the default
# configuration
logger.warning(
(
"Using default W8A8 Block INT8 kernel config. Performance might "
"be sub-optimal! Config file not found at %s"
),
config_file_path,
)
return None
def w8a8_block_int8_matmul(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: List[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
"""matrix multiplication with block-wise quantization.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
Args:
A: The input tensor, e.g., activation.
B: The input tensor, e.g., weight.
As: The per-token-group quantization scale for `A`.
Bs: The per-block quantization scale for `B`.
block_size: The block size for per-block quantization. It should
be 2-dim, e.g., [128, 128].
output_dytpe: The dtype of the returned tensor.
Returns:
torch.Tensor: The result of matmul.
"""
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert A.shape[-1] == B.shape[-1]
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
M = A.numel() // A.shape[-1]
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
N, K = B.shape
assert triton.cdiv(N, block_n) == Bs.shape[0]
assert triton.cdiv(K, block_k) == Bs.shape[1]
C_shape = A.shape[:-1] + (N,)
C = A.new_empty(C_shape, dtype=output_dtype)
#configs = get_w8a8_block_int8_configs(N, K, block_size[0], block_size[1])
#if configs:
# # If an optimal configuration map has been found, look up the
# # optimal config
# config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
#else:
# Default config
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
#print("block_size[0]:{},block_size[1]:{}".format(block_size[0],block_size[1]))
# config = {
# "BLOCK_SIZE_M": 32, #64
# "BLOCK_SIZE_N": block_size[0],
# "BLOCK_SIZE_K": block_size[1],
# "GROUP_SIZE_M": 32,
# "num_warps": 4,
# "num_stages": 3,
# }
if M<=64:
config = {
"BLOCK_SIZE_M": 16, #64
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 2,
"num_warps": 4,
"num_stages": 0,
}
elif M<128:
config = {
"BLOCK_SIZE_M": 32, #64
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 2,
"num_warps": 4,
"num_stages": 0,
}
elif M<=256:
config = {
"BLOCK_SIZE_M": 64, #64
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 2,
"num_warps": 4,
"num_stages": 0,
}
else :
config = {
"BLOCK_SIZE_M": 64, #64
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 0,
}
def grid(META):
return (
triton.cdiv(M, META["BLOCK_SIZE_M"]) *
triton.cdiv(N, META["BLOCK_SIZE_N"]),
)
_w8a8_block_int8_matmul[grid](
A,
B,
C,
As,
Bs,
M,
N,
K,
block_n,
block_k,
A.stride(-2),
A.stride(-1),
B.stride(1),
B.stride(0),
C.stride(-2),
C.stride(-1),
As.stride(-2),
As.stride(-1),
Bs.stride(1),
Bs.stride(0),
**config,
)
return C
def native_w8a8_block_int8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16):
"""matrix multiplication with block-wise quantization using native torch.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
"""
A = A.to(torch.float32)
B = B.to(torch.float32)
assert A.shape[-1] == B.shape[-1]
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1]
assert A.shape[:-1] == As.shape[:-1]
M = A.numel() // A.shape[-1]
N, K = B.shape
origin_C_shape = A.shape[:-1] + (N,)
A = A.reshape(M, A.shape[-1])
As = As.reshape(M, As.shape[-1])
n_tiles = (N + block_n - 1) // block_n
k_tiles = (K + block_k - 1) // block_k
assert n_tiles == Bs.shape[0]
assert k_tiles == Bs.shape[1]
C_shape = (M, N)
C = torch.zeros(C_shape, dtype=torch.float32, device=A.device)
A_tiles = [A[:, i * block_k : min((i + 1) * block_k, K)] for i in range(k_tiles)]
B_tiles = [
[
B[
j * block_n : min((j + 1) * block_n, N),
i * block_k : min((i + 1) * block_k, K),
]
for i in range(k_tiles)
]
for j in range(n_tiles)
]
C_tiles = [C[:, j * block_n : min((j + 1) * block_n, N)] for j in range(n_tiles)]
As_tiles = [As[:, i : i + 1] for i in range(k_tiles)]
for i in range(k_tiles):
for j in range(n_tiles):
a = A_tiles[i]
b = B_tiles[j][i]
c = C_tiles[j]
s = As_tiles[i] * Bs[j][i]
c[:, :] += torch.matmul(a, b.t()) * s
C = C.reshape(origin_C_shape).to(output_dtype)
return C
def apply_w8a8_block_int8_linear(
input: torch.Tensor,
weight: torch.Tensor,
block_size: List[int],
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert input_scale is None
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
q_input, x_scale = per_token_group_quant_int8(input_2d, block_size[1])
output = w8a8_block_int8_matmul(
q_input, weight, x_scale, weight_scale, block_size,
output_dtype=input.dtype
)
# output = native_w8a8_block_int8_matmul(
# q_input, weight, x_scale, weight_scale, block_size,
# output_dtype=input.dtype
# )
if bias is not None:
output = output + bias
return output.to(dtype=input.dtype).view(*output_shape)
def input_to_int8(
x: torch.Tensor, dtype: torch.dtype = torch.int8
) -> Tuple[torch.Tensor, torch.Tensor]:
"""This function quantizes input values to
int8 values with tensor-wise quantization.
"""
iinfo = torch.iinfo(dtype)
min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
int8_min, int8_max = iinfo.min, iinfo.max
scale = int8_max / amax
x_scl_sat = (x * scale).clamp(min=int8_min, max=int8_max)
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
def block_dequant(
x_q_block: torch.Tensor,
x_s: torch.Tensor,
block_size: List[int],
) -> torch.Tensor:
"""This function conducts block-wise dequantization.
The inputs are block-wise quantization tensor `x_q_block`,
block-wise quantization scale and the block size.
The outputs are dequantized tensor.
"""
block_n, block_k = block_size[0], block_size[1]
n, k = x_q_block.shape
n_tiles = (n + block_n - 1) // block_n
k_tiles = (k + block_k - 1) // block_k
assert n_tiles == x_s.shape[0]
assert k_tiles == x_s.shape[1]
x_dq_block = x_q_block.to(torch.float32)
for i in range(k_tiles):
for j in range(n_tiles):
x_dq_block[
j * block_n : min((j + 1) * block_n, n),
i * block_k : min((i + 1) * block_k, k),
] *= x_s[j][i]
return x_dq_block
...@@ -59,7 +59,9 @@ from .utils import (PPMissingLayer, is_pp_missing_parameter, ...@@ -59,7 +59,9 @@ from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix) maybe_prefix)
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.int8_utils import (
block_dequant as int8_block_dequant,
)
class DeepseekV2MLP(nn.Module): class DeepseekV2MLP(nn.Module):
......
...@@ -72,7 +72,7 @@ class RocmPlatform(Platform): ...@@ -72,7 +72,7 @@ class RocmPlatform(Platform):
supported_quantization: list[str] = [ supported_quantization: list[str] = [
"awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors", "awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors",
"fbgemm_fp8", "gguf", "quark", "moe_wna16" "fbgemm_fp8", "gguf", "quark", "moe_wna16","blockwise_int8"
] ]
@classmethod @classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment