Commit a6a926bc authored by eellison's avatar eellison Committed by Francisco Massa
Browse files

Make fcn_resnet Scriptable (#1352)

* script_fcn_resnet

* Make old models load

* DeepLabV3 also got torchscript-ready
parent 05ca824e
......@@ -28,10 +28,10 @@ def get_available_video_models():
# model_name, expected to script without error
torchub_models = {
"deeplabv3_resnet101": False,
"deeplabv3_resnet101": True,
"mobilenet_v2": True,
"resnext50_32x4d": True,
"fcn_resnet101": False,
"fcn_resnet101": True,
"googlenet": False,
"densenet121": True,
"resnet18": True,
......
......@@ -2,9 +2,10 @@ from collections import OrderedDict
import torch
from torch import nn
from torch.jit.annotations import Dict
class IntermediateLayerGetter(nn.ModuleDict):
class IntermediateLayerGetter(nn.Module):
"""
Module wrapper that returns intermediate layers from a model
......@@ -35,9 +36,16 @@ class IntermediateLayerGetter(nn.ModuleDict):
>>> [('feat1', torch.Size([1, 64, 56, 56])),
>>> ('feat2', torch.Size([1, 256, 14, 14]))]
"""
_version = 2
__constants__ = ['layers']
__annotations__ = {
"return_layers": Dict[str, str],
}
def __init__(self, model, return_layers):
if not set(return_layers).issubset([name for name, _ in model.named_children()]):
raise ValueError("return_layers are not present in model")
super(IntermediateLayerGetter, self).__init__()
orig_return_layers = return_layers
return_layers = {k: v for k, v in return_layers.items()}
......@@ -49,14 +57,33 @@ class IntermediateLayerGetter(nn.ModuleDict):
if not return_layers:
break
super(IntermediateLayerGetter, self).__init__(layers)
self.layers = nn.ModuleDict(layers)
self.return_layers = orig_return_layers
def forward(self, x):
out = OrderedDict()
for name, module in self.named_children():
for name, module in self.layers.items():
x = module(x)
if name in self.return_layers:
out_name = self.return_layers[name]
out[out_name] = x
return out
@torch.jit.ignore
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
version = local_metadata.get('version', None)
if (version is None or version < 2):
# now we have a new nesting level for torchscript support
for new_key in self.state_dict().keys():
# remove prefix "layers."
old_key = new_key[len("layers."):]
old_key = prefix + old_key
new_key = prefix + new_key
if old_key in state_dict:
value = state_dict[old_key]
del state_dict[old_key]
state_dict[new_key] = value
super(IntermediateLayerGetter, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
......@@ -6,6 +6,8 @@ from torch.nn import functional as F
class _SimpleSegmentationModel(nn.Module):
__constants__ = ['aux_classifier']
def __init__(self, backbone, classifier, aux_classifier=None):
super(_SimpleSegmentationModel, self).__init__()
self.backbone = backbone
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment