Commit 62255259 authored by Kai Chen's avatar Kai Chen
Browse files

use registry to manage modules

parent 2df1e0a0
from .detectors import (BaseDetector, TwoStageDetector, RPN, FastRCNN,
FasterRCNN, MaskRCNN)
from .builder import (build_neck, build_anchor_head, build_roi_extractor,
build_bbox_head, build_mask_head, build_detector)
from .backbones import *
from .necks import *
from .roi_extractors import *
from .anchor_heads import *
from .bbox_heads import *
from .mask_heads import *
from .detectors import *
from .registry import BACKBONES, NECKS, ROI_EXTRACTORS, HEADS, DETECTORS
from .builder import (build_backbone, build_neck, build_roi_extractor,
build_head, build_detector)
__all__ = [
'BaseDetector', 'TwoStageDetector', 'RPN', 'FastRCNN', 'FasterRCNN',
'MaskRCNN', 'build_backbone', 'build_neck', 'build_anchor_head',
'build_roi_extractor', 'build_bbox_head', 'build_mask_head',
'BACKBONES', 'NECKS', 'ROI_EXTRACTORS', 'HEADS', 'DETECTORS',
'build_backbone', 'build_neck', 'build_roi_extractor', 'build_head',
'build_detector'
]
......@@ -3,14 +3,16 @@ from __future__ import division
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import normal_init
from mmdet.core import (AnchorGenerator, anchor_target, delta2bbox,
multi_apply, weighted_cross_entropy, weighted_smoothl1,
weighted_binary_cross_entropy,
weighted_sigmoid_focal_loss, multiclass_nms)
from ..utils import normal_init
from ..registry import HEADS
@HEADS.register_module
class AnchorHead(nn.Module):
"""Anchor-based head (RPN, RetinaNet, SSD, etc.).
......
......@@ -3,9 +3,11 @@ import torch.nn as nn
from mmcv.cnn import normal_init
from .anchor_head import AnchorHead
from ..registry import HEADS
from ..utils import bias_init_with_prob
@HEADS.register_module
class RetinaHead(AnchorHead):
def __init__(self,
......
......@@ -6,8 +6,10 @@ from mmcv.cnn import normal_init
from mmdet.core import delta2bbox
from mmdet.ops import nms
from .anchor_head import AnchorHead
from ..registry import HEADS
@HEADS.register_module
class RPNHead(AnchorHead):
def __init__(self, in_channels, **kwargs):
......
......@@ -7,8 +7,10 @@ from mmcv.cnn import xavier_init
from mmdet.core import (AnchorGenerator, anchor_target, weighted_smoothl1,
multi_apply)
from .anchor_head import AnchorHead
from ..registry import HEADS
@HEADS.register_module
class SSDHead(AnchorHead):
def __init__(self,
......@@ -144,7 +146,6 @@ class SSDHead(AnchorHead):
self.target_stds,
cfg,
gt_labels_list=gt_labels,
cls_out_channels=self.cls_out_channels,
sampling=False,
unmap_outputs=False)
if cls_reg_targets is None:
......
......@@ -6,6 +6,8 @@ import torch.utils.checkpoint as cp
from mmcv.cnn import constant_init, kaiming_init
from mmcv.runner import load_checkpoint
from ..registry import BACKBONES
def conv3x3(in_planes, out_planes, stride=1, dilation=1):
"3x3 convolution with padding"
......@@ -182,6 +184,7 @@ def make_res_layer(block,
return nn.Sequential(*layers)
@BACKBONES.register_module
class ResNet(nn.Module):
"""ResNet backbone.
......
......@@ -4,6 +4,7 @@ import torch.nn as nn
from .resnet import ResNet
from .resnet import Bottleneck as _Bottleneck
from ..registry import BACKBONES
class Bottleneck(_Bottleneck):
......@@ -92,6 +93,7 @@ def make_res_layer(block,
return nn.Sequential(*layers)
@BACKBONES.register_module
class ResNeXt(ResNet):
"""ResNeXt backbone.
......
......@@ -6,8 +6,10 @@ import torch.nn.functional as F
from mmcv.cnn import (VGG, xavier_init, constant_init, kaiming_init,
normal_init)
from mmcv.runner import load_checkpoint
from ..registry import BACKBONES
@BACKBONES.register_module
class SSDVGG(VGG):
extra_setting = {
300: (256, 'S', 512, 128, 'S', 256, 128, 256, 128, 256),
......
......@@ -4,8 +4,10 @@ import torch.nn.functional as F
from mmdet.core import (delta2bbox, multiclass_nms, bbox_target,
weighted_cross_entropy, weighted_smoothl1, accuracy)
from ..registry import HEADS
@HEADS.register_module
class BBoxHead(nn.Module):
"""Simplest RoI head, with only two fc layers for classification and
regression respectively"""
......@@ -78,8 +80,14 @@ class BBoxHead(nn.Module):
target_stds=self.target_stds)
return cls_reg_targets
def loss(self, cls_score, bbox_pred, labels, label_weights, bbox_targets,
bbox_weights, reduce=True):
def loss(self,
cls_score,
bbox_pred,
labels,
label_weights,
bbox_targets,
bbox_weights,
reduce=True):
losses = dict()
if cls_score is not None:
losses['loss_cls'] = weighted_cross_entropy(
......
import torch.nn as nn
from .bbox_head import BBoxHead
from ..registry import HEADS
from ..utils import ConvModule
@HEADS.register_module
class ConvFCBBoxHead(BBoxHead):
"""More general bbox head, with shared conv and fc layers and two optional
separated branches.
......@@ -165,6 +167,7 @@ class ConvFCBBoxHead(BBoxHead):
return cls_score, bbox_pred
@HEADS.register_module
class SharedFCBBoxHead(ConvFCBBoxHead):
def __init__(self, num_fcs=2, fc_out_channels=1024, *args, **kwargs):
......
from mmcv.runner import obj_from_dict
import mmcv
from torch import nn
from . import (backbones, necks, roi_extractors, anchor_heads, bbox_heads,
mask_heads)
def _build_module(cfg, parrent=None, default_args=None):
return cfg if isinstance(cfg, nn.Module) else obj_from_dict(
cfg, parrent, default_args)
def build(cfg, parrent=None, default_args=None):
from .registry import BACKBONES, NECKS, ROI_EXTRACTORS, HEADS, DETECTORS
def _build_module(cfg, registry, default_args):
assert isinstance(cfg, dict) and 'type' in cfg
assert isinstance(default_args, dict) or default_args is None
args = cfg.copy()
obj_type = args.pop('type')
if mmcv.is_str(obj_type):
if obj_type not in registry.module_dict:
raise KeyError('{} is not in the {} registry'.format(
obj_type, registry.name))
obj_type = registry.module_dict[obj_type]
elif not isinstance(obj_type, type):
raise TypeError('type must be a str or valid type, but got {}'.format(
type(obj_type)))
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
return obj_type(**args)
def build(cfg, registry, default_args=None):
if isinstance(cfg, list):
modules = [_build_module(cfg_, parrent, default_args) for cfg_ in cfg]
modules = [_build_module(cfg_, registry, default_args) for cfg_ in cfg]
return nn.Sequential(*modules)
else:
return _build_module(cfg, parrent, default_args)
return _build_module(cfg, registry, default_args)
def build_backbone(cfg):
return build(cfg, backbones)
return build(cfg, BACKBONES)
def build_neck(cfg):
return build(cfg, necks)
def build_anchor_head(cfg):
return build(cfg, anchor_heads)
return build(cfg, NECKS)
def build_roi_extractor(cfg):
return build(cfg, roi_extractors)
def build_bbox_head(cfg):
return build(cfg, bbox_heads)
return build(cfg, ROI_EXTRACTORS)
def build_mask_head(cfg):
return build(cfg, mask_heads)
def build_head(cfg):
return build(cfg, HEADS)
def build_detector(cfg, train_cfg=None, test_cfg=None):
from . import detectors
return build(cfg, detectors, dict(train_cfg=train_cfg, test_cfg=test_cfg))
return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
......@@ -6,10 +6,12 @@ import torch.nn as nn
from .base import BaseDetector
from .test_mixins import RPNTestMixin
from .. import builder
from ..registry import DETECTORS
from mmdet.core import (assign_and_sample, bbox2roi, bbox2result, multi_apply,
merge_aug_masks)
@DETECTORS.register_module
class CascadeRCNN(BaseDetector, RPNTestMixin):
def __init__(self,
......
from .two_stage import TwoStageDetector
from ..registry import DETECTORS
@DETECTORS.register_module
class FastRCNN(TwoStageDetector):
def __init__(self,
......
from .two_stage import TwoStageDetector
from ..registry import DETECTORS
@DETECTORS.register_module
class FasterRCNN(TwoStageDetector):
def __init__(self,
......@@ -13,11 +15,11 @@ class FasterRCNN(TwoStageDetector):
test_cfg,
pretrained=None):
super(FasterRCNN, self).__init__(
backbone=backbone,
neck=neck,
rpn_head=rpn_head,
bbox_roi_extractor=bbox_roi_extractor,
bbox_head=bbox_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained)
backbone=backbone,
neck=neck,
rpn_head=rpn_head,
bbox_roi_extractor=bbox_roi_extractor,
bbox_head=bbox_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained)
from .two_stage import TwoStageDetector
from ..registry import DETECTORS
@DETECTORS.register_module
class MaskRCNN(TwoStageDetector):
def __init__(self,
......
from .single_stage import SingleStageDetector
from ..registry import DETECTORS
@DETECTORS.register_module
class RetinaNet(SingleStageDetector):
def __init__(self,
......
......@@ -4,8 +4,10 @@ from mmdet.core import tensor2imgs, bbox_mapping
from .base import BaseDetector
from .test_mixins import RPNTestMixin
from .. import builder
from ..registry import DETECTORS
@DETECTORS.register_module
class RPN(BaseDetector, RPNTestMixin):
def __init__(self,
......@@ -18,7 +20,7 @@ class RPN(BaseDetector, RPNTestMixin):
super(RPN, self).__init__()
self.backbone = builder.build_backbone(backbone)
self.neck = builder.build_neck(neck) if neck is not None else None
self.rpn_head = builder.build_anchor_head(rpn_head)
self.rpn_head = builder.build_head(rpn_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained)
......
......@@ -2,9 +2,11 @@ import torch.nn as nn
from .base import BaseDetector
from .. import builder
from ..registry import DETECTORS
from mmdet.core import bbox2result
@DETECTORS.register_module
class SingleStageDetector(BaseDetector):
def __init__(self,
......@@ -18,7 +20,7 @@ class SingleStageDetector(BaseDetector):
self.backbone = builder.build_backbone(backbone)
if neck is not None:
self.neck = builder.build_neck(neck)
self.bbox_head = builder.build_anchor_head(bbox_head)
self.bbox_head = builder.build_head(bbox_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained)
......
......@@ -4,9 +4,11 @@ import torch.nn as nn
from .base import BaseDetector
from .test_mixins import RPNTestMixin, BBoxTestMixin, MaskTestMixin
from .. import builder
from ..registry import DETECTORS
from mmdet.core import bbox2roi, bbox2result, build_assigner, build_sampler
@DETECTORS.register_module
class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
MaskTestMixin):
......@@ -30,17 +32,17 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
raise NotImplementedError
if rpn_head is not None:
self.rpn_head = builder.build_anchor_head(rpn_head)
self.rpn_head = builder.build_head(rpn_head)
if bbox_head is not None:
self.bbox_roi_extractor = builder.build_roi_extractor(
bbox_roi_extractor)
self.bbox_head = builder.build_bbox_head(bbox_head)
self.bbox_head = builder.build_head(bbox_head)
if mask_head is not None:
self.mask_roi_extractor = builder.build_roi_extractor(
mask_roi_extractor)
self.mask_head = builder.build_mask_head(mask_head)
self.mask_head = builder.build_head(mask_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
......
......@@ -4,10 +4,12 @@ import pycocotools.mask as mask_util
import torch
import torch.nn as nn
from ..registry import HEADS
from ..utils import ConvModule
from mmdet.core import mask_cross_entropy, mask_target
@HEADS.register_module
class FCNMaskHead(nn.Module):
def __init__(self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment