roi_pool.py 3.18 KB
Newer Older
1
2
3
4
5
6
7
8
import torch
from torch import nn

from torch.autograd import Function
from torch.autograd.function import once_differentiable

from torch.nn.modules.utils import _pair

9
from torchvision.extension import _lazy_import
10
11
12
13
14
15
16
17
18
from ._utils import convert_boxes_to_roi_format


class _RoIPoolFunction(Function):
    @staticmethod
    def forward(ctx, input, rois, output_size, spatial_scale):
        ctx.output_size = _pair(output_size)
        ctx.spatial_scale = spatial_scale
        ctx.input_shape = input.size()
19
        _C = _lazy_import()
20
21
22
23
24
25
26
27
28
29
30
31
32
        output, argmax = _C.roi_pool_forward(
            input, rois, spatial_scale,
            output_size[0], output_size[1])
        ctx.save_for_backward(rois, argmax)
        return output

    @staticmethod
    @once_differentiable
    def backward(ctx, grad_output):
        rois, argmax = ctx.saved_tensors
        output_size = ctx.output_size
        spatial_scale = ctx.spatial_scale
        bs, ch, h, w = ctx.input_shape
33
        _C = _lazy_import()
34
35
36
37
38
39
40
41
42
43
44
45
        grad_input = _C.roi_pool_backward(
            grad_output, rois, argmax, spatial_scale,
            output_size[0], output_size[1], bs, ch, h, w)
        return grad_input, None, None, None


def roi_pool(input, boxes, output_size, spatial_scale=1.0):
    """
    Performs Region of Interest (RoI) Pool operator described in Fast R-CNN

    Arguments:
        input (Tensor[N, C, H, W]): input tensor
46
        boxes (Tensor[K, 5] or List[Tensor[L, 4]]): the box coordinates in (x1, y1, x2, y2)
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
            format where the regions will be taken from. If a single Tensor is passed,
            then the first column should contain the batch index. If a list of Tensors
            is passed, then each Tensor will correspond to the boxes for an element i
            in a batch
        output_size (int or Tuple[int, int]): the size of the output after the cropping
            is performed, as (height, width)
        spatial_scale (float): a scaling factor that maps the input coordinates to
            the box coordinates. Default: 1.0

    Returns:
        output (Tensor[K, C, output_size[0], output_size[1]])
    """
    rois = boxes
    if not isinstance(rois, torch.Tensor):
        rois = convert_boxes_to_roi_format(rois)
62
63
64
65
66
67
68
    # TODO: Change this to support backwards, which we
    #       do not currently support when JIT tracing.
    if torch._C._get_tracing_state():
        _lazy_import()
        output, _ = torch.ops.torchvision.roi_pool(input, rois, spatial_scale,
                                                   output_size[0], output_size[1])
        return output
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
    return _RoIPoolFunction.apply(input, rois, output_size, spatial_scale)


class RoIPool(nn.Module):
    """
    See roi_pool
    """
    def __init__(self, output_size, spatial_scale):
        super(RoIPool, self).__init__()
        self.output_size = output_size
        self.spatial_scale = spatial_scale

    def forward(self, input, rois):
        return roi_pool(input, rois, self.output_size, self.spatial_scale)

    def __repr__(self):
        tmpstr = self.__class__.__name__ + '('
        tmpstr += 'output_size=' + str(self.output_size)
        tmpstr += ', spatial_scale=' + str(self.spatial_scale)
        tmpstr += ')'
        return tmpstr