Commit 26c9630b authored by eellison's avatar eellison Committed by Francisco Massa
Browse files

make shufflenet and resnet scriptable (#1270)

* make shufflenet scriptable

* make resnet18 scriptable

* set downsample to identity instead of __constants__ api

* use __constants__ for downsample instead of identity

* import tensor to fix flake

* use torch.Tensor type annotation instead of import
parent 20a4a42d
......@@ -33,9 +33,9 @@ torchub_models = {
"fcn_resnet101": False,
"googlenet": False,
"densenet121": False,
"resnet18": False,
"resnet18": True,
"alexnet": True,
"shufflenet_v2_x1_0": False,
"shufflenet_v2_x1_0": True,
"squeezenet1_0": True,
"vgg11": True,
"inception_v3": False,
......
......@@ -34,6 +34,7 @@ def conv1x1(in_planes, out_planes, stride=1):
class BasicBlock(nn.Module):
expansion = 1
__constants__ = ['downsample']
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
......
......@@ -17,6 +17,7 @@ model_urls = {
def channel_shuffle(x, groups):
# type: (torch.Tensor, int) -> torch.Tensor
batchsize, num_channels, height, width = x.data.size()
channels_per_group = num_channels // groups
......@@ -51,6 +52,8 @@ class InvertedResidual(nn.Module):
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
)
else:
self.branch1 = nn.Sequential()
self.branch2 = nn.Sequential(
nn.Conv2d(inp if (self.stride > 1) else branch_features,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment