Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
bcc640c4
Unverified
Commit
bcc640c4
authored
Oct 12, 2022
by
QuanluZhang
Committed by
GitHub
Oct 12, 2022
Browse files
[nas] fix issue introduced by the trial recovery feature (#5109)
parent
87677df8
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
105 additions
and
41 deletions
+105
-41
nni/algorithms/hpo/bohb_advisor/bohb_advisor.py
nni/algorithms/hpo/bohb_advisor/bohb_advisor.py
+15
-2
nni/algorithms/hpo/hyperband_advisor.py
nni/algorithms/hpo/hyperband_advisor.py
+10
-1
nni/algorithms/hpo/tpe_tuner.py
nni/algorithms/hpo/tpe_tuner.py
+0
-13
nni/nas/execution/common/integration.py
nni/nas/execution/common/integration.py
+23
-3
nni/nas/execution/common/integration_api.py
nni/nas/execution/common/integration_api.py
+1
-0
nni/recoverable.py
nni/recoverable.py
+30
-0
nni/runtime/msg_dispatcher.py
nni/runtime/msg_dispatcher.py
+15
-9
nni/runtime/msg_dispatcher_base.py
nni/runtime/msg_dispatcher_base.py
+1
-0
nni/tuner.py
nni/tuner.py
+0
-8
test/algo/nas/test_cgo_engine.py
test/algo/nas/test_cgo_engine.py
+3
-0
test/ut/nas/test_engine.py
test/ut/nas/test_engine.py
+2
-0
test/ut/sdk/test_assessor.py
test/ut/sdk/test_assessor.py
+5
-5
No files found.
nni/algorithms/hpo/bohb_advisor/bohb_advisor.py
View file @
bcc640c4
...
@@ -648,8 +648,11 @@ class BOHB(MsgDispatcherBase):
...
@@ -648,8 +648,11 @@ class BOHB(MsgDispatcherBase):
event: the job's state
event: the job's state
hyper_params: the hyperparameters (a string) generated and returned by tuner
hyper_params: the hyperparameters (a string) generated and returned by tuner
"""
"""
logger
.
debug
(
'Tuner handle trial end, result is %s'
,
data
)
hyper_params
=
nni
.
load
(
data
[
'hyper_params'
])
hyper_params
=
nni
.
load
(
data
[
'hyper_params'
])
if
self
.
is_created_in_previous_exp
(
hyper_params
[
'parameter_id'
]):
# The end of the recovered trial is ignored
return
logger
.
debug
(
'Tuner handle trial end, result is %s'
,
data
)
self
.
_handle_trial_end
(
hyper_params
[
'parameter_id'
])
self
.
_handle_trial_end
(
hyper_params
[
'parameter_id'
])
if
data
[
'trial_job_id'
]
in
self
.
job_id_para_id_map
:
if
data
[
'trial_job_id'
]
in
self
.
job_id_para_id_map
:
del
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]]
del
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]]
...
@@ -695,6 +698,13 @@ class BOHB(MsgDispatcherBase):
...
@@ -695,6 +698,13 @@ class BOHB(MsgDispatcherBase):
ValueError
ValueError
Data type not supported
Data type not supported
"""
"""
if
self
.
is_created_in_previous_exp
(
data
[
'parameter_id'
]):
if
data
[
'type'
]
==
MetricType
.
FINAL
:
# only deal with final metric using import data
param
=
self
.
get_previous_param
(
data
[
'parameter_id'
])
trial_data
=
[{
'parameter'
:
param
,
'value'
:
nni
.
load
(
data
[
'value'
])}]
self
.
handle_import_data
(
trial_data
)
return
logger
.
debug
(
'handle report metric data = %s'
,
data
)
logger
.
debug
(
'handle report metric data = %s'
,
data
)
if
'value'
in
data
:
if
'value'
in
data
:
data
[
'value'
]
=
nni
.
load
(
data
[
'value'
])
data
[
'value'
]
=
nni
.
load
(
data
[
'value'
])
...
@@ -752,7 +762,10 @@ class BOHB(MsgDispatcherBase):
...
@@ -752,7 +762,10 @@ class BOHB(MsgDispatcherBase):
'Data type not supported: {}'
.
format
(
data
[
'type'
]))
'Data type not supported: {}'
.
format
(
data
[
'type'
]))
def
handle_add_customized_trial
(
self
,
data
):
def
handle_add_customized_trial
(
self
,
data
):
pass
global
_next_parameter_id
# data: parameters
previous_max_param_id
=
self
.
recover_parameter_id
(
data
)
_next_parameter_id
=
previous_max_param_id
+
1
def
handle_import_data
(
self
,
data
):
def
handle_import_data
(
self
,
data
):
"""Import additional data for tuning
"""Import additional data for tuning
...
...
nni/algorithms/hpo/hyperband_advisor.py
View file @
bcc640c4
...
@@ -522,6 +522,9 @@ class Hyperband(MsgDispatcherBase):
...
@@ -522,6 +522,9 @@ class Hyperband(MsgDispatcherBase):
hyper_params: the hyperparameters (a string) generated and returned by tuner
hyper_params: the hyperparameters (a string) generated and returned by tuner
"""
"""
hyper_params
=
nni
.
load
(
data
[
'hyper_params'
])
hyper_params
=
nni
.
load
(
data
[
'hyper_params'
])
if
self
.
is_created_in_previous_exp
(
hyper_params
[
'parameter_id'
]):
# The end of the recovered trial is ignored
return
self
.
_handle_trial_end
(
hyper_params
[
'parameter_id'
])
self
.
_handle_trial_end
(
hyper_params
[
'parameter_id'
])
if
data
[
'trial_job_id'
]
in
self
.
job_id_para_id_map
:
if
data
[
'trial_job_id'
]
in
self
.
job_id_para_id_map
:
del
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]]
del
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]]
...
@@ -538,6 +541,9 @@ class Hyperband(MsgDispatcherBase):
...
@@ -538,6 +541,9 @@ class Hyperband(MsgDispatcherBase):
ValueError
ValueError
Data type not supported
Data type not supported
"""
"""
if
self
.
is_created_in_previous_exp
(
data
[
'parameter_id'
]):
# do not support recovering the algorithm state
return
if
'value'
in
data
:
if
'value'
in
data
:
data
[
'value'
]
=
nni
.
load
(
data
[
'value'
])
data
[
'value'
]
=
nni
.
load
(
data
[
'value'
])
# multiphase? need to check
# multiphase? need to check
...
@@ -576,7 +582,10 @@ class Hyperband(MsgDispatcherBase):
...
@@ -576,7 +582,10 @@ class Hyperband(MsgDispatcherBase):
raise
ValueError
(
'Data type not supported: {}'
.
format
(
data
[
'type'
]))
raise
ValueError
(
'Data type not supported: {}'
.
format
(
data
[
'type'
]))
def
handle_add_customized_trial
(
self
,
data
):
def
handle_add_customized_trial
(
self
,
data
):
pass
global
_next_parameter_id
# data: parameters
previous_max_param_id
=
self
.
recover_parameter_id
(
data
)
_next_parameter_id
=
previous_max_param_id
+
1
def
handle_import_data
(
self
,
data
):
def
handle_import_data
(
self
,
data
):
pass
pass
nni/algorithms/hpo/tpe_tuner.py
View file @
bcc640c4
...
@@ -218,19 +218,6 @@ class TpeTuner(Tuner):
...
@@ -218,19 +218,6 @@ class TpeTuner(Tuner):
self
.
dedup
.
add_history
(
param
)
self
.
dedup
.
add_history
(
param
)
_logger
.
info
(
f
'Replayed
{
len
(
data
)
}
FINISHED trials'
)
_logger
.
info
(
f
'Replayed
{
len
(
data
)
}
FINISHED trials'
)
def
import_customized_data
(
self
,
data
):
# for dedup customized / resumed
if
isinstance
(
data
,
str
):
data
=
nni
.
load
(
data
)
for
trial
in
data
:
# {'parameter_id': 0, 'parameter_source': 'resumed', 'parameters': {'batch_size': 128, ...}
if
isinstance
(
trial
,
str
):
trial
=
nni
.
load
(
trial
)
param
=
format_parameters
(
trial
[
'parameters'
],
self
.
space
)
self
.
_running_params
[
trial
[
'parameter_id'
]]
=
param
self
.
dedup
.
add_history
(
param
)
_logger
.
info
(
f
'Replayed
{
len
(
data
)
}
RUNING/WAITING trials'
)
def
suggest
(
args
,
rng
,
space
,
history
):
def
suggest
(
args
,
rng
,
space
,
history
):
params
=
{}
params
=
{}
for
key
,
spec
in
space
.
items
():
for
key
,
spec
in
space
.
items
():
...
...
nni/nas/execution/common/integration.py
View file @
bcc640c4
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
__all__
=
[
'RetiariiAdvisor'
]
__all__
=
[
'RetiariiAdvisor'
]
import
logging
import
logging
import
time
import
os
import
os
from
typing
import
Any
,
Callable
,
Optional
,
Dict
,
List
,
Tuple
from
typing
import
Any
,
Callable
,
Optional
,
Dict
,
List
,
Tuple
...
@@ -60,11 +61,12 @@ class RetiariiAdvisor(MsgDispatcherBase):
...
@@ -60,11 +61,12 @@ class RetiariiAdvisor(MsgDispatcherBase):
self
.
final_metric_callback
:
Optional
[
Callable
[[
int
,
MetricData
],
None
]]
=
None
self
.
final_metric_callback
:
Optional
[
Callable
[[
int
,
MetricData
],
None
]]
=
None
self
.
parameters_count
=
0
self
.
parameters_count
=
0
# Sometimes messages arrive first before the callbacks get registered.
# Sometimes messages arrive first before the callbacks get registered.
# Or in case that we allow engine to be absent during the experiment.
# Or in case that we allow engine to be absent during the experiment.
# Here we need to store the messages and invoke them later.
# Here we need to store the messages and invoke them later.
self
.
call_queue
:
List
[
Tuple
[
str
,
list
]]
=
[]
self
.
call_queue
:
List
[
Tuple
[
str
,
list
]]
=
[]
# this is for waiting the to-be-recovered trials from nnimanager
self
.
_advisor_initialized
=
False
def
register_callbacks
(
self
,
callbacks
:
Dict
[
str
,
Callable
[...,
None
]]):
def
register_callbacks
(
self
,
callbacks
:
Dict
[
str
,
Callable
[...,
None
]]):
"""
"""
...
@@ -167,6 +169,10 @@ class RetiariiAdvisor(MsgDispatcherBase):
...
@@ -167,6 +169,10 @@ class RetiariiAdvisor(MsgDispatcherBase):
Parameter ID that is assigned to this parameter,
Parameter ID that is assigned to this parameter,
which will be used for identification in future.
which will be used for identification in future.
"""
"""
while
not
self
.
_advisor_initialized
:
_logger
.
info
(
'Wait for RetiariiAdvisor to be initialized...'
)
time
.
sleep
(
0.5
)
self
.
parameters_count
+=
1
self
.
parameters_count
+=
1
if
placement_constraint
is
None
:
if
placement_constraint
is
None
:
placement_constraint
=
{
placement_constraint
=
{
...
@@ -204,6 +210,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
...
@@ -204,6 +210,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
self
.
send
(
CommandType
.
NoMoreTrialJobs
,
''
)
self
.
send
(
CommandType
.
NoMoreTrialJobs
,
''
)
def
handle_request_trial_jobs
(
self
,
num_trials
):
def
handle_request_trial_jobs
(
self
,
num_trials
):
self
.
_advisor_initialized
=
True
_logger
.
debug
(
'Request trial jobs: %s'
,
num_trials
)
_logger
.
debug
(
'Request trial jobs: %s'
,
num_trials
)
self
.
invoke_callback
(
'request_trial_jobs'
,
num_trials
)
self
.
invoke_callback
(
'request_trial_jobs'
,
num_trials
)
...
@@ -212,10 +219,22 @@ class RetiariiAdvisor(MsgDispatcherBase):
...
@@ -212,10 +219,22 @@ class RetiariiAdvisor(MsgDispatcherBase):
self
.
search_space
=
data
self
.
search_space
=
data
def
handle_trial_end
(
self
,
data
):
def
handle_trial_end
(
self
,
data
):
# TODO: we should properly handle the trials in self._customized_parameter_ids instead of ignoring
id_
=
nni
.
load
(
data
[
'hyper_params'
])[
'parameter_id'
]
if
self
.
is_created_in_previous_exp
(
id_
):
_logger
.
info
(
'The end of the recovered trial %d is ignored'
,
id_
)
return
_logger
.
debug
(
'Trial end: %s'
,
data
)
_logger
.
debug
(
'Trial end: %s'
,
data
)
self
.
invoke_callback
(
'trial_end'
,
nni
.
load
(
data
[
'hyper_params'
])[
'parameter_id'
]
,
data
[
'event'
]
==
'SUCCEEDED'
)
self
.
invoke_callback
(
'trial_end'
,
id_
,
data
[
'event'
]
==
'SUCCEEDED'
)
def
handle_report_metric_data
(
self
,
data
):
def
handle_report_metric_data
(
self
,
data
):
# TODO: we should properly handle the trials in self._customized_parameter_ids instead of ignoring
if
self
.
is_created_in_previous_exp
(
data
[
'parameter_id'
]):
_logger
.
info
(
'The metrics of the recovered trial %d are ignored'
,
data
[
'parameter_id'
])
return
# NOTE: this part is not aligned with hpo tuners.
# in hpo tuners, trial_job_id is used for intermediate results handling
# parameter_id is for final result handling.
_logger
.
debug
(
'Metric reported: %s'
,
data
)
_logger
.
debug
(
'Metric reported: %s'
,
data
)
if
data
[
'type'
]
==
MetricType
.
REQUEST_PARAMETER
:
if
data
[
'type'
]
==
MetricType
.
REQUEST_PARAMETER
:
raise
ValueError
(
'Request parameter not supported'
)
raise
ValueError
(
'Request parameter not supported'
)
...
@@ -239,4 +258,5 @@ class RetiariiAdvisor(MsgDispatcherBase):
...
@@ -239,4 +258,5 @@ class RetiariiAdvisor(MsgDispatcherBase):
pass
pass
def
handle_add_customized_trial
(
self
,
data
):
def
handle_add_customized_trial
(
self
,
data
):
pass
previous_max_param_id
=
self
.
recover_parameter_id
(
data
)
self
.
parameters_count
=
previous_max_param_id
nni/nas/execution/common/integration_api.py
View file @
bcc640c4
...
@@ -12,6 +12,7 @@ from typing import NewType, Any
...
@@ -12,6 +12,7 @@ from typing import NewType, Any
import
nni
import
nni
from
nni.common.version
import
version_check
from
nni.common.version
import
version_check
# NOTE: this is only for passing flake8, we cannot import RetiariiAdvisor
# NOTE: this is only for passing flake8, we cannot import RetiariiAdvisor
# because it would induce cycled import
# because it would induce cycled import
RetiariiAdvisor
=
NewType
(
'RetiariiAdvisor'
,
Any
)
RetiariiAdvisor
=
NewType
(
'RetiariiAdvisor'
,
Any
)
...
...
nni/recoverable.py
View file @
bcc640c4
...
@@ -4,8 +4,12 @@
...
@@ -4,8 +4,12 @@
from
__future__
import
annotations
from
__future__
import
annotations
import
os
import
os
import
nni
class
Recoverable
:
class
Recoverable
:
def
__init__
(
self
):
self
.
recovered_max_param_id
=
-
1
self
.
recovered_trial_params
=
{}
def
load_checkpoint
(
self
)
->
None
:
def
load_checkpoint
(
self
)
->
None
:
pass
pass
...
@@ -18,3 +22,29 @@ class Recoverable:
...
@@ -18,3 +22,29 @@ class Recoverable:
if
ckp_path
is
not
None
and
os
.
path
.
isdir
(
ckp_path
):
if
ckp_path
is
not
None
and
os
.
path
.
isdir
(
ckp_path
):
return
ckp_path
return
ckp_path
return
None
return
None
def
recover_parameter_id
(
self
,
data
)
->
int
:
# this is for handling the resuming of the interrupted data: parameters
if
not
isinstance
(
data
,
list
):
data
=
[
data
]
previous_max_param_id
=
0
for
trial
in
data
:
# {'parameter_id': 0, 'parameter_source': 'resumed', 'parameters': {'batch_size': 128, ...}
if
isinstance
(
trial
,
str
):
trial
=
nni
.
load
(
trial
)
if
not
isinstance
(
trial
[
'parameter_id'
],
int
):
# for dealing with user customized trials
# skip for now
continue
self
.
recovered_trial_params
[
trial
[
'parameter_id'
]]
=
trial
[
'parameters'
]
if
previous_max_param_id
<
trial
[
'parameter_id'
]:
previous_max_param_id
=
trial
[
'parameter_id'
]
self
.
recovered_max_param_id
=
previous_max_param_id
return
previous_max_param_id
def
is_created_in_previous_exp
(
self
,
param_id
:
int
)
->
bool
:
return
param_id
<=
self
.
recovered_max_param_id
def
get_previous_param
(
self
,
param_id
:
int
)
->
dict
:
return
self
.
recovered_trial_params
[
param_id
]
\ No newline at end of file
nni/runtime/msg_dispatcher.py
View file @
bcc640c4
...
@@ -120,15 +120,10 @@ class MsgDispatcher(MsgDispatcherBase):
...
@@ -120,15 +120,10 @@ class MsgDispatcher(MsgDispatcherBase):
self
.
tuner
.
import_data
(
data
)
self
.
tuner
.
import_data
(
data
)
def
handle_add_customized_trial
(
self
,
data
):
def
handle_add_customized_trial
(
self
,
data
):
global
_next_parameter_id
# data: parameters
# data: parameters
if
not
isinstance
(
data
,
list
):
previous_max_param_id
=
self
.
recover_parameter_id
(
data
)
data
=
[
data
]
_next_parameter_id
=
previous_max_param_id
+
1
for
_
in
data
:
id_
=
_create_parameter_id
()
_customized_parameter_ids
.
add
(
id_
)
self
.
tuner
.
import_customized_data
(
data
)
def
handle_report_metric_data
(
self
,
data
):
def
handle_report_metric_data
(
self
,
data
):
"""
"""
...
@@ -137,6 +132,13 @@ class MsgDispatcher(MsgDispatcherBase):
...
@@ -137,6 +132,13 @@ class MsgDispatcher(MsgDispatcherBase):
- 'value': metric value reported by nni.report_final_result()
- 'value': metric value reported by nni.report_final_result()
- 'type': report type, support {'FINAL', 'PERIODICAL'}
- 'type': report type, support {'FINAL', 'PERIODICAL'}
"""
"""
if
self
.
is_created_in_previous_exp
(
data
[
'parameter_id'
]):
if
data
[
'type'
]
==
MetricType
.
FINAL
:
# only deal with final metric using import data
param
=
self
.
get_previous_param
(
data
[
'parameter_id'
])
trial_data
=
[{
'parameter'
:
param
,
'value'
:
load
(
data
[
'value'
])}]
self
.
handle_import_data
(
trial_data
)
return
# metrics value is dumped as json string in trial, so we need to decode it here
# metrics value is dumped as json string in trial, so we need to decode it here
if
'value'
in
data
:
if
'value'
in
data
:
data
[
'value'
]
=
load
(
data
[
'value'
])
data
[
'value'
]
=
load
(
data
[
'value'
])
...
@@ -166,6 +168,10 @@ class MsgDispatcher(MsgDispatcherBase):
...
@@ -166,6 +168,10 @@ class MsgDispatcher(MsgDispatcherBase):
- event: the job's state
- event: the job's state
- hyper_params: the hyperparameters generated and returned by tuner
- hyper_params: the hyperparameters generated and returned by tuner
"""
"""
id_
=
load
(
data
[
'hyper_params'
])[
'parameter_id'
]
if
self
.
is_created_in_previous_exp
(
id_
):
# The end of the recovered trial is ignored
return
trial_job_id
=
data
[
'trial_job_id'
]
trial_job_id
=
data
[
'trial_job_id'
]
_ended_trials
.
add
(
trial_job_id
)
_ended_trials
.
add
(
trial_job_id
)
if
trial_job_id
in
_trial_history
:
if
trial_job_id
in
_trial_history
:
...
@@ -173,7 +179,7 @@ class MsgDispatcher(MsgDispatcherBase):
...
@@ -173,7 +179,7 @@ class MsgDispatcher(MsgDispatcherBase):
if
self
.
assessor
is
not
None
:
if
self
.
assessor
is
not
None
:
self
.
assessor
.
trial_end
(
trial_job_id
,
data
[
'event'
]
==
'SUCCEEDED'
)
self
.
assessor
.
trial_end
(
trial_job_id
,
data
[
'event'
]
==
'SUCCEEDED'
)
if
self
.
tuner
is
not
None
:
if
self
.
tuner
is
not
None
:
self
.
tuner
.
trial_end
(
load
(
data
[
'hyper_params'
])[
'parameter_id'
]
,
data
[
'event'
]
==
'SUCCEEDED'
)
self
.
tuner
.
trial_end
(
id_
,
data
[
'event'
]
==
'SUCCEEDED'
)
def
_handle_final_metric_data
(
self
,
data
):
def
_handle_final_metric_data
(
self
,
data
):
"""Call tuner to process final results
"""Call tuner to process final results
...
...
nni/runtime/msg_dispatcher_base.py
View file @
bcc640c4
...
@@ -30,6 +30,7 @@ class MsgDispatcherBase(Recoverable):
...
@@ -30,6 +30,7 @@ class MsgDispatcherBase(Recoverable):
"""
"""
def
__init__
(
self
,
command_channel_url
=
None
):
def
__init__
(
self
,
command_channel_url
=
None
):
super
().
__init__
()
self
.
stopping
=
False
self
.
stopping
=
False
if
command_channel_url
is
None
:
if
command_channel_url
is
None
:
command_channel_url
=
dispatcher_env_vars
.
NNI_TUNER_COMMAND_CHANNEL
command_channel_url
=
dispatcher_env_vars
.
NNI_TUNER_COMMAND_CHANNEL
...
...
nni/tuner.py
View file @
bcc640c4
...
@@ -219,14 +219,6 @@ class Tuner(Recoverable):
...
@@ -219,14 +219,6 @@ class Tuner(Recoverable):
# data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
# data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
pass
pass
def
import_customized_data
(
self
,
data
:
list
[
TrialRecord
])
->
None
:
"""
Internal API under revising, not recommended for end users.
"""
# Import resume data for avoiding duplications
# data: a list of dictionarys, each of which has at least two keys, 'parameter_id' and 'parameters'
pass
def
_on_exit
(
self
)
->
None
:
def
_on_exit
(
self
)
->
None
:
pass
pass
...
...
test/algo/nas/test_cgo_engine.py
View file @
bcc640c4
...
@@ -319,6 +319,9 @@ class CGOEngineTest(unittest.TestCase):
...
@@ -319,6 +319,9 @@ class CGOEngineTest(unittest.TestCase):
advisor
.
_channel
=
protocol
.
LegacyCommandChannel
()
advisor
.
_channel
=
protocol
.
LegacyCommandChannel
()
advisor
.
default_worker
.
start
()
advisor
.
default_worker
.
start
()
advisor
.
assessor_worker
.
start
()
advisor
.
assessor_worker
.
start
()
# this is because RetiariiAdvisor only works after `_advisor_initialized` becomes True.
# normally it becomes true when `handle_request_trial_jobs` is invoked
advisor
.
_advisor_initialized
=
True
remote
=
RemoteConfig
(
machine_list
=
[])
remote
=
RemoteConfig
(
machine_list
=
[])
remote
.
machine_list
.
append
(
RemoteMachineConfig
(
host
=
'test'
,
gpu_indices
=
[
0
,
1
,
2
,
3
]))
remote
.
machine_list
.
append
(
RemoteMachineConfig
(
host
=
'test'
,
gpu_indices
=
[
0
,
1
,
2
,
3
]))
...
...
test/ut/nas/test_engine.py
View file @
bcc640c4
...
@@ -27,6 +27,7 @@ class EngineTest(unittest.TestCase):
...
@@ -27,6 +27,7 @@ class EngineTest(unittest.TestCase):
nni
.
retiarii
.
integration_api
.
_advisor
=
None
nni
.
retiarii
.
integration_api
.
_advisor
=
None
nni
.
retiarii
.
execution
.
api
.
_execution_engine
=
None
nni
.
retiarii
.
execution
.
api
.
_execution_engine
=
None
advisor
=
RetiariiAdvisor
(
'ws://_unittest_placeholder_'
)
advisor
=
RetiariiAdvisor
(
'ws://_unittest_placeholder_'
)
advisor
.
_advisor_initialized
=
True
advisor
.
_channel
=
LegacyCommandChannel
()
advisor
.
_channel
=
LegacyCommandChannel
()
advisor
.
default_worker
.
start
()
advisor
.
default_worker
.
start
()
advisor
.
assessor_worker
.
start
()
advisor
.
assessor_worker
.
start
()
...
@@ -44,6 +45,7 @@ class EngineTest(unittest.TestCase):
...
@@ -44,6 +45,7 @@ class EngineTest(unittest.TestCase):
nni
.
retiarii
.
integration_api
.
_advisor
=
None
nni
.
retiarii
.
integration_api
.
_advisor
=
None
nni
.
retiarii
.
execution
.
api
.
_execution_engine
=
None
nni
.
retiarii
.
execution
.
api
.
_execution_engine
=
None
advisor
=
RetiariiAdvisor
(
'ws://_unittest_placeholder_'
)
advisor
=
RetiariiAdvisor
(
'ws://_unittest_placeholder_'
)
advisor
.
_advisor_initialized
=
True
advisor
.
_channel
=
LegacyCommandChannel
()
advisor
.
_channel
=
LegacyCommandChannel
()
advisor
.
default_worker
.
start
()
advisor
.
default_worker
.
start
()
advisor
.
assessor_worker
.
start
()
advisor
.
assessor_worker
.
start
()
...
...
test/ut/sdk/test_assessor.py
View file @
bcc640c4
...
@@ -48,11 +48,11 @@ class AssessorTestCase(TestCase):
...
@@ -48,11 +48,11 @@ class AssessorTestCase(TestCase):
def
test_assessor
(
self
):
def
test_assessor
(
self
):
pass
pass
_reverse_io
()
_reverse_io
()
send
(
CommandType
.
ReportMetricData
,
'{"trial_job_id":"A","type":"PERIODICAL","sequence":0,"value":"2"}'
)
send
(
CommandType
.
ReportMetricData
,
'{"
parameter_id": 0,"
trial_job_id":"A","type":"PERIODICAL","sequence":0,"value":"2"}'
)
send
(
CommandType
.
ReportMetricData
,
'{"trial_job_id":"B","type":"PERIODICAL","sequence":0,"value":"2"}'
)
send
(
CommandType
.
ReportMetricData
,
'{"
parameter_id": 1,"
trial_job_id":"B","type":"PERIODICAL","sequence":0,"value":"2"}'
)
send
(
CommandType
.
ReportMetricData
,
'{"trial_job_id":"A","type":"PERIODICAL","sequence":1,"value":"3"}'
)
send
(
CommandType
.
ReportMetricData
,
'{"
parameter_id": 0,"
trial_job_id":"A","type":"PERIODICAL","sequence":1,"value":"3"}'
)
send
(
CommandType
.
TrialEnd
,
'{"trial_job_id":"A","event":"SYS_CANCELED"}'
)
send
(
CommandType
.
TrialEnd
,
'{"trial_job_id":"A","event":"SYS_CANCELED"
,"hyper_params":"{
\\
"parameter_id
\\
": 0}"
}'
)
send
(
CommandType
.
TrialEnd
,
'{"trial_job_id":"B","event":"SUCCEEDED"}'
)
send
(
CommandType
.
TrialEnd
,
'{"trial_job_id":"B","event":"SUCCEEDED"
,"hyper_params":"{
\\
"parameter_id
\\
": 1}"
}'
)
send
(
CommandType
.
NewTrialJob
,
'null'
)
send
(
CommandType
.
NewTrialJob
,
'null'
)
_restore_io
()
_restore_io
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment