roi_align.py 11 KB
Newer Older
1
import functools
2
3
from typing import List, Union

4
import torch
5
import torch.fx
6
from torch import nn, Tensor
7
from torch._dynamo.utils import is_compile_supported
8
from torch.jit.annotations import BroadcastingList2
9
from torch.nn.modules.utils import _pair
10
from torchvision.extension import _assert_has_ops, _has_ops
11

12
from ..utils import _log_api_usage_once
13
from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format
14
15


16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def lazy_compile(**compile_kwargs):
    """Lazily wrap a function with torch.compile on the first call

    This avoids eagerly importing dynamo.
    """

    def decorate_fn(fn):
        @functools.wraps(fn)
        def compile_hook(*args, **kwargs):
            compiled_fn = torch.compile(fn, **compile_kwargs)
            globals()[fn.__name__] = functools.wraps(fn)(compiled_fn)
            return compiled_fn(*args, **kwargs)

        return compile_hook

    return decorate_fn


34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# NB: all inputs are tensors
def _bilinear_interpolate(
    input,  # [N, C, H, W]
    roi_batch_ind,  # [K]
    y,  # [K, PH, IY]
    x,  # [K, PW, IX]
    ymask,  # [K, IY]
    xmask,  # [K, IX]
):
    _, channels, height, width = input.size()

    # deal with inverse element out of feature map boundary
    y = y.clamp(min=0)
    x = x.clamp(min=0)
    y_low = y.int()
    x_low = x.int()
    y_high = torch.where(y_low >= height - 1, height - 1, y_low + 1)
    y_low = torch.where(y_low >= height - 1, height - 1, y_low)
    y = torch.where(y_low >= height - 1, y.to(input.dtype), y)

    x_high = torch.where(x_low >= width - 1, width - 1, x_low + 1)
    x_low = torch.where(x_low >= width - 1, width - 1, x_low)
    x = torch.where(x_low >= width - 1, x.to(input.dtype), x)

    ly = y - y_low
    lx = x - x_low
    hy = 1.0 - ly
    hx = 1.0 - lx

    # do bilinear interpolation, but respect the masking!
    # TODO: It's possible the masking here is unnecessary if y and
    # x were clamped appropriately; hard to tell
    def masked_index(
        y,  # [K, PH, IY]
        x,  # [K, PW, IX]
    ):
        if ymask is not None:
            assert xmask is not None
            y = torch.where(ymask[:, None, :], y, 0)
            x = torch.where(xmask[:, None, :], x, 0)
        return input[
            roi_batch_ind[:, None, None, None, None, None],
            torch.arange(channels, device=input.device)[None, :, None, None, None, None],
            y[:, None, :, None, :, None],  # prev [K, PH, IY]
            x[:, None, None, :, None, :],  # prev [K, PW, IX]
        ]  # [K, C, PH, PW, IY, IX]

    v1 = masked_index(y_low, x_low)
    v2 = masked_index(y_low, x_high)
    v3 = masked_index(y_high, x_low)
    v4 = masked_index(y_high, x_high)

    # all ws preemptively [K, C, PH, PW, IY, IX]
    def outer_prod(y, x):
        return y[:, None, :, None, :, None] * x[:, None, None, :, None, :]

    w1 = outer_prod(hy, hx)
    w2 = outer_prod(hy, lx)
    w3 = outer_prod(ly, hx)
    w4 = outer_prod(ly, lx)

    val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4
    return val


# TODO: this doesn't actually cache
# TODO: main library should make this easier to do
def maybe_cast(tensor):
    if torch.is_autocast_enabled() and tensor.is_cuda and tensor.dtype != torch.double:
        return tensor.float()
    else:
        return tensor


108
109
110
111
# This is a pure Python and differentiable implementation of roi_align.  When
# run in eager mode, it uses a lot of memory, but when compiled it has
# acceptable memory usage.  The main point of this implementation is that
# its backwards is deterministic.
112
113
# It is transcribed directly off of the roi_align CUDA kernel, see
# https://dev-discuss.pytorch.org/t/a-pure-python-implementation-of-roi-align-that-looks-just-like-its-cuda-kernel/1266
114
@lazy_compile(dynamic=True)
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
    orig_dtype = input.dtype

    input = maybe_cast(input)
    rois = maybe_cast(rois)

    _, _, height, width = input.size()

    ph = torch.arange(pooled_height, device=input.device)  # [PH]
    pw = torch.arange(pooled_width, device=input.device)  # [PW]

    # input: [N, C, H, W]
    # rois: [K, 5]

    roi_batch_ind = rois[:, 0].int()  # [K]
    offset = 0.5 if aligned else 0.0
    roi_start_w = rois[:, 1] * spatial_scale - offset  # [K]
    roi_start_h = rois[:, 2] * spatial_scale - offset  # [K]
    roi_end_w = rois[:, 3] * spatial_scale - offset  # [K]
    roi_end_h = rois[:, 4] * spatial_scale - offset  # [K]

    roi_width = roi_end_w - roi_start_w  # [K]
    roi_height = roi_end_h - roi_start_h  # [K]
    if not aligned:
        roi_width = torch.clamp(roi_width, min=1.0)  # [K]
        roi_height = torch.clamp(roi_height, min=1.0)  # [K]

    bin_size_h = roi_height / pooled_height  # [K]
    bin_size_w = roi_width / pooled_width  # [K]

    exact_sampling = sampling_ratio > 0

    roi_bin_grid_h = sampling_ratio if exact_sampling else torch.ceil(roi_height / pooled_height)  # scalar or [K]
    roi_bin_grid_w = sampling_ratio if exact_sampling else torch.ceil(roi_width / pooled_width)  # scalar or [K]

    """
    iy, ix = dims(2)
    """

    if exact_sampling:
        count = max(roi_bin_grid_h * roi_bin_grid_w, 1)  # scalar
        iy = torch.arange(roi_bin_grid_h, device=input.device)  # [IY]
        ix = torch.arange(roi_bin_grid_w, device=input.device)  # [IX]
        ymask = None
        xmask = None
    else:
        count = torch.clamp(roi_bin_grid_h * roi_bin_grid_w, min=1)  # [K]
        # When doing adaptive sampling, the number of samples we need to do
        # is data-dependent based on how big the ROIs are.  This is a bit
        # awkward because first-class dims can't actually handle this.
        # So instead, we inefficiently suppose that we needed to sample ALL
        # the points and mask out things that turned out to be unnecessary
        iy = torch.arange(height, device=input.device)  # [IY]
        ix = torch.arange(width, device=input.device)  # [IX]
        ymask = iy[None, :] < roi_bin_grid_h[:, None]  # [K, IY]
        xmask = ix[None, :] < roi_bin_grid_w[:, None]  # [K, IX]

    def from_K(t):
        return t[:, None, None]

    y = (
        from_K(roi_start_h)
        + ph[None, :, None] * from_K(bin_size_h)
178
        + (iy[None, None, :] + 0.5).to(input.dtype) * from_K(bin_size_h / roi_bin_grid_h)
179
180
181
182
    )  # [K, PH, IY]
    x = (
        from_K(roi_start_w)
        + pw[None, :, None] * from_K(bin_size_w)
183
        + (ix[None, None, :] + 0.5).to(input.dtype) * from_K(bin_size_w / roi_bin_grid_w)
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
    )  # [K, PW, IX]
    val = _bilinear_interpolate(input, roi_batch_ind, y, x, ymask, xmask)  # [K, C, PH, PW, IY, IX]

    # Mask out samples that weren't actually adaptively needed
    if not exact_sampling:
        val = torch.where(ymask[:, None, None, None, :, None], val, 0)
        val = torch.where(xmask[:, None, None, None, None, :], val, 0)

    output = val.sum((-1, -2))  # remove IY, IX ~> [K, C, PH, PW]
    if isinstance(count, torch.Tensor):
        output /= count[:, None, None, None]
    else:
        output /= count

    output = output.to(orig_dtype)

    return output


203
@torch.fx.wrap
204
205
def roi_align(
    input: Tensor,
206
    boxes: Union[Tensor, List[Tensor]],
207
208
209
210
211
    output_size: BroadcastingList2[int],
    spatial_scale: float = 1.0,
    sampling_ratio: int = -1,
    aligned: bool = False,
) -> Tensor:
212
    """
213
    Performs Region of Interest (RoI) Align operator with average pooling, as described in Mask R-CNN.
214

215
    Args:
216
217
        input (Tensor[N, C, H, W]): The input tensor, i.e. a batch with ``N`` elements. Each element
            contains ``C`` feature maps of dimensions ``H x W``.
218
            If the tensor is quantized, we expect a batch size of ``N == 1``.
219
        boxes (Tensor[K, 5] or List[Tensor[L, 4]]): the box coordinates in (x1, y1, x2, y2)
220
221
            format where the regions will be taken from.
            The coordinate must satisfy ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
222
223
224
225
226
227
            If a single Tensor is passed, then the first column should
            contain the index of the corresponding element in the batch, i.e. a number in ``[0, N - 1]``.
            If a list of Tensors is passed, then each Tensor will correspond to the boxes for an element i
            in the batch.
        output_size (int or Tuple[int, int]): the size of the output (in bins or pixels) after the pooling
            is performed, as (height, width).
228
229
230
231
        spatial_scale (float): a scaling factor that maps the box coordinates to
            the input coordinates. For example, if your boxes are defined on the scale
            of a 224x224 image and your input is a 112x112 feature map (resulting from a 0.5x scaling of
            the original image), you'll want to set this to 0.5. Default: 1.0
232
233
        sampling_ratio (int): number of sampling points in the interpolation grid
            used to compute the output value of each pooled output bin. If > 0,
234
            then exactly ``sampling_ratio x sampling_ratio`` sampling points per bin are used. If
235
            <= 0, then an adaptive number of grid points are used (computed as
236
            ``ceil(roi_width / output_width)``, and likewise for height). Default: -1
AhnDW's avatar
AhnDW committed
237
        aligned (bool): If False, use the legacy implementation.
238
239
            If True, pixel shift the box coordinates it by -0.5 for a better alignment with the two
            neighboring pixel indices. This version is used in Detectron2
240
241

    Returns:
242
        Tensor[K, C, output_size[0], output_size[1]]: The pooled RoIs.
243
    """
Kai Zhang's avatar
Kai Zhang committed
244
245
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(roi_align)
246
    check_roi_boxes_shape(boxes)
247
    rois = boxes
248
    output_size = _pair(output_size)
249
250
    if not isinstance(rois, torch.Tensor):
        rois = convert_boxes_to_roi_format(rois)
251
    if not torch.jit.is_scripting():
252
253
254
        if (
            not _has_ops() or (torch.are_deterministic_algorithms_enabled() and (input.is_cuda or input.is_mps))
        ) and is_compile_supported(input.device.type):
255
256
            return _roi_align(input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned)
    _assert_has_ops()
257
258
259
    return torch.ops.torchvision.roi_align(
        input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned
    )
260
261
262
263


class RoIAlign(nn.Module):
    """
264
    See :func:`roi_align`.
265
    """
266

267
268
269
270
271
272
273
    def __init__(
        self,
        output_size: BroadcastingList2[int],
        spatial_scale: float,
        sampling_ratio: int,
        aligned: bool = False,
    ):
274
        super().__init__()
Kai Zhang's avatar
Kai Zhang committed
275
        _log_api_usage_once(self)
276
277
278
        self.output_size = output_size
        self.spatial_scale = spatial_scale
        self.sampling_ratio = sampling_ratio
AhnDW's avatar
AhnDW committed
279
        self.aligned = aligned
280

281
    def forward(self, input: Tensor, rois: Union[Tensor, List[Tensor]]) -> Tensor:
AhnDW's avatar
AhnDW committed
282
        return roi_align(input, rois, self.output_size, self.spatial_scale, self.sampling_ratio, self.aligned)
283

284
    def __repr__(self) -> str:
Joao Gomes's avatar
Joao Gomes committed
285
286
287
288
289
290
291
292
293
        s = (
            f"{self.__class__.__name__}("
            f"output_size={self.output_size}"
            f", spatial_scale={self.spatial_scale}"
            f", sampling_ratio={self.sampling_ratio}"
            f", aligned={self.aligned}"
            f")"
        )
        return s