"lm_eval/tasks/pega_mmlu/default/utils.py" did not exist on "292fdae58d2e5dbdb9ad505763759a3c4b7199e9"
launcher.py 28 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
Deshui Yu's avatar
Deshui Yu committed
3
4
5

import json
import os
6
from pathlib import Path
7
import sys
8
import string
chicm-ms's avatar
chicm-ms committed
9
10
import random
import time
Deshui Yu's avatar
Deshui Yu committed
11
import tempfile
12
import re
13
from subprocess import Popen, check_call, CalledProcessError, PIPE, STDOUT
liuzhe-lz's avatar
liuzhe-lz committed
14
from nni.experiment.config import ExperimentConfig, convert
15
16
from nni.tools.annotation import expand_annotations, generate_search_space
from nni.tools.package_utils import get_builtin_module_class_name
17
import nni_node  # pylint: disable=import-error, wrong-import-order
Deshui Yu's avatar
Deshui Yu committed
18
from .launcher_utils import validate_all_content
chicm-ms's avatar
chicm-ms committed
19
from .rest_utils import rest_put, rest_post, check_rest_server, check_response
20
from .url_utils import cluster_metadata_url, experiment_url, get_local_urls, set_prefix_url
21
from .config_utils import Config, Experiments
22
from .common_utils import get_yml_content, get_json_content, print_error, print_normal, detect_port, get_user
chicm-ms's avatar
chicm-ms committed
23

J-shang's avatar
J-shang committed
24
from .constants import NNI_HOME_DIR, ERROR_INFO, REST_TIME_OUT, EXPERIMENT_SUCCESS_INFO, LOG_HEADER
25
from .command_utils import check_output_command, kill_command
26
from .nnictl_utils import update_experiment
Gems Guo's avatar
Gems Guo committed
27

28
29
k8s_training_services = ['kubeflow', 'frameworkcontroller', 'adl']

30
def get_log_path(experiment_id):
31
    '''generate stdout and stderr log path'''
J-shang's avatar
J-shang committed
32
33
34
    os.makedirs(os.path.join(NNI_HOME_DIR, experiment_id, 'log'), exist_ok=True)
    stdout_full_path = os.path.join(NNI_HOME_DIR, experiment_id, 'log', 'nnictl_stdout.log')
    stderr_full_path = os.path.join(NNI_HOME_DIR, experiment_id, 'log', 'nnictl_stderr.log')
35
36
37
38
39
40
    return stdout_full_path, stderr_full_path

def print_log_content(config_file_name):
    '''print log information'''
    stdout_full_path, stderr_full_path = get_log_path(config_file_name)
    print_normal(' Stdout:')
41
    print(check_output_command(stdout_full_path))
42
43
    print('\n\n')
    print_normal(' Stderr:')
44
    print(check_output_command(stderr_full_path))
45

46
def start_rest_server(port, platform, mode, experiment_id, foreground=False, log_dir=None, log_level=None, url_prefix=None):
Deshui Yu's avatar
Deshui Yu committed
47
    '''Run nni manager process'''
SparkSnail's avatar
SparkSnail committed
48
    if detect_port(port):
49
        print_error('Port %s is used by another process, please reset the port!\n' \
SparkSnail's avatar
SparkSnail committed
50
        'You could use \'nnictl create --help\' to get help information' % port)
51
        exit(1)
52

53
54
    if (platform not in ['local', 'aml']) and detect_port(int(port) + 1):
        print_error('%s mode need an additional adjacent port %d, and the port %d is used by another process!\n' \
55
        'You could set another port to start experiment!\n' \
56
        'You could use \'nnictl create --help\' to get help information' % (platform, (int(port) + 1), (int(port) + 1)))
57
        exit(1)
Deshui Yu's avatar
Deshui Yu committed
58
59

    print_normal('Starting restful server...')
60

61
    entry_dir = nni_node.__path__[0]
chicm-ms's avatar
chicm-ms committed
62
63
64
    if (not entry_dir) or (not os.path.exists(entry_dir)):
        print_error('Fail to find nni under python library')
        exit(1)
Zejun Lin's avatar
Zejun Lin committed
65
    entry_file = os.path.join(entry_dir, 'main.js')
chicm-ms's avatar
chicm-ms committed
66

demianzhang's avatar
demianzhang committed
67
    if sys.platform == 'win32':
68
69
        node_command = os.path.join(entry_dir, 'node.exe')
    else:
liuzhe-lz's avatar
liuzhe-lz committed
70
        node_command = os.path.join(entry_dir, 'node')
71
72
    cmds = [node_command, '--max-old-space-size=4096', entry_file, '--port', str(port), '--mode', platform, \
            '--experiment_id', experiment_id]
SparkSnail's avatar
SparkSnail committed
73
74
75
76
77
    if mode == 'view':
        cmds += ['--start_mode', 'resume']
        cmds += ['--readonly', 'true']
    else:
        cmds += ['--start_mode', mode]
78
79
80
81
    if log_dir is not None:
        cmds += ['--log_dir', log_dir]
    if log_level is not None:
        cmds += ['--log_level', log_level]
SparkSnail's avatar
SparkSnail committed
82
    if foreground:
83
        cmds += ['--foreground', 'true']
84
85
    if url_prefix:
        _validate_prefix_path(url_prefix)
86
        set_prefix_url(url_prefix)
87
88
        cmds += ['--url_prefix', url_prefix]

89
    stdout_full_path, stderr_full_path = get_log_path(experiment_id)
90
    with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file:
91
92
        start_time = time.time()
        time_now = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))
93
94
95
96
97
98
        #add time information in the header of log files
        log_header = LOG_HEADER % str(time_now)
        stdout_file.write(log_header)
        stderr_file.write(log_header)
        if sys.platform == 'win32':
            from subprocess import CREATE_NEW_PROCESS_GROUP
SparkSnail's avatar
SparkSnail committed
99
            if foreground:
100
101
102
                process = Popen(cmds, cwd=entry_dir, stdout=PIPE, stderr=STDOUT, creationflags=CREATE_NEW_PROCESS_GROUP)
            else:
                process = Popen(cmds, cwd=entry_dir, stdout=stdout_file, stderr=stderr_file, creationflags=CREATE_NEW_PROCESS_GROUP)
103
        else:
SparkSnail's avatar
SparkSnail committed
104
            if foreground:
105
106
107
                process = Popen(cmds, cwd=entry_dir, stdout=PIPE, stderr=PIPE)
            else:
                process = Popen(cmds, cwd=entry_dir, stdout=stdout_file, stderr=stderr_file)
108
    return process, int(start_time * 1000)
Deshui Yu's avatar
Deshui Yu committed
109

110
def set_trial_config(experiment_config, port, config_file_name):
111
    '''set trial configuration'''
Deshui Yu's avatar
Deshui Yu committed
112
    request_data = dict()
113
    request_data['trial_config'] = experiment_config['trial']
114
    response = rest_put(cluster_metadata_url(port), json.dumps(request_data), REST_TIME_OUT)
115
116
117
    if check_response(response):
        return True
    else:
118
        print('Error message is {}'.format(response.text))
119
        _, stderr_full_path = get_log_path(config_file_name)
SparkSnail's avatar
SparkSnail committed
120
121
122
        if response:
            with open(stderr_full_path, 'a+') as fout:
                fout.write(json.dumps(json.loads(response.text), indent=4, sort_keys=True, separators=(',', ':')))
123
        return False
124

125
126
def set_adl_config(experiment_config, port, config_file_name):
    '''set adl configuration'''
127
128
129
130
131
132
133
134
135
136
137
138
    adl_config_data = dict()
    # hack for supporting v2 config, need refactor
    adl_config_data['adl_config'] = {}
    response = rest_put(cluster_metadata_url(port), json.dumps(adl_config_data), REST_TIME_OUT)
    err_message = None
    if not response or not response.status_code == 200:
        if response is not None:
            err_message = response.text
            _, stderr_full_path = get_log_path(config_file_name)
            with open(stderr_full_path, 'a+') as fout:
                fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
        return False, err_message
139
    set_V1_common_config(experiment_config, port, config_file_name)
140
141
142
143
144
145
    result, message = setNNIManagerIp(experiment_config, port, config_file_name)
    if not result:
        return result, message
    #set trial_config
    return set_trial_config(experiment_config, port, config_file_name), None

146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
def validate_response(response, config_file_name):
    err_message = None
    if not response or not response.status_code == 200:
        if response is not None:
            err_message = response.text
            _, stderr_full_path = get_log_path(config_file_name)
            with open(stderr_full_path, 'a+') as fout:
                fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
        print_error('Error:' + err_message)
        exit(1)

# hack to fix v1 version_check and log_collection bug, need refactor
def set_V1_common_config(experiment_config, port, config_file_name):
    version_check = True
    #debug mode should disable version check
    if experiment_config.get('debug') is not None:
        version_check = not experiment_config.get('debug')
    #validate version check
    if experiment_config.get('versionCheck') is not None:
        version_check = experiment_config.get('versionCheck')
    response = rest_put(cluster_metadata_url(port), json.dumps({'version_check': version_check}), REST_TIME_OUT)
    validate_response(response, config_file_name)
    if experiment_config.get('logCollection'):
169
170
        data = json.dumps({'log_collection': experiment_config.get('logCollection')})
        response = rest_put(cluster_metadata_url(port), data, REST_TIME_OUT)
171
172
        validate_response(response, config_file_name)

173
174
175
176
177
def setNNIManagerIp(experiment_config, port, config_file_name):
    '''set nniManagerIp'''
    if experiment_config.get('nniManagerIp') is None:
        return True, None
    ip_config_dict = dict()
chicm-ms's avatar
chicm-ms committed
178
    ip_config_dict['nni_manager_ip'] = {'nniManagerIp': experiment_config['nniManagerIp']}
179
    response = rest_put(cluster_metadata_url(port), json.dumps(ip_config_dict), REST_TIME_OUT)
180
181
182
183
184
185
186
187
188
189
    err_message = None
    if not response or not response.status_code == 200:
        if response is not None:
            err_message = response.text
            _, stderr_full_path = get_log_path(config_file_name)
            with open(stderr_full_path, 'a+') as fout:
                fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
        return False, err_message
    return True, None

190
def set_kubeflow_config(experiment_config, port, config_file_name):
191
    '''set kubeflow configuration'''
192
193
    kubeflow_config_data = dict()
    kubeflow_config_data['kubeflow_config'] = experiment_config['kubeflowConfig']
194
    response = rest_put(cluster_metadata_url(port), json.dumps(kubeflow_config_data), REST_TIME_OUT)
195
196
197
198
199
200
201
202
    err_message = None
    if not response or not response.status_code == 200:
        if response is not None:
            err_message = response.text
            _, stderr_full_path = get_log_path(config_file_name)
            with open(stderr_full_path, 'a+') as fout:
                fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
        return False, err_message
203
    set_V1_common_config(experiment_config, port, config_file_name)
204
205
206
    result, message = setNNIManagerIp(experiment_config, port, config_file_name)
    if not result:
        return result, message
207
208
209
    #set trial_config
    return set_trial_config(experiment_config, port, config_file_name), err_message

210
def set_frameworkcontroller_config(experiment_config, port, config_file_name):
211
    '''set kubeflow configuration'''
212
213
    frameworkcontroller_config_data = dict()
    frameworkcontroller_config_data['frameworkcontroller_config'] = experiment_config['frameworkcontrollerConfig']
214
    response = rest_put(cluster_metadata_url(port), json.dumps(frameworkcontroller_config_data), REST_TIME_OUT)
215
216
217
218
219
220
221
222
    err_message = None
    if not response or not response.status_code == 200:
        if response is not None:
            err_message = response.text
            _, stderr_full_path = get_log_path(config_file_name)
            with open(stderr_full_path, 'a+') as fout:
                fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
        return False, err_message
223
    set_V1_common_config(experiment_config, port, config_file_name)
224
225
226
227
228
229
    result, message = setNNIManagerIp(experiment_config, port, config_file_name)
    if not result:
        return result, message
    #set trial_config
    return set_trial_config(experiment_config, port, config_file_name), err_message

230
231
def set_shared_storage(experiment_config, port, config_file_name):
    if 'sharedStorage' in experiment_config:
232
233
        data = json.dumps({'shared_storage_config': experiment_config['sharedStorage']})
        response = rest_put(cluster_metadata_url(port), data, REST_TIME_OUT)
234
235
236
237
238
239
240
241
242
243
        err_message = None
        if not response or not response.status_code == 200:
            if response is not None:
                err_message = response.text
                _, stderr_full_path = get_log_path(config_file_name)
                with open(stderr_full_path, 'a+') as fout:
                    fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
            return False, err_message
    return True, None

244
def set_experiment_v1(experiment_config, mode, port, config_file_name):
Deshui Yu's avatar
Deshui Yu committed
245
246
247
248
249
250
    '''Call startExperiment (rest POST /experiment) with yaml file content'''
    request_data = dict()
    request_data['authorName'] = experiment_config['authorName']
    request_data['experimentName'] = experiment_config['experimentName']
    request_data['trialConcurrency'] = experiment_config['trialConcurrency']
    request_data['maxExecDuration'] = experiment_config['maxExecDuration']
liuzhe-lz's avatar
liuzhe-lz committed
251
    request_data['maxExperimentDuration'] = str(experiment_config['maxExecDuration']) + 's'
Deshui Yu's avatar
Deshui Yu committed
252
    request_data['maxTrialNum'] = experiment_config['maxTrialNum']
liuzhe-lz's avatar
liuzhe-lz committed
253
    request_data['maxTrialNumber'] = experiment_config['maxTrialNum']
254
    request_data['searchSpace'] = experiment_config.get('searchSpace')
255
    request_data['trainingServicePlatform'] = experiment_config.get('trainingServicePlatform')
256
257
    # hack for hotfix, fix config.trainingService undefined error, need refactor
    request_data['trainingService'] = {'platform': experiment_config.get('trainingServicePlatform')}
258
259
    if experiment_config.get('description'):
        request_data['description'] = experiment_config['description']
chicm-ms's avatar
chicm-ms committed
260
261
    if experiment_config.get('multiPhase'):
        request_data['multiPhase'] = experiment_config.get('multiPhase')
262
263
    if experiment_config.get('multiThread'):
        request_data['multiThread'] = experiment_config.get('multiThread')
J-shang's avatar
J-shang committed
264
265
    if experiment_config.get('nniManagerIp'):
        request_data['nniManagerIp'] = experiment_config.get('nniManagerIp')
QuanluZhang's avatar
QuanluZhang committed
266
267
    if experiment_config.get('advisor'):
        request_data['advisor'] = experiment_config['advisor']
268
269
270
271
        if request_data['advisor'].get('gpuNum'):
            print_error('gpuNum is deprecated, please use gpuIndices instead.')
        if request_data['advisor'].get('gpuIndices') and isinstance(request_data['advisor'].get('gpuIndices'), int):
            request_data['advisor']['gpuIndices'] = str(request_data['advisor'].get('gpuIndices'))
QuanluZhang's avatar
QuanluZhang committed
272
273
    else:
        request_data['tuner'] = experiment_config['tuner']
274
275
276
277
        if request_data['tuner'].get('gpuNum'):
            print_error('gpuNum is deprecated, please use gpuIndices instead.')
        if request_data['tuner'].get('gpuIndices') and isinstance(request_data['tuner'].get('gpuIndices'), int):
            request_data['tuner']['gpuIndices'] = str(request_data['tuner'].get('gpuIndices'))
QuanluZhang's avatar
QuanluZhang committed
278
279
        if 'assessor' in experiment_config:
            request_data['assessor'] = experiment_config['assessor']
280
281
            if request_data['assessor'].get('gpuNum'):
                print_error('gpuNum is deprecated, please remove it from your config file.')
SparkSnail's avatar
SparkSnail committed
282
    #debug mode should disable version check
283
    if experiment_config.get('debug') is not None:
SparkSnail's avatar
SparkSnail committed
284
        request_data['versionCheck'] = not experiment_config.get('debug')
285
286
287
    #validate version check
    if experiment_config.get('versionCheck') is not None:
        request_data['versionCheck'] = experiment_config.get('versionCheck')
SparkSnail's avatar
SparkSnail committed
288
289
    if experiment_config.get('logCollection'):
        request_data['logCollection'] = experiment_config.get('logCollection')
Deshui Yu's avatar
Deshui Yu committed
290
    request_data['clusterMetaData'] = []
291
    if experiment_config['trainingServicePlatform'] == 'kubeflow':
292
293
294
295
        request_data['clusterMetaData'].append(
            {'key': 'kubeflow_config', 'value': experiment_config['kubeflowConfig']})
        request_data['clusterMetaData'].append(
            {'key': 'trial_config', 'value': experiment_config['trial']})
296
297
298
299
300
    elif experiment_config['trainingServicePlatform'] == 'frameworkcontroller':
        request_data['clusterMetaData'].append(
            {'key': 'frameworkcontroller_config', 'value': experiment_config['frameworkcontrollerConfig']})
        request_data['clusterMetaData'].append(
            {'key': 'trial_config', 'value': experiment_config['trial']})
J-shang's avatar
J-shang committed
301
302
303
    elif experiment_config['trainingServicePlatform'] == 'adl':
        request_data['clusterMetaData'].append(
            {'key': 'trial_config', 'value': experiment_config['trial']})
304
    response = rest_post(experiment_url(port), json.dumps(request_data), REST_TIME_OUT, show_error=True)
305
306
307
    if check_response(response):
        return response
    else:
308
        _, stderr_full_path = get_log_path(config_file_name)
309
        if response is not None:
SparkSnail's avatar
SparkSnail committed
310
311
            with open(stderr_full_path, 'a+') as fout:
                fout.write(json.dumps(json.loads(response.text), indent=4, sort_keys=True, separators=(',', ':')))
SparkSnail's avatar
SparkSnail committed
312
            print_error('Setting experiment error, error message is {}'.format(response.text))
313
        return None
Deshui Yu's avatar
Deshui Yu committed
314

315
316
317
318
319
320
321
322
323
324
325
326
327
def set_experiment_v2(experiment_config, mode, port, config_file_name):
    '''Call startExperiment (rest POST /experiment) with yaml file content'''
    response = rest_post(experiment_url(port), json.dumps(experiment_config), REST_TIME_OUT, show_error=True)
    if check_response(response):
        return response
    else:
        _, stderr_full_path = get_log_path(config_file_name)
        if response is not None:
            with open(stderr_full_path, 'a+') as fout:
                fout.write(json.dumps(json.loads(response.text), indent=4, sort_keys=True, separators=(',', ':')))
            print_error('Setting experiment error, error message is {}'.format(response.text))
        return None

SparkSnail's avatar
SparkSnail committed
328
329
330
331
def set_platform_config(platform, experiment_config, port, config_file_name, rest_process):
    '''call set_cluster_metadata for specific platform'''
    print_normal('Setting {0} config...'.format(platform))
    config_result, err_msg = None, None
332
333
    if platform == 'adl':
        config_result, err_msg = set_adl_config(experiment_config, port, config_file_name)
SparkSnail's avatar
SparkSnail committed
334
335
336
337
338
339
340
    elif platform == 'kubeflow':
        config_result, err_msg = set_kubeflow_config(experiment_config, port, config_file_name)
    elif platform == 'frameworkcontroller':
        config_result, err_msg = set_frameworkcontroller_config(experiment_config, port, config_file_name)
    else:
        raise Exception(ERROR_INFO % 'Unsupported platform!')
        exit(1)
341
342
    if config_result:
        config_result, err_msg = set_shared_storage(experiment_config, port, config_file_name)
SparkSnail's avatar
SparkSnail committed
343
344
345
346
347
348
349
350
351
352
    if config_result:
        print_normal('Successfully set {0} config!'.format(platform))
    else:
        print_error('Failed! Error is: {}'.format(err_msg))
        try:
            kill_command(rest_process.pid)
        except Exception:
            raise Exception(ERROR_INFO % 'Rest server stopped!')
        exit(1)

353
def launch_experiment(args, experiment_config, mode, experiment_id, config_version):
Deshui Yu's avatar
Deshui Yu committed
354
    '''follow steps to start rest server and start experiment'''
355
    # check packages for tuner
356
    package_name, module_name = None, None
357
    if experiment_config.get('tuner') and experiment_config['tuner'].get('builtinTunerName'):
358
        package_name = experiment_config['tuner']['builtinTunerName']
chicm-ms's avatar
chicm-ms committed
359
        module_name, _ = get_builtin_module_class_name('tuners', package_name)
360
361
    elif experiment_config.get('advisor') and experiment_config['advisor'].get('builtinAdvisorName'):
        package_name = experiment_config['advisor']['builtinAdvisorName']
chicm-ms's avatar
chicm-ms committed
362
        module_name, _ = get_builtin_module_class_name('advisors', package_name)
363
    if package_name and module_name:
364
        try:
365
            stdout_full_path, stderr_full_path = get_log_path(experiment_id)
366
367
            with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file:
                check_call([sys.executable, '-c', 'import %s'%(module_name)], stdout=stdout_file, stderr=stderr_file)
chicm-ms's avatar
chicm-ms committed
368
        except CalledProcessError:
369
            print_error('some errors happen when import package %s.' %(package_name))
370
            print_log_content(experiment_id)
371
372
373
            if package_name in ['SMAC', 'BOHB', 'PPOTuner']:
                print_error(f'The dependencies for {package_name} can be installed through pip install nni[{package_name}]')
            raise
374
375
376
377
    if config_version == 1:
        log_dir = experiment_config['logDir'] if experiment_config.get('logDir') else NNI_HOME_DIR
    else:
        log_dir = experiment_config['experimentWorkingDirectory'] if experiment_config.get('experimentWorkingDirectory') else NNI_HOME_DIR
378
    log_level = experiment_config['logLevel'] if experiment_config.get('logLevel') else 'info'
SparkSnail's avatar
SparkSnail committed
379
    #view experiment mode do not need debug function, when view an experiment, there will be no new logs created
SparkSnail's avatar
SparkSnail committed
380
    foreground = False
SparkSnail's avatar
SparkSnail committed
381
    if mode != 'view':
SparkSnail's avatar
SparkSnail committed
382
        foreground = args.foreground
SparkSnail's avatar
SparkSnail committed
383
384
        if log_level not in ['trace', 'debug'] and (args.debug or experiment_config.get('debug') is True):
            log_level = 'debug'
Deshui Yu's avatar
Deshui Yu committed
385
    # start rest server
386
387
    if config_version == 1:
        platform = experiment_config['trainingServicePlatform']
liuzhe-lz's avatar
liuzhe-lz committed
388
389
    elif isinstance(experiment_config['trainingService'], list):
        platform = 'hybrid'
390
391
392
393
    else:
        platform = experiment_config['trainingService']['platform']

    rest_process, start_time = start_rest_server(args.port, platform, \
394
                                                 mode, experiment_id, foreground, log_dir, log_level, args.url_prefix)
395
    # save experiment information
J-shang's avatar
J-shang committed
396
    Experiments().add_experiment(experiment_id, args.port, start_time,
397
                                 platform,
398
399
                                 experiment_config.get('experimentName', 'N/A')
                                 , pid=rest_process.pid, logDir=log_dir, prefixUrl=args.url_prefix)
Deshui Yu's avatar
Deshui Yu committed
400
401
    # Deal with annotation
    if experiment_config.get('useAnnotation'):
402
        path = os.path.join(tempfile.gettempdir(), get_user(), 'nni', 'annotation')
QuanluZhang's avatar
QuanluZhang committed
403
404
        if not os.path.isdir(path):
            os.makedirs(path)
liuzhe-lz's avatar
liuzhe-lz committed
405
        path = tempfile.mkdtemp(dir=path)
406
407
408
409
410
411
412
        if config_version == 1:
            nas_mode = experiment_config['trial'].get('nasMode', 'classic_mode')
            code_dir = expand_annotations(experiment_config['trial']['codeDir'], path, nas_mode=nas_mode)
            experiment_config['trial']['codeDir'] = code_dir
        else:
            code_dir = expand_annotations(experiment_config['trialCodeDirectory'], path)
            experiment_config['trialCodeDirectory'] = code_dir
liuzhe-lz's avatar
liuzhe-lz committed
413
        search_space = generate_search_space(code_dir)
liuzhe-lz's avatar
liuzhe-lz committed
414
        experiment_config['searchSpace'] = search_space
Deshui Yu's avatar
Deshui Yu committed
415
        assert search_space, ERROR_INFO % 'Generated search space is empty'
416
417
418
    elif config_version == 1:
        if experiment_config.get('searchSpacePath'):
            search_space = get_json_content(experiment_config.get('searchSpacePath'))
liuzhe-lz's avatar
liuzhe-lz committed
419
            experiment_config['searchSpace'] = search_space
420
        else:
liuzhe-lz's avatar
liuzhe-lz committed
421
            experiment_config['searchSpace'] = ''
Deshui Yu's avatar
Deshui Yu committed
422
423

    # check rest server
goooxu's avatar
goooxu committed
424
    running, _ = check_rest_server(args.port)
425
    if running:
426
        print_normal('Successfully started Restful server!')
Deshui Yu's avatar
Deshui Yu committed
427
428
    else:
        print_error('Restful server start failed!')
429
        print_log_content(experiment_id)
Deshui Yu's avatar
Deshui Yu committed
430
        try:
431
            kill_command(rest_process.pid)
Deshui Yu's avatar
Deshui Yu committed
432
433
        except Exception:
            raise Exception(ERROR_INFO % 'Rest server stopped!')
goooxu's avatar
goooxu committed
434
        exit(1)
435
436
437
438
    if config_version == 1 and mode != 'view':
        # set platform configuration
        set_platform_config(experiment_config['trainingServicePlatform'], experiment_config, args.port,\
                            experiment_id, rest_process)
chicm-ms's avatar
chicm-ms committed
439

Deshui Yu's avatar
Deshui Yu committed
440
441
    # start a new experiment
    print_normal('Starting experiment...')
442
    # set debug configuration
SparkSnail's avatar
SparkSnail committed
443
    if mode != 'view' and experiment_config.get('debug') is None:
444
        experiment_config['debug'] = args.debug
445
446
447
448
    if config_version == 1:
        response = set_experiment_v1(experiment_config, mode, args.port, experiment_id)
    else:
        response = set_experiment_v2(experiment_config, mode, args.port, experiment_id)
Deshui Yu's avatar
Deshui Yu committed
449
450
451
452
    if response:
        if experiment_id is None:
            experiment_id = json.loads(response.text).get('experiment_id')
    else:
453
        print_error('Start experiment failed!')
454
        print_log_content(experiment_id)
Deshui Yu's avatar
Deshui Yu committed
455
        try:
456
            kill_command(rest_process.pid)
Deshui Yu's avatar
Deshui Yu committed
457
        except Exception:
458
            raise Exception(ERROR_INFO % 'Restful server stopped!')
goooxu's avatar
goooxu committed
459
        exit(1)
Ni Hao's avatar
Ni Hao committed
460
    url_prefix_format = '' if args.url_prefix is None else '/{0}'.format(args.url_prefix)
461
    if experiment_config.get('nniManagerIp'):
Ni Hao's avatar
Ni Hao committed
462
        web_ui_url_list = ['http://{0}:{1}{2}'.format(experiment_config['nniManagerIp'], str(args.port), url_prefix_format)]
463
    else:
Ni Hao's avatar
Ni Hao committed
464
        web_ui_url_list = get_local_urls(args.port, url_prefix_format)
J-shang's avatar
J-shang committed
465
    Experiments().update_experiment(experiment_id, 'webuiUrl', web_ui_url_list)
466

467
    print_normal(EXPERIMENT_SUCCESS_INFO % (experiment_id, '   '.join(web_ui_url_list)))
SparkSnail's avatar
SparkSnail committed
468
    if mode != 'view' and args.foreground:
469
470
471
472
473
474
475
        try:
            while True:
                log_content = rest_process.stdout.readline().strip().decode('utf-8')
                print(log_content)
        except KeyboardInterrupt:
            kill_command(rest_process.pid)
            print_normal('Stopping experiment...')
Deshui Yu's avatar
Deshui Yu committed
476

liuzhe-lz's avatar
liuzhe-lz committed
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
def _validate_v1(config, path):
    try:
        validate_all_content(config, path)
    except Exception as e:
        print_error(f'Config V1 validation failed: {repr(e)}')
        exit(1)

def _validate_v2(config, path):
    base_path = Path(path).parent
    try:
        conf = ExperimentConfig(_base_path=base_path, **config)
        return conf.json()
    except Exception as e:
        print_error(f'Config V2 validation failed: {repr(e)}')

492
def _validate_prefix_path(path):
493
494
495
496
    assert not path.startswith('/'), 'URL prefix should not start with "/".'
    parts = path.split('/')
    valid = all(re.match('^[A-Za-z0-9_-]*$', part) for part in parts)
    assert valid, 'URL prefix should only contain letter, number, underscore, and hyphen.'
497

SparkSnail's avatar
SparkSnail committed
498
499
def create_experiment(args):
    '''start a new experiment'''
500
    experiment_id = ''.join(random.sample(string.ascii_letters + string.digits, 8))
SparkSnail's avatar
SparkSnail committed
501
502
503
504
    config_path = os.path.abspath(args.config)
    if not os.path.exists(config_path):
        print_error('Please set correct config path!')
        exit(1)
505
    config_yml = get_yml_content(config_path)
506

liuzhe-lz's avatar
liuzhe-lz committed
507
508
509
510
511
512
513
514
515
516
517
518
    if 'trainingServicePlatform' in config_yml:
        _validate_v1(config_yml, config_path)
        platform = config_yml['trainingServicePlatform']
        if platform in k8s_training_services:
            schema = 1
            config_v1 = config_yml
        else:
            schema = 2
            config_v2 = convert.to_v2(config_yml).json()
    else:
        config_v2 = _validate_v2(config_yml, config_path)
        schema = 2
SparkSnail's avatar
SparkSnail committed
519

520
    try:
liuzhe-lz's avatar
liuzhe-lz committed
521
522
        if schema == 1:
            launch_experiment(args, config_v1, 'new', experiment_id, 1)
523
524
        else:
            launch_experiment(args, config_v2, 'new', experiment_id, 2)
525
    except Exception as exception:
526
        restServerPid = Experiments().get_all_experiments().get(experiment_id, {}).get('pid')
527
528
529
530
        if restServerPid:
            kill_command(restServerPid)
        print_error(exception)
        exit(1)
SparkSnail's avatar
SparkSnail committed
531
532
533

def manage_stopped_experiment(args, mode):
    '''view a stopped experiment'''
SparkSnail's avatar
SparkSnail committed
534
    update_experiment()
J-shang's avatar
J-shang committed
535
536
    experiments_config = Experiments()
    experiments_dict = experiments_config.get_all_experiments()
537
538
539
    experiment_id = None
    #find the latest stopped experiment
    if not args.id:
540
        print_error('Please set experiment id! \nYou could use \'nnictl {0} id\' to {0} a stopped experiment!\n' \
SparkSnail's avatar
SparkSnail committed
541
        'You could use \'nnictl experiment list --all\' to show all experiments!'.format(mode))
SparkSnail's avatar
SparkSnail committed
542
        exit(1)
543
    else:
J-shang's avatar
J-shang committed
544
        if experiments_dict.get(args.id) is None:
545
546
            print_error('Id %s not exist!' % args.id)
            exit(1)
J-shang's avatar
J-shang committed
547
        if experiments_dict[args.id]['status'] != 'STOPPED':
SparkSnail's avatar
SparkSnail committed
548
            print_error('Only stopped experiments can be {0}ed!'.format(mode))
549
550
            exit(1)
        experiment_id = args.id
SparkSnail's avatar
SparkSnail committed
551
    print_normal('{0} experiment {1}...'.format(mode, experiment_id))
J-shang's avatar
J-shang committed
552
553
    experiment_config = Config(experiment_id, experiments_dict[args.id]['logDir']).get_config()
    experiments_config.update_experiment(args.id, 'port', args.port)
554
    args.url_prefix = experiments_dict[args.id]['prefixUrl']
555
    assert 'trainingService' in experiment_config or 'trainingServicePlatform' in experiment_config
556
    try:
SparkSnail's avatar
SparkSnail committed
557
        if 'trainingServicePlatform' in experiment_config:
558
            experiment_config['logDir'] = experiments_dict[args.id]['logDir']
559
            launch_experiment(args, experiment_config, mode, experiment_id, 1)
SparkSnail's avatar
SparkSnail committed
560
561
562
        else:
            experiment_config['experimentWorkingDirectory'] = experiments_dict[args.id]['logDir']
            launch_experiment(args, experiment_config, mode, experiment_id, 2)
563
    except Exception as exception:
564
        restServerPid = Experiments().get_all_experiments().get(experiment_id, {}).get('pid')
565
566
567
568
        if restServerPid:
            kill_command(restServerPid)
        print_error(exception)
        exit(1)
Deshui Yu's avatar
Deshui Yu committed
569

SparkSnail's avatar
SparkSnail committed
570
571
572
def view_experiment(args):
    '''view a stopped experiment'''
    manage_stopped_experiment(args, 'view')
Deshui Yu's avatar
Deshui Yu committed
573

SparkSnail's avatar
SparkSnail committed
574
575
def resume_experiment(args):
    '''resume an experiment'''
liuzhe-lz's avatar
liuzhe-lz committed
576
    manage_stopped_experiment(args, 'resume')