generate_gemm_config.py 583 Bytes
Newer Older
Li Zhang's avatar
Li Zhang committed
1
2
3
# Copyright (c) OpenMMLab. All rights reserved.

import subprocess
AllentDan's avatar
AllentDan committed
4

Li Zhang's avatar
Li Zhang committed
5
6
7
import fire


Li Zhang's avatar
Li Zhang committed
8
def main(head_num: int = 32,
Li Zhang's avatar
Li Zhang committed
9
         size_per_head: int = 128,
Li Zhang's avatar
Li Zhang committed
10
11
12
         vocab_size: int = 32000,
         inter_size: int = 11008,
         tensor_para_size: int = 1,
Li Zhang's avatar
Li Zhang committed
13
14
15
         max_batch_size: int = 64):
    for bsz in range(1, max_batch_size + 1):
        subprocess.call(
AllentDan's avatar
AllentDan committed
16
17
            f'bin/llama_gemm {bsz} 1 1 {head_num} {size_per_head} {inter_size}'
            f' {vocab_size} 1 {tensor_para_size} {0 if bsz == 1 else 1}',
Li Zhang's avatar
Li Zhang committed
18
19
20
21
22
            shell=True)


if __name__ == '__main__':
    fire.Fire(main)