Unverified Commit 290824c1 authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Refactor NNI manager globals (step 5) - shutdown (#4706)

parent 18962129
......@@ -17,8 +17,9 @@
import assert from 'assert/strict';
import { NniManagerArgs, parseArgs } from './arguments';
import { NniPaths, createPaths } from './paths';
import { LogStream, initLogStream } from './log_stream';
import { NniPaths, createPaths } from './paths';
import { ShutdownManager } from './shutdown';
export { NniManagerArgs, NniPaths };
......@@ -31,6 +32,7 @@ export { NniManagerArgs, NniPaths };
export interface NniGlobals {
readonly args: NniManagerArgs;
readonly paths: NniPaths;
readonly shutdown: ShutdownManager;
readonly logStream: LogStream;
}
......@@ -57,7 +59,8 @@ export function initGlobals(): void {
const args = parseArgs(process.argv.slice(2));
const paths = createPaths(args);
const logStream = initLogStream(args, paths);
const shutdown = new ShutdownManager();
const globals: NniGlobals = { args, paths, logStream };
const globals: NniGlobals = { args, paths, logStream, shutdown };
Object.assign(global.nni, globals);
}
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
/**
* Shutdown manager.
*
* Each standalone module should register its clean up method with:
*
* globals.shutdown.register('MyModule', async () => { this.cleanUp(); });
*
* If a module is a children of another module (for example NNIDataStore is a children of NNIManager),
* it should not register shutdown callback on its own,
* instead the parent module should take care of its destruction.
*
* Upon shutdown, all callbacks will be invoked *concurrently*. No guarantee on order.
*
* A module can request for shutdown when unrecoverable error occurs:
*
* try {
* this.doSomethingMustSuccess();
* } catch (error) {
* globals.shutdown.criticalError('MyModule', error);
* }
*
* Note that when a module invokes `criticalError()`, its own registered callback will not get called.
*
* When editting this module, keep robustness in mind.
* Bugs in this module can easily swallow logs and make it difficult to reproduce users' issue.
**/
import { Logger, getRobustLogger } from 'common/log';
const logger: Logger = getRobustLogger('ShutdownManager');
export class ShutdownManager {
private processStatus: 'initializing' | 'running' | 'stopping' = 'initializing';
private modules: Map<string, () => Promise<void>> = new Map();
private hasError: boolean = false;
public register(moduleName: string, shutdownCallback: () => Promise<void>): void {
if (this.modules.has(moduleName)) {
logger.error(`Module ${moduleName} has registered twice.`, new Error().stack);
}
this.modules.set(moduleName, shutdownCallback);
}
public initiate(reason: string): void {
if (this.processStatus === 'stopping') {
logger.warning('initiate() invoked but already stopping:', reason);
} else {
logger.info('Initiate shutdown:', reason);
this.shutdown();
}
}
public criticalError(moduleName: string, error: Error): void {
logger.critical(`Critical error ocurred in module ${moduleName}:`, error);
this.hasError = true;
if (this.processStatus === 'initializing') {
logger.error('Starting failed.');
process.exit(1);
} else if (this.processStatus !== 'stopping') {
this.modules.delete(moduleName);
this.shutdown();
}
}
public notifyInitializeComplete(): void {
if (this.processStatus === 'initializing') {
this.processStatus = 'running';
} else {
logger.error('notifyInitializeComplete() invoked in status', this.processStatus);
}
}
private shutdown(): void {
this.processStatus = 'stopping';
const promises = Array.from(this.modules).map(async ([moduleName, callback]) => {
try {
await callback();
} catch (error) {
logger.error(`Error during shutting down ${moduleName}:`, error);
this.hasError = true;
}
this.modules.delete(moduleName);
});
const timeoutTimer = setTimeout(async () => {
try {
logger.error('Following modules failed to shut down in time:', this.modules.keys());
await global.nni.logStream.close();
} finally {
process.exit(1);
}
}, shutdownTimeout);
Promise.all(promises).then(async () => {
try {
clearTimeout(timeoutTimer);
logger.info('Shutdown complete.');
await global.nni.logStream.close();
} finally {
process.exit(this.hasError ? 1 : 0);
}
});
}
}
let shutdownTimeout: number = 60_000;
export namespace UnitTestHelpers {
export function setShutdownTimeout(ms: number): void {
shutdownTimeout = ms;
}
}
......@@ -7,10 +7,16 @@
*
* Use this module to replace NNI globals with mocked values:
*
* import 'common/globals/unittest';
*
* Or:
*
* import globals from 'common/globals/unittest';
*
* You can then edit these mocked globals and the injection will be visible to all modules.
* Remember to invoke `resetGlobals()` in "after()" hook if you do so.
*
* Attention: TypeScript will remove "unused" import statements. Use the first format when "globals" is never used.
**/
import os from 'os';
......@@ -49,11 +55,14 @@ export function resetGlobals(): void {
const logStream = {
writeLine: (_line: string): void => { /* dummy */ },
writeLineSync: (_line: string): void => { /* dummy */ },
close: (): void => { /* dummy */ }
close: async (): Promise<void> => { /* dummy */ }
};
const shutdown = {
register: (..._: any): void => { /* dummy */ },
};
const globalAsAny = global as any;
const utGlobals = { args, paths, logStream, reset: resetGlobals };
const utGlobals = { args, paths, logStream, shutdown, reset: resetGlobals };
if (globalAsAny.nni === undefined) {
globalAsAny.nni = utGlobals;
} else {
......
......@@ -10,6 +10,9 @@
* Outputs:
*
* [1970-01-01 00:00:00] INFO (moduleName) hello { to: 'world' }
*
* Loggers use `util.inspect()` to format values,
* so objects will be smartly stringified and exceptions will include stack trace.
**/
import util from 'util';
......
......@@ -36,9 +36,6 @@ interface NNIManagerStatus {
abstract class Manager {
public abstract startExperiment(experimentConfig: ExperimentConfig): Promise<string>;
public abstract resumeExperiment(readonly: boolean): Promise<void>;
public abstract stopExperiment(): Promise<void>;
public abstract stopExperimentTopHalf(): Promise<void>;
public abstract stopExperimentBottomHalf(): Promise<void>;
public abstract getExperimentProfile(): Promise<ExperimentProfile>;
public abstract updateExperimentProfile(experimentProfile: ExperimentProfile, updateType: ProfileUpdateType): Promise<void>;
public abstract importData(data: string): Promise<void>;
......
......@@ -26,7 +26,6 @@ import {
REPORT_METRIC_DATA, REQUEST_TRIAL_JOBS, SEND_TRIAL_JOB_PARAMETER, TERMINATE, TRIAL_END, UPDATE_SEARCH_SPACE, IMPORT_DATA
} from './commands';
import { createDispatcherInterface, createDispatcherPipeInterface, IpcInterface } from './ipcInterface';
import { RestServer } from '../rest_server';
/**
* NNIManager which implements Manager interface
......@@ -76,6 +75,8 @@ class NNIManager implements Manager {
if (pipe !== null) {
this.dispatcher = createDispatcherPipeInterface(pipe);
}
globals.shutdown.register('NniManager', this.stopExperiment.bind(this));
}
public updateExperimentProfile(experimentProfile: ExperimentProfile, updateType: ProfileUpdateType): Promise<void> {
......@@ -294,12 +295,12 @@ class NNIManager implements Manager {
return this.dataStore.getTrialJobStatistics();
}
public async stopExperiment(): Promise<void> {
private async stopExperiment(): Promise<void> {
await this.stopExperimentTopHalf();
await this.stopExperimentBottomHalf();
}
public async stopExperimentTopHalf(): Promise<void> {
private async stopExperimentTopHalf(): Promise<void> {
this.setStatus('STOPPING');
this.log.info('Stopping experiment, cleaning up ...');
......@@ -323,7 +324,7 @@ class NNIManager implements Manager {
this.dispatcher = undefined;
}
public async stopExperimentBottomHalf(): Promise<void> {
private async stopExperimentBottomHalf(): Promise<void> {
try {
const trialJobList: TrialJobDetail[] = await this.trainingService.listTrialJobs();
......@@ -351,19 +352,9 @@ class NNIManager implements Manager {
this.setStatus('STOPPED');
this.log.info('Experiment stopped.');
let hasError: boolean = false;
try {
await this.experimentManager.stop();
await component.get<TensorboardManager>(TensorboardManager).stop();
await this.dataStore.close();
await component.get<RestServer>(RestServer).shutdown();
} catch (err) {
hasError = true;
this.log.error(`${err.stack}`);
} finally {
await globals.logStream.close();
process.exit(hasError ? 1 : 0);
}
await this.experimentManager.stop();
await component.get<TensorboardManager>(TensorboardManager).stop();
await this.dataStore.close();
}
public async getMetricData(trialJobId?: string, metricType?: MetricType): Promise<MetricDataRecord[]> {
......
......@@ -29,7 +29,7 @@ import * as component from 'common/component';
import { Database, DataStore } from 'common/datastore';
import { ExperimentManager } from 'common/experimentManager';
import globals, { initGlobals } from 'common/globals';
import { getLogger } from 'common/log';
import { Logger, getLogger } from 'common/log';
import { Manager } from 'common/manager';
import { TensorboardManager } from 'common/tensorboardManager';
import { NNIDataStore } from 'core/nniDataStore';
......@@ -41,8 +41,10 @@ import { RestServer } from 'rest_server';
import path from 'path';
const logger: Logger = getLogger('main');
async function start(): Promise<void> {
getLogger('main').info('Start NNI manager');
logger.info('Start NNI manager');
Container.bind(Manager).to(NNIManager).scope(Scope.Singleton);
Container.bind(Database).to(SqlDB).scope(Scope.Singleton);
......@@ -55,28 +57,26 @@ async function start(): Promise<void> {
const restServer = new RestServer(globals.args.port, globals.args.urlPrefix);
await restServer.start();
}
function shutdown(): void {
(component.get(Manager) as Manager).stopExperiment();
globals.shutdown.notifyInitializeComplete();
}
// Register callbacks to free training service resources on unexpected shutdown.
// A graceful stop should use REST API,
// because interrupts can cause strange behaviors in children processes.
process.on('SIGTERM', shutdown);
process.on('SIGBREAK', shutdown);
process.on('SIGINT', shutdown);
process.on('SIGTERM', () => { globals.shutdown.initiate('SIGTERM'); });
process.on('SIGBREAK', () => { globals.shutdown.initiate('SIGBREAK'); });
process.on('SIGINT', () => { globals.shutdown.initiate('SIGINT'); });
/* main */
initGlobals();
start().then(() => {
getLogger('main').debug('start() returned.');
logger.debug('start() returned.');
}).catch((error) => {
try {
getLogger('main').error('Failed to start:', error);
logger.error('Failed to start:', error);
} catch (loggerError) {
console.error('Failed to start:', error);
console.error('Seems logger is faulty:', loggerError);
......
......@@ -14,11 +14,10 @@
* Remember to update them if the values are changed, or if this file is moved.
*
* TODO:
* 1. Add a global function to handle critical error.
* 2. Refactor ClusterJobRestServer to an express-ws application so it doesn't require extra port.
* 3. Provide public API to register express app, so this can be decoupled with other modules' implementation.
* 4. Refactor NNIRestHandler. It's a mess.
* 5. Deal with log path mismatch between REST API and file system.
* 1. Refactor ClusterJobRestServer to an express-ws application so it doesn't require extra port.
* 2. Provide public API to register express app, so this can be decoupled with other modules' implementation.
* 3. Refactor NNIRestHandler. It's a mess.
* 4. Deal with log path mismatch between REST API and file system.
**/
import assert from 'assert/strict';
......@@ -34,6 +33,8 @@ import globals from 'common/globals';
import { Logger, getLogger } from 'common/log';
import { createRestHandler } from './restHandler';
const logger: Logger = getLogger('RestServer');
/**
* The singleton REST server that dispatches web UI and `Experiment` requests.
*
......@@ -44,24 +45,22 @@ export class RestServer {
private port: number;
private urlPrefix: string;
private server: Server | null = null;
private logger: Logger = getLogger('RestServer');
constructor(port: number, urlPrefix: string) {
assert(!urlPrefix.startsWith('/') && !urlPrefix.endsWith('/'));
this.port = port;
this.urlPrefix = urlPrefix;
globals.shutdown.register('RestServer', this.shutdown.bind(this));
}
// The promise is resolved when it's ready to serve requests.
// This worth nothing for now,
// but for example if we connect to tuner using WebSocket then it must be launched after promise resolved.
public start(): Promise<void> {
this.logger.info(`Starting REST server at port ${this.port}, URL prefix: "/${this.urlPrefix}"`);
logger.info(`Starting REST server at port ${this.port}, URL prefix: "/${this.urlPrefix}"`);
const app = express();
// FIXME: We should have a global handler for critical errors.
// `shutdown()` is not a callback and should not be passed to NNIRestHandler.
app.use('/' + this.urlPrefix, rootRouter(this.shutdown.bind(this)));
app.use('/' + this.urlPrefix, rootRouter());
app.all('*', (_req: Request, res: Response) => { res.status(404).send(`Outside prefix "/${this.urlPrefix}"`); });
this.server = app.listen(this.port);
......@@ -70,31 +69,22 @@ export class RestServer {
if (this.port === 0) { // Currently for unit test, can be public feature in future.
this.port = (<AddressInfo>this.server!.address()).port;
}
this.logger.info('REST server started.');
logger.info('REST server started.');
deferred.resolve();
});
// FIXME: Use global handler. The event can be emitted after listening.
this.server.on('error', (error: Error) => {
this.logger.error('REST server error:', error);
deferred.reject(error);
});
this.server.on('error', (error: Error) => { globals.shutdown.criticalError('RestServer', error); });
return deferred.promise;
}
public shutdown(): Promise<void> {
this.logger.info('Stopping REST server.');
logger.info('Stopping REST server.');
if (this.server === null) {
this.logger.warning('REST server is not running.');
logger.warning('REST server is not running.');
return Promise.resolve();
}
const deferred = new Deferred<void>();
this.server.close(() => {
this.logger.info('REST server stopped.');
deferred.resolve();
});
// FIXME: Use global handler. It should be aware of shutting down event and swallow errors in this stage.
this.server.on('error', (error: Error) => {
this.logger.error('REST server error:', error);
logger.info('REST server stopped.');
deferred.resolve();
});
return deferred.promise;
......@@ -109,12 +99,12 @@ export class RestServer {
*
* In fact experiments management should have a separate prefix and module.
**/
function rootRouter(stopCallback: () => Promise<void>): Router {
function rootRouter(): Router {
const router = Router();
router.use(express.json({ limit: '50mb' }));
/* NNI manager APIs */
router.use('/api/v1/nni', restHandlerFactory(stopCallback));
router.use('/api/v1/nni', restHandlerFactory());
/* Download log files */
// The REST API path "/logs" does not match file system path "/log".
......@@ -165,7 +155,7 @@ export namespace UnitTestHelpers {
}
export function disableNniManager(): void {
restHandlerFactory = (_: any): Router => Router();
restHandlerFactory = (): Router => Router();
}
export function reset(): void {
......
......@@ -8,6 +8,7 @@ import * as component from '../common/component';
import { DataStore, MetricDataRecord, TrialJobInfo } from '../common/datastore';
import { NNIError, NNIErrorNames } from '../common/errors';
import { isNewExperiment, isReadonly } from '../common/experimentStartupInfo';
import globals from 'common/globals';
import { getLogger, Logger } from '../common/log';
import { ExperimentProfile, Manager, TrialJobStatistics } from '../common/manager';
import { ExperimentManager } from '../common/experimentManager';
......@@ -22,17 +23,15 @@ import { TrialJobStatus } from '../common/trainingService';
//const expressJoi = require('express-joi-validator');
class NNIRestHandler {
private stopCallback: () => Promise<void>;
private nniManager: Manager;
private experimentsManager: ExperimentManager;
private tensorboardManager: TensorboardManager;
private log: Logger;
constructor(stopCallback: () => Promise<void>) {
constructor() {
this.nniManager = component.get(Manager);
this.experimentsManager = component.get(ExperimentManager);
this.tensorboardManager = component.get(TensorboardManager);
this.stopCallback = stopCallback;
this.log = getLogger('NNIRestHandler');
}
......@@ -124,7 +123,7 @@ class NNIRestHandler {
this.handleError(err, res);
this.log.error(err.message);
this.log.error(`Datastore initialize failed, stopping rest server...`);
await this.stopCallback();
globals.shutdown.criticalError('RestHandler', err);
});
});
}
......@@ -416,10 +415,8 @@ class NNIRestHandler {
private stop(router: Router): void {
router.delete('/experiment', (_req: Request, res: Response) => {
this.nniManager.stopExperimentTopHalf().then(() => {
res.send();
this.nniManager.stopExperimentBottomHalf();
});
res.send();
globals.shutdown.initiate('REST request');
});
}
......@@ -433,8 +430,6 @@ class NNIRestHandler {
}
}
export function createRestHandler(stopCallback: () => Promise<void>): Router {
const handler: NNIRestHandler = new NNIRestHandler(stopCallback);
return handler.createRestHandler();
export function createRestHandler(): Router {
return new NNIRestHandler().createRestHandler();
}
......@@ -3,6 +3,7 @@
import assert from 'assert/strict';
import fs from 'fs';
import os from 'os';
import path from 'path';
import { setTimeout } from 'timers/promises';
......@@ -74,7 +75,7 @@ describe('## globals.log_stream ##', () => {
/* configure test environment */
const origConsoleLog = console.log;
const tempDir = fs.mkdtempSync('nni-ut-');
const tempDir = fs.mkdtempSync(path.join(os.tmpdir(), 'nni-ut-'));
function beforeHook() {
console.log = (line => { consoleContent += line + '\n'; });
......
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
import assert from 'assert/strict';
import { setTimeout } from 'timers/promises';
import { Deferred } from 'ts-deferred';
import 'common/globals/unittest';
import { ShutdownManager, UnitTestHelpers } from 'common/globals/shutdown';
/* environment */
UnitTestHelpers.setShutdownTimeout(10);
let shutdown: ShutdownManager = new ShutdownManager();
let callbackCount: number[] = [ 0, 0 ];
let exitCode: number | null = null;
/* test cases */
// Test a normal shutdown.
// Each callback should be invoked once.
async function testShutdown(): Promise<void> {
shutdown.register('ModuleA', async () => { callbackCount[0] += 1; });
shutdown.register('ModuleB', async () => { callbackCount[1] += 1; });
shutdown.initiate('unittest');
await setTimeout(10);
assert.deepEqual(callbackCount, [ 1, 1 ]);
assert.equal(exitCode, 0);
}
// Test a shutdown caused by critical error.
// The faulty module's callback will not be invoked by design.
async function testError(): Promise<void> {
shutdown.notifyInitializeComplete();
shutdown.register('ModuleA', async () => { callbackCount[0] += 1; });
shutdown.register('ModuleB', async () => { callbackCount[1] += 1; });
shutdown.criticalError('ModuleA', new Error('test critical error'));
await setTimeout(10);
assert.deepEqual(callbackCount, [ 0, 1 ]);
assert.equal(exitCode, 1);
}
// Test a shutdown caused by critical error in initializing phase.
// Current implementation does not invoke callbacks in this case, so the timeout is 0.
// If you have modified shutdown logic and this case failed, check the timeout.
async function testInitError(): Promise<void> {
shutdown.register('ModuleA', async () => { callbackCount[0] += 1; });
shutdown.criticalError('ModuleA', new Error('test init error'));
await setTimeout();
assert.equal(exitCode, 1);
}
// Simulate an error inside shutdown callback.
async function testCallbackError(): Promise<void> {
shutdown.notifyInitializeComplete();
shutdown.register('ModuleA', async () => { callbackCount[0] += 1; });
shutdown.register('ModuleB', async () => {
callbackCount[1] += 1;
throw new Error('Module B callback error');
});
shutdown.initiate('unittest');
await setTimeout(10);
assert.deepEqual(callbackCount, [ 1, 1 ]);
assert.equal(exitCode, 1);
}
// Simulate unresponsive shutdown callback.
// Pay attention that timeout handler does not explicitly cancel shutdown callback
// because in real world it terminates the process.
// But in mocked environment process.exit() is overwritten so the callback will eventually finish,
// and it can cause another process.exit().
// Make sure not to recover mocked process.exit() before the callback finish.
async function testTimeout(): Promise<void> {
const deferred = new Deferred<void>();
shutdown.register('ModuleA', async () => { callbackCount[0] += 1; });
shutdown.register('ModuleB', async () => {
await setTimeout(30); // we have set timeout to 10 ms so this times out
callbackCount[1] += 1;
deferred.resolve();
});
shutdown.initiate('unittest');
await setTimeout(20);
assert.deepEqual(callbackCount, [ 1, 0 ]);
assert.equal(exitCode, 1);
// if we don't await, process.exit() will be recovered and it will terminate testing.
await deferred.promise;
}
/* register */
describe('## globals.shutdown ##', () => {
before(beforeHook);
beforeEach(beforeEachHook);
it('normal', testShutdown);
it('on error', testError);
it('on init fail', testInitError);
it('callback raise error', testCallbackError);
it('timeout', testTimeout);
after(afterHook);
});
/* hooks */
const origProcessExit = process.exit;
function beforeHook() {
process.exit = ((code: number) => { exitCode = code; }) as any;
}
function beforeEachHook() {
shutdown = new ShutdownManager();
callbackCount = [ 0, 0 ];
exitCode = null;
}
function afterHook() {
process.exit = origProcessExit;
}
......@@ -8,6 +8,8 @@ import { Logger, getLogger, getRobustLogger } from 'common/log';
/* test cases */
// Write a log message in different format for each level.
// Checks the log stream contains all messages.
function testDebugLevel() {
stream.reset();
globals.args.logLevel = 'debug';
......@@ -26,6 +28,8 @@ function testDebugLevel() {
assert.equal(stderr, '');
}
// Write a log message in different format for each level.
// Check logs below specified log level are filtered.
function testWarningLevel() {
stream.reset();
globals.args.logLevel = 'warning';
......@@ -45,6 +49,8 @@ function testWarningLevel() {
assert.equal(stderr, '');
}
// Write some logs; simulate an error in log stream; then write other logs.
// Check logs after the error are written to stderr.
function testRobust() {
stream.reset();
globals.args.logLevel = 'info';
......
......@@ -135,7 +135,7 @@ describe('Unit test for nnimanager', function () {
after(async () => {
// FIXME
await nniManager.stopExperimentTopHalf();
await (nniManager as any).stopExperimentTopHalf();
cleanupUnitTest();
})
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment