Unverified Commit 2bc98441 authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

[retiarii] fix experiment does not exit after done (#4916)

parent c299e576
...@@ -46,8 +46,10 @@ class MsgDispatcherBase(Recoverable): ...@@ -46,8 +46,10 @@ class MsgDispatcherBase(Recoverable):
self._channel.connect() self._channel.connect()
self.default_command_queue = Queue() self.default_command_queue = Queue()
self.assessor_command_queue = Queue() self.assessor_command_queue = Queue()
self.default_worker = threading.Thread(target=self.command_queue_worker, args=(self.default_command_queue,)) # here daemon should be True, because their parent thread is configured as daemon to enable smooth exit of NAS experiment.
self.assessor_worker = threading.Thread(target=self.command_queue_worker, args=(self.assessor_command_queue,)) # if daemon is not set, these threads will block the daemon effect of their parent thread.
self.default_worker = threading.Thread(target=self.command_queue_worker, args=(self.default_command_queue,), daemon=True)
self.assessor_worker = threading.Thread(target=self.command_queue_worker, args=(self.assessor_command_queue,), daemon=True)
self.worker_exceptions = [] self.worker_exceptions = []
def run(self): def run(self):
......
import argparse
import os
import sys
import pytorch_lightning as pl
import pytest
from subprocess import Popen
from nni.retiarii import strategy
from nni.retiarii.experiment.pytorch import RetiariiExeConfig, RetiariiExperiment
from .test_oneshot import _mnist_net
pytestmark = pytest.mark.skipif(pl.__version__ < '1.0', reason='Incompatible APIs')
def test_multi_trial():
evaluator_kwargs = {
'max_epochs': 1
}
to_test = [
# (model, evaluator)
_mnist_net('simple', evaluator_kwargs),
_mnist_net('simple_value_choice', evaluator_kwargs),
_mnist_net('value_choice', evaluator_kwargs),
_mnist_net('repeat', evaluator_kwargs),
_mnist_net('custom_op', evaluator_kwargs),
]
for base_model, evaluator in to_test:
search_strategy = strategy.Random()
exp = RetiariiExperiment(base_model, evaluator, strategy=search_strategy)
exp_config = RetiariiExeConfig('local')
exp_config.experiment_name = 'mnist_unittest'
exp_config.trial_concurrency = 1
exp_config.max_trial_number = 1
exp_config.training_service.use_active_gpu = False
exp.run(exp_config, 8080)
assert isinstance(exp.export_top_models()[0], dict)
exp.stop()
python_script = """
from nni.retiarii import strategy
from nni.retiarii.experiment.pytorch import RetiariiExeConfig, RetiariiExperiment
from test_oneshot import _mnist_net
base_model, evaluator = _mnist_net('simple', {'max_epochs': 1})
search_strategy = strategy.Random()
exp = RetiariiExperiment(base_model, evaluator, strategy=search_strategy)
exp_config = RetiariiExeConfig('local')
exp_config.experiment_name = 'mnist_unittest'
exp_config.trial_concurrency = 1
exp_config.max_trial_number = 1
exp_config.training_service.use_active_gpu = False
exp.run(exp_config, 8080)
assert isinstance(exp.export_top_models()[0], dict)
"""
@pytest.mark.timeout(600)
def test_exp_exit_without_stop():
script_name = 'tmp_multi_trial.py'
with open(script_name, 'w') as f:
f.write(python_script)
proc = Popen([sys.executable, script_name])
proc.wait()
os.remove(script_name)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--exp', type=str, default='all', metavar='E',
help='experiment to run, default = all')
args = parser.parse_args()
if args.exp == 'all':
test_multi_trial()
test_exp_exit_without_stop()
else:
globals()[f'test_{args.exp}']()
...@@ -7,6 +7,7 @@ from torchvision import transforms ...@@ -7,6 +7,7 @@ from torchvision import transforms
from torchvision.datasets import MNIST from torchvision.datasets import MNIST
from torch.utils.data import Dataset, RandomSampler from torch.utils.data import Dataset, RandomSampler
import nni
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
from nni.retiarii import strategy, model_wrapper, basic_unit from nni.retiarii import strategy, model_wrapper, basic_unit
from nni.retiarii.experiment.pytorch import RetiariiExeConfig, RetiariiExperiment from nni.retiarii.experiment.pytorch import RetiariiExeConfig, RetiariiExperiment
...@@ -216,13 +217,13 @@ def _mnist_net(type_, evaluator_kwargs): ...@@ -216,13 +217,13 @@ def _mnist_net(type_, evaluator_kwargs):
raise ValueError(f'Unsupported type: {type_}') raise ValueError(f'Unsupported type: {type_}')
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = MNIST('data/mnist', train=True, download=True, transform=transform) train_dataset = nni.trace(MNIST)('data/mnist', train=True, download=True, transform=transform)
# Multi-GPU combined dataloader will break this subset sampler. Expected though. # Multi-GPU combined dataloader will break this subset sampler. Expected though.
train_random_sampler = RandomSampler(train_dataset, True, int(len(train_dataset) / 20)) train_random_sampler = nni.trace(RandomSampler)(train_dataset, True, int(len(train_dataset) / 20))
train_loader = DataLoader(train_dataset, 64, sampler=train_random_sampler) train_loader = nni.trace(DataLoader)(train_dataset, 64, sampler=train_random_sampler)
valid_dataset = MNIST('data/mnist', train=False, download=True, transform=transform) valid_dataset = nni.trace(MNIST)('data/mnist', train=False, download=True, transform=transform)
valid_random_sampler = RandomSampler(valid_dataset, True, int(len(valid_dataset) / 20)) valid_random_sampler = nni.trace(RandomSampler)(valid_dataset, True, int(len(valid_dataset) / 20))
valid_loader = DataLoader(valid_dataset, 64, sampler=valid_random_sampler) valid_loader = nni.trace(DataLoader)(valid_dataset, 64, sampler=valid_random_sampler)
evaluator = Classification(train_dataloader=train_loader, val_dataloaders=valid_loader, **evaluator_kwargs) evaluator = Classification(train_dataloader=train_loader, val_dataloaders=valid_loader, **evaluator_kwargs)
return base_model, evaluator return base_model, evaluator
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment