Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
4a0cc125
Unverified
Commit
4a0cc125
authored
Apr 21, 2021
by
QuanluZhang
Committed by
GitHub
Apr 21, 2021
Browse files
[Retiarii] fix experiment early exit (#3547)
parent
336d671c
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
43 additions
and
10 deletions
+43
-10
nni/retiarii/execution/api.py
nni/retiarii/execution/api.py
+7
-1
nni/retiarii/execution/base.py
nni/retiarii/execution/base.py
+4
-0
nni/retiarii/execution/cgo_engine.py
nni/retiarii/execution/cgo_engine.py
+3
-0
nni/retiarii/execution/interface.py
nni/retiarii/execution/interface.py
+7
-0
nni/retiarii/experiment/pytorch.py
nni/retiarii/experiment/pytorch.py
+15
-5
nni/retiarii/strategy/tpe_strategy.py
nni/retiarii/strategy/tpe_strategy.py
+4
-4
test/ut/retiarii/test_strategy.py
test/ut/retiarii/test_strategy.py
+3
-0
No files found.
nni/retiarii/execution/api.py
View file @
4a0cc125
...
@@ -13,7 +13,7 @@ _default_listener = None
...
@@ -13,7 +13,7 @@ _default_listener = None
__all__
=
[
'get_execution_engine'
,
'get_and_register_default_listener'
,
__all__
=
[
'get_execution_engine'
,
'get_and_register_default_listener'
,
'list_models'
,
'submit_models'
,
'wait_models'
,
'query_available_resources'
,
'list_models'
,
'submit_models'
,
'wait_models'
,
'query_available_resources'
,
'set_execution_engine'
,
'is_stopped_exec'
]
'set_execution_engine'
,
'is_stopped_exec'
,
'budget_exhausted'
]
def
set_execution_engine
(
engine
)
->
None
:
def
set_execution_engine
(
engine
)
->
None
:
global
_execution_engine
global
_execution_engine
...
@@ -22,6 +22,7 @@ def set_execution_engine(engine) -> None:
...
@@ -22,6 +22,7 @@ def set_execution_engine(engine) -> None:
else
:
else
:
raise
RuntimeError
(
'execution engine is already set'
)
raise
RuntimeError
(
'execution engine is already set'
)
def
get_execution_engine
()
->
AbstractExecutionEngine
:
def
get_execution_engine
()
->
AbstractExecutionEngine
:
"""
"""
Currently we assume the default execution engine is BaseExecutionEngine.
Currently we assume the default execution engine is BaseExecutionEngine.
...
@@ -67,3 +68,8 @@ def query_available_resources() -> int:
...
@@ -67,3 +68,8 @@ def query_available_resources() -> int:
def
is_stopped_exec
(
model
:
Model
)
->
bool
:
def
is_stopped_exec
(
model
:
Model
)
->
bool
:
return
model
.
status
in
(
ModelStatus
.
Trained
,
ModelStatus
.
Failed
)
return
model
.
status
in
(
ModelStatus
.
Trained
,
ModelStatus
.
Failed
)
def
budget_exhausted
()
->
bool
:
engine
=
get_execution_engine
()
return
engine
.
budget_exhausted
()
nni/retiarii/execution/base.py
View file @
4a0cc125
...
@@ -104,6 +104,10 @@ class BaseExecutionEngine(AbstractExecutionEngine):
...
@@ -104,6 +104,10 @@ class BaseExecutionEngine(AbstractExecutionEngine):
def
query_available_resource
(
self
)
->
int
:
def
query_available_resource
(
self
)
->
int
:
return
self
.
resources
return
self
.
resources
def
budget_exhausted
(
self
)
->
bool
:
advisor
=
get_advisor
()
return
advisor
.
stopping
@
classmethod
@
classmethod
def
trial_execute_graph
(
cls
)
->
None
:
def
trial_execute_graph
(
cls
)
->
None
:
"""
"""
...
...
nni/retiarii/execution/cgo_engine.py
View file @
4a0cc125
...
@@ -130,6 +130,9 @@ class CGOExecutionEngine(AbstractExecutionEngine):
...
@@ -130,6 +130,9 @@ class CGOExecutionEngine(AbstractExecutionEngine):
def
query_available_resource
(
self
)
->
List
[
WorkerInfo
]:
def
query_available_resource
(
self
)
->
List
[
WorkerInfo
]:
raise
NotImplementedError
# move the method from listener to here?
raise
NotImplementedError
# move the method from listener to here?
def
budget_exhausted
(
self
)
->
bool
:
raise
NotImplementedError
@
classmethod
@
classmethod
def
trial_execute_graph
(
cls
)
->
None
:
def
trial_execute_graph
(
cls
)
->
None
:
"""
"""
...
...
nni/retiarii/execution/interface.py
View file @
4a0cc125
...
@@ -123,6 +123,13 @@ class AbstractExecutionEngine(ABC):
...
@@ -123,6 +123,13 @@ class AbstractExecutionEngine(ABC):
"""
"""
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
def
budget_exhausted
(
self
)
->
bool
:
"""
Check whether user configured max trial number or max execution duration has been reached
"""
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
register_graph_listener
(
self
,
listener
:
AbstractGraphListener
)
->
None
:
def
register_graph_listener
(
self
,
listener
:
AbstractGraphListener
)
->
None
:
"""
"""
...
...
nni/retiarii/experiment/pytorch.py
View file @
4a0cc125
...
@@ -165,7 +165,8 @@ class RetiariiExperiment(Experiment):
...
@@ -165,7 +165,8 @@ class RetiariiExperiment(Experiment):
_logger
.
info
(
'Start strategy...'
)
_logger
.
info
(
'Start strategy...'
)
self
.
strategy
.
run
(
base_model_ir
,
self
.
applied_mutators
)
self
.
strategy
.
run
(
base_model_ir
,
self
.
applied_mutators
)
_logger
.
info
(
'Strategy exit'
)
_logger
.
info
(
'Strategy exit'
)
self
.
_dispatcher
.
mark_experiment_as_ending
()
# TODO: find out a proper way to show no more trial message on WebUI
#self._dispatcher.mark_experiment_as_ending()
def
start
(
self
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
None
:
def
start
(
self
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
None
:
"""
"""
...
@@ -210,11 +211,12 @@ class RetiariiExperiment(Experiment):
...
@@ -210,11 +211,12 @@ class RetiariiExperiment(Experiment):
msg
=
'Web UI URLs: '
+
colorama
.
Fore
.
CYAN
+
' '
.
join
(
ips
)
+
colorama
.
Style
.
RESET_ALL
msg
=
'Web UI URLs: '
+
colorama
.
Fore
.
CYAN
+
' '
.
join
(
ips
)
+
colorama
.
Style
.
RESET_ALL
_logger
.
info
(
msg
)
_logger
.
info
(
msg
)
Thread
(
target
=
self
.
_check_exp_status
).
start
()
exp_status_checker
=
Thread
(
target
=
self
.
_check_exp_status
)
exp_status_checker
.
start
()
self
.
_start_strategy
()
self
.
_start_strategy
()
# TODO: the experiment should be completed, when strategy exits and there is no running job
# TODO: the experiment should be completed, when strategy exits and there is no running job
# _logger.info('Waiting for submitted trial jobs to finish...')
_logger
.
info
(
'Waiting for experiment to become DONE (you can ctrl+c if there is no running trial jobs)...'
)
_logger
.
info
(
'Waiting for experiment to become DONE (you can ctrl+c if there is no running trial jobs)...'
)
exp_status_checker
.
join
()
def
_create_dispatcher
(
self
):
def
_create_dispatcher
(
self
):
return
self
.
_dispatcher
return
self
.
_dispatcher
...
@@ -240,7 +242,12 @@ class RetiariiExperiment(Experiment):
...
@@ -240,7 +242,12 @@ class RetiariiExperiment(Experiment):
try
:
try
:
while
True
:
while
True
:
time
.
sleep
(
10
)
time
.
sleep
(
10
)
status
=
self
.
get_status
()
# this if is to deal with the situation that
# nnimanager is cleaned up by ctrl+c first
if
self
.
_proc
.
poll
()
is
None
:
status
=
self
.
get_status
()
else
:
return
False
if
status
==
'DONE'
or
status
==
'STOPPED'
:
if
status
==
'DONE'
or
status
==
'STOPPED'
:
return
True
return
True
if
status
==
'ERROR'
:
if
status
==
'ERROR'
:
...
@@ -261,7 +268,10 @@ class RetiariiExperiment(Experiment):
...
@@ -261,7 +268,10 @@ class RetiariiExperiment(Experiment):
nni
.
runtime
.
log
.
stop_experiment_log
(
self
.
id
)
nni
.
runtime
.
log
.
stop_experiment_log
(
self
.
id
)
if
self
.
_proc
is
not
None
:
if
self
.
_proc
is
not
None
:
try
:
try
:
rest
.
delete
(
self
.
port
,
'/experiment'
)
# this if is to deal with the situation that
# nnimanager is cleaned up by ctrl+c first
if
self
.
_proc
.
poll
()
is
None
:
rest
.
delete
(
self
.
port
,
'/experiment'
)
except
Exception
as
e
:
except
Exception
as
e
:
_logger
.
exception
(
e
)
_logger
.
exception
(
e
)
_logger
.
warning
(
'Cannot gracefully stop experiment, killing NNI process...'
)
_logger
.
warning
(
'Cannot gracefully stop experiment, killing NNI process...'
)
...
...
nni/retiarii/strategy/tpe_strategy.py
View file @
4a0cc125
...
@@ -6,7 +6,7 @@ import time
...
@@ -6,7 +6,7 @@ import time
from
nni.algorithms.hpo.hyperopt_tuner
import
HyperoptTuner
from
nni.algorithms.hpo.hyperopt_tuner
import
HyperoptTuner
from
..
import
Sampler
,
submit_models
,
query_available_resources
,
is_stopped_exec
from
..
import
Sampler
,
submit_models
,
query_available_resources
,
is_stopped_exec
,
budget_exhausted
from
.base
import
BaseStrategy
from
.base
import
BaseStrategy
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -54,7 +54,7 @@ class TPEStrategy(BaseStrategy):
...
@@ -54,7 +54,7 @@ class TPEStrategy(BaseStrategy):
self
.
tpe_sampler
.
update_sample_space
(
sample_space
)
self
.
tpe_sampler
.
update_sample_space
(
sample_space
)
_logger
.
info
(
'TPE strategy has been started.'
)
_logger
.
info
(
'TPE strategy has been started.'
)
while
True
:
while
not
budget_exhausted
()
:
avail_resource
=
query_available_resources
()
avail_resource
=
query_available_resources
()
if
avail_resource
>
0
:
if
avail_resource
>
0
:
model
=
base_model
model
=
base_model
...
@@ -70,13 +70,13 @@ class TPEStrategy(BaseStrategy):
...
@@ -70,13 +70,13 @@ class TPEStrategy(BaseStrategy):
else
:
else
:
time
.
sleep
(
2
)
time
.
sleep
(
2
)
_logger
.
warnin
g
(
'num of running models: %d'
,
len
(
self
.
running_models
))
_logger
.
debu
g
(
'num of running models: %d'
,
len
(
self
.
running_models
))
to_be_deleted
=
[]
to_be_deleted
=
[]
for
_id
,
_model
in
self
.
running_models
.
items
():
for
_id
,
_model
in
self
.
running_models
.
items
():
if
is_stopped_exec
(
_model
):
if
is_stopped_exec
(
_model
):
if
_model
.
metric
is
not
None
:
if
_model
.
metric
is
not
None
:
self
.
tpe_sampler
.
receive_result
(
_id
,
_model
.
metric
)
self
.
tpe_sampler
.
receive_result
(
_id
,
_model
.
metric
)
_logger
.
warnin
g
(
'tpe receive results: %d, %s'
,
_id
,
_model
.
metric
)
_logger
.
debu
g
(
'tpe receive results: %d, %s'
,
_id
,
_model
.
metric
)
to_be_deleted
.
append
(
_id
)
to_be_deleted
.
append
(
_id
)
for
_id
in
to_be_deleted
:
for
_id
in
to_be_deleted
:
del
self
.
running_models
[
_id
]
del
self
.
running_models
[
_id
]
test/ut/retiarii/test_strategy.py
View file @
4a0cc125
...
@@ -43,6 +43,9 @@ class MockExecutionEngine(AbstractExecutionEngine):
...
@@ -43,6 +43,9 @@ class MockExecutionEngine(AbstractExecutionEngine):
def
query_available_resource
(
self
)
->
Union
[
List
[
WorkerInfo
],
int
]:
def
query_available_resource
(
self
)
->
Union
[
List
[
WorkerInfo
],
int
]:
return
self
.
_resource_left
return
self
.
_resource_left
def
budget_exhausted
(
self
)
->
bool
:
pass
def
register_graph_listener
(
self
,
listener
:
AbstractGraphListener
)
->
None
:
def
register_graph_listener
(
self
,
listener
:
AbstractGraphListener
)
->
None
:
pass
pass
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment