Unverified Commit 9df04d54 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Potential bug fix for GuidedAnchorHead (#754)

* code formatting for guided_anchor_head.py

* bug fix for using multi_apply
parent 726ebdc9
...@@ -36,11 +36,10 @@ class FeatureAdaption(nn.Module): ...@@ -36,11 +36,10 @@ class FeatureAdaption(nn.Module):
deformable_groups=4): deformable_groups=4):
super(FeatureAdaption, self).__init__() super(FeatureAdaption, self).__init__()
offset_channels = kernel_size * kernel_size * 2 offset_channels = kernel_size * kernel_size * 2
self.conv_offset = nn.Conv2d(2, self.conv_offset = nn.Conv2d(
deformable_groups * offset_channels, 2, deformable_groups * offset_channels, 1, bias=False)
1, self.conv_adaption = DeformConv(
bias=False) in_channels,
self.conv_adaption = DeformConv(in_channels,
out_channels, out_channels,
kernel_size=kernel_size, kernel_size=kernel_size,
padding=(kernel_size - 1) // 2, padding=(kernel_size - 1) // 2,
...@@ -109,20 +108,23 @@ class GuidedAnchorHead(AnchorHead): ...@@ -109,20 +108,23 @@ class GuidedAnchorHead(AnchorHead):
target_stds=(1.0, 1.0, 1.0, 1.0), target_stds=(1.0, 1.0, 1.0, 1.0),
deformable_groups=4, deformable_groups=4,
loc_filter_thr=0.01, loc_filter_thr=0.01,
loss_loc=dict(type='FocalLoss', loss_loc=dict(
type='FocalLoss',
use_sigmoid=True, use_sigmoid=True,
gamma=2.0, gamma=2.0,
alpha=0.25, alpha=0.25,
loss_weight=1.0), loss_weight=1.0),
loss_shape=dict(type='IoULoss', loss_shape=dict(
type='IoULoss',
style='bounded', style='bounded',
beta=0.2, beta=0.2,
loss_weight=1.0), loss_weight=1.0),
loss_cls=dict(type='CrossEntropyLoss', loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=True, use_sigmoid=True,
loss_weight=1.0), loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_bbox=dict(
loss_weight=1.0)): type='SmoothL1Loss', beta=1.0, loss_weight=1.0)):
super(AnchorHead, self).__init__() super(AnchorHead, self).__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.num_classes = num_classes self.num_classes = num_classes
...@@ -258,8 +260,8 @@ class GuidedAnchorHead(AnchorHead): ...@@ -258,8 +260,8 @@ class GuidedAnchorHead(AnchorHead):
inside_flags_list.append(inside_flags) inside_flags_list.append(inside_flags)
# inside_flag for a position is true if any anchor in this # inside_flag for a position is true if any anchor in this
# position is true # position is true
inside_flags = (torch.stack(inside_flags_list, 0).sum(dim=0) > inside_flags = (
0) torch.stack(inside_flags_list, 0).sum(dim=0) > 0)
multi_level_flags.append(inside_flags) multi_level_flags.append(inside_flags)
inside_flag_list.append(multi_level_flags) inside_flag_list.append(multi_level_flags)
return approxs_list, inside_flag_list return approxs_list, inside_flag_list
...@@ -347,7 +349,8 @@ class GuidedAnchorHead(AnchorHead): ...@@ -347,7 +349,8 @@ class GuidedAnchorHead(AnchorHead):
-1, 2).detach()[mask] -1, 2).detach()[mask]
bbox_deltas = anchor_deltas.new_full(squares.size(), 0) bbox_deltas = anchor_deltas.new_full(squares.size(), 0)
bbox_deltas[:, 2:] = anchor_deltas bbox_deltas[:, 2:] = anchor_deltas
guided_anchors = delta2bbox(squares, guided_anchors = delta2bbox(
squares,
bbox_deltas, bbox_deltas,
self.anchoring_means, self.anchoring_means,
self.anchoring_stds, self.anchoring_stds,
...@@ -368,12 +371,14 @@ class GuidedAnchorHead(AnchorHead): ...@@ -368,12 +371,14 @@ class GuidedAnchorHead(AnchorHead):
bbox_anchors_ = bbox_anchors[inds] bbox_anchors_ = bbox_anchors[inds]
bbox_gts_ = bbox_gts[inds] bbox_gts_ = bbox_gts[inds]
anchor_weights_ = anchor_weights[inds] anchor_weights_ = anchor_weights[inds]
pred_anchors_ = delta2bbox(bbox_anchors_, pred_anchors_ = delta2bbox(
bbox_anchors_,
bbox_deltas_, bbox_deltas_,
self.anchoring_means, self.anchoring_means,
self.anchoring_stds, self.anchoring_stds,
wh_ratio_clip=1e-6) wh_ratio_clip=1e-6)
loss_shape = self.loss_shape(pred_anchors_, loss_shape = self.loss_shape(
pred_anchors_,
bbox_gts_, bbox_gts_,
anchor_weights_, anchor_weights_,
avg_factor=anchor_total_num) avg_factor=anchor_total_num)
...@@ -381,7 +386,8 @@ class GuidedAnchorHead(AnchorHead): ...@@ -381,7 +386,8 @@ class GuidedAnchorHead(AnchorHead):
def loss_loc_single(self, loc_pred, loc_target, loc_weight, loc_avg_factor, def loss_loc_single(self, loc_pred, loc_target, loc_weight, loc_avg_factor,
cfg): cfg):
loss_loc = self.loss_loc(loc_pred.reshape(-1, 1), loss_loc = self.loss_loc(
loc_pred.reshape(-1, 1),
loc_target.reshape(-1, 1).long(), loc_target.reshape(-1, 1).long(),
loc_weight.reshape(-1, 1), loc_weight.reshape(-1, 1),
avg_factor=loc_avg_factor) avg_factor=loc_avg_factor)
...@@ -418,7 +424,8 @@ class GuidedAnchorHead(AnchorHead): ...@@ -418,7 +424,8 @@ class GuidedAnchorHead(AnchorHead):
# get shape targets # get shape targets
sampling = False if not hasattr(cfg, 'ga_sampler') else True sampling = False if not hasattr(cfg, 'ga_sampler') else True
shape_targets = ga_shape_target(approxs_list, shape_targets = ga_shape_target(
approxs_list,
inside_flag_list, inside_flag_list,
squares_list, squares_list,
gt_bboxes, gt_bboxes,
...@@ -430,13 +437,14 @@ class GuidedAnchorHead(AnchorHead): ...@@ -430,13 +437,14 @@ class GuidedAnchorHead(AnchorHead):
return None return None
(bbox_anchors_list, bbox_gts_list, anchor_weights_list, anchor_fg_num, (bbox_anchors_list, bbox_gts_list, anchor_weights_list, anchor_fg_num,
anchor_bg_num) = shape_targets anchor_bg_num) = shape_targets
anchor_total_num = (anchor_fg_num if not sampling else anchor_fg_num + anchor_total_num = (
anchor_bg_num) anchor_fg_num if not sampling else anchor_fg_num + anchor_bg_num)
# get anchor targets # get anchor targets
sampling = False if self.cls_focal_loss else True sampling = False if self.cls_focal_loss else True
label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1 label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
cls_reg_targets = anchor_target(guided_anchors_list, cls_reg_targets = anchor_target(
guided_anchors_list,
inside_flag_list, inside_flag_list,
gt_bboxes, gt_bboxes,
img_metas, img_metas,
...@@ -451,8 +459,9 @@ class GuidedAnchorHead(AnchorHead): ...@@ -451,8 +459,9 @@ class GuidedAnchorHead(AnchorHead):
return None return None
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
num_total_pos, num_total_neg) = cls_reg_targets num_total_pos, num_total_neg) = cls_reg_targets
num_total_samples = (num_total_pos if self.cls_focal_loss else num_total_samples = (
num_total_pos + num_total_neg) num_total_pos if self.cls_focal_loss else num_total_pos +
num_total_neg)
# get classification and bbox regression losses # get classification and bbox regression losses
losses_cls, losses_bbox = multi_apply( losses_cls, losses_bbox = multi_apply(
...@@ -467,21 +476,29 @@ class GuidedAnchorHead(AnchorHead): ...@@ -467,21 +476,29 @@ class GuidedAnchorHead(AnchorHead):
cfg=cfg) cfg=cfg)
# get anchor location loss # get anchor location loss
losses_loc, = multi_apply(self.loss_loc_single, losses_loc = []
loc_preds, for i in range(len(loc_preds)):
loc_targets, loss_loc = self.loss_loc_single(
loc_weights, loc_preds[i],
loc_targets[i],
loc_weights[i],
loc_avg_factor=loc_avg_factor, loc_avg_factor=loc_avg_factor,
cfg=cfg) cfg=cfg)
losses_loc.append(loss_loc)
# get anchor shape loss # get anchor shape loss
losses_shape, = multi_apply(self.loss_shape_single, losses_shape = []
shape_preds, for i in range(len(shape_preds)):
bbox_anchors_list, loss_shape = self.loss_shape_single(
bbox_gts_list, shape_preds[i],
anchor_weights_list, bbox_anchors_list[i],
bbox_gts_list[i],
anchor_weights_list[i],
anchor_total_num=anchor_total_num) anchor_total_num=anchor_total_num)
return dict(loss_cls=losses_cls, losses_shape.append(loss_shape)
return dict(
loss_cls=losses_cls,
loss_bbox=losses_bbox, loss_bbox=losses_bbox,
loss_shape=losses_shape, loss_shape=losses_shape,
loss_loc=losses_loc) loss_loc=losses_loc)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment