f"[extension] Failed to build PyTorch extension because the detected CUDA version ({bare_metal_major}.{bare_metal_minor}) "
f"mismatches the version that was used to compile PyTorch ({torch_cuda_major}.{torch_cuda_minor})."
"Please make sure you have set the CUDA_HOME correctly and installed the correct PyTorch in https://pytorch.org/get-started/locally/ ."
)
ifbare_metal_minor!=torch_cuda_minor:
warnings.warn(
f"[extension] The CUDA version on the system ({bare_metal_major}.{bare_metal_minor}) does not match with the version ({torch_cuda_major}.{torch_cuda_minor}) torch was compiled with. "
"The mismatch is found in the minor version. As the APIs are compatible, we will allow compilation to proceed. "
"If you encounter any issue when using the built kernel, please try to build it again with fully matched CUDA versions"
)
returnTrue
defget_pytorch_version()->List[int]:
"""
This functions finds the PyTorch version.
Returns:
A tuple of integers in the form of (major, minor, patch).
The CUDA version required by PyTorch, in the form of tuple (major, minor).
"""
nvcc_path=os.path.join(cuda_dir,"bin/nvcc")
ifcuda_dirisNone:
raiseValueError(
f"[extension] The argument cuda_dir is None, but expected to be a string. Please make sure your have exported the environment variable CUDA_HOME correctly."
)
# check for nvcc path
ifnotos.path.exists(nvcc_path):
raiseFileNotFoundError(
f"[extension] The nvcc compiler is not found in {nvcc_path}, please make sure you have set the correct value for CUDA_HOME."
)
# parse the nvcc -v output to obtain the system cuda version
"[extension] CUDA_HOME is not found. You need to export CUDA_HOME environment variable or install CUDA Toolkit first in order to build/load CUDA extensions"