Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
4a0cc125
Unverified
Commit
4a0cc125
authored
Apr 21, 2021
by
QuanluZhang
Committed by
GitHub
Apr 21, 2021
Browse files
[Retiarii] fix experiment early exit (#3547)
parent
336d671c
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
43 additions
and
10 deletions
+43
-10
nni/retiarii/execution/api.py
nni/retiarii/execution/api.py
+7
-1
nni/retiarii/execution/base.py
nni/retiarii/execution/base.py
+4
-0
nni/retiarii/execution/cgo_engine.py
nni/retiarii/execution/cgo_engine.py
+3
-0
nni/retiarii/execution/interface.py
nni/retiarii/execution/interface.py
+7
-0
nni/retiarii/experiment/pytorch.py
nni/retiarii/experiment/pytorch.py
+15
-5
nni/retiarii/strategy/tpe_strategy.py
nni/retiarii/strategy/tpe_strategy.py
+4
-4
test/ut/retiarii/test_strategy.py
test/ut/retiarii/test_strategy.py
+3
-0
No files found.
nni/retiarii/execution/api.py
View file @
4a0cc125
...
@@ -13,7 +13,7 @@ _default_listener = None
...
@@ -13,7 +13,7 @@ _default_listener = None
__all__
=
[
'get_execution_engine'
,
'get_and_register_default_listener'
,
__all__
=
[
'get_execution_engine'
,
'get_and_register_default_listener'
,
'list_models'
,
'submit_models'
,
'wait_models'
,
'query_available_resources'
,
'list_models'
,
'submit_models'
,
'wait_models'
,
'query_available_resources'
,
'set_execution_engine'
,
'is_stopped_exec'
]
'set_execution_engine'
,
'is_stopped_exec'
,
'budget_exhausted'
]
def
set_execution_engine
(
engine
)
->
None
:
def
set_execution_engine
(
engine
)
->
None
:
global
_execution_engine
global
_execution_engine
...
@@ -22,6 +22,7 @@ def set_execution_engine(engine) -> None:
...
@@ -22,6 +22,7 @@ def set_execution_engine(engine) -> None:
else
:
else
:
raise
RuntimeError
(
'execution engine is already set'
)
raise
RuntimeError
(
'execution engine is already set'
)
def
get_execution_engine
()
->
AbstractExecutionEngine
:
def
get_execution_engine
()
->
AbstractExecutionEngine
:
"""
"""
Currently we assume the default execution engine is BaseExecutionEngine.
Currently we assume the default execution engine is BaseExecutionEngine.
...
@@ -67,3 +68,8 @@ def query_available_resources() -> int:
...
@@ -67,3 +68,8 @@ def query_available_resources() -> int:
def
is_stopped_exec
(
model
:
Model
)
->
bool
:
def
is_stopped_exec
(
model
:
Model
)
->
bool
:
return
model
.
status
in
(
ModelStatus
.
Trained
,
ModelStatus
.
Failed
)
return
model
.
status
in
(
ModelStatus
.
Trained
,
ModelStatus
.
Failed
)
def
budget_exhausted
()
->
bool
:
engine
=
get_execution_engine
()
return
engine
.
budget_exhausted
()
nni/retiarii/execution/base.py
View file @
4a0cc125
...
@@ -104,6 +104,10 @@ class BaseExecutionEngine(AbstractExecutionEngine):
...
@@ -104,6 +104,10 @@ class BaseExecutionEngine(AbstractExecutionEngine):
def
query_available_resource
(
self
)
->
int
:
def
query_available_resource
(
self
)
->
int
:
return
self
.
resources
return
self
.
resources
def
budget_exhausted
(
self
)
->
bool
:
advisor
=
get_advisor
()
return
advisor
.
stopping
@
classmethod
@
classmethod
def
trial_execute_graph
(
cls
)
->
None
:
def
trial_execute_graph
(
cls
)
->
None
:
"""
"""
...
...
nni/retiarii/execution/cgo_engine.py
View file @
4a0cc125
...
@@ -130,6 +130,9 @@ class CGOExecutionEngine(AbstractExecutionEngine):
...
@@ -130,6 +130,9 @@ class CGOExecutionEngine(AbstractExecutionEngine):
def
query_available_resource
(
self
)
->
List
[
WorkerInfo
]:
def
query_available_resource
(
self
)
->
List
[
WorkerInfo
]:
raise
NotImplementedError
# move the method from listener to here?
raise
NotImplementedError
# move the method from listener to here?
def
budget_exhausted
(
self
)
->
bool
:
raise
NotImplementedError
@
classmethod
@
classmethod
def
trial_execute_graph
(
cls
)
->
None
:
def
trial_execute_graph
(
cls
)
->
None
:
"""
"""
...
...
nni/retiarii/execution/interface.py
View file @
4a0cc125
...
@@ -123,6 +123,13 @@ class AbstractExecutionEngine(ABC):
...
@@ -123,6 +123,13 @@ class AbstractExecutionEngine(ABC):
"""
"""
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
def
budget_exhausted
(
self
)
->
bool
:
"""
Check whether user configured max trial number or max execution duration has been reached
"""
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
register_graph_listener
(
self
,
listener
:
AbstractGraphListener
)
->
None
:
def
register_graph_listener
(
self
,
listener
:
AbstractGraphListener
)
->
None
:
"""
"""
...
...
nni/retiarii/experiment/pytorch.py
View file @
4a0cc125
...
@@ -165,7 +165,8 @@ class RetiariiExperiment(Experiment):
...
@@ -165,7 +165,8 @@ class RetiariiExperiment(Experiment):
_logger
.
info
(
'Start strategy...'
)
_logger
.
info
(
'Start strategy...'
)
self
.
strategy
.
run
(
base_model_ir
,
self
.
applied_mutators
)
self
.
strategy
.
run
(
base_model_ir
,
self
.
applied_mutators
)
_logger
.
info
(
'Strategy exit'
)
_logger
.
info
(
'Strategy exit'
)
self
.
_dispatcher
.
mark_experiment_as_ending
()
# TODO: find out a proper way to show no more trial message on WebUI
#self._dispatcher.mark_experiment_as_ending()
def
start
(
self
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
None
:
def
start
(
self
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
None
:
"""
"""
...
@@ -210,11 +211,12 @@ class RetiariiExperiment(Experiment):
...
@@ -210,11 +211,12 @@ class RetiariiExperiment(Experiment):
msg
=
'Web UI URLs: '
+
colorama
.
Fore
.
CYAN
+
' '
.
join
(
ips
)
+
colorama
.
Style
.
RESET_ALL
msg
=
'Web UI URLs: '
+
colorama
.
Fore
.
CYAN
+
' '
.
join
(
ips
)
+
colorama
.
Style
.
RESET_ALL
_logger
.
info
(
msg
)
_logger
.
info
(
msg
)
Thread
(
target
=
self
.
_check_exp_status
).
start
()
exp_status_checker
=
Thread
(
target
=
self
.
_check_exp_status
)
exp_status_checker
.
start
()
self
.
_start_strategy
()
self
.
_start_strategy
()
# TODO: the experiment should be completed, when strategy exits and there is no running job
# TODO: the experiment should be completed, when strategy exits and there is no running job
# _logger.info('Waiting for submitted trial jobs to finish...')
_logger
.
info
(
'Waiting for experiment to become DONE (you can ctrl+c if there is no running trial jobs)...'
)
_logger
.
info
(
'Waiting for experiment to become DONE (you can ctrl+c if there is no running trial jobs)...'
)
exp_status_checker
.
join
()
def
_create_dispatcher
(
self
):
def
_create_dispatcher
(
self
):
return
self
.
_dispatcher
return
self
.
_dispatcher
...
@@ -240,7 +242,12 @@ class RetiariiExperiment(Experiment):
...
@@ -240,7 +242,12 @@ class RetiariiExperiment(Experiment):
try
:
try
:
while
True
:
while
True
:
time
.
sleep
(
10
)
time
.
sleep
(
10
)
status
=
self
.
get_status
()
# this if is to deal with the situation that
# nnimanager is cleaned up by ctrl+c first
if
self
.
_proc
.
poll
()
is
None
:
status
=
self
.
get_status
()
else
:
return
False
if
status
==
'DONE'
or
status
==
'STOPPED'
:
if
status
==
'DONE'
or
status
==
'STOPPED'
:
return
True
return
True
if
status
==
'ERROR'
:
if
status
==
'ERROR'
:
...
@@ -261,7 +268,10 @@ class RetiariiExperiment(Experiment):
...
@@ -261,7 +268,10 @@ class RetiariiExperiment(Experiment):
nni
.
runtime
.
log
.
stop_experiment_log
(
self
.
id
)
nni
.
runtime
.
log
.
stop_experiment_log
(
self
.
id
)
if
self
.
_proc
is
not
None
:
if
self
.
_proc
is
not
None
:
try
:
try
:
rest
.
delete
(
self
.
port
,
'/experiment'
)
# this if is to deal with the situation that
# nnimanager is cleaned up by ctrl+c first
if
self
.
_proc
.
poll
()
is
None
:
rest
.
delete
(
self
.
port
,
'/experiment'
)
except
Exception
as
e
:
except
Exception
as
e
:
_logger
.
exception
(
e
)
_logger
.
exception
(
e
)
_logger
.
warning
(
'Cannot gracefully stop experiment, killing NNI process...'
)
_logger
.
warning
(
'Cannot gracefully stop experiment, killing NNI process...'
)
...
...
nni/retiarii/strategy/tpe_strategy.py
View file @
4a0cc125
...
@@ -6,7 +6,7 @@ import time
...
@@ -6,7 +6,7 @@ import time
from
nni.algorithms.hpo.hyperopt_tuner
import
HyperoptTuner
from
nni.algorithms.hpo.hyperopt_tuner
import
HyperoptTuner
from
..
import
Sampler
,
submit_models
,
query_available_resources
,
is_stopped_exec
from
..
import
Sampler
,
submit_models
,
query_available_resources
,
is_stopped_exec
,
budget_exhausted
from
.base
import
BaseStrategy
from
.base
import
BaseStrategy
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -54,7 +54,7 @@ class TPEStrategy(BaseStrategy):
...
@@ -54,7 +54,7 @@ class TPEStrategy(BaseStrategy):
self
.
tpe_sampler
.
update_sample_space
(
sample_space
)
self
.
tpe_sampler
.
update_sample_space
(
sample_space
)
_logger
.
info
(
'TPE strategy has been started.'
)
_logger
.
info
(
'TPE strategy has been started.'
)
while
True
:
while
not
budget_exhausted
()
:
avail_resource
=
query_available_resources
()
avail_resource
=
query_available_resources
()
if
avail_resource
>
0
:
if
avail_resource
>
0
:
model
=
base_model
model
=
base_model
...
@@ -70,13 +70,13 @@ class TPEStrategy(BaseStrategy):
...
@@ -70,13 +70,13 @@ class TPEStrategy(BaseStrategy):
else
:
else
:
time
.
sleep
(
2
)
time
.
sleep
(
2
)
_logger
.
warnin
g
(
'num of running models: %d'
,
len
(
self
.
running_models
))
_logger
.
debu
g
(
'num of running models: %d'
,
len
(
self
.
running_models
))
to_be_deleted
=
[]
to_be_deleted
=
[]
for
_id
,
_model
in
self
.
running_models
.
items
():
for
_id
,
_model
in
self
.
running_models
.
items
():
if
is_stopped_exec
(
_model
):
if
is_stopped_exec
(
_model
):
if
_model
.
metric
is
not
None
:
if
_model
.
metric
is
not
None
:
self
.
tpe_sampler
.
receive_result
(
_id
,
_model
.
metric
)
self
.
tpe_sampler
.
receive_result
(
_id
,
_model
.
metric
)
_logger
.
warnin
g
(
'tpe receive results: %d, %s'
,
_id
,
_model
.
metric
)
_logger
.
debu
g
(
'tpe receive results: %d, %s'
,
_id
,
_model
.
metric
)
to_be_deleted
.
append
(
_id
)
to_be_deleted
.
append
(
_id
)
for
_id
in
to_be_deleted
:
for
_id
in
to_be_deleted
:
del
self
.
running_models
[
_id
]
del
self
.
running_models
[
_id
]
test/ut/retiarii/test_strategy.py
View file @
4a0cc125
...
@@ -43,6 +43,9 @@ class MockExecutionEngine(AbstractExecutionEngine):
...
@@ -43,6 +43,9 @@ class MockExecutionEngine(AbstractExecutionEngine):
def
query_available_resource
(
self
)
->
Union
[
List
[
WorkerInfo
],
int
]:
def
query_available_resource
(
self
)
->
Union
[
List
[
WorkerInfo
],
int
]:
return
self
.
_resource_left
return
self
.
_resource_left
def
budget_exhausted
(
self
)
->
bool
:
pass
def
register_graph_listener
(
self
,
listener
:
AbstractGraphListener
)
->
None
:
def
register_graph_listener
(
self
,
listener
:
AbstractGraphListener
)
->
None
:
pass
pass
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment