Commit 57d34592 authored by Jiangmiao Pang's avatar Jiangmiao Pang Committed by Kai Chen
Browse files

Add loss evaluator (#678)

* Fix license in setup.py

* Add code for loss evaluator

* Configs support loss evaluator

* Fix a little bug

* Fix flake8

* return revised bbox to reg

* return revised bbox to reg

* revision according to comments

* fix flake8
parent a99dbae7
......@@ -27,7 +27,10 @@ model = dict(
anchor_strides=[4, 8, 16, 32, 64],
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0],
use_sigmoid_cls=True),
use_sigmoid_cls=True,
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
......@@ -45,7 +48,10 @@ model = dict(
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2],
reg_class_agnostic=False,
norm_cfg=norm_cfg),
norm_cfg=norm_cfg,
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)),
mask_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
......@@ -57,7 +63,9 @@ model = dict(
in_channels=256,
conv_out_channels=256,
num_classes=81,
norm_cfg=norm_cfg))
norm_cfg=norm_cfg,
loss_mask=dict(
type='CrossEntropyLoss', use_mask=True, loss_weight=1.0)))
# model training and testing settings
train_cfg = dict(
rpn=dict(
......@@ -75,7 +83,6 @@ train_cfg = dict(
add_gt_as_proposals=False),
allowed_border=0,
pos_weight=-1,
smoothl1_beta=1 / 9.0,
debug=False),
rpn_proposal=dict(
nms_across_levels=False,
......
......@@ -5,14 +5,16 @@ from .anchor_heads import * # noqa: F401,F403
from .shared_heads import * # noqa: F401,F403
from .bbox_heads import * # noqa: F401,F403
from .mask_heads import * # noqa: F401,F403
from .losses import * # noqa: F401,F403
from .detectors import * # noqa: F401,F403
from .registry import (BACKBONES, NECKS, ROI_EXTRACTORS, SHARED_HEADS, HEADS,
DETECTORS)
LOSSES, DETECTORS)
from .builder import (build_backbone, build_neck, build_roi_extractor,
build_shared_head, build_head, build_detector)
build_shared_head, build_head, build_loss,
build_detector)
__all__ = [
'BACKBONES', 'NECKS', 'ROI_EXTRACTORS', 'SHARED_HEADS', 'HEADS',
'BACKBONES', 'NECKS', 'ROI_EXTRACTORS', 'SHARED_HEADS', 'HEADS', 'LOSSES',
'DETECTORS', 'build_backbone', 'build_neck', 'build_roi_extractor',
'build_shared_head', 'build_head', 'build_detector'
'build_shared_head', 'build_head', 'build_loss', 'build_detector'
]
......@@ -6,9 +6,8 @@ import torch.nn as nn
from mmcv.cnn import normal_init
from mmdet.core import (AnchorGenerator, anchor_target, delta2bbox,
multi_apply, weighted_cross_entropy, weighted_smoothl1,
weighted_binary_cross_entropy,
weighted_sigmoid_focal_loss, multiclass_nms)
multi_apply, multiclass_nms)
from ..builder import build_loss
from ..registry import HEADS
......@@ -25,9 +24,8 @@ class AnchorHead(nn.Module):
anchor_base_sizes (Iterable): Anchor base sizes.
target_means (Iterable): Mean values of regression targets.
target_stds (Iterable): Std values of regression targets.
use_sigmoid_cls (bool): Whether to use sigmoid loss for
classification. (softmax by default)
cls_focal_loss (bool): Whether to use focal loss for classification.
loss_cls (dict): Config of classification loss.
loss_bbox (dict): Config of localization loss.
""" # noqa: W605
def __init__(self,
......@@ -40,8 +38,12 @@ class AnchorHead(nn.Module):
anchor_base_sizes=None,
target_means=(.0, .0, .0, .0),
target_stds=(1.0, 1.0, 1.0, 1.0),
use_sigmoid_cls=False,
cls_focal_loss=False):
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=True,
loss_weight=1.0),
loss_bbox=dict(
type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)):
super(AnchorHead, self).__init__()
self.in_channels = in_channels
self.num_classes = num_classes
......@@ -53,8 +55,15 @@ class AnchorHead(nn.Module):
anchor_strides) if anchor_base_sizes is None else anchor_base_sizes
self.target_means = target_means
self.target_stds = target_stds
self.use_sigmoid_cls = use_sigmoid_cls
self.cls_focal_loss = cls_focal_loss
self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
self.sampling = loss_cls['type'] not in ['FocalLoss']
if self.use_sigmoid_cls:
self.cls_out_channels = num_classes - 1
else:
self.cls_out_channels = num_classes
self.loss_cls = build_loss(loss_cls)
self.loss_bbox = build_loss(loss_bbox)
self.anchor_generators = []
for anchor_base in self.anchor_base_sizes:
......@@ -62,11 +71,6 @@ class AnchorHead(nn.Module):
AnchorGenerator(anchor_base, anchor_scales, anchor_ratios))
self.num_anchors = len(self.anchor_ratios) * len(self.anchor_scales)
if self.use_sigmoid_cls:
self.cls_out_channels = self.num_classes - 1
else:
self.cls_out_channels = self.num_classes
self._init_layers()
def _init_layers(self):
......@@ -130,40 +134,20 @@ class AnchorHead(nn.Module):
# classification loss
labels = labels.reshape(-1)
label_weights = label_weights.reshape(-1)
cls_score = cls_score.permute(0, 2, 3, 1).reshape(
-1, self.cls_out_channels)
if self.use_sigmoid_cls:
if self.cls_focal_loss:
cls_criterion = weighted_sigmoid_focal_loss
else:
cls_criterion = weighted_binary_cross_entropy
else:
if self.cls_focal_loss:
raise NotImplementedError
else:
cls_criterion = weighted_cross_entropy
if self.cls_focal_loss:
loss_cls = cls_criterion(
cls_score,
labels,
label_weights,
gamma=cfg.gamma,
alpha=cfg.alpha,
avg_factor=num_total_samples)
else:
loss_cls = cls_criterion(
cls_score, labels, label_weights, avg_factor=num_total_samples)
cls_score = cls_score.permute(0, 2, 3,
1).reshape(-1, self.cls_out_channels)
loss_cls = self.loss_cls(
cls_score, labels, label_weights, avg_factor=num_total_samples)
# regression loss
bbox_targets = bbox_targets.reshape(-1, 4)
bbox_weights = bbox_weights.reshape(-1, 4)
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
loss_reg = weighted_smoothl1(
loss_bbox = self.loss_bbox(
bbox_pred,
bbox_targets,
bbox_weights,
beta=cfg.smoothl1_beta,
avg_factor=num_total_samples)
return loss_cls, loss_reg
return loss_cls, loss_bbox
def loss(self,
cls_scores,
......@@ -178,7 +162,6 @@ class AnchorHead(nn.Module):
anchor_list, valid_flag_list = self.get_anchors(
featmap_sizes, img_metas)
sampling = False if self.cls_focal_loss else True
label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
cls_reg_targets = anchor_target(
anchor_list,
......@@ -191,15 +174,14 @@ class AnchorHead(nn.Module):
gt_bboxes_ignore_list=gt_bboxes_ignore,
gt_labels_list=gt_labels,
label_channels=label_channels,
sampling=sampling)
sampling=self.sampling)
if cls_reg_targets is None:
return None
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
num_total_pos, num_total_neg) = cls_reg_targets
num_total_samples = (
num_total_pos
if self.cls_focal_loss else num_total_pos + num_total_neg)
losses_cls, losses_reg = multi_apply(
num_total_pos + num_total_neg if self.sampling else num_total_pos)
losses_cls, losses_bbox = multi_apply(
self.loss_single,
cls_scores,
bbox_preds,
......@@ -209,7 +191,7 @@ class AnchorHead(nn.Module):
bbox_weights_list,
num_total_samples=num_total_samples,
cfg=cfg)
return dict(loss_cls=losses_cls, loss_reg=losses_reg)
return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
def get_bboxes(self, cls_scores, bbox_preds, img_metas, cfg,
rescale=False):
......@@ -251,8 +233,8 @@ class AnchorHead(nn.Module):
for cls_score, bbox_pred, anchors in zip(cls_scores, bbox_preds,
mlvl_anchors):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
cls_score = cls_score.permute(1, 2, 0).reshape(
-1, self.cls_out_channels)
cls_score = cls_score.permute(1, 2,
0).reshape(-1, self.cls_out_channels)
if self.use_sigmoid_cls:
scores = cls_score.sigmoid()
else:
......@@ -279,6 +261,7 @@ class AnchorHead(nn.Module):
if self.use_sigmoid_cls:
padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
mlvl_scores = torch.cat([padding, mlvl_scores], dim=1)
det_bboxes, det_labels = multiclass_nms(
mlvl_bboxes, mlvl_scores, cfg.score_thr, cfg.nms, cfg.max_per_img)
det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores,
cfg.score_thr, cfg.nms,
cfg.max_per_img)
return det_bboxes, det_labels
......@@ -28,12 +28,7 @@ class RetinaHead(AnchorHead):
[2**(i / scales_per_octave) for i in range(scales_per_octave)])
anchor_scales = octave_scales * octave_base_scale
super(RetinaHead, self).__init__(
num_classes,
in_channels,
anchor_scales=anchor_scales,
use_sigmoid_cls=True,
cls_focal_loss=True,
**kwargs)
num_classes, in_channels, anchor_scales=anchor_scales, **kwargs)
def _init_layers(self):
self.relu = nn.ReLU(inplace=True)
......
......@@ -50,7 +50,7 @@ class RPNHead(AnchorHead):
cfg,
gt_bboxes_ignore=gt_bboxes_ignore)
return dict(
loss_rpn_cls=losses['loss_cls'], loss_rpn_reg=losses['loss_reg'])
loss_rpn_cls=losses['loss_cls'], loss_rpn_bbox=losses['loss_bbox'])
def get_bboxes_single(self,
cls_scores,
......
......@@ -10,6 +10,7 @@ from .anchor_head import AnchorHead
from ..registry import HEADS
# TODO: add loss evaluator for SSD
@HEADS.register_module
class SSDHead(AnchorHead):
......@@ -122,13 +123,13 @@ class SSDHead(AnchorHead):
loss_cls_neg = topk_loss_cls_neg.sum()
loss_cls = (loss_cls_pos + loss_cls_neg) / num_total_samples
loss_reg = weighted_smoothl1(
loss_bbox = weighted_smoothl1(
bbox_pred,
bbox_targets,
bbox_weights,
beta=cfg.smoothl1_beta,
avg_factor=num_total_samples)
return loss_cls[None], loss_reg
return loss_cls[None], loss_bbox
def loss(self,
cls_scores,
......@@ -167,18 +168,18 @@ class SSDHead(AnchorHead):
num_images, -1, self.cls_out_channels) for s in cls_scores
], 1)
all_labels = torch.cat(labels_list, -1).view(num_images, -1)
all_label_weights = torch.cat(label_weights_list, -1).view(
num_images, -1)
all_label_weights = torch.cat(label_weights_list,
-1).view(num_images, -1)
all_bbox_preds = torch.cat([
b.permute(0, 2, 3, 1).reshape(num_images, -1, 4)
for b in bbox_preds
], -2)
all_bbox_targets = torch.cat(bbox_targets_list, -2).view(
num_images, -1, 4)
all_bbox_weights = torch.cat(bbox_weights_list, -2).view(
num_images, -1, 4)
all_bbox_targets = torch.cat(bbox_targets_list,
-2).view(num_images, -1, 4)
all_bbox_weights = torch.cat(bbox_weights_list,
-2).view(num_images, -1, 4)
losses_cls, losses_reg = multi_apply(
losses_cls, losses_bbox = multi_apply(
self.loss_single,
all_cls_scores,
all_bbox_preds,
......@@ -188,4 +189,4 @@ class SSDHead(AnchorHead):
all_bbox_weights,
num_total_samples=num_total_pos,
cfg=cfg)
return dict(loss_cls=losses_cls, loss_reg=losses_reg)
return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
......@@ -2,8 +2,8 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet.core import (delta2bbox, multiclass_nms, bbox_target,
weighted_cross_entropy, weighted_smoothl1, accuracy)
from mmdet.core import delta2bbox, multiclass_nms, bbox_target, accuracy
from ..builder import build_loss
from ..registry import HEADS
......@@ -21,7 +21,13 @@ class BBoxHead(nn.Module):
num_classes=81,
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2],
reg_class_agnostic=False):
reg_class_agnostic=False,
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0),
loss_bbox=dict(
type='SmoothL1Loss', beta=1.0, loss_weight=1.0)):
super(BBoxHead, self).__init__()
assert with_cls or with_reg
self.with_avg_pool = with_avg_pool
......@@ -34,6 +40,9 @@ class BBoxHead(nn.Module):
self.target_stds = target_stds
self.reg_class_agnostic = reg_class_agnostic
self.loss_cls = build_loss(loss_cls)
self.loss_bbox = build_loss(loss_bbox)
in_channels = self.in_channels
if self.with_avg_pool:
self.avg_pool = nn.AvgPool2d(roi_feat_size)
......@@ -90,7 +99,7 @@ class BBoxHead(nn.Module):
reduce=True):
losses = dict()
if cls_score is not None:
losses['loss_cls'] = weighted_cross_entropy(
losses['loss_cls'] = self.loss_cls(
cls_score, labels, label_weights, reduce=reduce)
losses['acc'] = accuracy(cls_score, labels)
if bbox_pred is not None:
......@@ -100,7 +109,7 @@ class BBoxHead(nn.Module):
else:
pos_bbox_pred = bbox_pred.view(bbox_pred.size(0), -1,
4)[pos_inds, labels[pos_inds]]
losses['loss_reg'] = weighted_smoothl1(
losses['loss_bbox'] = self.loss_bbox(
pos_bbox_pred,
bbox_targets[pos_inds],
bbox_weights[pos_inds],
......@@ -132,8 +141,9 @@ class BBoxHead(nn.Module):
if cfg is None:
return bboxes, scores
else:
det_bboxes, det_labels = multiclass_nms(
bboxes, scores, cfg.score_thr, cfg.nms, cfg.max_per_img)
det_bboxes, det_labels = multiclass_nms(bboxes, scores,
cfg.score_thr, cfg.nms,
cfg.max_per_img)
return det_bboxes, det_labels
......
......@@ -29,8 +29,8 @@ class ConvFCBBoxHead(BBoxHead):
*args,
**kwargs):
super(ConvFCBBoxHead, self).__init__(*args, **kwargs)
assert (num_shared_convs + num_shared_fcs + num_cls_convs + num_cls_fcs
+ num_reg_convs + num_reg_fcs > 0)
assert (num_shared_convs + num_shared_fcs + num_cls_convs +
num_cls_fcs + num_reg_convs + num_reg_fcs > 0)
if num_cls_convs > 0 or num_reg_convs > 0:
assert num_shared_fcs == 0
if not self.with_cls:
......@@ -76,8 +76,8 @@ class ConvFCBBoxHead(BBoxHead):
if self.with_cls:
self.fc_cls = nn.Linear(self.cls_last_dim, self.num_classes)
if self.with_reg:
out_dim_reg = (4 if self.reg_class_agnostic else
4 * self.num_classes)
out_dim_reg = (4 if self.reg_class_agnostic else 4 *
self.num_classes)
self.fc_reg = nn.Linear(self.reg_last_dim, out_dim_reg)
def _add_conv_fc_branch(self,
......
......@@ -2,7 +2,7 @@ import mmcv
from torch import nn
from .registry import (BACKBONES, NECKS, ROI_EXTRACTORS, SHARED_HEADS, HEADS,
DETECTORS)
LOSSES, DETECTORS)
def _build_module(cfg, registry, default_args):
......@@ -52,5 +52,9 @@ def build_head(cfg):
return build(cfg, HEADS)
def build_loss(cfg):
return build(cfg, LOSSES)
def build_detector(cfg, train_cfg=None, test_cfg=None):
return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
from .cross_entropy_loss import CrossEntropyLoss
from .focal_loss import FocalLoss
from .smooth_l1_loss import SmoothL1Loss
__all__ = ['CrossEntropyLoss', 'FocalLoss', 'SmoothL1Loss']
import torch.nn as nn
from mmdet.core import (weighted_cross_entropy, weighted_binary_cross_entropy,
mask_cross_entropy)
from ..registry import LOSSES
@LOSSES.register_module
class CrossEntropyLoss(nn.Module):
def __init__(self, use_sigmoid=False, use_mask=False, loss_weight=1.0):
super(CrossEntropyLoss, self).__init__()
assert (use_sigmoid is False) or (use_mask is False)
self.use_sigmoid = use_sigmoid
self.use_mask = use_mask
self.loss_weight = loss_weight
if self.use_sigmoid:
self.cls_criterion = weighted_binary_cross_entropy
elif self.use_mask:
self.cls_criterion = mask_cross_entropy
else:
self.cls_criterion = weighted_cross_entropy
def forward(self, cls_score, label, label_weight, *args, **kwargs):
loss_cls = self.loss_weight * self.cls_criterion(
cls_score, label, label_weight, *args, **kwargs)
return loss_cls
import torch.nn as nn
from mmdet.core import weighted_sigmoid_focal_loss
from ..registry import LOSSES
@LOSSES.register_module
class FocalLoss(nn.Module):
def __init__(self,
use_sigmoid=False,
loss_weight=1.0,
gamma=2.0,
alpha=0.25):
super(FocalLoss, self).__init__()
assert use_sigmoid is True, 'Only sigmoid focaloss supported now.'
self.use_sigmoid = use_sigmoid
self.loss_weight = loss_weight
self.gamma = gamma
self.alpha = alpha
self.cls_criterion = weighted_sigmoid_focal_loss
def forward(self, cls_score, label, label_weight, *args, **kwargs):
if self.use_sigmoid:
loss_cls = self.loss_weight * self.cls_criterion(
cls_score,
label,
label_weight,
gamma=self.gamma,
alpha=self.alpha,
*args,
**kwargs)
else:
raise NotImplementedError
return loss_cls
import torch.nn as nn
from mmdet.core import weighted_smoothl1
from ..registry import LOSSES
@LOSSES.register_module
class SmoothL1Loss(nn.Module):
def __init__(self, beta=1.0, loss_weight=1.0):
super(SmoothL1Loss, self).__init__()
self.beta = beta
self.loss_weight = loss_weight
def forward(self, pred, target, weight, *args, **kwargs):
loss_bbox = self.loss_weight * weighted_smoothl1(
pred, target, weight, beta=self.beta, *args, **kwargs)
return loss_bbox
......@@ -4,9 +4,10 @@ import pycocotools.mask as mask_util
import torch
import torch.nn as nn
from ..builder import build_loss
from ..registry import HEADS
from ..utils import ConvModule
from mmdet.core import mask_cross_entropy, mask_target
from mmdet.core import mask_target
@HEADS.register_module
......@@ -23,7 +24,9 @@ class FCNMaskHead(nn.Module):
num_classes=81,
class_agnostic=False,
conv_cfg=None,
norm_cfg=None):
norm_cfg=None,
loss_mask=dict(
type='CrossEntropyLoss', use_mask=True, loss_weight=1.0)):
super(FCNMaskHead, self).__init__()
if upsample_method not in [None, 'deconv', 'nearest', 'bilinear']:
raise ValueError(
......@@ -40,6 +43,7 @@ class FCNMaskHead(nn.Module):
self.class_agnostic = class_agnostic
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.loss_mask = build_loss(loss_mask)
self.convs = nn.ModuleList()
for i in range(self.num_convs):
......@@ -106,10 +110,10 @@ class FCNMaskHead(nn.Module):
def loss(self, mask_pred, mask_targets, labels):
loss = dict()
if self.class_agnostic:
loss_mask = mask_cross_entropy(mask_pred, mask_targets,
torch.zeros_like(labels))
loss_mask = self.loss_mask(mask_pred, mask_targets,
torch.zeros_like(labels))
else:
loss_mask = mask_cross_entropy(mask_pred, mask_targets, labels)
loss_mask = self.loss_mask(mask_pred, mask_targets, labels)
loss['loss_mask'] = loss_mask
return loss
......
......@@ -41,4 +41,5 @@ NECKS = Registry('neck')
ROI_EXTRACTORS = Registry('roi_extractor')
SHARED_HEADS = Registry('shared_head')
HEADS = Registry('head')
LOSSES = Registry('loss')
DETECTORS = Registry('detector')
......@@ -102,7 +102,7 @@ if __name__ == '__main__':
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6',
],
license='GPLv3',
license='Apache License 2.0',
setup_requires=['pytest-runner'],
tests_require=['pytest'],
install_requires=[
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment