Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
df6145a2
Commit
df6145a2
authored
Dec 16, 2020
by
Yuge Zhang
Browse files
Merge branch 'master' of
https://github.com/microsoft/nni
into dev-retiarii
parents
0f0c6288
f8424a9f
Changes
205
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
114 additions
and
28 deletions
+114
-28
nni/algorithms/hpo/pbt_tuner/__init__.py
nni/algorithms/hpo/pbt_tuner/__init__.py
+0
-0
nni/algorithms/hpo/ppo_tuner/__init__.py
nni/algorithms/hpo/ppo_tuner/__init__.py
+1
-1
nni/algorithms/hpo/ppo_tuner/requirements.txt
nni/algorithms/hpo/ppo_tuner/requirements.txt
+0
-2
nni/algorithms/hpo/regularized_evolution_tuner.py
nni/algorithms/hpo/regularized_evolution_tuner.py
+0
-0
nni/algorithms/hpo/regularized_evolution_tuner/__init__.py
nni/algorithms/hpo/regularized_evolution_tuner/__init__.py
+0
-1
nni/algorithms/hpo/smac_tuner/__init__.py
nni/algorithms/hpo/smac_tuner/__init__.py
+1
-1
nni/algorithms/hpo/smac_tuner/requirements.txt
nni/algorithms/hpo/smac_tuner/requirements.txt
+0
-2
nni/algorithms/nas/pytorch/cdarts/__init__.py
nni/algorithms/nas/pytorch/cdarts/__init__.py
+1
-1
nni/algorithms/nas/pytorch/random/__init__.py
nni/algorithms/nas/pytorch/random/__init__.py
+1
-1
nni/compression/pytorch/compressor.py
nni/compression/pytorch/compressor.py
+1
-1
nni/experiment/config/base.py
nni/experiment/config/base.py
+1
-1
nni/experiment/launcher.py
nni/experiment/launcher.py
+1
-1
nni/nas/pytorch/__init__.py
nni/nas/pytorch/__init__.py
+6
-0
nni/runtime/env_vars.py
nni/runtime/env_vars.py
+2
-1
nni/runtime/log.py
nni/runtime/log.py
+1
-1
nni/runtime/platform/__init__.py
nni/runtime/platform/__init__.py
+1
-1
nni/runtime/platform/local.py
nni/runtime/platform/local.py
+2
-1
nni/tools/nnictl/config_schema.py
nni/tools/nnictl/config_schema.py
+43
-4
nni/tools/nnictl/config_utils.py
nni/tools/nnictl/config_utils.py
+4
-1
nni/tools/nnictl/launcher.py
nni/tools/nnictl/launcher.py
+48
-7
No files found.
nni/algorithms/hpo/pbt_tuner/__init__.py
deleted
100644 → 0
View file @
0f0c6288
nni/algorithms/hpo/ppo_tuner/__init__.py
View file @
df6145a2
from
.ppo_tuner
import
PPOTuner
from
.ppo_tuner
import
PPOTuner
,
PPOClassArgsValidator
nni/algorithms/hpo/ppo_tuner/requirements.txt
deleted
100644 → 0
View file @
0f0c6288
enum34
gym
nni/algorithms/hpo/regularized_evolution_tuner
/regularized_evolution_tuner
.py
→
nni/algorithms/hpo/regularized_evolution_tuner.py
View file @
df6145a2
File moved
nni/algorithms/hpo/regularized_evolution_tuner/__init__.py
deleted
100644 → 0
View file @
0f0c6288
from
.regularized_evolution_tuner
import
RegularizedEvolutionTuner
nni/algorithms/hpo/smac_tuner/__init__.py
View file @
df6145a2
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
from
.smac_tuner
import
SMACTuner
from
.smac_tuner
import
SMACTuner
,
SMACClassArgsValidator
nni/algorithms/hpo/smac_tuner/requirements.txt
deleted
100644 → 0
View file @
0f0c6288
git+https://github.com/QuanluZhang/ConfigSpace.git
git+https://github.com/QuanluZhang/SMAC3.git
nni/algorithms/nas/pytorch/cdarts/__init__.py
View file @
df6145a2
...
@@ -2,4 +2,4 @@
...
@@ -2,4 +2,4 @@
# Licensed under the MIT license.
# Licensed under the MIT license.
from
.mutator
import
RegularizedDartsMutator
,
RegularizedMutatorParallel
,
DartsDiscreteMutator
from
.mutator
import
RegularizedDartsMutator
,
RegularizedMutatorParallel
,
DartsDiscreteMutator
from
.trainer
import
CdartsTrainer
from
.trainer
import
CdartsTrainer
\ No newline at end of file
nni/algorithms/nas/pytorch/random/__init__.py
View file @
df6145a2
from
.mutator
import
RandomMutator
from
.mutator
import
RandomMutator
\ No newline at end of file
nni/compression/pytorch/compressor.py
View file @
df6145a2
...
@@ -662,7 +662,7 @@ class QuantGrad(torch.autograd.Function):
...
@@ -662,7 +662,7 @@ class QuantGrad(torch.autograd.Function):
if
quant_type
==
QuantType
.
QUANT_INPUT
:
if
quant_type
==
QuantType
.
QUANT_INPUT
:
output
=
wrapper
.
quantizer
.
quantize_input
(
tensor
,
wrapper
,
**
kwargs
)
output
=
wrapper
.
quantizer
.
quantize_input
(
tensor
,
wrapper
,
**
kwargs
)
elif
quant_type
==
QuantType
.
QUANT_WEIGHT
:
elif
quant_type
==
QuantType
.
QUANT_WEIGHT
:
output
=
wrapper
.
quantizer
.
quantize_weight
(
wrapper
,
**
kwargs
)
output
=
wrapper
.
quantizer
.
quantize_weight
(
wrapper
,
**
kwargs
)
elif
quant_type
==
QuantType
.
QUANT_OUTPUT
:
elif
quant_type
==
QuantType
.
QUANT_OUTPUT
:
output
=
wrapper
.
quantizer
.
quantize_output
(
tensor
,
wrapper
,
**
kwargs
)
output
=
wrapper
.
quantizer
.
quantize_output
(
tensor
,
wrapper
,
**
kwargs
)
else
:
else
:
...
...
nni/experiment/config/base.py
View file @
df6145a2
...
@@ -87,7 +87,7 @@ class ConfigBase:
...
@@ -87,7 +87,7 @@ class ConfigBase:
"""
"""
return
dataclasses
.
asdict
(
return
dataclasses
.
asdict
(
self
.
canonical
(),
self
.
canonical
(),
dict_factory
=
lambda
items
:
dict
((
util
.
camel_case
(
k
),
v
)
for
k
,
v
in
items
if
v
is
not
None
)
dict_factory
=
lambda
items
:
dict
((
util
.
camel_case
(
k
),
v
)
for
k
,
v
in
items
if
v
is
not
None
)
)
)
def
canonical
(
self
:
T
)
->
T
:
def
canonical
(
self
:
T
)
->
T
:
...
...
nni/experiment/launcher.py
View file @
df6145a2
...
@@ -32,7 +32,7 @@ def start_experiment(config: ExperimentConfig, port: int, debug: bool) -> Tuple[
...
@@ -32,7 +32,7 @@ def start_experiment(config: ExperimentConfig, port: int, debug: bool) -> Tuple[
exp_id
=
management
.
generate_experiment_id
()
exp_id
=
management
.
generate_experiment_id
()
try
:
try
:
_logger
.
info
(
f
'Creating experiment
{
colorama
.
Fore
.
CYAN
}{
exp_id
}
'
)
_logger
.
info
(
'Creating experiment
%s%s'
,
colorama
.
Fore
.
CYAN
,
exp_id
)
pipe
=
Pipe
(
exp_id
)
pipe
=
Pipe
(
exp_id
)
proc
=
_start_rest_server
(
config
,
port
,
debug
,
exp_id
,
pipe
.
path
)
proc
=
_start_rest_server
(
config
,
port
,
debug
,
exp_id
,
pipe
.
path
)
_logger
.
info
(
'Connecting IPC pipe...'
)
_logger
.
info
(
'Connecting IPC pipe...'
)
...
...
nni/nas/pytorch/__init__.py
View file @
df6145a2
from
.base_mutator
import
BaseMutator
from
.base_trainer
import
BaseTrainer
from
.fixed
import
apply_fixed_architecture
from
.mutables
import
Mutable
,
LayerChoice
,
InputChoice
from
.mutator
import
Mutator
from
.trainer
import
Trainer
nni/runtime/env_vars.py
View file @
df6145a2
...
@@ -12,7 +12,8 @@ _trial_env_var_names = [
...
@@ -12,7 +12,8 @@ _trial_env_var_names = [
'NNI_SYS_DIR'
,
'NNI_SYS_DIR'
,
'NNI_OUTPUT_DIR'
,
'NNI_OUTPUT_DIR'
,
'NNI_TRIAL_SEQ_ID'
,
'NNI_TRIAL_SEQ_ID'
,
'MULTI_PHASE'
'MULTI_PHASE'
,
'REUSE_MODE'
]
]
_dispatcher_env_var_names
=
[
_dispatcher_env_var_names
=
[
...
...
nni/runtime/log.py
View file @
df6145a2
...
@@ -31,7 +31,7 @@ def init_logger() -> None:
...
@@ -31,7 +31,7 @@ def init_logger() -> None:
if
trial_platform
==
'unittest'
:
if
trial_platform
==
'unittest'
:
return
return
if
trial_platform
:
if
trial_platform
and
not
trial_env_vars
.
REUSE_MODE
:
_init_logger_trial
()
_init_logger_trial
()
return
return
...
...
nni/runtime/platform/__init__.py
View file @
df6145a2
...
@@ -9,7 +9,7 @@ if trial_env_vars.NNI_PLATFORM is None:
...
@@ -9,7 +9,7 @@ if trial_env_vars.NNI_PLATFORM is None:
from
.standalone
import
*
from
.standalone
import
*
elif
trial_env_vars
.
NNI_PLATFORM
==
'unittest'
:
elif
trial_env_vars
.
NNI_PLATFORM
==
'unittest'
:
from
.test
import
*
from
.test
import
*
elif
trial_env_vars
.
NNI_PLATFORM
in
(
'adl'
,
'local'
,
'remote'
,
'pai'
,
'kubeflow'
,
'frameworkcontroller'
,
'paiYarn'
,
'dlts'
,
'aml'
):
elif
trial_env_vars
.
NNI_PLATFORM
in
(
'local'
,
'remote'
,
'pai'
,
'kubeflow'
,
'frameworkcontroller'
,
'paiYarn'
,
'dlts'
,
'aml'
,
'adl'
,
'heterogeneous'
):
from
.local
import
*
from
.local
import
*
else
:
else
:
raise
RuntimeError
(
'Unknown platform %s'
%
trial_env_vars
.
NNI_PLATFORM
)
raise
RuntimeError
(
'Unknown platform %s'
%
trial_env_vars
.
NNI_PLATFORM
)
nni/runtime/platform/local.py
View file @
df6145a2
...
@@ -19,6 +19,7 @@ _outputdir = trial_env_vars.NNI_OUTPUT_DIR
...
@@ -19,6 +19,7 @@ _outputdir = trial_env_vars.NNI_OUTPUT_DIR
if
not
os
.
path
.
exists
(
_outputdir
):
if
not
os
.
path
.
exists
(
_outputdir
):
os
.
makedirs
(
_outputdir
)
os
.
makedirs
(
_outputdir
)
_reuse_mode
=
trial_env_vars
.
REUSE_MODE
_nni_platform
=
trial_env_vars
.
NNI_PLATFORM
_nni_platform
=
trial_env_vars
.
NNI_PLATFORM
_multiphase
=
trial_env_vars
.
MULTI_PHASE
_multiphase
=
trial_env_vars
.
MULTI_PHASE
...
@@ -58,7 +59,7 @@ def get_next_parameter():
...
@@ -58,7 +59,7 @@ def get_next_parameter():
return
params
return
params
def
send_metric
(
string
):
def
send_metric
(
string
):
if
_nni_platform
!=
'local'
:
if
_nni_platform
!=
'local'
or
_reuse_mode
in
(
'true'
,
'True'
)
:
assert
len
(
string
)
<
1000000
,
'Metric too long'
assert
len
(
string
)
<
1000000
,
'Metric too long'
print
(
"NNISDK_MEb'%s'"
%
(
string
),
flush
=
True
)
print
(
"NNISDK_MEb'%s'"
%
(
string
),
flush
=
True
)
else
:
else
:
...
...
nni/tools/nnictl/config_schema.py
View file @
df6145a2
...
@@ -124,7 +124,7 @@ common_schema = {
...
@@ -124,7 +124,7 @@ common_schema = {
Optional
(
'maxExecDuration'
):
And
(
Regex
(
r
'^[1-9][0-9]*[s|m|h|d]$'
,
error
=
'ERROR: maxExecDuration format is [digit]{s,m,h,d}'
)),
Optional
(
'maxExecDuration'
):
And
(
Regex
(
r
'^[1-9][0-9]*[s|m|h|d]$'
,
error
=
'ERROR: maxExecDuration format is [digit]{s,m,h,d}'
)),
Optional
(
'maxTrialNum'
):
setNumberRange
(
'maxTrialNum'
,
int
,
1
,
99999
),
Optional
(
'maxTrialNum'
):
setNumberRange
(
'maxTrialNum'
,
int
,
1
,
99999
),
'trainingServicePlatform'
:
setChoice
(
'trainingServicePlatform'
:
setChoice
(
'trainingServicePlatform'
,
'adl'
,
'remote'
,
'local'
,
'pai'
,
'kubeflow'
,
'frameworkcontroller'
,
'paiYarn'
,
'dlts'
,
'aml'
),
'trainingServicePlatform'
,
'remote'
,
'local'
,
'pai'
,
'kubeflow'
,
'frameworkcontroller'
,
'paiYarn'
,
'dlts'
,
'aml'
,
'adl'
,
'heterogeneous'
),
Optional
(
'searchSpacePath'
):
And
(
os
.
path
.
exists
,
error
=
SCHEMA_PATH_ERROR
%
'searchSpacePath'
),
Optional
(
'searchSpacePath'
):
And
(
os
.
path
.
exists
,
error
=
SCHEMA_PATH_ERROR
%
'searchSpacePath'
),
Optional
(
'multiPhase'
):
setType
(
'multiPhase'
,
bool
),
Optional
(
'multiPhase'
):
setType
(
'multiPhase'
,
bool
),
Optional
(
'multiThread'
):
setType
(
'multiThread'
,
bool
),
Optional
(
'multiThread'
):
setType
(
'multiThread'
,
bool
),
...
@@ -208,7 +208,7 @@ pai_trial_schema = {
...
@@ -208,7 +208,7 @@ pai_trial_schema = {
}
}
pai_config_schema
=
{
pai_config_schema
=
{
'paiConfig'
:
{
Optional
(
'paiConfig'
)
:
{
'userName'
:
setType
(
'userName'
,
str
),
'userName'
:
setType
(
'userName'
,
str
),
Or
(
'passWord'
,
'token'
,
only_one
=
True
):
str
,
Or
(
'passWord'
,
'token'
,
only_one
=
True
):
str
,
'host'
:
setType
(
'host'
,
str
),
'host'
:
setType
(
'host'
,
str
),
...
@@ -252,7 +252,7 @@ aml_trial_schema = {
...
@@ -252,7 +252,7 @@ aml_trial_schema = {
}
}
aml_config_schema
=
{
aml_config_schema
=
{
'amlConfig'
:
{
Optional
(
'amlConfig'
)
:
{
'subscriptionId'
:
setType
(
'subscriptionId'
,
str
),
'subscriptionId'
:
setType
(
'subscriptionId'
,
str
),
'resourceGroup'
:
setType
(
'resourceGroup'
,
str
),
'resourceGroup'
:
setType
(
'resourceGroup'
,
str
),
'workspaceName'
:
setType
(
'workspaceName'
,
str
),
'workspaceName'
:
setType
(
'workspaceName'
,
str
),
...
@@ -262,6 +262,29 @@ aml_config_schema = {
...
@@ -262,6 +262,29 @@ aml_config_schema = {
}
}
}
}
heterogeneous_trial_schema
=
{
'trial'
:
{
'codeDir'
:
setPathCheck
(
'codeDir'
),
Optional
(
'nniManagerNFSMountPath'
):
setPathCheck
(
'nniManagerNFSMountPath'
),
Optional
(
'containerNFSMountPath'
):
setType
(
'containerNFSMountPath'
,
str
),
Optional
(
'nasMode'
):
setChoice
(
'nasMode'
,
'classic_mode'
,
'enas_mode'
,
'oneshot_mode'
,
'darts_mode'
),
'command'
:
setType
(
'command'
,
str
),
Optional
(
'gpuNum'
):
setNumberRange
(
'gpuNum'
,
int
,
0
,
99999
),
Optional
(
'cpuNum'
):
setNumberRange
(
'cpuNum'
,
int
,
0
,
99999
),
Optional
(
'memoryMB'
):
setType
(
'memoryMB'
,
int
),
Optional
(
'image'
):
setType
(
'image'
,
str
),
Optional
(
'virtualCluster'
):
setType
(
'virtualCluster'
,
str
),
Optional
(
'paiStorageConfigName'
):
setType
(
'paiStorageConfigName'
,
str
),
Optional
(
'paiConfigPath'
):
And
(
os
.
path
.
exists
,
error
=
SCHEMA_PATH_ERROR
%
'paiConfigPath'
)
}
}
heterogeneous_config_schema
=
{
'heterogeneousConfig'
:
{
'trainingServicePlatforms'
:
[
'local'
,
'remote'
,
'pai'
,
'aml'
]
}
}
adl_trial_schema
=
{
adl_trial_schema
=
{
'trial'
:{
'trial'
:{
'codeDir'
:
setType
(
'codeDir'
,
str
),
'codeDir'
:
setType
(
'codeDir'
,
str
),
...
@@ -404,7 +427,7 @@ remote_config_schema = {
...
@@ -404,7 +427,7 @@ remote_config_schema = {
}
}
machine_list_schema
=
{
machine_list_schema
=
{
'machineList'
:
[
Or
(
Optional
(
'machineList'
)
:
[
Or
(
{
{
'ip'
:
setType
(
'ip'
,
str
),
'ip'
:
setType
(
'ip'
,
str
),
Optional
(
'port'
):
setNumberRange
(
'port'
,
int
,
1
,
65535
),
Optional
(
'port'
):
setNumberRange
(
'port'
,
int
,
1
,
65535
),
...
@@ -438,6 +461,8 @@ training_service_schema_dict = {
...
@@ -438,6 +461,8 @@ training_service_schema_dict = {
'frameworkcontroller'
:
Schema
({
**
common_schema
,
**
frameworkcontroller_trial_schema
,
**
frameworkcontroller_config_schema
}),
'frameworkcontroller'
:
Schema
({
**
common_schema
,
**
frameworkcontroller_trial_schema
,
**
frameworkcontroller_config_schema
}),
'aml'
:
Schema
({
**
common_schema
,
**
aml_trial_schema
,
**
aml_config_schema
}),
'aml'
:
Schema
({
**
common_schema
,
**
aml_trial_schema
,
**
aml_config_schema
}),
'dlts'
:
Schema
({
**
common_schema
,
**
dlts_trial_schema
,
**
dlts_config_schema
}),
'dlts'
:
Schema
({
**
common_schema
,
**
dlts_trial_schema
,
**
dlts_config_schema
}),
'heterogeneous'
:
Schema
({
**
common_schema
,
**
heterogeneous_trial_schema
,
**
heterogeneous_config_schema
,
**
machine_list_schema
,
**
pai_config_schema
,
**
aml_config_schema
,
**
remote_config_schema
}),
}
}
...
@@ -454,6 +479,7 @@ class NNIConfigSchema:
...
@@ -454,6 +479,7 @@ class NNIConfigSchema:
self
.
validate_pai_trial_conifg
(
experiment_config
)
self
.
validate_pai_trial_conifg
(
experiment_config
)
self
.
validate_kubeflow_operators
(
experiment_config
)
self
.
validate_kubeflow_operators
(
experiment_config
)
self
.
validate_eth0_device
(
experiment_config
)
self
.
validate_eth0_device
(
experiment_config
)
self
.
validate_heterogeneous_platforms
(
experiment_config
)
def
validate_tuner_adivosr_assessor
(
self
,
experiment_config
):
def
validate_tuner_adivosr_assessor
(
self
,
experiment_config
):
if
experiment_config
.
get
(
'advisor'
):
if
experiment_config
.
get
(
'advisor'
):
...
@@ -563,3 +589,16 @@ class NNIConfigSchema:
...
@@ -563,3 +589,16 @@ class NNIConfigSchema:
and
not
experiment_config
.
get
(
'nniManagerIp'
)
\
and
not
experiment_config
.
get
(
'nniManagerIp'
)
\
and
'eth0'
not
in
netifaces
.
interfaces
():
and
'eth0'
not
in
netifaces
.
interfaces
():
raise
SchemaError
(
'This machine does not contain eth0 network device, please set nniManagerIp in config file!'
)
raise
SchemaError
(
'This machine does not contain eth0 network device, please set nniManagerIp in config file!'
)
def
validate_heterogeneous_platforms
(
self
,
experiment_config
):
required_config_name_map
=
{
'remote'
:
'machineList'
,
'aml'
:
'amlConfig'
,
'pai'
:
'paiConfig'
}
if
experiment_config
.
get
(
'trainingServicePlatform'
)
==
'heterogeneous'
:
for
platform
in
experiment_config
[
'heterogeneousConfig'
][
'trainingServicePlatforms'
]:
config_name
=
required_config_name_map
.
get
(
platform
)
if
config_name
and
not
experiment_config
.
get
(
config_name
):
raise
SchemaError
(
'Need to set {0} for {1} in heterogeneous mode!'
.
format
(
config_name
,
platform
))
\ No newline at end of file
nni/tools/nnictl/config_utils.py
View file @
df6145a2
...
@@ -85,7 +85,10 @@ class Experiments:
...
@@ -85,7 +85,10 @@ class Experiments:
self
.
experiments
=
self
.
read_file
()
self
.
experiments
=
self
.
read_file
()
if
expId
not
in
self
.
experiments
:
if
expId
not
in
self
.
experiments
:
return
False
return
False
self
.
experiments
[
expId
][
key
]
=
value
if
value
is
None
:
self
.
experiments
[
expId
].
pop
(
key
,
None
)
else
:
self
.
experiments
[
expId
][
key
]
=
value
self
.
write_file
()
self
.
write_file
()
return
True
return
True
...
...
nni/tools/nnictl/launcher.py
View file @
df6145a2
...
@@ -118,13 +118,6 @@ def set_local_config(experiment_config, port, config_file_name):
...
@@ -118,13 +118,6 @@ def set_local_config(experiment_config, port, config_file_name):
request_data
=
dict
()
request_data
=
dict
()
if
experiment_config
.
get
(
'localConfig'
):
if
experiment_config
.
get
(
'localConfig'
):
request_data
[
'local_config'
]
=
experiment_config
[
'localConfig'
]
request_data
[
'local_config'
]
=
experiment_config
[
'localConfig'
]
if
request_data
[
'local_config'
]:
if
request_data
[
'local_config'
].
get
(
'gpuIndices'
)
and
isinstance
(
request_data
[
'local_config'
].
get
(
'gpuIndices'
),
int
):
request_data
[
'local_config'
][
'gpuIndices'
]
=
str
(
request_data
[
'local_config'
].
get
(
'gpuIndices'
))
if
request_data
[
'local_config'
].
get
(
'maxTrialNumOnEachGpu'
):
request_data
[
'local_config'
][
'maxTrialNumOnEachGpu'
]
=
request_data
[
'local_config'
].
get
(
'maxTrialNumOnEachGpu'
)
if
request_data
[
'local_config'
].
get
(
'useActiveGpu'
):
request_data
[
'local_config'
][
'useActiveGpu'
]
=
request_data
[
'local_config'
].
get
(
'useActiveGpu'
)
response
=
rest_put
(
cluster_metadata_url
(
port
),
json
.
dumps
(
request_data
),
REST_TIME_OUT
)
response
=
rest_put
(
cluster_metadata_url
(
port
),
json
.
dumps
(
request_data
),
REST_TIME_OUT
)
err_message
=
''
err_message
=
''
if
not
response
or
not
check_response
(
response
):
if
not
response
or
not
check_response
(
response
):
...
@@ -306,6 +299,37 @@ def set_aml_config(experiment_config, port, config_file_name):
...
@@ -306,6 +299,37 @@ def set_aml_config(experiment_config, port, config_file_name):
#set trial_config
#set trial_config
return
set_trial_config
(
experiment_config
,
port
,
config_file_name
),
err_message
return
set_trial_config
(
experiment_config
,
port
,
config_file_name
),
err_message
def
set_heterogeneous_config
(
experiment_config
,
port
,
config_file_name
):
'''set heterogeneous configuration'''
heterogeneous_config_data
=
dict
()
heterogeneous_config_data
[
'heterogeneous_config'
]
=
experiment_config
[
'heterogeneousConfig'
]
platform_list
=
experiment_config
[
'heterogeneousConfig'
][
'trainingServicePlatforms'
]
for
platform
in
platform_list
:
if
platform
==
'aml'
:
heterogeneous_config_data
[
'aml_config'
]
=
experiment_config
[
'amlConfig'
]
elif
platform
==
'remote'
:
if
experiment_config
.
get
(
'remoteConfig'
):
heterogeneous_config_data
[
'remote_config'
]
=
experiment_config
[
'remoteConfig'
]
heterogeneous_config_data
[
'machine_list'
]
=
experiment_config
[
'machineList'
]
elif
platform
==
'local'
and
experiment_config
.
get
(
'localConfig'
):
heterogeneous_config_data
[
'local_config'
]
=
experiment_config
[
'localConfig'
]
elif
platform
==
'pai'
:
heterogeneous_config_data
[
'pai_config'
]
=
experiment_config
[
'paiConfig'
]
response
=
rest_put
(
cluster_metadata_url
(
port
),
json
.
dumps
(
heterogeneous_config_data
),
REST_TIME_OUT
)
err_message
=
None
if
not
response
or
not
response
.
status_code
==
200
:
if
response
is
not
None
:
err_message
=
response
.
text
_
,
stderr_full_path
=
get_log_path
(
config_file_name
)
with
open
(
stderr_full_path
,
'a+'
)
as
fout
:
fout
.
write
(
json
.
dumps
(
json
.
loads
(
err_message
),
indent
=
4
,
sort_keys
=
True
,
separators
=
(
','
,
':'
)))
return
False
,
err_message
result
,
message
=
setNNIManagerIp
(
experiment_config
,
port
,
config_file_name
)
if
not
result
:
return
result
,
message
#set trial_config
return
set_trial_config
(
experiment_config
,
port
,
config_file_name
),
err_message
def
set_experiment
(
experiment_config
,
mode
,
port
,
config_file_name
):
def
set_experiment
(
experiment_config
,
mode
,
port
,
config_file_name
):
'''Call startExperiment (rest POST /experiment) with yaml file content'''
'''Call startExperiment (rest POST /experiment) with yaml file content'''
request_data
=
dict
()
request_data
=
dict
()
...
@@ -387,6 +411,21 @@ def set_experiment(experiment_config, mode, port, config_file_name):
...
@@ -387,6 +411,21 @@ def set_experiment(experiment_config, mode, port, config_file_name):
{
'key'
:
'aml_config'
,
'value'
:
experiment_config
[
'amlConfig'
]})
{
'key'
:
'aml_config'
,
'value'
:
experiment_config
[
'amlConfig'
]})
request_data
[
'clusterMetaData'
].
append
(
request_data
[
'clusterMetaData'
].
append
(
{
'key'
:
'trial_config'
,
'value'
:
experiment_config
[
'trial'
]})
{
'key'
:
'trial_config'
,
'value'
:
experiment_config
[
'trial'
]})
elif
experiment_config
[
'trainingServicePlatform'
]
==
'heterogeneous'
:
request_data
[
'clusterMetaData'
].
append
(
{
'key'
:
'heterogeneous_config'
,
'value'
:
experiment_config
[
'heterogeneousConfig'
]})
platform_list
=
experiment_config
[
'heterogeneousConfig'
][
'trainingServicePlatforms'
]
request_dict
=
{
'aml'
:
{
'key'
:
'aml_config'
,
'value'
:
experiment_config
.
get
(
'amlConfig'
)},
'remote'
:
{
'key'
:
'machine_list'
,
'value'
:
experiment_config
.
get
(
'machineList'
)},
'pai'
:
{
'key'
:
'pai_config'
,
'value'
:
experiment_config
.
get
(
'paiConfig'
)},
'local'
:
{
'key'
:
'local_config'
,
'value'
:
experiment_config
.
get
(
'localConfig'
)}
}
for
platform
in
platform_list
:
if
request_dict
.
get
(
platform
):
request_data
[
'clusterMetaData'
].
append
(
request_dict
[
platform
])
request_data
[
'clusterMetaData'
].
append
(
{
'key'
:
'trial_config'
,
'value'
:
experiment_config
[
'trial'
]})
response
=
rest_post
(
experiment_url
(
port
),
json
.
dumps
(
request_data
),
REST_TIME_OUT
,
show_error
=
True
)
response
=
rest_post
(
experiment_url
(
port
),
json
.
dumps
(
request_data
),
REST_TIME_OUT
,
show_error
=
True
)
if
check_response
(
response
):
if
check_response
(
response
):
return
response
return
response
...
@@ -420,6 +459,8 @@ def set_platform_config(platform, experiment_config, port, config_file_name, res
...
@@ -420,6 +459,8 @@ def set_platform_config(platform, experiment_config, port, config_file_name, res
config_result
,
err_msg
=
set_dlts_config
(
experiment_config
,
port
,
config_file_name
)
config_result
,
err_msg
=
set_dlts_config
(
experiment_config
,
port
,
config_file_name
)
elif
platform
==
'aml'
:
elif
platform
==
'aml'
:
config_result
,
err_msg
=
set_aml_config
(
experiment_config
,
port
,
config_file_name
)
config_result
,
err_msg
=
set_aml_config
(
experiment_config
,
port
,
config_file_name
)
elif
platform
==
'heterogeneous'
:
config_result
,
err_msg
=
set_heterogeneous_config
(
experiment_config
,
port
,
config_file_name
)
else
:
else
:
raise
Exception
(
ERROR_INFO
%
'Unsupported platform!'
)
raise
Exception
(
ERROR_INFO
%
'Unsupported platform!'
)
exit
(
1
)
exit
(
1
)
...
...
Prev
1
…
5
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment