You need to sign in or sign up before continuing.
Unverified Commit e2db2edd authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Add MobileNetV3 architecture for Segmentation (#3276)

* Making _segm_resnet() generic and reusable.

* Adding fcn and deeplabv3 directly on mobilenetv3 backbone.

* Adding tests for segmentation models.

* Rename is_strided with _is_cn.

* Add dilation support on MobileNetV3 for Segmentation.

* Add Lite R-ASPP with MobileNetV3 backbone.

* Add pretrained model weights.

* Removing model fcn_mobilenet_v3_large.

* Adding docs and imports.

* Fixing typo and readme.
parent cae5400d
...@@ -271,7 +271,8 @@ The models subpackage contains definitions for the following model ...@@ -271,7 +271,8 @@ The models subpackage contains definitions for the following model
architectures for semantic segmentation: architectures for semantic segmentation:
- `FCN ResNet50, ResNet101 <https://arxiv.org/abs/1411.4038>`_ - `FCN ResNet50, ResNet101 <https://arxiv.org/abs/1411.4038>`_
- `DeepLabV3 ResNet50, ResNet101 <https://arxiv.org/abs/1706.05587>`_ - `DeepLabV3 ResNet50, ResNet101, MobileNetV3-Large <https://arxiv.org/abs/1706.05587>`_
- `LR-ASPP MobileNetV3-Large <https://arxiv.org/abs/1905.02244>`_
As with image classification models, all pre-trained models expect input images normalized in the same way. As with image classification models, all pre-trained models expect input images normalized in the same way.
The images have to be loaded in to a range of ``[0, 1]`` and then normalized using The images have to be loaded in to a range of ``[0, 1]`` and then normalized using
...@@ -298,6 +299,8 @@ FCN ResNet50 60.5 91.4 ...@@ -298,6 +299,8 @@ FCN ResNet50 60.5 91.4
FCN ResNet101 63.7 91.9 FCN ResNet101 63.7 91.9
DeepLabV3 ResNet50 66.4 92.4 DeepLabV3 ResNet50 66.4 92.4
DeepLabV3 ResNet101 67.4 92.4 DeepLabV3 ResNet101 67.4 92.4
DeepLabV3 MobileNetV3-Large 60.3 91.2
LR-ASPP MobileNetV3-Large 57.9 91.2
================================ ============= ==================== ================================ ============= ====================
...@@ -313,6 +316,13 @@ DeepLabV3 ...@@ -313,6 +316,13 @@ DeepLabV3
.. autofunction:: torchvision.models.segmentation.deeplabv3_resnet50 .. autofunction:: torchvision.models.segmentation.deeplabv3_resnet50
.. autofunction:: torchvision.models.segmentation.deeplabv3_resnet101 .. autofunction:: torchvision.models.segmentation.deeplabv3_resnet101
.. autofunction:: torchvision.models.segmentation.deeplabv3_mobilenet_v3_large
LR-ASPP
-------
.. autofunction:: torchvision.models.segmentation.lraspp_mobilenet_v3_large
Object Detection, Instance Segmentation and Person Keypoint Detection Object Detection, Instance Segmentation and Person Keypoint Detection
......
...@@ -18,4 +18,4 @@ from torchvision.models.mnasnet import mnasnet0_5, mnasnet0_75, mnasnet1_0, \ ...@@ -18,4 +18,4 @@ from torchvision.models.mnasnet import mnasnet0_5, mnasnet0_75, mnasnet1_0, \
# segmentation # segmentation
from torchvision.models.segmentation import fcn_resnet50, fcn_resnet101, \ from torchvision.models.segmentation import fcn_resnet50, fcn_resnet101, \
deeplabv3_resnet50, deeplabv3_resnet101 deeplabv3_resnet50, deeplabv3_resnet101, deeplabv3_mobilenet_v3_large, lraspp_mobilenet_v3_large
...@@ -31,3 +31,13 @@ python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py --lr 0. ...@@ -31,3 +31,13 @@ python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py --lr 0.
``` ```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py --lr 0.02 --dataset coco -b 4 --model deeplabv3_resnet101 --aux-loss python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py --lr 0.02 --dataset coco -b 4 --model deeplabv3_resnet101 --aux-loss
``` ```
## deeplabv3_mobilenet_v3_large
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py --dataset coco -b 4 --model deeplabv3_mobilenet_v3_large --aux-loss --wd 0.000001
```
## lraspp_mobilenet_v3_large
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py --dataset coco -b 4 --model lraspp_mobilenet_v3_large --wd 0.000001
```
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
...@@ -61,8 +61,10 @@ autocast_flaky_numerics = ( ...@@ -61,8 +61,10 @@ autocast_flaky_numerics = (
"wide_resnet101_2", "wide_resnet101_2",
"deeplabv3_resnet50", "deeplabv3_resnet50",
"deeplabv3_resnet101", "deeplabv3_resnet101",
"deeplabv3_mobilenet_v3_large",
"fcn_resnet50", "fcn_resnet50",
"fcn_resnet101", "fcn_resnet101",
"lraspp_mobilenet_v3_large",
) )
......
...@@ -136,9 +136,9 @@ def mobilenet_backbone( ...@@ -136,9 +136,9 @@ def mobilenet_backbone(
): ):
backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer).features backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer).features
# Gather the indeces of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
# The first and last blocks are always included because they are the C0 (conv1) and Cn. # The first and last blocks are always included because they are the C0 (conv1) and Cn.
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "is_strided", False)] + [len(backbone) - 1] stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
num_stages = len(stage_indices) num_stages = len(stage_indices)
# find the index of the layer from which we wont freeze # find the index of the layer from which we wont freeze
......
...@@ -38,14 +38,16 @@ class ConvBNActivation(nn.Sequential): ...@@ -38,14 +38,16 @@ class ConvBNActivation(nn.Sequential):
groups: int = 1, groups: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None, norm_layer: Optional[Callable[..., nn.Module]] = None,
activation_layer: Optional[Callable[..., nn.Module]] = None, activation_layer: Optional[Callable[..., nn.Module]] = None,
dilation: int = 1,
) -> None: ) -> None:
padding = (kernel_size - 1) // 2 padding = (kernel_size - 1) // 2 * dilation
if norm_layer is None: if norm_layer is None:
norm_layer = nn.BatchNorm2d norm_layer = nn.BatchNorm2d
if activation_layer is None: if activation_layer is None:
activation_layer = nn.ReLU6 activation_layer = nn.ReLU6
super(ConvBNReLU, self).__init__( super(ConvBNReLU, self).__init__(
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, dilation=dilation, groups=groups,
bias=False),
norm_layer(out_planes), norm_layer(out_planes),
activation_layer(inplace=True) activation_layer(inplace=True)
) )
...@@ -88,7 +90,7 @@ class InvertedResidual(nn.Module): ...@@ -88,7 +90,7 @@ class InvertedResidual(nn.Module):
]) ])
self.conv = nn.Sequential(*layers) self.conv = nn.Sequential(*layers)
self.out_channels = oup self.out_channels = oup
self.is_strided = stride > 1 self._is_cn = stride > 1
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
if self.use_res_connect: if self.use_res_connect:
......
...@@ -38,7 +38,7 @@ class SqueezeExcitation(nn.Module): ...@@ -38,7 +38,7 @@ class SqueezeExcitation(nn.Module):
class InvertedResidualConfig: class InvertedResidualConfig:
def __init__(self, input_channels: int, kernel: int, expanded_channels: int, out_channels: int, use_se: bool, def __init__(self, input_channels: int, kernel: int, expanded_channels: int, out_channels: int, use_se: bool,
activation: str, stride: int, width_mult: float): activation: str, stride: int, dilation: int, width_mult: float):
self.input_channels = self.adjust_channels(input_channels, width_mult) self.input_channels = self.adjust_channels(input_channels, width_mult)
self.kernel = kernel self.kernel = kernel
self.expanded_channels = self.adjust_channels(expanded_channels, width_mult) self.expanded_channels = self.adjust_channels(expanded_channels, width_mult)
...@@ -46,6 +46,7 @@ class InvertedResidualConfig: ...@@ -46,6 +46,7 @@ class InvertedResidualConfig:
self.use_se = use_se self.use_se = use_se
self.use_hs = activation == "HS" self.use_hs = activation == "HS"
self.stride = stride self.stride = stride
self.dilation = dilation
@staticmethod @staticmethod
def adjust_channels(channels: int, width_mult: float): def adjust_channels(channels: int, width_mult: float):
...@@ -70,9 +71,10 @@ class InvertedResidual(nn.Module): ...@@ -70,9 +71,10 @@ class InvertedResidual(nn.Module):
norm_layer=norm_layer, activation_layer=activation_layer)) norm_layer=norm_layer, activation_layer=activation_layer))
# depthwise # depthwise
stride = 1 if cnf.dilation > 1 else cnf.stride
layers.append(ConvBNActivation(cnf.expanded_channels, cnf.expanded_channels, kernel_size=cnf.kernel, layers.append(ConvBNActivation(cnf.expanded_channels, cnf.expanded_channels, kernel_size=cnf.kernel,
stride=cnf.stride, groups=cnf.expanded_channels, norm_layer=norm_layer, stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels,
activation_layer=activation_layer)) norm_layer=norm_layer, activation_layer=activation_layer))
if cnf.use_se: if cnf.use_se:
layers.append(SqueezeExcitation(cnf.expanded_channels)) layers.append(SqueezeExcitation(cnf.expanded_channels))
...@@ -82,7 +84,7 @@ class InvertedResidual(nn.Module): ...@@ -82,7 +84,7 @@ class InvertedResidual(nn.Module):
self.block = nn.Sequential(*layers) self.block = nn.Sequential(*layers)
self.out_channels = cnf.out_channels self.out_channels = cnf.out_channels
self.is_strided = cnf.stride > 1 self._is_cn = cnf.stride > 1
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
result = self.block(input) result = self.block(input)
...@@ -194,8 +196,7 @@ def _mobilenet_v3( ...@@ -194,8 +196,7 @@ def _mobilenet_v3(
return model return model
def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, reduced_tail: bool = False, def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3:
**kwargs: Any) -> MobileNetV3:
""" """
Constructs a large MobileNetV3 architecture from Constructs a large MobileNetV3 architecture from
`"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_. `"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.
...@@ -203,40 +204,38 @@ def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, reduced_ ...@@ -203,40 +204,38 @@ def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, reduced_
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
reduced_tail (bool): If True, reduces the channel counts of all feature layers
between C4 and C5 by 2. It is used to reduce the channel redundancy in the
backbone for Detection and Segmentation.
""" """
# non-public config parameters
reduce_divider = 2 if kwargs.pop('_reduced_tail', False) else 1
dilation = 2 if kwargs.pop('_dilated', False) else 1
width_mult = 1.0 width_mult = 1.0
bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult) bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult)
adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult) adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult)
reduce_divider = 2 if reduced_tail else 1
inverted_residual_setting = [ inverted_residual_setting = [
bneck_conf(16, 3, 16, 16, False, "RE", 1), bneck_conf(16, 3, 16, 16, False, "RE", 1, 1),
bneck_conf(16, 3, 64, 24, False, "RE", 2), # C1 bneck_conf(16, 3, 64, 24, False, "RE", 2, 1), # C1
bneck_conf(24, 3, 72, 24, False, "RE", 1), bneck_conf(24, 3, 72, 24, False, "RE", 1, 1),
bneck_conf(24, 5, 72, 40, True, "RE", 2), # C2 bneck_conf(24, 5, 72, 40, True, "RE", 2, 1), # C2
bneck_conf(40, 5, 120, 40, True, "RE", 1), bneck_conf(40, 5, 120, 40, True, "RE", 1, 1),
bneck_conf(40, 5, 120, 40, True, "RE", 1), bneck_conf(40, 5, 120, 40, True, "RE", 1, 1),
bneck_conf(40, 3, 240, 80, False, "HS", 2), # C3 bneck_conf(40, 3, 240, 80, False, "HS", 2, 1), # C3
bneck_conf(80, 3, 200, 80, False, "HS", 1), bneck_conf(80, 3, 200, 80, False, "HS", 1, 1),
bneck_conf(80, 3, 184, 80, False, "HS", 1), bneck_conf(80, 3, 184, 80, False, "HS", 1, 1),
bneck_conf(80, 3, 184, 80, False, "HS", 1), bneck_conf(80, 3, 184, 80, False, "HS", 1, 1),
bneck_conf(80, 3, 480, 112, True, "HS", 1), bneck_conf(80, 3, 480, 112, True, "HS", 1, 1),
bneck_conf(112, 3, 672, 112, True, "HS", 1), bneck_conf(112, 3, 672, 112, True, "HS", 1, 1),
bneck_conf(112, 5, 672, 160 // reduce_divider, True, "HS", 2), # C4 bneck_conf(112, 5, 672, 160 // reduce_divider, True, "HS", 2, dilation), # C4
bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1), bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation),
bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1), bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation),
] ]
last_channel = adjust_channels(1280 // reduce_divider) # C5 last_channel = adjust_channels(1280 // reduce_divider) # C5
return _mobilenet_v3("mobilenet_v3_large", inverted_residual_setting, last_channel, pretrained, progress, **kwargs) return _mobilenet_v3("mobilenet_v3_large", inverted_residual_setting, last_channel, pretrained, progress, **kwargs)
def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, reduced_tail: bool = False, def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3:
**kwargs: Any) -> MobileNetV3:
""" """
Constructs a small MobileNetV3 architecture from Constructs a small MobileNetV3 architecture from
`"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_. `"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.
...@@ -244,28 +243,27 @@ def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, reduced_ ...@@ -244,28 +243,27 @@ def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, reduced_
Args: Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
reduced_tail (bool): If True, reduces the channel counts of all feature layers
between C4 and C5 by 2. It is used to reduce the channel redundancy in the
backbone for Detection and Segmentation.
""" """
# non-public config parameters
reduce_divider = 2 if kwargs.pop('_reduced_tail', False) else 1
dilation = 2 if kwargs.pop('_dilated', False) else 1
width_mult = 1.0 width_mult = 1.0
bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult) bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult)
adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult) adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult)
reduce_divider = 2 if reduced_tail else 1
inverted_residual_setting = [ inverted_residual_setting = [
bneck_conf(16, 3, 16, 16, True, "RE", 2), # C1 bneck_conf(16, 3, 16, 16, True, "RE", 2, 1), # C1
bneck_conf(16, 3, 72, 24, False, "RE", 2), # C2 bneck_conf(16, 3, 72, 24, False, "RE", 2, 1), # C2
bneck_conf(24, 3, 88, 24, False, "RE", 1), bneck_conf(24, 3, 88, 24, False, "RE", 1, 1),
bneck_conf(24, 5, 96, 40, True, "HS", 2), # C3 bneck_conf(24, 5, 96, 40, True, "HS", 2, 1), # C3
bneck_conf(40, 5, 240, 40, True, "HS", 1), bneck_conf(40, 5, 240, 40, True, "HS", 1, 1),
bneck_conf(40, 5, 240, 40, True, "HS", 1), bneck_conf(40, 5, 240, 40, True, "HS", 1, 1),
bneck_conf(40, 5, 120, 48, True, "HS", 1), bneck_conf(40, 5, 120, 48, True, "HS", 1, 1),
bneck_conf(48, 5, 144, 48, True, "HS", 1), bneck_conf(48, 5, 144, 48, True, "HS", 1, 1),
bneck_conf(48, 5, 288, 96 // reduce_divider, True, "HS", 2), # C4 bneck_conf(48, 5, 288, 96 // reduce_divider, True, "HS", 2, dilation), # C4
bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1), bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1, dilation),
bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1), bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1, dilation),
] ]
last_channel = adjust_channels(1024 // reduce_divider) # C5 last_channel = adjust_channels(1024 // reduce_divider) # C5
......
from .segmentation import * from .segmentation import *
from .fcn import * from .fcn import *
from .deeplabv3 import * from .deeplabv3 import *
from .lraspp import *
from collections import OrderedDict from collections import OrderedDict
import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
......
from collections import OrderedDict
from torch import nn, Tensor
from torch.nn import functional as F
from typing import Dict
__all__ = ["LRASPP"]
class LRASPP(nn.Module):
"""
Implements a Lite R-ASPP Network for semantic segmentation from
`"Searching for MobileNetV3"
<https://arxiv.org/abs/1905.02244>`_.
Args:
backbone (nn.Module): the network used to compute the features for the model.
The backbone should return an OrderedDict[Tensor], with the key being
"high" for the high level feature map and "low" for the low level feature map.
low_channels (int): the number of channels of the low level features.
high_channels (int): the number of channels of the high level features.
num_classes (int): number of output classes of the model (including the background).
inter_channels (int, optional): the number of channels for intermediate computations.
"""
def __init__(self, backbone, low_channels, high_channels, num_classes, inter_channels=128):
super().__init__()
self.backbone = backbone
self.classifier = LRASPPHead(low_channels, high_channels, num_classes, inter_channels)
def forward(self, input):
features = self.backbone(input)
out = self.classifier(features)
out = F.interpolate(out, size=input.shape[-2:], mode='bilinear', align_corners=False)
result = OrderedDict()
result["out"] = out
return result
class LRASPPHead(nn.Module):
def __init__(self, low_channels, high_channels, num_classes, inter_channels):
super().__init__()
self.cbr = nn.Sequential(
nn.Conv2d(high_channels, inter_channels, 1, bias=False),
nn.BatchNorm2d(inter_channels),
nn.ReLU(inplace=True)
)
self.scale = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(high_channels, inter_channels, 1, bias=False),
nn.Sigmoid(),
)
self.low_classifier = nn.Conv2d(low_channels, num_classes, 1)
self.high_classifier = nn.Conv2d(inter_channels, num_classes, 1)
def forward(self, input: Dict[str, Tensor]) -> Tensor:
low = input["low"]
high = input["high"]
x = self.cbr(high)
s = self.scale(high)
x = x * s
x = F.interpolate(x, size=low.shape[-2:], mode='bilinear', align_corners=False)
return self.low_classifier(low) + self.high_classifier(x)
from .._utils import IntermediateLayerGetter from .._utils import IntermediateLayerGetter
from ..utils import load_state_dict_from_url from ..utils import load_state_dict_from_url
from .. import mobilenetv3
from .. import resnet from .. import resnet
from .deeplabv3 import DeepLabHead, DeepLabV3 from .deeplabv3 import DeepLabHead, DeepLabV3
from .fcn import FCN, FCNHead from .fcn import FCN, FCNHead
from .lraspp import LRASPP
__all__ = ['fcn_resnet50', 'fcn_resnet101', 'deeplabv3_resnet50', 'deeplabv3_resnet101'] __all__ = ['fcn_resnet50', 'fcn_resnet101', 'deeplabv3_resnet50', 'deeplabv3_resnet101',
'deeplabv3_mobilenet_v3_large', 'lraspp_mobilenet_v3_large']
model_urls = { model_urls = {
...@@ -13,30 +16,50 @@ model_urls = { ...@@ -13,30 +16,50 @@ model_urls = {
'fcn_resnet101_coco': 'https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth', 'fcn_resnet101_coco': 'https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth',
'deeplabv3_resnet50_coco': 'https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth', 'deeplabv3_resnet50_coco': 'https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth',
'deeplabv3_resnet101_coco': 'https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth', 'deeplabv3_resnet101_coco': 'https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth',
'deeplabv3_mobilenet_v3_large_coco':
'https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth',
'lraspp_mobilenet_v3_large_coco': 'https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth',
} }
def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True): def _segm_model(name, backbone_name, num_classes, aux, pretrained_backbone=True):
backbone = resnet.__dict__[backbone_name]( if 'resnet' in backbone_name:
pretrained=pretrained_backbone, backbone = resnet.__dict__[backbone_name](
replace_stride_with_dilation=[False, True, True]) pretrained=pretrained_backbone,
replace_stride_with_dilation=[False, True, True])
return_layers = {'layer4': 'out'} out_layer = 'layer4'
out_inplanes = 2048
aux_layer = 'layer3'
aux_inplanes = 1024
elif 'mobilenet_v3' in backbone_name:
backbone = mobilenetv3.__dict__[backbone_name](pretrained=pretrained_backbone, _dilated=True).features
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
out_pos = stage_indices[-1] # use C5 which has output_stride = 16
out_layer = str(out_pos)
out_inplanes = backbone[out_pos].out_channels
aux_pos = stage_indices[-4] # use C2 here which has output_stride = 8
aux_layer = str(aux_pos)
aux_inplanes = backbone[aux_pos].out_channels
else:
raise NotImplementedError('backbone {} is not supported as of now'.format(backbone_name))
return_layers = {out_layer: 'out'}
if aux: if aux:
return_layers['layer3'] = 'aux' return_layers[aux_layer] = 'aux'
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
aux_classifier = None aux_classifier = None
if aux: if aux:
inplanes = 1024 aux_classifier = FCNHead(aux_inplanes, num_classes)
aux_classifier = FCNHead(inplanes, num_classes)
model_map = { model_map = {
'deeplabv3': (DeepLabHead, DeepLabV3), 'deeplabv3': (DeepLabHead, DeepLabV3),
'fcn': (FCNHead, FCN), 'fcn': (FCNHead, FCN),
} }
inplanes = 2048 classifier = model_map[name][0](out_inplanes, num_classes)
classifier = model_map[name][0](inplanes, num_classes)
base_model = model_map[name][1] base_model = model_map[name][1]
model = base_model(backbone, classifier, aux_classifier) model = base_model(backbone, classifier, aux_classifier)
...@@ -46,15 +69,36 @@ def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True ...@@ -46,15 +69,36 @@ def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True
def _load_model(arch_type, backbone, pretrained, progress, num_classes, aux_loss, **kwargs): def _load_model(arch_type, backbone, pretrained, progress, num_classes, aux_loss, **kwargs):
if pretrained: if pretrained:
aux_loss = True aux_loss = True
model = _segm_resnet(arch_type, backbone, num_classes, aux_loss, **kwargs) model = _segm_model(arch_type, backbone, num_classes, aux_loss, **kwargs)
if pretrained: if pretrained:
arch = arch_type + '_' + backbone + '_coco' _load_weights(model, arch_type, backbone, progress)
model_url = model_urls[arch] return model
if model_url is None:
raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
else: def _load_weights(model, arch_type, backbone, progress):
state_dict = load_state_dict_from_url(model_url, progress=progress) arch = arch_type + '_' + backbone + '_coco'
model.load_state_dict(state_dict) model_url = model_urls.get(arch, None)
if model_url is None:
raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
else:
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
def _segm_lraspp_mobilenetv3(backbone_name, num_classes, pretrained_backbone=True):
backbone = mobilenetv3.__dict__[backbone_name](pretrained=pretrained_backbone, _dilated=True).features
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
low_pos = stage_indices[-4] # use C2 here which has output_stride = 8
high_pos = stage_indices[-1] # use C5 which has output_stride = 16
low_channels = backbone[low_pos].out_channels
high_channels = backbone[high_pos].out_channels
backbone = IntermediateLayerGetter(backbone, return_layers={str(low_pos): 'low', str(high_pos): 'high'})
model = LRASPP(backbone, low_channels, high_channels, num_classes)
return model return model
...@@ -66,6 +110,8 @@ def fcn_resnet50(pretrained=False, progress=True, ...@@ -66,6 +110,8 @@ def fcn_resnet50(pretrained=False, progress=True,
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
aux_loss (bool): If True, it uses an auxiliary loss
""" """
return _load_model('fcn', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs) return _load_model('fcn', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs)
...@@ -78,6 +124,8 @@ def fcn_resnet101(pretrained=False, progress=True, ...@@ -78,6 +124,8 @@ def fcn_resnet101(pretrained=False, progress=True,
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
aux_loss (bool): If True, it uses an auxiliary loss
""" """
return _load_model('fcn', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs) return _load_model('fcn', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs)
...@@ -90,6 +138,8 @@ def deeplabv3_resnet50(pretrained=False, progress=True, ...@@ -90,6 +138,8 @@ def deeplabv3_resnet50(pretrained=False, progress=True,
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
aux_loss (bool): If True, it uses an auxiliary loss
""" """
return _load_model('deeplabv3', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs) return _load_model('deeplabv3', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs)
...@@ -102,5 +152,42 @@ def deeplabv3_resnet101(pretrained=False, progress=True, ...@@ -102,5 +152,42 @@ def deeplabv3_resnet101(pretrained=False, progress=True,
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
aux_loss (bool): If True, it uses an auxiliary loss
""" """
return _load_model('deeplabv3', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs) return _load_model('deeplabv3', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs)
def deeplabv3_mobilenet_v3_large(pretrained=False, progress=True,
num_classes=21, aux_loss=None, **kwargs):
"""Constructs a DeepLabV3 model with a MobileNetV3-Large backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
aux_loss (bool): If True, it uses an auxiliary loss
"""
return _load_model('deeplabv3', 'mobilenet_v3_large', pretrained, progress, num_classes, aux_loss, **kwargs)
def lraspp_mobilenet_v3_large(pretrained=False, progress=True, num_classes=21, **kwargs):
"""Constructs a Lite R-ASPP Network model with a MobileNetV3-Large backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
"""
if kwargs.pop("aux_loss", False):
raise NotImplementedError('This model does not use auxiliary loss')
backbone_name = 'mobilenet_v3_large'
model = _segm_lraspp_mobilenetv3(backbone_name, num_classes, **kwargs)
if pretrained:
_load_weights(model, 'lraspp', backbone_name, progress)
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment