Unverified Commit 50896ec5 authored by Chirag Jain's avatar Chirag Jain Committed by GitHub
Browse files

Make nvcc threads configurable via environment variable (#885)

parent 6c9e60de
......@@ -55,7 +55,8 @@ def raise_if_cuda_home_none(global_option: str) -> None:
def append_nvcc_threads(nvcc_extra_args):
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if bare_metal_version >= Version("11.2"):
return nvcc_extra_args + ["--threads", "4"]
nvcc_threads = os.getenv("NVCC_THREADS") or "4"
return nvcc_extra_args + ["--threads", nvcc_threads]
return nvcc_extra_args
......
......@@ -19,7 +19,8 @@ def get_cuda_bare_metal_version(cuda_dir):
def append_nvcc_threads(nvcc_extra_args):
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if bare_metal_version >= Version("11.2"):
return nvcc_extra_args + ["--threads", "4"]
nvcc_threads = os.getenv("NVCC_THREADS") or "4"
return nvcc_extra_args + ["--threads", nvcc_threads]
return nvcc_extra_args
......
......@@ -22,7 +22,8 @@ def get_cuda_bare_metal_version(cuda_dir):
def append_nvcc_threads(nvcc_extra_args):
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
return nvcc_extra_args + ["--threads", "4"]
nvcc_threads = os.getenv("NVCC_THREADS") or "4"
return nvcc_extra_args + ["--threads", nvcc_threads]
return nvcc_extra_args
......
......@@ -53,7 +53,8 @@ def raise_if_cuda_home_none(global_option: str) -> None:
def append_nvcc_threads(nvcc_extra_args):
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if bare_metal_version >= Version("11.2"):
return nvcc_extra_args + ["--threads", "4"]
nvcc_threads = os.getenv("NVCC_THREADS") or "4"
return nvcc_extra_args + ["--threads", nvcc_threads]
return nvcc_extra_args
......
......@@ -53,7 +53,8 @@ def raise_if_cuda_home_none(global_option: str) -> None:
def append_nvcc_threads(nvcc_extra_args):
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if bare_metal_version >= Version("11.2"):
return nvcc_extra_args + ["--threads", "4"]
nvcc_threads = os.getenv("NVCC_THREADS") or "4"
return nvcc_extra_args + ["--threads", nvcc_threads]
return nvcc_extra_args
......
......@@ -53,7 +53,8 @@ def raise_if_cuda_home_none(global_option: str) -> None:
def append_nvcc_threads(nvcc_extra_args):
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if bare_metal_version >= Version("11.2"):
return nvcc_extra_args + ["--threads", "4"]
nvcc_threads = os.getenv("NVCC_THREADS") or "4"
return nvcc_extra_args + ["--threads", nvcc_threads]
return nvcc_extra_args
......
......@@ -83,7 +83,8 @@ def check_if_cuda_home_none(global_option: str) -> None:
def append_nvcc_threads(nvcc_extra_args):
return nvcc_extra_args + ["--threads", "4"]
nvcc_threads = os.getenv("NVCC_THREADS") or "4"
return nvcc_extra_args + ["--threads", nvcc_threads]
cmdclass = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment