Commit 441015ea authored by Kai Chen's avatar Kai Chen
Browse files

Merge branch 'master' into pytorch-1.0

parents 2017c81e 3b6ae96d
import torch.nn as nn import torch.nn as nn
from .bbox_head import BBoxHead from .bbox_head import BBoxHead
from ..registry import HEADS
from ..utils import ConvModule from ..utils import ConvModule
@HEADS.register_module
class ConvFCBBoxHead(BBoxHead): class ConvFCBBoxHead(BBoxHead):
"""More general bbox head, with shared conv and fc layers and two optional """More general bbox head, with shared conv and fc layers and two optional
separated branches. separated branches.
...@@ -165,6 +167,7 @@ class ConvFCBBoxHead(BBoxHead): ...@@ -165,6 +167,7 @@ class ConvFCBBoxHead(BBoxHead):
return cls_score, bbox_pred return cls_score, bbox_pred
@HEADS.register_module
class SharedFCBBoxHead(ConvFCBBoxHead): class SharedFCBBoxHead(ConvFCBBoxHead):
def __init__(self, num_fcs=2, fc_out_channels=1024, *args, **kwargs): def __init__(self, num_fcs=2, fc_out_channels=1024, *args, **kwargs):
......
from mmcv.runner import obj_from_dict import mmcv
from torch import nn from torch import nn
from . import (backbones, necks, roi_extractors, rpn_heads, bbox_heads, from .registry import BACKBONES, NECKS, ROI_EXTRACTORS, HEADS, DETECTORS
mask_heads, single_stage_heads)
__all__ = [ def _build_module(cfg, registry, default_args):
'build_backbone', 'build_neck', 'build_rpn_head', 'build_roi_extractor', assert isinstance(cfg, dict) and 'type' in cfg
'build_bbox_head', 'build_mask_head', 'build_single_stage_head', assert isinstance(default_args, dict) or default_args is None
'build_detector' args = cfg.copy()
] obj_type = args.pop('type')
if mmcv.is_str(obj_type):
if obj_type not in registry.module_dict:
def _build_module(cfg, parrent=None, default_args=None): raise KeyError('{} is not in the {} registry'.format(
return cfg if isinstance(cfg, nn.Module) else obj_from_dict( obj_type, registry.name))
cfg, parrent, default_args) obj_type = registry.module_dict[obj_type]
elif not isinstance(obj_type, type):
raise TypeError('type must be a str or valid type, but got {}'.format(
def build(cfg, parrent=None, default_args=None): type(obj_type)))
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
return obj_type(**args)
def build(cfg, registry, default_args=None):
if isinstance(cfg, list): if isinstance(cfg, list):
modules = [_build_module(cfg_, parrent, default_args) for cfg_ in cfg] modules = [_build_module(cfg_, registry, default_args) for cfg_ in cfg]
return nn.Sequential(*modules) return nn.Sequential(*modules)
else: else:
return _build_module(cfg, parrent, default_args) return _build_module(cfg, registry, default_args)
def build_backbone(cfg): def build_backbone(cfg):
return build(cfg, backbones) return build(cfg, BACKBONES)
def build_neck(cfg): def build_neck(cfg):
return build(cfg, necks) return build(cfg, NECKS)
def build_rpn_head(cfg):
return build(cfg, rpn_heads)
def build_roi_extractor(cfg): def build_roi_extractor(cfg):
return build(cfg, roi_extractors) return build(cfg, ROI_EXTRACTORS)
def build_bbox_head(cfg):
return build(cfg, bbox_heads)
def build_mask_head(cfg):
return build(cfg, mask_heads)
def build_single_stage_head(cfg): def build_head(cfg):
return build(cfg, single_stage_heads) return build(cfg, HEADS)
def build_detector(cfg, train_cfg=None, test_cfg=None): def build_detector(cfg, train_cfg=None, test_cfg=None):
from . import detectors return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
return build(cfg, detectors, dict(train_cfg=train_cfg, test_cfg=test_cfg))
...@@ -6,10 +6,12 @@ import torch.nn as nn ...@@ -6,10 +6,12 @@ import torch.nn as nn
from .base import BaseDetector from .base import BaseDetector
from .test_mixins import RPNTestMixin from .test_mixins import RPNTestMixin
from .. import builder from .. import builder
from ..registry import DETECTORS
from mmdet.core import (assign_and_sample, bbox2roi, bbox2result, multi_apply, from mmdet.core import (assign_and_sample, bbox2roi, bbox2result, multi_apply,
merge_aug_masks) merge_aug_masks)
@DETECTORS.register_module
class CascadeRCNN(BaseDetector, RPNTestMixin): class CascadeRCNN(BaseDetector, RPNTestMixin):
def __init__(self, def __init__(self,
...@@ -37,7 +39,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin): ...@@ -37,7 +39,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
raise NotImplementedError raise NotImplementedError
if rpn_head is not None: if rpn_head is not None:
self.rpn_head = builder.build_rpn_head(rpn_head) self.rpn_head = builder.build_head(rpn_head)
if bbox_head is not None: if bbox_head is not None:
self.bbox_roi_extractor = nn.ModuleList() self.bbox_roi_extractor = nn.ModuleList()
...@@ -52,7 +54,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin): ...@@ -52,7 +54,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
for roi_extractor, head in zip(bbox_roi_extractor, bbox_head): for roi_extractor, head in zip(bbox_roi_extractor, bbox_head):
self.bbox_roi_extractor.append( self.bbox_roi_extractor.append(
builder.build_roi_extractor(roi_extractor)) builder.build_roi_extractor(roi_extractor))
self.bbox_head.append(builder.build_bbox_head(head)) self.bbox_head.append(builder.build_head(head))
if mask_head is not None: if mask_head is not None:
self.mask_roi_extractor = nn.ModuleList() self.mask_roi_extractor = nn.ModuleList()
...@@ -67,7 +69,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin): ...@@ -67,7 +69,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
for roi_extractor, head in zip(mask_roi_extractor, mask_head): for roi_extractor, head in zip(mask_roi_extractor, mask_head):
self.mask_roi_extractor.append( self.mask_roi_extractor.append(
builder.build_roi_extractor(roi_extractor)) builder.build_roi_extractor(roi_extractor))
self.mask_head.append(builder.build_mask_head(head)) self.mask_head.append(builder.build_head(head))
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.test_cfg = test_cfg self.test_cfg = test_cfg
...@@ -123,7 +125,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin): ...@@ -123,7 +125,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
losses.update(rpn_losses) losses.update(rpn_losses)
proposal_inputs = rpn_outs + (img_meta, self.test_cfg.rpn) proposal_inputs = rpn_outs + (img_meta, self.test_cfg.rpn)
proposal_list = self.rpn_head.get_proposals(*proposal_inputs) proposal_list = self.rpn_head.get_bboxes(*proposal_inputs)
else: else:
proposal_list = proposals proposal_list = proposals
......
from .two_stage import TwoStageDetector from .two_stage import TwoStageDetector
from ..registry import DETECTORS
@DETECTORS.register_module
class FastRCNN(TwoStageDetector): class FastRCNN(TwoStageDetector):
def __init__(self, def __init__(self,
......
from .two_stage import TwoStageDetector from .two_stage import TwoStageDetector
from ..registry import DETECTORS
@DETECTORS.register_module
class FasterRCNN(TwoStageDetector): class FasterRCNN(TwoStageDetector):
def __init__(self, def __init__(self,
...@@ -13,11 +15,11 @@ class FasterRCNN(TwoStageDetector): ...@@ -13,11 +15,11 @@ class FasterRCNN(TwoStageDetector):
test_cfg, test_cfg,
pretrained=None): pretrained=None):
super(FasterRCNN, self).__init__( super(FasterRCNN, self).__init__(
backbone=backbone, backbone=backbone,
neck=neck, neck=neck,
rpn_head=rpn_head, rpn_head=rpn_head,
bbox_roi_extractor=bbox_roi_extractor, bbox_roi_extractor=bbox_roi_extractor,
bbox_head=bbox_head, bbox_head=bbox_head,
train_cfg=train_cfg, train_cfg=train_cfg,
test_cfg=test_cfg, test_cfg=test_cfg,
pretrained=pretrained) pretrained=pretrained)
from .two_stage import TwoStageDetector from .two_stage import TwoStageDetector
from ..registry import DETECTORS
@DETECTORS.register_module
class MaskRCNN(TwoStageDetector): class MaskRCNN(TwoStageDetector):
def __init__(self, def __init__(self,
......
from .single_stage import SingleStageDetector from .single_stage import SingleStageDetector
from ..registry import DETECTORS
@DETECTORS.register_module
class RetinaNet(SingleStageDetector): class RetinaNet(SingleStageDetector):
def __init__(self, def __init__(self,
......
...@@ -4,8 +4,10 @@ from mmdet.core import tensor2imgs, bbox_mapping ...@@ -4,8 +4,10 @@ from mmdet.core import tensor2imgs, bbox_mapping
from .base import BaseDetector from .base import BaseDetector
from .test_mixins import RPNTestMixin from .test_mixins import RPNTestMixin
from .. import builder from .. import builder
from ..registry import DETECTORS
@DETECTORS.register_module
class RPN(BaseDetector, RPNTestMixin): class RPN(BaseDetector, RPNTestMixin):
def __init__(self, def __init__(self,
...@@ -18,7 +20,7 @@ class RPN(BaseDetector, RPNTestMixin): ...@@ -18,7 +20,7 @@ class RPN(BaseDetector, RPNTestMixin):
super(RPN, self).__init__() super(RPN, self).__init__()
self.backbone = builder.build_backbone(backbone) self.backbone = builder.build_backbone(backbone)
self.neck = builder.build_neck(neck) if neck is not None else None self.neck = builder.build_neck(neck) if neck is not None else None
self.rpn_head = builder.build_rpn_head(rpn_head) self.rpn_head = builder.build_head(rpn_head)
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.test_cfg = test_cfg self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained) self.init_weights(pretrained=pretrained)
......
...@@ -2,9 +2,11 @@ import torch.nn as nn ...@@ -2,9 +2,11 @@ import torch.nn as nn
from .base import BaseDetector from .base import BaseDetector
from .. import builder from .. import builder
from ..registry import DETECTORS
from mmdet.core import bbox2result from mmdet.core import bbox2result
@DETECTORS.register_module
class SingleStageDetector(BaseDetector): class SingleStageDetector(BaseDetector):
def __init__(self, def __init__(self,
...@@ -18,7 +20,7 @@ class SingleStageDetector(BaseDetector): ...@@ -18,7 +20,7 @@ class SingleStageDetector(BaseDetector):
self.backbone = builder.build_backbone(backbone) self.backbone = builder.build_backbone(backbone)
if neck is not None: if neck is not None:
self.neck = builder.build_neck(neck) self.neck = builder.build_neck(neck)
self.bbox_head = builder.build_single_stage_head(bbox_head) self.bbox_head = builder.build_head(bbox_head)
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.test_cfg = test_cfg self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained) self.init_weights(pretrained=pretrained)
...@@ -51,7 +53,7 @@ class SingleStageDetector(BaseDetector): ...@@ -51,7 +53,7 @@ class SingleStageDetector(BaseDetector):
x = self.extract_feat(img) x = self.extract_feat(img)
outs = self.bbox_head(x) outs = self.bbox_head(x)
bbox_inputs = outs + (img_meta, self.test_cfg, rescale) bbox_inputs = outs + (img_meta, self.test_cfg, rescale)
bbox_list = self.bbox_head.get_det_bboxes(*bbox_inputs) bbox_list = self.bbox_head.get_bboxes(*bbox_inputs)
bbox_results = [ bbox_results = [
bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes) bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes)
for det_bboxes, det_labels in bbox_list for det_bboxes, det_labels in bbox_list
......
...@@ -7,7 +7,7 @@ class RPNTestMixin(object): ...@@ -7,7 +7,7 @@ class RPNTestMixin(object):
def simple_test_rpn(self, x, img_meta, rpn_test_cfg): def simple_test_rpn(self, x, img_meta, rpn_test_cfg):
rpn_outs = self.rpn_head(x) rpn_outs = self.rpn_head(x)
proposal_inputs = rpn_outs + (img_meta, rpn_test_cfg) proposal_inputs = rpn_outs + (img_meta, rpn_test_cfg)
proposal_list = self.rpn_head.get_proposals(*proposal_inputs) proposal_list = self.rpn_head.get_bboxes(*proposal_inputs)
return proposal_list return proposal_list
def aug_test_rpn(self, feats, img_metas, rpn_test_cfg): def aug_test_rpn(self, feats, img_metas, rpn_test_cfg):
......
...@@ -4,9 +4,11 @@ import torch.nn as nn ...@@ -4,9 +4,11 @@ import torch.nn as nn
from .base import BaseDetector from .base import BaseDetector
from .test_mixins import RPNTestMixin, BBoxTestMixin, MaskTestMixin from .test_mixins import RPNTestMixin, BBoxTestMixin, MaskTestMixin
from .. import builder from .. import builder
from ..registry import DETECTORS
from mmdet.core import bbox2roi, bbox2result, build_assigner, build_sampler from mmdet.core import bbox2roi, bbox2result, build_assigner, build_sampler
@DETECTORS.register_module
class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
MaskTestMixin): MaskTestMixin):
...@@ -30,17 +32,17 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, ...@@ -30,17 +32,17 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
raise NotImplementedError raise NotImplementedError
if rpn_head is not None: if rpn_head is not None:
self.rpn_head = builder.build_rpn_head(rpn_head) self.rpn_head = builder.build_head(rpn_head)
if bbox_head is not None: if bbox_head is not None:
self.bbox_roi_extractor = builder.build_roi_extractor( self.bbox_roi_extractor = builder.build_roi_extractor(
bbox_roi_extractor) bbox_roi_extractor)
self.bbox_head = builder.build_bbox_head(bbox_head) self.bbox_head = builder.build_head(bbox_head)
if mask_head is not None: if mask_head is not None:
self.mask_roi_extractor = builder.build_roi_extractor( self.mask_roi_extractor = builder.build_roi_extractor(
mask_roi_extractor) mask_roi_extractor)
self.mask_head = builder.build_mask_head(mask_head) self.mask_head = builder.build_head(mask_head)
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.test_cfg = test_cfg self.test_cfg = test_cfg
...@@ -96,7 +98,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin, ...@@ -96,7 +98,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
losses.update(rpn_losses) losses.update(rpn_losses)
proposal_inputs = rpn_outs + (img_meta, self.test_cfg.rpn) proposal_inputs = rpn_outs + (img_meta, self.test_cfg.rpn)
proposal_list = self.rpn_head.get_proposals(*proposal_inputs) proposal_list = self.rpn_head.get_bboxes(*proposal_inputs)
else: else:
proposal_list = proposals proposal_list = proposals
......
...@@ -4,10 +4,12 @@ import pycocotools.mask as mask_util ...@@ -4,10 +4,12 @@ import pycocotools.mask as mask_util
import torch import torch
import torch.nn as nn import torch.nn as nn
from ..registry import HEADS
from ..utils import ConvModule from ..utils import ConvModule
from mmdet.core import mask_cross_entropy, mask_target from mmdet.core import mask_cross_entropy, mask_target
@HEADS.register_module
class FCNMaskHead(nn.Module): class FCNMaskHead(nn.Module):
def __init__(self, def __init__(self,
......
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmcv.cnn import xavier_init
from ..utils import ConvModule from ..utils import ConvModule
from ..utils import xavier_init from ..registry import NECKS
@NECKS.register_module
class FPN(nn.Module): class FPN(nn.Module):
def __init__(self, def __init__(self,
......
import torch.nn as nn
class Registry(object):
def __init__(self, name):
self._name = name
self._module_dict = dict()
@property
def name(self):
return self._name
@property
def module_dict(self):
return self._module_dict
def _register_module(self, module_class):
"""Register a module.
Args:
module (:obj:`nn.Module`): Module to be registered.
"""
if not issubclass(module_class, nn.Module):
raise TypeError(
'module must be a child of nn.Module, but got {}'.format(
type(module_class)))
module_name = module_class.__name__
if module_name in self._module_dict:
raise KeyError('{} is already registered in {}'.format(
module_name, self.name))
self._module_dict[module_name] = module_class
def register_module(self, cls):
self._register_module(cls)
return cls
BACKBONES = Registry('backbone')
NECKS = Registry('neck')
ROI_EXTRACTORS = Registry('roi_extractor')
HEADS = Registry('head')
DETECTORS = Registry('detector')
...@@ -4,8 +4,10 @@ import torch ...@@ -4,8 +4,10 @@ import torch
import torch.nn as nn import torch.nn as nn
from mmdet import ops from mmdet import ops
from ..registry import ROI_EXTRACTORS
@ROI_EXTRACTORS.register_module
class SingleRoIExtractor(nn.Module): class SingleRoIExtractor(nn.Module):
"""Extract RoI features from a single level feature map. """Extract RoI features from a single level feature map.
......
from .rpn_head import RPNHead
__all__ = ['RPNHead']
from __future__ import division
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet.core import (AnchorGenerator, anchor_target, delta2bbox,
multi_apply, weighted_cross_entropy, weighted_smoothl1,
weighted_binary_cross_entropy)
from mmdet.ops import nms
from ..utils import normal_init
class RPNHead(nn.Module):
"""Network head of RPN.
/ - rpn_cls (1x1 conv)
input - rpn_conv (3x3 conv) -
\ - rpn_reg (1x1 conv)
Args:
in_channels (int): Number of channels in the input feature map.
feat_channels (int): Number of channels for the RPN feature map.
anchor_scales (Iterable): Anchor scales.
anchor_ratios (Iterable): Anchor aspect ratios.
anchor_strides (Iterable): Anchor strides.
anchor_base_sizes (Iterable): Anchor base sizes.
target_means (Iterable): Mean values of regression targets.
target_stds (Iterable): Std values of regression targets.
use_sigmoid_cls (bool): Whether to use sigmoid loss for classification.
(softmax by default)
""" # noqa: W605
def __init__(self,
in_channels,
feat_channels=256,
anchor_scales=[8, 16, 32],
anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[4, 8, 16, 32, 64],
anchor_base_sizes=None,
target_means=(.0, .0, .0, .0),
target_stds=(1.0, 1.0, 1.0, 1.0),
use_sigmoid_cls=False):
super(RPNHead, self).__init__()
self.in_channels = in_channels
self.feat_channels = feat_channels
self.anchor_scales = anchor_scales
self.anchor_ratios = anchor_ratios
self.anchor_strides = anchor_strides
self.anchor_base_sizes = list(
anchor_strides) if anchor_base_sizes is None else anchor_base_sizes
self.target_means = target_means
self.target_stds = target_stds
self.use_sigmoid_cls = use_sigmoid_cls
self.anchor_generators = []
for anchor_base in self.anchor_base_sizes:
self.anchor_generators.append(
AnchorGenerator(anchor_base, anchor_scales, anchor_ratios))
self.rpn_conv = nn.Conv2d(in_channels, feat_channels, 3, padding=1)
self.relu = nn.ReLU(inplace=True)
self.num_anchors = len(self.anchor_ratios) * len(self.anchor_scales)
out_channels = (self.num_anchors
if self.use_sigmoid_cls else self.num_anchors * 2)
self.rpn_cls = nn.Conv2d(feat_channels, out_channels, 1)
self.rpn_reg = nn.Conv2d(feat_channels, self.num_anchors * 4, 1)
self.debug_imgs = None
def init_weights(self):
normal_init(self.rpn_conv, std=0.01)
normal_init(self.rpn_cls, std=0.01)
normal_init(self.rpn_reg, std=0.01)
def forward_single(self, x):
rpn_feat = self.relu(self.rpn_conv(x))
rpn_cls_score = self.rpn_cls(rpn_feat)
rpn_bbox_pred = self.rpn_reg(rpn_feat)
return rpn_cls_score, rpn_bbox_pred
def forward(self, feats):
return multi_apply(self.forward_single, feats)
def get_anchors(self, featmap_sizes, img_metas):
"""Get anchors according to feature map sizes.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
img_metas (list[dict]): Image meta info.
Returns:
tuple: anchors of each image, valid flags of each image
"""
num_imgs = len(img_metas)
num_levels = len(featmap_sizes)
# since feature map sizes of all images are the same, we only compute
# anchors for one time
multi_level_anchors = []
for i in range(num_levels):
anchors = self.anchor_generators[i].grid_anchors(
featmap_sizes[i], self.anchor_strides[i])
multi_level_anchors.append(anchors)
anchor_list = [multi_level_anchors for _ in range(num_imgs)]
# for each image, we compute valid flags of multi level anchors
valid_flag_list = []
for img_id, img_meta in enumerate(img_metas):
multi_level_flags = []
for i in range(num_levels):
anchor_stride = self.anchor_strides[i]
feat_h, feat_w = featmap_sizes[i]
h, w, _ = img_meta['pad_shape']
valid_feat_h = min(int(np.ceil(h / anchor_stride)), feat_h)
valid_feat_w = min(int(np.ceil(w / anchor_stride)), feat_w)
flags = self.anchor_generators[i].valid_flags(
(feat_h, feat_w), (valid_feat_h, valid_feat_w))
multi_level_flags.append(flags)
valid_flag_list.append(multi_level_flags)
return anchor_list, valid_flag_list
def loss_single(self, rpn_cls_score, rpn_bbox_pred, labels, label_weights,
bbox_targets, bbox_weights, num_total_samples, cfg):
# classification loss
labels = labels.contiguous().view(-1)
label_weights = label_weights.contiguous().view(-1)
if self.use_sigmoid_cls:
rpn_cls_score = rpn_cls_score.permute(0, 2, 3,
1).contiguous().view(-1)
criterion = weighted_binary_cross_entropy
else:
rpn_cls_score = rpn_cls_score.permute(0, 2, 3,
1).contiguous().view(-1, 2)
criterion = weighted_cross_entropy
loss_cls = criterion(
rpn_cls_score, labels, label_weights, avg_factor=num_total_samples)
# regression loss
bbox_targets = bbox_targets.contiguous().view(-1, 4)
bbox_weights = bbox_weights.contiguous().view(-1, 4)
rpn_bbox_pred = rpn_bbox_pred.permute(0, 2, 3, 1).contiguous().view(
-1, 4)
loss_reg = weighted_smoothl1(
rpn_bbox_pred,
bbox_targets,
bbox_weights,
beta=cfg.smoothl1_beta,
avg_factor=num_total_samples)
return loss_cls, loss_reg
def loss(self, rpn_cls_scores, rpn_bbox_preds, gt_bboxes, img_shapes, cfg):
featmap_sizes = [featmap.size()[-2:] for featmap in rpn_cls_scores]
assert len(featmap_sizes) == len(self.anchor_generators)
anchor_list, valid_flag_list = self.get_anchors(
featmap_sizes, img_shapes)
cls_reg_targets = anchor_target(
anchor_list, valid_flag_list, gt_bboxes, img_shapes,
self.target_means, self.target_stds, cfg)
if cls_reg_targets is None:
return None
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
num_total_pos, num_total_neg) = cls_reg_targets
losses_cls, losses_reg = multi_apply(
self.loss_single,
rpn_cls_scores,
rpn_bbox_preds,
labels_list,
label_weights_list,
bbox_targets_list,
bbox_weights_list,
num_total_samples=num_total_pos + num_total_neg,
cfg=cfg)
return dict(loss_rpn_cls=losses_cls, loss_rpn_reg=losses_reg)
def get_proposals(self, rpn_cls_scores, rpn_bbox_preds, img_meta, cfg):
num_imgs = len(img_meta)
featmap_sizes = [featmap.size()[-2:] for featmap in rpn_cls_scores]
mlvl_anchors = [
self.anchor_generators[idx].grid_anchors(featmap_sizes[idx],
self.anchor_strides[idx])
for idx in range(len(featmap_sizes))
]
proposal_list = []
for img_id in range(num_imgs):
rpn_cls_score_list = [
rpn_cls_scores[idx][img_id].detach()
for idx in range(len(rpn_cls_scores))
]
rpn_bbox_pred_list = [
rpn_bbox_preds[idx][img_id].detach()
for idx in range(len(rpn_bbox_preds))
]
assert len(rpn_cls_score_list) == len(rpn_bbox_pred_list)
proposals = self._get_proposals_single(
rpn_cls_score_list, rpn_bbox_pred_list, mlvl_anchors,
img_meta[img_id]['img_shape'], cfg)
proposal_list.append(proposals)
return proposal_list
def _get_proposals_single(self, rpn_cls_scores, rpn_bbox_preds,
mlvl_anchors, img_shape, cfg):
mlvl_proposals = []
for idx in range(len(rpn_cls_scores)):
rpn_cls_score = rpn_cls_scores[idx]
rpn_bbox_pred = rpn_bbox_preds[idx]
assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:]
anchors = mlvl_anchors[idx]
if self.use_sigmoid_cls:
rpn_cls_score = rpn_cls_score.permute(1, 2,
0).contiguous().view(-1)
rpn_cls_prob = rpn_cls_score.sigmoid()
scores = rpn_cls_prob
else:
rpn_cls_score = rpn_cls_score.permute(1, 2,
0).contiguous().view(
-1, 2)
rpn_cls_prob = F.softmax(rpn_cls_score, dim=1)
scores = rpn_cls_prob[:, 1]
rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).contiguous().view(
-1, 4)
_, order = scores.sort(0, descending=True)
if cfg.nms_pre > 0:
order = order[:cfg.nms_pre]
rpn_bbox_pred = rpn_bbox_pred[order, :]
anchors = anchors[order, :]
scores = scores[order]
proposals = delta2bbox(anchors, rpn_bbox_pred, self.target_means,
self.target_stds, img_shape)
w = proposals[:, 2] - proposals[:, 0] + 1
h = proposals[:, 3] - proposals[:, 1] + 1
valid_inds = torch.nonzero((w >= cfg.min_bbox_size) &
(h >= cfg.min_bbox_size)).squeeze()
proposals = proposals[valid_inds, :]
scores = scores[valid_inds]
proposals = torch.cat([proposals, scores.unsqueeze(-1)], dim=-1)
proposals, _ = nms(proposals, cfg.nms_thr)
proposals = proposals[:cfg.nms_post, :]
mlvl_proposals.append(proposals)
proposals = torch.cat(mlvl_proposals, 0)
if cfg.nms_across_levels:
proposals, _ = nms(proposals, cfg.nms_thr)
proposals = proposals[:cfg.max_num, :]
else:
scores = proposals[:, 4]
_, order = scores.sort(0, descending=True)
num = min(cfg.max_num, proposals.shape[0])
order = order[:num]
proposals = proposals[order, :]
return proposals
from .retina_head import RetinaHead
__all__ = ['RetinaHead']
...@@ -53,7 +53,8 @@ class ConvModule(nn.Module): ...@@ -53,7 +53,8 @@ class ConvModule(nn.Module):
if self.with_norm: if self.with_norm:
norm_channels = out_channels if self.activate_last else in_channels norm_channels = out_channels if self.activate_last else in_channels
self.norm = build_norm_layer(normalize, norm_channels) self.norm_name, norm = build_norm_layer(normalize, norm_channels)
self.add_module(self.norm_name, norm)
if self.with_activatation: if self.with_activatation:
assert activation in ['relu'], 'Only ReLU supported.' assert activation in ['relu'], 'Only ReLU supported.'
...@@ -63,6 +64,10 @@ class ConvModule(nn.Module): ...@@ -63,6 +64,10 @@ class ConvModule(nn.Module):
# Default using msra init # Default using msra init
self.init_weights() self.init_weights()
@property
def norm(self):
return getattr(self, self.norm_name)
def init_weights(self): def init_weights(self):
nonlinearity = 'relu' if self.activation is None else self.activation nonlinearity = 'relu' if self.activation is None else self.activation
kaiming_init(self.conv, nonlinearity=nonlinearity) kaiming_init(self.conv, nonlinearity=nonlinearity)
......
import torch.nn as nn import torch.nn as nn
norm_cfg = {'BN': nn.BatchNorm2d, 'SyncBN': None, 'GN': None}
norm_cfg = {
# format: layer_type: (abbreviation, module)
'BN': ('bn', nn.BatchNorm2d),
'SyncBN': ('bn', None),
'GN': ('gn', nn.GroupNorm),
# and potentially 'SN'
}
def build_norm_layer(cfg, num_features):
def build_norm_layer(cfg, num_features, postfix=''):
""" Build normalization layer
Args:
cfg (dict): cfg should contain:
type (str): identify norm layer type.
layer args: args needed to instantiate a norm layer.
frozen (bool): [optional] whether stop gradient updates
of norm layer, it is helpful to set frozen mode
in backbone's norms.
num_features (int): number of channels from input
postfix (int, str): appended into norm abbreation to
create named layer.
Returns:
name (str): abbreation + postfix
layer (nn.Module): created norm layer
"""
assert isinstance(cfg, dict) and 'type' in cfg assert isinstance(cfg, dict) and 'type' in cfg
cfg_ = cfg.copy() cfg_ = cfg.copy()
cfg_.setdefault('eps', 1e-5)
layer_type = cfg_.pop('type')
layer_type = cfg_.pop('type')
if layer_type not in norm_cfg: if layer_type not in norm_cfg:
raise KeyError('Unrecognized norm type {}'.format(layer_type)) raise KeyError('Unrecognized norm type {}'.format(layer_type))
elif norm_cfg[layer_type] is None: else:
raise NotImplementedError abbr, norm_layer = norm_cfg[layer_type]
if norm_layer is None:
raise NotImplementedError
assert isinstance(postfix, (int, str))
name = abbr + str(postfix)
frozen = cfg_.pop('frozen', False)
cfg_.setdefault('eps', 1e-5)
if layer_type != 'GN':
layer = norm_layer(num_features, **cfg_)
else:
assert 'num_groups' in cfg_
layer = norm_layer(num_channels=num_features, **cfg_)
if frozen:
for param in layer.parameters():
param.requires_grad = False
return norm_cfg[layer_type](num_features, **cfg_) return name, layer
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment