"...git@developer.sourcefind.cn:OpenDAS/tilelang.git" did not exist on "a7730272e4aeeed198b855b7f36ef7ac88cdd76b"
Unverified Commit b0de7c93 authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Add search space validation for choice types (#3975)

parent ef9e27b9
...@@ -47,6 +47,11 @@ def validate_search_space( ...@@ -47,6 +47,11 @@ def validate_search_space(
raise ValueError(f'search space "{name}"\'s value is not a list : {spec}') raise ValueError(f'search space "{name}"\'s value is not a list : {spec}')
if type_ == 'choice': if type_ == 'choice':
if not all(isinstance(arg, (float, int, str)) for arg in args):
# FIXME: need further check for each algorithm which types are actually supported
# for now validation only prints warning so it doesn't harm
if not isinstance(args[0], dict) or '_name' not in args[0]: # not nested search space
raise ValueError(f'search space "{name}" (choice) should only contain numbers or strings : {spec}')
continue continue
if type_.startswith('q'): if type_.startswith('q'):
......
...@@ -16,12 +16,22 @@ good_partial = { ...@@ -16,12 +16,22 @@ good_partial = {
'choice': good['choice'], 'choice': good['choice'],
'randint': good['randint'], 'randint': good['randint'],
} }
good_nested = {
'outer': {
'_type': 'choice',
'_value': [
{ '_name': 'empty' },
{ '_name': 'a', 'a_1': { '_type': 'choice', '_value': ['a', 'b'] } }
]
}
}
bad_type = 'x' bad_type = 'x'
bad_spec_type = { 'x': [1, 2, 3] } bad_spec_type = { 'x': [1, 2, 3] }
bad_fields = { 'x': { 'type': 'choice', 'value': ['a', 'b'] } } bad_fields = { 'x': { 'type': 'choice', 'value': ['a', 'b'] } }
bad_type_name = { 'x': { '_type': 'choic', '_value': ['a'] } } bad_type_name = { 'x': { '_type': 'choic', '_value': ['a'] } }
bad_value = { 'x': { '_type': 'choice', '_value': 'ab' } } bad_value = { 'x': { '_type': 'choice', '_value': 'ab' } }
bad_choice_args = { 'x': { '_type': 'choice', 'value': [ 'a', object() ] } }
bad_2_args = { 'x': { '_type': 'randint', '_value': [1, 2, 3] } } bad_2_args = { 'x': { '_type': 'randint', '_value': [1, 2, 3] } }
bad_3_args = { 'x': { '_type': 'quniform', '_value': [0] } } bad_3_args = { 'x': { '_type': 'quniform', '_value': [0] } }
bad_int_args = { 'x': { '_type': 'randint', '_value': [1.0, 2.0] } } bad_int_args = { 'x': { '_type': 'randint', '_value': [1.0, 2.0] } }
...@@ -32,11 +42,13 @@ bad_sigma = { 'x': { '_type': 'normal', '_value': [0, 0] } } ...@@ -32,11 +42,13 @@ bad_sigma = { 'x': { '_type': 'normal', '_value': [0, 0] } }
def test_hpo_utils(): def test_hpo_utils():
assert validate_search_space(good, raise_exception=False) assert validate_search_space(good, raise_exception=False)
assert validate_search_space(good_nested, raise_exception=False)
assert not validate_search_space(bad_type, raise_exception=False) assert not validate_search_space(bad_type, raise_exception=False)
assert not validate_search_space(bad_spec_type, raise_exception=False) assert not validate_search_space(bad_spec_type, raise_exception=False)
assert not validate_search_space(bad_fields, raise_exception=False) assert not validate_search_space(bad_fields, raise_exception=False)
assert not validate_search_space(bad_type_name, raise_exception=False) assert not validate_search_space(bad_type_name, raise_exception=False)
assert not validate_search_space(bad_value, raise_exception=False) assert not validate_search_space(bad_value, raise_exception=False)
assert not validate_search_space(bad_choice_args, raise_exception=False)
assert not validate_search_space(bad_2_args, raise_exception=False) assert not validate_search_space(bad_2_args, raise_exception=False)
assert not validate_search_space(bad_3_args, raise_exception=False) assert not validate_search_space(bad_3_args, raise_exception=False)
assert not validate_search_space(bad_int_args, raise_exception=False) assert not validate_search_space(bad_int_args, raise_exception=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment