Unverified Commit e75b4973 authored by Monica Alfaro's avatar Monica Alfaro Committed by GitHub
Browse files

Train Faster R-CNN with negative samples (#1911)



* modified FasterRCNN to accept negative samples

* remove debug lines

* Change torch.zeros_like to torch.zerros

* Add unit tests

* take the `device` into account
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent d45a77d4
import torch
import torchvision.models
from torchvision.ops import MultiScaleRoIAlign
from torchvision.models.detection.rpn import AnchorGenerator, RPNHead, RegionProposalNetwork
from torchvision.models.detection.roi_heads import RoIHeads
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead
import unittest
class Tester(unittest.TestCase):
def test_targets_to_anchors(self):
boxes = torch.zeros((0, 4), dtype=torch.float32)
negative_target = {"boxes": boxes,
"labels": torch.zeros((1, 1), dtype=torch.int64),
"image_id": 4,
"area": (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]),
"iscrowd": torch.zeros((0,), dtype=torch.int64)}
anchors = [torch.randint(-50, 50, (3, 4), dtype=torch.float32)]
targets = [negative_target]
anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
rpn_anchor_generator = AnchorGenerator(
anchor_sizes, aspect_ratios
)
rpn_head = RPNHead(4, rpn_anchor_generator.num_anchors_per_location()[0])
head = RegionProposalNetwork(
rpn_anchor_generator, rpn_head,
0.5, 0.3,
256, 0.5,
2000, 2000, 0.7)
labels, matched_gt_boxes = head.assign_targets_to_anchors(anchors, targets)
self.assertEqual(labels[0].sum(), 0)
self.assertEqual(labels[0].shape, torch.Size([anchors[0].shape[0]]))
self.assertEqual(labels[0].dtype, torch.float32)
self.assertEqual(matched_gt_boxes[0].sum(), 0)
self.assertEqual(matched_gt_boxes[0].shape, anchors[0].shape)
self.assertEqual(matched_gt_boxes[0].dtype, torch.float32)
def test_assign_targets_to_proposals(self):
proposals = [torch.randint(-50, 50, (20, 4), dtype=torch.float32)]
gt_boxes = [torch.zeros((0, 4), dtype=torch.float32)]
gt_labels = [torch.tensor([[0]], dtype=torch.int64)]
box_roi_pool = MultiScaleRoIAlign(
featmap_names=['0', '1', '2', '3'],
output_size=7,
sampling_ratio=2)
resolution = box_roi_pool.output_size[0]
representation_size = 1024
box_head = TwoMLPHead(
4 * resolution ** 2,
representation_size)
representation_size = 1024
box_predictor = FastRCNNPredictor(
representation_size,
2)
roi_heads = RoIHeads(
# Box
box_roi_pool, box_head, box_predictor,
0.5, 0.5,
512, 0.25,
None,
0.05, 0.5, 100)
matched_idxs, labels = roi_heads.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
self.assertEqual(matched_idxs[0].sum(), 0)
self.assertEqual(matched_idxs[0].shape, torch.Size([proposals[0].shape[0]]))
self.assertEqual(matched_idxs[0].dtype, torch.int64)
self.assertEqual(labels[0].sum(), 0)
self.assertEqual(labels[0].shape, torch.Size([proposals[0].shape[0]]))
self.assertEqual(labels[0].dtype, torch.int64)
def test_forward_negative_sample(self):
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 2)
images = [torch.rand((3, 100, 100), dtype=torch.float32)]
boxes = torch.zeros((0, 4), dtype=torch.float32)
negative_target = {"boxes": boxes,
"labels": torch.zeros((1, 1), dtype=torch.int64),
"image_id": 4,
"area": (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]),
"iscrowd": torch.zeros((0,), dtype=torch.int64)}
targets = [negative_target]
loss_dict = model(images, targets)
self.assertEqual(loss_dict["loss_box_reg"], torch.tensor(0.))
self.assertEqual(loss_dict["loss_rpn_box_reg"], torch.tensor(0.))
if __name__ == '__main__':
unittest.main()
......@@ -574,22 +574,33 @@ class RoIHeads(torch.nn.Module):
matched_idxs = []
labels = []
for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):
# set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands
match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image)
matched_idxs_in_image = self.proposal_matcher(match_quality_matrix)
clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0)
if gt_boxes_in_image.numel() == 0:
# Background image
device = proposals_in_image.device
clamped_matched_idxs_in_image = torch.zeros(
(proposals_in_image.shape[0],), dtype=torch.int64, device=device
)
labels_in_image = torch.zeros(
(proposals_in_image.shape[0],), dtype=torch.int64, device=device
)
else:
# set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands
match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image)
matched_idxs_in_image = self.proposal_matcher(match_quality_matrix)
labels_in_image = gt_labels_in_image[clamped_matched_idxs_in_image]
labels_in_image = labels_in_image.to(dtype=torch.int64)
clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0)
# Label background (below the low threshold)
bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD
labels_in_image[bg_inds] = torch.tensor(0)
labels_in_image = gt_labels_in_image[clamped_matched_idxs_in_image]
labels_in_image = labels_in_image.to(dtype=torch.int64)
# Label ignore proposals (between low and high thresholds)
ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS
labels_in_image[ignore_inds] = torch.tensor(-1) # -1 is ignored by sampler
# Label background (below the low threshold)
bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD
labels_in_image[bg_inds] = torch.tensor(0)
# Label ignore proposals (between low and high thresholds)
ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS
labels_in_image[ignore_inds] = torch.tensor(-1) # -1 is ignored by sampler
matched_idxs.append(clamped_matched_idxs_in_image)
labels.append(labels_in_image)
......@@ -635,6 +646,8 @@ class RoIHeads(torch.nn.Module):
self.check_targets(targets)
assert targets is not None
dtype = proposals[0].dtype
device = proposals[0].device
gt_boxes = [t["boxes"].to(dtype) for t in targets]
gt_labels = [t["labels"] for t in targets]
......@@ -652,7 +665,11 @@ class RoIHeads(torch.nn.Module):
proposals[img_id] = proposals[img_id][img_sampled_inds]
labels[img_id] = labels[img_id][img_sampled_inds]
matched_idxs[img_id] = matched_idxs[img_id][img_sampled_inds]
matched_gt_boxes.append(gt_boxes[img_id][matched_idxs[img_id]])
gt_boxes_in_image = gt_boxes[img_id]
if gt_boxes_in_image.numel() == 0:
gt_boxes_in_image = torch.zeros((1, 4), dtype=dtype, device=device)
matched_gt_boxes.append(gt_boxes_in_image[matched_idxs[img_id]])
regression_targets = self.box_coder.encode(matched_gt_boxes, proposals)
return proposals, matched_idxs, labels, regression_targets
......
......@@ -336,24 +336,31 @@ class RegionProposalNetwork(torch.nn.Module):
matched_gt_boxes = []
for anchors_per_image, targets_per_image in zip(anchors, targets):
gt_boxes = targets_per_image["boxes"]
match_quality_matrix = box_ops.box_iou(gt_boxes, anchors_per_image)
matched_idxs = self.proposal_matcher(match_quality_matrix)
# get the targets corresponding GT for each proposal
# NB: need to clamp the indices because we can have a single
# GT in the image, and matched_idxs can be -2, which goes
# out of bounds
matched_gt_boxes_per_image = gt_boxes[matched_idxs.clamp(min=0)]
labels_per_image = matched_idxs >= 0
labels_per_image = labels_per_image.to(dtype=torch.float32)
# Background (negative examples)
bg_indices = matched_idxs == self.proposal_matcher.BELOW_LOW_THRESHOLD
labels_per_image[bg_indices] = torch.tensor(0.0)
# discard indices that are between thresholds
inds_to_discard = matched_idxs == self.proposal_matcher.BETWEEN_THRESHOLDS
labels_per_image[inds_to_discard] = torch.tensor(-1.0)
if gt_boxes.numel() == 0:
# Background image (negative example)
device = anchors_per_image.device
matched_gt_boxes_per_image = torch.zeros(anchors_per_image.shape, dtype=torch.float32, device=device)
labels_per_image = torch.zeros((anchors_per_image.shape[0],), dtype=torch.float32, device=device)
else:
match_quality_matrix = box_ops.box_iou(gt_boxes, anchors_per_image)
matched_idxs = self.proposal_matcher(match_quality_matrix)
# get the targets corresponding GT for each proposal
# NB: need to clamp the indices because we can have a single
# GT in the image, and matched_idxs can be -2, which goes
# out of bounds
matched_gt_boxes_per_image = gt_boxes[matched_idxs.clamp(min=0)]
labels_per_image = matched_idxs >= 0
labels_per_image = labels_per_image.to(dtype=torch.float32)
# Background (negative examples)
bg_indices = matched_idxs == self.proposal_matcher.BELOW_LOW_THRESHOLD
labels_per_image[bg_indices] = torch.tensor(0.0)
# discard indices that are between thresholds
inds_to_discard = matched_idxs == self.proposal_matcher.BETWEEN_THRESHOLDS
labels_per_image[inds_to_discard] = torch.tensor(-1.0)
labels.append(labels_per_image)
matched_gt_boxes.append(matched_gt_boxes_per_image)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment