run_tests.py 10.8 KB
Newer Older
chicm-ms's avatar
chicm-ms committed
1
2
3
4
5
6
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import argparse
import datetime
import json
7
8
9
10
11
import os
import subprocess
import sys
import time

chicm-ms's avatar
chicm-ms committed
12
13
14
import ruamel.yaml as yaml

import validators
15
16
17
18
19
20
21
from utils import (CLEAR, EXPERIMENT_URL, GREEN, RED, REST_ENDPOINT,
                   STATUS_URL, TRIAL_JOBS_URL, deep_update, dump_yml_content,
                   get_experiment_dir, get_experiment_id,
                   get_experiment_status, get_failed_trial_jobs,
                   get_trial_stats, get_yml_content, parse_max_duration_time,
                   print_experiment_log, print_trial_job_log,
                   wait_for_port_available)
chicm-ms's avatar
chicm-ms committed
22
23
24

it_variables = {}

25

SparkSnail's avatar
SparkSnail committed
26
def update_training_service_config(config, training_service, config_file_path):
chicm-ms's avatar
chicm-ms committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
    it_ts_config = get_yml_content(os.path.join('config', 'training_service.yml'))

    # hack for kubeflow trial config
    if training_service == 'kubeflow':
        it_ts_config[training_service]['trial']['worker']['command'] = config['trial']['command']
        config['trial'].pop('command')
        if 'gpuNum' in config['trial']:
            config['trial'].pop('gpuNum')

    if training_service == 'frameworkcontroller':
        it_ts_config[training_service]['trial']['taskRoles'][0]['command'] = config['trial']['command']
        config['trial'].pop('command')
        if 'gpuNum' in config['trial']:
            config['trial'].pop('gpuNum')
SparkSnail's avatar
SparkSnail committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54
    
    if training_service == 'adl':
        # hack for adl trial config, codeDir in adl mode refers to path in container
        containerCodeDir = config['trial']['codeDir']
        # replace metric test folders to container folder
        if config['trial']['codeDir'] == '.':
            containerCodeDir = '/' + config_file_path[:config_file_path.rfind('/')]
        elif config['trial']['codeDir'] == '../naive_trial':
            containerCodeDir = '/test/config/naive_trial'
        elif '../../../' in config['trial']['codeDir']:
            # replace example folders to container folder
            containerCodeDir = config['trial']['codeDir'].replace('../../../', '/')
        it_ts_config[training_service]['trial']['codeDir'] = containerCodeDir
        it_ts_config[training_service]['trial']['command'] = 'cd {0} && {1}'.format(containerCodeDir, config['trial']['command'])
chicm-ms's avatar
chicm-ms committed
55
56
57
58

    deep_update(config, it_ts_config['all'])
    deep_update(config, it_ts_config[training_service])

59

chicm-ms's avatar
chicm-ms committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def prepare_config_file(test_case_config, it_config, args):
    config_path = args.nni_source_dir + test_case_config['configFile']
    test_yml_config = get_yml_content(config_path)

    # apply test case specific config
    if test_case_config.get('config') is not None:
        deep_update(test_yml_config, test_case_config['config'])

    # hack for windows
    if sys.platform == 'win32' and args.ts == 'local':
        test_yml_config['trial']['command'] = test_yml_config['trial']['command'].replace('python3', 'python')

    # apply training service config
    # user's gpuNum, logCollection config is overwritten by the config in training_service.yml
    # the hack for kubeflow should be applied at last step
SparkSnail's avatar
SparkSnail committed
75
    update_training_service_config(test_yml_config, args.ts, test_case_config['configFile'])
chicm-ms's avatar
chicm-ms committed
76
77
78
79
80
81
82
83

    # generate temporary config yml file to launch experiment
    new_config_file = config_path + '.tmp'
    dump_yml_content(new_config_file, test_yml_config)
    print(yaml.dump(test_yml_config, default_flow_style=False), flush=True)

    return new_config_file

84

chicm-ms's avatar
chicm-ms committed
85
86
87
88
89
90
91
def run_test_case(test_case_config, it_config, args):
    new_config_file = prepare_config_file(test_case_config, it_config, args)
    # set configFile variable
    it_variables['$configFile'] = new_config_file

    try:
        launch_test(new_config_file, args.ts, test_case_config)
92
        invoke_validator(test_case_config, args.nni_source_dir, args.ts)
chicm-ms's avatar
chicm-ms committed
93
94
95
96
    finally:
        stop_command = get_command(test_case_config, 'stopCommand')
        print('Stop command:', stop_command, flush=True)
        if stop_command:
97
            subprocess.run(stop_command, shell=True)
98
99
100
        exit_command = get_command(test_case_config, 'onExitCommand')
        print('Exit command:', exit_command, flush=True)
        if exit_command:
101
            subprocess.run(exit_command, shell=True, check=True)
chicm-ms's avatar
chicm-ms committed
102
103
104
105
        # remove tmp config file
        if os.path.exists(new_config_file):
            os.remove(new_config_file)

106

107
def invoke_validator(test_case_config, nni_source_dir, training_service):
chicm-ms's avatar
chicm-ms committed
108
109
110
111
112
113
114
    validator_config = test_case_config.get('validator')
    if validator_config is None or validator_config.get('class') is None:
        return

    validator = validators.__dict__[validator_config.get('class')]()
    kwargs = validator_config.get('kwargs', {})
    print('kwargs:', kwargs)
115
116
117
118
119
120
121
    experiment_id = get_experiment_id(EXPERIMENT_URL)
    try:
        validator(REST_ENDPOINT, get_experiment_dir(EXPERIMENT_URL), nni_source_dir, **kwargs)
    except:
        print_experiment_log(experiment_id=experiment_id)
        print_trial_job_log(training_service, TRIAL_JOBS_URL)
        raise
chicm-ms's avatar
chicm-ms committed
122

123

chicm-ms's avatar
chicm-ms committed
124
125
126
127
def get_max_values(config_file):
    experiment_config = get_yml_content(config_file)
    return parse_max_duration_time(experiment_config['maxExecDuration']), experiment_config['maxTrialNum']

128

chicm-ms's avatar
chicm-ms committed
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
def get_command(test_case_config, commandKey):
    command = test_case_config.get(commandKey)
    if commandKey == 'launchCommand':
        assert command is not None
    if command is None:
        return None

    # replace variables
    for k in it_variables:
        command = command.replace(k, it_variables[k])

    # hack for windows, not limited to local training service
    if sys.platform == 'win32':
        command = command.replace('python3', 'python')

    return command

146

chicm-ms's avatar
chicm-ms committed
147
148
149
150
def launch_test(config_file, training_service, test_case_config):
    launch_command = get_command(test_case_config, 'launchCommand')
    print('launch command: ', launch_command, flush=True)

151
    proc = subprocess.run(launch_command, shell=True)
chicm-ms's avatar
chicm-ms committed
152

153
    assert proc.returncode == 0, 'launch command failed with code %d' % proc.returncode
chicm-ms's avatar
chicm-ms committed
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169

    # set experiment ID into variable
    exp_var_name = test_case_config.get('setExperimentIdtoVar')
    if exp_var_name is not None:
        assert exp_var_name.startswith('$')
        it_variables[exp_var_name] = get_experiment_id(EXPERIMENT_URL)
    print('variables:', it_variables)

    max_duration, max_trial_num = get_max_values(config_file)
    print('max_duration:', max_duration, ' max_trial_num:', max_trial_num)

    if not test_case_config.get('experimentStatusCheck'):
        return

    bg_time = time.time()
    print(str(datetime.datetime.now()), ' waiting ...', flush=True)
170
171
    try:
        # wait restful server to be ready
chicm-ms's avatar
chicm-ms committed
172
        time.sleep(3)
173
174
175
        experiment_id = get_experiment_id(EXPERIMENT_URL)
        while True:
            waited_time = time.time() - bg_time
176
            if waited_time > max_duration + 10:
177
178
179
180
181
182
183
184
185
186
                print('waited: {}, max_duration: {}'.format(waited_time, max_duration))
                break
            status = get_experiment_status(STATUS_URL)
            if status in ['DONE', 'ERROR']:
                print('experiment status:', status)
                break
            num_failed = len(get_failed_trial_jobs(TRIAL_JOBS_URL))
            if num_failed > 0:
                print('failed jobs: ', num_failed)
                break
187
            time.sleep(1)
188
189
190
    except:
        print_experiment_log(experiment_id=experiment_id)
        raise
chicm-ms's avatar
chicm-ms committed
191
192
    print(str(datetime.datetime.now()), ' waiting done', flush=True)
    if get_experiment_status(STATUS_URL) == 'ERROR':
193
        print_experiment_log(experiment_id=experiment_id)
chicm-ms's avatar
chicm-ms committed
194
195
196
197

    trial_stats = get_trial_stats(TRIAL_JOBS_URL)
    print(json.dumps(trial_stats, indent=4), flush=True)
    if status != 'DONE' or trial_stats['SUCCEEDED'] + trial_stats['EARLY_STOPPED'] < max_trial_num:
198
        print_experiment_log(experiment_id=experiment_id)
chicm-ms's avatar
chicm-ms committed
199
200
201
        print_trial_job_log(training_service, TRIAL_JOBS_URL)
        raise AssertionError('Failed to finish in maxExecDuration')

202

chicm-ms's avatar
chicm-ms committed
203
204
205
206
207
208
209
210
211
212
def case_excluded(name, excludes):
    if name is None:
        return False
    if excludes is not None:
        excludes = excludes.split(',')
        for e in excludes:
            if name in e or e in name:
                return True
    return False

213

chicm-ms's avatar
chicm-ms committed
214
215
216
217
218
219
220
def case_included(name, cases):
    assert cases is not None
    for case in cases.split(','):
        if case in name:
            return True
    return False

221

chicm-ms's avatar
chicm-ms committed
222
223
224
def match_platform(test_case_config):
    return sys.platform in test_case_config['platform'].split(' ')

225

chicm-ms's avatar
chicm-ms committed
226
227
228
229
230
231
232
233
234
def match_training_service(test_case_config, cur_training_service):
    case_ts = test_case_config['trainingService']
    assert case_ts is not None
    if case_ts == 'all':
        return True
    if cur_training_service in case_ts.split(' '):
        return True
    return False

235

chicm-ms's avatar
chicm-ms committed
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
def run(args):
    it_config = get_yml_content(args.config)

    for test_case_config in it_config['testCases']:
        name = test_case_config['name']
        if case_excluded(name, args.exclude):
            print('{} excluded'.format(name))
            continue
        if args.cases and not case_included(name, args.cases):
            continue

        # fill test case default config
        for k in it_config['defaultTestCaseConfig']:
            if k not in test_case_config:
                test_case_config[k] = it_config['defaultTestCaseConfig'][k]
        print(json.dumps(test_case_config, indent=4))

        if not match_platform(test_case_config):
            print('skipped {}, platform {} not match [{}]'.format(name, sys.platform, test_case_config['platform']))
            continue

chicm-ms's avatar
chicm-ms committed
257
        if not match_training_service(test_case_config, args.ts):
258
259
            print('skipped {}, training service {} not match [{}]'.format(
                name, args.ts, test_case_config['trainingService']))
chicm-ms's avatar
chicm-ms committed
260
            continue
261
262
263
264
265
        # remote mode need more time to cleanup 
        if args.ts == 'remote':
            wait_for_port_available(8080, 180)
        else:
            wait_for_port_available(8080, 30)
SparkSnail's avatar
SparkSnail committed
266
267
268
269

        # adl mode need more time to cleanup PVC
        if args.ts == 'adl' and name == 'nnictl-resume-2':
            time.sleep(30)
270
        print('## {}Testing: {}{} ##'.format(GREEN, name, CLEAR))
chicm-ms's avatar
chicm-ms committed
271
272
273
274
275
276
277
278
279
280
281
282
        begin_time = time.time()

        run_test_case(test_case_config, it_config, args)
        print('{}Test {}: TEST PASS IN {} SECONDS{}'.format(GREEN, name, int(time.time()-begin_time), CLEAR), flush=True)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, required=True)
    parser.add_argument("--nni_source_dir", type=str, default='../')
    parser.add_argument("--cases", type=str, default=None)
    parser.add_argument("--exclude", type=str, default=None)
283
    parser.add_argument("--ts", type=str, choices=['local', 'remote', 'pai',
SparkSnail's avatar
SparkSnail committed
284
                                                   'kubeflow', 'frameworkcontroller', 'adl'], default='local')
chicm-ms's avatar
chicm-ms committed
285
286
287
    args = parser.parse_args()

    run(args)