Commit 8a086f02 authored by ThangVu's avatar ThangVu
Browse files

add frozen stage for group norm

parent 628441b7
......@@ -234,9 +234,9 @@ class ResNet(nn.Module):
dilations=(1, 1, 1, 1),
out_indices=(0, 1, 2, 3),
style='pytorch',
frozen_stages=-1,
normalize=dict(
type='BN',
frozen_stages=-1,
bn_eval=True,
bn_frozen=False),
with_cp=False):
......@@ -245,7 +245,7 @@ class ResNet(nn.Module):
raise KeyError('invalid depth {} for resnet'.format(depth))
assert num_stages >= 1 and num_stages <= 4
block, stage_blocks = self.arch_settings[depth]
stage_blocks = stage_blocks[:num_stages]
self.stage_blocks = stage_blocks[:num_stages]
assert len(strides) == len(dilations) == num_stages
assert max(out_indices) < num_stages
......@@ -254,14 +254,14 @@ class ResNet(nn.Module):
if normalize['type'] == 'GN':
assert 'num_groups' in normalize
else:
assert (set(['type', 'frozen_stages', 'bn_eval', 'bn_frozen'])
assert (set(['type', 'bn_eval', 'bn_frozen'])
== set(normalize))
self.out_indices = out_indices
self.style = style
self.with_cp = with_cp
self.frozen_stages = frozen_stages
if normalize['type'] == 'BN':
self.frozen_stages = normalize['frozen_stages']
self.bn_eval = normalize['bn_eval']
self.bn_frozen = normalize['bn_frozen']
self.normalize = normalize
......@@ -334,27 +334,27 @@ class ResNet(nn.Module):
def train(self, mode=True):
super(ResNet, self).train(mode)
if self.normalize['type'] == 'BN':
if self.bn_eval:
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
if self.bn_frozen:
for params in m.parameters():
params.requires_grad = False
if mode and self.frozen_stages >= 0:
for param in self.conv1.parameters():
param.requires_grad = False
for param in self.bn1.parameters():
if self.normalize['type'] == 'BN' and self.bn_eval:
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
if self.bn_frozen:
for params in m.parameters():
params.requires_grad = False
if mode and self.frozen_stages >= 0:
for param in self.conv1.parameters():
param.requires_grad = False
stem_norm = getattr(self, self.stem_norm_name)
stem_norm.eval()
for param in stem_norm.parameters():
param.requires_grad = False
for i in range(1, self.frozen_stages + 1):
mod = getattr(self, 'layer{}'.format(i))
mod.eval()
for param in mod.parameters():
param.requires_grad = False
self.bn1.eval()
self.bn1.weight.requires_grad = False
self.bn1.bias.requires_grad = False
for i in range(1, self.frozen_stages + 1):
mod = getattr(self, 'layer{}'.format(i))
mod.eval()
for param in mod.parameters():
param.requires_grad = False
class ResNetClassifier(ResNet):
......@@ -433,8 +433,8 @@ class ResNetClassifier(ResNet):
cf_state = pickle.load(f, encoding='latin1')
if 'blobs' in cf_state:
cf_state = cf_state['blobs']
for py_k, cf_k in mapping.items():
print('Loading {} to {}'.format(cf_k, py_k))
for i, (py_k, cf_k) in enumerate(mapping.items(), 1):
print('[{}/{}] Loading {} to {}'.format(i, len(mapping), cf_k, py_k))
assert py_k in py_state and cf_k in cf_state
py_state[py_k] = torch.Tensor(cf_state[cf_k])
self.load_state_dict(py_state)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment