python_test.py 934 Bytes
Newer Older
1
2
import os
import subprocess
PanZezhong's avatar
PanZezhong committed
3
from set_env import set_env
4
5
6
7
8
9
10
11
12
13
14
import sys

PROJECT_DIR = os.path.abspath(
    os.path.join(os.path.dirname(__file__), "..", "test", "infiniop")
)
os.chdir(PROJECT_DIR)


def run_tests(args):
    failed = []
    for test in [
15
        "causal_softmax.py",
16
        "gemm.py",
17
        "random_sample.py",
18
        "rms_norm.py",
19
        "rope.py",
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
        "swiglu.py",
    ]:
        result = subprocess.run(
            f"python {test} {args}", text=True, encoding="utf-8", shell=True
        )
        if result.returncode != 0:
            failed.append(test)

    return failed


if __name__ == "__main__":
    set_env()
    failed = run_tests(" ".join(sys.argv[1:]))
    if len(failed) == 0:
        print("\033[92mAll tests passed!\033[0m")
    else:
        print("\033[91mThe following tests failed:\033[0m")
        for test in failed:
            print(f"\033[91m - {test}\033[0m")
    exit(len(failed))