Commit f1d06cdc authored by Kai Chen's avatar Kai Chen
Browse files

support gt_bboxes_ignore for anchor heads

parent 801c8b19
......@@ -11,6 +11,7 @@ def anchor_target(anchor_list,
target_means,
target_stds,
cfg,
gt_bboxes_ignore_list=None,
gt_labels_list=None,
label_channels=1,
sampling=True,
......@@ -41,6 +42,8 @@ def anchor_target(anchor_list,
valid_flag_list[i] = torch.cat(valid_flag_list[i])
# compute targets for each image
if gt_bboxes_ignore_list is None:
gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
if gt_labels_list is None:
gt_labels_list = [None for _ in range(num_imgs)]
(all_labels, all_label_weights, all_bbox_targets, all_bbox_weights,
......@@ -49,6 +52,7 @@ def anchor_target(anchor_list,
anchor_list,
valid_flag_list,
gt_bboxes_list,
gt_bboxes_ignore_list,
gt_labels_list,
img_metas,
target_means=target_means,
......@@ -90,6 +94,7 @@ def images_to_levels(target, num_level_anchors):
def anchor_target_single(flat_anchors,
valid_flags,
gt_bboxes,
gt_bboxes_ignore,
gt_labels,
img_meta,
target_means,
......@@ -108,11 +113,11 @@ def anchor_target_single(flat_anchors,
if sampling:
assign_result, sampling_result = assign_and_sample(
anchors, gt_bboxes, None, None, cfg)
anchors, gt_bboxes, gt_bboxes_ignore, None, cfg)
else:
bbox_assigner = build_assigner(cfg.assigner)
assign_result = bbox_assigner.assign(anchors, gt_bboxes, None,
gt_labels)
assign_result = bbox_assigner.assign(anchors, gt_bboxes,
gt_bboxes_ignore, gt_labels)
bbox_sampler = PseudoSampler()
sampling_result = bbox_sampler.sample(assign_result, anchors,
gt_bboxes)
......
......@@ -169,8 +169,14 @@ class AnchorHead(nn.Module):
avg_factor=num_total_samples)
return loss_cls, loss_reg
def loss(self, cls_scores, bbox_preds, gt_bboxes, gt_labels, img_metas,
cfg):
def loss(self,
cls_scores,
bbox_preds,
gt_bboxes,
gt_labels,
img_metas,
cfg,
gt_bboxes_ignore=None):
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == len(self.anchor_generators)
......@@ -186,6 +192,7 @@ class AnchorHead(nn.Module):
self.target_means,
self.target_stds,
cfg,
gt_bboxes_ignore_list=gt_bboxes_ignore,
gt_labels_list=gt_labels,
label_channels=label_channels,
sampling=sampling)
......
......@@ -34,9 +34,21 @@ class RPNHead(AnchorHead):
rpn_bbox_pred = self.rpn_reg(x)
return rpn_cls_score, rpn_bbox_pred
def loss(self, cls_scores, bbox_preds, gt_bboxes, img_metas, cfg):
losses = super(RPNHead, self).loss(cls_scores, bbox_preds, gt_bboxes,
None, img_metas, cfg)
def loss(self,
cls_scores,
bbox_preds,
gt_bboxes,
img_metas,
cfg,
gt_bboxes_ignore=None):
losses = super(RPNHead, self).loss(
cls_scores,
bbox_preds,
gt_bboxes,
None,
img_metas,
cfg,
gt_bboxes_ignore=gt_bboxes_ignore)
return dict(
loss_rpn_cls=losses['loss_cls'], loss_rpn_reg=losses['loss_reg'])
......
......@@ -130,8 +130,14 @@ class SSDHead(AnchorHead):
avg_factor=num_total_samples)
return loss_cls[None], loss_reg
def loss(self, cls_scores, bbox_preds, gt_bboxes, gt_labels, img_metas,
cfg):
def loss(self,
cls_scores,
bbox_preds,
gt_bboxes,
gt_labels,
img_metas,
cfg,
gt_bboxes_ignore=None):
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == len(self.anchor_generators)
......@@ -145,6 +151,7 @@ class SSDHead(AnchorHead):
self.target_means,
self.target_stds,
cfg,
gt_bboxes_ignore_list=gt_bboxes_ignore,
gt_labels_list=gt_labels,
label_channels=1,
sampling=False,
......
......@@ -109,8 +109,8 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
img,
img_meta,
gt_bboxes,
gt_bboxes_ignore,
gt_labels,
gt_bboxes_ignore=None,
gt_masks=None,
proposals=None):
x = self.extract_feat(img)
......@@ -121,7 +121,8 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
rpn_outs = self.rpn_head(x)
rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta,
self.train_cfg.rpn)
rpn_losses = self.rpn_head.loss(*rpn_loss_inputs)
rpn_losses = self.rpn_head.loss(
*rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
losses.update(rpn_losses)
proposal_inputs = rpn_outs + (img_meta, self.test_cfg.rpn)
......
......@@ -38,7 +38,11 @@ class RPN(BaseDetector, RPNTestMixin):
x = self.neck(x)
return x
def forward_train(self, img, img_meta, gt_bboxes=None):
def forward_train(self,
img,
img_meta,
gt_bboxes=None,
gt_bboxes_ignore=None):
if self.train_cfg.rpn.get('debug', False):
self.rpn_head.debug_imgs = tensor2imgs(img)
......@@ -46,7 +50,8 @@ class RPN(BaseDetector, RPNTestMixin):
rpn_outs = self.rpn_head(x)
rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta, self.train_cfg.rpn)
losses = self.rpn_head.loss(*rpn_loss_inputs)
losses = self.rpn_head.loss(
*rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
return losses
def simple_test(self, img, img_meta, rescale=False):
......
......@@ -42,11 +42,17 @@ class SingleStageDetector(BaseDetector):
x = self.neck(x)
return x
def forward_train(self, img, img_metas, gt_bboxes, gt_labels):
def forward_train(self,
img,
img_metas,
gt_bboxes,
gt_labels,
gt_bboxes_ignore=None):
x = self.extract_feat(img)
outs = self.bbox_head(x)
loss_inputs = outs + (gt_bboxes, gt_labels, img_metas, self.train_cfg)
losses = self.bbox_head.loss(*loss_inputs)
losses = self.bbox_head.loss(
*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
return losses
def simple_test(self, img, img_meta, rescale=False):
......
......@@ -81,8 +81,8 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
img,
img_meta,
gt_bboxes,
gt_bboxes_ignore,
gt_labels,
gt_bboxes_ignore=None,
gt_masks=None,
proposals=None):
x = self.extract_feat(img)
......@@ -94,7 +94,8 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
rpn_outs = self.rpn_head(x)
rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta,
self.train_cfg.rpn)
rpn_losses = self.rpn_head.loss(*rpn_loss_inputs)
rpn_losses = self.rpn_head.loss(
*rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
losses.update(rpn_losses)
proposal_inputs = rpn_outs + (img_meta, self.test_cfg.rpn)
......@@ -108,6 +109,8 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
bbox_sampler = build_sampler(
self.train_cfg.rcnn.sampler, context=self)
num_imgs = img.size(0)
if gt_bboxes_ignore is None:
gt_bboxes_ignore = [None for _ in range(num_imgs)]
sampling_results = []
for i in range(num_imgs):
assign_result = bbox_assigner.assign(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment