Unverified Commit 7bf6e7b1 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Add MobileNetV3 architecture for Classification (#3252)

* Add MobileNetV3 Architecture in TorchVision (#3182)

* Adding implementation of network architecture

* Adding rmsprop support on the train.py

* Adding auto-augment and random-erase in the training scripts.

* Adding support for reduced tail on MobileNetV3.

* Tagging blocks with comments.

* Adding documentation, pre-trained model URL and a minor refactoring.

* Handling better untrained supported models.
parent 8ebfd2f5
......@@ -22,7 +22,8 @@ architectures for image classification:
- `Inception`_ v3
- `GoogLeNet`_
- `ShuffleNet`_ v2
- `MobileNet`_ v2
- `MobileNetV2`_
- `MobileNetV3`_
- `ResNeXt`_
- `Wide ResNet`_
- `MNASNet`_
......@@ -40,7 +41,9 @@ You can construct a model with random weights by calling its constructor:
inception = models.inception_v3()
googlenet = models.googlenet()
shufflenet = models.shufflenet_v2_x1_0()
mobilenet = models.mobilenet_v2()
mobilenet_v2 = models.mobilenet_v2()
mobilenet_v3_large = models.mobilenet_v3_large()
mobilenet_v3_small = models.mobilenet_v3_small()
resnext50_32x4d = models.resnext50_32x4d()
wide_resnet50_2 = models.wide_resnet50_2()
mnasnet = models.mnasnet1_0()
......@@ -59,7 +62,8 @@ These can be constructed by passing ``pretrained=True``:
inception = models.inception_v3(pretrained=True)
googlenet = models.googlenet(pretrained=True)
shufflenet = models.shufflenet_v2_x1_0(pretrained=True)
mobilenet = models.mobilenet_v2(pretrained=True)
mobilenet_v2 = models.mobilenet_v2(pretrained=True)
mobilenet_v3_large = models.mobilenet_v3_large(pretrained=True)
resnext50_32x4d = models.resnext50_32x4d(pretrained=True)
wide_resnet50_2 = models.wide_resnet50_2(pretrained=True)
mnasnet = models.mnasnet1_0(pretrained=True)
......@@ -137,6 +141,7 @@ Inception v3 22.55 6.44
GoogleNet 30.22 10.47
ShuffleNet V2 30.64 11.68
MobileNet V2 28.12 9.71
MobileNet V3 Large 25.96 8.66
ResNeXt-50-32x4d 22.38 6.30
ResNeXt-101-32x8d 20.69 5.47
Wide ResNet-50-2 21.49 5.91
......@@ -153,7 +158,8 @@ MNASNet 1.0 26.49 8.456
.. _Inception: https://arxiv.org/abs/1512.00567
.. _GoogLeNet: https://arxiv.org/abs/1409.4842
.. _ShuffleNet: https://arxiv.org/abs/1807.11164
.. _MobileNet: https://arxiv.org/abs/1801.04381
.. _MobileNetV2: https://arxiv.org/abs/1801.04381
.. _MobileNetV3: https://arxiv.org/abs/1905.02244
.. _ResNeXt: https://arxiv.org/abs/1611.05431
.. _MNASNet: https://arxiv.org/abs/1807.11626
......@@ -231,6 +237,12 @@ MobileNet v2
.. autofunction:: mobilenet_v2
MobileNet v3
-------------
.. autofunction:: mobilenet_v3_large
.. autofunction:: mobilenet_v3_small
ResNext
-------
......
......@@ -11,7 +11,8 @@ from torchvision.models.squeezenet import squeezenet1_0, squeezenet1_1
from torchvision.models.vgg import vgg11, vgg13, vgg16, vgg19, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn
from torchvision.models.googlenet import googlenet
from torchvision.models.shufflenetv2 import shufflenet_v2_x0_5, shufflenet_v2_x1_0
from torchvision.models.mobilenet import mobilenet_v2
from torchvision.models.mobilenetv2 import mobilenet_v2
from torchvision.models.mobilenetv3 import mobilenet_v3_large, mobilenet_v3_small
from torchvision.models.mnasnet import mnasnet0_5, mnasnet0_75, mnasnet1_0, \
mnasnet1_3
......
......@@ -53,6 +53,16 @@ python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
--lr-step-size 1 --lr-gamma 0.98
```
### MobileNetV3 Large
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
--model mobilenet_v3_large --epochs 600 --opt rmsprop --batch-size 128 --lr 0.064\
--wd 0.00001 --lr-step-size 2 --lr-gamma 0.973 --auto-augment imagenet --random-erase 0.2
```
Then we averaged the parameters of the last 3 checkpoints that improved the Acc@1. See [#3182](https://github.com/pytorch/vision/pull/3182) for details.
## Mixed precision training
Automatic Mixed Precision (AMP) training on GPU for Pytorch can be enabled with the [NVIDIA Apex extension](https://github.com/NVIDIA/apex).
......
......@@ -79,7 +79,7 @@ def _get_cache_path(filepath):
return cache_path
def load_data(traindir, valdir, cache_dataset, distributed):
def load_data(traindir, valdir, args):
# Data loading code
print("Loading data")
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
......@@ -88,20 +88,28 @@ def load_data(traindir, valdir, cache_dataset, distributed):
print("Loading training data")
st = time.time()
cache_path = _get_cache_path(traindir)
if cache_dataset and os.path.exists(cache_path):
if args.cache_dataset and os.path.exists(cache_path):
# Attention, as the transforms are also cached!
print("Loading dataset_train from {}".format(cache_path))
dataset, _ = torch.load(cache_path)
else:
dataset = torchvision.datasets.ImageFolder(
traindir,
transforms.Compose([
trans = [
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
]
if args.auto_augment is not None:
aa_policy = transforms.AutoAugmentPolicy(args.auto_augment)
trans.append(transforms.AutoAugment(policy=aa_policy))
trans.extend([
transforms.ToTensor(),
normalize,
]))
if cache_dataset:
])
if args.random_erase > 0:
trans.append(transforms.RandomErasing(p=args.random_erase))
dataset = torchvision.datasets.ImageFolder(
traindir,
transforms.Compose(trans))
if args.cache_dataset:
print("Saving dataset_train to {}".format(cache_path))
utils.mkdir(os.path.dirname(cache_path))
utils.save_on_master((dataset, traindir), cache_path)
......@@ -109,7 +117,7 @@ def load_data(traindir, valdir, cache_dataset, distributed):
print("Loading validation data")
cache_path = _get_cache_path(valdir)
if cache_dataset and os.path.exists(cache_path):
if args.cache_dataset and os.path.exists(cache_path):
# Attention, as the transforms are also cached!
print("Loading dataset_test from {}".format(cache_path))
dataset_test, _ = torch.load(cache_path)
......@@ -122,13 +130,13 @@ def load_data(traindir, valdir, cache_dataset, distributed):
transforms.ToTensor(),
normalize,
]))
if cache_dataset:
if args.cache_dataset:
print("Saving dataset_test to {}".format(cache_path))
utils.mkdir(os.path.dirname(cache_path))
utils.save_on_master((dataset_test, valdir), cache_path)
print("Creating data loaders")
if distributed:
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
else:
......@@ -155,8 +163,7 @@ def main(args):
train_dir = os.path.join(args.data_path, 'train')
val_dir = os.path.join(args.data_path, 'val')
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir,
args.cache_dataset, args.distributed)
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=args.batch_size,
sampler=train_sampler, num_workers=args.workers, pin_memory=True)
......@@ -173,8 +180,15 @@ def main(args):
criterion = nn.CrossEntropyLoss()
opt_name = args.opt.lower()
if opt_name == 'sgd':
optimizer = torch.optim.SGD(
model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
elif opt_name == 'rmsprop':
optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum,
weight_decay=args.weight_decay, eps=0.0316, alpha=0.9)
else:
raise RuntimeError("Invalid optimizer {}. Only SGD and RMSprop are supported.".format(args.opt))
if args.apex:
model, optimizer = amp.initialize(model, optimizer,
......@@ -238,6 +252,7 @@ def parse_args():
help='number of total epochs to run')
parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
help='number of data loading workers (default: 16)')
parser.add_argument('--opt', default='sgd', type=str, help='optimizer')
parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
......@@ -275,6 +290,8 @@ def parse_args():
help="Use pre-trained models from the modelzoo",
action="store_true",
)
parser.add_argument('--auto-augment', default=None, help='auto augment policy (default: None)')
parser.add_argument('--random-erase', default=0.0, type=float, help='random erasing probability (default: 0.0)')
# Mixed precision training parameters
parser.add_argument('--apex', action='store_true',
......
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
......@@ -275,14 +275,15 @@ class ModelTester(TestCase):
out = model(x)
self.assertEqual(out.shape[-1], 1000)
def test_mobilenetv2_norm_layer(self):
model = models.__dict__["mobilenet_v2"]()
def test_mobilenet_norm_layer(self):
for name in ["mobilenet_v2", "mobilenet_v3_large", "mobilenet_v3_small"]:
model = models.__dict__[name]()
self.assertTrue(any(isinstance(x, nn.BatchNorm2d) for x in model.modules()))
def get_gn(num_channels):
return nn.GroupNorm(32, num_channels)
model = models.__dict__["mobilenet_v2"](norm_layer=get_gn)
model = models.__dict__[name](norm_layer=get_gn)
self.assertFalse(any(isinstance(x, nn.BatchNorm2d) for x in model.modules()))
self.assertTrue(any(isinstance(x, nn.GroupNorm) for x in model.modules()))
......
from .mobilenetv2 import MobileNetV2, mobilenet_v2, __all__ as mv2_all
from .mobilenetv3 import MobileNetV3, mobilenet_v3_large, mobilenet_v3_small, __all__ as mv3_all
__all__ = mv2_all
__all__ = mv2_all + mv3_all
......@@ -53,6 +53,7 @@ class ConvBNActivation(nn.Sequential):
norm_layer(out_planes),
activation_layer(inplace=True)
)
self.out_channels = out_planes
# necessary for backwards compatibility
......@@ -90,6 +91,8 @@ class InvertedResidual(nn.Module):
norm_layer(oup),
])
self.conv = nn.Sequential(*layers)
self.out_channels = oup
self.is_strided = stride > 1
def forward(self, x: Tensor) -> Tensor:
if self.use_res_connect:
......
import torch
from functools import partial
from torch import nn, Tensor
from torch.nn import functional as F
from typing import Any, Callable, List, Optional, Sequence
from torchvision.models.utils import load_state_dict_from_url
from torchvision.models.mobilenetv2 import _make_divisible, ConvBNActivation
__all__ = ["MobileNetV3", "mobilenet_v3_large", "mobilenet_v3_small"]
model_urls = {
"mobilenet_v3_large": "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth",
"mobilenet_v3_small": None,
}
class Identity(nn.Module):
def __init__(self, inplace: bool = False):
super().__init__()
self.inplace = inplace
def forward(self, input: Tensor) -> Tensor:
return input
class SqueezeExcitation(nn.Module):
def __init__(self, input_channels: int, squeeze_factor: int = 4):
super().__init__()
squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8)
self.fc1 = nn.Conv2d(input_channels, squeeze_channels, 1)
self.fc2 = nn.Conv2d(squeeze_channels, input_channels, 1)
def forward(self, input: Tensor) -> Tensor:
scale = F.adaptive_avg_pool2d(input, 1)
scale = self.fc1(scale)
scale = F.relu(scale, inplace=True)
scale = self.fc2(scale)
scale = F.hardsigmoid(scale, inplace=True)
return scale * input
class InvertedResidualConfig:
def __init__(self, input_channels: int, kernel: int, expanded_channels: int, out_channels: int, use_se: bool,
activation: str, stride: int, width_mult: float):
self.input_channels = self.adjust_channels(input_channels, width_mult)
self.kernel = kernel
self.expanded_channels = self.adjust_channels(expanded_channels, width_mult)
self.out_channels = self.adjust_channels(out_channels, width_mult)
self.use_se = use_se
self.use_hs = activation == "HS"
self.stride = stride
@staticmethod
def adjust_channels(channels: int, width_mult: float):
return _make_divisible(channels * width_mult, 8)
class InvertedResidual(nn.Module):
def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Module]):
super().__init__()
if not (1 <= cnf.stride <= 2):
raise ValueError('illegal stride value')
self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels
layers: List[nn.Module] = []
activation_layer = nn.Hardswish if cnf.use_hs else nn.ReLU
# expand
if cnf.expanded_channels != cnf.input_channels:
layers.append(ConvBNActivation(cnf.input_channels, cnf.expanded_channels, kernel_size=1,
norm_layer=norm_layer, activation_layer=activation_layer))
# depthwise
layers.append(ConvBNActivation(cnf.expanded_channels, cnf.expanded_channels, kernel_size=cnf.kernel,
stride=cnf.stride, groups=cnf.expanded_channels, norm_layer=norm_layer,
activation_layer=activation_layer))
if cnf.use_se:
layers.append(SqueezeExcitation(cnf.expanded_channels))
# project
layers.append(ConvBNActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer,
activation_layer=Identity))
self.block = nn.Sequential(*layers)
self.out_channels = cnf.out_channels
self.is_strided = cnf.stride > 1
def forward(self, input: Tensor) -> Tensor:
result = self.block(input)
if self.use_res_connect:
result += input
return result
class MobileNetV3(nn.Module):
def __init__(
self,
inverted_residual_setting: List[InvertedResidualConfig],
last_channel: int,
num_classes: int = 1000,
block: Optional[Callable[..., nn.Module]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None
) -> None:
"""
MobileNet V3 main class
Args:
inverted_residual_setting (List[InvertedResidualConfig]): Network structure
last_channel (int): The number of channels on the penultimate layer
num_classes (int): Number of classes
block (Optional[Callable[..., nn.Module]]): Module specifying inverted residual building block for mobilenet
norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use
"""
super().__init__()
if not inverted_residual_setting:
raise ValueError("The inverted_residual_setting should not be empty")
elif not (isinstance(inverted_residual_setting, Sequence) and
all([isinstance(s, InvertedResidualConfig) for s in inverted_residual_setting])):
raise TypeError("The inverted_residual_setting should be List[InvertedResidualConfig]")
if block is None:
block = InvertedResidual
if norm_layer is None:
norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.01)
layers: List[nn.Module] = []
# building first layer
firstconv_output_channels = inverted_residual_setting[0].input_channels
layers.append(ConvBNActivation(3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer,
activation_layer=nn.Hardswish))
# building inverted residual blocks
for cnf in inverted_residual_setting:
layers.append(block(cnf, norm_layer))
# building last several layers
lastconv_input_channels = inverted_residual_setting[-1].out_channels
lastconv_output_channels = 6 * lastconv_input_channels
layers.append(ConvBNActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1,
norm_layer=norm_layer, activation_layer=nn.Hardswish))
self.features = nn.Sequential(*layers)
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Sequential(
nn.Linear(lastconv_output_channels, last_channel),
nn.Hardswish(inplace=True),
nn.Dropout(p=0.2, inplace=True),
nn.Linear(last_channel, num_classes),
)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.zeros_(m.bias)
def _forward_impl(self, x: Tensor) -> Tensor:
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
def forward(self, x: Tensor) -> Tensor:
return self._forward_impl(x)
def _mobilenet_v3(
arch: str,
inverted_residual_setting: List[InvertedResidualConfig],
last_channel: int,
pretrained: bool,
progress: bool,
**kwargs: Any
):
model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs)
if pretrained:
if model_urls.get(arch, None) is None:
raise ValueError("No checkpoint is available for model type {}".format(arch))
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
model.load_state_dict(state_dict)
return model
def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, reduced_tail: bool = False,
**kwargs: Any) -> MobileNetV3:
"""
Constructs a large MobileNetV3 architecture from
`"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
reduced_tail (bool): If True, reduces the channel counts of all feature layers
between C4 and C5 by 2. It is used to reduce the channel redundancy in the
backbone for Detection and Segmentation.
"""
width_mult = 1.0
bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult)
adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult)
reduce_divider = 2 if reduced_tail else 1
inverted_residual_setting = [
bneck_conf(16, 3, 16, 16, False, "RE", 1),
bneck_conf(16, 3, 64, 24, False, "RE", 2), # C1
bneck_conf(24, 3, 72, 24, False, "RE", 1),
bneck_conf(24, 5, 72, 40, True, "RE", 2), # C2
bneck_conf(40, 5, 120, 40, True, "RE", 1),
bneck_conf(40, 5, 120, 40, True, "RE", 1),
bneck_conf(40, 3, 240, 80, False, "HS", 2), # C3
bneck_conf(80, 3, 200, 80, False, "HS", 1),
bneck_conf(80, 3, 184, 80, False, "HS", 1),
bneck_conf(80, 3, 184, 80, False, "HS", 1),
bneck_conf(80, 3, 480, 112, True, "HS", 1),
bneck_conf(112, 3, 672, 112, True, "HS", 1),
bneck_conf(112, 5, 672, 160 // reduce_divider, True, "HS", 2), # C4
bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1),
bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1),
]
last_channel = adjust_channels(1280 // reduce_divider) # C5
return _mobilenet_v3("mobilenet_v3_large", inverted_residual_setting, last_channel, pretrained, progress, **kwargs)
def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, reduced_tail: bool = False,
**kwargs: Any) -> MobileNetV3:
"""
Constructs a small MobileNetV3 architecture from
`"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
reduced_tail (bool): If True, reduces the channel counts of all feature layers
between C4 and C5 by 2. It is used to reduce the channel redundancy in the
backbone for Detection and Segmentation.
"""
width_mult = 1.0
bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult)
adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult)
reduce_divider = 2 if reduced_tail else 1
inverted_residual_setting = [
bneck_conf(16, 3, 16, 16, True, "RE", 2), # C1
bneck_conf(16, 3, 72, 24, False, "RE", 2), # C2
bneck_conf(24, 3, 88, 24, False, "RE", 1),
bneck_conf(24, 5, 96, 40, True, "HS", 2), # C3
bneck_conf(40, 5, 240, 40, True, "HS", 1),
bneck_conf(40, 5, 240, 40, True, "HS", 1),
bneck_conf(40, 5, 120, 48, True, "HS", 1),
bneck_conf(48, 5, 144, 48, True, "HS", 1),
bneck_conf(48, 5, 288, 96 // reduce_divider, True, "HS", 2), # C4
bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1),
bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1),
]
last_channel = adjust_channels(1024 // reduce_divider) # C5
return _mobilenet_v3("mobilenet_v3_small", inverted_residual_setting, last_channel, pretrained, progress, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment