"lightx2v.egg-info/PKG-INFO" did not exist on "f062265aefbcabf849c9aaee3a14a8efc6385a0d"
launcher.py 30.5 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
Deshui Yu's avatar
Deshui Yu committed
3
4
5

import json
import os
6
import sys
7
import string
chicm-ms's avatar
chicm-ms committed
8
9
import random
import time
Deshui Yu's avatar
Deshui Yu committed
10
import tempfile
11
from subprocess import Popen, check_call, CalledProcessError, PIPE, STDOUT
12
13
from nni.tools.annotation import expand_annotations, generate_search_space
from nni.tools.package_utils import get_builtin_module_class_name
14
import nni_node
Deshui Yu's avatar
Deshui Yu committed
15
from .launcher_utils import validate_all_content
chicm-ms's avatar
chicm-ms committed
16
from .rest_utils import rest_put, rest_post, check_rest_server, check_response
SparkSnail's avatar
SparkSnail committed
17
from .url_utils import cluster_metadata_url, experiment_url, get_local_urls
18
from .config_utils import Config, Experiments
chicm-ms's avatar
chicm-ms committed
19
from .common_utils import get_yml_content, get_json_content, print_error, print_normal, \
chicm-ms's avatar
chicm-ms committed
20
21
22
                          detect_port, get_user

from .constants import NNICTL_HOME_DIR, ERROR_INFO, REST_TIME_OUT, EXPERIMENT_SUCCESS_INFO, LOG_HEADER, INSTALLABLE_PACKAGE_META
23
from .command_utils import check_output_command, kill_command
24
from .nnictl_utils import update_experiment
Gems Guo's avatar
Gems Guo committed
25

26
def get_log_path(experiment_id):
27
    '''generate stdout and stderr log path'''
28
29
30
    os.makedirs(os.path.join(NNICTL_HOME_DIR, experiment_id, 'log'), exist_ok=True)
    stdout_full_path = os.path.join(NNICTL_HOME_DIR, experiment_id, 'log', 'nnictl_stdout.log')
    stderr_full_path = os.path.join(NNICTL_HOME_DIR, experiment_id, 'log', 'nnictl_stderr.log')
31
32
33
34
35
36
    return stdout_full_path, stderr_full_path

def print_log_content(config_file_name):
    '''print log information'''
    stdout_full_path, stderr_full_path = get_log_path(config_file_name)
    print_normal(' Stdout:')
37
    print(check_output_command(stdout_full_path))
38
39
    print('\n\n')
    print_normal(' Stderr:')
40
    print(check_output_command(stderr_full_path))
41

42
def start_rest_server(port, platform, mode, experiment_id, foreground=False, log_dir=None, log_level=None):
Deshui Yu's avatar
Deshui Yu committed
43
    '''Run nni manager process'''
SparkSnail's avatar
SparkSnail committed
44
    if detect_port(port):
45
        print_error('Port %s is used by another process, please reset the port!\n' \
SparkSnail's avatar
SparkSnail committed
46
        'You could use \'nnictl create --help\' to get help information' % port)
47
        exit(1)
48

SparkSnail's avatar
SparkSnail committed
49
    if (platform != 'local') and detect_port(int(port) + 1):
50
51
        print_error('PAI mode need an additional adjacent port %d, and the port %d is used by another process!\n' \
        'You could set another port to start experiment!\n' \
SparkSnail's avatar
SparkSnail committed
52
        'You could use \'nnictl create --help\' to get help information' % ((int(port) + 1), (int(port) + 1)))
53
        exit(1)
Deshui Yu's avatar
Deshui Yu committed
54
55

    print_normal('Starting restful server...')
56

57
    entry_dir = nni_node.__path__[0]
chicm-ms's avatar
chicm-ms committed
58
59
60
    if (not entry_dir) or (not os.path.exists(entry_dir)):
        print_error('Fail to find nni under python library')
        exit(1)
Zejun Lin's avatar
Zejun Lin committed
61
    entry_file = os.path.join(entry_dir, 'main.js')
chicm-ms's avatar
chicm-ms committed
62

demianzhang's avatar
demianzhang committed
63
    if sys.platform == 'win32':
64
65
        node_command = os.path.join(entry_dir, 'node.exe')
    else:
liuzhe-lz's avatar
liuzhe-lz committed
66
        node_command = os.path.join(entry_dir, 'node')
67
68
    cmds = [node_command, '--max-old-space-size=4096', entry_file, '--port', str(port), '--mode', platform, \
            '--experiment_id', experiment_id]
SparkSnail's avatar
SparkSnail committed
69
70
71
72
73
    if mode == 'view':
        cmds += ['--start_mode', 'resume']
        cmds += ['--readonly', 'true']
    else:
        cmds += ['--start_mode', mode]
74
75
76
77
    if log_dir is not None:
        cmds += ['--log_dir', log_dir]
    if log_level is not None:
        cmds += ['--log_level', log_level]
SparkSnail's avatar
SparkSnail committed
78
    if foreground:
79
        cmds += ['--foreground', 'true']
80
    stdout_full_path, stderr_full_path = get_log_path(experiment_id)
81
    with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file:
82
83
        start_time = time.time()
        time_now = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))
84
85
86
87
88
89
        #add time information in the header of log files
        log_header = LOG_HEADER % str(time_now)
        stdout_file.write(log_header)
        stderr_file.write(log_header)
        if sys.platform == 'win32':
            from subprocess import CREATE_NEW_PROCESS_GROUP
SparkSnail's avatar
SparkSnail committed
90
            if foreground:
91
92
93
                process = Popen(cmds, cwd=entry_dir, stdout=PIPE, stderr=STDOUT, creationflags=CREATE_NEW_PROCESS_GROUP)
            else:
                process = Popen(cmds, cwd=entry_dir, stdout=stdout_file, stderr=stderr_file, creationflags=CREATE_NEW_PROCESS_GROUP)
94
        else:
SparkSnail's avatar
SparkSnail committed
95
            if foreground:
96
97
98
                process = Popen(cmds, cwd=entry_dir, stdout=PIPE, stderr=PIPE)
            else:
                process = Popen(cmds, cwd=entry_dir, stdout=stdout_file, stderr=stderr_file)
99
    return process, int(start_time * 1000)
Deshui Yu's avatar
Deshui Yu committed
100

101
def set_trial_config(experiment_config, port, config_file_name):
102
    '''set trial configuration'''
Deshui Yu's avatar
Deshui Yu committed
103
    request_data = dict()
104
    request_data['trial_config'] = experiment_config['trial']
105
    response = rest_put(cluster_metadata_url(port), json.dumps(request_data), REST_TIME_OUT)
106
107
108
    if check_response(response):
        return True
    else:
109
        print('Error message is {}'.format(response.text))
110
        _, stderr_full_path = get_log_path(config_file_name)
SparkSnail's avatar
SparkSnail committed
111
112
113
        if response:
            with open(stderr_full_path, 'a+') as fout:
                fout.write(json.dumps(json.loads(response.text), indent=4, sort_keys=True, separators=(',', ':')))
114
        return False
115

116
def set_local_config(experiment_config, port, config_file_name):
117
    '''set local configuration'''
118
119
120
    request_data = dict()
    if experiment_config.get('localConfig'):
        request_data['local_config'] = experiment_config['localConfig']
121
122
123
124
125
126
127
        if request_data['local_config']:
            if request_data['local_config'].get('gpuIndices') and isinstance(request_data['local_config'].get('gpuIndices'), int):
                request_data['local_config']['gpuIndices'] = str(request_data['local_config'].get('gpuIndices'))
            if request_data['local_config'].get('maxTrialNumOnEachGpu'):
                request_data['local_config']['maxTrialNumOnEachGpu'] = request_data['local_config'].get('maxTrialNumOnEachGpu')
            if request_data['local_config'].get('useActiveGpu'):
                request_data['local_config']['useActiveGpu'] = request_data['local_config'].get('useActiveGpu')
128
129
130
131
132
133
134
135
136
137
        response = rest_put(cluster_metadata_url(port), json.dumps(request_data), REST_TIME_OUT)
        err_message = ''
        if not response or not check_response(response):
            if response is not None:
                err_message = response.text
                _, stderr_full_path = get_log_path(config_file_name)
                with open(stderr_full_path, 'a+') as fout:
                    fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
            return False, err_message

SparkSnail's avatar
SparkSnail committed
138
    return set_trial_config(experiment_config, port, config_file_name), None
Deshui Yu's avatar
Deshui Yu committed
139

140
141
142
143
144
145
146
147
def set_adl_config(experiment_config, port, config_file_name):
    '''set adl configuration'''
    result, message = setNNIManagerIp(experiment_config, port, config_file_name)
    if not result:
        return result, message
    #set trial_config
    return set_trial_config(experiment_config, port, config_file_name), None

148
def set_remote_config(experiment_config, port, config_file_name):
Deshui Yu's avatar
Deshui Yu committed
149
150
151
    '''Call setClusterMetadata to pass trial'''
    #set machine_list
    request_data = dict()
SparkSnail's avatar
SparkSnail committed
152
153
154
155
    if experiment_config.get('remoteConfig'):
        request_data['remote_config'] = experiment_config['remoteConfig']
    else:
        request_data['remote_config'] = {'reuse': False}
Deshui Yu's avatar
Deshui Yu committed
156
    request_data['machine_list'] = experiment_config['machineList']
157
158
159
160
    if request_data['machine_list']:
        for i in range(len(request_data['machine_list'])):
            if isinstance(request_data['machine_list'][i].get('gpuIndices'), int):
                request_data['machine_list'][i]['gpuIndices'] = str(request_data['machine_list'][i].get('gpuIndices'))
161
162
163
    # It needs to connect all remote machines, the time out of connection is 30 seconds.
    # So timeout of this place should be longer.
    response = rest_put(cluster_metadata_url(port), json.dumps(request_data), 60, True)
164
    err_message = ''
165
    if not response or not check_response(response):
166
167
        if response is not None:
            err_message = response.text
168
            _, stderr_full_path = get_log_path(config_file_name)
goooxu's avatar
goooxu committed
169
            with open(stderr_full_path, 'a+') as fout:
170
                fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
171
        return False, err_message
172
173
174
    result, message = setNNIManagerIp(experiment_config, port, config_file_name)
    if not result:
        return result, message
Deshui Yu's avatar
Deshui Yu committed
175
    #set trial_config
176
    return set_trial_config(experiment_config, port, config_file_name), err_message
Deshui Yu's avatar
Deshui Yu committed
177

178
179
180
181
182
def setNNIManagerIp(experiment_config, port, config_file_name):
    '''set nniManagerIp'''
    if experiment_config.get('nniManagerIp') is None:
        return True, None
    ip_config_dict = dict()
chicm-ms's avatar
chicm-ms committed
183
    ip_config_dict['nni_manager_ip'] = {'nniManagerIp': experiment_config['nniManagerIp']}
184
    response = rest_put(cluster_metadata_url(port), json.dumps(ip_config_dict), REST_TIME_OUT)
185
186
187
188
189
190
191
192
193
194
    err_message = None
    if not response or not response.status_code == 200:
        if response is not None:
            err_message = response.text
            _, stderr_full_path = get_log_path(config_file_name)
            with open(stderr_full_path, 'a+') as fout:
                fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
        return False, err_message
    return True, None

195
def set_pai_config(experiment_config, port, config_file_name):
196
    '''set pai configuration'''
197
198
    pai_config_data = dict()
    pai_config_data['pai_config'] = experiment_config['paiConfig']
199
    response = rest_put(cluster_metadata_url(port), json.dumps(pai_config_data), REST_TIME_OUT)
200
    err_message = None
201
202
203
    if not response or not response.status_code == 200:
        if response is not None:
            err_message = response.text
204
            _, stderr_full_path = get_log_path(config_file_name)
205
            with open(stderr_full_path, 'a+') as fout:
chicm-ms's avatar
chicm-ms committed
206
                fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
207
        return False, err_message
208
209
210
    result, message = setNNIManagerIp(experiment_config, port, config_file_name)
    if not result:
        return result, message
211
    #set trial_config
212
    return set_trial_config(experiment_config, port, config_file_name), err_message
213

214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
def set_pai_yarn_config(experiment_config, port, config_file_name):
    '''set paiYarn configuration'''
    pai_yarn_config_data = dict()
    pai_yarn_config_data['pai_yarn_config'] = experiment_config['paiYarnConfig']
    response = rest_put(cluster_metadata_url(port), json.dumps(pai_yarn_config_data), REST_TIME_OUT)
    err_message = None
    if not response or not response.status_code == 200:
        if response is not None:
            err_message = response.text
            _, stderr_full_path = get_log_path(config_file_name)
            with open(stderr_full_path, 'a+') as fout:
                fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
        return False, err_message
    result, message = setNNIManagerIp(experiment_config, port, config_file_name)
    if not result:
        return result, message
    #set trial_config
    return set_trial_config(experiment_config, port, config_file_name), err_message

233
def set_kubeflow_config(experiment_config, port, config_file_name):
234
    '''set kubeflow configuration'''
235
236
    kubeflow_config_data = dict()
    kubeflow_config_data['kubeflow_config'] = experiment_config['kubeflowConfig']
237
    response = rest_put(cluster_metadata_url(port), json.dumps(kubeflow_config_data), REST_TIME_OUT)
238
239
240
241
242
243
244
245
    err_message = None
    if not response or not response.status_code == 200:
        if response is not None:
            err_message = response.text
            _, stderr_full_path = get_log_path(config_file_name)
            with open(stderr_full_path, 'a+') as fout:
                fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
        return False, err_message
246
247
248
    result, message = setNNIManagerIp(experiment_config, port, config_file_name)
    if not result:
        return result, message
249
250
251
    #set trial_config
    return set_trial_config(experiment_config, port, config_file_name), err_message

252
def set_frameworkcontroller_config(experiment_config, port, config_file_name):
253
    '''set kubeflow configuration'''
254
255
    frameworkcontroller_config_data = dict()
    frameworkcontroller_config_data['frameworkcontroller_config'] = experiment_config['frameworkcontrollerConfig']
256
    response = rest_put(cluster_metadata_url(port), json.dumps(frameworkcontroller_config_data), REST_TIME_OUT)
257
258
259
260
261
262
263
264
265
266
267
268
269
270
    err_message = None
    if not response or not response.status_code == 200:
        if response is not None:
            err_message = response.text
            _, stderr_full_path = get_log_path(config_file_name)
            with open(stderr_full_path, 'a+') as fout:
                fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
        return False, err_message
    result, message = setNNIManagerIp(experiment_config, port, config_file_name)
    if not result:
        return result, message
    #set trial_config
    return set_trial_config(experiment_config, port, config_file_name), err_message

George Cheng's avatar
George Cheng committed
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
def set_dlts_config(experiment_config, port, config_file_name):
    '''set dlts configuration'''
    dlts_config_data = dict()
    dlts_config_data['dlts_config'] = experiment_config['dltsConfig']
    response = rest_put(cluster_metadata_url(port), json.dumps(dlts_config_data), REST_TIME_OUT)
    err_message = None
    if not response or not response.status_code == 200:
        if response is not None:
            err_message = response.text
            _, stderr_full_path = get_log_path(config_file_name)
            with open(stderr_full_path, 'a+') as fout:
                fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
        return False, err_message
    result, message = setNNIManagerIp(experiment_config, port, config_file_name)
    if not result:
        return result, message
    #set trial_config
    return set_trial_config(experiment_config, port, config_file_name), err_message

SparkSnail's avatar
SparkSnail committed
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
def set_aml_config(experiment_config, port, config_file_name):
    '''set aml configuration'''
    aml_config_data = dict()
    aml_config_data['aml_config'] = experiment_config['amlConfig']
    response = rest_put(cluster_metadata_url(port), json.dumps(aml_config_data), REST_TIME_OUT)
    err_message = None
    if not response or not response.status_code == 200:
        if response is not None:
            err_message = response.text
            _, stderr_full_path = get_log_path(config_file_name)
            with open(stderr_full_path, 'a+') as fout:
                fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
        return False, err_message
    result, message = setNNIManagerIp(experiment_config, port, config_file_name)
    if not result:
        return result, message
    #set trial_config
    return set_trial_config(experiment_config, port, config_file_name), err_message

309
def set_experiment(experiment_config, mode, port, config_file_name):
Deshui Yu's avatar
Deshui Yu committed
310
311
312
313
314
315
316
    '''Call startExperiment (rest POST /experiment) with yaml file content'''
    request_data = dict()
    request_data['authorName'] = experiment_config['authorName']
    request_data['experimentName'] = experiment_config['experimentName']
    request_data['trialConcurrency'] = experiment_config['trialConcurrency']
    request_data['maxExecDuration'] = experiment_config['maxExecDuration']
    request_data['maxTrialNum'] = experiment_config['maxTrialNum']
317
    request_data['searchSpace'] = experiment_config.get('searchSpace')
318
319
320
    request_data['trainingServicePlatform'] = experiment_config.get('trainingServicePlatform')
    if experiment_config.get('description'):
        request_data['description'] = experiment_config['description']
chicm-ms's avatar
chicm-ms committed
321
322
    if experiment_config.get('multiPhase'):
        request_data['multiPhase'] = experiment_config.get('multiPhase')
323
324
    if experiment_config.get('multiThread'):
        request_data['multiThread'] = experiment_config.get('multiThread')
QuanluZhang's avatar
QuanluZhang committed
325
326
    if experiment_config.get('advisor'):
        request_data['advisor'] = experiment_config['advisor']
327
328
329
330
        if request_data['advisor'].get('gpuNum'):
            print_error('gpuNum is deprecated, please use gpuIndices instead.')
        if request_data['advisor'].get('gpuIndices') and isinstance(request_data['advisor'].get('gpuIndices'), int):
            request_data['advisor']['gpuIndices'] = str(request_data['advisor'].get('gpuIndices'))
QuanluZhang's avatar
QuanluZhang committed
331
332
    else:
        request_data['tuner'] = experiment_config['tuner']
333
334
335
336
        if request_data['tuner'].get('gpuNum'):
            print_error('gpuNum is deprecated, please use gpuIndices instead.')
        if request_data['tuner'].get('gpuIndices') and isinstance(request_data['tuner'].get('gpuIndices'), int):
            request_data['tuner']['gpuIndices'] = str(request_data['tuner'].get('gpuIndices'))
QuanluZhang's avatar
QuanluZhang committed
337
338
        if 'assessor' in experiment_config:
            request_data['assessor'] = experiment_config['assessor']
339
340
            if request_data['assessor'].get('gpuNum'):
                print_error('gpuNum is deprecated, please remove it from your config file.')
SparkSnail's avatar
SparkSnail committed
341
    #debug mode should disable version check
342
    if experiment_config.get('debug') is not None:
SparkSnail's avatar
SparkSnail committed
343
        request_data['versionCheck'] = not experiment_config.get('debug')
344
345
346
    #validate version check
    if experiment_config.get('versionCheck') is not None:
        request_data['versionCheck'] = experiment_config.get('versionCheck')
SparkSnail's avatar
SparkSnail committed
347
348
    if experiment_config.get('logCollection'):
        request_data['logCollection'] = experiment_config.get('logCollection')
Deshui Yu's avatar
Deshui Yu committed
349
350
351
    request_data['clusterMetaData'] = []
    if experiment_config['trainingServicePlatform'] == 'local':
        request_data['clusterMetaData'].append(
352
            {'key':'codeDir', 'value':experiment_config['trial']['codeDir']})
Deshui Yu's avatar
Deshui Yu committed
353
        request_data['clusterMetaData'].append(
354
            {'key': 'command', 'value': experiment_config['trial']['command']})
355
    elif experiment_config['trainingServicePlatform'] == 'remote':
Deshui Yu's avatar
Deshui Yu committed
356
357
358
        request_data['clusterMetaData'].append(
            {'key': 'machine_list', 'value': experiment_config['machineList']})
        request_data['clusterMetaData'].append(
359
            {'key': 'trial_config', 'value': experiment_config['trial']})
360
361
362
363
364
        if not experiment_config.get('remoteConfig'):
            # set default value of reuse in remoteConfig to False
            experiment_config['remoteConfig'] = {'reuse': False}
        request_data['clusterMetaData'].append(
            {'key': 'remote_config', 'value': experiment_config['remoteConfig']})
365
366
    elif experiment_config['trainingServicePlatform'] == 'pai':
        request_data['clusterMetaData'].append(
367
            {'key': 'pai_config', 'value': experiment_config['paiConfig']})
368
        request_data['clusterMetaData'].append(
369
            {'key': 'trial_config', 'value': experiment_config['trial']})
370
371
372
373
374
    elif experiment_config['trainingServicePlatform'] == 'paiYarn':
        request_data['clusterMetaData'].append(
            {'key': 'pai_yarn_config', 'value': experiment_config['paiYarnConfig']})
        request_data['clusterMetaData'].append(
            {'key': 'trial_config', 'value': experiment_config['trial']})
375
376
377
378
379
    elif experiment_config['trainingServicePlatform'] == 'kubeflow':
        request_data['clusterMetaData'].append(
            {'key': 'kubeflow_config', 'value': experiment_config['kubeflowConfig']})
        request_data['clusterMetaData'].append(
            {'key': 'trial_config', 'value': experiment_config['trial']})
380
381
382
383
384
    elif experiment_config['trainingServicePlatform'] == 'frameworkcontroller':
        request_data['clusterMetaData'].append(
            {'key': 'frameworkcontroller_config', 'value': experiment_config['frameworkcontrollerConfig']})
        request_data['clusterMetaData'].append(
            {'key': 'trial_config', 'value': experiment_config['trial']})
385
386
387
388
389
    elif experiment_config['trainingServicePlatform'] == 'aml':
        request_data['clusterMetaData'].append(
            {'key': 'aml_config', 'value': experiment_config['amlConfig']})
        request_data['clusterMetaData'].append(
            {'key': 'trial_config', 'value': experiment_config['trial']})
390
    response = rest_post(experiment_url(port), json.dumps(request_data), REST_TIME_OUT, show_error=True)
391
392
393
    if check_response(response):
        return response
    else:
394
        _, stderr_full_path = get_log_path(config_file_name)
395
        if response is not None:
SparkSnail's avatar
SparkSnail committed
396
397
            with open(stderr_full_path, 'a+') as fout:
                fout.write(json.dumps(json.loads(response.text), indent=4, sort_keys=True, separators=(',', ':')))
SparkSnail's avatar
SparkSnail committed
398
            print_error('Setting experiment error, error message is {}'.format(response.text))
399
        return None
Deshui Yu's avatar
Deshui Yu committed
400

SparkSnail's avatar
SparkSnail committed
401
402
403
404
def set_platform_config(platform, experiment_config, port, config_file_name, rest_process):
    '''call set_cluster_metadata for specific platform'''
    print_normal('Setting {0} config...'.format(platform))
    config_result, err_msg = None, None
405
406
407
    if platform == 'adl':
        config_result, err_msg = set_adl_config(experiment_config, port, config_file_name)
    elif platform == 'local':
SparkSnail's avatar
SparkSnail committed
408
409
410
411
412
        config_result, err_msg = set_local_config(experiment_config, port, config_file_name)
    elif platform == 'remote':
        config_result, err_msg = set_remote_config(experiment_config, port, config_file_name)
    elif platform == 'pai':
        config_result, err_msg = set_pai_config(experiment_config, port, config_file_name)
413
414
    elif platform == 'paiYarn':
        config_result, err_msg = set_pai_yarn_config(experiment_config, port, config_file_name)
SparkSnail's avatar
SparkSnail committed
415
416
417
418
    elif platform == 'kubeflow':
        config_result, err_msg = set_kubeflow_config(experiment_config, port, config_file_name)
    elif platform == 'frameworkcontroller':
        config_result, err_msg = set_frameworkcontroller_config(experiment_config, port, config_file_name)
George Cheng's avatar
George Cheng committed
419
420
    elif platform == 'dlts':
        config_result, err_msg = set_dlts_config(experiment_config, port, config_file_name)
SparkSnail's avatar
SparkSnail committed
421
422
    elif platform == 'aml':
        config_result, err_msg = set_aml_config(experiment_config, port, config_file_name)
SparkSnail's avatar
SparkSnail committed
423
424
425
426
427
428
429
430
431
432
433
434
435
    else:
        raise Exception(ERROR_INFO % 'Unsupported platform!')
        exit(1)
    if config_result:
        print_normal('Successfully set {0} config!'.format(platform))
    else:
        print_error('Failed! Error is: {}'.format(err_msg))
        try:
            kill_command(rest_process.pid)
        except Exception:
            raise Exception(ERROR_INFO % 'Rest server stopped!')
        exit(1)

436
def launch_experiment(args, experiment_config, mode, experiment_id):
Deshui Yu's avatar
Deshui Yu committed
437
    '''follow steps to start rest server and start experiment'''
438
    nni_config = Config(experiment_id)
439
    # check packages for tuner
440
    package_name, module_name = None, None
441
    if experiment_config.get('tuner') and experiment_config['tuner'].get('builtinTunerName'):
442
        package_name = experiment_config['tuner']['builtinTunerName']
chicm-ms's avatar
chicm-ms committed
443
        module_name, _ = get_builtin_module_class_name('tuners', package_name)
444
445
    elif experiment_config.get('advisor') and experiment_config['advisor'].get('builtinAdvisorName'):
        package_name = experiment_config['advisor']['builtinAdvisorName']
chicm-ms's avatar
chicm-ms committed
446
        module_name, _ = get_builtin_module_class_name('advisors', package_name)
447
    if package_name and module_name:
448
        try:
449
            stdout_full_path, stderr_full_path = get_log_path(experiment_id)
450
451
            with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file:
                check_call([sys.executable, '-c', 'import %s'%(module_name)], stdout=stdout_file, stderr=stderr_file)
chicm-ms's avatar
chicm-ms committed
452
        except CalledProcessError:
453
            print_error('some errors happen when import package %s.' %(package_name))
454
            print_log_content(experiment_id)
chicm-ms's avatar
chicm-ms committed
455
            if package_name in INSTALLABLE_PACKAGE_META:
chicm-ms's avatar
chicm-ms committed
456
                print_error('If %s is not installed, it should be installed through '\
457
                            '\'nnictl package install --name %s\'' % (package_name, package_name))
458
            exit(1)
459
460
    log_dir = experiment_config['logDir'] if experiment_config.get('logDir') else None
    log_level = experiment_config['logLevel'] if experiment_config.get('logLevel') else None
SparkSnail's avatar
SparkSnail committed
461
    #view experiment mode do not need debug function, when view an experiment, there will be no new logs created
SparkSnail's avatar
SparkSnail committed
462
    foreground = False
SparkSnail's avatar
SparkSnail committed
463
    if mode != 'view':
SparkSnail's avatar
SparkSnail committed
464
        foreground = args.foreground
SparkSnail's avatar
SparkSnail committed
465
466
        if log_level not in ['trace', 'debug'] and (args.debug or experiment_config.get('debug') is True):
            log_level = 'debug'
Deshui Yu's avatar
Deshui Yu committed
467
    # start rest server
SparkSnail's avatar
SparkSnail committed
468
    rest_process, start_time = start_rest_server(args.port, experiment_config['trainingServicePlatform'], \
469
                                                 mode, experiment_id, foreground, log_dir, log_level)
Deshui Yu's avatar
Deshui Yu committed
470
471
472
    nni_config.set_config('restServerPid', rest_process.pid)
    # Deal with annotation
    if experiment_config.get('useAnnotation'):
473
        path = os.path.join(tempfile.gettempdir(), get_user(), 'nni', 'annotation')
QuanluZhang's avatar
QuanluZhang committed
474
475
        if not os.path.isdir(path):
            os.makedirs(path)
liuzhe-lz's avatar
liuzhe-lz committed
476
        path = tempfile.mkdtemp(dir=path)
477
478
        nas_mode = experiment_config['trial'].get('nasMode', 'classic_mode')
        code_dir = expand_annotations(experiment_config['trial']['codeDir'], path, nas_mode=nas_mode)
liuzhe-lz's avatar
liuzhe-lz committed
479
480
        experiment_config['trial']['codeDir'] = code_dir
        search_space = generate_search_space(code_dir)
481
        experiment_config['searchSpace'] = json.dumps(search_space)
Deshui Yu's avatar
Deshui Yu committed
482
        assert search_space, ERROR_INFO % 'Generated search space is empty'
483
    elif experiment_config.get('searchSpacePath'):
Zejun Lin's avatar
Zejun Lin committed
484
485
        search_space = get_json_content(experiment_config.get('searchSpacePath'))
        experiment_config['searchSpace'] = json.dumps(search_space)
Deshui Yu's avatar
Deshui Yu committed
486
    else:
487
        experiment_config['searchSpace'] = json.dumps('')
Deshui Yu's avatar
Deshui Yu committed
488
489

    # check rest server
goooxu's avatar
goooxu committed
490
    running, _ = check_rest_server(args.port)
491
    if running:
492
        print_normal('Successfully started Restful server!')
Deshui Yu's avatar
Deshui Yu committed
493
494
    else:
        print_error('Restful server start failed!')
495
        print_log_content(experiment_id)
Deshui Yu's avatar
Deshui Yu committed
496
        try:
497
            kill_command(rest_process.pid)
Deshui Yu's avatar
Deshui Yu committed
498
499
        except Exception:
            raise Exception(ERROR_INFO % 'Rest server stopped!')
goooxu's avatar
goooxu committed
500
        exit(1)
SparkSnail's avatar
SparkSnail committed
501
502
    if mode != 'view':
        # set platform configuration
chicm-ms's avatar
chicm-ms committed
503
        set_platform_config(experiment_config['trainingServicePlatform'], experiment_config, args.port,\
504
                            experiment_id, rest_process)
chicm-ms's avatar
chicm-ms committed
505

Deshui Yu's avatar
Deshui Yu committed
506
507
    # start a new experiment
    print_normal('Starting experiment...')
508
509
510
511
512
    # save experiment information
    nnictl_experiment_config = Experiments()
    nnictl_experiment_config.add_experiment(experiment_id, args.port, start_time,
                                            experiment_config['trainingServicePlatform'],
                                            experiment_config['experimentName'], pid=rest_process.pid, logDir=log_dir)
513
    # set debug configuration
SparkSnail's avatar
SparkSnail committed
514
    if mode != 'view' and experiment_config.get('debug') is None:
515
        experiment_config['debug'] = args.debug
516
    response = set_experiment(experiment_config, mode, args.port, experiment_id)
Deshui Yu's avatar
Deshui Yu committed
517
518
519
520
    if response:
        if experiment_id is None:
            experiment_id = json.loads(response.text).get('experiment_id')
    else:
521
        print_error('Start experiment failed!')
522
        print_log_content(experiment_id)
Deshui Yu's avatar
Deshui Yu committed
523
        try:
524
            kill_command(rest_process.pid)
Deshui Yu's avatar
Deshui Yu committed
525
        except Exception:
526
            raise Exception(ERROR_INFO % 'Restful server stopped!')
goooxu's avatar
goooxu committed
527
        exit(1)
528
529
530
531
    if experiment_config.get('nniManagerIp'):
        web_ui_url_list = ['{0}:{1}'.format(experiment_config['nniManagerIp'], str(args.port))]
    else:
        web_ui_url_list = get_local_urls(args.port)
SparkSnail's avatar
SparkSnail committed
532
    nni_config.set_config('webuiUrl', web_ui_url_list)
533

534
    print_normal(EXPERIMENT_SUCCESS_INFO % (experiment_id, '   '.join(web_ui_url_list)))
SparkSnail's avatar
SparkSnail committed
535
    if mode != 'view' and args.foreground:
536
537
538
539
540
541
542
        try:
            while True:
                log_content = rest_process.stdout.readline().strip().decode('utf-8')
                print(log_content)
        except KeyboardInterrupt:
            kill_command(rest_process.pid)
            print_normal('Stopping experiment...')
Deshui Yu's avatar
Deshui Yu committed
543

SparkSnail's avatar
SparkSnail committed
544
545
def create_experiment(args):
    '''start a new experiment'''
546
547
548
    experiment_id = ''.join(random.sample(string.ascii_letters + string.digits, 8))
    nni_config = Config(experiment_id)
    nni_config.set_config('experimentId', experiment_id)
SparkSnail's avatar
SparkSnail committed
549
550
551
552
553
    config_path = os.path.abspath(args.config)
    if not os.path.exists(config_path):
        print_error('Please set correct config path!')
        exit(1)
    experiment_config = get_yml_content(config_path)
chicm-ms's avatar
chicm-ms committed
554
555
556
557
558
    try:
        validate_all_content(experiment_config, config_path)
    except Exception as e:
        print_error(e)
        exit(1)
SparkSnail's avatar
SparkSnail committed
559
560
561

    nni_config.set_config('experimentConfig', experiment_config)
    nni_config.set_config('restServerPort', args.port)
562
    try:
563
        launch_experiment(args, experiment_config, 'new', experiment_id)
564
    except Exception as exception:
565
        nni_config = Config(experiment_id)
566
567
568
569
570
        restServerPid = nni_config.get_config('restServerPid')
        if restServerPid:
            kill_command(restServerPid)
        print_error(exception)
        exit(1)
SparkSnail's avatar
SparkSnail committed
571
572
573

def manage_stopped_experiment(args, mode):
    '''view a stopped experiment'''
SparkSnail's avatar
SparkSnail committed
574
    update_experiment()
575
576
577
578
579
    experiment_config = Experiments()
    experiment_dict = experiment_config.get_all_experiments()
    experiment_id = None
    #find the latest stopped experiment
    if not args.id:
580
        print_error('Please set experiment id! \nYou could use \'nnictl {0} id\' to {0} a stopped experiment!\n' \
SparkSnail's avatar
SparkSnail committed
581
        'You could use \'nnictl experiment list --all\' to show all experiments!'.format(mode))
SparkSnail's avatar
SparkSnail committed
582
        exit(1)
583
584
585
586
    else:
        if experiment_dict.get(args.id) is None:
            print_error('Id %s not exist!' % args.id)
            exit(1)
587
        if experiment_dict[args.id]['status'] != 'STOPPED':
SparkSnail's avatar
SparkSnail committed
588
            print_error('Only stopped experiments can be {0}ed!'.format(mode))
589
590
            exit(1)
        experiment_id = args.id
SparkSnail's avatar
SparkSnail committed
591
    print_normal('{0} experiment {1}...'.format(mode, experiment_id))
592
    nni_config = Config(experiment_id)
Deshui Yu's avatar
Deshui Yu committed
593
    experiment_config = nni_config.get_config('experimentConfig')
594
    nni_config.set_config('restServerPort', args.port)
595
    try:
596
        launch_experiment(args, experiment_config, mode, experiment_id)
597
    except Exception as exception:
598
        nni_config = Config(experiment_id)
599
600
601
602
603
        restServerPid = nni_config.get_config('restServerPid')
        if restServerPid:
            kill_command(restServerPid)
        print_error(exception)
        exit(1)
Deshui Yu's avatar
Deshui Yu committed
604

SparkSnail's avatar
SparkSnail committed
605
606
607
def view_experiment(args):
    '''view a stopped experiment'''
    manage_stopped_experiment(args, 'view')
Deshui Yu's avatar
Deshui Yu committed
608

SparkSnail's avatar
SparkSnail committed
609
610
def resume_experiment(args):
    '''resume an experiment'''
liuzhe-lz's avatar
liuzhe-lz committed
611
    manage_stopped_experiment(args, 'resume')