"vscode:/vscode.git/clone" did not exist on "de277a8119019f5bcdcd49271d814ad601b68ba2"
Commit bee93859 authored by ThangVu's avatar ThangVu
Browse files

revise group norm (5)

parent 574a920a
...@@ -9,6 +9,7 @@ model = dict( ...@@ -9,6 +9,7 @@ model = dict(
out_indices=(0, 1, 2, 3), out_indices=(0, 1, 2, 3),
frozen_stages=1, frozen_stages=1,
style='pytorch', style='pytorch',
# Note: eval_mode and frozen are required args for backbone
normalize=dict( normalize=dict(
type='GN', type='GN',
num_groups=32, num_groups=32,
......
...@@ -37,8 +37,8 @@ class BasicBlock(nn.Module): ...@@ -37,8 +37,8 @@ class BasicBlock(nn.Module):
self.norm1_name, norm1 = build_norm_layer(normalize, planes, postfix=1) self.norm1_name, norm1 = build_norm_layer(normalize, planes, postfix=1)
self.norm2_name, norm2 = build_norm_layer(normalize, planes, postfix=2) self.norm2_name, norm2 = build_norm_layer(normalize, planes, postfix=2)
self.add_module(self.norm1, norm1) self.add_module(self.norm1_name, norm1)
self.add_module(self.norm2, norm2) self.add_module(self.norm2_name, norm2)
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes) self.conv2 = conv3x3(planes, planes)
...@@ -325,12 +325,12 @@ class ResNet(nn.Module): ...@@ -325,12 +325,12 @@ class ResNet(nn.Module):
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
def _freeze_stages(self):
if self.frozen_stages >= 0: if self.frozen_stages >= 0:
for m in [self.conv1, self.norm1]: for m in [self.conv1, self.norm1]:
for param in m.parameters(): for param in m.parameters():
param.requires_grad = False param.requires_grad = False
def _freeze_stages(self):
for i in range(1, self.frozen_stages + 1): for i in range(1, self.frozen_stages + 1):
m = getattr(self, 'layer{}'.format(i)) m = getattr(self, 'layer{}'.format(i))
for param in m.parameters(): for param in m.parameters():
......
...@@ -31,6 +31,13 @@ def build_norm_layer(cfg, num_features, postfix=''): ...@@ -31,6 +31,13 @@ def build_norm_layer(cfg, num_features, postfix=''):
assert isinstance(cfg, dict) and 'type' in cfg assert isinstance(cfg, dict) and 'type' in cfg
cfg_ = cfg.copy() cfg_ = cfg.copy()
# eval_mode is supported and popped out for processing in module
# having pretrained weight only (e.g. backbone)
# raise an exception if eval_mode is in here
if 'eval_mode' in cfg:
raise Exception('eval_mode for modules without pretrained weights '
'is not supported')
layer_type = cfg_.pop('type') layer_type = cfg_.pop('type')
if layer_type not in norm_cfg: if layer_type not in norm_cfg:
raise KeyError('Unrecognized norm type {}'.format(layer_type)) raise KeyError('Unrecognized norm type {}'.format(layer_type))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment