Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
2fc47247
"vscode:/vscode.git/clone" did not exist on "54ef5a5c26b336fa2f35f49639cb02c8eb73ac63"
Unverified
Commit
2fc47247
authored
May 24, 2022
by
QuanluZhang
Committed by
GitHub
May 24, 2022
Browse files
[retiarii] refactor of nas experiment (#4841)
parent
c80bda29
Changes
20
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
346 additions
and
324 deletions
+346
-324
nni/experiment/config/base.py
nni/experiment/config/base.py
+5
-0
nni/experiment/config/experiment_config.py
nni/experiment/config/experiment_config.py
+4
-3
nni/experiment/experiment.py
nni/experiment/experiment.py
+40
-25
nni/retiarii/execution/api.py
nni/retiarii/execution/api.py
+7
-6
nni/retiarii/execution/base.py
nni/retiarii/execution/base.py
+17
-3
nni/retiarii/execution/cgo_engine.py
nni/retiarii/execution/cgo_engine.py
+31
-7
nni/retiarii/experiment/__init__.py
nni/retiarii/experiment/__init__.py
+2
-0
nni/retiarii/experiment/config/__init__.py
nni/retiarii/experiment/config/__init__.py
+5
-0
nni/retiarii/experiment/config/engine_config.py
nni/retiarii/experiment/config/engine_config.py
+41
-0
nni/retiarii/experiment/config/experiment_config.py
nni/retiarii/experiment/config/experiment_config.py
+60
-0
nni/retiarii/experiment/pytorch.py
nni/retiarii/experiment/pytorch.py
+88
-257
nni/retiarii/integration_api.py
nni/retiarii/integration_api.py
+4
-1
nni/runtime/msg_dispatcher_base.py
nni/runtime/msg_dispatcher_base.py
+18
-2
test/retiarii_test/cgo_mnasnet/base_mnasnet.py
test/retiarii_test/cgo_mnasnet/base_mnasnet.py
+0
-1
test/retiarii_test/cgo_mnasnet/test.py
test/retiarii_test/cgo_mnasnet/test.py
+3
-5
test/ut/retiarii/test_cgo_engine.py
test/ut/retiarii/test_cgo_engine.py
+13
-9
test/ut/retiarii/test_engine.py
test/ut/retiarii/test_engine.py
+2
-2
test/ut/sdk/test_assessor.py
test/ut/sdk/test_assessor.py
+1
-1
test/ut/sdk/test_msg_dispatcher.py
test/ut/sdk/test_msg_dispatcher.py
+1
-1
ts/nni_manager/core/nnimanager.ts
ts/nni_manager/core/nnimanager.ts
+4
-1
No files found.
nni/experiment/config/base.py
View file @
2fc47247
...
...
@@ -54,6 +54,11 @@ class ConfigBase:
Config objects will remember where they are loaded; therefore relative paths can be resolved smartly.
If a config object is created with constructor, the base path will be current working directory.
If it is loaded with ``ConfigBase.load(path)``, the base path will be ``path``'s parent.
.. attention::
All the classes that inherit ``ConfigBase`` are not allowed to use ``from __future__ import annotations``,
because ``ConfigBase`` uses ``typeguard`` to perform runtime check and it does not support lazy annotations.
"""
def
__init__
(
self
,
**
kwargs
):
...
...
nni/experiment/config/experiment_config.py
View file @
2fc47247
...
...
@@ -164,10 +164,11 @@ class ExperimentConfig(ConfigBase):
# currently I have only seen one issue of this kind
#Path(self.experiment_working_directory).mkdir(parents=True, exist_ok=True)
utils
.
validate_gpu_indices
(
self
.
tuner_gpu_indices
)
if
type
(
self
).
__name__
!=
'RetiariiExeConfig'
:
utils
.
validate_gpu_indices
(
self
.
tuner_gpu_indices
)
if
self
.
tuner
is
None
:
raise
ValueError
(
'ExperimentConfig: tuner must be set'
)
if
self
.
tuner
is
None
:
raise
ValueError
(
'ExperimentConfig: tuner must be set'
)
def
_load_search_space_file
(
search_space_path
):
# FIXME
...
...
nni/experiment/experiment.py
View file @
2fc47247
...
...
@@ -84,20 +84,9 @@ class Experiment:
else
:
self
.
config
=
config_or_platform
def
start
(
self
,
port
:
int
=
8080
,
debug
:
bool
=
False
,
run_mode
:
RunMode
=
RunMode
.
Background
)
->
None
:
"""
Start the experiment in background.
This method will raise exception on failure.
If it returns, the experiment should have been successfully started.
Parameters
----------
port
The port of web UI.
debug
Whether to start in debug mode.
"""
def
_start_impl
(
self
,
port
:
int
,
debug
:
bool
,
run_mode
:
RunMode
,
tuner_command_channel
:
str
|
None
,
tags
:
list
[
str
]
=
[])
->
ExperimentConfig
:
assert
self
.
config
is
not
None
if
run_mode
is
not
RunMode
.
Detach
:
atexit
.
register
(
self
.
stop
)
...
...
@@ -111,7 +100,8 @@ class Experiment:
log_level
=
'debug'
if
(
debug
or
config
.
log_level
==
'trace'
)
else
config
.
log_level
start_experiment_logging
(
self
.
id
,
log_file
,
cast
(
str
,
log_level
))
self
.
_proc
=
launcher
.
start_experiment
(
self
.
_action
,
self
.
id
,
config
,
port
,
debug
,
run_mode
,
self
.
url_prefix
)
self
.
_proc
=
launcher
.
start_experiment
(
self
.
_action
,
self
.
id
,
config
,
port
,
debug
,
run_mode
,
self
.
url_prefix
,
tuner_command_channel
,
tags
)
assert
self
.
_proc
is
not
None
self
.
port
=
port
# port will be None if start up failed
...
...
@@ -124,12 +114,27 @@ class Experiment:
ips
=
[
f
'http://
{
ip
}
:
{
port
}
'
for
ip
in
ips
if
ip
]
msg
=
'Web portal URLs: ${CYAN}'
+
' '
.
join
(
ips
)
_logger
.
info
(
msg
)
return
config
def
st
op
(
self
)
->
None
:
def
st
art
(
self
,
port
:
int
=
8080
,
debug
:
bool
=
False
,
run_mode
:
RunMode
=
RunMode
.
Background
)
->
None
:
"""
Stop the experiment.
Start the experiment in background.
This method will raise exception on failure.
If it returns, the experiment should have been successfully started.
Parameters
----------
port
The port of web UI.
debug
Whether to start in debug mode.
run_mode
Running the experiment in foreground or background
"""
_logger
.
info
(
'Stopping experiment, please wait...'
)
self
.
_start_impl
(
port
,
debug
,
run_mode
,
None
,
[])
def
_stop_impl
(
self
)
->
None
:
atexit
.
unregister
(
self
.
stop
)
stop_experiment_logging
(
self
.
id
)
...
...
@@ -144,8 +149,24 @@ class Experiment:
self
.
id
=
None
# type: ignore
self
.
port
=
None
self
.
_proc
=
None
def
stop
(
self
)
->
None
:
"""
Stop the experiment.
"""
_logger
.
info
(
'Stopping experiment, please wait...'
)
self
.
_stop_impl
()
_logger
.
info
(
'Experiment stopped'
)
def
_wait_completion
(
self
)
->
bool
:
while
True
:
status
=
self
.
get_status
()
if
status
==
'DONE'
or
status
==
'STOPPED'
:
return
True
if
status
==
'ERROR'
:
return
False
time
.
sleep
(
10
)
def
run
(
self
,
port
:
int
=
8080
,
wait_completion
:
bool
=
True
,
debug
:
bool
=
False
)
->
bool
|
None
:
"""
Run the experiment.
...
...
@@ -159,13 +180,7 @@ class Experiment:
self
.
start
(
port
,
debug
)
if
wait_completion
:
try
:
while
True
:
time
.
sleep
(
10
)
status
=
self
.
get_status
()
if
status
==
'DONE'
or
status
==
'STOPPED'
:
return
True
if
status
==
'ERROR'
:
return
False
self
.
_wait_completion
()
except
KeyboardInterrupt
:
_logger
.
warning
(
'KeyboardInterrupt detected'
)
self
.
stop
()
...
...
nni/retiarii/execution/api.py
View file @
2fc47247
...
...
@@ -2,6 +2,7 @@
# Licensed under the MIT license.
import
time
import
warnings
from
typing
import
Iterable
from
..graph
import
Model
,
ModelStatus
...
...
@@ -18,12 +19,12 @@ __all__ = ['get_execution_engine', 'get_and_register_default_listener',
def
set_execution_engine
(
engine
:
AbstractExecutionEngine
)
->
None
:
global
_execution_engine
if
_execution_engine
is
None
:
_e
xecution
_
engine
=
engine
else
:
rais
e
R
un
timeError
(
'Execution engine is already set. '
'You should avoid instantiating RetiariiExperiment twice in one process. '
'If you are runni
ng
in
a Jupyter notebook, please restart the kernel.'
)
if
_execution_engine
is
not
None
:
warnings
.
warn
(
'E
xecution
engine
is already set. '
'You should avoid instantiating RetiariiExperiment twice in one process. '
'If you ar
e
r
un
ning in a Jupyter notebook, please restart the kernel.'
,
RuntimeWarning
)
_execution_e
ngin
e
=
engine
def
get_execution_engine
()
->
AbstractExecutionEngine
:
...
...
nni/retiarii/execution/base.py
View file @
2fc47247
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
__future__
import
annotations
import
logging
import
os
import
random
import
string
from
typing
import
Any
,
Dict
,
Iterable
,
List
from
nni.experiment
import
rest
from
.interface
import
AbstractExecutionEngine
,
AbstractGraphListener
from
.utils
import
get_mutation_summary
from
..
import
codegen
,
utils
...
...
@@ -54,12 +58,22 @@ class BaseExecutionEngine(AbstractExecutionEngine):
Resource management is implemented in this class.
"""
def
__init__
(
self
)
->
None
:
def
__init__
(
self
,
rest_port
:
int
|
None
=
None
,
rest_url_prefix
:
str
|
None
=
None
)
->
None
:
"""
Upon initialization, advisor callbacks need to be registered.
Advisor will call the callbacks when the corresponding event has been triggered.
Base execution engine will get those callbacks and broadcast them to graph listener.
Parameters
----------
rest_port
The port of the experiment's rest server
rest_url_prefix
The url prefix of the experiment's rest entry
"""
self
.
port
=
rest_port
self
.
url_prefix
=
rest_url_prefix
self
.
_listeners
:
List
[
AbstractGraphListener
]
=
[]
# register advisor callbacks
...
...
@@ -123,8 +137,8 @@ class BaseExecutionEngine(AbstractExecutionEngine):
return
self
.
resources
def
budget_exhausted
(
self
)
->
bool
:
advisor
=
get_advisor
(
)
return
advisor
.
stopping
resp
=
rest
.
get
(
self
.
port
,
'/check-status'
,
self
.
url_prefix
)
return
resp
[
'status'
]
==
'DONE'
@
classmethod
def
pack_model_data
(
cls
,
model
:
Model
)
->
Any
:
...
...
nni/retiarii/execution/cgo_engine.py
View file @
2fc47247
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
__future__
import
annotations
import
logging
import
os
import
random
import
string
import
time
import
threading
from
typing
import
Iterable
,
List
,
Dict
,
Tuple
from
typing
import
Iterable
,
List
,
Dict
,
Tuple
,
cast
from
dataclasses
import
dataclass
from
nni.common.device
import
GPUDevice
,
Device
from
nni.experiment.config.training_services
import
RemoteConfig
from
.interface
import
AbstractExecutionEngine
,
AbstractGraphListener
,
WorkerInfo
from
..
import
codegen
,
utils
from
..graph
import
Model
,
ModelStatus
,
MetricData
,
Node
...
...
@@ -31,7 +34,6 @@ class TrialSubmission:
placement
:
Dict
[
Node
,
Device
]
grouped_models
:
List
[
Model
]
class
CGOExecutionEngine
(
AbstractExecutionEngine
):
"""
The execution engine with Cross-Graph Optimization (CGO).
...
...
@@ -41,24 +43,35 @@ class CGOExecutionEngine(AbstractExecutionEngine):
Parameters
----------
devices : List[De
vice
]
Available devices for execution
.
max_concurrency
: int
training_ser
vice
The remote training service config
.
max_concurrency
The maximum number of trials to run concurrently.
batch_waiting_time
: int
batch_waiting_time
Seconds to wait for each batch of trial submission.
The trials within one batch could apply cross-graph optimization.
rest_port
The port of the experiment's rest server
rest_url_prefix
The url prefix of the experiment's rest entry
"""
def
__init__
(
self
,
devices
:
List
[
Device
]
=
None
,
def
__init__
(
self
,
training_service
:
RemoteConfig
,
max_concurrency
:
int
=
None
,
batch_waiting_time
:
int
=
60
,
rest_port
:
int
|
None
=
None
,
rest_url_prefix
:
str
|
None
=
None
)
->
None
:
self
.
port
=
rest_port
self
.
url_prefix
=
rest_url_prefix
self
.
_listeners
:
List
[
AbstractGraphListener
]
=
[]
self
.
_running_models
:
Dict
[
int
,
Model
]
=
dict
()
self
.
logical_plan_counter
=
0
self
.
available_devices
:
List
[
Device
]
=
[]
self
.
max_concurrency
:
int
=
max_concurrency
devices
=
self
.
_construct_devices
(
training_service
)
for
device
in
devices
:
self
.
available_devices
.
append
(
device
)
self
.
all_devices
=
self
.
available_devices
.
copy
()
...
...
@@ -88,6 +101,17 @@ class CGOExecutionEngine(AbstractExecutionEngine):
self
.
_consumer_thread
=
threading
.
Thread
(
target
=
self
.
_consume_models
)
self
.
_consumer_thread
.
start
()
def
_construct_devices
(
self
,
training_service
):
devices
=
[]
if
hasattr
(
training_service
,
'machine_list'
):
for
machine
in
cast
(
RemoteConfig
,
training_service
).
machine_list
:
assert
machine
.
gpu_indices
is
not
None
,
\
'gpu_indices must be set in RemoteMachineConfig for CGO execution engine'
assert
isinstance
(
machine
.
gpu_indices
,
list
),
'gpu_indices must be a list'
for
gpu_idx
in
machine
.
gpu_indices
:
devices
.
append
(
GPUDevice
(
machine
.
host
,
gpu_idx
))
return
devices
def
join
(
self
):
self
.
_stopped
=
True
self
.
_consumer_thread
.
join
()
...
...
nni/retiarii/experiment/__init__.py
View file @
2fc47247
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
\ No newline at end of file
nni/retiarii/experiment/config/__init__.py
0 → 100644
View file @
2fc47247
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
.experiment_config
import
*
from
.engine_config
import
*
\ No newline at end of file
nni/retiarii/experiment/config/engine_config.py
0 → 100644
View file @
2fc47247
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
dataclasses
import
dataclass
from
typing
import
Optional
,
List
from
nni.experiment.config.base
import
ConfigBase
__all__
=
[
'ExecutionEngineConfig'
,
'BaseEngineConfig'
,
'OneshotEngineConfig'
,
'PyEngineConfig'
,
'CgoEngineConfig'
,
'BenchmarkEngineConfig'
]
@
dataclass
(
init
=
False
)
class
ExecutionEngineConfig
(
ConfigBase
):
name
:
str
@
dataclass
(
init
=
False
)
class
PyEngineConfig
(
ExecutionEngineConfig
):
name
:
str
=
'py'
@
dataclass
(
init
=
False
)
class
OneshotEngineConfig
(
ExecutionEngineConfig
):
name
:
str
=
'oneshot'
@
dataclass
(
init
=
False
)
class
BaseEngineConfig
(
ExecutionEngineConfig
):
name
:
str
=
'base'
# input used in GraphConverterWithShape. Currently support shape tuple only.
dummy_input
:
Optional
[
List
[
int
]]
=
None
@
dataclass
(
init
=
False
)
class
CgoEngineConfig
(
ExecutionEngineConfig
):
name
:
str
=
'cgo'
max_concurrency_cgo
:
Optional
[
int
]
=
None
batch_waiting_time
:
Optional
[
int
]
=
None
# input used in GraphConverterWithShape. Currently support shape tuple only.
dummy_input
:
Optional
[
List
[
int
]]
=
None
@
dataclass
(
init
=
False
)
class
BenchmarkEngineConfig
(
ExecutionEngineConfig
):
name
:
str
=
'benchmark'
benchmark
:
Optional
[
str
]
=
None
\ No newline at end of file
nni/retiarii/experiment/config/experiment_config.py
0 → 100644
View file @
2fc47247
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
os
from
dataclasses
import
dataclass
from
typing
import
Any
,
Union
from
nni.experiment.config
import
utils
,
ExperimentConfig
from
.engine_config
import
ExecutionEngineConfig
__all__
=
[
'RetiariiExeConfig'
]
def
execution_engine_config_factory
(
engine_name
):
# FIXME: may move this function to experiment utils in future
cls
=
_get_ee_config_class
(
engine_name
)
if
cls
is
None
:
raise
ValueError
(
f
'Invalid execution engine name:
{
engine_name
}
'
)
return
cls
()
def
_get_ee_config_class
(
engine_name
):
for
cls
in
ExecutionEngineConfig
.
__subclasses__
():
if
cls
.
name
==
engine_name
:
return
cls
return
None
@
dataclass
(
init
=
False
)
class
RetiariiExeConfig
(
ExperimentConfig
):
# FIXME: refactor this class to inherit from a new common base class with HPO config
search_space
:
Any
=
''
trial_code_directory
:
utils
.
PathLike
=
'.'
trial_command
:
str
=
'_reserved'
# new config field for NAS
execution_engine
:
Union
[
str
,
ExecutionEngineConfig
]
def
__init__
(
self
,
training_service_platform
:
Union
[
str
,
None
]
=
None
,
execution_engine
:
Union
[
str
,
ExecutionEngineConfig
]
=
'py'
,
**
kwargs
):
super
().
__init__
(
training_service_platform
,
**
kwargs
)
self
.
execution_engine
=
execution_engine
def
_canonicalize
(
self
,
_parents
):
msg
=
'{} is not supposed to be set in Retiarii experiment by users, your config is {}.'
if
self
.
search_space
!=
''
:
raise
ValueError
(
msg
.
format
(
'search_space'
,
self
.
search_space
))
# TODO: maybe we should also allow users to specify trial_code_directory
if
str
(
self
.
trial_code_directory
)
!=
'.'
and
not
os
.
path
.
isabs
(
self
.
trial_code_directory
):
raise
ValueError
(
msg
.
format
(
'trial_code_directory'
,
self
.
trial_code_directory
))
if
self
.
trial_command
!=
'_reserved'
and
\
not
self
.
trial_command
.
startswith
(
'python3 -m nni.retiarii.trial_entry '
):
raise
ValueError
(
msg
.
format
(
'trial_command'
,
self
.
trial_command
))
if
isinstance
(
self
.
execution_engine
,
str
):
self
.
execution_engine
=
execution_engine_config_factory
(
self
.
execution_engine
)
if
self
.
execution_engine
.
name
in
(
'py'
,
'base'
,
'cgo'
):
# TODO: replace python3 with more elegant approach
# maybe use sys.executable rendered in trial side (e.g., trial_runner)
self
.
trial_command
=
'python3 -m nni.retiarii.trial_entry '
+
self
.
execution_engine
.
name
super
().
_canonicalize
([
self
])
nni/retiarii/experiment/pytorch.py
View file @
2fc47247
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
atexit
from
__future__
import
annotations
import
logging
import
os
import
socket
import
time
import
warnings
from
dataclasses
import
dataclass
from
pathlib
import
Path
from
subprocess
import
Popen
from
threading
import
Thread
from
typing
import
Any
,
List
,
Optional
,
Union
,
cast
from
typing
import
Any
,
List
,
Union
,
cast
import
colorama
import
psutil
import
torch
import
torch.nn
as
nn
import
nni.runtime.log
from
nni.common.device
import
GPUDevice
from
nni.experiment
import
Experiment
,
RunMode
,
launcher
,
management
,
rest
from
nni.experiment.config
import
utils
from
nni.experiment.config.base
import
ConfigBase
from
nni.experiment.config.training_service
import
TrainingServiceConfig
from
nni.experiment
import
Experiment
,
RunMode
from
nni.experiment.config.training_services
import
RemoteConfig
from
nni.runtime.tuner_command_channel
import
TunerCommandChannel
from
nni.tools.nnictl.command_utils
import
kill_command
from
.config
import
(
RetiariiExeConfig
,
OneshotEngineConfig
,
BaseEngineConfig
,
PyEngineConfig
,
CgoEngineConfig
,
BenchmarkEngineConfig
)
from
..codegen
import
model_to_pytorch_script
from
..converter
import
convert_to_graph
from
..converter.graph_gen
import
GraphConverterWithShape
...
...
@@ -46,79 +39,7 @@ from ..strategy.utils import dry_run_for_formatted_search_space
_logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
'RetiariiExeConfig'
,
'RetiariiExperiment'
]
@
dataclass
(
init
=
False
)
class
RetiariiExeConfig
(
ConfigBase
):
experiment_name
:
Optional
[
str
]
=
None
search_space
:
Any
=
''
# TODO: remove
trial_command
:
str
=
'_reserved'
trial_code_directory
:
utils
.
PathLike
=
'.'
trial_concurrency
:
int
trial_gpu_number
:
int
=
0
devices
:
Optional
[
List
[
Union
[
str
,
GPUDevice
]]]
=
None
max_experiment_duration
:
Optional
[
str
]
=
None
max_trial_number
:
Optional
[
int
]
=
None
max_concurrency_cgo
:
Optional
[
int
]
=
None
batch_waiting_time
:
Optional
[
int
]
=
None
nni_manager_ip
:
Optional
[
str
]
=
None
debug
:
bool
=
False
log_level
:
str
=
'info'
experiment_working_directory
:
utils
.
PathLike
=
'~/nni-experiments'
# remove configuration of tuner/assessor/advisor
training_service
:
TrainingServiceConfig
execution_engine
:
str
=
'py'
# input used in GraphConverterWithShape. Currently support shape tuple only.
dummy_input
:
Optional
[
List
[
int
]]
=
None
# input used for benchmark engine.
benchmark
:
Optional
[
str
]
=
None
def
__init__
(
self
,
training_service_platform
:
Optional
[
str
]
=
None
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
if
training_service_platform
is
not
None
:
assert
'training_service'
not
in
kwargs
self
.
training_service
=
utils
.
training_service_config_factory
(
platform
=
training_service_platform
)
self
.
__dict__
[
'trial_command'
]
=
'python3 -m nni.retiarii.trial_entry py'
def
__setattr__
(
self
,
key
,
value
):
fixed_attrs
=
{
'search_space'
:
''
,
'trial_command'
:
'_reserved'
}
if
key
in
fixed_attrs
and
fixed_attrs
[
key
]
!=
value
:
raise
AttributeError
(
f
'
{
key
}
is not supposed to be set in Retiarii mode by users!'
)
# 'trial_code_directory' is handled differently because the path will be converted to absolute path by us
if
key
==
'trial_code_directory'
and
not
(
str
(
value
)
==
'.'
or
os
.
path
.
isabs
(
value
)):
raise
AttributeError
(
f
'
{
key
}
is not supposed to be set in Retiarii mode by users!'
)
if
key
==
'execution_engine'
:
assert
value
in
[
'base'
,
'py'
,
'cgo'
,
'benchmark'
,
'oneshot'
],
f
'The specified execution engine "
{
value
}
" is not supported.'
self
.
__dict__
[
'trial_command'
]
=
'python3 -m nni.retiarii.trial_entry '
+
value
self
.
__dict__
[
key
]
=
value
def
validate
(
self
,
initialized_tuner
:
bool
=
False
)
->
None
:
super
().
validate
()
@
property
def
_canonical_rules
(
self
):
return
_canonical_rules
@
property
def
_validation_rules
(
self
):
return
_validation_rules
_canonical_rules
=
{
}
_validation_rules
=
{
'trial_code_directory'
:
lambda
value
:
(
Path
(
value
).
is_dir
(),
f
'"
{
value
}
" does not exist or is not directory'
),
'trial_concurrency'
:
lambda
value
:
value
>
0
,
'trial_gpu_number'
:
lambda
value
:
value
>=
0
,
'max_trial_number'
:
lambda
value
:
value
>
0
,
'log_level'
:
lambda
value
:
value
in
[
"trace"
,
"debug"
,
"info"
,
"warning"
,
"error"
,
"fatal"
],
'training_service'
:
lambda
value
:
(
type
(
value
)
is
not
TrainingServiceConfig
,
'cannot be abstract base class'
)
}
__all__
=
[
'RetiariiExperiment'
]
def
preprocess_model
(
base_model
,
evaluator
,
applied_mutators
,
full_ir
=
True
,
dummy_input
=
None
,
oneshot
=
False
):
...
...
@@ -252,9 +173,14 @@ class RetiariiExperiment(Experiment):
... final_model = Net()
"""
def
__init__
(
self
,
base_model
:
nn
.
Module
,
evaluator
:
Union
[
BaseOneShotTrainer
,
Evaluator
]
=
cast
(
Evaluator
,
None
),
applied_mutators
:
List
[
Mutator
]
=
cast
(
List
[
Mutator
],
None
),
strategy
:
BaseStrategy
=
cast
(
BaseStrategy
,
None
),
def
__init__
(
self
,
base_model
:
nn
.
Module
,
evaluator
:
Union
[
BaseOneShotTrainer
,
Evaluator
]
=
cast
(
Evaluator
,
None
),
applied_mutators
:
List
[
Mutator
]
=
cast
(
List
[
Mutator
],
None
),
strategy
:
BaseStrategy
=
cast
(
BaseStrategy
,
None
),
trainer
:
BaseOneShotTrainer
=
cast
(
BaseOneShotTrainer
,
None
)):
super
().
__init__
(
None
)
self
.
config
:
RetiariiExeConfig
=
cast
(
RetiariiExeConfig
,
None
)
if
trainer
is
not
None
:
warnings
.
warn
(
'Usage of `trainer` in RetiariiExperiment is deprecated and will be removed soon. '
'Please consider specifying it as a positional argument, or use `evaluator`.'
,
DeprecationWarning
)
...
...
@@ -263,25 +189,13 @@ class RetiariiExperiment(Experiment):
if
evaluator
is
None
:
raise
ValueError
(
'Evaluator should not be none.'
)
# TODO: The current design of init interface of Retiarii experiment needs to be reviewed.
self
.
config
:
RetiariiExeConfig
=
cast
(
RetiariiExeConfig
,
None
)
self
.
port
:
Optional
[
int
]
=
None
self
.
base_model
=
base_model
self
.
evaluator
:
Union
[
Evaluator
,
BaseOneShotTrainer
]
=
evaluator
self
.
applied_mutators
=
applied_mutators
self
.
strategy
=
strategy
from
nni.retiarii.oneshot.pytorch.strategy
import
OneShotStrategy
if
not
isinstance
(
strategy
,
OneShotStrategy
):
# FIXME: Dispatcher should not be created this early.
self
.
_dispatcher
=
RetiariiAdvisor
(
'_placeholder_'
)
else
:
self
.
_dispatcher
=
cast
(
RetiariiAdvisor
,
None
)
self
.
_dispatcher_thread
:
Optional
[
Thread
]
=
None
self
.
_proc
:
Optional
[
Popen
]
=
None
self
.
url_prefix
=
None
self
.
_dispatcher
=
None
self
.
_dispatcher_thread
=
None
# check for sanity
if
not
is_model_wrapped
(
base_model
):
...
...
@@ -290,11 +204,12 @@ class RetiariiExperiment(Experiment):
'but it may cause inconsistent behavior compared to the time when you add it.'
+
colorama
.
Style
.
RESET_ALL
,
RuntimeWarning
)
def
_
start
_strategy
(
self
):
def
_
run
_strategy
(
self
,
config
:
RetiariiExeConfig
):
base_model_ir
,
self
.
applied_mutators
=
preprocess_model
(
self
.
base_model
,
self
.
evaluator
,
self
.
applied_mutators
,
full_ir
=
self
.
config
.
execution_engine
not
in
[
'py'
,
'benchmark'
],
dummy_input
=
self
.
config
.
dummy_input
full_ir
=
not
isinstance
(
config
.
execution_engine
,
(
PyEngineConfig
,
BenchmarkEngineConfig
)),
dummy_input
=
config
.
execution_engine
.
dummy_input
if
isinstance
(
config
.
execution_engine
,
(
BaseEngineConfig
,
CgoEngineConfig
))
else
None
)
_logger
.
info
(
'Start strategy...'
)
...
...
@@ -303,102 +218,49 @@ class RetiariiExperiment(Experiment):
self
.
strategy
.
run
(
base_model_ir
,
self
.
applied_mutators
)
_logger
.
info
(
'Strategy exit'
)
# TODO: find out a proper way to show no more trial message on WebUI
# self._dispatcher.mark_experiment_as_ending()
def
start
(
self
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
None
:
"""
Start the experiment in background.
This method will raise exception on failure.
If it returns, the experiment should have been successfully started.
Parameters
----------
port
The port of web UI.
debug
Whether to start in debug mode.
"""
atexit
.
register
(
self
.
stop
)
self
.
config
=
self
.
config
.
canonical_copy
()
# we will probably need a execution engine factory to make this clean and elegant
if
self
.
config
.
execution_engine
==
'base'
:
def
_create_execution_engine
(
self
,
config
:
RetiariiExeConfig
)
->
None
:
#TODO: we will probably need a execution engine factory to make this clean and elegant
if
isinstance
(
config
.
execution_engine
,
BaseEngineConfig
):
from
..execution.base
import
BaseExecutionEngine
engine
=
BaseExecutionEngine
()
elif
self
.
config
.
execution_engine
==
'cgo'
:
engine
=
BaseExecutionEngine
(
self
.
port
,
self
.
url_prefix
)
elif
isinstance
(
config
.
execution_engine
,
CgoEngineConfig
)
:
from
..execution.cgo_engine
import
CGOExecutionEngine
assert
self
.
config
.
training_service
.
platform
==
'remote'
,
\
assert
not
isinstance
(
config
.
training_service
,
list
)
\
and
config
.
training_service
.
platform
==
'remote'
,
\
"CGO execution engine currently only supports remote training service"
assert
self
.
config
.
batch_waiting_time
is
not
None
and
self
.
config
.
max_concurrency_cgo
is
not
None
devices
=
self
.
_construct_devices
()
engine
=
CGOExecutionEngine
(
devices
,
max_concurrency
=
self
.
config
.
max_concurrency_cgo
,
batch_waiting_time
=
self
.
config
.
batch_waiting_time
)
elif
self
.
config
.
execution_engine
==
'py'
:
assert
config
.
execution_engine
.
batch_waiting_time
is
not
None
\
and
config
.
execution_engine
.
max_concurrency_cgo
is
not
None
engine
=
CGOExecutionEngine
(
cast
(
RemoteConfig
,
config
.
training_service
),
max_concurrency
=
config
.
execution_engine
.
max_concurrency_cgo
,
batch_waiting_time
=
config
.
execution_engine
.
batch_waiting_time
,
rest_port
=
self
.
port
,
rest_url_prefix
=
self
.
url_prefix
)
elif
isinstance
(
config
.
execution_engine
,
PyEngineConfig
):
from
..execution.python
import
PurePythonExecutionEngine
engine
=
PurePythonExecutionEngine
()
elif
self
.
config
.
execution_engine
==
'b
enchmark
'
:
engine
=
PurePythonExecutionEngine
(
self
.
port
,
self
.
url_prefix
)
elif
isinstance
(
config
.
execution_engine
,
B
enchmark
EngineConfig
)
:
from
..execution.benchmark
import
BenchmarkExecutionEngine
assert
self
.
config
.
benchmark
is
not
None
,
'"benchmark" must be set when benchmark execution engine is used.'
engine
=
BenchmarkExecutionEngine
(
self
.
config
.
benchmark
)
assert
config
.
execution_engine
.
benchmark
is
not
None
,
\
'"benchmark" must be set when benchmark execution engine is used.'
engine
=
BenchmarkExecutionEngine
(
config
.
execution_engine
.
benchmark
)
else
:
raise
ValueError
(
f
'Unsupported engine type:
{
self
.
config
.
execution_engine
}
'
)
raise
ValueError
(
f
'Unsupported engine type:
{
config
.
execution_engine
}
'
)
set_execution_engine
(
engine
)
self
.
id
=
management
.
generate_experiment_id
()
log_file
=
Path
(
self
.
config
.
experiment_working_directory
,
self
.
id
,
'log'
,
'experiment.log'
)
log_file
.
parent
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
log_level
=
'debug'
if
(
debug
or
self
.
config
.
log_level
==
'trace'
)
else
self
.
config
.
log_level
nni
.
runtime
.
log
.
start_experiment_logging
(
self
.
id
,
log_file
,
cast
(
str
,
log_level
))
ws_url
=
f
'ws://localhost:
{
port
}
/tuner'
self
.
_proc
=
launcher
.
start_experiment
(
'create'
,
self
.
id
,
self
.
config
,
port
,
debug
,
# type: ignore
RunMode
.
Background
,
None
,
ws_url
,
[
'retiarii'
])
assert
self
.
_proc
is
not
None
self
.
port
=
port
# port will be None if start up failed
# dispatcher must be launched after pipe initialized
# the logic to launch dispatcher in background should be refactored into dispatcher api
self
.
_dispatcher
=
self
.
_create_dispatcher
()
if
self
.
_dispatcher
is
not
None
:
self
.
_dispatcher
.
_channel
=
TunerCommandChannel
(
ws_url
)
self
.
_dispatcher_thread
=
Thread
(
target
=
self
.
_dispatcher
.
run
)
self
.
_dispatcher_thread
.
start
()
ips
=
[
self
.
config
.
nni_manager_ip
]
for
interfaces
in
psutil
.
net_if_addrs
().
values
():
for
interface
in
interfaces
:
if
interface
.
family
==
socket
.
AF_INET
:
ips
.
append
(
interface
.
address
)
ips
=
[
f
'http://
{
ip
}
:
{
port
}
'
for
ip
in
ips
if
ip
]
msg
=
'Web UI URLs: '
+
colorama
.
Fore
.
CYAN
+
' '
.
join
(
ips
)
+
colorama
.
Style
.
RESET_ALL
_logger
.
info
(
msg
)
exp_status_checker
=
Thread
(
target
=
self
.
_check_exp_status
)
exp_status_checker
.
start
()
self
.
_start_strategy
()
# TODO: the experiment should be completed, when strategy exits and there is no running job
_logger
.
info
(
'Waiting for experiment to become DONE (you can ctrl+c if there is no running trial jobs)...'
)
exp_status_checker
.
join
()
def
_construct_devices
(
self
):
devices
=
[]
if
hasattr
(
self
.
config
.
training_service
,
'machine_list'
):
for
machine
in
cast
(
RemoteConfig
,
self
.
config
.
training_service
).
machine_list
:
assert
machine
.
gpu_indices
is
not
None
,
\
'gpu_indices must be set in RemoteMachineConfig for CGO execution engine'
assert
isinstance
(
machine
.
gpu_indices
,
list
),
'gpu_indices must be a list'
for
gpu_idx
in
machine
.
gpu_indices
:
devices
.
append
(
GPUDevice
(
machine
.
host
,
gpu_idx
))
return
devices
def
_create_dispatcher
(
self
):
return
self
.
_dispatcher
def
run
(
self
,
config
:
Optional
[
RetiariiExeConfig
]
=
None
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
None
:
def
start
(
self
,
*
args
,
**
kwargs
)
->
None
:
"""
By design, the only different between `start` and `run` is that `start` is asynchronous,
while `run` waits the experiment to complete. RetiariiExperiment always waits the experiment
to complete as strategy runs in foreground.
"""
raise
NotImplementedError
(
'RetiariiExperiment is not supposed to provide `start` method'
)
def
run
(
self
,
config
:
RetiariiExeConfig
|
None
=
None
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
None
:
"""
Run the experiment.
This function will block until experiment finish or error.
...
...
@@ -410,75 +272,47 @@ class RetiariiExperiment(Experiment):
# 'In case you want to stick to the old implementation, '
# 'please consider using ``trainer.fit()`` instead of experiment.', DeprecationWarning)
self
.
evaluator
.
fit
()
return
if
config
is
None
:
warnings
.
warn
(
'config = None is deprecate in future. If you are running a one-shot experiment, '
'please consider creating a config and set execution engine to `oneshot`.'
,
DeprecationWarning
)
config
=
RetiariiExeConfig
()
config
.
execution_engine
=
'oneshot'
self
.
config
=
RetiariiExeConfig
()
self
.
config
.
execution_engine
=
OneshotEngineConfig
()
else
:
self
.
config
=
config
if
config
.
execution_engine
==
'oneshot'
:
if
isinstance
(
self
.
config
.
execution_engine
,
OneshotEngineConfig
)
\
or
(
isinstance
(
self
.
config
.
execution_engine
,
str
)
and
self
.
config
.
execution_engine
==
'oneshot'
):
# this is hacky, will be refactored when oneshot can run on training services
base_model_ir
,
self
.
applied_mutators
=
preprocess_model
(
self
.
base_model
,
self
.
evaluator
,
self
.
applied_mutators
,
oneshot
=
True
)
self
.
strategy
.
run
(
base_model_ir
,
self
.
applied_mutators
)
else
:
assert
config
is
not
None
,
'You are using classic search mode, config cannot be None!'
self
.
config
=
config
self
.
start
(
port
,
debug
)
def
_check_exp_status
(
self
)
->
bool
:
"""
Run the experiment.
This function will block until experiment finish or error.
Return `True` when experiment done; or return `False` when experiment failed.
"""
assert
self
.
_proc
is
not
None
try
:
while
True
:
time
.
sleep
(
10
)
# this if is to deal with the situation that
# nnimanager is cleaned up by ctrl+c first
if
self
.
_proc
.
poll
()
is
None
:
status
=
self
.
get_status
()
else
:
return
False
if
status
==
'DONE'
or
status
==
'STOPPED'
:
return
True
if
status
==
'ERROR'
:
return
False
except
KeyboardInterrupt
:
_logger
.
warning
(
'KeyboardInterrupt detected'
)
finally
:
self
.
stop
()
raise
RuntimeError
(
'Check experiment status failed.'
)
ws_url
=
f
'ws://localhost:
{
port
}
/tuner'
canonicalized_config
=
self
.
_start_impl
(
port
,
debug
,
RunMode
.
Background
,
ws_url
,
[
'retiarii'
])
canonicalized_config
=
cast
(
RetiariiExeConfig
,
canonicalized_config
)
self
.
_dispatcher
=
RetiariiAdvisor
(
ws_url
)
self
.
_dispatcher_thread
=
Thread
(
target
=
self
.
_dispatcher
.
run
,
daemon
=
True
)
self
.
_dispatcher_thread
.
start
()
# FIXME: engine cannot be created twice
self
.
_create_execution_engine
(
canonicalized_config
)
try
:
self
.
_run_strategy
(
canonicalized_config
)
# FIXME: move this logic to strategy with a new API provided by execution engine
self
.
_wait_completion
()
except
KeyboardInterrupt
:
_logger
.
warning
(
'KeyboardInterrupt detected'
)
self
.
stop
()
_logger
.
info
(
'Search process is done, the experiment is still alive, `stop()` can terminate the experiment.'
)
def
stop
(
self
)
->
None
:
"""
Stop background experiment.
"""
_logger
.
info
(
'Stopping experiment, please wait...'
)
atexit
.
unregister
(
self
.
stop
)
# stop strategy first
if
self
.
_dispatcher_thread
is
not
None
:
self
.
_dispatcher
.
stopping
=
True
self
.
_dispatcher_thread
.
join
(
timeout
=
1
)
if
self
.
id
is
not
None
:
nni
.
runtime
.
log
.
stop_experiment_logging
(
self
.
id
)
if
self
.
_proc
is
not
None
:
try
:
# this if is to deal with the situation that
# nnimanager is cleaned up by ctrl+c first
if
self
.
_proc
.
poll
()
is
None
:
rest
.
delete
(
self
.
port
,
'/experiment'
)
except
Exception
as
e
:
_logger
.
exception
(
e
)
_logger
.
warning
(
'Cannot gracefully stop experiment, killing NNI process...'
)
kill_command
(
self
.
_proc
.
pid
)
self
.
id
=
cast
(
str
,
None
)
self
.
port
=
cast
(
int
,
None
)
self
.
_proc
=
None
self
.
_stop_impl
()
if
self
.
_dispatcher_thread
:
self
.
_dispatcher_thread
.
join
()
self
.
_dispatcher
=
cast
(
RetiariiAdvisor
,
None
)
self
.
_dispatcher_thread
=
None
_logger
.
info
(
'Experiment stopped'
)
...
...
@@ -502,8 +336,11 @@ class RetiariiExperiment(Experiment):
If ``code``, the python code of model will be returned.
If ``dict``, the mutation history will be returned.
"""
# TODO: the base class may also need this method
if
formatter
==
'code'
:
assert
self
.
config
.
execution_engine
!=
'py'
,
'You should use `dict` formatter when using Python execution engine.'
config
=
self
.
config
.
canonical_copy
()
assert
not
isinstance
(
config
.
execution_engine
,
PyEngineConfig
),
\
'You should use `dict` formatter when using Python execution engine.'
if
isinstance
(
self
.
evaluator
,
BaseOneShotTrainer
):
assert
top_k
==
1
,
'Only support top_k is 1 for now.'
return
self
.
evaluator
.
export
()
...
...
@@ -520,9 +357,3 @@ class RetiariiExperiment(Experiment):
return
[
model_to_pytorch_script
(
model
)
for
model
in
all_models
[:
top_k
]]
elif
formatter
==
'dict'
:
return
[
get_mutation_dict
(
model
)
for
model
in
all_models
[:
top_k
]]
def
retrain_model
(
self
,
model
):
"""
this function retrains the exported model, and test it to output test accuracy
"""
raise
NotImplementedError
nni/retiarii/integration_api.py
View file @
2fc47247
...
...
@@ -22,7 +22,10 @@ def get_advisor() -> 'RetiariiAdvisor':
def
register_advisor
(
advisor
:
'RetiariiAdvisor'
):
global
_advisor
assert
_advisor
is
None
if
_advisor
is
not
None
:
warnings
.
warn
(
'Advisor is already set.'
'You should avoid instantiating RetiariiExperiment twice in one proces.'
'If you are running in a Jupyter notebook, please restart the kernel.'
)
_advisor
=
advisor
...
...
nni/runtime/msg_dispatcher_base.py
View file @
2fc47247
...
...
@@ -18,8 +18,15 @@ _worker_fast_exit_on_terminate = True
class
MsgDispatcherBase
(
Recoverable
):
"""This is where tuners and assessors are not defined yet.
"""
This is where tuners and assessors are not defined yet.
Inherits this class to make your own advisor.
.. note::
The class inheriting MsgDispatcherBase should be instantiated
after nnimanager (rest server) is started, so that the object
is ready to use right after its instantiation.
"""
def
__init__
(
self
,
command_channel_url
=
None
):
...
...
@@ -27,6 +34,16 @@ class MsgDispatcherBase(Recoverable):
if
command_channel_url
is
None
:
command_channel_url
=
dispatcher_env_vars
.
NNI_TUNER_COMMAND_CHANNEL
self
.
_channel
=
TunerCommandChannel
(
command_channel_url
)
# NOTE: `connect()` should be put in __init__. First, this `connect()` affects nnimanager's
# starting process, without `connect()` nnimanager is blocked in `dispatcher.init()`.
# Second, nas experiment uses a thread to execute `run()` of this class, thus, there is
# no way to know when the websocket between nnimanager and dispatcher is built. The following
# logic may crash is websocket is not built. One example is updating search space. If updating
# search space too soon, as the websocket has not been built, the rest api of updating search
# space will timeout.
# FIXME: this is making unittest happy
if
not
command_channel_url
.
startswith
(
'ws://_unittest_'
):
self
.
_channel
.
connect
()
self
.
default_command_queue
=
Queue
()
self
.
assessor_command_queue
=
Queue
()
self
.
default_worker
=
threading
.
Thread
(
target
=
self
.
command_queue_worker
,
args
=
(
self
.
default_command_queue
,))
...
...
@@ -39,7 +56,6 @@ class MsgDispatcherBase(Recoverable):
"""
_logger
.
info
(
'Dispatcher started'
)
self
.
_channel
.
connect
()
self
.
default_worker
.
start
()
self
.
assessor_worker
.
start
()
...
...
test/retiarii_test/cgo_mnasnet/base_mnasnet.py
View file @
2fc47247
...
...
@@ -4,7 +4,6 @@ import warnings
import
torch
import
torch.nn
as
torch_nn
from
torchvision.models.utils
import
load_state_dict_from_url
import
torch.nn.functional
as
F
import
sys
...
...
test/retiarii_test/cgo_mnasnet/test.py
View file @
2fc47247
...
...
@@ -8,7 +8,7 @@ import nni.retiarii.evaluator.pytorch.cgo.evaluator as cgo
from
nni.retiarii
import
serialize
from
base_mnasnet
import
MNASNet
from
nni.experiment
import
RemoteMachineConfig
from
nni.retiarii.experiment.pytorch
import
RetiariiExperiment
,
RetiariiExeConfig
from
nni.retiarii.experiment.pytorch
import
RetiariiExperiment
,
RetiariiExeConfig
,
CgoEngineConfig
from
nni.retiarii.strategy
import
TPEStrategy
from
torchvision
import
transforms
from
torchvision.datasets
import
CIFAR10
...
...
@@ -59,8 +59,6 @@ if __name__ == '__main__':
exp_config
.
max_trial_number
=
10
exp_config
.
trial_gpu_number
=
1
exp_config
.
training_service
.
reuse_mode
=
True
exp_config
.
max_concurrency_cgo
=
3
exp_config
.
batch_waiting_time
=
0
rm_conf
=
RemoteMachineConfig
()
rm_conf
.
host
=
'127.0.0.1'
...
...
@@ -73,6 +71,6 @@ if __name__ == '__main__':
rm_conf
.
max_trial_number_per_gpu
=
3
exp_config
.
training_service
.
machine_list
=
[
rm_conf
]
exp_config
.
execution_engine
=
'cgo'
exp_config
.
execution_engine
=
CgoEngineConfig
(
max_concurrency_cgo
=
3
,
batch_waiting_time
=
0
)
exp
.
run
(
exp_config
,
8099
)
\ No newline at end of file
exp
.
run
(
exp_config
,
8099
)
test/ut/retiarii/test_cgo_engine.py
View file @
2fc47247
...
...
@@ -9,6 +9,7 @@ from pytorch_lightning.utilities.seed import seed_everything
from
pathlib
import
Path
import
nni
from
nni.experiment.config
import
RemoteConfig
,
RemoteMachineConfig
import
nni.runtime.platform.test
from
nni.runtime.tuner_command_channel
import
legacy
as
protocol
import
json
...
...
@@ -263,13 +264,14 @@ class CGOEngineTest(unittest.TestCase):
opt
=
DedupInputOptimizer
()
opt
.
convert
(
lp
)
advisor
=
RetiariiAdvisor
(
'ws://_placeholder_'
)
advisor
=
RetiariiAdvisor
(
'ws://_
unittest_
placeholder_'
)
advisor
.
_channel
=
protocol
.
LegacyCommandChannel
()
advisor
.
default_worker
.
start
()
advisor
.
assessor_worker
.
start
()
available_devices
=
[
GPUDevice
(
"test"
,
0
),
GPUDevice
(
"test"
,
1
),
GPUDevice
(
"test"
,
2
),
GPUDevice
(
"test"
,
3
)]
cgo
=
CGOExecutionEngine
(
devices
=
available_devices
,
batch_waiting_time
=
0
)
remote
=
RemoteConfig
(
machine_list
=
[])
remote
.
machine_list
.
append
(
RemoteMachineConfig
(
host
=
'test'
,
gpu_indices
=
[
0
,
1
,
2
,
3
]))
cgo
=
CGOExecutionEngine
(
training_service
=
remote
,
batch_waiting_time
=
0
)
phy_models
=
cgo
.
_assemble
(
lp
)
self
.
assertTrue
(
len
(
phy_models
)
==
1
)
...
...
@@ -286,13 +288,14 @@ class CGOEngineTest(unittest.TestCase):
opt
=
DedupInputOptimizer
()
opt
.
convert
(
lp
)
advisor
=
RetiariiAdvisor
(
'ws://_placeholder_'
)
advisor
=
RetiariiAdvisor
(
'ws://_
unittest_
placeholder_'
)
advisor
.
_channel
=
protocol
.
LegacyCommandChannel
()
advisor
.
default_worker
.
start
()
advisor
.
assessor_worker
.
start
()
available_devices
=
[
GPUDevice
(
"test"
,
0
),
GPUDevice
(
"test"
,
1
)]
cgo
=
CGOExecutionEngine
(
devices
=
available_devices
,
batch_waiting_time
=
0
)
remote
=
RemoteConfig
(
machine_list
=
[])
remote
.
machine_list
.
append
(
RemoteMachineConfig
(
host
=
'test'
,
gpu_indices
=
[
0
,
1
]))
cgo
=
CGOExecutionEngine
(
training_service
=
remote
,
batch_waiting_time
=
0
)
phy_models
=
cgo
.
_assemble
(
lp
)
self
.
assertTrue
(
len
(
phy_models
)
==
2
)
...
...
@@ -311,13 +314,14 @@ class CGOEngineTest(unittest.TestCase):
models
=
_load_mnist
(
2
)
advisor
=
RetiariiAdvisor
(
'ws://_placeholder_'
)
advisor
=
RetiariiAdvisor
(
'ws://_
unittest_
placeholder_'
)
advisor
.
_channel
=
protocol
.
LegacyCommandChannel
()
advisor
.
default_worker
.
start
()
advisor
.
assessor_worker
.
start
()
cgo_engine
=
CGOExecutionEngine
(
devices
=
[
GPUDevice
(
"test"
,
0
),
GPUDevice
(
"test"
,
1
),
GPUDevice
(
"test"
,
2
),
GPUDevice
(
"test"
,
3
)],
batch_waiting_time
=
0
)
remote
=
RemoteConfig
(
machine_list
=
[])
remote
.
machine_list
.
append
(
RemoteMachineConfig
(
host
=
'test'
,
gpu_indices
=
[
0
,
1
,
2
,
3
]))
cgo_engine
=
CGOExecutionEngine
(
training_service
=
remote
,
batch_waiting_time
=
0
)
set_execution_engine
(
cgo_engine
)
submit_models
(
*
models
)
time
.
sleep
(
3
)
...
...
test/ut/retiarii/test_engine.py
View file @
2fc47247
...
...
@@ -25,7 +25,7 @@ class EngineTest(unittest.TestCase):
def
test_base_execution_engine
(
self
):
nni
.
retiarii
.
integration_api
.
_advisor
=
None
nni
.
retiarii
.
execution
.
api
.
_execution_engine
=
None
advisor
=
RetiariiAdvisor
(
'ws://_placeholder_'
)
advisor
=
RetiariiAdvisor
(
'ws://_
unittest_
placeholder_'
)
advisor
.
_channel
=
LegacyCommandChannel
()
advisor
.
default_worker
.
start
()
advisor
.
assessor_worker
.
start
()
...
...
@@ -42,7 +42,7 @@ class EngineTest(unittest.TestCase):
def
test_py_execution_engine
(
self
):
nni
.
retiarii
.
integration_api
.
_advisor
=
None
nni
.
retiarii
.
execution
.
api
.
_execution_engine
=
None
advisor
=
RetiariiAdvisor
(
'ws://_placeholder_'
)
advisor
=
RetiariiAdvisor
(
'ws://_
unittest_
placeholder_'
)
advisor
.
_channel
=
LegacyCommandChannel
()
advisor
.
default_worker
.
start
()
advisor
.
assessor_worker
.
start
()
...
...
test/ut/sdk/test_assessor.py
View file @
2fc47247
...
...
@@ -57,7 +57,7 @@ class AssessorTestCase(TestCase):
_restore_io
()
assessor
=
NaiveAssessor
()
dispatcher
=
MsgDispatcher
(
'ws://_placeholder_'
,
None
,
assessor
)
dispatcher
=
MsgDispatcher
(
'ws://_
unittest_
placeholder_'
,
None
,
assessor
)
dispatcher
.
_channel
=
LegacyCommandChannel
()
msg_dispatcher_base
.
_worker_fast_exit_on_terminate
=
False
...
...
test/ut/sdk/test_msg_dispatcher.py
View file @
2fc47247
...
...
@@ -66,7 +66,7 @@ class MsgDispatcherTestCase(TestCase):
_restore_io
()
tuner
=
NaiveTuner
()
dispatcher
=
MsgDispatcher
(
'ws://_placeholder_'
,
tuner
)
dispatcher
=
MsgDispatcher
(
'ws://_
unittest_
placeholder_'
,
tuner
)
dispatcher
.
_channel
=
LegacyCommandChannel
()
msg_dispatcher_base
.
_worker_fast_exit_on_terminate
=
False
...
...
ts/nni_manager/core/nnimanager.ts
View file @
2fc47247
...
...
@@ -303,8 +303,11 @@ class NNIManager implements Manager {
}
this
.
trainingService
.
removeTrialJobMetricListener
(
this
.
trialJobMetricListener
);
// NOTE: this sending TERMINATE should be out of the if clause,
// because when python dispatcher is started before nnimanager
// this.dispatcherPid would not have a valid value (i.e., not >0).
this
.
dispatcher
.
sendCommand
(
TERMINATE
);
if
(
this
.
dispatcherPid
>
0
)
{
this
.
dispatcher
.
sendCommand
(
TERMINATE
);
// gracefully terminate tuner and assessor here, wait at most 30 seconds.
for
(
let
i
:
number
=
0
;
i
<
30
;
i
++
)
{
if
(
!
await
isAlive
(
this
.
dispatcherPid
))
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment