Commit a6a926bc authored by eellison's avatar eellison Committed by Francisco Massa
Browse files

Make fcn_resnet Scriptable (#1352)

* script_fcn_resnet

* Make old models load

* DeepLabV3 also got torchscript-ready
parent 05ca824e
...@@ -28,10 +28,10 @@ def get_available_video_models(): ...@@ -28,10 +28,10 @@ def get_available_video_models():
# model_name, expected to script without error # model_name, expected to script without error
torchub_models = { torchub_models = {
"deeplabv3_resnet101": False, "deeplabv3_resnet101": True,
"mobilenet_v2": True, "mobilenet_v2": True,
"resnext50_32x4d": True, "resnext50_32x4d": True,
"fcn_resnet101": False, "fcn_resnet101": True,
"googlenet": False, "googlenet": False,
"densenet121": True, "densenet121": True,
"resnet18": True, "resnet18": True,
......
...@@ -2,9 +2,10 @@ from collections import OrderedDict ...@@ -2,9 +2,10 @@ from collections import OrderedDict
import torch import torch
from torch import nn from torch import nn
from torch.jit.annotations import Dict
class IntermediateLayerGetter(nn.ModuleDict): class IntermediateLayerGetter(nn.Module):
""" """
Module wrapper that returns intermediate layers from a model Module wrapper that returns intermediate layers from a model
...@@ -35,9 +36,16 @@ class IntermediateLayerGetter(nn.ModuleDict): ...@@ -35,9 +36,16 @@ class IntermediateLayerGetter(nn.ModuleDict):
>>> [('feat1', torch.Size([1, 64, 56, 56])), >>> [('feat1', torch.Size([1, 64, 56, 56])),
>>> ('feat2', torch.Size([1, 256, 14, 14]))] >>> ('feat2', torch.Size([1, 256, 14, 14]))]
""" """
_version = 2
__constants__ = ['layers']
__annotations__ = {
"return_layers": Dict[str, str],
}
def __init__(self, model, return_layers): def __init__(self, model, return_layers):
if not set(return_layers).issubset([name for name, _ in model.named_children()]): if not set(return_layers).issubset([name for name, _ in model.named_children()]):
raise ValueError("return_layers are not present in model") raise ValueError("return_layers are not present in model")
super(IntermediateLayerGetter, self).__init__()
orig_return_layers = return_layers orig_return_layers = return_layers
return_layers = {k: v for k, v in return_layers.items()} return_layers = {k: v for k, v in return_layers.items()}
...@@ -49,14 +57,33 @@ class IntermediateLayerGetter(nn.ModuleDict): ...@@ -49,14 +57,33 @@ class IntermediateLayerGetter(nn.ModuleDict):
if not return_layers: if not return_layers:
break break
super(IntermediateLayerGetter, self).__init__(layers) self.layers = nn.ModuleDict(layers)
self.return_layers = orig_return_layers self.return_layers = orig_return_layers
def forward(self, x): def forward(self, x):
out = OrderedDict() out = OrderedDict()
for name, module in self.named_children(): for name, module in self.layers.items():
x = module(x) x = module(x)
if name in self.return_layers: if name in self.return_layers:
out_name = self.return_layers[name] out_name = self.return_layers[name]
out[out_name] = x out[out_name] = x
return out return out
@torch.jit.ignore
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
version = local_metadata.get('version', None)
if (version is None or version < 2):
# now we have a new nesting level for torchscript support
for new_key in self.state_dict().keys():
# remove prefix "layers."
old_key = new_key[len("layers."):]
old_key = prefix + old_key
new_key = prefix + new_key
if old_key in state_dict:
value = state_dict[old_key]
del state_dict[old_key]
state_dict[new_key] = value
super(IntermediateLayerGetter, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
...@@ -6,6 +6,8 @@ from torch.nn import functional as F ...@@ -6,6 +6,8 @@ from torch.nn import functional as F
class _SimpleSegmentationModel(nn.Module): class _SimpleSegmentationModel(nn.Module):
__constants__ = ['aux_classifier']
def __init__(self, backbone, classifier, aux_classifier=None): def __init__(self, backbone, classifier, aux_classifier=None):
super(_SimpleSegmentationModel, self).__init__() super(_SimpleSegmentationModel, self).__init__()
self.backbone = backbone self.backbone = backbone
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment