Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
9cbbf6f8
"src/vscode:/vscode.git/clone" did not exist on "345684d58d99e945112c9c7d9b6fa45960734565"
Unverified
Commit
9cbbf6f8
authored
Dec 23, 2019
by
SparkSnail
Committed by
GitHub
Dec 23, 2019
Browse files
Support pai and paiYarn trainingservice (#1853)
parent
9d01d083
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
68 additions
and
11 deletions
+68
-11
test/generate_ts_config.py
test/generate_ts_config.py
+3
-3
test/training_service.yml
test/training_service.yml
+2
-2
tools/nni_cmd/config_schema.py
tools/nni_cmd/config_schema.py
+32
-2
tools/nni_cmd/launcher.py
tools/nni_cmd/launcher.py
+26
-0
tools/nni_cmd/launcher_utils.py
tools/nni_cmd/launcher_utils.py
+4
-3
tools/nni_trial_tool/trial_keeper.py
tools/nni_trial_tool/trial_keeper.py
+1
-1
No files found.
test/generate_ts_config.py
View file @
9cbbf6f8
...
...
@@ -14,11 +14,11 @@ def update_training_service_config(args):
config
[
args
.
ts
][
'nniManagerIp'
]
=
args
.
nni_manager_ip
if
args
.
ts
==
'pai'
:
if
args
.
pai_user
is
not
None
:
config
[
args
.
ts
][
'paiConfig'
][
'userName'
]
=
args
.
pai_user
config
[
args
.
ts
][
'pai
Yarn
Config'
][
'userName'
]
=
args
.
pai_user
if
args
.
pai_pwd
is
not
None
:
config
[
args
.
ts
][
'paiConfig'
][
'passWord'
]
=
args
.
pai_pwd
config
[
args
.
ts
][
'pai
Yarn
Config'
][
'passWord'
]
=
args
.
pai_pwd
if
args
.
pai_host
is
not
None
:
config
[
args
.
ts
][
'paiConfig'
][
'host'
]
=
args
.
pai_host
config
[
args
.
ts
][
'pai
Yarn
Config'
][
'host'
]
=
args
.
pai_host
if
args
.
nni_docker_image
is
not
None
:
config
[
args
.
ts
][
'trial'
][
'image'
]
=
args
.
nni_docker_image
if
args
.
data_dir
is
not
None
:
...
...
test/training_service.yml
View file @
9cbbf6f8
...
...
@@ -29,11 +29,11 @@ local:
pai
:
nniManagerIp
:
maxExecDuration
:
15m
paiConfig
:
pai
Yarn
Config
:
host
:
passWord
:
userName
:
trainingServicePlatform
:
pai
trainingServicePlatform
:
pai
Yarn
trial
:
gpuNum
:
1
cpuNum
:
1
...
...
tools/nni_cmd/config_schema.py
View file @
9cbbf6f8
...
...
@@ -32,7 +32,7 @@ common_schema = {
'trialConcurrency'
:
setNumberRange
(
'trialConcurrency'
,
int
,
1
,
99999
),
Optional
(
'maxExecDuration'
):
And
(
Regex
(
r
'^[1-9][0-9]*[s|m|h|d]$'
,
error
=
'ERROR: maxExecDuration format is [digit]{s,m,h,d}'
)),
Optional
(
'maxTrialNum'
):
setNumberRange
(
'maxTrialNum'
,
int
,
1
,
99999
),
'trainingServicePlatform'
:
setChoice
(
'trainingServicePlatform'
,
'remote'
,
'local'
,
'pai'
,
'kubeflow'
,
'frameworkcontroller'
),
'trainingServicePlatform'
:
setChoice
(
'trainingServicePlatform'
,
'remote'
,
'local'
,
'pai'
,
'kubeflow'
,
'frameworkcontroller'
,
'paiYarn'
),
Optional
(
'searchSpacePath'
):
And
(
os
.
path
.
exists
,
error
=
SCHEMA_PATH_ERROR
%
'searchSpacePath'
),
Optional
(
'multiPhase'
):
setType
(
'multiPhase'
,
bool
),
Optional
(
'multiThread'
):
setType
(
'multiThread'
,
bool
),
...
...
@@ -232,7 +232,7 @@ common_trial_schema = {
}
}
pai_trial_schema
=
{
pai_
yarn_
trial_schema
=
{
'trial'
:{
'command'
:
setType
(
'command'
,
str
),
'codeDir'
:
setPathCheck
(
'codeDir'
),
...
...
@@ -256,6 +256,34 @@ pai_trial_schema = {
}
}
pai_yarn_config_schema
=
{
'paiYarnConfig'
:
Or
({
'userName'
:
setType
(
'userName'
,
str
),
'passWord'
:
setType
(
'passWord'
,
str
),
'host'
:
setType
(
'host'
,
str
)
},
{
'userName'
:
setType
(
'userName'
,
str
),
'token'
:
setType
(
'token'
,
str
),
'host'
:
setType
(
'host'
,
str
)
})
}
pai_trial_schema
=
{
'trial'
:{
'command'
:
setType
(
'command'
,
str
),
'codeDir'
:
setPathCheck
(
'codeDir'
),
'gpuNum'
:
setNumberRange
(
'gpuNum'
,
int
,
0
,
99999
),
'cpuNum'
:
setNumberRange
(
'cpuNum'
,
int
,
0
,
99999
),
'memoryMB'
:
setType
(
'memoryMB'
,
int
),
'image'
:
setType
(
'image'
,
str
),
Optional
(
'virtualCluster'
):
setType
(
'virtualCluster'
,
str
),
'nniManagerNFSMountPath'
:
setPathCheck
(
'nniManagerNFSMountPath'
),
'containerNFSMountPath'
:
setType
(
'containerNFSMountPath'
,
str
),
'paiStoragePlugin'
:
setType
(
'paiStoragePlugin'
,
str
)
}
}
pai_config_schema
=
{
'paiConfig'
:
Or
({
'userName'
:
setType
(
'userName'
,
str
),
...
...
@@ -405,6 +433,8 @@ REMOTE_CONFIG_SCHEMA = Schema({**common_schema, **common_trial_schema, **machine
PAI_CONFIG_SCHEMA
=
Schema
({
**
common_schema
,
**
pai_trial_schema
,
**
pai_config_schema
})
PAI_YARN_CONFIG_SCHEMA
=
Schema
({
**
common_schema
,
**
pai_yarn_trial_schema
,
**
pai_yarn_config_schema
})
KUBEFLOW_CONFIG_SCHEMA
=
Schema
({
**
common_schema
,
**
kubeflow_trial_schema
,
**
kubeflow_config_schema
})
FRAMEWORKCONTROLLER_CONFIG_SCHEMA
=
Schema
({
**
common_schema
,
**
frameworkcontroller_trial_schema
,
**
frameworkcontroller_config_schema
})
tools/nni_cmd/launcher.py
View file @
9cbbf6f8
...
...
@@ -224,6 +224,25 @@ def set_pai_config(experiment_config, port, config_file_name):
#set trial_config
return
set_trial_config
(
experiment_config
,
port
,
config_file_name
),
err_message
def
set_pai_yarn_config
(
experiment_config
,
port
,
config_file_name
):
'''set paiYarn configuration'''
pai_yarn_config_data
=
dict
()
pai_yarn_config_data
[
'pai_yarn_config'
]
=
experiment_config
[
'paiYarnConfig'
]
response
=
rest_put
(
cluster_metadata_url
(
port
),
json
.
dumps
(
pai_yarn_config_data
),
REST_TIME_OUT
)
err_message
=
None
if
not
response
or
not
response
.
status_code
==
200
:
if
response
is
not
None
:
err_message
=
response
.
text
_
,
stderr_full_path
=
get_log_path
(
config_file_name
)
with
open
(
stderr_full_path
,
'a+'
)
as
fout
:
fout
.
write
(
json
.
dumps
(
json
.
loads
(
err_message
),
indent
=
4
,
sort_keys
=
True
,
separators
=
(
','
,
':'
)))
return
False
,
err_message
result
,
message
=
setNNIManagerIp
(
experiment_config
,
port
,
config_file_name
)
if
not
result
:
return
result
,
message
#set trial_config
return
set_trial_config
(
experiment_config
,
port
,
config_file_name
),
err_message
def
set_kubeflow_config
(
experiment_config
,
port
,
config_file_name
):
'''set kubeflow configuration'''
kubeflow_config_data
=
dict
()
...
...
@@ -320,6 +339,11 @@ def set_experiment(experiment_config, mode, port, config_file_name):
{
'key'
:
'pai_config'
,
'value'
:
experiment_config
[
'paiConfig'
]})
request_data
[
'clusterMetaData'
].
append
(
{
'key'
:
'trial_config'
,
'value'
:
experiment_config
[
'trial'
]})
elif
experiment_config
[
'trainingServicePlatform'
]
==
'paiYarn'
:
request_data
[
'clusterMetaData'
].
append
(
{
'key'
:
'pai_yarn_config'
,
'value'
:
experiment_config
[
'paiYarnConfig'
]})
request_data
[
'clusterMetaData'
].
append
(
{
'key'
:
'trial_config'
,
'value'
:
experiment_config
[
'trial'
]})
elif
experiment_config
[
'trainingServicePlatform'
]
==
'kubeflow'
:
request_data
[
'clusterMetaData'
].
append
(
{
'key'
:
'kubeflow_config'
,
'value'
:
experiment_config
[
'kubeflowConfig'
]})
...
...
@@ -351,6 +375,8 @@ def set_platform_config(platform, experiment_config, port, config_file_name, res
config_result
,
err_msg
=
set_remote_config
(
experiment_config
,
port
,
config_file_name
)
elif
platform
==
'pai'
:
config_result
,
err_msg
=
set_pai_config
(
experiment_config
,
port
,
config_file_name
)
elif
platform
==
'paiYarn'
:
config_result
,
err_msg
=
set_pai_yarn_config
(
experiment_config
,
port
,
config_file_name
)
elif
platform
==
'kubeflow'
:
config_result
,
err_msg
=
set_kubeflow_config
(
experiment_config
,
port
,
config_file_name
)
elif
platform
==
'frameworkcontroller'
:
...
...
tools/nni_cmd/launcher_utils.py
View file @
9cbbf6f8
...
...
@@ -5,7 +5,7 @@ import os
import
json
from
schema
import
SchemaError
from
schema
import
Schema
from
.config_schema
import
LOCAL_CONFIG_SCHEMA
,
REMOTE_CONFIG_SCHEMA
,
PAI_CONFIG_SCHEMA
,
KUBEFLOW_CONFIG_SCHEMA
,
\
from
.config_schema
import
LOCAL_CONFIG_SCHEMA
,
REMOTE_CONFIG_SCHEMA
,
PAI_CONFIG_SCHEMA
,
PAI_YARN_CONFIG_SCHEMA
,
KUBEFLOW_CONFIG_SCHEMA
,
\
FRAMEWORKCONTROLLER_CONFIG_SCHEMA
,
tuner_schema_dict
,
advisor_schema_dict
,
assessor_schema_dict
from
.common_utils
import
print_error
,
print_warning
,
print_normal
...
...
@@ -143,13 +143,14 @@ def validate_kubeflow_operators(experiment_config):
def
validate_common_content
(
experiment_config
):
'''Validate whether the common values in experiment_config is valid'''
if
not
experiment_config
.
get
(
'trainingServicePlatform'
)
or
\
experiment_config
.
get
(
'trainingServicePlatform'
)
not
in
[
'local'
,
'remote'
,
'pai'
,
'kubeflow'
,
'frameworkcontroller'
]:
experiment_config
.
get
(
'trainingServicePlatform'
)
not
in
[
'local'
,
'remote'
,
'pai'
,
'kubeflow'
,
'frameworkcontroller'
,
'paiYarn'
]:
print_error
(
'Please set correct trainingServicePlatform!'
)
exit
(
1
)
schema_dict
=
{
'local'
:
LOCAL_CONFIG_SCHEMA
,
'remote'
:
REMOTE_CONFIG_SCHEMA
,
'pai'
:
PAI_CONFIG_SCHEMA
,
'paiYarn'
:
PAI_YARN_CONFIG_SCHEMA
,
'kubeflow'
:
KUBEFLOW_CONFIG_SCHEMA
,
'frameworkcontroller'
:
FRAMEWORKCONTROLLER_CONFIG_SCHEMA
}
...
...
@@ -255,7 +256,7 @@ def validate_machine_list(experiment_config):
def
validate_pai_trial_conifg
(
experiment_config
):
'''validate the trial config in pai platform'''
if
experiment_config
.
get
(
'trainingServicePlatform'
)
==
'pai'
:
if
experiment_config
.
get
(
'trainingServicePlatform'
)
in
[
'pai'
,
'paiYarn'
]
:
if
experiment_config
.
get
(
'trial'
).
get
(
'shmMB'
)
and
\
experiment_config
[
'trial'
][
'shmMB'
]
>
experiment_config
[
'trial'
][
'memoryMB'
]:
print_error
(
'shmMB should be no more than memoryMB!'
)
...
...
tools/nni_trial_tool/trial_keeper.py
View file @
9cbbf6f8
...
...
@@ -223,7 +223,7 @@ if __name__ == '__main__':
exit
(
1
)
check_version
(
args
)
try
:
if
NNI_PLATFORM
==
'pai'
and
is_multi_phase
():
if
NNI_PLATFORM
==
'pai
Yarn
'
and
is_multi_phase
():
fetch_parameter_file
(
args
)
main_loop
(
args
)
except
SystemExit
as
se
:
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment