Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
f9b31893
Commit
f9b31893
authored
Dec 10, 2018
by
yhcao6
Browse files
refactor
parent
763153dc
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
53 additions
and
77 deletions
+53
-77
mmdet/core/bbox/samplers/base_sampler.py
mmdet/core/bbox/samplers/base_sampler.py
+7
-5
mmdet/core/bbox/samplers/instance_balanced_pos_sampler.py
mmdet/core/bbox/samplers/instance_balanced_pos_sampler.py
+1
-1
mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py
mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py
+1
-1
mmdet/core/bbox/samplers/ohem_sampler.py
mmdet/core/bbox/samplers/ohem_sampler.py
+37
-64
mmdet/core/bbox/samplers/random_sampler.py
mmdet/core/bbox/samplers/random_sampler.py
+2
-2
mmdet/models/detectors/two_stage.py
mmdet/models/detectors/two_stage.py
+5
-4
No files found.
mmdet/core/bbox/samplers/base_sampler.py
View file @
f9b31893
...
...
@@ -12,14 +12,15 @@ class BaseSampler(metaclass=ABCMeta):
self
.
neg_sampler
=
self
@
abstractmethod
def
_sample_pos
(
self
,
assign_result
,
num_expected
):
def
_sample_pos
(
self
,
assign_result
,
num_expected
,
**
kwargs
):
pass
@
abstractmethod
def
_sample_neg
(
self
,
assign_result
,
num_expected
):
def
_sample_neg
(
self
,
assign_result
,
num_expected
,
**
kwargs
):
pass
def
sample
(
self
,
assign_result
,
bboxes
,
gt_bboxes
,
gt_labels
=
None
):
def
sample
(
self
,
assign_result
,
bboxes
,
gt_bboxes
,
gt_labels
=
None
,
**
kwargs
):
"""Sample positive and negative bboxes.
This is a simple implementation of bbox sampling given candidates,
...
...
@@ -44,8 +45,9 @@ class BaseSampler(metaclass=ABCMeta):
gt_flags
=
torch
.
cat
([
gt_ones
,
gt_flags
])
num_expected_pos
=
int
(
self
.
num
*
self
.
pos_fraction
)
kwargs
.
update
(
dict
(
bboxes
=
bboxes
))
pos_inds
=
self
.
pos_sampler
.
_sample_pos
(
assign_result
,
num_expected_pos
)
num_expected_pos
,
**
kwargs
)
# We found that sampled indices have duplicated items occasionally.
# (may be a bug of PyTorch)
pos_inds
=
pos_inds
.
unique
()
...
...
@@ -57,7 +59,7 @@ class BaseSampler(metaclass=ABCMeta):
if
num_expected_neg
>
neg_upper_bound
:
num_expected_neg
=
neg_upper_bound
neg_inds
=
self
.
neg_sampler
.
_sample_neg
(
assign_result
,
num_expected_neg
)
num_expected_neg
,
**
kwargs
)
neg_inds
=
neg_inds
.
unique
()
return
SamplingResult
(
pos_inds
,
neg_inds
,
bboxes
,
gt_bboxes
,
...
...
mmdet/core/bbox/samplers/instance_balanced_pos_sampler.py
View file @
f9b31893
...
...
@@ -6,7 +6,7 @@ from .random_sampler import RandomSampler
class
InstanceBalancedPosSampler
(
RandomSampler
):
def
_sample_pos
(
self
,
assign_result
,
num_expected
):
def
_sample_pos
(
self
,
assign_result
,
num_expected
,
bboxes
=
None
):
pos_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
>
0
)
if
pos_inds
.
numel
()
!=
0
:
pos_inds
=
pos_inds
.
squeeze
(
1
)
...
...
mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py
View file @
f9b31893
...
...
@@ -19,7 +19,7 @@ class IoUBalancedNegSampler(RandomSampler):
self
.
hard_thr
=
hard_thr
self
.
hard_fraction
=
hard_fraction
def
_sample_neg
(
self
,
assign_result
,
num_expected
):
def
_sample_neg
(
self
,
assign_result
,
num_expected
,
bboxes
=
None
):
neg_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
==
0
)
if
neg_inds
.
numel
()
!=
0
:
neg_inds
=
neg_inds
.
squeeze
(
1
)
...
...
mmdet/core/bbox/samplers/ohem_sampler.py
View file @
f9b31893
...
...
@@ -2,7 +2,6 @@ import torch
from
.base_sampler
import
BaseSampler
from
..transforms
import
bbox2roi
from
.sampling_result
import
SamplingResult
class
OHEMSampler
(
BaseSampler
):
...
...
@@ -11,14 +10,19 @@ class OHEMSampler(BaseSampler):
num
,
pos_fraction
,
neg_pos_ub
=-
1
,
add_gt_as_proposals
=
True
,):
add_gt_as_proposals
=
True
,
bbox_roi_extractor
=
None
,
bbox_head
=
None
):
super
(
OHEMSampler
,
self
).
__init__
()
self
.
num
=
num
self
.
pos_fraction
=
pos_fraction
self
.
neg_pos_ub
=
neg_pos_ub
self
.
add_gt_as_proposals
=
add_gt_as_proposals
self
.
bbox_roi_extractor
=
bbox_roi_extractor
self
.
bbox_head
=
bbox_head
def
_sample_pos
(
self
,
assign_result
,
num_expected
,
loss_all
):
def
_sample_pos
(
self
,
assign_result
,
num_expected
,
bboxes
=
None
,
feats
=
None
):
"""Hard sample some positive samples."""
pos_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
>
0
)
if
pos_inds
.
numel
()
!=
0
:
...
...
@@ -26,10 +30,24 @@ class OHEMSampler(BaseSampler):
if
pos_inds
.
numel
()
<=
num_expected
:
return
pos_inds
else
:
_
,
topk_loss_pos_inds
=
loss_all
[
pos_inds
].
topk
(
num_expected
)
with
torch
.
no_grad
():
rois
=
bbox2roi
([
bboxes
[
pos_inds
]])
bbox_feats
=
self
.
bbox_roi_extractor
(
feats
[:
self
.
bbox_roi_extractor
.
num_inputs
],
rois
)
cls_score
,
_
=
self
.
bbox_head
(
bbox_feats
)
loss_all
=
self
.
bbox_head
.
loss
(
cls_score
=
cls_score
,
bbox_pred
=
None
,
labels
=
assign_result
.
labels
[
pos_inds
],
label_weights
=
cls_score
.
new_ones
(
cls_score
.
size
(
0
)),
bbox_targets
=
None
,
bbox_weights
=
None
,
reduction
=
'none'
)[
'loss_cls'
]
_
,
topk_loss_pos_inds
=
loss_all
.
topk
(
num_expected
)
return
pos_inds
[
topk_loss_pos_inds
]
def
_sample_neg
(
self
,
assign_result
,
num_expected
,
loss_all
):
def
_sample_neg
(
self
,
assign_result
,
num_expected
,
bboxes
=
None
,
feats
=
None
):
"""Hard sample some negative samples."""
neg_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
==
0
)
if
neg_inds
.
numel
()
!=
0
:
...
...
@@ -37,63 +55,18 @@ class OHEMSampler(BaseSampler):
if
len
(
neg_inds
)
<=
num_expected
:
return
neg_inds
else
:
_
,
topk_loss_neg_inds
=
loss_all
[
neg_inds
].
topk
(
num_expected
)
with
torch
.
no_grad
():
rois
=
bbox2roi
([
bboxes
[
neg_inds
]])
bbox_feats
=
self
.
bbox_roi_extractor
(
feats
[:
self
.
bbox_roi_extractor
.
num_inputs
],
rois
)
cls_score
,
_
=
self
.
bbox_head
(
bbox_feats
)
loss_all
=
self
.
bbox_head
.
loss
(
cls_score
=
cls_score
,
bbox_pred
=
None
,
labels
=
assign_result
.
labels
[
neg_inds
],
label_weights
=
cls_score
.
new_ones
(
cls_score
.
size
(
0
)),
bbox_targets
=
None
,
bbox_weights
=
None
,
reduction
=
'none'
)[
'loss_cls'
]
_
,
topk_loss_neg_inds
=
loss_all
.
topk
(
num_expected
)
return
neg_inds
[
topk_loss_neg_inds
]
def
sample
(
self
,
assign_result
,
bboxes
,
gt_bboxes
,
gt_labels
=
None
,
feats
=
None
,
bbox_roi_extractor
=
None
,
bbox_head
=
None
):
"""Sample positive and negative bboxes.
This is a simple implementation of bbox sampling given candidates,
assigning results and ground truth bboxes.
Args:
assign_result (:obj:`AssignResult`): Bbox assigning results.
bboxes (Tensor): Boxes to be sampled from.
gt_bboxes (Tensor): Ground truth bboxes.
gt_labels (Tensor, optional): Class labels of ground truth bboxes.
Returns:
:obj:`SamplingResult`: Sampling result.
"""
bboxes
=
bboxes
[:,
:
4
]
gt_flags
=
bboxes
.
new_zeros
((
bboxes
.
shape
[
0
],
),
dtype
=
torch
.
uint8
)
if
self
.
add_gt_as_proposals
:
bboxes
=
torch
.
cat
([
gt_bboxes
,
bboxes
],
dim
=
0
)
assign_result
.
add_gt_
(
gt_labels
)
gt_ones
=
bboxes
.
new_ones
(
gt_bboxes
.
shape
[
0
],
dtype
=
torch
.
uint8
)
gt_flags
=
torch
.
cat
([
gt_ones
,
gt_flags
])
# calculate loss of all samples used for hard mining
with
torch
.
no_grad
():
rois
=
bbox2roi
([
bboxes
])
bbox_feats
=
bbox_roi_extractor
(
feats
[:
bbox_roi_extractor
.
num_inputs
],
rois
)
cls_score
,
_
=
bbox_head
(
bbox_feats
)
loss_all
=
bbox_head
.
loss
(
cls_score
=
cls_score
,
bbox_pred
=
None
,
labels
=
assign_result
.
labels
,
label_weights
=
cls_score
.
new_ones
(
cls_score
.
size
(
0
)),
bbox_targets
=
None
,
bbox_weights
=
None
,
reduction
=
'none'
)[
'loss_cls'
]
num_expected_pos
=
int
(
self
.
num
*
self
.
pos_fraction
)
pos_inds
=
self
.
_sample_pos
(
assign_result
,
num_expected_pos
,
loss_all
)
# We found that sampled indices have duplicated items occasionally.
# (may be a bug of PyTorch)
pos_inds
=
pos_inds
.
unique
()
num_sampled_pos
=
pos_inds
.
numel
()
num_expected_neg
=
self
.
num
-
num_sampled_pos
if
self
.
neg_pos_ub
>=
0
:
_pos
=
max
(
1
,
num_sampled_pos
)
neg_upper_bound
=
int
(
self
.
neg_pos_ub
*
_pos
)
if
num_expected_neg
>
neg_upper_bound
:
num_expected_neg
=
neg_upper_bound
neg_inds
=
self
.
_sample_neg
(
assign_result
,
num_expected_neg
,
loss_all
)
neg_inds
=
neg_inds
.
unique
()
return
SamplingResult
(
pos_inds
,
neg_inds
,
bboxes
,
gt_bboxes
,
assign_result
,
gt_flags
)
mmdet/core/bbox/samplers/random_sampler.py
View file @
f9b31893
...
...
@@ -34,7 +34,7 @@ class RandomSampler(BaseSampler):
rand_inds
=
torch
.
from_numpy
(
rand_inds
).
long
().
to
(
gallery
.
device
)
return
gallery
[
rand_inds
]
def
_sample_pos
(
self
,
assign_result
,
num_expected
):
def
_sample_pos
(
self
,
assign_result
,
num_expected
,
bboxes
=
None
):
"""Randomly sample some positive samples."""
pos_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
>
0
)
if
pos_inds
.
numel
()
!=
0
:
...
...
@@ -44,7 +44,7 @@ class RandomSampler(BaseSampler):
else
:
return
self
.
random_choice
(
pos_inds
,
num_expected
)
def
_sample_neg
(
self
,
assign_result
,
num_expected
):
def
_sample_neg
(
self
,
assign_result
,
num_expected
,
bboxes
=
None
):
"""Randomly sample some negative samples."""
neg_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
==
0
)
if
neg_inds
.
numel
()
!=
0
:
...
...
mmdet/models/detectors/two_stage.py
View file @
f9b31893
...
...
@@ -103,7 +103,10 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
# assign gts and sample proposals
if
self
.
with_bbox
or
self
.
with_mask
:
bbox_assigner
=
build_assigner
(
self
.
train_cfg
.
rcnn
.
assigner
)
bbox_sampler
=
build_sampler
(
self
.
train_cfg
.
rcnn
.
sampler
)
bbox_sampler
=
build_sampler
(
self
.
train_cfg
.
rcnn
.
sampler
,
dict
(
bbox_roi_extractor
=
self
.
bbox_roi_extractor
,
bbox_head
=
self
.
bbox_head
))
num_imgs
=
img
.
size
(
0
)
assign_results
=
[]
sampling_results
=
[]
...
...
@@ -117,9 +120,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
proposal_list
[
i
],
gt_bboxes
[
i
],
gt_labels
[
i
],
[
xx
[
i
][
None
]
for
xx
in
x
],
self
.
bbox_roi_extractor
,
self
.
bbox_head
)
feats
=
[
xx
[
i
][
None
]
for
xx
in
x
])
else
:
sampling_result
=
bbox_sampler
.
sample
(
assign_result
,
proposal_list
[
i
],
gt_bboxes
[
i
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment