Unverified Commit 5ab984a4 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

tensorboard backend (#3454)

parent 6808708d
...@@ -108,6 +108,9 @@ abstract class Manager { ...@@ -108,6 +108,9 @@ abstract class Manager {
public abstract getTrialJobStatistics(): Promise<TrialJobStatistics[]>; public abstract getTrialJobStatistics(): Promise<TrialJobStatistics[]>;
public abstract getStatus(): NNIManagerStatus; public abstract getStatus(): NNIManagerStatus;
public abstract getTrialOutputLocalPath(trialJobId: string): Promise<string>;
public abstract fetchTrialOutput(trialJobId: string, subpath: string): Promise<void>;
} }
export { Manager, ExperimentParams, ExperimentProfile, TrialJobStatistics, ProfileUpdateType, NNIManagerStatus, ExperimentStatus, ExperimentStartUpMode }; export { Manager, ExperimentParams, ExperimentProfile, TrialJobStatistics, ProfileUpdateType, NNIManagerStatus, ExperimentStatus, ExperimentStartUpMode };
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
interface TensorboardParams {
trials: string;
}
type TensorboardTaskStatus = 'RUNNING' | 'DOWNLOADING_DATA' | 'STOPPING' | 'STOPPED' | 'ERROR' | 'FAIL_DOWNLOAD_DATA';
interface TensorboardTaskInfo {
readonly id: string;
readonly status: TensorboardTaskStatus;
readonly trialJobIdList: string[];
readonly trialLogDirectoryList: string[];
readonly pid?: number;
readonly port?: string;
}
abstract class TensorboardManager {
public abstract startTensorboardTask(tensorboardParams: TensorboardParams): Promise<TensorboardTaskInfo>;
public abstract getTensorboardTask(tensorboardTaskId: string): Promise<TensorboardTaskInfo>;
public abstract updateTensorboardTask(tensorboardTaskId: string): Promise<TensorboardTaskInfo>;
public abstract listTensorboardTasks(): Promise<TensorboardTaskInfo[]>;
public abstract stopTensorboardTask(tensorboardTaskId: string): Promise<TensorboardTaskInfo>;
public abstract stopAllTensorboardTask(): Promise<void>;
public abstract stop(): Promise<void>;
}
export {
TensorboardParams, TensorboardTaskStatus, TensorboardTaskInfo, TensorboardManager
}
...@@ -85,6 +85,8 @@ abstract class TrainingService { ...@@ -85,6 +85,8 @@ abstract class TrainingService {
public abstract getTrialLog(trialJobId: string, logType: LogType): Promise<string>; public abstract getTrialLog(trialJobId: string, logType: LogType): Promise<string>;
public abstract setClusterMetadata(key: string, value: string): Promise<void>; public abstract setClusterMetadata(key: string, value: string): Promise<void>;
public abstract getClusterMetadata(key: string): Promise<string>; public abstract getClusterMetadata(key: string): Promise<string>;
public abstract getTrialOutputLocalPath(trialJobId: string): Promise<string>;
public abstract fetchTrialOutput(trialJobId: string, subpath: string): Promise<void>;
public abstract cleanUp(): Promise<void>; public abstract cleanUp(): Promise<void>;
public abstract run(): Promise<void>; public abstract run(): Promise<void>;
} }
......
...@@ -9,6 +9,7 @@ import * as cpp from 'child-process-promise'; ...@@ -9,6 +9,7 @@ import * as cpp from 'child-process-promise';
import * as cp from 'child_process'; import * as cp from 'child_process';
import { ChildProcess, spawn, StdioOptions } from 'child_process'; import { ChildProcess, spawn, StdioOptions } from 'child_process';
import * as fs from 'fs'; import * as fs from 'fs';
import * as net from 'net';
import * as os from 'os'; import * as os from 'os';
import * as path from 'path'; import * as path from 'path';
import * as lockfile from 'lockfile'; import * as lockfile from 'lockfile';
...@@ -340,11 +341,9 @@ async function getVersion(): Promise<string> { ...@@ -340,11 +341,9 @@ async function getVersion(): Promise<string> {
/** /**
* run command as ChildProcess * run command as ChildProcess
*/ */
function getTunerProc(command: string, stdio: StdioOptions, newCwd: string, newEnv: any): ChildProcess { function getTunerProc(command: string, stdio: StdioOptions, newCwd: string, newEnv: any, newShell: boolean = true, isDetached: boolean = false): ChildProcess {
let cmd: string = command; let cmd: string = command;
let arg: string[] = []; let arg: string[] = [];
let newShell: boolean = true;
let isDetached: boolean = false;
if (process.platform === "win32") { if (process.platform === "win32") {
cmd = command.split(" ", 1)[0]; cmd = command.split(" ", 1)[0];
arg = command.substr(cmd.length + 1).split(" "); arg = command.substr(cmd.length + 1).split(" ");
...@@ -449,8 +448,45 @@ function withLockSync(func: Function, filePath: string, lockOpts: {[key: string] ...@@ -449,8 +448,45 @@ function withLockSync(func: Function, filePath: string, lockOpts: {[key: string]
return result; return result;
} }
async function isPortOpen(host: string, port: number): Promise<boolean> {
return new Promise<boolean>((resolve, reject) => {
try{
const stream = net.createConnection(port, host);
const id = setTimeout(() => {
stream.destroy();
resolve(false);
}, 1000);
stream.on('connect', () => {
clearTimeout(id);
stream.destroy();
resolve(true);
});
stream.on('error', () => {
clearTimeout(id);
stream.destroy();
resolve(false);
});
} catch (error) {
reject(error);
}
});
}
async function getFreePort(host: string, start: number, end: number): Promise<number> {
if (start > end) {
throw new Error(`no more free port`);
}
if (await isPortOpen(host, start)) {
return await getFreePort(host, start + 1, end);
} else {
return start;
}
}
export { export {
countFilesRecursively, validateFileNameRecursively, generateParamFileName, getMsgDispatcherCommand, getCheckpointDir, getExperimentsInfoPath, countFilesRecursively, validateFileNameRecursively, generateParamFileName, getMsgDispatcherCommand, getCheckpointDir, getExperimentsInfoPath,
getLogDir, getExperimentRootDir, getJobCancelStatus, getDefaultDatabaseDir, getIPV4Address, unixPathJoin, withLockSync, getLogDir, getExperimentRootDir, getJobCancelStatus, getDefaultDatabaseDir, getIPV4Address, unixPathJoin, withLockSync, getFreePort, isPortOpen,
mkDirP, mkDirPSync, delay, prepareUnitTest, parseArg, cleanupUnitTest, uniqueString, randomInt, randomSelect, getLogLevel, getVersion, getCmdPy, getTunerProc, isAlive, killPid, getNewLine mkDirP, mkDirPSync, delay, prepareUnitTest, parseArg, cleanupUnitTest, uniqueString, randomInt, randomSelect, getLogLevel, getVersion, getCmdPy, getTunerProc, isAlive, killPid, getNewLine
}; };
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
import * as fs from 'fs';
import * as cp from 'child_process';
import * as path from 'path';
import { ChildProcess } from 'child_process';
import * as component from '../common/component';
import { getLogger, Logger } from '../common/log';
import { getTunerProc, isAlive, uniqueString, mkDirPSync, getFreePort } from '../common/utils';
import { Manager } from '../common/manager';
import { TensorboardParams, TensorboardTaskStatus, TensorboardTaskInfo, TensorboardManager } from '../common/tensorboardManager';
class TensorboardTaskDetail implements TensorboardTaskInfo {
public id: string;
public status: TensorboardTaskStatus;
public trialJobIdList: string[];
public trialLogDirectoryList: string[];
public pid?: number;
public port?: string;
constructor(id: string, status: TensorboardTaskStatus, trialJobIdList: string[], trialLogDirectoryList: string[]) {
this.id = id;
this.status = status;
this.trialJobIdList = trialJobIdList;
this.trialLogDirectoryList = trialLogDirectoryList;
}
}
class NNITensorboardManager implements TensorboardManager {
private log: Logger;
private tensorboardTaskMap: Map<string, TensorboardTaskDetail>;
private tensorboardVersion: string | undefined;
private nniManager: Manager;
constructor() {
this.log = getLogger();
this.tensorboardTaskMap = new Map<string, TensorboardTaskDetail>();
this.setTensorboardVersion();
this.nniManager = component.get(Manager);
}
public async startTensorboardTask(tensorboardParams: TensorboardParams): Promise<TensorboardTaskDetail> {
const trialJobIds = tensorboardParams.trials;
const trialJobIdList: string[] = [];
const trialLogDirectoryList: string[] = [];
await Promise.all(trialJobIds.split(',').map(async (trialJobId) => {
const trialTensorboardDataPath = path.join(await this.nniManager.getTrialOutputLocalPath(trialJobId), 'tensorboard');
mkDirPSync(trialTensorboardDataPath);
trialJobIdList.push(trialJobId);
trialLogDirectoryList.push(trialTensorboardDataPath);
}));
this.log.info(`tensorboard: ${trialJobIdList} ${trialLogDirectoryList}`);
return await this.startTensorboardTaskProcess(trialJobIdList, trialLogDirectoryList);
}
private async startTensorboardTaskProcess(trialJobIdList: string[], trialLogDirectoryList: string[]): Promise<TensorboardTaskDetail> {
const host = 'localhost';
const port = await getFreePort(host, 6006, 65535);
const command = await this.getTensorboardStartCommand(trialJobIdList, trialLogDirectoryList, port);
this.log.info(`tensorboard start command: ${command}`);
const tensorboardTask = new TensorboardTaskDetail(uniqueString(5), 'RUNNING', trialJobIdList, trialLogDirectoryList);
this.tensorboardTaskMap.set(tensorboardTask.id, tensorboardTask);
const tensorboardProc: ChildProcess = getTunerProc(command, 'ignore', process.cwd(), process.env, true, true);
tensorboardProc.on('error', async (error) => {
this.log.error(error);
const alive: boolean = await isAlive(tensorboardProc.pid);
if (alive) {
process.kill(-tensorboardProc.pid);
}
this.setTensorboardTaskStatus(tensorboardTask, 'ERROR');
});
tensorboardTask.pid = tensorboardProc.pid;
tensorboardTask.port = `${port}`;
this.log.info(`tensorboard task id: ${tensorboardTask.id}`);
this.updateTensorboardTask(tensorboardTask.id);
return tensorboardTask;
}
private async getTensorboardStartCommand(trialJobIdList: string[], trialLogDirectoryList: string[], port: number): Promise<string> {
if (this.tensorboardVersion === undefined) {
this.setTensorboardVersion();
if (this.tensorboardVersion === undefined) {
throw new Error(`Tensorboard may not installed, if you want to use tensorboard, please check if tensorboard installed.`);
}
}
if (trialJobIdList.length !== trialLogDirectoryList.length) {
throw new Error('trial list length does not match');
}
if (trialJobIdList.length === 0) {
throw new Error('trial list length is 0');
}
let logdirCmd = '--logdir';
if (this.tensorboardVersion >= '2.0') {
logdirCmd = '--bind_all --logdir_spec'
}
try {
const logRealPaths: string[] = [];
for (const idx in trialJobIdList) {
const realPath = fs.realpathSync(trialLogDirectoryList[idx]);
const trialJob = await this.nniManager.getTrialJob(trialJobIdList[idx]);
logRealPaths.push(`${trialJob.sequenceId}-${trialJobIdList[idx]}:${realPath}`);
}
const command = `tensorboard ${logdirCmd}=${logRealPaths.join(',')} --port=${port}`;
return command;
} catch (error){
throw new Error(`${error.message}`);
}
}
private setTensorboardVersion(): void {
let command = `python3 -c 'import tensorboard ; print(tensorboard.__version__)'`;
if (process.platform === 'win32') {
command = `python -c 'import tensorboard ; print(tensorboard.__version__)'`;
}
try {
const tensorboardVersion = cp.execSync(command).toString();
if (/\d+(.\d+)*/.test(tensorboardVersion)) {
this.tensorboardVersion = tensorboardVersion;
}
} catch (error) {
this.log.warning(`Tensorboard may not installed, if you want to use tensorboard, please check if tensorboard installed.`);
}
}
public async getTensorboardTask(tensorboardTaskId: string): Promise<TensorboardTaskDetail> {
const tensorboardTask: TensorboardTaskDetail | undefined = this.tensorboardTaskMap.get(tensorboardTaskId);
if (tensorboardTask === undefined) {
throw new Error('Tensorboard task not found');
}
else{
if (tensorboardTask.status !== 'STOPPED'){
const alive: boolean = await isAlive(tensorboardTask.pid);
if (!alive) {
this.setTensorboardTaskStatus(tensorboardTask, 'ERROR');
}
}
return tensorboardTask;
}
}
public async listTensorboardTasks(): Promise<TensorboardTaskDetail[]> {
const result: TensorboardTaskDetail[] = [];
this.tensorboardTaskMap.forEach((value) => {
result.push(value);
});
return result;
}
private setTensorboardTaskStatus(tensorboardTask: TensorboardTaskDetail, newStatus: TensorboardTaskStatus): void {
if (tensorboardTask.status !== newStatus) {
const oldStatus = tensorboardTask.status;
tensorboardTask.status = newStatus;
this.log.info(`tensorboardTask ${tensorboardTask.id} status update: ${oldStatus} to ${tensorboardTask.status}`);
}
}
private downloadDataFinished(tensorboardTask: TensorboardTaskDetail): void {
this.setTensorboardTaskStatus(tensorboardTask, 'RUNNING');
}
public async updateTensorboardTask(tensorboardTaskId: string): Promise<TensorboardTaskInfo> {
const tensorboardTask: TensorboardTaskDetail = await this.getTensorboardTask(tensorboardTaskId);
if (['RUNNING', 'FAIL_DOWNLOAD_DATA'].includes(tensorboardTask.status)){
this.setTensorboardTaskStatus(tensorboardTask, 'DOWNLOADING_DATA');
Promise.all(tensorboardTask.trialJobIdList.map((trialJobId) => {
this.nniManager.fetchTrialOutput(trialJobId, 'tensorboard');
})).then(() => {
this.downloadDataFinished(tensorboardTask);
}).catch((error: Error) => {
this.setTensorboardTaskStatus(tensorboardTask, 'FAIL_DOWNLOAD_DATA');
this.log.error(`${error.message}`);
});
return tensorboardTask;
} else {
throw new Error('only tensorboard task with RUNNING or FAIL_DOWNLOAD_DATA can update data');
}
}
public async stopTensorboardTask(tensorboardTaskId: string): Promise<TensorboardTaskInfo> {
const tensorboardTask = await this.getTensorboardTask(tensorboardTaskId);
if (['RUNNING', 'FAIL_DOWNLOAD_DATA'].includes(tensorboardTask.status)){
this.killTensorboardTaskProc(tensorboardTask);
return tensorboardTask;
} else {
throw new Error('Only RUNNING FAIL_DOWNLOAD_DATA task can be stopped');
}
}
private async killTensorboardTaskProc(tensorboardTask: TensorboardTaskDetail): Promise<void> {
if (['ERROR', 'STOPPED'].includes(tensorboardTask.status)) {
return
}
const alive: boolean = await isAlive(tensorboardTask.pid);
if (!alive) {
this.setTensorboardTaskStatus(tensorboardTask, 'ERROR');
} else {
this.setTensorboardTaskStatus(tensorboardTask, 'STOPPING');
if (tensorboardTask.pid) {
process.kill(-tensorboardTask.pid);
}
this.log.debug(`Tensorboard task ${tensorboardTask.id} stopped.`);
this.setTensorboardTaskStatus(tensorboardTask, 'STOPPED');
this.tensorboardTaskMap.delete(tensorboardTask.id);
}
}
public async stopAllTensorboardTask(): Promise<void> {
this.log.info('Forced stopping all tensorboard task.')
for (const task of this.tensorboardTaskMap) {
await this.killTensorboardTaskProc(task[1]);
}
this.log.info('All tensorboard task stopped.')
}
public async stop(): Promise<void> {
await this.stopAllTensorboardTask();
this.log.info('Tensorboard manager stopped.');
}
}
export {
NNITensorboardManager, TensorboardTaskDetail
};
...@@ -16,6 +16,7 @@ import { ...@@ -16,6 +16,7 @@ import {
NNIManagerStatus, ProfileUpdateType, TrialJobStatistics NNIManagerStatus, ProfileUpdateType, TrialJobStatistics
} from '../common/manager'; } from '../common/manager';
import { ExperimentManager } from '../common/experimentManager'; import { ExperimentManager } from '../common/experimentManager';
import { TensorboardManager } from '../common/tensorboardManager';
import { import {
TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric, TrialJobStatus, LogType TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric, TrialJobStatus, LogType
} from '../common/trainingService'; } from '../common/trainingService';
...@@ -356,6 +357,7 @@ class NNIManager implements Manager { ...@@ -356,6 +357,7 @@ class NNIManager implements Manager {
let hasError: boolean = false; let hasError: boolean = false;
try { try {
await this.experimentManager.stop(); await this.experimentManager.stop();
await component.get<TensorboardManager>(TensorboardManager).stop();
await this.dataStore.close(); await this.dataStore.close();
await component.get<NNIRestServer>(NNIRestServer).stop(); await component.get<NNIRestServer>(NNIRestServer).stop();
} catch (err) { } catch (err) {
...@@ -881,6 +883,14 @@ class NNIManager implements Manager { ...@@ -881,6 +883,14 @@ class NNIManager implements Manager {
return Promise.resolve(chkpDir); return Promise.resolve(chkpDir);
} }
public async getTrialOutputLocalPath(trialJobId: string): Promise<string> {
return this.trainingService.getTrialOutputLocalPath(trialJobId);
}
public async fetchTrialOutput(trialJobId: string, subpath: string): Promise<void> {
return this.trainingService.fetchTrialOutput(trialJobId, subpath);
}
} }
export { NNIManager }; export { NNIManager };
...@@ -124,6 +124,14 @@ class MockedTrainingService extends TrainingService { ...@@ -124,6 +124,14 @@ class MockedTrainingService extends TrainingService {
public cleanUp(): Promise<void> { public cleanUp(): Promise<void> {
return Promise.resolve(); return Promise.resolve();
} }
public getTrialOutputLocalPath(_trialJobId: string): Promise<string> {
throw new MethodNotImplementedError();
}
public fetchTrialOutput(_trialJobId: string, _subpath: string): Promise<void> {
throw new MethodNotImplementedError();
}
} }
export{MockedTrainingService, testTrainingServiceProvider} export{MockedTrainingService, testTrainingServiceProvider}
...@@ -19,6 +19,8 @@ import { NNIManager } from '../nnimanager'; ...@@ -19,6 +19,8 @@ import { NNIManager } from '../nnimanager';
import { SqlDB } from '../sqlDatabase'; import { SqlDB } from '../sqlDatabase';
import { MockedTrainingService } from './mockedTrainingService'; import { MockedTrainingService } from './mockedTrainingService';
import { MockedDataStore } from './mockedDatastore'; import { MockedDataStore } from './mockedDatastore';
import { TensorboardManager } from '../../common/tensorboardManager';
import { NNITensorboardManager } from '../../core/nniTensorboardManager';
import * as path from 'path'; import * as path from 'path';
async function initContainer(): Promise<void> { async function initContainer(): Promise<void> {
...@@ -28,6 +30,7 @@ async function initContainer(): Promise<void> { ...@@ -28,6 +30,7 @@ async function initContainer(): Promise<void> {
Container.bind(Database).to(SqlDB).scope(Scope.Singleton); Container.bind(Database).to(SqlDB).scope(Scope.Singleton);
Container.bind(DataStore).to(MockedDataStore).scope(Scope.Singleton); Container.bind(DataStore).to(MockedDataStore).scope(Scope.Singleton);
Container.bind(ExperimentManager).to(NNIExperimentsManager).scope(Scope.Singleton); Container.bind(ExperimentManager).to(NNIExperimentsManager).scope(Scope.Singleton);
Container.bind(TensorboardManager).to(NNITensorboardManager).scope(Scope.Singleton);
await component.get<DataStore>(DataStore).init(); await component.get<DataStore>(DataStore).init();
} }
......
...@@ -13,12 +13,14 @@ import { setExperimentStartupInfo } from './common/experimentStartupInfo'; ...@@ -13,12 +13,14 @@ import { setExperimentStartupInfo } from './common/experimentStartupInfo';
import { getLogger, Logger, logLevelNameMap } from './common/log'; import { getLogger, Logger, logLevelNameMap } from './common/log';
import { Manager, ExperimentStartUpMode } from './common/manager'; import { Manager, ExperimentStartUpMode } from './common/manager';
import { ExperimentManager } from './common/experimentManager'; import { ExperimentManager } from './common/experimentManager';
import { TensorboardManager } from './common/tensorboardManager';
import { TrainingService } from './common/trainingService'; import { TrainingService } from './common/trainingService';
import { getLogDir, mkDirP, parseArg } from './common/utils'; import { getLogDir, mkDirP, parseArg } from './common/utils';
import { NNIDataStore } from './core/nniDataStore'; import { NNIDataStore } from './core/nniDataStore';
import { NNIManager } from './core/nnimanager'; import { NNIManager } from './core/nnimanager';
import { SqlDB } from './core/sqlDatabase'; import { SqlDB } from './core/sqlDatabase';
import { NNIExperimentsManager } from './core/nniExperimentsManager'; import { NNIExperimentsManager } from './core/nniExperimentsManager';
import { NNITensorboardManager } from './core/nniTensorboardManager';
import { NNIRestServer } from './rest_server/nniRestServer'; import { NNIRestServer } from './rest_server/nniRestServer';
import { FrameworkControllerTrainingService } from './training_service/kubernetes/frameworkcontroller/frameworkcontrollerTrainingService'; import { FrameworkControllerTrainingService } from './training_service/kubernetes/frameworkcontroller/frameworkcontrollerTrainingService';
import { AdlTrainingService } from './training_service/kubernetes/adl/adlTrainingService'; import { AdlTrainingService } from './training_service/kubernetes/adl/adlTrainingService';
...@@ -76,6 +78,9 @@ async function initContainer(foreground: boolean, platformMode: string, logFileN ...@@ -76,6 +78,9 @@ async function initContainer(foreground: boolean, platformMode: string, logFileN
Container.bind(ExperimentManager) Container.bind(ExperimentManager)
.to(NNIExperimentsManager) .to(NNIExperimentsManager)
.scope(Scope.Singleton); .scope(Scope.Singleton);
Container.bind(TensorboardManager)
.to(NNITensorboardManager)
.scope(Scope.Singleton);
const DEFAULT_LOGFILE: string = path.join(getLogDir(), 'nnimanager.log'); const DEFAULT_LOGFILE: string = path.join(getLogDir(), 'nnimanager.log');
if (foreground) { if (foreground) {
logFileName = undefined; logFileName = undefined;
......
...@@ -13,6 +13,7 @@ import { isNewExperiment, isReadonly } from '../common/experimentStartupInfo'; ...@@ -13,6 +13,7 @@ import { isNewExperiment, isReadonly } from '../common/experimentStartupInfo';
import { getLogger, Logger } from '../common/log'; import { getLogger, Logger } from '../common/log';
import { ExperimentProfile, Manager, TrialJobStatistics } from '../common/manager'; import { ExperimentProfile, Manager, TrialJobStatistics } from '../common/manager';
import { ExperimentManager } from '../common/experimentManager'; import { ExperimentManager } from '../common/experimentManager';
import { TensorboardManager, TensorboardTaskInfo } from '../common/tensorboardManager';
import { ValidationSchemas } from './restValidationSchemas'; import { ValidationSchemas } from './restValidationSchemas';
import { NNIRestServer } from './nniRestServer'; import { NNIRestServer } from './nniRestServer';
import { getVersion } from '../common/utils'; import { getVersion } from '../common/utils';
...@@ -23,11 +24,13 @@ class NNIRestHandler { ...@@ -23,11 +24,13 @@ class NNIRestHandler {
private restServer: NNIRestServer; private restServer: NNIRestServer;
private nniManager: Manager; private nniManager: Manager;
private experimentsManager: ExperimentManager; private experimentsManager: ExperimentManager;
private tensorboardManager: TensorboardManager;
private log: Logger; private log: Logger;
constructor(rs: NNIRestServer) { constructor(rs: NNIRestServer) {
this.nniManager = component.get(Manager); this.nniManager = component.get(Manager);
this.experimentsManager = component.get(ExperimentManager); this.experimentsManager = component.get(ExperimentManager);
this.tensorboardManager = component.get(TensorboardManager);
this.restServer = rs; this.restServer = rs;
this.log = getLogger(); this.log = getLogger();
} }
...@@ -64,6 +67,12 @@ class NNIRestHandler { ...@@ -64,6 +67,12 @@ class NNIRestHandler {
this.getTrialLog(router); this.getTrialLog(router);
this.exportData(router); this.exportData(router);
this.getExperimentsInfo(router); this.getExperimentsInfo(router);
this.startTensorboardTask(router);
this.getTensorboardTask(router);
this.updateTensorboardTask(router);
this.stopTensorboardTask(router);
this.stopAllTensorboardTask(router);
this.listTensorboardTask(router);
this.stop(router); this.stop(router);
// Express-joi-validator configuration // Express-joi-validator configuration
...@@ -318,6 +327,67 @@ class NNIRestHandler { ...@@ -318,6 +327,67 @@ class NNIRestHandler {
}); });
} }
private startTensorboardTask(router: Router): void {
router.post('/tensorboard', (req: Request, res: Response) => {
this.tensorboardManager.startTensorboardTask(req.body).then((taskDetail: TensorboardTaskInfo) => {
this.log.info(taskDetail);
res.send(Object.assign({}, taskDetail));
}).catch((err: Error) => {
this.handleError(err, res, false, 400);
});
});
}
private getTensorboardTask(router: Router): void {
router.get('/tensorboard/:id', (req: Request, res: Response) => {
this.tensorboardManager.getTensorboardTask(req.params.id).then((taskDetail: TensorboardTaskInfo) => {
res.send(Object.assign({}, taskDetail));
}).catch((err: Error) => {
this.handleError(err, res);
});
});
}
private updateTensorboardTask(router: Router): void {
router.put('/tensorboard/:id', (req: Request, res: Response) => {
this.tensorboardManager.updateTensorboardTask(req.params.id).then((taskDetail: TensorboardTaskInfo) => {
res.send(Object.assign({}, taskDetail));
}).catch((err: Error) => {
this.handleError(err, res);
});
});
}
private stopTensorboardTask(router: Router): void {
router.delete('/tensorboard/:id', (req: Request, res: Response) => {
this.tensorboardManager.stopTensorboardTask(req.params.id).then((taskDetail: TensorboardTaskInfo) => {
res.send(Object.assign({}, taskDetail));
}).catch((err: Error) => {
this.handleError(err, res);
});
});
}
private stopAllTensorboardTask(router: Router): void {
router.delete('/tensorboard-tasks', (req: Request, res: Response) => {
this.tensorboardManager.stopAllTensorboardTask().then(() => {
res.send();
}).catch((err: Error) => {
this.handleError(err, res);
});
});
}
private listTensorboardTask(router: Router): void {
router.get('/tensorboard-tasks', (req: Request, res: Response) => {
this.tensorboardManager.listTensorboardTasks().then((taskDetails: TensorboardTaskInfo[]) => {
res.send(taskDetails);
}).catch((err: Error) => {
this.handleError(err, res);
});
});
}
private stop(router: Router): void { private stop(router: Router): void {
router.delete('/experiment', (req: Request, res: Response) => { router.delete('/experiment', (req: Request, res: Response) => {
this.nniManager.stopExperimentTopHalf().then(() => { this.nniManager.stopExperimentTopHalf().then(() => {
......
...@@ -189,4 +189,12 @@ export class MockedNNIManager extends Manager { ...@@ -189,4 +189,12 @@ export class MockedNNIManager extends Manager {
return Promise.resolve([job1, job2]); return Promise.resolve([job1, job2]);
} }
public async getTrialOutputLocalPath(_trialJobId: string): Promise<string> {
throw new MethodNotImplementedError();
}
public async fetchTrialOutput(_trialJobId: string, _subpath: string): Promise<void> {
throw new MethodNotImplementedError();
}
} }
...@@ -18,6 +18,8 @@ import { MockedTrainingService } from '../../core/test/mockedTrainingService'; ...@@ -18,6 +18,8 @@ import { MockedTrainingService } from '../../core/test/mockedTrainingService';
import { NNIRestServer } from '../nniRestServer'; import { NNIRestServer } from '../nniRestServer';
import { testManagerProvider } from './mockedNNIManager'; import { testManagerProvider } from './mockedNNIManager';
import { testExperimentManagerProvider } from './mockedExperimentManager'; import { testExperimentManagerProvider } from './mockedExperimentManager';
import { TensorboardManager } from '../../common/tensorboardManager';
import { NNITensorboardManager } from '../../core/nniTensorboardManager';
describe('Unit test for rest server', () => { describe('Unit test for rest server', () => {
...@@ -28,7 +30,8 @@ describe('Unit test for rest server', () => { ...@@ -28,7 +30,8 @@ describe('Unit test for rest server', () => {
Container.bind(Manager).provider(testManagerProvider); Container.bind(Manager).provider(testManagerProvider);
Container.bind(DataStore).to(MockedDataStore); Container.bind(DataStore).to(MockedDataStore);
Container.bind(TrainingService).to(MockedTrainingService); Container.bind(TrainingService).to(MockedTrainingService);
Container.bind(ExperimentManager).provider(testExperimentManagerProvider) Container.bind(ExperimentManager).provider(testExperimentManagerProvider);
Container.bind(TensorboardManager).to(NNITensorboardManager);
const restServer: NNIRestServer = component.get(NNIRestServer); const restServer: NNIRestServer = component.get(NNIRestServer);
restServer.start().then(() => { restServer.start().then(() => {
ROOT_URL = `${restServer.endPoint}/api/v1/nni`; ROOT_URL = `${restServer.endPoint}/api/v1/nni`;
......
...@@ -565,6 +565,14 @@ class DLTSTrainingService implements TrainingService { ...@@ -565,6 +565,14 @@ class DLTSTrainingService implements TrainingService {
public get isMultiPhaseJobSupported(): boolean { public get isMultiPhaseJobSupported(): boolean {
return false; return false;
} }
public getTrialOutputLocalPath(_trialJobId: string): Promise<string> {
throw new MethodNotImplementedError();
}
public fetchTrialOutput(_trialJobId: string, _subpath: string): Promise<void> {
throw new MethodNotImplementedError();
}
} }
export { DLTSTrainingService }; export { DLTSTrainingService };
...@@ -393,5 +393,13 @@ abstract class KubernetesTrainingService { ...@@ -393,5 +393,13 @@ abstract class KubernetesTrainingService {
} }
return Promise.resolve(folderUriInAzure); return Promise.resolve(folderUriInAzure);
} }
public getTrialOutputLocalPath(_trialJobId: string): Promise<string> {
throw new MethodNotImplementedError();
}
public fetchTrialOutput(_trialJobId: string, _subpath: string): Promise<void> {
throw new MethodNotImplementedError();
}
} }
export {KubernetesTrainingService}; export {KubernetesTrainingService};
...@@ -583,6 +583,22 @@ class LocalTrainingService implements TrainingService { ...@@ -583,6 +583,22 @@ class LocalTrainingService implements TrainingService {
const filepath: string = path.join(directory, generateParamFileName(hyperParameters)); const filepath: string = path.join(directory, generateParamFileName(hyperParameters));
await fs.promises.writeFile(filepath, hyperParameters.value, { encoding: 'utf8' }); await fs.promises.writeFile(filepath, hyperParameters.value, { encoding: 'utf8' });
} }
public async getTrialOutputLocalPath(trialJobId: string): Promise<string> {
return Promise.resolve(path.join(this.rootDir, 'trials', trialJobId));
}
public async fetchTrialOutput(trialJobId: string, subpath: string): Promise<void> {
let trialLocalPath = await this.getTrialOutputLocalPath(trialJobId);
if (subpath !== undefined) {
trialLocalPath = path.join(trialLocalPath, subpath);
}
if (fs.existsSync(trialLocalPath)) {
return Promise.resolve();
} else {
return Promise.reject(new Error('Trial local path not exist.'));
}
}
} }
export { LocalTrainingService }; export { LocalTrainingService };
...@@ -576,6 +576,14 @@ class PAITrainingService implements TrainingService { ...@@ -576,6 +576,14 @@ class PAITrainingService implements TrainingService {
const filepath: string = path.join(directory, generateParamFileName(hyperParameters)); const filepath: string = path.join(directory, generateParamFileName(hyperParameters));
await fs.promises.writeFile(filepath, hyperParameters.value, { encoding: 'utf8' }); await fs.promises.writeFile(filepath, hyperParameters.value, { encoding: 'utf8' });
} }
public getTrialOutputLocalPath(_trialJobId: string): Promise<string> {
throw new MethodNotImplementedError();
}
public fetchTrialOutput(_trialJobId: string, _subpath: string): Promise<void> {
throw new MethodNotImplementedError();
}
} }
export { PAITrainingService }; export { PAITrainingService };
...@@ -679,6 +679,14 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -679,6 +679,14 @@ class RemoteMachineTrainingService implements TrainingService {
await executor.copyFileToRemote(localFilepath, executor.joinPath(trialWorkingFolder, fileName)); await executor.copyFileToRemote(localFilepath, executor.joinPath(trialWorkingFolder, fileName));
} }
public getTrialOutputLocalPath(_trialJobId: string): Promise<string> {
throw new MethodNotImplementedError();
}
public fetchTrialOutput(_trialJobId: string, _subpath: string): Promise<void> {
throw new MethodNotImplementedError();
}
} }
export { RemoteMachineTrainingService }; export { RemoteMachineTrainingService };
...@@ -183,6 +183,20 @@ class RouterTrainingService implements TrainingService { ...@@ -183,6 +183,20 @@ class RouterTrainingService implements TrainingService {
} }
return await this.internalTrainingService.run(); return await this.internalTrainingService.run();
} }
public async getTrialOutputLocalPath(trialJobId: string): Promise<string> {
if (this.internalTrainingService === undefined) {
throw new Error("TrainingService is not assigned!");
}
return this.internalTrainingService.getTrialOutputLocalPath(trialJobId);
}
public async fetchTrialOutput(trialJobId: string, subpath: string): Promise<void> {
if (this.internalTrainingService === undefined) {
throw new Error("TrainingService is not assigned!");
}
return this.internalTrainingService.fetchTrialOutput(trialJobId, subpath);
}
} }
export { RouterTrainingService }; export { RouterTrainingService };
...@@ -941,6 +941,29 @@ class TrialDispatcher implements TrainingService { ...@@ -941,6 +941,29 @@ class TrialDispatcher implements TrainingService {
this.useSharedStorage = true; this.useSharedStorage = true;
return Promise.resolve(); return Promise.resolve();
} }
public async getTrialOutputLocalPath(trialJobId: string): Promise<string> {
// TODO: support non shared storage
if (this.useSharedStorage) {
const localWorkingRoot = component.get<SharedStorageService>(SharedStorageService).localWorkingRoot;
return Promise.resolve(path.join(localWorkingRoot, 'trials', trialJobId));
} else {
return Promise.reject(new Error('Only support shared storage right now.'));
}
}
public async fetchTrialOutput(trialJobId: string, subpath: string | undefined): Promise<void> {
// TODO: support non shared storage
let trialLocalPath = await this.getTrialOutputLocalPath(trialJobId);
if (subpath !== undefined) {
trialLocalPath = path.join(trialLocalPath, subpath);
}
if (fs.existsSync(trialLocalPath)) {
return Promise.resolve();
} else {
return Promise.reject(new Error('Trial local path not exist.'));
}
}
} }
export { TrialDispatcher }; export { TrialDispatcher };
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment