Commit 966884e2 authored by zhuww's avatar zhuww
Browse files

modify setup

parent e4119508
...@@ -14,14 +14,33 @@ from pathlib import Path ...@@ -14,14 +14,33 @@ from pathlib import Path
this_dir = os.path.dirname(os.path.abspath(__file__)) this_dir = os.path.dirname(os.path.abspath(__file__))
def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/hipcc", "--version"], universal_newlines=True)
output = raw_output.split()
release_idx = output.index("version:") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
return raw_output, bare_metal_major, bare_metal_minor
def check_cuda_torch_binary_vs_bare_metal(cuda_dir): def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
torch_binary_major = torch.version.hip.split(".")[0] torch_binary_major = torch.version.hip.split(".")[0]
torch_binary_minor = torch.version.hip.split(".")[1] torch_binary_minor = torch.version.hip.split(".")[1]
print("\nCompiling cuda extensions with") print("\nCompiling cuda extensions with")
print(raw_output + "from " + cuda_dir + "/bin\n")
def append_nvcc_threads(nvcc_extra_args): if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor):
return nvcc_extra_args raise RuntimeError(
"Cuda extensions are being compiled with a version of Cuda that does " +
"not match the version used to compile Pytorch binaries. " +
"Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda) +
"In some cases, a minor-version mismatch will not cause later errors: " +
"https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. "
"You can try commenting out this check (at your own risk).")
if not torch.cuda.is_available(): if not torch.cuda.is_available():
...@@ -74,18 +93,14 @@ else: ...@@ -74,18 +93,14 @@ else:
], ],
extra_compile_args={ extra_compile_args={
'cxx': ['-O3'] + version_dependent_macros, 'cxx': ['-O3'] + version_dependent_macros,
'hipcc': 'nvcc':['-O3'] + version_dependent_macros + extra_cuda_flags
append_nvcc_threads(['-O3', '--use_fast_math'] + version_dependent_macros +
extra_cuda_flags)
}) })
cc_flag = []
cc_flag = ['-gencode', 'arch=compute_70,code=sm_70']
extra_cuda_flags = [ extra_cuda_flags = [
'-std=c++14', '-maxrregcount=50', '-U__CUDA_NO_HALF_OPERATORS__', '-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__'
'-U__CUDA_NO_HALF_CONVERSIONS__', '--expt-relaxed-constexpr', '--expt-extended-lambda'
] ]
ext_modules.append( ext_modules.append(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment