"packaging/vscode:/vscode.git/clone" did not exist on "aa5b31d85756d916dd3ff464391935e17a002ca8"
Commit 8457bbab authored by Jon Crall's avatar Jon Crall Committed by Kai Chen
Browse files

Enhance AssignResult and SamplingResult (#1995)

* Enhance AssignResult and SamplingResult

Add runtime dependency on ubelt (pending approval)

Fix issue in SamplingResult.__init__

Add rng as attribute of RandomSampler

* fix linters

* remove ubelt

* Fix linters

* fix linters again
parent 78529eca
import torch import torch
from mmdet.utils import util_mixins
class AssignResult(object):
class AssignResult(util_mixins.NiceRepr):
""" """
Stores assignments between predicted and truth boxes. Stores assignments between predicted and truth boxes.
...@@ -44,20 +46,25 @@ class AssignResult(object): ...@@ -44,20 +46,25 @@ class AssignResult(object):
self.max_overlaps = max_overlaps self.max_overlaps = max_overlaps
self.labels = labels self.labels = labels
def add_gt_(self, gt_labels): @property
self_inds = torch.arange( def num_preds(self):
1, len(gt_labels) + 1, dtype=torch.long, device=gt_labels.device) """
self.gt_inds = torch.cat([self_inds, self.gt_inds]) Return the number of predictions in this assignment
"""
# Was this a bug? return len(self.gt_inds)
# self.max_overlaps = torch.cat(
# [self.max_overlaps.new_ones(self.num_gts), self.max_overlaps])
# IIUC, It seems like the correct code should be:
self.max_overlaps = torch.cat(
[self.max_overlaps.new_ones(len(gt_labels)), self.max_overlaps])
if self.labels is not None: @property
self.labels = torch.cat([gt_labels, self.labels]) def info(self):
"""
Returns a dictionary of info about the object
"""
return {
'num_gts': self.num_gts,
'num_preds': self.num_preds,
'gt_inds': self.gt_inds,
'max_overlaps': self.max_overlaps,
'labels': self.labels,
}
def __nice__(self): def __nice__(self):
""" """
...@@ -81,12 +88,105 @@ class AssignResult(object): ...@@ -81,12 +88,105 @@ class AssignResult(object):
parts.append('labels.shape={!r}'.format(tuple(self.labels.shape))) parts.append('labels.shape={!r}'.format(tuple(self.labels.shape)))
return ', '.join(parts) return ', '.join(parts)
def __repr__(self): @classmethod
nice = self.__nice__() def random(cls, **kwargs):
classname = self.__class__.__name__ """
return '<{}({}) at {}>'.format(classname, nice, hex(id(self))) Create random AssignResult for tests or debugging.
Kwargs:
num_preds: number of predicted boxes
num_gts: number of true boxes
p_ignore (float): probability of a predicted box assinged to an
ignored truth
p_assigned (float): probability of a predicted box not being
assigned
p_use_label (float | bool): with labels or not
rng (None | int | numpy.random.RandomState): seed or state
Returns:
AssignResult :
Example:
>>> from mmdet.core.bbox.assigners.assign_result import * # NOQA
>>> self = AssignResult.random()
>>> print(self.info)
"""
from mmdet.core.bbox import demodata
rng = demodata.ensure_rng(kwargs.get('rng', None))
num_gts = kwargs.get('num_gts', None)
num_preds = kwargs.get('num_preds', None)
p_ignore = kwargs.get('p_ignore', 0.3)
p_assigned = kwargs.get('p_assigned', 0.7)
p_use_label = kwargs.get('p_use_label', 0.5)
num_classes = kwargs.get('p_use_label', 3)
if num_gts is None:
num_gts = rng.randint(0, 8)
if num_preds is None:
num_preds = rng.randint(0, 16)
def __str__(self): if num_gts == 0:
classname = self.__class__.__name__ max_overlaps = torch.zeros(num_preds, dtype=torch.float32)
nice = self.__nice__() gt_inds = torch.zeros(num_preds, dtype=torch.int64)
return '<{}({})>'.format(classname, nice) if p_use_label is True or p_use_label < rng.rand():
labels = torch.zeros(num_preds, dtype=torch.int64)
else:
labels = None
else:
import numpy as np
# Create an overlap for each predicted box
max_overlaps = torch.from_numpy(rng.rand(num_preds))
# Construct gt_inds for each predicted box
is_assigned = torch.from_numpy(rng.rand(num_preds) < p_assigned)
# maximum number of assignments constraints
n_assigned = min(num_preds, min(num_gts, is_assigned.sum()))
assigned_idxs = np.where(is_assigned)[0]
rng.shuffle(assigned_idxs)
assigned_idxs = assigned_idxs[0:n_assigned]
assigned_idxs.sort()
is_assigned[:] = 0
is_assigned[assigned_idxs] = True
is_ignore = torch.from_numpy(
rng.rand(num_preds) < p_ignore) & is_assigned
gt_inds = torch.zeros(num_preds, dtype=torch.int64)
true_idxs = np.arange(num_gts)
rng.shuffle(true_idxs)
true_idxs = torch.from_numpy(true_idxs)
gt_inds[is_assigned] = true_idxs[:n_assigned]
gt_inds = torch.from_numpy(
rng.randint(1, num_gts + 1, size=num_preds))
gt_inds[is_ignore] = -1
gt_inds[~is_assigned] = 0
max_overlaps[~is_assigned] = 0
if p_use_label is True or p_use_label < rng.rand():
if num_classes == 0:
labels = torch.zeros(num_preds, dtype=torch.int64)
else:
labels = torch.from_numpy(
rng.randint(1, num_classes + 1, size=num_preds))
labels[~is_assigned] = 0
else:
labels = None
self = cls(num_gts, gt_inds, max_overlaps, labels)
return self
def add_gt_(self, gt_labels):
self_inds = torch.arange(
1, len(gt_labels) + 1, dtype=torch.long, device=gt_labels.device)
self.gt_inds = torch.cat([self_inds, self.gt_inds])
self.max_overlaps = torch.cat(
[self.max_overlaps.new_ones(len(gt_labels)), self.max_overlaps])
if self.labels is not None:
self.labels = torch.cat([gt_labels, self.labels])
...@@ -47,11 +47,30 @@ class BaseSampler(metaclass=ABCMeta): ...@@ -47,11 +47,30 @@ class BaseSampler(metaclass=ABCMeta):
Returns: Returns:
:obj:`SamplingResult`: Sampling result. :obj:`SamplingResult`: Sampling result.
Example:
>>> from mmdet.core.bbox import RandomSampler
>>> from mmdet.core.bbox import AssignResult
>>> from mmdet.core.bbox.demodata import ensure_rng, random_boxes
>>> rng = ensure_rng(None)
>>> assign_result = AssignResult.random(rng=rng)
>>> bboxes = random_boxes(assign_result.num_preds, rng=rng)
>>> gt_bboxes = random_boxes(assign_result.num_gts, rng=rng)
>>> gt_labels = None
>>> self = RandomSampler(num=32, pos_fraction=0.5, neg_pos_ub=-1,
>>> add_gt_as_proposals=False)
>>> self = self.sample(assign_result, bboxes, gt_bboxes, gt_labels)
""" """
if len(bboxes.shape) < 2:
bboxes = bboxes[None, :]
bboxes = bboxes[:, :4] bboxes = bboxes[:, :4]
gt_flags = bboxes.new_zeros((bboxes.shape[0], ), dtype=torch.uint8) gt_flags = bboxes.new_zeros((bboxes.shape[0], ), dtype=torch.uint8)
if self.add_gt_as_proposals and len(gt_bboxes) > 0: if self.add_gt_as_proposals and len(gt_bboxes) > 0:
if gt_labels is None:
raise ValueError(
'gt_labels must be given when add_gt_as_proposals is True')
bboxes = torch.cat([gt_bboxes, bboxes], dim=0) bboxes = torch.cat([gt_bboxes, bboxes], dim=0)
assign_result.add_gt_(gt_labels) assign_result.add_gt_(gt_labels)
gt_ones = bboxes.new_ones(gt_bboxes.shape[0], dtype=torch.uint8) gt_ones = bboxes.new_ones(gt_bboxes.shape[0], dtype=torch.uint8)
...@@ -74,5 +93,6 @@ class BaseSampler(metaclass=ABCMeta): ...@@ -74,5 +93,6 @@ class BaseSampler(metaclass=ABCMeta):
assign_result, num_expected_neg, bboxes=bboxes, **kwargs) assign_result, num_expected_neg, bboxes=bboxes, **kwargs)
neg_inds = neg_inds.unique() neg_inds = neg_inds.unique()
return SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes, sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes,
assign_result, gt_flags) assign_result, gt_flags)
return sampling_result
...@@ -12,11 +12,12 @@ class RandomSampler(BaseSampler): ...@@ -12,11 +12,12 @@ class RandomSampler(BaseSampler):
neg_pos_ub=-1, neg_pos_ub=-1,
add_gt_as_proposals=True, add_gt_as_proposals=True,
**kwargs): **kwargs):
from mmdet.core.bbox import demodata
super(RandomSampler, self).__init__(num, pos_fraction, neg_pos_ub, super(RandomSampler, self).__init__(num, pos_fraction, neg_pos_ub,
add_gt_as_proposals) add_gt_as_proposals)
self.rng = demodata.ensure_rng(kwargs.get('rng', None))
@staticmethod def random_choice(self, gallery, num):
def random_choice(gallery, num):
"""Random select some elements from the gallery. """Random select some elements from the gallery.
It seems that Pytorch's implementation is slower than numpy so we use It seems that Pytorch's implementation is slower than numpy so we use
...@@ -26,7 +27,7 @@ class RandomSampler(BaseSampler): ...@@ -26,7 +27,7 @@ class RandomSampler(BaseSampler):
if isinstance(gallery, list): if isinstance(gallery, list):
gallery = np.array(gallery) gallery = np.array(gallery)
cands = np.arange(len(gallery)) cands = np.arange(len(gallery))
np.random.shuffle(cands) self.rng.shuffle(cands)
rand_inds = cands[:num] rand_inds = cands[:num]
if not isinstance(gallery, np.ndarray): if not isinstance(gallery, np.ndarray):
rand_inds = torch.from_numpy(rand_inds).long().to(gallery.device) rand_inds = torch.from_numpy(rand_inds).long().to(gallery.device)
......
import torch import torch
from mmdet.utils import util_mixins
class SamplingResult(object):
class SamplingResult(util_mixins.NiceRepr):
"""
Example:
>>> # xdoctest: +IGNORE_WANT
>>> from mmdet.core.bbox.samplers.sampling_result import * # NOQA
>>> self = SamplingResult.random(rng=10)
>>> print('self = {}'.format(self))
self = <SamplingResult({
'neg_bboxes': torch.Size([12, 4]),
'neg_inds': tensor([ 0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12]),
'num_gts': 4,
'pos_assigned_gt_inds': tensor([], dtype=torch.int64),
'pos_bboxes': torch.Size([0, 4]),
'pos_inds': tensor([], dtype=torch.int64),
'pos_is_gt': tensor([], dtype=torch.uint8)
})>
"""
def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result, def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result,
gt_flags): gt_flags):
...@@ -13,7 +31,17 @@ class SamplingResult(object): ...@@ -13,7 +31,17 @@ class SamplingResult(object):
self.num_gts = gt_bboxes.shape[0] self.num_gts = gt_bboxes.shape[0]
self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1 self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1
self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds, :]
if gt_bboxes.numel() == 0:
# hack for index error case
assert self.pos_assigned_gt_inds.numel() == 0
self.pos_gt_bboxes = torch.empty_like(gt_bboxes).view(-1, 4)
else:
if len(gt_bboxes.shape) < 2:
gt_bboxes = gt_bboxes.view(-1, 4)
self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds, :]
if assign_result.labels is not None: if assign_result.labels is not None:
self.pos_gt_labels = assign_result.labels[pos_inds] self.pos_gt_labels = assign_result.labels[pos_inds]
else: else:
...@@ -22,3 +50,105 @@ class SamplingResult(object): ...@@ -22,3 +50,105 @@ class SamplingResult(object):
@property @property
def bboxes(self): def bboxes(self):
return torch.cat([self.pos_bboxes, self.neg_bboxes]) return torch.cat([self.pos_bboxes, self.neg_bboxes])
def to(self, device):
"""
Change the device of the data inplace.
Example:
>>> self = SamplingResult.random()
>>> print('self = {}'.format(self.to(None)))
>>> # xdoctest: +REQUIRES(--gpu)
>>> print('self = {}'.format(self.to(0)))
"""
_dict = self.__dict__
for key, value in _dict.items():
if isinstance(value, torch.Tensor):
_dict[key] = value.to(device)
return self
def __nice__(self):
data = self.info.copy()
data['pos_bboxes'] = data.pop('pos_bboxes').shape
data['neg_bboxes'] = data.pop('neg_bboxes').shape
parts = ['\'{}\': {!r}'.format(k, v) for k, v in sorted(data.items())]
body = ' ' + ',\n '.join(parts)
return '{\n' + body + '\n}'
@property
def info(self):
"""
Returns a dictionary of info about the object
"""
return {
'pos_inds': self.pos_inds,
'neg_inds': self.neg_inds,
'pos_bboxes': self.pos_bboxes,
'neg_bboxes': self.neg_bboxes,
'pos_is_gt': self.pos_is_gt,
'num_gts': self.num_gts,
'pos_assigned_gt_inds': self.pos_assigned_gt_inds,
}
@classmethod
def random(cls, rng=None, **kwargs):
"""
Args:
rng (None | int | numpy.random.RandomState): seed or state
Kwargs:
num_preds: number of predicted boxes
num_gts: number of true boxes
p_ignore (float): probability of a predicted box assinged to an
ignored truth
p_assigned (float): probability of a predicted box not being
assigned
p_use_label (float | bool): with labels or not
Returns:
AssignResult :
Example:
>>> from mmdet.core.bbox.samplers.sampling_result import * # NOQA
>>> self = SamplingResult.random()
>>> print(self.__dict__)
"""
from mmdet.core.bbox.samplers.random_sampler import RandomSampler
from mmdet.core.bbox.assigners.assign_result import AssignResult
from mmdet.core.bbox import demodata
rng = demodata.ensure_rng(rng)
# make probabalistic?
num = 32
pos_fraction = 0.5
neg_pos_ub = -1
assign_result = AssignResult.random(rng=rng, **kwargs)
# Note we could just compute an assignment
bboxes = demodata.random_boxes(assign_result.num_preds, rng=rng)
gt_bboxes = demodata.random_boxes(assign_result.num_gts, rng=rng)
if rng.rand() > 0.2:
# sometimes algorithms squeeze their data, be robust to that
gt_bboxes = gt_bboxes.squeeze()
bboxes = bboxes.squeeze()
if assign_result.labels is None:
gt_labels = None
else:
gt_labels = None # todo
if gt_labels is None:
add_gt_as_proposals = False
else:
add_gt_as_proposals = True # make probabalistic?
sampler = RandomSampler(
num,
pos_fraction,
neg_pos_ubo=neg_pos_ub,
add_gt_as_proposals=add_gt_as_proposals,
rng=rng)
self = sampler.sample(assign_result, bboxes, gt_bboxes, gt_labels)
return self
# -*- coding: utf-8 -*-
"""
This module defines the :class:`NiceRepr` mixin class, which defines a
``__repr__`` and ``__str__`` method that only depend on a custom ``__nice__``
method, which you must define. This means you only have to overload one
function instead of two. Furthermore, if the object defines a ``__len__``
method, then the ``__nice__`` method defaults to something sensible, otherwise
it is treated as abstract and raises ``NotImplementedError``.
To use simply have your object inherit from :class:`NiceRepr`
(multi-inheritance should be ok).
This code was copied from the ubelt library: https://github.com/Erotemic/ubelt
Example:
>>> # Objects that define __nice__ have a default __str__ and __repr__
>>> class Student(NiceRepr):
... def __init__(self, name):
... self.name = name
... def __nice__(self):
... return self.name
>>> s1 = Student('Alice')
>>> s2 = Student('Bob')
>>> print('s1 = {}'.format(s1))
>>> print('s2 = {}'.format(s2))
s1 = <Student(Alice)>
s2 = <Student(Bob)>
Example:
>>> # Objects that define __len__ have a default __nice__
>>> class Group(NiceRepr):
... def __init__(self, data):
... self.data = data
... def __len__(self):
... return len(self.data)
>>> g = Group([1, 2, 3])
>>> print('g = {}'.format(g))
g = <Group(3)>
"""
import warnings
class NiceRepr(object):
"""
Inherit from this class and define ``__nice__`` to "nicely" print your
objects.
Defines ``__str__`` and ``__repr__`` in terms of ``__nice__`` function
Classes that inherit from :class:`NiceRepr` should redefine ``__nice__``.
If the inheriting class has a ``__len__``, method then the default
``__nice__`` method will return its length.
Example:
>>> class Foo(NiceRepr):
... def __nice__(self):
... return 'info'
>>> foo = Foo()
>>> assert str(foo) == '<Foo(info)>'
>>> assert repr(foo).startswith('<Foo(info) at ')
Example:
>>> class Bar(NiceRepr):
... pass
>>> bar = Bar()
>>> import pytest
>>> with pytest.warns(None) as record:
>>> assert 'object at' in str(bar)
>>> assert 'object at' in repr(bar)
Example:
>>> class Baz(NiceRepr):
... def __len__(self):
... return 5
>>> baz = Baz()
>>> assert str(baz) == '<Baz(5)>'
"""
def __nice__(self):
if hasattr(self, '__len__'):
# It is a common pattern for objects to use __len__ in __nice__
# As a convenience we define a default __nice__ for these objects
return str(len(self))
else:
# In all other cases force the subclass to overload __nice__
raise NotImplementedError(
'Define the __nice__ method for {!r}'.format(self.__class__))
def __repr__(self):
try:
nice = self.__nice__()
classname = self.__class__.__name__
return '<{0}({1}) at {2}>'.format(classname, nice, hex(id(self)))
except NotImplementedError as ex:
warnings.warn(str(ex), category=RuntimeWarning)
return object.__repr__(self)
def __str__(self):
try:
classname = self.__class__.__name__
nice = self.__nice__()
return '<{0}({1})>'.format(classname, nice)
except NotImplementedError as ex:
warnings.warn(str(ex), category=RuntimeWarning)
return object.__repr__(self)
...@@ -259,3 +259,19 @@ def test_approx_iou_assigner_with_empty_boxes_and_gt(): ...@@ -259,3 +259,19 @@ def test_approx_iou_assigner_with_empty_boxes_and_gt():
assign_result = self.assign(approxs, squares, approxs_per_octave, assign_result = self.assign(approxs, squares, approxs_per_octave,
gt_bboxes) gt_bboxes)
assert len(assign_result.gt_inds) == 0 assert len(assign_result.gt_inds) == 0
def test_random_assign_result():
"""
Test random instantiation of assign result to catch corner cases
"""
from mmdet.core.bbox.assigners.assign_result import AssignResult
AssignResult.random()
AssignResult.random(num_gts=0, num_preds=0)
AssignResult.random(num_gts=0, num_preds=3)
AssignResult.random(num_gts=3, num_preds=3)
AssignResult.random(num_gts=0, num_preds=3)
AssignResult.random(num_gts=7, num_preds=7)
AssignResult.random(num_gts=7, num_preds=64)
AssignResult.random(num_gts=24, num_preds=3)
...@@ -233,3 +233,17 @@ def test_ohem_sampler_empty_pred(): ...@@ -233,3 +233,17 @@ def test_ohem_sampler_empty_pred():
assert len(sample_result.pos_bboxes) == len(sample_result.pos_inds) assert len(sample_result.pos_bboxes) == len(sample_result.pos_inds)
assert len(sample_result.neg_bboxes) == len(sample_result.neg_inds) assert len(sample_result.neg_bboxes) == len(sample_result.neg_inds)
def test_random_sample_result():
from mmdet.core.bbox.samplers.sampling_result import SamplingResult
SamplingResult.random(num_gts=0, num_preds=0)
SamplingResult.random(num_gts=0, num_preds=3)
SamplingResult.random(num_gts=3, num_preds=3)
SamplingResult.random(num_gts=0, num_preds=3)
SamplingResult.random(num_gts=7, num_preds=7)
SamplingResult.random(num_gts=7, num_preds=64)
SamplingResult.random(num_gts=24, num_preds=3)
for i in range(3):
SamplingResult.random(rng=i)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment