Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
6a363603
Commit
6a363603
authored
Dec 11, 2018
by
yhcao6
Browse files
reuse hard mining code
parent
09c3bc4c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
25 additions
and
33 deletions
+25
-33
mmdet/core/bbox/samplers/ohem_sampler.py
mmdet/core/bbox/samplers/ohem_sampler.py
+24
-32
mmdet/models/detectors/two_stage.py
mmdet/models/detectors/two_stage.py
+1
-1
No files found.
mmdet/core/bbox/samplers/ohem_sampler.py
View file @
6a363603
...
@@ -21,34 +21,39 @@ class OHEMSampler(BaseSampler):
...
@@ -21,34 +21,39 @@ class OHEMSampler(BaseSampler):
self
.
bbox_roi_extractor
=
bbox_roi_extractor
self
.
bbox_roi_extractor
=
bbox_roi_extractor
self
.
bbox_head
=
bbox_head
self
.
bbox_head
=
bbox_head
def
hard_mining
(
self
,
gallery
,
assign_result
,
num_expected
,
bboxes
,
feats
):
# hard mining from the gallery.
with
torch
.
no_grad
():
rois
=
bbox2roi
([
bboxes
[
gallery
]])
bbox_feats
=
self
.
bbox_roi_extractor
(
feats
[:
self
.
bbox_roi_extractor
.
num_inputs
],
rois
)
cls_score
,
_
=
self
.
bbox_head
(
bbox_feats
)
loss
=
self
.
bbox_head
.
loss
(
cls_score
=
cls_score
,
bbox_pred
=
None
,
labels
=
assign_result
.
labels
[
gallery
],
label_weights
=
cls_score
.
new_ones
(
cls_score
.
size
(
0
)),
bbox_targets
=
None
,
bbox_weights
=
None
,
reduce
=
False
)[
'loss_cls'
]
_
,
topk_loss_inds
=
loss
.
topk
(
num_expected
)
return
gallery
[
topk_loss_inds
]
def
_sample_pos
(
self
,
def
_sample_pos
(
self
,
assign_result
,
assign_result
,
num_expected
,
num_expected
,
bboxes
=
None
,
bboxes
=
None
,
feats
=
None
,
feats
=
None
,
**
kwargs
):
**
kwargs
):
"""
Hard sample some positive samples
."""
#
Hard sample some positive samples
pos_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
>
0
)
pos_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
>
0
)
if
pos_inds
.
numel
()
!=
0
:
if
pos_inds
.
numel
()
!=
0
:
pos_inds
=
pos_inds
.
squeeze
(
1
)
pos_inds
=
pos_inds
.
squeeze
(
1
)
if
pos_inds
.
numel
()
<=
num_expected
:
if
pos_inds
.
numel
()
<=
num_expected
:
return
pos_inds
return
pos_inds
else
:
else
:
with
torch
.
no_grad
():
return
self
.
hard_mining
(
pos_inds
,
assign_result
,
num_expected
,
rois
=
bbox2roi
([
bboxes
[
pos_inds
]])
bboxes
,
feats
)
bbox_feats
=
self
.
bbox_roi_extractor
(
feats
[:
self
.
bbox_roi_extractor
.
num_inputs
],
rois
)
cls_score
,
_
=
self
.
bbox_head
(
bbox_feats
)
loss_pos
=
self
.
bbox_head
.
loss
(
cls_score
=
cls_score
,
bbox_pred
=
None
,
labels
=
assign_result
.
labels
[
pos_inds
],
label_weights
=
cls_score
.
new_ones
(
cls_score
.
size
(
0
)),
bbox_targets
=
None
,
bbox_weights
=
None
,
reduce
=
False
)[
'loss_cls'
]
_
,
topk_loss_pos_inds
=
loss_pos
.
topk
(
num_expected
)
return
pos_inds
[
topk_loss_pos_inds
]
def
_sample_neg
(
self
,
def
_sample_neg
(
self
,
assign_result
,
assign_result
,
...
@@ -56,25 +61,12 @@ class OHEMSampler(BaseSampler):
...
@@ -56,25 +61,12 @@ class OHEMSampler(BaseSampler):
bboxes
=
None
,
bboxes
=
None
,
feats
=
None
,
feats
=
None
,
**
kwargs
):
**
kwargs
):
"""
Hard sample some negative samples
."""
#
Hard sample some negative samples
neg_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
==
0
)
neg_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
==
0
)
if
neg_inds
.
numel
()
!=
0
:
if
neg_inds
.
numel
()
!=
0
:
neg_inds
=
neg_inds
.
squeeze
(
1
)
neg_inds
=
neg_inds
.
squeeze
(
1
)
if
len
(
neg_inds
)
<=
num_expected
:
if
len
(
neg_inds
)
<=
num_expected
:
return
neg_inds
return
neg_inds
else
:
else
:
with
torch
.
no_grad
():
return
self
.
hard_mining
(
neg_inds
,
assign_result
,
num_expected
,
rois
=
bbox2roi
([
bboxes
[
neg_inds
]])
bboxes
,
feats
)
bbox_feats
=
self
.
bbox_roi_extractor
(
feats
[:
self
.
bbox_roi_extractor
.
num_inputs
],
rois
)
cls_score
,
_
=
self
.
bbox_head
(
bbox_feats
)
loss_neg
=
self
.
bbox_head
.
loss
(
cls_score
=
cls_score
,
bbox_pred
=
None
,
labels
=
assign_result
.
labels
[
neg_inds
],
label_weights
=
cls_score
.
new_ones
(
cls_score
.
size
(
0
)),
bbox_targets
=
None
,
bbox_weights
=
None
,
reduce
=
False
)[
'loss_cls'
]
_
,
topk_loss_neg_inds
=
loss_neg
.
topk
(
num_expected
)
return
neg_inds
[
topk_loss_neg_inds
]
mmdet/models/detectors/two_stage.py
View file @
6a363603
...
@@ -4,7 +4,7 @@ import torch.nn as nn
...
@@ -4,7 +4,7 @@ import torch.nn as nn
from
.base
import
BaseDetector
from
.base
import
BaseDetector
from
.test_mixins
import
RPNTestMixin
,
BBoxTestMixin
,
MaskTestMixin
from
.test_mixins
import
RPNTestMixin
,
BBoxTestMixin
,
MaskTestMixin
from
..
import
builder
from
..
import
builder
from
mmdet.core
import
(
bbox2roi
,
bbox2result
,
build_assigner
,
build_sampler
)
from
mmdet.core
import
bbox2roi
,
bbox2result
,
build_assigner
,
build_sampler
class
TwoStageDetector
(
BaseDetector
,
RPNTestMixin
,
BBoxTestMixin
,
class
TwoStageDetector
(
BaseDetector
,
RPNTestMixin
,
BBoxTestMixin
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment