Commit 21110d93 authored by eellison's avatar eellison Committed by Francisco Massa
Browse files

Make Densenet Scriptable (#1342)

* make densenet scriptable

* make py2 compat

* use torch List polyfill

* fix unpacking for checkpointing

* fewer changes to _Denseblock

* improve error message

* print traceback

* add typing dependency

* add typing dependency to travis too

* Make loading old checkpoints work
parent f677ea31
...@@ -17,12 +17,12 @@ matrix: ...@@ -17,12 +17,12 @@ matrix:
script: ./travis-scripts/run-clang-format/run-clang-format.py -r torchvision/csrc script: ./travis-scripts/run-clang-format/run-clang-format.py -r torchvision/csrc
- env: LINT_CHECK - env: LINT_CHECK
python: "2.7" python: "2.7"
install: pip install flake8 install: pip install flake8 typing
script: flake8 --exclude .circleci script: flake8 --exclude .circleci
after_success: [] after_success: []
- env: LINT_CHECK - env: LINT_CHECK
python: "3.6" python: "3.6"
install: pip install flake8 install: pip install flake8 typing
script: flake8 .circleci script: flake8 .circleci
after_success: [] after_success: []
- python: "2.7" - python: "2.7"
...@@ -52,6 +52,7 @@ before_install: ...@@ -52,6 +52,7 @@ before_install:
- pip install future - pip install future
- pip install pytest pytest-cov codecov - pip install pytest pytest-cov codecov
- pip install mock - pip install mock
- pip install typing
- | - |
if [[ $TRAVIS_PYTHON_VERSION == 3.6 ]]; then if [[ $TRAVIS_PYTHON_VERSION == 3.6 ]]; then
pip install onnxruntime pip install onnxruntime
......
...@@ -47,6 +47,7 @@ test: ...@@ -47,6 +47,7 @@ test:
- mock - mock
- av - av
- ca-certificates - ca-certificates
- typing
commands: commands:
pytest . pytest .
......
...@@ -3,6 +3,7 @@ from itertools import product ...@@ -3,6 +3,7 @@ from itertools import product
import torch import torch
from torchvision import models from torchvision import models
import unittest import unittest
import traceback
def get_available_classification_models(): def get_available_classification_models():
...@@ -32,7 +33,7 @@ torchub_models = { ...@@ -32,7 +33,7 @@ torchub_models = {
"resnext50_32x4d": True, "resnext50_32x4d": True,
"fcn_resnet101": False, "fcn_resnet101": False,
"googlenet": False, "googlenet": False,
"densenet121": False, "densenet121": True,
"resnet18": True, "resnet18": True,
"alexnet": True, "alexnet": True,
"shufflenet_v2_x1_0": True, "shufflenet_v2_x1_0": True,
...@@ -47,11 +48,14 @@ class Tester(unittest.TestCase): ...@@ -47,11 +48,14 @@ class Tester(unittest.TestCase):
if name not in torchub_models: if name not in torchub_models:
return return
scriptable = True scriptable = True
msg = ""
try: try:
torch.jit.script(model) torch.jit.script(model)
except Exception: except Exception as e:
tb = traceback.format_exc()
scriptable = False scriptable = False
self.assertEqual(torchub_models[name], scriptable) msg = str(e) + str(tb)
self.assertEqual(torchub_models[name], scriptable, msg)
def _test_classification_model(self, name, input_shape): def _test_classification_model(self, name, input_shape):
# passing num_class equal to a number other than 1000 helps in making the test # passing num_class equal to a number other than 1000 helps in making the test
......
...@@ -5,6 +5,8 @@ import torch.nn.functional as F ...@@ -5,6 +5,8 @@ import torch.nn.functional as F
import torch.utils.checkpoint as cp import torch.utils.checkpoint as cp
from collections import OrderedDict from collections import OrderedDict
from .utils import load_state_dict_from_url from .utils import load_state_dict_from_url
from torch import Tensor
from torch.jit.annotations import List
__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']
...@@ -17,16 +19,7 @@ model_urls = { ...@@ -17,16 +19,7 @@ model_urls = {
} }
def _bn_function_factory(norm, relu, conv): class _DenseLayer(nn.Module):
def bn_function(*inputs):
concated_features = torch.cat(inputs, 1)
bottleneck_output = conv(relu(norm(concated_features)))
return bottleneck_output
return bn_function
class _DenseLayer(nn.Sequential):
def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, memory_efficient=False): def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, memory_efficient=False):
super(_DenseLayer, self).__init__() super(_DenseLayer, self).__init__()
self.add_module('norm1', nn.BatchNorm2d(num_input_features)), self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
...@@ -39,15 +32,57 @@ class _DenseLayer(nn.Sequential): ...@@ -39,15 +32,57 @@ class _DenseLayer(nn.Sequential):
self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
kernel_size=3, stride=1, padding=1, kernel_size=3, stride=1, padding=1,
bias=False)), bias=False)),
self.drop_rate = drop_rate self.drop_rate = float(drop_rate)
self.memory_efficient = memory_efficient self.memory_efficient = memory_efficient
def forward(self, *prev_features): def bn_function(self, inputs):
bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1) # type: (List[Tensor]) -> Tensor
if self.memory_efficient and any(prev_feature.requires_grad for prev_feature in prev_features): concated_features = torch.cat(inputs, 1)
bottleneck_output = cp.checkpoint(bn_function, *prev_features) bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features))) # noqa: T484
return bottleneck_output
# todo: rewrite when torchscript supports any
def any_requires_grad(self, input):
# type: (List[Tensor]) -> bool
for tensor in input:
if tensor.requires_grad:
return True
return False
@torch.jit.unused # noqa: T484
def call_checkpoint_bottleneck(self, input):
# type: (List[Tensor]) -> Tensor
def closure(*inputs):
return self.bn_function(*inputs)
return cp.checkpoint(closure, input)
@torch.jit._overload_method # noqa: F811
def forward(self, input):
# type: (List[Tensor]) -> (Tensor)
pass
@torch.jit._overload_method # noqa: F811
def forward(self, input):
# type: (Tensor) -> (Tensor)
pass
# torchscript does not yet support *args, so we overload method
# allowing it to take either a List[Tensor] or single Tensor
def forward(self, input): # noqa: F811
if isinstance(input, Tensor):
prev_features = [input]
else: else:
bottleneck_output = bn_function(*prev_features) prev_features = input
if self.memory_efficient and self.any_requires_grad(prev_features):
if torch.jit.is_scripting():
raise Exception("Memory Efficient not supported in JIT")
bottleneck_output = self.call_checkpoint_bottleneck(prev_features)
else:
bottleneck_output = self.bn_function(prev_features)
new_features = self.conv2(self.relu2(self.norm2(bottleneck_output))) new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
if self.drop_rate > 0: if self.drop_rate > 0:
new_features = F.dropout(new_features, p=self.drop_rate, new_features = F.dropout(new_features, p=self.drop_rate,
...@@ -56,8 +91,12 @@ class _DenseLayer(nn.Sequential): ...@@ -56,8 +91,12 @@ class _DenseLayer(nn.Sequential):
class _DenseBlock(nn.Module): class _DenseBlock(nn.Module):
_version = 2
__constants__ = ['layers']
def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, memory_efficient=False): def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, memory_efficient=False):
super(_DenseBlock, self).__init__() super(_DenseBlock, self).__init__()
self.layers = nn.ModuleDict()
for i in range(num_layers): for i in range(num_layers):
layer = _DenseLayer( layer = _DenseLayer(
num_input_features + i * growth_rate, num_input_features + i * growth_rate,
...@@ -66,15 +105,34 @@ class _DenseBlock(nn.Module): ...@@ -66,15 +105,34 @@ class _DenseBlock(nn.Module):
drop_rate=drop_rate, drop_rate=drop_rate,
memory_efficient=memory_efficient, memory_efficient=memory_efficient,
) )
self.add_module('denselayer%d' % (i + 1), layer) self.layers['denselayer%d' % (i + 1)] = layer
def forward(self, init_features): def forward(self, init_features):
features = [init_features] features = [init_features]
for name, layer in self.named_children(): for name, layer in self.layers.items():
new_features = layer(*features) new_features = layer(features)
features.append(new_features) features.append(new_features)
return torch.cat(features, 1) return torch.cat(features, 1)
@torch.jit.ignore
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
version = local_metadata.get('version', None)
if (version is None or version < 2):
# now we have a new nesting level for torchscript support
for new_key in self.state_dict().keys():
# remove prefix "layers."
old_key = new_key[len("layers."):]
old_key = prefix + old_key
new_key = prefix + new_key
if old_key in state_dict:
value = state_dict[old_key]
del state_dict[old_key]
state_dict[new_key] = value
super(_DenseBlock, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
class _Transition(nn.Sequential): class _Transition(nn.Sequential):
def __init__(self, num_input_features, num_output_features): def __init__(self, num_input_features, num_output_features):
...@@ -102,6 +160,8 @@ class DenseNet(nn.Module): ...@@ -102,6 +160,8 @@ class DenseNet(nn.Module):
but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_ but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
""" """
__constants__ = ['features']
def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000, memory_efficient=False): num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000, memory_efficient=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment