run_tests.py 9.49 KB
Newer Older
chicm-ms's avatar
chicm-ms committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import sys
import os
import argparse
import subprocess
import time
import datetime
import shlex
import traceback
import json
import ruamel.yaml as yaml

from utils import get_experiment_status, get_yml_content, dump_yml_content, get_experiment_id, \
    parse_max_duration_time, get_trial_stats, deep_update, print_trial_job_log, get_failed_trial_jobs, \
    get_experiment_dir, print_experiment_log
18
from utils import GREEN, RED, CLEAR, STATUS_URL, TRIAL_JOBS_URL, EXPERIMENT_URL, REST_ENDPOINT, wait_for_port_available
chicm-ms's avatar
chicm-ms committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import validators

it_variables = {}

def update_training_service_config(config, training_service):
    it_ts_config = get_yml_content(os.path.join('config', 'training_service.yml'))

    # hack for kubeflow trial config
    if training_service == 'kubeflow':
        it_ts_config[training_service]['trial']['worker']['command'] = config['trial']['command']
        config['trial'].pop('command')
        if 'gpuNum' in config['trial']:
            config['trial'].pop('gpuNum')

    if training_service == 'frameworkcontroller':
        it_ts_config[training_service]['trial']['taskRoles'][0]['command'] = config['trial']['command']
        config['trial'].pop('command')
        if 'gpuNum' in config['trial']:
            config['trial'].pop('gpuNum')

    deep_update(config, it_ts_config['all'])
    deep_update(config, it_ts_config[training_service])

def prepare_config_file(test_case_config, it_config, args):
    config_path = args.nni_source_dir + test_case_config['configFile']
    test_yml_config = get_yml_content(config_path)

    # apply test case specific config
    if test_case_config.get('config') is not None:
        deep_update(test_yml_config, test_case_config['config'])

    # hack for windows
    if sys.platform == 'win32' and args.ts == 'local':
        test_yml_config['trial']['command'] = test_yml_config['trial']['command'].replace('python3', 'python')

    # apply training service config
    # user's gpuNum, logCollection config is overwritten by the config in training_service.yml
    # the hack for kubeflow should be applied at last step
    update_training_service_config(test_yml_config, args.ts)

    # generate temporary config yml file to launch experiment
    new_config_file = config_path + '.tmp'
    dump_yml_content(new_config_file, test_yml_config)
    print(yaml.dump(test_yml_config, default_flow_style=False), flush=True)

    return new_config_file

def run_test_case(test_case_config, it_config, args):
    new_config_file = prepare_config_file(test_case_config, it_config, args)
    # set configFile variable
    it_variables['$configFile'] = new_config_file

    try:
        launch_test(new_config_file, args.ts, test_case_config)
73
        invoke_validator(test_case_config, args.nni_source_dir, args.ts)
chicm-ms's avatar
chicm-ms committed
74
75
76
77
78
    finally:
        stop_command = get_command(test_case_config, 'stopCommand')
        print('Stop command:', stop_command, flush=True)
        if stop_command:
            subprocess.run(shlex.split(stop_command))
79
80
81
82
        exit_command = get_command(test_case_config, 'onExitCommand')
        print('Exit command:', exit_command, flush=True)
        if exit_command:
            subprocess.run(shlex.split(exit_command), check=True)
chicm-ms's avatar
chicm-ms committed
83
84
85
86
        # remove tmp config file
        if os.path.exists(new_config_file):
            os.remove(new_config_file)

87
def invoke_validator(test_case_config, nni_source_dir, training_service):
chicm-ms's avatar
chicm-ms committed
88
89
90
91
92
93
94
    validator_config = test_case_config.get('validator')
    if validator_config is None or validator_config.get('class') is None:
        return

    validator = validators.__dict__[validator_config.get('class')]()
    kwargs = validator_config.get('kwargs', {})
    print('kwargs:', kwargs)
95
96
97
98
99
100
101
    experiment_id = get_experiment_id(EXPERIMENT_URL)
    try:
        validator(REST_ENDPOINT, get_experiment_dir(EXPERIMENT_URL), nni_source_dir, **kwargs)
    except:
        print_experiment_log(experiment_id=experiment_id)
        print_trial_job_log(training_service, TRIAL_JOBS_URL)
        raise
chicm-ms's avatar
chicm-ms committed
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129

def get_max_values(config_file):
    experiment_config = get_yml_content(config_file)
    return parse_max_duration_time(experiment_config['maxExecDuration']), experiment_config['maxTrialNum']

def get_command(test_case_config, commandKey):
    command = test_case_config.get(commandKey)
    if commandKey == 'launchCommand':
        assert command is not None
    if command is None:
        return None

    # replace variables
    for k in it_variables:
        command = command.replace(k, it_variables[k])

    # hack for windows, not limited to local training service
    if sys.platform == 'win32':
        command = command.replace('python3', 'python')

    return command

def launch_test(config_file, training_service, test_case_config):
    launch_command = get_command(test_case_config, 'launchCommand')
    print('launch command: ', launch_command, flush=True)

    proc = subprocess.run(shlex.split(launch_command))

130
    assert proc.returncode == 0, 'launch command failed with code %d' % proc.returncode
chicm-ms's avatar
chicm-ms committed
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146

    # set experiment ID into variable
    exp_var_name = test_case_config.get('setExperimentIdtoVar')
    if exp_var_name is not None:
        assert exp_var_name.startswith('$')
        it_variables[exp_var_name] = get_experiment_id(EXPERIMENT_URL)
    print('variables:', it_variables)

    max_duration, max_trial_num = get_max_values(config_file)
    print('max_duration:', max_duration, ' max_trial_num:', max_trial_num)

    if not test_case_config.get('experimentStatusCheck'):
        return

    bg_time = time.time()
    print(str(datetime.datetime.now()), ' waiting ...', flush=True)
147
148
    try:
        # wait restful server to be ready
chicm-ms's avatar
chicm-ms committed
149
        time.sleep(3)
150
151
152
153
154
155
156
157
158
159
160
161
162
163
        experiment_id = get_experiment_id(EXPERIMENT_URL)
        while True:
            waited_time = time.time() - bg_time
            if  waited_time > max_duration + 10:
                print('waited: {}, max_duration: {}'.format(waited_time, max_duration))
                break
            status = get_experiment_status(STATUS_URL)
            if status in ['DONE', 'ERROR']:
                print('experiment status:', status)
                break
            num_failed = len(get_failed_trial_jobs(TRIAL_JOBS_URL))
            if num_failed > 0:
                print('failed jobs: ', num_failed)
                break
164
            time.sleep(1)
165
166
167
    except:
        print_experiment_log(experiment_id=experiment_id)
        raise
chicm-ms's avatar
chicm-ms committed
168
169
    print(str(datetime.datetime.now()), ' waiting done', flush=True)
    if get_experiment_status(STATUS_URL) == 'ERROR':
170
        print_experiment_log(experiment_id=experiment_id)
chicm-ms's avatar
chicm-ms committed
171
172
173
174

    trial_stats = get_trial_stats(TRIAL_JOBS_URL)
    print(json.dumps(trial_stats, indent=4), flush=True)
    if status != 'DONE' or trial_stats['SUCCEEDED'] + trial_stats['EARLY_STOPPED'] < max_trial_num:
175
        print_experiment_log(experiment_id=experiment_id)
chicm-ms's avatar
chicm-ms committed
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
        print_trial_job_log(training_service, TRIAL_JOBS_URL)
        raise AssertionError('Failed to finish in maxExecDuration')

def case_excluded(name, excludes):
    if name is None:
        return False
    if excludes is not None:
        excludes = excludes.split(',')
        for e in excludes:
            if name in e or e in name:
                return True
    return False

def case_included(name, cases):
    assert cases is not None
    for case in cases.split(','):
        if case in name:
            return True
    return False

def match_platform(test_case_config):
    return sys.platform in test_case_config['platform'].split(' ')

chicm-ms's avatar
chicm-ms committed
199
200
201
202
203
204
205
206
207
def match_training_service(test_case_config, cur_training_service):
    case_ts = test_case_config['trainingService']
    assert case_ts is not None
    if case_ts == 'all':
        return True
    if cur_training_service in case_ts.split(' '):
        return True
    return False

chicm-ms's avatar
chicm-ms committed
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
def run(args):
    it_config = get_yml_content(args.config)

    for test_case_config in it_config['testCases']:
        name = test_case_config['name']
        if case_excluded(name, args.exclude):
            print('{} excluded'.format(name))
            continue
        if args.cases and not case_included(name, args.cases):
            continue

        # fill test case default config
        for k in it_config['defaultTestCaseConfig']:
            if k not in test_case_config:
                test_case_config[k] = it_config['defaultTestCaseConfig'][k]
        print(json.dumps(test_case_config, indent=4))

        if not match_platform(test_case_config):
            print('skipped {}, platform {} not match [{}]'.format(name, sys.platform, test_case_config['platform']))
            continue

chicm-ms's avatar
chicm-ms committed
229
230
231
232
        if not match_training_service(test_case_config, args.ts):
            print('skipped {}, training service {} not match [{}]'.format(name, args.ts, test_case_config['trainingService']))
            continue

chicm-ms's avatar
chicm-ms committed
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
        wait_for_port_available(8080, 30)
        print('{}Testing: {}{}'.format(GREEN, name, CLEAR))
        begin_time = time.time()

        run_test_case(test_case_config, it_config, args)
        print('{}Test {}: TEST PASS IN {} SECONDS{}'.format(GREEN, name, int(time.time()-begin_time), CLEAR), flush=True)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, required=True)
    parser.add_argument("--nni_source_dir", type=str, default='../')
    parser.add_argument("--cases", type=str, default=None)
    parser.add_argument("--exclude", type=str, default=None)
    parser.add_argument("--ts", type=str, choices=['local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller'], default='local')
    args = parser.parse_args()

    run(args)