roi_align.py 10.5 KB
Newer Older
1
2
from typing import List, Union

3
import torch
4
import torch._dynamo
5
import torch.fx
6
from torch import nn, Tensor
7
from torch.jit.annotations import BroadcastingList2
8
from torch.nn.modules.utils import _pair
9
from torchvision.extension import _assert_has_ops, _has_ops
10

11
from ..utils import _log_api_usage_once
12
from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format
13
14


15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
# NB: all inputs are tensors
def _bilinear_interpolate(
    input,  # [N, C, H, W]
    roi_batch_ind,  # [K]
    y,  # [K, PH, IY]
    x,  # [K, PW, IX]
    ymask,  # [K, IY]
    xmask,  # [K, IX]
):
    _, channels, height, width = input.size()

    # deal with inverse element out of feature map boundary
    y = y.clamp(min=0)
    x = x.clamp(min=0)
    y_low = y.int()
    x_low = x.int()
    y_high = torch.where(y_low >= height - 1, height - 1, y_low + 1)
    y_low = torch.where(y_low >= height - 1, height - 1, y_low)
    y = torch.where(y_low >= height - 1, y.to(input.dtype), y)

    x_high = torch.where(x_low >= width - 1, width - 1, x_low + 1)
    x_low = torch.where(x_low >= width - 1, width - 1, x_low)
    x = torch.where(x_low >= width - 1, x.to(input.dtype), x)

    ly = y - y_low
    lx = x - x_low
    hy = 1.0 - ly
    hx = 1.0 - lx

    # do bilinear interpolation, but respect the masking!
    # TODO: It's possible the masking here is unnecessary if y and
    # x were clamped appropriately; hard to tell
    def masked_index(
        y,  # [K, PH, IY]
        x,  # [K, PW, IX]
    ):
        if ymask is not None:
            assert xmask is not None
            y = torch.where(ymask[:, None, :], y, 0)
            x = torch.where(xmask[:, None, :], x, 0)
        return input[
            roi_batch_ind[:, None, None, None, None, None],
            torch.arange(channels, device=input.device)[None, :, None, None, None, None],
            y[:, None, :, None, :, None],  # prev [K, PH, IY]
            x[:, None, None, :, None, :],  # prev [K, PW, IX]
        ]  # [K, C, PH, PW, IY, IX]

    v1 = masked_index(y_low, x_low)
    v2 = masked_index(y_low, x_high)
    v3 = masked_index(y_high, x_low)
    v4 = masked_index(y_high, x_high)

    # all ws preemptively [K, C, PH, PW, IY, IX]
    def outer_prod(y, x):
        return y[:, None, :, None, :, None] * x[:, None, None, :, None, :]

    w1 = outer_prod(hy, hx)
    w2 = outer_prod(hy, lx)
    w3 = outer_prod(ly, hx)
    w4 = outer_prod(ly, lx)

    val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4
    return val


# TODO: this doesn't actually cache
# TODO: main library should make this easier to do
def maybe_cast(tensor):
    if torch.is_autocast_enabled() and tensor.is_cuda and tensor.dtype != torch.double:
        return tensor.float()
    else:
        return tensor


# This is a slow but pure Python and differentiable implementation of
# roi_align.  It potentially is a good basis for Inductor compilation
# (but I have not benchmarked it) but today it is solely used for the
# fact that its backwards can be implemented deterministically,
# which is needed for the PT2 benchmark suite.
#
# It is transcribed directly off of the roi_align CUDA kernel, see
# https://dev-discuss.pytorch.org/t/a-pure-python-implementation-of-roi-align-that-looks-just-like-its-cuda-kernel/1266
@torch._dynamo.allow_in_graph
def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
    orig_dtype = input.dtype

    input = maybe_cast(input)
    rois = maybe_cast(rois)

    _, _, height, width = input.size()

    ph = torch.arange(pooled_height, device=input.device)  # [PH]
    pw = torch.arange(pooled_width, device=input.device)  # [PW]

    # input: [N, C, H, W]
    # rois: [K, 5]

    roi_batch_ind = rois[:, 0].int()  # [K]
    offset = 0.5 if aligned else 0.0
    roi_start_w = rois[:, 1] * spatial_scale - offset  # [K]
    roi_start_h = rois[:, 2] * spatial_scale - offset  # [K]
    roi_end_w = rois[:, 3] * spatial_scale - offset  # [K]
    roi_end_h = rois[:, 4] * spatial_scale - offset  # [K]

    roi_width = roi_end_w - roi_start_w  # [K]
    roi_height = roi_end_h - roi_start_h  # [K]
    if not aligned:
        roi_width = torch.clamp(roi_width, min=1.0)  # [K]
        roi_height = torch.clamp(roi_height, min=1.0)  # [K]

    bin_size_h = roi_height / pooled_height  # [K]
    bin_size_w = roi_width / pooled_width  # [K]

    exact_sampling = sampling_ratio > 0

    roi_bin_grid_h = sampling_ratio if exact_sampling else torch.ceil(roi_height / pooled_height)  # scalar or [K]
    roi_bin_grid_w = sampling_ratio if exact_sampling else torch.ceil(roi_width / pooled_width)  # scalar or [K]

    """
    iy, ix = dims(2)
    """

    if exact_sampling:
        count = max(roi_bin_grid_h * roi_bin_grid_w, 1)  # scalar
        iy = torch.arange(roi_bin_grid_h, device=input.device)  # [IY]
        ix = torch.arange(roi_bin_grid_w, device=input.device)  # [IX]
        ymask = None
        xmask = None
    else:
        count = torch.clamp(roi_bin_grid_h * roi_bin_grid_w, min=1)  # [K]
        # When doing adaptive sampling, the number of samples we need to do
        # is data-dependent based on how big the ROIs are.  This is a bit
        # awkward because first-class dims can't actually handle this.
        # So instead, we inefficiently suppose that we needed to sample ALL
        # the points and mask out things that turned out to be unnecessary
        iy = torch.arange(height, device=input.device)  # [IY]
        ix = torch.arange(width, device=input.device)  # [IX]
        ymask = iy[None, :] < roi_bin_grid_h[:, None]  # [K, IY]
        xmask = ix[None, :] < roi_bin_grid_w[:, None]  # [K, IX]

    def from_K(t):
        return t[:, None, None]

    y = (
        from_K(roi_start_h)
        + ph[None, :, None] * from_K(bin_size_h)
161
        + (iy[None, None, :] + 0.5).to(input.dtype) * from_K(bin_size_h / roi_bin_grid_h)
162
163
164
165
    )  # [K, PH, IY]
    x = (
        from_K(roi_start_w)
        + pw[None, :, None] * from_K(bin_size_w)
166
        + (ix[None, None, :] + 0.5).to(input.dtype) * from_K(bin_size_w / roi_bin_grid_w)
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
    )  # [K, PW, IX]
    val = _bilinear_interpolate(input, roi_batch_ind, y, x, ymask, xmask)  # [K, C, PH, PW, IY, IX]

    # Mask out samples that weren't actually adaptively needed
    if not exact_sampling:
        val = torch.where(ymask[:, None, None, None, :, None], val, 0)
        val = torch.where(xmask[:, None, None, None, None, :], val, 0)

    output = val.sum((-1, -2))  # remove IY, IX ~> [K, C, PH, PW]
    if isinstance(count, torch.Tensor):
        output /= count[:, None, None, None]
    else:
        output /= count

    output = output.to(orig_dtype)

    return output


186
@torch.fx.wrap
187
188
def roi_align(
    input: Tensor,
189
    boxes: Union[Tensor, List[Tensor]],
190
191
192
193
194
    output_size: BroadcastingList2[int],
    spatial_scale: float = 1.0,
    sampling_ratio: int = -1,
    aligned: bool = False,
) -> Tensor:
195
    """
196
    Performs Region of Interest (RoI) Align operator with average pooling, as described in Mask R-CNN.
197

198
    Args:
199
200
        input (Tensor[N, C, H, W]): The input tensor, i.e. a batch with ``N`` elements. Each element
            contains ``C`` feature maps of dimensions ``H x W``.
201
            If the tensor is quantized, we expect a batch size of ``N == 1``.
202
        boxes (Tensor[K, 5] or List[Tensor[L, 4]]): the box coordinates in (x1, y1, x2, y2)
203
204
            format where the regions will be taken from.
            The coordinate must satisfy ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
205
206
207
208
209
210
            If a single Tensor is passed, then the first column should
            contain the index of the corresponding element in the batch, i.e. a number in ``[0, N - 1]``.
            If a list of Tensors is passed, then each Tensor will correspond to the boxes for an element i
            in the batch.
        output_size (int or Tuple[int, int]): the size of the output (in bins or pixels) after the pooling
            is performed, as (height, width).
211
212
213
214
        spatial_scale (float): a scaling factor that maps the box coordinates to
            the input coordinates. For example, if your boxes are defined on the scale
            of a 224x224 image and your input is a 112x112 feature map (resulting from a 0.5x scaling of
            the original image), you'll want to set this to 0.5. Default: 1.0
215
216
        sampling_ratio (int): number of sampling points in the interpolation grid
            used to compute the output value of each pooled output bin. If > 0,
217
            then exactly ``sampling_ratio x sampling_ratio`` sampling points per bin are used. If
218
            <= 0, then an adaptive number of grid points are used (computed as
219
            ``ceil(roi_width / output_width)``, and likewise for height). Default: -1
AhnDW's avatar
AhnDW committed
220
        aligned (bool): If False, use the legacy implementation.
221
222
            If True, pixel shift the box coordinates it by -0.5 for a better alignment with the two
            neighboring pixel indices. This version is used in Detectron2
223
224

    Returns:
225
        Tensor[K, C, output_size[0], output_size[1]]: The pooled RoIs.
226
    """
Kai Zhang's avatar
Kai Zhang committed
227
228
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(roi_align)
229
    check_roi_boxes_shape(boxes)
230
    rois = boxes
231
    output_size = _pair(output_size)
232
233
    if not isinstance(rois, torch.Tensor):
        rois = convert_boxes_to_roi_format(rois)
234
    if not torch.jit.is_scripting():
235
        if not _has_ops() or (torch.are_deterministic_algorithms_enabled() and (input.is_cuda or input.is_mps)):
236
237
            return _roi_align(input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned)
    _assert_has_ops()
238
239
240
    return torch.ops.torchvision.roi_align(
        input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned
    )
241
242
243
244


class RoIAlign(nn.Module):
    """
245
    See :func:`roi_align`.
246
    """
247

248
249
250
251
252
253
254
    def __init__(
        self,
        output_size: BroadcastingList2[int],
        spatial_scale: float,
        sampling_ratio: int,
        aligned: bool = False,
    ):
255
        super().__init__()
Kai Zhang's avatar
Kai Zhang committed
256
        _log_api_usage_once(self)
257
258
259
        self.output_size = output_size
        self.spatial_scale = spatial_scale
        self.sampling_ratio = sampling_ratio
AhnDW's avatar
AhnDW committed
260
        self.aligned = aligned
261

262
    def forward(self, input: Tensor, rois: Union[Tensor, List[Tensor]]) -> Tensor:
AhnDW's avatar
AhnDW committed
263
        return roi_align(input, rois, self.output_size, self.spatial_scale, self.sampling_ratio, self.aligned)
264

265
    def __repr__(self) -> str:
Joao Gomes's avatar
Joao Gomes committed
266
267
268
269
270
271
272
273
274
        s = (
            f"{self.__class__.__name__}("
            f"output_size={self.output_size}"
            f", spatial_scale={self.spatial_scale}"
            f", sampling_ratio={self.sampling_ratio}"
            f", aligned={self.aligned}"
            f")"
        )
        return s