"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "4f3ec5364e1543ce0e0b866eeed239f1aedcb9f4"
Commit 9cdc8144 authored by hx89's avatar hx89 Committed by Francisco Massa
Browse files

Quantizable googlenet, inceptionv3 and shufflenetv2 models (#1503)

* quantizable googlenet

* Minor improvements

* Rename basic_conv2d with conv_block plus additional fixes

* More renamings and fixes

* Bugfix

* Fix missing import for mypy

* Add pretrained weights
parent b438d321
...@@ -5,7 +5,7 @@ from collections import namedtuple ...@@ -5,7 +5,7 @@ from collections import namedtuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.jit.annotations import Optional from torch.jit.annotations import Optional, Tuple
from torch import Tensor from torch import Tensor
from .utils import load_state_dict_from_url from .utils import load_state_dict_from_url
...@@ -63,34 +63,42 @@ def googlenet(pretrained=False, progress=True, **kwargs): ...@@ -63,34 +63,42 @@ def googlenet(pretrained=False, progress=True, **kwargs):
class GoogLeNet(nn.Module): class GoogLeNet(nn.Module):
__constants__ = ['aux_logits', 'transform_input'] __constants__ = ['aux_logits', 'transform_input']
def __init__(self, num_classes=1000, aux_logits=True, transform_input=False, init_weights=True): def __init__(self, num_classes=1000, aux_logits=True, transform_input=False, init_weights=True,
blocks=None):
super(GoogLeNet, self).__init__() super(GoogLeNet, self).__init__()
if blocks is None:
blocks = [BasicConv2d, Inception, InceptionAux]
assert len(blocks) == 3
conv_block = blocks[0]
inception_block = blocks[1]
inception_aux_block = blocks[2]
self.aux_logits = aux_logits self.aux_logits = aux_logits
self.transform_input = transform_input self.transform_input = transform_input
self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3) self.conv1 = conv_block(3, 64, kernel_size=7, stride=2, padding=3)
self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True) self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
self.conv2 = BasicConv2d(64, 64, kernel_size=1) self.conv2 = conv_block(64, 64, kernel_size=1)
self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1) self.conv3 = conv_block(64, 192, kernel_size=3, padding=1)
self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True) self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32) self.inception3a = inception_block(192, 64, 96, 128, 16, 32, 32)
self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64) self.inception3b = inception_block(256, 128, 128, 192, 32, 96, 64)
self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True) self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64) self.inception4a = inception_block(480, 192, 96, 208, 16, 48, 64)
self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64) self.inception4b = inception_block(512, 160, 112, 224, 24, 64, 64)
self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64) self.inception4c = inception_block(512, 128, 128, 256, 24, 64, 64)
self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64) self.inception4d = inception_block(512, 112, 144, 288, 32, 64, 64)
self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128) self.inception4e = inception_block(528, 256, 160, 320, 32, 128, 128)
self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128) self.inception5a = inception_block(832, 256, 160, 320, 32, 128, 128)
self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128) self.inception5b = inception_block(832, 384, 192, 384, 48, 128, 128)
if aux_logits: if aux_logits:
self.aux1 = InceptionAux(512, num_classes) self.aux1 = inception_aux_block(512, num_classes)
self.aux2 = InceptionAux(528, num_classes) self.aux2 = inception_aux_block(528, num_classes)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.dropout = nn.Dropout(0.2) self.dropout = nn.Dropout(0.2)
...@@ -112,14 +120,17 @@ class GoogLeNet(nn.Module): ...@@ -112,14 +120,17 @@ class GoogLeNet(nn.Module):
nn.init.constant_(m.weight, 1) nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
def forward(self, x): def _transform_input(self, x):
# type: (Tensor) -> GoogLeNetOutputs # type: (Tensor) -> Tensor
if self.transform_input: if self.transform_input:
x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
x = torch.cat((x_ch0, x_ch1, x_ch2), 1) x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
return x
def _forward(self, x):
# type: (Tensor) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]
# N x 3 x 224 x 224 # N x 3 x 224 x 224
x = self.conv1(x) x = self.conv1(x)
# N x 64 x 112 x 112 # N x 64 x 112 x 112
...@@ -173,12 +184,7 @@ class GoogLeNet(nn.Module): ...@@ -173,12 +184,7 @@ class GoogLeNet(nn.Module):
x = self.dropout(x) x = self.dropout(x)
x = self.fc(x) x = self.fc(x)
# N x 1000 (num_classes) # N x 1000 (num_classes)
if torch.jit.is_scripting(): return x, aux2, aux1
if not aux_defined:
warnings.warn("Scripted GoogleNet always returns GoogleNetOutputs Tuple")
return GoogLeNetOutputs(x, aux2, aux1)
else:
return self.eager_outputs(x, aux2, aux1)
@torch.jit.unused @torch.jit.unused
def eager_outputs(self, x, aux2, aux1): def eager_outputs(self, x, aux2, aux1):
...@@ -188,45 +194,65 @@ class GoogLeNet(nn.Module): ...@@ -188,45 +194,65 @@ class GoogLeNet(nn.Module):
else: else:
return x return x
def forward(self, x):
# type: (Tensor) -> GoogLeNetOutputs
x = self._transform_input(x)
x, aux1, aux2 = self._forward(x)
aux_defined = self.training and self.aux_logits
if torch.jit.is_scripting():
if not aux_defined:
warnings.warn("Scripted GoogleNet always returns GoogleNetOutputs Tuple")
return GoogLeNetOutputs(x, aux2, aux1)
else:
return self.eager_outputs(x, aux2, aux1)
class Inception(nn.Module): class Inception(nn.Module):
__constants__ = ['branch2', 'branch3', 'branch4'] __constants__ = ['branch2', 'branch3', 'branch4']
def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj): def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj,
conv_block=None):
super(Inception, self).__init__() super(Inception, self).__init__()
if conv_block is None:
self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1) conv_block = BasicConv2d
self.branch1 = conv_block(in_channels, ch1x1, kernel_size=1)
self.branch2 = nn.Sequential( self.branch2 = nn.Sequential(
BasicConv2d(in_channels, ch3x3red, kernel_size=1), conv_block(in_channels, ch3x3red, kernel_size=1),
BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1) conv_block(ch3x3red, ch3x3, kernel_size=3, padding=1)
) )
self.branch3 = nn.Sequential( self.branch3 = nn.Sequential(
BasicConv2d(in_channels, ch5x5red, kernel_size=1), conv_block(in_channels, ch5x5red, kernel_size=1),
BasicConv2d(ch5x5red, ch5x5, kernel_size=3, padding=1) conv_block(ch5x5red, ch5x5, kernel_size=3, padding=1)
) )
self.branch4 = nn.Sequential( self.branch4 = nn.Sequential(
nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True), nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True),
BasicConv2d(in_channels, pool_proj, kernel_size=1) conv_block(in_channels, pool_proj, kernel_size=1)
) )
def forward(self, x): def _forward(self, x):
branch1 = self.branch1(x) branch1 = self.branch1(x)
branch2 = self.branch2(x) branch2 = self.branch2(x)
branch3 = self.branch3(x) branch3 = self.branch3(x)
branch4 = self.branch4(x) branch4 = self.branch4(x)
outputs = [branch1, branch2, branch3, branch4] outputs = [branch1, branch2, branch3, branch4]
return outputs
def forward(self, x):
outputs = self._forward(x)
return torch.cat(outputs, 1) return torch.cat(outputs, 1)
class InceptionAux(nn.Module): class InceptionAux(nn.Module):
def __init__(self, in_channels, num_classes): def __init__(self, in_channels, num_classes, conv_block=None):
super(InceptionAux, self).__init__() super(InceptionAux, self).__init__()
self.conv = BasicConv2d(in_channels, 128, kernel_size=1) if conv_block is None:
conv_block = BasicConv2d
self.conv = conv_block(in_channels, 128, kernel_size=1)
self.fc1 = nn.Linear(2048, 1024) self.fc1 = nn.Linear(2048, 1024)
self.fc2 = nn.Linear(1024, num_classes) self.fc2 = nn.Linear(1024, num_classes)
......
...@@ -6,6 +6,7 @@ import torch ...@@ -6,6 +6,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.jit.annotations import Optional from torch.jit.annotations import Optional
from torch import Tensor
from .utils import load_state_dict_from_url from .utils import load_state_dict_from_url
...@@ -63,28 +64,43 @@ def inception_v3(pretrained=False, progress=True, **kwargs): ...@@ -63,28 +64,43 @@ def inception_v3(pretrained=False, progress=True, **kwargs):
class Inception3(nn.Module): class Inception3(nn.Module):
def __init__(self, num_classes=1000, aux_logits=True, transform_input=False): def __init__(self, num_classes=1000, aux_logits=True, transform_input=False,
inception_blocks=None):
super(Inception3, self).__init__() super(Inception3, self).__init__()
if inception_blocks is None:
inception_blocks = [
BasicConv2d, InceptionA, InceptionB, InceptionC,
InceptionD, InceptionE, InceptionAux
]
assert len(inception_blocks) == 7
conv_block = inception_blocks[0]
inception_a = inception_blocks[1]
inception_b = inception_blocks[2]
inception_c = inception_blocks[3]
inception_d = inception_blocks[4]
inception_e = inception_blocks[5]
inception_aux = inception_blocks[6]
self.aux_logits = aux_logits self.aux_logits = aux_logits
self.transform_input = transform_input self.transform_input = transform_input
self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2) self.Conv2d_1a_3x3 = conv_block(3, 32, kernel_size=3, stride=2)
self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3) self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3)
self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1) self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1)
self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1) self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1)
self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3) self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3)
self.Mixed_5b = InceptionA(192, pool_features=32) self.Mixed_5b = inception_a(192, pool_features=32)
self.Mixed_5c = InceptionA(256, pool_features=64) self.Mixed_5c = inception_a(256, pool_features=64)
self.Mixed_5d = InceptionA(288, pool_features=64) self.Mixed_5d = inception_a(288, pool_features=64)
self.Mixed_6a = InceptionB(288) self.Mixed_6a = inception_b(288)
self.Mixed_6b = InceptionC(768, channels_7x7=128) self.Mixed_6b = inception_c(768, channels_7x7=128)
self.Mixed_6c = InceptionC(768, channels_7x7=160) self.Mixed_6c = inception_c(768, channels_7x7=160)
self.Mixed_6d = InceptionC(768, channels_7x7=160) self.Mixed_6d = inception_c(768, channels_7x7=160)
self.Mixed_6e = InceptionC(768, channels_7x7=192) self.Mixed_6e = inception_c(768, channels_7x7=192)
if aux_logits: if aux_logits:
self.AuxLogits = InceptionAux(768, num_classes) self.AuxLogits = inception_aux(768, num_classes)
self.Mixed_7a = InceptionD(768) self.Mixed_7a = inception_d(768)
self.Mixed_7b = InceptionE(1280) self.Mixed_7b = inception_e(1280)
self.Mixed_7c = InceptionE(2048) self.Mixed_7c = inception_e(2048)
self.fc = nn.Linear(2048, num_classes) self.fc = nn.Linear(2048, num_classes)
for m in self.modules(): for m in self.modules():
...@@ -100,12 +116,15 @@ class Inception3(nn.Module): ...@@ -100,12 +116,15 @@ class Inception3(nn.Module):
nn.init.constant_(m.weight, 1) nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
def forward(self, x): def _transform_input(self, x):
if self.transform_input: if self.transform_input:
x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
x = torch.cat((x_ch0, x_ch1, x_ch2), 1) x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
return x
def _forward(self, x):
# N x 3 x 299 x 299 # N x 3 x 299 x 299
x = self.Conv2d_1a_3x3(x) x = self.Conv2d_1a_3x3(x)
# N x 32 x 149 x 149 # N x 32 x 149 x 149
...@@ -158,37 +177,46 @@ class Inception3(nn.Module): ...@@ -158,37 +177,46 @@ class Inception3(nn.Module):
# N x 2048 # N x 2048
x = self.fc(x) x = self.fc(x)
# N x 1000 (num_classes) # N x 1000 (num_classes)
if torch.jit.is_scripting(): return x, aux
if not aux_defined:
warnings.warn("Scripted InceptionNet always returns InceptionOutputs Tuple")
return InceptionOutputs(x, aux)
else:
return self.eager_outputs(x, aux)
@torch.jit.unused @torch.jit.unused
def eager_outputs(self, x, aux): def eager_outputs(self, x, aux):
# type: (torch.Tensor, Optional[torch.Tensor]) -> InceptionOutputs # type: (Tensor, Optional[Tensor]) -> InceptionOutputs
if self.training and self.aux_logits: if self.training and self.aux_logits:
return InceptionOutputs(x, aux) return InceptionOutputs(x, aux)
return x else:
return x
def forward(self, x):
x = self._transform_input(x)
x, aux = self._forward(x)
aux_defined = self.training and self.aux_logits
if torch.jit.is_scripting():
if not aux_defined:
warnings.warn("Scripted Inception3 always returns Inception3 Tuple")
return InceptionOutputs(x, aux)
else:
return self.eager_outputs(x, aux)
class InceptionA(nn.Module): class InceptionA(nn.Module):
def __init__(self, in_channels, pool_features): def __init__(self, in_channels, pool_features, conv_block=None):
super(InceptionA, self).__init__() super(InceptionA, self).__init__()
self.branch1x1 = BasicConv2d(in_channels, 64, kernel_size=1) if conv_block is None:
conv_block = BasicConv2d
self.branch1x1 = conv_block(in_channels, 64, kernel_size=1)
self.branch5x5_1 = BasicConv2d(in_channels, 48, kernel_size=1) self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1)
self.branch5x5_2 = BasicConv2d(48, 64, kernel_size=5, padding=2) self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2)
self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1) self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1) self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, padding=1) self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1)
self.branch_pool = BasicConv2d(in_channels, pool_features, kernel_size=1) self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1)
def forward(self, x): def _forward(self, x):
branch1x1 = self.branch1x1(x) branch1x1 = self.branch1x1(x)
branch5x5 = self.branch5x5_1(x) branch5x5 = self.branch5x5_1(x)
...@@ -202,20 +230,26 @@ class InceptionA(nn.Module): ...@@ -202,20 +230,26 @@ class InceptionA(nn.Module):
branch_pool = self.branch_pool(branch_pool) branch_pool = self.branch_pool(branch_pool)
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
return outputs
def forward(self, x):
outputs = self._forward(x)
return torch.cat(outputs, 1) return torch.cat(outputs, 1)
class InceptionB(nn.Module): class InceptionB(nn.Module):
def __init__(self, in_channels): def __init__(self, in_channels, conv_block=None):
super(InceptionB, self).__init__() super(InceptionB, self).__init__()
self.branch3x3 = BasicConv2d(in_channels, 384, kernel_size=3, stride=2) if conv_block is None:
conv_block = BasicConv2d
self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2)
self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1) self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1) self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, stride=2) self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2)
def forward(self, x): def _forward(self, x):
branch3x3 = self.branch3x3(x) branch3x3 = self.branch3x3(x)
branch3x3dbl = self.branch3x3dbl_1(x) branch3x3dbl = self.branch3x3dbl_1(x)
...@@ -225,29 +259,35 @@ class InceptionB(nn.Module): ...@@ -225,29 +259,35 @@ class InceptionB(nn.Module):
branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
outputs = [branch3x3, branch3x3dbl, branch_pool] outputs = [branch3x3, branch3x3dbl, branch_pool]
return outputs
def forward(self, x):
outputs = self._forward(x)
return torch.cat(outputs, 1) return torch.cat(outputs, 1)
class InceptionC(nn.Module): class InceptionC(nn.Module):
def __init__(self, in_channels, channels_7x7): def __init__(self, in_channels, channels_7x7, conv_block=None):
super(InceptionC, self).__init__() super(InceptionC, self).__init__()
self.branch1x1 = BasicConv2d(in_channels, 192, kernel_size=1) if conv_block is None:
conv_block = BasicConv2d
self.branch1x1 = conv_block(in_channels, 192, kernel_size=1)
c7 = channels_7x7 c7 = channels_7x7
self.branch7x7_1 = BasicConv2d(in_channels, c7, kernel_size=1) self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1)
self.branch7x7_2 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3)) self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
self.branch7x7_3 = BasicConv2d(c7, 192, kernel_size=(7, 1), padding=(3, 0)) self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0))
self.branch7x7dbl_1 = BasicConv2d(in_channels, c7, kernel_size=1) self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1)
self.branch7x7dbl_2 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)) self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
self.branch7x7dbl_3 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3)) self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
self.branch7x7dbl_4 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)) self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
self.branch7x7dbl_5 = BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3)) self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3))
self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1) self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
def forward(self, x): def _forward(self, x):
branch1x1 = self.branch1x1(x) branch1x1 = self.branch1x1(x)
branch7x7 = self.branch7x7_1(x) branch7x7 = self.branch7x7_1(x)
...@@ -264,22 +304,28 @@ class InceptionC(nn.Module): ...@@ -264,22 +304,28 @@ class InceptionC(nn.Module):
branch_pool = self.branch_pool(branch_pool) branch_pool = self.branch_pool(branch_pool)
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
return outputs
def forward(self, x):
outputs = self._forward(x)
return torch.cat(outputs, 1) return torch.cat(outputs, 1)
class InceptionD(nn.Module): class InceptionD(nn.Module):
def __init__(self, in_channels): def __init__(self, in_channels, conv_block=None):
super(InceptionD, self).__init__() super(InceptionD, self).__init__()
self.branch3x3_1 = BasicConv2d(in_channels, 192, kernel_size=1) if conv_block is None:
self.branch3x3_2 = BasicConv2d(192, 320, kernel_size=3, stride=2) conv_block = BasicConv2d
self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1)
self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2)
self.branch7x7x3_1 = BasicConv2d(in_channels, 192, kernel_size=1) self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1)
self.branch7x7x3_2 = BasicConv2d(192, 192, kernel_size=(1, 7), padding=(0, 3)) self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3))
self.branch7x7x3_3 = BasicConv2d(192, 192, kernel_size=(7, 1), padding=(3, 0)) self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0))
self.branch7x7x3_4 = BasicConv2d(192, 192, kernel_size=3, stride=2) self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2)
def forward(self, x): def _forward(self, x):
branch3x3 = self.branch3x3_1(x) branch3x3 = self.branch3x3_1(x)
branch3x3 = self.branch3x3_2(branch3x3) branch3x3 = self.branch3x3_2(branch3x3)
...@@ -290,27 +336,33 @@ class InceptionD(nn.Module): ...@@ -290,27 +336,33 @@ class InceptionD(nn.Module):
branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
outputs = [branch3x3, branch7x7x3, branch_pool] outputs = [branch3x3, branch7x7x3, branch_pool]
return outputs
def forward(self, x):
outputs = self._forward(x)
return torch.cat(outputs, 1) return torch.cat(outputs, 1)
class InceptionE(nn.Module): class InceptionE(nn.Module):
def __init__(self, in_channels): def __init__(self, in_channels, conv_block=None):
super(InceptionE, self).__init__() super(InceptionE, self).__init__()
self.branch1x1 = BasicConv2d(in_channels, 320, kernel_size=1) if conv_block is None:
conv_block = BasicConv2d
self.branch1x1 = conv_block(in_channels, 320, kernel_size=1)
self.branch3x3_1 = BasicConv2d(in_channels, 384, kernel_size=1) self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1)
self.branch3x3_2a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1)) self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
self.branch3x3_2b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0)) self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
self.branch3x3dbl_1 = BasicConv2d(in_channels, 448, kernel_size=1) self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1)
self.branch3x3dbl_2 = BasicConv2d(448, 384, kernel_size=3, padding=1) self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1)
self.branch3x3dbl_3a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1)) self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
self.branch3x3dbl_3b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0)) self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1) self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
def forward(self, x): def _forward(self, x):
branch1x1 = self.branch1x1(x) branch1x1 = self.branch1x1(x)
branch3x3 = self.branch3x3_1(x) branch3x3 = self.branch3x3_1(x)
...@@ -332,15 +384,21 @@ class InceptionE(nn.Module): ...@@ -332,15 +384,21 @@ class InceptionE(nn.Module):
branch_pool = self.branch_pool(branch_pool) branch_pool = self.branch_pool(branch_pool)
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
return outputs
def forward(self, x):
outputs = self._forward(x)
return torch.cat(outputs, 1) return torch.cat(outputs, 1)
class InceptionAux(nn.Module): class InceptionAux(nn.Module):
def __init__(self, in_channels, num_classes): def __init__(self, in_channels, num_classes, conv_block=None):
super(InceptionAux, self).__init__() super(InceptionAux, self).__init__()
self.conv0 = BasicConv2d(in_channels, 128, kernel_size=1) if conv_block is None:
self.conv1 = BasicConv2d(128, 768, kernel_size=5) conv_block = BasicConv2d
self.conv0 = conv_block(in_channels, 128, kernel_size=1)
self.conv1 = conv_block(128, 768, kernel_size=5)
self.conv1.stddev = 0.01 self.conv1.stddev = 0.01
self.fc = nn.Linear(768, num_classes) self.fc = nn.Linear(768, num_classes)
self.fc.stddev = 0.001 self.fc.stddev = 0.001
......
from .mobilenet import * from .mobilenet import *
from .resnet import * from .resnet import *
from .googlenet import *
from .inception import *
from .shufflenetv2 import *
import warnings
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.jit.annotations import Optional
from torchvision.models.utils import load_state_dict_from_url
from torchvision.models.googlenet import (
GoogLeNetOutputs, BasicConv2d, Inception, InceptionAux, GoogLeNet, model_urls)
from .utils import _replace_relu, quantize_model
__all__ = ['QuantizableGoogLeNet', 'googlenet']
quant_model_urls = {
# fp32 GoogLeNet ported from TensorFlow, with weights quantized in PyTorch
'googlenet_fbgemm': 'https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth',
}
def googlenet(pretrained=False, progress=True, quantize=False, **kwargs):
r"""GoogLeNet (Inception v1) model architecture from
`"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_.
Note that quantize = True returns a quantized model with 8 bit
weights. Quantized models only support inference and run on CPUs.
GPU inference is not yet supported
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
aux_logits (bool): If True, adds two auxiliary branches that can improve training.
Default: *False* when pretrained is True otherwise *True*
transform_input (bool): If True, preprocesses the input according to the method with which it
was trained on ImageNet. Default: *False*
"""
if pretrained:
if 'transform_input' not in kwargs:
kwargs['transform_input'] = True
if 'aux_logits' not in kwargs:
kwargs['aux_logits'] = False
if kwargs['aux_logits']:
warnings.warn('auxiliary heads in the pretrained googlenet model are NOT pretrained, '
'so make sure to train them')
original_aux_logits = kwargs['aux_logits']
kwargs['aux_logits'] = True
kwargs['init_weights'] = False
model = QuantizableGoogLeNet(**kwargs)
_replace_relu(model)
if quantize:
# TODO use pretrained as a string to specify the backend
backend = 'fbgemm'
quantize_model(model, backend)
else:
assert pretrained in [True, False]
if pretrained:
if quantize:
model_url = quant_model_urls['googlenet' + '_' + backend]
else:
model_url = model_urls['googlenet']
state_dict = load_state_dict_from_url(model_url,
progress=progress)
model.load_state_dict(state_dict)
if not original_aux_logits:
model.aux_logits = False
del model.aux1, model.aux2
return model
class QuantizableBasicConv2d(BasicConv2d):
def __init__(self, *args, **kwargs):
super(QuantizableBasicConv2d, self).__init__(*args, **kwargs)
self.relu = nn.ReLU()
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
def fuse_model(self):
torch.quantization.fuse_modules(self, ["conv", "bn", "relu"], inplace=True)
class QuantizableInception(Inception):
def __init__(self, *args, **kwargs):
super(QuantizableInception, self).__init__(
conv_block=QuantizableBasicConv2d, *args, **kwargs)
self.cat = nn.quantized.FloatFunctional()
def forward(self, x):
outputs = self._forward(x)
return self.cat.cat(outputs, 1)
class QuantizableInceptionAux(InceptionAux):
def __init__(self, *args, **kwargs):
super(QuantizableInceptionAux, self).__init__(
conv_block=QuantizableBasicConv2d, *args, **kwargs)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.7)
def forward(self, x):
# aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
x = F.adaptive_avg_pool2d(x, (4, 4))
# aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
x = self.conv(x)
# N x 128 x 4 x 4
x = torch.flatten(x, 1)
# N x 2048
x = self.relu(self.fc1(x))
# N x 1024
x = self.dropout(x)
# N x 1024
x = self.fc2(x)
# N x 1000 (num_classes)
return x
class QuantizableGoogLeNet(GoogLeNet):
def __init__(self, *args, **kwargs):
super(QuantizableGoogLeNet, self).__init__(
blocks=[QuantizableBasicConv2d, QuantizableInception, QuantizableInceptionAux],
*args,
**kwargs
)
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
def forward(self, x):
x = self._transform_input(x)
x = self.quant(x)
x, aux1, aux2 = self._forward(x)
x = self.dequant(x)
aux_defined = self.training and self.aux_logits
if torch.jit.is_scripting():
if not aux_defined:
warnings.warn("Scripted QuantizableGoogleNet always returns GoogleNetOutputs Tuple")
return GoogLeNetOutputs(x, aux2, aux1)
else:
return self.eager_outputs(x, aux2, aux1)
def fuse_model(self):
r"""Fuse conv/bn/relu modules in googlenet model
Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization.
Model is modified in place. Note that this operation does not change numerics
and the model after modification is in floating point
"""
for m in self.modules():
if type(m) == QuantizableBasicConv2d:
m.fuse_model()
import warnings
from collections import namedtuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import inception as inception_module
from torchvision.models.inception import InceptionOutputs
from torch.jit.annotations import Optional
from torchvision.models.utils import load_state_dict_from_url
from .utils import _replace_relu, quantize_model
__all__ = [
"QuantizableInception3",
"inception_v3",
]
quant_model_urls = {
# fp32 weights ported from TensorFlow, quantized in PyTorch
"inception_v3_google_fbgemm":
"https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-4f6e4894.pth"
}
def inception_v3(pretrained=False, progress=True, quantize=False, **kwargs):
r"""Inception v3 model architecture from
`"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_.
.. note::
**Important**: In contrast to the other models the inception_v3 expects tensors with a size of
N x 3 x 299 x 299, so ensure your images are sized accordingly.
Note that quantize = True returns a quantized model with 8 bit
weights. Quantized models only support inference and run on CPUs.
GPU inference is not yet supported
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
aux_logits (bool): If True, add an auxiliary branch that can improve training.
Default: *True*
transform_input (bool): If True, preprocesses the input according to the method with which it
was trained on ImageNet. Default: *False*
"""
if pretrained:
if "transform_input" not in kwargs:
kwargs["transform_input"] = True
if "aux_logits" in kwargs:
original_aux_logits = kwargs["aux_logits"]
kwargs["aux_logits"] = True
else:
original_aux_logits = False
model = QuantizableInception3(**kwargs)
_replace_relu(model)
if quantize:
# TODO use pretrained as a string to specify the backend
backend = 'fbgemm'
quantize_model(model, backend)
else:
assert pretrained in [True, False]
if pretrained:
if quantize:
model_url = quant_model_urls['inception_v3_google' + '_' + backend]
else:
model_url = inception_module.model_urls['inception_v3_google']
state_dict = load_state_dict_from_url(model_url,
progress=progress)
model.load_state_dict(state_dict)
if not original_aux_logits:
model.aux_logits = False
del model.AuxLogits
return model
class QuantizableBasicConv2d(inception_module.BasicConv2d):
def __init__(self, *args, **kwargs):
super(QuantizableBasicConv2d, self).__init__(*args, **kwargs)
self.relu = nn.ReLU()
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
def fuse_model(self):
torch.quantization.fuse_modules(self, ["conv", "bn", "relu"], inplace=True)
class QuantizableInceptionA(inception_module.InceptionA):
def __init__(self, *args, **kwargs):
super(QuantizableInceptionA, self).__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs)
self.myop = nn.quantized.FloatFunctional()
def forward(self, x):
outputs = self._forward(x)
return self.myop.cat(outputs, 1)
class QuantizableInceptionB(inception_module.InceptionB):
def __init__(self, *args, **kwargs):
super(QuantizableInceptionB, self).__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs)
self.myop = nn.quantized.FloatFunctional()
def forward(self, x):
outputs = self._forward(x)
return self.myop.cat(outputs, 1)
class QuantizableInceptionC(inception_module.InceptionC):
def __init__(self, *args, **kwargs):
super(QuantizableInceptionC, self).__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs)
self.myop = nn.quantized.FloatFunctional()
def forward(self, x):
outputs = self._forward(x)
return self.myop.cat(outputs, 1)
class QuantizableInceptionD(inception_module.InceptionD):
def __init__(self, *args, **kwargs):
super(QuantizableInceptionD, self).__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs)
self.myop = nn.quantized.FloatFunctional()
def forward(self, x):
outputs = self._forward(x)
return self.myop.cat(outputs, 1)
class QuantizableInceptionE(inception_module.InceptionE):
def __init__(self, *args, **kwargs):
super(QuantizableInceptionE, self).__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs)
self.myop = nn.quantized.FloatFunctional()
def _forward(self, x):
branch1x1 = self.branch1x1(x)
branch3x3 = self.branch3x3_1(x)
branch3x3 = [self.branch3x3_2a(branch3x3), self.branch3x3_2b(branch3x3)]
branch3x3 = self.myop.cat(branch3x3, 1)
branch3x3dbl = self.branch3x3dbl_1(x)
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
branch3x3dbl = [
self.branch3x3dbl_3a(branch3x3dbl),
self.branch3x3dbl_3b(branch3x3dbl),
]
branch3x3dbl = self.myop.cat(branch3x3dbl, 1)
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
branch_pool = self.branch_pool(branch_pool)
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
return outputs
def forward(self, x):
outputs = self._forward(x)
return self.myop.cat(outputs, 1)
class QuantizableInceptionAux(inception_module.InceptionAux):
def __init__(self, *args, **kwargs):
super(QuantizableInceptionAux, self).__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs)
class QuantizableInception3(inception_module.Inception3):
def __init__(self, num_classes=1000, aux_logits=True, transform_input=False):
super(QuantizableInception3, self).__init__(
num_classes=num_classes,
aux_logits=aux_logits,
transform_input=transform_input,
inception_blocks=[
QuantizableBasicConv2d,
QuantizableInceptionA,
QuantizableInceptionB,
QuantizableInceptionC,
QuantizableInceptionD,
QuantizableInceptionE,
QuantizableInceptionAux
]
)
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
def forward(self, x):
x = self._transform_input(x)
x = self.quant(x)
x, aux = self._forward(x)
x = self.dequant(x)
aux_defined = self.training and self.aux_logits
if torch.jit.is_scripting():
if not aux_defined:
warnings.warn("Scripted QuantizableInception3 always returns QuantizableInception3 Tuple")
return InceptionOutputs(x, aux)
else:
return self.eager_outputs(x, aux)
def fuse_model(self):
r"""Fuse conv/bn/relu modules in inception model
Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization.
Model is modified in place. Note that this operation does not change numerics
and the model after modification is in floating point
"""
for m in self.modules():
if type(m) == QuantizableBasicConv2d:
m.fuse_model()
import torch
import torch.nn as nn
from torchvision.models.utils import load_state_dict_from_url
import torchvision.models.shufflenetv2
import sys
from .utils import _replace_relu, quantize_model
shufflenetv2 = sys.modules['torchvision.models.shufflenetv2']
__all__ = [
'QuantizableShuffleNetV2', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0',
'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0'
]
quant_model_urls = {
'shufflenetv2_x0.5_fbgemm': None,
'shufflenetv2_x1.0_fbgemm':
'https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-751f210b.pth',
'shufflenetv2_x1.5_fbgemm': None,
'shufflenetv2_x2.0_fbgemm': None,
}
class QuantizableInvertedResidual(shufflenetv2.InvertedResidual):
def __init__(self, *args, **kwargs):
super(QuantizableInvertedResidual, self).__init__(*args, **kwargs)
self.cat = nn.quantized.FloatFunctional()
def forward(self, x):
if self.stride == 1:
x1, x2 = x.chunk(2, dim=1)
out = self.cat.cat((x1, self.branch2(x2)), dim=1)
else:
out = self.cat.cat((self.branch1(x), self.branch2(x)), dim=1)
out = shufflenetv2.channel_shuffle(out, 2)
return out
class QuantizableShuffleNetV2(shufflenetv2.ShuffleNetV2):
def __init__(self, *args, **kwargs):
super(QuantizableShuffleNetV2, self).__init__(*args, inverted_residual=QuantizableInvertedResidual, **kwargs)
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
def forward(self, x):
x = self.quant(x)
x = self._forward(x)
x = self.dequant(x)
return x
def fuse_model(self):
r"""Fuse conv/bn/relu modules in shufflenetv2 model
Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization.
Model is modified in place. Note that this operation does not change numerics
and the model after modification is in floating point
"""
for name, m in self._modules.items():
if name in ["conv1", "conv5"]:
torch.quantization.fuse_modules(m, [["0", "1", "2"]], inplace=True)
for m in self.modules():
if type(m) == QuantizableInvertedResidual:
if len(m.branch1._modules.items()) > 0:
torch.quantization.fuse_modules(
m.branch1, [["0", "1"], ["2", "3", "4"]], inplace=True
)
torch.quantization.fuse_modules(
m.branch2,
[["0", "1", "2"], ["3", "4"], ["5", "6", "7"]],
inplace=True,
)
def _shufflenetv2(arch, pretrained, progress, quantize, *args, **kwargs):
model = QuantizableShuffleNetV2(*args, **kwargs)
_replace_relu(model)
if quantize:
# TODO use pretrained as a string to specify the backend
backend = 'fbgemm'
quantize_model(model, backend)
else:
assert pretrained in [True, False]
if pretrained:
if quantize:
model_url = quant_model_urls[arch + '_' + backend]
else:
model_url = shufflenetv2.model_urls[arch]
state_dict = load_state_dict_from_url(model_url,
progress=progress)
model.load_state_dict(state_dict)
return model
def shufflenet_v2_x0_5(pretrained=False, progress=True, quantize=False, **kwargs):
"""
Constructs a ShuffleNetV2 with 0.5x output channels, as described in
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
<https://arxiv.org/abs/1807.11164>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress, quantize,
[4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)
def shufflenet_v2_x1_0(pretrained=False, progress=True, quantize=False, **kwargs):
"""
Constructs a ShuffleNetV2 with 1.0x output channels, as described in
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
<https://arxiv.org/abs/1807.11164>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress, quantize,
[4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)
def shufflenet_v2_x1_5(pretrained=False, progress=True, quantize=False, **kwargs):
"""
Constructs a ShuffleNetV2 with 1.5x output channels, as described in
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
<https://arxiv.org/abs/1807.11164>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress, quantize,
[4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)
def shufflenet_v2_x2_0(pretrained=False, progress=True, quantize=False, **kwargs):
"""
Constructs a ShuffleNetV2 with 2.0x output channels, as described in
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
<https://arxiv.org/abs/1807.11164>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress, quantize,
[4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)
...@@ -84,7 +84,7 @@ class InvertedResidual(nn.Module): ...@@ -84,7 +84,7 @@ class InvertedResidual(nn.Module):
class ShuffleNetV2(nn.Module): class ShuffleNetV2(nn.Module):
def __init__(self, stages_repeats, stages_out_channels, num_classes=1000): def __init__(self, stages_repeats, stages_out_channels, num_classes=1000, inverted_residual=InvertedResidual):
super(ShuffleNetV2, self).__init__() super(ShuffleNetV2, self).__init__()
if len(stages_repeats) != 3: if len(stages_repeats) != 3:
...@@ -107,9 +107,9 @@ class ShuffleNetV2(nn.Module): ...@@ -107,9 +107,9 @@ class ShuffleNetV2(nn.Module):
stage_names = ['stage{}'.format(i) for i in [2, 3, 4]] stage_names = ['stage{}'.format(i) for i in [2, 3, 4]]
for name, repeats, output_channels in zip( for name, repeats, output_channels in zip(
stage_names, stages_repeats, self._stage_out_channels[1:]): stage_names, stages_repeats, self._stage_out_channels[1:]):
seq = [InvertedResidual(input_channels, output_channels, 2)] seq = [inverted_residual(input_channels, output_channels, 2)]
for i in range(repeats - 1): for i in range(repeats - 1):
seq.append(InvertedResidual(output_channels, output_channels, 1)) seq.append(inverted_residual(output_channels, output_channels, 1))
setattr(self, name, nn.Sequential(*seq)) setattr(self, name, nn.Sequential(*seq))
input_channels = output_channels input_channels = output_channels
...@@ -122,7 +122,7 @@ class ShuffleNetV2(nn.Module): ...@@ -122,7 +122,7 @@ class ShuffleNetV2(nn.Module):
self.fc = nn.Linear(output_channels, num_classes) self.fc = nn.Linear(output_channels, num_classes)
def forward(self, x): def _forward(self, x):
x = self.conv1(x) x = self.conv1(x)
x = self.maxpool(x) x = self.maxpool(x)
x = self.stage2(x) x = self.stage2(x)
...@@ -133,6 +133,8 @@ class ShuffleNetV2(nn.Module): ...@@ -133,6 +133,8 @@ class ShuffleNetV2(nn.Module):
x = self.fc(x) x = self.fc(x)
return x return x
forward = _forward
def _shufflenetv2(arch, pretrained, progress, *args, **kwargs): def _shufflenetv2(arch, pretrained, progress, *args, **kwargs):
model = ShuffleNetV2(*args, **kwargs) model = ShuffleNetV2(*args, **kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment