Unverified Commit b1d4c129 authored by fishyds's avatar fishyds Committed by GitHub
Browse files

[PAI training service] Support running multiple PAI experiment (#348)

* Change base image from devel to runtime, to reduce docker image size

* Support running multiple experiment for PAI

* Fix a bug regarding to recuisively reference between paiRestServer and
paiTrainingService
parent 35e0832b
......@@ -26,15 +26,17 @@ import * as component from '../common/component';
class ExperimentStartupInfo {
private experimentId: string = '';
private newExperiment: boolean = true;
private basePort: number = -1;
private initialized: boolean = false;
private initTrialSequenceID: number = 0;
public setStartupInfo(newExperiment: boolean, experimentId: string): void {
public setStartupInfo(newExperiment: boolean, experimentId: string, basePort: number): void {
assert(!this.initialized);
assert(experimentId.trim().length > 0);
this.newExperiment = newExperiment;
this.experimentId = experimentId;
this.basePort = basePort;
this.initialized = true;
}
......@@ -44,6 +46,12 @@ class ExperimentStartupInfo {
return this.experimentId;
}
public getBasePort(): number {
assert(this.initialized);
return this.basePort;
}
public isNewExperiment(): boolean {
assert(this.initialized);
......@@ -66,6 +74,10 @@ function getExperimentId(): string {
return component.get<ExperimentStartupInfo>(ExperimentStartupInfo).getExperimentId();
}
function getBasePort(): number {
return component.get<ExperimentStartupInfo>(ExperimentStartupInfo).getBasePort();
}
function isNewExperiment(): boolean {
return component.get<ExperimentStartupInfo>(ExperimentStartupInfo).isNewExperiment();
}
......@@ -78,9 +90,9 @@ function getInitTrialSequenceId(): number {
return component.get<ExperimentStartupInfo>(ExperimentStartupInfo).getInitTrialSequenceId();
}
function setExperimentStartupInfo(newExperiment: boolean, experimentId: string): void {
component.get<ExperimentStartupInfo>(ExperimentStartupInfo).setStartupInfo(newExperiment, experimentId);
function setExperimentStartupInfo(newExperiment: boolean, experimentId: string, basePort: number): void {
component.get<ExperimentStartupInfo>(ExperimentStartupInfo).setStartupInfo(newExperiment, experimentId, basePort);
}
export { ExperimentStartupInfo, getExperimentId, isNewExperiment,
export { ExperimentStartupInfo, getBasePort, getExperimentId, isNewExperiment,
setExperimentStartupInfo, setInitTrialSequenceId, getInitTrialSequenceId };
......@@ -19,10 +19,12 @@
'use strict';
import * as assert from 'assert';
import * as express from 'express';
import * as http from 'http';
import { Deferred } from 'ts-deferred';
import { getLogger, Logger } from './log';
import { getBasePort } from './experimentStartupInfo';
/**
* Abstraction class to create a RestServer
......@@ -39,13 +41,20 @@ export abstract class RestServer {
protected port?: number;
protected app: express.Application = express();
protected log: Logger = getLogger();
protected basePort?: number;
constructor() {
this.port = getBasePort();
assert(this.port && this.port > 1024);
}
get endPoint(): string {
// tslint:disable-next-line:no-http-string
return `http://${this.hostName}:${this.port}`;
}
public start(port?: number, hostName?: string): Promise<void> {
public start(hostName?: string): Promise<void> {
this.log.info(`RestServer start`);
if (this.startTask !== undefined) {
return this.startTask.promise;
}
......@@ -56,9 +65,8 @@ export abstract class RestServer {
if (hostName) {
this.hostName = hostName;
}
if (port) {
this.port = port;
}
this.log.info(`RestServer base port is ${this.port}`);
this.server = this.app.listen(this.port as number, this.hostName).on('listening', () => {
this.startTask.resolve();
......
......@@ -222,7 +222,7 @@ function prepareUnitTest(): void {
Container.snapshot(TrainingService);
Container.snapshot(Manager);
setExperimentStartupInfo(true, 'unittest');
setExperimentStartupInfo(true, 'unittest', 8080);
mkDirPSync(getLogDir());
const sqliteFile: string = path.join(getDefaultDatabaseDir(), 'nni.sqlite');
......
......@@ -39,10 +39,10 @@ import {
import { PAITrainingService } from './training_service/pai/paiTrainingService'
function initStartupInfo(startExpMode: string, resumeExperimentId: string) {
function initStartupInfo(startExpMode: string, resumeExperimentId: string, basePort: number) {
const createNew: boolean = (startExpMode === 'new');
const expId: string = createNew ? uniqueString(8) : resumeExperimentId;
setExperimentStartupInfo(createNew, expId);
setExperimentStartupInfo(createNew, expId, basePort);
}
async function initContainer(platformMode: string): Promise<void> {
......@@ -93,14 +93,14 @@ if (startMode === 'resume' && experimentId.trim().length < 1) {
process.exit(1);
}
initStartupInfo(startMode, experimentId);
initStartupInfo(startMode, experimentId, port);
mkDirP(getLogDir()).then(async () => {
const log: Logger = getLogger();
try {
await initContainer(mode);
const restServer: NNIRestServer = component.get(NNIRestServer);
await restServer.start(port);
await restServer.start();
log.info(`Rest server listening on: ${restServer.endPoint}`);
} catch (err) {
log.error(`${err.stack}`);
......
......@@ -62,8 +62,8 @@ fi`;
export const PAI_TRIAL_COMMAND_FORMAT: string =
`export NNI_PLATFORM=pai NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={2} NNI_EXP_ID={3}
&& cd $NNI_SYS_DIR && sh install_nni.sh
&& python3 -m nni_trial_tool.trial_keeper --trial_command '{4}' --nnimanager_ip '{5}' --pai_hdfs_output_dir '{6}'
--pai_hdfs_host '{7}' --pai_user_name {8}`;
&& python3 -m nni_trial_tool.trial_keeper --trial_command '{4}' --nnimanager_ip '{5}' --nnimanager_port '{6}'
--pai_hdfs_output_dir '{7}' --pai_hdfs_host '{8}' --pai_user_name {9}`;
export const PAI_OUTPUT_DIR_FORMAT: string =
`hdfs://{0}:9000/`;
......
......@@ -19,9 +19,11 @@
'use strict';
import * as assert from 'assert';
import { Request, Response, Router } from 'express';
import * as bodyParser from 'body-parser';
import * as component from '../../common/component';
import { getBasePort } from '../../common/experimentStartupInfo';
import { getExperimentId } from '../../common/experimentStartupInfo';
import { Inject } from 'typescript-ioc';
import { PAITrainingService } from './paiTrainingService';
......@@ -48,10 +50,20 @@ export class PAIJobRestServer extends RestServer{
*/
constructor() {
super();
this.port = PAIJobRestServer.DEFAULT_PORT;
const basePort: number = getBasePort();
assert(basePort && basePort > 1024);
this.port = basePort + 1; // PAIJobRestServer.DEFAULT_PORT;
this.paiTrainingService = component.get(PAITrainingService);
}
public get paiRestServerPort(): number {
if(!this.port) {
throw new Error('PAI Rest server port is undefined');
}
return this.port;
}
/**
* NNIRestServer's own router registration
*/
......
......@@ -68,6 +68,7 @@ class PAITrainingService implements TrainingService {
private hdfsBaseDir: string | undefined;
private hdfsOutputHost: string | undefined;
private trialSequenceId: number;
private paiRestServerPort?: number;
constructor() {
this.log = getLogger();
......@@ -145,6 +146,11 @@ class PAITrainingService implements TrainingService {
throw new Error('hdfsOutputHost is not initialized');
}
if(!this.paiRestServerPort) {
const restServer: PAIJobRestServer = component.get(PAIJobRestServer);
this.paiRestServerPort = restServer.paiRestServerPort;
}
this.log.info(`submitTrialJob: form: ${JSON.stringify(form)}`);
const trialJobId: string = uniqueString(5);
......@@ -200,6 +206,7 @@ class PAITrainingService implements TrainingService {
this.experimentId,
this.paiTrialConfig.command,
getIPV4Address(),
this.paiRestServerPort,
hdfsOutputDir,
this.hdfsOutputHost,
this.paiClusterConfig.userName
......
......@@ -24,8 +24,6 @@ API_ROOT_URL = '/api/v1/nni-pai'
BASE_URL = 'http://{}'
DEFAULT_REST_PORT = 51189
HOME_DIR = os.path.join(os.environ['HOME'], 'nni')
LOG_DIR = os.environ['NNI_OUTPUT_DIR']
......
......@@ -24,7 +24,7 @@ import os
import re
import requests
from .constants import BASE_URL, DEFAULT_REST_PORT
from .constants import BASE_URL
from .rest_utils import rest_get, rest_post, rest_put, rest_delete
from .url_utils import gen_update_metrics_url
......@@ -40,11 +40,10 @@ class TrialMetricsReader():
'''
Read metrics data from a trial job
'''
def __init__(self, rest_port = DEFAULT_REST_PORT):
def __init__(self):
metrics_base_dir = os.path.join(NNI_SYS_DIR, '.nni')
self.offset_filename = os.path.join(metrics_base_dir, 'metrics_offset')
self.metrics_filename = os.path.join(metrics_base_dir, 'metrics')
self.rest_port = rest_port
if not os.path.exists(metrics_base_dir):
os.makedirs(metrics_base_dir)
......@@ -107,7 +106,7 @@ class TrialMetricsReader():
offset = self._get_offset()
return self._read_all_available_records(offset)
def read_experiment_metrics(nnimanager_ip):
def read_experiment_metrics(nnimanager_ip, nnimanager_port):
'''
Read metrics data for specified trial jobs
'''
......@@ -118,7 +117,7 @@ def read_experiment_metrics(nnimanager_ip):
result['metrics'] = reader.read_trial_metrics()
print('Result metrics is {}'.format(json.dumps(result)))
if len(result['metrics']) > 0:
response = rest_post(gen_update_metrics_url(BASE_URL.format(nnimanager_ip), DEFAULT_REST_PORT, NNI_EXP_ID, NNI_TRIAL_JOB_ID), json.dumps(result), 10)
response = rest_post(gen_update_metrics_url(BASE_URL.format(nnimanager_ip), nnimanager_port, NNI_EXP_ID, NNI_TRIAL_JOB_ID), json.dumps(result), 10)
print('Response code is {}'.format(response.status_code))
except Exception:
#TODO error logging to file
......
......@@ -48,7 +48,7 @@ def main_loop(args):
while True:
retCode = process.poll()
## Read experiment metrics, to avoid missing metrics
read_experiment_metrics(args.nnimanager_ip)
read_experiment_metrics(args.nnimanager_ip, args.nnimanager_port)
if retCode is not None:
print('subprocess terminated. Exit code is {}. Quit'.format(retCode))
......@@ -80,7 +80,8 @@ if __name__ == '__main__':
PARSER = argparse.ArgumentParser()
PARSER.set_defaults(func=trial_keeper_help_info)
PARSER.add_argument('--trial_command', type=str, help='Command to launch trial process')
PARSER.add_argument('--nnimanager_ip', type=str, default='localhost', help='NNI manager IP')
PARSER.add_argument('--nnimanager_ip', type=str, default='localhost', help='NNI manager rest server IP')
PARSER.add_argument('--nnimanager_port', type=str, default='8081', help='NNI manager rest server port')
PARSER.add_argument('--pai_hdfs_output_dir', type=str, help='the output dir of hdfs')
PARSER.add_argument('--pai_hdfs_host', type=str, help='the host of hdfs')
PARSER.add_argument('--pai_user_name', type=str, help='the username of hdfs')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment