"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "b64c5227595f5eed10f6ff3ac7953de0bb07ab2d"
Unverified Commit 185be3a9 authored by F-G Fernandez's avatar F-G Fernandez Committed by GitHub
Browse files

Added typing annotations to models/segmentation (#4227)

* style: Added typing annotations to segmentation/_utils

* style: Added typing annotations to segmentation/segmentation

* style: Added typing annotations to remaining segmentation models

* style: Fixed typing of DeepLab

* style: Fixed typing

* fix: Fixed typing annotations & default values

* Fixing python_type_check
parent 7947fc8f
from collections import OrderedDict from collections import OrderedDict
from typing import Optional, Dict
from torch import nn from torch import nn, Tensor
from torch.nn import functional as F from torch.nn import functional as F
class _SimpleSegmentationModel(nn.Module): class _SimpleSegmentationModel(nn.Module):
__constants__ = ['aux_classifier'] __constants__ = ['aux_classifier']
def __init__(self, backbone, classifier, aux_classifier=None): def __init__(
self,
backbone: nn.Module,
classifier: nn.Module,
aux_classifier: Optional[nn.Module] = None
) -> None:
super(_SimpleSegmentationModel, self).__init__() super(_SimpleSegmentationModel, self).__init__()
self.backbone = backbone self.backbone = backbone
self.classifier = classifier self.classifier = classifier
self.aux_classifier = aux_classifier self.aux_classifier = aux_classifier
def forward(self, x): def forward(self, x: Tensor) -> Dict[str, Tensor]:
input_shape = x.shape[-2:] input_shape = x.shape[-2:]
# contract: features is a dict of tensors # contract: features is a dict of tensors
features = self.backbone(x) features = self.backbone(x)
......
import torch import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from typing import List
from ._utils import _SimpleSegmentationModel from ._utils import _SimpleSegmentationModel
...@@ -27,7 +28,7 @@ class DeepLabV3(_SimpleSegmentationModel): ...@@ -27,7 +28,7 @@ class DeepLabV3(_SimpleSegmentationModel):
class DeepLabHead(nn.Sequential): class DeepLabHead(nn.Sequential):
def __init__(self, in_channels, num_classes): def __init__(self, in_channels: int, num_classes: int) -> None:
super(DeepLabHead, self).__init__( super(DeepLabHead, self).__init__(
ASPP(in_channels, [12, 24, 36]), ASPP(in_channels, [12, 24, 36]),
nn.Conv2d(256, 256, 3, padding=1, bias=False), nn.Conv2d(256, 256, 3, padding=1, bias=False),
...@@ -38,7 +39,7 @@ class DeepLabHead(nn.Sequential): ...@@ -38,7 +39,7 @@ class DeepLabHead(nn.Sequential):
class ASPPConv(nn.Sequential): class ASPPConv(nn.Sequential):
def __init__(self, in_channels, out_channels, dilation): def __init__(self, in_channels: int, out_channels: int, dilation: int) -> None:
modules = [ modules = [
nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False), nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
nn.BatchNorm2d(out_channels), nn.BatchNorm2d(out_channels),
...@@ -48,14 +49,14 @@ class ASPPConv(nn.Sequential): ...@@ -48,14 +49,14 @@ class ASPPConv(nn.Sequential):
class ASPPPooling(nn.Sequential): class ASPPPooling(nn.Sequential):
def __init__(self, in_channels, out_channels): def __init__(self, in_channels: int, out_channels: int) -> None:
super(ASPPPooling, self).__init__( super(ASPPPooling, self).__init__(
nn.AdaptiveAvgPool2d(1), nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, out_channels, 1, bias=False), nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels), nn.BatchNorm2d(out_channels),
nn.ReLU()) nn.ReLU())
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
size = x.shape[-2:] size = x.shape[-2:]
for mod in self: for mod in self:
x = mod(x) x = mod(x)
...@@ -63,7 +64,7 @@ class ASPPPooling(nn.Sequential): ...@@ -63,7 +64,7 @@ class ASPPPooling(nn.Sequential):
class ASPP(nn.Module): class ASPP(nn.Module):
def __init__(self, in_channels, atrous_rates, out_channels=256): def __init__(self, in_channels: int, atrous_rates: List[int], out_channels: int = 256) -> None:
super(ASPP, self).__init__() super(ASPP, self).__init__()
modules = [] modules = []
modules.append(nn.Sequential( modules.append(nn.Sequential(
...@@ -85,9 +86,9 @@ class ASPP(nn.Module): ...@@ -85,9 +86,9 @@ class ASPP(nn.Module):
nn.ReLU(), nn.ReLU(),
nn.Dropout(0.5)) nn.Dropout(0.5))
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
res = [] _res = []
for conv in self.convs: for conv in self.convs:
res.append(conv(x)) _res.append(conv(x))
res = torch.cat(res, dim=1) res = torch.cat(_res, dim=1)
return self.project(res) return self.project(res)
...@@ -23,7 +23,7 @@ class FCN(_SimpleSegmentationModel): ...@@ -23,7 +23,7 @@ class FCN(_SimpleSegmentationModel):
class FCNHead(nn.Sequential): class FCNHead(nn.Sequential):
def __init__(self, in_channels, channels): def __init__(self, in_channels: int, channels: int) -> None:
inter_channels = in_channels // 4 inter_channels = in_channels // 4
layers = [ layers = [
nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
......
...@@ -24,12 +24,19 @@ class LRASPP(nn.Module): ...@@ -24,12 +24,19 @@ class LRASPP(nn.Module):
inter_channels (int, optional): the number of channels for intermediate computations. inter_channels (int, optional): the number of channels for intermediate computations.
""" """
def __init__(self, backbone, low_channels, high_channels, num_classes, inter_channels=128): def __init__(
self,
backbone: nn.Module,
low_channels: int,
high_channels: int,
num_classes: int,
inter_channels: int = 128
) -> None:
super().__init__() super().__init__()
self.backbone = backbone self.backbone = backbone
self.classifier = LRASPPHead(low_channels, high_channels, num_classes, inter_channels) self.classifier = LRASPPHead(low_channels, high_channels, num_classes, inter_channels)
def forward(self, input): def forward(self, input: Tensor) -> Dict[str, Tensor]:
features = self.backbone(input) features = self.backbone(input)
out = self.classifier(features) out = self.classifier(features)
out = F.interpolate(out, size=input.shape[-2:], mode='bilinear', align_corners=False) out = F.interpolate(out, size=input.shape[-2:], mode='bilinear', align_corners=False)
...@@ -42,7 +49,13 @@ class LRASPP(nn.Module): ...@@ -42,7 +49,13 @@ class LRASPP(nn.Module):
class LRASPPHead(nn.Module): class LRASPPHead(nn.Module):
def __init__(self, low_channels, high_channels, num_classes, inter_channels): def __init__(
self,
low_channels: int,
high_channels: int,
num_classes: int,
inter_channels: int
) -> None:
super().__init__() super().__init__()
self.cbr = nn.Sequential( self.cbr = nn.Sequential(
nn.Conv2d(high_channels, inter_channels, 1, bias=False), nn.Conv2d(high_channels, inter_channels, 1, bias=False),
......
from torch import nn
from typing import Any, Optional
from .._utils import IntermediateLayerGetter from .._utils import IntermediateLayerGetter
from ..._internally_replaced_utils import load_state_dict_from_url from ..._internally_replaced_utils import load_state_dict_from_url
from .. import mobilenetv3 from .. import mobilenetv3
...@@ -22,7 +24,13 @@ model_urls = { ...@@ -22,7 +24,13 @@ model_urls = {
} }
def _segm_model(name, backbone_name, num_classes, aux, pretrained_backbone=True): def _segm_model(
name: str,
backbone_name: str,
num_classes: int,
aux: Optional[bool],
pretrained_backbone: bool = True
) -> nn.Module:
if 'resnet' in backbone_name: if 'resnet' in backbone_name:
backbone = resnet.__dict__[backbone_name]( backbone = resnet.__dict__[backbone_name](
pretrained=pretrained_backbone, pretrained=pretrained_backbone,
...@@ -66,7 +74,15 @@ def _segm_model(name, backbone_name, num_classes, aux, pretrained_backbone=True) ...@@ -66,7 +74,15 @@ def _segm_model(name, backbone_name, num_classes, aux, pretrained_backbone=True)
return model return model
def _load_model(arch_type, backbone, pretrained, progress, num_classes, aux_loss, **kwargs): def _load_model(
arch_type: str,
backbone: str,
pretrained: bool,
progress: bool,
num_classes: int,
aux_loss: Optional[bool],
**kwargs: Any
) -> nn.Module:
if pretrained: if pretrained:
aux_loss = True aux_loss = True
kwargs["pretrained_backbone"] = False kwargs["pretrained_backbone"] = False
...@@ -76,7 +92,7 @@ def _load_model(arch_type, backbone, pretrained, progress, num_classes, aux_loss ...@@ -76,7 +92,7 @@ def _load_model(arch_type, backbone, pretrained, progress, num_classes, aux_loss
return model return model
def _load_weights(model, arch_type, backbone, progress): def _load_weights(model: nn.Module, arch_type: str, backbone: str, progress: bool) -> None:
arch = arch_type + '_' + backbone + '_coco' arch = arch_type + '_' + backbone + '_coco'
model_url = model_urls.get(arch, None) model_url = model_urls.get(arch, None)
if model_url is None: if model_url is None:
...@@ -86,7 +102,7 @@ def _load_weights(model, arch_type, backbone, progress): ...@@ -86,7 +102,7 @@ def _load_weights(model, arch_type, backbone, progress):
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
def _segm_lraspp_mobilenetv3(backbone_name, num_classes, pretrained_backbone=True): def _segm_lraspp_mobilenetv3(backbone_name: str, num_classes: int, pretrained_backbone: bool = True) -> LRASPP:
backbone = mobilenetv3.__dict__[backbone_name](pretrained=pretrained_backbone, dilated=True).features backbone = mobilenetv3.__dict__[backbone_name](pretrained=pretrained_backbone, dilated=True).features
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
...@@ -103,8 +119,13 @@ def _segm_lraspp_mobilenetv3(backbone_name, num_classes, pretrained_backbone=Tru ...@@ -103,8 +119,13 @@ def _segm_lraspp_mobilenetv3(backbone_name, num_classes, pretrained_backbone=Tru
return model return model
def fcn_resnet50(pretrained=False, progress=True, def fcn_resnet50(
num_classes=21, aux_loss=None, **kwargs): pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
aux_loss: Optional[bool] = None,
**kwargs: Any
) -> nn.Module:
"""Constructs a Fully-Convolutional Network model with a ResNet-50 backbone. """Constructs a Fully-Convolutional Network model with a ResNet-50 backbone.
Args: Args:
...@@ -117,8 +138,13 @@ def fcn_resnet50(pretrained=False, progress=True, ...@@ -117,8 +138,13 @@ def fcn_resnet50(pretrained=False, progress=True,
return _load_model('fcn', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs) return _load_model('fcn', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs)
def fcn_resnet101(pretrained=False, progress=True, def fcn_resnet101(
num_classes=21, aux_loss=None, **kwargs): pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
aux_loss: Optional[bool] = None,
**kwargs: Any
) -> nn.Module:
"""Constructs a Fully-Convolutional Network model with a ResNet-101 backbone. """Constructs a Fully-Convolutional Network model with a ResNet-101 backbone.
Args: Args:
...@@ -131,8 +157,13 @@ def fcn_resnet101(pretrained=False, progress=True, ...@@ -131,8 +157,13 @@ def fcn_resnet101(pretrained=False, progress=True,
return _load_model('fcn', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs) return _load_model('fcn', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs)
def deeplabv3_resnet50(pretrained=False, progress=True, def deeplabv3_resnet50(
num_classes=21, aux_loss=None, **kwargs): pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
aux_loss: Optional[bool] = None,
**kwargs: Any
) -> nn.Module:
"""Constructs a DeepLabV3 model with a ResNet-50 backbone. """Constructs a DeepLabV3 model with a ResNet-50 backbone.
Args: Args:
...@@ -145,8 +176,13 @@ def deeplabv3_resnet50(pretrained=False, progress=True, ...@@ -145,8 +176,13 @@ def deeplabv3_resnet50(pretrained=False, progress=True,
return _load_model('deeplabv3', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs) return _load_model('deeplabv3', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs)
def deeplabv3_resnet101(pretrained=False, progress=True, def deeplabv3_resnet101(
num_classes=21, aux_loss=None, **kwargs): pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
aux_loss: Optional[bool] = None,
**kwargs: Any
) -> nn.Module:
"""Constructs a DeepLabV3 model with a ResNet-101 backbone. """Constructs a DeepLabV3 model with a ResNet-101 backbone.
Args: Args:
...@@ -159,8 +195,13 @@ def deeplabv3_resnet101(pretrained=False, progress=True, ...@@ -159,8 +195,13 @@ def deeplabv3_resnet101(pretrained=False, progress=True,
return _load_model('deeplabv3', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs) return _load_model('deeplabv3', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs)
def deeplabv3_mobilenet_v3_large(pretrained=False, progress=True, def deeplabv3_mobilenet_v3_large(
num_classes=21, aux_loss=None, **kwargs): pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
aux_loss: Optional[bool] = None,
**kwargs: Any
) -> nn.Module:
"""Constructs a DeepLabV3 model with a MobileNetV3-Large backbone. """Constructs a DeepLabV3 model with a MobileNetV3-Large backbone.
Args: Args:
...@@ -173,7 +214,12 @@ def deeplabv3_mobilenet_v3_large(pretrained=False, progress=True, ...@@ -173,7 +214,12 @@ def deeplabv3_mobilenet_v3_large(pretrained=False, progress=True,
return _load_model('deeplabv3', 'mobilenet_v3_large', pretrained, progress, num_classes, aux_loss, **kwargs) return _load_model('deeplabv3', 'mobilenet_v3_large', pretrained, progress, num_classes, aux_loss, **kwargs)
def lraspp_mobilenet_v3_large(pretrained=False, progress=True, num_classes=21, **kwargs): def lraspp_mobilenet_v3_large(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
**kwargs: Any
) -> nn.Module:
"""Constructs a Lite R-ASPP Network model with a MobileNetV3-Large backbone. """Constructs a Lite R-ASPP Network model with a MobileNetV3-Large backbone.
Args: Args:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment