Unverified Commit d5857823 authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Config refactor (#4370)

parent cb090e8c
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from dataclasses import dataclass """
from typing import Optional Configuration for Kubeflow training service.
from .base import ConfigBase Check the reference_ for explaination of each field.
from .common import TrainingServiceConfig
from . import util
__all__ = ['KubeflowConfig', 'KubeflowRoleConfig', 'KubeflowStorageConfig', 'KubeflowNfsConfig', 'KubeflowAzureStorageConfig'] You may also want to check `Kubeflow training service doc`_.
.. _reference: https://nni.readthedocs.io/en/stable/reference/experiment_config.html
@dataclass(init=False) .. _Kubeflow training service doc: https://nni.readthedocs.io/en/stable/TrainingService/KubeflowMode.html
class KubeflowStorageConfig(ConfigBase):
storage_type: str
server: Optional[str] = None
path: Optional[str] = None
azure_account: Optional[str] = None
azure_share: Optional[str] = None
key_vault_name: Optional[str] = None
key_vault_key: Optional[str] = None
@dataclass(init=False) """
class KubeflowNfsConfig(KubeflowStorageConfig):
storage: str = 'nfs'
server: str
path: str
@dataclass(init=False) __all__ = ['KubeflowConfig', 'KubeflowRoleConfig']
class KubeflowAzureStorageConfig(ConfigBase):
storage: str = 'azureStorage'
azure_account: str
azure_share: str
key_vault_name: str
key_vault_key: str
from dataclasses import dataclass
from typing import Optional, Union
from ..base import ConfigBase
from ..training_service import TrainingServiceConfig
from .k8s_storage import K8sStorageConfig
@dataclass(init=False) @dataclass(init=False)
class KubeflowRoleConfig(ConfigBase): class KubeflowRoleConfig(ConfigBase):
...@@ -42,31 +29,21 @@ class KubeflowRoleConfig(ConfigBase): ...@@ -42,31 +29,21 @@ class KubeflowRoleConfig(ConfigBase):
command: str command: str
gpu_number: Optional[int] = 0 gpu_number: Optional[int] = 0
cpu_number: int cpu_number: int
memory_size: str memory_size: Union[str, int]
docker_image: str = 'msranni/nni:latest' docker_image: str = 'msranni/nni:latest'
code_directory: str code_directory: str
@dataclass(init=False) @dataclass(init=False)
class KubeflowConfig(TrainingServiceConfig): class KubeflowConfig(TrainingServiceConfig):
platform: str = 'kubeflow' platform: str = 'kubeflow'
operator: str operator: str
api_version: str api_version: str
storage: KubeflowStorageConfig storage: K8sStorageConfig
worker: Optional[KubeflowRoleConfig] = None worker: Optional[KubeflowRoleConfig] = None
ps: Optional[KubeflowRoleConfig] = None ps: Optional[KubeflowRoleConfig] = None
master: Optional[KubeflowRoleConfig] = None master: Optional[KubeflowRoleConfig] = None
reuse_mode: Optional[bool] = True #set reuse mode as true for v2 config reuse_mode: Optional[bool] = True #set reuse mode as true for v2 config
def __init__(self, **kwargs): def _validate_canonical(self):
kwargs = util.case_insensitive(kwargs) super()._validate_canonical()
kwargs['storage'] = util.load_config(KubeflowStorageConfig, kwargs.get('storage')) assert self.operator in ['tf-operator', 'pytorch-operator']
kwargs['worker'] = util.load_config(KubeflowRoleConfig, kwargs.get('worker'))
kwargs['ps'] = util.load_config(KubeflowRoleConfig, kwargs.get('ps'))
kwargs['master'] = util.load_config(KubeflowRoleConfig, kwargs.get('master'))
super().__init__(**kwargs)
_validation_rules = {
'platform': lambda value: (value == 'kubeflow', 'cannot be modified'),
'operator': lambda value: value in ['tf-operator', 'pytorch-operator']
}
\ No newline at end of file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Configuration for local training service.
Check the reference_ for explaination of each field.
You may also want to check `local training service doc`_.
.. _reference: https://nni.readthedocs.io/en/stable/reference/experiment_config.html
.. _local training service doc: https://nni.readthedocs.io/en/stable/TrainingService/LocalMode.html
"""
__all__ = ['LocalConfig']
from dataclasses import dataclass
from typing import List, Optional, Union
from ..training_service import TrainingServiceConfig
from .. import utils
@dataclass(init=False)
class LocalConfig(TrainingServiceConfig):
platform: str = 'local'
use_active_gpu: Optional[bool] = None
max_trial_number_per_gpu: int = 1
gpu_indices: Union[List[int], int, str, None] = None
reuse_mode: bool = False
def _canonicalize(self, parents):
super()._canonicalize(parents)
self.gpu_indices = utils.canonical_gpu_indices(self.gpu_indices)
self.nni_manager_ip = None
def _validate_canonical(self):
super()._validate_canonical()
utils.validate_gpu_indices(self.gpu_indices)
if self.trial_gpu_number and self.use_active_gpu is None:
raise ValueError(
'LocalConfig: please set use_active_gpu to True if your system has GUI, '
'or set it to False if the computer runs multiple experiments concurrently.'
)
if not self.trial_gpu_number and self.max_trial_number_per_gpu != 1:
raise ValueError('LocalConfig: max_trial_number_per_gpu does not work without trial_gpu_number')
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from dataclasses import dataclass """
from pathlib import Path, PurePosixPath Configuration for OpenPAI training service.
from typing import Any, Dict, Optional
Check the reference_ for explaination of each field.
You may also want to check `OpenPAI training service doc`_.
.. _reference: https://nni.readthedocs.io/en/stable/reference/experiment_config.html
from .base import PathLike .. _OpenPAI training service doc: https://nni.readthedocs.io/en/stable/TrainingService/PaiMode.html
from .common import TrainingServiceConfig
from . import util """
__all__ = ['OpenpaiConfig'] __all__ = ['OpenpaiConfig']
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Optional, Union
from ..training_service import TrainingServiceConfig
from ..utils import PathLike
@dataclass(init=False) @dataclass(init=False)
class OpenpaiConfig(TrainingServiceConfig): class OpenpaiConfig(TrainingServiceConfig):
platform: str = 'openpai' platform: str = 'openpai'
...@@ -18,7 +30,7 @@ class OpenpaiConfig(TrainingServiceConfig): ...@@ -18,7 +30,7 @@ class OpenpaiConfig(TrainingServiceConfig):
username: str username: str
token: str token: str
trial_cpu_number: int trial_cpu_number: int
trial_memory_size: str trial_memory_size: Union[str, int]
storage_config_name: str storage_config_name: str
docker_image: str = 'msranni/nni:latest' docker_image: str = 'msranni/nni:latest'
virtual_cluster: Optional[str] virtual_cluster: Optional[str]
...@@ -26,23 +38,23 @@ class OpenpaiConfig(TrainingServiceConfig): ...@@ -26,23 +38,23 @@ class OpenpaiConfig(TrainingServiceConfig):
container_storage_mount_point: str container_storage_mount_point: str
reuse_mode: bool = True reuse_mode: bool = True
openpai_config: Optional[Dict[str, Any]] = None openpai_config: Optional[Dict] = None
openpai_config_file: Optional[PathLike] = None openpai_config_file: Optional[PathLike] = None
_canonical_rules = { def _canonicalize(self, parents):
'host': lambda value: 'https://' + value if '://' not in value else value, # type: ignore super()._canonicalize(parents)
'local_storage_mount_point': util.canonical_path, if '://' not in self.host:
'openpai_config_file': util.canonical_path self.host = 'https://' + self.host
}
def _validate_canonical(self) -> None:
_validation_rules = { super()._validate_canonical()
'platform': lambda value: (value == 'openpai', 'cannot be modified'), if self.trial_gpu_number is None:
'local_storage_mount_point': lambda value: Path(value).is_dir(), raise ValueError('OpenpaiConfig: trial_gpu_number is not set')
'container_storage_mount_point': lambda value: (PurePosixPath(value).is_absolute(), 'is not absolute'), if not Path(self.local_storage_mount_point).is_dir():
'openpai_config_file': lambda value: Path(value).is_file() raise ValueError(
} f'OpenpaiConfig: local_storage_mount_point "(self.local_storage_mount_point)" is not a directory'
)
def validate(self) -> None:
super().validate()
if self.openpai_config is not None and self.openpai_config_file is not None: if self.openpai_config is not None and self.openpai_config_file is not None:
raise ValueError('openpai_config and openpai_config_file can only be set one') raise ValueError('openpai_config and openpai_config_file can only be set one')
if self.openpai_config_file is not None and not Path(self.openpai_config_file).is_file():
raise ValueError(f'OpenpaiConfig: openpai_config_file "(self.openpai_config_file)" is not a file')
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Configuration for remote training service.
Check the reference_ for explaination of each field.
You may also want to check `remote training service doc`_.
.. _reference: https://nni.readthedocs.io/en/stable/reference/experiment_config.html
.. _remote training service doc: https://nni.readthedocs.io/en/stable/TrainingService/RemoteMachineMode.html
"""
__all__ = ['RemoteConfig', 'RemoteMachineConfig']
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Union
import warnings
from ..base import ConfigBase
from ..training_service import TrainingServiceConfig
from .. import utils
@dataclass(init=False)
class RemoteMachineConfig(ConfigBase):
host: str
port: int = 22
user: str
password: Optional[str] = None
ssh_key_file: Optional[utils.PathLike] = '~/.ssh/id_rsa'
ssh_passphrase: Optional[str] = None
use_active_gpu: bool = False
max_trial_number_per_gpu: int = 1
gpu_indices: Union[List[int], int, str, None] = None
python_path: Optional[str] = None
def _canonicalize(self, parents):
super()._canonicalize(parents)
if self.password is not None:
self.ssh_key_file = None
self.gpu_indices = utils.canonical_gpu_indices(self.gpu_indices)
def _validate_canonical(self):
super()._validate_canonical()
assert 0 < self.port < 65536
assert self.max_trial_number_per_gpu > 0
utils.validate_gpu_indices(self.gpu_indices)
if self.password is not None:
warnings.warn('SSH password will be exposed in web UI as plain text. We recommend to use SSH key file.')
elif not Path(self.ssh_key_file).is_file():
raise ValueError(
f'RemoteMachineConfig: You must either provide password or a valid SSH key file "{self.ssh_key_file}"'
)
@dataclass(init=False)
class RemoteConfig(TrainingServiceConfig):
platform: str = 'remote'
machine_list: List[RemoteMachineConfig]
reuse_mode: bool = True
def _validate_canonical(self):
super()._validate_canonical()
if not self.machine_list:
raise ValueError(f'RemoteConfig: must provide at least one machine in machine_list')
if not self.trial_gpu_number and any(machine.max_trial_number_per_gpu != 1 for machine in self.machine_list):
raise ValueError('RemoteConfig: max_trial_number_per_gpu does not work without trial_gpu_number')
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Miscellaneous utility functions.
"""
import importlib
import json
import math
import os.path
from pathlib import Path
from typing import Any, Dict, Optional, Union, List
import nni.runtime.config
PathLike = Union[Path, str]
def case_insensitive(key_or_kwargs: Union[str, Dict[str, Any]]) -> Union[str, Dict[str, Any]]:
if isinstance(key_or_kwargs, str):
return key_or_kwargs.lower().replace('_', '')
else:
return {key.lower().replace('_', ''): value for key, value in key_or_kwargs.items()}
def camel_case(key: str) -> str:
words = key.strip('_').split('_')
return words[0] + ''.join(word.title() for word in words[1:])
def canonical_path(path: Optional[PathLike]) -> Optional[str]:
# Path.resolve() does not work on Windows when file not exist, so use os.path instead
return os.path.abspath(os.path.expanduser(path)) if path is not None else None
def count(*values) -> int:
return sum(value is not None and value is not False for value in values)
def training_service_config_factory(
platform: Union[str, List[str]] = None,
config: Union[List, Dict] = None,
base_path: Optional[Path] = None): # -> TrainingServiceConfig
from .common import TrainingServiceConfig
# import all custom config classes so they can be found in TrainingServiceConfig.__subclasses__()
custom_ts_config_path = nni.runtime.config.get_config_file('training_services.json')
custom_ts_config = json.load(custom_ts_config_path.open())
for custom_ts_pkg in custom_ts_config.keys():
pkg = importlib.import_module(custom_ts_pkg)
_config_class = pkg.nni_training_service_info.config_class
ts_configs = []
if platform is not None:
assert config is None
platforms = platform if isinstance(platform, list) else [platform]
for cls in TrainingServiceConfig.__subclasses__():
if cls.platform in platforms:
ts_configs.append(cls())
if len(ts_configs) < len(platforms):
bad = ', '.join(set(platforms) - set(ts_configs))
raise RuntimeError(f'Bad training service platform: {bad}')
else:
assert config is not None
supported_platforms = {cls.platform: cls for cls in TrainingServiceConfig.__subclasses__()}
configs = config if isinstance(config, list) else [config]
for conf in configs:
if conf['platform'] not in supported_platforms:
raise RuntimeError(f'Unrecognized platform {conf["platform"]}')
ts_configs.append(supported_platforms[conf['platform']](_base_path=base_path, **conf))
return ts_configs if len(ts_configs) > 1 else ts_configs[0]
def load_config(Type, value):
if isinstance(value, list):
return [load_config(Type, item) for item in value]
if isinstance(value, dict):
return Type(**value)
return value
def strip_optional(type_hint):
return type_hint.__args__[0] if str(type_hint).startswith('typing.Optional[') else type_hint
def parse_time(time: str, target_unit: str = 's') -> int:
return _parse_unit(time.lower(), target_unit, _time_units)
def parse_size(size: str, target_unit: str = 'mb') -> int:
return _parse_unit(size.lower(), target_unit, _size_units)
_time_units = {'d': 24 * 3600, 'h': 3600, 'm': 60, 's': 1}
_size_units = {'gb': 1024 * 1024 * 1024, 'mb': 1024 * 1024, 'kb': 1024}
def _parse_unit(string, target_unit, all_units):
for unit, factor in all_units.items():
if string.endswith(unit):
number = string[:-len(unit)]
value = float(number) * factor
return math.ceil(value / all_units[target_unit])
raise ValueError(f'Unsupported unit in "{string}"')
def canonical_gpu_indices(indices: Union[List[int], str, int, None]) -> Optional[List[int]]:
if isinstance(indices, str):
return [int(idx) for idx in indices.split(',')]
if isinstance(indices, int):
return [indices]
return indices
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Utility functions for experiment config classes.
Check "public.py" to see which functions you can utilize.
"""
from .public import *
from .internal import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Utility functions for experiment config classes, internal part.
If you are implementing a config class for a training service, it's unlikely you will need these.
"""
import dataclasses
import importlib
import json
import os.path
from pathlib import Path
import socket
import typeguard
import nni.runtime.config
from .public import is_missing
## handle relative path ##
_current_base_path = None
def get_base_path():
if _current_base_path is None:
return Path()
return _current_base_path
def set_base_path(path):
global _current_base_path
assert _current_base_path is None
_current_base_path = path
def unset_base_path():
global _current_base_path
_current_base_path = None
def resolve_path(path, base_path):
if path is None:
return None
# Path.resolve() does not work on Windows when file not exist, so use os.path instead
path = os.path.expanduser(path)
if not os.path.isabs(path):
path = os.path.join(base_path, path)
return str(os.path.realpath(path)) # it should be already str, but official doc does not specify it's type
## field name case convertion ##
def case_insensitive(key):
return key.lower().replace('_', '')
def camel_case(key):
words = key.strip('_').split('_')
return words[0] + ''.join(word.title() for word in words[1:])
## type hint utils ##
def is_instance(value, type_hint):
try:
typeguard.check_type('_', value, type_hint)
except TypeError:
return False
return True
def validate_type(config):
class_name = type(config).__name__
for field in dataclasses.fields(config):
value = getattr(config, field.name)
#check existense
if is_missing(value):
raise ValueError(f'{class_name}: {field.name} is not set')
if not is_instance(value, field.type):
raise ValueError(f'{class_name}: type of {field.name} ({repr(value)}) is not {field.type}')
def is_path_like(type_hint):
# only `PathLike` and `Any` accepts `Path`; check `int` to make sure it's not `Any`
return is_instance(Path(), type_hint) and not is_instance(1, type_hint)
## type inference ##
def guess_config_type(obj, type_hint):
ret = guess_list_config_type([obj], type_hint, _hint_list_item=True)
return ret[0] if ret else None
def guess_list_config_type(objs, type_hint, _hint_list_item=False):
# avoid circular import
from ..base import ConfigBase
from ..training_service import TrainingServiceConfig
# because __init__ of subclasses might be complex, we first create empty objects to determine type
candidate_classes = []
for cls in _all_subclasses(ConfigBase):
if issubclass(cls, TrainingServiceConfig): # training service configs are specially handled
continue
empty_list = [cls.__new__(cls)]
if _hint_list_item:
good_type = is_instance(empty_list[0], type_hint)
else:
good_type = is_instance(empty_list, type_hint)
if good_type:
candidate_classes.append(cls)
if not candidate_classes: # it does not accept config type
return None
if len(candidate_classes) == 1: # the type is confirmed, raise error if cannot convert to this type
return [candidate_classes[0](**obj) for obj in objs]
# multiple candidates available, call __init__ to further verify
candidate_configs = []
for cls in candidate_classes:
try:
configs = [cls(**obj) for obj in objs]
except Exception:
continue
candidate_configs.append(configs)
if not candidate_configs:
return None
if len(candidate_configs) == 1:
return candidate_configs[0]
# still have multiple candidates, choose the common base class
for base in candidate_configs:
base_class = type(base[0])
is_base = all(isinstance(configs[0], base_class) for configs in candidate_configs)
if is_base:
return base
return None # cannot detect the type, give up
def _all_subclasses(cls):
subclasses = set(cls.__subclasses__())
return subclasses.union(*[_all_subclasses(subclass) for subclass in subclasses])
def training_service_config_factory(platform):
cls = _get_ts_config_class(platform)
if cls is None:
raise ValueError(f'Bad training service platform: {platform}')
return cls()
def load_training_service_config(config):
if isinstance(config, dict) and 'platform' in config:
cls = _get_ts_config_class(config['platform'])
if cls is not None:
return cls(**config)
return config # not valid json, don't touch
def _get_ts_config_class(platform):
from ..training_service import TrainingServiceConfig # avoid circular import
# import all custom config classes so they can be found in TrainingServiceConfig.__subclasses__()
custom_ts_config_path = nni.runtime.config.get_config_file('training_services.json')
with custom_ts_config_path.open() as config_file:
custom_ts_config = json.load(config_file)
for custom_ts_pkg in custom_ts_config.keys():
pkg = importlib.import_module(custom_ts_pkg)
_config_class = pkg.nni_training_service_info.config_class
for cls in TrainingServiceConfig.__subclasses__():
if cls.platform == platform:
return cls
return None
## misc ##
def get_ipv4_address():
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(('192.0.2.0', 80))
addr = s.getsockname()[0]
s.close()
return addr
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Utility functions for experiment config classes.
"""
import dataclasses
import math
from pathlib import Path
from typing import Union
PathLike = Union[Path, str]
def is_missing(value):
"""
Used to check whether a dataclass field has ever been assigned.
If a field without default value has never been assigned, it will have a special value ``MISSING``.
This function checks if the parameter is ``MISSING``.
"""
# MISSING is not singleton and there is no official API to check it
return isinstance(value, type(dataclasses.MISSING))
def canonical_gpu_indices(indices):
"""
If ``indices`` is not None, cast it to list of int.
"""
if isinstance(indices, str):
return [int(idx) for idx in indices.split(',')]
if isinstance(indices, int):
return [indices]
return indices
def validate_gpu_indices(indices):
if indices is None:
return
if len(set(indices)) != len(indices):
raise ValueError(f'Duplication detected in GPU indices {indices}')
if any(idx < 0 for idx in indices):
raise ValueError(f'Negative detected in GPU indices {indices}')
def parse_time(value):
"""
If ``value`` is a string, convert it to integral number of seconds.
"""
return _parse_unit(value, 's', _time_units)
def parse_memory_size(value):
"""
If ``value`` is a string, convert it to integral number of mega bytes.
"""
return _parse_unit(value, 'mb', _size_units)
_time_units = {'d': 24 * 3600, 'h': 3600, 'm': 60, 's': 1}
_size_units = {'tb': 1024 ** 4, 'gb': 1024 ** 3, 'mb': 1024 ** 2, 'kb': 1024, 'b': 1}
def _parse_unit(value, target_unit, all_units):
if not isinstance(value, str):
return value
value = value.lower()
for unit, factor in all_units.items():
if value.endswith(unit):
number = value[:-len(unit)]
value = float(number) * factor
return math.ceil(value / all_units[target_unit])
supported_units = ', '.join(all_units.keys())
raise ValueError(f'Bad unit in "{value}", supported units are {supported_units}')
import atexit import atexit
from enum import Enum
import logging import logging
from pathlib import Path from pathlib import Path
import socket import socket
...@@ -12,7 +13,7 @@ import psutil ...@@ -12,7 +13,7 @@ import psutil
import nni.runtime.log import nni.runtime.log
from nni.common import dump from nni.common import dump
from .config import ExperimentConfig, AlgorithmConfig from .config import ExperimentConfig
from .data import TrialJob, TrialMetricData, TrialResult from .data import TrialJob, TrialMetricData, TrialResult
from . import launcher from . import launcher
from . import management from . import management
...@@ -21,6 +22,17 @@ from ..tools.nnictl.command_utils import kill_command ...@@ -21,6 +22,17 @@ from ..tools.nnictl.command_utils import kill_command
_logger = logging.getLogger('nni.experiment') _logger = logging.getLogger('nni.experiment')
class RunMode(Enum):
"""
Config lifecycle and ouput redirection of NNI manager process.
- Background: stop NNI manager when Python script exits; do not print NNI manager log. (default)
- Foreground: stop NNI manager when Python script exits; print NNI manager log to stdout.
- Detach: do not stop NNI manager when Python script exits.
"""
Background = 'background'
Foreground = 'foreground'
Detach = 'detach'
class Experiment: class Experiment:
""" """
...@@ -73,21 +85,19 @@ class Experiment: ...@@ -73,21 +85,19 @@ class Experiment:
nni.runtime.log.init_logger_experiment() nni.runtime.log.init_logger_experiment()
self.config: Optional[ExperimentConfig] = None self.config: Optional[ExperimentConfig] = None
self.id: Optional[str] = None self.id: str = management.generate_experiment_id()
self.port: Optional[int] = None self.port: Optional[int] = None
self._proc: Optional[Popen] = None self._proc: Optional[Popen] = None
self.mode = 'new' self.mode = 'new'
self.url_prefix: Optional[str] = None
args = [config, training_service] # deal with overloading args = [config, training_service] # deal with overloading
if isinstance(args[0], (str, list)): if isinstance(args[0], (str, list)):
self.config = ExperimentConfig(args[0]) self.config = ExperimentConfig(args[0])
self.config.tuner = AlgorithmConfig(name='_none_', class_args={})
self.config.assessor = AlgorithmConfig(name='_none_', class_args={})
self.config.advisor = AlgorithmConfig(name='_none_', class_args={})
else: else:
self.config = args[0] self.config = args[0]
def start(self, port: int = 8080, debug: bool = False) -> None: def start(self, port: int = 8080, debug: bool = False, run_mode: RunMode = RunMode.Background) -> None:
""" """
Start the experiment in background. Start the experiment in background.
...@@ -101,25 +111,25 @@ class Experiment: ...@@ -101,25 +111,25 @@ class Experiment:
debug debug
Whether to start in debug mode. Whether to start in debug mode.
""" """
if run_mode is not RunMode.Detach:
atexit.register(self.stop) atexit.register(self.stop)
if self.mode == 'new': config = self.config.canonical_copy()
self.id = management.generate_experiment_id() if config.use_annotation:
else: raise RuntimeError('NNI annotation is not supported by Python experiment API.')
self.config = launcher.get_stopped_experiment_config(self.id, self.mode)
if self.config.experiment_working_directory is not None: if config.experiment_working_directory is not None:
log_dir = Path(self.config.experiment_working_directory, self.id, 'log') log_dir = Path(config.experiment_working_directory, self.id, 'log')
else: else: # this should never happen in latest version, keep it until v2.7 for potential compatibility
log_dir = Path.home() / f'nni-experiments/{self.id}/log' log_dir = Path.home() / f'nni-experiments/{self.id}/log'
nni.runtime.log.start_experiment_log(self.id, log_dir, debug) nni.runtime.log.start_experiment_log(self.id, log_dir, debug)
self._proc = launcher.start_experiment(self.id, self.config, port, debug, mode=self.mode) self._proc = launcher.start_experiment(self.mode, self.id, config, port, debug, run_mode, self.url_prefix)
assert self._proc is not None assert self._proc is not None
self.port = port # port will be None if start up failed self.port = port # port will be None if start up failed
ips = [self.config.nni_manager_ip] ips = [config.nni_manager_ip]
for interfaces in psutil.net_if_addrs().values(): for interfaces in psutil.net_if_addrs().values():
for interface in interfaces: for interface in interfaces:
if interface.family == socket.AF_INET: if interface.family == socket.AF_INET:
...@@ -135,11 +145,10 @@ class Experiment: ...@@ -135,11 +145,10 @@ class Experiment:
_logger.info('Stopping experiment, please wait...') _logger.info('Stopping experiment, please wait...')
atexit.unregister(self.stop) atexit.unregister(self.stop)
if self.id is not None:
nni.runtime.log.stop_experiment_log(self.id) nni.runtime.log.stop_experiment_log(self.id)
if self._proc is not None: if self._proc is not None:
try: try:
rest.delete(self.port, '/experiment') rest.delete(self.port, '/experiment', self.url_prefix)
except Exception as e: except Exception as e:
_logger.exception(e) _logger.exception(e)
_logger.warning('Cannot gracefully stop experiment, killing NNI process...') _logger.warning('Cannot gracefully stop experiment, killing NNI process...')
...@@ -197,8 +206,8 @@ class Experiment: ...@@ -197,8 +206,8 @@ class Experiment:
_logger.info('Connect to port %d success, experiment id is %s, status is %s.', port, experiment.id, status) _logger.info('Connect to port %d success, experiment id is %s, status is %s.', port, experiment.id, status)
return experiment return experiment
@classmethod @staticmethod
def resume(cls, experiment_id: str, port: int = 8080, wait_completion: bool = True, debug: bool = False): def resume(experiment_id: str, port: int = 8080, wait_completion: bool = True, debug: bool = False):
""" """
Resume a stopped experiment. Resume a stopped experiment.
...@@ -213,15 +222,13 @@ class Experiment: ...@@ -213,15 +222,13 @@ class Experiment:
debug debug
Whether to start in debug mode. Whether to start in debug mode.
""" """
experiment = Experiment() experiment = Experiment._resume(experiment_id)
experiment.id = experiment_id
experiment.mode = 'resume'
experiment.run(port=port, wait_completion=wait_completion, debug=debug) experiment.run(port=port, wait_completion=wait_completion, debug=debug)
if not wait_completion: if not wait_completion:
return experiment return experiment
@classmethod @staticmethod
def view(cls, experiment_id: str, port: int = 8080, non_blocking: bool = False): def view(experiment_id: str, port: int = 8080, non_blocking: bool = False):
""" """
View a stopped experiment. View a stopped experiment.
...@@ -234,11 +241,8 @@ class Experiment: ...@@ -234,11 +241,8 @@ class Experiment:
non_blocking non_blocking
If false, run in the foreground. If true, run in the background. If false, run in the foreground. If true, run in the background.
""" """
debug = False experiment = Experiment._view(experiment_id)
experiment = Experiment() experiment.start(port=port, debug=False)
experiment.id = experiment_id
experiment.mode = 'view'
experiment.start(port=port, debug=debug)
if non_blocking: if non_blocking:
return experiment return experiment
else: else:
...@@ -250,6 +254,22 @@ class Experiment: ...@@ -250,6 +254,22 @@ class Experiment:
finally: finally:
experiment.stop() experiment.stop()
@staticmethod
def _resume(exp_id, exp_dir=None):
exp = Experiment()
exp.id = exp_id
exp.mode = 'resume'
exp.config = launcher.get_stopped_experiment_config(exp_id, exp_dir)
return exp
@staticmethod
def _view(exp_id, exp_dir=None):
exp = Experiment()
exp.id = exp_id
exp.mode = 'view'
exp.config = launcher.get_stopped_experiment_config(exp_id, exp_dir)
return exp
def get_status(self) -> str: def get_status(self) -> str:
""" """
Return experiment status as a str. Return experiment status as a str.
...@@ -259,7 +279,7 @@ class Experiment: ...@@ -259,7 +279,7 @@ class Experiment:
str str
Experiment status. Experiment status.
""" """
resp = rest.get(self.port, '/check-status') resp = rest.get(self.port, '/check-status', self.url_prefix)
return resp['status'] return resp['status']
def get_trial_job(self, trial_job_id: str): def get_trial_job(self, trial_job_id: str):
...@@ -276,7 +296,7 @@ class Experiment: ...@@ -276,7 +296,7 @@ class Experiment:
TrialJob TrialJob
A `TrialJob` instance corresponding to `trial_job_id`. A `TrialJob` instance corresponding to `trial_job_id`.
""" """
resp = rest.get(self.port, '/trial-jobs/{}'.format(trial_job_id)) resp = rest.get(self.port, '/trial-jobs/{}'.format(trial_job_id), self.url_prefix)
return TrialJob(**resp) return TrialJob(**resp)
def list_trial_jobs(self): def list_trial_jobs(self):
...@@ -288,7 +308,7 @@ class Experiment: ...@@ -288,7 +308,7 @@ class Experiment:
list list
List of `TrialJob`. List of `TrialJob`.
""" """
resp = rest.get(self.port, '/trial-jobs') resp = rest.get(self.port, '/trial-jobs', self.url_prefix)
return [TrialJob(**trial_job) for trial_job in resp] return [TrialJob(**trial_job) for trial_job in resp]
def get_job_statistics(self): def get_job_statistics(self):
...@@ -300,7 +320,7 @@ class Experiment: ...@@ -300,7 +320,7 @@ class Experiment:
dict dict
Job statistics information. Job statistics information.
""" """
resp = rest.get(self.port, '/job-statistics') resp = rest.get(self.port, '/job-statistics', self.url_prefix)
return resp return resp
def get_job_metrics(self, trial_job_id=None): def get_job_metrics(self, trial_job_id=None):
...@@ -318,7 +338,7 @@ class Experiment: ...@@ -318,7 +338,7 @@ class Experiment:
Each key is a trialJobId, the corresponding value is a list of `TrialMetricData`. Each key is a trialJobId, the corresponding value is a list of `TrialMetricData`.
""" """
api = '/metric-data/{}'.format(trial_job_id) if trial_job_id else '/metric-data' api = '/metric-data/{}'.format(trial_job_id) if trial_job_id else '/metric-data'
resp = rest.get(self.port, api) resp = rest.get(self.port, api, self.url_prefix)
metric_dict = {} metric_dict = {}
for metric in resp: for metric in resp:
trial_id = metric["trialJobId"] trial_id = metric["trialJobId"]
...@@ -337,7 +357,7 @@ class Experiment: ...@@ -337,7 +357,7 @@ class Experiment:
dict dict
The profile of the experiment. The profile of the experiment.
""" """
resp = rest.get(self.port, '/experiment') resp = rest.get(self.port, '/experiment', self.url_prefix)
return resp return resp
def get_experiment_metadata(self, exp_id: str): def get_experiment_metadata(self, exp_id: str):
...@@ -364,7 +384,7 @@ class Experiment: ...@@ -364,7 +384,7 @@ class Experiment:
list list
The experiments metadata. The experiments metadata.
""" """
resp = rest.get(self.port, '/experiments-info') resp = rest.get(self.port, '/experiments-info', self.url_prefix)
return resp return resp
def export_data(self): def export_data(self):
...@@ -376,7 +396,7 @@ class Experiment: ...@@ -376,7 +396,7 @@ class Experiment:
list list
List of `TrialResult`. List of `TrialResult`.
""" """
resp = rest.get(self.port, '/export-data') resp = rest.get(self.port, '/export-data', self.url_prefix)
return [TrialResult(**trial_result) for trial_result in resp] return [TrialResult(**trial_result) for trial_result in resp]
def _get_query_type(self, key: str): def _get_query_type(self, key: str):
...@@ -403,7 +423,7 @@ class Experiment: ...@@ -403,7 +423,7 @@ class Experiment:
api = '/experiment{}'.format(self._get_query_type(key)) api = '/experiment{}'.format(self._get_query_type(key))
experiment_profile = self.get_experiment_profile() experiment_profile = self.get_experiment_profile()
experiment_profile['params'][key] = value experiment_profile['params'][key] = value
rest.put(self.port, api, experiment_profile) rest.put(self.port, api, experiment_profile, self.url_prefix)
logging.info('Successfully update %s.', key) logging.info('Successfully update %s.', key)
def update_trial_concurrency(self, value: int): def update_trial_concurrency(self, value: int):
......
...@@ -2,7 +2,10 @@ ...@@ -2,7 +2,10 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import contextlib import contextlib
from dataclasses import dataclass, fields
from datetime import datetime
import logging import logging
import os.path
from pathlib import Path from pathlib import Path
import socket import socket
from subprocess import Popen from subprocess import Popen
...@@ -23,29 +26,89 @@ from ..tools.nnictl.nnictl_utils import update_experiment ...@@ -23,29 +26,89 @@ from ..tools.nnictl.nnictl_utils import update_experiment
_logger = logging.getLogger('nni.experiment') _logger = logging.getLogger('nni.experiment')
@dataclass(init=False)
class NniManagerArgs:
port: int
experiment_id: int
start_mode: str # new or resume
mode: str # training service platform
log_dir: str
log_level: str
readonly: bool = False
foreground: bool = False
url_prefix: Optional[str] = None
dispatcher_pipe: Optional[str] = None
def start_experiment(exp_id: str, config: ExperimentConfig, port: int, debug: bool, mode: str = 'new') -> Popen: def __init__(self, action, exp_id, config, port, debug, foreground, url_prefix):
proc = None self.port = port
self.experiment_id = exp_id
self.foreground = foreground
self.url_prefix = url_prefix
self.log_dir = config.experiment_working_directory
config.validate(initialized_tuner=False) if isinstance(config.training_service, list):
_ensure_port_idle(port) self.mode = 'hybrid'
else:
self.mode = config.training_service.platform
if mode != 'view': self.log_level = config.log_level
if isinstance(config.training_service, list): # hybrid training service if debug and self.log_level not in ['debug', 'trace']:
_ensure_port_idle(port + 1, 'Hybrid training service requires an additional port') self.log_level = 'debug'
elif config.training_service.platform in ['remote', 'openpai', 'kubeflow', 'frameworkcontroller', 'adl']:
_ensure_port_idle(port + 1, f'{config.training_service.platform} requires an additional port')
if action == 'resume':
self.start_mode = 'resume'
elif action == 'view':
self.start_mode = 'resume'
self.readonly = True
else:
self.start_mode = 'new'
def to_command_line_args(self):
ret = []
for field in fields(self):
value = getattr(self, field.name)
if value is not None:
ret.append('--' + field.name)
if isinstance(value, bool):
ret.append(str(value).lower())
else:
ret.append(str(value))
return ret
def start_experiment(action, exp_id, config, port, debug, run_mode, url_prefix):
foreground = run_mode.value == 'foreground'
nni_manager_args = NniManagerArgs(action, exp_id, config, port, debug, foreground, url_prefix)
_ensure_port_idle(port)
websocket_platforms = ['hybrid', 'remote', 'openpai', 'kubeflow', 'frameworkcontroller', 'adl']
if action != 'view' and nni_manager_args.mode in websocket_platforms:
_ensure_port_idle(port + 1, f'{nni_manager_args.mode} requires an additional port')
proc = None
try: try:
_logger.info('Creating experiment, Experiment ID: %s', colorama.Fore.CYAN + exp_id + colorama.Style.RESET_ALL) _logger.info(
start_time, proc = _start_rest_server(config, port, debug, exp_id, mode=mode) 'Creating experiment, Experiment ID: %s', colorama.Fore.CYAN + exp_id + colorama.Style.RESET_ALL
)
proc = _start_rest_server(nni_manager_args, run_mode)
start_time = int(time.time() * 1000)
_logger.info('Starting web server...') _logger.info('Starting web server...')
_check_rest_server(port) _check_rest_server(port, url_prefix=url_prefix)
platform = 'hybrid' if isinstance(config.training_service, list) else config.training_service.platform
_save_experiment_information(exp_id, port, start_time, platform, Experiments().add_experiment(
config.experiment_name, proc.pid, str(config.experiment_working_directory), []) exp_id,
port,
start_time,
nni_manager_args.mode,
config.experiment_name,
pid=proc.pid,
logDir=config.experiment_working_directory,
tag=[],
)
_logger.info('Setting up...') _logger.info('Setting up...')
rest.post(port, '/experiment', config.json()) rest.post(port, '/experiment', config.json(), url_prefix)
return proc return proc
except Exception as e: except Exception as e:
...@@ -55,6 +118,33 @@ def start_experiment(exp_id: str, config: ExperimentConfig, port: int, debug: bo ...@@ -55,6 +118,33 @@ def start_experiment(exp_id: str, config: ExperimentConfig, port: int, debug: bo
proc.kill() proc.kill()
raise e raise e
def _start_rest_server(nni_manager_args, run_mode) -> Tuple[int, Popen]:
node_dir = Path(nni_node.__path__[0])
node = str(node_dir / ('node.exe' if sys.platform == 'win32' else 'node'))
main_js = str(node_dir / 'main.js')
cmd = [node, '--max-old-space-size=4096', main_js]
cmd += nni_manager_args.to_command_line_args()
if run_mode.value == 'detach':
log = Path(nni_manager_args.log_dir, nni_manager_args.experiment_id, 'log')
out = (log / 'nnictl_stdout.log').open('a')
err = (log / 'nnictl_stderr.log').open('a')
header = f'Experiment {nni_manager_args.experiment_id} start: {datetime.now()}'
header = '-' * 80 + '\n' + header + '\n' + '-' * 80 + '\n'
out.write(header)
err.write(header)
else:
out = None
err = None
if sys.platform == 'win32':
from subprocess import CREATE_NEW_PROCESS_GROUP
return Popen(cmd, stdout=out, stderr=err, cwd=node_dir, creationflags=CREATE_NEW_PROCESS_GROUP)
else:
return Popen(cmd, stdout=out, stderr=err, cwd=node_dir, preexec_fn=os.setpgrp)
def start_experiment_retiarii(exp_id: str, config: ExperimentConfig, port: int, debug: bool) -> Popen: def start_experiment_retiarii(exp_id: str, config: ExperimentConfig, port: int, debug: bool) -> Popen:
pipe = None pipe = None
proc = None proc = None
...@@ -69,7 +159,7 @@ def start_experiment_retiarii(exp_id: str, config: ExperimentConfig, port: int, ...@@ -69,7 +159,7 @@ def start_experiment_retiarii(exp_id: str, config: ExperimentConfig, port: int,
try: try:
_logger.info('Creating experiment, Experiment ID: %s', colorama.Fore.CYAN + exp_id + colorama.Style.RESET_ALL) _logger.info('Creating experiment, Experiment ID: %s', colorama.Fore.CYAN + exp_id + colorama.Style.RESET_ALL)
pipe = Pipe(exp_id) pipe = Pipe(exp_id)
start_time, proc = _start_rest_server(config, port, debug, exp_id, pipe.path) start_time, proc = _start_rest_server_retiarii(config, port, debug, exp_id, pipe.path)
_logger.info('Connecting IPC pipe...') _logger.info('Connecting IPC pipe...')
pipe_file = pipe.connect() pipe_file = pipe.connect()
nni.runtime.protocol._in_file = pipe_file nni.runtime.protocol._in_file = pipe_file
...@@ -101,8 +191,8 @@ def _ensure_port_idle(port: int, message: Optional[str] = None) -> None: ...@@ -101,8 +191,8 @@ def _ensure_port_idle(port: int, message: Optional[str] = None) -> None:
raise RuntimeError(f'Port {port} is not idle {message}') raise RuntimeError(f'Port {port} is not idle {message}')
def _start_rest_server(config: ExperimentConfig, port: int, debug: bool, experiment_id: str, pipe_path: str = None, def _start_rest_server_retiarii(config: ExperimentConfig, port: int, debug: bool, experiment_id: str,
mode: str = 'new') -> Tuple[int, Popen]: pipe_path: str = None, mode: str = 'new') -> Tuple[int, Popen]:
if isinstance(config.training_service, list): if isinstance(config.training_service, list):
ts = 'hybrid' ts = 'hybrid'
else: else:
...@@ -145,15 +235,15 @@ def _start_rest_server(config: ExperimentConfig, port: int, debug: bool, experim ...@@ -145,15 +235,15 @@ def _start_rest_server(config: ExperimentConfig, port: int, debug: bool, experim
return int(time.time() * 1000), proc return int(time.time() * 1000), proc
def _check_rest_server(port: int, retry: int = 3) -> None: def _check_rest_server(port: int, retry: int = 3, url_prefix: Optional[str] = None) -> None:
for i in range(retry): for i in range(retry):
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
rest.get(port, '/check-status') rest.get(port, '/check-status', url_prefix)
return return
if i > 0: if i > 0:
_logger.warning('Timeout, retry...') _logger.warning('Timeout, retry...')
time.sleep(1) time.sleep(1)
rest.get(port, '/check-status') rest.get(port, '/check-status', url_prefix)
def _save_experiment_information(experiment_id: str, port: int, start_time: int, platform: str, def _save_experiment_information(experiment_id: str, port: int, start_time: int, platform: str,
...@@ -162,7 +252,16 @@ def _save_experiment_information(experiment_id: str, port: int, start_time: int, ...@@ -162,7 +252,16 @@ def _save_experiment_information(experiment_id: str, port: int, start_time: int,
experiments_config.add_experiment(experiment_id, port, start_time, platform, name, pid=pid, logDir=logDir, tag=tag) experiments_config.add_experiment(experiment_id, port, start_time, platform, name, pid=pid, logDir=logDir, tag=tag)
def get_stopped_experiment_config(exp_id: str, mode: str) -> None: def get_stopped_experiment_config(exp_id, exp_dir=None):
if exp_dir:
exp_config = Config(exp_id, exp_dir).get_config()
config = ExperimentConfig(**exp_config)
if not os.path.samefile(exp_dir, config.experiment_working_directory):
msg = 'Experiment working directory provided in command line (%s) is different from experiment config (%s)'
_logger.warning(msg, exp_dir, config.experiment_working_directory)
config.experiment_working_directory = exp_dir
return config
else:
update_experiment() update_experiment()
experiments_config = Experiments() experiments_config = Experiments()
experiments_dict = experiments_config.get_all_experiments() experiments_dict = experiments_config.get_all_experiments()
...@@ -171,8 +270,7 @@ def get_stopped_experiment_config(exp_id: str, mode: str) -> None: ...@@ -171,8 +270,7 @@ def get_stopped_experiment_config(exp_id: str, mode: str) -> None:
_logger.error('Id %s not exist!', exp_id) _logger.error('Id %s not exist!', exp_id)
return return
if experiment_metadata['status'] != 'STOPPED': if experiment_metadata['status'] != 'STOPPED':
_logger.error('Only stopped experiments can be %sed!', mode) _logger.error('Only stopped experiments can be resumed or viewed!')
return return
experiment_config = Config(exp_id, experiment_metadata['logDir']).get_config() experiment_config = Config(exp_id, experiment_metadata['logDir']).get_config()
config = ExperimentConfig(**experiment_config) return ExperimentConfig(**experiment_config)
return config
...@@ -5,31 +5,40 @@ import requests ...@@ -5,31 +5,40 @@ import requests
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
url_template = 'http://localhost:{}/api/v1/nni{}'
timeout = 20 timeout = 20
def request(method: str, port: Optional[int], api: str, data: Any = None) -> Any: def request(method: str, port: Optional[int], api: str, data: Any = None, prefix: Optional[str] = None) -> Any:
if port is None: if port is None:
raise RuntimeError('Experiment is not running') raise RuntimeError('Experiment is not running')
url = url_template.format(port, api)
url_parts = [
f'http://localhost:{port}',
prefix,
'api/v1/nni',
api
]
url = '/'.join(part.strip('/') for part in url_parts if part)
if data is None: if data is None:
resp = requests.request(method, url, timeout=timeout) resp = requests.request(method, url, timeout=timeout)
else: else:
resp = requests.request(method, url, json=data, timeout=timeout) resp = requests.request(method, url, json=data, timeout=timeout)
if not resp.ok: if not resp.ok:
_logger.error('rest request %s %s failed: %s %s', method.upper(), url, resp.status_code, resp.text) _logger.error('rest request %s %s failed: %s %s', method.upper(), url, resp.status_code, resp.text)
resp.raise_for_status() resp.raise_for_status()
if method.lower() in ['get', 'post'] and len(resp.content) > 0: if method.lower() in ['get', 'post'] and len(resp.content) > 0:
return resp.json() return resp.json()
def get(port: Optional[int], api: str) -> Any: def get(port: Optional[int], api: str, prefix: Optional[str] = None) -> Any:
return request('get', port, api) return request('get', port, api, prefix=prefix)
def post(port: Optional[int], api: str, data: Any) -> Any: def post(port: Optional[int], api: str, data: Any, prefix: Optional[str] = None) -> Any:
return request('post', port, api, data) return request('post', port, api, data, prefix=prefix)
def put(port: Optional[int], api: str, data: Any) -> None: def put(port: Optional[int], api: str, data: Any, prefix: Optional[str] = None) -> None:
request('put', port, api, data) request('put', port, api, data, prefix=prefix)
def delete(port: Optional[int], api: str) -> None: def delete(port: Optional[int], api: str, prefix: Optional[str] = None) -> None:
request('delete', port, api) request('delete', port, api, prefix=prefix)
...@@ -18,9 +18,10 @@ import torch ...@@ -18,9 +18,10 @@ import torch
import torch.nn as nn import torch.nn as nn
import nni.runtime.log import nni.runtime.log
from nni.common.device import GPUDevice from nni.common.device import GPUDevice
from nni.experiment import Experiment, TrainingServiceConfig, launcher, management, rest from nni.experiment import Experiment, launcher, management, rest
from nni.experiment.config import util from nni.experiment.config import utils
from nni.experiment.config.base import ConfigBase, PathLike from nni.experiment.config.base import ConfigBase
from nni.experiment.config.training_service import TrainingServiceConfig
from nni.experiment.pipe import Pipe from nni.experiment.pipe import Pipe
from nni.tools.nnictl.command_utils import kill_command from nni.tools.nnictl.command_utils import kill_command
...@@ -45,7 +46,7 @@ class RetiariiExeConfig(ConfigBase): ...@@ -45,7 +46,7 @@ class RetiariiExeConfig(ConfigBase):
experiment_name: Optional[str] = None experiment_name: Optional[str] = None
search_space: Any = '' # TODO: remove search_space: Any = '' # TODO: remove
trial_command: str = '_reserved' trial_command: str = '_reserved'
trial_code_directory: PathLike = '.' trial_code_directory: utils.PathLike = '.'
trial_concurrency: int trial_concurrency: int
trial_gpu_number: int = 0 trial_gpu_number: int = 0
devices: Optional[List[Union[str, GPUDevice]]] = None devices: Optional[List[Union[str, GPUDevice]]] = None
...@@ -56,7 +57,7 @@ class RetiariiExeConfig(ConfigBase): ...@@ -56,7 +57,7 @@ class RetiariiExeConfig(ConfigBase):
nni_manager_ip: Optional[str] = None nni_manager_ip: Optional[str] = None
debug: bool = False debug: bool = False
log_level: Optional[str] = None log_level: Optional[str] = None
experiment_working_directory: PathLike = '~/nni-experiments' experiment_working_directory: utils.PathLike = '~/nni-experiments'
# remove configuration of tuner/assessor/advisor # remove configuration of tuner/assessor/advisor
training_service: TrainingServiceConfig training_service: TrainingServiceConfig
execution_engine: str = 'py' execution_engine: str = 'py'
...@@ -71,7 +72,7 @@ class RetiariiExeConfig(ConfigBase): ...@@ -71,7 +72,7 @@ class RetiariiExeConfig(ConfigBase):
super().__init__(**kwargs) super().__init__(**kwargs)
if training_service_platform is not None: if training_service_platform is not None:
assert 'training_service' not in kwargs assert 'training_service' not in kwargs
self.training_service = util.training_service_config_factory(platform=training_service_platform) self.training_service = utils.training_service_config_factory(platform=training_service_platform)
self.__dict__['trial_command'] = 'python3 -m nni.retiarii.trial_entry py' self.__dict__['trial_command'] = 'python3 -m nni.retiarii.trial_entry py'
def __setattr__(self, key, value): def __setattr__(self, key, value):
...@@ -100,16 +101,12 @@ class RetiariiExeConfig(ConfigBase): ...@@ -100,16 +101,12 @@ class RetiariiExeConfig(ConfigBase):
_canonical_rules = { _canonical_rules = {
'trial_code_directory': util.canonical_path,
'max_experiment_duration': lambda value: f'{util.parse_time(value)}s' if value is not None else None,
'experiment_working_directory': util.canonical_path
} }
_validation_rules = { _validation_rules = {
'trial_code_directory': lambda value: (Path(value).is_dir(), f'"{value}" does not exist or is not directory'), 'trial_code_directory': lambda value: (Path(value).is_dir(), f'"{value}" does not exist or is not directory'),
'trial_concurrency': lambda value: value > 0, 'trial_concurrency': lambda value: value > 0,
'trial_gpu_number': lambda value: value >= 0, 'trial_gpu_number': lambda value: value >= 0,
'max_experiment_duration': lambda value: util.parse_time(value) > 0,
'max_trial_number': lambda value: value > 0, 'max_trial_number': lambda value: value > 0,
'log_level': lambda value: value in ["trace", "debug", "info", "warning", "error", "fatal"], 'log_level': lambda value: value in ["trace", "debug", "info", "warning", "error", "fatal"],
'training_service': lambda value: (type(value) is not TrainingServiceConfig, 'cannot be abstract base class') 'training_service': lambda value: (type(value) is not TrainingServiceConfig, 'cannot be abstract base class')
......
...@@ -66,7 +66,9 @@ def start_experiment_log(experiment_id: str, log_directory: Path, debug: bool) - ...@@ -66,7 +66,9 @@ def start_experiment_log(experiment_id: str, log_directory: Path, debug: bool) -
def stop_experiment_log(experiment_id: str) -> None: def stop_experiment_log(experiment_id: str) -> None:
if experiment_id in handlers: if experiment_id in handlers:
logging.getLogger().removeHandler(handlers.pop(experiment_id)) handler = handlers.pop(experiment_id, None)
if handler is not None:
logging.getLogger().removeHandler(handler)
def _init_logger_dispatcher() -> None: def _init_logger_dispatcher() -> None:
......
This diff is collapsed.
This diff is collapsed.
...@@ -62,7 +62,7 @@ def parse_args(): ...@@ -62,7 +62,7 @@ def parse_args():
# parse resume command # parse resume command
parser_resume = subparsers.add_parser('resume', help='resume a new experiment') parser_resume = subparsers.add_parser('resume', help='resume a new experiment')
parser_resume.add_argument('id', nargs='?', help='The id of the experiment you want to resume') parser_resume.add_argument('id', help='The id of the experiment you want to resume')
parser_resume.add_argument('--port', '-p', default=DEFAULT_REST_PORT, dest='port', type=int, help='the port of restful server') parser_resume.add_argument('--port', '-p', default=DEFAULT_REST_PORT, dest='port', type=int, help='the port of restful server')
parser_resume.add_argument('--debug', '-d', action='store_true', help=' set debug mode') parser_resume.add_argument('--debug', '-d', action='store_true', help=' set debug mode')
parser_resume.add_argument('--foreground', '-f', action='store_true', help=' set foreground mode, print log content to terminal') parser_resume.add_argument('--foreground', '-f', action='store_true', help=' set foreground mode, print log content to terminal')
...@@ -72,7 +72,7 @@ def parse_args(): ...@@ -72,7 +72,7 @@ def parse_args():
# parse view command # parse view command
parser_view = subparsers.add_parser('view', help='view a stopped experiment') parser_view = subparsers.add_parser('view', help='view a stopped experiment')
parser_view.add_argument('id', nargs='?', help='The id of the experiment you want to view') parser_view.add_argument('id', help='The id of the experiment you want to view')
parser_view.add_argument('--port', '-p', default=DEFAULT_REST_PORT, dest='port', type=int, help='the port of restful server') parser_view.add_argument('--port', '-p', default=DEFAULT_REST_PORT, dest='port', type=int, help='the port of restful server')
parser_view.add_argument('--experiment_dir', '-e', help='view experiment from external folder, specify the full path of ' \ parser_view.add_argument('--experiment_dir', '-e', help='view experiment from external folder, specify the full path of ' \
'experiment folder') 'experiment folder')
......
...@@ -199,9 +199,6 @@ testCases: ...@@ -199,9 +199,6 @@ testCases:
launchCommand: nnictl view $resumeExpId launchCommand: nnictl view $resumeExpId
experimentStatusCheck: False experimentStatusCheck: False
- name: multi-thread
configFile: test/config/multi_thread/config.yml
######################################################################### #########################################################################
# nni assessor test # nni assessor test
......
...@@ -132,9 +132,6 @@ testCases: ...@@ -132,9 +132,6 @@ testCases:
launchCommand: nnictl view $resumeExpId launchCommand: nnictl view $resumeExpId
experimentStatusCheck: False experimentStatusCheck: False
- name: multi-thread
configFile: test/config/multi_thread/config.yml
######################################################################### #########################################################################
# nni assessor test # nni assessor test
######################################################################### #########################################################################
......
...@@ -42,9 +42,6 @@ testCases: ...@@ -42,9 +42,6 @@ testCases:
kwargs: kwargs:
expected_result_file: expected_metrics_dict.json expected_result_file: expected_metrics_dict.json
- name: multi-thread
configFile: test/config/multi_thread/config.yml
######################################################################### #########################################################################
# nni assessor test # nni assessor test
######################################################################### #########################################################################
......
experimentName: test case
searchSpaceFile: search_space.json
trialCommand: python main.py
trialCodeDirectory: ../assets
trialConcurrency: 2
trialGpuNumber: 1
maxExperimentDuration: 1.5h
maxTrialNumber: 10
maxTrialDuration: 60
nniManagerIp: 1.2.3.4
debug: true
logLevel: warning
tunerGpuIndices: 0
assessor:
name: assess
advisor:
className: Advisor
codeDirectory: .
classArgs: {random_seed: 0}
trainingService:
platform: local
useActiveGpu: false
maxTrialNumberPerGpu: 2
gpuIndices: 1,2
reuseMode: true
sharedStorage:
storageType: NFS
localMountPoint: . # git cannot commit empty dir, so just use this
remoteMountPoint: /tmp
localMounted: usermount
nfsServer: nfs.test.case
exportedDirectory: root
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment