Commit 7fee854f authored by Casper's avatar Casper
Browse files

Add compute capability flags and n_threads as str

parent 77ca8337
......@@ -35,10 +35,19 @@ if not torch_is_prebuilt:
ext_modules = []
if build_cuda_extension:
n_threads = min(os.cpu_count(), 8)
# figure out compute capability
compute_capabilities = {80, 86, 89, 90}
if torch_is_prebuilt:
compute_capabilities.update({87})
capability_flags = ["-gencode", f"arch=compute_{cap},code=sm_{cap}" for cap in compute_capabilities]
# num threads
n_threads = str(min(os.cpu_count(), 8))
# final args
cxx_args = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17"]
nvcc_args = ["-O3", "-std=c++17", "--threads", n_threads]
nvcc_args = ["-O3", "-std=c++17", "--threads", n_threads] + capability_flags
ext_modules.append(
CUDAExtension(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment