Unverified Commit 21165b53 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #138 from Microsoft/master

merge master
parents 41a9a598 f10c3311
...@@ -4,13 +4,13 @@ ...@@ -4,13 +4,13 @@
* * * * * *
[![MIT 许可证](https://img.shields.io/badge/license-MIT-brightgreen.svg)](LICENSE) [![生成状态](https://msrasrg.visualstudio.com/NNIOpenSource/_apis/build/status/Microsoft.nni)](https://msrasrg.visualstudio.com/NNIOpenSource/_build/latest?definitionId=6) [![问题](https://img.shields.io/github/issues-raw/Microsoft/nni.svg)](https://github.com/Microsoft/nni/issues?q=is%3Aissue+is%3Aopen) [![Bug](https://img.shields.io/github/issues/Microsoft/nni/bug.svg)](https://github.com/Microsoft/nni/issues?q=is%3Aissue+is%3Aopen+label%3Abug) [![拉取请求](https://img.shields.io/github/issues-pr-raw/Microsoft/nni.svg)](https://github.com/Microsoft/nni/pulls?q=is%3Apr+is%3Aopen) [![版本](https://img.shields.io/github/release/Microsoft/nni.svg)](https://github.com/Microsoft/nni/releases) [![进入 https://gitter.im/Microsoft/nni 聊天室提问](https://badges.gitter.im/Microsoft/nni.svg)](https://gitter.im/Microsoft/nni?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) [![MIT 许可证](https://img.shields.io/badge/license-MIT-brightgreen.svg)](LICENSE) [![生成状态](https://msrasrg.visualstudio.com/NNIOpenSource/_apis/build/status/Microsoft.nni)](https://msrasrg.visualstudio.com/NNIOpenSource/_build/latest?definitionId=6) [![问题](https://img.shields.io/github/issues-raw/Microsoft/nni.svg)](https://github.com/Microsoft/nni/issues?q=is%3Aissue+is%3Aopen) [![Bug](https://img.shields.io/github/issues/Microsoft/nni/bug.svg)](https://github.com/Microsoft/nni/issues?q=is%3Aissue+is%3Aopen+label%3Abug) [![拉取请求](https://img.shields.io/github/issues-pr-raw/Microsoft/nni.svg)](https://github.com/Microsoft/nni/pulls?q=is%3Apr+is%3Aopen) [![版本](https://img.shields.io/github/release/Microsoft/nni.svg)](https://github.com/Microsoft/nni/releases) [![进入 https://gitter.im/Microsoft/nni 聊天室提问](https://badges.gitter.im/Microsoft/nni.svg)](https://gitter.im/Microsoft/nni?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) [![文档状态](https://readthedocs.org/projects/nni/badge/?version=latest)](https://nni.readthedocs.io/en/latest/?badge=latest)
[English](README.md) [English](README.md)
NNI (Neural Network Intelligence) 是自动机器学习(AutoML)的工具包。 它通过多种调优的算法来搜索最好的神经网络结构和(或)超参,并支持单机、本地多机、云等不同的运行环境。 NNI (Neural Network Intelligence) 是自动机器学习(AutoML)的工具包。 它通过多种调优的算法来搜索最好的神经网络结构和(或)超参,并支持单机、本地多机、云等不同的运行环境。
### **NNI [v0.5.1](https://github.com/Microsoft/nni/releases) 已发布!** ### **NNI [v0.5.2](https://github.com/Microsoft/nni/releases) 已发布!**
<p align="center"> <p align="center">
<a href="#nni-v05-has-been-released"><img src="docs/img/overview.svg" /></a> <a href="#nni-v05-has-been-released"><img src="docs/img/overview.svg" /></a>
...@@ -116,7 +116,7 @@ NNI (Neural Network Intelligence) 是自动机器学习(AutoML)的工具包 ...@@ -116,7 +116,7 @@ NNI (Neural Network Intelligence) 是自动机器学习(AutoML)的工具包
*`python >= 3.5` 的环境中运行命令: `git``wget`,确保安装了这两个组件。 *`python >= 3.5` 的环境中运行命令: `git``wget`,确保安装了这两个组件。
```bash ```bash
git clone -b v0.5.1 https://github.com/Microsoft/nni.git git clone -b v0.5.2 https://github.com/Microsoft/nni.git
cd nni cd nni
source install.sh source install.sh
``` ```
...@@ -130,7 +130,7 @@ NNI (Neural Network Intelligence) 是自动机器学习(AutoML)的工具包 ...@@ -130,7 +130,7 @@ NNI (Neural Network Intelligence) 是自动机器学习(AutoML)的工具包
* 通过克隆源代码下载示例。 * 通过克隆源代码下载示例。
```bash ```bash
git clone -b v0.5.1 https://github.com/Microsoft/nni.git git clone -b v0.5.2 https://github.com/Microsoft/nni.git
``` ```
* 运行 mnist 示例。 * 运行 mnist 示例。
......
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge,
# to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
jobs: jobs:
- job: 'version_number_validation'
pool:
vmImage: 'Ubuntu 16.04'
strategy:
matrix:
Python36:
PYTHON_VERSION: '3.6'
steps:
- script: |
echo $(build_version)
if [[ $(build_version) =~ ^v[0-9](.[0-9]){1,3}$ ]]; then
echo 'valid build version $(build_version)'
echo `git describe --tags --abbrev=0`
else
echo 'invalid build version $(build_version)'
exit 1
fi
condition: eq( variables['build_type'], 'prerelease' )
displayName: 'Validate prerelease version number'
- script: |
export BRANCH_TAG=`git describe --tags --abbrev=0`
echo $BRANCH_TAG
if [[ $BRANCH_TAG = $(build_version) && $BRANCH_TAG =~ ^v[0-9](.[0-9]){1,3}$ ]]; then
echo 'Build version match branch tag'
else
echo 'Build version does not match branch tag'
exit 1
fi
condition: eq( variables['build_type'], 'release' )
displayName: 'Validate release version number and branch tag'
- job: 'Build_upload_nni_ubuntu' - job: 'Build_upload_nni_ubuntu'
dependsOn: version_number_validation
condition: succeeded()
pool: pool:
vmImage: 'Ubuntu 16.04' vmImage: 'Ubuntu 16.04'
strategy: strategy:
...@@ -13,21 +67,22 @@ jobs: ...@@ -13,21 +67,22 @@ jobs:
python3 -m pip install --upgrade pip setuptools --user python3 -m pip install --upgrade pip setuptools --user
python3 -m pip install twine --user python3 -m pip install twine --user
displayName: 'Install twine' displayName: 'Install twine'
- script: | - script: |
# NNI build scripts (Makefile) uses branch tag as package version number
# To test this pipeline without impacting nni testpypi/pypi packages, uncomment following git tag command
# git tag v0.0.1
cd deployment/pypi cd deployment/pypi
if [ $(build_type) = 'prerelease' ] if [ $(build_type) = 'prerelease' ]
then then
# NNI build scripts (Makefile) uses branch tag as package version number
git tag $(build_version)
echo 'building prerelease package...' echo 'building prerelease package...'
make version_ts=true build make version_ts=true build
else else
echo 'building release package...' echo 'building release package...'
make build make build
fi fi
condition: eq( variables['upload_package'], 'true' ) condition: eq( variables['upload_package'], 'true')
displayName: 'build nni bdsit_wheel' displayName: 'build nni bdsit_wheel'
- script: | - script: |
cd deployment/pypi cd deployment/pypi
if [ $(build_type) = 'prerelease' ] if [ $(build_type) = 'prerelease' ]
...@@ -38,7 +93,7 @@ jobs: ...@@ -38,7 +93,7 @@ jobs:
echo 'uploading release package to pypi...' echo 'uploading release package to pypi...'
python3 -m twine upload -u $(pypi_user) -p $(pypi_pwd) dist/* python3 -m twine upload -u $(pypi_user) -p $(pypi_pwd) dist/*
fi fi
condition: eq( variables['upload_package'], 'true' ) condition: eq( variables['upload_package'], 'true')
displayName: 'upload nni package to pypi/testpypi' displayName: 'upload nni package to pypi/testpypi'
- script: | - script: |
...@@ -48,7 +103,7 @@ jobs: ...@@ -48,7 +103,7 @@ jobs:
then then
docker login -u $(docker_hub_dev_user) -p $(docker_hub_dev_pwd) docker login -u $(docker_hub_dev_user) -p $(docker_hub_dev_pwd)
export IMG_NAME=$(dev_docker_img) export IMG_NAME=$(dev_docker_img)
export IMG_TAG=`git describe --tags --abbrev=0`.`date +%y%m%d%H%M` export IMG_TAG=`git describe --tags --abbrev=0`.`date -u +%y%m%d%H%M`
echo 'updating docker file for testpyi...' echo 'updating docker file for testpyi...'
sed -ie 's/RUN python3 -m pip --no-cache-dir install nni/RUN python3 -m pip install --user --no-cache-dir --index-url https:\/\/test.pypi.org\/simple --extra-index-url https:\/\/pypi.org\/simple nni/' Dockerfile sed -ie 's/RUN python3 -m pip --no-cache-dir install nni/RUN python3 -m pip install --user --no-cache-dir --index-url https:\/\/test.pypi.org\/simple --extra-index-url https:\/\/pypi.org\/simple nni/' Dockerfile
else else
...@@ -67,10 +122,12 @@ jobs: ...@@ -67,10 +122,12 @@ jobs:
docker push $IMG_NAME:latest docker push $IMG_NAME:latest
fi fi
condition: eq( variables['build_docker_img'], 'true' ) condition: eq( variables['build_docker_img'], 'true')
displayName: 'build and upload nni docker image' displayName: 'build and upload nni docker image'
- job: 'Build_upload_nni_macos' - job: 'Build_upload_nni_macos'
dependsOn: version_number_validation
condition: succeeded()
pool: pool:
vmImage: 'macOS 10.13' vmImage: 'macOS 10.13'
strategy: strategy:
...@@ -82,24 +139,26 @@ jobs: ...@@ -82,24 +139,26 @@ jobs:
python3 -m pip install --upgrade pip setuptools --user python3 -m pip install --upgrade pip setuptools --user
python3 -m pip install twine --user python3 -m pip install twine --user
displayName: 'Install twine' displayName: 'Install twine'
- script: | - script: |
make install-dependencies make install-dependencies
displayName: 'Install nni dependencies' displayName: 'Install nni dependencies'
- script: | - script: |
# NNI build scripts (Makefile) uses branch tag as package version number
# To test this pipeline without impacting nni testpypi/pypi packages, uncomment following git tag command
# git tag v0.0.1
cd deployment/pypi cd deployment/pypi
if [ $(build_type) = 'prerelease' ] if [ $(build_type) = 'prerelease' ]
then then
# NNI build scripts (Makefile) uses branch tag as package version number
git tag $(build_version)
echo 'building prerelease package...' echo 'building prerelease package...'
PATH=$HOME/Library/Python/3.7/bin:$PATH make version_ts=true build PATH=$HOME/Library/Python/3.7/bin:$PATH make version_ts=true build
else else
echo 'building release package...' echo 'building release package...'
PATH=$HOME/Library/Python/3.7/bin:$PATH make build PATH=$HOME/Library/Python/3.7/bin:$PATH make build
fi fi
condition: eq( variables['upload_package'], 'true' ) condition: eq( variables['upload_package'], 'true')
displayName: 'build nni bdsit_wheel' displayName: 'build nni bdsit_wheel'
- script: | - script: |
cd deployment/pypi cd deployment/pypi
if [ $(build_type) = 'prerelease' ] if [ $(build_type) = 'prerelease' ]
...@@ -110,5 +169,5 @@ jobs: ...@@ -110,5 +169,5 @@ jobs:
echo 'uploading release package to pypi...' echo 'uploading release package to pypi...'
python3 -m twine upload -u $(pypi_user) -p $(pypi_pwd) dist/* python3 -m twine upload -u $(pypi_user) -p $(pypi_pwd) dist/*
fi fi
condition: eq( variables['upload_package'], 'true' ) condition: eq( variables['upload_package'], 'true')
displayName: 'upload nni package to pypi/testpypi' displayName: 'upload nni package to pypi/testpypi'
...@@ -11,7 +11,7 @@ else ...@@ -11,7 +11,7 @@ else
$(error platform $(UNAME_S) not supported) $(error platform $(UNAME_S) not supported)
endif endif
TIME_STAMP = $(shell date "+%y%m%d%H%M") TIME_STAMP = $(shell date -u "+%y%m%d%H%M")
NNI_VERSION_VALUE = $(shell git describe --tags --abbrev=0) NNI_VERSION_VALUE = $(shell git describe --tags --abbrev=0)
# To include time stamp in version value, run: # To include time stamp in version value, run:
......
...@@ -56,6 +56,7 @@ setuptools.setup( ...@@ -56,6 +56,7 @@ setuptools.setup(
'nni_gpu_tool': '../../tools/nni_gpu_tool', 'nni_gpu_tool': '../../tools/nni_gpu_tool',
'nni': '../../src/sdk/pynni/nni' 'nni': '../../src/sdk/pynni/nni'
}, },
package_data = {'nni': ['**/requirements.txt']},
python_requires = '>=3.5', python_requires = '>=3.5',
install_requires = [ install_requires = [
'schema', 'schema',
...@@ -81,4 +82,4 @@ setuptools.setup( ...@@ -81,4 +82,4 @@ setuptools.setup(
'nnictl = nni_cmd.nnictl:parse_args' 'nnictl = nni_cmd.nnictl:parse_args'
] ]
} }
) )
\ No newline at end of file
...@@ -54,7 +54,7 @@ Compared with LocalMode and [RemoteMachineMode](RemoteMachineMode.md), trial con ...@@ -54,7 +54,7 @@ Compared with LocalMode and [RemoteMachineMode](RemoteMachineMode.md), trial con
* dataDir * dataDir
* Optional key. It specifies the HDFS data direcotry for trial to download data. The format should be something like hdfs://{your HDFS host}:9000/{your data directory} * Optional key. It specifies the HDFS data direcotry for trial to download data. The format should be something like hdfs://{your HDFS host}:9000/{your data directory}
* outputDir * outputDir
* Optional key. It specifies the HDFS output direcotry for trial. Once the trial is completed (either succeed or fail), trial's stdout, stderr will be copied to this directory by NNI sdk automatically. The format should be something like hdfs://{your HDFS host}:9000/{your output directory} * Optional key. It specifies the HDFS output directory for trial. Once the trial is completed (either succeed or fail), trial's stdout, stderr will be copied to this directory by NNI sdk automatically. The format should be something like hdfs://{your HDFS host}:9000/{your output directory}
Once complete to fill NNI experiment config file and save (for example, save as exp_pai.yml), then run the following command Once complete to fill NNI experiment config file and save (for example, save as exp_pai.yml), then run the following command
``` ```
......
# ChangeLog # ChangeLog
## Release 0.5.2 - 3/4/2019
### Improvements
* Curve fitting assessor performance improvement.
### Documentation
* Chinese version document: https://nni.readthedocs.io/zh/latest/
* Debuggability/serviceability document: https://nni.readthedocs.io/en/latest/HowToDebug.html
* Tuner assessor reference: https://nni.readthedocs.io/en/latest/sdk_reference.html#tuner
### Bug Fixes and Other Changes
* Fix a race condition bug that does not store trial job cancel status correctly.
* Fix search space parsing error when using SMAC tuner.
* Fix cifar10 example broken pipe issue.
* Add unit test cases for nnimanager and local training service.
* Add integration test azure pipelines for remote machine, PAI and kubeflow training services.
* Support Pylon in PAI webhdfs client.
## Release 0.5.1 - 1/31/2018 ## Release 0.5.1 - 1/31/2018
### Improvements ### Improvements
* Making [log directory](https://github.com/Microsoft/nni/blob/v0.5.1/docs/en_US/ExperimentConfig.md) configurable * Making [log directory](https://github.com/Microsoft/nni/blob/v0.5.1/docs/en_US/ExperimentConfig.md) configurable
......
...@@ -30,7 +30,20 @@ NNI 中,有 4 种类型的 Annotation; ...@@ -30,7 +30,20 @@ NNI 中,有 4 种类型的 Annotation;
- **sampling_algo**: 指定搜索空间的采样算法。 可将其换成 NNI 支持的其它采样函数,函数要以 `nni.` 开头。例如,`choice``uniform`,详见 [SearchSpaceSpec](SearchSpaceSpec.md) - **sampling_algo**: 指定搜索空间的采样算法。 可将其换成 NNI 支持的其它采样函数,函数要以 `nni.` 开头。例如,`choice``uniform`,详见 [SearchSpaceSpec](SearchSpaceSpec.md)
- **name**: 将被赋值的变量名称。 注意,此参数应该与下面一行等号左边的值相同。 - **name**: 将被赋值的变量名称。 注意,此参数应该与下面一行等号左边的值相同。
例如: NNI 支持如下 10 种类型来表示搜索空间:
- `@nni.variable(nni.choice(option1,option2,...,optionN),name=variable)` 变量值是选项中的一种,这些变量可以是任意的表达式。
- `@nni.variable(nni.randint(upper),name=variable)` 变量可以是范围 [0, upper) 中的任意整数。
- `@nni.variable(nni.uniform(low, high),name=variable)` 变量值会是 low 和 high 之间均匀分布的某个值。
- `@nni.variable(nni.quniform(low, high, q),name=variable)` 变量值会是 low 和 high 之间均匀分布的某个值,公式为:round(uniform(low, high) / q) * q
- `@nni.variable(nni.loguniform(low, high),name=variable)` 变量值是 exp(uniform(low, high)) 的点,数值以对数均匀分布。
- `@nni.variable(nni.qloguniform(low, high, q),name=variable)` 变量值会是 low 和 high 之间均匀分布的某个值,公式为:round(exp(uniform(low, high)) / q) * q
- `@nni.variable(nni.normal(mu, sigma),name=variable)` 变量值为正态分布的实数值,平均值为 mu,标准方差为 sigma。
- `@nni.variable(nni.qnormal(mu, sigma, q),name=variable)` 变量值分布的公式为: round(normal(mu, sigma) / q) * q
- `@nni.variable(nni.lognormal(mu, sigma),name=variable)` 变量值分布的公式为: exp(normal(mu, sigma))
- `@nni.variable(nni.qlognormal(mu, sigma, q),name=variable)` 变量值分布的公式为: round(exp(normal(mu, sigma)) / q) * q
样例如下:
```python ```python
'''@nni.variable(nni.choice(0.1, 0.01, 0.001), name=learning_rate)''' '''@nni.variable(nni.choice(0.1, 0.01, 0.001), name=learning_rate)'''
...@@ -45,7 +58,7 @@ learning_rate = 0.1 ...@@ -45,7 +58,7 @@ learning_rate = 0.1
**参数** **参数**
- **\*functions**: 可选择的函数。 注意,必须是包括参数的完整函数调用。 例如 `max_pool(hidden_layer, pool_size)` - **functions**: 可选择的函数。 注意,必须是包括参数的完整函数调用。 例如 `max_pool(hidden_layer, pool_size)`
- **name**: 将被替换的函数名称。 - **name**: 将被替换的函数名称。
例如: 例如:
......
...@@ -6,13 +6,10 @@ NNI 提供了先进的调优算法,使用上也很简单。 下面是内置 As ...@@ -6,13 +6,10 @@ NNI 提供了先进的调优算法,使用上也很简单。 下面是内置 As
当前支持的 Assessor: 当前支持的 Assessor:
* [Medianstop(中位数终止)](medianstopAssessor.md) | Assessor | 算法简介 |
* [Curvefitting(曲线拟合)](curvefittingAssessor.md) | --------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| [**Medianstop**](#MedianStop) | Medianstop 是一个简单的提前终止算法。 如果 Trial X 的在步骤 S 的最好目标值比所有已完成 Trial 的步骤 S 的中位数值明显要低,这个 Trial 就会被提前停止。 [参考论文](https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/46180.pdf) |
| Assessor | 算法简介 | | [**Curvefitting**](#Curvefitting) | Curve Fitting Assessor 是一个 LPA (learning, predicting, assessing,即学习、预测、评估) 的算法。 如果预测的 Trial X 在 step S 比性能最好的 Trial 要差,就会提前终止它。 此算法中采用了 12 种曲线来拟合精度曲线。 [参考论文](http://aad.informatik.uni-freiburg.de/papers/15-IJCAI-Extrapolation_of_Learning_Curves.pdf) |
| ------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| __Medianstop__ [(用法)](#MedianStop) | Medianstop 是一个简单的提前终止算法。 如果 Trial X 的在步骤 S 的最好目标值比所有已完成 Trial 的步骤 S 的中位数值明显要低,这个 Trial 就会被提前停止。 [参考论文](https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/46180.pdf) |
| __Curvefitting__ [(用法)](#Curvefitting) | Curve Fitting Assessor 是一个 LPA (learning, predicting, assessing,即学习、预测、评估) 的算法。 如果预测的 Trial X 在 step S 比性能最好的 Trial 要差,就会提前终止它。 此算法中采用了 12 种曲线来拟合精度曲线。 [参考论文](http://aad.informatik.uni-freiburg.de/papers/15-IJCAI-Extrapolation_of_Learning_Curves.pdf) |
## 用法 ## 用法
......
...@@ -6,29 +6,18 @@ NNI 提供了先进的调优算法,使用上也很简单。 下面是内置 Tu ...@@ -6,29 +6,18 @@ NNI 提供了先进的调优算法,使用上也很简单。 下面是内置 Tu
当前支持的 Tuner: 当前支持的 Tuner:
* [TPE](hyperoptTuner.md) | Tuner | 算法简介 |
* [Random Search(随机搜索)](hyperoptTuner.md) | ---------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
* [Anneal(退火算法)](hyperoptTuner.md) | [**TPE**](#TPE) | Tree-structured Parzen Estimator (TPE) 是一种 sequential model-based optimization(SMBO,即基于序列模型优化)的方法。 SMBO 方法根据历史指标数据来按顺序构造模型,来估算超参的性能,随后基于此模型来选择新的超参。 [参考论文](https://papers.nips.cc/paper/4443-algorithms-for-hyper-parameter-optimization.pdf) |
* [Naive Evolution(进化算法)](evolutionTuner.md) | [**Random Search**](#Random) | 在超参优化时,随机搜索算法展示了其惊人的简单和效果。 建议当不清楚超参的先验分布时,采用随机搜索作为基准。 [参考论文](http://www.jmlr.org/papers/volume13/bergstra12a/bergstra12a.pdf) |
* [SMAC](smacTuner.md) | [**Anneal**](#Anneal) | 这种简单的退火算法从先前的采样开始,会越来越靠近发现的最佳点取样。 此算法是随机搜索的简单变体,利用了反应曲面的平滑性。 退火率不是自适应的。 |
* [Batch Tuner(批量调参器)](batchTuner.md) | [**Naive Evolution**](#Evolution) | 朴素进化算法来自于大规模图像分类进化。 它会基于搜索空间随机生成一个种群。 在每一代中,会选择较好的结果,并对其下一代进行一些变异(例如,改动一个超参,增加或减少一层)。 进化算法需要很多次 Trial 才能有效,但它也非常简单,也很容易扩展新功能。 [参考论文](https://arxiv.org/pdf/1703.01041.pdf) |
* [Grid Search(网格搜索)](gridsearchTuner.md) | [**SMAC**](#SMAC) | SMAC 基于 Sequential Model-Based Optimization (SMBO,即序列的基于模型优化方法)。 它会利用使用过的结果好的模型(高斯随机过程模型),并将随机森林引入到 SMBO 中,来处理分类参数。 SMAC 算法包装了 Github 的 SMAC3。 注意:SMAC 需要通过 `nnictl package` 命令来安装。 [参考论文,](https://www.cs.ubc.ca/~hutter/papers/10-TR-SMAC.pdf) [Github 代码库](https://github.com/automl/SMAC3) |
* [Hyperband](hyperbandAdvisor.md) | [**Batch tuner**](#Batch) | Batch Tuner 能让用户简单的提供几组配置(如,超参选项的组合)。 当所有配置都执行完后,Experiment 即结束。 Batch Tuner 仅支持 choice 类型。 |
* [Network Morphism](networkmorphismTuner.md) | [**Grid Search**](#GridSearch) | Grid Search 会穷举定义在搜索空间文件中的所有超参组合。 网格搜索可以使用的类型有 choice, quniform, qloguniform。 quniform 和 qloguniform 中的数值 q 具有特别的含义(不同于搜索空间文档中的说明)。 它表示了在最高值与最低值之间采样的值的数量。 |
* [Metis Tuner](metisTuner.md) | [**Hyperband**](#Hyperband) | Hyperband 试图用有限的资源来探索尽可能多的组合,并发现最好的结果。 它的基本思路是生成大量的配置,并运行少量的步骤来找到有可能好的配置,然后继续训练找到其中更好的配置。 [参考论文](https://arxiv.org/pdf/1603.06560.pdf) |
| [**Network Morphism**](#NetworkMorphism) | Network Morphism 提供了深度学习模型的自动架构搜索功能。 每个子网络都继承于父网络的知识和形态,并变换网络的不同形态,包括深度,宽度,跨层连接(skip-connection)。 然后使用历史的架构和指标,来估计子网络的值。 最后会选择最有希望的模型进行训练。 [参考论文](https://arxiv.org/abs/1806.10282) |
| Tuner | 算法简介 | | [**Metis Tuner**](#MetisTuner) | 大多数调参工具仅仅预测最优配置,而 Metis 的优势在于有两个输出:(a) 最优配置的当前预测结果, 以及 (b) 下一次 Trial 的建议。 它不进行随机取样。 大多数工具假设训练集没有噪声数据,但 Metis 会知道是否需要对某个超参重新采样。 [参考论文](https://www.microsoft.com/en-us/research/publication/metis-robustly-tuning-tail-latencies-cloud-systems/) |
| --------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| **TPE** [(用法)](#TPE) | Tree-structured Parzen Estimator (TPE) 是一种 sequential model-based optimization(SMBO,即基于序列模型优化)的方法。 SMBO 方法根据历史指标数据来按顺序构造模型,来估算超参的性能,随后基于此模型来选择新的超参。 [参考论文](https://papers.nips.cc/paper/4443-algorithms-for-hyper-parameter-optimization.pdf) |
| **Random Search** [(用法)](#Random) | 在超参优化时,随机搜索算法展示了其惊人的简单和效果。 建议当不清楚超参的先验分布时,采用随机搜索作为基准。 [参考论文](http://www.jmlr.org/papers/volume13/bergstra12a/bergstra12a.pdf) |
| **Anneal** [(用法)](#Anneal) | 这种简单的退火算法从先前的采样开始,会越来越靠近发现的最佳点取样。 此算法是随机搜索的简单变体,利用了反应曲面的平滑性。 退火率不是自适应的。 |
| **Naive Evolution** [(用法)](#Evolution) | 朴素进化算法来自于大规模图像分类进化。 它会基于搜索空间随机生成一个种群。 在每一代中,会选择较好的结果,并对其下一代进行一些变异(例如,改动一个超参,增加或减少一层)。 进化算法需要很多次 Trial 才能有效,但它也非常简单,也很容易扩展新功能。 [参考论文](https://arxiv.org/pdf/1703.01041.pdf) |
| **SMAC** [(用法)](#SMAC) | SMAC 基于 Sequential Model-Based Optimization (SMBO,即序列的基于模型优化方法)。 它会利用使用过的结果好的模型(高斯随机过程模型),并将随机森林引入到 SMBO 中,来处理分类参数。 SMAC 算法包装了 Github 的 SMAC3。 注意:SMAC 需要通过 `nnictl package` 命令来安装。 [参考论文,](https://www.cs.ubc.ca/~hutter/papers/10-TR-SMAC.pdf) [Github 代码库](https://github.com/automl/SMAC3) |
| **Batch tuner** [(用法)](#Batch) | Batch Tuner 能让用户简单的提供几组配置(如,超参选项的组合)。 当所有配置都执行完后,Experiment 即结束。 Batch Tuner 仅支持 choice 类型。 |
| **Grid Search** [(用法)](#GridSearch) | Grid Search 会穷举定义在搜索空间文件中的所有超参组合。 网格搜索可以使用的类型有 choice, quniform, qloguniform。 quniform 和 qloguniform 中的数值 q 具有特别的含义(不同于搜索空间文档中的说明)。 它表示了在最高值与最低值之间采样的值的数量。 |
| **Hyperband** [(用法)](#Hyperband) | Hyperband 试图用有限的资源来探索尽可能多的组合,并发现最好的结果。 它的基本思路是生成大量的配置,并运行少量的步骤来找到有可能好的配置,然后继续训练找到其中更好的配置。 [参考论文](https://arxiv.org/pdf/1603.06560.pdf) |
| **Network Morphism** [(用法)](#NetworkMorphism) | Network Morphism 提供了深度学习模型的自动架构搜索功能。 每个子网络都继承于父网络的知识和形态,并变换网络的不同形态,包括深度,宽度,跨层连接(skip-connection)。 然后使用历史的架构和指标,来估计子网络的值。 最后会选择最有希望的模型进行训练。 [参考论文](https://arxiv.org/abs/1806.10282) |
| **Metis Tuner** [(用法)](#MetisTuner) | 大多数调参工具仅仅预测最优配置,而 Metis 的优势在于有两个输出:(a) 最优配置的当前预测结果, 以及 (b) 下一次 Trial 的建议。 它不进行随机取样。 大多数工具假设训练集没有噪声数据,但 Metis 会知道是否需要对某个超参重新采样。 [参考论文](https://www.microsoft.com/en-us/research/publication/metis-robustly-tuning-tail-latencies-cloud-systems/) |
<br /> <br />
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
先决条件:`python >=3.5, git, wget` 先决条件:`python >=3.5, git, wget`
```bash ```bash
git clone -b v0.5.1 https://github.com/Microsoft/nni.git git clone -b v0.5.2 https://github.com/Microsoft/nni.git
cd nni cd nni
./install.sh ./install.sh
``` ```
......
# 更改日志 # 更改日志
## 发布 0.5.2 - 3/4/2019
### 改进
* 提升 Curve fitting Assessor 的性能。
### 文档
* 发布中文文档网站:https://nni.readthedocs.io/zh/latest/
* 调试和维护:https://nni.readthedocs.io/en/latest/HowToDebug.html
* Tuner、Assessor 参考:https://nni.readthedocs.io/en/latest/sdk_reference.html#tuner
### Bug 修复和其它更新
* 修复了在某些极端条件下,不能正确存储任务的取消状态。
* 修复在使用 SMAC Tuner 时,解析搜索空间的错误。
* 修复 CIFAR-10 样例中的 broken pipe 问题。
* 为本地训练服务和 NNI 管理器添加单元测试。
* 为远程服务器、OpenPAI 和 Kubeflow 训练平台在 Azure 中增加集成测试。
* 在 OpenPAI 客户端中支持 Pylon 路径。
## 发布 0.5.1 - 1/31/2018 ## 发布 0.5.1 - 1/31/2018
### 改进 ### 改进
* 可配置[日志目录](ExperimentConfig.md) * [日志目录](https://github.com/Microsoft/nni/blob/v0.5.1/docs/en_US/ExperimentConfig.md)可配置
* 支持[不同级别的日志](ExperimentConfig.md),使其更易于调试。 * 支持[不同级别的日志](https://github.com/Microsoft/nni/blob/v0.5.1/docs/en_US/ExperimentConfig.md),使其更易于调试。
### 文档 ### 文档
...@@ -23,14 +44,14 @@ ...@@ -23,14 +44,14 @@
#### 支持新的 Tuner 和 Assessor #### 支持新的 Tuner 和 Assessor
* 支持[Metis tuner](./Builtin_Tuner.md#MetisTuner) 作为 NNI 的 Tuner。 **在线**超参调优的场景,Metis 算法已经被证明非常有效。 * 支持新的 [Metis Tuner](metisTuner.md)。 对于**在线**超参调优的场景,Metis 算法已经被证明非常有效。
* 支持 [ENAS customized tuner](https://github.com/countif/enas_nni)。由 GitHub 社区用户所贡献。它是神经网络的搜索算法,能够通过强化学习来学习神经网络架构,比 NAS 的性能更好。 * 支持 [ENAS customized tuner](https://github.com/countif/enas_nni)。由 GitHub 社区用户所贡献。它是神经网络的搜索算法,能够通过强化学习来学习神经网络架构,比 NAS 的性能更好。
* 支持 [Curve fitting (曲线拟合)Assessor](./Builtin_Tuner.md#Curvefitting),通过曲线拟合的策略来实现提前终止 Trial。 * 支持 [Curve fitting (曲线拟合)Assessor](curvefittingAssessor.md),通过曲线拟合的策略来实现提前终止 Trial。
* 进一步支持 [Weight Sharing(权重共享)](./AdvancedNAS.md):为 NAS Tuner 通过 NFS 来提供权重共享。 * 进一步支持 [Weight Sharing(权重共享)](./AdvancedNAS.md):为 NAS Tuner 通过 NFS 来提供权重共享。
#### 改进训练平台 #### 改进训练平台
* [FrameworkController 训练服务](./FrameworkControllerMode.md): 支持使用在 Kubernetes 上使用 FrameworkController。 * [FrameworkController 训练平台](./FrameworkControllerMode.md): 支持使用在 Kubernetes 上使用 FrameworkController。
* FrameworkController 是 Kubernetes 上非常通用的控制器(Controller),能用来运行基于各种机器学习框架的分布式作业,如 TensorFlow,Pytorch, MXNet 等。 * FrameworkController 是 Kubernetes 上非常通用的控制器(Controller),能用来运行基于各种机器学习框架的分布式作业,如 TensorFlow,Pytorch, MXNet 等。
* NNI 为作业定义了统一而简单的规范。 * NNI 为作业定义了统一而简单的规范。
* 如何使用 FrameworkController 的 MNIST 样例。 * 如何使用 FrameworkController 的 MNIST 样例。
...@@ -48,12 +69,12 @@ ...@@ -48,12 +69,12 @@
#### 支持新的 Tuner #### 支持新的 Tuner
* 支持新 Tuner [network morphism](./Builtin_Tuner.md#NetworkMorphism) * 支持新 [network morphism](networkmorphismTuner.md) Tuner。
#### 改进训练平台 #### 改进训练平台
*[Kubeflow 训练服务](KubeflowMode.md)的依赖从 kubectl CLI 迁移到 [Kubernetes API](https://kubernetes.io/docs/concepts/overview/kubernetes-api/) 客户端。 *[Kubeflow 训练平台](KubeflowMode.md)的依赖从 kubectl CLI 迁移到 [Kubernetes API](https://kubernetes.io/docs/concepts/overview/kubernetes-api/) 客户端。
* Kubeflow 训练服务支持 [Pytorch-operator](https://github.com/kubeflow/pytorch-operator) * Kubeflow 训练平台支持 [Pytorch-operator](https://github.com/kubeflow/pytorch-operator)
* 改进将本地代码文件上传到 OpenPAI HDFS 的性能。 * 改进将本地代码文件上传到 OpenPAI HDFS 的性能。
* 修复 OpenPAI 在 WEB 界面的 Bug:当 OpenPAI 认证过期后,Web 界面无法更新 Trial 作业的状态。 * 修复 OpenPAI 在 WEB 界面的 Bug:当 OpenPAI 认证过期后,Web 界面无法更新 Trial 作业的状态。
...@@ -82,8 +103,8 @@ ...@@ -82,8 +103,8 @@
* [Kubeflow 训练服务](./KubeflowMode.md) * [Kubeflow 训练服务](./KubeflowMode.md)
* 支持 tf-operator * 支持 tf-operator
* 使用 Kubeflow 的[分布式 Trial 样例](https://github.com/Microsoft/nni/tree/master/examples/trials/mnist-distributed/dist_mnist.py) * 使用 Kubeflow 的[分布式 Trial 样例](https://github.com/Microsoft/nni/tree/master/examples/trials/mnist-distributed/dist_mnist.py)
* [网格搜索 Tuner](Builtin_Tuner.md#GridSearch) * [网格搜索 Tuner](gridsearchTuner.md)
* [Hyperband Tuner](Builtin_Tuner.md#Hyperband) * [Hyperband Tuner](hyperbandAdvisor.md)
* 支持在 MAC 上运行 NNI Experiment * 支持在 MAC 上运行 NNI Experiment
* Web 界面 * Web 界面
* 支持 hyperband Tuner * 支持 hyperband Tuner
...@@ -137,13 +158,13 @@ ...@@ -137,13 +158,13 @@
* float * float
* 包含有 'default' 键值的 dict,'default' 的值必须为 int 或 float。 dict 可以包含任何其它键值对。 * 包含有 'default' 键值的 dict,'default' 的值必须为 int 或 float。 dict 可以包含任何其它键值对。
### 新的内置 Tuner ### 支持新的 Tuner
* **Batch Tuner(批处理调参器)** 会执行所有超参组合,可被用来批量提交 Trial 任务。 * **Batch Tuner(批处理调参器)** 会执行所有超参组合,可被用来批量提交 Trial 任务。
### 新样例 ### 新样例
*的 NNI Docker 映像: *的 NNI Docker 映像:
```bash ```bash
docker pull msranni/nni:latest docker pull msranni/nni:latest
...@@ -166,7 +187,7 @@ ...@@ -166,7 +187,7 @@
* 支持 [OpenPAI](https://github.com/Microsoft/pai) (又称 pai) 训练服务 (参考[这里](./PAIMode.md)来了解如何在 OpenPAI 下提交 NNI 任务) * 支持 [OpenPAI](https://github.com/Microsoft/pai) (又称 pai) 训练服务 (参考[这里](./PAIMode.md)来了解如何在 OpenPAI 下提交 NNI 任务)
* 支持 pai 模式的训练服务。 NNI Trial 可发送至 OpenPAI 集群上运行 * 支持 pai 模式的训练服务。 NNI Trial 可发送至 OpenPAI 集群上运行
* NNI Trial 输出 (包括日志和模型文件) 会被复制到 OpenPAI 的 HDFS 中。 * NNI Trial 输出 (包括日志和模型文件) 会被复制到 OpenPAI 的 HDFS 中。
* 支持 [SMAC](https://www.cs.ubc.ca/~hutter/papers/10-TR-SMAC.pdf) Tuner (参考[这里](Builtin_Tuner.md),了解如何使用 SMAC Tuner) * 支持 [SMAC](https://www.cs.ubc.ca/~hutter/papers/10-TR-SMAC.pdf) Tuner (参考[这里](smacTuner.md),了解如何使用 SMAC Tuner)
* [SMAC](https://www.cs.ubc.ca/~hutter/papers/10-TR-SMAC.pdf) 基于 Sequential Model-Based Optimization (SMBO). 它会利用使用过的结果好的模型(高斯随机过程模型),并将随机森林引入到 SMBO 中,来处理分类参数。 NNI 的 SMAC 通过包装 [SMAC3](https://github.com/automl/SMAC3) 来支持。 * [SMAC](https://www.cs.ubc.ca/~hutter/papers/10-TR-SMAC.pdf) 基于 Sequential Model-Based Optimization (SMBO). 它会利用使用过的结果好的模型(高斯随机过程模型),并将随机森林引入到 SMBO 中,来处理分类参数。 NNI 的 SMAC 通过包装 [SMAC3](https://github.com/automl/SMAC3) 来支持。
* 支持将 NNI 安装在 [conda](https://conda.io/docs/index.html) 和 Python 虚拟环境中。 * 支持将 NNI 安装在 [conda](https://conda.io/docs/index.html) 和 Python 虚拟环境中。
* 其它 * 其它
......
...@@ -53,20 +53,20 @@ ...@@ -53,20 +53,20 @@
* 这表示变量值会类似于 round(loguniform(low, high)) / q) * q * 这表示变量值会类似于 round(loguniform(low, high)) / q) * q
* 适用于值是“平滑”的离散变量,但上下限均有限制。 * 适用于值是“平滑”的离散变量,但上下限均有限制。
* {"_type":"normal","_value":[label, mu, sigma]} * {"_type":"normal","_value":[mu, sigma]}
* 变量值为实数,且为正态分布,均值为 mu,标准方差为 sigma。 优化时,此变量不受约束。 * 变量值为实数,且为正态分布,均值为 mu,标准方差为 sigma。 优化时,此变量不受约束。
* {"_type":"qnormal","_value":[label, mu, sigma, q]} * {"_type":"qnormal","_value":[mu, sigma, q]}
* 这表示变量值会类似于 round(normal(mu, sigma) / q) * q * 这表示变量值会类似于 round(normal(mu, sigma) / q) * q
* 适用于在 mu 周围的离散变量,且没有上下限限制。 * 适用于在 mu 周围的离散变量,且没有上下限限制。
* {"_type":"lognormal","_value":[label, mu, sigma]} * {"_type":"lognormal","_value":[mu, sigma]}
* 变量值为 exp(normal(mu, sigma)) 分布,范围值是对数的正态分布。 当优化时,此变量必须是正数。 * 变量值为 exp(normal(mu, sigma)) 分布,范围值是对数的正态分布。 当优化时,此变量必须是正数。
* {"_type":"qlognormal","_value":[label, mu, sigma, q]} * {"_type":"qlognormal","_value":[mu, sigma, q]}
* 这表示变量值会类似于 round(exp(normal(mu, sigma)) / q) * q * 这表示变量值会类似于 round(exp(normal(mu, sigma)) / q) * q
* 适用于值是“平滑”的离散变量,但某一边有界。 * 适用于值是“平滑”的离散变量,但某一边有界。
......
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
###################### ######################
.. toctree:: .. toctree::
:maxdepth: 2
安装<Installation> 安装<Installation>
实现 Trial<Trials> 实现 Trial<Trials>
Tuner<tuners> Tuner<tuners>
......
...@@ -13,5 +13,7 @@ Assessor 从 Trial 中接收中间结果,并通过指定的算法决定此 Tri ...@@ -13,5 +13,7 @@ Assessor 从 Trial 中接收中间结果,并通过指定的算法决定此 Tri
与 Tuner 类似,可使用内置的 Assessor,也可以自定义 Assessor。 参考下列教程,获取详细信息: 与 Tuner 类似,可使用内置的 Assessor,也可以自定义 Assessor。 参考下列教程,获取详细信息:
.. toctree:: .. toctree::
内置 Assessor<Builtin_Assessors> :maxdepth: 2
内置 Assessor<builtinAssessor>
自定义 Assessor<Customize_Assessor> 自定义 Assessor<Customize_Assessor>
内置 Assessor
=================
.. toctree::
:maxdepth: 1
介绍<Builtin_Assessors>
Medianstop<medianstopAssessor>
Curvefitting<curvefittingAssessor>
\ No newline at end of file
内置 Tuner
==================
.. toctree::
:maxdepth: 1
介绍<Builtin_Tuner>
TPE<hyperoptTuner>
Random Search<hyperoptTuner>
Anneal<hyperoptTuner>
Naive Evolution<evolutionTuner>
SMAC<smacTuner>
Batch Tuner<batchTuner>
Grid Search<gridsearchTuner>
Hyperband<hyperbandAdvisor>
Network Morphism<networkmorphismTuner>
Metis Tuner<metisTuner>
\ No newline at end of file
...@@ -186,9 +186,9 @@ epub_exclude_files = ['search.html'] ...@@ -186,9 +186,9 @@ epub_exclude_files = ['search.html']
# -- Extension configuration ------------------------------------------------- # -- Extension configuration -------------------------------------------------
github_doc_root = 'https://github.com/Microsoft/nni/tree/master/doc/'
def setup(app): def setup(app):
app.add_config_value('recommonmark_config', { app.add_config_value('recommonmark_config', {
'enable_auto_toc_tree': True, 'enable_eval_rst': True,
}, True) 'enable_auto_toc_tree': False,
app.add_transform(AutoStructify) }, True)
app.add_transform(AutoStructify)
\ No newline at end of file
...@@ -11,6 +11,8 @@ Tuner 从 Trial 接收指标结果,来评估一组超参或网络结构的性 ...@@ -11,6 +11,8 @@ Tuner 从 Trial 接收指标结果,来评估一组超参或网络结构的性
详细信息,参考以下教程: 详细信息,参考以下教程:
.. toctree:: .. toctree::
内置 Tuner<Builtin_Tuner> :maxdepth: 2
内置 Tuner<builtinTuner>
自定义 Tuner<Customize_Tuner> 自定义 Tuner<Customize_Tuner>
自定义 Advisor<Customize_Advisor> 自定义 Advisor<Customize_Advisor>
\ No newline at end of file
...@@ -4,8 +4,9 @@ import argparse ...@@ -4,8 +4,9 @@ import argparse
import logging import logging
import math import math
import tempfile import tempfile
import tensorflow as tf import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data from tensorflow.examples.tutorials.mnist import input_data
FLAGS = None FLAGS = None
...@@ -152,12 +153,21 @@ def bias_variable(shape): ...@@ -152,12 +153,21 @@ def bias_variable(shape):
return tf.Variable(initial) return tf.Variable(initial)
def download_mnist_retry(data_dir, max_num_retries=20):
"""Try to download mnist dataset and avoid errors"""
for _ in range(max_num_retries):
try:
return input_data.read_data_sets(data_dir, one_hot=True)
except tf.errors.AlreadyExistsError:
time.sleep(1)
raise Exception("Failed to download MNIST.")
def main(params): def main(params):
''' '''
Main function, build mnist network, run and send result to NNI. Main function, build mnist network, run and send result to NNI.
''' '''
# Import data # Import data
mnist = input_data.read_data_sets(params['data_dir'], one_hot=True) mnist = download_mnist_retry(params['data_dir'])
print('Mnist download data done.') print('Mnist download data done.')
logger.debug('Mnist download data done.') logger.debug('Mnist download data done.')
......
''' '''
mnist.py is an example to show: how to use iterative search space to tune architecture network for mnist. mnist.py is an example to show: how to use iterative search space to tune architecture network for mnist.
''' '''
from __future__ import absolute_import from __future__ import absolute_import, division, print_function
from __future__ import division
from __future__ import print_function
import argparse
import codecs
import json
import logging import logging
import math import math
import sys
import tempfile import tempfile
import tensorflow as tf import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data from tensorflow.examples.tutorials.mnist import input_data
import nni import nni
logger = logging.getLogger('mnist_cascading_search_space') logger = logging.getLogger('mnist_cascading_search_space')
FLAGS = None FLAGS = None
...@@ -95,10 +89,19 @@ class MnistNetwork(object): ...@@ -95,10 +89,19 @@ class MnistNetwork(object):
child_accuracy = tf.equal(tf.argmax(output_layer, 1), tf.argmax(self.y, 1)) child_accuracy = tf.equal(tf.argmax(output_layer, 1), tf.argmax(self.y, 1))
self.accuracy = tf.reduce_mean(tf.cast(child_accuracy, "float")) # add a reduce_mean self.accuracy = tf.reduce_mean(tf.cast(child_accuracy, "float")) # add a reduce_mean
def download_mnist_retry(data_dir, max_num_retries=20):
"""Try to download mnist dataset and avoid errors"""
for _ in range(max_num_retries):
try:
return input_data.read_data_sets(data_dir, one_hot=True)
except tf.errors.AlreadyExistsError:
time.sleep(1)
raise Exception("Failed to download MNIST.")
def main(params): def main(params):
# Import data # Import data
mnist = input_data.read_data_sets(params['data_dir'], one_hot=True) mnist = download_mnist_retry(params['data_dir'])
# Create the model # Create the model
# Build the graph for the deep net # Build the graph for the deep net
mnist_network = MnistNetwork(params) mnist_network = MnistNetwork(params)
...@@ -117,15 +120,15 @@ def main(params): ...@@ -117,15 +120,15 @@ def main(params):
for i in range(params['batch_num']): for i in range(params['batch_num']):
batch = mnist.train.next_batch(params['batch_size']) batch = mnist.train.next_batch(params['batch_size'])
mnist_network.train_step.run(feed_dict={mnist_network.x: batch[0], mnist_network.y: batch[1]}) mnist_network.train_step.run(feed_dict={mnist_network.x: batch[0], mnist_network.y: batch[1]})
if i % 100 == 0: if i % 100 == 0:
train_accuracy = mnist_network.accuracy.eval(feed_dict={ train_accuracy = mnist_network.accuracy.eval(feed_dict={
mnist_network.x: batch[0], mnist_network.y: batch[1]}) mnist_network.x: batch[0], mnist_network.y: batch[1]})
print('step %d, training accuracy %g' % (i, train_accuracy)) print('step %d, training accuracy %g' % (i, train_accuracy))
test_acc = mnist_network.accuracy.eval(feed_dict={ test_acc = mnist_network.accuracy.eval(feed_dict={
mnist_network.x: mnist.test.images, mnist_network.y: mnist.test.labels}) mnist_network.x: mnist.test.images, mnist_network.y: mnist.test.labels})
nni.report_final_result(test_acc) nni.report_final_result(test_acc)
def generate_defualt_params(): def generate_defualt_params():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment