Unverified Commit 9f4485c1 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #130 from Microsoft/master

merge master
parents 1ee97350 51fbf695
...@@ -71,6 +71,7 @@ interface TrialJobDetail { ...@@ -71,6 +71,7 @@ interface TrialJobDetail {
readonly workingDirectory: string; readonly workingDirectory: string;
readonly form: JobApplicationForm; readonly form: JobApplicationForm;
readonly sequenceId: number; readonly sequenceId: number;
isEarlyStopped?: boolean;
} }
interface HostJobDetail { interface HostJobDetail {
......
...@@ -99,7 +99,25 @@ class MockedDataStore implements DataStore { ...@@ -99,7 +99,25 @@ class MockedDataStore implements DataStore {
private dbTrialJobs: SimpleDb = new SimpleDb('trial_jobs', './trial_jobs.json'); private dbTrialJobs: SimpleDb = new SimpleDb('trial_jobs', './trial_jobs.json');
private dbMetrics: SimpleDb = new SimpleDb('metrics', './metrics.json'); private dbMetrics: SimpleDb = new SimpleDb('metrics', './metrics.json');
trailJob1 = {
event: 'ADD_CUSTOMIZED',
timestamp: Date.now(),
trialJobId: "4321",
data: ''
}
metrics1 = {
timestamp: Date.now(),
trialJobId: '4321',
parameterId: 'param1',
type: 'CUSTOM',
sequence: 21,
data: ''
}
init(): Promise<void> { init(): Promise<void> {
this.dbTrialJobs.saveData(this.trailJob1);
this.dbMetrics.saveData(this.metrics1);
return Promise.resolve(); return Promise.resolve();
} }
......
...@@ -19,25 +19,27 @@ ...@@ -19,25 +19,27 @@
'use strict'; 'use strict';
import * as os from 'os';
import { assert, expect } from 'chai'; import { assert, expect } from 'chai';
import { Container, Scope } from 'typescript-ioc'; import { Container, Scope } from 'typescript-ioc';
import * as component from '../../common/component'; import * as component from '../../common/component';
import { Database, DataStore } from '../../common/datastore'; import { Database, DataStore } from '../../common/datastore';
import { Manager } from '../../common/manager'; import { Manager, ExperimentProfile} from '../../common/manager';
import { TrainingService } from '../../common/trainingService'; import { TrainingService } from '../../common/trainingService';
import { cleanupUnitTest, prepareUnitTest } from '../../common/utils'; import { cleanupUnitTest, prepareUnitTest } from '../../common/utils';
import { NNIDataStore } from '../nniDataStore'; import { NNIDataStore } from '../nniDataStore';
import { NNIManager } from '../nnimanager'; import { NNIManager } from '../nnimanager';
import { SqlDB } from '../sqlDatabase'; import { SqlDB } from '../sqlDatabase';
import { MockedTrainingService } from './mockedTrainingService'; import { MockedTrainingService } from './mockedTrainingService';
import { MockedDataStore } from './mockedDatastore';
async function initContainer(): Promise<void> { async function initContainer(): Promise<void> {
prepareUnitTest(); prepareUnitTest();
Container.bind(TrainingService).to(MockedTrainingService).scope(Scope.Singleton); Container.bind(TrainingService).to(MockedTrainingService).scope(Scope.Singleton);
Container.bind(Manager).to(NNIManager).scope(Scope.Singleton); Container.bind(Manager).to(NNIManager).scope(Scope.Singleton);
Container.bind(Database).to(SqlDB).scope(Scope.Singleton); Container.bind(Database).to(SqlDB).scope(Scope.Singleton);
Container.bind(DataStore).to(NNIDataStore).scope(Scope.Singleton); Container.bind(DataStore).to(MockedDataStore).scope(Scope.Singleton);
await component.get<DataStore>(DataStore).init(); await component.get<DataStore>(DataStore).init();
} }
...@@ -51,9 +53,9 @@ describe('Unit test for nnimanager', function () { ...@@ -51,9 +53,9 @@ describe('Unit test for nnimanager', function () {
let experimentParams = { let experimentParams = {
authorName: 'zql', authorName: 'zql',
experimentName: 'naive_experiment', experimentName: 'naive_experiment',
trialConcurrency: 2, trialConcurrency: 3,
maxExecDuration: 5, maxExecDuration: 5,
maxTrialNum: 2, maxTrialNum: 3,
trainingServicePlatform: 'local', trainingServicePlatform: 'local',
searchSpace: '{"x":1}', searchSpace: '{"x":1}',
tuner: { tuner: {
...@@ -71,36 +73,74 @@ describe('Unit test for nnimanager', function () { ...@@ -71,36 +73,74 @@ describe('Unit test for nnimanager', function () {
} }
} }
let updateExperimentParams = {
authorName: '',
experimentName: 'another_experiment',
trialConcurrency: 2,
maxExecDuration: 6,
maxTrialNum: 2,
trainingServicePlatform: 'local',
searchSpace: '{"y":2}',
tuner: {
className: 'TPE',
classArgs: {
optimize_mode: 'maximize'
},
checkpointDir: '',
gpuNum: 0
},
assessor: {
className: 'Medianstop',
checkpointDir: '',
gpuNum: 1
}
}
let experimentProfile = {
params: updateExperimentParams,
id: 'test',
execDuration: 0,
maxSequenceId: 0,
revision: 0
}
before(async () => { before(async () => {
await initContainer(); await initContainer();
nniManager = component.get(Manager); nniManager = component.get(Manager);
const expId: string = await nniManager.startExperiment(experimentParams); const expId: string = await nniManager.startExperiment(experimentParams);
assert(expId); assert.strictEqual(expId, 'unittest');
}); })
after(async () => { after(async () => {
await nniManager.stopExperiment(); await setTimeout(() => {nniManager.stopExperiment()},15000);
cleanupUnitTest(); cleanupUnitTest();
}) })
it('test resumeExperiment', () => {
//TODO: add resume experiment unit test
it('test addCustomizedTrialJob', () => {
return nniManager.addCustomizedTrialJob('hyperParams').then(() => {
}).catch((error) => {
assert.fail(error);
}) })
})
it('test listTrialJobs', () => { it('test listTrialJobs', () => {
//FIXME: not implemented return nniManager.listTrialJobs().then(function (trialjobdetails) {
//return nniManager.listTrialJobs().then(function (trialJobDetails) { expect(trialjobdetails.length).to.be.equal(2);
// expect(trialJobDetails.length).to.be.equal(2); }).catch((error) => {
//}).catch(function (error) { assert.fail(error);
// assert.fail(error); })
//})
}) })
it('test getTrialJob valid', () => { it('test getTrialJob valid', () => {
//query a exist id //query a exist id
return nniManager.getTrialJob('1234').then(function (trialJobDetail) { return nniManager.getTrialJob('1234').then(function (trialJobDetail) {
expect(trialJobDetail.id).to.be.equal('1234'); expect(trialJobDetail.id).to.be.equal('1234');
}).catch(function (error) { }).catch((error) => {
assert.fail(error); assert.fail(error);
}) })
}) })
...@@ -132,7 +172,6 @@ describe('Unit test for nnimanager', function () { ...@@ -132,7 +172,6 @@ describe('Unit test for nnimanager', function () {
}) })
}) })
//TODO: complete ut
it('test cancelTrialJobByUser', () => { it('test cancelTrialJobByUser', () => {
return nniManager.cancelTrialJobByUser('1234').then(() => { return nniManager.cancelTrialJobByUser('1234').then(() => {
...@@ -141,11 +180,112 @@ describe('Unit test for nnimanager', function () { ...@@ -141,11 +180,112 @@ describe('Unit test for nnimanager', function () {
}) })
}) })
it('test addCustomizedTrialJob', () => { it('test getExperimentProfile', () => {
return nniManager.addCustomizedTrialJob('hyperParams').then(() => { return nniManager.getExperimentProfile().then((experimentProfile) => {
expect(experimentProfile.id).to.be.equal('unittest');
expect(experimentProfile.logDir).to.be.equal(os.homedir()+'/nni/experiments/unittest');
}).catch((error) => {
assert.fail(error);
})
})
it('test updateExperimentProfile TRIAL_CONCURRENCY', () => {
return nniManager.updateExperimentProfile(experimentProfile, 'TRIAL_CONCURRENCY').then(() => {
nniManager.getExperimentProfile().then((updateProfile) => {
expect(updateProfile.params.trialConcurrency).to.be.equal(2);
});
}).catch((error) => {
assert.fail(error);
})
})
it('test updateExperimentProfile MAX_EXEC_DURATION', () => {
return nniManager.updateExperimentProfile(experimentProfile, 'MAX_EXEC_DURATION').then(() => {
nniManager.getExperimentProfile().then((updateProfile) => {
expect(updateProfile.params.maxExecDuration).to.be.equal(6);
});
}).catch((error) => {
assert.fail(error);
})
})
it('test updateExperimentProfile SEARCH_SPACE', () => {
return nniManager.updateExperimentProfile(experimentProfile, 'SEARCH_SPACE').then(() => {
nniManager.getExperimentProfile().then((updateProfile) => {
expect(updateProfile.params.searchSpace).to.be.equal('{"y":2}');
});
}).catch((error) => {
assert.fail(error);
})
})
it('test updateExperimentProfile MAX_TRIAL_NUM', () => {
return nniManager.updateExperimentProfile(experimentProfile, 'MAX_TRIAL_NUM').then(() => {
nniManager.getExperimentProfile().then((updateProfile) => {
expect(updateProfile.params.maxTrialNum).to.be.equal(2);
});
}).catch((error) => { }).catch((error) => {
assert.fail(error); assert.fail(error);
}) })
}) })
it('test getStatus', () => {
assert.strictEqual(nniManager.getStatus().status,'RUNNING');
})
it('test getMetricData with trialJobId', () => {
//query a exist trialJobId
return nniManager.getMetricData('4321', 'CUSTOM').then((metricData) => {
expect(metricData.length).to.be.equal(1);
expect(metricData[0].trialJobId).to.be.equal('4321');
expect(metricData[0].parameterId).to.be.equal('param1');
}).catch((error) => {
assert.fail(error);
})
})
it('test getMetricData with invalid trialJobId', () => {
//query an invalid trialJobId
return nniManager.getMetricData('43210', 'CUSTOM').then((metricData) => {
assert.fail();
}).catch((error) => {
})
})
it('test getTrialJobStatistics', () => {
// get 3 trial jobs (init, addCustomizedTrialJob, cancelTrialJobByUser)
return nniManager.getTrialJobStatistics().then(function (trialJobStatistics) {
expect(trialJobStatistics.length).to.be.equal(2);
if (trialJobStatistics[0].trialJobStatus === 'WAITING') {
expect(trialJobStatistics[0].trialJobNumber).to.be.equal(2);
expect(trialJobStatistics[1].trialJobNumber).to.be.equal(1);
}
else {
expect(trialJobStatistics[1].trialJobNumber).to.be.equal(2);
expect(trialJobStatistics[0].trialJobNumber).to.be.equal(1);
}
}).catch((error) => {
assert.fail(error);
})
})
it('test addCustomizedTrialJob reach maxTrialNum', () => {
// test currSubmittedTrialNum reach maxTrialNum
return nniManager.addCustomizedTrialJob('hyperParam').then(() => {
nniManager.getTrialJobStatistics().then(function (trialJobStatistics) {
if (trialJobStatistics[0].trialJobStatus === 'WAITING')
expect(trialJobStatistics[0].trialJobNumber).to.be.equal(2);
else
expect(trialJobStatistics[1].trialJobNumber).to.be.equal(2);
})
}).catch((error) => {
assert.fail(error);
})
})
it('test resumeExperiment', async () => {
//TODO: add resume experiment unit test
})
}) })
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
"scripts": { "scripts": {
"postbuild": "cp -rf config ./dist/", "postbuild": "cp -rf config ./dist/",
"build": "tsc", "build": "tsc",
"test": "nyc mocha -r ts-node/register -t 15000 --recursive **/*.test.ts --exclude node_modules/**/**/*.test.ts --exclude core/test/nnimanager.test.ts --colors", "test": "nyc mocha -r ts-node/register -t 15000 --recursive **/*.test.ts --exclude node_modules/**/**/*.test.ts --colors",
"start": "node dist/main.js", "start": "node dist/main.js",
"tslint": "tslint -p ." "tslint": "tslint -p ."
}, },
......
...@@ -46,6 +46,7 @@ export class KubeflowJobInfoCollector extends KubernetesJobInfoCollector{ ...@@ -46,6 +46,7 @@ export class KubeflowJobInfoCollector extends KubernetesJobInfoCollector{
try { try {
kubernetesJobInfo = await kubernetesCRDClient.getKubernetesJob(kubernetesTrialJob.kubernetesJobName); kubernetesJobInfo = await kubernetesCRDClient.getKubernetesJob(kubernetesTrialJob.kubernetesJobName);
} catch(error) { } catch(error) {
// Notice: it maynot be a 'real' error since cancel trial job can also cause getKubernetesJob failed.
this.log.error(`Get job ${kubernetesTrialJob.kubernetesJobName} info failed, error is ${error}`); this.log.error(`Get job ${kubernetesTrialJob.kubernetesJobName} info failed, error is ${error}`);
//This is not treat as a error status //This is not treat as a error status
return Promise.resolve(); return Promise.resolve();
......
...@@ -255,7 +255,7 @@ class LocalTrainingService implements TrainingService { ...@@ -255,7 +255,7 @@ class LocalTrainingService implements TrainingService {
} }
if (trialJob.pid === undefined){ if (trialJob.pid === undefined){
this.setTrialJobStatus(trialJob, 'USER_CANCELED'); this.setTrialJobStatus(trialJob, 'USER_CANCELED');
return; return Promise.resolve();
} }
if (trialJob.form.jobType === 'TRIAL') { if (trialJob.form.jobType === 'TRIAL') {
await tkill(trialJob.pid, 'SIGKILL'); await tkill(trialJob.pid, 'SIGKILL');
...@@ -265,6 +265,7 @@ class LocalTrainingService implements TrainingService { ...@@ -265,6 +265,7 @@ class LocalTrainingService implements TrainingService {
throw new Error(`Job type not supported: ${trialJob.form.jobType}`); throw new Error(`Job type not supported: ${trialJob.form.jobType}`);
} }
this.setTrialJobStatus(trialJob, getJobCancelStatus(isEarlyStopped)); this.setTrialJobStatus(trialJob, getJobCancelStatus(isEarlyStopped));
return Promise.resolve();
} }
public async setClusterMetadata(key: string, value: string): Promise<void> { public async setClusterMetadata(key: string, value: string): Promise<void> {
......
...@@ -34,6 +34,7 @@ export class PAITrialJobDetail implements TrialJobDetail { ...@@ -34,6 +34,7 @@ export class PAITrialJobDetail implements TrialJobDetail {
public form: JobApplicationForm; public form: JobApplicationForm;
public sequenceId: number; public sequenceId: number;
public hdfsLogPath: string; public hdfsLogPath: string;
public isEarlyStopped?: boolean;
constructor(id: string, status: TrialJobStatus, paiJobName : string, constructor(id: string, status: TrialJobStatus, paiJobName : string,
submitTime: number, workingDirectory: string, form: JobApplicationForm, sequenceId: number, hdfsLogPath: string) { submitTime: number, workingDirectory: string, form: JobApplicationForm, sequenceId: number, hdfsLogPath: string) {
...@@ -63,7 +64,7 @@ export const PAI_TRIAL_COMMAND_FORMAT: string = ...@@ -63,7 +64,7 @@ export const PAI_TRIAL_COMMAND_FORMAT: string =
`export NNI_PLATFORM=pai NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={2} NNI_EXP_ID={3} NNI_TRIAL_SEQ_ID={4} `export NNI_PLATFORM=pai NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={2} NNI_EXP_ID={3} NNI_TRIAL_SEQ_ID={4}
&& cd $NNI_SYS_DIR && sh install_nni.sh && cd $NNI_SYS_DIR && sh install_nni.sh
&& python3 -m nni_trial_tool.trial_keeper --trial_command '{5}' --nnimanager_ip '{6}' --nnimanager_port '{7}' && python3 -m nni_trial_tool.trial_keeper --trial_command '{5}' --nnimanager_ip '{6}' --nnimanager_port '{7}'
--pai_hdfs_output_dir '{8}' --pai_hdfs_host '{9}' --pai_user_name {10} --nni_hdfs_exp_dir '{11}'`; --pai_hdfs_output_dir '{8}' --pai_hdfs_host '{9}' --pai_user_name {10} --nni_hdfs_exp_dir '{11}' --webhdfs_path '/webhdfs/api/v1'`;
export const PAI_OUTPUT_DIR_FORMAT: string = export const PAI_OUTPUT_DIR_FORMAT: string =
`hdfs://{0}:9000/`; `hdfs://{0}:9000/`;
......
...@@ -103,8 +103,12 @@ export class PAIJobInfoCollector { ...@@ -103,8 +103,12 @@ export class PAIJobInfoCollector {
paiTrialJob.status = 'SUCCEEDED'; paiTrialJob.status = 'SUCCEEDED';
break; break;
case 'STOPPED': case 'STOPPED':
if (paiTrialJob.status !== 'EARLY_STOPPED') { if (paiTrialJob.isEarlyStopped !== undefined) {
paiTrialJob.status = 'USER_CANCELED'; paiTrialJob.status = paiTrialJob.isEarlyStopped === true ?
'EARLY_STOPPED' : 'USER_CANCELED';
} else {
// if paiTrialJob's isEarlyStopped is undefined, that mean we didn't stop it via cancellation, mark it as SYS_CANCELLED by PAI
paiTrialJob.status = 'SYS_CANCELED';
} }
break; break;
case 'FAILED': case 'FAILED':
......
...@@ -324,14 +324,15 @@ class PAITrainingService implements TrainingService { ...@@ -324,14 +324,15 @@ class PAITrainingService implements TrainingService {
"Authorization": 'Bearer ' + this.paiToken "Authorization": 'Bearer ' + this.paiToken
} }
}; };
// Set trialjobDetail's early stopped field, to mark the job's cancellation source
trialJobDetail.isEarlyStopped = isEarlyStopped;
request(stopJobRequest, (error: Error, response: request.Response, body: any) => { request(stopJobRequest, (error: Error, response: request.Response, body: any) => {
if (error || response.statusCode >= 400) { if (error || response.statusCode >= 400) {
this.log.error(`PAI Training service: stop trial ${trialJobId} to PAI Cluster failed!`); this.log.error(`PAI Training service: stop trial ${trialJobId} to PAI Cluster failed!`);
deferred.reject(error ? error.message : 'Stop trial failed, http code: ' + response.statusCode); deferred.reject(error ? error.message : 'Stop trial failed, http code: ' + response.statusCode);
} else { } else {
if (isEarlyStopped) {
trialJobDetail.status = 'EARLY_STOPPED';
}
deferred.resolve(); deferred.resolve();
} }
}); });
......
...@@ -80,6 +80,7 @@ export class RemoteMachineTrialJobDetail implements TrialJobDetail { ...@@ -80,6 +80,7 @@ export class RemoteMachineTrialJobDetail implements TrialJobDetail {
public form: JobApplicationForm; public form: JobApplicationForm;
public sequenceId: number; public sequenceId: number;
public rmMeta?: RemoteMachineMeta; public rmMeta?: RemoteMachineMeta;
public isEarlyStopped?: boolean;
constructor(id: string, status: TrialJobStatus, submitTime: number, constructor(id: string, status: TrialJobStatus, submitTime: number,
workingDirectory: string, form: JobApplicationForm, sequenceId: number) { workingDirectory: string, form: JobApplicationForm, sequenceId: number) {
...@@ -114,7 +115,7 @@ export NNI_PLATFORM=remote NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={ ...@@ -114,7 +115,7 @@ export NNI_PLATFORM=remote NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={
cd $NNI_SYS_DIR cd $NNI_SYS_DIR
sh install_nni.sh sh install_nni.sh
echo $$ >{6} echo $$ >{6}
python3 -m nni_trial_tool.trial_keeper --trial_command '{7}' --nnimanager_ip '{8}' --nnimanager_port '{9}' python3 -m nni_trial_tool.trial_keeper --trial_command '{7}' --nnimanager_ip '{8}' --nnimanager_port '{9}' 1>$NNI_OUTPUT_DIR/trialkeeper_stdout 2>$NNI_OUTPUT_DIR/trialkeeper_stderr
echo $? \`date +%s%3N\` >{10}`; echo $? \`date +%s%3N\` >{10}`;
export const HOST_JOB_SHELL_FORMAT: string = export const HOST_JOB_SHELL_FORMAT: string =
......
...@@ -48,7 +48,7 @@ import { ...@@ -48,7 +48,7 @@ import {
GPU_COLLECTOR_FORMAT GPU_COLLECTOR_FORMAT
} from './remoteMachineData'; } from './remoteMachineData';
import { SSHClientUtility } from './sshClientUtility'; import { SSHClientUtility } from './sshClientUtility';
import { validateCodeDir} from '../common/util'; import { validateCodeDir } from '../common/util';
import { RemoteMachineJobRestServer } from './remoteMachineJobRestServer'; import { RemoteMachineJobRestServer } from './remoteMachineJobRestServer';
import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData'; import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData';
import { mkDirP } from '../../common/utils'; import { mkDirP } from '../../common/utils';
...@@ -279,8 +279,9 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -279,8 +279,9 @@ class RemoteMachineTrainingService implements TrainingService {
const jobpidPath: string = this.getJobPidPath(trialJob.id); const jobpidPath: string = this.getJobPidPath(trialJob.id);
try { try {
// Mark the toEarlyStop tag here
trialJob.isEarlyStopped = isEarlyStopped;
await SSHClientUtility.remoteExeCommand(`pkill -P \`cat ${jobpidPath}\``, sshClient); await SSHClientUtility.remoteExeCommand(`pkill -P \`cat ${jobpidPath}\``, sshClient);
trialJob.status = getJobCancelStatus(isEarlyStopped);
} catch (error) { } catch (error) {
// Not handle the error since pkill failed will not impact trial job's current status // Not handle the error since pkill failed will not impact trial job's current status
this.log.error(`remoteTrainingService.cancelTrialJob: ${error.message}`); this.log.error(`remoteTrainingService.cancelTrialJob: ${error.message}`);
...@@ -482,6 +483,11 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -482,6 +483,11 @@ class RemoteMachineTrainingService implements TrainingService {
if (trialJobDetail === undefined) { if (trialJobDetail === undefined) {
throw new NNIError(NNIErrorNames.INVALID_JOB_DETAIL, `Invalid job detail information for trial job ${trialJobId}`); throw new NNIError(NNIErrorNames.INVALID_JOB_DETAIL, `Invalid job detail information for trial job ${trialJobId}`);
} }
// If job is not WATIING, Don't prepare and resolve true immediately
if (trialJobDetail.status !== 'WAITING') {
deferred.resolve(true);
return deferred.promise;
}
// get an ssh client from scheduler // get an ssh client from scheduler
const rmScheduleResult: RemoteMachineScheduleResult = this.gpuScheduler.scheduleMachine(this.trialConfig.gpuNum, trialJobId); const rmScheduleResult: RemoteMachineScheduleResult = this.gpuScheduler.scheduleMachine(this.trialConfig.gpuNum, trialJobId);
if (rmScheduleResult.resultType === ScheduleResultType.REQUIRE_EXCEED_TOTAL) { if (rmScheduleResult.resultType === ScheduleResultType.REQUIRE_EXCEED_TOTAL) {
...@@ -640,7 +646,12 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -640,7 +646,12 @@ class RemoteMachineTrainingService implements TrainingService {
if (parseInt(code, 10) === 0) { if (parseInt(code, 10) === 0) {
trialJob.status = 'SUCCEEDED'; trialJob.status = 'SUCCEEDED';
} else { } else {
// isEarlyStopped is never set, mean it's not cancelled by NNI, so if the process's exit code >0, mark it as FAILED
if (trialJob.isEarlyStopped === undefined) {
trialJob.status = 'FAILED'; trialJob.status = 'FAILED';
} else {
trialJob.status = getJobCancelStatus(trialJob.isEarlyStopped);
}
} }
trialJob.endTime = parseInt(timestamp, 10); trialJob.endTime = parseInt(timestamp, 10);
} }
......
...@@ -19,14 +19,106 @@ ...@@ -19,14 +19,106 @@
'use strict'; 'use strict';
import { TrainingService } from '../../common/trainingService'; import * as assert from 'assert';
import { LocalTrainingService } from '../local/localTrainingService'; import * as chai from 'chai';
import * as chaiAsPromised from 'chai-as-promised';
import * as fs from 'fs';
import * as tmp from 'tmp';
import * as component from '../../common/component'; import * as component from '../../common/component';
import { TrialJobApplicationForm, TrialJobDetail, TrainingService } from '../../common/trainingService';
import { cleanupUnitTest, delay, prepareUnitTest } from '../../common/utils';
import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
import { LocalTrainingServiceForGPU } from '../local/localTrainingServiceForGPU';
// TODO: copy mockedTrail.py to local folder
const localCodeDir: string = tmp.dirSync().name
const mockedTrialPath: string = './training_service/test/mockedTrial.py'
fs.copyFileSync(mockedTrialPath, localCodeDir + '/mockedTrial.py')
describe('Unit Test for LocalTrainingService', () => { describe('Unit Test for LocalTrainingService', () => {
let trainingService: TrainingService let trialConfig: any = `{"command":"sleep 1h && echo hello","codeDir":"${localCodeDir}","gpuNum":1}`
let localTrainingService: LocalTrainingServiceForGPU;
before(() => {
chai.should();
chai.use(chaiAsPromised);
prepareUnitTest();
});
after(() => {
cleanupUnitTest();
});
beforeEach(() => {
localTrainingService = component.get(LocalTrainingServiceForGPU);
localTrainingService.run();
});
afterEach(() => {
localTrainingService.cleanUp();
});
it('List empty trial jobs', async () => {
//trial jobs should be empty, since there are no submitted jobs
chai.expect(await localTrainingService.listTrialJobs()).to.be.empty;
});
it('setClusterMetadata and getClusterMetadata', async () => {
await localTrainingService.setClusterMetadata(TrialConfigMetadataKey.TRIAL_CONFIG, trialConfig);
localTrainingService.getClusterMetadata(TrialConfigMetadataKey.TRIAL_CONFIG).then((data)=>{
chai.expect(data).to.be.equals(trialConfig);
});
});
it('Submit job and Cancel job', async () => {
await localTrainingService.setClusterMetadata(TrialConfigMetadataKey.TRIAL_CONFIG, trialConfig);
// submit job
const form: TrialJobApplicationForm = {
jobType: 'TRIAL',
hyperParameters: {
value: 'mock hyperparameters',
index: 0
}
};
const jobDetail: TrialJobDetail = await localTrainingService.submitTrialJob(form);
chai.expect(jobDetail.status).to.be.equals('WAITING');
await localTrainingService.cancelTrialJob(jobDetail.id);
chai.expect(jobDetail.status).to.be.equals('USER_CANCELED');
}).timeout(20000);
it('Read metrics, Add listener, and remove listener', async () => {
// set meta data
const trialConfig: string = `{\"command\":\"python3 mockedTrial.py\", \"codeDir\":\"${localCodeDir}\",\"gpuNum\":0}`
await localTrainingService.setClusterMetadata(TrialConfigMetadataKey.TRIAL_CONFIG, trialConfig);
// submit job
const form: TrialJobApplicationForm = {
jobType: 'TRIAL',
hyperParameters: {
value: 'mock hyperparameters',
index: 0
}
};
const jobDetail: TrialJobDetail = await localTrainingService.submitTrialJob(form);
chai.expect(jobDetail.status).to.be.equals('WAITING');
localTrainingService.listTrialJobs().then((jobList)=>{
chai.expect(jobList.length).to.be.equals(1);
});
// Add metrics listeners
const listener1 = function f1(metric: any) {
chai.expect(metric.id).to.be.equals(jobDetail.id);
}
localTrainingService.addTrialJobMetricListener(listener1);
// Wait to collect metric
await delay(1000);
await localTrainingService.cancelTrialJob(jobDetail.id);
localTrainingService.removeTrialJobMetricListener(listener1);
}).timeout(20000);
beforeEach(async () => { it('Test multiphaseSupported', () => {
trainingService = component.get(LocalTrainingService); chai.expect(localTrainingService.isMultiPhaseJobSupported).to.be.equals(true)
}) })
}); });
\ No newline at end of file
...@@ -182,6 +182,7 @@ class SlideBar extends React.Component<{}, SliderState> { ...@@ -182,6 +182,7 @@ class SlideBar extends React.Component<{}, SliderState> {
render() { render() {
const { version, menuVisible } = this.state; const { version, menuVisible } = this.state;
const feed = `https://github.com/Microsoft/nni/issues/new?labels=${version}`;
const menu = ( const menu = (
<Menu onClick={this.handleMenuClick}> <Menu onClick={this.handleMenuClick}>
<Menu.Item key="1">Experiment Parameters</Menu.Item> <Menu.Item key="1">Experiment Parameters</Menu.Item>
...@@ -221,7 +222,7 @@ class SlideBar extends React.Component<{}, SliderState> { ...@@ -221,7 +222,7 @@ class SlideBar extends React.Component<{}, SliderState> {
Download <Icon type="down" /> Download <Icon type="down" />
</a> </a>
</Dropdown> </Dropdown>
<a href="https://github.com/Microsoft/nni/issues/new?labels=v0.5.1" target="_blank"> <a href={feed} target="_blank">
<img <img
src={require('../static/img/icon/issue.png')} src={require('../static/img/icon/issue.png')}
alt="NNI github issue" alt="NNI github issue"
......
...@@ -38,7 +38,6 @@ def gen_new_config(config_file, training_service='local'): ...@@ -38,7 +38,6 @@ def gen_new_config(config_file, training_service='local'):
new_config_file = config_file + '.tmp' new_config_file = config_file + '.tmp'
ts = get_yml_content('training_service.yml')[training_service] ts = get_yml_content('training_service.yml')[training_service]
print(config)
print(ts) print(ts)
# hack for kubeflow trial config # hack for kubeflow trial config
...@@ -64,7 +63,6 @@ def run_test(config_file, training_service, local_gpu=False): ...@@ -64,7 +63,6 @@ def run_test(config_file, training_service, local_gpu=False):
return return
try: try:
print('Testing %s...' % config_file)
proc = subprocess.run(['nnictl', 'create', '--config', new_config_file]) proc = subprocess.run(['nnictl', 'create', '--config', new_config_file])
assert proc.returncode == 0, '`nnictl create` failed with code %d' % proc.returncode assert proc.returncode == 0, '`nnictl create` failed with code %d' % proc.returncode
...@@ -109,8 +107,10 @@ def run(args): ...@@ -109,8 +107,10 @@ def run(args):
try: try:
# sleep 5 seconds here, to make sure previous stopped exp has enough time to exit to avoid port conflict # sleep 5 seconds here, to make sure previous stopped exp has enough time to exit to avoid port conflict
time.sleep(5) time.sleep(5)
print(GREEN + 'Testing:' + config_file + CLEAR)
begin_time = time.time()
run_test(config_file, args.ts, args.local_gpu) run_test(config_file, args.ts, args.local_gpu)
print(GREEN + 'Test %s: TEST PASS' % (config_file) + CLEAR) print(GREEN + 'Test %s: TEST PASS IN %d mins' % (config_file, (time.time() - begin_time)/60) + CLEAR)
except Exception as error: except Exception as error:
print(RED + 'Test %s: TEST FAIL' % (config_file) + CLEAR) print(RED + 'Test %s: TEST FAIL' % (config_file) + CLEAR)
print('%r' % error) print('%r' % error)
......
jobs: jobs:
- job: 'integration_test_kubeflow' - job: 'integration_test_kubeflow'
timeoutInMinutes: 0
pool: 'NNI CI KUBE CLI' pool: 'NNI CI KUBE CLI'
variables: variables:
......
jobs: jobs:
- job: 'integration_test_pai' - job: 'integration_test_pai'
timeoutInMinutes: 0
pool: 'NNI CI PAI CLI' pool: 'NNI CI PAI CLI'
variables: variables:
......
...@@ -168,7 +168,9 @@ def set_remote_config(experiment_config, port, config_file_name): ...@@ -168,7 +168,9 @@ def set_remote_config(experiment_config, port, config_file_name):
with open(stderr_full_path, 'a+') as fout: with open(stderr_full_path, 'a+') as fout:
fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':'))) fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
return False, err_message return False, err_message
result, message = setNNIManagerIp(experiment_config, port, config_file_name)
if not result:
return result, message
#set trial_config #set trial_config
return set_trial_config(experiment_config, port, config_file_name), err_message return set_trial_config(experiment_config, port, config_file_name), err_message
......
...@@ -48,10 +48,25 @@ def main_loop(args): ...@@ -48,10 +48,25 @@ def main_loop(args):
# redirect trial keeper's stdout and stderr to syslog # redirect trial keeper's stdout and stderr to syslog
trial_syslogger_stdout = RemoteLogger(args.nnimanager_ip, args.nnimanager_port, 'trial', StdOutputType.Stdout) trial_syslogger_stdout = RemoteLogger(args.nnimanager_ip, args.nnimanager_port, 'trial', StdOutputType.Stdout)
sys.stdout = sys.stderr = trial_keeper_syslogger sys.stdout = sys.stderr = trial_keeper_syslogger
# backward compatibility
if args.pai_hdfs_host is not None and args.nni_hdfs_exp_dir is not None: hdfs_host = None
hdfs_output_dir = None
if args.hdfs_host:
hdfs_host = args.hdfs_host
elif args.pai_hdfs_host:
hdfs_host = args.pai_hdfs_host
if args.hdfs_output_dir:
hdfs_output_dir = args.hdfs_output_dir
elif args.pai_hdfs_output_dir:
hdfs_output_dir = args.pai_hdfs_output_dir
if hdfs_host is not None and args.nni_hdfs_exp_dir is not None:
try: try:
hdfs_client = HdfsClient(hosts='{0}:{1}'.format(args.pai_hdfs_host, '50070'), user_name=args.pai_user_name, timeout=5) if args.webhdfs_path:
hdfs_client = HdfsClient(hosts='{0}:80'.format(hdfs_host), user_name=args.pai_user_name, webhdfs_path=args.webhdfs_path, timeout=5)
else:
# backward compatibility
hdfs_client = HdfsClient(hosts='{0}:{1}'.format(hdfs_host, '50070'), user_name=args.pai_user_name, timeout=5)
except Exception as e: except Exception as e:
nni_log(LogType.Error, 'Create HDFS client error: ' + str(e)) nni_log(LogType.Error, 'Create HDFS client error: ' + str(e))
raise e raise e
...@@ -67,14 +82,14 @@ def main_loop(args): ...@@ -67,14 +82,14 @@ def main_loop(args):
# child worker process exits and all stdout data is read # child worker process exits and all stdout data is read
if retCode is not None and log_pipe_stdout.set_process_exit() and log_pipe_stdout.is_read_completed == True: if retCode is not None and log_pipe_stdout.set_process_exit() and log_pipe_stdout.is_read_completed == True:
nni_log(LogType.Info, 'subprocess terminated. Exit code is {}. Quit'.format(retCode)) nni_log(LogType.Info, 'subprocess terminated. Exit code is {}. Quit'.format(retCode))
if args.pai_hdfs_output_dir is not None: if hdfs_output_dir is not None:
# Copy local directory to hdfs for OpenPAI # Copy local directory to hdfs for OpenPAI
nni_local_output_dir = os.environ['NNI_OUTPUT_DIR'] nni_local_output_dir = os.environ['NNI_OUTPUT_DIR']
try: try:
if copyDirectoryToHdfs(nni_local_output_dir, args.pai_hdfs_output_dir, hdfs_client): if copyDirectoryToHdfs(nni_local_output_dir, hdfs_output_dir, hdfs_client):
nni_log(LogType.Info, 'copy directory from {0} to {1} success!'.format(nni_local_output_dir, args.pai_hdfs_output_dir)) nni_log(LogType.Info, 'copy directory from {0} to {1} success!'.format(nni_local_output_dir, hdfs_output_dir))
else: else:
nni_log(LogType.Info, 'copy directory from {0} to {1} failed!'.format(nni_local_output_dir, args.pai_hdfs_output_dir)) nni_log(LogType.Info, 'copy directory from {0} to {1} failed!'.format(nni_local_output_dir, hdfs_output_dir))
except Exception as e: except Exception as e:
nni_log(LogType.Error, 'HDFS copy directory got exception: ' + str(e)) nni_log(LogType.Error, 'HDFS copy directory got exception: ' + str(e))
raise e raise e
...@@ -95,10 +110,13 @@ if __name__ == '__main__': ...@@ -95,10 +110,13 @@ if __name__ == '__main__':
PARSER.add_argument('--trial_command', type=str, help='Command to launch trial process') PARSER.add_argument('--trial_command', type=str, help='Command to launch trial process')
PARSER.add_argument('--nnimanager_ip', type=str, default='localhost', help='NNI manager rest server IP') PARSER.add_argument('--nnimanager_ip', type=str, default='localhost', help='NNI manager rest server IP')
PARSER.add_argument('--nnimanager_port', type=str, default='8081', help='NNI manager rest server port') PARSER.add_argument('--nnimanager_port', type=str, default='8081', help='NNI manager rest server port')
PARSER.add_argument('--pai_hdfs_output_dir', type=str, help='the output dir of hdfs') PARSER.add_argument('--pai_hdfs_output_dir', type=str, help='the output dir of pai_hdfs') # backward compatibility
PARSER.add_argument('--pai_hdfs_host', type=str, help='the host of hdfs') PARSER.add_argument('--hdfs_output_dir', type=str, help='the output dir of hdfs')
PARSER.add_argument('--pai_hdfs_host', type=str, help='the host of pai_hdfs') # backward compatibility
PARSER.add_argument('--hdfs_host', type=str, help='the host of hdfs')
PARSER.add_argument('--pai_user_name', type=str, help='the username of hdfs') PARSER.add_argument('--pai_user_name', type=str, help='the username of hdfs')
PARSER.add_argument('--nni_hdfs_exp_dir', type=str, help='nni experiment directory in hdfs') PARSER.add_argument('--nni_hdfs_exp_dir', type=str, help='nni experiment directory in hdfs')
PARSER.add_argument('--webhdfs_path', type=str, help='the webhdfs path used in webhdfs URL')
args, unknown = PARSER.parse_known_args() args, unknown = PARSER.parse_known_args()
if args.trial_command is None: if args.trial_command is None:
exit(1) exit(1)
......
...@@ -12,7 +12,7 @@ setuptools.setup( ...@@ -12,7 +12,7 @@ setuptools.setup(
'psutil', 'psutil',
'astor', 'astor',
'schema', 'schema',
'pyhdfs' 'PythonWebHDFS'
], ],
author = 'Microsoft NNI Team', author = 'Microsoft NNI Team',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment