Commit ba8dccd6 authored by suiguoxin's avatar suiguoxin
Browse files

Merge branch 'master' of https://github.com/microsoft/nni

parents 56a1575b 150ee83a
......@@ -91,6 +91,7 @@ interface TrialJobMetric {
* define TrainingServiceError
*/
class TrainingServiceError extends Error {
private errCode: number;
constructor(errorCode: number, errorMessage: string) {
......@@ -136,5 +137,3 @@ export {
TrainingServiceMetadata, TrialJobDetail, TrialJobMetric, HyperParameters,
HostJobApplicationForm, JobApplicationForm, JobType, NNIManagerIpConfig
};
......@@ -374,6 +374,40 @@ function countFilesRecursively(directory: string, timeoutMilliSeconds?: number):
});
}
function validateFileName(fileName: string): boolean {
let pattern: string = '^[a-z0-9A-Z\.-_]+$';
const validateResult = fileName.match(pattern);
if(validateResult) {
return true;
}
return false;
}
async function validateFileNameRecursively(directory: string): Promise<boolean> {
if(!fs.existsSync(directory)) {
throw Error(`Direcotory ${directory} doesn't exist`);
}
const fileNameArray: string[] = fs.readdirSync(directory);
let result = true;
for(var name of fileNameArray){
const fullFilePath: string = path.join(directory, name);
try {
// validate file names and directory names
result = validateFileName(name);
if (fs.lstatSync(fullFilePath).isDirectory()) {
result = result && await validateFileNameRecursively(fullFilePath);
}
if(!result) {
return Promise.reject(new Error(`file name in ${fullFilePath} is not valid!`));
}
} catch(error) {
return Promise.reject(error);
}
}
return Promise.resolve(result);
}
/**
* get the version of current package
*/
......@@ -474,6 +508,6 @@ function unixPathJoin(...paths: any[]): string {
return dir;
}
export {countFilesRecursively, getRemoteTmpDir, generateParamFileName, getMsgDispatcherCommand, getCheckpointDir,
export {countFilesRecursively, validateFileNameRecursively, getRemoteTmpDir, generateParamFileName, getMsgDispatcherCommand, getCheckpointDir,
getLogDir, getExperimentRootDir, getJobCancelStatus, getDefaultDatabaseDir, getIPV4Address, unixPathJoin,
mkDirP, delay, prepareUnitTest, parseArg, cleanupUnitTest, uniqueString, randomSelect, getLogLevel, getVersion, getCmdPy, getTunerProc, isAlive, killPid, getNewLine };
......@@ -33,6 +33,7 @@
"@types/chai-as-promised": "^7.1.0",
"@types/express": "^4.16.0",
"@types/glob": "^7.1.1",
"@types/js-base64": "^2.3.1",
"@types/mocha": "^5.2.5",
"@types/node": "10.12.18",
"@types/request": "^2.47.1",
......
......@@ -20,22 +20,24 @@
'use strict';
import * as assert from 'assert';
import { Request, Response, Router } from 'express';
// tslint:disable-next-line:no-implicit-dependencies
import * as bodyParser from 'body-parser';
import { Request, Response, Router } from 'express';
import * as fs from 'fs';
import * as path from 'path';
import { Writable } from 'stream';
import { String } from 'typescript-string-operations';
import * as component from '../../common/component';
import * as fs from 'fs'
import * as path from 'path'
import { getBasePort, getExperimentId } from '../../common/experimentStartupInfo';
import { RestServer } from '../../common/restServer'
import { RestServer } from '../../common/restServer';
import { getLogDir } from '../../common/utils';
import { Writable } from 'stream';
/**
* Cluster Job Training service Rest server, provides rest API to support Cluster job metrics update
*
*/
@component.Singleton
export abstract class ClusterJobRestServer extends RestServer{
export abstract class ClusterJobRestServer extends RestServer {
private readonly API_ROOT_URL: string = '/api/v1/nni-pai';
private readonly NNI_METRICS_PATTERN: string = `NNISDK_MEb'(?<metrics>.*?)'`;
......@@ -51,19 +53,20 @@ export abstract class ClusterJobRestServer extends RestServer{
constructor() {
super();
const basePort: number = getBasePort();
assert(basePort && basePort > 1024);
assert(basePort !== undefined && basePort > 1024);
this.port = basePort + 1;
}
public get clusterRestServerPort(): number {
if(!this.port) {
if (this.port === undefined) {
throw new Error('PAI Rest server port is undefined');
}
return this.port;
}
public get getErrorMessage(): string | undefined{
public get getErrorMessage(): string | undefined {
return this.errorMessage;
}
......@@ -79,11 +82,15 @@ export abstract class ClusterJobRestServer extends RestServer{
this.app.use(this.API_ROOT_URL, this.createRestHandler());
}
// Abstract method to handle trial metrics data
// tslint:disable-next-line:no-any
protected abstract handleTrialMetrics(jobId : string, trialMetrics : any[]) : void;
// tslint:disable: no-unsafe-any no-any
private createRestHandler() : Router {
const router: Router = Router();
// tslint:disable-next-line:typedef
router.use((req: Request, res: Response, next) => {
router.use((req: Request, res: Response, next: any) => {
this.log.info(`${req.method}: ${req.url}: body:\n${JSON.stringify(req.body, undefined, 4)}`);
res.setHeader('Content-Type', 'application/json');
next();
......@@ -92,7 +99,7 @@ export abstract class ClusterJobRestServer extends RestServer{
router.post(`/version/${this.expId}/:trialId`, (req: Request, res: Response) => {
if (this.enableVersionCheck) {
try {
const checkResultSuccess: boolean = req.body.tag === 'VCSuccess'? true: false;
const checkResultSuccess: boolean = req.body.tag === 'VCSuccess' ? true : false;
if (this.versionCheckSuccess !== undefined && this.versionCheckSuccess !== checkResultSuccess) {
this.errorMessage = 'Version check error, version check result is inconsistent!';
this.log.error(this.errorMessage);
......@@ -103,7 +110,7 @@ export abstract class ClusterJobRestServer extends RestServer{
this.versionCheckSuccess = false;
this.errorMessage = req.body.msg;
}
} catch(err) {
} catch (err) {
this.log.error(`json parse metrics error: ${err}`);
res.status(500);
res.send(err.message);
......@@ -122,8 +129,7 @@ export abstract class ClusterJobRestServer extends RestServer{
this.handleTrialMetrics(req.body.jobId, req.body.metrics);
res.send();
}
catch(err) {
} catch (err) {
this.log.error(`json parse metrics error: ${err}`);
res.status(500);
res.send(err.message);
......@@ -131,35 +137,37 @@ export abstract class ClusterJobRestServer extends RestServer{
});
router.post(`/stdout/${this.expId}/:trialId`, (req: Request, res: Response) => {
if(this.enableVersionCheck && !this.versionCheckSuccess && !this.errorMessage) {
this.errorMessage = `Version check failed, didn't get version check response from trialKeeper, please check your NNI version in `
+ `NNIManager and TrialKeeper!`
if (this.enableVersionCheck && (this.versionCheckSuccess === undefined || !this.versionCheckSuccess)
&& this.errorMessage === undefined) {
this.errorMessage = `Version check failed, didn't get version check response from trialKeeper,`
+ ` please check your NNI version in NNIManager and TrialKeeper!`;
}
const trialLogPath: string = path.join(getLogDir(), `trial_${req.params.trialId}.log`);
try {
let skipLogging: boolean = false;
if(req.body.tag === 'trial' && req.body.msg !== undefined) {
const metricsContent = req.body.msg.match(this.NNI_METRICS_PATTERN);
if(metricsContent && metricsContent.groups) {
this.handleTrialMetrics(req.params.trialId, [metricsContent.groups['metrics']]);
if (req.body.tag === 'trial' && req.body.msg !== undefined) {
const metricsContent: any = req.body.msg.match(this.NNI_METRICS_PATTERN);
if (metricsContent && metricsContent.groups) {
const key: string = 'metrics';
this.handleTrialMetrics(req.params.trialId, [metricsContent.groups[key]]);
skipLogging = true;
}
}
if(!skipLogging){
if (!skipLogging) {
// Construct write stream to write remote trial's log into local file
// tslint:disable-next-line:non-literal-fs-path
const writeStream: Writable = fs.createWriteStream(trialLogPath, {
flags: 'a+',
encoding: 'utf8',
autoClose: true
});
writeStream.write(req.body.msg + '\n');
writeStream.write(String.Format('{0}\n', req.body.msg));
writeStream.end();
}
res.send();
}
catch(err) {
} catch (err) {
this.log.error(`json parse stdout data error: ${err}`);
res.status(500);
res.send(err.message);
......@@ -168,7 +176,5 @@ export abstract class ClusterJobRestServer extends RestServer{
return router;
}
/** Abstract method to handle trial metrics data */
protected abstract handleTrialMetrics(jobId : string, trialMetrics : any[]) : void;
// tslint:enable: no-unsafe-any no-any
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment