"vscode:/vscode.git/clone" did not exist on "f8ba4017007dd189c8a0b9968a1f84b32e61a839"
Unverified Commit d165905d authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

[Retiarii] end2end (#3122)

parent 7d1acfbd
# FIXME: For demonstration only. It should not be here
from pathlib import Path
from nni.experiment import Experiment
from nni.algorithms.hpo.hyperopt_tuner.hyperopt_tuner import HyperoptTuner
tuner = HyperoptTuner('tpe')
search_space = {
"dropout_rate": { "_type": "uniform", "_value": [0.5, 0.9] },
"conv_size": { "_type": "choice", "_value": [2, 3, 5, 7] },
"hidden_size": { "_type": "choice", "_value": [124, 512, 1024] },
"batch_size": { "_type": "choice", "_value": [16, 32] },
"learning_rate": { "_type": "choice", "_value": [0.0001, 0.001, 0.01, 0.1] }
}
experiment = Experiment(tuner, 'local')
experiment.config.experiment_name = 'test'
experiment.config.trial_concurrency = 2
experiment.config.max_trial_number = 5
experiment.config.search_space = search_space
experiment.config.trial_command = 'python3 mnist.py'
experiment.config.trial_code_directory = Path(__file__).parent
experiment.run(8081, debug=True)
...@@ -3,6 +3,9 @@ ...@@ -3,6 +3,9 @@
__version__ = '999.0.0-developing' __version__ = '999.0.0-developing'
from .runtime.log import init_logger
init_logger()
from .runtime.env_vars import dispatcher_env_vars from .runtime.env_vars import dispatcher_env_vars
from .utils import ClassArgsValidator from .utils import ClassArgsValidator
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from .config import *
from .experiment import Experiment, RetiariiExperiment
from .nni_client import * from .nni_client import *
from .base import ExperimentConfig, RetiariiExpConfig
from .local import LocalExperimentConfig
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import dataclasses
import json
from pathlib import Path
from typing import Any, Dict, Optional, Union
@dataclasses.dataclass(init=False)
class ExperimentConfig:
experiment_name: str
search_space: Any
max_execution_seconds: Optional[int] = None
max_trial_number: Optional[int] = None
trial_concurrency: int
trial_command: str
trial_code_directory: Union[Path, str]
trial_gpu_number: int = 0
extra_config: Optional[Dict[str, str]] = None
_training_service: str
# these values will be used to create template object,
# and the user should overwrite them later.
_placeholder = {
'experiment_name': '_unset_',
'search_space': '_unset_',
'trial_concurrency': -1,
'trial_command': '_unset_',
'trial_code_directory': '_unset_'
}
# simple validation functions
# complex validation logic with special error message should go to `validate()` method instead
_value_range = {
'max_execution_seconds': lambda x: x is None or x > 0,
'max_trial_number': lambda x: x is None or x > 0,
'trial_concurrency': lambda x: x > 0,
'trial_gpu_number': lambda x: x >= 0
}
def __init__(self, **kwargs):
for field in dataclasses.fields(self):
if field.name in kwargs:
setattr(self, field.name, kwargs[field.name])
elif field.default != dataclasses.MISSING:
setattr(self, field.name, field.default)
else:
setattr(self, field.name, type(self)._placeholder[field.name])
def validate(self) -> None:
# check existence
for key, placeholder_value in type(self)._placeholder.items():
if getattr(self, key) == placeholder_value:
raise ValueError(f'Field "{key}" is not set')
# TODO: check type
# check value
for key, condition in type(self)._value_range.items():
value = getattr(self, key)
if not condition(value):
raise ValueError(f'Field "{key}" ({repr(value)}) out of range')
# check special fields
if not Path(self.trial_code_directory).is_dir():
raise ValueError(f'Trial code directory "{self.trial_code_directory}" does not exist or is not directory')
def experiment_config_json(self) -> Dict[str, Any]:
# this only contains the common part for most (if not all) training services
# subclasses should override it to provide exclusive fields
return {
'authorName': '_',
'experimentName': self.experiment_name,
'trialConcurrency': self.trial_concurrency,
'maxExecDuration': self.max_execution_seconds or (999 * 24 * 3600),
'maxTrialNum': self.max_trial_number or 99999,
'searchSpace': json.dumps(self.search_space),
'trainingServicePlatform': self._training_service,
'tuner': {'builtinTunerName': '_user_created_'},
**(self.extra_config or {})
}
def cluster_metadata_json(self) -> Any:
# the cluster metadata format is a total mess
# leave it to each subclass before we refactoring nni manager
raise NotImplementedError()
@staticmethod
def create_template(training_service: str) -> 'ExperimentConfig':
for cls in ExperimentConfig.__subclasses__():
for field in dataclasses.fields(cls):
if field.name == '_training_service' and field.default == training_service:
return cls()
raise ValueError(f'Unrecognized training service {training_service}')
class RetiariiExpConfig(ExperimentConfig):
@staticmethod
def create_template(training_service: str) -> 'ExperimentConfig':
for cls in ExperimentConfig.__subclasses__():
for field in dataclasses.fields(cls):
if field.name == '_training_service' and field.default == training_service:
config_obj = cls()
config_obj.search_space = {}
config_obj.trial_command = 'python3 -m nni.retiarii.trial_entry'
# FIXME: expose this field to users
config_obj.trial_code_directory = '../..'
return config_obj
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict
from .base import ExperimentConfig
@dataclass(init=False)
class LocalExperimentConfig(ExperimentConfig):
use_active_gpu: bool = False
_training_service: str = 'local'
def experiment_config_json(self) -> Dict[str, Any]:
ret = super().experiment_config_json()
ret['clusterMetaData'] = [
{
'key': 'codeDir',
'value': str(Path(self.trial_code_directory).resolve())
},
{
'key': 'command',
'value': self.trial_command
}
]
#ret['local_config'] = {
# 'useActiveGpu': self.use_active_gpu
#}
return ret
def cluster_metadata_json(self) -> Any:
return {
'trial_config': {
'command': self.trial_command,
'codeDir': str(Path(self.trial_code_directory).resolve())
}
}
import logging
from subprocess import Popen
import time
from threading import Thread
from typing import Optional, overload, List, Union, Callable
from nni.runtime.msg_dispatcher import MsgDispatcher
from nni.tuner import Tuner
from nni.retiarii.integration import RetiariiAdvisor
from nni.retiarii.converter.graph_gen import convert_to_graph
from .config import ExperimentConfig
from . import launcher
from .pipe import Pipe
from . import rest
_logger = logging.getLogger(__name__)
class Experiment:
"""
Controls an NNI experiment.
You may either create a new NNI experiment with construtor and `Experiment.start()`,
# TODO: or control an existing experiment with `Experiment.connect()`.
Attributes
----------
config
Experiment configuration.
port
Web UI port of the experiment, or `None` if it is not running.
"""
@overload
def __init__(self, tuner: Tuner, config: ExperimentConfig) -> None:
"""
Prepare an experiment.
Use `Experiment.start()` to launch it.
Parameters
----------
tuner
A tuner instance. # TODO: accessor / advisor
config
Experiment configuration.
"""
...
@overload
def __init__(self, tuner: Tuner, training_service: str) -> None:
"""
Prepare an experiment, leaving configuration fields to be set later.
Example usage::
experiment = Experiment(my_tuner, 'remote')
experiment.config.trial_command = 'python3 trial.py'
experiment.config.machines.append(RemoteMachineConfig(ip=..., user_name=...))
...
experiment.start(8080)
Parameters
----------
tuner
A tuner instance.
training_service
Name of training service.
Supported value: "local", "remote", "openpai"/"pai".
"""
...
def __init__(self, tuner: Tuner, config=None, training_service=None):
self.config: ExperimentConfig
self.port: Optional[int] = None
self._dispatcher = MsgDispatcher(tuner, None)
self._proc: Optional[Popen] = None
self._pipe: Optional[Pipe] = None
if isinstance(config, str):
config, training_service = None, config
if training_service == 'openpai':
training_service = 'pai'
if config is None:
self.config = ExperimentConfig.create_template(training_service)
else:
self.config = config
def start(self, port: int = 8080, debug: bool = False) -> None:
"""
Start the experiment in background.
This method will raise exception on failure.
If it returns, the experiment should have been successfully started.
Parameters
----------
port
The port of web UI.
debug
Whether to start in debug mode.
"""
if debug:
logging.getLogger('nni').setLevel(logging.DEBUG)
self._proc, self._pipe = launcher.start_experiment(self.config, port, debug)
assert self._proc is not None
assert self._pipe is not None
self.port = port # port will be None if start up failed
# dispatcher must be created after pipe initialized
# the logic to launch dispatcher in background should be refactored into dispatcher api
Thread(target=self._dispatcher.run).start()
# TODO: register experiment management metadata
def stop(self) -> None:
"""
Stop background experiment.
"""
self._proc.kill()
self._pipe.close()
self.port = None
self._proc = None
self._pipe = None
def run(self, port: int = 8080, debug: bool = False) -> str:
"""
Run the experiment.
This function will block until experiment finish or error.
"""
self.start(port, debug)
try:
while True:
time.sleep(10)
status = self.get_status()
if status in ['ERROR', 'STOPPED', 'NO_MORE_TRIAL']:
return status
finally:
self.stop()
def get_status(self) -> str:
if self.port is None:
raise RuntimeError('Experiment is not running')
resp = rest.get(self.port, '/check-status')
return resp['status']
class RetiariiExperiment(Experiment):
def __init__(self, base_model: 'nn.Module', trainer: 'BaseTrainer',
applied_mutators: List['Mutator'], strategy: 'BaseStrategy',
tca: 'TraceClassArguments' = None):
self.config: ExperimentConfig = None
self.port: Optional[int] = None
self.base_model = base_model
self.trainer = trainer
self.applied_mutators = applied_mutators
self.strategy = strategy
self.recorded_module_args = tca.recorded_arguments # FIXME: remove this argument
self._dispatcher = RetiariiAdvisor()
self._proc: Optional[Popen] = None
self._pipe: Optional[Pipe] = None
def _start_strategy(self):
import torch
script_module = torch.jit.script(self.base_model)
base_model = convert_to_graph(script_module, self.base_model, self.recorded_module_args)
assert id(self.trainer) in self.recorded_module_args
trainer_config = self.recorded_module_args[id(self.trainer)]
_logger.info('Starting strategy...')
Thread(target=self.strategy.run, args=(base_model, self.applied_mutators, trainer_config)).start()
_logger.info('Strategy started!')
def start(self, config: ExperimentConfig, port: int = 8080, debug: bool = False) -> None:
"""
Start the experiment in background.
This method will raise exception on failure.
If it returns, the experiment should have been successfully started.
Parameters
----------
port
The port of web UI.
debug
Whether to start in debug mode.
"""
if debug:
logging.getLogger('nni').setLevel(logging.DEBUG)
self._proc, self._pipe = launcher.start_experiment(config, port, debug)
assert self._proc is not None
assert self._pipe is not None
self.port = port # port will be None if start up failed
# dispatcher must be created after pipe initialized
# the logic to launch dispatcher in background should be refactored into dispatcher api
Thread(target=self._dispatcher.run).start()
self._start_strategy()
# TODO: register experiment management metadata
def stop(self) -> None:
"""
Stop background experiment.
"""
self._proc.kill()
self._pipe.close()
self.port = None
self._proc = None
self._pipe = None
def run(self, config: ExperimentConfig, port: int = 8080, debug: bool = False) -> str:
"""
Run the experiment.
This function will block until experiment finish or error.
"""
self.start(config, port, debug)
try:
while True:
time.sleep(10)
status = self.get_status()
if status in ['ERROR', 'STOPPED', 'NO_MORE_TRIAL']:
return status
finally:
self.stop()
def get_status(self) -> str:
if self.port is None:
raise RuntimeError('Experiment is not running')
resp = rest.get(self.port, '/check-status')
return resp['status']
import contextlib
from pathlib import Path
import socket
from subprocess import Popen
import sys
import time
from typing import Optional, Tuple
import nni.runtime.protocol
import nni_node
from .config import ExperimentConfig
from . import management
from .pipe import Pipe
from . import rest
def start_experiment(config: ExperimentConfig, port: int, debug: bool) -> Tuple[Popen, Pipe]:
pipe = None
proc = None
config.validate()
_ensure_port_idle(port)
if config._training_service == 'pai':
_ensure_port_idle(port + 1, 'OpenPAI requires an additional port')
exp_id = management.generate_experiment_id()
try:
print(f'Creating experiment {exp_id}...')
pipe = Pipe(exp_id)
proc = _start_rest_server(config, port, debug, exp_id, pipe.path)
pipe_file = pipe.connect()
nni.runtime.protocol._in_file = pipe_file
nni.runtime.protocol._out_file = pipe_file
print('Statring web server...')
_check_rest_server(port)
print('Setting up...')
_init_experiment(config, port, debug) # todo: kill on fail
return proc, pipe
except Exception as e:
print('Create experiment failed')
if proc is not None:
with contextlib.suppress(Exception):
proc.kill()
if pipe is not None:
with contextlib.suppress(Exception):
pipe.close()
raise e
def _ensure_port_idle(port: int, message: Optional[str] = None) -> None:
sock = socket.socket()
if sock.connect_ex(('localhost', port)) == 0:
sock.close()
message = f'(message)' if message else ''
raise RuntimeError(f'Port {port} is not idle {message}')
def _start_rest_server(config: ExperimentConfig, port: int, debug: bool, experiment_id: str, pipe_path: str) -> Popen:
args = {
'port': port,
'mode': config._training_service,
'experiment_id': experiment_id,
'start_mode': 'new',
'log_level': 'debug' if debug else 'info',
'dispatcher_pipe': pipe_path,
}
node_dir = Path(nni_node.__path__[0])
node = str(node_dir / ('node.exe' if sys.platform == 'win32' else 'node'))
main_js = str(node_dir / 'main.js')
cmd = [node, '--max-old-space-size=4096', main_js]
for arg_key, arg_value in args.items():
cmd.append('--' + arg_key)
cmd.append(str(arg_value))
return Popen(cmd, cwd=node_dir)
def _check_rest_server(port: int, retry: int = 10) -> None:
for _ in range(retry):
with contextlib.suppress(Exception):
rest.get(port, '/check-status')
return
time.sleep(1)
rest.get(port, '/check-status')
def _init_experiment(config: ExperimentConfig, port: int, debug: bool) -> None:
rest.put(port, '/experiment/cluster-metadata', config.cluster_metadata_json())
rest.post(port, '/experiment', config.experiment_config_json())
from pathlib import Path
import random
import string
def generate_experiment_id() -> str:
return ''.join(random.sample(string.ascii_lowercase + string.digits, 8))
def create_experiment_directory(experiment_id: str) -> Path:
path = Path.home() / 'nni-experiments' / experiment_id
path.mkdir(parents=True, exist_ok=True)
return path
# TODO: port shangning's work here, and use it in Experiment.start()/.stop()
...@@ -29,7 +29,7 @@ import requests ...@@ -29,7 +29,7 @@ import requests
import yaml import yaml
__all__ = [ __all__ = [
'Experiment', 'ExternalExperiment',
'TrialResult', 'TrialResult',
'TrialMetricData', 'TrialMetricData',
'TrialHyperParameters', 'TrialHyperParameters',
...@@ -229,7 +229,7 @@ class TrialJob: ...@@ -229,7 +229,7 @@ class TrialJob:
.format(self.trialJobId, self.status, self.hyperParameters, self.logPath, .format(self.trialJobId, self.status, self.hyperParameters, self.logPath,
self.startTime, self.endTime, self.finalMetricData, self.stderrPath) self.startTime, self.endTime, self.finalMetricData, self.stderrPath)
class Experiment: class ExternalExperiment:
def __init__(self): def __init__(self):
self._endpoint = None self._endpoint = None
self._exp_id = None self._exp_id = None
......
from io import BufferedIOBase
import os
import sys
if sys.platform == 'win32':
import _win32
import msvcrt
class WindowsPipe:
def __init__(self, experiment_id: str):
self.path: str = r'\\.\pipe\nni-' + experiment_id
self.file = None
self._handle = _win32.CreateNamedPipe(
self.path,
_win32.PIPE_ACCESS_DUPLEX,
_win32.PIPE_TYPE_MESSAGE | _win32.PIPE_READMODE_MESSAGE | _win32.PIPE_WAIT,
1,
8192,
8192,
0,
_win32.NULL
)
def connect(self) -> BufferedIOBase:
_win32.ConnectNamedPipe(self._handle, _win32.NULL)
fd = msvcrt.open_osfhandle(self._handle)
self.file = os.fdopen(fd, 'rwb')
return self.file
def close(self) -> None:
if self.file is not None:
self.file.close()
_win32.CloseHandle(self._handle)
Pipe = WindowsPipe
else:
import socket
from . import management
class UnixPipe:
def __init__(self, experiment_id: str):
self.path: str = str(management.create_experiment_directory(experiment_id) / 'dispatcher-pipe')
self.file = None
self._socket = socket.socket(socket.AF_UNIX)
self._socket.bind(self.path)
self._socket.listen(1) # only accepts one connection
def connect(self) -> BufferedIOBase:
conn, _ = self._socket.accept()
self.file = conn.makefile('rwb')
return self.file
def close(self) -> None:
if self.file is not None:
self.file.close()
self._socket.close()
os.unlink(self.path)
Pipe = UnixPipe
import logging
from typing import Any
import requests
_logger = logging.getLogger(__name__)
url_template = 'http://localhost:{}/api/v1/nni{}'
timeout = 20
def get(port: int, api: str) -> Any:
url = url_template.format(port, api)
resp = requests.get(url, timeout=timeout)
if not resp.ok:
_logger.error('rest request GET %s %s failed: %s %s', port, api, resp.status_code, resp.text)
resp.raise_for_status()
return resp.json()
def post(port: int, api: str, data: Any) -> Any:
url = url_template.format(port, api)
resp = requests.post(url, json=data, timeout=timeout)
if not resp.ok:
_logger.error('rest request POST %s %s failed: %s %s', port, api, resp.status_code, resp.text)
resp.raise_for_status()
return resp.json()
def put(port: int, api: str, data: Any) -> None:
url = url_template.format(port, api)
resp = requests.put(url, json=data, timeout=timeout)
if not resp.ok:
_logger.error('rest request PUT %s %s failed: %s', port, api, resp.status_code)
resp.raise_for_status()
...@@ -2,4 +2,3 @@ from .operation import Operation ...@@ -2,4 +2,3 @@ from .operation import Operation
from .graph import * from .graph import *
from .execution import * from .execution import *
from .mutator import * from .mutator import *
from .model_apis import nn
...@@ -15,7 +15,7 @@ def model_to_pytorch_script(model: Model, placement = None) -> str: ...@@ -15,7 +15,7 @@ def model_to_pytorch_script(model: Model, placement = None) -> str:
import_pkgs, graph_code = graph_to_pytorch_model(name, cell, placement = placement) import_pkgs, graph_code = graph_to_pytorch_model(name, cell, placement = placement)
graphs.append(graph_code) graphs.append(graph_code)
total_pkgs.update(import_pkgs) total_pkgs.update(import_pkgs)
# TODO: set correct PATH for the packages (after launch refactor) # FIXME: set correct PATH for the packages (after launch refactor)
pkgs_code = '\n'.join(['import {}'.format(pkg) for pkg in total_pkgs]) pkgs_code = '\n'.join(['import {}'.format(pkg) for pkg in total_pkgs])
return _PyTorchScriptTemplate.format(pkgs_code, '\n\n'.join(graphs)).strip() return _PyTorchScriptTemplate.format(pkgs_code, '\n\n'.join(graphs)).strip()
...@@ -54,6 +54,22 @@ def _format_inputs(node: Node) -> List[str]: ...@@ -54,6 +54,22 @@ def _format_inputs(node: Node) -> List[str]:
inputs.append('{}[{}]'.format(edge.head.name, edge.head_slot)) inputs.append('{}[{}]'.format(edge.head.name, edge.head_slot))
return inputs return inputs
def _remove_prefix(names, graph_name):
"""
variables name (full name space) is too long,
shorten the name by removing the prefix ```graph_name```
"""
if isinstance(names, list):
converted_names = []
for name in names:
if name.startswith(graph_name):
converted_names.append(name[len(graph_name):])
else:
converted_names.append(name)
return converted_names
else:
return names[len(graph_name):] if names.startswith(graph_name) else names
def graph_to_pytorch_model(graph_name: str, graph: Graph, placement = None) -> str: def graph_to_pytorch_model(graph_name: str, graph: Graph, placement = None) -> str:
nodes = graph.topo_sort() # FIXME: topological sort is needed here nodes = graph.topo_sort() # FIXME: topological sort is needed here
...@@ -67,7 +83,7 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement = None) -> s ...@@ -67,7 +83,7 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement = None) -> s
pkg_name = node.operation.get_import_pkg() pkg_name = node.operation.get_import_pkg()
if pkg_name is not None: if pkg_name is not None:
import_pkgs.add(pkg_name) import_pkgs.add(pkg_name)
node_code = node.operation.to_init_code(node.name) node_code = node.operation.to_init_code(_remove_prefix(node.name, graph_name))
if node_code is not None: if node_code is not None:
if placement and node in placement and len(node_code) > 0: if placement and node in placement and len(node_code) > 0:
node_codes.append(f"{node_code}.to('{placement[node].device}')") node_codes.append(f"{node_code}.to('{placement[node].device}')")
...@@ -77,6 +93,8 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement = None) -> s ...@@ -77,6 +93,8 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement = None) -> s
if graph.input_node.operation.io_names is None: if graph.input_node.operation.io_names is None:
input_code = '*_inputs' input_code = '*_inputs'
else: else:
for name in graph.input_node.operation.io_names:
assert not name.startswith(graph_name)
input_code = ', '.join(graph.input_node.operation.io_names) input_code = ', '.join(graph.input_node.operation.io_names)
edge_codes = [] edge_codes = []
...@@ -84,9 +102,12 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement = None) -> s ...@@ -84,9 +102,12 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement = None) -> s
for node in sorted_nodes: for node in sorted_nodes:
if node.operation: if node.operation:
inputs = _format_inputs(node) inputs = _format_inputs(node)
edge_codes.append(node.operation.to_forward_code(node.name, node.name, inputs)) inputs = _remove_prefix(inputs, graph_name)
node_name = _remove_prefix(node.name, graph_name)
edge_codes.append(node.operation.to_forward_code(node_name, node_name, inputs))
output_names = _format_inputs(graph.output_node) output_names = _format_inputs(graph.output_node)
output_names = _remove_prefix(output_names, graph_name)
if not output_names: if not output_names:
raise RuntimeError('"forward" function should have return value(s): {}, {}, {}'.format(output_names, graph_name, graph.output_node)) raise RuntimeError('"forward" function should have return value(s): {}, {}, {}'.format(output_names, graph_name, graph.output_node))
output_code = ', '.join(output_names) output_code = ', '.join(output_names)
...@@ -95,7 +116,7 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement = None) -> s ...@@ -95,7 +116,7 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement = None) -> s
return import_pkgs, _PyTorchModelTemplate.format( return import_pkgs, _PyTorchModelTemplate.format(
graph_name=('Graph' if graph_name == '_graph' else graph_name), graph_name=('Graph' if graph_name == '_graph' else graph_name),
inputs=input_code, inputs=input_code,
outputs=', '.join(output_names), outputs=output_code,
nodes=linebreak.join(node_codes), nodes=linebreak.join(node_codes),
edges=linebreak.join(edge_codes) edges=linebreak.join(edge_codes)
) )
...@@ -109,6 +130,7 @@ import torch.nn as nn ...@@ -109,6 +130,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
# FIXME: remove these two lines
import sys import sys
sys.path.append("test/convert_test") sys.path.append("test/convert_test")
......
This diff is collapsed.
MODULE_EXCEPT_LIST = ['Sequential'] MODULE_EXCEPT_LIST = ['Sequential']
RETIARII_BASE_OPS = ['Placeholder']
class Type: class Type:
"""Node Type class """Node Type class
...@@ -7,12 +7,25 @@ class Type: ...@@ -7,12 +7,25 @@ class Type:
Attr = 'Attr' Attr = 'Attr'
Constant = 'Constant' Constant = 'Constant'
ListConstruct = 'ListConstruct' ListConstruct = 'ListConstruct'
LayerChoice = 'LayerChoice'
InputChoice = 'InputChoice'
ValueChoice = 'ValueChoice'
Placeholder = 'Placeholder'
MergedSlice = 'MergedSlice'
# deal with aten op # deal with aten op
BasicOpsPT = { BasicOpsPT = {
'aten::mean': 'Mean', 'aten::mean': 'Mean',
'aten::relu': 'Relu', 'aten::relu': 'Relu',
'aten::add': 'Add' 'aten::add': 'Add',
'aten::__getitem__': 'getitem',
'aten::append': 'Append',
'aten::len': 'Len',
'aten::slice': 'Slice',
'aten::cat': 'Cat',
'aten::size': 'Size',
'aten::view': 'View'
} }
BasicOpsTF = {} BasicOpsTF = {}
\ No newline at end of file
def build_full_name(prefix, name, seq=None): def build_full_name(prefix, name, seq=None):
if isinstance(name, list):
name = '__'.join(name)
if seq is None: if seq is None:
return '{}__{}'.format(prefix, name) return '{}__{}'.format(prefix, name)
else: else:
......
...@@ -34,6 +34,6 @@ def convert_to_visualize(graph_ir, vgraph): ...@@ -34,6 +34,6 @@ def convert_to_visualize(graph_ir, vgraph):
subgraph.edge(src, dst) subgraph.edge(src, dst)
def visualize_model(graph_ir): def visualize_model(graph_ir):
vgraph = graphviz.Digraph('G', filename='vgraph', format='png') vgraph = graphviz.Digraph('G', filename='vgraph', format='jpg')
convert_to_visualize(graph_ir, vgraph) convert_to_visualize(graph_ir, vgraph)
vgraph.render() vgraph.render()
\ No newline at end of file
...@@ -13,8 +13,7 @@ _execution_engine = None ...@@ -13,8 +13,7 @@ _execution_engine = None
_default_listener = None _default_listener = None
__all__ = ['get_execution_engine', 'get_and_register_default_listener', __all__ = ['get_execution_engine', 'get_and_register_default_listener',
'submit_models', 'wait_models', 'query_available_resources', 'submit_models', 'wait_models', 'query_available_resources']
'get_base_model_ir', 'get_specified_mutators', 'get_trainer']
def get_execution_engine() -> BaseExecutionEngine: def get_execution_engine() -> BaseExecutionEngine:
...@@ -45,27 +44,6 @@ def _get_search_space() -> 'Dict': ...@@ -45,27 +44,6 @@ def _get_search_space() -> 'Dict':
break break
return engine.get_search_space() return engine.get_search_space()
def get_base_model_ir() -> 'Model':
search_space = _get_search_space()
return Model._load(search_space['base_model_ir'])
def get_specified_mutators() -> List['Mutator']:
search_space = _get_search_space()
applied_mutators = []
for each in search_space['applied_mutators']:
spec = importlib.util.spec_from_file_location("module.name", each['filepath'])
m = importlib.util.module_from_spec(spec)
spec.loader.exec_module(m)
#m.BlockMutator()
class_constructor = getattr(m, each['classname'])
mutator = class_constructor(**each['args'])
applied_mutators.append(mutator)
return applied_mutators
def get_trainer() -> 'BaseTrainer':
search_space = _get_search_space()
return search_space['training_approach']
def submit_models(*models: Model) -> None: def submit_models(*models: Model) -> None:
engine = get_execution_engine() engine = get_execution_engine()
get_and_register_default_listener(engine) get_and_register_default_listener(engine)
......
...@@ -106,6 +106,7 @@ class BaseExecutionEngine(AbstractExecutionEngine): ...@@ -106,6 +106,7 @@ class BaseExecutionEngine(AbstractExecutionEngine):
Initialize the model, hand it over to trainer. Initialize the model, hand it over to trainer.
""" """
graph_data = BaseGraphData.load(receive_trial_parameters()) graph_data = BaseGraphData.load(receive_trial_parameters())
# FIXME: update this part to dump code to a correct path!!!
with open('_generated_model.py', 'w') as f: with open('_generated_model.py', 'w') as f:
f.write(graph_data.model_script) f.write(graph_data.model_script)
trainer_cls = utils.import_(graph_data.training_module) trainer_cls = utils.import_(graph_data.training_module)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment