"vscode:/vscode.git/clone" did not exist on "8d6f237018952a2aa3de56b5c523b598bf1a2fbb"
Unverified Commit 1c56fea8 authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Merge pull request #21 from microsoft/master

pull code
parents 12410686 97829ccd
......@@ -64,11 +64,11 @@ else
fi`;
export const PAI_TRIAL_COMMAND_FORMAT: string =
`export NNI_PLATFORM=pai NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={2} NNI_EXP_ID={3} NNI_TRIAL_SEQ_ID={4} \
`export NNI_PLATFORM=pai NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={2} NNI_EXP_ID={3} NNI_TRIAL_SEQ_ID={4} MULTI_PHASE={5} \
&& cd $NNI_SYS_DIR && sh install_nni.sh \
&& python3 -m nni_trial_tool.trial_keeper --trial_command '{5}' --nnimanager_ip '{6}' --nnimanager_port '{7}' \
--pai_hdfs_output_dir '{8}' --pai_hdfs_host '{9}' --pai_user_name {10} --nni_hdfs_exp_dir '{11}' --webhdfs_path '/webhdfs/api/v1' \
--nni_manager_version '{12}' --log_collection '{13}'`;
&& python3 -m nni_trial_tool.trial_keeper --trial_command '{6}' --nnimanager_ip '{7}' --nnimanager_port '{8}' \
--pai_hdfs_output_dir '{9}' --pai_hdfs_host '{10}' --pai_user_name {11} --nni_hdfs_exp_dir '{12}' --webhdfs_path '/webhdfs/api/v1' \
--nni_manager_version '{13}' --log_collection '{14}'`;
export const PAI_OUTPUT_DIR_FORMAT: string =
`hdfs://{0}:9000/`;
......
......@@ -19,17 +19,26 @@
'use strict';
import { Request, Response, Router } from 'express';
import { Inject } from 'typescript-ioc';
import * as component from '../../common/component';
import { ClusterJobRestServer } from '../common/clusterJobRestServer';
import { PAITrainingService } from './paiTrainingService';
export interface ParameterFileMeta {
readonly experimentId: string;
readonly trialId: string;
readonly filePath: string;
}
/**
* PAI Training service Rest server, provides rest API to support pai job metrics update
*
*/
@component.Singleton
export class PAIJobRestServer extends ClusterJobRestServer {
private parameterFileMetaList: ParameterFileMeta[] = [];
@Inject
private readonly paiTrainingService : PAITrainingService;
......@@ -52,4 +61,33 @@ export class PAIJobRestServer extends ClusterJobRestServer {
});
}
}
protected createRestHandler(): Router {
const router: Router = super.createRestHandler();
router.post(`/parameter-file-meta`, (req: Request, res: Response) => {
try {
this.log.info(`POST /parameter-file-meta, body is ${JSON.stringify(req.body)}`);
this.parameterFileMetaList.push(req.body);
res.send();
} catch (err) {
this.log.error(`POST parameter-file-meta error: ${err}`);
res.status(500);
res.send(err.message);
}
});
router.get(`/parameter-file-meta`, (req: Request, res: Response) => {
try {
this.log.info(`GET /parameter-file-meta`);
res.send(this.parameterFileMetaList);
} catch (err) {
this.log.error(`GET parameter-file-meta error: ${err}`);
res.status(500);
res.send(err.message);
}
});
return router;
}
}
......@@ -33,7 +33,7 @@ import { MethodNotImplementedError } from '../../common/errors';
import { getExperimentId, getInitTrialSequenceId } from '../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../common/log';
import {
JobApplicationForm, NNIManagerIpConfig, TrainingService,
HyperParameters, JobApplicationForm, NNIManagerIpConfig, TrainingService,
TrialJobApplicationForm, TrialJobDetail, TrialJobMetric
} from '../../common/trainingService';
import { delay, generateParamFileName,
......@@ -45,7 +45,7 @@ import { HDFSClientUtility } from './hdfsClientUtility';
import { NNIPAITrialConfig, PAIClusterConfig, PAIJobConfig, PAITaskRole } from './paiConfig';
import { PAI_LOG_PATH_FORMAT, PAI_OUTPUT_DIR_FORMAT, PAI_TRIAL_COMMAND_FORMAT, PAITrialJobDetail } from './paiData';
import { PAIJobInfoCollector } from './paiJobInfoCollector';
import { PAIJobRestServer } from './paiJobRestServer';
import { PAIJobRestServer, ParameterFileMeta } from './paiJobRestServer';
import * as WebHDFS from 'webhdfs';
......@@ -79,6 +79,7 @@ class PAITrainingService implements TrainingService {
private copyExpCodeDirPromise?: Promise<void>;
private versionCheck: boolean = true;
private logCollection: string;
private isMultiPhase: boolean = false;
constructor() {
this.log = getLogger();
......@@ -179,12 +180,22 @@ class PAITrainingService implements TrainingService {
return deferred.promise;
}
public updateTrialJob(trialJobId: string, form: JobApplicationForm): Promise<TrialJobDetail> {
throw new MethodNotImplementedError();
public async updateTrialJob(trialJobId: string, form: JobApplicationForm): Promise<TrialJobDetail> {
const trialJobDetail: undefined | TrialJobDetail = this.trialJobsMap.get(trialJobId);
if (trialJobDetail === undefined) {
throw new Error(`updateTrialJob failed: ${trialJobId} not found`);
}
if (form.jobType === 'TRIAL') {
await this.writeParameterFile(trialJobId, (<TrialJobApplicationForm>form).hyperParameters);
} else {
throw new Error(`updateTrialJob failed: jobType ${form.jobType} not supported.`);
}
return trialJobDetail;
}
public get isMultiPhaseJobSupported(): boolean {
return false;
return true;
}
// tslint:disable:no-http-string
......@@ -336,6 +347,9 @@ class PAITrainingService implements TrainingService {
case TrialConfigMetadataKey.LOG_COLLECTION:
this.logCollection = value;
break;
case TrialConfigMetadataKey.MULTI_PHASE:
this.isMultiPhase = (value === 'true' || value === 'True');
break;
default:
//Reject for unknown keys
throw new Error(`Uknown key: ${key}`);
......@@ -445,6 +459,7 @@ class PAITrainingService implements TrainingService {
trialJobId,
this.experimentId,
trialJobDetail.sequenceId,
this.isMultiPhase,
this.paiTrialConfig.command,
nniManagerIp,
this.paiRestServerPort,
......@@ -632,7 +647,50 @@ class PAITrainingService implements TrainingService {
return Promise.race([timeoutDelay, deferred.promise])
.finally(() => { clearTimeout(timeoutId); });
}
// tslint:enable:no-any no-unsafe-any no-http-string
private async writeParameterFile(trialJobId: string, hyperParameters: HyperParameters): Promise<void> {
if (this.paiClusterConfig === undefined) {
throw new Error('PAI Cluster config is not initialized');
}
if (this.paiTrialConfig === undefined) {
throw new Error('PAI trial config is not initialized');
}
const trialLocalTempFolder: string = path.join(getExperimentRootDir(), 'trials-local', trialJobId);
const hpFileName: string = generateParamFileName(hyperParameters);
const localFilepath: string = path.join(trialLocalTempFolder, hpFileName);
await fs.promises.writeFile(localFilepath, hyperParameters.value, { encoding: 'utf8' });
const hdfsCodeDir: string = HDFSClientUtility.getHdfsTrialWorkDir(this.paiClusterConfig.userName, trialJobId);
const hdfsHpFilePath: string = path.join(hdfsCodeDir, hpFileName);
await HDFSClientUtility.copyFileToHdfs(localFilepath, hdfsHpFilePath, this.hdfsClient);
await this.postParameterFileMeta({
experimentId: this.experimentId,
trialId: trialJobId,
filePath: hdfsHpFilePath
});
}
private postParameterFileMeta(parameterFileMeta: ParameterFileMeta): Promise<void> {
const deferred : Deferred<void> = new Deferred<void>();
const restServer: PAIJobRestServer = component.get(PAIJobRestServer);
const req: request.Options = {
uri: `${restServer.endPoint}${restServer.apiRootUrl}/parameter-file-meta`,
method: 'POST',
json: true,
body: parameterFileMeta
};
request(req, (err: Error, res: request.Response) => {
if (err) {
deferred.reject(err);
} else {
deferred.resolve();
}
});
return deferred.promise;
}
}
export { PAITrainingService };
declare module 'tail-stream' {
export interface Stream {
on(type: 'data', callback: (data: Buffer) => void): void;
destroy(): void;
end(data: number): void;
emit(data: string): void;
}
export function createReadStream(path: string): Stream;
}
\ No newline at end of file
......@@ -28,9 +28,8 @@ import json
import importlib
from .constants import ModuleName, ClassName, ClassArgs, AdvisorModuleName, AdvisorClassName
from nni.common import enable_multi_thread
from nni.common import enable_multi_thread, enable_multi_phase
from nni.msg_dispatcher import MsgDispatcher
from nni.multi_phase.multi_phase_dispatcher import MultiPhaseMsgDispatcher
logger = logging.getLogger('nni.main')
logger.debug('START')
......@@ -126,6 +125,8 @@ def main():
args = parse_args()
if args.multi_thread:
enable_multi_thread()
if args.multi_phase:
enable_multi_phase()
if args.advisor_class_name:
# advisor is enabled and starts to run
......@@ -180,9 +181,6 @@ def main():
if assessor is None:
raise AssertionError('Failed to create Assessor instance')
if args.multi_phase:
dispatcher = MultiPhaseMsgDispatcher(tuner, assessor)
else:
dispatcher = MsgDispatcher(tuner, assessor)
try:
......
......@@ -78,7 +78,7 @@ class BatchTuner(Tuner):
"""
self.values = self.is_valid(search_space)
def generate_parameters(self, parameter_id):
def generate_parameters(self, parameter_id, **kwargs):
"""Returns a dict of trial (hyper-)parameters, as a serializable object.
Parameters
......@@ -90,7 +90,7 @@ class BatchTuner(Tuner):
raise nni.NoMoreTrialError('no more parameters now.')
return self.values[self.count]
def receive_trial_result(self, parameter_id, parameters, value):
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
pass
def import_data(self, data):
......
......@@ -69,6 +69,7 @@ def init_logger(logger_file_path, log_level_name='info'):
sys.stdout = _LoggerFileWrapper(logger_file)
_multi_thread = False
_multi_phase = False
def enable_multi_thread():
global _multi_thread
......@@ -76,3 +77,10 @@ def enable_multi_thread():
def multi_thread_enabled():
return _multi_thread
def enable_multi_phase():
global _multi_phase
_multi_phase = True
def multi_phase_enabled():
return _multi_phase
......@@ -188,7 +188,7 @@ class EvolutionTuner(Tuner):
self.searchspace_json, is_rand, self.random_state)
self.population.append(Individual(config=config))
def generate_parameters(self, parameter_id):
def generate_parameters(self, parameter_id, **kwargs):
"""Returns a dict of trial (hyper-)parameters, as a serializable object.
Parameters
......@@ -232,7 +232,7 @@ class EvolutionTuner(Tuner):
config = split_index(total_config)
return config
def receive_trial_result(self, parameter_id, parameters, value):
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
'''Record the result from a trial
Parameters
......
......@@ -137,7 +137,7 @@ class GridSearchTuner(Tuner):
'''
self.expanded_search_space = self.json2parameter(search_space)
def generate_parameters(self, parameter_id):
def generate_parameters(self, parameter_id, **kwargs):
self.count += 1
while (self.count <= len(self.expanded_search_space)-1):
_params_tuple = convert_dict2tuple(self.expanded_search_space[self.count])
......@@ -147,7 +147,7 @@ class GridSearchTuner(Tuner):
return self.expanded_search_space[self.count]
raise nni.NoMoreTrialError('no more parameters now.')
def receive_trial_result(self, parameter_id, parameters, value):
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
pass
def import_data(self, data):
......
......@@ -248,7 +248,7 @@ class HyperoptTuner(Tuner):
verbose=0)
self.rval.catch_eval_exceptions = False
def generate_parameters(self, parameter_id):
def generate_parameters(self, parameter_id, **kwargs):
"""
Returns a set of trial (hyper-)parameters, as a serializable object.
......@@ -269,7 +269,7 @@ class HyperoptTuner(Tuner):
params = split_index(total_params)
return params
def receive_trial_result(self, parameter_id, parameters, value):
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
"""
Record an observation of the objective function
......
......@@ -174,7 +174,7 @@ class MetisTuner(Tuner):
return output
def generate_parameters(self, parameter_id):
def generate_parameters(self, parameter_id, **kwargs):
"""Generate next parameter for trial
If the number of trial result is lower than cold start number,
metis will first random generate some parameters.
......@@ -205,7 +205,7 @@ class MetisTuner(Tuner):
return results
def receive_trial_result(self, parameter_id, parameters, value):
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
"""Tuner receive result from trial.
Parameters
......
......@@ -18,7 +18,6 @@
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import os
import logging
from collections import defaultdict
import json_tricks
......@@ -26,7 +25,7 @@ import json_tricks
from .protocol import CommandType, send
from .msg_dispatcher_base import MsgDispatcherBase
from .assessor import AssessResult
from .common import multi_thread_enabled
from .common import multi_thread_enabled, multi_phase_enabled
from .env_vars import dispatcher_env_vars
_logger = logging.getLogger(__name__)
......@@ -61,13 +60,19 @@ def _create_parameter_id():
_next_parameter_id += 1
return _next_parameter_id - 1
def _pack_parameter(parameter_id, params, customized=False):
def _pack_parameter(parameter_id, params, customized=False, trial_job_id=None, parameter_index=None):
_trial_params[parameter_id] = params
ret = {
'parameter_id': parameter_id,
'parameter_source': 'customized' if customized else 'algorithm',
'parameters': params
}
if trial_job_id is not None:
ret['trial_job_id'] = trial_job_id
if parameter_index is not None:
ret['parameter_index'] = parameter_index
else:
ret['parameter_index'] = 0
return json_tricks.dumps(ret)
class MsgDispatcher(MsgDispatcherBase):
......@@ -133,8 +138,13 @@ class MsgDispatcher(MsgDispatcherBase):
elif data['type'] == 'PERIODICAL':
if self.assessor is not None:
self._handle_intermediate_metric_data(data)
else:
pass
elif data['type'] == 'REQUEST_PARAMETER':
assert multi_phase_enabled()
assert data['trial_job_id'] is not None
assert data['parameter_index'] is not None
param_id = _create_parameter_id()
param = self.tuner.generate_parameters(param_id, trial_job_id=data['trial_job_id'])
send(CommandType.SendTrialJobParameter, _pack_parameter(param_id, param, trial_job_id=data['trial_job_id'], parameter_index=data['parameter_index']))
else:
raise ValueError('Data type not supported: {}'.format(data['type']))
......@@ -160,7 +170,13 @@ class MsgDispatcher(MsgDispatcherBase):
id_ = data['parameter_id']
value = data['value']
if id_ in _customized_parameter_ids:
if multi_phase_enabled():
self.tuner.receive_customized_trial_result(id_, _trial_params[id_], value, trial_job_id=data['trial_job_id'])
else:
self.tuner.receive_customized_trial_result(id_, _trial_params[id_], value)
else:
if multi_phase_enabled():
self.tuner.receive_trial_result(id_, _trial_params[id_], value, trial_job_id=data['trial_job_id'])
else:
self.tuner.receive_trial_result(id_, _trial_params[id_], value)
......
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import logging
from collections import defaultdict
import json_tricks
from nni.protocol import CommandType, send
from nni.msg_dispatcher_base import MsgDispatcherBase
from nni.assessor import AssessResult
_logger = logging.getLogger(__name__)
# Assessor global variables
_trial_history = defaultdict(dict)
'''key: trial job ID; value: intermediate results, mapping from sequence number to data'''
_ended_trials = set()
'''trial_job_id of all ended trials.
We need this because NNI manager may send metrics after reporting a trial ended.
TODO: move this logic to NNI manager
'''
def _sort_history(history):
ret = [ ]
for i, _ in enumerate(history):
if i in history:
ret.append(history[i])
else:
break
return ret
# Tuner global variables
_next_parameter_id = 0
_trial_params = {}
'''key: trial job ID; value: parameters'''
_customized_parameter_ids = set()
def _create_parameter_id():
global _next_parameter_id # pylint: disable=global-statement
_next_parameter_id += 1
return _next_parameter_id - 1
def _pack_parameter(parameter_id, params, customized=False, trial_job_id=None, parameter_index=None):
_trial_params[parameter_id] = params
ret = {
'parameter_id': parameter_id,
'parameter_source': 'customized' if customized else 'algorithm',
'parameters': params
}
if trial_job_id is not None:
ret['trial_job_id'] = trial_job_id
if parameter_index is not None:
ret['parameter_index'] = parameter_index
else:
ret['parameter_index'] = 0
return json_tricks.dumps(ret)
class MultiPhaseMsgDispatcher(MsgDispatcherBase):
def __init__(self, tuner, assessor=None):
super(MultiPhaseMsgDispatcher, self).__init__()
self.tuner = tuner
self.assessor = assessor
if assessor is None:
_logger.debug('Assessor is not configured')
def load_checkpoint(self):
self.tuner.load_checkpoint()
if self.assessor is not None:
self.assessor.load_checkpoint()
def save_checkpoint(self):
self.tuner.save_checkpoint()
if self.assessor is not None:
self.assessor.save_checkpoint()
def handle_initialize(self, data):
'''
data is search space
'''
self.tuner.update_search_space(data)
send(CommandType.Initialized, '')
return True
def handle_request_trial_jobs(self, data):
# data: number or trial jobs
ids = [_create_parameter_id() for _ in range(data)]
params_list = self.tuner.generate_multiple_parameters(ids)
assert len(ids) == len(params_list)
for i, _ in enumerate(ids):
send(CommandType.NewTrialJob, _pack_parameter(ids[i], params_list[i]))
return True
def handle_update_search_space(self, data):
self.tuner.update_search_space(data)
return True
def handle_import_data(self, data):
"""import additional data for tuning
data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
self.tuner.import_data(data)
return True
def handle_add_customized_trial(self, data):
# data: parameters
id_ = _create_parameter_id()
_customized_parameter_ids.add(id_)
send(CommandType.NewTrialJob, _pack_parameter(id_, data, customized=True))
return True
def handle_report_metric_data(self, data):
trial_job_id = data['trial_job_id']
if data['type'] == 'FINAL':
id_ = data['parameter_id']
if id_ in _customized_parameter_ids:
self.tuner.receive_customized_trial_result(id_, _trial_params[id_], data['value'], trial_job_id)
else:
self.tuner.receive_trial_result(id_, _trial_params[id_], data['value'], trial_job_id)
elif data['type'] == 'PERIODICAL':
if self.assessor is not None:
self._handle_intermediate_metric_data(data)
else:
pass
elif data['type'] == 'REQUEST_PARAMETER':
assert data['trial_job_id'] is not None
assert data['parameter_index'] is not None
param_id = _create_parameter_id()
param = self.tuner.generate_parameters(param_id, trial_job_id)
send(CommandType.SendTrialJobParameter, _pack_parameter(param_id, param, trial_job_id=data['trial_job_id'], parameter_index=data['parameter_index']))
else:
raise ValueError('Data type not supported: {}'.format(data['type']))
return True
def handle_trial_end(self, data):
trial_job_id = data['trial_job_id']
_ended_trials.add(trial_job_id)
if trial_job_id in _trial_history:
_trial_history.pop(trial_job_id)
if self.assessor is not None:
self.assessor.trial_end(trial_job_id, data['event'] == 'SUCCEEDED')
if self.tuner is not None:
self.tuner.trial_end(json_tricks.loads(data['hyper_params'])['parameter_id'], data['event'] == 'SUCCEEDED', trial_job_id)
return True
def handle_import_data(self, data):
pass
def _handle_intermediate_metric_data(self, data):
if data['type'] != 'PERIODICAL':
return True
if self.assessor is None:
return True
trial_job_id = data['trial_job_id']
if trial_job_id in _ended_trials:
return True
history = _trial_history[trial_job_id]
history[data['sequence']] = data['value']
ordered_history = _sort_history(history)
if len(ordered_history) < data['sequence']: # no user-visible update since last time
return True
try:
result = self.assessor.assess_trial(trial_job_id, ordered_history)
except Exception as e:
_logger.exception('Assessor error')
if isinstance(result, bool):
result = AssessResult.Good if result else AssessResult.Bad
elif not isinstance(result, AssessResult):
msg = 'Result of Assessor.assess_trial must be an object of AssessResult, not %s'
raise RuntimeError(msg % type(result))
if result is AssessResult.Bad:
_logger.debug('BAD, kill %s', trial_job_id)
send(CommandType.KillTrialJob, json_tricks.dumps(trial_job_id))
else:
_logger.debug('GOOD')
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import logging
from nni.recoverable import Recoverable
_logger = logging.getLogger(__name__)
class MultiPhaseTuner(Recoverable):
# pylint: disable=no-self-use,unused-argument
def generate_parameters(self, parameter_id, trial_job_id=None):
"""Returns a set of trial (hyper-)parameters, as a serializable object.
User code must override either this function or 'generate_multiple_parameters()'.
parameter_id: identifier of the parameter (int)
"""
raise NotImplementedError('Tuner: generate_parameters not implemented')
def generate_multiple_parameters(self, parameter_id_list):
"""Returns multiple sets of trial (hyper-)parameters, as iterable of serializable objects.
Call 'generate_parameters()' by 'count' times by default.
User code must override either this function or 'generate_parameters()'.
parameter_id_list: list of int
"""
return [self.generate_parameters(parameter_id) for parameter_id in parameter_id_list]
def receive_trial_result(self, parameter_id, parameters, value, trial_job_id):
"""Invoked when a trial reports its final result. Must override.
parameter_id: identifier of the parameter (int)
parameters: object created by 'generate_parameters()'
value: object reported by trial
trial_job_id: identifier of the trial (str)
"""
raise NotImplementedError('Tuner: receive_trial_result not implemented')
def receive_customized_trial_result(self, parameter_id, parameters, value, trial_job_id):
"""Invoked when a trial added by WebUI reports its final result. Do nothing by default.
parameter_id: identifier of the parameter (int)
parameters: object created by user
value: object reported by trial
trial_job_id: identifier of the trial (str)
"""
_logger.info('Customized trial job %s ignored by tuner', parameter_id)
def trial_end(self, parameter_id, success, trial_job_id):
"""Invoked when a trial is completed or terminated. Do nothing by default.
parameter_id: identifier of the parameter (int)
success: True if the trial successfully completed; False if failed or terminated
trial_job_id: identifier of the trial (str)
"""
pass
def update_search_space(self, search_space):
"""Update the search space of tuner. Must override.
search_space: JSON object
"""
raise NotImplementedError('Tuner: update_search_space not implemented')
def import_data(self, data):
"""Import additional data for tuning
data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
pass
def load_checkpoint(self):
"""Load the checkpoint of tuner.
path: checkpoint directory for tuner
"""
checkpoin_path = self.get_checkpoint_path()
_logger.info('Load checkpoint ignored by tuner, checkpoint path: %s' % checkpoin_path)
def save_checkpoint(self):
"""Save the checkpoint of tuner.
path: checkpoint directory for tuner
"""
checkpoin_path = self.get_checkpoint_path()
_logger.info('Save checkpoint ignored by tuner, checkpoint path: %s' % checkpoin_path)
def _on_exit(self):
pass
def _on_error(self):
pass
def import_data(self, data):
pass
......@@ -123,7 +123,7 @@ class NetworkMorphismTuner(Tuner):
"""
self.search_space = search_space
def generate_parameters(self, parameter_id):
def generate_parameters(self, parameter_id, **kwargs):
"""
Returns a set of trial neural architecture, as a serializable object.
......@@ -152,7 +152,7 @@ class NetworkMorphismTuner(Tuner):
return json_out
def receive_trial_result(self, parameter_id, parameters, value):
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
""" Record an observation of the objective function.
Parameters
......
......@@ -151,7 +151,7 @@ class SMACTuner(Tuner):
else:
self.logger.warning('update search space is not supported.')
def receive_trial_result(self, parameter_id, parameters, value):
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
"""receive_trial_result
Parameters
......@@ -209,7 +209,7 @@ class SMACTuner(Tuner):
converted_dict[key] = value
return converted_dict
def generate_parameters(self, parameter_id):
def generate_parameters(self, parameter_id, **kwargs):
"""generate one instance of hyperparameters
Parameters
......@@ -232,7 +232,7 @@ class SMACTuner(Tuner):
self.total_data[parameter_id] = challenger
return self.convert_loguniform_categorical(challenger.get_dictionary())
def generate_multiple_parameters(self, parameter_id_list):
def generate_multiple_parameters(self, parameter_id_list, **kwargs):
"""generate mutiple instances of hyperparameters
Parameters
......
......@@ -30,14 +30,14 @@ _logger = logging.getLogger(__name__)
class Tuner(Recoverable):
# pylint: disable=no-self-use,unused-argument
def generate_parameters(self, parameter_id):
def generate_parameters(self, parameter_id, **kwargs):
"""Returns a set of trial (hyper-)parameters, as a serializable object.
User code must override either this function or 'generate_multiple_parameters()'.
parameter_id: int
"""
raise NotImplementedError('Tuner: generate_parameters not implemented')
def generate_multiple_parameters(self, parameter_id_list):
def generate_multiple_parameters(self, parameter_id_list, **kwargs):
"""Returns multiple sets of trial (hyper-)parameters, as iterable of serializable objects.
Call 'generate_parameters()' by 'count' times by default.
User code must override either this function or 'generate_parameters()'.
......@@ -49,13 +49,13 @@ class Tuner(Recoverable):
for parameter_id in parameter_id_list:
try:
_logger.debug("generating param for {}".format(parameter_id))
res = self.generate_parameters(parameter_id)
res = self.generate_parameters(parameter_id, **kwargs)
except nni.NoMoreTrialError:
return result
result.append(res)
return result
def receive_trial_result(self, parameter_id, parameters, value):
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
"""Invoked when a trial reports its final result. Must override.
parameter_id: int
parameters: object created by 'generate_parameters()'
......@@ -63,7 +63,7 @@ class Tuner(Recoverable):
"""
raise NotImplementedError('Tuner: receive_trial_result not implemented')
def receive_customized_trial_result(self, parameter_id, parameters, value):
def receive_customized_trial_result(self, parameter_id, parameters, value, **kwargs):
"""Invoked when a trial added by WebUI reports its final result. Do nothing by default.
parameter_id: int
parameters: object created by user
......@@ -71,7 +71,7 @@ class Tuner(Recoverable):
"""
_logger.info('Customized trial job %s ignored by tuner', parameter_id)
def trial_end(self, parameter_id, success):
def trial_end(self, parameter_id, success, **kwargs):
"""Invoked when a trial is completed or terminated. Do nothing by default.
parameter_id: int
success: True if the trial successfully completed; False if failed or terminated
......
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import logging
import random
from io import BytesIO
import nni
import nni.protocol
from nni.protocol import CommandType, send, receive
from nni.multi_phase.multi_phase_tuner import MultiPhaseTuner
from nni.multi_phase.multi_phase_dispatcher import MultiPhaseMsgDispatcher
from unittest import TestCase, main
class NaiveMultiPhaseTuner(MultiPhaseTuner):
'''
supports only choices
'''
def __init__(self):
self.search_space = None
def generate_parameters(self, parameter_id, trial_job_id=None):
"""Returns a set of trial (hyper-)parameters, as a serializable object.
User code must override either this function or 'generate_multiple_parameters()'.
parameter_id: int
"""
generated_parameters = {}
if self.search_space is None:
raise AssertionError('Search space not specified')
for k in self.search_space:
param = self.search_space[k]
if not param['_type'] == 'choice':
raise ValueError('Only choice type is supported')
param_values = param['_value']
generated_parameters[k] = param_values[random.randint(0, len(param_values)-1)]
logging.getLogger(__name__).debug(generated_parameters)
return generated_parameters
def receive_trial_result(self, parameter_id, parameters, value, trial_job_id):
logging.getLogger(__name__).debug('receive_trial_result: {},{},{},{}'.format(parameter_id, parameters, value, trial_job_id))
def receive_customized_trial_result(self, parameter_id, parameters, value, trial_job_id):
pass
def update_search_space(self, search_space):
self.search_space = search_space
_in_buf = BytesIO()
_out_buf = BytesIO()
def _reverse_io():
_in_buf.seek(0)
_out_buf.seek(0)
nni.protocol._out_file = _in_buf
nni.protocol._in_file = _out_buf
def _restore_io():
_in_buf.seek(0)
_out_buf.seek(0)
nni.protocol._in_file = _in_buf
nni.protocol._out_file = _out_buf
def _test_tuner():
_reverse_io() # now we are sending to Tuner's incoming stream
send(CommandType.UpdateSearchSpace, "{\"learning_rate\": {\"_value\": [0.0001, 0.001, 0.002, 0.005, 0.01], \"_type\": \"choice\"}, \"optimizer\": {\"_value\": [\"Adam\", \"SGD\"], \"_type\": \"choice\"}}")
send(CommandType.RequestTrialJobs, '2')
send(CommandType.ReportMetricData, '{"parameter_id":0,"type":"PERIODICAL","value":10,"trial_job_id":"abc"}')
send(CommandType.ReportMetricData, '{"parameter_id":1,"type":"FINAL","value":11,"trial_job_id":"abc"}')
send(CommandType.AddCustomizedTrialJob, '{"param":-1}')
send(CommandType.ReportMetricData, '{"parameter_id":2,"type":"FINAL","value":22,"trial_job_id":"abc"}')
send(CommandType.RequestTrialJobs, '1')
send(CommandType.TrialEnd, '{"trial_job_id":"abc"}')
_restore_io()
tuner = NaiveMultiPhaseTuner()
dispatcher = MultiPhaseMsgDispatcher(tuner)
dispatcher.run()
_reverse_io() # now we are receiving from Tuner's outgoing stream
command, data = receive() # this one is customized
print(command, data)
class MultiPhaseTestCase(TestCase):
def test_tuner(self):
_test_tuner()
if __name__ == '__main__':
main()
\ No newline at end of file
......@@ -35,7 +35,7 @@ class NaiveTuner(Tuner):
self.trial_results = [ ]
self.search_space = None
def generate_parameters(self, parameter_id):
def generate_parameters(self, parameter_id, **kwargs):
# report Tuner's internal states to generated parameters,
# so we don't need to pause the main loop
self.param += 2
......@@ -45,7 +45,7 @@ class NaiveTuner(Tuner):
'search_space': self.search_space
}
def receive_trial_result(self, parameter_id, parameters, value):
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
reward = extract_scalar_reward(value)
self.trial_results.append((parameter_id, parameters['param'], reward, False))
......@@ -103,11 +103,9 @@ class TunerTestCase(TestCase):
command, data = receive() # this one is customized
data = json.loads(data)
self.assertIs(command, CommandType.NewTrialJob)
self.assertEqual(data, {
'parameter_id': 2,
'parameter_source': 'customized',
'parameters': { 'param': -1 }
})
self.assertEqual(data['parameter_id'], 2)
self.assertEqual(data['parameter_source'], 'customized')
self.assertEqual(data['parameters'], { 'param': -1 })
self._assert_params(3, 6, [[1,4,11,False], [2,-1,22,True]], {'name':'SS0'})
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment