You need to sign in or sign up before continuing.
Commit ba8dccd6 authored by suiguoxin's avatar suiguoxin
Browse files

Merge branch 'master' of https://github.com/microsoft/nni

parents 56a1575b 150ee83a
# How to use ga_customer_tuner? # How to use ga_customer_tuner?
This tuner is a customized tuner which only suitable for trial whose code path is "~/nni/examples/trials/ga_squad", This tuner is a customized tuner which only suitable for trial whose code path is "~/nni/examples/trials/ga_squad",
type `cd ~/nni/examples/trials/ga_squad` and check readme.md to get more information for ga_squad trial. type `cd ~/nni/examples/trials/ga_squad` and check readme.md to get more information for ga_squad trial.
# config # config
If you want to use ga_customer_tuner in your experiment, you could set config file as following format: If you want to use ga_customer_tuner in your experiment, you could set config file as following format:
``` ```
......
# How to use ga_customer_tuner? # How to use ga_customer_tuner?
This tuner is a customized tuner which only suitable for trial whose code path is "~/nni/examples/trials/ga_squad", This tuner is a customized tuner which only suitable for trial whose code path is "~/nni/examples/trials/ga_squad",
type `cd ~/nni/examples/trials/ga_squad` and check readme.md to get more information for ga_squad trial. type `cd ~/nni/examples/trials/ga_squad` and check readme.md to get more information for ga_squad trial.
# config # config
If you want to use ga_customer_tuner in your experiment, you could set config file as following format: If you want to use ga_customer_tuner in your experiment, you could set config file as following format:
``` ```
......
...@@ -29,7 +29,7 @@ import { getBasePort } from './experimentStartupInfo'; ...@@ -29,7 +29,7 @@ import { getBasePort } from './experimentStartupInfo';
/** /**
* Abstraction class to create a RestServer * Abstraction class to create a RestServer
* The module who wants to use a RestServer could <b>extends</b> this abstract class * The module who wants to use a RestServer could <b>extends</b> this abstract class
* And implement its own registerRestHandler() function to register routers * And implement its own registerRestHandler() function to register routers
*/ */
export abstract class RestServer { export abstract class RestServer {
...@@ -43,7 +43,7 @@ export abstract class RestServer { ...@@ -43,7 +43,7 @@ export abstract class RestServer {
protected app: express.Application = express(); protected app: express.Application = express();
protected log: Logger = getLogger(); protected log: Logger = getLogger();
protected basePort?: number; protected basePort?: number;
constructor() { constructor() {
this.port = getBasePort(); this.port = getBasePort();
assert(this.port && this.port > 1024); assert(this.port && this.port > 1024);
...@@ -91,9 +91,9 @@ export abstract class RestServer { ...@@ -91,9 +91,9 @@ export abstract class RestServer {
} else { } else {
this.startTask.promise.then( this.startTask.promise.then(
() => { // Started () => { // Started
//Stops the server from accepting new connections and keeps existing connections. //Stops the server from accepting new connections and keeps existing connections.
//This function is asynchronous, the server is finally closed when all connections //This function is asynchronous, the server is finally closed when all connections
//are ended and the server emits a 'close' event. //are ended and the server emits a 'close' event.
//Refer https://nodejs.org/docs/latest/api/net.html#net_server_close_callback //Refer https://nodejs.org/docs/latest/api/net.html#net_server_close_callback
this.server.close().on('close', () => { this.server.close().on('close', () => {
this.log.info('Rest server stopped.'); this.log.info('Rest server stopped.');
......
...@@ -91,6 +91,7 @@ interface TrialJobMetric { ...@@ -91,6 +91,7 @@ interface TrialJobMetric {
* define TrainingServiceError * define TrainingServiceError
*/ */
class TrainingServiceError extends Error { class TrainingServiceError extends Error {
private errCode: number; private errCode: number;
constructor(errorCode: number, errorMessage: string) { constructor(errorCode: number, errorMessage: string) {
...@@ -136,5 +137,3 @@ export { ...@@ -136,5 +137,3 @@ export {
TrainingServiceMetadata, TrialJobDetail, TrialJobMetric, HyperParameters, TrainingServiceMetadata, TrialJobDetail, TrialJobMetric, HyperParameters,
HostJobApplicationForm, JobApplicationForm, JobType, NNIManagerIpConfig HostJobApplicationForm, JobApplicationForm, JobType, NNIManagerIpConfig
}; };
...@@ -167,7 +167,7 @@ function getCmdPy(): string { ...@@ -167,7 +167,7 @@ function getCmdPy(): string {
} }
/** /**
* Generate command line to start automl algorithm(s), * Generate command line to start automl algorithm(s),
* either start advisor or start a process which runs tuner and assessor * either start advisor or start a process which runs tuner and assessor
* @param tuner : For builtin tuner: * @param tuner : For builtin tuner:
* { * {
...@@ -361,11 +361,11 @@ function countFilesRecursively(directory: string, timeoutMilliSeconds?: number): ...@@ -361,11 +361,11 @@ function countFilesRecursively(directory: string, timeoutMilliSeconds?: number):
if(process.platform === "win32") { if(process.platform === "win32") {
cmd = `powershell "Get-ChildItem -Path ${directory} -Recurse -File | Measure-Object | %{$_.Count}"` cmd = `powershell "Get-ChildItem -Path ${directory} -Recurse -File | Measure-Object | %{$_.Count}"`
} else { } else {
cmd = `find ${directory} -type f | wc -l`; cmd = `find ${directory} -type f | wc -l`;
} }
cpp.exec(cmd).then((result) => { cpp.exec(cmd).then((result) => {
if(result.stdout && parseInt(result.stdout)) { if(result.stdout && parseInt(result.stdout)) {
fileCount = parseInt(result.stdout); fileCount = parseInt(result.stdout);
} }
deferred.resolve(fileCount); deferred.resolve(fileCount);
}); });
...@@ -374,6 +374,40 @@ function countFilesRecursively(directory: string, timeoutMilliSeconds?: number): ...@@ -374,6 +374,40 @@ function countFilesRecursively(directory: string, timeoutMilliSeconds?: number):
}); });
} }
function validateFileName(fileName: string): boolean {
let pattern: string = '^[a-z0-9A-Z\.-_]+$';
const validateResult = fileName.match(pattern);
if(validateResult) {
return true;
}
return false;
}
async function validateFileNameRecursively(directory: string): Promise<boolean> {
if(!fs.existsSync(directory)) {
throw Error(`Direcotory ${directory} doesn't exist`);
}
const fileNameArray: string[] = fs.readdirSync(directory);
let result = true;
for(var name of fileNameArray){
const fullFilePath: string = path.join(directory, name);
try {
// validate file names and directory names
result = validateFileName(name);
if (fs.lstatSync(fullFilePath).isDirectory()) {
result = result && await validateFileNameRecursively(fullFilePath);
}
if(!result) {
return Promise.reject(new Error(`file name in ${fullFilePath} is not valid!`));
}
} catch(error) {
return Promise.reject(error);
}
}
return Promise.resolve(result);
}
/** /**
* get the version of current package * get the version of current package
*/ */
...@@ -385,7 +419,7 @@ async function getVersion(): Promise<string> { ...@@ -385,7 +419,7 @@ async function getVersion(): Promise<string> {
deferred.reject(error); deferred.reject(error);
}); });
return deferred.promise; return deferred.promise;
} }
/** /**
* run command as ChildProcess * run command as ChildProcess
...@@ -437,7 +471,7 @@ async function isAlive(pid:any): Promise<boolean> { ...@@ -437,7 +471,7 @@ async function isAlive(pid:any): Promise<boolean> {
} }
/** /**
* kill process * kill process
*/ */
async function killPid(pid:any): Promise<void> { async function killPid(pid:any): Promise<void> {
let deferred : Deferred<void> = new Deferred<void>(); let deferred : Deferred<void> = new Deferred<void>();
...@@ -466,7 +500,7 @@ function getNewLine(): string { ...@@ -466,7 +500,7 @@ function getNewLine(): string {
/** /**
* Use '/' to join path instead of '\' for all kinds of platform * Use '/' to join path instead of '\' for all kinds of platform
* @param path * @param path
*/ */
function unixPathJoin(...paths: any[]): string { function unixPathJoin(...paths: any[]): string {
const dir: string = paths.filter((path: any) => path !== '').join('/'); const dir: string = paths.filter((path: any) => path !== '').join('/');
...@@ -474,6 +508,6 @@ function unixPathJoin(...paths: any[]): string { ...@@ -474,6 +508,6 @@ function unixPathJoin(...paths: any[]): string {
return dir; return dir;
} }
export {countFilesRecursively, getRemoteTmpDir, generateParamFileName, getMsgDispatcherCommand, getCheckpointDir, export {countFilesRecursively, validateFileNameRecursively, getRemoteTmpDir, generateParamFileName, getMsgDispatcherCommand, getCheckpointDir,
getLogDir, getExperimentRootDir, getJobCancelStatus, getDefaultDatabaseDir, getIPV4Address, unixPathJoin, getLogDir, getExperimentRootDir, getJobCancelStatus, getDefaultDatabaseDir, getIPV4Address, unixPathJoin,
mkDirP, delay, prepareUnitTest, parseArg, cleanupUnitTest, uniqueString, randomSelect, getLogLevel, getVersion, getCmdPy, getTunerProc, isAlive, killPid, getNewLine }; mkDirP, delay, prepareUnitTest, parseArg, cleanupUnitTest, uniqueString, randomSelect, getLogLevel, getVersion, getCmdPy, getTunerProc, isAlive, killPid, getNewLine };
{ {
"kind": "CustomResourceDefinition", "kind": "CustomResourceDefinition",
"spec": { "spec": {
"scope": "Namespaced", "scope": "Namespaced",
"version": "v1", "version": "v1",
"group": "frameworkcontroller.microsoft.com", "group": "frameworkcontroller.microsoft.com",
"names": { "names": {
"kind": "Framework", "kind": "Framework",
"plural": "frameworks", "plural": "frameworks",
"singular": "framework" "singular": "framework"
} }
}, },
"apiVersion": "apiextensions.k8s.io/v1beta1", "apiVersion": "apiextensions.k8s.io/v1beta1",
"metadata": { "metadata": {
"name": "frameworks.frameworkcontroller.microsoft.com" "name": "frameworks.frameworkcontroller.microsoft.com"
} }
......
{ {
"kind": "CustomResourceDefinition", "kind": "CustomResourceDefinition",
"spec": { "spec": {
"scope": "Namespaced", "scope": "Namespaced",
"version": "v1alpha2", "version": "v1alpha2",
"group": "kubeflow.org", "group": "kubeflow.org",
"names": { "names": {
"kind": "PyTorchJob", "kind": "PyTorchJob",
"plural": "pytorchjobs", "plural": "pytorchjobs",
"singular": "pytorchjob" "singular": "pytorchjob"
} }
}, },
"apiVersion": "apiextensions.k8s.io/v1beta1", "apiVersion": "apiextensions.k8s.io/v1beta1",
"metadata": { "metadata": {
"name": "pytorchjobs.kubeflow.org" "name": "pytorchjobs.kubeflow.org"
} }
......
{ {
"kind": "CustomResourceDefinition", "kind": "CustomResourceDefinition",
"spec": { "spec": {
"scope": "Namespaced", "scope": "Namespaced",
"version": "v1beta1", "version": "v1beta1",
"group": "kubeflow.org", "group": "kubeflow.org",
"names": { "names": {
"kind": "PyTorchJob", "kind": "PyTorchJob",
"plural": "pytorchjobs", "plural": "pytorchjobs",
"singular": "pytorchjob" "singular": "pytorchjob"
} }
}, },
"apiVersion": "apiextensions.k8s.io/v1beta1", "apiVersion": "apiextensions.k8s.io/v1beta1",
"metadata": { "metadata": {
"name": "pytorchjobs.kubeflow.org" "name": "pytorchjobs.kubeflow.org"
} }
......
{ {
"kind": "CustomResourceDefinition", "kind": "CustomResourceDefinition",
"spec": { "spec": {
"scope": "Namespaced", "scope": "Namespaced",
"version": "v1alpha2", "version": "v1alpha2",
"group": "kubeflow.org", "group": "kubeflow.org",
"names": { "names": {
"kind": "TFJob", "kind": "TFJob",
"plural": "tfjobs", "plural": "tfjobs",
"singular": "tfjob" "singular": "tfjob"
} }
}, },
"apiVersion": "apiextensions.k8s.io/v1beta1", "apiVersion": "apiextensions.k8s.io/v1beta1",
"metadata": { "metadata": {
"name": "tfjobs.kubeflow.org" "name": "tfjobs.kubeflow.org"
} }
......
{ {
"kind": "CustomResourceDefinition", "kind": "CustomResourceDefinition",
"spec": { "spec": {
"scope": "Namespaced", "scope": "Namespaced",
"version": "v1beta1", "version": "v1beta1",
"group": "kubeflow.org", "group": "kubeflow.org",
"names": { "names": {
"kind": "TFJob", "kind": "TFJob",
"plural": "tfjobs", "plural": "tfjobs",
"singular": "tfjob" "singular": "tfjob"
} }
}, },
"apiVersion": "apiextensions.k8s.io/v1beta1", "apiVersion": "apiextensions.k8s.io/v1beta1",
"metadata": { "metadata": {
"name": "tfjobs.kubeflow.org" "name": "tfjobs.kubeflow.org"
} }
......
...@@ -159,7 +159,7 @@ class NNIManager implements Manager { ...@@ -159,7 +159,7 @@ class NNIManager implements Manager {
if (expParams.logCollection !== undefined) { if (expParams.logCollection !== undefined) {
this.trainingService.setClusterMetadata('log_collection', expParams.logCollection.toString()); this.trainingService.setClusterMetadata('log_collection', expParams.logCollection.toString());
} }
const dispatcherCommand: string = getMsgDispatcherCommand(expParams.tuner, expParams.assessor, expParams.advisor, const dispatcherCommand: string = getMsgDispatcherCommand(expParams.tuner, expParams.assessor, expParams.advisor,
expParams.multiPhase, expParams.multiThread); expParams.multiPhase, expParams.multiThread);
this.log.debug(`dispatcher command: ${dispatcherCommand}`); this.log.debug(`dispatcher command: ${dispatcherCommand}`);
...@@ -493,7 +493,7 @@ class NNIManager implements Manager { ...@@ -493,7 +493,7 @@ class NNIManager implements Manager {
// If trialConcurrency does not change, requestTrialNum equals finishedTrialJobNum. // If trialConcurrency does not change, requestTrialNum equals finishedTrialJobNum.
// If trialConcurrency changes, for example, trialConcurrency increases by 2 (trialConcurrencyChange=2), then // If trialConcurrency changes, for example, trialConcurrency increases by 2 (trialConcurrencyChange=2), then
// requestTrialNum equals 2 + finishedTrialJobNum and trialConcurrencyChange becomes 0. // requestTrialNum equals 2 + finishedTrialJobNum and trialConcurrencyChange becomes 0.
// If trialConcurrency changes, for example, trialConcurrency decreases by 4 (trialConcurrencyChange=-4) and // If trialConcurrency changes, for example, trialConcurrency decreases by 4 (trialConcurrencyChange=-4) and
// finishedTrialJobNum is 2, then requestTrialNum becomes -2. No trial will be requested from tuner, // finishedTrialJobNum is 2, then requestTrialNum becomes -2. No trial will be requested from tuner,
// and trialConcurrencyChange becomes -2. // and trialConcurrencyChange becomes -2.
const requestTrialNum: number = this.trialConcurrencyChange + finishedTrialJobNum; const requestTrialNum: number = this.trialConcurrencyChange + finishedTrialJobNum;
......
...@@ -46,11 +46,11 @@ function runProcess(): Promise<Error | null> { ...@@ -46,11 +46,11 @@ function runProcess(): Promise<Error | null> {
if (code !== 0) { if (code !== 0) {
deferred.resolve(new Error(`return code: ${code}`)); deferred.resolve(new Error(`return code: ${code}`));
} else { } else {
let str = proc.stdout.read().toString(); let str = proc.stdout.read().toString();
if(str.search("\r\n")!=-1){ if(str.search("\r\n")!=-1){
sentCommands = str.split("\r\n"); sentCommands = str.split("\r\n");
} }
else{ else{
sentCommands = str.split('\n'); sentCommands = str.split('\n');
} }
deferred.resolve(null); deferred.resolve(null);
...@@ -76,7 +76,7 @@ function runProcess(): Promise<Error | null> { ...@@ -76,7 +76,7 @@ function runProcess(): Promise<Error | null> {
commandTooLong = error; commandTooLong = error;
} }
// Command #4: FE is not tuner/assessor command, test the exception type of send non-valid command // Command #4: FE is not tuner/assessor command, test the exception type of send non-valid command
try { try {
dispatcher.sendCommand('FE', '1'); dispatcher.sendCommand('FE', '1');
} catch (error) { } catch (error) {
......
...@@ -59,10 +59,10 @@ class MockedTrainingService extends TrainingService { ...@@ -59,10 +59,10 @@ class MockedTrainingService extends TrainingService {
}, },
sequenceId: 0 sequenceId: 0
}; };
public listTrialJobs(): Promise<TrialJobDetail[]> { public listTrialJobs(): Promise<TrialJobDetail[]> {
const deferred = new Deferred<TrialJobDetail[]>(); const deferred = new Deferred<TrialJobDetail[]>();
deferred.resolve([this.jobDetail1, this.jobDetail2]); deferred.resolve([this.jobDetail1, this.jobDetail2]);
return deferred.promise; return deferred.promise;
} }
......
...@@ -104,7 +104,7 @@ describe('Unit test for nnimanager', function () { ...@@ -104,7 +104,7 @@ describe('Unit test for nnimanager', function () {
maxSequenceId: 0, maxSequenceId: 0,
revision: 0 revision: 0
} }
before(async () => { before(async () => {
await initContainer(); await initContainer();
......
...@@ -33,6 +33,7 @@ ...@@ -33,6 +33,7 @@
"@types/chai-as-promised": "^7.1.0", "@types/chai-as-promised": "^7.1.0",
"@types/express": "^4.16.0", "@types/express": "^4.16.0",
"@types/glob": "^7.1.1", "@types/glob": "^7.1.1",
"@types/js-base64": "^2.3.1",
"@types/mocha": "^5.2.5", "@types/mocha": "^5.2.5",
"@types/node": "10.12.18", "@types/node": "10.12.18",
"@types/request": "^2.47.1", "@types/request": "^2.47.1",
......
...@@ -31,7 +31,7 @@ import { createRestHandler } from './restHandler'; ...@@ -31,7 +31,7 @@ import { createRestHandler } from './restHandler';
* NNI Main rest server, provides rest API to support * NNI Main rest server, provides rest API to support
* # nnictl CLI tool * # nnictl CLI tool
* # NNI WebUI * # NNI WebUI
* *
*/ */
@component.Singleton @component.Singleton
export class NNIRestServer extends RestServer { export class NNIRestServer extends RestServer {
......
...@@ -146,7 +146,7 @@ class NNIRestHandler { ...@@ -146,7 +146,7 @@ class NNIRestHandler {
}); });
}); });
} }
private importData(router: Router): void { private importData(router: Router): void {
router.post('/experiment/import-data', (req: Request, res: Response) => { router.post('/experiment/import-data', (req: Request, res: Response) => {
this.nniManager.importData(JSON.stringify(req.body)).then(() => { this.nniManager.importData(JSON.stringify(req.body)).then(() => {
......
...@@ -133,7 +133,7 @@ export namespace ValidationSchemas { ...@@ -133,7 +133,7 @@ export namespace ValidationSchemas {
}) })
}), }),
nni_manager_ip: joi.object({ nni_manager_ip: joi.object({
nniManagerIp: joi.string().min(1) nniManagerIp: joi.string().min(1)
}) })
} }
}; };
......
...@@ -20,22 +20,24 @@ ...@@ -20,22 +20,24 @@
'use strict'; 'use strict';
import * as assert from 'assert'; import * as assert from 'assert';
import { Request, Response, Router } from 'express'; // tslint:disable-next-line:no-implicit-dependencies
import * as bodyParser from 'body-parser'; import * as bodyParser from 'body-parser';
import { Request, Response, Router } from 'express';
import * as fs from 'fs';
import * as path from 'path';
import { Writable } from 'stream';
import { String } from 'typescript-string-operations';
import * as component from '../../common/component'; import * as component from '../../common/component';
import * as fs from 'fs'
import * as path from 'path'
import { getBasePort, getExperimentId } from '../../common/experimentStartupInfo'; import { getBasePort, getExperimentId } from '../../common/experimentStartupInfo';
import { RestServer } from '../../common/restServer' import { RestServer } from '../../common/restServer';
import { getLogDir } from '../../common/utils'; import { getLogDir } from '../../common/utils';
import { Writable } from 'stream';
/** /**
* Cluster Job Training service Rest server, provides rest API to support Cluster job metrics update * Cluster Job Training service Rest server, provides rest API to support Cluster job metrics update
* *
*/ */
@component.Singleton @component.Singleton
export abstract class ClusterJobRestServer extends RestServer{ export abstract class ClusterJobRestServer extends RestServer {
private readonly API_ROOT_URL: string = '/api/v1/nni-pai'; private readonly API_ROOT_URL: string = '/api/v1/nni-pai';
private readonly NNI_METRICS_PATTERN: string = `NNISDK_MEb'(?<metrics>.*?)'`; private readonly NNI_METRICS_PATTERN: string = `NNISDK_MEb'(?<metrics>.*?)'`;
...@@ -51,22 +53,23 @@ export abstract class ClusterJobRestServer extends RestServer{ ...@@ -51,22 +53,23 @@ export abstract class ClusterJobRestServer extends RestServer{
constructor() { constructor() {
super(); super();
const basePort: number = getBasePort(); const basePort: number = getBasePort();
assert(basePort && basePort > 1024); assert(basePort !== undefined && basePort > 1024);
this.port = basePort + 1; this.port = basePort + 1;
} }
public get clusterRestServerPort(): number { public get clusterRestServerPort(): number {
if(!this.port) { if (this.port === undefined) {
throw new Error('PAI Rest server port is undefined'); throw new Error('PAI Rest server port is undefined');
} }
return this.port; return this.port;
} }
public get getErrorMessage(): string | undefined{ public get getErrorMessage(): string | undefined {
return this.errorMessage; return this.errorMessage;
} }
public set setEnableVersionCheck(versionCheck: boolean) { public set setEnableVersionCheck(versionCheck: boolean) {
this.enableVersionCheck = versionCheck; this.enableVersionCheck = versionCheck;
} }
...@@ -79,11 +82,15 @@ export abstract class ClusterJobRestServer extends RestServer{ ...@@ -79,11 +82,15 @@ export abstract class ClusterJobRestServer extends RestServer{
this.app.use(this.API_ROOT_URL, this.createRestHandler()); this.app.use(this.API_ROOT_URL, this.createRestHandler());
} }
// Abstract method to handle trial metrics data
// tslint:disable-next-line:no-any
protected abstract handleTrialMetrics(jobId : string, trialMetrics : any[]) : void;
// tslint:disable: no-unsafe-any no-any
private createRestHandler() : Router { private createRestHandler() : Router {
const router: Router = Router(); const router: Router = Router();
// tslint:disable-next-line:typedef router.use((req: Request, res: Response, next: any) => {
router.use((req: Request, res: Response, next) => {
this.log.info(`${req.method}: ${req.url}: body:\n${JSON.stringify(req.body, undefined, 4)}`); this.log.info(`${req.method}: ${req.url}: body:\n${JSON.stringify(req.body, undefined, 4)}`);
res.setHeader('Content-Type', 'application/json'); res.setHeader('Content-Type', 'application/json');
next(); next();
...@@ -92,7 +99,7 @@ export abstract class ClusterJobRestServer extends RestServer{ ...@@ -92,7 +99,7 @@ export abstract class ClusterJobRestServer extends RestServer{
router.post(`/version/${this.expId}/:trialId`, (req: Request, res: Response) => { router.post(`/version/${this.expId}/:trialId`, (req: Request, res: Response) => {
if (this.enableVersionCheck) { if (this.enableVersionCheck) {
try { try {
const checkResultSuccess: boolean = req.body.tag === 'VCSuccess'? true: false; const checkResultSuccess: boolean = req.body.tag === 'VCSuccess' ? true : false;
if (this.versionCheckSuccess !== undefined && this.versionCheckSuccess !== checkResultSuccess) { if (this.versionCheckSuccess !== undefined && this.versionCheckSuccess !== checkResultSuccess) {
this.errorMessage = 'Version check error, version check result is inconsistent!'; this.errorMessage = 'Version check error, version check result is inconsistent!';
this.log.error(this.errorMessage); this.log.error(this.errorMessage);
...@@ -103,7 +110,7 @@ export abstract class ClusterJobRestServer extends RestServer{ ...@@ -103,7 +110,7 @@ export abstract class ClusterJobRestServer extends RestServer{
this.versionCheckSuccess = false; this.versionCheckSuccess = false;
this.errorMessage = req.body.msg; this.errorMessage = req.body.msg;
} }
} catch(err) { } catch (err) {
this.log.error(`json parse metrics error: ${err}`); this.log.error(`json parse metrics error: ${err}`);
res.status(500); res.status(500);
res.send(err.message); res.send(err.message);
...@@ -122,8 +129,7 @@ export abstract class ClusterJobRestServer extends RestServer{ ...@@ -122,8 +129,7 @@ export abstract class ClusterJobRestServer extends RestServer{
this.handleTrialMetrics(req.body.jobId, req.body.metrics); this.handleTrialMetrics(req.body.jobId, req.body.metrics);
res.send(); res.send();
} } catch (err) {
catch(err) {
this.log.error(`json parse metrics error: ${err}`); this.log.error(`json parse metrics error: ${err}`);
res.status(500); res.status(500);
res.send(err.message); res.send(err.message);
...@@ -131,35 +137,37 @@ export abstract class ClusterJobRestServer extends RestServer{ ...@@ -131,35 +137,37 @@ export abstract class ClusterJobRestServer extends RestServer{
}); });
router.post(`/stdout/${this.expId}/:trialId`, (req: Request, res: Response) => { router.post(`/stdout/${this.expId}/:trialId`, (req: Request, res: Response) => {
if(this.enableVersionCheck && !this.versionCheckSuccess && !this.errorMessage) { if (this.enableVersionCheck && (this.versionCheckSuccess === undefined || !this.versionCheckSuccess)
this.errorMessage = `Version check failed, didn't get version check response from trialKeeper, please check your NNI version in ` && this.errorMessage === undefined) {
+ `NNIManager and TrialKeeper!` this.errorMessage = `Version check failed, didn't get version check response from trialKeeper,`
+ ` please check your NNI version in NNIManager and TrialKeeper!`;
} }
const trialLogPath: string = path.join(getLogDir(), `trial_${req.params.trialId}.log`); const trialLogPath: string = path.join(getLogDir(), `trial_${req.params.trialId}.log`);
try { try {
let skipLogging: boolean = false; let skipLogging: boolean = false;
if(req.body.tag === 'trial' && req.body.msg !== undefined) { if (req.body.tag === 'trial' && req.body.msg !== undefined) {
const metricsContent = req.body.msg.match(this.NNI_METRICS_PATTERN); const metricsContent: any = req.body.msg.match(this.NNI_METRICS_PATTERN);
if(metricsContent && metricsContent.groups) { if (metricsContent && metricsContent.groups) {
this.handleTrialMetrics(req.params.trialId, [metricsContent.groups['metrics']]); const key: string = 'metrics';
this.handleTrialMetrics(req.params.trialId, [metricsContent.groups[key]]);
skipLogging = true; skipLogging = true;
} }
} }
if(!skipLogging){ if (!skipLogging) {
// Construct write stream to write remote trial's log into local file // Construct write stream to write remote trial's log into local file
// tslint:disable-next-line:non-literal-fs-path
const writeStream: Writable = fs.createWriteStream(trialLogPath, { const writeStream: Writable = fs.createWriteStream(trialLogPath, {
flags: 'a+', flags: 'a+',
encoding: 'utf8', encoding: 'utf8',
autoClose: true autoClose: true
}); });
writeStream.write(req.body.msg + '\n'); writeStream.write(String.Format('{0}\n', req.body.msg));
writeStream.end(); writeStream.end();
} }
res.send(); res.send();
} } catch (err) {
catch(err) {
this.log.error(`json parse stdout data error: ${err}`); this.log.error(`json parse stdout data error: ${err}`);
res.status(500); res.status(500);
res.send(err.message); res.send(err.message);
...@@ -168,7 +176,5 @@ export abstract class ClusterJobRestServer extends RestServer{ ...@@ -168,7 +176,5 @@ export abstract class ClusterJobRestServer extends RestServer{
return router; return router;
} }
// tslint:enable: no-unsafe-any no-any
/** Abstract method to handle trial metrics data */ }
protected abstract handleTrialMetrics(jobId : string, trialMetrics : any[]) : void;
}
\ No newline at end of file
...@@ -19,12 +19,12 @@ ...@@ -19,12 +19,12 @@
'use strict'; 'use strict';
export const CONTAINER_INSTALL_NNI_SHELL_FORMAT: string = export const CONTAINER_INSTALL_NNI_SHELL_FORMAT: string =
`#!/bin/bash `#!/bin/bash
if python3 -c 'import nni' > /dev/null 2>&1; then if python3 -c 'import nni' > /dev/null 2>&1; then
# nni module is already installed, skip # nni module is already installed, skip
return return
else else
# Install nni # Install nni
python3 -m pip install --user --upgrade nni python3 -m pip install --user --upgrade nni
fi`; fi`;
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment