"vscode:/vscode.git/clone" did not exist on "4fb81e008c44f73c64c3dc8f9937dd4c1c6d20d8"
Unverified Commit cbb63c5b authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

WebSocket (step 4) - NAS experiment (#4825)

parent e8de5eb4
...@@ -14,6 +14,7 @@ from typing import Any ...@@ -14,6 +14,7 @@ from typing import Any
import colorama import colorama
import psutil import psutil
from typing_extensions import Literal
import nni.runtime.log import nni.runtime.log
...@@ -78,7 +79,7 @@ class Experiment: ...@@ -78,7 +79,7 @@ class Experiment:
self.id: str = management.generate_experiment_id() self.id: str = management.generate_experiment_id()
self.port: int | None = None self.port: int | None = None
self._proc: Popen | psutil.Process | None = None self._proc: Popen | psutil.Process | None = None
self._action = 'create' self._action: Literal['create', 'resume', 'view'] = 'create'
self.url_prefix: str | None = None self.url_prefix: str | None = None
if isinstance(config_or_platform, (str, list)): if isinstance(config_or_platform, (str, list)):
......
...@@ -13,18 +13,19 @@ import socket ...@@ -13,18 +13,19 @@ import socket
from subprocess import Popen from subprocess import Popen
import sys import sys
import time import time
from typing import Optional, Tuple, List, Any from typing import Any, TYPE_CHECKING, cast
import colorama import colorama
from typing_extensions import Literal
import nni.runtime.protocol
from .config import ExperimentConfig from .config import ExperimentConfig
from .pipe import Pipe
from . import rest from . import rest
from ..tools.nnictl.config_utils import Experiments, Config from ..tools.nnictl.config_utils import Experiments, Config
from ..tools.nnictl.nnictl_utils import update_experiment from ..tools.nnictl.nnictl_utils import update_experiment
if TYPE_CHECKING:
from .experiment import RunMode
_logger = logging.getLogger('nni.experiment') _logger = logging.getLogger('nni.experiment')
@dataclass(init=False) @dataclass(init=False)
...@@ -32,16 +33,24 @@ class NniManagerArgs: ...@@ -32,16 +33,24 @@ class NniManagerArgs:
# argv sent to "ts/nni_manager/main.js" # argv sent to "ts/nni_manager/main.js"
port: int port: int
experiment_id: int experiment_id: str
action: str # 'new', 'resume', 'view' action: Literal['create', 'resume', 'view']
mode: str # training service platform, to be removed mode: str # training service platform, to be removed
experiments_directory: str # renamed "config.nni_experiments_directory", must be absolute experiments_directory: str # renamed "config.nni_experiments_directory", must be absolute
log_level: str log_level: str
foreground: bool = False foreground: bool = False
url_prefix: str | None = None # leading and trailing "/" must be stripped url_prefix: str | None = None # leading and trailing "/" must be stripped
dispatcher_pipe: str | None = None tuner_command_channel: str | None = None
def __init__(self, action, exp_id, config, port, debug, foreground, url_prefix): def __init__(self,
action: Literal['create', 'resume', 'view'],
exp_id: str,
config: ExperimentConfig,
port: int,
debug: bool,
foreground: bool,
url_prefix: str | None,
tuner_command_channel: str | None):
self.port = port self.port = port
self.experiment_id = exp_id self.experiment_id = exp_id
self.action = action self.action = action
...@@ -49,18 +58,19 @@ class NniManagerArgs: ...@@ -49,18 +58,19 @@ class NniManagerArgs:
self.url_prefix = url_prefix self.url_prefix = url_prefix
# config field name "experiment_working_directory" is a mistake # config field name "experiment_working_directory" is a mistake
# see "ts/nni_manager/common/globals/arguments.ts" for details # see "ts/nni_manager/common/globals/arguments.ts" for details
self.experiments_directory = config.experiment_working_directory self.experiments_directory = cast(str, config.experiment_working_directory)
self.tuner_command_channel = tuner_command_channel
if isinstance(config.training_service, list): if isinstance(config.training_service, list):
self.mode = 'hybrid' self.mode = 'hybrid'
else: else:
self.mode = config.training_service.platform self.mode = config.training_service.platform
self.log_level = config.log_level self.log_level = cast(str, config.log_level)
if debug and self.log_level not in ['debug', 'trace']: if debug and self.log_level not in ['debug', 'trace']:
self.log_level = 'debug' self.log_level = 'debug'
def to_command_line_args(self): def to_command_line_args(self) -> list[str]:
# reformat fields to meet yargs library's format # reformat fields to meet yargs library's format
# see "ts/nni_manager/common/globals/arguments.ts" for details # see "ts/nni_manager/common/globals/arguments.ts" for details
ret = [] ret = []
...@@ -74,11 +84,21 @@ class NniManagerArgs: ...@@ -74,11 +84,21 @@ class NniManagerArgs:
ret.append(str(value)) ret.append(str(value))
return ret return ret
def start_experiment(action, exp_id, config, port, debug, run_mode, url_prefix): def start_experiment(
action: Literal['create', 'resume', 'view'],
exp_id: str,
config: ExperimentConfig,
port: int,
debug: bool,
run_mode: RunMode,
url_prefix: str | None,
tuner_command_channel: str | None = None,
tags: list[str] = []) -> Popen:
foreground = run_mode.value == 'foreground' foreground = run_mode.value == 'foreground'
if url_prefix is not None: if url_prefix is not None:
url_prefix = url_prefix.strip('/') url_prefix = url_prefix.strip('/')
nni_manager_args = NniManagerArgs(action, exp_id, config, port, debug, foreground, url_prefix) nni_manager_args = NniManagerArgs(action, exp_id, config, port, debug, foreground, url_prefix, tuner_command_channel)
_ensure_port_idle(port) _ensure_port_idle(port)
websocket_platforms = ['hybrid', 'remote', 'openpai', 'kubeflow', 'frameworkcontroller', 'adl'] websocket_platforms = ['hybrid', 'remote', 'openpai', 'kubeflow', 'frameworkcontroller', 'adl']
...@@ -112,8 +132,8 @@ def start_experiment(action, exp_id, config, port, debug, run_mode, url_prefix): ...@@ -112,8 +132,8 @@ def start_experiment(action, exp_id, config, port, debug, run_mode, url_prefix):
nni_manager_args.mode, nni_manager_args.mode,
config.experiment_name, config.experiment_name,
pid=proc.pid, pid=proc.pid,
logDir=config.experiment_working_directory, logDir=cast(str, config.experiment_working_directory),
tag=[], tag=tags,
prefixUrl=url_prefix prefixUrl=url_prefix
) )
...@@ -129,12 +149,12 @@ def start_experiment(action, exp_id, config, port, debug, run_mode, url_prefix): ...@@ -129,12 +149,12 @@ def start_experiment(action, exp_id, config, port, debug, run_mode, url_prefix):
return proc return proc
def _start_rest_server(nni_manager_args, run_mode) -> Popen: def _start_rest_server(nni_manager_args: NniManagerArgs, run_mode: RunMode) -> Popen:
import nni_node import nni_node
node_dir = Path(nni_node.__path__[0]) # type: ignore node_dir = Path(nni_node.__path__[0]) # type: ignore
node = str(node_dir / ('node.exe' if sys.platform == 'win32' else 'node')) node = str(node_dir / ('node.exe' if sys.platform == 'win32' else 'node'))
main_js = str(node_dir / 'main.js') main_js = str(node_dir / 'main.js')
cmd = [node, '--max-old-space-size=4096', main_js] cmd = [node, '--max-old-space-size=4096', '--trace-uncaught', main_js]
cmd += nni_manager_args.to_command_line_args() cmd += nni_manager_args.to_command_line_args()
if run_mode.value == 'detach': if run_mode.value == 'detach':
...@@ -157,45 +177,7 @@ def _start_rest_server(nni_manager_args, run_mode) -> Popen: ...@@ -157,45 +177,7 @@ def _start_rest_server(nni_manager_args, run_mode) -> Popen:
return Popen(cmd, stdout=out, stderr=err, cwd=node_dir, preexec_fn=os.setpgrp) # type: ignore return Popen(cmd, stdout=out, stderr=err, cwd=node_dir, preexec_fn=os.setpgrp) # type: ignore
def start_experiment_retiarii(exp_id, config, port, debug): def _ensure_port_idle(port: int, message: str | None = None) -> None:
pipe = None
proc = None
config.validate(initialized_tuner=True)
_ensure_port_idle(port)
if isinstance(config.training_service, list): # hybrid training service
_ensure_port_idle(port + 1, 'Hybrid training service requires an additional port')
elif config.training_service.platform in ['remote', 'openpai', 'kubeflow', 'frameworkcontroller', 'adl']:
_ensure_port_idle(port + 1, f'{config.training_service.platform} requires an additional port')
try:
_logger.info('Creating experiment, Experiment ID: %s', colorama.Fore.CYAN + exp_id + colorama.Style.RESET_ALL)
pipe = Pipe(exp_id)
start_time, proc = _start_rest_server_retiarii(config, port, debug, exp_id, pipe.path)
_logger.info('Connecting IPC pipe...')
pipe_file = pipe.connect()
nni.runtime.protocol._set_in_file(pipe_file)
nni.runtime.protocol._set_out_file(pipe_file)
_logger.info('Starting web server...')
_check_rest_server(port)
platform = 'hybrid' if isinstance(config.training_service, list) else config.training_service.platform
_save_experiment_information(exp_id, port, start_time, platform,
config.experiment_name, proc.pid, config.experiment_working_directory, ['retiarii'])
_logger.info('Setting up...')
rest.post(port, '/experiment', config.json())
return proc, pipe
except Exception as e:
_logger.error('Create experiment failed')
if proc is not None:
with contextlib.suppress(Exception):
proc.kill()
if pipe is not None:
with contextlib.suppress(Exception):
pipe.close()
raise e
def _ensure_port_idle(port: int, message: Optional[str] = None) -> None:
sock = socket.socket() sock = socket.socket()
if sock.connect_ex(('localhost', port)) == 0: if sock.connect_ex(('localhost', port)) == 0:
sock.close() sock.close()
...@@ -203,48 +185,7 @@ def _ensure_port_idle(port: int, message: Optional[str] = None) -> None: ...@@ -203,48 +185,7 @@ def _ensure_port_idle(port: int, message: Optional[str] = None) -> None:
raise RuntimeError(f'Port {port} is not idle {message}') raise RuntimeError(f'Port {port} is not idle {message}')
def _start_rest_server_retiarii(config: ExperimentConfig, port: int, debug: bool, experiment_id: str, def _check_rest_server(port: int, retry: int = 3, url_prefix: str | None = None) -> None:
pipe_path: str, mode: str = 'create') -> Tuple[int, Popen]:
if isinstance(config.training_service, list):
ts = 'hybrid'
else:
ts = config.training_service.platform
if ts == 'openpai':
ts = 'pai'
args = {
'port': port,
'mode': ts,
'experiment_id': experiment_id,
'action': mode,
'experiments_directory': config.experiment_working_directory,
'log_level': 'debug' if debug else 'info'
}
if pipe_path is not None:
args['dispatcher_pipe'] = pipe_path
import nni_node
node_dir = Path(nni_node.__path__[0]) # type: ignore
node = str(node_dir / ('node.exe' if sys.platform == 'win32' else 'node'))
main_js = str(node_dir / 'main.js')
cmd = [node, '--max-old-space-size=4096', main_js]
for arg_key, arg_value in args.items():
cmd.append('--' + arg_key.replace('_', '-'))
cmd.append(str(arg_value))
if sys.platform == 'win32':
from subprocess import CREATE_NEW_PROCESS_GROUP
proc = Popen(cmd, cwd=node_dir, creationflags=CREATE_NEW_PROCESS_GROUP)
else:
if pipe_path is None:
import os
proc = Popen(cmd, cwd=node_dir, preexec_fn=os.setpgrp)
else:
proc = Popen(cmd, cwd=node_dir)
return int(time.time() * 1000), proc
def _check_rest_server(port: int, retry: int = 3, url_prefix: Optional[str] = None) -> None:
for i in range(retry): for i in range(retry):
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
rest.get(port, '/check-status', url_prefix) rest.get(port, '/check-status', url_prefix)
...@@ -256,7 +197,7 @@ def _check_rest_server(port: int, retry: int = 3, url_prefix: Optional[str] = No ...@@ -256,7 +197,7 @@ def _check_rest_server(port: int, retry: int = 3, url_prefix: Optional[str] = No
def _save_experiment_information(experiment_id: str, port: int, start_time: int, platform: str, def _save_experiment_information(experiment_id: str, port: int, start_time: int, platform: str,
name: str, pid: int, logDir: str, tag: List[Any]) -> None: name: str, pid: int, logDir: str, tag: list[Any]) -> None:
experiments_config = Experiments() experiments_config = Experiments()
experiments_config.add_experiment(experiment_id, port, start_time, platform, name, pid=pid, logDir=logDir, tag=tag) experiments_config.add_experiment(experiment_id, port, start_time, platform, name, pid=pid, logDir=logDir, tag=tag)
......
...@@ -19,12 +19,12 @@ import torch ...@@ -19,12 +19,12 @@ import torch
import torch.nn as nn import torch.nn as nn
import nni.runtime.log import nni.runtime.log
from nni.common.device import GPUDevice from nni.common.device import GPUDevice
from nni.experiment import Experiment, launcher, management, rest from nni.experiment import Experiment, RunMode, launcher, management, rest
from nni.experiment.config import utils from nni.experiment.config import utils
from nni.experiment.config.base import ConfigBase from nni.experiment.config.base import ConfigBase
from nni.experiment.config.training_service import TrainingServiceConfig from nni.experiment.config.training_service import TrainingServiceConfig
from nni.experiment.config.training_services import RemoteConfig from nni.experiment.config.training_services import RemoteConfig
from nni.experiment.pipe import Pipe from nni.runtime.protocol import connect_websocket
from nni.tools.nnictl.command_utils import kill_command from nni.tools.nnictl.command_utils import kill_command
from ..codegen import model_to_pytorch_script from ..codegen import model_to_pytorch_script
...@@ -64,7 +64,7 @@ class RetiariiExeConfig(ConfigBase): ...@@ -64,7 +64,7 @@ class RetiariiExeConfig(ConfigBase):
batch_waiting_time: Optional[int] = None batch_waiting_time: Optional[int] = None
nni_manager_ip: Optional[str] = None nni_manager_ip: Optional[str] = None
debug: bool = False debug: bool = False
log_level: Optional[str] = None log_level: str = 'info'
experiment_working_directory: utils.PathLike = '~/nni-experiments' experiment_working_directory: utils.PathLike = '~/nni-experiments'
# remove configuration of tuner/assessor/advisor # remove configuration of tuner/assessor/advisor
training_service: TrainingServiceConfig training_service: TrainingServiceConfig
...@@ -279,7 +279,6 @@ class RetiariiExperiment(Experiment): ...@@ -279,7 +279,6 @@ class RetiariiExperiment(Experiment):
self._dispatcher = cast(RetiariiAdvisor, None) self._dispatcher = cast(RetiariiAdvisor, None)
self._dispatcher_thread: Optional[Thread] = None self._dispatcher_thread: Optional[Thread] = None
self._proc: Optional[Popen] = None self._proc: Optional[Popen] = None
self._pipe: Optional[Pipe] = None
self.url_prefix = None self.url_prefix = None
...@@ -354,9 +353,11 @@ class RetiariiExperiment(Experiment): ...@@ -354,9 +353,11 @@ class RetiariiExperiment(Experiment):
log_dir = Path.home() / f'nni-experiments/{self.id}/log' log_dir = Path.home() / f'nni-experiments/{self.id}/log'
nni.runtime.log.start_experiment_log(self.id, log_dir, debug) nni.runtime.log.start_experiment_log(self.id, log_dir, debug)
self._proc, self._pipe = launcher.start_experiment_retiarii(self.id, self.config, port, debug) ws_url = f'ws://localhost:{port}/tuner'
self._proc = launcher.start_experiment('create', self.id, self.config, port, debug, # type: ignore
RunMode.Background, None, ws_url, ['retiarii'])
assert self._proc is not None assert self._proc is not None
assert self._pipe is not None connect_websocket(ws_url)
self.port = port # port will be None if start up failed self.port = port # port will be None if start up failed
...@@ -474,13 +475,9 @@ class RetiariiExperiment(Experiment): ...@@ -474,13 +475,9 @@ class RetiariiExperiment(Experiment):
_logger.warning('Cannot gracefully stop experiment, killing NNI process...') _logger.warning('Cannot gracefully stop experiment, killing NNI process...')
kill_command(self._proc.pid) kill_command(self._proc.pid)
if self._pipe is not None:
self._pipe.close()
self.id = cast(str, None) self.id = cast(str, None)
self.port = cast(int, None) self.port = cast(int, None)
self._proc = None self._proc = None
self._pipe = None
self._dispatcher = cast(RetiariiAdvisor, None) self._dispatcher = cast(RetiariiAdvisor, None)
self._dispatcher_thread = None self._dispatcher_thread = None
_logger.info('Experiment stopped') _logger.info('Experiment stopped')
......
...@@ -10,7 +10,6 @@ export class ExperimentStartupInfo { ...@@ -10,7 +10,6 @@ export class ExperimentStartupInfo {
public logDir: string = globals.paths.experimentRoot; public logDir: string = globals.paths.experimentRoot;
public logLevel: string = globals.args.logLevel; public logLevel: string = globals.args.logLevel;
public readonly: boolean = (globals.args.action === 'view'); public readonly: boolean = (globals.args.action === 'view');
public dispatcherPipe: string | null = globals.args.dispatcherPipe ?? null;
public platform: string = globals.args.mode as string; public platform: string = globals.args.mode as string;
public urlprefix: string = globals.args.urlPrefix; public urlprefix: string = globals.args.urlPrefix;
...@@ -42,7 +41,3 @@ export function getPlatform(): string { ...@@ -42,7 +41,3 @@ export function getPlatform(): string {
export function isReadonly(): boolean { export function isReadonly(): boolean {
return globals.args.action === 'view'; return globals.args.action === 'view';
} }
export function getDispatcherPipe(): string | null {
return globals.args.dispatcherPipe ?? null;
}
...@@ -28,10 +28,10 @@ export interface NniManagerArgs { ...@@ -28,10 +28,10 @@ export interface NniManagerArgs {
readonly logLevel: 'critical' | 'error' | 'warning' | 'info' | 'debug'; readonly logLevel: 'critical' | 'error' | 'warning' | 'info' | 'debug';
readonly foreground: boolean; readonly foreground: boolean;
readonly urlPrefix: string; // leading and trailing "/" must be stripped readonly urlPrefix: string; // leading and trailing "/" must be stripped
readonly tunerCommandChannel: string | null;
// these are planned to be removed // these are planned to be removed
readonly mode: string; readonly mode: string;
readonly dispatcherPipe: string | undefined;
} }
export function parseArgs(rawArgs: string[]): NniManagerArgs { export function parseArgs(rawArgs: string[]): NniManagerArgs {
...@@ -44,9 +44,6 @@ export function parseArgs(rawArgs: string[]): NniManagerArgs { ...@@ -44,9 +44,6 @@ export function parseArgs(rawArgs: string[]): NniManagerArgs {
argsAsAny[key] = (parsedArgs as any)[key]; argsAsAny[key] = (parsedArgs as any)[key];
assert(!Number.isNaN(argsAsAny[key]), `Command line arg --${key} is not a number`); assert(!Number.isNaN(argsAsAny[key]), `Command line arg --${key} is not a number`);
} }
if (argsAsAny.dispatcherPipe === '') {
argsAsAny.dispatcherPipe = undefined;
}
const args: NniManagerArgs = argsAsAny; const args: NniManagerArgs = argsAsAny;
const prefixErrMsg = `Command line arg --url-prefix "${args.urlPrefix}" is not stripped`; const prefixErrMsg = `Command line arg --url-prefix "${args.urlPrefix}" is not stripped`;
...@@ -84,12 +81,12 @@ const yargsOptions = { ...@@ -84,12 +81,12 @@ const yargsOptions = {
default: '', default: '',
type: 'string' type: 'string'
}, },
tunerCommandChannel: {
mode: { default: null,
default: '',
type: 'string' type: 'string'
}, },
dispatcherPipe: {
mode: {
default: '', default: '',
type: 'string' type: 'string'
} }
......
...@@ -48,8 +48,8 @@ export function resetGlobals(): void { ...@@ -48,8 +48,8 @@ export function resetGlobals(): void {
logLevel: 'info', logLevel: 'info',
foreground: false, foreground: false,
urlPrefix: '', urlPrefix: '',
mode: 'unittest', tunerCommandChannel: null,
dispatcherPipe: undefined mode: 'unittest'
}; };
const paths = createPaths(args); const paths = createPaths(args);
const logStream = { const logStream = {
......
...@@ -3,17 +3,10 @@ ...@@ -3,17 +3,10 @@
import { IpcInterface } from './tuner_command_channel/common'; import { IpcInterface } from './tuner_command_channel/common';
export { IpcInterface } from './tuner_command_channel/common'; export { IpcInterface } from './tuner_command_channel/common';
export { createDispatcherPipeInterface, encodeCommand } from './tuner_command_channel/legacy';
import * as shim from './tuner_command_channel/shim'; import * as shim from './tuner_command_channel/shim';
let tunerDisabled: boolean = false; let tunerDisabled: boolean = false;
class DummyIpcInterface implements IpcInterface {
public sendCommand(_commandType: string, _content?: string): void { /* empty */ }
public onCommand(_listener: (commandType: string, content: string) => void): void { /* empty */ }
public onError(_listener: (error: Error) => void): void { /* empty */ }
}
export async function createDispatcherInterface(): Promise<IpcInterface> { export async function createDispatcherInterface(): Promise<IpcInterface> {
if (!tunerDisabled) { if (!tunerDisabled) {
return await shim.createDispatcherInterface(); return await shim.createDispatcherInterface();
...@@ -22,6 +15,19 @@ export async function createDispatcherInterface(): Promise<IpcInterface> { ...@@ -22,6 +15,19 @@ export async function createDispatcherInterface(): Promise<IpcInterface> {
} }
} }
export function encodeCommand(commandType: string, content: string): Buffer {
const contentBuffer: Buffer = Buffer.from(content);
const contentLengthBuffer: Buffer = Buffer.from(contentBuffer.length.toString().padStart(14, '0'));
return Buffer.concat([Buffer.from(commandType), contentLengthBuffer, contentBuffer]);
}
class DummyIpcInterface implements IpcInterface {
public async init(): Promise<void> { /* empty */ }
public sendCommand(_commandType: string, _content?: string): void { /* empty */ }
public onCommand(_listener: (commandType: string, content: string) => void): void { /* empty */ }
public onError(_listener: (error: Error) => void): void { /* empty */ }
}
export namespace UnitTestHelpers { export namespace UnitTestHelpers {
export function disableTuner(): void { export function disableTuner(): void {
tunerDisabled = true; tunerDisabled = true;
......
...@@ -7,7 +7,7 @@ import { Deferred } from 'ts-deferred'; ...@@ -7,7 +7,7 @@ import { Deferred } from 'ts-deferred';
import * as component from '../common/component'; import * as component from '../common/component';
import { DataStore, MetricDataRecord, MetricType, TrialJobInfo } from '../common/datastore'; import { DataStore, MetricDataRecord, MetricType, TrialJobInfo } from '../common/datastore';
import { NNIError } from '../common/errors'; import { NNIError } from '../common/errors';
import { getExperimentId, getDispatcherPipe } from '../common/experimentStartupInfo'; import { getExperimentId } from '../common/experimentStartupInfo';
import globals from 'common/globals'; import globals from 'common/globals';
import { Logger, getLogger } from '../common/log'; import { Logger, getLogger } from '../common/log';
import { import {
...@@ -25,7 +25,7 @@ import { ...@@ -25,7 +25,7 @@ import {
INITIALIZE, INITIALIZED, KILL_TRIAL_JOB, NEW_TRIAL_JOB, NO_MORE_TRIAL_JOBS, PING, INITIALIZE, INITIALIZED, KILL_TRIAL_JOB, NEW_TRIAL_JOB, NO_MORE_TRIAL_JOBS, PING,
REPORT_METRIC_DATA, REQUEST_TRIAL_JOBS, SEND_TRIAL_JOB_PARAMETER, TERMINATE, TRIAL_END, UPDATE_SEARCH_SPACE, IMPORT_DATA REPORT_METRIC_DATA, REQUEST_TRIAL_JOBS, SEND_TRIAL_JOB_PARAMETER, TERMINATE, TRIAL_END, UPDATE_SEARCH_SPACE, IMPORT_DATA
} from './commands'; } from './commands';
import { createDispatcherInterface, createDispatcherPipeInterface, IpcInterface } from './ipcInterface'; import { createDispatcherInterface, IpcInterface } from './ipcInterface';
/** /**
* NNIManager which implements Manager interface * NNIManager which implements Manager interface
...@@ -71,11 +71,6 @@ class NNIManager implements Manager { ...@@ -71,11 +71,6 @@ class NNIManager implements Manager {
}); });
}; };
const pipe = getDispatcherPipe();
if (pipe !== null) {
this.dispatcher = createDispatcherPipeInterface(pipe);
}
globals.shutdown.register('NniManager', this.stopExperiment.bind(this)); globals.shutdown.register('NniManager', this.stopExperiment.bind(this));
} }
...@@ -466,7 +461,22 @@ class NNIManager implements Manager { ...@@ -466,7 +461,22 @@ class NNIManager implements Manager {
if (this.dispatcher !== undefined) { if (this.dispatcher !== undefined) {
return; return;
} }
const stdio: StdioOptions = ['ignore', process.stdout, process.stderr, 'pipe', 'pipe'];
let tunerWs: string;
if (globals.args.urlPrefix) {
tunerWs = `ws://localhost:${globals.args.port}/${globals.args.urlPrefix}/tuner`;
} else {
tunerWs = `ws://localhost:${globals.args.port}/tuner`;
}
if (globals.args.tunerCommandChannel) {
// TODO: this will become configurable after refactoring rest handler interface
assert.equal(tunerWs, globals.args.tunerCommandChannel);
this.dispatcher = await createDispatcherInterface();
return;
}
const stdio: StdioOptions = ['ignore', process.stdout, process.stderr];
let newCwd: string; let newCwd: string;
if (cwd === undefined || cwd === '') { if (cwd === undefined || cwd === '') {
newCwd = getLogDir(); newCwd = getLogDir();
...@@ -476,13 +486,6 @@ class NNIManager implements Manager { ...@@ -476,13 +486,6 @@ class NNIManager implements Manager {
// TO DO: add CUDA_VISIBLE_DEVICES // TO DO: add CUDA_VISIBLE_DEVICES
const includeIntermediateResultsEnv = !!(this.config.deprecated && this.config.deprecated.includeIntermediateResults); const includeIntermediateResultsEnv = !!(this.config.deprecated && this.config.deprecated.includeIntermediateResults);
let tunerWs: string;
if (globals.args.urlPrefix) {
tunerWs = `ws://localhost:${globals.args.port}/${globals.args.urlPrefix}/tuner`;
} else {
tunerWs = `ws://localhost:${globals.args.port}/tuner`;
}
const nniEnv = { const nniEnv = {
SDK_PROCESS: 'dispatcher', SDK_PROCESS: 'dispatcher',
NNI_MODE: mode, NNI_MODE: mode,
...@@ -713,6 +716,7 @@ class NNIManager implements Manager { ...@@ -713,6 +716,7 @@ class NNIManager implements Manager {
private async run(): Promise<void> { private async run(): Promise<void> {
assert(this.dispatcher !== undefined); assert(this.dispatcher !== undefined);
await this.dispatcher.init();
this.addEventListeners(); this.addEventListeners();
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
// Licensed under the MIT license. // Licensed under the MIT license.
export interface IpcInterface { export interface IpcInterface {
init(): Promise<void>;
sendCommand(commandType: string, content?: string): void; sendCommand(commandType: string, content?: string): void;
onCommand(listener: (commandType: string, content: string) => void): void; onCommand(listener: (commandType: string, content: string) => void): void;
onError(listener: (error: Error) => void): void; onError(listener: (error: Error) => void): void;
......
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
import assert from 'assert';
import { ChildProcess } from 'child_process';
import { EventEmitter } from 'events';
import net from 'net';
import { Readable, Writable } from 'stream';
import { NNIError } from '../../common/errors';
import { getLogger, Logger } from '../../common/log';
import { getLogDir } from '../../common/utils';
import * as CommandType from '../commands';
import type { IpcInterface } from './common';
const ipcOutgoingFd: number = 3;
const ipcIncomingFd: number = 4;
/**
* Encode a command
* @param commandType a command type defined in 'core/commands'
* @param content payload of the command
* @returns binary command data
*/
function encodeCommand(commandType: string, content: string): Buffer {
const contentBuffer: Buffer = Buffer.from(content);
const contentLengthBuffer: Buffer = Buffer.from(contentBuffer.length.toString().padStart(14, '0'));
return Buffer.concat([Buffer.from(commandType), contentLengthBuffer, contentBuffer]);
}
/**
* Decode a command
* @param Buffer binary incoming data
* @returns a tuple of (success, commandType, content, remain)
* success: true if the buffer contains at least one complete command; otherwise false
* remain: remaining data after the first command
*/
function decodeCommand(data: Buffer): [boolean, string, string, Buffer] {
if (data.length < 8) {
return [false, '', '', data];
}
const commandType: string = data.slice(0, 2).toString();
const contentLength: number = parseInt(data.slice(2, 16).toString(), 10);
if (data.length < contentLength + 16) {
return [false, '', '', data];
}
const content: string = data.slice(16, contentLength + 16).toString();
const remain: Buffer = data.slice(contentLength + 16);
return [true, commandType, content, remain];
}
class LegacyIpcInterface implements IpcInterface {
private acceptCommandTypes: Set<string>;
private outgoingStream: Writable;
private incomingStream: Readable;
private eventEmitter: EventEmitter;
private readBuffer: Buffer;
private logger: Logger = getLogger('IpcInterface');
/**
* Construct a IPC proxy
* @param proc the process to wrap
* @param acceptCommandTypes set of accepted commands for this process
*/
constructor(outStream: Writable, inStream: Readable, acceptCommandTypes: Set<string>) {
this.acceptCommandTypes = acceptCommandTypes;
this.outgoingStream = outStream;
this.incomingStream = inStream;
this.eventEmitter = new EventEmitter();
this.readBuffer = Buffer.alloc(0);
this.incomingStream.on('data', (data: Buffer) => { this.receive(data); });
this.incomingStream.on('error', (error: Error) => { this.eventEmitter.emit('error', error); });
this.outgoingStream.on('error', (error: Error) => { this.eventEmitter.emit('error', error); });
}
/**
* Send a command to process
* @param commandType: a command type defined in 'core/commands'
* @param content: payload of command
*/
public sendCommand(commandType: string, content: string = ''): void {
this.logger.debug(`ipcInterface command type: [${commandType}], content:[${content}]`);
assert.ok(this.acceptCommandTypes.has(commandType));
try {
const data: Buffer = encodeCommand(commandType, content);
if (!this.outgoingStream.write(data)) {
this.logger.warning('Commands jammed in buffer!');
}
} catch (err) {
throw NNIError.FromError(
err,
`Dispatcher Error, please check this dispatcher log file for more detailed information: ${getLogDir()}/dispatcher.log . `
);
}
}
/**
* Add a command listener
* @param listener the listener callback
*/
public onCommand(listener: (commandType: string, content: string) => void): void {
this.eventEmitter.on('command', listener);
}
public onError(listener: (error: Error) => void): void {
this.eventEmitter.on('error', listener);
}
/**
* Deal with incoming data from process
* Invoke listeners for each complete command received, save incomplete command to buffer
* @param data binary incoming data
*/
private receive(data: Buffer): void {
this.readBuffer = Buffer.concat([this.readBuffer, data]);
while (this.readBuffer.length > 0) {
const [success, commandType, content, remain] = decodeCommand(this.readBuffer);
if (!success) {
break;
}
assert.ok(this.acceptCommandTypes.has(commandType));
this.eventEmitter.emit('command', commandType, content);
this.readBuffer = remain;
}
}
}
/**
* Create IPC proxy for tuner process
* @param process_ the tuner process
*/
async function createDispatcherInterface(process: ChildProcess): Promise<IpcInterface> {
const outStream = <Writable>process.stdio[ipcOutgoingFd];
const inStream = <Readable>process.stdio[ipcIncomingFd];
return new LegacyIpcInterface(outStream, inStream, new Set([...CommandType.TUNER_COMMANDS, ...CommandType.ASSESSOR_COMMANDS]));
}
function createDispatcherPipeInterface(pipePath: string): IpcInterface {
const client = net.createConnection(pipePath);
return new LegacyIpcInterface(client, client, new Set([...CommandType.TUNER_COMMANDS, ...CommandType.ASSESSOR_COMMANDS]));
}
export { createDispatcherInterface, createDispatcherPipeInterface, encodeCommand, decodeCommand };
...@@ -5,9 +5,7 @@ import type { IpcInterface } from './common'; ...@@ -5,9 +5,7 @@ import type { IpcInterface } from './common';
import { WebSocketChannel, getWebSocketChannel } from './websocket_channel'; import { WebSocketChannel, getWebSocketChannel } from './websocket_channel';
export async function createDispatcherInterface(): Promise<IpcInterface> { export async function createDispatcherInterface(): Promise<IpcInterface> {
const ipcInterface = new WsIpcInterface(); return new WsIpcInterface();
await ipcInterface.init();
return ipcInterface;
} }
class WsIpcInterface implements IpcInterface { class WsIpcInterface implements IpcInterface {
......
...@@ -14,9 +14,9 @@ const expected = { ...@@ -14,9 +14,9 @@ const expected = {
logLevel: 'error', logLevel: 'error',
foreground: false, foreground: false,
urlPrefix: '', urlPrefix: '',
tunerCommandChannel: null,
mode: '', mode: '',
dispatcherPipe: undefined,
}; };
function testGoodShort(): void { function testGoodShort(): void {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment