Unverified Commit c447249c authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Support loading from `state_dict` of supernet (#4544)

parent 6b828681
...@@ -6,6 +6,7 @@ import re ...@@ -6,6 +6,7 @@ import re
from typing import Dict, List, Tuple, Any from typing import Dict, List, Tuple, Any
from nni.retiarii.operation_def.torch_op_def import ToDevice from nni.retiarii.operation_def.torch_op_def import ToDevice
from nni.retiarii.utils import STATE_DICT_PY_MAPPING
from nni.common.device import Device, GPUDevice from nni.common.device import Device, GPUDevice
from ..graph import IllegalGraphError, Edge, Graph, Node, Model from ..graph import IllegalGraphError, Edge, Graph, Node, Model
...@@ -97,7 +98,18 @@ def _format_variable_name(name: str, graph_name: str) -> str: ...@@ -97,7 +98,18 @@ def _format_variable_name(name: str, graph_name: str) -> str:
name = name.replace('/', '__') name = name.replace('/', '__')
# https://stackoverflow.com/questions/3303312/how-do-i-convert-a-string-to-a-valid-variable-name-in-python # https://stackoverflow.com/questions/3303312/how-do-i-convert-a-string-to-a-valid-variable-name-in-python
return re.sub('\W|^(?=\d)','_', name) name = re.sub('\W|^(?=\d)','_', name)
if name.startswith('__') and (len(name) > 2 and name[2] != '_'):
# name can't start with double underscore
# it's reserved in Python: https://stackoverflow.com/a/1301409/6837658
# but it's actually very common in our generated code
name = name[1:]
elif name.startswith('_'):
# to avoid conflicts between '_' and '__'
name = 'i' + name
return name
def generate_cuda_mapping(placement: Dict[Node, Device]) -> Dict[Device, int]: def generate_cuda_mapping(placement: Dict[Node, Device]) -> Dict[Device, int]:
...@@ -125,6 +137,7 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str ...@@ -125,6 +137,7 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
# only need to generate code for module here # only need to generate code for module here
import_pkgs = set() import_pkgs = set()
node_codes = [] node_codes = []
node_python_mappings = {}
cuda_remapped_id = None cuda_remapped_id = None
if placement: if placement:
cuda_remapped_id = generate_cuda_mapping(placement) cuda_remapped_id = generate_cuda_mapping(placement)
...@@ -138,7 +151,9 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str ...@@ -138,7 +151,9 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
pkg_name = node.operation.get_import_pkg() pkg_name = node.operation.get_import_pkg()
if pkg_name is not None: if pkg_name is not None:
import_pkgs.add(pkg_name) import_pkgs.add(pkg_name)
node_code = node.operation.to_init_code(_format_variable_name(node.name, graph_name))
py_variable_name = _format_variable_name(node.name, graph_name)
node_code = node.operation.to_init_code(py_variable_name)
if node_code is not None: if node_code is not None:
if placement and node in placement and len(node_code) > 0: if placement and node in placement and len(node_code) > 0:
if isinstance(placement[node], GPUDevice): if isinstance(placement[node], GPUDevice):
...@@ -149,6 +164,11 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str ...@@ -149,6 +164,11 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
else: else:
node_codes.append(node_code) node_codes.append(node_code)
# Map to module hierarchies in original search space python code
node_python_mappings[py_variable_name] = node.python_name
node_codes.append(f'self.{STATE_DICT_PY_MAPPING} = {node_python_mappings}')
if graph.input_node.operation.io_names is None: if graph.input_node.operation.io_names is None:
input_code = '*_inputs' input_code = '*_inputs'
else: else:
......
...@@ -11,6 +11,7 @@ import torch.nn as nn ...@@ -11,6 +11,7 @@ import torch.nn as nn
from nni.common.serializer import Translatable from nni.common.serializer import Translatable
from nni.retiarii.serializer import basic_unit from nni.retiarii.serializer import basic_unit
from nni.retiarii.utils import STATE_DICT_PY_MAPPING_PARTIAL
from .utils import Mutable, generate_new_label, get_fixed_value from .utils import Mutable, generate_new_label, get_fixed_value
...@@ -65,9 +66,22 @@ class LayerChoice(Mutable): ...@@ -65,9 +66,22 @@ class LayerChoice(Mutable):
label: Optional[str] = None, **kwargs): label: Optional[str] = None, **kwargs):
chosen = get_fixed_value(label) chosen = get_fixed_value(label)
if isinstance(candidates, list): if isinstance(candidates, list):
return candidates[int(chosen)] result = candidates[int(chosen)]
else: else:
return candidates[chosen] result = candidates[chosen]
# map the named hierarchies to support weight inheritance for python engine
if hasattr(result, STATE_DICT_PY_MAPPING_PARTIAL):
# handle cases where layer choices are nested
# already has a mapping, will merge with it
prev_mapping = getattr(result, STATE_DICT_PY_MAPPING_PARTIAL)
setattr(result, STATE_DICT_PY_MAPPING_PARTIAL, {k: f'{chosen}.{v}' for k, v in prev_mapping.items()})
else:
# "result" needs to know where to map itself.
# Ideally, we should put a _mapping_ in the module where "result" is located,
# but it's impossible to put mapping into parent module here.
setattr(result, STATE_DICT_PY_MAPPING_PARTIAL, {'__self__': str(chosen)})
return result
def __init__(self, candidates: Union[Dict[str, nn.Module], List[nn.Module]], *, def __init__(self, candidates: Union[Dict[str, nn.Module], List[nn.Module]], *,
prior: Optional[List[float]] = None, label: Optional[str] = None, **kwargs): prior: Optional[List[float]] = None, label: Optional[str] = None, **kwargs):
......
...@@ -5,6 +5,8 @@ from typing import Callable, List, Union, Tuple, Optional ...@@ -5,6 +5,8 @@ from typing import Callable, List, Union, Tuple, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from nni.retiarii.utils import STATE_DICT_PY_MAPPING_PARTIAL
from .api import LayerChoice from .api import LayerChoice
from .cell import Cell from .cell import Cell
from .nasbench101 import NasBench101Cell, NasBench101Mutator from .nasbench101 import NasBench101Cell, NasBench101Mutator
...@@ -38,7 +40,15 @@ class Repeat(Mutable): ...@@ -38,7 +40,15 @@ class Repeat(Mutable):
List[nn.Module]], List[nn.Module]],
depth: Union[int, Tuple[int, int]], *, label: Optional[str] = None): depth: Union[int, Tuple[int, int]], *, label: Optional[str] = None):
repeat = get_fixed_value(label) repeat = get_fixed_value(label)
return nn.Sequential(*cls._replicate_and_instantiate(blocks, repeat)) result = nn.Sequential(*cls._replicate_and_instantiate(blocks, repeat))
if hasattr(result, STATE_DICT_PY_MAPPING_PARTIAL):
# already has a mapping, will merge with it
prev_mapping = getattr(result, STATE_DICT_PY_MAPPING_PARTIAL)
setattr(result, STATE_DICT_PY_MAPPING_PARTIAL, {k: f'blocks.{v}' for k, v in prev_mapping.items()})
else:
setattr(result, STATE_DICT_PY_MAPPING_PARTIAL, {'__self__': 'blocks'})
return result
def __init__(self, def __init__(self,
blocks: Union[Callable[[int], nn.Module], blocks: Union[Callable[[int], nn.Module],
......
...@@ -304,6 +304,8 @@ class NasBench101Cell(Mutable): ...@@ -304,6 +304,8 @@ class NasBench101Cell(Mutable):
[op_candidates[selected[f'{label}/op{i}']] for i in range(1, num_nodes - 1)], [op_candidates[selected[f'{label}/op{i}']] for i in range(1, num_nodes - 1)],
adjacency_list, in_features, out_features, num_nodes, projection) adjacency_list, in_features, out_features, num_nodes, projection)
# FIXME: weight inheritance on nasbench101 is not supported yet
def __init__(self, op_candidates: Union[Dict[str, Callable[[int], nn.Module]], List[Callable[[int], nn.Module]]], def __init__(self, op_candidates: Union[Dict[str, Callable[[int], nn.Module]], List[Callable[[int], nn.Module]]],
in_features: int, out_features: int, projection: Callable[[int, int], nn.Module], in_features: int, out_features: int, projection: Callable[[int, int], nn.Module],
max_num_nodes: int = 7, max_num_edges: int = 9, label: Optional[str] = None): max_num_nodes: int = 7, max_num_edges: int = 9, label: Optional[str] = None):
......
...@@ -2,8 +2,10 @@ ...@@ -2,8 +2,10 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import inspect import inspect
import itertools
import warnings import warnings
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager
from typing import Any, List, Dict from typing import Any, List, Dict
from pathlib import Path from pathlib import Path
...@@ -154,3 +156,119 @@ class ModelNamespace: ...@@ -154,3 +156,119 @@ class ModelNamespace:
def get_current_context(key: str) -> Any: def get_current_context(key: str) -> Any:
return ContextStack.top(key) return ContextStack.top(key)
# map variables to prefix in the state dict
# e.g., {'upsample': 'mynet.module.deconv2.upsample_layer'}
STATE_DICT_PY_MAPPING = '_mapping_'
# map variables to `prefix`.`value` in the state dict
# e.g., {'upsample': 'choice3.upsample_layer'},
# which actually means {'upsample': 'mynet.module.choice3.upsample_layer'},
# and 'upsample' is also in `mynet.module`.
STATE_DICT_PY_MAPPING_PARTIAL = '_mapping_partial_'
@contextmanager
def original_state_dict_hooks(model: Any):
"""
Use this patch if you want to save/load state dict in the original state dict hierarchy.
For example, when you already have a state dict for the base model / search space (which often
happens when you have trained a supernet with one-shot strategies), the state dict isn't organized
in the same way as when a sub-model is sampled from the search space. This patch will help
the modules in the sub-model find the corresponding module in the base model.
The code looks like,
.. code-block:: python
with original_state_dict_hooks(model):
model.load_state_dict(state_dict_from_supernet, strict=False) # supernet has extra keys
Or vice-versa,
.. code-block:: python
with original_state_dict_hooks(model):
supernet_style_state_dict = model.state_dict()
"""
import torch.nn as nn
assert isinstance(model, nn.Module), 'PyTorch is the only supported framework for now.'
# the following are written for pytorch only
# first get the full mapping
full_mapping = {}
def full_mapping_in_module(src_prefix, tar_prefix, module):
if hasattr(module, STATE_DICT_PY_MAPPING):
# only values are complete
local_map = getattr(module, STATE_DICT_PY_MAPPING)
elif hasattr(module, STATE_DICT_PY_MAPPING_PARTIAL):
# keys and values are both incomplete
local_map = getattr(module, STATE_DICT_PY_MAPPING_PARTIAL)
local_map = {k: tar_prefix + v for k, v in local_map.items()}
else:
# no mapping
local_map = {}
if '__self__' in local_map:
# special case, overwrite prefix
tar_prefix = local_map['__self__'] + '.'
for key, value in local_map.items():
if key != '' and key not in module._modules: # not a sub-module, probably a parameter
full_mapping[src_prefix + key] = value
if src_prefix != tar_prefix: # To deal with leaf nodes.
for name, value in itertools.chain(module._parameters.items(), module._buffers.items()): # direct children
if value is None or name in module._non_persistent_buffers_set:
# it won't appear in state dict
continue
if (src_prefix + name) not in full_mapping:
full_mapping[src_prefix + name] = tar_prefix + name
for name, child in module.named_children():
# sub-modules
full_mapping_in_module(
src_prefix + name + '.',
local_map.get(name, tar_prefix + name) + '.', # if mapping doesn't exist, respect the prefix
child
)
full_mapping_in_module('', '', model)
def load_state_dict_hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
reverse_mapping = defaultdict(list)
for src, tar in full_mapping.items():
reverse_mapping[tar].append(src)
transf_state_dict = {}
for src, tar_keys in reverse_mapping.items():
if src in state_dict:
value = state_dict.pop(src)
for tar in tar_keys:
transf_state_dict[tar] = value
else:
missing_keys.append(src)
state_dict.update(transf_state_dict)
def state_dict_hook(module, destination, prefix, local_metadata):
result = {}
for src, tar in full_mapping.items():
if src in destination:
result[tar] = destination.pop(src)
else:
raise KeyError(f'"{src}" not in state dict, but found in mapping.')
destination.update(result)
try:
hooks = []
hooks.append(model._register_load_state_dict_pre_hook(load_state_dict_hook))
hooks.append(model._register_state_dict_hook(state_dict_hook))
yield
finally:
for hook in hooks:
hook.remove()
...@@ -16,6 +16,7 @@ class _model(nn.Module): ...@@ -16,6 +16,7 @@ class _model(nn.Module):
self.fc1 = torch.nn.Linear(out_features=256, in_features=1024) self.fc1 = torch.nn.Linear(out_features=256, in_features=1024)
self.fc2 = torch.nn.Linear(out_features=10, in_features=256) self.fc2 = torch.nn.Linear(out_features=10, in_features=256)
self.softmax = torch.nn.Softmax() self.softmax = torch.nn.Softmax()
self._mapping_ = {'stem': None, 'flatten': None, 'fc1': None, 'fc2': None, 'softmax': None}
def forward(self, image): def forward(self, image):
stem = self.stem(image) stem = self.stem(image)
...@@ -34,6 +35,7 @@ class stem(nn.Module): ...@@ -34,6 +35,7 @@ class stem(nn.Module):
self.pool1 = torch.nn.MaxPool2d(kernel_size=2) self.pool1 = torch.nn.MaxPool2d(kernel_size=2)
self.conv2 = torch.nn.Conv2d(out_channels=64, in_channels=32, kernel_size=5) self.conv2 = torch.nn.Conv2d(out_channels=64, in_channels=32, kernel_size=5)
self.pool2 = torch.nn.MaxPool2d(kernel_size=2) self.pool2 = torch.nn.MaxPool2d(kernel_size=2)
self._mapping_ = {'conv1': None, 'pool1': None, 'conv2': None, 'pool2': None}
def forward(self, *_inputs): def forward(self, *_inputs):
conv1 = self.conv1(_inputs[0]) conv1 = self.conv1(_inputs[0])
......
...@@ -14,6 +14,7 @@ import torchvision ...@@ -14,6 +14,7 @@ import torchvision
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
from nni.retiarii import basic_unit from nni.retiarii import basic_unit
from nni.retiarii.codegen import model_to_pytorch_script from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.utils import original_state_dict_hooks
from .convert_mixin import ConvertMixin, ConvertWithShapeMixin from .convert_mixin import ConvertMixin, ConvertWithShapeMixin
...@@ -50,16 +51,6 @@ class Linear(nn.Module): ...@@ -50,16 +51,6 @@ class Linear(nn.Module):
return out.view(size[0], size[1], -1) return out.view(size[0], size[1], -1)
class TestConvert(unittest.TestCase, ConvertMixin): class TestConvert(unittest.TestCase, ConvertMixin):
@staticmethod
def _match_state_dict(current_values, expected_format):
result = {}
for k, v in expected_format.items():
for idx, cv in enumerate(current_values):
if cv.shape == v.shape:
result[k] = cv
current_values.pop(idx)
break
return result
def checkExportImport(self, model, input): def checkExportImport(self, model, input):
model_ir = self._convert_model(model, input) model_ir = self._convert_model(model, input)
...@@ -68,9 +59,8 @@ class TestConvert(unittest.TestCase, ConvertMixin): ...@@ -68,9 +59,8 @@ class TestConvert(unittest.TestCase, ConvertMixin):
exec_vars = {} exec_vars = {}
exec(model_code + '\n\nconverted_model = _model()', exec_vars) exec(model_code + '\n\nconverted_model = _model()', exec_vars)
converted_model = exec_vars['converted_model'] converted_model = exec_vars['converted_model']
converted_state_dict = self._match_state_dict(list(model.state_dict().values()), with original_state_dict_hooks(converted_model):
dict(converted_model.state_dict())) converted_model.load_state_dict(dict(model.state_dict()))
converted_model.load_state_dict(converted_state_dict)
with torch.no_grad(): with torch.no_grad():
expected_output = model.eval()(*input) expected_output = model.eval()(*input)
converted_output = converted_model.eval()(*input) converted_output = converted_model.eval()(*input)
......
...@@ -12,20 +12,11 @@ from nni.retiarii import basic_unit ...@@ -12,20 +12,11 @@ from nni.retiarii import basic_unit
from .convert_mixin import ConvertMixin, ConvertWithShapeMixin from .convert_mixin import ConvertMixin, ConvertWithShapeMixin
from nni.retiarii.codegen import model_to_pytorch_script from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.utils import original_state_dict_hooks
# following pytorch v1.7.1 # following pytorch v1.7.1
class TestConvert(unittest.TestCase, ConvertMixin): class TestConvert(unittest.TestCase, ConvertMixin):
@staticmethod
def _match_state_dict(current_values, expected_format):
result = {}
for k, v in expected_format.items():
for idx, cv in enumerate(current_values):
if cv.shape == v.shape:
result[k] = cv
current_values.pop(idx)
break
return result
def checkExportImport(self, model, input, check_value=True): def checkExportImport(self, model, input, check_value=True):
model_ir = self._convert_model(model, input) model_ir = self._convert_model(model, input)
...@@ -35,9 +26,10 @@ class TestConvert(unittest.TestCase, ConvertMixin): ...@@ -35,9 +26,10 @@ class TestConvert(unittest.TestCase, ConvertMixin):
exec_vars = {} exec_vars = {}
exec(model_code + '\n\nconverted_model = _model()', exec_vars) exec(model_code + '\n\nconverted_model = _model()', exec_vars)
converted_model = exec_vars['converted_model'] converted_model = exec_vars['converted_model']
converted_state_dict = self._match_state_dict(list(model.state_dict().values()),
dict(converted_model.state_dict())) with original_state_dict_hooks(converted_model):
converted_model.load_state_dict(converted_state_dict) converted_model.load_state_dict(model.state_dict())
with torch.no_grad(): with torch.no_grad():
expected_output = model.eval()(*input) expected_output = model.eval()(*input)
converted_output = converted_model.eval()(*input) converted_output = converted_model.eval()(*input)
......
...@@ -9,23 +9,13 @@ import torch.nn.functional as F ...@@ -9,23 +9,13 @@ import torch.nn.functional as F
import torchvision import torchvision
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
from nni.retiarii import serialize
from nni.retiarii.codegen import model_to_pytorch_script from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.utils import original_state_dict_hooks
from .convert_mixin import ConvertMixin, ConvertWithShapeMixin from .convert_mixin import ConvertMixin, ConvertWithShapeMixin
class TestModels(unittest.TestCase, ConvertMixin): class TestModels(unittest.TestCase, ConvertMixin):
@staticmethod
def _match_state_dict(current_values, expected_format):
result = {}
for k, v in expected_format.items():
for idx, cv in enumerate(current_values):
if cv.shape == v.shape:
result[k] = cv
current_values.pop(idx)
break
return result
def run_test(self, model, input, check_value=True): def run_test(self, model, input, check_value=True):
model_ir = self._convert_model(model, input) model_ir = self._convert_model(model, input)
...@@ -35,9 +25,10 @@ class TestModels(unittest.TestCase, ConvertMixin): ...@@ -35,9 +25,10 @@ class TestModels(unittest.TestCase, ConvertMixin):
exec_vars = {} exec_vars = {}
exec(model_code + '\n\nconverted_model = _model()', exec_vars) exec(model_code + '\n\nconverted_model = _model()', exec_vars)
converted_model = exec_vars['converted_model'] converted_model = exec_vars['converted_model']
converted_state_dict = self._match_state_dict(list(model.state_dict().values()),
dict(converted_model.state_dict())) with original_state_dict_hooks(converted_model):
converted_model.load_state_dict(converted_state_dict) converted_model.load_state_dict(model.state_dict())
with torch.no_grad(): with torch.no_grad():
expected_output = model.eval()(*input) expected_output = model.eval()(*input)
converted_output = converted_model.eval()(*input) converted_output = converted_model.eval()(*input)
......
...@@ -16,6 +16,7 @@ import torchvision ...@@ -16,6 +16,7 @@ import torchvision
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
from nni.retiarii.codegen import model_to_pytorch_script from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.utils import original_state_dict_hooks
from .convert_mixin import ConvertMixin, ConvertWithShapeMixin from .convert_mixin import ConvertMixin, ConvertWithShapeMixin
...@@ -23,16 +24,6 @@ from .convert_mixin import ConvertMixin, ConvertWithShapeMixin ...@@ -23,16 +24,6 @@ from .convert_mixin import ConvertMixin, ConvertWithShapeMixin
class TestOperators(unittest.TestCase, ConvertMixin): class TestOperators(unittest.TestCase, ConvertMixin):
@staticmethod
def _match_state_dict(current_values, expected_format):
result = {}
for k, v in expected_format.items():
for idx, cv in enumerate(current_values):
if cv.shape == v.shape:
result[k] = cv
current_values.pop(idx)
break
return result
def checkExportImport(self, model, input, check_value=True): def checkExportImport(self, model, input, check_value=True):
model_ir = self._convert_model(model, input) model_ir = self._convert_model(model, input)
...@@ -42,9 +33,10 @@ class TestOperators(unittest.TestCase, ConvertMixin): ...@@ -42,9 +33,10 @@ class TestOperators(unittest.TestCase, ConvertMixin):
exec_vars = {} exec_vars = {}
exec(model_code + '\n\nconverted_model = _model()', exec_vars) exec(model_code + '\n\nconverted_model = _model()', exec_vars)
converted_model = exec_vars['converted_model'] converted_model = exec_vars['converted_model']
converted_state_dict = self._match_state_dict(list(model.state_dict().values()),
dict(converted_model.state_dict())) with original_state_dict_hooks(converted_model):
converted_model.load_state_dict(converted_state_dict) converted_model.load_state_dict(model.state_dict())
with torch.no_grad(): with torch.no_grad():
expected_output = model.eval()(*input) expected_output = model.eval()(*input)
converted_output = converted_model.eval()(*input) converted_output = converted_model.eval()(*input)
......
...@@ -14,28 +14,17 @@ import torch.nn.functional as F ...@@ -14,28 +14,17 @@ import torch.nn.functional as F
import torchvision import torchvision
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
from nni.retiarii import serialize
from nni.retiarii.codegen import model_to_pytorch_script from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.utils import original_state_dict_hooks
from .convert_mixin import ConvertMixin, ConvertWithShapeMixin from .convert_mixin import ConvertMixin, ConvertWithShapeMixin
class TestPytorch(unittest.TestCase, ConvertMixin): class TestPytorch(unittest.TestCase, ConvertMixin):
@staticmethod
def _match_state_dict(current_values, expected_format): def run_test(self, model, input, check_value=True, strict_load=True):
result = {}
for k, v in expected_format.items():
for idx, cv in enumerate(current_values):
if cv.shape == v.shape:
result[k] = cv
current_values.pop(idx)
break
return result
def run_test(self, model, input, check_value=True):
model_ir = self._convert_model(model, input) model_ir = self._convert_model(model, input)
model_code = model_to_pytorch_script(model_ir) model_code = model_to_pytorch_script(model_ir)
print(model_code)
from .inject_nn import remove_inject_pytorch_nn from .inject_nn import remove_inject_pytorch_nn
remove_inject_pytorch_nn() remove_inject_pytorch_nn()
...@@ -43,9 +32,10 @@ class TestPytorch(unittest.TestCase, ConvertMixin): ...@@ -43,9 +32,10 @@ class TestPytorch(unittest.TestCase, ConvertMixin):
exec_vars = {} exec_vars = {}
exec(model_code + '\n\nconverted_model = _model()', exec_vars) exec(model_code + '\n\nconverted_model = _model()', exec_vars)
converted_model = exec_vars['converted_model'] converted_model = exec_vars['converted_model']
converted_state_dict = self._match_state_dict(list(model.state_dict().values()),
dict(converted_model.state_dict())) with original_state_dict_hooks(converted_model):
converted_model.load_state_dict(converted_state_dict) converted_model.load_state_dict(model.state_dict(), strict=strict_load)
with torch.no_grad(): with torch.no_grad():
expected_output = model.eval()(*input) expected_output = model.eval()(*input)
converted_output = converted_model.eval()(*input) converted_output = converted_model.eval()(*input)
...@@ -76,7 +66,8 @@ class TestPytorch(unittest.TestCase, ConvertMixin): ...@@ -76,7 +66,8 @@ class TestPytorch(unittest.TestCase, ConvertMixin):
model = LargeModel() model = LargeModel()
x = torch.tensor([2], dtype=torch.long) x = torch.tensor([2], dtype=torch.long)
self.run_test(model, (x, )) # emb and lin1 is actually not used so they won't appear in generated model
self.run_test(model, (x, ), strict_load=False)
@unittest.skip('skip for now, as it needs inject_nn') @unittest.skip('skip for now, as it needs inject_nn')
def test_mobilenet_v2_with_external_data(self): def test_mobilenet_v2_with_external_data(self):
......
...@@ -17,7 +17,7 @@ from nni.retiarii.graph import Model ...@@ -17,7 +17,7 @@ from nni.retiarii.graph import Model
from nni.retiarii.nn.pytorch.api import ValueChoice from nni.retiarii.nn.pytorch.api import ValueChoice
from nni.retiarii.nn.pytorch.mutator import process_evaluator_mutations, process_inline_mutation, extract_mutation_from_pt_module from nni.retiarii.nn.pytorch.mutator import process_evaluator_mutations, process_inline_mutation, extract_mutation_from_pt_module
from nni.retiarii.serializer import model_wrapper from nni.retiarii.serializer import model_wrapper
from nni.retiarii.utils import ContextStack from nni.retiarii.utils import ContextStack, original_state_dict_hooks
class EnumerateSampler(Sampler): class EnumerateSampler(Sampler):
...@@ -123,6 +123,29 @@ class GraphIR(unittest.TestCase): ...@@ -123,6 +123,29 @@ class GraphIR(unittest.TestCase):
self.assertEqual(self._get_converted_pytorch_model(model_new)(torch.randn(1, 3, 3, 3)).size(), self.assertEqual(self._get_converted_pytorch_model(model_new)(torch.randn(1, 3, 3, 3)).size(),
torch.Size([1, i, 3, 3])) torch.Size([1, i, 3, 3]))
def test_layer_choice_weight_inheritance(self):
@model_wrapper
class Net(nn.Module):
def __init__(self):
super().__init__()
self.module = nn.LayerChoice([nn.Conv2d(3, i, kernel_size=1) for i in range(1, 11)])
def forward(self, x):
return self.module(x)
orig_model = Net()
model, mutators = self._get_model_with_mutators(orig_model)
mutator = mutators[0].bind_sampler(EnumerateSampler())
for i in range(1, 11):
model_new = mutator.apply(model)
model_new = self._get_converted_pytorch_model(model_new)
with original_state_dict_hooks(model_new):
model_new.load_state_dict(orig_model.state_dict(), strict=False)
inp = torch.randn(1, 3, 3, 3)
a = getattr(orig_model.module, str(i - 1))(inp)
b = model_new(inp)
self.assertLess((a - b).abs().max().item(), 1E-4)
def test_nested_layer_choice(self): def test_nested_layer_choice(self):
@model_wrapper @model_wrapper
class Net(nn.Module): class Net(nn.Module):
...@@ -150,6 +173,40 @@ class GraphIR(unittest.TestCase): ...@@ -150,6 +173,40 @@ class GraphIR(unittest.TestCase):
self.assertEqual(self._get_converted_pytorch_model(mutators[1].apply(mutators[0].apply(model)))(input).size(), self.assertEqual(self._get_converted_pytorch_model(mutators[1].apply(mutators[0].apply(model)))(input).size(),
torch.Size([1, 5, 5, 5])) torch.Size([1, 5, 5, 5]))
def test_nested_layer_choice_weight_inheritance(self):
@model_wrapper
class Net(nn.Module):
def __init__(self):
super().__init__()
self.module = nn.LayerChoice([
nn.LayerChoice([nn.Conv2d(3, 3, kernel_size=1),
nn.Conv2d(3, 4, kernel_size=1),
nn.Conv2d(3, 5, kernel_size=1)]),
nn.Conv2d(3, 1, kernel_size=1)
])
def forward(self, x):
return self.module(x)
orig_model = Net()
model, mutators = self._get_model_with_mutators(orig_model)
mutators[0].bind_sampler(EnumerateSampler())
mutators[1].bind_sampler(EnumerateSampler())
input = torch.randn(1, 3, 5, 5)
for i in range(3):
model_new = self._get_converted_pytorch_model(mutators[1].apply(mutators[0].apply(model)))
with original_state_dict_hooks(model_new):
model_new.load_state_dict(orig_model.state_dict(), strict=False)
if i == 0:
a = getattr(getattr(orig_model.module, '0'), '0')(input)
elif i == 1:
a = getattr(orig_model.module, '1')(input)
elif i == 2:
a = getattr(getattr(orig_model.module, '0'), '2')(input)
b = model_new(input)
self.assertLess((a - b).abs().max().item(), 1E-4)
def test_input_choice(self): def test_input_choice(self):
@model_wrapper @model_wrapper
class Net(nn.Module): class Net(nn.Module):
...@@ -578,6 +635,30 @@ class GraphIR(unittest.TestCase): ...@@ -578,6 +635,30 @@ class GraphIR(unittest.TestCase):
self.assertIn(1., result) self.assertIn(1., result)
def test_repeat_weight_inheritance(self):
@model_wrapper
class Net(nn.Module):
def __init__(self):
super().__init__()
self.module = nn.Repeat(lambda index: nn.Conv2d(3, 3, 1), (2, 5))
def forward(self, x):
return self.module(x)
orig_model = Net()
model, mutators = self._get_model_with_mutators(orig_model)
mutator = mutators[0].bind_sampler(EnumerateSampler())
inp = torch.randn(1, 3, 5, 5)
for i in range(4):
model_new = self._get_converted_pytorch_model(mutator.apply(model))
with original_state_dict_hooks(model_new):
model_new.load_state_dict(orig_model.state_dict(), strict=False)
a = nn.Sequential(*orig_model.module.blocks[:i + 2])(inp)
b = model_new(inp)
self.assertLess((a - b).abs().max().item(), 1E-4)
def test_cell(self): def test_cell(self):
@model_wrapper @model_wrapper
class Net(nn.Module): class Net(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment