"docs/vscode:/vscode.git/clone" did not exist on "d1f74a3e4d0340564dd6dcf6061b3a1cacf1f3c7"
Unverified Commit 95d44cc1 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Merge pull request #253 from hellock/registry

Use registry to manage modules
parents e72a9fd5 e2594f17
from .detectors import (BaseDetector, TwoStageDetector, RPN, FastRCNN, from .backbones import * # noqa: F401,F403
FasterRCNN, MaskRCNN) from .necks import * # noqa: F401,F403
from .builder import (build_neck, build_anchor_head, build_roi_extractor, from .roi_extractors import * # noqa: F401,F403
build_bbox_head, build_mask_head, build_detector) from .anchor_heads import * # noqa: F401,F403
from .bbox_heads import * # noqa: F401,F403
from .mask_heads import * # noqa: F401,F403
from .detectors import * # noqa: F401,F403
from .registry import BACKBONES, NECKS, ROI_EXTRACTORS, HEADS, DETECTORS
from .builder import (build_backbone, build_neck, build_roi_extractor,
build_head, build_detector)
__all__ = [ __all__ = [
'BaseDetector', 'TwoStageDetector', 'RPN', 'FastRCNN', 'FasterRCNN', 'BACKBONES', 'NECKS', 'ROI_EXTRACTORS', 'HEADS', 'DETECTORS',
'MaskRCNN', 'build_backbone', 'build_neck', 'build_anchor_head', 'build_backbone', 'build_neck', 'build_roi_extractor', 'build_head',
'build_roi_extractor', 'build_bbox_head', 'build_mask_head',
'build_detector' 'build_detector'
] ]
...@@ -3,14 +3,16 @@ from __future__ import division ...@@ -3,14 +3,16 @@ from __future__ import division
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmcv.cnn import normal_init
from mmdet.core import (AnchorGenerator, anchor_target, delta2bbox, from mmdet.core import (AnchorGenerator, anchor_target, delta2bbox,
multi_apply, weighted_cross_entropy, weighted_smoothl1, multi_apply, weighted_cross_entropy, weighted_smoothl1,
weighted_binary_cross_entropy, weighted_binary_cross_entropy,
weighted_sigmoid_focal_loss, multiclass_nms) weighted_sigmoid_focal_loss, multiclass_nms)
from ..utils import normal_init from ..registry import HEADS
@HEADS.register_module
class AnchorHead(nn.Module): class AnchorHead(nn.Module):
"""Anchor-based head (RPN, RetinaNet, SSD, etc.). """Anchor-based head (RPN, RetinaNet, SSD, etc.).
......
...@@ -3,9 +3,11 @@ import torch.nn as nn ...@@ -3,9 +3,11 @@ import torch.nn as nn
from mmcv.cnn import normal_init from mmcv.cnn import normal_init
from .anchor_head import AnchorHead from .anchor_head import AnchorHead
from ..registry import HEADS
from ..utils import bias_init_with_prob from ..utils import bias_init_with_prob
@HEADS.register_module
class RetinaHead(AnchorHead): class RetinaHead(AnchorHead):
def __init__(self, def __init__(self,
......
...@@ -6,8 +6,10 @@ from mmcv.cnn import normal_init ...@@ -6,8 +6,10 @@ from mmcv.cnn import normal_init
from mmdet.core import delta2bbox from mmdet.core import delta2bbox
from mmdet.ops import nms from mmdet.ops import nms
from .anchor_head import AnchorHead from .anchor_head import AnchorHead
from ..registry import HEADS
@HEADS.register_module
class RPNHead(AnchorHead): class RPNHead(AnchorHead):
def __init__(self, in_channels, **kwargs): def __init__(self, in_channels, **kwargs):
......
...@@ -7,8 +7,10 @@ from mmcv.cnn import xavier_init ...@@ -7,8 +7,10 @@ from mmcv.cnn import xavier_init
from mmdet.core import (AnchorGenerator, anchor_target, weighted_smoothl1, from mmdet.core import (AnchorGenerator, anchor_target, weighted_smoothl1,
multi_apply) multi_apply)
from .anchor_head import AnchorHead from .anchor_head import AnchorHead
from ..registry import HEADS
@HEADS.register_module
class SSDHead(AnchorHead): class SSDHead(AnchorHead):
def __init__(self, def __init__(self,
......
...@@ -7,6 +7,8 @@ from mmcv.cnn import constant_init, kaiming_init ...@@ -7,6 +7,8 @@ from mmcv.cnn import constant_init, kaiming_init
from mmcv.runner import load_checkpoint from mmcv.runner import load_checkpoint
from ..utils import build_norm_layer from ..utils import build_norm_layer
from ..registry import BACKBONES
def conv3x3(in_planes, out_planes, stride=1, dilation=1): def conv3x3(in_planes, out_planes, stride=1, dilation=1):
"3x3 convolution with padding" "3x3 convolution with padding"
...@@ -222,6 +224,7 @@ def make_res_layer(block, ...@@ -222,6 +224,7 @@ def make_res_layer(block,
return nn.Sequential(*layers) return nn.Sequential(*layers)
@BACKBONES.register_module
class ResNet(nn.Module): class ResNet(nn.Module):
"""ResNet backbone. """ResNet backbone.
......
...@@ -4,6 +4,7 @@ import torch.nn as nn ...@@ -4,6 +4,7 @@ import torch.nn as nn
from .resnet import ResNet from .resnet import ResNet
from .resnet import Bottleneck as _Bottleneck from .resnet import Bottleneck as _Bottleneck
from ..registry import BACKBONES
from ..utils import build_norm_layer from ..utils import build_norm_layer
...@@ -106,6 +107,7 @@ def make_res_layer(block, ...@@ -106,6 +107,7 @@ def make_res_layer(block,
return nn.Sequential(*layers) return nn.Sequential(*layers)
@BACKBONES.register_module
class ResNeXt(ResNet): class ResNeXt(ResNet):
"""ResNeXt backbone. """ResNeXt backbone.
......
...@@ -6,8 +6,10 @@ import torch.nn.functional as F ...@@ -6,8 +6,10 @@ import torch.nn.functional as F
from mmcv.cnn import (VGG, xavier_init, constant_init, kaiming_init, from mmcv.cnn import (VGG, xavier_init, constant_init, kaiming_init,
normal_init) normal_init)
from mmcv.runner import load_checkpoint from mmcv.runner import load_checkpoint
from ..registry import BACKBONES
@BACKBONES.register_module
class SSDVGG(VGG): class SSDVGG(VGG):
extra_setting = { extra_setting = {
300: (256, 'S', 512, 128, 'S', 256, 128, 256, 128, 256), 300: (256, 'S', 512, 128, 'S', 256, 128, 256, 128, 256),
......
...@@ -4,8 +4,10 @@ import torch.nn.functional as F ...@@ -4,8 +4,10 @@ import torch.nn.functional as F
from mmdet.core import (delta2bbox, multiclass_nms, bbox_target, from mmdet.core import (delta2bbox, multiclass_nms, bbox_target,
weighted_cross_entropy, weighted_smoothl1, accuracy) weighted_cross_entropy, weighted_smoothl1, accuracy)
from ..registry import HEADS
@HEADS.register_module
class BBoxHead(nn.Module): class BBoxHead(nn.Module):
"""Simplest RoI head, with only two fc layers for classification and """Simplest RoI head, with only two fc layers for classification and
regression respectively""" regression respectively"""
...@@ -78,8 +80,14 @@ class BBoxHead(nn.Module): ...@@ -78,8 +80,14 @@ class BBoxHead(nn.Module):
target_stds=self.target_stds) target_stds=self.target_stds)
return cls_reg_targets return cls_reg_targets
def loss(self, cls_score, bbox_pred, labels, label_weights, bbox_targets, def loss(self,
bbox_weights, reduce=True): cls_score,
bbox_pred,
labels,
label_weights,
bbox_targets,
bbox_weights,
reduce=True):
losses = dict() losses = dict()
if cls_score is not None: if cls_score is not None:
losses['loss_cls'] = weighted_cross_entropy( losses['loss_cls'] = weighted_cross_entropy(
......
import torch.nn as nn import torch.nn as nn
from .bbox_head import BBoxHead from .bbox_head import BBoxHead
from ..registry import HEADS
from ..utils import ConvModule from ..utils import ConvModule
@HEADS.register_module
class ConvFCBBoxHead(BBoxHead): class ConvFCBBoxHead(BBoxHead):
"""More general bbox head, with shared conv and fc layers and two optional """More general bbox head, with shared conv and fc layers and two optional
separated branches. separated branches.
...@@ -165,6 +167,7 @@ class ConvFCBBoxHead(BBoxHead): ...@@ -165,6 +167,7 @@ class ConvFCBBoxHead(BBoxHead):
return cls_score, bbox_pred return cls_score, bbox_pred
@HEADS.register_module
class SharedFCBBoxHead(ConvFCBBoxHead): class SharedFCBBoxHead(ConvFCBBoxHead):
def __init__(self, num_fcs=2, fc_out_channels=1024, *args, **kwargs): def __init__(self, num_fcs=2, fc_out_channels=1024, *args, **kwargs):
......
from mmcv.runner import obj_from_dict import mmcv
from torch import nn from torch import nn
from . import (backbones, necks, roi_extractors, anchor_heads, bbox_heads, from .registry import BACKBONES, NECKS, ROI_EXTRACTORS, HEADS, DETECTORS
mask_heads)
def _build_module(cfg, registry, default_args):
def _build_module(cfg, parrent=None, default_args=None): assert isinstance(cfg, dict) and 'type' in cfg
return cfg if isinstance(cfg, nn.Module) else obj_from_dict( assert isinstance(default_args, dict) or default_args is None
cfg, parrent, default_args) args = cfg.copy()
obj_type = args.pop('type')
if mmcv.is_str(obj_type):
def build(cfg, parrent=None, default_args=None): if obj_type not in registry.module_dict:
raise KeyError('{} is not in the {} registry'.format(
obj_type, registry.name))
obj_type = registry.module_dict[obj_type]
elif not isinstance(obj_type, type):
raise TypeError('type must be a str or valid type, but got {}'.format(
type(obj_type)))
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
return obj_type(**args)
def build(cfg, registry, default_args=None):
if isinstance(cfg, list): if isinstance(cfg, list):
modules = [_build_module(cfg_, parrent, default_args) for cfg_ in cfg] modules = [_build_module(cfg_, registry, default_args) for cfg_ in cfg]
return nn.Sequential(*modules) return nn.Sequential(*modules)
else: else:
return _build_module(cfg, parrent, default_args) return _build_module(cfg, registry, default_args)
def build_backbone(cfg): def build_backbone(cfg):
return build(cfg, backbones) return build(cfg, BACKBONES)
def build_neck(cfg): def build_neck(cfg):
return build(cfg, necks) return build(cfg, NECKS)
def build_anchor_head(cfg):
return build(cfg, anchor_heads)
def build_roi_extractor(cfg): def build_roi_extractor(cfg):
return build(cfg, roi_extractors) return build(cfg, ROI_EXTRACTORS)
def build_bbox_head(cfg):
return build(cfg, bbox_heads)
def build_mask_head(cfg): def build_head(cfg):
return build(cfg, mask_heads) return build(cfg, HEADS)
def build_detector(cfg, train_cfg=None, test_cfg=None): def build_detector(cfg, train_cfg=None, test_cfg=None):
from . import detectors return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
return build(cfg, detectors, dict(train_cfg=train_cfg, test_cfg=test_cfg))
...@@ -6,10 +6,12 @@ import torch.nn as nn ...@@ -6,10 +6,12 @@ import torch.nn as nn
from .base import BaseDetector from .base import BaseDetector
from .test_mixins import RPNTestMixin from .test_mixins import RPNTestMixin
from .. import builder from .. import builder
from ..registry import DETECTORS
from mmdet.core import (assign_and_sample, bbox2roi, bbox2result, multi_apply, from mmdet.core import (assign_and_sample, bbox2roi, bbox2result, multi_apply,
merge_aug_masks) merge_aug_masks)
@DETECTORS.register_module
class CascadeRCNN(BaseDetector, RPNTestMixin): class CascadeRCNN(BaseDetector, RPNTestMixin):
def __init__(self, def __init__(self,
......
from .two_stage import TwoStageDetector from .two_stage import TwoStageDetector
from ..registry import DETECTORS
@DETECTORS.register_module
class FastRCNN(TwoStageDetector): class FastRCNN(TwoStageDetector):
def __init__(self, def __init__(self,
......
from .two_stage import TwoStageDetector from .two_stage import TwoStageDetector
from ..registry import DETECTORS
@DETECTORS.register_module
class FasterRCNN(TwoStageDetector): class FasterRCNN(TwoStageDetector):
def __init__(self, def __init__(self,
...@@ -13,11 +15,11 @@ class FasterRCNN(TwoStageDetector): ...@@ -13,11 +15,11 @@ class FasterRCNN(TwoStageDetector):
test_cfg, test_cfg,
pretrained=None): pretrained=None):
super(FasterRCNN, self).__init__( super(FasterRCNN, self).__init__(
backbone=backbone, backbone=backbone,
neck=neck, neck=neck,
rpn_head=rpn_head, rpn_head=rpn_head,
bbox_roi_extractor=bbox_roi_extractor, bbox_roi_extractor=bbox_roi_extractor,
bbox_head=bbox_head, bbox_head=bbox_head,
train_cfg=train_cfg, train_cfg=train_cfg,
test_cfg=test_cfg, test_cfg=test_cfg,
pretrained=pretrained) pretrained=pretrained)
from .two_stage import TwoStageDetector from .two_stage import TwoStageDetector
from ..registry import DETECTORS
@DETECTORS.register_module
class MaskRCNN(TwoStageDetector): class MaskRCNN(TwoStageDetector):
def __init__(self, def __init__(self,
......
from .single_stage import SingleStageDetector from .single_stage import SingleStageDetector
from ..registry import DETECTORS
@DETECTORS.register_module
class RetinaNet(SingleStageDetector): class RetinaNet(SingleStageDetector):
def __init__(self, def __init__(self,
......
...@@ -4,8 +4,10 @@ from mmdet.core import tensor2imgs, bbox_mapping ...@@ -4,8 +4,10 @@ from mmdet.core import tensor2imgs, bbox_mapping
from .base import BaseDetector from .base import BaseDetector
from .test_mixins import RPNTestMixin from .test_mixins import RPNTestMixin
from .. import builder from .. import builder
from ..registry import DETECTORS
@DETECTORS.register_module
class RPN(BaseDetector, RPNTestMixin): class RPN(BaseDetector, RPNTestMixin):
def __init__(self, def __init__(self,
...@@ -18,7 +20,7 @@ class RPN(BaseDetector, RPNTestMixin): ...@@ -18,7 +20,7 @@ class RPN(BaseDetector, RPNTestMixin):
super(RPN, self).__init__() super(RPN, self).__init__()
self.backbone = builder.build_backbone(backbone) self.backbone = builder.build_backbone(backbone)
self.neck = builder.build_neck(neck) if neck is not None else None self.neck = builder.build_neck(neck) if neck is not None else None
self.rpn_head = builder.build_anchor_head(rpn_head) self.rpn_head = builder.build_head(rpn_head)
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.test_cfg = test_cfg self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained) self.init_weights(pretrained=pretrained)
......
...@@ -2,9 +2,11 @@ import torch.nn as nn ...@@ -2,9 +2,11 @@ import torch.nn as nn
from .base import BaseDetector from .base import BaseDetector
from .. import builder from .. import builder
from ..registry import DETECTORS
from mmdet.core import bbox2result from mmdet.core import bbox2result
@DETECTORS.register_module
class SingleStageDetector(BaseDetector): class SingleStageDetector(BaseDetector):
def __init__(self, def __init__(self,
...@@ -18,7 +20,7 @@ class SingleStageDetector(BaseDetector): ...@@ -18,7 +20,7 @@ class SingleStageDetector(BaseDetector):
self.backbone = builder.build_backbone(backbone) self.backbone = builder.build_backbone(backbone)
if neck is not None: if neck is not None:
self.neck = builder.build_neck(neck) self.neck = builder.build_neck(neck)
self.bbox_head = builder.build_anchor_head(bbox_head) self.bbox_head = builder.build_head(bbox_head)
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.test_cfg = test_cfg self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained) self.init_weights(pretrained=pretrained)
......
...@@ -4,9 +4,11 @@ import torch.nn as nn ...@@ -4,9 +4,11 @@ import torch.nn as nn
from .base import BaseDetector from .base import BaseDetector
from .test_mixins import RPNTestMixin, BBoxTestMixin, MaskTestMixin from .test_mixins import RPNTestMixin, BBoxTestMixin, MaskTestMixin
from .. import builder from .. import builder
from ..registry import DETECTORS
from mmdet.core import bbox2roi, bbox2result, build_assigner, build_sampler from mmdet.core import bbox2roi, bbox2result, build_assigner, build_sampler
@DETECTORS.register_module
class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
MaskTestMixin): MaskTestMixin):
...@@ -30,17 +32,17 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, ...@@ -30,17 +32,17 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
raise NotImplementedError raise NotImplementedError
if rpn_head is not None: if rpn_head is not None:
self.rpn_head = builder.build_anchor_head(rpn_head) self.rpn_head = builder.build_head(rpn_head)
if bbox_head is not None: if bbox_head is not None:
self.bbox_roi_extractor = builder.build_roi_extractor( self.bbox_roi_extractor = builder.build_roi_extractor(
bbox_roi_extractor) bbox_roi_extractor)
self.bbox_head = builder.build_bbox_head(bbox_head) self.bbox_head = builder.build_head(bbox_head)
if mask_head is not None: if mask_head is not None:
self.mask_roi_extractor = builder.build_roi_extractor( self.mask_roi_extractor = builder.build_roi_extractor(
mask_roi_extractor) mask_roi_extractor)
self.mask_head = builder.build_mask_head(mask_head) self.mask_head = builder.build_head(mask_head)
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.test_cfg = test_cfg self.test_cfg = test_cfg
......
...@@ -4,10 +4,12 @@ import pycocotools.mask as mask_util ...@@ -4,10 +4,12 @@ import pycocotools.mask as mask_util
import torch import torch
import torch.nn as nn import torch.nn as nn
from ..registry import HEADS
from ..utils import ConvModule from ..utils import ConvModule
from mmdet.core import mask_cross_entropy, mask_target from mmdet.core import mask_cross_entropy, mask_target
@HEADS.register_module
class FCNMaskHead(nn.Module): class FCNMaskHead(nn.Module):
def __init__(self, def __init__(self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment