Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
64e310d5
Unverified
Commit
64e310d5
authored
Nov 27, 2018
by
Kai Chen
Committed by
GitHub
Nov 27, 2018
Browse files
Merge pull request #123 from hellock/single-stage
Add SingleStageDetector and RetinaNet
parents
52a34b5d
9ace2eee
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
587 additions
and
35 deletions
+587
-35
.travis.yml
.travis.yml
+0
-1
MODEL_ZOO.md
MODEL_ZOO.md
+4
-4
README.md
README.md
+3
-0
configs/retinanet_r50_fpn_1x.py
configs/retinanet_r50_fpn_1x.py
+116
-0
mmdet/core/anchor/anchor_target.py
mmdet/core/anchor/anchor_target.py
+74
-21
mmdet/models/builder.py
mmdet/models/builder.py
+7
-2
mmdet/models/detectors/__init__.py
mmdet/models/detectors/__init__.py
+4
-2
mmdet/models/detectors/retinanet.py
mmdet/models/detectors/retinanet.py
+14
-0
mmdet/models/detectors/single_stage.py
mmdet/models/detectors/single_stage.py
+62
-0
mmdet/models/rpn_heads/rpn_head.py
mmdet/models/rpn_heads/rpn_head.py
+2
-2
mmdet/models/single_stage_heads/__init__.py
mmdet/models/single_stage_heads/__init__.py
+3
-0
mmdet/models/single_stage_heads/retina_head.py
mmdet/models/single_stage_heads/retina_head.py
+287
-0
mmdet/models/utils/__init__.py
mmdet/models/utils/__init__.py
+3
-2
mmdet/models/utils/weight_init.py
mmdet/models/utils/weight_init.py
+7
-0
setup.py
setup.py
+1
-1
No files found.
.travis.yml
View file @
64e310d5
...
@@ -5,7 +5,6 @@ install:
...
@@ -5,7 +5,6 @@ install:
-
pip install flake8
-
pip install flake8
python
:
python
:
-
"
2.7"
-
"
3.5"
-
"
3.5"
-
"
3.6"
-
"
3.6"
...
...
MODEL_ZOO.md
View file @
64e310d5
...
@@ -71,13 +71,13 @@ We released RPN, Faster R-CNN and Mask R-CNN models in the first version. More m
...
@@ -71,13 +71,13 @@ We released RPN, Faster R-CNN and Mask R-CNN models in the first version. More m
| R-50-FPN | pytorch | Mask | 1x | 5.3 | 0.50 | 10.6 | 36.8 | 34.1 |
[
model
](
https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/fast_mask_rcnn_r50_fpn_1x_20181010-e030a38f.pth
)
\|
[
result
](
https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/results/fast_mask_rcnn_r50_fpn_1x_20181010_results.pkl.json
)
|
| R-50-FPN | pytorch | Mask | 1x | 5.3 | 0.50 | 10.6 | 36.8 | 34.1 |
[
model
](
https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/fast_mask_rcnn_r50_fpn_1x_20181010-e030a38f.pth
)
\|
[
result
](
https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/results/fast_mask_rcnn_r50_fpn_1x_20181010_results.pkl.json
)
|
| R-50-FPN | pytorch | Mask | 2x | 5.3 | 0.50 | 10.6 | 37.9 | 34.8 |
[
model
](
https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/fast_mask_rcnn_r50_fpn_2x_20181010-5048cb03.pth
)
\|
[
result
](
https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/results/fast_mask_rcnn_r50_fpn_2x_20181010_results.pkl.json
)
|
| R-50-FPN | pytorch | Mask | 2x | 5.3 | 0.50 | 10.6 | 37.9 | 34.8 |
[
model
](
https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/fast_mask_rcnn_r50_fpn_2x_20181010-5048cb03.pth
)
\|
[
result
](
https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/results/fast_mask_rcnn_r50_fpn_2x_20181010_results.pkl.json
)
|
### RetinaNet
(coming soon)
### RetinaNet
| Backbone | Style | Lr schd | Mem (GB) | Train time (s/iter) | Inf time (fps) | box AP | Download |
| Backbone | Style | Lr schd | Mem (GB) | Train time (s/iter) | Inf time (fps) | box AP | Download |
|:--------:|:-------:|:-------:|:--------:|:-------------------:|:--------------:|:------:|:--------:|
|:--------:|:-------:|:-------:|:--------:|:-------------------:|:--------------:|:------:|:--------:|
| R-50-FPN | caffe | 1x |
|
|
|
|
|
| R-50-FPN | caffe | 1x |
6.7
|
0.468
|
9.4
|
35.8
|
-
|
| R-50-FPN | pytorch | 1x |
|
|
|
|
|
| R-50-FPN | pytorch | 1x |
6.9
|
0.496
|
9.1
|
35.6
|
[
model
](
https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/retinanet_r50_fpn_1x_20181125-3d3c2142.pth
)
|
| R-50-FPN | pytorch | 2x |
|
|
|
|
|
| R-50-FPN | pytorch | 2x |
6.9
|
0.496
|
9.1
|
36.5
|
[
model
](
https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/retinanet_r50_fpn_2x_20181125-e0dbec97.pth
)
|
### Cascade R-CNN
### Cascade R-CNN
...
...
README.md
View file @
64e310d5
...
@@ -36,6 +36,9 @@ This project is released under the [Apache 2.0 license](LICENSE).
...
@@ -36,6 +36,9 @@ This project is released under the [Apache 2.0 license](LICENSE).
## Updates
## Updates
v0.5.4 (27/11/2018)
-
Add SingleStageDetector and RetinaNet.
v0.5.3 (26/11/2018)
v0.5.3 (26/11/2018)
-
Add Cascade R-CNN and Cascade Mask R-CNN.
-
Add Cascade R-CNN and Cascade Mask R-CNN.
-
Add support for Soft-NMS in config files.
-
Add support for Soft-NMS in config files.
...
...
configs/retinanet_r50_fpn_1x.py
0 → 100644
View file @
64e310d5
# model settings
model
=
dict
(
type
=
'RetinaNet'
,
pretrained
=
'modelzoo://resnet50'
,
backbone
=
dict
(
type
=
'ResNet'
,
depth
=
50
,
num_stages
=
4
,
out_indices
=
(
0
,
1
,
2
,
3
),
frozen_stages
=
1
,
style
=
'pytorch'
),
neck
=
dict
(
type
=
'FPN'
,
in_channels
=
[
256
,
512
,
1024
,
2048
],
out_channels
=
256
,
start_level
=
1
,
add_extra_convs
=
True
,
num_outs
=
5
),
bbox_head
=
dict
(
type
=
'RetinaHead'
,
num_classes
=
81
,
in_channels
=
256
,
stacked_convs
=
4
,
feat_channels
=
256
,
octave_base_scale
=
4
,
scales_per_octave
=
3
,
anchor_ratios
=
[
0.5
,
1.0
,
2.0
],
anchor_strides
=
[
8
,
16
,
32
,
64
,
128
],
target_means
=
[.
0
,
.
0
,
.
0
,
.
0
],
target_stds
=
[
1.0
,
1.0
,
1.0
,
1.0
]))
# training and testing settings
train_cfg
=
dict
(
assigner
=
dict
(
pos_iou_thr
=
0.5
,
neg_iou_thr
=
0.4
,
min_pos_iou
=
0
,
ignore_iof_thr
=-
1
),
smoothl1_beta
=
0.11
,
gamma
=
2.0
,
alpha
=
0.25
,
allowed_border
=-
1
,
pos_weight
=-
1
,
debug
=
False
)
test_cfg
=
dict
(
nms_pre
=
1000
,
min_bbox_size
=
0
,
score_thr
=
0.05
,
nms
=
dict
(
type
=
'nms'
,
iou_thr
=
0.5
),
max_per_img
=
100
)
# dataset settings
dataset_type
=
'CocoDataset'
data_root
=
'data/coco/'
img_norm_cfg
=
dict
(
mean
=
[
123.675
,
116.28
,
103.53
],
std
=
[
58.395
,
57.12
,
57.375
],
to_rgb
=
True
)
data
=
dict
(
imgs_per_gpu
=
2
,
workers_per_gpu
=
2
,
train
=
dict
(
type
=
dataset_type
,
ann_file
=
data_root
+
'annotations/instances_train2017.json'
,
img_prefix
=
data_root
+
'train2017/'
,
img_scale
=
(
1333
,
800
),
img_norm_cfg
=
img_norm_cfg
,
size_divisor
=
32
,
flip_ratio
=
0.5
,
with_mask
=
False
,
with_crowd
=
False
,
with_label
=
True
),
val
=
dict
(
type
=
dataset_type
,
ann_file
=
data_root
+
'annotations/instances_val2017.json'
,
img_prefix
=
data_root
+
'val2017/'
,
img_scale
=
(
1333
,
800
),
img_norm_cfg
=
img_norm_cfg
,
size_divisor
=
32
,
flip_ratio
=
0
,
with_mask
=
False
,
with_crowd
=
False
,
with_label
=
True
),
test
=
dict
(
type
=
dataset_type
,
ann_file
=
data_root
+
'annotations/instances_val2017.json'
,
img_prefix
=
data_root
+
'val2017/'
,
img_scale
=
(
1333
,
800
),
img_norm_cfg
=
img_norm_cfg
,
size_divisor
=
32
,
flip_ratio
=
0
,
with_mask
=
False
,
with_crowd
=
False
,
with_label
=
False
,
test_mode
=
True
))
# optimizer
optimizer
=
dict
(
type
=
'SGD'
,
lr
=
0.01
,
momentum
=
0.9
,
weight_decay
=
0.0001
)
optimizer_config
=
dict
(
grad_clip
=
dict
(
max_norm
=
35
,
norm_type
=
2
))
# learning policy
lr_config
=
dict
(
policy
=
'step'
,
warmup
=
'linear'
,
warmup_iters
=
500
,
warmup_ratio
=
1.0
/
3
,
step
=
[
8
,
11
])
checkpoint_config
=
dict
(
interval
=
1
)
# yapf:disable
log_config
=
dict
(
interval
=
50
,
hooks
=
[
dict
(
type
=
'TextLoggerHook'
),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs
=
12
device_ids
=
range
(
8
)
dist_params
=
dict
(
backend
=
'nccl'
)
log_level
=
'INFO'
work_dir
=
'./work_dirs/retinanet_r50_fpn_1x'
load_from
=
None
resume_from
=
None
workflow
=
[(
'train'
,
1
)]
mmdet/core/anchor/anchor_target.py
View file @
64e310d5
import
torch
import
torch
from
..bbox
import
assign_and_sample
,
bbox2delta
from
..bbox
import
assign_and_sample
,
BBoxAssigner
,
SamplingResult
,
bbox2delta
from
..utils
import
multi_apply
from
..utils
import
multi_apply
def
anchor_target
(
anchor_list
,
valid_flag_list
,
gt_bboxes_list
,
img_metas
,
def
anchor_target
(
anchor_list
,
target_means
,
target_stds
,
cfg
):
valid_flag_list
,
gt_bboxes_list
,
img_metas
,
target_means
,
target_stds
,
cfg
,
gt_labels_list
=
None
,
cls_out_channels
=
1
,
sampling
=
True
):
"""Compute regression and classification targets for anchors.
"""Compute regression and classification targets for anchors.
Args:
Args:
...
@@ -32,28 +40,34 @@ def anchor_target(anchor_list, valid_flag_list, gt_bboxes_list, img_metas,
...
@@ -32,28 +40,34 @@ def anchor_target(anchor_list, valid_flag_list, gt_bboxes_list, img_metas,
valid_flag_list
[
i
]
=
torch
.
cat
(
valid_flag_list
[
i
])
valid_flag_list
[
i
]
=
torch
.
cat
(
valid_flag_list
[
i
])
# compute targets for each image
# compute targets for each image
means_replicas
=
[
target_means
for
_
in
range
(
num_imgs
)]
if
gt_labels_list
is
None
:
stds_replicas
=
[
target_stds
for
_
in
range
(
num_imgs
)]
gt_labels_list
=
[
None
for
_
in
range
(
num_imgs
)]
cfg_replicas
=
[
cfg
for
_
in
range
(
num_imgs
)]
(
all_labels
,
all_label_weights
,
all_bbox_targets
,
all_bbox_weights
,
(
all_labels
,
all_label_weights
,
all_bbox_targets
,
pos_inds_list
,
neg_inds_list
)
=
multi_apply
(
all_bbox_weights
,
pos_inds_list
,
neg_inds_list
)
=
multi_apply
(
anchor_target_single
,
anchor_target_single
,
anchor_list
,
valid_flag_list
,
gt_bboxes_list
,
anchor_list
,
img_metas
,
means_replicas
,
stds_replicas
,
cfg_replicas
)
valid_flag_list
,
gt_bboxes_list
,
gt_labels_list
,
img_metas
,
target_means
=
target_means
,
target_stds
=
target_stds
,
cfg
=
cfg
,
cls_out_channels
=
cls_out_channels
,
sampling
=
sampling
)
# no valid anchors
# no valid anchors
if
any
([
labels
is
None
for
labels
in
all_labels
]):
if
any
([
labels
is
None
for
labels
in
all_labels
]):
return
None
return
None
# sampled anchors of all images
# sampled anchors of all images
num_total_samples
=
sum
([
num_total_pos
=
sum
([
max
(
inds
.
numel
(),
1
)
for
inds
in
pos_inds_list
])
max
(
pos_inds
.
numel
()
+
neg_inds
.
numel
(),
1
)
num_total_neg
=
sum
([
max
(
inds
.
numel
(),
1
)
for
inds
in
neg_inds_list
])
for
pos_inds
,
neg_inds
in
zip
(
pos_inds_list
,
neg_inds_list
)
])
# split targets to a list w.r.t. multiple levels
# split targets to a list w.r.t. multiple levels
labels_list
=
images_to_levels
(
all_labels
,
num_level_anchors
)
labels_list
=
images_to_levels
(
all_labels
,
num_level_anchors
)
label_weights_list
=
images_to_levels
(
all_label_weights
,
num_level_anchors
)
label_weights_list
=
images_to_levels
(
all_label_weights
,
num_level_anchors
)
bbox_targets_list
=
images_to_levels
(
all_bbox_targets
,
num_level_anchors
)
bbox_targets_list
=
images_to_levels
(
all_bbox_targets
,
num_level_anchors
)
bbox_weights_list
=
images_to_levels
(
all_bbox_weights
,
num_level_anchors
)
bbox_weights_list
=
images_to_levels
(
all_bbox_weights
,
num_level_anchors
)
return
(
labels_list
,
label_weights_list
,
bbox_targets_list
,
return
(
labels_list
,
label_weights_list
,
bbox_targets_list
,
bbox_weights_list
,
num_total_
samples
)
bbox_weights_list
,
num_total_
pos
,
num_total_neg
)
def
images_to_levels
(
target
,
num_level_anchors
):
def
images_to_levels
(
target
,
num_level_anchors
):
...
@@ -71,8 +85,16 @@ def images_to_levels(target, num_level_anchors):
...
@@ -71,8 +85,16 @@ def images_to_levels(target, num_level_anchors):
return
level_targets
return
level_targets
def
anchor_target_single
(
flat_anchors
,
valid_flags
,
gt_bboxes
,
img_meta
,
def
anchor_target_single
(
flat_anchors
,
target_means
,
target_stds
,
cfg
):
valid_flags
,
gt_bboxes
,
gt_labels
,
img_meta
,
target_means
,
target_stds
,
cfg
,
cls_out_channels
=
1
,
sampling
=
True
):
inside_flags
=
anchor_inside_flags
(
flat_anchors
,
valid_flags
,
inside_flags
=
anchor_inside_flags
(
flat_anchors
,
valid_flags
,
img_meta
[
'img_shape'
][:
2
],
img_meta
[
'img_shape'
][:
2
],
cfg
.
allowed_border
)
cfg
.
allowed_border
)
...
@@ -80,13 +102,27 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta,
...
@@ -80,13 +102,27 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta,
return
(
None
,
)
*
6
return
(
None
,
)
*
6
# assign gt and sample anchors
# assign gt and sample anchors
anchors
=
flat_anchors
[
inside_flags
,
:]
anchors
=
flat_anchors
[
inside_flags
,
:]
_
,
sampling_result
=
assign_and_sample
(
anchors
,
gt_bboxes
,
None
,
None
,
cfg
)
if
sampling
:
assign_result
,
sampling_result
=
assign_and_sample
(
anchors
,
gt_bboxes
,
None
,
None
,
cfg
)
else
:
bbox_assigner
=
BBoxAssigner
(
**
cfg
.
assigner
)
assign_result
=
bbox_assigner
.
assign
(
anchors
,
gt_bboxes
,
None
,
gt_labels
)
pos_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
>
0
).
squeeze
(
-
1
).
unique
()
neg_inds
=
torch
.
nonzero
(
assign_result
.
gt_inds
==
0
).
squeeze
(
-
1
).
unique
()
gt_flags
=
anchors
.
new_zeros
(
anchors
.
shape
[
0
],
dtype
=
torch
.
uint8
)
sampling_result
=
SamplingResult
(
pos_inds
,
neg_inds
,
anchors
,
gt_bboxes
,
assign_result
,
gt_flags
)
num_valid_anchors
=
anchors
.
shape
[
0
]
num_valid_anchors
=
anchors
.
shape
[
0
]
bbox_targets
=
torch
.
zeros_like
(
anchors
)
bbox_targets
=
torch
.
zeros_like
(
anchors
)
bbox_weights
=
torch
.
zeros_like
(
anchors
)
bbox_weights
=
torch
.
zeros_like
(
anchors
)
labels
=
anchors
.
new_zeros
(
(
num_valid_anchors
,
)
)
labels
=
anchors
.
new_zeros
(
num_valid_anchors
,
dtype
=
torch
.
long
)
label_weights
=
anchors
.
new_zeros
(
(
num_valid_anchors
,
)
)
label_weights
=
anchors
.
new_zeros
(
num_valid_anchors
,
dtype
=
torch
.
float
)
pos_inds
=
sampling_result
.
pos_inds
pos_inds
=
sampling_result
.
pos_inds
neg_inds
=
sampling_result
.
neg_inds
neg_inds
=
sampling_result
.
neg_inds
...
@@ -96,7 +132,10 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta,
...
@@ -96,7 +132,10 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta,
target_means
,
target_stds
)
target_means
,
target_stds
)
bbox_targets
[
pos_inds
,
:]
=
pos_bbox_targets
bbox_targets
[
pos_inds
,
:]
=
pos_bbox_targets
bbox_weights
[
pos_inds
,
:]
=
1.0
bbox_weights
[
pos_inds
,
:]
=
1.0
labels
[
pos_inds
]
=
1
if
gt_labels
is
None
:
labels
[
pos_inds
]
=
1
else
:
labels
[
pos_inds
]
=
gt_labels
[
sampling_result
.
pos_assigned_gt_inds
]
if
cfg
.
pos_weight
<=
0
:
if
cfg
.
pos_weight
<=
0
:
label_weights
[
pos_inds
]
=
1.0
label_weights
[
pos_inds
]
=
1.0
else
:
else
:
...
@@ -108,6 +147,9 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta,
...
@@ -108,6 +147,9 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta,
num_total_anchors
=
flat_anchors
.
size
(
0
)
num_total_anchors
=
flat_anchors
.
size
(
0
)
labels
=
unmap
(
labels
,
num_total_anchors
,
inside_flags
)
labels
=
unmap
(
labels
,
num_total_anchors
,
inside_flags
)
label_weights
=
unmap
(
label_weights
,
num_total_anchors
,
inside_flags
)
label_weights
=
unmap
(
label_weights
,
num_total_anchors
,
inside_flags
)
if
cls_out_channels
>
1
:
labels
,
label_weights
=
expand_binary_labels
(
labels
,
label_weights
,
cls_out_channels
)
bbox_targets
=
unmap
(
bbox_targets
,
num_total_anchors
,
inside_flags
)
bbox_targets
=
unmap
(
bbox_targets
,
num_total_anchors
,
inside_flags
)
bbox_weights
=
unmap
(
bbox_weights
,
num_total_anchors
,
inside_flags
)
bbox_weights
=
unmap
(
bbox_weights
,
num_total_anchors
,
inside_flags
)
...
@@ -115,6 +157,17 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta,
...
@@ -115,6 +157,17 @@ def anchor_target_single(flat_anchors, valid_flags, gt_bboxes, img_meta,
neg_inds
)
neg_inds
)
def
expand_binary_labels
(
labels
,
label_weights
,
cls_out_channels
):
bin_labels
=
labels
.
new_full
(
(
labels
.
size
(
0
),
cls_out_channels
),
0
,
dtype
=
torch
.
float32
)
inds
=
torch
.
nonzero
(
labels
>=
1
).
squeeze
()
if
inds
.
numel
()
>
0
:
bin_labels
[
inds
,
labels
[
inds
]
-
1
]
=
1
bin_label_weights
=
label_weights
.
view
(
-
1
,
1
).
expand
(
label_weights
.
size
(
0
),
cls_out_channels
)
return
bin_labels
,
bin_label_weights
def
anchor_inside_flags
(
flat_anchors
,
valid_flags
,
img_shape
,
def
anchor_inside_flags
(
flat_anchors
,
valid_flags
,
img_shape
,
allowed_border
=
0
):
allowed_border
=
0
):
img_h
,
img_w
=
img_shape
[:
2
]
img_h
,
img_w
=
img_shape
[:
2
]
...
...
mmdet/models/builder.py
View file @
64e310d5
...
@@ -2,11 +2,12 @@ from mmcv.runner import obj_from_dict
...
@@ -2,11 +2,12 @@ from mmcv.runner import obj_from_dict
from
torch
import
nn
from
torch
import
nn
from
.
import
(
backbones
,
necks
,
roi_extractors
,
rpn_heads
,
bbox_heads
,
from
.
import
(
backbones
,
necks
,
roi_extractors
,
rpn_heads
,
bbox_heads
,
mask_heads
)
mask_heads
,
single_stage_heads
)
__all__
=
[
__all__
=
[
'build_backbone'
,
'build_neck'
,
'build_rpn_head'
,
'build_roi_extractor'
,
'build_backbone'
,
'build_neck'
,
'build_rpn_head'
,
'build_roi_extractor'
,
'build_bbox_head'
,
'build_mask_head'
,
'build_detector'
'build_bbox_head'
,
'build_mask_head'
,
'build_single_stage_head'
,
'build_detector'
]
]
...
@@ -47,6 +48,10 @@ def build_mask_head(cfg):
...
@@ -47,6 +48,10 @@ def build_mask_head(cfg):
return
build
(
cfg
,
mask_heads
)
return
build
(
cfg
,
mask_heads
)
def
build_single_stage_head
(
cfg
):
return
build
(
cfg
,
single_stage_heads
)
def
build_detector
(
cfg
,
train_cfg
=
None
,
test_cfg
=
None
):
def
build_detector
(
cfg
,
train_cfg
=
None
,
test_cfg
=
None
):
from
.
import
detectors
from
.
import
detectors
return
build
(
cfg
,
detectors
,
dict
(
train_cfg
=
train_cfg
,
test_cfg
=
test_cfg
))
return
build
(
cfg
,
detectors
,
dict
(
train_cfg
=
train_cfg
,
test_cfg
=
test_cfg
))
mmdet/models/detectors/__init__.py
View file @
64e310d5
from
.base
import
BaseDetector
from
.base
import
BaseDetector
from
.single_stage
import
SingleStageDetector
from
.two_stage
import
TwoStageDetector
from
.two_stage
import
TwoStageDetector
from
.rpn
import
RPN
from
.rpn
import
RPN
from
.fast_rcnn
import
FastRCNN
from
.fast_rcnn
import
FastRCNN
from
.faster_rcnn
import
FasterRCNN
from
.faster_rcnn
import
FasterRCNN
from
.mask_rcnn
import
MaskRCNN
from
.mask_rcnn
import
MaskRCNN
from
.cascade_rcnn
import
CascadeRCNN
from
.cascade_rcnn
import
CascadeRCNN
from
.retinanet
import
RetinaNet
__all__
=
[
__all__
=
[
'BaseDetector'
,
'
Two
StageDetector'
,
'
RPN'
,
'FastRCNN'
,
'FasterRCN
N'
,
'BaseDetector'
,
'
Single
StageDetector'
,
'
TwoStageDetector'
,
'RP
N'
,
'MaskRCNN'
,
'CascadeRCNN'
'FastRCNN'
,
'FasterRCNN'
,
'MaskRCNN'
,
'CascadeRCNN'
,
'RetinaNet'
]
]
mmdet/models/detectors/retinanet.py
0 → 100644
View file @
64e310d5
from
.single_stage
import
SingleStageDetector
class
RetinaNet
(
SingleStageDetector
):
def
__init__
(
self
,
backbone
,
neck
,
bbox_head
,
train_cfg
=
None
,
test_cfg
=
None
,
pretrained
=
None
):
super
(
RetinaNet
,
self
).
__init__
(
backbone
,
neck
,
bbox_head
,
train_cfg
,
test_cfg
,
pretrained
)
mmdet/models/detectors/single_stage.py
0 → 100644
View file @
64e310d5
import
torch.nn
as
nn
from
.base
import
BaseDetector
from
..
import
builder
from
mmdet.core
import
bbox2result
class
SingleStageDetector
(
BaseDetector
):
def
__init__
(
self
,
backbone
,
neck
=
None
,
bbox_head
=
None
,
train_cfg
=
None
,
test_cfg
=
None
,
pretrained
=
None
):
super
(
SingleStageDetector
,
self
).
__init__
()
self
.
backbone
=
builder
.
build_backbone
(
backbone
)
if
neck
is
not
None
:
self
.
neck
=
builder
.
build_neck
(
neck
)
self
.
bbox_head
=
builder
.
build_single_stage_head
(
bbox_head
)
self
.
train_cfg
=
train_cfg
self
.
test_cfg
=
test_cfg
self
.
init_weights
(
pretrained
=
pretrained
)
def
init_weights
(
self
,
pretrained
=
None
):
super
(
SingleStageDetector
,
self
).
init_weights
(
pretrained
)
self
.
backbone
.
init_weights
(
pretrained
=
pretrained
)
if
self
.
with_neck
:
if
isinstance
(
self
.
neck
,
nn
.
Sequential
):
for
m
in
self
.
neck
:
m
.
init_weights
()
else
:
self
.
neck
.
init_weights
()
self
.
bbox_head
.
init_weights
()
def
extract_feat
(
self
,
img
):
x
=
self
.
backbone
(
img
)
if
self
.
with_neck
:
x
=
self
.
neck
(
x
)
return
x
def
forward_train
(
self
,
img
,
img_metas
,
gt_bboxes
,
gt_labels
):
x
=
self
.
extract_feat
(
img
)
outs
=
self
.
bbox_head
(
x
)
loss_inputs
=
outs
+
(
gt_bboxes
,
gt_labels
,
img_metas
,
self
.
train_cfg
)
losses
=
self
.
bbox_head
.
loss
(
*
loss_inputs
)
return
losses
def
simple_test
(
self
,
img
,
img_meta
,
rescale
=
False
):
x
=
self
.
extract_feat
(
img
)
outs
=
self
.
bbox_head
(
x
)
bbox_inputs
=
outs
+
(
img_meta
,
self
.
test_cfg
,
rescale
)
bbox_list
=
self
.
bbox_head
.
get_det_bboxes
(
*
bbox_inputs
)
bbox_results
=
[
bbox2result
(
det_bboxes
,
det_labels
,
self
.
bbox_head
.
num_classes
)
for
det_bboxes
,
det_labels
in
bbox_list
]
return
bbox_results
[
0
]
def
aug_test
(
self
,
imgs
,
img_metas
,
rescale
=
False
):
raise
NotImplementedError
mmdet/models/rpn_heads/rpn_head.py
View file @
64e310d5
...
@@ -160,7 +160,7 @@ class RPNHead(nn.Module):
...
@@ -160,7 +160,7 @@ class RPNHead(nn.Module):
if
cls_reg_targets
is
None
:
if
cls_reg_targets
is
None
:
return
None
return
None
(
labels_list
,
label_weights_list
,
bbox_targets_list
,
bbox_weights_list
,
(
labels_list
,
label_weights_list
,
bbox_targets_list
,
bbox_weights_list
,
num_total_
samples
)
=
cls_reg_targets
num_total_
pos
,
num_total_neg
)
=
cls_reg_targets
losses_cls
,
losses_reg
=
multi_apply
(
losses_cls
,
losses_reg
=
multi_apply
(
self
.
loss_single
,
self
.
loss_single
,
rpn_cls_scores
,
rpn_cls_scores
,
...
@@ -169,7 +169,7 @@ class RPNHead(nn.Module):
...
@@ -169,7 +169,7 @@ class RPNHead(nn.Module):
label_weights_list
,
label_weights_list
,
bbox_targets_list
,
bbox_targets_list
,
bbox_weights_list
,
bbox_weights_list
,
num_total_samples
=
num_total_
samples
,
num_total_samples
=
num_total_
pos
+
num_total_neg
,
cfg
=
cfg
)
cfg
=
cfg
)
return
dict
(
loss_rpn_cls
=
losses_cls
,
loss_rpn_reg
=
losses_reg
)
return
dict
(
loss_rpn_cls
=
losses_cls
,
loss_rpn_reg
=
losses_reg
)
...
...
mmdet/models/single_stage_heads/__init__.py
0 → 100644
View file @
64e310d5
from
.retina_head
import
RetinaHead
__all__
=
[
'RetinaHead'
]
mmdet/models/single_stage_heads/retina_head.py
0 → 100644
View file @
64e310d5
from
__future__
import
division
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
mmdet.core
import
(
AnchorGenerator
,
anchor_target
,
multi_apply
,
delta2bbox
,
weighted_smoothl1
,
weighted_sigmoid_focal_loss
,
multiclass_nms
)
from
..utils
import
normal_init
,
bias_init_with_prob
class
RetinaHead
(
nn
.
Module
):
"""Head of RetinaNet.
/ cls_convs - retina_cls (3x3 conv)
input -
\ reg_convs - retina_reg (3x3 conv)
Args:
in_channels (int): Number of channels in the input feature map.
num_classes (int): Class number (including background).
stacked_convs (int): Number of convolutional layers added for cls and
reg branch.
feat_channels (int): Number of channels for the RPN feature map.
scales_per_octave (int): Number of anchor scales per octave.
octave_base_scale (int): Base octave scale. Anchor scales are computed
as `s*2^(i/n)`, for i in [0, n-1], where s is `octave_base_scale`
and n is `scales_per_octave`.
anchor_ratios (Iterable): Anchor aspect ratios.
anchor_strides (Iterable): Anchor strides.
target_means (Iterable): Mean values of regression targets.
target_stds (Iterable): Std values of regression targets.
"""
# noqa: W605
def
__init__
(
self
,
in_channels
,
num_classes
,
stacked_convs
=
4
,
feat_channels
=
256
,
octave_base_scale
=
4
,
scales_per_octave
=
3
,
anchor_ratios
=
[
0.5
,
1.0
,
2.0
],
anchor_strides
=
[
8
,
16
,
32
,
64
,
128
],
anchor_base_sizes
=
None
,
target_means
=
(.
0
,
.
0
,
.
0
,
.
0
),
target_stds
=
(
1.0
,
1.0
,
1.0
,
1.0
)):
super
(
RetinaHead
,
self
).
__init__
()
self
.
in_channels
=
in_channels
self
.
num_classes
=
num_classes
self
.
octave_base_scale
=
octave_base_scale
self
.
scales_per_octave
=
scales_per_octave
self
.
anchor_ratios
=
anchor_ratios
self
.
anchor_strides
=
anchor_strides
self
.
anchor_base_sizes
=
list
(
anchor_strides
)
if
anchor_base_sizes
is
None
else
anchor_base_sizes
self
.
target_means
=
target_means
self
.
target_stds
=
target_stds
self
.
anchor_generators
=
[]
for
anchor_base
in
self
.
anchor_base_sizes
:
octave_scales
=
np
.
array
(
[
2
**
(
i
/
scales_per_octave
)
for
i
in
range
(
scales_per_octave
)])
anchor_scales
=
octave_scales
*
octave_base_scale
self
.
anchor_generators
.
append
(
AnchorGenerator
(
anchor_base
,
anchor_scales
,
anchor_ratios
))
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
num_anchors
=
int
(
len
(
self
.
anchor_ratios
)
*
self
.
scales_per_octave
)
self
.
cls_out_channels
=
self
.
num_classes
-
1
self
.
bbox_pred_dim
=
4
self
.
stacked_convs
=
stacked_convs
self
.
cls_convs
=
nn
.
ModuleList
()
self
.
reg_convs
=
nn
.
ModuleList
()
for
i
in
range
(
self
.
stacked_convs
):
chn
=
in_channels
if
i
==
0
else
feat_channels
self
.
cls_convs
.
append
(
nn
.
Conv2d
(
chn
,
feat_channels
,
3
,
stride
=
1
,
padding
=
1
))
self
.
reg_convs
.
append
(
nn
.
Conv2d
(
chn
,
feat_channels
,
3
,
stride
=
1
,
padding
=
1
))
self
.
retina_cls
=
nn
.
Conv2d
(
feat_channels
,
self
.
num_anchors
*
self
.
cls_out_channels
,
3
,
stride
=
1
,
padding
=
1
)
self
.
retina_reg
=
nn
.
Conv2d
(
feat_channels
,
self
.
num_anchors
*
self
.
bbox_pred_dim
,
3
,
stride
=
1
,
padding
=
1
)
self
.
debug_imgs
=
None
def
init_weights
(
self
):
for
m
in
self
.
cls_convs
:
normal_init
(
m
,
std
=
0.01
)
for
m
in
self
.
reg_convs
:
normal_init
(
m
,
std
=
0.01
)
bias_cls
=
bias_init_with_prob
(
0.01
)
normal_init
(
self
.
retina_cls
,
std
=
0.01
,
bias
=
bias_cls
)
normal_init
(
self
.
retina_reg
,
std
=
0.01
)
def
forward_single
(
self
,
x
):
cls_feat
=
x
reg_feat
=
x
for
cls_conv
in
self
.
cls_convs
:
cls_feat
=
self
.
relu
(
cls_conv
(
cls_feat
))
for
reg_conv
in
self
.
reg_convs
:
reg_feat
=
self
.
relu
(
reg_conv
(
reg_feat
))
cls_score
=
self
.
retina_cls
(
cls_feat
)
bbox_pred
=
self
.
retina_reg
(
reg_feat
)
return
cls_score
,
bbox_pred
def
forward
(
self
,
feats
):
return
multi_apply
(
self
.
forward_single
,
feats
)
def
get_anchors
(
self
,
featmap_sizes
,
img_metas
):
"""Get anchors according to feature map sizes.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
img_metas (list[dict]): Image meta info.
Returns:
tuple: anchors of each image, valid flags of each image
"""
num_imgs
=
len
(
img_metas
)
num_levels
=
len
(
featmap_sizes
)
# since feature map sizes of all images are the same, we only compute
# anchors for one time
multi_level_anchors
=
[]
for
i
in
range
(
num_levels
):
anchors
=
self
.
anchor_generators
[
i
].
grid_anchors
(
featmap_sizes
[
i
],
self
.
anchor_strides
[
i
])
multi_level_anchors
.
append
(
anchors
)
anchor_list
=
[
multi_level_anchors
for
_
in
range
(
num_imgs
)]
# for each image, we compute valid flags of multi level anchors
valid_flag_list
=
[]
for
img_id
,
img_meta
in
enumerate
(
img_metas
):
multi_level_flags
=
[]
for
i
in
range
(
num_levels
):
anchor_stride
=
self
.
anchor_strides
[
i
]
feat_h
,
feat_w
=
featmap_sizes
[
i
]
h
,
w
,
_
=
img_meta
[
'pad_shape'
]
valid_feat_h
=
min
(
int
(
np
.
ceil
(
h
/
anchor_stride
)),
feat_h
)
valid_feat_w
=
min
(
int
(
np
.
ceil
(
w
/
anchor_stride
)),
feat_w
)
flags
=
self
.
anchor_generators
[
i
].
valid_flags
(
(
feat_h
,
feat_w
),
(
valid_feat_h
,
valid_feat_w
))
multi_level_flags
.
append
(
flags
)
valid_flag_list
.
append
(
multi_level_flags
)
return
anchor_list
,
valid_flag_list
def
loss_single
(
self
,
cls_score
,
bbox_pred
,
labels
,
label_weights
,
bbox_targets
,
bbox_weights
,
num_pos_samples
,
cfg
):
# classification loss
labels
=
labels
.
contiguous
().
view
(
-
1
,
self
.
cls_out_channels
)
label_weights
=
label_weights
.
contiguous
().
view
(
-
1
,
self
.
cls_out_channels
)
cls_score
=
cls_score
.
permute
(
0
,
2
,
3
,
1
).
contiguous
().
view
(
-
1
,
self
.
cls_out_channels
)
loss_cls
=
weighted_sigmoid_focal_loss
(
cls_score
,
labels
,
label_weights
,
cfg
.
gamma
,
cfg
.
alpha
,
avg_factor
=
num_pos_samples
)
# regression loss
bbox_targets
=
bbox_targets
.
contiguous
().
view
(
-
1
,
4
)
bbox_weights
=
bbox_weights
.
contiguous
().
view
(
-
1
,
4
)
bbox_pred
=
bbox_pred
.
permute
(
0
,
2
,
3
,
1
).
contiguous
().
view
(
-
1
,
4
)
loss_reg
=
weighted_smoothl1
(
bbox_pred
,
bbox_targets
,
bbox_weights
,
beta
=
cfg
.
smoothl1_beta
,
avg_factor
=
num_pos_samples
)
return
loss_cls
,
loss_reg
def
loss
(
self
,
cls_scores
,
bbox_preds
,
gt_bboxes
,
gt_labels
,
img_metas
,
cfg
):
featmap_sizes
=
[
featmap
.
size
()[
-
2
:]
for
featmap
in
cls_scores
]
assert
len
(
featmap_sizes
)
==
len
(
self
.
anchor_generators
)
anchor_list
,
valid_flag_list
=
self
.
get_anchors
(
featmap_sizes
,
img_metas
)
cls_reg_targets
=
anchor_target
(
anchor_list
,
valid_flag_list
,
gt_bboxes
,
img_metas
,
self
.
target_means
,
self
.
target_stds
,
cfg
,
gt_labels_list
=
gt_labels
,
cls_out_channels
=
self
.
cls_out_channels
,
sampling
=
False
)
if
cls_reg_targets
is
None
:
return
None
(
labels_list
,
label_weights_list
,
bbox_targets_list
,
bbox_weights_list
,
num_total_pos
,
num_total_neg
)
=
cls_reg_targets
losses_cls
,
losses_reg
=
multi_apply
(
self
.
loss_single
,
cls_scores
,
bbox_preds
,
labels_list
,
label_weights_list
,
bbox_targets_list
,
bbox_weights_list
,
num_pos_samples
=
num_total_pos
,
cfg
=
cfg
)
return
dict
(
loss_cls
=
losses_cls
,
loss_reg
=
losses_reg
)
def
get_det_bboxes
(
self
,
cls_scores
,
bbox_preds
,
img_metas
,
cfg
,
rescale
=
False
):
assert
len
(
cls_scores
)
==
len
(
bbox_preds
)
num_levels
=
len
(
cls_scores
)
mlvl_anchors
=
[
self
.
anchor_generators
[
i
].
grid_anchors
(
cls_scores
[
i
].
size
()[
-
2
:],
self
.
anchor_strides
[
i
])
for
i
in
range
(
num_levels
)
]
result_list
=
[]
for
img_id
in
range
(
len
(
img_metas
)):
cls_score_list
=
[
cls_scores
[
i
][
img_id
].
detach
()
for
i
in
range
(
num_levels
)
]
bbox_pred_list
=
[
bbox_preds
[
i
][
img_id
].
detach
()
for
i
in
range
(
num_levels
)
]
img_shape
=
img_metas
[
img_id
][
'img_shape'
]
scale_factor
=
img_metas
[
img_id
][
'scale_factor'
]
results
=
self
.
_get_det_bboxes_single
(
cls_score_list
,
bbox_pred_list
,
mlvl_anchors
,
img_shape
,
scale_factor
,
cfg
,
rescale
)
result_list
.
append
(
results
)
return
result_list
def
_get_det_bboxes_single
(
self
,
cls_scores
,
bbox_preds
,
mlvl_anchors
,
img_shape
,
scale_factor
,
cfg
,
rescale
=
False
):
assert
len
(
cls_scores
)
==
len
(
bbox_preds
)
==
len
(
mlvl_anchors
)
mlvl_proposals
=
[]
mlvl_scores
=
[]
for
cls_score
,
bbox_pred
,
anchors
in
zip
(
cls_scores
,
bbox_preds
,
mlvl_anchors
):
assert
cls_score
.
size
()[
-
2
:]
==
bbox_pred
.
size
()[
-
2
:]
cls_score
=
cls_score
.
permute
(
1
,
2
,
0
).
contiguous
().
view
(
-
1
,
self
.
cls_out_channels
)
scores
=
cls_score
.
sigmoid
()
bbox_pred
=
bbox_pred
.
permute
(
1
,
2
,
0
).
contiguous
().
view
(
-
1
,
4
)
proposals
=
delta2bbox
(
anchors
,
bbox_pred
,
self
.
target_means
,
self
.
target_stds
,
img_shape
)
if
cfg
.
nms_pre
>
0
and
scores
.
shape
[
0
]
>
cfg
.
nms_pre
:
maxscores
,
_
=
scores
.
max
(
dim
=
1
)
_
,
topk_inds
=
maxscores
.
topk
(
cfg
.
nms_pre
)
proposals
=
proposals
[
topk_inds
,
:]
scores
=
scores
[
topk_inds
,
:]
mlvl_proposals
.
append
(
proposals
)
mlvl_scores
.
append
(
scores
)
mlvl_proposals
=
torch
.
cat
(
mlvl_proposals
)
if
rescale
:
mlvl_proposals
/=
scale_factor
mlvl_scores
=
torch
.
cat
(
mlvl_scores
)
padding
=
mlvl_scores
.
new_zeros
(
mlvl_scores
.
shape
[
0
],
1
)
mlvl_scores
=
torch
.
cat
([
padding
,
mlvl_scores
],
dim
=
1
)
det_bboxes
,
det_labels
=
multiclass_nms
(
mlvl_proposals
,
mlvl_scores
,
cfg
.
score_thr
,
cfg
.
nms
,
cfg
.
max_per_img
)
return
det_bboxes
,
det_labels
mmdet/models/utils/__init__.py
View file @
64e310d5
from
.conv_module
import
ConvModule
from
.conv_module
import
ConvModule
from
.norm
import
build_norm_layer
from
.norm
import
build_norm_layer
from
.weight_init
import
xavier_init
,
normal_init
,
uniform_init
,
kaiming_init
from
.weight_init
import
(
xavier_init
,
normal_init
,
uniform_init
,
kaiming_init
,
bias_init_with_prob
)
__all__
=
[
__all__
=
[
'ConvModule'
,
'build_norm_layer'
,
'xavier_init'
,
'normal_init'
,
'ConvModule'
,
'build_norm_layer'
,
'xavier_init'
,
'normal_init'
,
'uniform_init'
,
'kaiming_init'
'uniform_init'
,
'kaiming_init'
,
'bias_init_with_prob'
]
]
mmdet/models/utils/weight_init.py
View file @
64e310d5
import
numpy
as
np
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -37,3 +38,9 @@ def kaiming_init(module,
...
@@ -37,3 +38,9 @@ def kaiming_init(module,
module
.
weight
,
mode
=
mode
,
nonlinearity
=
nonlinearity
)
module
.
weight
,
mode
=
mode
,
nonlinearity
=
nonlinearity
)
if
hasattr
(
module
,
'bias'
):
if
hasattr
(
module
,
'bias'
):
nn
.
init
.
constant_
(
module
.
bias
,
bias
)
nn
.
init
.
constant_
(
module
.
bias
,
bias
)
def
bias_init_with_prob
(
prior_prob
):
""" initialize conv/fc bias value according to giving probablity"""
bias_init
=
float
(
-
np
.
log
((
1
-
prior_prob
)
/
prior_prob
))
return
bias_init
setup.py
View file @
64e310d5
...
@@ -12,7 +12,7 @@ def readme():
...
@@ -12,7 +12,7 @@ def readme():
MAJOR
=
0
MAJOR
=
0
MINOR
=
5
MINOR
=
5
PATCH
=
3
PATCH
=
4
SUFFIX
=
''
SUFFIX
=
''
SHORT_VERSION
=
'{}.{}.{}{}'
.
format
(
MAJOR
,
MINOR
,
PATCH
,
SUFFIX
)
SHORT_VERSION
=
'{}.{}.{}{}'
.
format
(
MAJOR
,
MINOR
,
PATCH
,
SUFFIX
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment