Unverified Commit 220b69ba authored by Soumith Chintala's avatar Soumith Chintala Committed by GitHub
Browse files

make C extension lazy-import (#971)

* make C extension lazy-import

* add lazy loading to roi_pool
parent 579eebea
...@@ -33,31 +33,3 @@ def get_image_backend(): ...@@ -33,31 +33,3 @@ def get_image_backend():
Gets the name of the package used to load images Gets the name of the package used to load images
""" """
return _image_backend return _image_backend
def _check_cuda_matches():
"""
Make sure that CUDA versions match between the pytorch install and torchvision install
"""
import torch
from torchvision import _C
if hasattr(_C, "CUDA_VERSION") and torch.version.cuda is not None:
tv_version = str(_C.CUDA_VERSION)
if int(tv_version) < 10000:
tv_major = int(tv_version[0])
tv_minor = int(tv_version[2])
else:
tv_major = int(tv_version[0:2])
tv_minor = int(tv_version[3])
t_version = torch.version.cuda
t_version = t_version.split('.')
t_major = int(t_version[0])
t_minor = int(t_version[1])
if t_major != tv_major or t_minor != tv_minor:
raise RuntimeError("Detected that PyTorch and torchvision were compiled with different CUDA versions. "
"PyTorch has CUDA Version={}.{} and torchvision has CUDA Version={}.{}. "
"Please reinstall the torchvision that matches your PyTorch install."
.format(t_major, t_minor, tv_major, tv_minor))
_check_cuda_matches()
_C = None
def _lazy_import():
"""
Make sure that CUDA versions match between the pytorch install and torchvision install
"""
global _C
if _C is not None:
return _C
import torch
from torchvision import _C as C
_C = C
if hasattr(_C, "CUDA_VERSION") and torch.version.cuda is not None:
tv_version = str(_C.CUDA_VERSION)
if int(tv_version) < 10000:
tv_major = int(tv_version[0])
tv_minor = int(tv_version[2])
else:
tv_major = int(tv_version[0:2])
tv_minor = int(tv_version[3])
t_version = torch.version.cuda
t_version = t_version.split('.')
t_major = int(t_version[0])
t_minor = int(t_version[1])
if t_major != tv_major or t_minor != tv_minor:
raise RuntimeError("Detected that PyTorch and torchvision were compiled with different CUDA versions. "
"PyTorch has CUDA Version={}.{} and torchvision has CUDA Version={}.{}. "
"Please reinstall the torchvision that matches your PyTorch install."
.format(t_major, t_minor, tv_major, tv_minor))
return _C
import torch import torch
from torchvision import _C from torchvision.extension import _lazy_import
def nms(boxes, scores, iou_threshold): def nms(boxes, scores, iou_threshold):
...@@ -22,6 +22,7 @@ def nms(boxes, scores, iou_threshold): ...@@ -22,6 +22,7 @@ def nms(boxes, scores, iou_threshold):
of the elements that have been kept of the elements that have been kept
by NMS, sorted in decreasing order of scores by NMS, sorted in decreasing order of scores
""" """
_C = _lazy_import()
return _C.nms(boxes, scores, iou_threshold) return _C.nms(boxes, scores, iou_threshold)
......
...@@ -6,7 +6,7 @@ from torch.autograd.function import once_differentiable ...@@ -6,7 +6,7 @@ from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
from torchvision import _C from torchvision.extension import _lazy_import
from ._utils import convert_boxes_to_roi_format from ._utils import convert_boxes_to_roi_format
...@@ -18,6 +18,7 @@ class _RoIAlignFunction(Function): ...@@ -18,6 +18,7 @@ class _RoIAlignFunction(Function):
ctx.spatial_scale = spatial_scale ctx.spatial_scale = spatial_scale
ctx.sampling_ratio = sampling_ratio ctx.sampling_ratio = sampling_ratio
ctx.input_shape = input.size() ctx.input_shape = input.size()
_C = _lazy_import()
output = _C.roi_align_forward( output = _C.roi_align_forward(
input, roi, spatial_scale, input, roi, spatial_scale,
output_size[0], output_size[1], sampling_ratio) output_size[0], output_size[1], sampling_ratio)
...@@ -31,6 +32,7 @@ class _RoIAlignFunction(Function): ...@@ -31,6 +32,7 @@ class _RoIAlignFunction(Function):
spatial_scale = ctx.spatial_scale spatial_scale = ctx.spatial_scale
sampling_ratio = ctx.sampling_ratio sampling_ratio = ctx.sampling_ratio
bs, ch, h, w = ctx.input_shape bs, ch, h, w = ctx.input_shape
_C = _lazy_import()
grad_input = _C.roi_align_backward( grad_input = _C.roi_align_backward(
grad_output, rois, spatial_scale, grad_output, rois, spatial_scale,
output_size[0], output_size[1], bs, ch, h, w, sampling_ratio) output_size[0], output_size[1], bs, ch, h, w, sampling_ratio)
......
...@@ -6,7 +6,7 @@ from torch.autograd.function import once_differentiable ...@@ -6,7 +6,7 @@ from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
from torchvision import _C from torchvision.extension import _lazy_import
from ._utils import convert_boxes_to_roi_format from ._utils import convert_boxes_to_roi_format
...@@ -16,6 +16,7 @@ class _RoIPoolFunction(Function): ...@@ -16,6 +16,7 @@ class _RoIPoolFunction(Function):
ctx.output_size = _pair(output_size) ctx.output_size = _pair(output_size)
ctx.spatial_scale = spatial_scale ctx.spatial_scale = spatial_scale
ctx.input_shape = input.size() ctx.input_shape = input.size()
_C = _lazy_import()
output, argmax = _C.roi_pool_forward( output, argmax = _C.roi_pool_forward(
input, rois, spatial_scale, input, rois, spatial_scale,
output_size[0], output_size[1]) output_size[0], output_size[1])
...@@ -29,6 +30,7 @@ class _RoIPoolFunction(Function): ...@@ -29,6 +30,7 @@ class _RoIPoolFunction(Function):
output_size = ctx.output_size output_size = ctx.output_size
spatial_scale = ctx.spatial_scale spatial_scale = ctx.spatial_scale
bs, ch, h, w = ctx.input_shape bs, ch, h, w = ctx.input_shape
_C = _lazy_import()
grad_input = _C.roi_pool_backward( grad_input = _C.roi_pool_backward(
grad_output, rois, argmax, spatial_scale, grad_output, rois, argmax, spatial_scale,
output_size[0], output_size[1], bs, ch, h, w) output_size[0], output_size[1], bs, ch, h, w)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment