amlEnvironmentService.ts 5.87 KB
Newer Older
SparkSnail's avatar
SparkSnail committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

'use strict';

import * as fs from 'fs';
import * as path from 'path';
import * as component from '../../../common/component';
import { getExperimentId } from '../../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../../common/log';
import { TrialConfigMetadataKey } from '../../common/trialConfigMetadataKey';
import { AMLClusterConfig, AMLTrialConfig } from '../aml/amlConfig';
import { EnvironmentInformation, EnvironmentService } from '../environment';
import { AMLEnvironmentInformation } from '../aml/amlConfig';
import { AMLClient } from '../aml/amlClient';
import {
    NNIManagerIpConfig,
} from '../../../common/trainingService';
import { validateCodeDir } from '../../common/util';
import { getExperimentRootDir } from '../../../common/utils';
import { AMLCommandChannel } from '../channels/amlCommandChannel';
import { CommandChannel } from "../commandChannel";
import { EventEmitter } from "events";


/**
 * Collector PAI jobs info from PAI cluster, and update pai job status locally
 */
@component.Singleton
export class AMLEnvironmentService extends EnvironmentService {
    
    private readonly log: Logger = getLogger();
    public amlClusterConfig: AMLClusterConfig | undefined;
    public amlTrialConfig: AMLTrialConfig | undefined;
    private amlJobConfig: any;
    private stopping: boolean = false;
    private versionCheck: boolean = true;
    private isMultiPhase: boolean = false;
    private nniVersion?: string;
    private experimentId: string;
    private nniManagerIpConfig?: NNIManagerIpConfig;
    private experimentRootDir: string;

    constructor() {
        super();
        this.experimentId = getExperimentId();
        this.experimentRootDir = getExperimentRootDir();
    }

    public get hasStorageService(): boolean {
        return false;
    }

    public getCommandChannel(commandEmitter: EventEmitter): CommandChannel {
        return new AMLCommandChannel(commandEmitter);
    }

    public createEnviornmentInfomation(envId: string, envName: string): EnvironmentInformation {
        return new AMLEnvironmentInformation(envId, envName);
    }

    public async config(key: string, value: string): Promise<void> {
        switch (key) {
            case TrialConfigMetadataKey.AML_CLUSTER_CONFIG:
                this.amlClusterConfig = <AMLClusterConfig>JSON.parse(value);
                break;

            case TrialConfigMetadataKey.TRIAL_CONFIG: {
                if (this.amlClusterConfig === undefined) {
                    this.log.error('aml cluster config is not initialized');
                    break;
                }
                this.amlTrialConfig = <AMLTrialConfig>JSON.parse(value);
                // Validate to make sure codeDir doesn't have too many files
                await validateCodeDir(this.amlTrialConfig.codeDir);
                break;
            }
            default:
                this.log.debug(`AML not proccessed metadata key: '${key}', value: '${value}'`);
        }
    }

    public async refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise<void> {
        environments.forEach(async (environment) => {
            const amlClient = (environment as AMLEnvironmentInformation).amlClient;
                    if (!amlClient) {
            throw new Error('AML client not initialized!');
            }
            const status = await amlClient.updateStatus(environment.status);
            switch (status.toUpperCase()) {
                case 'WAITING':
                case 'RUNNING':
                case 'QUEUED':
                    // RUNNING status is set by runner, and ignore waiting status
                    break;
                case 'COMPLETED':
                case 'SUCCEEDED':
                    environment.setFinalStatus('SUCCEEDED');
                    break;
                case 'FAILED':
                    environment.setFinalStatus('FAILED');
                    break;
                case 'STOPPED':
                case 'STOPPING':
                    environment.setFinalStatus('USER_CANCELED');
                    break;
                default:
                    environment.setFinalStatus('UNKNOWN');
            }
        });
    }

    public async startEnvironment(environment: EnvironmentInformation): Promise<void> {
        if (this.amlClusterConfig === undefined) {
            throw new Error('AML Cluster config is not initialized');
        }
        if (this.amlTrialConfig === undefined) {
            throw new Error('AML trial config is not initialized');
        }
        const amlEnvironment: AMLEnvironmentInformation = environment as AMLEnvironmentInformation;
        const environmentLocalTempFolder = path.join(this.experimentRootDir, this.experimentId, "environment-temp");
        environment.command = `import os\nos.system('${amlEnvironment.command}')`;
        await fs.promises.writeFile(path.join(environmentLocalTempFolder, 'nni_script.py'), amlEnvironment.command ,{ encoding: 'utf8' });
        const amlClient = new AMLClient(
            this.amlClusterConfig.subscriptionId,
            this.amlClusterConfig.resourceGroup,
            this.amlClusterConfig.workspaceName,
            this.experimentId,
            this.amlTrialConfig.computeTarget,
            this.amlTrialConfig.image,
            'nni_script.py',
            environmentLocalTempFolder
        );
        amlEnvironment.id = await amlClient.submit();
        amlEnvironment.trackingUrl = await amlClient.getTrackingUrl();
        amlEnvironment.amlClient = amlClient;
    }

    public async stopEnvironment(environment: EnvironmentInformation): Promise<void> {
        const amlEnvironment: AMLEnvironmentInformation = environment as AMLEnvironmentInformation;
        const amlClient = amlEnvironment.amlClient;
        if (!amlClient) {
            throw new Error('AML client not initialized!');
        }
        amlClient.stop();
    }
}