Unverified Commit 32fdd32b authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Add simple HPO search space validation (#3877)

parent 749a463a
...@@ -9,6 +9,7 @@ batch_tuner.py including: ...@@ -9,6 +9,7 @@ batch_tuner.py including:
import logging import logging
import nni import nni
from nni.common.hpo_utils import validate_search_space
from nni.tuner import Tuner from nni.tuner import Tuner
TYPE = '_type' TYPE = '_type'
...@@ -75,6 +76,7 @@ class BatchTuner(Tuner): ...@@ -75,6 +76,7 @@ class BatchTuner(Tuner):
---------- ----------
search_space : dict search_space : dict
""" """
validate_search_space(search_space, ['choice'])
self._values = self.is_valid(search_space) self._values = self.is_valid(search_space)
def generate_parameters(self, parameter_id, **kwargs): def generate_parameters(self, parameter_id, **kwargs):
......
...@@ -7,6 +7,7 @@ from torch.distributions import Normal ...@@ -7,6 +7,7 @@ from torch.distributions import Normal
import nni.parameter_expressions as parameter_expressions import nni.parameter_expressions as parameter_expressions
from nni import ClassArgsValidator from nni import ClassArgsValidator
from nni.common.hpo_utils import validate_search_space
from nni.tuner import Tuner from nni.tuner import Tuner
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -86,6 +87,7 @@ class DNGOTuner(Tuner): ...@@ -86,6 +87,7 @@ class DNGOTuner(Tuner):
return new_x return new_x
def update_search_space(self, search_space): def update_search_space(self, search_space):
validate_search_space(search_space, ['choice', 'randint', 'uniform', 'quniform', 'loguniform', 'qloguniform'])
self.searchspace_json = search_space self.searchspace_json = search_space
self.random_state = np.random.RandomState() self.random_state = np.random.RandomState()
......
...@@ -16,6 +16,7 @@ from sklearn.gaussian_process.kernels import Matern ...@@ -16,6 +16,7 @@ from sklearn.gaussian_process.kernels import Matern
from sklearn.gaussian_process import GaussianProcessRegressor from sklearn.gaussian_process import GaussianProcessRegressor
from nni import ClassArgsValidator from nni import ClassArgsValidator
from nni.common.hpo_utils import validate_search_space
from nni.tuner import Tuner from nni.tuner import Tuner
from nni.utils import OptimizeMode, extract_scalar_reward from nni.utils import OptimizeMode, extract_scalar_reward
...@@ -103,6 +104,7 @@ class GPTuner(Tuner): ...@@ -103,6 +104,7 @@ class GPTuner(Tuner):
Override of the abstract method in :class:`~nni.tuner.Tuner`. Override of the abstract method in :class:`~nni.tuner.Tuner`.
""" """
validate_search_space(search_space, ['choice', 'randint', 'uniform', 'quniform', 'loguniform', 'qloguniform'])
self._space = TargetSpace(search_space, self._random_state) self._space = TargetSpace(search_space, self._random_state)
def generate_parameters(self, parameter_id, **kwargs): def generate_parameters(self, parameter_id, **kwargs):
......
...@@ -11,6 +11,7 @@ import logging ...@@ -11,6 +11,7 @@ import logging
import numpy as np import numpy as np
import nni import nni
from nni.common.hpo_utils import validate_search_space
from nni.tuner import Tuner from nni.tuner import Tuner
from nni.utils import convert_dict2tuple from nni.utils import convert_dict2tuple
...@@ -144,6 +145,7 @@ class GridSearchTuner(Tuner): ...@@ -144,6 +145,7 @@ class GridSearchTuner(Tuner):
search_space : dict search_space : dict
The format could be referred to search space spec (https://nni.readthedocs.io/en/latest/Tutorial/SearchSpaceSpec.html). The format could be referred to search space spec (https://nni.readthedocs.io/en/latest/Tutorial/SearchSpaceSpec.html).
""" """
validate_search_space(search_space, ['choice', 'randint', 'quniform'])
self.expanded_search_space = self._json2parameter(search_space) self.expanded_search_space = self._json2parameter(search_space)
def generate_parameters(self, parameter_id, **kwargs): def generate_parameters(self, parameter_id, **kwargs):
......
...@@ -15,6 +15,7 @@ import numpy as np ...@@ -15,6 +15,7 @@ import numpy as np
from schema import Schema, Optional from schema import Schema, Optional
from nni import ClassArgsValidator from nni import ClassArgsValidator
from nni.common.hpo_utils import validate_search_space
from nni.runtime.common import multi_phase_enabled from nni.runtime.common import multi_phase_enabled
from nni.runtime.msg_dispatcher_base import MsgDispatcherBase from nni.runtime.msg_dispatcher_base import MsgDispatcherBase
from nni.runtime.protocol import CommandType, send from nni.runtime.protocol import CommandType, send
...@@ -379,6 +380,7 @@ class Hyperband(MsgDispatcherBase): ...@@ -379,6 +380,7 @@ class Hyperband(MsgDispatcherBase):
def handle_update_search_space(self, data): def handle_update_search_space(self, data):
"""data: JSON object, which is search space """data: JSON object, which is search space
""" """
validate_search_space(data)
self.searchspace_json = data self.searchspace_json = data
self.random_state = np.random.RandomState() self.random_state = np.random.RandomState()
......
...@@ -12,6 +12,7 @@ import hyperopt as hp ...@@ -12,6 +12,7 @@ import hyperopt as hp
import numpy as np import numpy as np
from schema import Optional, Schema from schema import Optional, Schema
from nni import ClassArgsValidator from nni import ClassArgsValidator
from nni.common.hpo_utils import validate_search_space
from nni.tuner import Tuner from nni.tuner import Tuner
from nni.utils import NodeType, OptimizeMode, extract_scalar_reward, split_index from nni.utils import NodeType, OptimizeMode, extract_scalar_reward, split_index
...@@ -246,6 +247,7 @@ class HyperoptTuner(Tuner): ...@@ -246,6 +247,7 @@ class HyperoptTuner(Tuner):
---------- ----------
search_space : dict search_space : dict
""" """
validate_search_space(search_space)
self.json = search_space self.json = search_space
search_space_instance = json2space(self.json) search_space_instance = json2space(self.json)
......
...@@ -16,6 +16,7 @@ from schema import Schema, Optional ...@@ -16,6 +16,7 @@ from schema import Schema, Optional
from nni import ClassArgsValidator from nni import ClassArgsValidator
from nni.tuner import Tuner from nni.tuner import Tuner
from nni.common.hpo_utils import validate_search_space
from nni.utils import OptimizeMode, extract_scalar_reward from nni.utils import OptimizeMode, extract_scalar_reward
from . import lib_constraint_summation from . import lib_constraint_summation
from . import lib_data from . import lib_data
...@@ -152,6 +153,8 @@ class MetisTuner(Tuner): ...@@ -152,6 +153,8 @@ class MetisTuner(Tuner):
---------- ----------
search_space : dict search_space : dict
""" """
validate_search_space(search_space, ['choice', 'randint', 'uniform', 'quniform'])
self.x_bounds = [[] for i in range(len(search_space))] self.x_bounds = [[] for i in range(len(search_space))]
self.x_types = [NONE_TYPE for i in range(len(search_space))] self.x_types = [NONE_TYPE for i in range(len(search_space))]
......
...@@ -21,6 +21,7 @@ from ConfigSpaceNNI import Configuration ...@@ -21,6 +21,7 @@ from ConfigSpaceNNI import Configuration
import nni import nni
from nni import ClassArgsValidator from nni import ClassArgsValidator
from nni.common.hpo_utils import validate_search_space
from nni.tuner import Tuner from nni.tuner import Tuner
from nni.utils import OptimizeMode, extract_scalar_reward from nni.utils import OptimizeMode, extract_scalar_reward
...@@ -143,6 +144,7 @@ class SMACTuner(Tuner): ...@@ -143,6 +144,7 @@ class SMACTuner(Tuner):
The format could be referred to search space spec (https://nni.readthedocs.io/en/latest/Tutorial/SearchSpaceSpec.html). The format could be referred to search space spec (https://nni.readthedocs.io/en/latest/Tutorial/SearchSpaceSpec.html).
""" """
self.logger.info('update search space in SMAC.') self.logger.info('update search space in SMAC.')
validate_search_space(search_space, ['choice', 'randint', 'uniform', 'quniform', 'loguniform'])
if not self.update_ss_done: if not self.update_ss_done:
self.categorical_dict = generate_scenario(search_space) self.categorical_dict = generate_scenario(search_space)
if self.categorical_dict is None: if self.categorical_dict is None:
......
import logging
from typing import Any, List, Optional
common_search_space_types = [
'choice',
'randint',
'uniform',
'quniform',
'loguniform',
'qloguniform',
'normal',
'qnormal',
'lognormal',
'qlognormal',
]
def validate_search_space(
search_space: Any,
support_types: Optional[List[str]] = None,
raise_exception: bool = False # for now, in case false positive
) -> bool:
if not raise_exception:
try:
validate_search_space(search_space, support_types, True)
return True
except ValueError as e:
logging.getLogger(__name__).error(e.args[0])
return False
if support_types is None:
support_types = common_search_space_types
if not isinstance(search_space, dict):
raise ValueError(f'search space is a {type(search_space).__name__}, expect a dict : {repr(search_space)}')
for name, spec in search_space.items():
if not isinstance(spec, dict):
raise ValueError(f'search space "{name}" is a {type(spec).__name__}, expect a dict : {repr(spec)}')
if '_type' not in spec or '_value' not in spec:
raise ValueError(f'search space "{name}" does not have "_type" or "_value" : {spec}')
type_ = spec['_type']
if type_ not in support_types:
raise ValueError(f'search space "{name}" has unsupported type "{type_}" : {spec}')
args = spec['_value']
if not isinstance(args, list):
raise ValueError(f'search space "{name}"\'s value is not a list : {spec}')
if type_ == 'choice':
continue
if type_.startswith('q'):
if len(args) != 3:
raise ValueError(f'search space "{name}" ({type_}) must have 3 values : {spec}')
else:
if len(args) != 2:
raise ValueError(f'search space "{name}" ({type_}) must have 2 values : {spec}')
if type_ == 'randint':
if not all(isinstance(arg, int) for arg in args):
raise ValueError(f'search space "{name}" ({type_}) must have int values : {spec}')
else:
if not all(isinstance(arg, (float, int)) for arg in args):
raise ValueError(f'search space "{name}" ({type_}) must have float values : {spec}')
if 'normal' not in type_:
if args[0] >= args[1]:
raise ValueError(f'search space "{name}" ({type_}) must have high > low : {spec}')
if 'log' in type_ and args[0] <= 0:
raise ValueError(f'search space "{name}" ({type_}) must have low > 0 : {spec}')
else:
if args[1] <= 0:
raise ValueError(f'search space "{name}" ({type_}) must have sigma > 0 : {spec}')
return True
from nni.common.hpo_utils import validate_search_space
good = {
'choice': { '_type': 'choice', '_value': ['a', 'b'] },
'randint': { '_type': 'randint', '_value': [1, 10] },
'uniform': { '_type': 'uniform', '_value': [0, 1.0] },
'quniform': { '_type': 'quniform', '_value': [1, 10, 0.1] },
'loguniform': { '_type': 'loguniform', '_value': [0.001, 0.1] },
'qloguniform': { '_type': 'qloguniform', '_value': [0.001, 0.1, 0.001] },
'normal': { '_type': 'normal', '_value': [0, 0.1] },
'qnormal': { '_type': 'qnormal', '_value': [0.5, 0.1, 0.1] },
'lognormal': { '_type': 'lognormal', '_value': [0.0, 1] },
'qlognormal': { '_type': 'qlognormal', '_value': [-1, 1, 0.1] },
}
good_partial = {
'choice': good['choice'],
'randint': good['randint'],
}
bad_type = 'x'
bad_spec_type = { 'x': [1, 2, 3] }
bad_fields = { 'x': { 'type': 'choice', 'value': ['a', 'b'] } }
bad_type_name = { 'x': { '_type': 'choic', '_value': ['a'] } }
bad_value = { 'x': { '_type': 'choice', '_value': 'ab' } }
bad_2_args = { 'x': { '_type': 'randint', '_value': [1, 2, 3] } }
bad_3_args = { 'x': { '_type': 'quniform', '_value': [0] } }
bad_int_args = { 'x': { '_type': 'randint', '_value': [1.0, 2.0] } }
bad_float_args = { 'x': { '_type': 'uniform', '_value': ['0.1', '0.2'] } }
bad_low_high = { 'x': { '_type': 'quniform', '_value': [2, 1, 0.1] } }
bad_log = { 'x': { '_type': 'loguniform', '_value': [0, 1] } }
bad_sigma = { 'x': { '_type': 'normal', '_value': [0, 0] } }
def test_hpo_utils():
assert validate_search_space(good, raise_exception=False)
assert not validate_search_space(bad_type, raise_exception=False)
assert not validate_search_space(bad_spec_type, raise_exception=False)
assert not validate_search_space(bad_fields, raise_exception=False)
assert not validate_search_space(bad_type_name, raise_exception=False)
assert not validate_search_space(bad_value, raise_exception=False)
assert not validate_search_space(bad_2_args, raise_exception=False)
assert not validate_search_space(bad_3_args, raise_exception=False)
assert not validate_search_space(bad_int_args, raise_exception=False)
assert not validate_search_space(bad_float_args, raise_exception=False)
assert not validate_search_space(bad_low_high, raise_exception=False)
assert not validate_search_space(bad_log, raise_exception=False)
assert not validate_search_space(bad_sigma, raise_exception=False)
assert validate_search_space(good_partial, ['choice', 'randint'], False)
assert not validate_search_space(good, ['choice', 'randint'], False)
if __name__ == '__main__':
test_hpo_utils()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment