Unverified Commit d48ad027 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #184 from microsoft/master

merge master
parents 9352cc88 22993e5d
...@@ -21,16 +21,16 @@ ...@@ -21,16 +21,16 @@
import * as assert from 'assert'; import * as assert from 'assert';
import * as cpp from 'child-process-promise'; import * as cpp from 'child-process-promise';
import * as path from 'path';
import * as os from 'os'; import * as os from 'os';
import * as path from 'path';
import { Client, ClientChannel, SFTPWrapper } from 'ssh2'; import { Client, ClientChannel, SFTPWrapper } from 'ssh2';
import * as stream from 'stream'; import * as stream from 'stream';
import { Deferred } from 'ts-deferred'; import { Deferred } from 'ts-deferred';
import { NNIError, NNIErrorNames } from '../../common/errors'; import { NNIError, NNIErrorNames } from '../../common/errors';
import { getLogger, Logger } from '../../common/log'; import { getLogger, Logger } from '../../common/log';
import { uniqueString, getRemoteTmpDir, unixPathJoin } from '../../common/utils'; import { getRemoteTmpDir, uniqueString, unixPathJoin } from '../../common/utils';
import { RemoteCommandResult } from './remoteMachineData';
import { execRemove, tarAdd } from '../common/util'; import { execRemove, tarAdd } from '../common/util';
import { RemoteCommandResult } from './remoteMachineData';
/** /**
* *
...@@ -44,7 +44,8 @@ export namespace SSHClientUtility { ...@@ -44,7 +44,8 @@ export namespace SSHClientUtility {
* @param remoteDirectory remote directory * @param remoteDirectory remote directory
* @param sshClient SSH client * @param sshClient SSH client
*/ */
export async function copyDirectoryToRemote(localDirectory : string, remoteDirectory : string, sshClient : Client, remoteOS: string) : Promise<void> { export async function copyDirectoryToRemote(localDirectory : string, remoteDirectory : string, sshClient : Client, remoteOS: string)
: Promise<void> {
const deferred: Deferred<void> = new Deferred<void>(); const deferred: Deferred<void> = new Deferred<void>();
const tmpTarName: string = `${uniqueString(10)}.tar.gz`; const tmpTarName: string = `${uniqueString(10)}.tar.gz`;
const localTarPath: string = path.join(os.tmpdir(), tmpTarName); const localTarPath: string = path.join(os.tmpdir(), tmpTarName);
...@@ -75,7 +76,7 @@ export namespace SSHClientUtility { ...@@ -75,7 +76,7 @@ export namespace SSHClientUtility {
assert(sshClient !== undefined); assert(sshClient !== undefined);
const deferred: Deferred<boolean> = new Deferred<boolean>(); const deferred: Deferred<boolean> = new Deferred<boolean>();
sshClient.sftp((err : Error, sftp : SFTPWrapper) => { sshClient.sftp((err : Error, sftp : SFTPWrapper) => {
if (err) { if (err !== undefined && err !== null) {
log.error(`copyFileToRemote: ${err.message}, ${localFilePath}, ${remoteFilePath}`); log.error(`copyFileToRemote: ${err.message}, ${localFilePath}, ${remoteFilePath}`);
deferred.reject(err); deferred.reject(err);
...@@ -84,7 +85,7 @@ export namespace SSHClientUtility { ...@@ -84,7 +85,7 @@ export namespace SSHClientUtility {
assert(sftp !== undefined); assert(sftp !== undefined);
sftp.fastPut(localFilePath, remoteFilePath, (fastPutErr : Error) => { sftp.fastPut(localFilePath, remoteFilePath, (fastPutErr : Error) => {
sftp.end(); sftp.end();
if (fastPutErr) { if (fastPutErr !== undefined && fastPutErr !== null) {
deferred.reject(fastPutErr); deferred.reject(fastPutErr);
} else { } else {
deferred.resolve(true); deferred.resolve(true);
...@@ -100,6 +101,7 @@ export namespace SSHClientUtility { ...@@ -100,6 +101,7 @@ export namespace SSHClientUtility {
* @param command the command to execute remotely * @param command the command to execute remotely
* @param client SSH Client * @param client SSH Client
*/ */
// tslint:disable:no-unsafe-any no-any
export function remoteExeCommand(command : string, client : Client): Promise<RemoteCommandResult> { export function remoteExeCommand(command : string, client : Client): Promise<RemoteCommandResult> {
const log: Logger = getLogger(); const log: Logger = getLogger();
log.debug(`remoteExeCommand: command: [${command}]`); log.debug(`remoteExeCommand: command: [${command}]`);
...@@ -109,7 +111,7 @@ export namespace SSHClientUtility { ...@@ -109,7 +111,7 @@ export namespace SSHClientUtility {
let exitCode : number; let exitCode : number;
client.exec(command, (err : Error, channel : ClientChannel) => { client.exec(command, (err : Error, channel : ClientChannel) => {
if (err) { if (err !== undefined && err !== null) {
log.error(`remoteExeCommand: ${err.message}`); log.error(`remoteExeCommand: ${err.message}`);
deferred.reject(err); deferred.reject(err);
...@@ -117,13 +119,14 @@ export namespace SSHClientUtility { ...@@ -117,13 +119,14 @@ export namespace SSHClientUtility {
} }
channel.on('data', (data : any, dataStderr : any) => { channel.on('data', (data : any, dataStderr : any) => {
if (dataStderr) { if (dataStderr !== undefined && dataStderr !== null) {
stderr += data.toString(); stderr += data.toString();
} else { } else {
stdout += data.toString(); stdout += data.toString();
} }
}).on('exit', (code, signal) => { })
exitCode = code as number; .on('exit', (code : any, signal : any) => {
exitCode = <number>code;
deferred.resolve({ deferred.resolve({
stdout : stdout, stdout : stdout,
stderr : stderr, stderr : stderr,
...@@ -138,8 +141,9 @@ export namespace SSHClientUtility { ...@@ -138,8 +141,9 @@ export namespace SSHClientUtility {
export function getRemoteFileContent(filePath: string, sshClient: Client): Promise<string> { export function getRemoteFileContent(filePath: string, sshClient: Client): Promise<string> {
const deferred: Deferred<string> = new Deferred<string>(); const deferred: Deferred<string> = new Deferred<string>();
sshClient.sftp((err: Error, sftp : SFTPWrapper) => { sshClient.sftp((err: Error, sftp : SFTPWrapper) => {
if (err) { if (err !== undefined && err !== null) {
getLogger().error(`getRemoteFileContent: ${err.message}`); getLogger()
.error(`getRemoteFileContent: ${err.message}`);
deferred.reject(new Error(`SFTP error: ${err.message}`)); deferred.reject(new Error(`SFTP error: ${err.message}`));
return; return;
...@@ -150,16 +154,19 @@ export namespace SSHClientUtility { ...@@ -150,16 +154,19 @@ export namespace SSHClientUtility {
let dataBuffer: string = ''; let dataBuffer: string = '';
sftpStream.on('data', (data : Buffer | string) => { sftpStream.on('data', (data : Buffer | string) => {
dataBuffer += data; dataBuffer += data;
}).on('error', (streamErr: Error) => { })
.on('error', (streamErr: Error) => {
sftp.end(); sftp.end();
deferred.reject(new NNIError(NNIErrorNames.NOT_FOUND, streamErr.message)); deferred.reject(new NNIError(NNIErrorNames.NOT_FOUND, streamErr.message));
}).on('end', () => { })
.on('end', () => {
// sftp connection need to be released manually once operation is done // sftp connection need to be released manually once operation is done
sftp.end(); sftp.end();
deferred.resolve(dataBuffer); deferred.resolve(dataBuffer);
}); });
} catch (error) { } catch (error) {
getLogger().error(`getRemoteFileContent: ${error.message}`); getLogger()
.error(`getRemoteFileContent: ${error.message}`);
sftp.end(); sftp.end();
deferred.reject(new Error(`SFTP error: ${error.message}`)); deferred.reject(new Error(`SFTP error: ${error.message}`));
} }
...@@ -167,4 +174,5 @@ export namespace SSHClientUtility { ...@@ -167,4 +174,5 @@ export namespace SSHClientUtility {
return deferred.promise; return deferred.promise;
} }
// tslint:enable:no-unsafe-any no-any
} }
...@@ -37,7 +37,7 @@ describe('WebHDFS', function () { ...@@ -37,7 +37,7 @@ describe('WebHDFS', function () {
{ {
"user": "user1", "user": "user1",
"port": 50070, "port": 50070,
"host": "10.0.0.0" "host": "10.0.0.0"
} }
*/ */
let skip: boolean = false; let skip: boolean = false;
...@@ -45,7 +45,7 @@ describe('WebHDFS', function () { ...@@ -45,7 +45,7 @@ describe('WebHDFS', function () {
let hdfsClient: any; let hdfsClient: any;
try { try {
testHDFSInfo = JSON.parse(fs.readFileSync('../../.vscode/hdfsInfo.json', 'utf8')); testHDFSInfo = JSON.parse(fs.readFileSync('../../.vscode/hdfsInfo.json', 'utf8'));
console.log(testHDFSInfo); console.log(testHDFSInfo);
hdfsClient = WebHDFS.createClient({ hdfsClient = WebHDFS.createClient({
user: testHDFSInfo.user, user: testHDFSInfo.user,
port: testHDFSInfo.port, port: testHDFSInfo.port,
...@@ -120,7 +120,7 @@ describe('WebHDFS', function () { ...@@ -120,7 +120,7 @@ describe('WebHDFS', function () {
chai.expect(actualFileData).to.be.equals(testFileData); chai.expect(actualFileData).to.be.equals(testFileData);
const testHDFSDirPath : string = path.join('/nni_unittest_' + uniqueString(6) + '_dir'); const testHDFSDirPath : string = path.join('/nni_unittest_' + uniqueString(6) + '_dir');
await HDFSClientUtility.copyDirectoryToHdfs(tmpLocalDirectoryPath, testHDFSDirPath, hdfsClient); await HDFSClientUtility.copyDirectoryToHdfs(tmpLocalDirectoryPath, testHDFSDirPath, hdfsClient);
const files : any[] = await HDFSClientUtility.readdir(testHDFSDirPath, hdfsClient); const files : any[] = await HDFSClientUtility.readdir(testHDFSDirPath, hdfsClient);
...@@ -133,7 +133,7 @@ describe('WebHDFS', function () { ...@@ -133,7 +133,7 @@ describe('WebHDFS', function () {
// Cleanup // Cleanup
rmdir(tmpLocalDirectoryPath); rmdir(tmpLocalDirectoryPath);
let deleteRestult : boolean = await HDFSClientUtility.deletePath(testHDFSFilePath, hdfsClient); let deleteRestult : boolean = await HDFSClientUtility.deletePath(testHDFSFilePath, hdfsClient);
chai.expect(deleteRestult).to.be.equals(true); chai.expect(deleteRestult).to.be.equals(true);
......
...@@ -63,7 +63,7 @@ describe('Unit Test for KubeflowTrainingService', () => { ...@@ -63,7 +63,7 @@ describe('Unit Test for KubeflowTrainingService', () => {
if (skip) { if (skip) {
return; return;
} }
kubeflowTrainingService = component.get(KubeflowTrainingService); kubeflowTrainingService = component.get(KubeflowTrainingService);
}); });
afterEach(() => { afterEach(() => {
...@@ -78,6 +78,6 @@ describe('Unit Test for KubeflowTrainingService', () => { ...@@ -78,6 +78,6 @@ describe('Unit Test for KubeflowTrainingService', () => {
return; return;
} }
await kubeflowTrainingService.setClusterMetadata(TrialConfigMetadataKey.KUBEFLOW_CLUSTER_CONFIG, testKubeflowConfig), await kubeflowTrainingService.setClusterMetadata(TrialConfigMetadataKey.KUBEFLOW_CLUSTER_CONFIG, testKubeflowConfig),
await kubeflowTrainingService.setClusterMetadata(TrialConfigMetadataKey.TRIAL_CONFIG, testKubeflowTrialConfig); await kubeflowTrainingService.setClusterMetadata(TrialConfigMetadataKey.TRIAL_CONFIG, testKubeflowTrialConfig);
}); });
}); });
\ No newline at end of file
...@@ -63,7 +63,7 @@ describe('Unit Test for LocalTrainingService', () => { ...@@ -63,7 +63,7 @@ describe('Unit Test for LocalTrainingService', () => {
//trial jobs should be empty, since there are no submitted jobs //trial jobs should be empty, since there are no submitted jobs
chai.expect(await localTrainingService.listTrialJobs()).to.be.empty; chai.expect(await localTrainingService.listTrialJobs()).to.be.empty;
}); });
it('setClusterMetadata and getClusterMetadata', async () => { it('setClusterMetadata and getClusterMetadata', async () => {
await localTrainingService.setClusterMetadata(TrialConfigMetadataKey.TRIAL_CONFIG, trialConfig); await localTrainingService.setClusterMetadata(TrialConfigMetadataKey.TRIAL_CONFIG, trialConfig);
localTrainingService.getClusterMetadata(TrialConfigMetadataKey.TRIAL_CONFIG).then((data)=>{ localTrainingService.getClusterMetadata(TrialConfigMetadataKey.TRIAL_CONFIG).then((data)=>{
...@@ -87,7 +87,7 @@ describe('Unit Test for LocalTrainingService', () => { ...@@ -87,7 +87,7 @@ describe('Unit Test for LocalTrainingService', () => {
await localTrainingService.cancelTrialJob(jobDetail.id); await localTrainingService.cancelTrialJob(jobDetail.id);
chai.expect(jobDetail.status).to.be.equals('USER_CANCELED'); chai.expect(jobDetail.status).to.be.equals('USER_CANCELED');
}).timeout(20000); }).timeout(20000);
it('Read metrics, Add listener, and remove listener', async () => { it('Read metrics, Add listener, and remove listener', async () => {
// set meta data // set meta data
const trialConfig: string = `{\"command\":\"python3 mockedTrial.py\", \"codeDir\":\"${localCodeDir}\",\"gpuNum\":0}` const trialConfig: string = `{\"command\":\"python3 mockedTrial.py\", \"codeDir\":\"${localCodeDir}\",\"gpuNum\":0}`
......
...@@ -89,7 +89,7 @@ describe('Unit Test for PAITrainingService', () => { ...@@ -89,7 +89,7 @@ describe('Unit Test for PAITrainingService', () => {
chai.expect(trialDetail.status).to.be.equals('WAITING'); chai.expect(trialDetail.status).to.be.equals('WAITING');
} catch(error) { } catch(error) {
console.log('Submit job failed:' + error); console.log('Submit job failed:' + error);
chai.assert(error) chai.assert(error)
} }
}); });
}); });
\ No newline at end of file
...@@ -9,7 +9,10 @@ ...@@ -9,7 +9,10 @@
"no-increment-decrement": false, "no-increment-decrement": false,
"promise-function-async": false, "promise-function-async": false,
"no-console": [true, "log"], "no-console": [true, "log"],
"no-multiline-string": false "no-multiline-string": false,
"no-suspicious-comment": false,
"no-backbone-get-set-outside-model": false,
"max-classes-per-file": false
}, },
"rulesDirectory": [], "rulesDirectory": [],
"linterOptions": { "linterOptions": {
......
...@@ -7,5 +7,5 @@ declare module 'child-process-promise' { ...@@ -7,5 +7,5 @@ declare module 'child-process-promise' {
stderr: string, stderr: string,
message: string message: string
} }
} }
} }
\ No newline at end of file
declare module 'webhdfs' {
export function createClient(arg: any): any;
}
\ No newline at end of file
...@@ -154,7 +154,7 @@ def main(): ...@@ -154,7 +154,7 @@ def main():
assessor = None assessor = None
if args.tuner_class_name in ModuleName: if args.tuner_class_name in ModuleName:
tuner = create_builtin_class_instance( tuner = create_builtin_class_instance(
args.tuner_class_name, args.tuner_class_name,
args.tuner_args) args.tuner_args)
else: else:
tuner = create_customized_class_instance( tuner = create_customized_class_instance(
......
...@@ -81,7 +81,7 @@ def create_bracket_parameter_id(brackets_id, brackets_curr_decay, increased_id=- ...@@ -81,7 +81,7 @@ def create_bracket_parameter_id(brackets_id, brackets_curr_decay, increased_id=-
class Bracket(): class Bracket():
""" """
A bracket in BOHB, all the information of a bracket is managed by A bracket in BOHB, all the information of a bracket is managed by
an instance of this class. an instance of this class.
Parameters Parameters
...@@ -251,7 +251,7 @@ class BOHB(MsgDispatcherBase): ...@@ -251,7 +251,7 @@ class BOHB(MsgDispatcherBase):
BOHB performs robust and efficient hyperparameter optimization BOHB performs robust and efficient hyperparameter optimization
at scale by combining the speed of Hyperband searches with the at scale by combining the speed of Hyperband searches with the
guidance and guarantees of convergence of Bayesian Optimization. guidance and guarantees of convergence of Bayesian Optimization.
Instead of sampling new configurations at random, BOHB uses Instead of sampling new configurations at random, BOHB uses
kernel density estimators to select promising candidates. kernel density estimators to select promising candidates.
Parameters Parameters
...@@ -335,7 +335,7 @@ class BOHB(MsgDispatcherBase): ...@@ -335,7 +335,7 @@ class BOHB(MsgDispatcherBase):
pass pass
def handle_initialize(self, data): def handle_initialize(self, data):
"""Initialize Tuner, including creating Bayesian optimization-based parametric models """Initialize Tuner, including creating Bayesian optimization-based parametric models
and search space formations and search space formations
Parameters Parameters
...@@ -403,7 +403,7 @@ class BOHB(MsgDispatcherBase): ...@@ -403,7 +403,7 @@ class BOHB(MsgDispatcherBase):
If this function is called, Command will be sent by BOHB: If this function is called, Command will be sent by BOHB:
a. If there is a parameter need to run, will return "NewTrialJob" with a dict: a. If there is a parameter need to run, will return "NewTrialJob" with a dict:
{ {
'parameter_id': id of new hyperparameter 'parameter_id': id of new hyperparameter
'parameter_source': 'algorithm' 'parameter_source': 'algorithm'
'parameters': value of new hyperparameter 'parameters': value of new hyperparameter
...@@ -458,30 +458,30 @@ class BOHB(MsgDispatcherBase): ...@@ -458,30 +458,30 @@ class BOHB(MsgDispatcherBase):
var, lower=search_space[var]["_value"][0], upper=search_space[var]["_value"][1])) var, lower=search_space[var]["_value"][0], upper=search_space[var]["_value"][1]))
elif _type == 'quniform': elif _type == 'quniform':
cs.add_hyperparameter(CSH.UniformFloatHyperparameter( cs.add_hyperparameter(CSH.UniformFloatHyperparameter(
var, lower=search_space[var]["_value"][0], upper=search_space[var]["_value"][1], var, lower=search_space[var]["_value"][0], upper=search_space[var]["_value"][1],
q=search_space[var]["_value"][2])) q=search_space[var]["_value"][2]))
elif _type == 'loguniform': elif _type == 'loguniform':
cs.add_hyperparameter(CSH.UniformFloatHyperparameter( cs.add_hyperparameter(CSH.UniformFloatHyperparameter(
var, lower=search_space[var]["_value"][0], upper=search_space[var]["_value"][1], var, lower=search_space[var]["_value"][0], upper=search_space[var]["_value"][1],
log=True)) log=True))
elif _type == 'qloguniform': elif _type == 'qloguniform':
cs.add_hyperparameter(CSH.UniformFloatHyperparameter( cs.add_hyperparameter(CSH.UniformFloatHyperparameter(
var, lower=search_space[var]["_value"][0], upper=search_space[var]["_value"][1], var, lower=search_space[var]["_value"][0], upper=search_space[var]["_value"][1],
q=search_space[var]["_value"][2], log=True)) q=search_space[var]["_value"][2], log=True))
elif _type == 'normal': elif _type == 'normal':
cs.add_hyperparameter(CSH.NormalFloatHyperparameter( cs.add_hyperparameter(CSH.NormalFloatHyperparameter(
var, mu=search_space[var]["_value"][1], sigma=search_space[var]["_value"][2])) var, mu=search_space[var]["_value"][1], sigma=search_space[var]["_value"][2]))
elif _type == 'qnormal': elif _type == 'qnormal':
cs.add_hyperparameter(CSH.NormalFloatHyperparameter( cs.add_hyperparameter(CSH.NormalFloatHyperparameter(
var, mu=search_space[var]["_value"][1], sigma=search_space[var]["_value"][2], var, mu=search_space[var]["_value"][1], sigma=search_space[var]["_value"][2],
q=search_space[var]["_value"][3])) q=search_space[var]["_value"][3]))
elif _type == 'lognormal': elif _type == 'lognormal':
cs.add_hyperparameter(CSH.NormalFloatHyperparameter( cs.add_hyperparameter(CSH.NormalFloatHyperparameter(
var, mu=search_space[var]["_value"][1], sigma=search_space[var]["_value"][2], var, mu=search_space[var]["_value"][1], sigma=search_space[var]["_value"][2],
log=True)) log=True))
elif _type == 'qlognormal': elif _type == 'qlognormal':
cs.add_hyperparameter(CSH.NormalFloatHyperparameter( cs.add_hyperparameter(CSH.NormalFloatHyperparameter(
var, mu=search_space[var]["_value"][1], sigma=search_space[var]["_value"][2], var, mu=search_space[var]["_value"][1], sigma=search_space[var]["_value"][2],
q=search_space[var]["_value"][3], log=True)) q=search_space[var]["_value"][3], log=True))
else: else:
raise ValueError( raise ValueError(
...@@ -553,7 +553,7 @@ class BOHB(MsgDispatcherBase): ...@@ -553,7 +553,7 @@ class BOHB(MsgDispatcherBase):
self.brackets[s].set_config_perf( self.brackets[s].set_config_perf(
int(i), data['parameter_id'], sys.maxsize, value) int(i), data['parameter_id'], sys.maxsize, value)
self.completed_hyper_configs.append(data) self.completed_hyper_configs.append(data)
_parameters = self.parameters[data['parameter_id']] _parameters = self.parameters[data['parameter_id']]
_parameters.pop(_KEY) _parameters.pop(_KEY)
# update BO with loss, max_s budget, hyperparameters # update BO with loss, max_s budget, hyperparameters
......
...@@ -117,7 +117,7 @@ class CG_BOHB(object): ...@@ -117,7 +117,7 @@ class CG_BOHB(object):
seperated by budget. This function sample a configuration from seperated by budget. This function sample a configuration from
largest budget. Firstly we sample "num_samples" configurations, largest budget. Firstly we sample "num_samples" configurations,
then prefer one with the largest l(x)/g(x). then prefer one with the largest l(x)/g(x).
Parameters: Parameters:
----------- -----------
info_dict: dict info_dict: dict
......
...@@ -34,7 +34,7 @@ log_level_map = { ...@@ -34,7 +34,7 @@ log_level_map = {
} }
_time_format = '%m/%d/%Y, %I:%M:%S %p' _time_format = '%m/%d/%Y, %I:%M:%S %p'
class _LoggerFileWrapper(TextIOBase): class _LoggerFileWrapper(TextIOBase):
def __init__(self, logger_file): def __init__(self, logger_file):
self.file = logger_file self.file = logger_file
......
...@@ -67,7 +67,7 @@ class CurvefittingAssessor(Assessor): ...@@ -67,7 +67,7 @@ class CurvefittingAssessor(Assessor):
def trial_end(self, trial_job_id, success): def trial_end(self, trial_job_id, success):
"""update the best performance of completed trial job """update the best performance of completed trial job
Parameters Parameters
---------- ----------
trial_job_id: int trial_job_id: int
...@@ -112,7 +112,7 @@ class CurvefittingAssessor(Assessor): ...@@ -112,7 +112,7 @@ class CurvefittingAssessor(Assessor):
curr_step = len(trial_history) curr_step = len(trial_history)
if curr_step < self.start_step: if curr_step < self.start_step:
return AssessResult.Good return AssessResult.Good
if trial_job_id in self.last_judgment_num.keys() and curr_step - self.last_judgment_num[trial_job_id] < self.gap: if trial_job_id in self.last_judgment_num.keys() and curr_step - self.last_judgment_num[trial_job_id] < self.gap:
return AssessResult.Good return AssessResult.Good
self.last_judgment_num[trial_job_id] = curr_step self.last_judgment_num[trial_job_id] = curr_step
......
...@@ -26,7 +26,7 @@ curve_combination_models = ['vap', 'pow3', 'linear', 'logx_linear', 'dr_hill_zer ...@@ -26,7 +26,7 @@ curve_combination_models = ['vap', 'pow3', 'linear', 'logx_linear', 'dr_hill_zer
def vap(x, a, b, c): def vap(x, a, b, c):
"""Vapor pressure model """Vapor pressure model
Parameters Parameters
---------- ----------
x: int x: int
...@@ -109,7 +109,7 @@ model_para_num['logx_linear'] = 2 ...@@ -109,7 +109,7 @@ model_para_num['logx_linear'] = 2
def dr_hill_zero_background(x, theta, eta, kappa): def dr_hill_zero_background(x, theta, eta, kappa):
"""dr hill zero background """dr hill zero background
Parameters Parameters
---------- ----------
x: int x: int
...@@ -261,7 +261,7 @@ model_para_num['weibull'] = 4 ...@@ -261,7 +261,7 @@ model_para_num['weibull'] = 4
def janoschek(x, a, beta, k, delta): def janoschek(x, a, beta, k, delta):
"""http://www.pisces-conservation.com/growthhelp/janoschek.htm """http://www.pisces-conservation.com/growthhelp/janoschek.htm
Parameters Parameters
---------- ----------
x: int x: int
......
...@@ -35,7 +35,7 @@ logger = logging.getLogger('curvefitting_Assessor') ...@@ -35,7 +35,7 @@ logger = logging.getLogger('curvefitting_Assessor')
class CurveModel(object): class CurveModel(object):
"""Build a Curve Model to predict the performance """Build a Curve Model to predict the performance
Algorithm: https://github.com/Microsoft/nni/blob/master/src/sdk/pynni/nni/curvefitting_assessor/README.md Algorithm: https://github.com/Microsoft/nni/blob/master/src/sdk/pynni/nni/curvefitting_assessor/README.md
Parameters Parameters
...@@ -53,7 +53,7 @@ class CurveModel(object): ...@@ -53,7 +53,7 @@ class CurveModel(object):
def fit_theta(self): def fit_theta(self):
"""use least squares to fit all default curves parameter seperately """use least squares to fit all default curves parameter seperately
Returns Returns
------- -------
None None
...@@ -87,7 +87,7 @@ class CurveModel(object): ...@@ -87,7 +87,7 @@ class CurveModel(object):
def filter_curve(self): def filter_curve(self):
"""filter the poor performing curve """filter the poor performing curve
Returns Returns
------- -------
None None
...@@ -117,7 +117,7 @@ class CurveModel(object): ...@@ -117,7 +117,7 @@ class CurveModel(object):
def predict_y(self, model, pos): def predict_y(self, model, pos):
"""return the predict y of 'model' when epoch = pos """return the predict y of 'model' when epoch = pos
Parameters Parameters
---------- ----------
model: string model: string
...@@ -162,7 +162,7 @@ class CurveModel(object): ...@@ -162,7 +162,7 @@ class CurveModel(object):
def normalize_weight(self, samples): def normalize_weight(self, samples):
"""normalize weight """normalize weight
Parameters Parameters
---------- ----------
samples: list samples: list
...@@ -184,7 +184,7 @@ class CurveModel(object): ...@@ -184,7 +184,7 @@ class CurveModel(object):
def sigma_sq(self, sample): def sigma_sq(self, sample):
"""returns the value of sigma square, given the weight's sample """returns the value of sigma square, given the weight's sample
Parameters Parameters
---------- ----------
sample: list sample: list
...@@ -203,7 +203,7 @@ class CurveModel(object): ...@@ -203,7 +203,7 @@ class CurveModel(object):
def normal_distribution(self, pos, sample): def normal_distribution(self, pos, sample):
"""returns the value of normal distribution, given the weight's sample and target position """returns the value of normal distribution, given the weight's sample and target position
Parameters Parameters
---------- ----------
pos: int pos: int
...@@ -227,7 +227,7 @@ class CurveModel(object): ...@@ -227,7 +227,7 @@ class CurveModel(object):
---------- ----------
sample: list sample: list
sample is a (1 * NUM_OF_FUNCTIONS) matrix, representing{w1, w2, ... wk} sample is a (1 * NUM_OF_FUNCTIONS) matrix, representing{w1, w2, ... wk}
Returns Returns
------- -------
float float
...@@ -241,13 +241,13 @@ class CurveModel(object): ...@@ -241,13 +241,13 @@ class CurveModel(object):
def prior(self, samples): def prior(self, samples):
"""priori distribution """priori distribution
Parameters Parameters
---------- ----------
samples: list samples: list
a collection of sample, it's a (NUM_OF_INSTANCE * NUM_OF_FUNCTIONS) matrix, a collection of sample, it's a (NUM_OF_INSTANCE * NUM_OF_FUNCTIONS) matrix,
representing{{w11, w12, ..., w1k}, {w21, w22, ... w2k}, ...{wk1, wk2,..., wkk}} representing{{w11, w12, ..., w1k}, {w21, w22, ... w2k}, ...{wk1, wk2,..., wkk}}
Returns Returns
------- -------
float float
...@@ -264,13 +264,13 @@ class CurveModel(object): ...@@ -264,13 +264,13 @@ class CurveModel(object):
def target_distribution(self, samples): def target_distribution(self, samples):
"""posterior probability """posterior probability
Parameters Parameters
---------- ----------
samples: list samples: list
a collection of sample, it's a (NUM_OF_INSTANCE * NUM_OF_FUNCTIONS) matrix, a collection of sample, it's a (NUM_OF_INSTANCE * NUM_OF_FUNCTIONS) matrix,
representing{{w11, w12, ..., w1k}, {w21, w22, ... w2k}, ...{wk1, wk2,..., wkk}} representing{{w11, w12, ..., w1k}, {w21, w22, ... w2k}, ...{wk1, wk2,..., wkk}}
Returns Returns
------- -------
float float
...@@ -319,7 +319,7 @@ class CurveModel(object): ...@@ -319,7 +319,7 @@ class CurveModel(object):
def predict(self, trial_history): def predict(self, trial_history):
"""predict the value of target position """predict the value of target position
Parameters Parameters
---------- ----------
trial_history: list trial_history: list
......
...@@ -167,7 +167,7 @@ class EvolutionTuner(Tuner): ...@@ -167,7 +167,7 @@ class EvolutionTuner(Tuner):
self.space = None self.space = None
def update_search_space(self, search_space): def update_search_space(self, search_space):
"""Update search space. """Update search space.
Search_space contains the information that user pre-defined. Search_space contains the information that user pre-defined.
Parameters Parameters
...@@ -194,7 +194,7 @@ class EvolutionTuner(Tuner): ...@@ -194,7 +194,7 @@ class EvolutionTuner(Tuner):
Parameters Parameters
---------- ----------
parameter_id : int parameter_id : int
Returns Returns
------- -------
config : dict config : dict
......
...@@ -43,7 +43,7 @@ _epsilon = 1e-6 ...@@ -43,7 +43,7 @@ _epsilon = 1e-6
def create_parameter_id(): def create_parameter_id():
"""Create an id """Create an id
Returns Returns
------- -------
int int
...@@ -55,7 +55,7 @@ def create_parameter_id(): ...@@ -55,7 +55,7 @@ def create_parameter_id():
def create_bracket_parameter_id(brackets_id, brackets_curr_decay, increased_id=-1): def create_bracket_parameter_id(brackets_id, brackets_curr_decay, increased_id=-1):
"""Create a full id for a specific bracket's hyperparameter configuration """Create a full id for a specific bracket's hyperparameter configuration
Parameters Parameters
---------- ----------
brackets_id: int brackets_id: int
...@@ -79,7 +79,7 @@ def create_bracket_parameter_id(brackets_id, brackets_curr_decay, increased_id=- ...@@ -79,7 +79,7 @@ def create_bracket_parameter_id(brackets_id, brackets_curr_decay, increased_id=-
def json2parameter(ss_spec, random_state): def json2parameter(ss_spec, random_state):
"""Randomly generate values for hyperparameters from hyperparameter space i.e., x. """Randomly generate values for hyperparameters from hyperparameter space i.e., x.
Parameters Parameters
---------- ----------
ss_spec: ss_spec:
...@@ -116,7 +116,7 @@ def json2parameter(ss_spec, random_state): ...@@ -116,7 +116,7 @@ def json2parameter(ss_spec, random_state):
class Bracket(): class Bracket():
"""A bracket in Hyperband, all the information of a bracket is managed by an instance of this class """A bracket in Hyperband, all the information of a bracket is managed by an instance of this class
Parameters Parameters
---------- ----------
s: int s: int
...@@ -132,7 +132,7 @@ class Bracket(): ...@@ -132,7 +132,7 @@ class Bracket():
optimize_mode: str optimize_mode: str
optimize mode, 'maximize' or 'minimize' optimize mode, 'maximize' or 'minimize'
""" """
def __init__(self, s, s_max, eta, R, optimize_mode): def __init__(self, s, s_max, eta, R, optimize_mode):
self.bracket_id = s self.bracket_id = s
self.s_max = s_max self.s_max = s_max
...@@ -163,7 +163,7 @@ class Bracket(): ...@@ -163,7 +163,7 @@ class Bracket():
def set_config_perf(self, i, parameter_id, seq, value): def set_config_perf(self, i, parameter_id, seq, value):
"""update trial's latest result with its sequence number, e.g., epoch number or batch number """update trial's latest result with its sequence number, e.g., epoch number or batch number
Parameters Parameters
---------- ----------
i: int i: int
...@@ -184,7 +184,7 @@ class Bracket(): ...@@ -184,7 +184,7 @@ class Bracket():
self.configs_perf[i][parameter_id] = [seq, value] self.configs_perf[i][parameter_id] = [seq, value]
else: else:
self.configs_perf[i][parameter_id] = [seq, value] self.configs_perf[i][parameter_id] = [seq, value]
def inform_trial_end(self, i): def inform_trial_end(self, i):
"""If the trial is finished and the corresponding round (i.e., i) has all its trials finished, """If the trial is finished and the corresponding round (i.e., i) has all its trials finished,
...@@ -230,7 +230,7 @@ class Bracket(): ...@@ -230,7 +230,7 @@ class Bracket():
---------- ----------
num: int num: int
the number of hyperparameter configurations the number of hyperparameter configurations
Returns Returns
------- -------
list list
...@@ -350,7 +350,7 @@ class Hyperband(MsgDispatcherBase): ...@@ -350,7 +350,7 @@ class Hyperband(MsgDispatcherBase):
def handle_update_search_space(self, data): def handle_update_search_space(self, data):
"""data: JSON object, which is search space """data: JSON object, which is search space
Parameters Parameters
---------- ----------
data: int data: int
...@@ -392,9 +392,9 @@ class Hyperband(MsgDispatcherBase): ...@@ -392,9 +392,9 @@ class Hyperband(MsgDispatcherBase):
""" """
Parameters Parameters
---------- ----------
data: data:
it is an object which has keys 'parameter_id', 'value', 'trial_job_id', 'type', 'sequence'. it is an object which has keys 'parameter_id', 'value', 'trial_job_id', 'type', 'sequence'.
Raises Raises
------ ------
ValueError ValueError
......
...@@ -21,10 +21,10 @@ from nni.assessor import Assessor, AssessResult ...@@ -21,10 +21,10 @@ from nni.assessor import Assessor, AssessResult
logger = logging.getLogger('medianstop_Assessor') logger = logging.getLogger('medianstop_Assessor')
class MedianstopAssessor(Assessor): class MedianstopAssessor(Assessor):
"""MedianstopAssessor is The median stopping rule stops a pending trial X at step S """MedianstopAssessor is The median stopping rule stops a pending trial X at step S
if the trial’s best objective value by step S is strictly worse than the median value if the trial’s best objective value by step S is strictly worse than the median value
of the running averages of all completed trials’ objectives reported up to step S of the running averages of all completed trials’ objectives reported up to step S
Parameters Parameters
---------- ----------
optimize_mode: str optimize_mode: str
...@@ -60,7 +60,7 @@ class MedianstopAssessor(Assessor): ...@@ -60,7 +60,7 @@ class MedianstopAssessor(Assessor):
def trial_end(self, trial_job_id, success): def trial_end(self, trial_job_id, success):
"""trial_end """trial_end
Parameters Parameters
---------- ----------
trial_job_id: int trial_job_id: int
...@@ -83,7 +83,7 @@ class MedianstopAssessor(Assessor): ...@@ -83,7 +83,7 @@ class MedianstopAssessor(Assessor):
def assess_trial(self, trial_job_id, trial_history): def assess_trial(self, trial_job_id, trial_history):
"""assess_trial """assess_trial
Parameters Parameters
---------- ----------
trial_job_id: int trial_job_id: int
......
...@@ -27,7 +27,7 @@ from scipy.optimize import minimize ...@@ -27,7 +27,7 @@ from scipy.optimize import minimize
import nni.metis_tuner.lib_data as lib_data import nni.metis_tuner.lib_data as lib_data
def next_hyperparameter_expected_improvement(fun_prediction, def next_hyperparameter_expected_improvement(fun_prediction,
fun_prediction_args, fun_prediction_args,
x_bounds, x_types, x_bounds, x_types,
samples_y_aggregation, samples_y_aggregation,
......
...@@ -69,7 +69,7 @@ class NetworkMorphismTuner(Tuner): ...@@ -69,7 +69,7 @@ class NetworkMorphismTuner(Tuner):
optimize_mode : str optimize_mode : str
optimize mode "minimize" or "maximize" (default: {"minimize"}) optimize mode "minimize" or "maximize" (default: {"minimize"})
path : str path : str
default mode path to save the model file (default: {"model_path"}) default mode path to save the model file (default: {"model_path"})
verbose : bool verbose : bool
verbose to print the log (default: {True}) verbose to print the log (default: {True})
beta : float beta : float
...@@ -154,7 +154,7 @@ class NetworkMorphismTuner(Tuner): ...@@ -154,7 +154,7 @@ class NetworkMorphismTuner(Tuner):
def receive_trial_result(self, parameter_id, parameters, value): def receive_trial_result(self, parameter_id, parameters, value):
""" Record an observation of the objective function. """ Record an observation of the objective function.
Parameters Parameters
---------- ----------
parameter_id : int parameter_id : int
...@@ -267,7 +267,7 @@ class NetworkMorphismTuner(Tuner): ...@@ -267,7 +267,7 @@ class NetworkMorphismTuner(Tuner):
---------- ----------
model_id : int model_id : int
model index model index
Returns Returns
------- -------
load_model : Graph load_model : Graph
...@@ -297,7 +297,7 @@ class NetworkMorphismTuner(Tuner): ...@@ -297,7 +297,7 @@ class NetworkMorphismTuner(Tuner):
---------- ----------
model_id : int model_id : int
model index model index
Returns Returns
------- -------
float float
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment