routerTrainingService.ts 8.91 KB
Newer Older
1
2
3
4
5
6
7
8
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

'use strict';

import { Container, Scope } from 'typescript-ioc';
import * as component from '../../common/component';
import { getLogger, Logger } from '../../common/log';
9
10
import { MethodNotImplementedError } from '../../common/errors'
import { TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric, LogType } from '../../common/trainingService';
11
12
13
14
import { delay } from '../../common/utils';
import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
import { PAIClusterConfig } from '../pai/paiConfig';
import { PAIK8STrainingService } from '../pai/paiK8S/paiK8STrainingService';
15
import { RemoteMachineTrainingService } from '../remote_machine/remoteMachineTrainingService';
16
17
18
import { MountedStorageService } from './storages/mountedStorageService';
import { StorageService } from './storageService';
import { TrialDispatcher } from './trialDispatcher';
19
import { RemoteConfig } from './remote/remoteConfig';
20
import { HeterogenousConfig } from './heterogenous/heterogenousConfig';
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49


/**
 * It's a intermedia implementation to support reusable training service.
 * The final goal is to support reusable training job in higher level than training service.
 */
@component.Singleton
class RouterTrainingService implements TrainingService {
    protected readonly log!: Logger;
    private internalTrainingService: TrainingService | undefined;

    constructor() {
        this.log = getLogger();
    }

    public async listTrialJobs(): Promise<TrialJobDetail[]> {
        if (this.internalTrainingService === undefined) {
            throw new Error("TrainingService is not assigned!");
        }
        return await this.internalTrainingService.listTrialJobs();
    }

    public async getTrialJob(trialJobId: string): Promise<TrialJobDetail> {
        if (this.internalTrainingService === undefined) {
            throw new Error("TrainingService is not assigned!");
        }
        return await this.internalTrainingService.getTrialJob(trialJobId);
    }

50
51
52
53
    public async getTrialLog(_trialJobId: string, _logType: LogType): Promise<string> {
        throw new MethodNotImplementedError();
    }

54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    public addTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void {
        if (this.internalTrainingService === undefined) {
            throw new Error("TrainingService is not assigned!");
        }
        this.internalTrainingService.addTrialJobMetricListener(listener);
    }

    public removeTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void {
        if (this.internalTrainingService === undefined) {
            throw new Error("TrainingService is not assigned!");
        }
        this.internalTrainingService.removeTrialJobMetricListener(listener);
    }

    public async submitTrialJob(form: TrialJobApplicationForm): Promise<TrialJobDetail> {
        if (this.internalTrainingService === undefined) {
            throw new Error("TrainingService is not assigned!");
        }
        return await this.internalTrainingService.submitTrialJob(form);
    }

    public async updateTrialJob(trialJobId: string, form: TrialJobApplicationForm): Promise<TrialJobDetail> {
        if (this.internalTrainingService === undefined) {
            throw new Error("TrainingService is not assigned!");
        }
        return await this.internalTrainingService.updateTrialJob(trialJobId, form);
    }

    public get isMultiPhaseJobSupported(): boolean {
        if (this.internalTrainingService === undefined) {
            throw new Error("TrainingService is not assigned!");
        }
        return this.internalTrainingService.isMultiPhaseJobSupported;
    }

    public async cancelTrialJob(trialJobId: string, isEarlyStopped?: boolean | undefined): Promise<void> {
        if (this.internalTrainingService === undefined) {
            throw new Error("TrainingService is not assigned!");
        }
        await this.internalTrainingService.cancelTrialJob(trialJobId, isEarlyStopped);
    }

    public async setClusterMetadata(key: string, value: string): Promise<void> {
        if (this.internalTrainingService === undefined) {
98
99
            // Need to refactor configuration, remove hybrid_config field in the future
            if (key === TrialConfigMetadataKey.HYBRID_CONFIG){
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
                this.internalTrainingService = component.get(TrialDispatcher);
                const heterogenousConfig: HeterogenousConfig = <HeterogenousConfig>JSON.parse(value);
                if (this.internalTrainingService === undefined) {
                    throw new Error("internalTrainingService not initialized!");
                }
                // Initialize storageService for pai, only support singleton for now, need refactor
                if (heterogenousConfig.trainingServicePlatforms.includes('pai')) {
                    Container.bind(StorageService)
                    .to(MountedStorageService)
                    .scope(Scope.Singleton);
                }
                await this.internalTrainingService.setClusterMetadata('platform_list', 
                    heterogenousConfig.trainingServicePlatforms.join(','));
            } else if (key === TrialConfigMetadataKey.LOCAL_CONFIG) {
                this.internalTrainingService = component.get(TrialDispatcher);
                if (this.internalTrainingService === undefined) {
                    throw new Error("internalTrainingService not initialized!");
                }
                await this.internalTrainingService.setClusterMetadata('platform_list', 'local');
            } else if (key === TrialConfigMetadataKey.PAI_CLUSTER_CONFIG) {
120
121
122
123
124
125
126
127
                const config = <PAIClusterConfig>JSON.parse(value);
                if (config.reuse === true) {
                    this.log.info(`reuse flag enabled, use EnvironmentManager.`);
                    this.internalTrainingService = component.get(TrialDispatcher);
                    // TODO to support other storages later.
                    Container.bind(StorageService)
                        .to(MountedStorageService)
                        .scope(Scope.Singleton);
128
129
130
131
                    if (this.internalTrainingService === undefined) {
                        throw new Error("internalTrainingService not initialized!");
                    }
                    await this.internalTrainingService.setClusterMetadata('platform_list', 'pai');
132
133
134
135
                } else {
                    this.log.debug(`caching metadata key:{} value:{}, as training service is not determined.`);
                    this.internalTrainingService = component.get(PAIK8STrainingService);
                }
SparkSnail's avatar
SparkSnail committed
136
137
138
            } else if (key === TrialConfigMetadataKey.AML_CLUSTER_CONFIG) {
                this.internalTrainingService = component.get(TrialDispatcher);
                if (this.internalTrainingService === undefined) {
139
                    throw new Error("internalTrainingService not initialized!");
SparkSnail's avatar
SparkSnail committed
140
                }
141
                await this.internalTrainingService.setClusterMetadata('platform_list', 'aml');
142
143
144
145
146
            } else if (key === TrialConfigMetadataKey.REMOTE_CONFIG) {
                const config = <RemoteConfig>JSON.parse(value);
                if (config.reuse === true) {
                    this.log.info(`reuse flag enabled, use EnvironmentManager.`);
                    this.internalTrainingService = component.get(TrialDispatcher);
147
148
149
150
                    if (this.internalTrainingService === undefined) {
                        throw new Error("internalTrainingService not initialized!");
                    }
                    await this.internalTrainingService.setClusterMetadata('platform_list', 'remote');
151
152
153
154
                } else {
                    this.log.debug(`caching metadata key:{} value:{}, as training service is not determined.`);
                    this.internalTrainingService = component.get(RemoteMachineTrainingService);
                }
155
156
            }
        }
157
158
159
160
161
        if (this.internalTrainingService === undefined) {
            throw new Error("internalTrainingService not initialized!");
        }
        await this.internalTrainingService.setClusterMetadata(key, value);
        
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
    }

    public async getClusterMetadata(key: string): Promise<string> {
        if (this.internalTrainingService === undefined) {
            throw new Error("TrainingService is not assigned!");
        }
        return await this.internalTrainingService.getClusterMetadata(key);
    }

    public async cleanUp(): Promise<void> {
        if (this.internalTrainingService === undefined) {
            throw new Error("TrainingService is not assigned!");
        }
        await this.internalTrainingService.cleanUp();
    }

    public async run(): Promise<void> {
        // wait internal training service is assigned.
        // It will be assigned after set metadata of paiConfig
        while (this.internalTrainingService === undefined) {
            await delay(100);
        }
        return await this.internalTrainingService.run();
    }
}

export { RouterTrainingService };