Unverified Commit b1d4c129 authored by fishyds's avatar fishyds Committed by GitHub
Browse files

[PAI training service] Support running multiple PAI experiment (#348)

* Change base image from devel to runtime, to reduce docker image size

* Support running multiple experiment for PAI

* Fix a bug regarding to recuisively reference between paiRestServer and
paiTrainingService
parent 35e0832b
...@@ -26,15 +26,17 @@ import * as component from '../common/component'; ...@@ -26,15 +26,17 @@ import * as component from '../common/component';
class ExperimentStartupInfo { class ExperimentStartupInfo {
private experimentId: string = ''; private experimentId: string = '';
private newExperiment: boolean = true; private newExperiment: boolean = true;
private basePort: number = -1;
private initialized: boolean = false; private initialized: boolean = false;
private initTrialSequenceID: number = 0; private initTrialSequenceID: number = 0;
public setStartupInfo(newExperiment: boolean, experimentId: string): void { public setStartupInfo(newExperiment: boolean, experimentId: string, basePort: number): void {
assert(!this.initialized); assert(!this.initialized);
assert(experimentId.trim().length > 0); assert(experimentId.trim().length > 0);
this.newExperiment = newExperiment; this.newExperiment = newExperiment;
this.experimentId = experimentId; this.experimentId = experimentId;
this.basePort = basePort;
this.initialized = true; this.initialized = true;
} }
...@@ -44,6 +46,12 @@ class ExperimentStartupInfo { ...@@ -44,6 +46,12 @@ class ExperimentStartupInfo {
return this.experimentId; return this.experimentId;
} }
public getBasePort(): number {
assert(this.initialized);
return this.basePort;
}
public isNewExperiment(): boolean { public isNewExperiment(): boolean {
assert(this.initialized); assert(this.initialized);
...@@ -66,6 +74,10 @@ function getExperimentId(): string { ...@@ -66,6 +74,10 @@ function getExperimentId(): string {
return component.get<ExperimentStartupInfo>(ExperimentStartupInfo).getExperimentId(); return component.get<ExperimentStartupInfo>(ExperimentStartupInfo).getExperimentId();
} }
function getBasePort(): number {
return component.get<ExperimentStartupInfo>(ExperimentStartupInfo).getBasePort();
}
function isNewExperiment(): boolean { function isNewExperiment(): boolean {
return component.get<ExperimentStartupInfo>(ExperimentStartupInfo).isNewExperiment(); return component.get<ExperimentStartupInfo>(ExperimentStartupInfo).isNewExperiment();
} }
...@@ -78,9 +90,9 @@ function getInitTrialSequenceId(): number { ...@@ -78,9 +90,9 @@ function getInitTrialSequenceId(): number {
return component.get<ExperimentStartupInfo>(ExperimentStartupInfo).getInitTrialSequenceId(); return component.get<ExperimentStartupInfo>(ExperimentStartupInfo).getInitTrialSequenceId();
} }
function setExperimentStartupInfo(newExperiment: boolean, experimentId: string): void { function setExperimentStartupInfo(newExperiment: boolean, experimentId: string, basePort: number): void {
component.get<ExperimentStartupInfo>(ExperimentStartupInfo).setStartupInfo(newExperiment, experimentId); component.get<ExperimentStartupInfo>(ExperimentStartupInfo).setStartupInfo(newExperiment, experimentId, basePort);
} }
export { ExperimentStartupInfo, getExperimentId, isNewExperiment, export { ExperimentStartupInfo, getBasePort, getExperimentId, isNewExperiment,
setExperimentStartupInfo, setInitTrialSequenceId, getInitTrialSequenceId }; setExperimentStartupInfo, setInitTrialSequenceId, getInitTrialSequenceId };
...@@ -19,10 +19,12 @@ ...@@ -19,10 +19,12 @@
'use strict'; 'use strict';
import * as assert from 'assert';
import * as express from 'express'; import * as express from 'express';
import * as http from 'http'; import * as http from 'http';
import { Deferred } from 'ts-deferred'; import { Deferred } from 'ts-deferred';
import { getLogger, Logger } from './log'; import { getLogger, Logger } from './log';
import { getBasePort } from './experimentStartupInfo';
/** /**
* Abstraction class to create a RestServer * Abstraction class to create a RestServer
...@@ -39,13 +41,20 @@ export abstract class RestServer { ...@@ -39,13 +41,20 @@ export abstract class RestServer {
protected port?: number; protected port?: number;
protected app: express.Application = express(); protected app: express.Application = express();
protected log: Logger = getLogger(); protected log: Logger = getLogger();
protected basePort?: number;
constructor() {
this.port = getBasePort();
assert(this.port && this.port > 1024);
}
get endPoint(): string { get endPoint(): string {
// tslint:disable-next-line:no-http-string // tslint:disable-next-line:no-http-string
return `http://${this.hostName}:${this.port}`; return `http://${this.hostName}:${this.port}`;
} }
public start(port?: number, hostName?: string): Promise<void> { public start(hostName?: string): Promise<void> {
this.log.info(`RestServer start`);
if (this.startTask !== undefined) { if (this.startTask !== undefined) {
return this.startTask.promise; return this.startTask.promise;
} }
...@@ -56,9 +65,8 @@ export abstract class RestServer { ...@@ -56,9 +65,8 @@ export abstract class RestServer {
if (hostName) { if (hostName) {
this.hostName = hostName; this.hostName = hostName;
} }
if (port) {
this.port = port; this.log.info(`RestServer base port is ${this.port}`);
}
this.server = this.app.listen(this.port as number, this.hostName).on('listening', () => { this.server = this.app.listen(this.port as number, this.hostName).on('listening', () => {
this.startTask.resolve(); this.startTask.resolve();
......
...@@ -222,7 +222,7 @@ function prepareUnitTest(): void { ...@@ -222,7 +222,7 @@ function prepareUnitTest(): void {
Container.snapshot(TrainingService); Container.snapshot(TrainingService);
Container.snapshot(Manager); Container.snapshot(Manager);
setExperimentStartupInfo(true, 'unittest'); setExperimentStartupInfo(true, 'unittest', 8080);
mkDirPSync(getLogDir()); mkDirPSync(getLogDir());
const sqliteFile: string = path.join(getDefaultDatabaseDir(), 'nni.sqlite'); const sqliteFile: string = path.join(getDefaultDatabaseDir(), 'nni.sqlite');
......
...@@ -39,10 +39,10 @@ import { ...@@ -39,10 +39,10 @@ import {
import { PAITrainingService } from './training_service/pai/paiTrainingService' import { PAITrainingService } from './training_service/pai/paiTrainingService'
function initStartupInfo(startExpMode: string, resumeExperimentId: string) { function initStartupInfo(startExpMode: string, resumeExperimentId: string, basePort: number) {
const createNew: boolean = (startExpMode === 'new'); const createNew: boolean = (startExpMode === 'new');
const expId: string = createNew ? uniqueString(8) : resumeExperimentId; const expId: string = createNew ? uniqueString(8) : resumeExperimentId;
setExperimentStartupInfo(createNew, expId); setExperimentStartupInfo(createNew, expId, basePort);
} }
async function initContainer(platformMode: string): Promise<void> { async function initContainer(platformMode: string): Promise<void> {
...@@ -93,14 +93,14 @@ if (startMode === 'resume' && experimentId.trim().length < 1) { ...@@ -93,14 +93,14 @@ if (startMode === 'resume' && experimentId.trim().length < 1) {
process.exit(1); process.exit(1);
} }
initStartupInfo(startMode, experimentId); initStartupInfo(startMode, experimentId, port);
mkDirP(getLogDir()).then(async () => { mkDirP(getLogDir()).then(async () => {
const log: Logger = getLogger(); const log: Logger = getLogger();
try { try {
await initContainer(mode); await initContainer(mode);
const restServer: NNIRestServer = component.get(NNIRestServer); const restServer: NNIRestServer = component.get(NNIRestServer);
await restServer.start(port); await restServer.start();
log.info(`Rest server listening on: ${restServer.endPoint}`); log.info(`Rest server listening on: ${restServer.endPoint}`);
} catch (err) { } catch (err) {
log.error(`${err.stack}`); log.error(`${err.stack}`);
......
...@@ -62,8 +62,8 @@ fi`; ...@@ -62,8 +62,8 @@ fi`;
export const PAI_TRIAL_COMMAND_FORMAT: string = export const PAI_TRIAL_COMMAND_FORMAT: string =
`export NNI_PLATFORM=pai NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={2} NNI_EXP_ID={3} `export NNI_PLATFORM=pai NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={2} NNI_EXP_ID={3}
&& cd $NNI_SYS_DIR && sh install_nni.sh && cd $NNI_SYS_DIR && sh install_nni.sh
&& python3 -m nni_trial_tool.trial_keeper --trial_command '{4}' --nnimanager_ip '{5}' --pai_hdfs_output_dir '{6}' && python3 -m nni_trial_tool.trial_keeper --trial_command '{4}' --nnimanager_ip '{5}' --nnimanager_port '{6}'
--pai_hdfs_host '{7}' --pai_user_name {8}`; --pai_hdfs_output_dir '{7}' --pai_hdfs_host '{8}' --pai_user_name {9}`;
export const PAI_OUTPUT_DIR_FORMAT: string = export const PAI_OUTPUT_DIR_FORMAT: string =
`hdfs://{0}:9000/`; `hdfs://{0}:9000/`;
......
...@@ -19,9 +19,11 @@ ...@@ -19,9 +19,11 @@
'use strict'; 'use strict';
import * as assert from 'assert';
import { Request, Response, Router } from 'express'; import { Request, Response, Router } from 'express';
import * as bodyParser from 'body-parser'; import * as bodyParser from 'body-parser';
import * as component from '../../common/component'; import * as component from '../../common/component';
import { getBasePort } from '../../common/experimentStartupInfo';
import { getExperimentId } from '../../common/experimentStartupInfo'; import { getExperimentId } from '../../common/experimentStartupInfo';
import { Inject } from 'typescript-ioc'; import { Inject } from 'typescript-ioc';
import { PAITrainingService } from './paiTrainingService'; import { PAITrainingService } from './paiTrainingService';
...@@ -48,10 +50,20 @@ export class PAIJobRestServer extends RestServer{ ...@@ -48,10 +50,20 @@ export class PAIJobRestServer extends RestServer{
*/ */
constructor() { constructor() {
super(); super();
this.port = PAIJobRestServer.DEFAULT_PORT; const basePort: number = getBasePort();
assert(basePort && basePort > 1024);
this.port = basePort + 1; // PAIJobRestServer.DEFAULT_PORT;
this.paiTrainingService = component.get(PAITrainingService); this.paiTrainingService = component.get(PAITrainingService);
} }
public get paiRestServerPort(): number {
if(!this.port) {
throw new Error('PAI Rest server port is undefined');
}
return this.port;
}
/** /**
* NNIRestServer's own router registration * NNIRestServer's own router registration
*/ */
......
...@@ -68,6 +68,7 @@ class PAITrainingService implements TrainingService { ...@@ -68,6 +68,7 @@ class PAITrainingService implements TrainingService {
private hdfsBaseDir: string | undefined; private hdfsBaseDir: string | undefined;
private hdfsOutputHost: string | undefined; private hdfsOutputHost: string | undefined;
private trialSequenceId: number; private trialSequenceId: number;
private paiRestServerPort?: number;
constructor() { constructor() {
this.log = getLogger(); this.log = getLogger();
...@@ -145,6 +146,11 @@ class PAITrainingService implements TrainingService { ...@@ -145,6 +146,11 @@ class PAITrainingService implements TrainingService {
throw new Error('hdfsOutputHost is not initialized'); throw new Error('hdfsOutputHost is not initialized');
} }
if(!this.paiRestServerPort) {
const restServer: PAIJobRestServer = component.get(PAIJobRestServer);
this.paiRestServerPort = restServer.paiRestServerPort;
}
this.log.info(`submitTrialJob: form: ${JSON.stringify(form)}`); this.log.info(`submitTrialJob: form: ${JSON.stringify(form)}`);
const trialJobId: string = uniqueString(5); const trialJobId: string = uniqueString(5);
...@@ -200,6 +206,7 @@ class PAITrainingService implements TrainingService { ...@@ -200,6 +206,7 @@ class PAITrainingService implements TrainingService {
this.experimentId, this.experimentId,
this.paiTrialConfig.command, this.paiTrialConfig.command,
getIPV4Address(), getIPV4Address(),
this.paiRestServerPort,
hdfsOutputDir, hdfsOutputDir,
this.hdfsOutputHost, this.hdfsOutputHost,
this.paiClusterConfig.userName this.paiClusterConfig.userName
......
...@@ -24,8 +24,6 @@ API_ROOT_URL = '/api/v1/nni-pai' ...@@ -24,8 +24,6 @@ API_ROOT_URL = '/api/v1/nni-pai'
BASE_URL = 'http://{}' BASE_URL = 'http://{}'
DEFAULT_REST_PORT = 51189
HOME_DIR = os.path.join(os.environ['HOME'], 'nni') HOME_DIR = os.path.join(os.environ['HOME'], 'nni')
LOG_DIR = os.environ['NNI_OUTPUT_DIR'] LOG_DIR = os.environ['NNI_OUTPUT_DIR']
......
...@@ -24,7 +24,7 @@ import os ...@@ -24,7 +24,7 @@ import os
import re import re
import requests import requests
from .constants import BASE_URL, DEFAULT_REST_PORT from .constants import BASE_URL
from .rest_utils import rest_get, rest_post, rest_put, rest_delete from .rest_utils import rest_get, rest_post, rest_put, rest_delete
from .url_utils import gen_update_metrics_url from .url_utils import gen_update_metrics_url
...@@ -40,11 +40,10 @@ class TrialMetricsReader(): ...@@ -40,11 +40,10 @@ class TrialMetricsReader():
''' '''
Read metrics data from a trial job Read metrics data from a trial job
''' '''
def __init__(self, rest_port = DEFAULT_REST_PORT): def __init__(self):
metrics_base_dir = os.path.join(NNI_SYS_DIR, '.nni') metrics_base_dir = os.path.join(NNI_SYS_DIR, '.nni')
self.offset_filename = os.path.join(metrics_base_dir, 'metrics_offset') self.offset_filename = os.path.join(metrics_base_dir, 'metrics_offset')
self.metrics_filename = os.path.join(metrics_base_dir, 'metrics') self.metrics_filename = os.path.join(metrics_base_dir, 'metrics')
self.rest_port = rest_port
if not os.path.exists(metrics_base_dir): if not os.path.exists(metrics_base_dir):
os.makedirs(metrics_base_dir) os.makedirs(metrics_base_dir)
...@@ -107,7 +106,7 @@ class TrialMetricsReader(): ...@@ -107,7 +106,7 @@ class TrialMetricsReader():
offset = self._get_offset() offset = self._get_offset()
return self._read_all_available_records(offset) return self._read_all_available_records(offset)
def read_experiment_metrics(nnimanager_ip): def read_experiment_metrics(nnimanager_ip, nnimanager_port):
''' '''
Read metrics data for specified trial jobs Read metrics data for specified trial jobs
''' '''
...@@ -118,7 +117,7 @@ def read_experiment_metrics(nnimanager_ip): ...@@ -118,7 +117,7 @@ def read_experiment_metrics(nnimanager_ip):
result['metrics'] = reader.read_trial_metrics() result['metrics'] = reader.read_trial_metrics()
print('Result metrics is {}'.format(json.dumps(result))) print('Result metrics is {}'.format(json.dumps(result)))
if len(result['metrics']) > 0: if len(result['metrics']) > 0:
response = rest_post(gen_update_metrics_url(BASE_URL.format(nnimanager_ip), DEFAULT_REST_PORT, NNI_EXP_ID, NNI_TRIAL_JOB_ID), json.dumps(result), 10) response = rest_post(gen_update_metrics_url(BASE_URL.format(nnimanager_ip), nnimanager_port, NNI_EXP_ID, NNI_TRIAL_JOB_ID), json.dumps(result), 10)
print('Response code is {}'.format(response.status_code)) print('Response code is {}'.format(response.status_code))
except Exception: except Exception:
#TODO error logging to file #TODO error logging to file
......
...@@ -48,7 +48,7 @@ def main_loop(args): ...@@ -48,7 +48,7 @@ def main_loop(args):
while True: while True:
retCode = process.poll() retCode = process.poll()
## Read experiment metrics, to avoid missing metrics ## Read experiment metrics, to avoid missing metrics
read_experiment_metrics(args.nnimanager_ip) read_experiment_metrics(args.nnimanager_ip, args.nnimanager_port)
if retCode is not None: if retCode is not None:
print('subprocess terminated. Exit code is {}. Quit'.format(retCode)) print('subprocess terminated. Exit code is {}. Quit'.format(retCode))
...@@ -80,7 +80,8 @@ if __name__ == '__main__': ...@@ -80,7 +80,8 @@ if __name__ == '__main__':
PARSER = argparse.ArgumentParser() PARSER = argparse.ArgumentParser()
PARSER.set_defaults(func=trial_keeper_help_info) PARSER.set_defaults(func=trial_keeper_help_info)
PARSER.add_argument('--trial_command', type=str, help='Command to launch trial process') PARSER.add_argument('--trial_command', type=str, help='Command to launch trial process')
PARSER.add_argument('--nnimanager_ip', type=str, default='localhost', help='NNI manager IP') PARSER.add_argument('--nnimanager_ip', type=str, default='localhost', help='NNI manager rest server IP')
PARSER.add_argument('--nnimanager_port', type=str, default='8081', help='NNI manager rest server port')
PARSER.add_argument('--pai_hdfs_output_dir', type=str, help='the output dir of hdfs') PARSER.add_argument('--pai_hdfs_output_dir', type=str, help='the output dir of hdfs')
PARSER.add_argument('--pai_hdfs_host', type=str, help='the host of hdfs') PARSER.add_argument('--pai_hdfs_host', type=str, help='the host of hdfs')
PARSER.add_argument('--pai_user_name', type=str, help='the username of hdfs') PARSER.add_argument('--pai_user_name', type=str, help='the username of hdfs')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment