Unverified Commit 35c3d169 authored by Ni Hao's avatar Ni Hao Committed by GitHub
Browse files

Support prefix url for nnimanager (#3643)


Co-authored-by: default avatarHao Ni <v-nihao@microsoft.com>
parent 9444e275
...@@ -108,7 +108,7 @@ class Experiments: ...@@ -108,7 +108,7 @@ class Experiments:
self.experiments = self.read_file() self.experiments = self.read_file()
def add_experiment(self, expId, port, startTime, platform, experiment_name, endTime='N/A', status='INITIALIZED', def add_experiment(self, expId, port, startTime, platform, experiment_name, endTime='N/A', status='INITIALIZED',
tag=[], pid=None, webuiUrl=[], logDir=''): tag=[], pid=None, webuiUrl=[], logDir='', prefixUrl=None):
'''set {key:value} pairs to self.experiment''' '''set {key:value} pairs to self.experiment'''
with self.lock: with self.lock:
self.experiments = self.read_file() self.experiments = self.read_file()
...@@ -124,6 +124,7 @@ class Experiments: ...@@ -124,6 +124,7 @@ class Experiments:
self.experiments[expId]['pid'] = pid self.experiments[expId]['pid'] = pid
self.experiments[expId]['webuiUrl'] = webuiUrl self.experiments[expId]['webuiUrl'] = webuiUrl
self.experiments[expId]['logDir'] = str(logDir) self.experiments[expId]['logDir'] = str(logDir)
self.experiments[expId]['prefixUrl'] = prefixUrl
self.write_file() self.write_file()
def update_experiment(self, expId, key, value): def update_experiment(self, expId, key, value):
......
...@@ -9,6 +9,7 @@ import string ...@@ -9,6 +9,7 @@ import string
import random import random
import time import time
import tempfile import tempfile
import re
from subprocess import Popen, check_call, CalledProcessError, PIPE, STDOUT from subprocess import Popen, check_call, CalledProcessError, PIPE, STDOUT
from nni.experiment.config import ExperimentConfig, convert from nni.experiment.config import ExperimentConfig, convert
from nni.tools.annotation import expand_annotations, generate_search_space from nni.tools.annotation import expand_annotations, generate_search_space
...@@ -16,7 +17,7 @@ from nni.tools.package_utils import get_builtin_module_class_name ...@@ -16,7 +17,7 @@ from nni.tools.package_utils import get_builtin_module_class_name
import nni_node # pylint: disable=import-error import nni_node # pylint: disable=import-error
from .launcher_utils import validate_all_content from .launcher_utils import validate_all_content
from .rest_utils import rest_put, rest_post, check_rest_server, check_response from .rest_utils import rest_put, rest_post, check_rest_server, check_response
from .url_utils import cluster_metadata_url, experiment_url, get_local_urls from .url_utils import cluster_metadata_url, experiment_url, get_local_urls, setPrefixUrl, formatURLPath
from .config_utils import Config, Experiments from .config_utils import Config, Experiments
from .common_utils import get_yml_content, get_json_content, print_error, print_normal, print_warning, \ from .common_utils import get_yml_content, get_json_content, print_error, print_normal, print_warning, \
detect_port, get_user detect_port, get_user
...@@ -43,7 +44,7 @@ def print_log_content(config_file_name): ...@@ -43,7 +44,7 @@ def print_log_content(config_file_name):
print_normal(' Stderr:') print_normal(' Stderr:')
print(check_output_command(stderr_full_path)) print(check_output_command(stderr_full_path))
def start_rest_server(port, platform, mode, experiment_id, foreground=False, log_dir=None, log_level=None): def start_rest_server(port, platform, mode, experiment_id, foreground=False, log_dir=None, log_level=None, url_prefix=None):
'''Run nni manager process''' '''Run nni manager process'''
if detect_port(port): if detect_port(port):
print_error('Port %s is used by another process, please reset the port!\n' \ print_error('Port %s is used by another process, please reset the port!\n' \
...@@ -81,6 +82,11 @@ def start_rest_server(port, platform, mode, experiment_id, foreground=False, log ...@@ -81,6 +82,11 @@ def start_rest_server(port, platform, mode, experiment_id, foreground=False, log
cmds += ['--log_level', log_level] cmds += ['--log_level', log_level]
if foreground: if foreground:
cmds += ['--foreground', 'true'] cmds += ['--foreground', 'true']
if url_prefix:
_validate_prefix_path(url_prefix)
setPrefixUrl(url_prefix)
cmds += ['--url_prefix', url_prefix]
stdout_full_path, stderr_full_path = get_log_path(experiment_id) stdout_full_path, stderr_full_path = get_log_path(experiment_id)
with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file: with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file:
start_time = time.time() start_time = time.time()
...@@ -384,11 +390,12 @@ def launch_experiment(args, experiment_config, mode, experiment_id, config_versi ...@@ -384,11 +390,12 @@ def launch_experiment(args, experiment_config, mode, experiment_id, config_versi
platform = experiment_config['trainingService']['platform'] platform = experiment_config['trainingService']['platform']
rest_process, start_time = start_rest_server(args.port, platform, \ rest_process, start_time = start_rest_server(args.port, platform, \
mode, experiment_id, foreground, log_dir, log_level) mode, experiment_id, foreground, log_dir, log_level, args.url_prefix)
# save experiment information # save experiment information
Experiments().add_experiment(experiment_id, args.port, start_time, Experiments().add_experiment(experiment_id, args.port, start_time,
platform, platform,
experiment_config.get('experimentName', 'N/A'), pid=rest_process.pid, logDir=log_dir) experiment_config.get('experimentName', 'N/A')
, pid=rest_process.pid, logDir=log_dir, prefixUrl=args.url_prefix)
# Deal with annotation # Deal with annotation
if experiment_config.get('useAnnotation'): if experiment_config.get('useAnnotation'):
path = os.path.join(tempfile.gettempdir(), get_user(), 'nni', 'annotation') path = os.path.join(tempfile.gettempdir(), get_user(), 'nni', 'annotation')
...@@ -446,9 +453,9 @@ def launch_experiment(args, experiment_config, mode, experiment_id, config_versi ...@@ -446,9 +453,9 @@ def launch_experiment(args, experiment_config, mode, experiment_id, config_versi
raise Exception(ERROR_INFO % 'Restful server stopped!') raise Exception(ERROR_INFO % 'Restful server stopped!')
exit(1) exit(1)
if experiment_config.get('nniManagerIp'): if experiment_config.get('nniManagerIp'):
web_ui_url_list = ['http://{0}:{1}'.format(experiment_config['nniManagerIp'], str(args.port))] web_ui_url_list = ['http://{0}:{1}{2}'.format(experiment_config['nniManagerIp'], str(args.port), formatURLPath(args.url_prefix))]
else: else:
web_ui_url_list = get_local_urls(args.port) web_ui_url_list = get_local_urls(args.port, args.url_prefix)
Experiments().update_experiment(experiment_id, 'webuiUrl', web_ui_url_list) Experiments().update_experiment(experiment_id, 'webuiUrl', web_ui_url_list)
print_normal(EXPERIMENT_SUCCESS_INFO % (experiment_id, ' '.join(web_ui_url_list))) print_normal(EXPERIMENT_SUCCESS_INFO % (experiment_id, ' '.join(web_ui_url_list)))
...@@ -476,6 +483,9 @@ def _validate_v2(config, path): ...@@ -476,6 +483,9 @@ def _validate_v2(config, path):
except Exception as e: except Exception as e:
print_error(f'Config V2 validation failed: {repr(e)}') print_error(f'Config V2 validation failed: {repr(e)}')
def _validate_prefix_path(path):
assert re.match("^[A-Za-z0-9_-]*$", path), "prefix url is invalid."
def create_experiment(args): def create_experiment(args):
'''start a new experiment''' '''start a new experiment'''
experiment_id = ''.join(random.sample(string.ascii_letters + string.digits, 8)) experiment_id = ''.join(random.sample(string.ascii_letters + string.digits, 8))
...@@ -533,6 +543,7 @@ def manage_stopped_experiment(args, mode): ...@@ -533,6 +543,7 @@ def manage_stopped_experiment(args, mode):
print_normal('{0} experiment {1}...'.format(mode, experiment_id)) print_normal('{0} experiment {1}...'.format(mode, experiment_id))
experiment_config = Config(experiment_id, experiments_dict[args.id]['logDir']).get_config() experiment_config = Config(experiment_id, experiments_dict[args.id]['logDir']).get_config()
experiments_config.update_experiment(args.id, 'port', args.port) experiments_config.update_experiment(args.id, 'port', args.port)
args.url_prefix = experiments_dict[args.id]['prefixUrl']
assert 'trainingService' in experiment_config or 'trainingServicePlatform' in experiment_config assert 'trainingService' in experiment_config or 'trainingServicePlatform' in experiment_config
try: try:
if 'trainingServicePlatform' in experiment_config: if 'trainingServicePlatform' in experiment_config:
......
...@@ -54,6 +54,7 @@ def parse_args(): ...@@ -54,6 +54,7 @@ def parse_args():
parser_start.add_argument('--config', '-c', required=True, dest='config', help='the path of yaml config file') parser_start.add_argument('--config', '-c', required=True, dest='config', help='the path of yaml config file')
parser_start.add_argument('--port', '-p', default=DEFAULT_REST_PORT, dest='port', type=int, help='the port of restful server') parser_start.add_argument('--port', '-p', default=DEFAULT_REST_PORT, dest='port', type=int, help='the port of restful server')
parser_start.add_argument('--debug', '-d', action='store_true', help=' set debug mode') parser_start.add_argument('--debug', '-d', action='store_true', help=' set debug mode')
parser_start.add_argument('--url_prefix', '-u', dest='url_prefix', help=' set prefix url')
parser_start.add_argument('--foreground', '-f', action='store_true', help=' set foreground mode, print log content to terminal') parser_start.add_argument('--foreground', '-f', action='store_true', help=' set foreground mode, print log content to terminal')
parser_start.set_defaults(func=create_experiment) parser_start.set_defaults(func=create_experiment)
......
...@@ -24,6 +24,13 @@ TENSORBOARD_API = '/tensorboard' ...@@ -24,6 +24,13 @@ TENSORBOARD_API = '/tensorboard'
METRIC_DATA_API = '/metric-data' METRIC_DATA_API = '/metric-data'
def formatURLPath(path):
return '' if path is None else '/{0}'.format(path)
def setPrefixUrl(prefix_path):
global API_ROOT_URL
API_ROOT_URL = formatURLPath(prefix_path)
def metric_data_url(port): def metric_data_url(port):
'''get metric_data url''' '''get metric_data url'''
return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, METRIC_DATA_API) return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, METRIC_DATA_API)
...@@ -60,7 +67,7 @@ def trial_job_id_url(port, job_id): ...@@ -60,7 +67,7 @@ def trial_job_id_url(port, job_id):
def export_data_url(port): def export_data_url(port):
'''get export_data url''' '''get export_data url'''
return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, EXPORT_DATA_API) return '{0}:{1}{2}{3}{4}'.format(BASE_URL, port, API_ROOT_URL, EXPORT_DATA_API)
def tensorboard_url(port): def tensorboard_url(port):
...@@ -68,11 +75,11 @@ def tensorboard_url(port): ...@@ -68,11 +75,11 @@ def tensorboard_url(port):
return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, TENSORBOARD_API) return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, TENSORBOARD_API)
def get_local_urls(port): def get_local_urls(port,prefix):
'''get urls of local machine''' '''get urls of local machine'''
url_list = [] url_list = []
for _, info in psutil.net_if_addrs().items(): for _, info in psutil.net_if_addrs().items():
for addr in info: for addr in info:
if socket.AddressFamily.AF_INET == addr.family: if socket.AddressFamily.AF_INET == addr.family:
url_list.append('http://{}:{}'.format(addr.address, port)) url_list.append('http://{0}:{1}{2}'.format(addr.address, port, formatURLPath(prefix)))
return url_list return url_list
...@@ -10,6 +10,8 @@ import * as component from '../common/component'; ...@@ -10,6 +10,8 @@ import * as component from '../common/component';
@component.Singleton @component.Singleton
class ExperimentStartupInfo { class ExperimentStartupInfo {
private readonly API_ROOT_URL: string = '/api/v1/nni';
private experimentId: string = ''; private experimentId: string = '';
private newExperiment: boolean = true; private newExperiment: boolean = true;
private basePort: number = -1; private basePort: number = -1;
...@@ -19,8 +21,9 @@ class ExperimentStartupInfo { ...@@ -19,8 +21,9 @@ class ExperimentStartupInfo {
private readonly: boolean = false; private readonly: boolean = false;
private dispatcherPipe: string | null = null; private dispatcherPipe: string | null = null;
private platform: string = ''; private platform: string = '';
private urlprefix: string = '';
public setStartupInfo(newExperiment: boolean, experimentId: string, basePort: number, platform: string, logDir?: string, logLevel?: string, readonly?: boolean, dispatcherPipe?: string): void { public setStartupInfo(newExperiment: boolean, experimentId: string, basePort: number, platform: string, logDir?: string, logLevel?: string, readonly?: boolean, dispatcherPipe?: string, urlprefix?: string): void {
assert(!this.initialized); assert(!this.initialized);
assert(experimentId.trim().length > 0); assert(experimentId.trim().length > 0);
this.newExperiment = newExperiment; this.newExperiment = newExperiment;
...@@ -46,6 +49,10 @@ class ExperimentStartupInfo { ...@@ -46,6 +49,10 @@ class ExperimentStartupInfo {
if (dispatcherPipe != undefined && dispatcherPipe.length > 0) { if (dispatcherPipe != undefined && dispatcherPipe.length > 0) {
this.dispatcherPipe = dispatcherPipe; this.dispatcherPipe = dispatcherPipe;
} }
if(urlprefix != undefined && urlprefix.length > 0){
this.urlprefix = urlprefix;
}
} }
public getExperimentId(): string { public getExperimentId(): string {
...@@ -94,6 +101,11 @@ class ExperimentStartupInfo { ...@@ -94,6 +101,11 @@ class ExperimentStartupInfo {
assert(this.initialized); assert(this.initialized);
return this.dispatcherPipe; return this.dispatcherPipe;
} }
public getAPIRootUrl(): string {
assert(this.initialized);
return this.urlprefix==''?this.API_ROOT_URL:`/${this.urlprefix}`;
}
} }
function getExperimentId(): string { function getExperimentId(): string {
...@@ -117,9 +129,9 @@ function getExperimentStartupInfo(): ExperimentStartupInfo { ...@@ -117,9 +129,9 @@ function getExperimentStartupInfo(): ExperimentStartupInfo {
} }
function setExperimentStartupInfo( function setExperimentStartupInfo(
newExperiment: boolean, experimentId: string, basePort: number, platform: string, logDir?: string, logLevel?: string, readonly?: boolean, dispatcherPipe?: string): void { newExperiment: boolean, experimentId: string, basePort: number, platform: string, logDir?: string, logLevel?: string, readonly?: boolean, dispatcherPipe?: string, urlprefix?: string): void {
component.get<ExperimentStartupInfo>(ExperimentStartupInfo) component.get<ExperimentStartupInfo>(ExperimentStartupInfo)
.setStartupInfo(newExperiment, experimentId, basePort, platform, logDir, logLevel, readonly, dispatcherPipe); .setStartupInfo(newExperiment, experimentId, basePort, platform, logDir, logLevel, readonly, dispatcherPipe, urlprefix);
} }
function isReadonly(): boolean { function isReadonly(): boolean {
...@@ -130,7 +142,11 @@ function getDispatcherPipe(): string | null { ...@@ -130,7 +142,11 @@ function getDispatcherPipe(): string | null {
return component.get<ExperimentStartupInfo>(ExperimentStartupInfo).getDispatcherPipe(); return component.get<ExperimentStartupInfo>(ExperimentStartupInfo).getDispatcherPipe();
} }
function getAPIRootUrl(): string {
return component.get<ExperimentStartupInfo>(ExperimentStartupInfo).getAPIRootUrl();
}
export { export {
ExperimentStartupInfo, getBasePort, getExperimentId, isNewExperiment, getPlatform, getExperimentStartupInfo, ExperimentStartupInfo, getBasePort, getExperimentId, isNewExperiment, getPlatform, getExperimentStartupInfo,
setExperimentStartupInfo, isReadonly, getDispatcherPipe setExperimentStartupInfo, isReadonly, getDispatcherPipe, getAPIRootUrl
}; };
...@@ -25,9 +25,9 @@ import { NNIRestServer } from './rest_server/nniRestServer'; ...@@ -25,9 +25,9 @@ import { NNIRestServer } from './rest_server/nniRestServer';
function initStartupInfo( function initStartupInfo(
startExpMode: string, experimentId: string, basePort: number, platform: string, startExpMode: string, experimentId: string, basePort: number, platform: string,
logDirectory: string, experimentLogLevel: string, readonly: boolean, dispatcherPipe: string): void { logDirectory: string, experimentLogLevel: string, readonly: boolean, dispatcherPipe: string, urlprefix: string): void {
const createNew: boolean = (startExpMode === ExperimentStartUpMode.NEW); const createNew: boolean = (startExpMode === ExperimentStartUpMode.NEW);
setExperimentStartupInfo(createNew, experimentId, basePort, platform, logDirectory, experimentLogLevel, readonly, dispatcherPipe); setExperimentStartupInfo(createNew, experimentId, basePort, platform, logDirectory, experimentLogLevel, readonly, dispatcherPipe, urlprefix);
} }
async function initContainer(foreground: boolean, platformMode: string, logFileName?: string): Promise<void> { async function initContainer(foreground: boolean, platformMode: string, logFileName?: string): Promise<void> {
...@@ -122,7 +122,9 @@ const readonly = readonlyArg.toLowerCase() == 'true' ? true : false; ...@@ -122,7 +122,9 @@ const readonly = readonlyArg.toLowerCase() == 'true' ? true : false;
const dispatcherPipe: string = parseArg(['--dispatcher_pipe']); const dispatcherPipe: string = parseArg(['--dispatcher_pipe']);
initStartupInfo(startMode, experimentId, port, mode, logDir, logLevel, readonly, dispatcherPipe); const urlPrefix: string = parseArg(['--url_prefix']);
initStartupInfo(startMode, experimentId, port, mode, logDir, logLevel, readonly, dispatcherPipe, urlPrefix);
mkDirP(getLogDir()) mkDirP(getLogDir())
.then(async () => { .then(async () => {
......
...@@ -10,6 +10,7 @@ import * as component from '../common/component'; ...@@ -10,6 +10,7 @@ import * as component from '../common/component';
import { RestServer } from '../common/restServer' import { RestServer } from '../common/restServer'
import { getLogDir } from '../common/utils'; import { getLogDir } from '../common/utils';
import { createRestHandler } from './restHandler'; import { createRestHandler } from './restHandler';
import { getAPIRootUrl } from '../common/experimentStartupInfo';
/** /**
* NNI Main rest server, provides rest API to support * NNI Main rest server, provides rest API to support
...@@ -19,14 +20,15 @@ import { createRestHandler } from './restHandler'; ...@@ -19,14 +20,15 @@ import { createRestHandler } from './restHandler';
*/ */
@component.Singleton @component.Singleton
export class NNIRestServer extends RestServer { export class NNIRestServer extends RestServer {
private readonly API_ROOT_URL: string = '/api/v1/nni';
private readonly LOGS_ROOT_URL: string = '/logs'; private readonly LOGS_ROOT_URL: string = '/logs';
protected API_ROOT_URL: string = '/api/v1/nni';
/** /**
* constructor to provide NNIRestServer's own rest property, e.g. port * constructor to provide NNIRestServer's own rest property, e.g. port
*/ */
constructor() { constructor() {
super(); super();
this.API_ROOT_URL = getAPIRootUrl();
} }
/** /**
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment