Unverified Commit 7c9fb403 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Fix gencode flags in setup (#145)



* Fix gencode flags based on cuda version
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* review suggestions
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* revert append_nvcc_threads change
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 9bc9e68d
...@@ -30,22 +30,22 @@ def get_cuda_bare_metal_version(cuda_dir): ...@@ -30,22 +30,22 @@ def get_cuda_bare_metal_version(cuda_dir):
release = output[release_idx].split(".") release = output[release_idx].split(".")
bare_metal_major = release[0] bare_metal_major = release[0]
bare_metal_minor = release[1][0] bare_metal_minor = release[1][0]
return raw_output, bare_metal_major, bare_metal_minor return (int(bare_metal_major), int(bare_metal_minor))
def append_nvcc_threads(nvcc_extra_args): def append_nvcc_threads(nvcc_extra_args):
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) cuda_major, cuda_minor = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2: if cuda_major >= 11 and cuda_minor >= 2:
return nvcc_extra_args + ["--threads", "4"] return nvcc_extra_args + ["--threads", "4"]
return nvcc_extra_args return nvcc_extra_args
def extra_gencodes(cc_flag): def extra_gencodes(cc_flag):
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) cuda_bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11: if cuda_bare_metal_version >= (11, 0):
cc_flag.append("-gencode") cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80") cc_flag.append("arch=compute_80,code=sm_80")
if int(bare_metal_minor) >= 8: if cuda_bare_metal_version >= (11, 8):
cc_flag.append("-gencode") cc_flag.append("-gencode")
cc_flag.append("arch=compute_90,code=sm_90") cc_flag.append("arch=compute_90,code=sm_90")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment