Commit 441015ea authored by Kai Chen's avatar Kai Chen
Browse files

Merge branch 'master' into pytorch-1.0

parents 2017c81e 3b6ae96d
import torch.nn as nn import torch.nn as nn
from .bbox_head import BBoxHead from .bbox_head import BBoxHead
from ..registry import HEADS
from ..utils import ConvModule from ..utils import ConvModule
@HEADS.register_module
class ConvFCBBoxHead(BBoxHead): class ConvFCBBoxHead(BBoxHead):
"""More general bbox head, with shared conv and fc layers and two optional """More general bbox head, with shared conv and fc layers and two optional
separated branches. separated branches.
...@@ -165,6 +167,7 @@ class ConvFCBBoxHead(BBoxHead): ...@@ -165,6 +167,7 @@ class ConvFCBBoxHead(BBoxHead):
return cls_score, bbox_pred return cls_score, bbox_pred
@HEADS.register_module
class SharedFCBBoxHead(ConvFCBBoxHead): class SharedFCBBoxHead(ConvFCBBoxHead):
def __init__(self, num_fcs=2, fc_out_channels=1024, *args, **kwargs): def __init__(self, num_fcs=2, fc_out_channels=1024, *args, **kwargs):
......
from mmcv.runner import obj_from_dict import mmcv
from torch import nn from torch import nn
from . import (backbones, necks, roi_extractors, rpn_heads, bbox_heads, from .registry import BACKBONES, NECKS, ROI_EXTRACTORS, HEADS, DETECTORS
mask_heads, single_stage_heads)
__all__ = [ def _build_module(cfg, registry, default_args):
'build_backbone', 'build_neck', 'build_rpn_head', 'build_roi_extractor', assert isinstance(cfg, dict) and 'type' in cfg
'build_bbox_head', 'build_mask_head', 'build_single_stage_head', assert isinstance(default_args, dict) or default_args is None
'build_detector' args = cfg.copy()
] obj_type = args.pop('type')
if mmcv.is_str(obj_type):
if obj_type not in registry.module_dict:
def _build_module(cfg, parrent=None, default_args=None): raise KeyError('{} is not in the {} registry'.format(
return cfg if isinstance(cfg, nn.Module) else obj_from_dict( obj_type, registry.name))
cfg, parrent, default_args) obj_type = registry.module_dict[obj_type]
elif not isinstance(obj_type, type):
raise TypeError('type must be a str or valid type, but got {}'.format(
def build(cfg, parrent=None, default_args=None): type(obj_type)))
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
return obj_type(**args)
def build(cfg, registry, default_args=None):
if isinstance(cfg, list): if isinstance(cfg, list):
modules = [_build_module(cfg_, parrent, default_args) for cfg_ in cfg] modules = [_build_module(cfg_, registry, default_args) for cfg_ in cfg]
return nn.Sequential(*modules) return nn.Sequential(*modules)
else: else:
return _build_module(cfg, parrent, default_args) return _build_module(cfg, registry, default_args)
def build_backbone(cfg): def build_backbone(cfg):
return build(cfg, backbones) return build(cfg, BACKBONES)
def build_neck(cfg): def build_neck(cfg):
return build(cfg, necks) return build(cfg, NECKS)
def build_rpn_head(cfg):
return build(cfg, rpn_heads)
def build_roi_extractor(cfg): def build_roi_extractor(cfg):
return build(cfg, roi_extractors) return build(cfg, ROI_EXTRACTORS)
def build_bbox_head(cfg):
return build(cfg, bbox_heads)
def build_mask_head(cfg):
return build(cfg, mask_heads)
def build_single_stage_head(cfg): def build_head(cfg):
return build(cfg, single_stage_heads) return build(cfg, HEADS)
def build_detector(cfg, train_cfg=None, test_cfg=None): def build_detector(cfg, train_cfg=None, test_cfg=None):
from . import detectors return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
return build(cfg, detectors, dict(train_cfg=train_cfg, test_cfg=test_cfg))
...@@ -6,10 +6,12 @@ import torch.nn as nn ...@@ -6,10 +6,12 @@ import torch.nn as nn
from .base import BaseDetector from .base import BaseDetector
from .test_mixins import RPNTestMixin from .test_mixins import RPNTestMixin
from .. import builder from .. import builder
from ..registry import DETECTORS
from mmdet.core import (assign_and_sample, bbox2roi, bbox2result, multi_apply, from mmdet.core import (assign_and_sample, bbox2roi, bbox2result, multi_apply,
merge_aug_masks) merge_aug_masks)
@DETECTORS.register_module
class CascadeRCNN(BaseDetector, RPNTestMixin): class CascadeRCNN(BaseDetector, RPNTestMixin):
def __init__(self, def __init__(self,
...@@ -37,7 +39,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin): ...@@ -37,7 +39,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
raise NotImplementedError raise NotImplementedError
if rpn_head is not None: if rpn_head is not None:
self.rpn_head = builder.build_rpn_head(rpn_head) self.rpn_head = builder.build_head(rpn_head)
if bbox_head is not None: if bbox_head is not None:
self.bbox_roi_extractor = nn.ModuleList() self.bbox_roi_extractor = nn.ModuleList()
...@@ -52,7 +54,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin): ...@@ -52,7 +54,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
for roi_extractor, head in zip(bbox_roi_extractor, bbox_head): for roi_extractor, head in zip(bbox_roi_extractor, bbox_head):
self.bbox_roi_extractor.append( self.bbox_roi_extractor.append(
builder.build_roi_extractor(roi_extractor)) builder.build_roi_extractor(roi_extractor))
self.bbox_head.append(builder.build_bbox_head(head)) self.bbox_head.append(builder.build_head(head))
if mask_head is not None: if mask_head is not None:
self.mask_roi_extractor = nn.ModuleList() self.mask_roi_extractor = nn.ModuleList()
...@@ -67,7 +69,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin): ...@@ -67,7 +69,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
for roi_extractor, head in zip(mask_roi_extractor, mask_head): for roi_extractor, head in zip(mask_roi_extractor, mask_head):
self.mask_roi_extractor.append( self.mask_roi_extractor.append(
builder.build_roi_extractor(roi_extractor)) builder.build_roi_extractor(roi_extractor))
self.mask_head.append(builder.build_mask_head(head)) self.mask_head.append(builder.build_head(head))
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.test_cfg = test_cfg self.test_cfg = test_cfg
...@@ -123,7 +125,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin): ...@@ -123,7 +125,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
losses.update(rpn_losses) losses.update(rpn_losses)
proposal_inputs = rpn_outs + (img_meta, self.test_cfg.rpn) proposal_inputs = rpn_outs + (img_meta, self.test_cfg.rpn)
proposal_list = self.rpn_head.get_proposals(*proposal_inputs) proposal_list = self.rpn_head.get_bboxes(*proposal_inputs)
else: else:
proposal_list = proposals proposal_list = proposals
......
from .two_stage import TwoStageDetector from .two_stage import TwoStageDetector
from ..registry import DETECTORS
@DETECTORS.register_module
class FastRCNN(TwoStageDetector): class FastRCNN(TwoStageDetector):
def __init__(self, def __init__(self,
......
from .two_stage import TwoStageDetector from .two_stage import TwoStageDetector
from ..registry import DETECTORS
@DETECTORS.register_module
class FasterRCNN(TwoStageDetector): class FasterRCNN(TwoStageDetector):
def __init__(self, def __init__(self,
...@@ -13,11 +15,11 @@ class FasterRCNN(TwoStageDetector): ...@@ -13,11 +15,11 @@ class FasterRCNN(TwoStageDetector):
test_cfg, test_cfg,
pretrained=None): pretrained=None):
super(FasterRCNN, self).__init__( super(FasterRCNN, self).__init__(
backbone=backbone, backbone=backbone,
neck=neck, neck=neck,
rpn_head=rpn_head, rpn_head=rpn_head,
bbox_roi_extractor=bbox_roi_extractor, bbox_roi_extractor=bbox_roi_extractor,
bbox_head=bbox_head, bbox_head=bbox_head,
train_cfg=train_cfg, train_cfg=train_cfg,
test_cfg=test_cfg, test_cfg=test_cfg,
pretrained=pretrained) pretrained=pretrained)
from .two_stage import TwoStageDetector from .two_stage import TwoStageDetector
from ..registry import DETECTORS
@DETECTORS.register_module
class MaskRCNN(TwoStageDetector): class MaskRCNN(TwoStageDetector):
def __init__(self, def __init__(self,
......
from .single_stage import SingleStageDetector from .single_stage import SingleStageDetector
from ..registry import DETECTORS
@DETECTORS.register_module
class RetinaNet(SingleStageDetector): class RetinaNet(SingleStageDetector):
def __init__(self, def __init__(self,
......
...@@ -4,8 +4,10 @@ from mmdet.core import tensor2imgs, bbox_mapping ...@@ -4,8 +4,10 @@ from mmdet.core import tensor2imgs, bbox_mapping
from .base import BaseDetector from .base import BaseDetector
from .test_mixins import RPNTestMixin from .test_mixins import RPNTestMixin
from .. import builder from .. import builder
from ..registry import DETECTORS
@DETECTORS.register_module
class RPN(BaseDetector, RPNTestMixin): class RPN(BaseDetector, RPNTestMixin):
def __init__(self, def __init__(self,
...@@ -18,7 +20,7 @@ class RPN(BaseDetector, RPNTestMixin): ...@@ -18,7 +20,7 @@ class RPN(BaseDetector, RPNTestMixin):
super(RPN, self).__init__() super(RPN, self).__init__()
self.backbone = builder.build_backbone(backbone) self.backbone = builder.build_backbone(backbone)
self.neck = builder.build_neck(neck) if neck is not None else None self.neck = builder.build_neck(neck) if neck is not None else None
self.rpn_head = builder.build_rpn_head(rpn_head) self.rpn_head = builder.build_head(rpn_head)
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.test_cfg = test_cfg self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained) self.init_weights(pretrained=pretrained)
......
...@@ -2,9 +2,11 @@ import torch.nn as nn ...@@ -2,9 +2,11 @@ import torch.nn as nn
from .base import BaseDetector from .base import BaseDetector
from .. import builder from .. import builder
from ..registry import DETECTORS
from mmdet.core import bbox2result from mmdet.core import bbox2result
@DETECTORS.register_module
class SingleStageDetector(BaseDetector): class SingleStageDetector(BaseDetector):
def __init__(self, def __init__(self,
...@@ -18,7 +20,7 @@ class SingleStageDetector(BaseDetector): ...@@ -18,7 +20,7 @@ class SingleStageDetector(BaseDetector):
self.backbone = builder.build_backbone(backbone) self.backbone = builder.build_backbone(backbone)
if neck is not None: if neck is not None:
self.neck = builder.build_neck(neck) self.neck = builder.build_neck(neck)
self.bbox_head = builder.build_single_stage_head(bbox_head) self.bbox_head = builder.build_head(bbox_head)
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.test_cfg = test_cfg self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained) self.init_weights(pretrained=pretrained)
...@@ -51,7 +53,7 @@ class SingleStageDetector(BaseDetector): ...@@ -51,7 +53,7 @@ class SingleStageDetector(BaseDetector):
x = self.extract_feat(img) x = self.extract_feat(img)
outs = self.bbox_head(x) outs = self.bbox_head(x)
bbox_inputs = outs + (img_meta, self.test_cfg, rescale) bbox_inputs = outs + (img_meta, self.test_cfg, rescale)
bbox_list = self.bbox_head.get_det_bboxes(*bbox_inputs) bbox_list = self.bbox_head.get_bboxes(*bbox_inputs)
bbox_results = [ bbox_results = [
bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes) bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes)
for det_bboxes, det_labels in bbox_list for det_bboxes, det_labels in bbox_list
......
...@@ -7,7 +7,7 @@ class RPNTestMixin(object): ...@@ -7,7 +7,7 @@ class RPNTestMixin(object):
def simple_test_rpn(self, x, img_meta, rpn_test_cfg): def simple_test_rpn(self, x, img_meta, rpn_test_cfg):
rpn_outs = self.rpn_head(x) rpn_outs = self.rpn_head(x)
proposal_inputs = rpn_outs + (img_meta, rpn_test_cfg) proposal_inputs = rpn_outs + (img_meta, rpn_test_cfg)
proposal_list = self.rpn_head.get_proposals(*proposal_inputs) proposal_list = self.rpn_head.get_bboxes(*proposal_inputs)
return proposal_list return proposal_list
def aug_test_rpn(self, feats, img_metas, rpn_test_cfg): def aug_test_rpn(self, feats, img_metas, rpn_test_cfg):
......
...@@ -4,9 +4,11 @@ import torch.nn as nn ...@@ -4,9 +4,11 @@ import torch.nn as nn
from .base import BaseDetector from .base import BaseDetector
from .test_mixins import RPNTestMixin, BBoxTestMixin, MaskTestMixin from .test_mixins import RPNTestMixin, BBoxTestMixin, MaskTestMixin
from .. import builder from .. import builder
from ..registry import DETECTORS
from mmdet.core import bbox2roi, bbox2result, build_assigner, build_sampler from mmdet.core import bbox2roi, bbox2result, build_assigner, build_sampler
@DETECTORS.register_module
class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
MaskTestMixin): MaskTestMixin):
...@@ -30,17 +32,17 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, ...@@ -30,17 +32,17 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
raise NotImplementedError raise NotImplementedError
if rpn_head is not None: if rpn_head is not None:
self.rpn_head = builder.build_rpn_head(rpn_head) self.rpn_head = builder.build_head(rpn_head)
if bbox_head is not None: if bbox_head is not None:
self.bbox_roi_extractor = builder.build_roi_extractor( self.bbox_roi_extractor = builder.build_roi_extractor(
bbox_roi_extractor) bbox_roi_extractor)
self.bbox_head = builder.build_bbox_head(bbox_head) self.bbox_head = builder.build_head(bbox_head)
if mask_head is not None: if mask_head is not None:
self.mask_roi_extractor = builder.build_roi_extractor( self.mask_roi_extractor = builder.build_roi_extractor(
mask_roi_extractor) mask_roi_extractor)
self.mask_head = builder.build_mask_head(mask_head) self.mask_head = builder.build_head(mask_head)
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.test_cfg = test_cfg self.test_cfg = test_cfg
...@@ -96,7 +98,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, ...@@ -96,7 +98,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
losses.update(rpn_losses) losses.update(rpn_losses)
proposal_inputs = rpn_outs + (img_meta, self.test_cfg.rpn) proposal_inputs = rpn_outs + (img_meta, self.test_cfg.rpn)
proposal_list = self.rpn_head.get_proposals(*proposal_inputs) proposal_list = self.rpn_head.get_bboxes(*proposal_inputs)
else: else:
proposal_list = proposals proposal_list = proposals
......
...@@ -4,10 +4,12 @@ import pycocotools.mask as mask_util ...@@ -4,10 +4,12 @@ import pycocotools.mask as mask_util
import torch import torch
import torch.nn as nn import torch.nn as nn
from ..registry import HEADS
from ..utils import ConvModule from ..utils import ConvModule
from mmdet.core import mask_cross_entropy, mask_target from mmdet.core import mask_cross_entropy, mask_target
@HEADS.register_module
class FCNMaskHead(nn.Module): class FCNMaskHead(nn.Module):
def __init__(self, def __init__(self,
......
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmcv.cnn import xavier_init
from ..utils import ConvModule from ..utils import ConvModule
from ..utils import xavier_init from ..registry import NECKS
@NECKS.register_module
class FPN(nn.Module): class FPN(nn.Module):
def __init__(self, def __init__(self,
......
import torch.nn as nn
class Registry(object):
def __init__(self, name):
self._name = name
self._module_dict = dict()
@property
def name(self):
return self._name
@property
def module_dict(self):
return self._module_dict
def _register_module(self, module_class):
"""Register a module.
Args:
module (:obj:`nn.Module`): Module to be registered.
"""
if not issubclass(module_class, nn.Module):
raise TypeError(
'module must be a child of nn.Module, but got {}'.format(
type(module_class)))
module_name = module_class.__name__
if module_name in self._module_dict:
raise KeyError('{} is already registered in {}'.format(
module_name, self.name))
self._module_dict[module_name] = module_class
def register_module(self, cls):
self._register_module(cls)
return cls
BACKBONES = Registry('backbone')
NECKS = Registry('neck')
ROI_EXTRACTORS = Registry('roi_extractor')
HEADS = Registry('head')
DETECTORS = Registry('detector')
...@@ -4,8 +4,10 @@ import torch ...@@ -4,8 +4,10 @@ import torch
import torch.nn as nn import torch.nn as nn
from mmdet import ops from mmdet import ops
from ..registry import ROI_EXTRACTORS
@ROI_EXTRACTORS.register_module
class SingleRoIExtractor(nn.Module): class SingleRoIExtractor(nn.Module):
"""Extract RoI features from a single level feature map. """Extract RoI features from a single level feature map.
......
from .rpn_head import RPNHead
__all__ = ['RPNHead']
This diff is collapsed.
from .retina_head import RetinaHead
__all__ = ['RetinaHead']
...@@ -53,7 +53,8 @@ class ConvModule(nn.Module): ...@@ -53,7 +53,8 @@ class ConvModule(nn.Module):
if self.with_norm: if self.with_norm:
norm_channels = out_channels if self.activate_last else in_channels norm_channels = out_channels if self.activate_last else in_channels
self.norm = build_norm_layer(normalize, norm_channels) self.norm_name, norm = build_norm_layer(normalize, norm_channels)
self.add_module(self.norm_name, norm)
if self.with_activatation: if self.with_activatation:
assert activation in ['relu'], 'Only ReLU supported.' assert activation in ['relu'], 'Only ReLU supported.'
...@@ -63,6 +64,10 @@ class ConvModule(nn.Module): ...@@ -63,6 +64,10 @@ class ConvModule(nn.Module):
# Default using msra init # Default using msra init
self.init_weights() self.init_weights()
@property
def norm(self):
return getattr(self, self.norm_name)
def init_weights(self): def init_weights(self):
nonlinearity = 'relu' if self.activation is None else self.activation nonlinearity = 'relu' if self.activation is None else self.activation
kaiming_init(self.conv, nonlinearity=nonlinearity) kaiming_init(self.conv, nonlinearity=nonlinearity)
......
import torch.nn as nn import torch.nn as nn
norm_cfg = {'BN': nn.BatchNorm2d, 'SyncBN': None, 'GN': None}
norm_cfg = {
# format: layer_type: (abbreviation, module)
'BN': ('bn', nn.BatchNorm2d),
'SyncBN': ('bn', None),
'GN': ('gn', nn.GroupNorm),
# and potentially 'SN'
}
def build_norm_layer(cfg, num_features):
def build_norm_layer(cfg, num_features, postfix=''):
""" Build normalization layer
Args:
cfg (dict): cfg should contain:
type (str): identify norm layer type.
layer args: args needed to instantiate a norm layer.
frozen (bool): [optional] whether stop gradient updates
of norm layer, it is helpful to set frozen mode
in backbone's norms.
num_features (int): number of channels from input
postfix (int, str): appended into norm abbreation to
create named layer.
Returns:
name (str): abbreation + postfix
layer (nn.Module): created norm layer
"""
assert isinstance(cfg, dict) and 'type' in cfg assert isinstance(cfg, dict) and 'type' in cfg
cfg_ = cfg.copy() cfg_ = cfg.copy()
cfg_.setdefault('eps', 1e-5)
layer_type = cfg_.pop('type')
layer_type = cfg_.pop('type')
if layer_type not in norm_cfg: if layer_type not in norm_cfg:
raise KeyError('Unrecognized norm type {}'.format(layer_type)) raise KeyError('Unrecognized norm type {}'.format(layer_type))
elif norm_cfg[layer_type] is None: else:
raise NotImplementedError abbr, norm_layer = norm_cfg[layer_type]
if norm_layer is None:
raise NotImplementedError
assert isinstance(postfix, (int, str))
name = abbr + str(postfix)
frozen = cfg_.pop('frozen', False)
cfg_.setdefault('eps', 1e-5)
if layer_type != 'GN':
layer = norm_layer(num_features, **cfg_)
else:
assert 'num_groups' in cfg_
layer = norm_layer(num_channels=num_features, **cfg_)
if frozen:
for param in layer.parameters():
param.requires_grad = False
return norm_cfg[layer_type](num_features, **cfg_) return name, layer
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment