Unverified Commit e29b58a1 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #244 from microsoft/master

merge master
parents e0c2c0eb 4f88be1f
...@@ -173,12 +173,12 @@ install-python-modules: ...@@ -173,12 +173,12 @@ install-python-modules:
dev-install-python-modules: dev-install-python-modules:
#$(_INFO) Installing Python SDK $(_END) #$(_INFO) Installing Python SDK $(_END)
mkdir -p build mkdir -p build
ln -sf ../src/sdk/pynni/nni build/nni ln -sf ../src/sdk/pynni/nni build
ln -sf ../src/sdk/pynni/nnicli build/nnicli ln -sf ../src/sdk/pycli/nnicli build
ln -sf ../tools/nni_annotation build/nni_annotation ln -sf ../tools/nni_annotation build
ln -sf ../tools/nni_cmd build/nni_cmd ln -sf ../tools/nni_cmd build
ln -sf ../tools/nni_trial_tool build/nni_trial_tool ln -sf ../tools/nni_trial_tool build
ln -sf ../tools/nni_gpu_tool build/nni_gpu_tool ln -sf ../tools/nni_gpu_tool build
cp setup.py build/ cp setup.py build/
cp README.md build/ cp README.md build/
sed -ie 's/$(NNI_VERSION_TEMPLATE)/$(NNI_VERSION_VALUE)/' build/setup.py sed -ie 's/$(NNI_VERSION_TEMPLATE)/$(NNI_VERSION_VALUE)/' build/setup.py
...@@ -209,10 +209,12 @@ dev-install-node-modules: ...@@ -209,10 +209,12 @@ dev-install-node-modules:
ln -sf ${PWD}/src/nni_manager/dist $(NNI_PKG_FOLDER) ln -sf ${PWD}/src/nni_manager/dist $(NNI_PKG_FOLDER)
cp src/nni_manager/package.json $(NNI_PKG_FOLDER) cp src/nni_manager/package.json $(NNI_PKG_FOLDER)
sed -ie 's/$(NNI_VERSION_TEMPLATE)/$(NNI_VERSION_VALUE)/' $(NNI_PKG_FOLDER)/package.json sed -ie 's/$(NNI_VERSION_TEMPLATE)/$(NNI_VERSION_VALUE)/' $(NNI_PKG_FOLDER)/package.json
ln -sf ${PWD}/src/nni_manager/node_modules $(NNI_PKG_FOLDER)/node_modules ln -sf ${PWD}/src/nni_manager/node_modules $(NNI_PKG_FOLDER)
ln -sf ${PWD}/src/webui/build $(NNI_PKG_FOLDER)/static ln -sf ${PWD}/src/webui/build -t $(NNI_PKG_FOLDER)
ln -sf ${PWD}/src/nasui/build $(NASUI_PKG_FOLDER)/build mv $(NNI_PKG_FOLDER)/build $(NNI_PKG_FOLDER)/static
ln -sf ${PWD}/src/nasui/server.js $(NASUI_PKG_FOLDER)/server.js mkdir -p $(NASUI_PKG_FOLDER)
ln -sf ${PWD}/src/nasui/build $(NASUI_PKG_FOLDER)
ln -sf ${PWD}/src/nasui/server.js $(NASUI_PKG_FOLDER)
.PHONY: install-scripts .PHONY: install-scripts
install-scripts: install-scripts:
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
**NNI (Neural Network Intelligence)** is a lightweight but powerful toolkit to help users **automate** <a href="docs/en_US/FeatureEngineering/Overview.md">Feature Engineering</a>, <a href="docs/en_US/NAS/Overview.md">Neural Architecture Search</a>, <a href="docs/en_US/Tuner/BuiltinTuner.md">Hyperparameter Tuning</a> and <a href="docs/en_US/Compressor/Overview.md">Model Compression</a>. **NNI (Neural Network Intelligence)** is a lightweight but powerful toolkit to help users **automate** <a href="docs/en_US/FeatureEngineering/Overview.md">Feature Engineering</a>, <a href="docs/en_US/NAS/Overview.md">Neural Architecture Search</a>, <a href="docs/en_US/Tuner/BuiltinTuner.md">Hyperparameter Tuning</a> and <a href="docs/en_US/Compressor/Overview.md">Model Compression</a>.
The tool manages automated machine learning (AutoML) experiments, **dispatches and runs** experiments' trial jobs generated by tuning algorithms to search the best neural architecture and/or hyper-parameters in **different training environments** like <a href="docs/en_US/TrainingService/LocalMode.md">Local Machine</a>, <a href="docs/en_US/TrainingService/RemoteMachineMode.md">Remote Servers</a>, <a href="docs/en_US/TrainingService/PaiMode.md">OpenPAI</a>, <a href="docs/en_US/TrainingService/KubeflowMode.md">Kubeflow</a>, <a href="docs/en_US/TrainingService/FrameworkControllerMode.md">FrameworkController on K8S (AKS etc.)</a> and other cloud options. The tool manages automated machine learning (AutoML) experiments, **dispatches and runs** experiments' trial jobs generated by tuning algorithms to search the best neural architecture and/or hyper-parameters in **different training environments** like <a href="docs/en_US/TrainingService/LocalMode.md">Local Machine</a>, <a href="docs/en_US/TrainingService/RemoteMachineMode.md">Remote Servers</a>, <a href="docs/en_US/TrainingService/PaiMode.md">OpenPAI</a>, <a href="docs/en_US/TrainingService/KubeflowMode.md">Kubeflow</a>, <a href="docs/en_US/TrainingService/FrameworkControllerMode.md">FrameworkController on K8S (AKS etc.)</a>, <a href="docs/en_US/TrainingService/DLTSMode.md">DLWorkspace (aka. DLTS)</a> and other cloud options.
## **Who should consider using NNI** ## **Who should consider using NNI**
...@@ -25,7 +25,7 @@ The tool manages automated machine learning (AutoML) experiments, **dispatches a ...@@ -25,7 +25,7 @@ The tool manages automated machine learning (AutoML) experiments, **dispatches a
* Researchers and data scientists who want to easily **implement and experiment new AutoML algorithms**, may it be: hyperparameter tuning algorithm, neural architect search algorithm or model compression algorithm. * Researchers and data scientists who want to easily **implement and experiment new AutoML algorithms**, may it be: hyperparameter tuning algorithm, neural architect search algorithm or model compression algorithm.
* ML Platform owners who want to **support AutoML in their platform**. * ML Platform owners who want to **support AutoML in their platform**.
### **NNI v1.5 has been released! &nbsp;<a href="#nni-released-reminder"><img width="48" src="docs/img/release_icon.png"></a>** ### **[NNI v1.5 has been released!](https://github.com/microsoft/nni/releases) &nbsp;<a href="#nni-released-reminder"><img width="48" src="docs/img/release_icon.png"></a>**
## **NNI capabilities in a glance** ## **NNI capabilities in a glance**
...@@ -170,6 +170,7 @@ Within the following table, we summarized the current NNI capabilities, we are g ...@@ -170,6 +170,7 @@ Within the following table, we summarized the current NNI capabilities, we are g
<li><a href="docs/en_US/TrainingService/KubeflowMode.md">Kubeflow</a></li> <li><a href="docs/en_US/TrainingService/KubeflowMode.md">Kubeflow</a></li>
<li><a href="docs/en_US/TrainingService/FrameworkControllerMode.md">FrameworkController on K8S (AKS etc.)</a></li> <li><a href="docs/en_US/TrainingService/FrameworkControllerMode.md">FrameworkController on K8S (AKS etc.)</a></li>
</ul> </ul>
<ul><li><a href="docs/en_US/TrainingService/DLTSMode.md">DLWorkspace (aka. DLTS)</a></li>
</ul> </ul>
</td> </td>
</tr> </tr>
...@@ -334,10 +335,15 @@ With authors' permission, we listed a set of NNI usage examples and relevant art ...@@ -334,10 +335,15 @@ With authors' permission, we listed a set of NNI usage examples and relevant art
* **Blog (in Chinese)** - [A summary of NNI new capabilities in 2019](https://mp.weixin.qq.com/s/7_KRT-rRojQbNuJzkjFMuA) by @squirrelsc * **Blog (in Chinese)** - [A summary of NNI new capabilities in 2019](https://mp.weixin.qq.com/s/7_KRT-rRojQbNuJzkjFMuA) by @squirrelsc
## **Feedback** ## **Feedback**
* Discuss on the NNI [Gitter](https://gitter.im/Microsoft/nni?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) in NNI.
* [File an issue](https://github.com/microsoft/nni/issues/new/choose) on GitHub. * [File an issue](https://github.com/microsoft/nni/issues/new/choose) on GitHub.
* Ask a question with NNI tags on [Stack Overflow](https://stackoverflow.com/questions/tagged/nni?sort=Newest&edited=true). * Ask a question with NNI tags on [Stack Overflow](https://stackoverflow.com/questions/tagged/nni?sort=Newest&edited=true).
* Discuss on the NNI [Gitter](https://gitter.im/Microsoft/nni?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) in NNI.
Join IM discussion groups:
|Gitter||WeChat|
|----|----|----|
|![image](https://user-images.githubusercontent.com/39592018/80665738-e0574a80-8acc-11ea-91bc-0836dc4cbf89.png)| OR |![image](https://user-images.githubusercontent.com/39592018/80665762-f06f2a00-8acc-11ea-8d22-e461e68e2d9b.png)|
## Related Projects ## Related Projects
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
**NNI (Neural Network Intelligence)** 是一个轻量但强大的工具包,帮助用户**自动**的进行[特征工程](docs/zh_CN/FeatureEngineering/Overview.md)[神经网络架构搜索](docs/zh_CN/NAS/Overview.md)[超参调优](docs/zh_CN/Tuner/BuiltinTuner.md)以及[模型压缩](docs/zh_CN/Compressor/Overview.md) **NNI (Neural Network Intelligence)** 是一个轻量但强大的工具包,帮助用户**自动**的进行[特征工程](docs/zh_CN/FeatureEngineering/Overview.md)[神经网络架构搜索](docs/zh_CN/NAS/Overview.md)[超参调优](docs/zh_CN/Tuner/BuiltinTuner.md)以及[模型压缩](docs/zh_CN/Compressor/Overview.md)
NNI 管理自动机器学习 (AutoML) 的 Experiment,**调度运行**由调优算法生成的 Trial 任务来找到最好的神经网络架构和/或超参,支持**各种训练环境**,如[本机](docs/zh_CN/TrainingService/LocalMode.md)[远程服务器](docs/zh_CN/TrainingService/RemoteMachineMode.md)[OpenPAI](docs/zh_CN/TrainingService/PaiMode.md)[Kubeflow](docs/zh_CN/TrainingService/KubeflowMode.md)[基于 K8S 的 FrameworkController(如,AKS 等)](docs/zh_CN/TrainingService/FrameworkControllerMode.md)以及其它云服务。 NNI 管理自动机器学习 (AutoML) 的 Experiment,**调度运行**由调优算法生成的 Trial 任务来找到最好的神经网络架构和/或超参,支持**各种训练环境**,如[本机](docs/zh_CN/TrainingService/LocalMode.md)[远程服务器](docs/zh_CN/TrainingService/RemoteMachineMode.md)[OpenPAI](docs/zh_CN/TrainingService/PaiMode.md)[Kubeflow](docs/zh_CN/TrainingService/KubeflowMode.md)[基于 K8S 的 FrameworkController(如,AKS 等)](docs/zh_CN/TrainingService/FrameworkControllerMode.md) [DLWorkspace (又称 DLTS)](docs/zh_CN/TrainingService/DLTSMode.md)其它云服务。
## **使用场景** ## **使用场景**
...@@ -19,7 +19,7 @@ NNI 管理自动机器学习 (AutoML) 的 Experiment,**调度运行**由调优 ...@@ -19,7 +19,7 @@ NNI 管理自动机器学习 (AutoML) 的 Experiment,**调度运行**由调优
* 想要更容易**实现或试验新的自动机器学习算法**的研究员或数据科学家,包括:超参调优算法,神经网络搜索算法以及模型压缩算法。 * 想要更容易**实现或试验新的自动机器学习算法**的研究员或数据科学家,包括:超参调优算法,神经网络搜索算法以及模型压缩算法。
* 在机器学习平台中**支持自动机器学习** * 在机器学习平台中**支持自动机器学习**
### **NNI v1.5 已发布! &nbsp;[<img width="48" src="docs/img/release_icon.png" />](#nni-released-reminder)** ### **[NNI v1.5 已发布!](https://github.com/microsoft/nni/releases) &nbsp;[<img width="48" src="docs/img/release_icon.png" />](#nni-released-reminder)**
## **NNI 功能一览** ## **NNI 功能一览**
...@@ -164,6 +164,7 @@ NNI 提供命令行工具以及友好的 WebUI 来管理训练的 Experiment。 ...@@ -164,6 +164,7 @@ NNI 提供命令行工具以及友好的 WebUI 来管理训练的 Experiment。
<li><a href="docs/zh_CN/TrainingService/KubeflowMode.md">Kubeflow</a></li> <li><a href="docs/zh_CN/TrainingService/KubeflowMode.md">Kubeflow</a></li>
<li><a href="docs/zh_CN/TrainingService/FrameworkControllerMode.md">基于 Kubernetes(AKS 等)的 FrameworkController</a></li> <li><a href="docs/zh_CN/TrainingService/FrameworkControllerMode.md">基于 Kubernetes(AKS 等)的 FrameworkController</a></li>
</ul> </ul>
<ul><li><a href="docs/zh_CN/TrainingService/DLTSMode.md">DLWorkspace (又称 DLTS)</a></li>
</ul> </ul>
</td> </td>
</tr> </tr>
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
For debugging NNI source code, your development environment should be under Ubuntu 16.04 (or above) system with python 3 and pip 3 installed, then follow the below steps. For debugging NNI source code, your development environment should be under Ubuntu 16.04 (or above) system with python 3 and pip 3 installed, then follow the below steps.
**1. Clone the source code** ### 1. Clone the source code
Run the command Run the command
...@@ -16,7 +16,7 @@ git clone https://github.com/Microsoft/nni.git ...@@ -16,7 +16,7 @@ git clone https://github.com/Microsoft/nni.git
to clone the source code to clone the source code
**2. Prepare the debug environment and install dependencies** ### 2. Prepare the debug environment and install dependencies
Change directory to the source code folder, then run the command Change directory to the source code folder, then run the command
...@@ -26,7 +26,7 @@ make install-dependencies ...@@ -26,7 +26,7 @@ make install-dependencies
to install the dependent tools for the environment to install the dependent tools for the environment
**3. Build source code** ### 3. Build source code
Run the command Run the command
...@@ -36,7 +36,7 @@ make build ...@@ -36,7 +36,7 @@ make build
to build the source code to build the source code
**4. Install NNI to development environment** ### 4. Install NNI to development environment
Run the command Run the command
...@@ -46,7 +46,7 @@ make dev-install ...@@ -46,7 +46,7 @@ make dev-install
to install the distribution content to development environment, and create cli scripts to install the distribution content to development environment, and create cli scripts
**5. Check if the environment is ready** ### 5. Check if the environment is ready
Now, you can try to start an experiment to check if your environment is ready. Now, you can try to start an experiment to check if your environment is ready.
For example, run the command For example, run the command
...@@ -57,9 +57,21 @@ nnictl create --config ~/nni/examples/trials/mnist-tfv1/config.yml ...@@ -57,9 +57,21 @@ nnictl create --config ~/nni/examples/trials/mnist-tfv1/config.yml
And open WebUI to check if everything is OK And open WebUI to check if everything is OK
**6. Redeploy** ### 6. Redeploy
After the code changes, it may need to redeploy. It depends on what kind of code changed.
#### Python
It doesn't need to redeploy, but the nnictl may need to be restarted.
#### TypeScript
* If `src/nni_manager` will be changed, run `yarn watch` continually under this folder. It will rebuild code instantly.
* If `src/webui` or `src/nasui` is changed, use **step 3** to rebuild code.
The nnictl may need to be restarted.
After the code changes, use **step 3** to rebuild your codes, then the changes will take effect immediately.
--- ---
At last, wish you have a wonderful day. At last, wish you have a wonderful day.
......
...@@ -45,6 +45,7 @@ extensions = [ ...@@ -45,6 +45,7 @@ extensions = [
'sphinx_markdown_tables', 'sphinx_markdown_tables',
'sphinxarg.ext', 'sphinxarg.ext',
'sphinx.ext.napoleon', 'sphinx.ext.napoleon',
'sphinx.ext.viewcode',
] ]
# Add mock modules # Add mock modules
......
...@@ -104,7 +104,7 @@ Tuner 有大量的文件、函数和类。 这里简单介绍最重要的文件 ...@@ -104,7 +104,7 @@ Tuner 有大量的文件、函数和类。 这里简单介绍最重要的文件
- `networkmorphism_tuner.py` 是使用 network morphism 算法的 Tuner。 - `networkmorphism_tuner.py` 是使用 network morphism 算法的 Tuner。
- `bayesian.py` 是用来基于已经搜索道德模型来预测未知模型指标的贝叶斯算法。 - `bayesian.py` 是用来基于已经搜索到的模型来预测未知模型指标的贝叶斯算法。
- `graph.py` 是元图数据结构。 类 Graph 表示了模型的神经网络图。 - `graph.py` 是元图数据结构。 类 Graph 表示了模型的神经网络图。
- Graph 从模型中抽取神经网络。 - Graph 从模型中抽取神经网络。
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
这是一个用于 NNI 神经网络架构搜索(NAS)接口的 Tuner。 它使用了 [ppo 算法](https://arxiv.org/abs/1707.06347)。 此实现继承了 [OpenAI 的 ppo2 实现](https://github.com/openai/baselines/tree/master/baselines/ppo2)的主要逻辑,并为 NAS 场景做了适配。 这是一个用于 NNI 神经网络架构搜索(NAS)接口的 Tuner。 它使用了 [ppo 算法](https://arxiv.org/abs/1707.06347)。 此实现继承了 [OpenAI 的 ppo2 实现](https://github.com/openai/baselines/tree/master/baselines/ppo2)的主要逻辑,并为 NAS 场景做了适配。
它能成功调优 [mnist-nas 示例](https://github.com/microsoft/nni/tree/master/examples/trials/mnist-nas),结果如下: mnist-nas 示例已调优,并得到以下结果: **注意:此示例正在重构中,以支持最新的 NAS 接口,完成后会重新发布示例代码。**
![](../../img/ppo_mnist.png) ![](../../img/ppo_mnist.png)
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
要调试 NNI 源代码,需要 Ubuntu 16.04 或更高版本系统的开发环境,并需要安装 Python 3 以及 pip 3,然后遵循以下步骤。 要调试 NNI 源代码,需要 Ubuntu 16.04 或更高版本系统的开发环境,并需要安装 Python 3 以及 pip 3,然后遵循以下步骤。
**1. 克隆源代码** ### 1. 克隆源代码
运行命令 运行命令
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
来克隆源代码 来克隆源代码
**2. 准备调试环境并安装依赖项** ### 2. 准备调试环境并安装依赖项**
将目录切换到源码目录,然后运行命令 将目录切换到源码目录,然后运行命令
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
来安装环境的依赖项工具 来安装环境的依赖项工具
**3. 生成源代码** ### 3. 生成源代码
运行命令 运行命令
...@@ -33,7 +33,7 @@ ...@@ -33,7 +33,7 @@
来生成源代码 来生成源代码
**4. 将 NNI 安装到开发环境中** ### 4. 将 NNI 安装到开发环境中
运行命令 运行命令
...@@ -42,7 +42,7 @@ ...@@ -42,7 +42,7 @@
来安装分发内容到开发环境,并创建 cli 脚本 来安装分发内容到开发环境,并创建 cli 脚本
**5. 检查环境是否正确** ### 5. 检查环境是否正确
Trial 启动 Experiment 来检查环境。 例如,运行命令 Trial 启动 Experiment 来检查环境。 例如,运行命令
...@@ -51,9 +51,20 @@ Trial 启动 Experiment 来检查环境。 例如,运行命令 ...@@ -51,9 +51,20 @@ Trial 启动 Experiment 来检查环境。 例如,运行命令
并打开网页界面查看 并打开网页界面查看
**6. 重新部署** ### 6. 重新部署
代码改动后,用**第 3 步**来重新生成代码,改动会立即生效。 代码更改后,可能需要重新部署。 这取决于更改了哪种代码。
#### Python
不需要重新部署,但可能需要重新启动 nnictl。
#### TypeScript
* 如果要更改 `src/nni_manager`,运行 `yarn watch` 可持续编译改动。 它将实时重建代码。
* 如果更改了 `src/webui``src/nasui` ,请使用 **第 3 步** 来重建代码。
可能需要重新启动 nnictl。
* * * * * *
......
...@@ -8,6 +8,8 @@ https://github.com/pytorch/examples/blob/master/mnist/main.py ...@@ -8,6 +8,8 @@ https://github.com/pytorch/examples/blob/master/mnist/main.py
import os import os
import argparse import argparse
import logging import logging
from collections import OrderedDict
import nni import nni
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -26,13 +28,15 @@ class Net(nn.Module): ...@@ -26,13 +28,15 @@ class Net(nn.Module):
def __init__(self, hidden_size): def __init__(self, hidden_size):
super(Net, self).__init__() super(Net, self).__init__()
# two options of conv1 # two options of conv1
self.conv1 = LayerChoice([nn.Conv2d(1, 20, 5, 1), self.conv1 = LayerChoice(OrderedDict([
nn.Conv2d(1, 20, 3, 1)], ("conv5x5", nn.Conv2d(1, 20, 5, 1)),
key='first_conv') ("conv3x3", nn.Conv2d(1, 20, 3, 1))
]), key='first_conv')
# two options of mid_conv # two options of mid_conv
self.mid_conv = LayerChoice([nn.Conv2d(20, 20, 3, 1, padding=1), self.mid_conv = LayerChoice([
nn.Conv2d(20, 20, 5, 1, padding=2)], nn.Conv2d(20, 20, 3, 1, padding=1),
key='mid_conv') nn.Conv2d(20, 20, 5, 1, padding=2)
], key='mid_conv')
self.conv2 = nn.Conv2d(20, 50, 5, 1) self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(4*4*50, hidden_size) self.fc1 = nn.Linear(4*4*50, hidden_size)
self.fc2 = nn.Linear(hidden_size, 10) self.fc2 = nn.Linear(hidden_size, 10)
...@@ -167,7 +171,6 @@ def get_params(): ...@@ -167,7 +171,6 @@ def get_params():
parser.add_argument('--log_interval', type=int, default=1000, metavar='N', parser.add_argument('--log_interval', type=int, default=1000, metavar='N',
help='how many batches to wait before logging training status') help='how many batches to wait before logging training status')
args, _ = parser.parse_known_args() args, _ = parser.parse_known_args()
return args return args
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from collections import OrderedDict
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -43,17 +45,15 @@ class Node(nn.Module): ...@@ -43,17 +45,15 @@ class Node(nn.Module):
stride = 2 if i < num_downsample_connect else 1 stride = 2 if i < num_downsample_connect else 1
choice_keys.append("{}_p{}".format(node_id, i)) choice_keys.append("{}_p{}".format(node_id, i))
self.ops.append( self.ops.append(
mutables.LayerChoice( mutables.LayerChoice(OrderedDict([
[ ("maxpool", ops.PoolBN('max', channels, 3, stride, 1, affine=False)),
ops.PoolBN('max', channels, 3, stride, 1, affine=False), ("avgpool", ops.PoolBN('avg', channels, 3, stride, 1, affine=False)),
ops.PoolBN('avg', channels, 3, stride, 1, affine=False), ("skipconnect", nn.Identity() if stride == 1 else ops.FactorizedReduce(channels, channels, affine=False)),
nn.Identity() if stride == 1 else ops.FactorizedReduce(channels, channels, affine=False), ("sepconv3x3", ops.SepConv(channels, channels, 3, stride, 1, affine=False)),
ops.SepConv(channels, channels, 3, stride, 1, affine=False), ("sepconv5x5", ops.SepConv(channels, channels, 5, stride, 2, affine=False)),
ops.SepConv(channels, channels, 5, stride, 2, affine=False), ("dilconv3x3", ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False)),
ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False), ("dilconv5x5", ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False))
ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False) ]), key=choice_keys[-1]))
],
key=choice_keys[-1]))
self.drop_path = ops.DropPath() self.drop_path = ops.DropPath()
self.input_switch = mutables.InputChoice(choose_from=choice_keys, n_chosen=2, key="{}_switch".format(node_id)) self.input_switch = mutables.InputChoice(choose_from=choice_keys, n_chosen=2, key="{}_switch".format(node_id))
......
...@@ -151,6 +151,5 @@ def load_and_parse_state_dict(filepath="./data/checkpoint-150000.pth.tar"): ...@@ -151,6 +151,5 @@ def load_and_parse_state_dict(filepath="./data/checkpoint-150000.pth.tar"):
for k, v in checkpoint["state_dict"].items(): for k, v in checkpoint["state_dict"].items():
if k.startswith("module."): if k.startswith("module."):
k = k[len("module."):] k = k[len("module."):]
k = re.sub(r"^(features.\d+).(\d+)", "\\1.choices.\\2", k)
result[k] = v result[k] = v
return result return result
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
"build": "tsc", "build": "tsc",
"test": "nyc mocha -r ts-node/register -t 15000 --recursive **/*.test.ts --exclude node_modules/**/**/*.test.ts --colors", "test": "nyc mocha -r ts-node/register -t 15000 --recursive **/*.test.ts --exclude node_modules/**/**/*.test.ts --colors",
"start": "node dist/main.js", "start": "node dist/main.js",
"watch": "tsc --watch",
"eslint": "npx eslint ./ --ext .ts" "eslint": "npx eslint ./ --ext .ts"
}, },
"license": "MIT", "license": "MIT",
......
...@@ -38,7 +38,9 @@ class DLTSTrainingService implements TrainingService { ...@@ -38,7 +38,9 @@ class DLTSTrainingService implements TrainingService {
private versionCheck: boolean = true; private versionCheck: boolean = true;
private logCollection: string = 'none'; private logCollection: string = 'none';
private isMultiPhase: boolean = false; private isMultiPhase: boolean = false;
private dltsRestServerHost: string;
private dltsRestServerPort?: number; private dltsRestServerPort?: number;
private jobMode: boolean;
private readonly trialJobsMap: Map<string, DLTSTrialJobDetail>; private readonly trialJobsMap: Map<string, DLTSTrialJobDetail>;
private nniManagerIpConfig?: NNIManagerIpConfig; private nniManagerIpConfig?: NNIManagerIpConfig;
...@@ -51,7 +53,9 @@ class DLTSTrainingService implements TrainingService { ...@@ -51,7 +53,9 @@ class DLTSTrainingService implements TrainingService {
this.trialJobsMap = new Map(); this.trialJobsMap = new Map();
this.jobQueue = []; this.jobQueue = [];
this.experimentId = getExperimentId(); this.experimentId = getExperimentId();
this.log.info('Construct DLTS training service.'); this.dltsRestServerHost = getIPV4Address();
this.jobMode = 'DLTS_JOB_ID' in process.env;
this.log.info(`Construct DLTS training service in ${this.jobMode ? 'job mode' : 'local mode'}.`);
} }
public async run(): Promise<void> { public async run(): Promise<void> {
...@@ -60,12 +64,70 @@ class DLTSTrainingService implements TrainingService { ...@@ -60,12 +64,70 @@ class DLTSTrainingService implements TrainingService {
await restServer.start(); await restServer.start();
restServer.setEnableVersionCheck = this.versionCheck; restServer.setEnableVersionCheck = this.versionCheck;
this.log.info(`DLTS Training service rest server listening on: ${restServer.endPoint}`); this.log.info(`DLTS Training service rest server listening on: ${restServer.endPoint}`);
if (this.jobMode) {
await this.exposeRestServerPort(restServer.clusterRestServerPort);
} else {
this.dltsRestServerPort = restServer.clusterRestServerPort
}
await Promise.all([ await Promise.all([
this.statusCheckingLoop(), this.statusCheckingLoop(),
this.submitJobLoop()]); this.submitJobLoop()]);
this.log.info('DLTS training service exit.'); this.log.info('DLTS training service exit.');
} }
private async exposeRestServerPort(port: number): Promise<void> {
if (this.dltsClusterConfig == null) {
throw Error('Cluster config is not set');
}
const { dashboard, cluster, email, password } = this.dltsClusterConfig;
const jobId = process.env['DLTS_JOB_ID'] + '';
const uri = `${dashboard}api/clusters/${cluster}/jobs/${jobId}/endpoints`;
const qs = { email, password };
do {
this.log.debug('Checking endpoints');
const endpoints = await new Promise((resolve, reject) => {
request.get(uri, { qs, json: true }, function (error, response, body) {
if (error) {
reject(error);
} else {
resolve(body);
}
});
});
this.log.debug('Endpoints: %o', endpoints);
if (Array.isArray(endpoints)) {
const restServerEndpoint = endpoints.find(({ podPort }) => podPort === port);
if (restServerEndpoint == null) {
this.log.debug('Exposing %d', port);
await new Promise((resolve, reject) => {
request.post(uri, {
qs,
json: true,
body: {
endpoints: [{
name: "nni-rest-server",
podPort: port
}]
}
}, function (error) {
if (error) {
reject(error);
} else {
resolve();
}
});
});
} else if (restServerEndpoint['status'] === 'running') {
// We get an exposed restserver port
this.dltsRestServerHost = restServerEndpoint['nodeName'];
this.dltsRestServerPort = restServerEndpoint['port'];
break;
}
}
} while (await new Promise(resolve => setTimeout(resolve, 1000, true)));
}
private async statusCheckingLoop(): Promise<void> { private async statusCheckingLoop(): Promise<void> {
while (!this.stopping) { while (!this.stopping) {
const updateDLTSTrialJobs: Promise<void>[] = []; const updateDLTSTrialJobs: Promise<void>[] = [];
...@@ -400,7 +462,7 @@ class DLTSTrainingService implements TrainingService { ...@@ -400,7 +462,7 @@ class DLTSTrainingService implements TrainingService {
); );
} }
// tslint:disable-next-line: strict-boolean-expressions // tslint:disable-next-line: strict-boolean-expressions
const nniManagerIp: string = this.nniManagerIpConfig ? this.nniManagerIpConfig.nniManagerIp : getIPV4Address(); const nniManagerIp: string = this.nniManagerIpConfig ? this.nniManagerIpConfig.nniManagerIp : this.dltsRestServerHost;
const version: string = this.versionCheck ? await getVersion() : ''; const version: string = this.versionCheck ? await getVersion() : '';
const nniDLTSTrialCommand: string = String.Format( const nniDLTSTrialCommand: string = String.Format(
DLTS_TRIAL_COMMAND_FORMAT, DLTS_TRIAL_COMMAND_FORMAT,
......
...@@ -100,12 +100,15 @@ class PAIK8STrainingService extends PAITrainingService { ...@@ -100,12 +100,15 @@ class PAIK8STrainingService extends PAITrainingService {
} }
} }
//TODO: update trial parameters // update trial parameters for multi-phase
public async updateTrialJob(trialJobId: string, form: TrialJobApplicationForm): Promise<TrialJobDetail> { public async updateTrialJob(trialJobId: string, form: TrialJobApplicationForm): Promise<TrialJobDetail> {
const trialJobDetail: undefined | TrialJobDetail = this.trialJobsMap.get(trialJobId); const trialJobDetail: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
if (trialJobDetail === undefined) { if (trialJobDetail === undefined) {
throw new Error(`updateTrialJob failed: ${trialJobId} not found`); throw new Error(`updateTrialJob failed: ${trialJobId} not found`);
} }
// Write file content ( parameter.cfg ) to working folders
await this.writeParameterFile(trialJobDetail.logPath, form.hyperParameters);
return trialJobDetail; return trialJobDetail;
} }
...@@ -230,24 +233,20 @@ class PAIK8STrainingService extends PAITrainingService { ...@@ -230,24 +233,20 @@ class PAIK8STrainingService extends PAITrainingService {
this.paiRestServerPort = this.paiJobRestServer.clusterRestServerPort; this.paiRestServerPort = this.paiJobRestServer.clusterRestServerPort;
// Step 1. Prepare PAI job configuration // Step 1. Prepare PAI job configuration
const trialLocalFolder: string = path.join(this.paiTrialConfig.nniManagerNFSMountPath, this.experimentId, trialJobId);
//create trial local working folder locally. //create trial local working folder locally.
await execMkdir(trialLocalFolder); await execMkdir(trialJobDetail.logPath);
const runScriptContent: string = CONTAINER_INSTALL_NNI_SHELL_FORMAT; const runScriptContent: string = CONTAINER_INSTALL_NNI_SHELL_FORMAT;
// Write NNI installation file to local files // Write NNI installation file to local files
await fs.promises.writeFile(path.join(trialLocalFolder, 'install_nni.sh'), runScriptContent, { encoding: 'utf8' }); await fs.promises.writeFile(path.join(trialJobDetail.logPath, 'install_nni.sh'), runScriptContent, { encoding: 'utf8' });
// Write file content ( parameter.cfg ) to local working folders // Write file content ( parameter.cfg ) to local working folders
if (trialJobDetail.form !== undefined) { if (trialJobDetail.form !== undefined) {
await fs.promises.writeFile( await this.writeParameterFile(trialJobDetail.logPath, trialJobDetail.form.hyperParameters);
path.join(trialLocalFolder, generateParamFileName(trialJobDetail.form.hyperParameters)),
trialJobDetail.form.hyperParameters.value, { encoding: 'utf8' }
);
} }
//Copy codeDir files to local working folder //Copy codeDir files to local working folder
await execCopydir(this.paiTrialConfig.codeDir, trialLocalFolder); await execCopydir(this.paiTrialConfig.codeDir, trialJobDetail.logPath);
const nniManagerIp: string = this.nniManagerIpConfig ? this.nniManagerIpConfig.nniManagerIp : getIPV4Address(); const nniManagerIp: string = this.nniManagerIpConfig ? this.nniManagerIpConfig.nniManagerIp : getIPV4Address();
const version: string = this.versionCheck ? await getVersion() : ''; const version: string = this.versionCheck ? await getVersion() : '';
...@@ -298,6 +297,11 @@ class PAIK8STrainingService extends PAITrainingService { ...@@ -298,6 +297,11 @@ class PAIK8STrainingService extends PAITrainingService {
return deferred.promise; return deferred.promise;
} }
private async writeParameterFile(directory: string, hyperParameters: HyperParameters): Promise<void> {
const filepath: string = path.join(directory, generateParamFileName(hyperParameters));
await fs.promises.writeFile(filepath, hyperParameters.value, { encoding: 'utf8' });
}
} }
export { PAIK8STrainingService }; export { PAIK8STrainingService };
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
import { OsCommands } from "../osCommands";
import { RemoteCommandResult } from "../remoteMachineData";
class LinuxCommands extends OsCommands {
public createFolder(folderName: string, sharedFolder: boolean = false): string {
let command;
if (sharedFolder) {
command = `umask 0; mkdir -p '${folderName}'`;
} else {
command = `mkdir -p '${folderName}'`;
}
return command;
}
public allowPermission(isRecursive: boolean = false, ...folders: string[]): string {
const folderString = folders.join("' '");
let command;
if (isRecursive) {
command = `chmod 777 -R '${folderString}'`;
} else {
command = `chmod 777 '${folderString}'`;
}
return command;
}
public removeFolder(folderName: string, isRecursive: boolean = false, isForce: boolean = true): string {
let flags = '';
if (isForce || isRecursive) {
flags = `-${isRecursive ? 'r' : 'd'}${isForce ? 'f' : ''} `;
}
const command = `rm ${flags}'${folderName}'`;
return command;
}
public removeFiles(folderName: string, filePattern: string): string {
const files = this.joinPath(folderName, filePattern);
const command = `rm '${files}'`;
return command;
}
public readLastLines(fileName: string, lineCount: number = 1): string {
const command = `tail -n ${lineCount} '${fileName}'`;
return command;
}
public isProcessAliveCommand(pidFileName: string): string {
const command = `kill -0 \`cat '${pidFileName}'\``;
return command;
}
public isProcessAliveProcessOutput(commandResult: RemoteCommandResult): boolean {
let result = true;
if (commandResult.exitCode !== 0) {
result = false;
}
return result;
}
public killChildProcesses(pidFileName: string): string {
const command = `pkill -P \`cat '${pidFileName}'\``;
return command;
}
public extractFile(tarFileName: string, targetFolder: string): string {
const command = `tar -oxzf '${tarFileName}' -C '${targetFolder}'`;
return command;
}
public executeScript(script: string, isFile: boolean): string {
let command: string;
if (isFile) {
command = `bash '${script}'`;
} else {
script = script.replace('"', '\\"');
command = `bash -c "${script}"`;
}
return command;
}
}
export { LinuxCommands };
...@@ -8,7 +8,7 @@ import { getLogger, Logger } from '../../common/log'; ...@@ -8,7 +8,7 @@ import { getLogger, Logger } from '../../common/log';
import { randomSelect } from '../../common/utils'; import { randomSelect } from '../../common/utils';
import { GPUInfo } from '../common/gpuData'; import { GPUInfo } from '../common/gpuData';
import { import {
parseGpuIndices, RemoteMachineMeta, RemoteMachineScheduleResult, RemoteMachineTrialJobDetail, ScheduleResultType, SSHClientManager parseGpuIndices, RemoteMachineMeta, RemoteMachineScheduleResult, RemoteMachineTrialJobDetail, ScheduleResultType, ExecutorManager
} from './remoteMachineData'; } from './remoteMachineData';
type SCHEDULE_POLICY_NAME = 'random' | 'round-robin'; type SCHEDULE_POLICY_NAME = 'random' | 'round-robin';
...@@ -18,7 +18,7 @@ type SCHEDULE_POLICY_NAME = 'random' | 'round-robin'; ...@@ -18,7 +18,7 @@ type SCHEDULE_POLICY_NAME = 'random' | 'round-robin';
*/ */
export class GPUScheduler { export class GPUScheduler {
private readonly machineSSHClientMap: Map<RemoteMachineMeta, SSHClientManager>; private readonly machineExecutorMap: Map<RemoteMachineMeta, ExecutorManager>;
private readonly log: Logger = getLogger(); private readonly log: Logger = getLogger();
private readonly policyName: SCHEDULE_POLICY_NAME = 'round-robin'; private readonly policyName: SCHEDULE_POLICY_NAME = 'round-robin';
private roundRobinIndex: number = 0; private roundRobinIndex: number = 0;
...@@ -26,12 +26,12 @@ export class GPUScheduler { ...@@ -26,12 +26,12 @@ export class GPUScheduler {
/** /**
* Constructor * Constructor
* @param machineSSHClientMap map from remote machine to sshClient * @param machineExecutorMap map from remote machine to executor
*/ */
constructor(machineSSHClientMap: Map<RemoteMachineMeta, SSHClientManager>) { constructor(machineExecutorMap: Map<RemoteMachineMeta, ExecutorManager>) {
assert(machineSSHClientMap.size > 0); assert(machineExecutorMap.size > 0);
this.machineSSHClientMap = machineSSHClientMap; this.machineExecutorMap = machineExecutorMap;
this.configuredRMs = Array.from(machineSSHClientMap.keys()); this.configuredRMs = Array.from(machineExecutorMap.keys());
} }
/** /**
...@@ -43,7 +43,7 @@ export class GPUScheduler { ...@@ -43,7 +43,7 @@ export class GPUScheduler {
requiredGPUNum = 0; requiredGPUNum = 0;
} }
assert(requiredGPUNum >= 0); assert(requiredGPUNum >= 0);
const allRMs: RemoteMachineMeta[] = Array.from(this.machineSSHClientMap.keys()); const allRMs: RemoteMachineMeta[] = Array.from(this.machineExecutorMap.keys());
assert(allRMs.length > 0); assert(allRMs.length > 0);
// Step 1: Check if required GPU number not exceeds the total GPU number in all machines // Step 1: Check if required GPU number not exceeds the total GPU number in all machines
...@@ -135,7 +135,7 @@ export class GPUScheduler { ...@@ -135,7 +135,7 @@ export class GPUScheduler {
*/ */
private gpuResourceDetection(): Map<RemoteMachineMeta, GPUInfo[]> { private gpuResourceDetection(): Map<RemoteMachineMeta, GPUInfo[]> {
const totalResourceMap: Map<RemoteMachineMeta, GPUInfo[]> = new Map<RemoteMachineMeta, GPUInfo[]>(); const totalResourceMap: Map<RemoteMachineMeta, GPUInfo[]> = new Map<RemoteMachineMeta, GPUInfo[]>();
this.machineSSHClientMap.forEach((sshClientManager: SSHClientManager, rmMeta: RemoteMachineMeta) => { this.machineExecutorMap.forEach((executorManager: ExecutorManager, rmMeta: RemoteMachineMeta) => {
// Assgin totoal GPU count as init available GPU number // Assgin totoal GPU count as init available GPU number
if (rmMeta.gpuSummary !== undefined) { if (rmMeta.gpuSummary !== undefined) {
const availableGPUs: GPUInfo[] = []; const availableGPUs: GPUInfo[] = [];
......
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
import { RemoteCommandResult } from "./remoteMachineData";
abstract class OsCommands {
protected pathSpliter: string = '/';
protected multiplePathSpliter: RegExp = new RegExp(`\\${this.pathSpliter}{2,}`);
public abstract createFolder(folderName: string, sharedFolder: boolean): string;
public abstract allowPermission(isRecursive: boolean, ...folders: string[]): string;
public abstract removeFolder(folderName: string, isRecursive: boolean, isForce: boolean): string;
public abstract removeFiles(folderOrFileName: string, filePattern: string): string;
public abstract readLastLines(fileName: string, lineCount: number): string;
public abstract isProcessAliveCommand(pidFileName: string): string;
public abstract isProcessAliveProcessOutput(result: RemoteCommandResult): boolean;
public abstract killChildProcesses(pidFileName: string): string;
public abstract extractFile(tarFileName: string, targetFolder: string): string;
public abstract executeScript(script: string, isFile: boolean): string;
public joinPath(...paths: string[]): string {
let dir: string = paths.filter((path: any) => path !== '').join(this.pathSpliter);
if (dir === '') {
dir = '.';
} else {
dir = dir.replace(this.multiplePathSpliter, this.pathSpliter);
}
return dir;
}
}
export { OsCommands };
...@@ -3,11 +3,9 @@ ...@@ -3,11 +3,9 @@
'use strict'; 'use strict';
import * as fs from 'fs'; import { TrialJobApplicationForm, TrialJobDetail, TrialJobStatus } from '../../common/trainingService';
import { Client, ConnectConfig } from 'ssh2';
import { Deferred } from 'ts-deferred';
import { TrialJobApplicationForm, TrialJobDetail, TrialJobStatus } from '../../common/trainingService';
import { GPUInfo, GPUSummary } from '../common/gpuData'; import { GPUInfo, GPUSummary } from '../common/gpuData';
import { ShellExecutor } from './shellExecutor';
/** /**
* Metadata of remote machine for configuration and statuc query * Metadata of remote machine for configuration and statuc query
...@@ -72,7 +70,7 @@ export class RemoteMachineTrialJobDetail implements TrialJobDetail { ...@@ -72,7 +70,7 @@ export class RemoteMachineTrialJobDetail implements TrialJobDetail {
public gpuIndices: GPUInfo[]; public gpuIndices: GPUInfo[];
constructor(id: string, status: TrialJobStatus, submitTime: number, constructor(id: string, status: TrialJobStatus, submitTime: number,
workingDirectory: string, form: TrialJobApplicationForm) { workingDirectory: string, form: TrialJobApplicationForm) {
this.id = id; this.id = id;
this.status = status; this.status = status;
this.submitTime = submitTime; this.submitTime = submitTime;
...@@ -84,149 +82,88 @@ export class RemoteMachineTrialJobDetail implements TrialJobDetail { ...@@ -84,149 +82,88 @@ export class RemoteMachineTrialJobDetail implements TrialJobDetail {
} }
/** /**
* The remote machine ssh client used for trial and gpu detector * The remote machine executor manager
*/ */
export class SSHClient { export class ExecutorManager {
private readonly sshClient: Client; private readonly executorArray: ShellExecutor[];
private usedConnectionNumber: number; //count the connection number of every client
constructor(sshClient: Client, usedConnectionNumber: number) {
this.sshClient = sshClient;
this.usedConnectionNumber = usedConnectionNumber;
}
public get getSSHClientInstance(): Client {
return this.sshClient;
}
public get getUsedConnectionNumber(): number {
return this.usedConnectionNumber;
}
public addUsedConnectionNumber(): void {
this.usedConnectionNumber += 1;
}
public minusUsedConnectionNumber(): void {
this.usedConnectionNumber -= 1;
}
}
/**
* The remote machine ssh client manager
*/
export class SSHClientManager {
private readonly sshClientArray: SSHClient[];
private readonly maxTrialNumberPerConnection: number; private readonly maxTrialNumberPerConnection: number;
private readonly rmMeta: RemoteMachineMeta; private readonly rmMeta: RemoteMachineMeta;
constructor(sshClientArray: SSHClient[], maxTrialNumberPerConnection: number, rmMeta: RemoteMachineMeta) { constructor(executorArray: ShellExecutor[], maxTrialNumberPerConnection: number, rmMeta: RemoteMachineMeta) {
this.rmMeta = rmMeta; this.rmMeta = rmMeta;
this.sshClientArray = sshClientArray; this.executorArray = executorArray;
this.maxTrialNumberPerConnection = maxTrialNumberPerConnection; this.maxTrialNumberPerConnection = maxTrialNumberPerConnection;
} }
/** /**
* find a available ssh client in ssh array, if no ssh client available, return undefined * find a available executor, if no executor available, return a new one
*/ */
public async getAvailableSSHClient(): Promise<Client> { public async getAvailableExecutor(): Promise<ShellExecutor> {
const deferred: Deferred<Client> = new Deferred<Client>(); for (const index of this.executorArray.keys()) {
for (const index of this.sshClientArray.keys()) { const connectionNumber: number = this.executorArray[index].getUsedConnectionNumber;
const connectionNumber: number = this.sshClientArray[index].getUsedConnectionNumber;
if (connectionNumber < this.maxTrialNumberPerConnection) { if (connectionNumber < this.maxTrialNumberPerConnection) {
this.sshClientArray[index].addUsedConnectionNumber(); this.executorArray[index].addUsedConnectionNumber();
deferred.resolve(this.sshClientArray[index].getSSHClientInstance);
return deferred.promise; return this.executorArray[index];
} }
} }
//init a new ssh client if could not get an available one //init a new executor if could not get an available one
return this.initNewSSHClient(); return await this.initNewShellExecutor();
} }
/** /**
* add a new ssh client to sshClientArray * add a new executor to executorArray
* @param sshClient SSH Client * @param executor ShellExecutor
*/ */
public addNewSSHClient(client: Client): void { public addNewShellExecutor(executor: ShellExecutor): void {
this.sshClientArray.push(new SSHClient(client, 1)); this.executorArray.push(executor);
} }
/** /**
* first ssh client instance is used for gpu collector and host job * first executor instance is used for gpu collector and host job
*/ */
public getFirstSSHClient(): Client { public getFirstExecutor(): ShellExecutor {
return this.sshClientArray[0].getSSHClientInstance; return this.executorArray[0];
} }
/** /**
* close all of ssh client * close all of executor
*/ */
public closeAllSSHClient(): void { public closeAllExecutor(): void {
for (const sshClient of this.sshClientArray) { for (const executor of this.executorArray) {
sshClient.getSSHClientInstance.end(); executor.close();
} }
} }
/** /**
* retrieve resource, minus a number for given ssh client * retrieve resource, minus a number for given executor
* @param client SSH Client * @param executor executor
*/ */
public releaseConnection(client: Client | undefined): void { public releaseConnection(executor: ShellExecutor | undefined): void {
if (client === undefined) { if (executor === undefined) {
throw new Error(`could not release a undefined ssh client`); throw new Error(`could not release a undefined executor`);
} }
for (const index of this.sshClientArray.keys()) { for (const index of this.executorArray.keys()) {
if (this.sshClientArray[index].getSSHClientInstance === client) { if (this.executorArray[index] === executor) {
this.sshClientArray[index].minusUsedConnectionNumber(); this.executorArray[index].minusUsedConnectionNumber();
break; break;
} }
} }
} }
/** /**
* Create a new ssh connection client and initialize it * Create a new connection executor and initialize it
*/ */
private initNewSSHClient(): Promise<Client> { private async initNewShellExecutor(): Promise<ShellExecutor> {
const deferred: Deferred<Client> = new Deferred<Client>(); const executor = new ShellExecutor();
const conn: Client = new Client(); await executor.initialize(this.rmMeta);
const connectConfig: ConnectConfig = { return executor;
host: this.rmMeta.ip,
port: this.rmMeta.port,
username: this.rmMeta.username,
tryKeyboard: true };
if (this.rmMeta.passwd !== undefined) {
connectConfig.password = this.rmMeta.passwd;
} else if (this.rmMeta.sshKeyPath !== undefined) {
if (!fs.existsSync(this.rmMeta.sshKeyPath)) {
//SSh key path is not a valid file, reject
deferred.reject(new Error(`${this.rmMeta.sshKeyPath} does not exist.`));
}
const privateKey: string = fs.readFileSync(this.rmMeta.sshKeyPath, 'utf8');
connectConfig.privateKey = privateKey;
connectConfig.passphrase = this.rmMeta.passphrase;
} else {
deferred.reject(new Error(`No valid passwd or sshKeyPath is configed.`));
}
conn.on('ready', () => {
this.addNewSSHClient(conn);
deferred.resolve(conn);
})
.on('error', (err: Error) => {
// SSH connection error, reject with error message
deferred.reject(new Error(err.message));
}).on("keyboard-interactive", (name, instructions, lang, prompts, finish) => {
finish([this.rmMeta.passwd]);
})
.connect(connectConfig);
return deferred.promise;
} }
} }
export type RemoteMachineScheduleResult = { scheduleInfo: RemoteMachineScheduleInfo | undefined; resultType: ScheduleResultType}; export type RemoteMachineScheduleResult = { scheduleInfo: RemoteMachineScheduleInfo | undefined; resultType: ScheduleResultType };
export type RemoteMachineScheduleInfo = { rmMeta: RemoteMachineMeta; cudaVisibleDevice: string}; export type RemoteMachineScheduleInfo = { rmMeta: RemoteMachineMeta; cudaVisibleDevice: string };
export enum ScheduleResultType { export enum ScheduleResultType {
// Schedule succeeded // Schedule succeeded
...@@ -240,7 +177,7 @@ export enum ScheduleResultType { ...@@ -240,7 +177,7 @@ export enum ScheduleResultType {
} }
export const REMOTEMACHINE_TRIAL_COMMAND_FORMAT: string = export const REMOTEMACHINE_TRIAL_COMMAND_FORMAT: string =
`#!/bin/bash `#!/bin/bash
export NNI_PLATFORM=remote NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={2} NNI_EXP_ID={3} \ export NNI_PLATFORM=remote NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={2} NNI_EXP_ID={3} \
NNI_TRIAL_SEQ_ID={4} export MULTI_PHASE={5} NNI_TRIAL_SEQ_ID={4} export MULTI_PHASE={5}
cd $NNI_SYS_DIR cd $NNI_SYS_DIR
...@@ -251,7 +188,7 @@ python3 -m nni_trial_tool.trial_keeper --trial_command '{7}' --nnimanager_ip '{8 ...@@ -251,7 +188,7 @@ python3 -m nni_trial_tool.trial_keeper --trial_command '{7}' --nnimanager_ip '{8
echo $? \`date +%s%3N\` >{12}`; echo $? \`date +%s%3N\` >{12}`;
export const HOST_JOB_SHELL_FORMAT: string = export const HOST_JOB_SHELL_FORMAT: string =
`#!/bin/bash `#!/bin/bash
cd {0} cd {0}
echo $$ >{1} echo $$ >{1}
eval {2} >stdout 2>stderr eval {2} >stdout 2>stderr
......
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
import * as assert from 'assert';
import * as os from 'os';
import * as path from 'path';
import * as fs from 'fs';
import { Client, ClientChannel, SFTPWrapper, ConnectConfig } from 'ssh2';
import { Deferred } from "ts-deferred";
import { RemoteCommandResult, RemoteMachineMeta } from "./remoteMachineData";
import * as stream from 'stream';
import { OsCommands } from "./osCommands";
import { LinuxCommands } from "./extends/linuxCommands";
import { getLogger, Logger } from '../../common/log';
import { NNIError, NNIErrorNames } from '../../common/errors';
import { execRemove, tarAdd } from '../common/util';
import { getRemoteTmpDir, uniqueString, unixPathJoin } from '../../common/utils';
class ShellExecutor {
private sshClient: Client = new Client();
private osCommands: OsCommands | undefined;
private usedConnectionNumber: number = 0; //count the connection number of every client
protected pathSpliter: string = '/';
protected multiplePathSpliter: RegExp = new RegExp(`\\${this.pathSpliter}{2,}`);
public async initialize(rmMeta: RemoteMachineMeta): Promise<void> {
const deferred: Deferred<void> = new Deferred<void>();
const connectConfig: ConnectConfig = {
host: rmMeta.ip,
port: rmMeta.port,
username: rmMeta.username,
tryKeyboard: true
};
if (rmMeta.passwd !== undefined) {
connectConfig.password = rmMeta.passwd;
} else if (rmMeta.sshKeyPath !== undefined) {
if (!fs.existsSync(rmMeta.sshKeyPath)) {
//SSh key path is not a valid file, reject
deferred.reject(new Error(`${rmMeta.sshKeyPath} does not exist.`));
}
const privateKey: string = fs.readFileSync(rmMeta.sshKeyPath, 'utf8');
connectConfig.privateKey = privateKey;
connectConfig.passphrase = rmMeta.passphrase;
} else {
deferred.reject(new Error(`No valid passwd or sshKeyPath is configed.`));
}
this.sshClient.on('ready', async () => {
// check OS type: windows or else
const result = await this.execute("ver");
if (result.exitCode == 0 && result.stdout.search("Windows") > -1) {
// not implement Windows commands yet.
throw new Error("not implement Windows commands yet.");
} else {
this.osCommands = new LinuxCommands();
}
deferred.resolve();
}).on('error', (err: Error) => {
// SSH connection error, reject with error message
deferred.reject(new Error(err.message));
}).on("keyboard-interactive", (name, instructions, lang, prompts, finish) => {
finish([rmMeta.passwd]);
}).connect(connectConfig);
return deferred.promise;
}
public close(): void {
this.sshClient.end();
}
public get getUsedConnectionNumber(): number {
return this.usedConnectionNumber;
}
public addUsedConnectionNumber(): void {
this.usedConnectionNumber += 1;
}
public minusUsedConnectionNumber(): void {
this.usedConnectionNumber -= 1;
}
public async createFolder(folderName: string, sharedFolder: boolean = false): Promise<boolean> {
const commandText = this.osCommands && this.osCommands.createFolder(folderName, sharedFolder);
const commandResult = await this.execute(commandText);
const result = commandResult.exitCode >= 0;
return result;
}
public async allowPermission(isRecursive: boolean = false, ...folders: string[]): Promise<boolean> {
const commandText = this.osCommands && this.osCommands.allowPermission(isRecursive, ...folders);
const commandResult = await this.execute(commandText);
const result = commandResult.exitCode >= 0;
return result;
}
public async removeFolder(folderName: string, isRecursive: boolean = false, isForce: boolean = true): Promise<boolean> {
const commandText = this.osCommands && this.osCommands.removeFolder(folderName, isRecursive, isForce);
const commandResult = await this.execute(commandText);
const result = commandResult.exitCode >= 0;
return result;
}
public async removeFiles(folderOrFileName: string, filePattern: string = ""): Promise<boolean> {
const commandText = this.osCommands && this.osCommands.removeFiles(folderOrFileName, filePattern);
const commandResult = await this.execute(commandText);
const result = commandResult.exitCode >= 0;
return result;
}
public async readLastLines(fileName: string, lineCount: number = 1): Promise<string> {
const commandText = this.osCommands && this.osCommands.readLastLines(fileName, lineCount);
const commandResult = await this.execute(commandText);
let result: string = "";
if (commandResult !== undefined && commandResult.stdout !== undefined && commandResult.stdout.length > 0) {
result = commandResult.stdout;
}
return result;
}
public async isProcessAlive(pidFileName: string): Promise<boolean> {
const commandText = this.osCommands && this.osCommands.isProcessAliveCommand(pidFileName);
const commandResult = await this.execute(commandText);
const result = this.osCommands && this.osCommands.isProcessAliveProcessOutput(commandResult);
return result !== undefined ? result : false;
}
public async killChildProcesses(pidFileName: string): Promise<boolean> {
const commandText = this.osCommands && this.osCommands.killChildProcesses(pidFileName);
const commandResult = await this.execute(commandText);
return commandResult.exitCode == 0;
}
public async extractFile(tarFileName: string, targetFolder: string): Promise<boolean> {
const commandText = this.osCommands && this.osCommands.extractFile(tarFileName, targetFolder);
const commandResult = await this.execute(commandText);
return commandResult.exitCode == 0;
}
public async executeScript(script: string, isFile: boolean, isInteractive: boolean = false): Promise<boolean> {
const commandText = this.osCommands && this.osCommands.executeScript(script, isFile);
const commandResult = await this.execute(commandText, undefined, isInteractive);
return commandResult.exitCode == 0;
}
/**
* Copy local file to remote path
* @param localFilePath the path of local file
* @param remoteFilePath the target path in remote machine
*/
public async copyFileToRemote(localFilePath: string, remoteFilePath: string): Promise<boolean> {
const log: Logger = getLogger();
log.debug(`copyFileToRemote: localFilePath: ${localFilePath}, remoteFilePath: ${remoteFilePath}`);
const deferred: Deferred<boolean> = new Deferred<boolean>();
this.sshClient.sftp((err: Error, sftp: SFTPWrapper) => {
if (err !== undefined && err !== null) {
log.error(`copyFileToRemote: ${err.message}, ${localFilePath}, ${remoteFilePath}`);
deferred.reject(err);
return;
}
assert(sftp !== undefined);
sftp.fastPut(localFilePath, remoteFilePath, (fastPutErr: Error) => {
sftp.end();
if (fastPutErr !== undefined && fastPutErr !== null) {
deferred.reject(fastPutErr);
} else {
deferred.resolve(true);
}
});
});
return deferred.promise;
}
/**
* Copy files and directories in local directory recursively to remote directory
* @param localDirectory local diretory
* @param remoteDirectory remote directory
* @param sshClient SSH client
*/
public async copyDirectoryToRemote(localDirectory: string, remoteDirectory: string, remoteOS: string): Promise<void> {
const tmpSuffix: string = uniqueString(5);
const localTarPath: string = path.join(os.tmpdir(), `nni_tmp_local_${tmpSuffix}.tar.gz`);
const remoteTarPath: string = unixPathJoin(getRemoteTmpDir(remoteOS), `nni_tmp_remote_${tmpSuffix}.tar.gz`);
// Compress files in local directory to experiment root directory
await tarAdd(localTarPath, localDirectory);
// Copy the compressed file to remoteDirectory and delete it
await this.copyFileToRemote(localTarPath, remoteTarPath);
await execRemove(localTarPath);
// Decompress the remote compressed file in and delete it
await this.extractFile(remoteTarPath, remoteDirectory);
await this.removeFiles(remoteTarPath);
}
public async getRemoteFileContent(filePath: string): Promise<string> {
const deferred: Deferred<string> = new Deferred<string>();
this.sshClient.sftp((err: Error, sftp: SFTPWrapper) => {
if (err !== undefined && err !== null) {
getLogger()
.error(`getRemoteFileContent: ${err.message}`);
deferred.reject(new Error(`SFTP error: ${err.message}`));
return;
}
try {
const sftpStream: stream.Readable = sftp.createReadStream(filePath);
let dataBuffer: string = '';
sftpStream.on('data', (data: Buffer | string) => {
dataBuffer += data;
})
.on('error', (streamErr: Error) => {
sftp.end();
deferred.reject(new NNIError(NNIErrorNames.NOT_FOUND, streamErr.message));
})
.on('end', () => {
// sftp connection need to be released manually once operation is done
sftp.end();
deferred.resolve(dataBuffer);
});
} catch (error) {
getLogger()
.error(`getRemoteFileContent: ${error.message}`);
sftp.end();
deferred.reject(new Error(`SFTP error: ${error.message}`));
}
});
return deferred.promise;
}
private async execute(command: string | undefined, processOutput: ((input: RemoteCommandResult) => RemoteCommandResult) | undefined = undefined, useShell: boolean = false): Promise<RemoteCommandResult> {
const log: Logger = getLogger();
log.debug(`remoteExeCommand: command: [${command}]`);
const deferred: Deferred<RemoteCommandResult> = new Deferred<RemoteCommandResult>();
let stdout: string = '';
let stderr: string = '';
let exitCode: number;
const callback = (err: Error, channel: ClientChannel): void => {
if (err !== undefined && err !== null) {
log.error(`remoteExeCommand: ${err.message}`);
deferred.reject(err);
return;
}
channel.on('data', (data: any) => {
stdout += data;
});
channel.on('exit', (code: any) => {
exitCode = <number>code;
log.debug(`remoteExeCommand exit(${exitCode})\nstdout: ${stdout}\nstderr: ${stderr}`);
let result = {
stdout: stdout,
stderr: stderr,
exitCode: exitCode
};
if (processOutput != undefined) {
result = processOutput(result);
}
deferred.resolve(result);
});
channel.stderr.on('data', function (data) {
stderr += data;
});
if (useShell) {
channel.stdin.write(`${command}\n`);
channel.end("exit\n");
}
return;
};
if (useShell) {
this.sshClient.shell(callback);
} else {
this.sshClient.exec(command !== undefined ? command : "", callback);
}
return deferred.promise;
}
}
export { ShellExecutor };
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment