Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
98c1a77f
Unverified
Commit
98c1a77f
authored
May 14, 2022
by
liuzhe-lz
Committed by
GitHub
May 14, 2022
Browse files
Support multiple HPO experiments in one process (#4855)
parent
5dc80762
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
207 additions
and
197 deletions
+207
-197
nni/__main__.py
nni/__main__.py
+4
-6
nni/algorithms/hpo/bohb_advisor/bohb_advisor.py
nni/algorithms/hpo/bohb_advisor/bohb_advisor.py
+6
-6
nni/algorithms/hpo/hyperband_advisor.py
nni/algorithms/hpo/hyperband_advisor.py
+5
-5
nni/retiarii/experiment/pytorch.py
nni/retiarii/experiment/pytorch.py
+5
-3
nni/retiarii/integration.py
nni/retiarii/integration.py
+6
-6
nni/runtime/common.py
nni/runtime/common.py
+0
-8
nni/runtime/env_vars.py
nni/runtime/env_vars.py
+2
-1
nni/runtime/msg_dispatcher.py
nni/runtime/msg_dispatcher.py
+12
-15
nni/runtime/msg_dispatcher_base.py
nni/runtime/msg_dispatcher_base.py
+26
-48
nni/runtime/protocol.py
nni/runtime/protocol.py
+0
-39
nni/runtime/tuner_command_channel/__init__.py
nni/runtime/tuner_command_channel/__init__.py
+4
-3
nni/runtime/tuner_command_channel/channel.py
nni/runtime/tuner_command_channel/channel.py
+61
-0
nni/runtime/tuner_command_channel/legacy.py
nni/runtime/tuner_command_channel/legacy.py
+33
-0
nni/runtime/tuner_command_channel/shim.py
nni/runtime/tuner_command_channel/shim.py
+0
-33
nni/tools/nnictl/config_schema.py
nni/tools/nnictl/config_schema.py
+0
-2
nni/tools/nnictl/launcher.py
nni/tools/nnictl/launcher.py
+7
-4
test/ut/retiarii/test_cgo_engine.py
test/ut/retiarii/test_cgo_engine.py
+16
-4
test/ut/retiarii/test_engine.py
test/ut/retiarii/test_engine.py
+13
-7
test/ut/sdk/helper/__init__.py
test/ut/sdk/helper/__init__.py
+0
-0
test/ut/sdk/test_assessor.py
test/ut/sdk/test_assessor.py
+7
-7
No files found.
nni/__main__.py
View file @
98c1a77f
...
@@ -9,7 +9,6 @@ import base64
...
@@ -9,7 +9,6 @@ import base64
from
.runtime.msg_dispatcher
import
MsgDispatcher
from
.runtime.msg_dispatcher
import
MsgDispatcher
from
.runtime.msg_dispatcher_base
import
MsgDispatcherBase
from
.runtime.msg_dispatcher_base
import
MsgDispatcherBase
from
.runtime.protocol
import
connect_websocket
from
.tools.package_utils
import
create_builtin_class_instance
,
create_customized_class_instance
from
.tools.package_utils
import
create_builtin_class_instance
,
create_customized_class_instance
logger
=
logging
.
getLogger
(
'nni.main'
)
logger
=
logging
.
getLogger
(
'nni.main'
)
...
@@ -21,10 +20,6 @@ if os.environ.get('COVERAGE_PROCESS_START'):
...
@@ -21,10 +20,6 @@ if os.environ.get('COVERAGE_PROCESS_START'):
def
main
():
def
main
():
# the url should be "ws://localhost:{port}/tuner" or "ws://localhost:{port}/{url_prefix}/tuner"
ws_url
=
os
.
environ
[
'NNI_TUNER_COMMAND_CHANNEL'
]
connect_websocket
(
ws_url
)
parser
=
argparse
.
ArgumentParser
(
description
=
'Dispatcher command line parser'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'Dispatcher command line parser'
)
parser
.
add_argument
(
'--exp_params'
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
'--exp_params'
,
type
=
str
,
required
=
True
)
args
,
_
=
parser
.
parse_known_args
()
args
,
_
=
parser
.
parse_known_args
()
...
@@ -56,7 +51,10 @@ def main():
...
@@ -56,7 +51,10 @@ def main():
assessor
=
_create_algo
(
exp_params
[
'assessor'
],
'assessor'
)
assessor
=
_create_algo
(
exp_params
[
'assessor'
],
'assessor'
)
else
:
else
:
assessor
=
None
assessor
=
None
dispatcher
=
MsgDispatcher
(
tuner
,
assessor
)
# the url should be "ws://localhost:{port}/tuner" or "ws://localhost:{port}/{url_prefix}/tuner"
url
=
os
.
environ
[
'NNI_TUNER_COMMAND_CHANNEL'
]
dispatcher
=
MsgDispatcher
(
url
,
tuner
,
assessor
)
try
:
try
:
dispatcher
.
run
()
dispatcher
.
run
()
...
...
nni/algorithms/hpo/bohb_advisor/bohb_advisor.py
View file @
98c1a77f
...
@@ -14,7 +14,7 @@ from ConfigSpace.read_and_write import pcs_new
...
@@ -14,7 +14,7 @@ from ConfigSpace.read_and_write import pcs_new
import
nni
import
nni
from
nni
import
ClassArgsValidator
from
nni
import
ClassArgsValidator
from
nni.runtime.
protoco
l
import
CommandType
,
send
from
nni.runtime.
tuner_command_channe
l
import
CommandType
from
nni.runtime.msg_dispatcher_base
import
MsgDispatcherBase
from
nni.runtime.msg_dispatcher_base
import
MsgDispatcherBase
from
nni.utils
import
OptimizeMode
,
MetricType
,
extract_scalar_reward
from
nni.utils
import
OptimizeMode
,
MetricType
,
extract_scalar_reward
from
nni.runtime.common
import
multi_phase_enabled
from
nni.runtime.common
import
multi_phase_enabled
...
@@ -483,7 +483,7 @@ class BOHB(MsgDispatcherBase):
...
@@ -483,7 +483,7 @@ class BOHB(MsgDispatcherBase):
raise
ValueError
(
'Error: Search space is None'
)
raise
ValueError
(
'Error: Search space is None'
)
# generate first brackets
# generate first brackets
self
.
generate_new_bracket
()
self
.
generate_new_bracket
()
send
(
CommandType
.
Initialized
,
''
)
self
.
send
(
CommandType
.
Initialized
,
''
)
def
generate_new_bracket
(
self
):
def
generate_new_bracket
(
self
):
"""generate a new bracket"""
"""generate a new bracket"""
...
@@ -541,7 +541,7 @@ class BOHB(MsgDispatcherBase):
...
@@ -541,7 +541,7 @@ class BOHB(MsgDispatcherBase):
'parameter_source'
:
'algorithm'
,
'parameter_source'
:
'algorithm'
,
'parameters'
:
''
'parameters'
:
''
}
}
send
(
CommandType
.
NoMoreTrialJobs
,
nni
.
dump
(
ret
))
self
.
send
(
CommandType
.
NoMoreTrialJobs
,
nni
.
dump
(
ret
))
return
None
return
None
assert
self
.
generated_hyper_configs
assert
self
.
generated_hyper_configs
params
=
self
.
generated_hyper_configs
.
pop
(
0
)
params
=
self
.
generated_hyper_configs
.
pop
(
0
)
...
@@ -572,7 +572,7 @@ class BOHB(MsgDispatcherBase):
...
@@ -572,7 +572,7 @@ class BOHB(MsgDispatcherBase):
"""
"""
ret
=
self
.
_get_one_trial_job
()
ret
=
self
.
_get_one_trial_job
()
if
ret
is
not
None
:
if
ret
is
not
None
:
send
(
CommandType
.
NewTrialJob
,
nni
.
dump
(
ret
))
self
.
send
(
CommandType
.
NewTrialJob
,
nni
.
dump
(
ret
))
self
.
credit
-=
1
self
.
credit
-=
1
def
handle_update_search_space
(
self
,
data
):
def
handle_update_search_space
(
self
,
data
):
...
@@ -664,7 +664,7 @@ class BOHB(MsgDispatcherBase):
...
@@ -664,7 +664,7 @@ class BOHB(MsgDispatcherBase):
ret
[
'parameter_index'
]
=
one_unsatisfied
[
'parameter_index'
]
ret
[
'parameter_index'
]
=
one_unsatisfied
[
'parameter_index'
]
# update parameter_id in self.job_id_para_id_map
# update parameter_id in self.job_id_para_id_map
self
.
job_id_para_id_map
[
ret
[
'trial_job_id'
]]
=
ret
[
'parameter_id'
]
self
.
job_id_para_id_map
[
ret
[
'trial_job_id'
]]
=
ret
[
'parameter_id'
]
send
(
CommandType
.
SendTrialJobParameter
,
nni
.
dump
(
ret
))
self
.
send
(
CommandType
.
SendTrialJobParameter
,
nni
.
dump
(
ret
))
for
_
in
range
(
self
.
credit
):
for
_
in
range
(
self
.
credit
):
self
.
_request_one_trial_job
()
self
.
_request_one_trial_job
()
...
@@ -712,7 +712,7 @@ class BOHB(MsgDispatcherBase):
...
@@ -712,7 +712,7 @@ class BOHB(MsgDispatcherBase):
ret
[
'parameter_index'
]
=
data
[
'parameter_index'
]
ret
[
'parameter_index'
]
=
data
[
'parameter_index'
]
# update parameter_id in self.job_id_para_id_map
# update parameter_id in self.job_id_para_id_map
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]]
=
ret
[
'parameter_id'
]
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]]
=
ret
[
'parameter_id'
]
send
(
CommandType
.
SendTrialJobParameter
,
nni
.
dump
(
ret
))
self
.
send
(
CommandType
.
SendTrialJobParameter
,
nni
.
dump
(
ret
))
else
:
else
:
assert
'value'
in
data
assert
'value'
in
data
value
=
extract_scalar_reward
(
data
[
'value'
])
value
=
extract_scalar_reward
(
data
[
'value'
])
...
...
nni/algorithms/hpo/hyperband_advisor.py
View file @
98c1a77f
...
@@ -18,7 +18,7 @@ from nni import ClassArgsValidator
...
@@ -18,7 +18,7 @@ from nni import ClassArgsValidator
from
nni.common.hpo_utils
import
validate_search_space
from
nni.common.hpo_utils
import
validate_search_space
from
nni.runtime.common
import
multi_phase_enabled
from
nni.runtime.common
import
multi_phase_enabled
from
nni.runtime.msg_dispatcher_base
import
MsgDispatcherBase
from
nni.runtime.msg_dispatcher_base
import
MsgDispatcherBase
from
nni.runtime.
protoco
l
import
CommandType
,
send
from
nni.runtime.
tuner_command_channe
l
import
CommandType
from
nni.utils
import
NodeType
,
OptimizeMode
,
MetricType
,
extract_scalar_reward
from
nni.utils
import
NodeType
,
OptimizeMode
,
MetricType
,
extract_scalar_reward
from
nni
import
parameter_expressions
from
nni
import
parameter_expressions
...
@@ -432,7 +432,7 @@ class Hyperband(MsgDispatcherBase):
...
@@ -432,7 +432,7 @@ class Hyperband(MsgDispatcherBase):
search space
search space
"""
"""
self
.
handle_update_search_space
(
data
)
self
.
handle_update_search_space
(
data
)
send
(
CommandType
.
Initialized
,
''
)
self
.
send
(
CommandType
.
Initialized
,
''
)
def
handle_request_trial_jobs
(
self
,
data
):
def
handle_request_trial_jobs
(
self
,
data
):
"""
"""
...
@@ -449,7 +449,7 @@ class Hyperband(MsgDispatcherBase):
...
@@ -449,7 +449,7 @@ class Hyperband(MsgDispatcherBase):
def
_request_one_trial_job
(
self
):
def
_request_one_trial_job
(
self
):
ret
=
self
.
_get_one_trial_job
()
ret
=
self
.
_get_one_trial_job
()
if
ret
is
not
None
:
if
ret
is
not
None
:
send
(
CommandType
.
NewTrialJob
,
nni
.
dump
(
ret
))
self
.
send
(
CommandType
.
NewTrialJob
,
nni
.
dump
(
ret
))
self
.
credit
-=
1
self
.
credit
-=
1
def
_get_one_trial_job
(
self
):
def
_get_one_trial_job
(
self
):
...
@@ -478,7 +478,7 @@ class Hyperband(MsgDispatcherBase):
...
@@ -478,7 +478,7 @@ class Hyperband(MsgDispatcherBase):
'parameter_source'
:
'algorithm'
,
'parameter_source'
:
'algorithm'
,
'parameters'
:
''
'parameters'
:
''
}
}
send
(
CommandType
.
NoMoreTrialJobs
,
nni
.
dump
(
ret
))
self
.
send
(
CommandType
.
NoMoreTrialJobs
,
nni
.
dump
(
ret
))
return
None
return
None
assert
self
.
generated_hyper_configs
assert
self
.
generated_hyper_configs
...
@@ -553,7 +553,7 @@ class Hyperband(MsgDispatcherBase):
...
@@ -553,7 +553,7 @@ class Hyperband(MsgDispatcherBase):
if
data
[
'parameter_index'
]
is
not
None
:
if
data
[
'parameter_index'
]
is
not
None
:
ret
[
'parameter_index'
]
=
data
[
'parameter_index'
]
ret
[
'parameter_index'
]
=
data
[
'parameter_index'
]
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]]
=
ret
[
'parameter_id'
]
self
.
job_id_para_id_map
[
data
[
'trial_job_id'
]]
=
ret
[
'parameter_id'
]
send
(
CommandType
.
SendTrialJobParameter
,
nni
.
dump
(
ret
))
self
.
send
(
CommandType
.
SendTrialJobParameter
,
nni
.
dump
(
ret
))
else
:
else
:
value
=
extract_scalar_reward
(
data
[
'value'
])
value
=
extract_scalar_reward
(
data
[
'value'
])
bracket_id
,
i
,
_
=
data
[
'parameter_id'
].
split
(
'_'
)
bracket_id
,
i
,
_
=
data
[
'parameter_id'
].
split
(
'_'
)
...
...
nni/retiarii/experiment/pytorch.py
View file @
98c1a77f
...
@@ -24,7 +24,7 @@ from nni.experiment.config import utils
...
@@ -24,7 +24,7 @@ from nni.experiment.config import utils
from
nni.experiment.config.base
import
ConfigBase
from
nni.experiment.config.base
import
ConfigBase
from
nni.experiment.config.training_service
import
TrainingServiceConfig
from
nni.experiment.config.training_service
import
TrainingServiceConfig
from
nni.experiment.config.training_services
import
RemoteConfig
from
nni.experiment.config.training_services
import
RemoteConfig
from
nni.runtime.
protocol
import
connect_websocket
from
nni.runtime.
tuner_command_channel
import
TunerCommandChannel
from
nni.tools.nnictl.command_utils
import
kill_command
from
nni.tools.nnictl.command_utils
import
kill_command
from
..codegen
import
model_to_pytorch_script
from
..codegen
import
model_to_pytorch_script
...
@@ -274,7 +274,8 @@ class RetiariiExperiment(Experiment):
...
@@ -274,7 +274,8 @@ class RetiariiExperiment(Experiment):
from
nni.retiarii.oneshot.pytorch.strategy
import
OneShotStrategy
from
nni.retiarii.oneshot.pytorch.strategy
import
OneShotStrategy
if
not
isinstance
(
strategy
,
OneShotStrategy
):
if
not
isinstance
(
strategy
,
OneShotStrategy
):
self
.
_dispatcher
=
RetiariiAdvisor
()
# FIXME: Dispatcher should not be created this early.
self
.
_dispatcher
=
RetiariiAdvisor
(
'_placeholder_'
)
else
:
else
:
self
.
_dispatcher
=
cast
(
RetiariiAdvisor
,
None
)
self
.
_dispatcher
=
cast
(
RetiariiAdvisor
,
None
)
self
.
_dispatcher_thread
:
Optional
[
Thread
]
=
None
self
.
_dispatcher_thread
:
Optional
[
Thread
]
=
None
...
@@ -357,13 +358,14 @@ class RetiariiExperiment(Experiment):
...
@@ -357,13 +358,14 @@ class RetiariiExperiment(Experiment):
self
.
_proc
=
launcher
.
start_experiment
(
'create'
,
self
.
id
,
self
.
config
,
port
,
debug
,
# type: ignore
self
.
_proc
=
launcher
.
start_experiment
(
'create'
,
self
.
id
,
self
.
config
,
port
,
debug
,
# type: ignore
RunMode
.
Background
,
None
,
ws_url
,
[
'retiarii'
])
RunMode
.
Background
,
None
,
ws_url
,
[
'retiarii'
])
assert
self
.
_proc
is
not
None
assert
self
.
_proc
is
not
None
connect_websocket
(
ws_url
)
self
.
port
=
port
# port will be None if start up failed
self
.
port
=
port
# port will be None if start up failed
# dispatcher must be launched after pipe initialized
# dispatcher must be launched after pipe initialized
# the logic to launch dispatcher in background should be refactored into dispatcher api
# the logic to launch dispatcher in background should be refactored into dispatcher api
self
.
_dispatcher
=
self
.
_create_dispatcher
()
self
.
_dispatcher
=
self
.
_create_dispatcher
()
if
self
.
_dispatcher
is
not
None
:
self
.
_dispatcher
.
_channel
=
TunerCommandChannel
(
ws_url
)
self
.
_dispatcher_thread
=
Thread
(
target
=
self
.
_dispatcher
.
run
)
self
.
_dispatcher_thread
=
Thread
(
target
=
self
.
_dispatcher
.
run
)
self
.
_dispatcher_thread
.
start
()
self
.
_dispatcher_thread
.
start
()
...
...
nni/retiarii/integration.py
View file @
98c1a77f
...
@@ -9,7 +9,7 @@ import nni
...
@@ -9,7 +9,7 @@ import nni
from
nni.common.serializer
import
PayloadTooLarge
from
nni.common.serializer
import
PayloadTooLarge
from
nni.common.version
import
version_dump
from
nni.common.version
import
version_dump
from
nni.runtime.msg_dispatcher_base
import
MsgDispatcherBase
from
nni.runtime.msg_dispatcher_base
import
MsgDispatcherBase
from
nni.runtime.
protoco
l
import
CommandType
,
send
from
nni.runtime.
tuner_command_channe
l
import
CommandType
from
nni.utils
import
MetricType
from
nni.utils
import
MetricType
from
.graph
import
MetricData
from
.graph
import
MetricData
...
@@ -48,8 +48,8 @@ class RetiariiAdvisor(MsgDispatcherBase):
...
@@ -48,8 +48,8 @@ class RetiariiAdvisor(MsgDispatcherBase):
final_metric_callback
final_metric_callback
"""
"""
def
__init__
(
self
):
def
__init__
(
self
,
url
:
str
):
super
(
RetiariiAdvisor
,
self
).
__init__
()
super
().
__init__
(
url
)
register_advisor
(
self
)
# register the current advisor as the "global only" advisor
register_advisor
(
self
)
# register the current advisor as the "global only" advisor
self
.
search_space
=
None
self
.
search_space
=
None
...
@@ -69,7 +69,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
...
@@ -69,7 +69,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
search space
search space
"""
"""
self
.
handle_update_search_space
(
data
)
self
.
handle_update_search_space
(
data
)
send
(
CommandType
.
Initialized
,
''
)
self
.
send
(
CommandType
.
Initialized
,
''
)
def
_validate_placement_constraint
(
self
,
placement_constraint
):
def
_validate_placement_constraint
(
self
,
placement_constraint
):
if
placement_constraint
is
None
:
if
placement_constraint
is
None
:
...
@@ -138,14 +138,14 @@ class RetiariiAdvisor(MsgDispatcherBase):
...
@@ -138,14 +138,14 @@ class RetiariiAdvisor(MsgDispatcherBase):
# trial parameters can be super large, disable pickle size limit here
# trial parameters can be super large, disable pickle size limit here
# nevertheless, there could still be blocked by pipe / nni-manager
# nevertheless, there could still be blocked by pipe / nni-manager
send
(
CommandType
.
NewTrialJob
,
send_payload
)
self
.
send
(
CommandType
.
NewTrialJob
,
send_payload
)
if
self
.
send_trial_callback
is
not
None
:
if
self
.
send_trial_callback
is
not
None
:
self
.
send_trial_callback
(
parameters
)
# pylint: disable=not-callable
self
.
send_trial_callback
(
parameters
)
# pylint: disable=not-callable
return
self
.
parameters_count
return
self
.
parameters_count
def
mark_experiment_as_ending
(
self
):
def
mark_experiment_as_ending
(
self
):
send
(
CommandType
.
NoMoreTrialJobs
,
''
)
self
.
send
(
CommandType
.
NoMoreTrialJobs
,
''
)
def
handle_request_trial_jobs
(
self
,
num_trials
):
def
handle_request_trial_jobs
(
self
,
num_trials
):
_logger
.
debug
(
'Request trial jobs: %s'
,
num_trials
)
_logger
.
debug
(
'Request trial jobs: %s'
,
num_trials
)
...
...
nni/runtime/common.py
View file @
98c1a77f
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
_multi_thread
=
False
_multi_phase
=
False
_multi_phase
=
False
def
enable_multi_thread
():
global
_multi_thread
_multi_thread
=
True
def
multi_thread_enabled
():
return
_multi_thread
def
enable_multi_phase
():
def
enable_multi_phase
():
global
_multi_phase
global
_multi_phase
_multi_phase
=
True
_multi_phase
=
True
...
...
nni/runtime/env_vars.py
View file @
98c1a77f
...
@@ -22,7 +22,8 @@ _dispatcher_env_var_names = [
...
@@ -22,7 +22,8 @@ _dispatcher_env_var_names = [
'NNI_CHECKPOINT_DIRECTORY'
,
'NNI_CHECKPOINT_DIRECTORY'
,
'NNI_LOG_DIRECTORY'
,
'NNI_LOG_DIRECTORY'
,
'NNI_LOG_LEVEL'
,
'NNI_LOG_LEVEL'
,
'NNI_INCLUDE_INTERMEDIATE_RESULTS'
'NNI_INCLUDE_INTERMEDIATE_RESULTS'
,
'NNI_TUNER_COMMAND_CHANNEL'
,
]
]
def
_load_env_vars
(
env_var_names
):
def
_load_env_vars
(
env_var_names
):
...
...
nni/runtime/msg_dispatcher.py
View file @
98c1a77f
...
@@ -7,10 +7,10 @@ from collections import defaultdict
...
@@ -7,10 +7,10 @@ from collections import defaultdict
from
nni
import
NoMoreTrialError
from
nni
import
NoMoreTrialError
from
nni.assessor
import
AssessResult
from
nni.assessor
import
AssessResult
from
.common
import
multi_thread_enabled
,
multi_phase_enabled
from
.common
import
multi_phase_enabled
from
.env_vars
import
dispatcher_env_vars
from
.env_vars
import
dispatcher_env_vars
from
.msg_dispatcher_base
import
MsgDispatcherBase
from
.msg_dispatcher_base
import
MsgDispatcherBase
from
.
protoco
l
import
CommandType
,
send
from
.
tuner_command_channe
l
import
CommandType
from
..common.serializer
import
dump
,
load
from
..common.serializer
import
dump
,
load
from
..utils
import
MetricType
from
..utils
import
MetricType
...
@@ -67,8 +67,8 @@ def _pack_parameter(parameter_id, params, customized=False, trial_job_id=None, p
...
@@ -67,8 +67,8 @@ def _pack_parameter(parameter_id, params, customized=False, trial_job_id=None, p
class
MsgDispatcher
(
MsgDispatcherBase
):
class
MsgDispatcher
(
MsgDispatcherBase
):
def
__init__
(
self
,
tuner
,
assessor
=
None
):
def
__init__
(
self
,
command_channel_url
,
tuner
,
assessor
=
None
):
super
(
MsgDispatcher
,
self
).
__init__
(
)
super
(
).
__init__
(
command_channel_url
)
self
.
tuner
=
tuner
self
.
tuner
=
tuner
self
.
assessor
=
assessor
self
.
assessor
=
assessor
if
assessor
is
None
:
if
assessor
is
None
:
...
@@ -88,12 +88,12 @@ class MsgDispatcher(MsgDispatcherBase):
...
@@ -88,12 +88,12 @@ class MsgDispatcher(MsgDispatcherBase):
"""Data is search space
"""Data is search space
"""
"""
self
.
tuner
.
update_search_space
(
data
)
self
.
tuner
.
update_search_space
(
data
)
send
(
CommandType
.
Initialized
,
''
)
self
.
send
(
CommandType
.
Initialized
,
''
)
def
send_trial_callback
(
self
,
id_
,
params
):
def
send_trial_callback
(
self
,
id_
,
params
):
"""For tuner to issue trial config when the config is generated
"""For tuner to issue trial config when the config is generated
"""
"""
send
(
CommandType
.
NewTrialJob
,
_pack_parameter
(
id_
,
params
))
self
.
send
(
CommandType
.
NewTrialJob
,
_pack_parameter
(
id_
,
params
))
def
handle_request_trial_jobs
(
self
,
data
):
def
handle_request_trial_jobs
(
self
,
data
):
# data: number or trial jobs
# data: number or trial jobs
...
@@ -102,10 +102,10 @@ class MsgDispatcher(MsgDispatcherBase):
...
@@ -102,10 +102,10 @@ class MsgDispatcher(MsgDispatcherBase):
params_list
=
self
.
tuner
.
generate_multiple_parameters
(
ids
,
st_callback
=
self
.
send_trial_callback
)
params_list
=
self
.
tuner
.
generate_multiple_parameters
(
ids
,
st_callback
=
self
.
send_trial_callback
)
for
i
,
_
in
enumerate
(
params_list
):
for
i
,
_
in
enumerate
(
params_list
):
send
(
CommandType
.
NewTrialJob
,
_pack_parameter
(
ids
[
i
],
params_list
[
i
]))
self
.
send
(
CommandType
.
NewTrialJob
,
_pack_parameter
(
ids
[
i
],
params_list
[
i
]))
# when parameters is None.
# when parameters is None.
if
len
(
params_list
)
<
len
(
ids
):
if
len
(
params_list
)
<
len
(
ids
):
send
(
CommandType
.
NoMoreTrialJobs
,
_pack_parameter
(
ids
[
0
],
''
))
self
.
send
(
CommandType
.
NoMoreTrialJobs
,
_pack_parameter
(
ids
[
0
],
''
))
def
handle_update_search_space
(
self
,
data
):
def
handle_update_search_space
(
self
,
data
):
self
.
tuner
.
update_search_space
(
data
)
self
.
tuner
.
update_search_space
(
data
)
...
@@ -148,7 +148,7 @@ class MsgDispatcher(MsgDispatcherBase):
...
@@ -148,7 +148,7 @@ class MsgDispatcher(MsgDispatcherBase):
param
=
self
.
tuner
.
generate_parameters
(
param_id
,
trial_job_id
=
data
[
'trial_job_id'
])
param
=
self
.
tuner
.
generate_parameters
(
param_id
,
trial_job_id
=
data
[
'trial_job_id'
])
except
NoMoreTrialError
:
except
NoMoreTrialError
:
param
=
None
param
=
None
send
(
CommandType
.
SendTrialJobParameter
,
_pack_parameter
(
param_id
,
param
,
trial_job_id
=
data
[
'trial_job_id'
],
self
.
send
(
CommandType
.
SendTrialJobParameter
,
_pack_parameter
(
param_id
,
param
,
trial_job_id
=
data
[
'trial_job_id'
],
parameter_index
=
data
[
'parameter_index'
]))
parameter_index
=
data
[
'parameter_index'
]))
else
:
else
:
raise
ValueError
(
'Data type not supported: {}'
.
format
(
data
[
'type'
]))
raise
ValueError
(
'Data type not supported: {}'
.
format
(
data
[
'type'
]))
...
@@ -222,7 +222,7 @@ class MsgDispatcher(MsgDispatcherBase):
...
@@ -222,7 +222,7 @@ class MsgDispatcher(MsgDispatcherBase):
if
result
is
AssessResult
.
Bad
:
if
result
is
AssessResult
.
Bad
:
_logger
.
debug
(
'BAD, kill %s'
,
trial_job_id
)
_logger
.
debug
(
'BAD, kill %s'
,
trial_job_id
)
send
(
CommandType
.
KillTrialJob
,
dump
(
trial_job_id
))
self
.
send
(
CommandType
.
KillTrialJob
,
dump
(
trial_job_id
))
# notify tuner
# notify tuner
_logger
.
debug
(
'env var: NNI_INCLUDE_INTERMEDIATE_RESULTS: [%s]'
,
_logger
.
debug
(
'env var: NNI_INCLUDE_INTERMEDIATE_RESULTS: [%s]'
,
dispatcher_env_vars
.
NNI_INCLUDE_INTERMEDIATE_RESULTS
)
dispatcher_env_vars
.
NNI_INCLUDE_INTERMEDIATE_RESULTS
)
...
@@ -237,8 +237,5 @@ class MsgDispatcher(MsgDispatcherBase):
...
@@ -237,8 +237,5 @@ class MsgDispatcher(MsgDispatcherBase):
"""
"""
_logger
.
debug
(
'Early stop notify tuner data: [%s]'
,
data
)
_logger
.
debug
(
'Early stop notify tuner data: [%s]'
,
data
)
data
[
'type'
]
=
MetricType
.
FINAL
data
[
'type'
]
=
MetricType
.
FINAL
if
multi_thread_enabled
():
data
[
'value'
]
=
dump
(
data
[
'value'
])
self
.
_handle_final_metric_data
(
data
)
self
.
enqueue_command
(
CommandType
.
ReportMetricData
,
data
)
else
:
data
[
'value'
]
=
dump
(
data
[
'value'
])
self
.
enqueue_command
(
CommandType
.
ReportMetricData
,
data
)
nni/runtime/msg_dispatcher_base.py
View file @
98c1a77f
...
@@ -3,14 +3,12 @@
...
@@ -3,14 +3,12 @@
import
threading
import
threading
import
logging
import
logging
from
multiprocessing.dummy
import
Pool
as
ThreadPool
from
queue
import
Queue
,
Empty
from
queue
import
Queue
,
Empty
from
.common
import
multi_thread_enabled
from
.env_vars
import
dispatcher_env_vars
from
.env_vars
import
dispatcher_env_vars
from
..common
import
load
from
..common
import
load
from
..recoverable
import
Recoverable
from
..recoverable
import
Recoverable
from
.
protoco
l
import
CommandType
,
receive
from
.
tuner_command_channe
l
import
CommandType
,
TunerCommandChannel
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -24,59 +22,52 @@ class MsgDispatcherBase(Recoverable):
...
@@ -24,59 +22,52 @@ class MsgDispatcherBase(Recoverable):
Inherits this class to make your own advisor.
Inherits this class to make your own advisor.
"""
"""
def
__init__
(
self
):
def
__init__
(
self
,
command_channel_url
=
None
):
self
.
stopping
=
False
self
.
stopping
=
False
if
multi_thread_enabled
():
if
command_channel_url
is
None
:
self
.
pool
=
ThreadPool
()
command_channel_url
=
dispatcher_env_vars
.
NNI_TUNER_COMMAND_CHANNEL
self
.
thread_results
=
[]
self
.
_channel
=
TunerCommandChannel
(
command_channel_url
)
else
:
self
.
default_command_queue
=
Queue
()
self
.
default_command_queue
=
Queue
()
self
.
assessor_command_queue
=
Queue
()
self
.
assessor_command_queue
=
Queue
()
self
.
default_worker
=
threading
.
Thread
(
target
=
self
.
command_queue_worker
,
args
=
(
self
.
default_command_queue
,))
self
.
default_worker
=
threading
.
Thread
(
target
=
self
.
command_queue_worker
,
args
=
(
self
.
default_command_queue
,))
self
.
assessor_worker
=
threading
.
Thread
(
target
=
self
.
command_queue_worker
,
args
=
(
self
.
assessor_command_queue
,))
self
.
assessor_worker
=
threading
.
Thread
(
target
=
self
.
command_queue_worker
,
self
.
worker_exceptions
=
[]
args
=
(
self
.
assessor_command_queue
,))
self
.
default_worker
.
start
()
self
.
assessor_worker
.
start
()
self
.
worker_exceptions
=
[]
def
run
(
self
):
def
run
(
self
):
"""Run the tuner.
"""Run the tuner.
This function will never return unless raise.
This function will never return unless raise.
"""
"""
_logger
.
info
(
'Dispatcher started'
)
_logger
.
info
(
'Dispatcher started'
)
self
.
_channel
.
connect
()
self
.
default_worker
.
start
()
self
.
assessor_worker
.
start
()
if
dispatcher_env_vars
.
NNI_MODE
==
'resume'
:
if
dispatcher_env_vars
.
NNI_MODE
==
'resume'
:
self
.
load_checkpoint
()
self
.
load_checkpoint
()
while
not
self
.
stopping
:
while
not
self
.
stopping
:
command
,
data
=
receive
()
command
,
data
=
self
.
_channel
.
_
receive
()
if
data
:
if
data
:
data
=
load
(
data
)
data
=
load
(
data
)
if
command
is
None
or
command
is
CommandType
.
Terminate
:
if
command
is
None
or
command
is
CommandType
.
Terminate
:
break
break
if
multi_thread_enabled
():
self
.
enqueue_command
(
command
,
data
)
result
=
self
.
pool
.
map_async
(
self
.
process_command_thread
,
[(
command
,
data
)])
if
self
.
worker_exceptions
:
self
.
thread_results
.
append
(
result
)
break
if
any
([
thread_result
.
ready
()
and
not
thread_result
.
successful
()
for
thread_result
in
self
.
thread_results
]):
_logger
.
debug
(
'Caught thread exception'
)
break
else
:
self
.
enqueue_command
(
command
,
data
)
if
self
.
worker_exceptions
:
break
_logger
.
info
(
'Dispatcher exiting...'
)
_logger
.
info
(
'Dispatcher exiting...'
)
self
.
stopping
=
True
self
.
stopping
=
True
if
multi_thread_enabled
():
self
.
default_worker
.
join
()
self
.
pool
.
close
()
self
.
assessor_worker
.
join
()
self
.
pool
.
join
()
self
.
_channel
.
disconnect
()
else
:
self
.
default_worker
.
join
()
self
.
assessor_worker
.
join
()
_logger
.
info
(
'Dispatcher terminiated'
)
_logger
.
info
(
'Dispatcher terminiated'
)
def
send
(
self
,
command
,
data
):
self
.
_channel
.
_send
(
command
,
data
)
def
command_queue_worker
(
self
,
command_queue
):
def
command_queue_worker
(
self
,
command_queue
):
"""Process commands in command queues.
"""Process commands in command queues.
"""
"""
...
@@ -112,19 +103,6 @@ class MsgDispatcherBase(Recoverable):
...
@@ -112,19 +103,6 @@ class MsgDispatcherBase(Recoverable):
if
qsize
>=
QUEUE_LEN_WARNING_MARK
:
if
qsize
>=
QUEUE_LEN_WARNING_MARK
:
_logger
.
warning
(
'assessor queue length: %d'
,
qsize
)
_logger
.
warning
(
'assessor queue length: %d'
,
qsize
)
def
process_command_thread
(
self
,
request
):
"""Worker thread to process a command.
"""
command
,
data
=
request
if
multi_thread_enabled
():
try
:
self
.
process_command
(
command
,
data
)
except
Exception
as
e
:
_logger
.
exception
(
str
(
e
))
raise
else
:
pass
def
process_command
(
self
,
command
,
data
):
def
process_command
(
self
,
command
,
data
):
_logger
.
debug
(
'process_command: command: [%s], data: [%s]'
,
command
,
data
)
_logger
.
debug
(
'process_command: command: [%s], data: [%s]'
,
command
,
data
)
...
@@ -242,4 +220,4 @@ class MsgDispatcherBase(Recoverable):
...
@@ -242,4 +220,4 @@ class MsgDispatcherBase(Recoverable):
hyper_params: the string that is sent by message dispatcher during the creation of trials.
hyper_params: the string that is sent by message dispatcher during the creation of trials.
"""
"""
raise
NotImplementedError
(
'handle_trial_end not implemented'
)
raise
NotImplementedError
(
'handle_trial_end not implemented'
)
\ No newline at end of file
nni/runtime/protocol.py
deleted
100644 → 0
View file @
5dc80762
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=unused-import
from
__future__
import
annotations
from
.tuner_command_channel.command_type
import
CommandType
from
.tuner_command_channel
import
legacy
from
.tuner_command_channel
import
shim
_use_ws
=
False
def
connect_websocket
(
url
:
str
):
global
_use_ws
_use_ws
=
True
shim
.
connect
(
url
)
def
send
(
command
:
CommandType
,
data
:
str
)
->
None
:
if
_use_ws
:
shim
.
send
(
command
,
data
)
else
:
legacy
.
send
(
command
,
data
)
def
receive
()
->
tuple
[
CommandType
,
str
]
|
tuple
[
None
,
None
]:
if
_use_ws
:
return
shim
.
receive
()
else
:
return
legacy
.
receive
()
# for unit test compatibility
def
_set_in_file
(
in_file
):
legacy
.
_in_file
=
in_file
def
_set_out_file
(
out_file
):
legacy
.
_out_file
=
out_file
def
_get_out_file
():
return
legacy
.
_out_file
nni/runtime/tuner_command_channel/__init__.py
View file @
98c1a77f
...
@@ -2,7 +2,8 @@
...
@@ -2,7 +2,8 @@
# Licensed under the MIT license.
# Licensed under the MIT license.
"""
"""
The IPC channel between tuner/assessor and NNI manager.
Low level APIs for algorithms to communicate with NNI manager.
Work in progress.
"""
"""
from
.command_type
import
CommandType
from
.channel
import
TunerCommandChannel
nni/runtime/tuner_command_channel/channel.py
0 → 100644
View file @
98c1a77f
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Low level APIs for algorithms to communicate with NNI manager.
"""
from
__future__
import
annotations
__all__
=
[
'TunerCommandChannel'
]
from
.command_type
import
CommandType
from
.websocket
import
WebSocket
class
TunerCommandChannel
:
"""
A channel to communicate with NNI manager.
Each NNI experiment has a channel URL for tuner/assessor/strategy algorithm.
The channel can only be connected once, so for each Python side :class:`~nni.experiment.Experiment` object,
there should be exactly one corresponding ``TunerCommandChannel`` instance.
:meth:`connect` must be invoked before sending or receiving data.
The constructor does not have side effect so ``TunerCommandChannel`` can be created anywhere.
But :meth:`connect` requires an initialized NNI manager, or otherwise the behavior is unpredictable.
:meth:`_send` and :meth:`_receive` are underscore-prefixed because their signatures are scheduled to change by v3.0.
Parameters
----------
url
The command channel URL.
For now it must be like ``"ws://localhost:8080/tuner"`` or ``"ws://localhost:8080/url-prefix/tuner"``.
"""
def
__init__
(
self
,
url
:
str
):
self
.
_channel
=
WebSocket
(
url
)
def
connect
(
self
)
->
None
:
self
.
_channel
.
connect
()
def
disconnect
(
self
)
->
None
:
self
.
_channel
.
disconnect
()
# TODO: Define semantic command class like `KillTrialJob(trial_id='abc')`.
# def send(self, command: Command) -> None:
# ...
# def receive(self) -> Command | None:
# ...
def
_send
(
self
,
command_type
:
CommandType
,
data
:
str
)
->
None
:
command
=
command_type
.
value
.
decode
()
+
data
self
.
_channel
.
send
(
command
)
def
_receive
(
self
)
->
tuple
[
CommandType
,
str
]
|
tuple
[
None
,
None
]:
command
=
self
.
_channel
.
receive
()
if
command
is
None
:
raise
RuntimeError
(
'NNI manager closed connection'
)
command_type
=
CommandType
(
command
[:
2
].
encode
())
return
command_type
,
command
[
2
:]
nni/runtime/tuner_command_channel/legacy.py
View file @
98c1a77f
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
__all__
=
[
'CommandType'
,
'LegacyCommandChannel'
,
'send'
,
'receive'
,
'_set_in_file'
,
'_set_out_file'
,
'_get_out_file'
,
]
import
logging
import
logging
import
os
import
os
import
threading
import
threading
...
@@ -18,6 +28,29 @@ try:
...
@@ -18,6 +28,29 @@ try:
except
OSError
:
except
OSError
:
_logger
.
debug
(
'IPC pipeline not exists'
)
_logger
.
debug
(
'IPC pipeline not exists'
)
def
_set_in_file
(
in_file
):
global
_in_file
_in_file
=
in_file
def
_set_out_file
(
out_file
):
global
_out_file
_out_file
=
out_file
def
_get_out_file
():
return
_out_file
class
LegacyCommandChannel
:
def
connect
(
self
):
pass
def
disconnect
(
self
):
pass
def
_send
(
self
,
command
,
data
):
send
(
command
,
data
)
def
_receive
(
self
):
return
receive
()
def
send
(
command
,
data
):
def
send
(
command
,
data
):
"""Send command to Training Service.
"""Send command to Training Service.
...
...
nni/runtime/tuner_command_channel/shim.py
deleted
100644 → 0
View file @
5dc80762
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Compatibility layer for old protocol APIs.
We are working on more semantic new APIs.
"""
from
__future__
import
annotations
from
.command_type
import
CommandType
from
.websocket
import
WebSocket
_ws
:
WebSocket
=
None
# type: ignore
def
connect
(
url
:
str
)
->
None
:
global
_ws
_ws
=
WebSocket
(
url
)
_ws
.
connect
()
def
send
(
command_type
:
CommandType
,
data
:
str
)
->
None
:
command
=
command_type
.
value
.
decode
()
+
data
_ws
.
send
(
command
)
def
receive
()
->
tuple
[
CommandType
,
str
]:
command
=
_ws
.
receive
()
if
command
is
None
:
raise
RuntimeError
(
'NNI manager closed connection'
)
command_type
=
CommandType
(
command
[:
2
].
encode
())
if
command_type
is
CommandType
.
Terminate
:
_ws
.
disconnect
()
return
command_type
,
command
[
2
:]
nni/tools/nnictl/config_schema.py
View file @
98c1a77f
...
@@ -2,7 +2,6 @@
...
@@ -2,7 +2,6 @@
# Licensed under the MIT license.
# Licensed under the MIT license.
import
json
import
json
import
logging
import
os
import
os
from
schema
import
And
,
Optional
,
Or
,
Regex
,
Schema
,
SchemaError
from
schema
import
And
,
Optional
,
Or
,
Regex
,
Schema
,
SchemaError
...
@@ -77,7 +76,6 @@ class AlgoSchema:
...
@@ -77,7 +76,6 @@ class AlgoSchema:
if
not
builtin_name
or
not
class_args
:
if
not
builtin_name
or
not
class_args
:
return
return
logging
.
getLogger
(
'nni.protocol'
).
setLevel
(
logging
.
ERROR
)
# we know IPC is not there, don't complain
validator
=
create_validator_instance
(
algo_type
+
's'
,
builtin_name
)
validator
=
create_validator_instance
(
algo_type
+
's'
,
builtin_name
)
if
validator
:
if
validator
:
try
:
try
:
...
...
nni/tools/nnictl/launcher.py
View file @
98c1a77f
...
@@ -88,11 +88,14 @@ def create_experiment(args):
...
@@ -88,11 +88,14 @@ def create_experiment(args):
exp
=
Experiment
(
config
)
exp
=
Experiment
(
config
)
exp
.
url_prefix
=
url_prefix
exp
.
url_prefix
=
url_prefix
run_mode
=
RunMode
.
Foreground
if
foreground
else
RunMode
.
Detach
exp
.
start
(
port
,
debug
,
run_mode
)
_logger
.
info
(
f
'To stop experiment run "nnictl stop
{
exp
.
id
}
" or "nnictl stop --all"'
)
if
foreground
:
_logger
.
info
(
'Reference: https://nni.readthedocs.io/en/stable/reference/nnictl.html'
)
exp
.
run
(
port
,
debug
=
debug
)
else
:
exp
.
start
(
port
,
debug
,
RunMode
.
Detach
)
_logger
.
info
(
f
'To stop experiment run "nnictl stop
{
exp
.
id
}
" or "nnictl stop --all"'
)
_logger
.
info
(
'Reference: https://nni.readthedocs.io/en/stable/reference/nnictl.html'
)
def
resume_experiment
(
args
):
def
resume_experiment
(
args
):
exp_id
=
args
.
id
exp_id
=
args
.
id
...
...
test/ut/retiarii/test_cgo_engine.py
View file @
98c1a77f
...
@@ -10,6 +10,7 @@ from pathlib import Path
...
@@ -10,6 +10,7 @@ from pathlib import Path
import
nni
import
nni
import
nni.runtime.platform.test
import
nni.runtime.platform.test
from
nni.runtime.tuner_command_channel
import
legacy
as
protocol
import
json
import
json
try
:
try
:
...
@@ -262,7 +263,11 @@ class CGOEngineTest(unittest.TestCase):
...
@@ -262,7 +263,11 @@ class CGOEngineTest(unittest.TestCase):
opt
=
DedupInputOptimizer
()
opt
=
DedupInputOptimizer
()
opt
.
convert
(
lp
)
opt
.
convert
(
lp
)
advisor
=
RetiariiAdvisor
()
advisor
=
RetiariiAdvisor
(
'ws://_placeholder_'
)
advisor
.
_channel
=
protocol
.
LegacyCommandChannel
()
advisor
.
default_worker
.
start
()
advisor
.
assessor_worker
.
start
()
available_devices
=
[
GPUDevice
(
"test"
,
0
),
GPUDevice
(
"test"
,
1
),
GPUDevice
(
"test"
,
2
),
GPUDevice
(
"test"
,
3
)]
available_devices
=
[
GPUDevice
(
"test"
,
0
),
GPUDevice
(
"test"
,
1
),
GPUDevice
(
"test"
,
2
),
GPUDevice
(
"test"
,
3
)]
cgo
=
CGOExecutionEngine
(
devices
=
available_devices
,
batch_waiting_time
=
0
)
cgo
=
CGOExecutionEngine
(
devices
=
available_devices
,
batch_waiting_time
=
0
)
...
@@ -281,7 +286,11 @@ class CGOEngineTest(unittest.TestCase):
...
@@ -281,7 +286,11 @@ class CGOEngineTest(unittest.TestCase):
opt
=
DedupInputOptimizer
()
opt
=
DedupInputOptimizer
()
opt
.
convert
(
lp
)
opt
.
convert
(
lp
)
advisor
=
RetiariiAdvisor
()
advisor
=
RetiariiAdvisor
(
'ws://_placeholder_'
)
advisor
.
_channel
=
protocol
.
LegacyCommandChannel
()
advisor
.
default_worker
.
start
()
advisor
.
assessor_worker
.
start
()
available_devices
=
[
GPUDevice
(
"test"
,
0
),
GPUDevice
(
"test"
,
1
)]
available_devices
=
[
GPUDevice
(
"test"
,
0
),
GPUDevice
(
"test"
,
1
)]
cgo
=
CGOExecutionEngine
(
devices
=
available_devices
,
batch_waiting_time
=
0
)
cgo
=
CGOExecutionEngine
(
devices
=
available_devices
,
batch_waiting_time
=
0
)
...
@@ -296,14 +305,17 @@ class CGOEngineTest(unittest.TestCase):
...
@@ -296,14 +305,17 @@ class CGOEngineTest(unittest.TestCase):
_reset
()
_reset
()
nni
.
retiarii
.
debug_configs
.
framework
=
'pytorch'
nni
.
retiarii
.
debug_configs
.
framework
=
'pytorch'
os
.
makedirs
(
'generated'
,
exist_ok
=
True
)
os
.
makedirs
(
'generated'
,
exist_ok
=
True
)
from
nni.runtime
import
protocol
import
nni.runtime.platform.test
as
tt
import
nni.runtime.platform.test
as
tt
protocol
.
_set_out_file
(
open
(
'generated/debug_protocol_out_file.py'
,
'wb'
))
protocol
.
_set_out_file
(
open
(
'generated/debug_protocol_out_file.py'
,
'wb'
))
protocol
.
_set_in_file
(
open
(
'generated/debug_protocol_out_file.py'
,
'rb'
))
protocol
.
_set_in_file
(
open
(
'generated/debug_protocol_out_file.py'
,
'rb'
))
models
=
_load_mnist
(
2
)
models
=
_load_mnist
(
2
)
advisor
=
RetiariiAdvisor
()
advisor
=
RetiariiAdvisor
(
'ws://_placeholder_'
)
advisor
.
_channel
=
protocol
.
LegacyCommandChannel
()
advisor
.
default_worker
.
start
()
advisor
.
assessor_worker
.
start
()
cgo_engine
=
CGOExecutionEngine
(
devices
=
[
GPUDevice
(
"test"
,
0
),
GPUDevice
(
"test"
,
1
),
cgo_engine
=
CGOExecutionEngine
(
devices
=
[
GPUDevice
(
"test"
,
0
),
GPUDevice
(
"test"
,
1
),
GPUDevice
(
"test"
,
2
),
GPUDevice
(
"test"
,
3
)],
batch_waiting_time
=
0
)
GPUDevice
(
"test"
,
2
),
GPUDevice
(
"test"
,
3
)],
batch_waiting_time
=
0
)
set_execution_engine
(
cgo_engine
)
set_execution_engine
(
cgo_engine
)
...
...
test/ut/retiarii/test_engine.py
View file @
98c1a77f
...
@@ -11,7 +11,7 @@ from nni.retiarii.execution.base import BaseExecutionEngine
...
@@ -11,7 +11,7 @@ from nni.retiarii.execution.base import BaseExecutionEngine
from
nni.retiarii.execution.python
import
PurePythonExecutionEngine
from
nni.retiarii.execution.python
import
PurePythonExecutionEngine
from
nni.retiarii.graph
import
DebugEvaluator
from
nni.retiarii.graph
import
DebugEvaluator
from
nni.retiarii.integration
import
RetiariiAdvisor
from
nni.retiarii.integration
import
RetiariiAdvisor
from
nni.runtime.tuner_command_channel.legacy
import
*
class
EngineTest
(
unittest
.
TestCase
):
class
EngineTest
(
unittest
.
TestCase
):
def
test_codegen
(
self
):
def
test_codegen
(
self
):
...
@@ -25,7 +25,11 @@ class EngineTest(unittest.TestCase):
...
@@ -25,7 +25,11 @@ class EngineTest(unittest.TestCase):
def
test_base_execution_engine
(
self
):
def
test_base_execution_engine
(
self
):
nni
.
retiarii
.
integration_api
.
_advisor
=
None
nni
.
retiarii
.
integration_api
.
_advisor
=
None
nni
.
retiarii
.
execution
.
api
.
_execution_engine
=
None
nni
.
retiarii
.
execution
.
api
.
_execution_engine
=
None
advisor
=
RetiariiAdvisor
()
advisor
=
RetiariiAdvisor
(
'ws://_placeholder_'
)
advisor
.
_channel
=
LegacyCommandChannel
()
advisor
.
default_worker
.
start
()
advisor
.
assessor_worker
.
start
()
set_execution_engine
(
BaseExecutionEngine
())
set_execution_engine
(
BaseExecutionEngine
())
with
open
(
self
.
enclosing_dir
/
'mnist_pytorch.json'
)
as
f
:
with
open
(
self
.
enclosing_dir
/
'mnist_pytorch.json'
)
as
f
:
model
=
Model
.
_load
(
json
.
load
(
f
))
model
=
Model
.
_load
(
json
.
load
(
f
))
...
@@ -38,7 +42,11 @@ class EngineTest(unittest.TestCase):
...
@@ -38,7 +42,11 @@ class EngineTest(unittest.TestCase):
def
test_py_execution_engine
(
self
):
def
test_py_execution_engine
(
self
):
nni
.
retiarii
.
integration_api
.
_advisor
=
None
nni
.
retiarii
.
integration_api
.
_advisor
=
None
nni
.
retiarii
.
execution
.
api
.
_execution_engine
=
None
nni
.
retiarii
.
execution
.
api
.
_execution_engine
=
None
advisor
=
RetiariiAdvisor
()
advisor
=
RetiariiAdvisor
(
'ws://_placeholder_'
)
advisor
.
_channel
=
LegacyCommandChannel
()
advisor
.
default_worker
.
start
()
advisor
.
assessor_worker
.
start
()
set_execution_engine
(
PurePythonExecutionEngine
())
set_execution_engine
(
PurePythonExecutionEngine
())
model
=
Model
.
_load
({
model
=
Model
.
_load
({
'_model'
:
{
'_model'
:
{
...
@@ -63,11 +71,9 @@ class EngineTest(unittest.TestCase):
...
@@ -63,11 +71,9 @@ class EngineTest(unittest.TestCase):
def
setUp
(
self
)
->
None
:
def
setUp
(
self
)
->
None
:
self
.
enclosing_dir
=
Path
(
__file__
).
parent
self
.
enclosing_dir
=
Path
(
__file__
).
parent
os
.
makedirs
(
self
.
enclosing_dir
/
'generated'
,
exist_ok
=
True
)
os
.
makedirs
(
self
.
enclosing_dir
/
'generated'
,
exist_ok
=
True
)
from
nni.runtime
import
protocol
_set_out_file
(
open
(
self
.
enclosing_dir
/
'generated/debug_protocol_out_file.py'
,
'wb'
))
protocol
.
_set_out_file
(
open
(
self
.
enclosing_dir
/
'generated/debug_protocol_out_file.py'
,
'wb'
))
def
tearDown
(
self
)
->
None
:
def
tearDown
(
self
)
->
None
:
from
nni.runtime
import
protocol
_get_out_file
().
close
()
protocol
.
_get_out_file
().
close
()
nni
.
retiarii
.
execution
.
api
.
_execution_engine
=
None
nni
.
retiarii
.
execution
.
api
.
_execution_engine
=
None
nni
.
retiarii
.
integration_api
.
_advisor
=
None
nni
.
retiarii
.
integration_api
.
_advisor
=
None
test/ut/sdk/helper/__init__.py
0 → 100644
View file @
98c1a77f
test/ut/sdk/test_assessor.py
View file @
98c1a77f
...
@@ -8,8 +8,7 @@ from unittest import TestCase, main
...
@@ -8,8 +8,7 @@ from unittest import TestCase, main
from
nni.assessor
import
Assessor
,
AssessResult
from
nni.assessor
import
Assessor
,
AssessResult
from
nni.runtime
import
msg_dispatcher_base
as
msg_dispatcher_base
from
nni.runtime
import
msg_dispatcher_base
as
msg_dispatcher_base
from
nni.runtime.msg_dispatcher
import
MsgDispatcher
from
nni.runtime.msg_dispatcher
import
MsgDispatcher
from
nni.runtime
import
protocol
from
nni.runtime.tuner_command_channel.legacy
import
*
from
nni.runtime.protocol
import
CommandType
,
send
,
receive
_trials
=
[]
_trials
=
[]
_end_trials
=
[]
_end_trials
=
[]
...
@@ -34,15 +33,15 @@ _out_buf = BytesIO()
...
@@ -34,15 +33,15 @@ _out_buf = BytesIO()
def
_reverse_io
():
def
_reverse_io
():
_in_buf
.
seek
(
0
)
_in_buf
.
seek
(
0
)
_out_buf
.
seek
(
0
)
_out_buf
.
seek
(
0
)
protocol
.
_set_out_file
(
_in_buf
)
_set_out_file
(
_in_buf
)
protocol
.
_set_in_file
(
_out_buf
)
_set_in_file
(
_out_buf
)
def
_restore_io
():
def
_restore_io
():
_in_buf
.
seek
(
0
)
_in_buf
.
seek
(
0
)
_out_buf
.
seek
(
0
)
_out_buf
.
seek
(
0
)
protocol
.
_set_in_file
(
_in_buf
)
_set_in_file
(
_in_buf
)
protocol
.
_set_out_file
(
_out_buf
)
_set_out_file
(
_out_buf
)
class
AssessorTestCase
(
TestCase
):
class
AssessorTestCase
(
TestCase
):
...
@@ -58,7 +57,8 @@ class AssessorTestCase(TestCase):
...
@@ -58,7 +57,8 @@ class AssessorTestCase(TestCase):
_restore_io
()
_restore_io
()
assessor
=
NaiveAssessor
()
assessor
=
NaiveAssessor
()
dispatcher
=
MsgDispatcher
(
None
,
assessor
)
dispatcher
=
MsgDispatcher
(
'ws://_placeholder_'
,
None
,
assessor
)
dispatcher
.
_channel
=
LegacyCommandChannel
()
msg_dispatcher_base
.
_worker_fast_exit_on_terminate
=
False
msg_dispatcher_base
.
_worker_fast_exit_on_terminate
=
False
dispatcher
.
run
()
dispatcher
.
run
()
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment