Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
d965808e
Unverified
Commit
d965808e
authored
Apr 22, 2021
by
liuzhe-lz
Committed by
GitHub
Apr 22, 2021
Browse files
Fix k8s and hybrid config (#3563)
parent
9c1f5344
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
60 additions
and
26 deletions
+60
-26
docs/en_US/reference/experiment_config.rst
docs/en_US/reference/experiment_config.rst
+3
-0
nni/__main__.py
nni/__main__.py
+6
-0
nni/experiment/config/base.py
nni/experiment/config/base.py
+1
-0
nni/experiment/config/common.py
nni/experiment/config/common.py
+7
-0
nni/experiment/config/convert.py
nni/experiment/config/convert.py
+30
-25
nni/tools/nnictl/launcher.py
nni/tools/nnictl/launcher.py
+2
-0
ts/nni_manager/core/nnimanager.ts
ts/nni_manager/core/nnimanager.ts
+11
-1
No files found.
docs/en_US/reference/experiment_config.rst
View file @
d965808e
...
...
@@ -35,6 +35,7 @@ Local Mode
trialCommand: python mnist.py
trialCodeDirectory: .
trialGpuNumber: 1
trialConcurrency: 2
maxExperimentDuration: 24h
maxTrialNumber: 100
tuner:
...
...
@@ -59,6 +60,7 @@ Local Mode (Inline Search Space)
_value: [0.0001, 0.1]
trialCommand: python mnist.py
trialGpuNumber: 1
trialConcurrency: 2
tuner:
name: TPE
classArgs:
...
...
@@ -77,6 +79,7 @@ Remote Mode
trialCommand: python mnist.py
trialCodeDirectory: .
trialGpuNumber: 1
trialConcurrency: 2
maxExperimentDuration: 24h
maxTrialNumber: 100
tuner:
...
...
nni/__main__.py
View file @
d965808e
...
...
@@ -32,6 +32,12 @@ def main():
if
exp_params
.
get
(
'deprecated'
,
{}).
get
(
'multiThread'
):
enable_multi_thread
()
if
'trainingServicePlatform'
in
exp_params
:
# config schema is v1
from
.experiment.config.convert
import
convert_algo
for
algo_type
in
[
'tuner'
,
'assessor'
,
'advisor'
]:
if
algo_type
in
exp_params
:
convert_algo
(
algo_type
,
exp_params
,
exp_params
)
if
exp_params
.
get
(
'advisor'
)
is
not
None
:
# advisor is enabled and starts to run
_run_advisor
(
exp_params
)
...
...
nni/experiment/config/base.py
View file @
d965808e
...
...
@@ -82,6 +82,7 @@ class ConfigBase:
Convert config to JSON object.
The keys of returned object will be camelCase.
"""
self
.
validate
()
return
dataclasses
.
asdict
(
self
.
canonical
(),
dict_factory
=
lambda
items
:
dict
((
util
.
camel_case
(
k
),
v
)
for
k
,
v
in
items
if
v
is
not
None
)
...
...
nni/experiment/config/common.py
View file @
d965808e
...
...
@@ -98,6 +98,13 @@ class ExperimentConfig(ConfigBase):
if
isinstance
(
kwargs
.
get
(
algo_type
),
dict
):
setattr
(
self
,
algo_type
,
_AlgorithmConfig
(
**
kwargs
.
pop
(
algo_type
)))
def
canonical
(
self
):
ret
=
super
().
canonical
()
if
isinstance
(
ret
.
training_service
,
list
):
for
i
,
ts
in
enumerate
(
ret
.
training_service
):
ret
.
training_service
[
i
]
=
ts
.
canonical
()
return
ret
def
validate
(
self
,
initialized_tuner
:
bool
=
False
)
->
None
:
super
().
validate
()
if
initialized_tuner
:
...
...
nni/experiment/config/convert.py
View file @
d965808e
...
...
@@ -45,31 +45,8 @@ def to_v2(v1) -> ExperimentConfig:
_move_field
(
v1_trial
,
v2
,
'gpuNum'
,
'trial_gpu_number'
)
for
algo_type
in
[
'tuner'
,
'assessor'
,
'advisor'
]:
if
algo_type
not
in
v1
:
continue
v1_algo
=
v1
.
pop
(
algo_type
)
builtin_name
=
v1_algo
.
pop
(
f
'builtin
{
algo_type
.
title
()
}
Name'
,
None
)
class_args
=
v1_algo
.
pop
(
'classArgs'
,
None
)
if
builtin_name
is
not
None
:
v2_algo
=
AlgorithmConfig
(
name
=
builtin_name
,
class_args
=
class_args
)
else
:
class_directory
=
util
.
canonical_path
(
v1_algo
.
pop
(
'codeDir'
))
class_file_name
=
v1_algo
.
pop
(
'classFileName'
)
assert
class_file_name
.
endswith
(
'.py'
)
class_name
=
class_file_name
[:
-
3
]
+
'.'
+
v1_algo
.
pop
(
'className'
)
v2_algo
=
CustomAlgorithmConfig
(
class_name
=
class_name
,
class_directory
=
class_directory
,
class_args
=
class_args
)
setattr
(
v2
,
algo_type
,
v2_algo
)
_deprecate
(
v1_algo
,
v2
,
'includeIntermediateResults'
)
_move_field
(
v1_algo
,
v2
,
'gpuIndices'
,
'tuner_gpu_indices'
)
assert
not
v1_algo
,
v1_algo
if
algo_type
in
v1
:
convert_algo
(
algo_type
,
v1
,
v2
)
ts
=
v2
.
training_service
...
...
@@ -259,3 +236,31 @@ def _deprecate(v1, v2, key):
if
v2
.
_deprecated
is
None
:
v2
.
_deprecated
=
{}
v2
.
_deprecated
[
key
]
=
v1
.
pop
(
key
)
def
convert_algo
(
algo_type
,
v1
,
v2
):
if
algo_type
not
in
v1
:
return
None
v1_algo
=
v1
.
pop
(
algo_type
)
builtin_name
=
v1_algo
.
pop
(
f
'builtin
{
algo_type
.
title
()
}
Name'
,
None
)
class_args
=
v1_algo
.
pop
(
'classArgs'
,
None
)
if
builtin_name
is
not
None
:
v2_algo
=
AlgorithmConfig
(
name
=
builtin_name
,
class_args
=
class_args
)
else
:
class_directory
=
util
.
canonical_path
(
v1_algo
.
pop
(
'codeDir'
))
class_file_name
=
v1_algo
.
pop
(
'classFileName'
)
assert
class_file_name
.
endswith
(
'.py'
)
class_name
=
class_file_name
[:
-
3
]
+
'.'
+
v1_algo
.
pop
(
'className'
)
v2_algo
=
CustomAlgorithmConfig
(
class_name
=
class_name
,
class_directory
=
class_directory
,
class_args
=
class_args
)
setattr
(
v2
,
algo_type
,
v2_algo
)
_deprecate
(
v1_algo
,
v2
,
'includeIntermediateResults'
)
_move_field
(
v1_algo
,
v2
,
'gpuIndices'
,
'tuner_gpu_indices'
)
assert
not
v1_algo
,
v1_algo
return
v2_algo
nni/tools/nnictl/launcher.py
View file @
d965808e
...
...
@@ -333,6 +333,8 @@ def launch_experiment(args, experiment_config, mode, experiment_id, config_versi
# start rest server
if
config_version
==
1
:
platform
=
experiment_config
[
'trainingServicePlatform'
]
elif
isinstance
(
experiment_config
[
'trainingService'
],
list
):
platform
=
'hybrid'
else
:
platform
=
experiment_config
[
'trainingService'
][
'platform'
]
...
...
ts/nni_manager/core/nnimanager.ts
View file @
d965808e
...
...
@@ -409,7 +409,17 @@ class NNIManager implements Manager {
private
async
initTrainingService
(
config
:
ExperimentConfig
):
Promise
<
TrainingService
>
{
this
.
config
=
config
;
const
platform
=
Array
.
isArray
(
config
.
trainingService
)
?
'
hybrid
'
:
config
.
trainingService
.
platform
;
let
platform
:
string
;
if
(
Array
.
isArray
(
config
.
trainingService
))
{
platform
=
'
hybrid
'
;
}
else
if
(
config
.
trainingService
.
platform
)
{
platform
=
config
.
trainingService
.
platform
;
}
else
{
platform
=
(
config
as
any
).
trainingServicePlatform
;
}
if
(
!
platform
)
{
throw
new
Error
(
'
Cannot detect training service platform
'
);
}
if
([
'
remote
'
,
'
pai
'
,
'
aml
'
,
'
hybrid
'
].
includes
(
platform
))
{
const
module_
=
await
import
(
'
../training_service/reusable/routerTrainingService
'
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment