Unverified Commit 95d44cc1 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Merge pull request #253 from hellock/registry

Use registry to manage modules
parents e72a9fd5 e2594f17
from .detectors import (BaseDetector, TwoStageDetector, RPN, FastRCNN,
FasterRCNN, MaskRCNN)
from .builder import (build_neck, build_anchor_head, build_roi_extractor,
build_bbox_head, build_mask_head, build_detector)
from .backbones import * # noqa: F401,F403
from .necks import * # noqa: F401,F403
from .roi_extractors import * # noqa: F401,F403
from .anchor_heads import * # noqa: F401,F403
from .bbox_heads import * # noqa: F401,F403
from .mask_heads import * # noqa: F401,F403
from .detectors import * # noqa: F401,F403
from .registry import BACKBONES, NECKS, ROI_EXTRACTORS, HEADS, DETECTORS
from .builder import (build_backbone, build_neck, build_roi_extractor,
build_head, build_detector)
__all__ = [
'BaseDetector', 'TwoStageDetector', 'RPN', 'FastRCNN', 'FasterRCNN',
'MaskRCNN', 'build_backbone', 'build_neck', 'build_anchor_head',
'build_roi_extractor', 'build_bbox_head', 'build_mask_head',
'BACKBONES', 'NECKS', 'ROI_EXTRACTORS', 'HEADS', 'DETECTORS',
'build_backbone', 'build_neck', 'build_roi_extractor', 'build_head',
'build_detector'
]
......@@ -3,14 +3,16 @@ from __future__ import division
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import normal_init
from mmdet.core import (AnchorGenerator, anchor_target, delta2bbox,
multi_apply, weighted_cross_entropy, weighted_smoothl1,
weighted_binary_cross_entropy,
weighted_sigmoid_focal_loss, multiclass_nms)
from ..utils import normal_init
from ..registry import HEADS
@HEADS.register_module
class AnchorHead(nn.Module):
"""Anchor-based head (RPN, RetinaNet, SSD, etc.).
......
......@@ -3,9 +3,11 @@ import torch.nn as nn
from mmcv.cnn import normal_init
from .anchor_head import AnchorHead
from ..registry import HEADS
from ..utils import bias_init_with_prob
@HEADS.register_module
class RetinaHead(AnchorHead):
def __init__(self,
......
......@@ -6,8 +6,10 @@ from mmcv.cnn import normal_init
from mmdet.core import delta2bbox
from mmdet.ops import nms
from .anchor_head import AnchorHead
from ..registry import HEADS
@HEADS.register_module
class RPNHead(AnchorHead):
def __init__(self, in_channels, **kwargs):
......
......@@ -7,8 +7,10 @@ from mmcv.cnn import xavier_init
from mmdet.core import (AnchorGenerator, anchor_target, weighted_smoothl1,
multi_apply)
from .anchor_head import AnchorHead
from ..registry import HEADS
@HEADS.register_module
class SSDHead(AnchorHead):
def __init__(self,
......
......@@ -7,6 +7,8 @@ from mmcv.cnn import constant_init, kaiming_init
from mmcv.runner import load_checkpoint
from ..utils import build_norm_layer
from ..registry import BACKBONES
def conv3x3(in_planes, out_planes, stride=1, dilation=1):
"3x3 convolution with padding"
......@@ -222,6 +224,7 @@ def make_res_layer(block,
return nn.Sequential(*layers)
@BACKBONES.register_module
class ResNet(nn.Module):
"""ResNet backbone.
......
......@@ -4,6 +4,7 @@ import torch.nn as nn
from .resnet import ResNet
from .resnet import Bottleneck as _Bottleneck
from ..registry import BACKBONES
from ..utils import build_norm_layer
......@@ -106,6 +107,7 @@ def make_res_layer(block,
return nn.Sequential(*layers)
@BACKBONES.register_module
class ResNeXt(ResNet):
"""ResNeXt backbone.
......
......@@ -6,8 +6,10 @@ import torch.nn.functional as F
from mmcv.cnn import (VGG, xavier_init, constant_init, kaiming_init,
normal_init)
from mmcv.runner import load_checkpoint
from ..registry import BACKBONES
@BACKBONES.register_module
class SSDVGG(VGG):
extra_setting = {
300: (256, 'S', 512, 128, 'S', 256, 128, 256, 128, 256),
......
......@@ -4,8 +4,10 @@ import torch.nn.functional as F
from mmdet.core import (delta2bbox, multiclass_nms, bbox_target,
weighted_cross_entropy, weighted_smoothl1, accuracy)
from ..registry import HEADS
@HEADS.register_module
class BBoxHead(nn.Module):
"""Simplest RoI head, with only two fc layers for classification and
regression respectively"""
......@@ -78,8 +80,14 @@ class BBoxHead(nn.Module):
target_stds=self.target_stds)
return cls_reg_targets
def loss(self, cls_score, bbox_pred, labels, label_weights, bbox_targets,
bbox_weights, reduce=True):
def loss(self,
cls_score,
bbox_pred,
labels,
label_weights,
bbox_targets,
bbox_weights,
reduce=True):
losses = dict()
if cls_score is not None:
losses['loss_cls'] = weighted_cross_entropy(
......
import torch.nn as nn
from .bbox_head import BBoxHead
from ..registry import HEADS
from ..utils import ConvModule
@HEADS.register_module
class ConvFCBBoxHead(BBoxHead):
"""More general bbox head, with shared conv and fc layers and two optional
separated branches.
......@@ -165,6 +167,7 @@ class ConvFCBBoxHead(BBoxHead):
return cls_score, bbox_pred
@HEADS.register_module
class SharedFCBBoxHead(ConvFCBBoxHead):
def __init__(self, num_fcs=2, fc_out_channels=1024, *args, **kwargs):
......
from mmcv.runner import obj_from_dict
import mmcv
from torch import nn
from . import (backbones, necks, roi_extractors, anchor_heads, bbox_heads,
mask_heads)
def _build_module(cfg, parrent=None, default_args=None):
return cfg if isinstance(cfg, nn.Module) else obj_from_dict(
cfg, parrent, default_args)
def build(cfg, parrent=None, default_args=None):
from .registry import BACKBONES, NECKS, ROI_EXTRACTORS, HEADS, DETECTORS
def _build_module(cfg, registry, default_args):
assert isinstance(cfg, dict) and 'type' in cfg
assert isinstance(default_args, dict) or default_args is None
args = cfg.copy()
obj_type = args.pop('type')
if mmcv.is_str(obj_type):
if obj_type not in registry.module_dict:
raise KeyError('{} is not in the {} registry'.format(
obj_type, registry.name))
obj_type = registry.module_dict[obj_type]
elif not isinstance(obj_type, type):
raise TypeError('type must be a str or valid type, but got {}'.format(
type(obj_type)))
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
return obj_type(**args)
def build(cfg, registry, default_args=None):
if isinstance(cfg, list):
modules = [_build_module(cfg_, parrent, default_args) for cfg_ in cfg]
modules = [_build_module(cfg_, registry, default_args) for cfg_ in cfg]
return nn.Sequential(*modules)
else:
return _build_module(cfg, parrent, default_args)
return _build_module(cfg, registry, default_args)
def build_backbone(cfg):
return build(cfg, backbones)
return build(cfg, BACKBONES)
def build_neck(cfg):
return build(cfg, necks)
def build_anchor_head(cfg):
return build(cfg, anchor_heads)
return build(cfg, NECKS)
def build_roi_extractor(cfg):
return build(cfg, roi_extractors)
def build_bbox_head(cfg):
return build(cfg, bbox_heads)
return build(cfg, ROI_EXTRACTORS)
def build_mask_head(cfg):
return build(cfg, mask_heads)
def build_head(cfg):
return build(cfg, HEADS)
def build_detector(cfg, train_cfg=None, test_cfg=None):
from . import detectors
return build(cfg, detectors, dict(train_cfg=train_cfg, test_cfg=test_cfg))
return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
......@@ -6,10 +6,12 @@ import torch.nn as nn
from .base import BaseDetector
from .test_mixins import RPNTestMixin
from .. import builder
from ..registry import DETECTORS
from mmdet.core import (assign_and_sample, bbox2roi, bbox2result, multi_apply,
merge_aug_masks)
@DETECTORS.register_module
class CascadeRCNN(BaseDetector, RPNTestMixin):
def __init__(self,
......
from .two_stage import TwoStageDetector
from ..registry import DETECTORS
@DETECTORS.register_module
class FastRCNN(TwoStageDetector):
def __init__(self,
......
from .two_stage import TwoStageDetector
from ..registry import DETECTORS
@DETECTORS.register_module
class FasterRCNN(TwoStageDetector):
def __init__(self,
......@@ -13,11 +15,11 @@ class FasterRCNN(TwoStageDetector):
test_cfg,
pretrained=None):
super(FasterRCNN, self).__init__(
backbone=backbone,
neck=neck,
rpn_head=rpn_head,
bbox_roi_extractor=bbox_roi_extractor,
bbox_head=bbox_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained)
backbone=backbone,
neck=neck,
rpn_head=rpn_head,
bbox_roi_extractor=bbox_roi_extractor,
bbox_head=bbox_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained)
from .two_stage import TwoStageDetector
from ..registry import DETECTORS
@DETECTORS.register_module
class MaskRCNN(TwoStageDetector):
def __init__(self,
......
from .single_stage import SingleStageDetector
from ..registry import DETECTORS
@DETECTORS.register_module
class RetinaNet(SingleStageDetector):
def __init__(self,
......
......@@ -4,8 +4,10 @@ from mmdet.core import tensor2imgs, bbox_mapping
from .base import BaseDetector
from .test_mixins import RPNTestMixin
from .. import builder
from ..registry import DETECTORS
@DETECTORS.register_module
class RPN(BaseDetector, RPNTestMixin):
def __init__(self,
......@@ -18,7 +20,7 @@ class RPN(BaseDetector, RPNTestMixin):
super(RPN, self).__init__()
self.backbone = builder.build_backbone(backbone)
self.neck = builder.build_neck(neck) if neck is not None else None
self.rpn_head = builder.build_anchor_head(rpn_head)
self.rpn_head = builder.build_head(rpn_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained)
......
......@@ -2,9 +2,11 @@ import torch.nn as nn
from .base import BaseDetector
from .. import builder
from ..registry import DETECTORS
from mmdet.core import bbox2result
@DETECTORS.register_module
class SingleStageDetector(BaseDetector):
def __init__(self,
......@@ -18,7 +20,7 @@ class SingleStageDetector(BaseDetector):
self.backbone = builder.build_backbone(backbone)
if neck is not None:
self.neck = builder.build_neck(neck)
self.bbox_head = builder.build_anchor_head(bbox_head)
self.bbox_head = builder.build_head(bbox_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained)
......
......@@ -4,9 +4,11 @@ import torch.nn as nn
from .base import BaseDetector
from .test_mixins import RPNTestMixin, BBoxTestMixin, MaskTestMixin
from .. import builder
from ..registry import DETECTORS
from mmdet.core import bbox2roi, bbox2result, build_assigner, build_sampler
@DETECTORS.register_module
class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
MaskTestMixin):
......@@ -30,17 +32,17 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
raise NotImplementedError
if rpn_head is not None:
self.rpn_head = builder.build_anchor_head(rpn_head)
self.rpn_head = builder.build_head(rpn_head)
if bbox_head is not None:
self.bbox_roi_extractor = builder.build_roi_extractor(
bbox_roi_extractor)
self.bbox_head = builder.build_bbox_head(bbox_head)
self.bbox_head = builder.build_head(bbox_head)
if mask_head is not None:
self.mask_roi_extractor = builder.build_roi_extractor(
mask_roi_extractor)
self.mask_head = builder.build_mask_head(mask_head)
self.mask_head = builder.build_head(mask_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
......
......@@ -4,10 +4,12 @@ import pycocotools.mask as mask_util
import torch
import torch.nn as nn
from ..registry import HEADS
from ..utils import ConvModule
from mmdet.core import mask_cross_entropy, mask_target
@HEADS.register_module
class FCNMaskHead(nn.Module):
def __init__(self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment