Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
bcc640c4
Unverified
Commit
bcc640c4
authored
Oct 12, 2022
by
QuanluZhang
Committed by
GitHub
Oct 12, 2022
Browse files
[nas] fix issue introduced by the trial recovery feature (#5109)
parent
87677df8
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
105 additions
and
41 deletions
+105
-41
nni/algorithms/hpo/bohb_advisor/bohb_advisor.py
nni/algorithms/hpo/bohb_advisor/bohb_advisor.py
+15
-2
nni/algorithms/hpo/hyperband_advisor.py
nni/algorithms/hpo/hyperband_advisor.py
+10
-1
nni/algorithms/hpo/tpe_tuner.py
nni/algorithms/hpo/tpe_tuner.py
+0
-13
nni/nas/execution/common/integration.py
nni/nas/execution/common/integration.py
+23
-3
nni/nas/execution/common/integration_api.py
nni/nas/execution/common/integration_api.py
+1
-0
nni/recoverable.py
nni/recoverable.py
+30
-0
nni/runtime/msg_dispatcher.py
nni/runtime/msg_dispatcher.py
+15
-9
nni/runtime/msg_dispatcher_base.py
nni/runtime/msg_dispatcher_base.py
+1
-0
nni/tuner.py
nni/tuner.py
+0
-8
test/algo/nas/test_cgo_engine.py
test/algo/nas/test_cgo_engine.py
+3
-0
test/ut/nas/test_engine.py
test/ut/nas/test_engine.py
+2
-0
test/ut/sdk/test_assessor.py
test/ut/sdk/test_assessor.py
+5
-5
No files found.
nni/algorithms/hpo/bohb_advisor/bohb_advisor.py
View file @
bcc640c4
...
...
@@ -648,8 +648,11 @@ class BOHB(MsgDispatcherBase):
event: the job's state
hyper_params: the hyperparameters (a string) generated and returned by tuner
"""
logger
.
debug
(
'Tuner handle trial end, result is %s'
,
data
)
hyper_params
=
nni
.
load
(
data
[
'hyper_params'
])
if
self
.
is_created_in_previous_exp
(
hyper_params
[
'parameter_id'
]):
# The end of the recovered trial is ignored
return
logger
.
debug
(
'Tuner handle trial end, result is %s'
,
data
)
self
.
_handle_trial_end
(
hyper_params
[
'parameter_id'
])
if
data
[
'trial_job_id'
]
in
self
.
job_id_para_id_map
:
del
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]]
...
...
@@ -695,6 +698,13 @@ class BOHB(MsgDispatcherBase):
ValueError
Data type not supported
"""
if
self
.
is_created_in_previous_exp
(
data
[
'parameter_id'
]):
if
data
[
'type'
]
==
MetricType
.
FINAL
:
# only deal with final metric using import data
param
=
self
.
get_previous_param
(
data
[
'parameter_id'
])
trial_data
=
[{
'parameter'
:
param
,
'value'
:
nni
.
load
(
data
[
'value'
])}]
self
.
handle_import_data
(
trial_data
)
return
logger
.
debug
(
'handle report metric data = %s'
,
data
)
if
'value'
in
data
:
data
[
'value'
]
=
nni
.
load
(
data
[
'value'
])
...
...
@@ -752,7 +762,10 @@ class BOHB(MsgDispatcherBase):
'Data type not supported: {}'
.
format
(
data
[
'type'
]))
def
handle_add_customized_trial
(
self
,
data
):
pass
global
_next_parameter_id
# data: parameters
previous_max_param_id
=
self
.
recover_parameter_id
(
data
)
_next_parameter_id
=
previous_max_param_id
+
1
def
handle_import_data
(
self
,
data
):
"""Import additional data for tuning
...
...
nni/algorithms/hpo/hyperband_advisor.py
View file @
bcc640c4
...
...
@@ -522,6 +522,9 @@ class Hyperband(MsgDispatcherBase):
hyper_params: the hyperparameters (a string) generated and returned by tuner
"""
hyper_params
=
nni
.
load
(
data
[
'hyper_params'
])
if
self
.
is_created_in_previous_exp
(
hyper_params
[
'parameter_id'
]):
# The end of the recovered trial is ignored
return
self
.
_handle_trial_end
(
hyper_params
[
'parameter_id'
])
if
data
[
'trial_job_id'
]
in
self
.
job_id_para_id_map
:
del
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]]
...
...
@@ -538,6 +541,9 @@ class Hyperband(MsgDispatcherBase):
ValueError
Data type not supported
"""
if
self
.
is_created_in_previous_exp
(
data
[
'parameter_id'
]):
# do not support recovering the algorithm state
return
if
'value'
in
data
:
data
[
'value'
]
=
nni
.
load
(
data
[
'value'
])
# multiphase? need to check
...
...
@@ -576,7 +582,10 @@ class Hyperband(MsgDispatcherBase):
raise
ValueError
(
'Data type not supported: {}'
.
format
(
data
[
'type'
]))
def
handle_add_customized_trial
(
self
,
data
):
pass
global
_next_parameter_id
# data: parameters
previous_max_param_id
=
self
.
recover_parameter_id
(
data
)
_next_parameter_id
=
previous_max_param_id
+
1
def
handle_import_data
(
self
,
data
):
pass
nni/algorithms/hpo/tpe_tuner.py
View file @
bcc640c4
...
...
@@ -218,19 +218,6 @@ class TpeTuner(Tuner):
self
.
dedup
.
add_history
(
param
)
_logger
.
info
(
f
'Replayed
{
len
(
data
)
}
FINISHED trials'
)
def
import_customized_data
(
self
,
data
):
# for dedup customized / resumed
if
isinstance
(
data
,
str
):
data
=
nni
.
load
(
data
)
for
trial
in
data
:
# {'parameter_id': 0, 'parameter_source': 'resumed', 'parameters': {'batch_size': 128, ...}
if
isinstance
(
trial
,
str
):
trial
=
nni
.
load
(
trial
)
param
=
format_parameters
(
trial
[
'parameters'
],
self
.
space
)
self
.
_running_params
[
trial
[
'parameter_id'
]]
=
param
self
.
dedup
.
add_history
(
param
)
_logger
.
info
(
f
'Replayed
{
len
(
data
)
}
RUNING/WAITING trials'
)
def
suggest
(
args
,
rng
,
space
,
history
):
params
=
{}
for
key
,
spec
in
space
.
items
():
...
...
nni/nas/execution/common/integration.py
View file @
bcc640c4
...
...
@@ -4,6 +4,7 @@
__all__
=
[
'RetiariiAdvisor'
]
import
logging
import
time
import
os
from
typing
import
Any
,
Callable
,
Optional
,
Dict
,
List
,
Tuple
...
...
@@ -60,11 +61,12 @@ class RetiariiAdvisor(MsgDispatcherBase):
self
.
final_metric_callback
:
Optional
[
Callable
[[
int
,
MetricData
],
None
]]
=
None
self
.
parameters_count
=
0
# Sometimes messages arrive first before the callbacks get registered.
# Or in case that we allow engine to be absent during the experiment.
# Here we need to store the messages and invoke them later.
self
.
call_queue
:
List
[
Tuple
[
str
,
list
]]
=
[]
# this is for waiting the to-be-recovered trials from nnimanager
self
.
_advisor_initialized
=
False
def
register_callbacks
(
self
,
callbacks
:
Dict
[
str
,
Callable
[...,
None
]]):
"""
...
...
@@ -167,6 +169,10 @@ class RetiariiAdvisor(MsgDispatcherBase):
Parameter ID that is assigned to this parameter,
which will be used for identification in future.
"""
while
not
self
.
_advisor_initialized
:
_logger
.
info
(
'Wait for RetiariiAdvisor to be initialized...'
)
time
.
sleep
(
0.5
)
self
.
parameters_count
+=
1
if
placement_constraint
is
None
:
placement_constraint
=
{
...
...
@@ -204,6 +210,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
self
.
send
(
CommandType
.
NoMoreTrialJobs
,
''
)
def
handle_request_trial_jobs
(
self
,
num_trials
):
self
.
_advisor_initialized
=
True
_logger
.
debug
(
'Request trial jobs: %s'
,
num_trials
)
self
.
invoke_callback
(
'request_trial_jobs'
,
num_trials
)
...
...
@@ -212,10 +219,22 @@ class RetiariiAdvisor(MsgDispatcherBase):
self
.
search_space
=
data
def
handle_trial_end
(
self
,
data
):
# TODO: we should properly handle the trials in self._customized_parameter_ids instead of ignoring
id_
=
nni
.
load
(
data
[
'hyper_params'
])[
'parameter_id'
]
if
self
.
is_created_in_previous_exp
(
id_
):
_logger
.
info
(
'The end of the recovered trial %d is ignored'
,
id_
)
return
_logger
.
debug
(
'Trial end: %s'
,
data
)
self
.
invoke_callback
(
'trial_end'
,
nni
.
load
(
data
[
'hyper_params'
])[
'parameter_id'
]
,
data
[
'event'
]
==
'SUCCEEDED'
)
self
.
invoke_callback
(
'trial_end'
,
id_
,
data
[
'event'
]
==
'SUCCEEDED'
)
def
handle_report_metric_data
(
self
,
data
):
# TODO: we should properly handle the trials in self._customized_parameter_ids instead of ignoring
if
self
.
is_created_in_previous_exp
(
data
[
'parameter_id'
]):
_logger
.
info
(
'The metrics of the recovered trial %d are ignored'
,
data
[
'parameter_id'
])
return
# NOTE: this part is not aligned with hpo tuners.
# in hpo tuners, trial_job_id is used for intermediate results handling
# parameter_id is for final result handling.
_logger
.
debug
(
'Metric reported: %s'
,
data
)
if
data
[
'type'
]
==
MetricType
.
REQUEST_PARAMETER
:
raise
ValueError
(
'Request parameter not supported'
)
...
...
@@ -239,4 +258,5 @@ class RetiariiAdvisor(MsgDispatcherBase):
pass
def
handle_add_customized_trial
(
self
,
data
):
pass
previous_max_param_id
=
self
.
recover_parameter_id
(
data
)
self
.
parameters_count
=
previous_max_param_id
nni/nas/execution/common/integration_api.py
View file @
bcc640c4
...
...
@@ -12,6 +12,7 @@ from typing import NewType, Any
import
nni
from
nni.common.version
import
version_check
# NOTE: this is only for passing flake8, we cannot import RetiariiAdvisor
# because it would induce cycled import
RetiariiAdvisor
=
NewType
(
'RetiariiAdvisor'
,
Any
)
...
...
nni/recoverable.py
View file @
bcc640c4
...
...
@@ -4,8 +4,12 @@
from
__future__
import
annotations
import
os
import
nni
class
Recoverable
:
def
__init__
(
self
):
self
.
recovered_max_param_id
=
-
1
self
.
recovered_trial_params
=
{}
def
load_checkpoint
(
self
)
->
None
:
pass
...
...
@@ -18,3 +22,29 @@ class Recoverable:
if
ckp_path
is
not
None
and
os
.
path
.
isdir
(
ckp_path
):
return
ckp_path
return
None
def
recover_parameter_id
(
self
,
data
)
->
int
:
# this is for handling the resuming of the interrupted data: parameters
if
not
isinstance
(
data
,
list
):
data
=
[
data
]
previous_max_param_id
=
0
for
trial
in
data
:
# {'parameter_id': 0, 'parameter_source': 'resumed', 'parameters': {'batch_size': 128, ...}
if
isinstance
(
trial
,
str
):
trial
=
nni
.
load
(
trial
)
if
not
isinstance
(
trial
[
'parameter_id'
],
int
):
# for dealing with user customized trials
# skip for now
continue
self
.
recovered_trial_params
[
trial
[
'parameter_id'
]]
=
trial
[
'parameters'
]
if
previous_max_param_id
<
trial
[
'parameter_id'
]:
previous_max_param_id
=
trial
[
'parameter_id'
]
self
.
recovered_max_param_id
=
previous_max_param_id
return
previous_max_param_id
def
is_created_in_previous_exp
(
self
,
param_id
:
int
)
->
bool
:
return
param_id
<=
self
.
recovered_max_param_id
def
get_previous_param
(
self
,
param_id
:
int
)
->
dict
:
return
self
.
recovered_trial_params
[
param_id
]
\ No newline at end of file
nni/runtime/msg_dispatcher.py
View file @
bcc640c4
...
...
@@ -120,15 +120,10 @@ class MsgDispatcher(MsgDispatcherBase):
self
.
tuner
.
import_data
(
data
)
def
handle_add_customized_trial
(
self
,
data
):
global
_next_parameter_id
# data: parameters
if
not
isinstance
(
data
,
list
):
data
=
[
data
]
for
_
in
data
:
id_
=
_create_parameter_id
()
_customized_parameter_ids
.
add
(
id_
)
self
.
tuner
.
import_customized_data
(
data
)
previous_max_param_id
=
self
.
recover_parameter_id
(
data
)
_next_parameter_id
=
previous_max_param_id
+
1
def
handle_report_metric_data
(
self
,
data
):
"""
...
...
@@ -137,6 +132,13 @@ class MsgDispatcher(MsgDispatcherBase):
- 'value': metric value reported by nni.report_final_result()
- 'type': report type, support {'FINAL', 'PERIODICAL'}
"""
if
self
.
is_created_in_previous_exp
(
data
[
'parameter_id'
]):
if
data
[
'type'
]
==
MetricType
.
FINAL
:
# only deal with final metric using import data
param
=
self
.
get_previous_param
(
data
[
'parameter_id'
])
trial_data
=
[{
'parameter'
:
param
,
'value'
:
load
(
data
[
'value'
])}]
self
.
handle_import_data
(
trial_data
)
return
# metrics value is dumped as json string in trial, so we need to decode it here
if
'value'
in
data
:
data
[
'value'
]
=
load
(
data
[
'value'
])
...
...
@@ -166,6 +168,10 @@ class MsgDispatcher(MsgDispatcherBase):
- event: the job's state
- hyper_params: the hyperparameters generated and returned by tuner
"""
id_
=
load
(
data
[
'hyper_params'
])[
'parameter_id'
]
if
self
.
is_created_in_previous_exp
(
id_
):
# The end of the recovered trial is ignored
return
trial_job_id
=
data
[
'trial_job_id'
]
_ended_trials
.
add
(
trial_job_id
)
if
trial_job_id
in
_trial_history
:
...
...
@@ -173,7 +179,7 @@ class MsgDispatcher(MsgDispatcherBase):
if
self
.
assessor
is
not
None
:
self
.
assessor
.
trial_end
(
trial_job_id
,
data
[
'event'
]
==
'SUCCEEDED'
)
if
self
.
tuner
is
not
None
:
self
.
tuner
.
trial_end
(
load
(
data
[
'hyper_params'
])[
'parameter_id'
]
,
data
[
'event'
]
==
'SUCCEEDED'
)
self
.
tuner
.
trial_end
(
id_
,
data
[
'event'
]
==
'SUCCEEDED'
)
def
_handle_final_metric_data
(
self
,
data
):
"""Call tuner to process final results
...
...
nni/runtime/msg_dispatcher_base.py
View file @
bcc640c4
...
...
@@ -30,6 +30,7 @@ class MsgDispatcherBase(Recoverable):
"""
def
__init__
(
self
,
command_channel_url
=
None
):
super
().
__init__
()
self
.
stopping
=
False
if
command_channel_url
is
None
:
command_channel_url
=
dispatcher_env_vars
.
NNI_TUNER_COMMAND_CHANNEL
...
...
nni/tuner.py
View file @
bcc640c4
...
...
@@ -219,14 +219,6 @@ class Tuner(Recoverable):
# data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
pass
def
import_customized_data
(
self
,
data
:
list
[
TrialRecord
])
->
None
:
"""
Internal API under revising, not recommended for end users.
"""
# Import resume data for avoiding duplications
# data: a list of dictionarys, each of which has at least two keys, 'parameter_id' and 'parameters'
pass
def
_on_exit
(
self
)
->
None
:
pass
...
...
test/algo/nas/test_cgo_engine.py
View file @
bcc640c4
...
...
@@ -319,6 +319,9 @@ class CGOEngineTest(unittest.TestCase):
advisor
.
_channel
=
protocol
.
LegacyCommandChannel
()
advisor
.
default_worker
.
start
()
advisor
.
assessor_worker
.
start
()
# this is because RetiariiAdvisor only works after `_advisor_initialized` becomes True.
# normally it becomes true when `handle_request_trial_jobs` is invoked
advisor
.
_advisor_initialized
=
True
remote
=
RemoteConfig
(
machine_list
=
[])
remote
.
machine_list
.
append
(
RemoteMachineConfig
(
host
=
'test'
,
gpu_indices
=
[
0
,
1
,
2
,
3
]))
...
...
test/ut/nas/test_engine.py
View file @
bcc640c4
...
...
@@ -27,6 +27,7 @@ class EngineTest(unittest.TestCase):
nni
.
retiarii
.
integration_api
.
_advisor
=
None
nni
.
retiarii
.
execution
.
api
.
_execution_engine
=
None
advisor
=
RetiariiAdvisor
(
'ws://_unittest_placeholder_'
)
advisor
.
_advisor_initialized
=
True
advisor
.
_channel
=
LegacyCommandChannel
()
advisor
.
default_worker
.
start
()
advisor
.
assessor_worker
.
start
()
...
...
@@ -44,6 +45,7 @@ class EngineTest(unittest.TestCase):
nni
.
retiarii
.
integration_api
.
_advisor
=
None
nni
.
retiarii
.
execution
.
api
.
_execution_engine
=
None
advisor
=
RetiariiAdvisor
(
'ws://_unittest_placeholder_'
)
advisor
.
_advisor_initialized
=
True
advisor
.
_channel
=
LegacyCommandChannel
()
advisor
.
default_worker
.
start
()
advisor
.
assessor_worker
.
start
()
...
...
test/ut/sdk/test_assessor.py
View file @
bcc640c4
...
...
@@ -48,11 +48,11 @@ class AssessorTestCase(TestCase):
def
test_assessor
(
self
):
pass
_reverse_io
()
send
(
CommandType
.
ReportMetricData
,
'{"trial_job_id":"A","type":"PERIODICAL","sequence":0,"value":"2"}'
)
send
(
CommandType
.
ReportMetricData
,
'{"trial_job_id":"B","type":"PERIODICAL","sequence":0,"value":"2"}'
)
send
(
CommandType
.
ReportMetricData
,
'{"trial_job_id":"A","type":"PERIODICAL","sequence":1,"value":"3"}'
)
send
(
CommandType
.
TrialEnd
,
'{"trial_job_id":"A","event":"SYS_CANCELED"}'
)
send
(
CommandType
.
TrialEnd
,
'{"trial_job_id":"B","event":"SUCCEEDED"}'
)
send
(
CommandType
.
ReportMetricData
,
'{"
parameter_id": 0,"
trial_job_id":"A","type":"PERIODICAL","sequence":0,"value":"2"}'
)
send
(
CommandType
.
ReportMetricData
,
'{"
parameter_id": 1,"
trial_job_id":"B","type":"PERIODICAL","sequence":0,"value":"2"}'
)
send
(
CommandType
.
ReportMetricData
,
'{"
parameter_id": 0,"
trial_job_id":"A","type":"PERIODICAL","sequence":1,"value":"3"}'
)
send
(
CommandType
.
TrialEnd
,
'{"trial_job_id":"A","event":"SYS_CANCELED"
,"hyper_params":"{
\\
"parameter_id
\\
": 0}"
}'
)
send
(
CommandType
.
TrialEnd
,
'{"trial_job_id":"B","event":"SUCCEEDED"
,"hyper_params":"{
\\
"parameter_id
\\
": 1}"
}'
)
send
(
CommandType
.
NewTrialJob
,
'null'
)
_restore_io
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment