Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
bac11303
Commit
bac11303
authored
Oct 17, 2018
by
Kai Chen
Browse files
add BBoxAssigner and BBoxSampler
parent
f3768bcd
Changes
14
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
522 additions
and
468 deletions
+522
-468
configs/fast_mask_rcnn_r50_fpn_1x.py
configs/fast_mask_rcnn_r50_fpn_1x.py
+12
-10
configs/fast_rcnn_r50_fpn_1x.py
configs/fast_rcnn_r50_fpn_1x.py
+12
-10
configs/faster_rcnn_r50_fpn_1x.py
configs/faster_rcnn_r50_fpn_1x.py
+24
-19
configs/mask_rcnn_r50_fpn_1x.py
configs/mask_rcnn_r50_fpn_1x.py
+24
-19
configs/rpn_r50_fpn_1x.py
configs/rpn_r50_fpn_1x.py
+12
-9
mmdet/core/anchor/anchor_target.py
mmdet/core/anchor/anchor_target.py
+10
-17
mmdet/core/bbox/__init__.py
mmdet/core/bbox/__init__.py
+7
-8
mmdet/core/bbox/assignment.py
mmdet/core/bbox/assignment.py
+155
-0
mmdet/core/bbox/bbox_target.py
mmdet/core/bbox/bbox_target.py
+25
-25
mmdet/core/bbox/sampling.py
mmdet/core/bbox/sampling.py
+190
-306
mmdet/datasets/coco.py
mmdet/datasets/coco.py
+6
-6
mmdet/models/bbox_heads/bbox_head.py
mmdet/models/bbox_heads/bbox_head.py
+8
-4
mmdet/models/detectors/two_stage.py
mmdet/models/detectors/two_stage.py
+32
-33
mmdet/models/mask_heads/fcn_mask_head.py
mmdet/models/mask_heads/fcn_mask_head.py
+5
-2
No files found.
configs/fast_mask_rcnn_r50_fpn_1x.py
View file @
bac11303
...
...
@@ -43,17 +43,19 @@ model = dict(
# model training and testing settings
train_cfg
=
dict
(
rcnn
=
dict
(
assigner
=
dict
(
pos_iou_thr
=
0.5
,
neg_iou_thr
=
0.5
,
min_pos_iou
=
0.5
,
ignore_iof_thr
=-
1
),
sampler
=
dict
(
num
=
512
,
pos_fraction
=
0.25
,
neg_pos_ub
=-
1
,
add_gt_as_proposals
=
True
,
pos_balance_sampling
=
False
,
neg_balance_thr
=
0
),
mask_size
=
28
,
pos_iou_thr
=
0.5
,
neg_iou_thr
=
0.5
,
crowd_thr
=
1.1
,
roi_batch_size
=
512
,
add_gt_as_proposals
=
True
,
pos_fraction
=
0.25
,
pos_balance_sampling
=
False
,
neg_pos_ub
=
512
,
neg_balance_thr
=
0
,
min_pos_iou
=
0.5
,
pos_weight
=-
1
,
debug
=
False
))
test_cfg
=
dict
(
...
...
configs/fast_rcnn_r50_fpn_1x.py
View file @
bac11303
...
...
@@ -32,16 +32,18 @@ model = dict(
# model training and testing settings
train_cfg
=
dict
(
rcnn
=
dict
(
pos_iou_thr
=
0.5
,
neg_iou_thr
=
0.5
,
crowd_thr
=
1.1
,
roi_batch_size
=
512
,
add_gt_as_proposals
=
True
,
pos_fraction
=
0.25
,
pos_balance_sampling
=
False
,
neg_pos_ub
=
512
,
neg_balance_thr
=
0
,
min_pos_iou
=
0.5
,
assigner
=
dict
(
pos_iou_thr
=
0.5
,
neg_iou_thr
=
0.5
,
min_pos_iou
=
0.5
,
ignore_iof_thr
=-
1
),
sampler
=
dict
(
num
=
512
,
pos_fraction
=
0.25
,
neg_pos_ub
=-
1
,
add_gt_as_proposals
=
True
,
pos_balance_sampling
=
False
,
neg_balance_thr
=
0
),
pos_weight
=-
1
,
debug
=
False
))
test_cfg
=
dict
(
rcnn
=
dict
(
score_thr
=
0.05
,
max_per_img
=
100
,
nms_thr
=
0.5
))
...
...
configs/faster_rcnn_r50_fpn_1x.py
View file @
bac11303
...
...
@@ -42,30 +42,35 @@ model = dict(
# model training and testing settings
train_cfg
=
dict
(
rpn
=
dict
(
pos_fraction
=
0.5
,
pos_balance_sampling
=
False
,
neg_pos_ub
=
256
,
assigner
=
dict
(
pos_iou_thr
=
0.7
,
neg_iou_thr
=
0.3
,
min_pos_iou
=
0.3
,
ignore_iof_thr
=-
1
),
sampler
=
dict
(
num
=
256
,
pos_fraction
=
0.5
,
neg_pos_ub
=-
1
,
add_gt_as_proposals
=
False
,
pos_balance_sampling
=
False
,
neg_balance_thr
=
0
),
allowed_border
=
0
,
crowd_thr
=
1.1
,
anchor_batch_size
=
256
,
pos_iou_thr
=
0.7
,
neg_iou_thr
=
0.3
,
neg_balance_thr
=
0
,
min_pos_iou
=
0.3
,
pos_weight
=-
1
,
smoothl1_beta
=
1
/
9.0
,
debug
=
False
),
rcnn
=
dict
(
pos_iou_thr
=
0.5
,
neg_iou_thr
=
0.5
,
crowd_thr
=
1.1
,
roi_batch_size
=
512
,
add_gt_as_proposals
=
True
,
pos_fraction
=
0.25
,
pos_balance_sampling
=
False
,
neg_pos_ub
=
512
,
neg_balance_thr
=
0
,
min_pos_iou
=
0.5
,
assigner
=
dict
(
pos_iou_thr
=
0.5
,
neg_iou_thr
=
0.5
,
min_pos_iou
=
0.5
,
ignore_iof_thr
=-
1
),
sampler
=
dict
(
num
=
512
,
pos_fraction
=
0.25
,
neg_pos_ub
=-
1
,
add_gt_as_proposals
=
True
,
pos_balance_sampling
=
False
,
neg_balance_thr
=
0
),
pos_weight
=-
1
,
debug
=
False
))
test_cfg
=
dict
(
...
...
configs/mask_rcnn_r50_fpn_1x.py
View file @
bac11303
...
...
@@ -53,31 +53,36 @@ model = dict(
# model training and testing settings
train_cfg
=
dict
(
rpn
=
dict
(
pos_fraction
=
0.5
,
pos_balance_sampling
=
False
,
neg_pos_ub
=
256
,
assigner
=
dict
(
pos_iou_thr
=
0.7
,
neg_iou_thr
=
0.3
,
min_pos_iou
=
0.3
,
ignore_iof_thr
=-
1
),
sampler
=
dict
(
num
=
256
,
pos_fraction
=
0.5
,
neg_pos_ub
=-
1
,
add_gt_as_proposals
=
False
,
pos_balance_sampling
=
False
,
neg_balance_thr
=
0
),
allowed_border
=
0
,
crowd_thr
=
1.1
,
anchor_batch_size
=
256
,
pos_iou_thr
=
0.7
,
neg_iou_thr
=
0.3
,
neg_balance_thr
=
0
,
min_pos_iou
=
0.3
,
pos_weight
=-
1
,
smoothl1_beta
=
1
/
9.0
,
debug
=
False
),
rcnn
=
dict
(
assigner
=
dict
(
pos_iou_thr
=
0.5
,
neg_iou_thr
=
0.5
,
min_pos_iou
=
0.5
,
ignore_iof_thr
=-
1
),
sampler
=
dict
(
num
=
512
,
pos_fraction
=
0.25
,
neg_pos_ub
=-
1
,
add_gt_as_proposals
=
True
,
pos_balance_sampling
=
False
,
neg_balance_thr
=
0
),
mask_size
=
28
,
pos_iou_thr
=
0.5
,
neg_iou_thr
=
0.5
,
crowd_thr
=
1.1
,
roi_batch_size
=
512
,
add_gt_as_proposals
=
True
,
pos_fraction
=
0.25
,
pos_balance_sampling
=
False
,
neg_pos_ub
=
512
,
neg_balance_thr
=
0
,
min_pos_iou
=
0.5
,
pos_weight
=-
1
,
debug
=
False
))
test_cfg
=
dict
(
...
...
configs/rpn_r50_fpn_1x.py
View file @
bac11303
...
...
@@ -27,16 +27,19 @@ model = dict(
# model training and testing settings
train_cfg
=
dict
(
rpn
=
dict
(
pos_fraction
=
0.5
,
pos_balance_sampling
=
False
,
neg_pos_ub
=
256
,
assigner
=
dict
(
pos_iou_thr
=
0.7
,
neg_iou_thr
=
0.3
,
min_pos_iou
=
0.3
,
ignore_iof_thr
=-
1
),
sampler
=
dict
(
num
=
256
,
pos_fraction
=
0.5
,
neg_pos_ub
=-
1
,
add_gt_as_proposals
=
False
,
pos_balance_sampling
=
False
,
neg_balance_thr
=
0
),
allowed_border
=
0
,
crowd_thr
=
1.1
,
anchor_batch_size
=
256
,
pos_iou_thr
=
0.7
,
neg_iou_thr
=
0.3
,
neg_balance_thr
=
0
,
min_pos_iou
=
0.3
,
pos_weight
=-
1
,
smoothl1_beta
=
1
/
9.0
,
debug
=
False
))
...
...
mmdet/core/anchor/anchor_target.py
View file @
bac11303
import
torch
from
..bbox
import
bbox_
assign
,
bbox2delta
,
bbox_sampling
from
..bbox
import
assign
_and_sample
,
bbox2delta
from
..utils
import
multi_apply
...
...
@@ -80,27 +80,20 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta,
return
(
None
,
)
*
6
# assign gt and sample anchors
anchors
=
flat_anchors
[
inside_flags
,
:]
assigned_gt_inds
,
argmax_overlaps
,
max_overlaps
=
bbox_assign
(
anchors
,
gt_bboxes
,
pos_iou_thr
=
cfg
.
pos_iou_thr
,
neg_iou_thr
=
cfg
.
neg_iou_thr
,
min_pos_iou
=
cfg
.
min_pos_iou
)
pos_inds
,
neg_inds
=
bbox_sampling
(
assigned_gt_inds
,
cfg
.
anchor_batch_size
,
cfg
.
pos_fraction
,
cfg
.
neg_pos_ub
,
cfg
.
pos_balance_sampling
,
max_overlaps
,
cfg
.
neg_balance_thr
)
_
,
sampling_result
=
assign_and_sample
(
anchors
,
gt_bboxes
,
None
,
None
,
cfg
)
num_valid_anchors
=
anchors
.
shape
[
0
]
bbox_targets
=
torch
.
zeros_like
(
anchors
)
bbox_weights
=
torch
.
zeros_like
(
anchors
)
labels
=
torch
.
zeros_like
(
assigned_gt_inds
)
label_weights
=
torch
.
zeros_like
(
assigned_gt_inds
,
dtype
=
anchors
.
dtype
)
labels
=
anchors
.
new_zeros
((
num_valid_anchors
,
)
)
label_weights
=
anchors
.
new_zeros
((
num_valid_
anchors
,
)
)
pos_inds
=
sampling_result
.
pos_inds
neg_inds
=
sampling_result
.
neg_inds
if
len
(
pos_inds
)
>
0
:
pos_anchors
=
anchors
[
pos_inds
,
:]
pos_gt_bbox
=
gt_bboxes
[
assigned_gt_inds
[
pos_inds
]
-
1
,
:]
pos_bbox_targets
=
bbox2delta
(
pos_anchors
,
pos_gt_bbox
,
target_means
,
target_stds
)
pos_bbox_targets
=
bbox2delta
(
sampling_result
.
pos_bboxes
,
sampling_result
.
pos_gt_bboxes
,
target_means
,
target_stds
)
bbox_targets
[
pos_inds
,
:]
=
pos_bbox_targets
bbox_weights
[
pos_inds
,
:]
=
1.0
labels
[
pos_inds
]
=
1
...
...
mmdet/core/bbox/__init__.py
View file @
bac11303
from
.geometry
import
bbox_overlaps
from
.
sampling
import
(
random_choice
,
bbox_assign
,
bbox_assign_wrt_overlaps
,
bbox_sampling
,
bbox_sampling_pos
,
bbox_sampling_neg
,
sample_bboxes
)
from
.
assignment
import
BBoxAssigner
,
AssignResult
from
.sampling
import
(
BBoxSampler
,
SamplingResult
,
assign_and_sample
,
random_choice
)
from
.transforms
import
(
bbox2delta
,
delta2bbox
,
bbox_flip
,
bbox_mapping
,
bbox_mapping_back
,
bbox2roi
,
roi2bbox
,
bbox2result
)
from
.bbox_target
import
bbox_target
__all__
=
[
'bbox_overlaps'
,
'random_choice'
,
'bbox_assign'
,
'bbox_assign_wrt_overlaps'
,
'bbox_sampling'
,
'bbox_sampling_pos'
,
'bbox_sampling_neg'
,
'sample_bboxes'
,
'bbox2delta'
,
'delta2bbox'
,
'bbox_flip'
,
'bbox_mapping'
,
'bbox_mapping_back'
,
'bbox2roi'
,
'roi2bbox'
,
'bbox2result'
,
'bbox_target'
'bbox_overlaps'
,
'BBoxAssigner'
,
'AssignResult'
,
'BBoxSampler'
,
'SamplingResult'
,
'assign_and_sample'
,
'random_choice'
,
'bbox2delta'
,
'delta2bbox'
,
'bbox_flip'
,
'bbox_mapping'
,
'bbox_mapping_back'
,
'bbox2roi'
,
'roi2bbox'
,
'bbox2result'
,
'bbox_target'
]
mmdet/core/bbox/assignment.py
0 → 100644
View file @
bac11303
import
torch
from
.geometry
import
bbox_overlaps
class
BBoxAssigner
(
object
):
"""Assign a corresponding gt bbox or background to each bbox.
Each proposals will be assigned with `-1`, `0`, or a positive integer
indicating the ground truth index.
- -1: don't care
- 0: negative sample, no assigned gt
- positive integer: positive sample, index (1-based) of assigned gt
Args:
pos_iou_thr (float): IoU threshold for positive bboxes.
neg_iou_thr (float or tuple): IoU threshold for negative bboxes.
min_pos_iou (float): Minimum iou for a bbox to be considered as a
positive bbox. For RPN, it is usually set as 0.3, for Fast R-CNN,
it is usually set as pos_iou_thr
ignore_iof_thr (float): IoF threshold for ignoring bboxes (if
`gt_bboxes_ignore` is specified). Negative values mean not
ignoring any bboxes.
"""
def
__init__
(
self
,
pos_iou_thr
,
neg_iou_thr
,
min_pos_iou
=
.
0
,
ignore_iof_thr
=-
1
):
self
.
pos_iou_thr
=
pos_iou_thr
self
.
neg_iou_thr
=
neg_iou_thr
self
.
min_pos_iou
=
min_pos_iou
self
.
ignore_iof_thr
=
ignore_iof_thr
def
assign
(
self
,
bboxes
,
gt_bboxes
,
gt_bboxes_ignore
=
None
,
gt_labels
=
None
):
"""Assign gt to bboxes.
This method assign a gt bbox to every bbox (proposal/anchor), each bbox
will be assigned with -1, 0, or a positive number. -1 means don't care,
0 means negative sample, positive number is the index (1-based) of
assigned gt.
The assignment is done in following steps, the order matters.
1. assign every bbox to -1
2. assign proposals whose iou with all gts < neg_iou_thr to 0
3. for each bbox, if the iou with its nearest gt >= pos_iou_thr,
assign it to that bbox
4. for each gt bbox, assign its nearest proposals (may be more than
one) to itself
Args:
bboxes (Tensor): Bounding boxes to be assigned, shape(n, 4).
gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
labelled as `ignored`, e.g., crowd boxes in COCO.
gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
Returns:
:obj:`AssignResult`: The assign result.
"""
if
bboxes
.
shape
[
0
]
==
0
or
gt_bboxes
.
shape
[
0
]
==
0
:
raise
ValueError
(
'No gt or bboxes'
)
bboxes
=
bboxes
[:,
:
4
]
overlaps
=
bbox_overlaps
(
bboxes
,
gt_bboxes
)
if
(
self
.
ignore_iof_thr
>
0
)
and
(
gt_bboxes_ignore
is
not
None
)
and
(
gt_bboxes_ignore
.
numel
()
>
0
):
ignore_overlaps
=
bbox_overlaps
(
bboxes
,
gt_bboxes_ignore
,
mode
=
'iof'
)
ignore_max_overlaps
,
_
=
ignore_overlaps
.
max
(
dim
=
1
)
ignore_bboxes_inds
=
torch
.
nonzero
(
ignore_max_overlaps
>
self
.
ignore_iof_thr
).
squeeze
()
if
ignore_bboxes_inds
.
numel
()
>
0
:
overlaps
[
ignore_bboxes_inds
[:,
0
],
:]
=
-
1
assign_result
=
self
.
assign_wrt_overlaps
(
overlaps
,
gt_labels
)
return
assign_result
def
assign_wrt_overlaps
(
self
,
overlaps
,
gt_labels
=
None
):
"""Assign w.r.t. the overlaps of bboxes with gts.
Args:
overlaps (Tensor): Overlaps between n bboxes and k gt_bboxes,
shape(n, k).
gt_labels (Tensor, optional): Labels of k gt_bboxes, shape (k, ).
Returns:
:obj:`AssignResult`: The assign result.
"""
if
overlaps
.
numel
()
==
0
:
raise
ValueError
(
'No gt or proposals'
)
num_bboxes
,
num_gts
=
overlaps
.
size
(
0
),
overlaps
.
size
(
1
)
# 1. assign -1 by default
assigned_gt_inds
=
overlaps
.
new_full
(
(
num_bboxes
,
),
-
1
,
dtype
=
torch
.
long
)
assert
overlaps
.
size
()
==
(
num_bboxes
,
num_gts
)
# for each anchor, which gt best overlaps with it
# for each anchor, the max iou of all gts
max_overlaps
,
argmax_overlaps
=
overlaps
.
max
(
dim
=
1
)
# for each gt, which anchor best overlaps with it
# for each gt, the max iou of all proposals
gt_max_overlaps
,
gt_argmax_overlaps
=
overlaps
.
max
(
dim
=
0
)
# 2. assign negative: below
if
isinstance
(
self
.
neg_iou_thr
,
float
):
assigned_gt_inds
[(
max_overlaps
>=
0
)
&
(
max_overlaps
<
self
.
neg_iou_thr
)]
=
0
elif
isinstance
(
self
.
neg_iou_thr
,
tuple
):
assert
len
(
self
.
neg_iou_thr
)
==
2
assigned_gt_inds
[(
max_overlaps
>=
self
.
neg_iou_thr
[
0
])
&
(
max_overlaps
<
self
.
neg_iou_thr
[
1
])]
=
0
# 3. assign positive: above positive IoU threshold
pos_inds
=
max_overlaps
>=
self
.
pos_iou_thr
assigned_gt_inds
[
pos_inds
]
=
argmax_overlaps
[
pos_inds
]
+
1
# 4. assign fg: for each gt, proposals with highest IoU
for
i
in
range
(
num_gts
):
if
gt_max_overlaps
[
i
]
>=
self
.
min_pos_iou
:
assigned_gt_inds
[
overlaps
[:,
i
]
==
gt_max_overlaps
[
i
]]
=
i
+
1
if
gt_labels
is
not
None
:
assigned_labels
=
assigned_gt_inds
.
new_zeros
((
num_bboxes
,
))
pos_inds
=
torch
.
nonzero
(
assigned_gt_inds
>
0
).
squeeze
()
if
pos_inds
.
numel
()
>
0
:
assigned_labels
[
pos_inds
]
=
gt_labels
[
assigned_gt_inds
[
pos_inds
]
-
1
]
else
:
assigned_labels
=
None
return
AssignResult
(
num_gts
,
assigned_gt_inds
,
max_overlaps
,
labels
=
assigned_labels
)
class
AssignResult
(
object
):
def
__init__
(
self
,
num_gts
,
gt_inds
,
max_overlaps
,
labels
=
None
):
self
.
num_gts
=
num_gts
self
.
gt_inds
=
gt_inds
self
.
max_overlaps
=
max_overlaps
self
.
labels
=
labels
def
add_gt_
(
self
,
gt_labels
):
self_inds
=
torch
.
arange
(
1
,
len
(
gt_labels
)
+
1
,
dtype
=
torch
.
long
,
device
=
gt_labels
.
device
)
self
.
gt_inds
=
torch
.
cat
([
self_inds
,
self
.
gt_inds
])
self
.
max_overlaps
=
torch
.
cat
(
[
self
.
max_overlaps
.
new_ones
(
self
.
num_gts
),
self
.
max_overlaps
])
if
self
.
labels
is
not
None
:
self
.
labels
=
torch
.
cat
([
gt_labels
,
self
.
labels
])
mmdet/core/bbox/bbox_target.py
View file @
bac11303
...
...
@@ -4,23 +4,23 @@ from .transforms import bbox2delta
from
..utils
import
multi_apply
def
bbox_target
(
pos_
proposal
s_list
,
neg_
proposal
s_list
,
def
bbox_target
(
pos_
bboxe
s_list
,
neg_
bboxe
s_list
,
pos_gt_bboxes_list
,
pos_gt_labels_list
,
cfg
,
reg_
num_
classes
=
1
,
reg_classes
=
1
,
target_means
=
[.
0
,
.
0
,
.
0
,
.
0
],
target_stds
=
[
1.0
,
1.0
,
1.0
,
1.0
],
concat
=
True
):
labels
,
label_weights
,
bbox_targets
,
bbox_weights
=
multi_apply
(
proposal
_target_single
,
pos_
proposal
s_list
,
neg_
proposal
s_list
,
bbox
_target_single
,
pos_
bboxe
s_list
,
neg_
bboxe
s_list
,
pos_gt_bboxes_list
,
pos_gt_labels_list
,
cfg
=
cfg
,
reg_
num_
classes
=
reg_
num_
classes
,
reg_classes
=
reg_classes
,
target_means
=
target_means
,
target_stds
=
target_stds
)
...
...
@@ -32,34 +32,34 @@ def bbox_target(pos_proposals_list,
return
labels
,
label_weights
,
bbox_targets
,
bbox_weights
def
proposal
_target_single
(
pos_
proposal
s
,
neg_proposal
s
,
pos_gt_bboxes
,
pos_gt_labels
,
cfg
,
reg_num
_classes
=
1
,
target_means
=
[.
0
,
.
0
,
.
0
,
.
0
],
target_stds
=
[
1.0
,
1.0
,
1.0
,
1.0
]):
num_pos
=
pos_
proposal
s
.
size
(
0
)
num_neg
=
neg_
proposal
s
.
size
(
0
)
def
bbox
_target_single
(
pos_
bboxe
s
,
neg_bboxe
s
,
pos_gt_bboxes
,
pos_gt_labels
,
cfg
,
reg
_classes
=
1
,
target_means
=
[.
0
,
.
0
,
.
0
,
.
0
],
target_stds
=
[
1.0
,
1.0
,
1.0
,
1.0
]):
num_pos
=
pos_
bboxe
s
.
size
(
0
)
num_neg
=
neg_
bboxe
s
.
size
(
0
)
num_samples
=
num_pos
+
num_neg
labels
=
pos_
proposal
s
.
new_zeros
(
num_samples
,
dtype
=
torch
.
long
)
label_weights
=
pos_
proposal
s
.
new_zeros
(
num_samples
)
bbox_targets
=
pos_
proposal
s
.
new_zeros
(
num_samples
,
4
)
bbox_weights
=
pos_
proposal
s
.
new_zeros
(
num_samples
,
4
)
labels
=
pos_
bboxe
s
.
new_zeros
(
num_samples
,
dtype
=
torch
.
long
)
label_weights
=
pos_
bboxe
s
.
new_zeros
(
num_samples
)
bbox_targets
=
pos_
bboxe
s
.
new_zeros
(
num_samples
,
4
)
bbox_weights
=
pos_
bboxe
s
.
new_zeros
(
num_samples
,
4
)
if
num_pos
>
0
:
labels
[:
num_pos
]
=
pos_gt_labels
pos_weight
=
1.0
if
cfg
.
pos_weight
<=
0
else
cfg
.
pos_weight
label_weights
[:
num_pos
]
=
pos_weight
pos_bbox_targets
=
bbox2delta
(
pos_
proposal
s
,
pos_gt_bboxes
,
target_means
,
target_stds
)
pos_bbox_targets
=
bbox2delta
(
pos_
bboxe
s
,
pos_gt_bboxes
,
target_means
,
target_stds
)
bbox_targets
[:
num_pos
,
:]
=
pos_bbox_targets
bbox_weights
[:
num_pos
,
:]
=
1
if
num_neg
>
0
:
label_weights
[
-
num_neg
:]
=
1.0
if
reg_
num_
classes
>
1
:
if
reg_classes
>
1
:
bbox_targets
,
bbox_weights
=
expand_target
(
bbox_targets
,
bbox_weights
,
labels
,
reg_
num_
classes
)
labels
,
reg_classes
)
return
labels
,
label_weights
,
bbox_targets
,
bbox_weights
...
...
mmdet/core/bbox/sampling.py
View file @
bac11303
This diff is collapsed.
Click to expand it.
mmdet/datasets/coco.py
View file @
bac11303
...
...
@@ -215,7 +215,7 @@ class CocoDataset(Dataset):
'proposals should have shapes (n, 4) or (n, 5), '
'but found {}'
.
format
(
proposals
.
shape
))
if
proposals
.
shape
[
1
]
==
5
:
scores
=
proposals
[:,
4
]
scores
=
proposals
[:,
4
,
None
]
proposals
=
proposals
[:,
:
4
]
else
:
scores
=
None
...
...
@@ -237,8 +237,8 @@ class CocoDataset(Dataset):
if
self
.
proposals
is
not
None
:
proposals
=
self
.
bbox_transform
(
proposals
,
img_shape
,
scale_factor
,
flip
)
proposals
=
np
.
hstack
(
[
proposals
,
scores
[:,
None
]
])
if
scores
is
not
None
else
proposals
proposals
=
np
.
hstack
(
[
proposals
,
scores
])
if
scores
is
not
None
else
proposals
gt_bboxes
=
self
.
bbox_transform
(
gt_bboxes
,
img_shape
,
scale_factor
,
flip
)
gt_bboxes_ignore
=
self
.
bbox_transform
(
gt_bboxes_ignore
,
img_shape
,
...
...
@@ -295,14 +295,14 @@ class CocoDataset(Dataset):
flip
=
flip
)
if
proposal
is
not
None
:
if
proposal
.
shape
[
1
]
==
5
:
score
=
proposal
[:,
4
]
score
=
proposal
[:,
4
,
None
]
proposal
=
proposal
[:,
:
4
]
else
:
score
=
None
_proposal
=
self
.
bbox_transform
(
proposal
,
img_shape
,
scale_factor
,
flip
)
_proposal
=
np
.
hstack
(
[
_proposal
,
score
[:,
None
]
])
if
score
is
not
None
else
_proposal
_proposal
=
np
.
hstack
(
[
_proposal
,
score
])
if
score
is
not
None
else
_proposal
_proposal
=
to_tensor
(
_proposal
)
else
:
_proposal
=
None
...
...
mmdet/models/bbox_heads/bbox_head.py
View file @
bac11303
...
...
@@ -59,16 +59,20 @@ class BBoxHead(nn.Module):
bbox_pred
=
self
.
fc_reg
(
x
)
if
self
.
with_reg
else
None
return
cls_score
,
bbox_pred
def
get_bbox_target
(
self
,
pos_proposals
,
neg_proposals
,
pos_gt_bboxes
,
pos_gt_labels
,
rcnn_train_cfg
):
reg_num_classes
=
1
if
self
.
reg_class_agnostic
else
self
.
num_classes
def
get_target
(
self
,
sampling_results
,
gt_bboxes
,
gt_labels
,
rcnn_train_cfg
):
pos_proposals
=
[
res
.
pos_bboxes
for
res
in
sampling_results
]
neg_proposals
=
[
res
.
neg_bboxes
for
res
in
sampling_results
]
pos_gt_bboxes
=
[
res
.
pos_gt_bboxes
for
res
in
sampling_results
]
pos_gt_labels
=
[
res
.
pos_gt_labels
for
res
in
sampling_results
]
reg_classes
=
1
if
self
.
reg_class_agnostic
else
self
.
num_classes
cls_reg_targets
=
bbox_target
(
pos_proposals
,
neg_proposals
,
pos_gt_bboxes
,
pos_gt_labels
,
rcnn_train_cfg
,
reg_
num_
classes
,
reg_classes
,
target_means
=
self
.
target_means
,
target_stds
=
self
.
target_stds
)
return
cls_reg_targets
...
...
mmdet/models/detectors/two_stage.py
View file @
bac11303
...
...
@@ -4,7 +4,7 @@ import torch.nn as nn
from
.base
import
BaseDetector
from
.test_mixins
import
RPNTestMixin
,
BBoxTestMixin
,
MaskTestMixin
from
..
import
builder
from
mmdet.core
import
sample_bboxes
,
bbox2roi
,
bbox2result
,
multi_apply
from
mmdet.core
import
(
assign_and_sample
,
bbox2roi
,
bbox2result
,
multi_apply
)
class
TwoStageDetector
(
BaseDetector
,
RPNTestMixin
,
BBoxTestMixin
,
...
...
@@ -80,10 +80,11 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
gt_labels
,
gt_masks
=
None
,
proposals
=
None
):
losses
=
dict
()
x
=
self
.
extract_feat
(
img
)
losses
=
dict
()
# RPN forward and loss
if
self
.
with_rpn
:
rpn_outs
=
self
.
rpn_head
(
x
)
rpn_loss_inputs
=
rpn_outs
+
(
gt_bboxes
,
img_meta
,
...
...
@@ -96,44 +97,43 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
else
:
proposal_list
=
proposals
# assign gts and sample proposals
if
self
.
with_bbox
or
self
.
with_mask
:
assign_results
,
sampling_results
=
multi_apply
(
assign_and_sample
,
proposal_list
,
gt_bboxes
,
gt_bboxes_ignore
,
gt_labels
,
cfg
=
self
.
train_cfg
.
rcnn
)
# bbox head forward and loss
if
self
.
with_bbox
:
(
pos_proposals
,
neg_proposals
,
pos_assigned_gt_inds
,
pos_gt_bboxes
,
pos_gt_labels
)
=
multi_apply
(
sample_bboxes
,
proposal_list
,
gt_bboxes
,
gt_bboxes_ignore
,
gt_labels
,
cfg
=
self
.
train_cfg
.
rcnn
)
(
labels
,
label_weights
,
bbox_targets
,
bbox_weights
)
=
self
.
bbox_head
.
get_bbox_target
(
pos_proposals
,
neg_proposals
,
pos_gt_bboxes
,
pos_gt_labels
,
self
.
train_cfg
.
rcnn
)
rois
=
bbox2roi
([
torch
.
cat
([
pos
,
neg
],
dim
=
0
)
for
pos
,
neg
in
zip
(
pos_proposals
,
neg_proposals
)
])
# TODO: a more flexible way to configurate feat maps
roi_feats
=
self
.
bbox_roi_extractor
(
rois
=
bbox2roi
([
res
.
bboxes
for
res
in
sampling_results
])
# TODO: a more flexible way to decide which feature maps to use
bbox_feats
=
self
.
bbox_roi_extractor
(
x
[:
self
.
bbox_roi_extractor
.
num_inputs
],
rois
)
cls_score
,
bbox_pred
=
self
.
bbox_head
(
roi
_feats
)
cls_score
,
bbox_pred
=
self
.
bbox_head
(
bbox
_feats
)
loss_bbox
=
self
.
bbox_head
.
loss
(
cls_score
,
bbox_pred
,
labels
,
label_weights
,
bbox_targets
,
bbox_weights
)
bbox_targets
=
self
.
bbox_head
.
get_target
(
sampling_results
,
gt_bboxes
,
gt_labels
,
self
.
train_cfg
.
rcnn
)
loss_bbox
=
self
.
bbox_head
.
loss
(
cls_score
,
bbox_pred
,
*
bbox_targets
)
losses
.
update
(
loss_bbox
)
# mask head forward and loss
if
self
.
with_mask
:
mask_targets
=
self
.
mask_head
.
get_mask_target
(
pos_proposals
,
pos_assigned_gt_inds
,
gt_masks
,
self
.
train_cfg
.
rcnn
)
pos_rois
=
bbox2roi
(
pos_proposals
)
pos_rois
=
bbox2roi
([
res
.
pos_bboxes
for
res
in
sampling_results
])
mask_feats
=
self
.
mask_roi_extractor
(
x
[:
self
.
mask_roi_extractor
.
num_inputs
],
pos_rois
)
mask_pred
=
self
.
mask_head
(
mask_feats
)
mask_targets
=
self
.
mask_head
.
get_target
(
sampling_results
,
gt_masks
,
self
.
train_cfg
.
rcnn
)
pos_labels
=
torch
.
cat
(
[
res
.
pos_gt_labels
for
res
in
sampling_results
])
loss_mask
=
self
.
mask_head
.
loss
(
mask_pred
,
mask_targets
,
torch
.
cat
(
pos_
gt_
labels
)
)
pos_labels
)
losses
.
update
(
loss_mask
)
return
losses
...
...
@@ -145,8 +145,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
x
=
self
.
extract_feat
(
img
)
proposal_list
=
self
.
simple_test_rpn
(
x
,
img_meta
,
self
.
test_cfg
.
rpn
)
if
proposals
is
None
else
proposals
x
,
img_meta
,
self
.
test_cfg
.
rpn
)
if
proposals
is
None
else
proposals
det_bboxes
,
det_labels
=
self
.
simple_test_bboxes
(
x
,
img_meta
,
proposal_list
,
self
.
test_cfg
.
rcnn
,
rescale
=
rescale
)
...
...
mmdet/models/mask_heads/fcn_mask_head.py
View file @
bac11303
...
...
@@ -86,8 +86,11 @@ class FCNMaskHead(nn.Module):
mask_pred
=
self
.
conv_logits
(
x
)
return
mask_pred
def
get_mask_target
(
self
,
pos_proposals
,
pos_assigned_gt_inds
,
gt_masks
,
rcnn_train_cfg
):
def
get_target
(
self
,
sampling_results
,
gt_masks
,
rcnn_train_cfg
):
pos_proposals
=
[
res
.
pos_bboxes
for
res
in
sampling_results
]
pos_assigned_gt_inds
=
[
res
.
pos_assigned_gt_inds
for
res
in
sampling_results
]
mask_targets
=
mask_target
(
pos_proposals
,
pos_assigned_gt_inds
,
gt_masks
,
rcnn_train_cfg
)
return
mask_targets
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment