Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
041b8ba1
Unverified
Commit
041b8ba1
authored
May 21, 2019
by
Francisco Massa
Committed by
GitHub
May 21, 2019
Browse files
Add pre-trained models for semantic segmentation (#930)
Also adds documentation for the segmentation models
parent
f76e598d
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
115 additions
and
8 deletions
+115
-8
torchvision/models/segmentation/__init__.py
torchvision/models/segmentation/__init__.py
+2
-0
torchvision/models/segmentation/deeplabv3.py
torchvision/models/segmentation/deeplabv3.py
+17
-0
torchvision/models/segmentation/fcn.py
torchvision/models/segmentation/fcn.py
+15
-0
torchvision/models/segmentation/segmentation.py
torchvision/models/segmentation/segmentation.py
+81
-8
No files found.
torchvision/models/segmentation/__init__.py
View file @
041b8ba1
from
.segmentation
import
*
from
.fcn
import
*
from
.deeplabv3
import
*
torchvision/models/segmentation/deeplabv3.py
View file @
041b8ba1
...
...
@@ -5,7 +5,24 @@ from torch.nn import functional as F
from
._utils
import
_SimpleSegmentationModel
__all__
=
[
"DeepLabV3"
]
class
DeepLabV3
(
_SimpleSegmentationModel
):
"""
Implements DeepLabV3 model from
`"Rethinking Atrous Convolution for Semantic Image Segmentation"
<https://arxiv.org/abs/1706.05587>`_.
Arguments:
backbone (nn.Module): the network used to compute the features for the model.
The backbone should return an OrderedDict[Tensor], with the key being
"out" for the last feature map used, and "aux" if an auxiliary classifier
is used.
classifier (nn.Module): module that takes the "out" element returned from
the backbone and returns a dense prediction.
aux_classifier (nn.Module, optional): auxiliary classifier used during training
"""
pass
...
...
torchvision/models/segmentation/fcn.py
View file @
041b8ba1
...
...
@@ -3,7 +3,22 @@ from torch import nn
from
._utils
import
_SimpleSegmentationModel
__all__
=
[
"FCN"
]
class
FCN
(
_SimpleSegmentationModel
):
"""
Implements a Fully-Convolutional Network for semantic segmentation.
Arguments:
backbone (nn.Module): the network used to compute the features for the model.
The backbone should return an OrderedDict[Tensor], with the key being
"out" for the last feature map used, and "aux" if an auxiliary classifier
is used.
classifier (nn.Module): module that takes the "out" element returned from
the backbone and returns a dense prediction.
aux_classifier (nn.Module, optional): auxiliary classifier used during training
"""
pass
...
...
torchvision/models/segmentation/segmentation.py
View file @
041b8ba1
from
.._utils
import
IntermediateLayerGetter
from
..utils
import
load_state_dict_from_url
from
..
import
resnet
from
.deeplabv3
import
DeepLabHead
,
DeepLabV3
from
.fcn
import
FCN
,
FCNHead
...
...
@@ -7,6 +8,14 @@ from .fcn import FCN, FCNHead
__all__
=
[
'fcn_resnet50'
,
'fcn_resnet101'
,
'deeplabv3_resnet50'
,
'deeplabv3_resnet101'
]
model_urls
=
{
'fcn_resnet50_coco'
:
None
,
'fcn_resnet101_coco'
:
'https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth'
,
'deeplabv3_resnet50_coco'
:
None
,
'deeplabv3_resnet101_coco'
:
'https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth'
,
}
def
_segm_resnet
(
name
,
backbone_name
,
num_classes
,
aux
,
pretrained_backbone
=
True
):
backbone
=
resnet
.
__dict__
[
backbone_name
](
pretrained
=
pretrained_backbone
,
...
...
@@ -34,29 +43,93 @@ def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True
return
model
def
fcn_resnet50
(
pretrained
=
False
,
num_classes
=
21
,
aux_loss
=
None
,
**
kwargs
):
def
fcn_resnet50
(
pretrained
=
False
,
progress
=
True
,
num_classes
=
21
,
aux_loss
=
None
,
**
kwargs
):
"""Constructs a Fully-Convolutional Network model with a ResNet-50 backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
"""
if
pretrained
:
aux_loss
=
True
model
=
_segm_resnet
(
"fcn"
,
"resnet50"
,
num_classes
,
aux_loss
,
**
kwargs
)
if
pretrained
:
pass
arch
=
'fcn_resnet50_coco'
model_url
=
model_urls
[
arch
]
if
model_url
is
None
:
raise
NotImplementedError
(
'pretrained {} is not supported as of now'
.
format
(
arch
))
else
:
state_dict
=
load_state_dict_from_url
(
model_url
,
progress
=
progress
)
model
.
load_state_dict
(
state_dict
)
return
model
def
fcn_resnet101
(
pretrained
=
False
,
num_classes
=
21
,
aux_loss
=
None
,
**
kwargs
):
def
fcn_resnet101
(
pretrained
=
False
,
progress
=
True
,
num_classes
=
21
,
aux_loss
=
None
,
**
kwargs
):
"""Constructs a Fully-Convolutional Network model with a ResNet-101 backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
"""
if
pretrained
:
aux_loss
=
True
model
=
_segm_resnet
(
"fcn"
,
"resnet101"
,
num_classes
,
aux_loss
,
**
kwargs
)
if
pretrained
:
pass
arch
=
'fcn_resnet101_coco'
model_url
=
model_urls
[
arch
]
if
model_url
is
None
:
raise
NotImplementedError
(
'pretrained {} is not supported as of now'
.
format
(
arch
))
else
:
state_dict
=
load_state_dict_from_url
(
model_url
,
progress
=
progress
)
model
.
load_state_dict
(
state_dict
)
return
model
def
deeplabv3_resnet50
(
pretrained
=
False
,
num_classes
=
21
,
aux_loss
=
None
,
**
kwargs
):
def
deeplabv3_resnet50
(
pretrained
=
False
,
progress
=
True
,
num_classes
=
21
,
aux_loss
=
None
,
**
kwargs
):
"""Constructs a DeepLabV3 model with a ResNet-50 backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
"""
if
pretrained
:
aux_loss
=
True
model
=
_segm_resnet
(
"deeplab"
,
"resnet50"
,
num_classes
,
aux_loss
,
**
kwargs
)
if
pretrained
:
pass
arch
=
'deeplabv3_resnet50_coco'
model_url
=
model_urls
[
arch
]
if
model_url
is
None
:
raise
NotImplementedError
(
'pretrained {} is not supported as of now'
.
format
(
arch
))
else
:
state_dict
=
load_state_dict_from_url
(
model_url
,
progress
=
progress
)
model
.
load_state_dict
(
state_dict
)
return
model
def
deeplabv3_resnet101
(
pretrained
=
False
,
num_classes
=
21
,
aux_loss
=
None
,
**
kwargs
):
def
deeplabv3_resnet101
(
pretrained
=
False
,
progress
=
True
,
num_classes
=
21
,
aux_loss
=
None
,
**
kwargs
):
"""Constructs a DeepLabV3 model with a ResNet-101 backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
"""
if
pretrained
:
aux_loss
=
True
model
=
_segm_resnet
(
"deeplab"
,
"resnet101"
,
num_classes
,
aux_loss
,
**
kwargs
)
if
pretrained
:
pass
arch
=
'deeplabv3_resnet101_coco'
model_url
=
model_urls
[
arch
]
if
model_url
is
None
:
raise
NotImplementedError
(
'pretrained {} is not supported as of now'
.
format
(
arch
))
else
:
state_dict
=
load_state_dict_from_url
(
model_url
,
progress
=
progress
)
model
.
load_state_dict
(
state_dict
)
return
model
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment