Commit 8457bbab authored by Jon Crall's avatar Jon Crall Committed by Kai Chen
Browse files

Enhance AssignResult and SamplingResult (#1995)

* Enhance AssignResult and SamplingResult

Add runtime dependency on ubelt (pending approval)

Fix issue in SamplingResult.__init__

Add rng as attribute of RandomSampler

* fix linters

* remove ubelt

* Fix linters

* fix linters again
parent 78529eca
import torch
from mmdet.utils import util_mixins
class AssignResult(object):
class AssignResult(util_mixins.NiceRepr):
"""
Stores assignments between predicted and truth boxes.
......@@ -44,20 +46,25 @@ class AssignResult(object):
self.max_overlaps = max_overlaps
self.labels = labels
def add_gt_(self, gt_labels):
self_inds = torch.arange(
1, len(gt_labels) + 1, dtype=torch.long, device=gt_labels.device)
self.gt_inds = torch.cat([self_inds, self.gt_inds])
# Was this a bug?
# self.max_overlaps = torch.cat(
# [self.max_overlaps.new_ones(self.num_gts), self.max_overlaps])
# IIUC, It seems like the correct code should be:
self.max_overlaps = torch.cat(
[self.max_overlaps.new_ones(len(gt_labels)), self.max_overlaps])
@property
def num_preds(self):
"""
Return the number of predictions in this assignment
"""
return len(self.gt_inds)
if self.labels is not None:
self.labels = torch.cat([gt_labels, self.labels])
@property
def info(self):
"""
Returns a dictionary of info about the object
"""
return {
'num_gts': self.num_gts,
'num_preds': self.num_preds,
'gt_inds': self.gt_inds,
'max_overlaps': self.max_overlaps,
'labels': self.labels,
}
def __nice__(self):
"""
......@@ -81,12 +88,105 @@ class AssignResult(object):
parts.append('labels.shape={!r}'.format(tuple(self.labels.shape)))
return ', '.join(parts)
def __repr__(self):
nice = self.__nice__()
classname = self.__class__.__name__
return '<{}({}) at {}>'.format(classname, nice, hex(id(self)))
@classmethod
def random(cls, **kwargs):
"""
Create random AssignResult for tests or debugging.
Kwargs:
num_preds: number of predicted boxes
num_gts: number of true boxes
p_ignore (float): probability of a predicted box assinged to an
ignored truth
p_assigned (float): probability of a predicted box not being
assigned
p_use_label (float | bool): with labels or not
rng (None | int | numpy.random.RandomState): seed or state
Returns:
AssignResult :
Example:
>>> from mmdet.core.bbox.assigners.assign_result import * # NOQA
>>> self = AssignResult.random()
>>> print(self.info)
"""
from mmdet.core.bbox import demodata
rng = demodata.ensure_rng(kwargs.get('rng', None))
num_gts = kwargs.get('num_gts', None)
num_preds = kwargs.get('num_preds', None)
p_ignore = kwargs.get('p_ignore', 0.3)
p_assigned = kwargs.get('p_assigned', 0.7)
p_use_label = kwargs.get('p_use_label', 0.5)
num_classes = kwargs.get('p_use_label', 3)
if num_gts is None:
num_gts = rng.randint(0, 8)
if num_preds is None:
num_preds = rng.randint(0, 16)
def __str__(self):
classname = self.__class__.__name__
nice = self.__nice__()
return '<{}({})>'.format(classname, nice)
if num_gts == 0:
max_overlaps = torch.zeros(num_preds, dtype=torch.float32)
gt_inds = torch.zeros(num_preds, dtype=torch.int64)
if p_use_label is True or p_use_label < rng.rand():
labels = torch.zeros(num_preds, dtype=torch.int64)
else:
labels = None
else:
import numpy as np
# Create an overlap for each predicted box
max_overlaps = torch.from_numpy(rng.rand(num_preds))
# Construct gt_inds for each predicted box
is_assigned = torch.from_numpy(rng.rand(num_preds) < p_assigned)
# maximum number of assignments constraints
n_assigned = min(num_preds, min(num_gts, is_assigned.sum()))
assigned_idxs = np.where(is_assigned)[0]
rng.shuffle(assigned_idxs)
assigned_idxs = assigned_idxs[0:n_assigned]
assigned_idxs.sort()
is_assigned[:] = 0
is_assigned[assigned_idxs] = True
is_ignore = torch.from_numpy(
rng.rand(num_preds) < p_ignore) & is_assigned
gt_inds = torch.zeros(num_preds, dtype=torch.int64)
true_idxs = np.arange(num_gts)
rng.shuffle(true_idxs)
true_idxs = torch.from_numpy(true_idxs)
gt_inds[is_assigned] = true_idxs[:n_assigned]
gt_inds = torch.from_numpy(
rng.randint(1, num_gts + 1, size=num_preds))
gt_inds[is_ignore] = -1
gt_inds[~is_assigned] = 0
max_overlaps[~is_assigned] = 0
if p_use_label is True or p_use_label < rng.rand():
if num_classes == 0:
labels = torch.zeros(num_preds, dtype=torch.int64)
else:
labels = torch.from_numpy(
rng.randint(1, num_classes + 1, size=num_preds))
labels[~is_assigned] = 0
else:
labels = None
self = cls(num_gts, gt_inds, max_overlaps, labels)
return self
def add_gt_(self, gt_labels):
self_inds = torch.arange(
1, len(gt_labels) + 1, dtype=torch.long, device=gt_labels.device)
self.gt_inds = torch.cat([self_inds, self.gt_inds])
self.max_overlaps = torch.cat(
[self.max_overlaps.new_ones(len(gt_labels)), self.max_overlaps])
if self.labels is not None:
self.labels = torch.cat([gt_labels, self.labels])
......@@ -47,11 +47,30 @@ class BaseSampler(metaclass=ABCMeta):
Returns:
:obj:`SamplingResult`: Sampling result.
Example:
>>> from mmdet.core.bbox import RandomSampler
>>> from mmdet.core.bbox import AssignResult
>>> from mmdet.core.bbox.demodata import ensure_rng, random_boxes
>>> rng = ensure_rng(None)
>>> assign_result = AssignResult.random(rng=rng)
>>> bboxes = random_boxes(assign_result.num_preds, rng=rng)
>>> gt_bboxes = random_boxes(assign_result.num_gts, rng=rng)
>>> gt_labels = None
>>> self = RandomSampler(num=32, pos_fraction=0.5, neg_pos_ub=-1,
>>> add_gt_as_proposals=False)
>>> self = self.sample(assign_result, bboxes, gt_bboxes, gt_labels)
"""
if len(bboxes.shape) < 2:
bboxes = bboxes[None, :]
bboxes = bboxes[:, :4]
gt_flags = bboxes.new_zeros((bboxes.shape[0], ), dtype=torch.uint8)
if self.add_gt_as_proposals and len(gt_bboxes) > 0:
if gt_labels is None:
raise ValueError(
'gt_labels must be given when add_gt_as_proposals is True')
bboxes = torch.cat([gt_bboxes, bboxes], dim=0)
assign_result.add_gt_(gt_labels)
gt_ones = bboxes.new_ones(gt_bboxes.shape[0], dtype=torch.uint8)
......@@ -74,5 +93,6 @@ class BaseSampler(metaclass=ABCMeta):
assign_result, num_expected_neg, bboxes=bboxes, **kwargs)
neg_inds = neg_inds.unique()
return SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes,
assign_result, gt_flags)
sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes,
assign_result, gt_flags)
return sampling_result
......@@ -12,11 +12,12 @@ class RandomSampler(BaseSampler):
neg_pos_ub=-1,
add_gt_as_proposals=True,
**kwargs):
from mmdet.core.bbox import demodata
super(RandomSampler, self).__init__(num, pos_fraction, neg_pos_ub,
add_gt_as_proposals)
self.rng = demodata.ensure_rng(kwargs.get('rng', None))
@staticmethod
def random_choice(gallery, num):
def random_choice(self, gallery, num):
"""Random select some elements from the gallery.
It seems that Pytorch's implementation is slower than numpy so we use
......@@ -26,7 +27,7 @@ class RandomSampler(BaseSampler):
if isinstance(gallery, list):
gallery = np.array(gallery)
cands = np.arange(len(gallery))
np.random.shuffle(cands)
self.rng.shuffle(cands)
rand_inds = cands[:num]
if not isinstance(gallery, np.ndarray):
rand_inds = torch.from_numpy(rand_inds).long().to(gallery.device)
......
import torch
from mmdet.utils import util_mixins
class SamplingResult(object):
class SamplingResult(util_mixins.NiceRepr):
"""
Example:
>>> # xdoctest: +IGNORE_WANT
>>> from mmdet.core.bbox.samplers.sampling_result import * # NOQA
>>> self = SamplingResult.random(rng=10)
>>> print('self = {}'.format(self))
self = <SamplingResult({
'neg_bboxes': torch.Size([12, 4]),
'neg_inds': tensor([ 0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12]),
'num_gts': 4,
'pos_assigned_gt_inds': tensor([], dtype=torch.int64),
'pos_bboxes': torch.Size([0, 4]),
'pos_inds': tensor([], dtype=torch.int64),
'pos_is_gt': tensor([], dtype=torch.uint8)
})>
"""
def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result,
gt_flags):
......@@ -13,7 +31,17 @@ class SamplingResult(object):
self.num_gts = gt_bboxes.shape[0]
self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1
self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds, :]
if gt_bboxes.numel() == 0:
# hack for index error case
assert self.pos_assigned_gt_inds.numel() == 0
self.pos_gt_bboxes = torch.empty_like(gt_bboxes).view(-1, 4)
else:
if len(gt_bboxes.shape) < 2:
gt_bboxes = gt_bboxes.view(-1, 4)
self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds, :]
if assign_result.labels is not None:
self.pos_gt_labels = assign_result.labels[pos_inds]
else:
......@@ -22,3 +50,105 @@ class SamplingResult(object):
@property
def bboxes(self):
return torch.cat([self.pos_bboxes, self.neg_bboxes])
def to(self, device):
"""
Change the device of the data inplace.
Example:
>>> self = SamplingResult.random()
>>> print('self = {}'.format(self.to(None)))
>>> # xdoctest: +REQUIRES(--gpu)
>>> print('self = {}'.format(self.to(0)))
"""
_dict = self.__dict__
for key, value in _dict.items():
if isinstance(value, torch.Tensor):
_dict[key] = value.to(device)
return self
def __nice__(self):
data = self.info.copy()
data['pos_bboxes'] = data.pop('pos_bboxes').shape
data['neg_bboxes'] = data.pop('neg_bboxes').shape
parts = ['\'{}\': {!r}'.format(k, v) for k, v in sorted(data.items())]
body = ' ' + ',\n '.join(parts)
return '{\n' + body + '\n}'
@property
def info(self):
"""
Returns a dictionary of info about the object
"""
return {
'pos_inds': self.pos_inds,
'neg_inds': self.neg_inds,
'pos_bboxes': self.pos_bboxes,
'neg_bboxes': self.neg_bboxes,
'pos_is_gt': self.pos_is_gt,
'num_gts': self.num_gts,
'pos_assigned_gt_inds': self.pos_assigned_gt_inds,
}
@classmethod
def random(cls, rng=None, **kwargs):
"""
Args:
rng (None | int | numpy.random.RandomState): seed or state
Kwargs:
num_preds: number of predicted boxes
num_gts: number of true boxes
p_ignore (float): probability of a predicted box assinged to an
ignored truth
p_assigned (float): probability of a predicted box not being
assigned
p_use_label (float | bool): with labels or not
Returns:
AssignResult :
Example:
>>> from mmdet.core.bbox.samplers.sampling_result import * # NOQA
>>> self = SamplingResult.random()
>>> print(self.__dict__)
"""
from mmdet.core.bbox.samplers.random_sampler import RandomSampler
from mmdet.core.bbox.assigners.assign_result import AssignResult
from mmdet.core.bbox import demodata
rng = demodata.ensure_rng(rng)
# make probabalistic?
num = 32
pos_fraction = 0.5
neg_pos_ub = -1
assign_result = AssignResult.random(rng=rng, **kwargs)
# Note we could just compute an assignment
bboxes = demodata.random_boxes(assign_result.num_preds, rng=rng)
gt_bboxes = demodata.random_boxes(assign_result.num_gts, rng=rng)
if rng.rand() > 0.2:
# sometimes algorithms squeeze their data, be robust to that
gt_bboxes = gt_bboxes.squeeze()
bboxes = bboxes.squeeze()
if assign_result.labels is None:
gt_labels = None
else:
gt_labels = None # todo
if gt_labels is None:
add_gt_as_proposals = False
else:
add_gt_as_proposals = True # make probabalistic?
sampler = RandomSampler(
num,
pos_fraction,
neg_pos_ubo=neg_pos_ub,
add_gt_as_proposals=add_gt_as_proposals,
rng=rng)
self = sampler.sample(assign_result, bboxes, gt_bboxes, gt_labels)
return self
# -*- coding: utf-8 -*-
"""
This module defines the :class:`NiceRepr` mixin class, which defines a
``__repr__`` and ``__str__`` method that only depend on a custom ``__nice__``
method, which you must define. This means you only have to overload one
function instead of two. Furthermore, if the object defines a ``__len__``
method, then the ``__nice__`` method defaults to something sensible, otherwise
it is treated as abstract and raises ``NotImplementedError``.
To use simply have your object inherit from :class:`NiceRepr`
(multi-inheritance should be ok).
This code was copied from the ubelt library: https://github.com/Erotemic/ubelt
Example:
>>> # Objects that define __nice__ have a default __str__ and __repr__
>>> class Student(NiceRepr):
... def __init__(self, name):
... self.name = name
... def __nice__(self):
... return self.name
>>> s1 = Student('Alice')
>>> s2 = Student('Bob')
>>> print('s1 = {}'.format(s1))
>>> print('s2 = {}'.format(s2))
s1 = <Student(Alice)>
s2 = <Student(Bob)>
Example:
>>> # Objects that define __len__ have a default __nice__
>>> class Group(NiceRepr):
... def __init__(self, data):
... self.data = data
... def __len__(self):
... return len(self.data)
>>> g = Group([1, 2, 3])
>>> print('g = {}'.format(g))
g = <Group(3)>
"""
import warnings
class NiceRepr(object):
"""
Inherit from this class and define ``__nice__`` to "nicely" print your
objects.
Defines ``__str__`` and ``__repr__`` in terms of ``__nice__`` function
Classes that inherit from :class:`NiceRepr` should redefine ``__nice__``.
If the inheriting class has a ``__len__``, method then the default
``__nice__`` method will return its length.
Example:
>>> class Foo(NiceRepr):
... def __nice__(self):
... return 'info'
>>> foo = Foo()
>>> assert str(foo) == '<Foo(info)>'
>>> assert repr(foo).startswith('<Foo(info) at ')
Example:
>>> class Bar(NiceRepr):
... pass
>>> bar = Bar()
>>> import pytest
>>> with pytest.warns(None) as record:
>>> assert 'object at' in str(bar)
>>> assert 'object at' in repr(bar)
Example:
>>> class Baz(NiceRepr):
... def __len__(self):
... return 5
>>> baz = Baz()
>>> assert str(baz) == '<Baz(5)>'
"""
def __nice__(self):
if hasattr(self, '__len__'):
# It is a common pattern for objects to use __len__ in __nice__
# As a convenience we define a default __nice__ for these objects
return str(len(self))
else:
# In all other cases force the subclass to overload __nice__
raise NotImplementedError(
'Define the __nice__ method for {!r}'.format(self.__class__))
def __repr__(self):
try:
nice = self.__nice__()
classname = self.__class__.__name__
return '<{0}({1}) at {2}>'.format(classname, nice, hex(id(self)))
except NotImplementedError as ex:
warnings.warn(str(ex), category=RuntimeWarning)
return object.__repr__(self)
def __str__(self):
try:
classname = self.__class__.__name__
nice = self.__nice__()
return '<{0}({1})>'.format(classname, nice)
except NotImplementedError as ex:
warnings.warn(str(ex), category=RuntimeWarning)
return object.__repr__(self)
......@@ -259,3 +259,19 @@ def test_approx_iou_assigner_with_empty_boxes_and_gt():
assign_result = self.assign(approxs, squares, approxs_per_octave,
gt_bboxes)
assert len(assign_result.gt_inds) == 0
def test_random_assign_result():
"""
Test random instantiation of assign result to catch corner cases
"""
from mmdet.core.bbox.assigners.assign_result import AssignResult
AssignResult.random()
AssignResult.random(num_gts=0, num_preds=0)
AssignResult.random(num_gts=0, num_preds=3)
AssignResult.random(num_gts=3, num_preds=3)
AssignResult.random(num_gts=0, num_preds=3)
AssignResult.random(num_gts=7, num_preds=7)
AssignResult.random(num_gts=7, num_preds=64)
AssignResult.random(num_gts=24, num_preds=3)
......@@ -233,3 +233,17 @@ def test_ohem_sampler_empty_pred():
assert len(sample_result.pos_bboxes) == len(sample_result.pos_inds)
assert len(sample_result.neg_bboxes) == len(sample_result.neg_inds)
def test_random_sample_result():
from mmdet.core.bbox.samplers.sampling_result import SamplingResult
SamplingResult.random(num_gts=0, num_preds=0)
SamplingResult.random(num_gts=0, num_preds=3)
SamplingResult.random(num_gts=3, num_preds=3)
SamplingResult.random(num_gts=0, num_preds=3)
SamplingResult.random(num_gts=7, num_preds=7)
SamplingResult.random(num_gts=7, num_preds=64)
SamplingResult.random(num_gts=24, num_preds=3)
for i in range(3):
SamplingResult.random(rng=i)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment