Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
2fc47247
Unverified
Commit
2fc47247
authored
May 24, 2022
by
QuanluZhang
Committed by
GitHub
May 24, 2022
Browse files
[retiarii] refactor of nas experiment (#4841)
parent
c80bda29
Changes
20
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
346 additions
and
324 deletions
+346
-324
nni/experiment/config/base.py
nni/experiment/config/base.py
+5
-0
nni/experiment/config/experiment_config.py
nni/experiment/config/experiment_config.py
+4
-3
nni/experiment/experiment.py
nni/experiment/experiment.py
+40
-25
nni/retiarii/execution/api.py
nni/retiarii/execution/api.py
+7
-6
nni/retiarii/execution/base.py
nni/retiarii/execution/base.py
+17
-3
nni/retiarii/execution/cgo_engine.py
nni/retiarii/execution/cgo_engine.py
+31
-7
nni/retiarii/experiment/__init__.py
nni/retiarii/experiment/__init__.py
+2
-0
nni/retiarii/experiment/config/__init__.py
nni/retiarii/experiment/config/__init__.py
+5
-0
nni/retiarii/experiment/config/engine_config.py
nni/retiarii/experiment/config/engine_config.py
+41
-0
nni/retiarii/experiment/config/experiment_config.py
nni/retiarii/experiment/config/experiment_config.py
+60
-0
nni/retiarii/experiment/pytorch.py
nni/retiarii/experiment/pytorch.py
+88
-257
nni/retiarii/integration_api.py
nni/retiarii/integration_api.py
+4
-1
nni/runtime/msg_dispatcher_base.py
nni/runtime/msg_dispatcher_base.py
+18
-2
test/retiarii_test/cgo_mnasnet/base_mnasnet.py
test/retiarii_test/cgo_mnasnet/base_mnasnet.py
+0
-1
test/retiarii_test/cgo_mnasnet/test.py
test/retiarii_test/cgo_mnasnet/test.py
+3
-5
test/ut/retiarii/test_cgo_engine.py
test/ut/retiarii/test_cgo_engine.py
+13
-9
test/ut/retiarii/test_engine.py
test/ut/retiarii/test_engine.py
+2
-2
test/ut/sdk/test_assessor.py
test/ut/sdk/test_assessor.py
+1
-1
test/ut/sdk/test_msg_dispatcher.py
test/ut/sdk/test_msg_dispatcher.py
+1
-1
ts/nni_manager/core/nnimanager.ts
ts/nni_manager/core/nnimanager.ts
+4
-1
No files found.
nni/experiment/config/base.py
View file @
2fc47247
...
@@ -54,6 +54,11 @@ class ConfigBase:
...
@@ -54,6 +54,11 @@ class ConfigBase:
Config objects will remember where they are loaded; therefore relative paths can be resolved smartly.
Config objects will remember where they are loaded; therefore relative paths can be resolved smartly.
If a config object is created with constructor, the base path will be current working directory.
If a config object is created with constructor, the base path will be current working directory.
If it is loaded with ``ConfigBase.load(path)``, the base path will be ``path``'s parent.
If it is loaded with ``ConfigBase.load(path)``, the base path will be ``path``'s parent.
.. attention::
All the classes that inherit ``ConfigBase`` are not allowed to use ``from __future__ import annotations``,
because ``ConfigBase`` uses ``typeguard`` to perform runtime check and it does not support lazy annotations.
"""
"""
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
**
kwargs
):
...
...
nni/experiment/config/experiment_config.py
View file @
2fc47247
...
@@ -164,10 +164,11 @@ class ExperimentConfig(ConfigBase):
...
@@ -164,10 +164,11 @@ class ExperimentConfig(ConfigBase):
# currently I have only seen one issue of this kind
# currently I have only seen one issue of this kind
#Path(self.experiment_working_directory).mkdir(parents=True, exist_ok=True)
#Path(self.experiment_working_directory).mkdir(parents=True, exist_ok=True)
utils
.
validate_gpu_indices
(
self
.
tuner_gpu_indices
)
if
type
(
self
).
__name__
!=
'RetiariiExeConfig'
:
utils
.
validate_gpu_indices
(
self
.
tuner_gpu_indices
)
if
self
.
tuner
is
None
:
if
self
.
tuner
is
None
:
raise
ValueError
(
'ExperimentConfig: tuner must be set'
)
raise
ValueError
(
'ExperimentConfig: tuner must be set'
)
def
_load_search_space_file
(
search_space_path
):
def
_load_search_space_file
(
search_space_path
):
# FIXME
# FIXME
...
...
nni/experiment/experiment.py
View file @
2fc47247
...
@@ -84,20 +84,9 @@ class Experiment:
...
@@ -84,20 +84,9 @@ class Experiment:
else
:
else
:
self
.
config
=
config_or_platform
self
.
config
=
config_or_platform
def
start
(
self
,
port
:
int
=
8080
,
debug
:
bool
=
False
,
run_mode
:
RunMode
=
RunMode
.
Background
)
->
None
:
def
_start_impl
(
self
,
port
:
int
,
debug
:
bool
,
run_mode
:
RunMode
,
"""
tuner_command_channel
:
str
|
None
,
Start the experiment in background.
tags
:
list
[
str
]
=
[])
->
ExperimentConfig
:
This method will raise exception on failure.
If it returns, the experiment should have been successfully started.
Parameters
----------
port
The port of web UI.
debug
Whether to start in debug mode.
"""
assert
self
.
config
is
not
None
assert
self
.
config
is
not
None
if
run_mode
is
not
RunMode
.
Detach
:
if
run_mode
is
not
RunMode
.
Detach
:
atexit
.
register
(
self
.
stop
)
atexit
.
register
(
self
.
stop
)
...
@@ -111,7 +100,8 @@ class Experiment:
...
@@ -111,7 +100,8 @@ class Experiment:
log_level
=
'debug'
if
(
debug
or
config
.
log_level
==
'trace'
)
else
config
.
log_level
log_level
=
'debug'
if
(
debug
or
config
.
log_level
==
'trace'
)
else
config
.
log_level
start_experiment_logging
(
self
.
id
,
log_file
,
cast
(
str
,
log_level
))
start_experiment_logging
(
self
.
id
,
log_file
,
cast
(
str
,
log_level
))
self
.
_proc
=
launcher
.
start_experiment
(
self
.
_action
,
self
.
id
,
config
,
port
,
debug
,
run_mode
,
self
.
url_prefix
)
self
.
_proc
=
launcher
.
start_experiment
(
self
.
_action
,
self
.
id
,
config
,
port
,
debug
,
run_mode
,
self
.
url_prefix
,
tuner_command_channel
,
tags
)
assert
self
.
_proc
is
not
None
assert
self
.
_proc
is
not
None
self
.
port
=
port
# port will be None if start up failed
self
.
port
=
port
# port will be None if start up failed
...
@@ -124,12 +114,27 @@ class Experiment:
...
@@ -124,12 +114,27 @@ class Experiment:
ips
=
[
f
'http://
{
ip
}
:
{
port
}
'
for
ip
in
ips
if
ip
]
ips
=
[
f
'http://
{
ip
}
:
{
port
}
'
for
ip
in
ips
if
ip
]
msg
=
'Web portal URLs: ${CYAN}'
+
' '
.
join
(
ips
)
msg
=
'Web portal URLs: ${CYAN}'
+
' '
.
join
(
ips
)
_logger
.
info
(
msg
)
_logger
.
info
(
msg
)
return
config
def
st
op
(
self
)
->
None
:
def
st
art
(
self
,
port
:
int
=
8080
,
debug
:
bool
=
False
,
run_mode
:
RunMode
=
RunMode
.
Background
)
->
None
:
"""
"""
Stop the experiment.
Start the experiment in background.
This method will raise exception on failure.
If it returns, the experiment should have been successfully started.
Parameters
----------
port
The port of web UI.
debug
Whether to start in debug mode.
run_mode
Running the experiment in foreground or background
"""
"""
_logger
.
info
(
'Stopping experiment, please wait...'
)
self
.
_start_impl
(
port
,
debug
,
run_mode
,
None
,
[])
def
_stop_impl
(
self
)
->
None
:
atexit
.
unregister
(
self
.
stop
)
atexit
.
unregister
(
self
.
stop
)
stop_experiment_logging
(
self
.
id
)
stop_experiment_logging
(
self
.
id
)
...
@@ -144,8 +149,24 @@ class Experiment:
...
@@ -144,8 +149,24 @@ class Experiment:
self
.
id
=
None
# type: ignore
self
.
id
=
None
# type: ignore
self
.
port
=
None
self
.
port
=
None
self
.
_proc
=
None
self
.
_proc
=
None
def
stop
(
self
)
->
None
:
"""
Stop the experiment.
"""
_logger
.
info
(
'Stopping experiment, please wait...'
)
self
.
_stop_impl
()
_logger
.
info
(
'Experiment stopped'
)
_logger
.
info
(
'Experiment stopped'
)
def
_wait_completion
(
self
)
->
bool
:
while
True
:
status
=
self
.
get_status
()
if
status
==
'DONE'
or
status
==
'STOPPED'
:
return
True
if
status
==
'ERROR'
:
return
False
time
.
sleep
(
10
)
def
run
(
self
,
port
:
int
=
8080
,
wait_completion
:
bool
=
True
,
debug
:
bool
=
False
)
->
bool
|
None
:
def
run
(
self
,
port
:
int
=
8080
,
wait_completion
:
bool
=
True
,
debug
:
bool
=
False
)
->
bool
|
None
:
"""
"""
Run the experiment.
Run the experiment.
...
@@ -159,13 +180,7 @@ class Experiment:
...
@@ -159,13 +180,7 @@ class Experiment:
self
.
start
(
port
,
debug
)
self
.
start
(
port
,
debug
)
if
wait_completion
:
if
wait_completion
:
try
:
try
:
while
True
:
self
.
_wait_completion
()
time
.
sleep
(
10
)
status
=
self
.
get_status
()
if
status
==
'DONE'
or
status
==
'STOPPED'
:
return
True
if
status
==
'ERROR'
:
return
False
except
KeyboardInterrupt
:
except
KeyboardInterrupt
:
_logger
.
warning
(
'KeyboardInterrupt detected'
)
_logger
.
warning
(
'KeyboardInterrupt detected'
)
self
.
stop
()
self
.
stop
()
...
...
nni/retiarii/execution/api.py
View file @
2fc47247
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
# Licensed under the MIT license.
# Licensed under the MIT license.
import
time
import
time
import
warnings
from
typing
import
Iterable
from
typing
import
Iterable
from
..graph
import
Model
,
ModelStatus
from
..graph
import
Model
,
ModelStatus
...
@@ -18,12 +19,12 @@ __all__ = ['get_execution_engine', 'get_and_register_default_listener',
...
@@ -18,12 +19,12 @@ __all__ = ['get_execution_engine', 'get_and_register_default_listener',
def
set_execution_engine
(
engine
:
AbstractExecutionEngine
)
->
None
:
def
set_execution_engine
(
engine
:
AbstractExecutionEngine
)
->
None
:
global
_execution_engine
global
_execution_engine
if
_execution_engine
is
None
:
if
_execution_engine
is
not
None
:
_e
xecution
_
engine
=
engine
warnings
.
warn
(
'E
xecution
engine
is already set. '
else
:
'You should avoid instantiating RetiariiExperiment twice in one process. '
rais
e
R
un
timeError
(
'Execution engine is already set. '
'If you ar
e
r
un
ning in a Jupyter notebook, please restart the kernel.'
,
'You should avoid instantiating RetiariiExperiment twice in one process. '
RuntimeWarning
)
'If you are runni
ng
in
a Jupyter notebook, please restart the kernel.'
)
_execution_e
ngin
e
=
engine
def
get_execution_engine
()
->
AbstractExecutionEngine
:
def
get_execution_engine
()
->
AbstractExecutionEngine
:
...
...
nni/retiarii/execution/base.py
View file @
2fc47247
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
from
__future__
import
annotations
import
logging
import
logging
import
os
import
os
import
random
import
random
import
string
import
string
from
typing
import
Any
,
Dict
,
Iterable
,
List
from
typing
import
Any
,
Dict
,
Iterable
,
List
from
nni.experiment
import
rest
from
.interface
import
AbstractExecutionEngine
,
AbstractGraphListener
from
.interface
import
AbstractExecutionEngine
,
AbstractGraphListener
from
.utils
import
get_mutation_summary
from
.utils
import
get_mutation_summary
from
..
import
codegen
,
utils
from
..
import
codegen
,
utils
...
@@ -54,12 +58,22 @@ class BaseExecutionEngine(AbstractExecutionEngine):
...
@@ -54,12 +58,22 @@ class BaseExecutionEngine(AbstractExecutionEngine):
Resource management is implemented in this class.
Resource management is implemented in this class.
"""
"""
def
__init__
(
self
)
->
None
:
def
__init__
(
self
,
rest_port
:
int
|
None
=
None
,
rest_url_prefix
:
str
|
None
=
None
)
->
None
:
"""
"""
Upon initialization, advisor callbacks need to be registered.
Upon initialization, advisor callbacks need to be registered.
Advisor will call the callbacks when the corresponding event has been triggered.
Advisor will call the callbacks when the corresponding event has been triggered.
Base execution engine will get those callbacks and broadcast them to graph listener.
Base execution engine will get those callbacks and broadcast them to graph listener.
Parameters
----------
rest_port
The port of the experiment's rest server
rest_url_prefix
The url prefix of the experiment's rest entry
"""
"""
self
.
port
=
rest_port
self
.
url_prefix
=
rest_url_prefix
self
.
_listeners
:
List
[
AbstractGraphListener
]
=
[]
self
.
_listeners
:
List
[
AbstractGraphListener
]
=
[]
# register advisor callbacks
# register advisor callbacks
...
@@ -123,8 +137,8 @@ class BaseExecutionEngine(AbstractExecutionEngine):
...
@@ -123,8 +137,8 @@ class BaseExecutionEngine(AbstractExecutionEngine):
return
self
.
resources
return
self
.
resources
def
budget_exhausted
(
self
)
->
bool
:
def
budget_exhausted
(
self
)
->
bool
:
advisor
=
get_advisor
(
)
resp
=
rest
.
get
(
self
.
port
,
'/check-status'
,
self
.
url_prefix
)
return
advisor
.
stopping
return
resp
[
'status'
]
==
'DONE'
@
classmethod
@
classmethod
def
pack_model_data
(
cls
,
model
:
Model
)
->
Any
:
def
pack_model_data
(
cls
,
model
:
Model
)
->
Any
:
...
...
nni/retiarii/execution/cgo_engine.py
View file @
2fc47247
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
from
__future__
import
annotations
import
logging
import
logging
import
os
import
os
import
random
import
random
import
string
import
string
import
time
import
time
import
threading
import
threading
from
typing
import
Iterable
,
List
,
Dict
,
Tuple
from
typing
import
Iterable
,
List
,
Dict
,
Tuple
,
cast
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
nni.common.device
import
GPUDevice
,
Device
from
nni.common.device
import
GPUDevice
,
Device
from
nni.experiment.config.training_services
import
RemoteConfig
from
.interface
import
AbstractExecutionEngine
,
AbstractGraphListener
,
WorkerInfo
from
.interface
import
AbstractExecutionEngine
,
AbstractGraphListener
,
WorkerInfo
from
..
import
codegen
,
utils
from
..
import
codegen
,
utils
from
..graph
import
Model
,
ModelStatus
,
MetricData
,
Node
from
..graph
import
Model
,
ModelStatus
,
MetricData
,
Node
...
@@ -31,7 +34,6 @@ class TrialSubmission:
...
@@ -31,7 +34,6 @@ class TrialSubmission:
placement
:
Dict
[
Node
,
Device
]
placement
:
Dict
[
Node
,
Device
]
grouped_models
:
List
[
Model
]
grouped_models
:
List
[
Model
]
class
CGOExecutionEngine
(
AbstractExecutionEngine
):
class
CGOExecutionEngine
(
AbstractExecutionEngine
):
"""
"""
The execution engine with Cross-Graph Optimization (CGO).
The execution engine with Cross-Graph Optimization (CGO).
...
@@ -41,24 +43,35 @@ class CGOExecutionEngine(AbstractExecutionEngine):
...
@@ -41,24 +43,35 @@ class CGOExecutionEngine(AbstractExecutionEngine):
Parameters
Parameters
----------
----------
devices : List[De
vice
]
training_ser
vice
Available devices for execution
.
The remote training service config
.
max_concurrency
: int
max_concurrency
The maximum number of trials to run concurrently.
The maximum number of trials to run concurrently.
batch_waiting_time
: int
batch_waiting_time
Seconds to wait for each batch of trial submission.
Seconds to wait for each batch of trial submission.
The trials within one batch could apply cross-graph optimization.
The trials within one batch could apply cross-graph optimization.
rest_port
The port of the experiment's rest server
rest_url_prefix
The url prefix of the experiment's rest entry
"""
"""
def
__init__
(
self
,
devices
:
List
[
Device
]
=
None
,
def
__init__
(
self
,
training_service
:
RemoteConfig
,
max_concurrency
:
int
=
None
,
max_concurrency
:
int
=
None
,
batch_waiting_time
:
int
=
60
,
batch_waiting_time
:
int
=
60
,
rest_port
:
int
|
None
=
None
,
rest_url_prefix
:
str
|
None
=
None
)
->
None
:
)
->
None
:
self
.
port
=
rest_port
self
.
url_prefix
=
rest_url_prefix
self
.
_listeners
:
List
[
AbstractGraphListener
]
=
[]
self
.
_listeners
:
List
[
AbstractGraphListener
]
=
[]
self
.
_running_models
:
Dict
[
int
,
Model
]
=
dict
()
self
.
_running_models
:
Dict
[
int
,
Model
]
=
dict
()
self
.
logical_plan_counter
=
0
self
.
logical_plan_counter
=
0
self
.
available_devices
:
List
[
Device
]
=
[]
self
.
available_devices
:
List
[
Device
]
=
[]
self
.
max_concurrency
:
int
=
max_concurrency
self
.
max_concurrency
:
int
=
max_concurrency
devices
=
self
.
_construct_devices
(
training_service
)
for
device
in
devices
:
for
device
in
devices
:
self
.
available_devices
.
append
(
device
)
self
.
available_devices
.
append
(
device
)
self
.
all_devices
=
self
.
available_devices
.
copy
()
self
.
all_devices
=
self
.
available_devices
.
copy
()
...
@@ -88,6 +101,17 @@ class CGOExecutionEngine(AbstractExecutionEngine):
...
@@ -88,6 +101,17 @@ class CGOExecutionEngine(AbstractExecutionEngine):
self
.
_consumer_thread
=
threading
.
Thread
(
target
=
self
.
_consume_models
)
self
.
_consumer_thread
=
threading
.
Thread
(
target
=
self
.
_consume_models
)
self
.
_consumer_thread
.
start
()
self
.
_consumer_thread
.
start
()
def
_construct_devices
(
self
,
training_service
):
devices
=
[]
if
hasattr
(
training_service
,
'machine_list'
):
for
machine
in
cast
(
RemoteConfig
,
training_service
).
machine_list
:
assert
machine
.
gpu_indices
is
not
None
,
\
'gpu_indices must be set in RemoteMachineConfig for CGO execution engine'
assert
isinstance
(
machine
.
gpu_indices
,
list
),
'gpu_indices must be a list'
for
gpu_idx
in
machine
.
gpu_indices
:
devices
.
append
(
GPUDevice
(
machine
.
host
,
gpu_idx
))
return
devices
def
join
(
self
):
def
join
(
self
):
self
.
_stopped
=
True
self
.
_stopped
=
True
self
.
_consumer_thread
.
join
()
self
.
_consumer_thread
.
join
()
...
...
nni/retiarii/experiment/__init__.py
View file @
2fc47247
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
\ No newline at end of file
nni/retiarii/experiment/config/__init__.py
0 → 100644
View file @
2fc47247
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
.experiment_config
import
*
from
.engine_config
import
*
\ No newline at end of file
nni/retiarii/experiment/config/engine_config.py
0 → 100644
View file @
2fc47247
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
dataclasses
import
dataclass
from
typing
import
Optional
,
List
from
nni.experiment.config.base
import
ConfigBase
__all__
=
[
'ExecutionEngineConfig'
,
'BaseEngineConfig'
,
'OneshotEngineConfig'
,
'PyEngineConfig'
,
'CgoEngineConfig'
,
'BenchmarkEngineConfig'
]
@
dataclass
(
init
=
False
)
class
ExecutionEngineConfig
(
ConfigBase
):
name
:
str
@
dataclass
(
init
=
False
)
class
PyEngineConfig
(
ExecutionEngineConfig
):
name
:
str
=
'py'
@
dataclass
(
init
=
False
)
class
OneshotEngineConfig
(
ExecutionEngineConfig
):
name
:
str
=
'oneshot'
@
dataclass
(
init
=
False
)
class
BaseEngineConfig
(
ExecutionEngineConfig
):
name
:
str
=
'base'
# input used in GraphConverterWithShape. Currently support shape tuple only.
dummy_input
:
Optional
[
List
[
int
]]
=
None
@
dataclass
(
init
=
False
)
class
CgoEngineConfig
(
ExecutionEngineConfig
):
name
:
str
=
'cgo'
max_concurrency_cgo
:
Optional
[
int
]
=
None
batch_waiting_time
:
Optional
[
int
]
=
None
# input used in GraphConverterWithShape. Currently support shape tuple only.
dummy_input
:
Optional
[
List
[
int
]]
=
None
@
dataclass
(
init
=
False
)
class
BenchmarkEngineConfig
(
ExecutionEngineConfig
):
name
:
str
=
'benchmark'
benchmark
:
Optional
[
str
]
=
None
\ No newline at end of file
nni/retiarii/experiment/config/experiment_config.py
0 → 100644
View file @
2fc47247
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
os
from
dataclasses
import
dataclass
from
typing
import
Any
,
Union
from
nni.experiment.config
import
utils
,
ExperimentConfig
from
.engine_config
import
ExecutionEngineConfig
__all__
=
[
'RetiariiExeConfig'
]
def
execution_engine_config_factory
(
engine_name
):
# FIXME: may move this function to experiment utils in future
cls
=
_get_ee_config_class
(
engine_name
)
if
cls
is
None
:
raise
ValueError
(
f
'Invalid execution engine name:
{
engine_name
}
'
)
return
cls
()
def
_get_ee_config_class
(
engine_name
):
for
cls
in
ExecutionEngineConfig
.
__subclasses__
():
if
cls
.
name
==
engine_name
:
return
cls
return
None
@
dataclass
(
init
=
False
)
class
RetiariiExeConfig
(
ExperimentConfig
):
# FIXME: refactor this class to inherit from a new common base class with HPO config
search_space
:
Any
=
''
trial_code_directory
:
utils
.
PathLike
=
'.'
trial_command
:
str
=
'_reserved'
# new config field for NAS
execution_engine
:
Union
[
str
,
ExecutionEngineConfig
]
def
__init__
(
self
,
training_service_platform
:
Union
[
str
,
None
]
=
None
,
execution_engine
:
Union
[
str
,
ExecutionEngineConfig
]
=
'py'
,
**
kwargs
):
super
().
__init__
(
training_service_platform
,
**
kwargs
)
self
.
execution_engine
=
execution_engine
def
_canonicalize
(
self
,
_parents
):
msg
=
'{} is not supposed to be set in Retiarii experiment by users, your config is {}.'
if
self
.
search_space
!=
''
:
raise
ValueError
(
msg
.
format
(
'search_space'
,
self
.
search_space
))
# TODO: maybe we should also allow users to specify trial_code_directory
if
str
(
self
.
trial_code_directory
)
!=
'.'
and
not
os
.
path
.
isabs
(
self
.
trial_code_directory
):
raise
ValueError
(
msg
.
format
(
'trial_code_directory'
,
self
.
trial_code_directory
))
if
self
.
trial_command
!=
'_reserved'
and
\
not
self
.
trial_command
.
startswith
(
'python3 -m nni.retiarii.trial_entry '
):
raise
ValueError
(
msg
.
format
(
'trial_command'
,
self
.
trial_command
))
if
isinstance
(
self
.
execution_engine
,
str
):
self
.
execution_engine
=
execution_engine_config_factory
(
self
.
execution_engine
)
if
self
.
execution_engine
.
name
in
(
'py'
,
'base'
,
'cgo'
):
# TODO: replace python3 with more elegant approach
# maybe use sys.executable rendered in trial side (e.g., trial_runner)
self
.
trial_command
=
'python3 -m nni.retiarii.trial_entry '
+
self
.
execution_engine
.
name
super
().
_canonicalize
([
self
])
nni/retiarii/experiment/pytorch.py
View file @
2fc47247
This diff is collapsed.
Click to expand it.
nni/retiarii/integration_api.py
View file @
2fc47247
...
@@ -22,7 +22,10 @@ def get_advisor() -> 'RetiariiAdvisor':
...
@@ -22,7 +22,10 @@ def get_advisor() -> 'RetiariiAdvisor':
def
register_advisor
(
advisor
:
'RetiariiAdvisor'
):
def
register_advisor
(
advisor
:
'RetiariiAdvisor'
):
global
_advisor
global
_advisor
assert
_advisor
is
None
if
_advisor
is
not
None
:
warnings
.
warn
(
'Advisor is already set.'
'You should avoid instantiating RetiariiExperiment twice in one proces.'
'If you are running in a Jupyter notebook, please restart the kernel.'
)
_advisor
=
advisor
_advisor
=
advisor
...
...
nni/runtime/msg_dispatcher_base.py
View file @
2fc47247
...
@@ -18,8 +18,15 @@ _worker_fast_exit_on_terminate = True
...
@@ -18,8 +18,15 @@ _worker_fast_exit_on_terminate = True
class
MsgDispatcherBase
(
Recoverable
):
class
MsgDispatcherBase
(
Recoverable
):
"""This is where tuners and assessors are not defined yet.
"""
This is where tuners and assessors are not defined yet.
Inherits this class to make your own advisor.
Inherits this class to make your own advisor.
.. note::
The class inheriting MsgDispatcherBase should be instantiated
after nnimanager (rest server) is started, so that the object
is ready to use right after its instantiation.
"""
"""
def
__init__
(
self
,
command_channel_url
=
None
):
def
__init__
(
self
,
command_channel_url
=
None
):
...
@@ -27,6 +34,16 @@ class MsgDispatcherBase(Recoverable):
...
@@ -27,6 +34,16 @@ class MsgDispatcherBase(Recoverable):
if
command_channel_url
is
None
:
if
command_channel_url
is
None
:
command_channel_url
=
dispatcher_env_vars
.
NNI_TUNER_COMMAND_CHANNEL
command_channel_url
=
dispatcher_env_vars
.
NNI_TUNER_COMMAND_CHANNEL
self
.
_channel
=
TunerCommandChannel
(
command_channel_url
)
self
.
_channel
=
TunerCommandChannel
(
command_channel_url
)
# NOTE: `connect()` should be put in __init__. First, this `connect()` affects nnimanager's
# starting process, without `connect()` nnimanager is blocked in `dispatcher.init()`.
# Second, nas experiment uses a thread to execute `run()` of this class, thus, there is
# no way to know when the websocket between nnimanager and dispatcher is built. The following
# logic may crash is websocket is not built. One example is updating search space. If updating
# search space too soon, as the websocket has not been built, the rest api of updating search
# space will timeout.
# FIXME: this is making unittest happy
if
not
command_channel_url
.
startswith
(
'ws://_unittest_'
):
self
.
_channel
.
connect
()
self
.
default_command_queue
=
Queue
()
self
.
default_command_queue
=
Queue
()
self
.
assessor_command_queue
=
Queue
()
self
.
assessor_command_queue
=
Queue
()
self
.
default_worker
=
threading
.
Thread
(
target
=
self
.
command_queue_worker
,
args
=
(
self
.
default_command_queue
,))
self
.
default_worker
=
threading
.
Thread
(
target
=
self
.
command_queue_worker
,
args
=
(
self
.
default_command_queue
,))
...
@@ -39,7 +56,6 @@ class MsgDispatcherBase(Recoverable):
...
@@ -39,7 +56,6 @@ class MsgDispatcherBase(Recoverable):
"""
"""
_logger
.
info
(
'Dispatcher started'
)
_logger
.
info
(
'Dispatcher started'
)
self
.
_channel
.
connect
()
self
.
default_worker
.
start
()
self
.
default_worker
.
start
()
self
.
assessor_worker
.
start
()
self
.
assessor_worker
.
start
()
...
...
test/retiarii_test/cgo_mnasnet/base_mnasnet.py
View file @
2fc47247
...
@@ -4,7 +4,6 @@ import warnings
...
@@ -4,7 +4,6 @@ import warnings
import
torch
import
torch
import
torch.nn
as
torch_nn
import
torch.nn
as
torch_nn
from
torchvision.models.utils
import
load_state_dict_from_url
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
sys
import
sys
...
...
test/retiarii_test/cgo_mnasnet/test.py
View file @
2fc47247
...
@@ -8,7 +8,7 @@ import nni.retiarii.evaluator.pytorch.cgo.evaluator as cgo
...
@@ -8,7 +8,7 @@ import nni.retiarii.evaluator.pytorch.cgo.evaluator as cgo
from
nni.retiarii
import
serialize
from
nni.retiarii
import
serialize
from
base_mnasnet
import
MNASNet
from
base_mnasnet
import
MNASNet
from
nni.experiment
import
RemoteMachineConfig
from
nni.experiment
import
RemoteMachineConfig
from
nni.retiarii.experiment.pytorch
import
RetiariiExperiment
,
RetiariiExeConfig
from
nni.retiarii.experiment.pytorch
import
RetiariiExperiment
,
RetiariiExeConfig
,
CgoEngineConfig
from
nni.retiarii.strategy
import
TPEStrategy
from
nni.retiarii.strategy
import
TPEStrategy
from
torchvision
import
transforms
from
torchvision
import
transforms
from
torchvision.datasets
import
CIFAR10
from
torchvision.datasets
import
CIFAR10
...
@@ -59,8 +59,6 @@ if __name__ == '__main__':
...
@@ -59,8 +59,6 @@ if __name__ == '__main__':
exp_config
.
max_trial_number
=
10
exp_config
.
max_trial_number
=
10
exp_config
.
trial_gpu_number
=
1
exp_config
.
trial_gpu_number
=
1
exp_config
.
training_service
.
reuse_mode
=
True
exp_config
.
training_service
.
reuse_mode
=
True
exp_config
.
max_concurrency_cgo
=
3
exp_config
.
batch_waiting_time
=
0
rm_conf
=
RemoteMachineConfig
()
rm_conf
=
RemoteMachineConfig
()
rm_conf
.
host
=
'127.0.0.1'
rm_conf
.
host
=
'127.0.0.1'
...
@@ -73,6 +71,6 @@ if __name__ == '__main__':
...
@@ -73,6 +71,6 @@ if __name__ == '__main__':
rm_conf
.
max_trial_number_per_gpu
=
3
rm_conf
.
max_trial_number_per_gpu
=
3
exp_config
.
training_service
.
machine_list
=
[
rm_conf
]
exp_config
.
training_service
.
machine_list
=
[
rm_conf
]
exp_config
.
execution_engine
=
'cgo'
exp_config
.
execution_engine
=
CgoEngineConfig
(
max_concurrency_cgo
=
3
,
batch_waiting_time
=
0
)
exp
.
run
(
exp_config
,
8099
)
exp
.
run
(
exp_config
,
8099
)
\ No newline at end of file
test/ut/retiarii/test_cgo_engine.py
View file @
2fc47247
...
@@ -9,6 +9,7 @@ from pytorch_lightning.utilities.seed import seed_everything
...
@@ -9,6 +9,7 @@ from pytorch_lightning.utilities.seed import seed_everything
from
pathlib
import
Path
from
pathlib
import
Path
import
nni
import
nni
from
nni.experiment.config
import
RemoteConfig
,
RemoteMachineConfig
import
nni.runtime.platform.test
import
nni.runtime.platform.test
from
nni.runtime.tuner_command_channel
import
legacy
as
protocol
from
nni.runtime.tuner_command_channel
import
legacy
as
protocol
import
json
import
json
...
@@ -263,13 +264,14 @@ class CGOEngineTest(unittest.TestCase):
...
@@ -263,13 +264,14 @@ class CGOEngineTest(unittest.TestCase):
opt
=
DedupInputOptimizer
()
opt
=
DedupInputOptimizer
()
opt
.
convert
(
lp
)
opt
.
convert
(
lp
)
advisor
=
RetiariiAdvisor
(
'ws://_placeholder_'
)
advisor
=
RetiariiAdvisor
(
'ws://_
unittest_
placeholder_'
)
advisor
.
_channel
=
protocol
.
LegacyCommandChannel
()
advisor
.
_channel
=
protocol
.
LegacyCommandChannel
()
advisor
.
default_worker
.
start
()
advisor
.
default_worker
.
start
()
advisor
.
assessor_worker
.
start
()
advisor
.
assessor_worker
.
start
()
available_devices
=
[
GPUDevice
(
"test"
,
0
),
GPUDevice
(
"test"
,
1
),
GPUDevice
(
"test"
,
2
),
GPUDevice
(
"test"
,
3
)]
remote
=
RemoteConfig
(
machine_list
=
[])
cgo
=
CGOExecutionEngine
(
devices
=
available_devices
,
batch_waiting_time
=
0
)
remote
.
machine_list
.
append
(
RemoteMachineConfig
(
host
=
'test'
,
gpu_indices
=
[
0
,
1
,
2
,
3
]))
cgo
=
CGOExecutionEngine
(
training_service
=
remote
,
batch_waiting_time
=
0
)
phy_models
=
cgo
.
_assemble
(
lp
)
phy_models
=
cgo
.
_assemble
(
lp
)
self
.
assertTrue
(
len
(
phy_models
)
==
1
)
self
.
assertTrue
(
len
(
phy_models
)
==
1
)
...
@@ -286,13 +288,14 @@ class CGOEngineTest(unittest.TestCase):
...
@@ -286,13 +288,14 @@ class CGOEngineTest(unittest.TestCase):
opt
=
DedupInputOptimizer
()
opt
=
DedupInputOptimizer
()
opt
.
convert
(
lp
)
opt
.
convert
(
lp
)
advisor
=
RetiariiAdvisor
(
'ws://_placeholder_'
)
advisor
=
RetiariiAdvisor
(
'ws://_
unittest_
placeholder_'
)
advisor
.
_channel
=
protocol
.
LegacyCommandChannel
()
advisor
.
_channel
=
protocol
.
LegacyCommandChannel
()
advisor
.
default_worker
.
start
()
advisor
.
default_worker
.
start
()
advisor
.
assessor_worker
.
start
()
advisor
.
assessor_worker
.
start
()
available_devices
=
[
GPUDevice
(
"test"
,
0
),
GPUDevice
(
"test"
,
1
)]
remote
=
RemoteConfig
(
machine_list
=
[])
cgo
=
CGOExecutionEngine
(
devices
=
available_devices
,
batch_waiting_time
=
0
)
remote
.
machine_list
.
append
(
RemoteMachineConfig
(
host
=
'test'
,
gpu_indices
=
[
0
,
1
]))
cgo
=
CGOExecutionEngine
(
training_service
=
remote
,
batch_waiting_time
=
0
)
phy_models
=
cgo
.
_assemble
(
lp
)
phy_models
=
cgo
.
_assemble
(
lp
)
self
.
assertTrue
(
len
(
phy_models
)
==
2
)
self
.
assertTrue
(
len
(
phy_models
)
==
2
)
...
@@ -311,13 +314,14 @@ class CGOEngineTest(unittest.TestCase):
...
@@ -311,13 +314,14 @@ class CGOEngineTest(unittest.TestCase):
models
=
_load_mnist
(
2
)
models
=
_load_mnist
(
2
)
advisor
=
RetiariiAdvisor
(
'ws://_placeholder_'
)
advisor
=
RetiariiAdvisor
(
'ws://_
unittest_
placeholder_'
)
advisor
.
_channel
=
protocol
.
LegacyCommandChannel
()
advisor
.
_channel
=
protocol
.
LegacyCommandChannel
()
advisor
.
default_worker
.
start
()
advisor
.
default_worker
.
start
()
advisor
.
assessor_worker
.
start
()
advisor
.
assessor_worker
.
start
()
cgo_engine
=
CGOExecutionEngine
(
devices
=
[
GPUDevice
(
"test"
,
0
),
GPUDevice
(
"test"
,
1
),
remote
=
RemoteConfig
(
machine_list
=
[])
GPUDevice
(
"test"
,
2
),
GPUDevice
(
"test"
,
3
)],
batch_waiting_time
=
0
)
remote
.
machine_list
.
append
(
RemoteMachineConfig
(
host
=
'test'
,
gpu_indices
=
[
0
,
1
,
2
,
3
]))
cgo_engine
=
CGOExecutionEngine
(
training_service
=
remote
,
batch_waiting_time
=
0
)
set_execution_engine
(
cgo_engine
)
set_execution_engine
(
cgo_engine
)
submit_models
(
*
models
)
submit_models
(
*
models
)
time
.
sleep
(
3
)
time
.
sleep
(
3
)
...
...
test/ut/retiarii/test_engine.py
View file @
2fc47247
...
@@ -25,7 +25,7 @@ class EngineTest(unittest.TestCase):
...
@@ -25,7 +25,7 @@ class EngineTest(unittest.TestCase):
def
test_base_execution_engine
(
self
):
def
test_base_execution_engine
(
self
):
nni
.
retiarii
.
integration_api
.
_advisor
=
None
nni
.
retiarii
.
integration_api
.
_advisor
=
None
nni
.
retiarii
.
execution
.
api
.
_execution_engine
=
None
nni
.
retiarii
.
execution
.
api
.
_execution_engine
=
None
advisor
=
RetiariiAdvisor
(
'ws://_placeholder_'
)
advisor
=
RetiariiAdvisor
(
'ws://_
unittest_
placeholder_'
)
advisor
.
_channel
=
LegacyCommandChannel
()
advisor
.
_channel
=
LegacyCommandChannel
()
advisor
.
default_worker
.
start
()
advisor
.
default_worker
.
start
()
advisor
.
assessor_worker
.
start
()
advisor
.
assessor_worker
.
start
()
...
@@ -42,7 +42,7 @@ class EngineTest(unittest.TestCase):
...
@@ -42,7 +42,7 @@ class EngineTest(unittest.TestCase):
def
test_py_execution_engine
(
self
):
def
test_py_execution_engine
(
self
):
nni
.
retiarii
.
integration_api
.
_advisor
=
None
nni
.
retiarii
.
integration_api
.
_advisor
=
None
nni
.
retiarii
.
execution
.
api
.
_execution_engine
=
None
nni
.
retiarii
.
execution
.
api
.
_execution_engine
=
None
advisor
=
RetiariiAdvisor
(
'ws://_placeholder_'
)
advisor
=
RetiariiAdvisor
(
'ws://_
unittest_
placeholder_'
)
advisor
.
_channel
=
LegacyCommandChannel
()
advisor
.
_channel
=
LegacyCommandChannel
()
advisor
.
default_worker
.
start
()
advisor
.
default_worker
.
start
()
advisor
.
assessor_worker
.
start
()
advisor
.
assessor_worker
.
start
()
...
...
test/ut/sdk/test_assessor.py
View file @
2fc47247
...
@@ -57,7 +57,7 @@ class AssessorTestCase(TestCase):
...
@@ -57,7 +57,7 @@ class AssessorTestCase(TestCase):
_restore_io
()
_restore_io
()
assessor
=
NaiveAssessor
()
assessor
=
NaiveAssessor
()
dispatcher
=
MsgDispatcher
(
'ws://_placeholder_'
,
None
,
assessor
)
dispatcher
=
MsgDispatcher
(
'ws://_
unittest_
placeholder_'
,
None
,
assessor
)
dispatcher
.
_channel
=
LegacyCommandChannel
()
dispatcher
.
_channel
=
LegacyCommandChannel
()
msg_dispatcher_base
.
_worker_fast_exit_on_terminate
=
False
msg_dispatcher_base
.
_worker_fast_exit_on_terminate
=
False
...
...
test/ut/sdk/test_msg_dispatcher.py
View file @
2fc47247
...
@@ -66,7 +66,7 @@ class MsgDispatcherTestCase(TestCase):
...
@@ -66,7 +66,7 @@ class MsgDispatcherTestCase(TestCase):
_restore_io
()
_restore_io
()
tuner
=
NaiveTuner
()
tuner
=
NaiveTuner
()
dispatcher
=
MsgDispatcher
(
'ws://_placeholder_'
,
tuner
)
dispatcher
=
MsgDispatcher
(
'ws://_
unittest_
placeholder_'
,
tuner
)
dispatcher
.
_channel
=
LegacyCommandChannel
()
dispatcher
.
_channel
=
LegacyCommandChannel
()
msg_dispatcher_base
.
_worker_fast_exit_on_terminate
=
False
msg_dispatcher_base
.
_worker_fast_exit_on_terminate
=
False
...
...
ts/nni_manager/core/nnimanager.ts
View file @
2fc47247
...
@@ -303,8 +303,11 @@ class NNIManager implements Manager {
...
@@ -303,8 +303,11 @@ class NNIManager implements Manager {
}
}
this
.
trainingService
.
removeTrialJobMetricListener
(
this
.
trialJobMetricListener
);
this
.
trainingService
.
removeTrialJobMetricListener
(
this
.
trialJobMetricListener
);
// NOTE: this sending TERMINATE should be out of the if clause,
// because when python dispatcher is started before nnimanager
// this.dispatcherPid would not have a valid value (i.e., not >0).
this
.
dispatcher
.
sendCommand
(
TERMINATE
);
if
(
this
.
dispatcherPid
>
0
)
{
if
(
this
.
dispatcherPid
>
0
)
{
this
.
dispatcher
.
sendCommand
(
TERMINATE
);
// gracefully terminate tuner and assessor here, wait at most 30 seconds.
// gracefully terminate tuner and assessor here, wait at most 30 seconds.
for
(
let
i
:
number
=
0
;
i
<
30
;
i
++
)
{
for
(
let
i
:
number
=
0
;
i
<
30
;
i
++
)
{
if
(
!
await
isAlive
(
this
.
dispatcherPid
))
{
if
(
!
await
isAlive
(
this
.
dispatcherPid
))
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment