"vscode:/vscode.git/clone" did not exist on "949b6c01e074c6f7712d7da37079218b3192b102"
Commit 441015ea authored by Kai Chen's avatar Kai Chen
Browse files

Merge branch 'master' into pytorch-1.0

parents 2017c81e 3b6ae96d
import torch.nn as nn
from .bbox_head import BBoxHead
from ..registry import HEADS
from ..utils import ConvModule
@HEADS.register_module
class ConvFCBBoxHead(BBoxHead):
"""More general bbox head, with shared conv and fc layers and two optional
separated branches.
......@@ -165,6 +167,7 @@ class ConvFCBBoxHead(BBoxHead):
return cls_score, bbox_pred
@HEADS.register_module
class SharedFCBBoxHead(ConvFCBBoxHead):
def __init__(self, num_fcs=2, fc_out_channels=1024, *args, **kwargs):
......
from mmcv.runner import obj_from_dict
import mmcv
from torch import nn
from . import (backbones, necks, roi_extractors, rpn_heads, bbox_heads,
mask_heads, single_stage_heads)
__all__ = [
'build_backbone', 'build_neck', 'build_rpn_head', 'build_roi_extractor',
'build_bbox_head', 'build_mask_head', 'build_single_stage_head',
'build_detector'
]
def _build_module(cfg, parrent=None, default_args=None):
return cfg if isinstance(cfg, nn.Module) else obj_from_dict(
cfg, parrent, default_args)
def build(cfg, parrent=None, default_args=None):
from .registry import BACKBONES, NECKS, ROI_EXTRACTORS, HEADS, DETECTORS
def _build_module(cfg, registry, default_args):
assert isinstance(cfg, dict) and 'type' in cfg
assert isinstance(default_args, dict) or default_args is None
args = cfg.copy()
obj_type = args.pop('type')
if mmcv.is_str(obj_type):
if obj_type not in registry.module_dict:
raise KeyError('{} is not in the {} registry'.format(
obj_type, registry.name))
obj_type = registry.module_dict[obj_type]
elif not isinstance(obj_type, type):
raise TypeError('type must be a str or valid type, but got {}'.format(
type(obj_type)))
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
return obj_type(**args)
def build(cfg, registry, default_args=None):
if isinstance(cfg, list):
modules = [_build_module(cfg_, parrent, default_args) for cfg_ in cfg]
modules = [_build_module(cfg_, registry, default_args) for cfg_ in cfg]
return nn.Sequential(*modules)
else:
return _build_module(cfg, parrent, default_args)
return _build_module(cfg, registry, default_args)
def build_backbone(cfg):
return build(cfg, backbones)
return build(cfg, BACKBONES)
def build_neck(cfg):
return build(cfg, necks)
def build_rpn_head(cfg):
return build(cfg, rpn_heads)
return build(cfg, NECKS)
def build_roi_extractor(cfg):
return build(cfg, roi_extractors)
def build_bbox_head(cfg):
return build(cfg, bbox_heads)
def build_mask_head(cfg):
return build(cfg, mask_heads)
return build(cfg, ROI_EXTRACTORS)
def build_single_stage_head(cfg):
return build(cfg, single_stage_heads)
def build_head(cfg):
return build(cfg, HEADS)
def build_detector(cfg, train_cfg=None, test_cfg=None):
from . import detectors
return build(cfg, detectors, dict(train_cfg=train_cfg, test_cfg=test_cfg))
return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
......@@ -6,10 +6,12 @@ import torch.nn as nn
from .base import BaseDetector
from .test_mixins import RPNTestMixin
from .. import builder
from ..registry import DETECTORS
from mmdet.core import (assign_and_sample, bbox2roi, bbox2result, multi_apply,
merge_aug_masks)
@DETECTORS.register_module
class CascadeRCNN(BaseDetector, RPNTestMixin):
def __init__(self,
......@@ -37,7 +39,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
raise NotImplementedError
if rpn_head is not None:
self.rpn_head = builder.build_rpn_head(rpn_head)
self.rpn_head = builder.build_head(rpn_head)
if bbox_head is not None:
self.bbox_roi_extractor = nn.ModuleList()
......@@ -52,7 +54,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
for roi_extractor, head in zip(bbox_roi_extractor, bbox_head):
self.bbox_roi_extractor.append(
builder.build_roi_extractor(roi_extractor))
self.bbox_head.append(builder.build_bbox_head(head))
self.bbox_head.append(builder.build_head(head))
if mask_head is not None:
self.mask_roi_extractor = nn.ModuleList()
......@@ -67,7 +69,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
for roi_extractor, head in zip(mask_roi_extractor, mask_head):
self.mask_roi_extractor.append(
builder.build_roi_extractor(roi_extractor))
self.mask_head.append(builder.build_mask_head(head))
self.mask_head.append(builder.build_head(head))
self.train_cfg = train_cfg
self.test_cfg = test_cfg
......@@ -123,7 +125,7 @@ class CascadeRCNN(BaseDetector, RPNTestMixin):
losses.update(rpn_losses)
proposal_inputs = rpn_outs + (img_meta, self.test_cfg.rpn)
proposal_list = self.rpn_head.get_proposals(*proposal_inputs)
proposal_list = self.rpn_head.get_bboxes(*proposal_inputs)
else:
proposal_list = proposals
......
from .two_stage import TwoStageDetector
from ..registry import DETECTORS
@DETECTORS.register_module
class FastRCNN(TwoStageDetector):
def __init__(self,
......
from .two_stage import TwoStageDetector
from ..registry import DETECTORS
@DETECTORS.register_module
class FasterRCNN(TwoStageDetector):
def __init__(self,
......@@ -13,11 +15,11 @@ class FasterRCNN(TwoStageDetector):
test_cfg,
pretrained=None):
super(FasterRCNN, self).__init__(
backbone=backbone,
neck=neck,
rpn_head=rpn_head,
bbox_roi_extractor=bbox_roi_extractor,
bbox_head=bbox_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained)
backbone=backbone,
neck=neck,
rpn_head=rpn_head,
bbox_roi_extractor=bbox_roi_extractor,
bbox_head=bbox_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained)
from .two_stage import TwoStageDetector
from ..registry import DETECTORS
@DETECTORS.register_module
class MaskRCNN(TwoStageDetector):
def __init__(self,
......
from .single_stage import SingleStageDetector
from ..registry import DETECTORS
@DETECTORS.register_module
class RetinaNet(SingleStageDetector):
def __init__(self,
......
......@@ -4,8 +4,10 @@ from mmdet.core import tensor2imgs, bbox_mapping
from .base import BaseDetector
from .test_mixins import RPNTestMixin
from .. import builder
from ..registry import DETECTORS
@DETECTORS.register_module
class RPN(BaseDetector, RPNTestMixin):
def __init__(self,
......@@ -18,7 +20,7 @@ class RPN(BaseDetector, RPNTestMixin):
super(RPN, self).__init__()
self.backbone = builder.build_backbone(backbone)
self.neck = builder.build_neck(neck) if neck is not None else None
self.rpn_head = builder.build_rpn_head(rpn_head)
self.rpn_head = builder.build_head(rpn_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained)
......
......@@ -2,9 +2,11 @@ import torch.nn as nn
from .base import BaseDetector
from .. import builder
from ..registry import DETECTORS
from mmdet.core import bbox2result
@DETECTORS.register_module
class SingleStageDetector(BaseDetector):
def __init__(self,
......@@ -18,7 +20,7 @@ class SingleStageDetector(BaseDetector):
self.backbone = builder.build_backbone(backbone)
if neck is not None:
self.neck = builder.build_neck(neck)
self.bbox_head = builder.build_single_stage_head(bbox_head)
self.bbox_head = builder.build_head(bbox_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained)
......@@ -51,7 +53,7 @@ class SingleStageDetector(BaseDetector):
x = self.extract_feat(img)
outs = self.bbox_head(x)
bbox_inputs = outs + (img_meta, self.test_cfg, rescale)
bbox_list = self.bbox_head.get_det_bboxes(*bbox_inputs)
bbox_list = self.bbox_head.get_bboxes(*bbox_inputs)
bbox_results = [
bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes)
for det_bboxes, det_labels in bbox_list
......
......@@ -7,7 +7,7 @@ class RPNTestMixin(object):
def simple_test_rpn(self, x, img_meta, rpn_test_cfg):
rpn_outs = self.rpn_head(x)
proposal_inputs = rpn_outs + (img_meta, rpn_test_cfg)
proposal_list = self.rpn_head.get_proposals(*proposal_inputs)
proposal_list = self.rpn_head.get_bboxes(*proposal_inputs)
return proposal_list
def aug_test_rpn(self, feats, img_metas, rpn_test_cfg):
......
......@@ -4,9 +4,11 @@ import torch.nn as nn
from .base import BaseDetector
from .test_mixins import RPNTestMixin, BBoxTestMixin, MaskTestMixin
from .. import builder
from ..registry import DETECTORS
from mmdet.core import bbox2roi, bbox2result, build_assigner, build_sampler
@DETECTORS.register_module
class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
MaskTestMixin):
......@@ -30,17 +32,17 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
raise NotImplementedError
if rpn_head is not None:
self.rpn_head = builder.build_rpn_head(rpn_head)
self.rpn_head = builder.build_head(rpn_head)
if bbox_head is not None:
self.bbox_roi_extractor = builder.build_roi_extractor(
bbox_roi_extractor)
self.bbox_head = builder.build_bbox_head(bbox_head)
self.bbox_head = builder.build_head(bbox_head)
if mask_head is not None:
self.mask_roi_extractor = builder.build_roi_extractor(
mask_roi_extractor)
self.mask_head = builder.build_mask_head(mask_head)
self.mask_head = builder.build_head(mask_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
......@@ -96,7 +98,7 @@ class TwoStageDetector(BaseDetector, RPNTestMixin, BBoxTestMixin,
losses.update(rpn_losses)
proposal_inputs = rpn_outs + (img_meta, self.test_cfg.rpn)
proposal_list = self.rpn_head.get_proposals(*proposal_inputs)
proposal_list = self.rpn_head.get_bboxes(*proposal_inputs)
else:
proposal_list = proposals
......
......@@ -4,10 +4,12 @@ import pycocotools.mask as mask_util
import torch
import torch.nn as nn
from ..registry import HEADS
from ..utils import ConvModule
from mmdet.core import mask_cross_entropy, mask_target
@HEADS.register_module
class FCNMaskHead(nn.Module):
def __init__(self,
......
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import xavier_init
from ..utils import ConvModule
from ..utils import xavier_init
from ..registry import NECKS
@NECKS.register_module
class FPN(nn.Module):
def __init__(self,
......
import torch.nn as nn
class Registry(object):
def __init__(self, name):
self._name = name
self._module_dict = dict()
@property
def name(self):
return self._name
@property
def module_dict(self):
return self._module_dict
def _register_module(self, module_class):
"""Register a module.
Args:
module (:obj:`nn.Module`): Module to be registered.
"""
if not issubclass(module_class, nn.Module):
raise TypeError(
'module must be a child of nn.Module, but got {}'.format(
type(module_class)))
module_name = module_class.__name__
if module_name in self._module_dict:
raise KeyError('{} is already registered in {}'.format(
module_name, self.name))
self._module_dict[module_name] = module_class
def register_module(self, cls):
self._register_module(cls)
return cls
BACKBONES = Registry('backbone')
NECKS = Registry('neck')
ROI_EXTRACTORS = Registry('roi_extractor')
HEADS = Registry('head')
DETECTORS = Registry('detector')
......@@ -4,8 +4,10 @@ import torch
import torch.nn as nn
from mmdet import ops
from ..registry import ROI_EXTRACTORS
@ROI_EXTRACTORS.register_module
class SingleRoIExtractor(nn.Module):
"""Extract RoI features from a single level feature map.
......
from .rpn_head import RPNHead
__all__ = ['RPNHead']
from __future__ import division
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet.core import (AnchorGenerator, anchor_target, delta2bbox,
multi_apply, weighted_cross_entropy, weighted_smoothl1,
weighted_binary_cross_entropy)
from mmdet.ops import nms
from ..utils import normal_init
class RPNHead(nn.Module):
"""Network head of RPN.
/ - rpn_cls (1x1 conv)
input - rpn_conv (3x3 conv) -
\ - rpn_reg (1x1 conv)
Args:
in_channels (int): Number of channels in the input feature map.
feat_channels (int): Number of channels for the RPN feature map.
anchor_scales (Iterable): Anchor scales.
anchor_ratios (Iterable): Anchor aspect ratios.
anchor_strides (Iterable): Anchor strides.
anchor_base_sizes (Iterable): Anchor base sizes.
target_means (Iterable): Mean values of regression targets.
target_stds (Iterable): Std values of regression targets.
use_sigmoid_cls (bool): Whether to use sigmoid loss for classification.
(softmax by default)
""" # noqa: W605
def __init__(self,
in_channels,
feat_channels=256,
anchor_scales=[8, 16, 32],
anchor_ratios=[0.5, 1.0, 2.0],
anchor_strides=[4, 8, 16, 32, 64],
anchor_base_sizes=None,
target_means=(.0, .0, .0, .0),
target_stds=(1.0, 1.0, 1.0, 1.0),
use_sigmoid_cls=False):
super(RPNHead, self).__init__()
self.in_channels = in_channels
self.feat_channels = feat_channels
self.anchor_scales = anchor_scales
self.anchor_ratios = anchor_ratios
self.anchor_strides = anchor_strides
self.anchor_base_sizes = list(
anchor_strides) if anchor_base_sizes is None else anchor_base_sizes
self.target_means = target_means
self.target_stds = target_stds
self.use_sigmoid_cls = use_sigmoid_cls
self.anchor_generators = []
for anchor_base in self.anchor_base_sizes:
self.anchor_generators.append(
AnchorGenerator(anchor_base, anchor_scales, anchor_ratios))
self.rpn_conv = nn.Conv2d(in_channels, feat_channels, 3, padding=1)
self.relu = nn.ReLU(inplace=True)
self.num_anchors = len(self.anchor_ratios) * len(self.anchor_scales)
out_channels = (self.num_anchors
if self.use_sigmoid_cls else self.num_anchors * 2)
self.rpn_cls = nn.Conv2d(feat_channels, out_channels, 1)
self.rpn_reg = nn.Conv2d(feat_channels, self.num_anchors * 4, 1)
self.debug_imgs = None
def init_weights(self):
normal_init(self.rpn_conv, std=0.01)
normal_init(self.rpn_cls, std=0.01)
normal_init(self.rpn_reg, std=0.01)
def forward_single(self, x):
rpn_feat = self.relu(self.rpn_conv(x))
rpn_cls_score = self.rpn_cls(rpn_feat)
rpn_bbox_pred = self.rpn_reg(rpn_feat)
return rpn_cls_score, rpn_bbox_pred
def forward(self, feats):
return multi_apply(self.forward_single, feats)
def get_anchors(self, featmap_sizes, img_metas):
"""Get anchors according to feature map sizes.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
img_metas (list[dict]): Image meta info.
Returns:
tuple: anchors of each image, valid flags of each image
"""
num_imgs = len(img_metas)
num_levels = len(featmap_sizes)
# since feature map sizes of all images are the same, we only compute
# anchors for one time
multi_level_anchors = []
for i in range(num_levels):
anchors = self.anchor_generators[i].grid_anchors(
featmap_sizes[i], self.anchor_strides[i])
multi_level_anchors.append(anchors)
anchor_list = [multi_level_anchors for _ in range(num_imgs)]
# for each image, we compute valid flags of multi level anchors
valid_flag_list = []
for img_id, img_meta in enumerate(img_metas):
multi_level_flags = []
for i in range(num_levels):
anchor_stride = self.anchor_strides[i]
feat_h, feat_w = featmap_sizes[i]
h, w, _ = img_meta['pad_shape']
valid_feat_h = min(int(np.ceil(h / anchor_stride)), feat_h)
valid_feat_w = min(int(np.ceil(w / anchor_stride)), feat_w)
flags = self.anchor_generators[i].valid_flags(
(feat_h, feat_w), (valid_feat_h, valid_feat_w))
multi_level_flags.append(flags)
valid_flag_list.append(multi_level_flags)
return anchor_list, valid_flag_list
def loss_single(self, rpn_cls_score, rpn_bbox_pred, labels, label_weights,
bbox_targets, bbox_weights, num_total_samples, cfg):
# classification loss
labels = labels.contiguous().view(-1)
label_weights = label_weights.contiguous().view(-1)
if self.use_sigmoid_cls:
rpn_cls_score = rpn_cls_score.permute(0, 2, 3,
1).contiguous().view(-1)
criterion = weighted_binary_cross_entropy
else:
rpn_cls_score = rpn_cls_score.permute(0, 2, 3,
1).contiguous().view(-1, 2)
criterion = weighted_cross_entropy
loss_cls = criterion(
rpn_cls_score, labels, label_weights, avg_factor=num_total_samples)
# regression loss
bbox_targets = bbox_targets.contiguous().view(-1, 4)
bbox_weights = bbox_weights.contiguous().view(-1, 4)
rpn_bbox_pred = rpn_bbox_pred.permute(0, 2, 3, 1).contiguous().view(
-1, 4)
loss_reg = weighted_smoothl1(
rpn_bbox_pred,
bbox_targets,
bbox_weights,
beta=cfg.smoothl1_beta,
avg_factor=num_total_samples)
return loss_cls, loss_reg
def loss(self, rpn_cls_scores, rpn_bbox_preds, gt_bboxes, img_shapes, cfg):
featmap_sizes = [featmap.size()[-2:] for featmap in rpn_cls_scores]
assert len(featmap_sizes) == len(self.anchor_generators)
anchor_list, valid_flag_list = self.get_anchors(
featmap_sizes, img_shapes)
cls_reg_targets = anchor_target(
anchor_list, valid_flag_list, gt_bboxes, img_shapes,
self.target_means, self.target_stds, cfg)
if cls_reg_targets is None:
return None
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
num_total_pos, num_total_neg) = cls_reg_targets
losses_cls, losses_reg = multi_apply(
self.loss_single,
rpn_cls_scores,
rpn_bbox_preds,
labels_list,
label_weights_list,
bbox_targets_list,
bbox_weights_list,
num_total_samples=num_total_pos + num_total_neg,
cfg=cfg)
return dict(loss_rpn_cls=losses_cls, loss_rpn_reg=losses_reg)
def get_proposals(self, rpn_cls_scores, rpn_bbox_preds, img_meta, cfg):
num_imgs = len(img_meta)
featmap_sizes = [featmap.size()[-2:] for featmap in rpn_cls_scores]
mlvl_anchors = [
self.anchor_generators[idx].grid_anchors(featmap_sizes[idx],
self.anchor_strides[idx])
for idx in range(len(featmap_sizes))
]
proposal_list = []
for img_id in range(num_imgs):
rpn_cls_score_list = [
rpn_cls_scores[idx][img_id].detach()
for idx in range(len(rpn_cls_scores))
]
rpn_bbox_pred_list = [
rpn_bbox_preds[idx][img_id].detach()
for idx in range(len(rpn_bbox_preds))
]
assert len(rpn_cls_score_list) == len(rpn_bbox_pred_list)
proposals = self._get_proposals_single(
rpn_cls_score_list, rpn_bbox_pred_list, mlvl_anchors,
img_meta[img_id]['img_shape'], cfg)
proposal_list.append(proposals)
return proposal_list
def _get_proposals_single(self, rpn_cls_scores, rpn_bbox_preds,
mlvl_anchors, img_shape, cfg):
mlvl_proposals = []
for idx in range(len(rpn_cls_scores)):
rpn_cls_score = rpn_cls_scores[idx]
rpn_bbox_pred = rpn_bbox_preds[idx]
assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:]
anchors = mlvl_anchors[idx]
if self.use_sigmoid_cls:
rpn_cls_score = rpn_cls_score.permute(1, 2,
0).contiguous().view(-1)
rpn_cls_prob = rpn_cls_score.sigmoid()
scores = rpn_cls_prob
else:
rpn_cls_score = rpn_cls_score.permute(1, 2,
0).contiguous().view(
-1, 2)
rpn_cls_prob = F.softmax(rpn_cls_score, dim=1)
scores = rpn_cls_prob[:, 1]
rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).contiguous().view(
-1, 4)
_, order = scores.sort(0, descending=True)
if cfg.nms_pre > 0:
order = order[:cfg.nms_pre]
rpn_bbox_pred = rpn_bbox_pred[order, :]
anchors = anchors[order, :]
scores = scores[order]
proposals = delta2bbox(anchors, rpn_bbox_pred, self.target_means,
self.target_stds, img_shape)
w = proposals[:, 2] - proposals[:, 0] + 1
h = proposals[:, 3] - proposals[:, 1] + 1
valid_inds = torch.nonzero((w >= cfg.min_bbox_size) &
(h >= cfg.min_bbox_size)).squeeze()
proposals = proposals[valid_inds, :]
scores = scores[valid_inds]
proposals = torch.cat([proposals, scores.unsqueeze(-1)], dim=-1)
proposals, _ = nms(proposals, cfg.nms_thr)
proposals = proposals[:cfg.nms_post, :]
mlvl_proposals.append(proposals)
proposals = torch.cat(mlvl_proposals, 0)
if cfg.nms_across_levels:
proposals, _ = nms(proposals, cfg.nms_thr)
proposals = proposals[:cfg.max_num, :]
else:
scores = proposals[:, 4]
_, order = scores.sort(0, descending=True)
num = min(cfg.max_num, proposals.shape[0])
order = order[:num]
proposals = proposals[order, :]
return proposals
from .retina_head import RetinaHead
__all__ = ['RetinaHead']
......@@ -53,7 +53,8 @@ class ConvModule(nn.Module):
if self.with_norm:
norm_channels = out_channels if self.activate_last else in_channels
self.norm = build_norm_layer(normalize, norm_channels)
self.norm_name, norm = build_norm_layer(normalize, norm_channels)
self.add_module(self.norm_name, norm)
if self.with_activatation:
assert activation in ['relu'], 'Only ReLU supported.'
......@@ -63,6 +64,10 @@ class ConvModule(nn.Module):
# Default using msra init
self.init_weights()
@property
def norm(self):
return getattr(self, self.norm_name)
def init_weights(self):
nonlinearity = 'relu' if self.activation is None else self.activation
kaiming_init(self.conv, nonlinearity=nonlinearity)
......
import torch.nn as nn
norm_cfg = {'BN': nn.BatchNorm2d, 'SyncBN': None, 'GN': None}
norm_cfg = {
# format: layer_type: (abbreviation, module)
'BN': ('bn', nn.BatchNorm2d),
'SyncBN': ('bn', None),
'GN': ('gn', nn.GroupNorm),
# and potentially 'SN'
}
def build_norm_layer(cfg, num_features):
def build_norm_layer(cfg, num_features, postfix=''):
""" Build normalization layer
Args:
cfg (dict): cfg should contain:
type (str): identify norm layer type.
layer args: args needed to instantiate a norm layer.
frozen (bool): [optional] whether stop gradient updates
of norm layer, it is helpful to set frozen mode
in backbone's norms.
num_features (int): number of channels from input
postfix (int, str): appended into norm abbreation to
create named layer.
Returns:
name (str): abbreation + postfix
layer (nn.Module): created norm layer
"""
assert isinstance(cfg, dict) and 'type' in cfg
cfg_ = cfg.copy()
cfg_.setdefault('eps', 1e-5)
layer_type = cfg_.pop('type')
layer_type = cfg_.pop('type')
if layer_type not in norm_cfg:
raise KeyError('Unrecognized norm type {}'.format(layer_type))
elif norm_cfg[layer_type] is None:
raise NotImplementedError
else:
abbr, norm_layer = norm_cfg[layer_type]
if norm_layer is None:
raise NotImplementedError
assert isinstance(postfix, (int, str))
name = abbr + str(postfix)
frozen = cfg_.pop('frozen', False)
cfg_.setdefault('eps', 1e-5)
if layer_type != 'GN':
layer = norm_layer(num_features, **cfg_)
else:
assert 'num_groups' in cfg_
layer = norm_layer(num_channels=num_features, **cfg_)
if frozen:
for param in layer.parameters():
param.requires_grad = False
return norm_cfg[layer_type](num_features, **cfg_)
return name, layer
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment