Unverified Commit c1e88fae authored by Jeff Daily's avatar Jeff Daily Committed by GitHub
Browse files

fix cross-compiled ROCm builds when no GPUs detected (#45)

parent 5baa68d3
...@@ -20,7 +20,17 @@ def get_cuda_bare_metal_version(cuda_dir): ...@@ -20,7 +20,17 @@ def get_cuda_bare_metal_version(cuda_dir):
return raw_output, bare_metal_major, bare_metal_minor return raw_output, bare_metal_major, bare_metal_minor
if not torch.cuda.is_available(): def check_if_rocm_pytorch():
is_rocm_pytorch = False
if torch.__version__ >= '1.5':
from torch.utils.cpp_extension import ROCM_HOME
is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False
return is_rocm_pytorch
IS_ROCM_PYTORCH = check_if_rocm_pytorch()
if not torch.cuda.is_available() and not IS_ROCM_PYTORCH:
# https://github.com/NVIDIA/apex/issues/486 # https://github.com/NVIDIA/apex/issues/486
# Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(), # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(),
# which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command). # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command).
...@@ -37,6 +47,11 @@ if not torch.cuda.is_available(): ...@@ -37,6 +47,11 @@ if not torch.cuda.is_available():
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0" os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
else: else:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5" os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5"
elif not torch.cuda.is_available() and IS_ROCM_PYTORCH:
print('\nWarning: Torch did not find available GPUs on this system.\n',
'If your intention is to cross-compile, this is not an error.\n'
'By default, Apex will cross-compile for the same gfx targets\n'
'used by default in ROCm PyTorch\n')
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MAJOR = int(torch.__version__.split('.')[0])
...@@ -106,16 +121,6 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir): ...@@ -106,16 +121,6 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
"https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. " "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. "
"You can try commenting out this check (at your own risk).") "You can try commenting out this check (at your own risk).")
def check_if_rocm_pytorch():
is_rocm_pytorch = False
if torch.__version__ >= '1.5':
from torch.utils.cpp_extension import ROCM_HOME
is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False
return is_rocm_pytorch
IS_ROCM_PYTORCH = check_if_rocm_pytorch()
# Set up macros for forward/backward compatibility hack around # Set up macros for forward/backward compatibility hack around
# https://github.com/pytorch/pytorch/commit/4404762d7dd955383acee92e6f06b48144a0742e # https://github.com/pytorch/pytorch/commit/4404762d7dd955383acee92e6f06b48144a0742e
# and # and
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment