Commit db14d74b authored by Kai Chen's avatar Kai Chen
Browse files

Merge branch 'master' of github.com:open-mmlab/mmdetection into dcn_cpp_extension

parents c1e0884f b7aa30c2
from .two_stage import TwoStageDetector from .two_stage import TwoStageDetector
from ..registry import DETECTORS
@DETECTORS.register_module
class FasterRCNN(TwoStageDetector): class FasterRCNN(TwoStageDetector):
def __init__(self, def __init__(self,
......
from .two_stage import TwoStageDetector from .two_stage import TwoStageDetector
from ..registry import DETECTORS
@DETECTORS.register_module
class MaskRCNN(TwoStageDetector): class MaskRCNN(TwoStageDetector):
def __init__(self, def __init__(self,
......
from .single_stage import SingleStageDetector from .single_stage import SingleStageDetector
from ..registry import DETECTORS
@DETECTORS.register_module
class RetinaNet(SingleStageDetector): class RetinaNet(SingleStageDetector):
def __init__(self, def __init__(self,
......
...@@ -4,8 +4,10 @@ from mmdet.core import tensor2imgs, bbox_mapping ...@@ -4,8 +4,10 @@ from mmdet.core import tensor2imgs, bbox_mapping
from .base import BaseDetector from .base import BaseDetector
from .test_mixins import RPNTestMixin from .test_mixins import RPNTestMixin
from .. import builder from .. import builder
from ..registry import DETECTORS
@DETECTORS.register_module
class RPN(BaseDetector, RPNTestMixin): class RPN(BaseDetector, RPNTestMixin):
def __init__(self, def __init__(self,
...@@ -18,7 +20,7 @@ class RPN(BaseDetector, RPNTestMixin): ...@@ -18,7 +20,7 @@ class RPN(BaseDetector, RPNTestMixin):
super(RPN, self).__init__() super(RPN, self).__init__()
self.backbone = builder.build_backbone(backbone) self.backbone = builder.build_backbone(backbone)
self.neck = builder.build_neck(neck) if neck is not None else None self.neck = builder.build_neck(neck) if neck is not None else None
self.rpn_head = builder.build_anchor_head(rpn_head) self.rpn_head = builder.build_head(rpn_head)
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.test_cfg = test_cfg self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained) self.init_weights(pretrained=pretrained)
......
...@@ -2,9 +2,11 @@ import torch.nn as nn ...@@ -2,9 +2,11 @@ import torch.nn as nn
from .base import BaseDetector from .base import BaseDetector
from .. import builder from .. import builder
from ..registry import DETECTORS
from mmdet.core import bbox2result from mmdet.core import bbox2result
@DETECTORS.register_module
class SingleStageDetector(BaseDetector): class SingleStageDetector(BaseDetector):
def __init__(self, def __init__(self,
...@@ -18,7 +20,7 @@ class SingleStageDetector(BaseDetector): ...@@ -18,7 +20,7 @@ class SingleStageDetector(BaseDetector):
self.backbone = builder.build_backbone(backbone) self.backbone = builder.build_backbone(backbone)
if neck is not None: if neck is not None:
self.neck = builder.build_neck(neck) self.neck = builder.build_neck(neck)
self.bbox_head = builder.build_anchor_head(bbox_head) self.bbox_head = builder.build_head(bbox_head)
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.test_cfg = test_cfg self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained) self.init_weights(pretrained=pretrained)
......
...@@ -4,9 +4,11 @@ import torch.nn as nn ...@@ -4,9 +4,11 @@ import torch.nn as nn
from .base import BaseDetector from .base import BaseDetector
from .test_mixins import RPNTestMixin, BBoxTestMixin, MaskTestMixin from .test_mixins import RPNTestMixin, BBoxTestMixin, MaskTestMixin
from .. import builder from .. import builder
from ..registry import DETECTORS
from mmdet.core import bbox2roi, bbox2result, build_assigner, build_sampler from mmdet.core import bbox2roi, bbox2result, build_assigner, build_sampler
@DETECTORS.register_module
class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
MaskTestMixin): MaskTestMixin):
...@@ -30,17 +32,17 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, ...@@ -30,17 +32,17 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
raise NotImplementedError raise NotImplementedError
if rpn_head is not None: if rpn_head is not None:
self.rpn_head = builder.build_anchor_head(rpn_head) self.rpn_head = builder.build_head(rpn_head)
if bbox_head is not None: if bbox_head is not None:
self.bbox_roi_extractor = builder.build_roi_extractor( self.bbox_roi_extractor = builder.build_roi_extractor(
bbox_roi_extractor) bbox_roi_extractor)
self.bbox_head = builder.build_bbox_head(bbox_head) self.bbox_head = builder.build_head(bbox_head)
if mask_head is not None: if mask_head is not None:
self.mask_roi_extractor = builder.build_roi_extractor( self.mask_roi_extractor = builder.build_roi_extractor(
mask_roi_extractor) mask_roi_extractor)
self.mask_head = builder.build_mask_head(mask_head) self.mask_head = builder.build_head(mask_head)
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.test_cfg = test_cfg self.test_cfg = test_cfg
......
...@@ -4,10 +4,12 @@ import pycocotools.mask as mask_util ...@@ -4,10 +4,12 @@ import pycocotools.mask as mask_util
import torch import torch
import torch.nn as nn import torch.nn as nn
from ..registry import HEADS
from ..utils import ConvModule from ..utils import ConvModule
from mmdet.core import mask_cross_entropy, mask_target from mmdet.core import mask_cross_entropy, mask_target
@HEADS.register_module
class FCNMaskHead(nn.Module): class FCNMaskHead(nn.Module):
def __init__(self, def __init__(self,
......
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from ..utils import ConvModule
from mmcv.cnn import xavier_init from mmcv.cnn import xavier_init
from ..utils import ConvModule
from ..registry import NECKS
@NECKS.register_module
class FPN(nn.Module): class FPN(nn.Module):
def __init__(self, def __init__(self,
......
import torch.nn as nn
class Registry(object):
def __init__(self, name):
self._name = name
self._module_dict = dict()
@property
def name(self):
return self._name
@property
def module_dict(self):
return self._module_dict
def _register_module(self, module_class):
"""Register a module.
Args:
module (:obj:`nn.Module`): Module to be registered.
"""
if not issubclass(module_class, nn.Module):
raise TypeError(
'module must be a child of nn.Module, but got {}'.format(
type(module_class)))
module_name = module_class.__name__
if module_name in self._module_dict:
raise KeyError('{} is already registered in {}'.format(
module_name, self.name))
self._module_dict[module_name] = module_class
def register_module(self, cls):
self._register_module(cls)
return cls
BACKBONES = Registry('backbone')
NECKS = Registry('neck')
ROI_EXTRACTORS = Registry('roi_extractor')
HEADS = Registry('head')
DETECTORS = Registry('detector')
...@@ -4,8 +4,10 @@ import torch ...@@ -4,8 +4,10 @@ import torch
import torch.nn as nn import torch.nn as nn
from mmdet import ops from mmdet import ops
from ..registry import ROI_EXTRACTORS
@ROI_EXTRACTORS.register_module
class SingleRoIExtractor(nn.Module): class SingleRoIExtractor(nn.Module):
"""Extract RoI features from a single level feature map. """Extract RoI features from a single level feature map.
......
...@@ -5,14 +5,14 @@ from setuptools import find_packages, setup ...@@ -5,14 +5,14 @@ from setuptools import find_packages, setup
def readme(): def readme():
with open('README.md') as f: with open('README.md', encoding='utf-8') as f:
content = f.read() content = f.read()
return content return content
MAJOR = 0 MAJOR = 0
MINOR = 5 MINOR = 5
PATCH = 5 PATCH = 6
SUFFIX = '' SUFFIX = ''
SHORT_VERSION = '{}.{}.{}{}'.format(MAJOR, MINOR, PATCH, SUFFIX) SHORT_VERSION = '{}.{}.{}{}'.format(MAJOR, MINOR, PATCH, SUFFIX)
...@@ -89,7 +89,7 @@ if __name__ == '__main__': ...@@ -89,7 +89,7 @@ if __name__ == '__main__':
long_description=readme(), long_description=readme(),
keywords='computer vision, object detection', keywords='computer vision, object detection',
url='https://github.com/open-mmlab/mmdetection', url='https://github.com/open-mmlab/mmdetection',
packages=find_packages(exclude=('configs', 'tools', 'demo',)), packages=find_packages(exclude=('configs', 'tools', 'demo')),
package_data={'mmdet.ops': ['*/*.so']}, package_data={'mmdet.ops': ['*/*.so']},
classifiers=[ classifiers=[
'Development Status :: 4 - Beta', 'Development Status :: 4 - Beta',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment