Commit 1ec92ef4 authored by myownskyW7's avatar myownskyW7 Committed by Kai Chen
Browse files

rename use_focal_loss -> cls_focal_loss (#639)

* use_sigmoid_cls -> cls_sigmoid_loss; use_focal_loss -> cls_focal_loss

* fix flake8 error

* cls_sigmoid_loss - > use_sigmoid_cls
parent 64928acc
......@@ -25,9 +25,9 @@ class AnchorHead(nn.Module):
anchor_base_sizes (Iterable): Anchor base sizes.
target_means (Iterable): Mean values of regression targets.
target_stds (Iterable): Std values of regression targets.
use_sigmoid_cls (bool): Whether to use sigmoid loss for classification.
(softmax by default)
use_focal_loss (bool): Whether to use focal loss for classification.
use_sigmoid_cls (bool): Whether to use sigmoid loss for
classification. (softmax by default)
cls_focal_loss (bool): Whether to use focal loss for classification.
""" # noqa: W605
def __init__(self,
......@@ -41,7 +41,7 @@ class AnchorHead(nn.Module):
target_means=(.0, .0, .0, .0),
target_stds=(1.0, 1.0, 1.0, 1.0),
use_sigmoid_cls=False,
use_focal_loss=False):
cls_focal_loss=False):
super(AnchorHead, self).__init__()
self.in_channels = in_channels
self.num_classes = num_classes
......@@ -54,7 +54,7 @@ class AnchorHead(nn.Module):
self.target_means = target_means
self.target_stds = target_stds
self.use_sigmoid_cls = use_sigmoid_cls
self.use_focal_loss = use_focal_loss
self.cls_focal_loss = cls_focal_loss
self.anchor_generators = []
for anchor_base in self.anchor_base_sizes:
......@@ -133,16 +133,16 @@ class AnchorHead(nn.Module):
cls_score = cls_score.permute(0, 2, 3, 1).reshape(
-1, self.cls_out_channels)
if self.use_sigmoid_cls:
if self.use_focal_loss:
if self.cls_focal_loss:
cls_criterion = weighted_sigmoid_focal_loss
else:
cls_criterion = weighted_binary_cross_entropy
else:
if self.use_focal_loss:
if self.cls_focal_loss:
raise NotImplementedError
else:
cls_criterion = weighted_cross_entropy
if self.use_focal_loss:
if self.cls_focal_loss:
loss_cls = cls_criterion(
cls_score,
labels,
......@@ -178,7 +178,7 @@ class AnchorHead(nn.Module):
anchor_list, valid_flag_list = self.get_anchors(
featmap_sizes, img_metas)
sampling = False if self.use_focal_loss else True
sampling = False if self.cls_focal_loss else True
label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
cls_reg_targets = anchor_target(
anchor_list,
......@@ -196,7 +196,7 @@ class AnchorHead(nn.Module):
return None
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
num_total_pos, num_total_neg) = cls_reg_targets
num_total_samples = (num_total_pos if self.use_focal_loss else
num_total_samples = (num_total_pos if self.cls_focal_loss else
num_total_pos + num_total_neg)
losses_cls, losses_reg = multi_apply(
self.loss_single,
......
......@@ -32,7 +32,7 @@ class RetinaHead(AnchorHead):
in_channels,
anchor_scales=anchor_scales,
use_sigmoid_cls=True,
use_focal_loss=True,
cls_focal_loss=True,
**kwargs)
def _init_layers(self):
......
......@@ -90,7 +90,7 @@ class SSDHead(AnchorHead):
self.target_means = target_means
self.target_stds = target_stds
self.use_sigmoid_cls = False
self.use_focal_loss = False
self.cls_focal_loss = False
def init_weights(self):
for m in self.modules():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment