Unverified Commit bef663b8 authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Experiments Manager (step 1) - decouple lock (#4824)

parent f6ec5394
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
abstract class ExperimentManager {
public abstract getExperimentsInfo(): Promise<JSON>;
public abstract setExperimentPath(newPath: string): void;
public abstract setExperimentInfo(experimentId: string, key: string, value: any): void;
public abstract stop(): Promise<void>;
}
export {ExperimentManager};
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
import assert from 'assert/strict'; import assert from 'assert/strict';
import fs from 'fs'; import fs from 'fs';
import os from 'os';
import path from 'path'; import path from 'path';
import type { NniManagerArgs } from './arguments'; import type { NniManagerArgs } from './arguments';
...@@ -17,6 +18,7 @@ import type { NniManagerArgs } from './arguments'; ...@@ -17,6 +18,7 @@ import type { NniManagerArgs } from './arguments';
export interface NniPaths { export interface NniPaths {
readonly experimentRoot: string; readonly experimentRoot: string;
readonly experimentsDirectory: string; readonly experimentsDirectory: string;
readonly experimentsList: string;
readonly logDirectory: string; // contains nni manager and dispatcher log; trial logs are not here readonly logDirectory: string; // contains nni manager and dispatcher log; trial logs are not here
readonly nniManagerLog: string; readonly nniManagerLog: string;
} }
...@@ -35,9 +37,13 @@ export function createPaths(args: NniManagerArgs): NniPaths { ...@@ -35,9 +37,13 @@ export function createPaths(args: NniManagerArgs): NniPaths {
const nniManagerLog = path.join(logDirectory, 'nnimanager.log'); const nniManagerLog = path.join(logDirectory, 'nnimanager.log');
// TODO: this should follow experiments directory config
const experimentsList = path.join(os.homedir(), 'nni-experiments', '.experiment');
return { return {
experimentRoot, experimentRoot,
experimentsDirectory: args.experimentsDirectory, experimentsDirectory: args.experimentsDirectory,
experimentsList,
logDirectory, logDirectory,
nniManagerLog, nniManagerLog,
}; };
......
...@@ -12,16 +12,13 @@ import net from 'net'; ...@@ -12,16 +12,13 @@ import net from 'net';
import os from 'os'; import os from 'os';
import path from 'path'; import path from 'path';
import * as timersPromises from 'timers/promises'; import * as timersPromises from 'timers/promises';
import lockfile from 'lockfile';
import { Deferred } from 'ts-deferred'; import { Deferred } from 'ts-deferred';
import { Container } from 'typescript-ioc'; import { Container } from 'typescript-ioc';
import glob from 'glob';
import { Database, DataStore } from './datastore'; import { Database, DataStore } from './datastore';
import globals from './globals'; import globals from './globals';
import { resetGlobals } from './globals/unittest'; // TODO: this file should not contain unittest helpers import { resetGlobals } from './globals/unittest'; // TODO: this file should not contain unittest helpers
import { ExperimentConfig, Manager } from './manager'; import { ExperimentConfig, Manager } from './manager';
import { ExperimentManager } from './experimentManager';
import { HyperParameters, TrainingService, TrialJobStatus } from './trainingService'; import { HyperParameters, TrainingService, TrialJobStatus } from './trainingService';
function getExperimentRootDir(): string { function getExperimentRootDir(): string {
...@@ -44,10 +41,6 @@ function getCheckpointDir(): string { ...@@ -44,10 +41,6 @@ function getCheckpointDir(): string {
return path.join(getExperimentRootDir(), 'checkpoint'); return path.join(getExperimentRootDir(), 'checkpoint');
} }
function getExperimentsInfoPath(): string {
return path.join(os.homedir(), 'nni-experiments', '.experiment');
}
async function mkDirP(dirPath: string): Promise<void> { async function mkDirP(dirPath: string): Promise<void> {
await fs.promises.mkdir(dirPath, { recursive: true }); await fs.promises.mkdir(dirPath, { recursive: true });
} }
...@@ -152,7 +145,6 @@ function prepareUnitTest(): void { ...@@ -152,7 +145,6 @@ function prepareUnitTest(): void {
Container.snapshot(DataStore); Container.snapshot(DataStore);
Container.snapshot(TrainingService); Container.snapshot(TrainingService);
Container.snapshot(Manager); Container.snapshot(Manager);
Container.snapshot(ExperimentManager);
resetGlobals(); resetGlobals();
...@@ -173,7 +165,6 @@ function cleanupUnitTest(): void { ...@@ -173,7 +165,6 @@ function cleanupUnitTest(): void {
Container.restore(TrainingService); Container.restore(TrainingService);
Container.restore(DataStore); Container.restore(DataStore);
Container.restore(Database); Container.restore(Database);
Container.restore(ExperimentManager);
} }
let cachedIpv4Address: string | null = null; let cachedIpv4Address: string | null = null;
...@@ -352,27 +343,6 @@ function unixPathJoin(...paths: any[]): string { ...@@ -352,27 +343,6 @@ function unixPathJoin(...paths: any[]): string {
return dir; return dir;
} }
/**
* lock a file sync
*/
function withLockSync(func: Function, filePath: string, lockOpts: {[key: string]: any}, ...args: any): any {
const lockName = path.join(path.dirname(filePath), path.basename(filePath) + `.lock.${process.pid}`);
if (typeof lockOpts['stale'] === 'number'){
const lockPath = path.join(path.dirname(filePath), path.basename(filePath) + '.lock.*');
const lockFileNames: string[] = glob.sync(lockPath);
const canLock: boolean = lockFileNames.map((fileName) => {
return fs.existsSync(fileName) && Date.now() - fs.statSync(fileName).mtimeMs < lockOpts['stale'];
}).filter(unexpired=>unexpired === true).length === 0;
if (!canLock) {
throw new Error('File has been locked.');
}
}
lockfile.lockSync(lockName, lockOpts);
const result = func(...args);
lockfile.unlockSync(lockName);
return result;
}
async function isPortOpen(host: string, port: number): Promise<boolean> { async function isPortOpen(host: string, port: number): Promise<boolean> {
return new Promise<boolean>((resolve, reject) => { return new Promise<boolean>((resolve, reject) => {
try{ try{
...@@ -416,7 +386,7 @@ export function importModule(modulePath: string): any { ...@@ -416,7 +386,7 @@ export function importModule(modulePath: string): any {
} }
export { export {
countFilesRecursively, generateParamFileName, getMsgDispatcherCommand, getCheckpointDir, getExperimentsInfoPath, countFilesRecursively, generateParamFileName, getMsgDispatcherCommand, getCheckpointDir,
getLogDir, getExperimentRootDir, getJobCancelStatus, getDefaultDatabaseDir, getIPV4Address, unixPathJoin, withLockSync, getFreePort, isPortOpen, getLogDir, getExperimentRootDir, getJobCancelStatus, getDefaultDatabaseDir, getIPV4Address, unixPathJoin, getFreePort, isPortOpen,
mkDirP, mkDirPSync, delay, prepareUnitTest, cleanupUnitTest, uniqueString, randomInt, randomSelect, getLogLevel, getVersion, getCmdPy, getTunerProc, isAlive, killPid, getNewLine mkDirP, mkDirPSync, delay, prepareUnitTest, cleanupUnitTest, uniqueString, randomInt, randomSelect, getLogLevel, getVersion, getCmdPy, getTunerProc, isAlive, killPid, getNewLine
}; };
...@@ -15,7 +15,7 @@ import { ...@@ -15,7 +15,7 @@ import {
NNIManagerStatus, ProfileUpdateType, TrialJobStatistics NNIManagerStatus, ProfileUpdateType, TrialJobStatistics
} from '../common/manager'; } from '../common/manager';
import { ExperimentConfig, LocalConfig, toSeconds, toCudaVisibleDevices } from '../common/experimentConfig'; import { ExperimentConfig, LocalConfig, toSeconds, toCudaVisibleDevices } from '../common/experimentConfig';
import { ExperimentManager } from '../common/experimentManager'; import { getExperimentsManager } from 'extensions/experiments_manager';
import { TensorboardManager } from '../common/tensorboardManager'; import { TensorboardManager } from '../common/tensorboardManager';
import { import {
TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric, TrialJobStatus, TrialCommandContent, PlacementConstraint TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric, TrialJobStatus, TrialCommandContent, PlacementConstraint
...@@ -33,7 +33,6 @@ import { createDispatcherInterface, IpcInterface } from './ipcInterface'; ...@@ -33,7 +33,6 @@ import { createDispatcherInterface, IpcInterface } from './ipcInterface';
class NNIManager implements Manager { class NNIManager implements Manager {
private trainingService!: TrainingService; private trainingService!: TrainingService;
private dispatcher: IpcInterface | undefined; private dispatcher: IpcInterface | undefined;
private experimentManager: ExperimentManager;
private currSubmittedTrialNum: number; // need to be recovered private currSubmittedTrialNum: number; // need to be recovered
private trialConcurrencyChange: number; // >0: increase, <0: decrease private trialConcurrencyChange: number; // >0: increase, <0: decrease
private log: Logger; private log: Logger;
...@@ -52,7 +51,6 @@ class NNIManager implements Manager { ...@@ -52,7 +51,6 @@ class NNIManager implements Manager {
constructor() { constructor() {
this.currSubmittedTrialNum = 0; this.currSubmittedTrialNum = 0;
this.trialConcurrencyChange = 0; this.trialConcurrencyChange = 0;
this.experimentManager = component.get(ExperimentManager);
this.dispatcherPid = 0; this.dispatcherPid = 0;
this.waitingTrials = []; this.waitingTrials = [];
this.trialJobs = new Map<string, TrialJobDetail>(); this.trialJobs = new Map<string, TrialJobDetail>();
...@@ -347,7 +345,6 @@ class NNIManager implements Manager { ...@@ -347,7 +345,6 @@ class NNIManager implements Manager {
this.setStatus('STOPPED'); this.setStatus('STOPPED');
this.log.info('Experiment stopped.'); this.log.info('Experiment stopped.');
await this.experimentManager.stop();
await component.get<TensorboardManager>(TensorboardManager).stop(); await component.get<TensorboardManager>(TensorboardManager).stop();
await this.dataStore.close(); await this.dataStore.close();
} }
...@@ -882,13 +879,13 @@ class NNIManager implements Manager { ...@@ -882,13 +879,13 @@ class NNIManager implements Manager {
if (status !== this.status.status) { if (status !== this.status.status) {
this.log.info(`Change NNIManager status from: ${this.status.status} to: ${status}`); this.log.info(`Change NNIManager status from: ${this.status.status} to: ${status}`);
this.status.status = status; this.status.status = status;
this.experimentManager.setExperimentInfo(this.experimentProfile.id, 'status', this.status.status); getExperimentsManager().setExperimentInfo(this.experimentProfile.id, 'status', this.status.status);
} }
} }
private setEndtime(): void { private setEndtime(): void {
this.experimentProfile.endTime = Date.now(); this.experimentProfile.endTime = Date.now();
this.experimentManager.setExperimentInfo(this.experimentProfile.id, 'endTime', this.experimentProfile.endTime); getExperimentsManager().setExperimentInfo(this.experimentProfile.id, 'endTime', this.experimentProfile.endTime);
} }
private async createCheckpointDir(): Promise<string> { private async createCheckpointDir(): Promise<string> {
......
// Copyright (c) Microsoft Corporation. // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license. // Licensed under the MIT license.
export { NNIExperimentsManager } from './manager'; import { ExperimentsManager } from './manager';
export { ExperimentsManager } from './manager';
let singleton: ExperimentsManager | null = null;
export function initExperimentsManager(): void {
getExperimentsManager();
}
export function getExperimentsManager(): ExperimentsManager {
if (singleton === null) {
singleton = new ExperimentsManager();
}
return singleton;
}
export namespace UnitTestHelpers {
export function setExperimentsManager(experimentsManager: any): void {
singleton = experimentsManager;
}
}
// Copyright (c) Microsoft Corporation. // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license. // Licensed under the MIT license.
import assert from 'assert/strict';
import fs from 'fs'; import fs from 'fs';
import os from 'os'; import os from 'os';
import path from 'path'; import path from 'path';
import assert from 'assert'; import * as timersPromises from 'timers/promises';
import { getLogger, Logger } from 'common/log';
import { isAlive, withLockSync, getExperimentsInfoPath, delay } from 'common/utils';
import { ExperimentManager } from 'common/experimentManager';
import { Deferred } from 'ts-deferred'; import { Deferred } from 'ts-deferred';
import { getLogger, Logger } from 'common/log';
import globals from 'common/globals';
import { isAlive } from 'common/utils';
import { withLock, withLockNoWait } from './utils';
const logger: Logger = getLogger('experiments_manager');
interface CrashedInfo { interface CrashedInfo {
experimentId: string; experimentId: string;
isCrashed: boolean; isCrashed: boolean;
...@@ -21,19 +26,15 @@ interface FileInfo { ...@@ -21,19 +26,15 @@ interface FileInfo {
mtime: number; mtime: number;
} }
class NNIExperimentsManager implements ExperimentManager { export class ExperimentsManager {
private experimentsPath: string; private profileUpdateTimer: Record<string, NodeJS.Timeout | undefined> = {};
private log: Logger;
private profileUpdateTimer: {[key: string]: any};
constructor() { constructor() {
this.experimentsPath = getExperimentsInfoPath(); globals.shutdown.register('experiments_manager', this.cleanUp.bind(this));
this.log = getLogger('NNIExperimentsManager');
this.profileUpdateTimer = {};
} }
public async getExperimentsInfo(): Promise<JSON> { public async getExperimentsInfo(): Promise<JSON> {
const fileInfo: FileInfo = await this.withLockIterated(this.readExperimentsInfo, 100); const fileInfo: FileInfo = await withLock(globals.paths.experimentsList, () => this.readExperimentsInfo());
const experimentsInformation = JSON.parse(fileInfo.buffer.toString()); const experimentsInformation = JSON.parse(fileInfo.buffer.toString());
const expIdList: Array<string> = Object.keys(experimentsInformation).filter((expId) => { const expIdList: Array<string> = Object.keys(experimentsInformation).filter((expId) => {
return experimentsInformation[expId]['status'] !== 'STOPPED'; return experimentsInformation[expId]['status'] !== 'STOPPED';
...@@ -42,11 +43,13 @@ class NNIExperimentsManager implements ExperimentManager { ...@@ -42,11 +43,13 @@ class NNIExperimentsManager implements ExperimentManager {
return this.checkCrashed(expId, experimentsInformation[expId]['pid']); return this.checkCrashed(expId, experimentsInformation[expId]['pid']);
}))).filter(crashedInfo => crashedInfo.isCrashed); }))).filter(crashedInfo => crashedInfo.isCrashed);
if (updateList.length > 0){ if (updateList.length > 0){
const result = await this.withLockIterated(this.updateAllStatus, 100, updateList.map(crashedInfo => crashedInfo.experimentId), fileInfo.mtime); const result = await withLock(globals.paths.experimentsList, () => {
return this.updateAllStatus(updateList.map(crashedInfo => crashedInfo.experimentId), fileInfo.mtime)
});
if (result !== undefined) { if (result !== undefined) {
return JSON.parse(JSON.stringify(Object.keys(result).map(key=>result[key]))); return JSON.parse(JSON.stringify(Object.keys(result).map(key=>result[key])));
} else { } else {
await delay(500); await timersPromises.setTimeout(500);
return await this.getExperimentsInfo(); return await this.getExperimentsInfo();
} }
} else { } else {
...@@ -54,66 +57,35 @@ class NNIExperimentsManager implements ExperimentManager { ...@@ -54,66 +57,35 @@ class NNIExperimentsManager implements ExperimentManager {
} }
} }
public setExperimentPath(newPath: string): void {
if (newPath[0] === '~') {
newPath = path.join(os.homedir(), newPath.slice(1));
}
if (!path.isAbsolute(newPath)) {
newPath = path.resolve(newPath);
}
this.log.info(`Set new experiment information path: ${newPath}`);
this.experimentsPath = newPath;
}
public setExperimentInfo(experimentId: string, key: string, value: any): void { public setExperimentInfo(experimentId: string, key: string, value: any): void {
try { try {
if (this.profileUpdateTimer[key] !== undefined) { if (this.profileUpdateTimer[key] !== undefined) {
// if a new call with the same timerId occurs, destroy the unfinished old one // if a new call with the same timerId occurs, destroy the unfinished old one
clearTimeout(this.profileUpdateTimer[key]); clearTimeout(this.profileUpdateTimer[key]!);
this.profileUpdateTimer[key] = undefined; this.profileUpdateTimer[key] = undefined;
} }
this.withLockSync(() => { withLockNoWait(globals.paths.experimentsList, () => {
const experimentsInformation = JSON.parse(fs.readFileSync(this.experimentsPath).toString()); const experimentsInformation = JSON.parse(fs.readFileSync(globals.paths.experimentsList).toString());
assert(experimentId in experimentsInformation, `Experiment Manager: Experiment Id ${experimentId} not found, this should not happen`); assert(experimentId in experimentsInformation, `Experiment Manager: Experiment Id ${experimentId} not found, this should not happen`);
if (value !== undefined) { if (value !== undefined) {
experimentsInformation[experimentId][key] = value; experimentsInformation[experimentId][key] = value;
} else { } else {
delete experimentsInformation[experimentId][key]; delete experimentsInformation[experimentId][key];
} }
fs.writeFileSync(this.experimentsPath, JSON.stringify(experimentsInformation, null, 4)); fs.writeFileSync(globals.paths.experimentsList, JSON.stringify(experimentsInformation, null, 4));
}); });
} catch (err) { } catch (err) {
this.log.error(err); logger.error(err);
this.log.debug(`Experiment Manager: Retry set key value: ${experimentId} {${key}: ${value}}`); logger.debug(`Experiment Manager: Retry set key value: ${experimentId} {${key}: ${value}}`);
if (err.code === 'EEXIST' || err.message === 'File has been locked.') { if (err.code === 'EEXIST' || err.message === 'File has been locked.') {
this.profileUpdateTimer[key] = setTimeout(this.setExperimentInfo.bind(this), 100, experimentId, key, value); this.profileUpdateTimer[key] = setTimeout(() => this.setExperimentInfo(experimentId, key, value), 100);
} }
} }
} }
private async withLockIterated (func: Function, retry: number, ...args: any): Promise<any> {
if (retry < 0) {
throw new Error('Lock file out of retries.');
}
try {
return this.withLockSync(func, ...args);
} catch(err) {
if (err.code === 'EEXIST' || err.message === 'File has been locked.') {
// retry wait is 50ms
await delay(50);
return await this.withLockIterated(func, retry - 1, ...args);
}
throw err;
}
}
private withLockSync (func: Function, ...args: any): any {
return withLockSync(func.bind(this), this.experimentsPath, {stale: 2 * 1000}, ...args);
}
private readExperimentsInfo(): FileInfo { private readExperimentsInfo(): FileInfo {
const buffer: Buffer = fs.readFileSync(this.experimentsPath); const buffer: Buffer = fs.readFileSync(globals.paths.experimentsList);
const mtime: number = fs.statSync(this.experimentsPath).mtimeMs; const mtime: number = fs.statSync(globals.paths.experimentsList).mtimeMs;
return {buffer: buffer, mtime: mtime}; return {buffer: buffer, mtime: mtime};
} }
...@@ -123,33 +95,27 @@ class NNIExperimentsManager implements ExperimentManager { ...@@ -123,33 +95,27 @@ class NNIExperimentsManager implements ExperimentManager {
} }
private updateAllStatus(updateList: Array<string>, timestamp: number): {[key: string]: any} | undefined { private updateAllStatus(updateList: Array<string>, timestamp: number): {[key: string]: any} | undefined {
if (timestamp !== fs.statSync(this.experimentsPath).mtimeMs) { if (timestamp !== fs.statSync(globals.paths.experimentsList).mtimeMs) {
return; return;
} else { } else {
const experimentsInformation = JSON.parse(fs.readFileSync(this.experimentsPath).toString()); const experimentsInformation = JSON.parse(fs.readFileSync(globals.paths.experimentsList).toString());
updateList.forEach((expId: string) => { updateList.forEach((expId: string) => {
if (experimentsInformation[expId]) { if (experimentsInformation[expId]) {
experimentsInformation[expId]['status'] = 'STOPPED'; experimentsInformation[expId]['status'] = 'STOPPED';
delete experimentsInformation[expId]['port']; delete experimentsInformation[expId]['port'];
} else { } else {
this.log.error(`Experiment Manager: Experiment Id ${expId} not found, this should not happen`); logger.error(`Experiment Manager: Experiment Id ${expId} not found, this should not happen`);
} }
}); });
fs.writeFileSync(this.experimentsPath, JSON.stringify(experimentsInformation, null, 4)); fs.writeFileSync(globals.paths.experimentsList, JSON.stringify(experimentsInformation, null, 4));
return experimentsInformation; return experimentsInformation;
} }
} }
public async stop(): Promise<void> {
this.log.debug('Stopping experiment manager.');
await this.cleanUp().catch(err=>this.log.error(err.message));
this.log.debug('Experiment manager stopped.');
}
private async cleanUp(): Promise<void> { private async cleanUp(): Promise<void> {
const deferred = new Deferred<void>(); const deferred = new Deferred<void>();
if (this.isUndone()) { if (this.isUndone()) {
this.log.debug('Experiment manager: something undone'); logger.debug('Experiment manager: something undone');
setTimeout(((deferred: Deferred<void>): void => { setTimeout(((deferred: Deferred<void>): void => {
if (this.isUndone()) { if (this.isUndone()) {
deferred.reject(new Error('Still has undone after 5s, forced stop.')); deferred.reject(new Error('Still has undone after 5s, forced stop.'));
...@@ -158,7 +124,7 @@ class NNIExperimentsManager implements ExperimentManager { ...@@ -158,7 +124,7 @@ class NNIExperimentsManager implements ExperimentManager {
} }
}).bind(this), 5 * 1000, deferred); }).bind(this), 5 * 1000, deferred);
} else { } else {
this.log.debug('Experiment manager: all clean up'); logger.debug('Experiment manager: all clean up');
deferred.resolve(); deferred.resolve();
} }
return deferred.promise; return deferred.promise;
...@@ -170,5 +136,3 @@ class NNIExperimentsManager implements ExperimentManager { ...@@ -170,5 +136,3 @@ class NNIExperimentsManager implements ExperimentManager {
}).length > 0; }).length > 0;
} }
} }
export { NNIExperimentsManager };
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
import fs from 'fs';
import path from 'path';
import * as timersPromises from 'timers/promises';
import glob from 'glob';
import lockfile from 'lockfile';
import globals from 'common/globals';
const lockStale: number = 2000;
const retry: number = 100;
export function withLockNoWait<T>(protectedFile: string, func: () => T): T {
const lockName = path.join(path.dirname(protectedFile), path.basename(protectedFile) + `.lock.${process.pid}`);
const lockPath = path.join(path.dirname(protectedFile), path.basename(protectedFile) + '.lock.*');
const lockFileNames: string[] = glob.sync(lockPath);
const canLock: boolean = lockFileNames.map((fileName) => {
return fs.existsSync(fileName) && Date.now() - fs.statSync(fileName).mtimeMs < lockStale;
}).filter(unexpired=>unexpired === true).length === 0;
if (!canLock) {
throw new Error('File has been locked.');
}
lockfile.lockSync(lockName, { stale: lockStale });
const result = func();
lockfile.unlockSync(lockName);
return result;
}
export async function withLock<T>(protectedFile: string, func: () => T): Promise<T> {
for (let i = 0; i < retry; i += 1) {
try {
return withLockNoWait(protectedFile, func);
} catch (error: any) {
if (error.code === 'EEXIST' || error.message === 'File has been locked.') {
await timersPromises.setTimeout(50);
} else {
throw error;
}
}
}
throw new Error('Lock file out of retries.');
}
...@@ -27,7 +27,6 @@ import { Container, Scope } from 'typescript-ioc'; ...@@ -27,7 +27,6 @@ import { Container, Scope } from 'typescript-ioc';
import * as component from 'common/component'; import * as component from 'common/component';
import { Database, DataStore } from 'common/datastore'; import { Database, DataStore } from 'common/datastore';
import { ExperimentManager } from 'common/experimentManager';
import globals, { initGlobals } from 'common/globals'; import globals, { initGlobals } from 'common/globals';
import { Logger, getLogger } from 'common/log'; import { Logger, getLogger } from 'common/log';
import { Manager } from 'common/manager'; import { Manager } from 'common/manager';
...@@ -35,7 +34,7 @@ import { TensorboardManager } from 'common/tensorboardManager'; ...@@ -35,7 +34,7 @@ import { TensorboardManager } from 'common/tensorboardManager';
import { NNIDataStore } from 'core/nniDataStore'; import { NNIDataStore } from 'core/nniDataStore';
import { NNIManager } from 'core/nnimanager'; import { NNIManager } from 'core/nnimanager';
import { SqlDB } from 'core/sqlDatabase'; import { SqlDB } from 'core/sqlDatabase';
import { NNIExperimentsManager } from 'extensions/experiments_manager'; import { initExperimentsManager } from 'extensions/experiments_manager';
import { NNITensorboardManager } from 'extensions/nniTensorboardManager'; import { NNITensorboardManager } from 'extensions/nniTensorboardManager';
import { RestServer } from 'rest_server'; import { RestServer } from 'rest_server';
...@@ -49,7 +48,6 @@ async function start(): Promise<void> { ...@@ -49,7 +48,6 @@ async function start(): Promise<void> {
Container.bind(Manager).to(NNIManager).scope(Scope.Singleton); Container.bind(Manager).to(NNIManager).scope(Scope.Singleton);
Container.bind(Database).to(SqlDB).scope(Scope.Singleton); Container.bind(Database).to(SqlDB).scope(Scope.Singleton);
Container.bind(DataStore).to(NNIDataStore).scope(Scope.Singleton); Container.bind(DataStore).to(NNIDataStore).scope(Scope.Singleton);
Container.bind(ExperimentManager).to(NNIExperimentsManager).scope(Scope.Singleton);
Container.bind(TensorboardManager).to(NNITensorboardManager).scope(Scope.Singleton); Container.bind(TensorboardManager).to(NNITensorboardManager).scope(Scope.Singleton);
const ds: DataStore = component.get(DataStore); const ds: DataStore = component.get(DataStore);
...@@ -58,6 +56,8 @@ async function start(): Promise<void> { ...@@ -58,6 +56,8 @@ async function start(): Promise<void> {
const restServer = new RestServer(globals.args.port, globals.args.urlPrefix); const restServer = new RestServer(globals.args.port, globals.args.urlPrefix);
await restServer.start(); await restServer.start();
initExperimentsManager();
globals.shutdown.notifyInitializeComplete(); globals.shutdown.notifyInitializeComplete();
} }
......
...@@ -11,7 +11,7 @@ import { isNewExperiment, isReadonly } from '../common/experimentStartupInfo'; ...@@ -11,7 +11,7 @@ import { isNewExperiment, isReadonly } from '../common/experimentStartupInfo';
import globals from 'common/globals'; import globals from 'common/globals';
import { getLogger, Logger } from '../common/log'; import { getLogger, Logger } from '../common/log';
import { ExperimentProfile, Manager, TrialJobStatistics } from '../common/manager'; import { ExperimentProfile, Manager, TrialJobStatistics } from '../common/manager';
import { ExperimentManager } from '../common/experimentManager'; import { getExperimentsManager } from 'extensions/experiments_manager';
import { TensorboardManager, TensorboardTaskInfo } from '../common/tensorboardManager'; import { TensorboardManager, TensorboardTaskInfo } from '../common/tensorboardManager';
import { ValidationSchemas } from './restValidationSchemas'; import { ValidationSchemas } from './restValidationSchemas';
import { getVersion } from '../common/utils'; import { getVersion } from '../common/utils';
...@@ -24,13 +24,11 @@ import { TrialJobStatus } from '../common/trainingService'; ...@@ -24,13 +24,11 @@ import { TrialJobStatus } from '../common/trainingService';
class NNIRestHandler { class NNIRestHandler {
private nniManager: Manager; private nniManager: Manager;
private experimentsManager: ExperimentManager;
private tensorboardManager: TensorboardManager; private tensorboardManager: TensorboardManager;
private log: Logger; private log: Logger;
constructor() { constructor() {
this.nniManager = component.get(Manager); this.nniManager = component.get(Manager);
this.experimentsManager = component.get(ExperimentManager);
this.tensorboardManager = component.get(TensorboardManager); this.tensorboardManager = component.get(TensorboardManager);
this.log = getLogger('NNIRestHandler'); this.log = getLogger('NNIRestHandler');
} }
...@@ -328,7 +326,7 @@ class NNIRestHandler { ...@@ -328,7 +326,7 @@ class NNIRestHandler {
router.get('/experiment-metadata', (_req: Request, res: Response) => { router.get('/experiment-metadata', (_req: Request, res: Response) => {
Promise.all([ Promise.all([
this.nniManager.getExperimentProfile(), this.nniManager.getExperimentProfile(),
this.experimentsManager.getExperimentsInfo() getExperimentsManager().getExperimentsInfo()
]).then(([profile, experimentInfo]) => { ]).then(([profile, experimentInfo]) => {
for (const info of experimentInfo as any) { for (const info of experimentInfo as any) {
if (info.id === profile.id) { if (info.id === profile.id) {
...@@ -344,7 +342,7 @@ class NNIRestHandler { ...@@ -344,7 +342,7 @@ class NNIRestHandler {
private getExperimentsInfo(router: Router): void { private getExperimentsInfo(router: Router): void {
router.get('/experiments-info', (_req: Request, res: Response) => { router.get('/experiments-info', (_req: Request, res: Response) => {
this.experimentsManager.getExperimentsInfo().then((experimentInfo: JSON) => { getExperimentsManager().getExperimentsInfo().then((experimentInfo: JSON) => {
res.send(JSON.stringify(experimentInfo)); res.send(JSON.stringify(experimentInfo));
}).catch((err: Error) => { }).catch((err: Error) => {
this.handleError(err, res); this.handleError(err, res);
......
...@@ -11,10 +11,8 @@ import { Container, Scope } from 'typescript-ioc'; ...@@ -11,10 +11,8 @@ import { Container, Scope } from 'typescript-ioc';
import * as component from '../../common/component'; import * as component from '../../common/component';
import { Database, DataStore } from '../../common/datastore'; import { Database, DataStore } from '../../common/datastore';
import { Manager, ExperimentProfile} from '../../common/manager'; import { Manager, ExperimentProfile} from '../../common/manager';
import { ExperimentManager } from '../../common/experimentManager';
import { TrainingService } from '../../common/trainingService'; import { TrainingService } from '../../common/trainingService';
import { cleanupUnitTest, prepareUnitTest, killPid } from '../../common/utils'; import { cleanupUnitTest, prepareUnitTest, killPid } from '../../common/utils';
import { NNIExperimentsManager } from 'extensions/experiments_manager';
import { NNIManager } from '../../core/nnimanager'; import { NNIManager } from '../../core/nnimanager';
import { SqlDB } from '../../core/sqlDatabase'; import { SqlDB } from '../../core/sqlDatabase';
import { MockedTrainingService } from '../mock/trainingService'; import { MockedTrainingService } from '../mock/trainingService';
...@@ -30,7 +28,6 @@ async function initContainer(): Promise<void> { ...@@ -30,7 +28,6 @@ async function initContainer(): Promise<void> {
Container.bind(Manager).to(NNIManager).scope(Scope.Singleton); Container.bind(Manager).to(NNIManager).scope(Scope.Singleton);
Container.bind(Database).to(SqlDB).scope(Scope.Singleton); Container.bind(Database).to(SqlDB).scope(Scope.Singleton);
Container.bind(DataStore).to(MockedDataStore).scope(Scope.Singleton); Container.bind(DataStore).to(MockedDataStore).scope(Scope.Singleton);
Container.bind(ExperimentManager).to(NNIExperimentsManager).scope(Scope.Singleton);
Container.bind(TensorboardManager).to(NNITensorboardManager).scope(Scope.Singleton); Container.bind(TensorboardManager).to(NNITensorboardManager).scope(Scope.Singleton);
await component.get<DataStore>(DataStore).init(); await component.get<DataStore>(DataStore).init();
} }
...@@ -116,8 +113,6 @@ describe('Unit test for nnimanager', function () { ...@@ -116,8 +113,6 @@ describe('Unit test for nnimanager', function () {
before(async () => { before(async () => {
await initContainer(); await initContainer();
fs.writeFileSync('.experiment.test', JSON.stringify(mockedInfo)); fs.writeFileSync('.experiment.test', JSON.stringify(mockedInfo));
const experimentsManager: ExperimentManager = component.get(ExperimentManager);
experimentsManager.setExperimentPath('.experiment.test');
nniManager = component.get(Manager); nniManager = component.get(Manager);
const expId: string = await nniManager.startExperiment(experimentParams); const expId: string = await nniManager.startExperiment(experimentParams);
......
// Copyright (c) Microsoft Corporation. // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license. // Licensed under the MIT license.
'use strict';
import { assert, expect } from 'chai'; import { assert, expect } from 'chai';
import * as fs from 'fs'; import fs from 'fs';
import { Container, Scope } from 'typescript-ioc'; import { Container, Scope } from 'typescript-ioc';
import os from 'os';
import path from 'path';
import * as component from '../../common/component'; import * as component from '../../common/component';
import { cleanupUnitTest, prepareUnitTest } from '../../common/utils'; import { cleanupUnitTest, prepareUnitTest } from '../../common/utils';
import { ExperimentManager } from '../../common/experimentManager'; import { ExperimentsManager } from 'extensions/experiments_manager';
import { NNIExperimentsManager } from 'extensions/experiments_manager'; import globals from 'common/globals/unittest';
let tempDir: string | null = null;
describe('Unit test for experiment manager', function () { let experimentManager: ExperimentsManager;
let experimentManager: NNIExperimentsManager; const mockedInfo = {
const mockedInfo = { "test": {
"test": { "port": 8080,
"port": 8080, "startTime": 1605246730756,
"startTime": 1605246730756, "endTime": "N/A",
"endTime": "N/A", "status": "INITIALIZED",
"status": "INITIALIZED", "platform": "local",
"platform": "local", "experimentName": "testExp",
"experimentName": "testExp", "tag": [], "pid": 11111,
"tag": [], "pid": 11111, "webuiUrl": [],
"webuiUrl": [], "logDir": null
"logDir": null
}
} }
}
describe('Unit test for experiment manager', function () {
before(() => { before(() => {
prepareUnitTest(); prepareUnitTest();
fs.writeFileSync('.experiment.test', JSON.stringify(mockedInfo)); tempDir = fs.mkdtempSync(path.join(os.tmpdir(), 'nni-ut-'));
Container.bind(ExperimentManager).to(NNIExperimentsManager).scope(Scope.Singleton); globals.paths.experimentsList = path.join(tempDir, '.experiment');
experimentManager = component.get(NNIExperimentsManager); fs.writeFileSync(globals.paths.experimentsList, JSON.stringify(mockedInfo));
experimentManager.setExperimentPath('.experiment.test'); experimentManager = new ExperimentsManager();
}); });
after(() => { after(() => {
if (fs.existsSync('.experiment.test')) { if (tempDir !== null) {
fs.unlinkSync('.experiment.test'); fs.rmSync(tempDir, { force: true, recursive: true });
} }
cleanupUnitTest(); cleanupUnitTest();
}); });
it('test getExperimentsInfo', () => { it('test getExperimentsInfo', async () => {
return experimentManager.getExperimentsInfo().then(function (experimentsInfo: {[key: string]: any}) { const experimentsInfo: {[key: string]: any} = await experimentManager.getExperimentsInfo();
new Array(experimentsInfo) for (let idx in experimentsInfo) {
for (let idx in experimentsInfo) { if (experimentsInfo[idx]['id'] === 'test') {
if (experimentsInfo[idx]['id'] === 'test') { expect(experimentsInfo[idx]['status']).to.be.oneOf(['STOPPED', 'ERROR']);
expect(experimentsInfo[idx]['status']).to.be.oneOf(['STOPPED', 'ERROR']); break;
break;
}
} }
}).catch((error) => { }
assert.fail(error);
})
}); });
}); });
// Copyright (c) Microsoft Corporation. // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license. // Licensed under the MIT license.
'use strict'; export class MockedExperimentManager {
import { ExperimentManager } from '../../common/experimentManager';
import { Provider } from 'typescript-ioc';
export const testExperimentManagerProvider: Provider = {
get: (): ExperimentManager => { return new mockedeExperimentManager(); }
};
export class mockedeExperimentManager extends ExperimentManager {
public getExperimentsInfo(): Promise<JSON> { public getExperimentsInfo(): Promise<JSON> {
const expInfo = JSON.parse(JSON.stringify({ const expInfo = JSON.parse(JSON.stringify({
"test": { "test": {
......
...@@ -10,16 +10,16 @@ import { Container } from 'typescript-ioc'; ...@@ -10,16 +10,16 @@ import { Container } from 'typescript-ioc';
import * as component from '../../common/component'; import * as component from '../../common/component';
import { DataStore } from '../../common/datastore'; import { DataStore } from '../../common/datastore';
import { ExperimentProfile, Manager } from '../../common/manager'; import { ExperimentProfile, Manager } from '../../common/manager';
import { ExperimentManager } from '../../common/experimentManager'
import { TrainingService } from '../../common/trainingService'; import { TrainingService } from '../../common/trainingService';
import { cleanupUnitTest, prepareUnitTest } from '../../common/utils'; import { cleanupUnitTest, prepareUnitTest } from '../../common/utils';
import { MockedDataStore } from '../mock/datastore'; import { MockedDataStore } from '../mock/datastore';
import { MockedTrainingService } from '../mock/trainingService'; import { MockedTrainingService } from '../mock/trainingService';
import { RestServer, UnitTestHelpers } from 'rest_server'; import { RestServer, UnitTestHelpers } from 'rest_server';
import { testManagerProvider } from '../mock/nniManager'; import { testManagerProvider } from '../mock/nniManager';
import { testExperimentManagerProvider } from '../mock/experimentManager'; import { MockedExperimentManager } from '../mock/experimentManager';
import { TensorboardManager } from '../../common/tensorboardManager'; import { TensorboardManager } from '../../common/tensorboardManager';
import { MockTensorboardManager } from '../mock/mockTensorboardManager'; import { MockTensorboardManager } from '../mock/mockTensorboardManager';
import { UnitTestHelpers as ExpsMgrHelpers } from 'extensions/experiments_manager';
let restServer: RestServer; let restServer: RestServer;
...@@ -28,11 +28,11 @@ describe('Unit test for rest handler', () => { ...@@ -28,11 +28,11 @@ describe('Unit test for rest handler', () => {
let ROOT_URL: string; let ROOT_URL: string;
before(async () => { before(async () => {
ExpsMgrHelpers.setExperimentsManager(new MockedExperimentManager());
prepareUnitTest(); prepareUnitTest();
Container.bind(Manager).provider(testManagerProvider); Container.bind(Manager).provider(testManagerProvider);
Container.bind(DataStore).to(MockedDataStore); Container.bind(DataStore).to(MockedDataStore);
Container.bind(TrainingService).to(MockedTrainingService); Container.bind(TrainingService).to(MockedTrainingService);
Container.bind(ExperimentManager).provider(testExperimentManagerProvider);
Container.bind(TensorboardManager).to(MockTensorboardManager); Container.bind(TensorboardManager).to(MockTensorboardManager);
restServer = new RestServer(0, ''); restServer = new RestServer(0, '');
await restServer.start(); await restServer.start();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment