test_list.py 1.77 KB
Newer Older
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import ast
from pathlib import Path
from typing import List


"""
This module outputs a list of tests for completion.
It has no dependencies.
"""


def get_test_files() -> List[Path]:
    root = Path(__file__).parent.parent
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
20
21
    dirs = ["tests", "projects/implicitron_trainer"]
    return [i for dir in dirs for i in (root / dir).glob("**/test*.py")]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54


def tests_from_file(path: Path, base: str) -> List[str]:
    """
    Returns all the tests in the given file, in format
    expected as arguments when running the tests.
    e.g.
        file_stem
        file_stem.TestFunctionality
        file_stem.TestFunctionality.test_f
        file_stem.TestFunctionality.test_g
    """
    with open(path) as f:
        node = ast.parse(f.read())
    out = [base]
    for cls in node.body:
        if not isinstance(cls, ast.ClassDef):
            continue
        if not cls.name.startswith("Test"):
            continue
        class_base = base + "." + cls.name
        out.append(class_base)
        for method in cls.body:
            if not isinstance(method, ast.FunctionDef):
                continue
            if not method.name.startswith("test"):
                continue
            out.append(class_base + "." + method.name)
    return out


def main() -> None:
    files = get_test_files()
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
55
    test_root = Path(__file__).parent.parent
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
56
57
58
59
60
61
62
63
64
65
    all_tests = []
    for f in files:
        file_base = str(f.relative_to(test_root))[:-3].replace("/", ".")
        all_tests.extend(tests_from_file(f, file_base))
    for test in sorted(all_tests):
        print(test)


if __name__ == "__main__":
    main()