roi_heads.py 33.1 KB
Newer Older
1
from typing import Optional, List, Dict, Tuple
2

3
import torch
4
import torch.nn.functional as F
5
import torchvision
eellison's avatar
eellison committed
6
from torch import nn, Tensor
7
8
9
10
11
12
13
from torchvision.ops import boxes as box_ops
from torchvision.ops import roi_align

from . import _utils as det_utils


def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
14
    # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
15
16
17
    """
    Computes the loss for Faster R-CNN.

18
    Args:
19
20
        class_logits (Tensor)
        box_regression (Tensor)
21
22
        labels (list[BoxList])
        regression_targets (Tensor)
23
24
25
26
27
28
29
30
31
32
33
34
35
36

    Returns:
        classification_loss (Tensor)
        box_loss (Tensor)
    """

    labels = torch.cat(labels, dim=0)
    regression_targets = torch.cat(regression_targets, dim=0)

    classification_loss = F.cross_entropy(class_logits, labels)

    # get indices that correspond to the regression targets for
    # the corresponding ground truth labels, to be used with
    # advanced indexing
37
    sampled_pos_inds_subset = torch.where(labels > 0)[0]
38
39
    labels_pos = labels[sampled_pos_inds_subset]
    N, num_classes = class_logits.shape
40
    box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
41

42
    box_loss = F.smooth_l1_loss(
43
44
        box_regression[sampled_pos_inds_subset, labels_pos],
        regression_targets[sampled_pos_inds_subset],
45
        beta=1 / 9,
46
        reduction="sum",
47
48
49
50
51
52
53
    )
    box_loss = box_loss / labels.numel()

    return classification_loss, box_loss


def maskrcnn_inference(x, labels):
54
    # type: (Tensor, List[Tensor]) -> List[Tensor]
55
56
57
58
59
60
    """
    From the results of the CNN, post process the masks
    by taking the mask corresponding to the class with max
    probability (which are of fixed size and directly output
    by the CNN) and return the masks in the mask field of the BoxList.

61
    Args:
62
        x (Tensor): the mask logits
63
        labels (list[BoxList]): bounding boxes that are used as
64
65
66
67
68
69
70
71
            reference, one for ech image

    Returns:
        results (list[BoxList]): one BoxList for each image, containing
            the extra field mask
    """
    mask_prob = x.sigmoid()

72
    # select masks corresponding to the predicted classes
73
    num_masks = x.shape[0]
Francisco Massa's avatar
Francisco Massa committed
74
    boxes_per_image = [label.shape[0] for label in labels]
75
76
77
    labels = torch.cat(labels)
    index = torch.arange(num_masks, device=labels.device)
    mask_prob = mask_prob[index, labels][:, None]
78
    mask_prob = mask_prob.split(boxes_per_image, dim=0)
79

80
    return mask_prob
81
82
83


def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
84
    # type: (Tensor, Tensor, Tensor, int) -> Tensor
85
86
87
88
89
90
91
92
93
94
    """
    Given segmentation masks and the bounding boxes corresponding
    to the location of the masks in the image, this function
    crops and resizes the masks in the position defined by the
    boxes. This prepares the masks for them to be fed to the
    loss computation as the targets.
    """
    matched_idxs = matched_idxs.to(boxes)
    rois = torch.cat([matched_idxs[:, None], boxes], dim=1)
    gt_masks = gt_masks[:, None].to(rois)
95
    return roi_align(gt_masks, rois, (M, M), 1.0)[:, 0]
96
97


98
def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs):
99
    # type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor]) -> Tensor
100
    """
101
    Args:
102
103
104
105
106
107
108
109
        proposals (list[BoxList])
        mask_logits (Tensor)
        targets (list[BoxList])

    Return:
        mask_loss (Tensor): scalar tensor containing the loss
    """

110
    discretization_size = mask_logits.shape[-1]
Francisco Massa's avatar
Francisco Massa committed
111
    labels = [gt_label[idxs] for gt_label, idxs in zip(gt_labels, mask_matched_idxs)]
112
    mask_targets = [
113
        project_masks_on_boxes(m, p, i, discretization_size) for m, p, i in zip(gt_masks, proposals, mask_matched_idxs)
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
    ]

    labels = torch.cat(labels, dim=0)
    mask_targets = torch.cat(mask_targets, dim=0)

    # torch.mean (in binary_cross_entropy_with_logits) doesn't
    # accept empty tensors, so handle it separately
    if mask_targets.numel() == 0:
        return mask_logits.sum() * 0

    mask_loss = F.binary_cross_entropy_with_logits(
        mask_logits[torch.arange(labels.shape[0], device=labels.device), labels], mask_targets
    )
    return mask_loss


def keypoints_to_heatmap(keypoints, rois, heatmap_size):
131
    # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
    offset_x = rois[:, 0]
    offset_y = rois[:, 1]
    scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
    scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])

    offset_x = offset_x[:, None]
    offset_y = offset_y[:, None]
    scale_x = scale_x[:, None]
    scale_y = scale_y[:, None]

    x = keypoints[..., 0]
    y = keypoints[..., 1]

    x_boundary_inds = x == rois[:, 2][:, None]
    y_boundary_inds = y == rois[:, 3][:, None]

    x = (x - offset_x) * scale_x
    x = x.floor().long()
    y = (y - offset_y) * scale_y
    y = y.floor().long()

153
154
    x[x_boundary_inds] = heatmap_size - 1
    y[y_boundary_inds] = heatmap_size - 1
155
156
157
158
159
160
161
162
163
164
165

    valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
    vis = keypoints[..., 2] > 0
    valid = (valid_loc & vis).long()

    lin_ind = y * heatmap_size + x
    heatmaps = lin_ind * valid

    return heatmaps, valid


166
167
168
def _onnx_heatmaps_to_keypoints(
    maps, maps_i, roi_map_width, roi_map_height, widths_i, heights_i, offset_x_i, offset_y_i
):
169
170
171
172
173
    num_keypoints = torch.scalar_tensor(maps.size(1), dtype=torch.int64)

    width_correction = widths_i / roi_map_width
    height_correction = heights_i / roi_map_height

174
    roi_map = F.interpolate(
175
176
        maps_i[:, None], size=(int(roi_map_height), int(roi_map_width)), mode="bicubic", align_corners=False
    )[:, 0]
177
178
179
180

    w = torch.scalar_tensor(roi_map.size(2), dtype=torch.int64)
    pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)

181
182
    x_int = pos % w
    y_int = (pos - x_int) // w
183

184
185
186
187
188
189
    x = (torch.tensor(0.5, dtype=torch.float32) + x_int.to(dtype=torch.float32)) * width_correction.to(
        dtype=torch.float32
    )
    y = (torch.tensor(0.5, dtype=torch.float32) + y_int.to(dtype=torch.float32)) * height_correction.to(
        dtype=torch.float32
    )
190
191
192

    xy_preds_i_0 = x + offset_x_i.to(dtype=torch.float32)
    xy_preds_i_1 = y + offset_y_i.to(dtype=torch.float32)
193
    xy_preds_i_2 = torch.ones(xy_preds_i_1.shape, dtype=torch.float32)
194
195
196
197
198
199
200
201
    xy_preds_i = torch.stack(
        [
            xy_preds_i_0.to(dtype=torch.float32),
            xy_preds_i_1.to(dtype=torch.float32),
            xy_preds_i_2.to(dtype=torch.float32),
        ],
        0,
    )
202
203

    # TODO: simplify when indexing without rank will be supported by ONNX
Ksenija Stanojevic's avatar
Ksenija Stanojevic committed
204
205
206
    base = num_keypoints * num_keypoints + num_keypoints + 1
    ind = torch.arange(num_keypoints)
    ind = ind.to(dtype=torch.int64) * base
207
208
209
210
211
212
    end_scores_i = (
        roi_map.index_select(1, y_int.to(dtype=torch.int64))
        .index_select(2, x_int.to(dtype=torch.int64))
        .view(-1)
        .index_select(0, ind.to(dtype=torch.int64))
    )
Ksenija Stanojevic's avatar
Ksenija Stanojevic committed
213

214
215
216
    return xy_preds_i, end_scores_i


217
@torch.jit._script_if_tracing
218
219
220
def _onnx_heatmaps_to_keypoints_loop(
    maps, rois, widths_ceil, heights_ceil, widths, heights, offset_x, offset_y, num_keypoints
):
221
222
223
224
    xy_preds = torch.zeros((0, 3, int(num_keypoints)), dtype=torch.float32, device=maps.device)
    end_scores = torch.zeros((0, int(num_keypoints)), dtype=torch.float32, device=maps.device)

    for i in range(int(rois.size(0))):
225
226
227
228
229
230
231
        xy_preds_i, end_scores_i = _onnx_heatmaps_to_keypoints(
            maps, maps[i], widths_ceil[i], heights_ceil[i], widths[i], heights[i], offset_x[i], offset_y[i]
        )
        xy_preds = torch.cat((xy_preds.to(dtype=torch.float32), xy_preds_i.unsqueeze(0).to(dtype=torch.float32)), 0)
        end_scores = torch.cat(
            (end_scores.to(dtype=torch.float32), end_scores_i.to(dtype=torch.float32).unsqueeze(0)), 0
        )
232
233
234
    return xy_preds, end_scores


235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
def heatmaps_to_keypoints(maps, rois):
    """Extract predicted keypoint locations from heatmaps. Output has shape
    (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
    for each keypoint.
    """
    # This function converts a discrete image coordinate in a HEATMAP_SIZE x
    # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain
    # consistency with keypoints_to_heatmap_labels by using the conversion from
    # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a
    # continuous coordinate.
    offset_x = rois[:, 0]
    offset_y = rois[:, 1]

    widths = rois[:, 2] - rois[:, 0]
    heights = rois[:, 3] - rois[:, 1]
    widths = widths.clamp(min=1)
    heights = heights.clamp(min=1)
    widths_ceil = widths.ceil()
    heights_ceil = heights.ceil()

    num_keypoints = maps.shape[1]
256
257

    if torchvision._is_tracing():
258
259
260
261
262
263
264
265
266
267
268
        xy_preds, end_scores = _onnx_heatmaps_to_keypoints_loop(
            maps,
            rois,
            widths_ceil,
            heights_ceil,
            widths,
            heights,
            offset_x,
            offset_y,
            torch.scalar_tensor(num_keypoints, dtype=torch.int64),
        )
269
270
        return xy_preds.permute(0, 2, 1), end_scores

271
272
273
274
275
276
277
    xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device)
    end_scores = torch.zeros((len(rois), num_keypoints), dtype=torch.float32, device=maps.device)
    for i in range(len(rois)):
        roi_map_width = int(widths_ceil[i].item())
        roi_map_height = int(heights_ceil[i].item())
        width_correction = widths[i] / roi_map_width
        height_correction = heights[i] / roi_map_height
278
        roi_map = F.interpolate(
279
280
            maps[i][:, None], size=(roi_map_height, roi_map_width), mode="bicubic", align_corners=False
        )[:, 0]
281
282
283
        # roi_map_probs = scores_to_probs(roi_map.copy())
        w = roi_map.shape[2]
        pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
eellison's avatar
eellison committed
284

285
        x_int = pos % w
286
        y_int = torch.div(pos - x_int, w, rounding_mode="floor")
287
288
289
290
291
292
293
        # assert (roi_map_probs[k, y_int, x_int] ==
        #         roi_map_probs[k, :, :].max())
        x = (x_int.float() + 0.5) * width_correction
        y = (y_int.float() + 0.5) * height_correction
        xy_preds[i, 0, :] = x + offset_x[i]
        xy_preds[i, 1, :] = y + offset_y[i]
        xy_preds[i, 2, :] = 1
294
        end_scores[i, :] = roi_map[torch.arange(num_keypoints, device=roi_map.device), y_int, x_int]
295
296
297
298

    return xy_preds.permute(0, 2, 1), end_scores


299
def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs):
300
    # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
301
    N, K, H, W = keypoint_logits.shape
302
303
304
305
    if H != W:
        raise ValueError(
            f"keypoint_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
        )
306
    discretization_size = H
307
308
309
310
    heatmaps = []
    valid = []
    for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_keypoints, keypoint_matched_idxs):
        kp = gt_kp_in_image[midx]
311
        heatmaps_per_image, valid_per_image = keypoints_to_heatmap(kp, proposals_per_image, discretization_size)
312
313
314
315
316
        heatmaps.append(heatmaps_per_image.view(-1))
        valid.append(valid_per_image.view(-1))

    keypoint_targets = torch.cat(heatmaps, dim=0)
    valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
317
    valid = torch.where(valid)[0]
318
319
320
321
322
323
324
325
326
327
328
329
330

    # torch.mean (in binary_cross_entropy_with_logits) does'nt
    # accept empty tensors, so handle it sepaartely
    if keypoint_targets.numel() == 0 or len(valid) == 0:
        return keypoint_logits.sum() * 0

    keypoint_logits = keypoint_logits.view(N * K, H * W)

    keypoint_loss = F.cross_entropy(keypoint_logits[valid], keypoint_targets[valid])
    return keypoint_loss


def keypointrcnn_inference(x, boxes):
331
    # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
332
333
334
    kp_probs = []
    kp_scores = []

335
    boxes_per_image = [box.size(0) for box in boxes]
336
337
338
339
340
341
342
343
344
345
    x2 = x.split(boxes_per_image, dim=0)

    for xx, bb in zip(x2, boxes):
        kp_prob, scores = heatmaps_to_keypoints(xx, bb)
        kp_probs.append(kp_prob)
        kp_scores.append(scores)

    return kp_probs, kp_scores


346
def _onnx_expand_boxes(boxes, scale):
347
    # type: (Tensor, float) -> Tensor
348
349
350
351
    w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
    h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
    x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
    y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
352
353
354
355
356
357
358
359
360
361
362
363

    w_half = w_half.to(dtype=torch.float32) * scale
    h_half = h_half.to(dtype=torch.float32) * scale

    boxes_exp0 = x_c - w_half
    boxes_exp1 = y_c - h_half
    boxes_exp2 = x_c + w_half
    boxes_exp3 = y_c + h_half
    boxes_exp = torch.stack((boxes_exp0, boxes_exp1, boxes_exp2, boxes_exp3), 1)
    return boxes_exp


364
365
# the next two functions should be merged inside Masker
# but are kept here for the moment while we need them
366
# temporarily for paste_mask_in_image
367
def expand_boxes(boxes, scale):
368
    # type: (Tensor, float) -> Tensor
369
370
    if torchvision._is_tracing():
        return _onnx_expand_boxes(boxes, scale)
371
372
373
374
    w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
    h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
    x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
    y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
375
376
377
378
379
380
381
382
383
384
385
386

    w_half *= scale
    h_half *= scale

    boxes_exp = torch.zeros_like(boxes)
    boxes_exp[:, 0] = x_c - w_half
    boxes_exp[:, 2] = x_c + w_half
    boxes_exp[:, 1] = y_c - h_half
    boxes_exp[:, 3] = y_c + h_half
    return boxes_exp


eellison's avatar
eellison committed
387
388
389
390
391
392
@torch.jit.unused
def expand_masks_tracing_scale(M, padding):
    # type: (int, int) -> float
    return torch.tensor(M + 2 * padding).to(torch.float32) / torch.tensor(M).to(torch.float32)


393
def expand_masks(mask, padding):
394
    # type: (Tensor, int) -> Tuple[Tensor, float]
395
    M = mask.shape[-1]
eellison's avatar
eellison committed
396
397
    if torch._C._get_tracing_state():  # could not import is_tracing(), not sure why
        scale = expand_masks_tracing_scale(M, padding)
398
399
    else:
        scale = float(M + 2 * padding) / M
400
    padded_mask = F.pad(mask, (padding,) * 4)
401
402
403
404
    return padded_mask, scale


def paste_mask_in_image(mask, box, im_h, im_w):
405
    # type: (Tensor, Tensor, int, int) -> Tensor
406
407
408
409
410
411
412
413
414
415
    TO_REMOVE = 1
    w = int(box[2] - box[0] + TO_REMOVE)
    h = int(box[3] - box[1] + TO_REMOVE)
    w = max(w, 1)
    h = max(h, 1)

    # Set shape to [batchxCxHxW]
    mask = mask.expand((1, 1, -1, -1))

    # Resize mask
416
    mask = F.interpolate(mask, size=(h, w), mode="bilinear", align_corners=False)
417
418
419
420
421
422
423
424
    mask = mask[0][0]

    im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device)
    x_0 = max(box[0], 0)
    x_1 = min(box[2] + 1, im_w)
    y_0 = max(box[1], 0)
    y_1 = min(box[3] + 1, im_h)

425
    im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])]
426
427
428
    return im_mask


429
430
431
432
def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
    one = torch.ones(1, dtype=torch.int64)
    zero = torch.zeros(1, dtype=torch.int64)

433
434
    w = box[2] - box[0] + one
    h = box[3] - box[1] + one
435
436
437
438
439
440
441
    w = torch.max(torch.cat((w, one)))
    h = torch.max(torch.cat((h, one)))

    # Set shape to [batchxCxHxW]
    mask = mask.expand((1, 1, mask.size(0), mask.size(1)))

    # Resize mask
442
    mask = F.interpolate(mask, size=(int(h), int(w)), mode="bilinear", align_corners=False)
443
444
445
446
447
448
449
    mask = mask[0][0]

    x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero)))
    x_1 = torch.min(torch.cat((box[2].unsqueeze(0) + one, im_w.unsqueeze(0))))
    y_0 = torch.max(torch.cat((box[1].unsqueeze(0), zero)))
    y_1 = torch.min(torch.cat((box[3].unsqueeze(0) + one, im_h.unsqueeze(0))))

450
    unpaded_im_mask = mask[(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])]
451
452
453
454
455
456

    # TODO : replace below with a dynamic padding when support is added in ONNX

    # pad y
    zeros_y0 = torch.zeros(y_0, unpaded_im_mask.size(1))
    zeros_y1 = torch.zeros(im_h - y_1, unpaded_im_mask.size(1))
457
    concat_0 = torch.cat((zeros_y0, unpaded_im_mask.to(dtype=torch.float32), zeros_y1), 0)[0:im_h, :]
458
459
460
    # pad x
    zeros_x0 = torch.zeros(concat_0.size(0), x_0)
    zeros_x1 = torch.zeros(concat_0.size(0), im_w - x_1)
461
    im_mask = torch.cat((zeros_x0, concat_0, zeros_x1), 1)[:, :im_w]
462
463
464
    return im_mask


465
@torch.jit._script_if_tracing
466
467
468
469
470
471
472
473
474
def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w):
    res_append = torch.zeros(0, im_h, im_w)
    for i in range(masks.size(0)):
        mask_res = _onnx_paste_mask_in_image(masks[i][0], boxes[i], im_h, im_w)
        mask_res = mask_res.unsqueeze(0)
        res_append = torch.cat((res_append, mask_res))
    return res_append


475
def paste_masks_in_image(masks, boxes, img_shape, padding=1):
476
    # type: (Tensor, Tensor, Tuple[int, int], int) -> Tensor
477
    masks, scale = expand_masks(masks, padding=padding)
478
    boxes = expand_boxes(boxes, scale).to(dtype=torch.int64)
479
    im_h, im_w = img_shape
480
481

    if torchvision._is_tracing():
482
483
484
485
        return _onnx_paste_masks_in_image_loop(
            masks, boxes, torch.scalar_tensor(im_h, dtype=torch.int64), torch.scalar_tensor(im_w, dtype=torch.int64)
        )[:, None]
    res = [paste_mask_in_image(m[0], b, im_h, im_w) for m, b in zip(masks, boxes)]
486
    if len(res) > 0:
eellison's avatar
eellison committed
487
        ret = torch.stack(res, dim=0)[:, None]
488
    else:
eellison's avatar
eellison committed
489
490
        ret = masks.new_empty((0, 1, im_h, im_w))
    return ret
491
492


493
class RoIHeads(nn.Module):
eellison's avatar
eellison committed
494
    __annotations__ = {
495
496
497
        "box_coder": det_utils.BoxCoder,
        "proposal_matcher": det_utils.Matcher,
        "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
eellison's avatar
eellison committed
498
499
    }

500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
    def __init__(
        self,
        box_roi_pool,
        box_head,
        box_predictor,
        # Faster R-CNN training
        fg_iou_thresh,
        bg_iou_thresh,
        batch_size_per_image,
        positive_fraction,
        bbox_reg_weights,
        # Faster R-CNN inference
        score_thresh,
        nms_thresh,
        detections_per_img,
        # Mask
        mask_roi_pool=None,
        mask_head=None,
        mask_predictor=None,
        keypoint_roi_pool=None,
        keypoint_head=None,
        keypoint_predictor=None,
    ):
523
        super().__init__()
524
525
526

        self.box_similarity = box_ops.box_iou
        # assign ground-truth boxes for each proposal
527
        self.proposal_matcher = det_utils.Matcher(fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=False)
528

529
        self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)
530
531

        if bbox_reg_weights is None:
532
            bbox_reg_weights = (10.0, 10.0, 5.0, 5.0)
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
        self.box_coder = det_utils.BoxCoder(bbox_reg_weights)

        self.box_roi_pool = box_roi_pool
        self.box_head = box_head
        self.box_predictor = box_predictor

        self.score_thresh = score_thresh
        self.nms_thresh = nms_thresh
        self.detections_per_img = detections_per_img

        self.mask_roi_pool = mask_roi_pool
        self.mask_head = mask_head
        self.mask_predictor = mask_predictor

        self.keypoint_roi_pool = keypoint_roi_pool
        self.keypoint_head = keypoint_head
        self.keypoint_predictor = keypoint_predictor

    def has_mask(self):
        if self.mask_roi_pool is None:
            return False
        if self.mask_head is None:
            return False
        if self.mask_predictor is None:
            return False
        return True

    def has_keypoint(self):
        if self.keypoint_roi_pool is None:
            return False
        if self.keypoint_head is None:
            return False
        if self.keypoint_predictor is None:
            return False
        return True

    def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
570
        # type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
571
572
573
574
        matched_idxs = []
        labels = []
        for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):

575
576
577
578
579
580
            if gt_boxes_in_image.numel() == 0:
                # Background image
                device = proposals_in_image.device
                clamped_matched_idxs_in_image = torch.zeros(
                    (proposals_in_image.shape[0],), dtype=torch.int64, device=device
                )
581
                labels_in_image = torch.zeros((proposals_in_image.shape[0],), dtype=torch.int64, device=device)
582
583
584
585
586
587
588
589
590
591
592
593
            else:
                #  set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands
                match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image)
                matched_idxs_in_image = self.proposal_matcher(match_quality_matrix)

                clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0)

                labels_in_image = gt_labels_in_image[clamped_matched_idxs_in_image]
                labels_in_image = labels_in_image.to(dtype=torch.int64)

                # Label background (below the low threshold)
                bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD
594
                labels_in_image[bg_inds] = 0
595
596
597

                # Label ignore proposals (between low and high thresholds)
                ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS
598
                labels_in_image[ignore_inds] = -1  # -1 is ignored by sampler
599
600
601
602
603
604

            matched_idxs.append(clamped_matched_idxs_in_image)
            labels.append(labels_in_image)
        return matched_idxs, labels

    def subsample(self, labels):
605
        # type: (List[Tensor]) -> List[Tensor]
606
607
        sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
        sampled_inds = []
608
        for img_idx, (pos_inds_img, neg_inds_img) in enumerate(zip(sampled_pos_inds, sampled_neg_inds)):
609
            img_sampled_inds = torch.where(pos_inds_img | neg_inds_img)[0]
610
611
612
613
            sampled_inds.append(img_sampled_inds)
        return sampled_inds

    def add_gt_proposals(self, proposals, gt_boxes):
614
        # type: (List[Tensor], List[Tensor]) -> List[Tensor]
615
        proposals = [torch.cat((proposal, gt_box)) for proposal, gt_box in zip(proposals, gt_boxes)]
616
617
618
619

        return proposals

    def check_targets(self, targets):
620
        # type: (Optional[List[Dict[str, Tensor]]]) -> None
621
622
623
624
625
626
        if targets is None:
            raise ValueError("targets should not be None")
        if not all(["boxes" in t for t in targets]):
            raise ValueError("Every element of targets should have a boxes key")
        if not all(["labels" in t for t in targets]):
            raise ValueError("Every element of targets should have a labels key")
eellison's avatar
eellison committed
627
        if self.has_mask():
628
629
            if not all(["masks" in t for t in targets]):
                raise ValueError("Every element of targets should have a masks key")
630

631
632
633
634
635
    def select_training_samples(
        self,
        proposals,  # type: List[Tensor]
        targets,  # type: Optional[List[Dict[str, Tensor]]]
    ):
636
        # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]
637
        self.check_targets(targets)
638
639
        if targets is None:
            raise ValueError("targets should not be None")
640
        dtype = proposals[0].dtype
641
642
        device = proposals[0].device

643
        gt_boxes = [t["boxes"].to(dtype) for t in targets]
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
        gt_labels = [t["labels"] for t in targets]

        # append ground-truth bboxes to propos
        proposals = self.add_gt_proposals(proposals, gt_boxes)

        # get matching gt indices for each proposal
        matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
        # sample a fixed proportion of positive-negative proposals
        sampled_inds = self.subsample(labels)
        matched_gt_boxes = []
        num_images = len(proposals)
        for img_id in range(num_images):
            img_sampled_inds = sampled_inds[img_id]
            proposals[img_id] = proposals[img_id][img_sampled_inds]
            labels[img_id] = labels[img_id][img_sampled_inds]
            matched_idxs[img_id] = matched_idxs[img_id][img_sampled_inds]
660
661
662
663
664

            gt_boxes_in_image = gt_boxes[img_id]
            if gt_boxes_in_image.numel() == 0:
                gt_boxes_in_image = torch.zeros((1, 4), dtype=dtype, device=device)
            matched_gt_boxes.append(gt_boxes_in_image[matched_idxs[img_id]])
665
666
667
668

        regression_targets = self.box_coder.encode(matched_gt_boxes, proposals)
        return proposals, matched_idxs, labels, regression_targets

669
670
671
672
673
674
675
    def postprocess_detections(
        self,
        class_logits,  # type: Tensor
        box_regression,  # type: Tensor
        proposals,  # type: List[Tensor]
        image_shapes,  # type: List[Tuple[int, int]]
    ):
676
        # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]
677
678
679
        device = class_logits.device
        num_classes = class_logits.shape[-1]

680
        boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals]
681
682
683
684
        pred_boxes = self.box_coder.decode(box_regression, proposals)

        pred_scores = F.softmax(class_logits, -1)

685
686
        pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
        pred_scores_list = pred_scores.split(boxes_per_image, 0)
687
688
689
690

        all_boxes = []
        all_scores = []
        all_labels = []
eellison's avatar
eellison committed
691
        for boxes, scores, image_shape in zip(pred_boxes_list, pred_scores_list, image_shapes):
692
693
694
695
696
697
698
699
700
701
702
703
704
            boxes = box_ops.clip_boxes_to_image(boxes, image_shape)

            # create labels for each prediction
            labels = torch.arange(num_classes, device=device)
            labels = labels.view(1, -1).expand_as(scores)

            # remove predictions with the background label
            boxes = boxes[:, 1:]
            scores = scores[:, 1:]
            labels = labels[:, 1:]

            # batch everything, by making every class prediction be a separate instance
            boxes = boxes.reshape(-1, 4)
705
706
            scores = scores.reshape(-1)
            labels = labels.reshape(-1)
707
708

            # remove low scoring boxes
709
            inds = torch.where(scores > self.score_thresh)[0]
710
711
            boxes, scores, labels = boxes[inds], scores[inds], labels[inds]

712
713
714
715
            # remove empty boxes
            keep = box_ops.remove_small_boxes(boxes, min_size=1e-2)
            boxes, scores, labels = boxes[keep], scores[keep], labels[keep]

716
717
718
            # non-maximum suppression, independently done per class
            keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
            # keep only topk scoring predictions
719
            keep = keep[: self.detections_per_img]
720
721
722
723
724
725
726
727
            boxes, scores, labels = boxes[keep], scores[keep], labels[keep]

            all_boxes.append(boxes)
            all_scores.append(scores)
            all_labels.append(labels)

        return all_boxes, all_scores, all_labels

728
729
730
731
732
733
734
    def forward(
        self,
        features,  # type: Dict[str, Tensor]
        proposals,  # type: List[Tensor]
        image_shapes,  # type: List[Tuple[int, int]]
        targets=None,  # type: Optional[List[Dict[str, Tensor]]]
    ):
735
        # type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]]
736
        """
737
        Args:
738
739
740
741
742
            features (List[Tensor])
            proposals (List[Tensor[N, 4]])
            image_shapes (List[Tuple[H, W]])
            targets (List[Dict])
        """
743
744
        if targets is not None:
            for t in targets:
eellison's avatar
eellison committed
745
746
                # TODO: https://github.com/pytorch/pytorch/issues/26731
                floating_point_types = (torch.float, torch.double, torch.half)
747
748
749
750
                if not t["boxes"].dtype in floating_point_types:
                    raise TypeError(f"target boxes must of float type, instead got {t['boxes'].dtype}")
                if not t["labels"].dtype == torch.int64:
                    raise TypeError("target labels must of int64 type, instead got {t['labels'].dtype}")
eellison's avatar
eellison committed
751
                if self.has_keypoint():
752
753
                    if not t["keypoints"].dtype == torch.float32:
                        raise TypeError(f"target keypoints must of float type, instead got {t['keypoints'].dtype}")
754

755
756
        if self.training:
            proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
eellison's avatar
eellison committed
757
758
759
760
        else:
            labels = None
            regression_targets = None
            matched_idxs = None
761
762
763
764
765

        box_features = self.box_roi_pool(features, proposals, image_shapes)
        box_features = self.box_head(box_features)
        class_logits, box_regression = self.box_predictor(box_features)

766
        result: List[Dict[str, torch.Tensor]] = []
eellison's avatar
eellison committed
767
        losses = {}
768
        if self.training:
769
770
771
772
            if labels is None:
                raise ValueError("labels cannot be None")
            if regression_targets is None:
                raise ValueError("regression_targets cannot be None")
773
774
            loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
            losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
775
776
777
778
779
        else:
            boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
            num_images = len(boxes)
            for i in range(num_images):
                result.append(
eellison's avatar
eellison committed
780
781
782
783
784
                    {
                        "boxes": boxes[i],
                        "labels": labels[i],
                        "scores": scores[i],
                    }
785
786
                )

eellison's avatar
eellison committed
787
        if self.has_mask():
788
789
            mask_proposals = [p["boxes"] for p in result]
            if self.training:
790
791
792
                if matched_idxs is None:
                    raise ValueError("if in trainning, matched_idxs should not be None")

793
794
795
796
797
                # during training, only focus on positive boxes
                num_images = len(proposals)
                mask_proposals = []
                pos_matched_idxs = []
                for img_id in range(num_images):
798
                    pos = torch.where(labels[img_id] > 0)[0]
799
800
                    mask_proposals.append(proposals[img_id][pos])
                    pos_matched_idxs.append(matched_idxs[img_id][pos])
eellison's avatar
eellison committed
801
802
            else:
                pos_matched_idxs = None
803

eellison's avatar
eellison committed
804
805
806
807
808
809
            if self.mask_roi_pool is not None:
                mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes)
                mask_features = self.mask_head(mask_features)
                mask_logits = self.mask_predictor(mask_features)
            else:
                raise Exception("Expected mask_roi_pool to be not None")
810
811
812

            loss_mask = {}
            if self.training:
813
814
                if targets is None or pos_matched_idxs is None or mask_logits is None:
                    raise ValueError("targets, pos_matched_idxs, mask_logits cannot be None when training")
eellison's avatar
eellison committed
815

816
817
                gt_masks = [t["masks"] for t in targets]
                gt_labels = [t["labels"] for t in targets]
818
819
                rcnn_loss_mask = maskrcnn_loss(mask_logits, mask_proposals, gt_masks, gt_labels, pos_matched_idxs)
                loss_mask = {"loss_mask": rcnn_loss_mask}
820
821
822
823
            else:
                labels = [r["labels"] for r in result]
                masks_probs = maskrcnn_inference(mask_logits, labels)
                for mask_prob, r in zip(masks_probs, result):
824
                    r["masks"] = mask_prob
825
826
827

            losses.update(loss_mask)

eellison's avatar
eellison committed
828
829
        # keep none checks in if conditional so torchscript will conditionally
        # compile each branch
830
831
832
833
834
        if (
            self.keypoint_roi_pool is not None
            and self.keypoint_head is not None
            and self.keypoint_predictor is not None
        ):
835
836
837
838
839
840
            keypoint_proposals = [p["boxes"] for p in result]
            if self.training:
                # during training, only focus on positive boxes
                num_images = len(proposals)
                keypoint_proposals = []
                pos_matched_idxs = []
841
842
843
                if matched_idxs is None:
                    raise ValueError("if in trainning, matched_idxs should not be None")

844
                for img_id in range(num_images):
845
                    pos = torch.where(labels[img_id] > 0)[0]
846
847
                    keypoint_proposals.append(proposals[img_id][pos])
                    pos_matched_idxs.append(matched_idxs[img_id][pos])
eellison's avatar
eellison committed
848
849
            else:
                pos_matched_idxs = None
850
851
852
853
854
855
856

            keypoint_features = self.keypoint_roi_pool(features, keypoint_proposals, image_shapes)
            keypoint_features = self.keypoint_head(keypoint_features)
            keypoint_logits = self.keypoint_predictor(keypoint_features)

            loss_keypoint = {}
            if self.training:
857
858
                if targets is None or pos_matched_idxs is None:
                    raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
eellison's avatar
eellison committed
859

860
                gt_keypoints = [t["keypoints"] for t in targets]
eellison's avatar
eellison committed
861
                rcnn_loss_keypoint = keypointrcnn_loss(
862
863
864
                    keypoint_logits, keypoint_proposals, gt_keypoints, pos_matched_idxs
                )
                loss_keypoint = {"loss_keypoint": rcnn_loss_keypoint}
865
            else:
866
867
868
869
                if keypoint_logits is None or keypoint_proposals is None:
                    raise ValueError(
                        "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
                    )
eellison's avatar
eellison committed
870

871
872
873
874
875
876
877
                keypoints_probs, kp_scores = keypointrcnn_inference(keypoint_logits, keypoint_proposals)
                for keypoint_prob, kps, r in zip(keypoints_probs, kp_scores, result):
                    r["keypoints"] = keypoint_prob
                    r["keypoints_scores"] = kps
            losses.update(loss_keypoint)

        return result, losses