Unverified Commit c1f9fa0f authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Merge pull request #163 from yhcao6/ohem-sampler

Add OHEMSampler
parents 94beb922 b932030c
# model settings
model = dict(
type='FasterRCNN',
pretrained='modelzoo://resnet50',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=5),
rpn_head=dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
anchor_scales=[8],
anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[4, 8, 16, 32, 64],
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0],
use_sigmoid_cls=True),
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=dict(
type='SharedFCBBoxHead',
num_fcs=2,
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=81,
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2],
reg_class_agnostic=False))
# model training and testing settings
train_cfg = dict(
rpn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.3,
min_pos_iou=0.3,
ignore_iof_thr=-1),
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False),
allowed_border=0,
pos_weight=-1,
smoothl1_beta=1 / 9.0,
debug=False),
rcnn=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.5,
ignore_iof_thr=-1),
sampler=dict(
type='OHEMSampler',
num=512,
pos_fraction=0.25,
neg_pos_ub=-1,
add_gt_as_proposals=True),
pos_weight=-1,
debug=False))
test_cfg = dict(
rpn=dict(
nms_across_levels=False,
nms_pre=2000,
nms_post=2000,
max_num=2000,
nms_thr=0.7,
min_bbox_size=0),
rcnn=dict(
score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100)
# soft-nms is also supported for rcnn testing
# e.g., nms=dict(type='soft_nms', iou_thr=0.5, min_score=0.05)
)
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
data = dict(
imgs_per_gpu=2,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0.5,
with_mask=False,
with_crowd=True,
with_label=True),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0,
with_mask=False,
with_crowd=True,
with_label=True),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
img_scale=(1333, 800),
img_norm_cfg=img_norm_cfg,
size_divisor=32,
flip_ratio=0,
with_mask=False,
with_label=False,
test_mode=True))
# optimizer
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[8, 11])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs = 12
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/faster_rcnn_r50_fpn_1x'
load_from = None
resume_from = None
workflow = [('train', 1)]
...@@ -3,23 +3,23 @@ import mmcv ...@@ -3,23 +3,23 @@ import mmcv
from . import assigners, samplers from . import assigners, samplers
def build_assigner(cfg, default_args=None): def build_assigner(cfg, **kwargs):
if isinstance(cfg, assigners.BaseAssigner): if isinstance(cfg, assigners.BaseAssigner):
return cfg return cfg
elif isinstance(cfg, dict): elif isinstance(cfg, dict):
return mmcv.runner.obj_from_dict( return mmcv.runner.obj_from_dict(
cfg, assigners, default_args=default_args) cfg, assigners, default_args=kwargs)
else: else:
raise TypeError('Invalid type {} for building a sampler'.format( raise TypeError('Invalid type {} for building a sampler'.format(
type(cfg))) type(cfg)))
def build_sampler(cfg, default_args=None): def build_sampler(cfg, **kwargs):
if isinstance(cfg, samplers.BaseSampler): if isinstance(cfg, samplers.BaseSampler):
return cfg return cfg
elif isinstance(cfg, dict): elif isinstance(cfg, dict):
return mmcv.runner.obj_from_dict( return mmcv.runner.obj_from_dict(
cfg, samplers, default_args=default_args) cfg, samplers, default_args=kwargs)
else: else:
raise TypeError('Invalid type {} for building a sampler'.format( raise TypeError('Invalid type {} for building a sampler'.format(
type(cfg))) type(cfg)))
......
...@@ -4,10 +4,11 @@ from .random_sampler import RandomSampler ...@@ -4,10 +4,11 @@ from .random_sampler import RandomSampler
from .instance_balanced_pos_sampler import InstanceBalancedPosSampler from .instance_balanced_pos_sampler import InstanceBalancedPosSampler
from .iou_balanced_neg_sampler import IoUBalancedNegSampler from .iou_balanced_neg_sampler import IoUBalancedNegSampler
from .combined_sampler import CombinedSampler from .combined_sampler import CombinedSampler
from .ohem_sampler import OHEMSampler
from .sampling_result import SamplingResult from .sampling_result import SamplingResult
__all__ = [ __all__ = [
'BaseSampler', 'PseudoSampler', 'RandomSampler', 'BaseSampler', 'PseudoSampler', 'RandomSampler',
'InstanceBalancedPosSampler', 'IoUBalancedNegSampler', 'CombinedSampler', 'InstanceBalancedPosSampler', 'IoUBalancedNegSampler', 'CombinedSampler',
'SamplingResult' 'OHEMSampler', 'SamplingResult'
] ]
...@@ -7,19 +7,33 @@ from .sampling_result import SamplingResult ...@@ -7,19 +7,33 @@ from .sampling_result import SamplingResult
class BaseSampler(metaclass=ABCMeta): class BaseSampler(metaclass=ABCMeta):
def __init__(self): def __init__(self,
num,
pos_fraction,
neg_pos_ub=-1,
add_gt_as_proposals=True,
**kwargs):
self.num = num
self.pos_fraction = pos_fraction
self.neg_pos_ub = neg_pos_ub
self.add_gt_as_proposals = add_gt_as_proposals
self.pos_sampler = self self.pos_sampler = self
self.neg_sampler = self self.neg_sampler = self
@abstractmethod @abstractmethod
def _sample_pos(self, assign_result, num_expected): def _sample_pos(self, assign_result, num_expected, **kwargs):
pass pass
@abstractmethod @abstractmethod
def _sample_neg(self, assign_result, num_expected): def _sample_neg(self, assign_result, num_expected, **kwargs):
pass pass
def sample(self, assign_result, bboxes, gt_bboxes, gt_labels=None): def sample(self,
assign_result,
bboxes,
gt_bboxes,
gt_labels=None,
**kwargs):
"""Sample positive and negative bboxes. """Sample positive and negative bboxes.
This is a simple implementation of bbox sampling given candidates, This is a simple implementation of bbox sampling given candidates,
...@@ -44,8 +58,8 @@ class BaseSampler(metaclass=ABCMeta): ...@@ -44,8 +58,8 @@ class BaseSampler(metaclass=ABCMeta):
gt_flags = torch.cat([gt_ones, gt_flags]) gt_flags = torch.cat([gt_ones, gt_flags])
num_expected_pos = int(self.num * self.pos_fraction) num_expected_pos = int(self.num * self.pos_fraction)
pos_inds = self.pos_sampler._sample_pos(assign_result, pos_inds = self.pos_sampler._sample_pos(
num_expected_pos) assign_result, num_expected_pos, bboxes=bboxes, **kwargs)
# We found that sampled indices have duplicated items occasionally. # We found that sampled indices have duplicated items occasionally.
# (may be a bug of PyTorch) # (may be a bug of PyTorch)
pos_inds = pos_inds.unique() pos_inds = pos_inds.unique()
...@@ -56,8 +70,8 @@ class BaseSampler(metaclass=ABCMeta): ...@@ -56,8 +70,8 @@ class BaseSampler(metaclass=ABCMeta):
neg_upper_bound = int(self.neg_pos_ub * _pos) neg_upper_bound = int(self.neg_pos_ub * _pos)
if num_expected_neg > neg_upper_bound: if num_expected_neg > neg_upper_bound:
num_expected_neg = neg_upper_bound num_expected_neg = neg_upper_bound
neg_inds = self.neg_sampler._sample_neg(assign_result, neg_inds = self.neg_sampler._sample_neg(
num_expected_neg) assign_result, num_expected_neg, bboxes=bboxes, **kwargs)
neg_inds = neg_inds.unique() neg_inds = neg_inds.unique()
return SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes, return SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes,
......
from .random_sampler import RandomSampler from .base_sampler import BaseSampler
from ..assign_sampling import build_sampler from ..assign_sampling import build_sampler
class CombinedSampler(RandomSampler): class CombinedSampler(BaseSampler):
def __init__(self, num, pos_fraction, pos_sampler, neg_sampler, **kwargs): def __init__(self, pos_sampler, neg_sampler, **kwargs):
super(CombinedSampler, self).__init__(num, pos_fraction, **kwargs) super(CombinedSampler, self).__init__(**kwargs)
default_args = dict(num=num, pos_fraction=pos_fraction) self.pos_sampler = build_sampler(pos_sampler, **kwargs)
default_args.update(kwargs) self.neg_sampler = build_sampler(neg_sampler, **kwargs)
self.pos_sampler = build_sampler(
pos_sampler, default_args=default_args) def _sample_pos(self, **kwargs):
self.neg_sampler = build_sampler( raise NotImplementedError
neg_sampler, default_args=default_args)
def _sample_neg(self, **kwargs):
raise NotImplementedError
...@@ -6,7 +6,7 @@ from .random_sampler import RandomSampler ...@@ -6,7 +6,7 @@ from .random_sampler import RandomSampler
class InstanceBalancedPosSampler(RandomSampler): class InstanceBalancedPosSampler(RandomSampler):
def _sample_pos(self, assign_result, num_expected): def _sample_pos(self, assign_result, num_expected, **kwargs):
pos_inds = torch.nonzero(assign_result.gt_inds > 0) pos_inds = torch.nonzero(assign_result.gt_inds > 0)
if pos_inds.numel() != 0: if pos_inds.numel() != 0:
pos_inds = pos_inds.squeeze(1) pos_inds = pos_inds.squeeze(1)
......
...@@ -19,7 +19,7 @@ class IoUBalancedNegSampler(RandomSampler): ...@@ -19,7 +19,7 @@ class IoUBalancedNegSampler(RandomSampler):
self.hard_thr = hard_thr self.hard_thr = hard_thr
self.hard_fraction = hard_fraction self.hard_fraction = hard_fraction
def _sample_neg(self, assign_result, num_expected): def _sample_neg(self, assign_result, num_expected, **kwargs):
neg_inds = torch.nonzero(assign_result.gt_inds == 0) neg_inds = torch.nonzero(assign_result.gt_inds == 0)
if neg_inds.numel() != 0: if neg_inds.numel() != 0:
neg_inds = neg_inds.squeeze(1) neg_inds = neg_inds.squeeze(1)
......
import torch
from .base_sampler import BaseSampler
from ..transforms import bbox2roi
class OHEMSampler(BaseSampler):
def __init__(self,
num,
pos_fraction,
context,
neg_pos_ub=-1,
add_gt_as_proposals=True,
**kwargs):
super(OHEMSampler, self).__init__(num, pos_fraction, neg_pos_ub,
add_gt_as_proposals)
self.bbox_roi_extractor = context.bbox_roi_extractor
self.bbox_head = context.bbox_head
def hard_mining(self, inds, num_expected, bboxes, labels, feats):
with torch.no_grad():
rois = bbox2roi([bboxes])
bbox_feats = self.bbox_roi_extractor(
feats[:self.bbox_roi_extractor.num_inputs], rois)
cls_score, _ = self.bbox_head(bbox_feats)
loss = self.bbox_head.loss(
cls_score=cls_score,
bbox_pred=None,
labels=labels,
label_weights=cls_score.new_ones(cls_score.size(0)),
bbox_targets=None,
bbox_weights=None,
reduce=False)['loss_cls']
_, topk_loss_inds = loss.topk(num_expected)
return inds[topk_loss_inds]
def _sample_pos(self,
assign_result,
num_expected,
bboxes=None,
feats=None,
**kwargs):
# Sample some hard positive samples
pos_inds = torch.nonzero(assign_result.gt_inds > 0)
if pos_inds.numel() != 0:
pos_inds = pos_inds.squeeze(1)
if pos_inds.numel() <= num_expected:
return pos_inds
else:
return self.hard_mining(pos_inds, num_expected, bboxes[pos_inds],
assign_result.labels[pos_inds], feats)
def _sample_neg(self,
assign_result,
num_expected,
bboxes=None,
feats=None,
**kwargs):
# Sample some hard negative samples
neg_inds = torch.nonzero(assign_result.gt_inds == 0)
if neg_inds.numel() != 0:
neg_inds = neg_inds.squeeze(1)
if len(neg_inds) <= num_expected:
return neg_inds
else:
return self.hard_mining(neg_inds, num_expected, bboxes[neg_inds],
assign_result.labels[neg_inds], feats)
...@@ -6,16 +6,16 @@ from .sampling_result import SamplingResult ...@@ -6,16 +6,16 @@ from .sampling_result import SamplingResult
class PseudoSampler(BaseSampler): class PseudoSampler(BaseSampler):
def __init__(self): def __init__(self, **kwargs):
pass pass
def _sample_pos(self): def _sample_pos(self, **kwargs):
raise NotImplementedError raise NotImplementedError
def _sample_neg(self): def _sample_neg(self, **kwargs):
raise NotImplementedError raise NotImplementedError
def sample(self, assign_result, bboxes, gt_bboxes): def sample(self, assign_result, bboxes, gt_bboxes, **kwargs):
pos_inds = torch.nonzero( pos_inds = torch.nonzero(
assign_result.gt_inds > 0).squeeze(-1).unique() assign_result.gt_inds > 0).squeeze(-1).unique()
neg_inds = torch.nonzero( neg_inds = torch.nonzero(
......
...@@ -10,12 +10,10 @@ class RandomSampler(BaseSampler): ...@@ -10,12 +10,10 @@ class RandomSampler(BaseSampler):
num, num,
pos_fraction, pos_fraction,
neg_pos_ub=-1, neg_pos_ub=-1,
add_gt_as_proposals=True): add_gt_as_proposals=True,
super(RandomSampler, self).__init__() **kwargs):
self.num = num super(RandomSampler, self).__init__(num, pos_fraction, neg_pos_ub,
self.pos_fraction = pos_fraction add_gt_as_proposals)
self.neg_pos_ub = neg_pos_ub
self.add_gt_as_proposals = add_gt_as_proposals
@staticmethod @staticmethod
def random_choice(gallery, num): def random_choice(gallery, num):
...@@ -34,7 +32,7 @@ class RandomSampler(BaseSampler): ...@@ -34,7 +32,7 @@ class RandomSampler(BaseSampler):
rand_inds = torch.from_numpy(rand_inds).long().to(gallery.device) rand_inds = torch.from_numpy(rand_inds).long().to(gallery.device)
return gallery[rand_inds] return gallery[rand_inds]
def _sample_pos(self, assign_result, num_expected): def _sample_pos(self, assign_result, num_expected, **kwargs):
"""Randomly sample some positive samples.""" """Randomly sample some positive samples."""
pos_inds = torch.nonzero(assign_result.gt_inds > 0) pos_inds = torch.nonzero(assign_result.gt_inds > 0)
if pos_inds.numel() != 0: if pos_inds.numel() != 0:
...@@ -44,7 +42,7 @@ class RandomSampler(BaseSampler): ...@@ -44,7 +42,7 @@ class RandomSampler(BaseSampler):
else: else:
return self.random_choice(pos_inds, num_expected) return self.random_choice(pos_inds, num_expected)
def _sample_neg(self, assign_result, num_expected): def _sample_neg(self, assign_result, num_expected, **kwargs):
"""Randomly sample some negative samples.""" """Randomly sample some negative samples."""
neg_inds = torch.nonzero(assign_result.gt_inds == 0) neg_inds = torch.nonzero(assign_result.gt_inds == 0)
if neg_inds.numel() != 0: if neg_inds.numel() != 0:
......
...@@ -10,11 +10,15 @@ def weighted_nll_loss(pred, label, weight, avg_factor=None): ...@@ -10,11 +10,15 @@ def weighted_nll_loss(pred, label, weight, avg_factor=None):
return torch.sum(raw * weight)[None] / avg_factor return torch.sum(raw * weight)[None] / avg_factor
def weighted_cross_entropy(pred, label, weight, avg_factor=None): def weighted_cross_entropy(pred, label, weight, avg_factor=None,
reduce=True):
if avg_factor is None: if avg_factor is None:
avg_factor = max(torch.sum(weight > 0).float().item(), 1.) avg_factor = max(torch.sum(weight > 0).float().item(), 1.)
raw = F.cross_entropy(pred, label, reduction='none') raw = F.cross_entropy(pred, label, reduction='none')
return torch.sum(raw * weight)[None] / avg_factor if reduce:
return torch.sum(raw * weight)[None] / avg_factor
else:
return raw * weight / avg_factor
def weighted_binary_cross_entropy(pred, label, weight, avg_factor=None): def weighted_binary_cross_entropy(pred, label, weight, avg_factor=None):
......
...@@ -79,11 +79,11 @@ class BBoxHead(nn.Module): ...@@ -79,11 +79,11 @@ class BBoxHead(nn.Module):
return cls_reg_targets return cls_reg_targets
def loss(self, cls_score, bbox_pred, labels, label_weights, bbox_targets, def loss(self, cls_score, bbox_pred, labels, label_weights, bbox_targets,
bbox_weights): bbox_weights, reduce=True):
losses = dict() losses = dict()
if cls_score is not None: if cls_score is not None:
losses['loss_cls'] = weighted_cross_entropy( losses['loss_cls'] = weighted_cross_entropy(
cls_score, labels, label_weights) cls_score, labels, label_weights, reduce=reduce)
losses['acc'] = accuracy(cls_score, labels) losses['acc'] = accuracy(cls_score, labels)
if bbox_pred is not None: if bbox_pred is not None:
losses['loss_reg'] = weighted_smoothl1( losses['loss_reg'] = weighted_smoothl1(
......
...@@ -4,7 +4,7 @@ import torch.nn as nn ...@@ -4,7 +4,7 @@ import torch.nn as nn
from .base import BaseDetector from .base import BaseDetector
from .test_mixins import RPNTestMixin, BBoxTestMixin, MaskTestMixin from .test_mixins import RPNTestMixin, BBoxTestMixin, MaskTestMixin
from .. import builder from .. import builder
from mmdet.core import (assign_and_sample, bbox2roi, bbox2result, multi_apply) from mmdet.core import bbox2roi, bbox2result, build_assigner, build_sampler
class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
...@@ -102,13 +102,22 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, ...@@ -102,13 +102,22 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
# assign gts and sample proposals # assign gts and sample proposals
if self.with_bbox or self.with_mask: if self.with_bbox or self.with_mask:
assign_results, sampling_results = multi_apply( bbox_assigner = build_assigner(self.train_cfg.rcnn.assigner)
assign_and_sample, bbox_sampler = build_sampler(
proposal_list, self.train_cfg.rcnn.sampler, context=self)
gt_bboxes, num_imgs = img.size(0)
gt_bboxes_ignore, sampling_results = []
gt_labels, for i in range(num_imgs):
cfg=self.train_cfg.rcnn) assign_result = bbox_assigner.assign(
proposal_list[i], gt_bboxes[i], gt_bboxes_ignore[i],
gt_labels[i])
sampling_result = bbox_sampler.sample(
assign_result,
proposal_list[i],
gt_bboxes[i],
gt_labels[i],
feats=[lvl_feat[i][None] for lvl_feat in x])
sampling_results.append(sampling_result)
# bbox head forward and loss # bbox head forward and loss
if self.with_bbox: if self.with_bbox:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment