routerTrainingService.ts 5.46 KB
Newer Older
1
2
3
4
5
6
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

'use strict';

import { getLogger, Logger } from '../../common/log';
7
8
import { MethodNotImplementedError } from '../../common/errors';
import { ExperimentConfig, RemoteConfig, OpenpaiConfig } from '../../common/experimentConfig';
9
import { TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric, LogType } from '../../common/trainingService';
10
import { delay } from '../../common/utils';
SparkSnail's avatar
SparkSnail committed
11
import { PAITrainingService } from '../pai/paiTrainingService';
12
import { RemoteMachineTrainingService } from '../remote_machine/remoteMachineTrainingService';
13
14
15
16
17
18
19
20
import { TrialDispatcher } from './trialDispatcher';


/**
 * It's a intermedia implementation to support reusable training service.
 * The final goal is to support reusable training job in higher level than training service.
 */
class RouterTrainingService implements TrainingService {
21
22
    private log!: Logger;
    private internalTrainingService!: TrainingService;
23

24
25
26
    public static async construct(config: ExperimentConfig): Promise<RouterTrainingService> {
        const instance = new RouterTrainingService();
        instance.log = getLogger('RouterTrainingService');
27
28
        const platform = Array.isArray(config.trainingService) ? 'hybrid' : config.trainingService.platform;
        if (platform === 'remote' && !(<RemoteConfig>config.trainingService).reuseMode) {
29
            instance.internalTrainingService = new RemoteMachineTrainingService(config);
30
        } else if (platform === 'openpai' && !(<OpenpaiConfig>config.trainingService).reuseMode) {
31
            instance.internalTrainingService = new PAITrainingService(config);
32
        } else {
33
            instance.internalTrainingService = await TrialDispatcher.construct(config);
34
        }
35
        return instance;
36
37
    }

38
39
40
    // eslint-disable-next-line @typescript-eslint/no-empty-function
    private constructor() { }

41
42
43
44
45
46
47
48
49
50
51
52
53
54
    public async listTrialJobs(): Promise<TrialJobDetail[]> {
        if (this.internalTrainingService === undefined) {
            throw new Error("TrainingService is not assigned!");
        }
        return await this.internalTrainingService.listTrialJobs();
    }

    public async getTrialJob(trialJobId: string): Promise<TrialJobDetail> {
        if (this.internalTrainingService === undefined) {
            throw new Error("TrainingService is not assigned!");
        }
        return await this.internalTrainingService.getTrialJob(trialJobId);
    }

55
56
57
58
    public async getTrialLog(_trialJobId: string, _logType: LogType): Promise<string> {
        throw new MethodNotImplementedError();
    }

59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
    public addTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void {
        if (this.internalTrainingService === undefined) {
            throw new Error("TrainingService is not assigned!");
        }
        this.internalTrainingService.addTrialJobMetricListener(listener);
    }

    public removeTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void {
        if (this.internalTrainingService === undefined) {
            throw new Error("TrainingService is not assigned!");
        }
        this.internalTrainingService.removeTrialJobMetricListener(listener);
    }

    public async submitTrialJob(form: TrialJobApplicationForm): Promise<TrialJobDetail> {
        if (this.internalTrainingService === undefined) {
            throw new Error("TrainingService is not assigned!");
        }
        return await this.internalTrainingService.submitTrialJob(form);
    }

    public async updateTrialJob(trialJobId: string, form: TrialJobApplicationForm): Promise<TrialJobDetail> {
        if (this.internalTrainingService === undefined) {
            throw new Error("TrainingService is not assigned!");
        }
        return await this.internalTrainingService.updateTrialJob(trialJobId, form);
    }

    public async cancelTrialJob(trialJobId: string, isEarlyStopped?: boolean | undefined): Promise<void> {
        if (this.internalTrainingService === undefined) {
            throw new Error("TrainingService is not assigned!");
        }
        await this.internalTrainingService.cancelTrialJob(trialJobId, isEarlyStopped);
    }

94
95
    public async setClusterMetadata(_key: string, _value: string): Promise<void> { return; }
    public async getClusterMetadata(_key: string): Promise<string> { return ''; }
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111

    public async cleanUp(): Promise<void> {
        if (this.internalTrainingService === undefined) {
            throw new Error("TrainingService is not assigned!");
        }
        await this.internalTrainingService.cleanUp();
    }

    public async run(): Promise<void> {
        // wait internal training service is assigned.
        // It will be assigned after set metadata of paiConfig
        while (this.internalTrainingService === undefined) {
            await delay(100);
        }
        return await this.internalTrainingService.run();
    }
J-shang's avatar
J-shang committed
112
113
114
115
116
117
118
119
120
121
122
123
124
125

    public async getTrialOutputLocalPath(trialJobId: string): Promise<string> {
        if (this.internalTrainingService === undefined) {
            throw new Error("TrainingService is not assigned!");
        }
        return this.internalTrainingService.getTrialOutputLocalPath(trialJobId);
    }

    public async fetchTrialOutput(trialJobId: string, subpath: string): Promise<void> {
        if (this.internalTrainingService === undefined) {
            throw new Error("TrainingService is not assigned!");
        }
        return this.internalTrainingService.fetchTrialOutput(trialJobId, subpath);
    }
126
127
128
}

export { RouterTrainingService };