Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
a4e3c240
Unverified
Commit
a4e3c240
authored
Mar 10, 2021
by
liuzhe-lz
Committed by
GitHub
Mar 10, 2021
Browse files
fix config v2 relative path (#3439)
parent
e457047c
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
19 additions
and
6 deletions
+19
-6
nni/experiment/config/base.py
nni/experiment/config/base.py
+2
-0
nni/experiment/config/common.py
nni/experiment/config/common.py
+10
-3
nni/experiment/config/util.py
nni/experiment/config/util.py
+5
-2
nni/tools/nnictl/launcher.py
nni/tools/nnictl/launcher.py
+2
-1
No files found.
nni/experiment/config/base.py
View file @
a4e3c240
...
...
@@ -47,6 +47,8 @@ class ConfigBase:
They will be converted to snake_case automatically.
If a field is missing and don't have default value, it will be set to `dataclasses.MISSING`.
"""
if
'basepath'
in
kwargs
:
_base_path
=
kwargs
.
pop
(
'basepath'
)
kwargs
=
{
util
.
case_insensitive
(
key
):
value
for
key
,
value
in
kwargs
.
items
()}
if
_base_path
is
None
:
_base_path
=
Path
()
...
...
nni/experiment/config/common.py
View file @
a4e3c240
...
...
@@ -68,17 +68,24 @@ class ExperimentConfig(ConfigBase):
training_service
:
Union
[
TrainingServiceConfig
,
List
[
TrainingServiceConfig
]]
def
__init__
(
self
,
training_service_platform
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
**
kwargs
):
base_path
=
kwargs
.
pop
(
'_base_path'
,
None
)
kwargs
=
util
.
case_insensitive
(
kwargs
)
if
training_service_platform
is
not
None
:
assert
'trainingservice'
not
in
kwargs
kwargs
[
'trainingservice'
]
=
util
.
training_service_config_factory
(
platform
=
training_service_platform
)
kwargs
[
'trainingservice'
]
=
util
.
training_service_config_factory
(
platform
=
training_service_platform
,
base_path
=
base_path
)
elif
isinstance
(
kwargs
.
get
(
'trainingservice'
),
(
dict
,
list
)):
# dict means a single training service
# list means hybrid training service
kwargs
[
'trainingservice'
]
=
util
.
training_service_config_factory
(
config
=
kwargs
[
'trainingservice'
])
kwargs
[
'trainingservice'
]
=
util
.
training_service_config_factory
(
config
=
kwargs
[
'trainingservice'
],
base_path
=
base_path
)
else
:
raise
RuntimeError
(
'Unsupported Training service configuration!'
)
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
_base_path
=
base_path
,
**
kwargs
)
for
algo_type
in
[
'tuner'
,
'assessor'
,
'advisor'
]:
if
isinstance
(
kwargs
.
get
(
algo_type
),
dict
):
setattr
(
self
,
algo_type
,
_AlgorithmConfig
(
**
kwargs
.
pop
(
algo_type
)))
...
...
nni/experiment/config/util.py
View file @
a4e3c240
...
...
@@ -29,7 +29,10 @@ def canonical_path(path: Optional[PathLike]) -> Optional[str]:
def
count
(
*
values
)
->
int
:
return
sum
(
value
is
not
None
and
value
is
not
False
for
value
in
values
)
def
training_service_config_factory
(
platform
:
Union
[
str
,
List
[
str
]]
=
None
,
config
:
Union
[
List
,
Dict
]
=
None
):
# -> TrainingServiceConfig
def
training_service_config_factory
(
platform
:
Union
[
str
,
List
[
str
]]
=
None
,
config
:
Union
[
List
,
Dict
]
=
None
,
base_path
:
Optional
[
Path
]
=
None
):
# -> TrainingServiceConfig
from
.common
import
TrainingServiceConfig
ts_configs
=
[]
if
platform
is
not
None
:
...
...
@@ -47,7 +50,7 @@ def training_service_config_factory(platform: Union[str, List[str]] = None, conf
for
conf
in
configs
:
if
conf
[
'platform'
]
not
in
supported_platforms
:
raise
RuntimeError
(
f
'Unrecognized platform
{
conf
[
"platform"
]
}
'
)
ts_configs
.
append
(
supported_platforms
[
conf
[
'platform'
]](
**
conf
))
ts_configs
.
append
(
supported_platforms
[
conf
[
'platform'
]](
_base_path
=
base_path
,
**
conf
))
return
ts_configs
if
len
(
ts_configs
)
>
1
else
ts_configs
[
0
]
def
load_config
(
Type
,
value
):
...
...
nni/tools/nnictl/launcher.py
View file @
a4e3c240
...
...
@@ -3,6 +3,7 @@
import
json
import
os
from
pathlib
import
Path
import
sys
import
string
import
random
...
...
@@ -590,7 +591,7 @@ def create_experiment(args):
except
Exception
:
print_warning
(
'Validation with V1 schema failed. Trying to convert from V2 format...'
)
try
:
config
=
ExperimentConfig
(
**
experiment_config
)
config
=
ExperimentConfig
(
_base_path
=
Path
(
config_path
).
parent
,
**
experiment_config
)
experiment_config
=
convert
.
to_v1_yaml
(
config
)
except
Exception
as
e
:
print_error
(
f
'Config in v2 format validation failed, the config error in v2 format is:
{
repr
(
e
)
}
'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment