Commit 0564df43 authored by ekka's avatar ekka Committed by Francisco Massa
Browse files

Refactoring of ShuffleNetV2 (#889)

* Minor refactoring of ShuffleNetV2

Added progress flag following #875. Further the following refactoring was also done:

1) added `version` argument in shufflenetv2 method and removed the operations for converting the `width_mult` arg to float and string.
2) removed `num_classes` argument and **kwargs from functions except `ShuffleNetV2`

* removed `version` arg

* Update shufflenetv2.py

* Removed the try except block

* Update shufflenetv2.py

* Changed version from float to str

* Replace `width_mult` with `stages_out_channels`

Removes the need of  `_getStages` function.
parent 78ed423d
......@@ -2,10 +2,9 @@ import functools
import torch
import torch.nn as nn
from .utils import load_state_dict_from_url
__all__ = ['ShuffleNetV2', 'shufflenetv2',
'shufflenetv2_x0_5', 'shufflenetv2_x1_0',
'shufflenetv2_x1_5', 'shufflenetv2_x2_0']
__all__ = ['ShuffleNetV2', 'shufflenetv2_x0_5', 'shufflenetv2_x1_0', 'shufflenetv2_x1_5', 'shufflenetv2_x2_0']
model_urls = {
'shufflenetv2_x0.5':
......@@ -85,16 +84,13 @@ class InvertedResidual(nn.Module):
class ShuffleNetV2(nn.Module):
def __init__(self, num_classes=1000, width_mult=1):
def __init__(self, stage_out_channels, num_classes=1000):
super(ShuffleNetV2, self).__init__()
try:
self.stage_out_channels = self._getStages(float(width_mult))
except KeyError:
raise ValueError('width_mult {} is not supported'.format(width_mult))
self.stage_out_channels = stage_out_channels
input_channels = 3
output_channels = self.stage_out_channels[0]
self.conv1 = nn.Sequential(
nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False),
nn.BatchNorm2d(output_channels),
......@@ -134,47 +130,32 @@ class ShuffleNetV2(nn.Module):
x = self.fc(x)
return x
@staticmethod
def _getStages(mult):
stages = {
'0.5': [24, 48, 96, 192, 1024],
'1.0': [24, 116, 232, 464, 1024],
'1.5': [24, 176, 352, 704, 1024],
'2.0': [24, 244, 488, 976, 2048],
}
return stages[str(mult)]
def shufflenetv2(pretrained=False, num_classes=1000, width_mult=1, **kwargs):
model = ShuffleNetV2(num_classes=num_classes, width_mult=width_mult)
def _shufflenetv2(arch, pretrained, progress, stage_out_channels, **kwargs):
model = ShuffleNetV2(stage_out_channels=stage_out_channels, **kwargs)
if pretrained:
# change width_mult to float
if isinstance(width_mult, int):
width_mult = float(width_mult)
model_type = ('_'.join([ShuffleNetV2.__name__, 'x' + str(width_mult)]))
try:
model_url = model_urls[model_type.lower()]
except KeyError:
raise ValueError('model {} is not support'.format(model_type))
model_url = model_urls[arch]
if model_url is None:
raise NotImplementedError('pretrained {} is not supported'.format(model_type))
model.load_state_dict(torch.utils.model_zoo.load_url(model_url))
raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
else:
state_dict = load_state_dict_from_url(model_urls, progress=progress)
model.load_state_dict(state_dict)
return model
def shufflenetv2_x0_5(pretrained=False, num_classes=1000, **kwargs):
return shufflenetv2(pretrained, num_classes, 0.5)
def shufflenetv2_x0_5(pretrained=False, progress=True, **kwargs):
return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress, [24, 48, 96, 192, 1024], **kwargs)
def shufflenetv2_x1_0(pretrained=False, num_classes=1000, **kwargs):
return shufflenetv2(pretrained, num_classes, 1)
def shufflenetv2_x1_0(pretrained=False, progress=True, **kwargs):
return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress, [24, 116, 232, 464, 1024], **kwargs)
def shufflenetv2_x1_5(pretrained=False, num_classes=1000, **kwargs):
return shufflenetv2(pretrained, num_classes, 1.5)
def shufflenetv2_x1_5(pretrained=False, progress=True, **kwargs):
return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress, [24, 176, 352, 704, 1024], **kwargs)
def shufflenetv2_x2_0(pretrained=False, num_classes=1000, **kwargs):
return shufflenetv2(pretrained, num_classes, 2)
def shufflenetv2_x2_0(pretrained=False, progress=True, **kwargs):
return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress, [24, 244, 488, 976, 2048], **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment