Unverified Commit 765206cb authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Create experiment from Python code (#3111)

parent 1a999d70
# FIXME: For demonstration only. It should not be here
from pathlib import Path
from nni.experiment import Experiment
from nni.algorithms.hpo.hyperopt_tuner.hyperopt_tuner import HyperoptTuner
tuner = HyperoptTuner('tpe')
search_space = {
"dropout_rate": { "_type": "uniform", "_value": [0.5, 0.9] },
"conv_size": { "_type": "choice", "_value": [2, 3, 5, 7] },
"hidden_size": { "_type": "choice", "_value": [124, 512, 1024] },
"batch_size": { "_type": "choice", "_value": [16, 32] },
"learning_rate": { "_type": "choice", "_value": [0.0001, 0.001, 0.01, 0.1] }
}
experiment = Experiment(tuner, 'local')
experiment.config.experiment_name = 'test'
experiment.config.trial_concurrency = 2
experiment.config.max_trial_number = 5
experiment.config.search_space = search_space
experiment.config.trial_command = 'python3 mnist.py'
experiment.config.trial_code_directory = Path(__file__).parent
experiment.config.training_service.use_active_gpu = True
experiment.run(8081)
......@@ -3,6 +3,9 @@
__version__ = '999.0.0-developing'
from .runtime.log import init_logger
init_logger()
from .runtime.env_vars import dispatcher_env_vars
from .utils import ClassArgsValidator
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .config import *
from .experiment import Experiment
from .nni_client import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .common import *
from .local import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy
import dataclasses
from pathlib import Path
from typing import Any, Dict, Optional, Type, TypeVar
from ruamel import yaml
from . import util
__all__ = ['ConfigBase', 'PathLike']
T = TypeVar('T', bound='ConfigBase')
PathLike = util.PathLike
def _is_missing(obj: Any) -> bool:
return isinstance(obj, type(dataclasses.MISSING))
class ConfigBase:
"""
Base class of config classes.
Subclass may override `_canonical_rules` and `_validation_rules`,
and `validate()` if the logic is complex.
"""
# Rules to convert field value to canonical format.
# The key is field name.
# The value is callable `value -> canonical_value`
# It is not type-hinted so dataclass won't treat it as field
_canonical_rules = {} # type: ignore
# Rules to validate field value.
# The key is field name.
# The value is callable `value -> valid` or `value -> (valid, error_message)`
# The rule will be called with canonical format and is only called when `value` is not None.
# `error_message` is used when `valid` is False.
# It will be prepended with class name and field name in exception message.
_validation_rules = {} # type: ignore
def __init__(self, *, _base_path: Optional[Path] = None, **kwargs):
"""
Initialize a config object and set some fields.
Name of keyword arguments can either be snake_case or camelCase.
They will be converted to snake_case automatically.
If a field is missing and don't have default value, it will be set to `dataclasses.MISSING`.
"""
kwargs = {util.case_insensitive(key): value for key, value in kwargs.items()}
if _base_path is None:
_base_path = Path()
for field in dataclasses.fields(self):
value = kwargs.pop(util.case_insensitive(field.name), field.default)
if value is not None and not _is_missing(value):
# relative paths loaded from config file are not relative to pwd
if 'Path' in str(field.type):
value = Path(value).expanduser()
if not value.is_absolute():
value = _base_path / value
# convert nested dict to config type
if isinstance(value, dict):
cls = util.strip_optional(field.type)
if isinstance(cls, type) and issubclass(cls, ConfigBase):
value = cls(**value, _base_path=_base_path)
setattr(self, field.name, value)
if kwargs:
cls = type(self).__name__
fields = ', '.join(kwargs.keys())
raise ValueError(f'{cls}: Unrecognized fields {fields}')
@classmethod
def load(cls: Type[T], path: PathLike) -> T:
"""
Load config from YAML (or JSON) file.
Keys in YAML file can either be camelCase or snake_case.
"""
data = yaml.safe_load(open(path))
if not isinstance(data, dict):
raise ValueError(f'Content of config file {path} is not a dict/object')
return cls(**data, _base_path=Path(path).parent)
def json(self) -> Dict[str, Any]:
"""
Convert config to JSON object.
The keys of returned object will be camelCase.
"""
return dataclasses.asdict(
self.canonical(),
dict_factory = lambda items: dict((util.camel_case(k), v) for k, v in items if v is not None)
)
def canonical(self: T) -> T:
"""
Returns a deep copy, where the fields supporting multiple formats are converted to the canonical format.
Noticeably, relative path may be converted to absolute path.
"""
ret = copy.deepcopy(self)
for field in dataclasses.fields(ret):
key, value = field.name, getattr(ret, field.name)
rule = ret._canonical_rules.get(key)
if rule is not None:
setattr(ret, key, rule(value))
elif isinstance(value, ConfigBase):
setattr(ret, key, value.canonical())
# value will be copied twice, should not be a performance issue anyway
return ret
def validate(self) -> None:
"""
Validate the config object and raise Exception if it's ill-formed.
"""
class_name = type(self).__name__
config = self.canonical()
for field in dataclasses.fields(config):
key, value = field.name, getattr(config, field.name)
# check existence
if _is_missing(value):
raise ValueError(f'{class_name}: {key} is not set')
# check type (TODO)
type_name = str(field.type).replace('typing.', '')
optional = any([
type_name.startswith('Optional['),
type_name.startswith('Union[') and 'NoneType' in type_name,
type_name == 'Any'
])
if value is None:
if optional:
continue
else:
raise ValueError(f'{class_name}: {key} cannot be None')
# check value
rule = config._validation_rules.get(key)
if rule is not None:
try:
result = rule(value)
except Exception:
raise ValueError(f'{class_name}: {key} has bad value {repr(value)}')
if isinstance(result, bool):
if not result:
raise ValueError(f'{class_name}: {key} ({repr(value)}) is out of range')
else:
if not result[0]:
raise ValueError(f'{class_name}: {key} {result[1]}')
# check nested config
if isinstance(value, ConfigBase):
value.validate()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from .base import ConfigBase, PathLike
from . import util
__all__ = [
'ExperimentConfig',
'AlgorithmConfig',
'CustomAlgorithmConfig',
'TrainingServiceConfig',
]
@dataclass(init=False)
class _AlgorithmConfig(ConfigBase):
name: Optional[str] = None
class_name: Optional[str] = None
code_directory: Optional[PathLike] = None
class_args: Optional[Dict[str, Any]] = None
def validate(self):
super().validate()
_validate_algo(self)
@dataclass(init=False)
class AlgorithmConfig(_AlgorithmConfig):
name: str
class_args: Optional[Dict[str, Any]] = None
@dataclass(init=False)
class CustomAlgorithmConfig(_AlgorithmConfig):
class_name: str
class_directory: Optional[PathLike] = None
class_args: Optional[Dict[str, Any]] = None
class TrainingServiceConfig(ConfigBase):
platform: str
@dataclass(init=False)
class ExperimentConfig(ConfigBase):
experiment_name: Optional[str] = None
search_space_file: Optional[PathLike] = None
search_space: Any = None
trial_command: str
trial_code_directory: PathLike = '.'
trial_concurrency: int
trial_gpu_number: int = 0
max_experiment_duration: Optional[str] = None
max_trial_number: Optional[int] = None
nni_manager_ip: Optional[str] = None
use_annotation: bool = False
debug: bool = False
log_level: Optional[str] = None
experiment_working_directory: Optional[PathLike] = None
tuner_gpu_indices: Optional[Union[List[int], str]] = None
tuner: Optional[_AlgorithmConfig] = None
accessor: Optional[_AlgorithmConfig] = None
advisor: Optional[_AlgorithmConfig] = None
training_service: TrainingServiceConfig
def __init__(self, training_service_platform: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if training_service_platform is not None:
assert 'training_service' not in kwargs
self.training_service = util.training_service_config_factory(training_service_platform)
def validate(self, initialized_tuner: bool = False) -> None:
super().validate()
if initialized_tuner:
_validate_for_exp(self)
else:
_validate_for_nnictl(self)
## End of public API ##
@property
def _canonical_rules(self):
return _canonical_rules
@property
def _validation_rules(self):
return _validation_rules
_canonical_rules = {
'search_space_file': util.canonical_path,
'trial_code_directory': util.canonical_path,
'max_experiment_duration': lambda value: f'{util.parse_time(value)}s' if value is not None else None,
'experiment_working_directory': util.canonical_path,
'tuner_gpu_indices': lambda value: [int(idx) for idx in value.split(',')] if isinstance(value, str) else value
}
_validation_rules = {
'search_space_file': lambda value: (Path(value).is_file(), f'"{value}" does not exist or is not regular file'),
'trial_code_directory': lambda value: (Path(value).is_dir(), f'"{value}" does not exist or is not directory'),
'trial_concurrency': lambda value: value > 0,
'trial_gpu_number': lambda value: value >= 0,
'max_experiment_duration': lambda value: util.parse_time(value) > 0,
'max_trial_number': lambda value: value > 0,
'log_level': lambda value: value in ["trace", "debug", "info", "warning", "error", "fatal"],
'tuner_gpu_indices': lambda value: all(i >= 0 for i in value) and len(value) == len(set(value)),
'training_service': lambda value: (type(value) is not TrainingServiceConfig, 'cannot be abstract base class')
}
def _validate_for_exp(config: ExperimentConfig) -> None:
# validate experiment for nni.Experiment, where tuner is already initialized outside
if config.use_annotation:
raise ValueError('ExperimentConfig: annotation is not supported in this mode')
if util.count(config.search_space, config.search_space_file) != 1:
raise ValueError('ExperimentConfig: search_space and search_space_file must be set one')
if util.count(config.tuner, config.accessor, config.advisor) != 0:
raise ValueError('ExperimentConfig: tuner, accessor, and advisor must not be set in for this mode')
if config.tuner_gpu_indices is not None:
raise ValueError('ExperimentConfig: tuner_gpu_indices is not supported in this mode')
def _validate_for_nnictl(config: ExperimentConfig) -> None:
# validate experiment for normal launching approach
if config.use_annotation:
if util.count(config.search_space, config.search_space_file) != 0:
raise ValueError('ExperimentConfig: search_space and search_space_file must not be set with annotationn')
else:
if util.count(config.search_space, config.search_space_file) != 1:
raise ValueError('ExperimentConfig: search_space and search_space_file must be set one')
if util.count(config.tuner, config.advisor) != 1:
raise ValueError('ExperimentConfig: tuner and advisor must be set one')
def _validate_algo(algo: AlgorithmConfig) -> None:
if algo.name is None:
if algo.class_name is None:
raise ValueError('Missing algorithm name')
if algo.code_directory is not None and not Path(algo.code_directory).is_dir():
raise ValueError(f'code_directory "{algo.code_directory}" does not exist or is not directory')
else:
if algo.class_name is not None or algo.code_directory is not None:
raise ValueError(f'When name is set for registered algorithm, class_name and code_directory cannot be used')
# TODO: verify algorithm installation and class args
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import logging
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Any, Dict, List
from .common import ExperimentConfig
from . import util
_logger = logging.getLogger(__name__)
def to_v1_yaml(config: ExperimentConfig, skip_nnictl: bool = False) -> Dict[str, Any]:
config.validate(skip_nnictl)
data = config.json()
ts = data.pop('trainingService')
if ts['platform'] == 'openpai':
ts['platform'] = 'pai'
data['authorName'] = 'N/A'
data['experimentName'] = data.get('experimentName', 'N/A')
data['maxExecDuration'] = data.pop('maxExperimentDuration', '999d')
if data['debug']:
data['versionCheck'] = False
data['maxTrialNum'] = data.pop('maxTrialNumber', 99999)
data['trainingServicePlatform'] = ts['platform']
ss = data.pop('searchSpace', None)
ss_file = data.pop('searchSpaceFile', None)
if ss is not None:
ss_file = NamedTemporaryFile('w', delete=False)
json.dump(ss, ss_file, indent=4)
data['searchSpacePath'] = ss_file.name
elif ss_file is not None:
data['searchSpacePath'] = ss_file
if 'experimentWorkingDirectory' in data:
data['logDir'] = data.pop('experimentWorkingDirectory')
for algo_type in ['tuner', 'assessor', 'advisor']:
algo = data.get(algo_type)
if algo is None:
continue
if algo['name'] is not None: # builtin
algo['builtin' + algo_type.title() + 'Name'] = algo.pop('name')
algo.pop('className', None)
algo.pop('codeDirectory', None)
else:
algo.pop('name', None)
class_name_parts = algo.pop('className').split('.')
algo['codeDir'] = algo.pop('codeDirectory', '') + '/'.join(class_name_parts[:-2])
algo['classFileName'] = class_name_parts[-2] + '.py'
algo['className'] = class_name_parts[-1]
tuner_gpu_indices = _convert_gpu_indices(data.pop('tunerGpuIndices', None))
if tuner_gpu_indices is not None:
data['tuner']['gpuIndicies'] = tuner_gpu_indices
data['trial'] = {
'command': data.pop('trialCommand'),
'codeDir': data.pop('trialCodeDirectory'),
'gpuNum': data.pop('trialGpuNumber', '')
}
if ts['platform'] == 'local':
data['localConfig'] = {
'useActiveGpu': ts['useActiveGpu'],
'maxTrialNumPerGpu': ts['maxTrialNumberPerGpu']
}
if ts.get('gpuIndices') is not None:
data['localConfig']['gpuIndices'] = ','.join(str(idx) for idx in ts['gpuIndices'])
elif ts['platform'] == 'remote':
data['remoteConfig'] = {'reuse': ts['reuseMode']}
data['machineList'] = []
for machine in ts['machineList']:
machine = {
'ip': machine['host'],
'username': machine['user'],
'passwd': machine['password'],
'sshKeyPath': machine['sshKeyFile'],
'passphrase': machine['sshPassphrase'],
'gpuIndices': _convert_gpu_indices(machine['gpuIndices']),
'maxTrialNumPerGpu': machine['maxTrialNumPerGpu'],
'useActiveGpu': machine['useActiveGpu'],
'preCommand': machine['trialPrepareCommand']
}
elif ts['platform'] == 'pai':
data['trial']['cpuNum'] = ts['trialCpuNumber']
data['trial']['memoryMB'] = util.parse_size(ts['trialMemorySize'])
data['trial']['image'] = ts['docker_image']
data['paiConfig'] = {
'userName': ts['username'],
'token': ts['token'],
'host': 'https://' + ts['host'],
'reuse': ts['reuseMode']
}
return data
def _convert_gpu_indices(indices):
return ','.join(str(idx) for idx in indices) if indices is not None else None
def to_cluster_metadata(config: ExperimentConfig) -> List[Dict[str, Any]]:
experiment_config = to_v1_yaml(config, skip_nnictl=True)
ret = []
if config.training_service.platform == 'local':
request_data = dict()
request_data['local_config'] = experiment_config['localConfig']
if request_data['local_config']:
if request_data['local_config'].get('gpuIndices') and isinstance(request_data['local_config'].get('gpuIndices'), int):
request_data['local_config']['gpuIndices'] = str(request_data['local_config'].get('gpuIndices'))
if request_data['local_config'].get('maxTrialNumOnEachGpu'):
request_data['local_config']['maxTrialNumOnEachGpu'] = request_data['local_config'].get('maxTrialNumOnEachGpu')
if request_data['local_config'].get('useActiveGpu'):
request_data['local_config']['useActiveGpu'] = request_data['local_config'].get('useActiveGpu')
ret.append(request_data)
elif config.training_service.platform == 'remote':
request_data = dict()
if experiment_config.get('remoteConfig'):
request_data['remote_config'] = experiment_config['remoteConfig']
else:
request_data['remote_config'] = {'reuse': False}
request_data['machine_list'] = experiment_config['machineList']
if request_data['machine_list']:
for i in range(len(request_data['machine_list'])):
if isinstance(request_data['machine_list'][i].get('gpuIndices'), int):
request_data['machine_list'][i]['gpuIndices'] = str(request_data['machine_list'][i].get('gpuIndices'))
ret.append(request_data)
elif config.training_service.platform == 'openpai':
pai_config_data = dict()
pai_config_data['pai_config'] = experiment_config['paiConfig']
ret.append(pai_config_data)
else:
raise RuntimeError('Unsupported training service ' + config.training_service.platform)
if experiment_config.get('nniManagerIp') is not None:
ret.append({'nni_manager_ip': {'nniManagerIp': experiment_config['nniManagerIp']}})
ret.append({'trial_config': experiment_config['trial']})
return ret
def to_rest_json(config: ExperimentConfig) -> Dict[str, Any]:
experiment_config = to_v1_yaml(config, skip_nnictl=True)
request_data = dict()
request_data['authorName'] = experiment_config['authorName']
request_data['experimentName'] = experiment_config['experimentName']
request_data['trialConcurrency'] = experiment_config['trialConcurrency']
request_data['maxExecDuration'] = util.parse_time(experiment_config['maxExecDuration'])
request_data['maxTrialNum'] = experiment_config['maxTrialNum']
if config.search_space is not None:
request_data['searchSpace'] = json.dumps(config.search_space)
else:
request_data['searchSpace'] = Path(config.search_space_file).read_text()
request_data['trainingServicePlatform'] = experiment_config.get('trainingServicePlatform')
if experiment_config.get('advisor'):
request_data['advisor'] = experiment_config['advisor']
if request_data['advisor'].get('gpuNum'):
_logger.warning('gpuNum is deprecated, please use gpuIndices instead.')
if request_data['advisor'].get('gpuIndices') and isinstance(request_data['advisor'].get('gpuIndices'), int):
request_data['advisor']['gpuIndices'] = str(request_data['advisor'].get('gpuIndices'))
elif experiment_config.get('tuner'):
request_data['tuner'] = experiment_config['tuner']
if request_data['tuner'].get('gpuNum'):
_logger.warning('gpuNum is deprecated, please use gpuIndices instead.')
if request_data['tuner'].get('gpuIndices') and isinstance(request_data['tuner'].get('gpuIndices'), int):
request_data['tuner']['gpuIndices'] = str(request_data['tuner'].get('gpuIndices'))
if 'assessor' in experiment_config:
request_data['assessor'] = experiment_config['assessor']
if request_data['assessor'].get('gpuNum'):
_logger.warning('gpuNum is deprecated, please remove it from your config file.')
else:
request_data['tuner'] = {'builtinTunerName': '_user_created_'}
#debug mode should disable version check
if experiment_config.get('debug') is not None:
request_data['versionCheck'] = not experiment_config.get('debug')
#validate version check
if experiment_config.get('versionCheck') is not None:
request_data['versionCheck'] = experiment_config.get('versionCheck')
if experiment_config.get('logCollection'):
request_data['logCollection'] = experiment_config.get('logCollection')
request_data['clusterMetaData'] = []
if experiment_config['trainingServicePlatform'] == 'local':
request_data['clusterMetaData'].append(
{'key':'codeDir', 'value':experiment_config['trial']['codeDir']})
request_data['clusterMetaData'].append(
{'key': 'command', 'value': experiment_config['trial']['command']})
elif experiment_config['trainingServicePlatform'] == 'remote':
request_data['clusterMetaData'].append(
{'key': 'machine_list', 'value': experiment_config['machineList']})
request_data['clusterMetaData'].append(
{'key': 'trial_config', 'value': experiment_config['trial']})
if not experiment_config.get('remoteConfig'):
# set default value of reuse in remoteConfig to False
experiment_config['remoteConfig'] = {'reuse': False}
request_data['clusterMetaData'].append(
{'key': 'remote_config', 'value': experiment_config['remoteConfig']})
elif experiment_config['trainingServicePlatform'] == 'pai':
request_data['clusterMetaData'].append(
{'key': 'pai_config', 'value': experiment_config['paiConfig']})
request_data['clusterMetaData'].append(
{'key': 'trial_config', 'value': experiment_config['trial']})
elif experiment_config['trainingServicePlatform'] == 'kubeflow':
request_data['clusterMetaData'].append(
{'key': 'kubeflow_config', 'value': experiment_config['kubeflowConfig']})
request_data['clusterMetaData'].append(
{'key': 'trial_config', 'value': experiment_config['trial']})
elif experiment_config['trainingServicePlatform'] == 'frameworkcontroller':
request_data['clusterMetaData'].append(
{'key': 'frameworkcontroller_config', 'value': experiment_config['frameworkcontrollerConfig']})
request_data['clusterMetaData'].append(
{'key': 'trial_config', 'value': experiment_config['trial']})
elif experiment_config['trainingServicePlatform'] == 'aml':
request_data['clusterMetaData'].append(
{'key': 'aml_config', 'value': experiment_config['amlConfig']})
request_data['clusterMetaData'].append(
{'key': 'trial_config', 'value': experiment_config['trial']})
return request_data
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from dataclasses import dataclass
from typing import List, Optional, Union
from .common import TrainingServiceConfig
__all__ = ['LocalConfig']
@dataclass(init=False)
class LocalConfig(TrainingServiceConfig):
platform: str = 'local'
use_active_gpu: bool
max_trial_number_per_gpu: int = 1
gpu_indices: Optional[Union[List[int], str]] = None
_canonical_rules = {
'gpu_indices': lambda value: [int(idx) for idx in value.split(',')] if isinstance(value, str) else value
}
_validation_rules = {
'platform': lambda value: (value == 'local', 'cannot be modified'),
'max_trial_number_per_gpu': lambda value: value > 0,
'gpu_indices': lambda value: all(idx >= 0 for idx in value) and len(value) == len(set(value))
}
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Miscellaneous utility functions.
"""
import math
import os.path
from pathlib import Path
from typing import Optional, Union
PathLike = Union[Path, str]
def case_insensitive(key: str) -> str:
return key.lower().replace('_', '')
def camel_case(key: str) -> str:
words = key.split('_')
return words[0] + ''.join(word.title() for word in words[1:])
def canonical_path(path: Optional[PathLike]) -> Optional[str]:
# Path.resolve() does not work on Windows when file not exist, so use os.path instead
return os.path.abspath(os.path.expanduser(path)) if path is not None else None
def count(*values) -> int:
return sum(value is not None and value is not False for value in values)
def training_service_config_factory(platform: str): # -> TrainingServiceConfig
from .common import TrainingServiceConfig
for cls in TrainingServiceConfig.__subclasses__():
if cls.platform == platform:
return cls()
raise ValueError(f'Unrecognized platform {platform}')
def strip_optional(type_hint):
return type_hint.__args__[0] if str(type_hint).startswith('typing.Optional[') else type_hint
def parse_time(time: str, target_unit: str = 's') -> int:
return _parse_unit(time.lower(), target_unit, _time_units)
def parse_size(size: str, target_unit: str = 'mb') -> int:
return _parse_unit(size.lower(), target_unit, _size_units)
_time_units = {'d': 24 * 3600, 'h': 3600, 'm': 60, 's': 1}
_size_units = {'gb': 1024 * 1024 * 1024, 'mb': 1024 * 1024, 'kb': 1024}
def _parse_unit(string, target_unit, all_units):
for unit, factor in all_units.items():
if string.endswith(unit):
number = string[:-len(unit)]
value = float(number) * factor
return math.ceil(value / all_units[target_unit])
raise ValueError(f'Unsupported unit in "{string}"')
import atexit
import logging
import socket
from subprocess import Popen
from threading import Thread
import time
from typing import Optional, overload
import colorama
import psutil
import nni.runtime.log
from nni.runtime.msg_dispatcher import MsgDispatcher
from nni.tuner import Tuner
from .config import ExperimentConfig
from . import launcher
from .pipe import Pipe
from . import rest
nni.runtime.log.init_logger_experiment()
_logger = logging.getLogger('nni.experiment')
class Experiment:
"""
Create and stop an NNI experiment.
Attributes
----------
config
Experiment configuration.
port
Web UI port of the experiment, or `None` if it is not running.
"""
@overload
def __init__(self, tuner: Tuner, config: ExperimentConfig) -> None:
"""
Prepare an experiment.
Use `Experiment.start()` to launch it.
Parameters
----------
tuner
A tuner instance.
config
Experiment configuration.
"""
...
@overload
def __init__(self, tuner: Tuner, training_service: str) -> None:
"""
Prepare an experiment, leaving configuration fields to be set later.
Example usage::
experiment = Experiment(my_tuner, 'remote')
experiment.config.trial_command = 'python3 trial.py'
experiment.config.machines.append(RemoteMachineConfig(ip=..., user_name=...))
...
experiment.start(8080)
Parameters
----------
tuner
A tuner instance.
training_service
Name of training service.
Supported value: "local", "remote", "openpai".
"""
...
def __init__(self, tuner: Tuner, config=None, training_service=None):
self.config: ExperimentConfig
self.port: Optional[int] = None
self.tuner: Tuner = tuner
self._proc: Optional[Popen] = None
self._pipe: Optional[Pipe] = None
self._dispatcher: Optional[MsgDispatcher] = None
self._dispatcher_thread: Optional[Thread] = None
if isinstance(config, str):
config, training_service = None, config
if config is None:
self.config = ExperimentConfig(training_service)
else:
self.config = config
def start(self, port: int = 8080, debug: bool = False) -> None:
"""
Start the experiment in background.
This method will raise exception on failure.
If it returns, the experiment should have been successfully started.
Parameters
----------
port
The port of web UI.
debug
Whether to start in debug mode.
"""
atexit.register(self.stop)
if debug:
logging.getLogger('nni').setLevel(logging.DEBUG)
self._proc, self._pipe = launcher.start_experiment(self.config, port, debug)
assert self._proc is not None
assert self._pipe is not None
self.port = port # port will be None if start up failed
# dispatcher must be launched after pipe initialized
# the logic to launch dispatcher in background should be refactored into dispatcher api
self._dispatcher = MsgDispatcher(self.tuner, None)
self._dispatcher_thread = Thread(target=self._dispatcher.run)
self._dispatcher_thread.start()
ips = [self.config.nni_manager_ip]
for interfaces in psutil.net_if_addrs().values():
for interface in interfaces:
if interface.family == socket.AF_INET:
ips.append(interface.address)
ips = [f'http://{ip}:{port}' for ip in ips if ip]
msg = 'Web UI URLs: ' + colorama.Fore.CYAN + ' '.join(ips)
_logger.info(msg)
# TODO: register experiment management metadata
def stop(self) -> None:
"""
Stop background experiment.
"""
_logger.info('Stopping experiment...')
atexit.unregister(self.stop)
if self._proc is not None:
self._proc.kill()
if self._pipe is not None:
self._pipe.close()
if self._dispatcher_thread is not None:
self._dispatcher.stopping = True
self._dispatcher_thread.join(timeout=1)
self.port = None
self._proc = None
self._pipe = None
self._dispatcher = None
self._dispatcher_thread = None
def run(self, port: int = 8080, debug: bool = False) -> bool:
"""
Run the experiment.
This function will block until experiment finish or error.
Return `True` when experiment done; or return `False` when experiment failed.
"""
self.start(port, debug)
try:
while True:
time.sleep(10)
status = self.get_status()
if status == 'STOPPED':
return True
if status == 'ERROR':
return False
finally:
self.stop()
def get_status(self) -> str:
if self.port is None:
raise RuntimeError('Experiment is not running')
resp = rest.get(self.port, '/check-status')
return resp['status']
import contextlib
import logging
from pathlib import Path
import socket
from subprocess import Popen
import sys
import time
from typing import Optional, Tuple
import colorama
import nni.runtime.protocol
import nni_node
from .config import ExperimentConfig
from .config import convert
from . import management
from .pipe import Pipe
from . import rest
_logger = logging.getLogger('nni.experiment')
def start_experiment(config: ExperimentConfig, port: int, debug: bool) -> Tuple[Popen, Pipe]:
pipe = None
proc = None
config.validate(initialized_tuner=True)
_ensure_port_idle(port)
if config.training_service.platform == 'openpai':
_ensure_port_idle(port + 1, 'OpenPAI requires an additional port')
exp_id = management.generate_experiment_id()
try:
_logger.info(f'Creating experiment {colorama.Fore.CYAN}{exp_id}')
pipe = Pipe(exp_id)
proc = _start_rest_server(config, port, debug, exp_id, pipe.path)
_logger.info('Connecting IPC pipe...')
pipe_file = pipe.connect()
nni.runtime.protocol._in_file = pipe_file
nni.runtime.protocol._out_file = pipe_file
_logger.info('Statring web server...')
_check_rest_server(port)
_logger.info('Setting up...')
_init_experiment(config, port, debug)
return proc, pipe
except Exception as e:
_logger.error('Create experiment failed')
if proc is not None:
with contextlib.suppress(Exception):
proc.kill()
if pipe is not None:
with contextlib.suppress(Exception):
pipe.close()
raise e
def _ensure_port_idle(port: int, message: Optional[str] = None) -> None:
sock = socket.socket()
if sock.connect_ex(('localhost', port)) == 0:
sock.close()
message = f'(message)' if message else ''
raise RuntimeError(f'Port {port} is not idle {message}')
def _start_rest_server(config: ExperimentConfig, port: int, debug: bool, experiment_id: str, pipe_path: str) -> Popen:
ts = config.training_service.platform
if ts == 'openpai':
ts = 'pai'
args = {
'port': port,
'mode': ts,
'experiment_id': experiment_id,
'start_mode': 'new',
'log_level': 'debug' if debug else 'info',
'dispatcher_pipe': pipe_path,
}
node_dir = Path(nni_node.__path__[0])
node = str(node_dir / ('node.exe' if sys.platform == 'win32' else 'node'))
main_js = str(node_dir / 'main.js')
cmd = [node, '--max-old-space-size=4096', main_js]
for arg_key, arg_value in args.items():
cmd.append('--' + arg_key)
cmd.append(str(arg_value))
return Popen(cmd, cwd=node_dir)
def _check_rest_server(port: int, retry: int = 3) -> None:
for i in range(retry):
with contextlib.suppress(Exception):
rest.get(port, '/check-status')
return
if i > 0:
_logger.warning('Timeout, retry...')
time.sleep(1)
rest.get(port, '/check-status')
def _init_experiment(config: ExperimentConfig, port: int, debug: bool) -> None:
for cluster_metadata in convert.to_cluster_metadata(config):
rest.put(port, '/experiment/cluster-metadata', cluster_metadata)
rest.post(port, '/experiment', convert.to_rest_json(config))
from pathlib import Path
import random
import string
def generate_experiment_id() -> str:
return ''.join(random.sample(string.ascii_lowercase + string.digits, 8))
def create_experiment_directory(experiment_id: str) -> Path:
path = Path.home() / 'nni-experiments' / experiment_id
path.mkdir(parents=True, exist_ok=True)
return path
# TODO: port shangning's work here, and use it in Experiment.start()/.stop()
......@@ -28,7 +28,7 @@ import json
import requests
__all__ = [
'Experiment',
'ExternalExperiment',
'TrialResult',
'TrialMetricData',
'TrialHyperParameters',
......@@ -228,7 +228,7 @@ class TrialJob:
.format(self.trialJobId, self.status, self.hyperParameters, self.logPath,
self.startTime, self.endTime, self.finalMetricData, self.stderrPath)
class Experiment:
class ExternalExperiment:
def __init__(self):
self._endpoint = None
self._exp_id = None
......
from io import BufferedIOBase
import os
import sys
if sys.platform == 'win32':
import _winapi
import msvcrt
class WindowsPipe:
def __init__(self, experiment_id: str):
self.path: str = r'\\.\pipe\nni-' + experiment_id
self.file = None
self._handle = _winapi.CreateNamedPipe(
self.path,
_winapi.PIPE_ACCESS_DUPLEX,
_winapi.PIPE_TYPE_MESSAGE | _winapi.PIPE_READMODE_MESSAGE | _winapi.PIPE_WAIT,
1,
8192,
8192,
0,
_winapi.NULL
)
def connect(self) -> BufferedIOBase:
_winapi.ConnectNamedPipe(self._handle, _winapi.NULL)
fd = msvcrt.open_osfhandle(self._handle, 0)
self.file = os.fdopen(fd, 'w+b')
return self.file
def close(self) -> None:
if self.file is not None:
self.file.close()
_winapi.CloseHandle(self._handle)
Pipe = WindowsPipe
else:
import socket
from . import management
class UnixPipe:
def __init__(self, experiment_id: str):
self.path: str = str(management.create_experiment_directory(experiment_id) / 'dispatcher-pipe')
self.file = None
self._socket = socket.socket(socket.AF_UNIX)
self._socket.bind(self.path)
self._socket.listen(1) # only accepts one connection
def connect(self) -> BufferedIOBase:
conn, _ = self._socket.accept()
self.file = conn.makefile('w+b')
return self.file
def close(self) -> None:
if self.file is not None:
self.file.close()
self._socket.close()
os.unlink(self.path)
Pipe = UnixPipe
import logging
from typing import Any
import requests
_logger = logging.getLogger(__name__)
url_template = 'http://localhost:{}/api/v1/nni{}'
timeout = 20
def get(port: int, api: str) -> Any:
url = url_template.format(port, api)
resp = requests.get(url, timeout=timeout)
if not resp.ok:
_logger.error('rest request GET %s %s failed: %s %s', port, api, resp.status_code, resp.text)
resp.raise_for_status()
return resp.json()
def post(port: int, api: str, data: Any) -> Any:
url = url_template.format(port, api)
resp = requests.post(url, json=data, timeout=timeout)
if not resp.ok:
_logger.error('rest request POST %s %s failed: %s %s', port, api, resp.status_code, resp.text)
resp.raise_for_status()
return resp.json()
def put(port: int, api: str, data: Any) -> None:
url = url_template.format(port, api)
resp = requests.put(url, json=data, timeout=timeout)
if not resp.ok:
_logger.error('rest request PUT %s %s failed: %s', port, api, resp.status_code)
resp.raise_for_status()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from datetime import datetime
from io import TextIOBase
import logging
import os
import sys
import time
log_level_map = {
'fatal': logging.FATAL,
'error': logging.ERROR,
'warning': logging.WARNING,
'info': logging.INFO,
'debug': logging.DEBUG
}
_time_format = '%m/%d/%Y, %I:%M:%S %p'
# FIXME
# This hotfix the bug that querying installed tuners with `package_utils` will activate dispatcher logger.
# This behavior depends on underlying implementation of `nnictl` and is likely to break in future.
_logger_initialized = False
class _LoggerFileWrapper(TextIOBase):
def __init__(self, logger_file):
self.file = logger_file
def write(self, s):
if s != '\n':
cur_time = datetime.now().strftime(_time_format)
self.file.write('[{}] PRINT '.format(cur_time) + s + '\n')
self.file.flush()
return len(s)
def init_logger(logger_file_path, log_level_name='info'):
"""Initialize root logger.
This will redirect anything from logging.getLogger() as well as stdout to specified file.
logger_file_path: path of logger file (path-like object).
"""
global _logger_initialized
if _logger_initialized:
return
_logger_initialized = True
if os.environ.get('NNI_PLATFORM') == 'unittest':
return # fixme: launching logic needs refactor
log_level = log_level_map.get(log_level_name, logging.INFO)
logger_file = open(logger_file_path, 'w')
fmt = '[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s'
logging.Formatter.converter = time.localtime
formatter = logging.Formatter(fmt, _time_format)
handler = logging.StreamHandler(logger_file)
handler.setFormatter(formatter)
root_logger = logging.getLogger()
root_logger.addHandler(handler)
root_logger.setLevel(log_level)
# these modules are too verbose
logging.getLogger('matplotlib').setLevel(log_level)
sys.stdout = _LoggerFileWrapper(logger_file)
def init_standalone_logger():
"""
Initialize root logger for standalone mode.
This will set NNI's log level to INFO and print its log to stdout.
"""
global _logger_initialized
if _logger_initialized:
return
_logger_initialized = True
fmt = '[%(asctime)s] %(levelname)s (%(name)s) %(message)s'
formatter = logging.Formatter(fmt, _time_format)
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(formatter)
nni_logger = logging.getLogger('nni')
nni_logger.addHandler(handler)
nni_logger.setLevel(logging.INFO)
nni_logger.propagate = False
# Following line does not affect NNI loggers, but without this user's logger won't be able to
# print log even it's level is set to INFO, so we do it for user's convenience.
# If this causes any issue in future, remove it and use `logging.info` instead of
# `logging.getLogger('xxx')` in all examples.
logging.basicConfig()
_multi_thread = False
_multi_phase = False
......
from datetime import datetime
from io import TextIOBase
import logging
from logging import FileHandler, Formatter, Handler, StreamHandler
from pathlib import Path
import sys
import time
from typing import Optional
import colorama
from .env_vars import dispatcher_env_vars, trial_env_vars
def init_logger() -> None:
"""
This function will (and should only) get invoked on the first time of importing nni (no matter which submodule).
It will try to detect the running environment and setup logger accordingly.
The detection should work in most cases but for `nnictl` and `nni.experiment`.
They will be identified as "standalone" mode and must configure the logger by themselves.
"""
colorama.init()
if dispatcher_env_vars.SDK_PROCESS == 'dispatcher':
_init_logger_dispatcher()
return
trial_platform = trial_env_vars.NNI_PLATFORM
if trial_platform == 'unittest':
return
if trial_platform:
_init_logger_trial()
return
_init_logger_standalone()
def init_logger_experiment() -> None:
"""
Initialize logger for `nni.experiment.Experiment`.
This function will get invoked after `init_logger()`.
"""
formatter.format = _colorful_format
time_format = '%Y-%m-%d %H:%M:%S'
formatter = Formatter(
'[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s',
time_format
)
def _init_logger_dispatcher() -> None:
log_level_map = {
'fatal': logging.CRITICAL,
'error': logging.ERROR,
'warning': logging.WARNING,
'info': logging.INFO,
'debug': logging.DEBUG,
'trace': 0
}
log_path = _prepare_log_dir(dispatcher_env_vars.NNI_LOG_DIRECTORY) / 'dispatcher.log'
log_level = log_level_map.get(dispatcher_env_vars.NNI_LOG_LEVEL, logging.INFO)
_setup_root_logger(FileHandler(log_path), log_level)
def _init_logger_trial() -> None:
log_path = _prepare_log_dir(trial_env_vars.NNI_OUTPUT_DIR) / 'trial.log'
log_file = open(log_path, 'w')
_setup_root_logger(StreamHandler(log_file), logging.INFO)
sys.stdout = _LogFileWrapper(log_file)
def _init_logger_standalone() -> None:
_setup_nni_logger(StreamHandler(sys.stdout), logging.INFO)
# Following line does not affect NNI loggers, but without this user's logger won't
# print log even it's level is set to INFO, so we do it for user's convenience.
# If this causes any issue in future, remove it and use `logging.info()` instead of
# `logging.getLogger('xxx').info()` in all examples.
logging.basicConfig()
def _prepare_log_dir(path: Optional[str]) -> Path:
if path is None:
return Path()
ret = Path(path)
ret.mkdir(parents=True, exist_ok=True)
return ret
def _setup_root_logger(handler: Handler, level: int) -> None:
_setup_logger('', handler, level)
def _setup_nni_logger(handler: Handler, level: int) -> None:
_setup_logger('nni', handler, level)
def _setup_logger(name: str, handler: Handler, level: int) -> None:
handler.setFormatter(formatter)
logger = logging.getLogger(name)
logger.addHandler(handler)
logger.setLevel(level)
logger.propagate = False
def _colorful_format(record):
if record.levelno >= logging.ERROR:
color = colorama.Fore.RED
elif record.levelno >= logging.WARNING:
color = colorama.Fore.YELLOW
elif record.levelno >= logging.INFO:
color = colorama.Fore.GREEN
else:
color = colorama.Fore.BLUE
msg = color + (record.msg % record.args) + colorama.Style.RESET_ALL
time = formatter.formatTime(record, time_format)
if record.levelno < logging.INFO:
return '[{}] {}:{} {}'.format(time, record.threadName, record.name, msg)
else:
return '[{}] {}'.format(time, msg)
class _LogFileWrapper(TextIOBase):
# wrap the logger file so that anything written to it will automatically get formatted
def __init__(self, log_file: TextIOBase):
self.file: TextIOBase = log_file
self.line_buffer: Optional[str] = None
self.line_start_time: Optional[datetime] = None
def write(self, s: str) -> int:
cur_time = datetime.now()
if self.line_buffer and (cur_time - self.line_start_time).total_seconds() > 0.1:
self.flush()
if self.line_buffer:
self.line_buffer += s
else:
self.line_buffer = s
self.line_start_time = cur_time
if '\n' not in s:
return len(s)
time_str = cur_time.strftime(time_format)
lines = self.line_buffer.split('\n')
for line in lines[:-1]:
self.file.write(f'[{time_str}] PRINT {line}\n')
self.file.flush()
self.line_buffer = lines[-1]
self.line_start_time = cur_time
return len(s)
def flush(self) -> None:
if self.line_buffer:
time_str = self.line_start_time.strftime(time_format)
self.file.write(f'[{time_str}] PRINT {self.line_buffer}\n')
self.file.flush()
self.line_buffer = None
......@@ -9,11 +9,9 @@ import json_tricks
from .common import multi_thread_enabled
from .env_vars import dispatcher_env_vars
from ..utils import init_dispatcher_logger
from ..recoverable import Recoverable
from .protocol import CommandType, receive
init_dispatcher_logger()
_logger = logging.getLogger(__name__)
......@@ -27,11 +25,11 @@ class MsgDispatcherBase(Recoverable):
"""
def __init__(self):
self.stopping = False
if multi_thread_enabled():
self.pool = ThreadPool()
self.thread_results = []
else:
self.stopping = False
self.default_command_queue = Queue()
self.assessor_command_queue = Queue()
self.default_worker = threading.Thread(target=self.command_queue_worker, args=(self.default_command_queue,))
......@@ -45,11 +43,11 @@ class MsgDispatcherBase(Recoverable):
"""Run the tuner.
This function will never return unless raise.
"""
_logger.info('Start dispatcher')
_logger.info('Dispatcher started')
if dispatcher_env_vars.NNI_MODE == 'resume':
self.load_checkpoint()
while True:
while not self.stopping:
command, data = receive()
if data:
data = json_tricks.loads(data)
......@@ -77,7 +75,7 @@ class MsgDispatcherBase(Recoverable):
self.default_worker.join()
self.assessor_worker.join()
_logger.info('Terminated by NNI manager')
_logger.info('Dispatcher terminiated')
def command_queue_worker(self, command_queue):
"""Process commands in command queues.
......
......@@ -7,7 +7,6 @@ import json
import time
import subprocess
from ..common import init_logger
from ..env_vars import trial_env_vars
from nni.utils import to_json
......@@ -21,9 +20,6 @@ if not os.path.exists(_outputdir):
os.makedirs(_outputdir)
_nni_platform = trial_env_vars.NNI_PLATFORM
if _nni_platform == 'local':
_log_file_path = os.path.join(_outputdir, 'trial.log')
init_logger(_log_file_path)
_multiphase = trial_env_vars.MULTI_PHASE
......
......@@ -4,8 +4,6 @@
import logging
import json_tricks
from ..common import init_standalone_logger
__all__ = [
'get_next_parameter',
'get_experiment_id',
......@@ -14,7 +12,6 @@ __all__ = [
'send_metric',
]
init_standalone_logger()
_logger = logging.getLogger('nni')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment