Commit 2f64dd90 authored by ekka's avatar ekka Committed by Francisco Massa
Browse files

Refactor Segmentation model (#1009)

This PR uses a protected method for loading and initializing the segmentation models. Relevant #875
parent e4e167a3
......@@ -32,7 +32,7 @@ def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True
aux_classifier = FCNHead(inplanes, num_classes)
model_map = {
'deeplab': (DeepLabHead, DeepLabV3),
'deeplabv3': (DeepLabHead, DeepLabV3),
'fcn': (FCNHead, FCN),
}
inplanes = 2048
......@@ -43,20 +43,12 @@ def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True
return model
def fcn_resnet50(pretrained=False, progress=True,
num_classes=21, aux_loss=None, **kwargs):
"""Constructs a Fully-Convolutional Network model with a ResNet-50 backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
"""
def _load_model(arch_type, backbone, pretrained, progress, num_classes, aux_loss, **kwargs):
if pretrained:
aux_loss = True
model = _segm_resnet("fcn", "resnet50", num_classes, aux_loss, **kwargs)
model = _segm_resnet(arch_type, backbone, num_classes, aux_loss, **kwargs)
if pretrained:
arch = 'fcn_resnet50_coco'
arch = arch_type + '_' + backbone + '_coco'
model_url = model_urls[arch]
if model_url is None:
raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
......@@ -66,6 +58,18 @@ def fcn_resnet50(pretrained=False, progress=True,
return model
def fcn_resnet50(pretrained=False, progress=True,
num_classes=21, aux_loss=None, **kwargs):
"""Constructs a Fully-Convolutional Network model with a ResNet-50 backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _load_model('fcn', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs)
def fcn_resnet101(pretrained=False, progress=True,
num_classes=21, aux_loss=None, **kwargs):
"""Constructs a Fully-Convolutional Network model with a ResNet-101 backbone.
......@@ -75,18 +79,7 @@ def fcn_resnet101(pretrained=False, progress=True,
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
"""
if pretrained:
aux_loss = True
model = _segm_resnet("fcn", "resnet101", num_classes, aux_loss, **kwargs)
if pretrained:
arch = 'fcn_resnet101_coco'
model_url = model_urls[arch]
if model_url is None:
raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
else:
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
return model
return _load_model('fcn', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs)
def deeplabv3_resnet50(pretrained=False, progress=True,
......@@ -98,18 +91,7 @@ def deeplabv3_resnet50(pretrained=False, progress=True,
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
"""
if pretrained:
aux_loss = True
model = _segm_resnet("deeplab", "resnet50", num_classes, aux_loss, **kwargs)
if pretrained:
arch = 'deeplabv3_resnet50_coco'
model_url = model_urls[arch]
if model_url is None:
raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
else:
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
return model
return _load_model('deeplabv3', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs)
def deeplabv3_resnet101(pretrained=False, progress=True,
......@@ -121,15 +103,4 @@ def deeplabv3_resnet101(pretrained=False, progress=True,
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
"""
if pretrained:
aux_loss = True
model = _segm_resnet("deeplab", "resnet101", num_classes, aux_loss, **kwargs)
if pretrained:
arch = 'deeplabv3_resnet101_coco'
model_url = model_urls[arch]
if model_url is None:
raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
else:
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
return model
return _load_model('deeplabv3', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment