Unverified Commit a68db4fa authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Remove discretization size from maskrcnn and keypointrcnn (#929)

Those were not free parameters, and can be inferred via the size of the output feature map
parent 5608b06a
......@@ -35,7 +35,6 @@ class KeypointRCNN(FasterRCNN):
bbox_reg_weights=None,
# keypoint parameters
keypoint_roi_pool=None, keypoint_head=None, keypoint_predictor=None,
keypoint_discretization_size=56,
num_keypoints=17):
assert isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None)))
......@@ -84,7 +83,6 @@ class KeypointRCNN(FasterRCNN):
self.roi_heads.keypoint_roi_pool = keypoint_roi_pool
self.roi_heads.keypoint_head = keypoint_head
self.roi_heads.keypoint_predictor = keypoint_predictor
self.roi_heads.keypoint_discretization_size = keypoint_discretization_size
class KeypointRCNNHeads(nn.Sequential):
......
......@@ -36,8 +36,7 @@ class MaskRCNN(FasterRCNN):
box_batch_size_per_image=512, box_positive_fraction=0.25,
bbox_reg_weights=None,
# Mask parameters
mask_roi_pool=None, mask_head=None, mask_predictor=None,
mask_discretization_size=28):
mask_roi_pool=None, mask_head=None, mask_predictor=None):
assert isinstance(mask_roi_pool, (MultiScaleRoIAlign, type(None)))
......@@ -84,7 +83,6 @@ class MaskRCNN(FasterRCNN):
self.roi_heads.mask_roi_pool = mask_roi_pool
self.roi_heads.mask_head = mask_head
self.roi_heads.mask_predictor = mask_predictor
self.roi_heads.mask_discretization_size = mask_discretization_size
class MaskRCNNHeads(nn.Sequential):
......
......@@ -90,7 +90,7 @@ def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
return roi_align(gt_masks, rois, (M, M), 1)[:, 0]
def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs, discretization_size):
def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs):
"""
Arguments:
proposals (list[BoxList])
......@@ -101,6 +101,7 @@ def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs
mask_loss (Tensor): scalar tensor containing the loss
"""
discretization_size = mask_logits.shape[-1]
labels = [l[idxs] for l, idxs in zip(gt_labels, mask_matched_idxs)]
mask_targets = [
project_masks_on_boxes(m, p, i, discretization_size)
......@@ -203,7 +204,10 @@ def heatmaps_to_keypoints(maps, rois):
return xy_preds.permute(0, 2, 1), end_scores
def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs, discretization_size):
def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs):
N, K, H, W = keypoint_logits.shape
assert H == W
discretization_size = H
heatmaps = []
valid = []
for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_keypoints, keypoint_matched_idxs):
......@@ -223,7 +227,6 @@ def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched
if keypoint_targets.numel() == 0 or len(valid) == 0:
return keypoint_logits.sum() * 0
N, K, H, W = keypoint_logits.shape
keypoint_logits = keypoint_logits.view(N * K, H * W)
keypoint_loss = F.cross_entropy(keypoint_logits[valid], keypoint_targets[valid])
......@@ -331,7 +334,6 @@ class RoIHeads(torch.nn.Module):
mask_roi_pool=None,
mask_head=None,
mask_predictor=None,
mask_discretization_size=None,
keypoint_roi_pool=None,
keypoint_head=None,
keypoint_predictor=None,
......@@ -364,7 +366,6 @@ class RoIHeads(torch.nn.Module):
self.mask_roi_pool = mask_roi_pool
self.mask_head = mask_head
self.mask_predictor = mask_predictor
self.mask_discretization_size = mask_discretization_size
self.keypoint_roi_pool = keypoint_roi_pool
self.keypoint_head = keypoint_head
......@@ -378,8 +379,6 @@ class RoIHeads(torch.nn.Module):
return False
if self.mask_predictor is None:
return False
if self.mask_discretization_size is None:
return False
return True
@property
......@@ -390,8 +389,6 @@ class RoIHeads(torch.nn.Module):
return False
if self.keypoint_predictor is None:
return False
if self.keypoint_discretization_size is None:
return False
return True
def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
......@@ -570,7 +567,7 @@ class RoIHeads(torch.nn.Module):
gt_labels = [t["labels"] for t in targets]
loss_mask = maskrcnn_loss(
mask_logits, mask_proposals,
gt_masks, gt_labels, pos_matched_idxs, self.mask_discretization_size)
gt_masks, gt_labels, pos_matched_idxs)
loss_mask = dict(loss_mask=loss_mask)
else:
labels = [r["labels"] for r in result]
......@@ -601,7 +598,7 @@ class RoIHeads(torch.nn.Module):
gt_keypoints = [t["keypoints"] for t in targets]
loss_keypoint = keypointrcnn_loss(
keypoint_logits, keypoint_proposals,
gt_keypoints, pos_matched_idxs, self.keypoint_discretization_size)
gt_keypoints, pos_matched_idxs)
loss_keypoint = dict(loss_keypoint=loss_keypoint)
else:
keypoints_probs, kp_scores = keypointrcnn_inference(keypoint_logits, keypoint_proposals)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment