sampling.py 14.1 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
2
3
4
5
6
7
import numpy as np
import torch

from .geometry import bbox_overlaps


def random_choice(gallery, num):
8
9
10
11
12
    """Random select some elements from the gallery.

    It seems that Pytorch's implementation is slower than numpy so we use numpy
    to randperm the indices.
    """
Kai Chen's avatar
Kai Chen committed
13
14
15
16
17
18
19
    assert len(gallery) >= num
    if isinstance(gallery, list):
        gallery = np.array(gallery)
    cands = np.arange(len(gallery))
    np.random.shuffle(cands)
    rand_inds = cands[:num]
    if not isinstance(gallery, np.ndarray):
20
        rand_inds = torch.from_numpy(rand_inds).long().to(gallery.device)
Kai Chen's avatar
Kai Chen committed
21
22
23
24
25
    return gallery[rand_inds]


def bbox_assign(proposals,
                gt_bboxes,
Kai Chen's avatar
Kai Chen committed
26
                gt_bboxes_ignore=None,
Kai Chen's avatar
Kai Chen committed
27
28
29
30
31
                gt_labels=None,
                pos_iou_thr=0.5,
                neg_iou_thr=0.5,
                min_pos_iou=.0,
                crowd_thr=-1):
Kai Chen's avatar
Kai Chen committed
32
33
34
35
36
37
38
39
40
41
42
    """Assign a corresponding gt bbox or background to each proposal/anchor.

    Each proposals will be assigned with `-1`, `0`, or a positive integer.

    - -1: don't care
    - 0: negative sample, no assigned gt
    - positive integer: positive sample, index (1-based) of assigned gt

    If `gt_bboxes_ignore` is specified, bboxes which have iof (intersection
    over foreground) with `gt_bboxes_ignore` above `crowd_thr` will be ignored.

Kai Chen's avatar
Kai Chen committed
43
    Args:
Kai Chen's avatar
Kai Chen committed
44
45
46
47
48
49
50
51
52
53
54
55
        proposals (Tensor): Proposals or RPN anchors, shape (n, 4).
        gt_bboxes (Tensor): Ground truth bboxes, shape (k, 4).
        gt_bboxes_ignore (Tensor, optional): shape(m, 4).
        gt_labels (Tensor, optional): shape (k, ).
        pos_iou_thr (float): IoU threshold for positive bboxes.
        neg_iou_thr (float or tuple): IoU threshold for negative bboxes.
        min_pos_iou (float): Minimum iou for a bbox to be considered as a
            positive bbox. For RPN, it is usually set as 0.3, for Fast R-CNN,
            it is usually set as pos_iou_thr
        crowd_thr (float): IoF threshold for ignoring bboxes. Negative value
            for not ignoring any bboxes.

Kai Chen's avatar
Kai Chen committed
56
57
58
59
60
61
62
63
64
65
    Returns:
        tuple: (assigned_gt_inds, argmax_overlaps, max_overlaps), shape (n, )
    """

    # calculate overlaps between the proposals and the gt boxes
    overlaps = bbox_overlaps(proposals, gt_bboxes)
    if overlaps.numel() == 0:
        raise ValueError('No gt bbox or proposals')

    # ignore proposals according to crowd bboxes
Kai Chen's avatar
Kai Chen committed
66
67
68
    if (crowd_thr > 0) and (gt_bboxes_ignore is
                            not None) and (gt_bboxes_ignore.numel() > 0):
        crowd_overlaps = bbox_overlaps(proposals, gt_bboxes_ignore, mode='iof')
Kai Chen's avatar
Kai Chen committed
69
70
71
72
73
74
        crowd_max_overlaps, _ = crowd_overlaps.max(dim=1)
        crowd_bboxes_inds = torch.nonzero(
            crowd_max_overlaps > crowd_thr).long()
        if crowd_bboxes_inds.numel() > 0:
            overlaps[crowd_bboxes_inds, :] = -1

Kai Chen's avatar
Kai Chen committed
75
    return bbox_assign_wrt_overlaps(overlaps, gt_labels, pos_iou_thr,
Kai Chen's avatar
Kai Chen committed
76
77
78
                                    neg_iou_thr, min_pos_iou)


Kai Chen's avatar
Kai Chen committed
79
def bbox_assign_wrt_overlaps(overlaps,
Kai Chen's avatar
Kai Chen committed
80
81
82
83
                             gt_labels=None,
                             pos_iou_thr=0.5,
                             neg_iou_thr=0.5,
                             min_pos_iou=.0):
84
85
86
    """Assign a corresponding gt bbox or background to each proposal/anchor.

    This method assign a gt bbox to every proposal, each proposals will be
Kai Chen's avatar
Kai Chen committed
87
88
89
    assigned with -1, 0, or a positive number. -1 means don't care, 0 means
    negative sample, positive number is the index (1-based) of assigned gt.
    The assignment is done in following steps, the order matters:
90

Kai Chen's avatar
Kai Chen committed
91
92
93
94
95
96
    1. assign every anchor to -1
    2. assign proposals whose iou with all gts < neg_iou_thr to 0
    3. for each anchor, if the iou with its nearest gt >= pos_iou_thr,
    assign it to that bbox
    4. for each gt bbox, assign its nearest proposals(may be more than one)
    to itself
97

Kai Chen's avatar
Kai Chen committed
98
    Args:
99
100
101
102
103
104
105
106
        overlaps (Tensor): Overlaps between n proposals and k gt_bboxes,
            shape(n, k).
        gt_labels (Tensor, optional): Labels of k gt_bboxes, shape (k, ).
        pos_iou_thr (float): IoU threshold for positive bboxes.
        neg_iou_thr (float or tuple): IoU threshold for negative bboxes.
        min_pos_iou (float): Minimum IoU for a bbox to be considered as a
            positive bbox. This argument only affects the 4th step.

Kai Chen's avatar
Kai Chen committed
107
    Returns:
108
109
        tuple: (assigned_gt_inds, [assigned_labels], argmax_overlaps,
            max_overlaps), shape (n, )
Kai Chen's avatar
Kai Chen committed
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
    """
    num_bboxes, num_gts = overlaps.size(0), overlaps.size(1)
    # 1. assign -1 by default
    assigned_gt_inds = overlaps.new(num_bboxes).long().fill_(-1)

    if overlaps.numel() == 0:
        raise ValueError('No gt bbox or proposals')

    assert overlaps.size() == (num_bboxes, num_gts)
    # for each anchor, which gt best overlaps with it
    # for each anchor, the max iou of all gts
    max_overlaps, argmax_overlaps = overlaps.max(dim=1)
    # for each gt, which anchor best overlaps with it
    # for each gt, the max iou of all proposals
    gt_max_overlaps, gt_argmax_overlaps = overlaps.max(dim=0)

    # 2. assign negative: below
    if isinstance(neg_iou_thr, float):
        assigned_gt_inds[(max_overlaps >= 0)
                         & (max_overlaps < neg_iou_thr)] = 0
    elif isinstance(neg_iou_thr, tuple):
        assert len(neg_iou_thr) == 2
        assigned_gt_inds[(max_overlaps >= neg_iou_thr[0])
                         & (max_overlaps < neg_iou_thr[1])] = 0

    # 3. assign positive: above positive IoU threshold
    pos_inds = max_overlaps >= pos_iou_thr
    assigned_gt_inds[pos_inds] = argmax_overlaps[pos_inds] + 1

    # 4. assign fg: for each gt, proposals with highest IoU
    for i in range(num_gts):
        if gt_max_overlaps[i] >= min_pos_iou:
            assigned_gt_inds[overlaps[:, i] == gt_max_overlaps[i]] = i + 1

    if gt_labels is None:
        return assigned_gt_inds, argmax_overlaps, max_overlaps
    else:
        assigned_labels = assigned_gt_inds.new(num_bboxes).fill_(0)
        pos_inds = torch.nonzero(assigned_gt_inds > 0).squeeze()
        if pos_inds.numel() > 0:
            assigned_labels[pos_inds] = gt_labels[assigned_gt_inds[pos_inds] -
                                                  1]
        return assigned_gt_inds, assigned_labels, argmax_overlaps, max_overlaps


155
156
157
def bbox_sampling_pos(assigned_gt_inds, num_expected, balance_sampling=True):
    """Balance sampling for positive bboxes/anchors.

Kai Chen's avatar
Kai Chen committed
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
    1. calculate average positive num for each gt: num_per_gt
    2. sample at most num_per_gt positives for each gt
    3. random sampling from rest anchors if not enough fg
    """
    pos_inds = torch.nonzero(assigned_gt_inds > 0)
    if pos_inds.numel() != 0:
        pos_inds = pos_inds.squeeze(1)
    if pos_inds.numel() <= num_expected:
        return pos_inds
    elif not balance_sampling:
        return random_choice(pos_inds, num_expected)
    else:
        unique_gt_inds = torch.unique(assigned_gt_inds[pos_inds].cpu())
        num_gts = len(unique_gt_inds)
        num_per_gt = int(round(num_expected / float(num_gts)) + 1)
        sampled_inds = []
        for i in unique_gt_inds:
            inds = torch.nonzero(assigned_gt_inds == i.item())
            if inds.numel() != 0:
                inds = inds.squeeze(1)
            else:
                continue
            if len(inds) > num_per_gt:
                inds = random_choice(inds, num_per_gt)
            sampled_inds.append(inds)
        sampled_inds = torch.cat(sampled_inds)
        if len(sampled_inds) < num_expected:
            num_extra = num_expected - len(sampled_inds)
            extra_inds = np.array(
                list(set(pos_inds.cpu()) - set(sampled_inds.cpu())))
            if len(extra_inds) > num_extra:
                extra_inds = random_choice(extra_inds, num_extra)
            extra_inds = torch.from_numpy(extra_inds).to(
                assigned_gt_inds.device).long()
            sampled_inds = torch.cat([sampled_inds, extra_inds])
        elif len(sampled_inds) > num_expected:
            sampled_inds = random_choice(sampled_inds, num_expected)
        return sampled_inds


198
199
200
201
202
203
204
205
206
207
def bbox_sampling_neg(assigned_gt_inds,
                      num_expected,
                      max_overlaps=None,
                      balance_thr=0,
                      hard_fraction=0.5):
    """Balance sampling for negative bboxes/anchors.

    Negative samples are split into 2 set: hard (balance_thr <= iou <
    neg_iou_thr) and easy(iou < balance_thr). The sampling ratio is controlled
    by `hard_fraction`.
Kai Chen's avatar
Kai Chen committed
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
    """
    neg_inds = torch.nonzero(assigned_gt_inds == 0)
    if neg_inds.numel() != 0:
        neg_inds = neg_inds.squeeze(1)
    if len(neg_inds) <= num_expected:
        return neg_inds
    elif balance_thr <= 0:
        # uniform sampling among all negative samples
        return random_choice(neg_inds, num_expected)
    else:
        assert max_overlaps is not None
        max_overlaps = max_overlaps.cpu().numpy()
        # balance sampling for negative samples
        neg_set = set(neg_inds.cpu().numpy())
        easy_set = set(
            np.where(
                np.logical_and(max_overlaps >= 0,
                               max_overlaps < balance_thr))[0])
        hard_set = set(np.where(max_overlaps >= balance_thr)[0])
        easy_neg_inds = list(easy_set & neg_set)
        hard_neg_inds = list(hard_set & neg_set)

        num_expected_hard = int(num_expected * hard_fraction)
        if len(hard_neg_inds) > num_expected_hard:
            sampled_hard_inds = random_choice(hard_neg_inds, num_expected_hard)
        else:
            sampled_hard_inds = np.array(hard_neg_inds, dtype=np.int)
        num_expected_easy = num_expected - len(sampled_hard_inds)
        if len(easy_neg_inds) > num_expected_easy:
            sampled_easy_inds = random_choice(easy_neg_inds, num_expected_easy)
        else:
            sampled_easy_inds = np.array(easy_neg_inds, dtype=np.int)
        sampled_inds = np.concatenate((sampled_easy_inds, sampled_hard_inds))
        if len(sampled_inds) < num_expected:
            num_extra = num_expected - len(sampled_inds)
            extra_inds = np.array(list(neg_set - set(sampled_inds)))
            if len(extra_inds) > num_extra:
                extra_inds = random_choice(extra_inds, num_extra)
            sampled_inds = np.concatenate((sampled_inds, extra_inds))
        sampled_inds = torch.from_numpy(sampled_inds).long().to(
            assigned_gt_inds.device)
        return sampled_inds


def bbox_sampling(assigned_gt_inds,
                  num_expected,
                  pos_fraction,
                  neg_pos_ub,
                  pos_balance_sampling=True,
                  max_overlaps=None,
                  neg_balance_thr=0,
                  neg_hard_fraction=0.5):
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
    """Sample positive and negative bboxes given assigned results.

    Args:
        assigned_gt_inds (Tensor): Assigned gt indices for each bbox.
        num_expected (int): Expected total samples (pos and neg).
        pos_fraction (float): Positive sample fraction.
        neg_pos_ub (float): Negative/Positive upper bound.
        pos_balance_sampling(bool): Whether to sample positive samples around
            each gt bbox evenly.
        max_overlaps (Tensor, optional): For each bbox, the max IoU of all gts.
            Used for negative balance sampling only.
        neg_balance_thr (float, optional): IoU threshold for simple/hard
            negative balance sampling.
        neg_hard_fraction (float, optional): Fraction of hard negative samples
            for negative balance sampling.

    Returns:
        tuple[Tensor]: positive bbox indices, negative bbox indices.
    """
Kai Chen's avatar
Kai Chen committed
279
    num_expected_pos = int(num_expected * pos_fraction)
280
281
282
283
    pos_inds = bbox_sampling_pos(assigned_gt_inds, num_expected_pos,
                                 pos_balance_sampling)
    # We found that sampled indices have duplicated items occasionally.
    # (mab be a bug of PyTorch)
Kai Chen's avatar
Kai Chen committed
284
    pos_inds = pos_inds.unique()
Kai Chen's avatar
Kai Chen committed
285
286
287
288
289
    num_sampled_pos = pos_inds.numel()
    num_neg_max = int(
        neg_pos_ub *
        num_sampled_pos) if num_sampled_pos > 0 else int(neg_pos_ub)
    num_expected_neg = min(num_neg_max, num_expected - num_sampled_pos)
290
291
292
    neg_inds = bbox_sampling_neg(assigned_gt_inds, num_expected_neg,
                                 max_overlaps, neg_balance_thr,
                                 neg_hard_fraction)
Kai Chen's avatar
Kai Chen committed
293
    neg_inds = neg_inds.unique()
Kai Chen's avatar
Kai Chen committed
294
    return pos_inds, neg_inds
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343


def sample_bboxes(bboxes, gt_bboxes, gt_bboxes_ignore, gt_labels, cfg):
    """Sample positive and negative bboxes.

    This is a simple implementation of bbox sampling given candidates and
    ground truth bboxes, which includes 3 steps.

    1. Assign gt to each bbox.
    2. Add gt bboxes to the sampling pool (optional).
    3. Perform positive and negative sampling.

    Args:
        bboxes (Tensor): Boxes to be sampled from.
        gt_bboxes (Tensor): Ground truth bboxes.
        gt_bboxes_ignore (Tensor): Ignored ground truth bboxes. In MS COCO,
            `crowd` bboxes are considered as ignored.
        gt_labels (Tensor): Class labels of ground truth bboxes.
        cfg (dict): Sampling configs.

    Returns:
        tuple[Tensor]: pos_bboxes, neg_bboxes, pos_assigned_gt_inds,
            pos_gt_bboxes, pos_gt_labels
    """
    bboxes = bboxes[:, :4]
    assigned_gt_inds, assigned_labels, argmax_overlaps, max_overlaps = \
        bbox_assign(bboxes, gt_bboxes, gt_bboxes_ignore, gt_labels,
                    cfg.pos_iou_thr, cfg.neg_iou_thr, cfg.min_pos_iou,
                    cfg.crowd_thr)

    if cfg.add_gt_as_proposals:
        bboxes = torch.cat([gt_bboxes, bboxes], dim=0)
        gt_assign_self = torch.arange(
            1, len(gt_labels) + 1, dtype=torch.long, device=bboxes.device)
        assigned_gt_inds = torch.cat([gt_assign_self, assigned_gt_inds])
        assigned_labels = torch.cat([gt_labels, assigned_labels])

    pos_inds, neg_inds = bbox_sampling(
        assigned_gt_inds, cfg.roi_batch_size, cfg.pos_fraction, cfg.neg_pos_ub,
        cfg.pos_balance_sampling, max_overlaps, cfg.neg_balance_thr)

    pos_bboxes = bboxes[pos_inds]
    neg_bboxes = bboxes[neg_inds]
    pos_assigned_gt_inds = assigned_gt_inds[pos_inds] - 1
    pos_gt_bboxes = gt_bboxes[pos_assigned_gt_inds, :]
    pos_gt_labels = assigned_labels[pos_inds]

    return (pos_bboxes, neg_bboxes, pos_assigned_gt_inds, pos_gt_bboxes,
            pos_gt_labels)