Unverified Commit f677ea31 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Remove cpp extensions in favor of torch ops (#1348)

* Remove C++ extensions in favor of custom ops

* Remove unused custom_ops.cpp file

* Rename _custom_ops.py

* Reorganize functions

* Minor improvements and fixes

* Fix lint

* Fully scriptable ops

* Import types used by annotations
parent 0dd55882
...@@ -52,9 +52,9 @@ def write_version_file(): ...@@ -52,9 +52,9 @@ def write_version_file():
with open(version_path, 'w') as f: with open(version_path, 'w') as f:
f.write("__version__ = '{}'\n".format(version)) f.write("__version__ = '{}'\n".format(version))
f.write("git_version = {}\n".format(repr(sha))) f.write("git_version = {}\n".format(repr(sha)))
f.write("from torchvision import _C\n") f.write("from torchvision.extension import _check_cuda_version\n")
f.write("if hasattr(_C, 'CUDA_VERSION'):\n") f.write("if _check_cuda_version() > 0:\n")
f.write(" cuda = _C.CUDA_VERSION\n") f.write(" cuda = _check_cuda_version()\n")
write_version_file() write_version_file()
...@@ -96,21 +96,12 @@ def get_extensions(): ...@@ -96,21 +96,12 @@ def get_extensions():
source_models = [os.path.join(models_dir, s) for s in source_models] source_models = [os.path.join(models_dir, s) for s in source_models]
tests = test_file + source_models tests = test_file + source_models
custom_ops_sources = [os.path.join(extensions_dir, "custom_ops", "custom_ops.cpp"),
os.path.join(extensions_dir, "cpu", "nms_cpu.cpp"),
os.path.join(extensions_dir, "cpu", "ROIAlign_cpu.cpp"),
os.path.join(extensions_dir, "cpu", "ROIPool_cpu.cpp")]
custom_ops_sources_cuda = [os.path.join(extensions_dir, "cuda", "nms_cuda.cu"),
os.path.join(extensions_dir, "cuda", "ROIAlign_cuda.cu"),
os.path.join(extensions_dir, "cuda", "ROIPool_cuda.cu")]
define_macros = [] define_macros = []
extra_compile_args = {} extra_compile_args = {}
if (torch.cuda.is_available() and CUDA_HOME is not None) or os.getenv('FORCE_CUDA', '0') == '1': if (torch.cuda.is_available() and CUDA_HOME is not None) or os.getenv('FORCE_CUDA', '0') == '1':
extension = CUDAExtension extension = CUDAExtension
sources += source_cuda sources += source_cuda
custom_ops_sources += custom_ops_sources_cuda
define_macros += [('WITH_CUDA', None)] define_macros += [('WITH_CUDA', None)]
nvcc_flags = os.getenv('NVCC_FLAGS', '') nvcc_flags = os.getenv('NVCC_FLAGS', '')
if nvcc_flags == '': if nvcc_flags == '':
...@@ -148,13 +139,6 @@ def get_extensions(): ...@@ -148,13 +139,6 @@ def get_extensions():
define_macros=define_macros, define_macros=define_macros,
extra_compile_args=extra_compile_args, extra_compile_args=extra_compile_args,
), ),
extension(
"torchvision._custom_ops",
sources=custom_ops_sources,
include_dirs=include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
),
] ]
return ext_modules return ext_modules
......
...@@ -190,7 +190,7 @@ class RoIPoolTester(unittest.TestCase): ...@@ -190,7 +190,7 @@ class RoIPoolTester(unittest.TestCase):
@torch.jit.script @torch.jit.script
def script_func(input, rois): def script_func(input, rois):
return torch.ops.torchvision.roi_pool(input, rois, 1.0, 5, 5)[0] return ops.roi_pool(input, rois, 5, 1.0)[0]
assert gradcheck(lambda x: script_func(x, rois), (x,)), 'gradcheck failed for scripted roi_pool' assert gradcheck(lambda x: script_func(x, rois), (x,)), 'gradcheck failed for scripted roi_pool'
...@@ -282,7 +282,7 @@ class RoIPoolTester(unittest.TestCase): ...@@ -282,7 +282,7 @@ class RoIPoolTester(unittest.TestCase):
@torch.jit.script @torch.jit.script
def script_func(input, rois): def script_func(input, rois):
return torch.ops.torchvision.roi_pool(input, rois, 1.0, 5, 5)[0] return ops.roi_pool(input, rois, 5, 1.0)[0]
assert gradcheck(lambda x: script_func(x, rois), (x,)), 'gradcheck failed for scripted roi_pool on CUDA' assert gradcheck(lambda x: script_func(x, rois), (x,)), 'gradcheck failed for scripted roi_pool on CUDA'
...@@ -442,7 +442,7 @@ class RoIAlignTester(unittest.TestCase): ...@@ -442,7 +442,7 @@ class RoIAlignTester(unittest.TestCase):
@torch.jit.script @torch.jit.script
def script_func(input, rois): def script_func(input, rois):
return torch.ops.torchvision.roi_align(input, rois, 0.5, 5, 5, 1)[0] return ops.roi_align(input, rois, 5, 0.5, 1)[0]
assert gradcheck(lambda x: script_func(x, rois), (x,)), 'gradcheck failed for scripted roi_align' assert gradcheck(lambda x: script_func(x, rois), (x,)), 'gradcheck failed for scripted roi_align'
...@@ -482,7 +482,7 @@ class RoIAlignTester(unittest.TestCase): ...@@ -482,7 +482,7 @@ class RoIAlignTester(unittest.TestCase):
@torch.jit.script @torch.jit.script
def script_func(input, rois): def script_func(input, rois):
return torch.ops.torchvision.roi_align(input, rois, 0.5, 5, 5, 1)[0] return ops.roi_align(input, rois, 5, 0.5, 1)[0]
assert gradcheck(lambda x: script_func(x, rois), (x,)), 'gradcheck failed for scripted roi_align on CUDA' assert gradcheck(lambda x: script_func(x, rois), (x,)), 'gradcheck failed for scripted roi_align on CUDA'
......
...@@ -5,6 +5,8 @@ from torchvision import transforms ...@@ -5,6 +5,8 @@ from torchvision import transforms
from torchvision import utils from torchvision import utils
from torchvision import io from torchvision import io
from .extension import _HAS_OPS
try: try:
from .version import __version__ # noqa: F401 from .version import __version__ # noqa: F401
except ImportError: except ImportError:
......
...@@ -74,3 +74,74 @@ at::Tensor ROIAlign_backward( ...@@ -74,3 +74,74 @@ at::Tensor ROIAlign_backward(
width, width,
sampling_ratio); sampling_ratio);
} }
using namespace at;
using torch::Tensor;
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;
class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
public:
static variable_list forward(
AutogradContext* ctx,
Variable input,
Variable rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio) {
ctx->saved_data["spatial_scale"] = spatial_scale;
ctx->saved_data["pooled_height"] = pooled_height;
ctx->saved_data["pooled_width"] = pooled_width;
ctx->saved_data["sampling_ratio"] = sampling_ratio;
ctx->saved_data["input_shape"] = input.sizes();
ctx->save_for_backward({rois});
auto result = ROIAlign_forward(
input,
rois,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio);
return {result};
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
// Use data saved in forward
auto saved = ctx->get_saved_variables();
auto rois = saved[0];
auto input_shape = ctx->saved_data["input_shape"].toIntList();
auto grad_in = ROIAlign_backward(
grad_output[0],
rois,
ctx->saved_data["spatial_scale"].toDouble(),
ctx->saved_data["pooled_height"].toInt(),
ctx->saved_data["pooled_width"].toInt(),
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3],
ctx->saved_data["sampling_ratio"].toInt());
return {
grad_in, Variable(), Variable(), Variable(), Variable(), Variable()};
}
};
Tensor roi_align(
const Tensor& input,
const Tensor& rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio) {
return ROIAlignFunction::apply(
input,
rois,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio)[0];
}
...@@ -63,4 +63,66 @@ at::Tensor ROIPool_backward( ...@@ -63,4 +63,66 @@ at::Tensor ROIPool_backward(
channels, channels,
height, height,
width); width);
} }
\ No newline at end of file
using namespace at;
using torch::Tensor;
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;
class ROIPoolFunction : public torch::autograd::Function<ROIPoolFunction> {
public:
static variable_list forward(
AutogradContext* ctx,
Variable input,
Variable rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width) {
ctx->saved_data["spatial_scale"] = spatial_scale;
ctx->saved_data["pooled_height"] = pooled_height;
ctx->saved_data["pooled_width"] = pooled_width;
ctx->saved_data["input_shape"] = input.sizes();
auto result = ROIPool_forward(
input, rois, spatial_scale, pooled_height, pooled_width);
auto output = std::get<0>(result);
auto argmax = std::get<1>(result);
ctx->save_for_backward({rois, argmax});
ctx->mark_non_differentiable({argmax});
return {output, argmax};
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
// Use data saved in forward
auto saved = ctx->get_saved_variables();
auto rois = saved[0];
auto argmax = saved[1];
auto input_shape = ctx->saved_data["input_shape"].toIntList();
auto grad_in = ROIPool_backward(
grad_output[0],
rois,
argmax,
ctx->saved_data["spatial_scale"].toDouble(),
ctx->saved_data["pooled_height"].toInt(),
ctx->saved_data["pooled_width"].toInt(),
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3]);
return {grad_in, Variable(), Variable(), Variable(), Variable()};
}
};
std::tuple<Tensor, Tensor> roi_pool(
const Tensor& input,
const Tensor& rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width) {
auto result = ROIPoolFunction::apply(
input, rois, spatial_scale, pooled_height, pooled_width);
return std::tuple<Tensor, Tensor>(result[0], result[1]);
}
#include <Python.h>
#include <torch/script.h>
#include "ROIAlign.h"
#include "ROIPool.h"
#include "nms.h"
using namespace at;
// If we are in a Windows environment, we need to define
// initialization functions for the _custom_ops extension
#ifdef _WIN32
#if PY_MAJOR_VERSION < 3
PyMODINIT_FUNC init_custom_ops(void) {
// No need to do anything.
// _custom_ops.py will run on load
return NULL;
}
#else
PyMODINIT_FUNC PyInit__custom_ops(void) {
// No need to do anything.
// _custom_ops.py will run on load
return NULL;
}
#endif
#endif
using torch::Tensor;
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;
class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
public:
static variable_list forward(
AutogradContext* ctx,
Variable input,
Variable rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio) {
ctx->saved_data["spatial_scale"] = spatial_scale;
ctx->saved_data["pooled_height"] = pooled_height;
ctx->saved_data["pooled_width"] = pooled_width;
ctx->saved_data["sampling_ratio"] = sampling_ratio;
ctx->saved_data["input_shape"] = input.sizes();
ctx->save_for_backward({rois});
auto result = ROIAlign_forward(
input,
rois,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio);
return {result};
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
// Use data saved in forward
auto saved = ctx->get_saved_variables();
auto rois = saved[0];
auto input_shape = ctx->saved_data["input_shape"].toIntList();
auto grad_in = ROIAlign_backward(
grad_output[0],
rois,
ctx->saved_data["spatial_scale"].toDouble(),
ctx->saved_data["pooled_height"].toInt(),
ctx->saved_data["pooled_width"].toInt(),
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3],
ctx->saved_data["sampling_ratio"].toInt());
return {
grad_in, Variable(), Variable(), Variable(), Variable(), Variable()};
}
};
Tensor roi_align(
const Tensor& input,
const Tensor& rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio) {
return ROIAlignFunction::apply(
input,
rois,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio)[0];
}
class ROIPoolFunction : public torch::autograd::Function<ROIPoolFunction> {
public:
static variable_list forward(
AutogradContext* ctx,
Variable input,
Variable rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width) {
ctx->saved_data["spatial_scale"] = spatial_scale;
ctx->saved_data["pooled_height"] = pooled_height;
ctx->saved_data["pooled_width"] = pooled_width;
ctx->saved_data["input_shape"] = input.sizes();
auto result = ROIPool_forward(
input, rois, spatial_scale, pooled_height, pooled_width);
auto output = std::get<0>(result);
auto argmax = std::get<1>(result);
ctx->save_for_backward({rois, argmax});
ctx->mark_non_differentiable({argmax});
return {output, argmax};
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
// Use data saved in forward
auto saved = ctx->get_saved_variables();
auto rois = saved[0];
auto argmax = saved[1];
auto input_shape = ctx->saved_data["input_shape"].toIntList();
auto grad_in = ROIPool_backward(
grad_output[0],
rois,
argmax,
ctx->saved_data["spatial_scale"].toDouble(),
ctx->saved_data["pooled_height"].toInt(),
ctx->saved_data["pooled_width"].toInt(),
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3]);
return {grad_in, Variable(), Variable(), Variable(), Variable()};
}
};
std::tuple<Tensor, Tensor> roi_pool(
const Tensor& input,
const Tensor& rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width) {
auto result = ROIPoolFunction::apply(
input, rois, spatial_scale, pooled_height, pooled_width);
return std::tuple<Tensor, Tensor>(result[0], result[1]);
}
static auto registry =
torch::RegisterOperators()
.op("torchvision::nms", &nms)
.op("torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> Tensor",
&roi_align)
.op("torchvision::roi_pool", &roi_pool);
#include <Python.h>
#include <torch/script.h>
#ifdef WITH_CUDA
#include <cuda.h>
#endif
#include "ROIAlign.h" #include "ROIAlign.h"
#include "ROIPool.h" #include "ROIPool.h"
#include "nms.h" #include "nms.h"
#ifdef WITH_CUDA // If we are in a Windows environment, we need to define
#include <cuda.h> // initialization functions for the _custom_ops extension
#ifdef _WIN32
#if PY_MAJOR_VERSION < 3
PyMODINIT_FUNC init_custom_ops(void) {
// No need to do anything.
// _custom_ops.py will run on load
return NULL;
}
#else
PyMODINIT_FUNC PyInit__custom_ops(void) {
// No need to do anything.
// _custom_ops.py will run on load
return NULL;
}
#endif
#endif #endif
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { int64_t _cuda_version() {
// TODO: remove nms from here since it is now registered
// and used as a PyTorch custom op
m.def("nms", &nms, "non-maximum suppression");
m.def("roi_align_forward", &ROIAlign_forward, "ROIAlign_forward");
m.def("roi_align_backward", &ROIAlign_backward, "ROIAlign_backward");
m.def("roi_pool_forward", &ROIPool_forward, "ROIPool_forward");
m.def("roi_pool_backward", &ROIPool_backward, "ROIPool_backward");
#ifdef WITH_CUDA #ifdef WITH_CUDA
m.attr("CUDA_VERSION") = CUDA_VERSION; return CUDA_VERSION;
#else
return -1;
#endif #endif
} }
static auto registry =
torch::RegisterOperators()
.op("torchvision::nms", &nms)
.op("torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> Tensor",
&roi_align)
.op("torchvision::roi_pool", &roi_pool)
.op("torchvision::_cuda_version", &_cuda_version);
_C = None _HAS_OPS = False
def _lazy_import(): def _register_extensions():
import os
import imp
import torch
# load the custom_op_library and register the custom ops
lib_dir = os.path.dirname(__file__)
_, path, _ = imp.find_module("_C", [lib_dir])
torch.ops.load_library(path)
try:
_register_extensions()
_HAS_OPS = True
except (ImportError, OSError):
pass
def _check_cuda_version():
""" """
Make sure that CUDA versions match between the pytorch install and torchvision install Make sure that CUDA versions match between the pytorch install and torchvision install
""" """
global _C if not _HAS_OPS:
if _C is not None: return -1
return _C
import torch import torch
from torchvision import _C as C _version = torch.ops.torchvision._cuda_version()
import torchvision.ops._custom_ops if _version != -1 and torch.version.cuda is not None:
_C = C tv_version = str(_version)
if hasattr(_C, "CUDA_VERSION") and torch.version.cuda is not None:
tv_version = str(_C.CUDA_VERSION)
if int(tv_version) < 10000: if int(tv_version) < 10000:
tv_major = int(tv_version[0]) tv_major = int(tv_version[0])
tv_minor = int(tv_version[2]) tv_minor = int(tv_version[2])
...@@ -29,4 +44,7 @@ def _lazy_import(): ...@@ -29,4 +44,7 @@ def _lazy_import():
"PyTorch has CUDA Version={}.{} and torchvision has CUDA Version={}.{}. " "PyTorch has CUDA Version={}.{} and torchvision has CUDA Version={}.{}. "
"Please reinstall the torchvision that matches your PyTorch install." "Please reinstall the torchvision that matches your PyTorch install."
.format(t_major, t_minor, tv_major, tv_minor)) .format(t_major, t_minor, tv_major, tv_minor))
return _C return _version
_check_cuda_version()
...@@ -4,6 +4,10 @@ from .roi_pool import roi_pool, RoIPool ...@@ -4,6 +4,10 @@ from .roi_pool import roi_pool, RoIPool
from .poolers import MultiScaleRoIAlign from .poolers import MultiScaleRoIAlign
from .feature_pyramid_network import FeaturePyramidNetwork from .feature_pyramid_network import FeaturePyramidNetwork
from ._register_onnx_ops import _register_custom_op
_register_custom_op()
__all__ = [ __all__ = [
'nms', 'roi_align', 'RoIAlign', 'roi_pool', 'RoIPool', 'nms', 'roi_align', 'RoIAlign', 'roi_pool', 'RoIPool',
......
import os
import sys import sys
import imp
import torch import torch
# load the custom_op_library and register the custom ops def _register_custom_op():
lib_dir = os.path.join(os.path.dirname(__file__), '..')
file, path, description = imp.find_module("_custom_ops", [lib_dir])
torch.ops.load_library(path)
def register_custom_op():
from torch.onnx.symbolic_helper import parse_args, scalar_type_to_onnx from torch.onnx.symbolic_helper import parse_args, scalar_type_to_onnx
from torch.onnx.symbolic_opset9 import select, unsqueeze, squeeze, _cast_Long, reshape from torch.onnx.symbolic_opset9 import select, unsqueeze, squeeze, _cast_Long, reshape
...@@ -41,6 +33,3 @@ def register_custom_op(): ...@@ -41,6 +33,3 @@ def register_custom_op():
register_custom_op_symbolic('torchvision::nms', symbolic_multi_label_nms, 10) register_custom_op_symbolic('torchvision::nms', symbolic_multi_label_nms, 10)
register_custom_op_symbolic('torchvision::roi_align', roi_align, 10) register_custom_op_symbolic('torchvision::roi_align', roi_align, 10)
register_custom_op_symbolic('torchvision::roi_pool', roi_pool, 10) register_custom_op_symbolic('torchvision::roi_pool', roi_pool, 10)
register_custom_op()
import torch import torch
from torch import Tensor
from torch.jit.annotations import List
def _cat(tensors, dim=0): def _cat(tensors, dim=0):
# type: (List[Tensor], int) -> Tensor
""" """
Efficient version of torch.cat that avoids a copy if there is only a single element in a list Efficient version of torch.cat that avoids a copy if there is only a single element in a list
""" """
assert isinstance(tensors, (list, tuple)) # TODO add back the assert
# assert isinstance(tensors, (list, tuple))
if len(tensors) == 1: if len(tensors) == 1:
return tensors[0] return tensors[0]
return torch.cat(tensors, dim) return torch.cat(tensors, dim)
def convert_boxes_to_roi_format(boxes): def convert_boxes_to_roi_format(boxes):
# type: (List[Tensor]) -> Tensor
concat_boxes = _cat([b for b in boxes], dim=0) concat_boxes = _cat([b for b in boxes], dim=0)
ids = _cat( temp = []
[ for i, b in enumerate(boxes):
torch.full_like(b[:, :1], i) temp.append(torch.full_like(b[:, :1], i))
for i, b in enumerate(boxes) ids = _cat(temp, dim=0)
],
dim=0,
)
rois = torch.cat([ids, concat_boxes], dim=1) rois = torch.cat([ids, concat_boxes], dim=1)
return rois return rois
import torch import torch
from torchvision.extension import _lazy_import
def nms(boxes, scores, iou_threshold): def nms(boxes, scores, iou_threshold):
...@@ -29,7 +28,6 @@ def nms(boxes, scores, iou_threshold): ...@@ -29,7 +28,6 @@ def nms(boxes, scores, iou_threshold):
of the elements that have been kept of the elements that have been kept
by NMS, sorted in decreasing order of scores by NMS, sorted in decreasing order of scores
""" """
_lazy_import()
return torch.ops.torchvision.nms(boxes, scores, iou_threshold) return torch.ops.torchvision.nms(boxes, scores, iou_threshold)
......
import torch import torch
from torch import nn from torch import nn, Tensor
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
from torch.jit.annotations import List
from torchvision.extension import _lazy_import
from ._utils import convert_boxes_to_roi_format from ._utils import convert_boxes_to_roi_format
def roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1): def roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1):
# type: (Tensor, Tensor, int, float, int) -> Tensor
""" """
Performs Region of Interest (RoI) Align operator described in Mask R-CNN Performs Region of Interest (RoI) Align operator described in Mask R-CNN
...@@ -35,9 +33,9 @@ def roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1): ...@@ -35,9 +33,9 @@ def roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1):
output (Tensor[K, C, output_size[0], output_size[1]]) output (Tensor[K, C, output_size[0], output_size[1]])
""" """
rois = boxes rois = boxes
output_size = _pair(output_size)
if not isinstance(rois, torch.Tensor): if not isinstance(rois, torch.Tensor):
rois = convert_boxes_to_roi_format(rois) rois = convert_boxes_to_roi_format(rois)
_lazy_import()
return torch.ops.torchvision.roi_align(input, rois, spatial_scale, return torch.ops.torchvision.roi_align(input, rois, spatial_scale,
output_size[0], output_size[1], output_size[0], output_size[1],
sampling_ratio) sampling_ratio)
......
import torch import torch
from torch import nn from torch import nn, Tensor
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
from torch.jit.annotations import List
from torchvision.extension import _lazy_import
from ._utils import convert_boxes_to_roi_format from ._utils import convert_boxes_to_roi_format
def roi_pool(input, boxes, output_size, spatial_scale=1.0): def roi_pool(input, boxes, output_size, spatial_scale=1.0):
# type: (Tensor, Tensor, int, float) -> Tensor
""" """
Performs Region of Interest (RoI) Pool operator described in Fast R-CNN Performs Region of Interest (RoI) Pool operator described in Fast R-CNN
...@@ -30,9 +28,9 @@ def roi_pool(input, boxes, output_size, spatial_scale=1.0): ...@@ -30,9 +28,9 @@ def roi_pool(input, boxes, output_size, spatial_scale=1.0):
output (Tensor[K, C, output_size[0], output_size[1]]) output (Tensor[K, C, output_size[0], output_size[1]])
""" """
rois = boxes rois = boxes
output_size = _pair(output_size)
if not isinstance(rois, torch.Tensor): if not isinstance(rois, torch.Tensor):
rois = convert_boxes_to_roi_format(rois) rois = convert_boxes_to_roi_format(rois)
_lazy_import()
output, _ = torch.ops.torchvision.roi_pool(input, rois, spatial_scale, output, _ = torch.ops.torchvision.roi_pool(input, rois, spatial_scale,
output_size[0], output_size[1]) output_size[0], output_size[1])
return output return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment