Commit 742fd13c authored by Sepehr Sameni's avatar Sepehr Sameni Committed by Francisco Massa
Browse files

remove duplicate code from densenet (#827)

* remove duplicate code from densenet

* correct indentation
parent 6f2f9213
...@@ -117,23 +117,14 @@ class DenseNet(nn.Module): ...@@ -117,23 +117,14 @@ class DenseNet(nn.Module):
return out return out
def densenet121(pretrained=False, **kwargs): def _load_state_dict(model, model_url):
r"""Densenet-121 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16),
**kwargs)
if pretrained:
# '.'s are no longer allowed in module names, but pervious _DenseLayer # '.'s are no longer allowed in module names, but pervious _DenseLayer
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
# They are also in the checkpoints in model_urls. This pattern is used # They are also in the checkpoints in model_urls. This pattern is used
# to find such keys. # to find such keys.
pattern = re.compile( pattern = re.compile(
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
state_dict = model_zoo.load_url(model_urls['densenet121']) state_dict = model_zoo.load_url(model_url)
for key in list(state_dict.keys()): for key in list(state_dict.keys()):
res = pattern.match(key) res = pattern.match(key)
if res: if res:
...@@ -141,6 +132,19 @@ def densenet121(pretrained=False, **kwargs): ...@@ -141,6 +132,19 @@ def densenet121(pretrained=False, **kwargs):
state_dict[new_key] = state_dict[key] state_dict[new_key] = state_dict[key]
del state_dict[key] del state_dict[key]
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
def densenet121(pretrained=False, **kwargs):
r"""Densenet-121 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16),
**kwargs)
if pretrained:
_load_state_dict(model, model_urls['densenet121'])
return model return model
...@@ -154,20 +158,7 @@ def densenet169(pretrained=False, **kwargs): ...@@ -154,20 +158,7 @@ def densenet169(pretrained=False, **kwargs):
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32),
**kwargs) **kwargs)
if pretrained: if pretrained:
# '.'s are no longer allowed in module names, but pervious _DenseLayer _load_state_dict(model, model_urls['densenet169'])
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
# They are also in the checkpoints in model_urls. This pattern is used
# to find such keys.
pattern = re.compile(
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
state_dict = model_zoo.load_url(model_urls['densenet169'])
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
new_key = res.group(1) + res.group(2)
state_dict[new_key] = state_dict[key]
del state_dict[key]
model.load_state_dict(state_dict)
return model return model
...@@ -181,20 +172,7 @@ def densenet201(pretrained=False, **kwargs): ...@@ -181,20 +172,7 @@ def densenet201(pretrained=False, **kwargs):
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32),
**kwargs) **kwargs)
if pretrained: if pretrained:
# '.'s are no longer allowed in module names, but pervious _DenseLayer _load_state_dict(model, model_urls['densenet201'])
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
# They are also in the checkpoints in model_urls. This pattern is used
# to find such keys.
pattern = re.compile(
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
state_dict = model_zoo.load_url(model_urls['densenet201'])
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
new_key = res.group(1) + res.group(2)
state_dict[new_key] = state_dict[key]
del state_dict[key]
model.load_state_dict(state_dict)
return model return model
...@@ -208,18 +186,5 @@ def densenet161(pretrained=False, **kwargs): ...@@ -208,18 +186,5 @@ def densenet161(pretrained=False, **kwargs):
model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24),
**kwargs) **kwargs)
if pretrained: if pretrained:
# '.'s are no longer allowed in module names, but pervious _DenseLayer _load_state_dict(model, model_urls['densenet161'])
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
# They are also in the checkpoints in model_urls. This pattern is used
# to find such keys.
pattern = re.compile(
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
state_dict = model_zoo.load_url(model_urls['densenet161'])
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
new_key = res.group(1) + res.group(2)
state_dict[new_key] = state_dict[key]
del state_dict[key]
model.load_state_dict(state_dict)
return model return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment