clusterJobRestServer.ts 6.24 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
3

4
5
import assert from 'assert';
import bodyParser from 'body-parser';
6
import { Request, Response, Router } from 'express';
7
8
import fs from 'fs';
import path from 'path';
9
10
import { Writable } from 'stream';
import { String } from 'typescript-string-operations';
11
12
13
14
import * as component from 'common/component';
import { getBasePort, getExperimentId } from 'common/experimentStartupInfo';
import { RestServer } from 'common/restServer';
import { getExperimentRootDir, mkDirPSync } from 'common/utils';
15
16
17

/**
 * Cluster Job Training service Rest server, provides rest API to support Cluster job metrics update
18
 *
19
20
 */
@component.Singleton
21
export abstract class ClusterJobRestServer extends RestServer {
22
    private readonly API_ROOT_URL: string = '/api/v1/nni-pai';
23
    private readonly NNI_METRICS_PATTERN: string = `NNISDK_MEb'(?<metrics>.*?)'`;
24
25
26

    private readonly expId: string = getExperimentId();

27
28
29
30
    private enableVersionCheck: boolean = true; //switch to enable version check
    private versionCheckSuccess: boolean | undefined;
    private errorMessage?: string;

31
32
33
34
35
36
    /**
     * constructor to provide NNIRestServer's own rest property, e.g. port
     */
    constructor() {
        super();
        const basePort: number = getBasePort();
37
        assert(basePort !== undefined && basePort > 1024);
38
39

        this.port = basePort + 1;
40
41
    }

42
43
44
45
    get apiRootUrl(): string {
        return this.API_ROOT_URL;
    }

46
    public get clusterRestServerPort(): number {
47
        if (this.port === undefined) {
48
49
            throw new Error('PAI Rest server port is undefined');
        }
50

51
52
        return this.port;
    }
53

54
    public get getErrorMessage(): string | undefined {
55
56
        return this.errorMessage;
    }
57

58
59
60
    public set setEnableVersionCheck(versionCheck: boolean) {
        this.enableVersionCheck = versionCheck;
    }
61
62
63
64
65
66
67
68
69

    /**
     * NNIRestServer's own router registration
     */
    protected registerRestHandler(): void {
        this.app.use(bodyParser.json());
        this.app.use(this.API_ROOT_URL, this.createRestHandler());
    }

70
    // Abstract method to handle trial metrics data
chicm-ms's avatar
chicm-ms committed
71
    protected abstract handleTrialMetrics(jobId: string, trialMetrics: any[]): void;
72

chicm-ms's avatar
chicm-ms committed
73
    protected createRestHandler(): Router {
74
75
        const router: Router = Router();

76
        router.use((req: Request, res: Response, next: any) => {
liuzhe-lz's avatar
liuzhe-lz committed
77
            this.log.info(`${req.method}: ${req.url}: body:`, req.body);
78
79
80
81
            res.setHeader('Content-Type', 'application/json');
            next();
        });

82
83
84
        router.post(`/version/${this.expId}/:trialId`, (req: Request, res: Response) => {
            if (this.enableVersionCheck) {
                try {
85
                    const checkResultSuccess: boolean = req.body.tag === 'VCSuccess' ? true : false;
86
87
88
89
90
91
92
93
94
95
                    if (this.versionCheckSuccess !== undefined && this.versionCheckSuccess !== checkResultSuccess) {
                        this.errorMessage = 'Version check error, version check result is inconsistent!';
                        this.log.error(this.errorMessage);
                    } else if (checkResultSuccess) {
                        this.log.info(`Version check in trialKeeper success!`);
                        this.versionCheckSuccess = true;
                    } else {
                        this.versionCheckSuccess = false;
                        this.errorMessage = req.body.msg;
                    }
96
                } catch (err) {
97
98
99
100
101
102
103
104
105
106
                    this.log.error(`json parse metrics error: ${err}`);
                    res.status(500);
                    res.send(err.message);
                }
            } else {
                this.log.info(`Skipping version check!`);
            }
            res.send();
        });

107
108
        router.post(`/update-metrics/${this.expId}/:trialId`, (req: Request, res: Response) => {
            try {
109
                this.log.info(`Get update-metrics request, trial job id is ${req.params['trialId']}`);
liuzhe-lz's avatar
liuzhe-lz committed
110
                this.log.info('update-metrics body is', req.body);
111
112
113
114

                this.handleTrialMetrics(req.body.jobId, req.body.metrics);

                res.send();
115
            } catch (err) {
116
117
118
119
120
121
                this.log.error(`json parse metrics error: ${err}`);
                res.status(500);
                res.send(err.message);
            }
        });

122
        router.post(`/stdout/${this.expId}/:trialId`, (req: Request, res: Response) => {
123
124
125
126
            if (this.enableVersionCheck && (this.versionCheckSuccess === undefined || !this.versionCheckSuccess)
            && this.errorMessage === undefined) {
                this.errorMessage = `Version check failed, didn't get version check response from trialKeeper,`
                 + ` please check your NNI version in NNIManager and TrialKeeper!`;
127
            }
128
            const trialLogDir: string = path.join(getExperimentRootDir(), 'trials', req.params['trialId']);
129
130
            mkDirPSync(trialLogDir);
            const trialLogPath: string = path.join(trialLogDir, 'stdout_log_collection.log');
131
132
            try {
                let skipLogging: boolean = false;
133
134
135
136
                if (req.body.tag === 'trial' && req.body.msg !== undefined) {
                    const metricsContent: any = req.body.msg.match(this.NNI_METRICS_PATTERN);
                    if (metricsContent && metricsContent.groups) {
                        const key: string = 'metrics';
137
                        this.handleTrialMetrics(req.params['trialId'], [metricsContent.groups[key]]);
138
139
140
141
                        skipLogging = true;
                    }
                }

142
                if (!skipLogging) {
143
144
145
146
147
148
149
                    // Construct write stream to write remote trial's log into local file
                    const writeStream: Writable = fs.createWriteStream(trialLogPath, {
                        flags: 'a+',
                        encoding: 'utf8',
                        autoClose: true
                    });

150
                    writeStream.write(String.Format('{0}\n', req.body.msg));
151
152
153
                    writeStream.end();
                }
                res.send();
154
            } catch (err) {
155
156
157
158
159
160
                this.log.error(`json parse stdout data error: ${err}`);
                res.status(500);
                res.send(err.message);
            }
        });

161
162
        return router;
    }
163
}