Unverified Commit 00e4debb authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Allow postponed annotations for config classes (#4883)

parent 3d6ddb9a
......@@ -89,7 +89,7 @@ class ConfigBase:
"""
self._base_path = utils.get_base_path()
args = {utils.case_insensitive(key): value for key, value in kwargs.items()}
for field in dataclasses.fields(self):
for field in utils.fields(self):
value = args.pop(utils.case_insensitive(field.name), field.default)
setattr(self, field.name, value)
if args: # maybe a key is misspelled
......@@ -98,7 +98,7 @@ class ConfigBase:
raise AttributeError(f'{class_name} does not have field(s) {fields}')
# try to unpack nested config
for field in dataclasses.fields(self):
for field in utils.fields(self):
value = getattr(self, field.name)
if utils.is_instance(value, field.type):
continue # already accepted by subclass, don't touch it
......@@ -214,7 +214,7 @@ class ConfigBase:
For example local training service's ``trialGpuNumber`` will be copied from top level when not set,
in this case it will be invoked like ``localConfig._canonicalize([experimentConfig])``.
"""
for field in dataclasses.fields(self):
for field in utils.fields(self):
value = getattr(self, field.name)
if isinstance(value, (Path, str)) and utils.is_path_like(field.type):
setattr(self, field.name, utils.resolve_path(value, self._base_path))
......@@ -235,7 +235,7 @@ class ConfigBase:
2. Call ``_validate_canonical()`` on children config objects, including those inside list and dict
"""
utils.validate_type(self)
for field in dataclasses.fields(self):
for field in utils.fields(self):
value = getattr(self, field.name)
_recursive_validate_child(value)
......@@ -247,7 +247,7 @@ class ConfigBase:
if hasattr(self, name) or name.startswith('_'):
super().__setattr__(name, value)
return
if name in [field.name for field in dataclasses.fields(self)]: # might happend during __init__
if name in [field.name for field in utils.fields(self)]: # might happend during __init__
super().__setattr__(name, value)
return
raise AttributeError(f'{type(self).__name__} does not have field {name}')
......
......@@ -52,7 +52,6 @@ class TrainingServiceConfig(ConfigBase):
def _validate_canonical(self):
super()._validate_canonical()
cls = type(self)
assert self.platform == cls.platform
if not Path(self.trial_code_directory).is_dir():
raise ValueError(f'{cls.__name__}: trial_code_directory "{self.trial_code_directory}" is not a directory')
assert self.trial_gpu_number is None or self.trial_gpu_number >= 0
......@@ -18,11 +18,13 @@ __all__ = ['AmlConfig']
from dataclasses import dataclass
from typing_extensions import Literal
from ..training_service import TrainingServiceConfig
@dataclass(init=False)
class AmlConfig(TrainingServiceConfig):
platform: str = 'aml'
platform: Literal['aml'] = 'aml'
subscription_id: str
resource_group: str
workspace_name: str
......
......@@ -4,13 +4,15 @@
from dataclasses import dataclass
from typing import Optional
from typing_extensions import Literal
from ..training_service import TrainingServiceConfig
__all__ = ['DlcConfig']
@dataclass(init=False)
class DlcConfig(TrainingServiceConfig):
platform: str = 'dlc'
platform: Literal['dlc'] = 'dlc'
type: str = 'Worker'
image: str # 'registry-vpc.{region}.aliyuncs.com/pai-dlc/tensorflow-training:1.15.0-cpu-py36-ubuntu18.04',
job_type: str = 'TFJob'
......
......@@ -19,6 +19,8 @@ __all__ = ['FrameworkControllerConfig', 'FrameworkControllerRoleConfig', 'Framew
from dataclasses import dataclass
from typing import List, Optional, Union
from typing_extensions import Literal
from ..base import ConfigBase
from ..training_service import TrainingServiceConfig
from .k8s_storage import K8sStorageConfig
......@@ -41,7 +43,7 @@ class FrameworkControllerRoleConfig(ConfigBase):
@dataclass(init=False)
class FrameworkControllerConfig(TrainingServiceConfig):
platform: str = 'frameworkcontroller'
platform: Literal['frameworkcontroller'] = 'frameworkcontroller'
storage: K8sStorageConfig
service_account_name: Optional[str]
task_roles: List[FrameworkControllerRoleConfig]
......
......@@ -10,6 +10,8 @@ __all__ = ['K8sStorageConfig', 'K8sAzureStorageConfig', 'K8sNfsConfig']
from dataclasses import dataclass
from typing import Optional
from typing_extensions import Literal
from ..base import ConfigBase
@dataclass(init=False)
......@@ -34,13 +36,13 @@ class K8sStorageConfig(ConfigBase):
@dataclass(init=False)
class K8sNfsConfig(K8sStorageConfig):
storage: str = 'nfs'
storage: Literal['nfs'] = 'nfs'
server: str
path: str
@dataclass(init=False)
class K8sAzureStorageConfig(K8sStorageConfig):
storage: str = 'azureStorage'
storage: Literal['azureStorage'] = 'azureStorage'
azure_account: str
azure_share: str
key_vault_name: str
......
......@@ -19,6 +19,8 @@ __all__ = ['KubeflowConfig', 'KubeflowRoleConfig']
from dataclasses import dataclass
from typing import Optional, Union
from typing_extensions import Literal
from ..base import ConfigBase
from ..training_service import TrainingServiceConfig
from .k8s_storage import K8sStorageConfig
......@@ -35,7 +37,7 @@ class KubeflowRoleConfig(ConfigBase):
@dataclass(init=False)
class KubeflowConfig(TrainingServiceConfig):
platform: str = 'kubeflow'
platform: Literal['kubeflow'] = 'kubeflow'
operator: str
api_version: str
storage: K8sStorageConfig
......
......@@ -19,12 +19,14 @@ __all__ = ['LocalConfig']
from dataclasses import dataclass
from typing import List, Optional, Union
from typing_extensions import Literal
from ..training_service import TrainingServiceConfig
from .. import utils
@dataclass(init=False)
class LocalConfig(TrainingServiceConfig):
platform: str = 'local'
platform: Literal['local'] = 'local'
use_active_gpu: Optional[bool] = None
max_trial_number_per_gpu: int = 1
gpu_indices: Union[List[int], int, str, None] = None
......
......@@ -20,12 +20,14 @@ from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Optional, Union
from typing_extensions import Literal
from ..training_service import TrainingServiceConfig
from ..utils import PathLike
@dataclass(init=False)
class OpenpaiConfig(TrainingServiceConfig):
platform: str = 'openpai'
platform: Literal['openpai'] = 'openpai'
host: str
username: str
token: str
......
......@@ -21,6 +21,8 @@ from pathlib import Path
from typing import List, Optional, Union
import warnings
from typing_extensions import Literal
from ..base import ConfigBase
from ..training_service import TrainingServiceConfig
from .. import utils
......@@ -60,7 +62,7 @@ class RemoteMachineConfig(ConfigBase):
@dataclass(init=False)
class RemoteConfig(TrainingServiceConfig):
platform: str = 'remote'
platform: Literal['remote'] = 'remote'
machine_list: List[RemoteMachineConfig]
reuse_mode: bool = True
......
......@@ -7,12 +7,25 @@ Utility functions for experiment config classes, internal part.
If you are implementing a config class for a training service, it's unlikely you will need these.
"""
from __future__ import annotations
__all__ = [
'get_base_path', 'set_base_path', 'unset_base_path', 'resolve_path',
'case_insensitive', 'camel_case',
'fields', 'is_instance', 'validate_type', 'is_path_like',
'guess_config_type', 'guess_list_config_type',
'training_service_config_factory', 'load_training_service_config',
'get_ipv4_address'
]
import copy
import dataclasses
import importlib
import json
import os.path
from pathlib import Path
import socket
import typing
import typeguard
......@@ -20,36 +33,30 @@ import nni.runtime.config
from .public import is_missing
__all__ = [
'get_base_path', 'set_base_path', 'unset_base_path', 'resolve_path',
'case_insensitive', 'camel_case',
'is_instance', 'validate_type', 'is_path_like',
'guess_config_type', 'guess_list_config_type',
'training_service_config_factory', 'load_training_service_config',
'get_ipv4_address'
]
if typing.TYPE_CHECKING:
from ..base import ConfigBase
from ..training_service import TrainingServiceConfig
## handle relative path ##
_current_base_path = None
_current_base_path: Path | None = None
def get_base_path():
def get_base_path() -> Path:
if _current_base_path is None:
return Path()
return _current_base_path
def set_base_path(path):
def set_base_path(path: Path) -> None:
global _current_base_path
assert _current_base_path is None
_current_base_path = path
def unset_base_path():
def unset_base_path() -> None:
global _current_base_path
_current_base_path = None
def resolve_path(path, base_path):
if path is None:
return None
def resolve_path(path: Path | str, base_path: Path) -> str:
assert path is not None
# Path.resolve() does not work on Windows when file not exist, so use os.path instead
path = os.path.expanduser(path)
if not os.path.isabs(path):
......@@ -58,23 +65,32 @@ def resolve_path(path, base_path):
## field name case convertion ##
def case_insensitive(key):
def case_insensitive(key: str) -> str:
return key.lower().replace('_', '')
def camel_case(key):
def camel_case(key: str) -> str:
words = key.strip('_').split('_')
return words[0] + ''.join(word.title() for word in words[1:])
## type hint utils ##
def is_instance(value, type_hint):
def fields(config: ConfigBase) -> list[dataclasses.Field]:
# Similar to `dataclasses.fields()`, but use `typing.get_types_hints()` to get `field.type`.
# This is useful when postponed evaluation is enabled.
ret = [copy.copy(field) for field in dataclasses.fields(config)]
types = typing.get_type_hints(type(config))
for field in ret:
field.type = types[field.name]
return ret
def is_instance(value, type_hint) -> bool:
try:
typeguard.check_type('_', value, type_hint)
except TypeError:
return False
return True
def validate_type(config):
def validate_type(config: ConfigBase) -> None:
class_name = type(config).__name__
for field in dataclasses.fields(config):
value = getattr(config, field.name)
......@@ -84,17 +100,17 @@ def validate_type(config):
if not is_instance(value, field.type):
raise ValueError(f'{class_name}: type of {field.name} ({repr(value)}) is not {field.type}')
def is_path_like(type_hint):
def is_path_like(type_hint) -> bool:
# only `PathLike` and `Any` accepts `Path`; check `int` to make sure it's not `Any`
return is_instance(Path(), type_hint) and not is_instance(1, type_hint)
## type inference ##
def guess_config_type(obj, type_hint):
def guess_config_type(obj, type_hint) -> ConfigBase | None:
ret = guess_list_config_type([obj], type_hint, _hint_list_item=True)
return ret[0] if ret else None
def guess_list_config_type(objs, type_hint, _hint_list_item=False):
def guess_list_config_type(objs, type_hint, _hint_list_item=False) -> list[ConfigBase] | None:
# avoid circular import
from ..base import ConfigBase
from ..training_service import TrainingServiceConfig
......@@ -144,20 +160,20 @@ def _all_subclasses(cls):
subclasses = set(cls.__subclasses__())
return subclasses.union(*[_all_subclasses(subclass) for subclass in subclasses])
def training_service_config_factory(platform):
def training_service_config_factory(platform: str) -> TrainingServiceConfig:
cls = _get_ts_config_class(platform)
if cls is None:
raise ValueError(f'Bad training service platform: {platform}')
return cls()
def load_training_service_config(config):
def load_training_service_config(config) -> TrainingServiceConfig:
if isinstance(config, dict) and 'platform' in config:
cls = _get_ts_config_class(config['platform'])
if cls is not None:
return cls(**config)
return config # not valid json, don't touch
def _get_ts_config_class(platform):
def _get_ts_config_class(platform: str) -> type[TrainingServiceConfig] | None:
from ..training_service import TrainingServiceConfig # avoid circular import
# import all custom config classes so they can be found in TrainingServiceConfig.__subclasses__()
......@@ -175,7 +191,7 @@ def _get_ts_config_class(platform):
## misc ##
def get_ipv4_address():
def get_ipv4_address() -> str:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(('192.0.2.0', 80))
addr = s.getsockname()[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment