"...distributed/git@developer.sourcefind.cn:OpenDAS/apex.git" did not exist on "72bce160dc6125a059ed2d3642d3eb104567b871"
Unverified Commit de387e8c authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Memory efficient densenet (#1003)

* GPU efficient Densenets

* removed `import math`

* Changed 'efficient' to 'memory_efficient'

* Add tests

* Bugfix in test

* Fix lint

* Remove unecessary formatting
parent 060c10f0
......@@ -60,6 +60,26 @@ class Tester(unittest.TestCase):
new_model = torch.nn.Sequential(layers)
return new_model
def test_memory_efficient_densenet(self):
input_shape = (1, 3, 300, 300)
x = torch.rand(input_shape)
for name in ['densenet121', 'densenet169', 'densenet201', 'densenet161']:
model1 = models.__dict__[name](num_classes=50, memory_efficient=True)
params = model1.state_dict()
model1.eval()
out1 = model1(x)
out1.sum().backward()
model2 = models.__dict__[name](num_classes=50, memory_efficient=False)
model2.load_state_dict(params)
model2.eval()
out2 = model2(x)
max_diff = (out1 - out2).abs().max()
self.assertTrue(max_diff < 1e-5)
def test_resnet_dilation(self):
# TODO improve tests to also check that each layer has the right dimensionality
for i in product([False, True], [False, True], [False, True]):
......
......@@ -2,8 +2,10 @@ import re
import torch
import torch.nn as nn
import torch.nn.functional as F
from .utils import load_state_dict_from_url
import torch.utils.checkpoint as cp
from collections import OrderedDict
from .utils import load_state_dict_from_url
__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']
......@@ -15,8 +17,17 @@ model_urls = {
}
def _bn_function_factory(norm, relu, conv):
def bn_function(*inputs):
concated_features = torch.cat(inputs, 1)
bottleneck_output = conv(relu(norm(concated_features)))
return bottleneck_output
return bn_function
class _DenseLayer(nn.Sequential):
def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, memory_efficient=False):
super(_DenseLayer, self).__init__()
self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
self.add_module('relu1', nn.ReLU(inplace=True)),
......@@ -29,23 +40,41 @@ class _DenseLayer(nn.Sequential):
kernel_size=3, stride=1, padding=1,
bias=False)),
self.drop_rate = drop_rate
def forward(self, x):
new_features = super(_DenseLayer, self).forward(x)
self.memory_efficient = memory_efficient
def forward(self, *prev_features):
bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1)
if self.memory_efficient and any(prev_feature.requires_grad for prev_feature in prev_features):
bottleneck_output = cp.checkpoint(bn_function, *prev_features)
else:
bottleneck_output = bn_function(*prev_features)
new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
if self.drop_rate > 0:
new_features = F.dropout(new_features, p=self.drop_rate,
training=self.training)
return torch.cat([x, new_features], 1)
return new_features
class _DenseBlock(nn.Sequential):
def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):
class _DenseBlock(nn.Module):
def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, memory_efficient=False):
super(_DenseBlock, self).__init__()
for i in range(num_layers):
layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate,
bn_size, drop_rate)
layer = _DenseLayer(
num_input_features + i * growth_rate,
growth_rate=growth_rate,
bn_size=bn_size,
drop_rate=drop_rate,
memory_efficient=memory_efficient,
)
self.add_module('denselayer%d' % (i + 1), layer)
def forward(self, init_features):
features = [init_features]
for name, layer in self.named_children():
new_features = layer(*features)
features.append(new_features)
return torch.cat(features, 1)
class _Transition(nn.Sequential):
def __init__(self, num_input_features, num_output_features):
......@@ -69,10 +98,12 @@ class DenseNet(nn.Module):
(i.e. bn_size * k features in the bottleneck layer)
drop_rate (float) - dropout rate after each dense layer
num_classes (int) - number of classification classes
memory_efficient (bool) - set to True to use checkpointing. Much more memory efficient,
but slower. Default: *False*
"""
def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000):
num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000, memory_efficient=False):
super(DenseNet, self).__init__()
......@@ -88,9 +119,14 @@ class DenseNet(nn.Module):
# Each denseblock
num_features = num_init_features
for i, num_layers in enumerate(block_config):
block = _DenseBlock(num_layers=num_layers, num_input_features=num_features,
bn_size=bn_size, growth_rate=growth_rate,
drop_rate=drop_rate)
block = _DenseBlock(
num_layers=num_layers,
num_input_features=num_features,
bn_size=bn_size,
growth_rate=growth_rate,
drop_rate=drop_rate,
memory_efficient=memory_efficient
)
self.features.add_module('denseblock%d' % (i + 1), block)
num_features = num_features + num_layers * growth_rate
if i != len(block_config) - 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment