Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
c6fde230
Commit
c6fde230
authored
Dec 11, 2018
by
pangjm
Browse files
Merge branch 'master' of github.com:open-mmlab/mmdetection
Conflicts: tools/train.py
parents
e74519bb
826a5613
Changes
64
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
473 additions
and
290 deletions
+473
-290
configs/rpn_r50_fpn_1x.py
configs/rpn_r50_fpn_1x.py
+3
-3
demo/coco_test_12510.jpg
demo/coco_test_12510.jpg
+0
-0
mmdet/core/anchor/anchor_target.py
mmdet/core/anchor/anchor_target.py
+70
-21
mmdet/core/bbox/__init__.py
mmdet/core/bbox/__init__.py
+11
-7
mmdet/core/bbox/assign_sampling.py
mmdet/core/bbox/assign_sampling.py
+35
-0
mmdet/core/bbox/assigners/__init__.py
mmdet/core/bbox/assigners/__init__.py
+5
-0
mmdet/core/bbox/assigners/assign_result.py
mmdet/core/bbox/assigners/assign_result.py
+19
-0
mmdet/core/bbox/assigners/base_assigner.py
mmdet/core/bbox/assigners/base_assigner.py
+8
-0
mmdet/core/bbox/assigners/max_iou_assigner.py
mmdet/core/bbox/assigners/max_iou_assigner.py
+15
-23
mmdet/core/bbox/samplers/__init__.py
mmdet/core/bbox/samplers/__init__.py
+13
-0
mmdet/core/bbox/samplers/base_sampler.py
mmdet/core/bbox/samplers/base_sampler.py
+64
-0
mmdet/core/bbox/samplers/combined_sampler.py
mmdet/core/bbox/samplers/combined_sampler.py
+14
-0
mmdet/core/bbox/samplers/instance_balanced_pos_sampler.py
mmdet/core/bbox/samplers/instance_balanced_pos_sampler.py
+41
-0
mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py
mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py
+62
-0
mmdet/core/bbox/samplers/pseudo_sampler.py
mmdet/core/bbox/samplers/pseudo_sampler.py
+26
-0
mmdet/core/bbox/samplers/random_sampler.py
mmdet/core/bbox/samplers/random_sampler.py
+55
-0
mmdet/core/bbox/samplers/sampling_result.py
mmdet/core/bbox/samplers/sampling_result.py
+24
-0
mmdet/core/bbox/sampling.py
mmdet/core/bbox/sampling.py
+0
-227
mmdet/core/post_processing/bbox_nms.py
mmdet/core/post_processing/bbox_nms.py
+7
-6
mmdet/core/post_processing/merge_augs.py
mmdet/core/post_processing/merge_augs.py
+1
-3
No files found.
configs/rpn_r50_fpn_1x.py
View file @
c6fde230
...
@@ -28,17 +28,17 @@ model = dict(
...
@@ -28,17 +28,17 @@ model = dict(
train_cfg
=
dict
(
train_cfg
=
dict
(
rpn
=
dict
(
rpn
=
dict
(
assigner
=
dict
(
assigner
=
dict
(
type
=
'MaxIoUAssigner'
,
pos_iou_thr
=
0.7
,
pos_iou_thr
=
0.7
,
neg_iou_thr
=
0.3
,
neg_iou_thr
=
0.3
,
min_pos_iou
=
0.3
,
min_pos_iou
=
0.3
,
ignore_iof_thr
=-
1
),
ignore_iof_thr
=-
1
),
sampler
=
dict
(
sampler
=
dict
(
type
=
'RandomSampler'
,
num
=
256
,
num
=
256
,
pos_fraction
=
0.5
,
pos_fraction
=
0.5
,
neg_pos_ub
=-
1
,
neg_pos_ub
=-
1
,
add_gt_as_proposals
=
False
,
add_gt_as_proposals
=
False
),
pos_balance_sampling
=
False
,
neg_balance_thr
=
0
),
allowed_border
=
0
,
allowed_border
=
0
,
pos_weight
=-
1
,
pos_weight
=-
1
,
smoothl1_beta
=
1
/
9.0
,
smoothl1_beta
=
1
/
9.0
,
...
...
demo/coco_test_12510.jpg
0 → 100644
View file @
c6fde230
179 KB
mmdet/core/anchor/anchor_target.py
View file @
c6fde230
import
torch
import
torch
from
..bbox
import
assign_and_sample
,
bbox2delta
from
..bbox
import
assign_and_sample
,
build_assigner
,
PseudoSampler
,
bbox2delta
from
..utils
import
multi_apply
from
..utils
import
multi_apply
def
anchor_target
(
anchor_list
,
valid_flag_list
,
gt_bboxes_list
,
img_metas
,
def
anchor_target
(
anchor_list
,
target_means
,
target_stds
,
cfg
):
valid_flag_list
,
gt_bboxes_list
,
img_metas
,
target_means
,
target_stds
,
cfg
,
gt_labels_list
=
None
,
cls_out_channels
=
1
,
sampling
=
True
):
"""Compute regression and classification targets for anchors.
"""Compute regression and classification targets for anchors.
Args:
Args:
...
@@ -32,28 +40,34 @@ def anchor_target(anchor_list, valid_flag_list, gt_bboxes_list, img_metas,
...
@@ -32,28 +40,34 @@ def anchor_target(anchor_list, valid_flag_list, gt_bboxes_list, img_metas,
valid_flag_list
[
i
]
=
torch
.
cat
(
valid_flag_list
[
i
])
valid_flag_list
[
i
]
=
torch
.
cat
(
valid_flag_list
[
i
])
# compute targets for each image
# compute targets for each image
means_replicas
=
[
target_means
for
_
in
range
(
num_imgs
)]
if
gt_labels_list
is
None
:
stds_replicas
=
[
target_stds
for
_
in
range
(
num_imgs
)]
gt_labels_list
=
[
None
for
_
in
range
(
num_imgs
)]
cfg_replicas
=
[
cfg
for
_
in
range
(
num_imgs
)]
(
all_labels
,
all_label_weights
,
all_bbox_targets
,
all_bbox_weights
,
(
all_labels
,
all_label_weights
,
all_bbox_targets
,
pos_inds_list
,
neg_inds_list
)
=
multi_apply
(
all_bbox_weights
,
pos_inds_list
,
neg_inds_list
)
=
multi_apply
(
anchor_target_single
,
anchor_target_single
,
anchor_list
,
valid_flag_list
,
gt_bboxes_list
,
anchor_list
,
img_metas
,
means_replicas
,
stds_replicas
,
cfg_replicas
)
valid_flag_list
,
gt_bboxes_list
,
gt_labels_list
,
img_metas
,
target_means
=
target_means
,
target_stds
=
target_stds
,
cfg
=
cfg
,
cls_out_channels
=
cls_out_channels
,
sampling
=
sampling
)
# no valid anchors
# no valid anchors
if
any
([
labels
is
None
for
labels
in
all_labels
]):
if
any
([
labels
is
None
for
labels
in
all_labels
]):
return
None
return
None
# sampled anchors of all images
# sampled anchors of all images
num_total_samples
=
sum
([
num_total_pos
=
sum
([
max
(
inds
.
numel
(),
1
)
for
inds
in
pos_inds_list
])
max
(
pos_inds
.
numel
()
+
neg_inds
.
numel
(),
1
)
num_total_neg
=
sum
([
max
(
inds
.
numel
(),
1
)
for
inds
in
neg_inds_list
])
for
pos_inds
,
neg_inds
in
zip
(
pos_inds_list
,
neg_inds_list
)
])
# split targets to a list w.r.t. multiple levels
# split targets to a list w.r.t. multiple levels
labels_list
=
images_to_levels
(
all_labels
,
num_level_anchors
)
labels_list
=
images_to_levels
(
all_labels
,
num_level_anchors
)
label_weights_list
=
images_to_levels
(
all_label_weights
,
num_level_anchors
)
label_weights_list
=
images_to_levels
(
all_label_weights
,
num_level_anchors
)
bbox_targets_list
=
images_to_levels
(
all_bbox_targets
,
num_level_anchors
)
bbox_targets_list
=
images_to_levels
(
all_bbox_targets
,
num_level_anchors
)
bbox_weights_list
=
images_to_levels
(
all_bbox_weights
,
num_level_anchors
)
bbox_weights_list
=
images_to_levels
(
all_bbox_weights
,
num_level_anchors
)
return
(
labels_list
,
label_weights_list
,
bbox_targets_list
,
return
(
labels_list
,
label_weights_list
,
bbox_targets_list
,
bbox_weights_list
,
num_total_
samples
)
bbox_weights_list
,
num_total_
pos
,
num_total_neg
)
def
images_to_levels
(
target
,
num_level_anchors
):
def
images_to_levels
(
target
,
num_level_anchors
):
...
@@ -71,8 +85,16 @@ def images_to_levels(target, num_level_anchors):
...
@@ -71,8 +85,16 @@ def images_to_levels(target, num_level_anchors):
return
level_targets
return
level_targets
def
anchor_target_single
(
flat_anchors
,
valid_flags
,
gt_bboxes
,
img_meta
,
def
anchor_target_single
(
flat_anchors
,
target_means
,
target_stds
,
cfg
):
valid_flags
,
gt_bboxes
,
gt_labels
,
img_meta
,
target_means
,
target_stds
,
cfg
,
cls_out_channels
=
1
,
sampling
=
True
):
inside_flags
=
anchor_inside_flags
(
flat_anchors
,
valid_flags
,
inside_flags
=
anchor_inside_flags
(
flat_anchors
,
valid_flags
,
img_meta
[
'img_shape'
][:
2
],
img_meta
[
'img_shape'
][:
2
],
cfg
.
allowed_border
)
cfg
.
allowed_border
)
...
@@ -80,13 +102,23 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta,
...
@@ -80,13 +102,23 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta,
return
(
None
,
)
*
6
return
(
None
,
)
*
6
# assign gt and sample anchors
# assign gt and sample anchors
anchors
=
flat_anchors
[
inside_flags
,
:]
anchors
=
flat_anchors
[
inside_flags
,
:]
_
,
sampling_result
=
assign_and_sample
(
anchors
,
gt_bboxes
,
None
,
None
,
cfg
)
if
sampling
:
assign_result
,
sampling_result
=
assign_and_sample
(
anchors
,
gt_bboxes
,
None
,
None
,
cfg
)
else
:
bbox_assigner
=
build_assigner
(
cfg
.
assigner
)
assign_result
=
bbox_assigner
.
assign
(
anchors
,
gt_bboxes
,
None
,
gt_labels
)
bbox_sampler
=
PseudoSampler
()
sampling_result
=
bbox_sampler
.
sample
(
assign_result
,
anchors
,
gt_bboxes
)
num_valid_anchors
=
anchors
.
shape
[
0
]
num_valid_anchors
=
anchors
.
shape
[
0
]
bbox_targets
=
torch
.
zeros_like
(
anchors
)
bbox_targets
=
torch
.
zeros_like
(
anchors
)
bbox_weights
=
torch
.
zeros_like
(
anchors
)
bbox_weights
=
torch
.
zeros_like
(
anchors
)
labels
=
anchors
.
new_zeros
(
(
num_valid_anchors
,
)
)
labels
=
anchors
.
new_zeros
(
num_valid_anchors
,
dtype
=
torch
.
long
)
label_weights
=
anchors
.
new_zeros
(
(
num_valid_anchors
,
)
)
label_weights
=
anchors
.
new_zeros
(
num_valid_anchors
,
dtype
=
torch
.
float
)
pos_inds
=
sampling_result
.
pos_inds
pos_inds
=
sampling_result
.
pos_inds
neg_inds
=
sampling_result
.
neg_inds
neg_inds
=
sampling_result
.
neg_inds
...
@@ -96,7 +128,10 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta,
...
@@ -96,7 +128,10 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta,
target_means
,
target_stds
)
target_means
,
target_stds
)
bbox_targets
[
pos_inds
,
:]
=
pos_bbox_targets
bbox_targets
[
pos_inds
,
:]
=
pos_bbox_targets
bbox_weights
[
pos_inds
,
:]
=
1.0
bbox_weights
[
pos_inds
,
:]
=
1.0
labels
[
pos_inds
]
=
1
if
gt_labels
is
None
:
labels
[
pos_inds
]
=
1
else
:
labels
[
pos_inds
]
=
gt_labels
[
sampling_result
.
pos_assigned_gt_inds
]
if
cfg
.
pos_weight
<=
0
:
if
cfg
.
pos_weight
<=
0
:
label_weights
[
pos_inds
]
=
1.0
label_weights
[
pos_inds
]
=
1.0
else
:
else
:
...
@@ -108,6 +143,9 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta,
...
@@ -108,6 +143,9 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta,
num_total_anchors
=
flat_anchors
.
size
(
0
)
num_total_anchors
=
flat_anchors
.
size
(
0
)
labels
=
unmap
(
labels
,
num_total_anchors
,
inside_flags
)
labels
=
unmap
(
labels
,
num_total_anchors
,
inside_flags
)
label_weights
=
unmap
(
label_weights
,
num_total_anchors
,
inside_flags
)
label_weights
=
unmap
(
label_weights
,
num_total_anchors
,
inside_flags
)
if
cls_out_channels
>
1
:
labels
,
label_weights
=
expand_binary_labels
(
labels
,
label_weights
,
cls_out_channels
)
bbox_targets
=
unmap
(
bbox_targets
,
num_total_anchors
,
inside_flags
)
bbox_targets
=
unmap
(
bbox_targets
,
num_total_anchors
,
inside_flags
)
bbox_weights
=
unmap
(
bbox_weights
,
num_total_anchors
,
inside_flags
)
bbox_weights
=
unmap
(
bbox_weights
,
num_total_anchors
,
inside_flags
)
...
@@ -115,6 +153,17 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta,
...
@@ -115,6 +153,17 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta,
neg_inds
)
neg_inds
)
def
expand_binary_labels
(
labels
,
label_weights
,
cls_out_channels
):
bin_labels
=
labels
.
new_full
(
(
labels
.
size
(
0
),
cls_out_channels
),
0
,
dtype
=
torch
.
float32
)
inds
=
torch
.
nonzero
(
labels
>=
1
).
squeeze
()
if
inds
.
numel
()
>
0
:
bin_labels
[
inds
,
labels
[
inds
]
-
1
]
=
1
bin_label_weights
=
label_weights
.
view
(
-
1
,
1
).
expand
(
label_weights
.
size
(
0
),
cls_out_channels
)
return
bin_labels
,
bin_label_weights
def
anchor_inside_flags
(
flat_anchors
,
valid_flags
,
img_shape
,
def
anchor_inside_flags
(
flat_anchors
,
valid_flags
,
img_shape
,
allowed_border
=
0
):
allowed_border
=
0
):
img_h
,
img_w
=
img_shape
[:
2
]
img_h
,
img_w
=
img_shape
[:
2
]
...
...
mmdet/core/bbox/__init__.py
View file @
c6fde230
from
.geometry
import
bbox_overlaps
from
.geometry
import
bbox_overlaps
from
.assignment
import
BBoxAssigner
,
AssignResult
from
.assigners
import
BaseAssigner
,
MaxIoUAssigner
,
AssignResult
from
.sampling
import
(
BBoxSampler
,
SamplingResult
,
assign_and_sample
,
from
.samplers
import
(
BaseSampler
,
PseudoSampler
,
RandomSampler
,
random_choice
)
InstanceBalancedPosSampler
,
IoUBalancedNegSampler
,
CombinedSampler
,
SamplingResult
)
from
.assign_sampling
import
build_assigner
,
build_sampler
,
assign_and_sample
from
.transforms
import
(
bbox2delta
,
delta2bbox
,
bbox_flip
,
bbox_mapping
,
from
.transforms
import
(
bbox2delta
,
delta2bbox
,
bbox_flip
,
bbox_mapping
,
bbox_mapping_back
,
bbox2roi
,
roi2bbox
,
bbox2result
)
bbox_mapping_back
,
bbox2roi
,
roi2bbox
,
bbox2result
)
from
.bbox_target
import
bbox_target
from
.bbox_target
import
bbox_target
__all__
=
[
__all__
=
[
'bbox_overlaps'
,
'BBoxAssigner'
,
'AssignResult'
,
'BBoxSampler'
,
'bbox_overlaps'
,
'BaseAssigner'
,
'MaxIoUAssigner'
,
'AssignResult'
,
'SamplingResult'
,
'assign_and_sample'
,
'random_choice'
,
'bbox2delta'
,
'BaseSampler'
,
'PseudoSampler'
,
'RandomSampler'
,
'delta2bbox'
,
'bbox_flip'
,
'bbox_mapping'
,
'bbox_mapping_back'
,
'bbox2roi'
,
'InstanceBalancedPosSampler'
,
'IoUBalancedNegSampler'
,
'CombinedSampler'
,
'roi2bbox'
,
'bbox2result'
,
'bbox_target'
'SamplingResult'
,
'build_assigner'
,
'build_sampler'
,
'assign_and_sample'
,
'bbox2delta'
,
'delta2bbox'
,
'bbox_flip'
,
'bbox_mapping'
,
'bbox_mapping_back'
,
'bbox2roi'
,
'roi2bbox'
,
'bbox2result'
,
'bbox_target'
]
]
mmdet/core/bbox/assign_sampling.py
0 → 100644
View file @
c6fde230
import
mmcv
from
.
import
assigners
,
samplers
def
build_assigner
(
cfg
,
default_args
=
None
):
if
isinstance
(
cfg
,
assigners
.
BaseAssigner
):
return
cfg
elif
isinstance
(
cfg
,
dict
):
return
mmcv
.
runner
.
obj_from_dict
(
cfg
,
assigners
,
default_args
=
default_args
)
else
:
raise
TypeError
(
'Invalid type {} for building a sampler'
.
format
(
type
(
cfg
)))
def
build_sampler
(
cfg
,
default_args
=
None
):
if
isinstance
(
cfg
,
samplers
.
BaseSampler
):
return
cfg
elif
isinstance
(
cfg
,
dict
):
return
mmcv
.
runner
.
obj_from_dict
(
cfg
,
samplers
,
default_args
=
default_args
)
else
:
raise
TypeError
(
'Invalid type {} for building a sampler'
.
format
(
type
(
cfg
)))
def
assign_and_sample
(
bboxes
,
gt_bboxes
,
gt_bboxes_ignore
,
gt_labels
,
cfg
):
bbox_assigner
=
build_assigner
(
cfg
.
assigner
)
bbox_sampler
=
build_sampler
(
cfg
.
sampler
)
assign_result
=
bbox_assigner
.
assign
(
bboxes
,
gt_bboxes
,
gt_bboxes_ignore
,
gt_labels
)
sampling_result
=
bbox_sampler
.
sample
(
assign_result
,
bboxes
,
gt_bboxes
,
gt_labels
)
return
assign_result
,
sampling_result
mmdet/core/bbox/assigners/__init__.py
0 → 100644
View file @
c6fde230
from
.base_assigner
import
BaseAssigner
from
.max_iou_assigner
import
MaxIoUAssigner
from
.assign_result
import
AssignResult
__all__
=
[
'BaseAssigner'
,
'MaxIoUAssigner'
,
'AssignResult'
]
mmdet/core/bbox/assigners/assign_result.py
0 → 100644
View file @
c6fde230
import
torch
class
AssignResult
(
object
):
def
__init__
(
self
,
num_gts
,
gt_inds
,
max_overlaps
,
labels
=
None
):
self
.
num_gts
=
num_gts
self
.
gt_inds
=
gt_inds
self
.
max_overlaps
=
max_overlaps
self
.
labels
=
labels
def
add_gt_
(
self
,
gt_labels
):
self_inds
=
torch
.
arange
(
1
,
len
(
gt_labels
)
+
1
,
dtype
=
torch
.
long
,
device
=
gt_labels
.
device
)
self
.
gt_inds
=
torch
.
cat
([
self_inds
,
self
.
gt_inds
])
self
.
max_overlaps
=
torch
.
cat
(
[
self
.
max_overlaps
.
new_ones
(
self
.
num_gts
),
self
.
max_overlaps
])
if
self
.
labels
is
not
None
:
self
.
labels
=
torch
.
cat
([
gt_labels
,
self
.
labels
])
mmdet/core/bbox/assigners/base_assigner.py
0 → 100644
View file @
c6fde230
from
abc
import
ABCMeta
,
abstractmethod
class
BaseAssigner
(
metaclass
=
ABCMeta
):
@
abstractmethod
def
assign
(
self
,
bboxes
,
gt_bboxes
,
gt_bboxes_ignore
=
None
,
gt_labels
=
None
):
pass
mmdet/core/bbox/assign
ment
.py
→
mmdet/core/bbox/assign
ers/max_iou_assigner
.py
View file @
c6fde230
import
torch
import
torch
from
.geometry
import
bbox_overlaps
from
.base_assigner
import
BaseAssigner
from
.assign_result
import
AssignResult
from
..geometry
import
bbox_overlaps
class
BBox
Assigner
(
object
):
class
MaxIoU
Assigner
(
BaseAssigner
):
"""Assign a corresponding gt bbox or background to each bbox.
"""Assign a corresponding gt bbox or background to each bbox.
Each proposals will be assigned with `-1`, `0`, or a positive integer
Each proposals will be assigned with `-1`, `0`, or a positive integer
...
@@ -17,8 +19,10 @@ class BBoxAssigner(object):
...
@@ -17,8 +19,10 @@ class BBoxAssigner(object):
pos_iou_thr (float): IoU threshold for positive bboxes.
pos_iou_thr (float): IoU threshold for positive bboxes.
neg_iou_thr (float or tuple): IoU threshold for negative bboxes.
neg_iou_thr (float or tuple): IoU threshold for negative bboxes.
min_pos_iou (float): Minimum iou for a bbox to be considered as a
min_pos_iou (float): Minimum iou for a bbox to be considered as a
positive bbox. For RPN, it is usually set as 0.3, for Fast R-CNN,
positive bbox. Positive samples can have smaller IoU than
it is usually set as pos_iou_thr
pos_iou_thr due to the 4th step (assign max IoU sample to each gt).
gt_max_assign_all (bool): Whether to assign all bboxes with the same
highest overlap with some gt to that gt.
ignore_iof_thr (float): IoF threshold for ignoring bboxes (if
ignore_iof_thr (float): IoF threshold for ignoring bboxes (if
`gt_bboxes_ignore` is specified). Negative values mean not
`gt_bboxes_ignore` is specified). Negative values mean not
ignoring any bboxes.
ignoring any bboxes.
...
@@ -28,10 +32,12 @@ class BBoxAssigner(object):
...
@@ -28,10 +32,12 @@ class BBoxAssigner(object):
pos_iou_thr
,
pos_iou_thr
,
neg_iou_thr
,
neg_iou_thr
,
min_pos_iou
=
.
0
,
min_pos_iou
=
.
0
,
gt_max_assign_all
=
True
,
ignore_iof_thr
=-
1
):
ignore_iof_thr
=-
1
):
self
.
pos_iou_thr
=
pos_iou_thr
self
.
pos_iou_thr
=
pos_iou_thr
self
.
neg_iou_thr
=
neg_iou_thr
self
.
neg_iou_thr
=
neg_iou_thr
self
.
min_pos_iou
=
min_pos_iou
self
.
min_pos_iou
=
min_pos_iou
self
.
gt_max_assign_all
=
gt_max_assign_all
self
.
ignore_iof_thr
=
ignore_iof_thr
self
.
ignore_iof_thr
=
ignore_iof_thr
def
assign
(
self
,
bboxes
,
gt_bboxes
,
gt_bboxes_ignore
=
None
,
gt_labels
=
None
):
def
assign
(
self
,
bboxes
,
gt_bboxes
,
gt_bboxes_ignore
=
None
,
gt_labels
=
None
):
...
@@ -122,7 +128,11 @@ class BBoxAssigner(object):
...
@@ -122,7 +128,11 @@ class BBoxAssigner(object):
# 4. assign fg: for each gt, proposals with highest IoU
# 4. assign fg: for each gt, proposals with highest IoU
for
i
in
range
(
num_gts
):
for
i
in
range
(
num_gts
):
if
gt_max_overlaps
[
i
]
>=
self
.
min_pos_iou
:
if
gt_max_overlaps
[
i
]
>=
self
.
min_pos_iou
:
assigned_gt_inds
[
overlaps
[:,
i
]
==
gt_max_overlaps
[
i
]]
=
i
+
1
if
self
.
gt_max_assign_all
:
max_iou_inds
=
overlaps
[:,
i
]
==
gt_max_overlaps
[
i
]
assigned_gt_inds
[
max_iou_inds
]
=
i
+
1
else
:
assigned_gt_inds
[
gt_argmax_overlaps
[
i
]]
=
i
+
1
if
gt_labels
is
not
None
:
if
gt_labels
is
not
None
:
assigned_labels
=
assigned_gt_inds
.
new_zeros
((
num_bboxes
,
))
assigned_labels
=
assigned_gt_inds
.
new_zeros
((
num_bboxes
,
))
...
@@ -135,21 +145,3 @@ class BBoxAssigner(object):
...
@@ -135,21 +145,3 @@ class BBoxAssigner(object):
return
AssignResult
(
return
AssignResult
(
num_gts
,
assigned_gt_inds
,
max_overlaps
,
labels
=
assigned_labels
)
num_gts
,
assigned_gt_inds
,
max_overlaps
,
labels
=
assigned_labels
)
class
AssignResult
(
object
):
def
__init__
(
self
,
num_gts
,
gt_inds
,
max_overlaps
,
labels
=
None
):
self
.
num_gts
=
num_gts
self
.
gt_inds
=
gt_inds
self
.
max_overlaps
=
max_overlaps
self
.
labels
=
labels
def
add_gt_
(
self
,
gt_labels
):
self_inds
=
torch
.
arange
(
1
,
len
(
gt_labels
)
+
1
,
dtype
=
torch
.
long
,
device
=
gt_labels
.
device
)
self
.
gt_inds
=
torch
.
cat
([
self_inds
,
self
.
gt_inds
])
self
.
max_overlaps
=
torch
.
cat
(
[
self
.
max_overlaps
.
new_ones
(
self
.
num_gts
),
self
.
max_overlaps
])
if
self
.
labels
is
not
None
:
self
.
labels
=
torch
.
cat
([
gt_labels
,
self
.
labels
])
mmdet/core/bbox/samplers/__init__.py
0 → 100644
View file @
c6fde230
from
.base_sampler
import
BaseSampler
from
.pseudo_sampler
import
PseudoSampler
from
.random_sampler
import
RandomSampler
from
.instance_balanced_pos_sampler
import
InstanceBalancedPosSampler
from
.iou_balanced_neg_sampler
import
IoUBalancedNegSampler
from
.combined_sampler
import
CombinedSampler
from
.sampling_result
import
SamplingResult
__all__
=
[
'BaseSampler'
,
'PseudoSampler'
,
'RandomSampler'
,
'InstanceBalancedPosSampler'
,
'IoUBalancedNegSampler'
,
'CombinedSampler'
,
'SamplingResult'
]
mmdet/core/bbox/samplers/base_sampler.py
0 → 100644
View file @
c6fde230
from
abc
import
ABCMeta
,
abstractmethod
import
torch
from
.sampling_result
import
SamplingResult
class
BaseSampler
(
metaclass
=
ABCMeta
):
def
__init__
(
self
):
self
.
pos_sampler
=
self
self
.
neg_sampler
=
self
@
abstractmethod
def
_sample_pos
(
self
,
assign_result
,
num_expected
):
pass
@
abstractmethod
def
_sample_neg
(
self
,
assign_result
,
num_expected
):
pass
def
sample
(
self
,
assign_result
,
bboxes
,
gt_bboxes
,
gt_labels
=
None
):
"""Sample positive and negative bboxes.
This is a simple implementation of bbox sampling given candidates,
assigning results and ground truth bboxes.
Args:
assign_result (:obj:`AssignResult`): Bbox assigning results.
bboxes (Tensor): Boxes to be sampled from.
gt_bboxes (Tensor): Ground truth bboxes.
gt_labels (Tensor, optional): Class labels of ground truth bboxes.
Returns:
:obj:`SamplingResult`: Sampling result.
"""
bboxes
=
bboxes
[:,
:
4
]
gt_flags
=
bboxes
.
new_zeros
((
bboxes
.
shape
[
0
],
),
dtype
=
torch
.
uint8
)
if
self
.
add_gt_as_proposals
:
bboxes
=
torch
.
cat
([
gt_bboxes
,
bboxes
],
dim
=
0
)
assign_result
.
add_gt_
(
gt_labels
)
gt_ones
=
bboxes
.
new_ones
(
gt_bboxes
.
shape
[
0
],
dtype
=
torch
.
uint8
)
gt_flags
=
torch
.
cat
([
gt_ones
,
gt_flags
])
num_expected_pos
=
int
(
self
.
num
*
self
.
pos_fraction
)
pos_inds
=
self
.
pos_sampler
.
_sample_pos
(
assign_result
,
num_expected_pos
)
# We found that sampled indices have duplicated items occasionally.
# (may be a bug of PyTorch)
pos_inds
=
pos_inds
.
unique
()
num_sampled_pos
=
pos_inds
.
numel
()
num_expected_neg
=
self
.
num
-
num_sampled_pos
if
self
.
neg_pos_ub
>=
0
:
_pos
=
max
(
1
,
num_sampled_pos
)
neg_upper_bound
=
int
(
self
.
neg_pos_ub
*
_pos
)
if
num_expected_neg
>
neg_upper_bound
:
num_expected_neg
=
neg_upper_bound
neg_inds
=
self
.
neg_sampler
.
_sample_neg
(
assign_result
,
num_expected_neg
)
neg_inds
=
neg_inds
.
unique
()
return
SamplingResult
(
pos_inds
,
neg_inds
,
bboxes
,
gt_bboxes
,
assign_result
,
gt_flags
)
mmdet/core/bbox/samplers/combined_sampler.py
0 → 100644
View file @
c6fde230
from
.random_sampler
import
RandomSampler
from
..assign_sampling
import
build_sampler
class
CombinedSampler
(
RandomSampler
):
def
__init__
(
self
,
num
,
pos_fraction
,
pos_sampler
,
neg_sampler
,
**
kwargs
):
super
(
CombinedSampler
,
self
).
__init__
(
num
,
pos_fraction
,
**
kwargs
)
default_args
=
dict
(
num
=
num
,
pos_fraction
=
pos_fraction
)
default_args
.
update
(
kwargs
)
self
.
pos_sampler
=
build_sampler
(
pos_sampler
,
default_args
=
default_args
)
self
.
neg_sampler
=
build_sampler
(
neg_sampler
,
default_args
=
default_args
)
mmdet/core/bbox/samplers/instance_balanced_pos_sampler.py
0 → 100644
View file @
c6fde230
import
numpy
as
np
import
torch
from
.random_sampler
import
RandomSampler
class
InstanceBalancedPosSampler
(
RandomSampler
):
def
_sample_pos
(
self
,
assign_result
,
num_expected
):
pos_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
>
0
)
if
pos_inds
.
numel
()
!=
0
:
pos_inds
=
pos_inds
.
squeeze
(
1
)
if
pos_inds
.
numel
()
<=
num_expected
:
return
pos_inds
else
:
unique_gt_inds
=
assign_result
.
gt_inds
[
pos_inds
].
unique
()
num_gts
=
len
(
unique_gt_inds
)
num_per_gt
=
int
(
round
(
num_expected
/
float
(
num_gts
))
+
1
)
sampled_inds
=
[]
for
i
in
unique_gt_inds
:
inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
==
i
.
item
())
if
inds
.
numel
()
!=
0
:
inds
=
inds
.
squeeze
(
1
)
else
:
continue
if
len
(
inds
)
>
num_per_gt
:
inds
=
self
.
random_choice
(
inds
,
num_per_gt
)
sampled_inds
.
append
(
inds
)
sampled_inds
=
torch
.
cat
(
sampled_inds
)
if
len
(
sampled_inds
)
<
num_expected
:
num_extra
=
num_expected
-
len
(
sampled_inds
)
extra_inds
=
np
.
array
(
list
(
set
(
pos_inds
.
cpu
())
-
set
(
sampled_inds
.
cpu
())))
if
len
(
extra_inds
)
>
num_extra
:
extra_inds
=
self
.
random_choice
(
extra_inds
,
num_extra
)
extra_inds
=
torch
.
from_numpy
(
extra_inds
).
to
(
assign_result
.
gt_inds
.
device
).
long
()
sampled_inds
=
torch
.
cat
([
sampled_inds
,
extra_inds
])
elif
len
(
sampled_inds
)
>
num_expected
:
sampled_inds
=
self
.
random_choice
(
sampled_inds
,
num_expected
)
return
sampled_inds
mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py
0 → 100644
View file @
c6fde230
import
numpy
as
np
import
torch
from
.random_sampler
import
RandomSampler
class
IoUBalancedNegSampler
(
RandomSampler
):
def
__init__
(
self
,
num
,
pos_fraction
,
hard_thr
=
0.1
,
hard_fraction
=
0.5
,
**
kwargs
):
super
(
IoUBalancedNegSampler
,
self
).
__init__
(
num
,
pos_fraction
,
**
kwargs
)
assert
hard_thr
>
0
assert
0
<
hard_fraction
<
1
self
.
hard_thr
=
hard_thr
self
.
hard_fraction
=
hard_fraction
def
_sample_neg
(
self
,
assign_result
,
num_expected
):
neg_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
==
0
)
if
neg_inds
.
numel
()
!=
0
:
neg_inds
=
neg_inds
.
squeeze
(
1
)
if
len
(
neg_inds
)
<=
num_expected
:
return
neg_inds
else
:
max_overlaps
=
assign_result
.
max_overlaps
.
cpu
().
numpy
()
# balance sampling for negative samples
neg_set
=
set
(
neg_inds
.
cpu
().
numpy
())
easy_set
=
set
(
np
.
where
(
np
.
logical_and
(
max_overlaps
>=
0
,
max_overlaps
<
self
.
hard_thr
))[
0
])
hard_set
=
set
(
np
.
where
(
max_overlaps
>=
self
.
hard_thr
)[
0
])
easy_neg_inds
=
list
(
easy_set
&
neg_set
)
hard_neg_inds
=
list
(
hard_set
&
neg_set
)
num_expected_hard
=
int
(
num_expected
*
self
.
hard_fraction
)
if
len
(
hard_neg_inds
)
>
num_expected_hard
:
sampled_hard_inds
=
self
.
random_choice
(
hard_neg_inds
,
num_expected_hard
)
else
:
sampled_hard_inds
=
np
.
array
(
hard_neg_inds
,
dtype
=
np
.
int
)
num_expected_easy
=
num_expected
-
len
(
sampled_hard_inds
)
if
len
(
easy_neg_inds
)
>
num_expected_easy
:
sampled_easy_inds
=
self
.
random_choice
(
easy_neg_inds
,
num_expected_easy
)
else
:
sampled_easy_inds
=
np
.
array
(
easy_neg_inds
,
dtype
=
np
.
int
)
sampled_inds
=
np
.
concatenate
((
sampled_easy_inds
,
sampled_hard_inds
))
if
len
(
sampled_inds
)
<
num_expected
:
num_extra
=
num_expected
-
len
(
sampled_inds
)
extra_inds
=
np
.
array
(
list
(
neg_set
-
set
(
sampled_inds
)))
if
len
(
extra_inds
)
>
num_extra
:
extra_inds
=
self
.
random_choice
(
extra_inds
,
num_extra
)
sampled_inds
=
np
.
concatenate
((
sampled_inds
,
extra_inds
))
sampled_inds
=
torch
.
from_numpy
(
sampled_inds
).
long
().
to
(
assign_result
.
gt_inds
.
device
)
return
sampled_inds
mmdet/core/bbox/samplers/pseudo_sampler.py
0 → 100644
View file @
c6fde230
import
torch
from
.base_sampler
import
BaseSampler
from
.sampling_result
import
SamplingResult
class
PseudoSampler
(
BaseSampler
):
def
__init__
(
self
):
pass
def
_sample_pos
(
self
):
raise
NotImplementedError
def
_sample_neg
(
self
):
raise
NotImplementedError
def
sample
(
self
,
assign_result
,
bboxes
,
gt_bboxes
):
pos_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
>
0
).
squeeze
(
-
1
).
unique
()
neg_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
==
0
).
squeeze
(
-
1
).
unique
()
gt_flags
=
bboxes
.
new_zeros
(
bboxes
.
shape
[
0
],
dtype
=
torch
.
uint8
)
sampling_result
=
SamplingResult
(
pos_inds
,
neg_inds
,
bboxes
,
gt_bboxes
,
assign_result
,
gt_flags
)
return
sampling_result
mmdet/core/bbox/samplers/random_sampler.py
0 → 100644
View file @
c6fde230
import
numpy
as
np
import
torch
from
.base_sampler
import
BaseSampler
class
RandomSampler
(
BaseSampler
):
def
__init__
(
self
,
num
,
pos_fraction
,
neg_pos_ub
=-
1
,
add_gt_as_proposals
=
True
):
super
(
RandomSampler
,
self
).
__init__
()
self
.
num
=
num
self
.
pos_fraction
=
pos_fraction
self
.
neg_pos_ub
=
neg_pos_ub
self
.
add_gt_as_proposals
=
add_gt_as_proposals
@
staticmethod
def
random_choice
(
gallery
,
num
):
"""Random select some elements from the gallery.
It seems that Pytorch's implementation is slower than numpy so we use
numpy to randperm the indices.
"""
assert
len
(
gallery
)
>=
num
if
isinstance
(
gallery
,
list
):
gallery
=
np
.
array
(
gallery
)
cands
=
np
.
arange
(
len
(
gallery
))
np
.
random
.
shuffle
(
cands
)
rand_inds
=
cands
[:
num
]
if
not
isinstance
(
gallery
,
np
.
ndarray
):
rand_inds
=
torch
.
from_numpy
(
rand_inds
).
long
().
to
(
gallery
.
device
)
return
gallery
[
rand_inds
]
def
_sample_pos
(
self
,
assign_result
,
num_expected
):
"""Randomly sample some positive samples."""
pos_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
>
0
)
if
pos_inds
.
numel
()
!=
0
:
pos_inds
=
pos_inds
.
squeeze
(
1
)
if
pos_inds
.
numel
()
<=
num_expected
:
return
pos_inds
else
:
return
self
.
random_choice
(
pos_inds
,
num_expected
)
def
_sample_neg
(
self
,
assign_result
,
num_expected
):
"""Randomly sample some negative samples."""
neg_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
==
0
)
if
neg_inds
.
numel
()
!=
0
:
neg_inds
=
neg_inds
.
squeeze
(
1
)
if
len
(
neg_inds
)
<=
num_expected
:
return
neg_inds
else
:
return
self
.
random_choice
(
neg_inds
,
num_expected
)
mmdet/core/bbox/samplers/sampling_result.py
0 → 100644
View file @
c6fde230
import
torch
class
SamplingResult
(
object
):
def
__init__
(
self
,
pos_inds
,
neg_inds
,
bboxes
,
gt_bboxes
,
assign_result
,
gt_flags
):
self
.
pos_inds
=
pos_inds
self
.
neg_inds
=
neg_inds
self
.
pos_bboxes
=
bboxes
[
pos_inds
]
self
.
neg_bboxes
=
bboxes
[
neg_inds
]
self
.
pos_is_gt
=
gt_flags
[
pos_inds
]
self
.
num_gts
=
gt_bboxes
.
shape
[
0
]
self
.
pos_assigned_gt_inds
=
assign_result
.
gt_inds
[
pos_inds
]
-
1
self
.
pos_gt_bboxes
=
gt_bboxes
[
self
.
pos_assigned_gt_inds
,
:]
if
assign_result
.
labels
is
not
None
:
self
.
pos_gt_labels
=
assign_result
.
labels
[
pos_inds
]
else
:
self
.
pos_gt_labels
=
None
@
property
def
bboxes
(
self
):
return
torch
.
cat
([
self
.
pos_bboxes
,
self
.
neg_bboxes
])
mmdet/core/bbox/sampling.py
deleted
100644 → 0
View file @
e74519bb
import
numpy
as
np
import
torch
from
.assignment
import
BBoxAssigner
def
random_choice
(
gallery
,
num
):
"""Random select some elements from the gallery.
It seems that Pytorch's implementation is slower than numpy so we use numpy
to randperm the indices.
"""
assert
len
(
gallery
)
>=
num
if
isinstance
(
gallery
,
list
):
gallery
=
np
.
array
(
gallery
)
cands
=
np
.
arange
(
len
(
gallery
))
np
.
random
.
shuffle
(
cands
)
rand_inds
=
cands
[:
num
]
if
not
isinstance
(
gallery
,
np
.
ndarray
):
rand_inds
=
torch
.
from_numpy
(
rand_inds
).
long
().
to
(
gallery
.
device
)
return
gallery
[
rand_inds
]
def
assign_and_sample
(
bboxes
,
gt_bboxes
,
gt_bboxes_ignore
,
gt_labels
,
cfg
):
bbox_assigner
=
BBoxAssigner
(
**
cfg
.
assigner
)
bbox_sampler
=
BBoxSampler
(
**
cfg
.
sampler
)
assign_result
=
bbox_assigner
.
assign
(
bboxes
,
gt_bboxes
,
gt_bboxes_ignore
,
gt_labels
)
sampling_result
=
bbox_sampler
.
sample
(
assign_result
,
bboxes
,
gt_bboxes
,
gt_labels
)
return
assign_result
,
sampling_result
class
BBoxSampler
(
object
):
"""Sample positive and negative bboxes given assigned results.
Args:
pos_fraction (float): Positive sample fraction.
neg_pos_ub (float): Negative/Positive upper bound.
pos_balance_sampling (bool): Whether to sample positive samples around
each gt bbox evenly.
neg_balance_thr (float, optional): IoU threshold for simple/hard
negative balance sampling.
neg_hard_fraction (float, optional): Fraction of hard negative samples
for negative balance sampling.
"""
def
__init__
(
self
,
num
,
pos_fraction
,
neg_pos_ub
=-
1
,
add_gt_as_proposals
=
True
,
pos_balance_sampling
=
False
,
neg_balance_thr
=
0
,
neg_hard_fraction
=
0.5
):
self
.
num
=
num
self
.
pos_fraction
=
pos_fraction
self
.
neg_pos_ub
=
neg_pos_ub
self
.
add_gt_as_proposals
=
add_gt_as_proposals
self
.
pos_balance_sampling
=
pos_balance_sampling
self
.
neg_balance_thr
=
neg_balance_thr
self
.
neg_hard_fraction
=
neg_hard_fraction
def
_sample_pos
(
self
,
assign_result
,
num_expected
):
"""Balance sampling for positive bboxes/anchors.
1. calculate average positive num for each gt: num_per_gt
2. sample at most num_per_gt positives for each gt
3. random sampling from rest anchors if not enough fg
"""
pos_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
>
0
)
if
pos_inds
.
numel
()
!=
0
:
pos_inds
=
pos_inds
.
squeeze
(
1
)
if
pos_inds
.
numel
()
<=
num_expected
:
return
pos_inds
elif
not
self
.
pos_balance_sampling
:
return
random_choice
(
pos_inds
,
num_expected
)
else
:
unique_gt_inds
=
torch
.
unique
(
assign_result
.
gt_inds
[
pos_inds
].
cpu
())
num_gts
=
len
(
unique_gt_inds
)
num_per_gt
=
int
(
round
(
num_expected
/
float
(
num_gts
))
+
1
)
sampled_inds
=
[]
for
i
in
unique_gt_inds
:
inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
==
i
.
item
())
if
inds
.
numel
()
!=
0
:
inds
=
inds
.
squeeze
(
1
)
else
:
continue
if
len
(
inds
)
>
num_per_gt
:
inds
=
random_choice
(
inds
,
num_per_gt
)
sampled_inds
.
append
(
inds
)
sampled_inds
=
torch
.
cat
(
sampled_inds
)
if
len
(
sampled_inds
)
<
num_expected
:
num_extra
=
num_expected
-
len
(
sampled_inds
)
extra_inds
=
np
.
array
(
list
(
set
(
pos_inds
.
cpu
())
-
set
(
sampled_inds
.
cpu
())))
if
len
(
extra_inds
)
>
num_extra
:
extra_inds
=
random_choice
(
extra_inds
,
num_extra
)
extra_inds
=
torch
.
from_numpy
(
extra_inds
).
to
(
assign_result
.
gt_inds
.
device
).
long
()
sampled_inds
=
torch
.
cat
([
sampled_inds
,
extra_inds
])
elif
len
(
sampled_inds
)
>
num_expected
:
sampled_inds
=
random_choice
(
sampled_inds
,
num_expected
)
return
sampled_inds
def
_sample_neg
(
self
,
assign_result
,
num_expected
):
"""Balance sampling for negative bboxes/anchors.
Negative samples are split into 2 set: hard (balance_thr <= iou <
neg_iou_thr) and easy (iou < balance_thr). The sampling ratio is
controlled by `hard_fraction`.
"""
neg_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
==
0
)
if
neg_inds
.
numel
()
!=
0
:
neg_inds
=
neg_inds
.
squeeze
(
1
)
if
len
(
neg_inds
)
<=
num_expected
:
return
neg_inds
elif
self
.
neg_balance_thr
<=
0
:
# uniform sampling among all negative samples
return
random_choice
(
neg_inds
,
num_expected
)
else
:
max_overlaps
=
assign_result
.
max_overlaps
.
cpu
().
numpy
()
# balance sampling for negative samples
neg_set
=
set
(
neg_inds
.
cpu
().
numpy
())
easy_set
=
set
(
np
.
where
(
np
.
logical_and
(
max_overlaps
>=
0
,
max_overlaps
<
self
.
neg_balance_thr
))[
0
])
hard_set
=
set
(
np
.
where
(
max_overlaps
>=
self
.
neg_balance_thr
)[
0
])
easy_neg_inds
=
list
(
easy_set
&
neg_set
)
hard_neg_inds
=
list
(
hard_set
&
neg_set
)
num_expected_hard
=
int
(
num_expected
*
self
.
neg_hard_fraction
)
if
len
(
hard_neg_inds
)
>
num_expected_hard
:
sampled_hard_inds
=
random_choice
(
hard_neg_inds
,
num_expected_hard
)
else
:
sampled_hard_inds
=
np
.
array
(
hard_neg_inds
,
dtype
=
np
.
int
)
num_expected_easy
=
num_expected
-
len
(
sampled_hard_inds
)
if
len
(
easy_neg_inds
)
>
num_expected_easy
:
sampled_easy_inds
=
random_choice
(
easy_neg_inds
,
num_expected_easy
)
else
:
sampled_easy_inds
=
np
.
array
(
easy_neg_inds
,
dtype
=
np
.
int
)
sampled_inds
=
np
.
concatenate
((
sampled_easy_inds
,
sampled_hard_inds
))
if
len
(
sampled_inds
)
<
num_expected
:
num_extra
=
num_expected
-
len
(
sampled_inds
)
extra_inds
=
np
.
array
(
list
(
neg_set
-
set
(
sampled_inds
)))
if
len
(
extra_inds
)
>
num_extra
:
extra_inds
=
random_choice
(
extra_inds
,
num_extra
)
sampled_inds
=
np
.
concatenate
((
sampled_inds
,
extra_inds
))
sampled_inds
=
torch
.
from_numpy
(
sampled_inds
).
long
().
to
(
assign_result
.
gt_inds
.
device
)
return
sampled_inds
def
sample
(
self
,
assign_result
,
bboxes
,
gt_bboxes
,
gt_labels
=
None
):
"""Sample positive and negative bboxes.
This is a simple implementation of bbox sampling given candidates,
assigning results and ground truth bboxes.
1. Assign gt to each bbox.
2. Add gt bboxes to the sampling pool (optional).
3. Perform positive and negative sampling.
Args:
assign_result (:obj:`AssignResult`): Bbox assigning results.
bboxes (Tensor): Boxes to be sampled from.
gt_bboxes (Tensor): Ground truth bboxes.
gt_labels (Tensor, optional): Class labels of ground truth bboxes.
Returns:
:obj:`SamplingResult`: Sampling result.
"""
bboxes
=
bboxes
[:,
:
4
]
gt_flags
=
bboxes
.
new_zeros
((
bboxes
.
shape
[
0
],
),
dtype
=
torch
.
uint8
)
if
self
.
add_gt_as_proposals
:
bboxes
=
torch
.
cat
([
gt_bboxes
,
bboxes
],
dim
=
0
)
assign_result
.
add_gt_
(
gt_labels
)
gt_flags
=
torch
.
cat
([
bboxes
.
new_ones
((
gt_bboxes
.
shape
[
0
],
),
dtype
=
torch
.
uint8
),
gt_flags
])
num_expected_pos
=
int
(
self
.
num
*
self
.
pos_fraction
)
pos_inds
=
self
.
_sample_pos
(
assign_result
,
num_expected_pos
)
# We found that sampled indices have duplicated items occasionally.
# (mab be a bug of PyTorch)
pos_inds
=
pos_inds
.
unique
()
num_sampled_pos
=
pos_inds
.
numel
()
num_expected_neg
=
self
.
num
-
num_sampled_pos
if
self
.
neg_pos_ub
>=
0
:
num_neg_max
=
int
(
self
.
neg_pos_ub
*
num_sampled_pos
)
if
num_sampled_pos
>
0
else
int
(
self
.
neg_pos_ub
)
num_expected_neg
=
min
(
num_neg_max
,
num_expected_neg
)
neg_inds
=
self
.
_sample_neg
(
assign_result
,
num_expected_neg
)
neg_inds
=
neg_inds
.
unique
()
return
SamplingResult
(
pos_inds
,
neg_inds
,
bboxes
,
gt_bboxes
,
assign_result
,
gt_flags
)
class
SamplingResult
(
object
):
def
__init__
(
self
,
pos_inds
,
neg_inds
,
bboxes
,
gt_bboxes
,
assign_result
,
gt_flags
):
self
.
pos_inds
=
pos_inds
self
.
neg_inds
=
neg_inds
self
.
pos_bboxes
=
bboxes
[
pos_inds
]
self
.
neg_bboxes
=
bboxes
[
neg_inds
]
self
.
pos_is_gt
=
gt_flags
[
pos_inds
]
self
.
num_gts
=
gt_bboxes
.
shape
[
0
]
self
.
pos_assigned_gt_inds
=
assign_result
.
gt_inds
[
pos_inds
]
-
1
self
.
pos_gt_bboxes
=
gt_bboxes
[
self
.
pos_assigned_gt_inds
,
:]
if
assign_result
.
labels
is
not
None
:
self
.
pos_gt_labels
=
assign_result
.
labels
[
pos_inds
]
else
:
self
.
pos_gt_labels
=
None
@
property
def
bboxes
(
self
):
return
torch
.
cat
([
self
.
pos_bboxes
,
self
.
neg_bboxes
])
mmdet/core/post_processing/bbox_nms.py
View file @
c6fde230
import
torch
import
torch
from
mmdet.ops
import
nms
from
mmdet.ops
.nms
import
nms
_wrapper
def
multiclass_nms
(
multi_bboxes
,
multi_scores
,
score_thr
,
nms_
thr
,
max_num
=-
1
):
def
multiclass_nms
(
multi_bboxes
,
multi_scores
,
score_thr
,
nms_
cfg
,
max_num
=-
1
):
"""NMS for multi-class bboxes.
"""NMS for multi-class bboxes.
Args:
Args:
...
@@ -21,6 +21,9 @@ def multiclass_nms(multi_bboxes, multi_scores, score_thr, nms_thr, max_num=-1):
...
@@ -21,6 +21,9 @@ def multiclass_nms(multi_bboxes, multi_scores, score_thr, nms_thr, max_num=-1):
"""
"""
num_classes
=
multi_scores
.
shape
[
1
]
num_classes
=
multi_scores
.
shape
[
1
]
bboxes
,
labels
=
[],
[]
bboxes
,
labels
=
[],
[]
nms_cfg_
=
nms_cfg
.
copy
()
nms_type
=
nms_cfg_
.
pop
(
'type'
,
'nms'
)
nms_op
=
getattr
(
nms_wrapper
,
nms_type
)
for
i
in
range
(
1
,
num_classes
):
for
i
in
range
(
1
,
num_classes
):
cls_inds
=
multi_scores
[:,
i
]
>
score_thr
cls_inds
=
multi_scores
[:,
i
]
>
score_thr
if
not
cls_inds
.
any
():
if
not
cls_inds
.
any
():
...
@@ -32,11 +35,9 @@ def multiclass_nms(multi_bboxes, multi_scores, score_thr, nms_thr, max_num=-1):
...
@@ -32,11 +35,9 @@ def multiclass_nms(multi_bboxes, multi_scores, score_thr, nms_thr, max_num=-1):
_bboxes
=
multi_bboxes
[
cls_inds
,
i
*
4
:(
i
+
1
)
*
4
]
_bboxes
=
multi_bboxes
[
cls_inds
,
i
*
4
:(
i
+
1
)
*
4
]
_scores
=
multi_scores
[
cls_inds
,
i
]
_scores
=
multi_scores
[
cls_inds
,
i
]
cls_dets
=
torch
.
cat
([
_bboxes
,
_scores
[:,
None
]],
dim
=
1
)
cls_dets
=
torch
.
cat
([
_bboxes
,
_scores
[:,
None
]],
dim
=
1
)
# perform nms
cls_dets
,
_
=
nms_op
(
cls_dets
,
**
nms_cfg_
)
nms_keep
=
nms
(
cls_dets
,
nms_thr
)
cls_dets
=
cls_dets
[
nms_keep
,
:]
cls_labels
=
multi_bboxes
.
new_full
(
cls_labels
=
multi_bboxes
.
new_full
(
(
len
(
nms_keep
)
,
),
i
-
1
,
dtype
=
torch
.
long
)
(
cls_dets
.
shape
[
0
]
,
),
i
-
1
,
dtype
=
torch
.
long
)
bboxes
.
append
(
cls_dets
)
bboxes
.
append
(
cls_dets
)
labels
.
append
(
cls_labels
)
labels
.
append
(
cls_labels
)
if
bboxes
:
if
bboxes
:
...
...
mmdet/core/post_processing/merge_augs.py
View file @
c6fde230
...
@@ -29,9 +29,7 @@ def merge_aug_proposals(aug_proposals, img_metas, rpn_test_cfg):
...
@@ -29,9 +29,7 @@ def merge_aug_proposals(aug_proposals, img_metas, rpn_test_cfg):
scale_factor
,
flip
)
scale_factor
,
flip
)
recovered_proposals
.
append
(
_proposals
)
recovered_proposals
.
append
(
_proposals
)
aug_proposals
=
torch
.
cat
(
recovered_proposals
,
dim
=
0
)
aug_proposals
=
torch
.
cat
(
recovered_proposals
,
dim
=
0
)
nms_keep
=
nms
(
aug_proposals
,
rpn_test_cfg
.
nms_thr
,
merged_proposals
,
_
=
nms
(
aug_proposals
,
rpn_test_cfg
.
nms_thr
)
aug_proposals
.
get_device
())
merged_proposals
=
aug_proposals
[
nms_keep
,
:]
scores
=
merged_proposals
[:,
4
]
scores
=
merged_proposals
[:,
4
]
_
,
order
=
scores
.
sort
(
0
,
descending
=
True
)
_
,
order
=
scores
.
sort
(
0
,
descending
=
True
)
num
=
min
(
rpn_test_cfg
.
max_num
,
merged_proposals
.
shape
[
0
])
num
=
min
(
rpn_test_cfg
.
max_num
,
merged_proposals
.
shape
[
0
])
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment