Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
95d44cc1
Unverified
Commit
95d44cc1
authored
Jan 16, 2019
by
Kai Chen
Committed by
GitHub
Jan 16, 2019
Browse files
Merge pull request #253 from hellock/registry
Use registry to manage modules
parents
e72a9fd5
e2594f17
Changes
23
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
104 additions
and
51 deletions
+104
-51
mmdet/models/__init__.py
mmdet/models/__init__.py
+12
-7
mmdet/models/anchor_heads/anchor_head.py
mmdet/models/anchor_heads/anchor_head.py
+3
-1
mmdet/models/anchor_heads/retina_head.py
mmdet/models/anchor_heads/retina_head.py
+2
-0
mmdet/models/anchor_heads/rpn_head.py
mmdet/models/anchor_heads/rpn_head.py
+2
-0
mmdet/models/anchor_heads/ssd_head.py
mmdet/models/anchor_heads/ssd_head.py
+2
-0
mmdet/models/backbones/resnet.py
mmdet/models/backbones/resnet.py
+3
-0
mmdet/models/backbones/resnext.py
mmdet/models/backbones/resnext.py
+2
-0
mmdet/models/backbones/ssd_vgg.py
mmdet/models/backbones/ssd_vgg.py
+2
-0
mmdet/models/bbox_heads/bbox_head.py
mmdet/models/bbox_heads/bbox_head.py
+10
-2
mmdet/models/bbox_heads/convfc_bbox_head.py
mmdet/models/bbox_heads/convfc_bbox_head.py
+3
-0
mmdet/models/builder.py
mmdet/models/builder.py
+32
-28
mmdet/models/detectors/cascade_rcnn.py
mmdet/models/detectors/cascade_rcnn.py
+2
-0
mmdet/models/detectors/fast_rcnn.py
mmdet/models/detectors/fast_rcnn.py
+2
-0
mmdet/models/detectors/faster_rcnn.py
mmdet/models/detectors/faster_rcnn.py
+10
-8
mmdet/models/detectors/mask_rcnn.py
mmdet/models/detectors/mask_rcnn.py
+2
-0
mmdet/models/detectors/retinanet.py
mmdet/models/detectors/retinanet.py
+2
-0
mmdet/models/detectors/rpn.py
mmdet/models/detectors/rpn.py
+3
-1
mmdet/models/detectors/single_stage.py
mmdet/models/detectors/single_stage.py
+3
-1
mmdet/models/detectors/two_stage.py
mmdet/models/detectors/two_stage.py
+5
-3
mmdet/models/mask_heads/fcn_mask_head.py
mmdet/models/mask_heads/fcn_mask_head.py
+2
-0
No files found.
mmdet/models/__init__.py
View file @
95d44cc1
from
.detectors
import
(
BaseDetector
,
TwoStageDetector
,
RPN
,
FastRCNN
,
from
.backbones
import
*
# noqa: F401,F403
FasterRCNN
,
MaskRCNN
)
from
.necks
import
*
# noqa: F401,F403
from
.builder
import
(
build_neck
,
build_anchor_head
,
build_roi_extractor
,
from
.roi_extractors
import
*
# noqa: F401,F403
build_bbox_head
,
build_mask_head
,
build_detector
)
from
.anchor_heads
import
*
# noqa: F401,F403
from
.bbox_heads
import
*
# noqa: F401,F403
from
.mask_heads
import
*
# noqa: F401,F403
from
.detectors
import
*
# noqa: F401,F403
from
.registry
import
BACKBONES
,
NECKS
,
ROI_EXTRACTORS
,
HEADS
,
DETECTORS
from
.builder
import
(
build_backbone
,
build_neck
,
build_roi_extractor
,
build_head
,
build_detector
)
__all__
=
[
__all__
=
[
'BaseDetector'
,
'TwoStageDetector'
,
'RPN'
,
'FastRCNN'
,
'FasterRCNN'
,
'BACKBONES'
,
'NECKS'
,
'ROI_EXTRACTORS'
,
'HEADS'
,
'DETECTORS'
,
'MaskRCNN'
,
'build_backbone'
,
'build_neck'
,
'build_anchor_head'
,
'build_backbone'
,
'build_neck'
,
'build_roi_extractor'
,
'build_head'
,
'build_roi_extractor'
,
'build_bbox_head'
,
'build_mask_head'
,
'build_detector'
'build_detector'
]
]
mmdet/models/anchor_heads/anchor_head.py
View file @
95d44cc1
...
@@ -3,14 +3,16 @@ from __future__ import division
...
@@ -3,14 +3,16 @@ from __future__ import division
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
mmcv.cnn
import
normal_init
from
mmdet.core
import
(
AnchorGenerator
,
anchor_target
,
delta2bbox
,
from
mmdet.core
import
(
AnchorGenerator
,
anchor_target
,
delta2bbox
,
multi_apply
,
weighted_cross_entropy
,
weighted_smoothl1
,
multi_apply
,
weighted_cross_entropy
,
weighted_smoothl1
,
weighted_binary_cross_entropy
,
weighted_binary_cross_entropy
,
weighted_sigmoid_focal_loss
,
multiclass_nms
)
weighted_sigmoid_focal_loss
,
multiclass_nms
)
from
..
utils
import
normal_init
from
..
registry
import
HEADS
@
HEADS
.
register_module
class
AnchorHead
(
nn
.
Module
):
class
AnchorHead
(
nn
.
Module
):
"""Anchor-based head (RPN, RetinaNet, SSD, etc.).
"""Anchor-based head (RPN, RetinaNet, SSD, etc.).
...
...
mmdet/models/anchor_heads/retina_head.py
View file @
95d44cc1
...
@@ -3,9 +3,11 @@ import torch.nn as nn
...
@@ -3,9 +3,11 @@ import torch.nn as nn
from
mmcv.cnn
import
normal_init
from
mmcv.cnn
import
normal_init
from
.anchor_head
import
AnchorHead
from
.anchor_head
import
AnchorHead
from
..registry
import
HEADS
from
..utils
import
bias_init_with_prob
from
..utils
import
bias_init_with_prob
@
HEADS
.
register_module
class
RetinaHead
(
AnchorHead
):
class
RetinaHead
(
AnchorHead
):
def
__init__
(
self
,
def
__init__
(
self
,
...
...
mmdet/models/anchor_heads/rpn_head.py
View file @
95d44cc1
...
@@ -6,8 +6,10 @@ from mmcv.cnn import normal_init
...
@@ -6,8 +6,10 @@ from mmcv.cnn import normal_init
from
mmdet.core
import
delta2bbox
from
mmdet.core
import
delta2bbox
from
mmdet.ops
import
nms
from
mmdet.ops
import
nms
from
.anchor_head
import
AnchorHead
from
.anchor_head
import
AnchorHead
from
..registry
import
HEADS
@
HEADS
.
register_module
class
RPNHead
(
AnchorHead
):
class
RPNHead
(
AnchorHead
):
def
__init__
(
self
,
in_channels
,
**
kwargs
):
def
__init__
(
self
,
in_channels
,
**
kwargs
):
...
...
mmdet/models/anchor_heads/ssd_head.py
View file @
95d44cc1
...
@@ -7,8 +7,10 @@ from mmcv.cnn import xavier_init
...
@@ -7,8 +7,10 @@ from mmcv.cnn import xavier_init
from
mmdet.core
import
(
AnchorGenerator
,
anchor_target
,
weighted_smoothl1
,
from
mmdet.core
import
(
AnchorGenerator
,
anchor_target
,
weighted_smoothl1
,
multi_apply
)
multi_apply
)
from
.anchor_head
import
AnchorHead
from
.anchor_head
import
AnchorHead
from
..registry
import
HEADS
@
HEADS
.
register_module
class
SSDHead
(
AnchorHead
):
class
SSDHead
(
AnchorHead
):
def
__init__
(
self
,
def
__init__
(
self
,
...
...
mmdet/models/backbones/resnet.py
View file @
95d44cc1
...
@@ -7,6 +7,8 @@ from mmcv.cnn import constant_init, kaiming_init
...
@@ -7,6 +7,8 @@ from mmcv.cnn import constant_init, kaiming_init
from
mmcv.runner
import
load_checkpoint
from
mmcv.runner
import
load_checkpoint
from
..utils
import
build_norm_layer
from
..utils
import
build_norm_layer
from
..registry
import
BACKBONES
def
conv3x3
(
in_planes
,
out_planes
,
stride
=
1
,
dilation
=
1
):
def
conv3x3
(
in_planes
,
out_planes
,
stride
=
1
,
dilation
=
1
):
"3x3 convolution with padding"
"3x3 convolution with padding"
...
@@ -222,6 +224,7 @@ def make_res_layer(block,
...
@@ -222,6 +224,7 @@ def make_res_layer(block,
return
nn
.
Sequential
(
*
layers
)
return
nn
.
Sequential
(
*
layers
)
@
BACKBONES
.
register_module
class
ResNet
(
nn
.
Module
):
class
ResNet
(
nn
.
Module
):
"""ResNet backbone.
"""ResNet backbone.
...
...
mmdet/models/backbones/resnext.py
View file @
95d44cc1
...
@@ -4,6 +4,7 @@ import torch.nn as nn
...
@@ -4,6 +4,7 @@ import torch.nn as nn
from
.resnet
import
ResNet
from
.resnet
import
ResNet
from
.resnet
import
Bottleneck
as
_Bottleneck
from
.resnet
import
Bottleneck
as
_Bottleneck
from
..registry
import
BACKBONES
from
..utils
import
build_norm_layer
from
..utils
import
build_norm_layer
...
@@ -106,6 +107,7 @@ def make_res_layer(block,
...
@@ -106,6 +107,7 @@ def make_res_layer(block,
return
nn
.
Sequential
(
*
layers
)
return
nn
.
Sequential
(
*
layers
)
@
BACKBONES
.
register_module
class
ResNeXt
(
ResNet
):
class
ResNeXt
(
ResNet
):
"""ResNeXt backbone.
"""ResNeXt backbone.
...
...
mmdet/models/backbones/ssd_vgg.py
View file @
95d44cc1
...
@@ -6,8 +6,10 @@ import torch.nn.functional as F
...
@@ -6,8 +6,10 @@ import torch.nn.functional as F
from
mmcv.cnn
import
(
VGG
,
xavier_init
,
constant_init
,
kaiming_init
,
from
mmcv.cnn
import
(
VGG
,
xavier_init
,
constant_init
,
kaiming_init
,
normal_init
)
normal_init
)
from
mmcv.runner
import
load_checkpoint
from
mmcv.runner
import
load_checkpoint
from
..registry
import
BACKBONES
@
BACKBONES
.
register_module
class
SSDVGG
(
VGG
):
class
SSDVGG
(
VGG
):
extra_setting
=
{
extra_setting
=
{
300
:
(
256
,
'S'
,
512
,
128
,
'S'
,
256
,
128
,
256
,
128
,
256
),
300
:
(
256
,
'S'
,
512
,
128
,
'S'
,
256
,
128
,
256
,
128
,
256
),
...
...
mmdet/models/bbox_heads/bbox_head.py
View file @
95d44cc1
...
@@ -4,8 +4,10 @@ import torch.nn.functional as F
...
@@ -4,8 +4,10 @@ import torch.nn.functional as F
from
mmdet.core
import
(
delta2bbox
,
multiclass_nms
,
bbox_target
,
from
mmdet.core
import
(
delta2bbox
,
multiclass_nms
,
bbox_target
,
weighted_cross_entropy
,
weighted_smoothl1
,
accuracy
)
weighted_cross_entropy
,
weighted_smoothl1
,
accuracy
)
from
..registry
import
HEADS
@
HEADS
.
register_module
class
BBoxHead
(
nn
.
Module
):
class
BBoxHead
(
nn
.
Module
):
"""Simplest RoI head, with only two fc layers for classification and
"""Simplest RoI head, with only two fc layers for classification and
regression respectively"""
regression respectively"""
...
@@ -78,8 +80,14 @@ class BBoxHead(nn.Module):
...
@@ -78,8 +80,14 @@ class BBoxHead(nn.Module):
target_stds
=
self
.
target_stds
)
target_stds
=
self
.
target_stds
)
return
cls_reg_targets
return
cls_reg_targets
def
loss
(
self
,
cls_score
,
bbox_pred
,
labels
,
label_weights
,
bbox_targets
,
def
loss
(
self
,
bbox_weights
,
reduce
=
True
):
cls_score
,
bbox_pred
,
labels
,
label_weights
,
bbox_targets
,
bbox_weights
,
reduce
=
True
):
losses
=
dict
()
losses
=
dict
()
if
cls_score
is
not
None
:
if
cls_score
is
not
None
:
losses
[
'loss_cls'
]
=
weighted_cross_entropy
(
losses
[
'loss_cls'
]
=
weighted_cross_entropy
(
...
...
mmdet/models/bbox_heads/convfc_bbox_head.py
View file @
95d44cc1
import
torch.nn
as
nn
import
torch.nn
as
nn
from
.bbox_head
import
BBoxHead
from
.bbox_head
import
BBoxHead
from
..registry
import
HEADS
from
..utils
import
ConvModule
from
..utils
import
ConvModule
@
HEADS
.
register_module
class
ConvFCBBoxHead
(
BBoxHead
):
class
ConvFCBBoxHead
(
BBoxHead
):
"""More general bbox head, with shared conv and fc layers and two optional
"""More general bbox head, with shared conv and fc layers and two optional
separated branches.
separated branches.
...
@@ -165,6 +167,7 @@ class ConvFCBBoxHead(BBoxHead):
...
@@ -165,6 +167,7 @@ class ConvFCBBoxHead(BBoxHead):
return
cls_score
,
bbox_pred
return
cls_score
,
bbox_pred
@
HEADS
.
register_module
class
SharedFCBBoxHead
(
ConvFCBBoxHead
):
class
SharedFCBBoxHead
(
ConvFCBBoxHead
):
def
__init__
(
self
,
num_fcs
=
2
,
fc_out_channels
=
1024
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
num_fcs
=
2
,
fc_out_channels
=
1024
,
*
args
,
**
kwargs
):
...
...
mmdet/models/builder.py
View file @
95d44cc1
from
mmcv.runner
import
obj_from_dict
import
mmcv
from
torch
import
nn
from
torch
import
nn
from
.
import
(
backbones
,
necks
,
roi_extractors
,
anchor_heads
,
bbox_heads
,
from
.registry
import
BACKBONES
,
NECKS
,
ROI_EXTRACTORS
,
HEADS
,
DETECTORS
mask_heads
)
def
_build_module
(
cfg
,
registry
,
default_args
):
def
_build_module
(
cfg
,
parrent
=
None
,
default_args
=
None
):
assert
isinstance
(
cfg
,
dict
)
and
'type'
in
cfg
return
cfg
if
isinstance
(
cfg
,
nn
.
Module
)
else
obj_from_dict
(
assert
isinstance
(
default_args
,
dict
)
or
default_args
is
None
cfg
,
parrent
,
default_args
)
args
=
cfg
.
copy
()
obj_type
=
args
.
pop
(
'type'
)
if
mmcv
.
is_str
(
obj_type
):
def
build
(
cfg
,
parrent
=
None
,
default_args
=
None
):
if
obj_type
not
in
registry
.
module_dict
:
raise
KeyError
(
'{} is not in the {} registry'
.
format
(
obj_type
,
registry
.
name
))
obj_type
=
registry
.
module_dict
[
obj_type
]
elif
not
isinstance
(
obj_type
,
type
):
raise
TypeError
(
'type must be a str or valid type, but got {}'
.
format
(
type
(
obj_type
)))
if
default_args
is
not
None
:
for
name
,
value
in
default_args
.
items
():
args
.
setdefault
(
name
,
value
)
return
obj_type
(
**
args
)
def
build
(
cfg
,
registry
,
default_args
=
None
):
if
isinstance
(
cfg
,
list
):
if
isinstance
(
cfg
,
list
):
modules
=
[
_build_module
(
cfg_
,
parrent
,
default_args
)
for
cfg_
in
cfg
]
modules
=
[
_build_module
(
cfg_
,
registry
,
default_args
)
for
cfg_
in
cfg
]
return
nn
.
Sequential
(
*
modules
)
return
nn
.
Sequential
(
*
modules
)
else
:
else
:
return
_build_module
(
cfg
,
parrent
,
default_args
)
return
_build_module
(
cfg
,
registry
,
default_args
)
def
build_backbone
(
cfg
):
def
build_backbone
(
cfg
):
return
build
(
cfg
,
backbones
)
return
build
(
cfg
,
BACKBONES
)
def
build_neck
(
cfg
):
def
build_neck
(
cfg
):
return
build
(
cfg
,
necks
)
return
build
(
cfg
,
NECKS
)
def
build_anchor_head
(
cfg
):
return
build
(
cfg
,
anchor_heads
)
def
build_roi_extractor
(
cfg
):
def
build_roi_extractor
(
cfg
):
return
build
(
cfg
,
roi_extractors
)
return
build
(
cfg
,
ROI_EXTRACTORS
)
def
build_bbox_head
(
cfg
):
return
build
(
cfg
,
bbox_heads
)
def
build_
mask_
head
(
cfg
):
def
build_head
(
cfg
):
return
build
(
cfg
,
mask_heads
)
return
build
(
cfg
,
HEADS
)
def
build_detector
(
cfg
,
train_cfg
=
None
,
test_cfg
=
None
):
def
build_detector
(
cfg
,
train_cfg
=
None
,
test_cfg
=
None
):
from
.
import
detectors
return
build
(
cfg
,
DETECTORS
,
dict
(
train_cfg
=
train_cfg
,
test_cfg
=
test_cfg
))
return
build
(
cfg
,
detectors
,
dict
(
train_cfg
=
train_cfg
,
test_cfg
=
test_cfg
))
mmdet/models/detectors/cascade_rcnn.py
View file @
95d44cc1
...
@@ -6,10 +6,12 @@ import torch.nn as nn
...
@@ -6,10 +6,12 @@ import torch.nn as nn
from
.base
import
BaseDetector
from
.base
import
BaseDetector
from
.test_mixins
import
RPNTestMixin
from
.test_mixins
import
RPNTestMixin
from
..
import
builder
from
..
import
builder
from
..registry
import
DETECTORS
from
mmdet.core
import
(
assign_and_sample
,
bbox2roi
,
bbox2result
,
multi_apply
,
from
mmdet.core
import
(
assign_and_sample
,
bbox2roi
,
bbox2result
,
multi_apply
,
merge_aug_masks
)
merge_aug_masks
)
@
DETECTORS
.
register_module
class
CascadeRCNN
(
BaseDetector
,
RPNTestMixin
):
class
CascadeRCNN
(
BaseDetector
,
RPNTestMixin
):
def
__init__
(
self
,
def
__init__
(
self
,
...
...
mmdet/models/detectors/fast_rcnn.py
View file @
95d44cc1
from
.two_stage
import
TwoStageDetector
from
.two_stage
import
TwoStageDetector
from
..registry
import
DETECTORS
@
DETECTORS
.
register_module
class
FastRCNN
(
TwoStageDetector
):
class
FastRCNN
(
TwoStageDetector
):
def
__init__
(
self
,
def
__init__
(
self
,
...
...
mmdet/models/detectors/faster_rcnn.py
View file @
95d44cc1
from
.two_stage
import
TwoStageDetector
from
.two_stage
import
TwoStageDetector
from
..registry
import
DETECTORS
@
DETECTORS
.
register_module
class
FasterRCNN
(
TwoStageDetector
):
class
FasterRCNN
(
TwoStageDetector
):
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -13,11 +15,11 @@ class FasterRCNN(TwoStageDetector):
...
@@ -13,11 +15,11 @@ class FasterRCNN(TwoStageDetector):
test_cfg
,
test_cfg
,
pretrained
=
None
):
pretrained
=
None
):
super
(
FasterRCNN
,
self
).
__init__
(
super
(
FasterRCNN
,
self
).
__init__
(
backbone
=
backbone
,
backbone
=
backbone
,
neck
=
neck
,
neck
=
neck
,
rpn_head
=
rpn_head
,
rpn_head
=
rpn_head
,
bbox_roi_extractor
=
bbox_roi_extractor
,
bbox_roi_extractor
=
bbox_roi_extractor
,
bbox_head
=
bbox_head
,
bbox_head
=
bbox_head
,
train_cfg
=
train_cfg
,
train_cfg
=
train_cfg
,
test_cfg
=
test_cfg
,
test_cfg
=
test_cfg
,
pretrained
=
pretrained
)
pretrained
=
pretrained
)
mmdet/models/detectors/mask_rcnn.py
View file @
95d44cc1
from
.two_stage
import
TwoStageDetector
from
.two_stage
import
TwoStageDetector
from
..registry
import
DETECTORS
@
DETECTORS
.
register_module
class
MaskRCNN
(
TwoStageDetector
):
class
MaskRCNN
(
TwoStageDetector
):
def
__init__
(
self
,
def
__init__
(
self
,
...
...
mmdet/models/detectors/retinanet.py
View file @
95d44cc1
from
.single_stage
import
SingleStageDetector
from
.single_stage
import
SingleStageDetector
from
..registry
import
DETECTORS
@
DETECTORS
.
register_module
class
RetinaNet
(
SingleStageDetector
):
class
RetinaNet
(
SingleStageDetector
):
def
__init__
(
self
,
def
__init__
(
self
,
...
...
mmdet/models/detectors/rpn.py
View file @
95d44cc1
...
@@ -4,8 +4,10 @@ from mmdet.core import tensor2imgs, bbox_mapping
...
@@ -4,8 +4,10 @@ from mmdet.core import tensor2imgs, bbox_mapping
from
.base
import
BaseDetector
from
.base
import
BaseDetector
from
.test_mixins
import
RPNTestMixin
from
.test_mixins
import
RPNTestMixin
from
..
import
builder
from
..
import
builder
from
..registry
import
DETECTORS
@
DETECTORS
.
register_module
class
RPN
(
BaseDetector
,
RPNTestMixin
):
class
RPN
(
BaseDetector
,
RPNTestMixin
):
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -18,7 +20,7 @@ class RPN(BaseDetector, RPNTestMixin):
...
@@ -18,7 +20,7 @@ class RPN(BaseDetector, RPNTestMixin):
super
(
RPN
,
self
).
__init__
()
super
(
RPN
,
self
).
__init__
()
self
.
backbone
=
builder
.
build_backbone
(
backbone
)
self
.
backbone
=
builder
.
build_backbone
(
backbone
)
self
.
neck
=
builder
.
build_neck
(
neck
)
if
neck
is
not
None
else
None
self
.
neck
=
builder
.
build_neck
(
neck
)
if
neck
is
not
None
else
None
self
.
rpn_head
=
builder
.
build_
anchor_
head
(
rpn_head
)
self
.
rpn_head
=
builder
.
build_head
(
rpn_head
)
self
.
train_cfg
=
train_cfg
self
.
train_cfg
=
train_cfg
self
.
test_cfg
=
test_cfg
self
.
test_cfg
=
test_cfg
self
.
init_weights
(
pretrained
=
pretrained
)
self
.
init_weights
(
pretrained
=
pretrained
)
...
...
mmdet/models/detectors/single_stage.py
View file @
95d44cc1
...
@@ -2,9 +2,11 @@ import torch.nn as nn
...
@@ -2,9 +2,11 @@ import torch.nn as nn
from
.base
import
BaseDetector
from
.base
import
BaseDetector
from
..
import
builder
from
..
import
builder
from
..registry
import
DETECTORS
from
mmdet.core
import
bbox2result
from
mmdet.core
import
bbox2result
@
DETECTORS
.
register_module
class
SingleStageDetector
(
BaseDetector
):
class
SingleStageDetector
(
BaseDetector
):
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -18,7 +20,7 @@ class SingleStageDetector(BaseDetector):
...
@@ -18,7 +20,7 @@ class SingleStageDetector(BaseDetector):
self
.
backbone
=
builder
.
build_backbone
(
backbone
)
self
.
backbone
=
builder
.
build_backbone
(
backbone
)
if
neck
is
not
None
:
if
neck
is
not
None
:
self
.
neck
=
builder
.
build_neck
(
neck
)
self
.
neck
=
builder
.
build_neck
(
neck
)
self
.
bbox_head
=
builder
.
build_
anchor_
head
(
bbox_head
)
self
.
bbox_head
=
builder
.
build_head
(
bbox_head
)
self
.
train_cfg
=
train_cfg
self
.
train_cfg
=
train_cfg
self
.
test_cfg
=
test_cfg
self
.
test_cfg
=
test_cfg
self
.
init_weights
(
pretrained
=
pretrained
)
self
.
init_weights
(
pretrained
=
pretrained
)
...
...
mmdet/models/detectors/two_stage.py
View file @
95d44cc1
...
@@ -4,9 +4,11 @@ import torch.nn as nn
...
@@ -4,9 +4,11 @@ import torch.nn as nn
from
.base
import
BaseDetector
from
.base
import
BaseDetector
from
.test_mixins
import
RPNTestMixin
,
BBoxTestMixin
,
MaskTestMixin
from
.test_mixins
import
RPNTestMixin
,
BBoxTestMixin
,
MaskTestMixin
from
..
import
builder
from
..
import
builder
from
..registry
import
DETECTORS
from
mmdet.core
import
bbox2roi
,
bbox2result
,
build_assigner
,
build_sampler
from
mmdet.core
import
bbox2roi
,
bbox2result
,
build_assigner
,
build_sampler
@
DETECTORS
.
register_module
class
TwoStageDetector
(
BaseDetector
,
RPNTestMixin
,
BBoxTestMixin
,
class
TwoStageDetector
(
BaseDetector
,
RPNTestMixin
,
BBoxTestMixin
,
MaskTestMixin
):
MaskTestMixin
):
...
@@ -30,17 +32,17 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
...
@@ -30,17 +32,17 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
raise
NotImplementedError
raise
NotImplementedError
if
rpn_head
is
not
None
:
if
rpn_head
is
not
None
:
self
.
rpn_head
=
builder
.
build_
anchor_
head
(
rpn_head
)
self
.
rpn_head
=
builder
.
build_head
(
rpn_head
)
if
bbox_head
is
not
None
:
if
bbox_head
is
not
None
:
self
.
bbox_roi_extractor
=
builder
.
build_roi_extractor
(
self
.
bbox_roi_extractor
=
builder
.
build_roi_extractor
(
bbox_roi_extractor
)
bbox_roi_extractor
)
self
.
bbox_head
=
builder
.
build_
bbox_
head
(
bbox_head
)
self
.
bbox_head
=
builder
.
build_head
(
bbox_head
)
if
mask_head
is
not
None
:
if
mask_head
is
not
None
:
self
.
mask_roi_extractor
=
builder
.
build_roi_extractor
(
self
.
mask_roi_extractor
=
builder
.
build_roi_extractor
(
mask_roi_extractor
)
mask_roi_extractor
)
self
.
mask_head
=
builder
.
build_
mask_
head
(
mask_head
)
self
.
mask_head
=
builder
.
build_head
(
mask_head
)
self
.
train_cfg
=
train_cfg
self
.
train_cfg
=
train_cfg
self
.
test_cfg
=
test_cfg
self
.
test_cfg
=
test_cfg
...
...
mmdet/models/mask_heads/fcn_mask_head.py
View file @
95d44cc1
...
@@ -4,10 +4,12 @@ import pycocotools.mask as mask_util
...
@@ -4,10 +4,12 @@ import pycocotools.mask as mask_util
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
..registry
import
HEADS
from
..utils
import
ConvModule
from
..utils
import
ConvModule
from
mmdet.core
import
mask_cross_entropy
,
mask_target
from
mmdet.core
import
mask_cross_entropy
,
mask_target
@
HEADS
.
register_module
class
FCNMaskHead
(
nn
.
Module
):
class
FCNMaskHead
(
nn
.
Module
):
def
__init__
(
self
,
def
__init__
(
self
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment