Unverified Commit 32fdd32b authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Add simple HPO search space validation (#3877)

parent 749a463a
......@@ -9,6 +9,7 @@ batch_tuner.py including:
import logging
import nni
from nni.common.hpo_utils import validate_search_space
from nni.tuner import Tuner
TYPE = '_type'
......@@ -75,6 +76,7 @@ class BatchTuner(Tuner):
----------
search_space : dict
"""
validate_search_space(search_space, ['choice'])
self._values = self.is_valid(search_space)
def generate_parameters(self, parameter_id, **kwargs):
......
......@@ -7,6 +7,7 @@ from torch.distributions import Normal
import nni.parameter_expressions as parameter_expressions
from nni import ClassArgsValidator
from nni.common.hpo_utils import validate_search_space
from nni.tuner import Tuner
_logger = logging.getLogger(__name__)
......@@ -86,6 +87,7 @@ class DNGOTuner(Tuner):
return new_x
def update_search_space(self, search_space):
validate_search_space(search_space, ['choice', 'randint', 'uniform', 'quniform', 'loguniform', 'qloguniform'])
self.searchspace_json = search_space
self.random_state = np.random.RandomState()
......
......@@ -16,6 +16,7 @@ from sklearn.gaussian_process.kernels import Matern
from sklearn.gaussian_process import GaussianProcessRegressor
from nni import ClassArgsValidator
from nni.common.hpo_utils import validate_search_space
from nni.tuner import Tuner
from nni.utils import OptimizeMode, extract_scalar_reward
......@@ -103,6 +104,7 @@ class GPTuner(Tuner):
Override of the abstract method in :class:`~nni.tuner.Tuner`.
"""
validate_search_space(search_space, ['choice', 'randint', 'uniform', 'quniform', 'loguniform', 'qloguniform'])
self._space = TargetSpace(search_space, self._random_state)
def generate_parameters(self, parameter_id, **kwargs):
......
......@@ -11,6 +11,7 @@ import logging
import numpy as np
import nni
from nni.common.hpo_utils import validate_search_space
from nni.tuner import Tuner
from nni.utils import convert_dict2tuple
......@@ -144,6 +145,7 @@ class GridSearchTuner(Tuner):
search_space : dict
The format could be referred to search space spec (https://nni.readthedocs.io/en/latest/Tutorial/SearchSpaceSpec.html).
"""
validate_search_space(search_space, ['choice', 'randint', 'quniform'])
self.expanded_search_space = self._json2parameter(search_space)
def generate_parameters(self, parameter_id, **kwargs):
......
......@@ -15,6 +15,7 @@ import numpy as np
from schema import Schema, Optional
from nni import ClassArgsValidator
from nni.common.hpo_utils import validate_search_space
from nni.runtime.common import multi_phase_enabled
from nni.runtime.msg_dispatcher_base import MsgDispatcherBase
from nni.runtime.protocol import CommandType, send
......@@ -379,6 +380,7 @@ class Hyperband(MsgDispatcherBase):
def handle_update_search_space(self, data):
"""data: JSON object, which is search space
"""
validate_search_space(data)
self.searchspace_json = data
self.random_state = np.random.RandomState()
......
......@@ -12,6 +12,7 @@ import hyperopt as hp
import numpy as np
from schema import Optional, Schema
from nni import ClassArgsValidator
from nni.common.hpo_utils import validate_search_space
from nni.tuner import Tuner
from nni.utils import NodeType, OptimizeMode, extract_scalar_reward, split_index
......@@ -246,6 +247,7 @@ class HyperoptTuner(Tuner):
----------
search_space : dict
"""
validate_search_space(search_space)
self.json = search_space
search_space_instance = json2space(self.json)
......
......@@ -16,6 +16,7 @@ from schema import Schema, Optional
from nni import ClassArgsValidator
from nni.tuner import Tuner
from nni.common.hpo_utils import validate_search_space
from nni.utils import OptimizeMode, extract_scalar_reward
from . import lib_constraint_summation
from . import lib_data
......@@ -152,6 +153,8 @@ class MetisTuner(Tuner):
----------
search_space : dict
"""
validate_search_space(search_space, ['choice', 'randint', 'uniform', 'quniform'])
self.x_bounds = [[] for i in range(len(search_space))]
self.x_types = [NONE_TYPE for i in range(len(search_space))]
......
......@@ -21,6 +21,7 @@ from ConfigSpaceNNI import Configuration
import nni
from nni import ClassArgsValidator
from nni.common.hpo_utils import validate_search_space
from nni.tuner import Tuner
from nni.utils import OptimizeMode, extract_scalar_reward
......@@ -143,6 +144,7 @@ class SMACTuner(Tuner):
The format could be referred to search space spec (https://nni.readthedocs.io/en/latest/Tutorial/SearchSpaceSpec.html).
"""
self.logger.info('update search space in SMAC.')
validate_search_space(search_space, ['choice', 'randint', 'uniform', 'quniform', 'loguniform'])
if not self.update_ss_done:
self.categorical_dict = generate_scenario(search_space)
if self.categorical_dict is None:
......
import logging
from typing import Any, List, Optional
common_search_space_types = [
'choice',
'randint',
'uniform',
'quniform',
'loguniform',
'qloguniform',
'normal',
'qnormal',
'lognormal',
'qlognormal',
]
def validate_search_space(
search_space: Any,
support_types: Optional[List[str]] = None,
raise_exception: bool = False # for now, in case false positive
) -> bool:
if not raise_exception:
try:
validate_search_space(search_space, support_types, True)
return True
except ValueError as e:
logging.getLogger(__name__).error(e.args[0])
return False
if support_types is None:
support_types = common_search_space_types
if not isinstance(search_space, dict):
raise ValueError(f'search space is a {type(search_space).__name__}, expect a dict : {repr(search_space)}')
for name, spec in search_space.items():
if not isinstance(spec, dict):
raise ValueError(f'search space "{name}" is a {type(spec).__name__}, expect a dict : {repr(spec)}')
if '_type' not in spec or '_value' not in spec:
raise ValueError(f'search space "{name}" does not have "_type" or "_value" : {spec}')
type_ = spec['_type']
if type_ not in support_types:
raise ValueError(f'search space "{name}" has unsupported type "{type_}" : {spec}')
args = spec['_value']
if not isinstance(args, list):
raise ValueError(f'search space "{name}"\'s value is not a list : {spec}')
if type_ == 'choice':
continue
if type_.startswith('q'):
if len(args) != 3:
raise ValueError(f'search space "{name}" ({type_}) must have 3 values : {spec}')
else:
if len(args) != 2:
raise ValueError(f'search space "{name}" ({type_}) must have 2 values : {spec}')
if type_ == 'randint':
if not all(isinstance(arg, int) for arg in args):
raise ValueError(f'search space "{name}" ({type_}) must have int values : {spec}')
else:
if not all(isinstance(arg, (float, int)) for arg in args):
raise ValueError(f'search space "{name}" ({type_}) must have float values : {spec}')
if 'normal' not in type_:
if args[0] >= args[1]:
raise ValueError(f'search space "{name}" ({type_}) must have high > low : {spec}')
if 'log' in type_ and args[0] <= 0:
raise ValueError(f'search space "{name}" ({type_}) must have low > 0 : {spec}')
else:
if args[1] <= 0:
raise ValueError(f'search space "{name}" ({type_}) must have sigma > 0 : {spec}')
return True
from nni.common.hpo_utils import validate_search_space
good = {
'choice': { '_type': 'choice', '_value': ['a', 'b'] },
'randint': { '_type': 'randint', '_value': [1, 10] },
'uniform': { '_type': 'uniform', '_value': [0, 1.0] },
'quniform': { '_type': 'quniform', '_value': [1, 10, 0.1] },
'loguniform': { '_type': 'loguniform', '_value': [0.001, 0.1] },
'qloguniform': { '_type': 'qloguniform', '_value': [0.001, 0.1, 0.001] },
'normal': { '_type': 'normal', '_value': [0, 0.1] },
'qnormal': { '_type': 'qnormal', '_value': [0.5, 0.1, 0.1] },
'lognormal': { '_type': 'lognormal', '_value': [0.0, 1] },
'qlognormal': { '_type': 'qlognormal', '_value': [-1, 1, 0.1] },
}
good_partial = {
'choice': good['choice'],
'randint': good['randint'],
}
bad_type = 'x'
bad_spec_type = { 'x': [1, 2, 3] }
bad_fields = { 'x': { 'type': 'choice', 'value': ['a', 'b'] } }
bad_type_name = { 'x': { '_type': 'choic', '_value': ['a'] } }
bad_value = { 'x': { '_type': 'choice', '_value': 'ab' } }
bad_2_args = { 'x': { '_type': 'randint', '_value': [1, 2, 3] } }
bad_3_args = { 'x': { '_type': 'quniform', '_value': [0] } }
bad_int_args = { 'x': { '_type': 'randint', '_value': [1.0, 2.0] } }
bad_float_args = { 'x': { '_type': 'uniform', '_value': ['0.1', '0.2'] } }
bad_low_high = { 'x': { '_type': 'quniform', '_value': [2, 1, 0.1] } }
bad_log = { 'x': { '_type': 'loguniform', '_value': [0, 1] } }
bad_sigma = { 'x': { '_type': 'normal', '_value': [0, 0] } }
def test_hpo_utils():
assert validate_search_space(good, raise_exception=False)
assert not validate_search_space(bad_type, raise_exception=False)
assert not validate_search_space(bad_spec_type, raise_exception=False)
assert not validate_search_space(bad_fields, raise_exception=False)
assert not validate_search_space(bad_type_name, raise_exception=False)
assert not validate_search_space(bad_value, raise_exception=False)
assert not validate_search_space(bad_2_args, raise_exception=False)
assert not validate_search_space(bad_3_args, raise_exception=False)
assert not validate_search_space(bad_int_args, raise_exception=False)
assert not validate_search_space(bad_float_args, raise_exception=False)
assert not validate_search_space(bad_low_high, raise_exception=False)
assert not validate_search_space(bad_log, raise_exception=False)
assert not validate_search_space(bad_sigma, raise_exception=False)
assert validate_search_space(good_partial, ['choice', 'randint'], False)
assert not validate_search_space(good, ['choice', 'randint'], False)
if __name__ == '__main__':
test_hpo_utils()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment