run_tests.py 8.79 KB
Newer Older
chicm-ms's avatar
chicm-ms committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import sys
import os
import argparse
import subprocess
import time
import datetime
import shlex
import traceback
import json
import ruamel.yaml as yaml

from utils import get_experiment_status, get_yml_content, dump_yml_content, get_experiment_id, \
    parse_max_duration_time, get_trial_stats, deep_update, print_trial_job_log, get_failed_trial_jobs, \
    get_experiment_dir, print_experiment_log
18
from utils import GREEN, RED, CLEAR, STATUS_URL, TRIAL_JOBS_URL, EXPERIMENT_URL, REST_ENDPOINT, wait_for_port_available
chicm-ms's avatar
chicm-ms committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import validators

it_variables = {}

def update_training_service_config(config, training_service):
    it_ts_config = get_yml_content(os.path.join('config', 'training_service.yml'))

    # hack for kubeflow trial config
    if training_service == 'kubeflow':
        it_ts_config[training_service]['trial']['worker']['command'] = config['trial']['command']
        config['trial'].pop('command')
        if 'gpuNum' in config['trial']:
            config['trial'].pop('gpuNum')

    if training_service == 'frameworkcontroller':
        it_ts_config[training_service]['trial']['taskRoles'][0]['command'] = config['trial']['command']
        config['trial'].pop('command')
        if 'gpuNum' in config['trial']:
            config['trial'].pop('gpuNum')

    deep_update(config, it_ts_config['all'])
    deep_update(config, it_ts_config[training_service])

def prepare_config_file(test_case_config, it_config, args):
    config_path = args.nni_source_dir + test_case_config['configFile']
    test_yml_config = get_yml_content(config_path)

    # apply test case specific config
    if test_case_config.get('config') is not None:
        deep_update(test_yml_config, test_case_config['config'])

    # hack for windows
    if sys.platform == 'win32' and args.ts == 'local':
        test_yml_config['trial']['command'] = test_yml_config['trial']['command'].replace('python3', 'python')

    # apply training service config
    # user's gpuNum, logCollection config is overwritten by the config in training_service.yml
    # the hack for kubeflow should be applied at last step
    update_training_service_config(test_yml_config, args.ts)

    # generate temporary config yml file to launch experiment
    new_config_file = config_path + '.tmp'
    dump_yml_content(new_config_file, test_yml_config)
    print(yaml.dump(test_yml_config, default_flow_style=False), flush=True)

    return new_config_file

def run_test_case(test_case_config, it_config, args):
    new_config_file = prepare_config_file(test_case_config, it_config, args)
    # set configFile variable
    it_variables['$configFile'] = new_config_file

    try:
        launch_test(new_config_file, args.ts, test_case_config)
73
        invoke_validator(test_case_config, args.nni_source_dir, args.ts)
chicm-ms's avatar
chicm-ms committed
74
75
76
77
78
79
80
81
82
    finally:
        stop_command = get_command(test_case_config, 'stopCommand')
        print('Stop command:', stop_command, flush=True)
        if stop_command:
            subprocess.run(shlex.split(stop_command))
        # remove tmp config file
        if os.path.exists(new_config_file):
            os.remove(new_config_file)

83
def invoke_validator(test_case_config, nni_source_dir, training_service):
chicm-ms's avatar
chicm-ms committed
84
85
86
87
88
89
90
    validator_config = test_case_config.get('validator')
    if validator_config is None or validator_config.get('class') is None:
        return

    validator = validators.__dict__[validator_config.get('class')]()
    kwargs = validator_config.get('kwargs', {})
    print('kwargs:', kwargs)
91
92
93
94
95
96
97
    experiment_id = get_experiment_id(EXPERIMENT_URL)
    try:
        validator(REST_ENDPOINT, get_experiment_dir(EXPERIMENT_URL), nni_source_dir, **kwargs)
    except:
        print_experiment_log(experiment_id=experiment_id)
        print_trial_job_log(training_service, TRIAL_JOBS_URL)
        raise
chicm-ms's avatar
chicm-ms committed
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125

def get_max_values(config_file):
    experiment_config = get_yml_content(config_file)
    return parse_max_duration_time(experiment_config['maxExecDuration']), experiment_config['maxTrialNum']

def get_command(test_case_config, commandKey):
    command = test_case_config.get(commandKey)
    if commandKey == 'launchCommand':
        assert command is not None
    if command is None:
        return None

    # replace variables
    for k in it_variables:
        command = command.replace(k, it_variables[k])

    # hack for windows, not limited to local training service
    if sys.platform == 'win32':
        command = command.replace('python3', 'python')

    return command

def launch_test(config_file, training_service, test_case_config):
    launch_command = get_command(test_case_config, 'launchCommand')
    print('launch command: ', launch_command, flush=True)

    proc = subprocess.run(shlex.split(launch_command))

126
    assert proc.returncode == 0, 'launch command failed with code %d' % proc.returncode
chicm-ms's avatar
chicm-ms committed
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142

    # set experiment ID into variable
    exp_var_name = test_case_config.get('setExperimentIdtoVar')
    if exp_var_name is not None:
        assert exp_var_name.startswith('$')
        it_variables[exp_var_name] = get_experiment_id(EXPERIMENT_URL)
    print('variables:', it_variables)

    max_duration, max_trial_num = get_max_values(config_file)
    print('max_duration:', max_duration, ' max_trial_num:', max_trial_num)

    if not test_case_config.get('experimentStatusCheck'):
        return

    bg_time = time.time()
    print(str(datetime.datetime.now()), ' waiting ...', flush=True)
143
144
    try:
        # wait restful server to be ready
chicm-ms's avatar
chicm-ms committed
145
        time.sleep(3)
146
147
148
149
150
151
152
153
154
155
156
157
158
159
        experiment_id = get_experiment_id(EXPERIMENT_URL)
        while True:
            waited_time = time.time() - bg_time
            if  waited_time > max_duration + 10:
                print('waited: {}, max_duration: {}'.format(waited_time, max_duration))
                break
            status = get_experiment_status(STATUS_URL)
            if status in ['DONE', 'ERROR']:
                print('experiment status:', status)
                break
            num_failed = len(get_failed_trial_jobs(TRIAL_JOBS_URL))
            if num_failed > 0:
                print('failed jobs: ', num_failed)
                break
160
            time.sleep(1)
161
162
163
    except:
        print_experiment_log(experiment_id=experiment_id)
        raise
chicm-ms's avatar
chicm-ms committed
164
165
    print(str(datetime.datetime.now()), ' waiting done', flush=True)
    if get_experiment_status(STATUS_URL) == 'ERROR':
166
        print_experiment_log(experiment_id=experiment_id)
chicm-ms's avatar
chicm-ms committed
167
168
169
170

    trial_stats = get_trial_stats(TRIAL_JOBS_URL)
    print(json.dumps(trial_stats, indent=4), flush=True)
    if status != 'DONE' or trial_stats['SUCCEEDED'] + trial_stats['EARLY_STOPPED'] < max_trial_num:
171
        print_experiment_log(experiment_id=experiment_id)
chicm-ms's avatar
chicm-ms committed
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
        print_trial_job_log(training_service, TRIAL_JOBS_URL)
        raise AssertionError('Failed to finish in maxExecDuration')

def case_excluded(name, excludes):
    if name is None:
        return False
    if excludes is not None:
        excludes = excludes.split(',')
        for e in excludes:
            if name in e or e in name:
                return True
    return False

def case_included(name, cases):
    assert cases is not None
    for case in cases.split(','):
        if case in name:
            return True
    return False

def match_platform(test_case_config):
    return sys.platform in test_case_config['platform'].split(' ')

def run(args):
    it_config = get_yml_content(args.config)

    for test_case_config in it_config['testCases']:
        name = test_case_config['name']
        if case_excluded(name, args.exclude):
            print('{} excluded'.format(name))
            continue
        if args.cases and not case_included(name, args.cases):
            continue

        # fill test case default config
        for k in it_config['defaultTestCaseConfig']:
            if k not in test_case_config:
                test_case_config[k] = it_config['defaultTestCaseConfig'][k]
        print(json.dumps(test_case_config, indent=4))

        if not match_platform(test_case_config):
            print('skipped {}, platform {} not match [{}]'.format(name, sys.platform, test_case_config['platform']))
            continue

        wait_for_port_available(8080, 30)
        print('{}Testing: {}{}'.format(GREEN, name, CLEAR))
        begin_time = time.time()

        run_test_case(test_case_config, it_config, args)
        print('{}Test {}: TEST PASS IN {} SECONDS{}'.format(GREEN, name, int(time.time()-begin_time), CLEAR), flush=True)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, required=True)
    parser.add_argument("--nni_source_dir", type=str, default='../')
    parser.add_argument("--cases", type=str, default=None)
    parser.add_argument("--exclude", type=str, default=None)
    parser.add_argument("--ts", type=str, choices=['local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller'], default='local')
    args = parser.parse_args()

    run(args)