Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
092b97f6
Commit
092b97f6
authored
Dec 10, 2018
by
yhcao6
Browse files
add kwargs to sample_pos, sample_neg
parent
f9b31893
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
16 additions
and
21 deletions
+16
-21
mmdet/core/bbox/samplers/instance_balanced_pos_sampler.py
mmdet/core/bbox/samplers/instance_balanced_pos_sampler.py
+1
-1
mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py
mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py
+1
-1
mmdet/core/bbox/samplers/ohem_sampler.py
mmdet/core/bbox/samplers/ohem_sampler.py
+6
-6
mmdet/core/bbox/samplers/random_sampler.py
mmdet/core/bbox/samplers/random_sampler.py
+2
-2
mmdet/models/detectors/two_stage.py
mmdet/models/detectors/two_stage.py
+6
-11
No files found.
mmdet/core/bbox/samplers/instance_balanced_pos_sampler.py
View file @
092b97f6
...
@@ -6,7 +6,7 @@ from .random_sampler import RandomSampler
...
@@ -6,7 +6,7 @@ from .random_sampler import RandomSampler
class
InstanceBalancedPosSampler
(
RandomSampler
):
class
InstanceBalancedPosSampler
(
RandomSampler
):
def
_sample_pos
(
self
,
assign_result
,
num_expected
,
bboxes
=
None
):
def
_sample_pos
(
self
,
assign_result
,
num_expected
,
**
kwargs
):
pos_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
>
0
)
pos_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
>
0
)
if
pos_inds
.
numel
()
!=
0
:
if
pos_inds
.
numel
()
!=
0
:
pos_inds
=
pos_inds
.
squeeze
(
1
)
pos_inds
=
pos_inds
.
squeeze
(
1
)
...
...
mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py
View file @
092b97f6
...
@@ -19,7 +19,7 @@ class IoUBalancedNegSampler(RandomSampler):
...
@@ -19,7 +19,7 @@ class IoUBalancedNegSampler(RandomSampler):
self
.
hard_thr
=
hard_thr
self
.
hard_thr
=
hard_thr
self
.
hard_fraction
=
hard_fraction
self
.
hard_fraction
=
hard_fraction
def
_sample_neg
(
self
,
assign_result
,
num_expected
,
bboxes
=
None
):
def
_sample_neg
(
self
,
assign_result
,
num_expected
,
**
kwargs
):
neg_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
==
0
)
neg_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
==
0
)
if
neg_inds
.
numel
()
!=
0
:
if
neg_inds
.
numel
()
!=
0
:
neg_inds
=
neg_inds
.
squeeze
(
1
)
neg_inds
=
neg_inds
.
squeeze
(
1
)
...
...
mmdet/core/bbox/samplers/ohem_sampler.py
View file @
092b97f6
...
@@ -22,7 +22,7 @@ class OHEMSampler(BaseSampler):
...
@@ -22,7 +22,7 @@ class OHEMSampler(BaseSampler):
self
.
bbox_head
=
bbox_head
self
.
bbox_head
=
bbox_head
def
_sample_pos
(
self
,
assign_result
,
num_expected
,
bboxes
=
None
,
def
_sample_pos
(
self
,
assign_result
,
num_expected
,
bboxes
=
None
,
feats
=
None
):
feats
=
None
,
**
kwargs
):
"""Hard sample some positive samples."""
"""Hard sample some positive samples."""
pos_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
>
0
)
pos_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
>
0
)
if
pos_inds
.
numel
()
!=
0
:
if
pos_inds
.
numel
()
!=
0
:
...
@@ -35,7 +35,7 @@ class OHEMSampler(BaseSampler):
...
@@ -35,7 +35,7 @@ class OHEMSampler(BaseSampler):
bbox_feats
=
self
.
bbox_roi_extractor
(
bbox_feats
=
self
.
bbox_roi_extractor
(
feats
[:
self
.
bbox_roi_extractor
.
num_inputs
],
rois
)
feats
[:
self
.
bbox_roi_extractor
.
num_inputs
],
rois
)
cls_score
,
_
=
self
.
bbox_head
(
bbox_feats
)
cls_score
,
_
=
self
.
bbox_head
(
bbox_feats
)
loss_
all
=
self
.
bbox_head
.
loss
(
loss_
pos
=
self
.
bbox_head
.
loss
(
cls_score
=
cls_score
,
cls_score
=
cls_score
,
bbox_pred
=
None
,
bbox_pred
=
None
,
labels
=
assign_result
.
labels
[
pos_inds
],
labels
=
assign_result
.
labels
[
pos_inds
],
...
@@ -43,11 +43,11 @@ class OHEMSampler(BaseSampler):
...
@@ -43,11 +43,11 @@ class OHEMSampler(BaseSampler):
bbox_targets
=
None
,
bbox_targets
=
None
,
bbox_weights
=
None
,
bbox_weights
=
None
,
reduction
=
'none'
)[
'loss_cls'
]
reduction
=
'none'
)[
'loss_cls'
]
_
,
topk_loss_pos_inds
=
loss_
all
.
topk
(
num_expected
)
_
,
topk_loss_pos_inds
=
loss_
pos
.
topk
(
num_expected
)
return
pos_inds
[
topk_loss_pos_inds
]
return
pos_inds
[
topk_loss_pos_inds
]
def
_sample_neg
(
self
,
assign_result
,
num_expected
,
bboxes
=
None
,
def
_sample_neg
(
self
,
assign_result
,
num_expected
,
bboxes
=
None
,
feats
=
None
):
feats
=
None
,
**
kwargs
):
"""Hard sample some negative samples."""
"""Hard sample some negative samples."""
neg_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
==
0
)
neg_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
==
0
)
if
neg_inds
.
numel
()
!=
0
:
if
neg_inds
.
numel
()
!=
0
:
...
@@ -60,7 +60,7 @@ class OHEMSampler(BaseSampler):
...
@@ -60,7 +60,7 @@ class OHEMSampler(BaseSampler):
bbox_feats
=
self
.
bbox_roi_extractor
(
bbox_feats
=
self
.
bbox_roi_extractor
(
feats
[:
self
.
bbox_roi_extractor
.
num_inputs
],
rois
)
feats
[:
self
.
bbox_roi_extractor
.
num_inputs
],
rois
)
cls_score
,
_
=
self
.
bbox_head
(
bbox_feats
)
cls_score
,
_
=
self
.
bbox_head
(
bbox_feats
)
loss_
all
=
self
.
bbox_head
.
loss
(
loss_
neg
=
self
.
bbox_head
.
loss
(
cls_score
=
cls_score
,
cls_score
=
cls_score
,
bbox_pred
=
None
,
bbox_pred
=
None
,
labels
=
assign_result
.
labels
[
neg_inds
],
labels
=
assign_result
.
labels
[
neg_inds
],
...
@@ -68,5 +68,5 @@ class OHEMSampler(BaseSampler):
...
@@ -68,5 +68,5 @@ class OHEMSampler(BaseSampler):
bbox_targets
=
None
,
bbox_targets
=
None
,
bbox_weights
=
None
,
bbox_weights
=
None
,
reduction
=
'none'
)[
'loss_cls'
]
reduction
=
'none'
)[
'loss_cls'
]
_
,
topk_loss_neg_inds
=
loss_
all
.
topk
(
num_expected
)
_
,
topk_loss_neg_inds
=
loss_
neg
.
topk
(
num_expected
)
return
neg_inds
[
topk_loss_neg_inds
]
return
neg_inds
[
topk_loss_neg_inds
]
mmdet/core/bbox/samplers/random_sampler.py
View file @
092b97f6
...
@@ -34,7 +34,7 @@ class RandomSampler(BaseSampler):
...
@@ -34,7 +34,7 @@ class RandomSampler(BaseSampler):
rand_inds
=
torch
.
from_numpy
(
rand_inds
).
long
().
to
(
gallery
.
device
)
rand_inds
=
torch
.
from_numpy
(
rand_inds
).
long
().
to
(
gallery
.
device
)
return
gallery
[
rand_inds
]
return
gallery
[
rand_inds
]
def
_sample_pos
(
self
,
assign_result
,
num_expected
,
bboxes
=
None
):
def
_sample_pos
(
self
,
assign_result
,
num_expected
,
**
kwargs
):
"""Randomly sample some positive samples."""
"""Randomly sample some positive samples."""
pos_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
>
0
)
pos_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
>
0
)
if
pos_inds
.
numel
()
!=
0
:
if
pos_inds
.
numel
()
!=
0
:
...
@@ -44,7 +44,7 @@ class RandomSampler(BaseSampler):
...
@@ -44,7 +44,7 @@ class RandomSampler(BaseSampler):
else
:
else
:
return
self
.
random_choice
(
pos_inds
,
num_expected
)
return
self
.
random_choice
(
pos_inds
,
num_expected
)
def
_sample_neg
(
self
,
assign_result
,
num_expected
,
bboxes
=
None
):
def
_sample_neg
(
self
,
assign_result
,
num_expected
,
**
kwargs
):
"""Randomly sample some negative samples."""
"""Randomly sample some negative samples."""
neg_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
==
0
)
neg_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
==
0
)
if
neg_inds
.
numel
()
!=
0
:
if
neg_inds
.
numel
()
!=
0
:
...
...
mmdet/models/detectors/two_stage.py
View file @
092b97f6
...
@@ -114,17 +114,12 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
...
@@ -114,17 +114,12 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
assign_result
=
bbox_assigner
.
assign
(
assign_result
=
bbox_assigner
.
assign
(
proposal_list
[
i
],
gt_bboxes
[
i
],
gt_bboxes_ignore
[
i
],
proposal_list
[
i
],
gt_bboxes
[
i
],
gt_bboxes_ignore
[
i
],
gt_labels
[
i
])
gt_labels
[
i
])
if
self
.
train_cfg
.
rcnn
.
sampler
.
type
==
'OHEMSampler'
:
sampling_result
=
bbox_sampler
.
sample
(
sampling_result
=
bbox_sampler
.
sample
(
assign_result
,
assign_result
,
proposal_list
[
i
],
proposal_list
[
i
],
gt_bboxes
[
i
],
gt_bboxes
[
i
],
gt_labels
[
i
],
gt_labels
[
i
],
feats
=
[
xx
[
i
][
None
]
for
xx
in
x
])
feats
=
[
xx
[
i
][
None
]
for
xx
in
x
])
else
:
sampling_result
=
bbox_sampler
.
sample
(
assign_result
,
proposal_list
[
i
],
gt_bboxes
[
i
],
gt_labels
[
i
])
assign_results
.
append
(
assign_result
)
assign_results
.
append
(
assign_result
)
sampling_results
.
append
(
sampling_result
)
sampling_results
.
append
(
sampling_result
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment