Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
d6febf29
Commit
d6febf29
authored
Jun 25, 2019
by
suiguoxin
Browse files
Merge branch 'master' of
git://github.com/microsoft/nni
parents
77c95479
c2179921
Changes
90
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
324 additions
and
44 deletions
+324
-44
tools/nni_annotation/__init__.py
tools/nni_annotation/__init__.py
+5
-4
tools/nni_annotation/code_generator.py
tools/nni_annotation/code_generator.py
+22
-9
tools/nni_cmd/config_schema.py
tools/nni_cmd/config_schema.py
+2
-1
tools/nni_cmd/launcher.py
tools/nni_cmd/launcher.py
+1
-1
tools/nni_cmd/nnictl.py
tools/nni_cmd/nnictl.py
+13
-0
tools/nni_cmd/nnictl_utils.py
tools/nni_cmd/nnictl_utils.py
+171
-6
tools/nni_cmd/ssh_utils.py
tools/nni_cmd/ssh_utils.py
+14
-0
tools/nni_trial_tool/constants.py
tools/nni_trial_tool/constants.py
+3
-1
tools/nni_trial_tool/trial_keeper.py
tools/nni_trial_tool/trial_keeper.py
+87
-20
tools/nni_trial_tool/url_utils.py
tools/nni_trial_tool/url_utils.py
+6
-2
No files found.
tools/nni_annotation/__init__.py
View file @
d6febf29
...
...
@@ -76,11 +76,12 @@ def _generate_file_search_space(path, module):
return
search_space
def
expand_annotations
(
src_dir
,
dst_dir
,
exp_id
=
''
,
trial_id
=
''
):
def
expand_annotations
(
src_dir
,
dst_dir
,
exp_id
=
''
,
trial_id
=
''
,
nas_mode
=
None
):
"""Expand annotations in user code.
Return dst_dir if annotation detected; return src_dir if not.
src_dir: directory path of user code (str)
dst_dir: directory to place generated files (str)
nas_mode: the mode of NAS given that NAS interface is used
"""
if
src_dir
[
-
1
]
==
slash
:
src_dir
=
src_dir
[:
-
1
]
...
...
@@ -108,7 +109,7 @@ def expand_annotations(src_dir, dst_dir, exp_id='', trial_id=''):
dst_path
=
os
.
path
.
join
(
dst_subdir
,
file_name
)
if
file_name
.
endswith
(
'.py'
):
if
trial_id
==
''
:
annotated
|=
_expand_file_annotations
(
src_path
,
dst_path
)
annotated
|=
_expand_file_annotations
(
src_path
,
dst_path
,
nas_mode
)
else
:
module
=
package
+
file_name
[:
-
3
]
annotated
|=
_generate_specific_file
(
src_path
,
dst_path
,
exp_id
,
trial_id
,
module
)
...
...
@@ -120,10 +121,10 @@ def expand_annotations(src_dir, dst_dir, exp_id='', trial_id=''):
return
dst_dir
if
annotated
else
src_dir
def
_expand_file_annotations
(
src_path
,
dst_path
):
def
_expand_file_annotations
(
src_path
,
dst_path
,
nas_mode
):
with
open
(
src_path
)
as
src
,
open
(
dst_path
,
'w'
)
as
dst
:
try
:
annotated_code
=
code_generator
.
parse
(
src
.
read
())
annotated_code
=
code_generator
.
parse
(
src
.
read
()
,
nas_mode
)
if
annotated_code
is
None
:
shutil
.
copyfile
(
src_path
,
dst_path
)
return
False
...
...
tools/nni_annotation/code_generator.py
View file @
d6febf29
...
...
@@ -21,14 +21,14 @@
import
ast
import
astor
from
nni_cmd.common_utils
import
print_warning
# pylint: disable=unidiomatic-typecheck
def
parse_annotation_mutable_layers
(
code
,
lineno
):
def
parse_annotation_mutable_layers
(
code
,
lineno
,
nas_mode
):
"""Parse the string of mutable layers in annotation.
Return a list of AST Expr nodes
code: annotation string (excluding '@')
nas_mode: the mode of NAS
"""
module
=
ast
.
parse
(
code
)
assert
type
(
module
)
is
ast
.
Module
,
'internal error #1'
...
...
@@ -110,6 +110,9 @@ def parse_annotation_mutable_layers(code, lineno):
else
:
target_call_args
.
append
(
ast
.
Dict
(
keys
=
[],
values
=
[]))
target_call_args
.
append
(
ast
.
Num
(
n
=
0
))
target_call_args
.
append
(
ast
.
Str
(
s
=
nas_mode
))
if
nas_mode
in
[
'enas_mode'
,
'oneshot_mode'
]:
target_call_args
.
append
(
ast
.
Name
(
id
=
'tensorflow'
))
target_call
=
ast
.
Call
(
func
=
target_call_attr
,
args
=
target_call_args
,
keywords
=
[])
node
=
ast
.
Assign
(
targets
=
[
layer_output
],
value
=
target_call
)
nodes
.
append
(
node
)
...
...
@@ -277,10 +280,11 @@ class FuncReplacer(ast.NodeTransformer):
class
Transformer
(
ast
.
NodeTransformer
):
"""Transform original code to annotated code"""
def
__init__
(
self
):
def
__init__
(
self
,
nas_mode
=
None
):
self
.
stack
=
[]
self
.
last_line
=
0
self
.
annotated
=
False
self
.
nas_mode
=
nas_mode
def
visit
(
self
,
node
):
if
isinstance
(
node
,
(
ast
.
expr
,
ast
.
stmt
)):
...
...
@@ -316,8 +320,11 @@ class Transformer(ast.NodeTransformer):
return
node
# not an annotation, ignore it
if
string
.
startswith
(
'@nni.get_next_parameter'
):
deprecated_message
=
"'@nni.get_next_parameter' is deprecated in annotation due to inconvenience. Please remove this line in the trial code."
print_warning
(
deprecated_message
)
call_node
=
parse_annotation
(
string
[
1
:]).
value
if
call_node
.
args
:
# it is used in enas mode as it needs to retrieve the next subgraph for training
call_attr
=
ast
.
Attribute
(
value
=
ast
.
Name
(
id
=
'nni'
,
ctx
=
ast
.
Load
()),
attr
=
'reload_tensorflow_variables'
,
ctx
=
ast
.
Load
())
return
ast
.
Expr
(
value
=
ast
.
Call
(
func
=
call_attr
,
args
=
call_node
.
args
,
keywords
=
[]))
if
string
.
startswith
(
'@nni.report_intermediate_result'
)
\
or
string
.
startswith
(
'@nni.report_final_result'
)
\
...
...
@@ -325,7 +332,8 @@ class Transformer(ast.NodeTransformer):
return
parse_annotation
(
string
[
1
:])
# expand annotation string to code
if
string
.
startswith
(
'@nni.mutable_layers'
):
return
parse_annotation_mutable_layers
(
string
[
1
:],
node
.
lineno
)
nodes
=
parse_annotation_mutable_layers
(
string
[
1
:],
node
.
lineno
,
self
.
nas_mode
)
return
nodes
if
string
.
startswith
(
'@nni.variable'
)
\
or
string
.
startswith
(
'@nni.function_choice'
):
...
...
@@ -343,17 +351,18 @@ class Transformer(ast.NodeTransformer):
return
node
def
parse
(
code
):
def
parse
(
code
,
nas_mode
=
None
):
"""Annotate user code.
Return annotated code (str) if annotation detected; return None if not.
code: original user code (str)
code: original user code (str),
nas_mode: the mode of NAS given that NAS interface is used
"""
try
:
ast_tree
=
ast
.
parse
(
code
)
except
Exception
:
raise
RuntimeError
(
'Bad Python code'
)
transformer
=
Transformer
()
transformer
=
Transformer
(
nas_mode
)
try
:
transformer
.
visit
(
ast_tree
)
except
AssertionError
as
exc
:
...
...
@@ -369,5 +378,9 @@ def parse(code):
if
type
(
nodes
[
i
])
is
ast
.
ImportFrom
and
nodes
[
i
].
module
==
'__future__'
:
last_future_import
=
i
nodes
.
insert
(
last_future_import
+
1
,
import_nni
)
# enas and oneshot modes for tensorflow need tensorflow module, so we import it here
if
nas_mode
in
[
'enas_mode'
,
'oneshot_mode'
]:
import_tf
=
ast
.
Import
(
names
=
[
ast
.
alias
(
name
=
'tensorflow'
,
asname
=
None
)])
nodes
.
insert
(
last_future_import
+
1
,
import_tf
)
return
astor
.
to_source
(
ast_tree
)
tools/nni_cmd/config_schema.py
View file @
d6febf29
...
...
@@ -196,7 +196,8 @@ common_trial_schema = {
'trial'
:{
'command'
:
setType
(
'command'
,
str
),
'codeDir'
:
setPathCheck
(
'codeDir'
),
'gpuNum'
:
setNumberRange
(
'gpuNum'
,
int
,
0
,
99999
)
'gpuNum'
:
setNumberRange
(
'gpuNum'
,
int
,
0
,
99999
),
Optional
(
'nasMode'
):
setChoice
(
'classic_mode'
,
'enas_mode'
,
'oneshot_mode'
)
}
}
...
...
tools/nni_cmd/launcher.py
View file @
d6febf29
...
...
@@ -377,7 +377,7 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
if
not
os
.
path
.
isdir
(
path
):
os
.
makedirs
(
path
)
path
=
tempfile
.
mkdtemp
(
dir
=
path
)
code_dir
=
expand_annotations
(
experiment_config
[
'trial'
][
'codeDir'
],
path
)
code_dir
=
expand_annotations
(
experiment_config
[
'trial'
][
'codeDir'
],
path
,
nas_mode
=
experiment_config
[
'trial'
][
'nasMode'
]
)
experiment_config
[
'trial'
][
'codeDir'
]
=
code_dir
search_space
=
generate_search_space
(
code_dir
)
experiment_config
[
'searchSpace'
]
=
json
.
dumps
(
search_space
)
...
...
tools/nni_cmd/nnictl.py
View file @
d6febf29
...
...
@@ -121,6 +121,19 @@ def parse_args():
parser_experiment_list
=
parser_experiment_subparsers
.
add_parser
(
'list'
,
help
=
'list all of running experiment ids'
)
parser_experiment_list
.
add_argument
(
'all'
,
nargs
=
'?'
,
help
=
'list all of experiments'
)
parser_experiment_list
.
set_defaults
(
func
=
experiment_list
)
parser_experiment_clean
=
parser_experiment_subparsers
.
add_parser
(
'delete'
,
help
=
'clean up the experiment data'
)
parser_experiment_clean
.
add_argument
(
'id'
,
nargs
=
'?'
,
help
=
'the id of experiment'
)
parser_experiment_clean
.
add_argument
(
'--all'
,
action
=
'store_true'
,
default
=
False
,
help
=
'delete all of experiments'
)
parser_experiment_clean
.
set_defaults
(
func
=
experiment_clean
)
#parse experiment command
parser_platform
=
subparsers
.
add_parser
(
'platform'
,
help
=
'get platform information'
)
#add subparsers for parser_experiment
parser_platform_subparsers
=
parser_platform
.
add_subparsers
()
parser_platform_clean
=
parser_platform_subparsers
.
add_parser
(
'clean'
,
help
=
'clean up the platform data'
)
parser_platform_clean
.
add_argument
(
'--config'
,
'-c'
,
required
=
True
,
dest
=
'config'
,
help
=
'the path of yaml config file'
)
parser_platform_clean
.
set_defaults
(
func
=
platform_clean
)
#import tuning data
parser_import_data
=
parser_experiment_subparsers
.
add_parser
(
'import'
,
help
=
'import additional data'
)
parser_import_data
.
add_argument
(
'id'
,
nargs
=
'?'
,
help
=
'the id of experiment'
)
...
...
tools/nni_cmd/nnictl_utils.py
View file @
d6febf29
...
...
@@ -24,6 +24,10 @@ import psutil
import
json
import
datetime
import
time
import
re
from
pathlib
import
Path
from
pyhdfs
import
HdfsClient
,
HdfsFileNotFoundException
import
shutil
from
subprocess
import
call
,
check_output
from
nni_annotation
import
expand_annotations
from
.rest_utils
import
rest_get
,
rest_delete
,
check_rest_server_quick
,
check_response
...
...
@@ -31,8 +35,9 @@ from .url_utils import trial_jobs_url, experiment_url, trial_job_id_url, export_
from
.config_utils
import
Config
,
Experiments
from
.constants
import
NNICTL_HOME_DIR
,
EXPERIMENT_INFORMATION_FORMAT
,
EXPERIMENT_DETAIL_FORMAT
,
\
EXPERIMENT_MONITOR_INFO
,
TRIAL_MONITOR_HEAD
,
TRIAL_MONITOR_CONTENT
,
TRIAL_MONITOR_TAIL
,
REST_TIME_OUT
from
.common_utils
import
print_normal
,
print_error
,
print_warning
,
detect_process
from
.common_utils
import
print_normal
,
print_error
,
print_warning
,
detect_process
,
get_yml_content
from
.command_utils
import
check_output_command
,
kill_command
from
.ssh_utils
import
create_ssh_sftp_client
,
remove_remote_directory
def
get_experiment_time
(
port
):
'''get the startTime and endTime of an experiment'''
...
...
@@ -73,10 +78,11 @@ def update_experiment():
if
status
:
experiment_config
.
update_experiment
(
key
,
'status'
,
status
)
def
check_experiment_id
(
args
):
def
check_experiment_id
(
args
,
update
=
True
):
'''check if the id is valid
'''
update_experiment
()
if
update
:
update_experiment
()
experiment_config
=
Experiments
()
experiment_dict
=
experiment_config
.
get_all_experiments
()
if
not
experiment_dict
:
...
...
@@ -170,7 +176,7 @@ def get_config_filename(args):
'''get the file name of config file'''
experiment_id
=
check_experiment_id
(
args
)
if
experiment_id
is
None
:
print_error
(
'Please set
the
experiment id!'
)
print_error
(
'Please set
correct
experiment id!'
)
exit
(
1
)
experiment_config
=
Experiments
()
experiment_dict
=
experiment_config
.
get_all_experiments
()
...
...
@@ -180,7 +186,7 @@ def get_experiment_port(args):
'''get the port of experiment'''
experiment_id
=
check_experiment_id
(
args
)
if
experiment_id
is
None
:
print_error
(
'Please set
the
experiment id!'
)
print_error
(
'Please set
correct
experiment id!'
)
exit
(
1
)
experiment_config
=
Experiments
()
experiment_dict
=
experiment_config
.
get_all_experiments
()
...
...
@@ -373,6 +379,166 @@ def webui_url(args):
nni_config
=
Config
(
get_config_filename
(
args
))
print_normal
(
'{0} {1}'
.
format
(
'Web UI url:'
,
' '
.
join
(
nni_config
.
get_config
(
'webuiUrl'
))))
def
local_clean
(
directory
):
'''clean up local data'''
print_normal
(
'removing folder {0}'
.
format
(
directory
))
try
:
shutil
.
rmtree
(
directory
)
except
FileNotFoundError
as
err
:
print_error
(
'{0} does not exist!'
.
format
(
directory
))
def
remote_clean
(
machine_list
,
experiment_id
=
None
):
'''clean up remote data'''
for
machine
in
machine_list
:
passwd
=
machine
.
get
(
'passwd'
)
userName
=
machine
.
get
(
'username'
)
host
=
machine
.
get
(
'ip'
)
port
=
machine
.
get
(
'port'
)
if
experiment_id
:
remote_dir
=
'/'
+
'/'
.
join
([
'tmp'
,
'nni'
,
'experiments'
,
experiment_id
])
else
:
remote_dir
=
'/'
+
'/'
.
join
([
'tmp'
,
'nni'
,
'experiments'
])
sftp
=
create_ssh_sftp_client
(
host
,
port
,
userName
,
passwd
)
print_normal
(
'removing folder {0}'
.
format
(
host
+
':'
+
str
(
port
)
+
remote_dir
))
remove_remote_directory
(
sftp
,
remote_dir
)
def
hdfs_clean
(
host
,
user_name
,
output_dir
,
experiment_id
=
None
):
'''clean up hdfs data'''
hdfs_client
=
HdfsClient
(
hosts
=
'{0}:80'
.
format
(
host
),
user_name
=
user_name
,
webhdfs_path
=
'/webhdfs/api/v1'
,
timeout
=
5
)
if
experiment_id
:
full_path
=
'/'
+
'/'
.
join
([
user_name
,
'nni'
,
'experiments'
,
experiment_id
])
else
:
full_path
=
'/'
+
'/'
.
join
([
user_name
,
'nni'
,
'experiments'
])
print_normal
(
'removing folder {0} in hdfs'
.
format
(
full_path
))
hdfs_client
.
delete
(
full_path
,
recursive
=
True
)
if
output_dir
:
pattern
=
re
.
compile
(
'hdfs://(?P<host>([0-9]{1,3}.){3}[0-9]{1,3})(:[0-9]{2,5})?(?P<baseDir>/.*)?'
)
match_result
=
pattern
.
match
(
output_dir
)
if
match_result
:
output_host
=
match_result
.
group
(
'host'
)
output_dir
=
match_result
.
group
(
'baseDir'
)
#check if the host is valid
if
output_host
!=
host
:
print_warning
(
'The host in {0} is not consistent with {1}'
.
format
(
output_dir
,
host
))
else
:
if
experiment_id
:
output_dir
=
output_dir
+
'/'
+
experiment_id
print_normal
(
'removing folder {0} in hdfs'
.
format
(
output_dir
))
hdfs_client
.
delete
(
output_dir
,
recursive
=
True
)
def
experiment_clean
(
args
):
'''clean up the experiment data'''
experiment_id_list
=
[]
experiment_config
=
Experiments
()
experiment_dict
=
experiment_config
.
get_all_experiments
()
if
args
.
all
:
experiment_id_list
=
list
(
experiment_dict
.
keys
())
else
:
if
args
.
id
is
None
:
print_error
(
'please set experiment id!'
)
exit
(
1
)
if
args
.
id
not
in
experiment_dict
:
print_error
(
'can not find id {0}!'
.
format
(
args
.
id
))
exit
(
1
)
experiment_id_list
.
append
(
args
.
id
)
while
True
:
print
(
'INFO: This action will delete experiment {0}, and it’s not recoverable.'
.
format
(
' '
.
join
(
experiment_id_list
)))
inputs
=
input
(
'INFO: do you want to continue?[y/N]:'
)
if
not
inputs
.
lower
()
or
inputs
.
lower
()
in
[
'n'
,
'no'
]:
exit
(
0
)
elif
inputs
.
lower
()
not
in
[
'y'
,
'n'
,
'yes'
,
'no'
]:
print_warning
(
'please input Y or N!'
)
else
:
break
for
experiment_id
in
experiment_id_list
:
nni_config
=
Config
(
experiment_dict
[
experiment_id
][
'fileName'
])
platform
=
nni_config
.
get_config
(
'experimentConfig'
).
get
(
'trainingServicePlatform'
)
experiment_id
=
nni_config
.
get_config
(
'experimentId'
)
if
platform
==
'remote'
:
machine_list
=
nni_config
.
get_config
(
'experimentConfig'
).
get
(
'machineList'
)
remote_clean
(
machine_list
,
experiment_id
)
elif
platform
==
'pai'
:
host
=
nni_config
.
get_config
(
'experimentConfig'
).
get
(
'paiConfig'
).
get
(
'host'
)
user_name
=
nni_config
.
get_config
(
'experimentConfig'
).
get
(
'paiConfig'
).
get
(
'userName'
)
output_dir
=
nni_config
.
get_config
(
'experimentConfig'
).
get
(
'trial'
).
get
(
'outputDir'
)
hdfs_clean
(
host
,
user_name
,
output_dir
,
experiment_id
)
elif
platform
!=
'local'
:
#TODO: support all platforms
print_warning
(
'platform {0} clean up not supported yet!'
.
format
(
platform
))
exit
(
0
)
#clean local data
home
=
str
(
Path
.
home
())
local_dir
=
nni_config
.
get_config
(
'experimentConfig'
).
get
(
'logDir'
)
if
not
local_dir
:
local_dir
=
os
.
path
.
join
(
home
,
'nni'
,
'experiments'
,
experiment_id
)
local_clean
(
local_dir
)
experiment_config
=
Experiments
()
print_normal
(
'removing metadata of experiment {0}'
.
format
(
experiment_id
))
experiment_config
.
remove_experiment
(
experiment_id
)
print_normal
(
'Finish!'
)
def
get_platform_dir
(
config_content
):
'''get the dir list to be deleted'''
platform
=
config_content
.
get
(
'trainingServicePlatform'
)
dir_list
=
[]
if
platform
==
'remote'
:
machine_list
=
config_content
.
get
(
'machineList'
)
for
machine
in
machine_list
:
host
=
machine
.
get
(
'ip'
)
port
=
machine
.
get
(
'port'
)
dir_list
.
append
(
host
+
':'
+
str
(
port
)
+
'/tmp/nni/experiments'
)
elif
platform
==
'pai'
:
pai_config
=
config_content
.
get
(
'paiConfig'
)
host
=
config_content
.
get
(
'paiConfig'
).
get
(
'host'
)
user_name
=
config_content
.
get
(
'paiConfig'
).
get
(
'userName'
)
output_dir
=
config_content
.
get
(
'trial'
).
get
(
'outputDir'
)
dir_list
.
append
(
'hdfs://{0}:9000/{1}/nni/experiments'
.
format
(
host
,
user_name
))
if
output_dir
:
dir_list
.
append
(
output_dir
)
return
dir_list
def
platform_clean
(
args
):
'''clean up the experiment data'''
config_path
=
os
.
path
.
abspath
(
args
.
config
)
if
not
os
.
path
.
exists
(
config_path
):
print_error
(
'Please set correct config path!'
)
exit
(
1
)
config_content
=
get_yml_content
(
config_path
)
platform
=
config_content
.
get
(
'trainingServicePlatform'
)
if
platform
not
in
[
'remote'
,
'pai'
]:
print_normal
(
'platform {0} not supported!'
.
format
(
platform
))
exit
(
0
)
experiment_config
=
Experiments
()
experiment_dict
=
experiment_config
.
get_all_experiments
()
update_experiment
()
id_list
=
list
(
experiment_dict
.
keys
())
dir_list
=
get_platform_dir
(
config_content
)
if
not
dir_list
:
print_normal
(
'No folder of NNI caches is found!'
)
exit
(
1
)
while
True
:
print_normal
(
'This command will remove below folders of NNI caches. If other users are using experiments on below hosts, it will be broken.'
)
for
dir
in
dir_list
:
print
(
' '
+
dir
)
inputs
=
input
(
'INFO: do you want to continue?[y/N]:'
)
if
not
inputs
.
lower
()
or
inputs
.
lower
()
in
[
'n'
,
'no'
]:
exit
(
0
)
elif
inputs
.
lower
()
not
in
[
'y'
,
'n'
,
'yes'
,
'no'
]:
print_warning
(
'please input Y or N!'
)
else
:
break
if
platform
==
'remote'
:
machine_list
=
config_content
.
get
(
'machineList'
)
for
machine
in
machine_list
:
remote_clean
(
machine_list
,
None
)
elif
platform
==
'pai'
:
pai_config
=
config_content
.
get
(
'paiConfig'
)
host
=
config_content
.
get
(
'paiConfig'
).
get
(
'host'
)
user_name
=
config_content
.
get
(
'paiConfig'
).
get
(
'userName'
)
output_dir
=
config_content
.
get
(
'trial'
).
get
(
'outputDir'
)
hdfs_clean
(
host
,
user_name
,
output_dir
,
None
)
print_normal
(
'Done!'
)
def
experiment_list
(
args
):
'''get the information of all experiments'''
experiment_config
=
Experiments
()
...
...
@@ -393,7 +559,6 @@ def experiment_list(args):
print_warning
(
'There is no experiment running...
\n
You can use
\'
nnictl experiment list all
\'
to list all stopped experiments!'
)
experiment_information
=
""
for
key
in
experiment_id_list
:
experiment_information
+=
(
EXPERIMENT_DETAIL_FORMAT
%
(
key
,
experiment_dict
[
key
][
'status'
],
experiment_dict
[
key
][
'port'
],
\
experiment_dict
[
key
].
get
(
'platform'
),
experiment_dict
[
key
][
'startTime'
],
experiment_dict
[
key
][
'endTime'
]))
print
(
EXPERIMENT_INFORMATION_FORMAT
%
experiment_information
)
...
...
tools/nni_cmd/ssh_utils.py
View file @
d6febf29
...
...
@@ -57,3 +57,17 @@ def create_ssh_sftp_client(host_ip, port, username, password):
return
sftp
except
Exception
as
exception
:
print_error
(
'Create ssh client error %s
\n
'
%
exception
)
def
remove_remote_directory
(
sftp
,
directory
):
'''remove a directory in remote machine'''
try
:
files
=
sftp
.
listdir
(
directory
)
for
file
in
files
:
filepath
=
'/'
.
join
([
directory
,
file
])
try
:
sftp
.
remove
(
filepath
)
except
IOError
:
remove_remote_directory
(
sftp
,
filepath
)
sftp
.
rmdir
(
directory
)
except
IOError
as
err
:
print_error
(
err
)
\ No newline at end of file
tools/nni_trial_tool/constants.py
View file @
d6febf29
...
...
@@ -36,6 +36,8 @@ STDERR_FULL_PATH = os.path.join(LOG_DIR, 'stderr')
STDOUT_API
=
'/stdout'
VERSION_API
=
'/version'
PARAMETER_META_API
=
'/parameter-file-meta'
NNI_SYS_DIR
=
os
.
environ
[
'NNI_SYS_DIR'
]
NNI_TRIAL_JOB_ID
=
os
.
environ
[
'NNI_TRIAL_JOB_ID'
]
NNI_EXP_ID
=
os
.
environ
[
'NNI_EXP_ID'
]
\ No newline at end of file
NNI_EXP_ID
=
os
.
environ
[
'NNI_EXP_ID'
]
MULTI_PHASE
=
os
.
environ
[
'MULTI_PHASE'
]
tools/nni_trial_tool/trial_keeper.py
View file @
d6febf29
...
...
@@ -28,18 +28,49 @@ import re
import
sys
import
select
import
json
import
threading
from
pyhdfs
import
HdfsClient
import
pkg_resources
from
.rest_utils
import
rest_post
from
.url_utils
import
gen_send_stdout_url
,
gen_send_version_url
from
.rest_utils
import
rest_post
,
rest_get
from
.url_utils
import
gen_send_stdout_url
,
gen_send_version_url
,
gen_parameter_meta_url
from
.constants
import
HOME_DIR
,
LOG_DIR
,
NNI_PLATFORM
,
STDOUT_FULL_PATH
,
STDERR_FULL_PATH
from
.hdfsClientUtility
import
copyDirectoryToHdfs
,
copyHdfsDirectoryToLocal
from
.constants
import
HOME_DIR
,
LOG_DIR
,
NNI_PLATFORM
,
STDOUT_FULL_PATH
,
STDERR_FULL_PATH
,
\
MULTI_PHASE
,
NNI_TRIAL_JOB_ID
,
NNI_SYS_DIR
,
NNI_EXP_ID
from
.hdfsClientUtility
import
copyDirectoryToHdfs
,
copyHdfsDirectoryToLocal
,
copyHdfsFileToLocal
from
.log_utils
import
LogType
,
nni_log
,
RemoteLogger
,
PipeLogReader
,
StdOutputType
logger
=
logging
.
getLogger
(
'trial_keeper'
)
regular
=
re
.
compile
(
'v?(?P<version>[0-9](\.[0-9]){0,1}).*'
)
_hdfs_client
=
None
def
get_hdfs_client
(
args
):
global
_hdfs_client
if
_hdfs_client
is
not
None
:
return
_hdfs_client
# backward compatibility
hdfs_host
=
None
if
args
.
hdfs_host
:
hdfs_host
=
args
.
hdfs_host
elif
args
.
pai_hdfs_host
:
hdfs_host
=
args
.
pai_hdfs_host
else
:
return
None
if
hdfs_host
is
not
None
and
args
.
nni_hdfs_exp_dir
is
not
None
:
try
:
if
args
.
webhdfs_path
:
_hdfs_client
=
HdfsClient
(
hosts
=
'{0}:80'
.
format
(
hdfs_host
),
user_name
=
args
.
pai_user_name
,
webhdfs_path
=
args
.
webhdfs_path
,
timeout
=
5
)
else
:
# backward compatibility
_hdfs_client
=
HdfsClient
(
hosts
=
'{0}:{1}'
.
format
(
hdfs_host
,
'50070'
),
user_name
=
args
.
pai_user_name
,
timeout
=
5
)
except
Exception
as
e
:
nni_log
(
LogType
.
Error
,
'Create HDFS client error: '
+
str
(
e
))
raise
e
return
_hdfs_client
def
main_loop
(
args
):
'''main loop logic for trial keeper'''
...
...
@@ -52,28 +83,16 @@ def main_loop(args):
# redirect trial keeper's stdout and stderr to syslog
trial_syslogger_stdout
=
RemoteLogger
(
args
.
nnimanager_ip
,
args
.
nnimanager_port
,
'trial'
,
StdOutputType
.
Stdout
,
args
.
log_collection
)
sys
.
stdout
=
sys
.
stderr
=
trial_keeper_syslogger
# backward compatibility
hdfs_host
=
None
hdfs_output_dir
=
None
if
args
.
hdfs_host
:
hdfs_host
=
args
.
hdfs_host
elif
args
.
pai_hdfs_host
:
hdfs_host
=
args
.
pai_hdfs_host
if
args
.
hdfs_output_dir
:
hdfs_output_dir
=
args
.
hdfs_output_dir
elif
args
.
pai_hdfs_output_dir
:
hdfs_output_dir
=
args
.
pai_hdfs_output_dir
if
hdfs_host
is
not
None
and
args
.
nni_hdfs_exp_dir
is
not
None
:
try
:
if
args
.
webhdfs_path
:
hdfs_client
=
HdfsClient
(
hosts
=
'{0}:80'
.
format
(
hdfs_host
),
user_name
=
args
.
pai_user_name
,
webhdfs_path
=
args
.
webhdfs_path
,
timeout
=
5
)
else
:
# backward compatibility
hdfs_client
=
HdfsClient
(
hosts
=
'{0}:{1}'
.
format
(
hdfs_host
,
'50070'
),
user_name
=
args
.
pai_user_name
,
timeout
=
5
)
except
Exception
as
e
:
nni_log
(
LogType
.
Error
,
'Create HDFS client error: '
+
str
(
e
))
raise
e
hdfs_client
=
get_hdfs_client
(
args
)
if
hdfs_client
is
not
None
:
copyHdfsDirectoryToLocal
(
args
.
nni_hdfs_exp_dir
,
os
.
getcwd
(),
hdfs_client
)
# Notice: We don't appoint env, which means subprocess wil inherit current environment and that is expected behavior
...
...
@@ -138,6 +157,52 @@ def check_version(args):
except
AttributeError
as
err
:
nni_log
(
LogType
.
Error
,
err
)
def
is_multi_phase
():
return
MULTI_PHASE
and
(
MULTI_PHASE
in
[
'True'
,
'true'
])
def
download_parameter
(
meta_list
,
args
):
"""
Download parameter file to local working directory.
meta_list format is defined in paiJobRestServer.ts
example meta_list:
[
{"experimentId":"yWFJarYa","trialId":"UpPkl","filePath":"/chec/nni/experiments/yWFJarYa/trials/UpPkl/parameter_1.cfg"},
{"experimentId":"yWFJarYa","trialId":"aIUMA","filePath":"/chec/nni/experiments/yWFJarYa/trials/aIUMA/parameter_1.cfg"}
]
"""
nni_log
(
LogType
.
Debug
,
str
(
meta_list
))
nni_log
(
LogType
.
Debug
,
'NNI_SYS_DIR: {}, trial Id: {}, experiment ID: {}'
.
format
(
NNI_SYS_DIR
,
NNI_TRIAL_JOB_ID
,
NNI_EXP_ID
))
nni_log
(
LogType
.
Debug
,
'NNI_SYS_DIR files: {}'
.
format
(
os
.
listdir
(
NNI_SYS_DIR
)))
for
meta
in
meta_list
:
if
meta
[
'experimentId'
]
==
NNI_EXP_ID
and
meta
[
'trialId'
]
==
NNI_TRIAL_JOB_ID
:
param_fp
=
os
.
path
.
join
(
NNI_SYS_DIR
,
os
.
path
.
basename
(
meta
[
'filePath'
]))
if
not
os
.
path
.
exists
(
param_fp
):
hdfs_client
=
get_hdfs_client
(
args
)
copyHdfsFileToLocal
(
meta
[
'filePath'
],
param_fp
,
hdfs_client
,
override
=
False
)
def
fetch_parameter_file
(
args
):
class
FetchThread
(
threading
.
Thread
):
def
__init__
(
self
,
args
):
super
(
FetchThread
,
self
).
__init__
()
self
.
args
=
args
def
run
(
self
):
uri
=
gen_parameter_meta_url
(
self
.
args
.
nnimanager_ip
,
self
.
args
.
nnimanager_port
)
nni_log
(
LogType
.
Info
,
uri
)
while
True
:
res
=
rest_get
(
uri
,
10
)
nni_log
(
LogType
.
Debug
,
'status code: {}'
.
format
(
res
.
status_code
))
if
res
.
status_code
==
200
:
meta_list
=
res
.
json
()
download_parameter
(
meta_list
,
self
.
args
)
else
:
nni_log
(
LogType
.
Warning
,
'rest response: {}'
.
format
(
str
(
res
)))
time
.
sleep
(
5
)
fetch_file_thread
=
FetchThread
(
args
)
fetch_file_thread
.
start
()
if
__name__
==
'__main__'
:
'''NNI Trial Keeper main function'''
PARSER
=
argparse
.
ArgumentParser
()
...
...
@@ -159,6 +224,8 @@ if __name__ == '__main__':
exit
(
1
)
check_version
(
args
)
try
:
if
NNI_PLATFORM
==
'pai'
and
is_multi_phase
():
fetch_parameter_file
(
args
)
main_loop
(
args
)
except
SystemExit
as
se
:
nni_log
(
LogType
.
Info
,
'NNI trial keeper exit with code {}'
.
format
(
se
.
code
))
...
...
tools/nni_trial_tool/url_utils.py
View file @
d6febf29
...
...
@@ -18,7 +18,7 @@
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
from
.constants
import
API_ROOT_URL
,
BASE_URL
,
STDOUT_API
,
NNI_TRIAL_JOB_ID
,
NNI_EXP_ID
,
VERSION_API
from
.constants
import
API_ROOT_URL
,
BASE_URL
,
STDOUT_API
,
NNI_TRIAL_JOB_ID
,
NNI_EXP_ID
,
VERSION_API
,
PARAMETER_META_API
def
gen_send_stdout_url
(
ip
,
port
):
'''Generate send stdout url'''
...
...
@@ -26,4 +26,8 @@ def gen_send_stdout_url(ip, port):
def
gen_send_version_url
(
ip
,
port
):
'''Generate send error url'''
return
'{0}:{1}{2}{3}/{4}/{5}'
.
format
(
BASE_URL
.
format
(
ip
),
port
,
API_ROOT_URL
,
VERSION_API
,
NNI_EXP_ID
,
NNI_TRIAL_JOB_ID
)
\ No newline at end of file
return
'{0}:{1}{2}{3}/{4}/{5}'
.
format
(
BASE_URL
.
format
(
ip
),
port
,
API_ROOT_URL
,
VERSION_API
,
NNI_EXP_ID
,
NNI_TRIAL_JOB_ID
)
def
gen_parameter_meta_url
(
ip
,
port
):
'''Generate send error url'''
return
'{0}:{1}{2}{3}'
.
format
(
BASE_URL
.
format
(
ip
),
port
,
API_ROOT_URL
,
PARAMETER_META_API
)
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment