test_build.py 5.09 KB
Newer Older
Patrick Labatut's avatar
Patrick Labatut committed
1
2
3
4
5
6
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
7
import importlib
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
8
import json
9
import os
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
10
import sys
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
11
import unittest
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
12
import unittest.mock
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
13
from collections import Counter
14

15
from common_testing import get_pytorch3d_dir
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
16

17

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
18
# This file groups together tests which look at the code without running it.
19
in_conda_build = os.environ.get("CONDA_BUILD_STATE", "") == "TEST"
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
20
in_re_worker = os.environ.get("INSIDE_RE_WORKER") is not None
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
21
22
23
24
25
26


class TestBuild(unittest.TestCase):
    def test_name_clash(self):
        # For setup.py, all translation units need distinct names, so we
        # cannot have foo.cu and foo.cpp, even in different directories.
27
        source_dir = get_pytorch3d_dir() / "pytorch3d"
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
28
29
30
31
32
33
34
35
36
37

        stems = []
        for extension in [".cu", ".cpp"]:
            files = source_dir.glob(f"**/*{extension}")
            stems.extend(f.stem for f in files)

        counter = Counter(stems)
        for k, v in counter.items():
            self.assertEqual(v, 1, f"Too many files with stem {k}.")

38
    @unittest.skipIf(in_re_worker, "In RE worker")
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
39
    def test_copyright(self):
40
        root_dir = get_pytorch3d_dir()
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
41
42
43

        extensions = ("py", "cu", "cuh", "cpp", "h", "hpp", "sh")

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
44
        expect = "Copyright (c) Facebook, Inc. and its affiliates.\n"
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
45

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
46
47
        files_missing_copyright_header = []

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
48
        for extension in extensions:
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
49
            for path in root_dir.glob(f"**/*.{extension}"):
50
51
52
53
54
55
56
57
58
59
60
61
62
                excluded_files = (
                    "pytorch3d/transforms/external/kornia_angle_axis_to_rotation_matrix.py",
                    "pytorch3d/csrc/pulsar/include/fastermath.h",
                )
                if in_conda_build:
                    excluded_files += (
                        "run_test.py",
                        "run_test.sh",
                        "conda_test_runner.sh",
                        "conda_test_env_vars.sh",
                    )

                if str(path).endswith(excluded_files):
Christoph Lassner's avatar
Christoph Lassner committed
63
                    continue
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
64
                with open(path) as f:
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
65
                    firstline = f.readline()
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
66
                    if firstline.startswith(("# -*-", "#!", "/*")):
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
67
                        firstline = f.readline()
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
68
69
70
71
72
73
                    if not firstline.endswith(expect):
                        files_missing_copyright_header.append(str(path))

        if len(files_missing_copyright_header) != 0:
            self.fail("\n".join(files_missing_copyright_header))

74
    @unittest.skipIf(in_re_worker, "In RE worker")
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
75
76
    def test_valid_ipynbs(self):
        # Check that the ipython notebooks are valid json
77
78
        root_dir = get_pytorch3d_dir()
        tutorials_dir = root_dir / "docs" / "tutorials"
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
79
80
81
82
83
        tutorials = sorted(tutorials_dir.glob("*.ipynb"))

        for tutorial in tutorials:
            with open(tutorial) as f:
                json.load(f)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115

    @unittest.skipIf(in_conda_build or in_re_worker, "In conda build, or RE worker")
    def test_enumerated_ipynbs(self):
        # Check that the tutorials are all referenced in tutorials.json.
        root_dir = get_pytorch3d_dir()
        tutorials_dir = root_dir / "docs" / "tutorials"
        tutorials_on_disk = sorted(i.stem for i in tutorials_dir.glob("*.ipynb"))

        json_file = root_dir / "website" / "tutorials.json"
        with open(json_file) as f:
            cfg_dict = json.load(f)
        listed_in_json = []
        for section in cfg_dict.values():
            listed_in_json.extend(item["id"] for item in section)

        self.assertListEqual(sorted(listed_in_json), tutorials_on_disk)

    @unittest.skipIf(in_conda_build or in_re_worker, "In conda build, or RE worker")
    def test_enumerated_notes(self):
        # Check that the notes are all referenced in sidebars.json.
        root_dir = get_pytorch3d_dir()
        notes_dir = root_dir / "docs" / "notes"
        notes_on_disk = sorted(i.stem for i in notes_dir.glob("*.md"))

        json_file = root_dir / "website" / "sidebars.json"
        with open(json_file) as f:
            cfg_dict = json.load(f)
        listed_in_json = []
        for section in cfg_dict["docs"].values():
            listed_in_json.extend(section)

        self.assertListEqual(sorted(listed_in_json), notes_on_disk)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134

    def test_no_import_cycles(self):
        # Check each module of pytorch3d imports cleanly,
        # which may fail if there are import cycles.

        # First check the setup of the test. If any of pytorch3d
        # was already imported the test would be pointless.
        for module in sys.modules:
            self.assertFalse(module.startswith("pytorch3d"), module)

        root_dir = get_pytorch3d_dir() / "pytorch3d"
        for module_file in root_dir.glob("**/*.py"):
            if module_file.stem == "__init__":
                continue
            relative_module = str(module_file.relative_to(root_dir))[:-3]
            module = "pytorch3d." + relative_module.replace("/", ".")
            with self.subTest(name=module):
                with unittest.mock.patch.dict(sys.modules):
                    importlib.import_module(module)