Unverified Commit b2171653 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Add compatibility checks for C++ extensions (#2467)

* Add compatibility checks for C++ extensions

* Fix lint
parent 2cc20d74
_HAS_OPS = False
def _has_ops():
return False
def _register_extensions():
import os
import importlib
......@@ -23,10 +27,26 @@ def _register_extensions():
try:
_register_extensions()
_HAS_OPS = True
def _has_ops(): # noqa: F811
return True
except (ImportError, OSError):
pass
def _assert_has_ops():
if not _has_ops():
raise RuntimeError(
"Couldn't load custom C++ ops. This can happen if your PyTorch and "
"torchvision versions are incompatible, or if you had errors while compiling "
"torchvision from source. For further information on the compatible versions, check "
"https://github.com/pytorch/vision#installation for the compatibility matrix. "
"Please check your PyTorch version with torch.__version__ and your torchvision "
"version with torchvision.__version__ and verify if they are compatible, and if not "
"please reinstall torchvision so that it matches your PyTorch install."
)
def _check_cuda_version():
"""
Make sure that CUDA versions match between the pytorch install and torchvision install
......
......@@ -3,6 +3,7 @@ from torch.jit.annotations import Tuple
from torch import Tensor
from ._box_convert import _box_cxcywh_to_xyxy, _box_xyxy_to_cxcywh, _box_xywh_to_xyxy, _box_xyxy_to_xywh
import torchvision
from torchvision.extension import _assert_has_ops
def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor:
......@@ -37,6 +38,7 @@ def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor:
of the elements that have been kept
by NMS, sorted in decreasing order of scores
"""
_assert_has_ops()
return torch.ops.torchvision.nms(boxes, scores, iou_threshold)
......
......@@ -6,6 +6,7 @@ from torch.nn import init
from torch.nn.parameter import Parameter
from torch.nn.modules.utils import _pair
from torch.jit.annotations import Optional, Tuple
from torchvision.extension import _assert_has_ops
def deform_conv2d(
......@@ -51,6 +52,7 @@ def deform_conv2d(
>>> torch.Size([4, 5, 8, 8])
"""
_assert_has_ops()
out_channels = weight.shape[0]
if bias is None:
bias = torch.zeros(out_channels, device=input.device, dtype=input.dtype)
......
......@@ -4,6 +4,7 @@ from torch import nn, Tensor
from torch.nn.modules.utils import _pair
from torch.jit.annotations import List, Tuple
from torchvision.extension import _assert_has_ops
from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape
......@@ -38,6 +39,7 @@ def ps_roi_align(
Returns:
output (Tensor[K, C, output_size[0], output_size[1]])
"""
_assert_has_ops()
check_roi_boxes_shape(boxes)
rois = boxes
output_size = _pair(output_size)
......
......@@ -4,6 +4,7 @@ from torch import nn, Tensor
from torch.nn.modules.utils import _pair
from torch.jit.annotations import List, Tuple
from torchvision.extension import _assert_has_ops
from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape
......@@ -32,6 +33,7 @@ def ps_roi_pool(
Returns:
output (Tensor[K, C, output_size[0], output_size[1]])
"""
_assert_has_ops()
check_roi_boxes_shape(boxes)
rois = boxes
output_size = _pair(output_size)
......
......@@ -4,6 +4,7 @@ from torch import nn, Tensor
from torch.nn.modules.utils import _pair
from torch.jit.annotations import List, BroadcastingList2
from torchvision.extension import _assert_has_ops
from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape
......@@ -41,6 +42,7 @@ def roi_align(
Returns:
output (Tensor[K, C, output_size[0], output_size[1]])
"""
_assert_has_ops()
check_roi_boxes_shape(boxes)
rois = boxes
output_size = _pair(output_size)
......
......@@ -4,6 +4,7 @@ from torch import nn, Tensor
from torch.nn.modules.utils import _pair
from torch.jit.annotations import List, BroadcastingList2
from torchvision.extension import _assert_has_ops
from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape
......@@ -31,6 +32,7 @@ def roi_pool(
Returns:
output (Tensor[K, C, output_size[0], output_size[1]])
"""
_assert_has_ops()
check_roi_boxes_shape(boxes)
rois = boxes
output_size = _pair(output_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment