test_multitrial.py 5.18 KB
Newer Older
Yuge Zhang's avatar
Yuge Zhang committed
1
2
import multiprocessing
import os
3
import subprocess
Yuge Zhang's avatar
Yuge Zhang committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import time

import pytest
import pytorch_lightning as pl
from nni.retiarii import strategy
from nni.retiarii.experiment.pytorch import RetiariiExeConfig, RetiariiExperiment
from ut.nas.test_experiment import nas_experiment_trial_params, ensure_success
from .test_oneshot import _mnist_net

pytestmark = pytest.mark.skipif(pl.__version__ < '1.0', reason='Incompatible APIs')


@pytest.mark.parametrize('model', [
    'simple', 'simple_value_choice', 'value_choice', 'repeat', 'custom_op'
])
def test_multi_trial(model, pytestconfig):
    evaluator_kwargs = {
        'max_epochs': 1
    }

    base_model, evaluator = _mnist_net(model, evaluator_kwargs)

    search_strategy = strategy.Random()
    exp = RetiariiExperiment(base_model, evaluator, strategy=search_strategy)
    exp_config = RetiariiExeConfig('local')
    exp_config.experiment_name = 'mnist_unittest'
    exp_config.trial_concurrency = 1
    exp_config.max_trial_number = 1
    exp_config._trial_command_params = nas_experiment_trial_params(pytestconfig.rootpath)
    exp.run(exp_config)
    ensure_success(exp)
    assert isinstance(exp.export_top_models()[0], dict)
    exp.stop()


def _test_experiment_in_separate_process(rootpath):
    try:
        base_model, evaluator = _mnist_net('simple', {'max_epochs': 1})
        search_strategy = strategy.Random()
        exp = RetiariiExperiment(base_model, evaluator, strategy=search_strategy)
        exp_config = RetiariiExeConfig('local')
        exp_config.experiment_name = 'mnist_unittest'
        exp_config.trial_concurrency = 1
        exp_config.max_trial_number = 1
        exp_config._trial_command_params = nas_experiment_trial_params(rootpath)
        exp.run(exp_config)
        ensure_success(exp)
        assert isinstance(exp.export_top_models()[0], dict)
    finally:
        # https://stackoverflow.com/questions/34506638/how-to-register-atexit-function-in-pythons-multiprocessing-subprocess
        import atexit
        atexit._run_exitfuncs()


def test_exp_exit_without_stop(pytestconfig):
    # NOTE: Multiprocessing has compatibility issue with OpenMP.
    # It makes the MNIST dataset fails to load on pipeline.
    # https://github.com/pytorch/pytorch/issues/50669
    # Need to use spawn as a workaround of this issue.
    ctx = multiprocessing.get_context('spawn')
    process = ctx.Process(
        target=_test_experiment_in_separate_process,
        kwargs=dict(rootpath=pytestconfig.rootpath)
    )
    process.start()
    print('Waiting for experiment in sub-process.')
    timeout = 180
    for _ in range(timeout):
        if process.is_alive():
            time.sleep(1)
        else:
            assert process.exitcode == 0
            return
    process.kill()
    raise RuntimeError(f'Experiment fails to stop in {timeout} seconds.')
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100


def test_multitrial_experiment_resume_view(pytestconfig):
    # start a normal nas experiment
    base_model, evaluator = _mnist_net('simple', {'max_epochs': 1})
    search_strategy = strategy.Random()
    exp = RetiariiExperiment(base_model, evaluator, strategy=search_strategy)
    exp_id = exp.id
    exp_config = RetiariiExeConfig('local')
    exp_config.trial_concurrency = 1
    exp_config.max_trial_number = 1
    exp_config._trial_command_params = nas_experiment_trial_params(pytestconfig.rootpath)
    exp.run(exp_config)
    ensure_success(exp)
    assert isinstance(exp.export_top_models()[0], dict)
    exp.stop()

    # resume the above nas experiment. only tested the resume logic in the python side,
    # as no more trial is executed after resume, the above experiment is already finished
    print('python api resume...')
    exp = RetiariiExperiment.resume(exp_id)
    ensure_success(exp)
QuanluZhang's avatar
QuanluZhang committed
101
102
103
104
    # sleep here because there would be several seconds for the experiment status to change
    # to ERROR from INITIALIZED/RUNNING if the resume gets error.
    time.sleep(6)
    assert exp.get_status() == 'DONE', f'The experiment status should not be {exp.get_status()}'
105
106
107
108
109
110
    # TODO: currently `export_top_models` does not work as strategy's states are not resumed
    # assert isinstance(exp.export_top_models()[0], dict)
    exp.stop()
    # view the above experiment in non blocking mode then stop it
    print('python api view...')
    exp = RetiariiExperiment.view(exp_id, non_blocking=True)
QuanluZhang's avatar
QuanluZhang committed
111
    assert exp.get_status() == 'VIEWED', f'The experiment status should not be {exp.get_status()}'
112
113
114
115
116
117
    exp.stop()

    # the following is nnictl resume and view
    print('nnictl resume...')
    new_env = os.environ.copy()
    new_env['PYTHONPATH'] = str(pytestconfig.rootpath)
QuanluZhang's avatar
QuanluZhang committed
118
119
    # NOTE: experiment status (e.g., ERROR) is not checked, because it runs in blocking mode and
    # the rest server exits right after the command is done
120
121
122
123
124
125
126
    proc = subprocess.run(f'nnictl resume {exp_id}', shell=True, env=new_env)
    assert proc.returncode == 0, 'resume nas experiment failed with code %d' % proc.returncode
    print('nnictl view...')
    proc = subprocess.run(f'nnictl view {exp_id}', shell=True)
    assert proc.returncode == 0, 'view nas experiment failed with code %d' % proc.returncode
    proc = subprocess.run(f'nnictl stop {exp_id}', shell=True)
    assert proc.returncode == 0, 'stop viewed nas experiment failed with code %d' % proc.returncode