utils.ts 12.9 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
Deshui Yu's avatar
Deshui Yu committed
3

4
import assert from 'assert';
Deshui Yu's avatar
Deshui Yu committed
5
import { randomBytes } from 'crypto';
6
7
import cpp from 'child-process-promise';
import cp from 'child_process';
8
import { ChildProcess, spawn, StdioOptions } from 'child_process';
9
10
11
12
13
import dgram from 'dgram';
import fs from 'fs';
import net from 'net';
import os from 'os';
import path from 'path';
liuzhe-lz's avatar
liuzhe-lz committed
14
import * as timersPromises from 'timers/promises';
15
import lockfile from 'lockfile';
Deshui Yu's avatar
Deshui Yu committed
16
17
import { Deferred } from 'ts-deferred';
import { Container } from 'typescript-ioc';
18
import glob from 'glob';
Deshui Yu's avatar
Deshui Yu committed
19
20

import { Database, DataStore } from './datastore';
21
import { getExperimentStartupInfo, setExperimentStartupInfo } from './experimentStartupInfo';
22
import { ExperimentConfig, Manager } from './manager';
23
import { ExperimentManager } from './experimentManager';
QuanluZhang's avatar
QuanluZhang committed
24
import { HyperParameters, TrainingService, TrialJobStatus } from './trainingService';
Deshui Yu's avatar
Deshui Yu committed
25

26
function getExperimentRootDir(): string {
27
    return getExperimentStartupInfo().logDir;
Deshui Yu's avatar
Deshui Yu committed
28
29
}

30
function getLogDir(): string {
Deshui Yu's avatar
Deshui Yu committed
31
32
33
    return path.join(getExperimentRootDir(), 'log');
}

34
function getLogLevel(): string {
35
    return getExperimentStartupInfo().logLevel;
36
37
}

Deshui Yu's avatar
Deshui Yu committed
38
39
40
41
function getDefaultDatabaseDir(): string {
    return path.join(getExperimentRootDir(), 'db');
}

QuanluZhang's avatar
QuanluZhang committed
42
43
44
45
function getCheckpointDir(): string {
    return path.join(getExperimentRootDir(), 'checkpoint');
}

46
47
48
49
function getExperimentsInfoPath(): string {
    return path.join(os.homedir(), 'nni-experiments', '.experiment');
}

liuzhe-lz's avatar
liuzhe-lz committed
50
51
async function mkDirP(dirPath: string): Promise<void> {
    await fs.promises.mkdir(dirPath, { recursive: true });
Deshui Yu's avatar
Deshui Yu committed
52
53
54
}

function mkDirPSync(dirPath: string): void {
liuzhe-lz's avatar
liuzhe-lz committed
55
    fs.mkdirSync(dirPath, { recursive: true });
Deshui Yu's avatar
Deshui Yu committed
56
57
}

liuzhe-lz's avatar
liuzhe-lz committed
58
const delay = timersPromises.setTimeout;
Deshui Yu's avatar
Deshui Yu committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96

/**
 * Convert index to character
 * @param index index
 * @returns a mapping character
 */
function charMap(index: number): number {
    if (index < 26) {
        return index + 97;
    } else if (index < 52) {
        return index - 26 + 65;
    } else {
        return index - 52 + 48;
    }
}

/**
 * Generate a unique string by length
 * @param len length of string
 * @returns a unique string
 */
function uniqueString(len: number): string {
    if (len === 0) {
        return '';
    }
    const byteLength: number = Math.ceil((Math.log2(52) + Math.log2(62) * (len - 1)) / 8);
    let num: number = randomBytes(byteLength).reduce((a: number, b: number) => a * 256 + b, 0);
    const codes: number[] = [];
    codes.push(charMap(num % 52));
    num = Math.floor(num / 52);
    for (let i: number = 1; i < len; i++) {
        codes.push(charMap(num % 62));
        num = Math.floor(num / 62);
    }

    return String.fromCharCode(...codes);
}

97
98
99
100
function randomInt(max: number): number {
    return Math.floor(Math.random() * max);
}

101
102
103
104
105
function randomSelect<T>(a: T[]): T {
    assert(a !== undefined);

    return a[Math.floor(Math.random() * a.length)];
}
106

Deshui Yu's avatar
Deshui Yu committed
107
108
109
110
111
112
113
114
115
116
117
118
function parseArg(names: string[]): string {
    if (process.argv.length >= 4) {
        for (let i: number = 2; i < process.argv.length - 1; i++) {
            if (names.includes(process.argv[i])) {
                return process.argv[i + 1];
            }
        }
    }

    return '';
}

119
function getCmdPy(): string {
120
    let cmd = 'python3';
121
    if (process.platform === 'win32') {
122
123
124
125
126
        cmd = 'python';
    }
    return cmd;
}

127
/**
128
 * Generate command line to start automl algorithm(s),
QuanluZhang's avatar
QuanluZhang committed
129
 * either start advisor or start a process which runs tuner and assessor
130
 *
chicm-ms's avatar
chicm-ms committed
131
 * @param expParams: experiment startup parameters
132
133
 *
 */
134
function getMsgDispatcherCommand(expParams: ExperimentConfig): string {
chicm-ms's avatar
chicm-ms committed
135
136
137
    const clonedParams = Object.assign({}, expParams);
    delete clonedParams.searchSpace;
    return `${getCmdPy()} -m nni --exp_params ${Buffer.from(JSON.stringify(clonedParams)).toString('base64')}`;
138
139
}

140
141
142
143
/**
 * Generate parameter file name based on HyperParameters object
 * @param hyperParameters HyperParameters instance
 */
chicm-ms's avatar
chicm-ms committed
144
function generateParamFileName(hyperParameters: HyperParameters): string {
145
146
147
    assert(hyperParameters !== undefined);
    assert(hyperParameters.index >= 0);

chicm-ms's avatar
chicm-ms committed
148
    let paramFileName: string;
149
    if (hyperParameters.index == 0) {
150
151
152
153
154
155
156
        paramFileName = 'parameter.cfg';
    } else {
        paramFileName = `parameter_${hyperParameters.index}.cfg`
    }
    return paramFileName;
}

Deshui Yu's avatar
Deshui Yu committed
157
158
159
160
161
162
163
164
165
/**
 * Initialize a pseudo experiment environment for unit test.
 * Must be paired with `cleanupUnitTest()`.
 */
function prepareUnitTest(): void {
    Container.snapshot(Database);
    Container.snapshot(DataStore);
    Container.snapshot(TrainingService);
    Container.snapshot(Manager);
166
    Container.snapshot(ExperimentManager);
Deshui Yu's avatar
Deshui Yu committed
167

168
169
170
    const logLevel: string = parseArg(['--log_level', '-ll']);

    setExperimentStartupInfo(true, 'unittest', 8080, 'unittest', undefined, logLevel);
Deshui Yu's avatar
Deshui Yu committed
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
    mkDirPSync(getLogDir());

    const sqliteFile: string = path.join(getDefaultDatabaseDir(), 'nni.sqlite');
    try {
        fs.unlinkSync(sqliteFile);
    } catch (err) {
        // file not exists, good
    }
}

/**
 * Clean up unit test pseudo experiment.
 * Must be paired with `prepareUnitTest()`.
 */
function cleanupUnitTest(): void {
    Container.restore(Manager);
    Container.restore(TrainingService);
    Container.restore(DataStore);
    Container.restore(Database);
190
    Container.restore(ExperimentManager);
191
192
    const logLevel: string = parseArg(['--log_level', '-ll']);
    setExperimentStartupInfo(true, 'unittest', 8080, 'unittest', undefined, logLevel);
Deshui Yu's avatar
Deshui Yu committed
193
194
}

195
196
let cachedIpv4Address: string | null = null;

197
/**
198
 * Get IPv4 address of current machine.
199
 */
liuzhe-lz's avatar
liuzhe-lz committed
200
async function getIPV4Address(): Promise<string> {
201
202
    if (cachedIpv4Address !== null) {
        return cachedIpv4Address;
203
    }
204

205
206
207
208
    // creates "udp connection" to a non-exist target, and get local address of the connection.
    // since udp is connectionless, this does not send actual packets.
    const socket = dgram.createSocket('udp4');
    socket.connect(1, '192.0.2.0');
liuzhe-lz's avatar
liuzhe-lz committed
209
    for (let i = 0; i < 10; i++) {  // wait the system to initialize "connection"
liuzhe-lz's avatar
liuzhe-lz committed
210
211
212
213
214
215
216
217
        await timersPromises.setTimeout(1);
        try {
            cachedIpv4Address = socket.address().address;
            socket.close();
            return cachedIpv4Address;
        } catch (error) {
            /* retry */
        }
liuzhe-lz's avatar
liuzhe-lz committed
218
    }
liuzhe-lz's avatar
liuzhe-lz committed
219

liuzhe-lz's avatar
liuzhe-lz committed
220
    cachedIpv4Address = socket.address().address;  // if it still fails, throw the error
221
222
    socket.close();
    return cachedIpv4Address;
223
224
}

QuanluZhang's avatar
QuanluZhang committed
225
226
227
228
229
230
231
/**
 * Get the status of canceled jobs according to the hint isEarlyStopped
 */
function getJobCancelStatus(isEarlyStopped: boolean): TrialJobStatus {
    return isEarlyStopped ? 'EARLY_STOPPED' : 'USER_CANCELED';
}

232
233
234
235
/**
 * Utility method to calculate file numbers under a directory, recursively
 * @param directory directory name
 */
chicm-ms's avatar
chicm-ms committed
236
function countFilesRecursively(directory: string): Promise<number> {
237
    if (!fs.existsSync(directory)) {
238
239
240
241
242
        throw Error(`Direcotory ${directory} doesn't exist`);
    }

    const deferred: Deferred<number> = new Deferred<number>();

chicm-ms's avatar
chicm-ms committed
243
    let timeoutId: NodeJS.Timer
244
    const delayTimeout: Promise<number> = new Promise((_resolve: Function, reject: Function): void => {
245
246
247
248
249
250
251
        // Set timeout and reject the promise once reach timeout (5 seconds)
        timeoutId = setTimeout(() => {
            reject(new Error(`Timeout: path ${directory} has too many files`));
        }, 5000);
    });

    let fileCount: number = -1;
252
    let cmd: string;
253
    if (process.platform === "win32") {
254
255
        cmd = `powershell "Get-ChildItem -Path ${directory} -Recurse -File | Measure-Object | %{$_.Count}"`
    } else {
256
        cmd = `find ${directory} -type f | wc -l`;
257
258
    }
    cpp.exec(cmd).then((result) => {
259
        if (result.stdout && parseInt(result.stdout)) {
260
            fileCount = parseInt(result.stdout);
261
262
263
264
265
266
267
268
        }
        deferred.resolve(fileCount);
    });
    return Promise.race([deferred.promise, delayTimeout]).finally(() => {
        clearTimeout(timeoutId);
    });
}

269
270
271
272
/**
 * get the version of current package
 */
async function getVersion(): Promise<string> {
chicm-ms's avatar
chicm-ms committed
273
    const deferred: Deferred<string> = new Deferred<string>();
274
    import(path.join(__dirname, '..', 'package.json')).then((pkg) => {
275
        deferred.resolve(pkg.version);
276
277
    }).catch(() => {
        deferred.resolve('999.0.0-developing');
278
279
    });
    return deferred.promise;
280
}
281

282
283
284
/**
 * run command as ChildProcess
 */
J-shang's avatar
J-shang committed
285
function getTunerProc(command: string, stdio: StdioOptions, newCwd: string, newEnv: any, newShell: boolean = true, isDetached: boolean = false): ChildProcess {
286
287
    let cmd: string = command;
    let arg: string[] = [];
288
    if (process.platform === "win32") {
289
        cmd = command.split(" ", 1)[0];
290
        arg = command.substr(cmd.length + 1).split(" ");
291
        newShell = false;
292
        isDetached = true;
293
294
295
296
297
    }
    const tunerProc: ChildProcess = spawn(cmd, arg, {
        stdio,
        cwd: newCwd,
        env: newEnv,
298
299
        shell: newShell,
        detached: isDetached
300
301
302
303
304
305
306
    });
    return tunerProc;
}

/**
 * judge whether the process is alive
 */
Yuge Zhang's avatar
Yuge Zhang committed
307
async function isAlive(pid: any): Promise<boolean> {
chicm-ms's avatar
chicm-ms committed
308
    const deferred: Deferred<boolean> = new Deferred<boolean>();
309
    let alive: boolean = false;
Yuge Zhang's avatar
Yuge Zhang committed
310
    if (process.platform === 'win32') {
311
312
313
314
315
316
317
        try {
            const str = cp.execSync(`powershell.exe Get-Process -Id ${pid} -ErrorAction SilentlyContinue`).toString();
            if (str) {
                alive = true;
            }
        }
        catch (error) {
chicm-ms's avatar
chicm-ms committed
318
            //ignore
319
320
        }
    }
Yuge Zhang's avatar
Yuge Zhang committed
321
    else {
322
323
324
325
326
327
328
329
330
331
332
333
        try {
            await cpp.exec(`kill -0 ${pid}`);
            alive = true;
        } catch (error) {
            //ignore
        }
    }
    deferred.resolve(alive);
    return deferred.promise;
}

/**
334
 * kill process
335
 */
Yuge Zhang's avatar
Yuge Zhang committed
336
async function killPid(pid: any): Promise<void> {
chicm-ms's avatar
chicm-ms committed
337
    const deferred: Deferred<void> = new Deferred<void>();
338
339
    try {
        if (process.platform === "win32") {
Yuge Zhang's avatar
Yuge Zhang committed
340
            await cpp.exec(`cmd.exe /c taskkill /PID ${pid} /F`);
341
        }
342
        else {
343
344
345
346
347
348
349
350
351
            await cpp.exec(`kill -9 ${pid}`);
        }
    } catch (error) {
        // pid does not exist, do nothing here
    }
    deferred.resolve();
    return deferred.promise;
}

352
function getNewLine(): string {
353
354
355
    if (process.platform === "win32") {
        return "\r\n";
    }
356
    else {
357
358
359
360
        return "\n";
    }
}

361
362
/**
 * Use '/' to join path instead of '\' for all kinds of platform
363
 * @param path
364
365
366
367
368
369
370
 */
function unixPathJoin(...paths: any[]): string {
    const dir: string = paths.filter((path: any) => path !== '').join('/');
    if (dir === '') return '.';
    return dir;
}

371
372
373
374
375
/**
 * lock a file sync
 */
function withLockSync(func: Function, filePath: string, lockOpts: {[key: string]: any}, ...args: any): any {
    const lockName = path.join(path.dirname(filePath), path.basename(filePath) + `.lock.${process.pid}`);
376
    if (typeof lockOpts['stale'] === 'number'){
377
378
379
        const lockPath = path.join(path.dirname(filePath), path.basename(filePath) + '.lock.*');
        const lockFileNames: string[] = glob.sync(lockPath);
        const canLock: boolean = lockFileNames.map((fileName) => {
380
            return fs.existsSync(fileName) && Date.now() - fs.statSync(fileName).mtimeMs < lockOpts['stale'];
381
        }).filter(unexpired=>unexpired === true).length === 0;
382
383
384
385
386
387
388
389
390
391
        if (!canLock) {
            throw new Error('File has been locked.');
        }
    }
    lockfile.lockSync(lockName, lockOpts);
    const result = func(...args);
    lockfile.unlockSync(lockName);
    return result;
}

J-shang's avatar
J-shang committed
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
async function isPortOpen(host: string, port: number): Promise<boolean> {
    return new Promise<boolean>((resolve, reject) => {
        try{
            const stream = net.createConnection(port, host);
            const id = setTimeout(() => {
                stream.destroy();
                resolve(false);
            }, 1000);

            stream.on('connect', () => {
                clearTimeout(id);
                stream.destroy();
                resolve(true);
            });

            stream.on('error', () => {
                clearTimeout(id);
                stream.destroy();
                resolve(false);
            });
        } catch (error) {
            reject(error);
        }
    });
}

async function getFreePort(host: string, start: number, end: number): Promise<number> {
    if (start > end) {
        throw new Error(`no more free port`);
    }
    if (await isPortOpen(host, start)) {
        return await getFreePort(host, start + 1, end);
    } else {
        return start;
    }
}

429
430
431
432
433
export function importModule(modulePath: string): any {
    module.paths.unshift(path.dirname(modulePath));
    return require(path.basename(modulePath));
}

434
export {
435
    countFilesRecursively, generateParamFileName, getMsgDispatcherCommand, getCheckpointDir, getExperimentsInfoPath,
J-shang's avatar
J-shang committed
436
    getLogDir, getExperimentRootDir, getJobCancelStatus, getDefaultDatabaseDir, getIPV4Address, unixPathJoin, withLockSync, getFreePort, isPortOpen,
437
438
    mkDirP, mkDirPSync, delay, prepareUnitTest, parseArg, cleanupUnitTest, uniqueString, randomInt, randomSelect, getLogLevel, getVersion, getCmdPy, getTunerProc, isAlive, killPid, getNewLine
};