clusterJobRestServer.ts 6.59 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
3
4
5
6

'use strict';

import * as assert from 'assert';
7
// tslint:disable-next-line:no-implicit-dependencies
8
import * as bodyParser from 'body-parser';
9
10
11
12
13
import { Request, Response, Router } from 'express';
import * as fs from 'fs';
import * as path from 'path';
import { Writable } from 'stream';
import { String } from 'typescript-string-operations';
14
15
import * as component from '../../common/component';
import { getBasePort, getExperimentId } from '../../common/experimentStartupInfo';
16
import { RestServer } from '../../common/restServer';
17
import { getExperimentRootDir, mkDirPSync } from '../../common/utils';
18
19
20

/**
 * Cluster Job Training service Rest server, provides rest API to support Cluster job metrics update
21
 *
22
23
 */
@component.Singleton
24
export abstract class ClusterJobRestServer extends RestServer {
25
    private readonly API_ROOT_URL: string = '/api/v1/nni-pai';
26
    private readonly NNI_METRICS_PATTERN: string = `NNISDK_MEb'(?<metrics>.*?)'`;
27
28
29

    private readonly expId: string = getExperimentId();

30
31
32
33
    private enableVersionCheck: boolean = true; //switch to enable version check
    private versionCheckSuccess: boolean | undefined;
    private errorMessage?: string;

34
35
36
37
38
39
    /**
     * constructor to provide NNIRestServer's own rest property, e.g. port
     */
    constructor() {
        super();
        const basePort: number = getBasePort();
40
        assert(basePort !== undefined && basePort > 1024);
41
42

        this.port = basePort + 1;
43
44
    }

45
46
47
48
    get apiRootUrl(): string {
        return this.API_ROOT_URL;
    }

49
    public get clusterRestServerPort(): number {
50
        if (this.port === undefined) {
51
52
            throw new Error('PAI Rest server port is undefined');
        }
53

54
55
        return this.port;
    }
56

57
    public get getErrorMessage(): string | undefined {
58
59
        return this.errorMessage;
    }
60

61
62
63
    public set setEnableVersionCheck(versionCheck: boolean) {
        this.enableVersionCheck = versionCheck;
    }
64
65
66
67
68
69
70
71
72

    /**
     * NNIRestServer's own router registration
     */
    protected registerRestHandler(): void {
        this.app.use(bodyParser.json());
        this.app.use(this.API_ROOT_URL, this.createRestHandler());
    }

73
74
75
76
77
    // Abstract method to handle trial metrics data
    // tslint:disable-next-line:no-any
    protected abstract handleTrialMetrics(jobId : string, trialMetrics : any[]) : void;

    // tslint:disable: no-unsafe-any no-any
78
    protected createRestHandler() : Router {
79
80
        const router: Router = Router();

81
        router.use((req: Request, res: Response, next: any) => {
82
83
84
85
86
            this.log.info(`${req.method}: ${req.url}: body:\n${JSON.stringify(req.body, undefined, 4)}`);
            res.setHeader('Content-Type', 'application/json');
            next();
        });

87
88
89
        router.post(`/version/${this.expId}/:trialId`, (req: Request, res: Response) => {
            if (this.enableVersionCheck) {
                try {
90
                    const checkResultSuccess: boolean = req.body.tag === 'VCSuccess' ? true : false;
91
92
93
94
95
96
97
98
99
100
                    if (this.versionCheckSuccess !== undefined && this.versionCheckSuccess !== checkResultSuccess) {
                        this.errorMessage = 'Version check error, version check result is inconsistent!';
                        this.log.error(this.errorMessage);
                    } else if (checkResultSuccess) {
                        this.log.info(`Version check in trialKeeper success!`);
                        this.versionCheckSuccess = true;
                    } else {
                        this.versionCheckSuccess = false;
                        this.errorMessage = req.body.msg;
                    }
101
                } catch (err) {
102
103
104
105
106
107
108
109
110
111
                    this.log.error(`json parse metrics error: ${err}`);
                    res.status(500);
                    res.send(err.message);
                }
            } else {
                this.log.info(`Skipping version check!`);
            }
            res.send();
        });

112
113
114
115
116
117
118
119
        router.post(`/update-metrics/${this.expId}/:trialId`, (req: Request, res: Response) => {
            try {
                this.log.info(`Get update-metrics request, trial job id is ${req.params.trialId}`);
                this.log.info(`update-metrics body is ${JSON.stringify(req.body)}`);

                this.handleTrialMetrics(req.body.jobId, req.body.metrics);

                res.send();
120
            } catch (err) {
121
122
123
124
125
126
                this.log.error(`json parse metrics error: ${err}`);
                res.status(500);
                res.send(err.message);
            }
        });

127
        router.post(`/stdout/${this.expId}/:trialId`, (req: Request, res: Response) => {
128
129
130
131
            if (this.enableVersionCheck && (this.versionCheckSuccess === undefined || !this.versionCheckSuccess)
            && this.errorMessage === undefined) {
                this.errorMessage = `Version check failed, didn't get version check response from trialKeeper,`
                 + ` please check your NNI version in NNIManager and TrialKeeper!`;
132
            }
133
134
135
            const trialLogDir: string = path.join(getExperimentRootDir(), 'trials', req.params.trialId);
            mkDirPSync(trialLogDir);
            const trialLogPath: string = path.join(trialLogDir, 'stdout_log_collection.log');
136
137
            try {
                let skipLogging: boolean = false;
138
139
140
141
142
                if (req.body.tag === 'trial' && req.body.msg !== undefined) {
                    const metricsContent: any = req.body.msg.match(this.NNI_METRICS_PATTERN);
                    if (metricsContent && metricsContent.groups) {
                        const key: string = 'metrics';
                        this.handleTrialMetrics(req.params.trialId, [metricsContent.groups[key]]);
143
144
145
146
                        skipLogging = true;
                    }
                }

147
                if (!skipLogging) {
148
                    // Construct write stream to write remote trial's log into local file
149
                    // tslint:disable-next-line:non-literal-fs-path
150
151
152
153
154
155
                    const writeStream: Writable = fs.createWriteStream(trialLogPath, {
                        flags: 'a+',
                        encoding: 'utf8',
                        autoClose: true
                    });

156
                    writeStream.write(String.Format('{0}\n', req.body.msg));
157
158
159
                    writeStream.end();
                }
                res.send();
160
            } catch (err) {
161
162
163
164
165
166
                this.log.error(`json parse stdout data error: ${err}`);
                res.status(500);
                res.send(err.message);
            }
        });

167
168
        return router;
    }
169
170
    // tslint:enable: no-unsafe-any no-any
}