poolers.py 9.68 KB
Newer Older
1
2
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
eellison's avatar
eellison committed
3
from torch import nn, Tensor
4

5
import torchvision
6
7
8
from torchvision.ops import roi_align
from torchvision.ops.boxes import box_area

9
from typing import Optional, List, Dict, Tuple, Union
10

Francisco Massa's avatar
Francisco Massa committed
11

12
13
14
15
# copying result_idx_in_level to a specific index in result[]
# is not supported by ONNX tracing yet.
# _onnx_merge_levels() is an implementation supported by ONNX
# that merges the levels to the right indices
eellison's avatar
eellison committed
16
@torch.jit.unused
17
def _onnx_merge_levels(levels: Tensor, unmerged_results: List[Tensor]) -> Tensor:
18
19
20
21
22
    first_result = unmerged_results[0]
    dtype, device = first_result.dtype, first_result.device
    res = torch.zeros((levels.size(0), first_result.size(1),
                       first_result.size(2), first_result.size(3)),
                      dtype=dtype, device=device)
Francisco Massa's avatar
Francisco Massa committed
23
    for level in range(len(unmerged_results)):
24
        index = torch.where(levels == level)[0].view(-1, 1, 1, 1)
25
        index = index.expand(index.size(0),
Francisco Massa's avatar
Francisco Massa committed
26
27
28
29
                             unmerged_results[level].size(1),
                             unmerged_results[level].size(2),
                             unmerged_results[level].size(3))
        res = res.scatter(0, index, unmerged_results[level])
30
31
32
    return res


eellison's avatar
eellison committed
33
# TODO: (eellison) T54974082 https://github.com/pytorch/pytorch/issues/26744/pytorch/issues/26744
34
35
36
37
38
39
40
def initLevelMapper(
    k_min: int,
    k_max: int,
    canonical_scale: int = 224,
    canonical_level: int = 4,
    eps: float = 1e-6,
):
eellison's avatar
eellison committed
41
42
43
    return LevelMapper(k_min, k_max, canonical_scale, canonical_level, eps)


44
45
46
class LevelMapper(object):
    """Determine which FPN level each RoI in a set of RoIs should map to based
    on the heuristic in the FPN paper.
47
48
49
50
51
52
53

    Arguments:
        k_min (int)
        k_max (int)
        canonical_scale (int)
        canonical_level (int)
        eps (float)
54
55
    """

56
57
58
59
60
61
62
63
    def __init__(
        self,
        k_min: int,
        k_max: int,
        canonical_scale: int = 224,
        canonical_level: int = 4,
        eps: float = 1e-6,
    ):
64
65
66
67
68
69
        self.k_min = k_min
        self.k_max = k_max
        self.s0 = canonical_scale
        self.lvl0 = canonical_level
        self.eps = eps

70
    def __call__(self, boxlists: List[Tensor]) -> Tensor:
71
72
73
74
75
76
77
78
        """
        Arguments:
            boxlists (list[BoxList])
        """
        # Compute level ids
        s = torch.sqrt(torch.cat([box_area(boxlist) for boxlist in boxlists]))

        # Eqn.(1) in FPN paper
79
        target_lvls = torch.floor(self.lvl0 + torch.log2(s / self.s0) + torch.tensor(self.eps, dtype=s.dtype))
80
        target_lvls = torch.clamp(target_lvls, min=self.k_min, max=self.k_max)
81
        return (target_lvls.to(torch.int64) - self.k_min).to(torch.int64)
82
83
84
85


class MultiScaleRoIAlign(nn.Module):
    """
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
    Multi-scale RoIAlign pooling, which is useful for detection with or without FPN.

    It infers the scale of the pooling via the heuristics present in the FPN paper.

    Arguments:
        featmap_names (List[str]): the names of the feature maps that will be used
            for the pooling.
        output_size (List[Tuple[int, int]] or List[int]): output size for the pooled region
        sampling_ratio (int): sampling ratio for ROIAlign

    Examples::

        >>> m = torchvision.ops.MultiScaleRoIAlign(['feat1', 'feat3'], 3, 2)
        >>> i = OrderedDict()
        >>> i['feat1'] = torch.rand(1, 5, 64, 64)
        >>> i['feat2'] = torch.rand(1, 5, 32, 32)  # this feature won't be used in the pooling
        >>> i['feat3'] = torch.rand(1, 5, 16, 16)
        >>> # create some random bounding boxes
        >>> boxes = torch.rand(6, 4) * 256; boxes[:, 2:] += boxes[:, :2]
        >>> # original image size, before computing the feature maps
        >>> image_sizes = [(512, 512)]
        >>> output = m(i, [boxes], image_sizes)
        >>> print(output.shape)
        >>> torch.Size([6, 5, 3, 3])

111
112
    """

eellison's avatar
eellison committed
113
114
115
116
117
    __annotations__ = {
        'scales': Optional[List[float]],
        'map_levels': Optional[LevelMapper]
    }

118
119
120
    def __init__(
        self,
        featmap_names: List[str],
121
        output_size: Union[int, Tuple[int], List[int]],
122
123
        sampling_ratio: int,
    ):
124
125
126
127
128
129
130
131
132
        super(MultiScaleRoIAlign, self).__init__()
        if isinstance(output_size, int):
            output_size = (output_size, output_size)
        self.featmap_names = featmap_names
        self.sampling_ratio = sampling_ratio
        self.output_size = tuple(output_size)
        self.scales = None
        self.map_levels = None

133
    def convert_to_roi_format(self, boxes: List[Tensor]) -> Tensor:
134
135
136
137
        concat_boxes = torch.cat(boxes, dim=0)
        device, dtype = concat_boxes.device, concat_boxes.dtype
        ids = torch.cat(
            [
eellison's avatar
eellison committed
138
                torch.full_like(b[:, :1], i, dtype=dtype, layout=torch.strided, device=device)
139
140
141
142
143
144
145
                for i, b in enumerate(boxes)
            ],
            dim=0,
        )
        rois = torch.cat([ids, concat_boxes], dim=1)
        return rois

146
    def infer_scale(self, feature: Tensor, original_size: List[int]) -> float:
147
148
        # assumption: the scale is of the form 2 ** (-k), with k integer
        size = feature.shape[-2:]
149
        possible_scales: List[float] = []
150
        for s1, s2 in zip(size, original_size):
151
            approx_scale = float(s1) / float(s2)
eellison's avatar
eellison committed
152
            scale = 2 ** float(torch.tensor(approx_scale).log2().round())
153
154
155
156
            possible_scales.append(scale)
        assert possible_scales[0] == possible_scales[1]
        return possible_scales[0]

157
158
159
160
161
    def setup_scales(
        self,
        features: List[Tensor],
        image_shapes: List[Tuple[int, int]],
    ) -> None:
eellison's avatar
eellison committed
162
163
164
165
166
167
168
169
        assert len(image_shapes) != 0
        max_x = 0
        max_y = 0
        for shape in image_shapes:
            max_x = max(shape[0], max_x)
            max_y = max(shape[1], max_y)
        original_input_shape = (max_x, max_y)

170
171
172
173
174
175
        scales = [self.infer_scale(feat, original_input_shape) for feat in features]
        # get the levels in the feature map by leveraging the fact that the network always
        # downsamples by a factor of 2 at each level.
        lvl_min = -torch.log2(torch.tensor(scales[0], dtype=torch.float32)).item()
        lvl_max = -torch.log2(torch.tensor(scales[-1], dtype=torch.float32)).item()
        self.scales = scales
eellison's avatar
eellison committed
176
        self.map_levels = initLevelMapper(int(lvl_min), int(lvl_max))
177

178
179
180
181
182
183
    def forward(
        self,
        x: Dict[str, Tensor],
        boxes: List[Tensor],
        image_shapes: List[Tuple[int, int]],
    ) -> Tensor:
184
185
        """
        Arguments:
186
187
188
            x (OrderedDict[Tensor]): feature maps for each level. They are assumed to have
                all the same number of channels, but they can have different sizes.
            boxes (List[Tensor[N, 4]]): boxes to be used to perform the pooling operation, in
189
                (x1, y1, x2, y2) format and in the image reference size, not the feature map
190
191
192
193
                reference.
            image_shapes (List[Tuple[height, width]]): the sizes of each image before they
                have been fed to a CNN to obtain feature maps. This allows us to infer the
                scale factor for each one of the levels to be pooled.
194
195
196
        Returns:
            result (Tensor)
        """
eellison's avatar
eellison committed
197
198
199
200
201
        x_filtered = []
        for k, v in x.items():
            if k in self.featmap_names:
                x_filtered.append(v)
        num_levels = len(x_filtered)
202
203
        rois = self.convert_to_roi_format(boxes)
        if self.scales is None:
eellison's avatar
eellison committed
204
205
206
207
            self.setup_scales(x_filtered, image_shapes)

        scales = self.scales
        assert scales is not None
208
209
210

        if num_levels == 1:
            return roi_align(
eellison's avatar
eellison committed
211
                x_filtered[0], rois,
212
                output_size=self.output_size,
eellison's avatar
eellison committed
213
                spatial_scale=scales[0],
214
215
216
                sampling_ratio=self.sampling_ratio
            )

eellison's avatar
eellison committed
217
218
219
220
        mapper = self.map_levels
        assert mapper is not None

        levels = mapper(boxes)
221
222

        num_rois = len(rois)
eellison's avatar
eellison committed
223
        num_channels = x_filtered[0].shape[1]
224

eellison's avatar
eellison committed
225
        dtype, device = x_filtered[0].dtype, x_filtered[0].device
226
227
228
229
230
231
        result = torch.zeros(
            (num_rois, num_channels,) + self.output_size,
            dtype=dtype,
            device=device,
        )

eellison's avatar
eellison committed
232
233
        tracing_results = []
        for level, (per_level_feature, scale) in enumerate(zip(x_filtered, scales)):
234
            idx_in_level = torch.where(levels == level)[0]
235
236
            rois_per_level = rois[idx_in_level]

237
            result_idx_in_level = roi_align(
238
239
                per_level_feature, rois_per_level,
                output_size=self.output_size,
240
241
242
                spatial_scale=scale, sampling_ratio=self.sampling_ratio)

            if torchvision._is_tracing():
eellison's avatar
eellison committed
243
                tracing_results.append(result_idx_in_level.to(dtype))
244
            else:
245
246
247
248
249
250
251
252
                # result and result_idx_in_level's dtypes are based on dtypes of different
                # elements in x_filtered.  x_filtered contains tensors output by different
                # layers.  When autocast is active, it may choose different dtypes for
                # different layers' outputs.  Therefore, we defensively match result's dtype
                # before copying elements from result_idx_in_level in the following op.
                # We need to cast manually (can't rely on autocast to cast for us) because
                # the op acts on result in-place, and autocast only affects out-of-place ops.
                result[idx_in_level] = result_idx_in_level.to(result.dtype)
253

254
        if torchvision._is_tracing():
eellison's avatar
eellison committed
255
256
            result = _onnx_merge_levels(levels, tracing_results)

257
        return result
258
259
260
261

    def __repr__(self) -> str:
        return (f"{self.__class__.__name__}(featmap_names={self.featmap_names}, "
                f"output_size={self.output_size}, sampling_ratio={self.sampling_ratio})")