setup.py 5.82 KB
Newer Older
facebook-github-bot's avatar
facebook-github-bot committed
1
2
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
Stanislav Pidhorskyi's avatar
Stanislav Pidhorskyi committed
3
# This source code is licensed under the MIT license found in the
facebook-github-bot's avatar
facebook-github-bot committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# LICENSE file in the root directory of this source tree.

import os
import platform
import re
import sys

from pkg_resources import DistributionNotFound, get_distribution

from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension


def main(debug: bool) -> None:
    extra_link_args = {
        "linux": ["-static-libgcc"] + ([] if debug else ["-flto"]),
        "win32": ["/DEBUG"] if debug else [],
    }
    cxx_args = {
        "linux": ["-std=c++17", "-Wall"]
        + (["-O0", "-g3", "-DDEBUG"] if debug else ["-O3", "--fast-math"]),
        "win32": (
            ["/std:c++17", "/MT", "/GR-", "/EHsc", '/D "NOMINMAX"']
            + (
                ["/Od", '/D "_DEBUG"']
                if debug
                else ["/O2", "/fp:fast", "/GL", '/D "NDEBUG"']
            )
        ),
    }

35
36
37
38
39
40
41
42
43
44
45
46
    nvcc_args = ["-O0", "-g", "-DDEBUG"] if debug else ["-O3", "--use_fast_math"]
    if not os.getenv("TORCH_CUDA_ARCH_LIST"):
        # Respect TORCH_CUDA_ARCH_LIST when set, otherwise fall back to a default list of archs
        nvcc_args.extend(
            [
                "-gencode=arch=compute_72,code=sm_72",
                "-gencode=arch=compute_75,code=sm_75",
                "-gencode=arch=compute_80,code=sm_80",
                "-gencode=arch=compute_86,code=sm_86",
                "-gencode=arch=compute_90,code=sm_90",
            ]
        )
facebook-github-bot's avatar
facebook-github-bot committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150

    # There is som issue effecting latest NVCC and pytorch 2.3.0 https://github.com/pytorch/pytorch/issues/122169
    # The workaround is adding -std=c++20 to NVCC args
    nvcc_args.append("-std=c++20")

    def get_dist(name):
        try:
            return get_distribution(name)
        except DistributionNotFound:
            return None

    root_path = os.path.dirname(os.path.abspath(__file__))
    package_name = "drtk"
    with open(os.path.join(root_path, "drtk", "__init__.py")) as f:
        init_file = f.read()

    pattern = re.compile(r"__version__\s*=\s*\"(\d*\.\d*.\d*)\"")
    groups = pattern.findall(init_file)
    assert len(groups) == 1
    version = groups[0]

    target_os = "none"

    if sys.platform == "darwin":
        target_os = "macos"
    elif os.name == "posix":
        target_os = "linux"
    elif platform.system() == "Windows":
        target_os = "win32"

    if target_os == "none":
        raise RuntimeError("Could not detect platform")
    if target_os == "macos":
        raise RuntimeError("Platform is not supported")

    include_dir = [os.path.join(root_path, "src", "include")]

    with open("README.md") as f:
        readme = f.read()

    pillow = "pillow" if get_dist("pillow-simd") is None else "pillow-simd"

    setup(
        name=package_name,
        version=version,
        author="Reality Labs, Meta",
        description="Differentiable Rendering Toolkit",
        long_description=readme,
        long_description_content_type="text/markdown",
        license="MIT",
        install_requires=["numpy", "torch", "torchvision", pillow],
        ext_modules=[
            CUDAExtension(
                name="drtk.rasterize_ext",
                sources=[
                    "src/rasterize/rasterize_module.cpp",
                    "src/rasterize/rasterize_kernel.cu",
                ],
                extra_compile_args={"cxx": cxx_args[target_os], "nvcc": nvcc_args},
                extra_link_args=extra_link_args[target_os],
                include_dirs=include_dir,
            ),
            CUDAExtension(
                "drtk.render_ext",
                sources=["src/render/render_kernel.cu", "src/render/render_module.cpp"],
                extra_compile_args={"cxx": cxx_args[target_os], "nvcc": nvcc_args},
                include_dirs=include_dir,
            ),
            CUDAExtension(
                "drtk.edge_grad_ext",
                sources=[
                    "src/edge_grad/edge_grad_module.cpp",
                    "src/edge_grad/edge_grad_kernel.cu",
                ],
                extra_compile_args={"cxx": cxx_args[target_os], "nvcc": nvcc_args},
                include_dirs=include_dir,
            ),
            CUDAExtension(
                "drtk.mipmap_grid_sampler_ext",
                sources=[
                    "src/mipmap_grid_sampler/mipmap_grid_sampler_module.cpp",
                    "src/mipmap_grid_sampler/mipmap_grid_sampler_kernel.cu",
                ],
                extra_compile_args={"cxx": cxx_args[target_os], "nvcc": nvcc_args},
                include_dirs=include_dir,
            ),
            CUDAExtension(
                "drtk.msi_ext",
                sources=[
                    "src/msi/msi_module.cpp",
                    "src/msi/msi_kernel.cu",
                ],
                extra_compile_args={"cxx": cxx_args[target_os], "nvcc": nvcc_args},
                include_dirs=include_dir,
            ),
            CUDAExtension(
                "drtk.interpolate_ext",
                sources=[
                    "src/interpolate/interpolate_module.cpp",
                    "src/interpolate/interpolate_kernel.cu",
                ],
                extra_compile_args={"cxx": cxx_args[target_os], "nvcc": nvcc_args},
                include_dirs=include_dir,
            ),
Stanislav Pidhorskyi's avatar
Stanislav Pidhorskyi committed
151
152
153
154
155
156
157
158
159
            CUDAExtension(
                "drtk.grid_scatter_ext",
                sources=[
                    "src/grid_scatter/grid_scatter_module.cpp",
                    "src/grid_scatter/grid_scatter_kernel.cu",
                ],
                extra_compile_args={"cxx": cxx_args[target_os], "nvcc": nvcc_args},
                include_dirs=include_dir,
            ),
facebook-github-bot's avatar
facebook-github-bot committed
160
161
162
163
164
165
166
167
        ],
        cmdclass={"build_ext": BuildExtension},
        packages=["drtk", "drtk.utils"],
    )


if __name__ == "__main__":
    main(any(x in sys.argv for x in ["debug", "-debug", "--debug"]))