Commit 763153dc authored by yhcao6's avatar yhcao6
Browse files

Add OHEMSampler

parent 826a5613
# model settings
model = dict(
type='FasterRCNN',
pretrained='modelzoo://resnet50',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5),
rpn_head=dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
anchor_scales=[8],
anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[4, 8, 16, 32, 64],
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0],
use_sigmoid_cls=True),
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=dict(
type='SharedFCBBoxHead',
num_fcs=2,
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=81,
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2],
reg_class_agnostic=False))
# model training and testing settings
train_cfg = dict(
rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=0,
pos_weight=-1,
smoothl1_beta=1 / 9.0,
debug=False),
rcnn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
ignore_iof_thr=-1),
sampler=dict(
type='OHEMSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
pos_weight=-1,
debug=False))
test_cfg = dict(
rpn=dict(
nms_across_levels=False,
nms_pre=2000,
nms_post=2000,
max_num=2000,
nms_thr=0.7,
min_bbox_size=0),
rcnn=dict(
score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100)
# soft-nms is also supported for rcnn testing
# e.g., nms=dict(type='soft_nms', iou_thr=0.5, min_score=0.05)
)
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
data = dict(
imgs_per_gpu=2,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0.5,
with_mask=False,
with_crowd=True,
with_label=True),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0,
with_mask=False,
with_crowd=True,
with_label=True),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0,
with_mask=False,
with_label=False,
test_mode=True))
# optimizer
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[8, 11])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs = 12
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/faster_rcnn_r50_fpn_1x'
load_from = None
resume_from = None
workflow = [('train', 1)]
......@@ -4,10 +4,11 @@ from .random_sampler import RandomSampler
from .instance_balanced_pos_sampler import InstanceBalancedPosSampler
from .iou_balanced_neg_sampler import IoUBalancedNegSampler
from .combined_sampler import CombinedSampler
from .ohem_sampler import OHEMSampler
from .sampling_result import SamplingResult
__all__ = [
'BaseSampler', 'PseudoSampler', 'RandomSampler',
'InstanceBalancedPosSampler', 'IoUBalancedNegSampler', 'CombinedSampler',
'SamplingResult'
'OHEMSampler', 'SamplingResult'
]
import torch
from .base_sampler import BaseSampler
from ..transforms import bbox2roi
from .sampling_result import SamplingResult
class OHEMSampler(BaseSampler):
def __init__(self,
num,
pos_fraction,
neg_pos_ub=-1,
add_gt_as_proposals=True,):
super(OHEMSampler, self).__init__()
self.num = num
self.pos_fraction = pos_fraction
self.neg_pos_ub = neg_pos_ub
self.add_gt_as_proposals = add_gt_as_proposals
def _sample_pos(self, assign_result, num_expected, loss_all):
"""Hard sample some positive samples."""
pos_inds = torch.nonzero(assign_result.gt_inds > 0)
if pos_inds.numel() != 0:
pos_inds = pos_inds.squeeze(1)
if pos_inds.numel() <= num_expected:
return pos_inds
else:
_, topk_loss_pos_inds = loss_all[pos_inds].topk(num_expected)
return pos_inds[topk_loss_pos_inds]
def _sample_neg(self, assign_result, num_expected, loss_all):
"""Hard sample some negative samples."""
neg_inds = torch.nonzero(assign_result.gt_inds == 0)
if neg_inds.numel() != 0:
neg_inds = neg_inds.squeeze(1)
if len(neg_inds) <= num_expected:
return neg_inds
else:
_, topk_loss_neg_inds = loss_all[neg_inds].topk(num_expected)
return neg_inds[topk_loss_neg_inds]
def sample(self, assign_result, bboxes, gt_bboxes, gt_labels=None,
feats=None, bbox_roi_extractor=None, bbox_head=None):
"""Sample positive and negative bboxes.
This is a simple implementation of bbox sampling given candidates,
assigning results and ground truth bboxes.
Args:
assign_result (:obj:`AssignResult`): Bbox assigning results.
bboxes (Tensor): Boxes to be sampled from.
gt_bboxes (Tensor): Ground truth bboxes.
gt_labels (Tensor, optional): Class labels of ground truth bboxes.
Returns:
:obj:`SamplingResult`: Sampling result.
"""
bboxes = bboxes[:, :4]
gt_flags = bboxes.new_zeros((bboxes.shape[0], ), dtype=torch.uint8)
if self.add_gt_as_proposals:
bboxes = torch.cat([gt_bboxes, bboxes], dim=0)
assign_result.add_gt_(gt_labels)
gt_ones = bboxes.new_ones(gt_bboxes.shape[0], dtype=torch.uint8)
gt_flags = torch.cat([gt_ones, gt_flags])
# calculate loss of all samples used for hard mining
with torch.no_grad():
rois = bbox2roi([bboxes])
bbox_feats = bbox_roi_extractor(
feats[:bbox_roi_extractor.num_inputs], rois)
cls_score, _ = bbox_head(bbox_feats)
loss_all = bbox_head.loss(
cls_score=cls_score,
bbox_pred=None,
labels=assign_result.labels,
label_weights=cls_score.new_ones(cls_score.size(0)),
bbox_targets=None,
bbox_weights=None,
reduction='none')['loss_cls']
num_expected_pos = int(self.num * self.pos_fraction)
pos_inds = self._sample_pos(assign_result, num_expected_pos, loss_all)
# We found that sampled indices have duplicated items occasionally.
# (may be a bug of PyTorch)
pos_inds = pos_inds.unique()
num_sampled_pos = pos_inds.numel()
num_expected_neg = self.num - num_sampled_pos
if self.neg_pos_ub >= 0:
_pos = max(1, num_sampled_pos)
neg_upper_bound = int(self.neg_pos_ub * _pos)
if num_expected_neg > neg_upper_bound:
num_expected_neg = neg_upper_bound
neg_inds = self._sample_neg(assign_result, num_expected_neg, loss_all)
neg_inds = neg_inds.unique()
return SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes,
assign_result, gt_flags)
......@@ -10,11 +10,15 @@ def weighted_nll_loss(pred, label, weight, avg_factor=None):
return torch.sum(raw * weight)[None] / avg_factor
def weighted_cross_entropy(pred, label, weight, avg_factor=None):
def weighted_cross_entropy(pred, label, weight, avg_factor=None,
reduction='elementwise_sum'):
if avg_factor is None:
avg_factor = max(torch.sum(weight > 0).float().item(), 1.)
raw = F.cross_entropy(pred, label, reduction='none')
return torch.sum(raw * weight)[None] / avg_factor
if reduction == 'elementwise_sum':
return torch.sum(raw * weight)[None] / avg_factor
elif reduction == 'none':
return raw * weight / avg_factor
def weighted_binary_cross_entropy(pred, label, weight, avg_factor=None):
......
......@@ -79,11 +79,11 @@ class BBoxHead(nn.Module):
return cls_reg_targets
def loss(self, cls_score, bbox_pred, labels, label_weights, bbox_targets,
bbox_weights):
bbox_weights, reduction='elementwise_sum'):
losses = dict()
if cls_score is not None:
losses['loss_cls'] = weighted_cross_entropy(
cls_score, labels, label_weights)
cls_score, labels, label_weights, reduction=reduction)
losses['acc'] = accuracy(cls_score, labels)
if bbox_pred is not None:
losses['loss_reg'] = weighted_smoothl1(
......
......@@ -4,7 +4,7 @@ import torch.nn as nn
from .base import BaseDetector
from .test_mixins import RPNTestMixin, BBoxTestMixin, MaskTestMixin
from .. import builder
from mmdet.core import (assign_and_sample, bbox2roi, bbox2result, multi_apply)
from mmdet.core import (bbox2roi, bbox2result, build_assigner, build_sampler)
class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
......@@ -102,13 +102,30 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
# assign gts and sample proposals
if self.with_bbox or self.with_mask:
assign_results, sampling_results = multi_apply(
assign_and_sample,
proposal_list,
gt_bboxes,
gt_bboxes_ignore,
gt_labels,
cfg=self.train_cfg.rcnn)
bbox_assigner = build_assigner(self.train_cfg.rcnn.assigner)
bbox_sampler = build_sampler(self.train_cfg.rcnn.sampler)
num_imgs = img.size(0)
assign_results = []
sampling_results = []
for i in range(num_imgs):
assign_result = bbox_assigner.assign(
proposal_list[i], gt_bboxes[i], gt_bboxes_ignore[i],
gt_labels[i])
if self.train_cfg.rcnn.sampler.type == 'OHEMSampler':
sampling_result = bbox_sampler.sample(
assign_result,
proposal_list[i],
gt_bboxes[i],
gt_labels[i],
[xx[i][None] for xx in x],
self.bbox_roi_extractor,
self.bbox_head)
else:
sampling_result = bbox_sampler.sample(
assign_result, proposal_list[i], gt_bboxes[i],
gt_labels[i])
assign_results.append(assign_result)
sampling_results.append(sampling_result)
# bbox head forward and loss
if self.with_bbox:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment