"vscode:/vscode.git/clone" did not exist on "1b648f2f42bf5b82421289cff350ac7af6ec46ea"
Unverified Commit 69cae211 authored by Chi Song's avatar Chi Song Committed by GitHub
Browse files

Support Windows as remote node. (#2431)

parent 1180599a
...@@ -2,18 +2,54 @@ ...@@ -2,18 +2,54 @@
NNI can run one experiment on multiple remote machines through SSH, called `remote` mode. It's like a lightweight training platform. In this mode, NNI can be started from your computer, and dispatch trials to remote machines in parallel. NNI can run one experiment on multiple remote machines through SSH, called `remote` mode. It's like a lightweight training platform. In this mode, NNI can be started from your computer, and dispatch trials to remote machines in parallel.
## Remote machine requirements The OS of remote machines supports `Linux`, `Windows 10`, and `Windows Server 2019`.
* It only supports Linux as remote machines, and [linux part in system specification](../Tutorial/InstallationLinux.md) is same as NNI local mode. ## Requirements
* Follow [installation](../Tutorial/InstallationLinux.md) to install NNI on each machine. * Make sure the default environment of remote machines meets requirements of your trial code. If the default environment does not meet the requirements, the setup script can be added into `command` field of NNI config.
* Make sure remote machines meet environment requirements of your trial code. If the default environment does not meet the requirements, the setup script can be added into `command` field of NNI config.
* Make sure remote machines can be accessed through SSH from the machine which runs `nnictl` command. It supports both password and key authentication of SSH. For advanced usages, please refer to [machineList part of configuration](../Tutorial/ExperimentConfig.md). * Make sure remote machines can be accessed through SSH from the machine which runs `nnictl` command. It supports both password and key authentication of SSH. For advanced usages, please refer to [machineList part of configuration](../Tutorial/ExperimentConfig.md).
* Make sure the NNI version on each machine is consistent. * Make sure the NNI version on each machine is consistent.
* Make sure the command of Trial is compatible with remote OSes, if you want to use remote Linux and Windows together. For example, the default python 3.x executable called `python3` on Linux, and `python` on Windows.
### Linux
* Follow [installation](../Tutorial/InstallationLinux.md) to install NNI on the remote machine.
### Windows
* Follow [installation](../Tutorial/InstallationWin.md) to install NNI on the remote machine.
* Install and start `OpenSSH Server`.
1. Open `Settings` app on Windows.
2. Click `Apps`, then click `Optional features`.
3. Click `Add a feature`, search and select `OpenSSH Server`, and then click `Install`.
4. Once it's installed, run below command to start and set to automatic start.
```bat
sc config sshd start=auto
net start sshd
```
* Make sure remote account is administrator, so that it can stop running trials.
* Make sure there is no welcome message more than default, since it causes ssh2 failed in NodeJs. For example, if you're using Data Science VM on Azure, it needs to remove extra echo commands in `C:\dsvm\tools\setup\welcome.bat`.
The output like below is ok, when opening a new command window.
```text
Microsoft Windows [Version 10.0.17763.1192]
(c) 2018 Microsoft Corporation. All rights reserved.
(py37_default) C:\Users\AzureUser>
```
## Run an experiment ## Run an experiment
e.g. there are three machines, which can be logged in with username and password. e.g. there are three machines, which can be logged in with username and password.
......
# Install on Windows # Install on Windows
## Installation ## Prerequires
Anaconda or Miniconda is highly recommended to manage multiple Python environments. * Python 3.5 (or above) 64-bit. [Anaconda](https://www.anaconda.com/products/individual) or [Miniconda](https://docs.conda.io/en/latest/miniconda.html) is highly recommended to manage multiple Python environments on Windows.
### Install NNI through pip * If it's a newly installed Python environment, it needs to install [Microsoft C++ Build Tools](https://visualstudio.microsoft.com/visual-cpp-build-tools/) to support build NNI dependencies like `scikit-learn`.
Prerequisites: `python 64-bit >= 3.5` ```bat
pip install cython wheel
```
```bash * git for verifying installation.
python -m pip install --upgrade nni
```
### Install NNI through source code ## Install NNI
If you are interested in special or the latest code versions, you can install NNI through source code. In most cases, you can install and upgrade NNI from pip package. It's easy and fast.
Prerequisites: `python 64-bit >=3.5`, `git`, `PowerShell`. If you are interested in special or the latest code versions, you can install NNI through source code.
```bash If you want to contribute to NNI, refer to [setup development environment](SetupNniDeveloperEnvironment.md).
git clone -b v1.5 https://github.com/Microsoft/nni.git
cd nni * From pip package
powershell -ExecutionPolicy Bypass -file install.ps1
``` ```bat
python -m pip install --upgrade nni
```
* From source code
```bat
git clone -b v1.5 https://github.com/Microsoft/nni.git
cd nni
powershell -ExecutionPolicy Bypass -file install.ps1
```
## Verify installation ## Verify installation
The following example is built on TensorFlow 1.x. Make sure **TensorFlow 1.x is used** when running it. The following example is built on TensorFlow 1.x. Make sure **TensorFlow 1.x is used** when running it.
* Download the examples via clone the source code. * Clone examples within source code.
```bash ```bat
git clone -b v1.5 https://github.com/Microsoft/nni.git git clone -b v1.5 https://github.com/Microsoft/nni.git
``` ```
* Run the MNIST example. * Run the MNIST example.
```bash ```bat
nnictl create --config nni\examples\trials\mnist-tfv1\config_windows.yml nnictl create --config nni\examples\trials\mnist-tfv1\config_windows.yml
``` ```
Note: for other examples you need to change trial command `python3` to `python` in each example YAML, if python3 is called through `python` on your machine. Note: If you are familiar with other frameworks, you can choose corresponding example under `examples\trials`. It needs to change trial command `python3` to `python` in each example YAML, since default installation has `python.exe`, not `python3.exe` executable.
* Wait for the message `INFO: Successfully started experiment!` in the command line. This message indicates that your experiment has been successfully started. You can explore the experiment using the `Web UI url`. * Wait for the message `INFO: Successfully started experiment!` in the command line. This message indicates that your experiment has been successfully started. You can explore the experiment using the `Web UI url`.
...@@ -112,18 +122,20 @@ If there is a stderr file, please check it. Two possible cases are: ...@@ -112,18 +122,20 @@ If there is a stderr file, please check it. Two possible cases are:
* forgetting to install experiment dependencies such as TensorFlow, Keras and so on. * forgetting to install experiment dependencies such as TensorFlow, Keras and so on.
### Fail to use BOHB on Windows ### Fail to use BOHB on Windows
Make sure a C++ 14.0 compiler is installed when trying to run `nnictl package install --name=BOHB` to install the dependencies. Make sure a C++ 14.0 compiler is installed when trying to run `nnictl package install --name=BOHB` to install the dependencies.
### Not supported tuner on Windows ### Not supported tuner on Windows
SMAC is not supported currently; for the specific reason refer to this [GitHub issue](https://github.com/automl/SMAC3/issues/483). SMAC is not supported currently; for the specific reason refer to this [GitHub issue](https://github.com/automl/SMAC3/issues/483).
### Use a Windows server as a remote worker ### Use Windows as a remote worker
Currently, you can't.
Note: Refer to [Remote Machine mode](../TrainingService/RemoteMachineMode.md).
* If an error like `Segmentation fault` is encountered, please refer to the [FAQ](FAQ.md) ### Segmentation fault (core dumped) when installing
Refer to [FAQ](FAQ.md).
## Further reading ## Further reading
......
...@@ -84,7 +84,11 @@ class SendMetrics(keras.callbacks.Callback): ...@@ -84,7 +84,11 @@ class SendMetrics(keras.callbacks.Callback):
Run on end of each epoch Run on end of each epoch
''' '''
LOG.debug(logs) LOG.debug(logs)
nni.report_intermediate_result(logs["val_acc"]) # TensorFlow 2.0 API reference claims the key is `val_acc`, but in fact it's `val_accuracy`
if 'val_acc' in logs:
nni.report_intermediate_result(logs['val_acc'])
else:
nni.report_intermediate_result(logs['val_accuracy'])
def train(args, params): def train(args, params):
''' '''
......
...@@ -86,7 +86,11 @@ class SendMetrics(keras.callbacks.Callback): ...@@ -86,7 +86,11 @@ class SendMetrics(keras.callbacks.Callback):
Run on end of each epoch Run on end of each epoch
''' '''
LOG.debug(logs) LOG.debug(logs)
nni.report_intermediate_result(logs["val_acc"]) # TensorFlow 2.0 API reference claims the key is `val_acc`, but in fact it's `val_accuracy`
if 'val_acc' in logs:
nni.report_intermediate_result(logs['val_acc'])
else:
nni.report_intermediate_result(logs['val_accuracy'])
def train(args, params): def train(args, params):
''' '''
......
...@@ -152,7 +152,11 @@ class SendMetrics(keras.callbacks.Callback): ...@@ -152,7 +152,11 @@ class SendMetrics(keras.callbacks.Callback):
if logs is None: if logs is None:
logs = dict() logs = dict()
logger.debug(logs) logger.debug(logs)
nni.report_intermediate_result(logs["val_accuracy"]) # TensorFlow 2.0 API reference claims the key is `val_acc`, but in fact it's `val_accuracy`
if 'val_acc' in logs:
nni.report_intermediate_result(logs['val_acc'])
else:
nni.report_intermediate_result(logs['val_accuracy'])
# Training # Training
......
...@@ -152,9 +152,11 @@ class SendMetrics(keras.callbacks.Callback): ...@@ -152,9 +152,11 @@ class SendMetrics(keras.callbacks.Callback):
if logs is None: if logs is None:
logs = dict() logs = dict()
logger.debug(logs) logger.debug(logs)
# accuracy key for keras 2.2.2: val_acc # TensorFlow 2.0 API reference claims the key is `val_acc`, but in fact it's `val_accuracy`
# for keras 2.3.1: val_accuracy if 'val_acc' in logs:
nni.report_intermediate_result(logs["val_accuracy"]) nni.report_intermediate_result(logs['val_acc'])
else:
nni.report_intermediate_result(logs['val_accuracy'])
# Training # Training
......
...@@ -148,12 +148,13 @@ cmd /c $NNI_YARN ...@@ -148,12 +148,13 @@ cmd /c $NNI_YARN
cmd /c $NNI_YARN build cmd /c $NNI_YARN build
Copy-Item config -Destination .\dist\ -Recurse -Force Copy-Item config -Destination .\dist\ -Recurse -Force
# Building WebUI # Building WebUI
# office-ui-fabric-react need longer time. the 180000 is in ms, mean 180 seconds, longer than default 30 seconds.
cd ..\webui cd ..\webui
cmd /c $NNI_YARN cmd /c $NNI_YARN --network-timeout 180000
cmd /c $NNI_YARN build cmd /c $NNI_YARN build
# Building NasUI # Building NasUI
cd ..\nasui cd ..\nasui
cmd /c $NNI_YARN cmd /c $NNI_YARN --network-timeout 180000
cmd /c $NNI_YARN build cmd /c $NNI_YARN build
cd ..\.. cd ..\..
......
...@@ -22,7 +22,7 @@ import { HyperParameters, TrainingService, TrialJobStatus } from './trainingServ ...@@ -22,7 +22,7 @@ import { HyperParameters, TrainingService, TrialJobStatus } from './trainingServ
function getExperimentRootDir(): string { function getExperimentRootDir(): string {
return getExperimentStartupInfo() return getExperimentStartupInfo()
.getLogDir(); .getLogDir();
} }
function getLogDir(): string { function getLogDir(): string {
...@@ -31,7 +31,7 @@ function getLogDir(): string { ...@@ -31,7 +31,7 @@ function getLogDir(): string {
function getLogLevel(): string { function getLogLevel(): string {
return getExperimentStartupInfo() return getExperimentStartupInfo()
.getLogLevel(); .getLogLevel();
} }
function getDefaultDatabaseDir(): string { function getDefaultDatabaseDir(): string {
...@@ -113,11 +113,16 @@ function uniqueString(len: number): string { ...@@ -113,11 +113,16 @@ function uniqueString(len: number): string {
return String.fromCharCode(...codes); return String.fromCharCode(...codes);
} }
function randomInt(max: number): number {
return Math.floor(Math.random() * max);
}
function randomSelect<T>(a: T[]): T { function randomSelect<T>(a: T[]): T {
assert(a !== undefined); assert(a !== undefined);
return a[Math.floor(Math.random() * a.length)]; return a[Math.floor(Math.random() * a.length)];
} }
function parseArg(names: string[]): string { function parseArg(names: string[]): string {
if (process.argv.length >= 4) { if (process.argv.length >= 4) {
for (let i: number = 2; i < process.argv.length - 1; i++) { for (let i: number = 2; i < process.argv.length - 1; i++) {
...@@ -132,7 +137,7 @@ function parseArg(names: string[]): string { ...@@ -132,7 +137,7 @@ function parseArg(names: string[]): string {
function getCmdPy(): string { function getCmdPy(): string {
let cmd = 'python3'; let cmd = 'python3';
if(process.platform === 'win32'){ if (process.platform === 'win32') {
cmd = 'python'; cmd = 'python';
} }
return cmd; return cmd;
...@@ -160,7 +165,7 @@ function generateParamFileName(hyperParameters: HyperParameters): string { ...@@ -160,7 +165,7 @@ function generateParamFileName(hyperParameters: HyperParameters): string {
assert(hyperParameters.index >= 0); assert(hyperParameters.index >= 0);
let paramFileName: string; let paramFileName: string;
if(hyperParameters.index == 0) { if (hyperParameters.index == 0) {
paramFileName = 'parameter.cfg'; paramFileName = 'parameter.cfg';
} else { } else {
paramFileName = `parameter_${hyperParameters.index}.cfg` paramFileName = `parameter_${hyperParameters.index}.cfg`
...@@ -211,9 +216,9 @@ function getIPV4Address(): string { ...@@ -211,9 +216,9 @@ function getIPV4Address(): string {
return cachedipv4Address; return cachedipv4Address;
} }
if(os.networkInterfaces().eth0) { if (os.networkInterfaces().eth0) {
for(const item of os.networkInterfaces().eth0) { for (const item of os.networkInterfaces().eth0) {
if(item.family === 'IPv4') { if (item.family === 'IPv4') {
cachedipv4Address = item.address; cachedipv4Address = item.address;
return cachedipv4Address; return cachedipv4Address;
} }
...@@ -225,14 +230,6 @@ function getIPV4Address(): string { ...@@ -225,14 +230,6 @@ function getIPV4Address(): string {
throw Error('getIPV4Address() failed because no valid IPv4 address found.') throw Error('getIPV4Address() failed because no valid IPv4 address found.')
} }
function getRemoteTmpDir(osType: string): string {
if (osType == 'linux') {
return '/tmp';
} else {
throw Error(`remote OS ${osType} not supported`);
}
}
/** /**
* Get the status of canceled jobs according to the hint isEarlyStopped * Get the status of canceled jobs according to the hint isEarlyStopped
*/ */
...@@ -245,7 +242,7 @@ function getJobCancelStatus(isEarlyStopped: boolean): TrialJobStatus { ...@@ -245,7 +242,7 @@ function getJobCancelStatus(isEarlyStopped: boolean): TrialJobStatus {
* @param directory directory name * @param directory directory name
*/ */
function countFilesRecursively(directory: string): Promise<number> { function countFilesRecursively(directory: string): Promise<number> {
if(!fs.existsSync(directory)) { if (!fs.existsSync(directory)) {
throw Error(`Direcotory ${directory} doesn't exist`); throw Error(`Direcotory ${directory} doesn't exist`);
} }
...@@ -261,13 +258,13 @@ function countFilesRecursively(directory: string): Promise<number> { ...@@ -261,13 +258,13 @@ function countFilesRecursively(directory: string): Promise<number> {
let fileCount: number = -1; let fileCount: number = -1;
let cmd: string; let cmd: string;
if(process.platform === "win32") { if (process.platform === "win32") {
cmd = `powershell "Get-ChildItem -Path ${directory} -Recurse -File | Measure-Object | %{$_.Count}"` cmd = `powershell "Get-ChildItem -Path ${directory} -Recurse -File | Measure-Object | %{$_.Count}"`
} else { } else {
cmd = `find ${directory} -type f | wc -l`; cmd = `find ${directory} -type f | wc -l`;
} }
cpp.exec(cmd).then((result) => { cpp.exec(cmd).then((result) => {
if(result.stdout && parseInt(result.stdout)) { if (result.stdout && parseInt(result.stdout)) {
fileCount = parseInt(result.stdout); fileCount = parseInt(result.stdout);
} }
deferred.resolve(fileCount); deferred.resolve(fileCount);
...@@ -280,20 +277,20 @@ function countFilesRecursively(directory: string): Promise<number> { ...@@ -280,20 +277,20 @@ function countFilesRecursively(directory: string): Promise<number> {
function validateFileName(fileName: string): boolean { function validateFileName(fileName: string): boolean {
const pattern: string = '^[a-z0-9A-Z._-]+$'; const pattern: string = '^[a-z0-9A-Z._-]+$';
const validateResult = fileName.match(pattern); const validateResult = fileName.match(pattern);
if(validateResult) { if (validateResult) {
return true; return true;
} }
return false; return false;
} }
async function validateFileNameRecursively(directory: string): Promise<boolean> { async function validateFileNameRecursively(directory: string): Promise<boolean> {
if(!fs.existsSync(directory)) { if (!fs.existsSync(directory)) {
throw Error(`Direcotory ${directory} doesn't exist`); throw Error(`Direcotory ${directory} doesn't exist`);
} }
const fileNameArray: string[] = fs.readdirSync(directory); const fileNameArray: string[] = fs.readdirSync(directory);
let result = true; let result = true;
for(const name of fileNameArray){ for (const name of fileNameArray) {
const fullFilePath: string = path.join(directory, name); const fullFilePath: string = path.join(directory, name);
try { try {
// validate file names and directory names // validate file names and directory names
...@@ -301,14 +298,14 @@ async function validateFileNameRecursively(directory: string): Promise<boolean> ...@@ -301,14 +298,14 @@ async function validateFileNameRecursively(directory: string): Promise<boolean>
if (fs.lstatSync(fullFilePath).isDirectory()) { if (fs.lstatSync(fullFilePath).isDirectory()) {
result = result && await validateFileNameRecursively(fullFilePath); result = result && await validateFileNameRecursively(fullFilePath);
} }
if(!result) { if (!result) {
return Promise.reject(new Error(`file name in ${fullFilePath} is not valid!`)); return Promise.reject(new Error(`file name in ${fullFilePath} is not valid!`));
} }
} catch(error) { } catch (error) {
return Promise.reject(error); return Promise.reject(error);
} }
} }
return Promise.resolve(result); return Promise.resolve(result);
} }
/** /**
...@@ -316,9 +313,9 @@ async function validateFileNameRecursively(directory: string): Promise<boolean> ...@@ -316,9 +313,9 @@ async function validateFileNameRecursively(directory: string): Promise<boolean>
*/ */
async function getVersion(): Promise<string> { async function getVersion(): Promise<string> {
const deferred: Deferred<string> = new Deferred<string>(); const deferred: Deferred<string> = new Deferred<string>();
import(path.join(__dirname, '..', 'package.json')).then((pkg)=>{ import(path.join(__dirname, '..', 'package.json')).then((pkg) => {
deferred.resolve(pkg.version); deferred.resolve(pkg.version);
}).catch((error)=>{ }).catch((error) => {
deferred.reject(error); deferred.reject(error);
}); });
return deferred.promise; return deferred.promise;
...@@ -331,9 +328,9 @@ function getTunerProc(command: string, stdio: StdioOptions, newCwd: string, newE ...@@ -331,9 +328,9 @@ function getTunerProc(command: string, stdio: StdioOptions, newCwd: string, newE
let cmd: string = command; let cmd: string = command;
let arg: string[] = []; let arg: string[] = [];
let newShell: boolean = true; let newShell: boolean = true;
if(process.platform === "win32"){ if (process.platform === "win32") {
cmd = command.split(" ", 1)[0]; cmd = command.split(" ", 1)[0];
arg = command.substr(cmd.length+1).split(" "); arg = command.substr(cmd.length + 1).split(" ");
newShell = false; newShell = false;
} }
const tunerProc: ChildProcess = spawn(cmd, arg, { const tunerProc: ChildProcess = spawn(cmd, arg, {
...@@ -383,7 +380,7 @@ async function killPid(pid: any): Promise<void> { ...@@ -383,7 +380,7 @@ async function killPid(pid: any): Promise<void> {
if (process.platform === "win32") { if (process.platform === "win32") {
await cpp.exec(`cmd.exe /c taskkill /PID ${pid} /F`); await cpp.exec(`cmd.exe /c taskkill /PID ${pid} /F`);
} }
else{ else {
await cpp.exec(`kill -9 ${pid}`); await cpp.exec(`kill -9 ${pid}`);
} }
} catch (error) { } catch (error) {
...@@ -397,7 +394,7 @@ function getNewLine(): string { ...@@ -397,7 +394,7 @@ function getNewLine(): string {
if (process.platform === "win32") { if (process.platform === "win32") {
return "\r\n"; return "\r\n";
} }
else{ else {
return "\n"; return "\n";
} }
} }
...@@ -412,6 +409,8 @@ function unixPathJoin(...paths: any[]): string { ...@@ -412,6 +409,8 @@ function unixPathJoin(...paths: any[]): string {
return dir; return dir;
} }
export {countFilesRecursively, validateFileNameRecursively, getRemoteTmpDir, generateParamFileName, getMsgDispatcherCommand, getCheckpointDir, export {
countFilesRecursively, validateFileNameRecursively, generateParamFileName, getMsgDispatcherCommand, getCheckpointDir,
getLogDir, getExperimentRootDir, getJobCancelStatus, getDefaultDatabaseDir, getIPV4Address, unixPathJoin, getLogDir, getExperimentRootDir, getJobCancelStatus, getDefaultDatabaseDir, getIPV4Address, unixPathJoin,
mkDirP, mkDirPSync, delay, prepareUnitTest, parseArg, cleanupUnitTest, uniqueString, randomSelect, getLogLevel, getVersion, getCmdPy, getTunerProc, isAlive, killPid, getNewLine }; mkDirP, mkDirPSync, delay, prepareUnitTest, parseArg, cleanupUnitTest, uniqueString, randomInt, randomSelect, getLogLevel, getVersion, getCmdPy, getTunerProc, isAlive, killPid, getNewLine
};
...@@ -266,7 +266,7 @@ class NNIManager implements Manager { ...@@ -266,7 +266,7 @@ class NNIManager implements Manager {
const delay1: Promise<{}> = new Promise((resolve: Function, reject: Function): void => { const delay1: Promise<{}> = new Promise((resolve: Function, reject: Function): void => {
timeoutId = setTimeout( timeoutId = setTimeout(
() => { reject(new Error('TrainingService setClusterMetadata timeout. Please check your config file.')); }, () => { reject(new Error('TrainingService setClusterMetadata timeout. Please check your config file.')); },
10000); 30000);
}); });
await Promise.race([delay1, this.trainingService.setClusterMetadata(key, value)]).finally(() => { await Promise.race([delay1, this.trainingService.setClusterMetadata(key, value)]).finally(() => {
clearTimeout(timeoutId); clearTimeout(timeoutId);
...@@ -368,7 +368,7 @@ class NNIManager implements Manager { ...@@ -368,7 +368,7 @@ class NNIManager implements Manager {
CUDA_VISIBLE_DEVICES: this.getGpuEnvvarValue() CUDA_VISIBLE_DEVICES: this.getGpuEnvvarValue()
}; };
const newEnv = Object.assign({}, process.env, nniEnv); const newEnv = Object.assign({}, process.env, nniEnv);
const tunerProc: ChildProcess = getTunerProc(command,stdio,newCwd,newEnv); const tunerProc: ChildProcess = getTunerProc(command, stdio, newCwd, newEnv);
this.dispatcherPid = tunerProc.pid; this.dispatcherPid = tunerProc.pid;
this.dispatcher = createDispatcherInterface(tunerProc); this.dispatcher = createDispatcherInterface(tunerProc);
...@@ -436,7 +436,9 @@ class NNIManager implements Manager { ...@@ -436,7 +436,9 @@ class NNIManager implements Manager {
} }
await killPid(this.dispatcherPid); await killPid(this.dispatcherPid);
const trialJobList: TrialJobDetail[] = await this.trainingService.listTrialJobs(); const trialJobList: TrialJobDetail[] = await this.trainingService.listTrialJobs();
// TO DO: to promise all
// DON'T try to make it in parallel, the training service may not handle it well.
// If there is performance concern, consider to support batch cancellation on training service.
for (const trialJob of trialJobList) { for (const trialJob of trialJobList) {
if (trialJob.status === 'RUNNING' || if (trialJob.status === 'RUNNING' ||
trialJob.status === 'WAITING') { trialJob.status === 'WAITING') {
...@@ -444,7 +446,7 @@ class NNIManager implements Manager { ...@@ -444,7 +446,7 @@ class NNIManager implements Manager {
this.log.info(`cancelTrialJob: ${trialJob.id}`); this.log.info(`cancelTrialJob: ${trialJob.id}`);
await this.trainingService.cancelTrialJob(trialJob.id); await this.trainingService.cancelTrialJob(trialJob.id);
} catch (error) { } catch (error) {
// pid does not exist, do nothing here this.log.debug(`ignorable error on canceling trial ${trialJob.id}. ${error}`);
} }
} }
} }
......
...@@ -174,10 +174,11 @@ export async function tarAdd(tarPath: string, sourcePath: string): Promise<void> ...@@ -174,10 +174,11 @@ export async function tarAdd(tarPath: string, sourcePath: string): Promise<void>
script.push( script.push(
`import os`, `import os`,
`import tarfile`, `import tarfile`,
String.Format(`tar = tarfile.open("{0}","w:gz")\r\nfor root,dir,files in os.walk("{1}"):`, tarFilePath, sourceFilePath), String.Format(`tar = tarfile.open("{0}","w:gz")\r\nroot="{1}"\r\nfor file_path,dir,files in os.walk(root):`, tarFilePath, sourceFilePath),
` for file in files:`, ` for file in files:`,
` fullpath = os.path.join(root,file)`, ` full_path = os.path.join(file_path, file)`,
` tar.add(fullpath, arcname=file)`, ` file = os.path.relpath(full_path, root)`,
` tar.add(full_path, arcname=file)`,
`tar.close()`); `tar.close()`);
await fs.promises.writeFile(path.join(os.tmpdir(), 'tar.py'), script.join(getNewLine()), { encoding: 'utf8', mode: 0o777 }); await fs.promises.writeFile(path.join(os.tmpdir(), 'tar.py'), script.join(getNewLine()), { encoding: 'utf8', mode: 0o777 });
const tarScript: string = path.join(os.tmpdir(), 'tar.py'); const tarScript: string = path.join(os.tmpdir(), 'tar.py');
......
...@@ -7,6 +7,36 @@ import { OsCommands } from "../osCommands"; ...@@ -7,6 +7,36 @@ import { OsCommands } from "../osCommands";
import { RemoteCommandResult } from "../remoteMachineData"; import { RemoteCommandResult } from "../remoteMachineData";
class LinuxCommands extends OsCommands { class LinuxCommands extends OsCommands {
public getScriptExt(): string {
return "sh";
}
public generateStartScript(workingDirectory: string, trialJobId: string, experimentId: string,
trialSequenceId: string, isMultiPhase: boolean, jobIdFileName: string,
command: string, nniManagerAddress: string, nniManagerPort: number,
nniManagerVersion: string, logCollection: string, exitCodeFile: string,
codeDir: string, cudaVisibleSetting: string): string {
return `#!/bin/bash
export NNI_PLATFORM=remote NNI_SYS_DIR=${workingDirectory} NNI_OUTPUT_DIR=${workingDirectory} NNI_TRIAL_JOB_ID=${trialJobId} \
NNI_EXP_ID=${experimentId} NNI_TRIAL_SEQ_ID=${trialSequenceId} NNI_CODE_DIR=${codeDir}
export MULTI_PHASE=${isMultiPhase}
cp -r $NNI_CODE_DIR/. $NNI_SYS_DIR
cd $NNI_SYS_DIR
sh install_nni.sh
python3 -m nni_trial_tool.trial_keeper --trial_command '${cudaVisibleSetting} ${command}' --nnimanager_ip '${nniManagerAddress}' \
--nnimanager_port '${nniManagerPort}' --nni_manager_version '${nniManagerVersion}' \
--job_id_file ${jobIdFileName} \
--log_collection '${logCollection}' 1>$NNI_OUTPUT_DIR/trialkeeper_stdout 2>$NNI_OUTPUT_DIR/trialkeeper_stderr
echo $? \`date +%s%3N\` >${exitCodeFile}`;
}
public generateGpuStatsScript(scriptFolder: string): string {
return `echo $$ > ${scriptFolder}/pid ; METRIC_OUTPUT_DIR=${scriptFolder} python3 -m nni_gpu_tool.gpu_metrics_collector`;
}
public createFolder(folderName: string, sharedFolder: boolean = false): string { public createFolder(folderName: string, sharedFolder: boolean = false): string {
let command; let command;
if (sharedFolder) { if (sharedFolder) {
...@@ -64,7 +94,19 @@ class LinuxCommands extends OsCommands { ...@@ -64,7 +94,19 @@ class LinuxCommands extends OsCommands {
} }
public killChildProcesses(pidFileName: string): string { public killChildProcesses(pidFileName: string): string {
const command = `pkill -P \`cat '${pidFileName}'\``; // prevent trialkeeper to be killed, so it can save exit code.
const command = `list_descendants ()
{
local children=$(ps -o pid= --ppid "$1")
for pid in $children
do
list_descendants "$pid"
done
echo "$children"
}
kill $(list_descendants \`cat '${pidFileName}'\`)`
return command; return command;
} }
......
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
import { OsCommands } from "../osCommands";
import { RemoteCommandResult } from "../remoteMachineData";
class WindowsCommands extends OsCommands {
protected pathSpliter: string = '\\';
public getScriptExt(): string {
return "cmd";
}
public generateStartScript(workingDirectory: string, trialJobId: string, experimentId: string,
trialSequenceId: string, isMultiPhase: boolean, jobIdFileName: string,
command: string, nniManagerAddress: string, nniManagerPort: number,
nniManagerVersion: string, logCollection: string, exitCodeFile: string,
codeDir: string, cudaVisibleSetting: string): string {
return `echo off
set NNI_PLATFORM=remote
set NNI_SYS_DIR=${workingDirectory}
set NNI_OUTPUT_DIR=${workingDirectory}
set NNI_TRIAL_JOB_ID=${trialJobId}
set NNI_EXP_ID=${experimentId}
set NNI_TRIAL_SEQ_ID=${trialSequenceId}
set MULTI_PHASE=${isMultiPhase}
set NNI_CODE_DIR=${codeDir}
${cudaVisibleSetting !== "" ? "set " + cudaVisibleSetting : ""}
robocopy /s %NNI_CODE_DIR%/. %NNI_SYS_DIR%
cd %NNI_SYS_DIR%
python -c "import nni" 2>nul
if not %ERRORLEVEL% EQU 0 (
echo installing NNI as exit code of "import nni" is %ERRORLEVEL%
python -m pip install --user --upgrade nni
)
echo starting script
python -m nni_trial_tool.trial_keeper --trial_command "${command}" --nnimanager_ip "${nniManagerAddress}" --nnimanager_port "${nniManagerPort}" --nni_manager_version "${nniManagerVersion}" --log_collection "${logCollection}" --job_id_file ${jobIdFileName} 1>%NNI_OUTPUT_DIR%/trialkeeper_stdout 2>%NNI_OUTPUT_DIR%/trialkeeper_stderr
echo save exit code(%ERRORLEVEL%) and time
echo|set /p="%ERRORLEVEL% " > ${exitCodeFile}
powershell -command "Write (((New-TimeSpan -Start (Get-Date "01/01/1970") -End (Get-Date).ToUniversalTime()).TotalMilliseconds).ToString("0")) | Out-file ${exitCodeFile} -Append -NoNewline -encoding utf8"`;
}
public generateGpuStatsScript(scriptFolder: string): string {
return `powershell -command $env:METRIC_OUTPUT_DIR='${scriptFolder}';$app = Start-Process -FilePath python -NoNewWindow -passthru -ArgumentList '-m nni_gpu_tool.gpu_metrics_collector' -RedirectStandardOutput ${scriptFolder}\\scriptstdout -RedirectStandardError ${scriptFolder}\\scriptstderr;Write $PID ^| Out-File ${scriptFolder}\\pid -NoNewline -encoding utf8;wait-process $app.ID`;
}
public createFolder(folderName: string, sharedFolder: boolean = false): string {
let command;
if (sharedFolder) {
command = `mkdir "${folderName}"\r\nICACLS "${folderName}" /grant "Users":F`;
} else {
command = `mkdir "${folderName}"`;
}
return command;
}
public allowPermission(isRecursive: boolean = false, ...folders: string[]): string {
let commands: string = "";
folders.forEach(folder => {
commands += `ICACLS "${folder}" /grant "Users":F${isRecursive ? " /T" : ""}\r\n`
});
return commands;
}
public removeFolder(folderName: string, isRecursive: boolean = false, isForce: boolean = true): string {
let flags = '';
if (isForce || isRecursive) {
flags = `${isRecursive ? ' /s' : ''}${isForce ? ' /q' : ''}`;
}
const command = `rmdir${flags} "${folderName}"`;
return command;
}
public removeFiles(folderName: string, filePattern: string): string {
const files = this.joinPath(folderName, filePattern);
const command = `del "${files}"`;
return command;
}
public readLastLines(fileName: string, lineCount: number = 1): string {
const command = `powershell.exe Get-Content "${fileName}" -Tail ${lineCount}`;
return command;
}
public isProcessAliveCommand(pidFileName: string): string {
const command = `powershell.exe Get-Process -Id (get-content "${pidFileName}") -ErrorAction SilentlyContinue`;
return command;
}
public isProcessAliveProcessOutput(commandResult: RemoteCommandResult): boolean {
let result = true;
if (commandResult.exitCode !== 0) {
result = false;
}
return result;
}
public killChildProcesses(pidFileName: string): string {
const command = `powershell "$ppid=(type ${pidFileName}); function Kill-Tree {Param([int]$subppid);` +
`Get-CimInstance Win32_Process | Where-Object { $_.ParentProcessId -eq $subppid } | ForEach-Object { Kill-Tree $_.ProcessId }; ` +
`if ($subppid -ne $ppid){Stop-Process -Id $subppid}}` +
`kill-tree $ppid"`;
return command;
}
public extractFile(tarFileName: string, targetFolder: string): string {
const command = `tar -xf "${tarFileName}" -C "${targetFolder}"`;
return command;
}
public executeScript(script: string, _isFile: boolean): string {
const command = `${script}`;
return command;
}
}
export { WindowsCommands };
...@@ -8,8 +8,16 @@ import { RemoteCommandResult } from "./remoteMachineData"; ...@@ -8,8 +8,16 @@ import { RemoteCommandResult } from "./remoteMachineData";
abstract class OsCommands { abstract class OsCommands {
protected pathSpliter: string = '/'; protected pathSpliter: string = '/';
protected multiplePathSpliter: RegExp = new RegExp(`\\${this.pathSpliter}{2,}`); protected multiplePathSpliter: RegExp = new RegExp(`[\\\\/]{2,}`);
protected normalizePath: RegExp = new RegExp(`[\\\\/]`);
public abstract getScriptExt(): string;
public abstract generateStartScript(workingDirectory: string, trialJobId: string, experimentId: string,
trialSequenceId: string, isMultiPhase: boolean, jobIdFileName: string,
command: string, nniManagerAddress: string, nniManagerPort: number,
nniManagerVersion: string, logCollection: string, exitCodeFile: string,
codeDir: string, cudaVisibleSetting: string): string;
public abstract generateGpuStatsScript(scriptFolder: string): string;
public abstract createFolder(folderName: string, sharedFolder: boolean): string; public abstract createFolder(folderName: string, sharedFolder: boolean): string;
public abstract allowPermission(isRecursive: boolean, ...folders: string[]): string; public abstract allowPermission(isRecursive: boolean, ...folders: string[]): string;
public abstract removeFolder(folderName: string, isRecursive: boolean, isForce: boolean): string; public abstract removeFolder(folderName: string, isRecursive: boolean, isForce: boolean): string;
...@@ -26,6 +34,9 @@ abstract class OsCommands { ...@@ -26,6 +34,9 @@ abstract class OsCommands {
if (dir === '') { if (dir === '') {
dir = '.'; dir = '.';
} else { } else {
// normalize
dir = dir.replace(this.normalizePath, this.pathSpliter);
// reduce duplicate ones
dir = dir.replace(this.multiplePathSpliter, this.pathSpliter); dir = dir.replace(this.multiplePathSpliter, this.pathSpliter);
} }
return dir; return dir;
......
...@@ -85,78 +85,82 @@ export class RemoteMachineTrialJobDetail implements TrialJobDetail { ...@@ -85,78 +85,82 @@ export class RemoteMachineTrialJobDetail implements TrialJobDetail {
* The remote machine executor manager * The remote machine executor manager
*/ */
export class ExecutorManager { export class ExecutorManager {
private readonly executorArray: ShellExecutor[]; private readonly executorMap: Map<string, ShellExecutor> = new Map<string, ShellExecutor>();
private readonly maxTrialNumberPerConnection: number;
private readonly rmMeta: RemoteMachineMeta; private readonly rmMeta: RemoteMachineMeta;
constructor(executorArray: ShellExecutor[], maxTrialNumberPerConnection: number, rmMeta: RemoteMachineMeta) {
private executors: ShellExecutor[] = [];
constructor(rmMeta: RemoteMachineMeta) {
this.rmMeta = rmMeta; this.rmMeta = rmMeta;
this.executorArray = executorArray;
this.maxTrialNumberPerConnection = maxTrialNumberPerConnection;
} }
/** public async getExecutor(id: string): Promise<ShellExecutor> {
* find a available executor, if no executor available, return a new one let isFound = false;
*/ let executor: ShellExecutor | undefined;
public async getAvailableExecutor(): Promise<ShellExecutor> {
for (const index of this.executorArray.keys()) {
const connectionNumber: number = this.executorArray[index].getUsedConnectionNumber;
if (connectionNumber < this.maxTrialNumberPerConnection) {
this.executorArray[index].addUsedConnectionNumber();
return this.executorArray[index]; // already assigned
if (this.executorMap.has(id)) {
executor = this.executorMap.get(id);
if (executor === undefined) {
throw new Error("executor shouldn't be undefined before return!");
} }
return executor;
} }
//init a new executor if could not get an available one for (const candidateExecutor of this.executors) {
return await this.initNewShellExecutor(); if (candidateExecutor.addUsage()) {
} isFound = true;
executor = candidateExecutor;
break;
}
}
// init a new executor if no free one.
if (!isFound) {
executor = await this.createShellExecutor();
}
/** if (executor === undefined) {
* add a new executor to executorArray throw new Error("executor shouldn't be undefined before set!");
* @param executor ShellExecutor }
*/ this.executorMap.set(id, executor);
public addNewShellExecutor(executor: ShellExecutor): void {
this.executorArray.push(executor);
}
/** return executor;
* first executor instance is used for gpu collector and host job
*/
public getFirstExecutor(): ShellExecutor {
return this.executorArray[0];
} }
/** /**
* close all of executor * close all of executor
*/ */
public closeAllExecutor(): void { public releaseAllExecutor(): void {
for (const executor of this.executorArray) { this.executorMap.clear();
for (const executor of this.executors) {
executor.close(); executor.close();
} }
this.executors = [];
} }
/** /**
* retrieve resource, minus a number for given executor * retrieve resource, minus a number for given executor
* @param executor executor * @param executor executor
*/ */
public releaseConnection(executor: ShellExecutor | undefined): void { public releaseExecutor(id: string): void {
const executor = this.executorMap.get(id);
if (executor === undefined) { if (executor === undefined) {
throw new Error(`could not release a undefined executor`); throw new Error(`executor for ${id} is not found`);
}
for (const index of this.executorArray.keys()) {
if (this.executorArray[index] === executor) {
this.executorArray[index].minusUsedConnectionNumber();
break;
}
} }
executor.releaseUsage();
this.executorMap.delete(id);
} }
/** /**
* Create a new connection executor and initialize it * Create a new connection executor and initialize it
*/ */
private async initNewShellExecutor(): Promise<ShellExecutor> { private async createShellExecutor(): Promise<ShellExecutor> {
const executor = new ShellExecutor(); const executor = new ShellExecutor();
await executor.initialize(this.rmMeta); await executor.initialize(this.rmMeta);
if (!executor.addUsage()) {
throw new Error("failed to add usage on new created Executor! It's a wired bug!");
}
this.executors.push(executor);
return executor; return executor;
} }
} }
...@@ -175,22 +179,3 @@ export enum ScheduleResultType { ...@@ -175,22 +179,3 @@ export enum ScheduleResultType {
// Cannot match requirement even if all GPU are a // Cannot match requirement even if all GPU are a
REQUIRE_EXCEED_TOTAL REQUIRE_EXCEED_TOTAL
} }
export const REMOTEMACHINE_TRIAL_COMMAND_FORMAT: string =
`#!/bin/bash
export NNI_PLATFORM=remote NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={2} NNI_EXP_ID={3} \
NNI_TRIAL_SEQ_ID={4} MULTI_PHASE={5} NNI_CODE_DIR={6}
cp -r $NNI_CODE_DIR/. $NNI_SYS_DIR
cd $NNI_SYS_DIR
sh install_nni.sh
echo $$ >{7}
python3 -m nni_trial_tool.trial_keeper --trial_command '{8}' --nnimanager_ip '{9}' --nnimanager_port '{10}' \
--nni_manager_version '{11}' --log_collection '{12}' 1>$NNI_OUTPUT_DIR/trialkeeper_stdout 2>$NNI_OUTPUT_DIR/trialkeeper_stderr
echo $? \`date +%s%3N\` >{13}`;
export const HOST_JOB_SHELL_FORMAT: string =
`#!/bin/bash
cd {0}
echo $$ >{1}
eval {2} >stdout 2>stderr
echo $? \`date +%s%3N\` >{3}`;
...@@ -4,27 +4,39 @@ ...@@ -4,27 +4,39 @@
'use strict'; 'use strict';
import * as assert from 'assert'; import * as assert from 'assert';
import * as fs from 'fs';
import * as os from 'os'; import * as os from 'os';
import * as path from 'path'; import * as path from 'path';
import * as fs from 'fs'; import { Client, ClientChannel, ConnectConfig, SFTPWrapper } from 'ssh2';
import { Client, ClientChannel, SFTPWrapper, ConnectConfig } from 'ssh2';
import { Deferred } from "ts-deferred";
import { RemoteCommandResult, RemoteMachineMeta } from "./remoteMachineData";
import * as stream from 'stream'; import * as stream from 'stream';
import { OsCommands } from "./osCommands"; import { Deferred } from "ts-deferred";
import { LinuxCommands } from "./extends/linuxCommands";
import { getLogger, Logger } from '../../common/log'; import { getLogger, Logger } from '../../common/log';
import { NNIError, NNIErrorNames } from '../../common/errors'; import { uniqueString, randomInt } from '../../common/utils';
import { execRemove, tarAdd } from '../common/util'; import { execRemove, tarAdd } from '../common/util';
import { getRemoteTmpDir, uniqueString, unixPathJoin } from '../../common/utils'; import { LinuxCommands } from "./extends/linuxCommands";
import { WindowsCommands } from './extends/windowsCommands';
import { OsCommands } from "./osCommands";
import { RemoteCommandResult, RemoteMachineMeta } from "./remoteMachineData";
import { NNIError, NNIErrorNames } from '../../common/errors';
class ShellExecutor { class ShellExecutor {
private sshClient: Client = new Client(); public name: string = "";
private osCommands: OsCommands | undefined;
private usedConnectionNumber: number = 0; //count the connection number of every client
protected pathSpliter: string = '/'; private readonly lineBreaker = new RegExp(`[\r\n]+`);
protected multiplePathSpliter: RegExp = new RegExp(`\\${this.pathSpliter}{2,}`); private readonly maxUsageCount = 5;
private osCommands: OsCommands | undefined;
private usedCount: number = 0; //count the connection number of every client
private readonly sshClient: Client;
private readonly log: Logger;
private tempPath: string = "";
private isWindows: boolean = false;
private channelDefaultOutputs: string[] = [];
constructor() {
this.log = getLogger();
this.sshClient = new Client();
}
public async initialize(rmMeta: RemoteMachineMeta): Promise<void> { public async initialize(rmMeta: RemoteMachineMeta): Promise<void> {
const deferred: Deferred<void> = new Deferred<void>(); const deferred: Deferred<void> = new Deferred<void>();
...@@ -33,8 +45,9 @@ class ShellExecutor { ...@@ -33,8 +45,9 @@ class ShellExecutor {
host: rmMeta.ip, host: rmMeta.ip,
port: rmMeta.port, port: rmMeta.port,
username: rmMeta.username, username: rmMeta.username,
tryKeyboard: true tryKeyboard: true,
}; };
this.name = `${rmMeta.username}@${rmMeta.ip}:${rmMeta.port}`;
if (rmMeta.passwd !== undefined) { if (rmMeta.passwd !== undefined) {
connectConfig.password = rmMeta.passwd; connectConfig.password = rmMeta.passwd;
} else if (rmMeta.sshKeyPath !== undefined) { } else if (rmMeta.sshKeyPath !== undefined) {
...@@ -49,20 +62,42 @@ class ShellExecutor { ...@@ -49,20 +62,42 @@ class ShellExecutor {
} else { } else {
deferred.reject(new Error(`No valid passwd or sshKeyPath is configed.`)); deferred.reject(new Error(`No valid passwd or sshKeyPath is configed.`));
} }
this.sshClient.on('ready', async () => { this.sshClient.on('ready', async () => {
// check OS type: windows or else // check OS type: windows or else
const result = await this.execute("ver"); const result = await this.execute("ver");
if (result.exitCode == 0 && result.stdout.search("Windows") > -1) { if (result.exitCode == 0 && result.stdout.search("Windows") > -1) {
// not implement Windows commands yet. this.osCommands = new WindowsCommands();
throw new Error("not implement Windows commands yet."); this.isWindows = true;
// detect default output and trying to remove it under windows.
// Anaconda has this kind of output.
let defaultResult = await this.execute("");
if (defaultResult.stdout !== "") {
deferred.reject(new Error(`The windows remote node shouldn't output welcome message, below content should be removed from the command window! \n` +
`${defaultResult.stdout}`));
}
defaultResult = await this.execute("powershell -command \"\"");
if (defaultResult.stdout !== "") {
this.channelDefaultOutputs.push(defaultResult.stdout);
}
this.log.debug(`set channelDefaultOutput to "${this.channelDefaultOutputs}"`);
// parse temp folder to expand possible environment variables.
const commandResult = await this.execute("echo %TEMP%");
this.tempPath = commandResult.stdout.replace(this.lineBreaker, "");
} else { } else {
this.osCommands = new LinuxCommands(); this.osCommands = new LinuxCommands();
// it's not stable to get tmp path by Linux command, like "echo /tmp" or "ld -d /tmp".
// Sometime it returns empty back, so hard code tmp path here.
this.tempPath = "/tmp";
} }
deferred.resolve(); deferred.resolve();
}).on('error', (err: Error) => { }).on('error', (err: Error) => {
// SSH connection error, reject with error message // SSH connection error, reject with error message
deferred.reject(new Error(err.message)); deferred.reject(new Error(err.message));
}).on("keyboard-interactive", (name, instructions, lang, prompts, finish) => { }).on("keyboard-interactive", (_name, _instructions, _lang, _prompts, finish) => {
finish([rmMeta.passwd]); finish([rmMeta.passwd]);
}).connect(connectConfig); }).connect(connectConfig);
...@@ -73,43 +108,108 @@ class ShellExecutor { ...@@ -73,43 +108,108 @@ class ShellExecutor {
this.sshClient.end(); this.sshClient.end();
} }
public get getUsedConnectionNumber(): number { public addUsage(): boolean {
return this.usedConnectionNumber; let isAddedSuccess = false;
if (this.usedCount < this.maxUsageCount) {
this.usedCount++;
isAddedSuccess = true;
}
return isAddedSuccess;
}
public releaseUsage(): boolean {
let canBeReleased = false;
if (this.usedCount > 0) {
this.usedCount--;
}
if (this.usedCount == 0) {
canBeReleased = true;
}
return canBeReleased;
}
public getScriptName(mainName: string): string {
if (this.osCommands === undefined) {
throw new Error("osCommands must be initialized!");
}
return `${mainName}.${this.osCommands.getScriptExt()}`;
}
public generateStartScript(workingDirectory: string, trialJobId: string, experimentId: string,
trialSequenceId: string, isMultiPhase: boolean,
command: string, nniManagerAddress: string, nniManagerPort: number,
nniManagerVersion: string, logCollection: string, cudaVisibleSetting: string): string {
if (this.osCommands === undefined) {
throw new Error("osCommands must be initialized!");
}
const jobIdFileName = this.joinPath(workingDirectory, '.nni', 'jobpid');
const exitCodeFile = this.joinPath(workingDirectory, '.nni', 'code');
const codeDir = this.getRemoteCodePath(experimentId);
return this.osCommands.generateStartScript(workingDirectory, trialJobId, experimentId,
trialSequenceId, isMultiPhase, jobIdFileName, command,
nniManagerAddress, nniManagerPort, nniManagerVersion,
logCollection, exitCodeFile, codeDir, cudaVisibleSetting);
}
public generateGpuStatsScript(experimentId: string): string {
if (this.osCommands === undefined) {
throw new Error("osCommands must be initialized!");
}
return this.osCommands.generateGpuStatsScript(this.getRemoteScriptsPath(experimentId));
}
public getTempPath(): string {
if (this.tempPath === "") {
throw new Error("tempPath must be initialized!");
}
return this.tempPath;
}
public getRemoteScriptsPath(experimentId: string): string {
return this.joinPath(this.getRemoteExperimentRootDir(experimentId), 'scripts');
}
public getRemoteCodePath(experimentId: string): string {
return this.joinPath(this.getRemoteExperimentRootDir(experimentId), 'nni-code');
} }
public addUsedConnectionNumber(): void { public getRemoteExperimentRootDir(experimentId: string): string {
this.usedConnectionNumber += 1; return this.joinPath(this.tempPath, 'nni', 'experiments', experimentId);
} }
public minusUsedConnectionNumber(): void { public joinPath(...paths: string[]): string {
this.usedConnectionNumber -= 1; if (!this.osCommands) {
throw new Error("osCommands must be initialized!");
}
return this.osCommands.joinPath(...paths);
} }
public async createFolder(folderName: string, sharedFolder: boolean = false): Promise<boolean> { public async createFolder(folderName: string, sharedFolder: boolean = false): Promise<boolean> {
const commandText = this.osCommands && this.osCommands.createFolder(folderName, sharedFolder); const commandText = this.osCommands && this.osCommands.createFolder(folderName, sharedFolder);
const commandResult = await this.execute(commandText); const commandResult = await this.execute(commandText);
const result = commandResult.exitCode >= 0; const result = commandResult.exitCode == 0;
return result; return result;
} }
public async allowPermission(isRecursive: boolean = false, ...folders: string[]): Promise<boolean> { public async allowPermission(isRecursive: boolean = false, ...folders: string[]): Promise<boolean> {
const commandText = this.osCommands && this.osCommands.allowPermission(isRecursive, ...folders); const commandText = this.osCommands && this.osCommands.allowPermission(isRecursive, ...folders);
const commandResult = await this.execute(commandText); const commandResult = await this.execute(commandText);
const result = commandResult.exitCode >= 0; const result = commandResult.exitCode == 0;
return result; return result;
} }
public async removeFolder(folderName: string, isRecursive: boolean = false, isForce: boolean = true): Promise<boolean> { public async removeFolder(folderName: string, isRecursive: boolean = false, isForce: boolean = true): Promise<boolean> {
const commandText = this.osCommands && this.osCommands.removeFolder(folderName, isRecursive, isForce); const commandText = this.osCommands && this.osCommands.removeFolder(folderName, isRecursive, isForce);
const commandResult = await this.execute(commandText); const commandResult = await this.execute(commandText);
const result = commandResult.exitCode >= 0; const result = commandResult.exitCode == 0;
return result; return result;
} }
public async removeFiles(folderOrFileName: string, filePattern: string = ""): Promise<boolean> { public async removeFiles(folderOrFileName: string, filePattern: string = ""): Promise<boolean> {
const commandText = this.osCommands && this.osCommands.removeFiles(folderOrFileName, filePattern); const commandText = this.osCommands && this.osCommands.removeFiles(folderOrFileName, filePattern);
const commandResult = await this.execute(commandText); const commandResult = await this.execute(commandText);
const result = commandResult.exitCode >= 0; const result = commandResult.exitCode == 0;
return result; return result;
} }
...@@ -142,10 +242,10 @@ class ShellExecutor { ...@@ -142,10 +242,10 @@ class ShellExecutor {
return commandResult.exitCode == 0; return commandResult.exitCode == 0;
} }
public async executeScript(script: string, isFile: boolean, isInteractive: boolean = false): Promise<boolean> { public async executeScript(script: string, isFile: boolean = false, isInteractive: boolean = false): Promise<RemoteCommandResult> {
const commandText = this.osCommands && this.osCommands.executeScript(script, isFile); const commandText = this.osCommands && this.osCommands.executeScript(script, isFile);
const commandResult = await this.execute(commandText, undefined, isInteractive); const commandResult = await this.execute(commandText, undefined, isInteractive);
return commandResult.exitCode == 0; return commandResult;
} }
/** /**
...@@ -154,13 +254,13 @@ class ShellExecutor { ...@@ -154,13 +254,13 @@ class ShellExecutor {
* @param remoteFilePath the target path in remote machine * @param remoteFilePath the target path in remote machine
*/ */
public async copyFileToRemote(localFilePath: string, remoteFilePath: string): Promise<boolean> { public async copyFileToRemote(localFilePath: string, remoteFilePath: string): Promise<boolean> {
const log: Logger = getLogger(); const commandIndex = randomInt(10000);
log.debug(`copyFileToRemote: localFilePath: ${localFilePath}, remoteFilePath: ${remoteFilePath}`); this.log.debug(`copyFileToRemote(${commandIndex}): localFilePath: ${localFilePath}, remoteFilePath: ${remoteFilePath}`);
const deferred: Deferred<boolean> = new Deferred<boolean>(); const deferred: Deferred<boolean> = new Deferred<boolean>();
this.sshClient.sftp((err: Error, sftp: SFTPWrapper) => { this.sshClient.sftp((err: Error, sftp: SFTPWrapper) => {
if (err !== undefined && err !== null) { if (err !== undefined && err !== null) {
log.error(`copyFileToRemote: ${err.message}, ${localFilePath}, ${remoteFilePath}`); this.log.error(`copyFileToRemote(${commandIndex}): ${err}`);
deferred.reject(err); deferred.reject(err);
return; return;
...@@ -169,6 +269,7 @@ class ShellExecutor { ...@@ -169,6 +269,7 @@ class ShellExecutor {
sftp.fastPut(localFilePath, remoteFilePath, (fastPutErr: Error) => { sftp.fastPut(localFilePath, remoteFilePath, (fastPutErr: Error) => {
sftp.end(); sftp.end();
if (fastPutErr !== undefined && fastPutErr !== null) { if (fastPutErr !== undefined && fastPutErr !== null) {
this.log.error(`copyFileToRemote(${commandIndex}) fastPutErr: ${fastPutErr}, ${localFilePath}, ${remoteFilePath}`);
deferred.reject(fastPutErr); deferred.reject(fastPutErr);
} else { } else {
deferred.resolve(true); deferred.resolve(true);
...@@ -183,12 +284,15 @@ class ShellExecutor { ...@@ -183,12 +284,15 @@ class ShellExecutor {
* Copy files and directories in local directory recursively to remote directory * Copy files and directories in local directory recursively to remote directory
* @param localDirectory local diretory * @param localDirectory local diretory
* @param remoteDirectory remote directory * @param remoteDirectory remote directory
* @param remoteOS the OS of remote machine
*/ */
public async copyDirectoryToRemote(localDirectory: string, remoteDirectory: string, remoteOS: string): Promise<void> { public async copyDirectoryToRemote(localDirectory: string, remoteDirectory: string): Promise<void> {
const tmpSuffix: string = uniqueString(5); const tmpSuffix: string = uniqueString(5);
const localTarPath: string = path.join(os.tmpdir(), `nni_tmp_local_${tmpSuffix}.tar.gz`); const localTarPath: string = path.join(os.tmpdir(), `nni_tmp_local_${tmpSuffix}.tar.gz`);
const remoteTarPath: string = unixPathJoin(getRemoteTmpDir(remoteOS), `nni_tmp_remote_${tmpSuffix}.tar.gz`); if (!this.osCommands) {
throw new Error("osCommands must be initialized!");
}
const remoteTarPath: string = this.osCommands.joinPath(this.tempPath, `nni_tmp_remote_${tmpSuffix}.tar.gz`);
// Create remote directory // Create remote directory
await this.createFolder(remoteDirectory); await this.createFolder(remoteDirectory);
// Compress files in local directory to experiment root directory // Compress files in local directory to experiment root directory
...@@ -202,12 +306,13 @@ class ShellExecutor { ...@@ -202,12 +306,13 @@ class ShellExecutor {
} }
public async getRemoteFileContent(filePath: string): Promise<string> { public async getRemoteFileContent(filePath: string): Promise<string> {
const commandIndex = randomInt(10000);
this.log.debug(`getRemoteFileContent(${commandIndex}): filePath: ${filePath}`);
const deferred: Deferred<string> = new Deferred<string>(); const deferred: Deferred<string> = new Deferred<string>();
this.sshClient.sftp((err: Error, sftp: SFTPWrapper) => { this.sshClient.sftp((err: Error, sftp: SFTPWrapper) => {
if (err !== undefined && err !== null) { if (err !== undefined && err !== null) {
getLogger() this.log.error(`getRemoteFileContent(${commandIndex}) sftp: ${err}`);
.error(`getRemoteFileContent: ${err.message}`); deferred.reject(new Error(`SFTP error: ${err}`));
deferred.reject(new Error(`SFTP error: ${err.message}`));
return; return;
} }
...@@ -228,8 +333,7 @@ class ShellExecutor { ...@@ -228,8 +333,7 @@ class ShellExecutor {
deferred.resolve(dataBuffer); deferred.resolve(dataBuffer);
}); });
} catch (error) { } catch (error) {
getLogger() this.log.error(`getRemoteFileContent(${commandIndex}): ${error.message}`);
.error(`getRemoteFileContent: ${error.message}`);
sftp.end(); sftp.end();
deferred.reject(new Error(`SFTP error: ${error.message}`)); deferred.reject(new Error(`SFTP error: ${error.message}`));
} }
...@@ -239,16 +343,20 @@ class ShellExecutor { ...@@ -239,16 +343,20 @@ class ShellExecutor {
} }
private async execute(command: string | undefined, processOutput: ((input: RemoteCommandResult) => RemoteCommandResult) | undefined = undefined, useShell: boolean = false): Promise<RemoteCommandResult> { private async execute(command: string | undefined, processOutput: ((input: RemoteCommandResult) => RemoteCommandResult) | undefined = undefined, useShell: boolean = false): Promise<RemoteCommandResult> {
const log: Logger = getLogger();
log.debug(`remoteExeCommand: command: [${command}]`);
const deferred: Deferred<RemoteCommandResult> = new Deferred<RemoteCommandResult>(); const deferred: Deferred<RemoteCommandResult> = new Deferred<RemoteCommandResult>();
let stdout: string = ''; let stdout: string = '';
let stderr: string = ''; let stderr: string = '';
let exitCode: number; let exitCode: number;
const commandIndex = randomInt(10000);
this.log.debug(`remoteExeCommand(${commandIndex}): [${command}]`);
// Windows always uses shell, and it needs to disable to get it works.
useShell = useShell && !this.isWindows;
const callback = (err: Error, channel: ClientChannel): void => { const callback = (err: Error, channel: ClientChannel): void => {
if (err !== undefined && err !== null) { if (err !== undefined && err !== null) {
log.error(`remoteExeCommand: ${err.message}`); this.log.error(`remoteExeCommand(${commandIndex}): ${err.message}`);
deferred.reject(err); deferred.reject(err);
return; return;
} }
...@@ -258,7 +366,23 @@ class ShellExecutor { ...@@ -258,7 +366,23 @@ class ShellExecutor {
}); });
channel.on('exit', (code: any) => { channel.on('exit', (code: any) => {
exitCode = <number>code; exitCode = <number>code;
log.debug(`remoteExeCommand exit(${exitCode})\nstdout: ${stdout}\nstderr: ${stderr}`);
// remove default output to get stdout correct.
if (this.channelDefaultOutputs.length > 0) {
let modifiedStdout = stdout;
this.channelDefaultOutputs.forEach(defaultOutput => {
if (modifiedStdout.startsWith(defaultOutput)) {
if (modifiedStdout.length > defaultOutput.length) {
modifiedStdout = modifiedStdout.substr(defaultOutput.length);
} else if (modifiedStdout.length === defaultOutput.length) {
modifiedStdout = "";
}
}
});
stdout = modifiedStdout;
}
this.log.debug(`remoteExeCommand(${commandIndex}) exit(${exitCode})\nstdout: ${stdout}\nstderr: ${stderr}`);
let result = { let result = {
stdout: stdout, stdout: stdout,
stderr: stderr, stderr: stderr,
...@@ -270,7 +394,7 @@ class ShellExecutor { ...@@ -270,7 +394,7 @@ class ShellExecutor {
} }
deferred.resolve(result); deferred.resolve(result);
}); });
channel.stderr.on('data', function (data) { channel.stderr.on('data', function (data: any) {
stderr += data; stderr += data;
}); });
......
...@@ -8,7 +8,6 @@ import * as chaiAsPromised from 'chai-as-promised'; ...@@ -8,7 +8,6 @@ import * as chaiAsPromised from 'chai-as-promised';
import * as component from '../../../common/component'; import * as component from '../../../common/component';
import { cleanupUnitTest, prepareUnitTest } from '../../../common/utils'; import { cleanupUnitTest, prepareUnitTest } from '../../../common/utils';
import { LinuxCommands } from '../extends/linuxCommands'; import { LinuxCommands } from '../extends/linuxCommands';
// import { TrialConfigMetadataKey } from '../trialConfigMetadataKey';
describe('Unit Test for linuxCommands', () => { describe('Unit Test for linuxCommands', () => {
...@@ -88,10 +87,6 @@ describe('Unit Test for linuxCommands', () => { ...@@ -88,10 +87,6 @@ describe('Unit Test for linuxCommands', () => {
)).to.equal(false); )).to.equal(false);
}) })
it('killChildProcesses', async () => {
chai.expect(linuxCommands.killChildProcesses("test")).to.equal("pkill -P `cat 'test'`");
})
it('extractFile', async () => { it('extractFile', async () => {
chai.expect(linuxCommands.extractFile("test.tar", "testfolder")).to.equal("tar -oxzf 'test.tar' -C 'testfolder'"); chai.expect(linuxCommands.extractFile("test.tar", "testfolder")).to.equal("tar -oxzf 'test.tar' -C 'testfolder'");
}) })
......
...@@ -8,29 +8,29 @@ import * as fs from 'fs'; ...@@ -8,29 +8,29 @@ import * as fs from 'fs';
import * as chai from 'chai'; import * as chai from 'chai';
import * as chaiAsPromised from 'chai-as-promised'; import * as chaiAsPromised from 'chai-as-promised';
import { Client } from 'ssh2';
import { ShellExecutor } from '../shellExecutor'; import { ShellExecutor } from '../shellExecutor';
import { prepareUnitTest, cleanupUnitTest } from '../../../common/utils'; import { prepareUnitTest, cleanupUnitTest } from '../../../common/utils';
const LOCALFILE: string = '/tmp/localSshclientUTData'; const LOCALFILE: string = 'localSshUTData';
const REMOTEFILE: string = '/tmp/remoteSshclientUTData'; const REMOTEFILE: string = 'remoteSshUTData';
const REMOTEFOLDER: string = '/tmp/remoteSshclientUTFolder'; const REMOTEFOLDER: string = 'remoteSshUTFolder';
async function copyFile(executor: ShellExecutor): Promise<void> { async function copyFile(executor: ShellExecutor): Promise<void> {
await executor.copyFileToRemote(LOCALFILE, REMOTEFILE); const remoteFullName = executor.joinPath(executor.getTempPath(), REMOTEFILE);
await executor.copyFileToRemote(LOCALFILE, remoteFullName);
} }
async function copyFileToRemoteLoop(executor: ShellExecutor): Promise<void> { async function copyFileToRemoteLoop(executor: ShellExecutor): Promise<void> {
for (let i: number = 0; i < 10; i++) { const remoteFullName = executor.joinPath(executor.getTempPath(), REMOTEFILE);
// console.log(i); for (let i: number = 0; i < 3; i++) {
await executor.copyFileToRemote(LOCALFILE, REMOTEFILE); await executor.copyFileToRemote(LOCALFILE, remoteFullName);
} }
} }
async function getRemoteFileContentLoop(executor: ShellExecutor): Promise<void> { async function getRemoteFileContentLoop(executor: ShellExecutor): Promise<void> {
for (let i: number = 0; i < 10; i++) { const remoteFullName = executor.joinPath(executor.getTempPath(), REMOTEFILE);
// console.log(i); for (let i: number = 0; i < 3; i++) {
await executor.getRemoteFileContent(REMOTEFILE); await executor.getRemoteFileContent(remoteFullName);
} }
} }
...@@ -41,14 +41,16 @@ describe('ShellExecutor test', () => { ...@@ -41,14 +41,16 @@ describe('ShellExecutor test', () => {
rmMeta = JSON.parse(fs.readFileSync('../../.vscode/rminfo.json', 'utf8')); rmMeta = JSON.parse(fs.readFileSync('../../.vscode/rminfo.json', 'utf8'));
console.log(rmMeta); console.log(rmMeta);
} catch (err) { } catch (err) {
console.log(`Please configure rminfo.json to enable remote machine test.${err}`); console.log(`Please configure rminfo.json to enable remote machine test. ${err}`);
skip = true; skip = true;
} }
before(async () => { before(async () => {
chai.should(); chai.should();
chai.use(chaiAsPromised); chai.use(chaiAsPromised);
await cpp.exec(`echo '1234' > ${LOCALFILE}`); if (!fs.existsSync(LOCALFILE)){
await cpp.exec(`echo '1234' > ${LOCALFILE}`);
}
prepareUnitTest(); prepareUnitTest();
}); });
...@@ -61,26 +63,27 @@ describe('ShellExecutor test', () => { ...@@ -61,26 +63,27 @@ describe('ShellExecutor test', () => {
if (skip) { if (skip) {
return; return;
} }
const shellExecutor: ShellExecutor = new ShellExecutor(); const executor: ShellExecutor = new ShellExecutor();
await shellExecutor.initialize(rmMeta); await executor.initialize(rmMeta);
let result = await shellExecutor.createFolder(REMOTEFOLDER, false); const remoteFullPath = executor.joinPath(executor.getTempPath(), REMOTEFOLDER);
let result = await executor.createFolder(remoteFullPath, false);
chai.expect(result).eq(true); chai.expect(result).eq(true);
result = await shellExecutor.removeFolder(REMOTEFOLDER); const commandResult = await executor.executeScript("dir");
chai.expect(commandResult.exitCode).eq(0);
result = await executor.removeFolder(remoteFullPath);
chai.expect(result).eq(true); chai.expect(result).eq(true);
await executor.close();
}); });
it('Test ShellExecutor', async () => { it('Test ShellExecutor', async () => {
if (skip) { if (skip) {
return; return;
} }
const shellExecutor: ShellExecutor = new ShellExecutor(); const executor: ShellExecutor = new ShellExecutor();
await shellExecutor.initialize(rmMeta); await executor.initialize(rmMeta);
await copyFile(shellExecutor); await copyFile(executor);
await Promise.all([ await copyFileToRemoteLoop(executor);
copyFileToRemoteLoop(shellExecutor), await getRemoteFileContentLoop(executor);
copyFileToRemoteLoop(shellExecutor), await executor.close();
copyFileToRemoteLoop(shellExecutor),
getRemoteFileContentLoop(shellExecutor)
]);
}); });
}); });
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
import * as chai from 'chai';
import * as chaiAsPromised from 'chai-as-promised';
import * as component from '../../../common/component';
import { cleanupUnitTest, prepareUnitTest } from '../../../common/utils';
import { WindowsCommands } from '../extends/windowsCommands';
describe('Unit Test for Windows Commands', () => {
let windowsCommands: WindowsCommands
before(() => {
chai.should();
chai.use(chaiAsPromised);
prepareUnitTest();
});
after(() => {
cleanupUnitTest();
});
beforeEach(() => {
windowsCommands = component.get(WindowsCommands);
});
afterEach(() => {
});
it('joinPath', async () => {
chai.expect(windowsCommands.joinPath("/root/", "\\first")).to.equal("\\root\\first");
chai.expect(windowsCommands.joinPath("root/", "first")).to.equal("root\\first");
chai.expect(windowsCommands.joinPath("\\root/", "\\first")).to.equal("\\root\\first");
chai.expect(windowsCommands.joinPath("\\root\\", "\\first")).to.equal("\\root\\first");
chai.expect(windowsCommands.joinPath("\\root", "first")).to.equal("\\root\\first");
chai.expect(windowsCommands.joinPath("\\root\\", "first")).to.equal("\\root\\first");
chai.expect(windowsCommands.joinPath("root\\", "first")).to.equal("root\\first");
chai.expect(windowsCommands.joinPath("root\\")).to.equal("root\\");
chai.expect(windowsCommands.joinPath("root")).to.equal("root");
chai.expect(windowsCommands.joinPath(".\\root")).to.equal(".\\root");
chai.expect(windowsCommands.joinPath("")).to.equal(".");
chai.expect(windowsCommands.joinPath("..")).to.equal("..");
})
it('createFolder', async () => {
chai.expect(windowsCommands.createFolder("test")).to.equal("mkdir \"test\"");
chai.expect(windowsCommands.createFolder("test", true)).to.equal("mkdir \"test\"\r\nICACLS \"test\" /grant \"Users\":F");
})
it('allowPermission', async () => {
chai.expect(windowsCommands.allowPermission(true, "test", "test1")).to.equal("ICACLS \"test\" /grant \"Users\":F /T\r\nICACLS \"test1\" /grant \"Users\":F /T\r\n");
chai.expect(windowsCommands.allowPermission(false, "test")).to.equal("ICACLS \"test\" /grant \"Users\":F\r\n");
})
it('removeFolder', async () => {
chai.expect(windowsCommands.removeFolder("test")).to.equal("rmdir /q \"test\"");
chai.expect(windowsCommands.removeFolder("test", true)).to.equal("rmdir /s /q \"test\"");
chai.expect(windowsCommands.removeFolder("test", true, false)).to.equal("rmdir /s \"test\"");
chai.expect(windowsCommands.removeFolder("test", false, false)).to.equal("rmdir \"test\"");
chai.expect(windowsCommands.removeFolder("test", true, true)).to.equal("rmdir /s /q \"test\"");
})
it('removeFiles', async () => {
chai.expect(windowsCommands.removeFiles("test", "*.sh")).to.equal("del \"test\\*.sh\"");
chai.expect(windowsCommands.removeFiles("test", "")).to.equal("del \"test\"");
})
it('readLastLines', async () => {
chai.expect(windowsCommands.readLastLines("test", 3)).to.equal("powershell.exe Get-Content \"test\" -Tail 3");
})
it('isProcessAlive', async () => {
chai.expect(windowsCommands.isProcessAliveCommand("test")).to.equal("powershell.exe Get-Process -Id (get-content \"test\") -ErrorAction SilentlyContinue");
chai.expect(windowsCommands.isProcessAliveProcessOutput(
{
exitCode: 0,
stdout: "",
stderr: ""
}
)).to.equal(true);
chai.expect(windowsCommands.isProcessAliveProcessOutput(
{
exitCode: 10,
stdout: "",
stderr: ""
}
)).to.equal(false);
})
it('extractFile', async () => {
chai.expect(windowsCommands.extractFile("test.tar", "testfolder")).to.equal("tar -xf \"test.tar\" -C \"testfolder\"");
})
it('executeScript', async () => {
chai.expect(windowsCommands.executeScript("test.sh", true)).to.equal("test.sh");
chai.expect(windowsCommands.executeScript("test script'\"", false)).to.equal("test script'\"");
})
});
...@@ -13,7 +13,6 @@ assessor: ...@@ -13,7 +13,6 @@ assessor:
trial: trial:
codeDir: ../../../examples/trials/mnist-annotation codeDir: ../../../examples/trials/mnist-annotation
command: python3 mnist.py --batch_num 10 command: python3 mnist.py --batch_num 10
gpuNum: 0
useAnnotation: true useAnnotation: true
multiPhase: false multiPhase: false
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment