Unverified Commit b9cbc227 authored by eellison's avatar eellison Committed by GitHub
Browse files

Make Googlnet & InceptionNet scriptable (#1349)

* make googlnet scriptable

* Remove typing import in favor of torch.jit.annotations

* add inceptionnet

* flake fixes

* fix asssert true

* add import division for torchscript

* fix script compilation

* fix flake, py2 division error

* fix py2 division error
parent e4f80bf1
...@@ -26,21 +26,20 @@ def get_available_video_models(): ...@@ -26,21 +26,20 @@ def get_available_video_models():
return [k for k, v in models.video.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"] return [k for k, v in models.video.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
# model_name, expected to script without error torchub_models = [
torchub_models = { "deeplabv3_resnet101",
"deeplabv3_resnet101": True, "mobilenet_v2",
"mobilenet_v2": True, "resnext50_32x4d",
"resnext50_32x4d": True, "fcn_resnet101",
"fcn_resnet101": True, "googlenet",
"googlenet": False, "densenet121",
"densenet121": True, "resnet18",
"resnet18": True, "alexnet",
"alexnet": True, "shufflenet_v2_x1_0",
"shufflenet_v2_x1_0": True, "squeezenet1_0",
"squeezenet1_0": True, "vgg11",
"vgg11": True, "inception_v3",
"inception_v3": False, ]
}
class Tester(unittest.TestCase): class Tester(unittest.TestCase):
...@@ -55,7 +54,7 @@ class Tester(unittest.TestCase): ...@@ -55,7 +54,7 @@ class Tester(unittest.TestCase):
tb = traceback.format_exc() tb = traceback.format_exc()
scriptable = False scriptable = False
msg = str(e) + str(tb) msg = str(e) + str(tb)
self.assertEqual(torchub_models[name], scriptable, msg) self.assertTrue(scriptable, msg)
def _test_classification_model(self, name, input_shape): def _test_classification_model(self, name, input_shape):
# passing num_class equal to a number other than 1000 helps in making the test # passing num_class equal to a number other than 1000 helps in making the test
......
from __future__ import division
import warnings import warnings
from collections import namedtuple from collections import namedtuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.jit.annotations import Optional
from torch import Tensor
from .utils import load_state_dict_from_url from .utils import load_state_dict_from_url
__all__ = ['GoogLeNet', 'googlenet'] __all__ = ['GoogLeNet', 'googlenet', "GoogLeNetOutputs", "_GoogLeNetOutputs"]
model_urls = { model_urls = {
# GoogLeNet ported from TensorFlow # GoogLeNet ported from TensorFlow
'googlenet': 'https://download.pytorch.org/models/googlenet-1378be20.pth', 'googlenet': 'https://download.pytorch.org/models/googlenet-1378be20.pth',
} }
_GoogLeNetOutputs = namedtuple('GoogLeNetOutputs', ['logits', 'aux_logits2', 'aux_logits1']) GoogLeNetOutputs = namedtuple('GoogLeNetOutputs', ['logits', 'aux_logits2', 'aux_logits1'])
GoogLeNetOutputs.__annotations__ = {'logits': Tensor, 'aux_logits2': Optional[Tensor],
'aux_logits1': Optional[Tensor]}
# Script annotations failed with _GoogleNetOutputs = namedtuple ...
# _GoogLeNetOutputs set here for backwards compat
_GoogLeNetOutputs = GoogLeNetOutputs
def googlenet(pretrained=False, progress=True, **kwargs): def googlenet(pretrained=False, progress=True, **kwargs):
...@@ -51,6 +61,7 @@ def googlenet(pretrained=False, progress=True, **kwargs): ...@@ -51,6 +61,7 @@ def googlenet(pretrained=False, progress=True, **kwargs):
class GoogLeNet(nn.Module): class GoogLeNet(nn.Module):
__constants__ = ['aux_logits', 'transform_input']
def __init__(self, num_classes=1000, aux_logits=True, transform_input=False, init_weights=True): def __init__(self, num_classes=1000, aux_logits=True, transform_input=False, init_weights=True):
super(GoogLeNet, self).__init__() super(GoogLeNet, self).__init__()
...@@ -102,6 +113,7 @@ class GoogLeNet(nn.Module): ...@@ -102,6 +113,7 @@ class GoogLeNet(nn.Module):
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
def forward(self, x): def forward(self, x):
# type: (Tensor) -> GoogLeNetOutputs
if self.transform_input: if self.transform_input:
x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
...@@ -128,8 +140,11 @@ class GoogLeNet(nn.Module): ...@@ -128,8 +140,11 @@ class GoogLeNet(nn.Module):
# N x 480 x 14 x 14 # N x 480 x 14 x 14
x = self.inception4a(x) x = self.inception4a(x)
# N x 512 x 14 x 14 # N x 512 x 14 x 14
if self.training and self.aux_logits: aux_defined = self.training and self.aux_logits
if aux_defined:
aux1 = self.aux1(x) aux1 = self.aux1(x)
else:
aux1 = None
x = self.inception4b(x) x = self.inception4b(x)
# N x 512 x 14 x 14 # N x 512 x 14 x 14
...@@ -137,8 +152,10 @@ class GoogLeNet(nn.Module): ...@@ -137,8 +152,10 @@ class GoogLeNet(nn.Module):
# N x 512 x 14 x 14 # N x 512 x 14 x 14
x = self.inception4d(x) x = self.inception4d(x)
# N x 528 x 14 x 14 # N x 528 x 14 x 14
if self.training and self.aux_logits: if aux_defined:
aux2 = self.aux2(x) aux2 = self.aux2(x)
else:
aux2 = None
x = self.inception4e(x) x = self.inception4e(x)
# N x 832 x 14 x 14 # N x 832 x 14 x 14
...@@ -156,12 +173,24 @@ class GoogLeNet(nn.Module): ...@@ -156,12 +173,24 @@ class GoogLeNet(nn.Module):
x = self.dropout(x) x = self.dropout(x)
x = self.fc(x) x = self.fc(x)
# N x 1000 (num_classes) # N x 1000 (num_classes)
if torch.jit.is_scripting():
if not aux_defined:
warnings.warn("Scripted GoogleNet always returns GoogleNetOutputs Tuple")
return GoogLeNetOutputs(x, aux2, aux1)
else:
return self.eager_outputs(x, aux2, aux1)
@torch.jit.unused
def eager_outputs(self, x, aux2, aux1):
# type: (Tensor, Optional[Tensor], Optional[Tensor]) -> GoogLeNetOutputs
if self.training and self.aux_logits: if self.training and self.aux_logits:
return _GoogLeNetOutputs(x, aux2, aux1) return _GoogLeNetOutputs(x, aux2, aux1)
return x else:
return x
class Inception(nn.Module): class Inception(nn.Module):
__constants__ = ['branch2', 'branch3', 'branch4']
def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj): def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
super(Inception, self).__init__() super(Inception, self).__init__()
......
from __future__ import division
from collections import namedtuple from collections import namedtuple
import warnings
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.jit.annotations import Optional
from .utils import load_state_dict_from_url from .utils import load_state_dict_from_url
__all__ = ['Inception3', 'inception_v3'] __all__ = ['Inception3', 'inception_v3', 'InceptionOutputs', '_InceptionOutputs']
model_urls = { model_urls = {
...@@ -13,7 +17,12 @@ model_urls = { ...@@ -13,7 +17,12 @@ model_urls = {
'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth', 'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',
} }
_InceptionOutputs = namedtuple('InceptionOutputs', ['logits', 'aux_logits']) InceptionOutputs = namedtuple('InceptionOutputs', ['logits', 'aux_logits'])
InceptionOutputs.__annotations__ = {'logits': torch.Tensor, 'aux_logits': Optional[torch.Tensor]}
# Script annotations failed with _GoogleNetOutputs = namedtuple ...
# _InceptionOutputs set here for backwards compat
_InceptionOutputs = InceptionOutputs
def inception_v3(pretrained=False, progress=True, **kwargs): def inception_v3(pretrained=False, progress=True, **kwargs):
...@@ -128,8 +137,11 @@ class Inception3(nn.Module): ...@@ -128,8 +137,11 @@ class Inception3(nn.Module):
# N x 768 x 17 x 17 # N x 768 x 17 x 17
x = self.Mixed_6e(x) x = self.Mixed_6e(x)
# N x 768 x 17 x 17 # N x 768 x 17 x 17
if self.training and self.aux_logits: aux_defined = self.training and self.aux_logits
if aux_defined:
aux = self.AuxLogits(x) aux = self.AuxLogits(x)
else:
aux = None
# N x 768 x 17 x 17 # N x 768 x 17 x 17
x = self.Mixed_7a(x) x = self.Mixed_7a(x)
# N x 1280 x 8 x 8 # N x 1280 x 8 x 8
...@@ -146,8 +158,18 @@ class Inception3(nn.Module): ...@@ -146,8 +158,18 @@ class Inception3(nn.Module):
# N x 2048 # N x 2048
x = self.fc(x) x = self.fc(x)
# N x 1000 (num_classes) # N x 1000 (num_classes)
if torch.jit.is_scripting():
if not aux_defined:
warnings.warn("Scripted InceptionNet always returns InceptionOutputs Tuple")
return InceptionOutputs(x, aux)
else:
return self.eager_outputs(x, aux)
@torch.jit.unused
def eager_outputs(self, x, aux):
# type: (torch.Tensor, Optional[torch.Tensor]) -> InceptionOutputs
if self.training and self.aux_logits: if self.training and self.aux_logits:
return _InceptionOutputs(x, aux) return InceptionOutputs(x, aux)
return x return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment