poolers.py 10.5 KB
Newer Older
1
import torch
eellison's avatar
eellison committed
2
from torch import nn, Tensor
3

4
import torchvision
5
6
7
from torchvision.ops import roi_align
from torchvision.ops.boxes import box_area

8
from typing import Optional, List, Dict, Tuple, Union
9

Francisco Massa's avatar
Francisco Massa committed
10

11
12
13
14
# copying result_idx_in_level to a specific index in result[]
# is not supported by ONNX tracing yet.
# _onnx_merge_levels() is an implementation supported by ONNX
# that merges the levels to the right indices
eellison's avatar
eellison committed
15
@torch.jit.unused
16
def _onnx_merge_levels(levels: Tensor, unmerged_results: List[Tensor]) -> Tensor:
17
18
19
20
21
    first_result = unmerged_results[0]
    dtype, device = first_result.dtype, first_result.device
    res = torch.zeros((levels.size(0), first_result.size(1),
                       first_result.size(2), first_result.size(3)),
                      dtype=dtype, device=device)
Francisco Massa's avatar
Francisco Massa committed
22
    for level in range(len(unmerged_results)):
23
        index = torch.where(levels == level)[0].view(-1, 1, 1, 1)
24
        index = index.expand(index.size(0),
Francisco Massa's avatar
Francisco Massa committed
25
26
27
28
                             unmerged_results[level].size(1),
                             unmerged_results[level].size(2),
                             unmerged_results[level].size(3))
        res = res.scatter(0, index, unmerged_results[level])
29
30
31
    return res


eellison's avatar
eellison committed
32
# TODO: (eellison) T54974082 https://github.com/pytorch/pytorch/issues/26744/pytorch/issues/26744
33
34
35
36
37
38
39
def initLevelMapper(
    k_min: int,
    k_max: int,
    canonical_scale: int = 224,
    canonical_level: int = 4,
    eps: float = 1e-6,
):
eellison's avatar
eellison committed
40
41
42
    return LevelMapper(k_min, k_max, canonical_scale, canonical_level, eps)


43
44
45
class LevelMapper(object):
    """Determine which FPN level each RoI in a set of RoIs should map to based
    on the heuristic in the FPN paper.
46

47
    Args:
48
49
50
51
52
        k_min (int)
        k_max (int)
        canonical_scale (int)
        canonical_level (int)
        eps (float)
53
54
    """

55
56
57
58
59
60
61
62
    def __init__(
        self,
        k_min: int,
        k_max: int,
        canonical_scale: int = 224,
        canonical_level: int = 4,
        eps: float = 1e-6,
    ):
63
64
65
66
67
68
        self.k_min = k_min
        self.k_max = k_max
        self.s0 = canonical_scale
        self.lvl0 = canonical_level
        self.eps = eps

69
    def __call__(self, boxlists: List[Tensor]) -> Tensor:
70
        """
71
        Args:
72
73
74
75
76
77
            boxlists (list[BoxList])
        """
        # Compute level ids
        s = torch.sqrt(torch.cat([box_area(boxlist) for boxlist in boxlists]))

        # Eqn.(1) in FPN paper
78
        target_lvls = torch.floor(self.lvl0 + torch.log2(s / self.s0) + torch.tensor(self.eps, dtype=s.dtype))
79
        target_lvls = torch.clamp(target_lvls, min=self.k_min, max=self.k_max)
80
        return (target_lvls.to(torch.int64) - self.k_min).to(torch.int64)
81
82
83
84


class MultiScaleRoIAlign(nn.Module):
    """
85
86
    Multi-scale RoIAlign pooling, which is useful for detection with or without FPN.

87
88
89
90
91
92
    It infers the scale of the pooling via the heuristics specified in eq. 1
    of the `Feature Pyramid Network paper <https://arxiv.org/abs/1612.03144>`_.
    They keyword-only parameters ``canonical_scale`` and ``canonical_level``
    correspond respectively to ``224`` and ``k0=4`` in eq. 1, and
    have the following meaning: ``canonical_level`` is the target level of the pyramid from
    which to pool a region of interest with ``w x h = canonical_scale x canonical_scale``.
93

94
    Args:
95
96
97
98
        featmap_names (List[str]): the names of the feature maps that will be used
            for the pooling.
        output_size (List[Tuple[int, int]] or List[int]): output size for the pooled region
        sampling_ratio (int): sampling ratio for ROIAlign
99
100
        canonical_scale (int, optional): canonical_scale for LevelMapper
        canonical_level (int, optional): canonical_level for LevelMapper
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116

    Examples::

        >>> m = torchvision.ops.MultiScaleRoIAlign(['feat1', 'feat3'], 3, 2)
        >>> i = OrderedDict()
        >>> i['feat1'] = torch.rand(1, 5, 64, 64)
        >>> i['feat2'] = torch.rand(1, 5, 32, 32)  # this feature won't be used in the pooling
        >>> i['feat3'] = torch.rand(1, 5, 16, 16)
        >>> # create some random bounding boxes
        >>> boxes = torch.rand(6, 4) * 256; boxes[:, 2:] += boxes[:, :2]
        >>> # original image size, before computing the feature maps
        >>> image_sizes = [(512, 512)]
        >>> output = m(i, [boxes], image_sizes)
        >>> print(output.shape)
        >>> torch.Size([6, 5, 3, 3])

117
118
    """

eellison's avatar
eellison committed
119
120
121
122
123
    __annotations__ = {
        'scales': Optional[List[float]],
        'map_levels': Optional[LevelMapper]
    }

124
125
126
    def __init__(
        self,
        featmap_names: List[str],
127
        output_size: Union[int, Tuple[int], List[int]],
128
        sampling_ratio: int,
129
130
131
        *,
        canonical_scale: int = 224,
        canonical_level: int = 4,
132
    ):
133
134
135
136
137
138
139
140
        super(MultiScaleRoIAlign, self).__init__()
        if isinstance(output_size, int):
            output_size = (output_size, output_size)
        self.featmap_names = featmap_names
        self.sampling_ratio = sampling_ratio
        self.output_size = tuple(output_size)
        self.scales = None
        self.map_levels = None
141
142
        self.canonical_scale = canonical_scale
        self.canonical_level = canonical_level
143

144
    def convert_to_roi_format(self, boxes: List[Tensor]) -> Tensor:
145
146
147
148
        concat_boxes = torch.cat(boxes, dim=0)
        device, dtype = concat_boxes.device, concat_boxes.dtype
        ids = torch.cat(
            [
eellison's avatar
eellison committed
149
                torch.full_like(b[:, :1], i, dtype=dtype, layout=torch.strided, device=device)
150
151
152
153
154
155
156
                for i, b in enumerate(boxes)
            ],
            dim=0,
        )
        rois = torch.cat([ids, concat_boxes], dim=1)
        return rois

157
    def infer_scale(self, feature: Tensor, original_size: List[int]) -> float:
158
159
        # assumption: the scale is of the form 2 ** (-k), with k integer
        size = feature.shape[-2:]
160
        possible_scales: List[float] = []
161
        for s1, s2 in zip(size, original_size):
162
            approx_scale = float(s1) / float(s2)
eellison's avatar
eellison committed
163
            scale = 2 ** float(torch.tensor(approx_scale).log2().round())
164
165
166
167
            possible_scales.append(scale)
        assert possible_scales[0] == possible_scales[1]
        return possible_scales[0]

168
169
170
171
172
    def setup_scales(
        self,
        features: List[Tensor],
        image_shapes: List[Tuple[int, int]],
    ) -> None:
eellison's avatar
eellison committed
173
174
175
176
177
178
179
180
        assert len(image_shapes) != 0
        max_x = 0
        max_y = 0
        for shape in image_shapes:
            max_x = max(shape[0], max_x)
            max_y = max(shape[1], max_y)
        original_input_shape = (max_x, max_y)

181
182
183
184
185
186
        scales = [self.infer_scale(feat, original_input_shape) for feat in features]
        # get the levels in the feature map by leveraging the fact that the network always
        # downsamples by a factor of 2 at each level.
        lvl_min = -torch.log2(torch.tensor(scales[0], dtype=torch.float32)).item()
        lvl_max = -torch.log2(torch.tensor(scales[-1], dtype=torch.float32)).item()
        self.scales = scales
187
188
189
190
191
192
        self.map_levels = initLevelMapper(
            int(lvl_min),
            int(lvl_max),
            canonical_scale=self.canonical_scale,
            canonical_level=self.canonical_level,
        )
193

194
195
196
197
198
199
    def forward(
        self,
        x: Dict[str, Tensor],
        boxes: List[Tensor],
        image_shapes: List[Tuple[int, int]],
    ) -> Tensor:
200
        """
201
        Args:
202
203
204
            x (OrderedDict[Tensor]): feature maps for each level. They are assumed to have
                all the same number of channels, but they can have different sizes.
            boxes (List[Tensor[N, 4]]): boxes to be used to perform the pooling operation, in
205
                (x1, y1, x2, y2) format and in the image reference size, not the feature map
206
                reference. The coordinate must satisfy ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
207
208
209
            image_shapes (List[Tuple[height, width]]): the sizes of each image before they
                have been fed to a CNN to obtain feature maps. This allows us to infer the
                scale factor for each one of the levels to be pooled.
210
211
212
        Returns:
            result (Tensor)
        """
eellison's avatar
eellison committed
213
214
215
216
217
        x_filtered = []
        for k, v in x.items():
            if k in self.featmap_names:
                x_filtered.append(v)
        num_levels = len(x_filtered)
218
219
        rois = self.convert_to_roi_format(boxes)
        if self.scales is None:
eellison's avatar
eellison committed
220
221
222
223
            self.setup_scales(x_filtered, image_shapes)

        scales = self.scales
        assert scales is not None
224
225
226

        if num_levels == 1:
            return roi_align(
eellison's avatar
eellison committed
227
                x_filtered[0], rois,
228
                output_size=self.output_size,
eellison's avatar
eellison committed
229
                spatial_scale=scales[0],
230
231
232
                sampling_ratio=self.sampling_ratio
            )

eellison's avatar
eellison committed
233
234
235
236
        mapper = self.map_levels
        assert mapper is not None

        levels = mapper(boxes)
237
238

        num_rois = len(rois)
eellison's avatar
eellison committed
239
        num_channels = x_filtered[0].shape[1]
240

eellison's avatar
eellison committed
241
        dtype, device = x_filtered[0].dtype, x_filtered[0].device
242
243
244
245
246
247
        result = torch.zeros(
            (num_rois, num_channels,) + self.output_size,
            dtype=dtype,
            device=device,
        )

eellison's avatar
eellison committed
248
249
        tracing_results = []
        for level, (per_level_feature, scale) in enumerate(zip(x_filtered, scales)):
250
            idx_in_level = torch.where(levels == level)[0]
251
252
            rois_per_level = rois[idx_in_level]

253
            result_idx_in_level = roi_align(
254
255
                per_level_feature, rois_per_level,
                output_size=self.output_size,
256
257
258
                spatial_scale=scale, sampling_ratio=self.sampling_ratio)

            if torchvision._is_tracing():
eellison's avatar
eellison committed
259
                tracing_results.append(result_idx_in_level.to(dtype))
260
            else:
261
262
263
264
265
266
267
268
                # result and result_idx_in_level's dtypes are based on dtypes of different
                # elements in x_filtered.  x_filtered contains tensors output by different
                # layers.  When autocast is active, it may choose different dtypes for
                # different layers' outputs.  Therefore, we defensively match result's dtype
                # before copying elements from result_idx_in_level in the following op.
                # We need to cast manually (can't rely on autocast to cast for us) because
                # the op acts on result in-place, and autocast only affects out-of-place ops.
                result[idx_in_level] = result_idx_in_level.to(result.dtype)
269

270
        if torchvision._is_tracing():
eellison's avatar
eellison committed
271
272
            result = _onnx_merge_levels(levels, tracing_results)

273
        return result
274
275
276
277

    def __repr__(self) -> str:
        return (f"{self.__class__.__name__}(featmap_names={self.featmap_names}, "
                f"output_size={self.output_size}, sampling_ratio={self.sampling_ratio})")