Unverified Commit dc3ac290 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Add C++ ops to torchvision (#826)

* Initial layout for layers with cpp extensions

* Move files around

* Fix import after move

* Add support for multiple types to ROIAlign

* Different organization

CUDA extensions work now

* Cleanups

* Reduce memory requirements for backwards

* Replace runtime_error by AT_ERROR

* Add nms test

* Add support for compilation using CPP extensions

* Change folder structure

* Add ROIPool cuda

* Cleanups

* Add roi_pool.py

* Fix lint

* Add initial structures folder for bounding boxes

* Assertion macros compatible with pytorch master (#540)

* Support for ROI Pooling (#592)

* ROI Pooling with tests. Fix for cuda context in ROI Align.

* renamed bottom and top to follow torch conventions

* remove .type().tensor() calls in favor of the new approach to tensor initialization (#626)

* Consistent naming for rois variable (#627)

* remove .type().tensor() calls in favor of the new approach to tensor initialization

* Consistent naming for rois variable in ROIPool

* ROIPool: Support for all datatypes (#632)

* Use of torch7 naming scheme for ROIAlign forward and backward

* use common cuda helpers in ROIAlign

* use .options() in favor of .type() where applicable

* Added tests for forward pass of ROIAlign, as well as more consistent naming scheme for CPU vs CUDA

* working ROIAlign cuda backwards pass

* working ROIAlign backwards pass for CPU

* added relevant headers for ROIAlign backwards

* tests for ROIAlign layer

* replace .type() with .options() for tensor initialization in ROIAlign layers

* support for Half types in ROIAlign

* gradcheck tests for ROIAlign

* updated ROIPool on CPU to work with all datatypes

* updated and cleaned tests for ROI Pooling

* Fix rebase problem

* Remove structures folder

* Improve cleanup and bugfix in test_layers

* Update C++ headers

* Add CUDAGuard to cu files

* Add more checks to layers

* Add CUDA NMS and tests

* Add multi-type support for NMS CUDA

* Avoid using THCudaMalloc

* Add clang-format and reformat c++ code

* Remove THC includes

* Rename layers to ops

* Add documentation and rename functions

* Improve the documentation a bit

* Fix some lint errors

* Fix remaining lint inssues

* Area computation doesn't add +1 in NMS

* Update CI to use PyTorch nightly

* Make NMS return indices sorted according to the score

* Address reviewer comments

* Lint fixes

* Improve doc for roi_align and roi_pool

* move to xenial

* Fix bug pointed by @lopuhin

* Fix RoIPool reference implementation in Python 2

Also fixes a bug in the clip_boxes_to_image -- this function needs a test!

* Remove change in .travis
parent 0564df43
import torch
from torchvision import _C
def nms(boxes, scores, iou_threshold):
"""
Performs non-maximum suppression (NMS) on the boxes according
to their intersection-over-union (IoU).
NMS iteratively removes lower scoring boxes which have an
IoU greater than iou_threshold with another (higher scoring)
box.
Arguments:
boxes (Tensor[N, 4]): boxes to perform NMS on
scores (Tensor[N]): scores for each one of the boxes
iou_threshold (float): discards all overlapping
boxes with IoU < iou_threshold
Returns:
keep (Tensor): int64 tensor with the indices
of the elements that have been kept
by NMS, sorted in decreasing order of scores
"""
return _C.nms(boxes, scores, iou_threshold)
def batched_nms(boxes, scores, idxs, iou_threshold):
"""
Performs non-maximum suppression in a batched fashion.
Each index value correspond to a category, and NMS
will not be applied between elements of different categories.
Arguments:
boxes (Tensor[N, 4]): boxes where NMS will be performed
scores (Tensor[N]): scores for each one of the boxes
idxs (Tensor[N]): indices of the categories for each
one of the boxes.
iou_threshold (float): discards all overlapping boxes
with IoU < iou_threshold
Returns:
keep (Tensor): int64 tensor with the indices of
the elements that have been kept by NMS, sorted
in decreasing order of scores
"""
# strategy: in order to perform NMS independently per class.
# we add an offset to all the boxes. The offset is dependent
# only on the class idx, and is large enough so that boxes
# from different classes do not overlap
max_coordinate = boxes.max()
offsets = idxs.to(boxes) * (max_coordinate + 1)
boxes_for_nms = boxes + offsets[:, None]
keep = nms(boxes_for_nms, scores, iou_threshold)
return keep
def remove_small_boxes(boxes, min_size):
ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1]
keep = (ws >= min_size) & (hs >= min_size)
keep = keep.nonzero().squeeze(1)
return keep
def clip_boxes_to_image(boxes, size):
"""
Arguments:
boxes (Tensor[N, 4])
size (Tuple[height, width])
Returns:
clipped_boxes (Tensor[N, 4])
"""
dim = boxes.dim()
boxes_x = boxes[..., 0::2]
boxes_y = boxes[..., 1::2]
height, width = size
boxes_x = boxes_x.clamp(min=0, max=width)
boxes_y = boxes_y.clamp(min=0, max=height)
clipped_boxes = torch.stack((boxes_x, boxes_y), dim=dim)
return clipped_boxes.reshape(boxes.shape)
def box_area(boxes):
"""
Computes the area of a set of bounding boxes, which are specified by its
(x0, y0, x1, y1) coordinates.
Arguments:
boxes (Tensor[N, 4]): boxes for which the area will be computed. They
are expected to be in (x0, y0, x1, y1) format
Returns:
area (Tensor[N]): area for each box
"""
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
# with slight modifications
def box_iou(boxes1, boxes2):
"""
Return intersection-over-union (Jaccard index) of boxes.
Arguments:
boxes1 (Tensor[N, 4])
boxes2 (Tensor[M, 4])
Returns:
iou (Tensor[N, M]): the NxM matrix containing the pairwise
IoU values for every element in boxes1 and boxes2
"""
area1 = box_area(boxes1)
area2 = box_area(boxes2)
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
wh = (rb - lt).clamp(min=0) # [N,M,2]
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
iou = inter / (area1[:, None] + area2 - inter)
return iou
import torch
from torch import nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair
from torchvision import _C
from ._utils import convert_boxes_to_roi_format
class _RoIAlignFunction(Function):
@staticmethod
def forward(ctx, input, roi, output_size, spatial_scale, sampling_ratio):
ctx.save_for_backward(roi)
ctx.output_size = _pair(output_size)
ctx.spatial_scale = spatial_scale
ctx.sampling_ratio = sampling_ratio
ctx.input_shape = input.size()
output = _C.roi_align_forward(
input, roi, spatial_scale,
output_size[0], output_size[1], sampling_ratio)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
rois, = ctx.saved_tensors
output_size = ctx.output_size
spatial_scale = ctx.spatial_scale
sampling_ratio = ctx.sampling_ratio
bs, ch, h, w = ctx.input_shape
grad_input = _C.roi_align_backward(
grad_output, rois, spatial_scale,
output_size[0], output_size[1], bs, ch, h, w, sampling_ratio)
return grad_input, None, None, None, None
def roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1):
"""
Performs Region of Interest (RoI) Align operator described in Mask R-CNN
Arguments:
input (Tensor[N, C, H, W]): input tensor
boxes (Tensor[K, 5] or List[Tensor[L, 4]]): the box coordinates in x1,y1,x2,y2
format where the regions will be taken from. If a single Tensor is passed,
then the first column should contain the batch index. If a list of Tensors
is passed, then each Tensor will correspond to the boxes for an element i
in a batch
output_size (int or Tuple[int, int]): the size of the output after the cropping
is performed, as (height, width)
spatial_scale (float): a scaling factor that maps the input coordinates to
the box coordinates. Default: 1.0
sampling_ratio (int): number of sampling points in the interpolation grid
used to compute the output value of each pooled output bin. If > 0,
then exactly sampling_ratio x sampling_ratio grid points are used. If
<= 0, then an adaptive number of grid points are used (computed as
ceil(roi_width / pooled_w), and likewise for height). Default: -1
Returns:
output (Tensor[K, C, output_size[0], output_size[1]])
"""
rois = boxes
if not isinstance(rois, torch.Tensor):
rois = convert_boxes_to_roi_format(rois)
return _RoIAlignFunction.apply(input, rois, output_size, spatial_scale, sampling_ratio)
class RoIAlign(nn.Module):
"""
See roi_align
"""
def __init__(self, output_size, spatial_scale, sampling_ratio):
super(RoIAlign, self).__init__()
self.output_size = output_size
self.spatial_scale = spatial_scale
self.sampling_ratio = sampling_ratio
def forward(self, input, rois):
return roi_align(input, rois, self.output_size, self.spatial_scale, self.sampling_ratio)
def __repr__(self):
tmpstr = self.__class__.__name__ + '('
tmpstr += 'output_size=' + str(self.output_size)
tmpstr += ', spatial_scale=' + str(self.spatial_scale)
tmpstr += ', sampling_ratio=' + str(self.sampling_ratio)
tmpstr += ')'
return tmpstr
import torch
from torch import nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair
from torchvision import _C
from ._utils import convert_boxes_to_roi_format
class _RoIPoolFunction(Function):
@staticmethod
def forward(ctx, input, rois, output_size, spatial_scale):
ctx.output_size = _pair(output_size)
ctx.spatial_scale = spatial_scale
ctx.input_shape = input.size()
output, argmax = _C.roi_pool_forward(
input, rois, spatial_scale,
output_size[0], output_size[1])
ctx.save_for_backward(rois, argmax)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
rois, argmax = ctx.saved_tensors
output_size = ctx.output_size
spatial_scale = ctx.spatial_scale
bs, ch, h, w = ctx.input_shape
grad_input = _C.roi_pool_backward(
grad_output, rois, argmax, spatial_scale,
output_size[0], output_size[1], bs, ch, h, w)
return grad_input, None, None, None
def roi_pool(input, boxes, output_size, spatial_scale=1.0):
"""
Performs Region of Interest (RoI) Pool operator described in Fast R-CNN
Arguments:
input (Tensor[N, C, H, W]): input tensor
boxes (Tensor[K, 5] or List[Tensor[L, 4]]): the box coordinates in x1,y1,x2,y2
format where the regions will be taken from. If a single Tensor is passed,
then the first column should contain the batch index. If a list of Tensors
is passed, then each Tensor will correspond to the boxes for an element i
in a batch
output_size (int or Tuple[int, int]): the size of the output after the cropping
is performed, as (height, width)
spatial_scale (float): a scaling factor that maps the input coordinates to
the box coordinates. Default: 1.0
Returns:
output (Tensor[K, C, output_size[0], output_size[1]])
"""
rois = boxes
if not isinstance(rois, torch.Tensor):
rois = convert_boxes_to_roi_format(rois)
return _RoIPoolFunction.apply(input, rois, output_size, spatial_scale)
class RoIPool(nn.Module):
"""
See roi_pool
"""
def __init__(self, output_size, spatial_scale):
super(RoIPool, self).__init__()
self.output_size = output_size
self.spatial_scale = spatial_scale
def forward(self, input, rois):
return roi_pool(input, rois, self.output_size, self.spatial_scale)
def __repr__(self):
tmpstr = self.__class__.__name__ + '('
tmpstr += 'output_size=' + str(self.output_size)
tmpstr += ', spatial_scale=' + str(self.spatial_scale)
tmpstr += ')'
return tmpstr
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment