Unverified Commit 220b69ba authored by Soumith Chintala's avatar Soumith Chintala Committed by GitHub
Browse files

make C extension lazy-import (#971)

* make C extension lazy-import

* add lazy loading to roi_pool
parent 579eebea
......@@ -33,31 +33,3 @@ def get_image_backend():
Gets the name of the package used to load images
"""
return _image_backend
def _check_cuda_matches():
"""
Make sure that CUDA versions match between the pytorch install and torchvision install
"""
import torch
from torchvision import _C
if hasattr(_C, "CUDA_VERSION") and torch.version.cuda is not None:
tv_version = str(_C.CUDA_VERSION)
if int(tv_version) < 10000:
tv_major = int(tv_version[0])
tv_minor = int(tv_version[2])
else:
tv_major = int(tv_version[0:2])
tv_minor = int(tv_version[3])
t_version = torch.version.cuda
t_version = t_version.split('.')
t_major = int(t_version[0])
t_minor = int(t_version[1])
if t_major != tv_major or t_minor != tv_minor:
raise RuntimeError("Detected that PyTorch and torchvision were compiled with different CUDA versions. "
"PyTorch has CUDA Version={}.{} and torchvision has CUDA Version={}.{}. "
"Please reinstall the torchvision that matches your PyTorch install."
.format(t_major, t_minor, tv_major, tv_minor))
_check_cuda_matches()
_C = None
def _lazy_import():
"""
Make sure that CUDA versions match between the pytorch install and torchvision install
"""
global _C
if _C is not None:
return _C
import torch
from torchvision import _C as C
_C = C
if hasattr(_C, "CUDA_VERSION") and torch.version.cuda is not None:
tv_version = str(_C.CUDA_VERSION)
if int(tv_version) < 10000:
tv_major = int(tv_version[0])
tv_minor = int(tv_version[2])
else:
tv_major = int(tv_version[0:2])
tv_minor = int(tv_version[3])
t_version = torch.version.cuda
t_version = t_version.split('.')
t_major = int(t_version[0])
t_minor = int(t_version[1])
if t_major != tv_major or t_minor != tv_minor:
raise RuntimeError("Detected that PyTorch and torchvision were compiled with different CUDA versions. "
"PyTorch has CUDA Version={}.{} and torchvision has CUDA Version={}.{}. "
"Please reinstall the torchvision that matches your PyTorch install."
.format(t_major, t_minor, tv_major, tv_minor))
return _C
import torch
from torchvision import _C
from torchvision.extension import _lazy_import
def nms(boxes, scores, iou_threshold):
......@@ -22,6 +22,7 @@ def nms(boxes, scores, iou_threshold):
of the elements that have been kept
by NMS, sorted in decreasing order of scores
"""
_C = _lazy_import()
return _C.nms(boxes, scores, iou_threshold)
......
......@@ -6,7 +6,7 @@ from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair
from torchvision import _C
from torchvision.extension import _lazy_import
from ._utils import convert_boxes_to_roi_format
......@@ -18,6 +18,7 @@ class _RoIAlignFunction(Function):
ctx.spatial_scale = spatial_scale
ctx.sampling_ratio = sampling_ratio
ctx.input_shape = input.size()
_C = _lazy_import()
output = _C.roi_align_forward(
input, roi, spatial_scale,
output_size[0], output_size[1], sampling_ratio)
......@@ -31,6 +32,7 @@ class _RoIAlignFunction(Function):
spatial_scale = ctx.spatial_scale
sampling_ratio = ctx.sampling_ratio
bs, ch, h, w = ctx.input_shape
_C = _lazy_import()
grad_input = _C.roi_align_backward(
grad_output, rois, spatial_scale,
output_size[0], output_size[1], bs, ch, h, w, sampling_ratio)
......
......@@ -6,7 +6,7 @@ from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair
from torchvision import _C
from torchvision.extension import _lazy_import
from ._utils import convert_boxes_to_roi_format
......@@ -16,6 +16,7 @@ class _RoIPoolFunction(Function):
ctx.output_size = _pair(output_size)
ctx.spatial_scale = spatial_scale
ctx.input_shape = input.size()
_C = _lazy_import()
output, argmax = _C.roi_pool_forward(
input, rois, spatial_scale,
output_size[0], output_size[1])
......@@ -29,6 +30,7 @@ class _RoIPoolFunction(Function):
output_size = ctx.output_size
spatial_scale = ctx.spatial_scale
bs, ch, h, w = ctx.input_shape
_C = _lazy_import()
grad_input = _C.roi_pool_backward(
grad_output, rois, argmax, spatial_scale,
output_size[0], output_size[1], bs, ch, h, w)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment