poolers.py 10.5 KB
Newer Older
1
from typing import Optional, List, Dict, Tuple, Union
2

3
import torch
4
import torchvision
5
from torch import nn, Tensor
6
7
from torchvision.ops.boxes import box_area

8
from .roi_align import roi_align
9

Francisco Massa's avatar
Francisco Massa committed
10

11
12
13
14
# copying result_idx_in_level to a specific index in result[]
# is not supported by ONNX tracing yet.
# _onnx_merge_levels() is an implementation supported by ONNX
# that merges the levels to the right indices
eellison's avatar
eellison committed
15
@torch.jit.unused
16
def _onnx_merge_levels(levels: Tensor, unmerged_results: List[Tensor]) -> Tensor:
17
18
    first_result = unmerged_results[0]
    dtype, device = first_result.dtype, first_result.device
19
20
21
    res = torch.zeros(
        (levels.size(0), first_result.size(1), first_result.size(2), first_result.size(3)), dtype=dtype, device=device
    )
Francisco Massa's avatar
Francisco Massa committed
22
    for level in range(len(unmerged_results)):
23
        index = torch.where(levels == level)[0].view(-1, 1, 1, 1)
24
25
26
27
28
29
        index = index.expand(
            index.size(0),
            unmerged_results[level].size(1),
            unmerged_results[level].size(2),
            unmerged_results[level].size(3),
        )
Francisco Massa's avatar
Francisco Massa committed
30
        res = res.scatter(0, index, unmerged_results[level])
31
32
33
    return res


eellison's avatar
eellison committed
34
# TODO: (eellison) T54974082 https://github.com/pytorch/pytorch/issues/26744/pytorch/issues/26744
35
36
37
38
39
40
41
def initLevelMapper(
    k_min: int,
    k_max: int,
    canonical_scale: int = 224,
    canonical_level: int = 4,
    eps: float = 1e-6,
):
eellison's avatar
eellison committed
42
43
44
    return LevelMapper(k_min, k_max, canonical_scale, canonical_level, eps)


45
class LevelMapper:
46
47
    """Determine which FPN level each RoI in a set of RoIs should map to based
    on the heuristic in the FPN paper.
48

49
    Args:
50
51
52
53
54
        k_min (int)
        k_max (int)
        canonical_scale (int)
        canonical_level (int)
        eps (float)
55
56
    """

57
58
59
60
61
62
63
64
    def __init__(
        self,
        k_min: int,
        k_max: int,
        canonical_scale: int = 224,
        canonical_level: int = 4,
        eps: float = 1e-6,
    ):
65
66
67
68
69
70
        self.k_min = k_min
        self.k_max = k_max
        self.s0 = canonical_scale
        self.lvl0 = canonical_level
        self.eps = eps

71
    def __call__(self, boxlists: List[Tensor]) -> Tensor:
72
        """
73
        Args:
74
75
76
77
78
79
            boxlists (list[BoxList])
        """
        # Compute level ids
        s = torch.sqrt(torch.cat([box_area(boxlist) for boxlist in boxlists]))

        # Eqn.(1) in FPN paper
80
        target_lvls = torch.floor(self.lvl0 + torch.log2(s / self.s0) + torch.tensor(self.eps, dtype=s.dtype))
81
        target_lvls = torch.clamp(target_lvls, min=self.k_min, max=self.k_max)
82
        return (target_lvls.to(torch.int64) - self.k_min).to(torch.int64)
83
84
85
86


class MultiScaleRoIAlign(nn.Module):
    """
87
88
    Multi-scale RoIAlign pooling, which is useful for detection with or without FPN.

89
90
91
92
93
94
    It infers the scale of the pooling via the heuristics specified in eq. 1
    of the `Feature Pyramid Network paper <https://arxiv.org/abs/1612.03144>`_.
    They keyword-only parameters ``canonical_scale`` and ``canonical_level``
    correspond respectively to ``224`` and ``k0=4`` in eq. 1, and
    have the following meaning: ``canonical_level`` is the target level of the pyramid from
    which to pool a region of interest with ``w x h = canonical_scale x canonical_scale``.
95

96
    Args:
97
98
99
100
        featmap_names (List[str]): the names of the feature maps that will be used
            for the pooling.
        output_size (List[Tuple[int, int]] or List[int]): output size for the pooled region
        sampling_ratio (int): sampling ratio for ROIAlign
101
102
        canonical_scale (int, optional): canonical_scale for LevelMapper
        canonical_level (int, optional): canonical_level for LevelMapper
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118

    Examples::

        >>> m = torchvision.ops.MultiScaleRoIAlign(['feat1', 'feat3'], 3, 2)
        >>> i = OrderedDict()
        >>> i['feat1'] = torch.rand(1, 5, 64, 64)
        >>> i['feat2'] = torch.rand(1, 5, 32, 32)  # this feature won't be used in the pooling
        >>> i['feat3'] = torch.rand(1, 5, 16, 16)
        >>> # create some random bounding boxes
        >>> boxes = torch.rand(6, 4) * 256; boxes[:, 2:] += boxes[:, :2]
        >>> # original image size, before computing the feature maps
        >>> image_sizes = [(512, 512)]
        >>> output = m(i, [boxes], image_sizes)
        >>> print(output.shape)
        >>> torch.Size([6, 5, 3, 3])

119
120
    """

121
    __annotations__ = {"scales": Optional[List[float]], "map_levels": Optional[LevelMapper]}
eellison's avatar
eellison committed
122

123
124
125
    def __init__(
        self,
        featmap_names: List[str],
126
        output_size: Union[int, Tuple[int], List[int]],
127
        sampling_ratio: int,
128
129
130
        *,
        canonical_scale: int = 224,
        canonical_level: int = 4,
131
    ):
132
        super().__init__()
133
134
135
136
137
138
139
        if isinstance(output_size, int):
            output_size = (output_size, output_size)
        self.featmap_names = featmap_names
        self.sampling_ratio = sampling_ratio
        self.output_size = tuple(output_size)
        self.scales = None
        self.map_levels = None
140
141
        self.canonical_scale = canonical_scale
        self.canonical_level = canonical_level
142

143
    def convert_to_roi_format(self, boxes: List[Tensor]) -> Tensor:
144
145
146
147
        concat_boxes = torch.cat(boxes, dim=0)
        device, dtype = concat_boxes.device, concat_boxes.dtype
        ids = torch.cat(
            [
eellison's avatar
eellison committed
148
                torch.full_like(b[:, :1], i, dtype=dtype, layout=torch.strided, device=device)
149
150
151
152
153
154
155
                for i, b in enumerate(boxes)
            ],
            dim=0,
        )
        rois = torch.cat([ids, concat_boxes], dim=1)
        return rois

156
    def infer_scale(self, feature: Tensor, original_size: List[int]) -> float:
157
158
        # assumption: the scale is of the form 2 ** (-k), with k integer
        size = feature.shape[-2:]
159
        possible_scales: List[float] = []
160
        for s1, s2 in zip(size, original_size):
161
            approx_scale = float(s1) / float(s2)
eellison's avatar
eellison committed
162
            scale = 2 ** float(torch.tensor(approx_scale).log2().round())
163
164
165
166
            possible_scales.append(scale)
        assert possible_scales[0] == possible_scales[1]
        return possible_scales[0]

167
168
169
170
171
    def setup_scales(
        self,
        features: List[Tensor],
        image_shapes: List[Tuple[int, int]],
    ) -> None:
eellison's avatar
eellison committed
172
173
174
175
176
177
178
179
        assert len(image_shapes) != 0
        max_x = 0
        max_y = 0
        for shape in image_shapes:
            max_x = max(shape[0], max_x)
            max_y = max(shape[1], max_y)
        original_input_shape = (max_x, max_y)

180
181
182
183
184
185
        scales = [self.infer_scale(feat, original_input_shape) for feat in features]
        # get the levels in the feature map by leveraging the fact that the network always
        # downsamples by a factor of 2 at each level.
        lvl_min = -torch.log2(torch.tensor(scales[0], dtype=torch.float32)).item()
        lvl_max = -torch.log2(torch.tensor(scales[-1], dtype=torch.float32)).item()
        self.scales = scales
186
187
188
189
190
191
        self.map_levels = initLevelMapper(
            int(lvl_min),
            int(lvl_max),
            canonical_scale=self.canonical_scale,
            canonical_level=self.canonical_level,
        )
192

193
194
195
196
197
198
    def forward(
        self,
        x: Dict[str, Tensor],
        boxes: List[Tensor],
        image_shapes: List[Tuple[int, int]],
    ) -> Tensor:
199
        """
200
        Args:
201
202
203
            x (OrderedDict[Tensor]): feature maps for each level. They are assumed to have
                all the same number of channels, but they can have different sizes.
            boxes (List[Tensor[N, 4]]): boxes to be used to perform the pooling operation, in
204
                (x1, y1, x2, y2) format and in the image reference size, not the feature map
205
                reference. The coordinate must satisfy ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
206
207
208
            image_shapes (List[Tuple[height, width]]): the sizes of each image before they
                have been fed to a CNN to obtain feature maps. This allows us to infer the
                scale factor for each one of the levels to be pooled.
209
210
211
        Returns:
            result (Tensor)
        """
eellison's avatar
eellison committed
212
213
214
215
216
        x_filtered = []
        for k, v in x.items():
            if k in self.featmap_names:
                x_filtered.append(v)
        num_levels = len(x_filtered)
217
218
        rois = self.convert_to_roi_format(boxes)
        if self.scales is None:
eellison's avatar
eellison committed
219
220
221
222
            self.setup_scales(x_filtered, image_shapes)

        scales = self.scales
        assert scales is not None
223
224
225

        if num_levels == 1:
            return roi_align(
226
227
                x_filtered[0],
                rois,
228
                output_size=self.output_size,
eellison's avatar
eellison committed
229
                spatial_scale=scales[0],
230
                sampling_ratio=self.sampling_ratio,
231
232
            )

eellison's avatar
eellison committed
233
234
235
236
        mapper = self.map_levels
        assert mapper is not None

        levels = mapper(boxes)
237
238

        num_rois = len(rois)
eellison's avatar
eellison committed
239
        num_channels = x_filtered[0].shape[1]
240

eellison's avatar
eellison committed
241
        dtype, device = x_filtered[0].dtype, x_filtered[0].device
242
        result = torch.zeros(
243
244
245
246
247
            (
                num_rois,
                num_channels,
            )
            + self.output_size,
248
249
250
251
            dtype=dtype,
            device=device,
        )

eellison's avatar
eellison committed
252
253
        tracing_results = []
        for level, (per_level_feature, scale) in enumerate(zip(x_filtered, scales)):
254
            idx_in_level = torch.where(levels == level)[0]
255
256
            rois_per_level = rois[idx_in_level]

257
            result_idx_in_level = roi_align(
258
259
                per_level_feature,
                rois_per_level,
260
                output_size=self.output_size,
261
262
263
                spatial_scale=scale,
                sampling_ratio=self.sampling_ratio,
            )
264
265

            if torchvision._is_tracing():
eellison's avatar
eellison committed
266
                tracing_results.append(result_idx_in_level.to(dtype))
267
            else:
268
269
270
271
272
273
274
275
                # result and result_idx_in_level's dtypes are based on dtypes of different
                # elements in x_filtered.  x_filtered contains tensors output by different
                # layers.  When autocast is active, it may choose different dtypes for
                # different layers' outputs.  Therefore, we defensively match result's dtype
                # before copying elements from result_idx_in_level in the following op.
                # We need to cast manually (can't rely on autocast to cast for us) because
                # the op acts on result in-place, and autocast only affects out-of-place ops.
                result[idx_in_level] = result_idx_in_level.to(result.dtype)
276

277
        if torchvision._is_tracing():
eellison's avatar
eellison committed
278
279
            result = _onnx_merge_levels(levels, tracing_results)

280
        return result
281
282

    def __repr__(self) -> str:
283
284
285
286
        return (
            f"{self.__class__.__name__}(featmap_names={self.featmap_names}, "
            f"output_size={self.output_size}, sampling_ratio={self.sampling_ratio})"
        )