amlEnvironmentService.ts 5.67 KB
Newer Older
SparkSnail's avatar
SparkSnail committed
1
2
3
4
5
6
7
8
9
10
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

'use strict';

import * as fs from 'fs';
import * as path from 'path';
import * as component from '../../../common/component';
import { getExperimentId } from '../../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../../common/log';
11
import { getExperimentRootDir } from '../../../common/utils';
liuzhe-lz's avatar
liuzhe-lz committed
12
import { ExperimentConfig, AmlConfig, flattenConfig } from '../../../common/experimentConfig';
SparkSnail's avatar
SparkSnail committed
13
import { validateCodeDir } from '../../common/util';
14
import { AMLClient } from '../aml/amlClient';
liuzhe-lz's avatar
liuzhe-lz committed
15
import { AMLEnvironmentInformation } from '../aml/amlConfig';
16
import { EnvironmentInformation, EnvironmentService } from '../environment';
17
18
import { EventEmitter } from "events";
import { AMLCommandChannel } from '../channels/amlCommandChannel';
19
import { SharedStorageService } from '../sharedStorage'
SparkSnail's avatar
SparkSnail committed
20

liuzhe-lz's avatar
liuzhe-lz committed
21
interface FlattenAmlConfig extends ExperimentConfig, AmlConfig { }
SparkSnail's avatar
SparkSnail committed
22
23

/**
24
 * Collector AML jobs info from AML cluster, and update aml job status locally
SparkSnail's avatar
SparkSnail committed
25
26
27
 */
@component.Singleton
export class AMLEnvironmentService extends EnvironmentService {
28

SparkSnail's avatar
SparkSnail committed
29
30
31
    private readonly log: Logger = getLogger();
    private experimentId: string;
    private experimentRootDir: string;
liuzhe-lz's avatar
liuzhe-lz committed
32
    private config: FlattenAmlConfig;
SparkSnail's avatar
SparkSnail committed
33

liuzhe-lz's avatar
liuzhe-lz committed
34
    constructor(config: ExperimentConfig) {
SparkSnail's avatar
SparkSnail committed
35
36
37
        super();
        this.experimentId = getExperimentId();
        this.experimentRootDir = getExperimentRootDir();
liuzhe-lz's avatar
liuzhe-lz committed
38
39
        this.config = flattenConfig(config, 'aml');
        validateCodeDir(this.config.trialCodeDirectory);
SparkSnail's avatar
SparkSnail committed
40
41
42
43
44
45
    }

    public get hasStorageService(): boolean {
        return false;
    }

46
47
    public initCommandChannel(eventEmitter: EventEmitter): void {
        this.commandChannel = new AMLCommandChannel(eventEmitter);
SparkSnail's avatar
SparkSnail committed
48
49
    }

J-shang's avatar
J-shang committed
50
    public createEnvironmentInformation(envId: string, envName: string): EnvironmentInformation {
SparkSnail's avatar
SparkSnail committed
51
52
53
        return new AMLEnvironmentInformation(envId, envName);
    }

54
55
56
57
    public get getName(): string {
        return 'aml';
    }

SparkSnail's avatar
SparkSnail committed
58
59
60
    public async refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise<void> {
        environments.forEach(async (environment) => {
            const amlClient = (environment as AMLEnvironmentInformation).amlClient;
61
            if (!amlClient) {
62
                return Promise.reject('AML client not initialized!');
SparkSnail's avatar
SparkSnail committed
63
            }
64
65
            const newStatus = await amlClient.updateStatus(environment.status);
            switch (newStatus.toUpperCase()) {
SparkSnail's avatar
SparkSnail committed
66
67
                case 'WAITING':
                case 'QUEUED':
68
69
70
71
                    environment.setStatus('WAITING');
                    break;
                case 'RUNNING':
                    environment.setStatus('RUNNING');
SparkSnail's avatar
SparkSnail committed
72
73
74
                    break;
                case 'COMPLETED':
                case 'SUCCEEDED':
75
                    environment.setStatus('SUCCEEDED');
SparkSnail's avatar
SparkSnail committed
76
77
                    break;
                case 'FAILED':
78
79
                    environment.setStatus('FAILED');
                    return Promise.reject(`AML: job ${environment.envId} is failed!`);
SparkSnail's avatar
SparkSnail committed
80
81
                case 'STOPPED':
                case 'STOPPING':
82
                    environment.setStatus('USER_CANCELED');
SparkSnail's avatar
SparkSnail committed
83
84
                    break;
                default:
85
                    environment.setStatus('UNKNOWN');
SparkSnail's avatar
SparkSnail committed
86
87
88
89
90
91
            }
        });
    }

    public async startEnvironment(environment: EnvironmentInformation): Promise<void> {
        const amlEnvironment: AMLEnvironmentInformation = environment as AMLEnvironmentInformation;
SparkSnail's avatar
SparkSnail committed
92
        const environmentLocalTempFolder = path.join(this.experimentRootDir, "environment-temp");
93
94
95
96
97
98
99
100
101
102
103
        if (!fs.existsSync(environmentLocalTempFolder)) {
            await fs.promises.mkdir(environmentLocalTempFolder, {recursive: true});
        }
        if (amlEnvironment.useSharedStorage) {
            const environmentRoot = component.get<SharedStorageService>(SharedStorageService).remoteWorkingRoot;
            const remoteMountCommand = component.get<SharedStorageService>(SharedStorageService).remoteMountCommand;
            amlEnvironment.command = `${remoteMountCommand} && cd ${environmentRoot} && ${amlEnvironment.command}`.replace(/"/g, `\\"`);
        } else {
            amlEnvironment.command = `mv envs outputs/envs && cd outputs && ${amlEnvironment.command}`;
        }
        amlEnvironment.command = `import os\nos.system('${amlEnvironment.command}')`;
SparkSnail's avatar
SparkSnail committed
104
105
106
        if (this.config.deprecated && this.config.deprecated.useActiveGpu !== undefined) {
            amlEnvironment.useActiveGpu = this.config.deprecated.useActiveGpu;
        }
liuzhe-lz's avatar
liuzhe-lz committed
107
        amlEnvironment.maxTrialNumberPerGpu = this.config.maxTrialNumberPerGpu;
108

109
        await fs.promises.writeFile(path.join(environmentLocalTempFolder, 'nni_script.py'), amlEnvironment.command, { encoding: 'utf8' });
SparkSnail's avatar
SparkSnail committed
110
        const amlClient = new AMLClient(
liuzhe-lz's avatar
liuzhe-lz committed
111
112
113
            this.config.subscriptionId,
            this.config.resourceGroup,
            this.config.workspaceName,
SparkSnail's avatar
SparkSnail committed
114
            this.experimentId,
liuzhe-lz's avatar
liuzhe-lz committed
115
116
            this.config.computeTarget,
            this.config.dockerImage,
SparkSnail's avatar
SparkSnail committed
117
118
119
120
            'nni_script.py',
            environmentLocalTempFolder
        );
        amlEnvironment.id = await amlClient.submit();
liuzhe-lz's avatar
liuzhe-lz committed
121
        this.log.debug('aml: before getTrackingUrl');
SparkSnail's avatar
SparkSnail committed
122
        amlEnvironment.trackingUrl = await amlClient.getTrackingUrl();
liuzhe-lz's avatar
liuzhe-lz committed
123
        this.log.debug('aml: after getTrackingUrl');
SparkSnail's avatar
SparkSnail committed
124
125
126
127
128
129
130
131
132
133
134
135
        amlEnvironment.amlClient = amlClient;
    }

    public async stopEnvironment(environment: EnvironmentInformation): Promise<void> {
        const amlEnvironment: AMLEnvironmentInformation = environment as AMLEnvironmentInformation;
        const amlClient = amlEnvironment.amlClient;
        if (!amlClient) {
            throw new Error('AML client not initialized!');
        }
        amlClient.stop();
    }
}