"builder/vscode:/vscode.git/clone" did not exist on "96f1b8ef751872cfe542e2a762f9b6fab7a69659"
Unverified Commit 99aa8226 authored by kalineid's avatar kalineid Committed by GitHub
Browse files

[Retiarii]: Add info required by nn-meter to graph ir (#3910)



* Fix mutable default

* LayerChoice:forward now will default run the first candidate to support trace (#3910)

* New GraphConverter to parse shape info required by nn-meter (#3910)

* Support model filter in Random strategy

* Support latency aware search in SPOS multi-trial example

* Fix for review (#3910)

* Add doc for hardware-aware NAS

* Fix lint python & Add nn_meter to sphinx mock

* Add comments

* Move LatencyFilter to examples

* Move example inputs into configs

* Support nested layer choice
Co-authored-by: default avatarJianyu Wei <v-wjiany@microsoft.com>
Co-authored-by: default avatarkalineid <nnob@mail.ustc.edu.cn>
Co-authored-by: default avatarYuge Zhang <scottyugochang@gmail.com>
Co-authored-by: default avatarYuge Zhang <Yuge.Zhang@microsoft.com>
parent 6af99c55
Hardware-aware NAS
==================
.. contents::
EndToEnd Multi-trial SPOS Demo
------------------------------
Basically, this demo will select the model whose latency satisfy constraints to train.
To run this demo, first install nn-Meter from source code (currently we haven't released this package, so development installation is required).
.. code-block:: bash
python setup.py develop
Then run multi-trail SPOS demo:
.. code-block:: bash
python ${NNI_ROOT}/examples/nas/oneshot/spos/multi_trial.py
How the demo works
------------------
To support latency-aware NAS, you first need a `Strategy` that supports filtering the models by latency. We provide such a filter named `LatencyFilter` in NNI and initialize a `Random` strategy with the filter:
.. code-block:: python
simple_strategy = strategy.Random(model_filter=LatencyFilter(100)
``LatencyFilter`` will predict the models\' latency by using nn-Meter and filter out the models whose latency are larger than the threshold (i.e., ``100`` in this example).
You can also build your own strategies and filters to support more flexible NAS such as sorting the models according to latency.
Then, pass this strategy to ``RetiariiExperiment`` along with some additional arguments: ``parse_shape=True, example_inputs=example_inputs``:
.. code-block:: python
RetiariiExperiment(base_model, trainer, [], simple_strategy, True, example_inputs)
Here, ``parse_shape=True`` means extracting shape info from the torch model as it is required by nn-Meter to predict latency. ``example_inputs`` is required for tracing shape info.
...@@ -11,3 +11,4 @@ In multi-trial NAS, users need model evaluator to evaluate the performance of ea ...@@ -11,3 +11,4 @@ In multi-trial NAS, users need model evaluator to evaluate the performance of ea
Exploration Strategies <ExplorationStrategies> Exploration Strategies <ExplorationStrategies>
Customize Exploration Strategies <WriteStrategy> Customize Exploration Strategies <WriteStrategy>
Execution Engines <ExecutionEngines> Execution Engines <ExecutionEngines>
Hardware-aware NAS <HardwareAwareNAS>
...@@ -51,7 +51,7 @@ extensions = [ ...@@ -51,7 +51,7 @@ extensions = [
] ]
# Add mock modules # Add mock modules
autodoc_mock_imports = ['apex', 'nni_node', 'tensorrt', 'pycuda'] autodoc_mock_imports = ['apex', 'nni_node', 'tensorrt', 'pycuda', 'nn_meter']
# Add any paths that contain templates here, relative to this directory. # Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates'] templates_path = ['_templates']
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import torch import torch
import torch.nn as nn import nni.retiarii.nn.pytorch as nn
class ShuffleNetBlock(nn.Module): class ShuffleNetBlock(nn.Module):
...@@ -78,7 +78,6 @@ class ShuffleNetBlock(nn.Module): ...@@ -78,7 +78,6 @@ class ShuffleNetBlock(nn.Module):
def _channel_shuffle(self, x): def _channel_shuffle(self, x):
bs, num_channels, height, width = x.size() bs, num_channels, height, width = x.size()
assert (num_channels % 4 == 0)
x = x.reshape(bs * num_channels // 2, 2, height * width) x = x.reshape(bs * num_channels // 2, 2, height * width)
x = x.permute(1, 0, 2) x = x.permute(1, 0, 2)
x = x.reshape(2, -1, num_channels // 2, height, width) x = x.reshape(2, -1, num_channels // 2, height, width)
......
# This file is to demo the usage of multi-trial NAS in the usage of SPOS search space.
import click import click
import nni.retiarii.evaluator.pytorch as pl import nni.retiarii.evaluator.pytorch as pl
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
...@@ -11,6 +13,8 @@ from torchvision.datasets import CIFAR10 ...@@ -11,6 +13,8 @@ from torchvision.datasets import CIFAR10
from blocks import ShuffleNetBlock, ShuffleXceptionBlock from blocks import ShuffleNetBlock, ShuffleXceptionBlock
from nn_meter import get_default_config, load_latency_predictors
class ShuffleNetV2(nn.Module): class ShuffleNetV2(nn.Module):
block_keys = [ block_keys = [
...@@ -73,10 +77,10 @@ class ShuffleNetV2(nn.Module): ...@@ -73,10 +77,10 @@ class ShuffleNetV2(nn.Module):
base_mid_channels = channels // 2 base_mid_channels = channels // 2
mid_channels = int(base_mid_channels) # prepare for scale mid_channels = int(base_mid_channels) # prepare for scale
choice_block = LayerChoice([ choice_block = LayerChoice([
serialize(ShuffleNetBlock, inp, oup, mid_channels=mid_channels, ksize=3, stride=stride, affine=self._affine), ShuffleNetBlock(inp, oup, mid_channels=mid_channels, ksize=3, stride=stride, affine=self._affine),
serialize(ShuffleNetBlock, inp, oup, mid_channels=mid_channels, ksize=5, stride=stride, affine=self._affine), ShuffleNetBlock(inp, oup, mid_channels=mid_channels, ksize=5, stride=stride, affine=self._affine),
serialize(ShuffleNetBlock, inp, oup, mid_channels=mid_channels, ksize=7, stride=stride, affine=self._affine), ShuffleNetBlock(inp, oup, mid_channels=mid_channels, ksize=7, stride=stride, affine=self._affine),
serialize(ShuffleXceptionBlock, inp, oup, mid_channels=mid_channels, stride=stride, affine=self._affine) ShuffleXceptionBlock(inp, oup, mid_channels=mid_channels, stride=stride, affine=self._affine)
]) ])
result.append(choice_block) result.append(choice_block)
...@@ -123,6 +127,35 @@ class ShuffleNetV2(nn.Module): ...@@ -123,6 +127,35 @@ class ShuffleNetV2(nn.Module):
torch.nn.init.constant_(m.bias, 0) torch.nn.init.constant_(m.bias, 0)
class LatencyFilter:
def __init__(self, threshold, config=None, hardware='', reverse=False):
"""
Filter the models according to predcted latency.
Parameters
----------
threshold: `float`
the threshold of latency
config, hardware:
determine the targeted device
reverse: `bool`
if reverse is `False`, then the model returns `True` when `latency < threshold`,
else otherwisse
"""
default_config, default_hardware = get_default_config()
if config is None:
config = default_config
if not hardware:
hardware = default_hardware
self.predictors = load_latency_predictors(config, hardware)
self.threshold = threshold
def __call__(self, ir_model):
latency = self.predictors.predict(ir_model, 'nni')
return latency < self.threshold
@click.command() @click.command()
@click.option('--port', default=8081, help='On which port the experiment is run.') @click.option('--port', default=8081, help='On which port the experiment is run.')
def _main(port): def _main(port):
...@@ -142,7 +175,7 @@ def _main(port): ...@@ -142,7 +175,7 @@ def _main(port):
val_dataloaders=pl.DataLoader(test_dataset, batch_size=64), val_dataloaders=pl.DataLoader(test_dataset, batch_size=64),
max_epochs=2, gpus=1) max_epochs=2, gpus=1)
simple_strategy = strategy.Random() simple_strategy = strategy.Random(model_filter=LatencyFilter(100))
exp = RetiariiExperiment(base_model, trainer, [], simple_strategy) exp = RetiariiExperiment(base_model, trainer, [], simple_strategy)
...@@ -152,6 +185,7 @@ def _main(port): ...@@ -152,6 +185,7 @@ def _main(port):
exp_config.trial_gpu_number = 1 exp_config.trial_gpu_number = 1
exp_config.training_service.use_active_gpu = False exp_config.training_service.use_active_gpu = False
exp_config.execution_engine = 'base' exp_config.execution_engine = 'base'
exp_config.example_inputs = [1, 3, 32, 32]
exp.run(exp_config, port) exp.run(exp_config, port)
......
...@@ -5,13 +5,17 @@ import re ...@@ -5,13 +5,17 @@ import re
import torch import torch
from ..graph import Graph, Model, Node from ..graph import Graph, Model, Node, Edge
from ..nn.pytorch import InputChoice, Placeholder from ..nn.pytorch import InputChoice, Placeholder, LayerChoice
from ..operation import Cell, Operation from ..operation import Cell, Operation
from ..serializer import get_init_parameters_or_fail from ..serializer import get_init_parameters_or_fail
from ..utils import get_importable_name from ..utils import get_importable_name
from .op_types import MODULE_EXCEPT_LIST, OpTypeName from .op_types import MODULE_EXCEPT_LIST, OpTypeName
from .utils import _convert_name, build_full_name from .utils import (
_convert_name, build_full_name, _without_shape_info,
_extract_info_from_trace_node, get_full_name_by_scope_name,
is_layerchoice_node, match_node, build_cand_name
)
class GraphConverter: class GraphConverter:
...@@ -305,9 +309,9 @@ class GraphConverter: ...@@ -305,9 +309,9 @@ class GraphConverter:
submodule_full_name = build_full_name(module_name, submodule_name) submodule_full_name = build_full_name(module_name, submodule_name)
submodule_obj = getattr(module, submodule_name) submodule_obj = getattr(module, submodule_name)
subgraph, sub_m_attrs = self.convert_module(script_module._modules[submodule_name], subgraph, sub_m_attrs = self._convert_module(script_module._modules[submodule_name],
submodule_obj, submodule_obj,
submodule_full_name, ir_model) submodule_full_name, ir_model)
else: else:
# %8 : __torch__.nni.retiarii.model_apis.nn.___torch_mangle_37.ModuleList = prim::GetAttr[name="cells"](%self) # %8 : __torch__.nni.retiarii.model_apis.nn.___torch_mangle_37.ModuleList = prim::GetAttr[name="cells"](%self)
# %10 : __torch__.darts_model.Cell = prim::GetAttr[name="0"](%8) # %10 : __torch__.darts_model.Cell = prim::GetAttr[name="0"](%8)
...@@ -339,7 +343,7 @@ class GraphConverter: ...@@ -339,7 +343,7 @@ class GraphConverter:
for each_name in list(reversed(module_name_space)): for each_name in list(reversed(module_name_space)):
submodule_obj = getattr(submodule_obj, each_name) submodule_obj = getattr(submodule_obj, each_name)
script_submodule = script_submodule._modules[each_name] script_submodule = script_submodule._modules[each_name]
subgraph, sub_m_attrs = self.convert_module(script_submodule, submodule_obj, submodule_full_name, ir_model) subgraph, sub_m_attrs = self._convert_module(script_submodule, submodule_obj, submodule_full_name, ir_model)
else: else:
raise RuntimeError('Unsupported module case: {}'.format(submodule.inputsAt(0).type().str())) raise RuntimeError('Unsupported module case: {}'.format(submodule.inputsAt(0).type().str()))
...@@ -566,29 +570,7 @@ class GraphConverter: ...@@ -566,29 +570,7 @@ class GraphConverter:
'accessor': module._accessor 'accessor': module._accessor
} }
def convert_module(self, script_module, module, module_name, ir_model): def _convert_module(self, script_module, module, module_name, ir_model):
"""
Convert a module to its graph ir (i.e., Graph) along with its input arguments
Parameters
----------
script_module : torch.jit.RecursiveScriptModule
the script module of ```module``` obtained with torch.jit.script
module : nn.Module
the targeted module instance
module_name : str
the constructed name space of ```module```
ir_model : Model
the whole graph ir
Returns
-------
Graph
the built graph ir from module, ```None``` means do not further parse the module
dict
the input arguments of this module
"""
# NOTE: have not supported nested LayerChoice, i.e., a candidate module # NOTE: have not supported nested LayerChoice, i.e., a candidate module
# also has LayerChoice or InputChoice or ValueChoice # also has LayerChoice or InputChoice or ValueChoice
original_type_name = script_module.original_name original_type_name = script_module.original_name
...@@ -597,10 +579,18 @@ class GraphConverter: ...@@ -597,10 +579,18 @@ class GraphConverter:
pass # do nothing pass # do nothing
elif original_type_name == OpTypeName.LayerChoice: elif original_type_name == OpTypeName.LayerChoice:
graph = Graph(ir_model, -100, module_name, _internal=True) # graph_id is not used now graph = Graph(ir_model, -100, module_name, _internal=True) # graph_id is not used now
candidate_name_list = [f'layerchoice_{module.label}_{cand_name}' for cand_name in module.names] candidate_name_list = []
for cand_name, cand in zip(candidate_name_list, module): for cand_name in module.names:
cand_type = '__torch__.' + get_importable_name(cand.__class__) cand = module[cand_name]
graph.add_node(cand_name, cand_type, get_init_parameters_or_fail(cand)) script_cand = script_module._modules[cand_name]
cand_name = build_cand_name(cand_name, module.label)
candidate_name_list.append(cand_name)
subgraph, attrs = self._convert_module(script_cand, cand, cand_name, ir_model)
if subgraph is not None:
graph.add_node(subgraph.name, Cell(cell_name=subgraph.name, parameters=attrs))
else:
cand_type = '__torch__.' + get_importable_name(cand.__class__)
graph.add_node(cand_name, cand_type, attrs)
graph._register() graph._register()
return graph, {'mutation': 'layerchoice', 'label': module.label, 'candidates': candidate_name_list} return graph, {'mutation': 'layerchoice', 'label': module.label, 'candidates': candidate_name_list}
elif original_type_name == OpTypeName.InputChoice: elif original_type_name == OpTypeName.InputChoice:
...@@ -654,8 +644,214 @@ class GraphConverter: ...@@ -654,8 +644,214 @@ class GraphConverter:
return ir_graph, {} return ir_graph, {}
def convert_module(self, script_module, module, module_name, ir_model):
"""
Convert a module to its graph ir (i.e., Graph) along with its input arguments
def convert_to_graph(script_module, module): Parameters
----------
script_module : torch.jit.RecursiveScriptModule
the script module of ```module``` obtained with torch.jit.script
module : nn.Module
the targeted module instance
module_name : str
the constructed name space of ```module```
ir_model : Model
the whole graph ir
Returns
-------
Graph
the built graph ir from module, ```None``` means do not further parse the module
dict
the input arguments of this module
"""
return self._convert_module(script_module, module, module_name, ir_model)
class GraphConverterWithShape(GraphConverter):
"""
Convert a pytorch model to nni ir along with input/output shape info.
Based ir acquired through `torch.jit.script`
and shape info acquired through `torch.jit.trace`.
Known issues
------------
1. `InputChoice` and `ValueChoice` not supported yet.
2. Currently random inputs are feeded while tracing layerchoice.
If forward path of candidates depends on input data, then wrong path will be traced.
This will result in incomplete shape info.
"""
def convert_module(self, script_module, module, module_name, ir_model, example_inputs):
module.eval()
ir_graph, attrs = self._convert_module(script_module, module, module_name, ir_model)
self.remove_dummy_nodes(ir_model)
self._initialize_parameters(ir_model)
self._trace_module(module, module_name, ir_model, example_inputs)
return ir_graph, attrs
def _initialize_parameters(self, ir_model: 'Model'):
for ir_node in ir_model.get_nodes():
if ir_node.operation.parameters is None:
ir_node.operation.parameters = {}
ir_node.operation.parameters.setdefault('input_shape', [])
ir_node.operation.parameters.setdefault('output_shape', [])
def _trace_module(self, module, module_name, ir_model: 'Model', example_inputs):
# First, trace the whole graph
tm_graph = self._trace(module, example_inputs)
for node in tm_graph.nodes():
parameters = _extract_info_from_trace_node(node)
# '__module.convpool/__module.convpool.1/__module.convpool.1.conv'
ir_node = match_node(ir_model, node, module_name)
if ir_node is not None:
ir_node.operation.parameters.update(parameters)
self.propagate_shape(ir_model)
# trace each layerchoice
for name, submodule in module.named_modules():
# TODO: support InputChoice and ValueChioce
if isinstance(submodule, LayerChoice):
full_name = get_full_name_by_scope_name(ir_model, name.split('.'), module_name)
lc_node = ir_model.get_node_by_name(full_name)
for cand_name in submodule.names:
cand = submodule[cand_name]
cand_name = build_cand_name(cand_name, submodule.label)
# TODO: Feed the exact input tensor if user provides input,
# in case the path changes according to input data.
lc_inputs = [torch.randn(shape) for shape in lc_node.operation.parameters['input_shape']]
self._trace_module(cand, cand_name, ir_model, lc_inputs)
def propagate_shape(self, ir_model: 'Model'):
def propagate_shape_for_graph(graph: 'Graph'):
if graph == ir_model.root_graph:
return
graph_node = ir_model.get_node_by_name(graph.name)
if not _without_shape_info(graph_node):
return
if is_layerchoice_node(graph_node):
cand_name = graph_node.operation.parameters['candidates'][0]
cand_node = ir_model.get_node_by_name(cand_name)
if _without_shape_info(cand_node):
propagate_shape_for_graph(ir_model.graphs[cand_name])
graph_node.operation.parameters['input_shape'] = cand_node.operation.parameters['input_shape']
graph_node.operation.parameters['output_shape'] = cand_node.operation.parameters['output_shape']
else:
input_shape = [[]] * len(graph.input_node.operation.io_names or [])
output_shape = [[]] * len(graph.output_node.operation.io_names or [])
for edge in graph.input_node.outgoing_edges:
node = edge.tail
if _without_shape_info(node):
if node.name in ir_model.graphs:
propagate_shape_for_graph(ir_model.graphs[node.name])
if node.operation.parameters['input_shape']:
input_shape[edge.head_slot or 0] = node.operation.parameters['input_shape'][edge.tail_slot or 0]
graph_node.operation.parameters['input_shape'] = input_shape
for edge in graph.output_node.incoming_edges:
node = edge.head
if _without_shape_info(node):
if node.name in ir_model.graphs:
propagate_shape_for_graph(ir_model.graphs[node.name])
if node.operation.parameters['output_shape']:
output_shape[edge.tail_slot or 0] = node.operation.parameters['output_shape'][edge.head_slot or 0]
graph_node.operation.parameters['output_shape'] = output_shape
propagate_shape_for_graph(graph_node.graph)
# propagate from node to graph
for node in ir_model.get_nodes():
propagate_shape_for_graph(node.graph)
def flatten(self, ir_model: 'Model'):
"""
Flatten the subgraph into root graph.
"""
def _flatten(graph: 'Graph'):
"""
flatten this graph
"""
model = graph.model
node_to_remove = []
for node in graph.hidden_nodes:
node_graph = model.graphs.get(node.name)
if node_graph is not None:
_flatten(node_graph)
# flatten node graph into this graph
id_to_new_node = {}
for node_graph_node in node_graph.hidden_nodes:
new_node = Node(graph, node_graph_node.id, node_graph_node.name, node_graph_node.operation, _internal=True)
new_node.update_label(node_graph_node.label)
new_node._register()
id_to_new_node[new_node.id] = new_node
# reconnect node edges
for in_edge in node.incoming_edges:
graph.del_edge(in_edge)
for input_node_edge in node_graph.input_node.outgoing_edges:
if input_node_edge.head_slot == in_edge.tail_slot:
graph.add_edge(
head=(in_edge.head, in_edge.head_slot),
tail=(id_to_new_node[input_node_edge.tail.id], input_node_edge.tail_slot))
for out_edge in node.outgoing_edges:
graph.del_edge(out_edge)
for output_node_edge in node_graph.output_node.incoming_edges:
if output_node_edge.head_slot == out_edge.tail_slot:
try:
graph.add_edge(
head=(id_to_new_node[output_node_edge.head.id], output_node_edge.head_slot),
tail=(out_edge.tail, out_edge.tail_slot))
except:
import pdb; pdb.set_trace()
for edge in node_graph.edges:
if edge.head == node_graph.input_node or edge.tail == node_graph.output_node:
continue
new_head = id_to_new_node[edge.head.id]
new_tail = id_to_new_node[edge.tail.id]
Edge((new_head, edge.head_slot), (new_tail, edge.tail_slot), _internal=True)._register()
node_to_remove.append(node)
del model.graphs[node.name]
for node in node_to_remove:
node.remove()
_flatten(ir_model.root_graph)
# remove subgraphs
ir_model.graphs = {ir_model._root_graph_name: ir_model.root_graph}
def _trace(self, module, example_inputs):
traced_module = torch.jit.trace(module, example_inputs)
torch._C._jit_pass_inline(traced_module.graph)
return traced_module.graph
def remove_dummy_nodes(self, ir_model: 'Model'):
# remove identity nodes
for node in ir_model.get_nodes_by_type('noop_identity'):
graph = node.graph
for in_edge in node.incoming_edges:
for out_edge in node.outgoing_edges:
if in_edge.tail_slot == out_edge.head_slot:
graph.add_edge(head=(in_edge.head, in_edge.head_slot), tail=(out_edge.tail, out_edge.tail_slot))
graph.del_edge(in_edge)
graph.del_edge(out_edge)
break
node.remove()
def convert_to_graph(script_module, module, converter=None, **kwargs):
""" """
Convert module to our graph ir, i.e., build a ```Model``` type Convert module to our graph ir, i.e., build a ```Model``` type
...@@ -665,6 +861,10 @@ def convert_to_graph(script_module, module): ...@@ -665,6 +861,10 @@ def convert_to_graph(script_module, module):
the script module obtained with torch.jit.script the script module obtained with torch.jit.script
module : nn.Module module : nn.Module
the targeted module instance the targeted module instance
converter : `TorchConverter`
default `GraphConverter` is used
kwargs:
will be passed to `converter.convert_module()`
Returns Returns
------- -------
...@@ -674,6 +874,8 @@ def convert_to_graph(script_module, module): ...@@ -674,6 +874,8 @@ def convert_to_graph(script_module, module):
model = Model(_internal=True) model = Model(_internal=True)
module_name = '_model' module_name = '_model'
GraphConverter().convert_module(script_module, module, module_name, model) if converter is None:
converter = GraphConverter()
converter.convert_module(script_module, module, module_name, model, **kwargs)
return model return model
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from ..operation import Cell
from ..graph import Model, Node
def build_full_name(prefix, name, seq=None): def build_full_name(prefix, name, seq=None):
if isinstance(name, list): if isinstance(name, list):
name = '__'.join(name) name = '__'.join(name)
...@@ -10,8 +14,98 @@ def build_full_name(prefix, name, seq=None): ...@@ -10,8 +14,98 @@ def build_full_name(prefix, name, seq=None):
return '{}__{}{}'.format(prefix, name, str(seq)) return '{}__{}{}'.format(prefix, name, str(seq))
def build_cand_name(name, label):
return f'layerchoice_{label}_{name}'
def _convert_name(name: str) -> str: def _convert_name(name: str) -> str:
""" """
Convert the names using separator '.' to valid variable name in code Convert the names using separator '.' to valid variable name in code
""" """
return name.replace('.', '__') return name.replace('.', '__')
def _extract_info_from_trace_node(trace_node):
"""
Extract parameters from a trace node.
Parameters
----------
trace_node: torch._C.Value
"""
input_shape = []
output_shape = []
inputs = list(trace_node.inputs())
# cat input tensors are in a strange place
if trace_node.kind() == 'aten::cat':
input_shape = [input.type().sizes() for input in inputs[0].node().inputs()]
else:
for _input in inputs:
input_type = _input.type()
if input_type.kind() == 'TensorType':
shape = input_type.sizes()
if shape:
input_shape.append(shape)
for _output in trace_node.outputs():
output_type = _output.type()
if output_type.kind() == 'TensorType':
shape = output_type.sizes()
if shape:
output_shape.append(shape)
parameters = {
'input_shape': input_shape,
'output_shape': output_shape,
}
if trace_node.kind() == 'aten::cat':
parameters['dim'] = inputs[1].toIValue()
return parameters
def is_layerchoice_node(ir_node: Node):
if ir_node is not None and isinstance(ir_node.operation, Cell) and ir_node.operation.parameters.get('mutation') == 'layerchoice':
return True
else:
return False
def get_full_name_by_scope_name(ir_model: Model, scope_names, prefix=''):
full_name = prefix
for last_scope in range(len(scope_names)):
ir_node = ir_model.get_node_by_name(full_name)
# check if it's layerchoice
if is_layerchoice_node(ir_node):
full_name = f'layerchoice_{ir_node.operation.parameters["label"]}_{scope_names[last_scope]}'
else:
full_name = build_full_name(full_name, scope_names[last_scope])
return full_name
def match_node(ir_model: Model, torch_node, prefix=''):
"""
Match the corresponding node of a torch._C.Value
"""
scope_names = torch_node.scopeName().split('/')[-1].split('.')[1:]
full_name = get_full_name_by_scope_name(ir_model, scope_names, prefix)
# handle the case when node is not nn.Module, but directly used in forward()
# Because name can't be directly matched, so I use a hacky way.
# I match the first unshaped node of that kind
graph = ir_model.graphs.get(full_name)
if graph is not None:
for node in graph.get_nodes_by_type(torch_node.kind()):
if not node.operation.parameters['input_shape']:
return node
return None
else:
return ir_model.get_node_by_name(full_name)
def _without_shape_info(node: Node):
return not node.operation.parameters['input_shape'] and not node.operation.parameters['output_shape']
...@@ -28,6 +28,7 @@ from nni.tools.nnictl.command_utils import kill_command ...@@ -28,6 +28,7 @@ from nni.tools.nnictl.command_utils import kill_command
from ..codegen import model_to_pytorch_script from ..codegen import model_to_pytorch_script
from ..converter import convert_to_graph from ..converter import convert_to_graph
from ..converter.graph_gen import GraphConverterWithShape
from ..execution import list_models, set_execution_engine from ..execution import list_models, set_execution_engine
from ..execution.python import get_mutation_dict from ..execution.python import get_mutation_dict
from ..graph import Model, Evaluator from ..graph import Model, Evaluator
...@@ -58,6 +59,9 @@ class RetiariiExeConfig(ConfigBase): ...@@ -58,6 +59,9 @@ class RetiariiExeConfig(ConfigBase):
training_service: TrainingServiceConfig training_service: TrainingServiceConfig
execution_engine: str = 'py' execution_engine: str = 'py'
# input used in GraphConverterWithShape. Currently support shape tuple only.
example_inputs: Optional[List[int]] = None
def __init__(self, training_service_platform: Optional[str] = None, **kwargs): def __init__(self, training_service_platform: Optional[str] = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if training_service_platform is not None: if training_service_platform is not None:
...@@ -106,7 +110,7 @@ _validation_rules = { ...@@ -106,7 +110,7 @@ _validation_rules = {
'training_service': lambda value: (type(value) is not TrainingServiceConfig, 'cannot be abstract base class') 'training_service': lambda value: (type(value) is not TrainingServiceConfig, 'cannot be abstract base class')
} }
def preprocess_model(base_model, trainer, applied_mutators, full_ir=True): def preprocess_model(base_model, trainer, applied_mutators, full_ir=True, example_inputs=None):
# TODO: this logic might need to be refactored into execution engine # TODO: this logic might need to be refactored into execution engine
if full_ir: if full_ir:
try: try:
...@@ -114,7 +118,13 @@ def preprocess_model(base_model, trainer, applied_mutators, full_ir=True): ...@@ -114,7 +118,13 @@ def preprocess_model(base_model, trainer, applied_mutators, full_ir=True):
except Exception as e: except Exception as e:
_logger.error('Your base model cannot be parsed by torch.jit.script, please fix the following error:') _logger.error('Your base model cannot be parsed by torch.jit.script, please fix the following error:')
raise e raise e
base_model_ir = convert_to_graph(script_module, base_model) if example_inputs is not None:
# FIXME: this is a workaround as full tensor is not supported in configs
example_inputs = torch.randn(*example_inputs)
converter = GraphConverterWithShape()
base_model_ir = convert_to_graph(script_module, base_model, converter, example_inputs=example_inputs)
else:
base_model_ir = convert_to_graph(script_module, base_model)
# handle inline mutations # handle inline mutations
mutators = process_inline_mutation(base_model_ir) mutators = process_inline_mutation(base_model_ir)
else: else:
...@@ -171,7 +181,8 @@ class RetiariiExperiment(Experiment): ...@@ -171,7 +181,8 @@ class RetiariiExperiment(Experiment):
def _start_strategy(self): def _start_strategy(self):
base_model_ir, self.applied_mutators = preprocess_model( base_model_ir, self.applied_mutators = preprocess_model(
self.base_model, self.trainer, self.applied_mutators, full_ir=self.config.execution_engine != 'py') self.base_model, self.trainer, self.applied_mutators, full_ir=self.config.execution_engine != 'py',
example_inputs=self.config.example_inputs)
_logger.info('Start strategy...') _logger.info('Start strategy...')
self.strategy.run(base_model_ir, self.applied_mutators) self.strategy.run(base_model_ir, self.applied_mutators)
......
...@@ -307,9 +307,9 @@ class Graph: ...@@ -307,9 +307,9 @@ class Graph:
@overload @overload
def add_node(self, name: str, operation: Operation) -> 'Node': ... def add_node(self, name: str, operation: Operation) -> 'Node': ...
@overload @overload
def add_node(self, name: str, type_name: str, parameters: Dict[str, Any] = {}) -> 'Node': ... def add_node(self, name: str, type_name: str, parameters: Dict[str, Any] = None) -> 'Node': ...
def add_node(self, name, operation_or_type, parameters={}): def add_node(self, name, operation_or_type, parameters=None):
if isinstance(operation_or_type, Operation): if isinstance(operation_or_type, Operation):
op = operation_or_type op = operation_or_type
else: else:
...@@ -319,9 +319,9 @@ class Graph: ...@@ -319,9 +319,9 @@ class Graph:
@overload @overload
def insert_node_on_edge(self, edge: 'Edge', name: str, operation: Operation) -> 'Node': ... def insert_node_on_edge(self, edge: 'Edge', name: str, operation: Operation) -> 'Node': ...
@overload @overload
def insert_node_on_edge(self, edge: 'Edge', name: str, type_name: str, parameters: Dict[str, Any] = {}) -> 'Node': ... def insert_node_on_edge(self, edge: 'Edge', name: str, type_name: str, parameters: Dict[str, Any] = None) -> 'Node': ...
def insert_node_on_edge(self, edge, name, operation_or_type, parameters={}) -> 'Node': def insert_node_on_edge(self, edge, name, operation_or_type, parameters=None) -> 'Node':
if isinstance(operation_or_type, Operation): if isinstance(operation_or_type, Operation):
op = operation_or_type op = operation_or_type
else: else:
...@@ -562,9 +562,9 @@ class Node: ...@@ -562,9 +562,9 @@ class Node:
@overload @overload
def update_operation(self, operation: Operation) -> None: ... def update_operation(self, operation: Operation) -> None: ...
@overload @overload
def update_operation(self, type_name: str, parameters: Dict[str, Any] = {}) -> None: ... def update_operation(self, type_name: str, parameters: Dict[str, Any] = None) -> None: ...
def update_operation(self, operation_or_type, parameters={}): def update_operation(self, operation_or_type, parameters=None):
if isinstance(operation_or_type, Operation): if isinstance(operation_or_type, Operation):
self.operation = operation_or_type self.operation = operation_or_type
else: else:
......
...@@ -90,6 +90,7 @@ class LayerChoice(nn.Module): ...@@ -90,6 +90,7 @@ class LayerChoice(nn.Module):
self.names.append(str(i)) self.names.append(str(i))
else: else:
raise TypeError("Unsupported candidates type: {}".format(type(candidates))) raise TypeError("Unsupported candidates type: {}".format(type(candidates)))
self._first_module = self._modules[self.names[0]] # to make the dummy forward meaningful
@property @property
def key(self): def key(self):
...@@ -143,7 +144,7 @@ class LayerChoice(nn.Module): ...@@ -143,7 +144,7 @@ class LayerChoice(nn.Module):
def forward(self, x): def forward(self, x):
warnings.warn('You should not run forward of this module directly.') warnings.warn('You should not run forward of this module directly.')
return x return self._first_module(x)
def __repr__(self): def __repr__(self):
return f'LayerChoice({self.candidates}, label={repr(self.label)})' return f'LayerChoice({self.candidates}, label={repr(self.label)})'
......
...@@ -52,7 +52,9 @@ class Operation: ...@@ -52,7 +52,9 @@ class Operation:
return True return True
@staticmethod @staticmethod
def new(type_name: str, parameters: Dict[str, Any] = {}, cell_name: str = None) -> 'Operation': def new(type_name: str, parameters: Dict[str, Any] = None, cell_name: str = None) -> 'Operation':
if parameters is None:
parameters = {}
if type_name == '_cell': if type_name == '_cell':
# NOTE: cell_name is the same as its Node's name, when the cell is wrapped within the node # NOTE: cell_name is the same as its Node's name, when the cell is wrapped within the node
return Cell(cell_name, parameters) return Cell(cell_name, parameters)
...@@ -199,9 +201,11 @@ class Cell(PyTorchOperation): ...@@ -199,9 +201,11 @@ class Cell(PyTorchOperation):
No real usage. Exists for compatibility with base class. No real usage. Exists for compatibility with base class.
""" """
def __init__(self, cell_name: str, parameters: Dict[str, Any] = {}): def __init__(self, cell_name: str, parameters: Dict[str, Any] = None):
self.type = '_cell' self.type = '_cell'
self.cell_name = cell_name self.cell_name = cell_name
if parameters is None:
parameters = {}
self.parameters = parameters self.parameters = parameters
def _to_class_name(self): def _to_class_name(self):
......
...@@ -10,7 +10,7 @@ from typing import Any, Dict, List ...@@ -10,7 +10,7 @@ from typing import Any, Dict, List
from .. import Sampler, submit_models, query_available_resources, budget_exhausted from .. import Sampler, submit_models, query_available_resources, budget_exhausted
from .base import BaseStrategy from .base import BaseStrategy
from .utils import dry_run_for_search_space, get_targeted_model from .utils import dry_run_for_search_space, get_targeted_model, filter_model
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -84,15 +84,18 @@ class Random(BaseStrategy): ...@@ -84,15 +84,18 @@ class Random(BaseStrategy):
Do not dry run to get the full search space. Used when the search space has variational size or candidates. Default: false. Do not dry run to get the full search space. Used when the search space has variational size or candidates. Default: false.
dedup : bool dedup : bool
Do not try the same configuration twice. When variational is true, deduplication is not supported. Default: true. Do not try the same configuration twice. When variational is true, deduplication is not supported. Default: true.
model_filter: Callable[[Model], bool]
Feed the model and return a bool. This will filter the models in search space and select which to submit.
""" """
def __init__(self, variational=False, dedup=True): def __init__(self, variational=False, dedup=True, model_filter=None):
self.variational = variational self.variational = variational
self.dedup = dedup self.dedup = dedup
if variational and dedup: if variational and dedup:
raise ValueError('Dedup is not supported in variational mode.') raise ValueError('Dedup is not supported in variational mode.')
self.random_sampler = _RandomSampler() self.random_sampler = _RandomSampler()
self._polling_interval = 2. self._polling_interval = 2.
self.filter = model_filter
def run(self, base_model, applied_mutators): def run(self, base_model, applied_mutators):
if self.variational: if self.variational:
...@@ -107,7 +110,8 @@ class Random(BaseStrategy): ...@@ -107,7 +110,8 @@ class Random(BaseStrategy):
for mutator in applied_mutators: for mutator in applied_mutators:
model = mutator.apply(model) model = mutator.apply(model)
_logger.debug('New model created. Applied mutators are: %s', str(applied_mutators)) _logger.debug('New model created. Applied mutators are: %s', str(applied_mutators))
submit_models(model) if filter_model(self.filter, model):
submit_models(model)
elif budget_exhausted(): elif budget_exhausted():
break break
else: else:
...@@ -121,4 +125,6 @@ class Random(BaseStrategy): ...@@ -121,4 +125,6 @@ class Random(BaseStrategy):
if budget_exhausted(): if budget_exhausted():
return return
time.sleep(self._polling_interval) time.sleep(self._polling_interval)
submit_models(get_targeted_model(base_model, applied_mutators, sample)) model = get_targeted_model(base_model, applied_mutators, sample)
if filter_model(self.filter, model):
submit_models(model)
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
import collections import collections
import logging
from typing import Dict, Any, List from typing import Dict, Any, List
from ..graph import Model from ..graph import Model
from ..mutator import Mutator, Sampler from ..mutator import Mutator, Sampler
_logger = logging.getLogger(__name__)
class _FixedSampler(Sampler): class _FixedSampler(Sampler):
def __init__(self, sample): def __init__(self, sample):
...@@ -30,3 +34,16 @@ def get_targeted_model(base_model: Model, mutators: List[Mutator], sample: dict) ...@@ -30,3 +34,16 @@ def get_targeted_model(base_model: Model, mutators: List[Mutator], sample: dict)
for mutator in mutators: for mutator in mutators:
model = mutator.bind_sampler(sampler).apply(model) model = mutator.bind_sampler(sampler).apply(model)
return model return model
def filter_model(model_filter, ir_model):
if model_filter is not None:
_logger.debug(f'Check if model satisfies constraints.')
if model_filter(ir_model):
_logger.debug(f'Model satisfied. Submit the model.')
return True
else:
_logger.debug(f'Model unsatisfied. Discard the model.')
return False
else:
return True
...@@ -111,6 +111,33 @@ class GraphIR(unittest.TestCase): ...@@ -111,6 +111,33 @@ class GraphIR(unittest.TestCase):
self.assertEqual(self._get_converted_pytorch_model(model_new)(torch.randn(1, 3, 3, 3)).size(), self.assertEqual(self._get_converted_pytorch_model(model_new)(torch.randn(1, 3, 3, 3)).size(),
torch.Size([1, i, 3, 3])) torch.Size([1, i, 3, 3]))
def test_nested_layer_choice(self):
@self.get_serializer()
class Net(nn.Module):
def __init__(self):
super().__init__()
self.module = nn.LayerChoice([
nn.LayerChoice([nn.Conv2d(3, 3, kernel_size=1),
nn.Conv2d(3, 4, kernel_size=1),
nn.Conv2d(3, 5, kernel_size=1)]),
nn.Conv2d(3, 1, kernel_size=1)
])
def forward(self, x):
return self.module(x)
model, mutators = self._get_model_with_mutators(Net())
self.assertEqual(len(mutators), 2)
mutators[0].bind_sampler(EnumerateSampler())
mutators[1].bind_sampler(EnumerateSampler())
input = torch.randn(1, 3, 5, 5)
self.assertEqual(self._get_converted_pytorch_model(mutators[1].apply(mutators[0].apply(model)))(input).size(),
torch.Size([1, 3, 5, 5]))
self.assertEqual(self._get_converted_pytorch_model(mutators[1].apply(mutators[0].apply(model)))(input).size(),
torch.Size([1, 1, 5, 5]))
self.assertEqual(self._get_converted_pytorch_model(mutators[1].apply(mutators[0].apply(model)))(input).size(),
torch.Size([1, 5, 5, 5]))
def test_input_choice(self): def test_input_choice(self):
@self.get_serializer() @self.get_serializer()
class Net(nn.Module): class Net(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment