Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
83644927
Commit
83644927
authored
Dec 11, 2018
by
yhcao6
Browse files
resort formal parameters order of hard mining
parent
d2cde908
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
8 deletions
+8
-8
mmdet/core/bbox/samplers/ohem_sampler.py
mmdet/core/bbox/samplers/ohem_sampler.py
+8
-8
No files found.
mmdet/core/bbox/samplers/ohem_sampler.py
View file @
83644927
...
@@ -21,23 +21,23 @@ class OHEMSampler(BaseSampler):
...
@@ -21,23 +21,23 @@ class OHEMSampler(BaseSampler):
self
.
bbox_roi_extractor
=
bbox_roi_extractor
self
.
bbox_roi_extractor
=
bbox_roi_extractor
self
.
bbox_head
=
bbox_head
self
.
bbox_head
=
bbox_head
def
hard_mining
(
self
,
gallery
,
assign_result
,
num_expected
,
bboxes
,
feats
):
def
hard_mining
(
self
,
inds
,
num_expected
,
bboxes
,
labels
,
feats
):
# hard mining from the gallery.
# hard mining from the gallery.
with
torch
.
no_grad
():
with
torch
.
no_grad
():
rois
=
bbox2roi
([
bboxes
[
gallery
]
])
rois
=
bbox2roi
([
bboxes
])
bbox_feats
=
self
.
bbox_roi_extractor
(
bbox_feats
=
self
.
bbox_roi_extractor
(
feats
[:
self
.
bbox_roi_extractor
.
num_inputs
],
rois
)
feats
[:
self
.
bbox_roi_extractor
.
num_inputs
],
rois
)
cls_score
,
_
=
self
.
bbox_head
(
bbox_feats
)
cls_score
,
_
=
self
.
bbox_head
(
bbox_feats
)
loss
=
self
.
bbox_head
.
loss
(
loss
=
self
.
bbox_head
.
loss
(
cls_score
=
cls_score
,
cls_score
=
cls_score
,
bbox_pred
=
None
,
bbox_pred
=
None
,
labels
=
assign_result
.
labels
[
gallery
]
,
labels
=
labels
,
label_weights
=
cls_score
.
new_ones
(
cls_score
.
size
(
0
)),
label_weights
=
cls_score
.
new_ones
(
cls_score
.
size
(
0
)),
bbox_targets
=
None
,
bbox_targets
=
None
,
bbox_weights
=
None
,
bbox_weights
=
None
,
reduce
=
False
)[
'loss_cls'
]
reduce
=
False
)[
'loss_cls'
]
_
,
topk_loss_inds
=
loss
.
topk
(
num_expected
)
_
,
topk_loss_inds
=
loss
.
topk
(
num_expected
)
return
gallery
[
topk_loss_inds
]
return
inds
[
topk_loss_inds
]
def
_sample_pos
(
self
,
def
_sample_pos
(
self
,
assign_result
,
assign_result
,
...
@@ -52,8 +52,8 @@ class OHEMSampler(BaseSampler):
...
@@ -52,8 +52,8 @@ class OHEMSampler(BaseSampler):
if
pos_inds
.
numel
()
<=
num_expected
:
if
pos_inds
.
numel
()
<=
num_expected
:
return
pos_inds
return
pos_inds
else
:
else
:
return
self
.
hard_mining
(
pos_inds
,
assign_result
,
num_expected
,
return
self
.
hard_mining
(
pos_inds
,
num_expected
,
bboxes
[
pos_inds
]
,
bboxes
,
feats
)
assign_result
.
labels
[
pos_inds
]
,
feats
)
def
_sample_neg
(
self
,
def
_sample_neg
(
self
,
assign_result
,
assign_result
,
...
@@ -68,5 +68,5 @@ class OHEMSampler(BaseSampler):
...
@@ -68,5 +68,5 @@ class OHEMSampler(BaseSampler):
if
len
(
neg_inds
)
<=
num_expected
:
if
len
(
neg_inds
)
<=
num_expected
:
return
neg_inds
return
neg_inds
else
:
else
:
return
self
.
hard_mining
(
neg_inds
,
assign_result
,
num_expected
,
return
self
.
hard_mining
(
neg_inds
,
num_expected
,
bboxes
[
neg_inds
]
,
bboxes
,
feats
)
assign_result
.
labels
[
neg_inds
]
,
feats
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment