Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
e3c1b855
Commit
e3c1b855
authored
Dec 06, 2018
by
Kai Chen
Browse files
refactoring for sampler and assigner
parent
65a2e5ea
Changes
32
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
332 additions
and
247 deletions
+332
-247
mmdet/core/bbox/assigners/assign_result.py
mmdet/core/bbox/assigners/assign_result.py
+19
-0
mmdet/core/bbox/assigners/base_assigner.py
mmdet/core/bbox/assigners/base_assigner.py
+8
-0
mmdet/core/bbox/assigners/max_iou_assigner.py
mmdet/core/bbox/assigners/max_iou_assigner.py
+4
-20
mmdet/core/bbox/samplers/__init__.py
mmdet/core/bbox/samplers/__init__.py
+13
-0
mmdet/core/bbox/samplers/base_sampler.py
mmdet/core/bbox/samplers/base_sampler.py
+64
-0
mmdet/core/bbox/samplers/combined_sampler.py
mmdet/core/bbox/samplers/combined_sampler.py
+16
-0
mmdet/core/bbox/samplers/instance_balanced_pos_sampler.py
mmdet/core/bbox/samplers/instance_balanced_pos_sampler.py
+41
-0
mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py
mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py
+62
-0
mmdet/core/bbox/samplers/pseudo_sampler.py
mmdet/core/bbox/samplers/pseudo_sampler.py
+26
-0
mmdet/core/bbox/samplers/random_sampler.py
mmdet/core/bbox/samplers/random_sampler.py
+55
-0
mmdet/core/bbox/samplers/sampling_result.py
mmdet/core/bbox/samplers/sampling_result.py
+24
-0
mmdet/core/bbox/sampling.py
mmdet/core/bbox/sampling.py
+0
-227
No files found.
mmdet/core/bbox/assigners/assign_result.py
0 → 100644
View file @
e3c1b855
import
torch
class
AssignResult
(
object
):
def
__init__
(
self
,
num_gts
,
gt_inds
,
max_overlaps
,
labels
=
None
):
self
.
num_gts
=
num_gts
self
.
gt_inds
=
gt_inds
self
.
max_overlaps
=
max_overlaps
self
.
labels
=
labels
def
add_gt_
(
self
,
gt_labels
):
self_inds
=
torch
.
arange
(
1
,
len
(
gt_labels
)
+
1
,
dtype
=
torch
.
long
,
device
=
gt_labels
.
device
)
self
.
gt_inds
=
torch
.
cat
([
self_inds
,
self
.
gt_inds
])
self
.
max_overlaps
=
torch
.
cat
(
[
self
.
max_overlaps
.
new_ones
(
self
.
num_gts
),
self
.
max_overlaps
])
if
self
.
labels
is
not
None
:
self
.
labels
=
torch
.
cat
([
gt_labels
,
self
.
labels
])
mmdet/core/bbox/assigners/base_assigner.py
0 → 100644
View file @
e3c1b855
from
abc
import
ABCMeta
,
abstractmethod
class
BaseAssigner
(
metaclass
=
ABCMeta
):
@
abstractmethod
def
assign
(
self
,
bboxes
,
gt_bboxes
,
gt_bboxes_ignore
=
None
,
gt_labels
=
None
):
pass
mmdet/core/bbox/assign
ment
.py
→
mmdet/core/bbox/assign
ers/max_iou_assigner
.py
View file @
e3c1b855
import
torch
from
.geometry
import
bbox_overlaps
from
.base_assigner
import
BaseAssigner
from
.assign_result
import
AssignResult
from
..geometry
import
bbox_overlaps
class
BBox
Assigner
(
object
):
class
MaxIoU
Assigner
(
BaseAssigner
):
"""Assign a corresponding gt bbox or background to each bbox.
Each proposals will be assigned with `-1`, `0`, or a positive integer
...
...
@@ -135,21 +137,3 @@ class BBoxAssigner(object):
return
AssignResult
(
num_gts
,
assigned_gt_inds
,
max_overlaps
,
labels
=
assigned_labels
)
class
AssignResult
(
object
):
def
__init__
(
self
,
num_gts
,
gt_inds
,
max_overlaps
,
labels
=
None
):
self
.
num_gts
=
num_gts
self
.
gt_inds
=
gt_inds
self
.
max_overlaps
=
max_overlaps
self
.
labels
=
labels
def
add_gt_
(
self
,
gt_labels
):
self_inds
=
torch
.
arange
(
1
,
len
(
gt_labels
)
+
1
,
dtype
=
torch
.
long
,
device
=
gt_labels
.
device
)
self
.
gt_inds
=
torch
.
cat
([
self_inds
,
self
.
gt_inds
])
self
.
max_overlaps
=
torch
.
cat
(
[
self
.
max_overlaps
.
new_ones
(
self
.
num_gts
),
self
.
max_overlaps
])
if
self
.
labels
is
not
None
:
self
.
labels
=
torch
.
cat
([
gt_labels
,
self
.
labels
])
mmdet/core/bbox/samplers/__init__.py
0 → 100644
View file @
e3c1b855
from
.base_sampler
import
BaseSampler
from
.pseudo_sampler
import
PseudoSampler
from
.random_sampler
import
RandomSampler
from
.instance_balanced_pos_sampler
import
InstanceBalancedPosSampler
from
.iou_balanced_neg_sampler
import
IoUBalancedNegSampler
from
.combined_sampler
import
CombinedSampler
from
.sampling_result
import
SamplingResult
__all__
=
[
'BaseSampler'
,
'PseudoSampler'
,
'RandomSampler'
,
'InstanceBalancedPosSampler'
,
'IoUBalancedNegSampler'
,
'CombinedSampler'
,
'SamplingResult'
]
mmdet/core/bbox/samplers/base_sampler.py
0 → 100644
View file @
e3c1b855
from
abc
import
ABCMeta
,
abstractmethod
import
torch
from
.sampling_result
import
SamplingResult
class
BaseSampler
(
metaclass
=
ABCMeta
):
def
__init__
(
self
):
self
.
pos_sampler
=
self
self
.
neg_sampler
=
self
@
abstractmethod
def
_sample_pos
(
self
,
assign_result
,
num_expected
):
pass
@
abstractmethod
def
_sample_neg
(
self
,
assign_result
,
num_expected
):
pass
def
sample
(
self
,
assign_result
,
bboxes
,
gt_bboxes
,
gt_labels
=
None
):
"""Sample positive and negative bboxes.
This is a simple implementation of bbox sampling given candidates,
assigning results and ground truth bboxes.
Args:
assign_result (:obj:`AssignResult`): Bbox assigning results.
bboxes (Tensor): Boxes to be sampled from.
gt_bboxes (Tensor): Ground truth bboxes.
gt_labels (Tensor, optional): Class labels of ground truth bboxes.
Returns:
:obj:`SamplingResult`: Sampling result.
"""
bboxes
=
bboxes
[:,
:
4
]
gt_flags
=
bboxes
.
new_zeros
((
bboxes
.
shape
[
0
],
),
dtype
=
torch
.
uint8
)
if
self
.
add_gt_as_proposals
:
bboxes
=
torch
.
cat
([
gt_bboxes
,
bboxes
],
dim
=
0
)
assign_result
.
add_gt_
(
gt_labels
)
gt_ones
=
bboxes
.
new_ones
(
gt_bboxes
.
shape
[
0
],
dtype
=
torch
.
uint8
)
gt_flags
=
torch
.
cat
([
gt_ones
,
gt_flags
])
num_expected_pos
=
int
(
self
.
num
*
self
.
pos_fraction
)
pos_inds
=
self
.
pos_sampler
.
_sample_pos
(
assign_result
,
num_expected_pos
)
# We found that sampled indices have duplicated items occasionally.
# (may be a bug of PyTorch)
pos_inds
=
pos_inds
.
unique
()
num_sampled_pos
=
pos_inds
.
numel
()
num_expected_neg
=
self
.
num
-
num_sampled_pos
if
self
.
neg_pos_ub
>=
0
:
_pos
=
max
(
1
,
num_sampled_pos
)
neg_upper_bound
=
int
(
self
.
neg_pos_ub
*
_pos
)
if
num_expected_neg
>
neg_upper_bound
:
num_expected_neg
=
neg_upper_bound
neg_inds
=
self
.
neg_sampler
.
_sample_neg
(
assign_result
,
num_expected_neg
)
neg_inds
=
neg_inds
.
unique
()
return
SamplingResult
(
pos_inds
,
neg_inds
,
bboxes
,
gt_bboxes
,
assign_result
,
gt_flags
)
mmdet/core/bbox/samplers/combined_sampler.py
0 → 100644
View file @
e3c1b855
from
mmcv.runner
import
obj_from_dict
from
.random_sampler
import
RandomSampler
from
..assign_sampling
import
build_sampler
class
CombinedSampler
(
RandomSampler
):
def
__init__
(
self
,
num
,
pos_fraction
,
pos_sampler
,
neg_sampler
,
**
kwargs
):
super
(
CombinedSampler
,
self
).
__init__
(
num
,
pos_fraction
,
**
kwargs
)
default_args
=
dict
(
num
=
num
,
pos_fraction
=
pos_fraction
)
default_args
.
update
(
kwargs
)
self
.
pos_sampler
=
build_sampler
(
pos_sampler
,
default_args
=
default_args
)
self
.
neg_sampler
=
build_sampler
(
neg_sampler
,
default_args
=
default_args
)
mmdet/core/bbox/samplers/instance_balanced_pos_sampler.py
0 → 100644
View file @
e3c1b855
import
numpy
as
np
import
torch
from
.random_sampler
import
RandomSampler
class
InstanceBalancedPosSampler
(
RandomSampler
):
def
_sample_pos
(
self
,
assign_result
,
num_expected
):
pos_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
>
0
)
if
pos_inds
.
numel
()
!=
0
:
pos_inds
=
pos_inds
.
squeeze
(
1
)
if
pos_inds
.
numel
()
<=
num_expected
:
return
pos_inds
else
:
unique_gt_inds
=
assign_result
.
gt_inds
[
pos_inds
].
unique
()
num_gts
=
len
(
unique_gt_inds
)
num_per_gt
=
int
(
round
(
num_expected
/
float
(
num_gts
))
+
1
)
sampled_inds
=
[]
for
i
in
unique_gt_inds
:
inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
==
i
.
item
())
if
inds
.
numel
()
!=
0
:
inds
=
inds
.
squeeze
(
1
)
else
:
continue
if
len
(
inds
)
>
num_per_gt
:
inds
=
self
.
random_choice
(
inds
,
num_per_gt
)
sampled_inds
.
append
(
inds
)
sampled_inds
=
torch
.
cat
(
sampled_inds
)
if
len
(
sampled_inds
)
<
num_expected
:
num_extra
=
num_expected
-
len
(
sampled_inds
)
extra_inds
=
np
.
array
(
list
(
set
(
pos_inds
.
cpu
())
-
set
(
sampled_inds
.
cpu
())))
if
len
(
extra_inds
)
>
num_extra
:
extra_inds
=
self
.
random_choice
(
extra_inds
,
num_extra
)
extra_inds
=
torch
.
from_numpy
(
extra_inds
).
to
(
assign_result
.
gt_inds
.
device
).
long
()
sampled_inds
=
torch
.
cat
([
sampled_inds
,
extra_inds
])
elif
len
(
sampled_inds
)
>
num_expected
:
sampled_inds
=
self
.
random_choice
(
sampled_inds
,
num_expected
)
return
sampled_inds
mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py
0 → 100644
View file @
e3c1b855
import
numpy
as
np
import
torch
from
.random_sampler
import
RandomSampler
class
IoUBalancedNegSampler
(
RandomSampler
):
def
__init__
(
self
,
num
,
pos_fraction
,
hard_thr
=
0.1
,
hard_fraction
=
0.5
,
**
kwargs
):
super
(
IoUBalancedNegSampler
,
self
).
__init__
(
num
,
pos_fraction
,
**
kwargs
)
assert
hard_thr
>
0
assert
0
<
hard_fraction
<
1
self
.
hard_thr
=
hard_thr
self
.
hard_fraction
=
hard_fraction
def
_sample_neg
(
self
,
assign_result
,
num_expected
):
neg_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
==
0
)
if
neg_inds
.
numel
()
!=
0
:
neg_inds
=
neg_inds
.
squeeze
(
1
)
if
len
(
neg_inds
)
<=
num_expected
:
return
neg_inds
else
:
max_overlaps
=
assign_result
.
max_overlaps
.
cpu
().
numpy
()
# balance sampling for negative samples
neg_set
=
set
(
neg_inds
.
cpu
().
numpy
())
easy_set
=
set
(
np
.
where
(
np
.
logical_and
(
max_overlaps
>=
0
,
max_overlaps
<
self
.
hard_thr
))[
0
])
hard_set
=
set
(
np
.
where
(
max_overlaps
>=
self
.
hard_thr
)[
0
])
easy_neg_inds
=
list
(
easy_set
&
neg_set
)
hard_neg_inds
=
list
(
hard_set
&
neg_set
)
num_expected_hard
=
int
(
num_expected
*
self
.
hard_fraction
)
if
len
(
hard_neg_inds
)
>
num_expected_hard
:
sampled_hard_inds
=
self
.
random_choice
(
hard_neg_inds
,
num_expected_hard
)
else
:
sampled_hard_inds
=
np
.
array
(
hard_neg_inds
,
dtype
=
np
.
int
)
num_expected_easy
=
num_expected
-
len
(
sampled_hard_inds
)
if
len
(
easy_neg_inds
)
>
num_expected_easy
:
sampled_easy_inds
=
self
.
random_choice
(
easy_neg_inds
,
num_expected_easy
)
else
:
sampled_easy_inds
=
np
.
array
(
easy_neg_inds
,
dtype
=
np
.
int
)
sampled_inds
=
np
.
concatenate
((
sampled_easy_inds
,
sampled_hard_inds
))
if
len
(
sampled_inds
)
<
num_expected
:
num_extra
=
num_expected
-
len
(
sampled_inds
)
extra_inds
=
np
.
array
(
list
(
neg_set
-
set
(
sampled_inds
)))
if
len
(
extra_inds
)
>
num_extra
:
extra_inds
=
self
.
random_choice
(
extra_inds
,
num_extra
)
sampled_inds
=
np
.
concatenate
((
sampled_inds
,
extra_inds
))
sampled_inds
=
torch
.
from_numpy
(
sampled_inds
).
long
().
to
(
assign_result
.
gt_inds
.
device
)
return
sampled_inds
mmdet/core/bbox/samplers/pseudo_sampler.py
0 → 100644
View file @
e3c1b855
import
torch
from
.base_sampler
import
BaseSampler
from
.sampling_result
import
SamplingResult
class
PseudoSampler
(
BaseSampler
):
def
__init__
(
self
):
pass
def
_sample_pos
(
self
):
raise
NotImplementedError
def
_sample_neg
(
self
):
raise
NotImplementedError
def
sample
(
self
,
assign_result
,
bboxes
,
gt_bboxes
):
pos_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
>
0
).
squeeze
(
-
1
).
unique
()
neg_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
==
0
).
squeeze
(
-
1
).
unique
()
gt_flags
=
bboxes
.
new_zeros
(
bboxes
.
shape
[
0
],
dtype
=
torch
.
uint8
)
sampling_result
=
SamplingResult
(
pos_inds
,
neg_inds
,
bboxes
,
gt_bboxes
,
assign_result
,
gt_flags
)
return
sampling_result
mmdet/core/bbox/samplers/random_sampler.py
0 → 100644
View file @
e3c1b855
import
numpy
as
np
import
torch
from
.base_sampler
import
BaseSampler
class
RandomSampler
(
BaseSampler
):
def
__init__
(
self
,
num
,
pos_fraction
,
neg_pos_ub
=-
1
,
add_gt_as_proposals
=
True
):
super
(
RandomSampler
,
self
).
__init__
()
self
.
num
=
num
self
.
pos_fraction
=
pos_fraction
self
.
neg_pos_ub
=
neg_pos_ub
self
.
add_gt_as_proposals
=
add_gt_as_proposals
@
staticmethod
def
random_choice
(
gallery
,
num
):
"""Random select some elements from the gallery.
It seems that Pytorch's implementation is slower than numpy so we use
numpy to randperm the indices.
"""
assert
len
(
gallery
)
>=
num
if
isinstance
(
gallery
,
list
):
gallery
=
np
.
array
(
gallery
)
cands
=
np
.
arange
(
len
(
gallery
))
np
.
random
.
shuffle
(
cands
)
rand_inds
=
cands
[:
num
]
if
not
isinstance
(
gallery
,
np
.
ndarray
):
rand_inds
=
torch
.
from_numpy
(
rand_inds
).
long
().
to
(
gallery
.
device
)
return
gallery
[
rand_inds
]
def
_sample_pos
(
self
,
assign_result
,
num_expected
):
"""Randomly sample some positive samples."""
pos_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
>
0
)
if
pos_inds
.
numel
()
!=
0
:
pos_inds
=
pos_inds
.
squeeze
(
1
)
if
pos_inds
.
numel
()
<=
num_expected
:
return
pos_inds
else
:
return
self
.
random_choice
(
pos_inds
,
num_expected
)
def
_sample_neg
(
self
,
assign_result
,
num_expected
):
"""Randomly sample some negative samples."""
neg_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
==
0
)
if
neg_inds
.
numel
()
!=
0
:
neg_inds
=
neg_inds
.
squeeze
(
1
)
if
len
(
neg_inds
)
<=
num_expected
:
return
neg_inds
else
:
return
self
.
random_choice
(
neg_inds
,
num_expected
)
mmdet/core/bbox/samplers/sampling_result.py
0 → 100644
View file @
e3c1b855
import
torch
class
SamplingResult
(
object
):
def
__init__
(
self
,
pos_inds
,
neg_inds
,
bboxes
,
gt_bboxes
,
assign_result
,
gt_flags
):
self
.
pos_inds
=
pos_inds
self
.
neg_inds
=
neg_inds
self
.
pos_bboxes
=
bboxes
[
pos_inds
]
self
.
neg_bboxes
=
bboxes
[
neg_inds
]
self
.
pos_is_gt
=
gt_flags
[
pos_inds
]
self
.
num_gts
=
gt_bboxes
.
shape
[
0
]
self
.
pos_assigned_gt_inds
=
assign_result
.
gt_inds
[
pos_inds
]
-
1
self
.
pos_gt_bboxes
=
gt_bboxes
[
self
.
pos_assigned_gt_inds
,
:]
if
assign_result
.
labels
is
not
None
:
self
.
pos_gt_labels
=
assign_result
.
labels
[
pos_inds
]
else
:
self
.
pos_gt_labels
=
None
@
property
def
bboxes
(
self
):
return
torch
.
cat
([
self
.
pos_bboxes
,
self
.
neg_bboxes
])
mmdet/core/bbox/sampling.py
deleted
100644 → 0
View file @
65a2e5ea
import
numpy
as
np
import
torch
from
.assignment
import
BBoxAssigner
def
random_choice
(
gallery
,
num
):
"""Random select some elements from the gallery.
It seems that Pytorch's implementation is slower than numpy so we use numpy
to randperm the indices.
"""
assert
len
(
gallery
)
>=
num
if
isinstance
(
gallery
,
list
):
gallery
=
np
.
array
(
gallery
)
cands
=
np
.
arange
(
len
(
gallery
))
np
.
random
.
shuffle
(
cands
)
rand_inds
=
cands
[:
num
]
if
not
isinstance
(
gallery
,
np
.
ndarray
):
rand_inds
=
torch
.
from_numpy
(
rand_inds
).
long
().
to
(
gallery
.
device
)
return
gallery
[
rand_inds
]
def
assign_and_sample
(
bboxes
,
gt_bboxes
,
gt_bboxes_ignore
,
gt_labels
,
cfg
):
bbox_assigner
=
BBoxAssigner
(
**
cfg
.
assigner
)
bbox_sampler
=
BBoxSampler
(
**
cfg
.
sampler
)
assign_result
=
bbox_assigner
.
assign
(
bboxes
,
gt_bboxes
,
gt_bboxes_ignore
,
gt_labels
)
sampling_result
=
bbox_sampler
.
sample
(
assign_result
,
bboxes
,
gt_bboxes
,
gt_labels
)
return
assign_result
,
sampling_result
class
BBoxSampler
(
object
):
"""Sample positive and negative bboxes given assigned results.
Args:
pos_fraction (float): Positive sample fraction.
neg_pos_ub (float): Negative/Positive upper bound.
pos_balance_sampling (bool): Whether to sample positive samples around
each gt bbox evenly.
neg_balance_thr (float, optional): IoU threshold for simple/hard
negative balance sampling.
neg_hard_fraction (float, optional): Fraction of hard negative samples
for negative balance sampling.
"""
def
__init__
(
self
,
num
,
pos_fraction
,
neg_pos_ub
=-
1
,
add_gt_as_proposals
=
True
,
pos_balance_sampling
=
False
,
neg_balance_thr
=
0
,
neg_hard_fraction
=
0.5
):
self
.
num
=
num
self
.
pos_fraction
=
pos_fraction
self
.
neg_pos_ub
=
neg_pos_ub
self
.
add_gt_as_proposals
=
add_gt_as_proposals
self
.
pos_balance_sampling
=
pos_balance_sampling
self
.
neg_balance_thr
=
neg_balance_thr
self
.
neg_hard_fraction
=
neg_hard_fraction
def
_sample_pos
(
self
,
assign_result
,
num_expected
):
"""Balance sampling for positive bboxes/anchors.
1. calculate average positive num for each gt: num_per_gt
2. sample at most num_per_gt positives for each gt
3. random sampling from rest anchors if not enough fg
"""
pos_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
>
0
)
if
pos_inds
.
numel
()
!=
0
:
pos_inds
=
pos_inds
.
squeeze
(
1
)
if
pos_inds
.
numel
()
<=
num_expected
:
return
pos_inds
elif
not
self
.
pos_balance_sampling
:
return
random_choice
(
pos_inds
,
num_expected
)
else
:
unique_gt_inds
=
torch
.
unique
(
assign_result
.
gt_inds
[
pos_inds
].
cpu
())
num_gts
=
len
(
unique_gt_inds
)
num_per_gt
=
int
(
round
(
num_expected
/
float
(
num_gts
))
+
1
)
sampled_inds
=
[]
for
i
in
unique_gt_inds
:
inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
==
i
.
item
())
if
inds
.
numel
()
!=
0
:
inds
=
inds
.
squeeze
(
1
)
else
:
continue
if
len
(
inds
)
>
num_per_gt
:
inds
=
random_choice
(
inds
,
num_per_gt
)
sampled_inds
.
append
(
inds
)
sampled_inds
=
torch
.
cat
(
sampled_inds
)
if
len
(
sampled_inds
)
<
num_expected
:
num_extra
=
num_expected
-
len
(
sampled_inds
)
extra_inds
=
np
.
array
(
list
(
set
(
pos_inds
.
cpu
())
-
set
(
sampled_inds
.
cpu
())))
if
len
(
extra_inds
)
>
num_extra
:
extra_inds
=
random_choice
(
extra_inds
,
num_extra
)
extra_inds
=
torch
.
from_numpy
(
extra_inds
).
to
(
assign_result
.
gt_inds
.
device
).
long
()
sampled_inds
=
torch
.
cat
([
sampled_inds
,
extra_inds
])
elif
len
(
sampled_inds
)
>
num_expected
:
sampled_inds
=
random_choice
(
sampled_inds
,
num_expected
)
return
sampled_inds
def
_sample_neg
(
self
,
assign_result
,
num_expected
):
"""Balance sampling for negative bboxes/anchors.
Negative samples are split into 2 set: hard (balance_thr <= iou <
neg_iou_thr) and easy (iou < balance_thr). The sampling ratio is
controlled by `hard_fraction`.
"""
neg_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
==
0
)
if
neg_inds
.
numel
()
!=
0
:
neg_inds
=
neg_inds
.
squeeze
(
1
)
if
len
(
neg_inds
)
<=
num_expected
:
return
neg_inds
elif
self
.
neg_balance_thr
<=
0
:
# uniform sampling among all negative samples
return
random_choice
(
neg_inds
,
num_expected
)
else
:
max_overlaps
=
assign_result
.
max_overlaps
.
cpu
().
numpy
()
# balance sampling for negative samples
neg_set
=
set
(
neg_inds
.
cpu
().
numpy
())
easy_set
=
set
(
np
.
where
(
np
.
logical_and
(
max_overlaps
>=
0
,
max_overlaps
<
self
.
neg_balance_thr
))[
0
])
hard_set
=
set
(
np
.
where
(
max_overlaps
>=
self
.
neg_balance_thr
)[
0
])
easy_neg_inds
=
list
(
easy_set
&
neg_set
)
hard_neg_inds
=
list
(
hard_set
&
neg_set
)
num_expected_hard
=
int
(
num_expected
*
self
.
neg_hard_fraction
)
if
len
(
hard_neg_inds
)
>
num_expected_hard
:
sampled_hard_inds
=
random_choice
(
hard_neg_inds
,
num_expected_hard
)
else
:
sampled_hard_inds
=
np
.
array
(
hard_neg_inds
,
dtype
=
np
.
int
)
num_expected_easy
=
num_expected
-
len
(
sampled_hard_inds
)
if
len
(
easy_neg_inds
)
>
num_expected_easy
:
sampled_easy_inds
=
random_choice
(
easy_neg_inds
,
num_expected_easy
)
else
:
sampled_easy_inds
=
np
.
array
(
easy_neg_inds
,
dtype
=
np
.
int
)
sampled_inds
=
np
.
concatenate
((
sampled_easy_inds
,
sampled_hard_inds
))
if
len
(
sampled_inds
)
<
num_expected
:
num_extra
=
num_expected
-
len
(
sampled_inds
)
extra_inds
=
np
.
array
(
list
(
neg_set
-
set
(
sampled_inds
)))
if
len
(
extra_inds
)
>
num_extra
:
extra_inds
=
random_choice
(
extra_inds
,
num_extra
)
sampled_inds
=
np
.
concatenate
((
sampled_inds
,
extra_inds
))
sampled_inds
=
torch
.
from_numpy
(
sampled_inds
).
long
().
to
(
assign_result
.
gt_inds
.
device
)
return
sampled_inds
def
sample
(
self
,
assign_result
,
bboxes
,
gt_bboxes
,
gt_labels
=
None
):
"""Sample positive and negative bboxes.
This is a simple implementation of bbox sampling given candidates,
assigning results and ground truth bboxes.
1. Assign gt to each bbox.
2. Add gt bboxes to the sampling pool (optional).
3. Perform positive and negative sampling.
Args:
assign_result (:obj:`AssignResult`): Bbox assigning results.
bboxes (Tensor): Boxes to be sampled from.
gt_bboxes (Tensor): Ground truth bboxes.
gt_labels (Tensor, optional): Class labels of ground truth bboxes.
Returns:
:obj:`SamplingResult`: Sampling result.
"""
bboxes
=
bboxes
[:,
:
4
]
gt_flags
=
bboxes
.
new_zeros
((
bboxes
.
shape
[
0
],
),
dtype
=
torch
.
uint8
)
if
self
.
add_gt_as_proposals
:
bboxes
=
torch
.
cat
([
gt_bboxes
,
bboxes
],
dim
=
0
)
assign_result
.
add_gt_
(
gt_labels
)
gt_flags
=
torch
.
cat
([
bboxes
.
new_ones
((
gt_bboxes
.
shape
[
0
],
),
dtype
=
torch
.
uint8
),
gt_flags
])
num_expected_pos
=
int
(
self
.
num
*
self
.
pos_fraction
)
pos_inds
=
self
.
_sample_pos
(
assign_result
,
num_expected_pos
)
# We found that sampled indices have duplicated items occasionally.
# (mab be a bug of PyTorch)
pos_inds
=
pos_inds
.
unique
()
num_sampled_pos
=
pos_inds
.
numel
()
num_expected_neg
=
self
.
num
-
num_sampled_pos
if
self
.
neg_pos_ub
>=
0
:
num_neg_max
=
int
(
self
.
neg_pos_ub
*
num_sampled_pos
)
if
num_sampled_pos
>
0
else
int
(
self
.
neg_pos_ub
)
num_expected_neg
=
min
(
num_neg_max
,
num_expected_neg
)
neg_inds
=
self
.
_sample_neg
(
assign_result
,
num_expected_neg
)
neg_inds
=
neg_inds
.
unique
()
return
SamplingResult
(
pos_inds
,
neg_inds
,
bboxes
,
gt_bboxes
,
assign_result
,
gt_flags
)
class
SamplingResult
(
object
):
def
__init__
(
self
,
pos_inds
,
neg_inds
,
bboxes
,
gt_bboxes
,
assign_result
,
gt_flags
):
self
.
pos_inds
=
pos_inds
self
.
neg_inds
=
neg_inds
self
.
pos_bboxes
=
bboxes
[
pos_inds
]
self
.
neg_bboxes
=
bboxes
[
neg_inds
]
self
.
pos_is_gt
=
gt_flags
[
pos_inds
]
self
.
num_gts
=
gt_bboxes
.
shape
[
0
]
self
.
pos_assigned_gt_inds
=
assign_result
.
gt_inds
[
pos_inds
]
-
1
self
.
pos_gt_bboxes
=
gt_bboxes
[
self
.
pos_assigned_gt_inds
,
:]
if
assign_result
.
labels
is
not
None
:
self
.
pos_gt_labels
=
assign_result
.
labels
[
pos_inds
]
else
:
self
.
pos_gt_labels
=
None
@
property
def
bboxes
(
self
):
return
torch
.
cat
([
self
.
pos_bboxes
,
self
.
neg_bboxes
])
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment