Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
b7cd20e6
Unverified
Commit
b7cd20e6
authored
Jul 12, 2019
by
QuanluZhang
Committed by
GitHub
Jul 12, 2019
Browse files
support multi-phase in hyperband (#1257)
* support multi-phase in hyperband and bohb
parent
917ce97f
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
180 additions
and
91 deletions
+180
-91
src/sdk/pynni/nni/__main__.py
src/sdk/pynni/nni/__main__.py
+0
-2
src/sdk/pynni/nni/bohb_advisor/bohb_advisor.py
src/sdk/pynni/nni/bohb_advisor/bohb_advisor.py
+106
-46
src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py
src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py
+60
-38
src/sdk/pynni/nni/msg_dispatcher.py
src/sdk/pynni/nni/msg_dispatcher.py
+6
-5
src/sdk/pynni/nni/utils.py
src/sdk/pynni/nni/utils.py
+8
-0
No files found.
src/sdk/pynni/nni/__main__.py
View file @
b7cd20e6
...
...
@@ -130,8 +130,6 @@ def main():
if
args
.
advisor_class_name
:
# advisor is enabled and starts to run
if
args
.
multi_phase
:
raise
AssertionError
(
'multi_phase has not been supported in advisor'
)
if
args
.
advisor_class_name
in
AdvisorModuleName
:
dispatcher
=
create_builtin_class_instance
(
args
.
advisor_class_name
,
...
...
src/sdk/pynni/nni/bohb_advisor/bohb_advisor.py
View file @
b7cd20e6
...
...
@@ -31,7 +31,8 @@ import ConfigSpace.hyperparameters as CSH
from
nni.protocol
import
CommandType
,
send
from
nni.msg_dispatcher_base
import
MsgDispatcherBase
from
nni.utils
import
OptimizeMode
,
extract_scalar_reward
,
randint_to_quniform
from
nni.utils
import
OptimizeMode
,
MetricType
,
extract_scalar_reward
,
randint_to_quniform
from
nni.common
import
multi_phase_enabled
from
.config_generator
import
CG_BOHB
...
...
@@ -79,7 +80,7 @@ def create_bracket_parameter_id(brackets_id, brackets_curr_decay, increased_id=-
return
params_id
class
Bracket
():
class
Bracket
(
object
):
"""
A bracket in BOHB, all the information of a bracket is managed by
an instance of this class.
...
...
@@ -329,11 +330,12 @@ class BOHB(MsgDispatcherBase):
# config generator
self
.
cg
=
None
def
load_checkpoint
(
self
):
pass
def
save_checkpoint
(
self
):
pass
# record the latest parameter_id of the trial job trial_job_id.
# if there is no running parameter_id, self.job_id_para_id_map[trial_job_id] == None
# new trial job is added to this dict and finished trial job is removed from it.
self
.
job_id_para_id_map
=
dict
()
# record the unsatisfied parameter request from trial jobs
self
.
unsatisfied_jobs
=
[]
def
handle_initialize
(
self
,
data
):
"""Initialize Tuner, including creating Bayesian optimization-based parametric models
...
...
@@ -399,7 +401,7 @@ class BOHB(MsgDispatcherBase):
for
_
in
range
(
self
.
credit
):
self
.
_request_one_trial_job
()
def
_
reques
t_one_trial_job
(
self
):
def
_
ge
t_one_trial_job
(
self
):
"""get one trial job, i.e., one hyperparameter configuration.
If this function is called, Command will be sent by BOHB:
...
...
@@ -423,7 +425,7 @@ class BOHB(MsgDispatcherBase):
'parameters'
:
''
}
send
(
CommandType
.
NoMoreTrialJobs
,
json_tricks
.
dumps
(
ret
))
return
return
None
assert
self
.
generated_hyper_configs
params
=
self
.
generated_hyper_configs
.
pop
()
ret
=
{
...
...
@@ -432,6 +434,27 @@ class BOHB(MsgDispatcherBase):
'parameters'
:
params
[
1
]
}
self
.
parameters
[
params
[
0
]]
=
params
[
1
]
return
ret
def
_request_one_trial_job
(
self
):
"""get one trial job, i.e., one hyperparameter configuration.
If this function is called, Command will be sent by BOHB:
a. If there is a parameter need to run, will return "NewTrialJob" with a dict:
{
'parameter_id': id of new hyperparameter
'parameter_source': 'algorithm'
'parameters': value of new hyperparameter
}
b. If BOHB don't have parameter waiting, will return "NoMoreTrialJobs" with
{
'parameter_id': '-1_0_0',
'parameter_source': 'algorithm',
'parameters': ''
}
"""
ret
=
self
.
_get_one_trial_job
()
if
ret
is
not
None
:
send
(
CommandType
.
NewTrialJob
,
json_tricks
.
dumps
(
ret
))
self
.
credit
-=
1
...
...
@@ -502,23 +525,38 @@ class BOHB(MsgDispatcherBase):
hyper_params: the hyperparameters (a string) generated and returned by tuner
"""
logger
.
debug
(
'Tuner handle trial end, result is %s'
,
data
)
hyper_params
=
json_tricks
.
loads
(
data
[
'hyper_params'
])
s
,
i
,
_
=
hyper_params
[
'parameter_id'
].
split
(
'_'
)
self
.
_handle_trial_end
(
hyper_params
[
'parameter_id'
])
if
data
[
'trial_job_id'
]
in
self
.
job_id_para_id_map
:
del
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]]
def
_send_new_trial
(
self
):
while
self
.
unsatisfied_jobs
:
ret
=
self
.
_get_one_trial_job
()
if
ret
is
None
:
break
one_unsatisfied
=
self
.
unsatisfied_jobs
.
pop
(
0
)
ret
[
'trial_job_id'
]
=
one_unsatisfied
[
'trial_job_id'
]
ret
[
'parameter_index'
]
=
one_unsatisfied
[
'parameter_index'
]
# update parameter_id in self.job_id_para_id_map
self
.
job_id_para_id_map
[
ret
[
'trial_job_id'
]]
=
ret
[
'parameter_id'
]
send
(
CommandType
.
SendTrialJobParameter
,
json_tricks
.
dumps
(
ret
))
for
_
in
range
(
self
.
credit
):
self
.
_request_one_trial_job
()
def
_handle_trial_end
(
self
,
parameter_id
):
s
,
i
,
_
=
parameter_id
.
split
(
'_'
)
hyper_configs
=
self
.
brackets
[
int
(
s
)].
inform_trial_end
(
int
(
i
))
if
hyper_configs
is
not
None
:
logger
.
debug
(
'bracket %s next round %s, hyper_configs: %s'
,
s
,
i
,
hyper_configs
)
self
.
generated_hyper_configs
=
self
.
generated_hyper_configs
+
hyper_configs
for
_
in
range
(
self
.
credit
):
self
.
_request_one_trial_job
()
# Finish this bracket and generate a new bracket
elif
self
.
brackets
[
int
(
s
)].
no_more_trial
:
self
.
curr_s
-=
1
self
.
generate_new_bracket
()
for
_
in
range
(
self
.
credit
):
self
.
_request_one_trial_job
()
self
.
_send_new_trial
()
def
handle_report_metric_data
(
self
,
data
):
"""reveice the metric data and update Bayesian optimization with final result
...
...
@@ -535,6 +573,22 @@ class BOHB(MsgDispatcherBase):
"""
logger
.
debug
(
'handle report metric data = %s'
,
data
)
if
data
[
'type'
]
==
MetricType
.
REQUEST_PARAMETER
:
assert
multi_phase_enabled
()
assert
data
[
'trial_job_id'
]
is
not
None
assert
data
[
'parameter_index'
]
is
not
None
assert
data
[
'trial_job_id'
]
in
self
.
job_id_para_id_map
self
.
_handle_trial_end
(
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]])
ret
=
self
.
_get_one_trial_job
()
if
ret
is
None
:
self
.
unsatisfied_jobs
.
append
({
'trial_job_id'
:
data
[
'trial_job_id'
],
'parameter_index'
:
data
[
'parameter_index'
]})
else
:
ret
[
'trial_job_id'
]
=
data
[
'trial_job_id'
]
ret
[
'parameter_index'
]
=
data
[
'parameter_index'
]
# update parameter_id in self.job_id_para_id_map
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]]
=
ret
[
'parameter_id'
]
send
(
CommandType
.
SendTrialJobParameter
,
json_tricks
.
dumps
(
ret
))
else
:
assert
'value'
in
data
value
=
extract_scalar_reward
(
data
[
'value'
])
if
self
.
optimize_mode
is
OptimizeMode
.
Maximize
:
...
...
@@ -543,12 +597,18 @@ class BOHB(MsgDispatcherBase):
reward
=
value
assert
'parameter_id'
in
data
s
,
i
,
_
=
data
[
'parameter_id'
].
split
(
'_'
)
logger
.
debug
(
'bracket id = %s, metrics value = %s, type = %s'
,
s
,
value
,
data
[
'type'
])
s
=
int
(
s
)
# add <trial_job_id, parameter_id> to self.job_id_para_id_map here,
# because when the first parameter_id is created, trial_job_id is not known yet.
if
data
[
'trial_job_id'
]
in
self
.
job_id_para_id_map
:
assert
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]]
==
data
[
'parameter_id'
]
else
:
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]]
=
data
[
'parameter_id'
]
assert
'type'
in
data
if
data
[
'type'
]
==
'
FINAL
'
:
if
data
[
'type'
]
==
MetricType
.
FINAL
:
# and PERIODICAL metric are independent, thus, not comparable.
assert
'sequence'
in
data
self
.
brackets
[
s
].
set_config_perf
(
...
...
@@ -559,7 +619,7 @@ class BOHB(MsgDispatcherBase):
_parameters
.
pop
(
_KEY
)
# update BO with loss, max_s budget, hyperparameters
self
.
cg
.
new_result
(
loss
=
reward
,
budget
=
data
[
'sequence'
],
parameters
=
_parameters
,
update_model
=
True
)
elif
data
[
'type'
]
==
'
PERIODICAL
'
:
elif
data
[
'type'
]
==
MetricType
.
PERIODICAL
:
self
.
brackets
[
s
].
set_config_perf
(
int
(
i
),
data
[
'parameter_id'
],
data
[
'sequence'
],
value
)
else
:
...
...
src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py
View file @
b7cd20e6
...
...
@@ -30,8 +30,8 @@ import json_tricks
from
nni.protocol
import
CommandType
,
send
from
nni.msg_dispatcher_base
import
MsgDispatcherBase
from
nni.common
import
init_logger
from
nni.utils
import
NodeType
,
OptimizeMode
,
extract_scalar_reward
,
randint_to_quniform
from
nni.common
import
init_logger
,
multi_phase_enabled
from
nni.utils
import
NodeType
,
OptimizeMode
,
MetricType
,
extract_scalar_reward
,
randint_to_quniform
import
nni.parameter_expressions
as
parameter_expressions
_logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -277,7 +277,7 @@ class Hyperband(MsgDispatcherBase):
optimize_mode: str
optimize mode, 'maximize' or 'minimize'
"""
def
__init__
(
self
,
R
,
eta
=
3
,
optimize_mode
=
'maximize'
):
def
__init__
(
self
,
R
=
60
,
eta
=
3
,
optimize_mode
=
'maximize'
):
"""B = (s_max + 1)R"""
super
(
Hyperband
,
self
).
__init__
()
self
.
R
=
R
# pylint: disable=invalid-name
...
...
@@ -296,11 +296,10 @@ class Hyperband(MsgDispatcherBase):
# In this case, tuner increases self.credit to issue a trial config sometime later.
self
.
credit
=
0
def
load_checkpoint
(
self
):
pass
def
save_checkpoint
(
self
):
pass
# record the latest parameter_id of the trial job trial_job_id.
# if there is no running parameter_id, self.job_id_para_id_map[trial_job_id] == None
# new trial job is added to this dict and finished trial job is removed from it.
self
.
job_id_para_id_map
=
dict
()
def
handle_initialize
(
self
,
data
):
"""data is search space
...
...
@@ -321,9 +320,10 @@ class Hyperband(MsgDispatcherBase):
number of trial jobs
"""
for
_
in
range
(
data
):
self
.
_request_one_trial_job
()
ret
=
self
.
_get_one_trial_job
()
send
(
CommandType
.
NewTrialJob
,
json_tricks
.
dumps
(
ret
))
def
_
reques
t_one_trial_job
(
self
):
def
_
ge
t_one_trial_job
(
self
):
"""get one trial job, i.e., one hyperparameter configuration."""
if
not
self
.
generated_hyper_configs
:
if
self
.
curr_s
<
0
:
...
...
@@ -346,7 +346,8 @@ class Hyperband(MsgDispatcherBase):
'parameter_source'
:
'algorithm'
,
'parameters'
:
params
[
1
]
}
send
(
CommandType
.
NewTrialJob
,
json_tricks
.
dumps
(
ret
))
return
ret
def
handle_update_search_space
(
self
,
data
):
"""data: JSON object, which is search space
...
...
@@ -360,6 +361,18 @@ class Hyperband(MsgDispatcherBase):
randint_to_quniform
(
self
.
searchspace_json
)
self
.
random_state
=
np
.
random
.
RandomState
()
def
_handle_trial_end
(
self
,
parameter_id
):
"""
Parameters
----------
parameter_id: parameter id of the finished config
"""
bracket_id
,
i
,
_
=
parameter_id
.
split
(
'_'
)
hyper_configs
=
self
.
brackets
[
int
(
bracket_id
)].
inform_trial_end
(
int
(
i
))
if
hyper_configs
is
not
None
:
_logger
.
debug
(
'bracket %s next round %s, hyper_configs: %s'
,
bracket_id
,
i
,
hyper_configs
)
self
.
generated_hyper_configs
=
self
.
generated_hyper_configs
+
hyper_configs
def
handle_trial_end
(
self
,
data
):
"""
Parameters
...
...
@@ -371,22 +384,9 @@ class Hyperband(MsgDispatcherBase):
hyper_params: the hyperparameters (a string) generated and returned by tuner
"""
hyper_params
=
json_tricks
.
loads
(
data
[
'hyper_params'
])
bracket_id
,
i
,
_
=
hyper_params
[
'parameter_id'
].
split
(
'_'
)
hyper_configs
=
self
.
brackets
[
int
(
bracket_id
)].
inform_trial_end
(
int
(
i
))
if
hyper_configs
is
not
None
:
_logger
.
debug
(
'bracket %s next round %s, hyper_configs: %s'
,
bracket_id
,
i
,
hyper_configs
)
self
.
generated_hyper_configs
=
self
.
generated_hyper_configs
+
hyper_configs
for
_
in
range
(
self
.
credit
):
if
not
self
.
generated_hyper_configs
:
break
params
=
self
.
generated_hyper_configs
.
pop
()
ret
=
{
'parameter_id'
:
params
[
0
],
'parameter_source'
:
'algorithm'
,
'parameters'
:
params
[
1
]
}
send
(
CommandType
.
NewTrialJob
,
json_tricks
.
dumps
(
ret
))
self
.
credit
-=
1
self
.
_handle_trial_end
(
hyper_params
[
'parameter_id'
])
if
data
[
'trial_job_id'
]
in
self
.
job_id_para_id_map
:
del
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]]
def
handle_report_metric_data
(
self
,
data
):
"""
...
...
@@ -400,15 +400,37 @@ class Hyperband(MsgDispatcherBase):
ValueError
Data type not supported
"""
if
data
[
'type'
]
==
MetricType
.
REQUEST_PARAMETER
:
assert
multi_phase_enabled
()
assert
data
[
'trial_job_id'
]
is
not
None
assert
data
[
'parameter_index'
]
is
not
None
assert
data
[
'trial_job_id'
]
in
self
.
job_id_para_id_map
self
.
_handle_trial_end
(
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]])
ret
=
self
.
_get_one_trial_job
()
if
data
[
'trial_job_id'
]
is
not
None
:
ret
[
'trial_job_id'
]
=
data
[
'trial_job_id'
]
if
data
[
'parameter_index'
]
is
not
None
:
ret
[
'parameter_index'
]
=
data
[
'parameter_index'
]
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]]
=
ret
[
'parameter_id'
]
send
(
CommandType
.
SendTrialJobParameter
,
json_tricks
.
dumps
(
ret
))
else
:
value
=
extract_scalar_reward
(
data
[
'value'
])
bracket_id
,
i
,
_
=
data
[
'parameter_id'
].
split
(
'_'
)
bracket_id
=
int
(
bracket_id
)
if
data
[
'type'
]
==
'FINAL'
:
# add <trial_job_id, parameter_id> to self.job_id_para_id_map here,
# because when the first parameter_id is created, trial_job_id is not known yet.
if
data
[
'trial_job_id'
]
in
self
.
job_id_para_id_map
:
assert
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]]
==
data
[
'parameter_id'
]
else
:
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]]
=
data
[
'parameter_id'
]
if
data
[
'type'
]
==
MetricType
.
FINAL
:
# sys.maxsize indicates this value is from FINAL metric data, because data['sequence'] from FINAL metric
# and PERIODICAL metric are independent, thus, not comparable.
self
.
brackets
[
bracket_id
].
set_config_perf
(
int
(
i
),
data
[
'parameter_id'
],
sys
.
maxsize
,
value
)
self
.
completed_hyper_configs
.
append
(
data
)
elif
data
[
'type'
]
==
'
PERIODICAL
'
:
elif
data
[
'type'
]
==
MetricType
.
PERIODICAL
:
self
.
brackets
[
bracket_id
].
set_config_perf
(
int
(
i
),
data
[
'parameter_id'
],
data
[
'sequence'
],
value
)
else
:
raise
ValueError
(
'Data type not supported: {}'
.
format
(
data
[
'type'
]))
...
...
src/sdk/pynni/nni/msg_dispatcher.py
View file @
b7cd20e6
...
...
@@ -27,6 +27,7 @@ from .msg_dispatcher_base import MsgDispatcherBase
from
.assessor
import
AssessResult
from
.common
import
multi_thread_enabled
,
multi_phase_enabled
from
.env_vars
import
dispatcher_env_vars
from
.utils
import
MetricType
_logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -133,12 +134,12 @@ class MsgDispatcher(MsgDispatcherBase):
- 'value': metric value reported by nni.report_final_result()
- 'type': report type, support {'FINAL', 'PERIODICAL'}
"""
if
data
[
'type'
]
==
'
FINAL
'
:
if
data
[
'type'
]
==
MetricType
.
FINAL
:
self
.
_handle_final_metric_data
(
data
)
elif
data
[
'type'
]
==
'
PERIODICAL
'
:
elif
data
[
'type'
]
==
MetricType
.
PERIODICAL
:
if
self
.
assessor
is
not
None
:
self
.
_handle_intermediate_metric_data
(
data
)
elif
data
[
'type'
]
==
'
REQUEST_PARAMETER
'
:
elif
data
[
'type'
]
==
MetricType
.
REQUEST_PARAMETER
:
assert
multi_phase_enabled
()
assert
data
[
'trial_job_id'
]
is
not
None
assert
data
[
'parameter_index'
]
is
not
None
...
...
@@ -183,7 +184,7 @@ class MsgDispatcher(MsgDispatcherBase):
def
_handle_intermediate_metric_data
(
self
,
data
):
"""Call assessor to process intermediate results
"""
if
data
[
'type'
]
!=
'
PERIODICAL
'
:
if
data
[
'type'
]
!=
MetricType
.
PERIODICAL
:
return
if
self
.
assessor
is
None
:
return
...
...
@@ -224,7 +225,7 @@ class MsgDispatcher(MsgDispatcherBase):
trial is early stopped.
"""
_logger
.
debug
(
'Early stop notify tuner data: [%s]'
,
data
)
data
[
'type'
]
=
'
FINAL
'
data
[
'type'
]
=
MetricType
.
FINAL
if
multi_thread_enabled
():
self
.
_handle_final_metric_data
(
data
)
else
:
...
...
src/sdk/pynni/nni/utils.py
View file @
b7cd20e6
...
...
@@ -51,6 +51,14 @@ class NodeType:
NAME
=
'_name'
class
MetricType
:
"""The types of metric data
"""
FINAL
=
'FINAL'
PERIODICAL
=
'PERIODICAL'
REQUEST_PARAMETER
=
'REQUEST_PARAMETER'
def
split_index
(
params
):
"""
Delete index infromation from params
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment