Commit bc64ee83 authored by hubertlu-tw's avatar hubertlu-tw
Browse files

Keep --peer_memory and --nccl_p2p CUDA-compatible

parent fd0f7631
......@@ -42,6 +42,55 @@ def get_cuda_bare_metal_version(cuda_dir):
return raw_output, bare_metal_major, bare_metal_minor
def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
torch_binary_major = torch.version.cuda.split(".")[0]
torch_binary_minor = torch.version.cuda.split(".")[1]
print("\nCompiling cuda extensions with")
print(raw_output + "from " + cuda_dir + "/bin\n")
if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor):
raise RuntimeError(
"Cuda extensions are being compiled with a version of Cuda that does "
"not match the version used to compile Pytorch binaries. "
"Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda)
+ "In some cases, a minor-version mismatch will not cause later errors: "
"https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. "
"You can try commenting out this check (at your own risk)."
)
def raise_if_cuda_home_none(global_option: str) -> None:
if CUDA_HOME is not None:
return
raise RuntimeError(
f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
"If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
"only images whose names contain 'devel' will provide nvcc."
)
def append_nvcc_threads(nvcc_extra_args):
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
return nvcc_extra_args + ["--threads", "4"]
return nvcc_extra_args
def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int) -> bool:
cudnn_available = torch.backends.cudnn.is_available()
cudnn_version = torch.backends.cudnn.version() if cudnn_available else None
if not (cudnn_available and (cudnn_version >= required_cudnn_version)):
warnings.warn(
f"Skip `{global_option}` as it requires cuDNN {required_cudnn_version} or later, "
f"but {'cuDNN is not available' if not cudnn_available else cudnn_version}"
)
return False
return True
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
......@@ -539,6 +588,10 @@ if "--fast_bottleneck" in sys.argv:
if "--peer_memory" in sys.argv or "--cuda_ext" in sys.argv:
if "--peer_memory" in sys.argv:
sys.argv.remove("--peer_memory")
if not IS_ROCM_PYTORCH:
raise_if_cuda_home_none("--peer_memory")
ext_modules.append(
CUDAExtension(
name="peer_memory_cuda",
......@@ -553,6 +606,10 @@ if "--peer_memory" in sys.argv or "--cuda_ext" in sys.argv:
if "--nccl_p2p" in sys.argv or "--cuda_ext" in sys.argv:
if "--nccl_p2p" in sys.argv:
sys.argv.remove("--nccl_p2p")
if not IS_ROCM_PYTORCH:
raise_if_cuda_home_none("--nccl_p2p")
ext_modules.append(
CUDAExtension(
name="nccl_p2p_cuda",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment