Commit 6ee72295 authored by Chris Klaiber's avatar Chris Klaiber Committed by Facebook GitHub Bot
Browse files

enable DRTK build in Docker without CUDA (#5)

Summary:
This unblocks building DRTK in a Docker build since BuildKit doesn't yet support GPUs: https://github.com/moby/buildkit/issues/1436

This works because CUDA is needed at run-time but not at build-time. When CUDA is present at build-time, the currently installed cards set which archs the extensions are built for. However, the archs are manually fixed in setup.py, so this commit also adds support for respecting TORCH_CUDA_ARCH_LIST which controls archs within torch.utils.cpp_extension.CUDAExtension

Pull Request resolved: https://github.com/facebookresearch/DRTK/pull/5

Test Plan:
Build without TORCH_CUDA_ARCH_LIST and observe using `ps` that compilation gets the flags specified in setup.py.

Build with TORCH_CUDA_ARCH_LIST=Turing and TORCH_CUDA_ARCH_LIST=8.6 and observe using `ps` that compilation gets only the flags for the requested archs.

Build within a Dockerfile using `docker-compose up --build`, which is using BuildKit, and observe that compilation succeeds and the library is usable when later run with CUDA GPUs available.

Reviewed By: HapeMask

Differential Revision: D63513797

Pulled By: podgorskiy

fbshipit-source-id: 54ce6765ccc37317999f18690a73823eb074f08f
parent dfe3fdfb
......@@ -32,13 +32,18 @@ def main(debug: bool) -> None:
),
}
nvcc_args = [
"-gencode=arch=compute_72,code=sm_72",
"-gencode=arch=compute_75,code=sm_75",
"-gencode=arch=compute_80,code=sm_80",
"-gencode=arch=compute_86,code=sm_86",
"-gencode=arch=compute_90,code=sm_90",
] + (["-O0", "-g", "-DDEBUG"] if debug else ["-O3", "--use_fast_math"])
nvcc_args = ["-O0", "-g", "-DDEBUG"] if debug else ["-O3", "--use_fast_math"]
if not os.getenv("TORCH_CUDA_ARCH_LIST"):
# Respect TORCH_CUDA_ARCH_LIST when set, otherwise fall back to a default list of archs
nvcc_args.extend(
[
"-gencode=arch=compute_72,code=sm_72",
"-gencode=arch=compute_75,code=sm_75",
"-gencode=arch=compute_80,code=sm_80",
"-gencode=arch=compute_86,code=sm_86",
"-gencode=arch=compute_90,code=sm_90",
]
)
# There is som issue effecting latest NVCC and pytorch 2.3.0 https://github.com/pytorch/pytorch/issues/122169
# The workaround is adding -std=c++20 to NVCC args
......@@ -60,13 +65,6 @@ def main(debug: bool) -> None:
assert len(groups) == 1
version = groups[0]
if get_dist("torch") is None:
raise RuntimeError("Setup requires torch package to be installed")
import torch as th
assert th.cuda.is_available()
target_os = "none"
if sys.platform == "darwin":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment