Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
aea98dd6
Unverified
Commit
aea98dd6
authored
Apr 03, 2021
by
Yuge Zhang
Committed by
GitHub
Apr 03, 2021
Browse files
[Retiarii] Export topk models (#3464)
parent
0494cae1
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
57 additions
and
10 deletions
+57
-10
nni/retiarii/execution/api.py
nni/retiarii/execution/api.py
+8
-1
nni/retiarii/execution/base.py
nni/retiarii/execution/base.py
+6
-1
nni/retiarii/execution/cgo_engine.py
nni/retiarii/execution/cgo_engine.py
+4
-1
nni/retiarii/execution/interface.py
nni/retiarii/execution/interface.py
+10
-1
nni/retiarii/experiment/pytorch.py
nni/retiarii/experiment/pytorch.py
+22
-5
test/retiarii_test/mnist/test.py
test/retiarii_test/mnist/test.py
+4
-1
test/ut/retiarii/test_strategy.py
test/ut/retiarii/test_strategy.py
+3
-0
No files found.
nni/retiarii/execution/api.py
View file @
aea98dd6
...
...
@@ -2,6 +2,7 @@
# Licensed under the MIT license.
import
time
from
typing
import
Iterable
from
..graph
import
Model
,
ModelStatus
from
.interface
import
AbstractExecutionEngine
...
...
@@ -11,7 +12,7 @@ _execution_engine = None
_default_listener
=
None
__all__
=
[
'get_execution_engine'
,
'get_and_register_default_listener'
,
'submit_models'
,
'wait_models'
,
'query_available_resources'
,
'list_models'
,
'submit_models'
,
'wait_models'
,
'query_available_resources'
,
'set_execution_engine'
,
'is_stopped_exec'
]
def
set_execution_engine
(
engine
)
->
None
:
...
...
@@ -43,6 +44,12 @@ def submit_models(*models: Model) -> None:
engine
.
submit_models
(
*
models
)
def
list_models
(
*
models
:
Model
)
->
Iterable
[
Model
]:
engine
=
get_execution_engine
()
get_and_register_default_listener
(
engine
)
return
engine
.
list_models
()
def
wait_models
(
*
models
:
Model
)
->
None
:
get_and_register_default_listener
(
get_execution_engine
())
while
True
:
...
...
nni/retiarii/execution/base.py
View file @
aea98dd6
...
...
@@ -5,7 +5,7 @@ import logging
import
os
import
random
import
string
from
typing
import
Dict
,
List
from
typing
import
Dict
,
Iterable
,
List
from
.interface
import
AbstractExecutionEngine
,
AbstractGraphListener
from
..
import
codegen
,
utils
...
...
@@ -53,6 +53,7 @@ class BaseExecutionEngine(AbstractExecutionEngine):
advisor
.
final_metric_callback
=
self
.
_final_metric_callback
self
.
_running_models
:
Dict
[
int
,
Model
]
=
dict
()
self
.
_history
:
List
[
Model
]
=
[]
self
.
resources
=
0
...
...
@@ -60,6 +61,10 @@ class BaseExecutionEngine(AbstractExecutionEngine):
for
model
in
models
:
data
=
BaseGraphData
(
codegen
.
model_to_pytorch_script
(
model
),
model
.
evaluator
)
self
.
_running_models
[
send_trial
(
data
.
dump
())]
=
model
self
.
_history
.
append
(
model
)
def
list_models
(
self
)
->
Iterable
[
Model
]:
return
self
.
_history
def
register_graph_listener
(
self
,
listener
:
AbstractGraphListener
)
->
None
:
self
.
_listeners
.
append
(
listener
)
...
...
nni/retiarii/execution/cgo_engine.py
View file @
aea98dd6
...
...
@@ -2,7 +2,7 @@
# Licensed under the MIT license.
import
logging
from
typing
import
List
,
Dict
,
Tuple
from
typing
import
Iterable
,
List
,
Dict
,
Tuple
from
.interface
import
AbstractExecutionEngine
,
AbstractGraphListener
,
WorkerInfo
from
..
import
codegen
,
utils
...
...
@@ -58,6 +58,9 @@ class CGOExecutionEngine(AbstractExecutionEngine):
# model.config['trainer_module'], model.config['trainer_kwargs'])
# self._running_models[send_trial(data.dump())] = model
def
list_models
(
self
)
->
Iterable
[
Model
]:
raise
NotImplementedError
def
_assemble
(
self
,
logical_plan
:
LogicalPlan
)
->
List
[
Tuple
[
Model
,
PhysicalDevice
]]:
# unique_models = set()
# for node in logical_plan.graph.nodes:
...
...
nni/retiarii/execution/interface.py
View file @
aea98dd6
...
...
@@ -2,7 +2,7 @@
# Licensed under the MIT license.
from
abc
import
ABC
,
abstractmethod
,
abstractclassmethod
from
typing
import
Any
,
NewType
,
List
,
Union
from
typing
import
Any
,
Iterable
,
NewType
,
List
,
Union
from
..graph
import
Model
,
MetricData
...
...
@@ -104,6 +104,15 @@ class AbstractExecutionEngine(ABC):
"""
raise
NotImplementedError
@
abstractmethod
def
list_models
(
self
)
->
Iterable
[
Model
]:
"""
Get all models in submitted.
Execution engine should store a copy of models that have been submitted and return a list of copies in this method.
"""
raise
NotImplementedError
@
abstractmethod
def
query_available_resource
(
self
)
->
Union
[
List
[
WorkerInfo
],
int
]:
"""
...
...
nni/retiarii/experiment/pytorch.py
View file @
aea98dd6
...
...
@@ -26,7 +26,9 @@ from nni.experiment.config.base import ConfigBase, PathLike
from
nni.experiment.pipe
import
Pipe
from
nni.tools.nnictl.command_utils
import
kill_command
from
..codegen
import
model_to_pytorch_script
from
..converter
import
convert_to_graph
from
..execution
import
list_models
from
..graph
import
Model
,
Evaluator
from
..integration
import
RetiariiAdvisor
from
..mutator
import
Mutator
...
...
@@ -257,16 +259,31 @@ class RetiariiExperiment(Experiment):
self
.
_dispatcher_thread
=
None
_logger
.
info
(
'Experiment stopped'
)
def
export_top_models
(
self
,
top_
n
:
int
=
1
)
:
def
export_top_models
(
self
,
top_
k
:
int
=
1
,
optimize_mode
:
str
=
'maximize'
,
formatter
:
str
=
'code'
)
->
Any
:
"""
export several top performing models
Export several top performing models.
For one-shot algorithms, only top-1 is supported. For others, ``optimize_mode`` asnd ``formater`` is
available for customization.
top_k : int
How many models are intended to be exported.
optimize_mode : str
``maximize`` or ``minimize``. Not supported by one-shot algorithms.
``optimize_mode`` is likely to be removed and defined in strategy in future.
formatter : str
Only model code is supported for now. Not supported by one-shot algorithms.
"""
if
top_n
!=
1
:
_logger
.
warning
(
'Only support top_n is 1 for now.'
)
if
isinstance
(
self
.
trainer
,
BaseOneShotTrainer
):
assert
top_k
==
1
,
'Only support top_k is 1 for now.'
return
self
.
trainer
.
export
()
else
:
_logger
.
info
(
'For this experiment, you can find out the best one from WebUI.'
)
all_models
=
filter
(
lambda
m
:
m
.
metric
is
not
None
,
list_models
())
assert
optimize_mode
in
[
'maximize'
,
'minimize'
]
all_models
=
sorted
(
all_models
,
key
=
lambda
m
:
m
.
metric
,
reverse
=
optimize_mode
==
'maximize'
)
assert
formatter
==
'code'
,
'Export formatter other than "code" is not supported yet.'
if
formatter
==
'code'
:
return
[
model_to_pytorch_script
(
model
)
for
model
in
all_models
[:
top_k
]]
def
retrain_model
(
self
,
model
):
"""
...
...
test/retiarii_test/mnist/test.py
View file @
aea98dd6
...
...
@@ -49,7 +49,10 @@ if __name__ == '__main__':
exp_config
=
RetiariiExeConfig
(
'local'
)
exp_config
.
experiment_name
=
'mnist_search'
exp_config
.
trial_concurrency
=
2
exp_config
.
max_trial_number
=
10
exp_config
.
max_trial_number
=
2
exp_config
.
training_service
.
use_active_gpu
=
False
exp
.
run
(
exp_config
,
8081
+
random
.
randint
(
0
,
100
))
print
(
'Final model:'
)
for
model_code
in
exp
.
export_top_models
():
print
(
model_code
)
test/ut/retiarii/test_strategy.py
View file @
aea98dd6
...
...
@@ -37,6 +37,9 @@ class MockExecutionEngine(AbstractExecutionEngine):
self
.
_resource_left
-=
1
threading
.
Thread
(
target
=
self
.
_model_complete
,
args
=
(
model
,
)).
start
()
def
list_models
(
self
)
->
List
[
Model
]:
return
self
.
models
def
query_available_resource
(
self
)
->
Union
[
List
[
WorkerInfo
],
int
]:
return
self
.
_resource_left
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment