test_build.py 4.26 KB
Newer Older
Patrick Labatut's avatar
Patrick Labatut committed
1
2
3
4
5
6
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
7
import json
8
import os
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
9
10
import unittest
from collections import Counter
11

12
from common_testing import get_pytorch3d_dir
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
13

14

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
15
# This file groups together tests which look at the code without running it.
16
in_conda_build = os.environ.get("CONDA_BUILD_STATE", "") == "TEST"
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
17
in_re_worker = os.environ.get("INSIDE_RE_WORKER") is not None
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
18
19
20


class TestBuild(unittest.TestCase):
21
    @unittest.skipIf(in_re_worker, "In RE worker")
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
22
23
24
    def test_name_clash(self):
        # For setup.py, all translation units need distinct names, so we
        # cannot have foo.cu and foo.cpp, even in different directories.
25
        source_dir = get_pytorch3d_dir() / "pytorch3d"
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
26
27
28
29
30
31
32
33
34
35

        stems = []
        for extension in [".cu", ".cpp"]:
            files = source_dir.glob(f"**/*{extension}")
            stems.extend(f.stem for f in files)

        counter = Counter(stems)
        for k, v in counter.items():
            self.assertEqual(v, 1, f"Too many files with stem {k}.")

36
    @unittest.skipIf(in_re_worker, "In RE worker")
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
37
    def test_copyright(self):
38
        root_dir = get_pytorch3d_dir()
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
39
40
41

        extensions = ("py", "cu", "cuh", "cpp", "h", "hpp", "sh")

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
42
        expect = "Copyright (c) Facebook, Inc. and its affiliates.\n"
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
43

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
44
45
        files_missing_copyright_header = []

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
46
        for extension in extensions:
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
47
            for path in root_dir.glob(f"**/*.{extension}"):
48
49
50
51
52
53
54
55
56
57
58
59
60
                excluded_files = (
                    "pytorch3d/transforms/external/kornia_angle_axis_to_rotation_matrix.py",
                    "pytorch3d/csrc/pulsar/include/fastermath.h",
                )
                if in_conda_build:
                    excluded_files += (
                        "run_test.py",
                        "run_test.sh",
                        "conda_test_runner.sh",
                        "conda_test_env_vars.sh",
                    )

                if str(path).endswith(excluded_files):
Christoph Lassner's avatar
Christoph Lassner committed
61
                    continue
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
62
                with open(path) as f:
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
63
                    firstline = f.readline()
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
64
                    if firstline.startswith(("# -*-", "#!", "/*")):
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
65
                        firstline = f.readline()
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
66
67
68
69
70
71
                    if not firstline.endswith(expect):
                        files_missing_copyright_header.append(str(path))

        if len(files_missing_copyright_header) != 0:
            self.fail("\n".join(files_missing_copyright_header))

72
    @unittest.skipIf(in_re_worker, "In RE worker")
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
73
74
    def test_valid_ipynbs(self):
        # Check that the ipython notebooks are valid json
75
76
        root_dir = get_pytorch3d_dir()
        tutorials_dir = root_dir / "docs" / "tutorials"
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
77
78
79
80
81
        tutorials = sorted(tutorials_dir.glob("*.ipynb"))

        for tutorial in tutorials:
            with open(tutorial) as f:
                json.load(f)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113

    @unittest.skipIf(in_conda_build or in_re_worker, "In conda build, or RE worker")
    def test_enumerated_ipynbs(self):
        # Check that the tutorials are all referenced in tutorials.json.
        root_dir = get_pytorch3d_dir()
        tutorials_dir = root_dir / "docs" / "tutorials"
        tutorials_on_disk = sorted(i.stem for i in tutorials_dir.glob("*.ipynb"))

        json_file = root_dir / "website" / "tutorials.json"
        with open(json_file) as f:
            cfg_dict = json.load(f)
        listed_in_json = []
        for section in cfg_dict.values():
            listed_in_json.extend(item["id"] for item in section)

        self.assertListEqual(sorted(listed_in_json), tutorials_on_disk)

    @unittest.skipIf(in_conda_build or in_re_worker, "In conda build, or RE worker")
    def test_enumerated_notes(self):
        # Check that the notes are all referenced in sidebars.json.
        root_dir = get_pytorch3d_dir()
        notes_dir = root_dir / "docs" / "notes"
        notes_on_disk = sorted(i.stem for i in notes_dir.glob("*.md"))

        json_file = root_dir / "website" / "sidebars.json"
        with open(json_file) as f:
            cfg_dict = json.load(f)
        listed_in_json = []
        for section in cfg_dict["docs"].values():
            listed_in_json.extend(section)

        self.assertListEqual(sorted(listed_in_json), notes_on_disk)