"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "1328aeb274610f492c10a246ffba0bc4de8f689b"
Unverified Commit a1d6b314 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Refactor Segmentation models (#4646)

* Move FCN methods to itsown package.

* Fix lint.

* Move LRASPP methods to their own package.

* Move DeepLabV3 methods to their own package.

* Adding deprecation warning for torchvision.models.segmentation.segmentation.

* Refactoring deeplab.

* Setting aux default to false.

* Fixing imports.

* Passing backbones instead of backbone names to builders.

* Fixing mypy

* Addressing review comments.

* Correcting typing.

* Restoring special handling for references.
parent e4a4a29a
from .segmentation import *
from .fcn import * from .fcn import *
from .deeplabv3 import * from .deeplabv3 import *
from .lraspp import * from .lraspp import *
...@@ -4,6 +4,8 @@ from typing import Optional, Dict ...@@ -4,6 +4,8 @@ from typing import Optional, Dict
from torch import nn, Tensor from torch import nn, Tensor
from torch.nn import functional as F from torch.nn import functional as F
from ..._internally_replaced_utils import load_state_dict_from_url
class _SimpleSegmentationModel(nn.Module): class _SimpleSegmentationModel(nn.Module):
__constants__ = ["aux_classifier"] __constants__ = ["aux_classifier"]
...@@ -32,3 +34,10 @@ class _SimpleSegmentationModel(nn.Module): ...@@ -32,3 +34,10 @@ class _SimpleSegmentationModel(nn.Module):
result["aux"] = x result["aux"] = x
return result return result
def _load_weights(arch: str, model: nn.Module, model_url: Optional[str], progress: bool) -> None:
if model_url is None:
raise ValueError("No checkpoint is available for {}".format(arch))
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
from typing import List from typing import List, Optional
import torch import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from ._utils import _SimpleSegmentationModel from .. import mobilenetv3
from .. import resnet
from ..feature_extraction import create_feature_extractor
from ._utils import _SimpleSegmentationModel, _load_weights
from .fcn import FCNHead
__all__ = ["DeepLabV3"] __all__ = [
"DeepLabV3",
"deeplabv3_resnet50",
"deeplabv3_resnet101",
"deeplabv3_mobilenet_v3_large",
]
model_urls = {
"deeplabv3_resnet50_coco": "https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth",
"deeplabv3_resnet101_coco": "https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth",
"deeplabv3_mobilenet_v3_large_coco": "https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth",
}
class DeepLabV3(_SimpleSegmentationModel): class DeepLabV3(_SimpleSegmentationModel):
...@@ -95,3 +111,131 @@ class ASPP(nn.Module): ...@@ -95,3 +111,131 @@ class ASPP(nn.Module):
_res.append(conv(x)) _res.append(conv(x))
res = torch.cat(_res, dim=1) res = torch.cat(_res, dim=1)
return self.project(res) return self.project(res)
def _deeplabv3_resnet(
backbone: resnet.ResNet,
num_classes: int,
aux: Optional[bool],
) -> DeepLabV3:
return_layers = {"layer4": "out"}
if aux:
return_layers["layer3"] = "aux"
backbone = create_feature_extractor(backbone, return_layers)
aux_classifier = FCNHead(1024, num_classes) if aux else None
classifier = DeepLabHead(2048, num_classes)
return DeepLabV3(backbone, classifier, aux_classifier)
def _deeplabv3_mobilenetv3(
backbone: mobilenetv3.MobileNetV3,
num_classes: int,
aux: Optional[bool],
) -> DeepLabV3:
backbone = backbone.features
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
out_pos = stage_indices[-1] # use C5 which has output_stride = 16
out_inplanes = backbone[out_pos].out_channels
aux_pos = stage_indices[-4] # use C2 here which has output_stride = 8
aux_inplanes = backbone[aux_pos].out_channels
return_layers = {str(out_pos): "out"}
if aux:
return_layers[str(aux_pos)] = "aux"
backbone = create_feature_extractor(backbone, return_layers)
aux_classifier = FCNHead(aux_inplanes, num_classes) if aux else None
classifier = DeepLabHead(out_inplanes, num_classes)
return DeepLabV3(backbone, classifier, aux_classifier)
def deeplabv3_resnet50(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
aux_loss: Optional[bool] = None,
pretrained_backbone: bool = True,
) -> DeepLabV3:
"""Constructs a DeepLabV3 model with a ResNet-50 backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
aux_loss (bool, optional): If True, it uses an auxiliary loss
pretrained_backbone (bool): If True, the backbone will be pre-trained.
"""
if pretrained:
aux_loss = True
pretrained_backbone = False
backbone = resnet.resnet50(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True])
model = _deeplabv3_resnet(backbone, num_classes, aux_loss)
if pretrained:
arch = "deeplabv3_resnet50_coco"
_load_weights(arch, model, model_urls.get(arch, None), progress)
return model
def deeplabv3_resnet101(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
aux_loss: Optional[bool] = None,
pretrained_backbone: bool = True,
) -> DeepLabV3:
"""Constructs a DeepLabV3 model with a ResNet-101 backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): The number of classes
aux_loss (bool, optional): If True, include an auxiliary classifier
pretrained_backbone (bool): If True, the backbone will be pre-trained.
"""
if pretrained:
aux_loss = True
pretrained_backbone = False
backbone = resnet.resnet101(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True])
model = _deeplabv3_resnet(backbone, num_classes, aux_loss)
if pretrained:
arch = "deeplabv3_resnet101_coco"
_load_weights(arch, model, model_urls.get(arch, None), progress)
return model
def deeplabv3_mobilenet_v3_large(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
aux_loss: Optional[bool] = None,
pretrained_backbone: bool = True,
) -> DeepLabV3:
"""Constructs a DeepLabV3 model with a MobileNetV3-Large backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
aux_loss (bool, optional): If True, it uses an auxiliary loss
pretrained_backbone (bool): If True, the backbone will be pre-trained.
"""
if pretrained:
aux_loss = True
pretrained_backbone = False
backbone = mobilenetv3.mobilenet_v3_large(pretrained=pretrained_backbone, dilated=True)
model = _deeplabv3_mobilenetv3(backbone, num_classes, aux_loss)
if pretrained:
arch = "deeplabv3_mobilenet_v3_large_coco"
_load_weights(arch, model, model_urls.get(arch, None), progress)
return model
from typing import Optional
from torch import nn from torch import nn
from ._utils import _SimpleSegmentationModel from .. import resnet
from ..feature_extraction import create_feature_extractor
from ._utils import _SimpleSegmentationModel, _load_weights
__all__ = ["FCN", "fcn_resnet50", "fcn_resnet101"]
__all__ = ["FCN"] model_urls = {
"fcn_resnet50_coco": "https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth",
"fcn_resnet101_coco": "https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth",
}
class FCN(_SimpleSegmentationModel): class FCN(_SimpleSegmentationModel):
...@@ -35,3 +45,78 @@ class FCNHead(nn.Sequential): ...@@ -35,3 +45,78 @@ class FCNHead(nn.Sequential):
] ]
super(FCNHead, self).__init__(*layers) super(FCNHead, self).__init__(*layers)
def _fcn_resnet(
backbone: resnet.ResNet,
num_classes: int,
aux: Optional[bool],
) -> FCN:
return_layers = {"layer4": "out"}
if aux:
return_layers["layer3"] = "aux"
backbone = create_feature_extractor(backbone, return_layers)
aux_classifier = FCNHead(1024, num_classes) if aux else None
classifier = FCNHead(2048, num_classes)
return FCN(backbone, classifier, aux_classifier)
def fcn_resnet50(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
aux_loss: Optional[bool] = None,
pretrained_backbone: bool = True,
) -> FCN:
"""Constructs a Fully-Convolutional Network model with a ResNet-50 backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
aux_loss (bool, optional): If True, it uses an auxiliary loss
pretrained_backbone (bool): If True, the backbone will be pre-trained.
"""
if pretrained:
aux_loss = True
pretrained_backbone = False
backbone = resnet.resnet50(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True])
model = _fcn_resnet(backbone, num_classes, aux_loss)
if pretrained:
arch = "fcn_resnet50_coco"
_load_weights(arch, model, model_urls.get(arch, None), progress)
return model
def fcn_resnet101(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
aux_loss: Optional[bool] = None,
pretrained_backbone: bool = True,
) -> FCN:
"""Constructs a Fully-Convolutional Network model with a ResNet-101 backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
aux_loss (bool, optional): If True, it uses an auxiliary loss
pretrained_backbone (bool): If True, the backbone will be pre-trained.
"""
if pretrained:
aux_loss = True
pretrained_backbone = False
backbone = resnet.resnet101(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True])
model = _fcn_resnet(backbone, num_classes, aux_loss)
if pretrained:
arch = "fcn_resnet101_coco"
_load_weights(arch, model, model_urls.get(arch, None), progress)
return model
from collections import OrderedDict from collections import OrderedDict
from typing import Dict from typing import Any, Dict
from torch import nn, Tensor from torch import nn, Tensor
from torch.nn import functional as F from torch.nn import functional as F
from .. import mobilenetv3
from ..feature_extraction import create_feature_extractor
from ._utils import _load_weights
__all__ = ["LRASPP"]
__all__ = ["LRASPP", "lraspp_mobilenet_v3_large"]
model_urls = {
"lraspp_mobilenet_v3_large_coco": "https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth",
}
class LRASPP(nn.Module): class LRASPP(nn.Module):
...@@ -68,3 +77,47 @@ class LRASPPHead(nn.Module): ...@@ -68,3 +77,47 @@ class LRASPPHead(nn.Module):
x = F.interpolate(x, size=low.shape[-2:], mode="bilinear", align_corners=False) x = F.interpolate(x, size=low.shape[-2:], mode="bilinear", align_corners=False)
return self.low_classifier(low) + self.high_classifier(x) return self.low_classifier(low) + self.high_classifier(x)
def _lraspp_mobilenetv3(backbone: mobilenetv3.MobileNetV3, num_classes: int) -> LRASPP:
backbone = backbone.features
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
low_pos = stage_indices[-4] # use C2 here which has output_stride = 8
high_pos = stage_indices[-1] # use C5 which has output_stride = 16
low_channels = backbone[low_pos].out_channels
high_channels = backbone[high_pos].out_channels
backbone = create_feature_extractor(backbone, {str(low_pos): "low", str(high_pos): "high"})
return LRASPP(backbone, low_channels, high_channels, num_classes)
def lraspp_mobilenet_v3_large(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
pretrained_backbone: bool = True,
**kwargs: Any,
) -> LRASPP:
"""Constructs a Lite R-ASPP Network model with a MobileNetV3-Large backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
pretrained_backbone (bool): If True, the backbone will be pre-trained.
"""
if kwargs.pop("aux_loss", False):
raise NotImplementedError("This model does not use auxiliary loss")
if pretrained:
pretrained_backbone = False
backbone = mobilenetv3.mobilenet_v3_large(pretrained=pretrained_backbone, dilated=True)
model = _lraspp_mobilenetv3(backbone, num_classes)
if pretrained:
arch = "lraspp_mobilenet_v3_large_coco"
_load_weights(arch, model, model_urls.get(arch, None), progress)
return model
from typing import Any, Optional import warnings
from torch import nn # Import all methods/classes for BC:
from . import * # noqa: F401, F403
from ..._internally_replaced_utils import load_state_dict_from_url
from .. import mobilenetv3
from .. import resnet
from ..feature_extraction import create_feature_extractor
from .deeplabv3 import DeepLabHead, DeepLabV3
from .fcn import FCN, FCNHead
from .lraspp import LRASPP
warnings.warn(
__all__ = [ "The 'torchvision.models.segmentation.segmentation' module is deprecated. Please use directly the parent module "
"fcn_resnet50", "instead."
"fcn_resnet101", )
"deeplabv3_resnet50",
"deeplabv3_resnet101",
"deeplabv3_mobilenet_v3_large",
"lraspp_mobilenet_v3_large",
]
model_urls = {
"fcn_resnet50_coco": "https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth",
"fcn_resnet101_coco": "https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth",
"deeplabv3_resnet50_coco": "https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth",
"deeplabv3_resnet101_coco": "https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth",
"deeplabv3_mobilenet_v3_large_coco": "https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth",
"lraspp_mobilenet_v3_large_coco": "https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth",
}
def _segm_model(
name: str, backbone_name: str, num_classes: int, aux: Optional[bool], pretrained_backbone: bool = True
) -> nn.Module:
if "resnet" in backbone_name:
backbone = resnet.__dict__[backbone_name](
pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True]
)
out_layer = "layer4"
out_inplanes = 2048
aux_layer = "layer3"
aux_inplanes = 1024
elif "mobilenet_v3" in backbone_name:
backbone = mobilenetv3.__dict__[backbone_name](pretrained=pretrained_backbone, dilated=True).features
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
out_pos = stage_indices[-1] # use C5 which has output_stride = 16
out_layer = str(out_pos)
out_inplanes = backbone[out_pos].out_channels
aux_pos = stage_indices[-4] # use C2 here which has output_stride = 8
aux_layer = str(aux_pos)
aux_inplanes = backbone[aux_pos].out_channels
else:
raise NotImplementedError("backbone {} is not supported as of now".format(backbone_name))
return_layers = {out_layer: "out"}
if aux:
return_layers[aux_layer] = "aux"
backbone = create_feature_extractor(backbone, return_layers)
aux_classifier = None
if aux:
aux_classifier = FCNHead(aux_inplanes, num_classes)
model_map = {
"deeplabv3": (DeepLabHead, DeepLabV3),
"fcn": (FCNHead, FCN),
}
classifier = model_map[name][0](out_inplanes, num_classes)
base_model = model_map[name][1]
model = base_model(backbone, classifier, aux_classifier)
return model
def _load_model(
arch_type: str,
backbone: str,
pretrained: bool,
progress: bool,
num_classes: int,
aux_loss: Optional[bool],
**kwargs: Any,
) -> nn.Module:
if pretrained:
aux_loss = True
kwargs["pretrained_backbone"] = False
model = _segm_model(arch_type, backbone, num_classes, aux_loss, **kwargs)
if pretrained:
_load_weights(model, arch_type, backbone, progress)
return model
def _load_weights(model: nn.Module, arch_type: str, backbone: str, progress: bool) -> None:
arch = arch_type + "_" + backbone + "_coco"
model_url = model_urls.get(arch, None)
if model_url is None:
raise NotImplementedError("pretrained {} is not supported as of now".format(arch))
else:
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
def _segm_lraspp_mobilenetv3(backbone_name: str, num_classes: int, pretrained_backbone: bool = True) -> LRASPP:
backbone = mobilenetv3.__dict__[backbone_name](pretrained=pretrained_backbone, dilated=True).features
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
low_pos = stage_indices[-4] # use C2 here which has output_stride = 8
high_pos = stage_indices[-1] # use C5 which has output_stride = 16
low_channels = backbone[low_pos].out_channels
high_channels = backbone[high_pos].out_channels
backbone = create_feature_extractor(backbone, {str(low_pos): "low", str(high_pos): "high"})
model = LRASPP(backbone, low_channels, high_channels, num_classes)
return model
def fcn_resnet50(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
aux_loss: Optional[bool] = None,
**kwargs: Any,
) -> nn.Module:
"""Constructs a Fully-Convolutional Network model with a ResNet-50 backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
aux_loss (bool): If True, it uses an auxiliary loss
"""
return _load_model("fcn", "resnet50", pretrained, progress, num_classes, aux_loss, **kwargs)
def fcn_resnet101(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
aux_loss: Optional[bool] = None,
**kwargs: Any,
) -> nn.Module:
"""Constructs a Fully-Convolutional Network model with a ResNet-101 backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
aux_loss (bool): If True, it uses an auxiliary loss
"""
return _load_model("fcn", "resnet101", pretrained, progress, num_classes, aux_loss, **kwargs)
def deeplabv3_resnet50(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
aux_loss: Optional[bool] = None,
**kwargs: Any,
) -> nn.Module:
"""Constructs a DeepLabV3 model with a ResNet-50 backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
aux_loss (bool): If True, it uses an auxiliary loss
"""
return _load_model("deeplabv3", "resnet50", pretrained, progress, num_classes, aux_loss, **kwargs)
def deeplabv3_resnet101(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
aux_loss: Optional[bool] = None,
**kwargs: Any,
) -> nn.Module:
"""Constructs a DeepLabV3 model with a ResNet-101 backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): The number of classes
aux_loss (bool): If True, include an auxiliary classifier
"""
return _load_model("deeplabv3", "resnet101", pretrained, progress, num_classes, aux_loss, **kwargs)
def deeplabv3_mobilenet_v3_large(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
aux_loss: Optional[bool] = None,
**kwargs: Any,
) -> nn.Module:
"""Constructs a DeepLabV3 model with a MobileNetV3-Large backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
aux_loss (bool): If True, it uses an auxiliary loss
"""
return _load_model("deeplabv3", "mobilenet_v3_large", pretrained, progress, num_classes, aux_loss, **kwargs)
def lraspp_mobilenet_v3_large(
pretrained: bool = False, progress: bool = True, num_classes: int = 21, **kwargs: Any
) -> nn.Module:
"""Constructs a Lite R-ASPP Network model with a MobileNetV3-Large backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
"""
if kwargs.pop("aux_loss", False):
raise NotImplementedError("This model does not use auxiliary loss")
backbone_name = "mobilenet_v3_large"
if pretrained:
kwargs["pretrained_backbone"] = False
model = _segm_lraspp_mobilenetv3(backbone_name, num_classes, **kwargs)
if pretrained:
_load_weights(model, "lraspp", backbone_name, progress)
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment