Unverified Commit c785655e authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #207 from microsoft/master

merge master
parents 9fae194a d6b61e2f
...@@ -25,9 +25,9 @@ import * as path from 'path'; ...@@ -25,9 +25,9 @@ import * as path from 'path';
import * as component from '../common/component'; import * as component from '../common/component';
import { DataStore, MetricDataRecord, TrialJobInfo } from '../common/datastore'; import { DataStore, MetricDataRecord, TrialJobInfo } from '../common/datastore';
import { NNIError, NNIErrorNames } from '../common/errors'; import { NNIError, NNIErrorNames } from '../common/errors';
import { isNewExperiment } from '../common/experimentStartupInfo'; import { isNewExperiment, isReadonly } from '../common/experimentStartupInfo';
import { getLogger, Logger } from '../common/log'; import { getLogger, Logger } from '../common/log';
import { ExperimentProfile, Manager, TrialJobStatistics} from '../common/manager'; import { ExperimentProfile, Manager, TrialJobStatistics, ExperimentStartUpMode } from '../common/manager';
import { ValidationSchemas } from './restValidationSchemas'; import { ValidationSchemas } from './restValidationSchemas';
import { NNIRestServer } from './nniRestServer'; import { NNIRestServer } from './nniRestServer';
import { getVersion } from '../common/utils'; import { getVersion } from '../common/utils';
...@@ -72,6 +72,8 @@ class NNIRestHandler { ...@@ -72,6 +72,8 @@ class NNIRestHandler {
this.addTrialJob(router); this.addTrialJob(router);
this.cancelTrialJob(router); this.cancelTrialJob(router);
this.getMetricData(router); this.getMetricData(router);
this.getMetricDataByRange(router);
this.getLatestMetricData(router);
this.exportData(router); this.exportData(router);
// Express-joi-validator configuration // Express-joi-validator configuration
...@@ -86,11 +88,11 @@ class NNIRestHandler { ...@@ -86,11 +88,11 @@ class NNIRestHandler {
return router; return router;
} }
private handle_error(err: Error, res: Response, isFatal: boolean = false): void { private handle_error(err: Error, res: Response, isFatal: boolean = false, errorCode: number = 500): void {
if (err instanceof NNIError && err.name === NNIErrorNames.NOT_FOUND) { if (err instanceof NNIError && err.name === NNIErrorNames.NOT_FOUND) {
res.status(404); res.status(404);
} else { } else {
res.status(500); res.status(errorCode);
} }
res.send({ res.send({
error: err.message error: err.message
...@@ -169,13 +171,13 @@ class NNIRestHandler { ...@@ -169,13 +171,13 @@ class NNIRestHandler {
this.handle_error(err, res); this.handle_error(err, res);
}); });
} else { } else {
this.nniManager.resumeExperiment().then(() => { this.nniManager.resumeExperiment(isReadonly()).then(() => {
res.send(); res.send();
}).catch((err: Error) => { }).catch((err: Error) => {
// Resume experiment is a step of initialization, so any exception thrown is a fatal // Resume experiment is a step of initialization, so any exception thrown is a fatal
this.handle_error(err, res); this.handle_error(err, res);
}); });
} }
}); });
} }
...@@ -193,18 +195,18 @@ class NNIRestHandler { ...@@ -193,18 +195,18 @@ class NNIRestHandler {
router.put( router.put(
'/experiment/cluster-metadata', expressJoi(ValidationSchemas.SETCLUSTERMETADATA), '/experiment/cluster-metadata', expressJoi(ValidationSchemas.SETCLUSTERMETADATA),
async (req: Request, res: Response) => { async (req: Request, res: Response) => {
// tslint:disable-next-line:no-any // tslint:disable-next-line:no-any
const metadata: any = req.body; const metadata: any = req.body;
const keys: string[] = Object.keys(metadata); const keys: string[] = Object.keys(metadata);
try { try {
for (const key of keys) { for (const key of keys) {
await this.nniManager.setClusterMetadata(key, JSON.stringify(metadata[key])); await this.nniManager.setClusterMetadata(key, JSON.stringify(metadata[key]));
}
res.send();
} catch (err) {
// setClusterMetata is a step of initialization, so any exception thrown is a fatal
this.handle_error(NNIError.FromError(err), res, true);
} }
res.send();
} catch (err) {
// setClusterMetata is a step of initialization, so any exception thrown is a fatal
this.handle_error(NNIError.FromError(err), res, true);
}
}); });
} }
...@@ -262,6 +264,28 @@ class NNIRestHandler { ...@@ -262,6 +264,28 @@ class NNIRestHandler {
}); });
} }
private getMetricDataByRange(router: Router): void {
router.get('/metric-data-range/:min_seq_id/:max_seq_id', async (req: Request, res: Response) => {
const minSeqId = Number(req.params.min_seq_id);
const maxSeqId = Number(req.params.max_seq_id);
this.nniManager.getMetricDataByRange(minSeqId, maxSeqId).then((metricsData: MetricDataRecord[]) => {
res.send(metricsData);
}).catch((err: Error) => {
this.handle_error(err, res);
});
});
}
private getLatestMetricData(router: Router): void {
router.get('/metric-data-latest/', async (req: Request, res: Response) => {
this.nniManager.getLatestMetricData().then((metricsData: MetricDataRecord[]) => {
res.send(metricsData);
}).catch((err: Error) => {
this.handle_error(err, res);
});
});
}
private exportData(router: Router): void { private exportData(router: Router): void {
router.get('/export-data', (req: Request, res: Response) => { router.get('/export-data', (req: Request, res: Response) => {
this.nniManager.exportData().then((exportedData: string) => { this.nniManager.exportData().then((exportedData: string) => {
......
...@@ -170,18 +170,18 @@ export namespace ValidationSchemas { ...@@ -170,18 +170,18 @@ export namespace ValidationSchemas {
classFileName: joi.string(), classFileName: joi.string(),
className: joi.string(), className: joi.string(),
classArgs: joi.any(), classArgs: joi.any(),
gpuNum: joi.number().min(0), checkpointDir: joi.string().allow(''),
checkpointDir: joi.string().allow('') gpuIndices: joi.string()
}), }),
tuner: joi.object({ tuner: joi.object({
builtinTunerName: joi.string().valid('TPE', 'Random', 'Anneal', 'Evolution', 'SMAC', 'BatchTuner', 'GridSearch', 'NetworkMorphism', 'MetisTuner', 'GPTuner'), builtinTunerName: joi.string().valid('TPE', 'Random', 'Anneal', 'Evolution', 'SMAC', 'BatchTuner', 'GridSearch', 'NetworkMorphism', 'MetisTuner', 'GPTuner', 'PPOTuner'),
codeDir: joi.string(), codeDir: joi.string(),
classFileName: joi.string(), classFileName: joi.string(),
className: joi.string(), className: joi.string(),
classArgs: joi.any(), classArgs: joi.any(),
gpuNum: joi.number().min(0),
checkpointDir: joi.string().allow(''), checkpointDir: joi.string().allow(''),
includeIntermediateResults: joi.boolean() includeIntermediateResults: joi.boolean(),
gpuIndices: joi.string()
}), }),
assessor: joi.object({ assessor: joi.object({
builtinAssessorName: joi.string().valid('Medianstop', 'Curvefitting'), builtinAssessorName: joi.string().valid('Medianstop', 'Curvefitting'),
...@@ -189,7 +189,6 @@ export namespace ValidationSchemas { ...@@ -189,7 +189,6 @@ export namespace ValidationSchemas {
classFileName: joi.string(), classFileName: joi.string(),
className: joi.string(), className: joi.string(),
classArgs: joi.any(), classArgs: joi.any(),
gpuNum: joi.number().min(0),
checkpointDir: joi.string().allow('') checkpointDir: joi.string().allow('')
}), }),
clusterMetaData: joi.array().items(joi.object({ clusterMetaData: joi.array().items(joi.object({
...@@ -210,7 +209,7 @@ export namespace ValidationSchemas { ...@@ -210,7 +209,7 @@ export namespace ValidationSchemas {
startTime: joi.number(), startTime: joi.number(),
endTime: joi.number(), endTime: joi.number(),
logDir: joi.string(), logDir: joi.string(),
maxSequenceId: joi.number() nextSequenceId: joi.number()
} }
}; };
} }
...@@ -85,9 +85,9 @@ export class MockedNNIManager extends Manager { ...@@ -85,9 +85,9 @@ export class MockedNNIManager extends Manager {
// tslint:disable-next-line:no-http-string // tslint:disable-next-line:no-http-string
url: 'http://test', url: 'http://test',
workingDirectory: '/tmp/mocked', workingDirectory: '/tmp/mocked',
sequenceId: 0,
form: { form: {
jobType: 'TRIAL' sequenceId: 0,
hyperParameters: { value: '', index: 0 }
} }
}; };
deferred.resolve(jobDetail); deferred.resolve(jobDetail);
...@@ -129,6 +129,12 @@ export class MockedNNIManager extends Manager { ...@@ -129,6 +129,12 @@ export class MockedNNIManager extends Manager {
public getMetricData(trialJobId: string, metricType: MetricType): Promise<MetricDataRecord[]> { public getMetricData(trialJobId: string, metricType: MetricType): Promise<MetricDataRecord[]> {
throw new MethodNotImplementedError(); throw new MethodNotImplementedError();
} }
public getMetricDataByRange(minSeqId: number, maxSeqId: number): Promise<MetricDataRecord[]> {
throw new MethodNotImplementedError();
}
public getLatestMetricData(): Promise<MetricDataRecord[]> {
throw new MethodNotImplementedError();
}
public getExperimentProfile(): Promise<ExperimentProfile> { public getExperimentProfile(): Promise<ExperimentProfile> {
const profile: ExperimentProfile = { const profile: ExperimentProfile = {
params: { params: {
...@@ -148,7 +154,7 @@ export class MockedNNIManager extends Manager { ...@@ -148,7 +154,7 @@ export class MockedNNIManager extends Manager {
execDuration: 0, execDuration: 0,
startTime: Date.now(), startTime: Date.now(),
endTime: Date.now(), endTime: Date.now(),
maxSequenceId: 0, nextSequenceId: 0,
revision: 0 revision: 0
}; };
......
...@@ -156,7 +156,7 @@ export async function execRemove(directory: string): Promise<void> { ...@@ -156,7 +156,7 @@ export async function execRemove(directory: string): Promise<void> {
*/ */
export async function execKill(pid: string): Promise<void> { export async function execKill(pid: string): Promise<void> {
if (process.platform === 'win32') { if (process.platform === 'win32') {
await cpp.exec(`cmd /c taskkill /PID ${pid} /T /F`); await cpp.exec(`cmd.exe /c taskkill /PID ${pid} /T /F`);
} else { } else {
await cpp.exec(`pkill -P ${pid}`); await cpp.exec(`pkill -P ${pid}`);
} }
......
...@@ -25,7 +25,7 @@ import * as path from 'path'; ...@@ -25,7 +25,7 @@ import * as path from 'path';
import * as component from '../../../common/component'; import * as component from '../../../common/component';
import { getExperimentId } from '../../../common/experimentStartupInfo'; import { getExperimentId } from '../../../common/experimentStartupInfo';
import { import {
JobApplicationForm, NNIManagerIpConfig, TrialJobApplicationForm, TrialJobDetail, TrialJobStatus NNIManagerIpConfig, TrialJobApplicationForm, TrialJobDetail, TrialJobStatus
} from '../../../common/trainingService'; } from '../../../common/trainingService';
import { delay, generateParamFileName, getExperimentRootDir, uniqueString } from '../../../common/utils'; import { delay, generateParamFileName, getExperimentRootDir, uniqueString } from '../../../common/utils';
import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../../common/containerJobData'; import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../../common/containerJobData';
...@@ -55,7 +55,6 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple ...@@ -55,7 +55,6 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple
super(); super();
this.fcJobInfoCollector = new FrameworkControllerJobInfoCollector(this.trialJobsMap); this.fcJobInfoCollector = new FrameworkControllerJobInfoCollector(this.trialJobsMap);
this.experimentId = getExperimentId(); this.experimentId = getExperimentId();
this.nextTrialSequenceId = -1;
} }
public async run(): Promise<void> { public async run(): Promise<void> {
...@@ -77,7 +76,7 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple ...@@ -77,7 +76,7 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple
} }
} }
public async submitTrialJob(form: JobApplicationForm): Promise<TrialJobDetail> { public async submitTrialJob(form: TrialJobApplicationForm): Promise<TrialJobDetail> {
if (this.fcClusterConfig === undefined) { if (this.fcClusterConfig === undefined) {
throw new Error('frameworkcontrollerClusterConfig is not initialized'); throw new Error('frameworkcontrollerClusterConfig is not initialized');
} }
...@@ -91,14 +90,13 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple ...@@ -91,14 +90,13 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple
} }
const trialJobId: string = uniqueString(5); const trialJobId: string = uniqueString(5);
const curTrialSequenceId: number = this.generateSequenceId();
// Set trial's NFS working folder // Set trial's NFS working folder
const trialWorkingFolder: string = path.join(this.CONTAINER_MOUNT_PATH, 'nni', getExperimentId(), trialJobId); const trialWorkingFolder: string = path.join(this.CONTAINER_MOUNT_PATH, 'nni', getExperimentId(), trialJobId);
const trialLocalTempFolder: string = path.join(getExperimentRootDir(), 'trials-local', trialJobId); const trialLocalTempFolder: string = path.join(getExperimentRootDir(), 'trials-local', trialJobId);
const frameworkcontrollerJobName: string = `nniexp${this.experimentId}trial${trialJobId}`.toLowerCase(); const frameworkcontrollerJobName: string = `nniexp${this.experimentId}trial${trialJobId}`.toLowerCase();
//Generate the port used for taskRole //Generate the port used for taskRole
this.generateContainerPort(); this.generateContainerPort();
await this.prepareRunScript(trialLocalTempFolder, curTrialSequenceId, trialJobId, trialWorkingFolder, form); await this.prepareRunScript(trialLocalTempFolder, trialJobId, trialWorkingFolder, form);
//upload code files //upload code files
const trialJobOutputUrl: string = await this.uploadCodeFiles(trialJobId, trialLocalTempFolder); const trialJobOutputUrl: string = await this.uploadCodeFiles(trialJobId, trialLocalTempFolder);
...@@ -113,7 +111,6 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple ...@@ -113,7 +111,6 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple
trialWorkingFolder, trialWorkingFolder,
form, form,
frameworkcontrollerJobName, frameworkcontrollerJobName,
curTrialSequenceId,
trialJobOutputUrl trialJobOutputUrl
); );
...@@ -248,8 +245,8 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple ...@@ -248,8 +245,8 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple
return `${portScript} . /mnt/frameworkbarrier/injector.sh && ${command}`; return `${portScript} . /mnt/frameworkbarrier/injector.sh && ${command}`;
} }
private async prepareRunScript(trialLocalTempFolder: string, curTrialSequenceId: number, trialJobId: string, private async prepareRunScript(trialLocalTempFolder: string, trialJobId: string,
trialWorkingFolder: string, form: JobApplicationForm): Promise<void> { trialWorkingFolder: string, form: TrialJobApplicationForm): Promise<void> {
if (this.fcTrialConfig === undefined) { if (this.fcTrialConfig === undefined) {
throw new Error('frameworkcontroller trial config is not initialized'); throw new Error('frameworkcontroller trial config is not initialized');
} }
...@@ -264,16 +261,16 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple ...@@ -264,16 +261,16 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple
for (const taskRole of this.fcTrialConfig.taskRoles) { for (const taskRole of this.fcTrialConfig.taskRoles) {
const runScriptContent: string = const runScriptContent: string =
await this.generateRunScript('frameworkcontroller', trialJobId, trialWorkingFolder, await this.generateRunScript('frameworkcontroller', trialJobId, trialWorkingFolder,
this.generateCommandScript(taskRole.command), curTrialSequenceId.toString(), this.generateCommandScript(taskRole.command), form.sequenceId.toString(),
taskRole.name, taskRole.gpuNum); taskRole.name, taskRole.gpuNum);
await fs.promises.writeFile(path.join(trialLocalTempFolder, `run_${taskRole.name}.sh`), runScriptContent, { encoding: 'utf8' }); await fs.promises.writeFile(path.join(trialLocalTempFolder, `run_${taskRole.name}.sh`), runScriptContent, { encoding: 'utf8' });
} }
// Write file content ( parameter.cfg ) to local tmp folders // Write file content ( parameter.cfg ) to local tmp folders
const trialForm : TrialJobApplicationForm = (<TrialJobApplicationForm>form); const trialForm : TrialJobApplicationForm = (<TrialJobApplicationForm>form);
if (trialForm !== undefined && trialForm.hyperParameters !== undefined) { if (form !== undefined) {
await fs.promises.writeFile(path.join(trialLocalTempFolder, generateParamFileName(trialForm.hyperParameters)), await fs.promises.writeFile(path.join(trialLocalTempFolder, generateParamFileName(form.hyperParameters)),
trialForm.hyperParameters.value, { encoding: 'utf8' }); form.hyperParameters.value, { encoding: 'utf8' });
} }
} }
......
...@@ -27,7 +27,7 @@ import * as component from '../../../common/component'; ...@@ -27,7 +27,7 @@ import * as component from '../../../common/component';
import { getExperimentId } from '../../../common/experimentStartupInfo'; import { getExperimentId } from '../../../common/experimentStartupInfo';
import { import {
JobApplicationForm, NNIManagerIpConfig, TrialJobApplicationForm, TrialJobDetail, TrialJobStatus NNIManagerIpConfig, TrialJobApplicationForm, TrialJobDetail, TrialJobStatus
} from '../../../common/trainingService'; } from '../../../common/trainingService';
import { delay, generateParamFileName, getExperimentRootDir, uniqueString } from '../../../common/utils'; import { delay, generateParamFileName, getExperimentRootDir, uniqueString } from '../../../common/utils';
import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../../common/containerJobData'; import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../../common/containerJobData';
...@@ -59,7 +59,6 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber ...@@ -59,7 +59,6 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber
super(); super();
this.kubeflowJobInfoCollector = new KubeflowJobInfoCollector(this.trialJobsMap); this.kubeflowJobInfoCollector = new KubeflowJobInfoCollector(this.trialJobsMap);
this.experimentId = getExperimentId(); this.experimentId = getExperimentId();
this.nextTrialSequenceId = -1;
this.log.info('Construct Kubeflow training service.'); this.log.info('Construct Kubeflow training service.');
} }
...@@ -84,7 +83,7 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber ...@@ -84,7 +83,7 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber
this.log.info('Kubeflow training service exit.'); this.log.info('Kubeflow training service exit.');
} }
public async submitTrialJob(form: JobApplicationForm): Promise<TrialJobDetail> { public async submitTrialJob(form: TrialJobApplicationForm): Promise<TrialJobDetail> {
if (this.kubernetesCRDClient === undefined) { if (this.kubernetesCRDClient === undefined) {
throw new Error('Kubeflow job operator client is undefined'); throw new Error('Kubeflow job operator client is undefined');
} }
...@@ -96,10 +95,9 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber ...@@ -96,10 +95,9 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber
const trialJobId: string = uniqueString(5); const trialJobId: string = uniqueString(5);
const trialWorkingFolder: string = path.join(this.CONTAINER_MOUNT_PATH, 'nni', getExperimentId(), trialJobId); const trialWorkingFolder: string = path.join(this.CONTAINER_MOUNT_PATH, 'nni', getExperimentId(), trialJobId);
const kubeflowJobName: string = `nni-exp-${this.experimentId}-trial-${trialJobId}`.toLowerCase(); const kubeflowJobName: string = `nni-exp-${this.experimentId}-trial-${trialJobId}`.toLowerCase();
const curTrialSequenceId: number = this.generateSequenceId();
const trialLocalTempFolder: string = path.join(getExperimentRootDir(), 'trials-local', trialJobId); const trialLocalTempFolder: string = path.join(getExperimentRootDir(), 'trials-local', trialJobId);
//prepare the runscript //prepare the runscript
await this.prepareRunScript(trialLocalTempFolder, trialJobId, trialWorkingFolder, curTrialSequenceId, form); await this.prepareRunScript(trialLocalTempFolder, trialJobId, trialWorkingFolder, form);
//upload files to sotrage //upload files to sotrage
const trialJobOutputUrl: string = await this.uploadCodeFiles(trialJobId, trialLocalTempFolder); const trialJobOutputUrl: string = await this.uploadCodeFiles(trialJobId, trialLocalTempFolder);
let initStatus: TrialJobStatus = 'WAITING'; let initStatus: TrialJobStatus = 'WAITING';
...@@ -113,7 +111,6 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber ...@@ -113,7 +111,6 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber
trialWorkingFolder, trialWorkingFolder,
form, form,
kubeflowJobName, kubeflowJobName,
curTrialSequenceId,
trialJobOutputUrl trialJobOutputUrl
); );
...@@ -236,8 +233,8 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber ...@@ -236,8 +233,8 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber
return Promise.resolve(trialJobOutputUrl); return Promise.resolve(trialJobOutputUrl);
} }
private async prepareRunScript(trialLocalTempFolder: string, trialJobId: string, trialWorkingFolder: string, curTrialSequenceId: number, private async prepareRunScript(trialLocalTempFolder: string, trialJobId: string, trialWorkingFolder: string,
form: JobApplicationForm): Promise<void> { form: TrialJobApplicationForm): Promise<void> {
if (this.kubeflowClusterConfig === undefined) { if (this.kubeflowClusterConfig === undefined) {
throw new Error('Kubeflow Cluster config is not initialized'); throw new Error('Kubeflow Cluster config is not initialized');
} }
...@@ -262,7 +259,7 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber ...@@ -262,7 +259,7 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber
if (kubeflowTrialConfig.worker !== undefined) { if (kubeflowTrialConfig.worker !== undefined) {
const workerRunScriptContent: string = await this.generateRunScript('kubeflow', trialJobId, trialWorkingFolder, const workerRunScriptContent: string = await this.generateRunScript('kubeflow', trialJobId, trialWorkingFolder,
kubeflowTrialConfig.worker.command, kubeflowTrialConfig.worker.command,
curTrialSequenceId.toString(), 'worker', form.sequenceId.toString(), 'worker',
kubeflowTrialConfig.worker.gpuNum); kubeflowTrialConfig.worker.gpuNum);
await fs.promises.writeFile(path.join(trialLocalTempFolder, 'run_worker.sh'), workerRunScriptContent, { encoding: 'utf8' }); await fs.promises.writeFile(path.join(trialLocalTempFolder, 'run_worker.sh'), workerRunScriptContent, { encoding: 'utf8' });
} }
...@@ -272,7 +269,7 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber ...@@ -272,7 +269,7 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber
if (tensorflowTrialConfig.ps !== undefined) { if (tensorflowTrialConfig.ps !== undefined) {
const psRunScriptContent: string = await this.generateRunScript('kubeflow', trialJobId, trialWorkingFolder, const psRunScriptContent: string = await this.generateRunScript('kubeflow', trialJobId, trialWorkingFolder,
tensorflowTrialConfig.ps.command, tensorflowTrialConfig.ps.command,
curTrialSequenceId.toString(), form.sequenceId.toString(),
'ps', tensorflowTrialConfig.ps.gpuNum); 'ps', tensorflowTrialConfig.ps.gpuNum);
await fs.promises.writeFile(path.join(trialLocalTempFolder, 'run_ps.sh'), psRunScriptContent, { encoding: 'utf8' }); await fs.promises.writeFile(path.join(trialLocalTempFolder, 'run_ps.sh'), psRunScriptContent, { encoding: 'utf8' });
} }
...@@ -281,16 +278,15 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber ...@@ -281,16 +278,15 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber
if (pytorchTrialConfig.master !== undefined) { if (pytorchTrialConfig.master !== undefined) {
const masterRunScriptContent: string = await this.generateRunScript('kubeflow', trialJobId, trialWorkingFolder, const masterRunScriptContent: string = await this.generateRunScript('kubeflow', trialJobId, trialWorkingFolder,
pytorchTrialConfig.master.command, pytorchTrialConfig.master.command,
curTrialSequenceId.toString(), 'master', form.sequenceId.toString(), 'master',
pytorchTrialConfig.master.gpuNum); pytorchTrialConfig.master.gpuNum);
await fs.promises.writeFile(path.join(trialLocalTempFolder, 'run_master.sh'), masterRunScriptContent, { encoding: 'utf8' }); await fs.promises.writeFile(path.join(trialLocalTempFolder, 'run_master.sh'), masterRunScriptContent, { encoding: 'utf8' });
} }
} }
// Write file content ( parameter.cfg ) to local tmp folders // Write file content ( parameter.cfg ) to local tmp folders
const trialForm : TrialJobApplicationForm = (<TrialJobApplicationForm>form); if (form !== undefined) {
if (trialForm !== undefined && trialForm.hyperParameters !== undefined) { await fs.promises.writeFile(path.join(trialLocalTempFolder, generateParamFileName(form.hyperParameters)),
await fs.promises.writeFile(path.join(trialLocalTempFolder, generateParamFileName(trialForm.hyperParameters)), form.hyperParameters.value, { encoding: 'utf8' });
trialForm.hyperParameters.value, { encoding: 'utf8' });
} }
} }
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
'use strict'; 'use strict';
import { JobApplicationForm, TrialJobDetail, TrialJobStatus } from '../../common/trainingService'; import { TrialJobApplicationForm, TrialJobDetail, TrialJobStatus } from '../../common/trainingService';
/** /**
* KubeflowTrialJobDetail * KubeflowTrialJobDetail
...@@ -33,21 +33,19 @@ export class KubernetesTrialJobDetail implements TrialJobDetail { ...@@ -33,21 +33,19 @@ export class KubernetesTrialJobDetail implements TrialJobDetail {
public tags?: string[]; public tags?: string[];
public url?: string; public url?: string;
public workingDirectory: string; public workingDirectory: string;
public form: JobApplicationForm; public form: TrialJobApplicationForm;
public kubernetesJobName: string; public kubernetesJobName: string;
public sequenceId: number;
public queryJobFailedCount: number; public queryJobFailedCount: number;
constructor(id: string, status: TrialJobStatus, submitTime: number, constructor(id: string, status: TrialJobStatus, submitTime: number,
workingDirectory: string, form: JobApplicationForm, workingDirectory: string, form: TrialJobApplicationForm,
kubernetesJobName: string, sequenceId: number, url: string) { kubernetesJobName: string, url: string) {
this.id = id; this.id = id;
this.status = status; this.status = status;
this.submitTime = submitTime; this.submitTime = submitTime;
this.workingDirectory = workingDirectory; this.workingDirectory = workingDirectory;
this.form = form; this.form = form;
this.kubernetesJobName = kubernetesJobName; this.kubernetesJobName = kubernetesJobName;
this.sequenceId = sequenceId;
this.tags = []; this.tags = [];
this.queryJobFailedCount = 0; this.queryJobFailedCount = 0;
this.url = url; this.url = url;
......
...@@ -26,7 +26,7 @@ import * as azureStorage from 'azure-storage'; ...@@ -26,7 +26,7 @@ import * as azureStorage from 'azure-storage';
import { EventEmitter } from 'events'; import { EventEmitter } from 'events';
import { Base64 } from 'js-base64'; import { Base64 } from 'js-base64';
import { String } from 'typescript-string-operations'; import { String } from 'typescript-string-operations';
import { getExperimentId, getInitTrialSequenceId } from '../../common/experimentStartupInfo'; import { getExperimentId } from '../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../common/log'; import { getLogger, Logger } from '../../common/log';
import { import {
NNIManagerIpConfig, TrialJobDetail, TrialJobMetric NNIManagerIpConfig, TrialJobDetail, TrialJobMetric
...@@ -53,7 +53,6 @@ abstract class KubernetesTrainingService { ...@@ -53,7 +53,6 @@ abstract class KubernetesTrainingService {
protected readonly trialLocalNFSTempFolder: string; protected readonly trialLocalNFSTempFolder: string;
protected stopping: boolean = false; protected stopping: boolean = false;
protected experimentId! : string; protected experimentId! : string;
protected nextTrialSequenceId: number;
protected kubernetesRestServerPort?: number; protected kubernetesRestServerPort?: number;
protected readonly CONTAINER_MOUNT_PATH: string; protected readonly CONTAINER_MOUNT_PATH: string;
protected azureStorageClient?: azureStorage.FileService; protected azureStorageClient?: azureStorage.FileService;
...@@ -74,7 +73,6 @@ abstract class KubernetesTrainingService { ...@@ -74,7 +73,6 @@ abstract class KubernetesTrainingService {
this.trialJobsMap = new Map<string, KubernetesTrialJobDetail>(); this.trialJobsMap = new Map<string, KubernetesTrialJobDetail>();
this.trialLocalNFSTempFolder = path.join(getExperimentRootDir(), 'trials-nfs-tmp'); this.trialLocalNFSTempFolder = path.join(getExperimentRootDir(), 'trials-nfs-tmp');
this.experimentId = getExperimentId(); this.experimentId = getExperimentId();
this.nextTrialSequenceId = -1;
this.CONTAINER_MOUNT_PATH = '/tmp/mount'; this.CONTAINER_MOUNT_PATH = '/tmp/mount';
this.genericK8sClient = new GeneralK8sClient(); this.genericK8sClient = new GeneralK8sClient();
this.logCollection = 'none'; this.logCollection = 'none';
...@@ -93,9 +91,7 @@ abstract class KubernetesTrainingService { ...@@ -93,9 +91,7 @@ abstract class KubernetesTrainingService {
const jobs: TrialJobDetail[] = []; const jobs: TrialJobDetail[] = [];
for (const [key, value] of this.trialJobsMap) { for (const [key, value] of this.trialJobsMap) {
if (value.form.jobType === 'TRIAL') { jobs.push(await this.getTrialJob(key));
jobs.push(await this.getTrialJob(key));
}
} }
return Promise.resolve(jobs); return Promise.resolve(jobs);
...@@ -222,14 +218,6 @@ abstract class KubernetesTrainingService { ...@@ -222,14 +218,6 @@ abstract class KubernetesTrainingService {
return Promise.resolve(); return Promise.resolve();
} }
protected generateSequenceId(): number {
if (this.nextTrialSequenceId === -1) {
this.nextTrialSequenceId = getInitTrialSequenceId();
}
return this.nextTrialSequenceId++;
}
// tslint:disable: no-unsafe-any no-any // tslint:disable: no-unsafe-any no-any
protected async createAzureStorage(vaultName: string, valutKeyName: string, accountName: string, azureShare: string): Promise<void> { protected async createAzureStorage(vaultName: string, valutKeyName: string, accountName: string, azureShare: string): Promise<void> {
try { try {
......
...@@ -26,10 +26,10 @@ import * as path from 'path'; ...@@ -26,10 +26,10 @@ import * as path from 'path';
import * as ts from 'tail-stream'; import * as ts from 'tail-stream';
import * as tkill from 'tree-kill'; import * as tkill from 'tree-kill';
import { NNIError, NNIErrorNames } from '../../common/errors'; import { NNIError, NNIErrorNames } from '../../common/errors';
import { getExperimentId, getInitTrialSequenceId } from '../../common/experimentStartupInfo'; import { getExperimentId } from '../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../common/log'; import { getLogger, Logger } from '../../common/log';
import { import {
HostJobApplicationForm, HyperParameters, JobApplicationForm, TrainingService, TrialJobApplicationForm, HyperParameters, TrainingService, TrialJobApplicationForm,
TrialJobDetail, TrialJobMetric, TrialJobStatus TrialJobDetail, TrialJobMetric, TrialJobStatus
} from '../../common/trainingService'; } from '../../common/trainingService';
import { import {
...@@ -76,21 +76,19 @@ class LocalTrialJobDetail implements TrialJobDetail { ...@@ -76,21 +76,19 @@ class LocalTrialJobDetail implements TrialJobDetail {
public tags?: string[]; public tags?: string[];
public url?: string; public url?: string;
public workingDirectory: string; public workingDirectory: string;
public form: JobApplicationForm; public form: TrialJobApplicationForm;
public sequenceId: number;
public pid?: number; public pid?: number;
public gpuIndices?: number[]; public gpuIndices?: number[];
constructor( constructor(
id: string, status: TrialJobStatus, submitTime: number, id: string, status: TrialJobStatus, submitTime: number,
workingDirectory: string, form: JobApplicationForm, sequenceId: number) { workingDirectory: string, form: TrialJobApplicationForm) {
this.id = id; this.id = id;
this.status = status; this.status = status;
this.submitTime = submitTime; this.submitTime = submitTime;
this.workingDirectory = workingDirectory; this.workingDirectory = workingDirectory;
this.form = form; this.form = form;
this.url = `file://localhost:${workingDirectory}`; this.url = `file://localhost:${workingDirectory}`;
this.sequenceId = sequenceId;
this.gpuIndices = []; this.gpuIndices = [];
} }
} }
...@@ -125,7 +123,6 @@ class LocalTrainingService implements TrainingService { ...@@ -125,7 +123,6 @@ class LocalTrainingService implements TrainingService {
private initialized: boolean; private initialized: boolean;
private stopping: boolean; private stopping: boolean;
private rootDir!: string; private rootDir!: string;
private trialSequenceId: number;
private readonly experimentId! : string; private readonly experimentId! : string;
private gpuScheduler!: GPUScheduler; private gpuScheduler!: GPUScheduler;
private readonly occupiedGpuIndexNumMap: Map<number, number>; private readonly occupiedGpuIndexNumMap: Map<number, number>;
...@@ -145,7 +142,6 @@ class LocalTrainingService implements TrainingService { ...@@ -145,7 +142,6 @@ class LocalTrainingService implements TrainingService {
this.initialized = false; this.initialized = false;
this.stopping = false; this.stopping = false;
this.log = getLogger(); this.log = getLogger();
this.trialSequenceId = -1;
this.experimentId = getExperimentId(); this.experimentId = getExperimentId();
this.jobStreamMap = new Map<string, ts.Stream>(); this.jobStreamMap = new Map<string, ts.Stream>();
this.log.info('Construct local machine training service.'); this.log.info('Construct local machine training service.');
...@@ -169,9 +165,7 @@ class LocalTrainingService implements TrainingService { ...@@ -169,9 +165,7 @@ class LocalTrainingService implements TrainingService {
const jobs: TrialJobDetail[] = []; const jobs: TrialJobDetail[] = [];
for (const key of this.jobMap.keys()) { for (const key of this.jobMap.keys()) {
const trialJob: TrialJobDetail = await this.getTrialJob(key); const trialJob: TrialJobDetail = await this.getTrialJob(key);
if (trialJob.form.jobType === 'TRIAL') { jobs.push(trialJob);
jobs.push(trialJob);
}
} }
return jobs; return jobs;
...@@ -182,9 +176,6 @@ class LocalTrainingService implements TrainingService { ...@@ -182,9 +176,6 @@ class LocalTrainingService implements TrainingService {
if (trialJob === undefined) { if (trialJob === undefined) {
throw new NNIError(NNIErrorNames.NOT_FOUND, 'Trial job not found'); throw new NNIError(NNIErrorNames.NOT_FOUND, 'Trial job not found');
} }
if (trialJob.form.jobType === 'HOST') {
return this.getHostJob(trialJobId);
}
if (trialJob.status === 'RUNNING') { if (trialJob.status === 'RUNNING') {
const alive: boolean = await isAlive(trialJob.pid); const alive: boolean = await isAlive(trialJob.pid);
if (!alive) { if (!alive) {
...@@ -219,28 +210,21 @@ class LocalTrainingService implements TrainingService { ...@@ -219,28 +210,21 @@ class LocalTrainingService implements TrainingService {
this.eventEmitter.off('metric', listener); this.eventEmitter.off('metric', listener);
} }
public submitTrialJob(form: JobApplicationForm): Promise<TrialJobDetail> { public submitTrialJob(form: TrialJobApplicationForm): Promise<TrialJobDetail> {
if (form.jobType === 'HOST') { const trialJobId: string = uniqueString(5);
return this.runHostJob(<HostJobApplicationForm>form); const trialJobDetail: LocalTrialJobDetail = new LocalTrialJobDetail(
} else if (form.jobType === 'TRIAL') { trialJobId,
const trialJobId: string = uniqueString(5); 'WAITING',
const trialJobDetail: LocalTrialJobDetail = new LocalTrialJobDetail( Date.now(),
trialJobId, path.join(this.rootDir, 'trials', trialJobId),
'WAITING', form
Date.now(), );
path.join(this.rootDir, 'trials', trialJobId), this.jobQueue.push(trialJobId);
form, this.jobMap.set(trialJobId, trialJobDetail);
this.generateSequenceId()
); this.log.debug(`submitTrialJob: return: ${JSON.stringify(trialJobDetail)} `);
this.jobQueue.push(trialJobId);
this.jobMap.set(trialJobId, trialJobDetail); return Promise.resolve(trialJobDetail);
this.log.debug(`submitTrialJob: return: ${JSON.stringify(trialJobDetail)} `);
return Promise.resolve(trialJobDetail);
} else {
return Promise.reject(new Error(`Job form not supported: ${JSON.stringify(form)}`));
}
} }
/** /**
...@@ -248,16 +232,12 @@ class LocalTrainingService implements TrainingService { ...@@ -248,16 +232,12 @@ class LocalTrainingService implements TrainingService {
* @param trialJobId trial job id * @param trialJobId trial job id
* @param form job application form * @param form job application form
*/ */
public async updateTrialJob(trialJobId: string, form: JobApplicationForm): Promise<TrialJobDetail> { public async updateTrialJob(trialJobId: string, form: TrialJobApplicationForm): Promise<TrialJobDetail> {
const trialJobDetail: undefined | TrialJobDetail = this.jobMap.get(trialJobId); const trialJobDetail: undefined | TrialJobDetail = this.jobMap.get(trialJobId);
if (trialJobDetail === undefined) { if (trialJobDetail === undefined) {
throw new Error(`updateTrialJob failed: ${trialJobId} not found`); throw new Error(`updateTrialJob failed: ${trialJobId} not found`);
} }
if (form.jobType === 'TRIAL') { await this.writeParameterFile(trialJobDetail.workingDirectory, form.hyperParameters);
await this.writeParameterFile(trialJobDetail.workingDirectory, (<TrialJobApplicationForm>form).hyperParameters);
} else {
throw new Error(`updateTrialJob failed: jobType ${form.jobType} not supported.`);
}
return trialJobDetail; return trialJobDetail;
} }
...@@ -279,13 +259,7 @@ class LocalTrainingService implements TrainingService { ...@@ -279,13 +259,7 @@ class LocalTrainingService implements TrainingService {
return Promise.resolve(); return Promise.resolve();
} }
if (trialJob.form.jobType === 'TRIAL') { tkill(trialJob.pid, 'SIGKILL');
tkill(trialJob.pid, 'SIGKILL');
} else if (trialJob.form.jobType === 'HOST') {
await cpp.exec(`pkill -9 -P ${trialJob.pid}`);
} else {
throw new Error(`Job type not supported: ${trialJob.form.jobType}`);
}
this.setTrialJobStatus(trialJob, getJobCancelStatus(isEarlyStopped)); this.setTrialJobStatus(trialJob, getJobCancelStatus(isEarlyStopped));
return Promise.resolve(); return Promise.resolve();
...@@ -409,7 +383,7 @@ class LocalTrainingService implements TrainingService { ...@@ -409,7 +383,7 @@ class LocalTrainingService implements TrainingService {
{ key: 'NNI_SYS_DIR', value: trialJobDetail.workingDirectory }, { key: 'NNI_SYS_DIR', value: trialJobDetail.workingDirectory },
{ key: 'NNI_TRIAL_JOB_ID', value: trialJobDetail.id }, { key: 'NNI_TRIAL_JOB_ID', value: trialJobDetail.id },
{ key: 'NNI_OUTPUT_DIR', value: trialJobDetail.workingDirectory }, { key: 'NNI_OUTPUT_DIR', value: trialJobDetail.workingDirectory },
{ key: 'NNI_TRIAL_SEQ_ID', value: trialJobDetail.sequenceId.toString() }, { key: 'NNI_TRIAL_SEQ_ID', value: trialJobDetail.form.sequenceId.toString() },
{ key: 'MULTI_PHASE', value: this.isMultiPhase.toString() } { key: 'MULTI_PHASE', value: this.isMultiPhase.toString() }
]; ];
if (gpuNum !== undefined) { if (gpuNum !== undefined) {
...@@ -516,7 +490,7 @@ class LocalTrainingService implements TrainingService { ...@@ -516,7 +490,7 @@ class LocalTrainingService implements TrainingService {
const script: string[] = []; const script: string[] = [];
if (process.platform === 'win32') { if (process.platform === 'win32') {
script.push( script.push(
`cmd /c ${localTrialConfig.command} 2>${path.join(workingDirectory, 'stderr')}`, `cmd.exe /c ${localTrialConfig.command} 2>${path.join(workingDirectory, 'stderr')}`,
`$NOW_DATE = [int64](([datetime]::UtcNow)-(get-date "1/1/1970")).TotalSeconds`, `$NOW_DATE = [int64](([datetime]::UtcNow)-(get-date "1/1/1970")).TotalSeconds`,
`$NOW_DATE = "$NOW_DATE" + (Get-Date -Format fff).ToString()`, `$NOW_DATE = "$NOW_DATE" + (Get-Date -Format fff).ToString()`,
`Write $LASTEXITCODE " " $NOW_DATE | Out-File ${path.join(workingDirectory, '.nni', 'state')} -NoNewline -encoding utf8`); `Write $LASTEXITCODE " " $NOW_DATE | Out-File ${path.join(workingDirectory, '.nni', 'state')} -NoNewline -encoding utf8`);
...@@ -562,7 +536,7 @@ class LocalTrainingService implements TrainingService { ...@@ -562,7 +536,7 @@ class LocalTrainingService implements TrainingService {
const scriptName: string = getScriptName('run'); const scriptName: string = getScriptName('run');
await fs.promises.writeFile(path.join(trialJobDetail.workingDirectory, scriptName), await fs.promises.writeFile(path.join(trialJobDetail.workingDirectory, scriptName),
runScriptContent.join(getNewLine()), { encoding: 'utf8', mode: 0o777 }); runScriptContent.join(getNewLine()), { encoding: 'utf8', mode: 0o777 });
await this.writeParameterFile(trialJobDetail.workingDirectory, (<TrialJobApplicationForm>trialJobDetail.form).hyperParameters); await this.writeParameterFile(trialJobDetail.workingDirectory, trialJobDetail.form.hyperParameters);
const trialJobProcess: cp.ChildProcess = runScript(path.join(trialJobDetail.workingDirectory, scriptName)); const trialJobProcess: cp.ChildProcess = runScript(path.join(trialJobDetail.workingDirectory, scriptName));
this.setTrialJobStatus(trialJobDetail, 'RUNNING'); this.setTrialJobStatus(trialJobDetail, 'RUNNING');
trialJobDetail.startTime = Date.now(); trialJobDetail.startTime = Date.now();
...@@ -589,60 +563,10 @@ class LocalTrainingService implements TrainingService { ...@@ -589,60 +563,10 @@ class LocalTrainingService implements TrainingService {
this.jobStreamMap.set(trialJobDetail.id, stream); this.jobStreamMap.set(trialJobDetail.id, stream);
} }
private async runHostJob(form: HostJobApplicationForm): Promise<TrialJobDetail> {
const jobId: string = uniqueString(5);
const workDir: string = path.join(this.rootDir, 'hostjobs', jobId);
await cpp.exec(`mkdir -p ${workDir}`);
const wrappedCmd: string = `cd ${workDir} && ${form.cmd}>stdout 2>stderr`;
this.log.debug(`runHostJob: command: ${wrappedCmd}`);
const process: cp.ChildProcess = cp.exec(wrappedCmd);
const jobDetail: LocalTrialJobDetail = {
id: jobId,
status: 'RUNNING',
submitTime: Date.now(),
workingDirectory: workDir,
form: form,
sequenceId: this.generateSequenceId(),
pid: process.pid
};
this.jobMap.set(jobId, jobDetail);
this.log.debug(`runHostJob: return: ${JSON.stringify(jobDetail)} `);
return jobDetail;
}
private async getHostJob(jobId: string): Promise<TrialJobDetail> {
const jobDetail: LocalTrialJobDetail | undefined = this.jobMap.get(jobId);
if (jobDetail === undefined) {
throw new NNIError(NNIErrorNames.NOT_FOUND, `Host Job not found: ${jobId}`);
}
try {
await cpp.exec(`kill -0 ${jobDetail.pid}`);
return jobDetail;
} catch (error) {
if (error instanceof Error) {
this.log.debug(`getHostJob: error: ${error.message}`);
this.jobMap.delete(jobId);
throw new NNIError(NNIErrorNames.NOT_FOUND, `Host Job not found: ${error.message}`);
} else {
throw error;
}
}
}
private async writeParameterFile(directory: string, hyperParameters: HyperParameters): Promise<void> { private async writeParameterFile(directory: string, hyperParameters: HyperParameters): Promise<void> {
const filepath: string = path.join(directory, generateParamFileName(hyperParameters)); const filepath: string = path.join(directory, generateParamFileName(hyperParameters));
await fs.promises.writeFile(filepath, hyperParameters.value, { encoding: 'utf8' }); await fs.promises.writeFile(filepath, hyperParameters.value, { encoding: 'utf8' });
} }
private generateSequenceId(): number {
if (this.trialSequenceId === -1) {
this.trialSequenceId = getInitTrialSequenceId();
}
return this.trialSequenceId++;
}
} }
export { LocalTrainingService }; export { LocalTrainingService };
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
'use strict'; 'use strict';
import { JobApplicationForm, TrialJobDetail, TrialJobStatus } from '../../common/trainingService'; import { TrialJobApplicationForm, TrialJobDetail, TrialJobStatus } from '../../common/trainingService';
/** /**
* PAI trial job detail * PAI trial job detail
...@@ -34,20 +34,18 @@ export class PAITrialJobDetail implements TrialJobDetail { ...@@ -34,20 +34,18 @@ export class PAITrialJobDetail implements TrialJobDetail {
public tags?: string[]; public tags?: string[];
public url?: string; public url?: string;
public workingDirectory: string; public workingDirectory: string;
public form: JobApplicationForm; public form: TrialJobApplicationForm;
public sequenceId: number;
public hdfsLogPath: string; public hdfsLogPath: string;
public isEarlyStopped?: boolean; public isEarlyStopped?: boolean;
constructor(id: string, status: TrialJobStatus, paiJobName : string, constructor(id: string, status: TrialJobStatus, paiJobName : string,
submitTime: number, workingDirectory: string, form: JobApplicationForm, sequenceId: number, hdfsLogPath: string) { submitTime: number, workingDirectory: string, form: TrialJobApplicationForm, hdfsLogPath: string) {
this.id = id; this.id = id;
this.status = status; this.status = status;
this.paiJobName = paiJobName; this.paiJobName = paiJobName;
this.submitTime = submitTime; this.submitTime = submitTime;
this.workingDirectory = workingDirectory; this.workingDirectory = workingDirectory;
this.form = form; this.form = form;
this.sequenceId = sequenceId;
this.tags = []; this.tags = [];
this.hdfsLogPath = hdfsLogPath; this.hdfsLogPath = hdfsLogPath;
} }
......
...@@ -30,10 +30,10 @@ import { EventEmitter } from 'events'; ...@@ -30,10 +30,10 @@ import { EventEmitter } from 'events';
import { Deferred } from 'ts-deferred'; import { Deferred } from 'ts-deferred';
import { String } from 'typescript-string-operations'; import { String } from 'typescript-string-operations';
import { MethodNotImplementedError } from '../../common/errors'; import { MethodNotImplementedError } from '../../common/errors';
import { getExperimentId, getInitTrialSequenceId } from '../../common/experimentStartupInfo'; import { getExperimentId } from '../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../common/log'; import { getLogger, Logger } from '../../common/log';
import { import {
HyperParameters, JobApplicationForm, NNIManagerIpConfig, TrainingService, HyperParameters, NNIManagerIpConfig, TrainingService,
TrialJobApplicationForm, TrialJobDetail, TrialJobMetric TrialJobApplicationForm, TrialJobDetail, TrialJobMetric
} from '../../common/trainingService'; } from '../../common/trainingService';
import { delay, generateParamFileName, import { delay, generateParamFileName,
...@@ -70,7 +70,6 @@ class PAITrainingService implements TrainingService { ...@@ -70,7 +70,6 @@ class PAITrainingService implements TrainingService {
private readonly paiTokenUpdateInterval: number; private readonly paiTokenUpdateInterval: number;
private readonly experimentId! : string; private readonly experimentId! : string;
private readonly paiJobCollector : PAIJobInfoCollector; private readonly paiJobCollector : PAIJobInfoCollector;
private nextTrialSequenceId: number;
private paiRestServerPort?: number; private paiRestServerPort?: number;
private nniManagerIpConfig?: NNIManagerIpConfig; private nniManagerIpConfig?: NNIManagerIpConfig;
private copyExpCodeDirPromise?: Promise<void>; private copyExpCodeDirPromise?: Promise<void>;
...@@ -90,7 +89,6 @@ class PAITrainingService implements TrainingService { ...@@ -90,7 +89,6 @@ class PAITrainingService implements TrainingService {
this.expRootDir = path.join('/nni', 'experiments', getExperimentId()); this.expRootDir = path.join('/nni', 'experiments', getExperimentId());
this.experimentId = getExperimentId(); this.experimentId = getExperimentId();
this.paiJobCollector = new PAIJobInfoCollector(this.trialJobsMap); this.paiJobCollector = new PAIJobInfoCollector(this.trialJobsMap);
this.nextTrialSequenceId = -1;
this.paiTokenUpdateInterval = 7200000; //2hours this.paiTokenUpdateInterval = 7200000; //2hours
this.logCollection = 'none'; this.logCollection = 'none';
this.log.info('Construct OpenPAI training service.'); this.log.info('Construct OpenPAI training service.');
...@@ -112,9 +110,7 @@ class PAITrainingService implements TrainingService { ...@@ -112,9 +110,7 @@ class PAITrainingService implements TrainingService {
const jobs: TrialJobDetail[] = []; const jobs: TrialJobDetail[] = [];
for (const [key, value] of this.trialJobsMap) { for (const [key, value] of this.trialJobsMap) {
if (value.form.jobType === 'TRIAL') { jobs.push(await this.getTrialJob(key));
jobs.push(await this.getTrialJob(key));
}
} }
return Promise.resolve(jobs); return Promise.resolve(jobs);
...@@ -142,7 +138,7 @@ class PAITrainingService implements TrainingService { ...@@ -142,7 +138,7 @@ class PAITrainingService implements TrainingService {
this.metricsEmitter.off('metric', listener); this.metricsEmitter.off('metric', listener);
} }
public async submitTrialJob(form: JobApplicationForm): Promise<TrialJobDetail> { public async submitTrialJob(form: TrialJobApplicationForm): Promise<TrialJobDetail> {
if (this.paiClusterConfig === undefined) { if (this.paiClusterConfig === undefined) {
throw new Error(`paiClusterConfig not initialized!`); throw new Error(`paiClusterConfig not initialized!`);
} }
...@@ -151,7 +147,6 @@ class PAITrainingService implements TrainingService { ...@@ -151,7 +147,6 @@ class PAITrainingService implements TrainingService {
this.log.info(`submitTrialJob: form: ${JSON.stringify(form)}`); this.log.info(`submitTrialJob: form: ${JSON.stringify(form)}`);
const trialJobId: string = uniqueString(5); const trialJobId: string = uniqueString(5);
const trialSequenceId: number = this.generateSequenceId();
//TODO: use HDFS working folder instead //TODO: use HDFS working folder instead
const trialWorkingFolder: string = path.join(this.expRootDir, 'trials', trialJobId); const trialWorkingFolder: string = path.join(this.expRootDir, 'trials', trialJobId);
const paiJobName: string = `nni_exp_${this.experimentId}_trial_${trialJobId}`; const paiJobName: string = `nni_exp_${this.experimentId}_trial_${trialJobId}`;
...@@ -171,7 +166,6 @@ class PAITrainingService implements TrainingService { ...@@ -171,7 +166,6 @@ class PAITrainingService implements TrainingService {
Date.now(), Date.now(),
trialWorkingFolder, trialWorkingFolder,
form, form,
trialSequenceId,
hdfsLogPath); hdfsLogPath);
this.trialJobsMap.set(trialJobId, trialJobDetail); this.trialJobsMap.set(trialJobId, trialJobDetail);
...@@ -181,16 +175,12 @@ class PAITrainingService implements TrainingService { ...@@ -181,16 +175,12 @@ class PAITrainingService implements TrainingService {
return deferred.promise; return deferred.promise;
} }
public async updateTrialJob(trialJobId: string, form: JobApplicationForm): Promise<TrialJobDetail> { public async updateTrialJob(trialJobId: string, form: TrialJobApplicationForm): Promise<TrialJobDetail> {
const trialJobDetail: undefined | TrialJobDetail = this.trialJobsMap.get(trialJobId); const trialJobDetail: undefined | TrialJobDetail = this.trialJobsMap.get(trialJobId);
if (trialJobDetail === undefined) { if (trialJobDetail === undefined) {
throw new Error(`updateTrialJob failed: ${trialJobId} not found`); throw new Error(`updateTrialJob failed: ${trialJobId} not found`);
} }
if (form.jobType === 'TRIAL') { await this.writeParameterFile(trialJobId, form.hyperParameters);
await this.writeParameterFile(trialJobId, (<TrialJobApplicationForm>form).hyperParameters);
} else {
throw new Error(`updateTrialJob failed: jobType ${form.jobType} not supported.`);
}
return trialJobDetail; return trialJobDetail;
} }
...@@ -397,11 +387,10 @@ class PAITrainingService implements TrainingService { ...@@ -397,11 +387,10 @@ class PAITrainingService implements TrainingService {
await fs.promises.writeFile(path.join(trialLocalTempFolder, 'install_nni.sh'), runScriptContent, { encoding: 'utf8' }); await fs.promises.writeFile(path.join(trialLocalTempFolder, 'install_nni.sh'), runScriptContent, { encoding: 'utf8' });
// Write file content ( parameter.cfg ) to local tmp folders // Write file content ( parameter.cfg ) to local tmp folders
const trialForm : TrialJobApplicationForm = (<TrialJobApplicationForm>trialJobDetail.form); if (trialJobDetail.form !== undefined) {
if (trialForm !== undefined) {
await fs.promises.writeFile( await fs.promises.writeFile(
path.join(trialLocalTempFolder, generateParamFileName(trialForm.hyperParameters)), path.join(trialLocalTempFolder, generateParamFileName(trialJobDetail.form.hyperParameters)),
trialForm.hyperParameters.value, { encoding: 'utf8' } trialJobDetail.form.hyperParameters.value, { encoding: 'utf8' }
); );
} }
const hdfsCodeDir: string = HDFSClientUtility.getHdfsTrialWorkDir(this.paiClusterConfig.userName, trialJobId); const hdfsCodeDir: string = HDFSClientUtility.getHdfsTrialWorkDir(this.paiClusterConfig.userName, trialJobId);
...@@ -416,7 +405,7 @@ class PAITrainingService implements TrainingService { ...@@ -416,7 +405,7 @@ class PAITrainingService implements TrainingService {
`$PWD/${trialJobId}/nnioutput`, `$PWD/${trialJobId}/nnioutput`,
trialJobId, trialJobId,
this.experimentId, this.experimentId,
trialJobDetail.sequenceId, trialJobDetail.form.sequenceId,
this.isMultiPhase, this.isMultiPhase,
this.paiTrialConfig.command, this.paiTrialConfig.command,
nniManagerIp, nniManagerIp,
...@@ -507,14 +496,6 @@ class PAITrainingService implements TrainingService { ...@@ -507,14 +496,6 @@ class PAITrainingService implements TrainingService {
return deferred.promise; return deferred.promise;
} }
private generateSequenceId(): number {
if (this.nextTrialSequenceId === -1) {
this.nextTrialSequenceId = getInitTrialSequenceId();
}
return this.nextTrialSequenceId++;
}
private async statusCheckingLoop(): Promise<void> { private async statusCheckingLoop(): Promise<void> {
while (!this.stopping) { while (!this.stopping) {
try { try {
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
import * as fs from 'fs'; import * as fs from 'fs';
import { Client, ConnectConfig } from 'ssh2'; import { Client, ConnectConfig } from 'ssh2';
import { Deferred } from 'ts-deferred'; import { Deferred } from 'ts-deferred';
import { JobApplicationForm, TrialJobDetail, TrialJobStatus } from '../../common/trainingService'; import { TrialJobApplicationForm, TrialJobDetail, TrialJobStatus } from '../../common/trainingService';
import { GPUInfo, GPUSummary } from '../common/gpuData'; import { GPUInfo, GPUSummary } from '../common/gpuData';
/** /**
...@@ -82,20 +82,18 @@ export class RemoteMachineTrialJobDetail implements TrialJobDetail { ...@@ -82,20 +82,18 @@ export class RemoteMachineTrialJobDetail implements TrialJobDetail {
public tags?: string[]; public tags?: string[];
public url?: string; public url?: string;
public workingDirectory: string; public workingDirectory: string;
public form: JobApplicationForm; public form: TrialJobApplicationForm;
public sequenceId: number;
public rmMeta?: RemoteMachineMeta; public rmMeta?: RemoteMachineMeta;
public isEarlyStopped?: boolean; public isEarlyStopped?: boolean;
public gpuIndices: GPUInfo[]; public gpuIndices: GPUInfo[];
constructor(id: string, status: TrialJobStatus, submitTime: number, constructor(id: string, status: TrialJobStatus, submitTime: number,
workingDirectory: string, form: JobApplicationForm, sequenceId: number) { workingDirectory: string, form: TrialJobApplicationForm) {
this.id = id; this.id = id;
this.status = status; this.status = status;
this.submitTime = submitTime; this.submitTime = submitTime;
this.workingDirectory = workingDirectory; this.workingDirectory = workingDirectory;
this.form = form; this.form = form;
this.sequenceId = sequenceId;
this.tags = []; this.tags = [];
this.gpuIndices = []; this.gpuIndices = [];
} }
......
...@@ -30,11 +30,11 @@ import { Deferred } from 'ts-deferred'; ...@@ -30,11 +30,11 @@ import { Deferred } from 'ts-deferred';
import { String } from 'typescript-string-operations'; import { String } from 'typescript-string-operations';
import * as component from '../../common/component'; import * as component from '../../common/component';
import { NNIError, NNIErrorNames } from '../../common/errors'; import { NNIError, NNIErrorNames } from '../../common/errors';
import { getExperimentId, getInitTrialSequenceId } from '../../common/experimentStartupInfo'; import { getExperimentId } from '../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../common/log'; import { getLogger, Logger } from '../../common/log';
import { ObservableTimer } from '../../common/observableTimer'; import { ObservableTimer } from '../../common/observableTimer';
import { import {
HostJobApplicationForm, HyperParameters, JobApplicationForm, NNIManagerIpConfig, TrainingService, TrialJobApplicationForm, HyperParameters, NNIManagerIpConfig, TrainingService, TrialJobApplicationForm,
TrialJobDetail, TrialJobMetric TrialJobDetail, TrialJobMetric
} from '../../common/trainingService'; } from '../../common/trainingService';
import { import {
...@@ -172,9 +172,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -172,9 +172,7 @@ class RemoteMachineTrainingService implements TrainingService {
const deferred: Deferred<TrialJobDetail[]> = new Deferred<TrialJobDetail[]>(); const deferred: Deferred<TrialJobDetail[]> = new Deferred<TrialJobDetail[]>();
for (const [key, value] of this.trialJobsMap) { for (const [key, value] of this.trialJobsMap) {
if (value.form.jobType === 'TRIAL') { jobs.push(await this.getTrialJob(key));
jobs.push(await this.getTrialJob(key));
}
} }
deferred.resolve(jobs); deferred.resolve(jobs);
...@@ -228,33 +226,26 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -228,33 +226,26 @@ class RemoteMachineTrainingService implements TrainingService {
* @param form trial job description form * @param form trial job description form
*/ */
// tslint:disable-next-line:informative-docs // tslint:disable-next-line:informative-docs
public async submitTrialJob(form: JobApplicationForm): Promise<TrialJobDetail> { public async submitTrialJob(form: TrialJobApplicationForm): Promise<TrialJobDetail> {
if (this.trialConfig === undefined) { if (this.trialConfig === undefined) {
throw new Error('trial config is not initialized'); throw new Error('trial config is not initialized');
} }
if (form.jobType === 'HOST') { // Generate trial job id(random)
return this.runHostJob(<HostJobApplicationForm>form); const trialJobId: string = uniqueString(5);
} else if (form.jobType === 'TRIAL') { const trialWorkingFolder: string = unixPathJoin(this.remoteExpRootDir, 'trials', trialJobId);
// Generate trial job id(random)
const trialJobId: string = uniqueString(5);
const trialWorkingFolder: string = unixPathJoin(this.remoteExpRootDir, 'trials', trialJobId);
const trialJobDetail: RemoteMachineTrialJobDetail = new RemoteMachineTrialJobDetail( const trialJobDetail: RemoteMachineTrialJobDetail = new RemoteMachineTrialJobDetail(
trialJobId, trialJobId,
'WAITING', 'WAITING',
Date.now(), Date.now(),
trialWorkingFolder, trialWorkingFolder,
form, form
this.generateSequenceId() );
); this.jobQueue.push(trialJobId);
this.jobQueue.push(trialJobId); this.trialJobsMap.set(trialJobId, trialJobDetail);
this.trialJobsMap.set(trialJobId, trialJobDetail);
return Promise.resolve(trialJobDetail); return Promise.resolve(trialJobDetail);
} else {
return Promise.reject(new Error(`Job form not supported: ${JSON.stringify(form)}, jobType should be HOST or TRIAL.`));
}
} }
/** /**
...@@ -262,20 +253,16 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -262,20 +253,16 @@ class RemoteMachineTrainingService implements TrainingService {
* @param trialJobId trial job id * @param trialJobId trial job id
* @param form job application form * @param form job application form
*/ */
public async updateTrialJob(trialJobId: string, form: JobApplicationForm): Promise<TrialJobDetail> { public async updateTrialJob(trialJobId: string, form: TrialJobApplicationForm): Promise<TrialJobDetail> {
const trialJobDetail: undefined | TrialJobDetail = this.trialJobsMap.get(trialJobId); const trialJobDetail: undefined | TrialJobDetail = this.trialJobsMap.get(trialJobId);
if (trialJobDetail === undefined) { if (trialJobDetail === undefined) {
throw new Error(`updateTrialJob failed: ${trialJobId} not found`); throw new Error(`updateTrialJob failed: ${trialJobId} not found`);
} }
if (form.jobType === 'TRIAL') { const rmMeta: RemoteMachineMeta | undefined = (<RemoteMachineTrialJobDetail>trialJobDetail).rmMeta;
const rmMeta: RemoteMachineMeta | undefined = (<RemoteMachineTrialJobDetail>trialJobDetail).rmMeta; if (rmMeta !== undefined) {
if (rmMeta !== undefined) { await this.writeParameterFile(trialJobId, form.hyperParameters, rmMeta);
await this.writeParameterFile(trialJobId, (<TrialJobApplicationForm>form).hyperParameters, rmMeta);
} else {
throw new Error(`updateTrialJob failed: ${trialJobId} rmMeta not found`);
}
} else { } else {
throw new Error(`updateTrialJob failed: jobType ${form.jobType} not supported.`); throw new Error(`updateTrialJob failed: ${trialJobId} rmMeta not found`);
} }
return trialJobDetail; return trialJobDetail;
...@@ -558,7 +545,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -558,7 +545,7 @@ class RemoteMachineTrainingService implements TrainingService {
await this.allocateSSHClientForTrial(trialJobDetail); await this.allocateSSHClientForTrial(trialJobDetail);
await this.launchTrialOnScheduledMachine( await this.launchTrialOnScheduledMachine(
trialJobId, trialWorkingFolder, <TrialJobApplicationForm>trialJobDetail.form, rmScheduleInfo); trialJobId, trialWorkingFolder, trialJobDetail.form, rmScheduleInfo);
trialJobDetail.status = 'RUNNING'; trialJobDetail.status = 'RUNNING';
trialJobDetail.url = `file://${rmScheduleInfo.rmMeta.ip}:${trialWorkingFolder}`; trialJobDetail.url = `file://${rmScheduleInfo.rmMeta.ip}:${trialWorkingFolder}`;
...@@ -628,7 +615,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -628,7 +615,7 @@ class RemoteMachineTrainingService implements TrainingService {
trialWorkingFolder, trialWorkingFolder,
trialJobId, trialJobId,
getExperimentId(), getExperimentId(),
trialJobDetail.sequenceId.toString(), trialJobDetail.form.sequenceId.toString(),
this.isMultiPhase, this.isMultiPhase,
unixPathJoin(trialWorkingFolder, '.nni', 'jobpid'), unixPathJoin(trialWorkingFolder, '.nni', 'jobpid'),
command, command,
...@@ -657,38 +644,6 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -657,38 +644,6 @@ class RemoteMachineTrainingService implements TrainingService {
SSHClientUtility.remoteExeCommand(`bash ${unixPathJoin(trialWorkingFolder, 'run.sh')}`, sshClient); SSHClientUtility.remoteExeCommand(`bash ${unixPathJoin(trialWorkingFolder, 'run.sh')}`, sshClient);
} }
private async runHostJob(form: HostJobApplicationForm): Promise<TrialJobDetail> {
const rmMeta: RemoteMachineMeta = this.getRmMetaByHost(form.host);
const sshClientManager: SSHClientManager | undefined = this.machineSSHClientMap.get(rmMeta);
if (sshClientManager === undefined) {
throw new Error('sshClient not found.');
}
const sshClient: Client = sshClientManager.getFirstSSHClient();
const jobId: string = uniqueString(5);
const localDir: string = path.join(this.expRootDir, 'hostjobs-local', jobId);
const remoteDir: string = this.getHostJobRemoteDir(jobId);
await cpp.exec(`mkdir -p ${localDir}`);
await SSHClientUtility.remoteExeCommand(`mkdir -p ${remoteDir}`, sshClient);
const runScriptContent: string = String.Format(
HOST_JOB_SHELL_FORMAT, remoteDir, path.join(remoteDir, 'jobpid'), form.cmd, path.join(remoteDir, 'code')
);
await fs.promises.writeFile(path.join(localDir, 'run.sh'), runScriptContent, { encoding: 'utf8' });
await SSHClientUtility.copyFileToRemote(
path.join(localDir, 'run.sh'), unixPathJoin(remoteDir, 'run.sh'), sshClient);
// tslint:disable-next-line: no-floating-promises
SSHClientUtility.remoteExeCommand(`bash ${unixPathJoin(remoteDir, 'run.sh')}`, sshClient);
const jobDetail: RemoteMachineTrialJobDetail = new RemoteMachineTrialJobDetail(
jobId, 'RUNNING', Date.now(), remoteDir, form, this.generateSequenceId()
);
jobDetail.rmMeta = rmMeta;
jobDetail.startTime = Date.now();
this.trialJobsMap.set(jobId, jobDetail);
this.log.debug(`runHostJob: return: ${JSON.stringify(jobDetail)} `);
return jobDetail;
}
private getRmMetaByHost(host: string): RemoteMachineMeta { private getRmMetaByHost(host: string): RemoteMachineMeta {
for (const [rmMeta, client] of this.machineSSHClientMap.entries()) { for (const [rmMeta, client] of this.machineSSHClientMap.entries()) {
if (rmMeta.ip === host) { if (rmMeta.ip === host) {
...@@ -765,13 +720,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -765,13 +720,7 @@ class RemoteMachineTrainingService implements TrainingService {
} }
let jobpidPath: string; let jobpidPath: string;
if (trialJobDetail.form.jobType === 'TRIAL') { jobpidPath = unixPathJoin(trialJobDetail.workingDirectory, '.nni', 'jobpid');
jobpidPath = unixPathJoin(trialJobDetail.workingDirectory, '.nni', 'jobpid');
} else if (trialJobDetail.form.jobType === 'HOST') {
jobpidPath = unixPathJoin(this.getHostJobRemoteDir(jobId), 'jobpid');
} else {
throw new Error(`Job type not supported: ${trialJobDetail.form.jobType}`);
}
return jobpidPath; return jobpidPath;
} }
...@@ -791,14 +740,6 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -791,14 +740,6 @@ class RemoteMachineTrainingService implements TrainingService {
await SSHClientUtility.copyFileToRemote(localFilepath, unixPathJoin(trialWorkingFolder, fileName), sshClient); await SSHClientUtility.copyFileToRemote(localFilepath, unixPathJoin(trialWorkingFolder, fileName), sshClient);
} }
private generateSequenceId(): number {
if (this.trialSequenceId === -1) {
this.trialSequenceId = getInitTrialSequenceId();
}
return this.trialSequenceId++;
}
} }
export { RemoteMachineTrainingService }; export { RemoteMachineTrainingService };
...@@ -76,7 +76,7 @@ describe('Unit Test for LocalTrainingService', () => { ...@@ -76,7 +76,7 @@ describe('Unit Test for LocalTrainingService', () => {
// submit job // submit job
const form: TrialJobApplicationForm = { const form: TrialJobApplicationForm = {
jobType: 'TRIAL', sequenceId: 0,
hyperParameters: { hyperParameters: {
value: 'mock hyperparameters', value: 'mock hyperparameters',
index: 0 index: 0
...@@ -95,7 +95,7 @@ describe('Unit Test for LocalTrainingService', () => { ...@@ -95,7 +95,7 @@ describe('Unit Test for LocalTrainingService', () => {
// submit job // submit job
const form: TrialJobApplicationForm = { const form: TrialJobApplicationForm = {
jobType: 'TRIAL', sequenceId: 0,
hyperParameters: { hyperParameters: {
value: 'mock hyperparameters', value: 'mock hyperparameters',
index: 0 index: 0
...@@ -121,4 +121,4 @@ describe('Unit Test for LocalTrainingService', () => { ...@@ -121,4 +121,4 @@ describe('Unit Test for LocalTrainingService', () => {
it('Test multiphaseSupported', () => { it('Test multiphaseSupported', () => {
chai.expect(localTrainingService.isMultiPhaseJobSupported).to.be.equals(true) chai.expect(localTrainingService.isMultiPhaseJobSupported).to.be.equals(true)
}) })
}); });
\ No newline at end of file
...@@ -24,6 +24,7 @@ import * as chaiAsPromised from 'chai-as-promised'; ...@@ -24,6 +24,7 @@ import * as chaiAsPromised from 'chai-as-promised';
import * as fs from 'fs'; import * as fs from 'fs';
import * as tmp from 'tmp'; import * as tmp from 'tmp';
import * as component from '../../common/component'; import * as component from '../../common/component';
import { TrialJobApplicationForm } from '../../common/trainingService';
import { cleanupUnitTest, prepareUnitTest } from '../../common/utils'; import { cleanupUnitTest, prepareUnitTest } from '../../common/utils';
import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey'; import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
import { PAITrainingService } from '../pai/paiTrainingService'; import { PAITrainingService } from '../pai/paiTrainingService';
...@@ -84,12 +85,16 @@ describe('Unit Test for PAITrainingService', () => { ...@@ -84,12 +85,16 @@ describe('Unit Test for PAITrainingService', () => {
console.log(`paiCluster is ${paiCluster}`) console.log(`paiCluster is ${paiCluster}`)
await paiTrainingService.setClusterMetadata(TrialConfigMetadataKey.PAI_CLUSTER_CONFIG, paiCluster); await paiTrainingService.setClusterMetadata(TrialConfigMetadataKey.PAI_CLUSTER_CONFIG, paiCluster);
await paiTrainingService.setClusterMetadata(TrialConfigMetadataKey.TRIAL_CONFIG, paiTrialConfig); await paiTrainingService.setClusterMetadata(TrialConfigMetadataKey.TRIAL_CONFIG, paiTrialConfig);
const form: TrialJobApplicationForm = {
sequenceId: 0,
hyperParameters: { value: '', index: 0 }
};
try { try {
const trialDetail = await paiTrainingService.submitTrialJob({jobType : 'TRIAL'}); const trialDetail = await paiTrainingService.submitTrialJob(form);
chai.expect(trialDetail.status).to.be.equals('WAITING'); chai.expect(trialDetail.status).to.be.equals('WAITING');
} catch(error) { } catch(error) {
console.log('Submit job failed:' + error); console.log('Submit job failed:' + error);
chai.assert(error) chai.assert(error)
} }
}); });
}); });
\ No newline at end of file
...@@ -99,11 +99,11 @@ describe('Unit Test for RemoteMachineTrainingService', () => { ...@@ -99,11 +99,11 @@ describe('Unit Test for RemoteMachineTrainingService', () => {
await remoteMachineTrainingService.setClusterMetadata( await remoteMachineTrainingService.setClusterMetadata(
TrialConfigMetadataKey.TRIAL_CONFIG, `{"command":"sleep 1h && echo ","codeDir":"${localCodeDir}","gpuNum":1}`); TrialConfigMetadataKey.TRIAL_CONFIG, `{"command":"sleep 1h && echo ","codeDir":"${localCodeDir}","gpuNum":1}`);
const form: TrialJobApplicationForm = { const form: TrialJobApplicationForm = {
jobType: 'TRIAL', sequenceId: 0,
hyperParameters: { hyperParameters: {
value: 'mock hyperparameters', value: 'mock hyperparameters',
index: 0 index: 0
} }
}; };
const trialJob = await remoteMachineTrainingService.submitTrialJob(form); const trialJob = await remoteMachineTrainingService.submitTrialJob(form);
...@@ -137,7 +137,7 @@ describe('Unit Test for RemoteMachineTrainingService', () => { ...@@ -137,7 +137,7 @@ describe('Unit Test for RemoteMachineTrainingService', () => {
// submit job // submit job
const form: TrialJobApplicationForm = { const form: TrialJobApplicationForm = {
jobType: 'TRIAL', sequenceId: 0,
hyperParameters: { hyperParameters: {
value: 'mock hyperparameters', value: 'mock hyperparameters',
index: 0 index: 0
......
from .compressor import LayerInfo, Compressor, Pruner, Quantizer
from .builtin_pruners import *
from .builtin_quantizers import *
import logging
import tensorflow as tf
from .compressor import Pruner
__all__ = [ 'LevelPruner', 'AGP_Pruner', 'SensitivityPruner' ]
_logger = logging.getLogger(__name__)
class LevelPruner(Pruner):
def __init__(self, config_list):
"""
config_list: supported keys:
- sparsity
"""
super().__init__(config_list)
def calc_mask(self, weight, config, **kwargs):
threshold = tf.contrib.distributions.percentile(tf.abs(weight), config['sparsity'] * 100)
return tf.cast(tf.math.greater(tf.abs(weight), threshold), weight.dtype)
class AGP_Pruner(Pruner):
"""An automated gradual pruning algorithm that prunes the smallest magnitude
weights to achieve a preset level of network sparsity.
Michael Zhu and Suyog Gupta, "To prune, or not to prune: exploring the
efficacy of pruning for model compression", 2017 NIPS Workshop on Machine
Learning of Phones and other Consumer Devices,
https://arxiv.org/pdf/1710.01878.pdf
"""
def __init__(self, config_list):
"""
config_list: supported keys:
- initial_sparsity
- final_sparsity: you should make sure initial_sparsity <= final_sparsity
- start_epoch: start epoch numer begin update mask
- end_epoch: end epoch number stop update mask
- frequency: if you want update every 2 epoch, you can set it 2
"""
super().__init__(config_list)
self.now_epoch = tf.Variable(0)
self.assign_handler = []
def calc_mask(self, weight, config, **kwargs):
target_sparsity = self.compute_target_sparsity(config)
threshold = tf.contrib.distributions.percentile(weight, target_sparsity * 100)
# stop gradient in case gradient change the mask
mask = tf.stop_gradient(tf.cast(tf.math.greater(weight, threshold), weight.dtype))
self.assign_handler.append(tf.assign(weight, weight * mask))
return mask
def compute_target_sparsity(self, config):
end_epoch = config.get('end_epoch', 1)
start_epoch = config.get('start_epoch', 0)
freq = config.get('frequency', 1)
final_sparsity = config.get('final_sparsity', 0)
initial_sparsity = config.get('initial_sparsity', 0)
if end_epoch <= start_epoch or initial_sparsity >= final_sparsity:
_logger.warning('your end epoch <= start epoch or initial_sparsity >= final_sparsity')
return final_sparsity
now_epoch = tf.minimum(self.now_epoch, tf.constant(end_epoch))
span = int(((end_epoch - start_epoch-1)//freq)*freq)
assert span > 0
base = tf.cast(now_epoch - start_epoch, tf.float32) / span
target_sparsity = (final_sparsity +
(initial_sparsity - final_sparsity)*
(tf.pow(1.0 - base, 3)))
return target_sparsity
def update_epoch(self, epoch, sess):
sess.run(self.assign_handler)
sess.run(tf.assign(self.now_epoch, int(epoch)))
class SensitivityPruner(Pruner):
"""Use algorithm from "Learning both Weights and Connections for Efficient Neural Networks"
https://arxiv.org/pdf/1506.02626v3.pdf
I.e.: "The pruning threshold is chosen as a quality parameter multiplied
by the standard deviation of a layers weights."
"""
def __init__(self, config_list):
"""
config_list: supported keys
- sparsity: chosen pruning sparsity
"""
super().__init__(config_list)
self.layer_mask = {}
self.assign_handler = []
def calc_mask(self, weight, config, op_name, **kwargs):
target_sparsity = config['sparsity'] * tf.math.reduce_std(weight)
mask = tf.get_variable(op_name + '_mask', initializer=tf.ones(weight.shape), trainable=False)
self.layer_mask[op_name] = mask
weight_assign_handler = tf.assign(weight, mask*weight)
# use control_dependencies so that weight_assign_handler will be executed before mask_update_handler
with tf.control_dependencies([weight_assign_handler]):
threshold = tf.contrib.distributions.percentile(weight, target_sparsity * 100)
# stop gradient in case gradient change the mask
new_mask = tf.stop_gradient(tf.cast(tf.math.greater(weight, threshold), weight.dtype))
mask_update_handler = tf.assign(mask, new_mask)
self.assign_handler.append(mask_update_handler)
return mask
def update_epoch(self, epoch, sess):
sess.run(self.assign_handler)
import logging
import tensorflow as tf
from .compressor import Quantizer
__all__ = [ 'NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer' ]
_logger = logging.getLogger(__name__)
class NaiveQuantizer(Quantizer):
"""quantize weight to 8 bits
"""
def __init__(self, config_list):
super().__init__(config_list)
self.layer_scale = { }
def quantize_weight(self, weight, config, op_name, **kwargs):
new_scale = tf.reduce_max(tf.abs(weight)) / 127
scale = tf.maximum(self.layer_scale.get(op_name, tf.constant(0.0)), new_scale)
self.layer_scale[op_name] = scale
orig_type = weight.dtype
return tf.cast(tf.cast(weight / scale, tf.int8), orig_type) * scale
class QAT_Quantizer(Quantizer):
"""Quantizer using the DoReFa scheme, as defined in:
Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference
http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf
"""
def __init__(self, config_list):
"""
config_list: supported keys:
- q_bits
"""
super().__init__(config_list)
def quantize_weight(self, weight, config, **kwargs):
a = tf.stop_gradient(tf.reduce_min(weight))
b = tf.stop_gradient(tf.reduce_max(weight))
n = tf.cast(2 ** config['q_bits'], tf.float32)
scale = b-a/(n-1)
# use gradient_override_map to change round to idetity for gradient
with tf.get_default_graph().gradient_override_map({'Round': 'Identity'}):
qw = tf.round((weight-a)/scale)*scale +a
return qw
class DoReFaQuantizer(Quantizer):
"""Quantizer using the DoReFa scheme, as defined in:
Zhou et al., DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients
(https://arxiv.org/abs/1606.06160)
"""
def __init__(self, config_list):
"""
config_list: supported keys:
- q_bits
"""
super().__init__(config_list)
def quantize_weight(self, weight, config, **kwargs):
a = tf.math.tanh(weight)
b = a/(2*tf.reduce_max(tf.abs(weight))) + 0.5
scale = pow(2, config['q_bits'] - 1)
# use gradient_override_map to change round to idetity for gradient
with tf.get_default_graph().gradient_override_map({'Round': 'Identity'}):
qw = tf.round(b*scale)/scale
r_qw = 2 * qw - 1
return r_qw
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment