"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "53af9779fae66cade7870cf71a15e0abe4c7cd07"
Commit 43ab2fef authored by Bar's avatar Bar Committed by Francisco Massa
Browse files

Enhance ShufflenetV2 (#892)

* Enhance ShufflenetV2

Class shufflenetv2 receives `stages_repeats` and `stages_out_channels` arguments.

* remove explicit num_classes argument from utility functions
parent dc3ac290
...@@ -84,13 +84,17 @@ class InvertedResidual(nn.Module): ...@@ -84,13 +84,17 @@ class InvertedResidual(nn.Module):
class ShuffleNetV2(nn.Module): class ShuffleNetV2(nn.Module):
def __init__(self, stage_out_channels, num_classes=1000): def __init__(self, stages_repeats, stages_out_channels, num_classes=1000):
super(ShuffleNetV2, self).__init__() super(ShuffleNetV2, self).__init__()
self.stage_out_channels = stage_out_channels if len(stages_repeats) != 3:
input_channels = 3 raise ValueError('expected stages_repeats as list of 3 positive ints')
output_channels = self.stage_out_channels[0] if len(stages_out_channels) != 5:
raise ValueError('expected stages_out_channels as list of 5 positive ints')
self._stage_out_channels = stages_out_channels
input_channels = 3
output_channels = self._stage_out_channels[0]
self.conv1 = nn.Sequential( self.conv1 = nn.Sequential(
nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False), nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False),
nn.BatchNorm2d(output_channels), nn.BatchNorm2d(output_channels),
...@@ -101,16 +105,15 @@ class ShuffleNetV2(nn.Module): ...@@ -101,16 +105,15 @@ class ShuffleNetV2(nn.Module):
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
stage_names = ['stage{}'.format(i) for i in [2, 3, 4]] stage_names = ['stage{}'.format(i) for i in [2, 3, 4]]
stage_repeats = [4, 8, 4]
for name, repeats, output_channels in zip( for name, repeats, output_channels in zip(
stage_names, stage_repeats, self.stage_out_channels[1:]): stage_names, stages_repeats, self._stage_out_channels[1:]):
seq = [InvertedResidual(input_channels, output_channels, 2)] seq = [InvertedResidual(input_channels, output_channels, 2)]
for i in range(repeats - 1): for i in range(repeats - 1):
seq.append(InvertedResidual(output_channels, output_channels, 1)) seq.append(InvertedResidual(output_channels, output_channels, 1))
setattr(self, name, nn.Sequential(*seq)) setattr(self, name, nn.Sequential(*seq))
input_channels = output_channels input_channels = output_channels
output_channels = self.stage_out_channels[-1] output_channels = self._stage_out_channels[-1]
self.conv5 = nn.Sequential( self.conv5 = nn.Sequential(
nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False), nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False),
nn.BatchNorm2d(output_channels), nn.BatchNorm2d(output_channels),
...@@ -131,8 +134,8 @@ class ShuffleNetV2(nn.Module): ...@@ -131,8 +134,8 @@ class ShuffleNetV2(nn.Module):
return x return x
def _shufflenetv2(arch, pretrained, progress, stage_out_channels, **kwargs): def _shufflenetv2(arch, pretrained, progress, *args, **kwargs):
model = ShuffleNetV2(stage_out_channels=stage_out_channels, **kwargs) model = ShuffleNetV2(*args, **kwargs)
if pretrained: if pretrained:
model_url = model_urls[arch] model_url = model_urls[arch]
...@@ -146,16 +149,20 @@ def _shufflenetv2(arch, pretrained, progress, stage_out_channels, **kwargs): ...@@ -146,16 +149,20 @@ def _shufflenetv2(arch, pretrained, progress, stage_out_channels, **kwargs):
def shufflenetv2_x0_5(pretrained=False, progress=True, **kwargs): def shufflenetv2_x0_5(pretrained=False, progress=True, **kwargs):
return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress, [24, 48, 96, 192, 1024], **kwargs) return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress,
[4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)
def shufflenetv2_x1_0(pretrained=False, progress=True, **kwargs): def shufflenetv2_x1_0(pretrained=False, progress=True, **kwargs):
return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress, [24, 116, 232, 464, 1024], **kwargs) return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress,
[4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)
def shufflenetv2_x1_5(pretrained=False, progress=True, **kwargs): def shufflenetv2_x1_5(pretrained=False, progress=True, **kwargs):
return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress, [24, 176, 352, 704, 1024], **kwargs) return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress,
[4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)
def shufflenetv2_x2_0(pretrained=False, progress=True, **kwargs): def shufflenetv2_x2_0(pretrained=False, progress=True, **kwargs):
return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress, [24, 244, 488, 976, 2048], **kwargs) return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress,
[4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment