poolers.py 9.76 KB
Newer Older
1
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
3
from typing import Union

4
5
import torch
import torch.nn.functional as F
eellison's avatar
eellison committed
6
from torch import nn, Tensor
7
8
9
10

from torchvision.ops import roi_align
from torchvision.ops.boxes import box_area

eellison's avatar
eellison committed
11
12
from torch.jit.annotations import Optional, List, Dict, Tuple
import torchvision
13

Francisco Massa's avatar
Francisco Massa committed
14

15
16
17
18
# copying result_idx_in_level to a specific index in result[]
# is not supported by ONNX tracing yet.
# _onnx_merge_levels() is an implementation supported by ONNX
# that merges the levels to the right indices
eellison's avatar
eellison committed
19
@torch.jit.unused
20
def _onnx_merge_levels(levels: Tensor, unmerged_results: List[Tensor]) -> Tensor:
21
22
23
24
25
    first_result = unmerged_results[0]
    dtype, device = first_result.dtype, first_result.device
    res = torch.zeros((levels.size(0), first_result.size(1),
                       first_result.size(2), first_result.size(3)),
                      dtype=dtype, device=device)
Francisco Massa's avatar
Francisco Massa committed
26
    for level in range(len(unmerged_results)):
27
        index = torch.where(levels == level)[0].view(-1, 1, 1, 1)
28
        index = index.expand(index.size(0),
Francisco Massa's avatar
Francisco Massa committed
29
30
31
32
                             unmerged_results[level].size(1),
                             unmerged_results[level].size(2),
                             unmerged_results[level].size(3))
        res = res.scatter(0, index, unmerged_results[level])
33
34
35
    return res


eellison's avatar
eellison committed
36
# TODO: (eellison) T54974082 https://github.com/pytorch/pytorch/issues/26744/pytorch/issues/26744
37
38
39
40
41
42
43
def initLevelMapper(
    k_min: int,
    k_max: int,
    canonical_scale: int = 224,
    canonical_level: int = 4,
    eps: float = 1e-6,
):
eellison's avatar
eellison committed
44
45
46
    return LevelMapper(k_min, k_max, canonical_scale, canonical_level, eps)


47
48
49
class LevelMapper(object):
    """Determine which FPN level each RoI in a set of RoIs should map to based
    on the heuristic in the FPN paper.
50
51
52
53
54
55
56

    Arguments:
        k_min (int)
        k_max (int)
        canonical_scale (int)
        canonical_level (int)
        eps (float)
57
58
    """

59
60
61
62
63
64
65
66
    def __init__(
        self,
        k_min: int,
        k_max: int,
        canonical_scale: int = 224,
        canonical_level: int = 4,
        eps: float = 1e-6,
    ):
67
68
69
70
71
72
        self.k_min = k_min
        self.k_max = k_max
        self.s0 = canonical_scale
        self.lvl0 = canonical_level
        self.eps = eps

73
    def __call__(self, boxlists: List[Tensor]) -> Tensor:
74
75
76
77
78
79
80
81
        """
        Arguments:
            boxlists (list[BoxList])
        """
        # Compute level ids
        s = torch.sqrt(torch.cat([box_area(boxlist) for boxlist in boxlists]))

        # Eqn.(1) in FPN paper
82
        target_lvls = torch.floor(self.lvl0 + torch.log2(s / self.s0) + torch.tensor(self.eps, dtype=s.dtype))
83
        target_lvls = torch.clamp(target_lvls, min=self.k_min, max=self.k_max)
84
        return (target_lvls.to(torch.int64) - self.k_min).to(torch.int64)
85
86
87
88


class MultiScaleRoIAlign(nn.Module):
    """
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
    Multi-scale RoIAlign pooling, which is useful for detection with or without FPN.

    It infers the scale of the pooling via the heuristics present in the FPN paper.

    Arguments:
        featmap_names (List[str]): the names of the feature maps that will be used
            for the pooling.
        output_size (List[Tuple[int, int]] or List[int]): output size for the pooled region
        sampling_ratio (int): sampling ratio for ROIAlign

    Examples::

        >>> m = torchvision.ops.MultiScaleRoIAlign(['feat1', 'feat3'], 3, 2)
        >>> i = OrderedDict()
        >>> i['feat1'] = torch.rand(1, 5, 64, 64)
        >>> i['feat2'] = torch.rand(1, 5, 32, 32)  # this feature won't be used in the pooling
        >>> i['feat3'] = torch.rand(1, 5, 16, 16)
        >>> # create some random bounding boxes
        >>> boxes = torch.rand(6, 4) * 256; boxes[:, 2:] += boxes[:, :2]
        >>> # original image size, before computing the feature maps
        >>> image_sizes = [(512, 512)]
        >>> output = m(i, [boxes], image_sizes)
        >>> print(output.shape)
        >>> torch.Size([6, 5, 3, 3])

114
115
    """

eellison's avatar
eellison committed
116
117
118
119
120
    __annotations__ = {
        'scales': Optional[List[float]],
        'map_levels': Optional[LevelMapper]
    }

121
122
123
    def __init__(
        self,
        featmap_names: List[str],
124
        output_size: Union[int, Tuple[int], List[int]],
125
126
        sampling_ratio: int,
    ):
127
128
129
130
131
132
133
134
135
        super(MultiScaleRoIAlign, self).__init__()
        if isinstance(output_size, int):
            output_size = (output_size, output_size)
        self.featmap_names = featmap_names
        self.sampling_ratio = sampling_ratio
        self.output_size = tuple(output_size)
        self.scales = None
        self.map_levels = None

136
    def convert_to_roi_format(self, boxes: List[Tensor]) -> Tensor:
137
138
139
140
        concat_boxes = torch.cat(boxes, dim=0)
        device, dtype = concat_boxes.device, concat_boxes.dtype
        ids = torch.cat(
            [
eellison's avatar
eellison committed
141
                torch.full_like(b[:, :1], i, dtype=dtype, layout=torch.strided, device=device)
142
143
144
145
146
147
148
                for i, b in enumerate(boxes)
            ],
            dim=0,
        )
        rois = torch.cat([ids, concat_boxes], dim=1)
        return rois

149
    def infer_scale(self, feature: Tensor, original_size: List[int]) -> float:
150
151
        # assumption: the scale is of the form 2 ** (-k), with k integer
        size = feature.shape[-2:]
eellison's avatar
eellison committed
152
        possible_scales = torch.jit.annotate(List[float], [])
153
        for s1, s2 in zip(size, original_size):
154
            approx_scale = float(s1) / float(s2)
eellison's avatar
eellison committed
155
            scale = 2 ** float(torch.tensor(approx_scale).log2().round())
156
157
158
159
            possible_scales.append(scale)
        assert possible_scales[0] == possible_scales[1]
        return possible_scales[0]

160
161
162
163
164
    def setup_scales(
        self,
        features: List[Tensor],
        image_shapes: List[Tuple[int, int]],
    ) -> None:
eellison's avatar
eellison committed
165
166
167
168
169
170
171
172
        assert len(image_shapes) != 0
        max_x = 0
        max_y = 0
        for shape in image_shapes:
            max_x = max(shape[0], max_x)
            max_y = max(shape[1], max_y)
        original_input_shape = (max_x, max_y)

173
174
175
176
177
178
        scales = [self.infer_scale(feat, original_input_shape) for feat in features]
        # get the levels in the feature map by leveraging the fact that the network always
        # downsamples by a factor of 2 at each level.
        lvl_min = -torch.log2(torch.tensor(scales[0], dtype=torch.float32)).item()
        lvl_max = -torch.log2(torch.tensor(scales[-1], dtype=torch.float32)).item()
        self.scales = scales
eellison's avatar
eellison committed
179
        self.map_levels = initLevelMapper(int(lvl_min), int(lvl_max))
180

181
182
183
184
185
186
    def forward(
        self,
        x: Dict[str, Tensor],
        boxes: List[Tensor],
        image_shapes: List[Tuple[int, int]],
    ) -> Tensor:
187
188
        """
        Arguments:
189
190
191
            x (OrderedDict[Tensor]): feature maps for each level. They are assumed to have
                all the same number of channels, but they can have different sizes.
            boxes (List[Tensor[N, 4]]): boxes to be used to perform the pooling operation, in
192
                (x1, y1, x2, y2) format and in the image reference size, not the feature map
193
194
195
196
                reference.
            image_shapes (List[Tuple[height, width]]): the sizes of each image before they
                have been fed to a CNN to obtain feature maps. This allows us to infer the
                scale factor for each one of the levels to be pooled.
197
198
199
        Returns:
            result (Tensor)
        """
eellison's avatar
eellison committed
200
201
202
203
204
        x_filtered = []
        for k, v in x.items():
            if k in self.featmap_names:
                x_filtered.append(v)
        num_levels = len(x_filtered)
205
206
        rois = self.convert_to_roi_format(boxes)
        if self.scales is None:
eellison's avatar
eellison committed
207
208
209
210
            self.setup_scales(x_filtered, image_shapes)

        scales = self.scales
        assert scales is not None
211
212
213

        if num_levels == 1:
            return roi_align(
eellison's avatar
eellison committed
214
                x_filtered[0], rois,
215
                output_size=self.output_size,
eellison's avatar
eellison committed
216
                spatial_scale=scales[0],
217
218
219
                sampling_ratio=self.sampling_ratio
            )

eellison's avatar
eellison committed
220
221
222
223
        mapper = self.map_levels
        assert mapper is not None

        levels = mapper(boxes)
224
225

        num_rois = len(rois)
eellison's avatar
eellison committed
226
        num_channels = x_filtered[0].shape[1]
227

eellison's avatar
eellison committed
228
        dtype, device = x_filtered[0].dtype, x_filtered[0].device
229
230
231
232
233
234
        result = torch.zeros(
            (num_rois, num_channels,) + self.output_size,
            dtype=dtype,
            device=device,
        )

eellison's avatar
eellison committed
235
236
        tracing_results = []
        for level, (per_level_feature, scale) in enumerate(zip(x_filtered, scales)):
237
            idx_in_level = torch.where(levels == level)[0]
238
239
            rois_per_level = rois[idx_in_level]

240
            result_idx_in_level = roi_align(
241
242
                per_level_feature, rois_per_level,
                output_size=self.output_size,
243
244
245
                spatial_scale=scale, sampling_ratio=self.sampling_ratio)

            if torchvision._is_tracing():
eellison's avatar
eellison committed
246
                tracing_results.append(result_idx_in_level.to(dtype))
247
            else:
248
249
250
251
252
253
254
255
                # result and result_idx_in_level's dtypes are based on dtypes of different
                # elements in x_filtered.  x_filtered contains tensors output by different
                # layers.  When autocast is active, it may choose different dtypes for
                # different layers' outputs.  Therefore, we defensively match result's dtype
                # before copying elements from result_idx_in_level in the following op.
                # We need to cast manually (can't rely on autocast to cast for us) because
                # the op acts on result in-place, and autocast only affects out-of-place ops.
                result[idx_in_level] = result_idx_in_level.to(result.dtype)
256

257
        if torchvision._is_tracing():
eellison's avatar
eellison committed
258
259
            result = _onnx_merge_levels(levels, tracing_results)

260
        return result
261
262
263
264

    def __repr__(self) -> str:
        return (f"{self.__class__.__name__}(featmap_names={self.featmap_names}, "
                f"output_size={self.output_size}, sampling_ratio={self.sampling_ratio})")