Unverified Commit 346d49d5 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #156 from Microsoft/master

merge master
parents d95c3513 58b259a5
...@@ -62,4 +62,8 @@ machineList: ...@@ -62,4 +62,8 @@ machineList:
nnictl create --config ~/nni/examples/trials/mnist-annotation/config_remote.yml nnictl create --config ~/nni/examples/trials/mnist-annotation/config_remote.yml
``` ```
来启动 Experiment。 来启动 Experiment。
\ No newline at end of file
## 版本校验
从 0.6 开始,NNI 支持版本校验,详情参考[这里](PAIMode.md)
\ No newline at end of file
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
## 1. 搜索空间 ## 1. 搜索空间
对于阅读理解项目,注意力和循环神经网络(RNN)已经被证明非常有效。 使用的搜索空间如下: 阅读理解领域,注意力(Attention)和循环神经网络都已被证明非常有效的方法,因此搜索空间定义如下:
1. IDENTITY (Effectively 表示继续训练)。 1. IDENTITY (Effectively 表示继续训练)。
2. INSERT-RNN-LAYER (插入 LSTM。 在 Experiment 中比较了 GRU 和 LSTM 的性能后,我们决定在这里采用 LSTM。) 2. INSERT-RNN-LAYER (插入 LSTM。 在 Experiment 中比较了 GRU 和 LSTM 的性能后,我们决定在这里采用 LSTM。)
......
...@@ -130,6 +130,34 @@ Annotation 的语法和用法等,参考 [Annotation](AnnotationSpec.md)。 ...@@ -130,6 +130,34 @@ Annotation 的语法和用法等,参考 [Annotation](AnnotationSpec.md)。
useAnnotation: true useAnnotation: true
## Trial 存放在什么地方?
### 本机模式
每个 Trial 都有单独的目录来输出自己的数据。 在每次 Trial 运行后,环境变量 `NNI_OUTPUT_DIR` 定义的目录都会被导出。 在这个目录中可以看到 Trial 的代码、数据和日志。 此外,Trial 的日志(包括 stdout)还会被重定向到此目录中的 `trial.log` 文件。
如果使用了 Annotation 方法,转换后的 Trial 代码会存放在另一个临时目录中。 可以在 `run.sh` 文件中的 `NNI_OUTPUT_DIR` 变量找到此目录。 文件中的第二行(即:`cd`)会切换到代码所在的实际路径。 参考 `run.sh` 文件样例:
```shell
#!/bin/bash
cd /tmp/user_name/nni/annotation/tmpzj0h72x6 #This is the actual directory
export NNI_PLATFORM=local
export NNI_SYS_DIR=/home/user_name/nni/experiments/$experiment_id$/trials/$trial_id$
export NNI_TRIAL_JOB_ID=nrbb2
export NNI_OUTPUT_DIR=/home/user_name/nni/experiments/$eperiment_id$/trials/$trial_id$
export NNI_TRIAL_SEQ_ID=1
export MULTI_PHASE=false
export CUDA_VISIBLE_DEVICES=
eval python3 mnist.py 2>/home/user_name/nni/experiments/$experiment_id$/trials/$trial_id$/stderr
echo $? `date +%s000` >/home/user_name/nni/experiments/$experiment_id$/trials/$trial_id$/.nni/state
```
### 其它模式
当 Trial 运行在 OpenPAI 这样的远程服务器上时,`NNI_OUTPUT_DIR` 仅会指向 Trial 的输出目录,而 `run.sh` 不会在此目录中。 `trial.log` 文件会被复制回本机的 Trial 目录中。目录的默认位置在 `~/nni/experiments/$experiment_id$/trials/$trial_id$/`
详细信息,可参考[调试指南](HowToDebug.md)
<a name="more-examples"></a> <a name="more-examples"></a>
## 更多 Trial 的样例 ## 更多 Trial 的样例
......
...@@ -12,4 +12,5 @@ ...@@ -12,4 +12,5 @@
Web 界面<WebUI> Web 界面<WebUI>
训练平台<training_services> 训练平台<training_services>
如何使用 Docker <HowToUseDocker> 如何使用 Docker <HowToUseDocker>
高级功能<advanced> 高级功能<advanced>
\ No newline at end of file 如何调试<HowToDebug>
\ No newline at end of file
...@@ -34,6 +34,16 @@ ...@@ -34,6 +34,16 @@
![](../img/trial_duration.png) ![](../img/trial_duration.png)
## 查看 Trial 中间结果
单击 "Intermediate Result" 标签查看折线图。
![](../img/webui-img/trials_intermeidate.png)
该图有筛选功能。 点击筛选按钮, 在第一个输入框中输入关注点的序号, 在第二个输入框中输入中间结果的范围,选出需要的数据。
![](../img/webui-img/filter_intermediate.png)
## 查看 Trial 状态 ## 查看 Trial 状态
点击 "Trials Detail" 标签查看所有 Trial 的状态。 特别是: 点击 "Trials Detail" 标签查看所有 Trial 的状态。 特别是:
......
...@@ -52,6 +52,52 @@ net = build_graph_from_json(RCV_CONFIG) ...@@ -52,6 +52,52 @@ net = build_graph_from_json(RCV_CONFIG)
nni.report_final_result(best_acc) nni.report_final_result(best_acc)
``` ```
如果需要保存并**读取最佳模型**,推荐采用以下方法。
```python
# 1. 使用 NNI API
## 从 Web 界面获取最佳模型的 ID
## 或查看 `nni/experiments/experiment_id/log/model_path/best_model.txt' 文件
## 从 JSON 文件中读取,并使用 NNI API 来加载
with open("best-model.json") as json_file:
json_of_model = json_file.read()
model = build_graph_from_json(json_of_model)
# 2. 使用框架的 API (与具体框架相关)
## 2.1 Keras API
## 在 Trial 代码中使用 Keras API 保存
## 最好保存 NNI 的 ID
model_id = nni.get_sequence_id()
## 将模型序列化为 JSON
model_json = model.to_json()
with open("model-{}.json".format(model_id), "w") as json_file:
json_file.write(model_json)
## 将权重序列化至 HDF5
model.save_weights("model-{}.h5".format(model_id))
## 重用模型时,使用 Keras API 读取
## 读取 JSON 文件,并创建模型
model_id = "" # 需要重用的模型 ID
with open('model-{}.json'.format(model_id), 'r') as json_file:
loaded_model_json = json_file.read()
loaded_model = model_from_json(loaded_model_json)
## 将权重加载到新模型中
loaded_model.load_weights("model-{}.h5".format(model_id))
## 2.2 PyTorch API
## 在 Trial 代码中使用 PyTorch API 保存
model_id = nni.get_sequence_id()
torch.save(model, "model-{}.pt".format(model_id))
## 重用模型时,使用 PyTorch API 读取
model_id = "" # 需要重用的模型 ID
loaded_model = torch.load("model-{}.pt".format(model_id))
```
## 3. 文件结构 ## 3. 文件结构
Tuner 有大量的文件、函数和类。 这里只简单介绍最重要的文件: Tuner 有大量的文件、函数和类。 这里只简单介绍最重要的文件:
...@@ -77,7 +123,7 @@ Tuner 有大量的文件、函数和类。 这里只简单介绍最重要的文 ...@@ -77,7 +123,7 @@ Tuner 有大量的文件、函数和类。 这里只简单介绍最重要的文
## 4. 网络表示的 JSON 样例 ## 4. 网络表示的 JSON 样例
这是样例定义的中间表示 JSON 文件,它会在架构搜索过程中从 Tuner 传到 Trial。 可调用 "json\_to\_graph()" 函数来将 JSON 文件转化为 Pytoch 或 Keras 模型。 样例如下。 这是定义的中间表示 JSON 样例,在架构搜索过程中从 Tuner 传到 Trial。 可调用 "json\_to\_graph()" 函数来将 JSON 文件转化为 Pytoch 或 Keras 模型。 样例如下。
```json ```json
{ {
......
...@@ -9,11 +9,11 @@ searchSpacePath: search_space.json ...@@ -9,11 +9,11 @@ searchSpacePath: search_space.json
#choice: true, false #choice: true, false
useAnnotation: false useAnnotation: false
advisor: advisor:
#choice: Hyperband #choice: Hyperband, BOHB
builtinAdvisorName: Hyperband builtinAdvisorName: Hyperband
classArgs: classArgs:
#R: the maximum STEPS (could be the number of mini-batches or epochs) can be #R: the maximum trial budget (could be the number of mini-batches or epochs) can be
# allocated to a trial. Each trial should use STEPS to control how long it runs. # allocated to a trial. Each trial should use trial budget to control how long it runs.
R: 100 R: 100
#eta: proportion of discarded trials #eta: proportion of discarded trials
eta: 3 eta: 3
......
...@@ -9,10 +9,10 @@ searchSpacePath: search_space.json ...@@ -9,10 +9,10 @@ searchSpacePath: search_space.json
#choice: true, false #choice: true, false
useAnnotation: false useAnnotation: false
advisor: advisor:
#choice: Hyperband #choice: Hyperband, BOHB
builtinAdvisorName: Hyperband builtinAdvisorName: Hyperband
classArgs: classArgs:
#R: the maximum STEPS #R: the maximum trial budget
R: 100 R: 100
#eta: proportion of discarded trials #eta: proportion of discarded trials
eta: 3 eta: 3
......
"""A deep MNIST classifier using convolutional layers.""" """A deep MNIST classifier using convolutional layers."""
import argparse
import logging import logging
import math import math
import tempfile import tempfile
...@@ -17,7 +18,7 @@ logger = logging.getLogger('mnist_AutoML') ...@@ -17,7 +18,7 @@ logger = logging.getLogger('mnist_AutoML')
class MnistNetwork(object): class MnistNetwork(object):
''' '''
MnistNetwork is for initlizing and building basic network for mnist. MnistNetwork is for initializing and building basic network for mnist.
''' '''
def __init__(self, def __init__(self,
channel_1_num, channel_1_num,
...@@ -188,7 +189,7 @@ def main(params): ...@@ -188,7 +189,7 @@ def main(params):
mnist_network.keep_prob: 1 - params['dropout_rate']} mnist_network.keep_prob: 1 - params['dropout_rate']}
) )
if i % 10 == 0: if i % 100 == 0:
test_acc = mnist_network.accuracy.eval( test_acc = mnist_network.accuracy.eval(
feed_dict={mnist_network.images: mnist.test.images, feed_dict={mnist_network.images: mnist.test.images,
mnist_network.labels: mnist.test.labels, mnist_network.labels: mnist.test.labels,
...@@ -207,38 +208,31 @@ def main(params): ...@@ -207,38 +208,31 @@ def main(params):
logger.debug('Final result is %g', test_acc) logger.debug('Final result is %g', test_acc)
logger.debug('Send final result done.') logger.debug('Send final result done.')
def get_params():
def generate_default_params(): ''' Get parameters from command line '''
''' parser = argparse.ArgumentParser()
Generate default parameters for mnist network. parser.add_argument("--data_dir", type=str, default='/tmp/tensorflow/mnist/input_data', help="data directory")
''' parser.add_argument("--dropout_rate", type=float, default=0.5, help="dropout rate")
params = { parser.add_argument("--channel_1_num", type=int, default=32)
'data_dir': '/tmp/tensorflow/mnist/input_data', parser.add_argument("--channel_2_num", type=int, default=64)
'dropout_rate': 0.5, parser.add_argument("--conv_size", type=int, default=5)
'channel_1_num': 32, parser.add_argument("--pool_size", type=int, default=2)
'channel_2_num': 64, parser.add_argument("--hidden_size", type=int, default=1024)
'conv_size': 5, parser.add_argument("--learning_rate", type=float, default=1e-4)
'pool_size': 2, parser.add_argument("--batch_num", type=int, default=2700)
'hidden_size': 1024, parser.add_argument("--batch_size", type=int, default=32)
'learning_rate': 1e-4,
'batch_size': 32} args, _ = parser.parse_known_args()
return params return args
if __name__ == '__main__': if __name__ == '__main__':
try: try:
# get parameters form tuner # get parameters form tuner
RCV_PARAMS = nni.get_next_parameter() tuner_params = nni.get_next_parameter()
logger.debug(RCV_PARAMS) logger.debug(tuner_params)
# run tuner_params['batch_num'] = tuner_params['TRIAL_BUDGET'] * 100
params = generate_default_params() params = vars(get_params())
params.update(RCV_PARAMS) params.update(tuner_params)
'''
If you use Hyperband, among the hyperparameters (i.e., key-value pairs) received by a trial,
there is one more key called `STEPS` besides the hyperparameters defined by user.
By using this `STEPS`, the trial can control how long it runs.
'''
params['batch_num'] = RCV_PARAMS['STEPS'] * 10
main(params) main(params)
except Exception as exception: except Exception as exception:
logger.exception(exception) logger.exception(exception)
......
...@@ -24,6 +24,7 @@ import random ...@@ -24,6 +24,7 @@ import random
import numpy as np import numpy as np
from nni.tuner import Tuner from nni.tuner import Tuner
from nni.utils import extract_scalar_reward
logger = logging.getLogger('ga_customer_tuner') logger = logging.getLogger('ga_customer_tuner')
...@@ -115,7 +116,7 @@ class CustomerTuner(Tuner): ...@@ -115,7 +116,7 @@ class CustomerTuner(Tuner):
parameters : dict of parameters parameters : dict of parameters
value: final metrics of the trial, including reward value: final metrics of the trial, including reward
''' '''
reward = self.extract_scalar_reward(value) reward = extract_scalar_reward(value)
if self.optimize_mode is OptimizeMode.Minimize: if self.optimize_mode is OptimizeMode.Minimize:
reward = -reward reward = -reward
......
...@@ -25,6 +25,7 @@ import os ...@@ -25,6 +25,7 @@ import os
from threading import Event, Lock, current_thread from threading import Event, Lock, current_thread
from nni.tuner import Tuner from nni.tuner import Tuner
from nni.utils import extract_scalar_reward
from graph import Graph, Layer, LayerType, Enum, graph_dumps, graph_loads, unique from graph import Graph, Layer, LayerType, Enum, graph_dumps, graph_loads, unique
...@@ -205,7 +206,7 @@ class CustomerTuner(Tuner): ...@@ -205,7 +206,7 @@ class CustomerTuner(Tuner):
logger.debug('acquiring lock for param {}'.format(parameter_id)) logger.debug('acquiring lock for param {}'.format(parameter_id))
self.thread_lock.acquire() self.thread_lock.acquire()
logger.debug('lock for current acquired') logger.debug('lock for current acquired')
reward = self.extract_scalar_reward(value) reward = extract_scalar_reward(value)
if self.optimize_mode is OptimizeMode.Minimize: if self.optimize_mode is OptimizeMode.Minimize:
reward = -reward reward = -reward
......
...@@ -291,9 +291,11 @@ class PAITrainingService implements TrainingService { ...@@ -291,9 +291,11 @@ class PAITrainingService implements TrainingService {
}; };
request(submitJobRequest, (error: Error, response: request.Response, body: any) => { request(submitJobRequest, (error: Error, response: request.Response, body: any) => {
if (error || response.statusCode >= 400) { if (error || response.statusCode >= 400) {
this.log.error(`PAI Training service: Submit trial ${trialJobId} to PAI Cluster failed!`); const errorMessage : string = error ? error.message :
`Submit trial ${trialJobId} failed, http code:${response.statusCode}, http body: ${response.body}`;
this.log.error(errorMessage);
trialJobDetail.status = 'FAILED'; trialJobDetail.status = 'FAILED';
deferred.reject(error ? error.message : 'Submit trial failed, http code: ' + response.statusCode); deferred.reject(new Error(errorMessage));
} else { } else {
trialJobDetail.submitTime = Date.now(); trialJobDetail.submitTime = Date.now();
deferred.resolve(trialJobDetail); deferred.resolve(trialJobDetail);
......
...@@ -31,6 +31,7 @@ import random ...@@ -31,6 +31,7 @@ import random
import numpy as np import numpy as np
from nni.tuner import Tuner from nni.tuner import Tuner
from nni.utils import extract_scalar_reward
from .. import parameter_expressions from .. import parameter_expressions
...@@ -287,7 +288,7 @@ class EvolutionTuner(Tuner): ...@@ -287,7 +288,7 @@ class EvolutionTuner(Tuner):
if value is dict, it should have "default" key. if value is dict, it should have "default" key.
value is final metrics of the trial. value is final metrics of the trial.
''' '''
reward = self.extract_scalar_reward(value) reward = extract_scalar_reward(value)
if parameter_id not in self.total_data: if parameter_id not in self.total_data:
raise RuntimeError('Received parameter_id not in total_data.') raise RuntimeError('Received parameter_id not in total_data.')
# restore the paramsters contains "_index" # restore the paramsters contains "_index"
......
...@@ -32,12 +32,13 @@ import json_tricks ...@@ -32,12 +32,13 @@ import json_tricks
from nni.protocol import CommandType, send from nni.protocol import CommandType, send
from nni.msg_dispatcher_base import MsgDispatcherBase from nni.msg_dispatcher_base import MsgDispatcherBase
from nni.common import init_logger from nni.common import init_logger
from nni.utils import extract_scalar_reward
from .. import parameter_expressions from .. import parameter_expressions
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
_next_parameter_id = 0 _next_parameter_id = 0
_KEY = 'STEPS' _KEY = 'TRIAL_BUDGET'
_epsilon = 1e-6 _epsilon = 1e-6
@unique @unique
...@@ -268,22 +269,6 @@ class Bracket(): ...@@ -268,22 +269,6 @@ class Bracket():
self.num_configs_to_run.append(len(hyper_configs)) self.num_configs_to_run.append(len(hyper_configs))
self.increase_i() self.increase_i()
def extract_scalar_reward(value, scalar_key='default'):
"""
Raises
------
RuntimeError
Incorrect final result: the final result should be float/int,
or a dict which has a key named "default" whose value is float/int.
"""
if isinstance(value, float) or isinstance(value, int):
reward = value
elif isinstance(value, dict) and scalar_key in value and isinstance(value[scalar_key], (float, int)):
reward = value[scalar_key]
else:
raise RuntimeError('Incorrect final result: the final result for %s should be float/int, or a dict which has a key named "default" whose value is float/int.' % str(self.__class__))
return reward
class Hyperband(MsgDispatcherBase): class Hyperband(MsgDispatcherBase):
"""Hyperband inherit from MsgDispatcherBase rather than Tuner, because it integrates both tuner's functions and assessor's functions. """Hyperband inherit from MsgDispatcherBase rather than Tuner, because it integrates both tuner's functions and assessor's functions.
This is an implementation that could fully leverage available resources, i.e., high parallelism. This is an implementation that could fully leverage available resources, i.e., high parallelism.
...@@ -320,7 +305,7 @@ class Hyperband(MsgDispatcherBase): ...@@ -320,7 +305,7 @@ class Hyperband(MsgDispatcherBase):
def load_checkpoint(self): def load_checkpoint(self):
pass pass
def save_checkpont(self): def save_checkpoint(self):
pass pass
def handle_initialize(self, data): def handle_initialize(self, data):
...@@ -333,7 +318,6 @@ class Hyperband(MsgDispatcherBase): ...@@ -333,7 +318,6 @@ class Hyperband(MsgDispatcherBase):
""" """
self.handle_update_search_space(data) self.handle_update_search_space(data)
send(CommandType.Initialized, '') send(CommandType.Initialized, '')
return True
def handle_request_trial_jobs(self, data): def handle_request_trial_jobs(self, data):
""" """
...@@ -345,21 +329,11 @@ class Hyperband(MsgDispatcherBase): ...@@ -345,21 +329,11 @@ class Hyperband(MsgDispatcherBase):
for _ in range(data): for _ in range(data):
self._request_one_trial_job() self._request_one_trial_job()
return True
def _request_one_trial_job(self): def _request_one_trial_job(self):
"""get one trial job, i.e., one hyperparameter configuration.""" """get one trial job, i.e., one hyperparameter configuration."""
if not self.generated_hyper_configs: if not self.generated_hyper_configs:
if self.curr_s < 0: if self.curr_s < 0:
# have tried all configurations self.curr_s = self.s_max
ret = {
'parameter_id': '-1_0_0',
'parameter_source': 'algorithm',
'parameters': ''
}
send(CommandType.NoMoreTrialJobs, json_tricks.dumps(ret))
self.credit += 1
return True
_logger.debug('create a new bracket, self.curr_s=%d', self.curr_s) _logger.debug('create a new bracket, self.curr_s=%d', self.curr_s)
self.brackets[self.curr_s] = Bracket(self.curr_s, self.s_max, self.eta, self.R, self.optimize_mode) self.brackets[self.curr_s] = Bracket(self.curr_s, self.s_max, self.eta, self.R, self.optimize_mode)
next_n, next_r = self.brackets[self.curr_s].get_n_r() next_n, next_r = self.brackets[self.curr_s].get_n_r()
...@@ -380,8 +354,6 @@ class Hyperband(MsgDispatcherBase): ...@@ -380,8 +354,6 @@ class Hyperband(MsgDispatcherBase):
} }
send(CommandType.NewTrialJob, json_tricks.dumps(ret)) send(CommandType.NewTrialJob, json_tricks.dumps(ret))
return True
def handle_update_search_space(self, data): def handle_update_search_space(self, data):
"""data: JSON object, which is search space """data: JSON object, which is search space
...@@ -393,8 +365,6 @@ class Hyperband(MsgDispatcherBase): ...@@ -393,8 +365,6 @@ class Hyperband(MsgDispatcherBase):
self.searchspace_json = data self.searchspace_json = data
self.random_state = np.random.RandomState() self.random_state = np.random.RandomState()
return True
def handle_trial_end(self, data): def handle_trial_end(self, data):
""" """
Parameters Parameters
...@@ -423,8 +393,6 @@ class Hyperband(MsgDispatcherBase): ...@@ -423,8 +393,6 @@ class Hyperband(MsgDispatcherBase):
send(CommandType.NewTrialJob, json_tricks.dumps(ret)) send(CommandType.NewTrialJob, json_tricks.dumps(ret))
self.credit -= 1 self.credit -= 1
return True
def handle_report_metric_data(self, data): def handle_report_metric_data(self, data):
""" """
Parameters Parameters
...@@ -450,7 +418,5 @@ class Hyperband(MsgDispatcherBase): ...@@ -450,7 +418,5 @@ class Hyperband(MsgDispatcherBase):
else: else:
raise ValueError('Data type not supported: {}'.format(data['type'])) raise ValueError('Data type not supported: {}'.format(data['type']))
return True
def handle_add_customized_trial(self, data): def handle_add_customized_trial(self, data):
pass pass
...@@ -29,6 +29,7 @@ import numpy as np ...@@ -29,6 +29,7 @@ import numpy as np
import hyperopt as hp import hyperopt as hp
from nni.tuner import Tuner from nni.tuner import Tuner
from nni.utils import extract_scalar_reward
logger = logging.getLogger('hyperopt_AutoML') logger = logging.getLogger('hyperopt_AutoML')
...@@ -241,7 +242,7 @@ class HyperoptTuner(Tuner): ...@@ -241,7 +242,7 @@ class HyperoptTuner(Tuner):
if value is dict, it should have "default" key. if value is dict, it should have "default" key.
value is final metrics of the trial. value is final metrics of the trial.
""" """
reward = self.extract_scalar_reward(value) reward = extract_scalar_reward(value)
# restore the paramsters contains '_index' # restore the paramsters contains '_index'
if parameter_id not in self.total_data: if parameter_id not in self.total_data:
raise RuntimeError('Received parameter_id not in total_data.') raise RuntimeError('Received parameter_id not in total_data.')
......
...@@ -38,6 +38,7 @@ import nni.metis_tuner.Regression_GP.OutlierDetection as gp_outlier_detection ...@@ -38,6 +38,7 @@ import nni.metis_tuner.Regression_GP.OutlierDetection as gp_outlier_detection
import nni.metis_tuner.Regression_GP.Prediction as gp_prediction import nni.metis_tuner.Regression_GP.Prediction as gp_prediction
import nni.metis_tuner.Regression_GP.Selection as gp_selection import nni.metis_tuner.Regression_GP.Selection as gp_selection
from nni.tuner import Tuner from nni.tuner import Tuner
from nni.utils import extract_scalar_reward
logger = logging.getLogger("Metis_Tuner_AutoML") logger = logging.getLogger("Metis_Tuner_AutoML")
...@@ -220,7 +221,7 @@ class MetisTuner(Tuner): ...@@ -220,7 +221,7 @@ class MetisTuner(Tuner):
value : dict/float value : dict/float
if value is dict, it should have "default" key. if value is dict, it should have "default" key.
""" """
value = self.extract_scalar_reward(value) value = extract_scalar_reward(value)
if self.optimize_mode == OptimizeMode.Maximize: if self.optimize_mode == OptimizeMode.Maximize:
value = -value value = -value
......
...@@ -92,7 +92,6 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -92,7 +92,6 @@ class MsgDispatcher(MsgDispatcherBase):
""" """
self.tuner.update_search_space(data) self.tuner.update_search_space(data)
send(CommandType.Initialized, '') send(CommandType.Initialized, '')
return True
def handle_request_trial_jobs(self, data): def handle_request_trial_jobs(self, data):
# data: number or trial jobs # data: number or trial jobs
...@@ -105,18 +104,15 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -105,18 +104,15 @@ class MsgDispatcher(MsgDispatcherBase):
# when parameters is None. # when parameters is None.
if len(params_list) < len(ids): if len(params_list) < len(ids):
send(CommandType.NoMoreTrialJobs, _pack_parameter(ids[0], '')) send(CommandType.NoMoreTrialJobs, _pack_parameter(ids[0], ''))
return True
def handle_update_search_space(self, data): def handle_update_search_space(self, data):
self.tuner.update_search_space(data) self.tuner.update_search_space(data)
return True
def handle_add_customized_trial(self, data): def handle_add_customized_trial(self, data):
# data: parameters # data: parameters
id_ = _create_parameter_id() id_ = _create_parameter_id()
_customized_parameter_ids.add(id_) _customized_parameter_ids.add(id_)
send(CommandType.NewTrialJob, _pack_parameter(id_, data, customized=True)) send(CommandType.NewTrialJob, _pack_parameter(id_, data, customized=True))
return True
def handle_report_metric_data(self, data): def handle_report_metric_data(self, data):
""" """
...@@ -135,8 +131,6 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -135,8 +131,6 @@ class MsgDispatcher(MsgDispatcherBase):
else: else:
raise ValueError('Data type not supported: {}'.format(data['type'])) raise ValueError('Data type not supported: {}'.format(data['type']))
return True
def handle_trial_end(self, data): def handle_trial_end(self, data):
""" """
data: it has three keys: trial_job_id, event, hyper_params data: it has three keys: trial_job_id, event, hyper_params
...@@ -150,7 +144,8 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -150,7 +144,8 @@ class MsgDispatcher(MsgDispatcherBase):
_trial_history.pop(trial_job_id) _trial_history.pop(trial_job_id)
if self.assessor is not None: if self.assessor is not None:
self.assessor.trial_end(trial_job_id, data['event'] == 'SUCCEEDED') self.assessor.trial_end(trial_job_id, data['event'] == 'SUCCEEDED')
return True if self.tuner is not None:
self.tuner.trial_end(json_tricks.loads(data['hyper_params'])['parameter_id'], data['event'] == 'SUCCEEDED')
def _handle_final_metric_data(self, data): def _handle_final_metric_data(self, data):
"""Call tuner to process final results """Call tuner to process final results
...@@ -166,19 +161,19 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -166,19 +161,19 @@ class MsgDispatcher(MsgDispatcherBase):
"""Call assessor to process intermediate results """Call assessor to process intermediate results
""" """
if data['type'] != 'PERIODICAL': if data['type'] != 'PERIODICAL':
return True return
if self.assessor is None: if self.assessor is None:
return True return
trial_job_id = data['trial_job_id'] trial_job_id = data['trial_job_id']
if trial_job_id in _ended_trials: if trial_job_id in _ended_trials:
return True return
history = _trial_history[trial_job_id] history = _trial_history[trial_job_id]
history[data['sequence']] = data['value'] history[data['sequence']] = data['value']
ordered_history = _sort_history(history) ordered_history = _sort_history(history)
if len(ordered_history) < data['sequence']: # no user-visible update since last time if len(ordered_history) < data['sequence']: # no user-visible update since last time
return True return
try: try:
result = self.assessor.assess_trial(trial_job_id, ordered_history) result = self.assessor.assess_trial(trial_job_id, ordered_history)
......
...@@ -152,8 +152,7 @@ class MsgDispatcherBase(Recoverable): ...@@ -152,8 +152,7 @@ class MsgDispatcherBase(Recoverable):
} }
if command not in command_handlers: if command not in command_handlers:
raise AssertionError('Unsupported command: {}'.format(command)) raise AssertionError('Unsupported command: {}'.format(command))
command_handlers[command](data)
return command_handlers[command](data)
def handle_ping(self, data): def handle_ping(self, data):
pass pass
......
...@@ -150,6 +150,8 @@ class MultiPhaseMsgDispatcher(MsgDispatcherBase): ...@@ -150,6 +150,8 @@ class MultiPhaseMsgDispatcher(MsgDispatcherBase):
_trial_history.pop(trial_job_id) _trial_history.pop(trial_job_id)
if self.assessor is not None: if self.assessor is not None:
self.assessor.trial_end(trial_job_id, data['event'] == 'SUCCEEDED') self.assessor.trial_end(trial_job_id, data['event'] == 'SUCCEEDED')
if self.tuner is not None:
self.tuner.trial_end(json_tricks.loads(data['hyper_params'])['parameter_id'], data['event'] == 'SUCCEEDED', trial_job_id)
return True return True
def _handle_intermediate_metric_data(self, data): def _handle_intermediate_metric_data(self, data):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment