amlEnvironmentService.ts 5.66 KB
Newer Older
SparkSnail's avatar
SparkSnail committed
1
2
3
4
5
6
7
8
9
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

'use strict';

import * as fs from 'fs';
import * as path from 'path';
import * as component from '../../../common/component';
import { getLogger, Logger } from '../../../common/log';
liuzhe-lz's avatar
liuzhe-lz committed
10
import { ExperimentConfig, AmlConfig, flattenConfig } from '../../../common/experimentConfig';
11
import { ExperimentStartupInfo } from '../../../common/experimentStartupInfo';
SparkSnail's avatar
SparkSnail committed
12
import { validateCodeDir } from '../../common/util';
13
import { AMLClient } from '../aml/amlClient';
liuzhe-lz's avatar
liuzhe-lz committed
14
import { AMLEnvironmentInformation } from '../aml/amlConfig';
15
import { EnvironmentInformation, EnvironmentService } from '../environment';
16
17
import { EventEmitter } from "events";
import { AMLCommandChannel } from '../channels/amlCommandChannel';
18
import { SharedStorageService } from '../sharedStorage'
SparkSnail's avatar
SparkSnail committed
19

liuzhe-lz's avatar
liuzhe-lz committed
20
interface FlattenAmlConfig extends ExperimentConfig, AmlConfig { }
SparkSnail's avatar
SparkSnail committed
21
22

/**
23
 * Collector AML jobs info from AML cluster, and update aml job status locally
SparkSnail's avatar
SparkSnail committed
24
25
26
 */
@component.Singleton
export class AMLEnvironmentService extends EnvironmentService {
27

liuzhe-lz's avatar
liuzhe-lz committed
28
    private readonly log: Logger = getLogger('AMLEnvironmentService');
SparkSnail's avatar
SparkSnail committed
29
30
    private experimentId: string;
    private experimentRootDir: string;
liuzhe-lz's avatar
liuzhe-lz committed
31
    private config: FlattenAmlConfig;
SparkSnail's avatar
SparkSnail committed
32

33
    constructor(config: ExperimentConfig, info: ExperimentStartupInfo) {
SparkSnail's avatar
SparkSnail committed
34
        super();
35
36
        this.experimentId = info.experimentId;
        this.experimentRootDir = info.logDir;
liuzhe-lz's avatar
liuzhe-lz committed
37
38
        this.config = flattenConfig(config, 'aml');
        validateCodeDir(this.config.trialCodeDirectory);
SparkSnail's avatar
SparkSnail committed
39
40
41
42
43
44
    }

    public get hasStorageService(): boolean {
        return false;
    }

45
46
    public initCommandChannel(eventEmitter: EventEmitter): void {
        this.commandChannel = new AMLCommandChannel(eventEmitter);
SparkSnail's avatar
SparkSnail committed
47
48
    }

J-shang's avatar
J-shang committed
49
    public createEnvironmentInformation(envId: string, envName: string): EnvironmentInformation {
SparkSnail's avatar
SparkSnail committed
50
51
52
        return new AMLEnvironmentInformation(envId, envName);
    }

53
54
55
56
    public get getName(): string {
        return 'aml';
    }

SparkSnail's avatar
SparkSnail committed
57
58
59
    public async refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise<void> {
        environments.forEach(async (environment) => {
            const amlClient = (environment as AMLEnvironmentInformation).amlClient;
60
            if (!amlClient) {
61
                return Promise.reject('AML client not initialized!');
SparkSnail's avatar
SparkSnail committed
62
            }
63
64
            const newStatus = await amlClient.updateStatus(environment.status);
            switch (newStatus.toUpperCase()) {
SparkSnail's avatar
SparkSnail committed
65
66
                case 'WAITING':
                case 'QUEUED':
67
68
69
70
                    environment.setStatus('WAITING');
                    break;
                case 'RUNNING':
                    environment.setStatus('RUNNING');
SparkSnail's avatar
SparkSnail committed
71
72
73
                    break;
                case 'COMPLETED':
                case 'SUCCEEDED':
74
                    environment.setStatus('SUCCEEDED');
SparkSnail's avatar
SparkSnail committed
75
76
                    break;
                case 'FAILED':
77
78
                    environment.setStatus('FAILED');
                    return Promise.reject(`AML: job ${environment.envId} is failed!`);
SparkSnail's avatar
SparkSnail committed
79
80
                case 'STOPPED':
                case 'STOPPING':
81
                    environment.setStatus('USER_CANCELED');
SparkSnail's avatar
SparkSnail committed
82
83
                    break;
                default:
84
                    environment.setStatus('UNKNOWN');
SparkSnail's avatar
SparkSnail committed
85
86
87
88
89
90
            }
        });
    }

    public async startEnvironment(environment: EnvironmentInformation): Promise<void> {
        const amlEnvironment: AMLEnvironmentInformation = environment as AMLEnvironmentInformation;
SparkSnail's avatar
SparkSnail committed
91
        const environmentLocalTempFolder = path.join(this.experimentRootDir, "environment-temp");
92
93
94
95
96
97
98
99
100
101
102
        if (!fs.existsSync(environmentLocalTempFolder)) {
            await fs.promises.mkdir(environmentLocalTempFolder, {recursive: true});
        }
        if (amlEnvironment.useSharedStorage) {
            const environmentRoot = component.get<SharedStorageService>(SharedStorageService).remoteWorkingRoot;
            const remoteMountCommand = component.get<SharedStorageService>(SharedStorageService).remoteMountCommand;
            amlEnvironment.command = `${remoteMountCommand} && cd ${environmentRoot} && ${amlEnvironment.command}`.replace(/"/g, `\\"`);
        } else {
            amlEnvironment.command = `mv envs outputs/envs && cd outputs && ${amlEnvironment.command}`;
        }
        amlEnvironment.command = `import os\nos.system('${amlEnvironment.command}')`;
SparkSnail's avatar
SparkSnail committed
103
104
105
        if (this.config.deprecated && this.config.deprecated.useActiveGpu !== undefined) {
            amlEnvironment.useActiveGpu = this.config.deprecated.useActiveGpu;
        }
liuzhe-lz's avatar
liuzhe-lz committed
106
        amlEnvironment.maxTrialNumberPerGpu = this.config.maxTrialNumberPerGpu;
107

108
        await fs.promises.writeFile(path.join(environmentLocalTempFolder, 'nni_script.py'), amlEnvironment.command, { encoding: 'utf8' });
SparkSnail's avatar
SparkSnail committed
109
        const amlClient = new AMLClient(
liuzhe-lz's avatar
liuzhe-lz committed
110
111
112
            this.config.subscriptionId,
            this.config.resourceGroup,
            this.config.workspaceName,
SparkSnail's avatar
SparkSnail committed
113
            this.experimentId,
liuzhe-lz's avatar
liuzhe-lz committed
114
115
            this.config.computeTarget,
            this.config.dockerImage,
SparkSnail's avatar
SparkSnail committed
116
117
118
119
            'nni_script.py',
            environmentLocalTempFolder
        );
        amlEnvironment.id = await amlClient.submit();
liuzhe-lz's avatar
liuzhe-lz committed
120
        this.log.debug('aml: before getTrackingUrl');
SparkSnail's avatar
SparkSnail committed
121
        amlEnvironment.trackingUrl = await amlClient.getTrackingUrl();
liuzhe-lz's avatar
liuzhe-lz committed
122
        this.log.debug('aml: after getTrackingUrl');
SparkSnail's avatar
SparkSnail committed
123
124
125
126
127
128
129
130
131
132
133
134
        amlEnvironment.amlClient = amlClient;
    }

    public async stopEnvironment(environment: EnvironmentInformation): Promise<void> {
        const amlEnvironment: AMLEnvironmentInformation = environment as AMLEnvironmentInformation;
        const amlClient = amlEnvironment.amlClient;
        if (!amlClient) {
            throw new Error('AML client not initialized!');
        }
        amlClient.stop();
    }
}