Unverified Commit b9cbc227 authored by eellison's avatar eellison Committed by GitHub
Browse files

Make Googlnet & InceptionNet scriptable (#1349)

* make googlnet scriptable

* Remove typing import in favor of torch.jit.annotations

* add inceptionnet

* flake fixes

* fix asssert true

* add import division for torchscript

* fix script compilation

* fix flake, py2 division error

* fix py2 division error
parent e4f80bf1
......@@ -26,21 +26,20 @@ def get_available_video_models():
return [k for k, v in models.video.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
# model_name, expected to script without error
torchub_models = {
"deeplabv3_resnet101": True,
"mobilenet_v2": True,
"resnext50_32x4d": True,
"fcn_resnet101": True,
"googlenet": False,
"densenet121": True,
"resnet18": True,
"alexnet": True,
"shufflenet_v2_x1_0": True,
"squeezenet1_0": True,
"vgg11": True,
"inception_v3": False,
}
torchub_models = [
"deeplabv3_resnet101",
"mobilenet_v2",
"resnext50_32x4d",
"fcn_resnet101",
"googlenet",
"densenet121",
"resnet18",
"alexnet",
"shufflenet_v2_x1_0",
"squeezenet1_0",
"vgg11",
"inception_v3",
]
class Tester(unittest.TestCase):
......@@ -55,7 +54,7 @@ class Tester(unittest.TestCase):
tb = traceback.format_exc()
scriptable = False
msg = str(e) + str(tb)
self.assertEqual(torchub_models[name], scriptable, msg)
self.assertTrue(scriptable, msg)
def _test_classification_model(self, name, input_shape):
# passing num_class equal to a number other than 1000 helps in making the test
......
from __future__ import division
import warnings
from collections import namedtuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.jit.annotations import Optional
from torch import Tensor
from .utils import load_state_dict_from_url
__all__ = ['GoogLeNet', 'googlenet']
__all__ = ['GoogLeNet', 'googlenet', "GoogLeNetOutputs", "_GoogLeNetOutputs"]
model_urls = {
# GoogLeNet ported from TensorFlow
'googlenet': 'https://download.pytorch.org/models/googlenet-1378be20.pth',
}
_GoogLeNetOutputs = namedtuple('GoogLeNetOutputs', ['logits', 'aux_logits2', 'aux_logits1'])
GoogLeNetOutputs = namedtuple('GoogLeNetOutputs', ['logits', 'aux_logits2', 'aux_logits1'])
GoogLeNetOutputs.__annotations__ = {'logits': Tensor, 'aux_logits2': Optional[Tensor],
'aux_logits1': Optional[Tensor]}
# Script annotations failed with _GoogleNetOutputs = namedtuple ...
# _GoogLeNetOutputs set here for backwards compat
_GoogLeNetOutputs = GoogLeNetOutputs
def googlenet(pretrained=False, progress=True, **kwargs):
......@@ -51,6 +61,7 @@ def googlenet(pretrained=False, progress=True, **kwargs):
class GoogLeNet(nn.Module):
__constants__ = ['aux_logits', 'transform_input']
def __init__(self, num_classes=1000, aux_logits=True, transform_input=False, init_weights=True):
super(GoogLeNet, self).__init__()
......@@ -102,6 +113,7 @@ class GoogLeNet(nn.Module):
nn.init.constant_(m.bias, 0)
def forward(self, x):
# type: (Tensor) -> GoogLeNetOutputs
if self.transform_input:
x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
......@@ -128,8 +140,11 @@ class GoogLeNet(nn.Module):
# N x 480 x 14 x 14
x = self.inception4a(x)
# N x 512 x 14 x 14
if self.training and self.aux_logits:
aux_defined = self.training and self.aux_logits
if aux_defined:
aux1 = self.aux1(x)
else:
aux1 = None
x = self.inception4b(x)
# N x 512 x 14 x 14
......@@ -137,8 +152,10 @@ class GoogLeNet(nn.Module):
# N x 512 x 14 x 14
x = self.inception4d(x)
# N x 528 x 14 x 14
if self.training and self.aux_logits:
if aux_defined:
aux2 = self.aux2(x)
else:
aux2 = None
x = self.inception4e(x)
# N x 832 x 14 x 14
......@@ -156,12 +173,24 @@ class GoogLeNet(nn.Module):
x = self.dropout(x)
x = self.fc(x)
# N x 1000 (num_classes)
if torch.jit.is_scripting():
if not aux_defined:
warnings.warn("Scripted GoogleNet always returns GoogleNetOutputs Tuple")
return GoogLeNetOutputs(x, aux2, aux1)
else:
return self.eager_outputs(x, aux2, aux1)
@torch.jit.unused
def eager_outputs(self, x, aux2, aux1):
# type: (Tensor, Optional[Tensor], Optional[Tensor]) -> GoogLeNetOutputs
if self.training and self.aux_logits:
return _GoogLeNetOutputs(x, aux2, aux1)
return x
else:
return x
class Inception(nn.Module):
__constants__ = ['branch2', 'branch3', 'branch4']
def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
super(Inception, self).__init__()
......
from __future__ import division
from collections import namedtuple
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.jit.annotations import Optional
from .utils import load_state_dict_from_url
__all__ = ['Inception3', 'inception_v3']
__all__ = ['Inception3', 'inception_v3', 'InceptionOutputs', '_InceptionOutputs']
model_urls = {
......@@ -13,7 +17,12 @@ model_urls = {
'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',
}
_InceptionOutputs = namedtuple('InceptionOutputs', ['logits', 'aux_logits'])
InceptionOutputs = namedtuple('InceptionOutputs', ['logits', 'aux_logits'])
InceptionOutputs.__annotations__ = {'logits': torch.Tensor, 'aux_logits': Optional[torch.Tensor]}
# Script annotations failed with _GoogleNetOutputs = namedtuple ...
# _InceptionOutputs set here for backwards compat
_InceptionOutputs = InceptionOutputs
def inception_v3(pretrained=False, progress=True, **kwargs):
......@@ -128,8 +137,11 @@ class Inception3(nn.Module):
# N x 768 x 17 x 17
x = self.Mixed_6e(x)
# N x 768 x 17 x 17
if self.training and self.aux_logits:
aux_defined = self.training and self.aux_logits
if aux_defined:
aux = self.AuxLogits(x)
else:
aux = None
# N x 768 x 17 x 17
x = self.Mixed_7a(x)
# N x 1280 x 8 x 8
......@@ -146,8 +158,18 @@ class Inception3(nn.Module):
# N x 2048
x = self.fc(x)
# N x 1000 (num_classes)
if torch.jit.is_scripting():
if not aux_defined:
warnings.warn("Scripted InceptionNet always returns InceptionOutputs Tuple")
return InceptionOutputs(x, aux)
else:
return self.eager_outputs(x, aux)
@torch.jit.unused
def eager_outputs(self, x, aux):
# type: (torch.Tensor, Optional[torch.Tensor]) -> InceptionOutputs
if self.training and self.aux_logits:
return _InceptionOutputs(x, aux)
return InceptionOutputs(x, aux)
return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment