Unverified Commit 9fb25ccc authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #189 from microsoft/master

merge master
parents 1500458a 7c4bc33b
......@@ -30,7 +30,6 @@
tr{
text-align: center;
color:#212121;
font-family: 'Segoe';
font-size: 14px;
/* background-color: #f2f2f2; */
}
......@@ -38,7 +37,6 @@
padding: 2px;
background-color:white !important;
font-size: 14px;
font-family: 'Segoe';
color: #808080;
border-bottom: 1px solid #d0d0d0;
text-align: center;
......
......@@ -95,6 +95,12 @@
dependencies:
"@types/react" "*"
"@types/react-responsive@^3.0.3":
version "3.0.3"
resolved "https://registry.yarnpkg.com/@types/react-responsive/-/react-responsive-3.0.3.tgz#a31b599c7cfe4135c5cc2f45d0b71df64803b23f"
dependencies:
"@types/react" "*"
"@types/react-router@3.0.15":
version "3.0.15"
resolved "http://registry.npmjs.org/@types/react-router/-/react-router-3.0.15.tgz#b55b0dc5ad8f6fa66b609f0efc390b191381d082"
......@@ -1898,6 +1904,12 @@ copy-descriptor@^0.1.0:
version "0.1.1"
resolved "https://registry.yarnpkg.com/copy-descriptor/-/copy-descriptor-0.1.1.tgz#676f6eb3c39997c2ee1ac3a924fd6124748f578d"
copy-to-clipboard@^3.0.8:
version "3.2.0"
resolved "https://registry.yarnpkg.com/copy-to-clipboard/-/copy-to-clipboard-3.2.0.tgz#d2724a3ccbfed89706fac8a894872c979ac74467"
dependencies:
toggle-selection "^1.0.6"
core-js@^1.0.0:
version "1.2.7"
resolved "https://registry.yarnpkg.com/core-js/-/core-js-1.2.7.tgz#652294c14651db28fa93bd2d5ff2983a4f08c636"
......@@ -2051,6 +2063,10 @@ css-loader@0.28.7:
postcss-value-parser "^3.3.0"
source-list-map "^2.0.0"
css-mediaquery@^0.1.2:
version "0.1.2"
resolved "https://registry.yarnpkg.com/css-mediaquery/-/css-mediaquery-0.1.2.tgz#6a2c37344928618631c54bd33cedd301da18bea0"
css-select@^1.1.0:
version "1.2.0"
resolved "https://registry.yarnpkg.com/css-select/-/css-select-1.2.0.tgz#2b3a110539c5355f1cd8d314623e870b121ec858"
......@@ -3308,11 +3324,11 @@ handle-thing@^1.2.5:
version "1.2.5"
resolved "https://registry.yarnpkg.com/handle-thing/-/handle-thing-1.2.5.tgz#fd7aad726bf1a5fd16dfc29b2f7a6601d27139c4"
handlebars@^4.0.3:
version "4.0.12"
resolved "https://registry.yarnpkg.com/handlebars/-/handlebars-4.0.12.tgz#2c15c8a96d46da5e266700518ba8cb8d919d5bc5"
handlebars@^4.0.3, handlebars@^4.1.0:
version "4.1.2"
resolved "https://registry.yarnpkg.com/handlebars/-/handlebars-4.1.2.tgz#b6b37c1ced0306b221e094fc7aca3ec23b131b67"
dependencies:
async "^2.5.0"
neo-async "^2.6.0"
optimist "^0.6.1"
source-map "^0.6.1"
optionalDependencies:
......@@ -3577,6 +3593,10 @@ https-browserify@^1.0.0:
version "1.0.0"
resolved "https://registry.yarnpkg.com/https-browserify/-/https-browserify-1.0.0.tgz#ec06c10e0a34c0f2faf199f7fd7fc78fffd03c73"
hyphenate-style-name@^1.0.0:
version "1.0.3"
resolved "https://registry.yarnpkg.com/hyphenate-style-name/-/hyphenate-style-name-1.0.3.tgz#097bb7fa0b8f1a9cf0bd5c734cf95899981a9b48"
iconv-lite@0.4.23:
version "0.4.23"
resolved "https://registry.yarnpkg.com/iconv-lite/-/iconv-lite-0.4.23.tgz#297871f63be507adcfbfca715d0cd0eed84e9a63"
......@@ -4737,7 +4757,7 @@ longest@^1.0.1:
version "1.0.1"
resolved "https://registry.yarnpkg.com/longest/-/longest-1.0.1.tgz#30a0b2da38f73770e8294a0d22e6625ed77d0097"
loose-envify@^1.0.0, loose-envify@^1.1.0, loose-envify@^1.2.0, loose-envify@^1.3.1:
loose-envify@^1.0.0, loose-envify@^1.1.0, loose-envify@^1.2.0, loose-envify@^1.3.1, loose-envify@^1.4.0:
version "1.4.0"
resolved "https://registry.yarnpkg.com/loose-envify/-/loose-envify-1.4.0.tgz#71ee51fa7be4caec1a63839f7e682d8132d30caf"
dependencies:
......@@ -4777,6 +4797,12 @@ makeerror@1.0.x:
dependencies:
tmpl "1.0.x"
map-age-cleaner@^0.1.1:
version "0.1.3"
resolved "https://registry.yarnpkg.com/map-age-cleaner/-/map-age-cleaner-0.1.3.tgz#7d583a7306434c055fe474b0f45078e6e1b4b92a"
dependencies:
p-defer "^1.0.0"
map-cache@^0.2.2:
version "0.2.2"
resolved "https://registry.yarnpkg.com/map-cache/-/map-cache-0.2.2.tgz#c32abd0bd6525d9b051645bb4f26ac5dc98a0dbf"
......@@ -4791,6 +4817,12 @@ map-visit@^1.0.0:
dependencies:
object-visit "^1.0.0"
matchmediaquery@^0.3.0:
version "0.3.0"
resolved "https://registry.yarnpkg.com/matchmediaquery/-/matchmediaquery-0.3.0.tgz#6f672bcdbc44de16825c6917fbcdcfb9b82607b1"
dependencies:
css-mediaquery "^0.1.2"
math-expression-evaluator@^1.2.14:
version "1.2.17"
resolved "https://registry.yarnpkg.com/math-expression-evaluator/-/math-expression-evaluator-1.2.17.tgz#de819fdbcd84dccd8fae59c6aeb79615b9d266ac"
......@@ -4811,11 +4843,13 @@ media-typer@0.3.0:
version "0.3.0"
resolved "https://registry.yarnpkg.com/media-typer/-/media-typer-0.3.0.tgz#8710d7af0aa626f8fffa1ce00168545263255748"
mem@^1.1.0:
version "1.1.0"
resolved "https://registry.yarnpkg.com/mem/-/mem-1.1.0.tgz#5edd52b485ca1d900fe64895505399a0dfa45f76"
mem@^1.1.0, mem@^4.0.0:
version "4.3.0"
resolved "https://registry.yarnpkg.com/mem/-/mem-4.3.0.tgz#461af497bc4ae09608cdb2e60eefb69bff744178"
dependencies:
mimic-fn "^1.0.0"
map-age-cleaner "^0.1.1"
mimic-fn "^2.0.0"
p-is-promise "^2.0.0"
memory-fs@^0.4.0, memory-fs@~0.4.1:
version "0.4.1"
......@@ -4922,6 +4956,10 @@ mimic-fn@^1.0.0:
version "1.2.0"
resolved "https://registry.yarnpkg.com/mimic-fn/-/mimic-fn-1.2.0.tgz#820c86a39334640e99516928bd03fca88057d022"
mimic-fn@^2.0.0:
version "2.1.0"
resolved "https://registry.yarnpkg.com/mimic-fn/-/mimic-fn-2.1.0.tgz#7ed2c2ccccaf84d3ffcb7a69b57711fc2083401b"
mini-store@^2.0.0:
version "2.0.0"
resolved "https://registry.yarnpkg.com/mini-store/-/mini-store-2.0.0.tgz#0843c048d6942ce55e3e78b1b67fc063022b5488"
......@@ -5008,6 +5046,10 @@ moment@2.x, moment@^2.22.2:
version "2.22.2"
resolved "https://registry.yarnpkg.com/moment/-/moment-2.22.2.tgz#3c257f9839fc0e93ff53149632239eb90783ff66"
monaco-editor@^0.15.1:
version "0.15.6"
resolved "https://registry.yarnpkg.com/monaco-editor/-/monaco-editor-0.15.6.tgz#d63b3b06f86f803464f003b252627c3eb4a09483"
move-concurrently@^1.0.1:
version "1.0.1"
resolved "https://registry.yarnpkg.com/move-concurrently/-/move-concurrently-1.0.1.tgz#be2c005fda32e0b29af1f05d7c4b33214c701f92"
......@@ -5086,6 +5128,10 @@ neo-async@^2.5.0:
version "2.6.0"
resolved "https://registry.yarnpkg.com/neo-async/-/neo-async-2.6.0.tgz#b9d15e4d71c6762908654b5183ed38b753340835"
neo-async@^2.6.0:
version "2.6.1"
resolved "https://registry.yarnpkg.com/neo-async/-/neo-async-2.6.1.tgz#ac27ada66167fa8849a6addd837f6b189ad2081c"
next-tick@1:
version "1.0.0"
resolved "https://registry.yarnpkg.com/next-tick/-/next-tick-1.0.0.tgz#ca86d1fe8828169b0120208e3dc8424b9db8342c"
......@@ -5393,10 +5439,18 @@ osenv@^0.1.4:
os-homedir "^1.0.0"
os-tmpdir "^1.0.0"
p-defer@^1.0.0:
version "1.0.0"
resolved "https://registry.yarnpkg.com/p-defer/-/p-defer-1.0.0.tgz#9f6eb182f6c9aa8cd743004a7d4f96b196b0fb0c"
p-finally@^1.0.0:
version "1.0.0"
resolved "https://registry.yarnpkg.com/p-finally/-/p-finally-1.0.0.tgz#3fbcfb15b899a44123b34b6dcc18b724336a2cae"
p-is-promise@^2.0.0:
version "2.1.0"
resolved "https://registry.yarnpkg.com/p-is-promise/-/p-is-promise-2.1.0.tgz#918cebaea248a62cf7ffab8e3bca8c5f882fc42e"
p-limit@^1.1.0:
version "1.3.0"
resolved "https://registry.yarnpkg.com/p-limit/-/p-limit-1.3.0.tgz#b86bd5f0c25690911c7590fcbfc2010d54b3ccb8"
......@@ -6198,6 +6252,14 @@ prop-types@15.x, prop-types@^15.5.0, prop-types@^15.5.10, prop-types@^15.5.4, pr
loose-envify "^1.3.1"
object-assign "^4.1.1"
prop-types@^15.6.1:
version "15.7.2"
resolved "https://registry.yarnpkg.com/prop-types/-/prop-types-15.7.2.tgz#52c41e75b8c87e72b9d9360e0206b99dcbffa6c5"
dependencies:
loose-envify "^1.4.0"
object-assign "^4.1.1"
react-is "^16.8.1"
proxy-addr@~2.0.4:
version "2.0.4"
resolved "https://registry.yarnpkg.com/proxy-addr/-/proxy-addr-2.0.4.tgz#ecfc733bf22ff8c6f407fa275327b9ab67e48b93"
......@@ -6741,19 +6803,23 @@ react-dev-utils@^5.0.1:
strip-ansi "3.0.1"
text-table "0.2.0"
react-dom@^16.4.2:
version "16.5.2"
resolved "https://registry.yarnpkg.com/react-dom/-/react-dom-16.5.2.tgz#b69ee47aa20bab5327b2b9d7c1fe2a30f2cfa9d7"
react-dom@^16.7.0-alpha.2:
version "16.8.6"
resolved "https://registry.yarnpkg.com/react-dom/-/react-dom-16.8.6.tgz#71d6303f631e8b0097f56165ef608f051ff6e10f"
dependencies:
loose-envify "^1.1.0"
object-assign "^4.1.1"
prop-types "^15.6.2"
schedule "^0.5.0"
scheduler "^0.13.6"
react-error-overlay@^4.0.1:
version "4.0.1"
resolved "https://registry.yarnpkg.com/react-error-overlay/-/react-error-overlay-4.0.1.tgz#417addb0814a90f3a7082eacba7cee588d00da89"
react-is@^16.8.1:
version "16.8.6"
resolved "https://registry.yarnpkg.com/react-is/-/react-is-16.8.6.tgz#5bbc1e2d29141c9fbdfed456343fe2bc430a6a16"
react-json-tree@^0.11.0:
version "0.11.0"
resolved "https://registry.yarnpkg.com/react-json-tree/-/react-json-tree-0.11.0.tgz#f5b17e83329a9c76ae38be5c04fda3a7fd684a35"
......@@ -6775,6 +6841,22 @@ react-lifecycles-compat@^3.0.2, react-lifecycles-compat@^3.0.4:
version "3.0.4"
resolved "https://registry.yarnpkg.com/react-lifecycles-compat/-/react-lifecycles-compat-3.0.4.tgz#4f1a273afdfc8f3488a8c516bfda78f872352362"
react-monaco-editor@^0.22.0:
version "0.22.0"
resolved "https://registry.yarnpkg.com/react-monaco-editor/-/react-monaco-editor-0.22.0.tgz#2ba4c9557d2e95bb0f097a56f5e5d30598f7f2f9"
dependencies:
"@types/react" "*"
monaco-editor "^0.15.1"
prop-types "^15.6.2"
react-responsive@^7.0.0:
version "7.0.0"
resolved "https://registry.yarnpkg.com/react-responsive/-/react-responsive-7.0.0.tgz#0abde0ccbb50e5e8407e3d61dd4696447e7ebd3c"
dependencies:
hyphenate-style-name "^1.0.0"
matchmediaquery "^0.3.0"
prop-types "^15.6.1"
react-router@3.2.1:
version "3.2.1"
resolved "http://registry.npmjs.org/react-router/-/react-router-3.2.1.tgz#b9a3279962bdfbe684c8bd0482b81ef288f0f244"
......@@ -6848,14 +6930,14 @@ react-slick@~0.23.1:
lodash.debounce "^4.0.8"
resize-observer-polyfill "^1.5.0"
react@^16.4.2:
version "16.5.2"
resolved "https://registry.yarnpkg.com/react/-/react-16.5.2.tgz#19f6b444ed139baa45609eee6dc3d318b3895d42"
react@^16.7.0-alpha.2:
version "16.8.6"
resolved "https://registry.yarnpkg.com/react/-/react-16.8.6.tgz#ad6c3a9614fd3a4e9ef51117f54d888da01f2bbe"
dependencies:
loose-envify "^1.1.0"
object-assign "^4.1.1"
prop-types "^15.6.2"
schedule "^0.5.0"
scheduler "^0.13.6"
read-pkg-up@^1.0.1:
version "1.0.1"
......@@ -7276,10 +7358,11 @@ sax@^1.2.4, sax@~1.2.1:
version "1.2.4"
resolved "https://registry.yarnpkg.com/sax/-/sax-1.2.4.tgz#2816234e2378bddc4e5354fab5caa895df7100d9"
schedule@^0.5.0:
version "0.5.0"
resolved "https://registry.yarnpkg.com/schedule/-/schedule-0.5.0.tgz#c128fffa0b402488b08b55ae74bb9df55cc29cc8"
scheduler@^0.13.6:
version "0.13.6"
resolved "https://registry.yarnpkg.com/scheduler/-/scheduler-0.13.6.tgz#466a4ec332467b31a91b9bf74e5347072e4cd889"
dependencies:
loose-envify "^1.1.0"
object-assign "^4.1.1"
schema-utils@^0.3.0:
......@@ -7976,6 +8059,10 @@ to-regex@^3.0.1, to-regex@^3.0.2:
regex-not "^1.0.2"
safe-regex "^1.1.0"
toggle-selection@^1.0.6:
version "1.0.6"
resolved "https://registry.yarnpkg.com/toggle-selection/-/toggle-selection-1.0.6.tgz#6e45b1263f2017fa0acc7d89d78b15b8bf77da32"
toposort@^1.0.0:
version "1.0.7"
resolved "https://registry.yarnpkg.com/toposort/-/toposort-1.0.7.tgz#2e68442d9f64ec720b8cc89e6443ac6caa950029"
......
authorName: nni
experimentName: default_test
maxExecDuration: 5m
maxTrialNum: 2
trialConcurrency: 1
searchSpacePath: search_space.json
tuner:
builtinTunerName: GPTuner
classArgs:
optimize_mode: maximize
utility: 'ei'
kappa: 5.0
xi: 0.0
nu: 2.5
alpha: 1e-6
cold_start_num: 10
selection_num_warm_up: 100000
selection_num_starting_points: 250
assessor:
builtinAssessorName: Medianstop
classArgs:
optimize_mode: maximize
trial:
codeDir: ../../../examples/trials/mnist
command: python3 mnist.py --batch_num 100
gpuNum: 0
useAnnotation: false
multiPhase: false
multiThread: false
trainingServicePlatform: local
......@@ -15,6 +15,6 @@
## 问题
* 使用了私有 API 来检测是否 Tuner 和 Assessor 成功结束。
* 使用了私有 API 来检测是否 Tuner 和 Assessor 成功结束。
* RESTful 服务的输出未测试。
* 远程计算机训练服务没有测试。
\ No newline at end of file
......@@ -18,7 +18,7 @@ jobs:
displayName: 'generate config files'
- script: |
cd test
python config_test.py --ts local --local_gpu --exclude smac,bohb
python config_test.py --ts local --local_gpu --exclude smac,bohb,multi_phase_batch,multi_phase_grid
displayName: 'Examples and advanced features tests on local machine'
- script: |
cd test
......
......@@ -31,7 +31,7 @@ jobs:
displayName: 'Built-in tuners / assessors tests'
- script: |
cd test
PATH=$HOME/.local/bin:$PATH python3 config_test.py --ts local --local_gpu
PATH=$HOME/.local/bin:$PATH python3 config_test.py --ts local --local_gpu --exclude multi_phase_batch,multi_phase_grid
displayName: 'Examples and advanced features tests on local machine'
- script: |
cd test
......
......@@ -65,5 +65,5 @@ jobs:
python --version
python generate_ts_config.py --ts pai --pai_host $(pai_host) --pai_user $(pai_user) --pai_pwd $(pai_pwd) --vc $(pai_virtual_cluster) --nni_docker_image $(docker_image) --data_dir $(data_dir) --output_dir $(output_dir) --nni_manager_ip $(nni_manager_ip)
python config_test.py --ts pai --exclude multi_phase,smac,bohb
python config_test.py --ts pai --exclude multi_phase,smac,bohb,multi_phase_batch,multi_phase_grid
displayName: 'Examples and advanced features tests on pai'
\ No newline at end of file
......@@ -75,5 +75,5 @@ jobs:
python3 generate_ts_config.py --ts pai --pai_host $(pai_host) --pai_user $(pai_user) --pai_pwd $(pai_pwd) --vc $(pai_virtual_cluster) \
--nni_docker_image $TEST_IMG --data_dir $(data_dir) --output_dir $(output_dir) --nni_manager_ip $(nni_manager_ip)
PATH=$HOME/.local/bin:$PATH python3 config_test.py --ts pai --exclude multi_phase
PATH=$HOME/.local/bin:$PATH python3 config_test.py --ts pai --exclude multi_phase_batch,multi_phase_grid
displayName: 'integration test'
......@@ -39,7 +39,7 @@ jobs:
cd test
python generate_ts_config.py --ts remote --remote_user $(docker_user) --remote_host $(remote_host) --remote_port $(Get-Content port) --remote_pwd $(docker_pwd) --nni_manager_ip $(nni_manager_ip)
Get-Content training_service.yml
python config_test.py --ts remote --exclude cifar10,smac,bohb
python config_test.py --ts remote --exclude cifar10,smac,bohb,multi_phase_batch,multi_phase_grid
displayName: 'integration test'
- task: SSH@0
inputs:
......
......@@ -52,7 +52,7 @@ jobs:
python3 generate_ts_config.py --ts remote --remote_user $(docker_user) --remote_host $(remote_host) \
--remote_port $(cat port) --remote_pwd $(docker_pwd) --nni_manager_ip $(nni_manager_ip)
cat training_service.yml
PATH=$HOME/.local/bin:$PATH python3 config_test.py --ts remote --exclude cifar10
PATH=$HOME/.local/bin:$PATH python3 config_test.py --ts remote --exclude cifar10,multi_phase_batch,multi_phase_grid
displayName: 'integration test'
- task: SSH@0
inputs:
......
......@@ -27,7 +27,7 @@ NNI 中,有 4 种类型的 Annotation;
**参数**
- **sampling_algo**: 指定搜索空间的采样算法。 可将其换成 NNI 支持的其它采样函数,函数要以 `nni.` 开头。例如,`choice``uniform`,详见 [SearchSpaceSpec](https://nni.readthedocs.io/zh/latest/SearchSpaceSpec.html)
- **sampling_algo**: 指定搜索空间的采样算法。 可将其换成 NNI 支持的其它采样函数,函数要以 `nni.` 开头。例如,`choice``uniform`,详见 [SearchSpaceSpec](https://nni.readthedocs.io/zh/latest/SearchSpaceSpec.html)
- **name**: 将被赋值的变量名称。 注意,此参数应该与下面一行等号左边的值相同。
NNI 支持如下 10 种类型来表示搜索空间:
......
......@@ -76,11 +76,12 @@ def _generate_file_search_space(path, module):
return search_space
def expand_annotations(src_dir, dst_dir, exp_id='', trial_id=''):
def expand_annotations(src_dir, dst_dir, exp_id='', trial_id='', nas_mode=None):
"""Expand annotations in user code.
Return dst_dir if annotation detected; return src_dir if not.
src_dir: directory path of user code (str)
dst_dir: directory to place generated files (str)
nas_mode: the mode of NAS given that NAS interface is used
"""
if src_dir[-1] == slash:
src_dir = src_dir[:-1]
......@@ -108,7 +109,7 @@ def expand_annotations(src_dir, dst_dir, exp_id='', trial_id=''):
dst_path = os.path.join(dst_subdir, file_name)
if file_name.endswith('.py'):
if trial_id == '':
annotated |= _expand_file_annotations(src_path, dst_path)
annotated |= _expand_file_annotations(src_path, dst_path, nas_mode)
else:
module = package + file_name[:-3]
annotated |= _generate_specific_file(src_path, dst_path, exp_id, trial_id, module)
......@@ -120,10 +121,10 @@ def expand_annotations(src_dir, dst_dir, exp_id='', trial_id=''):
return dst_dir if annotated else src_dir
def _expand_file_annotations(src_path, dst_path):
def _expand_file_annotations(src_path, dst_path, nas_mode):
with open(src_path) as src, open(dst_path, 'w') as dst:
try:
annotated_code = code_generator.parse(src.read())
annotated_code = code_generator.parse(src.read(), nas_mode)
if annotated_code is None:
shutil.copyfile(src_path, dst_path)
return False
......
......@@ -21,14 +21,14 @@
import ast
import astor
from nni_cmd.common_utils import print_warning
# pylint: disable=unidiomatic-typecheck
def parse_annotation_mutable_layers(code, lineno):
def parse_annotation_mutable_layers(code, lineno, nas_mode):
"""Parse the string of mutable layers in annotation.
Return a list of AST Expr nodes
code: annotation string (excluding '@')
nas_mode: the mode of NAS
"""
module = ast.parse(code)
assert type(module) is ast.Module, 'internal error #1'
......@@ -110,6 +110,9 @@ def parse_annotation_mutable_layers(code, lineno):
else:
target_call_args.append(ast.Dict(keys=[], values=[]))
target_call_args.append(ast.Num(n=0))
target_call_args.append(ast.Str(s=nas_mode))
if nas_mode in ['enas_mode', 'oneshot_mode']:
target_call_args.append(ast.Name(id='tensorflow'))
target_call = ast.Call(func=target_call_attr, args=target_call_args, keywords=[])
node = ast.Assign(targets=[layer_output], value=target_call)
nodes.append(node)
......@@ -277,10 +280,11 @@ class FuncReplacer(ast.NodeTransformer):
class Transformer(ast.NodeTransformer):
"""Transform original code to annotated code"""
def __init__(self):
def __init__(self, nas_mode=None):
self.stack = []
self.last_line = 0
self.annotated = False
self.nas_mode = nas_mode
def visit(self, node):
if isinstance(node, (ast.expr, ast.stmt)):
......@@ -316,8 +320,11 @@ class Transformer(ast.NodeTransformer):
return node # not an annotation, ignore it
if string.startswith('@nni.get_next_parameter'):
deprecated_message = "'@nni.get_next_parameter' is deprecated in annotation due to inconvenience. Please remove this line in the trial code."
print_warning(deprecated_message)
call_node = parse_annotation(string[1:]).value
if call_node.args:
# it is used in enas mode as it needs to retrieve the next subgraph for training
call_attr = ast.Attribute(value=ast.Name(id='nni', ctx=ast.Load()), attr='reload_tensorflow_variables', ctx=ast.Load())
return ast.Expr(value=ast.Call(func=call_attr, args=call_node.args, keywords=[]))
if string.startswith('@nni.report_intermediate_result') \
or string.startswith('@nni.report_final_result') \
......@@ -325,7 +332,8 @@ class Transformer(ast.NodeTransformer):
return parse_annotation(string[1:]) # expand annotation string to code
if string.startswith('@nni.mutable_layers'):
return parse_annotation_mutable_layers(string[1:], node.lineno)
nodes = parse_annotation_mutable_layers(string[1:], node.lineno, self.nas_mode)
return nodes
if string.startswith('@nni.variable') \
or string.startswith('@nni.function_choice'):
......@@ -343,17 +351,18 @@ class Transformer(ast.NodeTransformer):
return node
def parse(code):
def parse(code, nas_mode=None):
"""Annotate user code.
Return annotated code (str) if annotation detected; return None if not.
code: original user code (str)
code: original user code (str),
nas_mode: the mode of NAS given that NAS interface is used
"""
try:
ast_tree = ast.parse(code)
except Exception:
raise RuntimeError('Bad Python code')
transformer = Transformer()
transformer = Transformer(nas_mode)
try:
transformer.visit(ast_tree)
except AssertionError as exc:
......@@ -369,5 +378,9 @@ def parse(code):
if type(nodes[i]) is ast.ImportFrom and nodes[i].module == '__future__':
last_future_import = i
nodes.insert(last_future_import + 1, import_nni)
# enas and oneshot modes for tensorflow need tensorflow module, so we import it here
if nas_mode in ['enas_mode', 'oneshot_mode']:
import_tf = ast.Import(names=[ast.alias(name='tensorflow', asname=None)])
nodes.insert(last_future_import + 1, import_tf)
return astor.to_source(ast_tree)
......@@ -104,6 +104,21 @@ tuner_schema_dict = {
},
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
},
'GPTuner': {
'builtinTunerName': 'GPTuner',
'classArgs': {
Optional('optimize_mode'): setChoice('optimize_mode', 'maximize', 'minimize'),
Optional('utility'): setChoice('utility', 'ei', 'ucb', 'poi'),
Optional('kappa'): setType('kappa', float),
Optional('xi'): setType('xi', float),
Optional('nu'): setType('nu', float),
Optional('alpha'): setType('alpha', float),
Optional('cold_start_num'): setType('cold_start_num', int),
Optional('selection_num_warm_up'): setType('selection_num_warm_up', int),
Optional('selection_num_starting_points'): setType('selection_num_starting_points', int),
},
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
},
'customized': {
'codeDir': setPathCheck('codeDir'),
'classFileName': setType('classFileName', str),
......@@ -181,7 +196,8 @@ common_trial_schema = {
'trial':{
'command': setType('command', str),
'codeDir': setPathCheck('codeDir'),
'gpuNum': setNumberRange('gpuNum', int, 0, 99999)
'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
Optional('nasMode'): setChoice('classic_mode', 'enas_mode', 'oneshot_mode')
}
}
......@@ -199,6 +215,7 @@ pai_trial_schema = {
Optional('outputDir'): And(Regex(r'hdfs://(([0-9]{1,3}.){3}[0-9]{1,3})(:[0-9]{2,5})?(/.*)?'),\
error='ERROR: outputDir format error, outputDir format is hdfs://xxx.xxx.xxx.xxx:xxx'),
Optional('virtualCluster'): setType('virtualCluster', str),
Optional('nasMode'): setChoice('classic_mode', 'enas_mode', 'oneshot_mode')
}
}
......@@ -213,6 +230,7 @@ pai_config_schema = {
kubeflow_trial_schema = {
'trial':{
'codeDir': setPathCheck('codeDir'),
Optional('nasMode'): setChoice('classic_mode', 'enas_mode', 'oneshot_mode'),
Optional('ps'): {
'replicas': setType('replicas', int),
'command': setType('command', str),
......
......@@ -377,7 +377,8 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
if not os.path.isdir(path):
os.makedirs(path)
path = tempfile.mkdtemp(dir=path)
code_dir = expand_annotations(experiment_config['trial']['codeDir'], path)
nas_mode = experiment_config['trial'].get('nasMode', 'classic_mode')
code_dir = expand_annotations(experiment_config['trial']['codeDir'], path, nas_mode=nas_mode)
experiment_config['trial']['codeDir'] = code_dir
search_space = generate_search_space(code_dir)
experiment_config['searchSpace'] = json.dumps(search_space)
......
......@@ -119,8 +119,21 @@ def parse_args():
parser_experiment_status.add_argument('id', nargs='?', help='the id of experiment')
parser_experiment_status.set_defaults(func=experiment_status)
parser_experiment_list = parser_experiment_subparsers.add_parser('list', help='list all of running experiment ids')
parser_experiment_list.add_argument('all', nargs='?', help='list all of experiments')
parser_experiment_list.add_argument('--all', action='store_true', default=False, help='list all of experiments')
parser_experiment_list.set_defaults(func=experiment_list)
parser_experiment_clean = parser_experiment_subparsers.add_parser('delete', help='clean up the experiment data')
parser_experiment_clean.add_argument('id', nargs='?', help='the id of experiment')
parser_experiment_clean.add_argument('--all', action='store_true', default=False, help='delete all of experiments')
parser_experiment_clean.set_defaults(func=experiment_clean)
#parse experiment command
parser_platform = subparsers.add_parser('platform', help='get platform information')
#add subparsers for parser_experiment
parser_platform_subparsers = parser_platform.add_subparsers()
parser_platform_clean = parser_platform_subparsers.add_parser('clean', help='clean up the platform data')
parser_platform_clean.add_argument('--config', '-c', required=True, dest='config', help='the path of yaml config file')
parser_platform_clean.set_defaults(func=platform_clean)
#import tuning data
parser_import_data = parser_experiment_subparsers.add_parser('import', help='import additional data')
parser_import_data.add_argument('id', nargs='?', help='the id of experiment')
......
......@@ -24,6 +24,10 @@ import psutil
import json
import datetime
import time
import re
from pathlib import Path
from pyhdfs import HdfsClient, HdfsFileNotFoundException
import shutil
from subprocess import call, check_output
from nni_annotation import expand_annotations
from .rest_utils import rest_get, rest_delete, check_rest_server_quick, check_response
......@@ -31,8 +35,9 @@ from .url_utils import trial_jobs_url, experiment_url, trial_job_id_url, export_
from .config_utils import Config, Experiments
from .constants import NNICTL_HOME_DIR, EXPERIMENT_INFORMATION_FORMAT, EXPERIMENT_DETAIL_FORMAT, \
EXPERIMENT_MONITOR_INFO, TRIAL_MONITOR_HEAD, TRIAL_MONITOR_CONTENT, TRIAL_MONITOR_TAIL, REST_TIME_OUT
from .common_utils import print_normal, print_error, print_warning, detect_process
from .common_utils import print_normal, print_error, print_warning, detect_process, get_yml_content
from .command_utils import check_output_command, kill_command
from .ssh_utils import create_ssh_sftp_client, remove_remote_directory
def get_experiment_time(port):
'''get the startTime and endTime of an experiment'''
......@@ -73,10 +78,11 @@ def update_experiment():
if status:
experiment_config.update_experiment(key, 'status', status)
def check_experiment_id(args):
def check_experiment_id(args, update=True):
'''check if the id is valid
'''
update_experiment()
if update:
update_experiment()
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
if not experiment_dict:
......@@ -100,14 +106,14 @@ def check_experiment_id(args):
print(EXPERIMENT_INFORMATION_FORMAT % experiment_information)
exit(1)
elif not running_experiment_list:
print_error('There is no experiment running!')
print_error('There is no experiment running.')
return None
else:
return running_experiment_list[0]
if experiment_dict.get(args.id):
return args.id
else:
print_error('Id not correct!')
print_error('Id not correct.')
return None
def parse_ids(args):
......@@ -145,7 +151,7 @@ def parse_ids(args):
exit(1)
else:
result_list = running_experiment_list
elif args.id == 'all':
elif args.all:
result_list = running_experiment_list
elif args.id.endswith('*'):
for id in running_experiment_list:
......@@ -170,7 +176,7 @@ def get_config_filename(args):
'''get the file name of config file'''
experiment_id = check_experiment_id(args)
if experiment_id is None:
print_error('Please set the experiment id!')
print_error('Please set correct experiment id.')
exit(1)
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
......@@ -180,7 +186,7 @@ def get_experiment_port(args):
'''get the port of experiment'''
experiment_id = check_experiment_id(args)
if experiment_id is None:
print_error('Please set the experiment id!')
print_error('Please set correct experiment id.')
exit(1)
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
......@@ -229,7 +235,7 @@ def stop_experiment(args):
except Exception as exception:
print_error(exception)
nni_config.set_config('tensorboardPidList', [])
print_normal('Stop experiment success!')
print_normal('Stop experiment success.')
experiment_config.update_experiment(experiment_id, 'status', 'STOPPED')
time_now = time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))
experiment_config.update_experiment(experiment_id, 'endTime', str(time_now))
......@@ -354,10 +360,10 @@ def log_trial(args):
if trial_id_path_dict.get(args.trial_id):
print_normal('id:' + args.trial_id + ' path:' + trial_id_path_dict[args.trial_id])
else:
print_error('trial id is not valid!')
print_error('trial id is not valid.')
exit(1)
else:
print_error('please specific the trial id!')
print_error('please specific the trial id.')
exit(1)
else:
for key in trial_id_path_dict:
......@@ -373,16 +379,179 @@ def webui_url(args):
nni_config = Config(get_config_filename(args))
print_normal('{0} {1}'.format('Web UI url:', ' '.join(nni_config.get_config('webuiUrl'))))
def local_clean(directory):
'''clean up local data'''
print_normal('removing folder {0}'.format(directory))
try:
shutil.rmtree(directory)
except FileNotFoundError as err:
print_error('{0} does not exist.'.format(directory))
def remote_clean(machine_list, experiment_id=None):
'''clean up remote data'''
for machine in machine_list:
passwd = machine.get('passwd')
userName = machine.get('username')
host = machine.get('ip')
port = machine.get('port')
if experiment_id:
remote_dir = '/' + '/'.join(['tmp', 'nni', 'experiments', experiment_id])
else:
remote_dir = '/' + '/'.join(['tmp', 'nni', 'experiments'])
sftp = create_ssh_sftp_client(host, port, userName, passwd)
print_normal('removing folder {0}'.format(host + ':' + str(port) + remote_dir))
remove_remote_directory(sftp, remote_dir)
def hdfs_clean(host, user_name, output_dir, experiment_id=None):
'''clean up hdfs data'''
hdfs_client = HdfsClient(hosts='{0}:80'.format(host), user_name=user_name, webhdfs_path='/webhdfs/api/v1', timeout=5)
if experiment_id:
full_path = '/' + '/'.join([user_name, 'nni', 'experiments', experiment_id])
else:
full_path = '/' + '/'.join([user_name, 'nni', 'experiments'])
print_normal('removing folder {0} in hdfs'.format(full_path))
hdfs_client.delete(full_path, recursive=True)
if output_dir:
pattern = re.compile('hdfs://(?P<host>([0-9]{1,3}.){3}[0-9]{1,3})(:[0-9]{2,5})?(?P<baseDir>/.*)?')
match_result = pattern.match(output_dir)
if match_result:
output_host = match_result.group('host')
output_dir = match_result.group('baseDir')
#check if the host is valid
if output_host != host:
print_warning('The host in {0} is not consistent with {1}'.format(output_dir, host))
else:
if experiment_id:
output_dir = output_dir + '/' + experiment_id
print_normal('removing folder {0} in hdfs'.format(output_dir))
hdfs_client.delete(output_dir, recursive=True)
def experiment_clean(args):
'''clean up the experiment data'''
experiment_id_list = []
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
if args.all:
experiment_id_list = list(experiment_dict.keys())
else:
if args.id is None:
print_error('please set experiment id.')
exit(1)
if args.id not in experiment_dict:
print_error('Cannot find experiment {0}.'.format(args.id))
exit(1)
experiment_id_list.append(args.id)
while True:
print('INFO: This action will delete experiment {0}, and it\'s not recoverable.'.format(' '.join(experiment_id_list)))
inputs = input('INFO: do you want to continue?[y/N]:')
if not inputs.lower() or inputs.lower() in ['n', 'no']:
exit(0)
elif inputs.lower() not in ['y', 'n', 'yes', 'no']:
print_warning('please input Y or N.')
else:
break
for experiment_id in experiment_id_list:
nni_config = Config(experiment_dict[experiment_id]['fileName'])
platform = nni_config.get_config('experimentConfig').get('trainingServicePlatform')
experiment_id = nni_config.get_config('experimentId')
if platform == 'remote':
machine_list = nni_config.get_config('experimentConfig').get('machineList')
remote_clean(machine_list, experiment_id)
elif platform == 'pai':
host = nni_config.get_config('experimentConfig').get('paiConfig').get('host')
user_name = nni_config.get_config('experimentConfig').get('paiConfig').get('userName')
output_dir = nni_config.get_config('experimentConfig').get('trial').get('outputDir')
hdfs_clean(host, user_name, output_dir, experiment_id)
elif platform != 'local':
#TODO: support all platforms
print_warning('platform {0} clean up not supported yet.'.format(platform))
exit(0)
#clean local data
home = str(Path.home())
local_dir = nni_config.get_config('experimentConfig').get('logDir')
if not local_dir:
local_dir = os.path.join(home, 'nni', 'experiments', experiment_id)
local_clean(local_dir)
experiment_config = Experiments()
print_normal('removing metadata of experiment {0}'.format(experiment_id))
experiment_config.remove_experiment(experiment_id)
print_normal('Done.')
def get_platform_dir(config_content):
'''get the dir list to be deleted'''
platform = config_content.get('trainingServicePlatform')
dir_list = []
if platform == 'remote':
machine_list = config_content.get('machineList')
for machine in machine_list:
host = machine.get('ip')
port = machine.get('port')
dir_list.append(host + ':' + str(port) + '/tmp/nni')
elif platform == 'pai':
pai_config = config_content.get('paiConfig')
host = config_content.get('paiConfig').get('host')
user_name = config_content.get('paiConfig').get('userName')
output_dir = config_content.get('trial').get('outputDir')
dir_list.append('server: {0}, path: {1}/nni'.format(host, user_name))
if output_dir:
dir_list.append(output_dir)
return dir_list
def platform_clean(args):
'''clean up the experiment data'''
config_path = os.path.abspath(args.config)
if not os.path.exists(config_path):
print_error('Please set correct config path.')
exit(1)
config_content = get_yml_content(config_path)
platform = config_content.get('trainingServicePlatform')
if platform == 'local':
print_normal('it doesn’t need to clean local platform.')
exit(0)
if platform not in ['remote', 'pai']:
print_normal('platform {0} not supported.'.format(platform))
exit(0)
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
update_experiment()
id_list = list(experiment_dict.keys())
dir_list = get_platform_dir(config_content)
if not dir_list:
print_normal('No folder of NNI caches is found.')
exit(1)
while True:
print_normal('This command will remove below folders of NNI caches. If other users are using experiments on below hosts, it will be broken.')
for dir in dir_list:
print(' ' + dir)
inputs = input('INFO: do you want to continue?[y/N]:')
if not inputs.lower() or inputs.lower() in ['n', 'no']:
exit(0)
elif inputs.lower() not in ['y', 'n', 'yes', 'no']:
print_warning('please input Y or N.')
else:
break
if platform == 'remote':
machine_list = config_content.get('machineList')
for machine in machine_list:
remote_clean(machine_list, None)
elif platform == 'pai':
pai_config = config_content.get('paiConfig')
host = config_content.get('paiConfig').get('host')
user_name = config_content.get('paiConfig').get('userName')
output_dir = config_content.get('trial').get('outputDir')
hdfs_clean(host, user_name, output_dir, None)
print_normal('Done.')
def experiment_list(args):
'''get the information of all experiments'''
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
if not experiment_dict:
print('There is no experiment running...')
print_normal('Cannot find experiments.')
exit(1)
update_experiment()
experiment_id_list = []
if args.all and args.all == 'all':
if args.all:
for key in experiment_dict.keys():
experiment_id_list.append(key)
else:
......@@ -390,10 +559,9 @@ def experiment_list(args):
if experiment_dict[key]['status'] != 'STOPPED':
experiment_id_list.append(key)
if not experiment_id_list:
print_warning('There is no experiment running...\nYou can use \'nnictl experiment list all\' to list all stopped experiments!')
print_warning('There is no experiment running...\nYou can use \'nnictl experiment list --all\' to list all stopped experiments.')
experiment_information = ""
for key in experiment_id_list:
experiment_information += (EXPERIMENT_DETAIL_FORMAT % (key, experiment_dict[key]['status'], experiment_dict[key]['port'],\
experiment_dict[key].get('platform'), experiment_dict[key]['startTime'], experiment_dict[key]['endTime']))
print(EXPERIMENT_INFORMATION_FORMAT % experiment_information)
......
......@@ -57,3 +57,17 @@ def create_ssh_sftp_client(host_ip, port, username, password):
return sftp
except Exception as exception:
print_error('Create ssh client error %s\n' % exception)
def remove_remote_directory(sftp, directory):
'''remove a directory in remote machine'''
try:
files = sftp.listdir(directory)
for file in files:
filepath = '/'.join([directory, file])
try:
sftp.remove(filepath)
except IOError:
remove_remote_directory(sftp, filepath)
sftp.rmdir(directory)
except IOError as err:
print_error(err)
\ No newline at end of file
......@@ -67,7 +67,7 @@ def trial_jobs_url(port):
def trial_job_id_url(port, job_id):
'''get trial_jobs with id url'''
return '{0}:{1}{2}{3}/:{4}'.format(BASE_URL, port, API_ROOT_URL, TRIAL_JOBS_API, job_id)
return '{0}:{1}{2}{3}/{4}'.format(BASE_URL, port, API_ROOT_URL, TRIAL_JOBS_API, job_id)
def export_data_url(port):
......@@ -87,4 +87,4 @@ def get_local_urls(port):
for addr in info:
if AddressFamily.AF_INET == addr.family:
url_list.append('http://{}:{}'.format(addr.address, port))
return url_list
\ No newline at end of file
return url_list
......@@ -51,7 +51,7 @@ def get_hdfs_client(args):
return _hdfs_client
# backward compatibility
hdfs_host = None
hdfs_output_dir = None
if args.hdfs_host:
hdfs_host = args.hdfs_host
elif args.pai_hdfs_host:
......@@ -83,6 +83,8 @@ def main_loop(args):
# redirect trial keeper's stdout and stderr to syslog
trial_syslogger_stdout = RemoteLogger(args.nnimanager_ip, args.nnimanager_port, 'trial', StdOutputType.Stdout, args.log_collection)
sys.stdout = sys.stderr = trial_keeper_syslogger
hdfs_output_dir = None
if args.hdfs_output_dir:
hdfs_output_dir = args.hdfs_output_dir
elif args.pai_hdfs_output_dir:
......@@ -222,7 +224,7 @@ if __name__ == '__main__':
exit(1)
check_version(args)
try:
if is_multi_phase():
if NNI_PLATFORM == 'pai' and is_multi_phase():
fetch_parameter_file(args)
main_loop(args)
except SystemExit as se:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment