Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
b7cd20e6
Unverified
Commit
b7cd20e6
authored
Jul 12, 2019
by
QuanluZhang
Committed by
GitHub
Jul 12, 2019
Browse files
support multi-phase in hyperband (#1257)
* support multi-phase in hyperband and bohb
parent
917ce97f
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
180 additions
and
91 deletions
+180
-91
src/sdk/pynni/nni/__main__.py
src/sdk/pynni/nni/__main__.py
+0
-2
src/sdk/pynni/nni/bohb_advisor/bohb_advisor.py
src/sdk/pynni/nni/bohb_advisor/bohb_advisor.py
+106
-46
src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py
src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py
+60
-38
src/sdk/pynni/nni/msg_dispatcher.py
src/sdk/pynni/nni/msg_dispatcher.py
+6
-5
src/sdk/pynni/nni/utils.py
src/sdk/pynni/nni/utils.py
+8
-0
No files found.
src/sdk/pynni/nni/__main__.py
View file @
b7cd20e6
...
@@ -130,8 +130,6 @@ def main():
...
@@ -130,8 +130,6 @@ def main():
if
args
.
advisor_class_name
:
if
args
.
advisor_class_name
:
# advisor is enabled and starts to run
# advisor is enabled and starts to run
if
args
.
multi_phase
:
raise
AssertionError
(
'multi_phase has not been supported in advisor'
)
if
args
.
advisor_class_name
in
AdvisorModuleName
:
if
args
.
advisor_class_name
in
AdvisorModuleName
:
dispatcher
=
create_builtin_class_instance
(
dispatcher
=
create_builtin_class_instance
(
args
.
advisor_class_name
,
args
.
advisor_class_name
,
...
...
src/sdk/pynni/nni/bohb_advisor/bohb_advisor.py
View file @
b7cd20e6
...
@@ -31,7 +31,8 @@ import ConfigSpace.hyperparameters as CSH
...
@@ -31,7 +31,8 @@ import ConfigSpace.hyperparameters as CSH
from
nni.protocol
import
CommandType
,
send
from
nni.protocol
import
CommandType
,
send
from
nni.msg_dispatcher_base
import
MsgDispatcherBase
from
nni.msg_dispatcher_base
import
MsgDispatcherBase
from
nni.utils
import
OptimizeMode
,
extract_scalar_reward
,
randint_to_quniform
from
nni.utils
import
OptimizeMode
,
MetricType
,
extract_scalar_reward
,
randint_to_quniform
from
nni.common
import
multi_phase_enabled
from
.config_generator
import
CG_BOHB
from
.config_generator
import
CG_BOHB
...
@@ -79,7 +80,7 @@ def create_bracket_parameter_id(brackets_id, brackets_curr_decay, increased_id=-
...
@@ -79,7 +80,7 @@ def create_bracket_parameter_id(brackets_id, brackets_curr_decay, increased_id=-
return
params_id
return
params_id
class
Bracket
():
class
Bracket
(
object
):
"""
"""
A bracket in BOHB, all the information of a bracket is managed by
A bracket in BOHB, all the information of a bracket is managed by
an instance of this class.
an instance of this class.
...
@@ -329,11 +330,12 @@ class BOHB(MsgDispatcherBase):
...
@@ -329,11 +330,12 @@ class BOHB(MsgDispatcherBase):
# config generator
# config generator
self
.
cg
=
None
self
.
cg
=
None
def
load_checkpoint
(
self
):
# record the latest parameter_id of the trial job trial_job_id.
pass
# if there is no running parameter_id, self.job_id_para_id_map[trial_job_id] == None
# new trial job is added to this dict and finished trial job is removed from it.
def
save_checkpoint
(
self
):
self
.
job_id_para_id_map
=
dict
()
pass
# record the unsatisfied parameter request from trial jobs
self
.
unsatisfied_jobs
=
[]
def
handle_initialize
(
self
,
data
):
def
handle_initialize
(
self
,
data
):
"""Initialize Tuner, including creating Bayesian optimization-based parametric models
"""Initialize Tuner, including creating Bayesian optimization-based parametric models
...
@@ -399,7 +401,7 @@ class BOHB(MsgDispatcherBase):
...
@@ -399,7 +401,7 @@ class BOHB(MsgDispatcherBase):
for
_
in
range
(
self
.
credit
):
for
_
in
range
(
self
.
credit
):
self
.
_request_one_trial_job
()
self
.
_request_one_trial_job
()
def
_
reques
t_one_trial_job
(
self
):
def
_
ge
t_one_trial_job
(
self
):
"""get one trial job, i.e., one hyperparameter configuration.
"""get one trial job, i.e., one hyperparameter configuration.
If this function is called, Command will be sent by BOHB:
If this function is called, Command will be sent by BOHB:
...
@@ -423,7 +425,7 @@ class BOHB(MsgDispatcherBase):
...
@@ -423,7 +425,7 @@ class BOHB(MsgDispatcherBase):
'parameters'
:
''
'parameters'
:
''
}
}
send
(
CommandType
.
NoMoreTrialJobs
,
json_tricks
.
dumps
(
ret
))
send
(
CommandType
.
NoMoreTrialJobs
,
json_tricks
.
dumps
(
ret
))
return
return
None
assert
self
.
generated_hyper_configs
assert
self
.
generated_hyper_configs
params
=
self
.
generated_hyper_configs
.
pop
()
params
=
self
.
generated_hyper_configs
.
pop
()
ret
=
{
ret
=
{
...
@@ -432,8 +434,29 @@ class BOHB(MsgDispatcherBase):
...
@@ -432,8 +434,29 @@ class BOHB(MsgDispatcherBase):
'parameters'
:
params
[
1
]
'parameters'
:
params
[
1
]
}
}
self
.
parameters
[
params
[
0
]]
=
params
[
1
]
self
.
parameters
[
params
[
0
]]
=
params
[
1
]
send
(
CommandType
.
NewTrialJob
,
json_tricks
.
dumps
(
ret
))
return
ret
self
.
credit
-=
1
def
_request_one_trial_job
(
self
):
"""get one trial job, i.e., one hyperparameter configuration.
If this function is called, Command will be sent by BOHB:
a. If there is a parameter need to run, will return "NewTrialJob" with a dict:
{
'parameter_id': id of new hyperparameter
'parameter_source': 'algorithm'
'parameters': value of new hyperparameter
}
b. If BOHB don't have parameter waiting, will return "NoMoreTrialJobs" with
{
'parameter_id': '-1_0_0',
'parameter_source': 'algorithm',
'parameters': ''
}
"""
ret
=
self
.
_get_one_trial_job
()
if
ret
is
not
None
:
send
(
CommandType
.
NewTrialJob
,
json_tricks
.
dumps
(
ret
))
self
.
credit
-=
1
def
handle_update_search_space
(
self
,
data
):
def
handle_update_search_space
(
self
,
data
):
"""change json format to ConfigSpace format dict<dict> -> configspace
"""change json format to ConfigSpace format dict<dict> -> configspace
...
@@ -502,23 +525,38 @@ class BOHB(MsgDispatcherBase):
...
@@ -502,23 +525,38 @@ class BOHB(MsgDispatcherBase):
hyper_params: the hyperparameters (a string) generated and returned by tuner
hyper_params: the hyperparameters (a string) generated and returned by tuner
"""
"""
logger
.
debug
(
'Tuner handle trial end, result is %s'
,
data
)
logger
.
debug
(
'Tuner handle trial end, result is %s'
,
data
)
hyper_params
=
json_tricks
.
loads
(
data
[
'hyper_params'
])
hyper_params
=
json_tricks
.
loads
(
data
[
'hyper_params'
])
s
,
i
,
_
=
hyper_params
[
'parameter_id'
].
split
(
'_'
)
self
.
_handle_trial_end
(
hyper_params
[
'parameter_id'
])
if
data
[
'trial_job_id'
]
in
self
.
job_id_para_id_map
:
del
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]]
def
_send_new_trial
(
self
):
while
self
.
unsatisfied_jobs
:
ret
=
self
.
_get_one_trial_job
()
if
ret
is
None
:
break
one_unsatisfied
=
self
.
unsatisfied_jobs
.
pop
(
0
)
ret
[
'trial_job_id'
]
=
one_unsatisfied
[
'trial_job_id'
]
ret
[
'parameter_index'
]
=
one_unsatisfied
[
'parameter_index'
]
# update parameter_id in self.job_id_para_id_map
self
.
job_id_para_id_map
[
ret
[
'trial_job_id'
]]
=
ret
[
'parameter_id'
]
send
(
CommandType
.
SendTrialJobParameter
,
json_tricks
.
dumps
(
ret
))
for
_
in
range
(
self
.
credit
):
self
.
_request_one_trial_job
()
def
_handle_trial_end
(
self
,
parameter_id
):
s
,
i
,
_
=
parameter_id
.
split
(
'_'
)
hyper_configs
=
self
.
brackets
[
int
(
s
)].
inform_trial_end
(
int
(
i
))
hyper_configs
=
self
.
brackets
[
int
(
s
)].
inform_trial_end
(
int
(
i
))
if
hyper_configs
is
not
None
:
if
hyper_configs
is
not
None
:
logger
.
debug
(
logger
.
debug
(
'bracket %s next round %s, hyper_configs: %s'
,
s
,
i
,
hyper_configs
)
'bracket %s next round %s, hyper_configs: %s'
,
s
,
i
,
hyper_configs
)
self
.
generated_hyper_configs
=
self
.
generated_hyper_configs
+
hyper_configs
self
.
generated_hyper_configs
=
self
.
generated_hyper_configs
+
hyper_configs
for
_
in
range
(
self
.
credit
):
self
.
_request_one_trial_job
()
# Finish this bracket and generate a new bracket
# Finish this bracket and generate a new bracket
elif
self
.
brackets
[
int
(
s
)].
no_more_trial
:
elif
self
.
brackets
[
int
(
s
)].
no_more_trial
:
self
.
curr_s
-=
1
self
.
curr_s
-=
1
self
.
generate_new_bracket
()
self
.
generate_new_bracket
()
for
_
in
range
(
self
.
credit
):
self
.
_send_new_trial
()
self
.
_request_one_trial_job
()
def
handle_report_metric_data
(
self
,
data
):
def
handle_report_metric_data
(
self
,
data
):
"""reveice the metric data and update Bayesian optimization with final result
"""reveice the metric data and update Bayesian optimization with final result
...
@@ -535,36 +573,58 @@ class BOHB(MsgDispatcherBase):
...
@@ -535,36 +573,58 @@ class BOHB(MsgDispatcherBase):
"""
"""
logger
.
debug
(
'handle report metric data = %s'
,
data
)
logger
.
debug
(
'handle report metric data = %s'
,
data
)
assert
'value'
in
data
if
data
[
'type'
]
==
MetricType
.
REQUEST_PARAMETER
:
value
=
extract_scalar_reward
(
data
[
'value'
])
assert
multi_phase_enabled
()
if
self
.
optimize_mode
is
OptimizeMode
.
Maximize
:
assert
data
[
'trial_job_id'
]
is
not
None
reward
=
-
value
assert
data
[
'parameter_index'
]
is
not
None
else
:
assert
data
[
'trial_job_id'
]
in
self
.
job_id_para_id_map
reward
=
value
self
.
_handle_trial_end
(
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]])
assert
'parameter_id'
in
data
ret
=
self
.
_get_one_trial_job
()
s
,
i
,
_
=
data
[
'parameter_id'
].
split
(
'_'
)
if
ret
is
None
:
self
.
unsatisfied_jobs
.
append
({
'trial_job_id'
:
data
[
'trial_job_id'
],
'parameter_index'
:
data
[
'parameter_index'
]})
logger
.
debug
(
'bracket id = %s, metrics value = %s, type = %s'
,
s
,
value
,
data
[
'type'
])
else
:
s
=
int
(
s
)
ret
[
'trial_job_id'
]
=
data
[
'trial_job_id'
]
ret
[
'parameter_index'
]
=
data
[
'parameter_index'
]
assert
'type'
in
data
# update parameter_id in self.job_id_para_id_map
if
data
[
'type'
]
==
'FINAL'
:
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]]
=
ret
[
'parameter_id'
]
# and PERIODICAL metric are independent, thus, not comparable.
send
(
CommandType
.
SendTrialJobParameter
,
json_tricks
.
dumps
(
ret
))
assert
'sequence'
in
data
self
.
brackets
[
s
].
set_config_perf
(
int
(
i
),
data
[
'parameter_id'
],
sys
.
maxsize
,
value
)
self
.
completed_hyper_configs
.
append
(
data
)
_parameters
=
self
.
parameters
[
data
[
'parameter_id'
]]
_parameters
.
pop
(
_KEY
)
# update BO with loss, max_s budget, hyperparameters
self
.
cg
.
new_result
(
loss
=
reward
,
budget
=
data
[
'sequence'
],
parameters
=
_parameters
,
update_model
=
True
)
elif
data
[
'type'
]
==
'PERIODICAL'
:
self
.
brackets
[
s
].
set_config_perf
(
int
(
i
),
data
[
'parameter_id'
],
data
[
'sequence'
],
value
)
else
:
else
:
raise
ValueError
(
assert
'value'
in
data
'Data type not supported: {}'
.
format
(
data
[
'type'
]))
value
=
extract_scalar_reward
(
data
[
'value'
])
if
self
.
optimize_mode
is
OptimizeMode
.
Maximize
:
reward
=
-
value
else
:
reward
=
value
assert
'parameter_id'
in
data
s
,
i
,
_
=
data
[
'parameter_id'
].
split
(
'_'
)
logger
.
debug
(
'bracket id = %s, metrics value = %s, type = %s'
,
s
,
value
,
data
[
'type'
])
s
=
int
(
s
)
# add <trial_job_id, parameter_id> to self.job_id_para_id_map here,
# because when the first parameter_id is created, trial_job_id is not known yet.
if
data
[
'trial_job_id'
]
in
self
.
job_id_para_id_map
:
assert
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]]
==
data
[
'parameter_id'
]
else
:
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]]
=
data
[
'parameter_id'
]
assert
'type'
in
data
if
data
[
'type'
]
==
MetricType
.
FINAL
:
# and PERIODICAL metric are independent, thus, not comparable.
assert
'sequence'
in
data
self
.
brackets
[
s
].
set_config_perf
(
int
(
i
),
data
[
'parameter_id'
],
sys
.
maxsize
,
value
)
self
.
completed_hyper_configs
.
append
(
data
)
_parameters
=
self
.
parameters
[
data
[
'parameter_id'
]]
_parameters
.
pop
(
_KEY
)
# update BO with loss, max_s budget, hyperparameters
self
.
cg
.
new_result
(
loss
=
reward
,
budget
=
data
[
'sequence'
],
parameters
=
_parameters
,
update_model
=
True
)
elif
data
[
'type'
]
==
MetricType
.
PERIODICAL
:
self
.
brackets
[
s
].
set_config_perf
(
int
(
i
),
data
[
'parameter_id'
],
data
[
'sequence'
],
value
)
else
:
raise
ValueError
(
'Data type not supported: {}'
.
format
(
data
[
'type'
]))
def
handle_add_customized_trial
(
self
,
data
):
def
handle_add_customized_trial
(
self
,
data
):
pass
pass
...
...
src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py
View file @
b7cd20e6
...
@@ -30,8 +30,8 @@ import json_tricks
...
@@ -30,8 +30,8 @@ import json_tricks
from
nni.protocol
import
CommandType
,
send
from
nni.protocol
import
CommandType
,
send
from
nni.msg_dispatcher_base
import
MsgDispatcherBase
from
nni.msg_dispatcher_base
import
MsgDispatcherBase
from
nni.common
import
init_logger
from
nni.common
import
init_logger
,
multi_phase_enabled
from
nni.utils
import
NodeType
,
OptimizeMode
,
extract_scalar_reward
,
randint_to_quniform
from
nni.utils
import
NodeType
,
OptimizeMode
,
MetricType
,
extract_scalar_reward
,
randint_to_quniform
import
nni.parameter_expressions
as
parameter_expressions
import
nni.parameter_expressions
as
parameter_expressions
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -277,7 +277,7 @@ class Hyperband(MsgDispatcherBase):
...
@@ -277,7 +277,7 @@ class Hyperband(MsgDispatcherBase):
optimize_mode: str
optimize_mode: str
optimize mode, 'maximize' or 'minimize'
optimize mode, 'maximize' or 'minimize'
"""
"""
def
__init__
(
self
,
R
,
eta
=
3
,
optimize_mode
=
'maximize'
):
def
__init__
(
self
,
R
=
60
,
eta
=
3
,
optimize_mode
=
'maximize'
):
"""B = (s_max + 1)R"""
"""B = (s_max + 1)R"""
super
(
Hyperband
,
self
).
__init__
()
super
(
Hyperband
,
self
).
__init__
()
self
.
R
=
R
# pylint: disable=invalid-name
self
.
R
=
R
# pylint: disable=invalid-name
...
@@ -296,11 +296,10 @@ class Hyperband(MsgDispatcherBase):
...
@@ -296,11 +296,10 @@ class Hyperband(MsgDispatcherBase):
# In this case, tuner increases self.credit to issue a trial config sometime later.
# In this case, tuner increases self.credit to issue a trial config sometime later.
self
.
credit
=
0
self
.
credit
=
0
def
load_checkpoint
(
self
):
# record the latest parameter_id of the trial job trial_job_id.
pass
# if there is no running parameter_id, self.job_id_para_id_map[trial_job_id] == None
# new trial job is added to this dict and finished trial job is removed from it.
def
save_checkpoint
(
self
):
self
.
job_id_para_id_map
=
dict
()
pass
def
handle_initialize
(
self
,
data
):
def
handle_initialize
(
self
,
data
):
"""data is search space
"""data is search space
...
@@ -321,9 +320,10 @@ class Hyperband(MsgDispatcherBase):
...
@@ -321,9 +320,10 @@ class Hyperband(MsgDispatcherBase):
number of trial jobs
number of trial jobs
"""
"""
for
_
in
range
(
data
):
for
_
in
range
(
data
):
self
.
_request_one_trial_job
()
ret
=
self
.
_get_one_trial_job
()
send
(
CommandType
.
NewTrialJob
,
json_tricks
.
dumps
(
ret
))
def
_
reques
t_one_trial_job
(
self
):
def
_
ge
t_one_trial_job
(
self
):
"""get one trial job, i.e., one hyperparameter configuration."""
"""get one trial job, i.e., one hyperparameter configuration."""
if
not
self
.
generated_hyper_configs
:
if
not
self
.
generated_hyper_configs
:
if
self
.
curr_s
<
0
:
if
self
.
curr_s
<
0
:
...
@@ -346,7 +346,8 @@ class Hyperband(MsgDispatcherBase):
...
@@ -346,7 +346,8 @@ class Hyperband(MsgDispatcherBase):
'parameter_source'
:
'algorithm'
,
'parameter_source'
:
'algorithm'
,
'parameters'
:
params
[
1
]
'parameters'
:
params
[
1
]
}
}
send
(
CommandType
.
NewTrialJob
,
json_tricks
.
dumps
(
ret
))
return
ret
def
handle_update_search_space
(
self
,
data
):
def
handle_update_search_space
(
self
,
data
):
"""data: JSON object, which is search space
"""data: JSON object, which is search space
...
@@ -360,6 +361,18 @@ class Hyperband(MsgDispatcherBase):
...
@@ -360,6 +361,18 @@ class Hyperband(MsgDispatcherBase):
randint_to_quniform
(
self
.
searchspace_json
)
randint_to_quniform
(
self
.
searchspace_json
)
self
.
random_state
=
np
.
random
.
RandomState
()
self
.
random_state
=
np
.
random
.
RandomState
()
def
_handle_trial_end
(
self
,
parameter_id
):
"""
Parameters
----------
parameter_id: parameter id of the finished config
"""
bracket_id
,
i
,
_
=
parameter_id
.
split
(
'_'
)
hyper_configs
=
self
.
brackets
[
int
(
bracket_id
)].
inform_trial_end
(
int
(
i
))
if
hyper_configs
is
not
None
:
_logger
.
debug
(
'bracket %s next round %s, hyper_configs: %s'
,
bracket_id
,
i
,
hyper_configs
)
self
.
generated_hyper_configs
=
self
.
generated_hyper_configs
+
hyper_configs
def
handle_trial_end
(
self
,
data
):
def
handle_trial_end
(
self
,
data
):
"""
"""
Parameters
Parameters
...
@@ -371,22 +384,9 @@ class Hyperband(MsgDispatcherBase):
...
@@ -371,22 +384,9 @@ class Hyperband(MsgDispatcherBase):
hyper_params: the hyperparameters (a string) generated and returned by tuner
hyper_params: the hyperparameters (a string) generated and returned by tuner
"""
"""
hyper_params
=
json_tricks
.
loads
(
data
[
'hyper_params'
])
hyper_params
=
json_tricks
.
loads
(
data
[
'hyper_params'
])
bracket_id
,
i
,
_
=
hyper_params
[
'parameter_id'
].
split
(
'_'
)
self
.
_handle_trial_end
(
hyper_params
[
'parameter_id'
])
hyper_configs
=
self
.
brackets
[
int
(
bracket_id
)].
inform_trial_end
(
int
(
i
))
if
data
[
'trial_job_id'
]
in
self
.
job_id_para_id_map
:
if
hyper_configs
is
not
None
:
del
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]]
_logger
.
debug
(
'bracket %s next round %s, hyper_configs: %s'
,
bracket_id
,
i
,
hyper_configs
)
self
.
generated_hyper_configs
=
self
.
generated_hyper_configs
+
hyper_configs
for
_
in
range
(
self
.
credit
):
if
not
self
.
generated_hyper_configs
:
break
params
=
self
.
generated_hyper_configs
.
pop
()
ret
=
{
'parameter_id'
:
params
[
0
],
'parameter_source'
:
'algorithm'
,
'parameters'
:
params
[
1
]
}
send
(
CommandType
.
NewTrialJob
,
json_tricks
.
dumps
(
ret
))
self
.
credit
-=
1
def
handle_report_metric_data
(
self
,
data
):
def
handle_report_metric_data
(
self
,
data
):
"""
"""
...
@@ -400,18 +400,40 @@ class Hyperband(MsgDispatcherBase):
...
@@ -400,18 +400,40 @@ class Hyperband(MsgDispatcherBase):
ValueError
ValueError
Data type not supported
Data type not supported
"""
"""
value
=
extract_scalar_reward
(
data
[
'value'
])
if
data
[
'type'
]
==
MetricType
.
REQUEST_PARAMETER
:
bracket_id
,
i
,
_
=
data
[
'parameter_id'
].
split
(
'_'
)
assert
multi_phase_enabled
()
bracket_id
=
int
(
bracket_id
)
assert
data
[
'trial_job_id'
]
is
not
None
if
data
[
'type'
]
==
'FINAL'
:
assert
data
[
'parameter_index'
]
is
not
None
# sys.maxsize indicates this value is from FINAL metric data, because data['sequence'] from FINAL metric
assert
data
[
'trial_job_id'
]
in
self
.
job_id_para_id_map
# and PERIODICAL metric are independent, thus, not comparable.
self
.
_handle_trial_end
(
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]])
self
.
brackets
[
bracket_id
].
set_config_perf
(
int
(
i
),
data
[
'parameter_id'
],
sys
.
maxsize
,
value
)
ret
=
self
.
_get_one_trial_job
()
self
.
completed_hyper_configs
.
append
(
data
)
if
data
[
'trial_job_id'
]
is
not
None
:
elif
data
[
'type'
]
==
'PERIODICAL'
:
ret
[
'trial_job_id'
]
=
data
[
'trial_job_id'
]
self
.
brackets
[
bracket_id
].
set_config_perf
(
int
(
i
),
data
[
'parameter_id'
],
data
[
'sequence'
],
value
)
if
data
[
'parameter_index'
]
is
not
None
:
ret
[
'parameter_index'
]
=
data
[
'parameter_index'
]
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]]
=
ret
[
'parameter_id'
]
send
(
CommandType
.
SendTrialJobParameter
,
json_tricks
.
dumps
(
ret
))
else
:
else
:
raise
ValueError
(
'Data type not supported: {}'
.
format
(
data
[
'type'
]))
value
=
extract_scalar_reward
(
data
[
'value'
])
bracket_id
,
i
,
_
=
data
[
'parameter_id'
].
split
(
'_'
)
bracket_id
=
int
(
bracket_id
)
# add <trial_job_id, parameter_id> to self.job_id_para_id_map here,
# because when the first parameter_id is created, trial_job_id is not known yet.
if
data
[
'trial_job_id'
]
in
self
.
job_id_para_id_map
:
assert
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]]
==
data
[
'parameter_id'
]
else
:
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]]
=
data
[
'parameter_id'
]
if
data
[
'type'
]
==
MetricType
.
FINAL
:
# sys.maxsize indicates this value is from FINAL metric data, because data['sequence'] from FINAL metric
# and PERIODICAL metric are independent, thus, not comparable.
self
.
brackets
[
bracket_id
].
set_config_perf
(
int
(
i
),
data
[
'parameter_id'
],
sys
.
maxsize
,
value
)
self
.
completed_hyper_configs
.
append
(
data
)
elif
data
[
'type'
]
==
MetricType
.
PERIODICAL
:
self
.
brackets
[
bracket_id
].
set_config_perf
(
int
(
i
),
data
[
'parameter_id'
],
data
[
'sequence'
],
value
)
else
:
raise
ValueError
(
'Data type not supported: {}'
.
format
(
data
[
'type'
]))
def
handle_add_customized_trial
(
self
,
data
):
def
handle_add_customized_trial
(
self
,
data
):
pass
pass
...
...
src/sdk/pynni/nni/msg_dispatcher.py
View file @
b7cd20e6
...
@@ -27,6 +27,7 @@ from .msg_dispatcher_base import MsgDispatcherBase
...
@@ -27,6 +27,7 @@ from .msg_dispatcher_base import MsgDispatcherBase
from
.assessor
import
AssessResult
from
.assessor
import
AssessResult
from
.common
import
multi_thread_enabled
,
multi_phase_enabled
from
.common
import
multi_thread_enabled
,
multi_phase_enabled
from
.env_vars
import
dispatcher_env_vars
from
.env_vars
import
dispatcher_env_vars
from
.utils
import
MetricType
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -133,12 +134,12 @@ class MsgDispatcher(MsgDispatcherBase):
...
@@ -133,12 +134,12 @@ class MsgDispatcher(MsgDispatcherBase):
- 'value': metric value reported by nni.report_final_result()
- 'value': metric value reported by nni.report_final_result()
- 'type': report type, support {'FINAL', 'PERIODICAL'}
- 'type': report type, support {'FINAL', 'PERIODICAL'}
"""
"""
if
data
[
'type'
]
==
'
FINAL
'
:
if
data
[
'type'
]
==
MetricType
.
FINAL
:
self
.
_handle_final_metric_data
(
data
)
self
.
_handle_final_metric_data
(
data
)
elif
data
[
'type'
]
==
'
PERIODICAL
'
:
elif
data
[
'type'
]
==
MetricType
.
PERIODICAL
:
if
self
.
assessor
is
not
None
:
if
self
.
assessor
is
not
None
:
self
.
_handle_intermediate_metric_data
(
data
)
self
.
_handle_intermediate_metric_data
(
data
)
elif
data
[
'type'
]
==
'
REQUEST_PARAMETER
'
:
elif
data
[
'type'
]
==
MetricType
.
REQUEST_PARAMETER
:
assert
multi_phase_enabled
()
assert
multi_phase_enabled
()
assert
data
[
'trial_job_id'
]
is
not
None
assert
data
[
'trial_job_id'
]
is
not
None
assert
data
[
'parameter_index'
]
is
not
None
assert
data
[
'parameter_index'
]
is
not
None
...
@@ -183,7 +184,7 @@ class MsgDispatcher(MsgDispatcherBase):
...
@@ -183,7 +184,7 @@ class MsgDispatcher(MsgDispatcherBase):
def
_handle_intermediate_metric_data
(
self
,
data
):
def
_handle_intermediate_metric_data
(
self
,
data
):
"""Call assessor to process intermediate results
"""Call assessor to process intermediate results
"""
"""
if
data
[
'type'
]
!=
'
PERIODICAL
'
:
if
data
[
'type'
]
!=
MetricType
.
PERIODICAL
:
return
return
if
self
.
assessor
is
None
:
if
self
.
assessor
is
None
:
return
return
...
@@ -224,7 +225,7 @@ class MsgDispatcher(MsgDispatcherBase):
...
@@ -224,7 +225,7 @@ class MsgDispatcher(MsgDispatcherBase):
trial is early stopped.
trial is early stopped.
"""
"""
_logger
.
debug
(
'Early stop notify tuner data: [%s]'
,
data
)
_logger
.
debug
(
'Early stop notify tuner data: [%s]'
,
data
)
data
[
'type'
]
=
'
FINAL
'
data
[
'type'
]
=
MetricType
.
FINAL
if
multi_thread_enabled
():
if
multi_thread_enabled
():
self
.
_handle_final_metric_data
(
data
)
self
.
_handle_final_metric_data
(
data
)
else
:
else
:
...
...
src/sdk/pynni/nni/utils.py
View file @
b7cd20e6
...
@@ -51,6 +51,14 @@ class NodeType:
...
@@ -51,6 +51,14 @@ class NodeType:
NAME
=
'_name'
NAME
=
'_name'
class
MetricType
:
"""The types of metric data
"""
FINAL
=
'FINAL'
PERIODICAL
=
'PERIODICAL'
REQUEST_PARAMETER
=
'REQUEST_PARAMETER'
def
split_index
(
params
):
def
split_index
(
params
):
"""
"""
Delete index infromation from params
Delete index infromation from params
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment