Commit ba8dccd6 authored by suiguoxin's avatar suiguoxin
Browse files

Merge branch 'master' of https://github.com/microsoft/nni

parents 56a1575b 150ee83a
...@@ -17,35 +17,36 @@ ...@@ -17,35 +17,36 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/ */
'use strict' 'use strict';
import * as cpp from 'child-process-promise'; import * as cpp from 'child-process-promise';
import * as path from 'path'; import * as path from 'path';
import * as azureStorage from 'azure-storage';
import { EventEmitter } from 'events'; import { EventEmitter } from 'events';
import { Base64 } from 'js-base64';
import { String } from 'typescript-string-operations';
import { getExperimentId, getInitTrialSequenceId } from '../../common/experimentStartupInfo'; import { getExperimentId, getInitTrialSequenceId } from '../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../common/log'; import { getLogger, Logger } from '../../common/log';
import { getExperimentRootDir, uniqueString, getJobCancelStatus, getIPV4Address, getVersion } from '../../common/utils';
import { import {
TrialJobDetail, TrialJobMetric, NNIManagerIpConfig NNIManagerIpConfig, TrialJobDetail, TrialJobMetric
} from '../../common/trainingService'; } from '../../common/trainingService';
import { KubernetesTrialJobDetail, KubernetesScriptFormat } from './kubernetesData'; import { getExperimentRootDir, getIPV4Address, getJobCancelStatus, getVersion, uniqueString } from '../../common/utils';
import { KubernetesClusterConfig } from './kubernetesConfig';
import { GeneralK8sClient, KubernetesCRDClient } from './kubernetesApiClient';
import { AzureStorageClientUtility } from './azureStorageClientUtils'; import { AzureStorageClientUtility } from './azureStorageClientUtils';
import { GeneralK8sClient, KubernetesCRDClient } from './kubernetesApiClient';
import { KubernetesClusterConfig } from './kubernetesConfig';
import { kubernetesScriptFormat, KubernetesTrialJobDetail } from './kubernetesData';
import { KubernetesJobRestServer } from './kubernetesJobRestServer'; import { KubernetesJobRestServer } from './kubernetesJobRestServer';
import { String } from 'typescript-string-operations';
import * as azureStorage from 'azure-storage';
var azure = require('azure-storage');
var base64 = require('js-base64').Base64;
/**
* Training Service implementation for Kubernetes
*/
abstract class KubernetesTrainingService { abstract class KubernetesTrainingService {
protected readonly NNI_KUBERNETES_TRIAL_LABEL: string = 'nni-kubernetes-trial'; protected readonly NNI_KUBERNETES_TRIAL_LABEL: string = 'nni-kubernetes-trial';
protected readonly log!: Logger; protected readonly log!: Logger;
protected readonly metricsEmitter: EventEmitter; protected readonly metricsEmitter: EventEmitter;
protected readonly trialJobsMap: Map<string, KubernetesTrialJobDetail>; protected readonly trialJobsMap: Map<string, KubernetesTrialJobDetail>;
/** experiment root dir in NFS */ // experiment root dir in NFS
protected readonly trialLocalNFSTempFolder: string; protected readonly trialLocalNFSTempFolder: string;
protected stopping: boolean = false; protected stopping: boolean = false;
protected experimentId! : string; protected experimentId! : string;
...@@ -63,35 +64,36 @@ abstract class KubernetesTrainingService { ...@@ -63,35 +64,36 @@ abstract class KubernetesTrainingService {
protected kubernetesClusterConfig?: KubernetesClusterConfig; protected kubernetesClusterConfig?: KubernetesClusterConfig;
protected versionCheck: boolean = true; protected versionCheck: boolean = true;
protected logCollection: string; protected logCollection: string;
constructor() { constructor() {
this.log = getLogger(); this.log = getLogger();
this.metricsEmitter = new EventEmitter(); this.metricsEmitter = new EventEmitter();
this.trialJobsMap = new Map<string, KubernetesTrialJobDetail>(); this.trialJobsMap = new Map<string, KubernetesTrialJobDetail>();
this.trialLocalNFSTempFolder = path.join(getExperimentRootDir(), 'trials-nfs-tmp'); this.trialLocalNFSTempFolder = path.join(getExperimentRootDir(), 'trials-nfs-tmp');
this.experimentId = getExperimentId(); this.experimentId = getExperimentId();
this.nextTrialSequenceId = -1; this.nextTrialSequenceId = -1;
this.CONTAINER_MOUNT_PATH = '/tmp/mount'; this.CONTAINER_MOUNT_PATH = '/tmp/mount';
this.genericK8sClient = new GeneralK8sClient(); this.genericK8sClient = new GeneralK8sClient();
this.logCollection = 'none'; this.logCollection = 'none';
} }
public generatePodResource(memory: number, cpuNum: number, gpuNum: number) { // tslint:disable:no-any
public generatePodResource(memory: number, cpuNum: number, gpuNum: number): any {
return { return {
'memory': `${memory}Mi`, memory: `${memory}Mi`,
'cpu': `${cpuNum}`, cpu: `${cpuNum}`,
'nvidia.com/gpu': `${gpuNum}` 'nvidia.com/gpu': `${gpuNum}`
} };
} } // tslint:enable:no-any
public async listTrialJobs(): Promise<TrialJobDetail[]> { public async listTrialJobs(): Promise<TrialJobDetail[]> {
const jobs: TrialJobDetail[] = []; const jobs: TrialJobDetail[] = [];
for (const [key, value] of this.trialJobsMap) { for (const [key, value] of this.trialJobsMap) {
if (value.form.jobType === 'TRIAL') { if (value.form.jobType === 'TRIAL') {
jobs.push(await this.getTrialJob(key)); jobs.push(await this.getTrialJob(key));
} }
}; }
return Promise.resolve(jobs); return Promise.resolve(jobs);
} }
...@@ -100,21 +102,21 @@ abstract class KubernetesTrainingService { ...@@ -100,21 +102,21 @@ abstract class KubernetesTrainingService {
const kubernetesTrialJob: TrialJobDetail | undefined = this.trialJobsMap.get(trialJobId); const kubernetesTrialJob: TrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
if (!kubernetesTrialJob) { if (kubernetesTrialJob === undefined) {
return Promise.reject(`trial job ${trialJobId} not found`) return Promise.reject(`trial job ${trialJobId} not found`);
} }
return Promise.resolve(kubernetesTrialJob); return Promise.resolve(kubernetesTrialJob);
} }
public addTrialJobMetricListener(listener: (metric: TrialJobMetric) => void) { public addTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void {
this.metricsEmitter.on('metric', listener); this.metricsEmitter.on('metric', listener);
} }
public removeTrialJobMetricListener(listener: (metric: TrialJobMetric) => void) { public removeTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void {
this.metricsEmitter.off('metric', listener); this.metricsEmitter.off('metric', listener);
} }
public get isMultiPhaseJobSupported(): boolean { public get isMultiPhaseJobSupported(): boolean {
return false; return false;
} }
...@@ -127,6 +129,96 @@ abstract class KubernetesTrainingService { ...@@ -127,6 +129,96 @@ abstract class KubernetesTrainingService {
return this.metricsEmitter; return this.metricsEmitter;
} }
public async cancelTrialJob(trialJobId: string, isEarlyStopped: boolean = false): Promise<void> {
const trialJobDetail : KubernetesTrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
if (trialJobDetail === undefined) {
const errorMessage: string = `CancelTrialJob: trial job id ${trialJobId} not found`;
this.log.error(errorMessage);
return Promise.reject(errorMessage);
}
if (this.kubernetesCRDClient === undefined) {
const errorMessage: string = `CancelTrialJob: trial job id ${trialJobId} failed because operatorClient is undefined`;
this.log.error(errorMessage);
return Promise.reject(errorMessage);
}
try {
await this.kubernetesCRDClient.deleteKubernetesJob(new Map(
[
['app', this.NNI_KUBERNETES_TRIAL_LABEL],
['expId', getExperimentId()],
['trialId', trialJobId]
]
));
} catch (err) {
const errorMessage: string = `Delete trial ${trialJobId} failed: ${err}`;
this.log.error(errorMessage);
return Promise.reject(errorMessage);
}
trialJobDetail.endTime = Date.now();
trialJobDetail.status = getJobCancelStatus(isEarlyStopped);
return Promise.resolve();
}
public async cleanUp(): Promise<void> {
this.stopping = true;
// First, cancel all running kubernetes jobs
for (const [trialJobId, kubernetesTrialJob] of this.trialJobsMap) {
if (['RUNNING', 'WAITING', 'UNKNOWN'].includes(kubernetesTrialJob.status)) {
try {
await this.cancelTrialJob(trialJobId);
} catch (error) {
// DONT throw error during cleanup
}
kubernetesTrialJob.status = 'SYS_CANCELED';
}
}
// Delete all kubernetes jobs whose expId label is current experiment id
try {
if (this.kubernetesCRDClient !== undefined) {
await this.kubernetesCRDClient.deleteKubernetesJob(new Map(
[
['app', this.NNI_KUBERNETES_TRIAL_LABEL],
['expId', getExperimentId()]
]
));
}
} catch (error) {
this.log.error(`Delete kubernetes job with label: app=${this.NNI_KUBERNETES_TRIAL_LABEL},\
expId=${getExperimentId()} failed, error is ${error}`);
}
// Unmount NFS
try {
await cpp.exec(`sudo umount ${this.trialLocalNFSTempFolder}`);
} catch (error) {
this.log.error(`Unmount ${this.trialLocalNFSTempFolder} failed, error is ${error}`);
}
// Stop kubernetes rest server
if (this.kubernetesJobRestServer === undefined) {
throw new Error('kubernetesJobRestServer not initialized!');
}
try {
await this.kubernetesJobRestServer.stop();
this.log.info('Kubernetes Training service rest server stopped successfully.');
} catch (error) {
// tslint:disable-next-line: no-unsafe-any
this.log.error(`Kubernetes Training service rest server stopped failed, error: ${error.message}`);
return Promise.reject(error);
}
return Promise.resolve();
}
protected generateSequenceId(): number { protected generateSequenceId(): number {
if (this.nextTrialSequenceId === -1) { if (this.nextTrialSequenceId === -1) {
this.nextTrialSequenceId = getInitTrialSequenceId(); this.nextTrialSequenceId = getInitTrialSequenceId();
...@@ -135,25 +227,31 @@ abstract class KubernetesTrainingService { ...@@ -135,25 +227,31 @@ abstract class KubernetesTrainingService {
return this.nextTrialSequenceId++; return this.nextTrialSequenceId++;
} }
// tslint:disable: no-unsafe-any no-any
protected async createAzureStorage(vaultName: string, valutKeyName: string, accountName: string, azureShare: string): Promise<void> { protected async createAzureStorage(vaultName: string, valutKeyName: string, accountName: string, azureShare: string): Promise<void> {
try { try {
const result = await cpp.exec(`az keyvault secret show --name ${valutKeyName} --vault-name ${vaultName}`); const result: any = await cpp.exec(`az keyvault secret show --name ${valutKeyName} --vault-name ${vaultName}`);
if(result.stderr) { if (result.stderr) {
const errorMessage: string = result.stderr; const errorMessage: string = result.stderr;
this.log.error(errorMessage); this.log.error(errorMessage);
return Promise.reject(errorMessage); return Promise.reject(errorMessage);
} }
const storageAccountKey =JSON.parse(result.stdout).value; const storageAccountKey: any = JSON.parse(result.stdout).value;
if (this.azureStorageAccountName === undefined) {
throw new Error('azureStorageAccountName not initialized!');
}
//create storage client //create storage client
this.azureStorageClient = azure.createFileService(this.azureStorageAccountName, storageAccountKey); this.azureStorageClient = azureStorage.createFileService(this.azureStorageAccountName, storageAccountKey);
await AzureStorageClientUtility.createShare(this.azureStorageClient, this.azureStorageShare); await AzureStorageClientUtility.createShare(this.azureStorageClient, this.azureStorageShare);
//create sotrage secret //create sotrage secret
this.azureStorageSecretName = 'nni-secret-' + uniqueString(8).toLowerCase(); this.azureStorageSecretName = String.Format('nni-secret-{0}', uniqueString(8)
.toLowerCase());
await this.genericK8sClient.createSecret( await this.genericK8sClient.createSecret(
{ {
apiVersion: 'v1', apiVersion: 'v1',
kind: 'Secret', kind: 'Secret',
metadata: { metadata: {
name: this.azureStorageSecretName, name: this.azureStorageSecretName,
namespace: 'default', namespace: 'default',
labels: { labels: {
...@@ -163,38 +261,42 @@ abstract class KubernetesTrainingService { ...@@ -163,38 +261,42 @@ abstract class KubernetesTrainingService {
}, },
type: 'Opaque', type: 'Opaque',
data: { data: {
azurestorageaccountname: base64.encode(this.azureStorageAccountName), azurestorageaccountname: Base64.encode(this.azureStorageAccountName),
azurestorageaccountkey: base64.encode(storageAccountKey) azurestorageaccountkey: Base64.encode(storageAccountKey)
} }
} }
); );
} catch(error) { } catch (error) {
this.log.error(error); this.log.error(error);
return Promise.reject(error); return Promise.reject(error);
} }
return Promise.resolve(); return Promise.resolve();
} }
// tslint:enable: no-unsafe-any no-any
/**
/**
* Genereate run script for different roles(like worker or ps) * Genereate run script for different roles(like worker or ps)
* @param trialJobId trial job id * @param trialJobId trial job id
* @param trialWorkingFolder working folder * @param trialWorkingFolder working folder
* @param command * @param command command
* @param trialSequenceId sequence id * @param trialSequenceId sequence id
*/ */
protected async generateRunScript(platform: string, trialJobId: string, trialWorkingFolder: string, protected async generateRunScript(platform: string, trialJobId: string, trialWorkingFolder: string,
command: string, trialSequenceId: string, roleName: string, gpuNum: number): Promise<string> { command: string, trialSequenceId: string, roleName: string, gpuNum: number): Promise<string> {
let nvidia_script: string = ''; let nvidiaScript: string = '';
// Nvidia devcie plugin for K8S has a known issue that requesting zero GPUs allocates all GPUs // Nvidia devcie plugin for K8S has a known issue that requesting zero GPUs allocates all GPUs
// Refer https://github.com/NVIDIA/k8s-device-plugin/issues/61 // Refer https://github.com/NVIDIA/k8s-device-plugin/issues/61
// So we have to explicitly set CUDA_VISIBLE_DEVICES to empty if user sets gpuNum to 0 in NNI config file // So we have to explicitly set CUDA_VISIBLE_DEVICES to empty if user sets gpuNum to 0 in NNI config file
if(gpuNum === 0) { if (gpuNum === 0) {
nvidia_script = `export CUDA_VISIBLE_DEVICES='0'`; nvidiaScript = `export CUDA_VISIBLE_DEVICES='0'`;
} }
const nniManagerIp = this.nniManagerIpConfig?this.nniManagerIpConfig.nniManagerIp:getIPV4Address(); // tslint:disable-next-line: strict-boolean-expressions
const version = this.versionCheck? await getVersion(): ''; const nniManagerIp: string = this.nniManagerIpConfig ? this.nniManagerIpConfig.nniManagerIp : getIPV4Address();
const version: string = this.versionCheck ? await getVersion() : '';
const runScript: string = String.Format( const runScript: string = String.Format(
KubernetesScriptFormat, kubernetesScriptFormat,
platform, platform,
trialJobId, trialJobId,
path.join(trialWorkingFolder, 'output', `${roleName}_output`), path.join(trialWorkingFolder, 'output', `${roleName}_output`),
...@@ -202,108 +304,28 @@ abstract class KubernetesTrainingService { ...@@ -202,108 +304,28 @@ abstract class KubernetesTrainingService {
getExperimentId(), getExperimentId(),
trialWorkingFolder, trialWorkingFolder,
trialSequenceId, trialSequenceId,
nvidia_script, nvidiaScript,
command, command,
nniManagerIp, nniManagerIp,
this.kubernetesRestServerPort, this.kubernetesRestServerPort,
version, version,
this.logCollection this.logCollection
); );
return Promise.resolve(runScript); return Promise.resolve(runScript);
} }
protected async createNFSStorage(nfsServer: string, nfsPath: string): Promise<void> { protected async createNFSStorage(nfsServer: string, nfsPath: string): Promise<void> {
await cpp.exec(`mkdir -p ${this.trialLocalNFSTempFolder}`); await cpp.exec(`mkdir -p ${this.trialLocalNFSTempFolder}`);
try { try {
await cpp.exec(`sudo mount ${nfsServer}:${nfsPath} ${this.trialLocalNFSTempFolder}`); await cpp.exec(`sudo mount ${nfsServer}:${nfsPath} ${this.trialLocalNFSTempFolder}`);
} catch(error) { } catch (error) {
const mountError: string = `Mount NFS ${nfsServer}:${nfsPath} to ${this.trialLocalNFSTempFolder} failed, error is ${error}`; const mountError: string = `Mount NFS ${nfsServer}:${nfsPath} to ${this.trialLocalNFSTempFolder} failed, error is ${error}`;
this.log.error(mountError); this.log.error(mountError);
return Promise.reject(mountError);
}
return Promise.resolve();
}
public async cancelTrialJob(trialJobId: string, isEarlyStopped: boolean = false): Promise<void> { return Promise.reject(mountError);
const trialJobDetail : KubernetesTrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
if(!trialJobDetail) {
const errorMessage: string = `CancelTrialJob: trial job id ${trialJobId} not found`;
this.log.error(errorMessage);
return Promise.reject(errorMessage);
}
if(!this.kubernetesCRDClient) {
const errorMessage: string = `CancelTrialJob: trial job id ${trialJobId} failed because operatorClient is undefined`;
this.log.error(errorMessage);
return Promise.reject(errorMessage);
}
try {
await this.kubernetesCRDClient.deleteKubernetesJob(new Map(
[
['app', this.NNI_KUBERNETES_TRIAL_LABEL],
['expId', getExperimentId()],
['trialId', trialJobId]
]
));
} catch(err) {
const errorMessage: string = `Delete trial ${trialJobId} failed: ${err}`;
this.log.error(errorMessage);
return Promise.reject(errorMessage);
}
trialJobDetail.endTime = Date.now();
trialJobDetail.status = getJobCancelStatus(isEarlyStopped);
return Promise.resolve();
}
public async cleanUp(): Promise<void> {
this.stopping = true;
// First, cancel all running kubernetes jobs
for(let [trialJobId, kubernetesTrialJob] of this.trialJobsMap) {
if(['RUNNING', 'WAITING', 'UNKNOWN'].includes(kubernetesTrialJob.status)) {
try {
await this.cancelTrialJob(trialJobId);
} catch(error) {} // DONT throw error during cleanup
kubernetesTrialJob.status = 'SYS_CANCELED';
}
}
// Delete all kubernetes jobs whose expId label is current experiment id
try {
if(this.kubernetesCRDClient) {
await this.kubernetesCRDClient.deleteKubernetesJob(new Map(
[
['app', this.NNI_KUBERNETES_TRIAL_LABEL],
['expId', getExperimentId()]
]
));
}
} catch(error) {
this.log.error(`Delete kubernetes job with label: app=${this.NNI_KUBERNETES_TRIAL_LABEL},expId=${getExperimentId()} failed, error is ${error}`);
}
// Unmount NFS
try {
await cpp.exec(`sudo umount ${this.trialLocalNFSTempFolder}`);
} catch(error) {
this.log.error(`Unmount ${this.trialLocalNFSTempFolder} failed, error is ${error}`);
}
// Stop kubernetes rest server
if(!this.kubernetesJobRestServer) {
throw new Error('kubernetesJobRestServer not initialized!');
}
try {
await this.kubernetesJobRestServer.stop();
this.log.info('Kubernetes Training service rest server stopped successfully.');
} catch (error) {
this.log.error(`Kubernetes Training service rest server stopped failed, error: ${error.message}`);
Promise.reject(error);
} }
return Promise.resolve(); return Promise.resolve();
} }
} }
export { KubernetesTrainingService };
export { KubernetesTrainingService }
...@@ -25,10 +25,10 @@ import * as fs from 'fs'; ...@@ -25,10 +25,10 @@ import * as fs from 'fs';
import * as os from 'os'; import * as os from 'os';
import * as path from 'path'; import * as path from 'path';
import { String } from 'typescript-string-operations'; import { String } from 'typescript-string-operations';
import { execMkdir, getScriptName, getgpuMetricsCollectorScriptContent, execScript, execTail, execRemove, execKill } from '../common/util'
import { getLogger, Logger } from '../../common/log'; import { getLogger, Logger } from '../../common/log';
import { delay } from '../../common/utils'; import { delay } from '../../common/utils';
import { GPUInfo, GPUSummary } from '../common/gpuData'; import { GPUInfo, GPUSummary } from '../common/gpuData';
import { execKill, execMkdir, execRemove, execTail, getgpuMetricsCollectorScriptContent, getScriptName, runScript } from '../common/util';
/** /**
* GPUScheduler for local training service * GPUScheduler for local training service
...@@ -37,8 +37,8 @@ class GPUScheduler { ...@@ -37,8 +37,8 @@ class GPUScheduler {
private gpuSummary!: GPUSummary; private gpuSummary!: GPUSummary;
private stopping: boolean; private stopping: boolean;
private log: Logger; private readonly log: Logger;
private gpuMetricCollectorScriptFolder: string; private readonly gpuMetricCollectorScriptFolder: string;
constructor() { constructor() {
this.stopping = false; this.stopping = false;
...@@ -58,28 +58,15 @@ class GPUScheduler { ...@@ -58,28 +58,15 @@ class GPUScheduler {
} }
} }
/**
* Generate gpu metric collector shell script in local machine,
* used to run in remote machine, and will be deleted after uploaded from local.
*/
private async runGpuMetricsCollectorScript(): Promise<void> {
await execMkdir(this.gpuMetricCollectorScriptFolder);
//generate gpu_metrics_collector script
let gpuMetricsCollectorScriptPath: string = path.join(this.gpuMetricCollectorScriptFolder, getScriptName('gpu_metrics_collector'));
const gpuMetricsCollectorScriptContent: string = getgpuMetricsCollectorScriptContent(this.gpuMetricCollectorScriptFolder);
await fs.promises.writeFile(gpuMetricsCollectorScriptPath, gpuMetricsCollectorScriptContent, { encoding: 'utf8' });
execScript(gpuMetricsCollectorScriptPath)
}
public getAvailableGPUIndices(useActiveGpu: boolean, occupiedGpuIndexNumMap: Map<number, number>): number[] { public getAvailableGPUIndices(useActiveGpu: boolean, occupiedGpuIndexNumMap: Map<number, number>): number[] {
if (this.gpuSummary !== undefined) { if (this.gpuSummary !== undefined) {
if(process.platform === 'win32' || useActiveGpu) { if (process.platform === 'win32' || useActiveGpu) {
return this.gpuSummary.gpuInfos.map((info: GPUInfo) => info.index); return this.gpuSummary.gpuInfos.map((info: GPUInfo) => info.index);
} } else {
else{ return this.gpuSummary.gpuInfos.filter((info: GPUInfo) =>
return this.gpuSummary.gpuInfos.filter((info: GPUInfo) => occupiedGpuIndexNumMap.get(info.index) === undefined && info.activeProcessNum === 0 ||
occupiedGpuIndexNumMap.get(info.index) === undefined && info.activeProcessNum === 0 || occupiedGpuIndexNumMap.get(info.index) !== undefined)
occupiedGpuIndexNumMap.get(info.index) !== undefined).map((info: GPUInfo) => info.index); .map((info: GPUInfo) => info.index);
} }
} }
...@@ -105,17 +92,32 @@ class GPUScheduler { ...@@ -105,17 +92,32 @@ class GPUScheduler {
} }
} }
/**
* Generate gpu metric collector shell script in local machine,
* used to run in remote machine, and will be deleted after uploaded from local.
*/
private async runGpuMetricsCollectorScript(): Promise<void> {
await execMkdir(this.gpuMetricCollectorScriptFolder);
//generate gpu_metrics_collector script
const gpuMetricsCollectorScriptPath: string =
path.join(this.gpuMetricCollectorScriptFolder, getScriptName('gpu_metrics_collector'));
const gpuMetricsCollectorScriptContent: string = getgpuMetricsCollectorScriptContent(this.gpuMetricCollectorScriptFolder);
await fs.promises.writeFile(gpuMetricsCollectorScriptPath, gpuMetricsCollectorScriptContent, { encoding: 'utf8' });
runScript(gpuMetricsCollectorScriptPath);
}
// tslint:disable:non-literal-fs-path
private async updateGPUSummary(): Promise<void> { private async updateGPUSummary(): Promise<void> {
let gpuMetricPath = path.join(this.gpuMetricCollectorScriptFolder, 'gpu_metrics'); const gpuMetricPath: string = path.join(this.gpuMetricCollectorScriptFolder, 'gpu_metrics');
if (fs.existsSync(gpuMetricPath)) { if (fs.existsSync(gpuMetricPath)) {
const cmdresult: cpp.childProcessPromise.Result = await execTail(gpuMetricPath); const cmdresult: cpp.childProcessPromise.Result = await execTail(gpuMetricPath);
if (cmdresult && cmdresult.stdout) { if (cmdresult !== undefined && cmdresult.stdout !== undefined) {
this.gpuSummary = <GPUSummary>JSON.parse(cmdresult.stdout); this.gpuSummary = <GPUSummary>JSON.parse(cmdresult.stdout);
} else { } else {
this.log.error('Could not get gpu metrics information!'); this.log.error('Could not get gpu metrics information!');
} }
} else{ } else {
this.log.warning('gpu_metrics file does not exist!') this.log.warning('gpu_metrics file does not exist!');
} }
} }
} }
......
...@@ -24,6 +24,7 @@ import { EventEmitter } from 'events'; ...@@ -24,6 +24,7 @@ import { EventEmitter } from 'events';
import * as fs from 'fs'; import * as fs from 'fs';
import * as path from 'path'; import * as path from 'path';
import * as ts from 'tail-stream'; import * as ts from 'tail-stream';
import * as tkill from 'tree-kill';
import { NNIError, NNIErrorNames } from '../../common/errors'; import { NNIError, NNIErrorNames } from '../../common/errors';
import { getInitTrialSequenceId } from '../../common/experimentStartupInfo'; import { getInitTrialSequenceId } from '../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../common/log'; import { getLogger, Logger } from '../../common/log';
...@@ -31,14 +32,14 @@ import { ...@@ -31,14 +32,14 @@ import {
HostJobApplicationForm, HyperParameters, JobApplicationForm, TrainingService, TrialJobApplicationForm, HostJobApplicationForm, HyperParameters, JobApplicationForm, TrainingService, TrialJobApplicationForm,
TrialJobDetail, TrialJobMetric, TrialJobStatus TrialJobDetail, TrialJobMetric, TrialJobStatus
} from '../../common/trainingService'; } from '../../common/trainingService';
import { delay, generateParamFileName, getExperimentRootDir, getJobCancelStatus, uniqueString, isAlive, getNewLine } from '../../common/utils'; import {
import { execMkdir, getScriptName, execScript, setEnvironmentVariable, execNewFile } from '../common/util' delay, generateParamFileName, getExperimentRootDir, getJobCancelStatus, getNewLine, isAlive, uniqueString
} from '../../common/utils';
import { TrialConfig } from '../common/trialConfig'; import { TrialConfig } from '../common/trialConfig';
import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey'; import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
import { execMkdir, execNewFile, getScriptName, runScript, setEnvironmentVariable } from '../common/util';
import { GPUScheduler } from './gpuScheduler'; import { GPUScheduler } from './gpuScheduler';
const tkill = require('tree-kill');
/** /**
* Decode a command * Decode a command
* @param Buffer binary incoming data * @param Buffer binary incoming data
...@@ -46,7 +47,7 @@ const tkill = require('tree-kill'); ...@@ -46,7 +47,7 @@ const tkill = require('tree-kill');
* success: true if the buffer contains at least one complete command; otherwise false * success: true if the buffer contains at least one complete command; otherwise false
* remain: remaining data after the first command * remain: remaining data after the first command
*/ */
// tslint:disable-next-line:informative-docs // tslint:disable:newline-per-chained-call informative-docs
function decodeCommand(data: Buffer): [boolean, string, string, Buffer] { function decodeCommand(data: Buffer): [boolean, string, string, Buffer] {
if (data.length < 8) { if (data.length < 8) {
return [false, '', '', data]; return [false, '', '', data];
...@@ -61,6 +62,7 @@ function decodeCommand(data: Buffer): [boolean, string, string, Buffer] { ...@@ -61,6 +62,7 @@ function decodeCommand(data: Buffer): [boolean, string, string, Buffer] {
return [true, commandType, content, remain]; return [true, commandType, content, remain];
} }
// tslint:enable:newline-per-chained-call informative-docs
/** /**
* LocalTrialJobDetail * LocalTrialJobDetail
...@@ -117,21 +119,21 @@ class LocalConfig { ...@@ -117,21 +119,21 @@ class LocalConfig {
* Local machine training service * Local machine training service
*/ */
class LocalTrainingService implements TrainingService { class LocalTrainingService implements TrainingService {
private eventEmitter: EventEmitter; private readonly eventEmitter: EventEmitter;
private jobMap: Map<string, LocalTrialJobDetail>; private readonly jobMap: Map<string, LocalTrialJobDetail>;
private jobQueue: string[]; private readonly jobQueue: string[];
private initialized: boolean; private initialized: boolean;
private stopping: boolean; private stopping: boolean;
private rootDir!: string; private rootDir!: string;
private trialSequenceId: number; private trialSequenceId: number;
private gpuScheduler!: GPUScheduler; private gpuScheduler!: GPUScheduler;
private occupiedGpuIndexNumMap: Map<number, number>; private readonly occupiedGpuIndexNumMap: Map<number, number>;
private designatedGpuIndices!: Set<number>; private designatedGpuIndices!: Set<number>;
private log: Logger; private readonly log: Logger;
private localTrailConfig?: TrialConfig; private localTrailConfig?: TrialConfig;
private localConfig?: LocalConfig; private localConfig?: LocalConfig;
private isMultiPhase: boolean; private isMultiPhase: boolean;
private jobStreamMap: Map<string, ts.Stream>; private readonly jobStreamMap: Map<string, ts.Stream>;
private maxTrialNumPerGpu: number; private maxTrialNumPerGpu: number;
private useActiveGpu: boolean; private useActiveGpu: boolean;
...@@ -182,7 +184,7 @@ class LocalTrainingService implements TrainingService { ...@@ -182,7 +184,7 @@ class LocalTrainingService implements TrainingService {
return this.getHostJob(trialJobId); return this.getHostJob(trialJobId);
} }
if (trialJob.status === 'RUNNING') { if (trialJob.status === 'RUNNING') {
let alive: boolean = await isAlive(trialJob.pid); const alive: boolean = await isAlive(trialJob.pid);
if (!alive) { if (!alive) {
trialJob.endTime = Date.now(); trialJob.endTime = Date.now();
this.setTrialJobStatus(trialJob, 'FAILED'); this.setTrialJobStatus(trialJob, 'FAILED');
...@@ -276,7 +278,7 @@ class LocalTrainingService implements TrainingService { ...@@ -276,7 +278,7 @@ class LocalTrainingService implements TrainingService {
return Promise.resolve(); return Promise.resolve();
} }
if (trialJob.form.jobType === 'TRIAL') { if (trialJob.form.jobType === 'TRIAL') {
await tkill(trialJob.pid, 'SIGKILL'); tkill(trialJob.pid, 'SIGKILL');
} else if (trialJob.form.jobType === 'HOST') { } else if (trialJob.form.jobType === 'HOST') {
await cpp.exec(`pkill -9 -P ${trialJob.pid}`); await cpp.exec(`pkill -9 -P ${trialJob.pid}`);
} else { } else {
...@@ -290,7 +292,8 @@ class LocalTrainingService implements TrainingService { ...@@ -290,7 +292,8 @@ class LocalTrainingService implements TrainingService {
public async setClusterMetadata(key: string, value: string): Promise<void> { public async setClusterMetadata(key: string, value: string): Promise<void> {
if (!this.initialized) { if (!this.initialized) {
this.rootDir = getExperimentRootDir(); this.rootDir = getExperimentRootDir();
if(!fs.existsSync(this.rootDir)){ // tslint:disable-next-line:non-literal-fs-path
if (!fs.existsSync(this.rootDir)) {
await cpp.exec(`powershell.exe mkdir ${this.rootDir}`); await cpp.exec(`powershell.exe mkdir ${this.rootDir}`);
} }
this.initialized = true; this.initialized = true;
...@@ -299,7 +302,7 @@ class LocalTrainingService implements TrainingService { ...@@ -299,7 +302,7 @@ class LocalTrainingService implements TrainingService {
case TrialConfigMetadataKey.TRIAL_CONFIG: case TrialConfigMetadataKey.TRIAL_CONFIG:
this.localTrailConfig = <TrialConfig>JSON.parse(value); this.localTrailConfig = <TrialConfig>JSON.parse(value);
// Parse trial config failed, throw Error // Parse trial config failed, throw Error
if (!this.localTrailConfig) { if (this.localTrailConfig === undefined) {
throw new Error('trial config parsed failed'); throw new Error('trial config parsed failed');
} }
this.log.info(`required GPU number is ${this.localTrailConfig.gpuNum}`); this.log.info(`required GPU number is ${this.localTrailConfig.gpuNum}`);
...@@ -336,10 +339,10 @@ class LocalTrainingService implements TrainingService { ...@@ -336,10 +339,10 @@ class LocalTrainingService implements TrainingService {
switch (key) { switch (key) {
case TrialConfigMetadataKey.TRIAL_CONFIG: case TrialConfigMetadataKey.TRIAL_CONFIG:
let getResult: Promise<string>; let getResult: Promise<string>;
if (!this.localTrailConfig) { if (this.localTrailConfig === undefined) {
getResult = Promise.reject(new NNIError(NNIErrorNames.NOT_FOUND, `${key} is never set yet`)); getResult = Promise.reject(new NNIError(NNIErrorNames.NOT_FOUND, `${key} is never set yet`));
} else { } else {
getResult = Promise.resolve(!this.localTrailConfig ? '' : JSON.stringify(this.localTrailConfig)); getResult = Promise.resolve(JSON.stringify(this.localTrailConfig));
} }
return getResult; return getResult;
...@@ -366,7 +369,7 @@ class LocalTrainingService implements TrainingService { ...@@ -366,7 +369,7 @@ class LocalTrainingService implements TrainingService {
if (['SUCCEEDED', 'FAILED', 'USER_CANCELED', 'SYS_CANCELED', 'EARLY_STOPPED'].includes(trialJob.status)) { if (['SUCCEEDED', 'FAILED', 'USER_CANCELED', 'SYS_CANCELED', 'EARLY_STOPPED'].includes(trialJob.status)) {
if (this.jobStreamMap.has(trialJob.id)) { if (this.jobStreamMap.has(trialJob.id)) {
const stream: ts.Stream | undefined = this.jobStreamMap.get(trialJob.id); const stream: ts.Stream | undefined = this.jobStreamMap.get(trialJob.id);
if (!stream) { if (stream === undefined) {
throw new Error(`Could not find stream in trial ${trialJob.id}`); throw new Error(`Could not find stream in trial ${trialJob.id}`);
} }
stream.destroy(); stream.destroy();
...@@ -376,13 +379,13 @@ class LocalTrainingService implements TrainingService { ...@@ -376,13 +379,13 @@ class LocalTrainingService implements TrainingService {
if (trialJob.gpuIndices !== undefined && trialJob.gpuIndices.length > 0 && this.gpuScheduler !== undefined) { if (trialJob.gpuIndices !== undefined && trialJob.gpuIndices.length > 0 && this.gpuScheduler !== undefined) {
if (oldStatus === 'RUNNING' && trialJob.status !== 'RUNNING') { if (oldStatus === 'RUNNING' && trialJob.status !== 'RUNNING') {
for (const index of trialJob.gpuIndices) { for (const index of trialJob.gpuIndices) {
let num: number | undefined = this.occupiedGpuIndexNumMap.get(index); const num: number | undefined = this.occupiedGpuIndexNumMap.get(index);
if(num === undefined) { if (num === undefined) {
throw new Error(`gpu resource schedule error`); throw new Error(`gpu resource schedule error`);
} else if(num === 1) { } else if (num === 1) {
this.occupiedGpuIndexNumMap.delete(index); this.occupiedGpuIndexNumMap.delete(index);
} else { } else {
this.occupiedGpuIndexNumMap.set(index, num - 1) this.occupiedGpuIndexNumMap.set(index, num - 1);
} }
} }
} }
...@@ -424,10 +427,10 @@ class LocalTrainingService implements TrainingService { ...@@ -424,10 +427,10 @@ class LocalTrainingService implements TrainingService {
} }
let selectedGPUIndices: number[] = []; let selectedGPUIndices: number[] = [];
let availableGpuIndices: number[] = this.gpuScheduler.getAvailableGPUIndices(this.useActiveGpu, this.occupiedGpuIndexNumMap); const availableGpuIndices: number[] = this.gpuScheduler.getAvailableGPUIndices(this.useActiveGpu, this.occupiedGpuIndexNumMap);
for(let index of availableGpuIndices) { for (const index of availableGpuIndices) {
let num: number | undefined = this.occupiedGpuIndexNumMap.get(index); const num: number | undefined = this.occupiedGpuIndexNumMap.get(index);
if(num === undefined || num < this.maxTrialNumPerGpu) { if (num === undefined || num < this.maxTrialNumPerGpu) {
selectedGPUIndices.push(index); selectedGPUIndices.push(index);
} }
} }
...@@ -461,11 +464,11 @@ class LocalTrainingService implements TrainingService { ...@@ -461,11 +464,11 @@ class LocalTrainingService implements TrainingService {
private occupyResource(resource: {gpuIndices: number[]}): void { private occupyResource(resource: {gpuIndices: number[]}): void {
if (this.gpuScheduler !== undefined) { if (this.gpuScheduler !== undefined) {
for (const index of resource.gpuIndices) { for (const index of resource.gpuIndices) {
let num: number | undefined = this.occupiedGpuIndexNumMap.get(index); const num: number | undefined = this.occupiedGpuIndexNumMap.get(index);
if(num === undefined) { if (num === undefined) {
this.occupiedGpuIndexNumMap.set(index, 1) this.occupiedGpuIndexNumMap.set(index, 1);
} else { } else {
this.occupiedGpuIndexNumMap.set(index, num + 1) this.occupiedGpuIndexNumMap.set(index, num + 1);
} }
} }
} }
...@@ -498,20 +501,20 @@ class LocalTrainingService implements TrainingService { ...@@ -498,20 +501,20 @@ class LocalTrainingService implements TrainingService {
} }
} }
private getScript(localTrailConfig: TrialConfig, workingDirectory: string): string[]{ private getScript(localTrailConfig: TrialConfig, workingDirectory: string): string[] {
let script: string[] = []; const script: string[] = [];
if (process.platform === "win32") { if (process.platform === 'win32') {
script.push( script.push(
`cmd /c ${localTrailConfig.command} 2>${path.join(workingDirectory, 'stderr')}`, `cmd /c ${localTrailConfig.command} 2>${path.join(workingDirectory, 'stderr')}`,
`$NOW_DATE = [int64](([datetime]::UtcNow)-(get-date "1/1/1970")).TotalSeconds`, `$NOW_DATE = [int64](([datetime]::UtcNow)-(get-date "1/1/1970")).TotalSeconds`,
`$NOW_DATE = "$NOW_DATE" + "000"`, `$NOW_DATE = "$NOW_DATE" + (Get-Date -Format fff).ToString()`,
`Write $LASTEXITCODE " " $NOW_DATE | Out-File ${path.join(workingDirectory, '.nni', 'state')} -NoNewline -encoding utf8`); `Write $LASTEXITCODE " " $NOW_DATE | Out-File ${path.join(workingDirectory, '.nni', 'state')} -NoNewline -encoding utf8`);
} } else {
else{
script.push( script.push(
`eval ${localTrailConfig.command} 2>${path.join(workingDirectory, 'stderr')}`, `eval ${localTrailConfig.command} 2>${path.join(workingDirectory, 'stderr')}`,
`echo $? \`date +%s000\` >${path.join(workingDirectory, '.nni', 'state')}`); `echo $? \`date +%s%3N\` >${path.join(workingDirectory, '.nni', 'state')}`);
} }
return script; return script;
} }
...@@ -519,28 +522,29 @@ class LocalTrainingService implements TrainingService { ...@@ -519,28 +522,29 @@ class LocalTrainingService implements TrainingService {
const trialJobDetail: LocalTrialJobDetail = <LocalTrialJobDetail>this.jobMap.get(trialJobId); const trialJobDetail: LocalTrialJobDetail = <LocalTrialJobDetail>this.jobMap.get(trialJobId);
const variables: { key: string; value: string }[] = this.getEnvironmentVariables(trialJobDetail, resource); const variables: { key: string; value: string }[] = this.getEnvironmentVariables(trialJobDetail, resource);
if (!this.localTrailConfig) { if (this.localTrailConfig === undefined) {
throw new Error('trial config is not initialized'); throw new Error('trial config is not initialized');
} }
const runScriptLines: string[] = []; const runScriptContent: string[] = [];
if (process.platform !== "win32"){ if (process.platform !== 'win32') {
runScriptLines.push('#!/bin/bash'); runScriptContent.push('#!/bin/bash');
} }
runScriptLines.push(`cd ${this.localTrailConfig.codeDir}`); runScriptContent.push(`cd ${this.localTrailConfig.codeDir}`);
for (const variable of variables) { for (const variable of variables) {
runScriptLines.push(setEnvironmentVariable(variable)); runScriptContent.push(setEnvironmentVariable(variable));
} }
const scripts: string[] = this.getScript(this.localTrailConfig, trialJobDetail.workingDirectory); const scripts: string[] = this.getScript(this.localTrailConfig, trialJobDetail.workingDirectory);
scripts.forEach(script => { scripts.forEach((script: string) => {
runScriptLines.push(script); runScriptContent.push(script);
}); });
await execMkdir(trialJobDetail.workingDirectory); await execMkdir(trialJobDetail.workingDirectory);
await execMkdir(path.join(trialJobDetail.workingDirectory, '.nni')); await execMkdir(path.join(trialJobDetail.workingDirectory, '.nni'));
await execNewFile(path.join(trialJobDetail.workingDirectory, '.nni', 'metrics')); await execNewFile(path.join(trialJobDetail.workingDirectory, '.nni', 'metrics'));
const scriptName: string = getScriptName('run'); const scriptName: string = getScriptName('run');
await fs.promises.writeFile(path.join(trialJobDetail.workingDirectory, scriptName), runScriptLines.join(getNewLine()), { encoding: 'utf8', mode: 0o777 }); await fs.promises.writeFile(path.join(trialJobDetail.workingDirectory, scriptName),
runScriptContent.join(getNewLine()), { encoding: 'utf8', mode: 0o777 });
await this.writeParameterFile(trialJobDetail.workingDirectory, (<TrialJobApplicationForm>trialJobDetail.form).hyperParameters); await this.writeParameterFile(trialJobDetail.workingDirectory, (<TrialJobApplicationForm>trialJobDetail.form).hyperParameters);
const trialJobProcess: cp.ChildProcess = execScript(path.join(trialJobDetail.workingDirectory, scriptName)); const trialJobProcess: cp.ChildProcess = runScript(path.join(trialJobDetail.workingDirectory, scriptName));
this.setTrialJobStatus(trialJobDetail, 'RUNNING'); this.setTrialJobStatus(trialJobDetail, 'RUNNING');
trialJobDetail.startTime = Date.now(); trialJobDetail.startTime = Date.now();
trialJobDetail.pid = trialJobProcess.pid; trialJobDetail.pid = trialJobProcess.pid;
......
...@@ -17,12 +17,12 @@ ...@@ -17,12 +17,12 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/ */
import * as path from 'path';
import * as fs from 'fs'; import * as fs from 'fs';
import * as path from 'path';
import { Deferred } from 'ts-deferred'; import { Deferred } from 'ts-deferred';
import { getExperimentId } from '../../common/experimentStartupInfo'; import { getExperimentId } from '../../common/experimentStartupInfo';
import { getLogger } from '../../common/log'; import { getLogger } from '../../common/log';
import { unixPathJoin } from '../../common/utils' import { unixPathJoin } from '../../common/utils';
/** /**
* HDFS client utility, including copy file/directory * HDFS client utility, including copy file/directory
...@@ -33,6 +33,7 @@ export namespace HDFSClientUtility { ...@@ -33,6 +33,7 @@ export namespace HDFSClientUtility {
* @param hdfsUserName HDFS user name * @param hdfsUserName HDFS user name
*/ */
function hdfsExpRootDir(hdfsUserName: string): string { function hdfsExpRootDir(hdfsUserName: string): string {
// tslint:disable-next-line:prefer-template
return '/' + unixPathJoin(hdfsUserName, 'nni', 'experiments', getExperimentId()); return '/' + unixPathJoin(hdfsUserName, 'nni', 'experiments', getExperimentId());
} }
...@@ -50,63 +51,70 @@ export namespace HDFSClientUtility { ...@@ -50,63 +51,70 @@ export namespace HDFSClientUtility {
* @param trialId NNI trial ID * @param trialId NNI trial ID
*/ */
export function getHdfsTrialWorkDir(hdfsUserName: string, trialId: string): string { export function getHdfsTrialWorkDir(hdfsUserName: string, trialId: string): string {
let root = hdfsExpRootDir(hdfsUserName) const root: string = hdfsExpRootDir(hdfsUserName);
console.log(root)
return unixPathJoin(root, 'trials', trialId); return unixPathJoin(root, 'trials', trialId);
} }
/** /**
* Copy a local file to hdfs directory * Copy a local file to hdfs directory
* *
* @param localFilePath local file path(source) * @param localFilePath local file path(source)
* @param hdfsFilePath hdfs file path(target) * @param hdfsFilePath hdfs file path(target)
* @param hdfsClient hdfs client * @param hdfsClient hdfs client
*/ */
// tslint:disable: no-unsafe-any non-literal-fs-path no-any
export async function copyFileToHdfs(localFilePath : string, hdfsFilePath : string, hdfsClient : any) : Promise<void> { export async function copyFileToHdfs(localFilePath : string, hdfsFilePath : string, hdfsClient : any) : Promise<void> {
const deferred: Deferred<void> = new Deferred<void>(); const deferred: Deferred<void> = new Deferred<void>();
// tslint:disable-next-line:non-literal-fs-path
fs.exists(localFilePath, (exists : boolean) => { fs.exists(localFilePath, (exists : boolean) => {
// Detect if local file exist // Detect if local file exist
if (exists) { if (exists) {
var localFileStream = fs.createReadStream(localFilePath); const localFileStream: fs.ReadStream = fs.createReadStream(localFilePath);
var hdfsFileStream = hdfsClient.createWriteStream(hdfsFilePath); const hdfsFileStream: any = hdfsClient.createWriteStream(hdfsFilePath);
localFileStream.pipe(hdfsFileStream); localFileStream.pipe(hdfsFileStream);
hdfsFileStream.on('finish', function onFinish () { hdfsFileStream.on('finish', () => {
deferred.resolve(); deferred.resolve();
}); });
hdfsFileStream.on('error', (err : any) => { hdfsFileStream.on('error', (err : any) => {
getLogger().error(`HDFSCientUtility:copyFileToHdfs, copy file failed, err is ${err.message}`); getLogger()
.error(`HDFSCientUtility:copyFileToHdfs, copy file failed, err is ${err.message}`);
deferred.reject(err); deferred.reject(err);
}); });
} else { } else {
getLogger().error(`HDFSCientUtility:copyFileToHdfs, ${localFilePath} doesn't exist locally`); getLogger()
.error(`HDFSCientUtility:copyFileToHdfs, ${localFilePath} doesn't exist locally`);
deferred.reject('file not exist!'); deferred.reject('file not exist!');
} }
}); });
return deferred.promise; return deferred.promise;
} }
/** /**
* Recursively copy local directory to hdfs directory * Recursively copy local directory to hdfs directory
* *
* @param localDirectory local directory * @param localDirectory local directory
* @param hdfsDirectory HDFS directory * @param hdfsDirectory HDFS directory
* @param hdfsClient HDFS client * @param hdfsClient HDFS client
*/ */
export async function copyDirectoryToHdfs(localDirectory : string, hdfsDirectory : string, hdfsClient : any) : Promise<void>{ export async function copyDirectoryToHdfs(localDirectory : string, hdfsDirectory : string, hdfsClient : any) : Promise<void> {
const deferred: Deferred<void> = new Deferred<void>(); const deferred: Deferred<void> = new Deferred<void>();
// TODO: fs.readdirSync doesn't support ~($HOME) // TODO: fs.readdirSync doesn't support ~($HOME)
const fileNameArray: string[] = fs.readdirSync(localDirectory); const fileNameArray: string[] = fs.readdirSync(localDirectory);
for(var fileName of fileNameArray){ for (const fileName of fileNameArray) {
const fullFilePath: string = path.join(localDirectory, fileName); const fullFilePath: string = path.join(localDirectory, fileName);
try { try {
if (fs.lstatSync(fullFilePath).isFile()) { // tslint:disable-next-line:non-literal-fs-path
if (fs.lstatSync(fullFilePath)
.isFile()) {
await copyFileToHdfs(fullFilePath, path.join(hdfsDirectory, fileName), hdfsClient); await copyFileToHdfs(fullFilePath, path.join(hdfsDirectory, fileName), hdfsClient);
} else { } else {
// If filePath is a directory, recuisively copy it to remote directory // If filePath is a directory, recuisively copy it to remote directory
await copyDirectoryToHdfs(fullFilePath, path.join(hdfsDirectory, fileName), hdfsClient); await copyDirectoryToHdfs(fullFilePath, path.join(hdfsDirectory, fileName), hdfsClient);
} }
} catch(error) { } catch (error) {
deferred.reject(error); deferred.reject(error);
} }
} }
...@@ -118,20 +126,20 @@ export namespace HDFSClientUtility { ...@@ -118,20 +126,20 @@ export namespace HDFSClientUtility {
/** /**
* Read content from HDFS file * Read content from HDFS file
* *
* @param hdfsPath HDFS file path * @param hdfsPath HDFS file path
* @param hdfsClient HDFS client * @param hdfsClient HDFS client
*/ */
export async function readFileFromHDFS(hdfsPath : string, hdfsClient :any) : Promise<Buffer> { export async function readFileFromHDFS(hdfsPath : string, hdfsClient : any) : Promise<Buffer> {
const deferred: Deferred<Buffer> = new Deferred<Buffer>(); const deferred: Deferred<Buffer> = new Deferred<Buffer>();
let buffer : Buffer = Buffer.alloc(0); let buffer : Buffer = Buffer.alloc(0);
const exist : boolean = await pathExists(hdfsPath, hdfsClient); const exist : boolean = await pathExists(hdfsPath, hdfsClient);
if(!exist) { if (!exist) {
deferred.reject(`${hdfsPath} doesn't exists`); deferred.reject(`${hdfsPath} doesn't exists`);
} }
const remoteFileStream = hdfsClient.createReadStream(hdfsPath); const remoteFileStream: any = hdfsClient.createReadStream(hdfsPath);
remoteFileStream.on('error', (err : any) => { remoteFileStream.on('error', (err : any) => {
// Reject with the error // Reject with the error
deferred.reject(err); deferred.reject(err);
...@@ -141,8 +149,8 @@ export namespace HDFSClientUtility { ...@@ -141,8 +149,8 @@ export namespace HDFSClientUtility {
// Concat the data chunk to buffer // Concat the data chunk to buffer
buffer = Buffer.concat([buffer, chunk]); buffer = Buffer.concat([buffer, chunk]);
}); });
remoteFileStream.on('finish', function onFinish () { remoteFileStream.on('finish', () => {
// Upload is done, resolve // Upload is done, resolve
deferred.resolve(buffer); deferred.resolve(buffer);
}); });
...@@ -152,36 +160,38 @@ export namespace HDFSClientUtility { ...@@ -152,36 +160,38 @@ export namespace HDFSClientUtility {
/** /**
* Check if an HDFS path already exists * Check if an HDFS path already exists
* *
* @param hdfsPath target path need to check in HDFS * @param hdfsPath target path need to check in HDFS
* @param hdfsClient HDFS client * @param hdfsClient HDFS client
*/ */
export async function pathExists(hdfsPath : string, hdfsClient : any) : Promise<boolean> { export async function pathExists(hdfsPath : string, hdfsClient : any) : Promise<boolean> {
const deferred : Deferred<boolean> = new Deferred<boolean>(); const deferred : Deferred<boolean> = new Deferred<boolean>();
hdfsClient.exists(hdfsPath, (exist : boolean ) => { hdfsClient.exists(hdfsPath, (exist : boolean) => {
deferred.resolve(exist); deferred.resolve(exist);
}); });
let timeoutId : NodeJS.Timer let timeoutId : NodeJS.Timer;
const delayTimeout : Promise<boolean> = new Promise<boolean>((resolve : Function, reject : Function) : void => { const delayTimeout : Promise<boolean> = new Promise<boolean>((resolve : Function, reject : Function) : void => {
// Set timeout and reject the promise once reach timeout (5 seconds) // Set timeout and reject the promise once reach timeout (5 seconds)
timeoutId = setTimeout(() => deferred.reject(`Check HDFS path ${hdfsPath} exists timeout`), 5000); timeoutId = setTimeout(() => { reject(`Check HDFS path ${hdfsPath} exists timeout`); }, 5000);
}); });
return Promise.race([deferred.promise, delayTimeout]).finally(() => clearTimeout(timeoutId)); return Promise.race([deferred.promise, delayTimeout])
.finally(() => { clearTimeout(timeoutId); });
} }
/** /**
* Mkdir in HDFS, use default permission 755 * Mkdir in HDFS, use default permission 755
* *
* @param hdfsPath the path in HDFS. It could be either file or directory * @param hdfsPath the path in HDFS. It could be either file or directory
* @param hdfsClient * @param hdfsClient HDFS client
*/ */
export function mkdir(hdfsPath : string, hdfsClient : any) : Promise<boolean> { export function mkdir(hdfsPath : string, hdfsClient : any) : Promise<boolean> {
const deferred : Deferred<boolean> = new Deferred<boolean>(); const deferred : Deferred<boolean> = new Deferred<boolean>();
hdfsClient.mkdir(hdfsPath, (err : any)=> { hdfsClient.mkdir(hdfsPath, (err : any) => {
if(!err) { if (!err) {
deferred.resolve(true); deferred.resolve(true);
} else { } else {
deferred.reject(err.message); deferred.reject(err.message);
...@@ -193,19 +203,19 @@ export namespace HDFSClientUtility { ...@@ -193,19 +203,19 @@ export namespace HDFSClientUtility {
/** /**
* Read directory contents * Read directory contents
* *
* @param hdfsPath the path in HDFS. It could be either file or directory * @param hdfsPath the path in HDFS. It could be either file or directory
* @param hdfsClient * @param hdfsClient HDFS client
*/ */
export async function readdir(hdfsPath : string, hdfsClient : any) : Promise<string[]> { export async function readdir(hdfsPath : string, hdfsClient : any) : Promise<string[]> {
const deferred : Deferred<string[]> = new Deferred<string[]>(); const deferred : Deferred<string[]> = new Deferred<string[]>();
const exist : boolean = await pathExists(hdfsPath, hdfsClient); const exist : boolean = await pathExists(hdfsPath, hdfsClient);
if(!exist) { if (!exist) {
deferred.reject(`${hdfsPath} doesn't exists`); deferred.reject(`${hdfsPath} doesn't exists`);
} }
hdfsClient.readdir(hdfsPath, (err : any, files : any[] ) => { hdfsClient.readdir(hdfsPath, (err : any, files : any[]) => {
if(err) { if (err) {
deferred.reject(err); deferred.reject(err);
} }
...@@ -218,18 +228,20 @@ export namespace HDFSClientUtility { ...@@ -218,18 +228,20 @@ export namespace HDFSClientUtility {
/** /**
* Delete HDFS path * Delete HDFS path
* @param hdfsPath the path in HDFS. It could be either file or directory * @param hdfsPath the path in HDFS. It could be either file or directory
* @param hdfsClient * @param hdfsClient HDFS client
* @param recursive Mark if need to delete recursively * @param recursive Mark if need to delete recursively
*/ */
export function deletePath(hdfsPath : string, hdfsClient : any, recursive : boolean = true) : Promise<boolean> { export function deletePath(hdfsPath : string, hdfsClient : any, recursive : boolean = true) : Promise<boolean> {
const deferred : Deferred<boolean> = new Deferred<boolean>(); const deferred : Deferred<boolean> = new Deferred<boolean>();
hdfsClient.unlink(hdfsPath, recursive, (err : any)=> { hdfsClient.unlink(hdfsPath, recursive, (err : any) => {
if(!err) { if (!err) {
deferred.resolve(true); deferred.resolve(true);
} else { } else {
deferred.reject(err.message); deferred.reject(err.message);
} }
}); });
return deferred.promise; return deferred.promise;
} }
// tslint:enable: no-unsafe-any non-literal-fs-path no-any
} }
...@@ -19,8 +19,11 @@ ...@@ -19,8 +19,11 @@
'use strict'; 'use strict';
import {TrialConfig} from '../common/trialConfig' import {TrialConfig} from '../common/trialConfig';
/**
* Task role for PAI
*/
export class PAITaskRole { export class PAITaskRole {
// Name for the task role // Name for the task role
public readonly name: string; public readonly name: string;
...@@ -36,7 +39,7 @@ export class PAITaskRole { ...@@ -36,7 +39,7 @@ export class PAITaskRole {
public readonly command: string; public readonly command: string;
//Shared memory for one task in the task role //Shared memory for one task in the task role
public readonly shmMB?: number; public readonly shmMB?: number;
/** /**
* Constructor * Constructor
* @param name Name for the task role * @param name Name for the task role
...@@ -46,18 +49,22 @@ export class PAITaskRole { ...@@ -46,18 +49,22 @@ export class PAITaskRole {
* @param gpuNumber GPU number for one task in the task role, no less than 0 * @param gpuNumber GPU number for one task in the task role, no less than 0
* @param command Executable command for tasks in the task role, can not be empty * @param command Executable command for tasks in the task role, can not be empty
*/ */
constructor(name : string, taskNumber : number, cpuNumber : number, memoryMB : number, gpuNumber : number, command : string, shmMB?: number) { constructor(name : string, taskNumber : number, cpuNumber : number, memoryMB : number, gpuNumber : number,
command : string, shmMB?: number) {
this.name = name; this.name = name;
this.taskNumber = taskNumber; this.taskNumber = taskNumber;
this.cpuNumber = cpuNumber; this.cpuNumber = cpuNumber;
this.memoryMB = memoryMB; this.memoryMB = memoryMB;
this.gpuNumber = gpuNumber; this.gpuNumber = gpuNumber;
this.command = command; this.command = command;
this.shmMB = shmMB; this.shmMB = shmMB;
} }
} }
export class PAIJobConfig{ /**
* Trial job configuration submitted to PAI
*/
export class PAIJobConfig {
// Name for the job, need to be unique // Name for the job, need to be unique
public readonly jobName: string; public readonly jobName: string;
// URL pointing to the Docker image for all tasks in the job // URL pointing to the Docker image for all tasks in the job
...@@ -83,8 +90,8 @@ export class PAIJobConfig{ ...@@ -83,8 +90,8 @@ export class PAIJobConfig{
* @param outputDir Output directory on HDFS * @param outputDir Output directory on HDFS
* @param taskRoles List of taskRole, one task role at least * @param taskRoles List of taskRole, one task role at least
*/ */
constructor(jobName: string, image : string, dataDir : string, outputDir : string, codeDir : string, constructor(jobName: string, image : string, dataDir : string, outputDir : string, codeDir : string,
taskRoles : PAITaskRole[], virtualCluster: string) { taskRoles : PAITaskRole[], virtualCluster: string) {
this.jobName = jobName; this.jobName = jobName;
this.image = image; this.image = image;
this.dataDir = dataDir; this.dataDir = dataDir;
...@@ -95,6 +102,9 @@ export class PAIJobConfig{ ...@@ -95,6 +102,9 @@ export class PAIJobConfig{
} }
} }
/**
* PAI cluster configuration
*/
export class PAIClusterConfig { export class PAIClusterConfig {
public readonly userName: string; public readonly userName: string;
public readonly passWord: string; public readonly passWord: string;
...@@ -106,18 +116,21 @@ export class PAIClusterConfig { ...@@ -106,18 +116,21 @@ export class PAIClusterConfig {
* @param passWord password of PAI Cluster * @param passWord password of PAI Cluster
* @param host Host IP of PAI Cluster * @param host Host IP of PAI Cluster
*/ */
constructor(userName: string, passWord : string, host : string){ constructor(userName: string, passWord : string, host : string) {
this.userName = userName; this.userName = userName;
this.passWord = passWord; this.passWord = passWord;
this.host = host; this.host = host;
} }
} }
export class NNIPAITrialConfig extends TrialConfig{ /**
* PAI trial configuration
*/
export class NNIPAITrialConfig extends TrialConfig {
public readonly cpuNum: number; public readonly cpuNum: number;
public readonly memoryMB: number; public readonly memoryMB: number;
public readonly image: string; public readonly image: string;
public readonly dataDir: string; public readonly dataDir: string;
public outputDir: string; public outputDir: string;
//The virtual cluster job runs on. If omitted, the job will run on default virtual cluster //The virtual cluster job runs on. If omitted, the job will run on default virtual cluster
...@@ -125,8 +138,8 @@ export class NNIPAITrialConfig extends TrialConfig{ ...@@ -125,8 +138,8 @@ export class NNIPAITrialConfig extends TrialConfig{
//Shared memory for one task in the task role //Shared memory for one task in the task role
public shmMB?: number; public shmMB?: number;
constructor(command : string, codeDir : string, gpuNum : number, cpuNum: number, memoryMB: number, constructor(command : string, codeDir : string, gpuNum : number, cpuNum: number, memoryMB: number,
image: string, dataDir: string, outputDir: string, virtualCluster?: string, shmMB?: number) { image: string, dataDir: string, outputDir: string, virtualCluster?: string, shmMB?: number) {
super(command, codeDir, gpuNum); super(command, codeDir, gpuNum);
this.cpuNum = cpuNum; this.cpuNum = cpuNum;
this.memoryMB = memoryMB; this.memoryMB = memoryMB;
...@@ -137,4 +150,3 @@ export class NNIPAITrialConfig extends TrialConfig{ ...@@ -137,4 +150,3 @@ export class NNIPAITrialConfig extends TrialConfig{
this.shmMB = shmMB; this.shmMB = shmMB;
} }
} }
...@@ -19,8 +19,11 @@ ...@@ -19,8 +19,11 @@
'use strict'; 'use strict';
import { JobApplicationForm, TrialJobDetail, TrialJobStatus } from 'common/trainingService'; import { JobApplicationForm, TrialJobDetail, TrialJobStatus } from '../../common/trainingService';
/**
* PAI trial job detail
*/
export class PAITrialJobDetail implements TrialJobDetail { export class PAITrialJobDetail implements TrialJobDetail {
public id: string; public id: string;
public status: TrialJobStatus; public status: TrialJobStatus;
...@@ -36,8 +39,8 @@ export class PAITrialJobDetail implements TrialJobDetail { ...@@ -36,8 +39,8 @@ export class PAITrialJobDetail implements TrialJobDetail {
public hdfsLogPath: string; public hdfsLogPath: string;
public isEarlyStopped?: boolean; public isEarlyStopped?: boolean;
constructor(id: string, status: TrialJobStatus, paiJobName : string, constructor(id: string, status: TrialJobStatus, paiJobName : string,
submitTime: number, workingDirectory: string, form: JobApplicationForm, sequenceId: number, hdfsLogPath: string) { submitTime: number, workingDirectory: string, form: JobApplicationForm, sequenceId: number, hdfsLogPath: string) {
this.id = id; this.id = id;
this.status = status; this.status = status;
this.paiJobName = paiJobName; this.paiJobName = paiJobName;
...@@ -50,7 +53,7 @@ export class PAITrialJobDetail implements TrialJobDetail { ...@@ -50,7 +53,7 @@ export class PAITrialJobDetail implements TrialJobDetail {
} }
} }
export const PAI_INSTALL_NNI_SHELL_FORMAT: string = export const PAI_INSTALL_NNI_SHELL_FORMAT: string =
`#!/bin/bash `#!/bin/bash
if python3 -c 'import nni' > /dev/null 2>&1; then if python3 -c 'import nni' > /dev/null 2>&1; then
# nni module is already installed, skip # nni module is already installed, skip
...@@ -61,13 +64,15 @@ else ...@@ -61,13 +64,15 @@ else
fi`; fi`;
export const PAI_TRIAL_COMMAND_FORMAT: string = export const PAI_TRIAL_COMMAND_FORMAT: string =
`export NNI_PLATFORM=pai NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={2} NNI_EXP_ID={3} NNI_TRIAL_SEQ_ID={4} `export NNI_PLATFORM=pai NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={2} NNI_EXP_ID={3} NNI_TRIAL_SEQ_ID={4} \
&& cd $NNI_SYS_DIR && sh install_nni.sh && cd $NNI_SYS_DIR && sh install_nni.sh \
&& python3 -m nni_trial_tool.trial_keeper --trial_command '{5}' --nnimanager_ip '{6}' --nnimanager_port '{7}' && python3 -m nni_trial_tool.trial_keeper --trial_command '{5}' --nnimanager_ip '{6}' --nnimanager_port '{7}' \
--pai_hdfs_output_dir '{8}' --pai_hdfs_host '{9}' --pai_user_name {10} --nni_hdfs_exp_dir '{11}' --webhdfs_path '/webhdfs/api/v1' --nni_manager_version '{12}' --log_collection '{13}'`; --pai_hdfs_output_dir '{8}' --pai_hdfs_host '{9}' --pai_user_name {10} --nni_hdfs_exp_dir '{11}' --webhdfs_path '/webhdfs/api/v1' \
--nni_manager_version '{12}' --log_collection '{13}'`;
export const PAI_OUTPUT_DIR_FORMAT: string = export const PAI_OUTPUT_DIR_FORMAT: string =
`hdfs://{0}:9000/`; `hdfs://{0}:9000/`;
export const PAI_LOG_PATH_FORMAT: string = // tslint:disable:no-http-string
`http://{0}/webhdfs/explorer.html#{1}` export const PAI_LOG_PATH_FORMAT: string =
`http://{0}/webhdfs/explorer.html#{1}`;
...@@ -19,13 +19,14 @@ ...@@ -19,13 +19,14 @@
'use strict'; 'use strict';
// tslint:disable-next-line:no-implicit-dependencies
import * as request from 'request'; import * as request from 'request';
import { Deferred } from 'ts-deferred'; import { Deferred } from 'ts-deferred';
import { getLogger, Logger } from '../../common/log';
import { NNIError, NNIErrorNames } from '../../common/errors'; import { NNIError, NNIErrorNames } from '../../common/errors';
import { PAITrialJobDetail } from './paiData'; import { getLogger, Logger } from '../../common/log';
import { PAIClusterConfig } from './paiConfig';
import { TrialJobStatus } from '../../common/trainingService'; import { TrialJobStatus } from '../../common/trainingService';
import { PAIClusterConfig } from './paiConfig';
import { PAITrialJobDetail } from './paiData';
/** /**
* Collector PAI jobs info from PAI cluster, and update pai job status locally * Collector PAI jobs info from PAI cluster, and update pai job status locally
...@@ -43,60 +44,65 @@ export class PAIJobInfoCollector { ...@@ -43,60 +44,65 @@ export class PAIJobInfoCollector {
} }
public async retrieveTrialStatus(paiToken? : string, paiClusterConfig?: PAIClusterConfig) : Promise<void> { public async retrieveTrialStatus(paiToken? : string, paiClusterConfig?: PAIClusterConfig) : Promise<void> {
if (!paiClusterConfig || !paiToken) { if (paiClusterConfig === undefined || paiToken === undefined) {
return Promise.resolve(); return Promise.resolve();
} }
const updatePaiTrialJobs : Promise<void>[] = []; const updatePaiTrialJobs : Promise<void>[] = [];
for(let [trialJobId, paiTrialJob] of this.trialJobsMap) { for (const [trialJobId, paiTrialJob] of this.trialJobsMap) {
if (!paiTrialJob) { if (paiTrialJob === undefined) {
throw new NNIError(NNIErrorNames.NOT_FOUND, `trial job id ${trialJobId} not found`); throw new NNIError(NNIErrorNames.NOT_FOUND, `trial job id ${trialJobId} not found`);
} }
updatePaiTrialJobs.push(this.getSinglePAITrialJobInfo(paiTrialJob, paiToken, paiClusterConfig)) updatePaiTrialJobs.push(this.getSinglePAITrialJobInfo(paiTrialJob, paiToken, paiClusterConfig));
} }
await Promise.all(updatePaiTrialJobs); await Promise.all(updatePaiTrialJobs);
} }
private getSinglePAITrialJobInfo(paiTrialJob : PAITrialJobDetail, paiToken : string, paiClusterConfig: PAIClusterConfig) : Promise<void> { private getSinglePAITrialJobInfo(paiTrialJob : PAITrialJobDetail, paiToken : string, paiClusterConfig: PAIClusterConfig)
: Promise<void> {
const deferred : Deferred<void> = new Deferred<void>(); const deferred : Deferred<void> = new Deferred<void>();
if (!this.statusesNeedToCheck.includes(paiTrialJob.status)) { if (!this.statusesNeedToCheck.includes(paiTrialJob.status)) {
deferred.resolve(); deferred.resolve();
return deferred.promise; return deferred.promise;
} }
// Rest call to get PAI job info and update status // Rest call to get PAI job info and update status
// Refer https://github.com/Microsoft/pai/blob/master/docs/rest-server/API.md for more detail about PAI Rest API // Refer https://github.com/Microsoft/pai/blob/master/docs/rest-server/API.md for more detail about PAI Rest API
const getJobInfoRequest: request.Options = { const getJobInfoRequest: request.Options = {
// tslint:disable-next-line:no-http-string
uri: `http://${paiClusterConfig.host}/rest-server/api/v1/user/${paiClusterConfig.userName}/jobs/${paiTrialJob.paiJobName}`, uri: `http://${paiClusterConfig.host}/rest-server/api/v1/user/${paiClusterConfig.userName}/jobs/${paiTrialJob.paiJobName}`,
method: 'GET', method: 'GET',
json: true, json: true,
headers: { headers: {
"Content-Type": "application/json", 'Content-Type': 'application/json',
"Authorization": 'Bearer ' + paiToken Authorization: `Bearer ${paiToken}`
} }
}; };
//TODO : pass in request timeout param?
// tslint:disable: no-unsafe-any no-any cyclomatic-complexity
//TODO : pass in request timeout param?
request(getJobInfoRequest, (error: Error, response: request.Response, body: any) => { request(getJobInfoRequest, (error: Error, response: request.Response, body: any) => {
if (error || response.statusCode >= 500) { if ((error !== undefined && error !== null) || response.statusCode >= 500) {
this.log.error(`PAI Training service: get job info for trial ${paiTrialJob.id} from PAI Cluster failed!`); this.log.error(`PAI Training service: get job info for trial ${paiTrialJob.id} from PAI Cluster failed!`);
// Queried PAI job info failed, set job status to UNKNOWN // Queried PAI job info failed, set job status to UNKNOWN
if(paiTrialJob.status === 'WAITING' || paiTrialJob.status === 'RUNNING') { if (paiTrialJob.status === 'WAITING' || paiTrialJob.status === 'RUNNING') {
paiTrialJob.status = 'UNKNOWN'; paiTrialJob.status = 'UNKNOWN';
} }
} else { } else {
if(response.body.jobStatus && response.body.jobStatus.state) { if (response.body.jobStatus && response.body.jobStatus.state) {
switch(response.body.jobStatus.state) { switch (response.body.jobStatus.state) {
case 'WAITING': case 'WAITING':
paiTrialJob.status = 'WAITING'; paiTrialJob.status = 'WAITING';
break; break;
case 'RUNNING': case 'RUNNING':
paiTrialJob.status = 'RUNNING'; paiTrialJob.status = 'RUNNING';
if(!paiTrialJob.startTime) { if (paiTrialJob.startTime === undefined) {
paiTrialJob.startTime = response.body.jobStatus.appLaunchedTime; paiTrialJob.startTime = response.body.jobStatus.appLaunchedTime;
} }
if(!paiTrialJob.url) { if (paiTrialJob.url === undefined) {
paiTrialJob.url = response.body.jobStatus.appTrackingUrl; paiTrialJob.url = response.body.jobStatus.appTrackingUrl;
} }
break; break;
case 'SUCCEEDED': case 'SUCCEEDED':
...@@ -104,30 +110,31 @@ export class PAIJobInfoCollector { ...@@ -104,30 +110,31 @@ export class PAIJobInfoCollector {
break; break;
case 'STOPPED': case 'STOPPED':
if (paiTrialJob.isEarlyStopped !== undefined) { if (paiTrialJob.isEarlyStopped !== undefined) {
paiTrialJob.status = paiTrialJob.isEarlyStopped === true ? paiTrialJob.status = paiTrialJob.isEarlyStopped === true ?
'EARLY_STOPPED' : 'USER_CANCELED'; 'EARLY_STOPPED' : 'USER_CANCELED';
} else { } else {
// if paiTrialJob's isEarlyStopped is undefined, that mean we didn't stop it via cancellation, mark it as SYS_CANCELLED by PAI /* if paiTrialJob's isEarlyStopped is undefined, that mean we didn't stop it via cancellation,
* mark it as SYS_CANCELLED by PAI
*/
paiTrialJob.status = 'SYS_CANCELED'; paiTrialJob.status = 'SYS_CANCELED';
} }
break; break;
case 'FAILED': case 'FAILED':
paiTrialJob.status = 'FAILED'; paiTrialJob.status = 'FAILED';
break; break;
default: default:
paiTrialJob.status = 'UNKNOWN'; paiTrialJob.status = 'UNKNOWN';
break;
} }
// For final job statues, update startTime, endTime and url // For final job statues, update startTime, endTime and url
if(this.finalStatuses.includes(paiTrialJob.status)) { if (this.finalStatuses.includes(paiTrialJob.status)) {
if(!paiTrialJob.startTime) { if (paiTrialJob.startTime === undefined) {
paiTrialJob.startTime = response.body.jobStatus.appLaunchedTime; paiTrialJob.startTime = response.body.jobStatus.appLaunchedTime;
} }
if(!paiTrialJob.endTime) { if (paiTrialJob.endTime === undefined) {
paiTrialJob.endTime = response.body.jobStatus.completedTime; paiTrialJob.endTime = response.body.jobStatus.completedTime;
} }
// Set pai trial job's url to WebHDFS output path // Set pai trial job's url to WebHDFS output path
if(paiTrialJob.hdfsLogPath) { if (paiTrialJob.hdfsLogPath !== undefined) {
paiTrialJob.url += `,${paiTrialJob.hdfsLogPath}`; paiTrialJob.url += `,${paiTrialJob.hdfsLogPath}`;
} }
} }
...@@ -138,4 +145,5 @@ export class PAIJobInfoCollector { ...@@ -138,4 +145,5 @@ export class PAIJobInfoCollector {
return deferred.promise; return deferred.promise;
} }
// tslint:enable: no-unsafe-any no-any
} }
...@@ -19,17 +19,17 @@ ...@@ -19,17 +19,17 @@
'use strict'; 'use strict';
import * as component from '../../common/component';
import { Inject } from 'typescript-ioc'; import { Inject } from 'typescript-ioc';
import * as component from '../../common/component';
import { ClusterJobRestServer } from '../common/clusterJobRestServer';
import { PAITrainingService } from './paiTrainingService'; import { PAITrainingService } from './paiTrainingService';
import { ClusterJobRestServer } from '../common/clusterJobRestServer'
/** /**
* PAI Training service Rest server, provides rest API to support pai job metrics update * PAI Training service Rest server, provides rest API to support pai job metrics update
* *
*/ */
@component.Singleton @component.Singleton
export class PAIJobRestServer extends ClusterJobRestServer{ export class PAIJobRestServer extends ClusterJobRestServer {
@Inject @Inject
private readonly paiTrainingService : PAITrainingService; private readonly paiTrainingService : PAITrainingService;
...@@ -41,6 +41,7 @@ export class PAIJobRestServer extends ClusterJobRestServer{ ...@@ -41,6 +41,7 @@ export class PAIJobRestServer extends ClusterJobRestServer{
this.paiTrainingService = component.get(PAITrainingService); this.paiTrainingService = component.get(PAITrainingService);
} }
// tslint:disable-next-line:no-any
protected handleTrialMetrics(jobId : string, metrics : any[]) : void { protected handleTrialMetrics(jobId : string, metrics : any[]) : void {
// Split metrics array into single metric, then emit // Split metrics array into single metric, then emit
// Warning: If not split metrics into single ones, the behavior will be UNKNOWN // Warning: If not split metrics into single ones, the behavior will be UNKNOWN
...@@ -51,4 +52,4 @@ export class PAIJobRestServer extends ClusterJobRestServer{ ...@@ -51,4 +52,4 @@ export class PAIJobRestServer extends ClusterJobRestServer{
}); });
} }
} }
} }
\ No newline at end of file
/** /**
* Copyright (c) Microsoft Corporation * Copyright (c) Microsoft Corporation
* All rights reserved. * All rights reserved.
...@@ -23,6 +22,7 @@ ...@@ -23,6 +22,7 @@
import * as cpp from 'child-process-promise'; import * as cpp from 'child-process-promise';
import * as fs from 'fs'; import * as fs from 'fs';
import * as path from 'path'; import * as path from 'path';
// tslint:disable-next-line:no-implicit-dependencies
import * as request from 'request'; import * as request from 'request';
import * as component from '../../common/component'; import * as component from '../../common/component';
...@@ -37,18 +37,17 @@ import { ...@@ -37,18 +37,17 @@ import {
TrialJobApplicationForm, TrialJobDetail, TrialJobMetric TrialJobApplicationForm, TrialJobDetail, TrialJobMetric
} from '../../common/trainingService'; } from '../../common/trainingService';
import { delay, generateParamFileName, import { delay, generateParamFileName,
getExperimentRootDir, getIPV4Address, getVersion, uniqueString } from '../../common/utils'; getExperimentRootDir, getIPV4Address, getVersion, uniqueString, unixPathJoin } from '../../common/utils';
import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData'; import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData';
import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey'; import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
import { validateCodeDir, execMkdir } from '../common/util'; import { execMkdir, validateCodeDir } from '../common/util';
import { unixPathJoin } from '../../common/utils'
import { HDFSClientUtility } from './hdfsClientUtility'; import { HDFSClientUtility } from './hdfsClientUtility';
import { NNIPAITrialConfig, PAIClusterConfig, PAIJobConfig, PAITaskRole } from './paiConfig'; import { NNIPAITrialConfig, PAIClusterConfig, PAIJobConfig, PAITaskRole } from './paiConfig';
import { PAI_LOG_PATH_FORMAT, PAI_OUTPUT_DIR_FORMAT, PAI_TRIAL_COMMAND_FORMAT, PAITrialJobDetail } from './paiData'; import { PAI_LOG_PATH_FORMAT, PAI_OUTPUT_DIR_FORMAT, PAI_TRIAL_COMMAND_FORMAT, PAITrialJobDetail } from './paiData';
import { PAIJobInfoCollector } from './paiJobInfoCollector'; import { PAIJobInfoCollector } from './paiJobInfoCollector';
import { PAIJobRestServer } from './paiJobRestServer'; import { PAIJobRestServer } from './paiJobRestServer';
const WebHDFS = require('webhdfs'); import * as WebHDFS from 'webhdfs';
/** /**
* Training Service implementation for OpenPAI (Open Platform for AI) * Training Service implementation for OpenPAI (Open Platform for AI)
...@@ -62,13 +61,14 @@ class PAITrainingService implements TrainingService { ...@@ -62,13 +61,14 @@ class PAITrainingService implements TrainingService {
private readonly expRootDir: string; private readonly expRootDir: string;
private paiTrialConfig: NNIPAITrialConfig | undefined; private paiTrialConfig: NNIPAITrialConfig | undefined;
private paiClusterConfig?: PAIClusterConfig; private paiClusterConfig?: PAIClusterConfig;
private jobQueue: string[]; private readonly jobQueue: string[];
private stopping: boolean = false; private stopping: boolean = false;
// tslint:disable-next-line:no-any
private hdfsClient: any; private hdfsClient: any;
private paiToken? : string; private paiToken? : string;
private paiTokenUpdateTime?: number; private paiTokenUpdateTime?: number;
private paiTokenUpdateInterval: number; private readonly paiTokenUpdateInterval: number;
private experimentId! : string; private readonly experimentId! : string;
private readonly paiJobCollector : PAIJobInfoCollector; private readonly paiJobCollector : PAIJobInfoCollector;
private readonly hdfsDirPattern: string; private readonly hdfsDirPattern: string;
private hdfsBaseDir: string | undefined; private hdfsBaseDir: string | undefined;
...@@ -121,13 +121,13 @@ class PAITrainingService implements TrainingService { ...@@ -121,13 +121,13 @@ class PAITrainingService implements TrainingService {
} }
public async getTrialJob(trialJobId: string): Promise<TrialJobDetail> { public async getTrialJob(trialJobId: string): Promise<TrialJobDetail> {
if (!this.paiClusterConfig) { if (this.paiClusterConfig === undefined) {
throw new Error('PAI Cluster config is not initialized'); throw new Error('PAI Cluster config is not initialized');
} }
const paiTrialJob: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId); const paiTrialJob: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
if (!paiTrialJob) { if (paiTrialJob === undefined) {
return Promise.reject(`trial job ${trialJobId} not found`); return Promise.reject(`trial job ${trialJobId} not found`);
} }
...@@ -144,7 +144,7 @@ class PAITrainingService implements TrainingService { ...@@ -144,7 +144,7 @@ class PAITrainingService implements TrainingService {
public async submitTrialJob(form: JobApplicationForm): Promise<TrialJobDetail> { public async submitTrialJob(form: JobApplicationForm): Promise<TrialJobDetail> {
const deferred : Deferred<PAITrialJobDetail> = new Deferred<PAITrialJobDetail>(); const deferred : Deferred<PAITrialJobDetail> = new Deferred<PAITrialJobDetail>();
if (!this.hdfsBaseDir) { if (this.hdfsBaseDir === undefined) {
throw new Error('hdfsBaseDir is not initialized'); throw new Error('hdfsBaseDir is not initialized');
} }
...@@ -187,24 +187,26 @@ class PAITrainingService implements TrainingService { ...@@ -187,24 +187,26 @@ class PAITrainingService implements TrainingService {
return false; return false;
} }
// tslint:disable:no-http-string
public cancelTrialJob(trialJobId: string, isEarlyStopped: boolean = false): Promise<void> { public cancelTrialJob(trialJobId: string, isEarlyStopped: boolean = false): Promise<void> {
const trialJobDetail : PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId); const trialJobDetail : PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
const deferred : Deferred<void> = new Deferred<void>(); const deferred : Deferred<void> = new Deferred<void>();
if (!trialJobDetail) { if (trialJobDetail === undefined) {
this.log.error(`cancelTrialJob: trial job id ${trialJobId} not found`); this.log.error(`cancelTrialJob: trial job id ${trialJobId} not found`);
return Promise.reject(); return Promise.reject();
} }
if (!this.paiClusterConfig) { if (this.paiClusterConfig === undefined) {
throw new Error('PAI Cluster config is not initialized'); throw new Error('PAI Cluster config is not initialized');
} }
if (!this.paiToken) { if (this.paiToken === undefined) {
throw new Error('PAI token is not initialized'); throw new Error('PAI token is not initialized');
} }
const stopJobRequest: request.Options = { const stopJobRequest: request.Options = {
uri: `http://${this.paiClusterConfig.host}/rest-server/api/v1/user/${this.paiClusterConfig.userName}/jobs/${trialJobDetail.paiJobName}/executionType`, uri: `http://${this.paiClusterConfig.host}/rest-server/api/v1/user/${this.paiClusterConfig.userName}\
/jobs/${trialJobDetail.paiJobName}/executionType`,
method: 'PUT', method: 'PUT',
json: true, json: true,
body: {value: 'STOP'}, body: {value: 'STOP'},
...@@ -217,10 +219,12 @@ class PAITrainingService implements TrainingService { ...@@ -217,10 +219,12 @@ class PAITrainingService implements TrainingService {
// Set trialjobDetail's early stopped field, to mark the job's cancellation source // Set trialjobDetail's early stopped field, to mark the job's cancellation source
trialJobDetail.isEarlyStopped = isEarlyStopped; trialJobDetail.isEarlyStopped = isEarlyStopped;
// tslint:disable-next-line:no-any
request(stopJobRequest, (error: Error, response: request.Response, body: any) => { request(stopJobRequest, (error: Error, response: request.Response, body: any) => {
if (error || response.statusCode >= 400) { if ((error !== undefined && error !== null) || response.statusCode >= 400) {
this.log.error(`PAI Training service: stop trial ${trialJobId} to PAI Cluster failed!`); this.log.error(`PAI Training service: stop trial ${trialJobId} to PAI Cluster failed!`);
deferred.reject(error ? error.message : `Stop trial failed, http code: ${response.statusCode}`); deferred.reject((error !== undefined && error !== null) ? error.message :
`Stop trial failed, http code: ${response.statusCode}`);
} else { } else {
deferred.resolve(); deferred.resolve();
} }
...@@ -229,6 +233,7 @@ class PAITrainingService implements TrainingService { ...@@ -229,6 +233,7 @@ class PAITrainingService implements TrainingService {
return deferred.promise; return deferred.promise;
} }
// tslint:disable: no-unsafe-any no-any
// tslint:disable-next-line:max-func-body-length // tslint:disable-next-line:max-func-body-length
public async setClusterMetadata(key: string, value: string): Promise<void> { public async setClusterMetadata(key: string, value: string): Promise<void> {
const deferred : Deferred<void> = new Deferred<void>(); const deferred : Deferred<void> = new Deferred<void>();
...@@ -256,47 +261,47 @@ class PAITrainingService implements TrainingService { ...@@ -256,47 +261,47 @@ class PAITrainingService implements TrainingService {
break; break;
case TrialConfigMetadataKey.TRIAL_CONFIG: case TrialConfigMetadataKey.TRIAL_CONFIG:
if (!this.paiClusterConfig) { if (this.paiClusterConfig === undefined) {
this.log.error('pai cluster config is not initialized'); this.log.error('pai cluster config is not initialized');
deferred.reject(new Error('pai cluster config is not initialized')); deferred.reject(new Error('pai cluster config is not initialized'));
break; break;
} }
this.paiTrialConfig = <NNIPAITrialConfig>JSON.parse(value); this.paiTrialConfig = <NNIPAITrialConfig>JSON.parse(value);
//paiTrialConfig.outputDir could be null if it is not set in nnictl //paiTrialConfig.outputDir could be null if it is not set in nnictl
if (this.paiTrialConfig.outputDir === undefined || this.paiTrialConfig.outputDir === null){ if (this.paiTrialConfig.outputDir === undefined || this.paiTrialConfig.outputDir === null) {
this.paiTrialConfig.outputDir = String.Format( this.paiTrialConfig.outputDir = String.Format(
PAI_OUTPUT_DIR_FORMAT, PAI_OUTPUT_DIR_FORMAT,
this.paiClusterConfig.host this.paiClusterConfig.host
).replace(/\r\n|\n|\r/gm, ''); )
.replace(/\r\n|\n|\r/gm, '');
} }
// Validate to make sure codeDir doesn't have too many files // Validate to make sure codeDir doesn't have too many files
try { try {
await validateCodeDir(this.paiTrialConfig.codeDir); await validateCodeDir(this.paiTrialConfig.codeDir);
} catch(error) { } catch (error) {
this.log.error(error); this.log.error(error);
deferred.reject(new Error(error)); deferred.reject(new Error(error));
break; break;
} }
const hdfsDirContent = this.paiTrialConfig.outputDir.match(this.hdfsDirPattern); const hdfsDirContent: any = this.paiTrialConfig.outputDir.match(this.hdfsDirPattern);
if (hdfsDirContent === null) { if (hdfsDirContent === null) {
throw new Error('Trial outputDir format Error'); throw new Error('Trial outputDir format Error');
} }
const groups = hdfsDirContent.groups; const groups: any = hdfsDirContent.groups;
if (groups === undefined) { if (groups === undefined) {
throw new Error('Trial outputDir format Error'); throw new Error('Trial outputDir format Error');
} }
this.hdfsOutputHost = groups.host;
this.hdfsOutputHost = groups['host'];
//TODO: choose to use /${username} as baseDir //TODO: choose to use /${username} as baseDir
this.hdfsBaseDir = groups['baseDir']; this.hdfsBaseDir = groups.baseDir;
if(this.hdfsBaseDir === undefined) { if (this.hdfsBaseDir === undefined) {
this.hdfsBaseDir = '/'; this.hdfsBaseDir = '/';
} }
let dataOutputHdfsClient; let dataOutputHdfsClient: any;
if (this.paiClusterConfig.host === this.hdfsOutputHost && this.hdfsClient) { if (this.paiClusterConfig.host === this.hdfsOutputHost && this.hdfsClient) {
dataOutputHdfsClient = this.hdfsClient; dataOutputHdfsClient = this.hdfsClient;
} else { } else {
...@@ -338,6 +343,7 @@ class PAITrainingService implements TrainingService { ...@@ -338,6 +343,7 @@ class PAITrainingService implements TrainingService {
return deferred.promise; return deferred.promise;
} }
// tslint:enable: no-unsafe-any
public getClusterMetadata(key: string): Promise<string> { public getClusterMetadata(key: string): Promise<string> {
const deferred : Deferred<string> = new Deferred<string>(); const deferred : Deferred<string> = new Deferred<string>();
...@@ -358,6 +364,7 @@ class PAITrainingService implements TrainingService { ...@@ -358,6 +364,7 @@ class PAITrainingService implements TrainingService {
deferred.resolve(); deferred.resolve();
this.log.info('PAI Training service rest server stopped successfully.'); this.log.info('PAI Training service rest server stopped successfully.');
} catch (error) { } catch (error) {
// tslint:disable-next-line: no-unsafe-any
this.log.error(`PAI Training service rest server stopped failed, error: ${error.message}`); this.log.error(`PAI Training service rest server stopped failed, error: ${error.message}`);
deferred.reject(error); deferred.reject(error);
} }
...@@ -374,35 +381,35 @@ class PAITrainingService implements TrainingService { ...@@ -374,35 +381,35 @@ class PAITrainingService implements TrainingService {
const deferred : Deferred<boolean> = new Deferred<boolean>(); const deferred : Deferred<boolean> = new Deferred<boolean>();
const trialJobDetail: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId); const trialJobDetail: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
if (!trialJobDetail) { if (trialJobDetail === undefined) {
throw new Error(`Failed to find PAITrialJobDetail for job ${trialJobId}`); throw new Error(`Failed to find PAITrialJobDetail for job ${trialJobId}`);
} }
if (!this.paiClusterConfig) { if (this.paiClusterConfig === undefined) {
throw new Error('PAI Cluster config is not initialized'); throw new Error('PAI Cluster config is not initialized');
} }
if (!this.paiTrialConfig) { if (this.paiTrialConfig === undefined) {
throw new Error('trial config is not initialized'); throw new Error('trial config is not initialized');
} }
if (!this.paiToken) { if (this.paiToken === undefined) {
throw new Error('PAI token is not initialized'); throw new Error('PAI token is not initialized');
} }
if (!this.hdfsBaseDir) { if (this.hdfsBaseDir === undefined) {
throw new Error('hdfsBaseDir is not initialized'); throw new Error('hdfsBaseDir is not initialized');
} }
if (!this.hdfsOutputHost) { if (this.hdfsOutputHost === undefined) {
throw new Error('hdfsOutputHost is not initialized'); throw new Error('hdfsOutputHost is not initialized');
} }
if (!this.paiRestServerPort) { if (this.paiRestServerPort === undefined) {
const restServer: PAIJobRestServer = component.get(PAIJobRestServer); const restServer: PAIJobRestServer = component.get(PAIJobRestServer);
this.paiRestServerPort = restServer.clusterRestServerPort; this.paiRestServerPort = restServer.clusterRestServerPort;
} }
// Make sure experiment code files is copied from local to HDFS // Make sure experiment code files is copied from local to HDFS
if (this.copyExpCodeDirPromise) { if (this.copyExpCodeDirPromise !== undefined) {
await this.copyExpCodeDirPromise; await this.copyExpCodeDirPromise;
} }
...@@ -420,13 +427,14 @@ class PAITrainingService implements TrainingService { ...@@ -420,13 +427,14 @@ class PAITrainingService implements TrainingService {
// Write file content ( parameter.cfg ) to local tmp folders // Write file content ( parameter.cfg ) to local tmp folders
const trialForm : TrialJobApplicationForm = (<TrialJobApplicationForm>trialJobDetail.form); const trialForm : TrialJobApplicationForm = (<TrialJobApplicationForm>trialJobDetail.form);
if (trialForm) { if (trialForm !== undefined) {
await fs.promises.writeFile( await fs.promises.writeFile(
path.join(trialLocalTempFolder, generateParamFileName(trialForm.hyperParameters)), path.join(trialLocalTempFolder, generateParamFileName(trialForm.hyperParameters)),
trialForm.hyperParameters.value, { encoding: 'utf8' } trialForm.hyperParameters.value, { encoding: 'utf8' }
); );
} }
// tslint:disable-next-line: strict-boolean-expressions
const nniManagerIp: string = this.nniManagerIpConfig ? this.nniManagerIpConfig.nniManagerIp : getIPV4Address(); const nniManagerIp: string = this.nniManagerIpConfig ? this.nniManagerIpConfig.nniManagerIp : getIPV4Address();
const version: string = this.versionCheck ? await getVersion() : ''; const version: string = this.versionCheck ? await getVersion() : '';
const nniPaiTrialCommand : string = String.Format( const nniPaiTrialCommand : string = String.Format(
...@@ -446,8 +454,10 @@ class PAITrainingService implements TrainingService { ...@@ -446,8 +454,10 @@ class PAITrainingService implements TrainingService {
HDFSClientUtility.getHdfsExpCodeDir(this.paiClusterConfig.userName), HDFSClientUtility.getHdfsExpCodeDir(this.paiClusterConfig.userName),
version, version,
this.logCollection this.logCollection
).replace(/\r\n|\n|\r/gm, ''); )
.replace(/\r\n|\n|\r/gm, '');
// tslint:disable-next-line:no-console
console.log(`nniPAItrial command is ${nniPaiTrialCommand.trim()}`); console.log(`nniPAItrial command is ${nniPaiTrialCommand.trim()}`);
const paiTaskRoles : PAITaskRole[] = [ const paiTaskRoles : PAITaskRole[] = [
new PAITaskRole( new PAITaskRole(
...@@ -489,7 +499,10 @@ class PAITrainingService implements TrainingService { ...@@ -489,7 +499,10 @@ class PAITrainingService implements TrainingService {
await HDFSClientUtility.copyDirectoryToHdfs(trialLocalTempFolder, hdfsCodeDir, this.hdfsClient); await HDFSClientUtility.copyDirectoryToHdfs(trialLocalTempFolder, hdfsCodeDir, this.hdfsClient);
} catch (error) { } catch (error) {
this.log.error(`PAI Training service: copy ${this.paiTrialConfig.codeDir} to HDFS ${hdfsCodeDir} failed, error is ${error}`); this.log.error(`PAI Training service: copy ${this.paiTrialConfig.codeDir} to HDFS ${hdfsCodeDir} failed, error is ${error}`);
throw new Error(error.message); trialJobDetail.status = 'FAILED';
deferred.resolve(true);
return deferred.promise;
} }
// Step 3. Submit PAI job via Rest call // Step 3. Submit PAI job via Rest call
...@@ -504,13 +517,14 @@ class PAITrainingService implements TrainingService { ...@@ -504,13 +517,14 @@ class PAITrainingService implements TrainingService {
Authorization: `Bearer ${this.paiToken}` Authorization: `Bearer ${this.paiToken}`
} }
}; };
// tslint:disable:no-any no-unsafe-any
request(submitJobRequest, (error: Error, response: request.Response, body: any) => { request(submitJobRequest, (error: Error, response: request.Response, body: any) => {
if (error || response.statusCode >= 400) { if ((error !== undefined && error !== null) || response.statusCode >= 400) {
const errorMessage : string = error ? error.message : const errorMessage : string = (error !== undefined && error !== null) ? error.message :
`Submit trial ${trialJobId} failed, http code:${response.statusCode}, http body: ${response.body}`; `Submit trial ${trialJobId} failed, http code:${response.statusCode}, http body: ${response.body}`;
this.log.error(errorMessage); this.log.error(errorMessage);
trialJobDetail.status = 'FAILED'; trialJobDetail.status = 'FAILED';
deferred.reject(new Error(errorMessage)); deferred.resolve(true);
} else { } else {
trialJobDetail.submitTime = Date.now(); trialJobDetail.submitTime = Date.now();
deferred.resolve(true); deferred.resolve(true);
...@@ -530,18 +544,18 @@ class PAITrainingService implements TrainingService { ...@@ -530,18 +544,18 @@ class PAITrainingService implements TrainingService {
private async statusCheckingLoop(): Promise<void> { private async statusCheckingLoop(): Promise<void> {
while (!this.stopping) { while (!this.stopping) {
try{ try {
await this.updatePaiToken(); await this.updatePaiToken();
}catch(error){ } catch (error) {
this.log.error(`${error}`); this.log.error(`${error}`);
//only throw error when initlize paiToken first time //only throw error when initlize paiToken first time
if(!this.paiToken) { if (this.paiToken === undefined) {
throw new Error(error); throw new Error(error);
} }
} }
await this.paiJobCollector.retrieveTrialStatus(this.paiToken, this.paiClusterConfig); await this.paiJobCollector.retrieveTrialStatus(this.paiToken, this.paiClusterConfig);
const restServer: PAIJobRestServer = component.get(PAIJobRestServer); const restServer: PAIJobRestServer = component.get(PAIJobRestServer);
if (restServer.getErrorMessage) { if (restServer.getErrorMessage !== undefined) {
throw new Error(restServer.getErrorMessage); throw new Error(restServer.getErrorMessage);
} }
await delay(3000); await delay(3000);
...@@ -572,17 +586,17 @@ class PAITrainingService implements TrainingService { ...@@ -572,17 +586,17 @@ class PAITrainingService implements TrainingService {
const currentTime: number = new Date().getTime(); const currentTime: number = new Date().getTime();
//If pai token initialized and not reach the interval time, do not update //If pai token initialized and not reach the interval time, do not update
if (this.paiTokenUpdateTime && (currentTime - this.paiTokenUpdateTime) < this.paiTokenUpdateInterval){ if (this.paiTokenUpdateTime !== undefined && (currentTime - this.paiTokenUpdateTime) < this.paiTokenUpdateInterval) {
return Promise.resolve(); return Promise.resolve();
} }
if (!this.paiClusterConfig) { if (this.paiClusterConfig === undefined) {
const paiClusterConfigError: string = `pai cluster config not initialized!`; const paiClusterConfigError: string = `pai cluster config not initialized!`;
this.log.error(`${paiClusterConfigError}`); this.log.error(`${paiClusterConfigError}`);
throw Error(`${paiClusterConfigError}`); throw Error(`${paiClusterConfigError}`);
} }
const authentication_req: request.Options = { const authenticationReq: request.Options = {
uri: `http://${this.paiClusterConfig.host}/rest-server/api/v1/token`, uri: `http://${this.paiClusterConfig.host}/rest-server/api/v1/token`,
method: 'POST', method: 'POST',
json: true, json: true,
...@@ -592,12 +606,12 @@ class PAITrainingService implements TrainingService { ...@@ -592,12 +606,12 @@ class PAITrainingService implements TrainingService {
} }
}; };
request(authentication_req, (error: Error, response: request.Response, body: any) => { request(authenticationReq, (error: Error, response: request.Response, body: any) => {
if (error) { if (error !== undefined && error !== null) {
this.log.error(`Get PAI token failed: ${error.message}`); this.log.error(`Get PAI token failed: ${error.message}`);
deferred.reject(new Error(`Get PAI token failed: ${error.message}`)); deferred.reject(new Error(`Get PAI token failed: ${error.message}`));
} else { } else {
if (response.statusCode !== 200){ if (response.statusCode !== 200) {
this.log.error(`Get PAI token failed: get PAI Rest return code ${response.statusCode}`); this.log.error(`Get PAI token failed: get PAI Rest return code ${response.statusCode}`);
deferred.reject(new Error(`Get PAI token failed: ${response.body}, please check paiConfig username or password`)); deferred.reject(new Error(`Get PAI token failed: ${response.body}, please check paiConfig username or password`));
} }
...@@ -616,8 +630,9 @@ class PAITrainingService implements TrainingService { ...@@ -616,8 +630,9 @@ class PAITrainingService implements TrainingService {
}); });
return Promise.race([timeoutDelay, deferred.promise]) return Promise.race([timeoutDelay, deferred.promise])
.finally(() => clearTimeout(timeoutId)); .finally(() => { clearTimeout(timeoutId); });
} }
// tslint:enable:no-any no-unsafe-any no-http-string
} }
export { PAITrainingService }; export { PAITrainingService };
...@@ -19,16 +19,20 @@ ...@@ -19,16 +19,20 @@
'use strict'; 'use strict';
import {TrialConfig} from '../common/trialConfig' import {TrialConfig} from '../common/trialConfig';
export class PAITrialConfig extends TrialConfig{ /**
* PAI configuration to run trials
*/
export class PAITrialConfig extends TrialConfig {
public readonly cpuNum: number; public readonly cpuNum: number;
public readonly memoryMB: number; public readonly memoryMB: number;
public readonly image: string; public readonly image: string;
public readonly dataDir: string; public readonly dataDir: string;
public readonly outputDir: string; public readonly outputDir: string;
constructor(command : string, codeDir : string, gpuNum : number, cpuNum: number, memoryMB: number, image: string, dataDir: string, outputDir: string) { constructor(command : string, codeDir : string, gpuNum : number, cpuNum: number, memoryMB: number,
image: string, dataDir: string, outputDir: string) {
super(command, codeDir, gpuNum); super(command, codeDir, gpuNum);
this.cpuNum = cpuNum; this.cpuNum = cpuNum;
this.memoryMB = memoryMB; this.memoryMB = memoryMB;
...@@ -36,4 +40,4 @@ export class PAITrialConfig extends TrialConfig{ ...@@ -36,4 +40,4 @@ export class PAITrialConfig extends TrialConfig{
this.dataDir = dataDir; this.dataDir = dataDir;
this.outputDir = outputDir; this.outputDir = outputDir;
} }
} }
\ No newline at end of file
...@@ -21,10 +21,12 @@ ...@@ -21,10 +21,12 @@
import * as assert from 'assert'; import * as assert from 'assert';
import { getLogger, Logger } from '../../common/log'; import { getLogger, Logger } from '../../common/log';
import { TrialJobDetail } from '../../common/trainingService';
import { randomSelect } from '../../common/utils'; import { randomSelect } from '../../common/utils';
import { GPUInfo } from '../common/gpuData'; import { GPUInfo } from '../common/gpuData';
import { RemoteMachineTrialJobDetail, parseGpuIndices, RemoteMachineMeta, RemoteMachineScheduleResult, ScheduleResultType, SSHClientManager } from './remoteMachineData'; import {
import { TrialJobDetail } from 'common/trainingService'; parseGpuIndices, RemoteMachineMeta, RemoteMachineScheduleResult, RemoteMachineTrialJobDetail, ScheduleResultType, SSHClientManager
} from './remoteMachineData';
/** /**
* A simple GPU scheduler implementation * A simple GPU scheduler implementation
...@@ -32,7 +34,7 @@ import { TrialJobDetail } from 'common/trainingService'; ...@@ -32,7 +34,7 @@ import { TrialJobDetail } from 'common/trainingService';
export class GPUScheduler { export class GPUScheduler {
private readonly machineSSHClientMap : Map<RemoteMachineMeta, SSHClientManager>; private readonly machineSSHClientMap : Map<RemoteMachineMeta, SSHClientManager>;
private log: Logger = getLogger(); private readonly log: Logger = getLogger();
/** /**
* Constructor * Constructor
...@@ -89,21 +91,21 @@ export class GPUScheduler { ...@@ -89,21 +91,21 @@ export class GPUScheduler {
* remove the job's gpu reversion * remove the job's gpu reversion
*/ */
public removeGpuReservation(trialJobId: string, trialJobMap: Map<string, RemoteMachineTrialJobDetail>): void { public removeGpuReservation(trialJobId: string, trialJobMap: Map<string, RemoteMachineTrialJobDetail>): void {
let trialJobDetail: RemoteMachineTrialJobDetail | undefined = trialJobMap.get(trialJobId); const trialJobDetail: RemoteMachineTrialJobDetail | undefined = trialJobMap.get(trialJobId);
if(trialJobDetail === undefined) { if (trialJobDetail === undefined) {
throw new Error(`could not get trialJobDetail by id ${trialJobId}`); throw new Error(`could not get trialJobDetail by id ${trialJobId}`);
} }
if (trialJobDetail.rmMeta !== undefined && if (trialJobDetail.rmMeta !== undefined &&
trialJobDetail.rmMeta.occupiedGpuIndexMap !== undefined && trialJobDetail.rmMeta.occupiedGpuIndexMap !== undefined &&
trialJobDetail.gpuIndices !== undefined && trialJobDetail.gpuIndices !== undefined &&
trialJobDetail.gpuIndices.length > 0) { trialJobDetail.gpuIndices.length > 0) {
for (const gpuInfo of trialJobDetail.gpuIndices) { for (const gpuInfo of trialJobDetail.gpuIndices) {
let num: number | undefined = trialJobDetail.rmMeta.occupiedGpuIndexMap.get(gpuInfo.index); const num: number | undefined = trialJobDetail.rmMeta.occupiedGpuIndexMap.get(gpuInfo.index);
if(num !== undefined) { if (num !== undefined) {
if(num === 1) { if (num === 1) {
trialJobDetail.rmMeta.occupiedGpuIndexMap.delete(gpuInfo.index); trialJobDetail.rmMeta.occupiedGpuIndexMap.delete(gpuInfo.index);
} else { } else {
trialJobDetail.rmMeta.occupiedGpuIndexMap.set(gpuInfo.index, num - 1) trialJobDetail.rmMeta.occupiedGpuIndexMap.set(gpuInfo.index, num - 1);
} }
} }
} }
...@@ -116,7 +118,6 @@ export class GPUScheduler { ...@@ -116,7 +118,6 @@ export class GPUScheduler {
const totalResourceMap: Map<RemoteMachineMeta, GPUInfo[]> = this.gpuResourceDetection(); const totalResourceMap: Map<RemoteMachineMeta, GPUInfo[]> = this.gpuResourceDetection();
const qualifiedRMs: RemoteMachineMeta[] = []; const qualifiedRMs: RemoteMachineMeta[] = [];
totalResourceMap.forEach((gpuInfos: GPUInfo[], rmMeta: RemoteMachineMeta) => { totalResourceMap.forEach((gpuInfos: GPUInfo[], rmMeta: RemoteMachineMeta) => {
if (gpuInfos !== undefined && gpuInfos.length >= requiredGPUNum) { if (gpuInfos !== undefined && gpuInfos.length >= requiredGPUNum) {
qualifiedRMs.push(rmMeta); qualifiedRMs.push(rmMeta);
} }
...@@ -154,6 +155,7 @@ export class GPUScheduler { ...@@ -154,6 +155,7 @@ export class GPUScheduler {
} }
} }
this.log.debug(`designated gpu indices: ${designatedGpuIndices}`); this.log.debug(`designated gpu indices: ${designatedGpuIndices}`);
// tslint:disable: strict-boolean-expressions
rmMeta.gpuSummary.gpuInfos.forEach((gpuInfo: GPUInfo) => { rmMeta.gpuSummary.gpuInfos.forEach((gpuInfo: GPUInfo) => {
// if the GPU has active process, OR be reserved by a job, // if the GPU has active process, OR be reserved by a job,
// or index not in gpuIndices configuration in machineList, // or index not in gpuIndices configuration in machineList,
...@@ -161,10 +163,10 @@ export class GPUScheduler { ...@@ -161,10 +163,10 @@ export class GPUScheduler {
// We should NOT allocate this GPU // We should NOT allocate this GPU
// if users set useActiveGpu, use the gpu whether there is another activeProcess // if users set useActiveGpu, use the gpu whether there is another activeProcess
if (designatedGpuIndices === undefined || designatedGpuIndices.has(gpuInfo.index)) { if (designatedGpuIndices === undefined || designatedGpuIndices.has(gpuInfo.index)) {
if(rmMeta.occupiedGpuIndexMap !== undefined) { if (rmMeta.occupiedGpuIndexMap !== undefined) {
let num = rmMeta.occupiedGpuIndexMap.get(gpuInfo.index); const num: number | undefined = rmMeta.occupiedGpuIndexMap.get(gpuInfo.index);
let maxTrialNumPerGpu: number = rmMeta.maxTrialNumPerGpu? rmMeta.maxTrialNumPerGpu: 1; const maxTrialNumPerGpu: number = rmMeta.maxTrialNumPerGpu ? rmMeta.maxTrialNumPerGpu : 1;
if((num === undefined && (!rmMeta.useActiveGpu && gpuInfo.activeProcessNum === 0 || rmMeta.useActiveGpu)) || if ((num === undefined && (!rmMeta.useActiveGpu && gpuInfo.activeProcessNum === 0 || rmMeta.useActiveGpu)) ||
(num !== undefined && num < maxTrialNumPerGpu)) { (num !== undefined && num < maxTrialNumPerGpu)) {
availableGPUs.push(gpuInfo); availableGPUs.push(gpuInfo);
} }
...@@ -179,6 +181,7 @@ export class GPUScheduler { ...@@ -179,6 +181,7 @@ export class GPUScheduler {
return totalResourceMap; return totalResourceMap;
} }
// tslint:enable: strict-boolean-expressions
private selectMachine(rmMetas: RemoteMachineMeta[]): RemoteMachineMeta { private selectMachine(rmMetas: RemoteMachineMeta[]): RemoteMachineMeta {
assert(rmMetas !== undefined && rmMetas.length > 0); assert(rmMetas !== undefined && rmMetas.length > 0);
...@@ -196,23 +199,28 @@ export class GPUScheduler { ...@@ -196,23 +199,28 @@ export class GPUScheduler {
assert(gpuInfos.length >= requiredGPUNum); assert(gpuInfos.length >= requiredGPUNum);
const allocatedGPUs: GPUInfo[] = this.selectGPUsForTrial(gpuInfos, requiredGPUNum); const allocatedGPUs: GPUInfo[] = this.selectGPUsForTrial(gpuInfos, requiredGPUNum);
allocatedGPUs.forEach((gpuInfo: GPUInfo) => { allocatedGPUs.forEach((gpuInfo: GPUInfo) => {
if(rmMeta.occupiedGpuIndexMap !== undefined) { if (rmMeta.occupiedGpuIndexMap !== undefined) {
let num = rmMeta.occupiedGpuIndexMap.get(gpuInfo.index); let num: number | undefined = rmMeta.occupiedGpuIndexMap.get(gpuInfo.index);
if(num === undefined) { if (num === undefined) {
num = 0; num = 0;
} }
rmMeta.occupiedGpuIndexMap.set(gpuInfo.index, num + 1); rmMeta.occupiedGpuIndexMap.set(gpuInfo.index, num + 1);
}else { } else {
throw new Error(`Machine ${rmMeta.ip} occupiedGpuIndexMap initialize error!`); throw new Error(`Machine ${rmMeta.ip} occupiedGpuIndexMap initialize error!`);
} }
}); });
trialJobDetail.gpuIndices = allocatedGPUs; trialJobDetail.gpuIndices = allocatedGPUs;
trialJobDetail.rmMeta = rmMeta; trialJobDetail.rmMeta = rmMeta;
return { return {
resultType: ScheduleResultType.SUCCEED, resultType: ScheduleResultType.SUCCEED,
scheduleInfo: { scheduleInfo: {
rmMeta: rmMeta, rmMeta: rmMeta,
cuda_visible_device: allocatedGPUs.map((gpuInfo: GPUInfo) => { return gpuInfo.index; }).join(',') cuda_visible_device: allocatedGPUs
.map((gpuInfo: GPUInfo) => {
return gpuInfo.index;
})
.join(',')
} }
}; };
} }
......
...@@ -23,7 +23,7 @@ import * as fs from 'fs'; ...@@ -23,7 +23,7 @@ import * as fs from 'fs';
import { Client, ConnectConfig } from 'ssh2'; import { Client, ConnectConfig } from 'ssh2';
import { Deferred } from 'ts-deferred'; import { Deferred } from 'ts-deferred';
import { JobApplicationForm, TrialJobDetail, TrialJobStatus } from '../../common/trainingService'; import { JobApplicationForm, TrialJobDetail, TrialJobStatus } from '../../common/trainingService';
import { GPUSummary, GPUInfo } from '../common/gpuData'; import { GPUInfo, GPUSummary } from '../common/gpuData';
/** /**
* Metadata of remote machine for configuration and statuc query * Metadata of remote machine for configuration and statuc query
...@@ -73,7 +73,6 @@ export class RemoteCommandResult { ...@@ -73,7 +73,6 @@ export class RemoteCommandResult {
/** /**
* RemoteMachineTrialJobDetail * RemoteMachineTrialJobDetail
*/ */
// tslint:disable-next-line:max-classes-per-file
export class RemoteMachineTrialJobDetail implements TrialJobDetail { export class RemoteMachineTrialJobDetail implements TrialJobDetail {
public id: string; public id: string;
public status: TrialJobStatus; public status: TrialJobStatus;
...@@ -98,7 +97,7 @@ export class RemoteMachineTrialJobDetail implements TrialJobDetail { ...@@ -98,7 +97,7 @@ export class RemoteMachineTrialJobDetail implements TrialJobDetail {
this.form = form; this.form = form;
this.sequenceId = sequenceId; this.sequenceId = sequenceId;
this.tags = []; this.tags = [];
this.gpuIndices = [] this.gpuIndices = [];
} }
} }
...@@ -112,7 +111,7 @@ export class SSHClient { ...@@ -112,7 +111,7 @@ export class SSHClient {
this.sshClient = sshClient; this.sshClient = sshClient;
this.usedConnectionNumber = usedConnectionNumber; this.usedConnectionNumber = usedConnectionNumber;
} }
public get getSSHClientInstance(): Client { public get getSSHClientInstance(): Client {
return this.sshClient; return this.sshClient;
} }
...@@ -121,17 +120,20 @@ export class SSHClient { ...@@ -121,17 +120,20 @@ export class SSHClient {
return this.usedConnectionNumber; return this.usedConnectionNumber;
} }
public addUsedConnectionNumber() { public addUsedConnectionNumber(): void {
this.usedConnectionNumber += 1; this.usedConnectionNumber += 1;
} }
public minusUsedConnectionNumber() { public minusUsedConnectionNumber(): void {
this.usedConnectionNumber -= 1; this.usedConnectionNumber -= 1;
} }
} }
/**
* The remote machine ssh client manager
*/
export class SSHClientManager { export class SSHClientManager {
private sshClientArray: SSHClient[]; private readonly sshClientArray: SSHClient[];
private readonly maxTrialNumberPerConnection: number; private readonly maxTrialNumberPerConnection: number;
private readonly rmMeta: RemoteMachineMeta; private readonly rmMeta: RemoteMachineMeta;
constructor(sshClientArray: SSHClient[], maxTrialNumberPerConnection: number, rmMeta: RemoteMachineMeta) { constructor(sshClientArray: SSHClient[], maxTrialNumberPerConnection: number, rmMeta: RemoteMachineMeta) {
...@@ -140,122 +142,128 @@ export class SSHClientManager { ...@@ -140,122 +142,128 @@ export class SSHClientManager {
this.maxTrialNumberPerConnection = maxTrialNumberPerConnection; this.maxTrialNumberPerConnection = maxTrialNumberPerConnection;
} }
/**
* Create a new ssh connection client and initialize it
*/
private initNewSSHClient(): Promise<Client> {
const deferred: Deferred<Client> = new Deferred<Client>();
const conn: Client = new Client();
let connectConfig: ConnectConfig = {
host: this.rmMeta.ip,
port: this.rmMeta.port,
username: this.rmMeta.username };
if (this.rmMeta.passwd) {
connectConfig.password = this.rmMeta.passwd;
} else if(this.rmMeta.sshKeyPath) {
if(!fs.existsSync(this.rmMeta.sshKeyPath)) {
//SSh key path is not a valid file, reject
deferred.reject(new Error(`${this.rmMeta.sshKeyPath} does not exist.`));
}
const privateKey: string = fs.readFileSync(this.rmMeta.sshKeyPath, 'utf8');
connectConfig.privateKey = privateKey;
connectConfig.passphrase = this.rmMeta.passphrase;
} else {
deferred.reject(new Error(`No valid passwd or sshKeyPath is configed.`));
}
conn.on('ready', () => {
this.addNewSSHClient(conn);
deferred.resolve(conn);
}).on('error', (err: Error) => {
// SSH connection error, reject with error message
deferred.reject(new Error(err.message));
}).connect(connectConfig);
return deferred.promise;
}
/** /**
* find a available ssh client in ssh array, if no ssh client available, return undefined * find a available ssh client in ssh array, if no ssh client available, return undefined
*/ */
public async getAvailableSSHClient(): Promise<Client> { public async getAvailableSSHClient(): Promise<Client> {
const deferred: Deferred<Client> = new Deferred<Client>(); const deferred: Deferred<Client> = new Deferred<Client>();
for (const index in this.sshClientArray) { for (const index of this.sshClientArray.keys()) {
let connectionNumber: number = this.sshClientArray[index].getUsedConnectionNumber; const connectionNumber: number = this.sshClientArray[index].getUsedConnectionNumber;
if(connectionNumber < this.maxTrialNumberPerConnection) { if (connectionNumber < this.maxTrialNumberPerConnection) {
this.sshClientArray[index].addUsedConnectionNumber(); this.sshClientArray[index].addUsedConnectionNumber();
deferred.resolve(this.sshClientArray[index].getSSHClientInstance); deferred.resolve(this.sshClientArray[index].getSSHClientInstance);
return deferred.promise; return deferred.promise;
} }
}; }
//init a new ssh client if could not get an available one //init a new ssh client if could not get an available one
return await this.initNewSSHClient(); return this.initNewSSHClient();
} }
/** /**
* add a new ssh client to sshClientArray * add a new ssh client to sshClientArray
* @param sshClient * @param sshClient SSH Client
*/ */
public addNewSSHClient(client: Client) { public addNewSSHClient(client: Client): void {
this.sshClientArray.push(new SSHClient(client, 1)); this.sshClientArray.push(new SSHClient(client, 1));
} }
/** /**
* first ssh clilent instance is used for gpu collector and host job * first ssh client instance is used for gpu collector and host job
*/ */
public getFirstSSHClient() { public getFirstSSHClient(): Client {
return this.sshClientArray[0].getSSHClientInstance; return this.sshClientArray[0].getSSHClientInstance;
} }
/** /**
* close all of ssh client * close all of ssh client
*/ */
public closeAllSSHClient() { public closeAllSSHClient(): void {
for (let sshClient of this.sshClientArray) { for (const sshClient of this.sshClientArray) {
sshClient.getSSHClientInstance.end(); sshClient.getSSHClientInstance.end();
} }
} }
/** /**
* retrieve resource, minus a number for given ssh client * retrieve resource, minus a number for given ssh client
* @param client * @param client SSH Client
*/ */
public releaseConnection(client: Client | undefined) { public releaseConnection(client: Client | undefined): void {
if(!client) { if (client === undefined) {
throw new Error(`could not release a undefined ssh client`); throw new Error(`could not release a undefined ssh client`);
} }
for(let index in this.sshClientArray) { for (const index of this.sshClientArray.keys()) {
if(this.sshClientArray[index].getSSHClientInstance === client) { if (this.sshClientArray[index].getSSHClientInstance === client) {
this.sshClientArray[index].minusUsedConnectionNumber(); this.sshClientArray[index].minusUsedConnectionNumber();
break; break;
} }
} }
} }
}
/**
* Create a new ssh connection client and initialize it
*/
// tslint:disable:non-literal-fs-path
private initNewSSHClient(): Promise<Client> {
const deferred: Deferred<Client> = new Deferred<Client>();
const conn: Client = new Client();
const connectConfig: ConnectConfig = {
host: this.rmMeta.ip,
port: this.rmMeta.port,
username: this.rmMeta.username };
if (this.rmMeta.passwd !== undefined) {
connectConfig.password = this.rmMeta.passwd;
} else if (this.rmMeta.sshKeyPath !== undefined) {
if (!fs.existsSync(this.rmMeta.sshKeyPath)) {
//SSh key path is not a valid file, reject
deferred.reject(new Error(`${this.rmMeta.sshKeyPath} does not exist.`));
}
const privateKey: string = fs.readFileSync(this.rmMeta.sshKeyPath, 'utf8');
connectConfig.privateKey = privateKey;
connectConfig.passphrase = this.rmMeta.passphrase;
} else {
deferred.reject(new Error(`No valid passwd or sshKeyPath is configed.`));
}
conn.on('ready', () => {
this.addNewSSHClient(conn);
deferred.resolve(conn);
})
.on('error', (err: Error) => {
// SSH connection error, reject with error message
deferred.reject(new Error(err.message));
})
.connect(connectConfig);
return deferred.promise;
}
}
export type RemoteMachineScheduleResult = { scheduleInfo : RemoteMachineScheduleInfo | undefined; resultType : ScheduleResultType}; export type RemoteMachineScheduleResult = { scheduleInfo : RemoteMachineScheduleInfo | undefined; resultType : ScheduleResultType};
export type RemoteMachineScheduleInfo = { rmMeta : RemoteMachineMeta; cuda_visible_device : string}; export type RemoteMachineScheduleInfo = { rmMeta : RemoteMachineMeta; cuda_visible_device : string};
export enum ScheduleResultType { export enum ScheduleResultType {
/* Schedule succeeded*/ // Schedule succeeded
SUCCEED, SUCCEED,
/* Temporarily, no enough available GPU right now */ // Temporarily, no enough available GPU right now
TMP_NO_AVAILABLE_GPU, TMP_NO_AVAILABLE_GPU,
/* Cannot match requirement even if all GPU are a*/ // Cannot match requirement even if all GPU are a
REQUIRE_EXCEED_TOTAL REQUIRE_EXCEED_TOTAL
} }
export const REMOTEMACHINE_TRIAL_COMMAND_FORMAT: string = export const REMOTEMACHINE_TRIAL_COMMAND_FORMAT: string =
`#!/bin/bash `#!/bin/bash
export NNI_PLATFORM=remote NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={2} NNI_EXP_ID={3} NNI_TRIAL_SEQ_ID={4} export MULTI_PHASE={5} export NNI_PLATFORM=remote NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={2} NNI_EXP_ID={3} \
NNI_TRIAL_SEQ_ID={4} export MULTI_PHASE={5}
cd $NNI_SYS_DIR cd $NNI_SYS_DIR
sh install_nni.sh sh install_nni.sh
echo $$ >{6} echo $$ >{6}
python3 -m nni_trial_tool.trial_keeper --trial_command '{7}' --nnimanager_ip '{8}' --nnimanager_port '{9}' --nni_manager_version '{10}' --log_collection '{11}' 1>$NNI_OUTPUT_DIR/trialkeeper_stdout 2>$NNI_OUTPUT_DIR/trialkeeper_stderr python3 -m nni_trial_tool.trial_keeper --trial_command '{7}' --nnimanager_ip '{8}' --nnimanager_port '{9}' \
--nni_manager_version '{10}' --log_collection '{11}' 1>$NNI_OUTPUT_DIR/trialkeeper_stdout 2>$NNI_OUTPUT_DIR/trialkeeper_stderr
echo $? \`date +%s%3N\` >{12}`; echo $? \`date +%s%3N\` >{12}`;
export const HOST_JOB_SHELL_FORMAT: string = export const HOST_JOB_SHELL_FORMAT: string =
......
...@@ -19,17 +19,17 @@ ...@@ -19,17 +19,17 @@
'use strict'; 'use strict';
import * as component from '../../common/component';
import { Inject } from 'typescript-ioc'; import { Inject } from 'typescript-ioc';
import * as component from '../../common/component';
import { ClusterJobRestServer } from '../common/clusterJobRestServer';
import { RemoteMachineTrainingService } from './remoteMachineTrainingService'; import { RemoteMachineTrainingService } from './remoteMachineTrainingService';
import { ClusterJobRestServer } from '../common/clusterJobRestServer'
/** /**
* RemoteMachine Training service Rest server, provides rest RemoteMachine to support remotemachine job metrics update * RemoteMachine Training service Rest server, provides rest RemoteMachine to support remotemachine job metrics update
* *
*/ */
@component.Singleton @component.Singleton
export class RemoteMachineJobRestServer extends ClusterJobRestServer{ export class RemoteMachineJobRestServer extends ClusterJobRestServer {
@Inject @Inject
private readonly remoteMachineTrainingService : RemoteMachineTrainingService; private readonly remoteMachineTrainingService : RemoteMachineTrainingService;
...@@ -41,6 +41,7 @@ export class RemoteMachineJobRestServer extends ClusterJobRestServer{ ...@@ -41,6 +41,7 @@ export class RemoteMachineJobRestServer extends ClusterJobRestServer{
this.remoteMachineTrainingService = component.get(RemoteMachineTrainingService); this.remoteMachineTrainingService = component.get(RemoteMachineTrainingService);
} }
// tslint:disable-next-line:no-any
protected handleTrialMetrics(jobId : string, metrics : any[]) : void { protected handleTrialMetrics(jobId : string, metrics : any[]) : void {
// Split metrics array into single metric, then emit // Split metrics array into single metric, then emit
// Warning: If not split metrics into single ones, the behavior will be UNKNOWNls // Warning: If not split metrics into single ones, the behavior will be UNKNOWNls
...@@ -51,4 +52,4 @@ export class RemoteMachineJobRestServer extends ClusterJobRestServer{ ...@@ -51,4 +52,4 @@ export class RemoteMachineJobRestServer extends ClusterJobRestServer{
}); });
} }
} }
} }
\ No newline at end of file
...@@ -34,42 +34,45 @@ import { getExperimentId, getInitTrialSequenceId } from '../../common/experiment ...@@ -34,42 +34,45 @@ import { getExperimentId, getInitTrialSequenceId } from '../../common/experiment
import { getLogger, Logger } from '../../common/log'; import { getLogger, Logger } from '../../common/log';
import { ObservableTimer } from '../../common/observableTimer'; import { ObservableTimer } from '../../common/observableTimer';
import { import {
HostJobApplicationForm, HyperParameters, JobApplicationForm, TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric, NNIManagerIpConfig HostJobApplicationForm, HyperParameters, JobApplicationForm, NNIManagerIpConfig, TrainingService, TrialJobApplicationForm,
TrialJobDetail, TrialJobMetric
} from '../../common/trainingService'; } from '../../common/trainingService';
import { delay, generateParamFileName, getExperimentRootDir, uniqueString, getJobCancelStatus, getRemoteTmpDir,getIPV4Address, getVersion, unixPathJoin } from '../../common/utils'; import {
import { GPUSummary } from '../common/gpuData'; delay, generateParamFileName, getExperimentRootDir, getIPV4Address, getJobCancelStatus, getRemoteTmpDir,
getVersion, uniqueString, unixPathJoin
} from '../../common/utils';
import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData';
import { GPU_INFO_COLLECTOR_FORMAT_LINUX, GPUSummary } from '../common/gpuData';
import { TrialConfig } from '../common/trialConfig'; import { TrialConfig } from '../common/trialConfig';
import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey'; import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
import { execCopydir, execMkdir, execRemove, validateCodeDir } from '../common/util';
import { GPUScheduler } from './gpuScheduler'; import { GPUScheduler } from './gpuScheduler';
import { import {
HOST_JOB_SHELL_FORMAT, RemoteCommandResult, RemoteMachineMeta, HOST_JOB_SHELL_FORMAT, RemoteCommandResult, REMOTEMACHINE_TRIAL_COMMAND_FORMAT, RemoteMachineMeta,
RemoteMachineScheduleInfo, RemoteMachineScheduleResult, SSHClient, SSHClientManager, RemoteMachineScheduleInfo, RemoteMachineScheduleResult, RemoteMachineTrialJobDetail,
RemoteMachineTrialJobDetail, ScheduleResultType, REMOTEMACHINE_TRIAL_COMMAND_FORMAT ScheduleResultType, SSHClient, SSHClientManager
} from './remoteMachineData'; } from './remoteMachineData';
import { GPU_INFO_COLLECTOR_FORMAT_LINUX } from '../common/gpuData';
import { SSHClientUtility } from './sshClientUtility';
import { validateCodeDir, execRemove, execMkdir, execCopydir } from '../common/util';
import { RemoteMachineJobRestServer } from './remoteMachineJobRestServer'; import { RemoteMachineJobRestServer } from './remoteMachineJobRestServer';
import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData'; import { SSHClientUtility } from './sshClientUtility';
/** /**
* Training Service implementation for Remote Machine (Linux) * Training Service implementation for Remote Machine (Linux)
*/ */
@component.Singleton @component.Singleton
class RemoteMachineTrainingService implements TrainingService { class RemoteMachineTrainingService implements TrainingService {
private machineSSHClientMap: Map<RemoteMachineMeta, SSHClientManager>; //machine ssh client map private readonly machineSSHClientMap: Map<RemoteMachineMeta, SSHClientManager>; //machine ssh client map
private trialSSHClientMap: Map<string, Client>; //trial ssh client map private readonly trialSSHClientMap: Map<string, Client>; //trial ssh client map
private trialJobsMap: Map<string, RemoteMachineTrialJobDetail>; private readonly trialJobsMap: Map<string, RemoteMachineTrialJobDetail>;
private readonly MAX_TRIAL_NUMBER_PER_SSHCONNECTION: number = 5 // every ssh client has a max trial concurrency number private readonly MAX_TRIAL_NUMBER_PER_SSHCONNECTION: number = 5; // every ssh client has a max trial concurrency number
private expRootDir: string; private readonly expRootDir: string;
private remoteExpRootDir: string; private readonly remoteExpRootDir: string;
private trialConfig: TrialConfig | undefined; private trialConfig: TrialConfig | undefined;
private gpuScheduler: GPUScheduler; private readonly gpuScheduler: GPUScheduler;
private jobQueue: string[]; private readonly jobQueue: string[];
private timer: ObservableTimer; private readonly timer: ObservableTimer;
private stopping: boolean = false; private stopping: boolean = false;
private metricsEmitter: EventEmitter; private readonly metricsEmitter: EventEmitter;
private log: Logger; private readonly log: Logger;
private isMultiPhase: boolean = false; private isMultiPhase: boolean = false;
private trialSequenceId: number; private trialSequenceId: number;
private remoteRestServerPort?: number; private remoteRestServerPort?: number;
...@@ -117,7 +120,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -117,7 +120,7 @@ class RemoteMachineTrainingService implements TrainingService {
break; break;
} }
} }
if(restServer.getErrorMessage) { if (restServer.getErrorMessage !== undefined) {
throw new Error(restServer.getErrorMessage); throw new Error(restServer.getErrorMessage);
this.stopping = true; this.stopping = true;
} }
...@@ -125,36 +128,37 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -125,36 +128,37 @@ class RemoteMachineTrainingService implements TrainingService {
} }
this.log.info('Remote machine training service exit.'); this.log.info('Remote machine training service exit.');
} }
/** /**
* give trial a ssh connection * give trial a ssh connection
* @param trial * @param trial remote machine trial job detail
*/ */
public async allocateSSHClientForTrial(trial: RemoteMachineTrialJobDetail): Promise<void> { public async allocateSSHClientForTrial(trial: RemoteMachineTrialJobDetail): Promise<void> {
const deferred: Deferred<void> = new Deferred<void>(); const deferred: Deferred<void> = new Deferred<void>();
if(!trial.rmMeta) { if (trial.rmMeta === undefined) {
throw new Error(`rmMeta not set in trial ${trial.id}`); throw new Error(`rmMeta not set in trial ${trial.id}`);
} }
let sshClientManager: SSHClientManager | undefined = this.machineSSHClientMap.get(trial.rmMeta); const sshClientManager: SSHClientManager | undefined = this.machineSSHClientMap.get(trial.rmMeta);
if(!sshClientManager) { if (sshClientManager === undefined) {
throw new Error(`remoteSSHClient not initialized`); throw new Error(`remoteSSHClient not initialized`);
} }
let sshClient: Client = await sshClientManager.getAvailableSSHClient(); const sshClient: Client = await sshClientManager.getAvailableSSHClient();
this.trialSSHClientMap.set(trial.id, sshClient); this.trialSSHClientMap.set(trial.id, sshClient);
deferred.resolve(); deferred.resolve();
return deferred.promise; return deferred.promise;
} }
/** /**
* If a trial is finished, release the connection resource * If a trial is finished, release the connection resource
* @param trial * @param trial remote machine trial job detail
*/ */
public releaseTrialSSHClient(trial: RemoteMachineTrialJobDetail): void { public releaseTrialSSHClient(trial: RemoteMachineTrialJobDetail): void {
if(!trial.rmMeta) { if (trial.rmMeta === undefined) {
throw new Error(`rmMeta not set in trial ${trial.id}`); throw new Error(`rmMeta not set in trial ${trial.id}`);
} }
let sshClientManager: SSHClientManager | undefined = this.machineSSHClientMap.get(trial.rmMeta); const sshClientManager: SSHClientManager | undefined = this.machineSSHClientMap.get(trial.rmMeta);
if(!sshClientManager) { if (sshClientManager === undefined) {
throw new Error(`sshClientManager not initialized`); throw new Error(`sshClientManager not initialized`);
} }
sshClientManager.releaseConnection(this.trialSSHClientMap.get(trial.id)); sshClientManager.releaseConnection(this.trialSSHClientMap.get(trial.id));
...@@ -167,11 +171,11 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -167,11 +171,11 @@ class RemoteMachineTrainingService implements TrainingService {
const jobs: TrialJobDetail[] = []; const jobs: TrialJobDetail[] = [];
const deferred: Deferred<TrialJobDetail[]> = new Deferred<TrialJobDetail[]>(); const deferred: Deferred<TrialJobDetail[]> = new Deferred<TrialJobDetail[]>();
for (const [key, value] of this.trialJobsMap) { for (const [key, value] of this.trialJobsMap) {
if (value.form.jobType === 'TRIAL') { if (value.form.jobType === 'TRIAL') {
jobs.push(await this.getTrialJob(key)); jobs.push(await this.getTrialJob(key));
} }
}; }
deferred.resolve(jobs); deferred.resolve(jobs);
return deferred.promise; return deferred.promise;
...@@ -183,7 +187,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -183,7 +187,7 @@ class RemoteMachineTrainingService implements TrainingService {
*/ */
public async getTrialJob(trialJobId: string): Promise<TrialJobDetail> { public async getTrialJob(trialJobId: string): Promise<TrialJobDetail> {
const trialJob: RemoteMachineTrialJobDetail | undefined = this.trialJobsMap.get(trialJobId); const trialJob: RemoteMachineTrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
if (!trialJob) { if (trialJob === undefined) {
throw new NNIError(NNIErrorNames.NOT_FOUND, `trial job id ${trialJobId} not found`); throw new NNIError(NNIErrorNames.NOT_FOUND, `trial job id ${trialJobId} not found`);
} }
//TO DO: add another job status, and design new job status change logic //TO DO: add another job status, and design new job status change logic
...@@ -193,7 +197,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -193,7 +197,7 @@ class RemoteMachineTrainingService implements TrainingService {
throw new Error(`rmMeta not set for submitted job ${trialJobId}`); throw new Error(`rmMeta not set for submitted job ${trialJobId}`);
} }
const sshClient: Client | undefined = this.trialSSHClientMap.get(trialJob.id); const sshClient: Client | undefined = this.trialSSHClientMap.get(trialJob.id);
if (!sshClient) { if (sshClient === undefined) {
throw new Error(`Invalid job id: ${trialJobId}, cannot find ssh client`); throw new Error(`Invalid job id: ${trialJobId}, cannot find ssh client`);
} }
...@@ -223,8 +227,9 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -223,8 +227,9 @@ class RemoteMachineTrainingService implements TrainingService {
* Submit trial job * Submit trial job
* @param form trial job description form * @param form trial job description form
*/ */
// tslint:disable-next-line:informative-docs
public async submitTrialJob(form: JobApplicationForm): Promise<TrialJobDetail> { public async submitTrialJob(form: JobApplicationForm): Promise<TrialJobDetail> {
if (!this.trialConfig) { if (this.trialConfig === undefined) {
throw new Error('trial config is not initialized'); throw new Error('trial config is not initialized');
} }
...@@ -275,17 +280,6 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -275,17 +280,6 @@ class RemoteMachineTrainingService implements TrainingService {
return trialJobDetail; return trialJobDetail;
} }
/**
* remove gpu reversion when job is not running
*/
private updateGpuReservation() {
for (const [key, value] of this.trialJobsMap) {
if(!['WAITING', 'RUNNING'].includes(value.status)) {
this.gpuScheduler.removeGpuReservation(key, this.trialJobsMap);
}
};
}
/** /**
* Is multiphase job supported in current training service * Is multiphase job supported in current training service
...@@ -298,10 +292,11 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -298,10 +292,11 @@ class RemoteMachineTrainingService implements TrainingService {
* Cancel trial job * Cancel trial job
* @param trialJobId ID of trial job * @param trialJobId ID of trial job
*/ */
// tslint:disable:informative-docs no-unsafe-any
public async cancelTrialJob(trialJobId: string, isEarlyStopped: boolean = false): Promise<void> { public async cancelTrialJob(trialJobId: string, isEarlyStopped: boolean = false): Promise<void> {
const deferred: Deferred<void> = new Deferred<void>(); const deferred: Deferred<void> = new Deferred<void>();
const trialJob: RemoteMachineTrialJobDetail | undefined = this.trialJobsMap.get(trialJobId); const trialJob: RemoteMachineTrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
if (!trialJob) { if (trialJob === undefined) {
deferred.reject(); deferred.reject();
throw new Error(`trial job id ${trialJobId} not found`); throw new Error(`trial job id ${trialJobId} not found`);
} }
...@@ -316,7 +311,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -316,7 +311,7 @@ class RemoteMachineTrainingService implements TrainingService {
if (trialJob.rmMeta !== undefined) { if (trialJob.rmMeta !== undefined) {
// If the trial job is already scheduled, check its status and kill the trial process in remote machine // If the trial job is already scheduled, check its status and kill the trial process in remote machine
const sshClient: Client | undefined = this.trialSSHClientMap.get(trialJob.id); const sshClient: Client | undefined = this.trialSSHClientMap.get(trialJob.id);
if (!sshClient) { if (sshClient === undefined) {
deferred.reject(); deferred.reject();
throw new Error(`Invalid job id ${trialJobId}, cannot find ssh client`); throw new Error(`Invalid job id ${trialJobId}, cannot find ssh client`);
} }
...@@ -358,20 +353,23 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -358,20 +353,23 @@ class RemoteMachineTrainingService implements TrainingService {
case TrialConfigMetadataKey.TRIAL_CONFIG: case TrialConfigMetadataKey.TRIAL_CONFIG:
const remoteMachineTrailConfig: TrialConfig = <TrialConfig>JSON.parse(value); const remoteMachineTrailConfig: TrialConfig = <TrialConfig>JSON.parse(value);
// Parse trial config failed, throw Error // Parse trial config failed, throw Error
if (!remoteMachineTrailConfig) { if (remoteMachineTrailConfig === undefined) {
throw new Error('trial config parsed failed'); throw new Error('trial config parsed failed');
} }
// codeDir is not a valid directory, throw Error // codeDir is not a valid directory, throw Error
if (!fs.lstatSync(remoteMachineTrailConfig.codeDir).isDirectory()) { // tslint:disable-next-line:non-literal-fs-path
if (!fs.lstatSync(remoteMachineTrailConfig.codeDir)
.isDirectory()) {
throw new Error(`codeDir ${remoteMachineTrailConfig.codeDir} is not a directory`); throw new Error(`codeDir ${remoteMachineTrailConfig.codeDir} is not a directory`);
} }
// Validate to make sure codeDir doesn't have too many files // Validate to make sure codeDir doesn't have too many files
try { try {
await validateCodeDir(remoteMachineTrailConfig.codeDir); await validateCodeDir(remoteMachineTrailConfig.codeDir);
} catch(error) { } catch (error) {
this.log.error(error); this.log.error(error);
return Promise.reject(new Error(error));
return Promise.reject(new Error(error));
} }
this.trialConfig = remoteMachineTrailConfig; this.trialConfig = remoteMachineTrailConfig;
...@@ -400,60 +398,73 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -400,60 +398,73 @@ class RemoteMachineTrainingService implements TrainingService {
return deferred.promise; return deferred.promise;
} }
/** /**
* cleanup() has a time out of 10s to clean remote connections * cleanup() has a time out of 10s to clean remote connections
*/ */
public async cleanUp(): Promise<void> { public async cleanUp(): Promise<void> {
this.log.info('Stopping remote machine training service...'); this.log.info('Stopping remote machine training service...');
this.stopping = true; this.stopping = true;
await Promise.race([delay(10000), this.cleanupConnections()]); await Promise.race([delay(10000), this.cleanupConnections()]);
} }
/**
* remove gpu reversion when job is not running
*/
private updateGpuReservation(): void {
for (const [key, value] of this.trialJobsMap) {
if (!['WAITING', 'RUNNING'].includes(value.status)) {
this.gpuScheduler.removeGpuReservation(key, this.trialJobsMap);
}
}
}
/** /**
* stop gpu_metric_collector process in remote machine and remove unused scripts * stop gpu_metric_collector process in remote machine and remove unused scripts
*/ */
private async cleanupConnections(): Promise<void> { private async cleanupConnections(): Promise<void> {
try{ try {
for (const [rmMeta, sshClientManager] of this.machineSSHClientMap.entries()) { for (const [rmMeta, sshClientManager] of this.machineSSHClientMap.entries()) {
let jobpidPath: string = unixPathJoin(this.getRemoteScriptsPath(rmMeta.username), 'pid'); const jobpidPath: string = unixPathJoin(this.getRemoteScriptsPath(rmMeta.username), 'pid');
let client: Client | undefined = sshClientManager.getFirstSSHClient(); const client: Client | undefined = sshClientManager.getFirstSSHClient();
if(client) { if (client !== undefined) {
await SSHClientUtility.remoteExeCommand(`pkill -P \`cat ${jobpidPath}\``, client); await SSHClientUtility.remoteExeCommand(`pkill -P \`cat ${jobpidPath}\``, client);
await SSHClientUtility.remoteExeCommand(`rm -rf ${this.getRemoteScriptsPath(rmMeta.username)}`, client); await SSHClientUtility.remoteExeCommand(`rm -rf ${this.getRemoteScriptsPath(rmMeta.username)}`, client);
} }
sshClientManager.closeAllSSHClient(); sshClientManager.closeAllSSHClient();
} }
}catch (error) { } catch (error) {
//ignore error, this function is called to cleanup remote connections when experiment is stopping //ignore error, this function is called to cleanup remote connections when experiment is stopping
this.log.error(`Cleanup connection exception, error is ${error.message}`); this.log.error(`Cleanup connection exception, error is ${error.message}`);
} }
return Promise.resolve(); return Promise.resolve();
} }
/** /**
* Generate gpu metric collector directory to store temp gpu metric collector script files * Generate gpu metric collector directory to store temp gpu metric collector script files
*/ */
private getLocalGpuMetricCollectorDir(): string { private getLocalGpuMetricCollectorDir(): string {
let userName: string = path.basename(os.homedir()); //get current user name of os const userName: string = path.basename(os.homedir()); //get current user name of os
return path.join(os.tmpdir(), userName, 'nni', 'scripts'); return path.join(os.tmpdir(), userName, 'nni', 'scripts');
} }
/** /**
* Generate gpu metric collector shell script in local machine, * Generate gpu metric collector shell script in local machine,
* used to run in remote machine, and will be deleted after uploaded from local. * used to run in remote machine, and will be deleted after uploaded from local.
*/ */
private async generateGpuMetricsCollectorScript(userName: string): Promise<void> { private async generateGpuMetricsCollectorScript(userName: string): Promise<void> {
let gpuMetricCollectorScriptFolder : string = this.getLocalGpuMetricCollectorDir(); const gpuMetricCollectorScriptFolder : string = this.getLocalGpuMetricCollectorDir();
await execMkdir(path.join(gpuMetricCollectorScriptFolder, userName)); await execMkdir(path.join(gpuMetricCollectorScriptFolder, userName));
//generate gpu_metrics_collector.sh //generate gpu_metrics_collector.sh
let gpuMetricsCollectorScriptPath: string = path.join(gpuMetricCollectorScriptFolder, userName, 'gpu_metrics_collector.sh'); const gpuMetricsCollectorScriptPath: string = path.join(gpuMetricCollectorScriptFolder, userName, 'gpu_metrics_collector.sh');
const remoteGPUScriptsDir: string = this.getRemoteScriptsPath(userName); // This directory is used to store gpu_metrics and pid created by script // This directory is used to store gpu_metrics and pid created by script
const remoteGPUScriptsDir: string = this.getRemoteScriptsPath(userName);
const gpuMetricsCollectorScriptContent: string = String.Format( const gpuMetricsCollectorScriptContent: string = String.Format(
GPU_INFO_COLLECTOR_FORMAT_LINUX, GPU_INFO_COLLECTOR_FORMAT_LINUX,
remoteGPUScriptsDir, remoteGPUScriptsDir,
unixPathJoin(remoteGPUScriptsDir, 'pid'), unixPathJoin(remoteGPUScriptsDir, 'pid')
); );
await fs.promises.writeFile(gpuMetricsCollectorScriptPath, gpuMetricsCollectorScriptContent, { encoding: 'utf8' }); await fs.promises.writeFile(gpuMetricsCollectorScriptPath, gpuMetricsCollectorScriptContent, { encoding: 'utf8' });
} }
...@@ -467,39 +478,44 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -467,39 +478,44 @@ class RemoteMachineTrainingService implements TrainingService {
rmMetaList.forEach(async (rmMeta: RemoteMachineMeta) => { rmMetaList.forEach(async (rmMeta: RemoteMachineMeta) => {
rmMeta.occupiedGpuIndexMap = new Map<number, number>(); rmMeta.occupiedGpuIndexMap = new Map<number, number>();
let sshClientManager: SSHClientManager = new SSHClientManager([], this.MAX_TRIAL_NUMBER_PER_SSHCONNECTION, rmMeta); const sshClientManager: SSHClientManager = new SSHClientManager([], this.MAX_TRIAL_NUMBER_PER_SSHCONNECTION, rmMeta);
let sshClient: Client = await sshClientManager.getAvailableSSHClient(); const sshClient: Client = await sshClientManager.getAvailableSSHClient();
this.machineSSHClientMap.set(rmMeta, sshClientManager); this.machineSSHClientMap.set(rmMeta, sshClientManager);
await this.initRemoteMachineOnConnected(rmMeta, sshClient); await this.initRemoteMachineOnConnected(rmMeta, sshClient);
if (++connectedRMNum === rmMetaList.length) { if (++connectedRMNum === rmMetaList.length) {
deferred.resolve(); deferred.resolve();
} }
}); });
return deferred.promise; return deferred.promise;
} }
private async initRemoteMachineOnConnected(rmMeta: RemoteMachineMeta, conn: Client): Promise<void> { private async initRemoteMachineOnConnected(rmMeta: RemoteMachineMeta, conn: Client): Promise<void> {
// Create root working directory after ssh connection is ready // Create root working directory after ssh connection is ready
await this.generateGpuMetricsCollectorScript(rmMeta.username); //generate gpu script in local machine first, will copy to remote machine later // generate gpu script in local machine first, will copy to remote machine later
await this.generateGpuMetricsCollectorScript(rmMeta.username);
const nniRootDir: string = unixPathJoin(getRemoteTmpDir(this.remoteOS), 'nni'); const nniRootDir: string = unixPathJoin(getRemoteTmpDir(this.remoteOS), 'nni');
await SSHClientUtility.remoteExeCommand(`mkdir -p ${this.remoteExpRootDir}`, conn); await SSHClientUtility.remoteExeCommand(`mkdir -p ${this.remoteExpRootDir}`, conn);
// Copy NNI scripts to remote expeirment working directory // Copy NNI scripts to remote expeirment working directory
const localGpuScriptCollectorDir: string = this.getLocalGpuMetricCollectorDir(); const localGpuScriptCollectorDir: string = this.getLocalGpuMetricCollectorDir();
const remoteGpuScriptCollectorDir: string = this.getRemoteScriptsPath(rmMeta.username); //the directory to store temp scripts in remote machine // the directory to store temp scripts in remote machine
const remoteGpuScriptCollectorDir: string = this.getRemoteScriptsPath(rmMeta.username);
await SSHClientUtility.remoteExeCommand(`mkdir -p ${remoteGpuScriptCollectorDir}`, conn); await SSHClientUtility.remoteExeCommand(`mkdir -p ${remoteGpuScriptCollectorDir}`, conn);
await SSHClientUtility.remoteExeCommand(`chmod 777 ${nniRootDir} ${nniRootDir}/* ${nniRootDir}/scripts/*`, conn); await SSHClientUtility.remoteExeCommand(`chmod 777 ${nniRootDir} ${nniRootDir}/* ${nniRootDir}/scripts/*`, conn);
//copy gpu_metrics_collector.sh to remote //copy gpu_metrics_collector.sh to remote
await SSHClientUtility.copyFileToRemote(path.join(localGpuScriptCollectorDir, rmMeta.username, 'gpu_metrics_collector.sh'), unixPathJoin(remoteGpuScriptCollectorDir, 'gpu_metrics_collector.sh'), conn); await SSHClientUtility.copyFileToRemote(path.join(localGpuScriptCollectorDir, rmMeta.username, 'gpu_metrics_collector.sh'),
unixPathJoin(remoteGpuScriptCollectorDir, 'gpu_metrics_collector.sh'), conn);
//Begin to execute gpu_metrics_collection scripts //Begin to execute gpu_metrics_collection scripts
// tslint:disable-next-line: no-floating-promises
SSHClientUtility.remoteExeCommand(`bash ${unixPathJoin(remoteGpuScriptCollectorDir, 'gpu_metrics_collector.sh')}`, conn); SSHClientUtility.remoteExeCommand(`bash ${unixPathJoin(remoteGpuScriptCollectorDir, 'gpu_metrics_collector.sh')}`, conn);
this.timer.subscribe( this.timer.subscribe(
async (tick: number) => { async (tick: number) => {
const cmdresult: RemoteCommandResult = await SSHClientUtility.remoteExeCommand( const cmdresult: RemoteCommandResult = await SSHClientUtility.remoteExeCommand(
`tail -n 1 ${unixPathJoin(remoteGpuScriptCollectorDir, 'gpu_metrics')}`, conn); `tail -n 1 ${unixPathJoin(remoteGpuScriptCollectorDir, 'gpu_metrics')}`, conn);
if (cmdresult && cmdresult.stdout) { if (cmdresult !== undefined && cmdresult.stdout !== undefined) {
rmMeta.gpuSummary = <GPUSummary>JSON.parse(cmdresult.stdout); rmMeta.gpuSummary = <GPUSummary>JSON.parse(cmdresult.stdout);
} }
} }
...@@ -509,7 +525,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -509,7 +525,7 @@ class RemoteMachineTrainingService implements TrainingService {
private async prepareTrialJob(trialJobId: string): Promise<boolean> { private async prepareTrialJob(trialJobId: string): Promise<boolean> {
const deferred : Deferred<boolean> = new Deferred<boolean>(); const deferred : Deferred<boolean> = new Deferred<boolean>();
if (!this.trialConfig) { if (this.trialConfig === undefined) {
throw new Error('trial config is not initialized'); throw new Error('trial config is not initialized');
} }
const trialJobDetail: RemoteMachineTrialJobDetail | undefined = this.trialJobsMap.get(trialJobId); const trialJobDetail: RemoteMachineTrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
...@@ -519,6 +535,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -519,6 +535,7 @@ class RemoteMachineTrainingService implements TrainingService {
// If job is not WATIING, Don't prepare and resolve true immediately // If job is not WATIING, Don't prepare and resolve true immediately
if (trialJobDetail.status !== 'WAITING') { if (trialJobDetail.status !== 'WAITING') {
deferred.resolve(true); deferred.resolve(true);
return deferred.promise; return deferred.promise;
} }
// get an ssh client from scheduler // get an ssh client from scheduler
...@@ -557,7 +574,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -557,7 +574,7 @@ class RemoteMachineTrainingService implements TrainingService {
private async launchTrialOnScheduledMachine(trialJobId: string, trialWorkingFolder: string, form: TrialJobApplicationForm, private async launchTrialOnScheduledMachine(trialJobId: string, trialWorkingFolder: string, form: TrialJobApplicationForm,
rmScheduleInfo: RemoteMachineScheduleInfo): Promise<void> { rmScheduleInfo: RemoteMachineScheduleInfo): Promise<void> {
if (!this.trialConfig) { if (this.trialConfig === undefined) {
throw new Error('trial config is not initialized'); throw new Error('trial config is not initialized');
} }
const cuda_visible_device: string = rmScheduleInfo.cuda_visible_device; const cuda_visible_device: string = rmScheduleInfo.cuda_visible_device;
...@@ -584,18 +601,19 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -584,18 +601,19 @@ class RemoteMachineTrainingService implements TrainingService {
let command: string; let command: string;
// Set CUDA_VISIBLE_DEVICES environment variable based on cuda_visible_device // Set CUDA_VISIBLE_DEVICES environment variable based on cuda_visible_device
// If no valid cuda_visible_device is defined, set CUDA_VISIBLE_DEVICES to empty string to hide GPU device // If no valid cuda_visible_device is defined, set CUDA_VISIBLE_DEVICES to empty string to hide GPU device
if(typeof cuda_visible_device === 'string' && cuda_visible_device.length > 0) { if (typeof cuda_visible_device === 'string' && cuda_visible_device.length > 0) {
command = `CUDA_VISIBLE_DEVICES=${cuda_visible_device} ${this.trialConfig.command}`; command = `CUDA_VISIBLE_DEVICES=${cuda_visible_device} ${this.trialConfig.command}`;
} else { } else {
command = `CUDA_VISIBLE_DEVICES=" " ${this.trialConfig.command}`; command = `CUDA_VISIBLE_DEVICES=" " ${this.trialConfig.command}`;
} }
const nniManagerIp = this.nniManagerIpConfig?this.nniManagerIpConfig.nniManagerIp:getIPV4Address(); // tslint:disable-next-line: strict-boolean-expressions
if(!this.remoteRestServerPort) { const nniManagerIp: string = this.nniManagerIpConfig ? this.nniManagerIpConfig.nniManagerIp : getIPV4Address();
if (this.remoteRestServerPort === undefined) {
const restServer: RemoteMachineJobRestServer = component.get(RemoteMachineJobRestServer); const restServer: RemoteMachineJobRestServer = component.get(RemoteMachineJobRestServer);
this.remoteRestServerPort = restServer.clusterRestServerPort; this.remoteRestServerPort = restServer.clusterRestServerPort;
} }
const version = this.versionCheck? await getVersion(): ''; const version: string = this.versionCheck ? await getVersion() : '';
const runScriptTrialContent: string = String.Format( const runScriptTrialContent: string = String.Format(
REMOTEMACHINE_TRIAL_COMMAND_FORMAT, REMOTEMACHINE_TRIAL_COMMAND_FORMAT,
trialWorkingFolder, trialWorkingFolder,
...@@ -611,7 +629,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -611,7 +629,7 @@ class RemoteMachineTrainingService implements TrainingService {
version, version,
this.logCollection, this.logCollection,
unixPathJoin(trialWorkingFolder, '.nni', 'code') unixPathJoin(trialWorkingFolder, '.nni', 'code')
) );
//create tmp trial working folder locally. //create tmp trial working folder locally.
await execMkdir(path.join(trialLocalTempFolder, '.nni')); await execMkdir(path.join(trialLocalTempFolder, '.nni'));
...@@ -627,6 +645,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -627,6 +645,7 @@ class RemoteMachineTrainingService implements TrainingService {
// Copy files in codeDir to remote working directory // Copy files in codeDir to remote working directory
await SSHClientUtility.copyDirectoryToRemote(trialLocalTempFolder, trialWorkingFolder, sshClient, this.remoteOS); await SSHClientUtility.copyDirectoryToRemote(trialLocalTempFolder, trialWorkingFolder, sshClient, this.remoteOS);
// Execute command in remote machine // Execute command in remote machine
// tslint:disable-next-line: no-floating-promises
SSHClientUtility.remoteExeCommand(`bash ${unixPathJoin(trialWorkingFolder, 'run.sh')}`, sshClient); SSHClientUtility.remoteExeCommand(`bash ${unixPathJoin(trialWorkingFolder, 'run.sh')}`, sshClient);
} }
...@@ -636,7 +655,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -636,7 +655,7 @@ class RemoteMachineTrainingService implements TrainingService {
if (sshClientManager === undefined) { if (sshClientManager === undefined) {
throw new Error('sshClient not found.'); throw new Error('sshClient not found.');
} }
let sshClient: Client = sshClientManager.getFirstSSHClient(); const sshClient: Client = sshClientManager.getFirstSSHClient();
const jobId: string = uniqueString(5); const jobId: string = uniqueString(5);
const localDir: string = path.join(this.expRootDir, 'hostjobs-local', jobId); const localDir: string = path.join(this.expRootDir, 'hostjobs-local', jobId);
const remoteDir: string = this.getHostJobRemoteDir(jobId); const remoteDir: string = this.getHostJobRemoteDir(jobId);
...@@ -648,6 +667,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -648,6 +667,7 @@ class RemoteMachineTrainingService implements TrainingService {
await fs.promises.writeFile(path.join(localDir, 'run.sh'), runScriptContent, { encoding: 'utf8' }); await fs.promises.writeFile(path.join(localDir, 'run.sh'), runScriptContent, { encoding: 'utf8' });
await SSHClientUtility.copyFileToRemote( await SSHClientUtility.copyFileToRemote(
path.join(localDir, 'run.sh'), unixPathJoin(remoteDir, 'run.sh'), sshClient); path.join(localDir, 'run.sh'), unixPathJoin(remoteDir, 'run.sh'), sshClient);
// tslint:disable-next-line: no-floating-promises
SSHClientUtility.remoteExeCommand(`bash ${unixPathJoin(remoteDir, 'run.sh')}`, sshClient); SSHClientUtility.remoteExeCommand(`bash ${unixPathJoin(remoteDir, 'run.sh')}`, sshClient);
const jobDetail: RemoteMachineTrialJobDetail = new RemoteMachineTrialJobDetail( const jobDetail: RemoteMachineTrialJobDetail = new RemoteMachineTrialJobDetail(
...@@ -680,8 +700,9 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -680,8 +700,9 @@ class RemoteMachineTrainingService implements TrainingService {
if (killResult !== 0) { if (killResult !== 0) {
const trailReturnCode: string = await SSHClientUtility.getRemoteFileContent(trialReturnCodeFilePath, sshClient); const trailReturnCode: string = await SSHClientUtility.getRemoteFileContent(trialReturnCodeFilePath, sshClient);
this.log.debug(`trailjob ${trialJob.id} return code: ${trailReturnCode}`); this.log.debug(`trailjob ${trialJob.id} return code: ${trailReturnCode}`);
const match: RegExpMatchArray | null = trailReturnCode.trim().match(/^(\d+)\s+(\d+)$/); const match: RegExpMatchArray | null = trailReturnCode.trim()
if (match) { .match(/^(\d+)\s+(\d+)$/);
if (match !== null) {
const { 1: code, 2: timestamp } = match; const { 1: code, 2: timestamp } = match;
// Update trial job's status based on result code // Update trial job's status based on result code
if (parseInt(code, 10) === 0) { if (parseInt(code, 10) === 0) {
...@@ -709,6 +730,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -709,6 +730,7 @@ class RemoteMachineTrainingService implements TrainingService {
deferred.resolve(trialJob); deferred.resolve(trialJob);
} }
} }
return deferred.promise; return deferred.promise;
} }
...@@ -720,7 +742,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -720,7 +742,7 @@ class RemoteMachineTrainingService implements TrainingService {
return unixPathJoin(this.remoteExpRootDir, 'hostjobs', jobId); return unixPathJoin(this.remoteExpRootDir, 'hostjobs', jobId);
} }
private getRemoteExperimentRootDir(): string{ private getRemoteExperimentRootDir(): string {
return unixPathJoin(getRemoteTmpDir(this.remoteOS), 'nni', 'experiments', getExperimentId()); return unixPathJoin(getRemoteTmpDir(this.remoteOS), 'nni', 'experiments', getExperimentId());
} }
......
...@@ -21,16 +21,16 @@ ...@@ -21,16 +21,16 @@
import * as assert from 'assert'; import * as assert from 'assert';
import * as cpp from 'child-process-promise'; import * as cpp from 'child-process-promise';
import * as path from 'path';
import * as os from 'os'; import * as os from 'os';
import * as path from 'path';
import { Client, ClientChannel, SFTPWrapper } from 'ssh2'; import { Client, ClientChannel, SFTPWrapper } from 'ssh2';
import * as stream from 'stream'; import * as stream from 'stream';
import { Deferred } from 'ts-deferred'; import { Deferred } from 'ts-deferred';
import { NNIError, NNIErrorNames } from '../../common/errors'; import { NNIError, NNIErrorNames } from '../../common/errors';
import { getLogger, Logger } from '../../common/log'; import { getLogger, Logger } from '../../common/log';
import { uniqueString, getRemoteTmpDir, unixPathJoin } from '../../common/utils'; import { getRemoteTmpDir, uniqueString, unixPathJoin } from '../../common/utils';
import { RemoteCommandResult } from './remoteMachineData';
import { execRemove, tarAdd } from '../common/util'; import { execRemove, tarAdd } from '../common/util';
import { RemoteCommandResult } from './remoteMachineData';
/** /**
* *
...@@ -44,7 +44,8 @@ export namespace SSHClientUtility { ...@@ -44,7 +44,8 @@ export namespace SSHClientUtility {
* @param remoteDirectory remote directory * @param remoteDirectory remote directory
* @param sshClient SSH client * @param sshClient SSH client
*/ */
export async function copyDirectoryToRemote(localDirectory : string, remoteDirectory : string, sshClient : Client, remoteOS: string) : Promise<void> { export async function copyDirectoryToRemote(localDirectory : string, remoteDirectory : string, sshClient : Client, remoteOS: string)
: Promise<void> {
const deferred: Deferred<void> = new Deferred<void>(); const deferred: Deferred<void> = new Deferred<void>();
const tmpTarName: string = `${uniqueString(10)}.tar.gz`; const tmpTarName: string = `${uniqueString(10)}.tar.gz`;
const localTarPath: string = path.join(os.tmpdir(), tmpTarName); const localTarPath: string = path.join(os.tmpdir(), tmpTarName);
...@@ -75,7 +76,7 @@ export namespace SSHClientUtility { ...@@ -75,7 +76,7 @@ export namespace SSHClientUtility {
assert(sshClient !== undefined); assert(sshClient !== undefined);
const deferred: Deferred<boolean> = new Deferred<boolean>(); const deferred: Deferred<boolean> = new Deferred<boolean>();
sshClient.sftp((err : Error, sftp : SFTPWrapper) => { sshClient.sftp((err : Error, sftp : SFTPWrapper) => {
if (err) { if (err !== undefined && err !== null) {
log.error(`copyFileToRemote: ${err.message}, ${localFilePath}, ${remoteFilePath}`); log.error(`copyFileToRemote: ${err.message}, ${localFilePath}, ${remoteFilePath}`);
deferred.reject(err); deferred.reject(err);
...@@ -84,7 +85,7 @@ export namespace SSHClientUtility { ...@@ -84,7 +85,7 @@ export namespace SSHClientUtility {
assert(sftp !== undefined); assert(sftp !== undefined);
sftp.fastPut(localFilePath, remoteFilePath, (fastPutErr : Error) => { sftp.fastPut(localFilePath, remoteFilePath, (fastPutErr : Error) => {
sftp.end(); sftp.end();
if (fastPutErr) { if (fastPutErr !== undefined && fastPutErr !== null) {
deferred.reject(fastPutErr); deferred.reject(fastPutErr);
} else { } else {
deferred.resolve(true); deferred.resolve(true);
...@@ -100,6 +101,7 @@ export namespace SSHClientUtility { ...@@ -100,6 +101,7 @@ export namespace SSHClientUtility {
* @param command the command to execute remotely * @param command the command to execute remotely
* @param client SSH Client * @param client SSH Client
*/ */
// tslint:disable:no-unsafe-any no-any
export function remoteExeCommand(command : string, client : Client): Promise<RemoteCommandResult> { export function remoteExeCommand(command : string, client : Client): Promise<RemoteCommandResult> {
const log: Logger = getLogger(); const log: Logger = getLogger();
log.debug(`remoteExeCommand: command: [${command}]`); log.debug(`remoteExeCommand: command: [${command}]`);
...@@ -109,7 +111,7 @@ export namespace SSHClientUtility { ...@@ -109,7 +111,7 @@ export namespace SSHClientUtility {
let exitCode : number; let exitCode : number;
client.exec(command, (err : Error, channel : ClientChannel) => { client.exec(command, (err : Error, channel : ClientChannel) => {
if (err) { if (err !== undefined && err !== null) {
log.error(`remoteExeCommand: ${err.message}`); log.error(`remoteExeCommand: ${err.message}`);
deferred.reject(err); deferred.reject(err);
...@@ -117,13 +119,14 @@ export namespace SSHClientUtility { ...@@ -117,13 +119,14 @@ export namespace SSHClientUtility {
} }
channel.on('data', (data : any, dataStderr : any) => { channel.on('data', (data : any, dataStderr : any) => {
if (dataStderr) { if (dataStderr !== undefined && dataStderr !== null) {
stderr += data.toString(); stderr += data.toString();
} else { } else {
stdout += data.toString(); stdout += data.toString();
} }
}).on('exit', (code, signal) => { })
exitCode = code as number; .on('exit', (code : any, signal : any) => {
exitCode = <number>code;
deferred.resolve({ deferred.resolve({
stdout : stdout, stdout : stdout,
stderr : stderr, stderr : stderr,
...@@ -138,8 +141,9 @@ export namespace SSHClientUtility { ...@@ -138,8 +141,9 @@ export namespace SSHClientUtility {
export function getRemoteFileContent(filePath: string, sshClient: Client): Promise<string> { export function getRemoteFileContent(filePath: string, sshClient: Client): Promise<string> {
const deferred: Deferred<string> = new Deferred<string>(); const deferred: Deferred<string> = new Deferred<string>();
sshClient.sftp((err: Error, sftp : SFTPWrapper) => { sshClient.sftp((err: Error, sftp : SFTPWrapper) => {
if (err) { if (err !== undefined && err !== null) {
getLogger().error(`getRemoteFileContent: ${err.message}`); getLogger()
.error(`getRemoteFileContent: ${err.message}`);
deferred.reject(new Error(`SFTP error: ${err.message}`)); deferred.reject(new Error(`SFTP error: ${err.message}`));
return; return;
...@@ -150,16 +154,19 @@ export namespace SSHClientUtility { ...@@ -150,16 +154,19 @@ export namespace SSHClientUtility {
let dataBuffer: string = ''; let dataBuffer: string = '';
sftpStream.on('data', (data : Buffer | string) => { sftpStream.on('data', (data : Buffer | string) => {
dataBuffer += data; dataBuffer += data;
}).on('error', (streamErr: Error) => { })
.on('error', (streamErr: Error) => {
sftp.end(); sftp.end();
deferred.reject(new NNIError(NNIErrorNames.NOT_FOUND, streamErr.message)); deferred.reject(new NNIError(NNIErrorNames.NOT_FOUND, streamErr.message));
}).on('end', () => { })
.on('end', () => {
// sftp connection need to be released manually once operation is done // sftp connection need to be released manually once operation is done
sftp.end(); sftp.end();
deferred.resolve(dataBuffer); deferred.resolve(dataBuffer);
}); });
} catch (error) { } catch (error) {
getLogger().error(`getRemoteFileContent: ${error.message}`); getLogger()
.error(`getRemoteFileContent: ${error.message}`);
sftp.end(); sftp.end();
deferred.reject(new Error(`SFTP error: ${error.message}`)); deferred.reject(new Error(`SFTP error: ${error.message}`));
} }
...@@ -167,4 +174,5 @@ export namespace SSHClientUtility { ...@@ -167,4 +174,5 @@ export namespace SSHClientUtility {
return deferred.promise; return deferred.promise;
} }
// tslint:enable:no-unsafe-any no-any
} }
...@@ -37,7 +37,7 @@ describe('WebHDFS', function () { ...@@ -37,7 +37,7 @@ describe('WebHDFS', function () {
{ {
"user": "user1", "user": "user1",
"port": 50070, "port": 50070,
"host": "10.0.0.0" "host": "10.0.0.0"
} }
*/ */
let skip: boolean = false; let skip: boolean = false;
...@@ -45,7 +45,7 @@ describe('WebHDFS', function () { ...@@ -45,7 +45,7 @@ describe('WebHDFS', function () {
let hdfsClient: any; let hdfsClient: any;
try { try {
testHDFSInfo = JSON.parse(fs.readFileSync('../../.vscode/hdfsInfo.json', 'utf8')); testHDFSInfo = JSON.parse(fs.readFileSync('../../.vscode/hdfsInfo.json', 'utf8'));
console.log(testHDFSInfo); console.log(testHDFSInfo);
hdfsClient = WebHDFS.createClient({ hdfsClient = WebHDFS.createClient({
user: testHDFSInfo.user, user: testHDFSInfo.user,
port: testHDFSInfo.port, port: testHDFSInfo.port,
...@@ -120,7 +120,7 @@ describe('WebHDFS', function () { ...@@ -120,7 +120,7 @@ describe('WebHDFS', function () {
chai.expect(actualFileData).to.be.equals(testFileData); chai.expect(actualFileData).to.be.equals(testFileData);
const testHDFSDirPath : string = path.join('/nni_unittest_' + uniqueString(6) + '_dir'); const testHDFSDirPath : string = path.join('/nni_unittest_' + uniqueString(6) + '_dir');
await HDFSClientUtility.copyDirectoryToHdfs(tmpLocalDirectoryPath, testHDFSDirPath, hdfsClient); await HDFSClientUtility.copyDirectoryToHdfs(tmpLocalDirectoryPath, testHDFSDirPath, hdfsClient);
const files : any[] = await HDFSClientUtility.readdir(testHDFSDirPath, hdfsClient); const files : any[] = await HDFSClientUtility.readdir(testHDFSDirPath, hdfsClient);
...@@ -133,7 +133,7 @@ describe('WebHDFS', function () { ...@@ -133,7 +133,7 @@ describe('WebHDFS', function () {
// Cleanup // Cleanup
rmdir(tmpLocalDirectoryPath); rmdir(tmpLocalDirectoryPath);
let deleteRestult : boolean = await HDFSClientUtility.deletePath(testHDFSFilePath, hdfsClient); let deleteRestult : boolean = await HDFSClientUtility.deletePath(testHDFSFilePath, hdfsClient);
chai.expect(deleteRestult).to.be.equals(true); chai.expect(deleteRestult).to.be.equals(true);
......
...@@ -63,7 +63,7 @@ describe('Unit Test for KubeflowTrainingService', () => { ...@@ -63,7 +63,7 @@ describe('Unit Test for KubeflowTrainingService', () => {
if (skip) { if (skip) {
return; return;
} }
kubeflowTrainingService = component.get(KubeflowTrainingService); kubeflowTrainingService = component.get(KubeflowTrainingService);
}); });
afterEach(() => { afterEach(() => {
...@@ -78,6 +78,6 @@ describe('Unit Test for KubeflowTrainingService', () => { ...@@ -78,6 +78,6 @@ describe('Unit Test for KubeflowTrainingService', () => {
return; return;
} }
await kubeflowTrainingService.setClusterMetadata(TrialConfigMetadataKey.KUBEFLOW_CLUSTER_CONFIG, testKubeflowConfig), await kubeflowTrainingService.setClusterMetadata(TrialConfigMetadataKey.KUBEFLOW_CLUSTER_CONFIG, testKubeflowConfig),
await kubeflowTrainingService.setClusterMetadata(TrialConfigMetadataKey.TRIAL_CONFIG, testKubeflowTrialConfig); await kubeflowTrainingService.setClusterMetadata(TrialConfigMetadataKey.TRIAL_CONFIG, testKubeflowTrialConfig);
}); });
}); });
\ No newline at end of file
...@@ -63,7 +63,7 @@ describe('Unit Test for LocalTrainingService', () => { ...@@ -63,7 +63,7 @@ describe('Unit Test for LocalTrainingService', () => {
//trial jobs should be empty, since there are no submitted jobs //trial jobs should be empty, since there are no submitted jobs
chai.expect(await localTrainingService.listTrialJobs()).to.be.empty; chai.expect(await localTrainingService.listTrialJobs()).to.be.empty;
}); });
it('setClusterMetadata and getClusterMetadata', async () => { it('setClusterMetadata and getClusterMetadata', async () => {
await localTrainingService.setClusterMetadata(TrialConfigMetadataKey.TRIAL_CONFIG, trialConfig); await localTrainingService.setClusterMetadata(TrialConfigMetadataKey.TRIAL_CONFIG, trialConfig);
localTrainingService.getClusterMetadata(TrialConfigMetadataKey.TRIAL_CONFIG).then((data)=>{ localTrainingService.getClusterMetadata(TrialConfigMetadataKey.TRIAL_CONFIG).then((data)=>{
...@@ -87,7 +87,7 @@ describe('Unit Test for LocalTrainingService', () => { ...@@ -87,7 +87,7 @@ describe('Unit Test for LocalTrainingService', () => {
await localTrainingService.cancelTrialJob(jobDetail.id); await localTrainingService.cancelTrialJob(jobDetail.id);
chai.expect(jobDetail.status).to.be.equals('USER_CANCELED'); chai.expect(jobDetail.status).to.be.equals('USER_CANCELED');
}).timeout(20000); }).timeout(20000);
it('Read metrics, Add listener, and remove listener', async () => { it('Read metrics, Add listener, and remove listener', async () => {
// set meta data // set meta data
const trialConfig: string = `{\"command\":\"python3 mockedTrial.py\", \"codeDir\":\"${localCodeDir}\",\"gpuNum\":0}` const trialConfig: string = `{\"command\":\"python3 mockedTrial.py\", \"codeDir\":\"${localCodeDir}\",\"gpuNum\":0}`
......
...@@ -89,7 +89,7 @@ describe('Unit Test for PAITrainingService', () => { ...@@ -89,7 +89,7 @@ describe('Unit Test for PAITrainingService', () => {
chai.expect(trialDetail.status).to.be.equals('WAITING'); chai.expect(trialDetail.status).to.be.equals('WAITING');
} catch(error) { } catch(error) {
console.log('Submit job failed:' + error); console.log('Submit job failed:' + error);
chai.assert(error) chai.assert(error)
} }
}); });
}); });
\ No newline at end of file
...@@ -9,7 +9,10 @@ ...@@ -9,7 +9,10 @@
"no-increment-decrement": false, "no-increment-decrement": false,
"promise-function-async": false, "promise-function-async": false,
"no-console": [true, "log"], "no-console": [true, "log"],
"no-multiline-string": false "no-multiline-string": false,
"no-suspicious-comment": false,
"no-backbone-get-set-outside-model": false,
"max-classes-per-file": false
}, },
"rulesDirectory": [], "rulesDirectory": [],
"linterOptions": { "linterOptions": {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment