Commit 57d34592 authored by Jiangmiao Pang's avatar Jiangmiao Pang Committed by Kai Chen
Browse files

Add loss evaluator (#678)

* Fix license in setup.py

* Add code for loss evaluator

* Configs support loss evaluator

* Fix a little bug

* Fix flake8

* return revised bbox to reg

* return revised bbox to reg

* revision according to comments

* fix flake8
parent a99dbae7
...@@ -27,7 +27,10 @@ model = dict( ...@@ -27,7 +27,10 @@ model = dict(
anchor_strides=[4, 8, 16, 32, 64], anchor_strides=[4, 8, 16, 32, 64],
target_means=[.0, .0, .0, .0], target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0], target_stds=[1.0, 1.0, 1.0, 1.0],
use_sigmoid_cls=True), use_sigmoid_cls=True,
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
bbox_roi_extractor=dict( bbox_roi_extractor=dict(
type='SingleRoIExtractor', type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2), roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
...@@ -45,7 +48,10 @@ model = dict( ...@@ -45,7 +48,10 @@ model = dict(
target_means=[0., 0., 0., 0.], target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2], target_stds=[0.1, 0.1, 0.2, 0.2],
reg_class_agnostic=False, reg_class_agnostic=False,
norm_cfg=norm_cfg), norm_cfg=norm_cfg,
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)),
mask_roi_extractor=dict( mask_roi_extractor=dict(
type='SingleRoIExtractor', type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2), roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
...@@ -57,7 +63,9 @@ model = dict( ...@@ -57,7 +63,9 @@ model = dict(
in_channels=256, in_channels=256,
conv_out_channels=256, conv_out_channels=256,
num_classes=81, num_classes=81,
norm_cfg=norm_cfg)) norm_cfg=norm_cfg,
loss_mask=dict(
type='CrossEntropyLoss', use_mask=True, loss_weight=1.0)))
# model training and testing settings # model training and testing settings
train_cfg = dict( train_cfg = dict(
rpn=dict( rpn=dict(
...@@ -75,7 +83,6 @@ train_cfg = dict( ...@@ -75,7 +83,6 @@ train_cfg = dict(
add_gt_as_proposals=False), add_gt_as_proposals=False),
allowed_border=0, allowed_border=0,
pos_weight=-1, pos_weight=-1,
smoothl1_beta=1 / 9.0,
debug=False), debug=False),
rpn_proposal=dict( rpn_proposal=dict(
nms_across_levels=False, nms_across_levels=False,
......
...@@ -5,14 +5,16 @@ from .anchor_heads import * # noqa: F401,F403 ...@@ -5,14 +5,16 @@ from .anchor_heads import * # noqa: F401,F403
from .shared_heads import * # noqa: F401,F403 from .shared_heads import * # noqa: F401,F403
from .bbox_heads import * # noqa: F401,F403 from .bbox_heads import * # noqa: F401,F403
from .mask_heads import * # noqa: F401,F403 from .mask_heads import * # noqa: F401,F403
from .losses import * # noqa: F401,F403
from .detectors import * # noqa: F401,F403 from .detectors import * # noqa: F401,F403
from .registry import (BACKBONES, NECKS, ROI_EXTRACTORS, SHARED_HEADS, HEADS, from .registry import (BACKBONES, NECKS, ROI_EXTRACTORS, SHARED_HEADS, HEADS,
DETECTORS) LOSSES, DETECTORS)
from .builder import (build_backbone, build_neck, build_roi_extractor, from .builder import (build_backbone, build_neck, build_roi_extractor,
build_shared_head, build_head, build_detector) build_shared_head, build_head, build_loss,
build_detector)
__all__ = [ __all__ = [
'BACKBONES', 'NECKS', 'ROI_EXTRACTORS', 'SHARED_HEADS', 'HEADS', 'BACKBONES', 'NECKS', 'ROI_EXTRACTORS', 'SHARED_HEADS', 'HEADS', 'LOSSES',
'DETECTORS', 'build_backbone', 'build_neck', 'build_roi_extractor', 'DETECTORS', 'build_backbone', 'build_neck', 'build_roi_extractor',
'build_shared_head', 'build_head', 'build_detector' 'build_shared_head', 'build_head', 'build_loss', 'build_detector'
] ]
...@@ -6,9 +6,8 @@ import torch.nn as nn ...@@ -6,9 +6,8 @@ import torch.nn as nn
from mmcv.cnn import normal_init from mmcv.cnn import normal_init
from mmdet.core import (AnchorGenerator, anchor_target, delta2bbox, from mmdet.core import (AnchorGenerator, anchor_target, delta2bbox,
multi_apply, weighted_cross_entropy, weighted_smoothl1, multi_apply, multiclass_nms)
weighted_binary_cross_entropy, from ..builder import build_loss
weighted_sigmoid_focal_loss, multiclass_nms)
from ..registry import HEADS from ..registry import HEADS
...@@ -25,9 +24,8 @@ class AnchorHead(nn.Module): ...@@ -25,9 +24,8 @@ class AnchorHead(nn.Module):
anchor_base_sizes (Iterable): Anchor base sizes. anchor_base_sizes (Iterable): Anchor base sizes.
target_means (Iterable): Mean values of regression targets. target_means (Iterable): Mean values of regression targets.
target_stds (Iterable): Std values of regression targets. target_stds (Iterable): Std values of regression targets.
use_sigmoid_cls (bool): Whether to use sigmoid loss for loss_cls (dict): Config of classification loss.
classification. (softmax by default) loss_bbox (dict): Config of localization loss.
cls_focal_loss (bool): Whether to use focal loss for classification.
""" # noqa: W605 """ # noqa: W605
def __init__(self, def __init__(self,
...@@ -40,8 +38,12 @@ class AnchorHead(nn.Module): ...@@ -40,8 +38,12 @@ class AnchorHead(nn.Module):
anchor_base_sizes=None, anchor_base_sizes=None,
target_means=(.0, .0, .0, .0), target_means=(.0, .0, .0, .0),
target_stds=(1.0, 1.0, 1.0, 1.0), target_stds=(1.0, 1.0, 1.0, 1.0),
use_sigmoid_cls=False, loss_cls=dict(
cls_focal_loss=False): type='CrossEntropyLoss',
use_sigmoid=True,
loss_weight=1.0),
loss_bbox=dict(
type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)):
super(AnchorHead, self).__init__() super(AnchorHead, self).__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.num_classes = num_classes self.num_classes = num_classes
...@@ -53,8 +55,15 @@ class AnchorHead(nn.Module): ...@@ -53,8 +55,15 @@ class AnchorHead(nn.Module):
anchor_strides) if anchor_base_sizes is None else anchor_base_sizes anchor_strides) if anchor_base_sizes is None else anchor_base_sizes
self.target_means = target_means self.target_means = target_means
self.target_stds = target_stds self.target_stds = target_stds
self.use_sigmoid_cls = use_sigmoid_cls
self.cls_focal_loss = cls_focal_loss self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
self.sampling = loss_cls['type'] not in ['FocalLoss']
if self.use_sigmoid_cls:
self.cls_out_channels = num_classes - 1
else:
self.cls_out_channels = num_classes
self.loss_cls = build_loss(loss_cls)
self.loss_bbox = build_loss(loss_bbox)
self.anchor_generators = [] self.anchor_generators = []
for anchor_base in self.anchor_base_sizes: for anchor_base in self.anchor_base_sizes:
...@@ -62,11 +71,6 @@ class AnchorHead(nn.Module): ...@@ -62,11 +71,6 @@ class AnchorHead(nn.Module):
AnchorGenerator(anchor_base, anchor_scales, anchor_ratios)) AnchorGenerator(anchor_base, anchor_scales, anchor_ratios))
self.num_anchors = len(self.anchor_ratios) * len(self.anchor_scales) self.num_anchors = len(self.anchor_ratios) * len(self.anchor_scales)
if self.use_sigmoid_cls:
self.cls_out_channels = self.num_classes - 1
else:
self.cls_out_channels = self.num_classes
self._init_layers() self._init_layers()
def _init_layers(self): def _init_layers(self):
...@@ -130,40 +134,20 @@ class AnchorHead(nn.Module): ...@@ -130,40 +134,20 @@ class AnchorHead(nn.Module):
# classification loss # classification loss
labels = labels.reshape(-1) labels = labels.reshape(-1)
label_weights = label_weights.reshape(-1) label_weights = label_weights.reshape(-1)
cls_score = cls_score.permute(0, 2, 3, 1).reshape( cls_score = cls_score.permute(0, 2, 3,
-1, self.cls_out_channels) 1).reshape(-1, self.cls_out_channels)
if self.use_sigmoid_cls: loss_cls = self.loss_cls(
if self.cls_focal_loss: cls_score, labels, label_weights, avg_factor=num_total_samples)
cls_criterion = weighted_sigmoid_focal_loss
else:
cls_criterion = weighted_binary_cross_entropy
else:
if self.cls_focal_loss:
raise NotImplementedError
else:
cls_criterion = weighted_cross_entropy
if self.cls_focal_loss:
loss_cls = cls_criterion(
cls_score,
labels,
label_weights,
gamma=cfg.gamma,
alpha=cfg.alpha,
avg_factor=num_total_samples)
else:
loss_cls = cls_criterion(
cls_score, labels, label_weights, avg_factor=num_total_samples)
# regression loss # regression loss
bbox_targets = bbox_targets.reshape(-1, 4) bbox_targets = bbox_targets.reshape(-1, 4)
bbox_weights = bbox_weights.reshape(-1, 4) bbox_weights = bbox_weights.reshape(-1, 4)
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4) bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
loss_reg = weighted_smoothl1( loss_bbox = self.loss_bbox(
bbox_pred, bbox_pred,
bbox_targets, bbox_targets,
bbox_weights, bbox_weights,
beta=cfg.smoothl1_beta,
avg_factor=num_total_samples) avg_factor=num_total_samples)
return loss_cls, loss_reg return loss_cls, loss_bbox
def loss(self, def loss(self,
cls_scores, cls_scores,
...@@ -178,7 +162,6 @@ class AnchorHead(nn.Module): ...@@ -178,7 +162,6 @@ class AnchorHead(nn.Module):
anchor_list, valid_flag_list = self.get_anchors( anchor_list, valid_flag_list = self.get_anchors(
featmap_sizes, img_metas) featmap_sizes, img_metas)
sampling = False if self.cls_focal_loss else True
label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1 label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
cls_reg_targets = anchor_target( cls_reg_targets = anchor_target(
anchor_list, anchor_list,
...@@ -191,15 +174,14 @@ class AnchorHead(nn.Module): ...@@ -191,15 +174,14 @@ class AnchorHead(nn.Module):
gt_bboxes_ignore_list=gt_bboxes_ignore, gt_bboxes_ignore_list=gt_bboxes_ignore,
gt_labels_list=gt_labels, gt_labels_list=gt_labels,
label_channels=label_channels, label_channels=label_channels,
sampling=sampling) sampling=self.sampling)
if cls_reg_targets is None: if cls_reg_targets is None:
return None return None
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
num_total_pos, num_total_neg) = cls_reg_targets num_total_pos, num_total_neg) = cls_reg_targets
num_total_samples = ( num_total_samples = (
num_total_pos num_total_pos + num_total_neg if self.sampling else num_total_pos)
if self.cls_focal_loss else num_total_pos + num_total_neg) losses_cls, losses_bbox = multi_apply(
losses_cls, losses_reg = multi_apply(
self.loss_single, self.loss_single,
cls_scores, cls_scores,
bbox_preds, bbox_preds,
...@@ -209,7 +191,7 @@ class AnchorHead(nn.Module): ...@@ -209,7 +191,7 @@ class AnchorHead(nn.Module):
bbox_weights_list, bbox_weights_list,
num_total_samples=num_total_samples, num_total_samples=num_total_samples,
cfg=cfg) cfg=cfg)
return dict(loss_cls=losses_cls, loss_reg=losses_reg) return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
def get_bboxes(self, cls_scores, bbox_preds, img_metas, cfg, def get_bboxes(self, cls_scores, bbox_preds, img_metas, cfg,
rescale=False): rescale=False):
...@@ -251,8 +233,8 @@ class AnchorHead(nn.Module): ...@@ -251,8 +233,8 @@ class AnchorHead(nn.Module):
for cls_score, bbox_pred, anchors in zip(cls_scores, bbox_preds, for cls_score, bbox_pred, anchors in zip(cls_scores, bbox_preds,
mlvl_anchors): mlvl_anchors):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:] assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
cls_score = cls_score.permute(1, 2, 0).reshape( cls_score = cls_score.permute(1, 2,
-1, self.cls_out_channels) 0).reshape(-1, self.cls_out_channels)
if self.use_sigmoid_cls: if self.use_sigmoid_cls:
scores = cls_score.sigmoid() scores = cls_score.sigmoid()
else: else:
...@@ -279,6 +261,7 @@ class AnchorHead(nn.Module): ...@@ -279,6 +261,7 @@ class AnchorHead(nn.Module):
if self.use_sigmoid_cls: if self.use_sigmoid_cls:
padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
mlvl_scores = torch.cat([padding, mlvl_scores], dim=1) mlvl_scores = torch.cat([padding, mlvl_scores], dim=1)
det_bboxes, det_labels = multiclass_nms( det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores,
mlvl_bboxes, mlvl_scores, cfg.score_thr, cfg.nms, cfg.max_per_img) cfg.score_thr, cfg.nms,
cfg.max_per_img)
return det_bboxes, det_labels return det_bboxes, det_labels
...@@ -28,12 +28,7 @@ class RetinaHead(AnchorHead): ...@@ -28,12 +28,7 @@ class RetinaHead(AnchorHead):
[2**(i / scales_per_octave) for i in range(scales_per_octave)]) [2**(i / scales_per_octave) for i in range(scales_per_octave)])
anchor_scales = octave_scales * octave_base_scale anchor_scales = octave_scales * octave_base_scale
super(RetinaHead, self).__init__( super(RetinaHead, self).__init__(
num_classes, num_classes, in_channels, anchor_scales=anchor_scales, **kwargs)
in_channels,
anchor_scales=anchor_scales,
use_sigmoid_cls=True,
cls_focal_loss=True,
**kwargs)
def _init_layers(self): def _init_layers(self):
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
......
...@@ -50,7 +50,7 @@ class RPNHead(AnchorHead): ...@@ -50,7 +50,7 @@ class RPNHead(AnchorHead):
cfg, cfg,
gt_bboxes_ignore=gt_bboxes_ignore) gt_bboxes_ignore=gt_bboxes_ignore)
return dict( return dict(
loss_rpn_cls=losses['loss_cls'], loss_rpn_reg=losses['loss_reg']) loss_rpn_cls=losses['loss_cls'], loss_rpn_bbox=losses['loss_bbox'])
def get_bboxes_single(self, def get_bboxes_single(self,
cls_scores, cls_scores,
......
...@@ -10,6 +10,7 @@ from .anchor_head import AnchorHead ...@@ -10,6 +10,7 @@ from .anchor_head import AnchorHead
from ..registry import HEADS from ..registry import HEADS
# TODO: add loss evaluator for SSD
@HEADS.register_module @HEADS.register_module
class SSDHead(AnchorHead): class SSDHead(AnchorHead):
...@@ -122,13 +123,13 @@ class SSDHead(AnchorHead): ...@@ -122,13 +123,13 @@ class SSDHead(AnchorHead):
loss_cls_neg = topk_loss_cls_neg.sum() loss_cls_neg = topk_loss_cls_neg.sum()
loss_cls = (loss_cls_pos + loss_cls_neg) / num_total_samples loss_cls = (loss_cls_pos + loss_cls_neg) / num_total_samples
loss_reg = weighted_smoothl1( loss_bbox = weighted_smoothl1(
bbox_pred, bbox_pred,
bbox_targets, bbox_targets,
bbox_weights, bbox_weights,
beta=cfg.smoothl1_beta, beta=cfg.smoothl1_beta,
avg_factor=num_total_samples) avg_factor=num_total_samples)
return loss_cls[None], loss_reg return loss_cls[None], loss_bbox
def loss(self, def loss(self,
cls_scores, cls_scores,
...@@ -167,18 +168,18 @@ class SSDHead(AnchorHead): ...@@ -167,18 +168,18 @@ class SSDHead(AnchorHead):
num_images, -1, self.cls_out_channels) for s in cls_scores num_images, -1, self.cls_out_channels) for s in cls_scores
], 1) ], 1)
all_labels = torch.cat(labels_list, -1).view(num_images, -1) all_labels = torch.cat(labels_list, -1).view(num_images, -1)
all_label_weights = torch.cat(label_weights_list, -1).view( all_label_weights = torch.cat(label_weights_list,
num_images, -1) -1).view(num_images, -1)
all_bbox_preds = torch.cat([ all_bbox_preds = torch.cat([
b.permute(0, 2, 3, 1).reshape(num_images, -1, 4) b.permute(0, 2, 3, 1).reshape(num_images, -1, 4)
for b in bbox_preds for b in bbox_preds
], -2) ], -2)
all_bbox_targets = torch.cat(bbox_targets_list, -2).view( all_bbox_targets = torch.cat(bbox_targets_list,
num_images, -1, 4) -2).view(num_images, -1, 4)
all_bbox_weights = torch.cat(bbox_weights_list, -2).view( all_bbox_weights = torch.cat(bbox_weights_list,
num_images, -1, 4) -2).view(num_images, -1, 4)
losses_cls, losses_reg = multi_apply( losses_cls, losses_bbox = multi_apply(
self.loss_single, self.loss_single,
all_cls_scores, all_cls_scores,
all_bbox_preds, all_bbox_preds,
...@@ -188,4 +189,4 @@ class SSDHead(AnchorHead): ...@@ -188,4 +189,4 @@ class SSDHead(AnchorHead):
all_bbox_weights, all_bbox_weights,
num_total_samples=num_total_pos, num_total_samples=num_total_pos,
cfg=cfg) cfg=cfg)
return dict(loss_cls=losses_cls, loss_reg=losses_reg) return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
...@@ -2,8 +2,8 @@ import torch ...@@ -2,8 +2,8 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmdet.core import (delta2bbox, multiclass_nms, bbox_target, from mmdet.core import delta2bbox, multiclass_nms, bbox_target, accuracy
weighted_cross_entropy, weighted_smoothl1, accuracy) from ..builder import build_loss
from ..registry import HEADS from ..registry import HEADS
...@@ -21,7 +21,13 @@ class BBoxHead(nn.Module): ...@@ -21,7 +21,13 @@ class BBoxHead(nn.Module):
num_classes=81, num_classes=81,
target_means=[0., 0., 0., 0.], target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2], target_stds=[0.1, 0.1, 0.2, 0.2],
reg_class_agnostic=False): reg_class_agnostic=False,
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0),
loss_bbox=dict(
type='SmoothL1Loss', beta=1.0, loss_weight=1.0)):
super(BBoxHead, self).__init__() super(BBoxHead, self).__init__()
assert with_cls or with_reg assert with_cls or with_reg
self.with_avg_pool = with_avg_pool self.with_avg_pool = with_avg_pool
...@@ -34,6 +40,9 @@ class BBoxHead(nn.Module): ...@@ -34,6 +40,9 @@ class BBoxHead(nn.Module):
self.target_stds = target_stds self.target_stds = target_stds
self.reg_class_agnostic = reg_class_agnostic self.reg_class_agnostic = reg_class_agnostic
self.loss_cls = build_loss(loss_cls)
self.loss_bbox = build_loss(loss_bbox)
in_channels = self.in_channels in_channels = self.in_channels
if self.with_avg_pool: if self.with_avg_pool:
self.avg_pool = nn.AvgPool2d(roi_feat_size) self.avg_pool = nn.AvgPool2d(roi_feat_size)
...@@ -90,7 +99,7 @@ class BBoxHead(nn.Module): ...@@ -90,7 +99,7 @@ class BBoxHead(nn.Module):
reduce=True): reduce=True):
losses = dict() losses = dict()
if cls_score is not None: if cls_score is not None:
losses['loss_cls'] = weighted_cross_entropy( losses['loss_cls'] = self.loss_cls(
cls_score, labels, label_weights, reduce=reduce) cls_score, labels, label_weights, reduce=reduce)
losses['acc'] = accuracy(cls_score, labels) losses['acc'] = accuracy(cls_score, labels)
if bbox_pred is not None: if bbox_pred is not None:
...@@ -100,7 +109,7 @@ class BBoxHead(nn.Module): ...@@ -100,7 +109,7 @@ class BBoxHead(nn.Module):
else: else:
pos_bbox_pred = bbox_pred.view(bbox_pred.size(0), -1, pos_bbox_pred = bbox_pred.view(bbox_pred.size(0), -1,
4)[pos_inds, labels[pos_inds]] 4)[pos_inds, labels[pos_inds]]
losses['loss_reg'] = weighted_smoothl1( losses['loss_bbox'] = self.loss_bbox(
pos_bbox_pred, pos_bbox_pred,
bbox_targets[pos_inds], bbox_targets[pos_inds],
bbox_weights[pos_inds], bbox_weights[pos_inds],
...@@ -132,8 +141,9 @@ class BBoxHead(nn.Module): ...@@ -132,8 +141,9 @@ class BBoxHead(nn.Module):
if cfg is None: if cfg is None:
return bboxes, scores return bboxes, scores
else: else:
det_bboxes, det_labels = multiclass_nms( det_bboxes, det_labels = multiclass_nms(bboxes, scores,
bboxes, scores, cfg.score_thr, cfg.nms, cfg.max_per_img) cfg.score_thr, cfg.nms,
cfg.max_per_img)
return det_bboxes, det_labels return det_bboxes, det_labels
......
...@@ -29,8 +29,8 @@ class ConvFCBBoxHead(BBoxHead): ...@@ -29,8 +29,8 @@ class ConvFCBBoxHead(BBoxHead):
*args, *args,
**kwargs): **kwargs):
super(ConvFCBBoxHead, self).__init__(*args, **kwargs) super(ConvFCBBoxHead, self).__init__(*args, **kwargs)
assert (num_shared_convs + num_shared_fcs + num_cls_convs + num_cls_fcs assert (num_shared_convs + num_shared_fcs + num_cls_convs +
+ num_reg_convs + num_reg_fcs > 0) num_cls_fcs + num_reg_convs + num_reg_fcs > 0)
if num_cls_convs > 0 or num_reg_convs > 0: if num_cls_convs > 0 or num_reg_convs > 0:
assert num_shared_fcs == 0 assert num_shared_fcs == 0
if not self.with_cls: if not self.with_cls:
...@@ -76,8 +76,8 @@ class ConvFCBBoxHead(BBoxHead): ...@@ -76,8 +76,8 @@ class ConvFCBBoxHead(BBoxHead):
if self.with_cls: if self.with_cls:
self.fc_cls = nn.Linear(self.cls_last_dim, self.num_classes) self.fc_cls = nn.Linear(self.cls_last_dim, self.num_classes)
if self.with_reg: if self.with_reg:
out_dim_reg = (4 if self.reg_class_agnostic else out_dim_reg = (4 if self.reg_class_agnostic else 4 *
4 * self.num_classes) self.num_classes)
self.fc_reg = nn.Linear(self.reg_last_dim, out_dim_reg) self.fc_reg = nn.Linear(self.reg_last_dim, out_dim_reg)
def _add_conv_fc_branch(self, def _add_conv_fc_branch(self,
......
...@@ -2,7 +2,7 @@ import mmcv ...@@ -2,7 +2,7 @@ import mmcv
from torch import nn from torch import nn
from .registry import (BACKBONES, NECKS, ROI_EXTRACTORS, SHARED_HEADS, HEADS, from .registry import (BACKBONES, NECKS, ROI_EXTRACTORS, SHARED_HEADS, HEADS,
DETECTORS) LOSSES, DETECTORS)
def _build_module(cfg, registry, default_args): def _build_module(cfg, registry, default_args):
...@@ -52,5 +52,9 @@ def build_head(cfg): ...@@ -52,5 +52,9 @@ def build_head(cfg):
return build(cfg, HEADS) return build(cfg, HEADS)
def build_loss(cfg):
return build(cfg, LOSSES)
def build_detector(cfg, train_cfg=None, test_cfg=None): def build_detector(cfg, train_cfg=None, test_cfg=None):
return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg)) return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
from .cross_entropy_loss import CrossEntropyLoss
from .focal_loss import FocalLoss
from .smooth_l1_loss import SmoothL1Loss
__all__ = ['CrossEntropyLoss', 'FocalLoss', 'SmoothL1Loss']
import torch.nn as nn
from mmdet.core import (weighted_cross_entropy, weighted_binary_cross_entropy,
mask_cross_entropy)
from ..registry import LOSSES
@LOSSES.register_module
class CrossEntropyLoss(nn.Module):
def __init__(self, use_sigmoid=False, use_mask=False, loss_weight=1.0):
super(CrossEntropyLoss, self).__init__()
assert (use_sigmoid is False) or (use_mask is False)
self.use_sigmoid = use_sigmoid
self.use_mask = use_mask
self.loss_weight = loss_weight
if self.use_sigmoid:
self.cls_criterion = weighted_binary_cross_entropy
elif self.use_mask:
self.cls_criterion = mask_cross_entropy
else:
self.cls_criterion = weighted_cross_entropy
def forward(self, cls_score, label, label_weight, *args, **kwargs):
loss_cls = self.loss_weight * self.cls_criterion(
cls_score, label, label_weight, *args, **kwargs)
return loss_cls
import torch.nn as nn
from mmdet.core import weighted_sigmoid_focal_loss
from ..registry import LOSSES
@LOSSES.register_module
class FocalLoss(nn.Module):
def __init__(self,
use_sigmoid=False,
loss_weight=1.0,
gamma=2.0,
alpha=0.25):
super(FocalLoss, self).__init__()
assert use_sigmoid is True, 'Only sigmoid focaloss supported now.'
self.use_sigmoid = use_sigmoid
self.loss_weight = loss_weight
self.gamma = gamma
self.alpha = alpha
self.cls_criterion = weighted_sigmoid_focal_loss
def forward(self, cls_score, label, label_weight, *args, **kwargs):
if self.use_sigmoid:
loss_cls = self.loss_weight * self.cls_criterion(
cls_score,
label,
label_weight,
gamma=self.gamma,
alpha=self.alpha,
*args,
**kwargs)
else:
raise NotImplementedError
return loss_cls
import torch.nn as nn
from mmdet.core import weighted_smoothl1
from ..registry import LOSSES
@LOSSES.register_module
class SmoothL1Loss(nn.Module):
def __init__(self, beta=1.0, loss_weight=1.0):
super(SmoothL1Loss, self).__init__()
self.beta = beta
self.loss_weight = loss_weight
def forward(self, pred, target, weight, *args, **kwargs):
loss_bbox = self.loss_weight * weighted_smoothl1(
pred, target, weight, beta=self.beta, *args, **kwargs)
return loss_bbox
...@@ -4,9 +4,10 @@ import pycocotools.mask as mask_util ...@@ -4,9 +4,10 @@ import pycocotools.mask as mask_util
import torch import torch
import torch.nn as nn import torch.nn as nn
from ..builder import build_loss
from ..registry import HEADS from ..registry import HEADS
from ..utils import ConvModule from ..utils import ConvModule
from mmdet.core import mask_cross_entropy, mask_target from mmdet.core import mask_target
@HEADS.register_module @HEADS.register_module
...@@ -23,7 +24,9 @@ class FCNMaskHead(nn.Module): ...@@ -23,7 +24,9 @@ class FCNMaskHead(nn.Module):
num_classes=81, num_classes=81,
class_agnostic=False, class_agnostic=False,
conv_cfg=None, conv_cfg=None,
norm_cfg=None): norm_cfg=None,
loss_mask=dict(
type='CrossEntropyLoss', use_mask=True, loss_weight=1.0)):
super(FCNMaskHead, self).__init__() super(FCNMaskHead, self).__init__()
if upsample_method not in [None, 'deconv', 'nearest', 'bilinear']: if upsample_method not in [None, 'deconv', 'nearest', 'bilinear']:
raise ValueError( raise ValueError(
...@@ -40,6 +43,7 @@ class FCNMaskHead(nn.Module): ...@@ -40,6 +43,7 @@ class FCNMaskHead(nn.Module):
self.class_agnostic = class_agnostic self.class_agnostic = class_agnostic
self.conv_cfg = conv_cfg self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg self.norm_cfg = norm_cfg
self.loss_mask = build_loss(loss_mask)
self.convs = nn.ModuleList() self.convs = nn.ModuleList()
for i in range(self.num_convs): for i in range(self.num_convs):
...@@ -106,10 +110,10 @@ class FCNMaskHead(nn.Module): ...@@ -106,10 +110,10 @@ class FCNMaskHead(nn.Module):
def loss(self, mask_pred, mask_targets, labels): def loss(self, mask_pred, mask_targets, labels):
loss = dict() loss = dict()
if self.class_agnostic: if self.class_agnostic:
loss_mask = mask_cross_entropy(mask_pred, mask_targets, loss_mask = self.loss_mask(mask_pred, mask_targets,
torch.zeros_like(labels)) torch.zeros_like(labels))
else: else:
loss_mask = mask_cross_entropy(mask_pred, mask_targets, labels) loss_mask = self.loss_mask(mask_pred, mask_targets, labels)
loss['loss_mask'] = loss_mask loss['loss_mask'] = loss_mask
return loss return loss
......
...@@ -41,4 +41,5 @@ NECKS = Registry('neck') ...@@ -41,4 +41,5 @@ NECKS = Registry('neck')
ROI_EXTRACTORS = Registry('roi_extractor') ROI_EXTRACTORS = Registry('roi_extractor')
SHARED_HEADS = Registry('shared_head') SHARED_HEADS = Registry('shared_head')
HEADS = Registry('head') HEADS = Registry('head')
LOSSES = Registry('loss')
DETECTORS = Registry('detector') DETECTORS = Registry('detector')
...@@ -102,7 +102,7 @@ if __name__ == '__main__': ...@@ -102,7 +102,7 @@ if __name__ == '__main__':
'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.6',
], ],
license='GPLv3', license='Apache License 2.0',
setup_requires=['pytest-runner'], setup_requires=['pytest-runner'],
tests_require=['pytest'], tests_require=['pytest'],
install_requires=[ install_requires=[
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment