Unverified Commit 8317295c authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Add Quantizable MobilenetV3 architecture for Classification (#3323)

* Refactoring mobilenetv3 to make code reusable.

* Adding quantizable MobileNetV3 architecture.

* Fix bug on reference script.

* Moving documentation of quantized models in the right place.

* Update documentation.

* Workaround for loading correct weights of quant model.

* Update weight URL and readme.

* Adding eval.
parent 17393cb7
...@@ -263,6 +263,53 @@ MNASNet ...@@ -263,6 +263,53 @@ MNASNet
.. autofunction:: mnasnet1_0 .. autofunction:: mnasnet1_0
.. autofunction:: mnasnet1_3 .. autofunction:: mnasnet1_3
Quantized Models
----------------
The following architectures provide support for INT8 quantized models. You can get
a model with random weights by calling its constructor:
.. code:: python
import torchvision.models as models
googlenet = models.quantization.googlenet()
inception_v3 = models.quantization.inception_v3()
mobilenet_v2 = models.quantization.mobilenet_v2()
mobilenet_v3_large = models.quantization.mobilenet_v3_large()
mobilenet_v3_small = models.quantization.mobilenet_v3_small()
resnet18 = models.quantization.resnet18()
resnet50 = models.quantization.resnet50()
resnext101_32x8d = models.quantization.resnext101_32x8d()
shufflenet_v2_x0_5 = models.quantization.shufflenet_v2_x0_5()
shufflenet_v2_x1_0 = models.quantization.shufflenet_v2_x1_0()
shufflenet_v2_x1_5 = models.quantization.shufflenet_v2_x1_5()
shufflenet_v2_x2_0 = models.quantization.shufflenet_v2_x2_0()
Obtaining a pre-trained quantized model can be done with a few lines of code:
.. code:: python
import torchvision.models as models
model = models.quantization.mobilenet_v2(pretrained=True, quantize=True)
model.eval()
# run the model with quantized inputs and weights
out = model(torch.rand(1, 3, 224, 224))
We provide pre-trained quantized weights for the following models:
================================ ============= =============
Model Acc@1 Acc@5
================================ ============= =============
MobileNet V2 71.658 90.150
MobileNet V3 Large 73.004 90.858
ShuffleNet V2 68.360 87.582
ResNet 18 69.494 88.882
ResNet 50 75.920 92.814
ResNext 101 32x8d 78.986 94.480
Inception V3 77.176 93.354
GoogleNet 69.826 89.404
================================ ============= =============
Semantic Segmentation Semantic Segmentation
===================== =====================
......
...@@ -74,27 +74,6 @@ python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\ ...@@ -74,27 +74,6 @@ python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
``` ```
## Quantized ## Quantized
### INT8 models
We add INT8 quantized models to follow the quantization support added in PyTorch 1.3.
Obtaining a pre-trained quantized model can be obtained with a few lines of code:
```
model = torchvision.models.quantization.mobilenet_v2(pretrained=True, quantize=True)
model.eval()
# run the model with quantized inputs and weights
out = model(torch.rand(1, 3, 224, 224))
```
We provide pre-trained quantized weights for the following models:
| Model | Acc@1 | Acc@5 |
|:-----------------:|:------:|:------:|
| MobileNet V2 | 71.658 | 90.150 |
| ShuffleNet V2: | 68.360 | 87.582 |
| ResNet 18 | 69.494 | 88.882 |
| ResNet 50 | 75.920 | 92.814 |
| ResNext 101 32x8d | 78.986 | 94.480 |
| Inception V3 | 77.176 | 93.354 |
| GoogleNet | 69.826 | 89.404 |
### Parameters used for generating quantized models: ### Parameters used for generating quantized models:
...@@ -106,6 +85,10 @@ For all post training quantized models (All quantized models except mobilenet-v2 ...@@ -106,6 +85,10 @@ For all post training quantized models (All quantized models except mobilenet-v2
4. eval_batch_size: 128 4. eval_batch_size: 128
5. backend: 'fbgemm' 5. backend: 'fbgemm'
```
python train_quantization.py --device='cpu' --post-training-quantize --backend='fbgemm' --model='<model_name>'
```
For Mobilenet-v2, the model was trained with quantization aware training, the settings used are: For Mobilenet-v2, the model was trained with quantization aware training, the settings used are:
1. num_workers: 16 1. num_workers: 16
2. batch_size: 32 2. batch_size: 32
...@@ -118,15 +101,38 @@ For Mobilenet-v2, the model was trained with quantization aware training, the se ...@@ -118,15 +101,38 @@ For Mobilenet-v2, the model was trained with quantization aware training, the se
9. momentum: 0.9 9. momentum: 0.9
10. lr_step_size:30 10. lr_step_size:30
11. lr_gamma: 0.1 11. lr_gamma: 0.1
12. weight-decay: 0.0001
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train_quantization.py --model='mobilenet_v2'
```
Training converges at about 10 epochs. Training converges at about 10 epochs.
For post training quant, device is set to CPU. For training, the device is set to CUDA For Mobilenet-v3 Large, the model was trained with quantization aware training, the settings used are:
1. num_workers: 16
2. batch_size: 32
3. eval_batch_size: 128
4. backend: 'qnnpack'
5. learning-rate: 0.001
6. num_epochs: 90
7. num_observer_update_epochs:4
8. num_batch_norm_update_epochs:3
9. momentum: 0.9
10. lr_step_size:30
11. lr_gamma: 0.1
12. weight-decay: 0.00001
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train_quantization.py --model='mobilenet_v3_large' \
--wd 0.00001 --lr 0.001
```
For post training quant, device is set to CPU. For training, the device is set to CUDA.
### Command to evaluate quantized models using the pre-trained weights: ### Command to evaluate quantized models using the pre-trained weights:
For all quantized models:
``` ```
python references/classification/train_quantization.py --data-path='imagenet_full_size/' \ python train_quantization.py --device='cpu' --test-only --backend='<backend>' --model='<model_name>'
--device='cpu' --test-only --backend='fbgemm' --model='<model_name>'
``` ```
...@@ -92,10 +92,12 @@ def load_data(traindir, valdir, args): ...@@ -92,10 +92,12 @@ def load_data(traindir, valdir, args):
print("Loading dataset_train from {}".format(cache_path)) print("Loading dataset_train from {}".format(cache_path))
dataset, _ = torch.load(cache_path) dataset, _ = torch.load(cache_path)
else: else:
auto_augment_policy = getattr(args, "auto_augment", None)
random_erase_prob = getattr(args, "random_erase", 0.0)
dataset = torchvision.datasets.ImageFolder( dataset = torchvision.datasets.ImageFolder(
traindir, traindir,
presets.ClassificationPresetTrain(crop_size=crop_size, auto_augment_policy=args.auto_augment, presets.ClassificationPresetTrain(crop_size=crop_size, auto_augment_policy=auto_augment_policy,
random_erase_prob=args.random_erase)) random_erase_prob=random_erase_prob))
if args.cache_dataset: if args.cache_dataset:
print("Saving dataset_train to {}".format(cache_path)) print("Saving dataset_train to {}".format(cache_path))
utils.mkdir(os.path.dirname(cache_path)) utils.mkdir(os.path.dirname(cache_path))
......
...@@ -37,8 +37,7 @@ def main(args): ...@@ -37,8 +37,7 @@ def main(args):
train_dir = os.path.join(args.data_path, 'train') train_dir = os.path.join(args.data_path, 'train')
val_dir = os.path.join(args.data_path, 'val') val_dir = os.path.join(args.data_path, 'val')
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
args.cache_dataset, args.distributed)
data_loader = torch.utils.data.DataLoader( data_loader = torch.utils.data.DataLoader(
dataset, batch_size=args.batch_size, dataset, batch_size=args.batch_size,
sampler=train_sampler, num_workers=args.workers, pin_memory=True) sampler=train_sampler, num_workers=args.workers, pin_memory=True)
......
...@@ -3,7 +3,7 @@ import torch ...@@ -3,7 +3,7 @@ import torch
from functools import partial from functools import partial
from torch import nn, Tensor from torch import nn, Tensor
from torch.nn import functional as F from torch.nn import functional as F
from typing import Any, Callable, List, Optional, Sequence from typing import Any, Callable, Dict, List, Optional, Sequence
from torchvision.models.utils import load_state_dict_from_url from torchvision.models.utils import load_state_dict_from_url
from torchvision.models.mobilenetv2 import _make_divisible, ConvBNActivation from torchvision.models.mobilenetv2 import _make_divisible, ConvBNActivation
...@@ -24,14 +24,18 @@ class SqueezeExcitation(nn.Module): ...@@ -24,14 +24,18 @@ class SqueezeExcitation(nn.Module):
super().__init__() super().__init__()
squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8) squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8)
self.fc1 = nn.Conv2d(input_channels, squeeze_channels, 1) self.fc1 = nn.Conv2d(input_channels, squeeze_channels, 1)
self.relu = nn.ReLU(inplace=True)
self.fc2 = nn.Conv2d(squeeze_channels, input_channels, 1) self.fc2 = nn.Conv2d(squeeze_channels, input_channels, 1)
def forward(self, input: Tensor) -> Tensor: def _scale(self, input: Tensor, inplace: bool) -> Tensor:
scale = F.adaptive_avg_pool2d(input, 1) scale = F.adaptive_avg_pool2d(input, 1)
scale = self.fc1(scale) scale = self.fc1(scale)
scale = F.relu(scale, inplace=True) scale = self.relu(scale)
scale = self.fc2(scale) scale = self.fc2(scale)
scale = F.hardsigmoid(scale, inplace=True) return F.hardsigmoid(scale, inplace=inplace)
def forward(self, input: Tensor) -> Tensor:
scale = self._scale(input, True)
return scale * input return scale * input
...@@ -55,7 +59,8 @@ class InvertedResidualConfig: ...@@ -55,7 +59,8 @@ class InvertedResidualConfig:
class InvertedResidual(nn.Module): class InvertedResidual(nn.Module):
def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Module]): def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Module],
se_layer: Callable[..., nn.Module] = SqueezeExcitation):
super().__init__() super().__init__()
if not (1 <= cnf.stride <= 2): if not (1 <= cnf.stride <= 2):
raise ValueError('illegal stride value') raise ValueError('illegal stride value')
...@@ -76,7 +81,7 @@ class InvertedResidual(nn.Module): ...@@ -76,7 +81,7 @@ class InvertedResidual(nn.Module):
stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels, stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels,
norm_layer=norm_layer, activation_layer=activation_layer)) norm_layer=norm_layer, activation_layer=activation_layer))
if cnf.use_se: if cnf.use_se:
layers.append(SqueezeExcitation(cnf.expanded_channels)) layers.append(se_layer(cnf.expanded_channels))
# project # project
layers.append(ConvBNActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, layers.append(ConvBNActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer,
...@@ -179,40 +184,16 @@ class MobileNetV3(nn.Module): ...@@ -179,40 +184,16 @@ class MobileNetV3(nn.Module):
return self._forward_impl(x) return self._forward_impl(x)
def _mobilenet_v3( def _mobilenet_v3_conf(arch: str, params: Dict[str, Any]):
arch: str,
inverted_residual_setting: List[InvertedResidualConfig],
last_channel: int,
pretrained: bool,
progress: bool,
**kwargs: Any
):
model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs)
if pretrained:
if model_urls.get(arch, None) is None:
raise ValueError("No checkpoint is available for model type {}".format(arch))
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
model.load_state_dict(state_dict)
return model
def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3:
"""
Constructs a large MobileNetV3 architecture from
`"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
# non-public config parameters # non-public config parameters
reduce_divider = 2 if kwargs.pop('_reduced_tail', False) else 1 reduce_divider = 2 if params.pop('_reduced_tail', False) else 1
dilation = 2 if kwargs.pop('_dilated', False) else 1 dilation = 2 if params.pop('_dilated', False) else 1
width_mult = 1.0 width_mult = params.pop('_width_mult', 1.0)
bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult) bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult)
adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult) adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult)
if arch == "mobilenet_v3_large":
inverted_residual_setting = [ inverted_residual_setting = [
bneck_conf(16, 3, 16, 16, False, "RE", 1, 1), bneck_conf(16, 3, 16, 16, False, "RE", 1, 1),
bneck_conf(16, 3, 64, 24, False, "RE", 2, 1), # C1 bneck_conf(16, 3, 64, 24, False, "RE", 2, 1), # C1
...@@ -231,27 +212,7 @@ def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, **kwargs ...@@ -231,27 +212,7 @@ def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, **kwargs
bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation), bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation),
] ]
last_channel = adjust_channels(1280 // reduce_divider) # C5 last_channel = adjust_channels(1280 // reduce_divider) # C5
elif arch == "mobilenet_v3_small":
return _mobilenet_v3("mobilenet_v3_large", inverted_residual_setting, last_channel, pretrained, progress, **kwargs)
def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3:
"""
Constructs a small MobileNetV3 architecture from
`"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
# non-public config parameters
reduce_divider = 2 if kwargs.pop('_reduced_tail', False) else 1
dilation = 2 if kwargs.pop('_dilated', False) else 1
width_mult = 1.0
bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult)
adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult)
inverted_residual_setting = [ inverted_residual_setting = [
bneck_conf(16, 3, 16, 16, True, "RE", 2, 1), # C1 bneck_conf(16, 3, 16, 16, True, "RE", 2, 1), # C1
bneck_conf(16, 3, 72, 24, False, "RE", 2, 1), # C2 bneck_conf(16, 3, 72, 24, False, "RE", 2, 1), # C2
...@@ -266,5 +227,52 @@ def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs ...@@ -266,5 +227,52 @@ def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs
bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1, dilation), bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1, dilation),
] ]
last_channel = adjust_channels(1024 // reduce_divider) # C5 last_channel = adjust_channels(1024 // reduce_divider) # C5
else:
raise ValueError("Unsupported model type {}".format(arch))
return inverted_residual_setting, last_channel
def _mobilenet_v3_model(
arch: str,
inverted_residual_setting: List[InvertedResidualConfig],
last_channel: int,
pretrained: bool,
progress: bool,
**kwargs: Any
):
model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs)
if pretrained:
if model_urls.get(arch, None) is None:
raise ValueError("No checkpoint is available for model type {}".format(arch))
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
model.load_state_dict(state_dict)
return model
return _mobilenet_v3("mobilenet_v3_small", inverted_residual_setting, last_channel, pretrained, progress, **kwargs) def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3:
"""
Constructs a large MobileNetV3 architecture from
`"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
arch = "mobilenet_v3_large"
inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, kwargs)
return _mobilenet_v3_model(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs)
def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3:
"""
Constructs a small MobileNetV3 architecture from
`"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
arch = "mobilenet_v3_small"
inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, kwargs)
return _mobilenet_v3_model(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs)
from .mobilenetv2 import QuantizableMobileNetV2, mobilenet_v2, __all__ as mv2_all from .mobilenetv2 import QuantizableMobileNetV2, mobilenet_v2, __all__ as mv2_all
from .mobilenetv3 import QuantizableMobileNetV3, mobilenet_v3_large, mobilenet_v3_small, __all__ as mv3_all
__all__ = mv2_all __all__ = mv2_all + mv3_all
import torch
from torch import nn, Tensor
from torchvision.models.utils import load_state_dict_from_url
from torchvision.models.mobilenetv3 import InvertedResidual, InvertedResidualConfig, ConvBNActivation, MobileNetV3,\
SqueezeExcitation, model_urls, _mobilenet_v3_conf
from torch.quantization import QuantStub, DeQuantStub, fuse_modules
from typing import Any, List, Optional
from .utils import _replace_relu
__all__ = ['QuantizableMobileNetV3', 'mobilenet_v3_large', 'mobilenet_v3_small']
quant_model_urls = {
'mobilenet_v3_large_qnnpack':
"https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth",
'mobilenet_v3_small_qnnpack': None,
}
class QuantizableSqueezeExcitation(SqueezeExcitation):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.skip_mul = nn.quantized.FloatFunctional()
def forward(self, input: Tensor) -> Tensor:
return self.skip_mul.mul(self._scale(input, False), input)
def fuse_model(self):
fuse_modules(self, ['fc1', 'relu'], inplace=True)
class QuantizableInvertedResidual(InvertedResidual):
def __init__(self, *args, **kwargs):
super().__init__(*args, se_layer=QuantizableSqueezeExcitation, **kwargs)
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, x):
if self.use_res_connect:
return self.skip_add.add(x, self.block(x))
else:
return self.block(x)
class QuantizableMobileNetV3(MobileNetV3):
def __init__(self, *args, **kwargs):
"""
MobileNet V3 main class
Args:
Inherits args from floating point MobileNetV3
"""
super().__init__(*args, **kwargs)
self.quant = QuantStub()
self.dequant = DeQuantStub()
def forward(self, x):
x = self.quant(x)
x = self._forward_impl(x)
x = self.dequant(x)
return x
def fuse_model(self):
for m in self.modules():
if type(m) == ConvBNActivation:
modules_to_fuse = ['0', '1']
if type(m[2]) == nn.ReLU:
modules_to_fuse.append('2')
fuse_modules(m, modules_to_fuse, inplace=True)
elif type(m) == QuantizableSqueezeExcitation:
m.fuse_model()
def _load_weights(
arch: str,
model: QuantizableMobileNetV3,
model_url: Optional[str],
progress: bool,
):
if model_url is None:
raise ValueError("No checkpoint is available for {}".format(arch))
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
def _mobilenet_v3_model(
arch: str,
inverted_residual_setting: List[InvertedResidualConfig],
last_channel: int,
pretrained: bool,
progress: bool,
quantize: bool,
**kwargs: Any
):
model = QuantizableMobileNetV3(inverted_residual_setting, last_channel, block=QuantizableInvertedResidual, **kwargs)
_replace_relu(model)
if quantize:
backend = 'qnnpack'
model.fuse_model()
model.qconfig = torch.quantization.get_default_qat_qconfig(backend)
torch.quantization.prepare_qat(model, inplace=True)
if pretrained:
_load_weights(arch, model, quant_model_urls.get(arch + '_' + backend, None), progress)
torch.quantization.convert(model, inplace=True)
model.eval()
else:
if pretrained:
_load_weights(arch, model, model_urls.get(arch, None), progress)
return model
def mobilenet_v3_large(pretrained=False, progress=True, quantize=False, **kwargs):
"""
Constructs a MobileNetV3 Large architecture from
`"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.
Note that quantize = True returns a quantized model with 8 bit
weights. Quantized models only support inference and run on CPUs.
GPU inference is not yet supported
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet.
progress (bool): If True, displays a progress bar of the download to stderr
quantize (bool): If True, returns a quantized model, else returns a float model
"""
arch = "mobilenet_v3_large"
inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, kwargs)
return _mobilenet_v3_model(arch, inverted_residual_setting, last_channel, pretrained, progress, quantize, **kwargs)
def mobilenet_v3_small(pretrained=False, progress=True, quantize=False, **kwargs):
"""
Constructs a MobileNetV3 Small architecture from
`"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.
Note that quantize = True returns a quantized model with 8 bit
weights. Quantized models only support inference and run on CPUs.
GPU inference is not yet supported
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet.
progress (bool): If True, displays a progress bar of the download to stderr
quantize (bool): If True, returns a quantized model, else returns a float model
"""
arch = "mobilenet_v3_small"
inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, kwargs)
return _mobilenet_v3_model(arch, inverted_residual_setting, last_channel, pretrained, progress, quantize, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment