Unverified Commit d7920fd2 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Add foreground mode in nnictl (#1956)

parent 26aa1136
...@@ -49,7 +49,7 @@ nnictl support commands: ...@@ -49,7 +49,7 @@ nnictl support commands:
|--config, -c| True| |YAML configure file of the experiment| |--config, -c| True| |YAML configure file of the experiment|
|--port, -p|False| |the port of restful server| |--port, -p|False| |the port of restful server|
|--debug, -d|False||set debug mode| |--debug, -d|False||set debug mode|
|--watch, -w|False||set watch mode| |--foreground, -f|False||set foreground mode, print log content to terminal|
* Examples * Examples
...@@ -98,7 +98,7 @@ Debug mode will disable version check function in Trialkeeper. ...@@ -98,7 +98,7 @@ Debug mode will disable version check function in Trialkeeper.
|id| True| |The id of the experiment you want to resume| |id| True| |The id of the experiment you want to resume|
|--port, -p| False| |Rest port of the experiment you want to resume| |--port, -p| False| |Rest port of the experiment you want to resume|
|--debug, -d|False||set debug mode| |--debug, -d|False||set debug mode|
|--watch, -w|False||set watch mode| |--foreground, -f|False||set foreground mode, print log content to terminal|
* Example * Example
......
...@@ -4,13 +4,11 @@ ...@@ -4,13 +4,11 @@
'use strict'; 'use strict';
import * as fs from 'fs'; import * as fs from 'fs';
import * as path from 'path';
import { Writable } from 'stream'; import { Writable } from 'stream';
import { WritableStreamBuffer } from 'stream-buffers'; import { WritableStreamBuffer } from 'stream-buffers';
import { format } from 'util'; import { format } from 'util';
import * as component from '../common/component'; import * as component from '../common/component';
import { getExperimentStartupInfo, isReadonly } from './experimentStartupInfo'; import { getExperimentStartupInfo, isReadonly } from './experimentStartupInfo';
import { getLogDir } from './utils';
const FATAL: number = 1; const FATAL: number = 1;
const ERROR: number = 2; const ERROR: number = 2;
...@@ -55,23 +53,21 @@ class BufferSerialEmitter { ...@@ -55,23 +53,21 @@ class BufferSerialEmitter {
@component.Singleton @component.Singleton
class Logger { class Logger {
private DEFAULT_LOGFILE: string = path.join(getLogDir(), 'nnimanager.log');
private level: number = INFO; private level: number = INFO;
private bufferSerialEmitter: BufferSerialEmitter; private bufferSerialEmitter?: BufferSerialEmitter;
private writable: Writable; private writable?: Writable;
private readonly: boolean = false; private readonly: boolean = false;
constructor(fileName?: string) { constructor(fileName?: string) {
let logFile: string | undefined = fileName; const logFile: string | undefined = fileName;
if (logFile === undefined) { if (logFile) {
logFile = this.DEFAULT_LOGFILE; this.writable = fs.createWriteStream(logFile, {
flags: 'a+',
encoding: 'utf8',
autoClose: true
});
this.bufferSerialEmitter = new BufferSerialEmitter(this.writable);
} }
this.writable = fs.createWriteStream(logFile, {
flags: 'a+',
encoding: 'utf8',
autoClose: true
});
this.bufferSerialEmitter = new BufferSerialEmitter(this.writable);
const logLevelName: string = getExperimentStartupInfo() const logLevelName: string = getExperimentStartupInfo()
.getLogLevel(); .getLogLevel();
...@@ -84,7 +80,9 @@ class Logger { ...@@ -84,7 +80,9 @@ class Logger {
} }
public close(): void { public close(): void {
this.writable.destroy(); if (this.writable) {
this.writable.destroy();
}
} }
public trace(...param: any[]): void { public trace(...param: any[]): void {
...@@ -128,12 +126,15 @@ class Logger { ...@@ -128,12 +126,15 @@ class Logger {
*/ */
private log(level: string, param: any[]): void { private log(level: string, param: any[]): void {
if (!this.readonly) { if (!this.readonly) {
const buffer: WritableStreamBuffer = new WritableStreamBuffer(); const logContent = `[${(new Date()).toLocaleString()}] ${level} ${format(param)}\n`;
buffer.write(`[${(new Date()).toLocaleString()}] ${level} `); if (this.writable && this.bufferSerialEmitter) {
buffer.write(format(param)); const buffer: WritableStreamBuffer = new WritableStreamBuffer();
buffer.write('\n'); buffer.write(logContent);
buffer.end(); buffer.end();
this.bufferSerialEmitter.feed(buffer.getContents()); this.bufferSerialEmitter.feed(buffer.getContents());
} else {
console.log(logContent);
}
} }
} }
} }
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
import { Container, Scope } from 'typescript-ioc'; import { Container, Scope } from 'typescript-ioc';
import * as fs from 'fs'; import * as fs from 'fs';
import * as path from 'path';
import * as component from './common/component'; import * as component from './common/component';
import { Database, DataStore } from './common/datastore'; import { Database, DataStore } from './common/datastore';
import { setExperimentStartupInfo } from './common/experimentStartupInfo'; import { setExperimentStartupInfo } from './common/experimentStartupInfo';
...@@ -34,7 +35,7 @@ function initStartupInfo( ...@@ -34,7 +35,7 @@ function initStartupInfo(
setExperimentStartupInfo(createNew, expId, basePort, logDirectory, experimentLogLevel, readonly); setExperimentStartupInfo(createNew, expId, basePort, logDirectory, experimentLogLevel, readonly);
} }
async function initContainer(platformMode: string, logFileName?: string): Promise<void> { async function initContainer(foreground: boolean, platformMode: string, logFileName?: string): Promise<void> {
if (platformMode === 'local') { if (platformMode === 'local') {
Container.bind(TrainingService) Container.bind(TrainingService)
.to(LocalTrainingService) .to(LocalTrainingService)
...@@ -71,6 +72,12 @@ async function initContainer(platformMode: string, logFileName?: string): Promis ...@@ -71,6 +72,12 @@ async function initContainer(platformMode: string, logFileName?: string): Promis
Container.bind(DataStore) Container.bind(DataStore)
.to(NNIDataStore) .to(NNIDataStore)
.scope(Scope.Singleton); .scope(Scope.Singleton);
const DEFAULT_LOGFILE: string = path.join(getLogDir(), 'nnimanager.log');
if (foreground) {
logFileName = undefined;
} else if (logFileName === undefined) {
logFileName = DEFAULT_LOGFILE;
}
Container.bind(Logger).provider({ Container.bind(Logger).provider({
get: (): Logger => new Logger(logFileName) get: (): Logger => new Logger(logFileName)
}); });
...@@ -81,7 +88,7 @@ async function initContainer(platformMode: string, logFileName?: string): Promis ...@@ -81,7 +88,7 @@ async function initContainer(platformMode: string, logFileName?: string): Promis
function usage(): void { function usage(): void {
console.info('usage: node main.js --port <port> --mode \ console.info('usage: node main.js --port <port> --mode \
<local/remote/pai/kubeflow/frameworkcontroller/paiYarn> --start_mode <new/resume> --experiment_id <id>'); <local/remote/pai/kubeflow/frameworkcontroller/paiYarn> --start_mode <new/resume> --experiment_id <id> --foreground <true/false>');
} }
const strPort: string = parseArg(['--port', '-p']); const strPort: string = parseArg(['--port', '-p']);
...@@ -90,6 +97,14 @@ if (!strPort || strPort.length === 0) { ...@@ -90,6 +97,14 @@ if (!strPort || strPort.length === 0) {
process.exit(1); process.exit(1);
} }
const foregroundArg: string = parseArg(['--foreground', '-f']);
if (!('true' || 'false').includes(foregroundArg.toLowerCase())) {
console.log(`FATAL: foreground property should only be true or false`);
usage();
process.exit(1);
}
const foreground: boolean = foregroundArg.toLowerCase() === 'true' ? true : false;
const port: number = parseInt(strPort, 10); const port: number = parseInt(strPort, 10);
const mode: string = parseArg(['--mode', '-m']); const mode: string = parseArg(['--mode', '-m']);
...@@ -138,7 +153,7 @@ initStartupInfo(startMode, experimentId, port, logDir, logLevel, readonly); ...@@ -138,7 +153,7 @@ initStartupInfo(startMode, experimentId, port, logDir, logLevel, readonly);
mkDirP(getLogDir()) mkDirP(getLogDir())
.then(async () => { .then(async () => {
try { try {
await initContainer(mode); await initContainer(foreground, mode);
const restServer: NNIRestServer = component.get(NNIRestServer); const restServer: NNIRestServer = component.get(NNIRestServer);
await restServer.start(); await restServer.start();
const log: Logger = getLogger(); const log: Logger = getLogger();
...@@ -162,6 +177,15 @@ function getStopSignal(): any { ...@@ -162,6 +177,15 @@ function getStopSignal(): any {
} }
} }
function getCtrlCSignal(): any {
return 'SIGINT';
}
process.on(getCtrlCSignal(), async () => {
const log: Logger = getLogger();
log.info(`Get SIGINT signal!`);
});
process.on(getStopSignal(), async () => { process.on(getStopSignal(), async () => {
const log: Logger = getLogger(); const log: Logger = getLogger();
let hasError: boolean = false; let hasError: boolean = false;
......
...@@ -9,7 +9,7 @@ import random ...@@ -9,7 +9,7 @@ import random
import site import site
import time import time
import tempfile import tempfile
from subprocess import Popen, check_call, CalledProcessError from subprocess import Popen, check_call, CalledProcessError, PIPE, STDOUT
from nni_annotation import expand_annotations, generate_search_space from nni_annotation import expand_annotations, generate_search_space
from nni.constants import ModuleName, AdvisorModuleName from nni.constants import ModuleName, AdvisorModuleName
from .launcher_utils import validate_all_content from .launcher_utils import validate_all_content
...@@ -20,7 +20,7 @@ from .common_utils import get_yml_content, get_json_content, print_error, print_ ...@@ -20,7 +20,7 @@ from .common_utils import get_yml_content, get_json_content, print_error, print_
detect_port, get_user, get_python_dir detect_port, get_user, get_python_dir
from .constants import NNICTL_HOME_DIR, ERROR_INFO, REST_TIME_OUT, EXPERIMENT_SUCCESS_INFO, LOG_HEADER, PACKAGE_REQUIREMENTS from .constants import NNICTL_HOME_DIR, ERROR_INFO, REST_TIME_OUT, EXPERIMENT_SUCCESS_INFO, LOG_HEADER, PACKAGE_REQUIREMENTS
from .command_utils import check_output_command, kill_command from .command_utils import check_output_command, kill_command
from .nnictl_utils import update_experiment, set_monitor from .nnictl_utils import update_experiment
def get_log_path(config_file_name): def get_log_path(config_file_name):
'''generate stdout and stderr log path''' '''generate stdout and stderr log path'''
...@@ -78,17 +78,17 @@ def get_nni_installation_path(): ...@@ -78,17 +78,17 @@ def get_nni_installation_path():
print_error('Fail to find nni under python library') print_error('Fail to find nni under python library')
exit(1) exit(1)
def start_rest_server(port, platform, mode, config_file_name, experiment_id=None, log_dir=None, log_level=None): def start_rest_server(args, platform, mode, config_file_name, experiment_id=None, log_dir=None, log_level=None):
'''Run nni manager process''' '''Run nni manager process'''
if detect_port(port): if detect_port(args.port):
print_error('Port %s is used by another process, please reset the port!\n' \ print_error('Port %s is used by another process, please reset the port!\n' \
'You could use \'nnictl create --help\' to get help information' % port) 'You could use \'nnictl create --help\' to get help information' % args.port)
exit(1) exit(1)
if (platform != 'local') and detect_port(int(port) + 1): if (platform != 'local') and detect_port(int(args.port) + 1):
print_error('PAI mode need an additional adjacent port %d, and the port %d is used by another process!\n' \ print_error('PAI mode need an additional adjacent port %d, and the port %d is used by another process!\n' \
'You could set another port to start experiment!\n' \ 'You could set another port to start experiment!\n' \
'You could use \'nnictl create --help\' to get help information' % ((int(port) + 1), (int(port) + 1))) 'You could use \'nnictl create --help\' to get help information' % ((int(args.port) + 1), (int(args.port) + 1)))
exit(1) exit(1)
print_normal('Starting restful server...') print_normal('Starting restful server...')
...@@ -99,7 +99,7 @@ def start_rest_server(port, platform, mode, config_file_name, experiment_id=None ...@@ -99,7 +99,7 @@ def start_rest_server(port, platform, mode, config_file_name, experiment_id=None
node_command = 'node' node_command = 'node'
if sys.platform == 'win32': if sys.platform == 'win32':
node_command = os.path.join(entry_dir[:-3], 'Scripts', 'node.exe') node_command = os.path.join(entry_dir[:-3], 'Scripts', 'node.exe')
cmds = [node_command, entry_file, '--port', str(port), '--mode', platform] cmds = [node_command, entry_file, '--port', str(args.port), '--mode', platform]
if mode == 'view': if mode == 'view':
cmds += ['--start_mode', 'resume'] cmds += ['--start_mode', 'resume']
cmds += ['--readonly', 'true'] cmds += ['--readonly', 'true']
...@@ -111,6 +111,8 @@ def start_rest_server(port, platform, mode, config_file_name, experiment_id=None ...@@ -111,6 +111,8 @@ def start_rest_server(port, platform, mode, config_file_name, experiment_id=None
cmds += ['--log_level', log_level] cmds += ['--log_level', log_level]
if mode in ['resume', 'view']: if mode in ['resume', 'view']:
cmds += ['--experiment_id', experiment_id] cmds += ['--experiment_id', experiment_id]
if args.foreground:
cmds += ['--foreground', 'true']
stdout_full_path, stderr_full_path = get_log_path(config_file_name) stdout_full_path, stderr_full_path = get_log_path(config_file_name)
with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file: with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file:
time_now = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) time_now = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
...@@ -120,9 +122,15 @@ def start_rest_server(port, platform, mode, config_file_name, experiment_id=None ...@@ -120,9 +122,15 @@ def start_rest_server(port, platform, mode, config_file_name, experiment_id=None
stderr_file.write(log_header) stderr_file.write(log_header)
if sys.platform == 'win32': if sys.platform == 'win32':
from subprocess import CREATE_NEW_PROCESS_GROUP from subprocess import CREATE_NEW_PROCESS_GROUP
process = Popen(cmds, cwd=entry_dir, stdout=stdout_file, stderr=stderr_file, creationflags=CREATE_NEW_PROCESS_GROUP) if args.foreground:
process = Popen(cmds, cwd=entry_dir, stdout=PIPE, stderr=STDOUT, creationflags=CREATE_NEW_PROCESS_GROUP)
else:
process = Popen(cmds, cwd=entry_dir, stdout=stdout_file, stderr=stderr_file, creationflags=CREATE_NEW_PROCESS_GROUP)
else: else:
process = Popen(cmds, cwd=entry_dir, stdout=stdout_file, stderr=stderr_file) if args.foreground:
process = Popen(cmds, cwd=entry_dir, stdout=PIPE, stderr=PIPE)
else:
process = Popen(cmds, cwd=entry_dir, stdout=stdout_file, stderr=stderr_file)
return process, str(time_now) return process, str(time_now)
def set_trial_config(experiment_config, port, config_file_name): def set_trial_config(experiment_config, port, config_file_name):
...@@ -424,7 +432,7 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen ...@@ -424,7 +432,7 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
if log_level not in ['trace', 'debug'] and (args.debug or experiment_config.get('debug') is True): if log_level not in ['trace', 'debug'] and (args.debug or experiment_config.get('debug') is True):
log_level = 'debug' log_level = 'debug'
# start rest server # start rest server
rest_process, start_time = start_rest_server(args.port, experiment_config['trainingServicePlatform'], \ rest_process, start_time = start_rest_server(args, experiment_config['trainingServicePlatform'], \
mode, config_file_name, experiment_id, log_dir, log_level) mode, config_file_name, experiment_id, log_dir, log_level)
nni_config.set_config('restServerPid', rest_process.pid) nni_config.set_config('restServerPid', rest_process.pid)
# Deal with annotation # Deal with annotation
...@@ -493,8 +501,14 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen ...@@ -493,8 +501,14 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
experiment_config['experimentName']) experiment_config['experimentName'])
print_normal(EXPERIMENT_SUCCESS_INFO % (experiment_id, ' '.join(web_ui_url_list))) print_normal(EXPERIMENT_SUCCESS_INFO % (experiment_id, ' '.join(web_ui_url_list)))
if args.watch: if args.foreground:
set_monitor(True, 3, args.port, rest_process.pid) try:
while True:
log_content = rest_process.stdout.readline().strip().decode('utf-8')
print(log_content)
except KeyboardInterrupt:
kill_command(rest_process.pid)
print_normal('Stopping experiment...')
def create_experiment(args): def create_experiment(args):
'''start a new experiment''' '''start a new experiment'''
......
...@@ -51,7 +51,7 @@ def parse_args(): ...@@ -51,7 +51,7 @@ def parse_args():
parser_start.add_argument('--config', '-c', required=True, dest='config', help='the path of yaml config file') parser_start.add_argument('--config', '-c', required=True, dest='config', help='the path of yaml config file')
parser_start.add_argument('--port', '-p', default=DEFAULT_REST_PORT, dest='port', help='the port of restful server') parser_start.add_argument('--port', '-p', default=DEFAULT_REST_PORT, dest='port', help='the port of restful server')
parser_start.add_argument('--debug', '-d', action='store_true', help=' set debug mode') parser_start.add_argument('--debug', '-d', action='store_true', help=' set debug mode')
parser_start.add_argument('--watch', '-w', action='store_true', help=' set watch mode') parser_start.add_argument('--foreground', '-f', action='store_true', help=' set foreground mode, print log content to terminal')
parser_start.set_defaults(func=create_experiment) parser_start.set_defaults(func=create_experiment)
# parse resume command # parse resume command
...@@ -59,7 +59,7 @@ def parse_args(): ...@@ -59,7 +59,7 @@ def parse_args():
parser_resume.add_argument('id', nargs='?', help='The id of the experiment you want to resume') parser_resume.add_argument('id', nargs='?', help='The id of the experiment you want to resume')
parser_resume.add_argument('--port', '-p', default=DEFAULT_REST_PORT, dest='port', help='the port of restful server') parser_resume.add_argument('--port', '-p', default=DEFAULT_REST_PORT, dest='port', help='the port of restful server')
parser_resume.add_argument('--debug', '-d', action='store_true', help=' set debug mode') parser_resume.add_argument('--debug', '-d', action='store_true', help=' set debug mode')
parser_resume.add_argument('--watch', '-w', action='store_true', help=' set watch mode') parser_resume.add_argument('--foreground', '-f', action='store_true', help=' set foreground mode, print log content to terminal')
parser_resume.set_defaults(func=resume_experiment) parser_resume.set_defaults(func=resume_experiment)
# parse view command # parse view command
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment