Unverified Commit a68db4fa authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Remove discretization size from maskrcnn and keypointrcnn (#929)

Those were not free parameters, and can be inferred via the size of the output feature map
parent 5608b06a
...@@ -35,7 +35,6 @@ class KeypointRCNN(FasterRCNN): ...@@ -35,7 +35,6 @@ class KeypointRCNN(FasterRCNN):
bbox_reg_weights=None, bbox_reg_weights=None,
# keypoint parameters # keypoint parameters
keypoint_roi_pool=None, keypoint_head=None, keypoint_predictor=None, keypoint_roi_pool=None, keypoint_head=None, keypoint_predictor=None,
keypoint_discretization_size=56,
num_keypoints=17): num_keypoints=17):
assert isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None))) assert isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None)))
...@@ -84,7 +83,6 @@ class KeypointRCNN(FasterRCNN): ...@@ -84,7 +83,6 @@ class KeypointRCNN(FasterRCNN):
self.roi_heads.keypoint_roi_pool = keypoint_roi_pool self.roi_heads.keypoint_roi_pool = keypoint_roi_pool
self.roi_heads.keypoint_head = keypoint_head self.roi_heads.keypoint_head = keypoint_head
self.roi_heads.keypoint_predictor = keypoint_predictor self.roi_heads.keypoint_predictor = keypoint_predictor
self.roi_heads.keypoint_discretization_size = keypoint_discretization_size
class KeypointRCNNHeads(nn.Sequential): class KeypointRCNNHeads(nn.Sequential):
......
...@@ -36,8 +36,7 @@ class MaskRCNN(FasterRCNN): ...@@ -36,8 +36,7 @@ class MaskRCNN(FasterRCNN):
box_batch_size_per_image=512, box_positive_fraction=0.25, box_batch_size_per_image=512, box_positive_fraction=0.25,
bbox_reg_weights=None, bbox_reg_weights=None,
# Mask parameters # Mask parameters
mask_roi_pool=None, mask_head=None, mask_predictor=None, mask_roi_pool=None, mask_head=None, mask_predictor=None):
mask_discretization_size=28):
assert isinstance(mask_roi_pool, (MultiScaleRoIAlign, type(None))) assert isinstance(mask_roi_pool, (MultiScaleRoIAlign, type(None)))
...@@ -84,7 +83,6 @@ class MaskRCNN(FasterRCNN): ...@@ -84,7 +83,6 @@ class MaskRCNN(FasterRCNN):
self.roi_heads.mask_roi_pool = mask_roi_pool self.roi_heads.mask_roi_pool = mask_roi_pool
self.roi_heads.mask_head = mask_head self.roi_heads.mask_head = mask_head
self.roi_heads.mask_predictor = mask_predictor self.roi_heads.mask_predictor = mask_predictor
self.roi_heads.mask_discretization_size = mask_discretization_size
class MaskRCNNHeads(nn.Sequential): class MaskRCNNHeads(nn.Sequential):
......
...@@ -90,7 +90,7 @@ def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M): ...@@ -90,7 +90,7 @@ def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
return roi_align(gt_masks, rois, (M, M), 1)[:, 0] return roi_align(gt_masks, rois, (M, M), 1)[:, 0]
def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs, discretization_size): def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs):
""" """
Arguments: Arguments:
proposals (list[BoxList]) proposals (list[BoxList])
...@@ -101,6 +101,7 @@ def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs ...@@ -101,6 +101,7 @@ def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs
mask_loss (Tensor): scalar tensor containing the loss mask_loss (Tensor): scalar tensor containing the loss
""" """
discretization_size = mask_logits.shape[-1]
labels = [l[idxs] for l, idxs in zip(gt_labels, mask_matched_idxs)] labels = [l[idxs] for l, idxs in zip(gt_labels, mask_matched_idxs)]
mask_targets = [ mask_targets = [
project_masks_on_boxes(m, p, i, discretization_size) project_masks_on_boxes(m, p, i, discretization_size)
...@@ -203,7 +204,10 @@ def heatmaps_to_keypoints(maps, rois): ...@@ -203,7 +204,10 @@ def heatmaps_to_keypoints(maps, rois):
return xy_preds.permute(0, 2, 1), end_scores return xy_preds.permute(0, 2, 1), end_scores
def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs, discretization_size): def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs):
N, K, H, W = keypoint_logits.shape
assert H == W
discretization_size = H
heatmaps = [] heatmaps = []
valid = [] valid = []
for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_keypoints, keypoint_matched_idxs): for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_keypoints, keypoint_matched_idxs):
...@@ -223,7 +227,6 @@ def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched ...@@ -223,7 +227,6 @@ def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched
if keypoint_targets.numel() == 0 or len(valid) == 0: if keypoint_targets.numel() == 0 or len(valid) == 0:
return keypoint_logits.sum() * 0 return keypoint_logits.sum() * 0
N, K, H, W = keypoint_logits.shape
keypoint_logits = keypoint_logits.view(N * K, H * W) keypoint_logits = keypoint_logits.view(N * K, H * W)
keypoint_loss = F.cross_entropy(keypoint_logits[valid], keypoint_targets[valid]) keypoint_loss = F.cross_entropy(keypoint_logits[valid], keypoint_targets[valid])
...@@ -331,7 +334,6 @@ class RoIHeads(torch.nn.Module): ...@@ -331,7 +334,6 @@ class RoIHeads(torch.nn.Module):
mask_roi_pool=None, mask_roi_pool=None,
mask_head=None, mask_head=None,
mask_predictor=None, mask_predictor=None,
mask_discretization_size=None,
keypoint_roi_pool=None, keypoint_roi_pool=None,
keypoint_head=None, keypoint_head=None,
keypoint_predictor=None, keypoint_predictor=None,
...@@ -364,7 +366,6 @@ class RoIHeads(torch.nn.Module): ...@@ -364,7 +366,6 @@ class RoIHeads(torch.nn.Module):
self.mask_roi_pool = mask_roi_pool self.mask_roi_pool = mask_roi_pool
self.mask_head = mask_head self.mask_head = mask_head
self.mask_predictor = mask_predictor self.mask_predictor = mask_predictor
self.mask_discretization_size = mask_discretization_size
self.keypoint_roi_pool = keypoint_roi_pool self.keypoint_roi_pool = keypoint_roi_pool
self.keypoint_head = keypoint_head self.keypoint_head = keypoint_head
...@@ -378,8 +379,6 @@ class RoIHeads(torch.nn.Module): ...@@ -378,8 +379,6 @@ class RoIHeads(torch.nn.Module):
return False return False
if self.mask_predictor is None: if self.mask_predictor is None:
return False return False
if self.mask_discretization_size is None:
return False
return True return True
@property @property
...@@ -390,8 +389,6 @@ class RoIHeads(torch.nn.Module): ...@@ -390,8 +389,6 @@ class RoIHeads(torch.nn.Module):
return False return False
if self.keypoint_predictor is None: if self.keypoint_predictor is None:
return False return False
if self.keypoint_discretization_size is None:
return False
return True return True
def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels): def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
...@@ -570,7 +567,7 @@ class RoIHeads(torch.nn.Module): ...@@ -570,7 +567,7 @@ class RoIHeads(torch.nn.Module):
gt_labels = [t["labels"] for t in targets] gt_labels = [t["labels"] for t in targets]
loss_mask = maskrcnn_loss( loss_mask = maskrcnn_loss(
mask_logits, mask_proposals, mask_logits, mask_proposals,
gt_masks, gt_labels, pos_matched_idxs, self.mask_discretization_size) gt_masks, gt_labels, pos_matched_idxs)
loss_mask = dict(loss_mask=loss_mask) loss_mask = dict(loss_mask=loss_mask)
else: else:
labels = [r["labels"] for r in result] labels = [r["labels"] for r in result]
...@@ -601,7 +598,7 @@ class RoIHeads(torch.nn.Module): ...@@ -601,7 +598,7 @@ class RoIHeads(torch.nn.Module):
gt_keypoints = [t["keypoints"] for t in targets] gt_keypoints = [t["keypoints"] for t in targets]
loss_keypoint = keypointrcnn_loss( loss_keypoint = keypointrcnn_loss(
keypoint_logits, keypoint_proposals, keypoint_logits, keypoint_proposals,
gt_keypoints, pos_matched_idxs, self.keypoint_discretization_size) gt_keypoints, pos_matched_idxs)
loss_keypoint = dict(loss_keypoint=loss_keypoint) loss_keypoint = dict(loss_keypoint=loss_keypoint)
else: else:
keypoints_probs, kp_scores = keypointrcnn_inference(keypoint_logits, keypoint_proposals) keypoints_probs, kp_scores = keypointrcnn_inference(keypoint_logits, keypoint_proposals)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment