Unverified Commit d5857823 authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Config refactor (#4370)

parent cb090e8c
...@@ -6,6 +6,7 @@ pyyaml >= 5.4 ...@@ -6,6 +6,7 @@ pyyaml >= 5.4
requests requests
responses responses
schema schema
typeguard
PythonWebHDFS PythonWebHDFS
colorama colorama
scikit-learn >= 0.24.1 ; python_version >= "3.7" scikit-learn >= 0.24.1 ; python_version >= "3.7"
......
...@@ -314,6 +314,4 @@ Azure Blob Config ...@@ -314,6 +314,4 @@ Azure Blob Config
.. autoattribute:: nni.experiment.config.AzureBlobConfig.storage_account_key .. autoattribute:: nni.experiment.config.AzureBlobConfig.storage_account_key
.. autoattribute:: nni.experiment.config.AzureBlobConfig.resource_group_name
.. autoattribute:: nni.experiment.config.AzureBlobConfig.container_name .. autoattribute:: nni.experiment.config.AzureBlobConfig.container_name
...@@ -33,11 +33,10 @@ def main(): ...@@ -33,11 +33,10 @@ def main():
enable_multi_thread() enable_multi_thread()
if 'trainingServicePlatform' in exp_params: # config schema is v1 if 'trainingServicePlatform' in exp_params: # config schema is v1
from types import SimpleNamespace
from .experiment.config.convert import convert_algo from .experiment.config.convert import convert_algo
for algo_type in ['tuner', 'assessor', 'advisor']: for algo_type in ['tuner', 'assessor', 'advisor']:
if algo_type in exp_params: if algo_type in exp_params:
exp_params[algo_type] = convert_algo(algo_type, exp_params, SimpleNamespace()).json() exp_params[algo_type] = convert_algo(algo_type, exp_params[algo_type])
if exp_params.get('advisor') is not None: if exp_params.get('advisor') is not None:
# advisor is enabled and starts to run # advisor is enabled and starts to run
......
...@@ -2,5 +2,5 @@ ...@@ -2,5 +2,5 @@
# Licensed under the MIT license. # Licensed under the MIT license.
from .config import * from .config import *
from .experiment import Experiment from .experiment import Experiment, RunMode
from .data import * from .data import *
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from .common import * from .exp_config import ExperimentConfig
from .local import * from .algorithm import AlgorithmConfig, CustomAlgorithmConfig
from .remote import * from .training_services import *
from .openpai import *
from .aml import *
from .kubeflow import *
from .frameworkcontroller import *
from .adl import *
from .dlc import *
from .shared_storage import * from .shared_storage import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from dataclasses import dataclass
from .common import TrainingServiceConfig
__all__ = ['AdlConfig']
@dataclass(init=False)
class AdlConfig(TrainingServiceConfig):
platform: str = 'adl'
docker_image: str = 'msranni/nni:latest'
_validation_rules = {
'platform': lambda value: (value == 'adl', 'cannot be modified')
}
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Config classes for tuner/assessor/advisor algorithms.
Use ``AlgorithmConfig`` to specify a built-in algorithm;
use ``CustomAlgorithmConfig`` to specify a custom algorithm.
Check the reference_ for explaination of each field.
You may also want to check `tuner's overview`_.
.. _reference: https://nni.readthedocs.io/en/stable/reference/experiment_config.html
.. _tuner's overview: https://nni.readthedocs.io/en/stable/Tuner/BuiltinTuner.html
"""
__all__ = ['AlgorithmConfig', 'CustomAlgorithmConfig']
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Optional
from .base import ConfigBase
from .utils import PathLike
@dataclass(init=False)
class _AlgorithmConfig(ConfigBase):
"""
Common base class for ``AlgorithmConfig`` and ``CustomAlgorithmConfig``.
It's a "union set" of 2 derived classes. So users can use it as either one.
"""
name: Optional[str] = None
class_name: Optional[str] = None
code_directory: Optional[PathLike] = None
class_args: Optional[Dict[str, Any]] = None
def _validate_canonical(self):
super()._validate_canonical()
if self.class_name is None: # assume it's built-in algorithm by default
assert self.name
assert self.code_directory is None
else: # custom algorithm
assert self.name is None
assert self.class_name
if not Path(self.code_directory).is_dir():
raise ValueError(f'CustomAlgorithmConfig: code_directory "{self.code_directory}" is not a directory')
@dataclass(init=False)
class AlgorithmConfig(_AlgorithmConfig):
"""
Configuration for built-in algorithm.
"""
name: str
class_args: Optional[Dict[str, Any]] = None
@dataclass(init=False)
class CustomAlgorithmConfig(_AlgorithmConfig):
"""
Configuration for custom algorithm.
"""
class_name: str
code_directory: Optional[PathLike] = '.'
class_args: Optional[Dict[str, Any]] = None
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
"""
``ConfigBase`` class. Nothing else.
Docstrings in this file are mainly for NNI contributors instead of end users.
"""
__all__ = ['ConfigBase']
import copy import copy
import dataclasses import dataclasses
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional, Type, TypeVar
import yaml import yaml
from . import util from . import utils
__all__ = ['ConfigBase', 'PathLike'] class ConfigBase:
"""
The abstract base class of experiment config classes.
T = TypeVar('T', bound='ConfigBase') A config class should be a type-hinted dataclass inheriting ``ConfigBase``.
Or for a training service config class, it can inherit ``TrainingServiceConfig``.
PathLike = util.PathLike .. code-block:: python
def _is_missing(obj: Any) -> bool: @dataclass(init=False)
return isinstance(obj, type(dataclasses.MISSING)) class ExperimentConfig(ConfigBase):
name: Optional[str]
...
class ConfigBase: Subclasses are suggested to override ``_canonicalize()`` and ``_validate_canonical()`` methods.
"""
Base class of config classes. Users can create a config object with constructor or ``ConfigBase.load()``,
Subclass may override `_canonical_rules` and `_validation_rules`, validate its legality with ``ConfigBase.validate()``,
and `validate()` if the logic is complex. and finally convert it to the format accepted by NNI manager with ``ConfigBase.json()``.
Example usage:
.. code-block:: python
# when using Python API
config1 = ExperimentConfig(trialCommand='...', trialConcurrency=1, ...)
config1.validate()
print(config1.json())
# when using config file
config2 = ExperimentConfig.load('examples/config.yml')
config2.validate()
print(config2.json())
Config objects will remember where they are loaded; therefore relative paths can be resolved smartly.
If a config object is created with constructor, the base path will be current working directory.
If it is loaded with ``ConfigBase.load(path)``, the base path will be ``path``'s parent.
""" """
# Rules to convert field value to canonical format. def __init__(self, **kwargs):
# The key is field name.
# The value is callable `value -> canonical_value`
# It is not type-hinted so dataclass won't treat it as field
_canonical_rules = {} # type: ignore
# Rules to validate field value.
# The key is field name.
# The value is callable `value -> valid` or `value -> (valid, error_message)`
# The rule will be called with canonical format and is only called when `value` is not None.
# `error_message` is used when `valid` is False.
# It will be prepended with class name and field name in exception message.
_validation_rules = {} # type: ignore
def __init__(self, *, _base_path: Optional[Path] = None, **kwargs):
""" """
Initialize a config object and set some fields. There are two common ways to use the constructor,
Name of keyword arguments can either be snake_case or camelCase. directly writing Python code and unpacking from JSON(YAML) object:
They will be converted to snake_case automatically.
If a field is missing and don't have default value, it will be set to `dataclasses.MISSING`. .. code-block:: python
config1 = AlgorithmConfig(name='TPE', class_args={'optimize_mode': 'maximize'})
json = {'name': 'TPE', 'classArgs': {'optimize_mode': 'maximize'}}
config2 = AlgorithmConfig(**json)
If the config class has fields whose type is another config class, or list of another config class,
they will recursively load dict values.
Because JSON objects can use "camelCase" for field names,
cases and underscores in ``kwargs`` keys are ignored in this constructor.
For example if a config class has a field ``hello_world``,
then using ``hello_world=1``, ``helloWorld=1``, and ``_HELLOWORLD_=1`` in constructor
will all assign to the same field.
If ``kwargs`` contain extra keys, a `ValueError` will be raised.
If ``kwargs`` do not have enough key, missing fields are silently set to `MISSING()`.
You can use ``utils.is_missing()`` to check them.
""" """
if 'basepath' in kwargs: self._base_path = utils.get_base_path()
_base_path = kwargs.pop('basepath') args = {utils.case_insensitive(key): value for key, value in kwargs.items()}
kwargs = {util.case_insensitive(key): value for key, value in kwargs.items()}
if _base_path is None:
_base_path = Path()
for field in dataclasses.fields(self): for field in dataclasses.fields(self):
value = kwargs.pop(util.case_insensitive(field.name), field.default) value = args.pop(utils.case_insensitive(field.name), field.default)
if value is not None and not _is_missing(value):
# relative paths loaded from config file are not relative to pwd
if 'Path' in str(field.type):
value = Path(value).expanduser()
if not value.is_absolute():
value = _base_path / value
setattr(self, field.name, value) setattr(self, field.name, value)
if kwargs: if args: # maybe a key is misspelled
cls = type(self).__name__ class_name = type(self).__name__
fields = ', '.join(kwargs.keys()) fields = ', '.join(args.keys())
raise ValueError(f'{cls}: Unrecognized fields {fields}') raise ValueError(f'{class_name} does not have field(s) {fields}')
# try to unpack nested config
for field in dataclasses.fields(self):
value = getattr(self, field.name)
if utils.is_instance(value, field.type):
continue # already accepted by subclass, don't touch it
if isinstance(value, dict):
config = utils.guess_config_type(value, field.type)
if config is not None:
setattr(self, field.name, config)
elif isinstance(value, list) and value and isinstance(value[0], dict):
configs = utils.guess_list_config_type(value, field.type)
if configs:
setattr(self, field.name, configs)
@classmethod @classmethod
def load(cls: Type[T], path: PathLike) -> T: def load(cls, path):
""" """
Load config from YAML (or JSON) file. Load a YAML config file from file system.
Keys in YAML file can either be camelCase or snake_case.
Since YAML is a superset of JSON, it can also load JSON files.
This method raises exception if:
- The file is not available
- The file content is not valid YAML
- Top level value of the YAML is not object
- The YAML contains not supported fields
It does not raise exception when the YAML misses fields or contains bad fields.
Parameters
----------
path : PathLike
Path of the config file.
Returns
-------
cls
An object of ConfigBase subclass.
""" """
data = yaml.safe_load(open(path)) with open(path) as yaml_file:
data = yaml.safe_load(yaml_file)
if not isinstance(data, dict): if not isinstance(data, dict):
raise ValueError(f'Content of config file {path} is not a dict/object') raise ValueError(f'Conent of config file {path} is not a dict/object')
return cls(**data, _base_path=Path(path).parent) utils.set_base_path(Path(path).parent)
config = cls(**data)
utils.unset_base_path()
return config
def json(self) -> Dict[str, Any]: def canonical_copy(self):
""" """
Convert config to JSON object. Create a canonicalized copy of the config, and validate it.
The keys of returned object will be camelCase.
This function is mainly used internally by NNI.
Returns
-------
type(self)
A deep copy.
"""
canon = copy.deepcopy(self)
canon._canonicalize([])
canon._validate_canonical()
return canon
def validate(self):
"""
Validate legality of the config object. Raise exception if any error occurred.
This function does **not** return truth value. Do not write ``if config.validate()``.
Returns
-------
None
""" """
self.validate() self.canonical_copy()
return dataclasses.asdict(
self.canonical(),
dict_factory=lambda items: dict((util.camel_case(k), v) for k, v in items if v is not None)
)
def canonical(self: T) -> T: def json(self):
""" """
Returns a deep copy, where the fields supporting multiple formats are converted to the canonical format. Convert the config to JSON object (not JSON string).
Noticeably, relative path may be converted to absolute path.
In current implementation ``json()`` will invoke ``validate()``, but this might change in future version.
It is recommended to call ``validate()`` before ``json()`` for now.
Returns
-------
dict
JSON object.
""" """
ret = copy.deepcopy(self) canon = self.canonical_copy()
for field in dataclasses.fields(ret): return dataclasses.asdict(canon, dict_factory=_dict_factory) # this is recursive
key, value = field.name, getattr(ret, field.name)
rule = ret._canonical_rules.get(key) def _canonicalize(self, parents):
if rule is not None:
setattr(ret, key, rule(value))
elif isinstance(value, ConfigBase):
setattr(ret, key, value.canonical())
# value will be copied twice, should not be a performance issue anyway
elif isinstance(value, Path):
setattr(ret, key, str(value))
return ret
def validate(self) -> None:
""" """
Validate the config object and raise Exception if it's ill-formed. The config schema for end users is more flexible than the format NNI manager accepts.
This method convert a config object to the constrained format accepted by NNI manager.
The default implementation will:
1. Resolve all ``PathLike`` fields to absolute path
2. Call ``_canonicalize()`` on all children config objects, including those inside list and dict
Subclasses are recommended to call ``super()._canonicalize(parents)`` at the end of their overrided version.
Parameters
----------
parents : list[ConfigBase]
The upper level config objects.
For example local training service's ``trialGpuNumber`` will be copied from top level when not set,
in this case it will be invoked like ``localConfig._canonicalize([experimentConfig])``.
""" """
class_name = type(self).__name__ for field in dataclasses.fields(self):
config = self.canonical() value = getattr(self, field.name)
if isinstance(value, (Path, str)) and utils.is_path_like(field.type):
for field in dataclasses.fields(config): setattr(self, field.name, utils.resolve_path(value, self._base_path))
key, value = field.name, getattr(config, field.name) else:
_recursive_canonicalize_child(value, [self] + parents)
# check existence
if _is_missing(value): def _validate_canonical(self):
raise ValueError(f'{class_name}: {key} is not set') """
Validate legality of a canonical config object. It's caller's responsibility to ensure the config is canonical.
# check type (TODO)
type_name = str(field.type).replace('typing.', '') Raise exception if any problem found. This function does **not** return truth value.
optional = any([
type_name.startswith('Optional['), The default implementation will:
type_name.startswith('Union[') and 'None' in type_name,
type_name == 'Any' 1. Validate that all fields match their type hint
]) 2. Call ``_validate_canonical()`` on children config objects, including those inside list and dict
if value is None:
if optional: Subclasses are recommended to to call ``super()._validate_canonical()``.
continue """
else: utils.validate_type(self)
raise ValueError(f'{class_name}: {key} cannot be None') for field in dataclasses.fields(self):
value = getattr(self, field.name)
# check value _recursive_validate_child(value)
rule = config._validation_rules.get(key)
if rule is not None: def __setattr__(self, name, value):
try: if hasattr(self, name) or name.startswith('_'):
result = rule(value) super().__setattr__(name, value)
except Exception: return
raise ValueError(f'{class_name}: {key} has bad value {repr(value)}') if name in [field.name for field in dataclasses.fields(self)]: # might happend during __init__
super().__setattr__(name, value)
if isinstance(result, bool): return
if not result: raise AttributeError(f'{type(self).__name__} does not have field {name}')
raise ValueError(f'{class_name}: {key} ({repr(value)}) is out of range')
else: def _dict_factory(items):
if not result[0]: ret = {}
raise ValueError(f'{class_name}: {key} {result[1]}') for key, value in items:
if value is not None:
# check nested config k = utils.camel_case(key)
if isinstance(value, ConfigBase): v = str(value) if isinstance(value, Path) else value
value.validate() ret[k] = v
return ret
def _recursive_canonicalize_child(child, parents):
if isinstance(child, ConfigBase):
child._canonicalize(parents)
elif isinstance(child, list):
for item in child:
_recursive_canonicalize_child(item, parents)
elif isinstance(child, dict):
for item in child.values():
_recursive_canonicalize_child(item, parents)
def _recursive_validate_child(child):
if isinstance(child, ConfigBase):
child._validate_canonical()
elif isinstance(child, list):
for item in child:
_recursive_validate_child(item)
elif isinstance(child, dict):
for item in child.values():
_recursive_validate_child(item)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import yaml
from .base import ConfigBase, PathLike
from . import util
__all__ = [
'ExperimentConfig',
'AlgorithmConfig',
'CustomAlgorithmConfig',
'TrainingServiceConfig',
]
@dataclass(init=False)
class _AlgorithmConfig(ConfigBase):
name: Optional[str] = None
class_name: Optional[str] = None
code_directory: Optional[PathLike] = None
class_args: Optional[Dict[str, Any]] = None
def validate(self):
super().validate()
_validate_algo(self)
_canonical_rules = {'code_directory': util.canonical_path}
@dataclass(init=False)
class AlgorithmConfig(_AlgorithmConfig):
name: str
class_args: Optional[Dict[str, Any]] = None
@dataclass(init=False)
class CustomAlgorithmConfig(_AlgorithmConfig):
class_name: str
code_directory: Optional[PathLike] = '.'
class_args: Optional[Dict[str, Any]] = None
class TrainingServiceConfig(ConfigBase):
platform: str
@dataclass(init=False)
class SharedStorageConfig(ConfigBase):
storage_type: str
local_mount_point: PathLike
remote_mount_point: str
local_mounted: str
storage_account_name: Optional[str] = None
storage_account_key: Optional[str] = None
container_name: Optional[str] = None
nfs_server: Optional[str] = None
exported_directory: Optional[str] = None
def __init__(self, *, _base_path: Optional[Path] = None, **kwargs):
kwargs = {util.case_insensitive(key): value for key, value in kwargs.items()}
if 'localmountpoint' in kwargs:
kwargs['localmountpoint'] = Path(kwargs['localmountpoint']).expanduser()
if not kwargs['localmountpoint'].is_absolute():
raise ValueError('localMountPoint can only be set as an absolute path.')
super().__init__(_base_path=_base_path, **kwargs)
@dataclass(init=False)
class ExperimentConfig(ConfigBase):
experiment_name: Optional[str] = None
search_space_file: Optional[PathLike] = None
search_space: Any = None
trial_command: str
trial_code_directory: PathLike = '.'
trial_concurrency: int
trial_gpu_number: Optional[int] = None # TODO: in openpai cannot be None
max_experiment_duration: Optional[str] = None
max_trial_number: Optional[int] = None
max_trial_duration: Optional[int] = None
nni_manager_ip: Optional[str] = None
use_annotation: bool = False
debug: bool = False
log_level: Optional[str] = None
experiment_working_directory: PathLike = '~/nni-experiments'
tuner_gpu_indices: Union[List[int], str, int, None] = None
tuner: Optional[_AlgorithmConfig] = None
assessor: Optional[_AlgorithmConfig] = None
advisor: Optional[_AlgorithmConfig] = None
training_service: Union[TrainingServiceConfig, List[TrainingServiceConfig]]
shared_storage: Optional[SharedStorageConfig] = None
_deprecated: Optional[Dict[str, Any]] = None
def __init__(self, training_service_platform: Optional[Union[str, List[str]]] = None, **kwargs):
base_path = kwargs.pop('_base_path', None)
kwargs = util.case_insensitive(kwargs)
if training_service_platform is not None:
assert 'trainingservice' not in kwargs
kwargs['trainingservice'] = util.training_service_config_factory(
platform=training_service_platform,
base_path=base_path
)
elif isinstance(kwargs.get('trainingservice'), (dict, list)):
# dict means a single training service
# list means hybrid training service
kwargs['trainingservice'] = util.training_service_config_factory(
config=kwargs['trainingservice'],
base_path=base_path
)
else:
raise RuntimeError('Unsupported Training service configuration!')
super().__init__(_base_path=base_path, **kwargs)
for algo_type in ['tuner', 'assessor', 'advisor']:
if isinstance(kwargs.get(algo_type), dict):
setattr(self, algo_type, _AlgorithmConfig(**kwargs.pop(algo_type)))
if isinstance(kwargs.get('sharedstorage'), dict):
setattr(self, 'shared_storage', SharedStorageConfig(_base_path=base_path, **kwargs.pop('sharedstorage')))
def canonical(self):
ret = super().canonical()
if isinstance(ret.training_service, list):
for i, ts in enumerate(ret.training_service):
ret.training_service[i] = ts.canonical()
return ret
def validate(self, initialized_tuner: bool = False) -> None:
super().validate()
if initialized_tuner:
_validate_for_exp(self.canonical())
else:
_validate_for_nnictl(self.canonical())
if self.trial_gpu_number and hasattr(self.training_service, 'use_active_gpu'):
if self.training_service.use_active_gpu is None:
raise ValueError('Please set "use_active_gpu"')
def json(self) -> Dict[str, Any]:
obj = super().json()
if obj.get('searchSpaceFile'):
obj['searchSpace'] = yaml.safe_load(open(obj.pop('searchSpaceFile')))
return obj
## End of public API ##
@property
def _canonical_rules(self):
return _canonical_rules
@property
def _validation_rules(self):
return _validation_rules
_canonical_rules = {
'search_space_file': util.canonical_path,
'trial_code_directory': util.canonical_path,
'max_experiment_duration': lambda value: f'{util.parse_time(value)}s' if value is not None else None,
'experiment_working_directory': util.canonical_path,
'tuner_gpu_indices': util.canonical_gpu_indices,
'tuner': lambda config: None if config is None or config.name == '_none_' else config.canonical(),
'assessor': lambda config: None if config is None or config.name == '_none_' else config.canonical(),
'advisor': lambda config: None if config is None or config.name == '_none_' else config.canonical(),
}
_validation_rules = {
'search_space_file': lambda value: (Path(value).is_file(), f'"{value}" does not exist or is not regular file'),
'trial_code_directory': lambda value: (Path(value).is_dir(), f'"{value}" does not exist or is not directory'),
'trial_concurrency': lambda value: value > 0,
'trial_gpu_number': lambda value: value >= 0,
'max_experiment_duration': lambda value: util.parse_time(value) > 0,
'max_trial_number': lambda value: value > 0,
'max_trial_duration': lambda value: util.parse_time(value) > 0,
'log_level': lambda value: value in ["trace", "debug", "info", "warning", "error", "fatal"],
'tuner_gpu_indices': lambda value: all(i >= 0 for i in value) and len(value) == len(set(value)),
'training_service': lambda value: (type(value) is not TrainingServiceConfig, 'cannot be abstract base class')
}
def _validate_for_exp(config: ExperimentConfig) -> None:
# validate experiment for nni.Experiment, where tuner is already initialized outside
if config.use_annotation:
raise ValueError('ExperimentConfig: annotation is not supported in this mode')
if util.count(config.search_space, config.search_space_file) != 1:
raise ValueError('ExperimentConfig: search_space and search_space_file must be set one')
if util.count(config.tuner, config.assessor, config.advisor) != 0:
raise ValueError('ExperimentConfig: tuner, assessor, and advisor must not be set in for this mode')
if config.tuner_gpu_indices is not None:
raise ValueError('ExperimentConfig: tuner_gpu_indices is not supported in this mode')
def _validate_for_nnictl(config: ExperimentConfig) -> None:
# validate experiment for normal launching approach
if config.use_annotation:
if util.count(config.search_space, config.search_space_file) != 0:
raise ValueError('ExperimentConfig: search_space and search_space_file must not be set with annotationn')
else:
if util.count(config.search_space, config.search_space_file) != 1:
raise ValueError('ExperimentConfig: search_space and search_space_file must be set one')
if util.count(config.tuner, config.advisor) != 1:
raise ValueError('ExperimentConfig: tuner and advisor must be set one')
def _validate_algo(algo: AlgorithmConfig) -> None:
if algo.name is None:
if algo.class_name is None:
raise ValueError('Missing algorithm name')
if algo.code_directory is not None and not Path(algo.code_directory).is_dir():
raise ValueError(f'code_directory "{algo.code_directory}" does not exist or is not directory')
else:
if algo.class_name is not None or algo.code_directory is not None:
raise ValueError(f'When name is set for registered algorithm, class_name and code_directory cannot be used')
# TODO: verify algorithm installation and class args
This diff is collapsed.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Top level experiement configuration class, ``ExperimentConfig``.
"""
__all__ = ['ExperimentConfig']
from dataclasses import dataclass
import logging
from typing import Any, List, Optional, Union
import yaml
from .algorithm import _AlgorithmConfig
from .base import ConfigBase
from .shared_storage import SharedStorageConfig
from .training_service import TrainingServiceConfig
from . import utils
@dataclass(init=False)
class ExperimentConfig(ConfigBase):
"""
Class of experiment configuration. Check the reference_ for explaination of each field.
When used in Python experiment API, it can be constructed in two favors:
1. Create an empty project then set each field
.. code-block:: python
config = ExperimentConfig('local')
config.search_space = {...}
config.tuner.name = 'random'
config.training_service.use_active_gpu = True
2. Use kwargs directly
.. code-block:: python
config = ExperimentConfig(
search_space = {...},
tuner = AlgorithmConfig(name='random'),
training_service = LocalConfig(
use_active_gpu = True
)
)
Fields commented as "training service field" acts like shortcut for all training services.
Users can either specify them here or inside training service config.
In latter case hybrid training services can have different settings.
.. _reference: https://nni.readthedocs.io/en/stable/reference/experiment_config.html
"""
experiment_name: Optional[str] = None
search_space_file: Optional[utils.PathLike] = None
search_space: Any = None
trial_command: Optional[str] = None # training service field
trial_code_directory: utils.PathLike = '.' # training service field
trial_concurrency: int
trial_gpu_number: Optional[int] = None # training service field
max_experiment_duration: Union[str, int, None] = None
max_trial_number: Optional[int] = None
max_trial_duration: Union[str, int, None] = None
nni_manager_ip: Optional[str] = None # training service field
use_annotation: bool = False
debug: bool = False
log_level: Optional[str] = None
experiment_working_directory: utils.PathLike = '~/nni-experiments'
tuner_gpu_indices: Union[List[int], int, str, None] = None
tuner: Optional[_AlgorithmConfig] = None
assessor: Optional[_AlgorithmConfig] = None
advisor: Optional[_AlgorithmConfig] = None
training_service: Union[TrainingServiceConfig, List[TrainingServiceConfig]]
shared_storage: Optional[SharedStorageConfig] = None
def __init__(self, training_service_platform=None, **kwargs):
super().__init__(**kwargs)
if training_service_platform is not None:
# the user chose to init with `config = ExperimentConfig('local')` and set fields later
# we need to create empty training service & algorithm configs to support `config.tuner.name = 'random'`
assert utils.is_missing(self.training_service)
if isinstance(training_service_platform, list):
self.training_service = [utils.training_service_config_factory(ts) for ts in training_service_platform]
else:
self.training_service = utils.training_service_config_factory(training_service_platform)
for algo_type in ['tuner', 'assessor', 'advisor']:
# add placeholder items, so users can write `config.tuner.name = 'random'`
if getattr(self, algo_type) is None:
setattr(self, algo_type, _AlgorithmConfig(name='_none_'))
elif not utils.is_missing(self.training_service):
# training service is set via json or constructor
if isinstance(self.training_service, list):
self.training_service = [utils.load_training_service_config(ts) for ts in self.training_service]
else:
self.training_service = utils.load_training_service_config(self.training_service)
def _canonicalize(self, _parents):
if self.log_level is None:
self.log_level = 'debug' if self.debug else 'info'
self.tuner_gpu_indices = utils.canonical_gpu_indices(self.tuner_gpu_indices)
for algo_type in ['tuner', 'assessor', 'advisor']:
algo = getattr(self, algo_type)
if algo is not None and algo.name == '_none_':
setattr(self, algo_type, None)
super()._canonicalize([self])
if self.nni_manager_ip is None:
# show a warning if user does not set nni_manager_ip. we have many issues caused by this
# the simple detection logic won't work for hybrid, but advanced users should not need it
# ideally we should check accessibility of the ip, but it need much more work
platform = getattr(self.training_service, 'platform')
has_ip = isinstance(getattr(self.training_service, 'nni_manager_ip'), str) # not None or MISSING
if platform and platform != 'local' and not has_ip:
ip = utils.get_ipv4_address()
msg = f'nni_manager_ip is not set, please make sure {ip} is accessible from training machines'
logging.getLogger('nni.experiment.config').warning(msg)
def _validate_canonical(self):
super()._validate_canonical()
space_cnt = (self.search_space is not None) + (self.search_space_file is not None)
if self.use_annotation and space_cnt != 0:
raise ValueError('ExperimentConfig: search space must not be set when annotation is enabled')
if not self.use_annotation and space_cnt < 1:
raise ValueError('ExperimentConfig: search_space and search_space_file must be set one')
if self.search_space_file is not None:
with open(self.search_space_file) as ss_file:
self.search_space = yaml.safe_load(ss_file)
# to make the error message clear, ideally it should be:
# `if concurrency < 0: raise ValueError('trial_concurrency ({concurrency}) must greater than 0')`
# but I believe there will be hardy few users make this kind of mistakes, so let's keep it simple
assert self.trial_concurrency > 0
assert self.max_experiment_duration is None or utils.parse_time(self.max_experiment_duration) > 0
assert self.max_trial_number is None or self.max_trial_number > 0
assert self.max_trial_duration is None or utils.parse_time(self.max_trial_duration) > 0
assert self.log_level in ['fatal', 'error', 'warning', 'info', 'debug', 'trace']
# following line is disabled because it has side effect
# enable it if users encounter problems caused by failure in creating experiment directory
# currently I have only seen one issue of this kind
#Path(self.experiment_working_directory).mkdir(parents=True, exist_ok=True)
utils.validate_gpu_indices(self.tuner_gpu_indices)
tuner_cnt = (self.tuner is not None) + (self.advisor is not None)
if tuner_cnt != 1:
raise ValueError('ExperimentConfig: tuner and advisor must be set one')
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from dataclasses import dataclass
from typing import List, Optional, Union
from .common import TrainingServiceConfig
from . import util
__all__ = ['LocalConfig']
@dataclass(init=False)
class LocalConfig(TrainingServiceConfig):
platform: str = 'local'
reuse_mode: bool = False
use_active_gpu: Optional[bool] = None
max_trial_number_per_gpu: int = 1
gpu_indices: Union[List[int], str, int, None] = None
_canonical_rules = {
'gpu_indices': util.canonical_gpu_indices
}
_validation_rules = {
'platform': lambda value: (value == 'local', 'cannot be modified'),
'max_trial_number_per_gpu': lambda value: value > 0,
'gpu_indices': lambda value: all(idx >= 0 for idx in value) and len(value) == len(set(value))
}
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Union
import warnings
from .base import ConfigBase, PathLike
from .common import TrainingServiceConfig
from . import util
__all__ = ['RemoteConfig', 'RemoteMachineConfig']
@dataclass(init=False)
class RemoteMachineConfig(ConfigBase):
host: str
port: int = 22
user: str
password: Optional[str] = None
ssh_key_file: PathLike = None #'~/.ssh/id_rsa'
ssh_passphrase: Optional[str] = None
use_active_gpu: bool = False
max_trial_number_per_gpu: int = 1
gpu_indices: Union[List[int], str, int, None] = None
python_path: Optional[str] = None
_canonical_rules = {
'ssh_key_file': util.canonical_path,
'gpu_indices': util.canonical_gpu_indices
}
_validation_rules = {
'port': lambda value: 0 < value < 65536,
'max_trial_number_per_gpu': lambda value: value > 0,
'gpu_indices': lambda value: all(idx >= 0 for idx in value) and len(value) == len(set(value))
}
def validate(self):
super().validate()
if self.password is None and not Path(self.ssh_key_file).is_file():
raise ValueError(f'Password is not provided and cannot find SSH key file "{self.ssh_key_file}"')
if self.password:
warnings.warn('Password will be exposed through web UI in plain text. We recommend to use SSH key file.')
@dataclass(init=False)
class RemoteConfig(TrainingServiceConfig):
platform: str = 'remote'
reuse_mode: bool = True
machine_list: List[RemoteMachineConfig]
def __init__(self, **kwargs):
kwargs = util.case_insensitive(kwargs)
kwargs['machinelist'] = util.load_config(RemoteMachineConfig, kwargs.get('machinelist'))
super().__init__(**kwargs)
_canonical_rules = {
'machine_list': lambda value: [config.canonical() for config in value]
}
_validation_rules = {
'platform': lambda value: (value == 'remote', 'cannot be modified')
}
...@@ -4,10 +4,23 @@ ...@@ -4,10 +4,23 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional
from .common import SharedStorageConfig from .base import ConfigBase
from .utils import PathLike
__all__ = ['NfsConfig', 'AzureBlobConfig'] __all__ = ['NfsConfig', 'AzureBlobConfig']
@dataclass(init=False)
class SharedStorageConfig(ConfigBase):
storage_type: str
local_mount_point: PathLike
remote_mount_point: str
local_mounted: str
storage_account_name: Optional[str] = None
storage_account_key: Optional[str] = None
container_name: Optional[str] = None
nfs_server: Optional[str] = None
exported_directory: Optional[str] = None
@dataclass(init=False) @dataclass(init=False)
class NfsConfig(SharedStorageConfig): class NfsConfig(SharedStorageConfig):
storage_type: str = 'NFS' storage_type: str = 'NFS'
...@@ -19,5 +32,4 @@ class AzureBlobConfig(SharedStorageConfig): ...@@ -19,5 +32,4 @@ class AzureBlobConfig(SharedStorageConfig):
storage_type: str = 'AzureBlob' storage_type: str = 'AzureBlob'
storage_account_name: str storage_account_name: str
storage_account_key: Optional[str] = None storage_account_key: Optional[str] = None
resource_group_name: Optional[str] = None
container_name: str container_name: str
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
``TrainingServiceConfig`` class.
Docstrings in this file are mainly for NNI contributors, or training service authors.
"""
__all__ = ['TrainingServiceConfig']
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
from .base import ConfigBase
from .utils import PathLike, is_missing
@dataclass(init=False)
class TrainingServiceConfig(ConfigBase):
"""
The base class of training service config classes.
See ``LocalConfig`` for example usage.
"""
platform: str
trial_command: str
trial_code_directory: PathLike
trial_gpu_number: Optional[int]
nni_manager_ip: Optional[str]
debug: bool
def _canonicalize(self, parents):
"""
Besides from ``ConfigBase._canonicalize()``, this overloaded version will also
copy training service specific fields from ``ExperimentConfig``.
"""
shortcuts = [ # fields that can set in root level config as shortcut
'trial_command',
'trial_code_directory',
'trial_gpu_number',
'nni_manager_ip',
'debug',
]
for field_name in shortcuts:
if is_missing(getattr(self, field_name)):
value = getattr(parents[0], field_name)
setattr(self, field_name, value)
super()._canonicalize(parents)
def _validate_canonical(self):
super()._validate_canonical()
cls = type(self)
assert self.platform == cls.platform
if not Path(self.trial_code_directory).is_dir():
raise ValueError(f'{cls.__name__}: trial_code_directory "{self.trial_code_directory}" is not a directory')
assert self.trial_gpu_number is None or self.trial_gpu_number >= 0
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .local import *
from .remote import *
from .openpai import *
from .k8s_storage import *
from .kubeflow import *
from .frameworkcontroller import *
from .aml import *
from .dlc import *
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from dataclasses import dataclass """
Configuration for AML training service.
Check the reference_ for explaination of each field.
You may also want to check `AML training service doc`_.
.. _reference: https://nni.readthedocs.io/en/stable/reference/experiment_config.html
.. _AML training service doc: https://nni.readthedocs.io/en/stable/TrainingService/AMLMode.html
from .common import TrainingServiceConfig """
__all__ = ['AmlConfig'] __all__ = ['AmlConfig']
from dataclasses import dataclass
from ..training_service import TrainingServiceConfig
@dataclass(init=False) @dataclass(init=False)
class AmlConfig(TrainingServiceConfig): class AmlConfig(TrainingServiceConfig):
platform: str = 'aml' platform: str = 'aml'
...@@ -16,7 +29,3 @@ class AmlConfig(TrainingServiceConfig): ...@@ -16,7 +29,3 @@ class AmlConfig(TrainingServiceConfig):
compute_target: str compute_target: str
docker_image: str = 'msranni/nni:latest' docker_image: str = 'msranni/nni:latest'
max_trial_number_per_gpu: int = 1 max_trial_number_per_gpu: int = 1
_validation_rules = {
'platform': lambda value: (value == 'aml', 'cannot be modified')
}
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from .common import TrainingServiceConfig from ..training_service import TrainingServiceConfig
__all__ = ['DlcConfig'] __all__ = ['DlcConfig']
...@@ -21,7 +21,3 @@ class DlcConfig(TrainingServiceConfig): ...@@ -21,7 +21,3 @@ class DlcConfig(TrainingServiceConfig):
access_key_secret: str access_key_secret: str
local_storage_mount_point: str local_storage_mount_point: str
container_storage_mount_point: str container_storage_mount_point: str
_validation_rules = {
'platform': lambda value: (value == 'dlc', 'cannot be modified')
}
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from dataclasses import dataclass """
from typing import List, Optional Configuration for FrameworkController training service.
from .base import ConfigBase Check the reference_ for explaination of each field.
from .common import TrainingServiceConfig
from . import util
__all__ = [ You may also want to check `FrameworkController training service doc`_.
'FrameworkControllerConfig',
'FrameworkControllerRoleConfig',
'_FrameworkControllerStorageConfig'
]
.. _reference: https://nni.readthedocs.io/en/stable/reference/experiment_config.html
@dataclass(init=False) .. _FrameworkController training service doc: https://nni.readthedocs.io/en/stable/TrainingService/FrameworkControllerMode.html
class _FrameworkControllerStorageConfig(ConfigBase):
storage_type: str """
server: Optional[str] = None
path: Optional[str] = None __all__ = ['FrameworkControllerConfig', 'FrameworkControllerRoleConfig', 'FrameworkAttemptCompletionPolicy']
azure_account: Optional[str] = None
azure_share: Optional[str] = None from dataclasses import dataclass
key_vault_name: Optional[str] = None from typing import List, Optional, Union
key_vault_key: Optional[str] = None
from ..base import ConfigBase
from ..training_service import TrainingServiceConfig
from .k8s_storage import K8sStorageConfig
@dataclass(init=False) @dataclass(init=False)
class FrameworkAttemptCompletionPolicy(ConfigBase): class FrameworkAttemptCompletionPolicy(ConfigBase):
...@@ -38,25 +36,13 @@ class FrameworkControllerRoleConfig(ConfigBase): ...@@ -38,25 +36,13 @@ class FrameworkControllerRoleConfig(ConfigBase):
command: str command: str
gpu_number: int gpu_number: int
cpu_number: int cpu_number: int
memory_size: str memory_size: Union[str, int]
framework_attempt_completion_policy: FrameworkAttemptCompletionPolicy framework_attempt_completion_policy: FrameworkAttemptCompletionPolicy
@dataclass(init=False) @dataclass(init=False)
class FrameworkControllerConfig(TrainingServiceConfig): class FrameworkControllerConfig(TrainingServiceConfig):
platform: str = 'frameworkcontroller' platform: str = 'frameworkcontroller'
service_account_name: str storage: K8sStorageConfig
storage: _FrameworkControllerStorageConfig
task_roles: List[FrameworkControllerRoleConfig]
reuse_mode: Optional[bool] = True #set reuse mode as true for v2 config
service_account_name: Optional[str] service_account_name: Optional[str]
task_roles: List[FrameworkControllerRoleConfig]
def __init__(self, **kwargs): reuse_mode: Optional[bool] = True
kwargs = util.case_insensitive(kwargs)
kwargs['storage'] = util.load_config(_FrameworkControllerStorageConfig, kwargs.get('storage'))
kwargs['taskroles'] = util.load_config(FrameworkControllerRoleConfig, kwargs.get('taskroles'))
super().__init__(**kwargs)
_validation_rules = {
'platform': lambda value: (value == 'frameworkcontroller', 'cannot be modified')
}
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Storage config classes for ``KubeflowConfig`` and ``FrameworkControllerConfig``
"""
__all__ = ['K8sStorageConfig', 'K8sAzureStorageConfig', 'K8sNfsConfig']
from dataclasses import dataclass
from typing import Optional
from ..base import ConfigBase
@dataclass(init=False)
class K8sStorageConfig(ConfigBase):
storage_type: str
azure_account: Optional[str] = None
azure_share: Optional[str] = None
key_vault_name: Optional[str] = None
key_vault_key: Optional[str] = None
server: Optional[str] = None
path: Optional[str] = None
def _validate_canonical(self):
super()._validate_canonical()
if self.storage_type == 'azureStorage':
assert self.server is None and self.path is None
elif self.storage_type == 'nfs':
assert self.azure_account is None and self.azure_share is None
assert self.key_vault_name is None and self.key_vault_key is None
else:
raise ValueError(f'Kubernetes storage_type ("{self.storage_type}") must either be "azureStorage" or "nfs"')
@dataclass(init=False)
class K8sNfsConfig(K8sStorageConfig):
storage: str = 'nfs'
server: str
path: str
@dataclass(init=False)
class K8sAzureStorageConfig(K8sStorageConfig):
storage: str = 'azureStorage'
azure_account: str
azure_share: str
key_vault_name: str
key_vault_key: str
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment