Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
e9040c9b
Unverified
Commit
e9040c9b
authored
Jul 03, 2019
by
chicm-ms
Committed by
GitHub
Jul 03, 2019
Browse files
Merge pull request #23 from microsoft/master
pull code
parents
256f27af
ed63175c
Changes
108
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
260 additions
and
32 deletions
+260
-32
tools/nni_annotation/__init__.py
tools/nni_annotation/__init__.py
+5
-4
tools/nni_annotation/code_generator.py
tools/nni_annotation/code_generator.py
+22
-9
tools/nni_cmd/config_schema.py
tools/nni_cmd/config_schema.py
+19
-1
tools/nni_cmd/launcher.py
tools/nni_cmd/launcher.py
+2
-1
tools/nni_cmd/nnictl.py
tools/nni_cmd/nnictl.py
+14
-1
tools/nni_cmd/nnictl_utils.py
tools/nni_cmd/nnictl_utils.py
+183
-15
tools/nni_cmd/ssh_utils.py
tools/nni_cmd/ssh_utils.py
+14
-0
tools/nni_trial_tool/trial_keeper.py
tools/nni_trial_tool/trial_keeper.py
+1
-1
No files found.
tools/nni_annotation/__init__.py
View file @
e9040c9b
...
@@ -76,11 +76,12 @@ def _generate_file_search_space(path, module):
...
@@ -76,11 +76,12 @@ def _generate_file_search_space(path, module):
return
search_space
return
search_space
def
expand_annotations
(
src_dir
,
dst_dir
,
exp_id
=
''
,
trial_id
=
''
):
def
expand_annotations
(
src_dir
,
dst_dir
,
exp_id
=
''
,
trial_id
=
''
,
nas_mode
=
None
):
"""Expand annotations in user code.
"""Expand annotations in user code.
Return dst_dir if annotation detected; return src_dir if not.
Return dst_dir if annotation detected; return src_dir if not.
src_dir: directory path of user code (str)
src_dir: directory path of user code (str)
dst_dir: directory to place generated files (str)
dst_dir: directory to place generated files (str)
nas_mode: the mode of NAS given that NAS interface is used
"""
"""
if
src_dir
[
-
1
]
==
slash
:
if
src_dir
[
-
1
]
==
slash
:
src_dir
=
src_dir
[:
-
1
]
src_dir
=
src_dir
[:
-
1
]
...
@@ -108,7 +109,7 @@ def expand_annotations(src_dir, dst_dir, exp_id='', trial_id=''):
...
@@ -108,7 +109,7 @@ def expand_annotations(src_dir, dst_dir, exp_id='', trial_id=''):
dst_path
=
os
.
path
.
join
(
dst_subdir
,
file_name
)
dst_path
=
os
.
path
.
join
(
dst_subdir
,
file_name
)
if
file_name
.
endswith
(
'.py'
):
if
file_name
.
endswith
(
'.py'
):
if
trial_id
==
''
:
if
trial_id
==
''
:
annotated
|=
_expand_file_annotations
(
src_path
,
dst_path
)
annotated
|=
_expand_file_annotations
(
src_path
,
dst_path
,
nas_mode
)
else
:
else
:
module
=
package
+
file_name
[:
-
3
]
module
=
package
+
file_name
[:
-
3
]
annotated
|=
_generate_specific_file
(
src_path
,
dst_path
,
exp_id
,
trial_id
,
module
)
annotated
|=
_generate_specific_file
(
src_path
,
dst_path
,
exp_id
,
trial_id
,
module
)
...
@@ -120,10 +121,10 @@ def expand_annotations(src_dir, dst_dir, exp_id='', trial_id=''):
...
@@ -120,10 +121,10 @@ def expand_annotations(src_dir, dst_dir, exp_id='', trial_id=''):
return
dst_dir
if
annotated
else
src_dir
return
dst_dir
if
annotated
else
src_dir
def
_expand_file_annotations
(
src_path
,
dst_path
):
def
_expand_file_annotations
(
src_path
,
dst_path
,
nas_mode
):
with
open
(
src_path
)
as
src
,
open
(
dst_path
,
'w'
)
as
dst
:
with
open
(
src_path
)
as
src
,
open
(
dst_path
,
'w'
)
as
dst
:
try
:
try
:
annotated_code
=
code_generator
.
parse
(
src
.
read
())
annotated_code
=
code_generator
.
parse
(
src
.
read
()
,
nas_mode
)
if
annotated_code
is
None
:
if
annotated_code
is
None
:
shutil
.
copyfile
(
src_path
,
dst_path
)
shutil
.
copyfile
(
src_path
,
dst_path
)
return
False
return
False
...
...
tools/nni_annotation/code_generator.py
View file @
e9040c9b
...
@@ -21,14 +21,14 @@
...
@@ -21,14 +21,14 @@
import
ast
import
ast
import
astor
import
astor
from
nni_cmd.common_utils
import
print_warning
# pylint: disable=unidiomatic-typecheck
# pylint: disable=unidiomatic-typecheck
def
parse_annotation_mutable_layers
(
code
,
lineno
):
def
parse_annotation_mutable_layers
(
code
,
lineno
,
nas_mode
):
"""Parse the string of mutable layers in annotation.
"""Parse the string of mutable layers in annotation.
Return a list of AST Expr nodes
Return a list of AST Expr nodes
code: annotation string (excluding '@')
code: annotation string (excluding '@')
nas_mode: the mode of NAS
"""
"""
module
=
ast
.
parse
(
code
)
module
=
ast
.
parse
(
code
)
assert
type
(
module
)
is
ast
.
Module
,
'internal error #1'
assert
type
(
module
)
is
ast
.
Module
,
'internal error #1'
...
@@ -110,6 +110,9 @@ def parse_annotation_mutable_layers(code, lineno):
...
@@ -110,6 +110,9 @@ def parse_annotation_mutable_layers(code, lineno):
else
:
else
:
target_call_args
.
append
(
ast
.
Dict
(
keys
=
[],
values
=
[]))
target_call_args
.
append
(
ast
.
Dict
(
keys
=
[],
values
=
[]))
target_call_args
.
append
(
ast
.
Num
(
n
=
0
))
target_call_args
.
append
(
ast
.
Num
(
n
=
0
))
target_call_args
.
append
(
ast
.
Str
(
s
=
nas_mode
))
if
nas_mode
in
[
'enas_mode'
,
'oneshot_mode'
]:
target_call_args
.
append
(
ast
.
Name
(
id
=
'tensorflow'
))
target_call
=
ast
.
Call
(
func
=
target_call_attr
,
args
=
target_call_args
,
keywords
=
[])
target_call
=
ast
.
Call
(
func
=
target_call_attr
,
args
=
target_call_args
,
keywords
=
[])
node
=
ast
.
Assign
(
targets
=
[
layer_output
],
value
=
target_call
)
node
=
ast
.
Assign
(
targets
=
[
layer_output
],
value
=
target_call
)
nodes
.
append
(
node
)
nodes
.
append
(
node
)
...
@@ -277,10 +280,11 @@ class FuncReplacer(ast.NodeTransformer):
...
@@ -277,10 +280,11 @@ class FuncReplacer(ast.NodeTransformer):
class
Transformer
(
ast
.
NodeTransformer
):
class
Transformer
(
ast
.
NodeTransformer
):
"""Transform original code to annotated code"""
"""Transform original code to annotated code"""
def
__init__
(
self
):
def
__init__
(
self
,
nas_mode
=
None
):
self
.
stack
=
[]
self
.
stack
=
[]
self
.
last_line
=
0
self
.
last_line
=
0
self
.
annotated
=
False
self
.
annotated
=
False
self
.
nas_mode
=
nas_mode
def
visit
(
self
,
node
):
def
visit
(
self
,
node
):
if
isinstance
(
node
,
(
ast
.
expr
,
ast
.
stmt
)):
if
isinstance
(
node
,
(
ast
.
expr
,
ast
.
stmt
)):
...
@@ -316,8 +320,11 @@ class Transformer(ast.NodeTransformer):
...
@@ -316,8 +320,11 @@ class Transformer(ast.NodeTransformer):
return
node
# not an annotation, ignore it
return
node
# not an annotation, ignore it
if
string
.
startswith
(
'@nni.get_next_parameter'
):
if
string
.
startswith
(
'@nni.get_next_parameter'
):
deprecated_message
=
"'@nni.get_next_parameter' is deprecated in annotation due to inconvenience. Please remove this line in the trial code."
call_node
=
parse_annotation
(
string
[
1
:]).
value
print_warning
(
deprecated_message
)
if
call_node
.
args
:
# it is used in enas mode as it needs to retrieve the next subgraph for training
call_attr
=
ast
.
Attribute
(
value
=
ast
.
Name
(
id
=
'nni'
,
ctx
=
ast
.
Load
()),
attr
=
'reload_tensorflow_variables'
,
ctx
=
ast
.
Load
())
return
ast
.
Expr
(
value
=
ast
.
Call
(
func
=
call_attr
,
args
=
call_node
.
args
,
keywords
=
[]))
if
string
.
startswith
(
'@nni.report_intermediate_result'
)
\
if
string
.
startswith
(
'@nni.report_intermediate_result'
)
\
or
string
.
startswith
(
'@nni.report_final_result'
)
\
or
string
.
startswith
(
'@nni.report_final_result'
)
\
...
@@ -325,7 +332,8 @@ class Transformer(ast.NodeTransformer):
...
@@ -325,7 +332,8 @@ class Transformer(ast.NodeTransformer):
return
parse_annotation
(
string
[
1
:])
# expand annotation string to code
return
parse_annotation
(
string
[
1
:])
# expand annotation string to code
if
string
.
startswith
(
'@nni.mutable_layers'
):
if
string
.
startswith
(
'@nni.mutable_layers'
):
return
parse_annotation_mutable_layers
(
string
[
1
:],
node
.
lineno
)
nodes
=
parse_annotation_mutable_layers
(
string
[
1
:],
node
.
lineno
,
self
.
nas_mode
)
return
nodes
if
string
.
startswith
(
'@nni.variable'
)
\
if
string
.
startswith
(
'@nni.variable'
)
\
or
string
.
startswith
(
'@nni.function_choice'
):
or
string
.
startswith
(
'@nni.function_choice'
):
...
@@ -343,17 +351,18 @@ class Transformer(ast.NodeTransformer):
...
@@ -343,17 +351,18 @@ class Transformer(ast.NodeTransformer):
return
node
return
node
def
parse
(
code
):
def
parse
(
code
,
nas_mode
=
None
):
"""Annotate user code.
"""Annotate user code.
Return annotated code (str) if annotation detected; return None if not.
Return annotated code (str) if annotation detected; return None if not.
code: original user code (str)
code: original user code (str),
nas_mode: the mode of NAS given that NAS interface is used
"""
"""
try
:
try
:
ast_tree
=
ast
.
parse
(
code
)
ast_tree
=
ast
.
parse
(
code
)
except
Exception
:
except
Exception
:
raise
RuntimeError
(
'Bad Python code'
)
raise
RuntimeError
(
'Bad Python code'
)
transformer
=
Transformer
()
transformer
=
Transformer
(
nas_mode
)
try
:
try
:
transformer
.
visit
(
ast_tree
)
transformer
.
visit
(
ast_tree
)
except
AssertionError
as
exc
:
except
AssertionError
as
exc
:
...
@@ -369,5 +378,9 @@ def parse(code):
...
@@ -369,5 +378,9 @@ def parse(code):
if
type
(
nodes
[
i
])
is
ast
.
ImportFrom
and
nodes
[
i
].
module
==
'__future__'
:
if
type
(
nodes
[
i
])
is
ast
.
ImportFrom
and
nodes
[
i
].
module
==
'__future__'
:
last_future_import
=
i
last_future_import
=
i
nodes
.
insert
(
last_future_import
+
1
,
import_nni
)
nodes
.
insert
(
last_future_import
+
1
,
import_nni
)
# enas and oneshot modes for tensorflow need tensorflow module, so we import it here
if
nas_mode
in
[
'enas_mode'
,
'oneshot_mode'
]:
import_tf
=
ast
.
Import
(
names
=
[
ast
.
alias
(
name
=
'tensorflow'
,
asname
=
None
)])
nodes
.
insert
(
last_future_import
+
1
,
import_tf
)
return
astor
.
to_source
(
ast_tree
)
return
astor
.
to_source
(
ast_tree
)
tools/nni_cmd/config_schema.py
View file @
e9040c9b
...
@@ -104,6 +104,21 @@ tuner_schema_dict = {
...
@@ -104,6 +104,21 @@ tuner_schema_dict = {
},
},
Optional
(
'gpuNum'
):
setNumberRange
(
'gpuNum'
,
int
,
0
,
99999
),
Optional
(
'gpuNum'
):
setNumberRange
(
'gpuNum'
,
int
,
0
,
99999
),
},
},
'GPTuner'
:
{
'builtinTunerName'
:
'GPTuner'
,
'classArgs'
:
{
Optional
(
'optimize_mode'
):
setChoice
(
'optimize_mode'
,
'maximize'
,
'minimize'
),
Optional
(
'utility'
):
setChoice
(
'utility'
,
'ei'
,
'ucb'
,
'poi'
),
Optional
(
'kappa'
):
setType
(
'kappa'
,
float
),
Optional
(
'xi'
):
setType
(
'xi'
,
float
),
Optional
(
'nu'
):
setType
(
'nu'
,
float
),
Optional
(
'alpha'
):
setType
(
'alpha'
,
float
),
Optional
(
'cold_start_num'
):
setType
(
'cold_start_num'
,
int
),
Optional
(
'selection_num_warm_up'
):
setType
(
'selection_num_warm_up'
,
int
),
Optional
(
'selection_num_starting_points'
):
setType
(
'selection_num_starting_points'
,
int
),
},
Optional
(
'gpuNum'
):
setNumberRange
(
'gpuNum'
,
int
,
0
,
99999
),
},
'customized'
:
{
'customized'
:
{
'codeDir'
:
setPathCheck
(
'codeDir'
),
'codeDir'
:
setPathCheck
(
'codeDir'
),
'classFileName'
:
setType
(
'classFileName'
,
str
),
'classFileName'
:
setType
(
'classFileName'
,
str
),
...
@@ -181,7 +196,8 @@ common_trial_schema = {
...
@@ -181,7 +196,8 @@ common_trial_schema = {
'trial'
:{
'trial'
:{
'command'
:
setType
(
'command'
,
str
),
'command'
:
setType
(
'command'
,
str
),
'codeDir'
:
setPathCheck
(
'codeDir'
),
'codeDir'
:
setPathCheck
(
'codeDir'
),
'gpuNum'
:
setNumberRange
(
'gpuNum'
,
int
,
0
,
99999
)
'gpuNum'
:
setNumberRange
(
'gpuNum'
,
int
,
0
,
99999
),
Optional
(
'nasMode'
):
setChoice
(
'classic_mode'
,
'enas_mode'
,
'oneshot_mode'
)
}
}
}
}
...
@@ -199,6 +215,7 @@ pai_trial_schema = {
...
@@ -199,6 +215,7 @@ pai_trial_schema = {
Optional
(
'outputDir'
):
And
(
Regex
(
r
'hdfs://(([0-9]{1,3}.){3}[0-9]{1,3})(:[0-9]{2,5})?(/.*)?'
),
\
Optional
(
'outputDir'
):
And
(
Regex
(
r
'hdfs://(([0-9]{1,3}.){3}[0-9]{1,3})(:[0-9]{2,5})?(/.*)?'
),
\
error
=
'ERROR: outputDir format error, outputDir format is hdfs://xxx.xxx.xxx.xxx:xxx'
),
error
=
'ERROR: outputDir format error, outputDir format is hdfs://xxx.xxx.xxx.xxx:xxx'
),
Optional
(
'virtualCluster'
):
setType
(
'virtualCluster'
,
str
),
Optional
(
'virtualCluster'
):
setType
(
'virtualCluster'
,
str
),
Optional
(
'nasMode'
):
setChoice
(
'classic_mode'
,
'enas_mode'
,
'oneshot_mode'
)
}
}
}
}
...
@@ -213,6 +230,7 @@ pai_config_schema = {
...
@@ -213,6 +230,7 @@ pai_config_schema = {
kubeflow_trial_schema
=
{
kubeflow_trial_schema
=
{
'trial'
:{
'trial'
:{
'codeDir'
:
setPathCheck
(
'codeDir'
),
'codeDir'
:
setPathCheck
(
'codeDir'
),
Optional
(
'nasMode'
):
setChoice
(
'classic_mode'
,
'enas_mode'
,
'oneshot_mode'
),
Optional
(
'ps'
):
{
Optional
(
'ps'
):
{
'replicas'
:
setType
(
'replicas'
,
int
),
'replicas'
:
setType
(
'replicas'
,
int
),
'command'
:
setType
(
'command'
,
str
),
'command'
:
setType
(
'command'
,
str
),
...
...
tools/nni_cmd/launcher.py
View file @
e9040c9b
...
@@ -377,7 +377,8 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
...
@@ -377,7 +377,8 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
if
not
os
.
path
.
isdir
(
path
):
if
not
os
.
path
.
isdir
(
path
):
os
.
makedirs
(
path
)
os
.
makedirs
(
path
)
path
=
tempfile
.
mkdtemp
(
dir
=
path
)
path
=
tempfile
.
mkdtemp
(
dir
=
path
)
code_dir
=
expand_annotations
(
experiment_config
[
'trial'
][
'codeDir'
],
path
)
nas_mode
=
experiment_config
[
'trial'
].
get
(
'nasMode'
,
'classic_mode'
)
code_dir
=
expand_annotations
(
experiment_config
[
'trial'
][
'codeDir'
],
path
,
nas_mode
=
nas_mode
)
experiment_config
[
'trial'
][
'codeDir'
]
=
code_dir
experiment_config
[
'trial'
][
'codeDir'
]
=
code_dir
search_space
=
generate_search_space
(
code_dir
)
search_space
=
generate_search_space
(
code_dir
)
experiment_config
[
'searchSpace'
]
=
json
.
dumps
(
search_space
)
experiment_config
[
'searchSpace'
]
=
json
.
dumps
(
search_space
)
...
...
tools/nni_cmd/nnictl.py
View file @
e9040c9b
...
@@ -119,8 +119,21 @@ def parse_args():
...
@@ -119,8 +119,21 @@ def parse_args():
parser_experiment_status
.
add_argument
(
'id'
,
nargs
=
'?'
,
help
=
'the id of experiment'
)
parser_experiment_status
.
add_argument
(
'id'
,
nargs
=
'?'
,
help
=
'the id of experiment'
)
parser_experiment_status
.
set_defaults
(
func
=
experiment_status
)
parser_experiment_status
.
set_defaults
(
func
=
experiment_status
)
parser_experiment_list
=
parser_experiment_subparsers
.
add_parser
(
'list'
,
help
=
'list all of running experiment ids'
)
parser_experiment_list
=
parser_experiment_subparsers
.
add_parser
(
'list'
,
help
=
'list all of running experiment ids'
)
parser_experiment_list
.
add_argument
(
'all'
,
nargs
=
'?'
,
help
=
'list all of experiments'
)
parser_experiment_list
.
add_argument
(
'
--
all'
,
action
=
'store_true'
,
default
=
False
,
help
=
'list all of experiments'
)
parser_experiment_list
.
set_defaults
(
func
=
experiment_list
)
parser_experiment_list
.
set_defaults
(
func
=
experiment_list
)
parser_experiment_clean
=
parser_experiment_subparsers
.
add_parser
(
'delete'
,
help
=
'clean up the experiment data'
)
parser_experiment_clean
.
add_argument
(
'id'
,
nargs
=
'?'
,
help
=
'the id of experiment'
)
parser_experiment_clean
.
add_argument
(
'--all'
,
action
=
'store_true'
,
default
=
False
,
help
=
'delete all of experiments'
)
parser_experiment_clean
.
set_defaults
(
func
=
experiment_clean
)
#parse experiment command
parser_platform
=
subparsers
.
add_parser
(
'platform'
,
help
=
'get platform information'
)
#add subparsers for parser_experiment
parser_platform_subparsers
=
parser_platform
.
add_subparsers
()
parser_platform_clean
=
parser_platform_subparsers
.
add_parser
(
'clean'
,
help
=
'clean up the platform data'
)
parser_platform_clean
.
add_argument
(
'--config'
,
'-c'
,
required
=
True
,
dest
=
'config'
,
help
=
'the path of yaml config file'
)
parser_platform_clean
.
set_defaults
(
func
=
platform_clean
)
#import tuning data
#import tuning data
parser_import_data
=
parser_experiment_subparsers
.
add_parser
(
'import'
,
help
=
'import additional data'
)
parser_import_data
=
parser_experiment_subparsers
.
add_parser
(
'import'
,
help
=
'import additional data'
)
parser_import_data
.
add_argument
(
'id'
,
nargs
=
'?'
,
help
=
'the id of experiment'
)
parser_import_data
.
add_argument
(
'id'
,
nargs
=
'?'
,
help
=
'the id of experiment'
)
...
...
tools/nni_cmd/nnictl_utils.py
View file @
e9040c9b
...
@@ -24,6 +24,10 @@ import psutil
...
@@ -24,6 +24,10 @@ import psutil
import
json
import
json
import
datetime
import
datetime
import
time
import
time
import
re
from
pathlib
import
Path
from
pyhdfs
import
HdfsClient
,
HdfsFileNotFoundException
import
shutil
from
subprocess
import
call
,
check_output
from
subprocess
import
call
,
check_output
from
nni_annotation
import
expand_annotations
from
nni_annotation
import
expand_annotations
from
.rest_utils
import
rest_get
,
rest_delete
,
check_rest_server_quick
,
check_response
from
.rest_utils
import
rest_get
,
rest_delete
,
check_rest_server_quick
,
check_response
...
@@ -31,8 +35,9 @@ from .url_utils import trial_jobs_url, experiment_url, trial_job_id_url, export_
...
@@ -31,8 +35,9 @@ from .url_utils import trial_jobs_url, experiment_url, trial_job_id_url, export_
from
.config_utils
import
Config
,
Experiments
from
.config_utils
import
Config
,
Experiments
from
.constants
import
NNICTL_HOME_DIR
,
EXPERIMENT_INFORMATION_FORMAT
,
EXPERIMENT_DETAIL_FORMAT
,
\
from
.constants
import
NNICTL_HOME_DIR
,
EXPERIMENT_INFORMATION_FORMAT
,
EXPERIMENT_DETAIL_FORMAT
,
\
EXPERIMENT_MONITOR_INFO
,
TRIAL_MONITOR_HEAD
,
TRIAL_MONITOR_CONTENT
,
TRIAL_MONITOR_TAIL
,
REST_TIME_OUT
EXPERIMENT_MONITOR_INFO
,
TRIAL_MONITOR_HEAD
,
TRIAL_MONITOR_CONTENT
,
TRIAL_MONITOR_TAIL
,
REST_TIME_OUT
from
.common_utils
import
print_normal
,
print_error
,
print_warning
,
detect_process
from
.common_utils
import
print_normal
,
print_error
,
print_warning
,
detect_process
,
get_yml_content
from
.command_utils
import
check_output_command
,
kill_command
from
.command_utils
import
check_output_command
,
kill_command
from
.ssh_utils
import
create_ssh_sftp_client
,
remove_remote_directory
def
get_experiment_time
(
port
):
def
get_experiment_time
(
port
):
'''get the startTime and endTime of an experiment'''
'''get the startTime and endTime of an experiment'''
...
@@ -73,10 +78,11 @@ def update_experiment():
...
@@ -73,10 +78,11 @@ def update_experiment():
if
status
:
if
status
:
experiment_config
.
update_experiment
(
key
,
'status'
,
status
)
experiment_config
.
update_experiment
(
key
,
'status'
,
status
)
def
check_experiment_id
(
args
):
def
check_experiment_id
(
args
,
update
=
True
):
'''check if the id is valid
'''check if the id is valid
'''
'''
update_experiment
()
if
update
:
update_experiment
()
experiment_config
=
Experiments
()
experiment_config
=
Experiments
()
experiment_dict
=
experiment_config
.
get_all_experiments
()
experiment_dict
=
experiment_config
.
get_all_experiments
()
if
not
experiment_dict
:
if
not
experiment_dict
:
...
@@ -100,14 +106,14 @@ def check_experiment_id(args):
...
@@ -100,14 +106,14 @@ def check_experiment_id(args):
print
(
EXPERIMENT_INFORMATION_FORMAT
%
experiment_information
)
print
(
EXPERIMENT_INFORMATION_FORMAT
%
experiment_information
)
exit
(
1
)
exit
(
1
)
elif
not
running_experiment_list
:
elif
not
running_experiment_list
:
print_error
(
'There is no experiment running
!
'
)
print_error
(
'There is no experiment running
.
'
)
return
None
return
None
else
:
else
:
return
running_experiment_list
[
0
]
return
running_experiment_list
[
0
]
if
experiment_dict
.
get
(
args
.
id
):
if
experiment_dict
.
get
(
args
.
id
):
return
args
.
id
return
args
.
id
else
:
else
:
print_error
(
'Id not correct
!
'
)
print_error
(
'Id not correct
.
'
)
return
None
return
None
def
parse_ids
(
args
):
def
parse_ids
(
args
):
...
@@ -145,7 +151,7 @@ def parse_ids(args):
...
@@ -145,7 +151,7 @@ def parse_ids(args):
exit
(
1
)
exit
(
1
)
else
:
else
:
result_list
=
running_experiment_list
result_list
=
running_experiment_list
elif
args
.
id
==
'
all
'
:
elif
args
.
all
:
result_list
=
running_experiment_list
result_list
=
running_experiment_list
elif
args
.
id
.
endswith
(
'*'
):
elif
args
.
id
.
endswith
(
'*'
):
for
id
in
running_experiment_list
:
for
id
in
running_experiment_list
:
...
@@ -170,7 +176,7 @@ def get_config_filename(args):
...
@@ -170,7 +176,7 @@ def get_config_filename(args):
'''get the file name of config file'''
'''get the file name of config file'''
experiment_id
=
check_experiment_id
(
args
)
experiment_id
=
check_experiment_id
(
args
)
if
experiment_id
is
None
:
if
experiment_id
is
None
:
print_error
(
'Please set
the
experiment id
!
'
)
print_error
(
'Please set
correct
experiment id
.
'
)
exit
(
1
)
exit
(
1
)
experiment_config
=
Experiments
()
experiment_config
=
Experiments
()
experiment_dict
=
experiment_config
.
get_all_experiments
()
experiment_dict
=
experiment_config
.
get_all_experiments
()
...
@@ -180,7 +186,7 @@ def get_experiment_port(args):
...
@@ -180,7 +186,7 @@ def get_experiment_port(args):
'''get the port of experiment'''
'''get the port of experiment'''
experiment_id
=
check_experiment_id
(
args
)
experiment_id
=
check_experiment_id
(
args
)
if
experiment_id
is
None
:
if
experiment_id
is
None
:
print_error
(
'Please set
the
experiment id
!
'
)
print_error
(
'Please set
correct
experiment id
.
'
)
exit
(
1
)
exit
(
1
)
experiment_config
=
Experiments
()
experiment_config
=
Experiments
()
experiment_dict
=
experiment_config
.
get_all_experiments
()
experiment_dict
=
experiment_config
.
get_all_experiments
()
...
@@ -229,7 +235,7 @@ def stop_experiment(args):
...
@@ -229,7 +235,7 @@ def stop_experiment(args):
except
Exception
as
exception
:
except
Exception
as
exception
:
print_error
(
exception
)
print_error
(
exception
)
nni_config
.
set_config
(
'tensorboardPidList'
,
[])
nni_config
.
set_config
(
'tensorboardPidList'
,
[])
print_normal
(
'Stop experiment success
!
'
)
print_normal
(
'Stop experiment success
.
'
)
experiment_config
.
update_experiment
(
experiment_id
,
'status'
,
'STOPPED'
)
experiment_config
.
update_experiment
(
experiment_id
,
'status'
,
'STOPPED'
)
time_now
=
time
.
strftime
(
'%Y-%m-%d %H:%M:%S'
,
time
.
localtime
(
time
.
time
()))
time_now
=
time
.
strftime
(
'%Y-%m-%d %H:%M:%S'
,
time
.
localtime
(
time
.
time
()))
experiment_config
.
update_experiment
(
experiment_id
,
'endTime'
,
str
(
time_now
))
experiment_config
.
update_experiment
(
experiment_id
,
'endTime'
,
str
(
time_now
))
...
@@ -354,10 +360,10 @@ def log_trial(args):
...
@@ -354,10 +360,10 @@ def log_trial(args):
if
trial_id_path_dict
.
get
(
args
.
trial_id
):
if
trial_id_path_dict
.
get
(
args
.
trial_id
):
print_normal
(
'id:'
+
args
.
trial_id
+
' path:'
+
trial_id_path_dict
[
args
.
trial_id
])
print_normal
(
'id:'
+
args
.
trial_id
+
' path:'
+
trial_id_path_dict
[
args
.
trial_id
])
else
:
else
:
print_error
(
'trial id is not valid
!
'
)
print_error
(
'trial id is not valid
.
'
)
exit
(
1
)
exit
(
1
)
else
:
else
:
print_error
(
'please specific the trial id
!
'
)
print_error
(
'please specific the trial id
.
'
)
exit
(
1
)
exit
(
1
)
else
:
else
:
for
key
in
trial_id_path_dict
:
for
key
in
trial_id_path_dict
:
...
@@ -373,16 +379,179 @@ def webui_url(args):
...
@@ -373,16 +379,179 @@ def webui_url(args):
nni_config
=
Config
(
get_config_filename
(
args
))
nni_config
=
Config
(
get_config_filename
(
args
))
print_normal
(
'{0} {1}'
.
format
(
'Web UI url:'
,
' '
.
join
(
nni_config
.
get_config
(
'webuiUrl'
))))
print_normal
(
'{0} {1}'
.
format
(
'Web UI url:'
,
' '
.
join
(
nni_config
.
get_config
(
'webuiUrl'
))))
def
local_clean
(
directory
):
'''clean up local data'''
print_normal
(
'removing folder {0}'
.
format
(
directory
))
try
:
shutil
.
rmtree
(
directory
)
except
FileNotFoundError
as
err
:
print_error
(
'{0} does not exist.'
.
format
(
directory
))
def
remote_clean
(
machine_list
,
experiment_id
=
None
):
'''clean up remote data'''
for
machine
in
machine_list
:
passwd
=
machine
.
get
(
'passwd'
)
userName
=
machine
.
get
(
'username'
)
host
=
machine
.
get
(
'ip'
)
port
=
machine
.
get
(
'port'
)
if
experiment_id
:
remote_dir
=
'/'
+
'/'
.
join
([
'tmp'
,
'nni'
,
'experiments'
,
experiment_id
])
else
:
remote_dir
=
'/'
+
'/'
.
join
([
'tmp'
,
'nni'
,
'experiments'
])
sftp
=
create_ssh_sftp_client
(
host
,
port
,
userName
,
passwd
)
print_normal
(
'removing folder {0}'
.
format
(
host
+
':'
+
str
(
port
)
+
remote_dir
))
remove_remote_directory
(
sftp
,
remote_dir
)
def
hdfs_clean
(
host
,
user_name
,
output_dir
,
experiment_id
=
None
):
'''clean up hdfs data'''
hdfs_client
=
HdfsClient
(
hosts
=
'{0}:80'
.
format
(
host
),
user_name
=
user_name
,
webhdfs_path
=
'/webhdfs/api/v1'
,
timeout
=
5
)
if
experiment_id
:
full_path
=
'/'
+
'/'
.
join
([
user_name
,
'nni'
,
'experiments'
,
experiment_id
])
else
:
full_path
=
'/'
+
'/'
.
join
([
user_name
,
'nni'
,
'experiments'
])
print_normal
(
'removing folder {0} in hdfs'
.
format
(
full_path
))
hdfs_client
.
delete
(
full_path
,
recursive
=
True
)
if
output_dir
:
pattern
=
re
.
compile
(
'hdfs://(?P<host>([0-9]{1,3}.){3}[0-9]{1,3})(:[0-9]{2,5})?(?P<baseDir>/.*)?'
)
match_result
=
pattern
.
match
(
output_dir
)
if
match_result
:
output_host
=
match_result
.
group
(
'host'
)
output_dir
=
match_result
.
group
(
'baseDir'
)
#check if the host is valid
if
output_host
!=
host
:
print_warning
(
'The host in {0} is not consistent with {1}'
.
format
(
output_dir
,
host
))
else
:
if
experiment_id
:
output_dir
=
output_dir
+
'/'
+
experiment_id
print_normal
(
'removing folder {0} in hdfs'
.
format
(
output_dir
))
hdfs_client
.
delete
(
output_dir
,
recursive
=
True
)
def
experiment_clean
(
args
):
'''clean up the experiment data'''
experiment_id_list
=
[]
experiment_config
=
Experiments
()
experiment_dict
=
experiment_config
.
get_all_experiments
()
if
args
.
all
:
experiment_id_list
=
list
(
experiment_dict
.
keys
())
else
:
if
args
.
id
is
None
:
print_error
(
'please set experiment id.'
)
exit
(
1
)
if
args
.
id
not
in
experiment_dict
:
print_error
(
'Cannot find experiment {0}.'
.
format
(
args
.
id
))
exit
(
1
)
experiment_id_list
.
append
(
args
.
id
)
while
True
:
print
(
'INFO: This action will delete experiment {0}, and it
\'
s not recoverable.'
.
format
(
' '
.
join
(
experiment_id_list
)))
inputs
=
input
(
'INFO: do you want to continue?[y/N]:'
)
if
not
inputs
.
lower
()
or
inputs
.
lower
()
in
[
'n'
,
'no'
]:
exit
(
0
)
elif
inputs
.
lower
()
not
in
[
'y'
,
'n'
,
'yes'
,
'no'
]:
print_warning
(
'please input Y or N.'
)
else
:
break
for
experiment_id
in
experiment_id_list
:
nni_config
=
Config
(
experiment_dict
[
experiment_id
][
'fileName'
])
platform
=
nni_config
.
get_config
(
'experimentConfig'
).
get
(
'trainingServicePlatform'
)
experiment_id
=
nni_config
.
get_config
(
'experimentId'
)
if
platform
==
'remote'
:
machine_list
=
nni_config
.
get_config
(
'experimentConfig'
).
get
(
'machineList'
)
remote_clean
(
machine_list
,
experiment_id
)
elif
platform
==
'pai'
:
host
=
nni_config
.
get_config
(
'experimentConfig'
).
get
(
'paiConfig'
).
get
(
'host'
)
user_name
=
nni_config
.
get_config
(
'experimentConfig'
).
get
(
'paiConfig'
).
get
(
'userName'
)
output_dir
=
nni_config
.
get_config
(
'experimentConfig'
).
get
(
'trial'
).
get
(
'outputDir'
)
hdfs_clean
(
host
,
user_name
,
output_dir
,
experiment_id
)
elif
platform
!=
'local'
:
#TODO: support all platforms
print_warning
(
'platform {0} clean up not supported yet.'
.
format
(
platform
))
exit
(
0
)
#clean local data
home
=
str
(
Path
.
home
())
local_dir
=
nni_config
.
get_config
(
'experimentConfig'
).
get
(
'logDir'
)
if
not
local_dir
:
local_dir
=
os
.
path
.
join
(
home
,
'nni'
,
'experiments'
,
experiment_id
)
local_clean
(
local_dir
)
experiment_config
=
Experiments
()
print_normal
(
'removing metadata of experiment {0}'
.
format
(
experiment_id
))
experiment_config
.
remove_experiment
(
experiment_id
)
print_normal
(
'Done.'
)
def
get_platform_dir
(
config_content
):
'''get the dir list to be deleted'''
platform
=
config_content
.
get
(
'trainingServicePlatform'
)
dir_list
=
[]
if
platform
==
'remote'
:
machine_list
=
config_content
.
get
(
'machineList'
)
for
machine
in
machine_list
:
host
=
machine
.
get
(
'ip'
)
port
=
machine
.
get
(
'port'
)
dir_list
.
append
(
host
+
':'
+
str
(
port
)
+
'/tmp/nni'
)
elif
platform
==
'pai'
:
pai_config
=
config_content
.
get
(
'paiConfig'
)
host
=
config_content
.
get
(
'paiConfig'
).
get
(
'host'
)
user_name
=
config_content
.
get
(
'paiConfig'
).
get
(
'userName'
)
output_dir
=
config_content
.
get
(
'trial'
).
get
(
'outputDir'
)
dir_list
.
append
(
'server: {0}, path: {1}/nni'
.
format
(
host
,
user_name
))
if
output_dir
:
dir_list
.
append
(
output_dir
)
return
dir_list
def
platform_clean
(
args
):
'''clean up the experiment data'''
config_path
=
os
.
path
.
abspath
(
args
.
config
)
if
not
os
.
path
.
exists
(
config_path
):
print_error
(
'Please set correct config path.'
)
exit
(
1
)
config_content
=
get_yml_content
(
config_path
)
platform
=
config_content
.
get
(
'trainingServicePlatform'
)
if
platform
==
'local'
:
print_normal
(
'it doesn’t need to clean local platform.'
)
exit
(
0
)
if
platform
not
in
[
'remote'
,
'pai'
]:
print_normal
(
'platform {0} not supported.'
.
format
(
platform
))
exit
(
0
)
experiment_config
=
Experiments
()
experiment_dict
=
experiment_config
.
get_all_experiments
()
update_experiment
()
id_list
=
list
(
experiment_dict
.
keys
())
dir_list
=
get_platform_dir
(
config_content
)
if
not
dir_list
:
print_normal
(
'No folder of NNI caches is found.'
)
exit
(
1
)
while
True
:
print_normal
(
'This command will remove below folders of NNI caches. If other users are using experiments on below hosts, it will be broken.'
)
for
dir
in
dir_list
:
print
(
' '
+
dir
)
inputs
=
input
(
'INFO: do you want to continue?[y/N]:'
)
if
not
inputs
.
lower
()
or
inputs
.
lower
()
in
[
'n'
,
'no'
]:
exit
(
0
)
elif
inputs
.
lower
()
not
in
[
'y'
,
'n'
,
'yes'
,
'no'
]:
print_warning
(
'please input Y or N.'
)
else
:
break
if
platform
==
'remote'
:
machine_list
=
config_content
.
get
(
'machineList'
)
for
machine
in
machine_list
:
remote_clean
(
machine_list
,
None
)
elif
platform
==
'pai'
:
pai_config
=
config_content
.
get
(
'paiConfig'
)
host
=
config_content
.
get
(
'paiConfig'
).
get
(
'host'
)
user_name
=
config_content
.
get
(
'paiConfig'
).
get
(
'userName'
)
output_dir
=
config_content
.
get
(
'trial'
).
get
(
'outputDir'
)
hdfs_clean
(
host
,
user_name
,
output_dir
,
None
)
print_normal
(
'Done.'
)
def
experiment_list
(
args
):
def
experiment_list
(
args
):
'''get the information of all experiments'''
'''get the information of all experiments'''
experiment_config
=
Experiments
()
experiment_config
=
Experiments
()
experiment_dict
=
experiment_config
.
get_all_experiments
()
experiment_dict
=
experiment_config
.
get_all_experiments
()
if
not
experiment_dict
:
if
not
experiment_dict
:
print
(
'There is no
experiment
running..
.'
)
print
_normal
(
'Cannot find
experiment
s
.'
)
exit
(
1
)
exit
(
1
)
update_experiment
()
update_experiment
()
experiment_id_list
=
[]
experiment_id_list
=
[]
if
args
.
all
and
args
.
all
==
'all'
:
if
args
.
all
:
for
key
in
experiment_dict
.
keys
():
for
key
in
experiment_dict
.
keys
():
experiment_id_list
.
append
(
key
)
experiment_id_list
.
append
(
key
)
else
:
else
:
...
@@ -390,10 +559,9 @@ def experiment_list(args):
...
@@ -390,10 +559,9 @@ def experiment_list(args):
if
experiment_dict
[
key
][
'status'
]
!=
'STOPPED'
:
if
experiment_dict
[
key
][
'status'
]
!=
'STOPPED'
:
experiment_id_list
.
append
(
key
)
experiment_id_list
.
append
(
key
)
if
not
experiment_id_list
:
if
not
experiment_id_list
:
print_warning
(
'There is no experiment running...
\n
You can use
\'
nnictl experiment list all
\'
to list all stopped experiments
!
'
)
print_warning
(
'There is no experiment running...
\n
You can use
\'
nnictl experiment list
--
all
\'
to list all stopped experiments
.
'
)
experiment_information
=
""
experiment_information
=
""
for
key
in
experiment_id_list
:
for
key
in
experiment_id_list
:
experiment_information
+=
(
EXPERIMENT_DETAIL_FORMAT
%
(
key
,
experiment_dict
[
key
][
'status'
],
experiment_dict
[
key
][
'port'
],
\
experiment_information
+=
(
EXPERIMENT_DETAIL_FORMAT
%
(
key
,
experiment_dict
[
key
][
'status'
],
experiment_dict
[
key
][
'port'
],
\
experiment_dict
[
key
].
get
(
'platform'
),
experiment_dict
[
key
][
'startTime'
],
experiment_dict
[
key
][
'endTime'
]))
experiment_dict
[
key
].
get
(
'platform'
),
experiment_dict
[
key
][
'startTime'
],
experiment_dict
[
key
][
'endTime'
]))
print
(
EXPERIMENT_INFORMATION_FORMAT
%
experiment_information
)
print
(
EXPERIMENT_INFORMATION_FORMAT
%
experiment_information
)
...
...
tools/nni_cmd/ssh_utils.py
View file @
e9040c9b
...
@@ -57,3 +57,17 @@ def create_ssh_sftp_client(host_ip, port, username, password):
...
@@ -57,3 +57,17 @@ def create_ssh_sftp_client(host_ip, port, username, password):
return
sftp
return
sftp
except
Exception
as
exception
:
except
Exception
as
exception
:
print_error
(
'Create ssh client error %s
\n
'
%
exception
)
print_error
(
'Create ssh client error %s
\n
'
%
exception
)
def
remove_remote_directory
(
sftp
,
directory
):
'''remove a directory in remote machine'''
try
:
files
=
sftp
.
listdir
(
directory
)
for
file
in
files
:
filepath
=
'/'
.
join
([
directory
,
file
])
try
:
sftp
.
remove
(
filepath
)
except
IOError
:
remove_remote_directory
(
sftp
,
filepath
)
sftp
.
rmdir
(
directory
)
except
IOError
as
err
:
print_error
(
err
)
\ No newline at end of file
tools/nni_trial_tool/trial_keeper.py
View file @
e9040c9b
...
@@ -224,7 +224,7 @@ if __name__ == '__main__':
...
@@ -224,7 +224,7 @@ if __name__ == '__main__':
exit
(
1
)
exit
(
1
)
check_version
(
args
)
check_version
(
args
)
try
:
try
:
if
is_multi_phase
():
if
NNI_PLATFORM
==
'pai'
and
is_multi_phase
():
fetch_parameter_file
(
args
)
fetch_parameter_file
(
args
)
main_loop
(
args
)
main_loop
(
args
)
except
SystemExit
as
se
:
except
SystemExit
as
se
:
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment