poolers.py 8.76 KB
Newer Older
eellison's avatar
eellison committed
1
2
from __future__ import division

3
4
5
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
import torch.nn.functional as F
eellison's avatar
eellison committed
6
from torch import nn, Tensor
7
8
9
10

from torchvision.ops import roi_align
from torchvision.ops.boxes import box_area

eellison's avatar
eellison committed
11
12
from torch.jit.annotations import Optional, List, Dict, Tuple
import torchvision
13

14
15
16
17
# copying result_idx_in_level to a specific index in result[]
# is not supported by ONNX tracing yet.
# _onnx_merge_levels() is an implementation supported by ONNX
# that merges the levels to the right indices
eellison's avatar
eellison committed
18
@torch.jit.unused
19
def _onnx_merge_levels(levels, unmerged_results):
eellison's avatar
eellison committed
20
    # type: (Tensor, List[Tensor]) -> Tensor
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
    first_result = unmerged_results[0]
    dtype, device = first_result.dtype, first_result.device
    res = torch.zeros((levels.size(0), first_result.size(1),
                       first_result.size(2), first_result.size(3)),
                      dtype=dtype, device=device)
    for l in range(len(unmerged_results)):
        index = (levels == l).nonzero().view(-1, 1, 1, 1)
        index = index.expand(index.size(0),
                             unmerged_results[l].size(1),
                             unmerged_results[l].size(2),
                             unmerged_results[l].size(3))
        res = res.scatter(0, index, unmerged_results[l])
    return res


eellison's avatar
eellison committed
36
37
38
39
40
41
42
# TODO: (eellison) T54974082 https://github.com/pytorch/pytorch/issues/26744/pytorch/issues/26744
def initLevelMapper(k_min, k_max, canonical_scale=224, canonical_level=4, eps=1e-6):
    # type: (int, int, int, int, float)
    return LevelMapper(k_min, k_max, canonical_scale, canonical_level, eps)


@torch.jit.script
43
44
45
class LevelMapper(object):
    """Determine which FPN level each RoI in a set of RoIs should map to based
    on the heuristic in the FPN paper.
46
47
48
49
50
51
52

    Arguments:
        k_min (int)
        k_max (int)
        canonical_scale (int)
        canonical_level (int)
        eps (float)
53
54
55
    """

    def __init__(self, k_min, k_max, canonical_scale=224, canonical_level=4, eps=1e-6):
eellison's avatar
eellison committed
56
        # type: (int, int, int, int, float)
57
58
59
60
61
62
63
        self.k_min = k_min
        self.k_max = k_max
        self.s0 = canonical_scale
        self.lvl0 = canonical_level
        self.eps = eps

    def __call__(self, boxlists):
eellison's avatar
eellison committed
64
        # type: (List[Tensor])
65
66
67
68
69
70
71
72
        """
        Arguments:
            boxlists (list[BoxList])
        """
        # Compute level ids
        s = torch.sqrt(torch.cat([box_area(boxlist) for boxlist in boxlists]))

        # Eqn.(1) in FPN paper
73
        target_lvls = torch.floor(self.lvl0 + torch.log2(s / self.s0) + torch.tensor(self.eps, dtype=s.dtype))
74
        target_lvls = torch.clamp(target_lvls, min=self.k_min, max=self.k_max)
75
        return (target_lvls.to(torch.int64) - self.k_min).to(torch.int64)
76
77
78
79


class MultiScaleRoIAlign(nn.Module):
    """
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
    Multi-scale RoIAlign pooling, which is useful for detection with or without FPN.

    It infers the scale of the pooling via the heuristics present in the FPN paper.

    Arguments:
        featmap_names (List[str]): the names of the feature maps that will be used
            for the pooling.
        output_size (List[Tuple[int, int]] or List[int]): output size for the pooled region
        sampling_ratio (int): sampling ratio for ROIAlign

    Examples::

        >>> m = torchvision.ops.MultiScaleRoIAlign(['feat1', 'feat3'], 3, 2)
        >>> i = OrderedDict()
        >>> i['feat1'] = torch.rand(1, 5, 64, 64)
        >>> i['feat2'] = torch.rand(1, 5, 32, 32)  # this feature won't be used in the pooling
        >>> i['feat3'] = torch.rand(1, 5, 16, 16)
        >>> # create some random bounding boxes
        >>> boxes = torch.rand(6, 4) * 256; boxes[:, 2:] += boxes[:, :2]
        >>> # original image size, before computing the feature maps
        >>> image_sizes = [(512, 512)]
        >>> output = m(i, [boxes], image_sizes)
        >>> print(output.shape)
        >>> torch.Size([6, 5, 3, 3])

105
106
    """

eellison's avatar
eellison committed
107
108
109
110
111
    __annotations__ = {
        'scales': Optional[List[float]],
        'map_levels': Optional[LevelMapper]
    }

112
113
114
115
116
117
118
119
120
121
122
    def __init__(self, featmap_names, output_size, sampling_ratio):
        super(MultiScaleRoIAlign, self).__init__()
        if isinstance(output_size, int):
            output_size = (output_size, output_size)
        self.featmap_names = featmap_names
        self.sampling_ratio = sampling_ratio
        self.output_size = tuple(output_size)
        self.scales = None
        self.map_levels = None

    def convert_to_roi_format(self, boxes):
eellison's avatar
eellison committed
123
        # type: (List[Tensor])
124
125
126
127
        concat_boxes = torch.cat(boxes, dim=0)
        device, dtype = concat_boxes.device, concat_boxes.dtype
        ids = torch.cat(
            [
eellison's avatar
eellison committed
128
                torch.full_like(b[:, :1], i, dtype=dtype, layout=torch.strided, device=device)
129
130
131
132
133
134
135
136
                for i, b in enumerate(boxes)
            ],
            dim=0,
        )
        rois = torch.cat([ids, concat_boxes], dim=1)
        return rois

    def infer_scale(self, feature, original_size):
eellison's avatar
eellison committed
137
        # type: (Tensor, List[int])
138
139
        # assumption: the scale is of the form 2 ** (-k), with k integer
        size = feature.shape[-2:]
eellison's avatar
eellison committed
140
        possible_scales = torch.jit.annotate(List[float], [])
141
142
        for s1, s2 in zip(size, original_size):
            approx_scale = float(s1) / s2
eellison's avatar
eellison committed
143
            scale = 2 ** float(torch.tensor(approx_scale).log2().round())
144
145
146
147
148
            possible_scales.append(scale)
        assert possible_scales[0] == possible_scales[1]
        return possible_scales[0]

    def setup_scales(self, features, image_shapes):
eellison's avatar
eellison committed
149
150
151
152
153
154
155
156
157
        # type: (List[Tensor], List[Tuple[int, int]])
        assert len(image_shapes) != 0
        max_x = 0
        max_y = 0
        for shape in image_shapes:
            max_x = max(shape[0], max_x)
            max_y = max(shape[1], max_y)
        original_input_shape = (max_x, max_y)

158
159
160
161
162
163
        scales = [self.infer_scale(feat, original_input_shape) for feat in features]
        # get the levels in the feature map by leveraging the fact that the network always
        # downsamples by a factor of 2 at each level.
        lvl_min = -torch.log2(torch.tensor(scales[0], dtype=torch.float32)).item()
        lvl_max = -torch.log2(torch.tensor(scales[-1], dtype=torch.float32)).item()
        self.scales = scales
eellison's avatar
eellison committed
164
        self.map_levels = initLevelMapper(int(lvl_min), int(lvl_max))
165
166

    def forward(self, x, boxes, image_shapes):
eellison's avatar
eellison committed
167
        # type: (Dict[str, Tensor], List[Tensor], List[Tuple[int, int]])
168
169
        """
        Arguments:
170
171
172
            x (OrderedDict[Tensor]): feature maps for each level. They are assumed to have
                all the same number of channels, but they can have different sizes.
            boxes (List[Tensor[N, 4]]): boxes to be used to perform the pooling operation, in
173
                (x1, y1, x2, y2) format and in the image reference size, not the feature map
174
175
176
177
                reference.
            image_shapes (List[Tuple[height, width]]): the sizes of each image before they
                have been fed to a CNN to obtain feature maps. This allows us to infer the
                scale factor for each one of the levels to be pooled.
178
179
180
        Returns:
            result (Tensor)
        """
eellison's avatar
eellison committed
181
182
183
184
185
        x_filtered = []
        for k, v in x.items():
            if k in self.featmap_names:
                x_filtered.append(v)
        num_levels = len(x_filtered)
186
187
        rois = self.convert_to_roi_format(boxes)
        if self.scales is None:
eellison's avatar
eellison committed
188
189
190
191
            self.setup_scales(x_filtered, image_shapes)

        scales = self.scales
        assert scales is not None
192
193
194

        if num_levels == 1:
            return roi_align(
eellison's avatar
eellison committed
195
                x_filtered[0], rois,
196
                output_size=self.output_size,
eellison's avatar
eellison committed
197
                spatial_scale=scales[0],
198
199
200
                sampling_ratio=self.sampling_ratio
            )

eellison's avatar
eellison committed
201
202
203
204
        mapper = self.map_levels
        assert mapper is not None

        levels = mapper(boxes)
205
206

        num_rois = len(rois)
eellison's avatar
eellison committed
207
        num_channels = x_filtered[0].shape[1]
208

eellison's avatar
eellison committed
209
        dtype, device = x_filtered[0].dtype, x_filtered[0].device
210
211
212
213
214
215
        result = torch.zeros(
            (num_rois, num_channels,) + self.output_size,
            dtype=dtype,
            device=device,
        )

eellison's avatar
eellison committed
216
217
        tracing_results = []
        for level, (per_level_feature, scale) in enumerate(zip(x_filtered, scales)):
218
219
220
            idx_in_level = torch.nonzero(levels == level).squeeze(1)
            rois_per_level = rois[idx_in_level]

221
            result_idx_in_level = roi_align(
222
223
                per_level_feature, rois_per_level,
                output_size=self.output_size,
224
225
226
                spatial_scale=scale, sampling_ratio=self.sampling_ratio)

            if torchvision._is_tracing():
eellison's avatar
eellison committed
227
                tracing_results.append(result_idx_in_level.to(dtype))
228
229
            else:
                result[idx_in_level] = result_idx_in_level
230

231
        if torchvision._is_tracing():
eellison's avatar
eellison committed
232
233
            result = _onnx_merge_levels(levels, tracing_results)

234
        return result