"gallery/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "655ebdbcd71251ff6bbac89c4183f537db9aae2d"
Unverified Commit 4ab46e5f authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Support for image with no annotations in RetinaNet (#3032)

* Enable support for images without annotations

* Ensuring gradient propagates to RegressionHead.

* Rewriting losses to remove branching.

* Fix the seed on DeformConv autocast test.
parent 9e71fdaf
...@@ -128,6 +128,15 @@ class Tester(unittest.TestCase): ...@@ -128,6 +128,15 @@ class Tester(unittest.TestCase):
self.assertEqual(loss_dict["loss_rpn_box_reg"], torch.tensor(0.)) self.assertEqual(loss_dict["loss_rpn_box_reg"], torch.tensor(0.))
self.assertEqual(loss_dict["loss_keypoint"], torch.tensor(0.)) self.assertEqual(loss_dict["loss_keypoint"], torch.tensor(0.))
def test_forward_negative_sample_retinanet(self):
model = torchvision.models.detection.retinanet_resnet50_fpn(
num_classes=2, min_size=100, max_size=100)
images, targets = self._make_empty_sample()
loss_dict = model(images, targets)
self.assertEqual(loss_dict["bbox_regression"], torch.tensor(0.))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
from common_utils import set_rng_seed
import math import math
import unittest import unittest
...@@ -655,6 +656,7 @@ class DeformConvTester(OpTester, unittest.TestCase): ...@@ -655,6 +656,7 @@ class DeformConvTester(OpTester, unittest.TestCase):
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_autocast(self): def test_autocast(self):
set_rng_seed(0)
for dtype in (torch.float, torch.half): for dtype in (torch.float, torch.half):
with torch.cuda.amp.autocast(): with torch.cuda.amp.autocast():
self._test_forward(torch.device("cuda"), False, dtype=dtype) self._test_forward(torch.device("cuda"), False, dtype=dtype)
......
...@@ -107,21 +107,16 @@ class RetinaNetClassificationHead(nn.Module): ...@@ -107,21 +107,16 @@ class RetinaNetClassificationHead(nn.Module):
# determine only the foreground # determine only the foreground
foreground_idxs_per_image = matched_idxs_per_image >= 0 foreground_idxs_per_image = matched_idxs_per_image >= 0
num_foreground = foreground_idxs_per_image.sum() num_foreground = foreground_idxs_per_image.sum()
# no matched_idxs means there were no annotations in this image
# TODO: enable support for images without annotations that works on distributed # create the target classification
if False: # matched_idxs_per_image.numel() == 0: gt_classes_target = torch.zeros_like(cls_logits_per_image)
gt_classes_target = torch.zeros_like(cls_logits_per_image) gt_classes_target[
valid_idxs_per_image = torch.arange(cls_logits_per_image.shape[0]) foreground_idxs_per_image,
else: targets_per_image['labels'][matched_idxs_per_image[foreground_idxs_per_image]]
# create the target classification ] = 1.0
gt_classes_target = torch.zeros_like(cls_logits_per_image)
gt_classes_target[ # find indices for which anchors should be ignored
foreground_idxs_per_image, valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS
targets_per_image['labels'][matched_idxs_per_image[foreground_idxs_per_image]]
] = 1.0
# find indices for which anchors should be ignored
valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS
# compute the classification loss # compute the classification loss
losses.append(sigmoid_focal_loss( losses.append(sigmoid_focal_loss(
...@@ -191,23 +186,12 @@ class RetinaNetRegressionHead(nn.Module): ...@@ -191,23 +186,12 @@ class RetinaNetRegressionHead(nn.Module):
for targets_per_image, bbox_regression_per_image, anchors_per_image, matched_idxs_per_image in \ for targets_per_image, bbox_regression_per_image, anchors_per_image, matched_idxs_per_image in \
zip(targets, bbox_regression, anchors, matched_idxs): zip(targets, bbox_regression, anchors, matched_idxs):
# no matched_idxs means there were no annotations in this image
# TODO enable support for images without annotations with distributed support
# if matched_idxs_per_image.numel() == 0:
# continue
# get the targets corresponding GT for each proposal
# NB: need to clamp the indices because we can have a single
# GT in the image, and matched_idxs can be -2, which goes
# out of bounds
matched_gt_boxes_per_image = targets_per_image['boxes'][matched_idxs_per_image.clamp(min=0)]
# determine only the foreground indices, ignore the rest # determine only the foreground indices, ignore the rest
foreground_idxs_per_image = matched_idxs_per_image >= 0 foreground_idxs_per_image = torch.where(matched_idxs_per_image >= 0)[0]
num_foreground = foreground_idxs_per_image.sum() num_foreground = foreground_idxs_per_image.numel()
# select only the foreground boxes # select only the foreground boxes
matched_gt_boxes_per_image = matched_gt_boxes_per_image[foreground_idxs_per_image, :] matched_gt_boxes_per_image = targets_per_image['boxes'][matched_idxs_per_image[foreground_idxs_per_image]]
bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :] bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :]
anchors_per_image = anchors_per_image[foreground_idxs_per_image, :] anchors_per_image = anchors_per_image[foreground_idxs_per_image, :]
...@@ -403,7 +387,7 @@ class RetinaNet(nn.Module): ...@@ -403,7 +387,7 @@ class RetinaNet(nn.Module):
matched_idxs = [] matched_idxs = []
for anchors_per_image, targets_per_image in zip(anchors, targets): for anchors_per_image, targets_per_image in zip(anchors, targets):
if targets_per_image['boxes'].numel() == 0: if targets_per_image['boxes'].numel() == 0:
matched_idxs.append(torch.empty((0,), dtype=torch.int32)) matched_idxs.append(torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64))
continue continue
match_quality_matrix = box_ops.box_iou(targets_per_image['boxes'], anchors_per_image) match_quality_matrix = box_ops.box_iou(targets_per_image['boxes'], anchors_per_image)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment