Commit 2f64dd90 authored by ekka's avatar ekka Committed by Francisco Massa
Browse files

Refactor Segmentation model (#1009)

This PR uses a protected method for loading and initializing the segmentation models. Relevant #875
parent e4e167a3
...@@ -32,7 +32,7 @@ def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True ...@@ -32,7 +32,7 @@ def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True
aux_classifier = FCNHead(inplanes, num_classes) aux_classifier = FCNHead(inplanes, num_classes)
model_map = { model_map = {
'deeplab': (DeepLabHead, DeepLabV3), 'deeplabv3': (DeepLabHead, DeepLabV3),
'fcn': (FCNHead, FCN), 'fcn': (FCNHead, FCN),
} }
inplanes = 2048 inplanes = 2048
...@@ -43,20 +43,12 @@ def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True ...@@ -43,20 +43,12 @@ def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True
return model return model
def fcn_resnet50(pretrained=False, progress=True, def _load_model(arch_type, backbone, pretrained, progress, num_classes, aux_loss, **kwargs):
num_classes=21, aux_loss=None, **kwargs):
"""Constructs a Fully-Convolutional Network model with a ResNet-50 backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
"""
if pretrained: if pretrained:
aux_loss = True aux_loss = True
model = _segm_resnet("fcn", "resnet50", num_classes, aux_loss, **kwargs) model = _segm_resnet(arch_type, backbone, num_classes, aux_loss, **kwargs)
if pretrained: if pretrained:
arch = 'fcn_resnet50_coco' arch = arch_type + '_' + backbone + '_coco'
model_url = model_urls[arch] model_url = model_urls[arch]
if model_url is None: if model_url is None:
raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
...@@ -66,6 +58,18 @@ def fcn_resnet50(pretrained=False, progress=True, ...@@ -66,6 +58,18 @@ def fcn_resnet50(pretrained=False, progress=True,
return model return model
def fcn_resnet50(pretrained=False, progress=True,
num_classes=21, aux_loss=None, **kwargs):
"""Constructs a Fully-Convolutional Network model with a ResNet-50 backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _load_model('fcn', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs)
def fcn_resnet101(pretrained=False, progress=True, def fcn_resnet101(pretrained=False, progress=True,
num_classes=21, aux_loss=None, **kwargs): num_classes=21, aux_loss=None, **kwargs):
"""Constructs a Fully-Convolutional Network model with a ResNet-101 backbone. """Constructs a Fully-Convolutional Network model with a ResNet-101 backbone.
...@@ -75,18 +79,7 @@ def fcn_resnet101(pretrained=False, progress=True, ...@@ -75,18 +79,7 @@ def fcn_resnet101(pretrained=False, progress=True,
contains the same classes as Pascal VOC contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
if pretrained: return _load_model('fcn', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs)
aux_loss = True
model = _segm_resnet("fcn", "resnet101", num_classes, aux_loss, **kwargs)
if pretrained:
arch = 'fcn_resnet101_coco'
model_url = model_urls[arch]
if model_url is None:
raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
else:
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
return model
def deeplabv3_resnet50(pretrained=False, progress=True, def deeplabv3_resnet50(pretrained=False, progress=True,
...@@ -98,18 +91,7 @@ def deeplabv3_resnet50(pretrained=False, progress=True, ...@@ -98,18 +91,7 @@ def deeplabv3_resnet50(pretrained=False, progress=True,
contains the same classes as Pascal VOC contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
if pretrained: return _load_model('deeplabv3', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs)
aux_loss = True
model = _segm_resnet("deeplab", "resnet50", num_classes, aux_loss, **kwargs)
if pretrained:
arch = 'deeplabv3_resnet50_coco'
model_url = model_urls[arch]
if model_url is None:
raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
else:
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
return model
def deeplabv3_resnet101(pretrained=False, progress=True, def deeplabv3_resnet101(pretrained=False, progress=True,
...@@ -121,15 +103,4 @@ def deeplabv3_resnet101(pretrained=False, progress=True, ...@@ -121,15 +103,4 @@ def deeplabv3_resnet101(pretrained=False, progress=True,
contains the same classes as Pascal VOC contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
if pretrained: return _load_model('deeplabv3', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs)
aux_loss = True
model = _segm_resnet("deeplab", "resnet101", num_classes, aux_loss, **kwargs)
if pretrained:
arch = 'deeplabv3_resnet101_coco'
model_url = model_urls[arch]
if model_url is None:
raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
else:
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment