Commit 2df1e0a0 authored by Kai Chen's avatar Kai Chen
Browse files

bug fix for using softmax

parent 70700512
......@@ -12,7 +12,7 @@ def anchor_target(anchor_list,
target_stds,
cfg,
gt_labels_list=None,
cls_out_channels=1,
label_channels=1,
sampling=True,
unmap_outputs=True):
"""Compute regression and classification targets for anchors.
......@@ -54,7 +54,7 @@ def anchor_target(anchor_list,
target_means=target_means,
target_stds=target_stds,
cfg=cfg,
cls_out_channels=cls_out_channels,
label_channels=label_channels,
sampling=sampling,
unmap_outputs=unmap_outputs)
# no valid anchors
......@@ -95,7 +95,7 @@ def anchor_target_single(flat_anchors,
target_means,
target_stds,
cfg,
cls_out_channels=1,
label_channels=1,
sampling=True,
unmap_outputs=True):
inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
......@@ -147,9 +147,9 @@ def anchor_target_single(flat_anchors,
num_total_anchors = flat_anchors.size(0)
labels = unmap(labels, num_total_anchors, inside_flags)
label_weights = unmap(label_weights, num_total_anchors, inside_flags)
if cls_out_channels > 1:
labels, label_weights = expand_binary_labels(labels, label_weights,
cls_out_channels)
if label_channels > 1:
labels, label_weights = expand_binary_labels(
labels, label_weights, label_channels)
bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)
......@@ -157,14 +157,14 @@ def anchor_target_single(flat_anchors,
neg_inds)
def expand_binary_labels(labels, label_weights, cls_out_channels):
def expand_binary_labels(labels, label_weights, label_channels):
bin_labels = labels.new_full(
(labels.size(0), cls_out_channels), 0, dtype=torch.float32)
(labels.size(0), label_channels), 0, dtype=torch.float32)
inds = torch.nonzero(labels >= 1).squeeze()
if inds.numel() > 0:
bin_labels[inds, labels[inds] - 1] = 1
bin_label_weights = label_weights.view(-1, 1).expand(
label_weights.size(0), cls_out_channels)
label_weights.size(0), label_channels)
return bin_labels, bin_label_weights
......
......@@ -14,13 +14,9 @@ from ..utils import normal_init
class AnchorHead(nn.Module):
"""Anchor-based head (RPN, RetinaNet, SSD, etc.).
/ - conv_cls (1x1 conv)
input - rpn_conv (3x3 conv) -
\ - conv_reg (1x1 conv)
Args:
in_channels (int): Number of channels in the input feature map.
feat_channels (int): Number of channels for the RPN feature map.
feat_channels (int): Number of channels of the feature map.
anchor_scales (Iterable): Anchor scales.
anchor_ratios (Iterable): Anchor aspect ratios.
anchor_strides (Iterable): Anchor strides.
......@@ -29,6 +25,7 @@ class AnchorHead(nn.Module):
target_stds (Iterable): Std values of regression targets.
use_sigmoid_cls (bool): Whether to use sigmoid loss for classification.
(softmax by default)
use_focal_loss (bool): Whether to use focal loss for classification.
""" # noqa: W605
def __init__(self,
......@@ -80,9 +77,9 @@ class AnchorHead(nn.Module):
normal_init(self.conv_reg, std=0.01)
def forward_single(self, x):
rpn_cls_score = self.conv_cls(x)
rpn_bbox_pred = self.conv_reg(x)
return rpn_cls_score, rpn_bbox_pred
cls_score = self.conv_cls(x)
bbox_pred = self.conv_reg(x)
return cls_score, bbox_pred
def forward(self, feats):
return multi_apply(self.forward_single, feats)
......@@ -129,10 +126,13 @@ class AnchorHead(nn.Module):
def loss_single(self, cls_score, bbox_pred, labels, label_weights,
bbox_targets, bbox_weights, num_total_samples, cfg):
# classification loss
labels = labels.contiguous().view(-1, self.cls_out_channels)
label_weights = label_weights.contiguous().view(
-1, self.cls_out_channels)
cls_score = cls_score.permute(0, 2, 3, 1).contiguous().view(
if self.use_sigmoid_cls:
labels = labels.reshape(-1, self.cls_out_channels)
label_weights = label_weights.reshape(-1, self.cls_out_channels)
else:
labels = labels.reshape(-1)
label_weights = label_weights.reshape(-1)
cls_score = cls_score.permute(0, 2, 3, 1).reshape(
-1, self.cls_out_channels)
if self.use_sigmoid_cls:
if self.use_focal_loss:
......@@ -156,9 +156,9 @@ class AnchorHead(nn.Module):
loss_cls = cls_criterion(
cls_score, labels, label_weights, avg_factor=num_total_samples)
# regression loss
bbox_targets = bbox_targets.contiguous().view(-1, 4)
bbox_weights = bbox_weights.contiguous().view(-1, 4)
bbox_pred = bbox_pred.permute(0, 2, 3, 1).contiguous().view(-1, 4)
bbox_targets = bbox_targets.reshape(-1, 4)
bbox_weights = bbox_weights.reshape(-1, 4)
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
loss_reg = weighted_smoothl1(
bbox_pred,
bbox_targets,
......@@ -175,6 +175,7 @@ class AnchorHead(nn.Module):
anchor_list, valid_flag_list = self.get_anchors(
featmap_sizes, img_metas)
sampling = False if self.use_focal_loss else True
label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
cls_reg_targets = anchor_target(
anchor_list,
valid_flag_list,
......@@ -184,7 +185,7 @@ class AnchorHead(nn.Module):
self.target_stds,
cfg,
gt_labels_list=gt_labels,
cls_out_channels=self.cls_out_channels,
label_channels=label_channels,
sampling=sampling)
if cls_reg_targets is None:
return None
......@@ -202,7 +203,7 @@ class AnchorHead(nn.Module):
bbox_weights_list,
num_total_samples=num_total_samples,
cfg=cfg)
return dict(loss_rpn_cls=losses_cls, loss_rpn_reg=losses_reg)
return dict(loss_cls=losses_cls, loss_reg=losses_reg)
def get_bboxes(self, cls_scores, bbox_preds, img_metas, cfg,
rescale=False):
......
......@@ -33,8 +33,10 @@ class RPNHead(AnchorHead):
return rpn_cls_score, rpn_bbox_pred
def loss(self, cls_scores, bbox_preds, gt_bboxes, img_metas, cfg):
return super(RPNHead, self).loss(cls_scores, bbox_preds, gt_bboxes,
None, img_metas, cfg)
losses = super(RPNHead, self).loss(cls_scores, bbox_preds, gt_bboxes,
None, img_metas, cfg)
return dict(
loss_rpn_cls=losses['loss_cls'], loss_rpn_reg=losses['loss_reg'])
def get_bboxes_single(self,
cls_scores,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment