Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
441015ea
"vscode:/vscode.git/clone" did not exist on "949b6c01e074c6f7712d7da37079218b3192b102"
Commit
441015ea
authored
Feb 06, 2019
by
Kai Chen
Browse files
Merge branch 'master' into pytorch-1.0
parents
2017c81e
3b6ae96d
Changes
76
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
175 additions
and
323 deletions
+175
-323
mmdet/models/bbox_heads/convfc_bbox_head.py
mmdet/models/bbox_heads/convfc_bbox_head.py
+3
-0
mmdet/models/builder.py
mmdet/models/builder.py
+32
-38
mmdet/models/detectors/cascade_rcnn.py
mmdet/models/detectors/cascade_rcnn.py
+6
-4
mmdet/models/detectors/fast_rcnn.py
mmdet/models/detectors/fast_rcnn.py
+2
-0
mmdet/models/detectors/faster_rcnn.py
mmdet/models/detectors/faster_rcnn.py
+10
-8
mmdet/models/detectors/mask_rcnn.py
mmdet/models/detectors/mask_rcnn.py
+2
-0
mmdet/models/detectors/retinanet.py
mmdet/models/detectors/retinanet.py
+2
-0
mmdet/models/detectors/rpn.py
mmdet/models/detectors/rpn.py
+3
-1
mmdet/models/detectors/single_stage.py
mmdet/models/detectors/single_stage.py
+4
-2
mmdet/models/detectors/test_mixins.py
mmdet/models/detectors/test_mixins.py
+1
-1
mmdet/models/detectors/two_stage.py
mmdet/models/detectors/two_stage.py
+6
-4
mmdet/models/mask_heads/fcn_mask_head.py
mmdet/models/mask_heads/fcn_mask_head.py
+2
-0
mmdet/models/necks/fpn.py
mmdet/models/necks/fpn.py
+4
-1
mmdet/models/registry.py
mmdet/models/registry.py
+43
-0
mmdet/models/roi_extractors/single_level.py
mmdet/models/roi_extractors/single_level.py
+2
-0
mmdet/models/rpn_heads/__init__.py
mmdet/models/rpn_heads/__init__.py
+0
-3
mmdet/models/rpn_heads/rpn_head.py
mmdet/models/rpn_heads/rpn_head.py
+0
-250
mmdet/models/single_stage_heads/__init__.py
mmdet/models/single_stage_heads/__init__.py
+0
-3
mmdet/models/utils/conv_module.py
mmdet/models/utils/conv_module.py
+6
-1
mmdet/models/utils/norm.py
mmdet/models/utils/norm.py
+47
-7
No files found.
mmdet/models/bbox_heads/convfc_bbox_head.py
View file @
441015ea
import
torch.nn
as
nn
from
.bbox_head
import
BBoxHead
from
..registry
import
HEADS
from
..utils
import
ConvModule
@
HEADS
.
register_module
class
ConvFCBBoxHead
(
BBoxHead
):
"""More general bbox head, with shared conv and fc layers and two optional
separated branches.
...
...
@@ -165,6 +167,7 @@ class ConvFCBBoxHead(BBoxHead):
return
cls_score
,
bbox_pred
@
HEADS
.
register_module
class
SharedFCBBoxHead
(
ConvFCBBoxHead
):
def
__init__
(
self
,
num_fcs
=
2
,
fc_out_channels
=
1024
,
*
args
,
**
kwargs
):
...
...
mmdet/models/builder.py
View file @
441015ea
from
mmcv.runner
import
obj_from_dict
import
mmcv
from
torch
import
nn
from
.
import
(
backbones
,
necks
,
roi_extractors
,
rpn_heads
,
bbox_heads
,
mask_heads
,
single_stage_heads
)
__all__
=
[
'build_backbone'
,
'build_neck'
,
'build_rpn_head'
,
'build_roi_extractor'
,
'build_bbox_head'
,
'build_mask_head'
,
'build_single_stage_head'
,
'build_detector'
]
def
_build_module
(
cfg
,
parrent
=
None
,
default_args
=
None
):
return
cfg
if
isinstance
(
cfg
,
nn
.
Module
)
else
obj_from_dict
(
cfg
,
parrent
,
default_args
)
def
build
(
cfg
,
parrent
=
None
,
default_args
=
None
):
from
.registry
import
BACKBONES
,
NECKS
,
ROI_EXTRACTORS
,
HEADS
,
DETECTORS
def
_build_module
(
cfg
,
registry
,
default_args
):
assert
isinstance
(
cfg
,
dict
)
and
'type'
in
cfg
assert
isinstance
(
default_args
,
dict
)
or
default_args
is
None
args
=
cfg
.
copy
()
obj_type
=
args
.
pop
(
'type'
)
if
mmcv
.
is_str
(
obj_type
):
if
obj_type
not
in
registry
.
module_dict
:
raise
KeyError
(
'{} is not in the {} registry'
.
format
(
obj_type
,
registry
.
name
))
obj_type
=
registry
.
module_dict
[
obj_type
]
elif
not
isinstance
(
obj_type
,
type
):
raise
TypeError
(
'type must be a str or valid type, but got {}'
.
format
(
type
(
obj_type
)))
if
default_args
is
not
None
:
for
name
,
value
in
default_args
.
items
():
args
.
setdefault
(
name
,
value
)
return
obj_type
(
**
args
)
def
build
(
cfg
,
registry
,
default_args
=
None
):
if
isinstance
(
cfg
,
list
):
modules
=
[
_build_module
(
cfg_
,
parrent
,
default_args
)
for
cfg_
in
cfg
]
modules
=
[
_build_module
(
cfg_
,
registry
,
default_args
)
for
cfg_
in
cfg
]
return
nn
.
Sequential
(
*
modules
)
else
:
return
_build_module
(
cfg
,
parrent
,
default_args
)
return
_build_module
(
cfg
,
registry
,
default_args
)
def
build_backbone
(
cfg
):
return
build
(
cfg
,
backbones
)
return
build
(
cfg
,
BACKBONES
)
def
build_neck
(
cfg
):
return
build
(
cfg
,
necks
)
def
build_rpn_head
(
cfg
):
return
build
(
cfg
,
rpn_heads
)
return
build
(
cfg
,
NECKS
)
def
build_roi_extractor
(
cfg
):
return
build
(
cfg
,
roi_extractors
)
def
build_bbox_head
(
cfg
):
return
build
(
cfg
,
bbox_heads
)
def
build_mask_head
(
cfg
):
return
build
(
cfg
,
mask_heads
)
return
build
(
cfg
,
ROI_EXTRACTORS
)
def
build_
single_stage_
head
(
cfg
):
return
build
(
cfg
,
single_stage_heads
)
def
build_head
(
cfg
):
return
build
(
cfg
,
HEADS
)
def
build_detector
(
cfg
,
train_cfg
=
None
,
test_cfg
=
None
):
from
.
import
detectors
return
build
(
cfg
,
detectors
,
dict
(
train_cfg
=
train_cfg
,
test_cfg
=
test_cfg
))
return
build
(
cfg
,
DETECTORS
,
dict
(
train_cfg
=
train_cfg
,
test_cfg
=
test_cfg
))
mmdet/models/detectors/cascade_rcnn.py
View file @
441015ea
...
...
@@ -6,10 +6,12 @@ import torch.nn as nn
from
.base
import
BaseDetector
from
.test_mixins
import
RPNTestMixin
from
..
import
builder
from
..registry
import
DETECTORS
from
mmdet.core
import
(
assign_and_sample
,
bbox2roi
,
bbox2result
,
multi_apply
,
merge_aug_masks
)
@
DETECTORS
.
register_module
class
CascadeRCNN
(
BaseDetector
,
RPNTestMixin
):
def
__init__
(
self
,
...
...
@@ -37,7 +39,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
raise
NotImplementedError
if
rpn_head
is
not
None
:
self
.
rpn_head
=
builder
.
build_
rpn_
head
(
rpn_head
)
self
.
rpn_head
=
builder
.
build_head
(
rpn_head
)
if
bbox_head
is
not
None
:
self
.
bbox_roi_extractor
=
nn
.
ModuleList
()
...
...
@@ -52,7 +54,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
for
roi_extractor
,
head
in
zip
(
bbox_roi_extractor
,
bbox_head
):
self
.
bbox_roi_extractor
.
append
(
builder
.
build_roi_extractor
(
roi_extractor
))
self
.
bbox_head
.
append
(
builder
.
build_
bbox_
head
(
head
))
self
.
bbox_head
.
append
(
builder
.
build_head
(
head
))
if
mask_head
is
not
None
:
self
.
mask_roi_extractor
=
nn
.
ModuleList
()
...
...
@@ -67,7 +69,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
for
roi_extractor
,
head
in
zip
(
mask_roi_extractor
,
mask_head
):
self
.
mask_roi_extractor
.
append
(
builder
.
build_roi_extractor
(
roi_extractor
))
self
.
mask_head
.
append
(
builder
.
build_
mask_
head
(
head
))
self
.
mask_head
.
append
(
builder
.
build_head
(
head
))
self
.
train_cfg
=
train_cfg
self
.
test_cfg
=
test_cfg
...
...
@@ -123,7 +125,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
losses
.
update
(
rpn_losses
)
proposal_inputs
=
rpn_outs
+
(
img_meta
,
self
.
test_cfg
.
rpn
)
proposal_list
=
self
.
rpn_head
.
get_
proposal
s
(
*
proposal_inputs
)
proposal_list
=
self
.
rpn_head
.
get_
bboxe
s
(
*
proposal_inputs
)
else
:
proposal_list
=
proposals
...
...
mmdet/models/detectors/fast_rcnn.py
View file @
441015ea
from
.two_stage
import
TwoStageDetector
from
..registry
import
DETECTORS
@
DETECTORS
.
register_module
class
FastRCNN
(
TwoStageDetector
):
def
__init__
(
self
,
...
...
mmdet/models/detectors/faster_rcnn.py
View file @
441015ea
from
.two_stage
import
TwoStageDetector
from
..registry
import
DETECTORS
@
DETECTORS
.
register_module
class
FasterRCNN
(
TwoStageDetector
):
def
__init__
(
self
,
...
...
@@ -13,11 +15,11 @@ class FasterRCNN(TwoStageDetector):
test_cfg
,
pretrained
=
None
):
super
(
FasterRCNN
,
self
).
__init__
(
backbone
=
backbone
,
neck
=
neck
,
rpn_head
=
rpn_head
,
bbox_roi_extractor
=
bbox_roi_extractor
,
bbox_head
=
bbox_head
,
train_cfg
=
train_cfg
,
test_cfg
=
test_cfg
,
pretrained
=
pretrained
)
backbone
=
backbone
,
neck
=
neck
,
rpn_head
=
rpn_head
,
bbox_roi_extractor
=
bbox_roi_extractor
,
bbox_head
=
bbox_head
,
train_cfg
=
train_cfg
,
test_cfg
=
test_cfg
,
pretrained
=
pretrained
)
mmdet/models/detectors/mask_rcnn.py
View file @
441015ea
from
.two_stage
import
TwoStageDetector
from
..registry
import
DETECTORS
@
DETECTORS
.
register_module
class
MaskRCNN
(
TwoStageDetector
):
def
__init__
(
self
,
...
...
mmdet/models/detectors/retinanet.py
View file @
441015ea
from
.single_stage
import
SingleStageDetector
from
..registry
import
DETECTORS
@
DETECTORS
.
register_module
class
RetinaNet
(
SingleStageDetector
):
def
__init__
(
self
,
...
...
mmdet/models/detectors/rpn.py
View file @
441015ea
...
...
@@ -4,8 +4,10 @@ from mmdet.core import tensor2imgs, bbox_mapping
from
.base
import
BaseDetector
from
.test_mixins
import
RPNTestMixin
from
..
import
builder
from
..registry
import
DETECTORS
@
DETECTORS
.
register_module
class
RPN
(
BaseDetector
,
RPNTestMixin
):
def
__init__
(
self
,
...
...
@@ -18,7 +20,7 @@ class RPN(BaseDetector, RPNTestMixin):
super
(
RPN
,
self
).
__init__
()
self
.
backbone
=
builder
.
build_backbone
(
backbone
)
self
.
neck
=
builder
.
build_neck
(
neck
)
if
neck
is
not
None
else
None
self
.
rpn_head
=
builder
.
build_
rpn_
head
(
rpn_head
)
self
.
rpn_head
=
builder
.
build_head
(
rpn_head
)
self
.
train_cfg
=
train_cfg
self
.
test_cfg
=
test_cfg
self
.
init_weights
(
pretrained
=
pretrained
)
...
...
mmdet/models/detectors/single_stage.py
View file @
441015ea
...
...
@@ -2,9 +2,11 @@ import torch.nn as nn
from
.base
import
BaseDetector
from
..
import
builder
from
..registry
import
DETECTORS
from
mmdet.core
import
bbox2result
@
DETECTORS
.
register_module
class
SingleStageDetector
(
BaseDetector
):
def
__init__
(
self
,
...
...
@@ -18,7 +20,7 @@ class SingleStageDetector(BaseDetector):
self
.
backbone
=
builder
.
build_backbone
(
backbone
)
if
neck
is
not
None
:
self
.
neck
=
builder
.
build_neck
(
neck
)
self
.
bbox_head
=
builder
.
build_
single_stage_
head
(
bbox_head
)
self
.
bbox_head
=
builder
.
build_head
(
bbox_head
)
self
.
train_cfg
=
train_cfg
self
.
test_cfg
=
test_cfg
self
.
init_weights
(
pretrained
=
pretrained
)
...
...
@@ -51,7 +53,7 @@ class SingleStageDetector(BaseDetector):
x
=
self
.
extract_feat
(
img
)
outs
=
self
.
bbox_head
(
x
)
bbox_inputs
=
outs
+
(
img_meta
,
self
.
test_cfg
,
rescale
)
bbox_list
=
self
.
bbox_head
.
get_
det_
bboxes
(
*
bbox_inputs
)
bbox_list
=
self
.
bbox_head
.
get_bboxes
(
*
bbox_inputs
)
bbox_results
=
[
bbox2result
(
det_bboxes
,
det_labels
,
self
.
bbox_head
.
num_classes
)
for
det_bboxes
,
det_labels
in
bbox_list
...
...
mmdet/models/detectors/test_mixins.py
View file @
441015ea
...
...
@@ -7,7 +7,7 @@ class RPNTestMixin(object):
def
simple_test_rpn
(
self
,
x
,
img_meta
,
rpn_test_cfg
):
rpn_outs
=
self
.
rpn_head
(
x
)
proposal_inputs
=
rpn_outs
+
(
img_meta
,
rpn_test_cfg
)
proposal_list
=
self
.
rpn_head
.
get_
proposal
s
(
*
proposal_inputs
)
proposal_list
=
self
.
rpn_head
.
get_
bboxe
s
(
*
proposal_inputs
)
return
proposal_list
def
aug_test_rpn
(
self
,
feats
,
img_metas
,
rpn_test_cfg
):
...
...
mmdet/models/detectors/two_stage.py
View file @
441015ea
...
...
@@ -4,9 +4,11 @@ import torch.nn as nn
from
.base
import
BaseDetector
from
.test_mixins
import
RPNTestMixin
,
BBoxTestMixin
,
MaskTestMixin
from
..
import
builder
from
..registry
import
DETECTORS
from
mmdet.core
import
bbox2roi
,
bbox2result
,
build_assigner
,
build_sampler
@
DETECTORS
.
register_module
class
TwoStageDetector
(
BaseDetector
,
RPNTestMixin
,
BBoxTestMixin
,
MaskTestMixin
):
...
...
@@ -30,17 +32,17 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
raise
NotImplementedError
if
rpn_head
is
not
None
:
self
.
rpn_head
=
builder
.
build_
rpn_
head
(
rpn_head
)
self
.
rpn_head
=
builder
.
build_head
(
rpn_head
)
if
bbox_head
is
not
None
:
self
.
bbox_roi_extractor
=
builder
.
build_roi_extractor
(
bbox_roi_extractor
)
self
.
bbox_head
=
builder
.
build_
bbox_
head
(
bbox_head
)
self
.
bbox_head
=
builder
.
build_head
(
bbox_head
)
if
mask_head
is
not
None
:
self
.
mask_roi_extractor
=
builder
.
build_roi_extractor
(
mask_roi_extractor
)
self
.
mask_head
=
builder
.
build_
mask_
head
(
mask_head
)
self
.
mask_head
=
builder
.
build_head
(
mask_head
)
self
.
train_cfg
=
train_cfg
self
.
test_cfg
=
test_cfg
...
...
@@ -96,7 +98,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
losses
.
update
(
rpn_losses
)
proposal_inputs
=
rpn_outs
+
(
img_meta
,
self
.
test_cfg
.
rpn
)
proposal_list
=
self
.
rpn_head
.
get_
proposal
s
(
*
proposal_inputs
)
proposal_list
=
self
.
rpn_head
.
get_
bboxe
s
(
*
proposal_inputs
)
else
:
proposal_list
=
proposals
...
...
mmdet/models/mask_heads/fcn_mask_head.py
View file @
441015ea
...
...
@@ -4,10 +4,12 @@ import pycocotools.mask as mask_util
import
torch
import
torch.nn
as
nn
from
..registry
import
HEADS
from
..utils
import
ConvModule
from
mmdet.core
import
mask_cross_entropy
,
mask_target
@
HEADS
.
register_module
class
FCNMaskHead
(
nn
.
Module
):
def
__init__
(
self
,
...
...
mmdet/models/necks/fpn.py
View file @
441015ea
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmcv.cnn
import
xavier_init
from
..utils
import
ConvModule
from
..
utils
import
xavier_init
from
..
registry
import
NECKS
@
NECKS
.
register_module
class
FPN
(
nn
.
Module
):
def
__init__
(
self
,
...
...
mmdet/models/registry.py
0 → 100644
View file @
441015ea
import
torch.nn
as
nn
class
Registry
(
object
):
def
__init__
(
self
,
name
):
self
.
_name
=
name
self
.
_module_dict
=
dict
()
@
property
def
name
(
self
):
return
self
.
_name
@
property
def
module_dict
(
self
):
return
self
.
_module_dict
def
_register_module
(
self
,
module_class
):
"""Register a module.
Args:
module (:obj:`nn.Module`): Module to be registered.
"""
if
not
issubclass
(
module_class
,
nn
.
Module
):
raise
TypeError
(
'module must be a child of nn.Module, but got {}'
.
format
(
type
(
module_class
)))
module_name
=
module_class
.
__name__
if
module_name
in
self
.
_module_dict
:
raise
KeyError
(
'{} is already registered in {}'
.
format
(
module_name
,
self
.
name
))
self
.
_module_dict
[
module_name
]
=
module_class
def
register_module
(
self
,
cls
):
self
.
_register_module
(
cls
)
return
cls
BACKBONES
=
Registry
(
'backbone'
)
NECKS
=
Registry
(
'neck'
)
ROI_EXTRACTORS
=
Registry
(
'roi_extractor'
)
HEADS
=
Registry
(
'head'
)
DETECTORS
=
Registry
(
'detector'
)
mmdet/models/roi_extractors/single_level.py
View file @
441015ea
...
...
@@ -4,8 +4,10 @@ import torch
import
torch.nn
as
nn
from
mmdet
import
ops
from
..registry
import
ROI_EXTRACTORS
@
ROI_EXTRACTORS
.
register_module
class
SingleRoIExtractor
(
nn
.
Module
):
"""Extract RoI features from a single level feature map.
...
...
mmdet/models/rpn_heads/__init__.py
deleted
100644 → 0
View file @
2017c81e
from
.rpn_head
import
RPNHead
__all__
=
[
'RPNHead'
]
mmdet/models/rpn_heads/rpn_head.py
deleted
100644 → 0
View file @
2017c81e
from
__future__
import
division
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmdet.core
import
(
AnchorGenerator
,
anchor_target
,
delta2bbox
,
multi_apply
,
weighted_cross_entropy
,
weighted_smoothl1
,
weighted_binary_cross_entropy
)
from
mmdet.ops
import
nms
from
..utils
import
normal_init
class
RPNHead
(
nn
.
Module
):
"""Network head of RPN.
/ - rpn_cls (1x1 conv)
input - rpn_conv (3x3 conv) -
\ - rpn_reg (1x1 conv)
Args:
in_channels (int): Number of channels in the input feature map.
feat_channels (int): Number of channels for the RPN feature map.
anchor_scales (Iterable): Anchor scales.
anchor_ratios (Iterable): Anchor aspect ratios.
anchor_strides (Iterable): Anchor strides.
anchor_base_sizes (Iterable): Anchor base sizes.
target_means (Iterable): Mean values of regression targets.
target_stds (Iterable): Std values of regression targets.
use_sigmoid_cls (bool): Whether to use sigmoid loss for classification.
(softmax by default)
"""
# noqa: W605
def
__init__
(
self
,
in_channels
,
feat_channels
=
256
,
anchor_scales
=
[
8
,
16
,
32
],
anchor_ratios
=
[
0.5
,
1.0
,
2.0
],
anchor_strides
=
[
4
,
8
,
16
,
32
,
64
],
anchor_base_sizes
=
None
,
target_means
=
(.
0
,
.
0
,
.
0
,
.
0
),
target_stds
=
(
1.0
,
1.0
,
1.0
,
1.0
),
use_sigmoid_cls
=
False
):
super
(
RPNHead
,
self
).
__init__
()
self
.
in_channels
=
in_channels
self
.
feat_channels
=
feat_channels
self
.
anchor_scales
=
anchor_scales
self
.
anchor_ratios
=
anchor_ratios
self
.
anchor_strides
=
anchor_strides
self
.
anchor_base_sizes
=
list
(
anchor_strides
)
if
anchor_base_sizes
is
None
else
anchor_base_sizes
self
.
target_means
=
target_means
self
.
target_stds
=
target_stds
self
.
use_sigmoid_cls
=
use_sigmoid_cls
self
.
anchor_generators
=
[]
for
anchor_base
in
self
.
anchor_base_sizes
:
self
.
anchor_generators
.
append
(
AnchorGenerator
(
anchor_base
,
anchor_scales
,
anchor_ratios
))
self
.
rpn_conv
=
nn
.
Conv2d
(
in_channels
,
feat_channels
,
3
,
padding
=
1
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
num_anchors
=
len
(
self
.
anchor_ratios
)
*
len
(
self
.
anchor_scales
)
out_channels
=
(
self
.
num_anchors
if
self
.
use_sigmoid_cls
else
self
.
num_anchors
*
2
)
self
.
rpn_cls
=
nn
.
Conv2d
(
feat_channels
,
out_channels
,
1
)
self
.
rpn_reg
=
nn
.
Conv2d
(
feat_channels
,
self
.
num_anchors
*
4
,
1
)
self
.
debug_imgs
=
None
def
init_weights
(
self
):
normal_init
(
self
.
rpn_conv
,
std
=
0.01
)
normal_init
(
self
.
rpn_cls
,
std
=
0.01
)
normal_init
(
self
.
rpn_reg
,
std
=
0.01
)
def
forward_single
(
self
,
x
):
rpn_feat
=
self
.
relu
(
self
.
rpn_conv
(
x
))
rpn_cls_score
=
self
.
rpn_cls
(
rpn_feat
)
rpn_bbox_pred
=
self
.
rpn_reg
(
rpn_feat
)
return
rpn_cls_score
,
rpn_bbox_pred
def
forward
(
self
,
feats
):
return
multi_apply
(
self
.
forward_single
,
feats
)
def
get_anchors
(
self
,
featmap_sizes
,
img_metas
):
"""Get anchors according to feature map sizes.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
img_metas (list[dict]): Image meta info.
Returns:
tuple: anchors of each image, valid flags of each image
"""
num_imgs
=
len
(
img_metas
)
num_levels
=
len
(
featmap_sizes
)
# since feature map sizes of all images are the same, we only compute
# anchors for one time
multi_level_anchors
=
[]
for
i
in
range
(
num_levels
):
anchors
=
self
.
anchor_generators
[
i
].
grid_anchors
(
featmap_sizes
[
i
],
self
.
anchor_strides
[
i
])
multi_level_anchors
.
append
(
anchors
)
anchor_list
=
[
multi_level_anchors
for
_
in
range
(
num_imgs
)]
# for each image, we compute valid flags of multi level anchors
valid_flag_list
=
[]
for
img_id
,
img_meta
in
enumerate
(
img_metas
):
multi_level_flags
=
[]
for
i
in
range
(
num_levels
):
anchor_stride
=
self
.
anchor_strides
[
i
]
feat_h
,
feat_w
=
featmap_sizes
[
i
]
h
,
w
,
_
=
img_meta
[
'pad_shape'
]
valid_feat_h
=
min
(
int
(
np
.
ceil
(
h
/
anchor_stride
)),
feat_h
)
valid_feat_w
=
min
(
int
(
np
.
ceil
(
w
/
anchor_stride
)),
feat_w
)
flags
=
self
.
anchor_generators
[
i
].
valid_flags
(
(
feat_h
,
feat_w
),
(
valid_feat_h
,
valid_feat_w
))
multi_level_flags
.
append
(
flags
)
valid_flag_list
.
append
(
multi_level_flags
)
return
anchor_list
,
valid_flag_list
def
loss_single
(
self
,
rpn_cls_score
,
rpn_bbox_pred
,
labels
,
label_weights
,
bbox_targets
,
bbox_weights
,
num_total_samples
,
cfg
):
# classification loss
labels
=
labels
.
contiguous
().
view
(
-
1
)
label_weights
=
label_weights
.
contiguous
().
view
(
-
1
)
if
self
.
use_sigmoid_cls
:
rpn_cls_score
=
rpn_cls_score
.
permute
(
0
,
2
,
3
,
1
).
contiguous
().
view
(
-
1
)
criterion
=
weighted_binary_cross_entropy
else
:
rpn_cls_score
=
rpn_cls_score
.
permute
(
0
,
2
,
3
,
1
).
contiguous
().
view
(
-
1
,
2
)
criterion
=
weighted_cross_entropy
loss_cls
=
criterion
(
rpn_cls_score
,
labels
,
label_weights
,
avg_factor
=
num_total_samples
)
# regression loss
bbox_targets
=
bbox_targets
.
contiguous
().
view
(
-
1
,
4
)
bbox_weights
=
bbox_weights
.
contiguous
().
view
(
-
1
,
4
)
rpn_bbox_pred
=
rpn_bbox_pred
.
permute
(
0
,
2
,
3
,
1
).
contiguous
().
view
(
-
1
,
4
)
loss_reg
=
weighted_smoothl1
(
rpn_bbox_pred
,
bbox_targets
,
bbox_weights
,
beta
=
cfg
.
smoothl1_beta
,
avg_factor
=
num_total_samples
)
return
loss_cls
,
loss_reg
def
loss
(
self
,
rpn_cls_scores
,
rpn_bbox_preds
,
gt_bboxes
,
img_shapes
,
cfg
):
featmap_sizes
=
[
featmap
.
size
()[
-
2
:]
for
featmap
in
rpn_cls_scores
]
assert
len
(
featmap_sizes
)
==
len
(
self
.
anchor_generators
)
anchor_list
,
valid_flag_list
=
self
.
get_anchors
(
featmap_sizes
,
img_shapes
)
cls_reg_targets
=
anchor_target
(
anchor_list
,
valid_flag_list
,
gt_bboxes
,
img_shapes
,
self
.
target_means
,
self
.
target_stds
,
cfg
)
if
cls_reg_targets
is
None
:
return
None
(
labels_list
,
label_weights_list
,
bbox_targets_list
,
bbox_weights_list
,
num_total_pos
,
num_total_neg
)
=
cls_reg_targets
losses_cls
,
losses_reg
=
multi_apply
(
self
.
loss_single
,
rpn_cls_scores
,
rpn_bbox_preds
,
labels_list
,
label_weights_list
,
bbox_targets_list
,
bbox_weights_list
,
num_total_samples
=
num_total_pos
+
num_total_neg
,
cfg
=
cfg
)
return
dict
(
loss_rpn_cls
=
losses_cls
,
loss_rpn_reg
=
losses_reg
)
def
get_proposals
(
self
,
rpn_cls_scores
,
rpn_bbox_preds
,
img_meta
,
cfg
):
num_imgs
=
len
(
img_meta
)
featmap_sizes
=
[
featmap
.
size
()[
-
2
:]
for
featmap
in
rpn_cls_scores
]
mlvl_anchors
=
[
self
.
anchor_generators
[
idx
].
grid_anchors
(
featmap_sizes
[
idx
],
self
.
anchor_strides
[
idx
])
for
idx
in
range
(
len
(
featmap_sizes
))
]
proposal_list
=
[]
for
img_id
in
range
(
num_imgs
):
rpn_cls_score_list
=
[
rpn_cls_scores
[
idx
][
img_id
].
detach
()
for
idx
in
range
(
len
(
rpn_cls_scores
))
]
rpn_bbox_pred_list
=
[
rpn_bbox_preds
[
idx
][
img_id
].
detach
()
for
idx
in
range
(
len
(
rpn_bbox_preds
))
]
assert
len
(
rpn_cls_score_list
)
==
len
(
rpn_bbox_pred_list
)
proposals
=
self
.
_get_proposals_single
(
rpn_cls_score_list
,
rpn_bbox_pred_list
,
mlvl_anchors
,
img_meta
[
img_id
][
'img_shape'
],
cfg
)
proposal_list
.
append
(
proposals
)
return
proposal_list
def
_get_proposals_single
(
self
,
rpn_cls_scores
,
rpn_bbox_preds
,
mlvl_anchors
,
img_shape
,
cfg
):
mlvl_proposals
=
[]
for
idx
in
range
(
len
(
rpn_cls_scores
)):
rpn_cls_score
=
rpn_cls_scores
[
idx
]
rpn_bbox_pred
=
rpn_bbox_preds
[
idx
]
assert
rpn_cls_score
.
size
()[
-
2
:]
==
rpn_bbox_pred
.
size
()[
-
2
:]
anchors
=
mlvl_anchors
[
idx
]
if
self
.
use_sigmoid_cls
:
rpn_cls_score
=
rpn_cls_score
.
permute
(
1
,
2
,
0
).
contiguous
().
view
(
-
1
)
rpn_cls_prob
=
rpn_cls_score
.
sigmoid
()
scores
=
rpn_cls_prob
else
:
rpn_cls_score
=
rpn_cls_score
.
permute
(
1
,
2
,
0
).
contiguous
().
view
(
-
1
,
2
)
rpn_cls_prob
=
F
.
softmax
(
rpn_cls_score
,
dim
=
1
)
scores
=
rpn_cls_prob
[:,
1
]
rpn_bbox_pred
=
rpn_bbox_pred
.
permute
(
1
,
2
,
0
).
contiguous
().
view
(
-
1
,
4
)
_
,
order
=
scores
.
sort
(
0
,
descending
=
True
)
if
cfg
.
nms_pre
>
0
:
order
=
order
[:
cfg
.
nms_pre
]
rpn_bbox_pred
=
rpn_bbox_pred
[
order
,
:]
anchors
=
anchors
[
order
,
:]
scores
=
scores
[
order
]
proposals
=
delta2bbox
(
anchors
,
rpn_bbox_pred
,
self
.
target_means
,
self
.
target_stds
,
img_shape
)
w
=
proposals
[:,
2
]
-
proposals
[:,
0
]
+
1
h
=
proposals
[:,
3
]
-
proposals
[:,
1
]
+
1
valid_inds
=
torch
.
nonzero
((
w
>=
cfg
.
min_bbox_size
)
&
(
h
>=
cfg
.
min_bbox_size
)).
squeeze
()
proposals
=
proposals
[
valid_inds
,
:]
scores
=
scores
[
valid_inds
]
proposals
=
torch
.
cat
([
proposals
,
scores
.
unsqueeze
(
-
1
)],
dim
=-
1
)
proposals
,
_
=
nms
(
proposals
,
cfg
.
nms_thr
)
proposals
=
proposals
[:
cfg
.
nms_post
,
:]
mlvl_proposals
.
append
(
proposals
)
proposals
=
torch
.
cat
(
mlvl_proposals
,
0
)
if
cfg
.
nms_across_levels
:
proposals
,
_
=
nms
(
proposals
,
cfg
.
nms_thr
)
proposals
=
proposals
[:
cfg
.
max_num
,
:]
else
:
scores
=
proposals
[:,
4
]
_
,
order
=
scores
.
sort
(
0
,
descending
=
True
)
num
=
min
(
cfg
.
max_num
,
proposals
.
shape
[
0
])
order
=
order
[:
num
]
proposals
=
proposals
[
order
,
:]
return
proposals
mmdet/models/single_stage_heads/__init__.py
deleted
100644 → 0
View file @
2017c81e
from
.retina_head
import
RetinaHead
__all__
=
[
'RetinaHead'
]
mmdet/models/utils/conv_module.py
View file @
441015ea
...
...
@@ -53,7 +53,8 @@ class ConvModule(nn.Module):
if
self
.
with_norm
:
norm_channels
=
out_channels
if
self
.
activate_last
else
in_channels
self
.
norm
=
build_norm_layer
(
normalize
,
norm_channels
)
self
.
norm_name
,
norm
=
build_norm_layer
(
normalize
,
norm_channels
)
self
.
add_module
(
self
.
norm_name
,
norm
)
if
self
.
with_activatation
:
assert
activation
in
[
'relu'
],
'Only ReLU supported.'
...
...
@@ -63,6 +64,10 @@ class ConvModule(nn.Module):
# Default using msra init
self
.
init_weights
()
@
property
def
norm
(
self
):
return
getattr
(
self
,
self
.
norm_name
)
def
init_weights
(
self
):
nonlinearity
=
'relu'
if
self
.
activation
is
None
else
self
.
activation
kaiming_init
(
self
.
conv
,
nonlinearity
=
nonlinearity
)
...
...
mmdet/models/utils/norm.py
View file @
441015ea
import
torch.nn
as
nn
norm_cfg
=
{
'BN'
:
nn
.
BatchNorm2d
,
'SyncBN'
:
None
,
'GN'
:
None
}
norm_cfg
=
{
# format: layer_type: (abbreviation, module)
'BN'
:
(
'bn'
,
nn
.
BatchNorm2d
),
'SyncBN'
:
(
'bn'
,
None
),
'GN'
:
(
'gn'
,
nn
.
GroupNorm
),
# and potentially 'SN'
}
def
build_norm_layer
(
cfg
,
num_features
):
def
build_norm_layer
(
cfg
,
num_features
,
postfix
=
''
):
""" Build normalization layer
Args:
cfg (dict): cfg should contain:
type (str): identify norm layer type.
layer args: args needed to instantiate a norm layer.
frozen (bool): [optional] whether stop gradient updates
of norm layer, it is helpful to set frozen mode
in backbone's norms.
num_features (int): number of channels from input
postfix (int, str): appended into norm abbreation to
create named layer.
Returns:
name (str): abbreation + postfix
layer (nn.Module): created norm layer
"""
assert
isinstance
(
cfg
,
dict
)
and
'type'
in
cfg
cfg_
=
cfg
.
copy
()
cfg_
.
setdefault
(
'eps'
,
1e-5
)
layer_type
=
cfg_
.
pop
(
'type'
)
layer_type
=
cfg_
.
pop
(
'type'
)
if
layer_type
not
in
norm_cfg
:
raise
KeyError
(
'Unrecognized norm type {}'
.
format
(
layer_type
))
elif
norm_cfg
[
layer_type
]
is
None
:
raise
NotImplementedError
else
:
abbr
,
norm_layer
=
norm_cfg
[
layer_type
]
if
norm_layer
is
None
:
raise
NotImplementedError
assert
isinstance
(
postfix
,
(
int
,
str
))
name
=
abbr
+
str
(
postfix
)
frozen
=
cfg_
.
pop
(
'frozen'
,
False
)
cfg_
.
setdefault
(
'eps'
,
1e-5
)
if
layer_type
!=
'GN'
:
layer
=
norm_layer
(
num_features
,
**
cfg_
)
else
:
assert
'num_groups'
in
cfg_
layer
=
norm_layer
(
num_channels
=
num_features
,
**
cfg_
)
if
frozen
:
for
param
in
layer
.
parameters
():
param
.
requires_grad
=
False
return
n
orm_cfg
[
layer_type
](
num_features
,
**
cfg_
)
return
n
ame
,
layer
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment