routerTrainingService.ts 5.73 KB
Newer Older
1
2
3
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

4
5
6
7
8
import { getLogger, Logger } from 'common/log';
import { MethodNotImplementedError } from 'common/errors';
import { ExperimentConfig, RemoteConfig, OpenpaiConfig, KubeflowConfig } from 'common/experimentConfig';
import { TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric } from 'common/trainingService';
import { delay } from 'common/utils';
SparkSnail's avatar
SparkSnail committed
9
import { PAITrainingService } from '../pai/paiTrainingService';
10
import { RemoteMachineTrainingService } from '../remote_machine/remoteMachineTrainingService';
11
import { KubeflowTrainingService } from '../kubernetes/kubeflow/kubeflowTrainingService';
12
13
14
15
16
17
18
19
import { TrialDispatcher } from './trialDispatcher';


/**
 * It's a intermedia implementation to support reusable training service.
 * The final goal is to support reusable training job in higher level than training service.
 */
class RouterTrainingService implements TrainingService {
20
21
    private log!: Logger;
    private internalTrainingService!: TrainingService;
22

23
24
25
    public static async construct(config: ExperimentConfig): Promise<RouterTrainingService> {
        const instance = new RouterTrainingService();
        instance.log = getLogger('RouterTrainingService');
26
        const platform = Array.isArray(config.trainingService) ? 'hybrid' : config.trainingService.platform;
liuzhe-lz's avatar
liuzhe-lz committed
27
        if (platform === 'remote' && (<RemoteConfig>config.trainingService).reuseMode === false) {
28
            instance.internalTrainingService = new RemoteMachineTrainingService(config);
liuzhe-lz's avatar
liuzhe-lz committed
29
        } else if (platform === 'openpai' && (<OpenpaiConfig>config.trainingService).reuseMode === false) {
30
            instance.internalTrainingService = new PAITrainingService(config);
liuzhe-lz's avatar
liuzhe-lz committed
31
        } else if (platform === 'kubeflow' && (<KubeflowConfig>config.trainingService).reuseMode === false) {
32
            instance.internalTrainingService = new KubeflowTrainingService();
33
        } else {
34
            instance.internalTrainingService = await TrialDispatcher.construct(config);
35
        }
36
        return instance;
37
38
    }

39
40
41
    // eslint-disable-next-line @typescript-eslint/no-empty-function
    private constructor() { }

42
43
44
45
46
47
48
49
50
51
52
53
54
55
    public async listTrialJobs(): Promise<TrialJobDetail[]> {
        if (this.internalTrainingService === undefined) {
            throw new Error("TrainingService is not assigned!");
        }
        return await this.internalTrainingService.listTrialJobs();
    }

    public async getTrialJob(trialJobId: string): Promise<TrialJobDetail> {
        if (this.internalTrainingService === undefined) {
            throw new Error("TrainingService is not assigned!");
        }
        return await this.internalTrainingService.getTrialJob(trialJobId);
    }

Yuge Zhang's avatar
Yuge Zhang committed
56
    public async getTrialFile(_trialJobId: string, _fileName: string): Promise<string | Buffer> {
57
58
59
        throw new MethodNotImplementedError();
    }

60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
    public addTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void {
        if (this.internalTrainingService === undefined) {
            throw new Error("TrainingService is not assigned!");
        }
        this.internalTrainingService.addTrialJobMetricListener(listener);
    }

    public removeTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void {
        if (this.internalTrainingService === undefined) {
            throw new Error("TrainingService is not assigned!");
        }
        this.internalTrainingService.removeTrialJobMetricListener(listener);
    }

    public async submitTrialJob(form: TrialJobApplicationForm): Promise<TrialJobDetail> {
        if (this.internalTrainingService === undefined) {
            throw new Error("TrainingService is not assigned!");
        }
        return await this.internalTrainingService.submitTrialJob(form);
    }

    public async updateTrialJob(trialJobId: string, form: TrialJobApplicationForm): Promise<TrialJobDetail> {
        if (this.internalTrainingService === undefined) {
            throw new Error("TrainingService is not assigned!");
        }
        return await this.internalTrainingService.updateTrialJob(trialJobId, form);
    }

    public async cancelTrialJob(trialJobId: string, isEarlyStopped?: boolean | undefined): Promise<void> {
        if (this.internalTrainingService === undefined) {
            throw new Error("TrainingService is not assigned!");
        }
        await this.internalTrainingService.cancelTrialJob(trialJobId, isEarlyStopped);
    }

95
96
    public async setClusterMetadata(_key: string, _value: string): Promise<void> { return; }
    public async getClusterMetadata(_key: string): Promise<string> { return ''; }
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112

    public async cleanUp(): Promise<void> {
        if (this.internalTrainingService === undefined) {
            throw new Error("TrainingService is not assigned!");
        }
        await this.internalTrainingService.cleanUp();
    }

    public async run(): Promise<void> {
        // wait internal training service is assigned.
        // It will be assigned after set metadata of paiConfig
        while (this.internalTrainingService === undefined) {
            await delay(100);
        }
        return await this.internalTrainingService.run();
    }
J-shang's avatar
J-shang committed
113
114
115
116
117
118
119
120
121
122
123
124
125
126

    public async getTrialOutputLocalPath(trialJobId: string): Promise<string> {
        if (this.internalTrainingService === undefined) {
            throw new Error("TrainingService is not assigned!");
        }
        return this.internalTrainingService.getTrialOutputLocalPath(trialJobId);
    }

    public async fetchTrialOutput(trialJobId: string, subpath: string): Promise<void> {
        if (this.internalTrainingService === undefined) {
            throw new Error("TrainingService is not assigned!");
        }
        return this.internalTrainingService.fetchTrialOutput(trialJobId, subpath);
    }
127
128
}

129
export { RouterTrainingService };