Commit 32afa752 authored by libuyu's avatar libuyu Committed by Kai Chen
Browse files

Code for "Gradient Harmonized Single-stage Detector" (#706)

* finish ghm in newest verion with AP 37.0

* add ghm config file

* reformat for PEP8

* reformat for flake8

* add documents for GHM and tensorize the params

* improve the docs

* add readme and update configs
parent b581e19f
# Gradient Harmonized Single-stage Detector
## Introduction
```
@inproceedings{li2019gradient,
title={Gradient Harmonized Single-stage Detector},
author={Li, Buyu and Liu, Yu and Wang, Xiaogang},
booktitle={AAAI Conference on Artificial Intelligence},
year={2019}
}
```
## Results and Models
To be benchmarked.
\ No newline at end of file
# model settings
model = dict(
type='RetinaNet',
pretrained='modelzoo://resnet50',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
add_extra_convs=True,
num_outs=5),
bbox_head=dict(
type='RetinaHead',
num_classes=81,
in_channels=256,
stacked_convs=4,
feat_channels=256,
octave_base_scale=4,
scales_per_octave=3,
anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[8, 16, 32, 64, 128],
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0],
loss_cls=dict(
type='GHMC',
bins=30,
momentum=0.75,
use_sigmoid=True,
loss_weight=1.0),
loss_bbox=dict(
type='GHMR',
mu=0.02,
bins=10,
momentum=0.7,
loss_weight=10.0)))
# training and testing settings
train_cfg = dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.4,
min_pos_iou=0,
ignore_iof_thr=-1),
allowed_border=-1,
pos_weight=-1,
debug=False)
test_cfg = dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_thr=0.5),
max_per_img=100)
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
data = dict(
imgs_per_gpu=2,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0.5,
with_mask=False,
with_crowd=False,
with_label=True),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0,
with_mask=False,
with_crowd=False,
with_label=True),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0,
with_mask=False,
with_crowd=False,
with_label=False,
test_mode=True))
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[8, 11])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs = 12
device_ids = range(8)
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/ghm'
load_from = None
resume_from = None
workflow = [('train', 1)]
...@@ -57,7 +57,7 @@ class AnchorHead(nn.Module): ...@@ -57,7 +57,7 @@ class AnchorHead(nn.Module):
self.target_stds = target_stds self.target_stds = target_stds
self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False) self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
self.sampling = loss_cls['type'] not in ['FocalLoss'] self.sampling = loss_cls['type'] not in ['FocalLoss', 'GHMC']
if self.use_sigmoid_cls: if self.use_sigmoid_cls:
self.cls_out_channels = num_classes - 1 self.cls_out_channels = num_classes - 1
else: else:
......
from .cross_entropy_loss import CrossEntropyLoss from .cross_entropy_loss import CrossEntropyLoss
from .focal_loss import FocalLoss from .focal_loss import FocalLoss
from .smooth_l1_loss import SmoothL1Loss from .smooth_l1_loss import SmoothL1Loss
from .ghm_loss import GHMC, GHMR
from .balanced_l1_loss import BalancedL1Loss from .balanced_l1_loss import BalancedL1Loss
from .iou_loss import IoULoss from .iou_loss import IoULoss
__all__ = [ __all__ = [
'CrossEntropyLoss', 'FocalLoss', 'SmoothL1Loss', 'BalancedL1Loss', 'CrossEntropyLoss', 'FocalLoss', 'SmoothL1Loss', 'BalancedL1Loss',
'IoULoss' 'IoULoss', 'GHMC', 'GHMR'
] ]
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..registry import LOSSES
def _expand_binary_labels(labels, label_weights, label_channels):
bin_labels = labels.new_full((labels.size(0), label_channels), 0)
inds = torch.nonzero(labels >= 1).squeeze()
if inds.numel() > 0:
bin_labels[inds, labels[inds] - 1] = 1
bin_label_weights = label_weights.view(-1, 1).expand(
label_weights.size(0), label_channels)
return bin_labels, bin_label_weights
@LOSSES.register_module
class GHMC(nn.Module):
"""GHM Classification Loss.
Details of the theorem can be viewed in the paper
"Gradient Harmonized Single-stage Detector".
https://arxiv.org/abs/1811.05181
Args:
bins (int): Number of the unit regions for distribution calculation.
momentum (float): The parameter for moving average.
use_sigmoid (bool): Can only be true for BCE based loss now.
loss_weight (float): The weight of the total GHM-C loss.
"""
def __init__(
self,
bins=10,
momentum=0,
use_sigmoid=True,
loss_weight=1.0):
super(GHMC, self).__init__()
self.bins = bins
self.momentum = momentum
self.edges = torch.arange(bins + 1).float().cuda() / bins
self.edges[-1] += 1e-6
if momentum > 0:
self.acc_sum = torch.zeros(bins).cuda()
self.use_sigmoid = use_sigmoid
if not self.use_sigmoid:
raise NotImplementedError
self.loss_weight = loss_weight
def forward(self, pred, target, label_weight, *args, **kwargs):
"""Calculate the GHM-C loss.
Args:
pred (float tensor of size [batch_num, class_num]):
The direct prediction of classification fc layer.
target (float tensor of size [batch_num, class_num]):
Binary class target for each sample.
label_weight (float tensor of size [batch_num, class_num]):
the value is 1 if the sample is valid and 0 if ignored.
Returns:
The gradient harmonized loss.
"""
# the target should be binary class label
if pred.dim() != target.dim():
target, label_weight = _expand_binary_labels(
target, label_weight, pred.size(-1))
target, label_weight = target.float(), label_weight.float()
edges = self.edges
mmt = self.momentum
weights = torch.zeros_like(pred)
# gradient length
g = torch.abs(pred.sigmoid().detach() - target)
valid = label_weight > 0
tot = max(valid.float().sum().item(), 1.0)
n = 0 # n valid bins
for i in range(self.bins):
inds = (g >= edges[i]) & (g < edges[i+1]) & valid
num_in_bin = inds.sum().item()
if num_in_bin > 0:
if mmt > 0:
self.acc_sum[i] = mmt * self.acc_sum[i] \
+ (1 - mmt) * num_in_bin
weights[inds] = tot / self.acc_sum[i]
else:
weights[inds] = tot / num_in_bin
n += 1
if n > 0:
weights = weights / n
loss = F.binary_cross_entropy_with_logits(
pred, target, weights, reduction='sum') / tot
return loss * self.loss_weight
@LOSSES.register_module
class GHMR(nn.Module):
"""GHM Regression Loss.
Details of the theorem can be viewed in the paper
"Gradient Harmonized Single-stage Detector"
https://arxiv.org/abs/1811.05181
Args:
mu (float): The parameter for the Authentic Smooth L1 loss.
bins (int): Number of the unit regions for distribution calculation.
momentum (float): The parameter for moving average.
loss_weight (float): The weight of the total GHM-R loss.
"""
def __init__(
self,
mu=0.02,
bins=10,
momentum=0,
loss_weight=1.0):
super(GHMR, self).__init__()
self.mu = mu
self.bins = bins
self.edges = torch.arange(bins + 1).float().cuda() / bins
self.edges[-1] = 1e3
self.momentum = momentum
if momentum > 0:
self.acc_sum = torch.zeros(bins).cuda()
self.loss_weight = loss_weight
def forward(self, pred, target, label_weight, avg_factor=None):
"""Calculate the GHM-R loss.
Args:
pred (float tensor of size [batch_num, 4 (* class_num)]):
The prediction of box regression layer. Channel number can be 4
or 4 * class_num depending on whether it is class-agnostic.
target (float tensor of size [batch_num, 4 (* class_num)]):
The target regression values with the same size of pred.
label_weight (float tensor of size [batch_num, 4 (* class_num)]):
The weight of each sample, 0 if ignored.
Returns:
The gradient harmonized loss.
"""
mu = self.mu
edges = self.edges
mmt = self.momentum
# ASL1 loss
diff = pred - target
loss = torch.sqrt(diff * diff + mu * mu) - mu
# gradient length
g = torch.abs(diff / torch.sqrt(mu * mu + diff * diff)).detach()
weights = torch.zeros_like(g)
valid = label_weight > 0
tot = max(label_weight.float().sum().item(), 1.0)
n = 0 # n: valid bins
for i in range(self.bins):
inds = (g >= edges[i]) & (g < edges[i+1]) & valid
num_in_bin = inds.sum().item()
if num_in_bin > 0:
n += 1
if mmt > 0:
self.acc_sum[i] = mmt * self.acc_sum[i] \
+ (1 - mmt) * num_in_bin
weights[inds] = tot / self.acc_sum[i]
else:
weights[inds] = tot / num_in_bin
if n > 0:
weights /= n
loss = loss * weights
loss = loss.sum() / tot
return loss * self.loss_weight
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment