launcher.py 27.2 KB
Newer Older
Deshui Yu's avatar
Deshui Yu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge,
# to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.


import json
import os
24
import sys
25
import string
chicm-ms's avatar
chicm-ms committed
26
27
28
import random
import site
import time
Deshui Yu's avatar
Deshui Yu committed
29
import tempfile
chicm-ms's avatar
chicm-ms committed
30
31
from subprocess import Popen, check_call, CalledProcessError
from nni_annotation import expand_annotations, generate_search_space
32
from nni.constants import ModuleName, AdvisorModuleName
Deshui Yu's avatar
Deshui Yu committed
33
from .launcher_utils import validate_all_content
chicm-ms's avatar
chicm-ms committed
34
from .rest_utils import rest_put, rest_post, check_rest_server, check_response
SparkSnail's avatar
SparkSnail committed
35
from .url_utils import cluster_metadata_url, experiment_url, get_local_urls
36
from .config_utils import Config, Experiments
chicm-ms's avatar
chicm-ms committed
37
38
39
from .common_utils import get_yml_content, get_json_content, print_error, print_normal, \
                          detect_port, get_user, get_python_dir
from .constants import NNICTL_HOME_DIR, ERROR_INFO, REST_TIME_OUT, EXPERIMENT_SUCCESS_INFO, LOG_HEADER, PACKAGE_REQUIREMENTS
40
from .command_utils import check_output_command, kill_command
SparkSnail's avatar
SparkSnail committed
41
from .nnictl_utils import update_experiment
Gems Guo's avatar
Gems Guo committed
42

43
44
45
46
47
48
49
50
51
52
def get_log_path(config_file_name):
    '''generate stdout and stderr log path'''
    stdout_full_path = os.path.join(NNICTL_HOME_DIR, config_file_name, 'stdout')
    stderr_full_path = os.path.join(NNICTL_HOME_DIR, config_file_name, 'stderr')
    return stdout_full_path, stderr_full_path

def print_log_content(config_file_name):
    '''print log information'''
    stdout_full_path, stderr_full_path = get_log_path(config_file_name)
    print_normal(' Stdout:')
53
    print(check_output_command(stdout_full_path))
54
55
    print('\n\n')
    print_normal(' Stderr:')
56
    print(check_output_command(stderr_full_path))
57

Zejun Lin's avatar
Zejun Lin committed
58
59
60
61
62
63
64
65
66
def get_nni_installation_path():
    ''' Find nni lib from the following locations in order
    Return nni root directory if it exists
    '''
    def try_installation_path_sequentially(*sitepackages):
        '''Try different installation path sequentially util nni is found.
        Return None if nothing is found
        '''
        def _generate_installation_path(sitepackages_path):
67
            python_dir = get_python_dir(sitepackages_path)
Zejun Lin's avatar
Zejun Lin committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
            entry_file = os.path.join(python_dir, 'nni', 'main.js')
            if os.path.isfile(entry_file):
                return python_dir
            return None

        for sitepackage in sitepackages:
            python_dir = _generate_installation_path(sitepackage)
            if python_dir:
                return python_dir
        return None

    if os.getenv('VIRTUAL_ENV'):
        # if 'virtualenv' package is used, `site` has not attr getsitepackages, so we will instead use VIRTUAL_ENV
        # Note that conda venv will not have VIRTUAL_ENV
        python_dir = os.getenv('VIRTUAL_ENV')
    else:
        python_sitepackage = site.getsitepackages()[0]
chicm-ms's avatar
chicm-ms committed
85
86
        # If system-wide python is used, we will give priority to using `local sitepackage`--"usersitepackages()" given
        # that nni exists there
Zejun Lin's avatar
Zejun Lin committed
87
88
89
90
91
92
93
94
95
96
97
        if python_sitepackage.startswith('/usr') or python_sitepackage.startswith('/Library'):
            python_dir = try_installation_path_sequentially(site.getusersitepackages(), site.getsitepackages()[0])
        else:
            python_dir = try_installation_path_sequentially(site.getsitepackages()[0], site.getusersitepackages())

    if python_dir:
        entry_file = os.path.join(python_dir, 'nni', 'main.js')
        if os.path.isfile(entry_file):
            return os.path.join(python_dir, 'nni')
    print_error('Fail to find nni under python library')
    exit(1)
98

99
def start_rest_server(port, platform, mode, config_file_name, experiment_id=None, log_dir=None, log_level=None):
Deshui Yu's avatar
Deshui Yu committed
100
    '''Run nni manager process'''
101
    if detect_port(port):
102
103
104
        print_error('Port %s is used by another process, please reset the port!\n' \
        'You could use \'nnictl create --help\' to get help information' % port)
        exit(1)
105

SparkSnail's avatar
SparkSnail committed
106
    if (platform != 'local') and detect_port(int(port) + 1):
107
108
109
        print_error('PAI mode need an additional adjacent port %d, and the port %d is used by another process!\n' \
        'You could set another port to start experiment!\n' \
        'You could use \'nnictl create --help\' to get help information' % ((int(port) + 1), (int(port) + 1)))
110
        exit(1)
Deshui Yu's avatar
Deshui Yu committed
111
112

    print_normal('Starting restful server...')
113

Zejun Lin's avatar
Zejun Lin committed
114
115
    entry_dir = get_nni_installation_path()
    entry_file = os.path.join(entry_dir, 'main.js')
chicm-ms's avatar
chicm-ms committed
116

demianzhang's avatar
demianzhang committed
117
118
119
    node_command = 'node'
    if sys.platform == 'win32':
        node_command = os.path.join(entry_dir[:-3], 'Scripts', 'node.exe')
SparkSnail's avatar
SparkSnail committed
120
121
122
123
124
125
    cmds = [node_command, entry_file, '--port', str(port), '--mode', platform]
    if mode == 'view':
        cmds += ['--start_mode', 'resume']
        cmds += ['--readonly', 'true']
    else:
        cmds += ['--start_mode', mode]
126
127
128
129
    if log_dir is not None:
        cmds += ['--log_dir', log_dir]
    if log_level is not None:
        cmds += ['--log_level', log_level]
SparkSnail's avatar
SparkSnail committed
130
    if mode in ['resume', 'view']:
Deshui Yu's avatar
Deshui Yu committed
131
        cmds += ['--experiment_id', experiment_id]
132
    stdout_full_path, stderr_full_path = get_log_path(config_file_name)
133
    with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file:
chicm-ms's avatar
chicm-ms committed
134
        time_now = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
135
136
137
138
139
140
141
142
143
        #add time information in the header of log files
        log_header = LOG_HEADER % str(time_now)
        stdout_file.write(log_header)
        stderr_file.write(log_header)
        if sys.platform == 'win32':
            from subprocess import CREATE_NEW_PROCESS_GROUP
            process = Popen(cmds, cwd=entry_dir, stdout=stdout_file, stderr=stderr_file, creationflags=CREATE_NEW_PROCESS_GROUP)
        else:
            process = Popen(cmds, cwd=entry_dir, stdout=stdout_file, stderr=stderr_file)
144
    return process, str(time_now)
Deshui Yu's avatar
Deshui Yu committed
145

146
def set_trial_config(experiment_config, port, config_file_name):
147
    '''set trial configuration'''
Deshui Yu's avatar
Deshui Yu committed
148
    request_data = dict()
149
    request_data['trial_config'] = experiment_config['trial']
150
    response = rest_put(cluster_metadata_url(port), json.dumps(request_data), REST_TIME_OUT)
151
152
153
    if check_response(response):
        return True
    else:
154
        print('Error message is {}'.format(response.text))
155
        _, stderr_full_path = get_log_path(config_file_name)
SparkSnail's avatar
SparkSnail committed
156
157
158
        if response:
            with open(stderr_full_path, 'a+') as fout:
                fout.write(json.dumps(json.loads(response.text), indent=4, sort_keys=True, separators=(',', ':')))
159
        return False
160

161
def set_local_config(experiment_config, port, config_file_name):
162
    '''set local configuration'''
163
164
165
    request_data = dict()
    if experiment_config.get('localConfig'):
        request_data['local_config'] = experiment_config['localConfig']
166
167
168
169
170
171
172
        if request_data['local_config']:
            if request_data['local_config'].get('gpuIndices') and isinstance(request_data['local_config'].get('gpuIndices'), int):
                request_data['local_config']['gpuIndices'] = str(request_data['local_config'].get('gpuIndices'))
            if request_data['local_config'].get('maxTrialNumOnEachGpu'):
                request_data['local_config']['maxTrialNumOnEachGpu'] = request_data['local_config'].get('maxTrialNumOnEachGpu')
            if request_data['local_config'].get('useActiveGpu'):
                request_data['local_config']['useActiveGpu'] = request_data['local_config'].get('useActiveGpu')
173
174
175
176
177
178
179
180
181
182
        response = rest_put(cluster_metadata_url(port), json.dumps(request_data), REST_TIME_OUT)
        err_message = ''
        if not response or not check_response(response):
            if response is not None:
                err_message = response.text
                _, stderr_full_path = get_log_path(config_file_name)
                with open(stderr_full_path, 'a+') as fout:
                    fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
            return False, err_message

SparkSnail's avatar
SparkSnail committed
183
    return set_trial_config(experiment_config, port, config_file_name), None
Deshui Yu's avatar
Deshui Yu committed
184

185
def set_remote_config(experiment_config, port, config_file_name):
Deshui Yu's avatar
Deshui Yu committed
186
187
188
189
    '''Call setClusterMetadata to pass trial'''
    #set machine_list
    request_data = dict()
    request_data['machine_list'] = experiment_config['machineList']
190
191
192
193
    if request_data['machine_list']:
        for i in range(len(request_data['machine_list'])):
            if isinstance(request_data['machine_list'][i].get('gpuIndices'), int):
                request_data['machine_list'][i]['gpuIndices'] = str(request_data['machine_list'][i].get('gpuIndices'))
194
    response = rest_put(cluster_metadata_url(port), json.dumps(request_data), REST_TIME_OUT)
195
    err_message = ''
196
    if not response or not check_response(response):
197
198
        if response is not None:
            err_message = response.text
199
            _, stderr_full_path = get_log_path(config_file_name)
goooxu's avatar
goooxu committed
200
            with open(stderr_full_path, 'a+') as fout:
201
                fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
202
        return False, err_message
203
204
205
    result, message = setNNIManagerIp(experiment_config, port, config_file_name)
    if not result:
        return result, message
Deshui Yu's avatar
Deshui Yu committed
206
    #set trial_config
207
    return set_trial_config(experiment_config, port, config_file_name), err_message
Deshui Yu's avatar
Deshui Yu committed
208

209
210
211
212
213
def setNNIManagerIp(experiment_config, port, config_file_name):
    '''set nniManagerIp'''
    if experiment_config.get('nniManagerIp') is None:
        return True, None
    ip_config_dict = dict()
chicm-ms's avatar
chicm-ms committed
214
    ip_config_dict['nni_manager_ip'] = {'nniManagerIp': experiment_config['nniManagerIp']}
215
    response = rest_put(cluster_metadata_url(port), json.dumps(ip_config_dict), REST_TIME_OUT)
216
217
218
219
220
221
222
223
224
225
    err_message = None
    if not response or not response.status_code == 200:
        if response is not None:
            err_message = response.text
            _, stderr_full_path = get_log_path(config_file_name)
            with open(stderr_full_path, 'a+') as fout:
                fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
        return False, err_message
    return True, None

226
def set_pai_config(experiment_config, port, config_file_name):
227
    '''set pai configuration'''
228
229
    pai_config_data = dict()
    pai_config_data['pai_config'] = experiment_config['paiConfig']
230
    response = rest_put(cluster_metadata_url(port), json.dumps(pai_config_data), REST_TIME_OUT)
231
    err_message = None
232
233
234
    if not response or not response.status_code == 200:
        if response is not None:
            err_message = response.text
235
            _, stderr_full_path = get_log_path(config_file_name)
236
            with open(stderr_full_path, 'a+') as fout:
chicm-ms's avatar
chicm-ms committed
237
                fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
238
        return False, err_message
239
240
241
    result, message = setNNIManagerIp(experiment_config, port, config_file_name)
    if not result:
        return result, message
242
    #set trial_config
243
    return set_trial_config(experiment_config, port, config_file_name), err_message
244

245
def set_kubeflow_config(experiment_config, port, config_file_name):
246
    '''set kubeflow configuration'''
247
248
    kubeflow_config_data = dict()
    kubeflow_config_data['kubeflow_config'] = experiment_config['kubeflowConfig']
249
    response = rest_put(cluster_metadata_url(port), json.dumps(kubeflow_config_data), REST_TIME_OUT)
250
251
252
253
254
255
256
257
    err_message = None
    if not response or not response.status_code == 200:
        if response is not None:
            err_message = response.text
            _, stderr_full_path = get_log_path(config_file_name)
            with open(stderr_full_path, 'a+') as fout:
                fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
        return False, err_message
258
259
260
    result, message = setNNIManagerIp(experiment_config, port, config_file_name)
    if not result:
        return result, message
261
262
263
    #set trial_config
    return set_trial_config(experiment_config, port, config_file_name), err_message

264
def set_frameworkcontroller_config(experiment_config, port, config_file_name):
265
    '''set kubeflow configuration'''
266
267
    frameworkcontroller_config_data = dict()
    frameworkcontroller_config_data['frameworkcontroller_config'] = experiment_config['frameworkcontrollerConfig']
268
    response = rest_put(cluster_metadata_url(port), json.dumps(frameworkcontroller_config_data), REST_TIME_OUT)
269
270
271
272
273
274
275
276
277
278
279
280
281
282
    err_message = None
    if not response or not response.status_code == 200:
        if response is not None:
            err_message = response.text
            _, stderr_full_path = get_log_path(config_file_name)
            with open(stderr_full_path, 'a+') as fout:
                fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
        return False, err_message
    result, message = setNNIManagerIp(experiment_config, port, config_file_name)
    if not result:
        return result, message
    #set trial_config
    return set_trial_config(experiment_config, port, config_file_name), err_message

283
def set_experiment(experiment_config, mode, port, config_file_name):
Deshui Yu's avatar
Deshui Yu committed
284
285
286
287
288
289
290
    '''Call startExperiment (rest POST /experiment) with yaml file content'''
    request_data = dict()
    request_data['authorName'] = experiment_config['authorName']
    request_data['experimentName'] = experiment_config['experimentName']
    request_data['trialConcurrency'] = experiment_config['trialConcurrency']
    request_data['maxExecDuration'] = experiment_config['maxExecDuration']
    request_data['maxTrialNum'] = experiment_config['maxTrialNum']
291
    request_data['searchSpace'] = experiment_config.get('searchSpace')
292
293
294
295
    request_data['trainingServicePlatform'] = experiment_config.get('trainingServicePlatform')

    if experiment_config.get('description'):
        request_data['description'] = experiment_config['description']
chicm-ms's avatar
chicm-ms committed
296
297
    if experiment_config.get('multiPhase'):
        request_data['multiPhase'] = experiment_config.get('multiPhase')
298
299
    if experiment_config.get('multiThread'):
        request_data['multiThread'] = experiment_config.get('multiThread')
QuanluZhang's avatar
QuanluZhang committed
300
301
    if experiment_config.get('advisor'):
        request_data['advisor'] = experiment_config['advisor']
302
303
304
305
        if request_data['advisor'].get('gpuNum'):
            print_error('gpuNum is deprecated, please use gpuIndices instead.')
        if request_data['advisor'].get('gpuIndices') and isinstance(request_data['advisor'].get('gpuIndices'), int):
            request_data['advisor']['gpuIndices'] = str(request_data['advisor'].get('gpuIndices'))
QuanluZhang's avatar
QuanluZhang committed
306
307
    else:
        request_data['tuner'] = experiment_config['tuner']
308
309
310
311
        if request_data['tuner'].get('gpuNum'):
            print_error('gpuNum is deprecated, please use gpuIndices instead.')
        if request_data['tuner'].get('gpuIndices') and isinstance(request_data['tuner'].get('gpuIndices'), int):
            request_data['tuner']['gpuIndices'] = str(request_data['tuner'].get('gpuIndices'))
QuanluZhang's avatar
QuanluZhang committed
312
313
        if 'assessor' in experiment_config:
            request_data['assessor'] = experiment_config['assessor']
314
315
            if request_data['assessor'].get('gpuNum'):
                print_error('gpuNum is deprecated, please remove it from your config file.')
SparkSnail's avatar
SparkSnail committed
316
    #debug mode should disable version check
317
    if experiment_config.get('debug') is not None:
SparkSnail's avatar
SparkSnail committed
318
        request_data['versionCheck'] = not experiment_config.get('debug')
319
320
321
    #validate version check
    if experiment_config.get('versionCheck') is not None:
        request_data['versionCheck'] = experiment_config.get('versionCheck')
SparkSnail's avatar
SparkSnail committed
322
323
    if experiment_config.get('logCollection'):
        request_data['logCollection'] = experiment_config.get('logCollection')
Deshui Yu's avatar
Deshui Yu committed
324
325
326
327

    request_data['clusterMetaData'] = []
    if experiment_config['trainingServicePlatform'] == 'local':
        request_data['clusterMetaData'].append(
328
            {'key':'codeDir', 'value':experiment_config['trial']['codeDir']})
Deshui Yu's avatar
Deshui Yu committed
329
        request_data['clusterMetaData'].append(
330
            {'key': 'command', 'value': experiment_config['trial']['command']})
331
    elif experiment_config['trainingServicePlatform'] == 'remote':
Deshui Yu's avatar
Deshui Yu committed
332
333
334
        request_data['clusterMetaData'].append(
            {'key': 'machine_list', 'value': experiment_config['machineList']})
        request_data['clusterMetaData'].append(
335
            {'key': 'trial_config', 'value': experiment_config['trial']})
336
337
    elif experiment_config['trainingServicePlatform'] == 'pai':
        request_data['clusterMetaData'].append(
338
            {'key': 'pai_config', 'value': experiment_config['paiConfig']})
339
        request_data['clusterMetaData'].append(
340
341
342
343
344
345
            {'key': 'trial_config', 'value': experiment_config['trial']})
    elif experiment_config['trainingServicePlatform'] == 'kubeflow':
        request_data['clusterMetaData'].append(
            {'key': 'kubeflow_config', 'value': experiment_config['kubeflowConfig']})
        request_data['clusterMetaData'].append(
            {'key': 'trial_config', 'value': experiment_config['trial']})
346
347
348
349
350
    elif experiment_config['trainingServicePlatform'] == 'frameworkcontroller':
        request_data['clusterMetaData'].append(
            {'key': 'frameworkcontroller_config', 'value': experiment_config['frameworkcontrollerConfig']})
        request_data['clusterMetaData'].append(
            {'key': 'trial_config', 'value': experiment_config['trial']})
351
    response = rest_post(experiment_url(port), json.dumps(request_data), REST_TIME_OUT, show_error=True)
352
353
354
    if check_response(response):
        return response
    else:
355
        _, stderr_full_path = get_log_path(config_file_name)
356
        if response is not None:
SparkSnail's avatar
SparkSnail committed
357
358
            with open(stderr_full_path, 'a+') as fout:
                fout.write(json.dumps(json.loads(response.text), indent=4, sort_keys=True, separators=(',', ':')))
SparkSnail's avatar
SparkSnail committed
359
            print_error('Setting experiment error, error message is {}'.format(response.text))
360
        return None
Deshui Yu's avatar
Deshui Yu committed
361

SparkSnail's avatar
SparkSnail committed
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
def set_platform_config(platform, experiment_config, port, config_file_name, rest_process):
    '''call set_cluster_metadata for specific platform'''
    print_normal('Setting {0} config...'.format(platform))
    config_result, err_msg = None, None
    if platform == 'local':
        config_result, err_msg = set_local_config(experiment_config, port, config_file_name)
    elif platform == 'remote':
        config_result, err_msg = set_remote_config(experiment_config, port, config_file_name)
    elif platform == 'pai':
        config_result, err_msg = set_pai_config(experiment_config, port, config_file_name)
    elif platform == 'kubeflow':
        config_result, err_msg = set_kubeflow_config(experiment_config, port, config_file_name)
    elif platform == 'frameworkcontroller':
        config_result, err_msg = set_frameworkcontroller_config(experiment_config, port, config_file_name)
    else:
        raise Exception(ERROR_INFO % 'Unsupported platform!')
        exit(1)
    if config_result:
        print_normal('Successfully set {0} config!'.format(platform))
    else:
        print_error('Failed! Error is: {}'.format(err_msg))
        try:
            kill_command(rest_process.pid)
        except Exception:
            raise Exception(ERROR_INFO % 'Rest server stopped!')
        exit(1)

389
def launch_experiment(args, experiment_config, mode, config_file_name, experiment_id=None):
Deshui Yu's avatar
Deshui Yu committed
390
    '''follow steps to start rest server and start experiment'''
391
    nni_config = Config(config_file_name)
392
    # check packages for tuner
393
    package_name, module_name = None, None
394
    if experiment_config.get('tuner') and experiment_config['tuner'].get('builtinTunerName'):
395
396
397
398
399
400
        package_name = experiment_config['tuner']['builtinTunerName']
        module_name = ModuleName.get(package_name)
    elif experiment_config.get('advisor') and experiment_config['advisor'].get('builtinAdvisorName'):
        package_name = experiment_config['advisor']['builtinAdvisorName']
        module_name = AdvisorModuleName.get(package_name)
    if package_name and module_name:
401
        try:
402
403
404
            stdout_full_path, stderr_full_path = get_log_path(config_file_name)
            with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file:
                check_call([sys.executable, '-c', 'import %s'%(module_name)], stdout=stdout_file, stderr=stderr_file)
chicm-ms's avatar
chicm-ms committed
405
        except CalledProcessError:
406
407
408
            print_error('some errors happen when import package %s.' %(package_name))
            print_log_content(config_file_name)
            if package_name in PACKAGE_REQUIREMENTS:
chicm-ms's avatar
chicm-ms committed
409
410
                print_error('If %s is not installed, it should be installed through '\
                            '\'nnictl package install --name %s\''%(package_name, package_name))
411
            exit(1)
412
413
    log_dir = experiment_config['logDir'] if experiment_config.get('logDir') else None
    log_level = experiment_config['logLevel'] if experiment_config.get('logLevel') else None
SparkSnail's avatar
SparkSnail committed
414
415
416
417
    #view experiment mode do not need debug function, when view an experiment, there will be no new logs created
    if mode != 'view':
        if log_level not in ['trace', 'debug'] and (args.debug or experiment_config.get('debug') is True):
            log_level = 'debug'
Deshui Yu's avatar
Deshui Yu committed
418
    # start rest server
chicm-ms's avatar
chicm-ms committed
419
420
    rest_process, start_time = start_rest_server(args.port, experiment_config['trainingServicePlatform'], \
                                                 mode, config_file_name, experiment_id, log_dir, log_level)
Deshui Yu's avatar
Deshui Yu committed
421
422
423
    nni_config.set_config('restServerPid', rest_process.pid)
    # Deal with annotation
    if experiment_config.get('useAnnotation'):
424
        path = os.path.join(tempfile.gettempdir(), get_user(), 'nni', 'annotation')
QuanluZhang's avatar
QuanluZhang committed
425
426
        if not os.path.isdir(path):
            os.makedirs(path)
liuzhe-lz's avatar
liuzhe-lz committed
427
        path = tempfile.mkdtemp(dir=path)
428
429
        nas_mode = experiment_config['trial'].get('nasMode', 'classic_mode')
        code_dir = expand_annotations(experiment_config['trial']['codeDir'], path, nas_mode=nas_mode)
liuzhe-lz's avatar
liuzhe-lz committed
430
431
        experiment_config['trial']['codeDir'] = code_dir
        search_space = generate_search_space(code_dir)
432
        experiment_config['searchSpace'] = json.dumps(search_space)
Deshui Yu's avatar
Deshui Yu committed
433
        assert search_space, ERROR_INFO % 'Generated search space is empty'
434
    elif experiment_config.get('searchSpacePath'):
Zejun Lin's avatar
Zejun Lin committed
435
436
        search_space = get_json_content(experiment_config.get('searchSpacePath'))
        experiment_config['searchSpace'] = json.dumps(search_space)
Deshui Yu's avatar
Deshui Yu committed
437
    else:
438
        experiment_config['searchSpace'] = json.dumps('')
Deshui Yu's avatar
Deshui Yu committed
439
440

    # check rest server
goooxu's avatar
goooxu committed
441
    running, _ = check_rest_server(args.port)
442
    if running:
443
        print_normal('Successfully started Restful server!')
Deshui Yu's avatar
Deshui Yu committed
444
445
    else:
        print_error('Restful server start failed!')
446
        print_log_content(config_file_name)
Deshui Yu's avatar
Deshui Yu committed
447
        try:
448
            kill_command(rest_process.pid)
Deshui Yu's avatar
Deshui Yu committed
449
450
        except Exception:
            raise Exception(ERROR_INFO % 'Rest server stopped!')
goooxu's avatar
goooxu committed
451
        exit(1)
SparkSnail's avatar
SparkSnail committed
452
453
    if mode != 'view':
        # set platform configuration
chicm-ms's avatar
chicm-ms committed
454
455
456
        set_platform_config(experiment_config['trainingServicePlatform'], experiment_config, args.port,\
                            config_file_name, rest_process)

Deshui Yu's avatar
Deshui Yu committed
457
458
    # start a new experiment
    print_normal('Starting experiment...')
459
    # set debug configuration
SparkSnail's avatar
SparkSnail committed
460
    if mode != 'view' and experiment_config.get('debug') is None:
461
        experiment_config['debug'] = args.debug
462
    response = set_experiment(experiment_config, mode, args.port, config_file_name)
Deshui Yu's avatar
Deshui Yu committed
463
464
465
466
467
    if response:
        if experiment_id is None:
            experiment_id = json.loads(response.text).get('experiment_id')
        nni_config.set_config('experimentId', experiment_id)
    else:
468
469
        print_error('Start experiment failed!')
        print_log_content(config_file_name)
Deshui Yu's avatar
Deshui Yu committed
470
        try:
471
            kill_command(rest_process.pid)
Deshui Yu's avatar
Deshui Yu committed
472
        except Exception:
473
            raise Exception(ERROR_INFO % 'Restful server stopped!')
goooxu's avatar
goooxu committed
474
        exit(1)
475
476
477
478
    if experiment_config.get('nniManagerIp'):
        web_ui_url_list = ['{0}:{1}'.format(experiment_config['nniManagerIp'], str(args.port))]
    else:
        web_ui_url_list = get_local_urls(args.port)
SparkSnail's avatar
SparkSnail committed
479
    nni_config.set_config('webuiUrl', web_ui_url_list)
480

481
    # save experiment information
SparkSnail's avatar
SparkSnail committed
482
    nnictl_experiment_config = Experiments()
483
484
485
    nnictl_experiment_config.add_experiment(experiment_id, args.port, start_time, config_file_name,
                                            experiment_config['trainingServicePlatform'],
                                            experiment_config['experimentName'])
486
487

    print_normal(EXPERIMENT_SUCCESS_INFO % (experiment_id, '   '.join(web_ui_url_list)))
Deshui Yu's avatar
Deshui Yu committed
488

SparkSnail's avatar
SparkSnail committed
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
def create_experiment(args):
    '''start a new experiment'''
    config_file_name = ''.join(random.sample(string.ascii_letters + string.digits, 8))
    nni_config = Config(config_file_name)
    config_path = os.path.abspath(args.config)
    if not os.path.exists(config_path):
        print_error('Please set correct config path!')
        exit(1)
    experiment_config = get_yml_content(config_path)
    validate_all_content(experiment_config, config_path)

    nni_config.set_config('experimentConfig', experiment_config)
    launch_experiment(args, experiment_config, 'new', config_file_name)
    nni_config.set_config('restServerPort', args.port)

def manage_stopped_experiment(args, mode):
    '''view a stopped experiment'''
SparkSnail's avatar
SparkSnail committed
506
    update_experiment()
507
508
509
510
511
    experiment_config = Experiments()
    experiment_dict = experiment_config.get_all_experiments()
    experiment_id = None
    #find the latest stopped experiment
    if not args.id:
SparkSnail's avatar
SparkSnail committed
512
513
        print_error('Please set experiment id! \nYou could use \'nnictl {0} {id}\' to {0} a stopped experiment!\n' \
        'You could use \'nnictl experiment list --all\' to show all experiments!'.format(mode))
SparkSnail's avatar
SparkSnail committed
514
        exit(1)
515
516
517
518
    else:
        if experiment_dict.get(args.id) is None:
            print_error('Id %s not exist!' % args.id)
            exit(1)
519
        if experiment_dict[args.id]['status'] != 'STOPPED':
SparkSnail's avatar
SparkSnail committed
520
            print_error('Only stopped experiments can be {0}ed!'.format(mode))
521
522
            exit(1)
        experiment_id = args.id
SparkSnail's avatar
SparkSnail committed
523
    print_normal('{0} experiment {1}...'.format(mode, experiment_id))
524
    nni_config = Config(experiment_dict[experiment_id]['fileName'])
Deshui Yu's avatar
Deshui Yu committed
525
526
    experiment_config = nni_config.get_config('experimentConfig')
    experiment_id = nni_config.get_config('experimentId')
SparkSnail's avatar
SparkSnail committed
527
528
529
    new_config_file_name = ''.join(random.sample(string.ascii_letters + string.digits, 8))
    new_nni_config = Config(new_config_file_name)
    new_nni_config.set_config('experimentConfig', experiment_config)
SparkSnail's avatar
SparkSnail committed
530
    launch_experiment(args, experiment_config, mode, new_config_file_name, experiment_id)
SparkSnail's avatar
SparkSnail committed
531
    new_nni_config.set_config('restServerPort', args.port)
Deshui Yu's avatar
Deshui Yu committed
532

SparkSnail's avatar
SparkSnail committed
533
534
535
def view_experiment(args):
    '''view a stopped experiment'''
    manage_stopped_experiment(args, 'view')
Deshui Yu's avatar
Deshui Yu committed
536

SparkSnail's avatar
SparkSnail committed
537
538
539
def resume_experiment(args):
    '''resume an experiment'''
    manage_stopped_experiment(args, 'resume')