poolers.py 8.81 KB
Newer Older
1
2
3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
import torch.nn.functional as F
eellison's avatar
eellison committed
4
from torch import nn, Tensor
5
6
7
8

from torchvision.ops import roi_align
from torchvision.ops.boxes import box_area

eellison's avatar
eellison committed
9
10
from torch.jit.annotations import Optional, List, Dict, Tuple
import torchvision
11

Francisco Massa's avatar
Francisco Massa committed
12

13
14
15
16
# copying result_idx_in_level to a specific index in result[]
# is not supported by ONNX tracing yet.
# _onnx_merge_levels() is an implementation supported by ONNX
# that merges the levels to the right indices
eellison's avatar
eellison committed
17
@torch.jit.unused
18
def _onnx_merge_levels(levels, unmerged_results):
eellison's avatar
eellison committed
19
    # type: (Tensor, List[Tensor]) -> Tensor
20
21
22
23
24
    first_result = unmerged_results[0]
    dtype, device = first_result.dtype, first_result.device
    res = torch.zeros((levels.size(0), first_result.size(1),
                       first_result.size(2), first_result.size(3)),
                      dtype=dtype, device=device)
Francisco Massa's avatar
Francisco Massa committed
25
26
    for level in range(len(unmerged_results)):
        index = (levels == level).nonzero().view(-1, 1, 1, 1)
27
        index = index.expand(index.size(0),
Francisco Massa's avatar
Francisco Massa committed
28
29
30
31
                             unmerged_results[level].size(1),
                             unmerged_results[level].size(2),
                             unmerged_results[level].size(3))
        res = res.scatter(0, index, unmerged_results[level])
32
33
34
    return res


eellison's avatar
eellison committed
35
36
# TODO: (eellison) T54974082 https://github.com/pytorch/pytorch/issues/26744/pytorch/issues/26744
def initLevelMapper(k_min, k_max, canonical_scale=224, canonical_level=4, eps=1e-6):
37
    # type: (int, int, int, int, float) -> LevelMapper
eellison's avatar
eellison committed
38
39
40
    return LevelMapper(k_min, k_max, canonical_scale, canonical_level, eps)


41
42
43
class LevelMapper(object):
    """Determine which FPN level each RoI in a set of RoIs should map to based
    on the heuristic in the FPN paper.
44
45
46
47
48
49
50

    Arguments:
        k_min (int)
        k_max (int)
        canonical_scale (int)
        canonical_level (int)
        eps (float)
51
52
53
    """

    def __init__(self, k_min, k_max, canonical_scale=224, canonical_level=4, eps=1e-6):
54
        # type: (int, int, int, int, float) -> None
55
56
57
58
59
60
61
        self.k_min = k_min
        self.k_max = k_max
        self.s0 = canonical_scale
        self.lvl0 = canonical_level
        self.eps = eps

    def __call__(self, boxlists):
62
        # type: (List[Tensor]) -> Tensor
63
64
65
66
67
68
69
70
        """
        Arguments:
            boxlists (list[BoxList])
        """
        # Compute level ids
        s = torch.sqrt(torch.cat([box_area(boxlist) for boxlist in boxlists]))

        # Eqn.(1) in FPN paper
71
        target_lvls = torch.floor(self.lvl0 + torch.log2(s / self.s0) + torch.tensor(self.eps, dtype=s.dtype))
72
        target_lvls = torch.clamp(target_lvls, min=self.k_min, max=self.k_max)
73
        return (target_lvls.to(torch.int64) - self.k_min).to(torch.int64)
74
75
76
77


class MultiScaleRoIAlign(nn.Module):
    """
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
    Multi-scale RoIAlign pooling, which is useful for detection with or without FPN.

    It infers the scale of the pooling via the heuristics present in the FPN paper.

    Arguments:
        featmap_names (List[str]): the names of the feature maps that will be used
            for the pooling.
        output_size (List[Tuple[int, int]] or List[int]): output size for the pooled region
        sampling_ratio (int): sampling ratio for ROIAlign

    Examples::

        >>> m = torchvision.ops.MultiScaleRoIAlign(['feat1', 'feat3'], 3, 2)
        >>> i = OrderedDict()
        >>> i['feat1'] = torch.rand(1, 5, 64, 64)
        >>> i['feat2'] = torch.rand(1, 5, 32, 32)  # this feature won't be used in the pooling
        >>> i['feat3'] = torch.rand(1, 5, 16, 16)
        >>> # create some random bounding boxes
        >>> boxes = torch.rand(6, 4) * 256; boxes[:, 2:] += boxes[:, :2]
        >>> # original image size, before computing the feature maps
        >>> image_sizes = [(512, 512)]
        >>> output = m(i, [boxes], image_sizes)
        >>> print(output.shape)
        >>> torch.Size([6, 5, 3, 3])

103
104
    """

eellison's avatar
eellison committed
105
106
107
108
109
    __annotations__ = {
        'scales': Optional[List[float]],
        'map_levels': Optional[LevelMapper]
    }

110
111
112
113
114
115
116
117
118
119
120
    def __init__(self, featmap_names, output_size, sampling_ratio):
        super(MultiScaleRoIAlign, self).__init__()
        if isinstance(output_size, int):
            output_size = (output_size, output_size)
        self.featmap_names = featmap_names
        self.sampling_ratio = sampling_ratio
        self.output_size = tuple(output_size)
        self.scales = None
        self.map_levels = None

    def convert_to_roi_format(self, boxes):
121
        # type: (List[Tensor]) -> Tensor
122
123
124
125
        concat_boxes = torch.cat(boxes, dim=0)
        device, dtype = concat_boxes.device, concat_boxes.dtype
        ids = torch.cat(
            [
eellison's avatar
eellison committed
126
                torch.full_like(b[:, :1], i, dtype=dtype, layout=torch.strided, device=device)
127
128
129
130
131
132
133
134
                for i, b in enumerate(boxes)
            ],
            dim=0,
        )
        rois = torch.cat([ids, concat_boxes], dim=1)
        return rois

    def infer_scale(self, feature, original_size):
135
        # type: (Tensor, List[int]) -> float
136
137
        # assumption: the scale is of the form 2 ** (-k), with k integer
        size = feature.shape[-2:]
eellison's avatar
eellison committed
138
        possible_scales = torch.jit.annotate(List[float], [])
139
        for s1, s2 in zip(size, original_size):
140
            approx_scale = float(s1) / float(s2)
eellison's avatar
eellison committed
141
            scale = 2 ** float(torch.tensor(approx_scale).log2().round())
142
143
144
145
146
            possible_scales.append(scale)
        assert possible_scales[0] == possible_scales[1]
        return possible_scales[0]

    def setup_scales(self, features, image_shapes):
147
        # type: (List[Tensor], List[Tuple[int, int]]) -> None
eellison's avatar
eellison committed
148
149
150
151
152
153
154
155
        assert len(image_shapes) != 0
        max_x = 0
        max_y = 0
        for shape in image_shapes:
            max_x = max(shape[0], max_x)
            max_y = max(shape[1], max_y)
        original_input_shape = (max_x, max_y)

156
157
158
159
160
161
        scales = [self.infer_scale(feat, original_input_shape) for feat in features]
        # get the levels in the feature map by leveraging the fact that the network always
        # downsamples by a factor of 2 at each level.
        lvl_min = -torch.log2(torch.tensor(scales[0], dtype=torch.float32)).item()
        lvl_max = -torch.log2(torch.tensor(scales[-1], dtype=torch.float32)).item()
        self.scales = scales
eellison's avatar
eellison committed
162
        self.map_levels = initLevelMapper(int(lvl_min), int(lvl_max))
163
164

    def forward(self, x, boxes, image_shapes):
165
        # type: (Dict[str, Tensor], List[Tensor], List[Tuple[int, int]]) -> Tensor
166
167
        """
        Arguments:
168
169
170
            x (OrderedDict[Tensor]): feature maps for each level. They are assumed to have
                all the same number of channels, but they can have different sizes.
            boxes (List[Tensor[N, 4]]): boxes to be used to perform the pooling operation, in
171
                (x1, y1, x2, y2) format and in the image reference size, not the feature map
172
173
174
175
                reference.
            image_shapes (List[Tuple[height, width]]): the sizes of each image before they
                have been fed to a CNN to obtain feature maps. This allows us to infer the
                scale factor for each one of the levels to be pooled.
176
177
178
        Returns:
            result (Tensor)
        """
eellison's avatar
eellison committed
179
180
181
182
183
        x_filtered = []
        for k, v in x.items():
            if k in self.featmap_names:
                x_filtered.append(v)
        num_levels = len(x_filtered)
184
185
        rois = self.convert_to_roi_format(boxes)
        if self.scales is None:
eellison's avatar
eellison committed
186
187
188
189
            self.setup_scales(x_filtered, image_shapes)

        scales = self.scales
        assert scales is not None
190
191
192

        if num_levels == 1:
            return roi_align(
eellison's avatar
eellison committed
193
                x_filtered[0], rois,
194
                output_size=self.output_size,
eellison's avatar
eellison committed
195
                spatial_scale=scales[0],
196
197
198
                sampling_ratio=self.sampling_ratio
            )

eellison's avatar
eellison committed
199
200
201
202
        mapper = self.map_levels
        assert mapper is not None

        levels = mapper(boxes)
203
204

        num_rois = len(rois)
eellison's avatar
eellison committed
205
        num_channels = x_filtered[0].shape[1]
206

eellison's avatar
eellison committed
207
        dtype, device = x_filtered[0].dtype, x_filtered[0].device
208
209
210
211
212
213
        result = torch.zeros(
            (num_rois, num_channels,) + self.output_size,
            dtype=dtype,
            device=device,
        )

eellison's avatar
eellison committed
214
215
        tracing_results = []
        for level, (per_level_feature, scale) in enumerate(zip(x_filtered, scales)):
216
217
218
            idx_in_level = torch.nonzero(levels == level).squeeze(1)
            rois_per_level = rois[idx_in_level]

219
            result_idx_in_level = roi_align(
220
221
                per_level_feature, rois_per_level,
                output_size=self.output_size,
222
223
224
                spatial_scale=scale, sampling_ratio=self.sampling_ratio)

            if torchvision._is_tracing():
eellison's avatar
eellison committed
225
                tracing_results.append(result_idx_in_level.to(dtype))
226
227
            else:
                result[idx_in_level] = result_idx_in_level
228

229
        if torchvision._is_tracing():
eellison's avatar
eellison committed
230
231
            result = _onnx_merge_levels(levels, tracing_results)

232
        return result