Unverified Commit ca99000d authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Designated gpu devices for NNI trial jobs (#991)

* Refactoring local training service
* Designated GPU for local training service
* RemoteMachine designated GPU configuration
parent cbad7acb
...@@ -392,6 +392,13 @@ machineList: ...@@ -392,6 +392,13 @@ machineList:
__image__ set the image to be used in __worker__. __image__ set the image to be used in __worker__.
* __localConfig__
__localConfig__ is applicable only if __trainingServicePlatform__ is set to ```local```, otherwise there should not be __localConfig__ section in configuration file.
* __gpuIndices__
__gpuIndices__ is used to specify designated GPU devices for NNI, if it is set, only the specified GPU devices are used for NNI trial jobs. Single or multiple GPU indices can be specified, multiple GPU indices are seperated by comma(,), such as ```1``` or ```0,1,3```.
* __machineList__ * __machineList__
__machineList__ should be set if __trainingServicePlatform__ is set to remote, or it should be empty. __machineList__ should be set if __trainingServicePlatform__ is set to remote, or it should be empty.
...@@ -422,6 +429,10 @@ machineList: ...@@ -422,6 +429,10 @@ machineList:
__passphrase__ is used to protect ssh key, which could be empty if users don't have passphrase. __passphrase__ is used to protect ssh key, which could be empty if users don't have passphrase.
* __gpuIndices__
__gpuIndices__ is used to specify designated GPU devices for NNI on this remote machine, if it is set, only the specified GPU devices are used for NNI trial jobs. Single or multiple GPU indices can be specified, multiple GPU indices are seperated by comma(,), such as ```1``` or ```0,1,3```.
* __kubeflowConfig__: * __kubeflowConfig__:
* __operator__ * __operator__
......
...@@ -21,27 +21,29 @@ ...@@ -21,27 +21,29 @@
import { Container, Scope } from 'typescript-ioc'; import { Container, Scope } from 'typescript-ioc';
import * as component from './common/component';
import * as fs from 'fs'; import * as fs from 'fs';
import * as component from './common/component';
import { Database, DataStore } from './common/datastore'; import { Database, DataStore } from './common/datastore';
import { setExperimentStartupInfo } from './common/experimentStartupInfo'; import { setExperimentStartupInfo } from './common/experimentStartupInfo';
import { getLogger, Logger, logLevelNameMap } from './common/log'; import { getLogger, Logger, logLevelNameMap } from './common/log';
import { Manager } from './common/manager'; import { Manager } from './common/manager';
import { TrainingService } from './common/trainingService'; import { TrainingService } from './common/trainingService';
import { parseArg, uniqueString, mkDirP, getLogDir } from './common/utils'; import { getLogDir, mkDirP, parseArg, uniqueString } from './common/utils';
import { NNIDataStore } from './core/nniDataStore'; import { NNIDataStore } from './core/nniDataStore';
import { NNIManager } from './core/nnimanager'; import { NNIManager } from './core/nnimanager';
import { SqlDB } from './core/sqlDatabase'; import { SqlDB } from './core/sqlDatabase';
import { NNIRestServer } from './rest_server/nniRestServer'; import { NNIRestServer } from './rest_server/nniRestServer';
import { LocalTrainingServiceForGPU } from './training_service/local/localTrainingServiceForGPU'; import { FrameworkControllerTrainingService } from './training_service/kubernetes/frameworkcontroller/frameworkcontrollerTrainingService';
import { KubeflowTrainingService } from './training_service/kubernetes/kubeflow/kubeflowTrainingService';
import { LocalTrainingService } from './training_service/local/localTrainingService';
import { PAITrainingService } from './training_service/pai/paiTrainingService';
import { import {
RemoteMachineTrainingService RemoteMachineTrainingService
} from './training_service/remote_machine/remoteMachineTrainingService'; } from './training_service/remote_machine/remoteMachineTrainingService';
import { PAITrainingService } from './training_service/pai/paiTrainingService';
import { KubeflowTrainingService } from './training_service/kubernetes/kubeflow/kubeflowTrainingService';
import { FrameworkControllerTrainingService } from './training_service/kubernetes/frameworkcontroller/frameworkcontrollerTrainingService';
function initStartupInfo(startExpMode: string, resumeExperimentId: string, basePort: number, logDirectory: string, experimentLogLevel: string) { function initStartupInfo(
startExpMode: string, resumeExperimentId: string, basePort: number,
logDirectory: string, experimentLogLevel: string): void {
const createNew: boolean = (startExpMode === 'new'); const createNew: boolean = (startExpMode === 'new');
const expId: string = createNew ? uniqueString(8) : resumeExperimentId; const expId: string = createNew ? uniqueString(8) : resumeExperimentId;
setExperimentStartupInfo(createNew, expId, basePort, logDirectory, experimentLogLevel); setExperimentStartupInfo(createNew, expId, basePort, logDirectory, experimentLogLevel);
...@@ -49,29 +51,45 @@ function initStartupInfo(startExpMode: string, resumeExperimentId: string, baseP ...@@ -49,29 +51,45 @@ function initStartupInfo(startExpMode: string, resumeExperimentId: string, baseP
async function initContainer(platformMode: string): Promise<void> { async function initContainer(platformMode: string): Promise<void> {
if (platformMode === 'local') { if (platformMode === 'local') {
Container.bind(TrainingService).to(LocalTrainingServiceForGPU).scope(Scope.Singleton); Container.bind(TrainingService)
.to(LocalTrainingService)
.scope(Scope.Singleton);
} else if (platformMode === 'remote') { } else if (platformMode === 'remote') {
Container.bind(TrainingService).to(RemoteMachineTrainingService).scope(Scope.Singleton); Container.bind(TrainingService)
.to(RemoteMachineTrainingService)
.scope(Scope.Singleton);
} else if (platformMode === 'pai') { } else if (platformMode === 'pai') {
Container.bind(TrainingService).to(PAITrainingService).scope(Scope.Singleton); Container.bind(TrainingService)
.to(PAITrainingService)
.scope(Scope.Singleton);
} else if (platformMode === 'kubeflow') { } else if (platformMode === 'kubeflow') {
Container.bind(TrainingService).to(KubeflowTrainingService).scope(Scope.Singleton); Container.bind(TrainingService)
.to(KubeflowTrainingService)
.scope(Scope.Singleton);
} else if (platformMode === 'frameworkcontroller') { } else if (platformMode === 'frameworkcontroller') {
Container.bind(TrainingService).to(FrameworkControllerTrainingService).scope(Scope.Singleton); Container.bind(TrainingService)
} .to(FrameworkControllerTrainingService)
else { .scope(Scope.Singleton);
} else {
throw new Error(`Error: unsupported mode: ${mode}`); throw new Error(`Error: unsupported mode: ${mode}`);
} }
Container.bind(Manager).to(NNIManager).scope(Scope.Singleton); Container.bind(Manager)
Container.bind(Database).to(SqlDB).scope(Scope.Singleton); .to(NNIManager)
Container.bind(DataStore).to(NNIDataStore).scope(Scope.Singleton); .scope(Scope.Singleton);
Container.bind(Database)
.to(SqlDB)
.scope(Scope.Singleton);
Container.bind(DataStore)
.to(NNIDataStore)
.scope(Scope.Singleton);
const ds: DataStore = component.get(DataStore); const ds: DataStore = component.get(DataStore);
await ds.init(); await ds.init();
} }
function usage(): void { function usage(): void {
console.info('usage: node main.js --port <port> --mode <local/remote/pai/kubeflow/frameworkcontroller> --start_mode <new/resume> --experiment_id <id>'); console.info('usage: node main.js --port <port> --mode \
<local/remote/pai/kubeflow/frameworkcontroller> --start_mode <new/resume> --experiment_id <id>');
} }
const strPort: string = parseArg(['--port', '-p']); const strPort: string = parseArg(['--port', '-p']);
...@@ -117,7 +135,8 @@ if (logLevel.length > 0 && !logLevelNameMap.has(logLevel)) { ...@@ -117,7 +135,8 @@ if (logLevel.length > 0 && !logLevelNameMap.has(logLevel)) {
initStartupInfo(startMode, experimentId, port, logDir, logLevel); initStartupInfo(startMode, experimentId, port, logDir, logLevel);
mkDirP(getLogDir()).then(async () => { mkDirP(getLogDir())
.then(async () => {
const log: Logger = getLogger(); const log: Logger = getLogger();
try { try {
await initContainer(mode); await initContainer(mode);
...@@ -128,25 +147,26 @@ mkDirP(getLogDir()).then(async () => { ...@@ -128,25 +147,26 @@ mkDirP(getLogDir()).then(async () => {
log.error(`${err.stack}`); log.error(`${err.stack}`);
throw err; throw err;
} }
}).catch((err: Error) => { })
.catch((err: Error) => {
console.error(`Failed to create log dir: ${err.stack}`); console.error(`Failed to create log dir: ${err.stack}`);
}); });
process.on('SIGTERM', async () => { process.on('SIGTERM', async () => {
const log: Logger = getLogger(); const log: Logger = getLogger();
let hasError: boolean = false; let hasError: boolean = false;
try{ try {
const nniManager: Manager = component.get(Manager); const nniManager: Manager = component.get(Manager);
await nniManager.stopExperiment(); await nniManager.stopExperiment();
const ds: DataStore = component.get(DataStore); const ds: DataStore = component.get(DataStore);
await ds.close(); await ds.close();
const restServer: NNIRestServer = component.get(NNIRestServer); const restServer: NNIRestServer = component.get(NNIRestServer);
await restServer.stop(); await restServer.stop();
}catch(err){ } catch (err) {
hasError = true; hasError = true;
log.error(`${err.stack}`); log.error(`${err.stack}`);
}finally{ } finally {
await log.close(); await log.close();
process.exit(hasError?1:0); process.exit(hasError ? 1 : 0);
} }
}) });
\ No newline at end of file
...@@ -30,8 +30,12 @@ export namespace ValidationSchemas { ...@@ -30,8 +30,12 @@ export namespace ValidationSchemas {
port: joi.number().min(1).max(65535).required(), port: joi.number().min(1).max(65535).required(),
passwd: joi.string(), passwd: joi.string(),
sshKeyPath: joi.string(), sshKeyPath: joi.string(),
passphrase: joi.string() passphrase: joi.string(),
gpuIndices: joi.string()
})), })),
local_config: joi.object({
gpuIndices: joi.string()
}),
trial_config: joi.object({ trial_config: joi.object({
image: joi.string().min(1), image: joi.string().min(1),
codeDir: joi.string().min(1).required(), codeDir: joi.string().min(1).required(),
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
*/ */
export enum TrialConfigMetadataKey { export enum TrialConfigMetadataKey {
MACHINE_LIST = 'machine_list', MACHINE_LIST = 'machine_list',
LOCAL_CONFIG = 'local_config',
TRIAL_CONFIG = 'trial_config', TRIAL_CONFIG = 'trial_config',
EXPERIMENT_ID = 'experimentId', EXPERIMENT_ID = 'experimentId',
MULTI_PHASE = 'multiPhase', MULTI_PHASE = 'multiPhase',
......
...@@ -19,19 +19,18 @@ ...@@ -19,19 +19,18 @@
'use strict'; 'use strict';
import { delay } from '../../common/utils';
import { GPUInfo, GPUSummary } from '../common/gpuData';
import { getLogger, Logger } from '../../common/log';
import * as cp from 'child_process';
import * as cpp from 'child-process-promise'; import * as cpp from 'child-process-promise';
import * as path from 'path'; import * as cp from 'child_process';
import * as os from 'os';
import * as fs from 'fs'; import * as fs from 'fs';
import * as os from 'os';
import * as path from 'path';
import { String } from 'typescript-string-operations'; import { String } from 'typescript-string-operations';
import { GPU_INFO_COLLECTOR_FORMAT } from '../common/gpuData' import { getLogger, Logger } from '../../common/log';
import { delay } from '../../common/utils';
import { GPU_INFO_COLLECTOR_FORMAT, GPUInfo, GPUSummary } from '../common/gpuData';
/** /**
* GPUScheduler * GPUScheduler for local training service
*/ */
class GPUScheduler { class GPUScheduler {
...@@ -58,45 +57,55 @@ class GPUScheduler { ...@@ -58,45 +57,55 @@ class GPUScheduler {
} }
} }
/**
* Generate gpu metric collector shell script in local machine,
* used to run in remote machine, and will be deleted after uploaded from local.
*/
private async runGpuMetricsCollectorScript(): Promise<void> {
await cpp.exec(`mkdir -p ${this.gpuMetricCollectorScriptFolder}`);
//generate gpu_metrics_collector.sh
let gpuMetricsCollectorScriptPath: string = path.join(this.gpuMetricCollectorScriptFolder, 'gpu_metrics_collector.sh');
const gpuMetricsCollectorScriptContent: string = String.Format(
GPU_INFO_COLLECTOR_FORMAT,
this.gpuMetricCollectorScriptFolder,
path.join(this.gpuMetricCollectorScriptFolder, 'pid'),
);
await fs.promises.writeFile(gpuMetricsCollectorScriptPath, gpuMetricsCollectorScriptContent, { encoding: 'utf8' });
cp.exec(`bash ${gpuMetricsCollectorScriptPath}`);
}
public getAvailableGPUIndices(): number[] { public getAvailableGPUIndices(): number[] {
if (this.gpuSummary !== undefined) { if (this.gpuSummary !== undefined) {
return this.gpuSummary.gpuInfos.filter((info: GPUInfo) => info.activeProcessNum === 0).map((info: GPUInfo) => info.index); return this.gpuSummary.gpuInfos.filter((info: GPUInfo) => info.activeProcessNum === 0)
.map((info: GPUInfo) => info.index);
} }
return []; return [];
} }
public async stop() { public getSystemGpuCount(): number {
if (this.gpuSummary !== undefined) {
return this.gpuSummary.gpuCount;
}
return 0;
}
public async stop(): Promise<void> {
this.stopping = true; this.stopping = true;
try { try {
const pid: string = await fs.promises.readFile(path.join(this.gpuMetricCollectorScriptFolder, 'pid'), 'utf8'); const pid: string = await fs.promises.readFile(path.join(this.gpuMetricCollectorScriptFolder, 'pid'), 'utf8');
await cpp.exec(`pkill -P ${pid}`); await cpp.exec(`pkill -P ${pid}`);
await cpp.exec(`rm -rf ${this.gpuMetricCollectorScriptFolder}`); await cpp.exec(`rm -rf ${this.gpuMetricCollectorScriptFolder}`);
} catch (error){ } catch (error) {
this.log.error(`GPU scheduler error: ${error}`); this.log.error(`GPU scheduler error: ${error}`);
} }
} }
private async updateGPUSummary() { /**
const cmdresult = await cpp.exec(`tail -n 1 ${path.join(this.gpuMetricCollectorScriptFolder, 'gpu_metrics')}`); * Generate gpu metric collector shell script in local machine,
if(cmdresult && cmdresult.stdout) { * used to run in remote machine, and will be deleted after uploaded from local.
*/
private async runGpuMetricsCollectorScript(): Promise<void> {
await cpp.exec(`mkdir -p ${this.gpuMetricCollectorScriptFolder}`);
//generate gpu_metrics_collector.sh
const gpuMetricsCollectorScriptPath: string = path.join(this.gpuMetricCollectorScriptFolder, 'gpu_metrics_collector.sh');
const gpuMetricsCollectorScriptContent: string = String.Format(
GPU_INFO_COLLECTOR_FORMAT,
this.gpuMetricCollectorScriptFolder,
path.join(this.gpuMetricCollectorScriptFolder, 'pid')
);
await fs.promises.writeFile(gpuMetricsCollectorScriptPath, gpuMetricsCollectorScriptContent, { encoding: 'utf8' });
cp.exec(`bash ${gpuMetricsCollectorScriptPath}`);
}
private async updateGPUSummary(): Promise<void> {
const cmdresult: cpp.childProcessPromise.Result =
await cpp.exec(`tail -n 1 ${path.join(this.gpuMetricCollectorScriptFolder, 'gpu_metrics')}`);
if (cmdresult && cmdresult.stdout) {
this.gpuSummary = <GPUSummary>JSON.parse(cmdresult.stdout); this.gpuSummary = <GPUSummary>JSON.parse(cmdresult.stdout);
} else { } else {
this.log.error('Could not get gpu metrics information!'); this.log.error('Could not get gpu metrics information!');
......
...@@ -19,7 +19,6 @@ ...@@ -19,7 +19,6 @@
'use strict'; 'use strict';
import * as assert from 'assert';
import * as cpp from 'child-process-promise'; import * as cpp from 'child-process-promise';
import * as cp from 'child_process'; import * as cp from 'child_process';
import { EventEmitter } from 'events'; import { EventEmitter } from 'events';
...@@ -27,15 +26,16 @@ import * as fs from 'fs'; ...@@ -27,15 +26,16 @@ import * as fs from 'fs';
import * as path from 'path'; import * as path from 'path';
import * as ts from 'tail-stream'; import * as ts from 'tail-stream';
import { NNIError, NNIErrorNames } from '../../common/errors'; import { NNIError, NNIErrorNames } from '../../common/errors';
import { getLogger, Logger } from '../../common/log';
import { TrialConfig } from '../common/trialConfig';
import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
import { getInitTrialSequenceId } from '../../common/experimentStartupInfo'; import { getInitTrialSequenceId } from '../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../common/log';
import { import {
HostJobApplicationForm, JobApplicationForm, HyperParameters, TrainingService, TrialJobApplicationForm, HostJobApplicationForm, HyperParameters, JobApplicationForm, TrainingService, TrialJobApplicationForm,
TrialJobDetail, TrialJobMetric, TrialJobStatus TrialJobDetail, TrialJobMetric, TrialJobStatus
} from '../../common/trainingService'; } from '../../common/trainingService';
import { delay, generateParamFileName, getExperimentRootDir, uniqueString, getJobCancelStatus } from '../../common/utils'; import { delay, generateParamFileName, getExperimentRootDir, getJobCancelStatus, uniqueString } from '../../common/utils';
import { TrialConfig } from '../common/trialConfig';
import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
import { GPUScheduler } from './gpuScheduler';
const tkill = require('tree-kill'); const tkill = require('tree-kill');
...@@ -46,6 +46,7 @@ const tkill = require('tree-kill'); ...@@ -46,6 +46,7 @@ const tkill = require('tree-kill');
* success: true if the buffer contains at least one complete command; otherwise false * success: true if the buffer contains at least one complete command; otherwise false
* remain: remaining data after the first command * remain: remaining data after the first command
*/ */
// tslint:disable-next-line:informative-docs
function decodeCommand(data: Buffer): [boolean, string, string, Buffer] { function decodeCommand(data: Buffer): [boolean, string, string, Buffer] {
if (data.length < 8) { if (data.length < 8) {
return [false, '', '', data]; return [false, '', '', data];
...@@ -76,8 +77,10 @@ class LocalTrialJobDetail implements TrialJobDetail { ...@@ -76,8 +77,10 @@ class LocalTrialJobDetail implements TrialJobDetail {
public form: JobApplicationForm; public form: JobApplicationForm;
public sequenceId: number; public sequenceId: number;
public pid?: number; public pid?: number;
public gpuIndices?: number[];
constructor(id: string, status: TrialJobStatus, submitTime: number, constructor(
id: string, status: TrialJobStatus, submitTime: number,
workingDirectory: string, form: JobApplicationForm, sequenceId: number) { workingDirectory: string, form: JobApplicationForm, sequenceId: number) {
this.id = id; this.id = id;
this.status = status; this.status = status;
...@@ -86,6 +89,19 @@ class LocalTrialJobDetail implements TrialJobDetail { ...@@ -86,6 +89,19 @@ class LocalTrialJobDetail implements TrialJobDetail {
this.form = form; this.form = form;
this.url = `file://localhost:${workingDirectory}`; this.url = `file://localhost:${workingDirectory}`;
this.sequenceId = sequenceId; this.sequenceId = sequenceId;
this.gpuIndices = [];
}
}
/**
* Local training service config
*/
class LocalConfig {
public gpuIndices?: string;
constructor(gpuIndices?: string) {
if (gpuIndices !== undefined) {
this.gpuIndices = gpuIndices;
}
} }
} }
...@@ -100,10 +116,14 @@ class LocalTrainingService implements TrainingService { ...@@ -100,10 +116,14 @@ class LocalTrainingService implements TrainingService {
private stopping: boolean; private stopping: boolean;
private rootDir!: string; private rootDir!: string;
private trialSequenceId: number; private trialSequenceId: number;
protected log: Logger; private gpuScheduler!: GPUScheduler;
protected localTrailConfig?: TrialConfig; private occupiedGpuIndices: Set<number>;
private designatedGpuIndices!: Set<number>;
private log: Logger;
private localTrailConfig?: TrialConfig;
private localConfig?: LocalConfig;
private isMultiPhase: boolean = false; private isMultiPhase: boolean = false;
protected jobStreamMap: Map<string, ts.Stream>; private jobStreamMap: Map<string, ts.Stream>;
constructor() { constructor() {
this.eventEmitter = new EventEmitter(); this.eventEmitter = new EventEmitter();
...@@ -115,26 +135,16 @@ class LocalTrainingService implements TrainingService { ...@@ -115,26 +135,16 @@ class LocalTrainingService implements TrainingService {
this.trialSequenceId = -1; this.trialSequenceId = -1;
this.jobStreamMap = new Map<string, ts.Stream>(); this.jobStreamMap = new Map<string, ts.Stream>();
this.log.info('Construct local machine training service.'); this.log.info('Construct local machine training service.');
this.occupiedGpuIndices = new Set<number>();
} }
public async run(): Promise<void> { public async run(): Promise<void> {
this.log.info('Run local machine training service.'); this.log.info('Run local machine training service.');
while (!this.stopping) { const longRunningTasks: Promise<void>[] = [this.runJobLoop()];
while (this.jobQueue.length !== 0) { if (this.gpuScheduler !== undefined) {
const trialJobId: string = this.jobQueue[0]; longRunningTasks.push(this.gpuScheduler.run());
const trialJobDeatil = this.jobMap.get(trialJobId)
if (trialJobDeatil !== undefined && trialJobDeatil.status === 'WAITING'){
const [success, resource] = this.tryGetAvailableResource();
if (!success) {
break;
}
this.occupyResource(resource);
await this.runTrialJob(trialJobId, resource);
}
this.jobQueue.shift();
}
await delay(5000);
} }
await Promise.all(longRunningTasks);
this.log.info('Local machine training service exit.'); this.log.info('Local machine training service exit.');
} }
...@@ -172,7 +182,8 @@ class LocalTrainingService implements TrainingService { ...@@ -172,7 +182,8 @@ class LocalTrainingService implements TrainingService {
this.setTrialJobStatus(trialJob, 'FAILED'); this.setTrialJobStatus(trialJob, 'FAILED');
try { try {
const state: string = await fs.promises.readFile(path.join(trialJob.workingDirectory, '.nni', 'state'), 'utf8'); const state: string = await fs.promises.readFile(path.join(trialJob.workingDirectory, '.nni', 'state'), 'utf8');
const match: RegExpMatchArray | null = state.trim().match(/^(\d+)\s+(\d+)/); const match: RegExpMatchArray | null = state.trim()
.match(/^(\d+)\s+(\d+)/);
if (match !== null) { if (match !== null) {
const { 1: code, 2: timestamp } = match; const { 1: code, 2: timestamp } = match;
if (parseInt(code, 10) === 0) { if (parseInt(code, 10) === 0) {
...@@ -253,8 +264,9 @@ class LocalTrainingService implements TrainingService { ...@@ -253,8 +264,9 @@ class LocalTrainingService implements TrainingService {
if (trialJob === undefined) { if (trialJob === undefined) {
throw new NNIError(NNIErrorNames.NOT_FOUND, 'Trial job not found'); throw new NNIError(NNIErrorNames.NOT_FOUND, 'Trial job not found');
} }
if (trialJob.pid === undefined){ if (trialJob.pid === undefined) {
this.setTrialJobStatus(trialJob, 'USER_CANCELED'); this.setTrialJobStatus(trialJob, 'USER_CANCELED');
return Promise.resolve(); return Promise.resolve();
} }
if (trialJob.form.jobType === 'TRIAL') { if (trialJob.form.jobType === 'TRIAL') {
...@@ -265,6 +277,7 @@ class LocalTrainingService implements TrainingService { ...@@ -265,6 +277,7 @@ class LocalTrainingService implements TrainingService {
throw new Error(`Job type not supported: ${trialJob.form.jobType}`); throw new Error(`Job type not supported: ${trialJob.form.jobType}`);
} }
this.setTrialJobStatus(trialJob, getJobCancelStatus(isEarlyStopped)); this.setTrialJobStatus(trialJob, getJobCancelStatus(isEarlyStopped));
return Promise.resolve(); return Promise.resolve();
} }
...@@ -281,6 +294,21 @@ class LocalTrainingService implements TrainingService { ...@@ -281,6 +294,21 @@ class LocalTrainingService implements TrainingService {
if (!this.localTrailConfig) { if (!this.localTrailConfig) {
throw new Error('trial config parsed failed'); throw new Error('trial config parsed failed');
} }
this.log.info(`required GPU number is ${this.localTrailConfig.gpuNum}`);
if (this.gpuScheduler === undefined && this.localTrailConfig.gpuNum > 0) {
this.gpuScheduler = new GPUScheduler();
}
break;
case TrialConfigMetadataKey.LOCAL_CONFIG:
this.localConfig = <LocalConfig>JSON.parse(value);
this.log.info(`Specified GPU indices: ${this.localConfig.gpuIndices}`);
if (this.localConfig.gpuIndices !== undefined) {
this.designatedGpuIndices = new Set(this.localConfig.gpuIndices.split(',')
.map((x: string) => parseInt(x, 10)));
if (this.designatedGpuIndices.size === 0) {
throw new Error('gpuIndices can not be empty if specified.');
}
}
break; break;
case TrialConfigMetadataKey.MULTI_PHASE: case TrialConfigMetadataKey.MULTI_PHASE:
this.isMultiPhase = (value === 'true' || value === 'True'); this.isMultiPhase = (value === 'true' || value === 'True');
...@@ -298,37 +326,51 @@ class LocalTrainingService implements TrainingService { ...@@ -298,37 +326,51 @@ class LocalTrainingService implements TrainingService {
} else { } else {
getResult = Promise.resolve(!this.localTrailConfig ? '' : JSON.stringify(this.localTrailConfig)); getResult = Promise.resolve(!this.localTrailConfig ? '' : JSON.stringify(this.localTrailConfig));
} }
return getResult; return getResult;
default: default:
return Promise.reject(new NNIError(NNIErrorNames.NOT_FOUND, 'Key not found')); return Promise.reject(new NNIError(NNIErrorNames.NOT_FOUND, 'Key not found'));
} }
} }
public cleanUp(): Promise<void> { public async cleanUp(): Promise<void> {
this.log.info('Stopping local machine training service...'); this.log.info('Stopping local machine training service...');
this.stopping = true; this.stopping = true;
for (const stream of this.jobStreamMap.values()) { for (const stream of this.jobStreamMap.values()) {
stream.destroy(); stream.destroy();
} }
if (this.gpuScheduler !== undefined) {
await this.gpuScheduler.stop();
}
return Promise.resolve(); return Promise.resolve();
} }
protected onTrialJobStatusChanged(trialJob: TrialJobDetail, oldStatus: TrialJobStatus): void { private onTrialJobStatusChanged(trialJob: LocalTrialJobDetail, oldStatus: TrialJobStatus): void {
//if job is not running, destory job stream //if job is not running, destory job stream
if(['SUCCEEDED', 'FAILED', 'USER_CANCELED', 'SYS_CANCELED', 'EARLY_STOPPED'].includes(trialJob.status)) { if (['SUCCEEDED', 'FAILED', 'USER_CANCELED', 'SYS_CANCELED', 'EARLY_STOPPED'].includes(trialJob.status)) {
if(this.jobStreamMap.has(trialJob.id)) { if (this.jobStreamMap.has(trialJob.id)) {
const stream = this.jobStreamMap.get(trialJob.id); const stream: ts.Stream | undefined = this.jobStreamMap.get(trialJob.id);
if(!stream) { if (!stream) {
throw new Error(`Could not find stream in trial ${trialJob.id}`); throw new Error(`Could not find stream in trial ${trialJob.id}`);
} }
stream.destroy(); stream.destroy();
this.jobStreamMap.delete(trialJob.id); this.jobStreamMap.delete(trialJob.id);
} }
} }
if (trialJob.gpuIndices !== undefined && trialJob.gpuIndices.length > 0 && this.gpuScheduler !== undefined) {
if (oldStatus === 'RUNNING' && trialJob.status !== 'RUNNING') {
for (const index of trialJob.gpuIndices) {
this.occupiedGpuIndices.delete(index);
}
}
}
} }
protected getEnvironmentVariables(trialJobDetail: TrialJobDetail, _: {}): { key: string; value: string }[] { private getEnvironmentVariables(
return [ trialJobDetail: TrialJobDetail,
resource?: { gpuIndices: number[] }): { key: string; value: string }[] {
const envVariables: { key: string; value: string }[] = [
{ key: 'NNI_PLATFORM', value: 'local' }, { key: 'NNI_PLATFORM', value: 'local' },
{ key: 'NNI_SYS_DIR', value: trialJobDetail.workingDirectory }, { key: 'NNI_SYS_DIR', value: trialJobDetail.workingDirectory },
{ key: 'NNI_TRIAL_JOB_ID', value: trialJobDetail.id }, { key: 'NNI_TRIAL_JOB_ID', value: trialJobDetail.id },
...@@ -336,18 +378,85 @@ class LocalTrainingService implements TrainingService { ...@@ -336,18 +378,85 @@ class LocalTrainingService implements TrainingService {
{ key: 'NNI_TRIAL_SEQ_ID', value: trialJobDetail.sequenceId.toString() }, { key: 'NNI_TRIAL_SEQ_ID', value: trialJobDetail.sequenceId.toString() },
{ key: 'MULTI_PHASE', value: this.isMultiPhase.toString() } { key: 'MULTI_PHASE', value: this.isMultiPhase.toString() }
]; ];
if (resource !== undefined && resource.gpuIndices.length > 0) {
envVariables.push({
key: 'CUDA_VISIBLE_DEVICES',
value: this.gpuScheduler === undefined ? '' : resource.gpuIndices.join(',')
});
}
return envVariables;
}
private setExtraProperties(trialJobDetail: LocalTrialJobDetail, resource: { gpuIndices: number[] }): void {
trialJobDetail.gpuIndices = resource.gpuIndices;
}
private tryGetAvailableResource(): [boolean, { gpuIndices: number[]}] {
if (this.localTrailConfig === undefined) {
throw new Error('localTrailConfig is not initialized!');
}
const resource: { gpuIndices: number[] } = { gpuIndices: [] };
if (this.gpuScheduler === undefined) {
return [true, resource];
}
let selectedGPUIndices: number[] = this.gpuScheduler.getAvailableGPUIndices()
.filter((index: number) => !this.occupiedGpuIndices.has(index));
if (this.designatedGpuIndices !== undefined) {
this.checkSpecifiedGpuIndices();
selectedGPUIndices = selectedGPUIndices.filter((index: number) => this.designatedGpuIndices.has(index));
}
if (selectedGPUIndices.length < this.localTrailConfig.gpuNum) {
return [false, resource];
}
selectedGPUIndices.splice(this.localTrailConfig.gpuNum);
Object.assign(resource, { gpuIndices: selectedGPUIndices });
return [true, resource];
} }
protected setExtraProperties(trialJobDetail: TrialJobDetail, resource: {}): void { private checkSpecifiedGpuIndices(): void {
//abstract const gpuCount: number = this.gpuScheduler.getSystemGpuCount();
if (this.designatedGpuIndices !== undefined) {
for (const index of this.designatedGpuIndices) {
if (index >= gpuCount) {
throw new Error(`Specified GPU index not found: ${index}`);
}
}
}
} }
protected tryGetAvailableResource(): [boolean, {}] { private occupyResource(resource: {gpuIndices: number[]}): void {
return [true, {}]; if (this.gpuScheduler !== undefined) {
for (const index of resource.gpuIndices) {
this.occupiedGpuIndices.add(index);
}
}
} }
protected occupyResource(_: {}): void { private async runJobLoop(): Promise<void> {
//abstract while (!this.stopping) {
while (!this.stopping && this.jobQueue.length !== 0) {
const trialJobId: string = this.jobQueue[0];
const trialJobDeatil: LocalTrialJobDetail | undefined = this.jobMap.get(trialJobId);
if (trialJobDeatil !== undefined && trialJobDeatil.status === 'WAITING') {
const [success, resource] = this.tryGetAvailableResource();
if (!success) {
break;
}
this.occupyResource(resource);
await this.runTrialJob(trialJobId, resource);
}
this.jobQueue.shift();
}
await delay(5000);
}
} }
private setTrialJobStatus(trialJob: LocalTrialJobDetail, newStatus: TrialJobStatus): void { private setTrialJobStatus(trialJob: LocalTrialJobDetail, newStatus: TrialJobStatus): void {
...@@ -358,7 +467,7 @@ class LocalTrainingService implements TrainingService { ...@@ -358,7 +467,7 @@ class LocalTrainingService implements TrainingService {
} }
} }
private async runTrialJob(trialJobId: string, resource: {}): Promise<void> { private async runTrialJob(trialJobId: string, resource: {gpuIndices: number[]}): Promise<void> {
const trialJobDetail: LocalTrialJobDetail = <LocalTrialJobDetail>this.jobMap.get(trialJobId); const trialJobDetail: LocalTrialJobDetail = <LocalTrialJobDetail>this.jobMap.get(trialJobId);
const variables: { key: string; value: string }[] = this.getEnvironmentVariables(trialJobDetail, resource); const variables: { key: string; value: string }[] = this.getEnvironmentVariables(trialJobDetail, resource);
...@@ -380,7 +489,8 @@ class LocalTrainingService implements TrainingService { ...@@ -380,7 +489,8 @@ class LocalTrainingService implements TrainingService {
await cpp.exec(`mkdir -p ${trialJobDetail.workingDirectory}`); await cpp.exec(`mkdir -p ${trialJobDetail.workingDirectory}`);
await cpp.exec(`mkdir -p ${path.join(trialJobDetail.workingDirectory, '.nni')}`); await cpp.exec(`mkdir -p ${path.join(trialJobDetail.workingDirectory, '.nni')}`);
await cpp.exec(`touch ${path.join(trialJobDetail.workingDirectory, '.nni', 'metrics')}`); await cpp.exec(`touch ${path.join(trialJobDetail.workingDirectory, '.nni', 'metrics')}`);
await fs.promises.writeFile(path.join(trialJobDetail.workingDirectory, 'run.sh'), runScriptLines.join('\n'), { encoding: 'utf8', mode: 0o777 }); await fs.promises.writeFile(
path.join(trialJobDetail.workingDirectory, 'run.sh'), runScriptLines.join('\n'), { encoding: 'utf8', mode: 0o777 });
await this.writeParameterFile(trialJobDetail.workingDirectory, (<TrialJobApplicationForm>trialJobDetail.form).hyperParameters); await this.writeParameterFile(trialJobDetail.workingDirectory, (<TrialJobApplicationForm>trialJobDetail.form).hyperParameters);
const process: cp.ChildProcess = cp.exec(`bash ${path.join(trialJobDetail.workingDirectory, 'run.sh')}`); const process: cp.ChildProcess = cp.exec(`bash ${path.join(trialJobDetail.workingDirectory, 'run.sh')}`);
...@@ -406,7 +516,7 @@ class LocalTrainingService implements TrainingService { ...@@ -406,7 +516,7 @@ class LocalTrainingService implements TrainingService {
buffer = remain; buffer = remain;
} }
}); });
this.jobStreamMap.set(trialJobDetail.id, stream); this.jobStreamMap.set(trialJobDetail.id, stream);
} }
......
/**
* Copyright (c) Microsoft Corporation
* All rights reserved.
*
* MIT License
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
* documentation files (the "Software"), to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and
* to permit persons to whom the Software is furnished to do so, subject to the following conditions:
* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
'use strict';
import { TrialJobDetail, TrialJobStatus } from '../../common/trainingService';
import { GPUScheduler } from './gpuScheduler';
import { LocalTrainingService } from './localTrainingService';
import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
type LocalTrialJobDetailForGPU = TrialJobDetail & { gpuIndices: number[] };
/**
* Local training service for GPU
*/
class LocalTrainingServiceForGPU extends LocalTrainingService {
private requiredGPUNum!: number;
private gpuScheduler!: GPUScheduler;
private availableGPUIndices: boolean[];
constructor() {
super();
this.availableGPUIndices = Array(16).fill(false); // Assume the maximum gpu number is 16
}
public async run(): Promise<void> {
if (this.gpuScheduler !== undefined) {
await Promise.all([
this.gpuScheduler.run(),
super.run()
]);
} else {
await super.run();
}
}
public async setClusterMetadata(key: string, value: string): Promise<void> {
await super.setClusterMetadata(key, value);
switch (key) {
case TrialConfigMetadataKey.TRIAL_CONFIG:
if(this.localTrailConfig !== undefined) {
this.requiredGPUNum = this.localTrailConfig.gpuNum;
} else {
// If no valid trial config is initialized, set requiredGPUNum to 0 as fallback value.
this.requiredGPUNum = 0;
}
this.log.info('required GPU number is ' + this.requiredGPUNum);
if (this.gpuScheduler === undefined && this.requiredGPUNum > 0) {
this.gpuScheduler = new GPUScheduler();
}
break;
default:
}
}
public async cleanUp(): Promise<void> {
if (this.gpuScheduler !== undefined) {
await this.gpuScheduler.stop();
}
return super.cleanUp();
}
protected onTrialJobStatusChanged(trialJob: LocalTrialJobDetailForGPU, oldStatus: TrialJobStatus): void {
super.onTrialJobStatusChanged(trialJob, oldStatus);
if (trialJob.gpuIndices !== undefined && trialJob.gpuIndices.length !== 0 && this.gpuScheduler !== undefined) {
if (oldStatus === 'RUNNING' && trialJob.status !== 'RUNNING') {
for (const index of trialJob.gpuIndices) {
this.availableGPUIndices[index] = false;
}
}
}
}
protected getEnvironmentVariables(
trialJobDetail: TrialJobDetail,
resource: { gpuIndices: number[] }): { key: string; value: string }[] {
const variables: { key: string; value: string }[] = super.getEnvironmentVariables(trialJobDetail, resource);
variables.push({
key: 'CUDA_VISIBLE_DEVICES',
value: this.gpuScheduler === undefined ? '' : resource.gpuIndices.join(',')
});
return variables;
}
protected setExtraProperties(trialJobDetail: LocalTrialJobDetailForGPU, resource: { gpuIndices: number[] }): void {
super.setExtraProperties(trialJobDetail, resource);
trialJobDetail.gpuIndices = resource.gpuIndices;
}
protected tryGetAvailableResource(): [boolean, {}] {
const [success, resource] = super.tryGetAvailableResource();
if (!success || this.gpuScheduler === undefined) {
return [success, resource];
}
const availableGPUIndices: number[] = this.gpuScheduler.getAvailableGPUIndices();
const selectedGPUIndices: number[] = availableGPUIndices.filter((index: number) => this.availableGPUIndices[index] === false);
if (selectedGPUIndices.length < this.requiredGPUNum) {
return [false, resource];
}
selectedGPUIndices.splice(this.requiredGPUNum);
Object.assign(resource, { gpuIndices: selectedGPUIndices });
return [true, resource];
}
protected occupyResource(resource: { gpuIndices: number[] }): void {
super.occupyResource(resource);
if (this.gpuScheduler !== undefined) {
for (const index of resource.gpuIndices) {
this.availableGPUIndices[index] = true;
}
}
}
}
export { LocalTrainingServiceForGPU };
...@@ -20,11 +20,10 @@ ...@@ -20,11 +20,10 @@
'use strict'; 'use strict';
import * as assert from 'assert'; import * as assert from 'assert';
import { Client } from 'ssh2';
import { getLogger, Logger } from '../../common/log'; import { getLogger, Logger } from '../../common/log';
import { randomSelect } from '../../common/utils'; import { randomSelect } from '../../common/utils';
import { GPUInfo } from '../common/gpuData'; import { GPUInfo } from '../common/gpuData';
import { RemoteMachineMeta, RemoteMachineScheduleResult, ScheduleResultType, SSHClientManager } from './remoteMachineData'; import { parseGpuIndices, RemoteMachineMeta, RemoteMachineScheduleResult, ScheduleResultType, SSHClientManager } from './remoteMachineData';
/** /**
* A simple GPU scheduler implementation * A simple GPU scheduler implementation
...@@ -85,6 +84,20 @@ export class GPUScheduler { ...@@ -85,6 +84,20 @@ export class GPUScheduler {
}; };
} }
/**
* remove the job's gpu reversion
*/
public removeGpuReservation(trialJobId: string, rmMeta?: RemoteMachineMeta): void {
// If remote machine has no GPU, gpuReservcation is not initialized, so check if it's undefined
if (rmMeta !== undefined && rmMeta.gpuReservation !== undefined) {
rmMeta.gpuReservation.forEach((reserveTrialJobId : string, gpuIndex : number) => {
if (reserveTrialJobId === trialJobId) {
rmMeta.gpuReservation.delete(gpuIndex);
}
});
}
}
private scheduleGPUHost(requiredGPUNum: number, trialJobId: string): RemoteMachineScheduleResult | undefined { private scheduleGPUHost(requiredGPUNum: number, trialJobId: string): RemoteMachineScheduleResult | undefined {
const totalResourceMap: Map<RemoteMachineMeta, GPUInfo[]> = this.gpuResourceDetection(); const totalResourceMap: Map<RemoteMachineMeta, GPUInfo[]> = this.gpuResourceDetection();
const qualifiedRMs: RemoteMachineMeta[] = []; const qualifiedRMs: RemoteMachineMeta[] = [];
...@@ -120,11 +133,14 @@ export class GPUScheduler { ...@@ -120,11 +133,14 @@ export class GPUScheduler {
if (rmMeta.gpuReservation === undefined) { if (rmMeta.gpuReservation === undefined) {
rmMeta.gpuReservation = new Map<number, string>(); rmMeta.gpuReservation = new Map<number, string>();
} }
const designatedGpuIndices: Set<number> | undefined = parseGpuIndices(rmMeta.gpuIndices);
this.log.debug(`designated gpu indices: ${designatedGpuIndices}`);
rmMeta.gpuSummary.gpuInfos.forEach((gpuInfo: GPUInfo) => { rmMeta.gpuSummary.gpuInfos.forEach((gpuInfo: GPUInfo) => {
// if the GPU has active process, OR be reserved by a job, // if the GPU has active process, OR be reserved by a job,
// or index not in gpuIndices configuration in machineList,
// We should NOT allocate this GPU // We should NOT allocate this GPU
if (gpuInfo.activeProcessNum === 0 && !rmMeta.gpuReservation.has(gpuInfo.index)) { if (gpuInfo.activeProcessNum === 0 && !rmMeta.gpuReservation.has(gpuInfo.index)
&& (designatedGpuIndices === undefined || designatedGpuIndices.has(gpuInfo.index))) {
availableGPUs.push(gpuInfo); availableGPUs.push(gpuInfo);
} }
}); });
...@@ -163,20 +179,5 @@ export class GPUScheduler { ...@@ -163,20 +179,5 @@ export class GPUScheduler {
} }
}; };
} }
/**
* remove the job's gpu reversion
* @param trialJobId
* @param rmMeta
*/
public removeGpuReservation(trialJobId: string, rmMeta?: RemoteMachineMeta): void{
// If remote machine has no GPU, gpuReservcation is not initialized, so check if it's undefined
if(rmMeta !== undefined && rmMeta.gpuReservation !== undefined) {
rmMeta.gpuReservation.forEach((reserveTrialJobId : string, gpuIndex : number) => {
if(reserveTrialJobId == trialJobId) {
rmMeta.gpuReservation.delete(gpuIndex);
}
});
}
}
} }
...@@ -19,12 +19,11 @@ ...@@ -19,12 +19,11 @@
'use strict'; 'use strict';
import { JobApplicationForm, TrialJobDetail, TrialJobStatus } from '../../common/trainingService'; import * as fs from 'fs';
import { GPUSummary } from '../common/gpuData';
import { Client, ConnectConfig } from 'ssh2'; import { Client, ConnectConfig } from 'ssh2';
import { Deferred } from 'ts-deferred'; import { Deferred } from 'ts-deferred';
import * as fs from 'fs'; import { JobApplicationForm, TrialJobDetail, TrialJobStatus } from '../../common/trainingService';
import { GPUSummary } from '../common/gpuData';
/** /**
* Metadata of remote machine for configuration and statuc query * Metadata of remote machine for configuration and statuc query
...@@ -37,11 +36,12 @@ export class RemoteMachineMeta { ...@@ -37,11 +36,12 @@ export class RemoteMachineMeta {
public readonly sshKeyPath?: string; public readonly sshKeyPath?: string;
public readonly passphrase?: string; public readonly passphrase?: string;
public gpuSummary : GPUSummary | undefined; public gpuSummary : GPUSummary | undefined;
/* GPU Reservation info, the key is GPU index, the value is the job id which reserves this GPU*/ // GPU Reservation info, the key is GPU index, the value is the job id which reserves this GPU
public gpuReservation : Map<number, string>; public gpuReservation : Map<number, string>;
public readonly gpuIndices?: string;
constructor(ip : string, port : number, username : string, passwd : string, constructor(ip : string, port : number, username : string, passwd : string,
sshKeyPath : string, passphrase : string) { sshKeyPath: string, passphrase : string, gpuIndices?: string) {
this.ip = ip; this.ip = ip;
this.port = port; this.port = port;
this.username = username; this.username = username;
...@@ -49,6 +49,19 @@ export class RemoteMachineMeta { ...@@ -49,6 +49,19 @@ export class RemoteMachineMeta {
this.sshKeyPath = sshKeyPath; this.sshKeyPath = sshKeyPath;
this.passphrase = passphrase; this.passphrase = passphrase;
this.gpuReservation = new Map<number, string>(); this.gpuReservation = new Map<number, string>();
this.gpuIndices = gpuIndices;
}
}
export function parseGpuIndices(gpuIndices?: string): Set<number> | undefined {
if (gpuIndices !== undefined) {
const indices: number[] = gpuIndices.split(',')
.map((x: string) => parseInt(x, 10));
if (indices.length > 0) {
return new Set(indices);
} else {
throw new Error('gpuIndices can not be empty if specified.');
}
} }
} }
......
...@@ -28,7 +28,7 @@ import * as component from '../../common/component'; ...@@ -28,7 +28,7 @@ import * as component from '../../common/component';
import { TrialJobApplicationForm, TrialJobDetail, TrainingService } from '../../common/trainingService'; import { TrialJobApplicationForm, TrialJobDetail, TrainingService } from '../../common/trainingService';
import { cleanupUnitTest, delay, prepareUnitTest } from '../../common/utils'; import { cleanupUnitTest, delay, prepareUnitTest } from '../../common/utils';
import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey'; import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
import { LocalTrainingServiceForGPU } from '../local/localTrainingServiceForGPU'; import { LocalTrainingService } from '../local/localTrainingService';
// TODO: copy mockedTrail.py to local folder // TODO: copy mockedTrail.py to local folder
const localCodeDir: string = tmp.dirSync().name const localCodeDir: string = tmp.dirSync().name
...@@ -38,7 +38,7 @@ fs.copyFileSync(mockedTrialPath, localCodeDir + '/mockedTrial.py') ...@@ -38,7 +38,7 @@ fs.copyFileSync(mockedTrialPath, localCodeDir + '/mockedTrial.py')
describe('Unit Test for LocalTrainingService', () => { describe('Unit Test for LocalTrainingService', () => {
let trialConfig: any = `{"command":"sleep 1h && echo hello","codeDir":"${localCodeDir}","gpuNum":1}` let trialConfig: any = `{"command":"sleep 1h && echo hello","codeDir":"${localCodeDir}","gpuNum":1}`
let localTrainingService: LocalTrainingServiceForGPU; let localTrainingService: LocalTrainingService;
before(() => { before(() => {
chai.should(); chai.should();
...@@ -51,7 +51,7 @@ describe('Unit Test for LocalTrainingService', () => { ...@@ -51,7 +51,7 @@ describe('Unit Test for LocalTrainingService', () => {
}); });
beforeEach(() => { beforeEach(() => {
localTrainingService = component.get(LocalTrainingServiceForGPU); localTrainingService = component.get(LocalTrainingService);
localTrainingService.run(); localTrainingService.run();
}); });
......
...@@ -135,6 +135,9 @@ Optional('assessor'): Or({ ...@@ -135,6 +135,9 @@ Optional('assessor'): Or({
Optional('classArgs'): dict, Optional('classArgs'): dict,
Optional('gpuNum'): And(int, lambda x: 0 <= x <= 99999), Optional('gpuNum'): And(int, lambda x: 0 <= x <= 99999),
}), }),
Optional('localConfig'): {
Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0))
}
} }
common_trial_schema = { common_trial_schema = {
...@@ -269,13 +272,15 @@ Optional('machineList'):[Or({ ...@@ -269,13 +272,15 @@ Optional('machineList'):[Or({
'ip': str, 'ip': str,
Optional('port'): And(int, lambda x: 0 < x < 65535), Optional('port'): And(int, lambda x: 0 < x < 65535),
'username': str, 'username': str,
'passwd': str 'passwd': str,
Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0))
},{ },{
'ip': str, 'ip': str,
Optional('port'): And(int, lambda x: 0 < x < 65535), Optional('port'): And(int, lambda x: 0 < x < 65535),
'username': str, 'username': str,
'sshKeyPath': os.path.exists, 'sshKeyPath': os.path.exists,
Optional('passphrase'): str Optional('passphrase'): str,
Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0))
})] })]
} }
......
...@@ -152,6 +152,23 @@ def set_trial_config(experiment_config, port, config_file_name): ...@@ -152,6 +152,23 @@ def set_trial_config(experiment_config, port, config_file_name):
def set_local_config(experiment_config, port, config_file_name): def set_local_config(experiment_config, port, config_file_name):
'''set local configuration''' '''set local configuration'''
#set machine_list
request_data = dict()
if experiment_config.get('localConfig'):
request_data['local_config'] = experiment_config['localConfig']
if request_data['local_config'] and request_data['local_config'].get('gpuIndices') \
and isinstance(request_data['local_config'].get('gpuIndices'), int):
request_data['local_config']['gpuIndices'] = str(request_data['local_config'].get('gpuIndices'))
response = rest_put(cluster_metadata_url(port), json.dumps(request_data), REST_TIME_OUT)
err_message = ''
if not response or not check_response(response):
if response is not None:
err_message = response.text
_, stderr_full_path = get_log_path(config_file_name)
with open(stderr_full_path, 'a+') as fout:
fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
return False, err_message
return set_trial_config(experiment_config, port, config_file_name) return set_trial_config(experiment_config, port, config_file_name)
def set_remote_config(experiment_config, port, config_file_name): def set_remote_config(experiment_config, port, config_file_name):
...@@ -159,6 +176,10 @@ def set_remote_config(experiment_config, port, config_file_name): ...@@ -159,6 +176,10 @@ def set_remote_config(experiment_config, port, config_file_name):
#set machine_list #set machine_list
request_data = dict() request_data = dict()
request_data['machine_list'] = experiment_config['machineList'] request_data['machine_list'] = experiment_config['machineList']
if request_data['machine_list']:
for i in range(len(request_data['machine_list'])):
if isinstance(request_data['machine_list'][i].get('gpuIndices'), int):
request_data['machine_list'][i]['gpuIndices'] = str(request_data['machine_list'][i].get('gpuIndices'))
response = rest_put(cluster_metadata_url(port), json.dumps(request_data), REST_TIME_OUT) response = rest_put(cluster_metadata_url(port), json.dumps(request_data), REST_TIME_OUT)
err_message = '' err_message = ''
if not response or not check_response(response): if not response or not check_response(response):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment