Unverified Commit e8de5eb4 authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

WebSocket (step 3) - HPO experiment (#4821)

parent 94a276ce
......@@ -9,6 +9,7 @@ import base64
from .runtime.msg_dispatcher import MsgDispatcher
from .runtime.msg_dispatcher_base import MsgDispatcherBase
from .runtime.protocol import connect_websocket
from .tools.package_utils import create_builtin_class_instance, create_customized_class_instance
logger = logging.getLogger('nni.main')
......@@ -20,6 +21,10 @@ if os.environ.get('COVERAGE_PROCESS_START'):
def main():
# the url should be "ws://localhost:{port}/tuner" or "ws://localhost:{port}/{url_prefix}/tuner"
ws_url = os.environ['NNI_TUNER_COMMAND_CHANNEL']
connect_websocket(ws_url)
parser = argparse.ArgumentParser(description='Dispatcher command line parser')
parser.add_argument('--exp_params', type=str, required=True)
args, _ = parser.parse_known_args()
......
......@@ -38,8 +38,8 @@ class NniManagerArgs:
experiments_directory: str # renamed "config.nni_experiments_directory", must be absolute
log_level: str
foreground: bool = False
url_prefix: Optional[str] = None # leading and trailing "/" must be stripped
dispatcher_pipe: Optional[str] = None
url_prefix: str | None = None # leading and trailing "/" must be stripped
dispatcher_pipe: str | None = None
def __init__(self, action, exp_id, config, port, debug, foreground, url_prefix):
self.port = port
......@@ -85,6 +85,15 @@ def start_experiment(action, exp_id, config, port, debug, run_mode, url_prefix):
if action != 'view' and nni_manager_args.mode in websocket_platforms:
_ensure_port_idle(port + 1, f'{nni_manager_args.mode} requires an additional port')
link = Path(config.experiment_working_directory, '_latest')
try:
if link.exists():
link.unlink()
link.symlink_to(exp_id, target_is_directory=True)
except Exception:
if sys.platform != 'win32':
_logger.warning(f'Failed to create link {link}')
proc = None
try:
_logger.info(
......@@ -118,18 +127,6 @@ def start_experiment(action, exp_id, config, port, debug, run_mode, url_prefix):
proc.kill()
raise e
link = Path(config.experiment_working_directory, '_latest')
try:
if sys.version_info >= (3, 8):
link.unlink(missing_ok=True)
else:
if link.exists():
link.unlink()
link.symlink_to(exp_id, target_is_directory=True)
except Exception:
if sys.platform != 'win32':
_logger.warning(f'Failed to create link {link}')
return proc
def _start_rest_server(nni_manager_args, run_mode) -> Popen:
......
......@@ -3,18 +3,37 @@
# pylint: disable=unused-import
from __future__ import annotations
from .tuner_command_channel.command_type import CommandType
from .tuner_command_channel.legacy import send, receive
from .tuner_command_channel import legacy
from .tuner_command_channel import shim
_use_ws = False
def connect_websocket(url: str):
global _use_ws
_use_ws = True
shim.connect(url)
def send(command: CommandType, data: str) -> None:
if _use_ws:
shim.send(command, data)
else:
legacy.send(command, data)
def receive() -> tuple[CommandType, str] | tuple[None, None]:
if _use_ws:
return shim.receive()
else:
return legacy.receive()
# for unit test compatibility
def _set_in_file(in_file):
from .tuner_command_channel import legacy
legacy._in_file = in_file
def _set_out_file(out_file):
from .tuner_command_channel import legacy
legacy._out_file = out_file
def _get_out_file():
from .tuner_command_channel import legacy
return legacy._out_file
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
import { IpcInterface } from './tuner_command_channel/common';
export { IpcInterface } from './tuner_command_channel/common';
export { createDispatcherInterface, createDispatcherPipeInterface, encodeCommand } from './tuner_command_channel/legacy';
export { createDispatcherPipeInterface, encodeCommand } from './tuner_command_channel/legacy';
import * as shim from './tuner_command_channel/shim';
let tunerDisabled: boolean = false;
class DummyIpcInterface implements IpcInterface {
public sendCommand(_commandType: string, _content?: string): void { /* empty */ }
public onCommand(_listener: (commandType: string, content: string) => void): void { /* empty */ }
public onError(_listener: (error: Error) => void): void { /* empty */ }
}
export async function createDispatcherInterface(): Promise<IpcInterface> {
if (!tunerDisabled) {
return await shim.createDispatcherInterface();
} else {
return new DummyIpcInterface();
}
}
export namespace UnitTestHelpers {
export function disableTuner(): void {
tunerDisabled = true;
}
}
......@@ -476,6 +476,13 @@ class NNIManager implements Manager {
// TO DO: add CUDA_VISIBLE_DEVICES
const includeIntermediateResultsEnv = !!(this.config.deprecated && this.config.deprecated.includeIntermediateResults);
let tunerWs: string;
if (globals.args.urlPrefix) {
tunerWs = `ws://localhost:${globals.args.port}/${globals.args.urlPrefix}/tuner`;
} else {
tunerWs = `ws://localhost:${globals.args.port}/tuner`;
}
const nniEnv = {
SDK_PROCESS: 'dispatcher',
NNI_MODE: mode,
......@@ -483,12 +490,13 @@ class NNIManager implements Manager {
NNI_LOG_DIRECTORY: getLogDir(),
NNI_LOG_LEVEL: getLogLevel(),
NNI_INCLUDE_INTERMEDIATE_RESULTS: includeIntermediateResultsEnv,
NNI_TUNER_COMMAND_CHANNEL: tunerWs,
CUDA_VISIBLE_DEVICES: toCudaVisibleDevices(this.experimentProfile.params.tunerGpuIndices)
};
const newEnv = Object.assign({}, process.env, nniEnv);
const tunerProc: ChildProcess = getTunerProc(command, stdio, newCwd, newEnv);
this.dispatcherPid = tunerProc.pid!;
this.dispatcher = await createDispatcherInterface(tunerProc);
this.dispatcher = await createDispatcherInterface();
return;
}
......
......@@ -42,7 +42,7 @@ async function runProcess(): Promise<Error | null> {
});
// create IPC interface
const dispatcher: IpcInterface = await createDispatcherInterface(proc);
const dispatcher: IpcInterface = await createDispatcherInterface();
dispatcher.onCommand((commandType: string, content: string): void => {
receivedCommands.push({ commandType, content });
});
......
......@@ -53,7 +53,7 @@ async function startProcess(): Promise<void> {
});
// create IPC interface
dispatcher = await createDispatcherInterface(proc);
dispatcher = await createDispatcherInterface();
(<IpcInterface>dispatcher).onCommand((commandType: string, content: string): void => {
console.log(commandType, content);
});
......
......@@ -13,7 +13,7 @@ import { Database, DataStore } from '../../common/datastore';
import { Manager, ExperimentProfile} from '../../common/manager';
import { ExperimentManager } from '../../common/experimentManager';
import { TrainingService } from '../../common/trainingService';
import { cleanupUnitTest, prepareUnitTest } from '../../common/utils';
import { cleanupUnitTest, prepareUnitTest, killPid } from '../../common/utils';
import { NNIExperimentsManager } from 'extensions/experiments_manager';
import { NNIManager } from '../../core/nnimanager';
import { SqlDB } from '../../core/sqlDatabase';
......@@ -22,9 +22,11 @@ import { MockedDataStore } from '../mock/datastore';
import { TensorboardManager } from '../../common/tensorboardManager';
import { NNITensorboardManager } from 'extensions/nniTensorboardManager';
import * as path from 'path';
import { UnitTestHelpers } from 'core/ipcInterface';
async function initContainer(): Promise<void> {
prepareUnitTest();
UnitTestHelpers.disableTuner();
Container.bind(Manager).to(NNIManager).scope(Scope.Singleton);
Container.bind(Database).to(SqlDB).scope(Scope.Singleton);
Container.bind(DataStore).to(MockedDataStore).scope(Scope.Singleton);
......@@ -134,8 +136,11 @@ describe('Unit test for nnimanager', function () {
})
after(async () => {
// FIXME
await (nniManager as any).stopExperimentTopHalf();
// FIXME: more proper clean up
const manager: any = nniManager;
await killPid(manager.dispatcherPid);
manager.dispatcherPid = 0;
await manager.stopExperimentTopHalf();
cleanupUnitTest();
})
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment