"vscode:/vscode.git/clone" did not exist on "bc26a2faea287cec6ceca03d6b8d4bbcc2e9a635"
Unverified Commit c447249c authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Support loading from `state_dict` of supernet (#4544)

parent 6b828681
......@@ -6,6 +6,7 @@ import re
from typing import Dict, List, Tuple, Any
from nni.retiarii.operation_def.torch_op_def import ToDevice
from nni.retiarii.utils import STATE_DICT_PY_MAPPING
from nni.common.device import Device, GPUDevice
from ..graph import IllegalGraphError, Edge, Graph, Node, Model
......@@ -97,7 +98,18 @@ def _format_variable_name(name: str, graph_name: str) -> str:
name = name.replace('/', '__')
# https://stackoverflow.com/questions/3303312/how-do-i-convert-a-string-to-a-valid-variable-name-in-python
return re.sub('\W|^(?=\d)','_', name)
name = re.sub('\W|^(?=\d)','_', name)
if name.startswith('__') and (len(name) > 2 and name[2] != '_'):
# name can't start with double underscore
# it's reserved in Python: https://stackoverflow.com/a/1301409/6837658
# but it's actually very common in our generated code
name = name[1:]
elif name.startswith('_'):
# to avoid conflicts between '_' and '__'
name = 'i' + name
return name
def generate_cuda_mapping(placement: Dict[Node, Device]) -> Dict[Device, int]:
......@@ -125,6 +137,7 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
# only need to generate code for module here
import_pkgs = set()
node_codes = []
node_python_mappings = {}
cuda_remapped_id = None
if placement:
cuda_remapped_id = generate_cuda_mapping(placement)
......@@ -138,7 +151,9 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
pkg_name = node.operation.get_import_pkg()
if pkg_name is not None:
import_pkgs.add(pkg_name)
node_code = node.operation.to_init_code(_format_variable_name(node.name, graph_name))
py_variable_name = _format_variable_name(node.name, graph_name)
node_code = node.operation.to_init_code(py_variable_name)
if node_code is not None:
if placement and node in placement and len(node_code) > 0:
if isinstance(placement[node], GPUDevice):
......@@ -149,6 +164,11 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
else:
node_codes.append(node_code)
# Map to module hierarchies in original search space python code
node_python_mappings[py_variable_name] = node.python_name
node_codes.append(f'self.{STATE_DICT_PY_MAPPING} = {node_python_mappings}')
if graph.input_node.operation.io_names is None:
input_code = '*_inputs'
else:
......
......@@ -11,6 +11,7 @@ import torch.nn as nn
from nni.common.serializer import Translatable
from nni.retiarii.serializer import basic_unit
from nni.retiarii.utils import STATE_DICT_PY_MAPPING_PARTIAL
from .utils import Mutable, generate_new_label, get_fixed_value
......@@ -65,9 +66,22 @@ class LayerChoice(Mutable):
label: Optional[str] = None, **kwargs):
chosen = get_fixed_value(label)
if isinstance(candidates, list):
return candidates[int(chosen)]
result = candidates[int(chosen)]
else:
return candidates[chosen]
result = candidates[chosen]
# map the named hierarchies to support weight inheritance for python engine
if hasattr(result, STATE_DICT_PY_MAPPING_PARTIAL):
# handle cases where layer choices are nested
# already has a mapping, will merge with it
prev_mapping = getattr(result, STATE_DICT_PY_MAPPING_PARTIAL)
setattr(result, STATE_DICT_PY_MAPPING_PARTIAL, {k: f'{chosen}.{v}' for k, v in prev_mapping.items()})
else:
# "result" needs to know where to map itself.
# Ideally, we should put a _mapping_ in the module where "result" is located,
# but it's impossible to put mapping into parent module here.
setattr(result, STATE_DICT_PY_MAPPING_PARTIAL, {'__self__': str(chosen)})
return result
def __init__(self, candidates: Union[Dict[str, nn.Module], List[nn.Module]], *,
prior: Optional[List[float]] = None, label: Optional[str] = None, **kwargs):
......
......@@ -5,6 +5,8 @@ from typing import Callable, List, Union, Tuple, Optional
import torch
import torch.nn as nn
from nni.retiarii.utils import STATE_DICT_PY_MAPPING_PARTIAL
from .api import LayerChoice
from .cell import Cell
from .nasbench101 import NasBench101Cell, NasBench101Mutator
......@@ -38,7 +40,15 @@ class Repeat(Mutable):
List[nn.Module]],
depth: Union[int, Tuple[int, int]], *, label: Optional[str] = None):
repeat = get_fixed_value(label)
return nn.Sequential(*cls._replicate_and_instantiate(blocks, repeat))
result = nn.Sequential(*cls._replicate_and_instantiate(blocks, repeat))
if hasattr(result, STATE_DICT_PY_MAPPING_PARTIAL):
# already has a mapping, will merge with it
prev_mapping = getattr(result, STATE_DICT_PY_MAPPING_PARTIAL)
setattr(result, STATE_DICT_PY_MAPPING_PARTIAL, {k: f'blocks.{v}' for k, v in prev_mapping.items()})
else:
setattr(result, STATE_DICT_PY_MAPPING_PARTIAL, {'__self__': 'blocks'})
return result
def __init__(self,
blocks: Union[Callable[[int], nn.Module],
......
......@@ -304,6 +304,8 @@ class NasBench101Cell(Mutable):
[op_candidates[selected[f'{label}/op{i}']] for i in range(1, num_nodes - 1)],
adjacency_list, in_features, out_features, num_nodes, projection)
# FIXME: weight inheritance on nasbench101 is not supported yet
def __init__(self, op_candidates: Union[Dict[str, Callable[[int], nn.Module]], List[Callable[[int], nn.Module]]],
in_features: int, out_features: int, projection: Callable[[int, int], nn.Module],
max_num_nodes: int = 7, max_num_edges: int = 9, label: Optional[str] = None):
......
......@@ -2,8 +2,10 @@
# Licensed under the MIT license.
import inspect
import itertools
import warnings
from collections import defaultdict
from contextlib import contextmanager
from typing import Any, List, Dict
from pathlib import Path
......@@ -154,3 +156,119 @@ class ModelNamespace:
def get_current_context(key: str) -> Any:
return ContextStack.top(key)
# map variables to prefix in the state dict
# e.g., {'upsample': 'mynet.module.deconv2.upsample_layer'}
STATE_DICT_PY_MAPPING = '_mapping_'
# map variables to `prefix`.`value` in the state dict
# e.g., {'upsample': 'choice3.upsample_layer'},
# which actually means {'upsample': 'mynet.module.choice3.upsample_layer'},
# and 'upsample' is also in `mynet.module`.
STATE_DICT_PY_MAPPING_PARTIAL = '_mapping_partial_'
@contextmanager
def original_state_dict_hooks(model: Any):
"""
Use this patch if you want to save/load state dict in the original state dict hierarchy.
For example, when you already have a state dict for the base model / search space (which often
happens when you have trained a supernet with one-shot strategies), the state dict isn't organized
in the same way as when a sub-model is sampled from the search space. This patch will help
the modules in the sub-model find the corresponding module in the base model.
The code looks like,
.. code-block:: python
with original_state_dict_hooks(model):
model.load_state_dict(state_dict_from_supernet, strict=False) # supernet has extra keys
Or vice-versa,
.. code-block:: python
with original_state_dict_hooks(model):
supernet_style_state_dict = model.state_dict()
"""
import torch.nn as nn
assert isinstance(model, nn.Module), 'PyTorch is the only supported framework for now.'
# the following are written for pytorch only
# first get the full mapping
full_mapping = {}
def full_mapping_in_module(src_prefix, tar_prefix, module):
if hasattr(module, STATE_DICT_PY_MAPPING):
# only values are complete
local_map = getattr(module, STATE_DICT_PY_MAPPING)
elif hasattr(module, STATE_DICT_PY_MAPPING_PARTIAL):
# keys and values are both incomplete
local_map = getattr(module, STATE_DICT_PY_MAPPING_PARTIAL)
local_map = {k: tar_prefix + v for k, v in local_map.items()}
else:
# no mapping
local_map = {}
if '__self__' in local_map:
# special case, overwrite prefix
tar_prefix = local_map['__self__'] + '.'
for key, value in local_map.items():
if key != '' and key not in module._modules: # not a sub-module, probably a parameter
full_mapping[src_prefix + key] = value
if src_prefix != tar_prefix: # To deal with leaf nodes.
for name, value in itertools.chain(module._parameters.items(), module._buffers.items()): # direct children
if value is None or name in module._non_persistent_buffers_set:
# it won't appear in state dict
continue
if (src_prefix + name) not in full_mapping:
full_mapping[src_prefix + name] = tar_prefix + name
for name, child in module.named_children():
# sub-modules
full_mapping_in_module(
src_prefix + name + '.',
local_map.get(name, tar_prefix + name) + '.', # if mapping doesn't exist, respect the prefix
child
)
full_mapping_in_module('', '', model)
def load_state_dict_hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
reverse_mapping = defaultdict(list)
for src, tar in full_mapping.items():
reverse_mapping[tar].append(src)
transf_state_dict = {}
for src, tar_keys in reverse_mapping.items():
if src in state_dict:
value = state_dict.pop(src)
for tar in tar_keys:
transf_state_dict[tar] = value
else:
missing_keys.append(src)
state_dict.update(transf_state_dict)
def state_dict_hook(module, destination, prefix, local_metadata):
result = {}
for src, tar in full_mapping.items():
if src in destination:
result[tar] = destination.pop(src)
else:
raise KeyError(f'"{src}" not in state dict, but found in mapping.')
destination.update(result)
try:
hooks = []
hooks.append(model._register_load_state_dict_pre_hook(load_state_dict_hook))
hooks.append(model._register_state_dict_hook(state_dict_hook))
yield
finally:
for hook in hooks:
hook.remove()
......@@ -16,6 +16,7 @@ class _model(nn.Module):
self.fc1 = torch.nn.Linear(out_features=256, in_features=1024)
self.fc2 = torch.nn.Linear(out_features=10, in_features=256)
self.softmax = torch.nn.Softmax()
self._mapping_ = {'stem': None, 'flatten': None, 'fc1': None, 'fc2': None, 'softmax': None}
def forward(self, image):
stem = self.stem(image)
......@@ -34,6 +35,7 @@ class stem(nn.Module):
self.pool1 = torch.nn.MaxPool2d(kernel_size=2)
self.conv2 = torch.nn.Conv2d(out_channels=64, in_channels=32, kernel_size=5)
self.pool2 = torch.nn.MaxPool2d(kernel_size=2)
self._mapping_ = {'conv1': None, 'pool1': None, 'conv2': None, 'pool2': None}
def forward(self, *_inputs):
conv1 = self.conv1(_inputs[0])
......
......@@ -14,6 +14,7 @@ import torchvision
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import basic_unit
from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.utils import original_state_dict_hooks
from .convert_mixin import ConvertMixin, ConvertWithShapeMixin
......@@ -50,16 +51,6 @@ class Linear(nn.Module):
return out.view(size[0], size[1], -1)
class TestConvert(unittest.TestCase, ConvertMixin):
@staticmethod
def _match_state_dict(current_values, expected_format):
result = {}
for k, v in expected_format.items():
for idx, cv in enumerate(current_values):
if cv.shape == v.shape:
result[k] = cv
current_values.pop(idx)
break
return result
def checkExportImport(self, model, input):
model_ir = self._convert_model(model, input)
......@@ -68,9 +59,8 @@ class TestConvert(unittest.TestCase, ConvertMixin):
exec_vars = {}
exec(model_code + '\n\nconverted_model = _model()', exec_vars)
converted_model = exec_vars['converted_model']
converted_state_dict = self._match_state_dict(list(model.state_dict().values()),
dict(converted_model.state_dict()))
converted_model.load_state_dict(converted_state_dict)
with original_state_dict_hooks(converted_model):
converted_model.load_state_dict(dict(model.state_dict()))
with torch.no_grad():
expected_output = model.eval()(*input)
converted_output = converted_model.eval()(*input)
......
......@@ -12,20 +12,11 @@ from nni.retiarii import basic_unit
from .convert_mixin import ConvertMixin, ConvertWithShapeMixin
from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.utils import original_state_dict_hooks
# following pytorch v1.7.1
class TestConvert(unittest.TestCase, ConvertMixin):
@staticmethod
def _match_state_dict(current_values, expected_format):
result = {}
for k, v in expected_format.items():
for idx, cv in enumerate(current_values):
if cv.shape == v.shape:
result[k] = cv
current_values.pop(idx)
break
return result
def checkExportImport(self, model, input, check_value=True):
model_ir = self._convert_model(model, input)
......@@ -35,9 +26,10 @@ class TestConvert(unittest.TestCase, ConvertMixin):
exec_vars = {}
exec(model_code + '\n\nconverted_model = _model()', exec_vars)
converted_model = exec_vars['converted_model']
converted_state_dict = self._match_state_dict(list(model.state_dict().values()),
dict(converted_model.state_dict()))
converted_model.load_state_dict(converted_state_dict)
with original_state_dict_hooks(converted_model):
converted_model.load_state_dict(model.state_dict())
with torch.no_grad():
expected_output = model.eval()(*input)
converted_output = converted_model.eval()(*input)
......
......@@ -9,23 +9,13 @@ import torch.nn.functional as F
import torchvision
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import serialize
from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.utils import original_state_dict_hooks
from .convert_mixin import ConvertMixin, ConvertWithShapeMixin
class TestModels(unittest.TestCase, ConvertMixin):
@staticmethod
def _match_state_dict(current_values, expected_format):
result = {}
for k, v in expected_format.items():
for idx, cv in enumerate(current_values):
if cv.shape == v.shape:
result[k] = cv
current_values.pop(idx)
break
return result
def run_test(self, model, input, check_value=True):
model_ir = self._convert_model(model, input)
......@@ -35,9 +25,10 @@ class TestModels(unittest.TestCase, ConvertMixin):
exec_vars = {}
exec(model_code + '\n\nconverted_model = _model()', exec_vars)
converted_model = exec_vars['converted_model']
converted_state_dict = self._match_state_dict(list(model.state_dict().values()),
dict(converted_model.state_dict()))
converted_model.load_state_dict(converted_state_dict)
with original_state_dict_hooks(converted_model):
converted_model.load_state_dict(model.state_dict())
with torch.no_grad():
expected_output = model.eval()(*input)
converted_output = converted_model.eval()(*input)
......
......@@ -16,6 +16,7 @@ import torchvision
import nni.retiarii.nn.pytorch as nn
from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.utils import original_state_dict_hooks
from .convert_mixin import ConvertMixin, ConvertWithShapeMixin
......@@ -23,16 +24,6 @@ from .convert_mixin import ConvertMixin, ConvertWithShapeMixin
class TestOperators(unittest.TestCase, ConvertMixin):
@staticmethod
def _match_state_dict(current_values, expected_format):
result = {}
for k, v in expected_format.items():
for idx, cv in enumerate(current_values):
if cv.shape == v.shape:
result[k] = cv
current_values.pop(idx)
break
return result
def checkExportImport(self, model, input, check_value=True):
model_ir = self._convert_model(model, input)
......@@ -42,9 +33,10 @@ class TestOperators(unittest.TestCase, ConvertMixin):
exec_vars = {}
exec(model_code + '\n\nconverted_model = _model()', exec_vars)
converted_model = exec_vars['converted_model']
converted_state_dict = self._match_state_dict(list(model.state_dict().values()),
dict(converted_model.state_dict()))
converted_model.load_state_dict(converted_state_dict)
with original_state_dict_hooks(converted_model):
converted_model.load_state_dict(model.state_dict())
with torch.no_grad():
expected_output = model.eval()(*input)
converted_output = converted_model.eval()(*input)
......
......@@ -14,28 +14,17 @@ import torch.nn.functional as F
import torchvision
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import serialize
from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.utils import original_state_dict_hooks
from .convert_mixin import ConvertMixin, ConvertWithShapeMixin
class TestPytorch(unittest.TestCase, ConvertMixin):
@staticmethod
def _match_state_dict(current_values, expected_format):
result = {}
for k, v in expected_format.items():
for idx, cv in enumerate(current_values):
if cv.shape == v.shape:
result[k] = cv
current_values.pop(idx)
break
return result
def run_test(self, model, input, check_value=True):
def run_test(self, model, input, check_value=True, strict_load=True):
model_ir = self._convert_model(model, input)
model_code = model_to_pytorch_script(model_ir)
print(model_code)
from .inject_nn import remove_inject_pytorch_nn
remove_inject_pytorch_nn()
......@@ -43,9 +32,10 @@ class TestPytorch(unittest.TestCase, ConvertMixin):
exec_vars = {}
exec(model_code + '\n\nconverted_model = _model()', exec_vars)
converted_model = exec_vars['converted_model']
converted_state_dict = self._match_state_dict(list(model.state_dict().values()),
dict(converted_model.state_dict()))
converted_model.load_state_dict(converted_state_dict)
with original_state_dict_hooks(converted_model):
converted_model.load_state_dict(model.state_dict(), strict=strict_load)
with torch.no_grad():
expected_output = model.eval()(*input)
converted_output = converted_model.eval()(*input)
......@@ -76,7 +66,8 @@ class TestPytorch(unittest.TestCase, ConvertMixin):
model = LargeModel()
x = torch.tensor([2], dtype=torch.long)
self.run_test(model, (x, ))
# emb and lin1 is actually not used so they won't appear in generated model
self.run_test(model, (x, ), strict_load=False)
@unittest.skip('skip for now, as it needs inject_nn')
def test_mobilenet_v2_with_external_data(self):
......
......@@ -17,7 +17,7 @@ from nni.retiarii.graph import Model
from nni.retiarii.nn.pytorch.api import ValueChoice
from nni.retiarii.nn.pytorch.mutator import process_evaluator_mutations, process_inline_mutation, extract_mutation_from_pt_module
from nni.retiarii.serializer import model_wrapper
from nni.retiarii.utils import ContextStack
from nni.retiarii.utils import ContextStack, original_state_dict_hooks
class EnumerateSampler(Sampler):
......@@ -123,6 +123,29 @@ class GraphIR(unittest.TestCase):
self.assertEqual(self._get_converted_pytorch_model(model_new)(torch.randn(1, 3, 3, 3)).size(),
torch.Size([1, i, 3, 3]))
def test_layer_choice_weight_inheritance(self):
@model_wrapper
class Net(nn.Module):
def __init__(self):
super().__init__()
self.module = nn.LayerChoice([nn.Conv2d(3, i, kernel_size=1) for i in range(1, 11)])
def forward(self, x):
return self.module(x)
orig_model = Net()
model, mutators = self._get_model_with_mutators(orig_model)
mutator = mutators[0].bind_sampler(EnumerateSampler())
for i in range(1, 11):
model_new = mutator.apply(model)
model_new = self._get_converted_pytorch_model(model_new)
with original_state_dict_hooks(model_new):
model_new.load_state_dict(orig_model.state_dict(), strict=False)
inp = torch.randn(1, 3, 3, 3)
a = getattr(orig_model.module, str(i - 1))(inp)
b = model_new(inp)
self.assertLess((a - b).abs().max().item(), 1E-4)
def test_nested_layer_choice(self):
@model_wrapper
class Net(nn.Module):
......@@ -150,6 +173,40 @@ class GraphIR(unittest.TestCase):
self.assertEqual(self._get_converted_pytorch_model(mutators[1].apply(mutators[0].apply(model)))(input).size(),
torch.Size([1, 5, 5, 5]))
def test_nested_layer_choice_weight_inheritance(self):
@model_wrapper
class Net(nn.Module):
def __init__(self):
super().__init__()
self.module = nn.LayerChoice([
nn.LayerChoice([nn.Conv2d(3, 3, kernel_size=1),
nn.Conv2d(3, 4, kernel_size=1),
nn.Conv2d(3, 5, kernel_size=1)]),
nn.Conv2d(3, 1, kernel_size=1)
])
def forward(self, x):
return self.module(x)
orig_model = Net()
model, mutators = self._get_model_with_mutators(orig_model)
mutators[0].bind_sampler(EnumerateSampler())
mutators[1].bind_sampler(EnumerateSampler())
input = torch.randn(1, 3, 5, 5)
for i in range(3):
model_new = self._get_converted_pytorch_model(mutators[1].apply(mutators[0].apply(model)))
with original_state_dict_hooks(model_new):
model_new.load_state_dict(orig_model.state_dict(), strict=False)
if i == 0:
a = getattr(getattr(orig_model.module, '0'), '0')(input)
elif i == 1:
a = getattr(orig_model.module, '1')(input)
elif i == 2:
a = getattr(getattr(orig_model.module, '0'), '2')(input)
b = model_new(input)
self.assertLess((a - b).abs().max().item(), 1E-4)
def test_input_choice(self):
@model_wrapper
class Net(nn.Module):
......@@ -578,6 +635,30 @@ class GraphIR(unittest.TestCase):
self.assertIn(1., result)
def test_repeat_weight_inheritance(self):
@model_wrapper
class Net(nn.Module):
def __init__(self):
super().__init__()
self.module = nn.Repeat(lambda index: nn.Conv2d(3, 3, 1), (2, 5))
def forward(self, x):
return self.module(x)
orig_model = Net()
model, mutators = self._get_model_with_mutators(orig_model)
mutator = mutators[0].bind_sampler(EnumerateSampler())
inp = torch.randn(1, 3, 5, 5)
for i in range(4):
model_new = self._get_converted_pytorch_model(mutator.apply(model))
with original_state_dict_hooks(model_new):
model_new.load_state_dict(orig_model.state_dict(), strict=False)
a = nn.Sequential(*orig_model.module.blocks[:i + 2])(inp)
b = model_new(inp)
self.assertLess((a - b).abs().max().item(), 1E-4)
def test_cell(self):
@model_wrapper
class Net(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment