clusterJobRestServer.ts 7.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
/**
 * Copyright (c) Microsoft Corporation
 * All rights reserved.
 *
 * MIT License
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
 * documentation files (the "Software"), to deal in the Software without restriction, including without limitation
 * the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and
 * to permit persons to whom the Software is furnished to do so, subject to the following conditions:
 * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
 * BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
 * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
 */

'use strict';

import * as assert from 'assert';
23
// tslint:disable-next-line:no-implicit-dependencies
24
import * as bodyParser from 'body-parser';
25
26
27
28
29
import { Request, Response, Router } from 'express';
import * as fs from 'fs';
import * as path from 'path';
import { Writable } from 'stream';
import { String } from 'typescript-string-operations';
30
31
import * as component from '../../common/component';
import { getBasePort, getExperimentId } from '../../common/experimentStartupInfo';
32
import { RestServer } from '../../common/restServer';
33
import { getLogDir } from '../../common/utils';
34
35
36

/**
 * Cluster Job Training service Rest server, provides rest API to support Cluster job metrics update
37
 *
38
39
 */
@component.Singleton
40
export abstract class ClusterJobRestServer extends RestServer {
41
    private readonly API_ROOT_URL: string = '/api/v1/nni-pai';
42
    private readonly NNI_METRICS_PATTERN: string = `NNISDK_MEb'(?<metrics>.*?)'`;
43
44
45

    private readonly expId: string = getExperimentId();

46
47
48
49
    private enableVersionCheck: boolean = true; //switch to enable version check
    private versionCheckSuccess: boolean | undefined;
    private errorMessage?: string;

50
51
52
53
54
55
    /**
     * constructor to provide NNIRestServer's own rest property, e.g. port
     */
    constructor() {
        super();
        const basePort: number = getBasePort();
56
        assert(basePort !== undefined && basePort > 1024);
57
58

        this.port = basePort + 1;
59
60
61
    }

    public get clusterRestServerPort(): number {
62
        if (this.port === undefined) {
63
64
            throw new Error('PAI Rest server port is undefined');
        }
65

66
67
        return this.port;
    }
68

69
    public get getErrorMessage(): string | undefined {
70
71
        return this.errorMessage;
    }
72

73
74
75
    public set setEnableVersionCheck(versionCheck: boolean) {
        this.enableVersionCheck = versionCheck;
    }
76
77
78
79
80
81
82
83
84

    /**
     * NNIRestServer's own router registration
     */
    protected registerRestHandler(): void {
        this.app.use(bodyParser.json());
        this.app.use(this.API_ROOT_URL, this.createRestHandler());
    }

85
86
87
88
89
    // Abstract method to handle trial metrics data
    // tslint:disable-next-line:no-any
    protected abstract handleTrialMetrics(jobId : string, trialMetrics : any[]) : void;

    // tslint:disable: no-unsafe-any no-any
90
91
92
    private createRestHandler() : Router {
        const router: Router = Router();

93
        router.use((req: Request, res: Response, next: any) => {
94
95
96
97
98
            this.log.info(`${req.method}: ${req.url}: body:\n${JSON.stringify(req.body, undefined, 4)}`);
            res.setHeader('Content-Type', 'application/json');
            next();
        });

99
100
101
        router.post(`/version/${this.expId}/:trialId`, (req: Request, res: Response) => {
            if (this.enableVersionCheck) {
                try {
102
                    const checkResultSuccess: boolean = req.body.tag === 'VCSuccess' ? true : false;
103
104
105
106
107
108
109
110
111
112
                    if (this.versionCheckSuccess !== undefined && this.versionCheckSuccess !== checkResultSuccess) {
                        this.errorMessage = 'Version check error, version check result is inconsistent!';
                        this.log.error(this.errorMessage);
                    } else if (checkResultSuccess) {
                        this.log.info(`Version check in trialKeeper success!`);
                        this.versionCheckSuccess = true;
                    } else {
                        this.versionCheckSuccess = false;
                        this.errorMessage = req.body.msg;
                    }
113
                } catch (err) {
114
115
116
117
118
119
120
121
122
123
                    this.log.error(`json parse metrics error: ${err}`);
                    res.status(500);
                    res.send(err.message);
                }
            } else {
                this.log.info(`Skipping version check!`);
            }
            res.send();
        });

124
125
126
127
128
129
130
131
        router.post(`/update-metrics/${this.expId}/:trialId`, (req: Request, res: Response) => {
            try {
                this.log.info(`Get update-metrics request, trial job id is ${req.params.trialId}`);
                this.log.info(`update-metrics body is ${JSON.stringify(req.body)}`);

                this.handleTrialMetrics(req.body.jobId, req.body.metrics);

                res.send();
132
            } catch (err) {
133
134
135
136
137
138
                this.log.error(`json parse metrics error: ${err}`);
                res.status(500);
                res.send(err.message);
            }
        });

139
        router.post(`/stdout/${this.expId}/:trialId`, (req: Request, res: Response) => {
140
141
142
143
            if (this.enableVersionCheck && (this.versionCheckSuccess === undefined || !this.versionCheckSuccess)
            && this.errorMessage === undefined) {
                this.errorMessage = `Version check failed, didn't get version check response from trialKeeper,`
                 + ` please check your NNI version in NNIManager and TrialKeeper!`;
144
            }
145
146
147
            const trialLogPath: string = path.join(getLogDir(), `trial_${req.params.trialId}.log`);
            try {
                let skipLogging: boolean = false;
148
149
150
151
152
                if (req.body.tag === 'trial' && req.body.msg !== undefined) {
                    const metricsContent: any = req.body.msg.match(this.NNI_METRICS_PATTERN);
                    if (metricsContent && metricsContent.groups) {
                        const key: string = 'metrics';
                        this.handleTrialMetrics(req.params.trialId, [metricsContent.groups[key]]);
153
154
155
156
                        skipLogging = true;
                    }
                }

157
                if (!skipLogging) {
158
                    // Construct write stream to write remote trial's log into local file
159
                    // tslint:disable-next-line:non-literal-fs-path
160
161
162
163
164
165
                    const writeStream: Writable = fs.createWriteStream(trialLogPath, {
                        flags: 'a+',
                        encoding: 'utf8',
                        autoClose: true
                    });

166
                    writeStream.write(String.Format('{0}\n', req.body.msg));
167
168
169
                    writeStream.end();
                }
                res.send();
170
            } catch (err) {
171
172
173
174
175
176
                this.log.error(`json parse stdout data error: ${err}`);
                res.status(500);
                res.send(err.message);
            }
        });

177
178
        return router;
    }
179
180
    // tslint:enable: no-unsafe-any no-any
}