"git@developer.sourcefind.cn:tsoc/superbenchmark.git" did not exist on "7a27732e97eb7e20ac7039578498fabdfa9e05e2"
python_test.py 1.47 KB
Newer Older
1
2
import os
import subprocess
PanZezhong's avatar
PanZezhong committed
3
from set_env import set_env
4
5
6
7
8
9
10
11
12
13
14
import sys

PROJECT_DIR = os.path.abspath(
    os.path.join(os.path.dirname(__file__), "..", "test", "infiniop")
)
os.chdir(PROJECT_DIR)


def run_tests(args):
    failed = []
    for test in [
PanZezhong's avatar
PanZezhong committed
15
        "add.py",
16
17
        "attention.py",
        "causal_softmax.py",
PanZezhong's avatar
PanZezhong committed
18
        "clip.py",
19
20
21
        "conv.py",
        #"dequantize_awq.py",
        "gelu.py",
22
        "gemm.py",
zhangyue's avatar
zhangyue committed
23
        #"layer_norm.py",
gongchensu's avatar
gongchensu committed
24
        "logsoftmax.py",
zhangyue's avatar
zhangyue committed
25
        #"lp_norm.py",
26
        "mul.py",
27
        "ones.py",
28
        "random_sample.py",
29
        "rearrange.py",
30
        "relu.py",
31
        "rms_norm.py",
32
        "rope.py",
33
        "sigmoid.py",
zhangyue's avatar
zhangyue committed
34
        #"softmax.py",
35
        "softplus.py",
Pepe's avatar
Pepe committed
36
        "sub.py",
37
        "swiglu.py",
38
        "tanh.py",
39
40
        "topkrouter.py",
        "topksoftmax.py",
41
        "zeros.py",
42
43
44
        # "paged_attention.py",
        # "paged_caching.py",
        # "paged_attention_prefill.py"
45
46
    ]:
        result = subprocess.run(
47
            f"python {test} {args} --debug", text=True, encoding="utf-8", shell=True
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
        )
        if result.returncode != 0:
            failed.append(test)

    return failed


if __name__ == "__main__":
    set_env()
    failed = run_tests(" ".join(sys.argv[1:]))
    if len(failed) == 0:
        print("\033[92mAll tests passed!\033[0m")
    else:
        print("\033[91mThe following tests failed:\033[0m")
        for test in failed:
            print(f"\033[91m - {test}\033[0m")
    exit(len(failed))