Commit 25015e1b authored by ThangVu's avatar ThangVu
Browse files

add group norm support

parent 2b17166a
import torch.nn as nn import torch.nn as nn
from .bbox_head import BBoxHead from .bbox_head import BBoxHead
from ..utils import ConvModule from ..utils import ConvModule, build_norm_layer
class ConvFCBBoxHead(BBoxHead): class ConvFCBBoxHead(BBoxHead):
...@@ -113,8 +113,13 @@ class ConvFCBBoxHead(BBoxHead): ...@@ -113,8 +113,13 @@ class ConvFCBBoxHead(BBoxHead):
for i in range(num_branch_fcs): for i in range(num_branch_fcs):
fc_in_channels = (last_layer_dim fc_in_channels = (last_layer_dim
if i == 0 else self.fc_out_channels) if i == 0 else self.fc_out_channels)
branch_fcs.append( if self.normalize is not None:
nn.Linear(fc_in_channels, self.fc_out_channels)) branch_fcs.append(nn.Sequential(
nn.Linear(fc_in_channels, self.fc_out_channels, False),
build_norm_layer(self.normalize, self.fc_out_channels)))
else:
branch_fcs.append(
nn.Linear(fc_in_channels, self.fc_out_channels))
last_layer_dim = self.fc_out_channels last_layer_dim = self.fc_out_channels
return branch_convs, branch_fcs, last_layer_dim return branch_convs, branch_fcs, last_layer_dim
...@@ -124,7 +129,8 @@ class ConvFCBBoxHead(BBoxHead): ...@@ -124,7 +129,8 @@ class ConvFCBBoxHead(BBoxHead):
for m in module_list.modules(): for m in module_list.modules():
if isinstance(m, nn.Linear): if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight) nn.init.xavier_uniform_(m.weight)
nn.init.constant_(m.bias, 0) if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x): def forward(self, x):
# shared part # shared part
......
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from ..utils import ConvModule from ..utils import ConvModule
from ..utils import xavier_init from mmcv.cnn import xavier_init
class FPN(nn.Module): class FPN(nn.Module):
......
import torch.nn as nn import torch.nn as nn
norm_cfg = {'BN': nn.BatchNorm2d, 'SyncBN': None, 'GN': None} norm_cfg = {'BN': nn.BatchNorm2d, 'SyncBN': None, 'GN': nn.GroupNorm}
def build_norm_layer(cfg, num_features): def build_norm_layer(cfg, num_features):
...@@ -9,9 +9,15 @@ def build_norm_layer(cfg, num_features): ...@@ -9,9 +9,15 @@ def build_norm_layer(cfg, num_features):
cfg_.setdefault('eps', 1e-5) cfg_.setdefault('eps', 1e-5)
layer_type = cfg_.pop('type') layer_type = cfg_.pop('type')
# args name matching
if layer_type == 'GN':
cfg_.setdefault('num_channels', num_features)
else:
cfg_.setdefault('num_features', num_features)
if layer_type not in norm_cfg: if layer_type not in norm_cfg:
raise KeyError('Unrecognized norm type {}'.format(layer_type)) raise KeyError('Unrecognized norm type {}'.format(layer_type))
elif norm_cfg[layer_type] is None: elif norm_cfg[layer_type] is None:
raise NotImplementedError raise NotImplementedError
return norm_cfg[layer_type](num_features, **cfg_) return norm_cfg[layer_type](**cfg_)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment