poolers.py 10.6 KB
Newer Older
1
from typing import Optional, List, Dict, Tuple, Union
2

3
import torch
4
import torchvision
5
from torch import nn, Tensor
6
7
from torchvision.ops.boxes import box_area

8
from ..utils import _log_api_usage_once
9
from .roi_align import roi_align
10

Francisco Massa's avatar
Francisco Massa committed
11

12
13
14
15
# copying result_idx_in_level to a specific index in result[]
# is not supported by ONNX tracing yet.
# _onnx_merge_levels() is an implementation supported by ONNX
# that merges the levels to the right indices
eellison's avatar
eellison committed
16
@torch.jit.unused
17
def _onnx_merge_levels(levels: Tensor, unmerged_results: List[Tensor]) -> Tensor:
18
19
    first_result = unmerged_results[0]
    dtype, device = first_result.dtype, first_result.device
20
21
22
    res = torch.zeros(
        (levels.size(0), first_result.size(1), first_result.size(2), first_result.size(3)), dtype=dtype, device=device
    )
Francisco Massa's avatar
Francisco Massa committed
23
    for level in range(len(unmerged_results)):
24
        index = torch.where(levels == level)[0].view(-1, 1, 1, 1)
25
26
27
28
29
30
        index = index.expand(
            index.size(0),
            unmerged_results[level].size(1),
            unmerged_results[level].size(2),
            unmerged_results[level].size(3),
        )
Francisco Massa's avatar
Francisco Massa committed
31
        res = res.scatter(0, index, unmerged_results[level])
32
33
34
    return res


eellison's avatar
eellison committed
35
# TODO: (eellison) T54974082 https://github.com/pytorch/pytorch/issues/26744/pytorch/issues/26744
36
37
38
39
40
41
42
def initLevelMapper(
    k_min: int,
    k_max: int,
    canonical_scale: int = 224,
    canonical_level: int = 4,
    eps: float = 1e-6,
):
eellison's avatar
eellison committed
43
44
45
    return LevelMapper(k_min, k_max, canonical_scale, canonical_level, eps)


46
class LevelMapper:
47
48
    """Determine which FPN level each RoI in a set of RoIs should map to based
    on the heuristic in the FPN paper.
49

50
    Args:
51
52
53
54
55
        k_min (int)
        k_max (int)
        canonical_scale (int)
        canonical_level (int)
        eps (float)
56
57
    """

58
59
60
61
62
63
64
65
    def __init__(
        self,
        k_min: int,
        k_max: int,
        canonical_scale: int = 224,
        canonical_level: int = 4,
        eps: float = 1e-6,
    ):
66
67
68
69
70
71
        self.k_min = k_min
        self.k_max = k_max
        self.s0 = canonical_scale
        self.lvl0 = canonical_level
        self.eps = eps

72
    def __call__(self, boxlists: List[Tensor]) -> Tensor:
73
        """
74
        Args:
75
76
77
78
79
80
            boxlists (list[BoxList])
        """
        # Compute level ids
        s = torch.sqrt(torch.cat([box_area(boxlist) for boxlist in boxlists]))

        # Eqn.(1) in FPN paper
81
        target_lvls = torch.floor(self.lvl0 + torch.log2(s / self.s0) + torch.tensor(self.eps, dtype=s.dtype))
82
        target_lvls = torch.clamp(target_lvls, min=self.k_min, max=self.k_max)
83
        return (target_lvls.to(torch.int64) - self.k_min).to(torch.int64)
84
85
86
87


class MultiScaleRoIAlign(nn.Module):
    """
88
89
    Multi-scale RoIAlign pooling, which is useful for detection with or without FPN.

90
91
92
93
94
95
    It infers the scale of the pooling via the heuristics specified in eq. 1
    of the `Feature Pyramid Network paper <https://arxiv.org/abs/1612.03144>`_.
    They keyword-only parameters ``canonical_scale`` and ``canonical_level``
    correspond respectively to ``224`` and ``k0=4`` in eq. 1, and
    have the following meaning: ``canonical_level`` is the target level of the pyramid from
    which to pool a region of interest with ``w x h = canonical_scale x canonical_scale``.
96

97
    Args:
98
99
100
101
        featmap_names (List[str]): the names of the feature maps that will be used
            for the pooling.
        output_size (List[Tuple[int, int]] or List[int]): output size for the pooled region
        sampling_ratio (int): sampling ratio for ROIAlign
102
103
        canonical_scale (int, optional): canonical_scale for LevelMapper
        canonical_level (int, optional): canonical_level for LevelMapper
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119

    Examples::

        >>> m = torchvision.ops.MultiScaleRoIAlign(['feat1', 'feat3'], 3, 2)
        >>> i = OrderedDict()
        >>> i['feat1'] = torch.rand(1, 5, 64, 64)
        >>> i['feat2'] = torch.rand(1, 5, 32, 32)  # this feature won't be used in the pooling
        >>> i['feat3'] = torch.rand(1, 5, 16, 16)
        >>> # create some random bounding boxes
        >>> boxes = torch.rand(6, 4) * 256; boxes[:, 2:] += boxes[:, :2]
        >>> # original image size, before computing the feature maps
        >>> image_sizes = [(512, 512)]
        >>> output = m(i, [boxes], image_sizes)
        >>> print(output.shape)
        >>> torch.Size([6, 5, 3, 3])

120
121
    """

122
    __annotations__ = {"scales": Optional[List[float]], "map_levels": Optional[LevelMapper]}
eellison's avatar
eellison committed
123

124
125
126
    def __init__(
        self,
        featmap_names: List[str],
127
        output_size: Union[int, Tuple[int], List[int]],
128
        sampling_ratio: int,
129
130
131
        *,
        canonical_scale: int = 224,
        canonical_level: int = 4,
132
    ):
133
        super().__init__()
134
        _log_api_usage_once(self)
135
136
137
138
139
140
141
        if isinstance(output_size, int):
            output_size = (output_size, output_size)
        self.featmap_names = featmap_names
        self.sampling_ratio = sampling_ratio
        self.output_size = tuple(output_size)
        self.scales = None
        self.map_levels = None
142
143
        self.canonical_scale = canonical_scale
        self.canonical_level = canonical_level
144

145
    def convert_to_roi_format(self, boxes: List[Tensor]) -> Tensor:
146
147
148
149
        concat_boxes = torch.cat(boxes, dim=0)
        device, dtype = concat_boxes.device, concat_boxes.dtype
        ids = torch.cat(
            [
eellison's avatar
eellison committed
150
                torch.full_like(b[:, :1], i, dtype=dtype, layout=torch.strided, device=device)
151
152
153
154
155
156
157
                for i, b in enumerate(boxes)
            ],
            dim=0,
        )
        rois = torch.cat([ids, concat_boxes], dim=1)
        return rois

158
    def infer_scale(self, feature: Tensor, original_size: List[int]) -> float:
159
160
        # assumption: the scale is of the form 2 ** (-k), with k integer
        size = feature.shape[-2:]
161
        possible_scales: List[float] = []
162
        for s1, s2 in zip(size, original_size):
163
            approx_scale = float(s1) / float(s2)
eellison's avatar
eellison committed
164
            scale = 2 ** float(torch.tensor(approx_scale).log2().round())
165
166
167
168
            possible_scales.append(scale)
        assert possible_scales[0] == possible_scales[1]
        return possible_scales[0]

169
170
171
172
173
    def setup_scales(
        self,
        features: List[Tensor],
        image_shapes: List[Tuple[int, int]],
    ) -> None:
eellison's avatar
eellison committed
174
175
176
177
178
179
180
181
        assert len(image_shapes) != 0
        max_x = 0
        max_y = 0
        for shape in image_shapes:
            max_x = max(shape[0], max_x)
            max_y = max(shape[1], max_y)
        original_input_shape = (max_x, max_y)

182
183
184
185
186
187
        scales = [self.infer_scale(feat, original_input_shape) for feat in features]
        # get the levels in the feature map by leveraging the fact that the network always
        # downsamples by a factor of 2 at each level.
        lvl_min = -torch.log2(torch.tensor(scales[0], dtype=torch.float32)).item()
        lvl_max = -torch.log2(torch.tensor(scales[-1], dtype=torch.float32)).item()
        self.scales = scales
188
189
190
191
192
193
        self.map_levels = initLevelMapper(
            int(lvl_min),
            int(lvl_max),
            canonical_scale=self.canonical_scale,
            canonical_level=self.canonical_level,
        )
194

195
196
197
198
199
200
    def forward(
        self,
        x: Dict[str, Tensor],
        boxes: List[Tensor],
        image_shapes: List[Tuple[int, int]],
    ) -> Tensor:
201
        """
202
        Args:
203
204
205
            x (OrderedDict[Tensor]): feature maps for each level. They are assumed to have
                all the same number of channels, but they can have different sizes.
            boxes (List[Tensor[N, 4]]): boxes to be used to perform the pooling operation, in
206
                (x1, y1, x2, y2) format and in the image reference size, not the feature map
207
                reference. The coordinate must satisfy ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
208
209
210
            image_shapes (List[Tuple[height, width]]): the sizes of each image before they
                have been fed to a CNN to obtain feature maps. This allows us to infer the
                scale factor for each one of the levels to be pooled.
211
212
213
        Returns:
            result (Tensor)
        """
eellison's avatar
eellison committed
214
215
216
217
218
        x_filtered = []
        for k, v in x.items():
            if k in self.featmap_names:
                x_filtered.append(v)
        num_levels = len(x_filtered)
219
220
        rois = self.convert_to_roi_format(boxes)
        if self.scales is None:
eellison's avatar
eellison committed
221
222
223
224
            self.setup_scales(x_filtered, image_shapes)

        scales = self.scales
        assert scales is not None
225
226
227

        if num_levels == 1:
            return roi_align(
228
229
                x_filtered[0],
                rois,
230
                output_size=self.output_size,
eellison's avatar
eellison committed
231
                spatial_scale=scales[0],
232
                sampling_ratio=self.sampling_ratio,
233
234
            )

eellison's avatar
eellison committed
235
236
237
238
        mapper = self.map_levels
        assert mapper is not None

        levels = mapper(boxes)
239
240

        num_rois = len(rois)
eellison's avatar
eellison committed
241
        num_channels = x_filtered[0].shape[1]
242

eellison's avatar
eellison committed
243
        dtype, device = x_filtered[0].dtype, x_filtered[0].device
244
        result = torch.zeros(
245
246
247
248
249
            (
                num_rois,
                num_channels,
            )
            + self.output_size,
250
251
252
253
            dtype=dtype,
            device=device,
        )

eellison's avatar
eellison committed
254
255
        tracing_results = []
        for level, (per_level_feature, scale) in enumerate(zip(x_filtered, scales)):
256
            idx_in_level = torch.where(levels == level)[0]
257
258
            rois_per_level = rois[idx_in_level]

259
            result_idx_in_level = roi_align(
260
261
                per_level_feature,
                rois_per_level,
262
                output_size=self.output_size,
263
264
265
                spatial_scale=scale,
                sampling_ratio=self.sampling_ratio,
            )
266
267

            if torchvision._is_tracing():
eellison's avatar
eellison committed
268
                tracing_results.append(result_idx_in_level.to(dtype))
269
            else:
270
271
272
273
274
275
276
277
                # result and result_idx_in_level's dtypes are based on dtypes of different
                # elements in x_filtered.  x_filtered contains tensors output by different
                # layers.  When autocast is active, it may choose different dtypes for
                # different layers' outputs.  Therefore, we defensively match result's dtype
                # before copying elements from result_idx_in_level in the following op.
                # We need to cast manually (can't rely on autocast to cast for us) because
                # the op acts on result in-place, and autocast only affects out-of-place ops.
                result[idx_in_level] = result_idx_in_level.to(result.dtype)
278

279
        if torchvision._is_tracing():
eellison's avatar
eellison committed
280
281
            result = _onnx_merge_levels(levels, tracing_results)

282
        return result
283
284

    def __repr__(self) -> str:
285
286
287
288
        return (
            f"{self.__class__.__name__}(featmap_names={self.featmap_names}, "
            f"output_size={self.output_size}, sampling_ratio={self.sampling_ratio})"
        )