Unverified Commit e2db2edd authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Add MobileNetV3 architecture for Segmentation (#3276)

* Making _segm_resnet() generic and reusable.

* Adding fcn and deeplabv3 directly on mobilenetv3 backbone.

* Adding tests for segmentation models.

* Rename is_strided with _is_cn.

* Add dilation support on MobileNetV3 for Segmentation.

* Add Lite R-ASPP with MobileNetV3 backbone.

* Add pretrained model weights.

* Removing model fcn_mobilenet_v3_large.

* Adding docs and imports.

* Fixing typo and readme.
parent cae5400d
......@@ -271,7 +271,8 @@ The models subpackage contains definitions for the following model
architectures for semantic segmentation:
- `FCN ResNet50, ResNet101 <https://arxiv.org/abs/1411.4038>`_
- `DeepLabV3 ResNet50, ResNet101 <https://arxiv.org/abs/1706.05587>`_
- `DeepLabV3 ResNet50, ResNet101, MobileNetV3-Large <https://arxiv.org/abs/1706.05587>`_
- `LR-ASPP MobileNetV3-Large <https://arxiv.org/abs/1905.02244>`_
As with image classification models, all pre-trained models expect input images normalized in the same way.
The images have to be loaded in to a range of ``[0, 1]`` and then normalized using
......@@ -298,6 +299,8 @@ FCN ResNet50 60.5 91.4
FCN ResNet101 63.7 91.9
DeepLabV3 ResNet50 66.4 92.4
DeepLabV3 ResNet101 67.4 92.4
DeepLabV3 MobileNetV3-Large 60.3 91.2
LR-ASPP MobileNetV3-Large 57.9 91.2
================================ ============= ====================
......@@ -313,6 +316,13 @@ DeepLabV3
.. autofunction:: torchvision.models.segmentation.deeplabv3_resnet50
.. autofunction:: torchvision.models.segmentation.deeplabv3_resnet101
.. autofunction:: torchvision.models.segmentation.deeplabv3_mobilenet_v3_large
LR-ASPP
-------
.. autofunction:: torchvision.models.segmentation.lraspp_mobilenet_v3_large
Object Detection, Instance Segmentation and Person Keypoint Detection
......
......@@ -18,4 +18,4 @@ from torchvision.models.mnasnet import mnasnet0_5, mnasnet0_75, mnasnet1_0, \
# segmentation
from torchvision.models.segmentation import fcn_resnet50, fcn_resnet101, \
deeplabv3_resnet50, deeplabv3_resnet101
deeplabv3_resnet50, deeplabv3_resnet101, deeplabv3_mobilenet_v3_large, lraspp_mobilenet_v3_large
......@@ -31,3 +31,13 @@ python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py --lr 0.
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py --lr 0.02 --dataset coco -b 4 --model deeplabv3_resnet101 --aux-loss
```
## deeplabv3_mobilenet_v3_large
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py --dataset coco -b 4 --model deeplabv3_mobilenet_v3_large --aux-loss --wd 0.000001
```
## lraspp_mobilenet_v3_large
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py --dataset coco -b 4 --model lraspp_mobilenet_v3_large --wd 0.000001
```
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
......@@ -61,8 +61,10 @@ autocast_flaky_numerics = (
"wide_resnet101_2",
"deeplabv3_resnet50",
"deeplabv3_resnet101",
"deeplabv3_mobilenet_v3_large",
"fcn_resnet50",
"fcn_resnet101",
"lraspp_mobilenet_v3_large",
)
......
......@@ -136,9 +136,9 @@ def mobilenet_backbone(
):
backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer).features
# Gather the indeces of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "is_strided", False)] + [len(backbone) - 1]
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
num_stages = len(stage_indices)
# find the index of the layer from which we wont freeze
......
......@@ -38,14 +38,16 @@ class ConvBNActivation(nn.Sequential):
groups: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None,
activation_layer: Optional[Callable[..., nn.Module]] = None,
dilation: int = 1,
) -> None:
padding = (kernel_size - 1) // 2
padding = (kernel_size - 1) // 2 * dilation
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if activation_layer is None:
activation_layer = nn.ReLU6
super(ConvBNReLU, self).__init__(
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, dilation=dilation, groups=groups,
bias=False),
norm_layer(out_planes),
activation_layer(inplace=True)
)
......@@ -88,7 +90,7 @@ class InvertedResidual(nn.Module):
])
self.conv = nn.Sequential(*layers)
self.out_channels = oup
self.is_strided = stride > 1
self._is_cn = stride > 1
def forward(self, x: Tensor) -> Tensor:
if self.use_res_connect:
......
......@@ -38,7 +38,7 @@ class SqueezeExcitation(nn.Module):
class InvertedResidualConfig:
def __init__(self, input_channels: int, kernel: int, expanded_channels: int, out_channels: int, use_se: bool,
activation: str, stride: int, width_mult: float):
activation: str, stride: int, dilation: int, width_mult: float):
self.input_channels = self.adjust_channels(input_channels, width_mult)
self.kernel = kernel
self.expanded_channels = self.adjust_channels(expanded_channels, width_mult)
......@@ -46,6 +46,7 @@ class InvertedResidualConfig:
self.use_se = use_se
self.use_hs = activation == "HS"
self.stride = stride
self.dilation = dilation
@staticmethod
def adjust_channels(channels: int, width_mult: float):
......@@ -70,9 +71,10 @@ class InvertedResidual(nn.Module):
norm_layer=norm_layer, activation_layer=activation_layer))
# depthwise
stride = 1 if cnf.dilation > 1 else cnf.stride
layers.append(ConvBNActivation(cnf.expanded_channels, cnf.expanded_channels, kernel_size=cnf.kernel,
stride=cnf.stride, groups=cnf.expanded_channels, norm_layer=norm_layer,
activation_layer=activation_layer))
stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels,
norm_layer=norm_layer, activation_layer=activation_layer))
if cnf.use_se:
layers.append(SqueezeExcitation(cnf.expanded_channels))
......@@ -82,7 +84,7 @@ class InvertedResidual(nn.Module):
self.block = nn.Sequential(*layers)
self.out_channels = cnf.out_channels
self.is_strided = cnf.stride > 1
self._is_cn = cnf.stride > 1
def forward(self, input: Tensor) -> Tensor:
result = self.block(input)
......@@ -194,8 +196,7 @@ def _mobilenet_v3(
return model
def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, reduced_tail: bool = False,
**kwargs: Any) -> MobileNetV3:
def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3:
"""
Constructs a large MobileNetV3 architecture from
`"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.
......@@ -203,40 +204,38 @@ def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, reduced_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
reduced_tail (bool): If True, reduces the channel counts of all feature layers
between C4 and C5 by 2. It is used to reduce the channel redundancy in the
backbone for Detection and Segmentation.
"""
# non-public config parameters
reduce_divider = 2 if kwargs.pop('_reduced_tail', False) else 1
dilation = 2 if kwargs.pop('_dilated', False) else 1
width_mult = 1.0
bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult)
adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult)
reduce_divider = 2 if reduced_tail else 1
inverted_residual_setting = [
bneck_conf(16, 3, 16, 16, False, "RE", 1),
bneck_conf(16, 3, 64, 24, False, "RE", 2), # C1
bneck_conf(24, 3, 72, 24, False, "RE", 1),
bneck_conf(24, 5, 72, 40, True, "RE", 2), # C2
bneck_conf(40, 5, 120, 40, True, "RE", 1),
bneck_conf(40, 5, 120, 40, True, "RE", 1),
bneck_conf(40, 3, 240, 80, False, "HS", 2), # C3
bneck_conf(80, 3, 200, 80, False, "HS", 1),
bneck_conf(80, 3, 184, 80, False, "HS", 1),
bneck_conf(80, 3, 184, 80, False, "HS", 1),
bneck_conf(80, 3, 480, 112, True, "HS", 1),
bneck_conf(112, 3, 672, 112, True, "HS", 1),
bneck_conf(112, 5, 672, 160 // reduce_divider, True, "HS", 2), # C4
bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1),
bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1),
bneck_conf(16, 3, 16, 16, False, "RE", 1, 1),
bneck_conf(16, 3, 64, 24, False, "RE", 2, 1), # C1
bneck_conf(24, 3, 72, 24, False, "RE", 1, 1),
bneck_conf(24, 5, 72, 40, True, "RE", 2, 1), # C2
bneck_conf(40, 5, 120, 40, True, "RE", 1, 1),
bneck_conf(40, 5, 120, 40, True, "RE", 1, 1),
bneck_conf(40, 3, 240, 80, False, "HS", 2, 1), # C3
bneck_conf(80, 3, 200, 80, False, "HS", 1, 1),
bneck_conf(80, 3, 184, 80, False, "HS", 1, 1),
bneck_conf(80, 3, 184, 80, False, "HS", 1, 1),
bneck_conf(80, 3, 480, 112, True, "HS", 1, 1),
bneck_conf(112, 3, 672, 112, True, "HS", 1, 1),
bneck_conf(112, 5, 672, 160 // reduce_divider, True, "HS", 2, dilation), # C4
bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation),
bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation),
]
last_channel = adjust_channels(1280 // reduce_divider) # C5
return _mobilenet_v3("mobilenet_v3_large", inverted_residual_setting, last_channel, pretrained, progress, **kwargs)
def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, reduced_tail: bool = False,
**kwargs: Any) -> MobileNetV3:
def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3:
"""
Constructs a small MobileNetV3 architecture from
`"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.
......@@ -244,28 +243,27 @@ def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, reduced_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
reduced_tail (bool): If True, reduces the channel counts of all feature layers
between C4 and C5 by 2. It is used to reduce the channel redundancy in the
backbone for Detection and Segmentation.
"""
# non-public config parameters
reduce_divider = 2 if kwargs.pop('_reduced_tail', False) else 1
dilation = 2 if kwargs.pop('_dilated', False) else 1
width_mult = 1.0
bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult)
adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult)
reduce_divider = 2 if reduced_tail else 1
inverted_residual_setting = [
bneck_conf(16, 3, 16, 16, True, "RE", 2), # C1
bneck_conf(16, 3, 72, 24, False, "RE", 2), # C2
bneck_conf(24, 3, 88, 24, False, "RE", 1),
bneck_conf(24, 5, 96, 40, True, "HS", 2), # C3
bneck_conf(40, 5, 240, 40, True, "HS", 1),
bneck_conf(40, 5, 240, 40, True, "HS", 1),
bneck_conf(40, 5, 120, 48, True, "HS", 1),
bneck_conf(48, 5, 144, 48, True, "HS", 1),
bneck_conf(48, 5, 288, 96 // reduce_divider, True, "HS", 2), # C4
bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1),
bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1),
bneck_conf(16, 3, 16, 16, True, "RE", 2, 1), # C1
bneck_conf(16, 3, 72, 24, False, "RE", 2, 1), # C2
bneck_conf(24, 3, 88, 24, False, "RE", 1, 1),
bneck_conf(24, 5, 96, 40, True, "HS", 2, 1), # C3
bneck_conf(40, 5, 240, 40, True, "HS", 1, 1),
bneck_conf(40, 5, 240, 40, True, "HS", 1, 1),
bneck_conf(40, 5, 120, 48, True, "HS", 1, 1),
bneck_conf(48, 5, 144, 48, True, "HS", 1, 1),
bneck_conf(48, 5, 288, 96 // reduce_divider, True, "HS", 2, dilation), # C4
bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1, dilation),
bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1, dilation),
]
last_channel = adjust_channels(1024 // reduce_divider) # C5
......
from .segmentation import *
from .fcn import *
from .deeplabv3 import *
from .lraspp import *
from collections import OrderedDict
import torch
from torch import nn
from torch.nn import functional as F
......
from collections import OrderedDict
from torch import nn, Tensor
from torch.nn import functional as F
from typing import Dict
__all__ = ["LRASPP"]
class LRASPP(nn.Module):
"""
Implements a Lite R-ASPP Network for semantic segmentation from
`"Searching for MobileNetV3"
<https://arxiv.org/abs/1905.02244>`_.
Args:
backbone (nn.Module): the network used to compute the features for the model.
The backbone should return an OrderedDict[Tensor], with the key being
"high" for the high level feature map and "low" for the low level feature map.
low_channels (int): the number of channels of the low level features.
high_channels (int): the number of channels of the high level features.
num_classes (int): number of output classes of the model (including the background).
inter_channels (int, optional): the number of channels for intermediate computations.
"""
def __init__(self, backbone, low_channels, high_channels, num_classes, inter_channels=128):
super().__init__()
self.backbone = backbone
self.classifier = LRASPPHead(low_channels, high_channels, num_classes, inter_channels)
def forward(self, input):
features = self.backbone(input)
out = self.classifier(features)
out = F.interpolate(out, size=input.shape[-2:], mode='bilinear', align_corners=False)
result = OrderedDict()
result["out"] = out
return result
class LRASPPHead(nn.Module):
def __init__(self, low_channels, high_channels, num_classes, inter_channels):
super().__init__()
self.cbr = nn.Sequential(
nn.Conv2d(high_channels, inter_channels, 1, bias=False),
nn.BatchNorm2d(inter_channels),
nn.ReLU(inplace=True)
)
self.scale = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(high_channels, inter_channels, 1, bias=False),
nn.Sigmoid(),
)
self.low_classifier = nn.Conv2d(low_channels, num_classes, 1)
self.high_classifier = nn.Conv2d(inter_channels, num_classes, 1)
def forward(self, input: Dict[str, Tensor]) -> Tensor:
low = input["low"]
high = input["high"]
x = self.cbr(high)
s = self.scale(high)
x = x * s
x = F.interpolate(x, size=low.shape[-2:], mode='bilinear', align_corners=False)
return self.low_classifier(low) + self.high_classifier(x)
from .._utils import IntermediateLayerGetter
from ..utils import load_state_dict_from_url
from .. import mobilenetv3
from .. import resnet
from .deeplabv3 import DeepLabHead, DeepLabV3
from .fcn import FCN, FCNHead
from .lraspp import LRASPP
__all__ = ['fcn_resnet50', 'fcn_resnet101', 'deeplabv3_resnet50', 'deeplabv3_resnet101']
__all__ = ['fcn_resnet50', 'fcn_resnet101', 'deeplabv3_resnet50', 'deeplabv3_resnet101',
'deeplabv3_mobilenet_v3_large', 'lraspp_mobilenet_v3_large']
model_urls = {
......@@ -13,30 +16,50 @@ model_urls = {
'fcn_resnet101_coco': 'https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth',
'deeplabv3_resnet50_coco': 'https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth',
'deeplabv3_resnet101_coco': 'https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth',
'deeplabv3_mobilenet_v3_large_coco':
'https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth',
'lraspp_mobilenet_v3_large_coco': 'https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth',
}
def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True):
def _segm_model(name, backbone_name, num_classes, aux, pretrained_backbone=True):
if 'resnet' in backbone_name:
backbone = resnet.__dict__[backbone_name](
pretrained=pretrained_backbone,
replace_stride_with_dilation=[False, True, True])
out_layer = 'layer4'
out_inplanes = 2048
aux_layer = 'layer3'
aux_inplanes = 1024
elif 'mobilenet_v3' in backbone_name:
backbone = mobilenetv3.__dict__[backbone_name](pretrained=pretrained_backbone, _dilated=True).features
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
out_pos = stage_indices[-1] # use C5 which has output_stride = 16
out_layer = str(out_pos)
out_inplanes = backbone[out_pos].out_channels
aux_pos = stage_indices[-4] # use C2 here which has output_stride = 8
aux_layer = str(aux_pos)
aux_inplanes = backbone[aux_pos].out_channels
else:
raise NotImplementedError('backbone {} is not supported as of now'.format(backbone_name))
return_layers = {'layer4': 'out'}
return_layers = {out_layer: 'out'}
if aux:
return_layers['layer3'] = 'aux'
return_layers[aux_layer] = 'aux'
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
aux_classifier = None
if aux:
inplanes = 1024
aux_classifier = FCNHead(inplanes, num_classes)
aux_classifier = FCNHead(aux_inplanes, num_classes)
model_map = {
'deeplabv3': (DeepLabHead, DeepLabV3),
'fcn': (FCNHead, FCN),
}
inplanes = 2048
classifier = model_map[name][0](inplanes, num_classes)
classifier = model_map[name][0](out_inplanes, num_classes)
base_model = model_map[name][1]
model = base_model(backbone, classifier, aux_classifier)
......@@ -46,15 +69,36 @@ def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True
def _load_model(arch_type, backbone, pretrained, progress, num_classes, aux_loss, **kwargs):
if pretrained:
aux_loss = True
model = _segm_resnet(arch_type, backbone, num_classes, aux_loss, **kwargs)
model = _segm_model(arch_type, backbone, num_classes, aux_loss, **kwargs)
if pretrained:
_load_weights(model, arch_type, backbone, progress)
return model
def _load_weights(model, arch_type, backbone, progress):
arch = arch_type + '_' + backbone + '_coco'
model_url = model_urls[arch]
model_url = model_urls.get(arch, None)
if model_url is None:
raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
else:
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
def _segm_lraspp_mobilenetv3(backbone_name, num_classes, pretrained_backbone=True):
backbone = mobilenetv3.__dict__[backbone_name](pretrained=pretrained_backbone, _dilated=True).features
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
low_pos = stage_indices[-4] # use C2 here which has output_stride = 8
high_pos = stage_indices[-1] # use C5 which has output_stride = 16
low_channels = backbone[low_pos].out_channels
high_channels = backbone[high_pos].out_channels
backbone = IntermediateLayerGetter(backbone, return_layers={str(low_pos): 'low', str(high_pos): 'high'})
model = LRASPP(backbone, low_channels, high_channels, num_classes)
return model
......@@ -66,6 +110,8 @@ def fcn_resnet50(pretrained=False, progress=True,
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
aux_loss (bool): If True, it uses an auxiliary loss
"""
return _load_model('fcn', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs)
......@@ -78,6 +124,8 @@ def fcn_resnet101(pretrained=False, progress=True,
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
aux_loss (bool): If True, it uses an auxiliary loss
"""
return _load_model('fcn', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs)
......@@ -90,6 +138,8 @@ def deeplabv3_resnet50(pretrained=False, progress=True,
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
aux_loss (bool): If True, it uses an auxiliary loss
"""
return _load_model('deeplabv3', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs)
......@@ -102,5 +152,42 @@ def deeplabv3_resnet101(pretrained=False, progress=True,
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
aux_loss (bool): If True, it uses an auxiliary loss
"""
return _load_model('deeplabv3', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs)
def deeplabv3_mobilenet_v3_large(pretrained=False, progress=True,
num_classes=21, aux_loss=None, **kwargs):
"""Constructs a DeepLabV3 model with a MobileNetV3-Large backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
aux_loss (bool): If True, it uses an auxiliary loss
"""
return _load_model('deeplabv3', 'mobilenet_v3_large', pretrained, progress, num_classes, aux_loss, **kwargs)
def lraspp_mobilenet_v3_large(pretrained=False, progress=True, num_classes=21, **kwargs):
"""Constructs a Lite R-ASPP Network model with a MobileNetV3-Large backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
"""
if kwargs.pop("aux_loss", False):
raise NotImplementedError('This model does not use auxiliary loss')
backbone_name = 'mobilenet_v3_large'
model = _segm_lraspp_mobilenetv3(backbone_name, num_classes, **kwargs)
if pretrained:
_load_weights(model, 'lraspp', backbone_name, progress)
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment