Unverified Commit e8de5eb4 authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

WebSocket (step 3) - HPO experiment (#4821)

parent 94a276ce
...@@ -9,6 +9,7 @@ import base64 ...@@ -9,6 +9,7 @@ import base64
from .runtime.msg_dispatcher import MsgDispatcher from .runtime.msg_dispatcher import MsgDispatcher
from .runtime.msg_dispatcher_base import MsgDispatcherBase from .runtime.msg_dispatcher_base import MsgDispatcherBase
from .runtime.protocol import connect_websocket
from .tools.package_utils import create_builtin_class_instance, create_customized_class_instance from .tools.package_utils import create_builtin_class_instance, create_customized_class_instance
logger = logging.getLogger('nni.main') logger = logging.getLogger('nni.main')
...@@ -20,6 +21,10 @@ if os.environ.get('COVERAGE_PROCESS_START'): ...@@ -20,6 +21,10 @@ if os.environ.get('COVERAGE_PROCESS_START'):
def main(): def main():
# the url should be "ws://localhost:{port}/tuner" or "ws://localhost:{port}/{url_prefix}/tuner"
ws_url = os.environ['NNI_TUNER_COMMAND_CHANNEL']
connect_websocket(ws_url)
parser = argparse.ArgumentParser(description='Dispatcher command line parser') parser = argparse.ArgumentParser(description='Dispatcher command line parser')
parser.add_argument('--exp_params', type=str, required=True) parser.add_argument('--exp_params', type=str, required=True)
args, _ = parser.parse_known_args() args, _ = parser.parse_known_args()
......
...@@ -38,8 +38,8 @@ class NniManagerArgs: ...@@ -38,8 +38,8 @@ class NniManagerArgs:
experiments_directory: str # renamed "config.nni_experiments_directory", must be absolute experiments_directory: str # renamed "config.nni_experiments_directory", must be absolute
log_level: str log_level: str
foreground: bool = False foreground: bool = False
url_prefix: Optional[str] = None # leading and trailing "/" must be stripped url_prefix: str | None = None # leading and trailing "/" must be stripped
dispatcher_pipe: Optional[str] = None dispatcher_pipe: str | None = None
def __init__(self, action, exp_id, config, port, debug, foreground, url_prefix): def __init__(self, action, exp_id, config, port, debug, foreground, url_prefix):
self.port = port self.port = port
...@@ -85,6 +85,15 @@ def start_experiment(action, exp_id, config, port, debug, run_mode, url_prefix): ...@@ -85,6 +85,15 @@ def start_experiment(action, exp_id, config, port, debug, run_mode, url_prefix):
if action != 'view' and nni_manager_args.mode in websocket_platforms: if action != 'view' and nni_manager_args.mode in websocket_platforms:
_ensure_port_idle(port + 1, f'{nni_manager_args.mode} requires an additional port') _ensure_port_idle(port + 1, f'{nni_manager_args.mode} requires an additional port')
link = Path(config.experiment_working_directory, '_latest')
try:
if link.exists():
link.unlink()
link.symlink_to(exp_id, target_is_directory=True)
except Exception:
if sys.platform != 'win32':
_logger.warning(f'Failed to create link {link}')
proc = None proc = None
try: try:
_logger.info( _logger.info(
...@@ -118,18 +127,6 @@ def start_experiment(action, exp_id, config, port, debug, run_mode, url_prefix): ...@@ -118,18 +127,6 @@ def start_experiment(action, exp_id, config, port, debug, run_mode, url_prefix):
proc.kill() proc.kill()
raise e raise e
link = Path(config.experiment_working_directory, '_latest')
try:
if sys.version_info >= (3, 8):
link.unlink(missing_ok=True)
else:
if link.exists():
link.unlink()
link.symlink_to(exp_id, target_is_directory=True)
except Exception:
if sys.platform != 'win32':
_logger.warning(f'Failed to create link {link}')
return proc return proc
def _start_rest_server(nni_manager_args, run_mode) -> Popen: def _start_rest_server(nni_manager_args, run_mode) -> Popen:
......
...@@ -3,18 +3,37 @@ ...@@ -3,18 +3,37 @@
# pylint: disable=unused-import # pylint: disable=unused-import
from __future__ import annotations
from .tuner_command_channel.command_type import CommandType from .tuner_command_channel.command_type import CommandType
from .tuner_command_channel.legacy import send, receive from .tuner_command_channel import legacy
from .tuner_command_channel import shim
_use_ws = False
def connect_websocket(url: str):
global _use_ws
_use_ws = True
shim.connect(url)
def send(command: CommandType, data: str) -> None:
if _use_ws:
shim.send(command, data)
else:
legacy.send(command, data)
def receive() -> tuple[CommandType, str] | tuple[None, None]:
if _use_ws:
return shim.receive()
else:
return legacy.receive()
# for unit test compatibility # for unit test compatibility
def _set_in_file(in_file): def _set_in_file(in_file):
from .tuner_command_channel import legacy
legacy._in_file = in_file legacy._in_file = in_file
def _set_out_file(out_file): def _set_out_file(out_file):
from .tuner_command_channel import legacy
legacy._out_file = out_file legacy._out_file = out_file
def _get_out_file(): def _get_out_file():
from .tuner_command_channel import legacy
return legacy._out_file return legacy._out_file
// Copyright (c) Microsoft Corporation. // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license. // Licensed under the MIT license.
import { IpcInterface } from './tuner_command_channel/common';
export { IpcInterface } from './tuner_command_channel/common'; export { IpcInterface } from './tuner_command_channel/common';
export { createDispatcherInterface, createDispatcherPipeInterface, encodeCommand } from './tuner_command_channel/legacy'; export { createDispatcherPipeInterface, encodeCommand } from './tuner_command_channel/legacy';
import * as shim from './tuner_command_channel/shim';
let tunerDisabled: boolean = false;
class DummyIpcInterface implements IpcInterface {
public sendCommand(_commandType: string, _content?: string): void { /* empty */ }
public onCommand(_listener: (commandType: string, content: string) => void): void { /* empty */ }
public onError(_listener: (error: Error) => void): void { /* empty */ }
}
export async function createDispatcherInterface(): Promise<IpcInterface> {
if (!tunerDisabled) {
return await shim.createDispatcherInterface();
} else {
return new DummyIpcInterface();
}
}
export namespace UnitTestHelpers {
export function disableTuner(): void {
tunerDisabled = true;
}
}
...@@ -476,6 +476,13 @@ class NNIManager implements Manager { ...@@ -476,6 +476,13 @@ class NNIManager implements Manager {
// TO DO: add CUDA_VISIBLE_DEVICES // TO DO: add CUDA_VISIBLE_DEVICES
const includeIntermediateResultsEnv = !!(this.config.deprecated && this.config.deprecated.includeIntermediateResults); const includeIntermediateResultsEnv = !!(this.config.deprecated && this.config.deprecated.includeIntermediateResults);
let tunerWs: string;
if (globals.args.urlPrefix) {
tunerWs = `ws://localhost:${globals.args.port}/${globals.args.urlPrefix}/tuner`;
} else {
tunerWs = `ws://localhost:${globals.args.port}/tuner`;
}
const nniEnv = { const nniEnv = {
SDK_PROCESS: 'dispatcher', SDK_PROCESS: 'dispatcher',
NNI_MODE: mode, NNI_MODE: mode,
...@@ -483,12 +490,13 @@ class NNIManager implements Manager { ...@@ -483,12 +490,13 @@ class NNIManager implements Manager {
NNI_LOG_DIRECTORY: getLogDir(), NNI_LOG_DIRECTORY: getLogDir(),
NNI_LOG_LEVEL: getLogLevel(), NNI_LOG_LEVEL: getLogLevel(),
NNI_INCLUDE_INTERMEDIATE_RESULTS: includeIntermediateResultsEnv, NNI_INCLUDE_INTERMEDIATE_RESULTS: includeIntermediateResultsEnv,
NNI_TUNER_COMMAND_CHANNEL: tunerWs,
CUDA_VISIBLE_DEVICES: toCudaVisibleDevices(this.experimentProfile.params.tunerGpuIndices) CUDA_VISIBLE_DEVICES: toCudaVisibleDevices(this.experimentProfile.params.tunerGpuIndices)
}; };
const newEnv = Object.assign({}, process.env, nniEnv); const newEnv = Object.assign({}, process.env, nniEnv);
const tunerProc: ChildProcess = getTunerProc(command, stdio, newCwd, newEnv); const tunerProc: ChildProcess = getTunerProc(command, stdio, newCwd, newEnv);
this.dispatcherPid = tunerProc.pid!; this.dispatcherPid = tunerProc.pid!;
this.dispatcher = await createDispatcherInterface(tunerProc); this.dispatcher = await createDispatcherInterface();
return; return;
} }
......
...@@ -42,7 +42,7 @@ async function runProcess(): Promise<Error | null> { ...@@ -42,7 +42,7 @@ async function runProcess(): Promise<Error | null> {
}); });
// create IPC interface // create IPC interface
const dispatcher: IpcInterface = await createDispatcherInterface(proc); const dispatcher: IpcInterface = await createDispatcherInterface();
dispatcher.onCommand((commandType: string, content: string): void => { dispatcher.onCommand((commandType: string, content: string): void => {
receivedCommands.push({ commandType, content }); receivedCommands.push({ commandType, content });
}); });
......
...@@ -53,7 +53,7 @@ async function startProcess(): Promise<void> { ...@@ -53,7 +53,7 @@ async function startProcess(): Promise<void> {
}); });
// create IPC interface // create IPC interface
dispatcher = await createDispatcherInterface(proc); dispatcher = await createDispatcherInterface();
(<IpcInterface>dispatcher).onCommand((commandType: string, content: string): void => { (<IpcInterface>dispatcher).onCommand((commandType: string, content: string): void => {
console.log(commandType, content); console.log(commandType, content);
}); });
......
...@@ -13,7 +13,7 @@ import { Database, DataStore } from '../../common/datastore'; ...@@ -13,7 +13,7 @@ import { Database, DataStore } from '../../common/datastore';
import { Manager, ExperimentProfile} from '../../common/manager'; import { Manager, ExperimentProfile} from '../../common/manager';
import { ExperimentManager } from '../../common/experimentManager'; import { ExperimentManager } from '../../common/experimentManager';
import { TrainingService } from '../../common/trainingService'; import { TrainingService } from '../../common/trainingService';
import { cleanupUnitTest, prepareUnitTest } from '../../common/utils'; import { cleanupUnitTest, prepareUnitTest, killPid } from '../../common/utils';
import { NNIExperimentsManager } from 'extensions/experiments_manager'; import { NNIExperimentsManager } from 'extensions/experiments_manager';
import { NNIManager } from '../../core/nnimanager'; import { NNIManager } from '../../core/nnimanager';
import { SqlDB } from '../../core/sqlDatabase'; import { SqlDB } from '../../core/sqlDatabase';
...@@ -22,9 +22,11 @@ import { MockedDataStore } from '../mock/datastore'; ...@@ -22,9 +22,11 @@ import { MockedDataStore } from '../mock/datastore';
import { TensorboardManager } from '../../common/tensorboardManager'; import { TensorboardManager } from '../../common/tensorboardManager';
import { NNITensorboardManager } from 'extensions/nniTensorboardManager'; import { NNITensorboardManager } from 'extensions/nniTensorboardManager';
import * as path from 'path'; import * as path from 'path';
import { UnitTestHelpers } from 'core/ipcInterface';
async function initContainer(): Promise<void> { async function initContainer(): Promise<void> {
prepareUnitTest(); prepareUnitTest();
UnitTestHelpers.disableTuner();
Container.bind(Manager).to(NNIManager).scope(Scope.Singleton); Container.bind(Manager).to(NNIManager).scope(Scope.Singleton);
Container.bind(Database).to(SqlDB).scope(Scope.Singleton); Container.bind(Database).to(SqlDB).scope(Scope.Singleton);
Container.bind(DataStore).to(MockedDataStore).scope(Scope.Singleton); Container.bind(DataStore).to(MockedDataStore).scope(Scope.Singleton);
...@@ -134,8 +136,11 @@ describe('Unit test for nnimanager', function () { ...@@ -134,8 +136,11 @@ describe('Unit test for nnimanager', function () {
}) })
after(async () => { after(async () => {
// FIXME // FIXME: more proper clean up
await (nniManager as any).stopExperimentTopHalf(); const manager: any = nniManager;
await killPid(manager.dispatcherPid);
manager.dispatcherPid = 0;
await manager.stopExperimentTopHalf();
cleanupUnitTest(); cleanupUnitTest();
}) })
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment