"torchvision/vscode:/vscode.git/clone" did not exist on "97b53f969ac12fad34a05539a85c58ebae13f027"
Commit b3cb86f2 authored by Soumith Chintala's avatar Soumith Chintala
Browse files

version check against PyTorch's CUDA version

parent c7a4ca99
......@@ -56,6 +56,9 @@ def write_version_file():
with open(version_path, 'w') as f:
f.write("__version__ = '{}'\n".format(version))
f.write("git_version = {}\n".format(repr(sha)))
f.write("from torchvision import _C\n")
f.write("if hasattr(_C, 'CUDA_VERSION'):\n")
f.write(" cuda = _C.CUDA_VERSION\n")
write_version_file()
......
......@@ -33,3 +33,31 @@ def get_image_backend():
Gets the name of the package used to load images
"""
return _image_backend
def _check_cuda_matches():
"""
Make sure that CUDA versions match between the pytorch install and torchvision install
"""
import torch
from torchvision import _C
if hasattr(_C, "CUDA_VERSION") and torch.version.cuda is not None:
tv_version = str(_C.CUDA_VERSION)
if int(tv_version) < 10000:
tv_major = int(tv_version[0])
tv_minor = int(tv_version[2])
else:
tv_major = int(tv_version[0:2])
tv_minor = int(tv_version[3])
t_version = torch.version.cuda
t_version = t_version.split('.')
t_major = int(t_version[0])
t_minor = int(t_version[1])
if t_major != tv_major or t_minor != tv_minor:
raise RuntimeError("Detected that PyTorch and torchvision were compiled with different CUDA versions. "
"PyTorch has CUDA Version={}.{} and torchvision has CUDA Version={}.{}. "
"Please reinstall the torchvision that matches your PyTorch install."
.format(t_major, t_minor, tv_major, tv_minor))
_check_cuda_matches()
......@@ -2,10 +2,17 @@
#include "ROIPool.h"
#include "nms.h"
#ifdef WITH_CUDA
#include <cuda.h>
#endif
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("nms", &nms, "non-maximum suppression");
m.def("roi_align_forward", &ROIAlign_forward, "ROIAlign_forward");
m.def("roi_align_backward", &ROIAlign_backward, "ROIAlign_backward");
m.def("roi_pool_forward", &ROIPool_forward, "ROIPool_forward");
m.def("roi_pool_backward", &ROIPool_backward, "ROIPool_backward");
#ifdef WITH_CUDA
m.attr("CUDA_VERSION") = CUDA_VERSION;
#endif
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment