Commit 8314d6ee authored by Deshui Yu's avatar Deshui Yu Committed by fishyds
Browse files

Merge from dogfood branch to master

parent 98530fd2
......@@ -31,16 +31,21 @@ export class RemoteMachineMeta {
public readonly ip : string;
public readonly port : number;
public readonly username : string;
public readonly passwd: string;
public readonly passwd?: string;
public readonly sshKeyPath?: string;
public readonly passphrase?: string;
public gpuSummary : GPUSummary | undefined;
/* GPU Reservation info, the key is GPU index, the value is the job id which reserves this GPU*/
public gpuReservation : Map<number, string>;
constructor(ip : string, port : number, username : string, passwd : string) {
constructor(ip : string, port : number, username : string, passwd : string,
sshKeyPath : string, passphrase : string) {
this.ip = ip;
this.port = port;
this.username = username;
this.passwd = passwd;
this.sshKeyPath = sshKeyPath;
this.passphrase = passphrase;
this.gpuReservation = new Map<number, string>();
}
}
......
......@@ -24,11 +24,11 @@ import { EventEmitter } from 'events';
import * as fs from 'fs';
import * as os from 'os';
import * as path from 'path';
import { Client } from 'ssh2';
import { Client, ConnectConfig } from 'ssh2';
import { Deferred } from 'ts-deferred';
import { String } from 'typescript-string-operations';
import * as component from '../../common/component';
import { NNIError, NNIErrorNames } from '../../common/errors';
import { MethodNotImplementedError, NNIError, NNIErrorNames } from '../../common/errors';
import { getExperimentId } from '../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../common/log';
import { ObservableTimer } from '../../common/observableTimer';
......@@ -195,6 +195,22 @@ class RemoteMachineTrainingService implements TrainingService {
}
}
/**
* Update trial job for multi-phase
* @param trialJobId trial job id
* @param form job application form
*/
public updateTrialJob(trialJobId: string, form: JobApplicationForm): Promise<TrialJobDetail> {
throw new MethodNotImplementedError();
}
/**
* Is multiphase job supported in current training service
*/
public get isMultiPhaseJobSupported(): boolean {
return false;
}
/**
* Cancel trial job
* @param trialJobId ID of trial job
......@@ -290,6 +306,24 @@ class RemoteMachineTrainingService implements TrainingService {
let connectedRMNum: number = 0;
rmMetaList.forEach((rmMeta: RemoteMachineMeta) => {
const conn: Client = new Client();
let connectConfig: ConnectConfig = {
host: rmMeta.ip,
port: rmMeta.port,
username: rmMeta.username };
if (rmMeta.passwd) {
connectConfig.password = rmMeta.passwd;
} else if(rmMeta.sshKeyPath) {
if(!fs.existsSync(rmMeta.sshKeyPath)) {
//SSh key path is not a valid file, reject
deferred.reject(new Error(`${rmMeta.sshKeyPath} does not exist.`));
}
const privateKey: string = fs.readFileSync(rmMeta.sshKeyPath, 'utf8');
connectConfig.privateKey = privateKey;
connectConfig.passphrase = rmMeta.passphrase;
} else {
deferred.reject(new Error(`No valid passwd or sshKeyPath is configed.`));
}
this.machineSSHClientMap.set(rmMeta, conn);
conn.on('ready', async () => {
await this.initRemoteMachineOnConnected(rmMeta, conn);
......@@ -299,12 +333,7 @@ class RemoteMachineTrainingService implements TrainingService {
}).on('error', (err: Error) => {
// SSH connection error, reject with error message
deferred.reject(new Error(err.message));
}).connect({
host: rmMeta.ip,
port: rmMeta.port,
username: rmMeta.username,
password: rmMeta.passwd
});
}).connect(connectConfig);
});
return deferred.promise;
......@@ -402,7 +431,7 @@ class RemoteMachineTrainingService implements TrainingService {
(typeof cuda_visible_device === 'string' && cuda_visible_device.length > 0) ?
`CUDA_VISIBLE_DEVICES=${cuda_visible_device} ` : `CUDA_VISIBLE_DEVICES=" " `,
this.trialConfig.command,
path.join(trialWorkingFolder, '.nni', 'stderr'),
path.join(trialWorkingFolder, 'stderr'),
path.join(trialWorkingFolder, '.nni', 'code'));
//create tmp trial working folder locally.
......
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge,
# to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
'''
__main__.py
'''
import os
import sys
import argparse
import logging
import json
import importlib
from nni.msg_dispatcher import MsgDispatcher
from nni.hyperopt_tuner.hyperopt_tuner import HyperoptTuner
from nni.evolution_tuner.evolution_tuner import EvolutionTuner
from nni.medianstop_assessor.medianstop_assessor import MedianstopAssessor
logger = logging.getLogger('nni.main')
logger.debug('START')
BUILT_IN_CLASS_NAMES = ['HyperoptTuner', 'EvolutionTuner', 'MedianstopAssessor']
def create_builtin_class_instance(classname, jsonstr_args):
if jsonstr_args:
class_args = json.loads(jsonstr_args)
instance = eval(classname)(**class_args)
else:
instance = eval(classname)()
return instance
def create_customized_class_instance(class_dir, class_filename, classname, jsonstr_args):
if not os.path.isfile(os.path.join(class_dir, class_filename)):
raise ValueError('Class file not found: {}'.format(os.path.join(class_dir, class_filename)))
sys.path.append(class_dir)
module_name = class_filename.split('.')[0]
class_module = importlib.import_module(module_name)
class_constructor = getattr(class_module, classname)
if jsonstr_args:
class_args = json.loads(jsonstr_args)
instance = class_constructor(**class_args)
else:
instance = class_constructor()
return instance
def parse_args():
parser = argparse.ArgumentParser(description='parse command line parameters.')
parser.add_argument('--tuner_class_name', type=str, required=True,
help='Tuner class name, the class must be a subclass of nni.Tuner')
parser.add_argument('--tuner_args', type=str, required=False,
help='Parameters pass to tuner __init__ constructor')
parser.add_argument('--tuner_directory', type=str, required=False,
help='Tuner directory')
parser.add_argument('--tuner_class_filename', type=str, required=False,
help='Tuner class file path')
parser.add_argument('--assessor_class_name', type=str, required=False,
help='Assessor class name, the class must be a subclass of nni.Assessor')
parser.add_argument('--assessor_args', type=str, required=False,
help='Parameters pass to assessor __init__ constructor')
parser.add_argument('--assessor_directory', type=str, required=False,
help='Assessor directory')
parser.add_argument('--assessor_class_filename', type=str, required=False,
help='Assessor class file path')
flags, _ = parser.parse_known_args()
return flags
def main():
'''
main function.
'''
args = parse_args()
tuner = None
assessor = None
if args.tuner_class_name is None:
raise ValueError('Tuner must be specified')
if args.tuner_class_name in BUILT_IN_CLASS_NAMES:
tuner = create_builtin_class_instance(args.tuner_class_name, args.tuner_args)
else:
tuner = create_customized_class_instance(args.tuner_directory, args.tuner_class_filename, args.tuner_class_name, args.tuner_args)
if args.assessor_class_name:
if args.assessor_class_name in BUILT_IN_CLASS_NAMES:
assessor = create_builtin_class_instance(args.assessor_class_name, args.assessor_args)
else:
assessor = create_customized_class_instance(args.assessor_directory, \
args.assessor_class_filename, args.assessor_class_name, args.assessor_args)
if tuner is None:
raise AssertionError('Failed to create Tuner instance')
dispatcher = MsgDispatcher(tuner, assessor)
try:
dispatcher.run()
tuner._on_exit()
if assessor is not None:
assessor._on_exit()
except Exception as exception:
logger.exception(exception)
tuner._on_error()
if assessor is not None:
assessor._on_error()
raise
if __name__ == '__main__':
try:
main()
except Exception as exception:
logger.exception(exception)
raise
......@@ -19,27 +19,18 @@
# ==================================================================================================
from collections import defaultdict
from enum import Enum
import logging
import os
import json_tricks
from .common import init_logger
from .protocol import CommandType, send, receive
from enum import Enum
from .recoverable import Recoverable
init_logger('assessor.log')
_logger = logging.getLogger(__name__)
class AssessResult(Enum):
Good = True
Bad = False
class Assessor:
class Assessor(Recoverable):
# pylint: disable=no-self-use,unused-argument
def assess_trial(self, trial_job_id, trial_history):
......@@ -57,101 +48,22 @@ class Assessor:
"""
pass
def load_checkpoint(self, path):
"""Load the checkpoint of assessor.
path: checkpoint directory of assessor
def load_checkpoint(self):
"""Load the checkpoint of assessr.
path: checkpoint directory for assessor
"""
_logger.info('Load checkpoint ignored by assessor')
checkpoin_path = self.get_checkpoint_path()
_logger.info('Load checkpoint ignored by assessor, checkpoint path: %s' % checkpoin_path)
def save_checkpoint(self, path):
def save_checkpoint(self):
"""Save the checkpoint of assessor.
path: checkpoint directory of assessor
"""
_logger.info('Save checkpoint ignored by assessor')
def request_save_checkpoint(self):
"""Request to save the checkpoint of assessor
"""
self.save_checkpoint(os.getenv('NNI_CHECKPOINT_DIRECTORY'))
def run(self):
"""Run the assessor.
This function will never return unless raise.
path: checkpoint directory for assessor
"""
mode = os.getenv('NNI_MODE')
if mode == 'resume':
self.load_checkpoint(os.getenv('NNI_CHECKPOINT_DIRECTORY'))
while _handle_request(self):
pass
_logger.info('Terminated by NNI manager')
_trial_history = defaultdict(dict)
'''key: trial job ID; value: intermediate results, mapping from sequence number to data'''
_ended_trials = set()
'''trial_job_id of all ended trials.
We need this because NNI manager may send metrics after reporting a trial ended.
TODO: move this logic to NNI manager
'''
def _sort_history(history):
ret = [ ]
for i, _ in enumerate(history):
if i in history:
ret.append(history[i])
else:
break
return ret
checkpoin_path = self.get_checkpoint_path()
_logger.info('Save checkpoint ignored by assessor, checkpoint path: %s' % checkpoin_path)
def _handle_request(assessor):
_logger.debug('waiting receive_message')
command, data = receive()
_logger.debug(command)
_logger.debug(data)
if command is CommandType.Terminate:
return False
data = json_tricks.loads(data)
if command is CommandType.ReportMetricData:
if data['type'] != 'PERIODICAL':
return True
trial_job_id = data['trial_job_id']
if trial_job_id in _ended_trials:
return True
history = _trial_history[trial_job_id]
history[data['sequence']] = data['value']
ordered_history = _sort_history(history)
if len(ordered_history) < data['sequence']: # no user-visible update since last time
return True
result = assessor.assess_trial(trial_job_id, ordered_history)
if isinstance(result, bool):
result = AssessResult.Good if result else AssessResult.Bad
elif not isinstance(result, AssessResult):
msg = 'Result of Assessor.assess_trial must be an object of AssessResult, not %s'
raise RuntimeError(msg % type(result))
if result is AssessResult.Bad:
_logger.debug('BAD, kill %s', trial_job_id)
send(CommandType.KillTrialJob, json_tricks.dumps(trial_job_id))
else:
_logger.debug('GOOD')
elif command is CommandType.TrialEnd:
trial_job_id = data['trial_job_id']
_ended_trials.add(trial_job_id)
if trial_job_id in _trial_history:
_trial_history.pop(trial_job_id)
assessor.trial_end(trial_job_id, data['event'] == 'SUCCEEDED')
else:
raise AssertionError('Unsupported command: %s' % command)
def _on_exit(self):
pass
return True
def _on_error(self):
pass
......@@ -26,7 +26,7 @@ class MedianstopAssessor(Assessor):
if the trial’s best objective value by step S is strictly worse than the median value
of the running averages of all completed trials’ objectives reported up to step S
'''
def __init__(self, start_step, optimize_mode):
def __init__(self, optimize_mode='maximize', start_step=0):
self.start_step = start_step
self.running_history = dict()
self.completed_avg_history = dict()
......
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import logging
from collections import defaultdict
import json_tricks
from .protocol import CommandType, send
from .msg_dispatcher_base import MsgDispatcherBase
from .assessor import AssessResult
_logger = logging.getLogger(__name__)
# Assessor global variables
_trial_history = defaultdict(dict)
'''key: trial job ID; value: intermediate results, mapping from sequence number to data'''
_ended_trials = set()
'''trial_job_id of all ended trials.
We need this because NNI manager may send metrics after reporting a trial ended.
TODO: move this logic to NNI manager
'''
def _sort_history(history):
ret = [ ]
for i, _ in enumerate(history):
if i in history:
ret.append(history[i])
else:
break
return ret
# Tuner global variables
_next_parameter_id = 0
_trial_params = {}
'''key: trial job ID; value: parameters'''
_customized_parameter_ids = set()
def _create_parameter_id():
global _next_parameter_id # pylint: disable=global-statement
_next_parameter_id += 1
return _next_parameter_id - 1
def _pack_parameter(parameter_id, params, customized=False):
_trial_params[parameter_id] = params
ret = {
'parameter_id': parameter_id,
'parameter_source': 'customized' if customized else 'algorithm',
'parameters': params
}
return json_tricks.dumps(ret)
class MsgDispatcher(MsgDispatcherBase):
def __init__(self, tuner, assessor=None):
super()
self.tuner = tuner
self.assessor = assessor
if assessor is None:
_logger.debug('Assessor is not configured')
def load_checkpoint(self):
self.tuner.load_checkpoint()
if self.assessor is not None:
self.assessor.load_checkpoint()
def save_checkpoint(self):
self.tuner.save_checkpoint()
if self.assessor is not None:
self.assessor.save_checkpoint()
def handle_request_trial_jobs(self, data):
# data: number or trial jobs
ids = [_create_parameter_id() for _ in range(data)]
params_list = self.tuner.generate_multiple_parameters(ids)
assert len(ids) == len(params_list)
for i, _ in enumerate(ids):
send(CommandType.NewTrialJob, _pack_parameter(ids[i], params_list[i]))
return True
def handle_update_search_space(self, data):
self.tuner.update_search_space(data)
return True
def handle_add_customized_trial(self, data):
# data: parameters
id_ = _create_parameter_id()
_customized_parameter_ids.add(id_)
send(CommandType.NewTrialJob, _pack_parameter(id_, data, customized=True))
return True
def handle_report_metric_data(self, data):
if data['type'] == 'FINAL':
id_ = data['parameter_id']
if id_ in _customized_parameter_ids:
self.tuner.receive_customized_trial_result(id_, _trial_params[id_], data['value'])
else:
self.tuner.receive_trial_result(id_, _trial_params[id_], data['value'])
elif data['type'] == 'PERIODICAL':
if self.assessor is not None:
self._handle_intermediate_metric_data(data)
else:
pass
else:
raise ValueError('Data type not supported: {}'.format(data['type']))
return True
def handle_trial_end(self, data):
trial_job_id = data['trial_job_id']
_ended_trials.add(trial_job_id)
if trial_job_id in _trial_history:
_trial_history.pop(trial_job_id)
if self.assessor is not None:
self.assessor.trial_end(trial_job_id, data['event'] == 'SUCCEEDED')
return True
def _handle_intermediate_metric_data(self, data):
if data['type'] != 'PERIODICAL':
return True
if self.assessor is None:
return True
trial_job_id = data['trial_job_id']
if trial_job_id in _ended_trials:
return True
history = _trial_history[trial_job_id]
history[data['sequence']] = data['value']
ordered_history = _sort_history(history)
if len(ordered_history) < data['sequence']: # no user-visible update since last time
return True
try:
result = self.assessor.assess_trial(trial_job_id, ordered_history)
except Exception as e:
_logger.exception('Assessor error')
if isinstance(result, bool):
result = AssessResult.Good if result else AssessResult.Bad
elif not isinstance(result, AssessResult):
msg = 'Result of Assessor.assess_trial must be an object of AssessResult, not %s'
raise RuntimeError(msg % type(result))
if result is AssessResult.Bad:
_logger.debug('BAD, kill %s', trial_job_id)
send(CommandType.KillTrialJob, json_tricks.dumps(trial_job_id))
else:
_logger.debug('GOOD')
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
#import json_tricks
import os
import logging
import json_tricks
from .common import init_logger
from .recoverable import Recoverable
from .protocol import CommandType, receive
init_logger('dispatcher.log')
_logger = logging.getLogger(__name__)
class MsgDispatcherBase(Recoverable):
def run(self):
"""Run the tuner.
This function will never return unless raise.
"""
mode = os.getenv('NNI_MODE')
if mode == 'resume':
self.load_checkpoint()
while self.handle_request():
pass
_logger.info('Terminated by NNI manager')
def handle_request(self):
_logger.debug('waiting receive_message')
command, data = receive()
if command is None:
return False
_logger.debug('handle request: command: [{}], data: [{}]'.format(command, data))
if command is CommandType.Terminate:
return False
data = json_tricks.loads(data)
command_handlers = {
# Tunner commands:
CommandType.RequestTrialJobs: self.handle_request_trial_jobs,
CommandType.UpdateSearchSpace: self.handle_update_search_space,
CommandType.AddCustomizedTrialJob: self.handle_add_customized_trial,
# Tunner/Assessor commands:
CommandType.ReportMetricData: self.handle_report_metric_data,
CommandType.TrialEnd: self.handle_trial_end,
}
if command not in command_handlers:
raise AssertionError('Unsupported command: {}'.format(command))
return command_handlers[command](data)
def handle_request_trial_jobs(self, data):
raise NotImplementedError('handle_request_trial_jobs not implemented')
def handle_update_search_space(self, data):
raise NotImplementedError('handle_update_search_space not implemented')
def handle_add_customized_trial(self, data):
raise NotImplementedError('handle_add_customized_trial not implemented')
def handle_report_metric_data(self, data):
raise NotImplementedError('handle_report_metric_data not implemented')
def handle_trial_end(self, data):
raise NotImplementedError('handle_trial_end not implemented')
......@@ -28,7 +28,7 @@ from ..common import init_logger
_dir = os.environ['NNI_SYS_DIR']
_metric_file = open(os.path.join(_dir, '.nni', 'metrics'), 'wb')
_log_file_path = os.path.join(_dir, '.nni', 'trial.log')
_log_file_path = os.path.join(_dir, 'trial.log')
init_logger(_log_file_path)
......
......@@ -55,6 +55,7 @@ def send(command, data):
data = data.encode('utf8')
assert len(data) < 1000000, 'Command too long'
msg = b'%b%06d%b' % (command.value, len(data), data)
logging.getLogger(__name__).debug('Sending command, data: [%s]' % data)
_out_file.write(msg)
_out_file.flush()
......
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import os
class Recoverable:
def load_checkpoint(self):
pass
def save_checkpont(self):
pass
def get_checkpoint_path(self):
ckp_path = os.getenv('NNI_CHECKPOINT_DIRECTORY')
if ckp_path is not None and os.path.isdir(ckp_path):
return ckp_path
return None
\ No newline at end of file
......@@ -20,19 +20,13 @@
import logging
import os
import json_tricks
from .recoverable import Recoverable
from .common import init_logger
from .protocol import CommandType, send, receive
init_logger('tuner.log')
_logger = logging.getLogger(__name__)
class Tuner:
class Tuner(Recoverable):
# pylint: disable=no-self-use,unused-argument
def generate_parameters(self, parameter_id):
......@@ -72,100 +66,22 @@ class Tuner:
"""
raise NotImplementedError('Tuner: update_search_space not implemented')
def load_checkpoint(self, path):
def load_checkpoint(self):
"""Load the checkpoint of tuner.
path: checkpoint directory for tuner
"""
_logger.info('Load checkpoint ignored by tuner')
checkpoin_path = self.get_checkpoint_path()
_logger.info('Load checkpoint ignored by tuner, checkpoint path: %s' % checkpoin_path)
def save_checkpoint(self, path):
def save_checkpoint(self):
"""Save the checkpoint of tuner.
path: checkpoint directory for tuner
"""
_logger.info('Save checkpoint ignored by tuner')
def request_save_checkpoint(self):
"""Request to save the checkpoint of tuner
"""
self.save_checkpoint(os.getenv('NNI_CHECKPOINT_DIRECTORY'))
def run(self):
"""Run the tuner.
This function will never return unless raise.
"""
mode = os.getenv('NNI_MODE')
if mode == 'resume':
self.load_checkpoint(os.getenv('NNI_CHECKPOINT_DIRECTORY'))
while _handle_request(self):
pass
_logger.info('Terminated by NNI manager')
_next_parameter_id = 0
_trial_params = {}
'''key: trial job ID; value: parameters'''
_customized_parameter_ids = set()
def _create_parameter_id():
global _next_parameter_id # pylint: disable=global-statement
_next_parameter_id += 1
return _next_parameter_id - 1
def _pack_parameter(parameter_id, params, customized=False):
_trial_params[parameter_id] = params
ret = {
'parameter_id': parameter_id,
'parameter_source': 'customized' if customized else 'algorithm',
'parameters': params
}
return json_tricks.dumps(ret)
def _handle_request(tuner):
_logger.debug('waiting receive_message')
command, data = receive()
if command is None:
return False
_logger.debug(command)
_logger.debug(data)
if command is CommandType.Terminate:
return False
data = json_tricks.loads(data)
if command is CommandType.RequestTrialJobs:
# data: number or trial jobs
ids = [_create_parameter_id() for _ in range(data)]
params_list = list(tuner.generate_multiple_parameters(ids))
assert len(ids) == len(params_list)
for i, _ in enumerate(ids):
send(CommandType.NewTrialJob, _pack_parameter(ids[i], params_list[i]))
elif command is CommandType.ReportMetricData:
# data: { 'type': 'FINAL', 'parameter_id': ..., 'value': ... }
if data['type'] == 'FINAL':
id_ = data['parameter_id']
if id_ in _customized_parameter_ids:
tuner.receive_customized_trial_result(id_, _trial_params[id_], data['value'])
else:
tuner.receive_trial_result(id_, _trial_params[id_], data['value'])
elif command is CommandType.UpdateSearchSpace:
# data: search space
tuner.update_search_space(data)
elif command is CommandType.AddCustomizedTrialJob:
# data: parameters
id_ = _create_parameter_id()
_customized_parameter_ids.add(id_)
send(CommandType.NewTrialJob, _pack_parameter(id_, data, customized=True))
checkpoin_path = self.get_checkpoint_path()
_logger.info('Save checkpoint ignored by tuner, checkpoint path: %s' % checkpoin_path)
else:
raise AssertionError('Unsupported command: %s' % command)
def _on_exit(self):
pass
return True
def _on_error(self):
pass
......@@ -261,17 +261,12 @@ class Control extends React.Component<{}, ControlState> {
} else {
this.addButtonLoad();
// new experiment obj
const parameter = [];
parameter.push({
parameters: addTrial
});
const sendPara = JSON.stringify(parameter[0]);
axios(`${MANAGER_IP}/trial-jobs`, {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
data: sendPara
data: addTrial
}).then(res => {
if (res.status === 200) {
message.success('Submit successfully');
......
......@@ -71,7 +71,7 @@ class Para extends React.Component<{}, ParaState> {
paraNodata: '',
};
}
hyperParaPic = () => {
axios
.all([
......@@ -238,9 +238,6 @@ class Para extends React.Component<{}, ParaState> {
type: 'continuous',
min: 0,
max: 1,
realtime: false,
calculable: true,
precision: 1,
// gradient color
color: ['#fb7c7c', 'yellow', 'lightblue']
},
......@@ -357,6 +354,7 @@ class Para extends React.Component<{}, ParaState> {
this._isMounted = false;
window.clearInterval(this.intervalIDPara);
}
render() {
const { option, paraNodata, dimName } = this.state;
return (
......@@ -365,6 +363,7 @@ class Para extends React.Component<{}, ParaState> {
<div className="paraTitle">
<div className="paraLeft">Hyper Parameter</div>
<div className="paraRight">
{/* <span>top</span> */}
<Select
className="parapercent"
style={{ width: '20%' }}
......@@ -372,10 +371,10 @@ class Para extends React.Component<{}, ParaState> {
optionFilterProp="children"
onSelect={this.percentNum}
>
<Option value="0.2">0.2</Option>
<Option value="0.5">0.5</Option>
<Option value="0.8">0.8</Option>
<Option value="1">1</Option>
<Option value="0.2">20%</Option>
<Option value="0.5">50%</Option>
<Option value="0.8">80%</Option>
<Option value="1">100%</Option>
</Select>
<Select
style={{ width: '60%' }}
......
......@@ -39,12 +39,6 @@ class SlideBar extends React.Component<{}, {}> {
<Icon className="floicon" type="right" />
</Link>
</li>
<li>
<Link to={'/tensor'} activeClassName="high">
<Icon className="icon" type="link" />Tensorboard
<Icon className="floicon" type="right" />
</Link>
</li>
</ul>
</div>
);
......
......@@ -234,25 +234,21 @@ class TrialStatus extends React.Component<{}, TabState> {
// kill job
killJob = (key: number, id: string, status: string) => {
if (status === 'RUNNING') {
axios(`${MANAGER_IP}/trial-jobs/${id}`, {
method: 'DELETE',
headers: {
'Content-Type': 'application/json;charset=utf-8'
axios(`${MANAGER_IP}/trial-jobs/${id}`, {
method: 'DELETE',
headers: {
'Content-Type': 'application/json;charset=utf-8'
}
})
.then(res => {
if (res.status === 200) {
message.success('Cancel the job successfully');
// render the table
this.drawTable();
} else {
message.error('fail to cancel the job');
}
})
.then(res => {
if (res.status === 200) {
message.success('Cancel the job successfully');
// render the table
this.drawTable();
} else {
message.error('fail to cancel the job');
}
});
} else {
message.error('you just can kill the job that status is Running');
}
});
}
// get tensorflow address
......@@ -347,13 +343,34 @@ class TrialStatus extends React.Component<{}, TabState> {
key: 'operation',
width: '10%',
render: (text: string, record: TableObj) => {
let trialStatus = record.status;
let flagKill = false;
if (trialStatus === 'RUNNING') {
flagKill = true;
} else {
flagKill = false;
}
return (
<Popconfirm
title="Are you sure to delete this trial?"
onConfirm={this.killJob.bind(this, record.key, record.id, record.status)}
>
<Button type="primary" className="tableButton">Kill</Button>
</Popconfirm>
flagKill
?
(
<Popconfirm
title="Are you sure to delete this trial?"
onConfirm={this.killJob.bind(this, record.key, record.id, record.status)}
>
<Button type="primary" className="tableButton">Kill</Button>
</Popconfirm>
)
:
(
<Button
type="primary"
className="tableButton"
disabled={true}
>
Kill
</Button>
)
);
},
}, {
......
......@@ -24,6 +24,11 @@
float: right;
width: 60%;
}
.paraRight>span{
font-size: 14px;
color: #333;
margin-right: 5px;
}
.paraRight .parapercent{
margin-right: 10px;
}
......
This diff is collapsed.
......@@ -4,7 +4,7 @@ from nni.assessor import Assessor, AssessResult
_logger = logging.getLogger('NaiveAssessor')
_logger.info('start')
_result = open('assessor_result.txt', 'w')
_result = open('/tmp/nni_assessor_result.txt', 'w')
class NaiveAssessor(Assessor):
def __init__(self):
......@@ -29,10 +29,10 @@ class NaiveAssessor(Assessor):
return AssessResult.Good
try:
NaiveAssessor().run()
_result.write('DONE\n')
except Exception as e:
_logger.exception(e)
_result.write('ERROR\n')
_result.close()
def _on_exit(self):
_result.write('DONE\n')
_result.close()
def _on_error(self):
_result.write('ERROR\n')
_result.close()
This diff is collapsed.
#!/bin/sh
python3 -m nnicmd.nnictl $@
WEB_UI_FOLDER=${PWD}/../../src/webui python3 -m nnicmd.nnictl $@
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment