Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
4a0cc125
Unverified
Commit
4a0cc125
authored
Apr 21, 2021
by
QuanluZhang
Committed by
GitHub
Apr 21, 2021
Browse files
[Retiarii] fix experiment early exit (#3547)
parent
336d671c
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
43 additions
and
10 deletions
+43
-10
nni/retiarii/execution/api.py
nni/retiarii/execution/api.py
+7
-1
nni/retiarii/execution/base.py
nni/retiarii/execution/base.py
+4
-0
nni/retiarii/execution/cgo_engine.py
nni/retiarii/execution/cgo_engine.py
+3
-0
nni/retiarii/execution/interface.py
nni/retiarii/execution/interface.py
+7
-0
nni/retiarii/experiment/pytorch.py
nni/retiarii/experiment/pytorch.py
+15
-5
nni/retiarii/strategy/tpe_strategy.py
nni/retiarii/strategy/tpe_strategy.py
+4
-4
test/ut/retiarii/test_strategy.py
test/ut/retiarii/test_strategy.py
+3
-0
No files found.
nni/retiarii/execution/api.py
View file @
4a0cc125
...
...
@@ -13,7 +13,7 @@ _default_listener = None
__all__
=
[
'get_execution_engine'
,
'get_and_register_default_listener'
,
'list_models'
,
'submit_models'
,
'wait_models'
,
'query_available_resources'
,
'set_execution_engine'
,
'is_stopped_exec'
]
'set_execution_engine'
,
'is_stopped_exec'
,
'budget_exhausted'
]
def
set_execution_engine
(
engine
)
->
None
:
global
_execution_engine
...
...
@@ -22,6 +22,7 @@ def set_execution_engine(engine) -> None:
else
:
raise
RuntimeError
(
'execution engine is already set'
)
def
get_execution_engine
()
->
AbstractExecutionEngine
:
"""
Currently we assume the default execution engine is BaseExecutionEngine.
...
...
@@ -67,3 +68,8 @@ def query_available_resources() -> int:
def
is_stopped_exec
(
model
:
Model
)
->
bool
:
return
model
.
status
in
(
ModelStatus
.
Trained
,
ModelStatus
.
Failed
)
def
budget_exhausted
()
->
bool
:
engine
=
get_execution_engine
()
return
engine
.
budget_exhausted
()
nni/retiarii/execution/base.py
View file @
4a0cc125
...
...
@@ -104,6 +104,10 @@ class BaseExecutionEngine(AbstractExecutionEngine):
def
query_available_resource
(
self
)
->
int
:
return
self
.
resources
def
budget_exhausted
(
self
)
->
bool
:
advisor
=
get_advisor
()
return
advisor
.
stopping
@
classmethod
def
trial_execute_graph
(
cls
)
->
None
:
"""
...
...
nni/retiarii/execution/cgo_engine.py
View file @
4a0cc125
...
...
@@ -130,6 +130,9 @@ class CGOExecutionEngine(AbstractExecutionEngine):
def
query_available_resource
(
self
)
->
List
[
WorkerInfo
]:
raise
NotImplementedError
# move the method from listener to here?
def
budget_exhausted
(
self
)
->
bool
:
raise
NotImplementedError
@
classmethod
def
trial_execute_graph
(
cls
)
->
None
:
"""
...
...
nni/retiarii/execution/interface.py
View file @
4a0cc125
...
...
@@ -123,6 +123,13 @@ class AbstractExecutionEngine(ABC):
"""
raise
NotImplementedError
@
abstractmethod
def
budget_exhausted
(
self
)
->
bool
:
"""
Check whether user configured max trial number or max execution duration has been reached
"""
raise
NotImplementedError
@
abstractmethod
def
register_graph_listener
(
self
,
listener
:
AbstractGraphListener
)
->
None
:
"""
...
...
nni/retiarii/experiment/pytorch.py
View file @
4a0cc125
...
...
@@ -165,7 +165,8 @@ class RetiariiExperiment(Experiment):
_logger
.
info
(
'Start strategy...'
)
self
.
strategy
.
run
(
base_model_ir
,
self
.
applied_mutators
)
_logger
.
info
(
'Strategy exit'
)
self
.
_dispatcher
.
mark_experiment_as_ending
()
# TODO: find out a proper way to show no more trial message on WebUI
#self._dispatcher.mark_experiment_as_ending()
def
start
(
self
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
None
:
"""
...
...
@@ -210,11 +211,12 @@ class RetiariiExperiment(Experiment):
msg
=
'Web UI URLs: '
+
colorama
.
Fore
.
CYAN
+
' '
.
join
(
ips
)
+
colorama
.
Style
.
RESET_ALL
_logger
.
info
(
msg
)
Thread
(
target
=
self
.
_check_exp_status
).
start
()
exp_status_checker
=
Thread
(
target
=
self
.
_check_exp_status
)
exp_status_checker
.
start
()
self
.
_start_strategy
()
# TODO: the experiment should be completed, when strategy exits and there is no running job
# _logger.info('Waiting for submitted trial jobs to finish...')
_logger
.
info
(
'Waiting for experiment to become DONE (you can ctrl+c if there is no running trial jobs)...'
)
exp_status_checker
.
join
()
def
_create_dispatcher
(
self
):
return
self
.
_dispatcher
...
...
@@ -240,7 +242,12 @@ class RetiariiExperiment(Experiment):
try
:
while
True
:
time
.
sleep
(
10
)
# this if is to deal with the situation that
# nnimanager is cleaned up by ctrl+c first
if
self
.
_proc
.
poll
()
is
None
:
status
=
self
.
get_status
()
else
:
return
False
if
status
==
'DONE'
or
status
==
'STOPPED'
:
return
True
if
status
==
'ERROR'
:
...
...
@@ -261,6 +268,9 @@ class RetiariiExperiment(Experiment):
nni
.
runtime
.
log
.
stop_experiment_log
(
self
.
id
)
if
self
.
_proc
is
not
None
:
try
:
# this if is to deal with the situation that
# nnimanager is cleaned up by ctrl+c first
if
self
.
_proc
.
poll
()
is
None
:
rest
.
delete
(
self
.
port
,
'/experiment'
)
except
Exception
as
e
:
_logger
.
exception
(
e
)
...
...
nni/retiarii/strategy/tpe_strategy.py
View file @
4a0cc125
...
...
@@ -6,7 +6,7 @@ import time
from
nni.algorithms.hpo.hyperopt_tuner
import
HyperoptTuner
from
..
import
Sampler
,
submit_models
,
query_available_resources
,
is_stopped_exec
from
..
import
Sampler
,
submit_models
,
query_available_resources
,
is_stopped_exec
,
budget_exhausted
from
.base
import
BaseStrategy
_logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -54,7 +54,7 @@ class TPEStrategy(BaseStrategy):
self
.
tpe_sampler
.
update_sample_space
(
sample_space
)
_logger
.
info
(
'TPE strategy has been started.'
)
while
True
:
while
not
budget_exhausted
()
:
avail_resource
=
query_available_resources
()
if
avail_resource
>
0
:
model
=
base_model
...
...
@@ -70,13 +70,13 @@ class TPEStrategy(BaseStrategy):
else
:
time
.
sleep
(
2
)
_logger
.
warnin
g
(
'num of running models: %d'
,
len
(
self
.
running_models
))
_logger
.
debu
g
(
'num of running models: %d'
,
len
(
self
.
running_models
))
to_be_deleted
=
[]
for
_id
,
_model
in
self
.
running_models
.
items
():
if
is_stopped_exec
(
_model
):
if
_model
.
metric
is
not
None
:
self
.
tpe_sampler
.
receive_result
(
_id
,
_model
.
metric
)
_logger
.
warnin
g
(
'tpe receive results: %d, %s'
,
_id
,
_model
.
metric
)
_logger
.
debu
g
(
'tpe receive results: %d, %s'
,
_id
,
_model
.
metric
)
to_be_deleted
.
append
(
_id
)
for
_id
in
to_be_deleted
:
del
self
.
running_models
[
_id
]
test/ut/retiarii/test_strategy.py
View file @
4a0cc125
...
...
@@ -43,6 +43,9 @@ class MockExecutionEngine(AbstractExecutionEngine):
def
query_available_resource
(
self
)
->
Union
[
List
[
WorkerInfo
],
int
]:
return
self
.
_resource_left
def
budget_exhausted
(
self
)
->
bool
:
pass
def
register_graph_listener
(
self
,
listener
:
AbstractGraphListener
)
->
None
:
pass
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment