Commit 1ec92ef4 authored by myownskyW7's avatar myownskyW7 Committed by Kai Chen
Browse files

rename use_focal_loss -> cls_focal_loss (#639)

* use_sigmoid_cls -> cls_sigmoid_loss; use_focal_loss -> cls_focal_loss

* fix flake8 error

* cls_sigmoid_loss - > use_sigmoid_cls
parent 64928acc
...@@ -25,9 +25,9 @@ class AnchorHead(nn.Module): ...@@ -25,9 +25,9 @@ class AnchorHead(nn.Module):
anchor_base_sizes (Iterable): Anchor base sizes. anchor_base_sizes (Iterable): Anchor base sizes.
target_means (Iterable): Mean values of regression targets. target_means (Iterable): Mean values of regression targets.
target_stds (Iterable): Std values of regression targets. target_stds (Iterable): Std values of regression targets.
use_sigmoid_cls (bool): Whether to use sigmoid loss for classification. use_sigmoid_cls (bool): Whether to use sigmoid loss for
(softmax by default) classification. (softmax by default)
use_focal_loss (bool): Whether to use focal loss for classification. cls_focal_loss (bool): Whether to use focal loss for classification.
""" # noqa: W605 """ # noqa: W605
def __init__(self, def __init__(self,
...@@ -41,7 +41,7 @@ class AnchorHead(nn.Module): ...@@ -41,7 +41,7 @@ class AnchorHead(nn.Module):
target_means=(.0, .0, .0, .0), target_means=(.0, .0, .0, .0),
target_stds=(1.0, 1.0, 1.0, 1.0), target_stds=(1.0, 1.0, 1.0, 1.0),
use_sigmoid_cls=False, use_sigmoid_cls=False,
use_focal_loss=False): cls_focal_loss=False):
super(AnchorHead, self).__init__() super(AnchorHead, self).__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.num_classes = num_classes self.num_classes = num_classes
...@@ -54,7 +54,7 @@ class AnchorHead(nn.Module): ...@@ -54,7 +54,7 @@ class AnchorHead(nn.Module):
self.target_means = target_means self.target_means = target_means
self.target_stds = target_stds self.target_stds = target_stds
self.use_sigmoid_cls = use_sigmoid_cls self.use_sigmoid_cls = use_sigmoid_cls
self.use_focal_loss = use_focal_loss self.cls_focal_loss = cls_focal_loss
self.anchor_generators = [] self.anchor_generators = []
for anchor_base in self.anchor_base_sizes: for anchor_base in self.anchor_base_sizes:
...@@ -133,16 +133,16 @@ class AnchorHead(nn.Module): ...@@ -133,16 +133,16 @@ class AnchorHead(nn.Module):
cls_score = cls_score.permute(0, 2, 3, 1).reshape( cls_score = cls_score.permute(0, 2, 3, 1).reshape(
-1, self.cls_out_channels) -1, self.cls_out_channels)
if self.use_sigmoid_cls: if self.use_sigmoid_cls:
if self.use_focal_loss: if self.cls_focal_loss:
cls_criterion = weighted_sigmoid_focal_loss cls_criterion = weighted_sigmoid_focal_loss
else: else:
cls_criterion = weighted_binary_cross_entropy cls_criterion = weighted_binary_cross_entropy
else: else:
if self.use_focal_loss: if self.cls_focal_loss:
raise NotImplementedError raise NotImplementedError
else: else:
cls_criterion = weighted_cross_entropy cls_criterion = weighted_cross_entropy
if self.use_focal_loss: if self.cls_focal_loss:
loss_cls = cls_criterion( loss_cls = cls_criterion(
cls_score, cls_score,
labels, labels,
...@@ -178,7 +178,7 @@ class AnchorHead(nn.Module): ...@@ -178,7 +178,7 @@ class AnchorHead(nn.Module):
anchor_list, valid_flag_list = self.get_anchors( anchor_list, valid_flag_list = self.get_anchors(
featmap_sizes, img_metas) featmap_sizes, img_metas)
sampling = False if self.use_focal_loss else True sampling = False if self.cls_focal_loss else True
label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1 label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
cls_reg_targets = anchor_target( cls_reg_targets = anchor_target(
anchor_list, anchor_list,
...@@ -196,7 +196,7 @@ class AnchorHead(nn.Module): ...@@ -196,7 +196,7 @@ class AnchorHead(nn.Module):
return None return None
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
num_total_pos, num_total_neg) = cls_reg_targets num_total_pos, num_total_neg) = cls_reg_targets
num_total_samples = (num_total_pos if self.use_focal_loss else num_total_samples = (num_total_pos if self.cls_focal_loss else
num_total_pos + num_total_neg) num_total_pos + num_total_neg)
losses_cls, losses_reg = multi_apply( losses_cls, losses_reg = multi_apply(
self.loss_single, self.loss_single,
......
...@@ -32,7 +32,7 @@ class RetinaHead(AnchorHead): ...@@ -32,7 +32,7 @@ class RetinaHead(AnchorHead):
in_channels, in_channels,
anchor_scales=anchor_scales, anchor_scales=anchor_scales,
use_sigmoid_cls=True, use_sigmoid_cls=True,
use_focal_loss=True, cls_focal_loss=True,
**kwargs) **kwargs)
def _init_layers(self): def _init_layers(self):
......
...@@ -90,7 +90,7 @@ class SSDHead(AnchorHead): ...@@ -90,7 +90,7 @@ class SSDHead(AnchorHead):
self.target_means = target_means self.target_means = target_means
self.target_stds = target_stds self.target_stds = target_stds
self.use_sigmoid_cls = False self.use_sigmoid_cls = False
self.use_focal_loss = False self.cls_focal_loss = False
def init_weights(self): def init_weights(self):
for m in self.modules(): for m in self.modules():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment