"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "6dc96db7bf12ff0118527f9f831e388490d44a9d"
Unverified Commit 00e4debb authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Allow postponed annotations for config classes (#4883)

parent 3d6ddb9a
...@@ -89,7 +89,7 @@ class ConfigBase: ...@@ -89,7 +89,7 @@ class ConfigBase:
""" """
self._base_path = utils.get_base_path() self._base_path = utils.get_base_path()
args = {utils.case_insensitive(key): value for key, value in kwargs.items()} args = {utils.case_insensitive(key): value for key, value in kwargs.items()}
for field in dataclasses.fields(self): for field in utils.fields(self):
value = args.pop(utils.case_insensitive(field.name), field.default) value = args.pop(utils.case_insensitive(field.name), field.default)
setattr(self, field.name, value) setattr(self, field.name, value)
if args: # maybe a key is misspelled if args: # maybe a key is misspelled
...@@ -98,7 +98,7 @@ class ConfigBase: ...@@ -98,7 +98,7 @@ class ConfigBase:
raise AttributeError(f'{class_name} does not have field(s) {fields}') raise AttributeError(f'{class_name} does not have field(s) {fields}')
# try to unpack nested config # try to unpack nested config
for field in dataclasses.fields(self): for field in utils.fields(self):
value = getattr(self, field.name) value = getattr(self, field.name)
if utils.is_instance(value, field.type): if utils.is_instance(value, field.type):
continue # already accepted by subclass, don't touch it continue # already accepted by subclass, don't touch it
...@@ -214,7 +214,7 @@ class ConfigBase: ...@@ -214,7 +214,7 @@ class ConfigBase:
For example local training service's ``trialGpuNumber`` will be copied from top level when not set, For example local training service's ``trialGpuNumber`` will be copied from top level when not set,
in this case it will be invoked like ``localConfig._canonicalize([experimentConfig])``. in this case it will be invoked like ``localConfig._canonicalize([experimentConfig])``.
""" """
for field in dataclasses.fields(self): for field in utils.fields(self):
value = getattr(self, field.name) value = getattr(self, field.name)
if isinstance(value, (Path, str)) and utils.is_path_like(field.type): if isinstance(value, (Path, str)) and utils.is_path_like(field.type):
setattr(self, field.name, utils.resolve_path(value, self._base_path)) setattr(self, field.name, utils.resolve_path(value, self._base_path))
...@@ -235,7 +235,7 @@ class ConfigBase: ...@@ -235,7 +235,7 @@ class ConfigBase:
2. Call ``_validate_canonical()`` on children config objects, including those inside list and dict 2. Call ``_validate_canonical()`` on children config objects, including those inside list and dict
""" """
utils.validate_type(self) utils.validate_type(self)
for field in dataclasses.fields(self): for field in utils.fields(self):
value = getattr(self, field.name) value = getattr(self, field.name)
_recursive_validate_child(value) _recursive_validate_child(value)
...@@ -247,7 +247,7 @@ class ConfigBase: ...@@ -247,7 +247,7 @@ class ConfigBase:
if hasattr(self, name) or name.startswith('_'): if hasattr(self, name) or name.startswith('_'):
super().__setattr__(name, value) super().__setattr__(name, value)
return return
if name in [field.name for field in dataclasses.fields(self)]: # might happend during __init__ if name in [field.name for field in utils.fields(self)]: # might happend during __init__
super().__setattr__(name, value) super().__setattr__(name, value)
return return
raise AttributeError(f'{type(self).__name__} does not have field {name}') raise AttributeError(f'{type(self).__name__} does not have field {name}')
......
...@@ -52,7 +52,6 @@ class TrainingServiceConfig(ConfigBase): ...@@ -52,7 +52,6 @@ class TrainingServiceConfig(ConfigBase):
def _validate_canonical(self): def _validate_canonical(self):
super()._validate_canonical() super()._validate_canonical()
cls = type(self) cls = type(self)
assert self.platform == cls.platform
if not Path(self.trial_code_directory).is_dir(): if not Path(self.trial_code_directory).is_dir():
raise ValueError(f'{cls.__name__}: trial_code_directory "{self.trial_code_directory}" is not a directory') raise ValueError(f'{cls.__name__}: trial_code_directory "{self.trial_code_directory}" is not a directory')
assert self.trial_gpu_number is None or self.trial_gpu_number >= 0 assert self.trial_gpu_number is None or self.trial_gpu_number >= 0
...@@ -18,11 +18,13 @@ __all__ = ['AmlConfig'] ...@@ -18,11 +18,13 @@ __all__ = ['AmlConfig']
from dataclasses import dataclass from dataclasses import dataclass
from typing_extensions import Literal
from ..training_service import TrainingServiceConfig from ..training_service import TrainingServiceConfig
@dataclass(init=False) @dataclass(init=False)
class AmlConfig(TrainingServiceConfig): class AmlConfig(TrainingServiceConfig):
platform: str = 'aml' platform: Literal['aml'] = 'aml'
subscription_id: str subscription_id: str
resource_group: str resource_group: str
workspace_name: str workspace_name: str
......
...@@ -4,13 +4,15 @@ ...@@ -4,13 +4,15 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional
from typing_extensions import Literal
from ..training_service import TrainingServiceConfig from ..training_service import TrainingServiceConfig
__all__ = ['DlcConfig'] __all__ = ['DlcConfig']
@dataclass(init=False) @dataclass(init=False)
class DlcConfig(TrainingServiceConfig): class DlcConfig(TrainingServiceConfig):
platform: str = 'dlc' platform: Literal['dlc'] = 'dlc'
type: str = 'Worker' type: str = 'Worker'
image: str # 'registry-vpc.{region}.aliyuncs.com/pai-dlc/tensorflow-training:1.15.0-cpu-py36-ubuntu18.04', image: str # 'registry-vpc.{region}.aliyuncs.com/pai-dlc/tensorflow-training:1.15.0-cpu-py36-ubuntu18.04',
job_type: str = 'TFJob' job_type: str = 'TFJob'
......
...@@ -19,6 +19,8 @@ __all__ = ['FrameworkControllerConfig', 'FrameworkControllerRoleConfig', 'Framew ...@@ -19,6 +19,8 @@ __all__ = ['FrameworkControllerConfig', 'FrameworkControllerRoleConfig', 'Framew
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Union from typing import List, Optional, Union
from typing_extensions import Literal
from ..base import ConfigBase from ..base import ConfigBase
from ..training_service import TrainingServiceConfig from ..training_service import TrainingServiceConfig
from .k8s_storage import K8sStorageConfig from .k8s_storage import K8sStorageConfig
...@@ -41,7 +43,7 @@ class FrameworkControllerRoleConfig(ConfigBase): ...@@ -41,7 +43,7 @@ class FrameworkControllerRoleConfig(ConfigBase):
@dataclass(init=False) @dataclass(init=False)
class FrameworkControllerConfig(TrainingServiceConfig): class FrameworkControllerConfig(TrainingServiceConfig):
platform: str = 'frameworkcontroller' platform: Literal['frameworkcontroller'] = 'frameworkcontroller'
storage: K8sStorageConfig storage: K8sStorageConfig
service_account_name: Optional[str] service_account_name: Optional[str]
task_roles: List[FrameworkControllerRoleConfig] task_roles: List[FrameworkControllerRoleConfig]
......
...@@ -10,6 +10,8 @@ __all__ = ['K8sStorageConfig', 'K8sAzureStorageConfig', 'K8sNfsConfig'] ...@@ -10,6 +10,8 @@ __all__ = ['K8sStorageConfig', 'K8sAzureStorageConfig', 'K8sNfsConfig']
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional
from typing_extensions import Literal
from ..base import ConfigBase from ..base import ConfigBase
@dataclass(init=False) @dataclass(init=False)
...@@ -34,13 +36,13 @@ class K8sStorageConfig(ConfigBase): ...@@ -34,13 +36,13 @@ class K8sStorageConfig(ConfigBase):
@dataclass(init=False) @dataclass(init=False)
class K8sNfsConfig(K8sStorageConfig): class K8sNfsConfig(K8sStorageConfig):
storage: str = 'nfs' storage: Literal['nfs'] = 'nfs'
server: str server: str
path: str path: str
@dataclass(init=False) @dataclass(init=False)
class K8sAzureStorageConfig(K8sStorageConfig): class K8sAzureStorageConfig(K8sStorageConfig):
storage: str = 'azureStorage' storage: Literal['azureStorage'] = 'azureStorage'
azure_account: str azure_account: str
azure_share: str azure_share: str
key_vault_name: str key_vault_name: str
......
...@@ -19,6 +19,8 @@ __all__ = ['KubeflowConfig', 'KubeflowRoleConfig'] ...@@ -19,6 +19,8 @@ __all__ = ['KubeflowConfig', 'KubeflowRoleConfig']
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Union from typing import Optional, Union
from typing_extensions import Literal
from ..base import ConfigBase from ..base import ConfigBase
from ..training_service import TrainingServiceConfig from ..training_service import TrainingServiceConfig
from .k8s_storage import K8sStorageConfig from .k8s_storage import K8sStorageConfig
...@@ -35,7 +37,7 @@ class KubeflowRoleConfig(ConfigBase): ...@@ -35,7 +37,7 @@ class KubeflowRoleConfig(ConfigBase):
@dataclass(init=False) @dataclass(init=False)
class KubeflowConfig(TrainingServiceConfig): class KubeflowConfig(TrainingServiceConfig):
platform: str = 'kubeflow' platform: Literal['kubeflow'] = 'kubeflow'
operator: str operator: str
api_version: str api_version: str
storage: K8sStorageConfig storage: K8sStorageConfig
......
...@@ -19,12 +19,14 @@ __all__ = ['LocalConfig'] ...@@ -19,12 +19,14 @@ __all__ = ['LocalConfig']
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Union from typing import List, Optional, Union
from typing_extensions import Literal
from ..training_service import TrainingServiceConfig from ..training_service import TrainingServiceConfig
from .. import utils from .. import utils
@dataclass(init=False) @dataclass(init=False)
class LocalConfig(TrainingServiceConfig): class LocalConfig(TrainingServiceConfig):
platform: str = 'local' platform: Literal['local'] = 'local'
use_active_gpu: Optional[bool] = None use_active_gpu: Optional[bool] = None
max_trial_number_per_gpu: int = 1 max_trial_number_per_gpu: int = 1
gpu_indices: Union[List[int], int, str, None] = None gpu_indices: Union[List[int], int, str, None] = None
......
...@@ -20,12 +20,14 @@ from dataclasses import dataclass ...@@ -20,12 +20,14 @@ from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
from typing_extensions import Literal
from ..training_service import TrainingServiceConfig from ..training_service import TrainingServiceConfig
from ..utils import PathLike from ..utils import PathLike
@dataclass(init=False) @dataclass(init=False)
class OpenpaiConfig(TrainingServiceConfig): class OpenpaiConfig(TrainingServiceConfig):
platform: str = 'openpai' platform: Literal['openpai'] = 'openpai'
host: str host: str
username: str username: str
token: str token: str
......
...@@ -21,6 +21,8 @@ from pathlib import Path ...@@ -21,6 +21,8 @@ from pathlib import Path
from typing import List, Optional, Union from typing import List, Optional, Union
import warnings import warnings
from typing_extensions import Literal
from ..base import ConfigBase from ..base import ConfigBase
from ..training_service import TrainingServiceConfig from ..training_service import TrainingServiceConfig
from .. import utils from .. import utils
...@@ -60,7 +62,7 @@ class RemoteMachineConfig(ConfigBase): ...@@ -60,7 +62,7 @@ class RemoteMachineConfig(ConfigBase):
@dataclass(init=False) @dataclass(init=False)
class RemoteConfig(TrainingServiceConfig): class RemoteConfig(TrainingServiceConfig):
platform: str = 'remote' platform: Literal['remote'] = 'remote'
machine_list: List[RemoteMachineConfig] machine_list: List[RemoteMachineConfig]
reuse_mode: bool = True reuse_mode: bool = True
......
...@@ -7,12 +7,25 @@ Utility functions for experiment config classes, internal part. ...@@ -7,12 +7,25 @@ Utility functions for experiment config classes, internal part.
If you are implementing a config class for a training service, it's unlikely you will need these. If you are implementing a config class for a training service, it's unlikely you will need these.
""" """
from __future__ import annotations
__all__ = [
'get_base_path', 'set_base_path', 'unset_base_path', 'resolve_path',
'case_insensitive', 'camel_case',
'fields', 'is_instance', 'validate_type', 'is_path_like',
'guess_config_type', 'guess_list_config_type',
'training_service_config_factory', 'load_training_service_config',
'get_ipv4_address'
]
import copy
import dataclasses import dataclasses
import importlib import importlib
import json import json
import os.path import os.path
from pathlib import Path from pathlib import Path
import socket import socket
import typing
import typeguard import typeguard
...@@ -20,36 +33,30 @@ import nni.runtime.config ...@@ -20,36 +33,30 @@ import nni.runtime.config
from .public import is_missing from .public import is_missing
__all__ = [ if typing.TYPE_CHECKING:
'get_base_path', 'set_base_path', 'unset_base_path', 'resolve_path', from ..base import ConfigBase
'case_insensitive', 'camel_case', from ..training_service import TrainingServiceConfig
'is_instance', 'validate_type', 'is_path_like',
'guess_config_type', 'guess_list_config_type',
'training_service_config_factory', 'load_training_service_config',
'get_ipv4_address'
]
## handle relative path ## ## handle relative path ##
_current_base_path = None _current_base_path: Path | None = None
def get_base_path(): def get_base_path() -> Path:
if _current_base_path is None: if _current_base_path is None:
return Path() return Path()
return _current_base_path return _current_base_path
def set_base_path(path): def set_base_path(path: Path) -> None:
global _current_base_path global _current_base_path
assert _current_base_path is None assert _current_base_path is None
_current_base_path = path _current_base_path = path
def unset_base_path(): def unset_base_path() -> None:
global _current_base_path global _current_base_path
_current_base_path = None _current_base_path = None
def resolve_path(path, base_path): def resolve_path(path: Path | str, base_path: Path) -> str:
if path is None: assert path is not None
return None
# Path.resolve() does not work on Windows when file not exist, so use os.path instead # Path.resolve() does not work on Windows when file not exist, so use os.path instead
path = os.path.expanduser(path) path = os.path.expanduser(path)
if not os.path.isabs(path): if not os.path.isabs(path):
...@@ -58,23 +65,32 @@ def resolve_path(path, base_path): ...@@ -58,23 +65,32 @@ def resolve_path(path, base_path):
## field name case convertion ## ## field name case convertion ##
def case_insensitive(key): def case_insensitive(key: str) -> str:
return key.lower().replace('_', '') return key.lower().replace('_', '')
def camel_case(key): def camel_case(key: str) -> str:
words = key.strip('_').split('_') words = key.strip('_').split('_')
return words[0] + ''.join(word.title() for word in words[1:]) return words[0] + ''.join(word.title() for word in words[1:])
## type hint utils ## ## type hint utils ##
def is_instance(value, type_hint): def fields(config: ConfigBase) -> list[dataclasses.Field]:
# Similar to `dataclasses.fields()`, but use `typing.get_types_hints()` to get `field.type`.
# This is useful when postponed evaluation is enabled.
ret = [copy.copy(field) for field in dataclasses.fields(config)]
types = typing.get_type_hints(type(config))
for field in ret:
field.type = types[field.name]
return ret
def is_instance(value, type_hint) -> bool:
try: try:
typeguard.check_type('_', value, type_hint) typeguard.check_type('_', value, type_hint)
except TypeError: except TypeError:
return False return False
return True return True
def validate_type(config): def validate_type(config: ConfigBase) -> None:
class_name = type(config).__name__ class_name = type(config).__name__
for field in dataclasses.fields(config): for field in dataclasses.fields(config):
value = getattr(config, field.name) value = getattr(config, field.name)
...@@ -84,17 +100,17 @@ def validate_type(config): ...@@ -84,17 +100,17 @@ def validate_type(config):
if not is_instance(value, field.type): if not is_instance(value, field.type):
raise ValueError(f'{class_name}: type of {field.name} ({repr(value)}) is not {field.type}') raise ValueError(f'{class_name}: type of {field.name} ({repr(value)}) is not {field.type}')
def is_path_like(type_hint): def is_path_like(type_hint) -> bool:
# only `PathLike` and `Any` accepts `Path`; check `int` to make sure it's not `Any` # only `PathLike` and `Any` accepts `Path`; check `int` to make sure it's not `Any`
return is_instance(Path(), type_hint) and not is_instance(1, type_hint) return is_instance(Path(), type_hint) and not is_instance(1, type_hint)
## type inference ## ## type inference ##
def guess_config_type(obj, type_hint): def guess_config_type(obj, type_hint) -> ConfigBase | None:
ret = guess_list_config_type([obj], type_hint, _hint_list_item=True) ret = guess_list_config_type([obj], type_hint, _hint_list_item=True)
return ret[0] if ret else None return ret[0] if ret else None
def guess_list_config_type(objs, type_hint, _hint_list_item=False): def guess_list_config_type(objs, type_hint, _hint_list_item=False) -> list[ConfigBase] | None:
# avoid circular import # avoid circular import
from ..base import ConfigBase from ..base import ConfigBase
from ..training_service import TrainingServiceConfig from ..training_service import TrainingServiceConfig
...@@ -144,20 +160,20 @@ def _all_subclasses(cls): ...@@ -144,20 +160,20 @@ def _all_subclasses(cls):
subclasses = set(cls.__subclasses__()) subclasses = set(cls.__subclasses__())
return subclasses.union(*[_all_subclasses(subclass) for subclass in subclasses]) return subclasses.union(*[_all_subclasses(subclass) for subclass in subclasses])
def training_service_config_factory(platform): def training_service_config_factory(platform: str) -> TrainingServiceConfig:
cls = _get_ts_config_class(platform) cls = _get_ts_config_class(platform)
if cls is None: if cls is None:
raise ValueError(f'Bad training service platform: {platform}') raise ValueError(f'Bad training service platform: {platform}')
return cls() return cls()
def load_training_service_config(config): def load_training_service_config(config) -> TrainingServiceConfig:
if isinstance(config, dict) and 'platform' in config: if isinstance(config, dict) and 'platform' in config:
cls = _get_ts_config_class(config['platform']) cls = _get_ts_config_class(config['platform'])
if cls is not None: if cls is not None:
return cls(**config) return cls(**config)
return config # not valid json, don't touch return config # not valid json, don't touch
def _get_ts_config_class(platform): def _get_ts_config_class(platform: str) -> type[TrainingServiceConfig] | None:
from ..training_service import TrainingServiceConfig # avoid circular import from ..training_service import TrainingServiceConfig # avoid circular import
# import all custom config classes so they can be found in TrainingServiceConfig.__subclasses__() # import all custom config classes so they can be found in TrainingServiceConfig.__subclasses__()
...@@ -175,7 +191,7 @@ def _get_ts_config_class(platform): ...@@ -175,7 +191,7 @@ def _get_ts_config_class(platform):
## misc ## ## misc ##
def get_ipv4_address(): def get_ipv4_address() -> str:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(('192.0.2.0', 80)) s.connect(('192.0.2.0', 80))
addr = s.getsockname()[0] addr = s.getsockname()[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment