Commit 21110d93 authored by eellison's avatar eellison Committed by Francisco Massa
Browse files

Make Densenet Scriptable (#1342)

* make densenet scriptable

* make py2 compat

* use torch List polyfill

* fix unpacking for checkpointing

* fewer changes to _Denseblock

* improve error message

* print traceback

* add typing dependency

* add typing dependency to travis too

* Make loading old checkpoints work
parent f677ea31
......@@ -17,12 +17,12 @@ matrix:
script: ./travis-scripts/run-clang-format/run-clang-format.py -r torchvision/csrc
- env: LINT_CHECK
python: "2.7"
install: pip install flake8
install: pip install flake8 typing
script: flake8 --exclude .circleci
after_success: []
- env: LINT_CHECK
python: "3.6"
install: pip install flake8
install: pip install flake8 typing
script: flake8 .circleci
after_success: []
- python: "2.7"
......@@ -52,6 +52,7 @@ before_install:
- pip install future
- pip install pytest pytest-cov codecov
- pip install mock
- pip install typing
- |
if [[ $TRAVIS_PYTHON_VERSION == 3.6 ]]; then
pip install onnxruntime
......
......@@ -47,6 +47,7 @@ test:
- mock
- av
- ca-certificates
- typing
commands:
pytest .
......
......@@ -3,6 +3,7 @@ from itertools import product
import torch
from torchvision import models
import unittest
import traceback
def get_available_classification_models():
......@@ -32,7 +33,7 @@ torchub_models = {
"resnext50_32x4d": True,
"fcn_resnet101": False,
"googlenet": False,
"densenet121": False,
"densenet121": True,
"resnet18": True,
"alexnet": True,
"shufflenet_v2_x1_0": True,
......@@ -47,11 +48,14 @@ class Tester(unittest.TestCase):
if name not in torchub_models:
return
scriptable = True
msg = ""
try:
torch.jit.script(model)
except Exception:
except Exception as e:
tb = traceback.format_exc()
scriptable = False
self.assertEqual(torchub_models[name], scriptable)
msg = str(e) + str(tb)
self.assertEqual(torchub_models[name], scriptable, msg)
def _test_classification_model(self, name, input_shape):
# passing num_class equal to a number other than 1000 helps in making the test
......
......@@ -5,6 +5,8 @@ import torch.nn.functional as F
import torch.utils.checkpoint as cp
from collections import OrderedDict
from .utils import load_state_dict_from_url
from torch import Tensor
from torch.jit.annotations import List
__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']
......@@ -17,16 +19,7 @@ model_urls = {
}
def _bn_function_factory(norm, relu, conv):
def bn_function(*inputs):
concated_features = torch.cat(inputs, 1)
bottleneck_output = conv(relu(norm(concated_features)))
return bottleneck_output
return bn_function
class _DenseLayer(nn.Sequential):
class _DenseLayer(nn.Module):
def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, memory_efficient=False):
super(_DenseLayer, self).__init__()
self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
......@@ -39,15 +32,57 @@ class _DenseLayer(nn.Sequential):
self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
kernel_size=3, stride=1, padding=1,
bias=False)),
self.drop_rate = drop_rate
self.drop_rate = float(drop_rate)
self.memory_efficient = memory_efficient
def forward(self, *prev_features):
bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1)
if self.memory_efficient and any(prev_feature.requires_grad for prev_feature in prev_features):
bottleneck_output = cp.checkpoint(bn_function, *prev_features)
def bn_function(self, inputs):
# type: (List[Tensor]) -> Tensor
concated_features = torch.cat(inputs, 1)
bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features))) # noqa: T484
return bottleneck_output
# todo: rewrite when torchscript supports any
def any_requires_grad(self, input):
# type: (List[Tensor]) -> bool
for tensor in input:
if tensor.requires_grad:
return True
return False
@torch.jit.unused # noqa: T484
def call_checkpoint_bottleneck(self, input):
# type: (List[Tensor]) -> Tensor
def closure(*inputs):
return self.bn_function(*inputs)
return cp.checkpoint(closure, input)
@torch.jit._overload_method # noqa: F811
def forward(self, input):
# type: (List[Tensor]) -> (Tensor)
pass
@torch.jit._overload_method # noqa: F811
def forward(self, input):
# type: (Tensor) -> (Tensor)
pass
# torchscript does not yet support *args, so we overload method
# allowing it to take either a List[Tensor] or single Tensor
def forward(self, input): # noqa: F811
if isinstance(input, Tensor):
prev_features = [input]
else:
bottleneck_output = bn_function(*prev_features)
prev_features = input
if self.memory_efficient and self.any_requires_grad(prev_features):
if torch.jit.is_scripting():
raise Exception("Memory Efficient not supported in JIT")
bottleneck_output = self.call_checkpoint_bottleneck(prev_features)
else:
bottleneck_output = self.bn_function(prev_features)
new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
if self.drop_rate > 0:
new_features = F.dropout(new_features, p=self.drop_rate,
......@@ -56,8 +91,12 @@ class _DenseLayer(nn.Sequential):
class _DenseBlock(nn.Module):
_version = 2
__constants__ = ['layers']
def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, memory_efficient=False):
super(_DenseBlock, self).__init__()
self.layers = nn.ModuleDict()
for i in range(num_layers):
layer = _DenseLayer(
num_input_features + i * growth_rate,
......@@ -66,15 +105,34 @@ class _DenseBlock(nn.Module):
drop_rate=drop_rate,
memory_efficient=memory_efficient,
)
self.add_module('denselayer%d' % (i + 1), layer)
self.layers['denselayer%d' % (i + 1)] = layer
def forward(self, init_features):
features = [init_features]
for name, layer in self.named_children():
new_features = layer(*features)
for name, layer in self.layers.items():
new_features = layer(features)
features.append(new_features)
return torch.cat(features, 1)
@torch.jit.ignore
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
version = local_metadata.get('version', None)
if (version is None or version < 2):
# now we have a new nesting level for torchscript support
for new_key in self.state_dict().keys():
# remove prefix "layers."
old_key = new_key[len("layers."):]
old_key = prefix + old_key
new_key = prefix + new_key
if old_key in state_dict:
value = state_dict[old_key]
del state_dict[old_key]
state_dict[new_key] = value
super(_DenseBlock, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
class _Transition(nn.Sequential):
def __init__(self, num_input_features, num_output_features):
......@@ -102,6 +160,8 @@ class DenseNet(nn.Module):
but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
"""
__constants__ = ['features']
def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000, memory_efficient=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment