Unverified Commit ac6aee81 authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Multiphase refactor and support OpenPAI training service. (#1138)

* Refactor multiphase interface

* Implement multiphase on PAI

* update multiphase doc
parent 1bd3637f
......@@ -31,7 +31,7 @@ class CustomizedTuner(Tuner):
def __init__(self, ...):
...
def receive_trial_result(self, parameter_id, parameters, value):
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
'''
Receive trial's final result.
parameter_id: int
......@@ -41,7 +41,7 @@ class CustomizedTuner(Tuner):
# your code implements here.
...
def generate_parameters(self, parameter_id):
def generate_parameters(self, parameter_id, **kwargs):
'''
Returns a set of trial (hyper-)parameters, as a serializable object
parameter_id: int
......@@ -51,7 +51,7 @@ class CustomizedTuner(Tuner):
...
```
`receive_trial_result` will receive the `parameter_id, parameters, value` as parameters input. Also, Tuner will receive the `value` object are exactly same value that Trial send.
`receive_trial_result` will receive the `parameter_id, parameters, value` as parameters input. Also, Tuner will receive the `value` object are exactly same value that Trial send. If `multiPhase` is set to `true` in the experiment configuration file, an additional `trial_job_id` parameter is passed to `receive_trial_result` and `generate_parameters` through the `**kwargs` parameter.
The `your_parameters` return from `generate_parameters` function, will be package as json object by NNI SDK. NNI SDK will unpack json object so the Trial will receive the exact same `your_parameters` from Tuner.
......@@ -59,7 +59,7 @@ For example:
If the you implement the `generate_parameters` like this:
```python
def generate_parameters(self, parameter_id):
def generate_parameters(self, parameter_id, **kwargs):
'''
Returns a set of trial (hyper-)parameters, as a serializable object
parameter_id: int
......
......@@ -38,7 +38,15 @@ To enable multi-phase, you should also add `multiPhase: true` in your experiment
### Write a tuner that leverages multi-phase:
Before writing a multi-phase tuner, we highly suggest you to go through [Customize Tuner](https://nni.readthedocs.io/en/latest/Customize_Tuner.html). Different from writing a normal tuner, your tuner needs to inherit from `MultiPhaseTuner` (in nni.multi_phase_tuner). The key difference between `Tuner` and `MultiPhaseTuner` is that the methods in MultiPhaseTuner are aware of additional information, that is, `trial_job_id`. With this information, the tuner could know which trial is requesting a configuration, and which trial is reporting results. This information provides enough flexibility for your tuner to deal with different trials and different phases. For example, you may want to use the trial_job_id parameter of generate_parameters method to generate hyperparameters for a specific trial job.
Before writing a multi-phase tuner, we highly suggest you to go through [Customize Tuner](https://nni.readthedocs.io/en/latest/Customize_Tuner.html). Same as writing a normal tuner, your tuner needs to inherit from `Tuner` class. When you enable multi-phase through configuration (set `multiPhase` to true), your tuner will get an additional parameter `trial_job_id` via tuner's following methods:
```
generate_parameters
generate_multiple_parameters
receive_trial_result
receive_customized_trial_result
trial_end
```
With this information, the tuner could know which trial is requesting a configuration, and which trial is reporting results. This information provides enough flexibility for your tuner to deal with different trials and different phases. For example, you may want to use the trial_job_id parameter of generate_parameters method to generate hyperparameters for a specific trial job.
Of course, to use your multi-phase tuner, __you should add `multiPhase: true` in your experiment YAML configure file__.
......
......@@ -79,7 +79,7 @@ class CustomerTuner(Tuner):
logger.debug('init population done.')
return
def generate_parameters(self, parameter_id):
def generate_parameters(self, parameter_id, **kwargs):
"""Returns a set of trial graph config, as a serializable object.
parameter_id : int
"""
......@@ -109,7 +109,7 @@ class CustomerTuner(Tuner):
return temp
def receive_trial_result(self, parameter_id, parameters, value):
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
'''
Record an observation of the objective function
parameter_id : int
......
......@@ -49,12 +49,12 @@ class RandomNASTuner(Tuner):
self.searchspace_json = search_space
self.random_state = np.random.RandomState()
def generate_parameters(self, parameter_id):
def generate_parameters(self, parameter_id, **kwargs):
'''generate
'''
return random_archi_generator(self.searchspace_json, self.random_state)
def receive_trial_result(self, parameter_id, parameters, value):
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
'''receive
'''
pass
......@@ -112,7 +112,7 @@ class CustomerTuner(Tuner):
population.append(Individual(indiv_id=self.generate_new_id(), graph_cfg=graph_tmp, result=None))
return population
def generate_parameters(self, parameter_id):
def generate_parameters(self, parameter_id, **kwargs):
"""Returns a set of trial graph config, as a serializable object.
An example configuration:
```json
......@@ -196,7 +196,7 @@ class CustomerTuner(Tuner):
logger.debug("trial {} ready".format(indiv.indiv_id))
return param_json
def receive_trial_result(self, parameter_id, parameters, value):
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
'''
Record an observation of the objective function
parameter_id : int
......
......@@ -58,6 +58,10 @@ export abstract class ClusterJobRestServer extends RestServer {
this.port = basePort + 1;
}
get apiRootUrl(): string {
return this.API_ROOT_URL;
}
public get clusterRestServerPort(): number {
if (this.port === undefined) {
throw new Error('PAI Rest server port is undefined');
......@@ -87,7 +91,7 @@ export abstract class ClusterJobRestServer extends RestServer {
protected abstract handleTrialMetrics(jobId : string, trialMetrics : any[]) : void;
// tslint:disable: no-unsafe-any no-any
private createRestHandler() : Router {
protected createRestHandler() : Router {
const router: Router = Router();
router.use((req: Request, res: Response, next: any) => {
......
......@@ -64,11 +64,11 @@ else
fi`;
export const PAI_TRIAL_COMMAND_FORMAT: string =
`export NNI_PLATFORM=pai NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={2} NNI_EXP_ID={3} NNI_TRIAL_SEQ_ID={4} \
`export NNI_PLATFORM=pai NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={2} NNI_EXP_ID={3} NNI_TRIAL_SEQ_ID={4} MULTI_PHASE={5} \
&& cd $NNI_SYS_DIR && sh install_nni.sh \
&& python3 -m nni_trial_tool.trial_keeper --trial_command '{5}' --nnimanager_ip '{6}' --nnimanager_port '{7}' \
--pai_hdfs_output_dir '{8}' --pai_hdfs_host '{9}' --pai_user_name {10} --nni_hdfs_exp_dir '{11}' --webhdfs_path '/webhdfs/api/v1' \
--nni_manager_version '{12}' --log_collection '{13}'`;
&& python3 -m nni_trial_tool.trial_keeper --trial_command '{6}' --nnimanager_ip '{7}' --nnimanager_port '{8}' \
--pai_hdfs_output_dir '{9}' --pai_hdfs_host '{10}' --pai_user_name {11} --nni_hdfs_exp_dir '{12}' --webhdfs_path '/webhdfs/api/v1' \
--nni_manager_version '{13}' --log_collection '{14}'`;
export const PAI_OUTPUT_DIR_FORMAT: string =
`hdfs://{0}:9000/`;
......
......@@ -19,17 +19,26 @@
'use strict';
import { Request, Response, Router } from 'express';
import { Inject } from 'typescript-ioc';
import * as component from '../../common/component';
import { ClusterJobRestServer } from '../common/clusterJobRestServer';
import { PAITrainingService } from './paiTrainingService';
export interface ParameterFileMeta {
readonly experimentId: string;
readonly trialId: string;
readonly filePath: string;
}
/**
* PAI Training service Rest server, provides rest API to support pai job metrics update
*
*/
@component.Singleton
export class PAIJobRestServer extends ClusterJobRestServer {
private parameterFileMetaList: ParameterFileMeta[] = [];
@Inject
private readonly paiTrainingService : PAITrainingService;
......@@ -52,4 +61,33 @@ export class PAIJobRestServer extends ClusterJobRestServer {
});
}
}
protected createRestHandler(): Router {
const router: Router = super.createRestHandler();
router.post(`/parameter-file-meta`, (req: Request, res: Response) => {
try {
this.log.info(`POST /parameter-file-meta, body is ${JSON.stringify(req.body)}`);
this.parameterFileMetaList.push(req.body);
res.send();
} catch (err) {
this.log.error(`POST parameter-file-meta error: ${err}`);
res.status(500);
res.send(err.message);
}
});
router.get(`/parameter-file-meta`, (req: Request, res: Response) => {
try {
this.log.info(`GET /parameter-file-meta`);
res.send(this.parameterFileMetaList);
} catch (err) {
this.log.error(`GET parameter-file-meta error: ${err}`);
res.status(500);
res.send(err.message);
}
});
return router;
}
}
......@@ -33,7 +33,7 @@ import { MethodNotImplementedError } from '../../common/errors';
import { getExperimentId, getInitTrialSequenceId } from '../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../common/log';
import {
JobApplicationForm, NNIManagerIpConfig, TrainingService,
HyperParameters, JobApplicationForm, NNIManagerIpConfig, TrainingService,
TrialJobApplicationForm, TrialJobDetail, TrialJobMetric
} from '../../common/trainingService';
import { delay, generateParamFileName,
......@@ -45,7 +45,7 @@ import { HDFSClientUtility } from './hdfsClientUtility';
import { NNIPAITrialConfig, PAIClusterConfig, PAIJobConfig, PAITaskRole } from './paiConfig';
import { PAI_LOG_PATH_FORMAT, PAI_OUTPUT_DIR_FORMAT, PAI_TRIAL_COMMAND_FORMAT, PAITrialJobDetail } from './paiData';
import { PAIJobInfoCollector } from './paiJobInfoCollector';
import { PAIJobRestServer } from './paiJobRestServer';
import { PAIJobRestServer, ParameterFileMeta } from './paiJobRestServer';
import * as WebHDFS from 'webhdfs';
......@@ -79,6 +79,7 @@ class PAITrainingService implements TrainingService {
private copyExpCodeDirPromise?: Promise<void>;
private versionCheck: boolean = true;
private logCollection: string;
private isMultiPhase: boolean = false;
constructor() {
this.log = getLogger();
......@@ -179,12 +180,22 @@ class PAITrainingService implements TrainingService {
return deferred.promise;
}
public updateTrialJob(trialJobId: string, form: JobApplicationForm): Promise<TrialJobDetail> {
throw new MethodNotImplementedError();
public async updateTrialJob(trialJobId: string, form: JobApplicationForm): Promise<TrialJobDetail> {
const trialJobDetail: undefined | TrialJobDetail = this.trialJobsMap.get(trialJobId);
if (trialJobDetail === undefined) {
throw new Error(`updateTrialJob failed: ${trialJobId} not found`);
}
if (form.jobType === 'TRIAL') {
await this.writeParameterFile(trialJobId, (<TrialJobApplicationForm>form).hyperParameters);
} else {
throw new Error(`updateTrialJob failed: jobType ${form.jobType} not supported.`);
}
return trialJobDetail;
}
public get isMultiPhaseJobSupported(): boolean {
return false;
return true;
}
// tslint:disable:no-http-string
......@@ -336,6 +347,9 @@ class PAITrainingService implements TrainingService {
case TrialConfigMetadataKey.LOG_COLLECTION:
this.logCollection = value;
break;
case TrialConfigMetadataKey.MULTI_PHASE:
this.isMultiPhase = (value === 'true' || value === 'True');
break;
default:
//Reject for unknown keys
throw new Error(`Uknown key: ${key}`);
......@@ -445,6 +459,7 @@ class PAITrainingService implements TrainingService {
trialJobId,
this.experimentId,
trialJobDetail.sequenceId,
this.isMultiPhase,
this.paiTrialConfig.command,
nniManagerIp,
this.paiRestServerPort,
......@@ -632,7 +647,50 @@ class PAITrainingService implements TrainingService {
return Promise.race([timeoutDelay, deferred.promise])
.finally(() => { clearTimeout(timeoutId); });
}
// tslint:enable:no-any no-unsafe-any no-http-string
private async writeParameterFile(trialJobId: string, hyperParameters: HyperParameters): Promise<void> {
if (this.paiClusterConfig === undefined) {
throw new Error('PAI Cluster config is not initialized');
}
if (this.paiTrialConfig === undefined) {
throw new Error('PAI trial config is not initialized');
}
const trialLocalTempFolder: string = path.join(getExperimentRootDir(), 'trials-local', trialJobId);
const hpFileName: string = generateParamFileName(hyperParameters);
const localFilepath: string = path.join(trialLocalTempFolder, hpFileName);
await fs.promises.writeFile(localFilepath, hyperParameters.value, { encoding: 'utf8' });
const hdfsCodeDir: string = HDFSClientUtility.getHdfsTrialWorkDir(this.paiClusterConfig.userName, trialJobId);
const hdfsHpFilePath: string = path.join(hdfsCodeDir, hpFileName);
await HDFSClientUtility.copyFileToHdfs(localFilepath, hdfsHpFilePath, this.hdfsClient);
await this.postParameterFileMeta({
experimentId: this.experimentId,
trialId: trialJobId,
filePath: hdfsHpFilePath
});
}
private postParameterFileMeta(parameterFileMeta: ParameterFileMeta): Promise<void> {
const deferred : Deferred<void> = new Deferred<void>();
const restServer: PAIJobRestServer = component.get(PAIJobRestServer);
const req: request.Options = {
uri: `${restServer.endPoint}${restServer.apiRootUrl}/parameter-file-meta`,
method: 'POST',
json: true,
body: parameterFileMeta
};
request(req, (err: Error, res: request.Response) => {
if (err) {
deferred.reject(err);
} else {
deferred.resolve();
}
});
return deferred.promise;
}
}
export { PAITrainingService };
......@@ -28,9 +28,8 @@ import json
import importlib
from .constants import ModuleName, ClassName, ClassArgs, AdvisorModuleName, AdvisorClassName
from nni.common import enable_multi_thread
from nni.common import enable_multi_thread, enable_multi_phase
from nni.msg_dispatcher import MsgDispatcher
from nni.multi_phase.multi_phase_dispatcher import MultiPhaseMsgDispatcher
logger = logging.getLogger('nni.main')
logger.debug('START')
......@@ -126,6 +125,8 @@ def main():
args = parse_args()
if args.multi_thread:
enable_multi_thread()
if args.multi_phase:
enable_multi_phase()
if args.advisor_class_name:
# advisor is enabled and starts to run
......@@ -180,9 +181,6 @@ def main():
if assessor is None:
raise AssertionError('Failed to create Assessor instance')
if args.multi_phase:
dispatcher = MultiPhaseMsgDispatcher(tuner, assessor)
else:
dispatcher = MsgDispatcher(tuner, assessor)
try:
......
......@@ -78,7 +78,7 @@ class BatchTuner(Tuner):
"""
self.values = self.is_valid(search_space)
def generate_parameters(self, parameter_id):
def generate_parameters(self, parameter_id, **kwargs):
"""Returns a dict of trial (hyper-)parameters, as a serializable object.
Parameters
......@@ -90,7 +90,7 @@ class BatchTuner(Tuner):
raise nni.NoMoreTrialError('no more parameters now.')
return self.values[self.count]
def receive_trial_result(self, parameter_id, parameters, value):
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
pass
def import_data(self, data):
......
......@@ -69,6 +69,7 @@ def init_logger(logger_file_path, log_level_name='info'):
sys.stdout = _LoggerFileWrapper(logger_file)
_multi_thread = False
_multi_phase = False
def enable_multi_thread():
global _multi_thread
......@@ -76,3 +77,10 @@ def enable_multi_thread():
def multi_thread_enabled():
return _multi_thread
def enable_multi_phase():
global _multi_phase
_multi_phase = True
def multi_phase_enabled():
return _multi_phase
......@@ -188,7 +188,7 @@ class EvolutionTuner(Tuner):
self.searchspace_json, is_rand, self.random_state)
self.population.append(Individual(config=config))
def generate_parameters(self, parameter_id):
def generate_parameters(self, parameter_id, **kwargs):
"""Returns a dict of trial (hyper-)parameters, as a serializable object.
Parameters
......@@ -232,7 +232,7 @@ class EvolutionTuner(Tuner):
config = split_index(total_config)
return config
def receive_trial_result(self, parameter_id, parameters, value):
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
'''Record the result from a trial
Parameters
......
......@@ -137,7 +137,7 @@ class GridSearchTuner(Tuner):
'''
self.expanded_search_space = self.json2parameter(search_space)
def generate_parameters(self, parameter_id):
def generate_parameters(self, parameter_id, **kwargs):
self.count += 1
while (self.count <= len(self.expanded_search_space)-1):
_params_tuple = convert_dict2tuple(self.expanded_search_space[self.count])
......@@ -147,7 +147,7 @@ class GridSearchTuner(Tuner):
return self.expanded_search_space[self.count]
raise nni.NoMoreTrialError('no more parameters now.')
def receive_trial_result(self, parameter_id, parameters, value):
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
pass
def import_data(self, data):
......
......@@ -248,7 +248,7 @@ class HyperoptTuner(Tuner):
verbose=0)
self.rval.catch_eval_exceptions = False
def generate_parameters(self, parameter_id):
def generate_parameters(self, parameter_id, **kwargs):
"""
Returns a set of trial (hyper-)parameters, as a serializable object.
......@@ -269,7 +269,7 @@ class HyperoptTuner(Tuner):
params = split_index(total_params)
return params
def receive_trial_result(self, parameter_id, parameters, value):
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
"""
Record an observation of the objective function
......
......@@ -174,7 +174,7 @@ class MetisTuner(Tuner):
return output
def generate_parameters(self, parameter_id):
def generate_parameters(self, parameter_id, **kwargs):
"""Generate next parameter for trial
If the number of trial result is lower than cold start number,
metis will first random generate some parameters.
......@@ -205,7 +205,7 @@ class MetisTuner(Tuner):
return results
def receive_trial_result(self, parameter_id, parameters, value):
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
"""Tuner receive result from trial.
Parameters
......
......@@ -18,7 +18,6 @@
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import os
import logging
from collections import defaultdict
import json_tricks
......@@ -26,7 +25,7 @@ import json_tricks
from .protocol import CommandType, send
from .msg_dispatcher_base import MsgDispatcherBase
from .assessor import AssessResult
from .common import multi_thread_enabled
from .common import multi_thread_enabled, multi_phase_enabled
from .env_vars import dispatcher_env_vars
_logger = logging.getLogger(__name__)
......@@ -61,13 +60,19 @@ def _create_parameter_id():
_next_parameter_id += 1
return _next_parameter_id - 1
def _pack_parameter(parameter_id, params, customized=False):
def _pack_parameter(parameter_id, params, customized=False, trial_job_id=None, parameter_index=None):
_trial_params[parameter_id] = params
ret = {
'parameter_id': parameter_id,
'parameter_source': 'customized' if customized else 'algorithm',
'parameters': params
}
if trial_job_id is not None:
ret['trial_job_id'] = trial_job_id
if parameter_index is not None:
ret['parameter_index'] = parameter_index
else:
ret['parameter_index'] = 0
return json_tricks.dumps(ret)
class MsgDispatcher(MsgDispatcherBase):
......@@ -133,8 +138,13 @@ class MsgDispatcher(MsgDispatcherBase):
elif data['type'] == 'PERIODICAL':
if self.assessor is not None:
self._handle_intermediate_metric_data(data)
else:
pass
elif data['type'] == 'REQUEST_PARAMETER':
assert multi_phase_enabled()
assert data['trial_job_id'] is not None
assert data['parameter_index'] is not None
param_id = _create_parameter_id()
param = self.tuner.generate_parameters(param_id, trial_job_id=data['trial_job_id'])
send(CommandType.SendTrialJobParameter, _pack_parameter(param_id, param, trial_job_id=data['trial_job_id'], parameter_index=data['parameter_index']))
else:
raise ValueError('Data type not supported: {}'.format(data['type']))
......@@ -160,7 +170,13 @@ class MsgDispatcher(MsgDispatcherBase):
id_ = data['parameter_id']
value = data['value']
if id_ in _customized_parameter_ids:
if multi_phase_enabled():
self.tuner.receive_customized_trial_result(id_, _trial_params[id_], value, trial_job_id=data['trial_job_id'])
else:
self.tuner.receive_customized_trial_result(id_, _trial_params[id_], value)
else:
if multi_phase_enabled():
self.tuner.receive_trial_result(id_, _trial_params[id_], value, trial_job_id=data['trial_job_id'])
else:
self.tuner.receive_trial_result(id_, _trial_params[id_], value)
......
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import logging
from collections import defaultdict
import json_tricks
from nni.protocol import CommandType, send
from nni.msg_dispatcher_base import MsgDispatcherBase
from nni.assessor import AssessResult
_logger = logging.getLogger(__name__)
# Assessor global variables
_trial_history = defaultdict(dict)
'''key: trial job ID; value: intermediate results, mapping from sequence number to data'''
_ended_trials = set()
'''trial_job_id of all ended trials.
We need this because NNI manager may send metrics after reporting a trial ended.
TODO: move this logic to NNI manager
'''
def _sort_history(history):
ret = [ ]
for i, _ in enumerate(history):
if i in history:
ret.append(history[i])
else:
break
return ret
# Tuner global variables
_next_parameter_id = 0
_trial_params = {}
'''key: trial job ID; value: parameters'''
_customized_parameter_ids = set()
def _create_parameter_id():
global _next_parameter_id # pylint: disable=global-statement
_next_parameter_id += 1
return _next_parameter_id - 1
def _pack_parameter(parameter_id, params, customized=False, trial_job_id=None, parameter_index=None):
_trial_params[parameter_id] = params
ret = {
'parameter_id': parameter_id,
'parameter_source': 'customized' if customized else 'algorithm',
'parameters': params
}
if trial_job_id is not None:
ret['trial_job_id'] = trial_job_id
if parameter_index is not None:
ret['parameter_index'] = parameter_index
else:
ret['parameter_index'] = 0
return json_tricks.dumps(ret)
class MultiPhaseMsgDispatcher(MsgDispatcherBase):
def __init__(self, tuner, assessor=None):
super(MultiPhaseMsgDispatcher, self).__init__()
self.tuner = tuner
self.assessor = assessor
if assessor is None:
_logger.debug('Assessor is not configured')
def load_checkpoint(self):
self.tuner.load_checkpoint()
if self.assessor is not None:
self.assessor.load_checkpoint()
def save_checkpoint(self):
self.tuner.save_checkpoint()
if self.assessor is not None:
self.assessor.save_checkpoint()
def handle_initialize(self, data):
'''
data is search space
'''
self.tuner.update_search_space(data)
send(CommandType.Initialized, '')
return True
def handle_request_trial_jobs(self, data):
# data: number or trial jobs
ids = [_create_parameter_id() for _ in range(data)]
params_list = self.tuner.generate_multiple_parameters(ids)
assert len(ids) == len(params_list)
for i, _ in enumerate(ids):
send(CommandType.NewTrialJob, _pack_parameter(ids[i], params_list[i]))
return True
def handle_update_search_space(self, data):
self.tuner.update_search_space(data)
return True
def handle_import_data(self, data):
"""import additional data for tuning
data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
self.tuner.import_data(data)
return True
def handle_add_customized_trial(self, data):
# data: parameters
id_ = _create_parameter_id()
_customized_parameter_ids.add(id_)
send(CommandType.NewTrialJob, _pack_parameter(id_, data, customized=True))
return True
def handle_report_metric_data(self, data):
trial_job_id = data['trial_job_id']
if data['type'] == 'FINAL':
id_ = data['parameter_id']
if id_ in _customized_parameter_ids:
self.tuner.receive_customized_trial_result(id_, _trial_params[id_], data['value'], trial_job_id)
else:
self.tuner.receive_trial_result(id_, _trial_params[id_], data['value'], trial_job_id)
elif data['type'] == 'PERIODICAL':
if self.assessor is not None:
self._handle_intermediate_metric_data(data)
else:
pass
elif data['type'] == 'REQUEST_PARAMETER':
assert data['trial_job_id'] is not None
assert data['parameter_index'] is not None
param_id = _create_parameter_id()
param = self.tuner.generate_parameters(param_id, trial_job_id)
send(CommandType.SendTrialJobParameter, _pack_parameter(param_id, param, trial_job_id=data['trial_job_id'], parameter_index=data['parameter_index']))
else:
raise ValueError('Data type not supported: {}'.format(data['type']))
return True
def handle_trial_end(self, data):
trial_job_id = data['trial_job_id']
_ended_trials.add(trial_job_id)
if trial_job_id in _trial_history:
_trial_history.pop(trial_job_id)
if self.assessor is not None:
self.assessor.trial_end(trial_job_id, data['event'] == 'SUCCEEDED')
if self.tuner is not None:
self.tuner.trial_end(json_tricks.loads(data['hyper_params'])['parameter_id'], data['event'] == 'SUCCEEDED', trial_job_id)
return True
def handle_import_data(self, data):
pass
def _handle_intermediate_metric_data(self, data):
if data['type'] != 'PERIODICAL':
return True
if self.assessor is None:
return True
trial_job_id = data['trial_job_id']
if trial_job_id in _ended_trials:
return True
history = _trial_history[trial_job_id]
history[data['sequence']] = data['value']
ordered_history = _sort_history(history)
if len(ordered_history) < data['sequence']: # no user-visible update since last time
return True
try:
result = self.assessor.assess_trial(trial_job_id, ordered_history)
except Exception as e:
_logger.exception('Assessor error')
if isinstance(result, bool):
result = AssessResult.Good if result else AssessResult.Bad
elif not isinstance(result, AssessResult):
msg = 'Result of Assessor.assess_trial must be an object of AssessResult, not %s'
raise RuntimeError(msg % type(result))
if result is AssessResult.Bad:
_logger.debug('BAD, kill %s', trial_job_id)
send(CommandType.KillTrialJob, json_tricks.dumps(trial_job_id))
else:
_logger.debug('GOOD')
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import logging
from nni.recoverable import Recoverable
_logger = logging.getLogger(__name__)
class MultiPhaseTuner(Recoverable):
# pylint: disable=no-self-use,unused-argument
def generate_parameters(self, parameter_id, trial_job_id=None):
"""Returns a set of trial (hyper-)parameters, as a serializable object.
User code must override either this function or 'generate_multiple_parameters()'.
parameter_id: identifier of the parameter (int)
"""
raise NotImplementedError('Tuner: generate_parameters not implemented')
def generate_multiple_parameters(self, parameter_id_list):
"""Returns multiple sets of trial (hyper-)parameters, as iterable of serializable objects.
Call 'generate_parameters()' by 'count' times by default.
User code must override either this function or 'generate_parameters()'.
parameter_id_list: list of int
"""
return [self.generate_parameters(parameter_id) for parameter_id in parameter_id_list]
def receive_trial_result(self, parameter_id, parameters, value, trial_job_id):
"""Invoked when a trial reports its final result. Must override.
parameter_id: identifier of the parameter (int)
parameters: object created by 'generate_parameters()'
value: object reported by trial
trial_job_id: identifier of the trial (str)
"""
raise NotImplementedError('Tuner: receive_trial_result not implemented')
def receive_customized_trial_result(self, parameter_id, parameters, value, trial_job_id):
"""Invoked when a trial added by WebUI reports its final result. Do nothing by default.
parameter_id: identifier of the parameter (int)
parameters: object created by user
value: object reported by trial
trial_job_id: identifier of the trial (str)
"""
_logger.info('Customized trial job %s ignored by tuner', parameter_id)
def trial_end(self, parameter_id, success, trial_job_id):
"""Invoked when a trial is completed or terminated. Do nothing by default.
parameter_id: identifier of the parameter (int)
success: True if the trial successfully completed; False if failed or terminated
trial_job_id: identifier of the trial (str)
"""
pass
def update_search_space(self, search_space):
"""Update the search space of tuner. Must override.
search_space: JSON object
"""
raise NotImplementedError('Tuner: update_search_space not implemented')
def import_data(self, data):
"""Import additional data for tuning
data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
pass
def load_checkpoint(self):
"""Load the checkpoint of tuner.
path: checkpoint directory for tuner
"""
checkpoin_path = self.get_checkpoint_path()
_logger.info('Load checkpoint ignored by tuner, checkpoint path: %s' % checkpoin_path)
def save_checkpoint(self):
"""Save the checkpoint of tuner.
path: checkpoint directory for tuner
"""
checkpoin_path = self.get_checkpoint_path()
_logger.info('Save checkpoint ignored by tuner, checkpoint path: %s' % checkpoin_path)
def _on_exit(self):
pass
def _on_error(self):
pass
def import_data(self, data):
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment