utils.ts 12.6 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
Deshui Yu's avatar
Deshui Yu committed
3

4
import assert from 'assert';
Deshui Yu's avatar
Deshui Yu committed
5
import { randomBytes } from 'crypto';
6
7
import cpp from 'child-process-promise';
import cp from 'child_process';
8
import { ChildProcess, spawn, StdioOptions } from 'child_process';
9
10
11
12
13
import dgram from 'dgram';
import fs from 'fs';
import net from 'net';
import os from 'os';
import path from 'path';
liuzhe-lz's avatar
liuzhe-lz committed
14
import * as timersPromises from 'timers/promises';
15
import lockfile from 'lockfile';
Deshui Yu's avatar
Deshui Yu committed
16
17
import { Deferred } from 'ts-deferred';
import { Container } from 'typescript-ioc';
18
import glob from 'glob';
Deshui Yu's avatar
Deshui Yu committed
19
20

import { Database, DataStore } from './datastore';
21
import { getExperimentStartupInfo, setExperimentStartupInfo } from './experimentStartupInfo';
22
import { ExperimentConfig, Manager } from './manager';
23
import { ExperimentManager } from './experimentManager';
QuanluZhang's avatar
QuanluZhang committed
24
import { HyperParameters, TrainingService, TrialJobStatus } from './trainingService';
Deshui Yu's avatar
Deshui Yu committed
25

26
function getExperimentRootDir(): string {
27
    return getExperimentStartupInfo().logDir;
Deshui Yu's avatar
Deshui Yu committed
28
29
}

30
function getLogDir(): string {
Deshui Yu's avatar
Deshui Yu committed
31
32
33
    return path.join(getExperimentRootDir(), 'log');
}

34
function getLogLevel(): string {
35
    return getExperimentStartupInfo().logLevel;
36
37
}

Deshui Yu's avatar
Deshui Yu committed
38
39
40
41
function getDefaultDatabaseDir(): string {
    return path.join(getExperimentRootDir(), 'db');
}

QuanluZhang's avatar
QuanluZhang committed
42
43
44
45
function getCheckpointDir(): string {
    return path.join(getExperimentRootDir(), 'checkpoint');
}

46
47
48
49
function getExperimentsInfoPath(): string {
    return path.join(os.homedir(), 'nni-experiments', '.experiment');
}

liuzhe-lz's avatar
liuzhe-lz committed
50
51
async function mkDirP(dirPath: string): Promise<void> {
    await fs.promises.mkdir(dirPath, { recursive: true });
Deshui Yu's avatar
Deshui Yu committed
52
53
54
}

function mkDirPSync(dirPath: string): void {
liuzhe-lz's avatar
liuzhe-lz committed
55
    fs.mkdirSync(dirPath, { recursive: true });
Deshui Yu's avatar
Deshui Yu committed
56
57
}

liuzhe-lz's avatar
liuzhe-lz committed
58
const delay = timersPromises.setTimeout;
Deshui Yu's avatar
Deshui Yu committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96

/**
 * Convert index to character
 * @param index index
 * @returns a mapping character
 */
function charMap(index: number): number {
    if (index < 26) {
        return index + 97;
    } else if (index < 52) {
        return index - 26 + 65;
    } else {
        return index - 52 + 48;
    }
}

/**
 * Generate a unique string by length
 * @param len length of string
 * @returns a unique string
 */
function uniqueString(len: number): string {
    if (len === 0) {
        return '';
    }
    const byteLength: number = Math.ceil((Math.log2(52) + Math.log2(62) * (len - 1)) / 8);
    let num: number = randomBytes(byteLength).reduce((a: number, b: number) => a * 256 + b, 0);
    const codes: number[] = [];
    codes.push(charMap(num % 52));
    num = Math.floor(num / 52);
    for (let i: number = 1; i < len; i++) {
        codes.push(charMap(num % 62));
        num = Math.floor(num / 62);
    }

    return String.fromCharCode(...codes);
}

97
98
99
100
function randomInt(max: number): number {
    return Math.floor(Math.random() * max);
}

101
102
103
104
105
function randomSelect<T>(a: T[]): T {
    assert(a !== undefined);

    return a[Math.floor(Math.random() * a.length)];
}
106

107
function getCmdPy(): string {
108
    let cmd = 'python3';
109
    if (process.platform === 'win32') {
110
111
112
113
114
        cmd = 'python';
    }
    return cmd;
}

115
/**
116
 * Generate command line to start automl algorithm(s),
QuanluZhang's avatar
QuanluZhang committed
117
 * either start advisor or start a process which runs tuner and assessor
118
 *
chicm-ms's avatar
chicm-ms committed
119
 * @param expParams: experiment startup parameters
120
121
 *
 */
122
function getMsgDispatcherCommand(expParams: ExperimentConfig): string {
chicm-ms's avatar
chicm-ms committed
123
124
125
    const clonedParams = Object.assign({}, expParams);
    delete clonedParams.searchSpace;
    return `${getCmdPy()} -m nni --exp_params ${Buffer.from(JSON.stringify(clonedParams)).toString('base64')}`;
126
127
}

128
129
130
131
/**
 * Generate parameter file name based on HyperParameters object
 * @param hyperParameters HyperParameters instance
 */
chicm-ms's avatar
chicm-ms committed
132
function generateParamFileName(hyperParameters: HyperParameters): string {
133
134
135
    assert(hyperParameters !== undefined);
    assert(hyperParameters.index >= 0);

chicm-ms's avatar
chicm-ms committed
136
    let paramFileName: string;
137
    if (hyperParameters.index == 0) {
138
139
140
141
142
143
144
        paramFileName = 'parameter.cfg';
    } else {
        paramFileName = `parameter_${hyperParameters.index}.cfg`
    }
    return paramFileName;
}

Deshui Yu's avatar
Deshui Yu committed
145
146
147
148
149
150
151
152
153
/**
 * Initialize a pseudo experiment environment for unit test.
 * Must be paired with `cleanupUnitTest()`.
 */
function prepareUnitTest(): void {
    Container.snapshot(Database);
    Container.snapshot(DataStore);
    Container.snapshot(TrainingService);
    Container.snapshot(Manager);
154
    Container.snapshot(ExperimentManager);
Deshui Yu's avatar
Deshui Yu committed
155

liuzhe-lz's avatar
liuzhe-lz committed
156
157
158
159
160
161
162
163
164
165
166
    setExperimentStartupInfo({
        port: 8080,
        experimentId: 'unittest',
        action: 'create',
        experimentsDirectory: path.join(os.homedir(), 'nni-experiments'),
        logLevel: 'info',
        foreground: false,
        urlPrefix: '',
        mode: 'unittest',
        dispatcherPipe: undefined,
    });
Deshui Yu's avatar
Deshui Yu committed
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
    mkDirPSync(getLogDir());

    const sqliteFile: string = path.join(getDefaultDatabaseDir(), 'nni.sqlite');
    try {
        fs.unlinkSync(sqliteFile);
    } catch (err) {
        // file not exists, good
    }
}

/**
 * Clean up unit test pseudo experiment.
 * Must be paired with `prepareUnitTest()`.
 */
function cleanupUnitTest(): void {
    Container.restore(Manager);
    Container.restore(TrainingService);
    Container.restore(DataStore);
    Container.restore(Database);
186
    Container.restore(ExperimentManager);
Deshui Yu's avatar
Deshui Yu committed
187
188
}

189
190
let cachedIpv4Address: string | null = null;

191
/**
192
 * Get IPv4 address of current machine.
193
 */
liuzhe-lz's avatar
liuzhe-lz committed
194
async function getIPV4Address(): Promise<string> {
195
196
    if (cachedIpv4Address !== null) {
        return cachedIpv4Address;
197
    }
198

199
200
201
202
    // creates "udp connection" to a non-exist target, and get local address of the connection.
    // since udp is connectionless, this does not send actual packets.
    const socket = dgram.createSocket('udp4');
    socket.connect(1, '192.0.2.0');
liuzhe-lz's avatar
liuzhe-lz committed
203
    for (let i = 0; i < 10; i++) {  // wait the system to initialize "connection"
liuzhe-lz's avatar
liuzhe-lz committed
204
205
206
207
208
209
210
211
        await timersPromises.setTimeout(1);
        try {
            cachedIpv4Address = socket.address().address;
            socket.close();
            return cachedIpv4Address;
        } catch (error) {
            /* retry */
        }
liuzhe-lz's avatar
liuzhe-lz committed
212
    }
liuzhe-lz's avatar
liuzhe-lz committed
213

liuzhe-lz's avatar
liuzhe-lz committed
214
    cachedIpv4Address = socket.address().address;  // if it still fails, throw the error
215
216
    socket.close();
    return cachedIpv4Address;
217
218
}

QuanluZhang's avatar
QuanluZhang committed
219
220
221
222
223
224
225
/**
 * Get the status of canceled jobs according to the hint isEarlyStopped
 */
function getJobCancelStatus(isEarlyStopped: boolean): TrialJobStatus {
    return isEarlyStopped ? 'EARLY_STOPPED' : 'USER_CANCELED';
}

226
227
228
229
/**
 * Utility method to calculate file numbers under a directory, recursively
 * @param directory directory name
 */
chicm-ms's avatar
chicm-ms committed
230
function countFilesRecursively(directory: string): Promise<number> {
231
    if (!fs.existsSync(directory)) {
232
233
234
235
236
        throw Error(`Direcotory ${directory} doesn't exist`);
    }

    const deferred: Deferred<number> = new Deferred<number>();

chicm-ms's avatar
chicm-ms committed
237
    let timeoutId: NodeJS.Timer
238
    const delayTimeout: Promise<number> = new Promise((_resolve: Function, reject: Function): void => {
239
240
241
242
243
244
245
        // Set timeout and reject the promise once reach timeout (5 seconds)
        timeoutId = setTimeout(() => {
            reject(new Error(`Timeout: path ${directory} has too many files`));
        }, 5000);
    });

    let fileCount: number = -1;
246
    let cmd: string;
247
    if (process.platform === "win32") {
248
249
        cmd = `powershell "Get-ChildItem -Path ${directory} -Recurse -File | Measure-Object | %{$_.Count}"`
    } else {
250
        cmd = `find ${directory} -type f | wc -l`;
251
252
    }
    cpp.exec(cmd).then((result) => {
253
        if (result.stdout && parseInt(result.stdout)) {
254
            fileCount = parseInt(result.stdout);
255
256
257
258
259
260
261
262
        }
        deferred.resolve(fileCount);
    });
    return Promise.race([deferred.promise, delayTimeout]).finally(() => {
        clearTimeout(timeoutId);
    });
}

263
264
265
266
/**
 * get the version of current package
 */
async function getVersion(): Promise<string> {
chicm-ms's avatar
chicm-ms committed
267
    const deferred: Deferred<string> = new Deferred<string>();
268
    import(path.join(__dirname, '..', 'package.json')).then((pkg) => {
269
        deferred.resolve(pkg.version);
270
271
    }).catch(() => {
        deferred.resolve('999.0.0-developing');
272
273
    });
    return deferred.promise;
274
}
275

276
277
278
/**
 * run command as ChildProcess
 */
J-shang's avatar
J-shang committed
279
function getTunerProc(command: string, stdio: StdioOptions, newCwd: string, newEnv: any, newShell: boolean = true, isDetached: boolean = false): ChildProcess {
280
281
    let cmd: string = command;
    let arg: string[] = [];
282
    if (process.platform === "win32") {
283
        cmd = command.split(" ", 1)[0];
284
        arg = command.substr(cmd.length + 1).split(" ");
285
        newShell = false;
286
        isDetached = true;
287
288
289
290
291
    }
    const tunerProc: ChildProcess = spawn(cmd, arg, {
        stdio,
        cwd: newCwd,
        env: newEnv,
292
293
        shell: newShell,
        detached: isDetached
294
295
296
297
298
299
300
    });
    return tunerProc;
}

/**
 * judge whether the process is alive
 */
Yuge Zhang's avatar
Yuge Zhang committed
301
async function isAlive(pid: any): Promise<boolean> {
chicm-ms's avatar
chicm-ms committed
302
    const deferred: Deferred<boolean> = new Deferred<boolean>();
303
    let alive: boolean = false;
Yuge Zhang's avatar
Yuge Zhang committed
304
    if (process.platform === 'win32') {
305
306
307
308
309
310
311
        try {
            const str = cp.execSync(`powershell.exe Get-Process -Id ${pid} -ErrorAction SilentlyContinue`).toString();
            if (str) {
                alive = true;
            }
        }
        catch (error) {
chicm-ms's avatar
chicm-ms committed
312
            //ignore
313
314
        }
    }
Yuge Zhang's avatar
Yuge Zhang committed
315
    else {
316
317
318
319
320
321
322
323
324
325
326
327
        try {
            await cpp.exec(`kill -0 ${pid}`);
            alive = true;
        } catch (error) {
            //ignore
        }
    }
    deferred.resolve(alive);
    return deferred.promise;
}

/**
328
 * kill process
329
 */
Yuge Zhang's avatar
Yuge Zhang committed
330
async function killPid(pid: any): Promise<void> {
chicm-ms's avatar
chicm-ms committed
331
    const deferred: Deferred<void> = new Deferred<void>();
332
333
    try {
        if (process.platform === "win32") {
Yuge Zhang's avatar
Yuge Zhang committed
334
            await cpp.exec(`cmd.exe /c taskkill /PID ${pid} /F`);
335
        }
336
        else {
337
338
339
340
341
342
343
344
345
            await cpp.exec(`kill -9 ${pid}`);
        }
    } catch (error) {
        // pid does not exist, do nothing here
    }
    deferred.resolve();
    return deferred.promise;
}

346
function getNewLine(): string {
347
348
349
    if (process.platform === "win32") {
        return "\r\n";
    }
350
    else {
351
352
353
354
        return "\n";
    }
}

355
356
/**
 * Use '/' to join path instead of '\' for all kinds of platform
357
 * @param path
358
359
360
361
362
363
364
 */
function unixPathJoin(...paths: any[]): string {
    const dir: string = paths.filter((path: any) => path !== '').join('/');
    if (dir === '') return '.';
    return dir;
}

365
366
367
368
369
/**
 * lock a file sync
 */
function withLockSync(func: Function, filePath: string, lockOpts: {[key: string]: any}, ...args: any): any {
    const lockName = path.join(path.dirname(filePath), path.basename(filePath) + `.lock.${process.pid}`);
370
    if (typeof lockOpts['stale'] === 'number'){
371
372
373
        const lockPath = path.join(path.dirname(filePath), path.basename(filePath) + '.lock.*');
        const lockFileNames: string[] = glob.sync(lockPath);
        const canLock: boolean = lockFileNames.map((fileName) => {
374
            return fs.existsSync(fileName) && Date.now() - fs.statSync(fileName).mtimeMs < lockOpts['stale'];
375
        }).filter(unexpired=>unexpired === true).length === 0;
376
377
378
379
380
381
382
383
384
385
        if (!canLock) {
            throw new Error('File has been locked.');
        }
    }
    lockfile.lockSync(lockName, lockOpts);
    const result = func(...args);
    lockfile.unlockSync(lockName);
    return result;
}

J-shang's avatar
J-shang committed
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
async function isPortOpen(host: string, port: number): Promise<boolean> {
    return new Promise<boolean>((resolve, reject) => {
        try{
            const stream = net.createConnection(port, host);
            const id = setTimeout(() => {
                stream.destroy();
                resolve(false);
            }, 1000);

            stream.on('connect', () => {
                clearTimeout(id);
                stream.destroy();
                resolve(true);
            });

            stream.on('error', () => {
                clearTimeout(id);
                stream.destroy();
                resolve(false);
            });
        } catch (error) {
            reject(error);
        }
    });
}

async function getFreePort(host: string, start: number, end: number): Promise<number> {
    if (start > end) {
        throw new Error(`no more free port`);
    }
    if (await isPortOpen(host, start)) {
        return await getFreePort(host, start + 1, end);
    } else {
        return start;
    }
}

423
424
425
426
427
export function importModule(modulePath: string): any {
    module.paths.unshift(path.dirname(modulePath));
    return require(path.basename(modulePath));
}

428
export {
429
    countFilesRecursively, generateParamFileName, getMsgDispatcherCommand, getCheckpointDir, getExperimentsInfoPath,
J-shang's avatar
J-shang committed
430
    getLogDir, getExperimentRootDir, getJobCancelStatus, getDefaultDatabaseDir, getIPV4Address, unixPathJoin, withLockSync, getFreePort, isPortOpen,
liuzhe-lz's avatar
liuzhe-lz committed
431
    mkDirP, mkDirPSync, delay, prepareUnitTest, cleanupUnitTest, uniqueString, randomInt, randomSelect, getLogLevel, getVersion, getCmdPy, getTunerProc, isAlive, killPid, getNewLine
432
};