Unverified Commit 185be3a9 authored by F-G Fernandez's avatar F-G Fernandez Committed by GitHub
Browse files

Added typing annotations to models/segmentation (#4227)

* style: Added typing annotations to segmentation/_utils

* style: Added typing annotations to segmentation/segmentation

* style: Added typing annotations to remaining segmentation models

* style: Fixed typing of DeepLab

* style: Fixed typing

* fix: Fixed typing annotations & default values

* Fixing python_type_check
parent 7947fc8f
from collections import OrderedDict
from typing import Optional, Dict
from torch import nn
from torch import nn, Tensor
from torch.nn import functional as F
class _SimpleSegmentationModel(nn.Module):
__constants__ = ['aux_classifier']
def __init__(self, backbone, classifier, aux_classifier=None):
def __init__(
self,
backbone: nn.Module,
classifier: nn.Module,
aux_classifier: Optional[nn.Module] = None
) -> None:
super(_SimpleSegmentationModel, self).__init__()
self.backbone = backbone
self.classifier = classifier
self.aux_classifier = aux_classifier
def forward(self, x):
def forward(self, x: Tensor) -> Dict[str, Tensor]:
input_shape = x.shape[-2:]
# contract: features is a dict of tensors
features = self.backbone(x)
......
import torch
from torch import nn
from torch.nn import functional as F
from typing import List
from ._utils import _SimpleSegmentationModel
......@@ -27,7 +28,7 @@ class DeepLabV3(_SimpleSegmentationModel):
class DeepLabHead(nn.Sequential):
def __init__(self, in_channels, num_classes):
def __init__(self, in_channels: int, num_classes: int) -> None:
super(DeepLabHead, self).__init__(
ASPP(in_channels, [12, 24, 36]),
nn.Conv2d(256, 256, 3, padding=1, bias=False),
......@@ -38,7 +39,7 @@ class DeepLabHead(nn.Sequential):
class ASPPConv(nn.Sequential):
def __init__(self, in_channels, out_channels, dilation):
def __init__(self, in_channels: int, out_channels: int, dilation: int) -> None:
modules = [
nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
nn.BatchNorm2d(out_channels),
......@@ -48,14 +49,14 @@ class ASPPConv(nn.Sequential):
class ASPPPooling(nn.Sequential):
def __init__(self, in_channels, out_channels):
def __init__(self, in_channels: int, out_channels: int) -> None:
super(ASPPPooling, self).__init__(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU())
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
size = x.shape[-2:]
for mod in self:
x = mod(x)
......@@ -63,7 +64,7 @@ class ASPPPooling(nn.Sequential):
class ASPP(nn.Module):
def __init__(self, in_channels, atrous_rates, out_channels=256):
def __init__(self, in_channels: int, atrous_rates: List[int], out_channels: int = 256) -> None:
super(ASPP, self).__init__()
modules = []
modules.append(nn.Sequential(
......@@ -85,9 +86,9 @@ class ASPP(nn.Module):
nn.ReLU(),
nn.Dropout(0.5))
def forward(self, x):
res = []
def forward(self, x: torch.Tensor) -> torch.Tensor:
_res = []
for conv in self.convs:
res.append(conv(x))
res = torch.cat(res, dim=1)
_res.append(conv(x))
res = torch.cat(_res, dim=1)
return self.project(res)
......@@ -23,7 +23,7 @@ class FCN(_SimpleSegmentationModel):
class FCNHead(nn.Sequential):
def __init__(self, in_channels, channels):
def __init__(self, in_channels: int, channels: int) -> None:
inter_channels = in_channels // 4
layers = [
nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
......
......@@ -24,12 +24,19 @@ class LRASPP(nn.Module):
inter_channels (int, optional): the number of channels for intermediate computations.
"""
def __init__(self, backbone, low_channels, high_channels, num_classes, inter_channels=128):
def __init__(
self,
backbone: nn.Module,
low_channels: int,
high_channels: int,
num_classes: int,
inter_channels: int = 128
) -> None:
super().__init__()
self.backbone = backbone
self.classifier = LRASPPHead(low_channels, high_channels, num_classes, inter_channels)
def forward(self, input):
def forward(self, input: Tensor) -> Dict[str, Tensor]:
features = self.backbone(input)
out = self.classifier(features)
out = F.interpolate(out, size=input.shape[-2:], mode='bilinear', align_corners=False)
......@@ -42,7 +49,13 @@ class LRASPP(nn.Module):
class LRASPPHead(nn.Module):
def __init__(self, low_channels, high_channels, num_classes, inter_channels):
def __init__(
self,
low_channels: int,
high_channels: int,
num_classes: int,
inter_channels: int
) -> None:
super().__init__()
self.cbr = nn.Sequential(
nn.Conv2d(high_channels, inter_channels, 1, bias=False),
......
from torch import nn
from typing import Any, Optional
from .._utils import IntermediateLayerGetter
from ..._internally_replaced_utils import load_state_dict_from_url
from .. import mobilenetv3
......@@ -22,7 +24,13 @@ model_urls = {
}
def _segm_model(name, backbone_name, num_classes, aux, pretrained_backbone=True):
def _segm_model(
name: str,
backbone_name: str,
num_classes: int,
aux: Optional[bool],
pretrained_backbone: bool = True
) -> nn.Module:
if 'resnet' in backbone_name:
backbone = resnet.__dict__[backbone_name](
pretrained=pretrained_backbone,
......@@ -66,7 +74,15 @@ def _segm_model(name, backbone_name, num_classes, aux, pretrained_backbone=True)
return model
def _load_model(arch_type, backbone, pretrained, progress, num_classes, aux_loss, **kwargs):
def _load_model(
arch_type: str,
backbone: str,
pretrained: bool,
progress: bool,
num_classes: int,
aux_loss: Optional[bool],
**kwargs: Any
) -> nn.Module:
if pretrained:
aux_loss = True
kwargs["pretrained_backbone"] = False
......@@ -76,7 +92,7 @@ def _load_model(arch_type, backbone, pretrained, progress, num_classes, aux_loss
return model
def _load_weights(model, arch_type, backbone, progress):
def _load_weights(model: nn.Module, arch_type: str, backbone: str, progress: bool) -> None:
arch = arch_type + '_' + backbone + '_coco'
model_url = model_urls.get(arch, None)
if model_url is None:
......@@ -86,7 +102,7 @@ def _load_weights(model, arch_type, backbone, progress):
model.load_state_dict(state_dict)
def _segm_lraspp_mobilenetv3(backbone_name, num_classes, pretrained_backbone=True):
def _segm_lraspp_mobilenetv3(backbone_name: str, num_classes: int, pretrained_backbone: bool = True) -> LRASPP:
backbone = mobilenetv3.__dict__[backbone_name](pretrained=pretrained_backbone, dilated=True).features
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
......@@ -103,8 +119,13 @@ def _segm_lraspp_mobilenetv3(backbone_name, num_classes, pretrained_backbone=Tru
return model
def fcn_resnet50(pretrained=False, progress=True,
num_classes=21, aux_loss=None, **kwargs):
def fcn_resnet50(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
aux_loss: Optional[bool] = None,
**kwargs: Any
) -> nn.Module:
"""Constructs a Fully-Convolutional Network model with a ResNet-50 backbone.
Args:
......@@ -117,8 +138,13 @@ def fcn_resnet50(pretrained=False, progress=True,
return _load_model('fcn', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs)
def fcn_resnet101(pretrained=False, progress=True,
num_classes=21, aux_loss=None, **kwargs):
def fcn_resnet101(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
aux_loss: Optional[bool] = None,
**kwargs: Any
) -> nn.Module:
"""Constructs a Fully-Convolutional Network model with a ResNet-101 backbone.
Args:
......@@ -131,8 +157,13 @@ def fcn_resnet101(pretrained=False, progress=True,
return _load_model('fcn', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs)
def deeplabv3_resnet50(pretrained=False, progress=True,
num_classes=21, aux_loss=None, **kwargs):
def deeplabv3_resnet50(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
aux_loss: Optional[bool] = None,
**kwargs: Any
) -> nn.Module:
"""Constructs a DeepLabV3 model with a ResNet-50 backbone.
Args:
......@@ -145,8 +176,13 @@ def deeplabv3_resnet50(pretrained=False, progress=True,
return _load_model('deeplabv3', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs)
def deeplabv3_resnet101(pretrained=False, progress=True,
num_classes=21, aux_loss=None, **kwargs):
def deeplabv3_resnet101(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
aux_loss: Optional[bool] = None,
**kwargs: Any
) -> nn.Module:
"""Constructs a DeepLabV3 model with a ResNet-101 backbone.
Args:
......@@ -159,8 +195,13 @@ def deeplabv3_resnet101(pretrained=False, progress=True,
return _load_model('deeplabv3', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs)
def deeplabv3_mobilenet_v3_large(pretrained=False, progress=True,
num_classes=21, aux_loss=None, **kwargs):
def deeplabv3_mobilenet_v3_large(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
aux_loss: Optional[bool] = None,
**kwargs: Any
) -> nn.Module:
"""Constructs a DeepLabV3 model with a MobileNetV3-Large backbone.
Args:
......@@ -173,7 +214,12 @@ def deeplabv3_mobilenet_v3_large(pretrained=False, progress=True,
return _load_model('deeplabv3', 'mobilenet_v3_large', pretrained, progress, num_classes, aux_loss, **kwargs)
def lraspp_mobilenet_v3_large(pretrained=False, progress=True, num_classes=21, **kwargs):
def lraspp_mobilenet_v3_large(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 21,
**kwargs: Any
) -> nn.Module:
"""Constructs a Lite R-ASPP Network model with a MobileNetV3-Large backbone.
Args:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment