Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
e9f137f0
Unverified
Commit
e9f137f0
authored
Feb 09, 2020
by
QuanluZhang
Committed by
GitHub
Feb 09, 2020
Browse files
merge from master (#2019)
parent
f7cf3ea5
Changes
107
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
106 additions
and
37 deletions
+106
-37
tools/nni_cmd/config_schema.py
tools/nni_cmd/config_schema.py
+19
-16
tools/nni_cmd/launcher.py
tools/nni_cmd/launcher.py
+27
-13
tools/nni_cmd/launcher_utils.py
tools/nni_cmd/launcher_utils.py
+45
-1
tools/nni_cmd/nnictl.py
tools/nni_cmd/nnictl.py
+2
-2
tools/nni_cmd/nnictl_utils.py
tools/nni_cmd/nnictl_utils.py
+3
-1
tools/nni_cmd/ssh_utils.py
tools/nni_cmd/ssh_utils.py
+6
-2
tools/nni_cmd/tensorboard_utils.py
tools/nni_cmd/tensorboard_utils.py
+4
-2
No files found.
tools/nni_cmd/config_schema.py
View file @
e9f137f0
...
@@ -271,16 +271,17 @@ pai_yarn_config_schema = {
...
@@ -271,16 +271,17 @@ pai_yarn_config_schema = {
pai_trial_schema
=
{
pai_trial_schema
=
{
'trial'
:{
'trial'
:{
'command'
:
setType
(
'command'
,
str
),
'codeDir'
:
setPathCheck
(
'codeDir'
),
'codeDir'
:
setPathCheck
(
'codeDir'
),
'gpuNum'
:
setNumberRange
(
'gpuNum'
,
int
,
0
,
99999
),
'cpuNum'
:
setNumberRange
(
'cpuNum'
,
int
,
0
,
99999
),
'memoryMB'
:
setType
(
'memoryMB'
,
int
),
'image'
:
setType
(
'image'
,
str
),
Optional
(
'virtualCluster'
):
setType
(
'virtualCluster'
,
str
),
'nniManagerNFSMountPath'
:
setPathCheck
(
'nniManagerNFSMountPath'
),
'nniManagerNFSMountPath'
:
setPathCheck
(
'nniManagerNFSMountPath'
),
'containerNFSMountPath'
:
setType
(
'containerNFSMountPath'
,
str
),
'containerNFSMountPath'
:
setType
(
'containerNFSMountPath'
,
str
),
'paiStoragePlugin'
:
setType
(
'paiStoragePlugin'
,
str
)
'command'
:
setType
(
'command'
,
str
),
Optional
(
'gpuNum'
):
setNumberRange
(
'gpuNum'
,
int
,
0
,
99999
),
Optional
(
'cpuNum'
):
setNumberRange
(
'cpuNum'
,
int
,
0
,
99999
),
Optional
(
'memoryMB'
):
setType
(
'memoryMB'
,
int
),
Optional
(
'image'
):
setType
(
'image'
,
str
),
Optional
(
'virtualCluster'
):
setType
(
'virtualCluster'
,
str
),
Optional
(
'paiStoragePlugin'
):
setType
(
'paiStoragePlugin'
,
str
),
Optional
(
'paiConfigPath'
):
And
(
os
.
path
.
exists
,
error
=
SCHEMA_PATH_ERROR
%
'paiConfigPath'
)
}
}
}
}
...
@@ -407,15 +408,8 @@ frameworkcontroller_config_schema = {
...
@@ -407,15 +408,8 @@ frameworkcontroller_config_schema = {
}
}
machine_list_schema
=
{
machine_list_schema
=
{
Optional
(
'machineList'
):[
Or
({
Optional
(
'machineList'
):[
Or
(
'ip'
:
setType
(
'ip'
,
str
),
{
Optional
(
'port'
):
setNumberRange
(
'port'
,
int
,
1
,
65535
),
'username'
:
setType
(
'username'
,
str
),
'passwd'
:
setType
(
'passwd'
,
str
),
Optional
(
'gpuIndices'
):
Or
(
int
,
And
(
str
,
lambda
x
:
len
([
int
(
i
)
for
i
in
x
.
split
(
','
)])
>
0
),
error
=
'gpuIndex format error!'
),
Optional
(
'maxTrialNumPerGpu'
):
setType
(
'maxTrialNumPerGpu'
,
int
),
Optional
(
'useActiveGpu'
):
setType
(
'useActiveGpu'
,
bool
)
},
{
'ip'
:
setType
(
'ip'
,
str
),
'ip'
:
setType
(
'ip'
,
str
),
Optional
(
'port'
):
setNumberRange
(
'port'
,
int
,
1
,
65535
),
Optional
(
'port'
):
setNumberRange
(
'port'
,
int
,
1
,
65535
),
'username'
:
setType
(
'username'
,
str
),
'username'
:
setType
(
'username'
,
str
),
...
@@ -424,6 +418,15 @@ machine_list_schema = {
...
@@ -424,6 +418,15 @@ machine_list_schema = {
Optional
(
'gpuIndices'
):
Or
(
int
,
And
(
str
,
lambda
x
:
len
([
int
(
i
)
for
i
in
x
.
split
(
','
)])
>
0
),
error
=
'gpuIndex format error!'
),
Optional
(
'gpuIndices'
):
Or
(
int
,
And
(
str
,
lambda
x
:
len
([
int
(
i
)
for
i
in
x
.
split
(
','
)])
>
0
),
error
=
'gpuIndex format error!'
),
Optional
(
'maxTrialNumPerGpu'
):
setType
(
'maxTrialNumPerGpu'
,
int
),
Optional
(
'maxTrialNumPerGpu'
):
setType
(
'maxTrialNumPerGpu'
,
int
),
Optional
(
'useActiveGpu'
):
setType
(
'useActiveGpu'
,
bool
)
Optional
(
'useActiveGpu'
):
setType
(
'useActiveGpu'
,
bool
)
},
{
'ip'
:
setType
(
'ip'
,
str
),
Optional
(
'port'
):
setNumberRange
(
'port'
,
int
,
1
,
65535
),
'username'
:
setType
(
'username'
,
str
),
'passwd'
:
setType
(
'passwd'
,
str
),
Optional
(
'gpuIndices'
):
Or
(
int
,
And
(
str
,
lambda
x
:
len
([
int
(
i
)
for
i
in
x
.
split
(
','
)])
>
0
),
error
=
'gpuIndex format error!'
),
Optional
(
'maxTrialNumPerGpu'
):
setType
(
'maxTrialNumPerGpu'
,
int
),
Optional
(
'useActiveGpu'
):
setType
(
'useActiveGpu'
,
bool
)
})]
})]
}
}
...
...
tools/nni_cmd/launcher.py
View file @
e9f137f0
...
@@ -9,7 +9,7 @@ import random
...
@@ -9,7 +9,7 @@ import random
import
site
import
site
import
time
import
time
import
tempfile
import
tempfile
from
subprocess
import
Popen
,
check_call
,
CalledProcessError
from
subprocess
import
Popen
,
check_call
,
CalledProcessError
,
PIPE
,
STDOUT
from
nni_annotation
import
expand_annotations
,
generate_search_space
from
nni_annotation
import
expand_annotations
,
generate_search_space
from
nni.constants
import
ModuleName
,
AdvisorModuleName
from
nni.constants
import
ModuleName
,
AdvisorModuleName
from
.launcher_utils
import
validate_all_content
from
.launcher_utils
import
validate_all_content
...
@@ -20,7 +20,7 @@ from .common_utils import get_yml_content, get_json_content, print_error, print_
...
@@ -20,7 +20,7 @@ from .common_utils import get_yml_content, get_json_content, print_error, print_
detect_port
,
get_user
,
get_python_dir
detect_port
,
get_user
,
get_python_dir
from
.constants
import
NNICTL_HOME_DIR
,
ERROR_INFO
,
REST_TIME_OUT
,
EXPERIMENT_SUCCESS_INFO
,
LOG_HEADER
,
PACKAGE_REQUIREMENTS
from
.constants
import
NNICTL_HOME_DIR
,
ERROR_INFO
,
REST_TIME_OUT
,
EXPERIMENT_SUCCESS_INFO
,
LOG_HEADER
,
PACKAGE_REQUIREMENTS
from
.command_utils
import
check_output_command
,
kill_command
from
.command_utils
import
check_output_command
,
kill_command
from
.nnictl_utils
import
update_experiment
,
set_monitor
from
.nnictl_utils
import
update_experiment
def
get_log_path
(
config_file_name
):
def
get_log_path
(
config_file_name
):
'''generate stdout and stderr log path'''
'''generate stdout and stderr log path'''
...
@@ -78,17 +78,17 @@ def get_nni_installation_path():
...
@@ -78,17 +78,17 @@ def get_nni_installation_path():
print_error
(
'Fail to find nni under python library'
)
print_error
(
'Fail to find nni under python library'
)
exit
(
1
)
exit
(
1
)
def
start_rest_server
(
port
,
platform
,
mode
,
config_file_name
,
experiment_id
=
None
,
log_dir
=
None
,
log_level
=
None
):
def
start_rest_server
(
args
,
platform
,
mode
,
config_file_name
,
experiment_id
=
None
,
log_dir
=
None
,
log_level
=
None
):
'''Run nni manager process'''
'''Run nni manager process'''
if
detect_port
(
port
):
if
detect_port
(
args
.
port
):
print_error
(
'Port %s is used by another process, please reset the port!
\n
'
\
print_error
(
'Port %s is used by another process, please reset the port!
\n
'
\
'You could use
\'
nnictl create --help
\'
to get help information'
%
port
)
'You could use
\'
nnictl create --help
\'
to get help information'
%
args
.
port
)
exit
(
1
)
exit
(
1
)
if
(
platform
!=
'local'
)
and
detect_port
(
int
(
port
)
+
1
):
if
(
platform
!=
'local'
)
and
detect_port
(
int
(
args
.
port
)
+
1
):
print_error
(
'PAI mode need an additional adjacent port %d, and the port %d is used by another process!
\n
'
\
print_error
(
'PAI mode need an additional adjacent port %d, and the port %d is used by another process!
\n
'
\
'You could set another port to start experiment!
\n
'
\
'You could set another port to start experiment!
\n
'
\
'You could use
\'
nnictl create --help
\'
to get help information'
%
((
int
(
port
)
+
1
),
(
int
(
port
)
+
1
)))
'You could use
\'
nnictl create --help
\'
to get help information'
%
((
int
(
args
.
port
)
+
1
),
(
int
(
args
.
port
)
+
1
)))
exit
(
1
)
exit
(
1
)
print_normal
(
'Starting restful server...'
)
print_normal
(
'Starting restful server...'
)
...
@@ -99,7 +99,7 @@ def start_rest_server(port, platform, mode, config_file_name, experiment_id=None
...
@@ -99,7 +99,7 @@ def start_rest_server(port, platform, mode, config_file_name, experiment_id=None
node_command
=
'node'
node_command
=
'node'
if
sys
.
platform
==
'win32'
:
if
sys
.
platform
==
'win32'
:
node_command
=
os
.
path
.
join
(
entry_dir
[:
-
3
],
'Scripts'
,
'node.exe'
)
node_command
=
os
.
path
.
join
(
entry_dir
[:
-
3
],
'Scripts'
,
'node.exe'
)
cmds
=
[
node_command
,
entry_file
,
'--port'
,
str
(
port
),
'--mode'
,
platform
]
cmds
=
[
node_command
,
entry_file
,
'--port'
,
str
(
args
.
port
),
'--mode'
,
platform
]
if
mode
==
'view'
:
if
mode
==
'view'
:
cmds
+=
[
'--start_mode'
,
'resume'
]
cmds
+=
[
'--start_mode'
,
'resume'
]
cmds
+=
[
'--readonly'
,
'true'
]
cmds
+=
[
'--readonly'
,
'true'
]
...
@@ -111,6 +111,8 @@ def start_rest_server(port, platform, mode, config_file_name, experiment_id=None
...
@@ -111,6 +111,8 @@ def start_rest_server(port, platform, mode, config_file_name, experiment_id=None
cmds
+=
[
'--log_level'
,
log_level
]
cmds
+=
[
'--log_level'
,
log_level
]
if
mode
in
[
'resume'
,
'view'
]:
if
mode
in
[
'resume'
,
'view'
]:
cmds
+=
[
'--experiment_id'
,
experiment_id
]
cmds
+=
[
'--experiment_id'
,
experiment_id
]
if
args
.
foreground
:
cmds
+=
[
'--foreground'
,
'true'
]
stdout_full_path
,
stderr_full_path
=
get_log_path
(
config_file_name
)
stdout_full_path
,
stderr_full_path
=
get_log_path
(
config_file_name
)
with
open
(
stdout_full_path
,
'a+'
)
as
stdout_file
,
open
(
stderr_full_path
,
'a+'
)
as
stderr_file
:
with
open
(
stdout_full_path
,
'a+'
)
as
stdout_file
,
open
(
stderr_full_path
,
'a+'
)
as
stderr_file
:
time_now
=
time
.
strftime
(
'%Y-%m-%d %H:%M:%S'
,
time
.
localtime
(
time
.
time
()))
time_now
=
time
.
strftime
(
'%Y-%m-%d %H:%M:%S'
,
time
.
localtime
(
time
.
time
()))
...
@@ -120,9 +122,15 @@ def start_rest_server(port, platform, mode, config_file_name, experiment_id=None
...
@@ -120,9 +122,15 @@ def start_rest_server(port, platform, mode, config_file_name, experiment_id=None
stderr_file
.
write
(
log_header
)
stderr_file
.
write
(
log_header
)
if
sys
.
platform
==
'win32'
:
if
sys
.
platform
==
'win32'
:
from
subprocess
import
CREATE_NEW_PROCESS_GROUP
from
subprocess
import
CREATE_NEW_PROCESS_GROUP
process
=
Popen
(
cmds
,
cwd
=
entry_dir
,
stdout
=
stdout_file
,
stderr
=
stderr_file
,
creationflags
=
CREATE_NEW_PROCESS_GROUP
)
if
args
.
foreground
:
process
=
Popen
(
cmds
,
cwd
=
entry_dir
,
stdout
=
PIPE
,
stderr
=
STDOUT
,
creationflags
=
CREATE_NEW_PROCESS_GROUP
)
else
:
process
=
Popen
(
cmds
,
cwd
=
entry_dir
,
stdout
=
stdout_file
,
stderr
=
stderr_file
,
creationflags
=
CREATE_NEW_PROCESS_GROUP
)
else
:
else
:
process
=
Popen
(
cmds
,
cwd
=
entry_dir
,
stdout
=
stdout_file
,
stderr
=
stderr_file
)
if
args
.
foreground
:
process
=
Popen
(
cmds
,
cwd
=
entry_dir
,
stdout
=
PIPE
,
stderr
=
PIPE
)
else
:
process
=
Popen
(
cmds
,
cwd
=
entry_dir
,
stdout
=
stdout_file
,
stderr
=
stderr_file
)
return
process
,
str
(
time_now
)
return
process
,
str
(
time_now
)
def
set_trial_config
(
experiment_config
,
port
,
config_file_name
):
def
set_trial_config
(
experiment_config
,
port
,
config_file_name
):
...
@@ -424,7 +432,7 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
...
@@ -424,7 +432,7 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
if
log_level
not
in
[
'trace'
,
'debug'
]
and
(
args
.
debug
or
experiment_config
.
get
(
'debug'
)
is
True
):
if
log_level
not
in
[
'trace'
,
'debug'
]
and
(
args
.
debug
or
experiment_config
.
get
(
'debug'
)
is
True
):
log_level
=
'debug'
log_level
=
'debug'
# start rest server
# start rest server
rest_process
,
start_time
=
start_rest_server
(
args
.
port
,
experiment_config
[
'trainingServicePlatform'
],
\
rest_process
,
start_time
=
start_rest_server
(
args
,
experiment_config
[
'trainingServicePlatform'
],
\
mode
,
config_file_name
,
experiment_id
,
log_dir
,
log_level
)
mode
,
config_file_name
,
experiment_id
,
log_dir
,
log_level
)
nni_config
.
set_config
(
'restServerPid'
,
rest_process
.
pid
)
nni_config
.
set_config
(
'restServerPid'
,
rest_process
.
pid
)
# Deal with annotation
# Deal with annotation
...
@@ -493,8 +501,14 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
...
@@ -493,8 +501,14 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
experiment_config
[
'experimentName'
])
experiment_config
[
'experimentName'
])
print_normal
(
EXPERIMENT_SUCCESS_INFO
%
(
experiment_id
,
' '
.
join
(
web_ui_url_list
)))
print_normal
(
EXPERIMENT_SUCCESS_INFO
%
(
experiment_id
,
' '
.
join
(
web_ui_url_list
)))
if
args
.
watch
:
if
args
.
foreground
:
set_monitor
(
True
,
3
,
args
.
port
,
rest_process
.
pid
)
try
:
while
True
:
log_content
=
rest_process
.
stdout
.
readline
().
strip
().
decode
(
'utf-8'
)
print
(
log_content
)
except
KeyboardInterrupt
:
kill_command
(
rest_process
.
pid
)
print_normal
(
'Stopping experiment...'
)
def
create_experiment
(
args
):
def
create_experiment
(
args
):
'''start a new experiment'''
'''start a new experiment'''
...
...
tools/nni_cmd/launcher_utils.py
View file @
e9f137f0
...
@@ -7,7 +7,7 @@ from schema import SchemaError
...
@@ -7,7 +7,7 @@ from schema import SchemaError
from
schema
import
Schema
from
schema
import
Schema
from
.config_schema
import
LOCAL_CONFIG_SCHEMA
,
REMOTE_CONFIG_SCHEMA
,
PAI_CONFIG_SCHEMA
,
PAI_YARN_CONFIG_SCHEMA
,
KUBEFLOW_CONFIG_SCHEMA
,
\
from
.config_schema
import
LOCAL_CONFIG_SCHEMA
,
REMOTE_CONFIG_SCHEMA
,
PAI_CONFIG_SCHEMA
,
PAI_YARN_CONFIG_SCHEMA
,
KUBEFLOW_CONFIG_SCHEMA
,
\
FRAMEWORKCONTROLLER_CONFIG_SCHEMA
,
tuner_schema_dict
,
advisor_schema_dict
,
assessor_schema_dict
FRAMEWORKCONTROLLER_CONFIG_SCHEMA
,
tuner_schema_dict
,
advisor_schema_dict
,
assessor_schema_dict
from
.common_utils
import
print_error
,
print_warning
,
print_normal
from
.common_utils
import
print_error
,
print_warning
,
print_normal
,
get_yml_content
def
expand_path
(
experiment_config
,
key
):
def
expand_path
(
experiment_config
,
key
):
'''Change '~' to user home directory'''
'''Change '~' to user home directory'''
...
@@ -63,6 +63,8 @@ def parse_path(experiment_config, config_path):
...
@@ -63,6 +63,8 @@ def parse_path(experiment_config, config_path):
if
experiment_config
.
get
(
'machineList'
):
if
experiment_config
.
get
(
'machineList'
):
for
index
in
range
(
len
(
experiment_config
[
'machineList'
])):
for
index
in
range
(
len
(
experiment_config
[
'machineList'
])):
expand_path
(
experiment_config
[
'machineList'
][
index
],
'sshKeyPath'
)
expand_path
(
experiment_config
[
'machineList'
][
index
],
'sshKeyPath'
)
if
experiment_config
[
'trial'
].
get
(
'paiConfigPath'
):
expand_path
(
experiment_config
[
'trial'
],
'paiConfigPath'
)
#if users use relative path, convert it to absolute path
#if users use relative path, convert it to absolute path
root_path
=
os
.
path
.
dirname
(
config_path
)
root_path
=
os
.
path
.
dirname
(
config_path
)
...
@@ -94,6 +96,8 @@ def parse_path(experiment_config, config_path):
...
@@ -94,6 +96,8 @@ def parse_path(experiment_config, config_path):
if
experiment_config
.
get
(
'machineList'
):
if
experiment_config
.
get
(
'machineList'
):
for
index
in
range
(
len
(
experiment_config
[
'machineList'
])):
for
index
in
range
(
len
(
experiment_config
[
'machineList'
])):
parse_relative_path
(
root_path
,
experiment_config
[
'machineList'
][
index
],
'sshKeyPath'
)
parse_relative_path
(
root_path
,
experiment_config
[
'machineList'
][
index
],
'sshKeyPath'
)
if
experiment_config
[
'trial'
].
get
(
'paiConfigPath'
):
parse_relative_path
(
root_path
,
experiment_config
[
'trial'
],
'paiConfigPath'
)
def
validate_search_space_content
(
experiment_config
):
def
validate_search_space_content
(
experiment_config
):
'''Validate searchspace content,
'''Validate searchspace content,
...
@@ -254,6 +258,45 @@ def validate_machine_list(experiment_config):
...
@@ -254,6 +258,45 @@ def validate_machine_list(experiment_config):
print_error
(
'Please set machineList!'
)
print_error
(
'Please set machineList!'
)
exit
(
1
)
exit
(
1
)
def
validate_pai_config_path
(
experiment_config
):
'''validate paiConfigPath field'''
if
experiment_config
.
get
(
'trainingServicePlatform'
)
==
'pai'
:
if
experiment_config
.
get
(
'trial'
,
{}).
get
(
'paiConfigPath'
):
# validate the file format of paiConfigPath, ensure it is yaml format
pai_config
=
get_yml_content
(
experiment_config
[
'trial'
][
'paiConfigPath'
])
if
experiment_config
[
'trial'
].
get
(
'image'
)
is
None
:
if
pai_config
.
get
(
'prerequisites'
,
[{}])[
0
].
get
(
'uri'
)
is
None
:
print_error
(
'Please set image field, or set image uri in your own paiConfig!'
)
exit
(
1
)
experiment_config
[
'trial'
][
'image'
]
=
pai_config
[
'prerequisites'
][
0
][
'uri'
]
if
experiment_config
[
'trial'
].
get
(
'gpuNum'
)
is
None
:
if
pai_config
.
get
(
'taskRoles'
,
{}).
get
(
'taskrole'
,
{}).
get
(
'resourcePerInstance'
,
{}).
get
(
'gpu'
)
is
None
:
print_error
(
'Please set gpuNum field, or set resourcePerInstance gpu in your own paiConfig!'
)
exit
(
1
)
experiment_config
[
'trial'
][
'gpuNum'
]
=
pai_config
[
'taskRoles'
][
'taskrole'
][
'resourcePerInstance'
][
'gpu'
]
if
experiment_config
[
'trial'
].
get
(
'cpuNum'
)
is
None
:
if
pai_config
.
get
(
'taskRoles'
,
{}).
get
(
'taskrole'
,
{}).
get
(
'resourcePerInstance'
,
{}).
get
(
'cpu'
)
is
None
:
print_error
(
'Please set cpuNum field, or set resourcePerInstance cpu in your own paiConfig!'
)
exit
(
1
)
experiment_config
[
'trial'
][
'cpuNum'
]
=
pai_config
[
'taskRoles'
][
'taskrole'
][
'resourcePerInstance'
][
'cpu'
]
if
experiment_config
[
'trial'
].
get
(
'memoryMB'
)
is
None
:
if
pai_config
.
get
(
'taskRoles'
,
{}).
get
(
'taskrole'
,
{}).
get
(
'resourcePerInstance'
,
{}).
get
(
'memoryMB'
,
{})
is
None
:
print_error
(
'Please set memoryMB field, or set resourcePerInstance memoryMB in your own paiConfig!'
)
exit
(
1
)
experiment_config
[
'trial'
][
'memoryMB'
]
=
pai_config
[
'taskRoles'
][
'taskrole'
][
'resourcePerInstance'
][
'memoryMB'
]
if
experiment_config
[
'trial'
].
get
(
'paiStoragePlugin'
)
is
None
:
if
pai_config
.
get
(
'extras'
,
{}).
get
(
'com.microsoft.pai.runtimeplugin'
,
[{}])[
0
].
get
(
'plugin'
)
is
None
:
print_error
(
'Please set paiStoragePlugin field, or set plugin in your own paiConfig!'
)
exit
(
1
)
experiment_config
[
'trial'
][
'paiStoragePlugin'
]
=
pai_config
[
'extras'
][
'com.microsoft.pai.runtimeplugin'
][
0
][
'plugin'
]
else
:
pai_trial_fields_required_list
=
[
'image'
,
'gpuNum'
,
'cpuNum'
,
'memoryMB'
,
'paiStoragePlugin'
]
for
trial_field
in
pai_trial_fields_required_list
:
if
experiment_config
[
'trial'
].
get
(
trial_field
)
is
None
:
print_error
(
'Please set {0} in trial configuration,
\
or set additional pai configuration file path in paiConfigPath!'
.
format
(
trial_field
))
exit
(
1
)
def
validate_pai_trial_conifg
(
experiment_config
):
def
validate_pai_trial_conifg
(
experiment_config
):
'''validate the trial config in pai platform'''
'''validate the trial config in pai platform'''
if
experiment_config
.
get
(
'trainingServicePlatform'
)
in
[
'pai'
,
'paiYarn'
]:
if
experiment_config
.
get
(
'trainingServicePlatform'
)
in
[
'pai'
,
'paiYarn'
]:
...
@@ -269,6 +312,7 @@ def validate_pai_trial_conifg(experiment_config):
...
@@ -269,6 +312,7 @@ def validate_pai_trial_conifg(experiment_config):
print_warning
(
warning_information
.
format
(
'dataDir'
))
print_warning
(
warning_information
.
format
(
'dataDir'
))
if
experiment_config
.
get
(
'trial'
).
get
(
'outputDir'
):
if
experiment_config
.
get
(
'trial'
).
get
(
'outputDir'
):
print_warning
(
warning_information
.
format
(
'outputDir'
))
print_warning
(
warning_information
.
format
(
'outputDir'
))
validate_pai_config_path
(
experiment_config
)
def
validate_all_content
(
experiment_config
,
config_path
):
def
validate_all_content
(
experiment_config
,
config_path
):
'''Validate whether experiment_config is valid'''
'''Validate whether experiment_config is valid'''
...
...
tools/nni_cmd/nnictl.py
View file @
e9f137f0
...
@@ -51,7 +51,7 @@ def parse_args():
...
@@ -51,7 +51,7 @@ def parse_args():
parser_start
.
add_argument
(
'--config'
,
'-c'
,
required
=
True
,
dest
=
'config'
,
help
=
'the path of yaml config file'
)
parser_start
.
add_argument
(
'--config'
,
'-c'
,
required
=
True
,
dest
=
'config'
,
help
=
'the path of yaml config file'
)
parser_start
.
add_argument
(
'--port'
,
'-p'
,
default
=
DEFAULT_REST_PORT
,
dest
=
'port'
,
help
=
'the port of restful server'
)
parser_start
.
add_argument
(
'--port'
,
'-p'
,
default
=
DEFAULT_REST_PORT
,
dest
=
'port'
,
help
=
'the port of restful server'
)
parser_start
.
add_argument
(
'--debug'
,
'-d'
,
action
=
'store_true'
,
help
=
' set debug mode'
)
parser_start
.
add_argument
(
'--debug'
,
'-d'
,
action
=
'store_true'
,
help
=
' set debug mode'
)
parser_start
.
add_argument
(
'--
watch
'
,
'-
w
'
,
action
=
'store_true'
,
help
=
' set
watch mode
'
)
parser_start
.
add_argument
(
'--
foreground
'
,
'-
f
'
,
action
=
'store_true'
,
help
=
' set
foreground mode, print log content to terminal
'
)
parser_start
.
set_defaults
(
func
=
create_experiment
)
parser_start
.
set_defaults
(
func
=
create_experiment
)
# parse resume command
# parse resume command
...
@@ -59,7 +59,7 @@ def parse_args():
...
@@ -59,7 +59,7 @@ def parse_args():
parser_resume
.
add_argument
(
'id'
,
nargs
=
'?'
,
help
=
'The id of the experiment you want to resume'
)
parser_resume
.
add_argument
(
'id'
,
nargs
=
'?'
,
help
=
'The id of the experiment you want to resume'
)
parser_resume
.
add_argument
(
'--port'
,
'-p'
,
default
=
DEFAULT_REST_PORT
,
dest
=
'port'
,
help
=
'the port of restful server'
)
parser_resume
.
add_argument
(
'--port'
,
'-p'
,
default
=
DEFAULT_REST_PORT
,
dest
=
'port'
,
help
=
'the port of restful server'
)
parser_resume
.
add_argument
(
'--debug'
,
'-d'
,
action
=
'store_true'
,
help
=
' set debug mode'
)
parser_resume
.
add_argument
(
'--debug'
,
'-d'
,
action
=
'store_true'
,
help
=
' set debug mode'
)
parser_resume
.
add_argument
(
'--
watch
'
,
'-
w
'
,
action
=
'store_true'
,
help
=
' set
watch mode
'
)
parser_resume
.
add_argument
(
'--
foreground
'
,
'-
f
'
,
action
=
'store_true'
,
help
=
' set
foreground mode, print log content to terminal
'
)
parser_resume
.
set_defaults
(
func
=
resume_experiment
)
parser_resume
.
set_defaults
(
func
=
resume_experiment
)
# parse view command
# parse view command
...
...
tools/nni_cmd/nnictl_utils.py
View file @
e9f137f0
...
@@ -403,11 +403,13 @@ def remote_clean(machine_list, experiment_id=None):
...
@@ -403,11 +403,13 @@ def remote_clean(machine_list, experiment_id=None):
userName
=
machine
.
get
(
'username'
)
userName
=
machine
.
get
(
'username'
)
host
=
machine
.
get
(
'ip'
)
host
=
machine
.
get
(
'ip'
)
port
=
machine
.
get
(
'port'
)
port
=
machine
.
get
(
'port'
)
sshKeyPath
=
machine
.
get
(
'sshKeyPath'
)
passphrase
=
machine
.
get
(
'passphrase'
)
if
experiment_id
:
if
experiment_id
:
remote_dir
=
'/'
+
'/'
.
join
([
'tmp'
,
'nni'
,
'experiments'
,
experiment_id
])
remote_dir
=
'/'
+
'/'
.
join
([
'tmp'
,
'nni'
,
'experiments'
,
experiment_id
])
else
:
else
:
remote_dir
=
'/'
+
'/'
.
join
([
'tmp'
,
'nni'
,
'experiments'
])
remote_dir
=
'/'
+
'/'
.
join
([
'tmp'
,
'nni'
,
'experiments'
])
sftp
=
create_ssh_sftp_client
(
host
,
port
,
userName
,
passwd
)
sftp
=
create_ssh_sftp_client
(
host
,
port
,
userName
,
passwd
,
sshKeyPath
,
passphrase
)
print_normal
(
'removing folder {0}'
.
format
(
host
+
':'
+
str
(
port
)
+
remote_dir
))
print_normal
(
'removing folder {0}'
.
format
(
host
+
':'
+
str
(
port
)
+
remote_dir
))
remove_remote_directory
(
sftp
,
remote_dir
)
remove_remote_directory
(
sftp
,
remote_dir
)
...
...
tools/nni_cmd/ssh_utils.py
View file @
e9f137f0
...
@@ -30,12 +30,16 @@ def copy_remote_directory_to_local(sftp, remote_path, local_path):
...
@@ -30,12 +30,16 @@ def copy_remote_directory_to_local(sftp, remote_path, local_path):
except
Exception
:
except
Exception
:
pass
pass
def
create_ssh_sftp_client
(
host_ip
,
port
,
username
,
password
):
def
create_ssh_sftp_client
(
host_ip
,
port
,
username
,
password
,
ssh_key_path
,
passphrase
):
'''create ssh client'''
'''create ssh client'''
try
:
try
:
paramiko
=
check_environment
()
paramiko
=
check_environment
()
conn
=
paramiko
.
Transport
(
host_ip
,
port
)
conn
=
paramiko
.
Transport
(
host_ip
,
port
)
conn
.
connect
(
username
=
username
,
password
=
password
)
if
ssh_key_path
is
not
None
:
ssh_key
=
paramiko
.
RSAKey
.
from_private_key_file
(
ssh_key_path
,
password
=
passphrase
)
conn
.
connect
(
username
=
username
,
pkey
=
ssh_key
)
else
:
conn
.
connect
(
username
=
username
,
password
=
password
)
sftp
=
paramiko
.
SFTPClient
.
from_transport
(
conn
)
sftp
=
paramiko
.
SFTPClient
.
from_transport
(
conn
)
return
sftp
return
sftp
except
Exception
as
exception
:
except
Exception
as
exception
:
...
...
tools/nni_cmd/tensorboard_utils.py
View file @
e9f137f0
...
@@ -37,12 +37,14 @@ def copy_data_from_remote(args, nni_config, trial_content, path_list, host_list,
...
@@ -37,12 +37,14 @@ def copy_data_from_remote(args, nni_config, trial_content, path_list, host_list,
machine_dict
=
{}
machine_dict
=
{}
local_path_list
=
[]
local_path_list
=
[]
for
machine
in
machine_list
:
for
machine
in
machine_list
:
machine_dict
[
machine
[
'ip'
]]
=
{
'port'
:
machine
[
'port'
],
'passwd'
:
machine
[
'passwd'
],
'username'
:
machine
[
'username'
]}
machine_dict
[
machine
[
'ip'
]]
=
{
'port'
:
machine
[
'port'
],
'passwd'
:
machine
[
'passwd'
],
'username'
:
machine
[
'username'
],
'sshKeyPath'
:
machine
.
get
(
'sshKeyPath'
),
'passphrase'
:
machine
.
get
(
'passphrase'
)}
for
index
,
host
in
enumerate
(
host_list
):
for
index
,
host
in
enumerate
(
host_list
):
local_path
=
os
.
path
.
join
(
temp_nni_path
,
trial_content
[
index
].
get
(
'id'
))
local_path
=
os
.
path
.
join
(
temp_nni_path
,
trial_content
[
index
].
get
(
'id'
))
local_path_list
.
append
(
local_path
)
local_path_list
.
append
(
local_path
)
print_normal
(
'Copying log data from %s to %s'
%
(
host
+
':'
+
path_list
[
index
],
local_path
))
print_normal
(
'Copying log data from %s to %s'
%
(
host
+
':'
+
path_list
[
index
],
local_path
))
sftp
=
create_ssh_sftp_client
(
host
,
machine_dict
[
host
][
'port'
],
machine_dict
[
host
][
'username'
],
machine_dict
[
host
][
'passwd'
])
sftp
=
create_ssh_sftp_client
(
host
,
machine_dict
[
host
][
'port'
],
machine_dict
[
host
][
'username'
],
machine_dict
[
host
][
'passwd'
],
machine_dict
[
host
][
'sshKeyPath'
],
machine_dict
[
host
][
'passphrase'
])
copy_remote_directory_to_local
(
sftp
,
path_list
[
index
],
local_path
)
copy_remote_directory_to_local
(
sftp
,
path_list
[
index
],
local_path
)
print_normal
(
'Copy done!'
)
print_normal
(
'Copy done!'
)
return
local_path_list
return
local_path_list
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment