Unverified Commit a1d6b314 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Refactor Segmentation models (#4646)

* Move FCN methods to itsown package.

* Fix lint.

* Move LRASPP methods to their own package.

* Move DeepLabV3 methods to their own package.

* Adding deprecation warning for torchvision.models.segmentation.segmentation.

* Refactoring deeplab.

* Setting aux default to false.

* Fixing imports.

* Passing backbones instead of backbone names to builders.

* Fixing mypy

* Addressing review comments.

* Correcting typing.

* Restoring special handling for references.
parent e4a4a29a
from .segmentation import *
from .fcn import *
from .deeplabv3 import *
from .lraspp import *
......@@ -4,6 +4,8 @@ from typing import Optional, Dict
from torch import nn, Tensor
from torch.nn import functional as F
from ..._internally_replaced_utils import load_state_dict_from_url
class _SimpleSegmentationModel(nn.Module):
__constants__ = ["aux_classifier"]
......@@ -32,3 +34,10 @@ class _SimpleSegmentationModel(nn.Module):
result["aux"] = x
return result
def _load_weights(arch: str, model: nn.Module, model_url: Optional[str], progress: bool) -> None:
if model_url is None:
raise ValueError("No checkpoint is available for {}".format(arch))
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
from typing import List
from typing import List, Optional
import torch
from torch import nn
from torch.nn import functional as F
from ._utils import _SimpleSegmentationModel
from .. import mobilenetv3
from .. import resnet
from ..feature_extraction import create_feature_extractor
from ._utils import _SimpleSegmentationModel, _load_weights
from .fcn import FCNHead
__all__ = ["DeepLabV3"]
__all__ = [
"DeepLabV3",
"deeplabv3_resnet50",
"deeplabv3_resnet101",
"deeplabv3_mobilenet_v3_large",
]
model_urls = {
"deeplabv3_resnet50_coco": "https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth",
"deeplabv3_resnet101_coco": "https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth",
"deeplabv3_mobilenet_v3_large_coco": "https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth",
}
class DeepLabV3(_SimpleSegmentationModel):
......@@ -95,3 +111,131 @@ class ASPP(nn.Module):
_res.append(conv(x))
res = torch.cat(_res, dim=1)
return self.project(res)
def _deeplabv3_resnet(
backbone: resnet.ResNet,
num_classes: int,
aux: Optional[bool],
) -> DeepLabV3:
return_layers = {"layer4": "out"}
if aux:
return_layers["layer3"] = "aux"
backbone = create_feature_extractor(backbone, return_layers)
aux_classifier = FCNHead(1024, num_classes) if aux else None
classifier = DeepLabHead(2048, num_classes)
return DeepLabV3(backbone, classifier, aux_classifier)
def _deeplabv3_mobilenetv3(
backbone: mobilenetv3.MobileNetV3,
num_classes: int,
aux: Optional[bool],
) -> DeepLabV3:
backbone = backbone.features
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
out_pos = stage_indices[-1] # use C5 which has output_stride = 16
out_inplanes = backbone[out_pos].out_channels
aux_pos = stage_indices[-4] # use C2 here which has output_stride = 8
aux_inplanes = backbone[aux_pos].out_channels
return_layers = {str(out_pos): "out"}
if aux:
return_layers[str(aux_pos)] = "aux"
backbone = create_feature_extractor(backbone, return_layers)
aux_classifier = FCNHead(aux_inplanes, num_classes) if aux else None
classifier = DeepLabHead(out_inplanes, num_classes)
return DeepLabV3(backbone, classifier, aux_classifier)
def deeplabv3_resnet50(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
aux_loss: Optional[bool] = None,
pretrained_backbone: bool = True,
) -> DeepLabV3:
"""Constructs a DeepLabV3 model with a ResNet-50 backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
aux_loss (bool, optional): If True, it uses an auxiliary loss
pretrained_backbone (bool): If True, the backbone will be pre-trained.
"""
if pretrained:
aux_loss = True
pretrained_backbone = False
backbone = resnet.resnet50(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True])
model = _deeplabv3_resnet(backbone, num_classes, aux_loss)
if pretrained:
arch = "deeplabv3_resnet50_coco"
_load_weights(arch, model, model_urls.get(arch, None), progress)
return model
def deeplabv3_resnet101(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
aux_loss: Optional[bool] = None,
pretrained_backbone: bool = True,
) -> DeepLabV3:
"""Constructs a DeepLabV3 model with a ResNet-101 backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): The number of classes
aux_loss (bool, optional): If True, include an auxiliary classifier
pretrained_backbone (bool): If True, the backbone will be pre-trained.
"""
if pretrained:
aux_loss = True
pretrained_backbone = False
backbone = resnet.resnet101(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True])
model = _deeplabv3_resnet(backbone, num_classes, aux_loss)
if pretrained:
arch = "deeplabv3_resnet101_coco"
_load_weights(arch, model, model_urls.get(arch, None), progress)
return model
def deeplabv3_mobilenet_v3_large(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
aux_loss: Optional[bool] = None,
pretrained_backbone: bool = True,
) -> DeepLabV3:
"""Constructs a DeepLabV3 model with a MobileNetV3-Large backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
aux_loss (bool, optional): If True, it uses an auxiliary loss
pretrained_backbone (bool): If True, the backbone will be pre-trained.
"""
if pretrained:
aux_loss = True
pretrained_backbone = False
backbone = mobilenetv3.mobilenet_v3_large(pretrained=pretrained_backbone, dilated=True)
model = _deeplabv3_mobilenetv3(backbone, num_classes, aux_loss)
if pretrained:
arch = "deeplabv3_mobilenet_v3_large_coco"
_load_weights(arch, model, model_urls.get(arch, None), progress)
return model
from typing import Optional
from torch import nn
from ._utils import _SimpleSegmentationModel
from .. import resnet
from ..feature_extraction import create_feature_extractor
from ._utils import _SimpleSegmentationModel, _load_weights
__all__ = ["FCN", "fcn_resnet50", "fcn_resnet101"]
__all__ = ["FCN"]
model_urls = {
"fcn_resnet50_coco": "https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth",
"fcn_resnet101_coco": "https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth",
}
class FCN(_SimpleSegmentationModel):
......@@ -35,3 +45,78 @@ class FCNHead(nn.Sequential):
]
super(FCNHead, self).__init__(*layers)
def _fcn_resnet(
backbone: resnet.ResNet,
num_classes: int,
aux: Optional[bool],
) -> FCN:
return_layers = {"layer4": "out"}
if aux:
return_layers["layer3"] = "aux"
backbone = create_feature_extractor(backbone, return_layers)
aux_classifier = FCNHead(1024, num_classes) if aux else None
classifier = FCNHead(2048, num_classes)
return FCN(backbone, classifier, aux_classifier)
def fcn_resnet50(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
aux_loss: Optional[bool] = None,
pretrained_backbone: bool = True,
) -> FCN:
"""Constructs a Fully-Convolutional Network model with a ResNet-50 backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
aux_loss (bool, optional): If True, it uses an auxiliary loss
pretrained_backbone (bool): If True, the backbone will be pre-trained.
"""
if pretrained:
aux_loss = True
pretrained_backbone = False
backbone = resnet.resnet50(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True])
model = _fcn_resnet(backbone, num_classes, aux_loss)
if pretrained:
arch = "fcn_resnet50_coco"
_load_weights(arch, model, model_urls.get(arch, None), progress)
return model
def fcn_resnet101(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
aux_loss: Optional[bool] = None,
pretrained_backbone: bool = True,
) -> FCN:
"""Constructs a Fully-Convolutional Network model with a ResNet-101 backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
aux_loss (bool, optional): If True, it uses an auxiliary loss
pretrained_backbone (bool): If True, the backbone will be pre-trained.
"""
if pretrained:
aux_loss = True
pretrained_backbone = False
backbone = resnet.resnet101(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True])
model = _fcn_resnet(backbone, num_classes, aux_loss)
if pretrained:
arch = "fcn_resnet101_coco"
_load_weights(arch, model, model_urls.get(arch, None), progress)
return model
from collections import OrderedDict
from typing import Dict
from typing import Any, Dict
from torch import nn, Tensor
from torch.nn import functional as F
from .. import mobilenetv3
from ..feature_extraction import create_feature_extractor
from ._utils import _load_weights
__all__ = ["LRASPP"]
__all__ = ["LRASPP", "lraspp_mobilenet_v3_large"]
model_urls = {
"lraspp_mobilenet_v3_large_coco": "https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth",
}
class LRASPP(nn.Module):
......@@ -68,3 +77,47 @@ class LRASPPHead(nn.Module):
x = F.interpolate(x, size=low.shape[-2:], mode="bilinear", align_corners=False)
return self.low_classifier(low) + self.high_classifier(x)
def _lraspp_mobilenetv3(backbone: mobilenetv3.MobileNetV3, num_classes: int) -> LRASPP:
backbone = backbone.features
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
low_pos = stage_indices[-4] # use C2 here which has output_stride = 8
high_pos = stage_indices[-1] # use C5 which has output_stride = 16
low_channels = backbone[low_pos].out_channels
high_channels = backbone[high_pos].out_channels
backbone = create_feature_extractor(backbone, {str(low_pos): "low", str(high_pos): "high"})
return LRASPP(backbone, low_channels, high_channels, num_classes)
def lraspp_mobilenet_v3_large(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
pretrained_backbone: bool = True,
**kwargs: Any,
) -> LRASPP:
"""Constructs a Lite R-ASPP Network model with a MobileNetV3-Large backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
pretrained_backbone (bool): If True, the backbone will be pre-trained.
"""
if kwargs.pop("aux_loss", False):
raise NotImplementedError("This model does not use auxiliary loss")
if pretrained:
pretrained_backbone = False
backbone = mobilenetv3.mobilenet_v3_large(pretrained=pretrained_backbone, dilated=True)
model = _lraspp_mobilenetv3(backbone, num_classes)
if pretrained:
arch = "lraspp_mobilenet_v3_large_coco"
_load_weights(arch, model, model_urls.get(arch, None), progress)
return model
from typing import Any, Optional
import warnings
from torch import nn
# Import all methods/classes for BC:
from . import * # noqa: F401, F403
from ..._internally_replaced_utils import load_state_dict_from_url
from .. import mobilenetv3
from .. import resnet
from ..feature_extraction import create_feature_extractor
from .deeplabv3 import DeepLabHead, DeepLabV3
from .fcn import FCN, FCNHead
from .lraspp import LRASPP
__all__ = [
"fcn_resnet50",
"fcn_resnet101",
"deeplabv3_resnet50",
"deeplabv3_resnet101",
"deeplabv3_mobilenet_v3_large",
"lraspp_mobilenet_v3_large",
]
model_urls = {
"fcn_resnet50_coco": "https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth",
"fcn_resnet101_coco": "https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth",
"deeplabv3_resnet50_coco": "https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth",
"deeplabv3_resnet101_coco": "https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth",
"deeplabv3_mobilenet_v3_large_coco": "https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth",
"lraspp_mobilenet_v3_large_coco": "https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth",
}
def _segm_model(
name: str, backbone_name: str, num_classes: int, aux: Optional[bool], pretrained_backbone: bool = True
) -> nn.Module:
if "resnet" in backbone_name:
backbone = resnet.__dict__[backbone_name](
pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True]
)
out_layer = "layer4"
out_inplanes = 2048
aux_layer = "layer3"
aux_inplanes = 1024
elif "mobilenet_v3" in backbone_name:
backbone = mobilenetv3.__dict__[backbone_name](pretrained=pretrained_backbone, dilated=True).features
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
out_pos = stage_indices[-1] # use C5 which has output_stride = 16
out_layer = str(out_pos)
out_inplanes = backbone[out_pos].out_channels
aux_pos = stage_indices[-4] # use C2 here which has output_stride = 8
aux_layer = str(aux_pos)
aux_inplanes = backbone[aux_pos].out_channels
else:
raise NotImplementedError("backbone {} is not supported as of now".format(backbone_name))
return_layers = {out_layer: "out"}
if aux:
return_layers[aux_layer] = "aux"
backbone = create_feature_extractor(backbone, return_layers)
aux_classifier = None
if aux:
aux_classifier = FCNHead(aux_inplanes, num_classes)
model_map = {
"deeplabv3": (DeepLabHead, DeepLabV3),
"fcn": (FCNHead, FCN),
}
classifier = model_map[name][0](out_inplanes, num_classes)
base_model = model_map[name][1]
model = base_model(backbone, classifier, aux_classifier)
return model
def _load_model(
arch_type: str,
backbone: str,
pretrained: bool,
progress: bool,
num_classes: int,
aux_loss: Optional[bool],
**kwargs: Any,
) -> nn.Module:
if pretrained:
aux_loss = True
kwargs["pretrained_backbone"] = False
model = _segm_model(arch_type, backbone, num_classes, aux_loss, **kwargs)
if pretrained:
_load_weights(model, arch_type, backbone, progress)
return model
def _load_weights(model: nn.Module, arch_type: str, backbone: str, progress: bool) -> None:
arch = arch_type + "_" + backbone + "_coco"
model_url = model_urls.get(arch, None)
if model_url is None:
raise NotImplementedError("pretrained {} is not supported as of now".format(arch))
else:
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
def _segm_lraspp_mobilenetv3(backbone_name: str, num_classes: int, pretrained_backbone: bool = True) -> LRASPP:
backbone = mobilenetv3.__dict__[backbone_name](pretrained=pretrained_backbone, dilated=True).features
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
low_pos = stage_indices[-4] # use C2 here which has output_stride = 8
high_pos = stage_indices[-1] # use C5 which has output_stride = 16
low_channels = backbone[low_pos].out_channels
high_channels = backbone[high_pos].out_channels
backbone = create_feature_extractor(backbone, {str(low_pos): "low", str(high_pos): "high"})
model = LRASPP(backbone, low_channels, high_channels, num_classes)
return model
def fcn_resnet50(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
aux_loss: Optional[bool] = None,
**kwargs: Any,
) -> nn.Module:
"""Constructs a Fully-Convolutional Network model with a ResNet-50 backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
aux_loss (bool): If True, it uses an auxiliary loss
"""
return _load_model("fcn", "resnet50", pretrained, progress, num_classes, aux_loss, **kwargs)
def fcn_resnet101(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
aux_loss: Optional[bool] = None,
**kwargs: Any,
) -> nn.Module:
"""Constructs a Fully-Convolutional Network model with a ResNet-101 backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
aux_loss (bool): If True, it uses an auxiliary loss
"""
return _load_model("fcn", "resnet101", pretrained, progress, num_classes, aux_loss, **kwargs)
def deeplabv3_resnet50(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
aux_loss: Optional[bool] = None,
**kwargs: Any,
) -> nn.Module:
"""Constructs a DeepLabV3 model with a ResNet-50 backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
aux_loss (bool): If True, it uses an auxiliary loss
"""
return _load_model("deeplabv3", "resnet50", pretrained, progress, num_classes, aux_loss, **kwargs)
def deeplabv3_resnet101(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
aux_loss: Optional[bool] = None,
**kwargs: Any,
) -> nn.Module:
"""Constructs a DeepLabV3 model with a ResNet-101 backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): The number of classes
aux_loss (bool): If True, include an auxiliary classifier
"""
return _load_model("deeplabv3", "resnet101", pretrained, progress, num_classes, aux_loss, **kwargs)
def deeplabv3_mobilenet_v3_large(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
aux_loss: Optional[bool] = None,
**kwargs: Any,
) -> nn.Module:
"""Constructs a DeepLabV3 model with a MobileNetV3-Large backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
aux_loss (bool): If True, it uses an auxiliary loss
"""
return _load_model("deeplabv3", "mobilenet_v3_large", pretrained, progress, num_classes, aux_loss, **kwargs)
def lraspp_mobilenet_v3_large(
pretrained: bool = False, progress: bool = True, num_classes: int = 21, **kwargs: Any
) -> nn.Module:
"""Constructs a Lite R-ASPP Network model with a MobileNetV3-Large backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
"""
if kwargs.pop("aux_loss", False):
raise NotImplementedError("This model does not use auxiliary loss")
backbone_name = "mobilenet_v3_large"
if pretrained:
kwargs["pretrained_backbone"] = False
model = _segm_lraspp_mobilenetv3(backbone_name, num_classes, **kwargs)
if pretrained:
_load_weights(model, "lraspp", backbone_name, progress)
return model
warnings.warn(
"The 'torchvision.models.segmentation.segmentation' module is deprecated. Please use directly the parent module "
"instead."
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment