"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "d06750a5fd19781de68066bb34a3520af83cf124"
Commit 0564df43 authored by ekka's avatar ekka Committed by Francisco Massa
Browse files

Refactoring of ShuffleNetV2 (#889)

* Minor refactoring of ShuffleNetV2

Added progress flag following #875. Further the following refactoring was also done:

1) added `version` argument in shufflenetv2 method and removed the operations for converting the `width_mult` arg to float and string.
2) removed `num_classes` argument and **kwargs from functions except `ShuffleNetV2`

* removed `version` arg

* Update shufflenetv2.py

* Removed the try except block

* Update shufflenetv2.py

* Changed version from float to str

* Replace `width_mult` with `stages_out_channels`

Removes the need of  `_getStages` function.
parent 78ed423d
...@@ -2,10 +2,9 @@ import functools ...@@ -2,10 +2,9 @@ import functools
import torch import torch
import torch.nn as nn import torch.nn as nn
from .utils import load_state_dict_from_url
__all__ = ['ShuffleNetV2', 'shufflenetv2', __all__ = ['ShuffleNetV2', 'shufflenetv2_x0_5', 'shufflenetv2_x1_0', 'shufflenetv2_x1_5', 'shufflenetv2_x2_0']
'shufflenetv2_x0_5', 'shufflenetv2_x1_0',
'shufflenetv2_x1_5', 'shufflenetv2_x2_0']
model_urls = { model_urls = {
'shufflenetv2_x0.5': 'shufflenetv2_x0.5':
...@@ -85,16 +84,13 @@ class InvertedResidual(nn.Module): ...@@ -85,16 +84,13 @@ class InvertedResidual(nn.Module):
class ShuffleNetV2(nn.Module): class ShuffleNetV2(nn.Module):
def __init__(self, num_classes=1000, width_mult=1): def __init__(self, stage_out_channels, num_classes=1000):
super(ShuffleNetV2, self).__init__() super(ShuffleNetV2, self).__init__()
try: self.stage_out_channels = stage_out_channels
self.stage_out_channels = self._getStages(float(width_mult))
except KeyError:
raise ValueError('width_mult {} is not supported'.format(width_mult))
input_channels = 3 input_channels = 3
output_channels = self.stage_out_channels[0] output_channels = self.stage_out_channels[0]
self.conv1 = nn.Sequential( self.conv1 = nn.Sequential(
nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False), nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False),
nn.BatchNorm2d(output_channels), nn.BatchNorm2d(output_channels),
...@@ -134,47 +130,32 @@ class ShuffleNetV2(nn.Module): ...@@ -134,47 +130,32 @@ class ShuffleNetV2(nn.Module):
x = self.fc(x) x = self.fc(x)
return x return x
@staticmethod
def _getStages(mult):
stages = {
'0.5': [24, 48, 96, 192, 1024],
'1.0': [24, 116, 232, 464, 1024],
'1.5': [24, 176, 352, 704, 1024],
'2.0': [24, 244, 488, 976, 2048],
}
return stages[str(mult)]
def _shufflenetv2(arch, pretrained, progress, stage_out_channels, **kwargs):
def shufflenetv2(pretrained=False, num_classes=1000, width_mult=1, **kwargs): model = ShuffleNetV2(stage_out_channels=stage_out_channels, **kwargs)
model = ShuffleNetV2(num_classes=num_classes, width_mult=width_mult)
if pretrained: if pretrained:
# change width_mult to float model_url = model_urls[arch]
if isinstance(width_mult, int):
width_mult = float(width_mult)
model_type = ('_'.join([ShuffleNetV2.__name__, 'x' + str(width_mult)]))
try:
model_url = model_urls[model_type.lower()]
except KeyError:
raise ValueError('model {} is not support'.format(model_type))
if model_url is None: if model_url is None:
raise NotImplementedError('pretrained {} is not supported'.format(model_type)) raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
model.load_state_dict(torch.utils.model_zoo.load_url(model_url)) else:
state_dict = load_state_dict_from_url(model_urls, progress=progress)
model.load_state_dict(state_dict)
return model return model
def shufflenetv2_x0_5(pretrained=False, num_classes=1000, **kwargs): def shufflenetv2_x0_5(pretrained=False, progress=True, **kwargs):
return shufflenetv2(pretrained, num_classes, 0.5) return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress, [24, 48, 96, 192, 1024], **kwargs)
def shufflenetv2_x1_0(pretrained=False, num_classes=1000, **kwargs): def shufflenetv2_x1_0(pretrained=False, progress=True, **kwargs):
return shufflenetv2(pretrained, num_classes, 1) return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress, [24, 116, 232, 464, 1024], **kwargs)
def shufflenetv2_x1_5(pretrained=False, num_classes=1000, **kwargs): def shufflenetv2_x1_5(pretrained=False, progress=True, **kwargs):
return shufflenetv2(pretrained, num_classes, 1.5) return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress, [24, 176, 352, 704, 1024], **kwargs)
def shufflenetv2_x2_0(pretrained=False, num_classes=1000, **kwargs): def shufflenetv2_x2_0(pretrained=False, progress=True, **kwargs):
return shufflenetv2(pretrained, num_classes, 2) return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress, [24, 244, 488, 976, 2048], **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment