"vscode:/vscode.git/clone" did not exist on "235d586bfb62e75352fed3b0f1e8dd1624d87dad"
test_build.py 2.59 KB
Newer Older
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
import os
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
3
4
5
6
import unittest
from collections import Counter
from pathlib import Path

7

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
8
# This file groups together tests which look at the code without running it.
9
10
# When running the tests inside conda's build, the code is not available.
in_conda_build = os.environ.get("CONDA_BUILD_STATE", "") == "TEST"
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
11
12
13


class TestBuild(unittest.TestCase):
14
    @unittest.skipIf(in_conda_build, "In conda build")
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
    def test_name_clash(self):
        # For setup.py, all translation units need distinct names, so we
        # cannot have foo.cu and foo.cpp, even in different directories.
        test_dir = Path(__file__).resolve().parent
        source_dir = test_dir.parent / "pytorch3d"

        stems = []
        for extension in [".cu", ".cpp"]:
            files = source_dir.glob(f"**/*{extension}")
            stems.extend(f.stem for f in files)

        counter = Counter(stems)
        for k, v in counter.items():
            self.assertEqual(v, 1, f"Too many files with stem {k}.")

30
    @unittest.skipIf(in_conda_build, "In conda build")
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
    def test_deprecated_usage(self):
        # Check certain expressions do not occur in the csrc code
        test_dir = Path(__file__).resolve().parent
        source_dir = test_dir.parent / "pytorch3d" / "csrc"

        files = sorted(source_dir.glob("**/*.*"))
        self.assertGreater(len(files), 4)

        patterns = [".type()", ".data()"]

        for file in files:
            with open(file) as f:
                text = f.read()
                for pattern in patterns:
                    found = pattern in text
                    msg = (
                        f"{pattern} found in {file.name}"
                        + ", this has been deprecated."
                    )
                    self.assertFalse(found, msg)

52
    @unittest.skipIf(in_conda_build, "In conda build")
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    def test_copyright(self):
        test_dir = Path(__file__).resolve().parent
        root_dir = test_dir.parent

        extensions = ("py", "cu", "cuh", "cpp", "h", "hpp", "sh")

        expect = (
            "Copyright (c) Facebook, Inc. and its affiliates."
            + " All rights reserved.\n"
        )

        for extension in extensions:
            for i in root_dir.glob(f"**/*.{extension}"):
                with open(i) as f:
                    firstline = f.readline()
                    if firstline.startswith(("# -*-", "#!")):
                        firstline = f.readline()
                    self.assertTrue(
71
                        firstline.endswith(expect), f"{i} missing copyright header."
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
72
                    )