test_build.py 4.24 KB
Newer Older
Patrick Labatut's avatar
Patrick Labatut committed
1
2
3
4
5
6
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
7
import json
8
import os
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
9
10
import unittest
from collections import Counter
11
12

from common_testing import get_pytorch3d_dir, get_tests_dir
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
13

14

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
15
# This file groups together tests which look at the code without running it.
16
17
# When running the tests inside conda's build, the code is not available.
in_conda_build = os.environ.get("CONDA_BUILD_STATE", "") == "TEST"
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
18
in_re_worker = os.environ.get("INSIDE_RE_WORKER") is not None
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
19
20
21


class TestBuild(unittest.TestCase):
22
    @unittest.skipIf(in_conda_build or in_re_worker, "In conda build, or RE worker")
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
23
24
25
    def test_name_clash(self):
        # For setup.py, all translation units need distinct names, so we
        # cannot have foo.cu and foo.cpp, even in different directories.
26
        test_dir = get_tests_dir()
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
27
28
29
30
31
32
33
34
35
36
37
        source_dir = test_dir.parent / "pytorch3d"

        stems = []
        for extension in [".cu", ".cpp"]:
            files = source_dir.glob(f"**/*{extension}")
            stems.extend(f.stem for f in files)

        counter = Counter(stems)
        for k, v in counter.items():
            self.assertEqual(v, 1, f"Too many files with stem {k}.")

38
    @unittest.skipIf(in_conda_build or in_re_worker, "In conda build, or RE worker")
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
39
    def test_copyright(self):
40
        test_dir = get_tests_dir()
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
41
42
43
44
        root_dir = test_dir.parent

        extensions = ("py", "cu", "cuh", "cpp", "h", "hpp", "sh")

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
45
        expect = "Copyright (c) Facebook, Inc. and its affiliates.\n"
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
46

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
47
48
        files_missing_copyright_header = []

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
49
        for extension in extensions:
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
50
51
            for path in root_dir.glob(f"**/*.{extension}"):
                if str(path).endswith(
Christoph Lassner's avatar
Christoph Lassner committed
52
53
54
                    "pytorch3d/transforms/external/kornia_angle_axis_to_rotation_matrix.py"
                ):
                    continue
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
55
                if str(path).endswith("pytorch3d/csrc/pulsar/include/fastermath.h"):
Christoph Lassner's avatar
Christoph Lassner committed
56
                    continue
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
57
                with open(path) as f:
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
58
                    firstline = f.readline()
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
59
                    if firstline.startswith(("# -*-", "#!", "/*")):
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
60
                        firstline = f.readline()
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
61
62
63
64
65
66
                    if not firstline.endswith(expect):
                        files_missing_copyright_header.append(str(path))

        if len(files_missing_copyright_header) != 0:
            self.fail("\n".join(files_missing_copyright_header))

67
    @unittest.skipIf(in_conda_build or in_re_worker, "In conda build, or RE worker")
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
68
69
    def test_valid_ipynbs(self):
        # Check that the ipython notebooks are valid json
70
71
        root_dir = get_pytorch3d_dir()
        tutorials_dir = root_dir / "docs" / "tutorials"
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
72
73
74
75
76
        tutorials = sorted(tutorials_dir.glob("*.ipynb"))

        for tutorial in tutorials:
            with open(tutorial) as f:
                json.load(f)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108

    @unittest.skipIf(in_conda_build or in_re_worker, "In conda build, or RE worker")
    def test_enumerated_ipynbs(self):
        # Check that the tutorials are all referenced in tutorials.json.
        root_dir = get_pytorch3d_dir()
        tutorials_dir = root_dir / "docs" / "tutorials"
        tutorials_on_disk = sorted(i.stem for i in tutorials_dir.glob("*.ipynb"))

        json_file = root_dir / "website" / "tutorials.json"
        with open(json_file) as f:
            cfg_dict = json.load(f)
        listed_in_json = []
        for section in cfg_dict.values():
            listed_in_json.extend(item["id"] for item in section)

        self.assertListEqual(sorted(listed_in_json), tutorials_on_disk)

    @unittest.skipIf(in_conda_build or in_re_worker, "In conda build, or RE worker")
    def test_enumerated_notes(self):
        # Check that the notes are all referenced in sidebars.json.
        root_dir = get_pytorch3d_dir()
        notes_dir = root_dir / "docs" / "notes"
        notes_on_disk = sorted(i.stem for i in notes_dir.glob("*.md"))

        json_file = root_dir / "website" / "sidebars.json"
        with open(json_file) as f:
            cfg_dict = json.load(f)
        listed_in_json = []
        for section in cfg_dict["docs"].values():
            listed_in_json.extend(section)

        self.assertListEqual(sorted(listed_in_json), notes_on_disk)