python_test.py 965 Bytes
Newer Older
1
2
import os
import subprocess
PanZezhong's avatar
PanZezhong committed
3
from set_env import set_env
4
5
6
7
8
9
10
11
12
13
14
import sys

PROJECT_DIR = os.path.abspath(
    os.path.join(os.path.dirname(__file__), "..", "test", "infiniop")
)
os.chdir(PROJECT_DIR)


def run_tests(args):
    failed = []
    for test in [
PanZezhong's avatar
PanZezhong committed
15
        "add.py",
16
        "gemm.py",
17
        "random_sample.py",
18
        "rms_norm.py",
19
        "rope.py",
Pepe's avatar
Pepe committed
20
        "sub.py",
21
        "swiglu.py",
PanZezhong's avatar
PanZezhong committed
22
        "attention.py",
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
    ]:
        result = subprocess.run(
            f"python {test} {args}", text=True, encoding="utf-8", shell=True
        )
        if result.returncode != 0:
            failed.append(test)

    return failed


if __name__ == "__main__":
    set_env()
    failed = run_tests(" ".join(sys.argv[1:]))
    if len(failed) == 0:
        print("\033[92mAll tests passed!\033[0m")
    else:
        print("\033[91mThe following tests failed:\033[0m")
        for test in failed:
            print(f"\033[91m - {test}\033[0m")
    exit(len(failed))