Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
c1f9fa0f
Unverified
Commit
c1f9fa0f
authored
Dec 12, 2018
by
Kai Chen
Committed by
GitHub
Dec 12, 2018
Browse files
Merge pull request #163 from yhcao6/ohem-sampler
Add OHEMSampler
parents
94beb922
b932030c
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
301 additions
and
49 deletions
+301
-49
configs/faster_rcnn_ohem_r50_fpn_1x.py
configs/faster_rcnn_ohem_r50_fpn_1x.py
+156
-0
mmdet/core/bbox/assign_sampling.py
mmdet/core/bbox/assign_sampling.py
+4
-4
mmdet/core/bbox/samplers/__init__.py
mmdet/core/bbox/samplers/__init__.py
+2
-1
mmdet/core/bbox/samplers/base_sampler.py
mmdet/core/bbox/samplers/base_sampler.py
+22
-8
mmdet/core/bbox/samplers/combined_sampler.py
mmdet/core/bbox/samplers/combined_sampler.py
+12
-10
mmdet/core/bbox/samplers/instance_balanced_pos_sampler.py
mmdet/core/bbox/samplers/instance_balanced_pos_sampler.py
+1
-1
mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py
mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py
+1
-1
mmdet/core/bbox/samplers/ohem_sampler.py
mmdet/core/bbox/samplers/ohem_sampler.py
+68
-0
mmdet/core/bbox/samplers/pseudo_sampler.py
mmdet/core/bbox/samplers/pseudo_sampler.py
+4
-4
mmdet/core/bbox/samplers/random_sampler.py
mmdet/core/bbox/samplers/random_sampler.py
+6
-8
mmdet/core/loss/losses.py
mmdet/core/loss/losses.py
+6
-2
mmdet/models/bbox_heads/bbox_head.py
mmdet/models/bbox_heads/bbox_head.py
+2
-2
mmdet/models/detectors/two_stage.py
mmdet/models/detectors/two_stage.py
+17
-8
No files found.
configs/faster_rcnn_ohem_r50_fpn_1x.py
0 → 100644
View file @
c1f9fa0f
# model settings
model
=
dict
(
type
=
'FasterRCNN'
,
pretrained
=
'modelzoo://resnet50'
,
backbone
=
dict
(
type
=
'ResNet'
,
depth
=
50
,
num_stages
=
4
,
out_indices
=
(
0
,
1
,
2
,
3
),
frozen_stages
=
1
,
style
=
'pytorch'
),
neck
=
dict
(
type
=
'FPN'
,
in_channels
=
[
256
,
512
,
1024
,
2048
],
out_channels
=
256
,
num_outs
=
5
),
rpn_head
=
dict
(
type
=
'RPNHead'
,
in_channels
=
256
,
feat_channels
=
256
,
anchor_scales
=
[
8
],
anchor_ratios
=
[
0.5
,
1.0
,
2.0
],
anchor_strides
=
[
4
,
8
,
16
,
32
,
64
],
target_means
=
[.
0
,
.
0
,
.
0
,
.
0
],
target_stds
=
[
1.0
,
1.0
,
1.0
,
1.0
],
use_sigmoid_cls
=
True
),
bbox_roi_extractor
=
dict
(
type
=
'SingleRoIExtractor'
,
roi_layer
=
dict
(
type
=
'RoIAlign'
,
out_size
=
7
,
sample_num
=
2
),
out_channels
=
256
,
featmap_strides
=
[
4
,
8
,
16
,
32
]),
bbox_head
=
dict
(
type
=
'SharedFCBBoxHead'
,
num_fcs
=
2
,
in_channels
=
256
,
fc_out_channels
=
1024
,
roi_feat_size
=
7
,
num_classes
=
81
,
target_means
=
[
0.
,
0.
,
0.
,
0.
],
target_stds
=
[
0.1
,
0.1
,
0.2
,
0.2
],
reg_class_agnostic
=
False
))
# model training and testing settings
train_cfg
=
dict
(
rpn
=
dict
(
assigner
=
dict
(
type
=
'MaxIoUAssigner'
,
pos_iou_thr
=
0.7
,
neg_iou_thr
=
0.3
,
min_pos_iou
=
0.3
,
ignore_iof_thr
=-
1
),
sampler
=
dict
(
type
=
'RandomSampler'
,
num
=
256
,
pos_fraction
=
0.5
,
neg_pos_ub
=-
1
,
add_gt_as_proposals
=
False
),
allowed_border
=
0
,
pos_weight
=-
1
,
smoothl1_beta
=
1
/
9.0
,
debug
=
False
),
rcnn
=
dict
(
assigner
=
dict
(
type
=
'MaxIoUAssigner'
,
pos_iou_thr
=
0.5
,
neg_iou_thr
=
0.5
,
min_pos_iou
=
0.5
,
ignore_iof_thr
=-
1
),
sampler
=
dict
(
type
=
'OHEMSampler'
,
num
=
512
,
pos_fraction
=
0.25
,
neg_pos_ub
=-
1
,
add_gt_as_proposals
=
True
),
pos_weight
=-
1
,
debug
=
False
))
test_cfg
=
dict
(
rpn
=
dict
(
nms_across_levels
=
False
,
nms_pre
=
2000
,
nms_post
=
2000
,
max_num
=
2000
,
nms_thr
=
0.7
,
min_bbox_size
=
0
),
rcnn
=
dict
(
score_thr
=
0.05
,
nms
=
dict
(
type
=
'nms'
,
iou_thr
=
0.5
),
max_per_img
=
100
)
# soft-nms is also supported for rcnn testing
# e.g., nms=dict(type='soft_nms', iou_thr=0.5, min_score=0.05)
)
# dataset settings
dataset_type
=
'CocoDataset'
data_root
=
'data/coco/'
img_norm_cfg
=
dict
(
mean
=
[
123.675
,
116.28
,
103.53
],
std
=
[
58.395
,
57.12
,
57.375
],
to_rgb
=
True
)
data
=
dict
(
imgs_per_gpu
=
2
,
workers_per_gpu
=
2
,
train
=
dict
(
type
=
dataset_type
,
ann_file
=
data_root
+
'annotations/instances_train2017.json'
,
img_prefix
=
data_root
+
'train2017/'
,
img_scale
=
(
1333
,
800
),
img_norm_cfg
=
img_norm_cfg
,
size_divisor
=
32
,
flip_ratio
=
0.5
,
with_mask
=
False
,
with_crowd
=
True
,
with_label
=
True
),
val
=
dict
(
type
=
dataset_type
,
ann_file
=
data_root
+
'annotations/instances_val2017.json'
,
img_prefix
=
data_root
+
'val2017/'
,
img_scale
=
(
1333
,
800
),
img_norm_cfg
=
img_norm_cfg
,
size_divisor
=
32
,
flip_ratio
=
0
,
with_mask
=
False
,
with_crowd
=
True
,
with_label
=
True
),
test
=
dict
(
type
=
dataset_type
,
ann_file
=
data_root
+
'annotations/instances_val2017.json'
,
img_prefix
=
data_root
+
'val2017/'
,
img_scale
=
(
1333
,
800
),
img_norm_cfg
=
img_norm_cfg
,
size_divisor
=
32
,
flip_ratio
=
0
,
with_mask
=
False
,
with_label
=
False
,
test_mode
=
True
))
# optimizer
optimizer
=
dict
(
type
=
'SGD'
,
lr
=
0.02
,
momentum
=
0.9
,
weight_decay
=
0.0001
)
optimizer_config
=
dict
(
grad_clip
=
dict
(
max_norm
=
35
,
norm_type
=
2
))
# learning policy
lr_config
=
dict
(
policy
=
'step'
,
warmup
=
'linear'
,
warmup_iters
=
500
,
warmup_ratio
=
1.0
/
3
,
step
=
[
8
,
11
])
checkpoint_config
=
dict
(
interval
=
1
)
# yapf:disable
log_config
=
dict
(
interval
=
50
,
hooks
=
[
dict
(
type
=
'TextLoggerHook'
),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs
=
12
dist_params
=
dict
(
backend
=
'nccl'
)
log_level
=
'INFO'
work_dir
=
'./work_dirs/faster_rcnn_r50_fpn_1x'
load_from
=
None
resume_from
=
None
workflow
=
[(
'train'
,
1
)]
mmdet/core/bbox/assign_sampling.py
View file @
c1f9fa0f
...
@@ -3,23 +3,23 @@ import mmcv
...
@@ -3,23 +3,23 @@ import mmcv
from
.
import
assigners
,
samplers
from
.
import
assigners
,
samplers
def
build_assigner
(
cfg
,
default_args
=
None
):
def
build_assigner
(
cfg
,
**
kwargs
):
if
isinstance
(
cfg
,
assigners
.
BaseAssigner
):
if
isinstance
(
cfg
,
assigners
.
BaseAssigner
):
return
cfg
return
cfg
elif
isinstance
(
cfg
,
dict
):
elif
isinstance
(
cfg
,
dict
):
return
mmcv
.
runner
.
obj_from_dict
(
return
mmcv
.
runner
.
obj_from_dict
(
cfg
,
assigners
,
default_args
=
default_
args
)
cfg
,
assigners
,
default_args
=
kw
args
)
else
:
else
:
raise
TypeError
(
'Invalid type {} for building a sampler'
.
format
(
raise
TypeError
(
'Invalid type {} for building a sampler'
.
format
(
type
(
cfg
)))
type
(
cfg
)))
def
build_sampler
(
cfg
,
default_args
=
None
):
def
build_sampler
(
cfg
,
**
kwargs
):
if
isinstance
(
cfg
,
samplers
.
BaseSampler
):
if
isinstance
(
cfg
,
samplers
.
BaseSampler
):
return
cfg
return
cfg
elif
isinstance
(
cfg
,
dict
):
elif
isinstance
(
cfg
,
dict
):
return
mmcv
.
runner
.
obj_from_dict
(
return
mmcv
.
runner
.
obj_from_dict
(
cfg
,
samplers
,
default_args
=
default_
args
)
cfg
,
samplers
,
default_args
=
kw
args
)
else
:
else
:
raise
TypeError
(
'Invalid type {} for building a sampler'
.
format
(
raise
TypeError
(
'Invalid type {} for building a sampler'
.
format
(
type
(
cfg
)))
type
(
cfg
)))
...
...
mmdet/core/bbox/samplers/__init__.py
View file @
c1f9fa0f
...
@@ -4,10 +4,11 @@ from .random_sampler import RandomSampler
...
@@ -4,10 +4,11 @@ from .random_sampler import RandomSampler
from
.instance_balanced_pos_sampler
import
InstanceBalancedPosSampler
from
.instance_balanced_pos_sampler
import
InstanceBalancedPosSampler
from
.iou_balanced_neg_sampler
import
IoUBalancedNegSampler
from
.iou_balanced_neg_sampler
import
IoUBalancedNegSampler
from
.combined_sampler
import
CombinedSampler
from
.combined_sampler
import
CombinedSampler
from
.ohem_sampler
import
OHEMSampler
from
.sampling_result
import
SamplingResult
from
.sampling_result
import
SamplingResult
__all__
=
[
__all__
=
[
'BaseSampler'
,
'PseudoSampler'
,
'RandomSampler'
,
'BaseSampler'
,
'PseudoSampler'
,
'RandomSampler'
,
'InstanceBalancedPosSampler'
,
'IoUBalancedNegSampler'
,
'CombinedSampler'
,
'InstanceBalancedPosSampler'
,
'IoUBalancedNegSampler'
,
'CombinedSampler'
,
'SamplingResult'
'OHEMSampler'
,
'SamplingResult'
]
]
mmdet/core/bbox/samplers/base_sampler.py
View file @
c1f9fa0f
...
@@ -7,19 +7,33 @@ from .sampling_result import SamplingResult
...
@@ -7,19 +7,33 @@ from .sampling_result import SamplingResult
class
BaseSampler
(
metaclass
=
ABCMeta
):
class
BaseSampler
(
metaclass
=
ABCMeta
):
def
__init__
(
self
):
def
__init__
(
self
,
num
,
pos_fraction
,
neg_pos_ub
=-
1
,
add_gt_as_proposals
=
True
,
**
kwargs
):
self
.
num
=
num
self
.
pos_fraction
=
pos_fraction
self
.
neg_pos_ub
=
neg_pos_ub
self
.
add_gt_as_proposals
=
add_gt_as_proposals
self
.
pos_sampler
=
self
self
.
pos_sampler
=
self
self
.
neg_sampler
=
self
self
.
neg_sampler
=
self
@
abstractmethod
@
abstractmethod
def
_sample_pos
(
self
,
assign_result
,
num_expected
):
def
_sample_pos
(
self
,
assign_result
,
num_expected
,
**
kwargs
):
pass
pass
@
abstractmethod
@
abstractmethod
def
_sample_neg
(
self
,
assign_result
,
num_expected
):
def
_sample_neg
(
self
,
assign_result
,
num_expected
,
**
kwargs
):
pass
pass
def
sample
(
self
,
assign_result
,
bboxes
,
gt_bboxes
,
gt_labels
=
None
):
def
sample
(
self
,
assign_result
,
bboxes
,
gt_bboxes
,
gt_labels
=
None
,
**
kwargs
):
"""Sample positive and negative bboxes.
"""Sample positive and negative bboxes.
This is a simple implementation of bbox sampling given candidates,
This is a simple implementation of bbox sampling given candidates,
...
@@ -44,8 +58,8 @@ class BaseSampler(metaclass=ABCMeta):
...
@@ -44,8 +58,8 @@ class BaseSampler(metaclass=ABCMeta):
gt_flags
=
torch
.
cat
([
gt_ones
,
gt_flags
])
gt_flags
=
torch
.
cat
([
gt_ones
,
gt_flags
])
num_expected_pos
=
int
(
self
.
num
*
self
.
pos_fraction
)
num_expected_pos
=
int
(
self
.
num
*
self
.
pos_fraction
)
pos_inds
=
self
.
pos_sampler
.
_sample_pos
(
assign_result
,
pos_inds
=
self
.
pos_sampler
.
_sample_pos
(
num_expected_po
s
)
assign_result
,
num_expected_pos
,
bboxes
=
bboxes
,
**
kwarg
s
)
# We found that sampled indices have duplicated items occasionally.
# We found that sampled indices have duplicated items occasionally.
# (may be a bug of PyTorch)
# (may be a bug of PyTorch)
pos_inds
=
pos_inds
.
unique
()
pos_inds
=
pos_inds
.
unique
()
...
@@ -56,8 +70,8 @@ class BaseSampler(metaclass=ABCMeta):
...
@@ -56,8 +70,8 @@ class BaseSampler(metaclass=ABCMeta):
neg_upper_bound
=
int
(
self
.
neg_pos_ub
*
_pos
)
neg_upper_bound
=
int
(
self
.
neg_pos_ub
*
_pos
)
if
num_expected_neg
>
neg_upper_bound
:
if
num_expected_neg
>
neg_upper_bound
:
num_expected_neg
=
neg_upper_bound
num_expected_neg
=
neg_upper_bound
neg_inds
=
self
.
neg_sampler
.
_sample_neg
(
assign_result
,
neg_inds
=
self
.
neg_sampler
.
_sample_neg
(
num_expected_neg
)
assign_result
,
num_expected_neg
,
bboxes
=
bboxes
,
**
kwargs
)
neg_inds
=
neg_inds
.
unique
()
neg_inds
=
neg_inds
.
unique
()
return
SamplingResult
(
pos_inds
,
neg_inds
,
bboxes
,
gt_bboxes
,
return
SamplingResult
(
pos_inds
,
neg_inds
,
bboxes
,
gt_bboxes
,
...
...
mmdet/core/bbox/samplers/combined_sampler.py
View file @
c1f9fa0f
from
.
random
_sampler
import
Random
Sampler
from
.
base
_sampler
import
Base
Sampler
from
..assign_sampling
import
build_sampler
from
..assign_sampling
import
build_sampler
class
CombinedSampler
(
Random
Sampler
):
class
CombinedSampler
(
Base
Sampler
):
def
__init__
(
self
,
num
,
pos_fraction
,
pos_sampler
,
neg_sampler
,
**
kwargs
):
def
__init__
(
self
,
pos_sampler
,
neg_sampler
,
**
kwargs
):
super
(
CombinedSampler
,
self
).
__init__
(
num
,
pos_fraction
,
**
kwargs
)
super
(
CombinedSampler
,
self
).
__init__
(
**
kwargs
)
default_args
=
dict
(
num
=
num
,
pos_fraction
=
pos_fraction
)
self
.
pos_sampler
=
build_sampler
(
pos_sampler
,
**
kwargs
)
default_args
.
update
(
kwargs
)
self
.
neg_sampler
=
build_sampler
(
neg_sampler
,
**
kwargs
)
self
.
pos_sampler
=
build_sampler
(
pos_sampler
,
default_args
=
default_args
)
def
_sample_pos
(
self
,
**
kwargs
):
self
.
neg_sampler
=
build_sampler
(
raise
NotImplementedError
neg_sampler
,
default_args
=
default_args
)
def
_sample_neg
(
self
,
**
kwargs
):
raise
NotImplementedError
mmdet/core/bbox/samplers/instance_balanced_pos_sampler.py
View file @
c1f9fa0f
...
@@ -6,7 +6,7 @@ from .random_sampler import RandomSampler
...
@@ -6,7 +6,7 @@ from .random_sampler import RandomSampler
class
InstanceBalancedPosSampler
(
RandomSampler
):
class
InstanceBalancedPosSampler
(
RandomSampler
):
def
_sample_pos
(
self
,
assign_result
,
num_expected
):
def
_sample_pos
(
self
,
assign_result
,
num_expected
,
**
kwargs
):
pos_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
>
0
)
pos_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
>
0
)
if
pos_inds
.
numel
()
!=
0
:
if
pos_inds
.
numel
()
!=
0
:
pos_inds
=
pos_inds
.
squeeze
(
1
)
pos_inds
=
pos_inds
.
squeeze
(
1
)
...
...
mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py
View file @
c1f9fa0f
...
@@ -19,7 +19,7 @@ class IoUBalancedNegSampler(RandomSampler):
...
@@ -19,7 +19,7 @@ class IoUBalancedNegSampler(RandomSampler):
self
.
hard_thr
=
hard_thr
self
.
hard_thr
=
hard_thr
self
.
hard_fraction
=
hard_fraction
self
.
hard_fraction
=
hard_fraction
def
_sample_neg
(
self
,
assign_result
,
num_expected
):
def
_sample_neg
(
self
,
assign_result
,
num_expected
,
**
kwargs
):
neg_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
==
0
)
neg_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
==
0
)
if
neg_inds
.
numel
()
!=
0
:
if
neg_inds
.
numel
()
!=
0
:
neg_inds
=
neg_inds
.
squeeze
(
1
)
neg_inds
=
neg_inds
.
squeeze
(
1
)
...
...
mmdet/core/bbox/samplers/ohem_sampler.py
0 → 100644
View file @
c1f9fa0f
import
torch
from
.base_sampler
import
BaseSampler
from
..transforms
import
bbox2roi
class
OHEMSampler
(
BaseSampler
):
def
__init__
(
self
,
num
,
pos_fraction
,
context
,
neg_pos_ub
=-
1
,
add_gt_as_proposals
=
True
,
**
kwargs
):
super
(
OHEMSampler
,
self
).
__init__
(
num
,
pos_fraction
,
neg_pos_ub
,
add_gt_as_proposals
)
self
.
bbox_roi_extractor
=
context
.
bbox_roi_extractor
self
.
bbox_head
=
context
.
bbox_head
def
hard_mining
(
self
,
inds
,
num_expected
,
bboxes
,
labels
,
feats
):
with
torch
.
no_grad
():
rois
=
bbox2roi
([
bboxes
])
bbox_feats
=
self
.
bbox_roi_extractor
(
feats
[:
self
.
bbox_roi_extractor
.
num_inputs
],
rois
)
cls_score
,
_
=
self
.
bbox_head
(
bbox_feats
)
loss
=
self
.
bbox_head
.
loss
(
cls_score
=
cls_score
,
bbox_pred
=
None
,
labels
=
labels
,
label_weights
=
cls_score
.
new_ones
(
cls_score
.
size
(
0
)),
bbox_targets
=
None
,
bbox_weights
=
None
,
reduce
=
False
)[
'loss_cls'
]
_
,
topk_loss_inds
=
loss
.
topk
(
num_expected
)
return
inds
[
topk_loss_inds
]
def
_sample_pos
(
self
,
assign_result
,
num_expected
,
bboxes
=
None
,
feats
=
None
,
**
kwargs
):
# Sample some hard positive samples
pos_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
>
0
)
if
pos_inds
.
numel
()
!=
0
:
pos_inds
=
pos_inds
.
squeeze
(
1
)
if
pos_inds
.
numel
()
<=
num_expected
:
return
pos_inds
else
:
return
self
.
hard_mining
(
pos_inds
,
num_expected
,
bboxes
[
pos_inds
],
assign_result
.
labels
[
pos_inds
],
feats
)
def
_sample_neg
(
self
,
assign_result
,
num_expected
,
bboxes
=
None
,
feats
=
None
,
**
kwargs
):
# Sample some hard negative samples
neg_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
==
0
)
if
neg_inds
.
numel
()
!=
0
:
neg_inds
=
neg_inds
.
squeeze
(
1
)
if
len
(
neg_inds
)
<=
num_expected
:
return
neg_inds
else
:
return
self
.
hard_mining
(
neg_inds
,
num_expected
,
bboxes
[
neg_inds
],
assign_result
.
labels
[
neg_inds
],
feats
)
mmdet/core/bbox/samplers/pseudo_sampler.py
View file @
c1f9fa0f
...
@@ -6,16 +6,16 @@ from .sampling_result import SamplingResult
...
@@ -6,16 +6,16 @@ from .sampling_result import SamplingResult
class
PseudoSampler
(
BaseSampler
):
class
PseudoSampler
(
BaseSampler
):
def
__init__
(
self
):
def
__init__
(
self
,
**
kwargs
):
pass
pass
def
_sample_pos
(
self
):
def
_sample_pos
(
self
,
**
kwargs
):
raise
NotImplementedError
raise
NotImplementedError
def
_sample_neg
(
self
):
def
_sample_neg
(
self
,
**
kwargs
):
raise
NotImplementedError
raise
NotImplementedError
def
sample
(
self
,
assign_result
,
bboxes
,
gt_bboxes
):
def
sample
(
self
,
assign_result
,
bboxes
,
gt_bboxes
,
**
kwargs
):
pos_inds
=
torch
.
nonzero
(
pos_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
>
0
).
squeeze
(
-
1
).
unique
()
assign_result
.
gt_inds
>
0
).
squeeze
(
-
1
).
unique
()
neg_inds
=
torch
.
nonzero
(
neg_inds
=
torch
.
nonzero
(
...
...
mmdet/core/bbox/samplers/random_sampler.py
View file @
c1f9fa0f
...
@@ -10,12 +10,10 @@ class RandomSampler(BaseSampler):
...
@@ -10,12 +10,10 @@ class RandomSampler(BaseSampler):
num
,
num
,
pos_fraction
,
pos_fraction
,
neg_pos_ub
=-
1
,
neg_pos_ub
=-
1
,
add_gt_as_proposals
=
True
):
add_gt_as_proposals
=
True
,
super
(
RandomSampler
,
self
).
__init__
()
**
kwargs
):
self
.
num
=
num
super
(
RandomSampler
,
self
).
__init__
(
num
,
pos_fraction
,
neg_pos_ub
,
self
.
pos_fraction
=
pos_fraction
add_gt_as_proposals
)
self
.
neg_pos_ub
=
neg_pos_ub
self
.
add_gt_as_proposals
=
add_gt_as_proposals
@
staticmethod
@
staticmethod
def
random_choice
(
gallery
,
num
):
def
random_choice
(
gallery
,
num
):
...
@@ -34,7 +32,7 @@ class RandomSampler(BaseSampler):
...
@@ -34,7 +32,7 @@ class RandomSampler(BaseSampler):
rand_inds
=
torch
.
from_numpy
(
rand_inds
).
long
().
to
(
gallery
.
device
)
rand_inds
=
torch
.
from_numpy
(
rand_inds
).
long
().
to
(
gallery
.
device
)
return
gallery
[
rand_inds
]
return
gallery
[
rand_inds
]
def
_sample_pos
(
self
,
assign_result
,
num_expected
):
def
_sample_pos
(
self
,
assign_result
,
num_expected
,
**
kwargs
):
"""Randomly sample some positive samples."""
"""Randomly sample some positive samples."""
pos_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
>
0
)
pos_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
>
0
)
if
pos_inds
.
numel
()
!=
0
:
if
pos_inds
.
numel
()
!=
0
:
...
@@ -44,7 +42,7 @@ class RandomSampler(BaseSampler):
...
@@ -44,7 +42,7 @@ class RandomSampler(BaseSampler):
else
:
else
:
return
self
.
random_choice
(
pos_inds
,
num_expected
)
return
self
.
random_choice
(
pos_inds
,
num_expected
)
def
_sample_neg
(
self
,
assign_result
,
num_expected
):
def
_sample_neg
(
self
,
assign_result
,
num_expected
,
**
kwargs
):
"""Randomly sample some negative samples."""
"""Randomly sample some negative samples."""
neg_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
==
0
)
neg_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
==
0
)
if
neg_inds
.
numel
()
!=
0
:
if
neg_inds
.
numel
()
!=
0
:
...
...
mmdet/core/loss/losses.py
View file @
c1f9fa0f
...
@@ -10,11 +10,15 @@ def weighted_nll_loss(pred, label, weight, avg_factor=None):
...
@@ -10,11 +10,15 @@ def weighted_nll_loss(pred, label, weight, avg_factor=None):
return
torch
.
sum
(
raw
*
weight
)[
None
]
/
avg_factor
return
torch
.
sum
(
raw
*
weight
)[
None
]
/
avg_factor
def
weighted_cross_entropy
(
pred
,
label
,
weight
,
avg_factor
=
None
):
def
weighted_cross_entropy
(
pred
,
label
,
weight
,
avg_factor
=
None
,
reduce
=
True
):
if
avg_factor
is
None
:
if
avg_factor
is
None
:
avg_factor
=
max
(
torch
.
sum
(
weight
>
0
).
float
().
item
(),
1.
)
avg_factor
=
max
(
torch
.
sum
(
weight
>
0
).
float
().
item
(),
1.
)
raw
=
F
.
cross_entropy
(
pred
,
label
,
reduction
=
'none'
)
raw
=
F
.
cross_entropy
(
pred
,
label
,
reduction
=
'none'
)
return
torch
.
sum
(
raw
*
weight
)[
None
]
/
avg_factor
if
reduce
:
return
torch
.
sum
(
raw
*
weight
)[
None
]
/
avg_factor
else
:
return
raw
*
weight
/
avg_factor
def
weighted_binary_cross_entropy
(
pred
,
label
,
weight
,
avg_factor
=
None
):
def
weighted_binary_cross_entropy
(
pred
,
label
,
weight
,
avg_factor
=
None
):
...
...
mmdet/models/bbox_heads/bbox_head.py
View file @
c1f9fa0f
...
@@ -79,11 +79,11 @@ class BBoxHead(nn.Module):
...
@@ -79,11 +79,11 @@ class BBoxHead(nn.Module):
return
cls_reg_targets
return
cls_reg_targets
def
loss
(
self
,
cls_score
,
bbox_pred
,
labels
,
label_weights
,
bbox_targets
,
def
loss
(
self
,
cls_score
,
bbox_pred
,
labels
,
label_weights
,
bbox_targets
,
bbox_weights
):
bbox_weights
,
reduce
=
True
):
losses
=
dict
()
losses
=
dict
()
if
cls_score
is
not
None
:
if
cls_score
is
not
None
:
losses
[
'loss_cls'
]
=
weighted_cross_entropy
(
losses
[
'loss_cls'
]
=
weighted_cross_entropy
(
cls_score
,
labels
,
label_weights
)
cls_score
,
labels
,
label_weights
,
reduce
=
reduce
)
losses
[
'acc'
]
=
accuracy
(
cls_score
,
labels
)
losses
[
'acc'
]
=
accuracy
(
cls_score
,
labels
)
if
bbox_pred
is
not
None
:
if
bbox_pred
is
not
None
:
losses
[
'loss_reg'
]
=
weighted_smoothl1
(
losses
[
'loss_reg'
]
=
weighted_smoothl1
(
...
...
mmdet/models/detectors/two_stage.py
View file @
c1f9fa0f
...
@@ -4,7 +4,7 @@ import torch.nn as nn
...
@@ -4,7 +4,7 @@ import torch.nn as nn
from
.base
import
BaseDetector
from
.base
import
BaseDetector
from
.test_mixins
import
RPNTestMixin
,
BBoxTestMixin
,
MaskTestMixin
from
.test_mixins
import
RPNTestMixin
,
BBoxTestMixin
,
MaskTestMixin
from
..
import
builder
from
..
import
builder
from
mmdet.core
import
(
assign_and_sample
,
bbox2roi
,
bbox2result
,
multi_apply
)
from
mmdet.core
import
bbox2roi
,
bbox2result
,
build_assigner
,
build_sampler
class
TwoStageDetector
(
BaseDetector
,
RPNTestMixin
,
BBoxTestMixin
,
class
TwoStageDetector
(
BaseDetector
,
RPNTestMixin
,
BBoxTestMixin
,
...
@@ -102,13 +102,22 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
...
@@ -102,13 +102,22 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
# assign gts and sample proposals
# assign gts and sample proposals
if
self
.
with_bbox
or
self
.
with_mask
:
if
self
.
with_bbox
or
self
.
with_mask
:
assign_results
,
sampling_results
=
multi_apply
(
bbox_assigner
=
build_assigner
(
self
.
train_cfg
.
rcnn
.
assigner
)
assign_and_sample
,
bbox_sampler
=
build_sampler
(
proposal_list
,
self
.
train_cfg
.
rcnn
.
sampler
,
context
=
self
)
gt_bboxes
,
num_imgs
=
img
.
size
(
0
)
gt_bboxes_ignore
,
sampling_results
=
[]
gt_labels
,
for
i
in
range
(
num_imgs
):
cfg
=
self
.
train_cfg
.
rcnn
)
assign_result
=
bbox_assigner
.
assign
(
proposal_list
[
i
],
gt_bboxes
[
i
],
gt_bboxes_ignore
[
i
],
gt_labels
[
i
])
sampling_result
=
bbox_sampler
.
sample
(
assign_result
,
proposal_list
[
i
],
gt_bboxes
[
i
],
gt_labels
[
i
],
feats
=
[
lvl_feat
[
i
][
None
]
for
lvl_feat
in
x
])
sampling_results
.
append
(
sampling_result
)
# bbox head forward and loss
# bbox head forward and loss
if
self
.
with_bbox
:
if
self
.
with_bbox
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment