Commit 10a71116 authored by eellison's avatar eellison Committed by Francisco Massa
Browse files

remove BC-breaking changes (#1560)

* remove changes that induced BC

* Re-enable tests that have been disabled

* Remove outdated comment

* Remove outdated comment
parent b05fc269
......@@ -5,7 +5,7 @@ from torch import nn
from torch.jit.annotations import Dict
class IntermediateLayerGetter(nn.Module):
class IntermediateLayerGetter(nn.ModuleDict):
"""
Module wrapper that returns intermediate layers from a model
......@@ -45,8 +45,6 @@ class IntermediateLayerGetter(nn.Module):
def __init__(self, model, return_layers):
if not set(return_layers).issubset([name for name, _ in model.named_children()]):
raise ValueError("return_layers are not present in model")
super(IntermediateLayerGetter, self).__init__()
orig_return_layers = return_layers
return_layers = {k: v for k, v in return_layers.items()}
layers = OrderedDict()
......@@ -57,33 +55,14 @@ class IntermediateLayerGetter(nn.Module):
if not return_layers:
break
self.layers = nn.ModuleDict(layers)
super(IntermediateLayerGetter, self).__init__(layers)
self.return_layers = orig_return_layers
def forward(self, x):
out = OrderedDict()
for name, module in self.layers.items():
for name, module in self.items():
x = module(x)
if name in self.return_layers:
out_name = self.return_layers[name]
out[out_name] = x
return out
@torch.jit.ignore
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
version = local_metadata.get('version', None)
if (version is None or version < 2):
# now we have a new nesting level for torchscript support
for new_key in self.state_dict().keys():
# remove prefix "layers."
old_key = new_key[len("layers."):]
old_key = prefix + old_key
new_key = prefix + new_key
if old_key in state_dict:
value = state_dict[old_key]
del state_dict[old_key]
state_dict[new_key] = value
super(IntermediateLayerGetter, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
......@@ -90,13 +90,12 @@ class _DenseLayer(nn.Module):
return new_features
class _DenseBlock(nn.Module):
class _DenseBlock(nn.ModuleDict):
_version = 2
__constants__ = ['layers']
def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, memory_efficient=False):
super(_DenseBlock, self).__init__()
self.layers = nn.ModuleDict()
for i in range(num_layers):
layer = _DenseLayer(
num_input_features + i * growth_rate,
......@@ -105,34 +104,15 @@ class _DenseBlock(nn.Module):
drop_rate=drop_rate,
memory_efficient=memory_efficient,
)
self.layers['denselayer%d' % (i + 1)] = layer
self.add_module('denselayer%d' % (i + 1), layer)
def forward(self, init_features):
features = [init_features]
for name, layer in self.layers.items():
for name, layer in self.items():
new_features = layer(features)
features.append(new_features)
return torch.cat(features, 1)
@torch.jit.ignore
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
version = local_metadata.get('version', None)
if (version is None or version < 2):
# now we have a new nesting level for torchscript support
for new_key in self.state_dict().keys():
# remove prefix "layers."
old_key = new_key[len("layers."):]
old_key = prefix + old_key
new_key = prefix + new_key
if old_key in state_dict:
value = state_dict[old_key]
del state_dict[old_key]
state_dict[new_key] = value
super(_DenseBlock, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
class _Transition(nn.Sequential):
def __init__(self, num_input_features, num_output_features):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment