"docs/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "e7389d7c209752062cea0853eb2747d567fdb0c5"
Commit 10a71116 authored by eellison's avatar eellison Committed by Francisco Massa
Browse files

remove BC-breaking changes (#1560)

* remove changes that induced BC

* Re-enable tests that have been disabled

* Remove outdated comment

* Remove outdated comment
parent b05fc269
...@@ -5,7 +5,7 @@ from torch import nn ...@@ -5,7 +5,7 @@ from torch import nn
from torch.jit.annotations import Dict from torch.jit.annotations import Dict
class IntermediateLayerGetter(nn.Module): class IntermediateLayerGetter(nn.ModuleDict):
""" """
Module wrapper that returns intermediate layers from a model Module wrapper that returns intermediate layers from a model
...@@ -45,8 +45,6 @@ class IntermediateLayerGetter(nn.Module): ...@@ -45,8 +45,6 @@ class IntermediateLayerGetter(nn.Module):
def __init__(self, model, return_layers): def __init__(self, model, return_layers):
if not set(return_layers).issubset([name for name, _ in model.named_children()]): if not set(return_layers).issubset([name for name, _ in model.named_children()]):
raise ValueError("return_layers are not present in model") raise ValueError("return_layers are not present in model")
super(IntermediateLayerGetter, self).__init__()
orig_return_layers = return_layers orig_return_layers = return_layers
return_layers = {k: v for k, v in return_layers.items()} return_layers = {k: v for k, v in return_layers.items()}
layers = OrderedDict() layers = OrderedDict()
...@@ -57,33 +55,14 @@ class IntermediateLayerGetter(nn.Module): ...@@ -57,33 +55,14 @@ class IntermediateLayerGetter(nn.Module):
if not return_layers: if not return_layers:
break break
self.layers = nn.ModuleDict(layers) super(IntermediateLayerGetter, self).__init__(layers)
self.return_layers = orig_return_layers self.return_layers = orig_return_layers
def forward(self, x): def forward(self, x):
out = OrderedDict() out = OrderedDict()
for name, module in self.layers.items(): for name, module in self.items():
x = module(x) x = module(x)
if name in self.return_layers: if name in self.return_layers:
out_name = self.return_layers[name] out_name = self.return_layers[name]
out[out_name] = x out[out_name] = x
return out return out
@torch.jit.ignore
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
version = local_metadata.get('version', None)
if (version is None or version < 2):
# now we have a new nesting level for torchscript support
for new_key in self.state_dict().keys():
# remove prefix "layers."
old_key = new_key[len("layers."):]
old_key = prefix + old_key
new_key = prefix + new_key
if old_key in state_dict:
value = state_dict[old_key]
del state_dict[old_key]
state_dict[new_key] = value
super(IntermediateLayerGetter, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
...@@ -90,13 +90,12 @@ class _DenseLayer(nn.Module): ...@@ -90,13 +90,12 @@ class _DenseLayer(nn.Module):
return new_features return new_features
class _DenseBlock(nn.Module): class _DenseBlock(nn.ModuleDict):
_version = 2 _version = 2
__constants__ = ['layers'] __constants__ = ['layers']
def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, memory_efficient=False): def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, memory_efficient=False):
super(_DenseBlock, self).__init__() super(_DenseBlock, self).__init__()
self.layers = nn.ModuleDict()
for i in range(num_layers): for i in range(num_layers):
layer = _DenseLayer( layer = _DenseLayer(
num_input_features + i * growth_rate, num_input_features + i * growth_rate,
...@@ -105,34 +104,15 @@ class _DenseBlock(nn.Module): ...@@ -105,34 +104,15 @@ class _DenseBlock(nn.Module):
drop_rate=drop_rate, drop_rate=drop_rate,
memory_efficient=memory_efficient, memory_efficient=memory_efficient,
) )
self.layers['denselayer%d' % (i + 1)] = layer self.add_module('denselayer%d' % (i + 1), layer)
def forward(self, init_features): def forward(self, init_features):
features = [init_features] features = [init_features]
for name, layer in self.layers.items(): for name, layer in self.items():
new_features = layer(features) new_features = layer(features)
features.append(new_features) features.append(new_features)
return torch.cat(features, 1) return torch.cat(features, 1)
@torch.jit.ignore
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
version = local_metadata.get('version', None)
if (version is None or version < 2):
# now we have a new nesting level for torchscript support
for new_key in self.state_dict().keys():
# remove prefix "layers."
old_key = new_key[len("layers."):]
old_key = prefix + old_key
new_key = prefix + new_key
if old_key in state_dict:
value = state_dict[old_key]
del state_dict[old_key]
state_dict[new_key] = value
super(_DenseBlock, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
class _Transition(nn.Sequential): class _Transition(nn.Sequential):
def __init__(self, num_input_features, num_output_features): def __init__(self, num_input_features, num_output_features):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment