tuner_test.py 3.18 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
3

chicm-ms's avatar
chicm-ms committed
4
5
import sys
import os.path as osp
6
7
8
9
10
import subprocess
import sys
import time
import traceback

11
12
from utils import get_yml_content, dump_yml_content, setup_experiment, get_nni_log_path, is_experiment_done
from utils import GREEN, RED, CLEAR, EXPERIMENT_URL
13

14
TUNER_LIST = ['GridSearch', 'BatchTuner', 'TPE', 'Random', 'Anneal', 'Evolution']
15
16
17
ASSESSOR_LIST = ['Medianstop']


chicm-ms's avatar
chicm-ms committed
18
19
20
21
22
23
24
def get_config_file_path():
    if sys.platform == 'win32':
        config_file = osp.join('tuner_test', 'local_win32.yml')
    else:
        config_file = osp.join('tuner_test', 'local.yml')
    return config_file

25
26
def switch(dispatch_type, dispatch_name):
    '''Change dispatch in config.yml'''
chicm-ms's avatar
chicm-ms committed
27
    config_path = get_config_file_path()
28
    experiment_config = get_yml_content(config_path)
xuehui's avatar
xuehui committed
29
    if dispatch_name in ['GridSearch', 'BatchTuner', 'Random']:
30
31
32
33
34
35
36
37
38
        experiment_config[dispatch_type.lower()] = {
            'builtin' + dispatch_type + 'Name': dispatch_name
        }
    else:
        experiment_config[dispatch_type.lower()] = {
            'builtin' + dispatch_type + 'Name': dispatch_name,
            'classArgs': {
                'optimize_mode': 'maximize'
            }
39
        }
demianzhang's avatar
demianzhang committed
40
41
42
43
    if dispatch_name == 'BatchTuner':
        experiment_config['searchSpacePath'] = 'batchtuner_search_space.json'
    else:
        experiment_config['searchSpacePath'] = 'search_space.json'
44
45
46
47
48
49
50
    dump_yml_content(config_path, experiment_config)

def test_builtin_dispatcher(dispatch_type, dispatch_name):
    '''test a dispatcher whose type is dispatch_type and name is dispatch_name'''
    switch(dispatch_type, dispatch_name)

    print('Testing %s...' % dispatch_name)
chicm-ms's avatar
chicm-ms committed
51
    proc = subprocess.run(['nnictl', 'create', '--config', get_config_file_path()])
52
53
    assert proc.returncode == 0, '`nnictl create` failed with code %d' % proc.returncode

54
    nnimanager_log_path = get_nni_log_path(EXPERIMENT_URL)
55

Zejun Lin's avatar
Zejun Lin committed
56
    for _ in range(20):
57
58
        time.sleep(3)
        # check if experiment is done
59
        experiment_status = is_experiment_done(nnimanager_log_path)
60
61
62
        if experiment_status:
            break

Zejun Lin's avatar
Zejun Lin committed
63
    assert experiment_status, 'Failed to finish in 1 min'
64
65
66
67
68
69
70

def run(dispatch_type):
    '''test all dispatchers whose type is dispatch_type'''
    assert dispatch_type in ['Tuner', 'Assessor'], 'Unsupported dispatcher type: %s' % (dispatch_type)
    dipsatcher_list = TUNER_LIST if dispatch_type == 'Tuner' else ASSESSOR_LIST
    for dispatcher_name in dipsatcher_list:
        try:
71
72
            # Sleep here to make sure previous stopped exp has enough time to exit to avoid port conflict
            time.sleep(6)
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
            test_builtin_dispatcher(dispatch_type, dispatcher_name)
            print(GREEN + 'Test %s %s: TEST PASS' % (dispatcher_name, dispatch_type) + CLEAR)
        except Exception as error:
            print(RED + 'Test %s %s: TEST FAIL' % (dispatcher_name, dispatch_type) + CLEAR)
            print('%r' % error)
            traceback.print_exc()
            raise error
        finally:
            subprocess.run(['nnictl', 'stop'])

if __name__ == '__main__':
    installed = (sys.argv[-1] != '--preinstall')
    setup_experiment(installed)

    run('Tuner')
    run('Assessor')