Unverified Commit f548d82f authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #250 from microsoft/master

merge master
parents 0a742aff 69cae211
...@@ -2,18 +2,54 @@ ...@@ -2,18 +2,54 @@
NNI can run one experiment on multiple remote machines through SSH, called `remote` mode. It's like a lightweight training platform. In this mode, NNI can be started from your computer, and dispatch trials to remote machines in parallel. NNI can run one experiment on multiple remote machines through SSH, called `remote` mode. It's like a lightweight training platform. In this mode, NNI can be started from your computer, and dispatch trials to remote machines in parallel.
## Remote machine requirements The OS of remote machines supports `Linux`, `Windows 10`, and `Windows Server 2019`.
* It only supports Linux as remote machines, and [linux part in system specification](../Tutorial/InstallationLinux.md) is same as NNI local mode. ## Requirements
* Follow [installation](../Tutorial/InstallationLinux.md) to install NNI on each machine. * Make sure the default environment of remote machines meets requirements of your trial code. If the default environment does not meet the requirements, the setup script can be added into `command` field of NNI config.
* Make sure remote machines meet environment requirements of your trial code. If the default environment does not meet the requirements, the setup script can be added into `command` field of NNI config.
* Make sure remote machines can be accessed through SSH from the machine which runs `nnictl` command. It supports both password and key authentication of SSH. For advanced usages, please refer to [machineList part of configuration](../Tutorial/ExperimentConfig.md). * Make sure remote machines can be accessed through SSH from the machine which runs `nnictl` command. It supports both password and key authentication of SSH. For advanced usages, please refer to [machineList part of configuration](../Tutorial/ExperimentConfig.md).
* Make sure the NNI version on each machine is consistent. * Make sure the NNI version on each machine is consistent.
* Make sure the command of Trial is compatible with remote OSes, if you want to use remote Linux and Windows together. For example, the default python 3.x executable called `python3` on Linux, and `python` on Windows.
### Linux
* Follow [installation](../Tutorial/InstallationLinux.md) to install NNI on the remote machine.
### Windows
* Follow [installation](../Tutorial/InstallationWin.md) to install NNI on the remote machine.
* Install and start `OpenSSH Server`.
1. Open `Settings` app on Windows.
2. Click `Apps`, then click `Optional features`.
3. Click `Add a feature`, search and select `OpenSSH Server`, and then click `Install`.
4. Once it's installed, run below command to start and set to automatic start.
```bat
sc config sshd start=auto
net start sshd
```
* Make sure remote account is administrator, so that it can stop running trials.
* Make sure there is no welcome message more than default, since it causes ssh2 failed in NodeJs. For example, if you're using Data Science VM on Azure, it needs to remove extra echo commands in `C:\dsvm\tools\setup\welcome.bat`.
The output like below is ok, when opening a new command window.
```text
Microsoft Windows [Version 10.0.17763.1192]
(c) 2018 Microsoft Corporation. All rights reserved.
(py37_default) C:\Users\AzureUser>
```
## Run an experiment ## Run an experiment
e.g. there are three machines, which can be logged in with username and password. e.g. there are three machines, which can be logged in with username and password.
......
# Install on Windows # Install on Windows
## Installation ## Prerequires
Anaconda or Miniconda is highly recommended to manage multiple Python environments. * Python 3.5 (or above) 64-bit. [Anaconda](https://www.anaconda.com/products/individual) or [Miniconda](https://docs.conda.io/en/latest/miniconda.html) is highly recommended to manage multiple Python environments on Windows.
### Install NNI through pip * If it's a newly installed Python environment, it needs to install [Microsoft C++ Build Tools](https://visualstudio.microsoft.com/visual-cpp-build-tools/) to support build NNI dependencies like `scikit-learn`.
Prerequisites: `python 64-bit >= 3.5` ```bat
pip install cython wheel
```
```bash * git for verifying installation.
python -m pip install --upgrade nni
```
### Install NNI through source code ## Install NNI
If you are interested in special or the latest code versions, you can install NNI through source code. In most cases, you can install and upgrade NNI from pip package. It's easy and fast.
Prerequisites: `python 64-bit >=3.5`, `git`, `PowerShell`. If you are interested in special or the latest code versions, you can install NNI through source code.
```bash If you want to contribute to NNI, refer to [setup development environment](SetupNniDeveloperEnvironment.md).
git clone -b v1.5 https://github.com/Microsoft/nni.git
cd nni * From pip package
powershell -ExecutionPolicy Bypass -file install.ps1
``` ```bat
python -m pip install --upgrade nni
```
* From source code
```bat
git clone -b v1.5 https://github.com/Microsoft/nni.git
cd nni
powershell -ExecutionPolicy Bypass -file install.ps1
```
## Verify installation ## Verify installation
The following example is built on TensorFlow 1.x. Make sure **TensorFlow 1.x is used** when running it. The following example is built on TensorFlow 1.x. Make sure **TensorFlow 1.x is used** when running it.
* Download the examples via clone the source code. * Clone examples within source code.
```bash ```bat
git clone -b v1.5 https://github.com/Microsoft/nni.git git clone -b v1.5 https://github.com/Microsoft/nni.git
``` ```
* Run the MNIST example. * Run the MNIST example.
```bash ```bat
nnictl create --config nni\examples\trials\mnist-tfv1\config_windows.yml nnictl create --config nni\examples\trials\mnist-tfv1\config_windows.yml
``` ```
Note: for other examples you need to change trial command `python3` to `python` in each example YAML, if python3 is called through `python` on your machine. Note: If you are familiar with other frameworks, you can choose corresponding example under `examples\trials`. It needs to change trial command `python3` to `python` in each example YAML, since default installation has `python.exe`, not `python3.exe` executable.
* Wait for the message `INFO: Successfully started experiment!` in the command line. This message indicates that your experiment has been successfully started. You can explore the experiment using the `Web UI url`. * Wait for the message `INFO: Successfully started experiment!` in the command line. This message indicates that your experiment has been successfully started. You can explore the experiment using the `Web UI url`.
...@@ -112,18 +122,20 @@ If there is a stderr file, please check it. Two possible cases are: ...@@ -112,18 +122,20 @@ If there is a stderr file, please check it. Two possible cases are:
* forgetting to install experiment dependencies such as TensorFlow, Keras and so on. * forgetting to install experiment dependencies such as TensorFlow, Keras and so on.
### Fail to use BOHB on Windows ### Fail to use BOHB on Windows
Make sure a C++ 14.0 compiler is installed when trying to run `nnictl package install --name=BOHB` to install the dependencies. Make sure a C++ 14.0 compiler is installed when trying to run `nnictl package install --name=BOHB` to install the dependencies.
### Not supported tuner on Windows ### Not supported tuner on Windows
SMAC is not supported currently; for the specific reason refer to this [GitHub issue](https://github.com/automl/SMAC3/issues/483). SMAC is not supported currently; for the specific reason refer to this [GitHub issue](https://github.com/automl/SMAC3/issues/483).
### Use a Windows server as a remote worker ### Use Windows as a remote worker
Currently, you can't.
Note: Refer to [Remote Machine mode](../TrainingService/RemoteMachineMode.md).
* If an error like `Segmentation fault` is encountered, please refer to the [FAQ](FAQ.md) ### Segmentation fault (core dumped) when installing
Refer to [FAQ](FAQ.md).
## Further reading ## Further reading
......
...@@ -84,7 +84,11 @@ class SendMetrics(keras.callbacks.Callback): ...@@ -84,7 +84,11 @@ class SendMetrics(keras.callbacks.Callback):
Run on end of each epoch Run on end of each epoch
''' '''
LOG.debug(logs) LOG.debug(logs)
nni.report_intermediate_result(logs["val_acc"]) # TensorFlow 2.0 API reference claims the key is `val_acc`, but in fact it's `val_accuracy`
if 'val_acc' in logs:
nni.report_intermediate_result(logs['val_acc'])
else:
nni.report_intermediate_result(logs['val_accuracy'])
def train(args, params): def train(args, params):
''' '''
......
...@@ -86,7 +86,11 @@ class SendMetrics(keras.callbacks.Callback): ...@@ -86,7 +86,11 @@ class SendMetrics(keras.callbacks.Callback):
Run on end of each epoch Run on end of each epoch
''' '''
LOG.debug(logs) LOG.debug(logs)
nni.report_intermediate_result(logs["val_acc"]) # TensorFlow 2.0 API reference claims the key is `val_acc`, but in fact it's `val_accuracy`
if 'val_acc' in logs:
nni.report_intermediate_result(logs['val_acc'])
else:
nni.report_intermediate_result(logs['val_accuracy'])
def train(args, params): def train(args, params):
''' '''
......
...@@ -152,7 +152,11 @@ class SendMetrics(keras.callbacks.Callback): ...@@ -152,7 +152,11 @@ class SendMetrics(keras.callbacks.Callback):
if logs is None: if logs is None:
logs = dict() logs = dict()
logger.debug(logs) logger.debug(logs)
nni.report_intermediate_result(logs["val_accuracy"]) # TensorFlow 2.0 API reference claims the key is `val_acc`, but in fact it's `val_accuracy`
if 'val_acc' in logs:
nni.report_intermediate_result(logs['val_acc'])
else:
nni.report_intermediate_result(logs['val_accuracy'])
# Training # Training
......
...@@ -152,9 +152,11 @@ class SendMetrics(keras.callbacks.Callback): ...@@ -152,9 +152,11 @@ class SendMetrics(keras.callbacks.Callback):
if logs is None: if logs is None:
logs = dict() logs = dict()
logger.debug(logs) logger.debug(logs)
# accuracy key for keras 2.2.2: val_acc # TensorFlow 2.0 API reference claims the key is `val_acc`, but in fact it's `val_accuracy`
# for keras 2.3.1: val_accuracy if 'val_acc' in logs:
nni.report_intermediate_result(logs["val_accuracy"]) nni.report_intermediate_result(logs['val_acc'])
else:
nni.report_intermediate_result(logs['val_accuracy'])
# Training # Training
......
...@@ -148,12 +148,13 @@ cmd /c $NNI_YARN ...@@ -148,12 +148,13 @@ cmd /c $NNI_YARN
cmd /c $NNI_YARN build cmd /c $NNI_YARN build
Copy-Item config -Destination .\dist\ -Recurse -Force Copy-Item config -Destination .\dist\ -Recurse -Force
# Building WebUI # Building WebUI
# office-ui-fabric-react need longer time. the 180000 is in ms, mean 180 seconds, longer than default 30 seconds.
cd ..\webui cd ..\webui
cmd /c $NNI_YARN cmd /c $NNI_YARN --network-timeout 180000
cmd /c $NNI_YARN build cmd /c $NNI_YARN build
# Building NasUI # Building NasUI
cd ..\nasui cd ..\nasui
cmd /c $NNI_YARN cmd /c $NNI_YARN --network-timeout 180000
cmd /c $NNI_YARN build cmd /c $NNI_YARN build
cd ..\.. cd ..\..
......
...@@ -22,7 +22,7 @@ import { HyperParameters, TrainingService, TrialJobStatus } from './trainingServ ...@@ -22,7 +22,7 @@ import { HyperParameters, TrainingService, TrialJobStatus } from './trainingServ
function getExperimentRootDir(): string { function getExperimentRootDir(): string {
return getExperimentStartupInfo() return getExperimentStartupInfo()
.getLogDir(); .getLogDir();
} }
function getLogDir(): string { function getLogDir(): string {
...@@ -31,7 +31,7 @@ function getLogDir(): string { ...@@ -31,7 +31,7 @@ function getLogDir(): string {
function getLogLevel(): string { function getLogLevel(): string {
return getExperimentStartupInfo() return getExperimentStartupInfo()
.getLogLevel(); .getLogLevel();
} }
function getDefaultDatabaseDir(): string { function getDefaultDatabaseDir(): string {
...@@ -113,11 +113,16 @@ function uniqueString(len: number): string { ...@@ -113,11 +113,16 @@ function uniqueString(len: number): string {
return String.fromCharCode(...codes); return String.fromCharCode(...codes);
} }
function randomInt(max: number): number {
return Math.floor(Math.random() * max);
}
function randomSelect<T>(a: T[]): T { function randomSelect<T>(a: T[]): T {
assert(a !== undefined); assert(a !== undefined);
return a[Math.floor(Math.random() * a.length)]; return a[Math.floor(Math.random() * a.length)];
} }
function parseArg(names: string[]): string { function parseArg(names: string[]): string {
if (process.argv.length >= 4) { if (process.argv.length >= 4) {
for (let i: number = 2; i < process.argv.length - 1; i++) { for (let i: number = 2; i < process.argv.length - 1; i++) {
...@@ -132,7 +137,7 @@ function parseArg(names: string[]): string { ...@@ -132,7 +137,7 @@ function parseArg(names: string[]): string {
function getCmdPy(): string { function getCmdPy(): string {
let cmd = 'python3'; let cmd = 'python3';
if(process.platform === 'win32'){ if (process.platform === 'win32') {
cmd = 'python'; cmd = 'python';
} }
return cmd; return cmd;
...@@ -160,7 +165,7 @@ function generateParamFileName(hyperParameters: HyperParameters): string { ...@@ -160,7 +165,7 @@ function generateParamFileName(hyperParameters: HyperParameters): string {
assert(hyperParameters.index >= 0); assert(hyperParameters.index >= 0);
let paramFileName: string; let paramFileName: string;
if(hyperParameters.index == 0) { if (hyperParameters.index == 0) {
paramFileName = 'parameter.cfg'; paramFileName = 'parameter.cfg';
} else { } else {
paramFileName = `parameter_${hyperParameters.index}.cfg` paramFileName = `parameter_${hyperParameters.index}.cfg`
...@@ -211,9 +216,9 @@ function getIPV4Address(): string { ...@@ -211,9 +216,9 @@ function getIPV4Address(): string {
return cachedipv4Address; return cachedipv4Address;
} }
if(os.networkInterfaces().eth0) { if (os.networkInterfaces().eth0) {
for(const item of os.networkInterfaces().eth0) { for (const item of os.networkInterfaces().eth0) {
if(item.family === 'IPv4') { if (item.family === 'IPv4') {
cachedipv4Address = item.address; cachedipv4Address = item.address;
return cachedipv4Address; return cachedipv4Address;
} }
...@@ -225,14 +230,6 @@ function getIPV4Address(): string { ...@@ -225,14 +230,6 @@ function getIPV4Address(): string {
throw Error('getIPV4Address() failed because no valid IPv4 address found.') throw Error('getIPV4Address() failed because no valid IPv4 address found.')
} }
function getRemoteTmpDir(osType: string): string {
if (osType == 'linux') {
return '/tmp';
} else {
throw Error(`remote OS ${osType} not supported`);
}
}
/** /**
* Get the status of canceled jobs according to the hint isEarlyStopped * Get the status of canceled jobs according to the hint isEarlyStopped
*/ */
...@@ -245,7 +242,7 @@ function getJobCancelStatus(isEarlyStopped: boolean): TrialJobStatus { ...@@ -245,7 +242,7 @@ function getJobCancelStatus(isEarlyStopped: boolean): TrialJobStatus {
* @param directory directory name * @param directory directory name
*/ */
function countFilesRecursively(directory: string): Promise<number> { function countFilesRecursively(directory: string): Promise<number> {
if(!fs.existsSync(directory)) { if (!fs.existsSync(directory)) {
throw Error(`Direcotory ${directory} doesn't exist`); throw Error(`Direcotory ${directory} doesn't exist`);
} }
...@@ -261,13 +258,13 @@ function countFilesRecursively(directory: string): Promise<number> { ...@@ -261,13 +258,13 @@ function countFilesRecursively(directory: string): Promise<number> {
let fileCount: number = -1; let fileCount: number = -1;
let cmd: string; let cmd: string;
if(process.platform === "win32") { if (process.platform === "win32") {
cmd = `powershell "Get-ChildItem -Path ${directory} -Recurse -File | Measure-Object | %{$_.Count}"` cmd = `powershell "Get-ChildItem -Path ${directory} -Recurse -File | Measure-Object | %{$_.Count}"`
} else { } else {
cmd = `find ${directory} -type f | wc -l`; cmd = `find ${directory} -type f | wc -l`;
} }
cpp.exec(cmd).then((result) => { cpp.exec(cmd).then((result) => {
if(result.stdout && parseInt(result.stdout)) { if (result.stdout && parseInt(result.stdout)) {
fileCount = parseInt(result.stdout); fileCount = parseInt(result.stdout);
} }
deferred.resolve(fileCount); deferred.resolve(fileCount);
...@@ -280,20 +277,20 @@ function countFilesRecursively(directory: string): Promise<number> { ...@@ -280,20 +277,20 @@ function countFilesRecursively(directory: string): Promise<number> {
function validateFileName(fileName: string): boolean { function validateFileName(fileName: string): boolean {
const pattern: string = '^[a-z0-9A-Z._-]+$'; const pattern: string = '^[a-z0-9A-Z._-]+$';
const validateResult = fileName.match(pattern); const validateResult = fileName.match(pattern);
if(validateResult) { if (validateResult) {
return true; return true;
} }
return false; return false;
} }
async function validateFileNameRecursively(directory: string): Promise<boolean> { async function validateFileNameRecursively(directory: string): Promise<boolean> {
if(!fs.existsSync(directory)) { if (!fs.existsSync(directory)) {
throw Error(`Direcotory ${directory} doesn't exist`); throw Error(`Direcotory ${directory} doesn't exist`);
} }
const fileNameArray: string[] = fs.readdirSync(directory); const fileNameArray: string[] = fs.readdirSync(directory);
let result = true; let result = true;
for(const name of fileNameArray){ for (const name of fileNameArray) {
const fullFilePath: string = path.join(directory, name); const fullFilePath: string = path.join(directory, name);
try { try {
// validate file names and directory names // validate file names and directory names
...@@ -301,14 +298,14 @@ async function validateFileNameRecursively(directory: string): Promise<boolean> ...@@ -301,14 +298,14 @@ async function validateFileNameRecursively(directory: string): Promise<boolean>
if (fs.lstatSync(fullFilePath).isDirectory()) { if (fs.lstatSync(fullFilePath).isDirectory()) {
result = result && await validateFileNameRecursively(fullFilePath); result = result && await validateFileNameRecursively(fullFilePath);
} }
if(!result) { if (!result) {
return Promise.reject(new Error(`file name in ${fullFilePath} is not valid!`)); return Promise.reject(new Error(`file name in ${fullFilePath} is not valid!`));
} }
} catch(error) { } catch (error) {
return Promise.reject(error); return Promise.reject(error);
} }
} }
return Promise.resolve(result); return Promise.resolve(result);
} }
/** /**
...@@ -316,9 +313,9 @@ async function validateFileNameRecursively(directory: string): Promise<boolean> ...@@ -316,9 +313,9 @@ async function validateFileNameRecursively(directory: string): Promise<boolean>
*/ */
async function getVersion(): Promise<string> { async function getVersion(): Promise<string> {
const deferred: Deferred<string> = new Deferred<string>(); const deferred: Deferred<string> = new Deferred<string>();
import(path.join(__dirname, '..', 'package.json')).then((pkg)=>{ import(path.join(__dirname, '..', 'package.json')).then((pkg) => {
deferred.resolve(pkg.version); deferred.resolve(pkg.version);
}).catch((error)=>{ }).catch((error) => {
deferred.reject(error); deferred.reject(error);
}); });
return deferred.promise; return deferred.promise;
...@@ -331,9 +328,9 @@ function getTunerProc(command: string, stdio: StdioOptions, newCwd: string, newE ...@@ -331,9 +328,9 @@ function getTunerProc(command: string, stdio: StdioOptions, newCwd: string, newE
let cmd: string = command; let cmd: string = command;
let arg: string[] = []; let arg: string[] = [];
let newShell: boolean = true; let newShell: boolean = true;
if(process.platform === "win32"){ if (process.platform === "win32") {
cmd = command.split(" ", 1)[0]; cmd = command.split(" ", 1)[0];
arg = command.substr(cmd.length+1).split(" "); arg = command.substr(cmd.length + 1).split(" ");
newShell = false; newShell = false;
} }
const tunerProc: ChildProcess = spawn(cmd, arg, { const tunerProc: ChildProcess = spawn(cmd, arg, {
...@@ -383,7 +380,7 @@ async function killPid(pid: any): Promise<void> { ...@@ -383,7 +380,7 @@ async function killPid(pid: any): Promise<void> {
if (process.platform === "win32") { if (process.platform === "win32") {
await cpp.exec(`cmd.exe /c taskkill /PID ${pid} /F`); await cpp.exec(`cmd.exe /c taskkill /PID ${pid} /F`);
} }
else{ else {
await cpp.exec(`kill -9 ${pid}`); await cpp.exec(`kill -9 ${pid}`);
} }
} catch (error) { } catch (error) {
...@@ -397,7 +394,7 @@ function getNewLine(): string { ...@@ -397,7 +394,7 @@ function getNewLine(): string {
if (process.platform === "win32") { if (process.platform === "win32") {
return "\r\n"; return "\r\n";
} }
else{ else {
return "\n"; return "\n";
} }
} }
...@@ -412,6 +409,8 @@ function unixPathJoin(...paths: any[]): string { ...@@ -412,6 +409,8 @@ function unixPathJoin(...paths: any[]): string {
return dir; return dir;
} }
export {countFilesRecursively, validateFileNameRecursively, getRemoteTmpDir, generateParamFileName, getMsgDispatcherCommand, getCheckpointDir, export {
countFilesRecursively, validateFileNameRecursively, generateParamFileName, getMsgDispatcherCommand, getCheckpointDir,
getLogDir, getExperimentRootDir, getJobCancelStatus, getDefaultDatabaseDir, getIPV4Address, unixPathJoin, getLogDir, getExperimentRootDir, getJobCancelStatus, getDefaultDatabaseDir, getIPV4Address, unixPathJoin,
mkDirP, mkDirPSync, delay, prepareUnitTest, parseArg, cleanupUnitTest, uniqueString, randomSelect, getLogLevel, getVersion, getCmdPy, getTunerProc, isAlive, killPid, getNewLine }; mkDirP, mkDirPSync, delay, prepareUnitTest, parseArg, cleanupUnitTest, uniqueString, randomInt, randomSelect, getLogLevel, getVersion, getCmdPy, getTunerProc, isAlive, killPid, getNewLine
};
...@@ -266,7 +266,7 @@ class NNIManager implements Manager { ...@@ -266,7 +266,7 @@ class NNIManager implements Manager {
const delay1: Promise<{}> = new Promise((resolve: Function, reject: Function): void => { const delay1: Promise<{}> = new Promise((resolve: Function, reject: Function): void => {
timeoutId = setTimeout( timeoutId = setTimeout(
() => { reject(new Error('TrainingService setClusterMetadata timeout. Please check your config file.')); }, () => { reject(new Error('TrainingService setClusterMetadata timeout. Please check your config file.')); },
10000); 30000);
}); });
await Promise.race([delay1, this.trainingService.setClusterMetadata(key, value)]).finally(() => { await Promise.race([delay1, this.trainingService.setClusterMetadata(key, value)]).finally(() => {
clearTimeout(timeoutId); clearTimeout(timeoutId);
...@@ -368,7 +368,7 @@ class NNIManager implements Manager { ...@@ -368,7 +368,7 @@ class NNIManager implements Manager {
CUDA_VISIBLE_DEVICES: this.getGpuEnvvarValue() CUDA_VISIBLE_DEVICES: this.getGpuEnvvarValue()
}; };
const newEnv = Object.assign({}, process.env, nniEnv); const newEnv = Object.assign({}, process.env, nniEnv);
const tunerProc: ChildProcess = getTunerProc(command,stdio,newCwd,newEnv); const tunerProc: ChildProcess = getTunerProc(command, stdio, newCwd, newEnv);
this.dispatcherPid = tunerProc.pid; this.dispatcherPid = tunerProc.pid;
this.dispatcher = createDispatcherInterface(tunerProc); this.dispatcher = createDispatcherInterface(tunerProc);
...@@ -436,7 +436,9 @@ class NNIManager implements Manager { ...@@ -436,7 +436,9 @@ class NNIManager implements Manager {
} }
await killPid(this.dispatcherPid); await killPid(this.dispatcherPid);
const trialJobList: TrialJobDetail[] = await this.trainingService.listTrialJobs(); const trialJobList: TrialJobDetail[] = await this.trainingService.listTrialJobs();
// TO DO: to promise all
// DON'T try to make it in parallel, the training service may not handle it well.
// If there is performance concern, consider to support batch cancellation on training service.
for (const trialJob of trialJobList) { for (const trialJob of trialJobList) {
if (trialJob.status === 'RUNNING' || if (trialJob.status === 'RUNNING' ||
trialJob.status === 'WAITING') { trialJob.status === 'WAITING') {
...@@ -444,7 +446,7 @@ class NNIManager implements Manager { ...@@ -444,7 +446,7 @@ class NNIManager implements Manager {
this.log.info(`cancelTrialJob: ${trialJob.id}`); this.log.info(`cancelTrialJob: ${trialJob.id}`);
await this.trainingService.cancelTrialJob(trialJob.id); await this.trainingService.cancelTrialJob(trialJob.id);
} catch (error) { } catch (error) {
// pid does not exist, do nothing here this.log.debug(`ignorable error on canceling trial ${trialJob.id}. ${error}`);
} }
} }
} }
......
...@@ -174,10 +174,11 @@ export async function tarAdd(tarPath: string, sourcePath: string): Promise<void> ...@@ -174,10 +174,11 @@ export async function tarAdd(tarPath: string, sourcePath: string): Promise<void>
script.push( script.push(
`import os`, `import os`,
`import tarfile`, `import tarfile`,
String.Format(`tar = tarfile.open("{0}","w:gz")\r\nfor root,dir,files in os.walk("{1}"):`, tarFilePath, sourceFilePath), String.Format(`tar = tarfile.open("{0}","w:gz")\r\nroot="{1}"\r\nfor file_path,dir,files in os.walk(root):`, tarFilePath, sourceFilePath),
` for file in files:`, ` for file in files:`,
` fullpath = os.path.join(root,file)`, ` full_path = os.path.join(file_path, file)`,
` tar.add(fullpath, arcname=file)`, ` file = os.path.relpath(full_path, root)`,
` tar.add(full_path, arcname=file)`,
`tar.close()`); `tar.close()`);
await fs.promises.writeFile(path.join(os.tmpdir(), 'tar.py'), script.join(getNewLine()), { encoding: 'utf8', mode: 0o777 }); await fs.promises.writeFile(path.join(os.tmpdir(), 'tar.py'), script.join(getNewLine()), { encoding: 'utf8', mode: 0o777 });
const tarScript: string = path.join(os.tmpdir(), 'tar.py'); const tarScript: string = path.join(os.tmpdir(), 'tar.py');
......
...@@ -158,7 +158,7 @@ class PAIK8STrainingService extends PAITrainingService { ...@@ -158,7 +158,7 @@ class PAIK8STrainingService extends PAITrainingService {
if (this.paiTrialConfig === undefined) { if (this.paiTrialConfig === undefined) {
throw new Error('trial config is not initialized'); throw new Error('trial config is not initialized');
} }
const containerNFSExpCodeDir = `${this.paiTrialConfig.containerNFSMountPath}/${this.experimentId}/'nni-code`; const containerNFSExpCodeDir = `${this.paiTrialConfig.containerNFSMountPath}/${this.experimentId}/nni-code`;
const containerWorkingDir: string = `${this.paiTrialConfig.containerNFSMountPath}/${this.experimentId}/${trialJobDetail.id}`; const containerWorkingDir: string = `${this.paiTrialConfig.containerNFSMountPath}/${this.experimentId}/${trialJobDetail.id}`;
const nniManagerIp: string = this.nniManagerIpConfig ? this.nniManagerIpConfig.nniManagerIp : getIPV4Address(); const nniManagerIp: string = this.nniManagerIpConfig ? this.nniManagerIpConfig.nniManagerIp : getIPV4Address();
const nniPaiTrialCommand: string = String.Format( const nniPaiTrialCommand: string = String.Format(
......
...@@ -7,6 +7,36 @@ import { OsCommands } from "../osCommands"; ...@@ -7,6 +7,36 @@ import { OsCommands } from "../osCommands";
import { RemoteCommandResult } from "../remoteMachineData"; import { RemoteCommandResult } from "../remoteMachineData";
class LinuxCommands extends OsCommands { class LinuxCommands extends OsCommands {
public getScriptExt(): string {
return "sh";
}
public generateStartScript(workingDirectory: string, trialJobId: string, experimentId: string,
trialSequenceId: string, isMultiPhase: boolean, jobIdFileName: string,
command: string, nniManagerAddress: string, nniManagerPort: number,
nniManagerVersion: string, logCollection: string, exitCodeFile: string,
codeDir: string, cudaVisibleSetting: string): string {
return `#!/bin/bash
export NNI_PLATFORM=remote NNI_SYS_DIR=${workingDirectory} NNI_OUTPUT_DIR=${workingDirectory} NNI_TRIAL_JOB_ID=${trialJobId} \
NNI_EXP_ID=${experimentId} NNI_TRIAL_SEQ_ID=${trialSequenceId} NNI_CODE_DIR=${codeDir}
export MULTI_PHASE=${isMultiPhase}
cp -r $NNI_CODE_DIR/. $NNI_SYS_DIR
cd $NNI_SYS_DIR
sh install_nni.sh
python3 -m nni_trial_tool.trial_keeper --trial_command '${cudaVisibleSetting} ${command}' --nnimanager_ip '${nniManagerAddress}' \
--nnimanager_port '${nniManagerPort}' --nni_manager_version '${nniManagerVersion}' \
--job_id_file ${jobIdFileName} \
--log_collection '${logCollection}' 1>$NNI_OUTPUT_DIR/trialkeeper_stdout 2>$NNI_OUTPUT_DIR/trialkeeper_stderr
echo $? \`date +%s%3N\` >${exitCodeFile}`;
}
public generateGpuStatsScript(scriptFolder: string): string {
return `echo $$ > ${scriptFolder}/pid ; METRIC_OUTPUT_DIR=${scriptFolder} python3 -m nni_gpu_tool.gpu_metrics_collector`;
}
public createFolder(folderName: string, sharedFolder: boolean = false): string { public createFolder(folderName: string, sharedFolder: boolean = false): string {
let command; let command;
if (sharedFolder) { if (sharedFolder) {
...@@ -64,7 +94,19 @@ class LinuxCommands extends OsCommands { ...@@ -64,7 +94,19 @@ class LinuxCommands extends OsCommands {
} }
public killChildProcesses(pidFileName: string): string { public killChildProcesses(pidFileName: string): string {
const command = `pkill -P \`cat '${pidFileName}'\``; // prevent trialkeeper to be killed, so it can save exit code.
const command = `list_descendants ()
{
local children=$(ps -o pid= --ppid "$1")
for pid in $children
do
list_descendants "$pid"
done
echo "$children"
}
kill $(list_descendants \`cat '${pidFileName}'\`)`
return command; return command;
} }
......
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
import { OsCommands } from "../osCommands";
import { RemoteCommandResult } from "../remoteMachineData";
class WindowsCommands extends OsCommands {
protected pathSpliter: string = '\\';
public getScriptExt(): string {
return "cmd";
}
public generateStartScript(workingDirectory: string, trialJobId: string, experimentId: string,
trialSequenceId: string, isMultiPhase: boolean, jobIdFileName: string,
command: string, nniManagerAddress: string, nniManagerPort: number,
nniManagerVersion: string, logCollection: string, exitCodeFile: string,
codeDir: string, cudaVisibleSetting: string): string {
return `echo off
set NNI_PLATFORM=remote
set NNI_SYS_DIR=${workingDirectory}
set NNI_OUTPUT_DIR=${workingDirectory}
set NNI_TRIAL_JOB_ID=${trialJobId}
set NNI_EXP_ID=${experimentId}
set NNI_TRIAL_SEQ_ID=${trialSequenceId}
set MULTI_PHASE=${isMultiPhase}
set NNI_CODE_DIR=${codeDir}
${cudaVisibleSetting !== "" ? "set " + cudaVisibleSetting : ""}
robocopy /s %NNI_CODE_DIR%/. %NNI_SYS_DIR%
cd %NNI_SYS_DIR%
python -c "import nni" 2>nul
if not %ERRORLEVEL% EQU 0 (
echo installing NNI as exit code of "import nni" is %ERRORLEVEL%
python -m pip install --user --upgrade nni
)
echo starting script
python -m nni_trial_tool.trial_keeper --trial_command "${command}" --nnimanager_ip "${nniManagerAddress}" --nnimanager_port "${nniManagerPort}" --nni_manager_version "${nniManagerVersion}" --log_collection "${logCollection}" --job_id_file ${jobIdFileName} 1>%NNI_OUTPUT_DIR%/trialkeeper_stdout 2>%NNI_OUTPUT_DIR%/trialkeeper_stderr
echo save exit code(%ERRORLEVEL%) and time
echo|set /p="%ERRORLEVEL% " > ${exitCodeFile}
powershell -command "Write (((New-TimeSpan -Start (Get-Date "01/01/1970") -End (Get-Date).ToUniversalTime()).TotalMilliseconds).ToString("0")) | Out-file ${exitCodeFile} -Append -NoNewline -encoding utf8"`;
}
public generateGpuStatsScript(scriptFolder: string): string {
return `powershell -command $env:METRIC_OUTPUT_DIR='${scriptFolder}';$app = Start-Process -FilePath python -NoNewWindow -passthru -ArgumentList '-m nni_gpu_tool.gpu_metrics_collector' -RedirectStandardOutput ${scriptFolder}\\scriptstdout -RedirectStandardError ${scriptFolder}\\scriptstderr;Write $PID ^| Out-File ${scriptFolder}\\pid -NoNewline -encoding utf8;wait-process $app.ID`;
}
public createFolder(folderName: string, sharedFolder: boolean = false): string {
let command;
if (sharedFolder) {
command = `mkdir "${folderName}"\r\nICACLS "${folderName}" /grant "Users":F`;
} else {
command = `mkdir "${folderName}"`;
}
return command;
}
public allowPermission(isRecursive: boolean = false, ...folders: string[]): string {
let commands: string = "";
folders.forEach(folder => {
commands += `ICACLS "${folder}" /grant "Users":F${isRecursive ? " /T" : ""}\r\n`
});
return commands;
}
public removeFolder(folderName: string, isRecursive: boolean = false, isForce: boolean = true): string {
let flags = '';
if (isForce || isRecursive) {
flags = `${isRecursive ? ' /s' : ''}${isForce ? ' /q' : ''}`;
}
const command = `rmdir${flags} "${folderName}"`;
return command;
}
public removeFiles(folderName: string, filePattern: string): string {
const files = this.joinPath(folderName, filePattern);
const command = `del "${files}"`;
return command;
}
public readLastLines(fileName: string, lineCount: number = 1): string {
const command = `powershell.exe Get-Content "${fileName}" -Tail ${lineCount}`;
return command;
}
public isProcessAliveCommand(pidFileName: string): string {
const command = `powershell.exe Get-Process -Id (get-content "${pidFileName}") -ErrorAction SilentlyContinue`;
return command;
}
public isProcessAliveProcessOutput(commandResult: RemoteCommandResult): boolean {
let result = true;
if (commandResult.exitCode !== 0) {
result = false;
}
return result;
}
public killChildProcesses(pidFileName: string): string {
const command = `powershell "$ppid=(type ${pidFileName}); function Kill-Tree {Param([int]$subppid);` +
`Get-CimInstance Win32_Process | Where-Object { $_.ParentProcessId -eq $subppid } | ForEach-Object { Kill-Tree $_.ProcessId }; ` +
`if ($subppid -ne $ppid){Stop-Process -Id $subppid}}` +
`kill-tree $ppid"`;
return command;
}
public extractFile(tarFileName: string, targetFolder: string): string {
const command = `tar -xf "${tarFileName}" -C "${targetFolder}"`;
return command;
}
public executeScript(script: string, _isFile: boolean): string {
const command = `${script}`;
return command;
}
}
export { WindowsCommands };
...@@ -8,8 +8,16 @@ import { RemoteCommandResult } from "./remoteMachineData"; ...@@ -8,8 +8,16 @@ import { RemoteCommandResult } from "./remoteMachineData";
abstract class OsCommands { abstract class OsCommands {
protected pathSpliter: string = '/'; protected pathSpliter: string = '/';
protected multiplePathSpliter: RegExp = new RegExp(`\\${this.pathSpliter}{2,}`); protected multiplePathSpliter: RegExp = new RegExp(`[\\\\/]{2,}`);
protected normalizePath: RegExp = new RegExp(`[\\\\/]`);
public abstract getScriptExt(): string;
public abstract generateStartScript(workingDirectory: string, trialJobId: string, experimentId: string,
trialSequenceId: string, isMultiPhase: boolean, jobIdFileName: string,
command: string, nniManagerAddress: string, nniManagerPort: number,
nniManagerVersion: string, logCollection: string, exitCodeFile: string,
codeDir: string, cudaVisibleSetting: string): string;
public abstract generateGpuStatsScript(scriptFolder: string): string;
public abstract createFolder(folderName: string, sharedFolder: boolean): string; public abstract createFolder(folderName: string, sharedFolder: boolean): string;
public abstract allowPermission(isRecursive: boolean, ...folders: string[]): string; public abstract allowPermission(isRecursive: boolean, ...folders: string[]): string;
public abstract removeFolder(folderName: string, isRecursive: boolean, isForce: boolean): string; public abstract removeFolder(folderName: string, isRecursive: boolean, isForce: boolean): string;
...@@ -26,6 +34,9 @@ abstract class OsCommands { ...@@ -26,6 +34,9 @@ abstract class OsCommands {
if (dir === '') { if (dir === '') {
dir = '.'; dir = '.';
} else { } else {
// normalize
dir = dir.replace(this.normalizePath, this.pathSpliter);
// reduce duplicate ones
dir = dir.replace(this.multiplePathSpliter, this.pathSpliter); dir = dir.replace(this.multiplePathSpliter, this.pathSpliter);
} }
return dir; return dir;
......
...@@ -85,78 +85,82 @@ export class RemoteMachineTrialJobDetail implements TrialJobDetail { ...@@ -85,78 +85,82 @@ export class RemoteMachineTrialJobDetail implements TrialJobDetail {
* The remote machine executor manager * The remote machine executor manager
*/ */
export class ExecutorManager { export class ExecutorManager {
private readonly executorArray: ShellExecutor[]; private readonly executorMap: Map<string, ShellExecutor> = new Map<string, ShellExecutor>();
private readonly maxTrialNumberPerConnection: number;
private readonly rmMeta: RemoteMachineMeta; private readonly rmMeta: RemoteMachineMeta;
constructor(executorArray: ShellExecutor[], maxTrialNumberPerConnection: number, rmMeta: RemoteMachineMeta) {
private executors: ShellExecutor[] = [];
constructor(rmMeta: RemoteMachineMeta) {
this.rmMeta = rmMeta; this.rmMeta = rmMeta;
this.executorArray = executorArray;
this.maxTrialNumberPerConnection = maxTrialNumberPerConnection;
} }
/** public async getExecutor(id: string): Promise<ShellExecutor> {
* find a available executor, if no executor available, return a new one let isFound = false;
*/ let executor: ShellExecutor | undefined;
public async getAvailableExecutor(): Promise<ShellExecutor> {
for (const index of this.executorArray.keys()) {
const connectionNumber: number = this.executorArray[index].getUsedConnectionNumber;
if (connectionNumber < this.maxTrialNumberPerConnection) {
this.executorArray[index].addUsedConnectionNumber();
return this.executorArray[index]; // already assigned
if (this.executorMap.has(id)) {
executor = this.executorMap.get(id);
if (executor === undefined) {
throw new Error("executor shouldn't be undefined before return!");
} }
return executor;
} }
//init a new executor if could not get an available one for (const candidateExecutor of this.executors) {
return await this.initNewShellExecutor(); if (candidateExecutor.addUsage()) {
} isFound = true;
executor = candidateExecutor;
break;
}
}
// init a new executor if no free one.
if (!isFound) {
executor = await this.createShellExecutor();
}
/** if (executor === undefined) {
* add a new executor to executorArray throw new Error("executor shouldn't be undefined before set!");
* @param executor ShellExecutor }
*/ this.executorMap.set(id, executor);
public addNewShellExecutor(executor: ShellExecutor): void {
this.executorArray.push(executor);
}
/** return executor;
* first executor instance is used for gpu collector and host job
*/
public getFirstExecutor(): ShellExecutor {
return this.executorArray[0];
} }
/** /**
* close all of executor * close all of executor
*/ */
public closeAllExecutor(): void { public releaseAllExecutor(): void {
for (const executor of this.executorArray) { this.executorMap.clear();
for (const executor of this.executors) {
executor.close(); executor.close();
} }
this.executors = [];
} }
/** /**
* retrieve resource, minus a number for given executor * retrieve resource, minus a number for given executor
* @param executor executor * @param executor executor
*/ */
public releaseConnection(executor: ShellExecutor | undefined): void { public releaseExecutor(id: string): void {
const executor = this.executorMap.get(id);
if (executor === undefined) { if (executor === undefined) {
throw new Error(`could not release a undefined executor`); throw new Error(`executor for ${id} is not found`);
}
for (const index of this.executorArray.keys()) {
if (this.executorArray[index] === executor) {
this.executorArray[index].minusUsedConnectionNumber();
break;
}
} }
executor.releaseUsage();
this.executorMap.delete(id);
} }
/** /**
* Create a new connection executor and initialize it * Create a new connection executor and initialize it
*/ */
private async initNewShellExecutor(): Promise<ShellExecutor> { private async createShellExecutor(): Promise<ShellExecutor> {
const executor = new ShellExecutor(); const executor = new ShellExecutor();
await executor.initialize(this.rmMeta); await executor.initialize(this.rmMeta);
if (!executor.addUsage()) {
throw new Error("failed to add usage on new created Executor! It's a wired bug!");
}
this.executors.push(executor);
return executor; return executor;
} }
} }
...@@ -175,22 +179,3 @@ export enum ScheduleResultType { ...@@ -175,22 +179,3 @@ export enum ScheduleResultType {
// Cannot match requirement even if all GPU are a // Cannot match requirement even if all GPU are a
REQUIRE_EXCEED_TOTAL REQUIRE_EXCEED_TOTAL
} }
export const REMOTEMACHINE_TRIAL_COMMAND_FORMAT: string =
`#!/bin/bash
export NNI_PLATFORM=remote NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={2} NNI_EXP_ID={3} \
NNI_TRIAL_SEQ_ID={4} MULTI_PHASE={5} NNI_CODE_DIR={6}
cp -r $NNI_CODE_DIR/. $NNI_SYS_DIR
cd $NNI_SYS_DIR
sh install_nni.sh
echo $$ >{7}
python3 -m nni_trial_tool.trial_keeper --trial_command '{8}' --nnimanager_ip '{9}' --nnimanager_port '{10}' \
--nni_manager_version '{11}' --log_collection '{12}' 1>$NNI_OUTPUT_DIR/trialkeeper_stdout 2>$NNI_OUTPUT_DIR/trialkeeper_stderr
echo $? \`date +%s%3N\` >{13}`;
export const HOST_JOB_SHELL_FORMAT: string =
`#!/bin/bash
cd {0}
echo $$ >{1}
eval {2} >stdout 2>stderr
echo $? \`date +%s%3N\` >{3}`;
...@@ -8,7 +8,6 @@ import { EventEmitter } from 'events'; ...@@ -8,7 +8,6 @@ import { EventEmitter } from 'events';
import * as fs from 'fs'; import * as fs from 'fs';
import * as path from 'path'; import * as path from 'path';
import { Deferred } from 'ts-deferred'; import { Deferred } from 'ts-deferred';
import { String } from 'typescript-string-operations';
import * as component from '../../common/component'; import * as component from '../../common/component';
import { NNIError, NNIErrorNames } from '../../common/errors'; import { NNIError, NNIErrorNames } from '../../common/errors';
import { getExperimentId } from '../../common/experimentStartupInfo'; import { getExperimentId } from '../../common/experimentStartupInfo';
...@@ -19,17 +18,17 @@ import { ...@@ -19,17 +18,17 @@ import {
TrialJobDetail, TrialJobMetric TrialJobDetail, TrialJobMetric
} from '../../common/trainingService'; } from '../../common/trainingService';
import { import {
delay, generateParamFileName, getExperimentRootDir, getIPV4Address, getJobCancelStatus, getRemoteTmpDir, delay, generateParamFileName, getExperimentRootDir, getIPV4Address, getJobCancelStatus,
getVersion, uniqueString, unixPathJoin getVersion, uniqueString
} from '../../common/utils'; } from '../../common/utils';
import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData'; import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData';
import { GPUSummary } from '../common/gpuData'; import { GPUSummary } from '../common/gpuData';
import { TrialConfig } from '../common/trialConfig'; import { TrialConfig } from '../common/trialConfig';
import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey'; import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
import { execMkdir, validateCodeDir, getGpuMetricsCollectorBashScriptContent } from '../common/util'; import { execMkdir, validateCodeDir } from '../common/util';
import { GPUScheduler } from './gpuScheduler'; import { GPUScheduler } from './gpuScheduler';
import { import {
REMOTEMACHINE_TRIAL_COMMAND_FORMAT, RemoteMachineMeta, RemoteMachineMeta,
RemoteMachineScheduleInfo, RemoteMachineScheduleResult, RemoteMachineTrialJobDetail, RemoteMachineScheduleInfo, RemoteMachineScheduleResult, RemoteMachineTrialJobDetail,
ScheduleResultType, ExecutorManager ScheduleResultType, ExecutorManager
} from './remoteMachineData'; } from './remoteMachineData';
...@@ -41,14 +40,12 @@ import { ShellExecutor } from 'training_service/remote_machine/shellExecutor'; ...@@ -41,14 +40,12 @@ import { ShellExecutor } from 'training_service/remote_machine/shellExecutor';
*/ */
@component.Singleton @component.Singleton
class RemoteMachineTrainingService implements TrainingService { class RemoteMachineTrainingService implements TrainingService {
private readonly initExecutorId = "initConnection";
private readonly machineExecutorManagerMap: Map<RemoteMachineMeta, ExecutorManager>; //machine excutor map private readonly machineExecutorManagerMap: Map<RemoteMachineMeta, ExecutorManager>; //machine excutor map
private readonly machineCopyExpCodeDirPromiseMap: Map<RemoteMachineMeta, Promise<void>>; private readonly machineCopyExpCodeDirPromiseMap: Map<RemoteMachineMeta, Promise<void>>;
private readonly trialExecutorMap: Map<string, ShellExecutor>; //trial excutor map private readonly trialExecutorManagerMap: Map<string, ExecutorManager>; //trial excutor map
private readonly trialJobsMap: Map<string, RemoteMachineTrialJobDetail>; private readonly trialJobsMap: Map<string, RemoteMachineTrialJobDetail>;
private readonly MAX_TRIAL_NUMBER_PER_EXECUTOR: number = 5; // every excutor has a max trial concurrency number
private readonly expRootDir: string; private readonly expRootDir: string;
private readonly remoteExpRootDir: string;
private readonly remoteExpCodeDir: string;
private trialConfig: TrialConfig | undefined; private trialConfig: TrialConfig | undefined;
private gpuScheduler?: GPUScheduler; private gpuScheduler?: GPUScheduler;
private readonly jobQueue: string[]; private readonly jobQueue: string[];
...@@ -57,27 +54,21 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -57,27 +54,21 @@ class RemoteMachineTrainingService implements TrainingService {
private readonly metricsEmitter: EventEmitter; private readonly metricsEmitter: EventEmitter;
private readonly log: Logger; private readonly log: Logger;
private isMultiPhase: boolean = false; private isMultiPhase: boolean = false;
private trialSequenceId: number;
private remoteRestServerPort?: number; private remoteRestServerPort?: number;
private readonly remoteOS: string;
private nniManagerIpConfig?: NNIManagerIpConfig; private nniManagerIpConfig?: NNIManagerIpConfig;
private versionCheck: boolean = true; private versionCheck: boolean = true;
private logCollection: string; private logCollection: string;
constructor(@component.Inject timer: ObservableTimer) { constructor(@component.Inject timer: ObservableTimer) {
this.remoteOS = 'linux';
this.metricsEmitter = new EventEmitter(); this.metricsEmitter = new EventEmitter();
this.trialJobsMap = new Map<string, RemoteMachineTrialJobDetail>(); this.trialJobsMap = new Map<string, RemoteMachineTrialJobDetail>();
this.trialExecutorMap = new Map<string, ShellExecutor>(); this.trialExecutorManagerMap = new Map<string, ExecutorManager>();
this.machineExecutorManagerMap = new Map<RemoteMachineMeta, ExecutorManager>();
this.machineCopyExpCodeDirPromiseMap = new Map<RemoteMachineMeta, Promise<void>>(); this.machineCopyExpCodeDirPromiseMap = new Map<RemoteMachineMeta, Promise<void>>();
this.machineExecutorManagerMap = new Map<RemoteMachineMeta, ExecutorManager>();
this.jobQueue = []; this.jobQueue = [];
this.expRootDir = getExperimentRootDir(); this.expRootDir = getExperimentRootDir();
this.remoteExpRootDir = this.getRemoteExperimentRootDir();
this.remoteExpCodeDir = unixPathJoin(this.remoteExpRootDir, 'nni-code');
this.timer = timer; this.timer = timer;
this.log = getLogger(); this.log = getLogger();
this.trialSequenceId = -1;
this.logCollection = 'none'; this.logCollection = 'none';
this.log.info('Construct remote machine training service.'); this.log.info('Construct remote machine training service.');
} }
...@@ -110,14 +101,14 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -110,14 +101,14 @@ class RemoteMachineTrainingService implements TrainingService {
} }
await delay(3000); await delay(3000);
} }
this.log.info('Remote machine training service exit.'); this.log.info('RemoteMachineTrainingService run loop exited.');
} }
/** /**
* give trial an executor * give trial an executor
* @param trial remote machine trial job detail * @param trial remote machine trial job detail
*/ */
public async allocateExecutorForTrial(trial: RemoteMachineTrialJobDetail): Promise<void> { public allocateExecutorManagerForTrial(trial: RemoteMachineTrialJobDetail): void {
if (trial.rmMeta === undefined) { if (trial.rmMeta === undefined) {
throw new Error(`rmMeta not set in trial ${trial.id}`); throw new Error(`rmMeta not set in trial ${trial.id}`);
} }
...@@ -125,23 +116,23 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -125,23 +116,23 @@ class RemoteMachineTrainingService implements TrainingService {
if (executorManager === undefined) { if (executorManager === undefined) {
throw new Error(`executorManager not initialized`); throw new Error(`executorManager not initialized`);
} }
const shellExecutor: ShellExecutor = await executorManager.getAvailableExecutor(); this.trialExecutorManagerMap.set(trial.id, executorManager);
this.trialExecutorMap.set(trial.id, shellExecutor);
} }
/** /**
* If a trial is finished, release the connection resource * If a trial is finished, release the connection resource
* @param trial remote machine trial job detail * @param trial remote machine trial job detail
*/ */
public releaseTrialExecutor(trial: RemoteMachineTrialJobDetail): void { public releaseTrialResource(trial: RemoteMachineTrialJobDetail): void {
if (trial.rmMeta === undefined) { if (trial.rmMeta === undefined) {
throw new Error(`rmMeta not set in trial ${trial.id}`); throw new Error(`rmMeta not set in trial ${trial.id}`);
} }
const executorManager: ExecutorManager | undefined = this.machineExecutorManagerMap.get(trial.rmMeta); const executorManager = this.trialExecutorManagerMap.get(trial.id);
if (executorManager === undefined) { if (executorManager === undefined) {
throw new Error(`executorManager not initialized`); throw new Error(`ExecutorManager is not assigned for trial ${trial.id}`);
} }
executorManager.releaseConnection(this.trialExecutorMap.get(trial.id)); // Note, it still keep reference in trialExecutorManagerMap, as there may be following requests from nni manager.
executorManager.releaseExecutor(trial.id);
} }
/** /**
...@@ -174,10 +165,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -174,10 +165,7 @@ class RemoteMachineTrainingService implements TrainingService {
if (trialJob.rmMeta === undefined) { if (trialJob.rmMeta === undefined) {
throw new Error(`rmMeta not set for submitted job ${trialJobId}`); throw new Error(`rmMeta not set for submitted job ${trialJobId}`);
} }
const executor: ShellExecutor | undefined = this.trialExecutorMap.get(trialJob.id); const executor = await this.getExecutor(trialJob.id);
if (executor === undefined) {
throw new Error(`Invalid job id: ${trialJobId}, cannot find executor`);
}
return this.updateTrialJobStatus(trialJob, executor); return this.updateTrialJobStatus(trialJob, executor);
} else { } else {
...@@ -212,13 +200,12 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -212,13 +200,12 @@ class RemoteMachineTrainingService implements TrainingService {
// Generate trial job id(random) // Generate trial job id(random)
const trialJobId: string = uniqueString(5); const trialJobId: string = uniqueString(5);
const trialWorkingFolder: string = unixPathJoin(this.remoteExpRootDir, 'trials', trialJobId);
const trialJobDetail: RemoteMachineTrialJobDetail = new RemoteMachineTrialJobDetail( const trialJobDetail: RemoteMachineTrialJobDetail = new RemoteMachineTrialJobDetail(
trialJobId, trialJobId,
'WAITING', 'WAITING',
Date.now(), Date.now(),
trialWorkingFolder, "unset",
form form
); );
this.jobQueue.push(trialJobId); this.jobQueue.push(trialJobId);
...@@ -268,26 +255,23 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -268,26 +255,23 @@ class RemoteMachineTrainingService implements TrainingService {
// Get executor where the job is running // Get executor where the job is running
if (trialJob.rmMeta !== undefined) { if (trialJob.rmMeta !== undefined) {
// If the trial job is already scheduled, check its status and kill the trial process in remote machine // If the trial job is already scheduled, check its status and kill the trial process in remote machine
const executor: ShellExecutor | undefined = this.trialExecutorMap.get(trialJob.id); const executor = await this.getExecutor(trialJob.id);
if (executor === undefined) {
throw new Error(`Invalid job id ${trialJobId}, cannot find executor`);
}
if (trialJob.status === 'UNKNOWN') { if (trialJob.status === 'UNKNOWN') {
this.releaseTrialExecutor(trialJob);
trialJob.status = 'USER_CANCELED'; trialJob.status = 'USER_CANCELED';
this.releaseTrialResource(trialJob);
return return
} }
const jobpidPath: string = this.getJobPidPath(trialJob.id); const jobpidPath: string = this.getJobPidPath(executor, trialJob.id);
try { try {
// Mark the toEarlyStop tag here // Mark the toEarlyStop tag here
trialJob.isEarlyStopped = isEarlyStopped; trialJob.isEarlyStopped = isEarlyStopped;
await executor.killChildProcesses(jobpidPath); await executor.killChildProcesses(jobpidPath);
this.releaseTrialExecutor(trialJob); this.releaseTrialResource(trialJob);
} catch (error) { } catch (error) {
// Not handle the error since pkill failed will not impact trial job's current status // Not handle the error since pkill failed will not impact trial job's current status
this.log.error(`remoteTrainingService.cancelTrialJob: ${error.message}`); this.log.error(`remoteTrainingService.cancelTrialJob: ${error}`);
} }
} else { } else {
// Job is not scheduled yet, set status to 'USER_CANCELLED' directly // Job is not scheduled yet, set status to 'USER_CANCELLED' directly
...@@ -329,15 +313,15 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -329,15 +313,15 @@ class RemoteMachineTrainingService implements TrainingService {
await validateCodeDir(remoteMachineTrailConfig.codeDir); await validateCodeDir(remoteMachineTrailConfig.codeDir);
// Copy codeDir to remote machine // Copy codeDir to remote machine
for (const [rmMeta, executorManager] of this.machineExecutorManagerMap.entries()) { for (const [rmMeta, executorManager] of this.machineExecutorManagerMap.entries()) {
const executor: ShellExecutor = await executorManager.getAvailableExecutor(); const executor: ShellExecutor = await executorManager.getExecutor(this.initExecutorId);
if (executor !== undefined) { if (executor !== undefined) {
this.machineCopyExpCodeDirPromiseMap.set( this.machineCopyExpCodeDirPromiseMap.set(
rmMeta, rmMeta,
executor.copyDirectoryToRemote(remoteMachineTrailConfig.codeDir, this.remoteExpCodeDir, this.remoteOS) executor.copyDirectoryToRemote(remoteMachineTrailConfig.codeDir, executor.getRemoteCodePath(getExperimentId()))
); );
} }
} }
} catch (error) { } catch (error) {
this.log.error(error); this.log.error(error);
...@@ -376,7 +360,15 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -376,7 +360,15 @@ class RemoteMachineTrainingService implements TrainingService {
public async cleanUp(): Promise<void> { public async cleanUp(): Promise<void> {
this.log.info('Stopping remote machine training service...'); this.log.info('Stopping remote machine training service...');
this.stopping = true; this.stopping = true;
await Promise.race([delay(10000), this.cleanupConnections()]); await this.cleanupConnections();
}
private async getExecutor(trialId: string): Promise<ShellExecutor> {
const executorManager = this.trialExecutorManagerMap.get(trialId);
if (executorManager === undefined) {
throw new Error(`ExecutorManager is not assigned for trial ${trialId}`);
}
return await executorManager.getExecutor(trialId);
} }
/** /**
...@@ -397,21 +389,19 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -397,21 +389,19 @@ class RemoteMachineTrainingService implements TrainingService {
*/ */
private async cleanupConnections(): Promise<void> { private async cleanupConnections(): Promise<void> {
try { try {
for (const [rmMeta, executorManager] of this.machineExecutorManagerMap.entries()) { for (const executorManager of this.machineExecutorManagerMap.values()) {
const jobpidPath: string = unixPathJoin(this.getRemoteScriptsPath(rmMeta.username), 'pid'); const executor = await executorManager.getExecutor(this.initExecutorId);
const executor: ShellExecutor | undefined = executorManager.getFirstExecutor();
if (executor !== undefined) { if (executor !== undefined) {
await executor.killChildProcesses(jobpidPath); this.log.info(`killing gpu metric collector on ${executor.name}`);
await executor.removeFolder(this.getRemoteScriptsPath(rmMeta.username)); const gpuJobPidPath: string = executor.joinPath(executor.getRemoteScriptsPath(getExperimentId()), 'pid');
await executor.killChildProcesses(gpuJobPidPath);
} }
executorManager.closeAllExecutor(); executorManager.releaseAllExecutor();
} }
} catch (error) { } catch (error) {
//ignore error, this function is called to cleanup remote connections when experiment is stopping //ignore error, this function is called to cleanup remote connections when experiment is stopping
this.log.error(`Cleanup connection exception, error is ${error.message}`); this.log.error(`Cleanup connection exception, error is ${error}`);
} }
return Promise.resolve();
} }
private async setupConnections(machineList: string): Promise<void> { private async setupConnections(machineList: string): Promise<void> {
...@@ -423,10 +413,14 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -423,10 +413,14 @@ class RemoteMachineTrainingService implements TrainingService {
rmMetaList.forEach(async (rmMeta: RemoteMachineMeta) => { rmMetaList.forEach(async (rmMeta: RemoteMachineMeta) => {
rmMeta.occupiedGpuIndexMap = new Map<number, number>(); rmMeta.occupiedGpuIndexMap = new Map<number, number>();
const executorManager: ExecutorManager = new ExecutorManager([], this.MAX_TRIAL_NUMBER_PER_EXECUTOR, rmMeta); const executorManager: ExecutorManager = new ExecutorManager(rmMeta);
const executor: ShellExecutor = await executorManager.getAvailableExecutor(); this.log.info(`connecting to ${rmMeta.username}@${rmMeta.ip}:${rmMeta.port}`);
const executor: ShellExecutor = await executorManager.getExecutor(this.initExecutorId);
this.log.debug(`reached ${executor.name}`);
this.machineExecutorManagerMap.set(rmMeta, executorManager); this.machineExecutorManagerMap.set(rmMeta, executorManager);
this.log.debug(`initializing ${executor.name}`);
await this.initRemoteMachineOnConnected(rmMeta, executor); await this.initRemoteMachineOnConnected(rmMeta, executor);
this.log.info(`connected to ${executor.name}`);
if (++connectedRMNum === rmMetaList.length) { if (++connectedRMNum === rmMetaList.length) {
deferred.resolve(); deferred.resolve();
} }
...@@ -437,27 +431,36 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -437,27 +431,36 @@ class RemoteMachineTrainingService implements TrainingService {
private async initRemoteMachineOnConnected(rmMeta: RemoteMachineMeta, executor: ShellExecutor): Promise<void> { private async initRemoteMachineOnConnected(rmMeta: RemoteMachineMeta, executor: ShellExecutor): Promise<void> {
// Create root working directory after executor is ready // Create root working directory after executor is ready
const nniRootDir: string = unixPathJoin(getRemoteTmpDir(this.remoteOS), 'nni'); const nniRootDir: string = executor.joinPath(executor.getTempPath(), 'nni');
await executor.createFolder(this.remoteExpRootDir); await executor.createFolder(executor.getRemoteExperimentRootDir(getExperimentId()));
// the directory to store temp scripts in remote machine // the directory to store temp scripts in remote machine
const remoteGpuScriptCollectorDir: string = this.getRemoteScriptsPath(rmMeta.username); const remoteGpuScriptCollectorDir: string = executor.getRemoteScriptsPath(getExperimentId());
// clean up previous result.
await executor.createFolder(remoteGpuScriptCollectorDir, true); await executor.createFolder(remoteGpuScriptCollectorDir, true);
await executor.allowPermission(false, nniRootDir, `${nniRootDir}/*`, `${nniRootDir}/scripts/*`); await executor.allowPermission(false, nniRootDir, `${nniRootDir}/*`, `${nniRootDir}/scripts/*`);
//Begin to execute gpu_metrics_collection scripts //Begin to execute gpu_metrics_collection scripts
const script = getGpuMetricsCollectorBashScriptContent(remoteGpuScriptCollectorDir); const script = executor.generateGpuStatsScript(getExperimentId());
executor.executeScript(script, false, true); executor.executeScript(script, false, true);
// the timer is trigger in 1 second, it causes multiple runs on server.
// So reduce it's freqeunce, only allow one of it run.
const collectingCount: boolean[] = [];
const disposable: Rx.IDisposable = this.timer.subscribe( const disposable: Rx.IDisposable = this.timer.subscribe(
async () => { async () => {
const cmdresult = await executor.readLastLines(unixPathJoin(remoteGpuScriptCollectorDir, 'gpu_metrics')); if (collectingCount.length == 0) {
if (cmdresult !== "") { collectingCount.push(true);
rmMeta.gpuSummary = <GPUSummary>JSON.parse(cmdresult); const cmdresult = await executor.readLastLines(executor.joinPath(remoteGpuScriptCollectorDir, 'gpu_metrics'));
if (rmMeta.gpuSummary.gpuCount === 0) { if (cmdresult !== "") {
this.log.warning(`No GPU found on remote machine ${rmMeta.ip}`); rmMeta.gpuSummary = <GPUSummary>JSON.parse(cmdresult);
this.timer.unsubscribe(disposable); if (rmMeta.gpuSummary.gpuCount === 0) {
this.log.warning(`No GPU found on remote machine ${rmMeta.ip}`);
this.timer.unsubscribe(disposable);
}
} }
collectingCount.pop();
} }
} }
); );
...@@ -492,7 +495,6 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -492,7 +495,6 @@ class RemoteMachineTrainingService implements TrainingService {
} else if (rmScheduleResult.resultType === ScheduleResultType.SUCCEED } else if (rmScheduleResult.resultType === ScheduleResultType.SUCCEED
&& rmScheduleResult.scheduleInfo !== undefined) { && rmScheduleResult.scheduleInfo !== undefined) {
const rmScheduleInfo: RemoteMachineScheduleInfo = rmScheduleResult.scheduleInfo; const rmScheduleInfo: RemoteMachineScheduleInfo = rmScheduleResult.scheduleInfo;
const trialWorkingFolder: string = unixPathJoin(this.remoteExpRootDir, 'trials', trialJobId);
trialJobDetail.rmMeta = rmScheduleInfo.rmMeta; trialJobDetail.rmMeta = rmScheduleInfo.rmMeta;
const copyExpCodeDirPromise = this.machineCopyExpCodeDirPromiseMap.get(trialJobDetail.rmMeta); const copyExpCodeDirPromise = this.machineCopyExpCodeDirPromiseMap.get(trialJobDetail.rmMeta);
...@@ -500,12 +502,16 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -500,12 +502,16 @@ class RemoteMachineTrainingService implements TrainingService {
await copyExpCodeDirPromise; await copyExpCodeDirPromise;
} }
await this.allocateExecutorForTrial(trialJobDetail); this.allocateExecutorManagerForTrial(trialJobDetail);
const executor = await this.getExecutor(trialJobDetail.id);
trialJobDetail.workingDirectory = executor.joinPath(executor.getRemoteExperimentRootDir(getExperimentId()), 'trials', trialJobDetail.id);
await this.launchTrialOnScheduledMachine( await this.launchTrialOnScheduledMachine(
trialJobId, trialWorkingFolder, trialJobDetail.form, rmScheduleInfo); trialJobId, trialJobDetail.form, rmScheduleInfo);
trialJobDetail.status = 'RUNNING'; trialJobDetail.status = 'RUNNING';
trialJobDetail.url = `file://${rmScheduleInfo.rmMeta.ip}:${trialWorkingFolder}`; trialJobDetail.url = `file://${rmScheduleInfo.rmMeta.ip}:${trialJobDetail.workingDirectory}`;
trialJobDetail.startTime = Date.now(); trialJobDetail.startTime = Date.now();
this.trialJobsMap.set(trialJobId, trialJobDetail); this.trialJobsMap.set(trialJobId, trialJobDetail);
...@@ -520,19 +526,13 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -520,19 +526,13 @@ class RemoteMachineTrainingService implements TrainingService {
return deferred.promise; return deferred.promise;
} }
private async launchTrialOnScheduledMachine(trialJobId: string, trialWorkingFolder: string, form: TrialJobApplicationForm, private async launchTrialOnScheduledMachine(trialJobId: string, form: TrialJobApplicationForm,
rmScheduleInfo: RemoteMachineScheduleInfo): Promise<void> { rmScheduleInfo: RemoteMachineScheduleInfo): Promise<void> {
if (this.trialConfig === undefined) { if (this.trialConfig === undefined) {
throw new Error('trial config is not initialized'); throw new Error('trial config is not initialized');
} }
const cudaVisibleDevice: string = rmScheduleInfo.cudaVisibleDevice; const cudaVisibleDevice: string = rmScheduleInfo.cudaVisibleDevice;
const executor: ShellExecutor | undefined = this.trialExecutorMap.get(trialJobId); const executor = await this.getExecutor(trialJobId);
if (executor === undefined) {
assert(false, 'ShellExecutor is undefined.');
// for lint
return;
}
const trialJobDetail: RemoteMachineTrialJobDetail | undefined = this.trialJobsMap.get(trialJobId); const trialJobDetail: RemoteMachineTrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
if (trialJobDetail === undefined) { if (trialJobDetail === undefined) {
throw new Error(`Can not get trial job detail for job: ${trialJobId}`); throw new Error(`Can not get trial job detail for job: ${trialJobId}`);
...@@ -540,23 +540,22 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -540,23 +540,22 @@ class RemoteMachineTrainingService implements TrainingService {
const trialLocalTempFolder: string = path.join(this.expRootDir, 'trials-local', trialJobId); const trialLocalTempFolder: string = path.join(this.expRootDir, 'trials-local', trialJobId);
await executor.createFolder(trialWorkingFolder); await executor.createFolder(executor.joinPath(trialJobDetail.workingDirectory, '.nni'));
await executor.createFolder(unixPathJoin(trialWorkingFolder, '.nni'));
// RemoteMachineRunShellFormat is the run shell format string, // RemoteMachineRunShellFormat is the run shell format string,
// See definition in remoteMachineData.ts // See definition in remoteMachineData.ts
let command: string; let cudaVisible: string;
// Set CUDA_VISIBLE_DEVICES environment variable based on cudaVisibleDevice // Set CUDA_VISIBLE_DEVICES environment variable based on cudaVisibleDevice
// If no valid cudaVisibleDevice is defined, set CUDA_VISIBLE_DEVICES to empty string to hide GPU device // If no valid cudaVisibleDevice is defined, set CUDA_VISIBLE_DEVICES to empty string to hide GPU device
// If gpuNum is undefined, will not set CUDA_VISIBLE_DEVICES in script // If gpuNum is undefined, will not set CUDA_VISIBLE_DEVICES in script
if (this.trialConfig.gpuNum === undefined) { if (this.trialConfig.gpuNum === undefined) {
command = this.trialConfig.command; cudaVisible = ""
} else { } else {
if (typeof cudaVisibleDevice === 'string' && cudaVisibleDevice.length > 0) { if (typeof cudaVisibleDevice === 'string' && cudaVisibleDevice.length > 0) {
command = `CUDA_VISIBLE_DEVICES=${cudaVisibleDevice} ${this.trialConfig.command}`; cudaVisible = `CUDA_VISIBLE_DEVICES=${cudaVisibleDevice}`;
} else { } else {
command = `CUDA_VISIBLE_DEVICES=" " ${this.trialConfig.command}`; cudaVisible = `CUDA_VISIBLE_DEVICES=" "`;
} }
} }
const nniManagerIp: string = this.nniManagerIpConfig ? this.nniManagerIpConfig.nniManagerIp : getIPV4Address(); const nniManagerIp: string = this.nniManagerIpConfig ? this.nniManagerIpConfig.nniManagerIp : getIPV4Address();
...@@ -565,50 +564,36 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -565,50 +564,36 @@ class RemoteMachineTrainingService implements TrainingService {
this.remoteRestServerPort = restServer.clusterRestServerPort; this.remoteRestServerPort = restServer.clusterRestServerPort;
} }
const version: string = this.versionCheck ? await getVersion() : ''; const version: string = this.versionCheck ? await getVersion() : '';
const runScriptTrialContent: string = String.Format( const runScriptTrialContent: string = executor.generateStartScript(
REMOTEMACHINE_TRIAL_COMMAND_FORMAT, trialJobDetail.workingDirectory,
trialWorkingFolder,
trialWorkingFolder,
trialJobId, trialJobId,
getExperimentId(), getExperimentId(),
trialJobDetail.form.sequenceId.toString(), trialJobDetail.form.sequenceId.toString(),
this.isMultiPhase, this.isMultiPhase,
this.remoteExpCodeDir, this.trialConfig.command,
unixPathJoin(trialWorkingFolder, '.nni', 'jobpid'),
command,
nniManagerIp, nniManagerIp,
this.remoteRestServerPort, this.remoteRestServerPort,
version, version,
this.logCollection, this.logCollection, cudaVisible);
unixPathJoin(trialWorkingFolder, '.nni', 'code')
);
//create tmp trial working folder locally. //create tmp trial working folder locally.
await execMkdir(path.join(trialLocalTempFolder, '.nni')); await execMkdir(path.join(trialLocalTempFolder, '.nni'));
// Write install_nni.sh
await fs.promises.writeFile(path.join(trialLocalTempFolder, 'install_nni.sh'), CONTAINER_INSTALL_NNI_SHELL_FORMAT, { encoding: 'utf8' }); // Write install_nni.sh, it's not used in Windows platform.
await fs.promises.writeFile(path.join(trialLocalTempFolder, executor.getScriptName("install_nni")), CONTAINER_INSTALL_NNI_SHELL_FORMAT, { encoding: 'utf8' });
// Write file content ( run.sh and parameter.cfg ) to local tmp files // Write file content ( run.sh and parameter.cfg ) to local tmp files
await fs.promises.writeFile(path.join(trialLocalTempFolder, 'run.sh'), runScriptTrialContent, { encoding: 'utf8' }); await fs.promises.writeFile(path.join(trialLocalTempFolder, executor.getScriptName("run")), runScriptTrialContent, { encoding: 'utf8' });
await this.writeParameterFile(trialJobId, form.hyperParameters); await this.writeParameterFile(trialJobId, form.hyperParameters);
// Copy files in codeDir to remote working directory // Copy files in codeDir to remote working directory
await executor.copyDirectoryToRemote(trialLocalTempFolder, trialWorkingFolder, this.remoteOS); await executor.copyDirectoryToRemote(trialLocalTempFolder, trialJobDetail.workingDirectory);
// Execute command in remote machine // Execute command in remote machine
executor.executeScript(unixPathJoin(trialWorkingFolder, 'run.sh'), true, true); executor.executeScript(executor.joinPath(trialJobDetail.workingDirectory, executor.getScriptName("run")), true, true);
}
private getRmMetaByHost(host: string): RemoteMachineMeta {
for (const rmMeta of this.machineExecutorManagerMap.keys()) {
if (rmMeta.ip === host) {
return rmMeta;
}
}
throw new Error(`Host not found: ${host}`);
} }
private async updateTrialJobStatus(trialJob: RemoteMachineTrialJobDetail, executor: ShellExecutor): Promise<TrialJobDetail> { private async updateTrialJobStatus(trialJob: RemoteMachineTrialJobDetail, executor: ShellExecutor): Promise<TrialJobDetail> {
const deferred: Deferred<TrialJobDetail> = new Deferred<TrialJobDetail>(); const deferred: Deferred<TrialJobDetail> = new Deferred<TrialJobDetail>();
const jobpidPath: string = this.getJobPidPath(trialJob.id); const jobpidPath: string = this.getJobPidPath(executor, trialJob.id);
const trialReturnCodeFilePath: string = unixPathJoin(this.remoteExpRootDir, 'trials', trialJob.id, '.nni', 'code'); const trialReturnCodeFilePath: string = executor.joinPath(executor.getRemoteExperimentRootDir(getExperimentId()), 'trials', trialJob.id, '.nni', 'code');
/* eslint-disable require-atomic-updates */ /* eslint-disable require-atomic-updates */
try { try {
const isAlive = await executor.isProcessAlive(jobpidPath); const isAlive = await executor.isProcessAlive(jobpidPath);
...@@ -617,7 +602,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -617,7 +602,7 @@ class RemoteMachineTrainingService implements TrainingService {
const trialReturnCode: string = await executor.getRemoteFileContent(trialReturnCodeFilePath); const trialReturnCode: string = await executor.getRemoteFileContent(trialReturnCodeFilePath);
this.log.debug(`trailjob ${trialJob.id} return code: ${trialReturnCode}`); this.log.debug(`trailjob ${trialJob.id} return code: ${trialReturnCode}`);
const match: RegExpMatchArray | null = trialReturnCode.trim() const match: RegExpMatchArray | null = trialReturnCode.trim()
.match(/^(\d+)\s+(\d+)$/); .match(/^-?(\d+)\s+(\d+)$/);
if (match !== null) { if (match !== null) {
const { 1: code, 2: timestamp } = match; const { 1: code, 2: timestamp } = match;
// Update trial job's status based on result code // Update trial job's status based on result code
...@@ -632,13 +617,13 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -632,13 +617,13 @@ class RemoteMachineTrainingService implements TrainingService {
} }
} }
trialJob.endTime = parseInt(timestamp, 10); trialJob.endTime = parseInt(timestamp, 10);
this.releaseTrialExecutor(trialJob); this.releaseTrialResource(trialJob);
} }
this.log.debug(`trailJob status update: ${trialJob.id}, ${trialJob.status}`); this.log.debug(`trailJob status update: ${trialJob.id}, ${trialJob.status}`);
} }
deferred.resolve(trialJob); deferred.resolve(trialJob);
} catch (error) { } catch (error) {
this.log.error(`Update job status exception, error is ${error.message}`); this.log.debug(`(Ignorable mostly)Update job status exception, error is ${error.message}`);
if (error instanceof NNIError && error.name === NNIErrorNames.NOT_FOUND) { if (error instanceof NNIError && error.name === NNIErrorNames.NOT_FOUND) {
deferred.resolve(trialJob); deferred.resolve(trialJob);
} else { } else {
...@@ -650,45 +635,30 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -650,45 +635,30 @@ class RemoteMachineTrainingService implements TrainingService {
return deferred.promise; return deferred.promise;
} }
private getRemoteScriptsPath(userName: string): string {
return unixPathJoin(getRemoteTmpDir(this.remoteOS), userName, 'nni', 'scripts');
}
private getHostJobRemoteDir(jobId: string): string {
return unixPathJoin(this.remoteExpRootDir, 'hostjobs', jobId);
}
private getRemoteExperimentRootDir(): string {
return unixPathJoin(getRemoteTmpDir(this.remoteOS), 'nni', 'experiments', getExperimentId());
}
public get MetricsEmitter(): EventEmitter { public get MetricsEmitter(): EventEmitter {
return this.metricsEmitter; return this.metricsEmitter;
} }
private getJobPidPath(jobId: string): string { private getJobPidPath(executor: ShellExecutor, jobId: string): string {
const trialJobDetail: RemoteMachineTrialJobDetail | undefined = this.trialJobsMap.get(jobId); const trialJobDetail: RemoteMachineTrialJobDetail | undefined = this.trialJobsMap.get(jobId);
if (trialJobDetail === undefined) { if (trialJobDetail === undefined) {
throw new NNIError(NNIErrorNames.INVALID_JOB_DETAIL, `Invalid job detail information for trial job ${jobId}`); throw new NNIError(NNIErrorNames.INVALID_JOB_DETAIL, `Invalid job detail information for trial job ${jobId}`);
} }
return unixPathJoin(trialJobDetail.workingDirectory, '.nni', 'jobpid'); return executor.joinPath(trialJobDetail.workingDirectory, '.nni', 'jobpid');
} }
private async writeParameterFile(trialJobId: string, hyperParameters: HyperParameters): Promise<void> { private async writeParameterFile(trialJobId: string, hyperParameters: HyperParameters): Promise<void> {
const executor: ShellExecutor | undefined = this.trialExecutorMap.get(trialJobId); const executor = await this.getExecutor(trialJobId);
if (executor === undefined) {
throw new Error('ShellExecutor is undefined.');
}
const trialWorkingFolder: string = unixPathJoin(this.remoteExpRootDir, 'trials', trialJobId); const trialWorkingFolder: string = executor.joinPath(executor.getRemoteExperimentRootDir(getExperimentId()), 'trials', trialJobId);
const trialLocalTempFolder: string = path.join(this.expRootDir, 'trials-local', trialJobId); const trialLocalTempFolder: string = path.join(this.expRootDir, 'trials-local', trialJobId);
const fileName: string = generateParamFileName(hyperParameters); const fileName: string = generateParamFileName(hyperParameters);
const localFilepath: string = path.join(trialLocalTempFolder, fileName); const localFilepath: string = path.join(trialLocalTempFolder, fileName);
await fs.promises.writeFile(localFilepath, hyperParameters.value, { encoding: 'utf8' }); await fs.promises.writeFile(localFilepath, hyperParameters.value, { encoding: 'utf8' });
await executor.copyFileToRemote(localFilepath, unixPathJoin(trialWorkingFolder, fileName)); await executor.copyFileToRemote(localFilepath, executor.joinPath(trialWorkingFolder, fileName));
} }
} }
......
...@@ -4,27 +4,39 @@ ...@@ -4,27 +4,39 @@
'use strict'; 'use strict';
import * as assert from 'assert'; import * as assert from 'assert';
import * as fs from 'fs';
import * as os from 'os'; import * as os from 'os';
import * as path from 'path'; import * as path from 'path';
import * as fs from 'fs'; import { Client, ClientChannel, ConnectConfig, SFTPWrapper } from 'ssh2';
import { Client, ClientChannel, SFTPWrapper, ConnectConfig } from 'ssh2';
import { Deferred } from "ts-deferred";
import { RemoteCommandResult, RemoteMachineMeta } from "./remoteMachineData";
import * as stream from 'stream'; import * as stream from 'stream';
import { OsCommands } from "./osCommands"; import { Deferred } from "ts-deferred";
import { LinuxCommands } from "./extends/linuxCommands";
import { getLogger, Logger } from '../../common/log'; import { getLogger, Logger } from '../../common/log';
import { NNIError, NNIErrorNames } from '../../common/errors'; import { uniqueString, randomInt } from '../../common/utils';
import { execRemove, tarAdd } from '../common/util'; import { execRemove, tarAdd } from '../common/util';
import { getRemoteTmpDir, uniqueString, unixPathJoin } from '../../common/utils'; import { LinuxCommands } from "./extends/linuxCommands";
import { WindowsCommands } from './extends/windowsCommands';
import { OsCommands } from "./osCommands";
import { RemoteCommandResult, RemoteMachineMeta } from "./remoteMachineData";
import { NNIError, NNIErrorNames } from '../../common/errors';
class ShellExecutor { class ShellExecutor {
private sshClient: Client = new Client(); public name: string = "";
private osCommands: OsCommands | undefined;
private usedConnectionNumber: number = 0; //count the connection number of every client
protected pathSpliter: string = '/'; private readonly lineBreaker = new RegExp(`[\r\n]+`);
protected multiplePathSpliter: RegExp = new RegExp(`\\${this.pathSpliter}{2,}`); private readonly maxUsageCount = 5;
private osCommands: OsCommands | undefined;
private usedCount: number = 0; //count the connection number of every client
private readonly sshClient: Client;
private readonly log: Logger;
private tempPath: string = "";
private isWindows: boolean = false;
private channelDefaultOutputs: string[] = [];
constructor() {
this.log = getLogger();
this.sshClient = new Client();
}
public async initialize(rmMeta: RemoteMachineMeta): Promise<void> { public async initialize(rmMeta: RemoteMachineMeta): Promise<void> {
const deferred: Deferred<void> = new Deferred<void>(); const deferred: Deferred<void> = new Deferred<void>();
...@@ -33,8 +45,9 @@ class ShellExecutor { ...@@ -33,8 +45,9 @@ class ShellExecutor {
host: rmMeta.ip, host: rmMeta.ip,
port: rmMeta.port, port: rmMeta.port,
username: rmMeta.username, username: rmMeta.username,
tryKeyboard: true tryKeyboard: true,
}; };
this.name = `${rmMeta.username}@${rmMeta.ip}:${rmMeta.port}`;
if (rmMeta.passwd !== undefined) { if (rmMeta.passwd !== undefined) {
connectConfig.password = rmMeta.passwd; connectConfig.password = rmMeta.passwd;
} else if (rmMeta.sshKeyPath !== undefined) { } else if (rmMeta.sshKeyPath !== undefined) {
...@@ -49,20 +62,42 @@ class ShellExecutor { ...@@ -49,20 +62,42 @@ class ShellExecutor {
} else { } else {
deferred.reject(new Error(`No valid passwd or sshKeyPath is configed.`)); deferred.reject(new Error(`No valid passwd or sshKeyPath is configed.`));
} }
this.sshClient.on('ready', async () => { this.sshClient.on('ready', async () => {
// check OS type: windows or else // check OS type: windows or else
const result = await this.execute("ver"); const result = await this.execute("ver");
if (result.exitCode == 0 && result.stdout.search("Windows") > -1) { if (result.exitCode == 0 && result.stdout.search("Windows") > -1) {
// not implement Windows commands yet. this.osCommands = new WindowsCommands();
throw new Error("not implement Windows commands yet."); this.isWindows = true;
// detect default output and trying to remove it under windows.
// Anaconda has this kind of output.
let defaultResult = await this.execute("");
if (defaultResult.stdout !== "") {
deferred.reject(new Error(`The windows remote node shouldn't output welcome message, below content should be removed from the command window! \n` +
`${defaultResult.stdout}`));
}
defaultResult = await this.execute("powershell -command \"\"");
if (defaultResult.stdout !== "") {
this.channelDefaultOutputs.push(defaultResult.stdout);
}
this.log.debug(`set channelDefaultOutput to "${this.channelDefaultOutputs}"`);
// parse temp folder to expand possible environment variables.
const commandResult = await this.execute("echo %TEMP%");
this.tempPath = commandResult.stdout.replace(this.lineBreaker, "");
} else { } else {
this.osCommands = new LinuxCommands(); this.osCommands = new LinuxCommands();
// it's not stable to get tmp path by Linux command, like "echo /tmp" or "ld -d /tmp".
// Sometime it returns empty back, so hard code tmp path here.
this.tempPath = "/tmp";
} }
deferred.resolve(); deferred.resolve();
}).on('error', (err: Error) => { }).on('error', (err: Error) => {
// SSH connection error, reject with error message // SSH connection error, reject with error message
deferred.reject(new Error(err.message)); deferred.reject(new Error(err.message));
}).on("keyboard-interactive", (name, instructions, lang, prompts, finish) => { }).on("keyboard-interactive", (_name, _instructions, _lang, _prompts, finish) => {
finish([rmMeta.passwd]); finish([rmMeta.passwd]);
}).connect(connectConfig); }).connect(connectConfig);
...@@ -73,43 +108,108 @@ class ShellExecutor { ...@@ -73,43 +108,108 @@ class ShellExecutor {
this.sshClient.end(); this.sshClient.end();
} }
public get getUsedConnectionNumber(): number { public addUsage(): boolean {
return this.usedConnectionNumber; let isAddedSuccess = false;
if (this.usedCount < this.maxUsageCount) {
this.usedCount++;
isAddedSuccess = true;
}
return isAddedSuccess;
}
public releaseUsage(): boolean {
let canBeReleased = false;
if (this.usedCount > 0) {
this.usedCount--;
}
if (this.usedCount == 0) {
canBeReleased = true;
}
return canBeReleased;
}
public getScriptName(mainName: string): string {
if (this.osCommands === undefined) {
throw new Error("osCommands must be initialized!");
}
return `${mainName}.${this.osCommands.getScriptExt()}`;
}
public generateStartScript(workingDirectory: string, trialJobId: string, experimentId: string,
trialSequenceId: string, isMultiPhase: boolean,
command: string, nniManagerAddress: string, nniManagerPort: number,
nniManagerVersion: string, logCollection: string, cudaVisibleSetting: string): string {
if (this.osCommands === undefined) {
throw new Error("osCommands must be initialized!");
}
const jobIdFileName = this.joinPath(workingDirectory, '.nni', 'jobpid');
const exitCodeFile = this.joinPath(workingDirectory, '.nni', 'code');
const codeDir = this.getRemoteCodePath(experimentId);
return this.osCommands.generateStartScript(workingDirectory, trialJobId, experimentId,
trialSequenceId, isMultiPhase, jobIdFileName, command,
nniManagerAddress, nniManagerPort, nniManagerVersion,
logCollection, exitCodeFile, codeDir, cudaVisibleSetting);
}
public generateGpuStatsScript(experimentId: string): string {
if (this.osCommands === undefined) {
throw new Error("osCommands must be initialized!");
}
return this.osCommands.generateGpuStatsScript(this.getRemoteScriptsPath(experimentId));
}
public getTempPath(): string {
if (this.tempPath === "") {
throw new Error("tempPath must be initialized!");
}
return this.tempPath;
}
public getRemoteScriptsPath(experimentId: string): string {
return this.joinPath(this.getRemoteExperimentRootDir(experimentId), 'scripts');
}
public getRemoteCodePath(experimentId: string): string {
return this.joinPath(this.getRemoteExperimentRootDir(experimentId), 'nni-code');
} }
public addUsedConnectionNumber(): void { public getRemoteExperimentRootDir(experimentId: string): string {
this.usedConnectionNumber += 1; return this.joinPath(this.tempPath, 'nni', 'experiments', experimentId);
} }
public minusUsedConnectionNumber(): void { public joinPath(...paths: string[]): string {
this.usedConnectionNumber -= 1; if (!this.osCommands) {
throw new Error("osCommands must be initialized!");
}
return this.osCommands.joinPath(...paths);
} }
public async createFolder(folderName: string, sharedFolder: boolean = false): Promise<boolean> { public async createFolder(folderName: string, sharedFolder: boolean = false): Promise<boolean> {
const commandText = this.osCommands && this.osCommands.createFolder(folderName, sharedFolder); const commandText = this.osCommands && this.osCommands.createFolder(folderName, sharedFolder);
const commandResult = await this.execute(commandText); const commandResult = await this.execute(commandText);
const result = commandResult.exitCode >= 0; const result = commandResult.exitCode == 0;
return result; return result;
} }
public async allowPermission(isRecursive: boolean = false, ...folders: string[]): Promise<boolean> { public async allowPermission(isRecursive: boolean = false, ...folders: string[]): Promise<boolean> {
const commandText = this.osCommands && this.osCommands.allowPermission(isRecursive, ...folders); const commandText = this.osCommands && this.osCommands.allowPermission(isRecursive, ...folders);
const commandResult = await this.execute(commandText); const commandResult = await this.execute(commandText);
const result = commandResult.exitCode >= 0; const result = commandResult.exitCode == 0;
return result; return result;
} }
public async removeFolder(folderName: string, isRecursive: boolean = false, isForce: boolean = true): Promise<boolean> { public async removeFolder(folderName: string, isRecursive: boolean = false, isForce: boolean = true): Promise<boolean> {
const commandText = this.osCommands && this.osCommands.removeFolder(folderName, isRecursive, isForce); const commandText = this.osCommands && this.osCommands.removeFolder(folderName, isRecursive, isForce);
const commandResult = await this.execute(commandText); const commandResult = await this.execute(commandText);
const result = commandResult.exitCode >= 0; const result = commandResult.exitCode == 0;
return result; return result;
} }
public async removeFiles(folderOrFileName: string, filePattern: string = ""): Promise<boolean> { public async removeFiles(folderOrFileName: string, filePattern: string = ""): Promise<boolean> {
const commandText = this.osCommands && this.osCommands.removeFiles(folderOrFileName, filePattern); const commandText = this.osCommands && this.osCommands.removeFiles(folderOrFileName, filePattern);
const commandResult = await this.execute(commandText); const commandResult = await this.execute(commandText);
const result = commandResult.exitCode >= 0; const result = commandResult.exitCode == 0;
return result; return result;
} }
...@@ -142,10 +242,10 @@ class ShellExecutor { ...@@ -142,10 +242,10 @@ class ShellExecutor {
return commandResult.exitCode == 0; return commandResult.exitCode == 0;
} }
public async executeScript(script: string, isFile: boolean, isInteractive: boolean = false): Promise<boolean> { public async executeScript(script: string, isFile: boolean = false, isInteractive: boolean = false): Promise<RemoteCommandResult> {
const commandText = this.osCommands && this.osCommands.executeScript(script, isFile); const commandText = this.osCommands && this.osCommands.executeScript(script, isFile);
const commandResult = await this.execute(commandText, undefined, isInteractive); const commandResult = await this.execute(commandText, undefined, isInteractive);
return commandResult.exitCode == 0; return commandResult;
} }
/** /**
...@@ -154,13 +254,13 @@ class ShellExecutor { ...@@ -154,13 +254,13 @@ class ShellExecutor {
* @param remoteFilePath the target path in remote machine * @param remoteFilePath the target path in remote machine
*/ */
public async copyFileToRemote(localFilePath: string, remoteFilePath: string): Promise<boolean> { public async copyFileToRemote(localFilePath: string, remoteFilePath: string): Promise<boolean> {
const log: Logger = getLogger(); const commandIndex = randomInt(10000);
log.debug(`copyFileToRemote: localFilePath: ${localFilePath}, remoteFilePath: ${remoteFilePath}`); this.log.debug(`copyFileToRemote(${commandIndex}): localFilePath: ${localFilePath}, remoteFilePath: ${remoteFilePath}`);
const deferred: Deferred<boolean> = new Deferred<boolean>(); const deferred: Deferred<boolean> = new Deferred<boolean>();
this.sshClient.sftp((err: Error, sftp: SFTPWrapper) => { this.sshClient.sftp((err: Error, sftp: SFTPWrapper) => {
if (err !== undefined && err !== null) { if (err !== undefined && err !== null) {
log.error(`copyFileToRemote: ${err.message}, ${localFilePath}, ${remoteFilePath}`); this.log.error(`copyFileToRemote(${commandIndex}): ${err}`);
deferred.reject(err); deferred.reject(err);
return; return;
...@@ -169,6 +269,7 @@ class ShellExecutor { ...@@ -169,6 +269,7 @@ class ShellExecutor {
sftp.fastPut(localFilePath, remoteFilePath, (fastPutErr: Error) => { sftp.fastPut(localFilePath, remoteFilePath, (fastPutErr: Error) => {
sftp.end(); sftp.end();
if (fastPutErr !== undefined && fastPutErr !== null) { if (fastPutErr !== undefined && fastPutErr !== null) {
this.log.error(`copyFileToRemote(${commandIndex}) fastPutErr: ${fastPutErr}, ${localFilePath}, ${remoteFilePath}`);
deferred.reject(fastPutErr); deferred.reject(fastPutErr);
} else { } else {
deferred.resolve(true); deferred.resolve(true);
...@@ -183,12 +284,15 @@ class ShellExecutor { ...@@ -183,12 +284,15 @@ class ShellExecutor {
* Copy files and directories in local directory recursively to remote directory * Copy files and directories in local directory recursively to remote directory
* @param localDirectory local diretory * @param localDirectory local diretory
* @param remoteDirectory remote directory * @param remoteDirectory remote directory
* @param remoteOS the OS of remote machine
*/ */
public async copyDirectoryToRemote(localDirectory: string, remoteDirectory: string, remoteOS: string): Promise<void> { public async copyDirectoryToRemote(localDirectory: string, remoteDirectory: string): Promise<void> {
const tmpSuffix: string = uniqueString(5); const tmpSuffix: string = uniqueString(5);
const localTarPath: string = path.join(os.tmpdir(), `nni_tmp_local_${tmpSuffix}.tar.gz`); const localTarPath: string = path.join(os.tmpdir(), `nni_tmp_local_${tmpSuffix}.tar.gz`);
const remoteTarPath: string = unixPathJoin(getRemoteTmpDir(remoteOS), `nni_tmp_remote_${tmpSuffix}.tar.gz`); if (!this.osCommands) {
throw new Error("osCommands must be initialized!");
}
const remoteTarPath: string = this.osCommands.joinPath(this.tempPath, `nni_tmp_remote_${tmpSuffix}.tar.gz`);
// Create remote directory // Create remote directory
await this.createFolder(remoteDirectory); await this.createFolder(remoteDirectory);
// Compress files in local directory to experiment root directory // Compress files in local directory to experiment root directory
...@@ -202,12 +306,13 @@ class ShellExecutor { ...@@ -202,12 +306,13 @@ class ShellExecutor {
} }
public async getRemoteFileContent(filePath: string): Promise<string> { public async getRemoteFileContent(filePath: string): Promise<string> {
const commandIndex = randomInt(10000);
this.log.debug(`getRemoteFileContent(${commandIndex}): filePath: ${filePath}`);
const deferred: Deferred<string> = new Deferred<string>(); const deferred: Deferred<string> = new Deferred<string>();
this.sshClient.sftp((err: Error, sftp: SFTPWrapper) => { this.sshClient.sftp((err: Error, sftp: SFTPWrapper) => {
if (err !== undefined && err !== null) { if (err !== undefined && err !== null) {
getLogger() this.log.error(`getRemoteFileContent(${commandIndex}) sftp: ${err}`);
.error(`getRemoteFileContent: ${err.message}`); deferred.reject(new Error(`SFTP error: ${err}`));
deferred.reject(new Error(`SFTP error: ${err.message}`));
return; return;
} }
...@@ -228,8 +333,7 @@ class ShellExecutor { ...@@ -228,8 +333,7 @@ class ShellExecutor {
deferred.resolve(dataBuffer); deferred.resolve(dataBuffer);
}); });
} catch (error) { } catch (error) {
getLogger() this.log.error(`getRemoteFileContent(${commandIndex}): ${error.message}`);
.error(`getRemoteFileContent: ${error.message}`);
sftp.end(); sftp.end();
deferred.reject(new Error(`SFTP error: ${error.message}`)); deferred.reject(new Error(`SFTP error: ${error.message}`));
} }
...@@ -239,16 +343,20 @@ class ShellExecutor { ...@@ -239,16 +343,20 @@ class ShellExecutor {
} }
private async execute(command: string | undefined, processOutput: ((input: RemoteCommandResult) => RemoteCommandResult) | undefined = undefined, useShell: boolean = false): Promise<RemoteCommandResult> { private async execute(command: string | undefined, processOutput: ((input: RemoteCommandResult) => RemoteCommandResult) | undefined = undefined, useShell: boolean = false): Promise<RemoteCommandResult> {
const log: Logger = getLogger();
log.debug(`remoteExeCommand: command: [${command}]`);
const deferred: Deferred<RemoteCommandResult> = new Deferred<RemoteCommandResult>(); const deferred: Deferred<RemoteCommandResult> = new Deferred<RemoteCommandResult>();
let stdout: string = ''; let stdout: string = '';
let stderr: string = ''; let stderr: string = '';
let exitCode: number; let exitCode: number;
const commandIndex = randomInt(10000);
this.log.debug(`remoteExeCommand(${commandIndex}): [${command}]`);
// Windows always uses shell, and it needs to disable to get it works.
useShell = useShell && !this.isWindows;
const callback = (err: Error, channel: ClientChannel): void => { const callback = (err: Error, channel: ClientChannel): void => {
if (err !== undefined && err !== null) { if (err !== undefined && err !== null) {
log.error(`remoteExeCommand: ${err.message}`); this.log.error(`remoteExeCommand(${commandIndex}): ${err.message}`);
deferred.reject(err); deferred.reject(err);
return; return;
} }
...@@ -258,7 +366,23 @@ class ShellExecutor { ...@@ -258,7 +366,23 @@ class ShellExecutor {
}); });
channel.on('exit', (code: any) => { channel.on('exit', (code: any) => {
exitCode = <number>code; exitCode = <number>code;
log.debug(`remoteExeCommand exit(${exitCode})\nstdout: ${stdout}\nstderr: ${stderr}`);
// remove default output to get stdout correct.
if (this.channelDefaultOutputs.length > 0) {
let modifiedStdout = stdout;
this.channelDefaultOutputs.forEach(defaultOutput => {
if (modifiedStdout.startsWith(defaultOutput)) {
if (modifiedStdout.length > defaultOutput.length) {
modifiedStdout = modifiedStdout.substr(defaultOutput.length);
} else if (modifiedStdout.length === defaultOutput.length) {
modifiedStdout = "";
}
}
});
stdout = modifiedStdout;
}
this.log.debug(`remoteExeCommand(${commandIndex}) exit(${exitCode})\nstdout: ${stdout}\nstderr: ${stderr}`);
let result = { let result = {
stdout: stdout, stdout: stdout,
stderr: stderr, stderr: stderr,
...@@ -270,7 +394,7 @@ class ShellExecutor { ...@@ -270,7 +394,7 @@ class ShellExecutor {
} }
deferred.resolve(result); deferred.resolve(result);
}); });
channel.stderr.on('data', function (data) { channel.stderr.on('data', function (data: any) {
stderr += data; stderr += data;
}); });
......
...@@ -8,7 +8,6 @@ import * as chaiAsPromised from 'chai-as-promised'; ...@@ -8,7 +8,6 @@ import * as chaiAsPromised from 'chai-as-promised';
import * as component from '../../../common/component'; import * as component from '../../../common/component';
import { cleanupUnitTest, prepareUnitTest } from '../../../common/utils'; import { cleanupUnitTest, prepareUnitTest } from '../../../common/utils';
import { LinuxCommands } from '../extends/linuxCommands'; import { LinuxCommands } from '../extends/linuxCommands';
// import { TrialConfigMetadataKey } from '../trialConfigMetadataKey';
describe('Unit Test for linuxCommands', () => { describe('Unit Test for linuxCommands', () => {
...@@ -88,10 +87,6 @@ describe('Unit Test for linuxCommands', () => { ...@@ -88,10 +87,6 @@ describe('Unit Test for linuxCommands', () => {
)).to.equal(false); )).to.equal(false);
}) })
it('killChildProcesses', async () => {
chai.expect(linuxCommands.killChildProcesses("test")).to.equal("pkill -P `cat 'test'`");
})
it('extractFile', async () => { it('extractFile', async () => {
chai.expect(linuxCommands.extractFile("test.tar", "testfolder")).to.equal("tar -oxzf 'test.tar' -C 'testfolder'"); chai.expect(linuxCommands.extractFile("test.tar", "testfolder")).to.equal("tar -oxzf 'test.tar' -C 'testfolder'");
}) })
......
...@@ -8,29 +8,29 @@ import * as fs from 'fs'; ...@@ -8,29 +8,29 @@ import * as fs from 'fs';
import * as chai from 'chai'; import * as chai from 'chai';
import * as chaiAsPromised from 'chai-as-promised'; import * as chaiAsPromised from 'chai-as-promised';
import { Client } from 'ssh2';
import { ShellExecutor } from '../shellExecutor'; import { ShellExecutor } from '../shellExecutor';
import { prepareUnitTest, cleanupUnitTest } from '../../../common/utils'; import { prepareUnitTest, cleanupUnitTest } from '../../../common/utils';
const LOCALFILE: string = '/tmp/localSshclientUTData'; const LOCALFILE: string = 'localSshUTData';
const REMOTEFILE: string = '/tmp/remoteSshclientUTData'; const REMOTEFILE: string = 'remoteSshUTData';
const REMOTEFOLDER: string = '/tmp/remoteSshclientUTFolder'; const REMOTEFOLDER: string = 'remoteSshUTFolder';
async function copyFile(executor: ShellExecutor): Promise<void> { async function copyFile(executor: ShellExecutor): Promise<void> {
await executor.copyFileToRemote(LOCALFILE, REMOTEFILE); const remoteFullName = executor.joinPath(executor.getTempPath(), REMOTEFILE);
await executor.copyFileToRemote(LOCALFILE, remoteFullName);
} }
async function copyFileToRemoteLoop(executor: ShellExecutor): Promise<void> { async function copyFileToRemoteLoop(executor: ShellExecutor): Promise<void> {
for (let i: number = 0; i < 10; i++) { const remoteFullName = executor.joinPath(executor.getTempPath(), REMOTEFILE);
// console.log(i); for (let i: number = 0; i < 3; i++) {
await executor.copyFileToRemote(LOCALFILE, REMOTEFILE); await executor.copyFileToRemote(LOCALFILE, remoteFullName);
} }
} }
async function getRemoteFileContentLoop(executor: ShellExecutor): Promise<void> { async function getRemoteFileContentLoop(executor: ShellExecutor): Promise<void> {
for (let i: number = 0; i < 10; i++) { const remoteFullName = executor.joinPath(executor.getTempPath(), REMOTEFILE);
// console.log(i); for (let i: number = 0; i < 3; i++) {
await executor.getRemoteFileContent(REMOTEFILE); await executor.getRemoteFileContent(remoteFullName);
} }
} }
...@@ -41,14 +41,16 @@ describe('ShellExecutor test', () => { ...@@ -41,14 +41,16 @@ describe('ShellExecutor test', () => {
rmMeta = JSON.parse(fs.readFileSync('../../.vscode/rminfo.json', 'utf8')); rmMeta = JSON.parse(fs.readFileSync('../../.vscode/rminfo.json', 'utf8'));
console.log(rmMeta); console.log(rmMeta);
} catch (err) { } catch (err) {
console.log(`Please configure rminfo.json to enable remote machine test.${err}`); console.log(`Please configure rminfo.json to enable remote machine test. ${err}`);
skip = true; skip = true;
} }
before(async () => { before(async () => {
chai.should(); chai.should();
chai.use(chaiAsPromised); chai.use(chaiAsPromised);
await cpp.exec(`echo '1234' > ${LOCALFILE}`); if (!fs.existsSync(LOCALFILE)){
await cpp.exec(`echo '1234' > ${LOCALFILE}`);
}
prepareUnitTest(); prepareUnitTest();
}); });
...@@ -61,26 +63,27 @@ describe('ShellExecutor test', () => { ...@@ -61,26 +63,27 @@ describe('ShellExecutor test', () => {
if (skip) { if (skip) {
return; return;
} }
const shellExecutor: ShellExecutor = new ShellExecutor(); const executor: ShellExecutor = new ShellExecutor();
await shellExecutor.initialize(rmMeta); await executor.initialize(rmMeta);
let result = await shellExecutor.createFolder(REMOTEFOLDER, false); const remoteFullPath = executor.joinPath(executor.getTempPath(), REMOTEFOLDER);
let result = await executor.createFolder(remoteFullPath, false);
chai.expect(result).eq(true); chai.expect(result).eq(true);
result = await shellExecutor.removeFolder(REMOTEFOLDER); const commandResult = await executor.executeScript("dir");
chai.expect(commandResult.exitCode).eq(0);
result = await executor.removeFolder(remoteFullPath);
chai.expect(result).eq(true); chai.expect(result).eq(true);
await executor.close();
}); });
it('Test ShellExecutor', async () => { it('Test ShellExecutor', async () => {
if (skip) { if (skip) {
return; return;
} }
const shellExecutor: ShellExecutor = new ShellExecutor(); const executor: ShellExecutor = new ShellExecutor();
await shellExecutor.initialize(rmMeta); await executor.initialize(rmMeta);
await copyFile(shellExecutor); await copyFile(executor);
await Promise.all([ await copyFileToRemoteLoop(executor);
copyFileToRemoteLoop(shellExecutor), await getRemoteFileContentLoop(executor);
copyFileToRemoteLoop(shellExecutor), await executor.close();
copyFileToRemoteLoop(shellExecutor),
getRemoteFileContentLoop(shellExecutor)
]);
}); });
}); });
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
import * as chai from 'chai';
import * as chaiAsPromised from 'chai-as-promised';
import * as component from '../../../common/component';
import { cleanupUnitTest, prepareUnitTest } from '../../../common/utils';
import { WindowsCommands } from '../extends/windowsCommands';
describe('Unit Test for Windows Commands', () => {
let windowsCommands: WindowsCommands
before(() => {
chai.should();
chai.use(chaiAsPromised);
prepareUnitTest();
});
after(() => {
cleanupUnitTest();
});
beforeEach(() => {
windowsCommands = component.get(WindowsCommands);
});
afterEach(() => {
});
it('joinPath', async () => {
chai.expect(windowsCommands.joinPath("/root/", "\\first")).to.equal("\\root\\first");
chai.expect(windowsCommands.joinPath("root/", "first")).to.equal("root\\first");
chai.expect(windowsCommands.joinPath("\\root/", "\\first")).to.equal("\\root\\first");
chai.expect(windowsCommands.joinPath("\\root\\", "\\first")).to.equal("\\root\\first");
chai.expect(windowsCommands.joinPath("\\root", "first")).to.equal("\\root\\first");
chai.expect(windowsCommands.joinPath("\\root\\", "first")).to.equal("\\root\\first");
chai.expect(windowsCommands.joinPath("root\\", "first")).to.equal("root\\first");
chai.expect(windowsCommands.joinPath("root\\")).to.equal("root\\");
chai.expect(windowsCommands.joinPath("root")).to.equal("root");
chai.expect(windowsCommands.joinPath(".\\root")).to.equal(".\\root");
chai.expect(windowsCommands.joinPath("")).to.equal(".");
chai.expect(windowsCommands.joinPath("..")).to.equal("..");
})
it('createFolder', async () => {
chai.expect(windowsCommands.createFolder("test")).to.equal("mkdir \"test\"");
chai.expect(windowsCommands.createFolder("test", true)).to.equal("mkdir \"test\"\r\nICACLS \"test\" /grant \"Users\":F");
})
it('allowPermission', async () => {
chai.expect(windowsCommands.allowPermission(true, "test", "test1")).to.equal("ICACLS \"test\" /grant \"Users\":F /T\r\nICACLS \"test1\" /grant \"Users\":F /T\r\n");
chai.expect(windowsCommands.allowPermission(false, "test")).to.equal("ICACLS \"test\" /grant \"Users\":F\r\n");
})
it('removeFolder', async () => {
chai.expect(windowsCommands.removeFolder("test")).to.equal("rmdir /q \"test\"");
chai.expect(windowsCommands.removeFolder("test", true)).to.equal("rmdir /s /q \"test\"");
chai.expect(windowsCommands.removeFolder("test", true, false)).to.equal("rmdir /s \"test\"");
chai.expect(windowsCommands.removeFolder("test", false, false)).to.equal("rmdir \"test\"");
chai.expect(windowsCommands.removeFolder("test", true, true)).to.equal("rmdir /s /q \"test\"");
})
it('removeFiles', async () => {
chai.expect(windowsCommands.removeFiles("test", "*.sh")).to.equal("del \"test\\*.sh\"");
chai.expect(windowsCommands.removeFiles("test", "")).to.equal("del \"test\"");
})
it('readLastLines', async () => {
chai.expect(windowsCommands.readLastLines("test", 3)).to.equal("powershell.exe Get-Content \"test\" -Tail 3");
})
it('isProcessAlive', async () => {
chai.expect(windowsCommands.isProcessAliveCommand("test")).to.equal("powershell.exe Get-Process -Id (get-content \"test\") -ErrorAction SilentlyContinue");
chai.expect(windowsCommands.isProcessAliveProcessOutput(
{
exitCode: 0,
stdout: "",
stderr: ""
}
)).to.equal(true);
chai.expect(windowsCommands.isProcessAliveProcessOutput(
{
exitCode: 10,
stdout: "",
stderr: ""
}
)).to.equal(false);
})
it('extractFile', async () => {
chai.expect(windowsCommands.extractFile("test.tar", "testfolder")).to.equal("tar -xf \"test.tar\" -C \"testfolder\"");
})
it('executeScript', async () => {
chai.expect(windowsCommands.executeScript("test.sh", true)).to.equal("test.sh");
chai.expect(windowsCommands.executeScript("test script'\"", false)).to.equal("test script'\"");
})
});
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment