Unverified Commit efd433bd authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Refactor NNI manager globals (step 4) - logging (#4705)

parent 3f16a5a0
...@@ -18,6 +18,7 @@ import assert from 'assert/strict'; ...@@ -18,6 +18,7 @@ import assert from 'assert/strict';
import { NniManagerArgs, parseArgs } from './arguments'; import { NniManagerArgs, parseArgs } from './arguments';
import { NniPaths, createPaths } from './paths'; import { NniPaths, createPaths } from './paths';
import { LogStream, initLogStream } from './log_stream';
export { NniManagerArgs, NniPaths }; export { NniManagerArgs, NniPaths };
...@@ -30,6 +31,8 @@ export { NniManagerArgs, NniPaths }; ...@@ -30,6 +31,8 @@ export { NniManagerArgs, NniPaths };
export interface NniGlobals { export interface NniGlobals {
readonly args: NniManagerArgs; readonly args: NniManagerArgs;
readonly paths: NniPaths; readonly paths: NniPaths;
readonly logStream: LogStream;
} }
// give type hint to `global.nni` (copied from SO, dunno how it works) // give type hint to `global.nni` (copied from SO, dunno how it works)
...@@ -53,7 +56,8 @@ export function initGlobals(): void { ...@@ -53,7 +56,8 @@ export function initGlobals(): void {
const args = parseArgs(process.argv.slice(2)); const args = parseArgs(process.argv.slice(2));
const paths = createPaths(args); const paths = createPaths(args);
const logStream = initLogStream(args, paths);
const globals: NniGlobals = { args, paths }; const globals: NniGlobals = { args, paths, logStream };
Object.assign(global.nni, globals); Object.assign(global.nni, globals);
} }
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
/**
* The underlying IO stream of loggers.
*
* Normal modules should not use this directly. Use "common/log.ts" instead.
**/
import fs from 'fs';
import { setTimeout } from 'timers/promises';
import util from 'util';
import type { NniManagerArgs } from './arguments';
import type { NniPaths } from './paths';
export interface LogStream {
writeLine(line: string): void;
writeLineSync(line: string): void;
close(): Promise<void>;
}
const writePromise = util.promisify(fs.write);
class LogStreamImpl implements LogStream {
private buffer: string[] = [];
private flushing: boolean = false;
private logFileFd: number;
private toConsole: boolean;
constructor(logFile: string, toConsole: boolean) {
this.logFileFd = fs.openSync(logFile, 'a');
this.toConsole = toConsole;
}
public writeLine(line: string): void {
this.buffer.push(line);
this.flush();
}
public writeLineSync(line: string): void {
if (this.toConsole) {
console.log(line);
}
fs.writeSync(this.logFileFd, line + '\n');
}
public async close(): Promise<void> {
while (this.flushing) {
await setTimeout();
}
fs.closeSync(this.logFileFd);
this.logFileFd = 2; // stderr
this.toConsole = false;
}
private async flush(): Promise<void> {
if (this.flushing) {
return;
}
this.flushing = true;
while (this.buffer.length > 0) {
const lines = this.buffer.join('\n');
this.buffer.length = 0;
if (this.toConsole) {
console.log(lines);
}
await writePromise(this.logFileFd, lines + '\n');
}
this.flushing = false;
}
}
export function initLogStream(args: NniManagerArgs, paths: NniPaths): LogStream {
return new LogStreamImpl(paths.nniManagerLog, args.foreground);
}
...@@ -18,6 +18,7 @@ import path from 'path'; ...@@ -18,6 +18,7 @@ import path from 'path';
import type { NniManagerArgs } from './arguments'; import type { NniManagerArgs } from './arguments';
import { NniPaths, createPaths } from './paths'; import { NniPaths, createPaths } from './paths';
import type { LogStream } from './log_stream';
// copied from https://www.typescriptlang.org/docs/handbook/2/mapped-types.html // copied from https://www.typescriptlang.org/docs/handbook/2/mapped-types.html
type Mutable<Type> = { type Mutable<Type> = {
...@@ -27,6 +28,9 @@ type Mutable<Type> = { ...@@ -27,6 +28,9 @@ type Mutable<Type> = {
export interface MutableGlobals { export interface MutableGlobals {
args: Mutable<NniManagerArgs>; args: Mutable<NniManagerArgs>;
paths: Mutable<NniPaths>; paths: Mutable<NniPaths>;
logStream: LogStream;
reset(): void;
} }
export function resetGlobals(): void { export function resetGlobals(): void {
...@@ -41,14 +45,19 @@ export function resetGlobals(): void { ...@@ -41,14 +45,19 @@ export function resetGlobals(): void {
mode: 'unittest', mode: 'unittest',
dispatcherPipe: undefined dispatcherPipe: undefined
}; };
const paths = createPaths(args); const paths = createPaths(args);
const logStream = {
writeLine: (_line: string): void => { /* dummy */ },
writeLineSync: (_line: string): void => { /* dummy */ },
close: (): void => { /* dummy */ }
};
const globals = { args, paths }; const globalAsAny = global as any;
if (global.nni === undefined) { const utGlobals = { args, paths, logStream, reset: resetGlobals };
global.nni = globals; if (globalAsAny.nni === undefined) {
globalAsAny.nni = utGlobals;
} else { } else {
Object.assign(global.nni, globals); Object.assign(globalAsAny.nni, utGlobals);
} }
} }
...@@ -61,5 +70,5 @@ if (isUnitTest()) { ...@@ -61,5 +70,5 @@ if (isUnitTest()) {
resetGlobals(); resetGlobals();
} }
const globals: MutableGlobals = global.nni; const globals: MutableGlobals = (global as any).nni;
export default globals; export default globals;
// Copyright (c) Microsoft Corporation. // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license. // Licensed under the MIT license.
import fs from 'fs'; /**
import { Writable } from 'stream'; * Python-like logging interface.
import util from 'util'; *
* const logger = getLogger('moduleName');
/* log level constants */ * logger.info('hello', { to: 'world' });
*
* Outputs:
*
* [1970-01-01 00:00:00] INFO (moduleName) hello { to: 'world' }
**/
export const DEBUG = 10; import util from 'util';
export const INFO = 20;
export const WARNING = 30;
export const ERROR = 40;
export const CRITICAL = 50;
export const TRACE = 1; import globals from 'common/globals';
export const FATAL = 50;
const levelNames = new Map<number, string>([ const levelNameToValue = { trace: 0, debug: 10, info: 20, warning: 30, error: 40, critical: 50 } as const;
[CRITICAL, 'CRITICAL'],
[ERROR, 'ERROR'],
[WARNING, 'WARNING'],
[INFO, 'INFO'],
[DEBUG, 'DEBUG'],
[TRACE, 'TRACE'],
]);
/* global_ states */ const loggers: Record<string, Logger> = {};
let logLevel: number = 0; export function getLogger(name: string): Logger {
const loggers = new Map<string, Logger>(); if (loggers[name] === undefined) {
loggers[name] = new Logger(name);
}
return loggers[name];
}
/* major api */ /**
* A special logger prints to stderr when the logging system has problems.
* For modules that are responsible for handling logger errors.
**/
export function getRobustLogger(name: string): Logger {
if (loggers[name] === undefined || !(loggers[name] as RobustLogger).robust) {
loggers[name] = new RobustLogger(name);
}
return loggers[name];
}
export class Logger { export class Logger {
private name: string; protected name: string;
constructor(name: string = 'root') { constructor(name: string) {
this.name = name; this.name = name;
} }
public trace(...args: any[]): void { public trace(...args: any[]): void {
this.log(TRACE, args); this.log(levelNameToValue.trace, 'TRACE', args);
} }
public debug(...args: any[]): void { public debug(...args: any[]): void {
this.log(DEBUG, args); this.log(levelNameToValue.debug, 'DEBUG', args);
} }
public info(...args: any[]): void { public info(...args: any[]): void {
this.log(INFO, args); this.log(levelNameToValue.info, 'INFO', args);
} }
public warning(...args: any[]): void { public warning(...args: any[]): void {
this.log(WARNING, args); this.log(levelNameToValue.warning, 'WARNING', args);
} }
public error(...args: any[]): void { public error(...args: any[]): void {
this.log(ERROR, args); this.log(levelNameToValue.error, 'ERROR', args);
} }
public critical(...args: any[]): void { public critical(...args: any[]): void {
this.log(CRITICAL, args); this.log(levelNameToValue.critical, 'CRITICAL', args);
} }
public fatal(...args: any[]): void { protected log(levelValue: number, levelName: string, args: any[]): void {
this.log(FATAL, args); if (levelValue >= levelNameToValue[globals.args.logLevel]) {
const msg = `[${timestamp()}] ${levelName} (${this.name}) ${formatArgs(args)}`;
globals.logStream.writeLine(msg);
}
} }
}
private log(level: number, args: any[]): void { class RobustLogger extends Logger {
const logFile: Writable | undefined = (global as any).logFile; public readonly robust: boolean = true;
if (level < logLevel) { private errorOccurred: boolean = false;
protected log(levelValue: number, levelName: string, args: any[]): void {
if (this.errorOccurred) {
this.logAfterError(levelName, args);
return; return;
} }
try {
const zeroPad = (num: number): string => num.toString().padStart(2, '0'); if (levelValue >= levelNameToValue[globals.args.logLevel]) {
const now = new Date(); const msg = `[${timestamp()}] ${levelName} (${this.name}) ${formatArgs(args)}`;
const date = now.getFullYear() + '-' + zeroPad(now.getMonth() + 1) + '-' + zeroPad(now.getDate()); globals.logStream.writeLineSync(msg);
const time = zeroPad(now.getHours()) + ':' + zeroPad(now.getMinutes()) + ':' + zeroPad(now.getSeconds());
const datetime = date + ' ' + time;
const levelName = levelNames.has(level) ? levelNames.get(level) : level.toString();
const message = args.map(arg => (typeof arg === 'string' ? arg : util.inspect(arg))).join(' ');
const record = `[${datetime}] ${levelName} (${this.name}) ${message}`;
if (logFile === undefined) {
if (!isUnitTest()) { // be quite for unit test
console.log(record);
} }
} else { } catch (error) {
logFile.write(record + '\n'); this.errorOccurred = true;
console.error('[ERROR] Logger has stopped working:', error);
this.logAfterError(levelName, args);
} }
} }
}
export function getLogger(name: string = 'root'): Logger { private logAfterError(levelName: string, args: any[]): void {
let logger = loggers.get(name); try {
if (logger === undefined) { args = args.map(arg => util.inspect(arg));
logger = new Logger(name); } catch { /* fallback */ }
loggers.set(name, logger); console.error(`[${levelName}] (${this.name})`, ...args);
} }
return logger;
} }
/* management functions */ function timestamp(): string {
const now = new Date();
export function setLogLevel(levelName: string): void { const date = now.getFullYear() + '-' + zeroPad(now.getMonth() + 1) + '-' + zeroPad(now.getDate());
if (levelName) { const time = zeroPad(now.getHours()) + ':' + zeroPad(now.getMinutes()) + ':' + zeroPad(now.getSeconds());
const level = module.exports[levelName.toUpperCase()]; return date + ' ' + time;
if (typeof level === 'number') {
logLevel = level;
} else {
console.log('[ERROR] Bad log level:', levelName);
getLogger('logging').error('Bad log level:', levelName);
}
}
} }
export function startLogging(logPath: string): void { function zeroPad(num: number): string {
(global as any).logFile = fs.createWriteStream(logPath, { return num.toString().padStart(2, '0');
flags: 'a+',
encoding: 'utf8',
autoClose: true
});
} }
export function stopLogging(): void { function formatArgs(args: any[]): string {
if ((global as any).logFile !== undefined) { return args.map(arg => (typeof arg === 'string' ? arg : util.inspect(arg))).join(' ');
(global as any).logFile.end();
(global as any).logFile = undefined;
}
}
/* utilities */
function isUnitTest(): boolean {
const event = process.env['npm_lifecycle_event'] ?? '';
return event.startsWith('test') || event === 'mocha' || event === 'nyc';
} }
...@@ -8,7 +8,8 @@ import * as component from '../common/component'; ...@@ -8,7 +8,8 @@ import * as component from '../common/component';
import { DataStore, MetricDataRecord, MetricType, TrialJobInfo } from '../common/datastore'; import { DataStore, MetricDataRecord, MetricType, TrialJobInfo } from '../common/datastore';
import { NNIError } from '../common/errors'; import { NNIError } from '../common/errors';
import { getExperimentId, getDispatcherPipe } from '../common/experimentStartupInfo'; import { getExperimentId, getDispatcherPipe } from '../common/experimentStartupInfo';
import { Logger, getLogger, stopLogging } from '../common/log'; import globals from 'common/globals';
import { Logger, getLogger } from '../common/log';
import { import {
ExperimentProfile, Manager, ExperimentStatus, ExperimentProfile, Manager, ExperimentStatus,
NNIManagerStatus, ProfileUpdateType, TrialJobStatistics NNIManagerStatus, ProfileUpdateType, TrialJobStatistics
...@@ -360,7 +361,7 @@ class NNIManager implements Manager { ...@@ -360,7 +361,7 @@ class NNIManager implements Manager {
hasError = true; hasError = true;
this.log.error(`${err.stack}`); this.log.error(`${err.stack}`);
} finally { } finally {
stopLogging(); await globals.logStream.close();
process.exit(hasError ? 1 : 0); process.exit(hasError ? 1 : 0);
} }
} }
......
...@@ -29,7 +29,7 @@ import * as component from 'common/component'; ...@@ -29,7 +29,7 @@ import * as component from 'common/component';
import { Database, DataStore } from 'common/datastore'; import { Database, DataStore } from 'common/datastore';
import { ExperimentManager } from 'common/experimentManager'; import { ExperimentManager } from 'common/experimentManager';
import globals, { initGlobals } from 'common/globals'; import globals, { initGlobals } from 'common/globals';
import { getLogger, setLogLevel, startLogging } from 'common/log'; import { getLogger } from 'common/log';
import { Manager } from 'common/manager'; import { Manager } from 'common/manager';
import { TensorboardManager } from 'common/tensorboardManager'; import { TensorboardManager } from 'common/tensorboardManager';
import { NNIDataStore } from 'core/nniDataStore'; import { NNIDataStore } from 'core/nniDataStore';
...@@ -72,18 +72,14 @@ process.on('SIGINT', shutdown); ...@@ -72,18 +72,14 @@ process.on('SIGINT', shutdown);
initGlobals(); initGlobals();
// TODO: these should be handled inside globals module
startLogging(globals.paths.nniManagerLog);
setLogLevel(globals.args.logLevel);
start().then(() => { start().then(() => {
getLogger('main').debug('start() returned.'); getLogger('main').debug('start() returned.');
}).catch((error) => { }).catch((error) => {
try { try {
getLogger('main').error('Failed to start:', error); getLogger('main').error('Failed to start:', error);
} catch (loggerError) { } catch (loggerError) {
console.log('Failed to start:', error); console.error('Failed to start:', error);
console.log('Seems logger is faulty:', loggerError); console.error('Seems logger is faulty:', loggerError);
} }
process.exit(1); process.exit(1);
}); });
......
...@@ -114,7 +114,7 @@ function rootRouter(stopCallback: () => Promise<void>): Router { ...@@ -114,7 +114,7 @@ function rootRouter(stopCallback: () => Promise<void>): Router {
router.use(express.json({ limit: '50mb' })); router.use(express.json({ limit: '50mb' }));
/* NNI manager APIs */ /* NNI manager APIs */
router.use('/api/v1/nni', createRestHandler(stopCallback)); router.use('/api/v1/nni', restHandlerFactory(stopCallback));
/* Download log files */ /* Download log files */
// The REST API path "/logs" does not match file system path "/log". // The REST API path "/logs" does not match file system path "/log".
...@@ -149,6 +149,7 @@ function netronProxy(): Router { ...@@ -149,6 +149,7 @@ function netronProxy(): Router {
let webuiPath: string = path.resolve('static'); let webuiPath: string = path.resolve('static');
let netronUrl: string = 'https://netron.app'; let netronUrl: string = 'https://netron.app';
let restHandlerFactory = createRestHandler;
export namespace UnitTestHelpers { export namespace UnitTestHelpers {
export function getPort(server: RestServer): number { export function getPort(server: RestServer): number {
...@@ -162,4 +163,14 @@ export namespace UnitTestHelpers { ...@@ -162,4 +163,14 @@ export namespace UnitTestHelpers {
export function setNetronUrl(mockUrl: string): void { export function setNetronUrl(mockUrl: string): void {
netronUrl = mockUrl; netronUrl = mockUrl;
} }
export function disableNniManager(): void {
restHandlerFactory = (_: any): Router => Router();
}
export function reset(): void {
webuiPath = path.resolve('static');
netronUrl = 'https://netron.app';
restHandlerFactory = createRestHandler;
}
} }
...@@ -100,7 +100,7 @@ class NNIRestHandler { ...@@ -100,7 +100,7 @@ class NNIRestHandler {
// If it's a fatal error, exit process // If it's a fatal error, exit process
if (isFatal) { if (isFatal) {
this.log.fatal(err); this.log.critical(err);
process.exit(1); process.exit(1);
} else { } else {
this.log.error(err); this.log.error(err);
......
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
import assert from 'assert/strict';
import fs from 'fs';
import path from 'path';
import { setTimeout } from 'timers/promises';
import { LogStream, initLogStream } from 'common/globals/log_stream';
import globals from 'common/globals/unittest';
const lines = [ 'hello', '你好' ];
let logStream: LogStream;
let consoleContent: string = '';
/* test cases */
// Test cases will be run twice, the first time in background mode (write to log file only),
// and the second time in foreground mode (write to log file + stdout).
// Write 2 lines and wait 10 ms for it to flush.
async function testWrite(): Promise<void> {
logStream.writeLine(lines[0]);
logStream.writeLine(lines[1]);
await setTimeout(10);
const expected = [ lines[0], lines[1] ].join('\n') + '\n';
const fileContent = fs.readFileSync(globals.paths.nniManagerLog, { encoding: 'utf8' });
assert.equal(fileContent, expected);
assert.equal(consoleContent, globals.args.foreground ? expected : '');
}
// Write 2 lines synchronously. It should not need to flush.
async function testWriteSync(): Promise<void> {
logStream.writeLineSync(lines[0]);
logStream.writeLineSync(lines[1]);
const expected = [ lines[0], lines[1], lines[0], lines[1] ].join('\n') + '\n';
const fileContent = fs.readFileSync(globals.paths.nniManagerLog, { encoding: 'utf8' });
assert.equal(fileContent, expected);
assert.equal(consoleContent, globals.args.foreground ? expected : '');
}
// Write 2 lines and close stream. It should guarantee to flush.
async function testClose(): Promise<void> {
logStream.writeLine(lines[1]);
logStream.writeLine(lines[0]);
await logStream.close();
const expected = [ lines[0], lines[1], lines[0], lines[1], lines[1], lines[0] ].join('\n') + '\n';
const fileContent = fs.readFileSync(globals.paths.nniManagerLog, { encoding: 'utf8' });
assert.equal(fileContent, expected);
assert.equal(consoleContent, globals.args.foreground ? expected : '');
}
/* register test cases */
describe('## globals.log_stream ##', () => {
before(beforeHook);
it('background', () => testWrite());
it('background sync', () => testWriteSync());
it('background close', () => testClose());
it('// switch to foreground', () => { switchForeground(); });
it('foreground', () => testWrite());
it('foreground sync', () => testWriteSync());
it('foreground close', () => testClose());
after(afterHook);
});
/* configure test environment */
const origConsoleLog = console.log;
const tempDir = fs.mkdtempSync('nni-ut-');
function beforeHook() {
console.log = (line => { consoleContent += line + '\n'; });
globals.paths.nniManagerLog = path.join(tempDir, 'nnimanager.log');
globals.args.foreground = false;
logStream = initLogStream(globals.args, globals.paths);
}
function switchForeground() {
logStream.close();
consoleContent = '';
fs.rmSync(globals.paths.nniManagerLog);
globals.args.foreground = true;
logStream = initLogStream(globals.args, globals.paths);
}
function afterHook() {
console.log = origConsoleLog;
fs.rmSync(tempDir, { force: true, recursive: true });
globals.reset();
}
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
import assert from 'assert/strict';
import globals from 'common/globals/unittest';
import { Logger, getLogger, getRobustLogger } from 'common/log';
/* test cases */
function testDebugLevel() {
stream.reset();
globals.args.logLevel = 'debug';
writeLogs1(getLogger('DebugLogger'));
writeLogs2(getLogger('DebugLogger'));
assert.match(stream.content, /DebugLogger/);
assert.match(stream.content, /debug-message/);
assert.match(stream.content, /info-message/);
assert.match(stream.content, /warning-message/);
assert.match(stream.content, /error-message/);
assert.match(stream.content, /critical-message/);
assert.equal(stderr, '');
}
function testWarningLevel() {
stream.reset();
globals.args.logLevel = 'warning';
writeLogs1(getLogger('WarningLogger1'));
writeLogs2(getLogger('WarningLogger2'));
assert.match(stream.content, /WarningLogger1/);
assert.match(stream.content, /WarningLogger2/);
assert.doesNotMatch(stream.content, /debug-message/);
assert.doesNotMatch(stream.content, /info-message/);
assert.match(stream.content, /warning-message/);
assert.match(stream.content, /error-message/);
assert.match(stream.content, /critical-message/);
assert.equal(stderr, '');
}
function testRobust() {
stream.reset();
globals.args.logLevel = 'info';
const logger = getRobustLogger('RobustLogger');
writeLogs1(logger);
stream.error = true;
writeLogs2(logger);
assert.match(stream.content, /RobustLogger/);
assert.doesNotMatch(stream.content, /debug-message/);
assert.match(stream.content, /info-message/);
assert.match(stream.content, /warning-message/);
assert.match(stderr, /stream-error/);
assert.match(stderr, /error-message/);
assert.match(stderr, /critical-message/);
}
/* register */
describe('## logging ##', () => {
before(beforeHook);
it('low log level', testDebugLevel);
it('high log level', testWarningLevel);
it('robust', testRobust);
after(afterHook);
});
/* helpers */
function writeLogs1(logger: Logger) {
logger.debug('debug-message');
logger.info(1, '2', 'info-message', 3);
logger.warning(undefined, [null, 'warning-message']);
}
function writeLogs2(logger: Logger) {
const recursiveObject: any = { 'message': 'error-message' };
recursiveObject.recursive = recursiveObject;
logger.error(recursiveObject);
logger.critical(new Error('critical-message'));
}
class TestLogStream {
public content: string = '';
public error: boolean = false;
reset(): void {
this.content = '';
this.error = false;
}
writeLine(line: string): void {
if (this.error) {
throw new Error('stream-error');
}
this.content += line + '\n';
}
writeLineSync(line: string): void {
if (this.error) {
throw new Error('stream-error');
}
this.content += line + '\n';
}
async close(): Promise<void> {
/* empty */
}
}
/* environment */
const stream = new TestLogStream();
const origConsoleError = console.error;
let stderr: string = '';
async function beforeHook() {
globals.logStream = stream;
console.error = (...args: any[]) => { stderr += args.join(' ') + '\n'; };
}
async function afterHook() {
globals.reset();
console.error = origConsoleError;
}
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
import type { TensorboardManager, TensorboardParams, TensorboardTaskInfo } from 'common/tensorboardManager';
const taskInfo: TensorboardTaskInfo = {
id: 'ID',
status: 'RUNNING',
trialJobIdList: [],
trialLogDirectoryList: [],
pid: undefined,
port: undefined,
};
export class MockTensorboardManager implements TensorboardManager {
public async startTensorboardTask(_tensorboardParams: TensorboardParams): Promise<TensorboardTaskInfo> {
return taskInfo;
}
public async getTensorboardTask(_tensorboardTaskId: string): Promise<TensorboardTaskInfo> {
return taskInfo;
}
public async updateTensorboardTask(_tensorboardTaskId: string): Promise<TensorboardTaskInfo> {
return taskInfo;
}
public async listTensorboardTasks(): Promise<TensorboardTaskInfo[]> {
return [ taskInfo ];
}
public async stopTensorboardTask(_tensorboardTaskId: string): Promise<TensorboardTaskInfo> {
return taskInfo;
}
public async stopAllTensorboardTask(): Promise<void> {
return;
}
public async stop(): Promise<void> {
return;
}
}
...@@ -15,32 +15,29 @@ import { TrainingService } from '../../common/trainingService'; ...@@ -15,32 +15,29 @@ import { TrainingService } from '../../common/trainingService';
import { cleanupUnitTest, prepareUnitTest } from '../../common/utils'; import { cleanupUnitTest, prepareUnitTest } from '../../common/utils';
import { MockedDataStore } from '../mock/datastore'; import { MockedDataStore } from '../mock/datastore';
import { MockedTrainingService } from '../mock/trainingService'; import { MockedTrainingService } from '../mock/trainingService';
import { RestServer } from '../../rest_server'; import { RestServer, UnitTestHelpers } from 'rest_server';
import { testManagerProvider } from '../mock/nniManager'; import { testManagerProvider } from '../mock/nniManager';
import { testExperimentManagerProvider } from '../mock/experimentManager'; import { testExperimentManagerProvider } from '../mock/experimentManager';
import { TensorboardManager } from '../../common/tensorboardManager'; import { TensorboardManager } from '../../common/tensorboardManager';
import { NNITensorboardManager } from '../../core/nniTensorboardManager'; import { MockTensorboardManager } from '../mock/mockTensorboardManager';
let restServer: RestServer; let restServer: RestServer;
describe('Unit test for rest server', () => { describe('Unit test for rest handler', () => {
let ROOT_URL: string; let ROOT_URL: string;
before((done: Mocha.Done) => { before(async () => {
prepareUnitTest(); prepareUnitTest();
Container.bind(Manager).provider(testManagerProvider); Container.bind(Manager).provider(testManagerProvider);
Container.bind(DataStore).to(MockedDataStore); Container.bind(DataStore).to(MockedDataStore);
Container.bind(TrainingService).to(MockedTrainingService); Container.bind(TrainingService).to(MockedTrainingService);
Container.bind(ExperimentManager).provider(testExperimentManagerProvider); Container.bind(ExperimentManager).provider(testExperimentManagerProvider);
Container.bind(TensorboardManager).to(NNITensorboardManager); Container.bind(TensorboardManager).to(MockTensorboardManager);
restServer = new RestServer(8080, ''); restServer = new RestServer(0, '');
restServer.start().then(() => { await restServer.start();
ROOT_URL = `http://localhost:8080/api/v1/nni`; const port = UnitTestHelpers.getPort(restServer);
done(); ROOT_URL = `http://localhost:${port}/api/v1/nni`;
}).catch((e: Error) => {
assert.fail(`Failed to start rest server: ${e.message}`);
});
}); });
after(() => { after(() => {
...@@ -130,56 +127,4 @@ describe('Unit test for rest server', () => { ...@@ -130,56 +127,4 @@ describe('Unit test for rest server', () => {
} }
}); });
}); });
/* FIXME
it('Test PUT experiment/cluster-metadata bad key', (done: Mocha.Done) => {
const req: request.Options = {
uri: `${ROOT_URL}/experiment/cluster-metadata`,
method: 'PUT',
json: true,
body: {
exception_test_key: 'test'
}
};
request(req, (err: Error, res: request.Response) => {
if (err) {
assert.fail(err.message);
} else {
expect(res.statusCode).to.equal(400);
}
done();
});
});
*/
/* FIXME
it('Test PUT experiment/cluster-metadata', (done: Mocha.Done) => {
const req: request.Options = {
uri: `${ROOT_URL}/experiment/cluster-metadata`,
method: 'PUT',
json: true,
body: {
machine_list: [{
ip: '10.10.10.101',
port: 22,
username: 'test',
passwd: '1234'
}, {
ip: '10.10.10.102',
port: 22,
username: 'test',
passwd: '1234'
}]
}
};
request(req, (err: Error, res: request.Response) => {
if (err) {
assert.fail(err.message);
} else {
expect(res.statusCode).to.equal(200);
}
done();
});
});
*/
}); });
...@@ -7,7 +7,7 @@ import path from 'path'; ...@@ -7,7 +7,7 @@ import path from 'path';
import fetch from 'node-fetch'; import fetch from 'node-fetch';
import globals, { resetGlobals } from 'common/globals/unittest'; import globals from 'common/globals/unittest';
import { RestServer, UnitTestHelpers } from 'rest_server'; import { RestServer, UnitTestHelpers } from 'rest_server';
import * as mock_netron_server from './mock_netron_server'; import * as mock_netron_server from './mock_netron_server';
...@@ -89,6 +89,8 @@ async function testOutsidePrefix(): Promise<void> { ...@@ -89,6 +89,8 @@ async function testOutsidePrefix(): Promise<void> {
/* Register test cases */ /* Register test cases */
describe('## rest_server ##', () => { describe('## rest_server ##', () => {
before(beforeHook);
it('logs', () => testLogs()); it('logs', () => testLogs());
it('netron get', () => testNetronGet()); it('netron get', () => testNetronGet());
it('netron post', () => testNetronPost()); it('netron post', () => testNetronPost());
...@@ -107,30 +109,34 @@ describe('## rest_server ##', () => { ...@@ -107,30 +109,34 @@ describe('## rest_server ##', () => {
it('prefix webui resource', () => testWebuiResource()); it('prefix webui resource', () => testWebuiResource());
it('prefix webui routing', () => testWebuiRouting()); it('prefix webui routing', () => testWebuiRouting());
it('outside prefix', () => testOutsidePrefix()); it('outside prefix', () => testOutsidePrefix());
after(afterHook);
}); });
/* Configure test environment */ /* Configure test environment */
before(async () => { async function beforeHook() {
await configRestServer(); await configRestServer();
const netronPort = await mock_netron_server.start(); const netronPort = await mock_netron_server.start();
netronHost = `localhost:${netronPort}`; netronHost = `localhost:${netronPort}`;
UnitTestHelpers.setNetronUrl('http://' + netronHost); UnitTestHelpers.setNetronUrl('http://' + netronHost);
}); }
after(async () => { async function afterHook() {
await restServer.shutdown(); await restServer.shutdown();
resetGlobals(); globals.reset();
}); UnitTestHelpers.reset();
}
async function configRestServer(urlPrefix?: string) { async function configRestServer(urlPrefix?: string): Promise<void> {
if (restServer !== undefined) { if (restServer !== undefined) {
await restServer.shutdown(); await restServer.shutdown();
} }
globals.paths.logDirectory = path.join(__dirname, 'log'); globals.paths.logDirectory = path.join(__dirname, 'log');
UnitTestHelpers.setWebuiPath(path.join(__dirname, 'static')); UnitTestHelpers.setWebuiPath(path.join(__dirname, 'static'));
UnitTestHelpers.disableNniManager();
restServer = new RestServer(0, urlPrefix ?? ''); restServer = new RestServer(0, urlPrefix ?? '');
await restServer.start(); await restServer.start();
......
...@@ -45,7 +45,7 @@ describe('Unit Test for MountedStorageService', () => { ...@@ -45,7 +45,7 @@ describe('Unit Test for MountedStorageService', () => {
chai.should(); chai.should();
chai.use(chaiAsPromised); chai.use(chaiAsPromised);
prepareUnitTest(); prepareUnitTest();
log = getLogger(); log = getLogger('unittest');
const testRoot = path.dirname(__filename); const testRoot = path.dirname(__filename);
localPath = path.join(testRoot, localPath); localPath = path.join(testRoot, localPath);
......
...@@ -195,7 +195,7 @@ describe('Unit Test for TrialDispatcher', () => { ...@@ -195,7 +195,7 @@ describe('Unit Test for TrialDispatcher', () => {
chai.should(); chai.should();
chai.use(chaiAsPromised); chai.use(chaiAsPromised);
prepareUnitTest(); prepareUnitTest();
log = getLogger(); log = getLogger('unittest');
}); });
after(() => { after(() => {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment