test_mpu.py 1.89 KB
Newer Older
Masaki Kozuki's avatar
Masaki Kozuki committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import os
import subprocess
import sys
import unittest


def run_mpu_tests():
    python_executable_path = sys.executable
    # repository_root = os.path.join(os.path.dirname(__file__), "../../../")
    # directory = os.path.abspath(os.path.join(repository_root, "tests/mpu"))
    directory = os.path.dirname(__file__)
    files = [
        os.path.join(directory, f) for f in os.listdir(directory)
        if f.startswith("run_") and os.path.isfile(os.path.join(directory, f))
    ]
    print("#######################################################")
    print(f"# Python executable path: {python_executable_path}")
    print(f"# {len(files)} tests: {files}")
    print("#######################################################")
    errors = []
    for i, test_file in enumerate(files, 1):
        test_run_cmd = f"NVIDIA_TF32_OVERRIDE=0  {python_executable_path} {test_file} --micro-batch-size 2 --num-layers 1 --hidden-size 256 --num-attention-heads 8 --max-position-embeddings 32 --encoder-seq-length 32 --use-cpu-initialization"  # NOQA
        print(f"### {i} / {len(files)}: cmd: {test_run_cmd}")
        try:
            output = subprocess.check_output(
                test_run_cmd, shell=True
            ).decode(sys.stdout.encoding).strip()
        except Exception as e:
            errors.append((test_file, str(e)))
        else:
            if '>> passed the test :-)' not in output:
                errors.append(test_file, output)
    else:
        if not errors:
            print("### PASSED")
        else:
            print("### FAILED")
            short_msg = f"{len(errors)} out of {len(files)} tests failed"
            print(short_msg)
            for (filename, log) in errors:
                print(f"File: {filename}\nLog: {log}")
            raise RuntimeError(short_msg)


class TestMPU(unittest.TestCase):

    def test_mpu(self):
        run_mpu_tests()


if __name__ == '__main__':
    unittest.main()