"...composable_kernel_onnxruntime.git" did not exist on "34c661e71cc8cf4753843a58786c8f6211ec5e22"
Commit 168d74e4 authored by Yuge Zhang's avatar Yuge Zhang Committed by QuanluZhang
Browse files

Add example for customized advisor and some refactoring (#1569)

Add example for customized advisor and some refactoring
parent f1210a9c
# **How To** - Customize Your Own Advisor
*Advisor targets the scenario that the automl algorithm wants the methods of both tuner and assessor. Advisor is similar to tuner on that it receives trial parameters request, final results, and generate trial parameters. Also, it is similar to assessor on that it receives intermediate results, trial's end state, and could send trial kill command. Note that, if you use Advisor, tuner and assessor are not allowed to be used at the same time.*
*Warning: API is subject to change in future releases.*
So, if user want to implement a customized Advisor, she/he only need to:
Advisor targets the scenario that the automl algorithm wants the methods of both tuner and assessor. Advisor is similar to tuner on that it receives trial parameters request, final results, and generate trial parameters. Also, it is similar to assessor on that it receives intermediate results, trial's end state, and could send trial kill command. Note that, if you use Advisor, tuner and assessor are not allowed to be used at the same time.
1. Define an Advisor inheriting from the MsgDispatcherBase class
1. Implement the methods with prefix `handle_` except `handle_request`
1. Configure your customized Advisor in experiment YAML config file
If a user want to implement a customized Advisor, she/he only needs to:
Here is an example:
**1) Define an Advisor inheriting from the MsgDispatcherBase class**
**1. Define an Advisor inheriting from the MsgDispatcherBase class.** For example:
```python
from nni.msg_dispatcher_base import MsgDispatcherBase
......@@ -20,13 +16,11 @@ class CustomizedAdvisor(MsgDispatcherBase):
...
```
**2) Implement the methods with prefix `handle_` except `handle_request`**
Please refer to the implementation of Hyperband ([src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py](https://github.com/Microsoft/nni/tree/master/src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py)) for how to implement the methods.
**2. Implement the methods with prefix `handle_` except `handle_request`**.. You might find [docs](https://nni.readthedocs.io/en/latest/sdk_reference.html#nni.msg_dispatcher_base.MsgDispatcherBase) for `MsgDispatcherBase` helpful.
**3) Configure your customized Advisor in experiment YAML config file**
**3. Configure your customized Advisor in experiment YAML config file.**
Similar to tuner and assessor. NNI needs to locate your customized Advisor class and instantiate the class, so you need to specify the location of the customized Advisor class and pass literal values as parameters to the \_\_init__ constructor.
Similar to tuner and assessor. NNI needs to locate your customized Advisor class and instantiate the class, so you need to specify the location of the customized Advisor class and pass literal values as parameters to the `__init__` constructor.
```yaml
advisor:
......@@ -38,3 +32,7 @@ advisor:
classArgs:
arg1: value1
```
## Example
Here we provide an [example](../../../examples/tuners/mnist_keras_customized_advisor).
......@@ -50,6 +50,9 @@ Assessor
Advisor
------------------------
.. autoclass:: nni.msg_dispatcher_base.MsgDispatcherBase
:members:
.. autoclass:: nni.hyperband_advisor.hyperband_advisor.Hyperband
:members:
......
authorName: default
experimentName: example_customized_advisor
trialConcurrency: 4
maxExecDuration: 1h
maxTrialNum: 200
#choice: local, remote, pai
trainingServicePlatform: local
searchSpacePath: search_space.json
#choice: true, false
useAnnotation: false
advisor:
codeDir: .
classFileName: dummy_advisor.py
className: DummyAdvisor
classArgs:
k: 3
trial:
command: python3 mnist_keras.py --epochs 100 --num_train 600 --num_test 100
codeDir: .
gpuNum: 0
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import logging
from collections import defaultdict
import json_tricks
import numpy as np
from nni import parameter_expressions as param
from nni.msg_dispatcher_base import MsgDispatcherBase
from nni.protocol import CommandType, send
from nni.utils import MetricType
logger = logging.getLogger('customized_advisor')
class DummyAdvisor(MsgDispatcherBase):
"""WARNING: Advisor API is subject to change in future releases.
This advisor creates a new trial when validation accuracy of any one of the trials just dropped.
The trial is killed if the validation accuracy doesn't improve for at least k last-reported metrics.
To demonstrate the high flexibility of writing advisors, we don't use tuners or the standard definition of
search space. This is just a demo to customize an advisor. It's not intended to make any sense.
"""
def __init__(self, k=3):
super(DummyAdvisor, self).__init__()
self.k = k
self.random_state = np.random.RandomState()
def handle_initialize(self, data):
logger.info("Advisor initialized: {}".format(data))
self.handle_update_search_space(data)
self.parameters_count = 0
self.parameter_best_metric = defaultdict(float)
self.parameter_cooldown = defaultdict(int)
send(CommandType.Initialized, '')
def _send_new_trial(self):
self.parameters_count += 1
new_trial = {
"parameter_id": self.parameters_count,
"parameters": {
"optimizer": param.choice(self.searchspace_json["optimizer"], self.random_state),
"learning_rate": param.loguniform(self.searchspace_json["learning_rate"][0],
self.searchspace_json["learning_rate"][1],
self.random_state)
},
"parameter_source": "algorithm"
}
logger.info("New trial sent: {}".format(new_trial))
send(CommandType.NewTrialJob, json_tricks.dumps(new_trial))
def handle_request_trial_jobs(self, data):
logger.info("Request trial jobs: {}".format(data))
for _ in range(data):
self._send_new_trial()
def handle_update_search_space(self, data):
logger.info("Search space update: {}".format(data))
self.searchspace_json = data
def handle_trial_end(self, data):
logger.info("Trial end: {}".format(data)) # do nothing
def handle_report_metric_data(self, data):
logger.info("Metric reported: {}".format(data))
if data['type'] == MetricType.REQUEST_PARAMETER:
raise ValueError("Request parameter not supported")
elif data["type"] == MetricType.PERIODICAL:
parameter_id = data["parameter_id"]
if data["value"] > self.parameter_best_metric[parameter_id]:
self.parameter_best_metric[parameter_id] = data["value"]
self.parameter_cooldown[parameter_id] = 0
else:
self.parameter_cooldown[parameter_id] += 1
logger.info("Accuracy dropped, cooldown {}, sending a new trial".format(
self.parameter_cooldown[parameter_id]))
self._send_new_trial()
if self.parameter_cooldown[parameter_id] >= self.k:
logger.info("Send kill signal to {}".format(data))
send(CommandType.KillTrialJob, json_tricks.dumps(data["trial_job_id"]))
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import argparse
import logging
import os
import keras
import numpy as np
from keras import backend as K
from keras.callbacks import TensorBoard
from keras.datasets import mnist
from keras.layers import Conv2D, Dense, Flatten, MaxPooling2D
from keras.models import Sequential
import nni
LOG = logging.getLogger('mnist_keras')
K.set_image_data_format('channels_last')
TENSORBOARD_DIR = os.environ['NNI_OUTPUT_DIR']
H, W = 28, 28
NUM_CLASSES = 10
def create_mnist_model(hyper_params, input_shape=(H, W, 1), num_classes=NUM_CLASSES):
"""
Create simple convolutional model
"""
layers = [
Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=input_shape),
Conv2D(64, (3, 3), activation='relu'),
MaxPooling2D(pool_size=(2, 2)),
Flatten(),
Dense(100, activation='relu'),
Dense(num_classes, activation='softmax')
]
model = Sequential(layers)
if hyper_params['optimizer'] == 'Adam':
optimizer = keras.optimizers.Adam(lr=hyper_params['learning_rate'])
else:
optimizer = keras.optimizers.SGD(lr=hyper_params['learning_rate'], momentum=0.9)
model.compile(loss=keras.losses.categorical_crossentropy, optimizer=optimizer, metrics=['accuracy'])
return model
def load_mnist_data(args):
"""
Load MNIST dataset
"""
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = (np.expand_dims(x_train, -1).astype(np.float) / 255.)[:args.num_train]
x_test = (np.expand_dims(x_test, -1).astype(np.float) / 255.)[:args.num_test]
y_train = keras.utils.to_categorical(y_train, NUM_CLASSES)[:args.num_train]
y_test = keras.utils.to_categorical(y_test, NUM_CLASSES)[:args.num_test]
LOG.debug('x_train shape: %s', (x_train.shape,))
LOG.debug('x_test shape: %s', (x_test.shape,))
return x_train, y_train, x_test, y_test
class SendMetrics(keras.callbacks.Callback):
"""
Keras callback to send metrics to NNI framework
"""
def on_epoch_end(self, epoch, logs={}):
"""
Run on end of each epoch
"""
LOG.debug(logs)
# Should this be val_acc or val_accuracy? Seems inconsistent behavior of Keras?
nni.report_intermediate_result(logs["val_accuracy"])
def train(args, params):
"""
Train model
"""
x_train, y_train, x_test, y_test = load_mnist_data(args)
model = create_mnist_model(params)
model.fit(x_train, y_train, batch_size=args.batch_size, epochs=args.epochs, verbose=1,
validation_data=(x_test, y_test), callbacks=[SendMetrics(), TensorBoard(log_dir=TENSORBOARD_DIR)])
_, acc = model.evaluate(x_test, y_test, verbose=0)
LOG.debug('Final result is: %d', acc)
nni.report_final_result(acc)
def generate_default_params():
"""
Generate default hyper parameters
"""
return {
'optimizer': 'Adam',
'learning_rate': 0.001
}
if __name__ == '__main__':
PARSER = argparse.ArgumentParser()
PARSER.add_argument("--batch_size", type=int, default=200, help="batch size", required=False)
PARSER.add_argument("--epochs", type=int, default=10, help="Train epochs", required=False)
PARSER.add_argument("--num_train", type=int, default=60000,
help="Number of train samples to be used, maximum 60000", required=False)
PARSER.add_argument("--num_test", type=int, default=10000, help="Number of test samples to be used, maximum 10000",
required=False)
ARGS, UNKNOWN = PARSER.parse_known_args()
# get parameters from tuner
RECEIVED_PARAMS = nni.get_next_parameter()
LOG.debug(RECEIVED_PARAMS)
PARAMS = generate_default_params()
PARAMS.update(RECEIVED_PARAMS)
# train
train(ARGS, PARAMS)
{
"README": "To demonstrate the flexibility, this search space does not follow the standard definition.",
"optimizer": ["Adam", "SGD"],
"learning_rate": [0.001, 0.1]
}
......@@ -21,18 +21,17 @@
hyperband_advisor.py
"""
import sys
import math
import copy
import logging
import numpy as np
import json_tricks
import math
import sys
from nni.protocol import CommandType, send
import json_tricks
import numpy as np
from nni.common import multi_phase_enabled
from nni.msg_dispatcher_base import MsgDispatcherBase
from nni.common import init_logger, multi_phase_enabled
from nni.protocol import CommandType, send
from nni.utils import NodeType, OptimizeMode, MetricType, extract_scalar_reward
import nni.parameter_expressions as parameter_expressions
_logger = logging.getLogger(__name__)
......@@ -53,6 +52,7 @@ def create_parameter_id():
_next_parameter_id += 1
return _next_parameter_id - 1
def create_bracket_parameter_id(brackets_id, brackets_curr_decay, increased_id=-1):
"""Create a full id for a specific bracket's hyperparameter configuration
......@@ -77,6 +77,7 @@ def create_bracket_parameter_id(brackets_id, brackets_curr_decay, increased_id=-
increased_id])
return params_id
def json2parameter(ss_spec, random_state):
"""Randomly generate values for hyperparameters from hyperparameter space i.e., x.
......@@ -100,7 +101,7 @@ def json2parameter(ss_spec, random_state):
_index = random_state.randint(len(_value))
chosen_params = json2parameter(ss_spec[NodeType.VALUE][_index], random_state)
else:
chosen_params = eval('parameter_expressions.' + # pylint: disable=eval-used
chosen_params = eval('parameter_expressions.' + # pylint: disable=eval-used
_type)(*(_value + [random_state]))
else:
chosen_params = dict()
......@@ -114,6 +115,7 @@ def json2parameter(ss_spec, random_state):
chosen_params = copy.deepcopy(ss_spec)
return chosen_params
class Bracket():
"""A bracket in Hyperband, all the information of a bracket is managed by an instance of this class
......@@ -137,12 +139,12 @@ class Bracket():
self.bracket_id = s
self.s_max = s_max
self.eta = eta
self.n = math.ceil((s_max + 1) * (eta**s) / (s + 1) - _epsilon) # pylint: disable=invalid-name
self.r = R / eta**s # pylint: disable=invalid-name
self.n = math.ceil((s_max + 1) * (eta ** s) / (s + 1) - _epsilon) # pylint: disable=invalid-name
self.r = R / eta ** s # pylint: disable=invalid-name
self.i = 0
self.hyper_configs = [] # [ {id: params}, {}, ... ]
self.configs_perf = [] # [ {id: [seq, acc]}, {}, ... ]
self.num_configs_to_run = [] # [ n, n, n, ... ]
self.hyper_configs = [] # [ {id: params}, {}, ... ]
self.configs_perf = [] # [ {id: [seq, acc]}, {}, ... ]
self.num_configs_to_run = [] # [ n, n, n, ... ]
self.num_finished_configs = [] # [ n, n, n, ... ]
self.optimize_mode = OptimizeMode(optimize_mode)
self.no_more_trial = False
......@@ -153,7 +155,7 @@ class Bracket():
def get_n_r(self):
"""return the values of n and r for the next round"""
return math.floor(self.n / self.eta**self.i + _epsilon), math.floor(self.r * self.eta**self.i + _epsilon)
return math.floor(self.n / self.eta ** self.i + _epsilon), math.floor(self.r * self.eta ** self.i + _epsilon)
def increase_i(self):
"""i means the ith round. Increase i by 1"""
......@@ -185,7 +187,6 @@ class Bracket():
else:
self.configs_perf[i][parameter_id] = [seq, value]
def inform_trial_end(self, i):
"""If the trial is finished and the corresponding round (i.e., i) has all its trials finished,
it will choose the top k trials for the next round (i.e., i+1)
......@@ -195,16 +196,17 @@ class Bracket():
i: int
the ith round
"""
global _KEY # pylint: disable=global-statement
global _KEY # pylint: disable=global-statement
self.num_finished_configs[i] += 1
_logger.debug('bracket id: %d, round: %d %d, finished: %d, all: %d', self.bracket_id, self.i, i, self.num_finished_configs[i], self.num_configs_to_run[i])
_logger.debug('bracket id: %d, round: %d %d, finished: %d, all: %d', self.bracket_id, self.i, i,
self.num_finished_configs[i], self.num_configs_to_run[i])
if self.num_finished_configs[i] >= self.num_configs_to_run[i] \
and self.no_more_trial is False:
and self.no_more_trial is False:
# choose candidate configs from finished configs to run in the next round
assert self.i == i + 1
this_round_perf = self.configs_perf[i]
if self.optimize_mode is OptimizeMode.Maximize:
sorted_perf = sorted(this_round_perf.items(), key=lambda kv: kv[1][1], reverse=True) # reverse
sorted_perf = sorted(this_round_perf.items(), key=lambda kv: kv[1][1], reverse=True) # reverse
else:
sorted_perf = sorted(this_round_perf.items(), key=lambda kv: kv[1][1])
_logger.debug('bracket %s next round %s, sorted hyper configs: %s', self.bracket_id, self.i, sorted_perf)
......@@ -214,7 +216,7 @@ class Bracket():
for k in range(next_n):
params_id = sorted_perf[k][0]
params = self.hyper_configs[i][params_id]
params[_KEY] = next_r # modify r
params[_KEY] = next_r # modify r
# generate new id
increased_id = params_id.split('_')[-1]
new_id = create_bracket_parameter_id(self.bracket_id, self.i, increased_id)
......@@ -223,7 +225,7 @@ class Bracket():
return [[key, value] for key, value in hyper_configs.items()]
return None
def get_hyperparameter_configurations(self, num, r, searchspace_json, random_state): # pylint: disable=invalid-name
def get_hyperparameter_configurations(self, num, r, searchspace_json, random_state): # pylint: disable=invalid-name
"""Randomly generate num hyperparameter configurations from search space
Parameters
......@@ -236,7 +238,7 @@ class Bracket():
list
a list of hyperparameter configurations. Format: [[key1, value1], [key2, value2], ...]
"""
global _KEY # pylint: disable=global-statement
global _KEY # pylint: disable=global-statement
assert self.i == 0
hyperparameter_configs = dict()
for _ in range(num):
......@@ -263,6 +265,7 @@ class Bracket():
self.num_configs_to_run.append(len(hyper_configs))
self.increase_i()
class Hyperband(MsgDispatcherBase):
"""Hyperband inherit from MsgDispatcherBase rather than Tuner, because it integrates both tuner's functions and assessor's functions.
This is an implementation that could fully leverage available resources, i.e., high parallelism.
......@@ -277,14 +280,15 @@ class Hyperband(MsgDispatcherBase):
optimize_mode: str
optimize mode, 'maximize' or 'minimize'
"""
def __init__(self, R=60, eta=3, optimize_mode='maximize'):
"""B = (s_max + 1)R"""
super(Hyperband, self).__init__()
self.R = R # pylint: disable=invalid-name
self.R = R # pylint: disable=invalid-name
self.eta = eta
self.brackets = dict() # dict of Bracket
self.generated_hyper_configs = [] # all the configs waiting for run
self.completed_hyper_configs = [] # all the completed configs
self.brackets = dict() # dict of Bracket
self.generated_hyper_configs = [] # all the configs waiting for run
self.completed_hyper_configs = [] # all the completed configs
self.s_max = math.floor(math.log(self.R, self.eta) + _epsilon)
self.curr_s = self.s_max
......@@ -302,12 +306,11 @@ class Hyperband(MsgDispatcherBase):
self.job_id_para_id_map = dict()
def handle_initialize(self, data):
"""data is search space
"""callback for initializing the advisor
Parameters
----------
data: int
number of trial jobs
data: dict
search space
"""
self.handle_update_search_space(data)
send(CommandType.Initialized, '')
......@@ -348,14 +351,8 @@ class Hyperband(MsgDispatcherBase):
}
return ret
def handle_update_search_space(self, data):
"""data: JSON object, which is search space
Parameters
----------
data: int
number of trial jobs
"""
self.searchspace_json = data
self.random_state = np.random.RandomState()
......
......@@ -42,8 +42,9 @@ We need this because NNI manager may send metrics after reporting a trial ended.
TODO: move this logic to NNI manager
'''
def _sort_history(history):
ret = [ ]
ret = []
for i, _ in enumerate(history):
if i in history:
ret.append(history[i])
......@@ -51,17 +52,20 @@ def _sort_history(history):
break
return ret
# Tuner global variables
_next_parameter_id = 0
_trial_params = {}
'''key: trial job ID; value: parameters'''
_customized_parameter_ids = set()
def _create_parameter_id():
global _next_parameter_id # pylint: disable=global-statement
_next_parameter_id += 1
return _next_parameter_id - 1
def _pack_parameter(parameter_id, params, customized=False, trial_job_id=None, parameter_index=None):
_trial_params[parameter_id] = params
ret = {
......@@ -77,6 +81,7 @@ def _pack_parameter(parameter_id, params, customized=False, trial_job_id=None, p
ret['parameter_index'] = 0
return json_tricks.dumps(ret)
class MsgDispatcher(MsgDispatcherBase):
def __init__(self, tuner, assessor=None):
super(MsgDispatcher, self).__init__()
......@@ -123,7 +128,7 @@ class MsgDispatcher(MsgDispatcherBase):
def handle_import_data(self, data):
"""Import additional data for tuning
data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
data: a list of dictionaries, each of which has at least two keys, 'parameter' and 'value'
"""
self.tuner.import_data(data)
......@@ -154,7 +159,8 @@ class MsgDispatcher(MsgDispatcherBase):
param = self.tuner.generate_parameters(param_id, trial_job_id=data['trial_job_id'])
except NoMoreTrialError:
param = None
send(CommandType.SendTrialJobParameter, _pack_parameter(param_id, param, trial_job_id=data['trial_job_id'], parameter_index=data['parameter_index']))
send(CommandType.SendTrialJobParameter, _pack_parameter(param_id, param, trial_job_id=data['trial_job_id'],
parameter_index=data['parameter_index']))
else:
raise ValueError('Data type not supported: {}'.format(data['type']))
......@@ -188,7 +194,8 @@ class MsgDispatcher(MsgDispatcherBase):
customized = True
else:
customized = False
self.tuner.receive_trial_result(id_, _trial_params[id_], value, customized=customized, trial_job_id=data.get('trial_job_id'))
self.tuner.receive_trial_result(id_, _trial_params[id_], value, customized=customized,
trial_job_id=data.get('trial_job_id'))
def _handle_intermediate_metric_data(self, data):
"""Call assessor to process intermediate results
......@@ -223,7 +230,8 @@ class MsgDispatcher(MsgDispatcherBase):
_logger.debug('BAD, kill %s', trial_job_id)
send(CommandType.KillTrialJob, json_tricks.dumps(trial_job_id))
# notify tuner
_logger.debug('env var: NNI_INCLUDE_INTERMEDIATE_RESULTS: [%s]', dispatcher_env_vars.NNI_INCLUDE_INTERMEDIATE_RESULTS)
_logger.debug('env var: NNI_INCLUDE_INTERMEDIATE_RESULTS: [%s]',
dispatcher_env_vars.NNI_INCLUDE_INTERMEDIATE_RESULTS)
if dispatcher_env_vars.NNI_INCLUDE_INTERMEDIATE_RESULTS == 'true':
self._earlystop_notify_tuner(data)
else:
......
......@@ -18,7 +18,6 @@
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
#import json_tricks
import os
import threading
import logging
......@@ -39,7 +38,12 @@ _logger = logging.getLogger(__name__)
QUEUE_LEN_WARNING_MARK = 20
_worker_fast_exit_on_terminate = True
class MsgDispatcherBase(Recoverable):
"""This is where tuners and assessors are not defined yet.
Inherits this class to make your own advisor.
"""
def __init__(self):
if multi_thread_enabled():
self.pool = ThreadPool()
......@@ -49,7 +53,8 @@ class MsgDispatcherBase(Recoverable):
self.default_command_queue = Queue()
self.assessor_command_queue = Queue()
self.default_worker = threading.Thread(target=self.command_queue_worker, args=(self.default_command_queue,))
self.assessor_worker = threading.Thread(target=self.command_queue_worker, args=(self.assessor_command_queue,))
self.assessor_worker = threading.Thread(target=self.command_queue_worker,
args=(self.assessor_command_queue,))
self.default_worker.start()
self.assessor_worker.start()
self.worker_exceptions = []
......@@ -72,7 +77,8 @@ class MsgDispatcherBase(Recoverable):
if multi_thread_enabled():
result = self.pool.map_async(self.process_command_thread, [(command, data)])
self.thread_results.append(result)
if any([thread_result.ready() and not thread_result.successful() for thread_result in self.thread_results]):
if any([thread_result.ready() and not thread_result.successful() for thread_result in
self.thread_results]):
_logger.debug('Caught thread exception')
break
else:
......@@ -112,7 +118,8 @@ class MsgDispatcherBase(Recoverable):
def enqueue_command(self, command, data):
"""Enqueue command into command queues
"""
if command == CommandType.TrialEnd or (command == CommandType.ReportMetricData and data['type'] == 'PERIODICAL'):
if command == CommandType.TrialEnd or (
command == CommandType.ReportMetricData and data['type'] == 'PERIODICAL'):
self.assessor_command_queue.put((command, data))
else:
self.default_command_queue.put((command, data))
......@@ -142,14 +149,14 @@ class MsgDispatcherBase(Recoverable):
_logger.debug('process_command: command: [{}], data: [{}]'.format(command, data))
command_handlers = {
# Tunner commands:
# Tuner commands:
CommandType.Initialize: self.handle_initialize,
CommandType.RequestTrialJobs: self.handle_request_trial_jobs,
CommandType.UpdateSearchSpace: self.handle_update_search_space,
CommandType.ImportData: self.handle_import_data,
CommandType.AddCustomizedTrialJob: self.handle_add_customized_trial,
# Tunner/Assessor commands:
# Tuner/Assessor commands:
CommandType.ReportMetricData: self.handle_report_metric_data,
CommandType.TrialEnd: self.handle_trial_end,
......@@ -163,22 +170,88 @@ class MsgDispatcherBase(Recoverable):
pass
def handle_initialize(self, data):
"""Initialize search space and tuner, if any
This method is meant to be called only once for each experiment, after calling this method,
dispatcher should `send(CommandType.Initialized, '')`, to set the status of the experiment to be "INITIALIZED".
Parameters
----------
data: dict
search space
"""
raise NotImplementedError('handle_initialize not implemented')
def handle_request_trial_jobs(self, data):
"""The message dispatcher is demanded to generate `data` trial jobs.
These trial jobs should be sent via `send(CommandType.NewTrialJob, json_tricks.dumps(parameter))`,
where `parameter` will be received by NNI Manager and eventually accessible to trial jobs as "next parameter".
Semantically, message dispatcher should do this `send` exactly `data` times.
The JSON sent by this method should follow the format of
{
"parameter_id": 42
"parameters": {
// this will be received by trial
},
"parameter_source": "algorithm" // optional
}
Parameters
----------
data: int
number of trial jobs
"""
raise NotImplementedError('handle_request_trial_jobs not implemented')
def handle_update_search_space(self, data):
raise NotImplementedError('handle_update_search_space not implemented')
"""This method will be called when search space is updated.
It's recommended to call this method in `handle_initialize` to initialize search space.
*No need to* notify NNI Manager when this update is done.
Parameters
----------
data: dict
search space
"""
raise NotImplementedError('handle_update_search_space not implemented')
def handle_import_data(self, data):
"""Import previous data when experiment is resumed.
Parameters
----------
data: list
a list of dictionaries, each of which has at least two keys, 'parameter' and 'value'
"""
raise NotImplementedError('handle_import_data not implemented')
def handle_add_customized_trial(self, data):
"""Experimental API. Not recommended for usage.
"""
raise NotImplementedError('handle_add_customized_trial not implemented')
def handle_report_metric_data(self, data):
"""Called when metric data is reported or new parameters are requested (for multiphase).
When new parameters are requested, this method should send a new parameter.
Parameters
----------
data: dict
a dict which contains 'parameter_id', 'value', 'trial_job_id', 'type', 'sequence'.
type: can be `MetricType.REQUEST_PARAMETER`, `MetricType.FINAL` or `MetricType.PERIODICAL`.
`REQUEST_PARAMETER` is used to request new parameters for multiphase trial job. In this case,
the dict will contain additional keys: `trial_job_id`, `parameter_index`. Refer to `msg_dispatcher.py`
as an example.
Raises
------
ValueError
Data type is not supported
"""
raise NotImplementedError('handle_report_metric_data not implemented')
def handle_trial_end(self, data):
"""Called when the state of one of the trials is changed
Parameters
----------
data: dict
a dict with keys: trial_job_id, event, hyper_params.
trial_job_id: the id generated by training service.
event: the job’s state.
hyper_params: the string that is sent by message dispatcher during the creation of trials.
"""
raise NotImplementedError('handle_trial_end not implemented')
......@@ -28,9 +28,9 @@ from io import BytesIO
import json
from unittest import TestCase, main
_trials = []
_end_trials = []
_trials = [ ]
_end_trials = [ ]
class NaiveAssessor(Assessor):
def assess_trial(self, trial_job_id, trial_history):
......@@ -47,12 +47,14 @@ class NaiveAssessor(Assessor):
_in_buf = BytesIO()
_out_buf = BytesIO()
def _reverse_io():
_in_buf.seek(0)
_out_buf.seek(0)
nni.protocol._out_file = _in_buf
nni.protocol._in_file = _out_buf
def _restore_io():
_in_buf.seek(0)
_out_buf.seek(0)
......
......@@ -32,7 +32,7 @@ from unittest import TestCase, main
class NaiveTuner(Tuner):
def __init__(self):
self.param = 0
self.trial_results = [ ]
self.trial_results = []
self.search_space = None
self.accept_customized_trials()
......@@ -57,12 +57,14 @@ class NaiveTuner(Tuner):
_in_buf = BytesIO()
_out_buf = BytesIO()
def _reverse_io():
_in_buf.seek(0)
_out_buf.seek(0)
nni.protocol._out_file = _in_buf
nni.protocol._in_file = _out_buf
def _restore_io():
_in_buf.seek(0)
_out_buf.seek(0)
......@@ -70,7 +72,6 @@ def _restore_io():
nni.protocol._out_file = _out_buf
class TunerTestCase(TestCase):
def test_tuner(self):
_reverse_io() # now we are sending to Tuner's incoming stream
......@@ -94,21 +95,20 @@ class TunerTestCase(TestCase):
self.assertEqual(e.args[0], 'Unsupported command: CommandType.KillTrialJob')
_reverse_io() # now we are receiving from Tuner's outgoing stream
self._assert_params(0, 2, [ ], None)
self._assert_params(1, 4, [ ], None)
self._assert_params(0, 2, [], None)
self._assert_params(1, 4, [], None)
command, data = receive() # this one is customized
data = json.loads(data)
self.assertIs(command, CommandType.NewTrialJob)
self.assertEqual(data['parameter_id'], 2)
self.assertEqual(data['parameter_source'], 'customized')
self.assertEqual(data['parameters'], { 'param': -1 })
self.assertEqual(data['parameters'], {'param': -1})
self._assert_params(3, 6, [[1,4,11,False], [2,-1,22,True]], {'name':'SS0'})
self._assert_params(3, 6, [[1, 4, 11, False], [2, -1, 22, True]], {'name': 'SS0'})
self.assertEqual(len(_out_buf.read()), 0) # no more commands
def _assert_params(self, parameter_id, param, trial_results, search_space):
command, data = receive()
self.assertIs(command, CommandType.NewTrialJob)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment