"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "8d858c380e59fb307e5e8774ec7fa1866384345c"
Unverified Commit 6272c412 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Upload pre-trained weights for MobileNet and ResNeXt (#917)

Also move weights from ShuffleNet to PyTorch bucket. Additionally, rename shufflenet to make it consistent with the other models
parent b384c4e7
......@@ -12,6 +12,8 @@ architectures:
- `Inception`_ v3
- `GoogLeNet`_
- `ShuffleNet`_ v2
- `MobileNet`_ v2
- `ResNeXt`_
You can construct a model with random weights by calling its constructor:
......@@ -25,7 +27,9 @@ You can construct a model with random weights by calling its constructor:
densenet = models.densenet161()
inception = models.inception_v3()
googlenet = models.googlenet()
shufflenet = models.shufflenetv2()
shufflenet = models.shufflenet_v2_x1_0()
mobilenet = models.mobilenet_v2()
resnext50_32x4d = models.resnext50_32x4d()
We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`.
These can be constructed by passing ``pretrained=True``:
......@@ -40,7 +44,9 @@ These can be constructed by passing ``pretrained=True``:
densenet = models.densenet161(pretrained=True)
inception = models.inception_v3(pretrained=True)
googlenet = models.googlenet(pretrained=True)
shufflenet = models.shufflenetv2(pretrained=True)
shufflenet = models.shufflenet_v2_x1_0(pretrained=True)
mobilenet = models.mobilenet_v2(pretrained=True)
resnext50_32x4d = models.resnext50_32x4d(pretrained=True)
Instancing a pre-trained model will download its weights to a cache directory.
This directory can be set using the `TORCH_MODEL_ZOO` environment variable. See
......@@ -92,6 +98,9 @@ Densenet-161 22.35 6.20
Inception v3 22.55 6.44
GoogleNet 30.22 10.47
ShuffleNet V2 30.64 11.68
MobileNet V2 28.12 9.71
ResNeXt-50-32x4d 22.38 6.30
ResNeXt-101-32x8d 20.69 5.47
================================ ============= =============
......@@ -103,6 +112,8 @@ ShuffleNet V2 30.64 11.68
.. _Inception: https://arxiv.org/abs/1512.00567
.. _GoogLeNet: https://arxiv.org/abs/1409.4842
.. _ShuffleNet: https://arxiv.org/abs/1807.11164
.. _MobileNet: https://arxiv.org/abs/1801.04381
.. _ResNeXt: https://arxiv.org/abs/1611.05431
.. currentmodule:: torchvision.models
......@@ -162,3 +173,14 @@ ShuffleNet v2
.. autofunction:: shufflenet
MobileNet v2
-------------
.. autofunction:: mobilenet_v2
ResNext
-------------
.. autofunction:: resnext50_32x4d
.. autofunction:: resnext101_32x8d
from torch import nn
from .utils import load_state_dict_from_url
__all__ = ['MobileNetV2', 'mobilenet_v2']
model_urls = {
'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
}
class ConvBNReLU(nn.Sequential):
......@@ -99,5 +108,18 @@ class MobileNetV2(nn.Module):
return x
def mobilenet_v2(pretrained=False, **kwargs):
return MobileNetV2(**kwargs)
def mobilenet_v2(pretrained=False, progress=True, **kwargs):
"""
Constructs a MobileNetV2 architecture from
`"MobileNetV2: Inverted Residuals and Linear Bottlenecks" <https://arxiv.org/abs/1801.04381>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
model = MobileNetV2(**kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
progress=progress)
model.load_state_dict(state_dict)
return model
......@@ -12,6 +12,8 @@ model_urls = {
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
}
......@@ -268,15 +270,27 @@ def resnet152(pretrained=False, progress=True, **kwargs):
**kwargs)
def resnext50_32x4d(**kwargs):
def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
"""Constructs a ResNeXt-50 32x4d model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 4
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
pretrained=False, progress=True, **kwargs)
pretrained, progress, **kwargs)
def resnext101_32x8d(**kwargs):
def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
"""Constructs a ResNeXt-101 32x8d model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 8
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
pretrained=False, progress=True, **kwargs)
pretrained, progress, **kwargs)
......@@ -4,13 +4,15 @@ import torch
import torch.nn as nn
from .utils import load_state_dict_from_url
__all__ = ['ShuffleNetV2', 'shufflenetv2_x0_5', 'shufflenetv2_x1_0', 'shufflenetv2_x1_5', 'shufflenetv2_x2_0']
__all__ = [
'ShuffleNetV2', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0',
'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0'
]
model_urls = {
'shufflenetv2_x0.5':
'https://github.com/barrh/Shufflenet-v2-Pytorch/releases/download/v0.1.0/shufflenetv2_x0.5-f707e7126e.pt',
'shufflenetv2_x1.0':
'https://github.com/barrh/Shufflenet-v2-Pytorch/releases/download/v0.1.0/shufflenetv2_x1-5666bf0f80.pt',
'shufflenetv2_x0.5': 'https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth',
'shufflenetv2_x1.0': 'https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth',
'shufflenetv2_x1.5': None,
'shufflenetv2_x2.0': None,
}
......@@ -142,27 +144,27 @@ def _shufflenetv2(arch, pretrained, progress, *args, **kwargs):
if model_url is None:
raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
else:
state_dict = load_state_dict_from_url(model_urls, progress=progress)
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
return model
def shufflenetv2_x0_5(pretrained=False, progress=True, **kwargs):
def shufflenet_v2_x0_5(pretrained=False, progress=True, **kwargs):
return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress,
[4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)
def shufflenetv2_x1_0(pretrained=False, progress=True, **kwargs):
def shufflenet_v2_x1_0(pretrained=False, progress=True, **kwargs):
return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress,
[4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)
def shufflenetv2_x1_5(pretrained=False, progress=True, **kwargs):
def shufflenet_v2_x1_5(pretrained=False, progress=True, **kwargs):
return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress,
[4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)
def shufflenetv2_x2_0(pretrained=False, progress=True, **kwargs):
def shufflenet_v2_x2_0(pretrained=False, progress=True, **kwargs):
return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress,
[4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment