Commit 742fd13c authored by Sepehr Sameni's avatar Sepehr Sameni Committed by Francisco Massa
Browse files

remove duplicate code from densenet (#827)

* remove duplicate code from densenet

* correct indentation
parent 6f2f9213
...@@ -117,6 +117,23 @@ class DenseNet(nn.Module): ...@@ -117,6 +117,23 @@ class DenseNet(nn.Module):
return out return out
def _load_state_dict(model, model_url):
# '.'s are no longer allowed in module names, but pervious _DenseLayer
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
# They are also in the checkpoints in model_urls. This pattern is used
# to find such keys.
pattern = re.compile(
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
state_dict = model_zoo.load_url(model_url)
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
new_key = res.group(1) + res.group(2)
state_dict[new_key] = state_dict[key]
del state_dict[key]
model.load_state_dict(state_dict)
def densenet121(pretrained=False, **kwargs): def densenet121(pretrained=False, **kwargs):
r"""Densenet-121 model from r"""Densenet-121 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_ `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
...@@ -127,20 +144,7 @@ def densenet121(pretrained=False, **kwargs): ...@@ -127,20 +144,7 @@ def densenet121(pretrained=False, **kwargs):
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16),
**kwargs) **kwargs)
if pretrained: if pretrained:
# '.'s are no longer allowed in module names, but pervious _DenseLayer _load_state_dict(model, model_urls['densenet121'])
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
# They are also in the checkpoints in model_urls. This pattern is used
# to find such keys.
pattern = re.compile(
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
state_dict = model_zoo.load_url(model_urls['densenet121'])
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
new_key = res.group(1) + res.group(2)
state_dict[new_key] = state_dict[key]
del state_dict[key]
model.load_state_dict(state_dict)
return model return model
...@@ -154,20 +158,7 @@ def densenet169(pretrained=False, **kwargs): ...@@ -154,20 +158,7 @@ def densenet169(pretrained=False, **kwargs):
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32),
**kwargs) **kwargs)
if pretrained: if pretrained:
# '.'s are no longer allowed in module names, but pervious _DenseLayer _load_state_dict(model, model_urls['densenet169'])
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
# They are also in the checkpoints in model_urls. This pattern is used
# to find such keys.
pattern = re.compile(
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
state_dict = model_zoo.load_url(model_urls['densenet169'])
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
new_key = res.group(1) + res.group(2)
state_dict[new_key] = state_dict[key]
del state_dict[key]
model.load_state_dict(state_dict)
return model return model
...@@ -181,20 +172,7 @@ def densenet201(pretrained=False, **kwargs): ...@@ -181,20 +172,7 @@ def densenet201(pretrained=False, **kwargs):
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32),
**kwargs) **kwargs)
if pretrained: if pretrained:
# '.'s are no longer allowed in module names, but pervious _DenseLayer _load_state_dict(model, model_urls['densenet201'])
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
# They are also in the checkpoints in model_urls. This pattern is used
# to find such keys.
pattern = re.compile(
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
state_dict = model_zoo.load_url(model_urls['densenet201'])
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
new_key = res.group(1) + res.group(2)
state_dict[new_key] = state_dict[key]
del state_dict[key]
model.load_state_dict(state_dict)
return model return model
...@@ -208,18 +186,5 @@ def densenet161(pretrained=False, **kwargs): ...@@ -208,18 +186,5 @@ def densenet161(pretrained=False, **kwargs):
model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24),
**kwargs) **kwargs)
if pretrained: if pretrained:
# '.'s are no longer allowed in module names, but pervious _DenseLayer _load_state_dict(model, model_urls['densenet161'])
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
# They are also in the checkpoints in model_urls. This pattern is used
# to find such keys.
pattern = re.compile(
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
state_dict = model_zoo.load_url(model_urls['densenet161'])
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
new_key = res.group(1) + res.group(2)
state_dict[new_key] = state_dict[key]
del state_dict[key]
model.load_state_dict(state_dict)
return model return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment