Commit 13dc0f8f authored by liuzhe's avatar liuzhe
Browse files

Merge branch 'master' into doc-refactor

parents 22165cea 3b27ac76
...@@ -6,6 +6,7 @@ import re ...@@ -6,6 +6,7 @@ import re
from typing import Dict, List, Tuple, Any from typing import Dict, List, Tuple, Any
from nni.retiarii.operation_def.torch_op_def import ToDevice from nni.retiarii.operation_def.torch_op_def import ToDevice
from nni.retiarii.utils import STATE_DICT_PY_MAPPING
from nni.common.device import Device, GPUDevice from nni.common.device import Device, GPUDevice
from ..graph import IllegalGraphError, Edge, Graph, Node, Model from ..graph import IllegalGraphError, Edge, Graph, Node, Model
...@@ -97,7 +98,18 @@ def _format_variable_name(name: str, graph_name: str) -> str: ...@@ -97,7 +98,18 @@ def _format_variable_name(name: str, graph_name: str) -> str:
name = name.replace('/', '__') name = name.replace('/', '__')
# https://stackoverflow.com/questions/3303312/how-do-i-convert-a-string-to-a-valid-variable-name-in-python # https://stackoverflow.com/questions/3303312/how-do-i-convert-a-string-to-a-valid-variable-name-in-python
return re.sub('\W|^(?=\d)','_', name) name = re.sub('\W|^(?=\d)','_', name)
if name.startswith('__') and (len(name) > 2 and name[2] != '_'):
# name can't start with double underscore
# it's reserved in Python: https://stackoverflow.com/a/1301409/6837658
# but it's actually very common in our generated code
name = name[1:]
elif name.startswith('_'):
# to avoid conflicts between '_' and '__'
name = 'i' + name
return name
def generate_cuda_mapping(placement: Dict[Node, Device]) -> Dict[Device, int]: def generate_cuda_mapping(placement: Dict[Node, Device]) -> Dict[Device, int]:
...@@ -125,6 +137,7 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str ...@@ -125,6 +137,7 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
# only need to generate code for module here # only need to generate code for module here
import_pkgs = set() import_pkgs = set()
node_codes = [] node_codes = []
node_python_mappings = {}
cuda_remapped_id = None cuda_remapped_id = None
if placement: if placement:
cuda_remapped_id = generate_cuda_mapping(placement) cuda_remapped_id = generate_cuda_mapping(placement)
...@@ -138,7 +151,9 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str ...@@ -138,7 +151,9 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
pkg_name = node.operation.get_import_pkg() pkg_name = node.operation.get_import_pkg()
if pkg_name is not None: if pkg_name is not None:
import_pkgs.add(pkg_name) import_pkgs.add(pkg_name)
node_code = node.operation.to_init_code(_format_variable_name(node.name, graph_name))
py_variable_name = _format_variable_name(node.name, graph_name)
node_code = node.operation.to_init_code(py_variable_name)
if node_code is not None: if node_code is not None:
if placement and node in placement and len(node_code) > 0: if placement and node in placement and len(node_code) > 0:
if isinstance(placement[node], GPUDevice): if isinstance(placement[node], GPUDevice):
...@@ -149,6 +164,11 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str ...@@ -149,6 +164,11 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
else: else:
node_codes.append(node_code) node_codes.append(node_code)
# Map to module hierarchies in original search space python code
node_python_mappings[py_variable_name] = node.python_name
node_codes.append(f'self.{STATE_DICT_PY_MAPPING} = {node_python_mappings}')
if graph.input_node.operation.io_names is None: if graph.input_node.operation.io_names is None:
input_code = '*_inputs' input_code = '*_inputs'
else: else:
......
...@@ -101,6 +101,7 @@ class _MultiModelSupervisedLearningModule(LightningModule): ...@@ -101,6 +101,7 @@ class _MultiModelSupervisedLearningModule(LightningModule):
return {name: self.trainer.callback_metrics['val_' + name].item() for name in self.metrics} return {name: self.trainer.callback_metrics['val_' + name].item() for name in self.metrics}
@nni.trace
class MultiModelSupervisedLearningModule(_MultiModelSupervisedLearningModule): class MultiModelSupervisedLearningModule(_MultiModelSupervisedLearningModule):
""" """
Lightning Module of SupervisedLearning for Cross-Graph Optimization. Lightning Module of SupervisedLearning for Cross-Graph Optimization.
......
...@@ -11,6 +11,7 @@ import torch.nn as nn ...@@ -11,6 +11,7 @@ import torch.nn as nn
from nni.common.serializer import Translatable from nni.common.serializer import Translatable
from nni.retiarii.serializer import basic_unit from nni.retiarii.serializer import basic_unit
from nni.retiarii.utils import STATE_DICT_PY_MAPPING_PARTIAL
from .utils import Mutable, generate_new_label, get_fixed_value from .utils import Mutable, generate_new_label, get_fixed_value
...@@ -82,9 +83,22 @@ class LayerChoice(Mutable): ...@@ -82,9 +83,22 @@ class LayerChoice(Mutable):
label: Optional[str] = None, **kwargs): label: Optional[str] = None, **kwargs):
chosen = get_fixed_value(label) chosen = get_fixed_value(label)
if isinstance(candidates, list): if isinstance(candidates, list):
return candidates[int(chosen)] result = candidates[int(chosen)]
else: else:
return candidates[chosen] result = candidates[chosen]
# map the named hierarchies to support weight inheritance for python engine
if hasattr(result, STATE_DICT_PY_MAPPING_PARTIAL):
# handle cases where layer choices are nested
# already has a mapping, will merge with it
prev_mapping = getattr(result, STATE_DICT_PY_MAPPING_PARTIAL)
setattr(result, STATE_DICT_PY_MAPPING_PARTIAL, {k: f'{chosen}.{v}' for k, v in prev_mapping.items()})
else:
# "result" needs to know where to map itself.
# Ideally, we should put a _mapping_ in the module where "result" is located,
# but it's impossible to put mapping into parent module here.
setattr(result, STATE_DICT_PY_MAPPING_PARTIAL, {'__self__': str(chosen)})
return result
def __init__(self, candidates: Union[Dict[str, nn.Module], List[nn.Module]], *, def __init__(self, candidates: Union[Dict[str, nn.Module], List[nn.Module]], *,
prior: Optional[List[float]] = None, label: Optional[str] = None, **kwargs): prior: Optional[List[float]] = None, label: Optional[str] = None, **kwargs):
......
...@@ -5,6 +5,8 @@ from typing import Callable, List, Union, Tuple, Optional ...@@ -5,6 +5,8 @@ from typing import Callable, List, Union, Tuple, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from nni.retiarii.utils import STATE_DICT_PY_MAPPING_PARTIAL
from .api import LayerChoice from .api import LayerChoice
from .cell import Cell from .cell import Cell
from .nasbench101 import NasBench101Cell, NasBench101Mutator from .nasbench101 import NasBench101Cell, NasBench101Mutator
...@@ -59,7 +61,15 @@ class Repeat(Mutable): ...@@ -59,7 +61,15 @@ class Repeat(Mutable):
List[nn.Module]], List[nn.Module]],
depth: Union[int, Tuple[int, int]], *, label: Optional[str] = None): depth: Union[int, Tuple[int, int]], *, label: Optional[str] = None):
repeat = get_fixed_value(label) repeat = get_fixed_value(label)
return nn.Sequential(*cls._replicate_and_instantiate(blocks, repeat)) result = nn.Sequential(*cls._replicate_and_instantiate(blocks, repeat))
if hasattr(result, STATE_DICT_PY_MAPPING_PARTIAL):
# already has a mapping, will merge with it
prev_mapping = getattr(result, STATE_DICT_PY_MAPPING_PARTIAL)
setattr(result, STATE_DICT_PY_MAPPING_PARTIAL, {k: f'blocks.{v}' for k, v in prev_mapping.items()})
else:
setattr(result, STATE_DICT_PY_MAPPING_PARTIAL, {'__self__': 'blocks'})
return result
def __init__(self, def __init__(self,
blocks: Union[Callable[[int], nn.Module], blocks: Union[Callable[[int], nn.Module],
......
...@@ -301,6 +301,8 @@ class NasBench101Cell(Mutable): ...@@ -301,6 +301,8 @@ class NasBench101Cell(Mutable):
[op_candidates[selected[f'{label}/op{i}']] for i in range(1, num_nodes - 1)], [op_candidates[selected[f'{label}/op{i}']] for i in range(1, num_nodes - 1)],
adjacency_list, in_features, out_features, num_nodes, projection) adjacency_list, in_features, out_features, num_nodes, projection)
# FIXME: weight inheritance on nasbench101 is not supported yet
def __init__(self, op_candidates: Union[Dict[str, Callable[[int], nn.Module]], List[Callable[[int], nn.Module]]], def __init__(self, op_candidates: Union[Dict[str, Callable[[int], nn.Module]], List[Callable[[int], nn.Module]]],
in_features: int, out_features: int, projection: Callable[[int, int], nn.Module], in_features: int, out_features: int, projection: Callable[[int, int], nn.Module],
max_num_nodes: int = 7, max_num_edges: int = 9, label: Optional[str] = None): max_num_nodes: int = 7, max_num_edges: int = 9, label: Optional[str] = None):
......
...@@ -6,7 +6,7 @@ import os ...@@ -6,7 +6,7 @@ import os
import warnings import warnings
from typing import Any, TypeVar, Union from typing import Any, TypeVar, Union
from nni.common.serializer import Traceable, is_traceable, trace, _copy_class_wrapper_attributes from nni.common.serializer import Traceable, is_traceable, is_wrapped_with_trace, trace, _copy_class_wrapper_attributes
from .utils import ModelNamespace from .utils import ModelNamespace
__all__ = ['get_init_parameters_or_fail', 'serialize', 'serialize_cls', 'basic_unit', 'model_wrapper', __all__ = ['get_init_parameters_or_fail', 'serialize', 'serialize_cls', 'basic_unit', 'model_wrapper',
...@@ -71,7 +71,8 @@ def basic_unit(cls: T, basic_unit_tag: bool = True) -> Union[T, Traceable]: ...@@ -71,7 +71,8 @@ def basic_unit(cls: T, basic_unit_tag: bool = True) -> Union[T, Traceable]:
if nni_trace_flag.lower() == 'disable': if nni_trace_flag.lower() == 'disable':
return cls return cls
_check_wrapped(cls) if _check_wrapped(cls, 'basic_unit'):
return cls
import torch.nn as nn import torch.nn as nn
assert issubclass(cls, nn.Module), 'When using @basic_unit, the class must be a subclass of nn.Module.' assert issubclass(cls, nn.Module), 'When using @basic_unit, the class must be a subclass of nn.Module.'
...@@ -79,15 +80,7 @@ def basic_unit(cls: T, basic_unit_tag: bool = True) -> Union[T, Traceable]: ...@@ -79,15 +80,7 @@ def basic_unit(cls: T, basic_unit_tag: bool = True) -> Union[T, Traceable]:
cls = trace(cls) cls = trace(cls)
cls._nni_basic_unit = basic_unit_tag cls._nni_basic_unit = basic_unit_tag
# HACK: for torch script _torchscript_patch(cls)
# https://github.com/pytorch/pytorch/pull/45261
# https://github.com/pytorch/pytorch/issues/54688
# I'm not sure whether there will be potential issues
import torch
cls._get_nni_attr = torch.jit.ignore(cls._get_nni_attr)
cls.trace_symbol = torch.jit.unused(cls.trace_symbol)
cls.trace_args = torch.jit.unused(cls.trace_args)
cls.trace_kwargs = torch.jit.unused(cls.trace_kwargs)
return cls return cls
...@@ -116,12 +109,14 @@ def model_wrapper(cls: T) -> Union[T, Traceable]: ...@@ -116,12 +109,14 @@ def model_wrapper(cls: T) -> Union[T, Traceable]:
if nni_trace_flag.lower() == 'disable': if nni_trace_flag.lower() == 'disable':
return cls return cls
_check_wrapped(cls) if _check_wrapped(cls, 'model_wrapper'):
return cls
import torch.nn as nn import torch.nn as nn
assert issubclass(cls, nn.Module) assert issubclass(cls, nn.Module)
wrapper = trace(cls) # subclass can still use trace info
wrapper = trace(cls, inheritable=True)
class reset_wrapper(wrapper): class reset_wrapper(wrapper):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
...@@ -129,8 +124,12 @@ def model_wrapper(cls: T) -> Union[T, Traceable]: ...@@ -129,8 +124,12 @@ def model_wrapper(cls: T) -> Union[T, Traceable]:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
_copy_class_wrapper_attributes(wrapper, reset_wrapper) _copy_class_wrapper_attributes(wrapper, reset_wrapper)
reset_wrapper.__wrapped__ = wrapper.__wrapped__ reset_wrapper.__wrapped__ = getattr(wrapper, '__wrapped__', wrapper)
reset_wrapper._nni_model_wrapper = True reset_wrapper._nni_model_wrapper = True
reset_wrapper._traced = True
_torchscript_patch(cls)
return reset_wrapper return reset_wrapper
...@@ -146,6 +145,32 @@ def is_model_wrapped(cls_or_instance) -> bool: ...@@ -146,6 +145,32 @@ def is_model_wrapped(cls_or_instance) -> bool:
return getattr(cls_or_instance, '_nni_model_wrapper', False) return getattr(cls_or_instance, '_nni_model_wrapper', False)
def _check_wrapped(cls: T) -> bool: def _check_wrapped(cls: T, rewrap: str) -> bool:
if getattr(cls, '_traced', False) or getattr(cls, '_nni_model_wrapper', False): wrapped = None
raise TypeError(f'{cls} is already wrapped with trace wrapper (basic_unit / model_wrapper / trace). Cannot wrap again.') if is_model_wrapped(cls):
wrapped = 'model_wrapper'
elif is_basic_unit(cls):
wrapped = 'basic_unit'
elif is_wrapped_with_trace(cls):
wrapped = 'nni.trace'
if wrapped:
if wrapped != rewrap:
raise TypeError(f'{cls} is already wrapped with {wrapped}. Cannot rewrap with {rewrap}.')
return True
return False
def _torchscript_patch(cls) -> None:
# HACK: for torch script
# https://github.com/pytorch/pytorch/pull/45261
# https://github.com/pytorch/pytorch/issues/54688
# I'm not sure whether there will be potential issues
import torch
if hasattr(cls, '_get_nni_attr'): # could not exist on non-linux
cls._get_nni_attr = torch.jit.ignore(cls._get_nni_attr)
if hasattr(cls, 'trace_symbol'):
# these must all exist or all non-exist
cls.trace_symbol = torch.jit.unused(cls.trace_symbol)
cls.trace_args = torch.jit.unused(cls.trace_args)
cls.trace_kwargs = torch.jit.unused(cls.trace_kwargs)
cls.trace_copy = torch.jit.ignore(cls.trace_copy)
...@@ -43,6 +43,17 @@ class MultiThreadEnvWorker(EnvWorker): ...@@ -43,6 +43,17 @@ class MultiThreadEnvWorker(EnvWorker):
def reset(self): def reset(self):
return self.env.reset() return self.env.reset()
def send(self, action):
# for tianshou >= 0.4.6
if action is None:
self.result = self.pool.apply_async(self.env.reset)
else:
self.send_action(action)
def recv(self):
# for tianshou >= 0.4.6
return self.result.get()
@staticmethod @staticmethod
def wait(*args, **kwargs): def wait(*args, **kwargs):
raise NotImplementedError('Async collect is not supported yet.') raise NotImplementedError('Async collect is not supported yet.')
......
...@@ -2,8 +2,10 @@ ...@@ -2,8 +2,10 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import inspect import inspect
import itertools
import warnings import warnings
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager
from typing import Any, List, Dict from typing import Any, List, Dict
from pathlib import Path from pathlib import Path
...@@ -150,3 +152,119 @@ class ModelNamespace: ...@@ -150,3 +152,119 @@ class ModelNamespace:
def get_current_context(key: str) -> Any: def get_current_context(key: str) -> Any:
return ContextStack.top(key) return ContextStack.top(key)
# map variables to prefix in the state dict
# e.g., {'upsample': 'mynet.module.deconv2.upsample_layer'}
STATE_DICT_PY_MAPPING = '_mapping_'
# map variables to `prefix`.`value` in the state dict
# e.g., {'upsample': 'choice3.upsample_layer'},
# which actually means {'upsample': 'mynet.module.choice3.upsample_layer'},
# and 'upsample' is also in `mynet.module`.
STATE_DICT_PY_MAPPING_PARTIAL = '_mapping_partial_'
@contextmanager
def original_state_dict_hooks(model: Any):
"""
Use this patch if you want to save/load state dict in the original state dict hierarchy.
For example, when you already have a state dict for the base model / search space (which often
happens when you have trained a supernet with one-shot strategies), the state dict isn't organized
in the same way as when a sub-model is sampled from the search space. This patch will help
the modules in the sub-model find the corresponding module in the base model.
The code looks like,
.. code-block:: python
with original_state_dict_hooks(model):
model.load_state_dict(state_dict_from_supernet, strict=False) # supernet has extra keys
Or vice-versa,
.. code-block:: python
with original_state_dict_hooks(model):
supernet_style_state_dict = model.state_dict()
"""
import torch.nn as nn
assert isinstance(model, nn.Module), 'PyTorch is the only supported framework for now.'
# the following are written for pytorch only
# first get the full mapping
full_mapping = {}
def full_mapping_in_module(src_prefix, tar_prefix, module):
if hasattr(module, STATE_DICT_PY_MAPPING):
# only values are complete
local_map = getattr(module, STATE_DICT_PY_MAPPING)
elif hasattr(module, STATE_DICT_PY_MAPPING_PARTIAL):
# keys and values are both incomplete
local_map = getattr(module, STATE_DICT_PY_MAPPING_PARTIAL)
local_map = {k: tar_prefix + v for k, v in local_map.items()}
else:
# no mapping
local_map = {}
if '__self__' in local_map:
# special case, overwrite prefix
tar_prefix = local_map['__self__'] + '.'
for key, value in local_map.items():
if key != '' and key not in module._modules: # not a sub-module, probably a parameter
full_mapping[src_prefix + key] = value
if src_prefix != tar_prefix: # To deal with leaf nodes.
for name, value in itertools.chain(module._parameters.items(), module._buffers.items()): # direct children
if value is None or name in module._non_persistent_buffers_set:
# it won't appear in state dict
continue
if (src_prefix + name) not in full_mapping:
full_mapping[src_prefix + name] = tar_prefix + name
for name, child in module.named_children():
# sub-modules
full_mapping_in_module(
src_prefix + name + '.',
local_map.get(name, tar_prefix + name) + '.', # if mapping doesn't exist, respect the prefix
child
)
full_mapping_in_module('', '', model)
def load_state_dict_hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
reverse_mapping = defaultdict(list)
for src, tar in full_mapping.items():
reverse_mapping[tar].append(src)
transf_state_dict = {}
for src, tar_keys in reverse_mapping.items():
if src in state_dict:
value = state_dict.pop(src)
for tar in tar_keys:
transf_state_dict[tar] = value
else:
missing_keys.append(src)
state_dict.update(transf_state_dict)
def state_dict_hook(module, destination, prefix, local_metadata):
result = {}
for src, tar in full_mapping.items():
if src in destination:
result[tar] = destination.pop(src)
else:
raise KeyError(f'"{src}" not in state dict, but found in mapping.')
destination.update(result)
try:
hooks = []
hooks.append(model._register_load_state_dict_pre_hook(load_state_dict_hook))
hooks.append(model._register_state_dict_hook(state_dict_hook))
yield
finally:
for hook in hooks:
hook.remove()
...@@ -70,21 +70,17 @@ def start_rest_server(port, platform, mode, experiment_id, foreground=False, log ...@@ -70,21 +70,17 @@ def start_rest_server(port, platform, mode, experiment_id, foreground=False, log
node_command = os.path.join(entry_dir, 'node') node_command = os.path.join(entry_dir, 'node')
cmds = [node_command, '--max-old-space-size=4096', entry_file, '--port', str(port), '--mode', platform, \ cmds = [node_command, '--max-old-space-size=4096', entry_file, '--port', str(port), '--mode', platform, \
'--experiment_id', experiment_id] '--experiment_id', experiment_id]
if mode == 'view': cmds += ['--action', mode]
cmds += ['--start_mode', 'resume']
cmds += ['--readonly', 'true']
else:
cmds += ['--start_mode', mode]
if log_dir is not None: if log_dir is not None:
cmds += ['--log_dir', log_dir] cmds += ['--experiments-directory', log_dir]
if log_level is not None: if log_level is not None:
cmds += ['--log_level', log_level] cmds += ['--log-level', log_level]
if foreground: if foreground:
cmds += ['--foreground', 'true'] cmds += ['--foreground', 'true']
if url_prefix: if url_prefix:
_validate_prefix_path(url_prefix) _validate_prefix_path(url_prefix)
set_prefix_url(url_prefix) set_prefix_url(url_prefix)
cmds += ['--url_prefix', url_prefix] cmds += ['--url-prefix', url_prefix.strip('/')]
stdout_full_path, stderr_full_path = get_log_path(experiment_id) stdout_full_path, stderr_full_path = get_log_path(experiment_id)
with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file: with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file:
...@@ -520,9 +516,9 @@ def create_experiment(args): ...@@ -520,9 +516,9 @@ def create_experiment(args):
try: try:
if schema == 1: if schema == 1:
launch_experiment(args, config_v1, 'new', experiment_id, 1) launch_experiment(args, config_v1, 'create', experiment_id, 1)
else: else:
launch_experiment(args, config_v2, 'new', experiment_id, 2) launch_experiment(args, config_v2, 'create', experiment_id, 2)
except Exception as exception: except Exception as exception:
restServerPid = Experiments().get_all_experiments().get(experiment_id, {}).get('pid') restServerPid = Experiments().get_all_experiments().get(experiment_id, {}).get('pid')
if restServerPid: if restServerPid:
......
...@@ -177,7 +177,7 @@ stages: ...@@ -177,7 +177,7 @@ stages:
- job: windows - job: windows
pool: pool:
vmImage: windows-latest vmImage: windows-latest
timeoutInMinutes: 70 timeoutInMinutes: 75
steps: steps:
- template: templates/install-dependencies.yml - template: templates/install-dependencies.yml
......
...@@ -512,6 +512,46 @@ class SpeedupTestCase(TestCase): ...@@ -512,6 +512,46 @@ class SpeedupTestCase(TestCase):
print("Fine-grained speeduped model") print("Fine-grained speeduped model")
print(model) print(model)
def test_multiplication_speedup(self):
"""
Model from issue 4540.
"""
class Net(torch.nn.Module):
def __init__(self,):
super(Net, self).__init__()
self.avgpool = torch.nn.AdaptiveAvgPool2d(1)
self.input = torch.nn.Conv2d(3, 8, 3)
self.bn = torch.nn.BatchNorm2d(8)
self.fc1 = torch.nn.Conv2d(8, 16, 1)
self.fc2 = torch.nn.Conv2d(16, 8, 1)
self.activation = torch.nn.ReLU()
self.scale_activation = torch.nn.Hardsigmoid()
self.out = torch.nn.Conv2d(8, 12, 1)
def forward(self, input):
input = self.activation(self.bn(self.input(input)))
scale = self.avgpool(input)
out1 = self.activation(self.fc1(scale))
out1 = self.scale_activation(self.fc2(out1))
return self.out(out1 * input)
model = Net().to(device)
model.eval()
im = torch.ones(1, 3, 512, 512).to(device)
model(im)
cfg_list = []
for name, module in model.named_modules():
if isinstance(module, torch.nn.Conv2d):
cfg_list.append({'op_types':['Conv2d'], 'sparsity':0.3, 'op_names':[name]})
pruner = L1FilterPruner(model, cfg_list)
pruner.compress()
pruner.export_model(MODEL_FILE, MASK_FILE)
pruner._unwrap_model()
ms=ModelSpeedup(model, im, MASK_FILE)
ms.speedup_model()
def tearDown(self): def tearDown(self):
if os.path.exists(MODEL_FILE): if os.path.exists(MODEL_FILE):
os.remove(MODEL_FILE) os.remove(MODEL_FILE)
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import random import random
import unittest import unittest
import numpy
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -105,6 +106,17 @@ class IterativePrunerTestCase(unittest.TestCase): ...@@ -105,6 +106,17 @@ class IterativePrunerTestCase(unittest.TestCase):
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list) sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.78 < sparsity_list[0]['total_sparsity'] < 0.82 assert 0.78 < sparsity_list[0]['total_sparsity'] < 0.82
def test_amc_pruner(self):
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'total_sparsity': 0.5, 'max_sparsity_per_layer': 0.8}]
dummy_input = torch.rand(10, 1, 28, 28)
ddpg_params = {'hidden1': 300, 'hidden2': 300, 'lr_c': 1e-3, 'lr_a': 1e-4, 'warmup': 5, 'discount': 1.,
'bsize': 64, 'rmsize': 100, 'window_length': 1, 'tau': 0.01, 'init_delta': 0.5, 'delta_decay': 0.99,
'max_episode_length': 1e9, 'epsilon': 50000}
pruner = AMCPruner(10, model, config_list, dummy_input, evaluator, finetuner=finetuner, ddpg_params=ddpg_params, target='flops', log_dir='../../../logs')
pruner.compress()
class FixSeedPrunerTestCase(unittest.TestCase):
def test_auto_compress_pruner(self): def test_auto_compress_pruner(self):
model = TorchModel() model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'total_sparsity': 0.8}] config_list = [{'op_types': ['Conv2d'], 'total_sparsity': 0.8}]
...@@ -126,15 +138,21 @@ class IterativePrunerTestCase(unittest.TestCase): ...@@ -126,15 +138,21 @@ class IterativePrunerTestCase(unittest.TestCase):
print(sparsity_list) print(sparsity_list)
assert 0.78 < sparsity_list[0]['total_sparsity'] < 0.82 assert 0.78 < sparsity_list[0]['total_sparsity'] < 0.82
def test_amc_pruner(self): def setUp(self) -> None:
model = TorchModel() # fix seed in order to solve the random failure of ut
config_list = [{'op_types': ['Conv2d'], 'total_sparsity': 0.5, 'max_sparsity_per_layer': 0.8}] random.seed(1024)
dummy_input = torch.rand(10, 1, 28, 28) numpy.random.seed(1024)
ddpg_params = {'hidden1': 300, 'hidden2': 300, 'lr_c': 1e-3, 'lr_a': 1e-4, 'warmup': 5, 'discount': 1., torch.manual_seed(1024)
'bsize': 64, 'rmsize': 100, 'window_length': 1, 'tau': 0.01, 'init_delta': 0.5, 'delta_decay': 0.99,
'max_episode_length': 1e9, 'epsilon': 50000} def tearDown(self) -> None:
pruner = AMCPruner(10, model, config_list, dummy_input, evaluator, finetuner=finetuner, ddpg_params=ddpg_params, target='flops', log_dir='../../../logs') # reset seed
pruner.compress() import time
now = int(time.time() * 100)
random.seed(now)
seed = random.randint(0, 2 ** 32 - 1)
random.seed(seed)
numpy.random.seed(seed)
torch.manual_seed(seed)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
import random
import unittest import unittest
import numpy
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -122,18 +124,6 @@ class PrunerTestCase(unittest.TestCase): ...@@ -122,18 +124,6 @@ class PrunerTestCase(unittest.TestCase):
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list) sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.78 < sparsity_list[0]['total_sparsity'] < 0.82 assert 0.78 < sparsity_list[0]['total_sparsity'] < 0.82
def test_activation_apoz_rank_pruner(self):
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
pruner = ActivationAPoZRankPruner(model=model, config_list=config_list, trainer=trainer,
traced_optimizer=get_optimizer(model), criterion=criterion, training_batches=5,
activation='relu', mode='dependency_aware',
dummy_input=torch.rand(10, 1, 28, 28))
pruned_model, masks = pruner.compress()
pruner._unwrap_model()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.78 < sparsity_list[0]['total_sparsity'] < 0.82
def test_activation_mean_rank_pruner(self): def test_activation_mean_rank_pruner(self):
model = TorchModel() model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}] config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
...@@ -177,6 +167,34 @@ class PrunerTestCase(unittest.TestCase): ...@@ -177,6 +167,34 @@ class PrunerTestCase(unittest.TestCase):
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list) sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.78 < sparsity_list[0]['total_sparsity'] < 0.82 assert 0.78 < sparsity_list[0]['total_sparsity'] < 0.82
class FixSeedPrunerTestCase(unittest.TestCase):
def test_activation_apoz_rank_pruner(self):
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
pruner = ActivationAPoZRankPruner(model=model, config_list=config_list, trainer=trainer,
traced_optimizer=get_optimizer(model), criterion=criterion, training_batches=5,
activation='relu', mode='dependency_aware',
dummy_input=torch.rand(10, 1, 28, 28))
pruned_model, masks = pruner.compress()
pruner._unwrap_model()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.78 < sparsity_list[0]['total_sparsity'] < 0.82
def setUp(self) -> None:
# fix seed in order to solve the random failure of ut
random.seed(1024)
numpy.random.seed(1024)
torch.manual_seed(1024)
def tearDown(self) -> None:
# reset seed
import time
now = int(time.time() * 100)
random.seed(now)
seed = random.randint(0, 2 ** 32 - 1)
random.seed(seed)
numpy.random.seed(seed)
torch.manual_seed(seed)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -16,6 +16,7 @@ class _model(nn.Module): ...@@ -16,6 +16,7 @@ class _model(nn.Module):
self.fc1 = torch.nn.Linear(out_features=256, in_features=1024) self.fc1 = torch.nn.Linear(out_features=256, in_features=1024)
self.fc2 = torch.nn.Linear(out_features=10, in_features=256) self.fc2 = torch.nn.Linear(out_features=10, in_features=256)
self.softmax = torch.nn.Softmax() self.softmax = torch.nn.Softmax()
self._mapping_ = {'stem': None, 'flatten': None, 'fc1': None, 'fc2': None, 'softmax': None}
def forward(self, image): def forward(self, image):
stem = self.stem(image) stem = self.stem(image)
...@@ -34,6 +35,7 @@ class stem(nn.Module): ...@@ -34,6 +35,7 @@ class stem(nn.Module):
self.pool1 = torch.nn.MaxPool2d(kernel_size=2) self.pool1 = torch.nn.MaxPool2d(kernel_size=2)
self.conv2 = torch.nn.Conv2d(out_channels=64, in_channels=32, kernel_size=5) self.conv2 = torch.nn.Conv2d(out_channels=64, in_channels=32, kernel_size=5)
self.pool2 = torch.nn.MaxPool2d(kernel_size=2) self.pool2 = torch.nn.MaxPool2d(kernel_size=2)
self._mapping_ = {'conv1': None, 'pool1': None, 'conv2': None, 'pool2': None}
def forward(self, *_inputs): def forward(self, *_inputs):
conv1 = self.conv1(_inputs[0]) conv1 = self.conv1(_inputs[0])
......
...@@ -9,6 +9,7 @@ from pytorch_lightning.utilities.seed import seed_everything ...@@ -9,6 +9,7 @@ from pytorch_lightning.utilities.seed import seed_everything
from pathlib import Path from pathlib import Path
import nni import nni
import nni.runtime.platform.test
try: try:
from nni.common.device import GPUDevice from nni.common.device import GPUDevice
......
...@@ -14,6 +14,7 @@ import torchvision ...@@ -14,6 +14,7 @@ import torchvision
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
from nni.retiarii import basic_unit from nni.retiarii import basic_unit
from nni.retiarii.codegen import model_to_pytorch_script from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.utils import original_state_dict_hooks
from .convert_mixin import ConvertMixin, ConvertWithShapeMixin from .convert_mixin import ConvertMixin, ConvertWithShapeMixin
...@@ -50,16 +51,6 @@ class Linear(nn.Module): ...@@ -50,16 +51,6 @@ class Linear(nn.Module):
return out.view(size[0], size[1], -1) return out.view(size[0], size[1], -1)
class TestConvert(unittest.TestCase, ConvertMixin): class TestConvert(unittest.TestCase, ConvertMixin):
@staticmethod
def _match_state_dict(current_values, expected_format):
result = {}
for k, v in expected_format.items():
for idx, cv in enumerate(current_values):
if cv.shape == v.shape:
result[k] = cv
current_values.pop(idx)
break
return result
def checkExportImport(self, model, input): def checkExportImport(self, model, input):
model_ir = self._convert_model(model, input) model_ir = self._convert_model(model, input)
...@@ -68,9 +59,8 @@ class TestConvert(unittest.TestCase, ConvertMixin): ...@@ -68,9 +59,8 @@ class TestConvert(unittest.TestCase, ConvertMixin):
exec_vars = {} exec_vars = {}
exec(model_code + '\n\nconverted_model = _model()', exec_vars) exec(model_code + '\n\nconverted_model = _model()', exec_vars)
converted_model = exec_vars['converted_model'] converted_model = exec_vars['converted_model']
converted_state_dict = self._match_state_dict(list(model.state_dict().values()), with original_state_dict_hooks(converted_model):
dict(converted_model.state_dict())) converted_model.load_state_dict(dict(model.state_dict()))
converted_model.load_state_dict(converted_state_dict)
with torch.no_grad(): with torch.no_grad():
expected_output = model.eval()(*input) expected_output = model.eval()(*input)
converted_output = converted_model.eval()(*input) converted_output = converted_model.eval()(*input)
......
...@@ -12,20 +12,11 @@ from nni.retiarii import basic_unit ...@@ -12,20 +12,11 @@ from nni.retiarii import basic_unit
from .convert_mixin import ConvertMixin, ConvertWithShapeMixin from .convert_mixin import ConvertMixin, ConvertWithShapeMixin
from nni.retiarii.codegen import model_to_pytorch_script from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.utils import original_state_dict_hooks
# following pytorch v1.7.1 # following pytorch v1.7.1
class TestConvert(unittest.TestCase, ConvertMixin): class TestConvert(unittest.TestCase, ConvertMixin):
@staticmethod
def _match_state_dict(current_values, expected_format):
result = {}
for k, v in expected_format.items():
for idx, cv in enumerate(current_values):
if cv.shape == v.shape:
result[k] = cv
current_values.pop(idx)
break
return result
def checkExportImport(self, model, input, check_value=True): def checkExportImport(self, model, input, check_value=True):
model_ir = self._convert_model(model, input) model_ir = self._convert_model(model, input)
...@@ -35,9 +26,10 @@ class TestConvert(unittest.TestCase, ConvertMixin): ...@@ -35,9 +26,10 @@ class TestConvert(unittest.TestCase, ConvertMixin):
exec_vars = {} exec_vars = {}
exec(model_code + '\n\nconverted_model = _model()', exec_vars) exec(model_code + '\n\nconverted_model = _model()', exec_vars)
converted_model = exec_vars['converted_model'] converted_model = exec_vars['converted_model']
converted_state_dict = self._match_state_dict(list(model.state_dict().values()),
dict(converted_model.state_dict())) with original_state_dict_hooks(converted_model):
converted_model.load_state_dict(converted_state_dict) converted_model.load_state_dict(model.state_dict())
with torch.no_grad(): with torch.no_grad():
expected_output = model.eval()(*input) expected_output = model.eval()(*input)
converted_output = converted_model.eval()(*input) converted_output = converted_model.eval()(*input)
......
...@@ -9,23 +9,13 @@ import torch.nn.functional as F ...@@ -9,23 +9,13 @@ import torch.nn.functional as F
import torchvision import torchvision
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
from nni.retiarii import serialize
from nni.retiarii.codegen import model_to_pytorch_script from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.utils import original_state_dict_hooks
from .convert_mixin import ConvertMixin, ConvertWithShapeMixin from .convert_mixin import ConvertMixin, ConvertWithShapeMixin
class TestModels(unittest.TestCase, ConvertMixin): class TestModels(unittest.TestCase, ConvertMixin):
@staticmethod
def _match_state_dict(current_values, expected_format):
result = {}
for k, v in expected_format.items():
for idx, cv in enumerate(current_values):
if cv.shape == v.shape:
result[k] = cv
current_values.pop(idx)
break
return result
def run_test(self, model, input, check_value=True): def run_test(self, model, input, check_value=True):
model_ir = self._convert_model(model, input) model_ir = self._convert_model(model, input)
...@@ -35,9 +25,10 @@ class TestModels(unittest.TestCase, ConvertMixin): ...@@ -35,9 +25,10 @@ class TestModels(unittest.TestCase, ConvertMixin):
exec_vars = {} exec_vars = {}
exec(model_code + '\n\nconverted_model = _model()', exec_vars) exec(model_code + '\n\nconverted_model = _model()', exec_vars)
converted_model = exec_vars['converted_model'] converted_model = exec_vars['converted_model']
converted_state_dict = self._match_state_dict(list(model.state_dict().values()),
dict(converted_model.state_dict())) with original_state_dict_hooks(converted_model):
converted_model.load_state_dict(converted_state_dict) converted_model.load_state_dict(model.state_dict())
with torch.no_grad(): with torch.no_grad():
expected_output = model.eval()(*input) expected_output = model.eval()(*input)
converted_output = converted_model.eval()(*input) converted_output = converted_model.eval()(*input)
......
...@@ -16,6 +16,7 @@ import torchvision ...@@ -16,6 +16,7 @@ import torchvision
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
from nni.retiarii.codegen import model_to_pytorch_script from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.utils import original_state_dict_hooks
from .convert_mixin import ConvertMixin, ConvertWithShapeMixin from .convert_mixin import ConvertMixin, ConvertWithShapeMixin
...@@ -23,16 +24,6 @@ from .convert_mixin import ConvertMixin, ConvertWithShapeMixin ...@@ -23,16 +24,6 @@ from .convert_mixin import ConvertMixin, ConvertWithShapeMixin
class TestOperators(unittest.TestCase, ConvertMixin): class TestOperators(unittest.TestCase, ConvertMixin):
@staticmethod
def _match_state_dict(current_values, expected_format):
result = {}
for k, v in expected_format.items():
for idx, cv in enumerate(current_values):
if cv.shape == v.shape:
result[k] = cv
current_values.pop(idx)
break
return result
def checkExportImport(self, model, input, check_value=True): def checkExportImport(self, model, input, check_value=True):
model_ir = self._convert_model(model, input) model_ir = self._convert_model(model, input)
...@@ -42,9 +33,10 @@ class TestOperators(unittest.TestCase, ConvertMixin): ...@@ -42,9 +33,10 @@ class TestOperators(unittest.TestCase, ConvertMixin):
exec_vars = {} exec_vars = {}
exec(model_code + '\n\nconverted_model = _model()', exec_vars) exec(model_code + '\n\nconverted_model = _model()', exec_vars)
converted_model = exec_vars['converted_model'] converted_model = exec_vars['converted_model']
converted_state_dict = self._match_state_dict(list(model.state_dict().values()),
dict(converted_model.state_dict())) with original_state_dict_hooks(converted_model):
converted_model.load_state_dict(converted_state_dict) converted_model.load_state_dict(model.state_dict())
with torch.no_grad(): with torch.no_grad():
expected_output = model.eval()(*input) expected_output = model.eval()(*input)
converted_output = converted_model.eval()(*input) converted_output = converted_model.eval()(*input)
......
...@@ -14,28 +14,17 @@ import torch.nn.functional as F ...@@ -14,28 +14,17 @@ import torch.nn.functional as F
import torchvision import torchvision
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
from nni.retiarii import serialize
from nni.retiarii.codegen import model_to_pytorch_script from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.utils import original_state_dict_hooks
from .convert_mixin import ConvertMixin, ConvertWithShapeMixin from .convert_mixin import ConvertMixin, ConvertWithShapeMixin
class TestPytorch(unittest.TestCase, ConvertMixin): class TestPytorch(unittest.TestCase, ConvertMixin):
@staticmethod
def _match_state_dict(current_values, expected_format): def run_test(self, model, input, check_value=True, strict_load=True):
result = {}
for k, v in expected_format.items():
for idx, cv in enumerate(current_values):
if cv.shape == v.shape:
result[k] = cv
current_values.pop(idx)
break
return result
def run_test(self, model, input, check_value=True):
model_ir = self._convert_model(model, input) model_ir = self._convert_model(model, input)
model_code = model_to_pytorch_script(model_ir) model_code = model_to_pytorch_script(model_ir)
print(model_code)
from .inject_nn import remove_inject_pytorch_nn from .inject_nn import remove_inject_pytorch_nn
remove_inject_pytorch_nn() remove_inject_pytorch_nn()
...@@ -43,9 +32,10 @@ class TestPytorch(unittest.TestCase, ConvertMixin): ...@@ -43,9 +32,10 @@ class TestPytorch(unittest.TestCase, ConvertMixin):
exec_vars = {} exec_vars = {}
exec(model_code + '\n\nconverted_model = _model()', exec_vars) exec(model_code + '\n\nconverted_model = _model()', exec_vars)
converted_model = exec_vars['converted_model'] converted_model = exec_vars['converted_model']
converted_state_dict = self._match_state_dict(list(model.state_dict().values()),
dict(converted_model.state_dict())) with original_state_dict_hooks(converted_model):
converted_model.load_state_dict(converted_state_dict) converted_model.load_state_dict(model.state_dict(), strict=strict_load)
with torch.no_grad(): with torch.no_grad():
expected_output = model.eval()(*input) expected_output = model.eval()(*input)
converted_output = converted_model.eval()(*input) converted_output = converted_model.eval()(*input)
...@@ -76,7 +66,8 @@ class TestPytorch(unittest.TestCase, ConvertMixin): ...@@ -76,7 +66,8 @@ class TestPytorch(unittest.TestCase, ConvertMixin):
model = LargeModel() model = LargeModel()
x = torch.tensor([2], dtype=torch.long) x = torch.tensor([2], dtype=torch.long)
self.run_test(model, (x, )) # emb and lin1 is actually not used so they won't appear in generated model
self.run_test(model, (x, ), strict_load=False)
@unittest.skip('skip for now, as it needs inject_nn') @unittest.skip('skip for now, as it needs inject_nn')
def test_mobilenet_v2_with_external_data(self): def test_mobilenet_v2_with_external_data(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment