Unverified Commit 041b8ba1 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Add pre-trained models for semantic segmentation (#930)

Also adds documentation for the segmentation models
parent f76e598d
from .segmentation import * from .segmentation import *
from .fcn import *
from .deeplabv3 import *
...@@ -5,7 +5,24 @@ from torch.nn import functional as F ...@@ -5,7 +5,24 @@ from torch.nn import functional as F
from ._utils import _SimpleSegmentationModel from ._utils import _SimpleSegmentationModel
__all__ = ["DeepLabV3"]
class DeepLabV3(_SimpleSegmentationModel): class DeepLabV3(_SimpleSegmentationModel):
"""
Implements DeepLabV3 model from
`"Rethinking Atrous Convolution for Semantic Image Segmentation"
<https://arxiv.org/abs/1706.05587>`_.
Arguments:
backbone (nn.Module): the network used to compute the features for the model.
The backbone should return an OrderedDict[Tensor], with the key being
"out" for the last feature map used, and "aux" if an auxiliary classifier
is used.
classifier (nn.Module): module that takes the "out" element returned from
the backbone and returns a dense prediction.
aux_classifier (nn.Module, optional): auxiliary classifier used during training
"""
pass pass
......
...@@ -3,7 +3,22 @@ from torch import nn ...@@ -3,7 +3,22 @@ from torch import nn
from ._utils import _SimpleSegmentationModel from ._utils import _SimpleSegmentationModel
__all__ = ["FCN"]
class FCN(_SimpleSegmentationModel): class FCN(_SimpleSegmentationModel):
"""
Implements a Fully-Convolutional Network for semantic segmentation.
Arguments:
backbone (nn.Module): the network used to compute the features for the model.
The backbone should return an OrderedDict[Tensor], with the key being
"out" for the last feature map used, and "aux" if an auxiliary classifier
is used.
classifier (nn.Module): module that takes the "out" element returned from
the backbone and returns a dense prediction.
aux_classifier (nn.Module, optional): auxiliary classifier used during training
"""
pass pass
......
from .._utils import IntermediateLayerGetter from .._utils import IntermediateLayerGetter
from ..utils import load_state_dict_from_url
from .. import resnet from .. import resnet
from .deeplabv3 import DeepLabHead, DeepLabV3 from .deeplabv3 import DeepLabHead, DeepLabV3
from .fcn import FCN, FCNHead from .fcn import FCN, FCNHead
...@@ -7,6 +8,14 @@ from .fcn import FCN, FCNHead ...@@ -7,6 +8,14 @@ from .fcn import FCN, FCNHead
__all__ = ['fcn_resnet50', 'fcn_resnet101', 'deeplabv3_resnet50', 'deeplabv3_resnet101'] __all__ = ['fcn_resnet50', 'fcn_resnet101', 'deeplabv3_resnet50', 'deeplabv3_resnet101']
model_urls = {
'fcn_resnet50_coco': None,
'fcn_resnet101_coco': 'https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth',
'deeplabv3_resnet50_coco': None,
'deeplabv3_resnet101_coco': 'https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth',
}
def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True): def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True):
backbone = resnet.__dict__[backbone_name]( backbone = resnet.__dict__[backbone_name](
pretrained=pretrained_backbone, pretrained=pretrained_backbone,
...@@ -34,29 +43,93 @@ def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True ...@@ -34,29 +43,93 @@ def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True
return model return model
def fcn_resnet50(pretrained=False, num_classes=21, aux_loss=None, **kwargs): def fcn_resnet50(pretrained=False, progress=True,
num_classes=21, aux_loss=None, **kwargs):
"""Constructs a Fully-Convolutional Network model with a ResNet-50 backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
"""
if pretrained:
aux_loss = True
model = _segm_resnet("fcn", "resnet50", num_classes, aux_loss, **kwargs) model = _segm_resnet("fcn", "resnet50", num_classes, aux_loss, **kwargs)
if pretrained: if pretrained:
pass arch = 'fcn_resnet50_coco'
model_url = model_urls[arch]
if model_url is None:
raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
else:
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
return model return model
def fcn_resnet101(pretrained=False, num_classes=21, aux_loss=None, **kwargs): def fcn_resnet101(pretrained=False, progress=True,
num_classes=21, aux_loss=None, **kwargs):
"""Constructs a Fully-Convolutional Network model with a ResNet-101 backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
"""
if pretrained:
aux_loss = True
model = _segm_resnet("fcn", "resnet101", num_classes, aux_loss, **kwargs) model = _segm_resnet("fcn", "resnet101", num_classes, aux_loss, **kwargs)
if pretrained: if pretrained:
pass arch = 'fcn_resnet101_coco'
model_url = model_urls[arch]
if model_url is None:
raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
else:
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
return model return model
def deeplabv3_resnet50(pretrained=False, num_classes=21, aux_loss=None, **kwargs): def deeplabv3_resnet50(pretrained=False, progress=True,
num_classes=21, aux_loss=None, **kwargs):
"""Constructs a DeepLabV3 model with a ResNet-50 backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
"""
if pretrained:
aux_loss = True
model = _segm_resnet("deeplab", "resnet50", num_classes, aux_loss, **kwargs) model = _segm_resnet("deeplab", "resnet50", num_classes, aux_loss, **kwargs)
if pretrained: if pretrained:
pass arch = 'deeplabv3_resnet50_coco'
model_url = model_urls[arch]
if model_url is None:
raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
else:
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
return model return model
def deeplabv3_resnet101(pretrained=False, num_classes=21, aux_loss=None, **kwargs): def deeplabv3_resnet101(pretrained=False, progress=True,
num_classes=21, aux_loss=None, **kwargs):
"""Constructs a DeepLabV3 model with a ResNet-101 backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
"""
if pretrained:
aux_loss = True
model = _segm_resnet("deeplab", "resnet101", num_classes, aux_loss, **kwargs) model = _segm_resnet("deeplab", "resnet101", num_classes, aux_loss, **kwargs)
if pretrained: if pretrained:
pass arch = 'deeplabv3_resnet101_coco'
model_url = model_urls[arch]
if model_url is None:
raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
else:
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
return model return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment