Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
00e4debb
"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "6dc96db7bf12ff0118527f9f831e388490d44a9d"
Unverified
Commit
00e4debb
authored
Jun 29, 2022
by
liuzhe-lz
Committed by
GitHub
Jun 29, 2022
Browse files
Allow postponed annotations for config classes (#4883)
parent
3d6ddb9a
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
72 additions
and
41 deletions
+72
-41
nni/experiment/config/base.py
nni/experiment/config/base.py
+5
-5
nni/experiment/config/training_service.py
nni/experiment/config/training_service.py
+0
-1
nni/experiment/config/training_services/aml.py
nni/experiment/config/training_services/aml.py
+3
-1
nni/experiment/config/training_services/dlc.py
nni/experiment/config/training_services/dlc.py
+3
-1
nni/experiment/config/training_services/frameworkcontroller.py
...xperiment/config/training_services/frameworkcontroller.py
+3
-1
nni/experiment/config/training_services/k8s_storage.py
nni/experiment/config/training_services/k8s_storage.py
+4
-2
nni/experiment/config/training_services/kubeflow.py
nni/experiment/config/training_services/kubeflow.py
+3
-1
nni/experiment/config/training_services/local.py
nni/experiment/config/training_services/local.py
+3
-1
nni/experiment/config/training_services/openpai.py
nni/experiment/config/training_services/openpai.py
+3
-1
nni/experiment/config/training_services/remote.py
nni/experiment/config/training_services/remote.py
+3
-1
nni/experiment/config/utils/internal.py
nni/experiment/config/utils/internal.py
+42
-26
No files found.
nni/experiment/config/base.py
View file @
00e4debb
...
@@ -89,7 +89,7 @@ class ConfigBase:
...
@@ -89,7 +89,7 @@ class ConfigBase:
"""
"""
self
.
_base_path
=
utils
.
get_base_path
()
self
.
_base_path
=
utils
.
get_base_path
()
args
=
{
utils
.
case_insensitive
(
key
):
value
for
key
,
value
in
kwargs
.
items
()}
args
=
{
utils
.
case_insensitive
(
key
):
value
for
key
,
value
in
kwargs
.
items
()}
for
field
in
dataclasse
s
.
fields
(
self
):
for
field
in
util
s
.
fields
(
self
):
value
=
args
.
pop
(
utils
.
case_insensitive
(
field
.
name
),
field
.
default
)
value
=
args
.
pop
(
utils
.
case_insensitive
(
field
.
name
),
field
.
default
)
setattr
(
self
,
field
.
name
,
value
)
setattr
(
self
,
field
.
name
,
value
)
if
args
:
# maybe a key is misspelled
if
args
:
# maybe a key is misspelled
...
@@ -98,7 +98,7 @@ class ConfigBase:
...
@@ -98,7 +98,7 @@ class ConfigBase:
raise
AttributeError
(
f
'
{
class_name
}
does not have field(s)
{
fields
}
'
)
raise
AttributeError
(
f
'
{
class_name
}
does not have field(s)
{
fields
}
'
)
# try to unpack nested config
# try to unpack nested config
for
field
in
dataclasse
s
.
fields
(
self
):
for
field
in
util
s
.
fields
(
self
):
value
=
getattr
(
self
,
field
.
name
)
value
=
getattr
(
self
,
field
.
name
)
if
utils
.
is_instance
(
value
,
field
.
type
):
if
utils
.
is_instance
(
value
,
field
.
type
):
continue
# already accepted by subclass, don't touch it
continue
# already accepted by subclass, don't touch it
...
@@ -214,7 +214,7 @@ class ConfigBase:
...
@@ -214,7 +214,7 @@ class ConfigBase:
For example local training service's ``trialGpuNumber`` will be copied from top level when not set,
For example local training service's ``trialGpuNumber`` will be copied from top level when not set,
in this case it will be invoked like ``localConfig._canonicalize([experimentConfig])``.
in this case it will be invoked like ``localConfig._canonicalize([experimentConfig])``.
"""
"""
for
field
in
dataclasse
s
.
fields
(
self
):
for
field
in
util
s
.
fields
(
self
):
value
=
getattr
(
self
,
field
.
name
)
value
=
getattr
(
self
,
field
.
name
)
if
isinstance
(
value
,
(
Path
,
str
))
and
utils
.
is_path_like
(
field
.
type
):
if
isinstance
(
value
,
(
Path
,
str
))
and
utils
.
is_path_like
(
field
.
type
):
setattr
(
self
,
field
.
name
,
utils
.
resolve_path
(
value
,
self
.
_base_path
))
setattr
(
self
,
field
.
name
,
utils
.
resolve_path
(
value
,
self
.
_base_path
))
...
@@ -235,7 +235,7 @@ class ConfigBase:
...
@@ -235,7 +235,7 @@ class ConfigBase:
2. Call ``_validate_canonical()`` on children config objects, including those inside list and dict
2. Call ``_validate_canonical()`` on children config objects, including those inside list and dict
"""
"""
utils
.
validate_type
(
self
)
utils
.
validate_type
(
self
)
for
field
in
dataclasse
s
.
fields
(
self
):
for
field
in
util
s
.
fields
(
self
):
value
=
getattr
(
self
,
field
.
name
)
value
=
getattr
(
self
,
field
.
name
)
_recursive_validate_child
(
value
)
_recursive_validate_child
(
value
)
...
@@ -247,7 +247,7 @@ class ConfigBase:
...
@@ -247,7 +247,7 @@ class ConfigBase:
if
hasattr
(
self
,
name
)
or
name
.
startswith
(
'_'
):
if
hasattr
(
self
,
name
)
or
name
.
startswith
(
'_'
):
super
().
__setattr__
(
name
,
value
)
super
().
__setattr__
(
name
,
value
)
return
return
if
name
in
[
field
.
name
for
field
in
dataclasse
s
.
fields
(
self
)]:
# might happend during __init__
if
name
in
[
field
.
name
for
field
in
util
s
.
fields
(
self
)]:
# might happend during __init__
super
().
__setattr__
(
name
,
value
)
super
().
__setattr__
(
name
,
value
)
return
return
raise
AttributeError
(
f
'
{
type
(
self
).
__name__
}
does not have field
{
name
}
'
)
raise
AttributeError
(
f
'
{
type
(
self
).
__name__
}
does not have field
{
name
}
'
)
...
...
nni/experiment/config/training_service.py
View file @
00e4debb
...
@@ -52,7 +52,6 @@ class TrainingServiceConfig(ConfigBase):
...
@@ -52,7 +52,6 @@ class TrainingServiceConfig(ConfigBase):
def
_validate_canonical
(
self
):
def
_validate_canonical
(
self
):
super
().
_validate_canonical
()
super
().
_validate_canonical
()
cls
=
type
(
self
)
cls
=
type
(
self
)
assert
self
.
platform
==
cls
.
platform
if
not
Path
(
self
.
trial_code_directory
).
is_dir
():
if
not
Path
(
self
.
trial_code_directory
).
is_dir
():
raise
ValueError
(
f
'
{
cls
.
__name__
}
: trial_code_directory "
{
self
.
trial_code_directory
}
" is not a directory'
)
raise
ValueError
(
f
'
{
cls
.
__name__
}
: trial_code_directory "
{
self
.
trial_code_directory
}
" is not a directory'
)
assert
self
.
trial_gpu_number
is
None
or
self
.
trial_gpu_number
>=
0
assert
self
.
trial_gpu_number
is
None
or
self
.
trial_gpu_number
>=
0
nni/experiment/config/training_services/aml.py
View file @
00e4debb
...
@@ -18,11 +18,13 @@ __all__ = ['AmlConfig']
...
@@ -18,11 +18,13 @@ __all__ = ['AmlConfig']
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing_extensions
import
Literal
from
..training_service
import
TrainingServiceConfig
from
..training_service
import
TrainingServiceConfig
@
dataclass
(
init
=
False
)
@
dataclass
(
init
=
False
)
class
AmlConfig
(
TrainingServiceConfig
):
class
AmlConfig
(
TrainingServiceConfig
):
platform
:
str
=
'aml'
platform
:
Literal
[
'aml'
]
=
'aml'
subscription_id
:
str
subscription_id
:
str
resource_group
:
str
resource_group
:
str
workspace_name
:
str
workspace_name
:
str
...
...
nni/experiment/config/training_services/dlc.py
View file @
00e4debb
...
@@ -4,13 +4,15 @@
...
@@ -4,13 +4,15 @@
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Optional
from
typing
import
Optional
from
typing_extensions
import
Literal
from
..training_service
import
TrainingServiceConfig
from
..training_service
import
TrainingServiceConfig
__all__
=
[
'DlcConfig'
]
__all__
=
[
'DlcConfig'
]
@
dataclass
(
init
=
False
)
@
dataclass
(
init
=
False
)
class
DlcConfig
(
TrainingServiceConfig
):
class
DlcConfig
(
TrainingServiceConfig
):
platform
:
str
=
'dlc'
platform
:
Literal
[
'dlc'
]
=
'dlc'
type
:
str
=
'Worker'
type
:
str
=
'Worker'
image
:
str
# 'registry-vpc.{region}.aliyuncs.com/pai-dlc/tensorflow-training:1.15.0-cpu-py36-ubuntu18.04',
image
:
str
# 'registry-vpc.{region}.aliyuncs.com/pai-dlc/tensorflow-training:1.15.0-cpu-py36-ubuntu18.04',
job_type
:
str
=
'TFJob'
job_type
:
str
=
'TFJob'
...
...
nni/experiment/config/training_services/frameworkcontroller.py
View file @
00e4debb
...
@@ -19,6 +19,8 @@ __all__ = ['FrameworkControllerConfig', 'FrameworkControllerRoleConfig', 'Framew
...
@@ -19,6 +19,8 @@ __all__ = ['FrameworkControllerConfig', 'FrameworkControllerRoleConfig', 'Framew
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
,
Union
from
typing
import
List
,
Optional
,
Union
from
typing_extensions
import
Literal
from
..base
import
ConfigBase
from
..base
import
ConfigBase
from
..training_service
import
TrainingServiceConfig
from
..training_service
import
TrainingServiceConfig
from
.k8s_storage
import
K8sStorageConfig
from
.k8s_storage
import
K8sStorageConfig
...
@@ -41,7 +43,7 @@ class FrameworkControllerRoleConfig(ConfigBase):
...
@@ -41,7 +43,7 @@ class FrameworkControllerRoleConfig(ConfigBase):
@
dataclass
(
init
=
False
)
@
dataclass
(
init
=
False
)
class
FrameworkControllerConfig
(
TrainingServiceConfig
):
class
FrameworkControllerConfig
(
TrainingServiceConfig
):
platform
:
str
=
'frameworkcontroller'
platform
:
Literal
[
'frameworkcontroller'
]
=
'frameworkcontroller'
storage
:
K8sStorageConfig
storage
:
K8sStorageConfig
service_account_name
:
Optional
[
str
]
service_account_name
:
Optional
[
str
]
task_roles
:
List
[
FrameworkControllerRoleConfig
]
task_roles
:
List
[
FrameworkControllerRoleConfig
]
...
...
nni/experiment/config/training_services/k8s_storage.py
View file @
00e4debb
...
@@ -10,6 +10,8 @@ __all__ = ['K8sStorageConfig', 'K8sAzureStorageConfig', 'K8sNfsConfig']
...
@@ -10,6 +10,8 @@ __all__ = ['K8sStorageConfig', 'K8sAzureStorageConfig', 'K8sNfsConfig']
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Optional
from
typing
import
Optional
from
typing_extensions
import
Literal
from
..base
import
ConfigBase
from
..base
import
ConfigBase
@
dataclass
(
init
=
False
)
@
dataclass
(
init
=
False
)
...
@@ -34,13 +36,13 @@ class K8sStorageConfig(ConfigBase):
...
@@ -34,13 +36,13 @@ class K8sStorageConfig(ConfigBase):
@
dataclass
(
init
=
False
)
@
dataclass
(
init
=
False
)
class
K8sNfsConfig
(
K8sStorageConfig
):
class
K8sNfsConfig
(
K8sStorageConfig
):
storage
:
str
=
'nfs'
storage
:
Literal
[
'nfs'
]
=
'nfs'
server
:
str
server
:
str
path
:
str
path
:
str
@
dataclass
(
init
=
False
)
@
dataclass
(
init
=
False
)
class
K8sAzureStorageConfig
(
K8sStorageConfig
):
class
K8sAzureStorageConfig
(
K8sStorageConfig
):
storage
:
str
=
'azureStorage'
storage
:
Literal
[
'azureStorage'
]
=
'azureStorage'
azure_account
:
str
azure_account
:
str
azure_share
:
str
azure_share
:
str
key_vault_name
:
str
key_vault_name
:
str
...
...
nni/experiment/config/training_services/kubeflow.py
View file @
00e4debb
...
@@ -19,6 +19,8 @@ __all__ = ['KubeflowConfig', 'KubeflowRoleConfig']
...
@@ -19,6 +19,8 @@ __all__ = ['KubeflowConfig', 'KubeflowRoleConfig']
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Union
from
typing
import
Optional
,
Union
from
typing_extensions
import
Literal
from
..base
import
ConfigBase
from
..base
import
ConfigBase
from
..training_service
import
TrainingServiceConfig
from
..training_service
import
TrainingServiceConfig
from
.k8s_storage
import
K8sStorageConfig
from
.k8s_storage
import
K8sStorageConfig
...
@@ -35,7 +37,7 @@ class KubeflowRoleConfig(ConfigBase):
...
@@ -35,7 +37,7 @@ class KubeflowRoleConfig(ConfigBase):
@
dataclass
(
init
=
False
)
@
dataclass
(
init
=
False
)
class
KubeflowConfig
(
TrainingServiceConfig
):
class
KubeflowConfig
(
TrainingServiceConfig
):
platform
:
str
=
'kubeflow'
platform
:
Literal
[
'kubeflow'
]
=
'kubeflow'
operator
:
str
operator
:
str
api_version
:
str
api_version
:
str
storage
:
K8sStorageConfig
storage
:
K8sStorageConfig
...
...
nni/experiment/config/training_services/local.py
View file @
00e4debb
...
@@ -19,12 +19,14 @@ __all__ = ['LocalConfig']
...
@@ -19,12 +19,14 @@ __all__ = ['LocalConfig']
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
,
Union
from
typing
import
List
,
Optional
,
Union
from
typing_extensions
import
Literal
from
..training_service
import
TrainingServiceConfig
from
..training_service
import
TrainingServiceConfig
from
..
import
utils
from
..
import
utils
@
dataclass
(
init
=
False
)
@
dataclass
(
init
=
False
)
class
LocalConfig
(
TrainingServiceConfig
):
class
LocalConfig
(
TrainingServiceConfig
):
platform
:
str
=
'local'
platform
:
Literal
[
'local'
]
=
'local'
use_active_gpu
:
Optional
[
bool
]
=
None
use_active_gpu
:
Optional
[
bool
]
=
None
max_trial_number_per_gpu
:
int
=
1
max_trial_number_per_gpu
:
int
=
1
gpu_indices
:
Union
[
List
[
int
],
int
,
str
,
None
]
=
None
gpu_indices
:
Union
[
List
[
int
],
int
,
str
,
None
]
=
None
...
...
nni/experiment/config/training_services/openpai.py
View file @
00e4debb
...
@@ -20,12 +20,14 @@ from dataclasses import dataclass
...
@@ -20,12 +20,14 @@ from dataclasses import dataclass
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Dict
,
Optional
,
Union
from
typing
import
Dict
,
Optional
,
Union
from
typing_extensions
import
Literal
from
..training_service
import
TrainingServiceConfig
from
..training_service
import
TrainingServiceConfig
from
..utils
import
PathLike
from
..utils
import
PathLike
@
dataclass
(
init
=
False
)
@
dataclass
(
init
=
False
)
class
OpenpaiConfig
(
TrainingServiceConfig
):
class
OpenpaiConfig
(
TrainingServiceConfig
):
platform
:
str
=
'openpai'
platform
:
Literal
[
'openpai'
]
=
'openpai'
host
:
str
host
:
str
username
:
str
username
:
str
token
:
str
token
:
str
...
...
nni/experiment/config/training_services/remote.py
View file @
00e4debb
...
@@ -21,6 +21,8 @@ from pathlib import Path
...
@@ -21,6 +21,8 @@ from pathlib import Path
from
typing
import
List
,
Optional
,
Union
from
typing
import
List
,
Optional
,
Union
import
warnings
import
warnings
from
typing_extensions
import
Literal
from
..base
import
ConfigBase
from
..base
import
ConfigBase
from
..training_service
import
TrainingServiceConfig
from
..training_service
import
TrainingServiceConfig
from
..
import
utils
from
..
import
utils
...
@@ -60,7 +62,7 @@ class RemoteMachineConfig(ConfigBase):
...
@@ -60,7 +62,7 @@ class RemoteMachineConfig(ConfigBase):
@
dataclass
(
init
=
False
)
@
dataclass
(
init
=
False
)
class
RemoteConfig
(
TrainingServiceConfig
):
class
RemoteConfig
(
TrainingServiceConfig
):
platform
:
str
=
'remote'
platform
:
Literal
[
'remote'
]
=
'remote'
machine_list
:
List
[
RemoteMachineConfig
]
machine_list
:
List
[
RemoteMachineConfig
]
reuse_mode
:
bool
=
True
reuse_mode
:
bool
=
True
...
...
nni/experiment/config/utils/internal.py
View file @
00e4debb
...
@@ -7,12 +7,25 @@ Utility functions for experiment config classes, internal part.
...
@@ -7,12 +7,25 @@ Utility functions for experiment config classes, internal part.
If you are implementing a config class for a training service, it's unlikely you will need these.
If you are implementing a config class for a training service, it's unlikely you will need these.
"""
"""
from
__future__
import
annotations
__all__
=
[
'get_base_path'
,
'set_base_path'
,
'unset_base_path'
,
'resolve_path'
,
'case_insensitive'
,
'camel_case'
,
'fields'
,
'is_instance'
,
'validate_type'
,
'is_path_like'
,
'guess_config_type'
,
'guess_list_config_type'
,
'training_service_config_factory'
,
'load_training_service_config'
,
'get_ipv4_address'
]
import
copy
import
dataclasses
import
dataclasses
import
importlib
import
importlib
import
json
import
json
import
os.path
import
os.path
from
pathlib
import
Path
from
pathlib
import
Path
import
socket
import
socket
import
typing
import
typeguard
import
typeguard
...
@@ -20,36 +33,30 @@ import nni.runtime.config
...
@@ -20,36 +33,30 @@ import nni.runtime.config
from
.public
import
is_missing
from
.public
import
is_missing
__all__
=
[
if
typing
.
TYPE_CHECKING
:
'get_base_path'
,
'set_base_path'
,
'unset_base_path'
,
'resolve_path'
,
from
..base
import
ConfigBase
'case_insensitive'
,
'camel_case'
,
from
..training_service
import
TrainingServiceConfig
'is_instance'
,
'validate_type'
,
'is_path_like'
,
'guess_config_type'
,
'guess_list_config_type'
,
'training_service_config_factory'
,
'load_training_service_config'
,
'get_ipv4_address'
]
## handle relative path ##
## handle relative path ##
_current_base_path
=
None
_current_base_path
:
Path
|
None
=
None
def
get_base_path
():
def
get_base_path
()
->
Path
:
if
_current_base_path
is
None
:
if
_current_base_path
is
None
:
return
Path
()
return
Path
()
return
_current_base_path
return
_current_base_path
def
set_base_path
(
path
)
:
def
set_base_path
(
path
:
Path
)
->
None
:
global
_current_base_path
global
_current_base_path
assert
_current_base_path
is
None
assert
_current_base_path
is
None
_current_base_path
=
path
_current_base_path
=
path
def
unset_base_path
():
def
unset_base_path
()
->
None
:
global
_current_base_path
global
_current_base_path
_current_base_path
=
None
_current_base_path
=
None
def
resolve_path
(
path
,
base_path
):
def
resolve_path
(
path
:
Path
|
str
,
base_path
:
Path
)
->
str
:
if
path
is
None
:
assert
path
is
not
None
return
None
# Path.resolve() does not work on Windows when file not exist, so use os.path instead
# Path.resolve() does not work on Windows when file not exist, so use os.path instead
path
=
os
.
path
.
expanduser
(
path
)
path
=
os
.
path
.
expanduser
(
path
)
if
not
os
.
path
.
isabs
(
path
):
if
not
os
.
path
.
isabs
(
path
):
...
@@ -58,23 +65,32 @@ def resolve_path(path, base_path):
...
@@ -58,23 +65,32 @@ def resolve_path(path, base_path):
## field name case convertion ##
## field name case convertion ##
def
case_insensitive
(
key
)
:
def
case_insensitive
(
key
:
str
)
->
str
:
return
key
.
lower
().
replace
(
'_'
,
''
)
return
key
.
lower
().
replace
(
'_'
,
''
)
def
camel_case
(
key
)
:
def
camel_case
(
key
:
str
)
->
str
:
words
=
key
.
strip
(
'_'
).
split
(
'_'
)
words
=
key
.
strip
(
'_'
).
split
(
'_'
)
return
words
[
0
]
+
''
.
join
(
word
.
title
()
for
word
in
words
[
1
:])
return
words
[
0
]
+
''
.
join
(
word
.
title
()
for
word
in
words
[
1
:])
## type hint utils ##
## type hint utils ##
def
is_instance
(
value
,
type_hint
):
def
fields
(
config
:
ConfigBase
)
->
list
[
dataclasses
.
Field
]:
# Similar to `dataclasses.fields()`, but use `typing.get_types_hints()` to get `field.type`.
# This is useful when postponed evaluation is enabled.
ret
=
[
copy
.
copy
(
field
)
for
field
in
dataclasses
.
fields
(
config
)]
types
=
typing
.
get_type_hints
(
type
(
config
))
for
field
in
ret
:
field
.
type
=
types
[
field
.
name
]
return
ret
def
is_instance
(
value
,
type_hint
)
->
bool
:
try
:
try
:
typeguard
.
check_type
(
'_'
,
value
,
type_hint
)
typeguard
.
check_type
(
'_'
,
value
,
type_hint
)
except
TypeError
:
except
TypeError
:
return
False
return
False
return
True
return
True
def
validate_type
(
config
)
:
def
validate_type
(
config
:
ConfigBase
)
->
None
:
class_name
=
type
(
config
).
__name__
class_name
=
type
(
config
).
__name__
for
field
in
dataclasses
.
fields
(
config
):
for
field
in
dataclasses
.
fields
(
config
):
value
=
getattr
(
config
,
field
.
name
)
value
=
getattr
(
config
,
field
.
name
)
...
@@ -84,17 +100,17 @@ def validate_type(config):
...
@@ -84,17 +100,17 @@ def validate_type(config):
if
not
is_instance
(
value
,
field
.
type
):
if
not
is_instance
(
value
,
field
.
type
):
raise
ValueError
(
f
'
{
class_name
}
: type of
{
field
.
name
}
(
{
repr
(
value
)
}
) is not
{
field
.
type
}
'
)
raise
ValueError
(
f
'
{
class_name
}
: type of
{
field
.
name
}
(
{
repr
(
value
)
}
) is not
{
field
.
type
}
'
)
def
is_path_like
(
type_hint
):
def
is_path_like
(
type_hint
)
->
bool
:
# only `PathLike` and `Any` accepts `Path`; check `int` to make sure it's not `Any`
# only `PathLike` and `Any` accepts `Path`; check `int` to make sure it's not `Any`
return
is_instance
(
Path
(),
type_hint
)
and
not
is_instance
(
1
,
type_hint
)
return
is_instance
(
Path
(),
type_hint
)
and
not
is_instance
(
1
,
type_hint
)
## type inference ##
## type inference ##
def
guess_config_type
(
obj
,
type_hint
):
def
guess_config_type
(
obj
,
type_hint
)
->
ConfigBase
|
None
:
ret
=
guess_list_config_type
([
obj
],
type_hint
,
_hint_list_item
=
True
)
ret
=
guess_list_config_type
([
obj
],
type_hint
,
_hint_list_item
=
True
)
return
ret
[
0
]
if
ret
else
None
return
ret
[
0
]
if
ret
else
None
def
guess_list_config_type
(
objs
,
type_hint
,
_hint_list_item
=
False
):
def
guess_list_config_type
(
objs
,
type_hint
,
_hint_list_item
=
False
)
->
list
[
ConfigBase
]
|
None
:
# avoid circular import
# avoid circular import
from
..base
import
ConfigBase
from
..base
import
ConfigBase
from
..training_service
import
TrainingServiceConfig
from
..training_service
import
TrainingServiceConfig
...
@@ -144,20 +160,20 @@ def _all_subclasses(cls):
...
@@ -144,20 +160,20 @@ def _all_subclasses(cls):
subclasses
=
set
(
cls
.
__subclasses__
())
subclasses
=
set
(
cls
.
__subclasses__
())
return
subclasses
.
union
(
*
[
_all_subclasses
(
subclass
)
for
subclass
in
subclasses
])
return
subclasses
.
union
(
*
[
_all_subclasses
(
subclass
)
for
subclass
in
subclasses
])
def
training_service_config_factory
(
platform
)
:
def
training_service_config_factory
(
platform
:
str
)
->
TrainingServiceConfig
:
cls
=
_get_ts_config_class
(
platform
)
cls
=
_get_ts_config_class
(
platform
)
if
cls
is
None
:
if
cls
is
None
:
raise
ValueError
(
f
'Bad training service platform:
{
platform
}
'
)
raise
ValueError
(
f
'Bad training service platform:
{
platform
}
'
)
return
cls
()
return
cls
()
def
load_training_service_config
(
config
):
def
load_training_service_config
(
config
)
->
TrainingServiceConfig
:
if
isinstance
(
config
,
dict
)
and
'platform'
in
config
:
if
isinstance
(
config
,
dict
)
and
'platform'
in
config
:
cls
=
_get_ts_config_class
(
config
[
'platform'
])
cls
=
_get_ts_config_class
(
config
[
'platform'
])
if
cls
is
not
None
:
if
cls
is
not
None
:
return
cls
(
**
config
)
return
cls
(
**
config
)
return
config
# not valid json, don't touch
return
config
# not valid json, don't touch
def
_get_ts_config_class
(
platform
)
:
def
_get_ts_config_class
(
platform
:
str
)
->
type
[
TrainingServiceConfig
]
|
None
:
from
..training_service
import
TrainingServiceConfig
# avoid circular import
from
..training_service
import
TrainingServiceConfig
# avoid circular import
# import all custom config classes so they can be found in TrainingServiceConfig.__subclasses__()
# import all custom config classes so they can be found in TrainingServiceConfig.__subclasses__()
...
@@ -175,7 +191,7 @@ def _get_ts_config_class(platform):
...
@@ -175,7 +191,7 @@ def _get_ts_config_class(platform):
## misc ##
## misc ##
def
get_ipv4_address
():
def
get_ipv4_address
()
->
str
:
s
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_DGRAM
)
s
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_DGRAM
)
s
.
connect
((
'192.0.2.0'
,
80
))
s
.
connect
((
'192.0.2.0'
,
80
))
addr
=
s
.
getsockname
()[
0
]
addr
=
s
.
getsockname
()[
0
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment