poolers.py 10.6 KB
Newer Older
1
2
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
eellison's avatar
eellison committed
3
from torch import nn, Tensor
4

5
import torchvision
6
7
8
from torchvision.ops import roi_align
from torchvision.ops.boxes import box_area

9
from typing import Optional, List, Dict, Tuple, Union
10

Francisco Massa's avatar
Francisco Massa committed
11

12
13
14
15
# copying result_idx_in_level to a specific index in result[]
# is not supported by ONNX tracing yet.
# _onnx_merge_levels() is an implementation supported by ONNX
# that merges the levels to the right indices
eellison's avatar
eellison committed
16
@torch.jit.unused
17
def _onnx_merge_levels(levels: Tensor, unmerged_results: List[Tensor]) -> Tensor:
18
19
20
21
22
    first_result = unmerged_results[0]
    dtype, device = first_result.dtype, first_result.device
    res = torch.zeros((levels.size(0), first_result.size(1),
                       first_result.size(2), first_result.size(3)),
                      dtype=dtype, device=device)
Francisco Massa's avatar
Francisco Massa committed
23
    for level in range(len(unmerged_results)):
24
        index = torch.where(levels == level)[0].view(-1, 1, 1, 1)
25
        index = index.expand(index.size(0),
Francisco Massa's avatar
Francisco Massa committed
26
27
28
29
                             unmerged_results[level].size(1),
                             unmerged_results[level].size(2),
                             unmerged_results[level].size(3))
        res = res.scatter(0, index, unmerged_results[level])
30
31
32
    return res


eellison's avatar
eellison committed
33
# TODO: (eellison) T54974082 https://github.com/pytorch/pytorch/issues/26744/pytorch/issues/26744
34
35
36
37
38
39
40
def initLevelMapper(
    k_min: int,
    k_max: int,
    canonical_scale: int = 224,
    canonical_level: int = 4,
    eps: float = 1e-6,
):
eellison's avatar
eellison committed
41
42
43
    return LevelMapper(k_min, k_max, canonical_scale, canonical_level, eps)


44
45
46
class LevelMapper(object):
    """Determine which FPN level each RoI in a set of RoIs should map to based
    on the heuristic in the FPN paper.
47

48
    Args:
49
50
51
52
53
        k_min (int)
        k_max (int)
        canonical_scale (int)
        canonical_level (int)
        eps (float)
54
55
    """

56
57
58
59
60
61
62
63
    def __init__(
        self,
        k_min: int,
        k_max: int,
        canonical_scale: int = 224,
        canonical_level: int = 4,
        eps: float = 1e-6,
    ):
64
65
66
67
68
69
        self.k_min = k_min
        self.k_max = k_max
        self.s0 = canonical_scale
        self.lvl0 = canonical_level
        self.eps = eps

70
    def __call__(self, boxlists: List[Tensor]) -> Tensor:
71
        """
72
        Args:
73
74
75
76
77
78
            boxlists (list[BoxList])
        """
        # Compute level ids
        s = torch.sqrt(torch.cat([box_area(boxlist) for boxlist in boxlists]))

        # Eqn.(1) in FPN paper
79
        target_lvls = torch.floor(self.lvl0 + torch.log2(s / self.s0) + torch.tensor(self.eps, dtype=s.dtype))
80
        target_lvls = torch.clamp(target_lvls, min=self.k_min, max=self.k_max)
81
        return (target_lvls.to(torch.int64) - self.k_min).to(torch.int64)
82
83
84
85


class MultiScaleRoIAlign(nn.Module):
    """
86
87
    Multi-scale RoIAlign pooling, which is useful for detection with or without FPN.

88
89
90
91
92
93
    It infers the scale of the pooling via the heuristics specified in eq. 1
    of the `Feature Pyramid Network paper <https://arxiv.org/abs/1612.03144>`_.
    They keyword-only parameters ``canonical_scale`` and ``canonical_level``
    correspond respectively to ``224`` and ``k0=4`` in eq. 1, and
    have the following meaning: ``canonical_level`` is the target level of the pyramid from
    which to pool a region of interest with ``w x h = canonical_scale x canonical_scale``.
94

95
    Args:
96
97
98
99
        featmap_names (List[str]): the names of the feature maps that will be used
            for the pooling.
        output_size (List[Tuple[int, int]] or List[int]): output size for the pooled region
        sampling_ratio (int): sampling ratio for ROIAlign
100
101
        canonical_scale (int, optional): canonical_scale for LevelMapper
        canonical_level (int, optional): canonical_level for LevelMapper
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117

    Examples::

        >>> m = torchvision.ops.MultiScaleRoIAlign(['feat1', 'feat3'], 3, 2)
        >>> i = OrderedDict()
        >>> i['feat1'] = torch.rand(1, 5, 64, 64)
        >>> i['feat2'] = torch.rand(1, 5, 32, 32)  # this feature won't be used in the pooling
        >>> i['feat3'] = torch.rand(1, 5, 16, 16)
        >>> # create some random bounding boxes
        >>> boxes = torch.rand(6, 4) * 256; boxes[:, 2:] += boxes[:, :2]
        >>> # original image size, before computing the feature maps
        >>> image_sizes = [(512, 512)]
        >>> output = m(i, [boxes], image_sizes)
        >>> print(output.shape)
        >>> torch.Size([6, 5, 3, 3])

118
119
    """

eellison's avatar
eellison committed
120
121
122
123
124
    __annotations__ = {
        'scales': Optional[List[float]],
        'map_levels': Optional[LevelMapper]
    }

125
126
127
    def __init__(
        self,
        featmap_names: List[str],
128
        output_size: Union[int, Tuple[int], List[int]],
129
        sampling_ratio: int,
130
131
132
        *,
        canonical_scale: int = 224,
        canonical_level: int = 4,
133
    ):
134
135
136
137
138
139
140
141
        super(MultiScaleRoIAlign, self).__init__()
        if isinstance(output_size, int):
            output_size = (output_size, output_size)
        self.featmap_names = featmap_names
        self.sampling_ratio = sampling_ratio
        self.output_size = tuple(output_size)
        self.scales = None
        self.map_levels = None
142
143
        self.canonical_scale = canonical_scale
        self.canonical_level = canonical_level
144

145
    def convert_to_roi_format(self, boxes: List[Tensor]) -> Tensor:
146
147
148
149
        concat_boxes = torch.cat(boxes, dim=0)
        device, dtype = concat_boxes.device, concat_boxes.dtype
        ids = torch.cat(
            [
eellison's avatar
eellison committed
150
                torch.full_like(b[:, :1], i, dtype=dtype, layout=torch.strided, device=device)
151
152
153
154
155
156
157
                for i, b in enumerate(boxes)
            ],
            dim=0,
        )
        rois = torch.cat([ids, concat_boxes], dim=1)
        return rois

158
    def infer_scale(self, feature: Tensor, original_size: List[int]) -> float:
159
160
        # assumption: the scale is of the form 2 ** (-k), with k integer
        size = feature.shape[-2:]
161
        possible_scales: List[float] = []
162
        for s1, s2 in zip(size, original_size):
163
            approx_scale = float(s1) / float(s2)
eellison's avatar
eellison committed
164
            scale = 2 ** float(torch.tensor(approx_scale).log2().round())
165
166
167
168
            possible_scales.append(scale)
        assert possible_scales[0] == possible_scales[1]
        return possible_scales[0]

169
170
171
172
173
    def setup_scales(
        self,
        features: List[Tensor],
        image_shapes: List[Tuple[int, int]],
    ) -> None:
eellison's avatar
eellison committed
174
175
176
177
178
179
180
181
        assert len(image_shapes) != 0
        max_x = 0
        max_y = 0
        for shape in image_shapes:
            max_x = max(shape[0], max_x)
            max_y = max(shape[1], max_y)
        original_input_shape = (max_x, max_y)

182
183
184
185
186
187
        scales = [self.infer_scale(feat, original_input_shape) for feat in features]
        # get the levels in the feature map by leveraging the fact that the network always
        # downsamples by a factor of 2 at each level.
        lvl_min = -torch.log2(torch.tensor(scales[0], dtype=torch.float32)).item()
        lvl_max = -torch.log2(torch.tensor(scales[-1], dtype=torch.float32)).item()
        self.scales = scales
188
189
190
191
192
193
        self.map_levels = initLevelMapper(
            int(lvl_min),
            int(lvl_max),
            canonical_scale=self.canonical_scale,
            canonical_level=self.canonical_level,
        )
194

195
196
197
198
199
200
    def forward(
        self,
        x: Dict[str, Tensor],
        boxes: List[Tensor],
        image_shapes: List[Tuple[int, int]],
    ) -> Tensor:
201
        """
202
        Args:
203
204
205
            x (OrderedDict[Tensor]): feature maps for each level. They are assumed to have
                all the same number of channels, but they can have different sizes.
            boxes (List[Tensor[N, 4]]): boxes to be used to perform the pooling operation, in
206
                (x1, y1, x2, y2) format and in the image reference size, not the feature map
207
                reference. The coordinate must satisfy ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
208
209
210
            image_shapes (List[Tuple[height, width]]): the sizes of each image before they
                have been fed to a CNN to obtain feature maps. This allows us to infer the
                scale factor for each one of the levels to be pooled.
211
212
213
        Returns:
            result (Tensor)
        """
eellison's avatar
eellison committed
214
215
216
217
218
        x_filtered = []
        for k, v in x.items():
            if k in self.featmap_names:
                x_filtered.append(v)
        num_levels = len(x_filtered)
219
220
        rois = self.convert_to_roi_format(boxes)
        if self.scales is None:
eellison's avatar
eellison committed
221
222
223
224
            self.setup_scales(x_filtered, image_shapes)

        scales = self.scales
        assert scales is not None
225
226
227

        if num_levels == 1:
            return roi_align(
eellison's avatar
eellison committed
228
                x_filtered[0], rois,
229
                output_size=self.output_size,
eellison's avatar
eellison committed
230
                spatial_scale=scales[0],
231
232
233
                sampling_ratio=self.sampling_ratio
            )

eellison's avatar
eellison committed
234
235
236
237
        mapper = self.map_levels
        assert mapper is not None

        levels = mapper(boxes)
238
239

        num_rois = len(rois)
eellison's avatar
eellison committed
240
        num_channels = x_filtered[0].shape[1]
241

eellison's avatar
eellison committed
242
        dtype, device = x_filtered[0].dtype, x_filtered[0].device
243
244
245
246
247
248
        result = torch.zeros(
            (num_rois, num_channels,) + self.output_size,
            dtype=dtype,
            device=device,
        )

eellison's avatar
eellison committed
249
250
        tracing_results = []
        for level, (per_level_feature, scale) in enumerate(zip(x_filtered, scales)):
251
            idx_in_level = torch.where(levels == level)[0]
252
253
            rois_per_level = rois[idx_in_level]

254
            result_idx_in_level = roi_align(
255
256
                per_level_feature, rois_per_level,
                output_size=self.output_size,
257
258
259
                spatial_scale=scale, sampling_ratio=self.sampling_ratio)

            if torchvision._is_tracing():
eellison's avatar
eellison committed
260
                tracing_results.append(result_idx_in_level.to(dtype))
261
            else:
262
263
264
265
266
267
268
269
                # result and result_idx_in_level's dtypes are based on dtypes of different
                # elements in x_filtered.  x_filtered contains tensors output by different
                # layers.  When autocast is active, it may choose different dtypes for
                # different layers' outputs.  Therefore, we defensively match result's dtype
                # before copying elements from result_idx_in_level in the following op.
                # We need to cast manually (can't rely on autocast to cast for us) because
                # the op acts on result in-place, and autocast only affects out-of-place ops.
                result[idx_in_level] = result_idx_in_level.to(result.dtype)
270

271
        if torchvision._is_tracing():
eellison's avatar
eellison committed
272
273
            result = _onnx_merge_levels(levels, tracing_results)

274
        return result
275
276
277
278

    def __repr__(self) -> str:
        return (f"{self.__class__.__name__}(featmap_names={self.featmap_names}, "
                f"output_size={self.output_size}, sampling_ratio={self.sampling_ratio})")