Commit 13dc0f8f authored by liuzhe's avatar liuzhe
Browse files

Merge branch 'master' into doc-refactor

parents 22165cea 3b27ac76
......@@ -6,6 +6,7 @@ import re
from typing import Dict, List, Tuple, Any
from nni.retiarii.operation_def.torch_op_def import ToDevice
from nni.retiarii.utils import STATE_DICT_PY_MAPPING
from nni.common.device import Device, GPUDevice
from ..graph import IllegalGraphError, Edge, Graph, Node, Model
......@@ -97,7 +98,18 @@ def _format_variable_name(name: str, graph_name: str) -> str:
name = name.replace('/', '__')
# https://stackoverflow.com/questions/3303312/how-do-i-convert-a-string-to-a-valid-variable-name-in-python
return re.sub('\W|^(?=\d)','_', name)
name = re.sub('\W|^(?=\d)','_', name)
if name.startswith('__') and (len(name) > 2 and name[2] != '_'):
# name can't start with double underscore
# it's reserved in Python: https://stackoverflow.com/a/1301409/6837658
# but it's actually very common in our generated code
name = name[1:]
elif name.startswith('_'):
# to avoid conflicts between '_' and '__'
name = 'i' + name
return name
def generate_cuda_mapping(placement: Dict[Node, Device]) -> Dict[Device, int]:
......@@ -125,6 +137,7 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
# only need to generate code for module here
import_pkgs = set()
node_codes = []
node_python_mappings = {}
cuda_remapped_id = None
if placement:
cuda_remapped_id = generate_cuda_mapping(placement)
......@@ -138,7 +151,9 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
pkg_name = node.operation.get_import_pkg()
if pkg_name is not None:
import_pkgs.add(pkg_name)
node_code = node.operation.to_init_code(_format_variable_name(node.name, graph_name))
py_variable_name = _format_variable_name(node.name, graph_name)
node_code = node.operation.to_init_code(py_variable_name)
if node_code is not None:
if placement and node in placement and len(node_code) > 0:
if isinstance(placement[node], GPUDevice):
......@@ -149,6 +164,11 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
else:
node_codes.append(node_code)
# Map to module hierarchies in original search space python code
node_python_mappings[py_variable_name] = node.python_name
node_codes.append(f'self.{STATE_DICT_PY_MAPPING} = {node_python_mappings}')
if graph.input_node.operation.io_names is None:
input_code = '*_inputs'
else:
......
......@@ -101,6 +101,7 @@ class _MultiModelSupervisedLearningModule(LightningModule):
return {name: self.trainer.callback_metrics['val_' + name].item() for name in self.metrics}
@nni.trace
class MultiModelSupervisedLearningModule(_MultiModelSupervisedLearningModule):
"""
Lightning Module of SupervisedLearning for Cross-Graph Optimization.
......
......@@ -11,6 +11,7 @@ import torch.nn as nn
from nni.common.serializer import Translatable
from nni.retiarii.serializer import basic_unit
from nni.retiarii.utils import STATE_DICT_PY_MAPPING_PARTIAL
from .utils import Mutable, generate_new_label, get_fixed_value
......@@ -82,9 +83,22 @@ class LayerChoice(Mutable):
label: Optional[str] = None, **kwargs):
chosen = get_fixed_value(label)
if isinstance(candidates, list):
return candidates[int(chosen)]
result = candidates[int(chosen)]
else:
return candidates[chosen]
result = candidates[chosen]
# map the named hierarchies to support weight inheritance for python engine
if hasattr(result, STATE_DICT_PY_MAPPING_PARTIAL):
# handle cases where layer choices are nested
# already has a mapping, will merge with it
prev_mapping = getattr(result, STATE_DICT_PY_MAPPING_PARTIAL)
setattr(result, STATE_DICT_PY_MAPPING_PARTIAL, {k: f'{chosen}.{v}' for k, v in prev_mapping.items()})
else:
# "result" needs to know where to map itself.
# Ideally, we should put a _mapping_ in the module where "result" is located,
# but it's impossible to put mapping into parent module here.
setattr(result, STATE_DICT_PY_MAPPING_PARTIAL, {'__self__': str(chosen)})
return result
def __init__(self, candidates: Union[Dict[str, nn.Module], List[nn.Module]], *,
prior: Optional[List[float]] = None, label: Optional[str] = None, **kwargs):
......
......@@ -5,6 +5,8 @@ from typing import Callable, List, Union, Tuple, Optional
import torch
import torch.nn as nn
from nni.retiarii.utils import STATE_DICT_PY_MAPPING_PARTIAL
from .api import LayerChoice
from .cell import Cell
from .nasbench101 import NasBench101Cell, NasBench101Mutator
......@@ -59,7 +61,15 @@ class Repeat(Mutable):
List[nn.Module]],
depth: Union[int, Tuple[int, int]], *, label: Optional[str] = None):
repeat = get_fixed_value(label)
return nn.Sequential(*cls._replicate_and_instantiate(blocks, repeat))
result = nn.Sequential(*cls._replicate_and_instantiate(blocks, repeat))
if hasattr(result, STATE_DICT_PY_MAPPING_PARTIAL):
# already has a mapping, will merge with it
prev_mapping = getattr(result, STATE_DICT_PY_MAPPING_PARTIAL)
setattr(result, STATE_DICT_PY_MAPPING_PARTIAL, {k: f'blocks.{v}' for k, v in prev_mapping.items()})
else:
setattr(result, STATE_DICT_PY_MAPPING_PARTIAL, {'__self__': 'blocks'})
return result
def __init__(self,
blocks: Union[Callable[[int], nn.Module],
......
......@@ -301,6 +301,8 @@ class NasBench101Cell(Mutable):
[op_candidates[selected[f'{label}/op{i}']] for i in range(1, num_nodes - 1)],
adjacency_list, in_features, out_features, num_nodes, projection)
# FIXME: weight inheritance on nasbench101 is not supported yet
def __init__(self, op_candidates: Union[Dict[str, Callable[[int], nn.Module]], List[Callable[[int], nn.Module]]],
in_features: int, out_features: int, projection: Callable[[int, int], nn.Module],
max_num_nodes: int = 7, max_num_edges: int = 9, label: Optional[str] = None):
......
......@@ -6,7 +6,7 @@ import os
import warnings
from typing import Any, TypeVar, Union
from nni.common.serializer import Traceable, is_traceable, trace, _copy_class_wrapper_attributes
from nni.common.serializer import Traceable, is_traceable, is_wrapped_with_trace, trace, _copy_class_wrapper_attributes
from .utils import ModelNamespace
__all__ = ['get_init_parameters_or_fail', 'serialize', 'serialize_cls', 'basic_unit', 'model_wrapper',
......@@ -71,7 +71,8 @@ def basic_unit(cls: T, basic_unit_tag: bool = True) -> Union[T, Traceable]:
if nni_trace_flag.lower() == 'disable':
return cls
_check_wrapped(cls)
if _check_wrapped(cls, 'basic_unit'):
return cls
import torch.nn as nn
assert issubclass(cls, nn.Module), 'When using @basic_unit, the class must be a subclass of nn.Module.'
......@@ -79,15 +80,7 @@ def basic_unit(cls: T, basic_unit_tag: bool = True) -> Union[T, Traceable]:
cls = trace(cls)
cls._nni_basic_unit = basic_unit_tag
# HACK: for torch script
# https://github.com/pytorch/pytorch/pull/45261
# https://github.com/pytorch/pytorch/issues/54688
# I'm not sure whether there will be potential issues
import torch
cls._get_nni_attr = torch.jit.ignore(cls._get_nni_attr)
cls.trace_symbol = torch.jit.unused(cls.trace_symbol)
cls.trace_args = torch.jit.unused(cls.trace_args)
cls.trace_kwargs = torch.jit.unused(cls.trace_kwargs)
_torchscript_patch(cls)
return cls
......@@ -116,12 +109,14 @@ def model_wrapper(cls: T) -> Union[T, Traceable]:
if nni_trace_flag.lower() == 'disable':
return cls
_check_wrapped(cls)
if _check_wrapped(cls, 'model_wrapper'):
return cls
import torch.nn as nn
assert issubclass(cls, nn.Module)
wrapper = trace(cls)
# subclass can still use trace info
wrapper = trace(cls, inheritable=True)
class reset_wrapper(wrapper):
def __init__(self, *args, **kwargs):
......@@ -129,8 +124,12 @@ def model_wrapper(cls: T) -> Union[T, Traceable]:
super().__init__(*args, **kwargs)
_copy_class_wrapper_attributes(wrapper, reset_wrapper)
reset_wrapper.__wrapped__ = wrapper.__wrapped__
reset_wrapper.__wrapped__ = getattr(wrapper, '__wrapped__', wrapper)
reset_wrapper._nni_model_wrapper = True
reset_wrapper._traced = True
_torchscript_patch(cls)
return reset_wrapper
......@@ -146,6 +145,32 @@ def is_model_wrapped(cls_or_instance) -> bool:
return getattr(cls_or_instance, '_nni_model_wrapper', False)
def _check_wrapped(cls: T) -> bool:
if getattr(cls, '_traced', False) or getattr(cls, '_nni_model_wrapper', False):
raise TypeError(f'{cls} is already wrapped with trace wrapper (basic_unit / model_wrapper / trace). Cannot wrap again.')
def _check_wrapped(cls: T, rewrap: str) -> bool:
wrapped = None
if is_model_wrapped(cls):
wrapped = 'model_wrapper'
elif is_basic_unit(cls):
wrapped = 'basic_unit'
elif is_wrapped_with_trace(cls):
wrapped = 'nni.trace'
if wrapped:
if wrapped != rewrap:
raise TypeError(f'{cls} is already wrapped with {wrapped}. Cannot rewrap with {rewrap}.')
return True
return False
def _torchscript_patch(cls) -> None:
# HACK: for torch script
# https://github.com/pytorch/pytorch/pull/45261
# https://github.com/pytorch/pytorch/issues/54688
# I'm not sure whether there will be potential issues
import torch
if hasattr(cls, '_get_nni_attr'): # could not exist on non-linux
cls._get_nni_attr = torch.jit.ignore(cls._get_nni_attr)
if hasattr(cls, 'trace_symbol'):
# these must all exist or all non-exist
cls.trace_symbol = torch.jit.unused(cls.trace_symbol)
cls.trace_args = torch.jit.unused(cls.trace_args)
cls.trace_kwargs = torch.jit.unused(cls.trace_kwargs)
cls.trace_copy = torch.jit.ignore(cls.trace_copy)
......@@ -43,6 +43,17 @@ class MultiThreadEnvWorker(EnvWorker):
def reset(self):
return self.env.reset()
def send(self, action):
# for tianshou >= 0.4.6
if action is None:
self.result = self.pool.apply_async(self.env.reset)
else:
self.send_action(action)
def recv(self):
# for tianshou >= 0.4.6
return self.result.get()
@staticmethod
def wait(*args, **kwargs):
raise NotImplementedError('Async collect is not supported yet.')
......
......@@ -2,8 +2,10 @@
# Licensed under the MIT license.
import inspect
import itertools
import warnings
from collections import defaultdict
from contextlib import contextmanager
from typing import Any, List, Dict
from pathlib import Path
......@@ -150,3 +152,119 @@ class ModelNamespace:
def get_current_context(key: str) -> Any:
return ContextStack.top(key)
# map variables to prefix in the state dict
# e.g., {'upsample': 'mynet.module.deconv2.upsample_layer'}
STATE_DICT_PY_MAPPING = '_mapping_'
# map variables to `prefix`.`value` in the state dict
# e.g., {'upsample': 'choice3.upsample_layer'},
# which actually means {'upsample': 'mynet.module.choice3.upsample_layer'},
# and 'upsample' is also in `mynet.module`.
STATE_DICT_PY_MAPPING_PARTIAL = '_mapping_partial_'
@contextmanager
def original_state_dict_hooks(model: Any):
"""
Use this patch if you want to save/load state dict in the original state dict hierarchy.
For example, when you already have a state dict for the base model / search space (which often
happens when you have trained a supernet with one-shot strategies), the state dict isn't organized
in the same way as when a sub-model is sampled from the search space. This patch will help
the modules in the sub-model find the corresponding module in the base model.
The code looks like,
.. code-block:: python
with original_state_dict_hooks(model):
model.load_state_dict(state_dict_from_supernet, strict=False) # supernet has extra keys
Or vice-versa,
.. code-block:: python
with original_state_dict_hooks(model):
supernet_style_state_dict = model.state_dict()
"""
import torch.nn as nn
assert isinstance(model, nn.Module), 'PyTorch is the only supported framework for now.'
# the following are written for pytorch only
# first get the full mapping
full_mapping = {}
def full_mapping_in_module(src_prefix, tar_prefix, module):
if hasattr(module, STATE_DICT_PY_MAPPING):
# only values are complete
local_map = getattr(module, STATE_DICT_PY_MAPPING)
elif hasattr(module, STATE_DICT_PY_MAPPING_PARTIAL):
# keys and values are both incomplete
local_map = getattr(module, STATE_DICT_PY_MAPPING_PARTIAL)
local_map = {k: tar_prefix + v for k, v in local_map.items()}
else:
# no mapping
local_map = {}
if '__self__' in local_map:
# special case, overwrite prefix
tar_prefix = local_map['__self__'] + '.'
for key, value in local_map.items():
if key != '' and key not in module._modules: # not a sub-module, probably a parameter
full_mapping[src_prefix + key] = value
if src_prefix != tar_prefix: # To deal with leaf nodes.
for name, value in itertools.chain(module._parameters.items(), module._buffers.items()): # direct children
if value is None or name in module._non_persistent_buffers_set:
# it won't appear in state dict
continue
if (src_prefix + name) not in full_mapping:
full_mapping[src_prefix + name] = tar_prefix + name
for name, child in module.named_children():
# sub-modules
full_mapping_in_module(
src_prefix + name + '.',
local_map.get(name, tar_prefix + name) + '.', # if mapping doesn't exist, respect the prefix
child
)
full_mapping_in_module('', '', model)
def load_state_dict_hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
reverse_mapping = defaultdict(list)
for src, tar in full_mapping.items():
reverse_mapping[tar].append(src)
transf_state_dict = {}
for src, tar_keys in reverse_mapping.items():
if src in state_dict:
value = state_dict.pop(src)
for tar in tar_keys:
transf_state_dict[tar] = value
else:
missing_keys.append(src)
state_dict.update(transf_state_dict)
def state_dict_hook(module, destination, prefix, local_metadata):
result = {}
for src, tar in full_mapping.items():
if src in destination:
result[tar] = destination.pop(src)
else:
raise KeyError(f'"{src}" not in state dict, but found in mapping.')
destination.update(result)
try:
hooks = []
hooks.append(model._register_load_state_dict_pre_hook(load_state_dict_hook))
hooks.append(model._register_state_dict_hook(state_dict_hook))
yield
finally:
for hook in hooks:
hook.remove()
......@@ -70,21 +70,17 @@ def start_rest_server(port, platform, mode, experiment_id, foreground=False, log
node_command = os.path.join(entry_dir, 'node')
cmds = [node_command, '--max-old-space-size=4096', entry_file, '--port', str(port), '--mode', platform, \
'--experiment_id', experiment_id]
if mode == 'view':
cmds += ['--start_mode', 'resume']
cmds += ['--readonly', 'true']
else:
cmds += ['--start_mode', mode]
cmds += ['--action', mode]
if log_dir is not None:
cmds += ['--log_dir', log_dir]
cmds += ['--experiments-directory', log_dir]
if log_level is not None:
cmds += ['--log_level', log_level]
cmds += ['--log-level', log_level]
if foreground:
cmds += ['--foreground', 'true']
if url_prefix:
_validate_prefix_path(url_prefix)
set_prefix_url(url_prefix)
cmds += ['--url_prefix', url_prefix]
cmds += ['--url-prefix', url_prefix.strip('/')]
stdout_full_path, stderr_full_path = get_log_path(experiment_id)
with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file:
......@@ -520,9 +516,9 @@ def create_experiment(args):
try:
if schema == 1:
launch_experiment(args, config_v1, 'new', experiment_id, 1)
launch_experiment(args, config_v1, 'create', experiment_id, 1)
else:
launch_experiment(args, config_v2, 'new', experiment_id, 2)
launch_experiment(args, config_v2, 'create', experiment_id, 2)
except Exception as exception:
restServerPid = Experiments().get_all_experiments().get(experiment_id, {}).get('pid')
if restServerPid:
......
......@@ -177,7 +177,7 @@ stages:
- job: windows
pool:
vmImage: windows-latest
timeoutInMinutes: 70
timeoutInMinutes: 75
steps:
- template: templates/install-dependencies.yml
......
......@@ -512,6 +512,46 @@ class SpeedupTestCase(TestCase):
print("Fine-grained speeduped model")
print(model)
def test_multiplication_speedup(self):
"""
Model from issue 4540.
"""
class Net(torch.nn.Module):
def __init__(self,):
super(Net, self).__init__()
self.avgpool = torch.nn.AdaptiveAvgPool2d(1)
self.input = torch.nn.Conv2d(3, 8, 3)
self.bn = torch.nn.BatchNorm2d(8)
self.fc1 = torch.nn.Conv2d(8, 16, 1)
self.fc2 = torch.nn.Conv2d(16, 8, 1)
self.activation = torch.nn.ReLU()
self.scale_activation = torch.nn.Hardsigmoid()
self.out = torch.nn.Conv2d(8, 12, 1)
def forward(self, input):
input = self.activation(self.bn(self.input(input)))
scale = self.avgpool(input)
out1 = self.activation(self.fc1(scale))
out1 = self.scale_activation(self.fc2(out1))
return self.out(out1 * input)
model = Net().to(device)
model.eval()
im = torch.ones(1, 3, 512, 512).to(device)
model(im)
cfg_list = []
for name, module in model.named_modules():
if isinstance(module, torch.nn.Conv2d):
cfg_list.append({'op_types':['Conv2d'], 'sparsity':0.3, 'op_names':[name]})
pruner = L1FilterPruner(model, cfg_list)
pruner.compress()
pruner.export_model(MODEL_FILE, MASK_FILE)
pruner._unwrap_model()
ms=ModelSpeedup(model, im, MASK_FILE)
ms.speedup_model()
def tearDown(self):
if os.path.exists(MODEL_FILE):
os.remove(MODEL_FILE)
......
......@@ -4,6 +4,7 @@
import random
import unittest
import numpy
import torch
import torch.nn.functional as F
......@@ -105,6 +106,17 @@ class IterativePrunerTestCase(unittest.TestCase):
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.78 < sparsity_list[0]['total_sparsity'] < 0.82
def test_amc_pruner(self):
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'total_sparsity': 0.5, 'max_sparsity_per_layer': 0.8}]
dummy_input = torch.rand(10, 1, 28, 28)
ddpg_params = {'hidden1': 300, 'hidden2': 300, 'lr_c': 1e-3, 'lr_a': 1e-4, 'warmup': 5, 'discount': 1.,
'bsize': 64, 'rmsize': 100, 'window_length': 1, 'tau': 0.01, 'init_delta': 0.5, 'delta_decay': 0.99,
'max_episode_length': 1e9, 'epsilon': 50000}
pruner = AMCPruner(10, model, config_list, dummy_input, evaluator, finetuner=finetuner, ddpg_params=ddpg_params, target='flops', log_dir='../../../logs')
pruner.compress()
class FixSeedPrunerTestCase(unittest.TestCase):
def test_auto_compress_pruner(self):
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'total_sparsity': 0.8}]
......@@ -126,15 +138,21 @@ class IterativePrunerTestCase(unittest.TestCase):
print(sparsity_list)
assert 0.78 < sparsity_list[0]['total_sparsity'] < 0.82
def test_amc_pruner(self):
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'total_sparsity': 0.5, 'max_sparsity_per_layer': 0.8}]
dummy_input = torch.rand(10, 1, 28, 28)
ddpg_params = {'hidden1': 300, 'hidden2': 300, 'lr_c': 1e-3, 'lr_a': 1e-4, 'warmup': 5, 'discount': 1.,
'bsize': 64, 'rmsize': 100, 'window_length': 1, 'tau': 0.01, 'init_delta': 0.5, 'delta_decay': 0.99,
'max_episode_length': 1e9, 'epsilon': 50000}
pruner = AMCPruner(10, model, config_list, dummy_input, evaluator, finetuner=finetuner, ddpg_params=ddpg_params, target='flops', log_dir='../../../logs')
pruner.compress()
def setUp(self) -> None:
# fix seed in order to solve the random failure of ut
random.seed(1024)
numpy.random.seed(1024)
torch.manual_seed(1024)
def tearDown(self) -> None:
# reset seed
import time
now = int(time.time() * 100)
random.seed(now)
seed = random.randint(0, 2 ** 32 - 1)
random.seed(seed)
numpy.random.seed(seed)
torch.manual_seed(seed)
if __name__ == '__main__':
unittest.main()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import random
import unittest
import numpy
import torch
import torch.nn.functional as F
......@@ -122,18 +124,6 @@ class PrunerTestCase(unittest.TestCase):
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.78 < sparsity_list[0]['total_sparsity'] < 0.82
def test_activation_apoz_rank_pruner(self):
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
pruner = ActivationAPoZRankPruner(model=model, config_list=config_list, trainer=trainer,
traced_optimizer=get_optimizer(model), criterion=criterion, training_batches=5,
activation='relu', mode='dependency_aware',
dummy_input=torch.rand(10, 1, 28, 28))
pruned_model, masks = pruner.compress()
pruner._unwrap_model()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.78 < sparsity_list[0]['total_sparsity'] < 0.82
def test_activation_mean_rank_pruner(self):
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
......@@ -177,6 +167,34 @@ class PrunerTestCase(unittest.TestCase):
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.78 < sparsity_list[0]['total_sparsity'] < 0.82
class FixSeedPrunerTestCase(unittest.TestCase):
def test_activation_apoz_rank_pruner(self):
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
pruner = ActivationAPoZRankPruner(model=model, config_list=config_list, trainer=trainer,
traced_optimizer=get_optimizer(model), criterion=criterion, training_batches=5,
activation='relu', mode='dependency_aware',
dummy_input=torch.rand(10, 1, 28, 28))
pruned_model, masks = pruner.compress()
pruner._unwrap_model()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.78 < sparsity_list[0]['total_sparsity'] < 0.82
def setUp(self) -> None:
# fix seed in order to solve the random failure of ut
random.seed(1024)
numpy.random.seed(1024)
torch.manual_seed(1024)
def tearDown(self) -> None:
# reset seed
import time
now = int(time.time() * 100)
random.seed(now)
seed = random.randint(0, 2 ** 32 - 1)
random.seed(seed)
numpy.random.seed(seed)
torch.manual_seed(seed)
if __name__ == '__main__':
unittest.main()
......@@ -16,6 +16,7 @@ class _model(nn.Module):
self.fc1 = torch.nn.Linear(out_features=256, in_features=1024)
self.fc2 = torch.nn.Linear(out_features=10, in_features=256)
self.softmax = torch.nn.Softmax()
self._mapping_ = {'stem': None, 'flatten': None, 'fc1': None, 'fc2': None, 'softmax': None}
def forward(self, image):
stem = self.stem(image)
......@@ -34,6 +35,7 @@ class stem(nn.Module):
self.pool1 = torch.nn.MaxPool2d(kernel_size=2)
self.conv2 = torch.nn.Conv2d(out_channels=64, in_channels=32, kernel_size=5)
self.pool2 = torch.nn.MaxPool2d(kernel_size=2)
self._mapping_ = {'conv1': None, 'pool1': None, 'conv2': None, 'pool2': None}
def forward(self, *_inputs):
conv1 = self.conv1(_inputs[0])
......
......@@ -9,6 +9,7 @@ from pytorch_lightning.utilities.seed import seed_everything
from pathlib import Path
import nni
import nni.runtime.platform.test
try:
from nni.common.device import GPUDevice
......
......@@ -14,6 +14,7 @@ import torchvision
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import basic_unit
from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.utils import original_state_dict_hooks
from .convert_mixin import ConvertMixin, ConvertWithShapeMixin
......@@ -50,16 +51,6 @@ class Linear(nn.Module):
return out.view(size[0], size[1], -1)
class TestConvert(unittest.TestCase, ConvertMixin):
@staticmethod
def _match_state_dict(current_values, expected_format):
result = {}
for k, v in expected_format.items():
for idx, cv in enumerate(current_values):
if cv.shape == v.shape:
result[k] = cv
current_values.pop(idx)
break
return result
def checkExportImport(self, model, input):
model_ir = self._convert_model(model, input)
......@@ -68,9 +59,8 @@ class TestConvert(unittest.TestCase, ConvertMixin):
exec_vars = {}
exec(model_code + '\n\nconverted_model = _model()', exec_vars)
converted_model = exec_vars['converted_model']
converted_state_dict = self._match_state_dict(list(model.state_dict().values()),
dict(converted_model.state_dict()))
converted_model.load_state_dict(converted_state_dict)
with original_state_dict_hooks(converted_model):
converted_model.load_state_dict(dict(model.state_dict()))
with torch.no_grad():
expected_output = model.eval()(*input)
converted_output = converted_model.eval()(*input)
......
......@@ -12,20 +12,11 @@ from nni.retiarii import basic_unit
from .convert_mixin import ConvertMixin, ConvertWithShapeMixin
from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.utils import original_state_dict_hooks
# following pytorch v1.7.1
class TestConvert(unittest.TestCase, ConvertMixin):
@staticmethod
def _match_state_dict(current_values, expected_format):
result = {}
for k, v in expected_format.items():
for idx, cv in enumerate(current_values):
if cv.shape == v.shape:
result[k] = cv
current_values.pop(idx)
break
return result
def checkExportImport(self, model, input, check_value=True):
model_ir = self._convert_model(model, input)
......@@ -35,9 +26,10 @@ class TestConvert(unittest.TestCase, ConvertMixin):
exec_vars = {}
exec(model_code + '\n\nconverted_model = _model()', exec_vars)
converted_model = exec_vars['converted_model']
converted_state_dict = self._match_state_dict(list(model.state_dict().values()),
dict(converted_model.state_dict()))
converted_model.load_state_dict(converted_state_dict)
with original_state_dict_hooks(converted_model):
converted_model.load_state_dict(model.state_dict())
with torch.no_grad():
expected_output = model.eval()(*input)
converted_output = converted_model.eval()(*input)
......
......@@ -9,23 +9,13 @@ import torch.nn.functional as F
import torchvision
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import serialize
from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.utils import original_state_dict_hooks
from .convert_mixin import ConvertMixin, ConvertWithShapeMixin
class TestModels(unittest.TestCase, ConvertMixin):
@staticmethod
def _match_state_dict(current_values, expected_format):
result = {}
for k, v in expected_format.items():
for idx, cv in enumerate(current_values):
if cv.shape == v.shape:
result[k] = cv
current_values.pop(idx)
break
return result
def run_test(self, model, input, check_value=True):
model_ir = self._convert_model(model, input)
......@@ -35,9 +25,10 @@ class TestModels(unittest.TestCase, ConvertMixin):
exec_vars = {}
exec(model_code + '\n\nconverted_model = _model()', exec_vars)
converted_model = exec_vars['converted_model']
converted_state_dict = self._match_state_dict(list(model.state_dict().values()),
dict(converted_model.state_dict()))
converted_model.load_state_dict(converted_state_dict)
with original_state_dict_hooks(converted_model):
converted_model.load_state_dict(model.state_dict())
with torch.no_grad():
expected_output = model.eval()(*input)
converted_output = converted_model.eval()(*input)
......
......@@ -16,6 +16,7 @@ import torchvision
import nni.retiarii.nn.pytorch as nn
from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.utils import original_state_dict_hooks
from .convert_mixin import ConvertMixin, ConvertWithShapeMixin
......@@ -23,16 +24,6 @@ from .convert_mixin import ConvertMixin, ConvertWithShapeMixin
class TestOperators(unittest.TestCase, ConvertMixin):
@staticmethod
def _match_state_dict(current_values, expected_format):
result = {}
for k, v in expected_format.items():
for idx, cv in enumerate(current_values):
if cv.shape == v.shape:
result[k] = cv
current_values.pop(idx)
break
return result
def checkExportImport(self, model, input, check_value=True):
model_ir = self._convert_model(model, input)
......@@ -42,9 +33,10 @@ class TestOperators(unittest.TestCase, ConvertMixin):
exec_vars = {}
exec(model_code + '\n\nconverted_model = _model()', exec_vars)
converted_model = exec_vars['converted_model']
converted_state_dict = self._match_state_dict(list(model.state_dict().values()),
dict(converted_model.state_dict()))
converted_model.load_state_dict(converted_state_dict)
with original_state_dict_hooks(converted_model):
converted_model.load_state_dict(model.state_dict())
with torch.no_grad():
expected_output = model.eval()(*input)
converted_output = converted_model.eval()(*input)
......
......@@ -14,28 +14,17 @@ import torch.nn.functional as F
import torchvision
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import serialize
from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.utils import original_state_dict_hooks
from .convert_mixin import ConvertMixin, ConvertWithShapeMixin
class TestPytorch(unittest.TestCase, ConvertMixin):
@staticmethod
def _match_state_dict(current_values, expected_format):
result = {}
for k, v in expected_format.items():
for idx, cv in enumerate(current_values):
if cv.shape == v.shape:
result[k] = cv
current_values.pop(idx)
break
return result
def run_test(self, model, input, check_value=True):
def run_test(self, model, input, check_value=True, strict_load=True):
model_ir = self._convert_model(model, input)
model_code = model_to_pytorch_script(model_ir)
print(model_code)
from .inject_nn import remove_inject_pytorch_nn
remove_inject_pytorch_nn()
......@@ -43,9 +32,10 @@ class TestPytorch(unittest.TestCase, ConvertMixin):
exec_vars = {}
exec(model_code + '\n\nconverted_model = _model()', exec_vars)
converted_model = exec_vars['converted_model']
converted_state_dict = self._match_state_dict(list(model.state_dict().values()),
dict(converted_model.state_dict()))
converted_model.load_state_dict(converted_state_dict)
with original_state_dict_hooks(converted_model):
converted_model.load_state_dict(model.state_dict(), strict=strict_load)
with torch.no_grad():
expected_output = model.eval()(*input)
converted_output = converted_model.eval()(*input)
......@@ -76,7 +66,8 @@ class TestPytorch(unittest.TestCase, ConvertMixin):
model = LargeModel()
x = torch.tensor([2], dtype=torch.long)
self.run_test(model, (x, ))
# emb and lin1 is actually not used so they won't appear in generated model
self.run_test(model, (x, ), strict_load=False)
@unittest.skip('skip for now, as it needs inject_nn')
def test_mobilenet_v2_with_external_data(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment