Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
32fdd32b
Unverified
Commit
32fdd32b
authored
Jun 30, 2021
by
liuzhe-lz
Committed by
GitHub
Jun 30, 2021
Browse files
Add simple HPO search space validation (#3877)
parent
749a463a
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
144 additions
and
0 deletions
+144
-0
nni/algorithms/hpo/batch_tuner.py
nni/algorithms/hpo/batch_tuner.py
+2
-0
nni/algorithms/hpo/dngo_tuner.py
nni/algorithms/hpo/dngo_tuner.py
+2
-0
nni/algorithms/hpo/gp_tuner/gp_tuner.py
nni/algorithms/hpo/gp_tuner/gp_tuner.py
+2
-0
nni/algorithms/hpo/gridsearch_tuner.py
nni/algorithms/hpo/gridsearch_tuner.py
+2
-0
nni/algorithms/hpo/hyperband_advisor.py
nni/algorithms/hpo/hyperband_advisor.py
+2
-0
nni/algorithms/hpo/hyperopt_tuner.py
nni/algorithms/hpo/hyperopt_tuner.py
+2
-0
nni/algorithms/hpo/metis_tuner/metis_tuner.py
nni/algorithms/hpo/metis_tuner/metis_tuner.py
+3
-0
nni/algorithms/hpo/smac_tuner/smac_tuner.py
nni/algorithms/hpo/smac_tuner/smac_tuner.py
+2
-0
nni/common/hpo_utils.py
nni/common/hpo_utils.py
+75
-0
test/ut/sdk/test_hpo_utils.py
test/ut/sdk/test_hpo_utils.py
+52
-0
No files found.
nni/algorithms/hpo/batch_tuner.py
View file @
32fdd32b
...
@@ -9,6 +9,7 @@ batch_tuner.py including:
...
@@ -9,6 +9,7 @@ batch_tuner.py including:
import
logging
import
logging
import
nni
import
nni
from
nni.common.hpo_utils
import
validate_search_space
from
nni.tuner
import
Tuner
from
nni.tuner
import
Tuner
TYPE
=
'_type'
TYPE
=
'_type'
...
@@ -75,6 +76,7 @@ class BatchTuner(Tuner):
...
@@ -75,6 +76,7 @@ class BatchTuner(Tuner):
----------
----------
search_space : dict
search_space : dict
"""
"""
validate_search_space
(
search_space
,
[
'choice'
])
self
.
_values
=
self
.
is_valid
(
search_space
)
self
.
_values
=
self
.
is_valid
(
search_space
)
def
generate_parameters
(
self
,
parameter_id
,
**
kwargs
):
def
generate_parameters
(
self
,
parameter_id
,
**
kwargs
):
...
...
nni/algorithms/hpo/dngo_tuner.py
View file @
32fdd32b
...
@@ -7,6 +7,7 @@ from torch.distributions import Normal
...
@@ -7,6 +7,7 @@ from torch.distributions import Normal
import
nni.parameter_expressions
as
parameter_expressions
import
nni.parameter_expressions
as
parameter_expressions
from
nni
import
ClassArgsValidator
from
nni
import
ClassArgsValidator
from
nni.common.hpo_utils
import
validate_search_space
from
nni.tuner
import
Tuner
from
nni.tuner
import
Tuner
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -86,6 +87,7 @@ class DNGOTuner(Tuner):
...
@@ -86,6 +87,7 @@ class DNGOTuner(Tuner):
return
new_x
return
new_x
def
update_search_space
(
self
,
search_space
):
def
update_search_space
(
self
,
search_space
):
validate_search_space
(
search_space
,
[
'choice'
,
'randint'
,
'uniform'
,
'quniform'
,
'loguniform'
,
'qloguniform'
])
self
.
searchspace_json
=
search_space
self
.
searchspace_json
=
search_space
self
.
random_state
=
np
.
random
.
RandomState
()
self
.
random_state
=
np
.
random
.
RandomState
()
...
...
nni/algorithms/hpo/gp_tuner/gp_tuner.py
View file @
32fdd32b
...
@@ -16,6 +16,7 @@ from sklearn.gaussian_process.kernels import Matern
...
@@ -16,6 +16,7 @@ from sklearn.gaussian_process.kernels import Matern
from
sklearn.gaussian_process
import
GaussianProcessRegressor
from
sklearn.gaussian_process
import
GaussianProcessRegressor
from
nni
import
ClassArgsValidator
from
nni
import
ClassArgsValidator
from
nni.common.hpo_utils
import
validate_search_space
from
nni.tuner
import
Tuner
from
nni.tuner
import
Tuner
from
nni.utils
import
OptimizeMode
,
extract_scalar_reward
from
nni.utils
import
OptimizeMode
,
extract_scalar_reward
...
@@ -103,6 +104,7 @@ class GPTuner(Tuner):
...
@@ -103,6 +104,7 @@ class GPTuner(Tuner):
Override of the abstract method in :class:`~nni.tuner.Tuner`.
Override of the abstract method in :class:`~nni.tuner.Tuner`.
"""
"""
validate_search_space
(
search_space
,
[
'choice'
,
'randint'
,
'uniform'
,
'quniform'
,
'loguniform'
,
'qloguniform'
])
self
.
_space
=
TargetSpace
(
search_space
,
self
.
_random_state
)
self
.
_space
=
TargetSpace
(
search_space
,
self
.
_random_state
)
def
generate_parameters
(
self
,
parameter_id
,
**
kwargs
):
def
generate_parameters
(
self
,
parameter_id
,
**
kwargs
):
...
...
nni/algorithms/hpo/gridsearch_tuner.py
View file @
32fdd32b
...
@@ -11,6 +11,7 @@ import logging
...
@@ -11,6 +11,7 @@ import logging
import
numpy
as
np
import
numpy
as
np
import
nni
import
nni
from
nni.common.hpo_utils
import
validate_search_space
from
nni.tuner
import
Tuner
from
nni.tuner
import
Tuner
from
nni.utils
import
convert_dict2tuple
from
nni.utils
import
convert_dict2tuple
...
@@ -144,6 +145,7 @@ class GridSearchTuner(Tuner):
...
@@ -144,6 +145,7 @@ class GridSearchTuner(Tuner):
search_space : dict
search_space : dict
The format could be referred to search space spec (https://nni.readthedocs.io/en/latest/Tutorial/SearchSpaceSpec.html).
The format could be referred to search space spec (https://nni.readthedocs.io/en/latest/Tutorial/SearchSpaceSpec.html).
"""
"""
validate_search_space
(
search_space
,
[
'choice'
,
'randint'
,
'quniform'
])
self
.
expanded_search_space
=
self
.
_json2parameter
(
search_space
)
self
.
expanded_search_space
=
self
.
_json2parameter
(
search_space
)
def
generate_parameters
(
self
,
parameter_id
,
**
kwargs
):
def
generate_parameters
(
self
,
parameter_id
,
**
kwargs
):
...
...
nni/algorithms/hpo/hyperband_advisor.py
View file @
32fdd32b
...
@@ -15,6 +15,7 @@ import numpy as np
...
@@ -15,6 +15,7 @@ import numpy as np
from
schema
import
Schema
,
Optional
from
schema
import
Schema
,
Optional
from
nni
import
ClassArgsValidator
from
nni
import
ClassArgsValidator
from
nni.common.hpo_utils
import
validate_search_space
from
nni.runtime.common
import
multi_phase_enabled
from
nni.runtime.common
import
multi_phase_enabled
from
nni.runtime.msg_dispatcher_base
import
MsgDispatcherBase
from
nni.runtime.msg_dispatcher_base
import
MsgDispatcherBase
from
nni.runtime.protocol
import
CommandType
,
send
from
nni.runtime.protocol
import
CommandType
,
send
...
@@ -379,6 +380,7 @@ class Hyperband(MsgDispatcherBase):
...
@@ -379,6 +380,7 @@ class Hyperband(MsgDispatcherBase):
def
handle_update_search_space
(
self
,
data
):
def
handle_update_search_space
(
self
,
data
):
"""data: JSON object, which is search space
"""data: JSON object, which is search space
"""
"""
validate_search_space
(
data
)
self
.
searchspace_json
=
data
self
.
searchspace_json
=
data
self
.
random_state
=
np
.
random
.
RandomState
()
self
.
random_state
=
np
.
random
.
RandomState
()
...
...
nni/algorithms/hpo/hyperopt_tuner.py
View file @
32fdd32b
...
@@ -12,6 +12,7 @@ import hyperopt as hp
...
@@ -12,6 +12,7 @@ import hyperopt as hp
import
numpy
as
np
import
numpy
as
np
from
schema
import
Optional
,
Schema
from
schema
import
Optional
,
Schema
from
nni
import
ClassArgsValidator
from
nni
import
ClassArgsValidator
from
nni.common.hpo_utils
import
validate_search_space
from
nni.tuner
import
Tuner
from
nni.tuner
import
Tuner
from
nni.utils
import
NodeType
,
OptimizeMode
,
extract_scalar_reward
,
split_index
from
nni.utils
import
NodeType
,
OptimizeMode
,
extract_scalar_reward
,
split_index
...
@@ -246,6 +247,7 @@ class HyperoptTuner(Tuner):
...
@@ -246,6 +247,7 @@ class HyperoptTuner(Tuner):
----------
----------
search_space : dict
search_space : dict
"""
"""
validate_search_space
(
search_space
)
self
.
json
=
search_space
self
.
json
=
search_space
search_space_instance
=
json2space
(
self
.
json
)
search_space_instance
=
json2space
(
self
.
json
)
...
...
nni/algorithms/hpo/metis_tuner/metis_tuner.py
View file @
32fdd32b
...
@@ -16,6 +16,7 @@ from schema import Schema, Optional
...
@@ -16,6 +16,7 @@ from schema import Schema, Optional
from
nni
import
ClassArgsValidator
from
nni
import
ClassArgsValidator
from
nni.tuner
import
Tuner
from
nni.tuner
import
Tuner
from
nni.common.hpo_utils
import
validate_search_space
from
nni.utils
import
OptimizeMode
,
extract_scalar_reward
from
nni.utils
import
OptimizeMode
,
extract_scalar_reward
from
.
import
lib_constraint_summation
from
.
import
lib_constraint_summation
from
.
import
lib_data
from
.
import
lib_data
...
@@ -152,6 +153,8 @@ class MetisTuner(Tuner):
...
@@ -152,6 +153,8 @@ class MetisTuner(Tuner):
----------
----------
search_space : dict
search_space : dict
"""
"""
validate_search_space
(
search_space
,
[
'choice'
,
'randint'
,
'uniform'
,
'quniform'
])
self
.
x_bounds
=
[[]
for
i
in
range
(
len
(
search_space
))]
self
.
x_bounds
=
[[]
for
i
in
range
(
len
(
search_space
))]
self
.
x_types
=
[
NONE_TYPE
for
i
in
range
(
len
(
search_space
))]
self
.
x_types
=
[
NONE_TYPE
for
i
in
range
(
len
(
search_space
))]
...
...
nni/algorithms/hpo/smac_tuner/smac_tuner.py
View file @
32fdd32b
...
@@ -21,6 +21,7 @@ from ConfigSpaceNNI import Configuration
...
@@ -21,6 +21,7 @@ from ConfigSpaceNNI import Configuration
import
nni
import
nni
from
nni
import
ClassArgsValidator
from
nni
import
ClassArgsValidator
from
nni.common.hpo_utils
import
validate_search_space
from
nni.tuner
import
Tuner
from
nni.tuner
import
Tuner
from
nni.utils
import
OptimizeMode
,
extract_scalar_reward
from
nni.utils
import
OptimizeMode
,
extract_scalar_reward
...
@@ -143,6 +144,7 @@ class SMACTuner(Tuner):
...
@@ -143,6 +144,7 @@ class SMACTuner(Tuner):
The format could be referred to search space spec (https://nni.readthedocs.io/en/latest/Tutorial/SearchSpaceSpec.html).
The format could be referred to search space spec (https://nni.readthedocs.io/en/latest/Tutorial/SearchSpaceSpec.html).
"""
"""
self
.
logger
.
info
(
'update search space in SMAC.'
)
self
.
logger
.
info
(
'update search space in SMAC.'
)
validate_search_space
(
search_space
,
[
'choice'
,
'randint'
,
'uniform'
,
'quniform'
,
'loguniform'
])
if
not
self
.
update_ss_done
:
if
not
self
.
update_ss_done
:
self
.
categorical_dict
=
generate_scenario
(
search_space
)
self
.
categorical_dict
=
generate_scenario
(
search_space
)
if
self
.
categorical_dict
is
None
:
if
self
.
categorical_dict
is
None
:
...
...
nni/common/hpo_utils.py
0 → 100644
View file @
32fdd32b
import
logging
from
typing
import
Any
,
List
,
Optional
common_search_space_types
=
[
'choice'
,
'randint'
,
'uniform'
,
'quniform'
,
'loguniform'
,
'qloguniform'
,
'normal'
,
'qnormal'
,
'lognormal'
,
'qlognormal'
,
]
def
validate_search_space
(
search_space
:
Any
,
support_types
:
Optional
[
List
[
str
]]
=
None
,
raise_exception
:
bool
=
False
# for now, in case false positive
)
->
bool
:
if
not
raise_exception
:
try
:
validate_search_space
(
search_space
,
support_types
,
True
)
return
True
except
ValueError
as
e
:
logging
.
getLogger
(
__name__
).
error
(
e
.
args
[
0
])
return
False
if
support_types
is
None
:
support_types
=
common_search_space_types
if
not
isinstance
(
search_space
,
dict
):
raise
ValueError
(
f
'search space is a
{
type
(
search_space
).
__name__
}
, expect a dict :
{
repr
(
search_space
)
}
'
)
for
name
,
spec
in
search_space
.
items
():
if
not
isinstance
(
spec
,
dict
):
raise
ValueError
(
f
'search space "
{
name
}
" is a
{
type
(
spec
).
__name__
}
, expect a dict :
{
repr
(
spec
)
}
'
)
if
'_type'
not
in
spec
or
'_value'
not
in
spec
:
raise
ValueError
(
f
'search space "
{
name
}
" does not have "_type" or "_value" :
{
spec
}
'
)
type_
=
spec
[
'_type'
]
if
type_
not
in
support_types
:
raise
ValueError
(
f
'search space "
{
name
}
" has unsupported type "
{
type_
}
" :
{
spec
}
'
)
args
=
spec
[
'_value'
]
if
not
isinstance
(
args
,
list
):
raise
ValueError
(
f
'search space "
{
name
}
"
\'
s value is not a list :
{
spec
}
'
)
if
type_
==
'choice'
:
continue
if
type_
.
startswith
(
'q'
):
if
len
(
args
)
!=
3
:
raise
ValueError
(
f
'search space "
{
name
}
" (
{
type_
}
) must have 3 values :
{
spec
}
'
)
else
:
if
len
(
args
)
!=
2
:
raise
ValueError
(
f
'search space "
{
name
}
" (
{
type_
}
) must have 2 values :
{
spec
}
'
)
if
type_
==
'randint'
:
if
not
all
(
isinstance
(
arg
,
int
)
for
arg
in
args
):
raise
ValueError
(
f
'search space "
{
name
}
" (
{
type_
}
) must have int values :
{
spec
}
'
)
else
:
if
not
all
(
isinstance
(
arg
,
(
float
,
int
))
for
arg
in
args
):
raise
ValueError
(
f
'search space "
{
name
}
" (
{
type_
}
) must have float values :
{
spec
}
'
)
if
'normal'
not
in
type_
:
if
args
[
0
]
>=
args
[
1
]:
raise
ValueError
(
f
'search space "
{
name
}
" (
{
type_
}
) must have high > low :
{
spec
}
'
)
if
'log'
in
type_
and
args
[
0
]
<=
0
:
raise
ValueError
(
f
'search space "
{
name
}
" (
{
type_
}
) must have low > 0 :
{
spec
}
'
)
else
:
if
args
[
1
]
<=
0
:
raise
ValueError
(
f
'search space "
{
name
}
" (
{
type_
}
) must have sigma > 0 :
{
spec
}
'
)
return
True
test/ut/sdk/test_hpo_utils.py
0 → 100644
View file @
32fdd32b
from
nni.common.hpo_utils
import
validate_search_space
good
=
{
'choice'
:
{
'_type'
:
'choice'
,
'_value'
:
[
'a'
,
'b'
]
},
'randint'
:
{
'_type'
:
'randint'
,
'_value'
:
[
1
,
10
]
},
'uniform'
:
{
'_type'
:
'uniform'
,
'_value'
:
[
0
,
1.0
]
},
'quniform'
:
{
'_type'
:
'quniform'
,
'_value'
:
[
1
,
10
,
0.1
]
},
'loguniform'
:
{
'_type'
:
'loguniform'
,
'_value'
:
[
0.001
,
0.1
]
},
'qloguniform'
:
{
'_type'
:
'qloguniform'
,
'_value'
:
[
0.001
,
0.1
,
0.001
]
},
'normal'
:
{
'_type'
:
'normal'
,
'_value'
:
[
0
,
0.1
]
},
'qnormal'
:
{
'_type'
:
'qnormal'
,
'_value'
:
[
0.5
,
0.1
,
0.1
]
},
'lognormal'
:
{
'_type'
:
'lognormal'
,
'_value'
:
[
0.0
,
1
]
},
'qlognormal'
:
{
'_type'
:
'qlognormal'
,
'_value'
:
[
-
1
,
1
,
0.1
]
},
}
good_partial
=
{
'choice'
:
good
[
'choice'
],
'randint'
:
good
[
'randint'
],
}
bad_type
=
'x'
bad_spec_type
=
{
'x'
:
[
1
,
2
,
3
]
}
bad_fields
=
{
'x'
:
{
'type'
:
'choice'
,
'value'
:
[
'a'
,
'b'
]
}
}
bad_type_name
=
{
'x'
:
{
'_type'
:
'choic'
,
'_value'
:
[
'a'
]
}
}
bad_value
=
{
'x'
:
{
'_type'
:
'choice'
,
'_value'
:
'ab'
}
}
bad_2_args
=
{
'x'
:
{
'_type'
:
'randint'
,
'_value'
:
[
1
,
2
,
3
]
}
}
bad_3_args
=
{
'x'
:
{
'_type'
:
'quniform'
,
'_value'
:
[
0
]
}
}
bad_int_args
=
{
'x'
:
{
'_type'
:
'randint'
,
'_value'
:
[
1.0
,
2.0
]
}
}
bad_float_args
=
{
'x'
:
{
'_type'
:
'uniform'
,
'_value'
:
[
'0.1'
,
'0.2'
]
}
}
bad_low_high
=
{
'x'
:
{
'_type'
:
'quniform'
,
'_value'
:
[
2
,
1
,
0.1
]
}
}
bad_log
=
{
'x'
:
{
'_type'
:
'loguniform'
,
'_value'
:
[
0
,
1
]
}
}
bad_sigma
=
{
'x'
:
{
'_type'
:
'normal'
,
'_value'
:
[
0
,
0
]
}
}
def
test_hpo_utils
():
assert
validate_search_space
(
good
,
raise_exception
=
False
)
assert
not
validate_search_space
(
bad_type
,
raise_exception
=
False
)
assert
not
validate_search_space
(
bad_spec_type
,
raise_exception
=
False
)
assert
not
validate_search_space
(
bad_fields
,
raise_exception
=
False
)
assert
not
validate_search_space
(
bad_type_name
,
raise_exception
=
False
)
assert
not
validate_search_space
(
bad_value
,
raise_exception
=
False
)
assert
not
validate_search_space
(
bad_2_args
,
raise_exception
=
False
)
assert
not
validate_search_space
(
bad_3_args
,
raise_exception
=
False
)
assert
not
validate_search_space
(
bad_int_args
,
raise_exception
=
False
)
assert
not
validate_search_space
(
bad_float_args
,
raise_exception
=
False
)
assert
not
validate_search_space
(
bad_low_high
,
raise_exception
=
False
)
assert
not
validate_search_space
(
bad_log
,
raise_exception
=
False
)
assert
not
validate_search_space
(
bad_sigma
,
raise_exception
=
False
)
assert
validate_search_space
(
good_partial
,
[
'choice'
,
'randint'
],
False
)
assert
not
validate_search_space
(
good
,
[
'choice'
,
'randint'
],
False
)
if
__name__
==
'__main__'
:
test_hpo_utils
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment