Commit 2df1e0a0 authored by Kai Chen's avatar Kai Chen
Browse files

bug fix for using softmax

parent 70700512
...@@ -12,7 +12,7 @@ def anchor_target(anchor_list, ...@@ -12,7 +12,7 @@ def anchor_target(anchor_list,
target_stds, target_stds,
cfg, cfg,
gt_labels_list=None, gt_labels_list=None,
cls_out_channels=1, label_channels=1,
sampling=True, sampling=True,
unmap_outputs=True): unmap_outputs=True):
"""Compute regression and classification targets for anchors. """Compute regression and classification targets for anchors.
...@@ -54,7 +54,7 @@ def anchor_target(anchor_list, ...@@ -54,7 +54,7 @@ def anchor_target(anchor_list,
target_means=target_means, target_means=target_means,
target_stds=target_stds, target_stds=target_stds,
cfg=cfg, cfg=cfg,
cls_out_channels=cls_out_channels, label_channels=label_channels,
sampling=sampling, sampling=sampling,
unmap_outputs=unmap_outputs) unmap_outputs=unmap_outputs)
# no valid anchors # no valid anchors
...@@ -95,7 +95,7 @@ def anchor_target_single(flat_anchors, ...@@ -95,7 +95,7 @@ def anchor_target_single(flat_anchors,
target_means, target_means,
target_stds, target_stds,
cfg, cfg,
cls_out_channels=1, label_channels=1,
sampling=True, sampling=True,
unmap_outputs=True): unmap_outputs=True):
inside_flags = anchor_inside_flags(flat_anchors, valid_flags, inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
...@@ -147,9 +147,9 @@ def anchor_target_single(flat_anchors, ...@@ -147,9 +147,9 @@ def anchor_target_single(flat_anchors,
num_total_anchors = flat_anchors.size(0) num_total_anchors = flat_anchors.size(0)
labels = unmap(labels, num_total_anchors, inside_flags) labels = unmap(labels, num_total_anchors, inside_flags)
label_weights = unmap(label_weights, num_total_anchors, inside_flags) label_weights = unmap(label_weights, num_total_anchors, inside_flags)
if cls_out_channels > 1: if label_channels > 1:
labels, label_weights = expand_binary_labels(labels, label_weights, labels, label_weights = expand_binary_labels(
cls_out_channels) labels, label_weights, label_channels)
bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags) bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags) bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)
...@@ -157,14 +157,14 @@ def anchor_target_single(flat_anchors, ...@@ -157,14 +157,14 @@ def anchor_target_single(flat_anchors,
neg_inds) neg_inds)
def expand_binary_labels(labels, label_weights, cls_out_channels): def expand_binary_labels(labels, label_weights, label_channels):
bin_labels = labels.new_full( bin_labels = labels.new_full(
(labels.size(0), cls_out_channels), 0, dtype=torch.float32) (labels.size(0), label_channels), 0, dtype=torch.float32)
inds = torch.nonzero(labels >= 1).squeeze() inds = torch.nonzero(labels >= 1).squeeze()
if inds.numel() > 0: if inds.numel() > 0:
bin_labels[inds, labels[inds] - 1] = 1 bin_labels[inds, labels[inds] - 1] = 1
bin_label_weights = label_weights.view(-1, 1).expand( bin_label_weights = label_weights.view(-1, 1).expand(
label_weights.size(0), cls_out_channels) label_weights.size(0), label_channels)
return bin_labels, bin_label_weights return bin_labels, bin_label_weights
......
...@@ -14,13 +14,9 @@ from ..utils import normal_init ...@@ -14,13 +14,9 @@ from ..utils import normal_init
class AnchorHead(nn.Module): class AnchorHead(nn.Module):
"""Anchor-based head (RPN, RetinaNet, SSD, etc.). """Anchor-based head (RPN, RetinaNet, SSD, etc.).
/ - conv_cls (1x1 conv)
input - rpn_conv (3x3 conv) -
\ - conv_reg (1x1 conv)
Args: Args:
in_channels (int): Number of channels in the input feature map. in_channels (int): Number of channels in the input feature map.
feat_channels (int): Number of channels for the RPN feature map. feat_channels (int): Number of channels of the feature map.
anchor_scales (Iterable): Anchor scales. anchor_scales (Iterable): Anchor scales.
anchor_ratios (Iterable): Anchor aspect ratios. anchor_ratios (Iterable): Anchor aspect ratios.
anchor_strides (Iterable): Anchor strides. anchor_strides (Iterable): Anchor strides.
...@@ -29,6 +25,7 @@ class AnchorHead(nn.Module): ...@@ -29,6 +25,7 @@ class AnchorHead(nn.Module):
target_stds (Iterable): Std values of regression targets. target_stds (Iterable): Std values of regression targets.
use_sigmoid_cls (bool): Whether to use sigmoid loss for classification. use_sigmoid_cls (bool): Whether to use sigmoid loss for classification.
(softmax by default) (softmax by default)
use_focal_loss (bool): Whether to use focal loss for classification.
""" # noqa: W605 """ # noqa: W605
def __init__(self, def __init__(self,
...@@ -80,9 +77,9 @@ class AnchorHead(nn.Module): ...@@ -80,9 +77,9 @@ class AnchorHead(nn.Module):
normal_init(self.conv_reg, std=0.01) normal_init(self.conv_reg, std=0.01)
def forward_single(self, x): def forward_single(self, x):
rpn_cls_score = self.conv_cls(x) cls_score = self.conv_cls(x)
rpn_bbox_pred = self.conv_reg(x) bbox_pred = self.conv_reg(x)
return rpn_cls_score, rpn_bbox_pred return cls_score, bbox_pred
def forward(self, feats): def forward(self, feats):
return multi_apply(self.forward_single, feats) return multi_apply(self.forward_single, feats)
...@@ -129,10 +126,13 @@ class AnchorHead(nn.Module): ...@@ -129,10 +126,13 @@ class AnchorHead(nn.Module):
def loss_single(self, cls_score, bbox_pred, labels, label_weights, def loss_single(self, cls_score, bbox_pred, labels, label_weights,
bbox_targets, bbox_weights, num_total_samples, cfg): bbox_targets, bbox_weights, num_total_samples, cfg):
# classification loss # classification loss
labels = labels.contiguous().view(-1, self.cls_out_channels) if self.use_sigmoid_cls:
label_weights = label_weights.contiguous().view( labels = labels.reshape(-1, self.cls_out_channels)
-1, self.cls_out_channels) label_weights = label_weights.reshape(-1, self.cls_out_channels)
cls_score = cls_score.permute(0, 2, 3, 1).contiguous().view( else:
labels = labels.reshape(-1)
label_weights = label_weights.reshape(-1)
cls_score = cls_score.permute(0, 2, 3, 1).reshape(
-1, self.cls_out_channels) -1, self.cls_out_channels)
if self.use_sigmoid_cls: if self.use_sigmoid_cls:
if self.use_focal_loss: if self.use_focal_loss:
...@@ -156,9 +156,9 @@ class AnchorHead(nn.Module): ...@@ -156,9 +156,9 @@ class AnchorHead(nn.Module):
loss_cls = cls_criterion( loss_cls = cls_criterion(
cls_score, labels, label_weights, avg_factor=num_total_samples) cls_score, labels, label_weights, avg_factor=num_total_samples)
# regression loss # regression loss
bbox_targets = bbox_targets.contiguous().view(-1, 4) bbox_targets = bbox_targets.reshape(-1, 4)
bbox_weights = bbox_weights.contiguous().view(-1, 4) bbox_weights = bbox_weights.reshape(-1, 4)
bbox_pred = bbox_pred.permute(0, 2, 3, 1).contiguous().view(-1, 4) bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
loss_reg = weighted_smoothl1( loss_reg = weighted_smoothl1(
bbox_pred, bbox_pred,
bbox_targets, bbox_targets,
...@@ -175,6 +175,7 @@ class AnchorHead(nn.Module): ...@@ -175,6 +175,7 @@ class AnchorHead(nn.Module):
anchor_list, valid_flag_list = self.get_anchors( anchor_list, valid_flag_list = self.get_anchors(
featmap_sizes, img_metas) featmap_sizes, img_metas)
sampling = False if self.use_focal_loss else True sampling = False if self.use_focal_loss else True
label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
cls_reg_targets = anchor_target( cls_reg_targets = anchor_target(
anchor_list, anchor_list,
valid_flag_list, valid_flag_list,
...@@ -184,7 +185,7 @@ class AnchorHead(nn.Module): ...@@ -184,7 +185,7 @@ class AnchorHead(nn.Module):
self.target_stds, self.target_stds,
cfg, cfg,
gt_labels_list=gt_labels, gt_labels_list=gt_labels,
cls_out_channels=self.cls_out_channels, label_channels=label_channels,
sampling=sampling) sampling=sampling)
if cls_reg_targets is None: if cls_reg_targets is None:
return None return None
...@@ -202,7 +203,7 @@ class AnchorHead(nn.Module): ...@@ -202,7 +203,7 @@ class AnchorHead(nn.Module):
bbox_weights_list, bbox_weights_list,
num_total_samples=num_total_samples, num_total_samples=num_total_samples,
cfg=cfg) cfg=cfg)
return dict(loss_rpn_cls=losses_cls, loss_rpn_reg=losses_reg) return dict(loss_cls=losses_cls, loss_reg=losses_reg)
def get_bboxes(self, cls_scores, bbox_preds, img_metas, cfg, def get_bboxes(self, cls_scores, bbox_preds, img_metas, cfg,
rescale=False): rescale=False):
......
...@@ -33,8 +33,10 @@ class RPNHead(AnchorHead): ...@@ -33,8 +33,10 @@ class RPNHead(AnchorHead):
return rpn_cls_score, rpn_bbox_pred return rpn_cls_score, rpn_bbox_pred
def loss(self, cls_scores, bbox_preds, gt_bboxes, img_metas, cfg): def loss(self, cls_scores, bbox_preds, gt_bboxes, img_metas, cfg):
return super(RPNHead, self).loss(cls_scores, bbox_preds, gt_bboxes, losses = super(RPNHead, self).loss(cls_scores, bbox_preds, gt_bboxes,
None, img_metas, cfg) None, img_metas, cfg)
return dict(
loss_rpn_cls=losses['loss_cls'], loss_rpn_reg=losses['loss_reg'])
def get_bboxes_single(self, def get_bboxes_single(self,
cls_scores, cls_scores,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment