run_tests.py 8.54 KB
Newer Older
chicm-ms's avatar
chicm-ms committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import sys
import os
import argparse
import subprocess
import time
import datetime
import shlex
import traceback
import json
import ruamel.yaml as yaml

from utils import get_experiment_status, get_yml_content, dump_yml_content, get_experiment_id, \
    parse_max_duration_time, get_trial_stats, deep_update, print_trial_job_log, get_failed_trial_jobs, \
    get_experiment_dir, print_experiment_log
from utils import GREEN, RED, CLEAR, STATUS_URL, TRIAL_JOBS_URL, EXPERIMENT_URL, REST_ENDPOINT, detect_port
import validators

it_variables = {}

def update_training_service_config(config, training_service):
    it_ts_config = get_yml_content(os.path.join('config', 'training_service.yml'))

    # hack for kubeflow trial config
    if training_service == 'kubeflow':
        it_ts_config[training_service]['trial']['worker']['command'] = config['trial']['command']
        config['trial'].pop('command')
        if 'gpuNum' in config['trial']:
            config['trial'].pop('gpuNum')

    if training_service == 'frameworkcontroller':
        it_ts_config[training_service]['trial']['taskRoles'][0]['command'] = config['trial']['command']
        config['trial'].pop('command')
        if 'gpuNum' in config['trial']:
            config['trial'].pop('gpuNum')

    deep_update(config, it_ts_config['all'])
    deep_update(config, it_ts_config[training_service])

def prepare_config_file(test_case_config, it_config, args):
    config_path = args.nni_source_dir + test_case_config['configFile']
    test_yml_config = get_yml_content(config_path)

    # apply test case specific config
    if test_case_config.get('config') is not None:
        deep_update(test_yml_config, test_case_config['config'])

    # hack for windows
    if sys.platform == 'win32' and args.ts == 'local':
        test_yml_config['trial']['command'] = test_yml_config['trial']['command'].replace('python3', 'python')

    # apply training service config
    # user's gpuNum, logCollection config is overwritten by the config in training_service.yml
    # the hack for kubeflow should be applied at last step
    update_training_service_config(test_yml_config, args.ts)

    # generate temporary config yml file to launch experiment
    new_config_file = config_path + '.tmp'
    dump_yml_content(new_config_file, test_yml_config)
    print(yaml.dump(test_yml_config, default_flow_style=False), flush=True)

    return new_config_file

def run_test_case(test_case_config, it_config, args):
    new_config_file = prepare_config_file(test_case_config, it_config, args)
    # set configFile variable
    it_variables['$configFile'] = new_config_file

    try:
        launch_test(new_config_file, args.ts, test_case_config)
        invoke_validator(test_case_config, args.nni_source_dir)
    finally:
        stop_command = get_command(test_case_config, 'stopCommand')
        print('Stop command:', stop_command, flush=True)
        if stop_command:
            subprocess.run(shlex.split(stop_command))
        # remove tmp config file
        if os.path.exists(new_config_file):
            os.remove(new_config_file)

def invoke_validator(test_case_config, nni_source_dir):
    validator_config = test_case_config.get('validator')
    if validator_config is None or validator_config.get('class') is None:
        return

    validator = validators.__dict__[validator_config.get('class')]()
    kwargs = validator_config.get('kwargs', {})
    print('kwargs:', kwargs)
    validator(REST_ENDPOINT, get_experiment_dir(EXPERIMENT_URL), nni_source_dir, **kwargs)

def get_max_values(config_file):
    experiment_config = get_yml_content(config_file)
    return parse_max_duration_time(experiment_config['maxExecDuration']), experiment_config['maxTrialNum']

def get_command(test_case_config, commandKey):
    command = test_case_config.get(commandKey)
    if commandKey == 'launchCommand':
        assert command is not None
    if command is None:
        return None

    # replace variables
    for k in it_variables:
        command = command.replace(k, it_variables[k])

    # hack for windows, not limited to local training service
    if sys.platform == 'win32':
        command = command.replace('python3', 'python')

    return command

def launch_test(config_file, training_service, test_case_config):
    launch_command = get_command(test_case_config, 'launchCommand')
    print('launch command: ', launch_command, flush=True)

    proc = subprocess.run(shlex.split(launch_command))

    assert proc.returncode == 0, '`nnictl create` failed with code %d' % proc.returncode

    # set experiment ID into variable
    exp_var_name = test_case_config.get('setExperimentIdtoVar')
    if exp_var_name is not None:
        assert exp_var_name.startswith('$')
        it_variables[exp_var_name] = get_experiment_id(EXPERIMENT_URL)
    print('variables:', it_variables)

    max_duration, max_trial_num = get_max_values(config_file)
    print('max_duration:', max_duration, ' max_trial_num:', max_trial_num)

    if not test_case_config.get('experimentStatusCheck'):
        return

    bg_time = time.time()
    print(str(datetime.datetime.now()), ' waiting ...', flush=True)
    while True:
        time.sleep(3)
        waited_time = time.time() - bg_time
        if  waited_time > max_duration + 10:
            print('waited: {}, max_duration: {}'.format(waited_time, max_duration))
            break
        status = get_experiment_status(STATUS_URL)
        if status in ['DONE', 'ERROR']:
            print('experiment status:', status)
            break
        num_failed = len(get_failed_trial_jobs(TRIAL_JOBS_URL))
        if num_failed > 0:
            print('failed jobs: ', num_failed)
            break

    print(str(datetime.datetime.now()), ' waiting done', flush=True)
    if get_experiment_status(STATUS_URL) == 'ERROR':
        print_experiment_log(EXPERIMENT_URL)

    trial_stats = get_trial_stats(TRIAL_JOBS_URL)
    print(json.dumps(trial_stats, indent=4), flush=True)
    if status != 'DONE' or trial_stats['SUCCEEDED'] + trial_stats['EARLY_STOPPED'] < max_trial_num:
        print_trial_job_log(training_service, TRIAL_JOBS_URL)
        raise AssertionError('Failed to finish in maxExecDuration')

def case_excluded(name, excludes):
    if name is None:
        return False
    if excludes is not None:
        excludes = excludes.split(',')
        for e in excludes:
            if name in e or e in name:
                return True
    return False

def case_included(name, cases):
    assert cases is not None
    for case in cases.split(','):
        if case in name:
            return True
    return False

def wait_for_port_available(port, timeout):
    begin_time = time.time()
    while True:
        if not detect_port(port):
            return
        if time.time() - begin_time > timeout:
            msg = 'port {} is not available in {} seconds.'.format(port, timeout)
            raise RuntimeError(msg)
        time.sleep(5)

def match_platform(test_case_config):
    return sys.platform in test_case_config['platform'].split(' ')

def run(args):
    it_config = get_yml_content(args.config)

    for test_case_config in it_config['testCases']:
        name = test_case_config['name']
        if case_excluded(name, args.exclude):
            print('{} excluded'.format(name))
            continue
        if args.cases and not case_included(name, args.cases):
            continue

        # fill test case default config
        for k in it_config['defaultTestCaseConfig']:
            if k not in test_case_config:
                test_case_config[k] = it_config['defaultTestCaseConfig'][k]
        print(json.dumps(test_case_config, indent=4))

        if not match_platform(test_case_config):
            print('skipped {}, platform {} not match [{}]'.format(name, sys.platform, test_case_config['platform']))
            continue

        wait_for_port_available(8080, 30)
        print('{}Testing: {}{}'.format(GREEN, name, CLEAR))
        begin_time = time.time()

        run_test_case(test_case_config, it_config, args)
        print('{}Test {}: TEST PASS IN {} SECONDS{}'.format(GREEN, name, int(time.time()-begin_time), CLEAR), flush=True)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, required=True)
    parser.add_argument("--nni_source_dir", type=str, default='../')
    parser.add_argument("--cases", type=str, default=None)
    parser.add_argument("--exclude", type=str, default=None)
    parser.add_argument("--ts", type=str, choices=['local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller'], default='local')
    args = parser.parse_args()

    run(args)