Unverified Commit 7c9fb403 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Fix gencode flags in setup (#145)



* Fix gencode flags based on cuda version
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* review suggestions
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* revert append_nvcc_threads change
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 9bc9e68d
......@@ -30,24 +30,24 @@ def get_cuda_bare_metal_version(cuda_dir):
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
return raw_output, bare_metal_major, bare_metal_minor
return (int(bare_metal_major), int(bare_metal_minor))
def append_nvcc_threads(nvcc_extra_args):
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
cuda_major, cuda_minor = get_cuda_bare_metal_version(CUDA_HOME)
if cuda_major >= 11 and cuda_minor >= 2:
return nvcc_extra_args + ["--threads", "4"]
return nvcc_extra_args
def extra_gencodes(cc_flag):
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11:
cuda_bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if cuda_bare_metal_version >= (11, 0):
cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80")
if int(bare_metal_minor) >= 8:
cc_flag.append("-gencode")
cc_flag.append("arch=compute_90,code=sm_90")
if cuda_bare_metal_version >= (11, 8):
cc_flag.append("-gencode")
cc_flag.append("arch=compute_90,code=sm_90")
def extra_compiler_flags():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment