Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
9cbbf6f8
Unverified
Commit
9cbbf6f8
authored
Dec 23, 2019
by
SparkSnail
Committed by
GitHub
Dec 23, 2019
Browse files
Support pai and paiYarn trainingservice (#1853)
parent
9d01d083
Changes
26
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
68 additions
and
11 deletions
+68
-11
test/generate_ts_config.py
test/generate_ts_config.py
+3
-3
test/training_service.yml
test/training_service.yml
+2
-2
tools/nni_cmd/config_schema.py
tools/nni_cmd/config_schema.py
+32
-2
tools/nni_cmd/launcher.py
tools/nni_cmd/launcher.py
+26
-0
tools/nni_cmd/launcher_utils.py
tools/nni_cmd/launcher_utils.py
+4
-3
tools/nni_trial_tool/trial_keeper.py
tools/nni_trial_tool/trial_keeper.py
+1
-1
No files found.
test/generate_ts_config.py
View file @
9cbbf6f8
...
...
@@ -14,11 +14,11 @@ def update_training_service_config(args):
config
[
args
.
ts
][
'nniManagerIp'
]
=
args
.
nni_manager_ip
if
args
.
ts
==
'pai'
:
if
args
.
pai_user
is
not
None
:
config
[
args
.
ts
][
'paiConfig'
][
'userName'
]
=
args
.
pai_user
config
[
args
.
ts
][
'pai
Yarn
Config'
][
'userName'
]
=
args
.
pai_user
if
args
.
pai_pwd
is
not
None
:
config
[
args
.
ts
][
'paiConfig'
][
'passWord'
]
=
args
.
pai_pwd
config
[
args
.
ts
][
'pai
Yarn
Config'
][
'passWord'
]
=
args
.
pai_pwd
if
args
.
pai_host
is
not
None
:
config
[
args
.
ts
][
'paiConfig'
][
'host'
]
=
args
.
pai_host
config
[
args
.
ts
][
'pai
Yarn
Config'
][
'host'
]
=
args
.
pai_host
if
args
.
nni_docker_image
is
not
None
:
config
[
args
.
ts
][
'trial'
][
'image'
]
=
args
.
nni_docker_image
if
args
.
data_dir
is
not
None
:
...
...
test/training_service.yml
View file @
9cbbf6f8
...
...
@@ -29,11 +29,11 @@ local:
pai
:
nniManagerIp
:
maxExecDuration
:
15m
paiConfig
:
pai
Yarn
Config
:
host
:
passWord
:
userName
:
trainingServicePlatform
:
pai
trainingServicePlatform
:
pai
Yarn
trial
:
gpuNum
:
1
cpuNum
:
1
...
...
tools/nni_cmd/config_schema.py
View file @
9cbbf6f8
...
...
@@ -32,7 +32,7 @@ common_schema = {
'trialConcurrency'
:
setNumberRange
(
'trialConcurrency'
,
int
,
1
,
99999
),
Optional
(
'maxExecDuration'
):
And
(
Regex
(
r
'^[1-9][0-9]*[s|m|h|d]$'
,
error
=
'ERROR: maxExecDuration format is [digit]{s,m,h,d}'
)),
Optional
(
'maxTrialNum'
):
setNumberRange
(
'maxTrialNum'
,
int
,
1
,
99999
),
'trainingServicePlatform'
:
setChoice
(
'trainingServicePlatform'
,
'remote'
,
'local'
,
'pai'
,
'kubeflow'
,
'frameworkcontroller'
),
'trainingServicePlatform'
:
setChoice
(
'trainingServicePlatform'
,
'remote'
,
'local'
,
'pai'
,
'kubeflow'
,
'frameworkcontroller'
,
'paiYarn'
),
Optional
(
'searchSpacePath'
):
And
(
os
.
path
.
exists
,
error
=
SCHEMA_PATH_ERROR
%
'searchSpacePath'
),
Optional
(
'multiPhase'
):
setType
(
'multiPhase'
,
bool
),
Optional
(
'multiThread'
):
setType
(
'multiThread'
,
bool
),
...
...
@@ -232,7 +232,7 @@ common_trial_schema = {
}
}
pai_trial_schema
=
{
pai_
yarn_
trial_schema
=
{
'trial'
:{
'command'
:
setType
(
'command'
,
str
),
'codeDir'
:
setPathCheck
(
'codeDir'
),
...
...
@@ -256,6 +256,34 @@ pai_trial_schema = {
}
}
pai_yarn_config_schema
=
{
'paiYarnConfig'
:
Or
({
'userName'
:
setType
(
'userName'
,
str
),
'passWord'
:
setType
(
'passWord'
,
str
),
'host'
:
setType
(
'host'
,
str
)
},
{
'userName'
:
setType
(
'userName'
,
str
),
'token'
:
setType
(
'token'
,
str
),
'host'
:
setType
(
'host'
,
str
)
})
}
pai_trial_schema
=
{
'trial'
:{
'command'
:
setType
(
'command'
,
str
),
'codeDir'
:
setPathCheck
(
'codeDir'
),
'gpuNum'
:
setNumberRange
(
'gpuNum'
,
int
,
0
,
99999
),
'cpuNum'
:
setNumberRange
(
'cpuNum'
,
int
,
0
,
99999
),
'memoryMB'
:
setType
(
'memoryMB'
,
int
),
'image'
:
setType
(
'image'
,
str
),
Optional
(
'virtualCluster'
):
setType
(
'virtualCluster'
,
str
),
'nniManagerNFSMountPath'
:
setPathCheck
(
'nniManagerNFSMountPath'
),
'containerNFSMountPath'
:
setType
(
'containerNFSMountPath'
,
str
),
'paiStoragePlugin'
:
setType
(
'paiStoragePlugin'
,
str
)
}
}
pai_config_schema
=
{
'paiConfig'
:
Or
({
'userName'
:
setType
(
'userName'
,
str
),
...
...
@@ -405,6 +433,8 @@ REMOTE_CONFIG_SCHEMA = Schema({**common_schema, **common_trial_schema, **machine
PAI_CONFIG_SCHEMA
=
Schema
({
**
common_schema
,
**
pai_trial_schema
,
**
pai_config_schema
})
PAI_YARN_CONFIG_SCHEMA
=
Schema
({
**
common_schema
,
**
pai_yarn_trial_schema
,
**
pai_yarn_config_schema
})
KUBEFLOW_CONFIG_SCHEMA
=
Schema
({
**
common_schema
,
**
kubeflow_trial_schema
,
**
kubeflow_config_schema
})
FRAMEWORKCONTROLLER_CONFIG_SCHEMA
=
Schema
({
**
common_schema
,
**
frameworkcontroller_trial_schema
,
**
frameworkcontroller_config_schema
})
tools/nni_cmd/launcher.py
View file @
9cbbf6f8
...
...
@@ -224,6 +224,25 @@ def set_pai_config(experiment_config, port, config_file_name):
#set trial_config
return
set_trial_config
(
experiment_config
,
port
,
config_file_name
),
err_message
def
set_pai_yarn_config
(
experiment_config
,
port
,
config_file_name
):
'''set paiYarn configuration'''
pai_yarn_config_data
=
dict
()
pai_yarn_config_data
[
'pai_yarn_config'
]
=
experiment_config
[
'paiYarnConfig'
]
response
=
rest_put
(
cluster_metadata_url
(
port
),
json
.
dumps
(
pai_yarn_config_data
),
REST_TIME_OUT
)
err_message
=
None
if
not
response
or
not
response
.
status_code
==
200
:
if
response
is
not
None
:
err_message
=
response
.
text
_
,
stderr_full_path
=
get_log_path
(
config_file_name
)
with
open
(
stderr_full_path
,
'a+'
)
as
fout
:
fout
.
write
(
json
.
dumps
(
json
.
loads
(
err_message
),
indent
=
4
,
sort_keys
=
True
,
separators
=
(
','
,
':'
)))
return
False
,
err_message
result
,
message
=
setNNIManagerIp
(
experiment_config
,
port
,
config_file_name
)
if
not
result
:
return
result
,
message
#set trial_config
return
set_trial_config
(
experiment_config
,
port
,
config_file_name
),
err_message
def
set_kubeflow_config
(
experiment_config
,
port
,
config_file_name
):
'''set kubeflow configuration'''
kubeflow_config_data
=
dict
()
...
...
@@ -320,6 +339,11 @@ def set_experiment(experiment_config, mode, port, config_file_name):
{
'key'
:
'pai_config'
,
'value'
:
experiment_config
[
'paiConfig'
]})
request_data
[
'clusterMetaData'
].
append
(
{
'key'
:
'trial_config'
,
'value'
:
experiment_config
[
'trial'
]})
elif
experiment_config
[
'trainingServicePlatform'
]
==
'paiYarn'
:
request_data
[
'clusterMetaData'
].
append
(
{
'key'
:
'pai_yarn_config'
,
'value'
:
experiment_config
[
'paiYarnConfig'
]})
request_data
[
'clusterMetaData'
].
append
(
{
'key'
:
'trial_config'
,
'value'
:
experiment_config
[
'trial'
]})
elif
experiment_config
[
'trainingServicePlatform'
]
==
'kubeflow'
:
request_data
[
'clusterMetaData'
].
append
(
{
'key'
:
'kubeflow_config'
,
'value'
:
experiment_config
[
'kubeflowConfig'
]})
...
...
@@ -351,6 +375,8 @@ def set_platform_config(platform, experiment_config, port, config_file_name, res
config_result
,
err_msg
=
set_remote_config
(
experiment_config
,
port
,
config_file_name
)
elif
platform
==
'pai'
:
config_result
,
err_msg
=
set_pai_config
(
experiment_config
,
port
,
config_file_name
)
elif
platform
==
'paiYarn'
:
config_result
,
err_msg
=
set_pai_yarn_config
(
experiment_config
,
port
,
config_file_name
)
elif
platform
==
'kubeflow'
:
config_result
,
err_msg
=
set_kubeflow_config
(
experiment_config
,
port
,
config_file_name
)
elif
platform
==
'frameworkcontroller'
:
...
...
tools/nni_cmd/launcher_utils.py
View file @
9cbbf6f8
...
...
@@ -5,7 +5,7 @@ import os
import
json
from
schema
import
SchemaError
from
schema
import
Schema
from
.config_schema
import
LOCAL_CONFIG_SCHEMA
,
REMOTE_CONFIG_SCHEMA
,
PAI_CONFIG_SCHEMA
,
KUBEFLOW_CONFIG_SCHEMA
,
\
from
.config_schema
import
LOCAL_CONFIG_SCHEMA
,
REMOTE_CONFIG_SCHEMA
,
PAI_CONFIG_SCHEMA
,
PAI_YARN_CONFIG_SCHEMA
,
KUBEFLOW_CONFIG_SCHEMA
,
\
FRAMEWORKCONTROLLER_CONFIG_SCHEMA
,
tuner_schema_dict
,
advisor_schema_dict
,
assessor_schema_dict
from
.common_utils
import
print_error
,
print_warning
,
print_normal
...
...
@@ -143,13 +143,14 @@ def validate_kubeflow_operators(experiment_config):
def
validate_common_content
(
experiment_config
):
'''Validate whether the common values in experiment_config is valid'''
if
not
experiment_config
.
get
(
'trainingServicePlatform'
)
or
\
experiment_config
.
get
(
'trainingServicePlatform'
)
not
in
[
'local'
,
'remote'
,
'pai'
,
'kubeflow'
,
'frameworkcontroller'
]:
experiment_config
.
get
(
'trainingServicePlatform'
)
not
in
[
'local'
,
'remote'
,
'pai'
,
'kubeflow'
,
'frameworkcontroller'
,
'paiYarn'
]:
print_error
(
'Please set correct trainingServicePlatform!'
)
exit
(
1
)
schema_dict
=
{
'local'
:
LOCAL_CONFIG_SCHEMA
,
'remote'
:
REMOTE_CONFIG_SCHEMA
,
'pai'
:
PAI_CONFIG_SCHEMA
,
'paiYarn'
:
PAI_YARN_CONFIG_SCHEMA
,
'kubeflow'
:
KUBEFLOW_CONFIG_SCHEMA
,
'frameworkcontroller'
:
FRAMEWORKCONTROLLER_CONFIG_SCHEMA
}
...
...
@@ -255,7 +256,7 @@ def validate_machine_list(experiment_config):
def
validate_pai_trial_conifg
(
experiment_config
):
'''validate the trial config in pai platform'''
if
experiment_config
.
get
(
'trainingServicePlatform'
)
==
'pai'
:
if
experiment_config
.
get
(
'trainingServicePlatform'
)
in
[
'pai'
,
'paiYarn'
]
:
if
experiment_config
.
get
(
'trial'
).
get
(
'shmMB'
)
and
\
experiment_config
[
'trial'
][
'shmMB'
]
>
experiment_config
[
'trial'
][
'memoryMB'
]:
print_error
(
'shmMB should be no more than memoryMB!'
)
...
...
tools/nni_trial_tool/trial_keeper.py
View file @
9cbbf6f8
...
...
@@ -223,7 +223,7 @@ if __name__ == '__main__':
exit
(
1
)
check_version
(
args
)
try
:
if
NNI_PLATFORM
==
'pai'
and
is_multi_phase
():
if
NNI_PLATFORM
==
'pai
Yarn
'
and
is_multi_phase
():
fetch_parameter_file
(
args
)
main_loop
(
args
)
except
SystemExit
as
se
:
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment