Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
70700512
Commit
70700512
authored
Jan 13, 2019
by
Kai Chen
Browse files
use AnchorHead to unify rpn head and single stage heads
parent
1b9f9b88
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
312 additions
and
524 deletions
+312
-524
mmdet/models/__init__.py
mmdet/models/__init__.py
+2
-2
mmdet/models/anchor_heads/__init__.py
mmdet/models/anchor_heads/__init__.py
+6
-0
mmdet/models/anchor_heads/anchor_head.py
mmdet/models/anchor_heads/anchor_head.py
+110
-120
mmdet/models/anchor_heads/retina_head.py
mmdet/models/anchor_heads/retina_head.py
+68
-0
mmdet/models/anchor_heads/rpn_head.py
mmdet/models/anchor_heads/rpn_head.py
+88
-0
mmdet/models/anchor_heads/ssd_head.py
mmdet/models/anchor_heads/ssd_head.py
+26
-127
mmdet/models/builder.py
mmdet/models/builder.py
+4
-14
mmdet/models/detectors/cascade_rcnn.py
mmdet/models/detectors/cascade_rcnn.py
+2
-2
mmdet/models/detectors/rpn.py
mmdet/models/detectors/rpn.py
+1
-1
mmdet/models/detectors/single_stage.py
mmdet/models/detectors/single_stage.py
+2
-2
mmdet/models/detectors/test_mixins.py
mmdet/models/detectors/test_mixins.py
+1
-1
mmdet/models/detectors/two_stage.py
mmdet/models/detectors/two_stage.py
+2
-2
mmdet/models/rpn_heads/__init__.py
mmdet/models/rpn_heads/__init__.py
+0
-3
mmdet/models/rpn_heads/rpn_head.py
mmdet/models/rpn_heads/rpn_head.py
+0
-250
No files found.
mmdet/models/__init__.py
View file @
70700512
from
.detectors
import
(
BaseDetector
,
TwoStageDetector
,
RPN
,
FastRCNN
,
from
.detectors
import
(
BaseDetector
,
TwoStageDetector
,
RPN
,
FastRCNN
,
FasterRCNN
,
MaskRCNN
)
FasterRCNN
,
MaskRCNN
)
from
.builder
import
(
build_neck
,
build_
rpn
_head
,
build_roi_extractor
,
from
.builder
import
(
build_neck
,
build_
anchor
_head
,
build_roi_extractor
,
build_bbox_head
,
build_mask_head
,
build_detector
)
build_bbox_head
,
build_mask_head
,
build_detector
)
__all__
=
[
__all__
=
[
'BaseDetector'
,
'TwoStageDetector'
,
'RPN'
,
'FastRCNN'
,
'FasterRCNN'
,
'BaseDetector'
,
'TwoStageDetector'
,
'RPN'
,
'FastRCNN'
,
'FasterRCNN'
,
'MaskRCNN'
,
'build_backbone'
,
'build_neck'
,
'build_
rpn
_head'
,
'MaskRCNN'
,
'build_backbone'
,
'build_neck'
,
'build_
anchor
_head'
,
'build_roi_extractor'
,
'build_bbox_head'
,
'build_mask_head'
,
'build_roi_extractor'
,
'build_bbox_head'
,
'build_mask_head'
,
'build_detector'
'build_detector'
]
]
mmdet/models/
single_stage
_heads/__init__.py
→
mmdet/models/
anchor
_heads/__init__.py
View file @
70700512
from
.anchor_head
import
AnchorHead
from
.rpn_head
import
RPNHead
from
.retina_head
import
RetinaHead
from
.retina_head
import
RetinaHead
from
.ssd_head
import
SSDHead
from
.ssd_head
import
SSDHead
__all__
=
[
'RetinaHead'
,
'SSDHead'
]
__all__
=
[
'AnchorHead'
,
'RPNHead'
,
'RetinaHead'
,
'SSDHead'
]
mmdet/models/
single_stage
_heads/
retina
_head.py
→
mmdet/models/
anchor
_heads/
anchor
_head.py
View file @
70700512
...
@@ -4,114 +4,85 @@ import numpy as np
...
@@ -4,114 +4,85 @@ import numpy as np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
mmdet.core
import
(
AnchorGenerator
,
anchor_target
,
multi_apply
,
from
mmdet.core
import
(
AnchorGenerator
,
anchor_target
,
delta2bbox
,
delta2bbox
,
weighted_smoothl1
,
multi_apply
,
weighted_cross_entropy
,
weighted_smoothl1
,
weighted_binary_cross_entropy
,
weighted_sigmoid_focal_loss
,
multiclass_nms
)
weighted_sigmoid_focal_loss
,
multiclass_nms
)
from
..utils
import
normal_init
,
bias_init_with_prob
from
..utils
import
normal_init
class
Retina
Head
(
nn
.
Module
):
class
Anchor
Head
(
nn
.
Module
):
"""
Head of RetinaNet
.
"""
Anchor-based head (RPN, RetinaNet, SSD, etc.)
.
/ cls_convs - retina
_cls (
3x3
conv)
/ - conv
_cls (
1x1
conv)
input -
input -
rpn_conv (3x3 conv) -
\ reg_convs - retina
_reg (
3x3
conv)
\ - conv
_reg (
1x1
conv)
Args:
Args:
in_channels (int): Number of channels in the input feature map.
in_channels (int): Number of channels in the input feature map.
num_classes (int): Class number (including background).
stacked_convs (int): Number of convolutional layers added for cls and
reg branch.
feat_channels (int): Number of channels for the RPN feature map.
feat_channels (int): Number of channels for the RPN feature map.
scales_per_octave (int): Number of anchor scales per octave.
anchor_scales (Iterable): Anchor scales.
octave_base_scale (int): Base octave scale. Anchor scales are computed
as `s*2^(i/n)`, for i in [0, n-1], where s is `octave_base_scale`
and n is `scales_per_octave`.
anchor_ratios (Iterable): Anchor aspect ratios.
anchor_ratios (Iterable): Anchor aspect ratios.
anchor_strides (Iterable): Anchor strides.
anchor_strides (Iterable): Anchor strides.
anchor_base_sizes (Iterable): Anchor base sizes.
target_means (Iterable): Mean values of regression targets.
target_means (Iterable): Mean values of regression targets.
target_stds (Iterable): Std values of regression targets.
target_stds (Iterable): Std values of regression targets.
use_sigmoid_cls (bool): Whether to use sigmoid loss for classification.
(softmax by default)
"""
# noqa: W605
"""
# noqa: W605
def
__init__
(
self
,
def
__init__
(
self
,
in_channels
,
num_classes
,
num_classes
,
stacked_convs
=
4
,
in_channels
,
feat_channels
=
256
,
feat_channels
=
256
,
octave_base_scale
=
4
,
anchor_scales
=
[
8
,
16
,
32
],
scales_per_octave
=
3
,
anchor_ratios
=
[
0.5
,
1.0
,
2.0
],
anchor_ratios
=
[
0.5
,
1.0
,
2.0
],
anchor_strides
=
[
8
,
16
,
32
,
64
,
128
],
anchor_strides
=
[
4
,
8
,
16
,
32
,
64
],
anchor_base_sizes
=
None
,
anchor_base_sizes
=
None
,
target_means
=
(.
0
,
.
0
,
.
0
,
.
0
),
target_means
=
(.
0
,
.
0
,
.
0
,
.
0
),
target_stds
=
(
1.0
,
1.0
,
1.0
,
1.0
)):
target_stds
=
(
1.0
,
1.0
,
1.0
,
1.0
),
super
(
RetinaHead
,
self
).
__init__
()
use_sigmoid_cls
=
False
,
use_focal_loss
=
False
):
super
(
AnchorHead
,
self
).
__init__
()
self
.
in_channels
=
in_channels
self
.
in_channels
=
in_channels
self
.
num_classes
=
num_classes
self
.
num_classes
=
num_classes
self
.
octave_base_scale
=
octave_base_scale
self
.
feat_channels
=
feat_channels
self
.
scales_per_octave
=
scales_per_octave
self
.
anchor_scales
=
anchor_scales
self
.
anchor_ratios
=
anchor_ratios
self
.
anchor_ratios
=
anchor_ratios
self
.
anchor_strides
=
anchor_strides
self
.
anchor_strides
=
anchor_strides
self
.
anchor_base_sizes
=
list
(
self
.
anchor_base_sizes
=
list
(
anchor_strides
)
if
anchor_base_sizes
is
None
else
anchor_base_sizes
anchor_strides
)
if
anchor_base_sizes
is
None
else
anchor_base_sizes
self
.
target_means
=
target_means
self
.
target_means
=
target_means
self
.
target_stds
=
target_stds
self
.
target_stds
=
target_stds
self
.
use_sigmoid_cls
=
use_sigmoid_cls
self
.
use_focal_loss
=
use_focal_loss
self
.
anchor_generators
=
[]
self
.
anchor_generators
=
[]
for
anchor_base
in
self
.
anchor_base_sizes
:
for
anchor_base
in
self
.
anchor_base_sizes
:
octave_scales
=
np
.
array
(
[
2
**
(
i
/
scales_per_octave
)
for
i
in
range
(
scales_per_octave
)])
anchor_scales
=
octave_scales
*
octave_base_scale
self
.
anchor_generators
.
append
(
self
.
anchor_generators
.
append
(
AnchorGenerator
(
anchor_base
,
anchor_scales
,
anchor_ratios
))
AnchorGenerator
(
anchor_base
,
anchor_scales
,
anchor_ratios
))
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
num_anchors
=
int
(
len
(
self
.
anchor_ratios
)
*
self
.
scales_per_octave
)
self
.
cls_out_channels
=
self
.
num_classes
-
1
self
.
bbox_pred_dim
=
4
self
.
stacked_convs
=
stacked_convs
self
.
num_anchors
=
len
(
self
.
anchor_ratios
)
*
len
(
self
.
anchor_scales
)
self
.
cls_convs
=
nn
.
ModuleList
()
if
self
.
use_sigmoid_cls
:
self
.
reg_convs
=
nn
.
ModuleList
()
self
.
cls_out_channels
=
self
.
num_classes
-
1
for
i
in
range
(
self
.
stacked_convs
):
else
:
chn
=
in_channels
if
i
==
0
else
feat_channels
self
.
cls_out_channels
=
self
.
num_classes
self
.
cls_convs
.
append
(
nn
.
Conv2d
(
chn
,
feat_channels
,
3
,
stride
=
1
,
padding
=
1
))
self
.
_init_layers
()
self
.
reg_convs
.
append
(
nn
.
Conv2d
(
chn
,
feat_channels
,
3
,
stride
=
1
,
padding
=
1
))
def
_init_layers
(
self
):
self
.
retina_cls
=
nn
.
Conv2d
(
self
.
conv_cls
=
nn
.
Conv2d
(
self
.
feat_channels
,
feat_channels
,
self
.
num_anchors
*
self
.
cls_out_channels
,
1
)
self
.
num_anchors
*
self
.
cls_out_channels
,
self
.
conv_reg
=
nn
.
Conv2d
(
self
.
feat_channels
,
self
.
num_anchors
*
4
,
1
)
3
,
stride
=
1
,
padding
=
1
)
self
.
retina_reg
=
nn
.
Conv2d
(
feat_channels
,
self
.
num_anchors
*
self
.
bbox_pred_dim
,
3
,
stride
=
1
,
padding
=
1
)
self
.
debug_imgs
=
None
def
init_weights
(
self
):
def
init_weights
(
self
):
for
m
in
self
.
cls_convs
:
normal_init
(
self
.
conv_cls
,
std
=
0.01
)
normal_init
(
m
,
std
=
0.01
)
normal_init
(
self
.
conv_reg
,
std
=
0.01
)
for
m
in
self
.
reg_convs
:
normal_init
(
m
,
std
=
0.01
)
bias_cls
=
bias_init_with_prob
(
0.01
)
normal_init
(
self
.
retina_cls
,
std
=
0.01
,
bias
=
bias_cls
)
normal_init
(
self
.
retina_reg
,
std
=
0.01
)
def
forward_single
(
self
,
x
):
def
forward_single
(
self
,
x
):
cls_feat
=
x
rpn_cls_score
=
self
.
conv_cls
(
x
)
reg_feat
=
x
rpn_bbox_pred
=
self
.
conv_reg
(
x
)
for
cls_conv
in
self
.
cls_convs
:
return
rpn_cls_score
,
rpn_bbox_pred
cls_feat
=
self
.
relu
(
cls_conv
(
cls_feat
))
for
reg_conv
in
self
.
reg_convs
:
reg_feat
=
self
.
relu
(
reg_conv
(
reg_feat
))
cls_score
=
self
.
retina_cls
(
cls_feat
)
bbox_pred
=
self
.
retina_reg
(
reg_feat
)
return
cls_score
,
bbox_pred
def
forward
(
self
,
feats
):
def
forward
(
self
,
feats
):
return
multi_apply
(
self
.
forward_single
,
feats
)
return
multi_apply
(
self
.
forward_single
,
feats
)
...
@@ -156,20 +127,34 @@ class RetinaHead(nn.Module):
...
@@ -156,20 +127,34 @@ class RetinaHead(nn.Module):
return
anchor_list
,
valid_flag_list
return
anchor_list
,
valid_flag_list
def
loss_single
(
self
,
cls_score
,
bbox_pred
,
labels
,
label_weights
,
def
loss_single
(
self
,
cls_score
,
bbox_pred
,
labels
,
label_weights
,
bbox_targets
,
bbox_weights
,
num_
pos
_samples
,
cfg
):
bbox_targets
,
bbox_weights
,
num_
total
_samples
,
cfg
):
# classification loss
# classification loss
labels
=
labels
.
contiguous
().
view
(
-
1
,
self
.
cls_out_channels
)
labels
=
labels
.
contiguous
().
view
(
-
1
,
self
.
cls_out_channels
)
label_weights
=
label_weights
.
contiguous
().
view
(
label_weights
=
label_weights
.
contiguous
().
view
(
-
1
,
self
.
cls_out_channels
)
-
1
,
self
.
cls_out_channels
)
cls_score
=
cls_score
.
permute
(
0
,
2
,
3
,
1
).
contiguous
().
view
(
cls_score
=
cls_score
.
permute
(
0
,
2
,
3
,
1
).
contiguous
().
view
(
-
1
,
self
.
cls_out_channels
)
-
1
,
self
.
cls_out_channels
)
loss_cls
=
weighted_sigmoid_focal_loss
(
if
self
.
use_sigmoid_cls
:
cls_score
,
if
self
.
use_focal_loss
:
labels
,
cls_criterion
=
weighted_sigmoid_focal_loss
label_weights
,
else
:
cfg
.
gamma
,
cls_criterion
=
weighted_binary_cross_entropy
cfg
.
alpha
,
else
:
avg_factor
=
num_pos_samples
)
if
self
.
use_focal_loss
:
raise
NotImplementedError
else
:
cls_criterion
=
weighted_cross_entropy
if
self
.
use_focal_loss
:
loss_cls
=
cls_criterion
(
cls_score
,
labels
,
label_weights
,
gamma
=
cfg
.
gamma
,
alpha
=
cfg
.
alpha
,
avg_factor
=
num_total_samples
)
else
:
loss_cls
=
cls_criterion
(
cls_score
,
labels
,
label_weights
,
avg_factor
=
num_total_samples
)
# regression loss
# regression loss
bbox_targets
=
bbox_targets
.
contiguous
().
view
(
-
1
,
4
)
bbox_targets
=
bbox_targets
.
contiguous
().
view
(
-
1
,
4
)
bbox_weights
=
bbox_weights
.
contiguous
().
view
(
-
1
,
4
)
bbox_weights
=
bbox_weights
.
contiguous
().
view
(
-
1
,
4
)
...
@@ -179,7 +164,7 @@ class RetinaHead(nn.Module):
...
@@ -179,7 +164,7 @@ class RetinaHead(nn.Module):
bbox_targets
,
bbox_targets
,
bbox_weights
,
bbox_weights
,
beta
=
cfg
.
smoothl1_beta
,
beta
=
cfg
.
smoothl1_beta
,
avg_factor
=
num_
pos
_samples
)
avg_factor
=
num_
total
_samples
)
return
loss_cls
,
loss_reg
return
loss_cls
,
loss_reg
def
loss
(
self
,
cls_scores
,
bbox_preds
,
gt_bboxes
,
gt_labels
,
img_metas
,
def
loss
(
self
,
cls_scores
,
bbox_preds
,
gt_bboxes
,
gt_labels
,
img_metas
,
...
@@ -189,6 +174,7 @@ class RetinaHead(nn.Module):
...
@@ -189,6 +174,7 @@ class RetinaHead(nn.Module):
anchor_list
,
valid_flag_list
=
self
.
get_anchors
(
anchor_list
,
valid_flag_list
=
self
.
get_anchors
(
featmap_sizes
,
img_metas
)
featmap_sizes
,
img_metas
)
sampling
=
False
if
self
.
use_focal_loss
else
True
cls_reg_targets
=
anchor_target
(
cls_reg_targets
=
anchor_target
(
anchor_list
,
anchor_list
,
valid_flag_list
,
valid_flag_list
,
...
@@ -199,12 +185,13 @@ class RetinaHead(nn.Module):
...
@@ -199,12 +185,13 @@ class RetinaHead(nn.Module):
cfg
,
cfg
,
gt_labels_list
=
gt_labels
,
gt_labels_list
=
gt_labels
,
cls_out_channels
=
self
.
cls_out_channels
,
cls_out_channels
=
self
.
cls_out_channels
,
sampling
=
False
)
sampling
=
sampling
)
if
cls_reg_targets
is
None
:
if
cls_reg_targets
is
None
:
return
None
return
None
(
labels_list
,
label_weights_list
,
bbox_targets_list
,
bbox_weights_list
,
(
labels_list
,
label_weights_list
,
bbox_targets_list
,
bbox_weights_list
,
num_total_pos
,
num_total_neg
)
=
cls_reg_targets
num_total_pos
,
num_total_neg
)
=
cls_reg_targets
num_total_samples
=
(
num_total_pos
if
self
.
use_focal_loss
else
num_total_pos
+
num_total_neg
)
losses_cls
,
losses_reg
=
multi_apply
(
losses_cls
,
losses_reg
=
multi_apply
(
self
.
loss_single
,
self
.
loss_single
,
cls_scores
,
cls_scores
,
...
@@ -213,16 +200,12 @@ class RetinaHead(nn.Module):
...
@@ -213,16 +200,12 @@ class RetinaHead(nn.Module):
label_weights_list
,
label_weights_list
,
bbox_targets_list
,
bbox_targets_list
,
bbox_weights_list
,
bbox_weights_list
,
num_
pos
_samples
=
num_total_
po
s
,
num_
total
_samples
=
num_total_
sample
s
,
cfg
=
cfg
)
cfg
=
cfg
)
return
dict
(
loss_cls
=
losses_cls
,
loss_reg
=
losses_reg
)
return
dict
(
loss_
rpn_
cls
=
losses_cls
,
loss_
rpn_
reg
=
losses_reg
)
def
get_det_bboxes
(
self
,
def
get_bboxes
(
self
,
cls_scores
,
bbox_preds
,
img_metas
,
cfg
,
cls_scores
,
rescale
=
False
):
bbox_preds
,
img_metas
,
cfg
,
rescale
=
False
):
assert
len
(
cls_scores
)
==
len
(
bbox_preds
)
assert
len
(
cls_scores
)
==
len
(
bbox_preds
)
num_levels
=
len
(
cls_scores
)
num_levels
=
len
(
cls_scores
)
...
@@ -231,7 +214,6 @@ class RetinaHead(nn.Module):
...
@@ -231,7 +214,6 @@ class RetinaHead(nn.Module):
self
.
anchor_strides
[
i
])
self
.
anchor_strides
[
i
])
for
i
in
range
(
num_levels
)
for
i
in
range
(
num_levels
)
]
]
result_list
=
[]
result_list
=
[]
for
img_id
in
range
(
len
(
img_metas
)):
for
img_id
in
range
(
len
(
img_metas
)):
cls_score_list
=
[
cls_score_list
=
[
...
@@ -242,46 +224,54 @@ class RetinaHead(nn.Module):
...
@@ -242,46 +224,54 @@ class RetinaHead(nn.Module):
]
]
img_shape
=
img_metas
[
img_id
][
'img_shape'
]
img_shape
=
img_metas
[
img_id
][
'img_shape'
]
scale_factor
=
img_metas
[
img_id
][
'scale_factor'
]
scale_factor
=
img_metas
[
img_id
][
'scale_factor'
]
result
s
=
self
.
_
get_
det_
bboxes_single
(
proposal
s
=
self
.
get_bboxes_single
(
cls_score_list
,
bbox_pred_list
,
cls_score_list
,
bbox_pred_list
,
mlvl_anchors
,
img_shape
,
mlvl_anchors
,
img_shape
,
scale_factor
,
cfg
,
rescale
)
scale_factor
,
cfg
,
rescale
)
result_list
.
append
(
result
s
)
result_list
.
append
(
proposal
s
)
return
result_list
return
result_list
def
_get_d
et_bboxes_single
(
self
,
def
g
et_bboxes_single
(
self
,
cls_scores
,
cls_scores
,
bbox_preds
,
bbox_preds
,
mlvl_anchors
,
mlvl_anchors
,
img_shape
,
img_shape
,
scale_factor
,
scale_factor
,
cfg
,
cfg
,
rescale
=
False
):
rescale
=
False
):
assert
len
(
cls_scores
)
==
len
(
bbox_preds
)
==
len
(
mlvl_anchors
)
assert
len
(
cls_scores
)
==
len
(
bbox_preds
)
==
len
(
mlvl_anchors
)
mlvl_
proposal
s
=
[]
mlvl_
bboxe
s
=
[]
mlvl_scores
=
[]
mlvl_scores
=
[]
for
cls_score
,
bbox_pred
,
anchors
in
zip
(
cls_scores
,
bbox_preds
,
for
cls_score
,
bbox_pred
,
anchors
in
zip
(
cls_scores
,
bbox_preds
,
mlvl_anchors
):
mlvl_anchors
):
assert
cls_score
.
size
()[
-
2
:]
==
bbox_pred
.
size
()[
-
2
:]
assert
cls_score
.
size
()[
-
2
:]
==
bbox_pred
.
size
()[
-
2
:]
cls_score
=
cls_score
.
permute
(
1
,
2
,
0
).
contiguous
().
view
(
cls_score
=
cls_score
.
permute
(
1
,
2
,
0
).
reshape
(
-
1
,
self
.
cls_out_channels
)
-
1
,
self
.
cls_out_channels
)
scores
=
cls_score
.
sigmoid
()
if
self
.
use_sigmoid_cls
:
bbox_pred
=
bbox_pred
.
permute
(
1
,
2
,
0
).
contiguous
().
view
(
-
1
,
4
)
scores
=
cls_score
.
sigmoid
()
proposals
=
delta2bbox
(
anchors
,
bbox_pred
,
self
.
target_means
,
else
:
self
.
target_stds
,
img_shape
)
scores
=
cls_score
.
softmax
(
-
1
)
if
cfg
.
nms_pre
>
0
and
scores
.
shape
[
0
]
>
cfg
.
nms_pre
:
bbox_pred
=
bbox_pred
.
permute
(
1
,
2
,
0
).
reshape
(
-
1
,
4
)
maxscores
,
_
=
scores
.
max
(
dim
=
1
)
nms_pre
=
cfg
.
get
(
'nms_pre'
,
-
1
)
_
,
topk_inds
=
maxscores
.
topk
(
cfg
.
nms_pre
)
if
nms_pre
>
0
and
scores
.
shape
[
0
]
>
nms_pre
:
proposals
=
proposals
[
topk_inds
,
:]
if
self
.
use_sigmoid_cls
:
max_scores
,
_
=
scores
.
max
(
dim
=
1
)
else
:
max_scores
,
_
=
scores
[:,
1
:].
max
(
dim
=
1
)
_
,
topk_inds
=
max_scores
.
topk
(
nms_pre
)
anchors
=
anchors
[
topk_inds
,
:]
bbox_pred
=
bbox_pred
[
topk_inds
,
:]
scores
=
scores
[
topk_inds
,
:]
scores
=
scores
[
topk_inds
,
:]
mlvl_proposals
.
append
(
proposals
)
bboxes
=
delta2bbox
(
anchors
,
bbox_pred
,
self
.
target_means
,
self
.
target_stds
,
img_shape
)
mlvl_bboxes
.
append
(
bboxes
)
mlvl_scores
.
append
(
scores
)
mlvl_scores
.
append
(
scores
)
mlvl_
proposal
s
=
torch
.
cat
(
mlvl_
proposal
s
)
mlvl_
bboxe
s
=
torch
.
cat
(
mlvl_
bboxe
s
)
if
rescale
:
if
rescale
:
mlvl_
proposals
/=
scale_factor
mlvl_
bboxes
/=
mlvl_bboxes
.
new_tensor
(
scale_factor
)
mlvl_scores
=
torch
.
cat
(
mlvl_scores
)
mlvl_scores
=
torch
.
cat
(
mlvl_scores
)
padding
=
mlvl_scores
.
new_zeros
(
mlvl_scores
.
shape
[
0
],
1
)
if
self
.
use_sigmoid_cls
:
mlvl_scores
=
torch
.
cat
([
padding
,
mlvl_scores
],
dim
=
1
)
padding
=
mlvl_scores
.
new_zeros
(
mlvl_scores
.
shape
[
0
],
1
)
det_bboxes
,
det_labels
=
multiclass_nms
(
mlvl_proposals
,
mlvl_scores
,
mlvl_scores
=
torch
.
cat
([
padding
,
mlvl_scores
],
dim
=
1
)
cfg
.
score_thr
,
cfg
.
nms
,
det_bboxes
,
det_labels
=
multiclass_
nms
(
cfg
.
max_per_img
)
mlvl_bboxes
,
mlvl_scores
,
cfg
.
score_thr
,
cfg
.
nms
,
cfg
.
max_per_img
)
return
det_bboxes
,
det_labels
return
det_bboxes
,
det_labels
mmdet/models/anchor_heads/retina_head.py
0 → 100644
View file @
70700512
import
numpy
as
np
import
torch.nn
as
nn
from
mmcv.cnn
import
normal_init
from
.anchor_head
import
AnchorHead
from
..utils
import
bias_init_with_prob
class
RetinaHead
(
AnchorHead
):
def
__init__
(
self
,
num_classes
,
in_channels
,
stacked_convs
=
4
,
octave_base_scale
=
4
,
scales_per_octave
=
3
,
**
kwargs
):
self
.
stacked_convs
=
stacked_convs
self
.
octave_base_scale
=
octave_base_scale
self
.
scales_per_octave
=
scales_per_octave
octave_scales
=
np
.
array
(
[
2
**
(
i
/
scales_per_octave
)
for
i
in
range
(
scales_per_octave
)])
anchor_scales
=
octave_scales
*
octave_base_scale
super
(
RetinaHead
,
self
).
__init__
(
num_classes
,
in_channels
,
anchor_scales
=
anchor_scales
,
use_sigmoid_cls
=
True
,
use_focal_loss
=
True
,
**
kwargs
)
def
_init_layers
(
self
):
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
cls_convs
=
nn
.
ModuleList
()
self
.
reg_convs
=
nn
.
ModuleList
()
for
i
in
range
(
self
.
stacked_convs
):
chn
=
self
.
in_channels
if
i
==
0
else
self
.
feat_channels
self
.
cls_convs
.
append
(
nn
.
Conv2d
(
chn
,
self
.
feat_channels
,
3
,
stride
=
1
,
padding
=
1
))
self
.
reg_convs
.
append
(
nn
.
Conv2d
(
chn
,
self
.
feat_channels
,
3
,
stride
=
1
,
padding
=
1
))
self
.
retina_cls
=
nn
.
Conv2d
(
self
.
feat_channels
,
self
.
num_anchors
*
self
.
cls_out_channels
,
3
,
padding
=
1
)
self
.
retina_reg
=
nn
.
Conv2d
(
self
.
feat_channels
,
self
.
num_anchors
*
4
,
3
,
padding
=
1
)
def
init_weights
(
self
):
for
m
in
self
.
cls_convs
:
normal_init
(
m
,
std
=
0.01
)
for
m
in
self
.
reg_convs
:
normal_init
(
m
,
std
=
0.01
)
bias_cls
=
bias_init_with_prob
(
0.01
)
normal_init
(
self
.
retina_cls
,
std
=
0.01
,
bias
=
bias_cls
)
normal_init
(
self
.
retina_reg
,
std
=
0.01
)
def
forward_single
(
self
,
x
):
cls_feat
=
x
reg_feat
=
x
for
cls_conv
in
self
.
cls_convs
:
cls_feat
=
self
.
relu
(
cls_conv
(
cls_feat
))
for
reg_conv
in
self
.
reg_convs
:
reg_feat
=
self
.
relu
(
reg_conv
(
reg_feat
))
cls_score
=
self
.
retina_cls
(
cls_feat
)
bbox_pred
=
self
.
retina_reg
(
reg_feat
)
return
cls_score
,
bbox_pred
mmdet/models/anchor_heads/rpn_head.py
0 → 100644
View file @
70700512
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmcv.cnn
import
normal_init
from
mmdet.core
import
delta2bbox
from
mmdet.ops
import
nms
from
.anchor_head
import
AnchorHead
class
RPNHead
(
AnchorHead
):
def
__init__
(
self
,
in_channels
,
**
kwargs
):
super
(
RPNHead
,
self
).
__init__
(
2
,
in_channels
,
**
kwargs
)
def
_init_layers
(
self
):
self
.
rpn_conv
=
nn
.
Conv2d
(
self
.
in_channels
,
self
.
feat_channels
,
3
,
padding
=
1
)
self
.
rpn_cls
=
nn
.
Conv2d
(
self
.
feat_channels
,
self
.
num_anchors
*
self
.
cls_out_channels
,
1
)
self
.
rpn_reg
=
nn
.
Conv2d
(
self
.
feat_channels
,
self
.
num_anchors
*
4
,
1
)
def
init_weights
(
self
):
normal_init
(
self
.
rpn_conv
,
std
=
0.01
)
normal_init
(
self
.
rpn_cls
,
std
=
0.01
)
normal_init
(
self
.
rpn_reg
,
std
=
0.01
)
def
forward_single
(
self
,
x
):
x
=
self
.
rpn_conv
(
x
)
x
=
F
.
relu
(
x
,
inplace
=
True
)
rpn_cls_score
=
self
.
rpn_cls
(
x
)
rpn_bbox_pred
=
self
.
rpn_reg
(
x
)
return
rpn_cls_score
,
rpn_bbox_pred
def
loss
(
self
,
cls_scores
,
bbox_preds
,
gt_bboxes
,
img_metas
,
cfg
):
return
super
(
RPNHead
,
self
).
loss
(
cls_scores
,
bbox_preds
,
gt_bboxes
,
None
,
img_metas
,
cfg
)
def
get_bboxes_single
(
self
,
cls_scores
,
bbox_preds
,
mlvl_anchors
,
img_shape
,
scale_factor
,
cfg
,
rescale
=
False
):
mlvl_proposals
=
[]
for
idx
in
range
(
len
(
cls_scores
)):
rpn_cls_score
=
cls_scores
[
idx
]
rpn_bbox_pred
=
bbox_preds
[
idx
]
assert
rpn_cls_score
.
size
()[
-
2
:]
==
rpn_bbox_pred
.
size
()[
-
2
:]
anchors
=
mlvl_anchors
[
idx
]
rpn_cls_score
=
rpn_cls_score
.
permute
(
1
,
2
,
0
)
if
self
.
use_sigmoid_cls
:
rpn_cls_score
=
rpn_cls_score
.
reshape
(
-
1
)
scores
=
rpn_cls_score
.
sigmoid
()
else
:
rpn_cls_score
=
rpn_cls_score
.
reshape
(
-
1
,
2
)
scores
=
rpn_cls_score
.
softmax
(
dim
=
1
)[:,
1
]
rpn_bbox_pred
=
rpn_bbox_pred
.
permute
(
1
,
2
,
0
).
reshape
(
-
1
,
4
)
if
cfg
.
nms_pre
>
0
and
scores
.
shape
[
0
]
>
cfg
.
nms_pre
:
_
,
topk_inds
=
scores
.
topk
(
cfg
.
nms_pre
)
rpn_bbox_pred
=
rpn_bbox_pred
[
topk_inds
,
:]
anchors
=
anchors
[
topk_inds
,
:]
scores
=
scores
[
topk_inds
]
proposals
=
delta2bbox
(
anchors
,
rpn_bbox_pred
,
self
.
target_means
,
self
.
target_stds
,
img_shape
)
if
cfg
.
min_bbox_size
>
0
:
w
=
proposals
[:,
2
]
-
proposals
[:,
0
]
+
1
h
=
proposals
[:,
3
]
-
proposals
[:,
1
]
+
1
valid_inds
=
torch
.
nonzero
((
w
>=
cfg
.
min_bbox_size
)
&
(
h
>=
cfg
.
min_bbox_size
)).
squeeze
()
proposals
=
proposals
[
valid_inds
,
:]
scores
=
scores
[
valid_inds
]
proposals
=
torch
.
cat
([
proposals
,
scores
.
unsqueeze
(
-
1
)],
dim
=-
1
)
proposals
,
_
=
nms
(
proposals
,
cfg
.
nms_thr
)
proposals
=
proposals
[:
cfg
.
nms_post
,
:]
mlvl_proposals
.
append
(
proposals
)
proposals
=
torch
.
cat
(
mlvl_proposals
,
0
)
if
cfg
.
nms_across_levels
:
proposals
,
_
=
nms
(
proposals
,
cfg
.
nms_thr
)
proposals
=
proposals
[:
cfg
.
max_num
,
:]
else
:
scores
=
proposals
[:,
4
]
num
=
min
(
cfg
.
max_num
,
proposals
.
shape
[
0
])
_
,
topk_inds
=
scores
.
topk
(
num
)
proposals
=
proposals
[
topk_inds
,
:]
return
proposals
mmdet/models/
single_stage
_heads/ssd_head.py
→
mmdet/models/
anchor
_heads/ssd_head.py
View file @
70700512
from
__future__
import
division
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
mmcv.cnn
import
xavier_init
from
mmcv.cnn
import
xavier_init
from
mmdet.core
import
(
AnchorGenerator
,
anchor_target
,
multi_apply
,
from
mmdet.core
import
(
AnchorGenerator
,
anchor_target
,
weighted_smoothl1
,
delta2bbox
,
weighted_smoothl1
,
multiclass_nms
)
multi_apply
)
from
.anchor_head
import
AnchorHead
class
SSDHead
(
nn
.
Module
):
class
SSDHead
(
AnchorHead
):
def
__init__
(
self
,
def
__init__
(
self
,
input_size
=
300
,
input_size
=
300
,
in_channels
=
(
512
,
1024
,
512
,
256
,
256
,
256
),
num_classes
=
81
,
num_classes
=
81
,
in_channels
=
(
512
,
1024
,
512
,
256
,
256
,
256
),
anchor_strides
=
(
8
,
16
,
32
,
64
,
100
,
300
),
anchor_strides
=
(
8
,
16
,
32
,
64
,
100
,
300
),
basesize_ratio_range
=
(
0.1
,
0.9
),
basesize_ratio_range
=
(
0.1
,
0.9
),
anchor_ratios
=
([
2
],
[
2
,
3
],
[
2
,
3
],
[
2
,
3
],
[
2
],
[
2
]),
anchor_ratios
=
([
2
],
[
2
,
3
],
[
2
,
3
],
[
2
,
3
],
[
2
],
[
2
]),
target_means
=
(.
0
,
.
0
,
.
0
,
.
0
),
target_means
=
(.
0
,
.
0
,
.
0
,
.
0
),
target_stds
=
(
1.0
,
1.0
,
1.0
,
1.0
)):
target_stds
=
(
1.0
,
1.0
,
1.0
,
1.0
)):
super
(
SSDHead
,
self
).
__init__
()
super
(
AnchorHead
,
self
).
__init__
()
# construct head
self
.
input_size
=
input_size
num_anchors
=
[
len
(
ratios
)
*
2
+
2
for
ratios
in
anchor_ratios
]
self
.
in_channels
=
in_channels
self
.
num_classes
=
num_classes
self
.
num_classes
=
num_classes
self
.
in_channels
=
in_channels
self
.
cls_out_channels
=
num_classes
self
.
cls_out_channels
=
num_classes
num_anchors
=
[
len
(
ratios
)
*
2
+
2
for
ratios
in
anchor_ratios
]
reg_convs
=
[]
reg_convs
=
[]
cls_convs
=
[]
cls_convs
=
[]
for
i
in
range
(
len
(
in_channels
)):
for
i
in
range
(
len
(
in_channels
)):
...
@@ -88,6 +87,8 @@ class SSDHead(nn.Module):
...
@@ -88,6 +87,8 @@ class SSDHead(nn.Module):
self
.
target_means
=
target_means
self
.
target_means
=
target_means
self
.
target_stds
=
target_stds
self
.
target_stds
=
target_stds
self
.
use_sigmoid_cls
=
False
self
.
use_focal_loss
=
False
def
init_weights
(
self
):
def
init_weights
(
self
):
for
m
in
self
.
modules
():
for
m
in
self
.
modules
():
...
@@ -103,68 +104,28 @@ class SSDHead(nn.Module):
...
@@ -103,68 +104,28 @@ class SSDHead(nn.Module):
bbox_preds
.
append
(
reg_conv
(
feat
))
bbox_preds
.
append
(
reg_conv
(
feat
))
return
cls_scores
,
bbox_preds
return
cls_scores
,
bbox_preds
def
get_anchors
(
self
,
featmap_sizes
,
img_metas
):
"""Get anchors according to feature map sizes.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
img_metas (list[dict]): Image meta info.
Returns:
tuple: anchors of each image, valid flags of each image
"""
num_imgs
=
len
(
img_metas
)
num_levels
=
len
(
featmap_sizes
)
# since feature map sizes of all images are the same, we only compute
# anchors for one time
multi_level_anchors
=
[]
for
i
in
range
(
num_levels
):
anchors
=
self
.
anchor_generators
[
i
].
grid_anchors
(
featmap_sizes
[
i
],
self
.
anchor_strides
[
i
])
multi_level_anchors
.
append
(
anchors
)
anchor_list
=
[
multi_level_anchors
for
_
in
range
(
num_imgs
)]
# for each image, we compute valid flags of multi level anchors
valid_flag_list
=
[]
for
img_id
,
img_meta
in
enumerate
(
img_metas
):
multi_level_flags
=
[]
for
i
in
range
(
num_levels
):
anchor_stride
=
self
.
anchor_strides
[
i
]
feat_h
,
feat_w
=
featmap_sizes
[
i
]
h
,
w
,
_
=
img_meta
[
'pad_shape'
]
valid_feat_h
=
min
(
int
(
np
.
ceil
(
h
/
anchor_stride
)),
feat_h
)
valid_feat_w
=
min
(
int
(
np
.
ceil
(
w
/
anchor_stride
)),
feat_w
)
flags
=
self
.
anchor_generators
[
i
].
valid_flags
(
(
feat_h
,
feat_w
),
(
valid_feat_h
,
valid_feat_w
))
multi_level_flags
.
append
(
flags
)
valid_flag_list
.
append
(
multi_level_flags
)
return
anchor_list
,
valid_flag_list
def
loss_single
(
self
,
cls_score
,
bbox_pred
,
labels
,
label_weights
,
def
loss_single
(
self
,
cls_score
,
bbox_pred
,
labels
,
label_weights
,
bbox_targets
,
bbox_weights
,
num_
pos
_samples
,
cfg
):
bbox_targets
,
bbox_weights
,
num_
total
_samples
,
cfg
):
loss_cls_all
=
F
.
cross_entropy
(
loss_cls_all
=
F
.
cross_entropy
(
cls_score
,
labels
,
reduction
=
'none'
)
*
label_weights
cls_score
,
labels
,
reduction
=
'none'
)
*
label_weights
pos_label_inds
=
(
labels
>
0
).
nonzero
().
view
(
-
1
)
pos_inds
=
(
labels
>
0
).
nonzero
().
view
(
-
1
)
neg_label_inds
=
(
labels
==
0
).
nonzero
().
view
(
-
1
)
neg_inds
=
(
labels
==
0
).
nonzero
().
view
(
-
1
)
num_sample_pos
=
pos_label_inds
.
size
(
0
)
num_pos_samples
=
pos_inds
.
size
(
0
)
num_sample_neg
=
cfg
.
neg_pos_ratio
*
num_sample_pos
num_neg_samples
=
cfg
.
neg_pos_ratio
*
num_pos_samples
if
num_sample_neg
>
neg_label_inds
.
size
(
0
):
if
num_neg_samples
>
neg_inds
.
size
(
0
):
num_sample_neg
=
neg_label_inds
.
size
(
0
)
num_neg_samples
=
neg_inds
.
size
(
0
)
topk_loss_cls_neg
,
topk_loss_cls_neg_inds
=
\
topk_loss_cls_neg
,
_
=
loss_cls_all
[
neg_inds
].
topk
(
num_neg_samples
)
loss_cls_all
[
neg_label_inds
].
topk
(
num_sample_neg
)
loss_cls_pos
=
loss_cls_all
[
pos_inds
].
sum
()
loss_cls_pos
=
loss_cls_all
[
pos_label_inds
].
sum
()
loss_cls_neg
=
topk_loss_cls_neg
.
sum
()
loss_cls_neg
=
topk_loss_cls_neg
.
sum
()
loss_cls
=
(
loss_cls_pos
+
loss_cls_neg
)
/
num_
pos
_samples
loss_cls
=
(
loss_cls_pos
+
loss_cls_neg
)
/
num_
total
_samples
loss_reg
=
weighted_smoothl1
(
loss_reg
=
weighted_smoothl1
(
bbox_pred
,
bbox_pred
,
bbox_targets
,
bbox_targets
,
bbox_weights
,
bbox_weights
,
beta
=
cfg
.
smoothl1_beta
,
beta
=
cfg
.
smoothl1_beta
,
avg_factor
=
num_
pos
_samples
)
avg_factor
=
num_
total
_samples
)
return
loss_cls
[
None
],
loss_reg
return
loss_cls
[
None
],
loss_reg
def
loss
(
self
,
cls_scores
,
bbox_preds
,
gt_bboxes
,
gt_labels
,
img_metas
,
def
loss
(
self
,
cls_scores
,
bbox_preds
,
gt_bboxes
,
gt_labels
,
img_metas
,
...
@@ -193,14 +154,14 @@ class SSDHead(nn.Module):
...
@@ -193,14 +154,14 @@ class SSDHead(nn.Module):
num_images
=
len
(
img_metas
)
num_images
=
len
(
img_metas
)
all_cls_scores
=
torch
.
cat
([
all_cls_scores
=
torch
.
cat
([
s
.
permute
(
0
,
2
,
3
,
1
).
contiguous
().
view
(
s
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
num_images
,
-
1
,
self
.
cls_out_channels
)
for
s
in
cls_scores
num_images
,
-
1
,
self
.
cls_out_channels
)
for
s
in
cls_scores
],
1
)
],
1
)
all_labels
=
torch
.
cat
(
labels_list
,
-
1
).
view
(
num_images
,
-
1
)
all_labels
=
torch
.
cat
(
labels_list
,
-
1
).
view
(
num_images
,
-
1
)
all_label_weights
=
torch
.
cat
(
label_weights_list
,
-
1
).
view
(
all_label_weights
=
torch
.
cat
(
label_weights_list
,
-
1
).
view
(
num_images
,
-
1
)
num_images
,
-
1
)
all_bbox_preds
=
torch
.
cat
([
all_bbox_preds
=
torch
.
cat
([
b
.
permute
(
0
,
2
,
3
,
1
).
contiguous
().
view
(
num_images
,
-
1
,
4
)
b
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
num_images
,
-
1
,
4
)
for
b
in
bbox_preds
for
b
in
bbox_preds
],
-
2
)
],
-
2
)
all_bbox_targets
=
torch
.
cat
(
bbox_targets_list
,
-
2
).
view
(
all_bbox_targets
=
torch
.
cat
(
bbox_targets_list
,
-
2
).
view
(
...
@@ -216,68 +177,6 @@ class SSDHead(nn.Module):
...
@@ -216,68 +177,6 @@ class SSDHead(nn.Module):
all_label_weights
,
all_label_weights
,
all_bbox_targets
,
all_bbox_targets
,
all_bbox_weights
,
all_bbox_weights
,
num_
pos
_samples
=
num_total_pos
,
num_
total
_samples
=
num_total_pos
,
cfg
=
cfg
)
cfg
=
cfg
)
return
dict
(
loss_cls
=
losses_cls
,
loss_reg
=
losses_reg
)
return
dict
(
loss_cls
=
losses_cls
,
loss_reg
=
losses_reg
)
def
get_det_bboxes
(
self
,
cls_scores
,
bbox_preds
,
img_metas
,
cfg
,
rescale
=
False
):
assert
len
(
cls_scores
)
==
len
(
bbox_preds
)
num_levels
=
len
(
cls_scores
)
mlvl_anchors
=
[
self
.
anchor_generators
[
i
].
grid_anchors
(
cls_scores
[
i
].
size
()[
-
2
:],
self
.
anchor_strides
[
i
])
for
i
in
range
(
num_levels
)
]
result_list
=
[]
for
img_id
in
range
(
len
(
img_metas
)):
cls_score_list
=
[
cls_scores
[
i
][
img_id
].
detach
()
for
i
in
range
(
num_levels
)
]
bbox_pred_list
=
[
bbox_preds
[
i
][
img_id
].
detach
()
for
i
in
range
(
num_levels
)
]
img_shape
=
img_metas
[
img_id
][
'img_shape'
]
scale_factor
=
img_metas
[
img_id
][
'scale_factor'
]
results
=
self
.
_get_det_bboxes_single
(
cls_score_list
,
bbox_pred_list
,
mlvl_anchors
,
img_shape
,
scale_factor
,
cfg
,
rescale
)
result_list
.
append
(
results
)
return
result_list
def
_get_det_bboxes_single
(
self
,
cls_scores
,
bbox_preds
,
mlvl_anchors
,
img_shape
,
scale_factor
,
cfg
,
rescale
=
False
):
assert
len
(
cls_scores
)
==
len
(
bbox_preds
)
==
len
(
mlvl_anchors
)
mlvl_proposals
=
[]
mlvl_scores
=
[]
for
cls_score
,
bbox_pred
,
anchors
in
zip
(
cls_scores
,
bbox_preds
,
mlvl_anchors
):
assert
cls_score
.
size
()[
-
2
:]
==
bbox_pred
.
size
()[
-
2
:]
cls_score
=
cls_score
.
permute
(
1
,
2
,
0
).
contiguous
().
view
(
-
1
,
self
.
cls_out_channels
)
scores
=
cls_score
.
softmax
(
-
1
)
bbox_pred
=
bbox_pred
.
permute
(
1
,
2
,
0
).
contiguous
().
view
(
-
1
,
4
)
proposals
=
delta2bbox
(
anchors
,
bbox_pred
,
self
.
target_means
,
self
.
target_stds
,
img_shape
)
mlvl_proposals
.
append
(
proposals
)
mlvl_scores
.
append
(
scores
)
mlvl_proposals
=
torch
.
cat
(
mlvl_proposals
)
if
rescale
:
mlvl_proposals
/=
mlvl_proposals
.
new_tensor
(
scale_factor
)
mlvl_scores
=
torch
.
cat
(
mlvl_scores
)
det_bboxes
,
det_labels
=
multiclass_nms
(
mlvl_proposals
,
mlvl_scores
,
cfg
.
score_thr
,
cfg
.
nms
,
cfg
.
max_per_img
)
return
det_bboxes
,
det_labels
mmdet/models/builder.py
View file @
70700512
from
mmcv.runner
import
obj_from_dict
from
mmcv.runner
import
obj_from_dict
from
torch
import
nn
from
torch
import
nn
from
.
import
(
backbones
,
necks
,
roi_extractors
,
rpn_heads
,
bbox_heads
,
from
.
import
(
backbones
,
necks
,
roi_extractors
,
anchor_heads
,
bbox_heads
,
mask_heads
,
single_stage_heads
)
mask_heads
)
__all__
=
[
'build_backbone'
,
'build_neck'
,
'build_rpn_head'
,
'build_roi_extractor'
,
'build_bbox_head'
,
'build_mask_head'
,
'build_single_stage_head'
,
'build_detector'
]
def
_build_module
(
cfg
,
parrent
=
None
,
default_args
=
None
):
def
_build_module
(
cfg
,
parrent
=
None
,
default_args
=
None
):
...
@@ -32,8 +26,8 @@ def build_neck(cfg):
...
@@ -32,8 +26,8 @@ def build_neck(cfg):
return
build
(
cfg
,
necks
)
return
build
(
cfg
,
necks
)
def
build_
rpn
_head
(
cfg
):
def
build_
anchor
_head
(
cfg
):
return
build
(
cfg
,
rpn
_heads
)
return
build
(
cfg
,
anchor
_heads
)
def
build_roi_extractor
(
cfg
):
def
build_roi_extractor
(
cfg
):
...
@@ -48,10 +42,6 @@ def build_mask_head(cfg):
...
@@ -48,10 +42,6 @@ def build_mask_head(cfg):
return
build
(
cfg
,
mask_heads
)
return
build
(
cfg
,
mask_heads
)
def
build_single_stage_head
(
cfg
):
return
build
(
cfg
,
single_stage_heads
)
def
build_detector
(
cfg
,
train_cfg
=
None
,
test_cfg
=
None
):
def
build_detector
(
cfg
,
train_cfg
=
None
,
test_cfg
=
None
):
from
.
import
detectors
from
.
import
detectors
return
build
(
cfg
,
detectors
,
dict
(
train_cfg
=
train_cfg
,
test_cfg
=
test_cfg
))
return
build
(
cfg
,
detectors
,
dict
(
train_cfg
=
train_cfg
,
test_cfg
=
test_cfg
))
mmdet/models/detectors/cascade_rcnn.py
View file @
70700512
...
@@ -37,7 +37,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
...
@@ -37,7 +37,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
raise
NotImplementedError
raise
NotImplementedError
if
rpn_head
is
not
None
:
if
rpn_head
is
not
None
:
self
.
rpn_head
=
builder
.
build_
rpn
_head
(
rpn_head
)
self
.
rpn_head
=
builder
.
build_
anchor
_head
(
rpn_head
)
if
bbox_head
is
not
None
:
if
bbox_head
is
not
None
:
self
.
bbox_roi_extractor
=
nn
.
ModuleList
()
self
.
bbox_roi_extractor
=
nn
.
ModuleList
()
...
@@ -123,7 +123,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
...
@@ -123,7 +123,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
losses
.
update
(
rpn_losses
)
losses
.
update
(
rpn_losses
)
proposal_inputs
=
rpn_outs
+
(
img_meta
,
self
.
test_cfg
.
rpn
)
proposal_inputs
=
rpn_outs
+
(
img_meta
,
self
.
test_cfg
.
rpn
)
proposal_list
=
self
.
rpn_head
.
get_
proposal
s
(
*
proposal_inputs
)
proposal_list
=
self
.
rpn_head
.
get_
bboxe
s
(
*
proposal_inputs
)
else
:
else
:
proposal_list
=
proposals
proposal_list
=
proposals
...
...
mmdet/models/detectors/rpn.py
View file @
70700512
...
@@ -18,7 +18,7 @@ class RPN(BaseDetector, RPNTestMixin):
...
@@ -18,7 +18,7 @@ class RPN(BaseDetector, RPNTestMixin):
super
(
RPN
,
self
).
__init__
()
super
(
RPN
,
self
).
__init__
()
self
.
backbone
=
builder
.
build_backbone
(
backbone
)
self
.
backbone
=
builder
.
build_backbone
(
backbone
)
self
.
neck
=
builder
.
build_neck
(
neck
)
if
neck
is
not
None
else
None
self
.
neck
=
builder
.
build_neck
(
neck
)
if
neck
is
not
None
else
None
self
.
rpn_head
=
builder
.
build_
rpn
_head
(
rpn_head
)
self
.
rpn_head
=
builder
.
build_
anchor
_head
(
rpn_head
)
self
.
train_cfg
=
train_cfg
self
.
train_cfg
=
train_cfg
self
.
test_cfg
=
test_cfg
self
.
test_cfg
=
test_cfg
self
.
init_weights
(
pretrained
=
pretrained
)
self
.
init_weights
(
pretrained
=
pretrained
)
...
...
mmdet/models/detectors/single_stage.py
View file @
70700512
...
@@ -18,7 +18,7 @@ class SingleStageDetector(BaseDetector):
...
@@ -18,7 +18,7 @@ class SingleStageDetector(BaseDetector):
self
.
backbone
=
builder
.
build_backbone
(
backbone
)
self
.
backbone
=
builder
.
build_backbone
(
backbone
)
if
neck
is
not
None
:
if
neck
is
not
None
:
self
.
neck
=
builder
.
build_neck
(
neck
)
self
.
neck
=
builder
.
build_neck
(
neck
)
self
.
bbox_head
=
builder
.
build_
single_stage
_head
(
bbox_head
)
self
.
bbox_head
=
builder
.
build_
anchor
_head
(
bbox_head
)
self
.
train_cfg
=
train_cfg
self
.
train_cfg
=
train_cfg
self
.
test_cfg
=
test_cfg
self
.
test_cfg
=
test_cfg
self
.
init_weights
(
pretrained
=
pretrained
)
self
.
init_weights
(
pretrained
=
pretrained
)
...
@@ -51,7 +51,7 @@ class SingleStageDetector(BaseDetector):
...
@@ -51,7 +51,7 @@ class SingleStageDetector(BaseDetector):
x
=
self
.
extract_feat
(
img
)
x
=
self
.
extract_feat
(
img
)
outs
=
self
.
bbox_head
(
x
)
outs
=
self
.
bbox_head
(
x
)
bbox_inputs
=
outs
+
(
img_meta
,
self
.
test_cfg
,
rescale
)
bbox_inputs
=
outs
+
(
img_meta
,
self
.
test_cfg
,
rescale
)
bbox_list
=
self
.
bbox_head
.
get_
det_
bboxes
(
*
bbox_inputs
)
bbox_list
=
self
.
bbox_head
.
get_bboxes
(
*
bbox_inputs
)
bbox_results
=
[
bbox_results
=
[
bbox2result
(
det_bboxes
,
det_labels
,
self
.
bbox_head
.
num_classes
)
bbox2result
(
det_bboxes
,
det_labels
,
self
.
bbox_head
.
num_classes
)
for
det_bboxes
,
det_labels
in
bbox_list
for
det_bboxes
,
det_labels
in
bbox_list
...
...
mmdet/models/detectors/test_mixins.py
View file @
70700512
...
@@ -7,7 +7,7 @@ class RPNTestMixin(object):
...
@@ -7,7 +7,7 @@ class RPNTestMixin(object):
def
simple_test_rpn
(
self
,
x
,
img_meta
,
rpn_test_cfg
):
def
simple_test_rpn
(
self
,
x
,
img_meta
,
rpn_test_cfg
):
rpn_outs
=
self
.
rpn_head
(
x
)
rpn_outs
=
self
.
rpn_head
(
x
)
proposal_inputs
=
rpn_outs
+
(
img_meta
,
rpn_test_cfg
)
proposal_inputs
=
rpn_outs
+
(
img_meta
,
rpn_test_cfg
)
proposal_list
=
self
.
rpn_head
.
get_
proposal
s
(
*
proposal_inputs
)
proposal_list
=
self
.
rpn_head
.
get_
bboxe
s
(
*
proposal_inputs
)
return
proposal_list
return
proposal_list
def
aug_test_rpn
(
self
,
feats
,
img_metas
,
rpn_test_cfg
):
def
aug_test_rpn
(
self
,
feats
,
img_metas
,
rpn_test_cfg
):
...
...
mmdet/models/detectors/two_stage.py
View file @
70700512
...
@@ -30,7 +30,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
...
@@ -30,7 +30,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
raise
NotImplementedError
raise
NotImplementedError
if
rpn_head
is
not
None
:
if
rpn_head
is
not
None
:
self
.
rpn_head
=
builder
.
build_
rpn
_head
(
rpn_head
)
self
.
rpn_head
=
builder
.
build_
anchor
_head
(
rpn_head
)
if
bbox_head
is
not
None
:
if
bbox_head
is
not
None
:
self
.
bbox_roi_extractor
=
builder
.
build_roi_extractor
(
self
.
bbox_roi_extractor
=
builder
.
build_roi_extractor
(
...
@@ -96,7 +96,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
...
@@ -96,7 +96,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
losses
.
update
(
rpn_losses
)
losses
.
update
(
rpn_losses
)
proposal_inputs
=
rpn_outs
+
(
img_meta
,
self
.
test_cfg
.
rpn
)
proposal_inputs
=
rpn_outs
+
(
img_meta
,
self
.
test_cfg
.
rpn
)
proposal_list
=
self
.
rpn_head
.
get_
proposal
s
(
*
proposal_inputs
)
proposal_list
=
self
.
rpn_head
.
get_
bboxe
s
(
*
proposal_inputs
)
else
:
else
:
proposal_list
=
proposals
proposal_list
=
proposals
...
...
mmdet/models/rpn_heads/__init__.py
deleted
100644 → 0
View file @
1b9f9b88
from
.rpn_head
import
RPNHead
__all__
=
[
'RPNHead'
]
mmdet/models/rpn_heads/rpn_head.py
deleted
100644 → 0
View file @
1b9f9b88
from
__future__
import
division
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmdet.core
import
(
AnchorGenerator
,
anchor_target
,
delta2bbox
,
multi_apply
,
weighted_cross_entropy
,
weighted_smoothl1
,
weighted_binary_cross_entropy
)
from
mmdet.ops
import
nms
from
..utils
import
normal_init
class
RPNHead
(
nn
.
Module
):
"""Network head of RPN.
/ - rpn_cls (1x1 conv)
input - rpn_conv (3x3 conv) -
\ - rpn_reg (1x1 conv)
Args:
in_channels (int): Number of channels in the input feature map.
feat_channels (int): Number of channels for the RPN feature map.
anchor_scales (Iterable): Anchor scales.
anchor_ratios (Iterable): Anchor aspect ratios.
anchor_strides (Iterable): Anchor strides.
anchor_base_sizes (Iterable): Anchor base sizes.
target_means (Iterable): Mean values of regression targets.
target_stds (Iterable): Std values of regression targets.
use_sigmoid_cls (bool): Whether to use sigmoid loss for classification.
(softmax by default)
"""
# noqa: W605
def
__init__
(
self
,
in_channels
,
feat_channels
=
256
,
anchor_scales
=
[
8
,
16
,
32
],
anchor_ratios
=
[
0.5
,
1.0
,
2.0
],
anchor_strides
=
[
4
,
8
,
16
,
32
,
64
],
anchor_base_sizes
=
None
,
target_means
=
(.
0
,
.
0
,
.
0
,
.
0
),
target_stds
=
(
1.0
,
1.0
,
1.0
,
1.0
),
use_sigmoid_cls
=
False
):
super
(
RPNHead
,
self
).
__init__
()
self
.
in_channels
=
in_channels
self
.
feat_channels
=
feat_channels
self
.
anchor_scales
=
anchor_scales
self
.
anchor_ratios
=
anchor_ratios
self
.
anchor_strides
=
anchor_strides
self
.
anchor_base_sizes
=
list
(
anchor_strides
)
if
anchor_base_sizes
is
None
else
anchor_base_sizes
self
.
target_means
=
target_means
self
.
target_stds
=
target_stds
self
.
use_sigmoid_cls
=
use_sigmoid_cls
self
.
anchor_generators
=
[]
for
anchor_base
in
self
.
anchor_base_sizes
:
self
.
anchor_generators
.
append
(
AnchorGenerator
(
anchor_base
,
anchor_scales
,
anchor_ratios
))
self
.
rpn_conv
=
nn
.
Conv2d
(
in_channels
,
feat_channels
,
3
,
padding
=
1
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
num_anchors
=
len
(
self
.
anchor_ratios
)
*
len
(
self
.
anchor_scales
)
out_channels
=
(
self
.
num_anchors
if
self
.
use_sigmoid_cls
else
self
.
num_anchors
*
2
)
self
.
rpn_cls
=
nn
.
Conv2d
(
feat_channels
,
out_channels
,
1
)
self
.
rpn_reg
=
nn
.
Conv2d
(
feat_channels
,
self
.
num_anchors
*
4
,
1
)
self
.
debug_imgs
=
None
def
init_weights
(
self
):
normal_init
(
self
.
rpn_conv
,
std
=
0.01
)
normal_init
(
self
.
rpn_cls
,
std
=
0.01
)
normal_init
(
self
.
rpn_reg
,
std
=
0.01
)
def
forward_single
(
self
,
x
):
rpn_feat
=
self
.
relu
(
self
.
rpn_conv
(
x
))
rpn_cls_score
=
self
.
rpn_cls
(
rpn_feat
)
rpn_bbox_pred
=
self
.
rpn_reg
(
rpn_feat
)
return
rpn_cls_score
,
rpn_bbox_pred
def
forward
(
self
,
feats
):
return
multi_apply
(
self
.
forward_single
,
feats
)
def
get_anchors
(
self
,
featmap_sizes
,
img_metas
):
"""Get anchors according to feature map sizes.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
img_metas (list[dict]): Image meta info.
Returns:
tuple: anchors of each image, valid flags of each image
"""
num_imgs
=
len
(
img_metas
)
num_levels
=
len
(
featmap_sizes
)
# since feature map sizes of all images are the same, we only compute
# anchors for one time
multi_level_anchors
=
[]
for
i
in
range
(
num_levels
):
anchors
=
self
.
anchor_generators
[
i
].
grid_anchors
(
featmap_sizes
[
i
],
self
.
anchor_strides
[
i
])
multi_level_anchors
.
append
(
anchors
)
anchor_list
=
[
multi_level_anchors
for
_
in
range
(
num_imgs
)]
# for each image, we compute valid flags of multi level anchors
valid_flag_list
=
[]
for
img_id
,
img_meta
in
enumerate
(
img_metas
):
multi_level_flags
=
[]
for
i
in
range
(
num_levels
):
anchor_stride
=
self
.
anchor_strides
[
i
]
feat_h
,
feat_w
=
featmap_sizes
[
i
]
h
,
w
,
_
=
img_meta
[
'pad_shape'
]
valid_feat_h
=
min
(
int
(
np
.
ceil
(
h
/
anchor_stride
)),
feat_h
)
valid_feat_w
=
min
(
int
(
np
.
ceil
(
w
/
anchor_stride
)),
feat_w
)
flags
=
self
.
anchor_generators
[
i
].
valid_flags
(
(
feat_h
,
feat_w
),
(
valid_feat_h
,
valid_feat_w
))
multi_level_flags
.
append
(
flags
)
valid_flag_list
.
append
(
multi_level_flags
)
return
anchor_list
,
valid_flag_list
def
loss_single
(
self
,
rpn_cls_score
,
rpn_bbox_pred
,
labels
,
label_weights
,
bbox_targets
,
bbox_weights
,
num_total_samples
,
cfg
):
# classification loss
labels
=
labels
.
contiguous
().
view
(
-
1
)
label_weights
=
label_weights
.
contiguous
().
view
(
-
1
)
if
self
.
use_sigmoid_cls
:
rpn_cls_score
=
rpn_cls_score
.
permute
(
0
,
2
,
3
,
1
).
contiguous
().
view
(
-
1
)
criterion
=
weighted_binary_cross_entropy
else
:
rpn_cls_score
=
rpn_cls_score
.
permute
(
0
,
2
,
3
,
1
).
contiguous
().
view
(
-
1
,
2
)
criterion
=
weighted_cross_entropy
loss_cls
=
criterion
(
rpn_cls_score
,
labels
,
label_weights
,
avg_factor
=
num_total_samples
)
# regression loss
bbox_targets
=
bbox_targets
.
contiguous
().
view
(
-
1
,
4
)
bbox_weights
=
bbox_weights
.
contiguous
().
view
(
-
1
,
4
)
rpn_bbox_pred
=
rpn_bbox_pred
.
permute
(
0
,
2
,
3
,
1
).
contiguous
().
view
(
-
1
,
4
)
loss_reg
=
weighted_smoothl1
(
rpn_bbox_pred
,
bbox_targets
,
bbox_weights
,
beta
=
cfg
.
smoothl1_beta
,
avg_factor
=
num_total_samples
)
return
loss_cls
,
loss_reg
def
loss
(
self
,
rpn_cls_scores
,
rpn_bbox_preds
,
gt_bboxes
,
img_shapes
,
cfg
):
featmap_sizes
=
[
featmap
.
size
()[
-
2
:]
for
featmap
in
rpn_cls_scores
]
assert
len
(
featmap_sizes
)
==
len
(
self
.
anchor_generators
)
anchor_list
,
valid_flag_list
=
self
.
get_anchors
(
featmap_sizes
,
img_shapes
)
cls_reg_targets
=
anchor_target
(
anchor_list
,
valid_flag_list
,
gt_bboxes
,
img_shapes
,
self
.
target_means
,
self
.
target_stds
,
cfg
)
if
cls_reg_targets
is
None
:
return
None
(
labels_list
,
label_weights_list
,
bbox_targets_list
,
bbox_weights_list
,
num_total_pos
,
num_total_neg
)
=
cls_reg_targets
losses_cls
,
losses_reg
=
multi_apply
(
self
.
loss_single
,
rpn_cls_scores
,
rpn_bbox_preds
,
labels_list
,
label_weights_list
,
bbox_targets_list
,
bbox_weights_list
,
num_total_samples
=
num_total_pos
+
num_total_neg
,
cfg
=
cfg
)
return
dict
(
loss_rpn_cls
=
losses_cls
,
loss_rpn_reg
=
losses_reg
)
def
get_proposals
(
self
,
rpn_cls_scores
,
rpn_bbox_preds
,
img_meta
,
cfg
):
num_imgs
=
len
(
img_meta
)
featmap_sizes
=
[
featmap
.
size
()[
-
2
:]
for
featmap
in
rpn_cls_scores
]
mlvl_anchors
=
[
self
.
anchor_generators
[
idx
].
grid_anchors
(
featmap_sizes
[
idx
],
self
.
anchor_strides
[
idx
])
for
idx
in
range
(
len
(
featmap_sizes
))
]
proposal_list
=
[]
for
img_id
in
range
(
num_imgs
):
rpn_cls_score_list
=
[
rpn_cls_scores
[
idx
][
img_id
].
detach
()
for
idx
in
range
(
len
(
rpn_cls_scores
))
]
rpn_bbox_pred_list
=
[
rpn_bbox_preds
[
idx
][
img_id
].
detach
()
for
idx
in
range
(
len
(
rpn_bbox_preds
))
]
assert
len
(
rpn_cls_score_list
)
==
len
(
rpn_bbox_pred_list
)
proposals
=
self
.
_get_proposals_single
(
rpn_cls_score_list
,
rpn_bbox_pred_list
,
mlvl_anchors
,
img_meta
[
img_id
][
'img_shape'
],
cfg
)
proposal_list
.
append
(
proposals
)
return
proposal_list
def
_get_proposals_single
(
self
,
rpn_cls_scores
,
rpn_bbox_preds
,
mlvl_anchors
,
img_shape
,
cfg
):
mlvl_proposals
=
[]
for
idx
in
range
(
len
(
rpn_cls_scores
)):
rpn_cls_score
=
rpn_cls_scores
[
idx
]
rpn_bbox_pred
=
rpn_bbox_preds
[
idx
]
assert
rpn_cls_score
.
size
()[
-
2
:]
==
rpn_bbox_pred
.
size
()[
-
2
:]
anchors
=
mlvl_anchors
[
idx
]
if
self
.
use_sigmoid_cls
:
rpn_cls_score
=
rpn_cls_score
.
permute
(
1
,
2
,
0
).
contiguous
().
view
(
-
1
)
rpn_cls_prob
=
rpn_cls_score
.
sigmoid
()
scores
=
rpn_cls_prob
else
:
rpn_cls_score
=
rpn_cls_score
.
permute
(
1
,
2
,
0
).
contiguous
().
view
(
-
1
,
2
)
rpn_cls_prob
=
F
.
softmax
(
rpn_cls_score
,
dim
=
1
)
scores
=
rpn_cls_prob
[:,
1
]
rpn_bbox_pred
=
rpn_bbox_pred
.
permute
(
1
,
2
,
0
).
contiguous
().
view
(
-
1
,
4
)
_
,
order
=
scores
.
sort
(
0
,
descending
=
True
)
if
cfg
.
nms_pre
>
0
:
order
=
order
[:
cfg
.
nms_pre
]
rpn_bbox_pred
=
rpn_bbox_pred
[
order
,
:]
anchors
=
anchors
[
order
,
:]
scores
=
scores
[
order
]
proposals
=
delta2bbox
(
anchors
,
rpn_bbox_pred
,
self
.
target_means
,
self
.
target_stds
,
img_shape
)
w
=
proposals
[:,
2
]
-
proposals
[:,
0
]
+
1
h
=
proposals
[:,
3
]
-
proposals
[:,
1
]
+
1
valid_inds
=
torch
.
nonzero
((
w
>=
cfg
.
min_bbox_size
)
&
(
h
>=
cfg
.
min_bbox_size
)).
squeeze
()
proposals
=
proposals
[
valid_inds
,
:]
scores
=
scores
[
valid_inds
]
proposals
=
torch
.
cat
([
proposals
,
scores
.
unsqueeze
(
-
1
)],
dim
=-
1
)
proposals
,
_
=
nms
(
proposals
,
cfg
.
nms_thr
)
proposals
=
proposals
[:
cfg
.
nms_post
,
:]
mlvl_proposals
.
append
(
proposals
)
proposals
=
torch
.
cat
(
mlvl_proposals
,
0
)
if
cfg
.
nms_across_levels
:
proposals
,
_
=
nms
(
proposals
,
cfg
.
nms_thr
)
proposals
=
proposals
[:
cfg
.
max_num
,
:]
else
:
scores
=
proposals
[:,
4
]
_
,
order
=
scores
.
sort
(
0
,
descending
=
True
)
num
=
min
(
cfg
.
max_num
,
proposals
.
shape
[
0
])
order
=
order
[:
num
]
proposals
=
proposals
[
order
,
:]
return
proposals
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment