poolers.py 9.52 KB
Newer Older
1
2
3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
import torch.nn.functional as F
eellison's avatar
eellison committed
4
from torch import nn, Tensor
5
6
7
8

from torchvision.ops import roi_align
from torchvision.ops.boxes import box_area

eellison's avatar
eellison committed
9
10
from torch.jit.annotations import Optional, List, Dict, Tuple
import torchvision
11

Francisco Massa's avatar
Francisco Massa committed
12

13
14
15
16
# copying result_idx_in_level to a specific index in result[]
# is not supported by ONNX tracing yet.
# _onnx_merge_levels() is an implementation supported by ONNX
# that merges the levels to the right indices
eellison's avatar
eellison committed
17
@torch.jit.unused
18
def _onnx_merge_levels(levels: Tensor, unmerged_results: List[Tensor]) -> Tensor:
19
20
21
22
23
    first_result = unmerged_results[0]
    dtype, device = first_result.dtype, first_result.device
    res = torch.zeros((levels.size(0), first_result.size(1),
                       first_result.size(2), first_result.size(3)),
                      dtype=dtype, device=device)
Francisco Massa's avatar
Francisco Massa committed
24
25
    for level in range(len(unmerged_results)):
        index = (levels == level).nonzero().view(-1, 1, 1, 1)
26
        index = index.expand(index.size(0),
Francisco Massa's avatar
Francisco Massa committed
27
28
29
30
                             unmerged_results[level].size(1),
                             unmerged_results[level].size(2),
                             unmerged_results[level].size(3))
        res = res.scatter(0, index, unmerged_results[level])
31
32
33
    return res


eellison's avatar
eellison committed
34
# TODO: (eellison) T54974082 https://github.com/pytorch/pytorch/issues/26744/pytorch/issues/26744
35
36
37
38
39
40
41
def initLevelMapper(
    k_min: int,
    k_max: int,
    canonical_scale: int = 224,
    canonical_level: int = 4,
    eps: float = 1e-6,
):
eellison's avatar
eellison committed
42
43
44
    return LevelMapper(k_min, k_max, canonical_scale, canonical_level, eps)


45
46
47
class LevelMapper(object):
    """Determine which FPN level each RoI in a set of RoIs should map to based
    on the heuristic in the FPN paper.
48
49
50
51
52
53
54

    Arguments:
        k_min (int)
        k_max (int)
        canonical_scale (int)
        canonical_level (int)
        eps (float)
55
56
    """

57
58
59
60
61
62
63
64
    def __init__(
        self,
        k_min: int,
        k_max: int,
        canonical_scale: int = 224,
        canonical_level: int = 4,
        eps: float = 1e-6,
    ):
65
66
67
68
69
70
        self.k_min = k_min
        self.k_max = k_max
        self.s0 = canonical_scale
        self.lvl0 = canonical_level
        self.eps = eps

71
    def __call__(self, boxlists: List[Tensor]) -> Tensor:
72
73
74
75
76
77
78
79
        """
        Arguments:
            boxlists (list[BoxList])
        """
        # Compute level ids
        s = torch.sqrt(torch.cat([box_area(boxlist) for boxlist in boxlists]))

        # Eqn.(1) in FPN paper
80
        target_lvls = torch.floor(self.lvl0 + torch.log2(s / self.s0) + torch.tensor(self.eps, dtype=s.dtype))
81
        target_lvls = torch.clamp(target_lvls, min=self.k_min, max=self.k_max)
82
        return (target_lvls.to(torch.int64) - self.k_min).to(torch.int64)
83
84
85
86


class MultiScaleRoIAlign(nn.Module):
    """
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
    Multi-scale RoIAlign pooling, which is useful for detection with or without FPN.

    It infers the scale of the pooling via the heuristics present in the FPN paper.

    Arguments:
        featmap_names (List[str]): the names of the feature maps that will be used
            for the pooling.
        output_size (List[Tuple[int, int]] or List[int]): output size for the pooled region
        sampling_ratio (int): sampling ratio for ROIAlign

    Examples::

        >>> m = torchvision.ops.MultiScaleRoIAlign(['feat1', 'feat3'], 3, 2)
        >>> i = OrderedDict()
        >>> i['feat1'] = torch.rand(1, 5, 64, 64)
        >>> i['feat2'] = torch.rand(1, 5, 32, 32)  # this feature won't be used in the pooling
        >>> i['feat3'] = torch.rand(1, 5, 16, 16)
        >>> # create some random bounding boxes
        >>> boxes = torch.rand(6, 4) * 256; boxes[:, 2:] += boxes[:, :2]
        >>> # original image size, before computing the feature maps
        >>> image_sizes = [(512, 512)]
        >>> output = m(i, [boxes], image_sizes)
        >>> print(output.shape)
        >>> torch.Size([6, 5, 3, 3])

112
113
    """

eellison's avatar
eellison committed
114
115
116
117
118
    __annotations__ = {
        'scales': Optional[List[float]],
        'map_levels': Optional[LevelMapper]
    }

119
120
121
122
123
124
    def __init__(
        self,
        featmap_names: List[str],
        output_size: List[int],
        sampling_ratio: int,
    ):
125
126
127
128
129
130
131
132
133
        super(MultiScaleRoIAlign, self).__init__()
        if isinstance(output_size, int):
            output_size = (output_size, output_size)
        self.featmap_names = featmap_names
        self.sampling_ratio = sampling_ratio
        self.output_size = tuple(output_size)
        self.scales = None
        self.map_levels = None

134
    def convert_to_roi_format(self, boxes: List[Tensor]) -> Tensor:
135
136
137
138
        concat_boxes = torch.cat(boxes, dim=0)
        device, dtype = concat_boxes.device, concat_boxes.dtype
        ids = torch.cat(
            [
eellison's avatar
eellison committed
139
                torch.full_like(b[:, :1], i, dtype=dtype, layout=torch.strided, device=device)
140
141
142
143
144
145
146
                for i, b in enumerate(boxes)
            ],
            dim=0,
        )
        rois = torch.cat([ids, concat_boxes], dim=1)
        return rois

147
    def infer_scale(self, feature: Tensor, original_size: List[int]) -> float:
148
149
        # assumption: the scale is of the form 2 ** (-k), with k integer
        size = feature.shape[-2:]
eellison's avatar
eellison committed
150
        possible_scales = torch.jit.annotate(List[float], [])
151
        for s1, s2 in zip(size, original_size):
152
            approx_scale = float(s1) / float(s2)
eellison's avatar
eellison committed
153
            scale = 2 ** float(torch.tensor(approx_scale).log2().round())
154
155
156
157
            possible_scales.append(scale)
        assert possible_scales[0] == possible_scales[1]
        return possible_scales[0]

158
159
160
161
162
    def setup_scales(
        self,
        features: List[Tensor],
        image_shapes: List[Tuple[int, int]],
    ) -> None:
eellison's avatar
eellison committed
163
164
165
166
167
168
169
170
        assert len(image_shapes) != 0
        max_x = 0
        max_y = 0
        for shape in image_shapes:
            max_x = max(shape[0], max_x)
            max_y = max(shape[1], max_y)
        original_input_shape = (max_x, max_y)

171
172
173
174
175
176
        scales = [self.infer_scale(feat, original_input_shape) for feat in features]
        # get the levels in the feature map by leveraging the fact that the network always
        # downsamples by a factor of 2 at each level.
        lvl_min = -torch.log2(torch.tensor(scales[0], dtype=torch.float32)).item()
        lvl_max = -torch.log2(torch.tensor(scales[-1], dtype=torch.float32)).item()
        self.scales = scales
eellison's avatar
eellison committed
177
        self.map_levels = initLevelMapper(int(lvl_min), int(lvl_max))
178

179
180
181
182
183
184
    def forward(
        self,
        x: Dict[str, Tensor],
        boxes: List[Tensor],
        image_shapes: List[Tuple[int, int]],
    ) -> Tensor:
185
186
        """
        Arguments:
187
188
189
            x (OrderedDict[Tensor]): feature maps for each level. They are assumed to have
                all the same number of channels, but they can have different sizes.
            boxes (List[Tensor[N, 4]]): boxes to be used to perform the pooling operation, in
190
                (x1, y1, x2, y2) format and in the image reference size, not the feature map
191
192
193
194
                reference.
            image_shapes (List[Tuple[height, width]]): the sizes of each image before they
                have been fed to a CNN to obtain feature maps. This allows us to infer the
                scale factor for each one of the levels to be pooled.
195
196
197
        Returns:
            result (Tensor)
        """
eellison's avatar
eellison committed
198
199
200
201
202
        x_filtered = []
        for k, v in x.items():
            if k in self.featmap_names:
                x_filtered.append(v)
        num_levels = len(x_filtered)
203
204
        rois = self.convert_to_roi_format(boxes)
        if self.scales is None:
eellison's avatar
eellison committed
205
206
207
208
            self.setup_scales(x_filtered, image_shapes)

        scales = self.scales
        assert scales is not None
209
210
211

        if num_levels == 1:
            return roi_align(
eellison's avatar
eellison committed
212
                x_filtered[0], rois,
213
                output_size=self.output_size,
eellison's avatar
eellison committed
214
                spatial_scale=scales[0],
215
216
217
                sampling_ratio=self.sampling_ratio
            )

eellison's avatar
eellison committed
218
219
220
221
        mapper = self.map_levels
        assert mapper is not None

        levels = mapper(boxes)
222
223

        num_rois = len(rois)
eellison's avatar
eellison committed
224
        num_channels = x_filtered[0].shape[1]
225

eellison's avatar
eellison committed
226
        dtype, device = x_filtered[0].dtype, x_filtered[0].device
227
228
229
230
231
232
        result = torch.zeros(
            (num_rois, num_channels,) + self.output_size,
            dtype=dtype,
            device=device,
        )

eellison's avatar
eellison committed
233
234
        tracing_results = []
        for level, (per_level_feature, scale) in enumerate(zip(x_filtered, scales)):
235
236
237
            idx_in_level = torch.nonzero(levels == level).squeeze(1)
            rois_per_level = rois[idx_in_level]

238
            result_idx_in_level = roi_align(
239
240
                per_level_feature, rois_per_level,
                output_size=self.output_size,
241
242
243
                spatial_scale=scale, sampling_ratio=self.sampling_ratio)

            if torchvision._is_tracing():
eellison's avatar
eellison committed
244
                tracing_results.append(result_idx_in_level.to(dtype))
245
            else:
246
247
248
249
250
251
252
253
                # result and result_idx_in_level's dtypes are based on dtypes of different
                # elements in x_filtered.  x_filtered contains tensors output by different
                # layers.  When autocast is active, it may choose different dtypes for
                # different layers' outputs.  Therefore, we defensively match result's dtype
                # before copying elements from result_idx_in_level in the following op.
                # We need to cast manually (can't rely on autocast to cast for us) because
                # the op acts on result in-place, and autocast only affects out-of-place ops.
                result[idx_in_level] = result_idx_in_level.to(result.dtype)
254

255
        if torchvision._is_tracing():
eellison's avatar
eellison committed
256
257
            result = _onnx_merge_levels(levels, tracing_results)

258
        return result