anchor_utils.py 10.6 KB
Newer Older
1
import math
2
import torch
3
from torch import nn, Tensor
4

5
from typing import List, Optional
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from .image_list import ImageList


class AnchorGenerator(nn.Module):
    """
    Module that generates anchors for a set of feature maps and
    image sizes.

    The module support computing anchors at multiple sizes and aspect ratios
    per feature map. This module assumes aspect ratio = height / width for
    each anchor.

    sizes and aspect_ratios should have the same number of elements, and it should
    correspond to the number of feature maps.

    sizes[i] and aspect_ratios[i] can have an arbitrary number of elements,
    and AnchorGenerator will output a set of sizes[i] * aspect_ratios[i] anchors
    per spatial location for feature map i.

25
    Args:
26
27
28
29
30
        sizes (Tuple[Tuple[int]]):
        aspect_ratios (Tuple[Tuple[float]]):
    """

    __annotations__ = {
31
        "cell_anchors": List[torch.Tensor],
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    }

    def __init__(
        self,
        sizes=((128, 256, 512),),
        aspect_ratios=((0.5, 1.0, 2.0),),
    ):
        super(AnchorGenerator, self).__init__()

        if not isinstance(sizes[0], (list, tuple)):
            # TODO change this
            sizes = tuple((s,) for s in sizes)
        if not isinstance(aspect_ratios[0], (list, tuple)):
            aspect_ratios = (aspect_ratios,) * len(sizes)

        assert len(sizes) == len(aspect_ratios)

        self.sizes = sizes
        self.aspect_ratios = aspect_ratios
51
52
        self.cell_anchors = [self.generate_anchors(size, aspect_ratio)
                             for size, aspect_ratio in zip(sizes, aspect_ratios)]
53
54
55
56
57

    # TODO: https://github.com/pytorch/pytorch/issues/26792
    # For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values.
    # (scales, aspect_ratios) are usually an element of zip(self.scales, self.aspect_ratios)
    # This method assumes aspect ratio = height / width for an anchor.
58
59
    def generate_anchors(self, scales: List[int], aspect_ratios: List[float], dtype: torch.dtype = torch.float32,
                         device: torch.device = torch.device("cpu")):
60
61
62
63
64
65
66
67
68
69
70
        scales = torch.as_tensor(scales, dtype=dtype, device=device)
        aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device)
        h_ratios = torch.sqrt(aspect_ratios)
        w_ratios = 1 / h_ratios

        ws = (w_ratios[:, None] * scales[None, :]).view(-1)
        hs = (h_ratios[:, None] * scales[None, :]).view(-1)

        base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2
        return base_anchors.round()

71
    def set_cell_anchors(self, dtype: torch.dtype, device: torch.device):
72
73
        self.cell_anchors = [cell_anchor.to(dtype=dtype, device=device)
                             for cell_anchor in self.cell_anchors]
74
75
76
77
78
79

    def num_anchors_per_location(self):
        return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)]

    # For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2),
    # output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a.
80
    def grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[Tensor]]) -> List[Tensor]:
81
82
83
        anchors = []
        cell_anchors = self.cell_anchors
        assert cell_anchors is not None
84
85

        if not (len(grid_sizes) == len(strides) == len(cell_anchors)):
Anirudh's avatar
Anirudh committed
86
            raise ValueError("Anchors should be Tuple[Tuple[int]] because each feature "
87
88
89
                             "map could potentially have different sizes and aspect ratios. "
                             "There needs to be a match between the number of "
                             "feature maps passed and the number of sizes / aspect ratios specified.")
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117

        for size, stride, base_anchors in zip(
            grid_sizes, strides, cell_anchors
        ):
            grid_height, grid_width = size
            stride_height, stride_width = stride
            device = base_anchors.device

            # For output anchor, compute [x_center, y_center, x_center, y_center]
            shifts_x = torch.arange(
                0, grid_width, dtype=torch.float32, device=device
            ) * stride_width
            shifts_y = torch.arange(
                0, grid_height, dtype=torch.float32, device=device
            ) * stride_height
            shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
            shift_x = shift_x.reshape(-1)
            shift_y = shift_y.reshape(-1)
            shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)

            # For every (base anchor, output anchor) pair,
            # offset each zero-centered base anchor by the center of the output anchor.
            anchors.append(
                (shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4)
            )

        return anchors

118
    def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Tensor]:
119
        grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]
120
121
122
123
124
        image_size = image_list.tensors.shape[-2:]
        dtype, device = feature_maps[0].dtype, feature_maps[0].device
        strides = [[torch.tensor(image_size[0] // g[0], dtype=torch.int64, device=device),
                    torch.tensor(image_size[1] // g[1], dtype=torch.int64, device=device)] for g in grid_sizes]
        self.set_cell_anchors(dtype, device)
125
        anchors_over_all_feature_maps = self.grid_anchors(grid_sizes, strides)
126
        anchors: List[List[torch.Tensor]] = []
127
        for _ in range(len(image_list.image_sizes)):
128
            anchors_in_image = [anchors_per_feature_map for anchors_per_feature_map in anchors_over_all_feature_maps]
129
130
131
            anchors.append(anchors_in_image)
        anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors]
        return anchors
132
133
134
135
136
137
138
139
140


class DefaultBoxGenerator(nn.Module):
    """
    This module generates the default boxes of SSD for a set of feature maps and image sizes.

    Args:
        aspect_ratios (List[List[int]]): A list with all the aspect ratios used in each feature map.
        min_ratio (float): The minimum scale :math:`\text{s}_{\text{min}}` of the default boxes used in the estimation
141
            of the scales of each feature map. It is used only if the ``scales`` parameter is not provided.
142
        max_ratio (float): The maximum scale :math:`\text{s}_{\text{max}}`  of the default boxes used in the estimation
143
144
145
            of the scales of each feature map. It is used only if the ``scales`` parameter is not provided.
        scales (List[float]], optional): The scales of the default boxes. If not provided it will be estimated using
            the ``min_ratio`` and ``max_ratio`` parameters.
146
147
148
149
150
151
152
        steps (List[int]], optional): It's a hyper-parameter that affects the tiling of defalt boxes. If not provided
            it will be estimated from the data.
        clip (bool): Whether the standardized values of default boxes should be clipped between 0 and 1. The clipping
            is applied while the boxes are encoded in format ``(cx, cy, w, h)``.
    """

    def __init__(self, aspect_ratios: List[List[int]], min_ratio: float = 0.15, max_ratio: float = 0.9,
153
                 scales: Optional[List[float]] = None, steps: Optional[List[int]] = None, clip: bool = True):
154
155
156
157
158
159
160
161
162
        super().__init__()
        if steps is not None:
            assert len(aspect_ratios) == len(steps)
        self.aspect_ratios = aspect_ratios
        self.steps = steps
        self.clip = clip
        num_outputs = len(aspect_ratios)

        # Estimation of default boxes scales
163
164
165
166
167
168
169
170
171
        if scales is None:
            if num_outputs > 1:
                range_ratio = max_ratio - min_ratio
                self.scales = [min_ratio + range_ratio * k / (num_outputs - 1.0) for k in range(num_outputs)]
                self.scales.append(1.0)
            else:
                self.scales = [min_ratio, max_ratio]
        else:
            self.scales = scales
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211

        self._wh_pairs = []
        for k in range(num_outputs):
            # Adding the 2 default width-height pairs for aspect ratio 1 and scale s'k
            s_k = self.scales[k]
            s_prime_k = math.sqrt(self.scales[k] * self.scales[k + 1])
            wh_pairs = [(s_k, s_k), (s_prime_k, s_prime_k)]

            # Adding 2 pairs for each aspect ratio of the feature map k
            for ar in self.aspect_ratios[k]:
                sq_ar = math.sqrt(ar)
                w = self.scales[k] * sq_ar
                h = self.scales[k] / sq_ar
                wh_pairs.extend([(w, h), (h, w)])

            self._wh_pairs.append(wh_pairs)

    def num_anchors_per_location(self):
        # Estimate num of anchors based on aspect ratios: 2 default boxes + 2 * ratios of feaure map.
        return [2 + 2 * len(r) for r in self.aspect_ratios]

    def __repr__(self) -> str:
        s = self.__class__.__name__ + '('
        s += 'aspect_ratios={aspect_ratios}'
        s += ', clip={clip}'
        s += ', scales={scales}'
        s += ', steps={steps}'
        s += ')'
        return s.format(**self.__dict__)

    def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Tensor]:
        grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]
        image_size = image_list.tensors.shape[-2:]
        dtype, device = feature_maps[0].dtype, feature_maps[0].device

        # Default Boxes calculation based on page 6 of SSD paper
        default_boxes: List[List[float]] = []
        for k, f_k in enumerate(grid_sizes):
            # Now add the default boxes for each width-height pair
            for j in range(f_k[0]):
212
213
214
215
216
                if self.steps is not None:
                    y_f_k = image_size[1] / self.steps[k]
                else:
                    y_f_k = float(f_k[0])
                cy = (j + 0.5) / y_f_k
217
                for i in range(f_k[1]):
218
219
220
221
222
                    if self.steps is not None:
                        x_f_k = image_size[0] / self.steps[k]
                    else:
                        x_f_k = float(f_k[1])
                    cx = (i + 0.5) / x_f_k
223
224
225
226
227
228
229
230
231
232
233
234
235
                    default_boxes.extend([[cx, cy, w, h] for w, h in self._wh_pairs[k]])

        dboxes = []
        for _ in image_list.image_sizes:
            dboxes_in_image = torch.tensor(default_boxes, dtype=dtype, device=device)
            if self.clip:
                dboxes_in_image.clamp_(min=0, max=1)
            dboxes_in_image = torch.cat([dboxes_in_image[:, :2] - 0.5 * dboxes_in_image[:, 2:],
                                         dboxes_in_image[:, :2] + 0.5 * dboxes_in_image[:, 2:]], -1)
            dboxes_in_image[:, 0::2] *= image_size[1]
            dboxes_in_image[:, 1::2] *= image_size[0]
            dboxes.append(dboxes_in_image)
        return dboxes