Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
b6894c1e
Unverified
Commit
b6894c1e
authored
Sep 27, 2021
by
Erik Fäßler
Committed by
GitHub
Sep 27, 2021
Browse files
Pass ConfigSpace definition file directly to BOHB (#4153)
parent
f1bfdd80
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
66 additions
and
44 deletions
+66
-44
nni/algorithms/hpo/bohb_advisor/bohb_advisor.py
nni/algorithms/hpo/bohb_advisor/bohb_advisor.py
+55
-44
nni/tools/nnictl/launcher_utils.py
nni/tools/nnictl/launcher_utils.py
+4
-0
nni/utils.py
nni/utils.py
+7
-0
No files found.
nni/algorithms/hpo/bohb_advisor/bohb_advisor.py
View file @
b6894c1e
...
@@ -4,7 +4,6 @@
...
@@ -4,7 +4,6 @@
'''
'''
bohb_advisor.py
bohb_advisor.py
'''
'''
import
sys
import
sys
import
math
import
math
import
logging
import
logging
...
@@ -12,6 +11,7 @@ import json_tricks
...
@@ -12,6 +11,7 @@ import json_tricks
from
schema
import
Schema
,
Optional
from
schema
import
Schema
,
Optional
import
ConfigSpace
as
CS
import
ConfigSpace
as
CS
import
ConfigSpace.hyperparameters
as
CSH
import
ConfigSpace.hyperparameters
as
CSH
from
ConfigSpace.read_and_write
import
pcs_new
from
nni
import
ClassArgsValidator
from
nni
import
ClassArgsValidator
from
nni.runtime.protocol
import
CommandType
,
send
from
nni.runtime.protocol
import
CommandType
,
send
...
@@ -244,6 +244,7 @@ class BOHBClassArgsValidator(ClassArgsValidator):
...
@@ -244,6 +244,7 @@ class BOHBClassArgsValidator(ClassArgsValidator):
Optional
(
'random_fraction'
):
self
.
range
(
'random_fraction'
,
float
,
0
,
9999
),
Optional
(
'random_fraction'
):
self
.
range
(
'random_fraction'
,
float
,
0
,
9999
),
Optional
(
'bandwidth_factor'
):
self
.
range
(
'bandwidth_factor'
,
float
,
0
,
9999
),
Optional
(
'bandwidth_factor'
):
self
.
range
(
'bandwidth_factor'
,
float
,
0
,
9999
),
Optional
(
'min_bandwidth'
):
self
.
range
(
'min_bandwidth'
,
float
,
0
,
9999
),
Optional
(
'min_bandwidth'
):
self
.
range
(
'min_bandwidth'
,
float
,
0
,
9999
),
Optional
(
'config_space'
):
self
.
path
(
'config_space'
)
}).
validate
(
kwargs
)
}).
validate
(
kwargs
)
class
BOHB
(
MsgDispatcherBase
):
class
BOHB
(
MsgDispatcherBase
):
...
@@ -297,7 +298,8 @@ class BOHB(MsgDispatcherBase):
...
@@ -297,7 +298,8 @@ class BOHB(MsgDispatcherBase):
num_samples
=
64
,
num_samples
=
64
,
random_fraction
=
1
/
3
,
random_fraction
=
1
/
3
,
bandwidth_factor
=
3
,
bandwidth_factor
=
3
,
min_bandwidth
=
1e-3
):
min_bandwidth
=
1e-3
,
config_space
=
None
):
super
(
BOHB
,
self
).
__init__
()
super
(
BOHB
,
self
).
__init__
()
self
.
optimize_mode
=
OptimizeMode
(
optimize_mode
)
self
.
optimize_mode
=
OptimizeMode
(
optimize_mode
)
self
.
min_budget
=
min_budget
self
.
min_budget
=
min_budget
...
@@ -309,6 +311,7 @@ class BOHB(MsgDispatcherBase):
...
@@ -309,6 +311,7 @@ class BOHB(MsgDispatcherBase):
self
.
random_fraction
=
random_fraction
self
.
random_fraction
=
random_fraction
self
.
bandwidth_factor
=
bandwidth_factor
self
.
bandwidth_factor
=
bandwidth_factor
self
.
min_bandwidth
=
min_bandwidth
self
.
min_bandwidth
=
min_bandwidth
self
.
config_space
=
config_space
# all the configs waiting for run
# all the configs waiting for run
self
.
generated_hyper_configs
=
[]
self
.
generated_hyper_configs
=
[]
...
@@ -468,6 +471,14 @@ class BOHB(MsgDispatcherBase):
...
@@ -468,6 +471,14 @@ class BOHB(MsgDispatcherBase):
search space of this experiment
search space of this experiment
"""
"""
search_space
=
data
search_space
=
data
cs
=
None
logger
.
debug
(
f
'Received data:
{
data
}
'
)
if
self
.
config_space
:
logger
.
info
(
f
'Got a ConfigSpace file path, parsing the search space directly from
{
self
.
config_space
}
. '
'The NNI search space is ignored.'
)
with
open
(
self
.
config_space
,
'r'
)
as
fh
:
cs
=
pcs_new
.
read
(
fh
)
else
:
cs
=
CS
.
ConfigurationSpace
()
cs
=
CS
.
ConfigurationSpace
()
for
var
in
search_space
:
for
var
in
search_space
:
_type
=
str
(
search_space
[
var
][
"_type"
])
_type
=
str
(
search_space
[
var
][
"_type"
])
...
...
nni/tools/nnictl/launcher_utils.py
View file @
b6894c1e
...
@@ -94,6 +94,10 @@ def parse_path(experiment_config, config_path):
...
@@ -94,6 +94,10 @@ def parse_path(experiment_config, config_path):
parse_relative_path
(
root_path
,
experiment_config
[
'assessor'
],
'codeDir'
)
parse_relative_path
(
root_path
,
experiment_config
[
'assessor'
],
'codeDir'
)
if
experiment_config
.
get
(
'advisor'
):
if
experiment_config
.
get
(
'advisor'
):
parse_relative_path
(
root_path
,
experiment_config
[
'advisor'
],
'codeDir'
)
parse_relative_path
(
root_path
,
experiment_config
[
'advisor'
],
'codeDir'
)
# for BOHB when delivering a ConfigSpace file directly
if
experiment_config
.
get
(
'advisor'
).
get
(
'classArgs'
)
and
experiment_config
.
get
(
'advisor'
).
get
(
'classArgs'
).
get
(
'config_space'
):
parse_relative_path
(
root_path
,
experiment_config
.
get
(
'advisor'
).
get
(
'classArgs'
),
'config_space'
)
if
experiment_config
.
get
(
'machineList'
):
if
experiment_config
.
get
(
'machineList'
):
for
index
in
range
(
len
(
experiment_config
[
'machineList'
])):
for
index
in
range
(
len
(
experiment_config
[
'machineList'
])):
parse_relative_path
(
root_path
,
experiment_config
[
'machineList'
][
index
],
'sshKeyPath'
)
parse_relative_path
(
root_path
,
experiment_config
[
'machineList'
][
index
],
'sshKeyPath'
)
...
...
nni/utils.py
View file @
b6894c1e
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
import
copy
import
copy
import
functools
import
functools
from
enum
import
Enum
,
unique
from
enum
import
Enum
,
unique
from
pathlib
import
Path
import
json_tricks
import
json_tricks
from
schema
import
And
from
schema
import
And
...
@@ -305,3 +306,9 @@ class ClassArgsValidator(object):
...
@@ -305,3 +306,9 @@ class ClassArgsValidator(object):
And
(
keyType
,
error
=
'%s should be %s type!'
%
(
key
,
keyType
.
__name__
)),
And
(
keyType
,
error
=
'%s should be %s type!'
%
(
key
,
keyType
.
__name__
)),
And
(
lambda
n
:
start
<=
n
<=
end
,
error
=
'%s should be in range of (%s, %s)!'
%
(
key
,
start
,
end
))
And
(
lambda
n
:
start
<=
n
<=
end
,
error
=
'%s should be in range of (%s, %s)!'
%
(
key
,
start
,
end
))
)
)
def
path
(
self
,
key
):
return
And
(
And
(
str
,
error
=
'%s should be a string!'
%
key
),
And
(
lambda
p
:
Path
(
p
).
exists
(),
error
=
'%s path does not exist!'
%
(
key
))
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment