Unverified Commit d165905d authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

[Retiarii] end2end (#3122)

parent 7d1acfbd
# FIXME: For demonstration only. It should not be here
from pathlib import Path
from nni.experiment import Experiment
from nni.algorithms.hpo.hyperopt_tuner.hyperopt_tuner import HyperoptTuner
tuner = HyperoptTuner('tpe')
search_space = {
"dropout_rate": { "_type": "uniform", "_value": [0.5, 0.9] },
"conv_size": { "_type": "choice", "_value": [2, 3, 5, 7] },
"hidden_size": { "_type": "choice", "_value": [124, 512, 1024] },
"batch_size": { "_type": "choice", "_value": [16, 32] },
"learning_rate": { "_type": "choice", "_value": [0.0001, 0.001, 0.01, 0.1] }
}
experiment = Experiment(tuner, 'local')
experiment.config.experiment_name = 'test'
experiment.config.trial_concurrency = 2
experiment.config.max_trial_number = 5
experiment.config.search_space = search_space
experiment.config.trial_command = 'python3 mnist.py'
experiment.config.trial_code_directory = Path(__file__).parent
experiment.run(8081, debug=True)
......@@ -3,6 +3,9 @@
__version__ = '999.0.0-developing'
from .runtime.log import init_logger
init_logger()
from .runtime.env_vars import dispatcher_env_vars
from .utils import ClassArgsValidator
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .config import *
from .experiment import Experiment, RetiariiExperiment
from .nni_client import *
from .base import ExperimentConfig, RetiariiExpConfig
from .local import LocalExperimentConfig
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import dataclasses
import json
from pathlib import Path
from typing import Any, Dict, Optional, Union
@dataclasses.dataclass(init=False)
class ExperimentConfig:
experiment_name: str
search_space: Any
max_execution_seconds: Optional[int] = None
max_trial_number: Optional[int] = None
trial_concurrency: int
trial_command: str
trial_code_directory: Union[Path, str]
trial_gpu_number: int = 0
extra_config: Optional[Dict[str, str]] = None
_training_service: str
# these values will be used to create template object,
# and the user should overwrite them later.
_placeholder = {
'experiment_name': '_unset_',
'search_space': '_unset_',
'trial_concurrency': -1,
'trial_command': '_unset_',
'trial_code_directory': '_unset_'
}
# simple validation functions
# complex validation logic with special error message should go to `validate()` method instead
_value_range = {
'max_execution_seconds': lambda x: x is None or x > 0,
'max_trial_number': lambda x: x is None or x > 0,
'trial_concurrency': lambda x: x > 0,
'trial_gpu_number': lambda x: x >= 0
}
def __init__(self, **kwargs):
for field in dataclasses.fields(self):
if field.name in kwargs:
setattr(self, field.name, kwargs[field.name])
elif field.default != dataclasses.MISSING:
setattr(self, field.name, field.default)
else:
setattr(self, field.name, type(self)._placeholder[field.name])
def validate(self) -> None:
# check existence
for key, placeholder_value in type(self)._placeholder.items():
if getattr(self, key) == placeholder_value:
raise ValueError(f'Field "{key}" is not set')
# TODO: check type
# check value
for key, condition in type(self)._value_range.items():
value = getattr(self, key)
if not condition(value):
raise ValueError(f'Field "{key}" ({repr(value)}) out of range')
# check special fields
if not Path(self.trial_code_directory).is_dir():
raise ValueError(f'Trial code directory "{self.trial_code_directory}" does not exist or is not directory')
def experiment_config_json(self) -> Dict[str, Any]:
# this only contains the common part for most (if not all) training services
# subclasses should override it to provide exclusive fields
return {
'authorName': '_',
'experimentName': self.experiment_name,
'trialConcurrency': self.trial_concurrency,
'maxExecDuration': self.max_execution_seconds or (999 * 24 * 3600),
'maxTrialNum': self.max_trial_number or 99999,
'searchSpace': json.dumps(self.search_space),
'trainingServicePlatform': self._training_service,
'tuner': {'builtinTunerName': '_user_created_'},
**(self.extra_config or {})
}
def cluster_metadata_json(self) -> Any:
# the cluster metadata format is a total mess
# leave it to each subclass before we refactoring nni manager
raise NotImplementedError()
@staticmethod
def create_template(training_service: str) -> 'ExperimentConfig':
for cls in ExperimentConfig.__subclasses__():
for field in dataclasses.fields(cls):
if field.name == '_training_service' and field.default == training_service:
return cls()
raise ValueError(f'Unrecognized training service {training_service}')
class RetiariiExpConfig(ExperimentConfig):
@staticmethod
def create_template(training_service: str) -> 'ExperimentConfig':
for cls in ExperimentConfig.__subclasses__():
for field in dataclasses.fields(cls):
if field.name == '_training_service' and field.default == training_service:
config_obj = cls()
config_obj.search_space = {}
config_obj.trial_command = 'python3 -m nni.retiarii.trial_entry'
# FIXME: expose this field to users
config_obj.trial_code_directory = '../..'
return config_obj
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict
from .base import ExperimentConfig
@dataclass(init=False)
class LocalExperimentConfig(ExperimentConfig):
use_active_gpu: bool = False
_training_service: str = 'local'
def experiment_config_json(self) -> Dict[str, Any]:
ret = super().experiment_config_json()
ret['clusterMetaData'] = [
{
'key': 'codeDir',
'value': str(Path(self.trial_code_directory).resolve())
},
{
'key': 'command',
'value': self.trial_command
}
]
#ret['local_config'] = {
# 'useActiveGpu': self.use_active_gpu
#}
return ret
def cluster_metadata_json(self) -> Any:
return {
'trial_config': {
'command': self.trial_command,
'codeDir': str(Path(self.trial_code_directory).resolve())
}
}
import logging
from subprocess import Popen
import time
from threading import Thread
from typing import Optional, overload, List, Union, Callable
from nni.runtime.msg_dispatcher import MsgDispatcher
from nni.tuner import Tuner
from nni.retiarii.integration import RetiariiAdvisor
from nni.retiarii.converter.graph_gen import convert_to_graph
from .config import ExperimentConfig
from . import launcher
from .pipe import Pipe
from . import rest
_logger = logging.getLogger(__name__)
class Experiment:
"""
Controls an NNI experiment.
You may either create a new NNI experiment with construtor and `Experiment.start()`,
# TODO: or control an existing experiment with `Experiment.connect()`.
Attributes
----------
config
Experiment configuration.
port
Web UI port of the experiment, or `None` if it is not running.
"""
@overload
def __init__(self, tuner: Tuner, config: ExperimentConfig) -> None:
"""
Prepare an experiment.
Use `Experiment.start()` to launch it.
Parameters
----------
tuner
A tuner instance. # TODO: accessor / advisor
config
Experiment configuration.
"""
...
@overload
def __init__(self, tuner: Tuner, training_service: str) -> None:
"""
Prepare an experiment, leaving configuration fields to be set later.
Example usage::
experiment = Experiment(my_tuner, 'remote')
experiment.config.trial_command = 'python3 trial.py'
experiment.config.machines.append(RemoteMachineConfig(ip=..., user_name=...))
...
experiment.start(8080)
Parameters
----------
tuner
A tuner instance.
training_service
Name of training service.
Supported value: "local", "remote", "openpai"/"pai".
"""
...
def __init__(self, tuner: Tuner, config=None, training_service=None):
self.config: ExperimentConfig
self.port: Optional[int] = None
self._dispatcher = MsgDispatcher(tuner, None)
self._proc: Optional[Popen] = None
self._pipe: Optional[Pipe] = None
if isinstance(config, str):
config, training_service = None, config
if training_service == 'openpai':
training_service = 'pai'
if config is None:
self.config = ExperimentConfig.create_template(training_service)
else:
self.config = config
def start(self, port: int = 8080, debug: bool = False) -> None:
"""
Start the experiment in background.
This method will raise exception on failure.
If it returns, the experiment should have been successfully started.
Parameters
----------
port
The port of web UI.
debug
Whether to start in debug mode.
"""
if debug:
logging.getLogger('nni').setLevel(logging.DEBUG)
self._proc, self._pipe = launcher.start_experiment(self.config, port, debug)
assert self._proc is not None
assert self._pipe is not None
self.port = port # port will be None if start up failed
# dispatcher must be created after pipe initialized
# the logic to launch dispatcher in background should be refactored into dispatcher api
Thread(target=self._dispatcher.run).start()
# TODO: register experiment management metadata
def stop(self) -> None:
"""
Stop background experiment.
"""
self._proc.kill()
self._pipe.close()
self.port = None
self._proc = None
self._pipe = None
def run(self, port: int = 8080, debug: bool = False) -> str:
"""
Run the experiment.
This function will block until experiment finish or error.
"""
self.start(port, debug)
try:
while True:
time.sleep(10)
status = self.get_status()
if status in ['ERROR', 'STOPPED', 'NO_MORE_TRIAL']:
return status
finally:
self.stop()
def get_status(self) -> str:
if self.port is None:
raise RuntimeError('Experiment is not running')
resp = rest.get(self.port, '/check-status')
return resp['status']
class RetiariiExperiment(Experiment):
def __init__(self, base_model: 'nn.Module', trainer: 'BaseTrainer',
applied_mutators: List['Mutator'], strategy: 'BaseStrategy',
tca: 'TraceClassArguments' = None):
self.config: ExperimentConfig = None
self.port: Optional[int] = None
self.base_model = base_model
self.trainer = trainer
self.applied_mutators = applied_mutators
self.strategy = strategy
self.recorded_module_args = tca.recorded_arguments # FIXME: remove this argument
self._dispatcher = RetiariiAdvisor()
self._proc: Optional[Popen] = None
self._pipe: Optional[Pipe] = None
def _start_strategy(self):
import torch
script_module = torch.jit.script(self.base_model)
base_model = convert_to_graph(script_module, self.base_model, self.recorded_module_args)
assert id(self.trainer) in self.recorded_module_args
trainer_config = self.recorded_module_args[id(self.trainer)]
_logger.info('Starting strategy...')
Thread(target=self.strategy.run, args=(base_model, self.applied_mutators, trainer_config)).start()
_logger.info('Strategy started!')
def start(self, config: ExperimentConfig, port: int = 8080, debug: bool = False) -> None:
"""
Start the experiment in background.
This method will raise exception on failure.
If it returns, the experiment should have been successfully started.
Parameters
----------
port
The port of web UI.
debug
Whether to start in debug mode.
"""
if debug:
logging.getLogger('nni').setLevel(logging.DEBUG)
self._proc, self._pipe = launcher.start_experiment(config, port, debug)
assert self._proc is not None
assert self._pipe is not None
self.port = port # port will be None if start up failed
# dispatcher must be created after pipe initialized
# the logic to launch dispatcher in background should be refactored into dispatcher api
Thread(target=self._dispatcher.run).start()
self._start_strategy()
# TODO: register experiment management metadata
def stop(self) -> None:
"""
Stop background experiment.
"""
self._proc.kill()
self._pipe.close()
self.port = None
self._proc = None
self._pipe = None
def run(self, config: ExperimentConfig, port: int = 8080, debug: bool = False) -> str:
"""
Run the experiment.
This function will block until experiment finish or error.
"""
self.start(config, port, debug)
try:
while True:
time.sleep(10)
status = self.get_status()
if status in ['ERROR', 'STOPPED', 'NO_MORE_TRIAL']:
return status
finally:
self.stop()
def get_status(self) -> str:
if self.port is None:
raise RuntimeError('Experiment is not running')
resp = rest.get(self.port, '/check-status')
return resp['status']
import contextlib
from pathlib import Path
import socket
from subprocess import Popen
import sys
import time
from typing import Optional, Tuple
import nni.runtime.protocol
import nni_node
from .config import ExperimentConfig
from . import management
from .pipe import Pipe
from . import rest
def start_experiment(config: ExperimentConfig, port: int, debug: bool) -> Tuple[Popen, Pipe]:
pipe = None
proc = None
config.validate()
_ensure_port_idle(port)
if config._training_service == 'pai':
_ensure_port_idle(port + 1, 'OpenPAI requires an additional port')
exp_id = management.generate_experiment_id()
try:
print(f'Creating experiment {exp_id}...')
pipe = Pipe(exp_id)
proc = _start_rest_server(config, port, debug, exp_id, pipe.path)
pipe_file = pipe.connect()
nni.runtime.protocol._in_file = pipe_file
nni.runtime.protocol._out_file = pipe_file
print('Statring web server...')
_check_rest_server(port)
print('Setting up...')
_init_experiment(config, port, debug) # todo: kill on fail
return proc, pipe
except Exception as e:
print('Create experiment failed')
if proc is not None:
with contextlib.suppress(Exception):
proc.kill()
if pipe is not None:
with contextlib.suppress(Exception):
pipe.close()
raise e
def _ensure_port_idle(port: int, message: Optional[str] = None) -> None:
sock = socket.socket()
if sock.connect_ex(('localhost', port)) == 0:
sock.close()
message = f'(message)' if message else ''
raise RuntimeError(f'Port {port} is not idle {message}')
def _start_rest_server(config: ExperimentConfig, port: int, debug: bool, experiment_id: str, pipe_path: str) -> Popen:
args = {
'port': port,
'mode': config._training_service,
'experiment_id': experiment_id,
'start_mode': 'new',
'log_level': 'debug' if debug else 'info',
'dispatcher_pipe': pipe_path,
}
node_dir = Path(nni_node.__path__[0])
node = str(node_dir / ('node.exe' if sys.platform == 'win32' else 'node'))
main_js = str(node_dir / 'main.js')
cmd = [node, '--max-old-space-size=4096', main_js]
for arg_key, arg_value in args.items():
cmd.append('--' + arg_key)
cmd.append(str(arg_value))
return Popen(cmd, cwd=node_dir)
def _check_rest_server(port: int, retry: int = 10) -> None:
for _ in range(retry):
with contextlib.suppress(Exception):
rest.get(port, '/check-status')
return
time.sleep(1)
rest.get(port, '/check-status')
def _init_experiment(config: ExperimentConfig, port: int, debug: bool) -> None:
rest.put(port, '/experiment/cluster-metadata', config.cluster_metadata_json())
rest.post(port, '/experiment', config.experiment_config_json())
from pathlib import Path
import random
import string
def generate_experiment_id() -> str:
return ''.join(random.sample(string.ascii_lowercase + string.digits, 8))
def create_experiment_directory(experiment_id: str) -> Path:
path = Path.home() / 'nni-experiments' / experiment_id
path.mkdir(parents=True, exist_ok=True)
return path
# TODO: port shangning's work here, and use it in Experiment.start()/.stop()
......@@ -29,7 +29,7 @@ import requests
import yaml
__all__ = [
'Experiment',
'ExternalExperiment',
'TrialResult',
'TrialMetricData',
'TrialHyperParameters',
......@@ -229,7 +229,7 @@ class TrialJob:
.format(self.trialJobId, self.status, self.hyperParameters, self.logPath,
self.startTime, self.endTime, self.finalMetricData, self.stderrPath)
class Experiment:
class ExternalExperiment:
def __init__(self):
self._endpoint = None
self._exp_id = None
......
from io import BufferedIOBase
import os
import sys
if sys.platform == 'win32':
import _win32
import msvcrt
class WindowsPipe:
def __init__(self, experiment_id: str):
self.path: str = r'\\.\pipe\nni-' + experiment_id
self.file = None
self._handle = _win32.CreateNamedPipe(
self.path,
_win32.PIPE_ACCESS_DUPLEX,
_win32.PIPE_TYPE_MESSAGE | _win32.PIPE_READMODE_MESSAGE | _win32.PIPE_WAIT,
1,
8192,
8192,
0,
_win32.NULL
)
def connect(self) -> BufferedIOBase:
_win32.ConnectNamedPipe(self._handle, _win32.NULL)
fd = msvcrt.open_osfhandle(self._handle)
self.file = os.fdopen(fd, 'rwb')
return self.file
def close(self) -> None:
if self.file is not None:
self.file.close()
_win32.CloseHandle(self._handle)
Pipe = WindowsPipe
else:
import socket
from . import management
class UnixPipe:
def __init__(self, experiment_id: str):
self.path: str = str(management.create_experiment_directory(experiment_id) / 'dispatcher-pipe')
self.file = None
self._socket = socket.socket(socket.AF_UNIX)
self._socket.bind(self.path)
self._socket.listen(1) # only accepts one connection
def connect(self) -> BufferedIOBase:
conn, _ = self._socket.accept()
self.file = conn.makefile('rwb')
return self.file
def close(self) -> None:
if self.file is not None:
self.file.close()
self._socket.close()
os.unlink(self.path)
Pipe = UnixPipe
import logging
from typing import Any
import requests
_logger = logging.getLogger(__name__)
url_template = 'http://localhost:{}/api/v1/nni{}'
timeout = 20
def get(port: int, api: str) -> Any:
url = url_template.format(port, api)
resp = requests.get(url, timeout=timeout)
if not resp.ok:
_logger.error('rest request GET %s %s failed: %s %s', port, api, resp.status_code, resp.text)
resp.raise_for_status()
return resp.json()
def post(port: int, api: str, data: Any) -> Any:
url = url_template.format(port, api)
resp = requests.post(url, json=data, timeout=timeout)
if not resp.ok:
_logger.error('rest request POST %s %s failed: %s %s', port, api, resp.status_code, resp.text)
resp.raise_for_status()
return resp.json()
def put(port: int, api: str, data: Any) -> None:
url = url_template.format(port, api)
resp = requests.put(url, json=data, timeout=timeout)
if not resp.ok:
_logger.error('rest request PUT %s %s failed: %s', port, api, resp.status_code)
resp.raise_for_status()
......@@ -2,4 +2,3 @@ from .operation import Operation
from .graph import *
from .execution import *
from .mutator import *
from .model_apis import nn
......@@ -15,7 +15,7 @@ def model_to_pytorch_script(model: Model, placement = None) -> str:
import_pkgs, graph_code = graph_to_pytorch_model(name, cell, placement = placement)
graphs.append(graph_code)
total_pkgs.update(import_pkgs)
# TODO: set correct PATH for the packages (after launch refactor)
# FIXME: set correct PATH for the packages (after launch refactor)
pkgs_code = '\n'.join(['import {}'.format(pkg) for pkg in total_pkgs])
return _PyTorchScriptTemplate.format(pkgs_code, '\n\n'.join(graphs)).strip()
......@@ -54,6 +54,22 @@ def _format_inputs(node: Node) -> List[str]:
inputs.append('{}[{}]'.format(edge.head.name, edge.head_slot))
return inputs
def _remove_prefix(names, graph_name):
"""
variables name (full name space) is too long,
shorten the name by removing the prefix ```graph_name```
"""
if isinstance(names, list):
converted_names = []
for name in names:
if name.startswith(graph_name):
converted_names.append(name[len(graph_name):])
else:
converted_names.append(name)
return converted_names
else:
return names[len(graph_name):] if names.startswith(graph_name) else names
def graph_to_pytorch_model(graph_name: str, graph: Graph, placement = None) -> str:
nodes = graph.topo_sort() # FIXME: topological sort is needed here
......@@ -67,7 +83,7 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement = None) -> s
pkg_name = node.operation.get_import_pkg()
if pkg_name is not None:
import_pkgs.add(pkg_name)
node_code = node.operation.to_init_code(node.name)
node_code = node.operation.to_init_code(_remove_prefix(node.name, graph_name))
if node_code is not None:
if placement and node in placement and len(node_code) > 0:
node_codes.append(f"{node_code}.to('{placement[node].device}')")
......@@ -77,6 +93,8 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement = None) -> s
if graph.input_node.operation.io_names is None:
input_code = '*_inputs'
else:
for name in graph.input_node.operation.io_names:
assert not name.startswith(graph_name)
input_code = ', '.join(graph.input_node.operation.io_names)
edge_codes = []
......@@ -84,9 +102,12 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement = None) -> s
for node in sorted_nodes:
if node.operation:
inputs = _format_inputs(node)
edge_codes.append(node.operation.to_forward_code(node.name, node.name, inputs))
inputs = _remove_prefix(inputs, graph_name)
node_name = _remove_prefix(node.name, graph_name)
edge_codes.append(node.operation.to_forward_code(node_name, node_name, inputs))
output_names = _format_inputs(graph.output_node)
output_names = _remove_prefix(output_names, graph_name)
if not output_names:
raise RuntimeError('"forward" function should have return value(s): {}, {}, {}'.format(output_names, graph_name, graph.output_node))
output_code = ', '.join(output_names)
......@@ -95,7 +116,7 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement = None) -> s
return import_pkgs, _PyTorchModelTemplate.format(
graph_name=('Graph' if graph_name == '_graph' else graph_name),
inputs=input_code,
outputs=', '.join(output_names),
outputs=output_code,
nodes=linebreak.join(node_codes),
edges=linebreak.join(edge_codes)
)
......@@ -109,6 +130,7 @@ import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
# FIXME: remove these two lines
import sys
sys.path.append("test/convert_test")
......
......@@ -4,9 +4,9 @@ import torch
from ..graph import Graph, Node, Edge, Model
from ..operation import Cell, Operation
from ..model_apis.nn import Placeholder
from ..nn.pytorch import Placeholder, LayerChoice, InputChoice
from .op_types import RETIARII_BASE_OPS, MODULE_EXCEPT_LIST, Type
from .op_types import MODULE_EXCEPT_LIST, Type
from .utils import build_full_name, _convert_name
......@@ -14,7 +14,7 @@ global_seq = 0
global_graph_id = 0
modules_arg = None
def _add_edge(ir_graph, node, graph_inputs, node_index, new_node, ignore_first=False):
def _add_edge(ir_graph, node, graph_inputs, node_index, new_node, output_remap, ignore_first=False):
"""
Parameters
----------
......@@ -25,9 +25,11 @@ def _add_edge(ir_graph, node, graph_inputs, node_index, new_node, ignore_first=F
node_index : Dict
new_node : Node
newly created ir node corresponding to `node`
output_remap : Dict
ignore_first : bool
if it is true, skip the first input
"""
is_single_input = (len([_input for _input in node.inputs()]) - (1 if ignore_first else 0)) == 1
new_node_input_idx = 0
for _input in node.inputs():
if ignore_first:
......@@ -39,6 +41,13 @@ def _add_edge(ir_graph, node, graph_inputs, node_index, new_node, ignore_first=F
idx = graph_inputs.index(_input)
src_node = ir_graph.input_node
src_node_idx = idx
elif _input in output_remap:
assert output_remap[_input].kind() == 'aten::append'
predecessor_node = output_remap[_input]
assert predecessor_node in node_index, 'predecessor node: {}'.format(predecessor_node)
src_node_idx = None
src_node = node_index[predecessor_node]
assert isinstance(src_node, Node)
else:
predecessor_node = _input.node()
assert predecessor_node in node_index, 'predecessor node: {}'.format(predecessor_node)
......@@ -50,44 +59,21 @@ def _add_edge(ir_graph, node, graph_inputs, node_index, new_node, ignore_first=F
idx = predecessor_outputs.index(_input)
ir_predecessor_node = node_index[predecessor_node]
src_node_idx = idx
# get source node
# the input is output of a basic node
assert isinstance(ir_predecessor_node, Node)
src_node = ir_predecessor_node
# handle destination node
dst_node = new_node
dst_node_idx = new_node_input_idx
if is_single_input:
dst_node_idx = None
else:
dst_node_idx = new_node_input_idx
# create edge
ir_graph.add_edge(head=(src_node, src_node_idx), tail=(dst_node, dst_node_idx))
new_node_input_idx += 1
def _handle_inputs(ir_graph, node, graph_inputs, node_index, module_name, ignore_first=False):
"""
create prim::GetAttr node when necessary. because for some cases prim::GetAttr nodes are removed,
for example, the prim::GetAttr used in prim::CallMethod
"""
global global_seq
for _input in node.inputs():
# for CallMethod and CallFunction
if ignore_first:
ignore_first = False
continue
if _input in graph_inputs:
continue
if _input.node().kind() == 'prim::Constant':
assert _input.node() in node_index
if _input.node().kind() == 'prim::GetAttr':
if _input.node() not in node_index:
node_type, attrs = handle_prim_attr_node(_input.node())
global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, Type.Attr, global_seq),
node_type, attrs)
node_index[_input.node()] = new_node
print('==handle inputs getattr==: ', _input.node())
def create_prim_constant_node(ir_graph, node, module_name):
global global_seq
attrs = {}
......@@ -100,46 +86,63 @@ def create_prim_constant_node(ir_graph, node, module_name):
def handle_prim_attr_node(node):
assert node.hasAttribute('name')
assert node.inputsAt(0).debugName() == 'self'
assert node.inputsAt(0).unique() == 0
attrs = {'name': node.s('name'), 'input': node.inputsAt(0).debugName()}
return node.kind(), attrs
def _remove_mangle(module_type_str):
return re.sub('\\.___torch_mangle_\\d+', '', module_type_str)
def remove_unconnected_nodes(ir_graph):
def remove_unconnected_nodes(ir_graph, targeted_type=None):
"""
Parameters
----------
ir_graph : Graph
our ir graph representation
targeted_type : str
nodes with ```targeted_type``` will be removed from graph if their fanout is 0.
```None``` means removing all the nodes whose fanout is 0.
"""
# build index of outputs of Node(s)
node_fanout = set()
for edge in ir_graph.edges:
if edge.head.id not in node_fanout:
node_fanout.add(edge.head.id)
to_removes = []
for hidden_node in ir_graph.hidden_nodes:
if hidden_node.id not in node_fanout:
assert isinstance(hidden_node, Node)
to_removes.append(hidden_node)
# some constant is not used, for example, function name as prim::Constant
assert hidden_node.operation.type == 'prim::Constant', 'the type is {}'.format(hidden_node.operation.type)
if targeted_type is None:
to_removes.append(hidden_node)
elif hidden_node.operation.type == targeted_type:
to_removes.append(hidden_node)
for hidden_node in to_removes:
hidden_node.remove()
def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, ir_graph):
"""
Convert torch script node to our node ir, and build our graph ir
Parameters
----------
script_module : torch.jit.RecursiveScriptModule
the torch script of ```module```
sm_graph : torch._C.Graph
the graph in torch script
module : nn.Module
the targeted pytorch module
module_name : str
```module```'s name
ir_model : Model
the whole graph ir
ir_graph : Graph
the graph ir of ```module```
Returns
-------
dict
the mapping from graph node to our graph ir node
"""
# handle inputs
graph_inputs = []
......@@ -153,6 +156,13 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
node_index = {} # graph node to graph ir node
# some node does not have output but it modifies a variable, for example aten::append
# %17 : Tensor[] = aten::append(%out.1, %16)
# %out.1 is updated, and %17 is None
# we add output to this type of node and connect it to the following node which uses %out.1
# key: tensor (%out.1), value: node (this node)
output_remap = {}
def handle_if_node(node):
"""
Parameters
......@@ -163,10 +173,10 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
Returns
-------
Node
the created ir node
the created node ir
"""
# only deal with input of prim::If is constant or attribute for now
# TODO: support constant expression
# will support constant expression in future
inputs = [i for i in node.inputs()]
assert len(inputs) == 1
if not inputs[0].node().kind() in ['prim::Constant', 'prim::GetAttr']:
......@@ -194,7 +204,7 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
Returns
-------
Node
the created ir node
the created node ir
"""
global global_seq
if node.kind() == 'prim::CallMethod':
......@@ -207,14 +217,41 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
assert submodule.kind() == 'prim::GetAttr'
assert submodule.hasAttribute('name')
submodule_name = submodule.s('name')
assert submodule_name in script_module._modules
submodule_full_name = build_full_name(module_name, submodule_name)
submodule_obj = getattr(module, submodule_name)
subgraph, sub_m_attrs = convert_module(script_module._modules[submodule_name],
submodule_obj,
submodule_full_name, ir_model)
# TODO: try not-connected placeholder in TorchScript
if submodule.inputsAt(0).debugName() == 'self':
# module is usually instantiated in __init__.
# when calling a module in forward,
# prim::GetAttr is used to obtain the module in torch script.
# therefore, we do this check for a module. example below:
# %25 : __torch__.xxx = prim::GetAttr[name="input_switch"](%self)
# %27 : Tensor = prim::CallMethod[name="forward"](%25, %out.1)
assert submodule_name in script_module._modules, "submodule_name: {} not in script_module {}".format(submodule_name, script_module._modules.keys())
submodule_full_name = build_full_name(module_name, submodule_name)
submodule_obj = getattr(module, submodule_name)
subgraph, sub_m_attrs = convert_module(script_module._modules[submodule_name],
submodule_obj,
submodule_full_name, ir_model)
else:
# %8 : __torch__.nni.retiarii.model_apis.nn.___torch_mangle_37.ModuleList = prim::GetAttr[name="cells"](%self)
# %10 : __torch__.darts_model.Cell = prim::GetAttr[name="0"](%8)
# %s1.4 : Tensor = prim::CallMethod[name="forward"](%10, %4, %4)
if submodule.inputsAt(0).type().name() == 'ModuleList':
# handle ModuleList
predecessor = submodule.inputsAt(0).node()
assert predecessor.kind() == 'prim::GetAttr'
assert predecessor.hasAttribute('name')
assert predecessor.inputsAt(0).debugName() == 'self'
predecessor_name = predecessor.s('name')
# FIXME: exchange
submodule_full_name = build_full_name(module_name, [submodule_name, predecessor_name])
predecessor_obj = getattr(module, predecessor_name)
submodule_obj = getattr(predecessor_obj, submodule_name)
subgraph, sub_m_attrs = convert_module(script_module._modules[predecessor_name]._modules[submodule_name],
submodule_obj, submodule_full_name, ir_model)
else:
raise RuntimeError('Unsupported module case: {}'.format(submodule.inputsAt(0).type().str()))
# TODO: match subgraph with maintained graphs
# build cell
if subgraph is None:
......@@ -222,14 +259,15 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
subcell = ir_graph.add_node(submodule_full_name, submodule_type_str, sub_m_attrs)
if isinstance(submodule_obj, Placeholder):
subcell.update_label(submodule_obj.label)
elif isinstance(submodule_obj, (LayerChoice, InputChoice)):
subcell.update_label(sub_m_attrs['label'])
else:
# Graph already created, create Cell for it
new_cell = Cell(cell_name=submodule_full_name, parameters=sub_m_attrs)
subcell = ir_graph.add_node(submodule_full_name, new_cell)
node_index[node] = subcell
_handle_inputs(ir_graph, node, graph_inputs, node_index, module_name, ignore_first=True)
# connect the cell into graph
_add_edge(ir_graph, node, graph_inputs, node_index, subcell, ignore_first=True)
_add_edge(ir_graph, node, graph_inputs, node_index, subcell, output_remap, ignore_first=True)
else:
raise RuntimeError('unsupported CallMethod {}'.format(node.s('name')))
elif node.kind() == 'prim::CallFunction':
......@@ -243,63 +281,162 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
func_node = ir_graph.add_node(build_full_name(module_name, func_name, global_seq),
'{}.{}'.format(func_type_str, func_name))
node_index[node] = func_node
_handle_inputs(ir_graph, node, graph_inputs, node_index, module_name, ignore_first=True)
_add_edge(ir_graph, node, graph_inputs, node_index, func_node, ignore_first=True)
_add_edge(ir_graph, node, graph_inputs, node_index, func_node, output_remap, ignore_first=True)
elif node.kind() == 'prim::Constant':
# TODO: how about calling a function twice? two constant nodes or one?
new_node = create_prim_constant_node(ir_graph, node, module_name)
node_index[node] = new_node
elif node.kind() == 'prim::ListConstruct':
global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, Type.ListConstruct, global_seq), node.kind())
node_index[node] = new_node
_handle_inputs(ir_graph, node, graph_inputs, node_index, module_name)
_add_edge(ir_graph, node, graph_inputs, node_index, new_node)
_add_edge(ir_graph, node, graph_inputs, node_index, new_node, output_remap)
elif node.kind() == 'aten::append':
global_seq += 1
aten_node = ir_graph.add_node(build_full_name(module_name, Type.BasicOpsPT[node.kind()], global_seq), node.kind())
node_index[node] = aten_node
_add_edge(ir_graph, node, graph_inputs, node_index, aten_node, output_remap)
output_remap[node.inputsAt(0)] = node
elif node.kind().startswith('aten::'):
# handle aten::XXX
global_seq += 1
aten_node = ir_graph.add_node(build_full_name(module_name, Type.BasicOpsPT[node.kind()], global_seq), node.kind())
node_index[node] = aten_node
_handle_inputs(ir_graph, node, graph_inputs, node_index, module_name)
_add_edge(ir_graph, node, graph_inputs, node_index, aten_node)
elif node.kind() == 'prim::Loop':
raise RuntimeError('Loop has not been supported yet!')
_add_edge(ir_graph, node, graph_inputs, node_index, aten_node, output_remap)
elif node.kind() == 'prim::GetAttr':
node_type, attrs = handle_prim_attr_node(node)
global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, Type.Attr, global_seq),
node_type, attrs)
node_index[node] = new_node
elif node.kind() == 'prim::If':
last_block_node = handle_if_node(node)
node_index[node] = last_block_node
elif node.kind() == 'prim::GetAttr':
pass
elif node.kind() == 'prim::Loop':
raise RuntimeError('Loop has not been supported yet!')
else:
raise RuntimeError('Unsupported kind: {}'.format(node.kind()))
if node in node_index:
return node_index[node]
else:
return None
return node_index[node]
for node in sm_graph.nodes():
handle_single_node(node)
return node_index
def merge_aten_slices(ir_graph):
"""
if there is aten::slice node, merge the consecutive ones together.
```x[:, :, 1:, 1:]``` in python code will be converted into 4 node in torch script,
each node has 5 inputs: tensor, dim, x, y, z (i.e., x:y:z)
"""
head_slice_nodes = []
has_slice_node = False
for node in ir_graph.hidden_nodes:
if node.operation.type == 'aten::slice':
has_slice_node = True
for pred in node.predecessors:
if pred.operation.type not in ['aten::slice', 'prim::Constant']:
head_slice_nodes.append(node)
break
if has_slice_node:
assert head_slice_nodes
for head_node in head_slice_nodes:
slot = 0
new_slice_node = ir_graph.add_node(build_full_name(head_node.name, 'merged'), Type.MergedSlice)
assert len(head_node.incoming_edges) == 5
for edge in head_node.incoming_edges:
edge.tail = new_slice_node
slot += 5
node = head_node
while len(node.successors) == 1 and node.successors[0].operation.type == 'aten::slice':
suc_node = node.successors[0]
assert len(suc_node.incoming_edges) == 5
for edge in suc_node.incoming_edges:
if edge.tail_slot == 0:
edge.remove()
else:
edge.tail = new_slice_node
edge.tail_slot = slot + edge.tail_slot - 1
slot += 4
ir_graph.hidden_nodes.remove(node)
node = suc_node
for edge in node.outgoing_edges:
edge.head = new_slice_node
ir_graph.hidden_nodes.remove(node)
def refine_graph(ir_graph):
"""
Do the following process to simplify graph:
1. remove unconnected constant node
2. remove unconnected getattr node
"""
# some constant is not used, for example, function name as prim::Constant
remove_unconnected_nodes(ir_graph, targeted_type='prim::Constant')
remove_unconnected_nodes(ir_graph, targeted_type='prim::GetAttr')
merge_aten_slices(ir_graph)
def _handle_layerchoice(module):
global modules_arg
m_attrs = {}
candidates = module.candidate_ops
for i, cand in enumerate(candidates):
assert id(cand) in modules_arg, 'id not exist: {}'.format(id(cand))
assert isinstance(modules_arg[id(cand)], dict)
m_attrs[f'choice_{i}'] = modules_arg[id(cand)]
m_attrs['label'] = module.label
return m_attrs
def _handle_inputchoice(module):
m_attrs = {}
m_attrs['n_chosen'] = module.n_chosen
m_attrs['reduction'] = module.reduction
m_attrs['label'] = module.label
return m_attrs
def convert_module(script_module, module, module_name, ir_model):
"""
Convert a module to its graph ir (i.e., Graph) along with its input arguments
Parameters
----------
script_module : torch.jit.RecursiveScriptModule
the script module of ```module``` obtained with torch.jit.script
module : nn.Module
the targeted module instance
module_name : str
the constructed name space of ```module```
ir_model : Model
the whole graph ir
Returns
-------
Graph
the built graph ir from module, ```None``` means do not further parse the module
dict
the input arguments of this module
"""
global global_graph_id
global modules_arg
assert id(module) in modules_arg, 'id not exist: {}, {}'.format(id(module), module_name)
if isinstance(modules_arg[id(module)], tuple):
positional_args, keyword_args = modules_arg[id(module)]
m_attrs = keyword_args
# TODO: remove positional args
m_attrs['positional_args'] = positional_args
else:
m_attrs = modules_arg[id(module)]
# NOTE: have not supported nested LayerChoice, i.e., a candidate module
# also has LayerChoice or InputChoice or ValueChoice
original_type_name = script_module.original_name
if original_type_name == Type.LayerChoice:
m_attrs = _handle_layerchoice(module)
return None, m_attrs
if original_type_name == Type.InputChoice:
m_attrs = _handle_inputchoice(module)
return None, m_attrs
if original_type_name in Type.Placeholder:
m_attrs = modules_arg[id(module)]
return None, m_attrs
if original_type_name in torch.nn.__dict__ and original_type_name not in MODULE_EXCEPT_LIST:
# this is a basic module from pytorch, no need to parse its graph
return None, m_attrs
if original_type_name in RETIARII_BASE_OPS:
m_attrs = modules_arg[id(module)]
return None, m_attrs
# handle TorchScript graph
......@@ -312,9 +449,7 @@ def convert_module(script_module, module, module_name, ir_model):
module_name, ir_model, ir_graph)
# handle graph outputs
graph_outputs = []
for _output in sm_graph.outputs():
graph_outputs.append(_output) # <class 'torch._C.Value'>
ir_graph._add_output(_convert_name(_output.debugName()))
predecessor_node_outputs = [o for o in _output.node().outputs()]
if len(predecessor_node_outputs) == 1:
......@@ -324,18 +459,20 @@ def convert_module(script_module, module, module_name, ir_model):
ir_graph.add_edge(head=(node_index[_output.node()], src_node_idx),
tail=(ir_graph.output_node, None))
remove_unconnected_nodes(ir_graph)
refine_graph(ir_graph)
ir_graph._register()
return ir_graph, m_attrs
return ir_graph, modules_arg[id(module)]
def convert_to_graph(script_module, module, recorded_modules_arg):
"""
Convert module to our graph ir, i.e., build a ```Model``` type
Parameters
----------
script_module : torch.jit.RecursiveScriptModule
the script module obtain with torch.jit.script
the script module obtained with torch.jit.script
module : nn.Module
the targeted module instance
recorded_modules_arg : dict
......@@ -350,6 +487,6 @@ def convert_to_graph(script_module, module, recorded_modules_arg):
model = Model(_internal=True)
module_name = '_model'
graph, m_attrs = convert_module(script_module, module, module_name, model)
convert_module(script_module, module, module_name, model)
return model
MODULE_EXCEPT_LIST = ['Sequential']
RETIARII_BASE_OPS = ['Placeholder']
class Type:
"""Node Type class
......@@ -7,12 +7,25 @@ class Type:
Attr = 'Attr'
Constant = 'Constant'
ListConstruct = 'ListConstruct'
LayerChoice = 'LayerChoice'
InputChoice = 'InputChoice'
ValueChoice = 'ValueChoice'
Placeholder = 'Placeholder'
MergedSlice = 'MergedSlice'
# deal with aten op
BasicOpsPT = {
'aten::mean': 'Mean',
'aten::relu': 'Relu',
'aten::add': 'Add'
'aten::add': 'Add',
'aten::__getitem__': 'getitem',
'aten::append': 'Append',
'aten::len': 'Len',
'aten::slice': 'Slice',
'aten::cat': 'Cat',
'aten::size': 'Size',
'aten::view': 'View'
}
BasicOpsTF = {}
\ No newline at end of file
def build_full_name(prefix, name, seq=None):
if isinstance(name, list):
name = '__'.join(name)
if seq is None:
return '{}__{}'.format(prefix, name)
else:
......
......@@ -34,6 +34,6 @@ def convert_to_visualize(graph_ir, vgraph):
subgraph.edge(src, dst)
def visualize_model(graph_ir):
vgraph = graphviz.Digraph('G', filename='vgraph', format='png')
vgraph = graphviz.Digraph('G', filename='vgraph', format='jpg')
convert_to_visualize(graph_ir, vgraph)
vgraph.render()
\ No newline at end of file
......@@ -13,8 +13,7 @@ _execution_engine = None
_default_listener = None
__all__ = ['get_execution_engine', 'get_and_register_default_listener',
'submit_models', 'wait_models', 'query_available_resources',
'get_base_model_ir', 'get_specified_mutators', 'get_trainer']
'submit_models', 'wait_models', 'query_available_resources']
def get_execution_engine() -> BaseExecutionEngine:
......@@ -45,27 +44,6 @@ def _get_search_space() -> 'Dict':
break
return engine.get_search_space()
def get_base_model_ir() -> 'Model':
search_space = _get_search_space()
return Model._load(search_space['base_model_ir'])
def get_specified_mutators() -> List['Mutator']:
search_space = _get_search_space()
applied_mutators = []
for each in search_space['applied_mutators']:
spec = importlib.util.spec_from_file_location("module.name", each['filepath'])
m = importlib.util.module_from_spec(spec)
spec.loader.exec_module(m)
#m.BlockMutator()
class_constructor = getattr(m, each['classname'])
mutator = class_constructor(**each['args'])
applied_mutators.append(mutator)
return applied_mutators
def get_trainer() -> 'BaseTrainer':
search_space = _get_search_space()
return search_space['training_approach']
def submit_models(*models: Model) -> None:
engine = get_execution_engine()
get_and_register_default_listener(engine)
......
......@@ -106,6 +106,7 @@ class BaseExecutionEngine(AbstractExecutionEngine):
Initialize the model, hand it over to trainer.
"""
graph_data = BaseGraphData.load(receive_trial_parameters())
# FIXME: update this part to dump code to a correct path!!!
with open('_generated_model.py', 'w') as f:
f.write(graph_data.model_script)
trainer_cls = utils.import_(graph_data.training_module)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment