Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
2bc98441
Unverified
Commit
2bc98441
authored
Jun 10, 2022
by
QuanluZhang
Committed by
GitHub
Jun 10, 2022
Browse files
[retiarii] fix experiment does not exit after done (#4916)
parent
c299e576
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
87 additions
and
8 deletions
+87
-8
nni/runtime/msg_dispatcher_base.py
nni/runtime/msg_dispatcher_base.py
+4
-2
test/ut/retiarii/test_multitrial.py
test/ut/retiarii/test_multitrial.py
+76
-0
test/ut/retiarii/test_oneshot.py
test/ut/retiarii/test_oneshot.py
+7
-6
No files found.
nni/runtime/msg_dispatcher_base.py
View file @
2bc98441
...
@@ -46,8 +46,10 @@ class MsgDispatcherBase(Recoverable):
...
@@ -46,8 +46,10 @@ class MsgDispatcherBase(Recoverable):
self
.
_channel
.
connect
()
self
.
_channel
.
connect
()
self
.
default_command_queue
=
Queue
()
self
.
default_command_queue
=
Queue
()
self
.
assessor_command_queue
=
Queue
()
self
.
assessor_command_queue
=
Queue
()
self
.
default_worker
=
threading
.
Thread
(
target
=
self
.
command_queue_worker
,
args
=
(
self
.
default_command_queue
,))
# here daemon should be True, because their parent thread is configured as daemon to enable smooth exit of NAS experiment.
self
.
assessor_worker
=
threading
.
Thread
(
target
=
self
.
command_queue_worker
,
args
=
(
self
.
assessor_command_queue
,))
# if daemon is not set, these threads will block the daemon effect of their parent thread.
self
.
default_worker
=
threading
.
Thread
(
target
=
self
.
command_queue_worker
,
args
=
(
self
.
default_command_queue
,),
daemon
=
True
)
self
.
assessor_worker
=
threading
.
Thread
(
target
=
self
.
command_queue_worker
,
args
=
(
self
.
assessor_command_queue
,),
daemon
=
True
)
self
.
worker_exceptions
=
[]
self
.
worker_exceptions
=
[]
def
run
(
self
):
def
run
(
self
):
...
...
test/ut/retiarii/test_multitrial.py
0 → 100644
View file @
2bc98441
import
argparse
import
os
import
sys
import
pytorch_lightning
as
pl
import
pytest
from
subprocess
import
Popen
from
nni.retiarii
import
strategy
from
nni.retiarii.experiment.pytorch
import
RetiariiExeConfig
,
RetiariiExperiment
from
.test_oneshot
import
_mnist_net
pytestmark
=
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
def
test_multi_trial
():
evaluator_kwargs
=
{
'max_epochs'
:
1
}
to_test
=
[
# (model, evaluator)
_mnist_net
(
'simple'
,
evaluator_kwargs
),
_mnist_net
(
'simple_value_choice'
,
evaluator_kwargs
),
_mnist_net
(
'value_choice'
,
evaluator_kwargs
),
_mnist_net
(
'repeat'
,
evaluator_kwargs
),
_mnist_net
(
'custom_op'
,
evaluator_kwargs
),
]
for
base_model
,
evaluator
in
to_test
:
search_strategy
=
strategy
.
Random
()
exp
=
RetiariiExperiment
(
base_model
,
evaluator
,
strategy
=
search_strategy
)
exp_config
=
RetiariiExeConfig
(
'local'
)
exp_config
.
experiment_name
=
'mnist_unittest'
exp_config
.
trial_concurrency
=
1
exp_config
.
max_trial_number
=
1
exp_config
.
training_service
.
use_active_gpu
=
False
exp
.
run
(
exp_config
,
8080
)
assert
isinstance
(
exp
.
export_top_models
()[
0
],
dict
)
exp
.
stop
()
python_script
=
"""
from nni.retiarii import strategy
from nni.retiarii.experiment.pytorch import RetiariiExeConfig, RetiariiExperiment
from test_oneshot import _mnist_net
base_model, evaluator = _mnist_net('simple', {'max_epochs': 1})
search_strategy = strategy.Random()
exp = RetiariiExperiment(base_model, evaluator, strategy=search_strategy)
exp_config = RetiariiExeConfig('local')
exp_config.experiment_name = 'mnist_unittest'
exp_config.trial_concurrency = 1
exp_config.max_trial_number = 1
exp_config.training_service.use_active_gpu = False
exp.run(exp_config, 8080)
assert isinstance(exp.export_top_models()[0], dict)
"""
@
pytest
.
mark
.
timeout
(
600
)
def
test_exp_exit_without_stop
():
script_name
=
'tmp_multi_trial.py'
with
open
(
script_name
,
'w'
)
as
f
:
f
.
write
(
python_script
)
proc
=
Popen
([
sys
.
executable
,
script_name
])
proc
.
wait
()
os
.
remove
(
script_name
)
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--exp'
,
type
=
str
,
default
=
'all'
,
metavar
=
'E'
,
help
=
'experiment to run, default = all'
)
args
=
parser
.
parse_args
()
if
args
.
exp
==
'all'
:
test_multi_trial
()
test_exp_exit_without_stop
()
else
:
globals
()[
f
'test_
{
args
.
exp
}
'
]()
test/ut/retiarii/test_oneshot.py
View file @
2bc98441
...
@@ -7,6 +7,7 @@ from torchvision import transforms
...
@@ -7,6 +7,7 @@ from torchvision import transforms
from
torchvision.datasets
import
MNIST
from
torchvision.datasets
import
MNIST
from
torch.utils.data
import
Dataset
,
RandomSampler
from
torch.utils.data
import
Dataset
,
RandomSampler
import
nni
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
strategy
,
model_wrapper
,
basic_unit
from
nni.retiarii
import
strategy
,
model_wrapper
,
basic_unit
from
nni.retiarii.experiment.pytorch
import
RetiariiExeConfig
,
RetiariiExperiment
from
nni.retiarii.experiment.pytorch
import
RetiariiExeConfig
,
RetiariiExperiment
...
@@ -216,13 +217,13 @@ def _mnist_net(type_, evaluator_kwargs):
...
@@ -216,13 +217,13 @@ def _mnist_net(type_, evaluator_kwargs):
raise
ValueError
(
f
'Unsupported type:
{
type_
}
'
)
raise
ValueError
(
f
'Unsupported type:
{
type_
}
'
)
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))])
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))])
train_dataset
=
MNIST
(
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
train_dataset
=
nni
.
trace
(
MNIST
)
(
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
# Multi-GPU combined dataloader will break this subset sampler. Expected though.
# Multi-GPU combined dataloader will break this subset sampler. Expected though.
train_random_sampler
=
RandomSampler
(
train_dataset
,
True
,
int
(
len
(
train_dataset
)
/
20
))
train_random_sampler
=
nni
.
trace
(
RandomSampler
)
(
train_dataset
,
True
,
int
(
len
(
train_dataset
)
/
20
))
train_loader
=
DataLoader
(
train_dataset
,
64
,
sampler
=
train_random_sampler
)
train_loader
=
nni
.
trace
(
DataLoader
)
(
train_dataset
,
64
,
sampler
=
train_random_sampler
)
valid_dataset
=
MNIST
(
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
valid_dataset
=
nni
.
trace
(
MNIST
)
(
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
valid_random_sampler
=
RandomSampler
(
valid_dataset
,
True
,
int
(
len
(
valid_dataset
)
/
20
))
valid_random_sampler
=
nni
.
trace
(
RandomSampler
)
(
valid_dataset
,
True
,
int
(
len
(
valid_dataset
)
/
20
))
valid_loader
=
DataLoader
(
valid_dataset
,
64
,
sampler
=
valid_random_sampler
)
valid_loader
=
nni
.
trace
(
DataLoader
)
(
valid_dataset
,
64
,
sampler
=
valid_random_sampler
)
evaluator
=
Classification
(
train_dataloader
=
train_loader
,
val_dataloaders
=
valid_loader
,
**
evaluator_kwargs
)
evaluator
=
Classification
(
train_dataloader
=
train_loader
,
val_dataloaders
=
valid_loader
,
**
evaluator_kwargs
)
return
base_model
,
evaluator
return
base_model
,
evaluator
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment