Commit 8a086f02 authored by ThangVu's avatar ThangVu
Browse files

add frozen stage for group norm

parent 628441b7
...@@ -234,9 +234,9 @@ class ResNet(nn.Module): ...@@ -234,9 +234,9 @@ class ResNet(nn.Module):
dilations=(1, 1, 1, 1), dilations=(1, 1, 1, 1),
out_indices=(0, 1, 2, 3), out_indices=(0, 1, 2, 3),
style='pytorch', style='pytorch',
frozen_stages=-1,
normalize=dict( normalize=dict(
type='BN', type='BN',
frozen_stages=-1,
bn_eval=True, bn_eval=True,
bn_frozen=False), bn_frozen=False),
with_cp=False): with_cp=False):
...@@ -245,7 +245,7 @@ class ResNet(nn.Module): ...@@ -245,7 +245,7 @@ class ResNet(nn.Module):
raise KeyError('invalid depth {} for resnet'.format(depth)) raise KeyError('invalid depth {} for resnet'.format(depth))
assert num_stages >= 1 and num_stages <= 4 assert num_stages >= 1 and num_stages <= 4
block, stage_blocks = self.arch_settings[depth] block, stage_blocks = self.arch_settings[depth]
stage_blocks = stage_blocks[:num_stages] self.stage_blocks = stage_blocks[:num_stages]
assert len(strides) == len(dilations) == num_stages assert len(strides) == len(dilations) == num_stages
assert max(out_indices) < num_stages assert max(out_indices) < num_stages
...@@ -254,14 +254,14 @@ class ResNet(nn.Module): ...@@ -254,14 +254,14 @@ class ResNet(nn.Module):
if normalize['type'] == 'GN': if normalize['type'] == 'GN':
assert 'num_groups' in normalize assert 'num_groups' in normalize
else: else:
assert (set(['type', 'frozen_stages', 'bn_eval', 'bn_frozen']) assert (set(['type', 'bn_eval', 'bn_frozen'])
== set(normalize)) == set(normalize))
self.out_indices = out_indices self.out_indices = out_indices
self.style = style self.style = style
self.with_cp = with_cp self.with_cp = with_cp
self.frozen_stages = frozen_stages
if normalize['type'] == 'BN': if normalize['type'] == 'BN':
self.frozen_stages = normalize['frozen_stages']
self.bn_eval = normalize['bn_eval'] self.bn_eval = normalize['bn_eval']
self.bn_frozen = normalize['bn_frozen'] self.bn_frozen = normalize['bn_frozen']
self.normalize = normalize self.normalize = normalize
...@@ -334,27 +334,27 @@ class ResNet(nn.Module): ...@@ -334,27 +334,27 @@ class ResNet(nn.Module):
def train(self, mode=True): def train(self, mode=True):
super(ResNet, self).train(mode) super(ResNet, self).train(mode)
if self.normalize['type'] == 'BN': if self.normalize['type'] == 'BN' and self.bn_eval:
if self.bn_eval: for m in self.modules():
for m in self.modules(): if isinstance(m, nn.BatchNorm2d):
if isinstance(m, nn.BatchNorm2d): m.eval()
m.eval() if self.bn_frozen:
if self.bn_frozen: for params in m.parameters():
for params in m.parameters(): params.requires_grad = False
params.requires_grad = False if mode and self.frozen_stages >= 0:
if mode and self.frozen_stages >= 0: for param in self.conv1.parameters():
for param in self.conv1.parameters(): param.requires_grad = False
param.requires_grad = False
for param in self.bn1.parameters(): stem_norm = getattr(self, self.stem_norm_name)
stem_norm.eval()
for param in stem_norm.parameters():
param.requires_grad = False
for i in range(1, self.frozen_stages + 1):
mod = getattr(self, 'layer{}'.format(i))
mod.eval()
for param in mod.parameters():
param.requires_grad = False param.requires_grad = False
self.bn1.eval()
self.bn1.weight.requires_grad = False
self.bn1.bias.requires_grad = False
for i in range(1, self.frozen_stages + 1):
mod = getattr(self, 'layer{}'.format(i))
mod.eval()
for param in mod.parameters():
param.requires_grad = False
class ResNetClassifier(ResNet): class ResNetClassifier(ResNet):
...@@ -433,8 +433,8 @@ class ResNetClassifier(ResNet): ...@@ -433,8 +433,8 @@ class ResNetClassifier(ResNet):
cf_state = pickle.load(f, encoding='latin1') cf_state = pickle.load(f, encoding='latin1')
if 'blobs' in cf_state: if 'blobs' in cf_state:
cf_state = cf_state['blobs'] cf_state = cf_state['blobs']
for py_k, cf_k in mapping.items(): for i, (py_k, cf_k) in enumerate(mapping.items(), 1):
print('Loading {} to {}'.format(cf_k, py_k)) print('[{}/{}] Loading {} to {}'.format(i, len(mapping), cf_k, py_k))
assert py_k in py_state and cf_k in cf_state assert py_k in py_state and cf_k in cf_state
py_state[py_k] = torch.Tensor(cf_state[cf_k]) py_state[py_k] = torch.Tensor(cf_state[cf_k])
self.load_state_dict(py_state) self.load_state_dict(py_state)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment