"git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "eab0da154b4c8cf68f95ef294844649c9e17ee60"
Commit 168d74e4 authored by Yuge Zhang's avatar Yuge Zhang Committed by QuanluZhang
Browse files

Add example for customized advisor and some refactoring (#1569)

Add example for customized advisor and some refactoring
parent f1210a9c
# **How To** - Customize Your Own Advisor # **How To** - Customize Your Own Advisor
*Advisor targets the scenario that the automl algorithm wants the methods of both tuner and assessor. Advisor is similar to tuner on that it receives trial parameters request, final results, and generate trial parameters. Also, it is similar to assessor on that it receives intermediate results, trial's end state, and could send trial kill command. Note that, if you use Advisor, tuner and assessor are not allowed to be used at the same time.* *Warning: API is subject to change in future releases.*
So, if user want to implement a customized Advisor, she/he only need to: Advisor targets the scenario that the automl algorithm wants the methods of both tuner and assessor. Advisor is similar to tuner on that it receives trial parameters request, final results, and generate trial parameters. Also, it is similar to assessor on that it receives intermediate results, trial's end state, and could send trial kill command. Note that, if you use Advisor, tuner and assessor are not allowed to be used at the same time.
1. Define an Advisor inheriting from the MsgDispatcherBase class If a user want to implement a customized Advisor, she/he only needs to:
1. Implement the methods with prefix `handle_` except `handle_request`
1. Configure your customized Advisor in experiment YAML config file
Here is an example: **1. Define an Advisor inheriting from the MsgDispatcherBase class.** For example:
**1) Define an Advisor inheriting from the MsgDispatcherBase class**
```python ```python
from nni.msg_dispatcher_base import MsgDispatcherBase from nni.msg_dispatcher_base import MsgDispatcherBase
...@@ -20,13 +16,11 @@ class CustomizedAdvisor(MsgDispatcherBase): ...@@ -20,13 +16,11 @@ class CustomizedAdvisor(MsgDispatcherBase):
... ...
``` ```
**2) Implement the methods with prefix `handle_` except `handle_request`** **2. Implement the methods with prefix `handle_` except `handle_request`**.. You might find [docs](https://nni.readthedocs.io/en/latest/sdk_reference.html#nni.msg_dispatcher_base.MsgDispatcherBase) for `MsgDispatcherBase` helpful.
Please refer to the implementation of Hyperband ([src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py](https://github.com/Microsoft/nni/tree/master/src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py)) for how to implement the methods.
**3) Configure your customized Advisor in experiment YAML config file** **3. Configure your customized Advisor in experiment YAML config file.**
Similar to tuner and assessor. NNI needs to locate your customized Advisor class and instantiate the class, so you need to specify the location of the customized Advisor class and pass literal values as parameters to the \_\_init__ constructor. Similar to tuner and assessor. NNI needs to locate your customized Advisor class and instantiate the class, so you need to specify the location of the customized Advisor class and pass literal values as parameters to the `__init__` constructor.
```yaml ```yaml
advisor: advisor:
...@@ -38,3 +32,7 @@ advisor: ...@@ -38,3 +32,7 @@ advisor:
classArgs: classArgs:
arg1: value1 arg1: value1
``` ```
## Example
Here we provide an [example](../../../examples/tuners/mnist_keras_customized_advisor).
...@@ -50,6 +50,9 @@ Assessor ...@@ -50,6 +50,9 @@ Assessor
Advisor Advisor
------------------------ ------------------------
.. autoclass:: nni.msg_dispatcher_base.MsgDispatcherBase
:members:
.. autoclass:: nni.hyperband_advisor.hyperband_advisor.Hyperband .. autoclass:: nni.hyperband_advisor.hyperband_advisor.Hyperband
:members: :members:
......
authorName: default
experimentName: example_customized_advisor
trialConcurrency: 4
maxExecDuration: 1h
maxTrialNum: 200
#choice: local, remote, pai
trainingServicePlatform: local
searchSpacePath: search_space.json
#choice: true, false
useAnnotation: false
advisor:
codeDir: .
classFileName: dummy_advisor.py
className: DummyAdvisor
classArgs:
k: 3
trial:
command: python3 mnist_keras.py --epochs 100 --num_train 600 --num_test 100
codeDir: .
gpuNum: 0
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import logging
from collections import defaultdict
import json_tricks
import numpy as np
from nni import parameter_expressions as param
from nni.msg_dispatcher_base import MsgDispatcherBase
from nni.protocol import CommandType, send
from nni.utils import MetricType
logger = logging.getLogger('customized_advisor')
class DummyAdvisor(MsgDispatcherBase):
"""WARNING: Advisor API is subject to change in future releases.
This advisor creates a new trial when validation accuracy of any one of the trials just dropped.
The trial is killed if the validation accuracy doesn't improve for at least k last-reported metrics.
To demonstrate the high flexibility of writing advisors, we don't use tuners or the standard definition of
search space. This is just a demo to customize an advisor. It's not intended to make any sense.
"""
def __init__(self, k=3):
super(DummyAdvisor, self).__init__()
self.k = k
self.random_state = np.random.RandomState()
def handle_initialize(self, data):
logger.info("Advisor initialized: {}".format(data))
self.handle_update_search_space(data)
self.parameters_count = 0
self.parameter_best_metric = defaultdict(float)
self.parameter_cooldown = defaultdict(int)
send(CommandType.Initialized, '')
def _send_new_trial(self):
self.parameters_count += 1
new_trial = {
"parameter_id": self.parameters_count,
"parameters": {
"optimizer": param.choice(self.searchspace_json["optimizer"], self.random_state),
"learning_rate": param.loguniform(self.searchspace_json["learning_rate"][0],
self.searchspace_json["learning_rate"][1],
self.random_state)
},
"parameter_source": "algorithm"
}
logger.info("New trial sent: {}".format(new_trial))
send(CommandType.NewTrialJob, json_tricks.dumps(new_trial))
def handle_request_trial_jobs(self, data):
logger.info("Request trial jobs: {}".format(data))
for _ in range(data):
self._send_new_trial()
def handle_update_search_space(self, data):
logger.info("Search space update: {}".format(data))
self.searchspace_json = data
def handle_trial_end(self, data):
logger.info("Trial end: {}".format(data)) # do nothing
def handle_report_metric_data(self, data):
logger.info("Metric reported: {}".format(data))
if data['type'] == MetricType.REQUEST_PARAMETER:
raise ValueError("Request parameter not supported")
elif data["type"] == MetricType.PERIODICAL:
parameter_id = data["parameter_id"]
if data["value"] > self.parameter_best_metric[parameter_id]:
self.parameter_best_metric[parameter_id] = data["value"]
self.parameter_cooldown[parameter_id] = 0
else:
self.parameter_cooldown[parameter_id] += 1
logger.info("Accuracy dropped, cooldown {}, sending a new trial".format(
self.parameter_cooldown[parameter_id]))
self._send_new_trial()
if self.parameter_cooldown[parameter_id] >= self.k:
logger.info("Send kill signal to {}".format(data))
send(CommandType.KillTrialJob, json_tricks.dumps(data["trial_job_id"]))
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import argparse
import logging
import os
import keras
import numpy as np
from keras import backend as K
from keras.callbacks import TensorBoard
from keras.datasets import mnist
from keras.layers import Conv2D, Dense, Flatten, MaxPooling2D
from keras.models import Sequential
import nni
LOG = logging.getLogger('mnist_keras')
K.set_image_data_format('channels_last')
TENSORBOARD_DIR = os.environ['NNI_OUTPUT_DIR']
H, W = 28, 28
NUM_CLASSES = 10
def create_mnist_model(hyper_params, input_shape=(H, W, 1), num_classes=NUM_CLASSES):
"""
Create simple convolutional model
"""
layers = [
Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=input_shape),
Conv2D(64, (3, 3), activation='relu'),
MaxPooling2D(pool_size=(2, 2)),
Flatten(),
Dense(100, activation='relu'),
Dense(num_classes, activation='softmax')
]
model = Sequential(layers)
if hyper_params['optimizer'] == 'Adam':
optimizer = keras.optimizers.Adam(lr=hyper_params['learning_rate'])
else:
optimizer = keras.optimizers.SGD(lr=hyper_params['learning_rate'], momentum=0.9)
model.compile(loss=keras.losses.categorical_crossentropy, optimizer=optimizer, metrics=['accuracy'])
return model
def load_mnist_data(args):
"""
Load MNIST dataset
"""
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = (np.expand_dims(x_train, -1).astype(np.float) / 255.)[:args.num_train]
x_test = (np.expand_dims(x_test, -1).astype(np.float) / 255.)[:args.num_test]
y_train = keras.utils.to_categorical(y_train, NUM_CLASSES)[:args.num_train]
y_test = keras.utils.to_categorical(y_test, NUM_CLASSES)[:args.num_test]
LOG.debug('x_train shape: %s', (x_train.shape,))
LOG.debug('x_test shape: %s', (x_test.shape,))
return x_train, y_train, x_test, y_test
class SendMetrics(keras.callbacks.Callback):
"""
Keras callback to send metrics to NNI framework
"""
def on_epoch_end(self, epoch, logs={}):
"""
Run on end of each epoch
"""
LOG.debug(logs)
# Should this be val_acc or val_accuracy? Seems inconsistent behavior of Keras?
nni.report_intermediate_result(logs["val_accuracy"])
def train(args, params):
"""
Train model
"""
x_train, y_train, x_test, y_test = load_mnist_data(args)
model = create_mnist_model(params)
model.fit(x_train, y_train, batch_size=args.batch_size, epochs=args.epochs, verbose=1,
validation_data=(x_test, y_test), callbacks=[SendMetrics(), TensorBoard(log_dir=TENSORBOARD_DIR)])
_, acc = model.evaluate(x_test, y_test, verbose=0)
LOG.debug('Final result is: %d', acc)
nni.report_final_result(acc)
def generate_default_params():
"""
Generate default hyper parameters
"""
return {
'optimizer': 'Adam',
'learning_rate': 0.001
}
if __name__ == '__main__':
PARSER = argparse.ArgumentParser()
PARSER.add_argument("--batch_size", type=int, default=200, help="batch size", required=False)
PARSER.add_argument("--epochs", type=int, default=10, help="Train epochs", required=False)
PARSER.add_argument("--num_train", type=int, default=60000,
help="Number of train samples to be used, maximum 60000", required=False)
PARSER.add_argument("--num_test", type=int, default=10000, help="Number of test samples to be used, maximum 10000",
required=False)
ARGS, UNKNOWN = PARSER.parse_known_args()
# get parameters from tuner
RECEIVED_PARAMS = nni.get_next_parameter()
LOG.debug(RECEIVED_PARAMS)
PARAMS = generate_default_params()
PARAMS.update(RECEIVED_PARAMS)
# train
train(ARGS, PARAMS)
{
"README": "To demonstrate the flexibility, this search space does not follow the standard definition.",
"optimizer": ["Adam", "SGD"],
"learning_rate": [0.001, 0.1]
}
...@@ -21,18 +21,17 @@ ...@@ -21,18 +21,17 @@
hyperband_advisor.py hyperband_advisor.py
""" """
import sys
import math
import copy import copy
import logging import logging
import numpy as np import math
import json_tricks import sys
from nni.protocol import CommandType, send import json_tricks
import numpy as np
from nni.common import multi_phase_enabled
from nni.msg_dispatcher_base import MsgDispatcherBase from nni.msg_dispatcher_base import MsgDispatcherBase
from nni.common import init_logger, multi_phase_enabled from nni.protocol import CommandType, send
from nni.utils import NodeType, OptimizeMode, MetricType, extract_scalar_reward from nni.utils import NodeType, OptimizeMode, MetricType, extract_scalar_reward
import nni.parameter_expressions as parameter_expressions
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -53,6 +52,7 @@ def create_parameter_id(): ...@@ -53,6 +52,7 @@ def create_parameter_id():
_next_parameter_id += 1 _next_parameter_id += 1
return _next_parameter_id - 1 return _next_parameter_id - 1
def create_bracket_parameter_id(brackets_id, brackets_curr_decay, increased_id=-1): def create_bracket_parameter_id(brackets_id, brackets_curr_decay, increased_id=-1):
"""Create a full id for a specific bracket's hyperparameter configuration """Create a full id for a specific bracket's hyperparameter configuration
...@@ -77,6 +77,7 @@ def create_bracket_parameter_id(brackets_id, brackets_curr_decay, increased_id=- ...@@ -77,6 +77,7 @@ def create_bracket_parameter_id(brackets_id, brackets_curr_decay, increased_id=-
increased_id]) increased_id])
return params_id return params_id
def json2parameter(ss_spec, random_state): def json2parameter(ss_spec, random_state):
"""Randomly generate values for hyperparameters from hyperparameter space i.e., x. """Randomly generate values for hyperparameters from hyperparameter space i.e., x.
...@@ -100,7 +101,7 @@ def json2parameter(ss_spec, random_state): ...@@ -100,7 +101,7 @@ def json2parameter(ss_spec, random_state):
_index = random_state.randint(len(_value)) _index = random_state.randint(len(_value))
chosen_params = json2parameter(ss_spec[NodeType.VALUE][_index], random_state) chosen_params = json2parameter(ss_spec[NodeType.VALUE][_index], random_state)
else: else:
chosen_params = eval('parameter_expressions.' + # pylint: disable=eval-used chosen_params = eval('parameter_expressions.' + # pylint: disable=eval-used
_type)(*(_value + [random_state])) _type)(*(_value + [random_state]))
else: else:
chosen_params = dict() chosen_params = dict()
...@@ -114,6 +115,7 @@ def json2parameter(ss_spec, random_state): ...@@ -114,6 +115,7 @@ def json2parameter(ss_spec, random_state):
chosen_params = copy.deepcopy(ss_spec) chosen_params = copy.deepcopy(ss_spec)
return chosen_params return chosen_params
class Bracket(): class Bracket():
"""A bracket in Hyperband, all the information of a bracket is managed by an instance of this class """A bracket in Hyperband, all the information of a bracket is managed by an instance of this class
...@@ -137,12 +139,12 @@ class Bracket(): ...@@ -137,12 +139,12 @@ class Bracket():
self.bracket_id = s self.bracket_id = s
self.s_max = s_max self.s_max = s_max
self.eta = eta self.eta = eta
self.n = math.ceil((s_max + 1) * (eta**s) / (s + 1) - _epsilon) # pylint: disable=invalid-name self.n = math.ceil((s_max + 1) * (eta ** s) / (s + 1) - _epsilon) # pylint: disable=invalid-name
self.r = R / eta**s # pylint: disable=invalid-name self.r = R / eta ** s # pylint: disable=invalid-name
self.i = 0 self.i = 0
self.hyper_configs = [] # [ {id: params}, {}, ... ] self.hyper_configs = [] # [ {id: params}, {}, ... ]
self.configs_perf = [] # [ {id: [seq, acc]}, {}, ... ] self.configs_perf = [] # [ {id: [seq, acc]}, {}, ... ]
self.num_configs_to_run = [] # [ n, n, n, ... ] self.num_configs_to_run = [] # [ n, n, n, ... ]
self.num_finished_configs = [] # [ n, n, n, ... ] self.num_finished_configs = [] # [ n, n, n, ... ]
self.optimize_mode = OptimizeMode(optimize_mode) self.optimize_mode = OptimizeMode(optimize_mode)
self.no_more_trial = False self.no_more_trial = False
...@@ -153,7 +155,7 @@ class Bracket(): ...@@ -153,7 +155,7 @@ class Bracket():
def get_n_r(self): def get_n_r(self):
"""return the values of n and r for the next round""" """return the values of n and r for the next round"""
return math.floor(self.n / self.eta**self.i + _epsilon), math.floor(self.r * self.eta**self.i + _epsilon) return math.floor(self.n / self.eta ** self.i + _epsilon), math.floor(self.r * self.eta ** self.i + _epsilon)
def increase_i(self): def increase_i(self):
"""i means the ith round. Increase i by 1""" """i means the ith round. Increase i by 1"""
...@@ -185,7 +187,6 @@ class Bracket(): ...@@ -185,7 +187,6 @@ class Bracket():
else: else:
self.configs_perf[i][parameter_id] = [seq, value] self.configs_perf[i][parameter_id] = [seq, value]
def inform_trial_end(self, i): def inform_trial_end(self, i):
"""If the trial is finished and the corresponding round (i.e., i) has all its trials finished, """If the trial is finished and the corresponding round (i.e., i) has all its trials finished,
it will choose the top k trials for the next round (i.e., i+1) it will choose the top k trials for the next round (i.e., i+1)
...@@ -195,16 +196,17 @@ class Bracket(): ...@@ -195,16 +196,17 @@ class Bracket():
i: int i: int
the ith round the ith round
""" """
global _KEY # pylint: disable=global-statement global _KEY # pylint: disable=global-statement
self.num_finished_configs[i] += 1 self.num_finished_configs[i] += 1
_logger.debug('bracket id: %d, round: %d %d, finished: %d, all: %d', self.bracket_id, self.i, i, self.num_finished_configs[i], self.num_configs_to_run[i]) _logger.debug('bracket id: %d, round: %d %d, finished: %d, all: %d', self.bracket_id, self.i, i,
self.num_finished_configs[i], self.num_configs_to_run[i])
if self.num_finished_configs[i] >= self.num_configs_to_run[i] \ if self.num_finished_configs[i] >= self.num_configs_to_run[i] \
and self.no_more_trial is False: and self.no_more_trial is False:
# choose candidate configs from finished configs to run in the next round # choose candidate configs from finished configs to run in the next round
assert self.i == i + 1 assert self.i == i + 1
this_round_perf = self.configs_perf[i] this_round_perf = self.configs_perf[i]
if self.optimize_mode is OptimizeMode.Maximize: if self.optimize_mode is OptimizeMode.Maximize:
sorted_perf = sorted(this_round_perf.items(), key=lambda kv: kv[1][1], reverse=True) # reverse sorted_perf = sorted(this_round_perf.items(), key=lambda kv: kv[1][1], reverse=True) # reverse
else: else:
sorted_perf = sorted(this_round_perf.items(), key=lambda kv: kv[1][1]) sorted_perf = sorted(this_round_perf.items(), key=lambda kv: kv[1][1])
_logger.debug('bracket %s next round %s, sorted hyper configs: %s', self.bracket_id, self.i, sorted_perf) _logger.debug('bracket %s next round %s, sorted hyper configs: %s', self.bracket_id, self.i, sorted_perf)
...@@ -214,7 +216,7 @@ class Bracket(): ...@@ -214,7 +216,7 @@ class Bracket():
for k in range(next_n): for k in range(next_n):
params_id = sorted_perf[k][0] params_id = sorted_perf[k][0]
params = self.hyper_configs[i][params_id] params = self.hyper_configs[i][params_id]
params[_KEY] = next_r # modify r params[_KEY] = next_r # modify r
# generate new id # generate new id
increased_id = params_id.split('_')[-1] increased_id = params_id.split('_')[-1]
new_id = create_bracket_parameter_id(self.bracket_id, self.i, increased_id) new_id = create_bracket_parameter_id(self.bracket_id, self.i, increased_id)
...@@ -223,7 +225,7 @@ class Bracket(): ...@@ -223,7 +225,7 @@ class Bracket():
return [[key, value] for key, value in hyper_configs.items()] return [[key, value] for key, value in hyper_configs.items()]
return None return None
def get_hyperparameter_configurations(self, num, r, searchspace_json, random_state): # pylint: disable=invalid-name def get_hyperparameter_configurations(self, num, r, searchspace_json, random_state): # pylint: disable=invalid-name
"""Randomly generate num hyperparameter configurations from search space """Randomly generate num hyperparameter configurations from search space
Parameters Parameters
...@@ -236,7 +238,7 @@ class Bracket(): ...@@ -236,7 +238,7 @@ class Bracket():
list list
a list of hyperparameter configurations. Format: [[key1, value1], [key2, value2], ...] a list of hyperparameter configurations. Format: [[key1, value1], [key2, value2], ...]
""" """
global _KEY # pylint: disable=global-statement global _KEY # pylint: disable=global-statement
assert self.i == 0 assert self.i == 0
hyperparameter_configs = dict() hyperparameter_configs = dict()
for _ in range(num): for _ in range(num):
...@@ -263,6 +265,7 @@ class Bracket(): ...@@ -263,6 +265,7 @@ class Bracket():
self.num_configs_to_run.append(len(hyper_configs)) self.num_configs_to_run.append(len(hyper_configs))
self.increase_i() self.increase_i()
class Hyperband(MsgDispatcherBase): class Hyperband(MsgDispatcherBase):
"""Hyperband inherit from MsgDispatcherBase rather than Tuner, because it integrates both tuner's functions and assessor's functions. """Hyperband inherit from MsgDispatcherBase rather than Tuner, because it integrates both tuner's functions and assessor's functions.
This is an implementation that could fully leverage available resources, i.e., high parallelism. This is an implementation that could fully leverage available resources, i.e., high parallelism.
...@@ -277,14 +280,15 @@ class Hyperband(MsgDispatcherBase): ...@@ -277,14 +280,15 @@ class Hyperband(MsgDispatcherBase):
optimize_mode: str optimize_mode: str
optimize mode, 'maximize' or 'minimize' optimize mode, 'maximize' or 'minimize'
""" """
def __init__(self, R=60, eta=3, optimize_mode='maximize'): def __init__(self, R=60, eta=3, optimize_mode='maximize'):
"""B = (s_max + 1)R""" """B = (s_max + 1)R"""
super(Hyperband, self).__init__() super(Hyperband, self).__init__()
self.R = R # pylint: disable=invalid-name self.R = R # pylint: disable=invalid-name
self.eta = eta self.eta = eta
self.brackets = dict() # dict of Bracket self.brackets = dict() # dict of Bracket
self.generated_hyper_configs = [] # all the configs waiting for run self.generated_hyper_configs = [] # all the configs waiting for run
self.completed_hyper_configs = [] # all the completed configs self.completed_hyper_configs = [] # all the completed configs
self.s_max = math.floor(math.log(self.R, self.eta) + _epsilon) self.s_max = math.floor(math.log(self.R, self.eta) + _epsilon)
self.curr_s = self.s_max self.curr_s = self.s_max
...@@ -302,12 +306,11 @@ class Hyperband(MsgDispatcherBase): ...@@ -302,12 +306,11 @@ class Hyperband(MsgDispatcherBase):
self.job_id_para_id_map = dict() self.job_id_para_id_map = dict()
def handle_initialize(self, data): def handle_initialize(self, data):
"""data is search space """callback for initializing the advisor
Parameters Parameters
---------- ----------
data: int data: dict
number of trial jobs search space
""" """
self.handle_update_search_space(data) self.handle_update_search_space(data)
send(CommandType.Initialized, '') send(CommandType.Initialized, '')
...@@ -348,14 +351,8 @@ class Hyperband(MsgDispatcherBase): ...@@ -348,14 +351,8 @@ class Hyperband(MsgDispatcherBase):
} }
return ret return ret
def handle_update_search_space(self, data): def handle_update_search_space(self, data):
"""data: JSON object, which is search space """data: JSON object, which is search space
Parameters
----------
data: int
number of trial jobs
""" """
self.searchspace_json = data self.searchspace_json = data
self.random_state = np.random.RandomState() self.random_state = np.random.RandomState()
......
...@@ -42,8 +42,9 @@ We need this because NNI manager may send metrics after reporting a trial ended. ...@@ -42,8 +42,9 @@ We need this because NNI manager may send metrics after reporting a trial ended.
TODO: move this logic to NNI manager TODO: move this logic to NNI manager
''' '''
def _sort_history(history): def _sort_history(history):
ret = [ ] ret = []
for i, _ in enumerate(history): for i, _ in enumerate(history):
if i in history: if i in history:
ret.append(history[i]) ret.append(history[i])
...@@ -51,17 +52,20 @@ def _sort_history(history): ...@@ -51,17 +52,20 @@ def _sort_history(history):
break break
return ret return ret
# Tuner global variables # Tuner global variables
_next_parameter_id = 0 _next_parameter_id = 0
_trial_params = {} _trial_params = {}
'''key: trial job ID; value: parameters''' '''key: trial job ID; value: parameters'''
_customized_parameter_ids = set() _customized_parameter_ids = set()
def _create_parameter_id(): def _create_parameter_id():
global _next_parameter_id # pylint: disable=global-statement global _next_parameter_id # pylint: disable=global-statement
_next_parameter_id += 1 _next_parameter_id += 1
return _next_parameter_id - 1 return _next_parameter_id - 1
def _pack_parameter(parameter_id, params, customized=False, trial_job_id=None, parameter_index=None): def _pack_parameter(parameter_id, params, customized=False, trial_job_id=None, parameter_index=None):
_trial_params[parameter_id] = params _trial_params[parameter_id] = params
ret = { ret = {
...@@ -77,6 +81,7 @@ def _pack_parameter(parameter_id, params, customized=False, trial_job_id=None, p ...@@ -77,6 +81,7 @@ def _pack_parameter(parameter_id, params, customized=False, trial_job_id=None, p
ret['parameter_index'] = 0 ret['parameter_index'] = 0
return json_tricks.dumps(ret) return json_tricks.dumps(ret)
class MsgDispatcher(MsgDispatcherBase): class MsgDispatcher(MsgDispatcherBase):
def __init__(self, tuner, assessor=None): def __init__(self, tuner, assessor=None):
super(MsgDispatcher, self).__init__() super(MsgDispatcher, self).__init__()
...@@ -123,7 +128,7 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -123,7 +128,7 @@ class MsgDispatcher(MsgDispatcherBase):
def handle_import_data(self, data): def handle_import_data(self, data):
"""Import additional data for tuning """Import additional data for tuning
data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value' data: a list of dictionaries, each of which has at least two keys, 'parameter' and 'value'
""" """
self.tuner.import_data(data) self.tuner.import_data(data)
...@@ -154,7 +159,8 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -154,7 +159,8 @@ class MsgDispatcher(MsgDispatcherBase):
param = self.tuner.generate_parameters(param_id, trial_job_id=data['trial_job_id']) param = self.tuner.generate_parameters(param_id, trial_job_id=data['trial_job_id'])
except NoMoreTrialError: except NoMoreTrialError:
param = None param = None
send(CommandType.SendTrialJobParameter, _pack_parameter(param_id, param, trial_job_id=data['trial_job_id'], parameter_index=data['parameter_index'])) send(CommandType.SendTrialJobParameter, _pack_parameter(param_id, param, trial_job_id=data['trial_job_id'],
parameter_index=data['parameter_index']))
else: else:
raise ValueError('Data type not supported: {}'.format(data['type'])) raise ValueError('Data type not supported: {}'.format(data['type']))
...@@ -188,7 +194,8 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -188,7 +194,8 @@ class MsgDispatcher(MsgDispatcherBase):
customized = True customized = True
else: else:
customized = False customized = False
self.tuner.receive_trial_result(id_, _trial_params[id_], value, customized=customized, trial_job_id=data.get('trial_job_id')) self.tuner.receive_trial_result(id_, _trial_params[id_], value, customized=customized,
trial_job_id=data.get('trial_job_id'))
def _handle_intermediate_metric_data(self, data): def _handle_intermediate_metric_data(self, data):
"""Call assessor to process intermediate results """Call assessor to process intermediate results
...@@ -223,7 +230,8 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -223,7 +230,8 @@ class MsgDispatcher(MsgDispatcherBase):
_logger.debug('BAD, kill %s', trial_job_id) _logger.debug('BAD, kill %s', trial_job_id)
send(CommandType.KillTrialJob, json_tricks.dumps(trial_job_id)) send(CommandType.KillTrialJob, json_tricks.dumps(trial_job_id))
# notify tuner # notify tuner
_logger.debug('env var: NNI_INCLUDE_INTERMEDIATE_RESULTS: [%s]', dispatcher_env_vars.NNI_INCLUDE_INTERMEDIATE_RESULTS) _logger.debug('env var: NNI_INCLUDE_INTERMEDIATE_RESULTS: [%s]',
dispatcher_env_vars.NNI_INCLUDE_INTERMEDIATE_RESULTS)
if dispatcher_env_vars.NNI_INCLUDE_INTERMEDIATE_RESULTS == 'true': if dispatcher_env_vars.NNI_INCLUDE_INTERMEDIATE_RESULTS == 'true':
self._earlystop_notify_tuner(data) self._earlystop_notify_tuner(data)
else: else:
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ================================================================================================== # ==================================================================================================
#import json_tricks
import os import os
import threading import threading
import logging import logging
...@@ -39,7 +38,12 @@ _logger = logging.getLogger(__name__) ...@@ -39,7 +38,12 @@ _logger = logging.getLogger(__name__)
QUEUE_LEN_WARNING_MARK = 20 QUEUE_LEN_WARNING_MARK = 20
_worker_fast_exit_on_terminate = True _worker_fast_exit_on_terminate = True
class MsgDispatcherBase(Recoverable): class MsgDispatcherBase(Recoverable):
"""This is where tuners and assessors are not defined yet.
Inherits this class to make your own advisor.
"""
def __init__(self): def __init__(self):
if multi_thread_enabled(): if multi_thread_enabled():
self.pool = ThreadPool() self.pool = ThreadPool()
...@@ -49,7 +53,8 @@ class MsgDispatcherBase(Recoverable): ...@@ -49,7 +53,8 @@ class MsgDispatcherBase(Recoverable):
self.default_command_queue = Queue() self.default_command_queue = Queue()
self.assessor_command_queue = Queue() self.assessor_command_queue = Queue()
self.default_worker = threading.Thread(target=self.command_queue_worker, args=(self.default_command_queue,)) self.default_worker = threading.Thread(target=self.command_queue_worker, args=(self.default_command_queue,))
self.assessor_worker = threading.Thread(target=self.command_queue_worker, args=(self.assessor_command_queue,)) self.assessor_worker = threading.Thread(target=self.command_queue_worker,
args=(self.assessor_command_queue,))
self.default_worker.start() self.default_worker.start()
self.assessor_worker.start() self.assessor_worker.start()
self.worker_exceptions = [] self.worker_exceptions = []
...@@ -72,7 +77,8 @@ class MsgDispatcherBase(Recoverable): ...@@ -72,7 +77,8 @@ class MsgDispatcherBase(Recoverable):
if multi_thread_enabled(): if multi_thread_enabled():
result = self.pool.map_async(self.process_command_thread, [(command, data)]) result = self.pool.map_async(self.process_command_thread, [(command, data)])
self.thread_results.append(result) self.thread_results.append(result)
if any([thread_result.ready() and not thread_result.successful() for thread_result in self.thread_results]): if any([thread_result.ready() and not thread_result.successful() for thread_result in
self.thread_results]):
_logger.debug('Caught thread exception') _logger.debug('Caught thread exception')
break break
else: else:
...@@ -112,7 +118,8 @@ class MsgDispatcherBase(Recoverable): ...@@ -112,7 +118,8 @@ class MsgDispatcherBase(Recoverable):
def enqueue_command(self, command, data): def enqueue_command(self, command, data):
"""Enqueue command into command queues """Enqueue command into command queues
""" """
if command == CommandType.TrialEnd or (command == CommandType.ReportMetricData and data['type'] == 'PERIODICAL'): if command == CommandType.TrialEnd or (
command == CommandType.ReportMetricData and data['type'] == 'PERIODICAL'):
self.assessor_command_queue.put((command, data)) self.assessor_command_queue.put((command, data))
else: else:
self.default_command_queue.put((command, data)) self.default_command_queue.put((command, data))
...@@ -142,14 +149,14 @@ class MsgDispatcherBase(Recoverable): ...@@ -142,14 +149,14 @@ class MsgDispatcherBase(Recoverable):
_logger.debug('process_command: command: [{}], data: [{}]'.format(command, data)) _logger.debug('process_command: command: [{}], data: [{}]'.format(command, data))
command_handlers = { command_handlers = {
# Tunner commands: # Tuner commands:
CommandType.Initialize: self.handle_initialize, CommandType.Initialize: self.handle_initialize,
CommandType.RequestTrialJobs: self.handle_request_trial_jobs, CommandType.RequestTrialJobs: self.handle_request_trial_jobs,
CommandType.UpdateSearchSpace: self.handle_update_search_space, CommandType.UpdateSearchSpace: self.handle_update_search_space,
CommandType.ImportData: self.handle_import_data, CommandType.ImportData: self.handle_import_data,
CommandType.AddCustomizedTrialJob: self.handle_add_customized_trial, CommandType.AddCustomizedTrialJob: self.handle_add_customized_trial,
# Tunner/Assessor commands: # Tuner/Assessor commands:
CommandType.ReportMetricData: self.handle_report_metric_data, CommandType.ReportMetricData: self.handle_report_metric_data,
CommandType.TrialEnd: self.handle_trial_end, CommandType.TrialEnd: self.handle_trial_end,
...@@ -163,22 +170,88 @@ class MsgDispatcherBase(Recoverable): ...@@ -163,22 +170,88 @@ class MsgDispatcherBase(Recoverable):
pass pass
def handle_initialize(self, data): def handle_initialize(self, data):
"""Initialize search space and tuner, if any
This method is meant to be called only once for each experiment, after calling this method,
dispatcher should `send(CommandType.Initialized, '')`, to set the status of the experiment to be "INITIALIZED".
Parameters
----------
data: dict
search space
"""
raise NotImplementedError('handle_initialize not implemented') raise NotImplementedError('handle_initialize not implemented')
def handle_request_trial_jobs(self, data): def handle_request_trial_jobs(self, data):
"""The message dispatcher is demanded to generate `data` trial jobs.
These trial jobs should be sent via `send(CommandType.NewTrialJob, json_tricks.dumps(parameter))`,
where `parameter` will be received by NNI Manager and eventually accessible to trial jobs as "next parameter".
Semantically, message dispatcher should do this `send` exactly `data` times.
The JSON sent by this method should follow the format of
{
"parameter_id": 42
"parameters": {
// this will be received by trial
},
"parameter_source": "algorithm" // optional
}
Parameters
----------
data: int
number of trial jobs
"""
raise NotImplementedError('handle_request_trial_jobs not implemented') raise NotImplementedError('handle_request_trial_jobs not implemented')
def handle_update_search_space(self, data): def handle_update_search_space(self, data):
raise NotImplementedError('handle_update_search_space not implemented') """This method will be called when search space is updated.
It's recommended to call this method in `handle_initialize` to initialize search space.
*No need to* notify NNI Manager when this update is done.
Parameters
----------
data: dict
search space
"""
raise NotImplementedError('handle_update_search_space not implemented')
def handle_import_data(self, data): def handle_import_data(self, data):
"""Import previous data when experiment is resumed.
Parameters
----------
data: list
a list of dictionaries, each of which has at least two keys, 'parameter' and 'value'
"""
raise NotImplementedError('handle_import_data not implemented') raise NotImplementedError('handle_import_data not implemented')
def handle_add_customized_trial(self, data): def handle_add_customized_trial(self, data):
"""Experimental API. Not recommended for usage.
"""
raise NotImplementedError('handle_add_customized_trial not implemented') raise NotImplementedError('handle_add_customized_trial not implemented')
def handle_report_metric_data(self, data): def handle_report_metric_data(self, data):
"""Called when metric data is reported or new parameters are requested (for multiphase).
When new parameters are requested, this method should send a new parameter.
Parameters
----------
data: dict
a dict which contains 'parameter_id', 'value', 'trial_job_id', 'type', 'sequence'.
type: can be `MetricType.REQUEST_PARAMETER`, `MetricType.FINAL` or `MetricType.PERIODICAL`.
`REQUEST_PARAMETER` is used to request new parameters for multiphase trial job. In this case,
the dict will contain additional keys: `trial_job_id`, `parameter_index`. Refer to `msg_dispatcher.py`
as an example.
Raises
------
ValueError
Data type is not supported
"""
raise NotImplementedError('handle_report_metric_data not implemented') raise NotImplementedError('handle_report_metric_data not implemented')
def handle_trial_end(self, data): def handle_trial_end(self, data):
"""Called when the state of one of the trials is changed
Parameters
----------
data: dict
a dict with keys: trial_job_id, event, hyper_params.
trial_job_id: the id generated by training service.
event: the job’s state.
hyper_params: the string that is sent by message dispatcher during the creation of trials.
"""
raise NotImplementedError('handle_trial_end not implemented') raise NotImplementedError('handle_trial_end not implemented')
...@@ -28,9 +28,9 @@ from io import BytesIO ...@@ -28,9 +28,9 @@ from io import BytesIO
import json import json
from unittest import TestCase, main from unittest import TestCase, main
_trials = []
_end_trials = []
_trials = [ ]
_end_trials = [ ]
class NaiveAssessor(Assessor): class NaiveAssessor(Assessor):
def assess_trial(self, trial_job_id, trial_history): def assess_trial(self, trial_job_id, trial_history):
...@@ -47,12 +47,14 @@ class NaiveAssessor(Assessor): ...@@ -47,12 +47,14 @@ class NaiveAssessor(Assessor):
_in_buf = BytesIO() _in_buf = BytesIO()
_out_buf = BytesIO() _out_buf = BytesIO()
def _reverse_io(): def _reverse_io():
_in_buf.seek(0) _in_buf.seek(0)
_out_buf.seek(0) _out_buf.seek(0)
nni.protocol._out_file = _in_buf nni.protocol._out_file = _in_buf
nni.protocol._in_file = _out_buf nni.protocol._in_file = _out_buf
def _restore_io(): def _restore_io():
_in_buf.seek(0) _in_buf.seek(0)
_out_buf.seek(0) _out_buf.seek(0)
......
...@@ -32,7 +32,7 @@ from unittest import TestCase, main ...@@ -32,7 +32,7 @@ from unittest import TestCase, main
class NaiveTuner(Tuner): class NaiveTuner(Tuner):
def __init__(self): def __init__(self):
self.param = 0 self.param = 0
self.trial_results = [ ] self.trial_results = []
self.search_space = None self.search_space = None
self.accept_customized_trials() self.accept_customized_trials()
...@@ -57,12 +57,14 @@ class NaiveTuner(Tuner): ...@@ -57,12 +57,14 @@ class NaiveTuner(Tuner):
_in_buf = BytesIO() _in_buf = BytesIO()
_out_buf = BytesIO() _out_buf = BytesIO()
def _reverse_io(): def _reverse_io():
_in_buf.seek(0) _in_buf.seek(0)
_out_buf.seek(0) _out_buf.seek(0)
nni.protocol._out_file = _in_buf nni.protocol._out_file = _in_buf
nni.protocol._in_file = _out_buf nni.protocol._in_file = _out_buf
def _restore_io(): def _restore_io():
_in_buf.seek(0) _in_buf.seek(0)
_out_buf.seek(0) _out_buf.seek(0)
...@@ -70,7 +72,6 @@ def _restore_io(): ...@@ -70,7 +72,6 @@ def _restore_io():
nni.protocol._out_file = _out_buf nni.protocol._out_file = _out_buf
class TunerTestCase(TestCase): class TunerTestCase(TestCase):
def test_tuner(self): def test_tuner(self):
_reverse_io() # now we are sending to Tuner's incoming stream _reverse_io() # now we are sending to Tuner's incoming stream
...@@ -94,21 +95,20 @@ class TunerTestCase(TestCase): ...@@ -94,21 +95,20 @@ class TunerTestCase(TestCase):
self.assertEqual(e.args[0], 'Unsupported command: CommandType.KillTrialJob') self.assertEqual(e.args[0], 'Unsupported command: CommandType.KillTrialJob')
_reverse_io() # now we are receiving from Tuner's outgoing stream _reverse_io() # now we are receiving from Tuner's outgoing stream
self._assert_params(0, 2, [ ], None) self._assert_params(0, 2, [], None)
self._assert_params(1, 4, [ ], None) self._assert_params(1, 4, [], None)
command, data = receive() # this one is customized command, data = receive() # this one is customized
data = json.loads(data) data = json.loads(data)
self.assertIs(command, CommandType.NewTrialJob) self.assertIs(command, CommandType.NewTrialJob)
self.assertEqual(data['parameter_id'], 2) self.assertEqual(data['parameter_id'], 2)
self.assertEqual(data['parameter_source'], 'customized') self.assertEqual(data['parameter_source'], 'customized')
self.assertEqual(data['parameters'], { 'param': -1 }) self.assertEqual(data['parameters'], {'param': -1})
self._assert_params(3, 6, [[1,4,11,False], [2,-1,22,True]], {'name':'SS0'}) self._assert_params(3, 6, [[1, 4, 11, False], [2, -1, 22, True]], {'name': 'SS0'})
self.assertEqual(len(_out_buf.read()), 0) # no more commands self.assertEqual(len(_out_buf.read()), 0) # no more commands
def _assert_params(self, parameter_id, param, trial_results, search_space): def _assert_params(self, parameter_id, param, trial_results, search_space):
command, data = receive() command, data = receive()
self.assertIs(command, CommandType.NewTrialJob) self.assertIs(command, CommandType.NewTrialJob)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment