Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
c785655e
Unverified
Commit
c785655e
authored
Oct 21, 2019
by
SparkSnail
Committed by
GitHub
Oct 21, 2019
Browse files
Merge pull request #207 from microsoft/master
merge master
parents
9fae194a
d6b61e2f
Changes
158
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
175 additions
and
169 deletions
+175
-169
test/naive_test/local.yml
test/naive_test/local.yml
+0
-2
test/pipelines-it-kubeflow.yml
test/pipelines-it-kubeflow.yml
+0
-5
test/pipelines-it-local-windows.yml
test/pipelines-it-local-windows.yml
+1
-1
test/pipelines-it-local.yml
test/pipelines-it-local.yml
+3
-3
test/pipelines-it-pai-windows.yml
test/pipelines-it-pai-windows.yml
+1
-1
test/pipelines-it-pai.yml
test/pipelines-it-pai.yml
+1
-6
test/pipelines-it-remote-windows.yml
test/pipelines-it-remote-windows.yml
+1
-1
test/pipelines-it-remote.yml
test/pipelines-it-remote.yml
+1
-1
tools/nni_annotation/.gitignore
tools/nni_annotation/.gitignore
+1
-0
tools/nni_annotation/code_generator.py
tools/nni_annotation/code_generator.py
+6
-5
tools/nni_annotation/search_space_generator.py
tools/nni_annotation/search_space_generator.py
+1
-2
tools/nni_annotation/specific_code_generator.py
tools/nni_annotation/specific_code_generator.py
+16
-8
tools/nni_annotation/test_annotation.py
tools/nni_annotation/test_annotation.py
+8
-7
tools/nni_cmd/config_schema.py
tools/nni_cmd/config_schema.py
+29
-14
tools/nni_cmd/constants.py
tools/nni_cmd/constants.py
+2
-1
tools/nni_cmd/launcher.py
tools/nni_cmd/launcher.py
+82
-101
tools/nni_cmd/nnictl.py
tools/nni_cmd/nnictl.py
+7
-1
tools/nni_cmd/nnictl_utils.py
tools/nni_cmd/nnictl_utils.py
+15
-10
No files found.
test/naive_test/local.yml
View file @
c785655e
...
@@ -14,14 +14,12 @@ tuner:
...
@@ -14,14 +14,12 @@ tuner:
className
:
NaiveTuner
className
:
NaiveTuner
classArgs
:
classArgs
:
optimize_mode
:
maximize
optimize_mode
:
maximize
gpuNum
:
0
assessor
:
assessor
:
codeDir
:
.
codeDir
:
.
classFileName
:
naive_assessor.py
classFileName
:
naive_assessor.py
className
:
NaiveAssessor
className
:
NaiveAssessor
classArgs
:
classArgs
:
optimize_mode
:
maximize
optimize_mode
:
maximize
gpuNum
:
0
trial
:
trial
:
command
:
python3 naive_trial.py
command
:
python3 naive_trial.py
codeDir
:
.
codeDir
:
.
...
...
test/pipelines-it-kubeflow.yml
View file @
c785655e
...
@@ -39,11 +39,6 @@ jobs:
...
@@ -39,11 +39,6 @@ jobs:
displayName
:
'
Install
nni
toolkit
via
source
code'
displayName
:
'
Install
nni
toolkit
via
source
code'
-
script
:
|
-
script
:
|
python3 -m pip install scikit-learn==0.20.0 --user
python3 -m pip install torch==0.4.1 --user
python3 -m pip install torchvision==0.2.1 --user
python3 -m pip install keras==2.1.6 --user
python3 -m pip install tensorflow==1.12.0 --user
sudo apt-get install swig -y
sudo apt-get install swig -y
PATH=$HOME/.local/bin:$PATH nnictl package install --name=SMAC
PATH=$HOME/.local/bin:$PATH nnictl package install --name=SMAC
PATH=$HOME/.local/bin:$PATH nnictl package install --name=BOHB
PATH=$HOME/.local/bin:$PATH nnictl package install --name=BOHB
...
...
test/pipelines-it-local-windows.yml
View file @
c785655e
...
@@ -18,7 +18,7 @@ jobs:
...
@@ -18,7 +18,7 @@ jobs:
displayName
:
'
generate
config
files'
displayName
:
'
generate
config
files'
-
script
:
|
-
script
:
|
cd test
cd test
python config_test.py --ts local --local_gpu --exclude smac,bohb
,multi_phase_batch,multi_phase_grid
python config_test.py --ts local --local_gpu --exclude smac,bohb
displayName
:
'
Examples
and
advanced
features
tests
on
local
machine'
displayName
:
'
Examples
and
advanced
features
tests
on
local
machine'
-
script
:
|
-
script
:
|
cd test
cd test
...
...
test/pipelines-it-local.yml
View file @
c785655e
...
@@ -9,8 +9,8 @@ jobs:
...
@@ -9,8 +9,8 @@ jobs:
displayName
:
'
Install
nni
toolkit
via
source
code'
displayName
:
'
Install
nni
toolkit
via
source
code'
-
script
:
|
-
script
:
|
python3 -m pip install scikit-learn==0.20.0 --user
python3 -m pip install scikit-learn==0.20.0 --user
python3 -m pip install torch==
0.4.1
--user
python3 -m pip install torch==
1.2.0
--user
python3 -m pip install torchvision==0.
2.1
--user
python3 -m pip install torchvision==0.
4.0
--user
python3 -m pip install keras==2.1.6 --user
python3 -m pip install keras==2.1.6 --user
python3 -m pip install tensorflow-gpu==1.12.0 --user
python3 -m pip install tensorflow-gpu==1.12.0 --user
sudo apt-get install swig -y
sudo apt-get install swig -y
...
@@ -31,7 +31,7 @@ jobs:
...
@@ -31,7 +31,7 @@ jobs:
displayName
:
'
Built-in
tuners
/
assessors
tests'
displayName
:
'
Built-in
tuners
/
assessors
tests'
-
script
:
|
-
script
:
|
cd test
cd test
PATH=$HOME/.local/bin:$PATH python3 config_test.py --ts local --local_gpu
--exclude multi_phase_batch,multi_phase_grid
PATH=$HOME/.local/bin:$PATH python3 config_test.py --ts local --local_gpu
displayName
:
'
Examples
and
advanced
features
tests
on
local
machine'
displayName
:
'
Examples
and
advanced
features
tests
on
local
machine'
-
script
:
|
-
script
:
|
cd test
cd test
...
...
test/pipelines-it-pai-windows.yml
View file @
c785655e
...
@@ -65,5 +65,5 @@ jobs:
...
@@ -65,5 +65,5 @@ jobs:
python --version
python --version
python generate_ts_config.py --ts pai --pai_host $(pai_host) --pai_user $(pai_user) --pai_pwd $(pai_pwd) --vc $(pai_virtual_cluster) --nni_docker_image $(docker_image) --data_dir $(data_dir) --output_dir $(output_dir) --nni_manager_ip $(nni_manager_ip)
python generate_ts_config.py --ts pai --pai_host $(pai_host) --pai_user $(pai_user) --pai_pwd $(pai_pwd) --vc $(pai_virtual_cluster) --nni_docker_image $(docker_image) --data_dir $(data_dir) --output_dir $(output_dir) --nni_manager_ip $(nni_manager_ip)
python config_test.py --ts pai --exclude multi_phase,smac,bohb
,multi_phase_batch,multi_phase_grid
python config_test.py --ts pai --exclude multi_phase,smac,bohb
displayName
:
'
Examples
and
advanced
features
tests
on
pai'
displayName
:
'
Examples
and
advanced
features
tests
on
pai'
\ No newline at end of file
test/pipelines-it-pai.yml
View file @
c785655e
...
@@ -39,11 +39,6 @@ jobs:
...
@@ -39,11 +39,6 @@ jobs:
displayName
:
'
Install
nni
toolkit
via
source
code'
displayName
:
'
Install
nni
toolkit
via
source
code'
-
script
:
|
-
script
:
|
python3 -m pip install scikit-learn==0.20.0 --user
python3 -m pip install torch==0.4.1 --user
python3 -m pip install torchvision==0.2.1 --user
python3 -m pip install keras==2.1.6 --user
python3 -m pip install tensorflow-gpu==1.12.0 --user
sudo apt-get install swig -y
sudo apt-get install swig -y
PATH=$HOME/.local/bin:$PATH nnictl package install --name=SMAC
PATH=$HOME/.local/bin:$PATH nnictl package install --name=SMAC
PATH=$HOME/.local/bin:$PATH nnictl package install --name=BOHB
PATH=$HOME/.local/bin:$PATH nnictl package install --name=BOHB
...
@@ -76,6 +71,6 @@ jobs:
...
@@ -76,6 +71,6 @@ jobs:
python3 generate_ts_config.py --ts pai --pai_host $(pai_host) --pai_user $(pai_user) --pai_pwd $(pai_pwd) --vc $(pai_virtual_cluster) \
python3 generate_ts_config.py --ts pai --pai_host $(pai_host) --pai_user $(pai_user) --pai_pwd $(pai_pwd) --vc $(pai_virtual_cluster) \
--nni_docker_image $TEST_IMG --data_dir $(data_dir) --output_dir $(output_dir) --nni_manager_ip $(nni_manager_ip)
--nni_docker_image $TEST_IMG --data_dir $(data_dir) --output_dir $(output_dir) --nni_manager_ip $(nni_manager_ip)
PATH=$HOME/.local/bin:$PATH python3 config_test.py --ts pai
--exclude multi_phase_batch,multi_phase_grid
PATH=$HOME/.local/bin:$PATH python3 config_test.py --ts pai
PATH=$HOME/.local/bin:$PATH python3 metrics_test.py
PATH=$HOME/.local/bin:$PATH python3 metrics_test.py
displayName
:
'
integration
test'
displayName
:
'
integration
test'
test/pipelines-it-remote-windows.yml
View file @
c785655e
...
@@ -39,7 +39,7 @@ jobs:
...
@@ -39,7 +39,7 @@ jobs:
cd test
cd test
python generate_ts_config.py --ts remote --remote_user $(docker_user) --remote_host $(remote_host) --remote_port $(Get-Content port) --remote_pwd $(docker_pwd) --nni_manager_ip $(nni_manager_ip)
python generate_ts_config.py --ts remote --remote_user $(docker_user) --remote_host $(remote_host) --remote_port $(Get-Content port) --remote_pwd $(docker_pwd) --nni_manager_ip $(nni_manager_ip)
Get-Content training_service.yml
Get-Content training_service.yml
python config_test.py --ts remote --exclude cifar10,smac,bohb
,multi_phase_batch,multi_phase_grid
python config_test.py --ts remote --exclude cifar10,smac,bohb
displayName
:
'
integration
test'
displayName
:
'
integration
test'
-
task
:
SSH@0
-
task
:
SSH@0
inputs
:
inputs
:
...
...
test/pipelines-it-remote.yml
View file @
c785655e
...
@@ -53,7 +53,7 @@ jobs:
...
@@ -53,7 +53,7 @@ jobs:
python3 generate_ts_config.py --ts remote --remote_user $(docker_user) --remote_host $(remote_host) \
python3 generate_ts_config.py --ts remote --remote_user $(docker_user) --remote_host $(remote_host) \
--remote_port $(cat port) --remote_pwd $(docker_pwd) --nni_manager_ip $(nni_manager_ip)
--remote_port $(cat port) --remote_pwd $(docker_pwd) --nni_manager_ip $(nni_manager_ip)
cat training_service.yml
cat training_service.yml
PATH=$HOME/.local/bin:$PATH python3 config_test.py --ts remote --exclude cifar10
,multi_phase_batch,multi_phase_grid
PATH=$HOME/.local/bin:$PATH python3 config_test.py --ts remote --exclude cifar10
PATH=$HOME/.local/bin:$PATH python3 metrics_test.py
PATH=$HOME/.local/bin:$PATH python3 metrics_test.py
displayName
:
'
integration
test'
displayName
:
'
integration
test'
-
task
:
SSH@0
-
task
:
SSH@0
...
...
tools/nni_annotation/.gitignore
0 → 100644
View file @
c785655e
_generated
tools/nni_annotation/code_generator.py
View file @
c785655e
...
@@ -22,6 +22,7 @@
...
@@ -22,6 +22,7 @@
import
ast
import
ast
import
astor
import
astor
# pylint: disable=unidiomatic-typecheck
# pylint: disable=unidiomatic-typecheck
def
parse_annotation_mutable_layers
(
code
,
lineno
,
nas_mode
):
def
parse_annotation_mutable_layers
(
code
,
lineno
,
nas_mode
):
...
@@ -79,7 +80,8 @@ def parse_annotation_mutable_layers(code, lineno, nas_mode):
...
@@ -79,7 +80,8 @@ def parse_annotation_mutable_layers(code, lineno, nas_mode):
fields
[
'optional_inputs'
]
=
True
fields
[
'optional_inputs'
]
=
True
elif
k
.
id
==
'optional_input_size'
:
elif
k
.
id
==
'optional_input_size'
:
assert
not
fields
[
'optional_input_size'
],
'Duplicated field: optional_input_size'
assert
not
fields
[
'optional_input_size'
],
'Duplicated field: optional_input_size'
assert
type
(
value
)
is
ast
.
Num
or
type
(
value
)
is
ast
.
List
,
'Value of optional_input_size should be a number or list'
assert
type
(
value
)
is
ast
.
Num
or
type
(
value
)
is
ast
.
List
,
\
'Value of optional_input_size should be a number or list'
optional_input_size
=
value
optional_input_size
=
value
fields
[
'optional_input_size'
]
=
True
fields
[
'optional_input_size'
]
=
True
elif
k
.
id
==
'layer_output'
:
elif
k
.
id
==
'layer_output'
:
...
@@ -118,6 +120,7 @@ def parse_annotation_mutable_layers(code, lineno, nas_mode):
...
@@ -118,6 +120,7 @@ def parse_annotation_mutable_layers(code, lineno, nas_mode):
nodes
.
append
(
node
)
nodes
.
append
(
node
)
return
nodes
return
nodes
def
parse_annotation
(
code
):
def
parse_annotation
(
code
):
"""Parse an annotation string.
"""Parse an annotation string.
Return an AST Expr node.
Return an AST Expr node.
...
@@ -198,7 +201,7 @@ def convert_args_to_dict(call, with_lambda=False):
...
@@ -198,7 +201,7 @@ def convert_args_to_dict(call, with_lambda=False):
if
type
(
arg
)
in
[
ast
.
Str
,
ast
.
Num
]:
if
type
(
arg
)
in
[
ast
.
Str
,
ast
.
Num
]:
arg_value
=
arg
arg_value
=
arg
else
:
else
:
# if arg is not a string or a number, we use its source code as the key
# if arg is not a string or a number, we use its source code as the key
arg_value
=
astor
.
to_source
(
arg
).
strip
(
'
\n
"'
)
arg_value
=
astor
.
to_source
(
arg
).
strip
(
'
\n
"'
)
arg_value
=
ast
.
Str
(
str
(
arg_value
))
arg_value
=
ast
.
Str
(
str
(
arg_value
))
arg
=
make_lambda
(
arg
)
if
with_lambda
else
arg
arg
=
make_lambda
(
arg
)
if
with_lambda
else
arg
...
@@ -311,7 +314,6 @@ class Transformer(ast.NodeTransformer):
...
@@ -311,7 +314,6 @@ class Transformer(ast.NodeTransformer):
return
self
.
_visit_children
(
node
)
return
self
.
_visit_children
(
node
)
def
_visit_string
(
self
,
node
):
def
_visit_string
(
self
,
node
):
string
=
node
.
value
.
s
string
=
node
.
value
.
s
if
string
.
startswith
(
'@nni.'
):
if
string
.
startswith
(
'@nni.'
):
...
@@ -325,7 +327,7 @@ class Transformer(ast.NodeTransformer):
...
@@ -325,7 +327,7 @@ class Transformer(ast.NodeTransformer):
call_node
.
args
.
insert
(
0
,
ast
.
Str
(
s
=
self
.
nas_mode
))
call_node
.
args
.
insert
(
0
,
ast
.
Str
(
s
=
self
.
nas_mode
))
return
expr
return
expr
if
string
.
startswith
(
'@nni.report_intermediate_result'
)
\
if
string
.
startswith
(
'@nni.report_intermediate_result'
)
\
or
string
.
startswith
(
'@nni.report_final_result'
)
\
or
string
.
startswith
(
'@nni.report_final_result'
)
\
or
string
.
startswith
(
'@nni.get_next_parameter'
):
or
string
.
startswith
(
'@nni.get_next_parameter'
):
return
parse_annotation
(
string
[
1
:])
# expand annotation string to code
return
parse_annotation
(
string
[
1
:])
# expand annotation string to code
...
@@ -341,7 +343,6 @@ class Transformer(ast.NodeTransformer):
...
@@ -341,7 +343,6 @@ class Transformer(ast.NodeTransformer):
raise
AssertionError
(
'Unexpected annotation function'
)
raise
AssertionError
(
'Unexpected annotation function'
)
def
_visit_children
(
self
,
node
):
def
_visit_children
(
self
,
node
):
self
.
stack
.
append
(
None
)
self
.
stack
.
append
(
None
)
self
.
generic_visit
(
node
)
self
.
generic_visit
(
node
)
...
...
tools/nni_annotation/search_space_generator.py
View file @
c785655e
...
@@ -64,7 +64,6 @@ class SearchSpaceGenerator(ast.NodeTransformer):
...
@@ -64,7 +64,6 @@ class SearchSpaceGenerator(ast.NodeTransformer):
'optional_input_size'
:
args
[
6
].
n
if
isinstance
(
args
[
6
],
ast
.
Num
)
else
[
args
[
6
].
elts
[
0
].
n
,
args
[
6
].
elts
[
1
].
n
]
'optional_input_size'
:
args
[
6
].
n
if
isinstance
(
args
[
6
],
ast
.
Num
)
else
[
args
[
6
].
elts
[
0
].
n
,
args
[
6
].
elts
[
1
].
n
]
}
}
def
visit_Call
(
self
,
node
):
# pylint: disable=invalid-name
def
visit_Call
(
self
,
node
):
# pylint: disable=invalid-name
self
.
generic_visit
(
node
)
self
.
generic_visit
(
node
)
...
@@ -108,7 +107,7 @@ class SearchSpaceGenerator(ast.NodeTransformer):
...
@@ -108,7 +107,7 @@ class SearchSpaceGenerator(ast.NodeTransformer):
else
:
else
:
# arguments of other functions must be literal number
# arguments of other functions must be literal number
assert
all
(
isinstance
(
ast
.
literal_eval
(
astor
.
to_source
(
arg
)),
numbers
.
Real
)
for
arg
in
node
.
args
),
\
assert
all
(
isinstance
(
ast
.
literal_eval
(
astor
.
to_source
(
arg
)),
numbers
.
Real
)
for
arg
in
node
.
args
),
\
'Smart parameter
\'
s arguments must be number literals'
'Smart parameter
\'
s arguments must be number literals'
args
=
[
ast
.
literal_eval
(
astor
.
to_source
(
arg
))
for
arg
in
node
.
args
]
args
=
[
ast
.
literal_eval
(
astor
.
to_source
(
arg
))
for
arg
in
node
.
args
]
key
=
self
.
module_name
+
'/'
+
name
+
'/'
+
func
key
=
self
.
module_name
+
'/'
+
name
+
'/'
+
func
...
...
tools/nni_annotation/specific_code_generator.py
View file @
c785655e
...
@@ -28,6 +28,7 @@ from nni_cmd.common_utils import print_warning
...
@@ -28,6 +28,7 @@ from nni_cmd.common_utils import print_warning
para_cfg
=
None
para_cfg
=
None
prefix_name
=
None
prefix_name
=
None
def
parse_annotation_mutable_layers
(
code
,
lineno
):
def
parse_annotation_mutable_layers
(
code
,
lineno
):
"""Parse the string of mutable layers in annotation.
"""Parse the string of mutable layers in annotation.
Return a list of AST Expr nodes
Return a list of AST Expr nodes
...
@@ -102,6 +103,7 @@ def parse_annotation_mutable_layers(code, lineno):
...
@@ -102,6 +103,7 @@ def parse_annotation_mutable_layers(code, lineno):
nodes
.
append
(
node
)
nodes
.
append
(
node
)
return
nodes
return
nodes
def
parse_annotation
(
code
):
def
parse_annotation
(
code
):
"""Parse an annotation string.
"""Parse an annotation string.
Return an AST Expr node.
Return an AST Expr node.
...
@@ -182,7 +184,7 @@ def convert_args_to_dict(call, with_lambda=False):
...
@@ -182,7 +184,7 @@ def convert_args_to_dict(call, with_lambda=False):
if
type
(
arg
)
in
[
ast
.
Str
,
ast
.
Num
]:
if
type
(
arg
)
in
[
ast
.
Str
,
ast
.
Num
]:
arg_value
=
arg
arg_value
=
arg
else
:
else
:
# if arg is not a string or a number, we use its source code as the key
# if arg is not a string or a number, we use its source code as the key
arg_value
=
astor
.
to_source
(
arg
).
strip
(
'
\n
"'
)
arg_value
=
astor
.
to_source
(
arg
).
strip
(
'
\n
"'
)
arg_value
=
ast
.
Str
(
str
(
arg_value
))
arg_value
=
ast
.
Str
(
str
(
arg_value
))
arg
=
make_lambda
(
arg
)
if
with_lambda
else
arg
arg
=
make_lambda
(
arg
)
if
with_lambda
else
arg
...
@@ -217,7 +219,7 @@ def test_variable_equal(node1, node2):
...
@@ -217,7 +219,7 @@ def test_variable_equal(node1, node2):
if
len
(
node1
)
!=
len
(
node2
):
if
len
(
node1
)
!=
len
(
node2
):
return
False
return
False
return
all
(
test_variable_equal
(
n1
,
n2
)
for
n1
,
n2
in
zip
(
node1
,
node2
))
return
all
(
test_variable_equal
(
n1
,
n2
)
for
n1
,
n2
in
zip
(
node1
,
node2
))
return
node1
==
node2
return
node1
==
node2
...
@@ -294,7 +296,6 @@ class Transformer(ast.NodeTransformer):
...
@@ -294,7 +296,6 @@ class Transformer(ast.NodeTransformer):
return
self
.
_visit_children
(
node
)
return
self
.
_visit_children
(
node
)
def
_visit_string
(
self
,
node
):
def
_visit_string
(
self
,
node
):
string
=
node
.
value
.
s
string
=
node
.
value
.
s
if
string
.
startswith
(
'@nni.'
):
if
string
.
startswith
(
'@nni.'
):
...
@@ -303,19 +304,27 @@ class Transformer(ast.NodeTransformer):
...
@@ -303,19 +304,27 @@ class Transformer(ast.NodeTransformer):
return
node
# not an annotation, ignore it
return
node
# not an annotation, ignore it
if
string
.
startswith
(
'@nni.get_next_parameter'
):
if
string
.
startswith
(
'@nni.get_next_parameter'
):
deprecated_message
=
"'@nni.get_next_parameter' is deprecated in annotation due to inconvenience. Please remove this line in the trial code."
deprecated_message
=
"'@nni.get_next_parameter' is deprecated in annotation due to inconvenience. "
\
"Please remove this line in the trial code."
print_warning
(
deprecated_message
)
print_warning
(
deprecated_message
)
return
ast
.
Expr
(
value
=
ast
.
Call
(
func
=
ast
.
Name
(
id
=
'print'
,
ctx
=
ast
.
Load
()),
args
=
[
ast
.
Str
(
s
=
'Get next parameter here...'
)],
keywords
=
[]))
return
ast
.
Expr
(
value
=
ast
.
Call
(
func
=
ast
.
Name
(
id
=
'print'
,
ctx
=
ast
.
Load
()),
args
=
[
ast
.
Str
(
s
=
'Get next parameter here...'
)],
keywords
=
[]))
if
string
.
startswith
(
'@nni.training_update'
):
return
ast
.
Expr
(
value
=
ast
.
Call
(
func
=
ast
.
Name
(
id
=
'print'
,
ctx
=
ast
.
Load
()),
args
=
[
ast
.
Str
(
s
=
'Training update here...'
)],
keywords
=
[]))
if
string
.
startswith
(
'@nni.report_intermediate_result'
):
if
string
.
startswith
(
'@nni.report_intermediate_result'
):
module
=
ast
.
parse
(
string
[
1
:])
module
=
ast
.
parse
(
string
[
1
:])
arg
=
module
.
body
[
0
].
value
.
args
[
0
]
arg
=
module
.
body
[
0
].
value
.
args
[
0
]
return
ast
.
Expr
(
value
=
ast
.
Call
(
func
=
ast
.
Name
(
id
=
'print'
,
ctx
=
ast
.
Load
()),
args
=
[
ast
.
Str
(
s
=
'nni.report_intermediate_result: '
),
arg
],
keywords
=
[]))
return
ast
.
Expr
(
value
=
ast
.
Call
(
func
=
ast
.
Name
(
id
=
'print'
,
ctx
=
ast
.
Load
()),
args
=
[
ast
.
Str
(
s
=
'nni.report_intermediate_result: '
),
arg
],
keywords
=
[]))
if
string
.
startswith
(
'@nni.report_final_result'
):
if
string
.
startswith
(
'@nni.report_final_result'
):
module
=
ast
.
parse
(
string
[
1
:])
module
=
ast
.
parse
(
string
[
1
:])
arg
=
module
.
body
[
0
].
value
.
args
[
0
]
arg
=
module
.
body
[
0
].
value
.
args
[
0
]
return
ast
.
Expr
(
value
=
ast
.
Call
(
func
=
ast
.
Name
(
id
=
'print'
,
ctx
=
ast
.
Load
()),
args
=
[
ast
.
Str
(
s
=
'nni.report_final_result: '
),
arg
],
keywords
=
[]))
return
ast
.
Expr
(
value
=
ast
.
Call
(
func
=
ast
.
Name
(
id
=
'print'
,
ctx
=
ast
.
Load
()),
args
=
[
ast
.
Str
(
s
=
'nni.report_final_result: '
),
arg
],
keywords
=
[]))
if
string
.
startswith
(
'@nni.mutable_layers'
):
if
string
.
startswith
(
'@nni.mutable_layers'
):
return
parse_annotation_mutable_layers
(
string
[
1
:],
node
.
lineno
)
return
parse_annotation_mutable_layers
(
string
[
1
:],
node
.
lineno
)
...
@@ -327,7 +336,6 @@ class Transformer(ast.NodeTransformer):
...
@@ -327,7 +336,6 @@ class Transformer(ast.NodeTransformer):
raise
AssertionError
(
'Unexpected annotation function'
)
raise
AssertionError
(
'Unexpected annotation function'
)
def
_visit_children
(
self
,
node
):
def
_visit_children
(
self
,
node
):
self
.
stack
.
append
(
None
)
self
.
stack
.
append
(
None
)
self
.
generic_visit
(
node
)
self
.
generic_visit
(
node
)
...
...
tools/nni_annotation/test_annotation.py
View file @
c785655e
...
@@ -39,17 +39,18 @@ class AnnotationTestCase(TestCase):
...
@@ -39,17 +39,18 @@ class AnnotationTestCase(TestCase):
shutil
.
rmtree
(
'_generated'
)
shutil
.
rmtree
(
'_generated'
)
def
test_search_space_generator
(
self
):
def
test_search_space_generator
(
self
):
search_space
=
generate_search_space
(
'testcase/annotated'
)
shutil
.
copytree
(
'testcase/annotated'
,
'_generated/annotated'
)
search_space
=
generate_search_space
(
'_generated/annotated'
)
with
open
(
'testcase/searchspace.json'
)
as
f
:
with
open
(
'testcase/searchspace.json'
)
as
f
:
self
.
assertEqual
(
search_space
,
json
.
load
(
f
))
self
.
assertEqual
(
search_space
,
json
.
load
(
f
))
def
test_code_generator
(
self
):
def
test_code_generator
(
self
):
code_dir
=
expand_annotations
(
'testcase/usercode'
,
'_generated'
,
nas_mode
=
'classic_mode'
)
code_dir
=
expand_annotations
(
'testcase/usercode'
,
'_generated
/usercode
'
,
nas_mode
=
'classic_mode'
)
self
.
assertEqual
(
code_dir
,
'_generated'
)
self
.
assertEqual
(
code_dir
,
'_generated
/usercode
'
)
self
.
_assert_source_equal
(
'testcase/annotated/nas.py'
,
'_generated/nas.py'
)
self
.
_assert_source_equal
(
'testcase/annotated/nas.py'
,
'_generated/
usercode/
nas.py'
)
self
.
_assert_source_equal
(
'testcase/annotated/mnist.py'
,
'_generated/mnist.py'
)
self
.
_assert_source_equal
(
'testcase/annotated/mnist.py'
,
'_generated/
usercode/
mnist.py'
)
self
.
_assert_source_equal
(
'testcase/annotated/dir/simple.py'
,
'_generated/dir/simple.py'
)
self
.
_assert_source_equal
(
'testcase/annotated/dir/simple.py'
,
'_generated/
usercode/
dir/simple.py'
)
with
open
(
'testcase/usercode/nonpy.txt'
)
as
src
,
open
(
'_generated/nonpy.txt'
)
as
dst
:
with
open
(
'testcase/usercode/nonpy.txt'
)
as
src
,
open
(
'_generated/
usercode/
nonpy.txt'
)
as
dst
:
assert
src
.
read
()
==
dst
.
read
()
assert
src
.
read
()
==
dst
.
read
()
def
test_annotation_detecting
(
self
):
def
test_annotation_detecting
(
self
):
...
...
tools/nni_cmd/config_schema.py
View file @
c785655e
...
@@ -76,7 +76,7 @@ tuner_schema_dict = {
...
@@ -76,7 +76,7 @@ tuner_schema_dict = {
'optimize_mode'
:
setChoice
(
'optimize_mode'
,
'maximize'
,
'minimize'
),
'optimize_mode'
:
setChoice
(
'optimize_mode'
,
'maximize'
,
'minimize'
),
},
},
Optional
(
'includeIntermediateResults'
):
setType
(
'includeIntermediateResults'
,
bool
),
Optional
(
'includeIntermediateResults'
):
setType
(
'includeIntermediateResults'
,
bool
),
Optional
(
'gpu
Num'
):
setNumberRange
(
'gpuNum'
,
int
,
0
,
99999
),
Optional
(
'gpu
Indices'
):
Or
(
int
,
And
(
str
,
lambda
x
:
len
([
int
(
i
)
for
i
in
x
.
split
(
','
)])
>
0
)
,
error
=
'gpuIndex format error!'
),
},
},
(
'Evolution'
):
{
(
'Evolution'
):
{
'builtinTunerName'
:
setChoice
(
'builtinTunerName'
,
'Evolution'
),
'builtinTunerName'
:
setChoice
(
'builtinTunerName'
,
'Evolution'
),
...
@@ -85,12 +85,12 @@ tuner_schema_dict = {
...
@@ -85,12 +85,12 @@ tuner_schema_dict = {
Optional
(
'population_size'
):
setNumberRange
(
'population_size'
,
int
,
0
,
99999
),
Optional
(
'population_size'
):
setNumberRange
(
'population_size'
,
int
,
0
,
99999
),
},
},
Optional
(
'includeIntermediateResults'
):
setType
(
'includeIntermediateResults'
,
bool
),
Optional
(
'includeIntermediateResults'
):
setType
(
'includeIntermediateResults'
,
bool
),
Optional
(
'gpu
Num'
):
setNumberRange
(
'gpuNum'
,
int
,
0
,
99999
),
Optional
(
'gpu
Indices'
):
Or
(
int
,
And
(
str
,
lambda
x
:
len
([
int
(
i
)
for
i
in
x
.
split
(
','
)])
>
0
)
,
error
=
'gpuIndex format error!'
),
},
},
(
'BatchTuner'
,
'GridSearch'
,
'Random'
):
{
(
'BatchTuner'
,
'GridSearch'
,
'Random'
):
{
'builtinTunerName'
:
setChoice
(
'builtinTunerName'
,
'BatchTuner'
,
'GridSearch'
,
'Random'
),
'builtinTunerName'
:
setChoice
(
'builtinTunerName'
,
'BatchTuner'
,
'GridSearch'
,
'Random'
),
Optional
(
'includeIntermediateResults'
):
setType
(
'includeIntermediateResults'
,
bool
),
Optional
(
'includeIntermediateResults'
):
setType
(
'includeIntermediateResults'
,
bool
),
Optional
(
'gpu
Num'
):
setNumberRange
(
'gpuNum'
,
int
,
0
,
99999
),
Optional
(
'gpu
Indices'
):
Or
(
int
,
And
(
str
,
lambda
x
:
len
([
int
(
i
)
for
i
in
x
.
split
(
','
)])
>
0
)
,
error
=
'gpuIndex format error!'
),
},
},
'TPE'
:
{
'TPE'
:
{
'builtinTunerName'
:
'TPE'
,
'builtinTunerName'
:
'TPE'
,
...
@@ -100,7 +100,7 @@ tuner_schema_dict = {
...
@@ -100,7 +100,7 @@ tuner_schema_dict = {
Optional
(
'constant_liar_type'
):
setChoice
(
'constant_liar_type'
,
'min'
,
'max'
,
'mean'
)
Optional
(
'constant_liar_type'
):
setChoice
(
'constant_liar_type'
,
'min'
,
'max'
,
'mean'
)
},
},
Optional
(
'includeIntermediateResults'
):
setType
(
'includeIntermediateResults'
,
bool
),
Optional
(
'includeIntermediateResults'
):
setType
(
'includeIntermediateResults'
,
bool
),
Optional
(
'gpu
Num'
):
setNumberRange
(
'gpuNum'
,
int
,
0
,
99999
),
Optional
(
'gpu
Indices'
):
Or
(
int
,
And
(
str
,
lambda
x
:
len
([
int
(
i
)
for
i
in
x
.
split
(
','
)])
>
0
)
,
error
=
'gpuIndex format error!'
),
},
},
'NetworkMorphism'
:
{
'NetworkMorphism'
:
{
'builtinTunerName'
:
'NetworkMorphism'
,
'builtinTunerName'
:
'NetworkMorphism'
,
...
@@ -112,7 +112,7 @@ tuner_schema_dict = {
...
@@ -112,7 +112,7 @@ tuner_schema_dict = {
Optional
(
'n_output_node'
):
setType
(
'n_output_node'
,
int
),
Optional
(
'n_output_node'
):
setType
(
'n_output_node'
,
int
),
},
},
Optional
(
'includeIntermediateResults'
):
setType
(
'includeIntermediateResults'
,
bool
),
Optional
(
'includeIntermediateResults'
):
setType
(
'includeIntermediateResults'
,
bool
),
Optional
(
'gpu
Num'
):
setNumberRange
(
'gpuNum'
,
int
,
0
,
99999
),
Optional
(
'gpu
Indices'
):
Or
(
int
,
And
(
str
,
lambda
x
:
len
([
int
(
i
)
for
i
in
x
.
split
(
','
)])
>
0
)
,
error
=
'gpuIndex format error!'
),
},
},
'MetisTuner'
:
{
'MetisTuner'
:
{
'builtinTunerName'
:
'MetisTuner'
,
'builtinTunerName'
:
'MetisTuner'
,
...
@@ -124,7 +124,7 @@ tuner_schema_dict = {
...
@@ -124,7 +124,7 @@ tuner_schema_dict = {
Optional
(
'cold_start_num'
):
setType
(
'cold_start_num'
,
int
),
Optional
(
'cold_start_num'
):
setType
(
'cold_start_num'
,
int
),
},
},
Optional
(
'includeIntermediateResults'
):
setType
(
'includeIntermediateResults'
,
bool
),
Optional
(
'includeIntermediateResults'
):
setType
(
'includeIntermediateResults'
,
bool
),
Optional
(
'gpu
Num'
):
setNumberRange
(
'gpuNum'
,
int
,
0
,
99999
),
Optional
(
'gpu
Indices'
):
Or
(
int
,
And
(
str
,
lambda
x
:
len
([
int
(
i
)
for
i
in
x
.
split
(
','
)])
>
0
)
,
error
=
'gpuIndex format error!'
),
},
},
'GPTuner'
:
{
'GPTuner'
:
{
'builtinTunerName'
:
'GPTuner'
,
'builtinTunerName'
:
'GPTuner'
,
...
@@ -140,7 +140,25 @@ tuner_schema_dict = {
...
@@ -140,7 +140,25 @@ tuner_schema_dict = {
Optional
(
'selection_num_starting_points'
):
setType
(
'selection_num_starting_points'
,
int
),
Optional
(
'selection_num_starting_points'
):
setType
(
'selection_num_starting_points'
,
int
),
},
},
Optional
(
'includeIntermediateResults'
):
setType
(
'includeIntermediateResults'
,
bool
),
Optional
(
'includeIntermediateResults'
):
setType
(
'includeIntermediateResults'
,
bool
),
Optional
(
'gpuNum'
):
setNumberRange
(
'gpuNum'
,
int
,
0
,
99999
),
Optional
(
'gpuIndices'
):
Or
(
int
,
And
(
str
,
lambda
x
:
len
([
int
(
i
)
for
i
in
x
.
split
(
','
)])
>
0
),
error
=
'gpuIndex format error!'
),
},
'PPOTuner'
:
{
'builtinTunerName'
:
'PPOTuner'
,
'classArgs'
:
{
'optimize_mode'
:
setChoice
(
'optimize_mode'
,
'maximize'
,
'minimize'
),
Optional
(
'trials_per_update'
):
setNumberRange
(
'trials_per_update'
,
int
,
0
,
99999
),
Optional
(
'epochs_per_update'
):
setNumberRange
(
'epochs_per_update'
,
int
,
0
,
99999
),
Optional
(
'minibatch_size'
):
setNumberRange
(
'minibatch_size'
,
int
,
0
,
99999
),
Optional
(
'ent_coef'
):
setType
(
'ent_coef'
,
float
),
Optional
(
'lr'
):
setType
(
'lr'
,
float
),
Optional
(
'vf_coef'
):
setType
(
'vf_coef'
,
float
),
Optional
(
'max_grad_norm'
):
setType
(
'max_grad_norm'
,
float
),
Optional
(
'gamma'
):
setType
(
'gamma'
,
float
),
Optional
(
'lam'
):
setType
(
'lam'
,
float
),
Optional
(
'cliprange'
):
setType
(
'cliprange'
,
float
),
},
Optional
(
'includeIntermediateResults'
):
setType
(
'includeIntermediateResults'
,
bool
),
Optional
(
'gpuIndices'
):
Or
(
int
,
And
(
str
,
lambda
x
:
len
([
int
(
i
)
for
i
in
x
.
split
(
','
)])
>
0
),
error
=
'gpuIndex format error!'
),
},
},
'customized'
:
{
'customized'
:
{
'codeDir'
:
setPathCheck
(
'codeDir'
),
'codeDir'
:
setPathCheck
(
'codeDir'
),
...
@@ -148,7 +166,7 @@ tuner_schema_dict = {
...
@@ -148,7 +166,7 @@ tuner_schema_dict = {
'className'
:
setType
(
'className'
,
str
),
'className'
:
setType
(
'className'
,
str
),
Optional
(
'classArgs'
):
dict
,
Optional
(
'classArgs'
):
dict
,
Optional
(
'includeIntermediateResults'
):
setType
(
'includeIntermediateResults'
,
bool
),
Optional
(
'includeIntermediateResults'
):
setType
(
'includeIntermediateResults'
,
bool
),
Optional
(
'gpu
Num'
):
setNumberRange
(
'gpuNum'
,
int
,
0
,
99999
),
Optional
(
'gpu
Indices'
):
Or
(
int
,
And
(
str
,
lambda
x
:
len
([
int
(
i
)
for
i
in
x
.
split
(
','
)])
>
0
)
,
error
=
'gpuIndex format error!'
),
}
}
}
}
...
@@ -160,7 +178,7 @@ advisor_schema_dict = {
...
@@ -160,7 +178,7 @@ advisor_schema_dict = {
Optional
(
'R'
):
setType
(
'R'
,
int
),
Optional
(
'R'
):
setType
(
'R'
,
int
),
Optional
(
'eta'
):
setType
(
'eta'
,
int
)
Optional
(
'eta'
):
setType
(
'eta'
,
int
)
},
},
Optional
(
'gpu
Num'
):
setNumberRange
(
'gpuNum'
,
int
,
0
,
99999
),
Optional
(
'gpu
Indices'
):
Or
(
int
,
And
(
str
,
lambda
x
:
len
([
int
(
i
)
for
i
in
x
.
split
(
','
)])
>
0
)
,
error
=
'gpuIndex format error!'
),
},
},
'BOHB'
:{
'BOHB'
:{
'builtinAdvisorName'
:
Or
(
'BOHB'
),
'builtinAdvisorName'
:
Or
(
'BOHB'
),
...
@@ -176,14 +194,14 @@ advisor_schema_dict = {
...
@@ -176,14 +194,14 @@ advisor_schema_dict = {
Optional
(
'bandwidth_factor'
):
setNumberRange
(
'bandwidth_factor'
,
float
,
0
,
9999
),
Optional
(
'bandwidth_factor'
):
setNumberRange
(
'bandwidth_factor'
,
float
,
0
,
9999
),
Optional
(
'min_bandwidth'
):
setNumberRange
(
'min_bandwidth'
,
float
,
0
,
9999
),
Optional
(
'min_bandwidth'
):
setNumberRange
(
'min_bandwidth'
,
float
,
0
,
9999
),
},
},
Optional
(
'gpu
Num'
):
setNumberRange
(
'gpuNum'
,
int
,
0
,
99999
),
Optional
(
'gpu
Indices'
):
Or
(
int
,
And
(
str
,
lambda
x
:
len
([
int
(
i
)
for
i
in
x
.
split
(
','
)])
>
0
)
,
error
=
'gpuIndex format error!'
),
},
},
'customized'
:{
'customized'
:{
'codeDir'
:
setPathCheck
(
'codeDir'
),
'codeDir'
:
setPathCheck
(
'codeDir'
),
'classFileName'
:
setType
(
'classFileName'
,
str
),
'classFileName'
:
setType
(
'classFileName'
,
str
),
'className'
:
setType
(
'className'
,
str
),
'className'
:
setType
(
'className'
,
str
),
Optional
(
'classArgs'
):
dict
,
Optional
(
'classArgs'
):
dict
,
Optional
(
'gpu
Num'
):
setNumberRange
(
'gpuNum'
,
int
,
0
,
99999
),
Optional
(
'gpu
Indices'
):
Or
(
int
,
And
(
str
,
lambda
x
:
len
([
int
(
i
)
for
i
in
x
.
split
(
','
)])
>
0
)
,
error
=
'gpuIndex format error!'
),
}
}
}
}
...
@@ -194,7 +212,6 @@ assessor_schema_dict = {
...
@@ -194,7 +212,6 @@ assessor_schema_dict = {
Optional
(
'optimize_mode'
):
setChoice
(
'optimize_mode'
,
'maximize'
,
'minimize'
),
Optional
(
'optimize_mode'
):
setChoice
(
'optimize_mode'
,
'maximize'
,
'minimize'
),
Optional
(
'start_step'
):
setNumberRange
(
'start_step'
,
int
,
0
,
9999
),
Optional
(
'start_step'
):
setNumberRange
(
'start_step'
,
int
,
0
,
9999
),
},
},
Optional
(
'gpuNum'
):
setNumberRange
(
'gpuNum'
,
int
,
0
,
99999
),
},
},
'Curvefitting'
:
{
'Curvefitting'
:
{
'builtinAssessorName'
:
'Curvefitting'
,
'builtinAssessorName'
:
'Curvefitting'
,
...
@@ -205,14 +222,12 @@ assessor_schema_dict = {
...
@@ -205,14 +222,12 @@ assessor_schema_dict = {
Optional
(
'threshold'
):
setNumberRange
(
'threshold'
,
float
,
0
,
9999
),
Optional
(
'threshold'
):
setNumberRange
(
'threshold'
,
float
,
0
,
9999
),
Optional
(
'gap'
):
setNumberRange
(
'gap'
,
int
,
1
,
9999
),
Optional
(
'gap'
):
setNumberRange
(
'gap'
,
int
,
1
,
9999
),
},
},
Optional
(
'gpuNum'
):
setNumberRange
(
'gpuNum'
,
int
,
0
,
99999
),
},
},
'customized'
:
{
'customized'
:
{
'codeDir'
:
setPathCheck
(
'codeDir'
),
'codeDir'
:
setPathCheck
(
'codeDir'
),
'classFileName'
:
setType
(
'classFileName'
,
str
),
'classFileName'
:
setType
(
'classFileName'
,
str
),
'className'
:
setType
(
'className'
,
str
),
'className'
:
setType
(
'className'
,
str
),
Optional
(
'classArgs'
):
dict
,
Optional
(
'classArgs'
):
dict
,
Optional
(
'gpuNum'
):
setNumberRange
(
'gpuNum'
,
int
,
0
,
99999
)
}
}
}
}
...
...
tools/nni_cmd/constants.py
View file @
c785655e
...
@@ -80,7 +80,8 @@ TRIAL_MONITOR_TAIL = '----------------------------------------------------------
...
@@ -80,7 +80,8 @@ TRIAL_MONITOR_TAIL = '----------------------------------------------------------
PACKAGE_REQUIREMENTS
=
{
PACKAGE_REQUIREMENTS
=
{
'SMAC'
:
'smac_tuner'
,
'SMAC'
:
'smac_tuner'
,
'BOHB'
:
'bohb_advisor'
'BOHB'
:
'bohb_advisor'
,
'PPOTuner'
:
'ppo_tuner'
}
}
TUNERS_SUPPORTING_IMPORT_DATA
=
{
TUNERS_SUPPORTING_IMPORT_DATA
=
{
...
...
tools/nni_cmd/launcher.py
View file @
c785655e
...
@@ -118,12 +118,17 @@ def start_rest_server(port, platform, mode, config_file_name, experiment_id=None
...
@@ -118,12 +118,17 @@ def start_rest_server(port, platform, mode, config_file_name, experiment_id=None
node_command
=
'node'
node_command
=
'node'
if
sys
.
platform
==
'win32'
:
if
sys
.
platform
==
'win32'
:
node_command
=
os
.
path
.
join
(
entry_dir
[:
-
3
],
'Scripts'
,
'node.exe'
)
node_command
=
os
.
path
.
join
(
entry_dir
[:
-
3
],
'Scripts'
,
'node.exe'
)
cmds
=
[
node_command
,
entry_file
,
'--port'
,
str
(
port
),
'--mode'
,
platform
,
'--start_mode'
,
mode
]
cmds
=
[
node_command
,
entry_file
,
'--port'
,
str
(
port
),
'--mode'
,
platform
]
if
mode
==
'view'
:
cmds
+=
[
'--start_mode'
,
'resume'
]
cmds
+=
[
'--readonly'
,
'true'
]
else
:
cmds
+=
[
'--start_mode'
,
mode
]
if
log_dir
is
not
None
:
if
log_dir
is
not
None
:
cmds
+=
[
'--log_dir'
,
log_dir
]
cmds
+=
[
'--log_dir'
,
log_dir
]
if
log_level
is
not
None
:
if
log_level
is
not
None
:
cmds
+=
[
'--log_level'
,
log_level
]
cmds
+=
[
'--log_level'
,
log_level
]
if
mode
==
'resume'
:
if
mode
in
[
'resume'
,
'view'
]
:
cmds
+=
[
'--experiment_id'
,
experiment_id
]
cmds
+=
[
'--experiment_id'
,
experiment_id
]
stdout_full_path
,
stderr_full_path
=
get_log_path
(
config_file_name
)
stdout_full_path
,
stderr_full_path
=
get_log_path
(
config_file_name
)
with
open
(
stdout_full_path
,
'a+'
)
as
stdout_file
,
open
(
stderr_full_path
,
'a+'
)
as
stderr_file
:
with
open
(
stdout_full_path
,
'a+'
)
as
stdout_file
,
open
(
stderr_full_path
,
'a+'
)
as
stderr_file
:
...
@@ -156,7 +161,6 @@ def set_trial_config(experiment_config, port, config_file_name):
...
@@ -156,7 +161,6 @@ def set_trial_config(experiment_config, port, config_file_name):
def
set_local_config
(
experiment_config
,
port
,
config_file_name
):
def
set_local_config
(
experiment_config
,
port
,
config_file_name
):
'''set local configuration'''
'''set local configuration'''
#set machine_list
request_data
=
dict
()
request_data
=
dict
()
if
experiment_config
.
get
(
'localConfig'
):
if
experiment_config
.
get
(
'localConfig'
):
request_data
[
'local_config'
]
=
experiment_config
[
'localConfig'
]
request_data
[
'local_config'
]
=
experiment_config
[
'localConfig'
]
...
@@ -177,7 +181,7 @@ def set_local_config(experiment_config, port, config_file_name):
...
@@ -177,7 +181,7 @@ def set_local_config(experiment_config, port, config_file_name):
fout
.
write
(
json
.
dumps
(
json
.
loads
(
err_message
),
indent
=
4
,
sort_keys
=
True
,
separators
=
(
','
,
':'
)))
fout
.
write
(
json
.
dumps
(
json
.
loads
(
err_message
),
indent
=
4
,
sort_keys
=
True
,
separators
=
(
','
,
':'
)))
return
False
,
err_message
return
False
,
err_message
return
set_trial_config
(
experiment_config
,
port
,
config_file_name
)
return
set_trial_config
(
experiment_config
,
port
,
config_file_name
)
,
None
def
set_remote_config
(
experiment_config
,
port
,
config_file_name
):
def
set_remote_config
(
experiment_config
,
port
,
config_file_name
):
'''Call setClusterMetadata to pass trial'''
'''Call setClusterMetadata to pass trial'''
...
@@ -296,10 +300,20 @@ def set_experiment(experiment_config, mode, port, config_file_name):
...
@@ -296,10 +300,20 @@ def set_experiment(experiment_config, mode, port, config_file_name):
request_data
[
'multiThread'
]
=
experiment_config
.
get
(
'multiThread'
)
request_data
[
'multiThread'
]
=
experiment_config
.
get
(
'multiThread'
)
if
experiment_config
.
get
(
'advisor'
):
if
experiment_config
.
get
(
'advisor'
):
request_data
[
'advisor'
]
=
experiment_config
[
'advisor'
]
request_data
[
'advisor'
]
=
experiment_config
[
'advisor'
]
if
request_data
[
'advisor'
].
get
(
'gpuNum'
):
print_error
(
'gpuNum is deprecated, please use gpuIndices instead.'
)
if
request_data
[
'advisor'
].
get
(
'gpuIndices'
)
and
isinstance
(
request_data
[
'advisor'
].
get
(
'gpuIndices'
),
int
):
request_data
[
'advisor'
][
'gpuIndices'
]
=
str
(
request_data
[
'advisor'
].
get
(
'gpuIndices'
))
else
:
else
:
request_data
[
'tuner'
]
=
experiment_config
[
'tuner'
]
request_data
[
'tuner'
]
=
experiment_config
[
'tuner'
]
if
request_data
[
'tuner'
].
get
(
'gpuNum'
):
print_error
(
'gpuNum is deprecated, please use gpuIndices instead.'
)
if
request_data
[
'tuner'
].
get
(
'gpuIndices'
)
and
isinstance
(
request_data
[
'tuner'
].
get
(
'gpuIndices'
),
int
):
request_data
[
'tuner'
][
'gpuIndices'
]
=
str
(
request_data
[
'tuner'
].
get
(
'gpuIndices'
))
if
'assessor'
in
experiment_config
:
if
'assessor'
in
experiment_config
:
request_data
[
'assessor'
]
=
experiment_config
[
'assessor'
]
request_data
[
'assessor'
]
=
experiment_config
[
'assessor'
]
if
request_data
[
'assessor'
].
get
(
'gpuNum'
):
print_error
(
'gpuNum is deprecated, please remove it from your config file.'
)
#debug mode should disable version check
#debug mode should disable version check
if
experiment_config
.
get
(
'debug'
)
is
not
None
:
if
experiment_config
.
get
(
'debug'
)
is
not
None
:
request_data
[
'versionCheck'
]
=
not
experiment_config
.
get
(
'debug'
)
request_data
[
'versionCheck'
]
=
not
experiment_config
.
get
(
'debug'
)
...
@@ -335,7 +349,6 @@ def set_experiment(experiment_config, mode, port, config_file_name):
...
@@ -335,7 +349,6 @@ def set_experiment(experiment_config, mode, port, config_file_name):
{
'key'
:
'frameworkcontroller_config'
,
'value'
:
experiment_config
[
'frameworkcontrollerConfig'
]})
{
'key'
:
'frameworkcontroller_config'
,
'value'
:
experiment_config
[
'frameworkcontrollerConfig'
]})
request_data
[
'clusterMetaData'
].
append
(
request_data
[
'clusterMetaData'
].
append
(
{
'key'
:
'trial_config'
,
'value'
:
experiment_config
[
'trial'
]})
{
'key'
:
'trial_config'
,
'value'
:
experiment_config
[
'trial'
]})
response
=
rest_post
(
experiment_url
(
port
),
json
.
dumps
(
request_data
),
REST_TIME_OUT
,
show_error
=
True
)
response
=
rest_post
(
experiment_url
(
port
),
json
.
dumps
(
request_data
),
REST_TIME_OUT
,
show_error
=
True
)
if
check_response
(
response
):
if
check_response
(
response
):
return
response
return
response
...
@@ -347,6 +360,33 @@ def set_experiment(experiment_config, mode, port, config_file_name):
...
@@ -347,6 +360,33 @@ def set_experiment(experiment_config, mode, port, config_file_name):
print_error
(
'Setting experiment error, error message is {}'
.
format
(
response
.
text
))
print_error
(
'Setting experiment error, error message is {}'
.
format
(
response
.
text
))
return
None
return
None
def
set_platform_config
(
platform
,
experiment_config
,
port
,
config_file_name
,
rest_process
):
'''call set_cluster_metadata for specific platform'''
print_normal
(
'Setting {0} config...'
.
format
(
platform
))
config_result
,
err_msg
=
None
,
None
if
platform
==
'local'
:
config_result
,
err_msg
=
set_local_config
(
experiment_config
,
port
,
config_file_name
)
elif
platform
==
'remote'
:
config_result
,
err_msg
=
set_remote_config
(
experiment_config
,
port
,
config_file_name
)
elif
platform
==
'pai'
:
config_result
,
err_msg
=
set_pai_config
(
experiment_config
,
port
,
config_file_name
)
elif
platform
==
'kubeflow'
:
config_result
,
err_msg
=
set_kubeflow_config
(
experiment_config
,
port
,
config_file_name
)
elif
platform
==
'frameworkcontroller'
:
config_result
,
err_msg
=
set_frameworkcontroller_config
(
experiment_config
,
port
,
config_file_name
)
else
:
raise
Exception
(
ERROR_INFO
%
'Unsupported platform!'
)
exit
(
1
)
if
config_result
:
print_normal
(
'Successfully set {0} config!'
.
format
(
platform
))
else
:
print_error
(
'Failed! Error is: {}'
.
format
(
err_msg
))
try
:
kill_command
(
rest_process
.
pid
)
except
Exception
:
raise
Exception
(
ERROR_INFO
%
'Rest server stopped!'
)
exit
(
1
)
def
launch_experiment
(
args
,
experiment_config
,
mode
,
config_file_name
,
experiment_id
=
None
):
def
launch_experiment
(
args
,
experiment_config
,
mode
,
config_file_name
,
experiment_id
=
None
):
'''follow steps to start rest server and start experiment'''
'''follow steps to start rest server and start experiment'''
nni_config
=
Config
(
config_file_name
)
nni_config
=
Config
(
config_file_name
)
...
@@ -371,8 +411,10 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
...
@@ -371,8 +411,10 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
exit
(
1
)
exit
(
1
)
log_dir
=
experiment_config
[
'logDir'
]
if
experiment_config
.
get
(
'logDir'
)
else
None
log_dir
=
experiment_config
[
'logDir'
]
if
experiment_config
.
get
(
'logDir'
)
else
None
log_level
=
experiment_config
[
'logLevel'
]
if
experiment_config
.
get
(
'logLevel'
)
else
None
log_level
=
experiment_config
[
'logLevel'
]
if
experiment_config
.
get
(
'logLevel'
)
else
None
if
log_level
not
in
[
'trace'
,
'debug'
]
and
(
args
.
debug
or
experiment_config
.
get
(
'debug'
)
is
True
):
#view experiment mode do not need debug function, when view an experiment, there will be no new logs created
log_level
=
'debug'
if
mode
!=
'view'
:
if
log_level
not
in
[
'trace'
,
'debug'
]
and
(
args
.
debug
or
experiment_config
.
get
(
'debug'
)
is
True
):
log_level
=
'debug'
# start rest server
# start rest server
rest_process
,
start_time
=
start_rest_server
(
args
.
port
,
experiment_config
[
'trainingServicePlatform'
],
mode
,
config_file_name
,
experiment_id
,
log_dir
,
log_level
)
rest_process
,
start_time
=
start_rest_server
(
args
.
port
,
experiment_config
[
'trainingServicePlatform'
],
mode
,
config_file_name
,
experiment_id
,
log_dir
,
log_level
)
nni_config
.
set_config
(
'restServerPid'
,
rest_process
.
pid
)
nni_config
.
set_config
(
'restServerPid'
,
rest_process
.
pid
)
...
@@ -406,83 +448,14 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
...
@@ -406,83 +448,14 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
except
Exception
:
except
Exception
:
raise
Exception
(
ERROR_INFO
%
'Rest server stopped!'
)
raise
Exception
(
ERROR_INFO
%
'Rest server stopped!'
)
exit
(
1
)
exit
(
1
)
if
mode
!=
'view'
:
# set remote config
# set platform configuration
if
experiment_config
[
'trainingServicePlatform'
]
==
'remote'
:
set_platform_config
(
experiment_config
[
'trainingServicePlatform'
],
experiment_config
,
args
.
port
,
config_file_name
,
rest_process
)
print_normal
(
'Setting remote config...'
)
config_result
,
err_msg
=
set_remote_config
(
experiment_config
,
args
.
port
,
config_file_name
)
if
config_result
:
print_normal
(
'Successfully set remote config!'
)
else
:
print_error
(
'Failed! Error is: {}'
.
format
(
err_msg
))
try
:
kill_command
(
rest_process
.
pid
)
except
Exception
:
raise
Exception
(
ERROR_INFO
%
'Rest server stopped!'
)
exit
(
1
)
# set local config
if
experiment_config
[
'trainingServicePlatform'
]
==
'local'
:
print_normal
(
'Setting local config...'
)
if
set_local_config
(
experiment_config
,
args
.
port
,
config_file_name
):
print_normal
(
'Successfully set local config!'
)
else
:
print_error
(
'Set local config failed!'
)
try
:
kill_command
(
rest_process
.
pid
)
except
Exception
:
raise
Exception
(
ERROR_INFO
%
'Rest server stopped!'
)
exit
(
1
)
#set pai config
if
experiment_config
[
'trainingServicePlatform'
]
==
'pai'
:
print_normal
(
'Setting pai config...'
)
config_result
,
err_msg
=
set_pai_config
(
experiment_config
,
args
.
port
,
config_file_name
)
if
config_result
:
print_normal
(
'Successfully set pai config!'
)
else
:
if
err_msg
:
print_error
(
'Failed! Error is: {}'
.
format
(
err_msg
))
try
:
kill_command
(
rest_process
.
pid
)
except
Exception
:
raise
Exception
(
ERROR_INFO
%
'Restful server stopped!'
)
exit
(
1
)
#set kubeflow config
if
experiment_config
[
'trainingServicePlatform'
]
==
'kubeflow'
:
print_normal
(
'Setting kubeflow config...'
)
config_result
,
err_msg
=
set_kubeflow_config
(
experiment_config
,
args
.
port
,
config_file_name
)
if
config_result
:
print_normal
(
'Successfully set kubeflow config!'
)
else
:
if
err_msg
:
print_error
(
'Failed! Error is: {}'
.
format
(
err_msg
))
try
:
kill_command
(
rest_process
.
pid
)
except
Exception
:
raise
Exception
(
ERROR_INFO
%
'Restful server stopped!'
)
exit
(
1
)
#set frameworkcontroller config
if
experiment_config
[
'trainingServicePlatform'
]
==
'frameworkcontroller'
:
print_normal
(
'Setting frameworkcontroller config...'
)
config_result
,
err_msg
=
set_frameworkcontroller_config
(
experiment_config
,
args
.
port
,
config_file_name
)
if
config_result
:
print_normal
(
'Successfully set frameworkcontroller config!'
)
else
:
if
err_msg
:
print_error
(
'Failed! Error is: {}'
.
format
(
err_msg
))
try
:
kill_command
(
rest_process
.
pid
)
except
Exception
:
raise
Exception
(
ERROR_INFO
%
'Restful server stopped!'
)
exit
(
1
)
# start a new experiment
# start a new experiment
print_normal
(
'Starting experiment...'
)
print_normal
(
'Starting experiment...'
)
# set debug configuration
# set debug configuration
if
experiment_config
.
get
(
'debug'
)
is
None
:
if
mode
!=
'view'
and
experiment_config
.
get
(
'debug'
)
is
None
:
experiment_config
[
'debug'
]
=
args
.
debug
experiment_config
[
'debug'
]
=
args
.
debug
response
=
set_experiment
(
experiment_config
,
mode
,
args
.
port
,
config_file_name
)
response
=
set_experiment
(
experiment_config
,
mode
,
args
.
port
,
config_file_name
)
if
response
:
if
response
:
...
@@ -509,8 +482,23 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
...
@@ -509,8 +482,23 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
print_normal
(
EXPERIMENT_SUCCESS_INFO
%
(
experiment_id
,
' '
.
join
(
web_ui_url_list
)))
print_normal
(
EXPERIMENT_SUCCESS_INFO
%
(
experiment_id
,
' '
.
join
(
web_ui_url_list
)))
def
resume_experiment
(
args
):
def
create_experiment
(
args
):
'''resume an experiment'''
'''start a new experiment'''
config_file_name
=
''
.
join
(
random
.
sample
(
string
.
ascii_letters
+
string
.
digits
,
8
))
nni_config
=
Config
(
config_file_name
)
config_path
=
os
.
path
.
abspath
(
args
.
config
)
if
not
os
.
path
.
exists
(
config_path
):
print_error
(
'Please set correct config path!'
)
exit
(
1
)
experiment_config
=
get_yml_content
(
config_path
)
validate_all_content
(
experiment_config
,
config_path
)
nni_config
.
set_config
(
'experimentConfig'
,
experiment_config
)
launch_experiment
(
args
,
experiment_config
,
'new'
,
config_file_name
)
nni_config
.
set_config
(
'restServerPort'
,
args
.
port
)
def
manage_stopped_experiment
(
args
,
mode
):
'''view a stopped experiment'''
update_experiment
()
update_experiment
()
experiment_config
=
Experiments
()
experiment_config
=
Experiments
()
experiment_dict
=
experiment_config
.
get_all_experiments
()
experiment_dict
=
experiment_config
.
get_all_experiments
()
...
@@ -518,38 +506,31 @@ def resume_experiment(args):
...
@@ -518,38 +506,31 @@ def resume_experiment(args):
experiment_endTime
=
None
experiment_endTime
=
None
#find the latest stopped experiment
#find the latest stopped experiment
if
not
args
.
id
:
if
not
args
.
id
:
print_error
(
'Please set experiment id!
\n
You could use
\'
nnictl
resume
{id}
\'
to
resume
a stopped experiment!
\n
'
\
print_error
(
'Please set experiment id!
\n
You could use
\'
nnictl
{0}
{id}
\'
to
{0}
a stopped experiment!
\n
'
\
'You could use
\'
nnictl experiment list --all
\'
to show all experiments!'
)
'You could use
\'
nnictl experiment list --all
\'
to show all experiments!'
.
format
(
mode
)
)
exit
(
1
)
exit
(
1
)
else
:
else
:
if
experiment_dict
.
get
(
args
.
id
)
is
None
:
if
experiment_dict
.
get
(
args
.
id
)
is
None
:
print_error
(
'Id %s not exist!'
%
args
.
id
)
print_error
(
'Id %s not exist!'
%
args
.
id
)
exit
(
1
)
exit
(
1
)
if
experiment_dict
[
args
.
id
][
'status'
]
!=
'STOPPED'
:
if
experiment_dict
[
args
.
id
][
'status'
]
!=
'STOPPED'
:
print_error
(
'Only stopped experiments can be
resumed!'
)
print_error
(
'Only stopped experiments can be
{0}ed!'
.
format
(
mode
)
)
exit
(
1
)
exit
(
1
)
experiment_id
=
args
.
id
experiment_id
=
args
.
id
print_normal
(
'
Resuming
experiment
%s
...'
%
experiment_id
)
print_normal
(
'
{0}
experiment
{1}
...'
.
format
(
mode
,
experiment_id
)
)
nni_config
=
Config
(
experiment_dict
[
experiment_id
][
'fileName'
])
nni_config
=
Config
(
experiment_dict
[
experiment_id
][
'fileName'
])
experiment_config
=
nni_config
.
get_config
(
'experimentConfig'
)
experiment_config
=
nni_config
.
get_config
(
'experimentConfig'
)
experiment_id
=
nni_config
.
get_config
(
'experimentId'
)
experiment_id
=
nni_config
.
get_config
(
'experimentId'
)
new_config_file_name
=
''
.
join
(
random
.
sample
(
string
.
ascii_letters
+
string
.
digits
,
8
))
new_config_file_name
=
''
.
join
(
random
.
sample
(
string
.
ascii_letters
+
string
.
digits
,
8
))
new_nni_config
=
Config
(
new_config_file_name
)
new_nni_config
=
Config
(
new_config_file_name
)
new_nni_config
.
set_config
(
'experimentConfig'
,
experiment_config
)
new_nni_config
.
set_config
(
'experimentConfig'
,
experiment_config
)
launch_experiment
(
args
,
experiment_config
,
'resume'
,
new_config_file_name
,
experiment_id
)
launch_experiment
(
args
,
experiment_config
,
mode
,
new_config_file_name
,
experiment_id
)
new_nni_config
.
set_config
(
'restServerPort'
,
args
.
port
)
new_nni_config
.
set_config
(
'restServerPort'
,
args
.
port
)
def
create_experiment
(
args
):
def
view_experiment
(
args
):
'''start a new experiment'''
'''view a stopped experiment'''
config_file_name
=
''
.
join
(
random
.
sample
(
string
.
ascii_letters
+
string
.
digits
,
8
))
manage_stopped_experiment
(
args
,
'view'
)
nni_config
=
Config
(
config_file_name
)
config_path
=
os
.
path
.
abspath
(
args
.
config
)
if
not
os
.
path
.
exists
(
config_path
):
print_error
(
'Please set correct config path!'
)
exit
(
1
)
experiment_config
=
get_yml_content
(
config_path
)
validate_all_content
(
experiment_config
,
config_path
)
nni_config
.
set_config
(
'experimentConfig'
,
experiment
_config
)
def
resume_
experiment
(
args
):
launch_experiment
(
args
,
experiment_config
,
'new'
,
config_file_name
)
'''resume an experiment'''
nni_config
.
set_config
(
'restServerPort'
,
args
.
port
)
manage_stopped_experiment
(
args
,
'resume'
)
\ No newline at end of file
tools/nni_cmd/nnictl.py
View file @
c785655e
...
@@ -21,7 +21,7 @@
...
@@ -21,7 +21,7 @@
import
argparse
import
argparse
import
pkg_resources
import
pkg_resources
from
.launcher
import
create_experiment
,
resume_experiment
from
.launcher
import
create_experiment
,
resume_experiment
,
view_experiment
from
.updater
import
update_searchspace
,
update_concurrency
,
update_duration
,
update_trialnum
,
import_data
from
.updater
import
update_searchspace
,
update_concurrency
,
update_duration
,
update_trialnum
,
import_data
from
.nnictl_utils
import
*
from
.nnictl_utils
import
*
from
.package_management
import
*
from
.package_management
import
*
...
@@ -66,6 +66,12 @@ def parse_args():
...
@@ -66,6 +66,12 @@ def parse_args():
parser_resume
.
add_argument
(
'--debug'
,
'-d'
,
action
=
'store_true'
,
help
=
' set debug mode'
)
parser_resume
.
add_argument
(
'--debug'
,
'-d'
,
action
=
'store_true'
,
help
=
' set debug mode'
)
parser_resume
.
set_defaults
(
func
=
resume_experiment
)
parser_resume
.
set_defaults
(
func
=
resume_experiment
)
# parse view command
parser_resume
=
subparsers
.
add_parser
(
'view'
,
help
=
'view a stopped experiment'
)
parser_resume
.
add_argument
(
'id'
,
nargs
=
'?'
,
help
=
'The id of the experiment you want to view'
)
parser_resume
.
add_argument
(
'--port'
,
'-p'
,
default
=
DEFAULT_REST_PORT
,
dest
=
'port'
,
help
=
'the port of restful server'
)
parser_resume
.
set_defaults
(
func
=
view_experiment
)
# parse update command
# parse update command
parser_updater
=
subparsers
.
add_parser
(
'update'
,
help
=
'update the experiment'
)
parser_updater
=
subparsers
.
add_parser
(
'update'
,
help
=
'update the experiment'
)
#add subparsers for parser_updater
#add subparsers for parser_updater
...
...
tools/nni_cmd/nnictl_utils.py
View file @
c785655e
...
@@ -351,6 +351,7 @@ def log_stderr(args):
...
@@ -351,6 +351,7 @@ def log_stderr(args):
def
log_trial
(
args
):
def
log_trial
(
args
):
''''get trial log path'''
''''get trial log path'''
trial_id_path_dict
=
{}
trial_id_path_dict
=
{}
trial_id_list
=
[]
nni_config
=
Config
(
get_config_filename
(
args
))
nni_config
=
Config
(
get_config_filename
(
args
))
rest_port
=
nni_config
.
get_config
(
'restServerPort'
)
rest_port
=
nni_config
.
get_config
(
'restServerPort'
)
rest_pid
=
nni_config
.
get_config
(
'restServerPid'
)
rest_pid
=
nni_config
.
get_config
(
'restServerPid'
)
...
@@ -363,23 +364,27 @@ def log_trial(args):
...
@@ -363,23 +364,27 @@ def log_trial(args):
if
response
and
check_response
(
response
):
if
response
and
check_response
(
response
):
content
=
json
.
loads
(
response
.
text
)
content
=
json
.
loads
(
response
.
text
)
for
trial
in
content
:
for
trial
in
content
:
trial_id_path_dict
[
trial
[
'id'
]]
=
trial
[
'logPath'
]
trial_id_list
.
append
(
trial
.
get
(
'id'
))
if
trial
.
get
(
'logPath'
):
trial_id_path_dict
[
trial
.
get
(
'id'
)]
=
trial
[
'logPath'
]
else
:
else
:
print_error
(
'Restful server is not running...'
)
print_error
(
'Restful server is not running...'
)
exit
(
1
)
exit
(
1
)
if
args
.
id
:
if
args
.
trial_id
:
if
args
.
trial_id
:
if
args
.
trial_id
not
in
trial_id_list
:
if
trial_id_path_dict
.
get
(
args
.
trial_id
):
print_error
(
'Trial id {0} not correct, please check your command!'
.
format
(
args
.
trial_id
))
print_normal
(
'id:'
+
args
.
trial_id
+
' path:'
+
trial_id_path_dict
[
args
.
trial_id
])
exit
(
1
)
else
:
if
trial_id_path_dict
.
get
(
args
.
trial_id
):
print_error
(
'trial id is not valid.'
)
print_normal
(
'id:'
+
args
.
trial_id
+
' path:'
+
trial_id_path_dict
[
args
.
trial_id
])
exit
(
1
)
else
:
else
:
print_error
(
'
please specific the trial id
.'
)
print_error
(
'
Log path is not available yet, please wait..
.'
)
exit
(
1
)
exit
(
1
)
else
:
else
:
print_normal
(
'All of trial log info:'
)
for
key
in
trial_id_path_dict
:
for
key
in
trial_id_path_dict
:
print
(
'id:'
+
key
+
' path:'
+
trial_id_path_dict
[
key
])
print_normal
(
'id:'
+
key
+
' path:'
+
trial_id_path_dict
[
key
])
if
not
trial_id_path_dict
:
print_normal
(
'None'
)
def
get_config
(
args
):
def
get_config
(
args
):
'''get config info'''
'''get config info'''
...
...
Prev
1
…
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment