roi_heads.py 33 KB
Newer Older
1
from typing import Dict, List, Optional, Tuple
2

3
import torch
4
import torch.nn.functional as F
5
import torchvision
eellison's avatar
eellison committed
6
from torch import nn, Tensor
7
from torchvision.ops import boxes as box_ops, roi_align
8
9
10
11
12

from . import _utils as det_utils


def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
13
    # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
14
15
16
    """
    Computes the loss for Faster R-CNN.

17
    Args:
18
19
        class_logits (Tensor)
        box_regression (Tensor)
20
21
        labels (list[BoxList])
        regression_targets (Tensor)
22
23
24
25
26
27
28
29
30
31
32
33
34
35

    Returns:
        classification_loss (Tensor)
        box_loss (Tensor)
    """

    labels = torch.cat(labels, dim=0)
    regression_targets = torch.cat(regression_targets, dim=0)

    classification_loss = F.cross_entropy(class_logits, labels)

    # get indices that correspond to the regression targets for
    # the corresponding ground truth labels, to be used with
    # advanced indexing
36
    sampled_pos_inds_subset = torch.where(labels > 0)[0]
37
38
    labels_pos = labels[sampled_pos_inds_subset]
    N, num_classes = class_logits.shape
39
    box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
40

41
    box_loss = F.smooth_l1_loss(
42
43
        box_regression[sampled_pos_inds_subset, labels_pos],
        regression_targets[sampled_pos_inds_subset],
44
        beta=1 / 9,
45
        reduction="sum",
46
47
48
49
50
51
52
    )
    box_loss = box_loss / labels.numel()

    return classification_loss, box_loss


def maskrcnn_inference(x, labels):
53
    # type: (Tensor, List[Tensor]) -> List[Tensor]
54
55
56
57
58
59
    """
    From the results of the CNN, post process the masks
    by taking the mask corresponding to the class with max
    probability (which are of fixed size and directly output
    by the CNN) and return the masks in the mask field of the BoxList.

60
    Args:
61
        x (Tensor): the mask logits
62
        labels (list[BoxList]): bounding boxes that are used as
63
64
65
66
67
68
69
70
            reference, one for ech image

    Returns:
        results (list[BoxList]): one BoxList for each image, containing
            the extra field mask
    """
    mask_prob = x.sigmoid()

71
    # select masks corresponding to the predicted classes
72
    num_masks = x.shape[0]
Francisco Massa's avatar
Francisco Massa committed
73
    boxes_per_image = [label.shape[0] for label in labels]
74
75
76
    labels = torch.cat(labels)
    index = torch.arange(num_masks, device=labels.device)
    mask_prob = mask_prob[index, labels][:, None]
77
    mask_prob = mask_prob.split(boxes_per_image, dim=0)
78

79
    return mask_prob
80
81
82


def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
83
    # type: (Tensor, Tensor, Tensor, int) -> Tensor
84
85
86
87
88
89
90
91
92
93
    """
    Given segmentation masks and the bounding boxes corresponding
    to the location of the masks in the image, this function
    crops and resizes the masks in the position defined by the
    boxes. This prepares the masks for them to be fed to the
    loss computation as the targets.
    """
    matched_idxs = matched_idxs.to(boxes)
    rois = torch.cat([matched_idxs[:, None], boxes], dim=1)
    gt_masks = gt_masks[:, None].to(rois)
94
    return roi_align(gt_masks, rois, (M, M), 1.0)[:, 0]
95
96


97
def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs):
98
    # type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor]) -> Tensor
99
    """
100
    Args:
101
102
103
104
105
106
107
108
        proposals (list[BoxList])
        mask_logits (Tensor)
        targets (list[BoxList])

    Return:
        mask_loss (Tensor): scalar tensor containing the loss
    """

109
    discretization_size = mask_logits.shape[-1]
Francisco Massa's avatar
Francisco Massa committed
110
    labels = [gt_label[idxs] for gt_label, idxs in zip(gt_labels, mask_matched_idxs)]
111
    mask_targets = [
112
        project_masks_on_boxes(m, p, i, discretization_size) for m, p, i in zip(gt_masks, proposals, mask_matched_idxs)
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
    ]

    labels = torch.cat(labels, dim=0)
    mask_targets = torch.cat(mask_targets, dim=0)

    # torch.mean (in binary_cross_entropy_with_logits) doesn't
    # accept empty tensors, so handle it separately
    if mask_targets.numel() == 0:
        return mask_logits.sum() * 0

    mask_loss = F.binary_cross_entropy_with_logits(
        mask_logits[torch.arange(labels.shape[0], device=labels.device), labels], mask_targets
    )
    return mask_loss


def keypoints_to_heatmap(keypoints, rois, heatmap_size):
130
    # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
    offset_x = rois[:, 0]
    offset_y = rois[:, 1]
    scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
    scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])

    offset_x = offset_x[:, None]
    offset_y = offset_y[:, None]
    scale_x = scale_x[:, None]
    scale_y = scale_y[:, None]

    x = keypoints[..., 0]
    y = keypoints[..., 1]

    x_boundary_inds = x == rois[:, 2][:, None]
    y_boundary_inds = y == rois[:, 3][:, None]

    x = (x - offset_x) * scale_x
    x = x.floor().long()
    y = (y - offset_y) * scale_y
    y = y.floor().long()

152
153
    x[x_boundary_inds] = heatmap_size - 1
    y[y_boundary_inds] = heatmap_size - 1
154
155
156
157
158
159
160
161
162
163
164

    valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
    vis = keypoints[..., 2] > 0
    valid = (valid_loc & vis).long()

    lin_ind = y * heatmap_size + x
    heatmaps = lin_ind * valid

    return heatmaps, valid


165
166
167
def _onnx_heatmaps_to_keypoints(
    maps, maps_i, roi_map_width, roi_map_height, widths_i, heights_i, offset_x_i, offset_y_i
):
168
169
170
171
172
    num_keypoints = torch.scalar_tensor(maps.size(1), dtype=torch.int64)

    width_correction = widths_i / roi_map_width
    height_correction = heights_i / roi_map_height

173
    roi_map = F.interpolate(
174
175
        maps_i[:, None], size=(int(roi_map_height), int(roi_map_width)), mode="bicubic", align_corners=False
    )[:, 0]
176
177
178
179

    w = torch.scalar_tensor(roi_map.size(2), dtype=torch.int64)
    pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)

180
181
    x_int = pos % w
    y_int = (pos - x_int) // w
182

183
184
185
186
187
188
    x = (torch.tensor(0.5, dtype=torch.float32) + x_int.to(dtype=torch.float32)) * width_correction.to(
        dtype=torch.float32
    )
    y = (torch.tensor(0.5, dtype=torch.float32) + y_int.to(dtype=torch.float32)) * height_correction.to(
        dtype=torch.float32
    )
189
190
191

    xy_preds_i_0 = x + offset_x_i.to(dtype=torch.float32)
    xy_preds_i_1 = y + offset_y_i.to(dtype=torch.float32)
192
    xy_preds_i_2 = torch.ones(xy_preds_i_1.shape, dtype=torch.float32)
193
194
195
196
197
198
199
200
    xy_preds_i = torch.stack(
        [
            xy_preds_i_0.to(dtype=torch.float32),
            xy_preds_i_1.to(dtype=torch.float32),
            xy_preds_i_2.to(dtype=torch.float32),
        ],
        0,
    )
201
202

    # TODO: simplify when indexing without rank will be supported by ONNX
Ksenija Stanojevic's avatar
Ksenija Stanojevic committed
203
204
205
    base = num_keypoints * num_keypoints + num_keypoints + 1
    ind = torch.arange(num_keypoints)
    ind = ind.to(dtype=torch.int64) * base
206
207
208
209
210
211
    end_scores_i = (
        roi_map.index_select(1, y_int.to(dtype=torch.int64))
        .index_select(2, x_int.to(dtype=torch.int64))
        .view(-1)
        .index_select(0, ind.to(dtype=torch.int64))
    )
Ksenija Stanojevic's avatar
Ksenija Stanojevic committed
212

213
214
215
    return xy_preds_i, end_scores_i


216
@torch.jit._script_if_tracing
217
218
219
def _onnx_heatmaps_to_keypoints_loop(
    maps, rois, widths_ceil, heights_ceil, widths, heights, offset_x, offset_y, num_keypoints
):
220
221
222
223
    xy_preds = torch.zeros((0, 3, int(num_keypoints)), dtype=torch.float32, device=maps.device)
    end_scores = torch.zeros((0, int(num_keypoints)), dtype=torch.float32, device=maps.device)

    for i in range(int(rois.size(0))):
224
225
226
227
228
229
230
        xy_preds_i, end_scores_i = _onnx_heatmaps_to_keypoints(
            maps, maps[i], widths_ceil[i], heights_ceil[i], widths[i], heights[i], offset_x[i], offset_y[i]
        )
        xy_preds = torch.cat((xy_preds.to(dtype=torch.float32), xy_preds_i.unsqueeze(0).to(dtype=torch.float32)), 0)
        end_scores = torch.cat(
            (end_scores.to(dtype=torch.float32), end_scores_i.to(dtype=torch.float32).unsqueeze(0)), 0
        )
231
232
233
    return xy_preds, end_scores


234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
def heatmaps_to_keypoints(maps, rois):
    """Extract predicted keypoint locations from heatmaps. Output has shape
    (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
    for each keypoint.
    """
    # This function converts a discrete image coordinate in a HEATMAP_SIZE x
    # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain
    # consistency with keypoints_to_heatmap_labels by using the conversion from
    # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a
    # continuous coordinate.
    offset_x = rois[:, 0]
    offset_y = rois[:, 1]

    widths = rois[:, 2] - rois[:, 0]
    heights = rois[:, 3] - rois[:, 1]
    widths = widths.clamp(min=1)
    heights = heights.clamp(min=1)
    widths_ceil = widths.ceil()
    heights_ceil = heights.ceil()

    num_keypoints = maps.shape[1]
255
256

    if torchvision._is_tracing():
257
258
259
260
261
262
263
264
265
266
267
        xy_preds, end_scores = _onnx_heatmaps_to_keypoints_loop(
            maps,
            rois,
            widths_ceil,
            heights_ceil,
            widths,
            heights,
            offset_x,
            offset_y,
            torch.scalar_tensor(num_keypoints, dtype=torch.int64),
        )
268
269
        return xy_preds.permute(0, 2, 1), end_scores

270
271
272
273
274
275
276
    xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device)
    end_scores = torch.zeros((len(rois), num_keypoints), dtype=torch.float32, device=maps.device)
    for i in range(len(rois)):
        roi_map_width = int(widths_ceil[i].item())
        roi_map_height = int(heights_ceil[i].item())
        width_correction = widths[i] / roi_map_width
        height_correction = heights[i] / roi_map_height
277
        roi_map = F.interpolate(
278
279
            maps[i][:, None], size=(roi_map_height, roi_map_width), mode="bicubic", align_corners=False
        )[:, 0]
280
281
282
        # roi_map_probs = scores_to_probs(roi_map.copy())
        w = roi_map.shape[2]
        pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
eellison's avatar
eellison committed
283

284
        x_int = pos % w
285
        y_int = torch.div(pos - x_int, w, rounding_mode="floor")
286
287
288
289
290
291
292
        # assert (roi_map_probs[k, y_int, x_int] ==
        #         roi_map_probs[k, :, :].max())
        x = (x_int.float() + 0.5) * width_correction
        y = (y_int.float() + 0.5) * height_correction
        xy_preds[i, 0, :] = x + offset_x[i]
        xy_preds[i, 1, :] = y + offset_y[i]
        xy_preds[i, 2, :] = 1
293
        end_scores[i, :] = roi_map[torch.arange(num_keypoints, device=roi_map.device), y_int, x_int]
294
295
296
297

    return xy_preds.permute(0, 2, 1), end_scores


298
def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs):
299
    # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
300
    N, K, H, W = keypoint_logits.shape
301
302
303
304
    if H != W:
        raise ValueError(
            f"keypoint_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
        )
305
    discretization_size = H
306
307
308
309
    heatmaps = []
    valid = []
    for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_keypoints, keypoint_matched_idxs):
        kp = gt_kp_in_image[midx]
310
        heatmaps_per_image, valid_per_image = keypoints_to_heatmap(kp, proposals_per_image, discretization_size)
311
312
313
314
315
        heatmaps.append(heatmaps_per_image.view(-1))
        valid.append(valid_per_image.view(-1))

    keypoint_targets = torch.cat(heatmaps, dim=0)
    valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
316
    valid = torch.where(valid)[0]
317

318
    # torch.mean (in binary_cross_entropy_with_logits) doesn't
319
320
321
322
323
324
325
326
327
328
329
    # accept empty tensors, so handle it sepaartely
    if keypoint_targets.numel() == 0 or len(valid) == 0:
        return keypoint_logits.sum() * 0

    keypoint_logits = keypoint_logits.view(N * K, H * W)

    keypoint_loss = F.cross_entropy(keypoint_logits[valid], keypoint_targets[valid])
    return keypoint_loss


def keypointrcnn_inference(x, boxes):
330
    # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
331
332
333
    kp_probs = []
    kp_scores = []

334
    boxes_per_image = [box.size(0) for box in boxes]
335
336
337
338
339
340
341
342
343
344
    x2 = x.split(boxes_per_image, dim=0)

    for xx, bb in zip(x2, boxes):
        kp_prob, scores = heatmaps_to_keypoints(xx, bb)
        kp_probs.append(kp_prob)
        kp_scores.append(scores)

    return kp_probs, kp_scores


345
def _onnx_expand_boxes(boxes, scale):
346
    # type: (Tensor, float) -> Tensor
347
348
349
350
    w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
    h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
    x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
    y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
351
352
353
354
355
356
357
358
359
360
361
362

    w_half = w_half.to(dtype=torch.float32) * scale
    h_half = h_half.to(dtype=torch.float32) * scale

    boxes_exp0 = x_c - w_half
    boxes_exp1 = y_c - h_half
    boxes_exp2 = x_c + w_half
    boxes_exp3 = y_c + h_half
    boxes_exp = torch.stack((boxes_exp0, boxes_exp1, boxes_exp2, boxes_exp3), 1)
    return boxes_exp


363
364
# the next two functions should be merged inside Masker
# but are kept here for the moment while we need them
365
# temporarily for paste_mask_in_image
366
def expand_boxes(boxes, scale):
367
    # type: (Tensor, float) -> Tensor
368
369
    if torchvision._is_tracing():
        return _onnx_expand_boxes(boxes, scale)
370
371
372
373
    w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
    h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
    x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
    y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
374
375
376
377
378
379
380
381
382
383
384
385

    w_half *= scale
    h_half *= scale

    boxes_exp = torch.zeros_like(boxes)
    boxes_exp[:, 0] = x_c - w_half
    boxes_exp[:, 2] = x_c + w_half
    boxes_exp[:, 1] = y_c - h_half
    boxes_exp[:, 3] = y_c + h_half
    return boxes_exp


eellison's avatar
eellison committed
386
387
388
389
390
391
@torch.jit.unused
def expand_masks_tracing_scale(M, padding):
    # type: (int, int) -> float
    return torch.tensor(M + 2 * padding).to(torch.float32) / torch.tensor(M).to(torch.float32)


392
def expand_masks(mask, padding):
393
    # type: (Tensor, int) -> Tuple[Tensor, float]
394
    M = mask.shape[-1]
eellison's avatar
eellison committed
395
396
    if torch._C._get_tracing_state():  # could not import is_tracing(), not sure why
        scale = expand_masks_tracing_scale(M, padding)
397
398
    else:
        scale = float(M + 2 * padding) / M
399
    padded_mask = F.pad(mask, (padding,) * 4)
400
401
402
403
    return padded_mask, scale


def paste_mask_in_image(mask, box, im_h, im_w):
404
    # type: (Tensor, Tensor, int, int) -> Tensor
405
406
407
408
409
410
411
412
413
414
    TO_REMOVE = 1
    w = int(box[2] - box[0] + TO_REMOVE)
    h = int(box[3] - box[1] + TO_REMOVE)
    w = max(w, 1)
    h = max(h, 1)

    # Set shape to [batchxCxHxW]
    mask = mask.expand((1, 1, -1, -1))

    # Resize mask
415
    mask = F.interpolate(mask, size=(h, w), mode="bilinear", align_corners=False)
416
417
418
419
420
421
422
423
    mask = mask[0][0]

    im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device)
    x_0 = max(box[0], 0)
    x_1 = min(box[2] + 1, im_w)
    y_0 = max(box[1], 0)
    y_1 = min(box[3] + 1, im_h)

424
    im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])]
425
426
427
    return im_mask


428
429
430
431
def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
    one = torch.ones(1, dtype=torch.int64)
    zero = torch.zeros(1, dtype=torch.int64)

432
433
    w = box[2] - box[0] + one
    h = box[3] - box[1] + one
434
435
436
437
438
439
440
    w = torch.max(torch.cat((w, one)))
    h = torch.max(torch.cat((h, one)))

    # Set shape to [batchxCxHxW]
    mask = mask.expand((1, 1, mask.size(0), mask.size(1)))

    # Resize mask
441
    mask = F.interpolate(mask, size=(int(h), int(w)), mode="bilinear", align_corners=False)
442
443
444
445
446
447
448
    mask = mask[0][0]

    x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero)))
    x_1 = torch.min(torch.cat((box[2].unsqueeze(0) + one, im_w.unsqueeze(0))))
    y_0 = torch.max(torch.cat((box[1].unsqueeze(0), zero)))
    y_1 = torch.min(torch.cat((box[3].unsqueeze(0) + one, im_h.unsqueeze(0))))

449
    unpaded_im_mask = mask[(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])]
450
451
452
453
454
455

    # TODO : replace below with a dynamic padding when support is added in ONNX

    # pad y
    zeros_y0 = torch.zeros(y_0, unpaded_im_mask.size(1))
    zeros_y1 = torch.zeros(im_h - y_1, unpaded_im_mask.size(1))
456
    concat_0 = torch.cat((zeros_y0, unpaded_im_mask.to(dtype=torch.float32), zeros_y1), 0)[0:im_h, :]
457
458
459
    # pad x
    zeros_x0 = torch.zeros(concat_0.size(0), x_0)
    zeros_x1 = torch.zeros(concat_0.size(0), im_w - x_1)
460
    im_mask = torch.cat((zeros_x0, concat_0, zeros_x1), 1)[:, :im_w]
461
462
463
    return im_mask


464
@torch.jit._script_if_tracing
465
466
467
468
469
470
471
472
473
def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w):
    res_append = torch.zeros(0, im_h, im_w)
    for i in range(masks.size(0)):
        mask_res = _onnx_paste_mask_in_image(masks[i][0], boxes[i], im_h, im_w)
        mask_res = mask_res.unsqueeze(0)
        res_append = torch.cat((res_append, mask_res))
    return res_append


474
def paste_masks_in_image(masks, boxes, img_shape, padding=1):
475
    # type: (Tensor, Tensor, Tuple[int, int], int) -> Tensor
476
    masks, scale = expand_masks(masks, padding=padding)
477
    boxes = expand_boxes(boxes, scale).to(dtype=torch.int64)
478
    im_h, im_w = img_shape
479
480

    if torchvision._is_tracing():
481
482
483
484
        return _onnx_paste_masks_in_image_loop(
            masks, boxes, torch.scalar_tensor(im_h, dtype=torch.int64), torch.scalar_tensor(im_w, dtype=torch.int64)
        )[:, None]
    res = [paste_mask_in_image(m[0], b, im_h, im_w) for m, b in zip(masks, boxes)]
485
    if len(res) > 0:
eellison's avatar
eellison committed
486
        ret = torch.stack(res, dim=0)[:, None]
487
    else:
eellison's avatar
eellison committed
488
489
        ret = masks.new_empty((0, 1, im_h, im_w))
    return ret
490
491


492
class RoIHeads(nn.Module):
eellison's avatar
eellison committed
493
    __annotations__ = {
494
495
496
        "box_coder": det_utils.BoxCoder,
        "proposal_matcher": det_utils.Matcher,
        "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
eellison's avatar
eellison committed
497
498
    }

499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
    def __init__(
        self,
        box_roi_pool,
        box_head,
        box_predictor,
        # Faster R-CNN training
        fg_iou_thresh,
        bg_iou_thresh,
        batch_size_per_image,
        positive_fraction,
        bbox_reg_weights,
        # Faster R-CNN inference
        score_thresh,
        nms_thresh,
        detections_per_img,
        # Mask
        mask_roi_pool=None,
        mask_head=None,
        mask_predictor=None,
        keypoint_roi_pool=None,
        keypoint_head=None,
        keypoint_predictor=None,
    ):
522
        super().__init__()
523
524
525

        self.box_similarity = box_ops.box_iou
        # assign ground-truth boxes for each proposal
526
        self.proposal_matcher = det_utils.Matcher(fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=False)
527

528
        self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)
529
530

        if bbox_reg_weights is None:
531
            bbox_reg_weights = (10.0, 10.0, 5.0, 5.0)
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
        self.box_coder = det_utils.BoxCoder(bbox_reg_weights)

        self.box_roi_pool = box_roi_pool
        self.box_head = box_head
        self.box_predictor = box_predictor

        self.score_thresh = score_thresh
        self.nms_thresh = nms_thresh
        self.detections_per_img = detections_per_img

        self.mask_roi_pool = mask_roi_pool
        self.mask_head = mask_head
        self.mask_predictor = mask_predictor

        self.keypoint_roi_pool = keypoint_roi_pool
        self.keypoint_head = keypoint_head
        self.keypoint_predictor = keypoint_predictor

    def has_mask(self):
        if self.mask_roi_pool is None:
            return False
        if self.mask_head is None:
            return False
        if self.mask_predictor is None:
            return False
        return True

    def has_keypoint(self):
        if self.keypoint_roi_pool is None:
            return False
        if self.keypoint_head is None:
            return False
        if self.keypoint_predictor is None:
            return False
        return True

    def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
569
        # type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
570
571
572
573
        matched_idxs = []
        labels = []
        for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):

574
575
576
577
578
579
            if gt_boxes_in_image.numel() == 0:
                # Background image
                device = proposals_in_image.device
                clamped_matched_idxs_in_image = torch.zeros(
                    (proposals_in_image.shape[0],), dtype=torch.int64, device=device
                )
580
                labels_in_image = torch.zeros((proposals_in_image.shape[0],), dtype=torch.int64, device=device)
581
582
583
584
585
586
587
588
589
590
591
592
            else:
                #  set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands
                match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image)
                matched_idxs_in_image = self.proposal_matcher(match_quality_matrix)

                clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0)

                labels_in_image = gt_labels_in_image[clamped_matched_idxs_in_image]
                labels_in_image = labels_in_image.to(dtype=torch.int64)

                # Label background (below the low threshold)
                bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD
593
                labels_in_image[bg_inds] = 0
594
595
596

                # Label ignore proposals (between low and high thresholds)
                ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS
597
                labels_in_image[ignore_inds] = -1  # -1 is ignored by sampler
598
599
600
601
602
603

            matched_idxs.append(clamped_matched_idxs_in_image)
            labels.append(labels_in_image)
        return matched_idxs, labels

    def subsample(self, labels):
604
        # type: (List[Tensor]) -> List[Tensor]
605
606
        sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
        sampled_inds = []
607
        for img_idx, (pos_inds_img, neg_inds_img) in enumerate(zip(sampled_pos_inds, sampled_neg_inds)):
608
            img_sampled_inds = torch.where(pos_inds_img | neg_inds_img)[0]
609
610
611
612
            sampled_inds.append(img_sampled_inds)
        return sampled_inds

    def add_gt_proposals(self, proposals, gt_boxes):
613
        # type: (List[Tensor], List[Tensor]) -> List[Tensor]
614
        proposals = [torch.cat((proposal, gt_box)) for proposal, gt_box in zip(proposals, gt_boxes)]
615
616
617
618

        return proposals

    def check_targets(self, targets):
619
        # type: (Optional[List[Dict[str, Tensor]]]) -> None
620
621
622
623
624
625
        if targets is None:
            raise ValueError("targets should not be None")
        if not all(["boxes" in t for t in targets]):
            raise ValueError("Every element of targets should have a boxes key")
        if not all(["labels" in t for t in targets]):
            raise ValueError("Every element of targets should have a labels key")
eellison's avatar
eellison committed
626
        if self.has_mask():
627
628
            if not all(["masks" in t for t in targets]):
                raise ValueError("Every element of targets should have a masks key")
629

630
631
632
633
634
    def select_training_samples(
        self,
        proposals,  # type: List[Tensor]
        targets,  # type: Optional[List[Dict[str, Tensor]]]
    ):
635
        # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]
636
        self.check_targets(targets)
637
638
        if targets is None:
            raise ValueError("targets should not be None")
639
        dtype = proposals[0].dtype
640
641
        device = proposals[0].device

642
        gt_boxes = [t["boxes"].to(dtype) for t in targets]
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
        gt_labels = [t["labels"] for t in targets]

        # append ground-truth bboxes to propos
        proposals = self.add_gt_proposals(proposals, gt_boxes)

        # get matching gt indices for each proposal
        matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
        # sample a fixed proportion of positive-negative proposals
        sampled_inds = self.subsample(labels)
        matched_gt_boxes = []
        num_images = len(proposals)
        for img_id in range(num_images):
            img_sampled_inds = sampled_inds[img_id]
            proposals[img_id] = proposals[img_id][img_sampled_inds]
            labels[img_id] = labels[img_id][img_sampled_inds]
            matched_idxs[img_id] = matched_idxs[img_id][img_sampled_inds]
659
660
661
662
663

            gt_boxes_in_image = gt_boxes[img_id]
            if gt_boxes_in_image.numel() == 0:
                gt_boxes_in_image = torch.zeros((1, 4), dtype=dtype, device=device)
            matched_gt_boxes.append(gt_boxes_in_image[matched_idxs[img_id]])
664
665
666
667

        regression_targets = self.box_coder.encode(matched_gt_boxes, proposals)
        return proposals, matched_idxs, labels, regression_targets

668
669
670
671
672
673
674
    def postprocess_detections(
        self,
        class_logits,  # type: Tensor
        box_regression,  # type: Tensor
        proposals,  # type: List[Tensor]
        image_shapes,  # type: List[Tuple[int, int]]
    ):
675
        # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]
676
677
678
        device = class_logits.device
        num_classes = class_logits.shape[-1]

679
        boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals]
680
681
682
683
        pred_boxes = self.box_coder.decode(box_regression, proposals)

        pred_scores = F.softmax(class_logits, -1)

684
685
        pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
        pred_scores_list = pred_scores.split(boxes_per_image, 0)
686
687
688
689

        all_boxes = []
        all_scores = []
        all_labels = []
eellison's avatar
eellison committed
690
        for boxes, scores, image_shape in zip(pred_boxes_list, pred_scores_list, image_shapes):
691
692
693
694
695
696
697
698
699
700
701
702
703
            boxes = box_ops.clip_boxes_to_image(boxes, image_shape)

            # create labels for each prediction
            labels = torch.arange(num_classes, device=device)
            labels = labels.view(1, -1).expand_as(scores)

            # remove predictions with the background label
            boxes = boxes[:, 1:]
            scores = scores[:, 1:]
            labels = labels[:, 1:]

            # batch everything, by making every class prediction be a separate instance
            boxes = boxes.reshape(-1, 4)
704
705
            scores = scores.reshape(-1)
            labels = labels.reshape(-1)
706
707

            # remove low scoring boxes
708
            inds = torch.where(scores > self.score_thresh)[0]
709
710
            boxes, scores, labels = boxes[inds], scores[inds], labels[inds]

711
712
713
714
            # remove empty boxes
            keep = box_ops.remove_small_boxes(boxes, min_size=1e-2)
            boxes, scores, labels = boxes[keep], scores[keep], labels[keep]

715
716
717
            # non-maximum suppression, independently done per class
            keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
            # keep only topk scoring predictions
718
            keep = keep[: self.detections_per_img]
719
720
721
722
723
724
725
726
            boxes, scores, labels = boxes[keep], scores[keep], labels[keep]

            all_boxes.append(boxes)
            all_scores.append(scores)
            all_labels.append(labels)

        return all_boxes, all_scores, all_labels

727
728
729
730
731
732
733
    def forward(
        self,
        features,  # type: Dict[str, Tensor]
        proposals,  # type: List[Tensor]
        image_shapes,  # type: List[Tuple[int, int]]
        targets=None,  # type: Optional[List[Dict[str, Tensor]]]
    ):
734
        # type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]]
735
        """
736
        Args:
737
738
739
740
741
            features (List[Tensor])
            proposals (List[Tensor[N, 4]])
            image_shapes (List[Tuple[H, W]])
            targets (List[Dict])
        """
742
743
        if targets is not None:
            for t in targets:
eellison's avatar
eellison committed
744
745
                # TODO: https://github.com/pytorch/pytorch/issues/26731
                floating_point_types = (torch.float, torch.double, torch.half)
746
747
748
                if not t["boxes"].dtype in floating_point_types:
                    raise TypeError(f"target boxes must of float type, instead got {t['boxes'].dtype}")
                if not t["labels"].dtype == torch.int64:
749
                    raise TypeError(f"target labels must of int64 type, instead got {t['labels'].dtype}")
eellison's avatar
eellison committed
750
                if self.has_keypoint():
751
752
                    if not t["keypoints"].dtype == torch.float32:
                        raise TypeError(f"target keypoints must of float type, instead got {t['keypoints'].dtype}")
753

754
755
        if self.training:
            proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
eellison's avatar
eellison committed
756
757
758
759
        else:
            labels = None
            regression_targets = None
            matched_idxs = None
760
761
762
763
764

        box_features = self.box_roi_pool(features, proposals, image_shapes)
        box_features = self.box_head(box_features)
        class_logits, box_regression = self.box_predictor(box_features)

765
        result: List[Dict[str, torch.Tensor]] = []
eellison's avatar
eellison committed
766
        losses = {}
767
        if self.training:
768
769
770
771
            if labels is None:
                raise ValueError("labels cannot be None")
            if regression_targets is None:
                raise ValueError("regression_targets cannot be None")
772
773
            loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
            losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
774
775
776
777
778
        else:
            boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
            num_images = len(boxes)
            for i in range(num_images):
                result.append(
eellison's avatar
eellison committed
779
780
781
782
783
                    {
                        "boxes": boxes[i],
                        "labels": labels[i],
                        "scores": scores[i],
                    }
784
785
                )

eellison's avatar
eellison committed
786
        if self.has_mask():
787
788
            mask_proposals = [p["boxes"] for p in result]
            if self.training:
789
                if matched_idxs is None:
790
                    raise ValueError("if in training, matched_idxs should not be None")
791

792
793
794
795
796
                # during training, only focus on positive boxes
                num_images = len(proposals)
                mask_proposals = []
                pos_matched_idxs = []
                for img_id in range(num_images):
797
                    pos = torch.where(labels[img_id] > 0)[0]
798
799
                    mask_proposals.append(proposals[img_id][pos])
                    pos_matched_idxs.append(matched_idxs[img_id][pos])
eellison's avatar
eellison committed
800
801
            else:
                pos_matched_idxs = None
802

eellison's avatar
eellison committed
803
804
805
806
807
808
            if self.mask_roi_pool is not None:
                mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes)
                mask_features = self.mask_head(mask_features)
                mask_logits = self.mask_predictor(mask_features)
            else:
                raise Exception("Expected mask_roi_pool to be not None")
809
810
811

            loss_mask = {}
            if self.training:
812
813
                if targets is None or pos_matched_idxs is None or mask_logits is None:
                    raise ValueError("targets, pos_matched_idxs, mask_logits cannot be None when training")
eellison's avatar
eellison committed
814

815
816
                gt_masks = [t["masks"] for t in targets]
                gt_labels = [t["labels"] for t in targets]
817
818
                rcnn_loss_mask = maskrcnn_loss(mask_logits, mask_proposals, gt_masks, gt_labels, pos_matched_idxs)
                loss_mask = {"loss_mask": rcnn_loss_mask}
819
820
821
822
            else:
                labels = [r["labels"] for r in result]
                masks_probs = maskrcnn_inference(mask_logits, labels)
                for mask_prob, r in zip(masks_probs, result):
823
                    r["masks"] = mask_prob
824
825
826

            losses.update(loss_mask)

eellison's avatar
eellison committed
827
828
        # keep none checks in if conditional so torchscript will conditionally
        # compile each branch
829
830
831
832
833
        if (
            self.keypoint_roi_pool is not None
            and self.keypoint_head is not None
            and self.keypoint_predictor is not None
        ):
834
835
836
837
838
839
            keypoint_proposals = [p["boxes"] for p in result]
            if self.training:
                # during training, only focus on positive boxes
                num_images = len(proposals)
                keypoint_proposals = []
                pos_matched_idxs = []
840
841
842
                if matched_idxs is None:
                    raise ValueError("if in trainning, matched_idxs should not be None")

843
                for img_id in range(num_images):
844
                    pos = torch.where(labels[img_id] > 0)[0]
845
846
                    keypoint_proposals.append(proposals[img_id][pos])
                    pos_matched_idxs.append(matched_idxs[img_id][pos])
eellison's avatar
eellison committed
847
848
            else:
                pos_matched_idxs = None
849
850
851
852
853
854
855

            keypoint_features = self.keypoint_roi_pool(features, keypoint_proposals, image_shapes)
            keypoint_features = self.keypoint_head(keypoint_features)
            keypoint_logits = self.keypoint_predictor(keypoint_features)

            loss_keypoint = {}
            if self.training:
856
857
                if targets is None or pos_matched_idxs is None:
                    raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
eellison's avatar
eellison committed
858

859
                gt_keypoints = [t["keypoints"] for t in targets]
eellison's avatar
eellison committed
860
                rcnn_loss_keypoint = keypointrcnn_loss(
861
862
863
                    keypoint_logits, keypoint_proposals, gt_keypoints, pos_matched_idxs
                )
                loss_keypoint = {"loss_keypoint": rcnn_loss_keypoint}
864
            else:
865
866
867
868
                if keypoint_logits is None or keypoint_proposals is None:
                    raise ValueError(
                        "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
                    )
eellison's avatar
eellison committed
869

870
871
872
873
874
875
876
                keypoints_probs, kp_scores = keypointrcnn_inference(keypoint_logits, keypoint_proposals)
                for keypoint_prob, kps, r in zip(keypoints_probs, kp_scores, result):
                    r["keypoints"] = keypoint_prob
                    r["keypoints_scores"] = kps
            losses.update(loss_keypoint)

        return result, losses