Unverified Commit 9df04d54 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Potential bug fix for GuidedAnchorHead (#754)

* code formatting for guided_anchor_head.py

* bug fix for using multi_apply
parent 726ebdc9
...@@ -36,15 +36,14 @@ class FeatureAdaption(nn.Module): ...@@ -36,15 +36,14 @@ class FeatureAdaption(nn.Module):
deformable_groups=4): deformable_groups=4):
super(FeatureAdaption, self).__init__() super(FeatureAdaption, self).__init__()
offset_channels = kernel_size * kernel_size * 2 offset_channels = kernel_size * kernel_size * 2
self.conv_offset = nn.Conv2d(2, self.conv_offset = nn.Conv2d(
deformable_groups * offset_channels, 2, deformable_groups * offset_channels, 1, bias=False)
1, self.conv_adaption = DeformConv(
bias=False) in_channels,
self.conv_adaption = DeformConv(in_channels, out_channels,
out_channels, kernel_size=kernel_size,
kernel_size=kernel_size, padding=(kernel_size - 1) // 2,
padding=(kernel_size - 1) // 2, deformable_groups=deformable_groups)
deformable_groups=deformable_groups)
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
def init_weights(self): def init_weights(self):
...@@ -109,20 +108,23 @@ class GuidedAnchorHead(AnchorHead): ...@@ -109,20 +108,23 @@ class GuidedAnchorHead(AnchorHead):
target_stds=(1.0, 1.0, 1.0, 1.0), target_stds=(1.0, 1.0, 1.0, 1.0),
deformable_groups=4, deformable_groups=4,
loc_filter_thr=0.01, loc_filter_thr=0.01,
loss_loc=dict(type='FocalLoss', loss_loc=dict(
use_sigmoid=True, type='FocalLoss',
gamma=2.0, use_sigmoid=True,
alpha=0.25, gamma=2.0,
loss_weight=1.0), alpha=0.25,
loss_shape=dict(type='IoULoss', loss_weight=1.0),
style='bounded', loss_shape=dict(
beta=0.2, type='IoULoss',
loss_weight=1.0), style='bounded',
loss_cls=dict(type='CrossEntropyLoss', beta=0.2,
use_sigmoid=True, loss_weight=1.0),
loss_weight=1.0), loss_cls=dict(
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, type='CrossEntropyLoss',
loss_weight=1.0)): use_sigmoid=True,
loss_weight=1.0),
loss_bbox=dict(
type='SmoothL1Loss', beta=1.0, loss_weight=1.0)):
super(AnchorHead, self).__init__() super(AnchorHead, self).__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.num_classes = num_classes self.num_classes = num_classes
...@@ -258,8 +260,8 @@ class GuidedAnchorHead(AnchorHead): ...@@ -258,8 +260,8 @@ class GuidedAnchorHead(AnchorHead):
inside_flags_list.append(inside_flags) inside_flags_list.append(inside_flags)
# inside_flag for a position is true if any anchor in this # inside_flag for a position is true if any anchor in this
# position is true # position is true
inside_flags = (torch.stack(inside_flags_list, 0).sum(dim=0) > inside_flags = (
0) torch.stack(inside_flags_list, 0).sum(dim=0) > 0)
multi_level_flags.append(inside_flags) multi_level_flags.append(inside_flags)
inside_flag_list.append(multi_level_flags) inside_flag_list.append(multi_level_flags)
return approxs_list, inside_flag_list return approxs_list, inside_flag_list
...@@ -347,11 +349,12 @@ class GuidedAnchorHead(AnchorHead): ...@@ -347,11 +349,12 @@ class GuidedAnchorHead(AnchorHead):
-1, 2).detach()[mask] -1, 2).detach()[mask]
bbox_deltas = anchor_deltas.new_full(squares.size(), 0) bbox_deltas = anchor_deltas.new_full(squares.size(), 0)
bbox_deltas[:, 2:] = anchor_deltas bbox_deltas[:, 2:] = anchor_deltas
guided_anchors = delta2bbox(squares, guided_anchors = delta2bbox(
bbox_deltas, squares,
self.anchoring_means, bbox_deltas,
self.anchoring_stds, self.anchoring_means,
wh_ratio_clip=1e-6) self.anchoring_stds,
wh_ratio_clip=1e-6)
return guided_anchors, mask return guided_anchors, mask
def loss_shape_single(self, shape_pred, bbox_anchors, bbox_gts, def loss_shape_single(self, shape_pred, bbox_anchors, bbox_gts,
...@@ -368,23 +371,26 @@ class GuidedAnchorHead(AnchorHead): ...@@ -368,23 +371,26 @@ class GuidedAnchorHead(AnchorHead):
bbox_anchors_ = bbox_anchors[inds] bbox_anchors_ = bbox_anchors[inds]
bbox_gts_ = bbox_gts[inds] bbox_gts_ = bbox_gts[inds]
anchor_weights_ = anchor_weights[inds] anchor_weights_ = anchor_weights[inds]
pred_anchors_ = delta2bbox(bbox_anchors_, pred_anchors_ = delta2bbox(
bbox_deltas_, bbox_anchors_,
self.anchoring_means, bbox_deltas_,
self.anchoring_stds, self.anchoring_means,
wh_ratio_clip=1e-6) self.anchoring_stds,
loss_shape = self.loss_shape(pred_anchors_, wh_ratio_clip=1e-6)
bbox_gts_, loss_shape = self.loss_shape(
anchor_weights_, pred_anchors_,
avg_factor=anchor_total_num) bbox_gts_,
anchor_weights_,
avg_factor=anchor_total_num)
return loss_shape return loss_shape
def loss_loc_single(self, loc_pred, loc_target, loc_weight, loc_avg_factor, def loss_loc_single(self, loc_pred, loc_target, loc_weight, loc_avg_factor,
cfg): cfg):
loss_loc = self.loss_loc(loc_pred.reshape(-1, 1), loss_loc = self.loss_loc(
loc_target.reshape(-1, 1).long(), loc_pred.reshape(-1, 1),
loc_weight.reshape(-1, 1), loc_target.reshape(-1, 1).long(),
avg_factor=loc_avg_factor) loc_weight.reshape(-1, 1),
avg_factor=loc_avg_factor)
return loss_loc return loss_loc
def loss(self, def loss(self,
...@@ -418,41 +424,44 @@ class GuidedAnchorHead(AnchorHead): ...@@ -418,41 +424,44 @@ class GuidedAnchorHead(AnchorHead):
# get shape targets # get shape targets
sampling = False if not hasattr(cfg, 'ga_sampler') else True sampling = False if not hasattr(cfg, 'ga_sampler') else True
shape_targets = ga_shape_target(approxs_list, shape_targets = ga_shape_target(
inside_flag_list, approxs_list,
squares_list, inside_flag_list,
gt_bboxes, squares_list,
img_metas, gt_bboxes,
self.approxs_per_octave, img_metas,
cfg, self.approxs_per_octave,
sampling=sampling) cfg,
sampling=sampling)
if shape_targets is None: if shape_targets is None:
return None return None
(bbox_anchors_list, bbox_gts_list, anchor_weights_list, anchor_fg_num, (bbox_anchors_list, bbox_gts_list, anchor_weights_list, anchor_fg_num,
anchor_bg_num) = shape_targets anchor_bg_num) = shape_targets
anchor_total_num = (anchor_fg_num if not sampling else anchor_fg_num + anchor_total_num = (
anchor_bg_num) anchor_fg_num if not sampling else anchor_fg_num + anchor_bg_num)
# get anchor targets # get anchor targets
sampling = False if self.cls_focal_loss else True sampling = False if self.cls_focal_loss else True
label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1 label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
cls_reg_targets = anchor_target(guided_anchors_list, cls_reg_targets = anchor_target(
inside_flag_list, guided_anchors_list,
gt_bboxes, inside_flag_list,
img_metas, gt_bboxes,
self.target_means, img_metas,
self.target_stds, self.target_means,
cfg, self.target_stds,
gt_bboxes_ignore_list=gt_bboxes_ignore, cfg,
gt_labels_list=gt_labels, gt_bboxes_ignore_list=gt_bboxes_ignore,
label_channels=label_channels, gt_labels_list=gt_labels,
sampling=sampling) label_channels=label_channels,
sampling=sampling)
if cls_reg_targets is None: if cls_reg_targets is None:
return None return None
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
num_total_pos, num_total_neg) = cls_reg_targets num_total_pos, num_total_neg) = cls_reg_targets
num_total_samples = (num_total_pos if self.cls_focal_loss else num_total_samples = (
num_total_pos + num_total_neg) num_total_pos if self.cls_focal_loss else num_total_pos +
num_total_neg)
# get classification and bbox regression losses # get classification and bbox regression losses
losses_cls, losses_bbox = multi_apply( losses_cls, losses_bbox = multi_apply(
...@@ -467,24 +476,32 @@ class GuidedAnchorHead(AnchorHead): ...@@ -467,24 +476,32 @@ class GuidedAnchorHead(AnchorHead):
cfg=cfg) cfg=cfg)
# get anchor location loss # get anchor location loss
losses_loc, = multi_apply(self.loss_loc_single, losses_loc = []
loc_preds, for i in range(len(loc_preds)):
loc_targets, loss_loc = self.loss_loc_single(
loc_weights, loc_preds[i],
loc_avg_factor=loc_avg_factor, loc_targets[i],
cfg=cfg) loc_weights[i],
loc_avg_factor=loc_avg_factor,
cfg=cfg)
losses_loc.append(loss_loc)
# get anchor shape loss # get anchor shape loss
losses_shape, = multi_apply(self.loss_shape_single, losses_shape = []
shape_preds, for i in range(len(shape_preds)):
bbox_anchors_list, loss_shape = self.loss_shape_single(
bbox_gts_list, shape_preds[i],
anchor_weights_list, bbox_anchors_list[i],
anchor_total_num=anchor_total_num) bbox_gts_list[i],
return dict(loss_cls=losses_cls, anchor_weights_list[i],
loss_bbox=losses_bbox, anchor_total_num=anchor_total_num)
loss_shape=losses_shape, losses_shape.append(loss_shape)
loss_loc=losses_loc)
return dict(
loss_cls=losses_cls,
loss_bbox=losses_bbox,
loss_shape=losses_shape,
loss_loc=losses_loc)
def get_bboxes(self, def get_bboxes(self,
cls_scores, cls_scores,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment