generate_gemm_config.py 862 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# Copyright (c) OpenMMLab. All rights reserved.

import subprocess

import fire


def get_llama_gemm():
    import os.path as osp

    import lmdeploy
    lmdeploy_dir = osp.split(lmdeploy.__file__)[0]
    bin_path = osp.join(lmdeploy_dir, 'bin', 'llama_gemm')
    assert osp.exists(bin_path), f'{bin_path} not exists'
    return bin_path


def main(head_num: int = 32,
         size_per_head: int = 128,
         vocab_size: int = 32000,
         inter_size: int = 11008,
         tensor_para_size: int = 1,
         max_batch_size: int = 64):
    for bsz in range(1, max_batch_size + 1):
        subprocess.call(
            f'{get_llama_gemm()} {bsz} 1 1 {head_num} {size_per_head}'
            f' {inter_size} {vocab_size} 1 {tensor_para_size}'
            f' {0 if bsz == 1 else 1}',
            shell=True)


if __name__ == '__main__':
    fire.Fire(main)