"vscode:/vscode.git/clone" did not exist on "3120114b5e6d0501253210fe5b697d05c37038bb"
Commit ae7a72bc authored by Hongarc's avatar Hongarc Committed by Chi Song
Browse files

Remove all whitespace at end of line (#1162)

parent 14c1b31c
...@@ -63,13 +63,13 @@ abstract class KubernetesTrainingService { ...@@ -63,13 +63,13 @@ abstract class KubernetesTrainingService {
protected kubernetesClusterConfig?: KubernetesClusterConfig; protected kubernetesClusterConfig?: KubernetesClusterConfig;
protected versionCheck: boolean = true; protected versionCheck: boolean = true;
protected logCollection: string; protected logCollection: string;
constructor() { constructor() {
this.log = getLogger(); this.log = getLogger();
this.metricsEmitter = new EventEmitter(); this.metricsEmitter = new EventEmitter();
this.trialJobsMap = new Map<string, KubernetesTrialJobDetail>(); this.trialJobsMap = new Map<string, KubernetesTrialJobDetail>();
this.trialLocalNFSTempFolder = path.join(getExperimentRootDir(), 'trials-nfs-tmp'); this.trialLocalNFSTempFolder = path.join(getExperimentRootDir(), 'trials-nfs-tmp');
this.experimentId = getExperimentId(); this.experimentId = getExperimentId();
this.nextTrialSequenceId = -1; this.nextTrialSequenceId = -1;
this.CONTAINER_MOUNT_PATH = '/tmp/mount'; this.CONTAINER_MOUNT_PATH = '/tmp/mount';
this.genericK8sClient = new GeneralK8sClient(); this.genericK8sClient = new GeneralK8sClient();
...@@ -86,8 +86,8 @@ abstract class KubernetesTrainingService { ...@@ -86,8 +86,8 @@ abstract class KubernetesTrainingService {
public async listTrialJobs(): Promise<TrialJobDetail[]> { public async listTrialJobs(): Promise<TrialJobDetail[]> {
const jobs: TrialJobDetail[] = []; const jobs: TrialJobDetail[] = [];
for (const [key, value] of this.trialJobsMap) { for (const [key, value] of this.trialJobsMap) {
if (value.form.jobType === 'TRIAL') { if (value.form.jobType === 'TRIAL') {
jobs.push(await this.getTrialJob(key)); jobs.push(await this.getTrialJob(key));
} }
...@@ -102,7 +102,7 @@ abstract class KubernetesTrainingService { ...@@ -102,7 +102,7 @@ abstract class KubernetesTrainingService {
if (!kubernetesTrialJob) { if (!kubernetesTrialJob) {
return Promise.reject(`trial job ${trialJobId} not found`) return Promise.reject(`trial job ${trialJobId} not found`)
} }
return Promise.resolve(kubernetesTrialJob); return Promise.resolve(kubernetesTrialJob);
} }
...@@ -114,7 +114,7 @@ abstract class KubernetesTrainingService { ...@@ -114,7 +114,7 @@ abstract class KubernetesTrainingService {
public removeTrialJobMetricListener(listener: (metric: TrialJobMetric) => void) { public removeTrialJobMetricListener(listener: (metric: TrialJobMetric) => void) {
this.metricsEmitter.off('metric', listener); this.metricsEmitter.off('metric', listener);
} }
public get isMultiPhaseJobSupported(): boolean { public get isMultiPhaseJobSupported(): boolean {
return false; return false;
} }
...@@ -153,7 +153,7 @@ abstract class KubernetesTrainingService { ...@@ -153,7 +153,7 @@ abstract class KubernetesTrainingService {
{ {
apiVersion: 'v1', apiVersion: 'v1',
kind: 'Secret', kind: 'Secret',
metadata: { metadata: {
name: this.azureStorageSecretName, name: this.azureStorageSecretName,
namespace: 'default', namespace: 'default',
labels: { labels: {
...@@ -174,15 +174,15 @@ abstract class KubernetesTrainingService { ...@@ -174,15 +174,15 @@ abstract class KubernetesTrainingService {
} }
return Promise.resolve(); return Promise.resolve();
} }
/** /**
* Genereate run script for different roles(like worker or ps) * Genereate run script for different roles(like worker or ps)
* @param trialJobId trial job id * @param trialJobId trial job id
* @param trialWorkingFolder working folder * @param trialWorkingFolder working folder
* @param command * @param command
* @param trialSequenceId sequence id * @param trialSequenceId sequence id
*/ */
protected async generateRunScript(platform: string, trialJobId: string, trialWorkingFolder: string, protected async generateRunScript(platform: string, trialJobId: string, trialWorkingFolder: string,
command: string, trialSequenceId: string, roleName: string, gpuNum: number): Promise<string> { command: string, trialSequenceId: string, roleName: string, gpuNum: number): Promise<string> {
let nvidia_script: string = ''; let nvidia_script: string = '';
// Nvidia devcie plugin for K8S has a known issue that requesting zero GPUs allocates all GPUs // Nvidia devcie plugin for K8S has a known issue that requesting zero GPUs allocates all GPUs
...@@ -229,7 +229,7 @@ abstract class KubernetesTrainingService { ...@@ -229,7 +229,7 @@ abstract class KubernetesTrainingService {
const errorMessage: string = `CancelTrialJob: trial job id ${trialJobId} not found`; const errorMessage: string = `CancelTrialJob: trial job id ${trialJobId} not found`;
this.log.error(errorMessage); this.log.error(errorMessage);
return Promise.reject(errorMessage); return Promise.reject(errorMessage);
} }
if(!this.kubernetesCRDClient) { if(!this.kubernetesCRDClient) {
const errorMessage: string = `CancelTrialJob: trial job id ${trialJobId} failed because operatorClient is undefined`; const errorMessage: string = `CancelTrialJob: trial job id ${trialJobId} failed because operatorClient is undefined`;
this.log.error(errorMessage); this.log.error(errorMessage);
...@@ -268,8 +268,8 @@ abstract class KubernetesTrainingService { ...@@ -268,8 +268,8 @@ abstract class KubernetesTrainingService {
kubernetesTrialJob.status = 'SYS_CANCELED'; kubernetesTrialJob.status = 'SYS_CANCELED';
} }
} }
// Delete all kubernetes jobs whose expId label is current experiment id // Delete all kubernetes jobs whose expId label is current experiment id
try { try {
if(this.kubernetesCRDClient) { if(this.kubernetesCRDClient) {
await this.kubernetesCRDClient.deleteKubernetesJob(new Map( await this.kubernetesCRDClient.deleteKubernetesJob(new Map(
...@@ -290,7 +290,7 @@ abstract class KubernetesTrainingService { ...@@ -290,7 +290,7 @@ abstract class KubernetesTrainingService {
this.log.error(`Unmount ${this.trialLocalNFSTempFolder} failed, error is ${error}`); this.log.error(`Unmount ${this.trialLocalNFSTempFolder} failed, error is ${error}`);
} }
// Stop kubernetes rest server // Stop kubernetes rest server
if(!this.kubernetesJobRestServer) { if(!this.kubernetesJobRestServer) {
throw new Error('kubernetesJobRestServer not initialized!'); throw new Error('kubernetesJobRestServer not initialized!');
} }
......
...@@ -59,8 +59,8 @@ class GPUScheduler { ...@@ -59,8 +59,8 @@ class GPUScheduler {
} }
/** /**
* Generate gpu metric collector shell script in local machine, * Generate gpu metric collector shell script in local machine,
* used to run in remote machine, and will be deleted after uploaded from local. * used to run in remote machine, and will be deleted after uploaded from local.
*/ */
private async runGpuMetricsCollectorScript(): Promise<void> { private async runGpuMetricsCollectorScript(): Promise<void> {
await execMkdir(this.gpuMetricCollectorScriptFolder); await execMkdir(this.gpuMetricCollectorScriptFolder);
......
...@@ -532,7 +532,7 @@ class LocalTrainingService implements TrainingService { ...@@ -532,7 +532,7 @@ class LocalTrainingService implements TrainingService {
} }
const scripts: string[] = this.getScript(this.localTrailConfig, trialJobDetail.workingDirectory); const scripts: string[] = this.getScript(this.localTrailConfig, trialJobDetail.workingDirectory);
scripts.forEach(script => { scripts.forEach(script => {
runScriptLines.push(script); runScriptLines.push(script);
}); });
await execMkdir(trialJobDetail.workingDirectory); await execMkdir(trialJobDetail.workingDirectory);
await execMkdir(path.join(trialJobDetail.workingDirectory, '.nni')); await execMkdir(path.join(trialJobDetail.workingDirectory, '.nni'));
......
...@@ -57,7 +57,7 @@ export namespace HDFSClientUtility { ...@@ -57,7 +57,7 @@ export namespace HDFSClientUtility {
/** /**
* Copy a local file to hdfs directory * Copy a local file to hdfs directory
* *
* @param localFilePath local file path(source) * @param localFilePath local file path(source)
* @param hdfsFilePath hdfs file path(target) * @param hdfsFilePath hdfs file path(target)
* @param hdfsClient hdfs client * @param hdfsClient hdfs client
...@@ -87,7 +87,7 @@ export namespace HDFSClientUtility { ...@@ -87,7 +87,7 @@ export namespace HDFSClientUtility {
/** /**
* Recursively copy local directory to hdfs directory * Recursively copy local directory to hdfs directory
* *
* @param localDirectory local directory * @param localDirectory local directory
* @param hdfsDirectory HDFS directory * @param hdfsDirectory HDFS directory
* @param hdfsClient HDFS client * @param hdfsClient HDFS client
...@@ -118,7 +118,7 @@ export namespace HDFSClientUtility { ...@@ -118,7 +118,7 @@ export namespace HDFSClientUtility {
/** /**
* Read content from HDFS file * Read content from HDFS file
* *
* @param hdfsPath HDFS file path * @param hdfsPath HDFS file path
* @param hdfsClient HDFS client * @param hdfsClient HDFS client
*/ */
...@@ -141,7 +141,7 @@ export namespace HDFSClientUtility { ...@@ -141,7 +141,7 @@ export namespace HDFSClientUtility {
// Concat the data chunk to buffer // Concat the data chunk to buffer
buffer = Buffer.concat([buffer, chunk]); buffer = Buffer.concat([buffer, chunk]);
}); });
remoteFileStream.on('finish', function onFinish () { remoteFileStream.on('finish', function onFinish () {
// Upload is done, resolve // Upload is done, resolve
deferred.resolve(buffer); deferred.resolve(buffer);
...@@ -152,7 +152,7 @@ export namespace HDFSClientUtility { ...@@ -152,7 +152,7 @@ export namespace HDFSClientUtility {
/** /**
* Check if an HDFS path already exists * Check if an HDFS path already exists
* *
* @param hdfsPath target path need to check in HDFS * @param hdfsPath target path need to check in HDFS
* @param hdfsClient HDFS client * @param hdfsClient HDFS client
*/ */
...@@ -164,7 +164,7 @@ export namespace HDFSClientUtility { ...@@ -164,7 +164,7 @@ export namespace HDFSClientUtility {
let timeoutId : NodeJS.Timer let timeoutId : NodeJS.Timer
const delayTimeout : Promise<boolean> = new Promise<boolean>((resolve : Function, reject : Function) : void => { const delayTimeout : Promise<boolean> = new Promise<boolean>((resolve : Function, reject : Function) : void => {
// Set timeout and reject the promise once reach timeout (5 seconds) // Set timeout and reject the promise once reach timeout (5 seconds)
timeoutId = setTimeout(() => deferred.reject(`Check HDFS path ${hdfsPath} exists timeout`), 5000); timeoutId = setTimeout(() => deferred.reject(`Check HDFS path ${hdfsPath} exists timeout`), 5000);
}); });
...@@ -173,9 +173,9 @@ export namespace HDFSClientUtility { ...@@ -173,9 +173,9 @@ export namespace HDFSClientUtility {
/** /**
* Mkdir in HDFS, use default permission 755 * Mkdir in HDFS, use default permission 755
* *
* @param hdfsPath the path in HDFS. It could be either file or directory * @param hdfsPath the path in HDFS. It could be either file or directory
* @param hdfsClient * @param hdfsClient
*/ */
export function mkdir(hdfsPath : string, hdfsClient : any) : Promise<boolean> { export function mkdir(hdfsPath : string, hdfsClient : any) : Promise<boolean> {
const deferred : Deferred<boolean> = new Deferred<boolean>(); const deferred : Deferred<boolean> = new Deferred<boolean>();
...@@ -193,9 +193,9 @@ export namespace HDFSClientUtility { ...@@ -193,9 +193,9 @@ export namespace HDFSClientUtility {
/** /**
* Read directory contents * Read directory contents
* *
* @param hdfsPath the path in HDFS. It could be either file or directory * @param hdfsPath the path in HDFS. It could be either file or directory
* @param hdfsClient * @param hdfsClient
*/ */
export async function readdir(hdfsPath : string, hdfsClient : any) : Promise<string[]> { export async function readdir(hdfsPath : string, hdfsClient : any) : Promise<string[]> {
const deferred : Deferred<string[]> = new Deferred<string[]>(); const deferred : Deferred<string[]> = new Deferred<string[]>();
...@@ -218,7 +218,7 @@ export namespace HDFSClientUtility { ...@@ -218,7 +218,7 @@ export namespace HDFSClientUtility {
/** /**
* Delete HDFS path * Delete HDFS path
* @param hdfsPath the path in HDFS. It could be either file or directory * @param hdfsPath the path in HDFS. It could be either file or directory
* @param hdfsClient * @param hdfsClient
* @param recursive Mark if need to delete recursively * @param recursive Mark if need to delete recursively
*/ */
export function deletePath(hdfsPath : string, hdfsClient : any, recursive : boolean = true) : Promise<boolean> { export function deletePath(hdfsPath : string, hdfsClient : any, recursive : boolean = true) : Promise<boolean> {
......
...@@ -36,7 +36,7 @@ export class PAITaskRole { ...@@ -36,7 +36,7 @@ export class PAITaskRole {
public readonly command: string; public readonly command: string;
//Shared memory for one task in the task role //Shared memory for one task in the task role
public readonly shmMB?: number; public readonly shmMB?: number;
/** /**
* Constructor * Constructor
* @param name Name for the task role * @param name Name for the task role
...@@ -52,7 +52,7 @@ export class PAITaskRole { ...@@ -52,7 +52,7 @@ export class PAITaskRole {
this.cpuNumber = cpuNumber; this.cpuNumber = cpuNumber;
this.memoryMB = memoryMB; this.memoryMB = memoryMB;
this.gpuNumber = gpuNumber; this.gpuNumber = gpuNumber;
this.command = command; this.command = command;
this.shmMB = shmMB; this.shmMB = shmMB;
} }
} }
...@@ -83,7 +83,7 @@ export class PAIJobConfig{ ...@@ -83,7 +83,7 @@ export class PAIJobConfig{
* @param outputDir Output directory on HDFS * @param outputDir Output directory on HDFS
* @param taskRoles List of taskRole, one task role at least * @param taskRoles List of taskRole, one task role at least
*/ */
constructor(jobName: string, image : string, dataDir : string, outputDir : string, codeDir : string, constructor(jobName: string, image : string, dataDir : string, outputDir : string, codeDir : string,
taskRoles : PAITaskRole[], virtualCluster: string) { taskRoles : PAITaskRole[], virtualCluster: string) {
this.jobName = jobName; this.jobName = jobName;
this.image = image; this.image = image;
...@@ -117,7 +117,7 @@ export class NNIPAITrialConfig extends TrialConfig{ ...@@ -117,7 +117,7 @@ export class NNIPAITrialConfig extends TrialConfig{
public readonly cpuNum: number; public readonly cpuNum: number;
public readonly memoryMB: number; public readonly memoryMB: number;
public readonly image: string; public readonly image: string;
public readonly dataDir: string; public readonly dataDir: string;
public outputDir: string; public outputDir: string;
//The virtual cluster job runs on. If omitted, the job will run on default virtual cluster //The virtual cluster job runs on. If omitted, the job will run on default virtual cluster
...@@ -125,7 +125,7 @@ export class NNIPAITrialConfig extends TrialConfig{ ...@@ -125,7 +125,7 @@ export class NNIPAITrialConfig extends TrialConfig{
//Shared memory for one task in the task role //Shared memory for one task in the task role
public shmMB?: number; public shmMB?: number;
constructor(command : string, codeDir : string, gpuNum : number, cpuNum: number, memoryMB: number, constructor(command : string, codeDir : string, gpuNum : number, cpuNum: number, memoryMB: number,
image: string, dataDir: string, outputDir: string, virtualCluster?: string, shmMB?: number) { image: string, dataDir: string, outputDir: string, virtualCluster?: string, shmMB?: number) {
super(command, codeDir, gpuNum); super(command, codeDir, gpuNum);
this.cpuNum = cpuNum; this.cpuNum = cpuNum;
......
...@@ -36,7 +36,7 @@ export class PAITrialJobDetail implements TrialJobDetail { ...@@ -36,7 +36,7 @@ export class PAITrialJobDetail implements TrialJobDetail {
public hdfsLogPath: string; public hdfsLogPath: string;
public isEarlyStopped?: boolean; public isEarlyStopped?: boolean;
constructor(id: string, status: TrialJobStatus, paiJobName : string, constructor(id: string, status: TrialJobStatus, paiJobName : string,
submitTime: number, workingDirectory: string, form: JobApplicationForm, sequenceId: number, hdfsLogPath: string) { submitTime: number, workingDirectory: string, form: JobApplicationForm, sequenceId: number, hdfsLogPath: string) {
this.id = id; this.id = id;
this.status = status; this.status = status;
...@@ -50,7 +50,7 @@ export class PAITrialJobDetail implements TrialJobDetail { ...@@ -50,7 +50,7 @@ export class PAITrialJobDetail implements TrialJobDetail {
} }
} }
export const PAI_INSTALL_NNI_SHELL_FORMAT: string = export const PAI_INSTALL_NNI_SHELL_FORMAT: string =
`#!/bin/bash `#!/bin/bash
if python3 -c 'import nni' > /dev/null 2>&1; then if python3 -c 'import nni' > /dev/null 2>&1; then
# nni module is already installed, skip # nni module is already installed, skip
...@@ -62,12 +62,12 @@ fi`; ...@@ -62,12 +62,12 @@ fi`;
export const PAI_TRIAL_COMMAND_FORMAT: string = export const PAI_TRIAL_COMMAND_FORMAT: string =
`export NNI_PLATFORM=pai NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={2} NNI_EXP_ID={3} NNI_TRIAL_SEQ_ID={4} `export NNI_PLATFORM=pai NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={2} NNI_EXP_ID={3} NNI_TRIAL_SEQ_ID={4}
&& cd $NNI_SYS_DIR && sh install_nni.sh && cd $NNI_SYS_DIR && sh install_nni.sh
&& python3 -m nni_trial_tool.trial_keeper --trial_command '{5}' --nnimanager_ip '{6}' --nnimanager_port '{7}' && python3 -m nni_trial_tool.trial_keeper --trial_command '{5}' --nnimanager_ip '{6}' --nnimanager_port '{7}'
--pai_hdfs_output_dir '{8}' --pai_hdfs_host '{9}' --pai_user_name {10} --nni_hdfs_exp_dir '{11}' --webhdfs_path '/webhdfs/api/v1' --nni_manager_version '{12}' --log_collection '{13}'`; --pai_hdfs_output_dir '{8}' --pai_hdfs_host '{9}' --pai_user_name {10} --nni_hdfs_exp_dir '{11}' --webhdfs_path '/webhdfs/api/v1' --nni_manager_version '{12}' --log_collection '{13}'`;
export const PAI_OUTPUT_DIR_FORMAT: string = export const PAI_OUTPUT_DIR_FORMAT: string =
`hdfs://{0}:9000/`; `hdfs://{0}:9000/`;
export const PAI_LOG_PATH_FORMAT: string = export const PAI_LOG_PATH_FORMAT: string =
`http://{0}/webhdfs/explorer.html#{1}` `http://{0}/webhdfs/explorer.html#{1}`
...@@ -44,7 +44,7 @@ export class PAIJobInfoCollector { ...@@ -44,7 +44,7 @@ export class PAIJobInfoCollector {
public async retrieveTrialStatus(paiToken? : string, paiClusterConfig?: PAIClusterConfig) : Promise<void> { public async retrieveTrialStatus(paiToken? : string, paiClusterConfig?: PAIClusterConfig) : Promise<void> {
if (!paiClusterConfig || !paiToken) { if (!paiClusterConfig || !paiToken) {
return Promise.resolve(); return Promise.resolve();
} }
const updatePaiTrialJobs : Promise<void>[] = []; const updatePaiTrialJobs : Promise<void>[] = [];
...@@ -76,7 +76,7 @@ export class PAIJobInfoCollector { ...@@ -76,7 +76,7 @@ export class PAIJobInfoCollector {
"Authorization": 'Bearer ' + paiToken "Authorization": 'Bearer ' + paiToken
} }
}; };
//TODO : pass in request timeout param? //TODO : pass in request timeout param?
request(getJobInfoRequest, (error: Error, response: request.Response, body: any) => { request(getJobInfoRequest, (error: Error, response: request.Response, body: any) => {
if (error || response.statusCode >= 500) { if (error || response.statusCode >= 500) {
this.log.error(`PAI Training service: get job info for trial ${paiTrialJob.id} from PAI Cluster failed!`); this.log.error(`PAI Training service: get job info for trial ${paiTrialJob.id} from PAI Cluster failed!`);
...@@ -87,7 +87,7 @@ export class PAIJobInfoCollector { ...@@ -87,7 +87,7 @@ export class PAIJobInfoCollector {
} else { } else {
if(response.body.jobStatus && response.body.jobStatus.state) { if(response.body.jobStatus && response.body.jobStatus.state) {
switch(response.body.jobStatus.state) { switch(response.body.jobStatus.state) {
case 'WAITING': case 'WAITING':
paiTrialJob.status = 'WAITING'; paiTrialJob.status = 'WAITING';
break; break;
case 'RUNNING': case 'RUNNING':
...@@ -96,7 +96,7 @@ export class PAIJobInfoCollector { ...@@ -96,7 +96,7 @@ export class PAIJobInfoCollector {
paiTrialJob.startTime = response.body.jobStatus.appLaunchedTime; paiTrialJob.startTime = response.body.jobStatus.appLaunchedTime;
} }
if(!paiTrialJob.url) { if(!paiTrialJob.url) {
paiTrialJob.url = response.body.jobStatus.appTrackingUrl; paiTrialJob.url = response.body.jobStatus.appTrackingUrl;
} }
break; break;
case 'SUCCEEDED': case 'SUCCEEDED':
...@@ -104,7 +104,7 @@ export class PAIJobInfoCollector { ...@@ -104,7 +104,7 @@ export class PAIJobInfoCollector {
break; break;
case 'STOPPED': case 'STOPPED':
if (paiTrialJob.isEarlyStopped !== undefined) { if (paiTrialJob.isEarlyStopped !== undefined) {
paiTrialJob.status = paiTrialJob.isEarlyStopped === true ? paiTrialJob.status = paiTrialJob.isEarlyStopped === true ?
'EARLY_STOPPED' : 'USER_CANCELED'; 'EARLY_STOPPED' : 'USER_CANCELED';
} else { } else {
// if paiTrialJob's isEarlyStopped is undefined, that mean we didn't stop it via cancellation, mark it as SYS_CANCELLED by PAI // if paiTrialJob's isEarlyStopped is undefined, that mean we didn't stop it via cancellation, mark it as SYS_CANCELLED by PAI
...@@ -112,7 +112,7 @@ export class PAIJobInfoCollector { ...@@ -112,7 +112,7 @@ export class PAIJobInfoCollector {
} }
break; break;
case 'FAILED': case 'FAILED':
paiTrialJob.status = 'FAILED'; paiTrialJob.status = 'FAILED';
break; break;
default: default:
paiTrialJob.status = 'UNKNOWN'; paiTrialJob.status = 'UNKNOWN';
......
...@@ -26,7 +26,7 @@ import { ClusterJobRestServer } from '../common/clusterJobRestServer' ...@@ -26,7 +26,7 @@ import { ClusterJobRestServer } from '../common/clusterJobRestServer'
/** /**
* PAI Training service Rest server, provides rest API to support pai job metrics update * PAI Training service Rest server, provides rest API to support pai job metrics update
* *
*/ */
@component.Singleton @component.Singleton
export class PAIJobRestServer extends ClusterJobRestServer{ export class PAIJobRestServer extends ClusterJobRestServer{
......
...@@ -25,7 +25,7 @@ export class PAITrialConfig extends TrialConfig{ ...@@ -25,7 +25,7 @@ export class PAITrialConfig extends TrialConfig{
public readonly cpuNum: number; public readonly cpuNum: number;
public readonly memoryMB: number; public readonly memoryMB: number;
public readonly image: string; public readonly image: string;
public readonly dataDir: string; public readonly dataDir: string;
public readonly outputDir: string; public readonly outputDir: string;
constructor(command : string, codeDir : string, gpuNum : number, cpuNum: number, memoryMB: number, image: string, dataDir: string, outputDir: string) { constructor(command : string, codeDir : string, gpuNum : number, cpuNum: number, memoryMB: number, image: string, dataDir: string, outputDir: string) {
......
...@@ -112,7 +112,7 @@ export class SSHClient { ...@@ -112,7 +112,7 @@ export class SSHClient {
this.sshClient = sshClient; this.sshClient = sshClient;
this.usedConnectionNumber = usedConnectionNumber; this.usedConnectionNumber = usedConnectionNumber;
} }
public get getSSHClientInstance(): Client { public get getSSHClientInstance(): Client {
return this.sshClient; return this.sshClient;
} }
...@@ -151,7 +151,7 @@ export class SSHClientManager { ...@@ -151,7 +151,7 @@ export class SSHClientManager {
port: this.rmMeta.port, port: this.rmMeta.port,
username: this.rmMeta.username }; username: this.rmMeta.username };
if (this.rmMeta.passwd) { if (this.rmMeta.passwd) {
connectConfig.password = this.rmMeta.passwd; connectConfig.password = this.rmMeta.passwd;
} else if(this.rmMeta.sshKeyPath) { } else if(this.rmMeta.sshKeyPath) {
if(!fs.existsSync(this.rmMeta.sshKeyPath)) { if(!fs.existsSync(this.rmMeta.sshKeyPath)) {
//SSh key path is not a valid file, reject //SSh key path is not a valid file, reject
...@@ -171,10 +171,10 @@ export class SSHClientManager { ...@@ -171,10 +171,10 @@ export class SSHClientManager {
// SSH connection error, reject with error message // SSH connection error, reject with error message
deferred.reject(new Error(err.message)); deferred.reject(new Error(err.message));
}).connect(connectConfig); }).connect(connectConfig);
return deferred.promise; return deferred.promise;
} }
/** /**
* find a available ssh client in ssh array, if no ssh client available, return undefined * find a available ssh client in ssh array, if no ssh client available, return undefined
*/ */
...@@ -191,7 +191,7 @@ export class SSHClientManager { ...@@ -191,7 +191,7 @@ export class SSHClientManager {
//init a new ssh client if could not get an available one //init a new ssh client if could not get an available one
return await this.initNewSSHClient(); return await this.initNewSSHClient();
} }
/** /**
* add a new ssh client to sshClientArray * add a new ssh client to sshClientArray
* @param sshClient * @param sshClient
...@@ -199,14 +199,14 @@ export class SSHClientManager { ...@@ -199,14 +199,14 @@ export class SSHClientManager {
public addNewSSHClient(client: Client) { public addNewSSHClient(client: Client) {
this.sshClientArray.push(new SSHClient(client, 1)); this.sshClientArray.push(new SSHClient(client, 1));
} }
/** /**
* first ssh clilent instance is used for gpu collector and host job * first ssh clilent instance is used for gpu collector and host job
*/ */
public getFirstSSHClient() { public getFirstSSHClient() {
return this.sshClientArray[0].getSSHClientInstance; return this.sshClientArray[0].getSSHClientInstance;
} }
/** /**
* close all of ssh client * close all of ssh client
*/ */
...@@ -215,7 +215,7 @@ export class SSHClientManager { ...@@ -215,7 +215,7 @@ export class SSHClientManager {
sshClient.getSSHClientInstance.end(); sshClient.getSSHClientInstance.end();
} }
} }
/** /**
* retrieve resource, minus a number for given ssh client * retrieve resource, minus a number for given ssh client
* @param client * @param client
...@@ -231,7 +231,7 @@ export class SSHClientManager { ...@@ -231,7 +231,7 @@ export class SSHClientManager {
} }
} }
} }
} }
export type RemoteMachineScheduleResult = { scheduleInfo : RemoteMachineScheduleInfo | undefined; resultType : ScheduleResultType}; export type RemoteMachineScheduleResult = { scheduleInfo : RemoteMachineScheduleInfo | undefined; resultType : ScheduleResultType};
...@@ -242,7 +242,7 @@ export enum ScheduleResultType { ...@@ -242,7 +242,7 @@ export enum ScheduleResultType {
/* Schedule succeeded*/ /* Schedule succeeded*/
SUCCEED, SUCCEED,
/* Temporarily, no enough available GPU right now */ /* Temporarily, no enough available GPU right now */
TMP_NO_AVAILABLE_GPU, TMP_NO_AVAILABLE_GPU,
/* Cannot match requirement even if all GPU are a*/ /* Cannot match requirement even if all GPU are a*/
......
...@@ -26,7 +26,7 @@ import { ClusterJobRestServer } from '../common/clusterJobRestServer' ...@@ -26,7 +26,7 @@ import { ClusterJobRestServer } from '../common/clusterJobRestServer'
/** /**
* RemoteMachine Training service Rest server, provides rest RemoteMachine to support remotemachine job metrics update * RemoteMachine Training service Rest server, provides rest RemoteMachine to support remotemachine job metrics update
* *
*/ */
@component.Singleton @component.Singleton
export class RemoteMachineJobRestServer extends ClusterJobRestServer{ export class RemoteMachineJobRestServer extends ClusterJobRestServer{
......
...@@ -125,10 +125,10 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -125,10 +125,10 @@ class RemoteMachineTrainingService implements TrainingService {
} }
this.log.info('Remote machine training service exit.'); this.log.info('Remote machine training service exit.');
} }
/** /**
* give trial a ssh connection * give trial a ssh connection
* @param trial * @param trial
*/ */
public async allocateSSHClientForTrial(trial: RemoteMachineTrialJobDetail): Promise<void> { public async allocateSSHClientForTrial(trial: RemoteMachineTrialJobDetail): Promise<void> {
const deferred: Deferred<void> = new Deferred<void>(); const deferred: Deferred<void> = new Deferred<void>();
...@@ -144,10 +144,10 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -144,10 +144,10 @@ class RemoteMachineTrainingService implements TrainingService {
deferred.resolve(); deferred.resolve();
return deferred.promise; return deferred.promise;
} }
/** /**
* If a trial is finished, release the connection resource * If a trial is finished, release the connection resource
* @param trial * @param trial
*/ */
public releaseTrialSSHClient(trial: RemoteMachineTrialJobDetail): void { public releaseTrialSSHClient(trial: RemoteMachineTrialJobDetail): void {
if(!trial.rmMeta) { if(!trial.rmMeta) {
...@@ -167,7 +167,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -167,7 +167,7 @@ class RemoteMachineTrainingService implements TrainingService {
const jobs: TrialJobDetail[] = []; const jobs: TrialJobDetail[] = [];
const deferred: Deferred<TrialJobDetail[]> = new Deferred<TrialJobDetail[]>(); const deferred: Deferred<TrialJobDetail[]> = new Deferred<TrialJobDetail[]>();
for (const [key, value] of this.trialJobsMap) { for (const [key, value] of this.trialJobsMap) {
if (value.form.jobType === 'TRIAL') { if (value.form.jobType === 'TRIAL') {
jobs.push(await this.getTrialJob(key)); jobs.push(await this.getTrialJob(key));
} }
...@@ -275,12 +275,12 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -275,12 +275,12 @@ class RemoteMachineTrainingService implements TrainingService {
return trialJobDetail; return trialJobDetail;
} }
/** /**
* remove gpu reversion when job is not running * remove gpu reversion when job is not running
*/ */
private updateGpuReservation() { private updateGpuReservation() {
for (const [key, value] of this.trialJobsMap) { for (const [key, value] of this.trialJobsMap) {
if(!['WAITING', 'RUNNING'].includes(value.status)) { if(!['WAITING', 'RUNNING'].includes(value.status)) {
this.gpuScheduler.removeGpuReservation(key, this.trialJobsMap); this.gpuScheduler.removeGpuReservation(key, this.trialJobsMap);
} }
...@@ -371,7 +371,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -371,7 +371,7 @@ class RemoteMachineTrainingService implements TrainingService {
await validateCodeDir(remoteMachineTrailConfig.codeDir); await validateCodeDir(remoteMachineTrailConfig.codeDir);
} catch(error) { } catch(error) {
this.log.error(error); this.log.error(error);
return Promise.reject(new Error(error)); return Promise.reject(new Error(error));
} }
this.trialConfig = remoteMachineTrailConfig; this.trialConfig = remoteMachineTrailConfig;
...@@ -400,16 +400,16 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -400,16 +400,16 @@ class RemoteMachineTrainingService implements TrainingService {
return deferred.promise; return deferred.promise;
} }
/** /**
* cleanup() has a time out of 10s to clean remote connections * cleanup() has a time out of 10s to clean remote connections
*/ */
public async cleanUp(): Promise<void> { public async cleanUp(): Promise<void> {
this.log.info('Stopping remote machine training service...'); this.log.info('Stopping remote machine training service...');
this.stopping = true; this.stopping = true;
await Promise.race([delay(10000), this.cleanupConnections()]); await Promise.race([delay(10000), this.cleanupConnections()]);
} }
/** /**
* stop gpu_metric_collector process in remote machine and remove unused scripts * stop gpu_metric_collector process in remote machine and remove unused scripts
*/ */
...@@ -430,8 +430,8 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -430,8 +430,8 @@ class RemoteMachineTrainingService implements TrainingService {
} }
return Promise.resolve(); return Promise.resolve();
} }
/** /**
* Generate gpu metric collector directory to store temp gpu metric collector script files * Generate gpu metric collector directory to store temp gpu metric collector script files
*/ */
...@@ -441,8 +441,8 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -441,8 +441,8 @@ class RemoteMachineTrainingService implements TrainingService {
} }
/** /**
* Generate gpu metric collector shell script in local machine, * Generate gpu metric collector shell script in local machine,
* used to run in remote machine, and will be deleted after uploaded from local. * used to run in remote machine, and will be deleted after uploaded from local.
*/ */
private async generateGpuMetricsCollectorScript(userName: string): Promise<void> { private async generateGpuMetricsCollectorScript(userName: string): Promise<void> {
let gpuMetricCollectorScriptFolder : string = this.getLocalGpuMetricCollectorDir(); let gpuMetricCollectorScriptFolder : string = this.getLocalGpuMetricCollectorDir();
...@@ -451,9 +451,9 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -451,9 +451,9 @@ class RemoteMachineTrainingService implements TrainingService {
let gpuMetricsCollectorScriptPath: string = path.join(gpuMetricCollectorScriptFolder, userName, 'gpu_metrics_collector.sh'); let gpuMetricsCollectorScriptPath: string = path.join(gpuMetricCollectorScriptFolder, userName, 'gpu_metrics_collector.sh');
const remoteGPUScriptsDir: string = this.getRemoteScriptsPath(userName); // This directory is used to store gpu_metrics and pid created by script const remoteGPUScriptsDir: string = this.getRemoteScriptsPath(userName); // This directory is used to store gpu_metrics and pid created by script
const gpuMetricsCollectorScriptContent: string = String.Format( const gpuMetricsCollectorScriptContent: string = String.Format(
GPU_INFO_COLLECTOR_FORMAT_LINUX, GPU_INFO_COLLECTOR_FORMAT_LINUX,
remoteGPUScriptsDir, remoteGPUScriptsDir,
unixPathJoin(remoteGPUScriptsDir, 'pid'), unixPathJoin(remoteGPUScriptsDir, 'pid'),
); );
await fs.promises.writeFile(gpuMetricsCollectorScriptPath, gpuMetricsCollectorScriptContent, { encoding: 'utf8' }); await fs.promises.writeFile(gpuMetricsCollectorScriptPath, gpuMetricsCollectorScriptContent, { encoding: 'utf8' });
} }
...@@ -589,7 +589,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -589,7 +589,7 @@ class RemoteMachineTrainingService implements TrainingService {
} else { } else {
command = `CUDA_VISIBLE_DEVICES=" " ${this.trialConfig.command}`; command = `CUDA_VISIBLE_DEVICES=" " ${this.trialConfig.command}`;
} }
const nniManagerIp = this.nniManagerIpConfig?this.nniManagerIpConfig.nniManagerIp:getIPV4Address(); const nniManagerIp = this.nniManagerIpConfig?this.nniManagerIpConfig.nniManagerIp:getIPV4Address();
if(!this.remoteRestServerPort) { if(!this.remoteRestServerPort) {
const restServer: RemoteMachineJobRestServer = component.get(RemoteMachineJobRestServer); const restServer: RemoteMachineJobRestServer = component.get(RemoteMachineJobRestServer);
......
...@@ -37,7 +37,7 @@ describe('WebHDFS', function () { ...@@ -37,7 +37,7 @@ describe('WebHDFS', function () {
{ {
"user": "user1", "user": "user1",
"port": 50070, "port": 50070,
"host": "10.0.0.0" "host": "10.0.0.0"
} }
*/ */
let skip: boolean = false; let skip: boolean = false;
...@@ -45,7 +45,7 @@ describe('WebHDFS', function () { ...@@ -45,7 +45,7 @@ describe('WebHDFS', function () {
let hdfsClient: any; let hdfsClient: any;
try { try {
testHDFSInfo = JSON.parse(fs.readFileSync('../../.vscode/hdfsInfo.json', 'utf8')); testHDFSInfo = JSON.parse(fs.readFileSync('../../.vscode/hdfsInfo.json', 'utf8'));
console.log(testHDFSInfo); console.log(testHDFSInfo);
hdfsClient = WebHDFS.createClient({ hdfsClient = WebHDFS.createClient({
user: testHDFSInfo.user, user: testHDFSInfo.user,
port: testHDFSInfo.port, port: testHDFSInfo.port,
...@@ -120,7 +120,7 @@ describe('WebHDFS', function () { ...@@ -120,7 +120,7 @@ describe('WebHDFS', function () {
chai.expect(actualFileData).to.be.equals(testFileData); chai.expect(actualFileData).to.be.equals(testFileData);
const testHDFSDirPath : string = path.join('/nni_unittest_' + uniqueString(6) + '_dir'); const testHDFSDirPath : string = path.join('/nni_unittest_' + uniqueString(6) + '_dir');
await HDFSClientUtility.copyDirectoryToHdfs(tmpLocalDirectoryPath, testHDFSDirPath, hdfsClient); await HDFSClientUtility.copyDirectoryToHdfs(tmpLocalDirectoryPath, testHDFSDirPath, hdfsClient);
const files : any[] = await HDFSClientUtility.readdir(testHDFSDirPath, hdfsClient); const files : any[] = await HDFSClientUtility.readdir(testHDFSDirPath, hdfsClient);
...@@ -133,7 +133,7 @@ describe('WebHDFS', function () { ...@@ -133,7 +133,7 @@ describe('WebHDFS', function () {
// Cleanup // Cleanup
rmdir(tmpLocalDirectoryPath); rmdir(tmpLocalDirectoryPath);
let deleteRestult : boolean = await HDFSClientUtility.deletePath(testHDFSFilePath, hdfsClient); let deleteRestult : boolean = await HDFSClientUtility.deletePath(testHDFSFilePath, hdfsClient);
chai.expect(deleteRestult).to.be.equals(true); chai.expect(deleteRestult).to.be.equals(true);
......
...@@ -63,7 +63,7 @@ describe('Unit Test for KubeflowTrainingService', () => { ...@@ -63,7 +63,7 @@ describe('Unit Test for KubeflowTrainingService', () => {
if (skip) { if (skip) {
return; return;
} }
kubeflowTrainingService = component.get(KubeflowTrainingService); kubeflowTrainingService = component.get(KubeflowTrainingService);
}); });
afterEach(() => { afterEach(() => {
...@@ -78,6 +78,6 @@ describe('Unit Test for KubeflowTrainingService', () => { ...@@ -78,6 +78,6 @@ describe('Unit Test for KubeflowTrainingService', () => {
return; return;
} }
await kubeflowTrainingService.setClusterMetadata(TrialConfigMetadataKey.KUBEFLOW_CLUSTER_CONFIG, testKubeflowConfig), await kubeflowTrainingService.setClusterMetadata(TrialConfigMetadataKey.KUBEFLOW_CLUSTER_CONFIG, testKubeflowConfig),
await kubeflowTrainingService.setClusterMetadata(TrialConfigMetadataKey.TRIAL_CONFIG, testKubeflowTrialConfig); await kubeflowTrainingService.setClusterMetadata(TrialConfigMetadataKey.TRIAL_CONFIG, testKubeflowTrialConfig);
}); });
}); });
\ No newline at end of file
...@@ -63,7 +63,7 @@ describe('Unit Test for LocalTrainingService', () => { ...@@ -63,7 +63,7 @@ describe('Unit Test for LocalTrainingService', () => {
//trial jobs should be empty, since there are no submitted jobs //trial jobs should be empty, since there are no submitted jobs
chai.expect(await localTrainingService.listTrialJobs()).to.be.empty; chai.expect(await localTrainingService.listTrialJobs()).to.be.empty;
}); });
it('setClusterMetadata and getClusterMetadata', async () => { it('setClusterMetadata and getClusterMetadata', async () => {
await localTrainingService.setClusterMetadata(TrialConfigMetadataKey.TRIAL_CONFIG, trialConfig); await localTrainingService.setClusterMetadata(TrialConfigMetadataKey.TRIAL_CONFIG, trialConfig);
localTrainingService.getClusterMetadata(TrialConfigMetadataKey.TRIAL_CONFIG).then((data)=>{ localTrainingService.getClusterMetadata(TrialConfigMetadataKey.TRIAL_CONFIG).then((data)=>{
...@@ -87,7 +87,7 @@ describe('Unit Test for LocalTrainingService', () => { ...@@ -87,7 +87,7 @@ describe('Unit Test for LocalTrainingService', () => {
await localTrainingService.cancelTrialJob(jobDetail.id); await localTrainingService.cancelTrialJob(jobDetail.id);
chai.expect(jobDetail.status).to.be.equals('USER_CANCELED'); chai.expect(jobDetail.status).to.be.equals('USER_CANCELED');
}).timeout(20000); }).timeout(20000);
it('Read metrics, Add listener, and remove listener', async () => { it('Read metrics, Add listener, and remove listener', async () => {
// set meta data // set meta data
const trialConfig: string = `{\"command\":\"python3 mockedTrial.py\", \"codeDir\":\"${localCodeDir}\",\"gpuNum\":0}` const trialConfig: string = `{\"command\":\"python3 mockedTrial.py\", \"codeDir\":\"${localCodeDir}\",\"gpuNum\":0}`
......
...@@ -89,7 +89,7 @@ describe('Unit Test for PAITrainingService', () => { ...@@ -89,7 +89,7 @@ describe('Unit Test for PAITrainingService', () => {
chai.expect(trialDetail.status).to.be.equals('WAITING'); chai.expect(trialDetail.status).to.be.equals('WAITING');
} catch(error) { } catch(error) {
console.log('Submit job failed:' + error); console.log('Submit job failed:' + error);
chai.assert(error) chai.assert(error)
} }
}); });
}); });
\ No newline at end of file
...@@ -7,5 +7,5 @@ declare module 'child-process-promise' { ...@@ -7,5 +7,5 @@ declare module 'child-process-promise' {
stderr: string, stderr: string,
message: string message: string
} }
} }
} }
\ No newline at end of file
...@@ -154,7 +154,7 @@ def main(): ...@@ -154,7 +154,7 @@ def main():
assessor = None assessor = None
if args.tuner_class_name in ModuleName: if args.tuner_class_name in ModuleName:
tuner = create_builtin_class_instance( tuner = create_builtin_class_instance(
args.tuner_class_name, args.tuner_class_name,
args.tuner_args) args.tuner_args)
else: else:
tuner = create_customized_class_instance( tuner = create_customized_class_instance(
......
...@@ -81,7 +81,7 @@ def create_bracket_parameter_id(brackets_id, brackets_curr_decay, increased_id=- ...@@ -81,7 +81,7 @@ def create_bracket_parameter_id(brackets_id, brackets_curr_decay, increased_id=-
class Bracket(): class Bracket():
""" """
A bracket in BOHB, all the information of a bracket is managed by A bracket in BOHB, all the information of a bracket is managed by
an instance of this class. an instance of this class.
Parameters Parameters
...@@ -251,7 +251,7 @@ class BOHB(MsgDispatcherBase): ...@@ -251,7 +251,7 @@ class BOHB(MsgDispatcherBase):
BOHB performs robust and efficient hyperparameter optimization BOHB performs robust and efficient hyperparameter optimization
at scale by combining the speed of Hyperband searches with the at scale by combining the speed of Hyperband searches with the
guidance and guarantees of convergence of Bayesian Optimization. guidance and guarantees of convergence of Bayesian Optimization.
Instead of sampling new configurations at random, BOHB uses Instead of sampling new configurations at random, BOHB uses
kernel density estimators to select promising candidates. kernel density estimators to select promising candidates.
Parameters Parameters
...@@ -335,7 +335,7 @@ class BOHB(MsgDispatcherBase): ...@@ -335,7 +335,7 @@ class BOHB(MsgDispatcherBase):
pass pass
def handle_initialize(self, data): def handle_initialize(self, data):
"""Initialize Tuner, including creating Bayesian optimization-based parametric models """Initialize Tuner, including creating Bayesian optimization-based parametric models
and search space formations and search space formations
Parameters Parameters
...@@ -403,7 +403,7 @@ class BOHB(MsgDispatcherBase): ...@@ -403,7 +403,7 @@ class BOHB(MsgDispatcherBase):
If this function is called, Command will be sent by BOHB: If this function is called, Command will be sent by BOHB:
a. If there is a parameter need to run, will return "NewTrialJob" with a dict: a. If there is a parameter need to run, will return "NewTrialJob" with a dict:
{ {
'parameter_id': id of new hyperparameter 'parameter_id': id of new hyperparameter
'parameter_source': 'algorithm' 'parameter_source': 'algorithm'
'parameters': value of new hyperparameter 'parameters': value of new hyperparameter
...@@ -458,30 +458,30 @@ class BOHB(MsgDispatcherBase): ...@@ -458,30 +458,30 @@ class BOHB(MsgDispatcherBase):
var, lower=search_space[var]["_value"][0], upper=search_space[var]["_value"][1])) var, lower=search_space[var]["_value"][0], upper=search_space[var]["_value"][1]))
elif _type == 'quniform': elif _type == 'quniform':
cs.add_hyperparameter(CSH.UniformFloatHyperparameter( cs.add_hyperparameter(CSH.UniformFloatHyperparameter(
var, lower=search_space[var]["_value"][0], upper=search_space[var]["_value"][1], var, lower=search_space[var]["_value"][0], upper=search_space[var]["_value"][1],
q=search_space[var]["_value"][2])) q=search_space[var]["_value"][2]))
elif _type == 'loguniform': elif _type == 'loguniform':
cs.add_hyperparameter(CSH.UniformFloatHyperparameter( cs.add_hyperparameter(CSH.UniformFloatHyperparameter(
var, lower=search_space[var]["_value"][0], upper=search_space[var]["_value"][1], var, lower=search_space[var]["_value"][0], upper=search_space[var]["_value"][1],
log=True)) log=True))
elif _type == 'qloguniform': elif _type == 'qloguniform':
cs.add_hyperparameter(CSH.UniformFloatHyperparameter( cs.add_hyperparameter(CSH.UniformFloatHyperparameter(
var, lower=search_space[var]["_value"][0], upper=search_space[var]["_value"][1], var, lower=search_space[var]["_value"][0], upper=search_space[var]["_value"][1],
q=search_space[var]["_value"][2], log=True)) q=search_space[var]["_value"][2], log=True))
elif _type == 'normal': elif _type == 'normal':
cs.add_hyperparameter(CSH.NormalFloatHyperparameter( cs.add_hyperparameter(CSH.NormalFloatHyperparameter(
var, mu=search_space[var]["_value"][1], sigma=search_space[var]["_value"][2])) var, mu=search_space[var]["_value"][1], sigma=search_space[var]["_value"][2]))
elif _type == 'qnormal': elif _type == 'qnormal':
cs.add_hyperparameter(CSH.NormalFloatHyperparameter( cs.add_hyperparameter(CSH.NormalFloatHyperparameter(
var, mu=search_space[var]["_value"][1], sigma=search_space[var]["_value"][2], var, mu=search_space[var]["_value"][1], sigma=search_space[var]["_value"][2],
q=search_space[var]["_value"][3])) q=search_space[var]["_value"][3]))
elif _type == 'lognormal': elif _type == 'lognormal':
cs.add_hyperparameter(CSH.NormalFloatHyperparameter( cs.add_hyperparameter(CSH.NormalFloatHyperparameter(
var, mu=search_space[var]["_value"][1], sigma=search_space[var]["_value"][2], var, mu=search_space[var]["_value"][1], sigma=search_space[var]["_value"][2],
log=True)) log=True))
elif _type == 'qlognormal': elif _type == 'qlognormal':
cs.add_hyperparameter(CSH.NormalFloatHyperparameter( cs.add_hyperparameter(CSH.NormalFloatHyperparameter(
var, mu=search_space[var]["_value"][1], sigma=search_space[var]["_value"][2], var, mu=search_space[var]["_value"][1], sigma=search_space[var]["_value"][2],
q=search_space[var]["_value"][3], log=True)) q=search_space[var]["_value"][3], log=True))
else: else:
raise ValueError( raise ValueError(
...@@ -553,7 +553,7 @@ class BOHB(MsgDispatcherBase): ...@@ -553,7 +553,7 @@ class BOHB(MsgDispatcherBase):
self.brackets[s].set_config_perf( self.brackets[s].set_config_perf(
int(i), data['parameter_id'], sys.maxsize, value) int(i), data['parameter_id'], sys.maxsize, value)
self.completed_hyper_configs.append(data) self.completed_hyper_configs.append(data)
_parameters = self.parameters[data['parameter_id']] _parameters = self.parameters[data['parameter_id']]
_parameters.pop(_KEY) _parameters.pop(_KEY)
# update BO with loss, max_s budget, hyperparameters # update BO with loss, max_s budget, hyperparameters
......
...@@ -117,7 +117,7 @@ class CG_BOHB(object): ...@@ -117,7 +117,7 @@ class CG_BOHB(object):
seperated by budget. This function sample a configuration from seperated by budget. This function sample a configuration from
largest budget. Firstly we sample "num_samples" configurations, largest budget. Firstly we sample "num_samples" configurations,
then prefer one with the largest l(x)/g(x). then prefer one with the largest l(x)/g(x).
Parameters: Parameters:
----------- -----------
info_dict: dict info_dict: dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment