Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
00e4debb
Unverified
Commit
00e4debb
authored
Jun 29, 2022
by
liuzhe-lz
Committed by
GitHub
Jun 29, 2022
Browse files
Allow postponed annotations for config classes (#4883)
parent
3d6ddb9a
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
72 additions
and
41 deletions
+72
-41
nni/experiment/config/base.py
nni/experiment/config/base.py
+5
-5
nni/experiment/config/training_service.py
nni/experiment/config/training_service.py
+0
-1
nni/experiment/config/training_services/aml.py
nni/experiment/config/training_services/aml.py
+3
-1
nni/experiment/config/training_services/dlc.py
nni/experiment/config/training_services/dlc.py
+3
-1
nni/experiment/config/training_services/frameworkcontroller.py
...xperiment/config/training_services/frameworkcontroller.py
+3
-1
nni/experiment/config/training_services/k8s_storage.py
nni/experiment/config/training_services/k8s_storage.py
+4
-2
nni/experiment/config/training_services/kubeflow.py
nni/experiment/config/training_services/kubeflow.py
+3
-1
nni/experiment/config/training_services/local.py
nni/experiment/config/training_services/local.py
+3
-1
nni/experiment/config/training_services/openpai.py
nni/experiment/config/training_services/openpai.py
+3
-1
nni/experiment/config/training_services/remote.py
nni/experiment/config/training_services/remote.py
+3
-1
nni/experiment/config/utils/internal.py
nni/experiment/config/utils/internal.py
+42
-26
No files found.
nni/experiment/config/base.py
View file @
00e4debb
...
...
@@ -89,7 +89,7 @@ class ConfigBase:
"""
self
.
_base_path
=
utils
.
get_base_path
()
args
=
{
utils
.
case_insensitive
(
key
):
value
for
key
,
value
in
kwargs
.
items
()}
for
field
in
dataclasse
s
.
fields
(
self
):
for
field
in
util
s
.
fields
(
self
):
value
=
args
.
pop
(
utils
.
case_insensitive
(
field
.
name
),
field
.
default
)
setattr
(
self
,
field
.
name
,
value
)
if
args
:
# maybe a key is misspelled
...
...
@@ -98,7 +98,7 @@ class ConfigBase:
raise
AttributeError
(
f
'
{
class_name
}
does not have field(s)
{
fields
}
'
)
# try to unpack nested config
for
field
in
dataclasse
s
.
fields
(
self
):
for
field
in
util
s
.
fields
(
self
):
value
=
getattr
(
self
,
field
.
name
)
if
utils
.
is_instance
(
value
,
field
.
type
):
continue
# already accepted by subclass, don't touch it
...
...
@@ -214,7 +214,7 @@ class ConfigBase:
For example local training service's ``trialGpuNumber`` will be copied from top level when not set,
in this case it will be invoked like ``localConfig._canonicalize([experimentConfig])``.
"""
for
field
in
dataclasse
s
.
fields
(
self
):
for
field
in
util
s
.
fields
(
self
):
value
=
getattr
(
self
,
field
.
name
)
if
isinstance
(
value
,
(
Path
,
str
))
and
utils
.
is_path_like
(
field
.
type
):
setattr
(
self
,
field
.
name
,
utils
.
resolve_path
(
value
,
self
.
_base_path
))
...
...
@@ -235,7 +235,7 @@ class ConfigBase:
2. Call ``_validate_canonical()`` on children config objects, including those inside list and dict
"""
utils
.
validate_type
(
self
)
for
field
in
dataclasse
s
.
fields
(
self
):
for
field
in
util
s
.
fields
(
self
):
value
=
getattr
(
self
,
field
.
name
)
_recursive_validate_child
(
value
)
...
...
@@ -247,7 +247,7 @@ class ConfigBase:
if
hasattr
(
self
,
name
)
or
name
.
startswith
(
'_'
):
super
().
__setattr__
(
name
,
value
)
return
if
name
in
[
field
.
name
for
field
in
dataclasse
s
.
fields
(
self
)]:
# might happend during __init__
if
name
in
[
field
.
name
for
field
in
util
s
.
fields
(
self
)]:
# might happend during __init__
super
().
__setattr__
(
name
,
value
)
return
raise
AttributeError
(
f
'
{
type
(
self
).
__name__
}
does not have field
{
name
}
'
)
...
...
nni/experiment/config/training_service.py
View file @
00e4debb
...
...
@@ -52,7 +52,6 @@ class TrainingServiceConfig(ConfigBase):
def
_validate_canonical
(
self
):
super
().
_validate_canonical
()
cls
=
type
(
self
)
assert
self
.
platform
==
cls
.
platform
if
not
Path
(
self
.
trial_code_directory
).
is_dir
():
raise
ValueError
(
f
'
{
cls
.
__name__
}
: trial_code_directory "
{
self
.
trial_code_directory
}
" is not a directory'
)
assert
self
.
trial_gpu_number
is
None
or
self
.
trial_gpu_number
>=
0
nni/experiment/config/training_services/aml.py
View file @
00e4debb
...
...
@@ -18,11 +18,13 @@ __all__ = ['AmlConfig']
from
dataclasses
import
dataclass
from
typing_extensions
import
Literal
from
..training_service
import
TrainingServiceConfig
@
dataclass
(
init
=
False
)
class
AmlConfig
(
TrainingServiceConfig
):
platform
:
str
=
'aml'
platform
:
Literal
[
'aml'
]
=
'aml'
subscription_id
:
str
resource_group
:
str
workspace_name
:
str
...
...
nni/experiment/config/training_services/dlc.py
View file @
00e4debb
...
...
@@ -4,13 +4,15 @@
from
dataclasses
import
dataclass
from
typing
import
Optional
from
typing_extensions
import
Literal
from
..training_service
import
TrainingServiceConfig
__all__
=
[
'DlcConfig'
]
@
dataclass
(
init
=
False
)
class
DlcConfig
(
TrainingServiceConfig
):
platform
:
str
=
'dlc'
platform
:
Literal
[
'dlc'
]
=
'dlc'
type
:
str
=
'Worker'
image
:
str
# 'registry-vpc.{region}.aliyuncs.com/pai-dlc/tensorflow-training:1.15.0-cpu-py36-ubuntu18.04',
job_type
:
str
=
'TFJob'
...
...
nni/experiment/config/training_services/frameworkcontroller.py
View file @
00e4debb
...
...
@@ -19,6 +19,8 @@ __all__ = ['FrameworkControllerConfig', 'FrameworkControllerRoleConfig', 'Framew
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
,
Union
from
typing_extensions
import
Literal
from
..base
import
ConfigBase
from
..training_service
import
TrainingServiceConfig
from
.k8s_storage
import
K8sStorageConfig
...
...
@@ -41,7 +43,7 @@ class FrameworkControllerRoleConfig(ConfigBase):
@
dataclass
(
init
=
False
)
class
FrameworkControllerConfig
(
TrainingServiceConfig
):
platform
:
str
=
'frameworkcontroller'
platform
:
Literal
[
'frameworkcontroller'
]
=
'frameworkcontroller'
storage
:
K8sStorageConfig
service_account_name
:
Optional
[
str
]
task_roles
:
List
[
FrameworkControllerRoleConfig
]
...
...
nni/experiment/config/training_services/k8s_storage.py
View file @
00e4debb
...
...
@@ -10,6 +10,8 @@ __all__ = ['K8sStorageConfig', 'K8sAzureStorageConfig', 'K8sNfsConfig']
from
dataclasses
import
dataclass
from
typing
import
Optional
from
typing_extensions
import
Literal
from
..base
import
ConfigBase
@
dataclass
(
init
=
False
)
...
...
@@ -34,13 +36,13 @@ class K8sStorageConfig(ConfigBase):
@
dataclass
(
init
=
False
)
class
K8sNfsConfig
(
K8sStorageConfig
):
storage
:
str
=
'nfs'
storage
:
Literal
[
'nfs'
]
=
'nfs'
server
:
str
path
:
str
@
dataclass
(
init
=
False
)
class
K8sAzureStorageConfig
(
K8sStorageConfig
):
storage
:
str
=
'azureStorage'
storage
:
Literal
[
'azureStorage'
]
=
'azureStorage'
azure_account
:
str
azure_share
:
str
key_vault_name
:
str
...
...
nni/experiment/config/training_services/kubeflow.py
View file @
00e4debb
...
...
@@ -19,6 +19,8 @@ __all__ = ['KubeflowConfig', 'KubeflowRoleConfig']
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Union
from
typing_extensions
import
Literal
from
..base
import
ConfigBase
from
..training_service
import
TrainingServiceConfig
from
.k8s_storage
import
K8sStorageConfig
...
...
@@ -35,7 +37,7 @@ class KubeflowRoleConfig(ConfigBase):
@
dataclass
(
init
=
False
)
class
KubeflowConfig
(
TrainingServiceConfig
):
platform
:
str
=
'kubeflow'
platform
:
Literal
[
'kubeflow'
]
=
'kubeflow'
operator
:
str
api_version
:
str
storage
:
K8sStorageConfig
...
...
nni/experiment/config/training_services/local.py
View file @
00e4debb
...
...
@@ -19,12 +19,14 @@ __all__ = ['LocalConfig']
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
,
Union
from
typing_extensions
import
Literal
from
..training_service
import
TrainingServiceConfig
from
..
import
utils
@
dataclass
(
init
=
False
)
class
LocalConfig
(
TrainingServiceConfig
):
platform
:
str
=
'local'
platform
:
Literal
[
'local'
]
=
'local'
use_active_gpu
:
Optional
[
bool
]
=
None
max_trial_number_per_gpu
:
int
=
1
gpu_indices
:
Union
[
List
[
int
],
int
,
str
,
None
]
=
None
...
...
nni/experiment/config/training_services/openpai.py
View file @
00e4debb
...
...
@@ -20,12 +20,14 @@ from dataclasses import dataclass
from
pathlib
import
Path
from
typing
import
Dict
,
Optional
,
Union
from
typing_extensions
import
Literal
from
..training_service
import
TrainingServiceConfig
from
..utils
import
PathLike
@
dataclass
(
init
=
False
)
class
OpenpaiConfig
(
TrainingServiceConfig
):
platform
:
str
=
'openpai'
platform
:
Literal
[
'openpai'
]
=
'openpai'
host
:
str
username
:
str
token
:
str
...
...
nni/experiment/config/training_services/remote.py
View file @
00e4debb
...
...
@@ -21,6 +21,8 @@ from pathlib import Path
from
typing
import
List
,
Optional
,
Union
import
warnings
from
typing_extensions
import
Literal
from
..base
import
ConfigBase
from
..training_service
import
TrainingServiceConfig
from
..
import
utils
...
...
@@ -60,7 +62,7 @@ class RemoteMachineConfig(ConfigBase):
@
dataclass
(
init
=
False
)
class
RemoteConfig
(
TrainingServiceConfig
):
platform
:
str
=
'remote'
platform
:
Literal
[
'remote'
]
=
'remote'
machine_list
:
List
[
RemoteMachineConfig
]
reuse_mode
:
bool
=
True
...
...
nni/experiment/config/utils/internal.py
View file @
00e4debb
...
...
@@ -7,12 +7,25 @@ Utility functions for experiment config classes, internal part.
If you are implementing a config class for a training service, it's unlikely you will need these.
"""
from
__future__
import
annotations
__all__
=
[
'get_base_path'
,
'set_base_path'
,
'unset_base_path'
,
'resolve_path'
,
'case_insensitive'
,
'camel_case'
,
'fields'
,
'is_instance'
,
'validate_type'
,
'is_path_like'
,
'guess_config_type'
,
'guess_list_config_type'
,
'training_service_config_factory'
,
'load_training_service_config'
,
'get_ipv4_address'
]
import
copy
import
dataclasses
import
importlib
import
json
import
os.path
from
pathlib
import
Path
import
socket
import
typing
import
typeguard
...
...
@@ -20,36 +33,30 @@ import nni.runtime.config
from
.public
import
is_missing
__all__
=
[
'get_base_path'
,
'set_base_path'
,
'unset_base_path'
,
'resolve_path'
,
'case_insensitive'
,
'camel_case'
,
'is_instance'
,
'validate_type'
,
'is_path_like'
,
'guess_config_type'
,
'guess_list_config_type'
,
'training_service_config_factory'
,
'load_training_service_config'
,
'get_ipv4_address'
]
if
typing
.
TYPE_CHECKING
:
from
..base
import
ConfigBase
from
..training_service
import
TrainingServiceConfig
## handle relative path ##
_current_base_path
=
None
_current_base_path
:
Path
|
None
=
None
def
get_base_path
():
def
get_base_path
()
->
Path
:
if
_current_base_path
is
None
:
return
Path
()
return
_current_base_path
def
set_base_path
(
path
)
:
def
set_base_path
(
path
:
Path
)
->
None
:
global
_current_base_path
assert
_current_base_path
is
None
_current_base_path
=
path
def
unset_base_path
():
def
unset_base_path
()
->
None
:
global
_current_base_path
_current_base_path
=
None
def
resolve_path
(
path
,
base_path
):
if
path
is
None
:
return
None
def
resolve_path
(
path
:
Path
|
str
,
base_path
:
Path
)
->
str
:
assert
path
is
not
None
# Path.resolve() does not work on Windows when file not exist, so use os.path instead
path
=
os
.
path
.
expanduser
(
path
)
if
not
os
.
path
.
isabs
(
path
):
...
...
@@ -58,23 +65,32 @@ def resolve_path(path, base_path):
## field name case convertion ##
def
case_insensitive
(
key
)
:
def
case_insensitive
(
key
:
str
)
->
str
:
return
key
.
lower
().
replace
(
'_'
,
''
)
def
camel_case
(
key
)
:
def
camel_case
(
key
:
str
)
->
str
:
words
=
key
.
strip
(
'_'
).
split
(
'_'
)
return
words
[
0
]
+
''
.
join
(
word
.
title
()
for
word
in
words
[
1
:])
## type hint utils ##
def
is_instance
(
value
,
type_hint
):
def
fields
(
config
:
ConfigBase
)
->
list
[
dataclasses
.
Field
]:
# Similar to `dataclasses.fields()`, but use `typing.get_types_hints()` to get `field.type`.
# This is useful when postponed evaluation is enabled.
ret
=
[
copy
.
copy
(
field
)
for
field
in
dataclasses
.
fields
(
config
)]
types
=
typing
.
get_type_hints
(
type
(
config
))
for
field
in
ret
:
field
.
type
=
types
[
field
.
name
]
return
ret
def
is_instance
(
value
,
type_hint
)
->
bool
:
try
:
typeguard
.
check_type
(
'_'
,
value
,
type_hint
)
except
TypeError
:
return
False
return
True
def
validate_type
(
config
)
:
def
validate_type
(
config
:
ConfigBase
)
->
None
:
class_name
=
type
(
config
).
__name__
for
field
in
dataclasses
.
fields
(
config
):
value
=
getattr
(
config
,
field
.
name
)
...
...
@@ -84,17 +100,17 @@ def validate_type(config):
if
not
is_instance
(
value
,
field
.
type
):
raise
ValueError
(
f
'
{
class_name
}
: type of
{
field
.
name
}
(
{
repr
(
value
)
}
) is not
{
field
.
type
}
'
)
def
is_path_like
(
type_hint
):
def
is_path_like
(
type_hint
)
->
bool
:
# only `PathLike` and `Any` accepts `Path`; check `int` to make sure it's not `Any`
return
is_instance
(
Path
(),
type_hint
)
and
not
is_instance
(
1
,
type_hint
)
## type inference ##
def
guess_config_type
(
obj
,
type_hint
):
def
guess_config_type
(
obj
,
type_hint
)
->
ConfigBase
|
None
:
ret
=
guess_list_config_type
([
obj
],
type_hint
,
_hint_list_item
=
True
)
return
ret
[
0
]
if
ret
else
None
def
guess_list_config_type
(
objs
,
type_hint
,
_hint_list_item
=
False
):
def
guess_list_config_type
(
objs
,
type_hint
,
_hint_list_item
=
False
)
->
list
[
ConfigBase
]
|
None
:
# avoid circular import
from
..base
import
ConfigBase
from
..training_service
import
TrainingServiceConfig
...
...
@@ -144,20 +160,20 @@ def _all_subclasses(cls):
subclasses
=
set
(
cls
.
__subclasses__
())
return
subclasses
.
union
(
*
[
_all_subclasses
(
subclass
)
for
subclass
in
subclasses
])
def
training_service_config_factory
(
platform
)
:
def
training_service_config_factory
(
platform
:
str
)
->
TrainingServiceConfig
:
cls
=
_get_ts_config_class
(
platform
)
if
cls
is
None
:
raise
ValueError
(
f
'Bad training service platform:
{
platform
}
'
)
return
cls
()
def
load_training_service_config
(
config
):
def
load_training_service_config
(
config
)
->
TrainingServiceConfig
:
if
isinstance
(
config
,
dict
)
and
'platform'
in
config
:
cls
=
_get_ts_config_class
(
config
[
'platform'
])
if
cls
is
not
None
:
return
cls
(
**
config
)
return
config
# not valid json, don't touch
def
_get_ts_config_class
(
platform
)
:
def
_get_ts_config_class
(
platform
:
str
)
->
type
[
TrainingServiceConfig
]
|
None
:
from
..training_service
import
TrainingServiceConfig
# avoid circular import
# import all custom config classes so they can be found in TrainingServiceConfig.__subclasses__()
...
...
@@ -175,7 +191,7 @@ def _get_ts_config_class(platform):
## misc ##
def
get_ipv4_address
():
def
get_ipv4_address
()
->
str
:
s
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_DGRAM
)
s
.
connect
((
'192.0.2.0'
,
80
))
addr
=
s
.
getsockname
()[
0
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment