Unverified Commit 08fe2924 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Merge pull request #3938 from microsoft/nn-meter

[DO NOT SQUASH] Support nn-Meter in Retiarii framework
parents 3bce6926 5e04d56c
Hardware-aware NAS
==================
.. contents::
EndToEnd Multi-trial SPOS Demo
------------------------------
Basically, this demo will select the model whose latency satisfy constraints to train.
To run this demo, first install nn-Meter from source code (currently we haven't released this package, so development installation is required).
.. code-block:: bash
python setup.py develop
Then run multi-trail SPOS demo:
.. code-block:: bash
python ${NNI_ROOT}/examples/nas/oneshot/spos/multi_trial.py
How the demo works
------------------
To support latency-aware NAS, you first need a `Strategy` that supports filtering the models by latency. We provide such a filter named `LatencyFilter` in NNI and initialize a `Random` strategy with the filter:
.. code-block:: python
simple_strategy = strategy.Random(model_filter=LatencyFilter(100)
``LatencyFilter`` will predict the models\' latency by using nn-Meter and filter out the models whose latency are larger than the threshold (i.e., ``100`` in this example).
You can also build your own strategies and filters to support more flexible NAS such as sorting the models according to latency.
Then, pass this strategy to ``RetiariiExperiment`` along with some additional arguments: ``parse_shape=True, dummy_input=dummy_input``:
.. code-block:: python
RetiariiExperiment(base_model, trainer, [], simple_strategy, True, dummy_input)
Here, ``parse_shape=True`` means extracting shape info from the torch model as it is required by nn-Meter to predict latency. ``dummy_input`` is required for tracing shape info.
......@@ -11,3 +11,4 @@ In multi-trial NAS, users need model evaluator to evaluate the performance of ea
Exploration Strategies <ExplorationStrategies>
Customize Exploration Strategies <WriteStrategy>
Execution Engines <ExecutionEngines>
Hardware-aware NAS <HardwareAwareNAS>
......@@ -51,7 +51,7 @@ extensions = [
]
# Add mock modules
autodoc_mock_imports = ['apex', 'nni_node', 'tensorrt', 'pycuda']
autodoc_mock_imports = ['apex', 'nni_node', 'tensorrt', 'pycuda', 'nn_meter']
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
......
......@@ -2,7 +2,7 @@
# Licensed under the MIT license.
import torch
import torch.nn as nn
import nni.retiarii.nn.pytorch as nn
class ShuffleNetBlock(nn.Module):
......@@ -27,7 +27,8 @@ class ShuffleNetBlock(nn.Module):
self.branch_main = nn.Sequential(*self._decode_point_depth_conv(sequence))
if stride == 2:
# FIXME: restore before merging into master
# remove if stride == 2 for torchscript
self.branch_proj = nn.Sequential(
# dw
nn.Conv2d(self.channels, self.channels, ksize, stride, self.pad,
......@@ -76,8 +77,7 @@ class ShuffleNetBlock(nn.Module):
return result
def _channel_shuffle(self, x):
bs, num_channels, height, width = x.data.size()
assert (num_channels % 4 == 0)
bs, num_channels, height, width = x.size()
x = x.reshape(bs * num_channels // 2, 2, height * width)
x = x.permute(1, 0, 2)
x = x.reshape(2, -1, num_channels // 2, height, width)
......
# This file is to demo the usage of multi-trial NAS in the usage of SPOS search space.
import click
import nni.retiarii.evaluator.pytorch as pl
import nni.retiarii.nn.pytorch as nn
import nni.retiarii.strategy as strategy
import torch
from nni.retiarii import serialize
from nni.retiarii.nn.pytorch import LayerChoice
from nni.retiarii.experiment.pytorch import RetiariiExeConfig, RetiariiExperiment
from torchvision import transforms
from torchvision.datasets import CIFAR10
from blocks import ShuffleNetBlock, ShuffleXceptionBlock
from nn_meter import get_default_config, load_latency_predictors
class ShuffleNetV2(nn.Module):
block_keys = [
'shufflenet_3x3',
'shufflenet_5x5',
'shufflenet_7x7',
'xception_3x3',
]
def __init__(self, input_size=224, first_conv_channels=16, last_conv_channels=1024, n_classes=1000, affine=False):
super().__init__()
assert input_size % 32 == 0
self.stage_blocks = [4, 4, 8, 4]
self.stage_channels = [64, 160, 320, 640]
self._parsed_flops = dict()
self._input_size = input_size
self._feature_map_size = input_size
self._first_conv_channels = first_conv_channels
self._last_conv_channels = last_conv_channels
self._n_classes = n_classes
self._affine = affine
# building first layer
self.first_conv = nn.Sequential(
nn.Conv2d(3, first_conv_channels, 3, 2, 1, bias=False),
nn.BatchNorm2d(first_conv_channels, affine=affine),
nn.ReLU(inplace=True),
)
self._feature_map_size //= 2
p_channels = first_conv_channels
features = []
for num_blocks, channels in zip(self.stage_blocks, self.stage_channels):
features.extend(self._make_blocks(num_blocks, p_channels, channels))
p_channels = channels
self.features = nn.Sequential(*features)
self.conv_last = nn.Sequential(
nn.Conv2d(p_channels, last_conv_channels, 1, 1, 0, bias=False),
nn.BatchNorm2d(last_conv_channels, affine=affine),
nn.ReLU(inplace=True),
)
self.globalpool = nn.AvgPool2d(self._feature_map_size)
self.dropout = nn.Dropout(0.1)
self.classifier = nn.Sequential(
nn.Linear(last_conv_channels, n_classes, bias=False),
)
self._initialize_weights()
def _make_blocks(self, blocks, in_channels, channels):
result = []
for i in range(blocks):
stride = 2 if i == 0 else 1
inp = in_channels if i == 0 else channels
oup = channels
base_mid_channels = channels // 2
mid_channels = int(base_mid_channels) # prepare for scale
choice_block = LayerChoice([
ShuffleNetBlock(inp, oup, mid_channels=mid_channels, ksize=3, stride=stride, affine=self._affine),
ShuffleNetBlock(inp, oup, mid_channels=mid_channels, ksize=5, stride=stride, affine=self._affine),
ShuffleNetBlock(inp, oup, mid_channels=mid_channels, ksize=7, stride=stride, affine=self._affine),
ShuffleXceptionBlock(inp, oup, mid_channels=mid_channels, stride=stride, affine=self._affine)
])
result.append(choice_block)
if stride == 2:
self._feature_map_size //= 2
return result
def forward(self, x):
bs = x.size(0)
x = self.first_conv(x)
x = self.features(x)
x = self.conv_last(x)
x = self.globalpool(x)
x = self.dropout(x)
x = x.contiguous().view(bs, -1)
x = self.classifier(x)
return x
def _initialize_weights(self):
# FIXME this won't work in base engine
for name, m in self.named_modules():
if isinstance(m, nn.Conv2d):
if 'first' in name:
torch.nn.init.normal_(m.weight, 0, 0.01)
else:
torch.nn.init.normal_(m.weight, 0, 1.0 / m.weight.shape[1])
if m.bias is not None:
torch.nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
if m.weight is not None:
torch.nn.init.constant_(m.weight, 1)
if m.bias is not None:
torch.nn.init.constant_(m.bias, 0.0001)
torch.nn.init.constant_(m.running_mean, 0)
elif isinstance(m, nn.BatchNorm1d):
torch.nn.init.constant_(m.weight, 1)
if m.bias is not None:
torch.nn.init.constant_(m.bias, 0.0001)
torch.nn.init.constant_(m.running_mean, 0)
elif isinstance(m, nn.Linear):
torch.nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
torch.nn.init.constant_(m.bias, 0)
class LatencyFilter:
def __init__(self, threshold, config=None, hardware='', reverse=False):
"""
Filter the models according to predcted latency.
Parameters
----------
threshold: `float`
the threshold of latency
config, hardware:
determine the targeted device
reverse: `bool`
if reverse is `False`, then the model returns `True` when `latency < threshold`,
else otherwisse
"""
default_config, default_hardware = get_default_config()
if config is None:
config = default_config
if not hardware:
hardware = default_hardware
self.predictors = load_latency_predictors(config, hardware)
self.threshold = threshold
def __call__(self, ir_model):
latency = self.predictors.predict(ir_model, 'nni')
return latency < self.threshold
@click.command()
@click.option('--port', default=8081, help='On which port the experiment is run.')
def _main(port):
base_model = ShuffleNetV2(32)
transf = [
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip()
]
normalize = [
transforms.ToTensor(),
transforms.Normalize([0.49139968, 0.48215827, 0.44653124], [0.24703233, 0.24348505, 0.26158768])
]
train_dataset = serialize(CIFAR10, 'data', train=True, download=True, transform=transforms.Compose(transf + normalize))
test_dataset = serialize(CIFAR10, 'data', train=False, transform=transforms.Compose(normalize))
trainer = pl.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=64),
val_dataloaders=pl.DataLoader(test_dataset, batch_size=64),
max_epochs=2, gpus=1)
simple_strategy = strategy.Random(model_filter=LatencyFilter(100))
exp = RetiariiExperiment(base_model, trainer, [], simple_strategy)
exp_config = RetiariiExeConfig('local')
exp_config.trial_concurrency = 2
exp_config.max_trial_number = 2
exp_config.trial_gpu_number = 1
exp_config.training_service.use_active_gpu = False
exp_config.execution_engine = 'base'
exp_config.dummy_input = [1, 3, 32, 32]
exp.run(exp_config, port)
print('Exported models:')
for model in exp.export_top_models(formatter='dict'):
print(model)
if __name__ == '__main__':
_main()
......@@ -5,13 +5,17 @@ import re
import torch
from ..graph import Graph, Model, Node
from ..nn.pytorch import InputChoice, Placeholder
from ..graph import Graph, Model, Node, Edge
from ..nn.pytorch import InputChoice, Placeholder, LayerChoice
from ..operation import Cell, Operation
from ..serializer import get_init_parameters_or_fail
from ..utils import get_importable_name
from .op_types import MODULE_EXCEPT_LIST, OpTypeName
from .utils import _convert_name, build_full_name
from .utils import (
_convert_name, build_full_name, _without_shape_info,
_extract_info_from_trace_node, get_full_name_by_scope_name,
is_layerchoice_node, match_node, build_cand_name
)
class GraphConverter:
......@@ -305,7 +309,7 @@ class GraphConverter:
submodule_full_name = build_full_name(module_name, submodule_name)
submodule_obj = getattr(module, submodule_name)
subgraph, sub_m_attrs = self.convert_module(script_module._modules[submodule_name],
subgraph, sub_m_attrs = self._convert_module(script_module._modules[submodule_name],
submodule_obj,
submodule_full_name, ir_model)
else:
......@@ -339,7 +343,7 @@ class GraphConverter:
for each_name in list(reversed(module_name_space)):
submodule_obj = getattr(submodule_obj, each_name)
script_submodule = script_submodule._modules[each_name]
subgraph, sub_m_attrs = self.convert_module(script_submodule, submodule_obj, submodule_full_name, ir_model)
subgraph, sub_m_attrs = self._convert_module(script_submodule, submodule_obj, submodule_full_name, ir_model)
else:
raise RuntimeError('Unsupported module case: {}'.format(submodule.inputsAt(0).type().str()))
......@@ -566,29 +570,7 @@ class GraphConverter:
'accessor': module._accessor
}
def convert_module(self, script_module, module, module_name, ir_model):
"""
Convert a module to its graph ir (i.e., Graph) along with its input arguments
Parameters
----------
script_module : torch.jit.RecursiveScriptModule
the script module of ```module``` obtained with torch.jit.script
module : nn.Module
the targeted module instance
module_name : str
the constructed name space of ```module```
ir_model : Model
the whole graph ir
Returns
-------
Graph
the built graph ir from module, ```None``` means do not further parse the module
dict
the input arguments of this module
"""
def _convert_module(self, script_module, module, module_name, ir_model):
# NOTE: have not supported nested LayerChoice, i.e., a candidate module
# also has LayerChoice or InputChoice or ValueChoice
original_type_name = script_module.original_name
......@@ -597,10 +579,18 @@ class GraphConverter:
pass # do nothing
elif original_type_name == OpTypeName.LayerChoice:
graph = Graph(ir_model, -100, module_name, _internal=True) # graph_id is not used now
candidate_name_list = [f'layerchoice_{module.label}_{cand_name}' for cand_name in module.names]
for cand_name, cand in zip(candidate_name_list, module):
candidate_name_list = []
for cand_name in module.names:
cand = module[cand_name]
script_cand = script_module._modules[cand_name]
cand_name = build_cand_name(cand_name, module.label)
candidate_name_list.append(cand_name)
subgraph, attrs = self._convert_module(script_cand, cand, cand_name, ir_model)
if subgraph is not None:
graph.add_node(subgraph.name, Cell(cell_name=subgraph.name, parameters=attrs))
else:
cand_type = '__torch__.' + get_importable_name(cand.__class__)
graph.add_node(cand_name, cand_type, get_init_parameters_or_fail(cand))
graph.add_node(cand_name, cand_type, attrs)
graph._register()
return graph, {'mutation': 'layerchoice', 'label': module.label, 'candidates': candidate_name_list}
elif original_type_name == OpTypeName.InputChoice:
......@@ -654,8 +644,214 @@ class GraphConverter:
return ir_graph, {}
def convert_module(self, script_module, module, module_name, ir_model):
"""
Convert a module to its graph ir (i.e., Graph) along with its input arguments
Parameters
----------
script_module : torch.jit.RecursiveScriptModule
the script module of ```module``` obtained with torch.jit.script
module : nn.Module
the targeted module instance
module_name : str
the constructed name space of ```module```
ir_model : Model
the whole graph ir
Returns
-------
Graph
the built graph ir from module, ```None``` means do not further parse the module
dict
the input arguments of this module
"""
return self._convert_module(script_module, module, module_name, ir_model)
class GraphConverterWithShape(GraphConverter):
"""
Convert a pytorch model to nni ir along with input/output shape info.
Based ir acquired through `torch.jit.script`
and shape info acquired through `torch.jit.trace`.
Known issues
------------
1. `InputChoice` and `ValueChoice` not supported yet.
2. Currently random inputs are feeded while tracing layerchoice.
If forward path of candidates depends on input data, then wrong path will be traced.
This will result in incomplete shape info.
"""
def convert_module(self, script_module, module, module_name, ir_model, dummy_input):
module.eval()
ir_graph, attrs = self._convert_module(script_module, module, module_name, ir_model)
self.remove_dummy_nodes(ir_model)
self._initialize_parameters(ir_model)
self._trace_module(module, module_name, ir_model, dummy_input)
return ir_graph, attrs
def _initialize_parameters(self, ir_model: 'Model'):
for ir_node in ir_model.get_nodes():
if ir_node.operation.parameters is None:
ir_node.operation.parameters = {}
ir_node.operation.parameters.setdefault('input_shape', [])
ir_node.operation.parameters.setdefault('output_shape', [])
def _trace_module(self, module, module_name, ir_model: 'Model', dummy_input):
# First, trace the whole graph
tm_graph = self._trace(module, dummy_input)
for node in tm_graph.nodes():
parameters = _extract_info_from_trace_node(node)
# '__module.convpool/__module.convpool.1/__module.convpool.1.conv'
ir_node = match_node(ir_model, node, module_name)
if ir_node is not None:
ir_node.operation.parameters.update(parameters)
self.propagate_shape(ir_model)
# trace each layerchoice
for name, submodule in module.named_modules():
# TODO: support InputChoice and ValueChioce
if isinstance(submodule, LayerChoice):
full_name = get_full_name_by_scope_name(ir_model, name.split('.'), module_name)
lc_node = ir_model.get_node_by_name(full_name)
for cand_name in submodule.names:
cand = submodule[cand_name]
cand_name = build_cand_name(cand_name, submodule.label)
# TODO: Feed the exact input tensor if user provides input,
# in case the path changes according to input data.
lc_inputs = [torch.randn(shape) for shape in lc_node.operation.parameters['input_shape']]
self._trace_module(cand, cand_name, ir_model, lc_inputs)
def propagate_shape(self, ir_model: 'Model'):
def propagate_shape_for_graph(graph: 'Graph'):
if graph == ir_model.root_graph:
return
graph_node = ir_model.get_node_by_name(graph.name)
if not _without_shape_info(graph_node):
return
if is_layerchoice_node(graph_node):
cand_name = graph_node.operation.parameters['candidates'][0]
cand_node = ir_model.get_node_by_name(cand_name)
if _without_shape_info(cand_node):
propagate_shape_for_graph(ir_model.graphs[cand_name])
graph_node.operation.parameters['input_shape'] = cand_node.operation.parameters['input_shape']
graph_node.operation.parameters['output_shape'] = cand_node.operation.parameters['output_shape']
else:
input_shape = [[]] * len(graph.input_node.operation.io_names or [])
output_shape = [[]] * len(graph.output_node.operation.io_names or [])
for edge in graph.input_node.outgoing_edges:
node = edge.tail
if _without_shape_info(node):
if node.name in ir_model.graphs:
propagate_shape_for_graph(ir_model.graphs[node.name])
if node.operation.parameters['input_shape']:
input_shape[edge.head_slot or 0] = node.operation.parameters['input_shape'][edge.tail_slot or 0]
graph_node.operation.parameters['input_shape'] = input_shape
for edge in graph.output_node.incoming_edges:
node = edge.head
if _without_shape_info(node):
if node.name in ir_model.graphs:
propagate_shape_for_graph(ir_model.graphs[node.name])
if node.operation.parameters['output_shape']:
output_shape[edge.tail_slot or 0] = node.operation.parameters['output_shape'][edge.head_slot or 0]
graph_node.operation.parameters['output_shape'] = output_shape
propagate_shape_for_graph(graph_node.graph)
# propagate from node to graph
for node in ir_model.get_nodes():
propagate_shape_for_graph(node.graph)
def flatten(self, ir_model: 'Model'):
"""
Flatten the subgraph into root graph.
"""
def _flatten(graph: 'Graph'):
"""
flatten this graph
"""
model = graph.model
node_to_remove = []
for node in graph.hidden_nodes:
node_graph = model.graphs.get(node.name)
if node_graph is not None:
_flatten(node_graph)
# flatten node graph into this graph
id_to_new_node = {}
for node_graph_node in node_graph.hidden_nodes:
new_node = Node(graph, node_graph_node.id, node_graph_node.name, node_graph_node.operation, _internal=True)
new_node.update_label(node_graph_node.label)
new_node._register()
id_to_new_node[new_node.id] = new_node
# reconnect node edges
for in_edge in node.incoming_edges:
graph.del_edge(in_edge)
for input_node_edge in node_graph.input_node.outgoing_edges:
if input_node_edge.head_slot == in_edge.tail_slot:
graph.add_edge(
head=(in_edge.head, in_edge.head_slot),
tail=(id_to_new_node[input_node_edge.tail.id], input_node_edge.tail_slot))
for out_edge in node.outgoing_edges:
graph.del_edge(out_edge)
for output_node_edge in node_graph.output_node.incoming_edges:
if output_node_edge.head_slot == out_edge.tail_slot:
try:
graph.add_edge(
head=(id_to_new_node[output_node_edge.head.id], output_node_edge.head_slot),
tail=(out_edge.tail, out_edge.tail_slot))
except:
import pdb; pdb.set_trace()
for edge in node_graph.edges:
if edge.head == node_graph.input_node or edge.tail == node_graph.output_node:
continue
new_head = id_to_new_node[edge.head.id]
new_tail = id_to_new_node[edge.tail.id]
Edge((new_head, edge.head_slot), (new_tail, edge.tail_slot), _internal=True)._register()
node_to_remove.append(node)
del model.graphs[node.name]
for node in node_to_remove:
node.remove()
_flatten(ir_model.root_graph)
# remove subgraphs
ir_model.graphs = {ir_model._root_graph_name: ir_model.root_graph}
def _trace(self, module, dummy_input):
traced_module = torch.jit.trace(module, dummy_input)
torch._C._jit_pass_inline(traced_module.graph)
return traced_module.graph
def remove_dummy_nodes(self, ir_model: 'Model'):
# remove identity nodes
for node in ir_model.get_nodes_by_type('noop_identity'):
graph = node.graph
for in_edge in node.incoming_edges:
for out_edge in node.outgoing_edges:
if in_edge.tail_slot == out_edge.head_slot:
graph.add_edge(head=(in_edge.head, in_edge.head_slot), tail=(out_edge.tail, out_edge.tail_slot))
graph.del_edge(in_edge)
graph.del_edge(out_edge)
break
node.remove()
def convert_to_graph(script_module, module):
def convert_to_graph(script_module, module, converter=None, **kwargs):
"""
Convert module to our graph ir, i.e., build a ```Model``` type
......@@ -665,6 +861,10 @@ def convert_to_graph(script_module, module):
the script module obtained with torch.jit.script
module : nn.Module
the targeted module instance
converter : `TorchConverter`
default `GraphConverter` is used
kwargs:
will be passed to `converter.convert_module()`
Returns
-------
......@@ -674,6 +874,8 @@ def convert_to_graph(script_module, module):
model = Model(_internal=True)
module_name = '_model'
GraphConverter().convert_module(script_module, module, module_name, model)
if converter is None:
converter = GraphConverter()
converter.convert_module(script_module, module, module_name, model, **kwargs)
return model
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from ..operation import Cell
from ..graph import Model, Node
def build_full_name(prefix, name, seq=None):
if isinstance(name, list):
name = '__'.join(name)
......@@ -10,8 +14,98 @@ def build_full_name(prefix, name, seq=None):
return '{}__{}{}'.format(prefix, name, str(seq))
def build_cand_name(name, label):
return f'layerchoice_{label}_{name}'
def _convert_name(name: str) -> str:
"""
Convert the names using separator '.' to valid variable name in code
"""
return name.replace('.', '__')
def _extract_info_from_trace_node(trace_node):
"""
Extract parameters from a trace node.
Parameters
----------
trace_node: torch._C.Value
"""
input_shape = []
output_shape = []
inputs = list(trace_node.inputs())
# cat input tensors are in a strange place
if trace_node.kind() == 'aten::cat':
input_shape = [input.type().sizes() for input in inputs[0].node().inputs()]
else:
for _input in inputs:
input_type = _input.type()
if input_type.kind() == 'TensorType':
shape = input_type.sizes()
if shape:
input_shape.append(shape)
for _output in trace_node.outputs():
output_type = _output.type()
if output_type.kind() == 'TensorType':
shape = output_type.sizes()
if shape:
output_shape.append(shape)
parameters = {
'input_shape': input_shape,
'output_shape': output_shape,
}
if trace_node.kind() == 'aten::cat':
parameters['dim'] = inputs[1].toIValue()
return parameters
def is_layerchoice_node(ir_node: Node):
if ir_node is not None and isinstance(ir_node.operation, Cell) and ir_node.operation.parameters.get('mutation') == 'layerchoice':
return True
else:
return False
def get_full_name_by_scope_name(ir_model: Model, scope_names, prefix=''):
full_name = prefix
for last_scope in range(len(scope_names)):
ir_node = ir_model.get_node_by_name(full_name)
# check if it's layerchoice
if is_layerchoice_node(ir_node):
full_name = f'layerchoice_{ir_node.operation.parameters["label"]}_{scope_names[last_scope]}'
else:
full_name = build_full_name(full_name, scope_names[last_scope])
return full_name
def match_node(ir_model: Model, torch_node, prefix=''):
"""
Match the corresponding node of a torch._C.Value
"""
scope_names = torch_node.scopeName().split('/')[-1].split('.')[1:]
full_name = get_full_name_by_scope_name(ir_model, scope_names, prefix)
# handle the case when node is not nn.Module, but directly used in forward()
# Because name can't be directly matched, so I use a hacky way.
# I match the first unshaped node of that kind
graph = ir_model.graphs.get(full_name)
if graph is not None:
for node in graph.get_nodes_by_type(torch_node.kind()):
if not node.operation.parameters['input_shape']:
return node
return None
else:
return ir_model.get_node_by_name(full_name)
def _without_shape_info(node: Node):
return not node.operation.parameters['input_shape'] and not node.operation.parameters['output_shape']
......@@ -29,6 +29,7 @@ from nni.common.device import GPUDevice
from ..codegen import model_to_pytorch_script
from ..converter import convert_to_graph
from ..converter.graph_gen import GraphConverterWithShape
from ..execution import list_models, set_execution_engine
from ..execution.python import get_mutation_dict
from ..graph import Model, Evaluator
......@@ -59,6 +60,9 @@ class RetiariiExeConfig(ConfigBase):
training_service: TrainingServiceConfig
execution_engine: str = 'py'
# input used in GraphConverterWithShape. Currently support shape tuple only.
dummy_input: Optional[List[int]] = None
def __init__(self, training_service_platform: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if training_service_platform is not None:
......@@ -107,7 +111,7 @@ _validation_rules = {
'training_service': lambda value: (type(value) is not TrainingServiceConfig, 'cannot be abstract base class')
}
def preprocess_model(base_model, trainer, applied_mutators, full_ir=True):
def preprocess_model(base_model, trainer, applied_mutators, full_ir=True, dummy_input=None):
# TODO: this logic might need to be refactored into execution engine
if full_ir:
try:
......@@ -115,6 +119,12 @@ def preprocess_model(base_model, trainer, applied_mutators, full_ir=True):
except Exception as e:
_logger.error('Your base model cannot be parsed by torch.jit.script, please fix the following error:')
raise e
if dummy_input is not None:
# FIXME: this is a workaround as full tensor is not supported in configs
dummy_input = torch.randn(*dummy_input)
converter = GraphConverterWithShape()
base_model_ir = convert_to_graph(script_module, base_model, converter, dummy_input=dummy_input)
else:
base_model_ir = convert_to_graph(script_module, base_model)
# handle inline mutations
mutators = process_inline_mutation(base_model_ir)
......@@ -172,7 +182,8 @@ class RetiariiExperiment(Experiment):
def _start_strategy(self):
base_model_ir, self.applied_mutators = preprocess_model(
self.base_model, self.trainer, self.applied_mutators, full_ir=self.config.execution_engine != 'py')
self.base_model, self.trainer, self.applied_mutators, full_ir=self.config.execution_engine != 'py',
dummy_input=self.config.dummy_input)
_logger.info('Start strategy...')
self.strategy.run(base_model_ir, self.applied_mutators)
......
......@@ -307,9 +307,9 @@ class Graph:
@overload
def add_node(self, name: str, operation: Operation) -> 'Node': ...
@overload
def add_node(self, name: str, type_name: str, parameters: Dict[str, Any] = {}) -> 'Node': ...
def add_node(self, name: str, type_name: str, parameters: Dict[str, Any] = None) -> 'Node': ...
def add_node(self, name, operation_or_type, parameters={}):
def add_node(self, name, operation_or_type, parameters=None):
if isinstance(operation_or_type, Operation):
op = operation_or_type
else:
......@@ -319,9 +319,9 @@ class Graph:
@overload
def insert_node_on_edge(self, edge: 'Edge', name: str, operation: Operation) -> 'Node': ...
@overload
def insert_node_on_edge(self, edge: 'Edge', name: str, type_name: str, parameters: Dict[str, Any] = {}) -> 'Node': ...
def insert_node_on_edge(self, edge: 'Edge', name: str, type_name: str, parameters: Dict[str, Any] = None) -> 'Node': ...
def insert_node_on_edge(self, edge, name, operation_or_type, parameters={}) -> 'Node':
def insert_node_on_edge(self, edge, name, operation_or_type, parameters=None) -> 'Node':
if isinstance(operation_or_type, Operation):
op = operation_or_type
else:
......@@ -562,9 +562,9 @@ class Node:
@overload
def update_operation(self, operation: Operation) -> None: ...
@overload
def update_operation(self, type_name: str, parameters: Dict[str, Any] = {}) -> None: ...
def update_operation(self, type_name: str, parameters: Dict[str, Any] = None) -> None: ...
def update_operation(self, operation_or_type, parameters={}):
def update_operation(self, operation_or_type, parameters=None):
if isinstance(operation_or_type, Operation):
self.operation = operation_or_type
else:
......
......@@ -98,6 +98,7 @@ class LayerChoice(nn.Module):
self.names.append(str(i))
else:
raise TypeError("Unsupported candidates type: {}".format(type(candidates)))
self._first_module = self._modules[self.names[0]] # to make the dummy forward meaningful
@property
def key(self):
......@@ -151,7 +152,7 @@ class LayerChoice(nn.Module):
def forward(self, x):
warnings.warn('You should not run forward of this module directly.')
return x
return self._first_module(x)
def __repr__(self):
return f'LayerChoice({self.candidates}, label={repr(self.label)})'
......
......@@ -52,7 +52,9 @@ class Operation:
return True
@staticmethod
def new(type_name: str, parameters: Dict[str, Any] = {}, cell_name: str = None) -> 'Operation':
def new(type_name: str, parameters: Dict[str, Any] = None, cell_name: str = None) -> 'Operation':
if parameters is None:
parameters = {}
if type_name == '_cell':
# NOTE: cell_name is the same as its Node's name, when the cell is wrapped within the node
return Cell(cell_name, parameters)
......@@ -199,9 +201,11 @@ class Cell(PyTorchOperation):
No real usage. Exists for compatibility with base class.
"""
def __init__(self, cell_name: str, parameters: Dict[str, Any] = {}):
def __init__(self, cell_name: str, parameters: Dict[str, Any] = None):
self.type = '_cell'
self.cell_name = cell_name
if parameters is None:
parameters = {}
self.parameters = parameters
def _to_class_name(self):
......
......@@ -10,7 +10,7 @@ from typing import Any, Dict, List
from .. import InvalidMutation, Sampler, submit_models, query_available_resources, budget_exhausted
from .base import BaseStrategy
from .utils import dry_run_for_search_space, get_targeted_model
from .utils import dry_run_for_search_space, get_targeted_model, filter_model
_logger = logging.getLogger(__name__)
......@@ -84,15 +84,18 @@ class Random(BaseStrategy):
Do not dry run to get the full search space. Used when the search space has variational size or candidates. Default: false.
dedup : bool
Do not try the same configuration twice. When variational is true, deduplication is not supported. Default: true.
model_filter: Callable[[Model], bool]
Feed the model and return a bool. This will filter the models in search space and select which to submit.
"""
def __init__(self, variational=False, dedup=True):
def __init__(self, variational=False, dedup=True, model_filter=None):
self.variational = variational
self.dedup = dedup
if variational and dedup:
raise ValueError('Dedup is not supported in variational mode.')
self.random_sampler = _RandomSampler()
self._polling_interval = 2.
self.filter = model_filter
def run(self, base_model, applied_mutators):
if self.variational:
......@@ -107,6 +110,7 @@ class Random(BaseStrategy):
for mutator in applied_mutators:
model = mutator.apply(model)
_logger.debug('New model created. Applied mutators are: %s', str(applied_mutators))
if filter_model(self.filter, model):
submit_models(model)
elif budget_exhausted():
break
......@@ -122,6 +126,8 @@ class Random(BaseStrategy):
return
time.sleep(self._polling_interval)
try:
submit_models(get_targeted_model(base_model, applied_mutators, sample))
model = get_targeted_model(base_model, applied_mutators, sample)
if filter_model(self.filter, model):
submit_models(model)
except InvalidMutation as e:
_logger.warning(f'Invalid mutation: {e}. Skip.')
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import collections
import logging
from typing import Dict, Any, List
from ..graph import Model
from ..mutator import Mutator, Sampler
_logger = logging.getLogger(__name__)
class _FixedSampler(Sampler):
def __init__(self, sample):
......@@ -30,3 +34,16 @@ def get_targeted_model(base_model: Model, mutators: List[Mutator], sample: dict)
for mutator in mutators:
model = mutator.bind_sampler(sampler).apply(model)
return model
def filter_model(model_filter, ir_model):
if model_filter is not None:
_logger.debug(f'Check if model satisfies constraints.')
if model_filter(ir_model):
_logger.debug(f'Model satisfied. Submit the model.')
return True
else:
_logger.debug(f'Model unsatisfied. Discard the model.')
return False
else:
return True
import torch
from nni.retiarii.converter.graph_gen import convert_to_graph, GraphConverterWithShape
class ConvertMixin:
@staticmethod
def _convert_model(model, input):
script_module = torch.jit.script(model)
model_ir = convert_to_graph(script_module, model)
return model_ir
class ConvertWithShapeMixin:
@staticmethod
def _convert_model(model, input):
script_module = torch.jit.script(model)
model_ir = convert_to_graph(script_module, model, converter=GraphConverterWithShape(), dummy_input=input)
return model_ir
......@@ -13,9 +13,10 @@ import torchvision
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import basic_unit
from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script
from .convert_mixin import ConvertMixin, ConvertWithShapeMixin
class MnistNet(nn.Module):
def __init__(self):
super(MnistNet, self).__init__()
......@@ -48,7 +49,7 @@ class Linear(nn.Module):
out = self.linear(input.view(size[0] * size[1], -1))
return out.view(size[0], size[1], -1)
class TestConvert(unittest.TestCase):
class TestConvert(unittest.TestCase, ConvertMixin):
@staticmethod
def _match_state_dict(current_values, expected_format):
result = {}
......@@ -61,8 +62,7 @@ class TestConvert(unittest.TestCase):
return result
def checkExportImport(self, model, input):
script_module = torch.jit.script(model)
model_ir = convert_to_graph(script_module, model)
model_ir = self._convert_model(model, input)
model_code = model_to_pytorch_script(model_ir)
exec_vars = {}
......@@ -579,3 +579,6 @@ class TestConvert(unittest.TestCase):
self.checkExportImport(model, (x,))
finally:
remove_inject_pytorch_nn()
class TestConvertWithShape(TestConvert, ConvertWithShapeMixin):
pass
......@@ -9,12 +9,13 @@ import torchvision
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import basic_unit
from nni.retiarii.converter import convert_to_graph
from .convert_mixin import ConvertMixin, ConvertWithShapeMixin
from nni.retiarii.codegen import model_to_pytorch_script
# following pytorch v1.7.1
class TestConvert(unittest.TestCase):
class TestConvert(unittest.TestCase, ConvertMixin):
@staticmethod
def _match_state_dict(current_values, expected_format):
result = {}
......@@ -27,8 +28,7 @@ class TestConvert(unittest.TestCase):
return result
def checkExportImport(self, model, input, check_value=True):
script_module = torch.jit.script(model)
model_ir = convert_to_graph(script_module, model)
model_ir = self._convert_model(model, input)
model_code = model_to_pytorch_script(model_ir)
print(model_code)
......@@ -280,3 +280,7 @@ class TestConvert(unittest.TestCase):
out1 = x.ceil()
return out1
self.checkExportImport(SimpleOp(), (torch.randn(4), ))
class TestConvertWithShape(TestConvert, ConvertWithShapeMixin):
pass
......@@ -10,11 +10,12 @@ import torchvision
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import serialize
from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script
from .convert_mixin import ConvertMixin, ConvertWithShapeMixin
class TestModels(unittest.TestCase):
class TestModels(unittest.TestCase, ConvertMixin):
@staticmethod
def _match_state_dict(current_values, expected_format):
result = {}
......@@ -27,8 +28,7 @@ class TestModels(unittest.TestCase):
return result
def run_test(self, model, input, check_value=True):
script_module = torch.jit.script(model)
model_ir = convert_to_graph(script_module, model)
model_ir = self._convert_model(model, input)
model_code = model_to_pytorch_script(model_ir)
print(model_code)
......@@ -89,3 +89,6 @@ class TestModels(unittest.TestCase):
model = Net(4)
x = torch.rand((1, 16), dtype=torch.float)
self.run_test(model, ([x], ))
class TestModelsWithShape(TestModels, ConvertWithShapeMixin):
pass
......@@ -15,13 +15,14 @@ import torch.nn.functional as F
import torchvision
import nni.retiarii.nn.pytorch as nn
from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script
from .convert_mixin import ConvertMixin, ConvertWithShapeMixin
# following pytorch v1.7.1
class TestOperators(unittest.TestCase):
class TestOperators(unittest.TestCase, ConvertMixin):
@staticmethod
def _match_state_dict(current_values, expected_format):
result = {}
......@@ -34,8 +35,7 @@ class TestOperators(unittest.TestCase):
return result
def checkExportImport(self, model, input, check_value=True):
script_module = torch.jit.script(model)
model_ir = convert_to_graph(script_module, model)
model_ir = self._convert_model(model, input)
model_code = model_to_pytorch_script(model_ir)
#print(model_code)
......@@ -1386,3 +1386,6 @@ class TestOperators(unittest.TestCase):
x = torch.randn(20, 5, 10, 10)
self.checkExportImport(SimpleOp(), (x, ))
class TestOperatorsWithShape(TestOperators, ConvertWithShapeMixin):
pass
......@@ -15,11 +15,12 @@ import torchvision
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import serialize
from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script
from .convert_mixin import ConvertMixin, ConvertWithShapeMixin
class TestPytorch(unittest.TestCase):
class TestPytorch(unittest.TestCase, ConvertMixin):
@staticmethod
def _match_state_dict(current_values, expected_format):
result = {}
......@@ -32,8 +33,7 @@ class TestPytorch(unittest.TestCase):
return result
def run_test(self, model, input, check_value=True):
script_module = torch.jit.script(model)
model_ir = convert_to_graph(script_module, model)
model_ir = self._convert_model(model, input)
model_code = model_to_pytorch_script(model_ir)
print(model_code)
......@@ -1231,3 +1231,6 @@ class TestPytorch(unittest.TestCase):
x = torch.randn(5, 3, 2)
self.run_test(SizeModel(10, 5), (x, ))
class TestPytorchWithShape(TestPytorch, ConvertWithShapeMixin):
pass
import unittest
import torch
import nni.retiarii.nn.pytorch as nn
from .convert_mixin import ConvertWithShapeMixin
class TestShape(unittest.TestCase, ConvertWithShapeMixin):
def test_simple_convnet(self):
class ConvNet(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 1, 3)
self.relu = nn.ReLU()
self.pool = nn.MaxPool2d(kernel_size=2)
def forward(self, x):
return self.pool(self.relu(self.conv(x)))
net = ConvNet()
input = torch.randn((1, 3, 224, 224))
model_ir = self._convert_model(net, input)
conv_node = model_ir.get_nodes_by_type('__torch__.torch.nn.modules.conv.Conv2d')[0]
relu_node = model_ir.get_nodes_by_type('__torch__.torch.nn.modules.activation.ReLU')[0]
pool_node = model_ir.get_nodes_by_type('__torch__.torch.nn.modules.pooling.MaxPool2d')[0]
self.assertEqual(conv_node.operation.parameters.get('input_shape'), [[1, 3, 224, 224]])
self.assertEqual(conv_node.operation.parameters.get('output_shape'), [[1, 1, 222, 222]])
self.assertEqual(relu_node.operation.parameters.get('input_shape'), [[1, 1, 222, 222]])
self.assertEqual(relu_node.operation.parameters.get('output_shape'), [[1, 1, 222, 222]])
self.assertEqual(pool_node.operation.parameters.get('input_shape'), [[1, 1, 222, 222]])
self.assertEqual(pool_node.operation.parameters.get('output_shape'), [[1, 1, 111, 111]])
def test_nested_module(self):
class ConvRelu(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 1, 3)
self.relu = nn.ReLU()
def forward(self, x):
return self.relu(self.conv(x))
class ConvNet(nn.Module):
def __init__(self):
super().__init__()
self.conv = ConvRelu()
self.pool = nn.MaxPool2d(kernel_size=2)
def forward(self, x):
return self.pool(self.conv(x))
net = ConvNet()
input = torch.randn((1, 3, 224, 224))
model_ir = self._convert_model(net, input)
# check if shape propagation works
cell_node = model_ir.get_nodes_by_type('_cell')[0]
self.assertEqual(cell_node.operation.parameters.get('input_shape'), [[1, 3, 224, 224]])
self.assertEqual(cell_node.operation.parameters.get('output_shape'), [[1, 1, 222, 222]])
def test_layerchoice(self):
class ConvNet(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.LayerChoice([
nn.Conv2d(3, 1, 3),
nn.Conv2d(3, 1, 5, padding=1),
])
self.pool = nn.MaxPool2d(kernel_size=2)
def forward(self, x):
return self.pool(self.conv(x))
net = ConvNet()
input = torch.randn((1, 3, 224, 224))
model_ir = self._convert_model(net, input)
# check shape info of each candidates
conv_nodes = model_ir.get_nodes_by_type('__torch__.torch.nn.modules.conv.Conv2d')
self.assertEqual(conv_nodes[0].operation.parameters.get('output_shape'), [[1, 1, 222, 222]])
self.assertEqual(conv_nodes[1].operation.parameters.get('output_shape'), [[1, 1, 222, 222]])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment