Unverified Commit e29b58a1 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #244 from microsoft/master

merge master
parents e0c2c0eb 4f88be1f
......@@ -173,12 +173,12 @@ install-python-modules:
dev-install-python-modules:
#$(_INFO) Installing Python SDK $(_END)
mkdir -p build
ln -sf ../src/sdk/pynni/nni build/nni
ln -sf ../src/sdk/pynni/nnicli build/nnicli
ln -sf ../tools/nni_annotation build/nni_annotation
ln -sf ../tools/nni_cmd build/nni_cmd
ln -sf ../tools/nni_trial_tool build/nni_trial_tool
ln -sf ../tools/nni_gpu_tool build/nni_gpu_tool
ln -sf ../src/sdk/pynni/nni build
ln -sf ../src/sdk/pycli/nnicli build
ln -sf ../tools/nni_annotation build
ln -sf ../tools/nni_cmd build
ln -sf ../tools/nni_trial_tool build
ln -sf ../tools/nni_gpu_tool build
cp setup.py build/
cp README.md build/
sed -ie 's/$(NNI_VERSION_TEMPLATE)/$(NNI_VERSION_VALUE)/' build/setup.py
......@@ -209,10 +209,12 @@ dev-install-node-modules:
ln -sf ${PWD}/src/nni_manager/dist $(NNI_PKG_FOLDER)
cp src/nni_manager/package.json $(NNI_PKG_FOLDER)
sed -ie 's/$(NNI_VERSION_TEMPLATE)/$(NNI_VERSION_VALUE)/' $(NNI_PKG_FOLDER)/package.json
ln -sf ${PWD}/src/nni_manager/node_modules $(NNI_PKG_FOLDER)/node_modules
ln -sf ${PWD}/src/webui/build $(NNI_PKG_FOLDER)/static
ln -sf ${PWD}/src/nasui/build $(NASUI_PKG_FOLDER)/build
ln -sf ${PWD}/src/nasui/server.js $(NASUI_PKG_FOLDER)/server.js
ln -sf ${PWD}/src/nni_manager/node_modules $(NNI_PKG_FOLDER)
ln -sf ${PWD}/src/webui/build -t $(NNI_PKG_FOLDER)
mv $(NNI_PKG_FOLDER)/build $(NNI_PKG_FOLDER)/static
mkdir -p $(NASUI_PKG_FOLDER)
ln -sf ${PWD}/src/nasui/build $(NASUI_PKG_FOLDER)
ln -sf ${PWD}/src/nasui/server.js $(NASUI_PKG_FOLDER)
.PHONY: install-scripts
install-scripts:
......
......@@ -16,7 +16,7 @@
**NNI (Neural Network Intelligence)** is a lightweight but powerful toolkit to help users **automate** <a href="docs/en_US/FeatureEngineering/Overview.md">Feature Engineering</a>, <a href="docs/en_US/NAS/Overview.md">Neural Architecture Search</a>, <a href="docs/en_US/Tuner/BuiltinTuner.md">Hyperparameter Tuning</a> and <a href="docs/en_US/Compressor/Overview.md">Model Compression</a>.
The tool manages automated machine learning (AutoML) experiments, **dispatches and runs** experiments' trial jobs generated by tuning algorithms to search the best neural architecture and/or hyper-parameters in **different training environments** like <a href="docs/en_US/TrainingService/LocalMode.md">Local Machine</a>, <a href="docs/en_US/TrainingService/RemoteMachineMode.md">Remote Servers</a>, <a href="docs/en_US/TrainingService/PaiMode.md">OpenPAI</a>, <a href="docs/en_US/TrainingService/KubeflowMode.md">Kubeflow</a>, <a href="docs/en_US/TrainingService/FrameworkControllerMode.md">FrameworkController on K8S (AKS etc.)</a> and other cloud options.
The tool manages automated machine learning (AutoML) experiments, **dispatches and runs** experiments' trial jobs generated by tuning algorithms to search the best neural architecture and/or hyper-parameters in **different training environments** like <a href="docs/en_US/TrainingService/LocalMode.md">Local Machine</a>, <a href="docs/en_US/TrainingService/RemoteMachineMode.md">Remote Servers</a>, <a href="docs/en_US/TrainingService/PaiMode.md">OpenPAI</a>, <a href="docs/en_US/TrainingService/KubeflowMode.md">Kubeflow</a>, <a href="docs/en_US/TrainingService/FrameworkControllerMode.md">FrameworkController on K8S (AKS etc.)</a>, <a href="docs/en_US/TrainingService/DLTSMode.md">DLWorkspace (aka. DLTS)</a> and other cloud options.
## **Who should consider using NNI**
......@@ -25,7 +25,7 @@ The tool manages automated machine learning (AutoML) experiments, **dispatches a
* Researchers and data scientists who want to easily **implement and experiment new AutoML algorithms**, may it be: hyperparameter tuning algorithm, neural architect search algorithm or model compression algorithm.
* ML Platform owners who want to **support AutoML in their platform**.
### **NNI v1.5 has been released! &nbsp;<a href="#nni-released-reminder"><img width="48" src="docs/img/release_icon.png"></a>**
### **[NNI v1.5 has been released!](https://github.com/microsoft/nni/releases) &nbsp;<a href="#nni-released-reminder"><img width="48" src="docs/img/release_icon.png"></a>**
## **NNI capabilities in a glance**
......@@ -170,6 +170,7 @@ Within the following table, we summarized the current NNI capabilities, we are g
<li><a href="docs/en_US/TrainingService/KubeflowMode.md">Kubeflow</a></li>
<li><a href="docs/en_US/TrainingService/FrameworkControllerMode.md">FrameworkController on K8S (AKS etc.)</a></li>
</ul>
<ul><li><a href="docs/en_US/TrainingService/DLTSMode.md">DLWorkspace (aka. DLTS)</a></li>
</ul>
</td>
</tr>
......@@ -334,10 +335,15 @@ With authors' permission, we listed a set of NNI usage examples and relevant art
* **Blog (in Chinese)** - [A summary of NNI new capabilities in 2019](https://mp.weixin.qq.com/s/7_KRT-rRojQbNuJzkjFMuA) by @squirrelsc
## **Feedback**
* Discuss on the NNI [Gitter](https://gitter.im/Microsoft/nni?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) in NNI.
* [File an issue](https://github.com/microsoft/nni/issues/new/choose) on GitHub.
* Ask a question with NNI tags on [Stack Overflow](https://stackoverflow.com/questions/tagged/nni?sort=Newest&edited=true).
* Discuss on the NNI [Gitter](https://gitter.im/Microsoft/nni?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) in NNI.
Join IM discussion groups:
|Gitter||WeChat|
|----|----|----|
|![image](https://user-images.githubusercontent.com/39592018/80665738-e0574a80-8acc-11ea-91bc-0836dc4cbf89.png)| OR |![image](https://user-images.githubusercontent.com/39592018/80665762-f06f2a00-8acc-11ea-8d22-e461e68e2d9b.png)|
## Related Projects
......
......@@ -10,7 +10,7 @@
**NNI (Neural Network Intelligence)** 是一个轻量但强大的工具包,帮助用户**自动**的进行[特征工程](docs/zh_CN/FeatureEngineering/Overview.md)[神经网络架构搜索](docs/zh_CN/NAS/Overview.md)[超参调优](docs/zh_CN/Tuner/BuiltinTuner.md)以及[模型压缩](docs/zh_CN/Compressor/Overview.md)
NNI 管理自动机器学习 (AutoML) 的 Experiment,**调度运行**由调优算法生成的 Trial 任务来找到最好的神经网络架构和/或超参,支持**各种训练环境**,如[本机](docs/zh_CN/TrainingService/LocalMode.md)[远程服务器](docs/zh_CN/TrainingService/RemoteMachineMode.md)[OpenPAI](docs/zh_CN/TrainingService/PaiMode.md)[Kubeflow](docs/zh_CN/TrainingService/KubeflowMode.md)[基于 K8S 的 FrameworkController(如,AKS 等)](docs/zh_CN/TrainingService/FrameworkControllerMode.md)以及其它云服务。
NNI 管理自动机器学习 (AutoML) 的 Experiment,**调度运行**由调优算法生成的 Trial 任务来找到最好的神经网络架构和/或超参,支持**各种训练环境**,如[本机](docs/zh_CN/TrainingService/LocalMode.md)[远程服务器](docs/zh_CN/TrainingService/RemoteMachineMode.md)[OpenPAI](docs/zh_CN/TrainingService/PaiMode.md)[Kubeflow](docs/zh_CN/TrainingService/KubeflowMode.md)[基于 K8S 的 FrameworkController(如,AKS 等)](docs/zh_CN/TrainingService/FrameworkControllerMode.md) [DLWorkspace (又称 DLTS)](docs/zh_CN/TrainingService/DLTSMode.md)其它云服务。
## **使用场景**
......@@ -19,7 +19,7 @@ NNI 管理自动机器学习 (AutoML) 的 Experiment,**调度运行**由调优
* 想要更容易**实现或试验新的自动机器学习算法**的研究员或数据科学家,包括:超参调优算法,神经网络搜索算法以及模型压缩算法。
* 在机器学习平台中**支持自动机器学习**
### **NNI v1.5 已发布! &nbsp;[<img width="48" src="docs/img/release_icon.png" />](#nni-released-reminder)**
### **[NNI v1.5 已发布!](https://github.com/microsoft/nni/releases) &nbsp;[<img width="48" src="docs/img/release_icon.png" />](#nni-released-reminder)**
## **NNI 功能一览**
......@@ -164,6 +164,7 @@ NNI 提供命令行工具以及友好的 WebUI 来管理训练的 Experiment。
<li><a href="docs/zh_CN/TrainingService/KubeflowMode.md">Kubeflow</a></li>
<li><a href="docs/zh_CN/TrainingService/FrameworkControllerMode.md">基于 Kubernetes(AKS 等)的 FrameworkController</a></li>
</ul>
<ul><li><a href="docs/zh_CN/TrainingService/DLTSMode.md">DLWorkspace (又称 DLTS)</a></li>
</ul>
</td>
</tr>
......
......@@ -6,7 +6,7 @@
For debugging NNI source code, your development environment should be under Ubuntu 16.04 (or above) system with python 3 and pip 3 installed, then follow the below steps.
**1. Clone the source code**
### 1. Clone the source code
Run the command
......@@ -16,7 +16,7 @@ git clone https://github.com/Microsoft/nni.git
to clone the source code
**2. Prepare the debug environment and install dependencies**
### 2. Prepare the debug environment and install dependencies
Change directory to the source code folder, then run the command
......@@ -26,7 +26,7 @@ make install-dependencies
to install the dependent tools for the environment
**3. Build source code**
### 3. Build source code
Run the command
......@@ -36,7 +36,7 @@ make build
to build the source code
**4. Install NNI to development environment**
### 4. Install NNI to development environment
Run the command
......@@ -46,7 +46,7 @@ make dev-install
to install the distribution content to development environment, and create cli scripts
**5. Check if the environment is ready**
### 5. Check if the environment is ready
Now, you can try to start an experiment to check if your environment is ready.
For example, run the command
......@@ -57,9 +57,21 @@ nnictl create --config ~/nni/examples/trials/mnist-tfv1/config.yml
And open WebUI to check if everything is OK
**6. Redeploy**
### 6. Redeploy
After the code changes, it may need to redeploy. It depends on what kind of code changed.
#### Python
It doesn't need to redeploy, but the nnictl may need to be restarted.
#### TypeScript
* If `src/nni_manager` will be changed, run `yarn watch` continually under this folder. It will rebuild code instantly.
* If `src/webui` or `src/nasui` is changed, use **step 3** to rebuild code.
The nnictl may need to be restarted.
After the code changes, use **step 3** to rebuild your codes, then the changes will take effect immediately.
---
At last, wish you have a wonderful day.
......
......@@ -45,6 +45,7 @@ extensions = [
'sphinx_markdown_tables',
'sphinxarg.ext',
'sphinx.ext.napoleon',
'sphinx.ext.viewcode',
]
# Add mock modules
......
......@@ -104,7 +104,7 @@ Tuner 有大量的文件、函数和类。 这里简单介绍最重要的文件
- `networkmorphism_tuner.py` 是使用 network morphism 算法的 Tuner。
- `bayesian.py` 是用来基于已经搜索道德模型来预测未知模型指标的贝叶斯算法。
- `bayesian.py` 是用来基于已经搜索到的模型来预测未知模型指标的贝叶斯算法。
- `graph.py` 是元图数据结构。 类 Graph 表示了模型的神经网络图。
- Graph 从模型中抽取神经网络。
......
......@@ -4,7 +4,7 @@
这是一个用于 NNI 神经网络架构搜索(NAS)接口的 Tuner。 它使用了 [ppo 算法](https://arxiv.org/abs/1707.06347)。 此实现继承了 [OpenAI 的 ppo2 实现](https://github.com/openai/baselines/tree/master/baselines/ppo2)的主要逻辑,并为 NAS 场景做了适配。
它能成功调优 [mnist-nas 示例](https://github.com/microsoft/nni/tree/master/examples/trials/mnist-nas),结果如下:
mnist-nas 示例已调优,并得到以下结果: **注意:此示例正在重构中,以支持最新的 NAS 接口,完成后会重新发布示例代码。**
![](../../img/ppo_mnist.png)
......
......@@ -6,7 +6,7 @@
要调试 NNI 源代码,需要 Ubuntu 16.04 或更高版本系统的开发环境,并需要安装 Python 3 以及 pip 3,然后遵循以下步骤。
**1. 克隆源代码**
### 1. 克隆源代码
运行命令
......@@ -15,7 +15,7 @@
来克隆源代码
**2. 准备调试环境并安装依赖项**
### 2. 准备调试环境并安装依赖项**
将目录切换到源码目录,然后运行命令
......@@ -24,7 +24,7 @@
来安装环境的依赖项工具
**3. 生成源代码**
### 3. 生成源代码
运行命令
......@@ -33,7 +33,7 @@
来生成源代码
**4. 将 NNI 安装到开发环境中**
### 4. 将 NNI 安装到开发环境中
运行命令
......@@ -42,7 +42,7 @@
来安装分发内容到开发环境,并创建 cli 脚本
**5. 检查环境是否正确**
### 5. 检查环境是否正确
Trial 启动 Experiment 来检查环境。 例如,运行命令
......@@ -51,9 +51,20 @@ Trial 启动 Experiment 来检查环境。 例如,运行命令
并打开网页界面查看
**6. 重新部署**
### 6. 重新部署
代码改动后,用**第 3 步**来重新生成代码,改动会立即生效。
代码更改后,可能需要重新部署。 这取决于更改了哪种代码。
#### Python
不需要重新部署,但可能需要重新启动 nnictl。
#### TypeScript
* 如果要更改 `src/nni_manager`,运行 `yarn watch` 可持续编译改动。 它将实时重建代码。
* 如果更改了 `src/webui``src/nasui` ,请使用 **第 3 步** 来重建代码。
可能需要重新启动 nnictl。
* * *
......
......@@ -8,6 +8,8 @@ https://github.com/pytorch/examples/blob/master/mnist/main.py
import os
import argparse
import logging
from collections import OrderedDict
import nni
import torch
import torch.nn as nn
......@@ -26,13 +28,15 @@ class Net(nn.Module):
def __init__(self, hidden_size):
super(Net, self).__init__()
# two options of conv1
self.conv1 = LayerChoice([nn.Conv2d(1, 20, 5, 1),
nn.Conv2d(1, 20, 3, 1)],
key='first_conv')
self.conv1 = LayerChoice(OrderedDict([
("conv5x5", nn.Conv2d(1, 20, 5, 1)),
("conv3x3", nn.Conv2d(1, 20, 3, 1))
]), key='first_conv')
# two options of mid_conv
self.mid_conv = LayerChoice([nn.Conv2d(20, 20, 3, 1, padding=1),
nn.Conv2d(20, 20, 5, 1, padding=2)],
key='mid_conv')
self.mid_conv = LayerChoice([
nn.Conv2d(20, 20, 3, 1, padding=1),
nn.Conv2d(20, 20, 5, 1, padding=2)
], key='mid_conv')
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(4*4*50, hidden_size)
self.fc2 = nn.Linear(hidden_size, 10)
......@@ -167,7 +171,6 @@ def get_params():
parser.add_argument('--log_interval', type=int, default=1000, metavar='N',
help='how many batches to wait before logging training status')
args, _ = parser.parse_known_args()
return args
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from collections import OrderedDict
import torch
import torch.nn as nn
......@@ -43,17 +45,15 @@ class Node(nn.Module):
stride = 2 if i < num_downsample_connect else 1
choice_keys.append("{}_p{}".format(node_id, i))
self.ops.append(
mutables.LayerChoice(
[
ops.PoolBN('max', channels, 3, stride, 1, affine=False),
ops.PoolBN('avg', channels, 3, stride, 1, affine=False),
nn.Identity() if stride == 1 else ops.FactorizedReduce(channels, channels, affine=False),
ops.SepConv(channels, channels, 3, stride, 1, affine=False),
ops.SepConv(channels, channels, 5, stride, 2, affine=False),
ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False),
ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False)
],
key=choice_keys[-1]))
mutables.LayerChoice(OrderedDict([
("maxpool", ops.PoolBN('max', channels, 3, stride, 1, affine=False)),
("avgpool", ops.PoolBN('avg', channels, 3, stride, 1, affine=False)),
("skipconnect", nn.Identity() if stride == 1 else ops.FactorizedReduce(channels, channels, affine=False)),
("sepconv3x3", ops.SepConv(channels, channels, 3, stride, 1, affine=False)),
("sepconv5x5", ops.SepConv(channels, channels, 5, stride, 2, affine=False)),
("dilconv3x3", ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False)),
("dilconv5x5", ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False))
]), key=choice_keys[-1]))
self.drop_path = ops.DropPath()
self.input_switch = mutables.InputChoice(choose_from=choice_keys, n_chosen=2, key="{}_switch".format(node_id))
......
......@@ -151,6 +151,5 @@ def load_and_parse_state_dict(filepath="./data/checkpoint-150000.pth.tar"):
for k, v in checkpoint["state_dict"].items():
if k.startswith("module."):
k = k[len("module."):]
k = re.sub(r"^(features.\d+).(\d+)", "\\1.choices.\\2", k)
result[k] = v
return result
......@@ -6,6 +6,7 @@
"build": "tsc",
"test": "nyc mocha -r ts-node/register -t 15000 --recursive **/*.test.ts --exclude node_modules/**/**/*.test.ts --colors",
"start": "node dist/main.js",
"watch": "tsc --watch",
"eslint": "npx eslint ./ --ext .ts"
},
"license": "MIT",
......
......@@ -38,7 +38,9 @@ class DLTSTrainingService implements TrainingService {
private versionCheck: boolean = true;
private logCollection: string = 'none';
private isMultiPhase: boolean = false;
private dltsRestServerHost: string;
private dltsRestServerPort?: number;
private jobMode: boolean;
private readonly trialJobsMap: Map<string, DLTSTrialJobDetail>;
private nniManagerIpConfig?: NNIManagerIpConfig;
......@@ -51,7 +53,9 @@ class DLTSTrainingService implements TrainingService {
this.trialJobsMap = new Map();
this.jobQueue = [];
this.experimentId = getExperimentId();
this.log.info('Construct DLTS training service.');
this.dltsRestServerHost = getIPV4Address();
this.jobMode = 'DLTS_JOB_ID' in process.env;
this.log.info(`Construct DLTS training service in ${this.jobMode ? 'job mode' : 'local mode'}.`);
}
public async run(): Promise<void> {
......@@ -60,12 +64,70 @@ class DLTSTrainingService implements TrainingService {
await restServer.start();
restServer.setEnableVersionCheck = this.versionCheck;
this.log.info(`DLTS Training service rest server listening on: ${restServer.endPoint}`);
if (this.jobMode) {
await this.exposeRestServerPort(restServer.clusterRestServerPort);
} else {
this.dltsRestServerPort = restServer.clusterRestServerPort
}
await Promise.all([
this.statusCheckingLoop(),
this.submitJobLoop()]);
this.log.info('DLTS training service exit.');
}
private async exposeRestServerPort(port: number): Promise<void> {
if (this.dltsClusterConfig == null) {
throw Error('Cluster config is not set');
}
const { dashboard, cluster, email, password } = this.dltsClusterConfig;
const jobId = process.env['DLTS_JOB_ID'] + '';
const uri = `${dashboard}api/clusters/${cluster}/jobs/${jobId}/endpoints`;
const qs = { email, password };
do {
this.log.debug('Checking endpoints');
const endpoints = await new Promise((resolve, reject) => {
request.get(uri, { qs, json: true }, function (error, response, body) {
if (error) {
reject(error);
} else {
resolve(body);
}
});
});
this.log.debug('Endpoints: %o', endpoints);
if (Array.isArray(endpoints)) {
const restServerEndpoint = endpoints.find(({ podPort }) => podPort === port);
if (restServerEndpoint == null) {
this.log.debug('Exposing %d', port);
await new Promise((resolve, reject) => {
request.post(uri, {
qs,
json: true,
body: {
endpoints: [{
name: "nni-rest-server",
podPort: port
}]
}
}, function (error) {
if (error) {
reject(error);
} else {
resolve();
}
});
});
} else if (restServerEndpoint['status'] === 'running') {
// We get an exposed restserver port
this.dltsRestServerHost = restServerEndpoint['nodeName'];
this.dltsRestServerPort = restServerEndpoint['port'];
break;
}
}
} while (await new Promise(resolve => setTimeout(resolve, 1000, true)));
}
private async statusCheckingLoop(): Promise<void> {
while (!this.stopping) {
const updateDLTSTrialJobs: Promise<void>[] = [];
......@@ -400,7 +462,7 @@ class DLTSTrainingService implements TrainingService {
);
}
// tslint:disable-next-line: strict-boolean-expressions
const nniManagerIp: string = this.nniManagerIpConfig ? this.nniManagerIpConfig.nniManagerIp : getIPV4Address();
const nniManagerIp: string = this.nniManagerIpConfig ? this.nniManagerIpConfig.nniManagerIp : this.dltsRestServerHost;
const version: string = this.versionCheck ? await getVersion() : '';
const nniDLTSTrialCommand: string = String.Format(
DLTS_TRIAL_COMMAND_FORMAT,
......
......@@ -100,12 +100,15 @@ class PAIK8STrainingService extends PAITrainingService {
}
}
//TODO: update trial parameters
// update trial parameters for multi-phase
public async updateTrialJob(trialJobId: string, form: TrialJobApplicationForm): Promise<TrialJobDetail> {
const trialJobDetail: undefined | TrialJobDetail = this.trialJobsMap.get(trialJobId);
const trialJobDetail: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
if (trialJobDetail === undefined) {
throw new Error(`updateTrialJob failed: ${trialJobId} not found`);
}
// Write file content ( parameter.cfg ) to working folders
await this.writeParameterFile(trialJobDetail.logPath, form.hyperParameters);
return trialJobDetail;
}
......@@ -230,24 +233,20 @@ class PAIK8STrainingService extends PAITrainingService {
this.paiRestServerPort = this.paiJobRestServer.clusterRestServerPort;
// Step 1. Prepare PAI job configuration
const trialLocalFolder: string = path.join(this.paiTrialConfig.nniManagerNFSMountPath, this.experimentId, trialJobId);
//create trial local working folder locally.
await execMkdir(trialLocalFolder);
await execMkdir(trialJobDetail.logPath);
const runScriptContent: string = CONTAINER_INSTALL_NNI_SHELL_FORMAT;
// Write NNI installation file to local files
await fs.promises.writeFile(path.join(trialLocalFolder, 'install_nni.sh'), runScriptContent, { encoding: 'utf8' });
await fs.promises.writeFile(path.join(trialJobDetail.logPath, 'install_nni.sh'), runScriptContent, { encoding: 'utf8' });
// Write file content ( parameter.cfg ) to local working folders
if (trialJobDetail.form !== undefined) {
await fs.promises.writeFile(
path.join(trialLocalFolder, generateParamFileName(trialJobDetail.form.hyperParameters)),
trialJobDetail.form.hyperParameters.value, { encoding: 'utf8' }
);
await this.writeParameterFile(trialJobDetail.logPath, trialJobDetail.form.hyperParameters);
}
//Copy codeDir files to local working folder
await execCopydir(this.paiTrialConfig.codeDir, trialLocalFolder);
await execCopydir(this.paiTrialConfig.codeDir, trialJobDetail.logPath);
const nniManagerIp: string = this.nniManagerIpConfig ? this.nniManagerIpConfig.nniManagerIp : getIPV4Address();
const version: string = this.versionCheck ? await getVersion() : '';
......@@ -298,6 +297,11 @@ class PAIK8STrainingService extends PAITrainingService {
return deferred.promise;
}
private async writeParameterFile(directory: string, hyperParameters: HyperParameters): Promise<void> {
const filepath: string = path.join(directory, generateParamFileName(hyperParameters));
await fs.promises.writeFile(filepath, hyperParameters.value, { encoding: 'utf8' });
}
}
export { PAIK8STrainingService };
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
import { OsCommands } from "../osCommands";
import { RemoteCommandResult } from "../remoteMachineData";
class LinuxCommands extends OsCommands {
public createFolder(folderName: string, sharedFolder: boolean = false): string {
let command;
if (sharedFolder) {
command = `umask 0; mkdir -p '${folderName}'`;
} else {
command = `mkdir -p '${folderName}'`;
}
return command;
}
public allowPermission(isRecursive: boolean = false, ...folders: string[]): string {
const folderString = folders.join("' '");
let command;
if (isRecursive) {
command = `chmod 777 -R '${folderString}'`;
} else {
command = `chmod 777 '${folderString}'`;
}
return command;
}
public removeFolder(folderName: string, isRecursive: boolean = false, isForce: boolean = true): string {
let flags = '';
if (isForce || isRecursive) {
flags = `-${isRecursive ? 'r' : 'd'}${isForce ? 'f' : ''} `;
}
const command = `rm ${flags}'${folderName}'`;
return command;
}
public removeFiles(folderName: string, filePattern: string): string {
const files = this.joinPath(folderName, filePattern);
const command = `rm '${files}'`;
return command;
}
public readLastLines(fileName: string, lineCount: number = 1): string {
const command = `tail -n ${lineCount} '${fileName}'`;
return command;
}
public isProcessAliveCommand(pidFileName: string): string {
const command = `kill -0 \`cat '${pidFileName}'\``;
return command;
}
public isProcessAliveProcessOutput(commandResult: RemoteCommandResult): boolean {
let result = true;
if (commandResult.exitCode !== 0) {
result = false;
}
return result;
}
public killChildProcesses(pidFileName: string): string {
const command = `pkill -P \`cat '${pidFileName}'\``;
return command;
}
public extractFile(tarFileName: string, targetFolder: string): string {
const command = `tar -oxzf '${tarFileName}' -C '${targetFolder}'`;
return command;
}
public executeScript(script: string, isFile: boolean): string {
let command: string;
if (isFile) {
command = `bash '${script}'`;
} else {
script = script.replace('"', '\\"');
command = `bash -c "${script}"`;
}
return command;
}
}
export { LinuxCommands };
......@@ -8,7 +8,7 @@ import { getLogger, Logger } from '../../common/log';
import { randomSelect } from '../../common/utils';
import { GPUInfo } from '../common/gpuData';
import {
parseGpuIndices, RemoteMachineMeta, RemoteMachineScheduleResult, RemoteMachineTrialJobDetail, ScheduleResultType, SSHClientManager
parseGpuIndices, RemoteMachineMeta, RemoteMachineScheduleResult, RemoteMachineTrialJobDetail, ScheduleResultType, ExecutorManager
} from './remoteMachineData';
type SCHEDULE_POLICY_NAME = 'random' | 'round-robin';
......@@ -18,7 +18,7 @@ type SCHEDULE_POLICY_NAME = 'random' | 'round-robin';
*/
export class GPUScheduler {
private readonly machineSSHClientMap: Map<RemoteMachineMeta, SSHClientManager>;
private readonly machineExecutorMap: Map<RemoteMachineMeta, ExecutorManager>;
private readonly log: Logger = getLogger();
private readonly policyName: SCHEDULE_POLICY_NAME = 'round-robin';
private roundRobinIndex: number = 0;
......@@ -26,12 +26,12 @@ export class GPUScheduler {
/**
* Constructor
* @param machineSSHClientMap map from remote machine to sshClient
* @param machineExecutorMap map from remote machine to executor
*/
constructor(machineSSHClientMap: Map<RemoteMachineMeta, SSHClientManager>) {
assert(machineSSHClientMap.size > 0);
this.machineSSHClientMap = machineSSHClientMap;
this.configuredRMs = Array.from(machineSSHClientMap.keys());
constructor(machineExecutorMap: Map<RemoteMachineMeta, ExecutorManager>) {
assert(machineExecutorMap.size > 0);
this.machineExecutorMap = machineExecutorMap;
this.configuredRMs = Array.from(machineExecutorMap.keys());
}
/**
......@@ -43,7 +43,7 @@ export class GPUScheduler {
requiredGPUNum = 0;
}
assert(requiredGPUNum >= 0);
const allRMs: RemoteMachineMeta[] = Array.from(this.machineSSHClientMap.keys());
const allRMs: RemoteMachineMeta[] = Array.from(this.machineExecutorMap.keys());
assert(allRMs.length > 0);
// Step 1: Check if required GPU number not exceeds the total GPU number in all machines
......@@ -135,7 +135,7 @@ export class GPUScheduler {
*/
private gpuResourceDetection(): Map<RemoteMachineMeta, GPUInfo[]> {
const totalResourceMap: Map<RemoteMachineMeta, GPUInfo[]> = new Map<RemoteMachineMeta, GPUInfo[]>();
this.machineSSHClientMap.forEach((sshClientManager: SSHClientManager, rmMeta: RemoteMachineMeta) => {
this.machineExecutorMap.forEach((executorManager: ExecutorManager, rmMeta: RemoteMachineMeta) => {
// Assgin totoal GPU count as init available GPU number
if (rmMeta.gpuSummary !== undefined) {
const availableGPUs: GPUInfo[] = [];
......
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
import { RemoteCommandResult } from "./remoteMachineData";
abstract class OsCommands {
protected pathSpliter: string = '/';
protected multiplePathSpliter: RegExp = new RegExp(`\\${this.pathSpliter}{2,}`);
public abstract createFolder(folderName: string, sharedFolder: boolean): string;
public abstract allowPermission(isRecursive: boolean, ...folders: string[]): string;
public abstract removeFolder(folderName: string, isRecursive: boolean, isForce: boolean): string;
public abstract removeFiles(folderOrFileName: string, filePattern: string): string;
public abstract readLastLines(fileName: string, lineCount: number): string;
public abstract isProcessAliveCommand(pidFileName: string): string;
public abstract isProcessAliveProcessOutput(result: RemoteCommandResult): boolean;
public abstract killChildProcesses(pidFileName: string): string;
public abstract extractFile(tarFileName: string, targetFolder: string): string;
public abstract executeScript(script: string, isFile: boolean): string;
public joinPath(...paths: string[]): string {
let dir: string = paths.filter((path: any) => path !== '').join(this.pathSpliter);
if (dir === '') {
dir = '.';
} else {
dir = dir.replace(this.multiplePathSpliter, this.pathSpliter);
}
return dir;
}
}
export { OsCommands };
......@@ -3,11 +3,9 @@
'use strict';
import * as fs from 'fs';
import { Client, ConnectConfig } from 'ssh2';
import { Deferred } from 'ts-deferred';
import { TrialJobApplicationForm, TrialJobDetail, TrialJobStatus } from '../../common/trainingService';
import { TrialJobApplicationForm, TrialJobDetail, TrialJobStatus } from '../../common/trainingService';
import { GPUInfo, GPUSummary } from '../common/gpuData';
import { ShellExecutor } from './shellExecutor';
/**
* Metadata of remote machine for configuration and statuc query
......@@ -72,7 +70,7 @@ export class RemoteMachineTrialJobDetail implements TrialJobDetail {
public gpuIndices: GPUInfo[];
constructor(id: string, status: TrialJobStatus, submitTime: number,
workingDirectory: string, form: TrialJobApplicationForm) {
workingDirectory: string, form: TrialJobApplicationForm) {
this.id = id;
this.status = status;
this.submitTime = submitTime;
......@@ -84,149 +82,88 @@ export class RemoteMachineTrialJobDetail implements TrialJobDetail {
}
/**
* The remote machine ssh client used for trial and gpu detector
* The remote machine executor manager
*/
export class SSHClient {
private readonly sshClient: Client;
private usedConnectionNumber: number; //count the connection number of every client
constructor(sshClient: Client, usedConnectionNumber: number) {
this.sshClient = sshClient;
this.usedConnectionNumber = usedConnectionNumber;
}
public get getSSHClientInstance(): Client {
return this.sshClient;
}
public get getUsedConnectionNumber(): number {
return this.usedConnectionNumber;
}
public addUsedConnectionNumber(): void {
this.usedConnectionNumber += 1;
}
public minusUsedConnectionNumber(): void {
this.usedConnectionNumber -= 1;
}
}
/**
* The remote machine ssh client manager
*/
export class SSHClientManager {
private readonly sshClientArray: SSHClient[];
export class ExecutorManager {
private readonly executorArray: ShellExecutor[];
private readonly maxTrialNumberPerConnection: number;
private readonly rmMeta: RemoteMachineMeta;
constructor(sshClientArray: SSHClient[], maxTrialNumberPerConnection: number, rmMeta: RemoteMachineMeta) {
constructor(executorArray: ShellExecutor[], maxTrialNumberPerConnection: number, rmMeta: RemoteMachineMeta) {
this.rmMeta = rmMeta;
this.sshClientArray = sshClientArray;
this.executorArray = executorArray;
this.maxTrialNumberPerConnection = maxTrialNumberPerConnection;
}
/**
* find a available ssh client in ssh array, if no ssh client available, return undefined
* find a available executor, if no executor available, return a new one
*/
public async getAvailableSSHClient(): Promise<Client> {
const deferred: Deferred<Client> = new Deferred<Client>();
for (const index of this.sshClientArray.keys()) {
const connectionNumber: number = this.sshClientArray[index].getUsedConnectionNumber;
public async getAvailableExecutor(): Promise<ShellExecutor> {
for (const index of this.executorArray.keys()) {
const connectionNumber: number = this.executorArray[index].getUsedConnectionNumber;
if (connectionNumber < this.maxTrialNumberPerConnection) {
this.sshClientArray[index].addUsedConnectionNumber();
deferred.resolve(this.sshClientArray[index].getSSHClientInstance);
this.executorArray[index].addUsedConnectionNumber();
return deferred.promise;
return this.executorArray[index];
}
}
//init a new ssh client if could not get an available one
return this.initNewSSHClient();
//init a new executor if could not get an available one
return await this.initNewShellExecutor();
}
/**
* add a new ssh client to sshClientArray
* @param sshClient SSH Client
* add a new executor to executorArray
* @param executor ShellExecutor
*/
public addNewSSHClient(client: Client): void {
this.sshClientArray.push(new SSHClient(client, 1));
public addNewShellExecutor(executor: ShellExecutor): void {
this.executorArray.push(executor);
}
/**
* first ssh client instance is used for gpu collector and host job
* first executor instance is used for gpu collector and host job
*/
public getFirstSSHClient(): Client {
return this.sshClientArray[0].getSSHClientInstance;
public getFirstExecutor(): ShellExecutor {
return this.executorArray[0];
}
/**
* close all of ssh client
* close all of executor
*/
public closeAllSSHClient(): void {
for (const sshClient of this.sshClientArray) {
sshClient.getSSHClientInstance.end();
public closeAllExecutor(): void {
for (const executor of this.executorArray) {
executor.close();
}
}
/**
* retrieve resource, minus a number for given ssh client
* @param client SSH Client
* retrieve resource, minus a number for given executor
* @param executor executor
*/
public releaseConnection(client: Client | undefined): void {
if (client === undefined) {
throw new Error(`could not release a undefined ssh client`);
public releaseConnection(executor: ShellExecutor | undefined): void {
if (executor === undefined) {
throw new Error(`could not release a undefined executor`);
}
for (const index of this.sshClientArray.keys()) {
if (this.sshClientArray[index].getSSHClientInstance === client) {
this.sshClientArray[index].minusUsedConnectionNumber();
for (const index of this.executorArray.keys()) {
if (this.executorArray[index] === executor) {
this.executorArray[index].minusUsedConnectionNumber();
break;
}
}
}
/**
* Create a new ssh connection client and initialize it
* Create a new connection executor and initialize it
*/
private initNewSSHClient(): Promise<Client> {
const deferred: Deferred<Client> = new Deferred<Client>();
const conn: Client = new Client();
const connectConfig: ConnectConfig = {
host: this.rmMeta.ip,
port: this.rmMeta.port,
username: this.rmMeta.username,
tryKeyboard: true };
if (this.rmMeta.passwd !== undefined) {
connectConfig.password = this.rmMeta.passwd;
} else if (this.rmMeta.sshKeyPath !== undefined) {
if (!fs.existsSync(this.rmMeta.sshKeyPath)) {
//SSh key path is not a valid file, reject
deferred.reject(new Error(`${this.rmMeta.sshKeyPath} does not exist.`));
}
const privateKey: string = fs.readFileSync(this.rmMeta.sshKeyPath, 'utf8');
connectConfig.privateKey = privateKey;
connectConfig.passphrase = this.rmMeta.passphrase;
} else {
deferred.reject(new Error(`No valid passwd or sshKeyPath is configed.`));
}
conn.on('ready', () => {
this.addNewSSHClient(conn);
deferred.resolve(conn);
})
.on('error', (err: Error) => {
// SSH connection error, reject with error message
deferred.reject(new Error(err.message));
}).on("keyboard-interactive", (name, instructions, lang, prompts, finish) => {
finish([this.rmMeta.passwd]);
})
.connect(connectConfig);
return deferred.promise;
private async initNewShellExecutor(): Promise<ShellExecutor> {
const executor = new ShellExecutor();
await executor.initialize(this.rmMeta);
return executor;
}
}
export type RemoteMachineScheduleResult = { scheduleInfo: RemoteMachineScheduleInfo | undefined; resultType: ScheduleResultType};
export type RemoteMachineScheduleResult = { scheduleInfo: RemoteMachineScheduleInfo | undefined; resultType: ScheduleResultType };
export type RemoteMachineScheduleInfo = { rmMeta: RemoteMachineMeta; cudaVisibleDevice: string};
export type RemoteMachineScheduleInfo = { rmMeta: RemoteMachineMeta; cudaVisibleDevice: string };
export enum ScheduleResultType {
// Schedule succeeded
......@@ -240,7 +177,7 @@ export enum ScheduleResultType {
}
export const REMOTEMACHINE_TRIAL_COMMAND_FORMAT: string =
`#!/bin/bash
`#!/bin/bash
export NNI_PLATFORM=remote NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={2} NNI_EXP_ID={3} \
NNI_TRIAL_SEQ_ID={4} export MULTI_PHASE={5}
cd $NNI_SYS_DIR
......@@ -251,7 +188,7 @@ python3 -m nni_trial_tool.trial_keeper --trial_command '{7}' --nnimanager_ip '{8
echo $? \`date +%s%3N\` >{12}`;
export const HOST_JOB_SHELL_FORMAT: string =
`#!/bin/bash
`#!/bin/bash
cd {0}
echo $$ >{1}
eval {2} >stdout 2>stderr
......
......@@ -7,7 +7,6 @@ import * as assert from 'assert';
import { EventEmitter } from 'events';
import * as fs from 'fs';
import * as path from 'path';
import { Client } from 'ssh2';
import { Deferred } from 'ts-deferred';
import { String } from 'typescript-string-operations';
import * as component from '../../common/component';
......@@ -30,22 +29,22 @@ import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
import { execCopydir, execMkdir, validateCodeDir, getGpuMetricsCollectorBashScriptContent } from '../common/util';
import { GPUScheduler } from './gpuScheduler';
import {
RemoteCommandResult, REMOTEMACHINE_TRIAL_COMMAND_FORMAT, RemoteMachineMeta,
REMOTEMACHINE_TRIAL_COMMAND_FORMAT, RemoteMachineMeta,
RemoteMachineScheduleInfo, RemoteMachineScheduleResult, RemoteMachineTrialJobDetail,
ScheduleResultType, SSHClientManager
ScheduleResultType, ExecutorManager
} from './remoteMachineData';
import { RemoteMachineJobRestServer } from './remoteMachineJobRestServer';
import { SSHClientUtility } from './sshClientUtility';
import { ShellExecutor } from 'training_service/remote_machine/shellExecutor';
/**
* Training Service implementation for Remote Machine (Linux)
*/
@component.Singleton
class RemoteMachineTrainingService implements TrainingService {
private readonly machineSSHClientMap: Map<RemoteMachineMeta, SSHClientManager>; //machine ssh client map
private readonly trialSSHClientMap: Map<string, Client>; //trial ssh client map
private readonly machineExecutorManagerMap: Map<RemoteMachineMeta, ExecutorManager>; //machine excutor map
private readonly trialExecutorMap: Map<string, ShellExecutor>; //trial excutor map
private readonly trialJobsMap: Map<string, RemoteMachineTrialJobDetail>;
private readonly MAX_TRIAL_NUMBER_PER_SSHCONNECTION: number = 5; // every ssh client has a max trial concurrency number
private readonly MAX_TRIAL_NUMBER_PER_EXECUTOR: number = 5; // every excutor has a max trial concurrency number
private readonly expRootDir: string;
private readonly remoteExpRootDir: string;
private trialConfig: TrialConfig | undefined;
......@@ -67,8 +66,8 @@ class RemoteMachineTrainingService implements TrainingService {
this.remoteOS = 'linux';
this.metricsEmitter = new EventEmitter();
this.trialJobsMap = new Map<string, RemoteMachineTrialJobDetail>();
this.trialSSHClientMap = new Map<string, Client>();
this.machineSSHClientMap = new Map<RemoteMachineMeta, SSHClientManager>();
this.trialExecutorMap = new Map<string, ShellExecutor>();
this.machineExecutorManagerMap = new Map<RemoteMachineMeta, ExecutorManager>();
this.jobQueue = [];
this.expRootDir = getExperimentRootDir();
this.remoteExpRootDir = this.getRemoteExperimentRootDir();
......@@ -111,38 +110,34 @@ class RemoteMachineTrainingService implements TrainingService {
}
/**
* give trial a ssh connection
* give trial an executor
* @param trial remote machine trial job detail
*/
public async allocateSSHClientForTrial(trial: RemoteMachineTrialJobDetail): Promise<void> {
const deferred: Deferred<void> = new Deferred<void>();
public async allocateExecutorForTrial(trial: RemoteMachineTrialJobDetail): Promise<void> {
if (trial.rmMeta === undefined) {
throw new Error(`rmMeta not set in trial ${trial.id}`);
}
const sshClientManager: SSHClientManager | undefined = this.machineSSHClientMap.get(trial.rmMeta);
if (sshClientManager === undefined) {
throw new Error(`remoteSSHClient not initialized`);
const executorManager: ExecutorManager | undefined = this.machineExecutorManagerMap.get(trial.rmMeta);
if (executorManager === undefined) {
throw new Error(`executorManager not initialized`);
}
const sshClient: Client = await sshClientManager.getAvailableSSHClient();
this.trialSSHClientMap.set(trial.id, sshClient);
deferred.resolve();
return deferred.promise;
const shellExecutor: ShellExecutor = await executorManager.getAvailableExecutor();
this.trialExecutorMap.set(trial.id, shellExecutor);
}
/**
* If a trial is finished, release the connection resource
* @param trial remote machine trial job detail
*/
public releaseTrialSSHClient(trial: RemoteMachineTrialJobDetail): void {
public releaseTrialExecutor(trial: RemoteMachineTrialJobDetail): void {
if (trial.rmMeta === undefined) {
throw new Error(`rmMeta not set in trial ${trial.id}`);
}
const sshClientManager: SSHClientManager | undefined = this.machineSSHClientMap.get(trial.rmMeta);
if (sshClientManager === undefined) {
throw new Error(`sshClientManager not initialized`);
const executorManager: ExecutorManager | undefined = this.machineExecutorManagerMap.get(trial.rmMeta);
if (executorManager === undefined) {
throw new Error(`executorManager not initialized`);
}
sshClientManager.releaseConnection(this.trialSSHClientMap.get(trial.id));
executorManager.releaseConnection(this.trialExecutorMap.get(trial.id));
}
/**
......@@ -152,7 +147,7 @@ class RemoteMachineTrainingService implements TrainingService {
const jobs: TrialJobDetail[] = [];
const deferred: Deferred<TrialJobDetail[]> = new Deferred<TrialJobDetail[]>();
for (const [key, value] of this.trialJobsMap) {
for (const [key,] of this.trialJobsMap) {
jobs.push(await this.getTrialJob(key));
}
deferred.resolve(jobs);
......@@ -171,16 +166,16 @@ class RemoteMachineTrainingService implements TrainingService {
}
//TO DO: add another job status, and design new job status change logic
if (trialJob.status === 'RUNNING' || trialJob.status === 'UNKNOWN') {
// Get ssh client where the job is running
// Get executor where the job is running
if (trialJob.rmMeta === undefined) {
throw new Error(`rmMeta not set for submitted job ${trialJobId}`);
}
const sshClient: Client | undefined = this.trialSSHClientMap.get(trialJob.id);
if (sshClient === undefined) {
throw new Error(`Invalid job id: ${trialJobId}, cannot find ssh client`);
const executor: ShellExecutor | undefined = this.trialExecutorMap.get(trialJob.id);
if (executor === undefined) {
throw new Error(`Invalid job id: ${trialJobId}, cannot find executor`);
}
return this.updateTrialJobStatus(trialJob, sshClient);
return this.updateTrialJobStatus(trialJob, executor);
} else {
return trialJob;
}
......@@ -255,10 +250,8 @@ class RemoteMachineTrainingService implements TrainingService {
* @param trialJobId ID of trial job
*/
public async cancelTrialJob(trialJobId: string, isEarlyStopped: boolean = false): Promise<void> {
const deferred: Deferred<void> = new Deferred<void>();
const trialJob: RemoteMachineTrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
if (trialJob === undefined) {
deferred.reject();
throw new Error(`trial job id ${trialJobId} not found`);
}
......@@ -268,17 +261,16 @@ class RemoteMachineTrainingService implements TrainingService {
this.jobQueue.splice(index, 1);
}
// Get ssh client where the job is running
// Get executor where the job is running
if (trialJob.rmMeta !== undefined) {
// If the trial job is already scheduled, check its status and kill the trial process in remote machine
const sshClient: Client | undefined = this.trialSSHClientMap.get(trialJob.id);
if (sshClient === undefined) {
deferred.reject();
throw new Error(`Invalid job id ${trialJobId}, cannot find ssh client`);
const executor: ShellExecutor | undefined = this.trialExecutorMap.get(trialJob.id);
if (executor === undefined) {
throw new Error(`Invalid job id ${trialJobId}, cannot find executor`);
}
if (trialJob.status === 'UNKNOWN') {
this.releaseTrialSSHClient(trialJob);
this.releaseTrialExecutor(trialJob);
trialJob.status = 'USER_CANCELED';
return
}
......@@ -287,8 +279,8 @@ class RemoteMachineTrainingService implements TrainingService {
try {
// Mark the toEarlyStop tag here
trialJob.isEarlyStopped = isEarlyStopped;
await SSHClientUtility.remoteExeCommand(`pkill -P \`cat ${jobpidPath}\``, sshClient);
this.releaseTrialSSHClient(trialJob);
await executor.killChildProcesses(jobpidPath);
this.releaseTrialExecutor(trialJob);
} catch (error) {
// Not handle the error since pkill failed will not impact trial job's current status
this.log.error(`remoteTrainingService.cancelTrialJob: ${error.message}`);
......@@ -303,7 +295,7 @@ class RemoteMachineTrainingService implements TrainingService {
/**
* Set culster metadata
* @param key metadata key
* //1. MACHINE_LIST -- create ssh client connect of machine list
* //1. MACHINE_LIST -- create executor of machine list
* //2. TRIAL_CONFIG -- trial configuration
* @param value metadata value
*/
......@@ -314,7 +306,7 @@ class RemoteMachineTrainingService implements TrainingService {
break;
case TrialConfigMetadataKey.MACHINE_LIST:
await this.setupConnections(value);
this.gpuScheduler = new GPUScheduler(this.machineSSHClientMap);
this.gpuScheduler = new GPUScheduler(this.machineExecutorManagerMap);
break;
case TrialConfigMetadataKey.TRIAL_CONFIG: {
const remoteMachineTrailConfig: TrialConfig = <TrialConfig>JSON.parse(value);
......@@ -324,7 +316,7 @@ class RemoteMachineTrainingService implements TrainingService {
}
// codeDir is not a valid directory, throw Error
if (!fs.lstatSync(remoteMachineTrailConfig.codeDir)
.isDirectory()) {
.isDirectory()) {
throw new Error(`codeDir ${remoteMachineTrailConfig.codeDir} is not a directory`);
}
......@@ -359,10 +351,8 @@ class RemoteMachineTrainingService implements TrainingService {
* Get culster metadata
* @param key metadata key
*/
public getClusterMetadata(key: string): Promise<string> {
const deferred: Deferred<string> = new Deferred<string>();
return deferred.promise;
public async getClusterMetadata(key: string): Promise<string> {
return "";
}
/**
......@@ -392,14 +382,14 @@ class RemoteMachineTrainingService implements TrainingService {
*/
private async cleanupConnections(): Promise<void> {
try {
for (const [rmMeta, sshClientManager] of this.machineSSHClientMap.entries()) {
for (const [rmMeta, executorManager] of this.machineExecutorManagerMap.entries()) {
const jobpidPath: string = unixPathJoin(this.getRemoteScriptsPath(rmMeta.username), 'pid');
const client: Client | undefined = sshClientManager.getFirstSSHClient();
if (client !== undefined) {
await SSHClientUtility.remoteExeCommand(`pkill -P \`cat ${jobpidPath}\``, client);
await SSHClientUtility.remoteExeCommand(`rm -rf ${this.getRemoteScriptsPath(rmMeta.username)}`, client);
const executor: ShellExecutor | undefined = executorManager.getFirstExecutor();
if (executor !== undefined) {
await executor.killChildProcesses(jobpidPath);
await executor.removeFolder(this.getRemoteScriptsPath(rmMeta.username));
}
sshClientManager.closeAllSSHClient();
executorManager.closeAllExecutor();
}
} catch (error) {
//ignore error, this function is called to cleanup remote connections when experiment is stopping
......@@ -418,10 +408,10 @@ class RemoteMachineTrainingService implements TrainingService {
rmMetaList.forEach(async (rmMeta: RemoteMachineMeta) => {
rmMeta.occupiedGpuIndexMap = new Map<number, number>();
const sshClientManager: SSHClientManager = new SSHClientManager([], this.MAX_TRIAL_NUMBER_PER_SSHCONNECTION, rmMeta);
const sshClient: Client = await sshClientManager.getAvailableSSHClient();
this.machineSSHClientMap.set(rmMeta, sshClientManager);
await this.initRemoteMachineOnConnected(rmMeta, sshClient);
const executorManager: ExecutorManager = new ExecutorManager([], this.MAX_TRIAL_NUMBER_PER_EXECUTOR, rmMeta);
const executor: ShellExecutor = await executorManager.getAvailableExecutor();
this.machineExecutorManagerMap.set(rmMeta, executorManager);
await this.initRemoteMachineOnConnected(rmMeta, executor);
if (++connectedRMNum === rmMetaList.length) {
deferred.resolve();
}
......@@ -430,26 +420,25 @@ class RemoteMachineTrainingService implements TrainingService {
return deferred.promise;
}
private async initRemoteMachineOnConnected(rmMeta: RemoteMachineMeta, conn: Client): Promise<void> {
// Create root working directory after ssh connection is ready
private async initRemoteMachineOnConnected(rmMeta: RemoteMachineMeta, executor: ShellExecutor): Promise<void> {
// Create root working directory after executor is ready
const nniRootDir: string = unixPathJoin(getRemoteTmpDir(this.remoteOS), 'nni');
await SSHClientUtility.remoteExeCommand(`mkdir -p ${this.remoteExpRootDir}`, conn);
await executor.createFolder(this.remoteExpRootDir);
// the directory to store temp scripts in remote machine
const remoteGpuScriptCollectorDir: string = this.getRemoteScriptsPath(rmMeta.username);
await SSHClientUtility.remoteExeCommand(`(umask 0 ; mkdir -p ${remoteGpuScriptCollectorDir})`, conn);
await SSHClientUtility.remoteExeCommand(`chmod 777 ${nniRootDir} ${nniRootDir}/* ${nniRootDir}/scripts/*`, conn);
await executor.createFolder(remoteGpuScriptCollectorDir, true);
await executor.allowPermission(false, nniRootDir, `${nniRootDir}/*`, `${nniRootDir}/scripts/*`);
//Begin to execute gpu_metrics_collection scripts
const script = getGpuMetricsCollectorBashScriptContent(remoteGpuScriptCollectorDir);
SSHClientUtility.remoteExeCommand(`bash -c '${script}'`, conn);
executor.executeScript(script, false, true);
const disposable: Rx.IDisposable = this.timer.subscribe(
async (tick: number) => {
const cmdresult: RemoteCommandResult = await SSHClientUtility.remoteExeCommand(
`tail -n 1 ${unixPathJoin(remoteGpuScriptCollectorDir, 'gpu_metrics')}`, conn);
if (cmdresult !== undefined && cmdresult.stdout !== undefined && cmdresult.stdout.length > 0) {
rmMeta.gpuSummary = <GPUSummary>JSON.parse(cmdresult.stdout);
async () => {
const cmdresult = await executor.readLastLines(unixPathJoin(remoteGpuScriptCollectorDir, 'gpu_metrics'));
if (cmdresult !== "") {
rmMeta.gpuSummary = <GPUSummary>JSON.parse(cmdresult);
if (rmMeta.gpuSummary.gpuCount === 0) {
this.log.warning(`No GPU found on remote machine ${rmMeta.ip}`);
this.timer.unsubscribe(disposable);
......@@ -478,7 +467,7 @@ class RemoteMachineTrainingService implements TrainingService {
return deferred.promise;
}
// get an ssh client from scheduler
// get an executor from scheduler
const rmScheduleResult: RemoteMachineScheduleResult = this.gpuScheduler.scheduleMachine(this.trialConfig.gpuNum, trialJobDetail);
if (rmScheduleResult.resultType === ScheduleResultType.REQUIRE_EXCEED_TOTAL) {
const errorMessage: string = `Required GPU number ${this.trialConfig.gpuNum} is too large, no machine can meet`;
......@@ -492,7 +481,7 @@ class RemoteMachineTrainingService implements TrainingService {
trialJobDetail.rmMeta = rmScheduleInfo.rmMeta;
await this.allocateSSHClientForTrial(trialJobDetail);
await this.allocateExecutorForTrial(trialJobDetail);
await this.launchTrialOnScheduledMachine(
trialJobId, trialWorkingFolder, trialJobDetail.form, rmScheduleInfo);
......@@ -513,14 +502,14 @@ class RemoteMachineTrainingService implements TrainingService {
}
private async launchTrialOnScheduledMachine(trialJobId: string, trialWorkingFolder: string, form: TrialJobApplicationForm,
rmScheduleInfo: RemoteMachineScheduleInfo): Promise<void> {
rmScheduleInfo: RemoteMachineScheduleInfo): Promise<void> {
if (this.trialConfig === undefined) {
throw new Error('trial config is not initialized');
}
const cudaVisibleDevice: string = rmScheduleInfo.cudaVisibleDevice;
const sshClient: Client | undefined = this.trialSSHClientMap.get(trialJobId);
if (sshClient === undefined) {
assert(false, 'sshClient is undefined.');
const executor: ShellExecutor | undefined = this.trialExecutorMap.get(trialJobId);
if (executor === undefined) {
assert(false, 'ShellExecutor is undefined.');
// for lint
return;
......@@ -532,8 +521,8 @@ class RemoteMachineTrainingService implements TrainingService {
const trialLocalTempFolder: string = path.join(this.expRootDir, 'trials-local', trialJobId);
await SSHClientUtility.remoteExeCommand(`mkdir -p ${trialWorkingFolder}`, sshClient);
await SSHClientUtility.remoteExeCommand(`mkdir -p ${unixPathJoin(trialWorkingFolder, '.nni')}`, sshClient);
await executor.createFolder(trialWorkingFolder);
await executor.createFolder(unixPathJoin(trialWorkingFolder, '.nni'));
// RemoteMachineRunShellFormat is the run shell format string,
// See definition in remoteMachineData.ts
......@@ -586,13 +575,13 @@ class RemoteMachineTrainingService implements TrainingService {
await fs.promises.writeFile(path.join(trialLocalTempFolder, 'run.sh'), runScriptTrialContent, { encoding: 'utf8' });
await this.writeParameterFile(trialJobId, form.hyperParameters);
// Copy files in codeDir to remote working directory
await SSHClientUtility.copyDirectoryToRemote(trialLocalTempFolder, trialWorkingFolder, sshClient, this.remoteOS);
await executor.copyDirectoryToRemote(trialLocalTempFolder, trialWorkingFolder, this.remoteOS);
// Execute command in remote machine
SSHClientUtility.remoteExeCommand(`bash ${unixPathJoin(trialWorkingFolder, 'run.sh')}`, sshClient);
executor.executeScript(unixPathJoin(trialWorkingFolder, 'run.sh'), true, true);
}
private getRmMetaByHost(host: string): RemoteMachineMeta {
for (const [rmMeta, client] of this.machineSSHClientMap.entries()) {
for (const rmMeta of this.machineExecutorManagerMap.keys()) {
if (rmMeta.ip === host) {
return rmMeta;
}
......@@ -600,19 +589,19 @@ class RemoteMachineTrainingService implements TrainingService {
throw new Error(`Host not found: ${host}`);
}
private async updateTrialJobStatus(trialJob: RemoteMachineTrialJobDetail, sshClient: Client): Promise<TrialJobDetail> {
private async updateTrialJobStatus(trialJob: RemoteMachineTrialJobDetail, executor: ShellExecutor): Promise<TrialJobDetail> {
const deferred: Deferred<TrialJobDetail> = new Deferred<TrialJobDetail>();
const jobpidPath: string = this.getJobPidPath(trialJob.id);
const trialReturnCodeFilePath: string = unixPathJoin(this.remoteExpRootDir, 'trials', trialJob.id, '.nni', 'code');
/* eslint-disable require-atomic-updates */
try {
const killResult: number = (await SSHClientUtility.remoteExeCommand(`kill -0 \`cat ${jobpidPath}\``, sshClient)).exitCode;
const isAlive = await executor.isProcessAlive(jobpidPath);
// if the process of jobpid is not alive any more
if (killResult !== 0) {
const trailReturnCode: string = await SSHClientUtility.getRemoteFileContent(trialReturnCodeFilePath, sshClient);
this.log.debug(`trailjob ${trialJob.id} return code: ${trailReturnCode}`);
const match: RegExpMatchArray | null = trailReturnCode.trim()
.match(/^(\d+)\s+(\d+)$/);
if (!isAlive) {
const trialReturnCode: string = await executor.getRemoteFileContent(trialReturnCodeFilePath);
this.log.debug(`trailjob ${trialJob.id} return code: ${trialReturnCode}`);
const match: RegExpMatchArray | null = trialReturnCode.trim()
.match(/^(\d+)\s+(\d+)$/);
if (match !== null) {
const { 1: code, 2: timestamp } = match;
// Update trial job's status based on result code
......@@ -627,7 +616,7 @@ class RemoteMachineTrainingService implements TrainingService {
}
}
trialJob.endTime = parseInt(timestamp, 10);
this.releaseTrialSSHClient(trialJob);
this.releaseTrialExecutor(trialJob);
}
this.log.debug(`trailJob status update: ${trialJob.id}, ${trialJob.status}`);
}
......@@ -671,9 +660,9 @@ class RemoteMachineTrainingService implements TrainingService {
}
private async writeParameterFile(trialJobId: string, hyperParameters: HyperParameters): Promise<void> {
const sshClient: Client | undefined = this.trialSSHClientMap.get(trialJobId);
if (sshClient === undefined) {
throw new Error('sshClient is undefined.');
const executor: ShellExecutor | undefined = this.trialExecutorMap.get(trialJobId);
if (executor === undefined) {
throw new Error('ShellExecutor is undefined.');
}
const trialWorkingFolder: string = unixPathJoin(this.remoteExpRootDir, 'trials', trialJobId);
......@@ -683,7 +672,7 @@ class RemoteMachineTrainingService implements TrainingService {
const localFilepath: string = path.join(trialLocalTempFolder, fileName);
await fs.promises.writeFile(localFilepath, hyperParameters.value, { encoding: 'utf8' });
await SSHClientUtility.copyFileToRemote(localFilepath, unixPathJoin(trialWorkingFolder, fileName), sshClient);
await executor.copyFileToRemote(localFilepath, unixPathJoin(trialWorkingFolder, fileName));
}
}
......
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
import * as assert from 'assert';
import * as os from 'os';
import * as path from 'path';
import * as fs from 'fs';
import { Client, ClientChannel, SFTPWrapper, ConnectConfig } from 'ssh2';
import { Deferred } from "ts-deferred";
import { RemoteCommandResult, RemoteMachineMeta } from "./remoteMachineData";
import * as stream from 'stream';
import { OsCommands } from "./osCommands";
import { LinuxCommands } from "./extends/linuxCommands";
import { getLogger, Logger } from '../../common/log';
import { NNIError, NNIErrorNames } from '../../common/errors';
import { execRemove, tarAdd } from '../common/util';
import { getRemoteTmpDir, uniqueString, unixPathJoin } from '../../common/utils';
class ShellExecutor {
private sshClient: Client = new Client();
private osCommands: OsCommands | undefined;
private usedConnectionNumber: number = 0; //count the connection number of every client
protected pathSpliter: string = '/';
protected multiplePathSpliter: RegExp = new RegExp(`\\${this.pathSpliter}{2,}`);
public async initialize(rmMeta: RemoteMachineMeta): Promise<void> {
const deferred: Deferred<void> = new Deferred<void>();
const connectConfig: ConnectConfig = {
host: rmMeta.ip,
port: rmMeta.port,
username: rmMeta.username,
tryKeyboard: true
};
if (rmMeta.passwd !== undefined) {
connectConfig.password = rmMeta.passwd;
} else if (rmMeta.sshKeyPath !== undefined) {
if (!fs.existsSync(rmMeta.sshKeyPath)) {
//SSh key path is not a valid file, reject
deferred.reject(new Error(`${rmMeta.sshKeyPath} does not exist.`));
}
const privateKey: string = fs.readFileSync(rmMeta.sshKeyPath, 'utf8');
connectConfig.privateKey = privateKey;
connectConfig.passphrase = rmMeta.passphrase;
} else {
deferred.reject(new Error(`No valid passwd or sshKeyPath is configed.`));
}
this.sshClient.on('ready', async () => {
// check OS type: windows or else
const result = await this.execute("ver");
if (result.exitCode == 0 && result.stdout.search("Windows") > -1) {
// not implement Windows commands yet.
throw new Error("not implement Windows commands yet.");
} else {
this.osCommands = new LinuxCommands();
}
deferred.resolve();
}).on('error', (err: Error) => {
// SSH connection error, reject with error message
deferred.reject(new Error(err.message));
}).on("keyboard-interactive", (name, instructions, lang, prompts, finish) => {
finish([rmMeta.passwd]);
}).connect(connectConfig);
return deferred.promise;
}
public close(): void {
this.sshClient.end();
}
public get getUsedConnectionNumber(): number {
return this.usedConnectionNumber;
}
public addUsedConnectionNumber(): void {
this.usedConnectionNumber += 1;
}
public minusUsedConnectionNumber(): void {
this.usedConnectionNumber -= 1;
}
public async createFolder(folderName: string, sharedFolder: boolean = false): Promise<boolean> {
const commandText = this.osCommands && this.osCommands.createFolder(folderName, sharedFolder);
const commandResult = await this.execute(commandText);
const result = commandResult.exitCode >= 0;
return result;
}
public async allowPermission(isRecursive: boolean = false, ...folders: string[]): Promise<boolean> {
const commandText = this.osCommands && this.osCommands.allowPermission(isRecursive, ...folders);
const commandResult = await this.execute(commandText);
const result = commandResult.exitCode >= 0;
return result;
}
public async removeFolder(folderName: string, isRecursive: boolean = false, isForce: boolean = true): Promise<boolean> {
const commandText = this.osCommands && this.osCommands.removeFolder(folderName, isRecursive, isForce);
const commandResult = await this.execute(commandText);
const result = commandResult.exitCode >= 0;
return result;
}
public async removeFiles(folderOrFileName: string, filePattern: string = ""): Promise<boolean> {
const commandText = this.osCommands && this.osCommands.removeFiles(folderOrFileName, filePattern);
const commandResult = await this.execute(commandText);
const result = commandResult.exitCode >= 0;
return result;
}
public async readLastLines(fileName: string, lineCount: number = 1): Promise<string> {
const commandText = this.osCommands && this.osCommands.readLastLines(fileName, lineCount);
const commandResult = await this.execute(commandText);
let result: string = "";
if (commandResult !== undefined && commandResult.stdout !== undefined && commandResult.stdout.length > 0) {
result = commandResult.stdout;
}
return result;
}
public async isProcessAlive(pidFileName: string): Promise<boolean> {
const commandText = this.osCommands && this.osCommands.isProcessAliveCommand(pidFileName);
const commandResult = await this.execute(commandText);
const result = this.osCommands && this.osCommands.isProcessAliveProcessOutput(commandResult);
return result !== undefined ? result : false;
}
public async killChildProcesses(pidFileName: string): Promise<boolean> {
const commandText = this.osCommands && this.osCommands.killChildProcesses(pidFileName);
const commandResult = await this.execute(commandText);
return commandResult.exitCode == 0;
}
public async extractFile(tarFileName: string, targetFolder: string): Promise<boolean> {
const commandText = this.osCommands && this.osCommands.extractFile(tarFileName, targetFolder);
const commandResult = await this.execute(commandText);
return commandResult.exitCode == 0;
}
public async executeScript(script: string, isFile: boolean, isInteractive: boolean = false): Promise<boolean> {
const commandText = this.osCommands && this.osCommands.executeScript(script, isFile);
const commandResult = await this.execute(commandText, undefined, isInteractive);
return commandResult.exitCode == 0;
}
/**
* Copy local file to remote path
* @param localFilePath the path of local file
* @param remoteFilePath the target path in remote machine
*/
public async copyFileToRemote(localFilePath: string, remoteFilePath: string): Promise<boolean> {
const log: Logger = getLogger();
log.debug(`copyFileToRemote: localFilePath: ${localFilePath}, remoteFilePath: ${remoteFilePath}`);
const deferred: Deferred<boolean> = new Deferred<boolean>();
this.sshClient.sftp((err: Error, sftp: SFTPWrapper) => {
if (err !== undefined && err !== null) {
log.error(`copyFileToRemote: ${err.message}, ${localFilePath}, ${remoteFilePath}`);
deferred.reject(err);
return;
}
assert(sftp !== undefined);
sftp.fastPut(localFilePath, remoteFilePath, (fastPutErr: Error) => {
sftp.end();
if (fastPutErr !== undefined && fastPutErr !== null) {
deferred.reject(fastPutErr);
} else {
deferred.resolve(true);
}
});
});
return deferred.promise;
}
/**
* Copy files and directories in local directory recursively to remote directory
* @param localDirectory local diretory
* @param remoteDirectory remote directory
* @param sshClient SSH client
*/
public async copyDirectoryToRemote(localDirectory: string, remoteDirectory: string, remoteOS: string): Promise<void> {
const tmpSuffix: string = uniqueString(5);
const localTarPath: string = path.join(os.tmpdir(), `nni_tmp_local_${tmpSuffix}.tar.gz`);
const remoteTarPath: string = unixPathJoin(getRemoteTmpDir(remoteOS), `nni_tmp_remote_${tmpSuffix}.tar.gz`);
// Compress files in local directory to experiment root directory
await tarAdd(localTarPath, localDirectory);
// Copy the compressed file to remoteDirectory and delete it
await this.copyFileToRemote(localTarPath, remoteTarPath);
await execRemove(localTarPath);
// Decompress the remote compressed file in and delete it
await this.extractFile(remoteTarPath, remoteDirectory);
await this.removeFiles(remoteTarPath);
}
public async getRemoteFileContent(filePath: string): Promise<string> {
const deferred: Deferred<string> = new Deferred<string>();
this.sshClient.sftp((err: Error, sftp: SFTPWrapper) => {
if (err !== undefined && err !== null) {
getLogger()
.error(`getRemoteFileContent: ${err.message}`);
deferred.reject(new Error(`SFTP error: ${err.message}`));
return;
}
try {
const sftpStream: stream.Readable = sftp.createReadStream(filePath);
let dataBuffer: string = '';
sftpStream.on('data', (data: Buffer | string) => {
dataBuffer += data;
})
.on('error', (streamErr: Error) => {
sftp.end();
deferred.reject(new NNIError(NNIErrorNames.NOT_FOUND, streamErr.message));
})
.on('end', () => {
// sftp connection need to be released manually once operation is done
sftp.end();
deferred.resolve(dataBuffer);
});
} catch (error) {
getLogger()
.error(`getRemoteFileContent: ${error.message}`);
sftp.end();
deferred.reject(new Error(`SFTP error: ${error.message}`));
}
});
return deferred.promise;
}
private async execute(command: string | undefined, processOutput: ((input: RemoteCommandResult) => RemoteCommandResult) | undefined = undefined, useShell: boolean = false): Promise<RemoteCommandResult> {
const log: Logger = getLogger();
log.debug(`remoteExeCommand: command: [${command}]`);
const deferred: Deferred<RemoteCommandResult> = new Deferred<RemoteCommandResult>();
let stdout: string = '';
let stderr: string = '';
let exitCode: number;
const callback = (err: Error, channel: ClientChannel): void => {
if (err !== undefined && err !== null) {
log.error(`remoteExeCommand: ${err.message}`);
deferred.reject(err);
return;
}
channel.on('data', (data: any) => {
stdout += data;
});
channel.on('exit', (code: any) => {
exitCode = <number>code;
log.debug(`remoteExeCommand exit(${exitCode})\nstdout: ${stdout}\nstderr: ${stderr}`);
let result = {
stdout: stdout,
stderr: stderr,
exitCode: exitCode
};
if (processOutput != undefined) {
result = processOutput(result);
}
deferred.resolve(result);
});
channel.stderr.on('data', function (data) {
stderr += data;
});
if (useShell) {
channel.stdin.write(`${command}\n`);
channel.end("exit\n");
}
return;
};
if (useShell) {
this.sshClient.shell(callback);
} else {
this.sshClient.exec(command !== undefined ? command : "", callback);
}
return deferred.promise;
}
}
export { ShellExecutor };
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment