Unverified Commit 11d903ec authored by Hu Ye's avatar Hu Ye Committed by GitHub
Browse files

fix bug when the target is empty in FCOS (#5267)



* fix bug when the target is empty

* Add unittest for empty instance training
Co-authored-by: default avatarZhiqiang Wang <zhiqwang@foxmail.com>
Co-authored-by: default avatarJoao Gomes <joaopsgomes@gmail.com>
parent 579f5f5f
......@@ -143,6 +143,17 @@ class TestModelsDetectionNegativeSamples:
assert_equal(loss_dict["bbox_regression"], torch.tensor(0.0))
def test_forward_negative_sample_fcos(self):
model = torchvision.models.detection.fcos_resnet50_fpn(
num_classes=2, min_size=100, max_size=100, pretrained_backbone=False
)
images, targets = self._make_empty_sample()
loss_dict = model(images, targets)
assert_equal(loss_dict["bbox_regression"], torch.tensor(0.0))
assert_equal(loss_dict["bbox_ctrness"], torch.tensor(0.0))
def test_forward_negative_sample_ssd(self):
model = torchvision.models.detection.ssd300_vgg16(num_classes=2, pretrained_backbone=False)
......
......@@ -59,9 +59,13 @@ class FCOSHead(nn.Module):
all_gt_classes_targets = []
all_gt_boxes_targets = []
for targets_per_image, matched_idxs_per_image in zip(targets, matched_idxs):
gt_classes_targets = targets_per_image["labels"][matched_idxs_per_image.clip(min=0)]
if len(targets_per_image["labels"]) == 0:
gt_classes_targets = targets_per_image["labels"].new_zeros((len(matched_idxs_per_image),))
gt_boxes_targets = targets_per_image["boxes"].new_zeros((len(matched_idxs_per_image), 4))
else:
gt_classes_targets = targets_per_image["labels"][matched_idxs_per_image.clip(min=0)]
gt_boxes_targets = targets_per_image["boxes"][matched_idxs_per_image.clip(min=0)]
gt_classes_targets[matched_idxs_per_image < 0] = -1 # backgroud
gt_boxes_targets = targets_per_image["boxes"][matched_idxs_per_image.clip(min=0)]
all_gt_classes_targets.append(gt_classes_targets)
all_gt_boxes_targets.append(gt_boxes_targets)
......@@ -95,13 +99,14 @@ class FCOSHead(nn.Module):
]
bbox_reg_targets = torch.stack(bbox_reg_targets, dim=0)
if len(bbox_reg_targets) == 0:
bbox_reg_targets.new_zeros(len(bbox_reg_targets))
left_right = bbox_reg_targets[:, :, [0, 2]]
top_bottom = bbox_reg_targets[:, :, [1, 3]]
gt_ctrness_targets = torch.sqrt(
(left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0])
* (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
)
gt_ctrness_targets = bbox_reg_targets.new_zeros(bbox_reg_targets.size()[:-1])
else:
left_right = bbox_reg_targets[:, :, [0, 2]]
top_bottom = bbox_reg_targets[:, :, [1, 3]]
gt_ctrness_targets = torch.sqrt(
(left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0])
* (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
)
pred_centerness = bbox_ctrness.squeeze(dim=2)
loss_bbox_ctrness = nn.functional.binary_cross_entropy_with_logits(
pred_centerness[foregroud_mask], gt_ctrness_targets[foregroud_mask], reduction="sum"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment