Commit 26c9630b authored by eellison's avatar eellison Committed by Francisco Massa
Browse files

make shufflenet and resnet scriptable (#1270)

* make shufflenet scriptable

* make resnet18 scriptable

* set downsample to identity instead of __constants__ api

* use __constants__ for downsample instead of identity

* import tensor to fix flake

* use torch.Tensor type annotation instead of import
parent 20a4a42d
...@@ -33,9 +33,9 @@ torchub_models = { ...@@ -33,9 +33,9 @@ torchub_models = {
"fcn_resnet101": False, "fcn_resnet101": False,
"googlenet": False, "googlenet": False,
"densenet121": False, "densenet121": False,
"resnet18": False, "resnet18": True,
"alexnet": True, "alexnet": True,
"shufflenet_v2_x1_0": False, "shufflenet_v2_x1_0": True,
"squeezenet1_0": True, "squeezenet1_0": True,
"vgg11": True, "vgg11": True,
"inception_v3": False, "inception_v3": False,
......
...@@ -34,6 +34,7 @@ def conv1x1(in_planes, out_planes, stride=1): ...@@ -34,6 +34,7 @@ def conv1x1(in_planes, out_planes, stride=1):
class BasicBlock(nn.Module): class BasicBlock(nn.Module):
expansion = 1 expansion = 1
__constants__ = ['downsample']
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None): base_width=64, dilation=1, norm_layer=None):
......
...@@ -17,6 +17,7 @@ model_urls = { ...@@ -17,6 +17,7 @@ model_urls = {
def channel_shuffle(x, groups): def channel_shuffle(x, groups):
# type: (torch.Tensor, int) -> torch.Tensor
batchsize, num_channels, height, width = x.data.size() batchsize, num_channels, height, width = x.data.size()
channels_per_group = num_channels // groups channels_per_group = num_channels // groups
...@@ -51,6 +52,8 @@ class InvertedResidual(nn.Module): ...@@ -51,6 +52,8 @@ class InvertedResidual(nn.Module):
nn.BatchNorm2d(branch_features), nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
) )
else:
self.branch1 = nn.Sequential()
self.branch2 = nn.Sequential( self.branch2 = nn.Sequential(
nn.Conv2d(inp if (self.stride > 1) else branch_features, nn.Conv2d(inp if (self.stride > 1) else branch_features,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment