You need to sign in or sign up before continuing.
Unverified Commit a68db4fa authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Remove discretization size from maskrcnn and keypointrcnn (#929)

Those were not free parameters, and can be inferred via the size of the output feature map
parent 5608b06a
...@@ -35,7 +35,6 @@ class KeypointRCNN(FasterRCNN): ...@@ -35,7 +35,6 @@ class KeypointRCNN(FasterRCNN):
bbox_reg_weights=None, bbox_reg_weights=None,
# keypoint parameters # keypoint parameters
keypoint_roi_pool=None, keypoint_head=None, keypoint_predictor=None, keypoint_roi_pool=None, keypoint_head=None, keypoint_predictor=None,
keypoint_discretization_size=56,
num_keypoints=17): num_keypoints=17):
assert isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None))) assert isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None)))
...@@ -84,7 +83,6 @@ class KeypointRCNN(FasterRCNN): ...@@ -84,7 +83,6 @@ class KeypointRCNN(FasterRCNN):
self.roi_heads.keypoint_roi_pool = keypoint_roi_pool self.roi_heads.keypoint_roi_pool = keypoint_roi_pool
self.roi_heads.keypoint_head = keypoint_head self.roi_heads.keypoint_head = keypoint_head
self.roi_heads.keypoint_predictor = keypoint_predictor self.roi_heads.keypoint_predictor = keypoint_predictor
self.roi_heads.keypoint_discretization_size = keypoint_discretization_size
class KeypointRCNNHeads(nn.Sequential): class KeypointRCNNHeads(nn.Sequential):
......
...@@ -36,8 +36,7 @@ class MaskRCNN(FasterRCNN): ...@@ -36,8 +36,7 @@ class MaskRCNN(FasterRCNN):
box_batch_size_per_image=512, box_positive_fraction=0.25, box_batch_size_per_image=512, box_positive_fraction=0.25,
bbox_reg_weights=None, bbox_reg_weights=None,
# Mask parameters # Mask parameters
mask_roi_pool=None, mask_head=None, mask_predictor=None, mask_roi_pool=None, mask_head=None, mask_predictor=None):
mask_discretization_size=28):
assert isinstance(mask_roi_pool, (MultiScaleRoIAlign, type(None))) assert isinstance(mask_roi_pool, (MultiScaleRoIAlign, type(None)))
...@@ -84,7 +83,6 @@ class MaskRCNN(FasterRCNN): ...@@ -84,7 +83,6 @@ class MaskRCNN(FasterRCNN):
self.roi_heads.mask_roi_pool = mask_roi_pool self.roi_heads.mask_roi_pool = mask_roi_pool
self.roi_heads.mask_head = mask_head self.roi_heads.mask_head = mask_head
self.roi_heads.mask_predictor = mask_predictor self.roi_heads.mask_predictor = mask_predictor
self.roi_heads.mask_discretization_size = mask_discretization_size
class MaskRCNNHeads(nn.Sequential): class MaskRCNNHeads(nn.Sequential):
......
...@@ -90,7 +90,7 @@ def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M): ...@@ -90,7 +90,7 @@ def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
return roi_align(gt_masks, rois, (M, M), 1)[:, 0] return roi_align(gt_masks, rois, (M, M), 1)[:, 0]
def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs, discretization_size): def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs):
""" """
Arguments: Arguments:
proposals (list[BoxList]) proposals (list[BoxList])
...@@ -101,6 +101,7 @@ def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs ...@@ -101,6 +101,7 @@ def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs
mask_loss (Tensor): scalar tensor containing the loss mask_loss (Tensor): scalar tensor containing the loss
""" """
discretization_size = mask_logits.shape[-1]
labels = [l[idxs] for l, idxs in zip(gt_labels, mask_matched_idxs)] labels = [l[idxs] for l, idxs in zip(gt_labels, mask_matched_idxs)]
mask_targets = [ mask_targets = [
project_masks_on_boxes(m, p, i, discretization_size) project_masks_on_boxes(m, p, i, discretization_size)
...@@ -203,7 +204,10 @@ def heatmaps_to_keypoints(maps, rois): ...@@ -203,7 +204,10 @@ def heatmaps_to_keypoints(maps, rois):
return xy_preds.permute(0, 2, 1), end_scores return xy_preds.permute(0, 2, 1), end_scores
def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs, discretization_size): def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs):
N, K, H, W = keypoint_logits.shape
assert H == W
discretization_size = H
heatmaps = [] heatmaps = []
valid = [] valid = []
for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_keypoints, keypoint_matched_idxs): for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_keypoints, keypoint_matched_idxs):
...@@ -223,7 +227,6 @@ def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched ...@@ -223,7 +227,6 @@ def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched
if keypoint_targets.numel() == 0 or len(valid) == 0: if keypoint_targets.numel() == 0 or len(valid) == 0:
return keypoint_logits.sum() * 0 return keypoint_logits.sum() * 0
N, K, H, W = keypoint_logits.shape
keypoint_logits = keypoint_logits.view(N * K, H * W) keypoint_logits = keypoint_logits.view(N * K, H * W)
keypoint_loss = F.cross_entropy(keypoint_logits[valid], keypoint_targets[valid]) keypoint_loss = F.cross_entropy(keypoint_logits[valid], keypoint_targets[valid])
...@@ -331,7 +334,6 @@ class RoIHeads(torch.nn.Module): ...@@ -331,7 +334,6 @@ class RoIHeads(torch.nn.Module):
mask_roi_pool=None, mask_roi_pool=None,
mask_head=None, mask_head=None,
mask_predictor=None, mask_predictor=None,
mask_discretization_size=None,
keypoint_roi_pool=None, keypoint_roi_pool=None,
keypoint_head=None, keypoint_head=None,
keypoint_predictor=None, keypoint_predictor=None,
...@@ -364,7 +366,6 @@ class RoIHeads(torch.nn.Module): ...@@ -364,7 +366,6 @@ class RoIHeads(torch.nn.Module):
self.mask_roi_pool = mask_roi_pool self.mask_roi_pool = mask_roi_pool
self.mask_head = mask_head self.mask_head = mask_head
self.mask_predictor = mask_predictor self.mask_predictor = mask_predictor
self.mask_discretization_size = mask_discretization_size
self.keypoint_roi_pool = keypoint_roi_pool self.keypoint_roi_pool = keypoint_roi_pool
self.keypoint_head = keypoint_head self.keypoint_head = keypoint_head
...@@ -378,8 +379,6 @@ class RoIHeads(torch.nn.Module): ...@@ -378,8 +379,6 @@ class RoIHeads(torch.nn.Module):
return False return False
if self.mask_predictor is None: if self.mask_predictor is None:
return False return False
if self.mask_discretization_size is None:
return False
return True return True
@property @property
...@@ -390,8 +389,6 @@ class RoIHeads(torch.nn.Module): ...@@ -390,8 +389,6 @@ class RoIHeads(torch.nn.Module):
return False return False
if self.keypoint_predictor is None: if self.keypoint_predictor is None:
return False return False
if self.keypoint_discretization_size is None:
return False
return True return True
def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels): def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
...@@ -570,7 +567,7 @@ class RoIHeads(torch.nn.Module): ...@@ -570,7 +567,7 @@ class RoIHeads(torch.nn.Module):
gt_labels = [t["labels"] for t in targets] gt_labels = [t["labels"] for t in targets]
loss_mask = maskrcnn_loss( loss_mask = maskrcnn_loss(
mask_logits, mask_proposals, mask_logits, mask_proposals,
gt_masks, gt_labels, pos_matched_idxs, self.mask_discretization_size) gt_masks, gt_labels, pos_matched_idxs)
loss_mask = dict(loss_mask=loss_mask) loss_mask = dict(loss_mask=loss_mask)
else: else:
labels = [r["labels"] for r in result] labels = [r["labels"] for r in result]
...@@ -601,7 +598,7 @@ class RoIHeads(torch.nn.Module): ...@@ -601,7 +598,7 @@ class RoIHeads(torch.nn.Module):
gt_keypoints = [t["keypoints"] for t in targets] gt_keypoints = [t["keypoints"] for t in targets]
loss_keypoint = keypointrcnn_loss( loss_keypoint = keypointrcnn_loss(
keypoint_logits, keypoint_proposals, keypoint_logits, keypoint_proposals,
gt_keypoints, pos_matched_idxs, self.keypoint_discretization_size) gt_keypoints, pos_matched_idxs)
loss_keypoint = dict(loss_keypoint=loss_keypoint) loss_keypoint = dict(loss_keypoint=loss_keypoint)
else: else:
keypoints_probs, kp_scores = keypointrcnn_inference(keypoint_logits, keypoint_proposals) keypoints_probs, kp_scores = keypointrcnn_inference(keypoint_logits, keypoint_proposals)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment