Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
bac11303
Commit
bac11303
authored
Oct 17, 2018
by
Kai Chen
Browse files
add BBoxAssigner and BBoxSampler
parent
f3768bcd
Changes
14
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
522 additions
and
468 deletions
+522
-468
configs/fast_mask_rcnn_r50_fpn_1x.py
configs/fast_mask_rcnn_r50_fpn_1x.py
+12
-10
configs/fast_rcnn_r50_fpn_1x.py
configs/fast_rcnn_r50_fpn_1x.py
+12
-10
configs/faster_rcnn_r50_fpn_1x.py
configs/faster_rcnn_r50_fpn_1x.py
+24
-19
configs/mask_rcnn_r50_fpn_1x.py
configs/mask_rcnn_r50_fpn_1x.py
+24
-19
configs/rpn_r50_fpn_1x.py
configs/rpn_r50_fpn_1x.py
+12
-9
mmdet/core/anchor/anchor_target.py
mmdet/core/anchor/anchor_target.py
+10
-17
mmdet/core/bbox/__init__.py
mmdet/core/bbox/__init__.py
+7
-8
mmdet/core/bbox/assignment.py
mmdet/core/bbox/assignment.py
+155
-0
mmdet/core/bbox/bbox_target.py
mmdet/core/bbox/bbox_target.py
+25
-25
mmdet/core/bbox/sampling.py
mmdet/core/bbox/sampling.py
+190
-306
mmdet/datasets/coco.py
mmdet/datasets/coco.py
+6
-6
mmdet/models/bbox_heads/bbox_head.py
mmdet/models/bbox_heads/bbox_head.py
+8
-4
mmdet/models/detectors/two_stage.py
mmdet/models/detectors/two_stage.py
+32
-33
mmdet/models/mask_heads/fcn_mask_head.py
mmdet/models/mask_heads/fcn_mask_head.py
+5
-2
No files found.
configs/fast_mask_rcnn_r50_fpn_1x.py
View file @
bac11303
...
@@ -43,17 +43,19 @@ model = dict(
...
@@ -43,17 +43,19 @@ model = dict(
# model training and testing settings
# model training and testing settings
train_cfg
=
dict
(
train_cfg
=
dict
(
rcnn
=
dict
(
rcnn
=
dict
(
m
as
k_size
=
28
,
as
signer
=
dict
(
pos_iou_thr
=
0.5
,
pos_iou_thr
=
0.5
,
neg_iou_thr
=
0.5
,
neg_iou_thr
=
0.5
,
crowd_thr
=
1.1
,
min_pos_iou
=
0.5
,
roi_batch_size
=
512
,
ignore_iof_thr
=-
1
),
add_gt_as_proposals
=
True
,
sampler
=
dict
(
num
=
512
,
pos_fraction
=
0.25
,
pos_fraction
=
0.25
,
neg_pos_ub
=-
1
,
add_gt_as_proposals
=
True
,
pos_balance_sampling
=
False
,
pos_balance_sampling
=
False
,
neg_pos_ub
=
512
,
neg_balance_thr
=
0
),
neg_balance_thr
=
0
,
mask_size
=
28
,
min_pos_iou
=
0.5
,
pos_weight
=-
1
,
pos_weight
=-
1
,
debug
=
False
))
debug
=
False
))
test_cfg
=
dict
(
test_cfg
=
dict
(
...
...
configs/fast_rcnn_r50_fpn_1x.py
View file @
bac11303
...
@@ -32,16 +32,18 @@ model = dict(
...
@@ -32,16 +32,18 @@ model = dict(
# model training and testing settings
# model training and testing settings
train_cfg
=
dict
(
train_cfg
=
dict
(
rcnn
=
dict
(
rcnn
=
dict
(
assigner
=
dict
(
pos_iou_thr
=
0.5
,
pos_iou_thr
=
0.5
,
neg_iou_thr
=
0.5
,
neg_iou_thr
=
0.5
,
crowd_thr
=
1.1
,
min_pos_iou
=
0.5
,
roi_batch_size
=
512
,
ignore_iof_thr
=-
1
),
add_gt_as_proposals
=
True
,
sampler
=
dict
(
num
=
512
,
pos_fraction
=
0.25
,
pos_fraction
=
0.25
,
neg_pos_ub
=-
1
,
add_gt_as_proposals
=
True
,
pos_balance_sampling
=
False
,
pos_balance_sampling
=
False
,
neg_pos_ub
=
512
,
neg_balance_thr
=
0
),
neg_balance_thr
=
0
,
min_pos_iou
=
0.5
,
pos_weight
=-
1
,
pos_weight
=-
1
,
debug
=
False
))
debug
=
False
))
test_cfg
=
dict
(
rcnn
=
dict
(
score_thr
=
0.05
,
max_per_img
=
100
,
nms_thr
=
0.5
))
test_cfg
=
dict
(
rcnn
=
dict
(
score_thr
=
0.05
,
max_per_img
=
100
,
nms_thr
=
0.5
))
...
...
configs/faster_rcnn_r50_fpn_1x.py
View file @
bac11303
...
@@ -42,30 +42,35 @@ model = dict(
...
@@ -42,30 +42,35 @@ model = dict(
# model training and testing settings
# model training and testing settings
train_cfg
=
dict
(
train_cfg
=
dict
(
rpn
=
dict
(
rpn
=
dict
(
pos_fraction
=
0.5
,
assigner
=
dict
(
pos_balance_sampling
=
False
,
neg_pos_ub
=
256
,
allowed_border
=
0
,
crowd_thr
=
1.1
,
anchor_batch_size
=
256
,
pos_iou_thr
=
0.7
,
pos_iou_thr
=
0.7
,
neg_iou_thr
=
0.3
,
neg_iou_thr
=
0.3
,
neg_balance_thr
=
0
,
min_pos_iou
=
0.3
,
min_pos_iou
=
0.3
,
ignore_iof_thr
=-
1
),
sampler
=
dict
(
num
=
256
,
pos_fraction
=
0.5
,
neg_pos_ub
=-
1
,
add_gt_as_proposals
=
False
,
pos_balance_sampling
=
False
,
neg_balance_thr
=
0
),
allowed_border
=
0
,
pos_weight
=-
1
,
pos_weight
=-
1
,
smoothl1_beta
=
1
/
9.0
,
smoothl1_beta
=
1
/
9.0
,
debug
=
False
),
debug
=
False
),
rcnn
=
dict
(
rcnn
=
dict
(
assigner
=
dict
(
pos_iou_thr
=
0.5
,
pos_iou_thr
=
0.5
,
neg_iou_thr
=
0.5
,
neg_iou_thr
=
0.5
,
crowd_thr
=
1.1
,
min_pos_iou
=
0.5
,
roi_batch_size
=
512
,
ignore_iof_thr
=-
1
),
add_gt_as_proposals
=
True
,
sampler
=
dict
(
num
=
512
,
pos_fraction
=
0.25
,
pos_fraction
=
0.25
,
neg_pos_ub
=-
1
,
add_gt_as_proposals
=
True
,
pos_balance_sampling
=
False
,
pos_balance_sampling
=
False
,
neg_pos_ub
=
512
,
neg_balance_thr
=
0
),
neg_balance_thr
=
0
,
min_pos_iou
=
0.5
,
pos_weight
=-
1
,
pos_weight
=-
1
,
debug
=
False
))
debug
=
False
))
test_cfg
=
dict
(
test_cfg
=
dict
(
...
...
configs/mask_rcnn_r50_fpn_1x.py
View file @
bac11303
...
@@ -53,31 +53,36 @@ model = dict(
...
@@ -53,31 +53,36 @@ model = dict(
# model training and testing settings
# model training and testing settings
train_cfg
=
dict
(
train_cfg
=
dict
(
rpn
=
dict
(
rpn
=
dict
(
pos_fraction
=
0.5
,
assigner
=
dict
(
pos_balance_sampling
=
False
,
neg_pos_ub
=
256
,
allowed_border
=
0
,
crowd_thr
=
1.1
,
anchor_batch_size
=
256
,
pos_iou_thr
=
0.7
,
pos_iou_thr
=
0.7
,
neg_iou_thr
=
0.3
,
neg_iou_thr
=
0.3
,
neg_balance_thr
=
0
,
min_pos_iou
=
0.3
,
min_pos_iou
=
0.3
,
ignore_iof_thr
=-
1
),
sampler
=
dict
(
num
=
256
,
pos_fraction
=
0.5
,
neg_pos_ub
=-
1
,
add_gt_as_proposals
=
False
,
pos_balance_sampling
=
False
,
neg_balance_thr
=
0
),
allowed_border
=
0
,
pos_weight
=-
1
,
pos_weight
=-
1
,
smoothl1_beta
=
1
/
9.0
,
smoothl1_beta
=
1
/
9.0
,
debug
=
False
),
debug
=
False
),
rcnn
=
dict
(
rcnn
=
dict
(
m
as
k_size
=
28
,
as
signer
=
dict
(
pos_iou_thr
=
0.5
,
pos_iou_thr
=
0.5
,
neg_iou_thr
=
0.5
,
neg_iou_thr
=
0.5
,
crowd_thr
=
1.1
,
min_pos_iou
=
0.5
,
roi_batch_size
=
512
,
ignore_iof_thr
=-
1
),
add_gt_as_proposals
=
True
,
sampler
=
dict
(
num
=
512
,
pos_fraction
=
0.25
,
pos_fraction
=
0.25
,
neg_pos_ub
=-
1
,
add_gt_as_proposals
=
True
,
pos_balance_sampling
=
False
,
pos_balance_sampling
=
False
,
neg_pos_ub
=
512
,
neg_balance_thr
=
0
),
neg_balance_thr
=
0
,
mask_size
=
28
,
min_pos_iou
=
0.5
,
pos_weight
=-
1
,
pos_weight
=-
1
,
debug
=
False
))
debug
=
False
))
test_cfg
=
dict
(
test_cfg
=
dict
(
...
...
configs/rpn_r50_fpn_1x.py
View file @
bac11303
...
@@ -27,16 +27,19 @@ model = dict(
...
@@ -27,16 +27,19 @@ model = dict(
# model training and testing settings
# model training and testing settings
train_cfg
=
dict
(
train_cfg
=
dict
(
rpn
=
dict
(
rpn
=
dict
(
pos_fraction
=
0.5
,
assigner
=
dict
(
pos_balance_sampling
=
False
,
neg_pos_ub
=
256
,
allowed_border
=
0
,
crowd_thr
=
1.1
,
anchor_batch_size
=
256
,
pos_iou_thr
=
0.7
,
pos_iou_thr
=
0.7
,
neg_iou_thr
=
0.3
,
neg_iou_thr
=
0.3
,
neg_balance_thr
=
0
,
min_pos_iou
=
0.3
,
min_pos_iou
=
0.3
,
ignore_iof_thr
=-
1
),
sampler
=
dict
(
num
=
256
,
pos_fraction
=
0.5
,
neg_pos_ub
=-
1
,
add_gt_as_proposals
=
False
,
pos_balance_sampling
=
False
,
neg_balance_thr
=
0
),
allowed_border
=
0
,
pos_weight
=-
1
,
pos_weight
=-
1
,
smoothl1_beta
=
1
/
9.0
,
smoothl1_beta
=
1
/
9.0
,
debug
=
False
))
debug
=
False
))
...
...
mmdet/core/anchor/anchor_target.py
View file @
bac11303
import
torch
import
torch
from
..bbox
import
bbox_
assign
,
bbox2delta
,
bbox_sampling
from
..bbox
import
assign
_and_sample
,
bbox2delta
from
..utils
import
multi_apply
from
..utils
import
multi_apply
...
@@ -80,27 +80,20 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta,
...
@@ -80,27 +80,20 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta,
return
(
None
,
)
*
6
return
(
None
,
)
*
6
# assign gt and sample anchors
# assign gt and sample anchors
anchors
=
flat_anchors
[
inside_flags
,
:]
anchors
=
flat_anchors
[
inside_flags
,
:]
assigned_gt_inds
,
argmax_overlaps
,
max_overlaps
=
bbox_assign
(
_
,
sampling_result
=
assign_and_sample
(
anchors
,
gt_bboxes
,
None
,
None
,
cfg
)
anchors
,
gt_bboxes
,
pos_iou_thr
=
cfg
.
pos_iou_thr
,
neg_iou_thr
=
cfg
.
neg_iou_thr
,
min_pos_iou
=
cfg
.
min_pos_iou
)
pos_inds
,
neg_inds
=
bbox_sampling
(
assigned_gt_inds
,
cfg
.
anchor_batch_size
,
cfg
.
pos_fraction
,
cfg
.
neg_pos_ub
,
cfg
.
pos_balance_sampling
,
max_overlaps
,
cfg
.
neg_balance_thr
)
num_valid_anchors
=
anchors
.
shape
[
0
]
bbox_targets
=
torch
.
zeros_like
(
anchors
)
bbox_targets
=
torch
.
zeros_like
(
anchors
)
bbox_weights
=
torch
.
zeros_like
(
anchors
)
bbox_weights
=
torch
.
zeros_like
(
anchors
)
labels
=
torch
.
zeros_like
(
assigned_gt_inds
)
labels
=
anchors
.
new_zeros
((
num_valid_anchors
,
)
)
label_weights
=
torch
.
zeros_like
(
assigned_gt_inds
,
dtype
=
anchors
.
dtype
)
label_weights
=
anchors
.
new_zeros
((
num_valid_
anchors
,
)
)
pos_inds
=
sampling_result
.
pos_inds
neg_inds
=
sampling_result
.
neg_inds
if
len
(
pos_inds
)
>
0
:
if
len
(
pos_inds
)
>
0
:
pos_anchors
=
anchors
[
pos_inds
,
:]
pos_bbox_targets
=
bbox2delta
(
sampling_result
.
pos_bboxes
,
pos_gt_bbox
=
gt_bboxes
[
assigned_gt_inds
[
pos_inds
]
-
1
,
:]
sampling_result
.
pos_gt_bboxes
,
pos_bbox_targets
=
bbox2delta
(
pos_anchors
,
pos_gt_bbox
,
target_means
,
target_means
,
target_stds
)
target_stds
)
bbox_targets
[
pos_inds
,
:]
=
pos_bbox_targets
bbox_targets
[
pos_inds
,
:]
=
pos_bbox_targets
bbox_weights
[
pos_inds
,
:]
=
1.0
bbox_weights
[
pos_inds
,
:]
=
1.0
labels
[
pos_inds
]
=
1
labels
[
pos_inds
]
=
1
...
...
mmdet/core/bbox/__init__.py
View file @
bac11303
from
.geometry
import
bbox_overlaps
from
.geometry
import
bbox_overlaps
from
.
sampling
import
(
random_choice
,
bbox_assign
,
bbox_assign_wrt_overlaps
,
from
.
assignment
import
BBoxAssigner
,
AssignResult
bbox_sampling
,
bbox_sampling_pos
,
bbox_sampling_neg
,
from
.sampling
import
(
BBoxSampler
,
SamplingResult
,
assign_and_sample
,
sample_bboxes
)
random_choice
)
from
.transforms
import
(
bbox2delta
,
delta2bbox
,
bbox_flip
,
bbox_mapping
,
from
.transforms
import
(
bbox2delta
,
delta2bbox
,
bbox_flip
,
bbox_mapping
,
bbox_mapping_back
,
bbox2roi
,
roi2bbox
,
bbox2result
)
bbox_mapping_back
,
bbox2roi
,
roi2bbox
,
bbox2result
)
from
.bbox_target
import
bbox_target
from
.bbox_target
import
bbox_target
__all__
=
[
__all__
=
[
'bbox_overlaps'
,
'random_choice'
,
'bbox_assign'
,
'bbox_overlaps'
,
'BBoxAssigner'
,
'AssignResult'
,
'BBoxSampler'
,
'bbox_assign_wrt_overlaps'
,
'bbox_sampling'
,
'bbox_sampling_pos'
,
'SamplingResult'
,
'assign_and_sample'
,
'random_choice'
,
'bbox2delta'
,
'bbox_sampling_neg'
,
'sample_bboxes'
,
'bbox2delta'
,
'delta2bbox'
,
'delta2bbox'
,
'bbox_flip'
,
'bbox_mapping'
,
'bbox_mapping_back'
,
'bbox2roi'
,
'bbox_flip'
,
'bbox_mapping'
,
'bbox_mapping_back'
,
'bbox2roi'
,
'roi2bbox'
,
'roi2bbox'
,
'bbox2result'
,
'bbox_target'
'bbox2result'
,
'bbox_target'
]
]
mmdet/core/bbox/assignment.py
0 → 100644
View file @
bac11303
import
torch
from
.geometry
import
bbox_overlaps
class
BBoxAssigner
(
object
):
"""Assign a corresponding gt bbox or background to each bbox.
Each proposals will be assigned with `-1`, `0`, or a positive integer
indicating the ground truth index.
- -1: don't care
- 0: negative sample, no assigned gt
- positive integer: positive sample, index (1-based) of assigned gt
Args:
pos_iou_thr (float): IoU threshold for positive bboxes.
neg_iou_thr (float or tuple): IoU threshold for negative bboxes.
min_pos_iou (float): Minimum iou for a bbox to be considered as a
positive bbox. For RPN, it is usually set as 0.3, for Fast R-CNN,
it is usually set as pos_iou_thr
ignore_iof_thr (float): IoF threshold for ignoring bboxes (if
`gt_bboxes_ignore` is specified). Negative values mean not
ignoring any bboxes.
"""
def
__init__
(
self
,
pos_iou_thr
,
neg_iou_thr
,
min_pos_iou
=
.
0
,
ignore_iof_thr
=-
1
):
self
.
pos_iou_thr
=
pos_iou_thr
self
.
neg_iou_thr
=
neg_iou_thr
self
.
min_pos_iou
=
min_pos_iou
self
.
ignore_iof_thr
=
ignore_iof_thr
def
assign
(
self
,
bboxes
,
gt_bboxes
,
gt_bboxes_ignore
=
None
,
gt_labels
=
None
):
"""Assign gt to bboxes.
This method assign a gt bbox to every bbox (proposal/anchor), each bbox
will be assigned with -1, 0, or a positive number. -1 means don't care,
0 means negative sample, positive number is the index (1-based) of
assigned gt.
The assignment is done in following steps, the order matters.
1. assign every bbox to -1
2. assign proposals whose iou with all gts < neg_iou_thr to 0
3. for each bbox, if the iou with its nearest gt >= pos_iou_thr,
assign it to that bbox
4. for each gt bbox, assign its nearest proposals (may be more than
one) to itself
Args:
bboxes (Tensor): Bounding boxes to be assigned, shape(n, 4).
gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
labelled as `ignored`, e.g., crowd boxes in COCO.
gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
Returns:
:obj:`AssignResult`: The assign result.
"""
if
bboxes
.
shape
[
0
]
==
0
or
gt_bboxes
.
shape
[
0
]
==
0
:
raise
ValueError
(
'No gt or bboxes'
)
bboxes
=
bboxes
[:,
:
4
]
overlaps
=
bbox_overlaps
(
bboxes
,
gt_bboxes
)
if
(
self
.
ignore_iof_thr
>
0
)
and
(
gt_bboxes_ignore
is
not
None
)
and
(
gt_bboxes_ignore
.
numel
()
>
0
):
ignore_overlaps
=
bbox_overlaps
(
bboxes
,
gt_bboxes_ignore
,
mode
=
'iof'
)
ignore_max_overlaps
,
_
=
ignore_overlaps
.
max
(
dim
=
1
)
ignore_bboxes_inds
=
torch
.
nonzero
(
ignore_max_overlaps
>
self
.
ignore_iof_thr
).
squeeze
()
if
ignore_bboxes_inds
.
numel
()
>
0
:
overlaps
[
ignore_bboxes_inds
[:,
0
],
:]
=
-
1
assign_result
=
self
.
assign_wrt_overlaps
(
overlaps
,
gt_labels
)
return
assign_result
def
assign_wrt_overlaps
(
self
,
overlaps
,
gt_labels
=
None
):
"""Assign w.r.t. the overlaps of bboxes with gts.
Args:
overlaps (Tensor): Overlaps between n bboxes and k gt_bboxes,
shape(n, k).
gt_labels (Tensor, optional): Labels of k gt_bboxes, shape (k, ).
Returns:
:obj:`AssignResult`: The assign result.
"""
if
overlaps
.
numel
()
==
0
:
raise
ValueError
(
'No gt or proposals'
)
num_bboxes
,
num_gts
=
overlaps
.
size
(
0
),
overlaps
.
size
(
1
)
# 1. assign -1 by default
assigned_gt_inds
=
overlaps
.
new_full
(
(
num_bboxes
,
),
-
1
,
dtype
=
torch
.
long
)
assert
overlaps
.
size
()
==
(
num_bboxes
,
num_gts
)
# for each anchor, which gt best overlaps with it
# for each anchor, the max iou of all gts
max_overlaps
,
argmax_overlaps
=
overlaps
.
max
(
dim
=
1
)
# for each gt, which anchor best overlaps with it
# for each gt, the max iou of all proposals
gt_max_overlaps
,
gt_argmax_overlaps
=
overlaps
.
max
(
dim
=
0
)
# 2. assign negative: below
if
isinstance
(
self
.
neg_iou_thr
,
float
):
assigned_gt_inds
[(
max_overlaps
>=
0
)
&
(
max_overlaps
<
self
.
neg_iou_thr
)]
=
0
elif
isinstance
(
self
.
neg_iou_thr
,
tuple
):
assert
len
(
self
.
neg_iou_thr
)
==
2
assigned_gt_inds
[(
max_overlaps
>=
self
.
neg_iou_thr
[
0
])
&
(
max_overlaps
<
self
.
neg_iou_thr
[
1
])]
=
0
# 3. assign positive: above positive IoU threshold
pos_inds
=
max_overlaps
>=
self
.
pos_iou_thr
assigned_gt_inds
[
pos_inds
]
=
argmax_overlaps
[
pos_inds
]
+
1
# 4. assign fg: for each gt, proposals with highest IoU
for
i
in
range
(
num_gts
):
if
gt_max_overlaps
[
i
]
>=
self
.
min_pos_iou
:
assigned_gt_inds
[
overlaps
[:,
i
]
==
gt_max_overlaps
[
i
]]
=
i
+
1
if
gt_labels
is
not
None
:
assigned_labels
=
assigned_gt_inds
.
new_zeros
((
num_bboxes
,
))
pos_inds
=
torch
.
nonzero
(
assigned_gt_inds
>
0
).
squeeze
()
if
pos_inds
.
numel
()
>
0
:
assigned_labels
[
pos_inds
]
=
gt_labels
[
assigned_gt_inds
[
pos_inds
]
-
1
]
else
:
assigned_labels
=
None
return
AssignResult
(
num_gts
,
assigned_gt_inds
,
max_overlaps
,
labels
=
assigned_labels
)
class
AssignResult
(
object
):
def
__init__
(
self
,
num_gts
,
gt_inds
,
max_overlaps
,
labels
=
None
):
self
.
num_gts
=
num_gts
self
.
gt_inds
=
gt_inds
self
.
max_overlaps
=
max_overlaps
self
.
labels
=
labels
def
add_gt_
(
self
,
gt_labels
):
self_inds
=
torch
.
arange
(
1
,
len
(
gt_labels
)
+
1
,
dtype
=
torch
.
long
,
device
=
gt_labels
.
device
)
self
.
gt_inds
=
torch
.
cat
([
self_inds
,
self
.
gt_inds
])
self
.
max_overlaps
=
torch
.
cat
(
[
self
.
max_overlaps
.
new_ones
(
self
.
num_gts
),
self
.
max_overlaps
])
if
self
.
labels
is
not
None
:
self
.
labels
=
torch
.
cat
([
gt_labels
,
self
.
labels
])
mmdet/core/bbox/bbox_target.py
View file @
bac11303
...
@@ -4,23 +4,23 @@ from .transforms import bbox2delta
...
@@ -4,23 +4,23 @@ from .transforms import bbox2delta
from
..utils
import
multi_apply
from
..utils
import
multi_apply
def
bbox_target
(
pos_
proposal
s_list
,
def
bbox_target
(
pos_
bboxe
s_list
,
neg_
proposal
s_list
,
neg_
bboxe
s_list
,
pos_gt_bboxes_list
,
pos_gt_bboxes_list
,
pos_gt_labels_list
,
pos_gt_labels_list
,
cfg
,
cfg
,
reg_
num_
classes
=
1
,
reg_classes
=
1
,
target_means
=
[.
0
,
.
0
,
.
0
,
.
0
],
target_means
=
[.
0
,
.
0
,
.
0
,
.
0
],
target_stds
=
[
1.0
,
1.0
,
1.0
,
1.0
],
target_stds
=
[
1.0
,
1.0
,
1.0
,
1.0
],
concat
=
True
):
concat
=
True
):
labels
,
label_weights
,
bbox_targets
,
bbox_weights
=
multi_apply
(
labels
,
label_weights
,
bbox_targets
,
bbox_weights
=
multi_apply
(
proposal
_target_single
,
bbox
_target_single
,
pos_
proposal
s_list
,
pos_
bboxe
s_list
,
neg_
proposal
s_list
,
neg_
bboxe
s_list
,
pos_gt_bboxes_list
,
pos_gt_bboxes_list
,
pos_gt_labels_list
,
pos_gt_labels_list
,
cfg
=
cfg
,
cfg
=
cfg
,
reg_
num_
classes
=
reg_
num_
classes
,
reg_classes
=
reg_classes
,
target_means
=
target_means
,
target_means
=
target_means
,
target_stds
=
target_stds
)
target_stds
=
target_stds
)
...
@@ -32,34 +32,34 @@ def bbox_target(pos_proposals_list,
...
@@ -32,34 +32,34 @@ def bbox_target(pos_proposals_list,
return
labels
,
label_weights
,
bbox_targets
,
bbox_weights
return
labels
,
label_weights
,
bbox_targets
,
bbox_weights
def
proposal
_target_single
(
pos_
proposal
s
,
def
bbox
_target_single
(
pos_
bboxe
s
,
neg_proposal
s
,
neg_bboxe
s
,
pos_gt_bboxes
,
pos_gt_bboxes
,
pos_gt_labels
,
pos_gt_labels
,
cfg
,
cfg
,
reg_num
_classes
=
1
,
reg
_classes
=
1
,
target_means
=
[.
0
,
.
0
,
.
0
,
.
0
],
target_means
=
[.
0
,
.
0
,
.
0
,
.
0
],
target_stds
=
[
1.0
,
1.0
,
1.0
,
1.0
]):
target_stds
=
[
1.0
,
1.0
,
1.0
,
1.0
]):
num_pos
=
pos_
proposal
s
.
size
(
0
)
num_pos
=
pos_
bboxe
s
.
size
(
0
)
num_neg
=
neg_
proposal
s
.
size
(
0
)
num_neg
=
neg_
bboxe
s
.
size
(
0
)
num_samples
=
num_pos
+
num_neg
num_samples
=
num_pos
+
num_neg
labels
=
pos_
proposal
s
.
new_zeros
(
num_samples
,
dtype
=
torch
.
long
)
labels
=
pos_
bboxe
s
.
new_zeros
(
num_samples
,
dtype
=
torch
.
long
)
label_weights
=
pos_
proposal
s
.
new_zeros
(
num_samples
)
label_weights
=
pos_
bboxe
s
.
new_zeros
(
num_samples
)
bbox_targets
=
pos_
proposal
s
.
new_zeros
(
num_samples
,
4
)
bbox_targets
=
pos_
bboxe
s
.
new_zeros
(
num_samples
,
4
)
bbox_weights
=
pos_
proposal
s
.
new_zeros
(
num_samples
,
4
)
bbox_weights
=
pos_
bboxe
s
.
new_zeros
(
num_samples
,
4
)
if
num_pos
>
0
:
if
num_pos
>
0
:
labels
[:
num_pos
]
=
pos_gt_labels
labels
[:
num_pos
]
=
pos_gt_labels
pos_weight
=
1.0
if
cfg
.
pos_weight
<=
0
else
cfg
.
pos_weight
pos_weight
=
1.0
if
cfg
.
pos_weight
<=
0
else
cfg
.
pos_weight
label_weights
[:
num_pos
]
=
pos_weight
label_weights
[:
num_pos
]
=
pos_weight
pos_bbox_targets
=
bbox2delta
(
pos_
proposal
s
,
pos_gt_bboxes
,
pos_bbox_targets
=
bbox2delta
(
pos_
bboxe
s
,
pos_gt_bboxes
,
target_means
,
target_means
,
target_stds
)
target_stds
)
bbox_targets
[:
num_pos
,
:]
=
pos_bbox_targets
bbox_targets
[:
num_pos
,
:]
=
pos_bbox_targets
bbox_weights
[:
num_pos
,
:]
=
1
bbox_weights
[:
num_pos
,
:]
=
1
if
num_neg
>
0
:
if
num_neg
>
0
:
label_weights
[
-
num_neg
:]
=
1.0
label_weights
[
-
num_neg
:]
=
1.0
if
reg_
num_
classes
>
1
:
if
reg_classes
>
1
:
bbox_targets
,
bbox_weights
=
expand_target
(
bbox_targets
,
bbox_weights
,
bbox_targets
,
bbox_weights
=
expand_target
(
bbox_targets
,
bbox_weights
,
labels
,
reg_
num_
classes
)
labels
,
reg_classes
)
return
labels
,
label_weights
,
bbox_targets
,
bbox_weights
return
labels
,
label_weights
,
bbox_targets
,
bbox_weights
...
...
mmdet/core/bbox/sampling.py
View file @
bac11303
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
.
geometry
import
bbox_overlaps
from
.
assignment
import
BBoxAssigner
def
random_choice
(
gallery
,
num
):
def
random_choice
(
gallery
,
num
):
...
@@ -21,158 +21,68 @@ def random_choice(gallery, num):
...
@@ -21,158 +21,68 @@ def random_choice(gallery, num):
return
gallery
[
rand_inds
]
return
gallery
[
rand_inds
]
def
bbox_assign
(
proposals
,
def
assign_and_sample
(
bboxes
,
gt_bboxes
,
gt_bboxes_ignore
,
gt_labels
,
cfg
):
gt_bboxes
,
bbox_assigner
=
BBoxAssigner
(
**
cfg
.
assigner
)
gt_bboxes_ignore
=
None
,
bbox_sampler
=
BBoxSampler
(
**
cfg
.
sampler
)
gt_labels
=
None
,
assign_result
=
bbox_assigner
.
assign
(
bboxes
,
gt_bboxes
,
gt_bboxes_ignore
,
pos_iou_thr
=
0.5
,
gt_labels
)
neg_iou_thr
=
0.5
,
sampling_result
=
bbox_sampler
.
sample
(
assign_result
,
bboxes
,
gt_bboxes
,
min_pos_iou
=
.
0
,
gt_labels
)
crowd_thr
=-
1
):
return
assign_result
,
sampling_result
"""Assign a corresponding gt bbox or background to each proposal/anchor.
Each proposals will be assigned with `-1`, `0`, or a positive integer.
- -1: don't care
class
BBoxSampler
(
object
):
- 0: negative sample, no assigned gt
"""Sample positive and negative bboxes given assigned results.
- positive integer: positive sample, index (1-based) of assigned gt
If `gt_bboxes_ignore` is specified, bboxes which have iof (intersection
over foreground) with `gt_bboxes_ignore` above `crowd_thr` will be ignored.
Args:
proposals (Tensor): Proposals or RPN anchors, shape (n, 4).
gt_bboxes (Tensor): Ground truth bboxes, shape (k, 4).
gt_bboxes_ignore (Tensor, optional): shape(m, 4).
gt_labels (Tensor, optional): shape (k, ).
pos_iou_thr (float): IoU threshold for positive bboxes.
neg_iou_thr (float or tuple): IoU threshold for negative bboxes.
min_pos_iou (float): Minimum iou for a bbox to be considered as a
positive bbox. For RPN, it is usually set as 0.3, for Fast R-CNN,
it is usually set as pos_iou_thr
crowd_thr (float): IoF threshold for ignoring bboxes. Negative value
for not ignoring any bboxes.
Returns:
tuple: (assigned_gt_inds, argmax_overlaps, max_overlaps), shape (n, )
"""
# calculate overlaps between the proposals and the gt boxes
overlaps
=
bbox_overlaps
(
proposals
,
gt_bboxes
)
if
overlaps
.
numel
()
==
0
:
raise
ValueError
(
'No gt bbox or proposals'
)
# ignore proposals according to crowd bboxes
if
(
crowd_thr
>
0
)
and
(
gt_bboxes_ignore
is
not
None
)
and
(
gt_bboxes_ignore
.
numel
()
>
0
):
crowd_overlaps
=
bbox_overlaps
(
proposals
,
gt_bboxes_ignore
,
mode
=
'iof'
)
crowd_max_overlaps
,
_
=
crowd_overlaps
.
max
(
dim
=
1
)
crowd_bboxes_inds
=
torch
.
nonzero
(
crowd_max_overlaps
>
crowd_thr
).
long
()
if
crowd_bboxes_inds
.
numel
()
>
0
:
overlaps
[
crowd_bboxes_inds
,
:]
=
-
1
return
bbox_assign_wrt_overlaps
(
overlaps
,
gt_labels
,
pos_iou_thr
,
neg_iou_thr
,
min_pos_iou
)
def
bbox_assign_wrt_overlaps
(
overlaps
,
gt_labels
=
None
,
pos_iou_thr
=
0.5
,
neg_iou_thr
=
0.5
,
min_pos_iou
=
.
0
):
"""Assign a corresponding gt bbox or background to each proposal/anchor.
This method assign a gt bbox to every proposal, each proposals will be
assigned with -1, 0, or a positive number. -1 means don't care, 0 means
negative sample, positive number is the index (1-based) of assigned gt.
The assignment is done in following steps, the order matters:
1. assign every anchor to -1
2. assign proposals whose iou with all gts < neg_iou_thr to 0
3. for each anchor, if the iou with its nearest gt >= pos_iou_thr,
assign it to that bbox
4. for each gt bbox, assign its nearest proposals(may be more than one)
to itself
Args:
Args:
overlaps (Tensor): Overlaps between n proposals and k gt_bboxes,
pos_fraction (float): Positive sample fraction.
shape(n, k).
neg_pos_ub (float): Negative/Positive upper bound.
gt_labels (Tensor, optional): Labels of k gt_bboxes, shape (k, ).
pos_balance_sampling (bool): Whether to sample positive samples around
pos_iou_thr (float): IoU threshold for positive bboxes.
each gt bbox evenly.
neg_iou_thr (float or tuple): IoU threshold for negative bboxes.
neg_balance_thr (float, optional): IoU threshold for simple/hard
min_pos_iou (float): Minimum IoU for a bbox to be considered as a
negative balance sampling.
positive bbox. This argument only affects the 4th step.
neg_hard_fraction (float, optional): Fraction of hard negative samples
for negative balance sampling.
Returns:
tuple: (assigned_gt_inds, [assigned_labels], argmax_overlaps,
max_overlaps), shape (n, )
"""
"""
num_bboxes
,
num_gts
=
overlaps
.
size
(
0
),
overlaps
.
size
(
1
)
# 1. assign -1 by default
assigned_gt_inds
=
overlaps
.
new
(
num_bboxes
).
long
().
fill_
(
-
1
)
if
overlaps
.
numel
()
==
0
:
raise
ValueError
(
'No gt bbox or proposals'
)
assert
overlaps
.
size
()
==
(
num_bboxes
,
num_gts
)
# for each anchor, which gt best overlaps with it
# for each anchor, the max iou of all gts
max_overlaps
,
argmax_overlaps
=
overlaps
.
max
(
dim
=
1
)
# for each gt, which anchor best overlaps with it
# for each gt, the max iou of all proposals
gt_max_overlaps
,
gt_argmax_overlaps
=
overlaps
.
max
(
dim
=
0
)
# 2. assign negative: below
if
isinstance
(
neg_iou_thr
,
float
):
assigned_gt_inds
[(
max_overlaps
>=
0
)
&
(
max_overlaps
<
neg_iou_thr
)]
=
0
elif
isinstance
(
neg_iou_thr
,
tuple
):
assert
len
(
neg_iou_thr
)
==
2
assigned_gt_inds
[(
max_overlaps
>=
neg_iou_thr
[
0
])
&
(
max_overlaps
<
neg_iou_thr
[
1
])]
=
0
# 3. assign positive: above positive IoU threshold
def
__init__
(
self
,
pos_inds
=
max_overlaps
>=
pos_iou_thr
num
,
assigned_gt_inds
[
pos_inds
]
=
argmax_overlaps
[
pos_inds
]
+
1
pos_fraction
,
neg_pos_ub
=-
1
,
# 4. assign fg: for each gt, proposals with highest IoU
add_gt_as_proposals
=
True
,
for
i
in
range
(
num_gts
):
pos_balance_sampling
=
False
,
if
gt_max_overlaps
[
i
]
>=
min_pos_iou
:
neg_balance_thr
=
0
,
assigned_gt_inds
[
overlaps
[:,
i
]
==
gt_max_overlaps
[
i
]]
=
i
+
1
neg_hard_fraction
=
0.5
):
self
.
num
=
num
if
gt_labels
is
None
:
self
.
pos_fraction
=
pos_fraction
return
assigned_gt_inds
,
argmax_overlaps
,
max_overlaps
self
.
neg_pos_ub
=
neg_pos_ub
else
:
self
.
add_gt_as_proposals
=
add_gt_as_proposals
assigned_labels
=
assigned_gt_inds
.
new
(
num_bboxes
).
fill_
(
0
)
self
.
pos_balance_sampling
=
pos_balance_sampling
pos_inds
=
torch
.
nonzero
(
assigned_gt_inds
>
0
).
squeeze
()
self
.
neg_balance_thr
=
neg_balance_thr
if
pos_inds
.
numel
()
>
0
:
self
.
neg_hard_fraction
=
neg_hard_fraction
assigned_labels
[
pos_inds
]
=
gt_labels
[
assigned_gt_inds
[
pos_inds
]
-
1
]
def
_sample_pos
(
self
,
assign_result
,
num_expected
):
return
assigned_gt_inds
,
assigned_labels
,
argmax_overlaps
,
max_overlaps
def
bbox_sampling_pos
(
assigned_gt_inds
,
num_expected
,
balance_sampling
=
True
):
"""Balance sampling for positive bboxes/anchors.
"""Balance sampling for positive bboxes/anchors.
1. calculate average positive num for each gt: num_per_gt
1. calculate average positive num for each gt: num_per_gt
2. sample at most num_per_gt positives for each gt
2. sample at most num_per_gt positives for each gt
3. random sampling from rest anchors if not enough fg
3. random sampling from rest anchors if not enough fg
"""
"""
pos_inds
=
torch
.
nonzero
(
assign
ed_
gt_inds
>
0
)
pos_inds
=
torch
.
nonzero
(
assign
_result
.
gt_inds
>
0
)
if
pos_inds
.
numel
()
!=
0
:
if
pos_inds
.
numel
()
!=
0
:
pos_inds
=
pos_inds
.
squeeze
(
1
)
pos_inds
=
pos_inds
.
squeeze
(
1
)
if
pos_inds
.
numel
()
<=
num_expected
:
if
pos_inds
.
numel
()
<=
num_expected
:
return
pos_inds
return
pos_inds
elif
not
balance_sampling
:
elif
not
self
.
pos_
balance_sampling
:
return
random_choice
(
pos_inds
,
num_expected
)
return
random_choice
(
pos_inds
,
num_expected
)
else
:
else
:
unique_gt_inds
=
torch
.
unique
(
assigned_gt_inds
[
pos_inds
].
cpu
())
unique_gt_inds
=
torch
.
unique
(
assign_result
.
gt_inds
[
pos_inds
].
cpu
())
num_gts
=
len
(
unique_gt_inds
)
num_gts
=
len
(
unique_gt_inds
)
num_per_gt
=
int
(
round
(
num_expected
/
float
(
num_gts
))
+
1
)
num_per_gt
=
int
(
round
(
num_expected
/
float
(
num_gts
))
+
1
)
sampled_inds
=
[]
sampled_inds
=
[]
for
i
in
unique_gt_inds
:
for
i
in
unique_gt_inds
:
inds
=
torch
.
nonzero
(
assign
ed_
gt_inds
==
i
.
item
())
inds
=
torch
.
nonzero
(
assign
_result
.
gt_inds
==
i
.
item
())
if
inds
.
numel
()
!=
0
:
if
inds
.
numel
()
!=
0
:
inds
=
inds
.
squeeze
(
1
)
inds
=
inds
.
squeeze
(
1
)
else
:
else
:
...
@@ -188,56 +98,53 @@ def bbox_sampling_pos(assigned_gt_inds, num_expected, balance_sampling=True):
...
@@ -188,56 +98,53 @@ def bbox_sampling_pos(assigned_gt_inds, num_expected, balance_sampling=True):
if
len
(
extra_inds
)
>
num_extra
:
if
len
(
extra_inds
)
>
num_extra
:
extra_inds
=
random_choice
(
extra_inds
,
num_extra
)
extra_inds
=
random_choice
(
extra_inds
,
num_extra
)
extra_inds
=
torch
.
from_numpy
(
extra_inds
).
to
(
extra_inds
=
torch
.
from_numpy
(
extra_inds
).
to
(
assign
ed_
gt_inds
.
device
).
long
()
assign
_result
.
gt_inds
.
device
).
long
()
sampled_inds
=
torch
.
cat
([
sampled_inds
,
extra_inds
])
sampled_inds
=
torch
.
cat
([
sampled_inds
,
extra_inds
])
elif
len
(
sampled_inds
)
>
num_expected
:
elif
len
(
sampled_inds
)
>
num_expected
:
sampled_inds
=
random_choice
(
sampled_inds
,
num_expected
)
sampled_inds
=
random_choice
(
sampled_inds
,
num_expected
)
return
sampled_inds
return
sampled_inds
def
_sample_neg
(
self
,
assign_result
,
num_expected
):
def
bbox_sampling_neg
(
assigned_gt_inds
,
num_expected
,
max_overlaps
=
None
,
balance_thr
=
0
,
hard_fraction
=
0.5
):
"""Balance sampling for negative bboxes/anchors.
"""Balance sampling for negative bboxes/anchors.
Negative samples are split into 2 set: hard (balance_thr <= iou <
Negative samples are split into 2 set: hard (balance_thr <= iou <
neg_iou_thr) and easy(iou < balance_thr). The sampling ratio is
controlled
neg_iou_thr) and easy
(iou < balance_thr). The sampling ratio is
by `hard_fraction`.
controlled
by `hard_fraction`.
"""
"""
neg_inds
=
torch
.
nonzero
(
assign
ed_
gt_inds
==
0
)
neg_inds
=
torch
.
nonzero
(
assign
_result
.
gt_inds
==
0
)
if
neg_inds
.
numel
()
!=
0
:
if
neg_inds
.
numel
()
!=
0
:
neg_inds
=
neg_inds
.
squeeze
(
1
)
neg_inds
=
neg_inds
.
squeeze
(
1
)
if
len
(
neg_inds
)
<=
num_expected
:
if
len
(
neg_inds
)
<=
num_expected
:
return
neg_inds
return
neg_inds
elif
balance_thr
<=
0
:
elif
self
.
neg_
balance_thr
<=
0
:
# uniform sampling among all negative samples
# uniform sampling among all negative samples
return
random_choice
(
neg_inds
,
num_expected
)
return
random_choice
(
neg_inds
,
num_expected
)
else
:
else
:
assert
max_overlaps
is
not
None
max_overlaps
=
assign_result
.
max_overlaps
.
cpu
().
numpy
()
max_overlaps
=
max_overlaps
.
cpu
().
numpy
()
# balance sampling for negative samples
# balance sampling for negative samples
neg_set
=
set
(
neg_inds
.
cpu
().
numpy
())
neg_set
=
set
(
neg_inds
.
cpu
().
numpy
())
easy_set
=
set
(
easy_set
=
set
(
np
.
where
(
np
.
where
(
np
.
logical_and
(
max_overlaps
>=
0
,
np
.
logical_and
(
max_overlaps
>=
0
,
max_overlaps
<
balance_thr
))[
0
])
max_overlaps
<
self
.
neg_
balance_thr
))[
0
])
hard_set
=
set
(
np
.
where
(
max_overlaps
>=
balance_thr
)[
0
])
hard_set
=
set
(
np
.
where
(
max_overlaps
>=
self
.
neg_
balance_thr
)[
0
])
easy_neg_inds
=
list
(
easy_set
&
neg_set
)
easy_neg_inds
=
list
(
easy_set
&
neg_set
)
hard_neg_inds
=
list
(
hard_set
&
neg_set
)
hard_neg_inds
=
list
(
hard_set
&
neg_set
)
num_expected_hard
=
int
(
num_expected
*
hard_fraction
)
num_expected_hard
=
int
(
num_expected
*
self
.
neg_
hard_fraction
)
if
len
(
hard_neg_inds
)
>
num_expected_hard
:
if
len
(
hard_neg_inds
)
>
num_expected_hard
:
sampled_hard_inds
=
random_choice
(
hard_neg_inds
,
num_expected_hard
)
sampled_hard_inds
=
random_choice
(
hard_neg_inds
,
num_expected_hard
)
else
:
else
:
sampled_hard_inds
=
np
.
array
(
hard_neg_inds
,
dtype
=
np
.
int
)
sampled_hard_inds
=
np
.
array
(
hard_neg_inds
,
dtype
=
np
.
int
)
num_expected_easy
=
num_expected
-
len
(
sampled_hard_inds
)
num_expected_easy
=
num_expected
-
len
(
sampled_hard_inds
)
if
len
(
easy_neg_inds
)
>
num_expected_easy
:
if
len
(
easy_neg_inds
)
>
num_expected_easy
:
sampled_easy_inds
=
random_choice
(
easy_neg_inds
,
num_expected_easy
)
sampled_easy_inds
=
random_choice
(
easy_neg_inds
,
num_expected_easy
)
else
:
else
:
sampled_easy_inds
=
np
.
array
(
easy_neg_inds
,
dtype
=
np
.
int
)
sampled_easy_inds
=
np
.
array
(
easy_neg_inds
,
dtype
=
np
.
int
)
sampled_inds
=
np
.
concatenate
((
sampled_easy_inds
,
sampled_hard_inds
))
sampled_inds
=
np
.
concatenate
((
sampled_easy_inds
,
sampled_hard_inds
))
if
len
(
sampled_inds
)
<
num_expected
:
if
len
(
sampled_inds
)
<
num_expected
:
num_extra
=
num_expected
-
len
(
sampled_inds
)
num_extra
=
num_expected
-
len
(
sampled_inds
)
extra_inds
=
np
.
array
(
list
(
neg_set
-
set
(
sampled_inds
)))
extra_inds
=
np
.
array
(
list
(
neg_set
-
set
(
sampled_inds
)))
...
@@ -245,99 +152,76 @@ def bbox_sampling_neg(assigned_gt_inds,
...
@@ -245,99 +152,76 @@ def bbox_sampling_neg(assigned_gt_inds,
extra_inds
=
random_choice
(
extra_inds
,
num_extra
)
extra_inds
=
random_choice
(
extra_inds
,
num_extra
)
sampled_inds
=
np
.
concatenate
((
sampled_inds
,
extra_inds
))
sampled_inds
=
np
.
concatenate
((
sampled_inds
,
extra_inds
))
sampled_inds
=
torch
.
from_numpy
(
sampled_inds
).
long
().
to
(
sampled_inds
=
torch
.
from_numpy
(
sampled_inds
).
long
().
to
(
assign
ed_
gt_inds
.
device
)
assign
_result
.
gt_inds
.
device
)
return
sampled_inds
return
sampled_inds
def
sample
(
self
,
assign_result
,
bboxes
,
gt_bboxes
,
gt_labels
=
None
):
def
bbox_sampling
(
assigned_gt_inds
,
num_expected
,
pos_fraction
,
neg_pos_ub
,
pos_balance_sampling
=
True
,
max_overlaps
=
None
,
neg_balance_thr
=
0
,
neg_hard_fraction
=
0.5
):
"""Sample positive and negative bboxes given assigned results.
Args:
assigned_gt_inds (Tensor): Assigned gt indices for each bbox.
num_expected (int): Expected total samples (pos and neg).
pos_fraction (float): Positive sample fraction.
neg_pos_ub (float): Negative/Positive upper bound.
pos_balance_sampling(bool): Whether to sample positive samples around
each gt bbox evenly.
max_overlaps (Tensor, optional): For each bbox, the max IoU of all gts.
Used for negative balance sampling only.
neg_balance_thr (float, optional): IoU threshold for simple/hard
negative balance sampling.
neg_hard_fraction (float, optional): Fraction of hard negative samples
for negative balance sampling.
Returns:
tuple[Tensor]: positive bbox indices, negative bbox indices.
"""
num_expected_pos
=
int
(
num_expected
*
pos_fraction
)
pos_inds
=
bbox_sampling_pos
(
assigned_gt_inds
,
num_expected_pos
,
pos_balance_sampling
)
# We found that sampled indices have duplicated items occasionally.
# (mab be a bug of PyTorch)
pos_inds
=
pos_inds
.
unique
()
num_sampled_pos
=
pos_inds
.
numel
()
num_neg_max
=
int
(
neg_pos_ub
*
num_sampled_pos
)
if
num_sampled_pos
>
0
else
int
(
neg_pos_ub
)
num_expected_neg
=
min
(
num_neg_max
,
num_expected
-
num_sampled_pos
)
neg_inds
=
bbox_sampling_neg
(
assigned_gt_inds
,
num_expected_neg
,
max_overlaps
,
neg_balance_thr
,
neg_hard_fraction
)
neg_inds
=
neg_inds
.
unique
()
return
pos_inds
,
neg_inds
def
sample_bboxes
(
bboxes
,
gt_bboxes
,
gt_bboxes_ignore
,
gt_labels
,
cfg
):
"""Sample positive and negative bboxes.
"""Sample positive and negative bboxes.
This is a simple implementation of bbox sampling given candidates
and
This is a simple implementation of bbox sampling given candidates
,
ground truth bboxes, which includes 3 step
s.
assigning results and ground truth bboxe
s.
1. Assign gt to each bbox.
1. Assign gt to each bbox.
2. Add gt bboxes to the sampling pool (optional).
2. Add gt bboxes to the sampling pool (optional).
3. Perform positive and negative sampling.
3. Perform positive and negative sampling.
Args:
Args:
assign_result (:obj:`AssignResult`): Bbox assigning results.
bboxes (Tensor): Boxes to be sampled from.
bboxes (Tensor): Boxes to be sampled from.
gt_bboxes (Tensor): Ground truth bboxes.
gt_bboxes (Tensor): Ground truth bboxes.
gt_bboxes_ignore (Tensor): Ignored ground truth bboxes. In MS COCO,
gt_labels (Tensor, optional): Class labels of ground truth bboxes.
`crowd` bboxes are considered as ignored.
gt_labels (Tensor): Class labels of ground truth bboxes.
cfg (dict): Sampling configs.
Returns:
Returns:
tuple[Tensor]: pos_bboxes, neg_bboxes, pos_assigned_gt_inds,
:obj:`SamplingResult`: Sampling result.
pos_gt_bboxes, pos_gt_labels
"""
"""
bboxes
=
bboxes
[:,
:
4
]
bboxes
=
bboxes
[:,
:
4
]
assigned_gt_inds
,
assigned_labels
,
argmax_overlaps
,
max_overlaps
=
\
bbox_assign
(
bboxes
,
gt_bboxes
,
gt_bboxes_ignore
,
gt_labels
,
cfg
.
pos_iou_thr
,
cfg
.
neg_iou_thr
,
cfg
.
min_pos_iou
,
cfg
.
crowd_thr
)
if
cfg
.
add_gt_as_proposals
:
gt_flags
=
bboxes
.
new_zeros
((
bboxes
.
shape
[
0
],
),
dtype
=
torch
.
uint8
)
if
self
.
add_gt_as_proposals
:
bboxes
=
torch
.
cat
([
gt_bboxes
,
bboxes
],
dim
=
0
)
bboxes
=
torch
.
cat
([
gt_bboxes
,
bboxes
],
dim
=
0
)
gt_assign_self
=
torch
.
arange
(
assign_result
.
add_gt_
(
gt_labels
)
1
,
len
(
gt_labels
)
+
1
,
dtype
=
torch
.
long
,
device
=
bboxes
.
device
)
gt_flags
=
torch
.
cat
([
assigned_gt_inds
=
torch
.
cat
([
gt_assign_self
,
assigned_gt_inds
])
bboxes
.
new_ones
((
gt_bboxes
.
shape
[
0
],
),
dtype
=
torch
.
uint8
),
assigned_labels
=
torch
.
cat
([
gt_labels
,
assigned_labels
])
gt_flags
])
num_expected_pos
=
int
(
self
.
num
*
self
.
pos_fraction
)
pos_inds
=
self
.
_sample_pos
(
assign_result
,
num_expected_pos
)
# We found that sampled indices have duplicated items occasionally.
# (mab be a bug of PyTorch)
pos_inds
=
pos_inds
.
unique
()
num_sampled_pos
=
pos_inds
.
numel
()
num_expected_neg
=
self
.
num
-
num_sampled_pos
if
self
.
neg_pos_ub
>=
0
:
num_neg_max
=
int
(
self
.
neg_pos_ub
*
num_sampled_pos
)
if
num_sampled_pos
>
0
else
int
(
self
.
neg_pos_ub
)
num_expected_neg
=
min
(
num_neg_max
,
num_expected_neg
)
neg_inds
=
self
.
_sample_neg
(
assign_result
,
num_expected_neg
)
neg_inds
=
neg_inds
.
unique
()
pos_inds
,
neg_inds
=
bbox_sampling
(
return
SamplingResult
(
pos_inds
,
neg_inds
,
bboxes
,
gt_bboxes
,
assigned_gt_inds
,
cfg
.
roi_batch_size
,
cfg
.
pos_fraction
,
cfg
.
neg_pos_ub
,
assign_result
,
gt_flags
)
cfg
.
pos_balance_sampling
,
max_overlaps
,
cfg
.
neg_balance_thr
)
pos_bboxes
=
bboxes
[
pos_inds
]
neg_bboxes
=
bboxes
[
neg_inds
]
pos_assigned_gt_inds
=
assigned_gt_inds
[
pos_inds
]
-
1
pos_gt_bboxes
=
gt_bboxes
[
pos_assigned_gt_inds
,
:]
pos_gt_labels
=
assigned_labels
[
pos_inds
]
return
(
pos_bboxes
,
neg_bboxes
,
pos_assigned_gt_inds
,
pos_gt_bboxes
,
class
SamplingResult
(
object
):
pos_gt_labels
)
def
__init__
(
self
,
pos_inds
,
neg_inds
,
bboxes
,
gt_bboxes
,
assign_result
,
gt_flags
):
self
.
pos_inds
=
pos_inds
self
.
neg_inds
=
neg_inds
self
.
pos_bboxes
=
bboxes
[
pos_inds
]
self
.
neg_bboxes
=
bboxes
[
neg_inds
]
self
.
pos_is_gt
=
gt_flags
[
pos_inds
]
self
.
num_gts
=
gt_bboxes
.
shape
[
0
]
self
.
pos_assigned_gt_inds
=
assign_result
.
gt_inds
[
pos_inds
]
-
1
self
.
pos_gt_bboxes
=
gt_bboxes
[
self
.
pos_assigned_gt_inds
,
:]
if
assign_result
.
labels
is
not
None
:
self
.
pos_gt_labels
=
assign_result
.
labels
[
pos_inds
]
else
:
self
.
pos_gt_labels
=
None
@
property
def
bboxes
(
self
):
return
torch
.
cat
([
self
.
pos_bboxes
,
self
.
neg_bboxes
])
mmdet/datasets/coco.py
View file @
bac11303
...
@@ -215,7 +215,7 @@ class CocoDataset(Dataset):
...
@@ -215,7 +215,7 @@ class CocoDataset(Dataset):
'proposals should have shapes (n, 4) or (n, 5), '
'proposals should have shapes (n, 4) or (n, 5), '
'but found {}'
.
format
(
proposals
.
shape
))
'but found {}'
.
format
(
proposals
.
shape
))
if
proposals
.
shape
[
1
]
==
5
:
if
proposals
.
shape
[
1
]
==
5
:
scores
=
proposals
[:,
4
]
scores
=
proposals
[:,
4
,
None
]
proposals
=
proposals
[:,
:
4
]
proposals
=
proposals
[:,
:
4
]
else
:
else
:
scores
=
None
scores
=
None
...
@@ -237,8 +237,8 @@ class CocoDataset(Dataset):
...
@@ -237,8 +237,8 @@ class CocoDataset(Dataset):
if
self
.
proposals
is
not
None
:
if
self
.
proposals
is
not
None
:
proposals
=
self
.
bbox_transform
(
proposals
,
img_shape
,
proposals
=
self
.
bbox_transform
(
proposals
,
img_shape
,
scale_factor
,
flip
)
scale_factor
,
flip
)
proposals
=
np
.
hstack
(
[
proposals
,
scores
[:,
None
]
proposals
=
np
.
hstack
(
])
if
scores
is
not
None
else
proposals
[
proposals
,
scores
])
if
scores
is
not
None
else
proposals
gt_bboxes
=
self
.
bbox_transform
(
gt_bboxes
,
img_shape
,
scale_factor
,
gt_bboxes
=
self
.
bbox_transform
(
gt_bboxes
,
img_shape
,
scale_factor
,
flip
)
flip
)
gt_bboxes_ignore
=
self
.
bbox_transform
(
gt_bboxes_ignore
,
img_shape
,
gt_bboxes_ignore
=
self
.
bbox_transform
(
gt_bboxes_ignore
,
img_shape
,
...
@@ -295,14 +295,14 @@ class CocoDataset(Dataset):
...
@@ -295,14 +295,14 @@ class CocoDataset(Dataset):
flip
=
flip
)
flip
=
flip
)
if
proposal
is
not
None
:
if
proposal
is
not
None
:
if
proposal
.
shape
[
1
]
==
5
:
if
proposal
.
shape
[
1
]
==
5
:
score
=
proposal
[:,
4
]
score
=
proposal
[:,
4
,
None
]
proposal
=
proposal
[:,
:
4
]
proposal
=
proposal
[:,
:
4
]
else
:
else
:
score
=
None
score
=
None
_proposal
=
self
.
bbox_transform
(
proposal
,
img_shape
,
_proposal
=
self
.
bbox_transform
(
proposal
,
img_shape
,
scale_factor
,
flip
)
scale_factor
,
flip
)
_proposal
=
np
.
hstack
(
[
_proposal
,
score
[:,
None
]
_proposal
=
np
.
hstack
(
])
if
score
is
not
None
else
_proposal
[
_proposal
,
score
])
if
score
is
not
None
else
_proposal
_proposal
=
to_tensor
(
_proposal
)
_proposal
=
to_tensor
(
_proposal
)
else
:
else
:
_proposal
=
None
_proposal
=
None
...
...
mmdet/models/bbox_heads/bbox_head.py
View file @
bac11303
...
@@ -59,16 +59,20 @@ class BBoxHead(nn.Module):
...
@@ -59,16 +59,20 @@ class BBoxHead(nn.Module):
bbox_pred
=
self
.
fc_reg
(
x
)
if
self
.
with_reg
else
None
bbox_pred
=
self
.
fc_reg
(
x
)
if
self
.
with_reg
else
None
return
cls_score
,
bbox_pred
return
cls_score
,
bbox_pred
def
get_bbox_target
(
self
,
pos_proposals
,
neg_proposals
,
pos_gt_bboxes
,
def
get_target
(
self
,
sampling_results
,
gt_bboxes
,
gt_labels
,
pos_gt_labels
,
rcnn_train_cfg
):
rcnn_train_cfg
):
reg_num_classes
=
1
if
self
.
reg_class_agnostic
else
self
.
num_classes
pos_proposals
=
[
res
.
pos_bboxes
for
res
in
sampling_results
]
neg_proposals
=
[
res
.
neg_bboxes
for
res
in
sampling_results
]
pos_gt_bboxes
=
[
res
.
pos_gt_bboxes
for
res
in
sampling_results
]
pos_gt_labels
=
[
res
.
pos_gt_labels
for
res
in
sampling_results
]
reg_classes
=
1
if
self
.
reg_class_agnostic
else
self
.
num_classes
cls_reg_targets
=
bbox_target
(
cls_reg_targets
=
bbox_target
(
pos_proposals
,
pos_proposals
,
neg_proposals
,
neg_proposals
,
pos_gt_bboxes
,
pos_gt_bboxes
,
pos_gt_labels
,
pos_gt_labels
,
rcnn_train_cfg
,
rcnn_train_cfg
,
reg_
num_
classes
,
reg_classes
,
target_means
=
self
.
target_means
,
target_means
=
self
.
target_means
,
target_stds
=
self
.
target_stds
)
target_stds
=
self
.
target_stds
)
return
cls_reg_targets
return
cls_reg_targets
...
...
mmdet/models/detectors/two_stage.py
View file @
bac11303
...
@@ -4,7 +4,7 @@ import torch.nn as nn
...
@@ -4,7 +4,7 @@ import torch.nn as nn
from
.base
import
BaseDetector
from
.base
import
BaseDetector
from
.test_mixins
import
RPNTestMixin
,
BBoxTestMixin
,
MaskTestMixin
from
.test_mixins
import
RPNTestMixin
,
BBoxTestMixin
,
MaskTestMixin
from
..
import
builder
from
..
import
builder
from
mmdet.core
import
sample_bboxes
,
bbox2roi
,
bbox2result
,
multi_apply
from
mmdet.core
import
(
assign_and_sample
,
bbox2roi
,
bbox2result
,
multi_apply
)
class
TwoStageDetector
(
BaseDetector
,
RPNTestMixin
,
BBoxTestMixin
,
class
TwoStageDetector
(
BaseDetector
,
RPNTestMixin
,
BBoxTestMixin
,
...
@@ -80,10 +80,11 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
...
@@ -80,10 +80,11 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
gt_labels
,
gt_labels
,
gt_masks
=
None
,
gt_masks
=
None
,
proposals
=
None
):
proposals
=
None
):
losses
=
dict
()
x
=
self
.
extract_feat
(
img
)
x
=
self
.
extract_feat
(
img
)
losses
=
dict
()
# RPN forward and loss
if
self
.
with_rpn
:
if
self
.
with_rpn
:
rpn_outs
=
self
.
rpn_head
(
x
)
rpn_outs
=
self
.
rpn_head
(
x
)
rpn_loss_inputs
=
rpn_outs
+
(
gt_bboxes
,
img_meta
,
rpn_loss_inputs
=
rpn_outs
+
(
gt_bboxes
,
img_meta
,
...
@@ -96,44 +97,43 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
...
@@ -96,44 +97,43 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
else
:
else
:
proposal_list
=
proposals
proposal_list
=
proposals
if
self
.
with_bbox
:
# assign gts and sample proposals
(
pos_proposals
,
neg_proposals
,
pos_assigned_gt_inds
,
pos_gt_bboxes
,
if
self
.
with_bbox
or
self
.
with_mask
:
pos_gt_labels
)
=
multi_apply
(
assign_results
,
sampling_results
=
multi_apply
(
sample_bboxes
,
assign_and_sample
,
proposal_list
,
proposal_list
,
gt_bboxes
,
gt_bboxes
,
gt_bboxes_ignore
,
gt_bboxes_ignore
,
gt_labels
,
gt_labels
,
cfg
=
self
.
train_cfg
.
rcnn
)
cfg
=
self
.
train_cfg
.
rcnn
)
(
labels
,
label_weights
,
bbox_targets
,
bbox_weights
)
=
self
.
bbox_head
.
get_bbox_target
(
# bbox head forward and loss
pos_proposals
,
neg_proposals
,
pos_gt_bboxes
,
pos_gt_labels
,
if
self
.
with_bbox
:
self
.
train_cfg
.
rcnn
)
rois
=
bbox2roi
([
res
.
bboxes
for
res
in
sampling_results
])
# TODO: a more flexible way to decide which feature maps to use
rois
=
bbox2roi
([
bbox_feats
=
self
.
bbox_roi_extractor
(
torch
.
cat
([
pos
,
neg
],
dim
=
0
)
for
pos
,
neg
in
zip
(
pos_proposals
,
neg_proposals
)
])
# TODO: a more flexible way to configurate feat maps
roi_feats
=
self
.
bbox_roi_extractor
(
x
[:
self
.
bbox_roi_extractor
.
num_inputs
],
rois
)
x
[:
self
.
bbox_roi_extractor
.
num_inputs
],
rois
)
cls_score
,
bbox_pred
=
self
.
bbox_head
(
roi
_feats
)
cls_score
,
bbox_pred
=
self
.
bbox_head
(
bbox
_feats
)
loss_bbox
=
self
.
bbox_head
.
loss
(
cls_score
,
bbox_pred
,
labels
,
bbox_targets
=
self
.
bbox_head
.
get_target
(
label_weights
,
bbox_targets
,
sampling_results
,
gt_bboxes
,
gt_labels
,
self
.
train_cfg
.
rcnn
)
bbox_weights
)
loss_bbox
=
self
.
bbox_head
.
loss
(
cls_score
,
bbox_pred
,
*
bbox_targets
)
losses
.
update
(
loss_bbox
)
losses
.
update
(
loss_bbox
)
# mask head forward and loss
if
self
.
with_mask
:
if
self
.
with_mask
:
mask_targets
=
self
.
mask_head
.
get_mask_target
(
pos_rois
=
bbox2roi
([
res
.
pos_bboxes
for
res
in
sampling_results
])
pos_proposals
,
pos_assigned_gt_inds
,
gt_masks
,
self
.
train_cfg
.
rcnn
)
pos_rois
=
bbox2roi
(
pos_proposals
)
mask_feats
=
self
.
mask_roi_extractor
(
mask_feats
=
self
.
mask_roi_extractor
(
x
[:
self
.
mask_roi_extractor
.
num_inputs
],
pos_rois
)
x
[:
self
.
mask_roi_extractor
.
num_inputs
],
pos_rois
)
mask_pred
=
self
.
mask_head
(
mask_feats
)
mask_pred
=
self
.
mask_head
(
mask_feats
)
mask_targets
=
self
.
mask_head
.
get_target
(
sampling_results
,
gt_masks
,
self
.
train_cfg
.
rcnn
)
pos_labels
=
torch
.
cat
(
[
res
.
pos_gt_labels
for
res
in
sampling_results
])
loss_mask
=
self
.
mask_head
.
loss
(
mask_pred
,
mask_targets
,
loss_mask
=
self
.
mask_head
.
loss
(
mask_pred
,
mask_targets
,
torch
.
cat
(
pos_
gt_
labels
)
)
pos_labels
)
losses
.
update
(
loss_mask
)
losses
.
update
(
loss_mask
)
return
losses
return
losses
...
@@ -145,8 +145,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
...
@@ -145,8 +145,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
x
=
self
.
extract_feat
(
img
)
x
=
self
.
extract_feat
(
img
)
proposal_list
=
self
.
simple_test_rpn
(
proposal_list
=
self
.
simple_test_rpn
(
x
,
img_meta
,
x
,
img_meta
,
self
.
test_cfg
.
rpn
)
if
proposals
is
None
else
proposals
self
.
test_cfg
.
rpn
)
if
proposals
is
None
else
proposals
det_bboxes
,
det_labels
=
self
.
simple_test_bboxes
(
det_bboxes
,
det_labels
=
self
.
simple_test_bboxes
(
x
,
img_meta
,
proposal_list
,
self
.
test_cfg
.
rcnn
,
rescale
=
rescale
)
x
,
img_meta
,
proposal_list
,
self
.
test_cfg
.
rcnn
,
rescale
=
rescale
)
...
...
mmdet/models/mask_heads/fcn_mask_head.py
View file @
bac11303
...
@@ -86,8 +86,11 @@ class FCNMaskHead(nn.Module):
...
@@ -86,8 +86,11 @@ class FCNMaskHead(nn.Module):
mask_pred
=
self
.
conv_logits
(
x
)
mask_pred
=
self
.
conv_logits
(
x
)
return
mask_pred
return
mask_pred
def
get_mask_target
(
self
,
pos_proposals
,
pos_assigned_gt_inds
,
gt_masks
,
def
get_target
(
self
,
sampling_results
,
gt_masks
,
rcnn_train_cfg
):
rcnn_train_cfg
):
pos_proposals
=
[
res
.
pos_bboxes
for
res
in
sampling_results
]
pos_assigned_gt_inds
=
[
res
.
pos_assigned_gt_inds
for
res
in
sampling_results
]
mask_targets
=
mask_target
(
pos_proposals
,
pos_assigned_gt_inds
,
mask_targets
=
mask_target
(
pos_proposals
,
pos_assigned_gt_inds
,
gt_masks
,
rcnn_train_cfg
)
gt_masks
,
rcnn_train_cfg
)
return
mask_targets
return
mask_targets
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment