"vscode:/vscode.git/clone" did not exist on "a0bbd30e29eee23e9718e5910bacc0b269bee043"
launcher.py 26.2 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
Deshui Yu's avatar
Deshui Yu committed
3
4
5

import json
import os
6
import sys
7
import string
chicm-ms's avatar
chicm-ms committed
8
9
10
import random
import site
import time
Deshui Yu's avatar
Deshui Yu committed
11
import tempfile
chicm-ms's avatar
chicm-ms committed
12
13
from subprocess import Popen, check_call, CalledProcessError
from nni_annotation import expand_annotations, generate_search_space
14
from nni.constants import ModuleName, AdvisorModuleName
Deshui Yu's avatar
Deshui Yu committed
15
from .launcher_utils import validate_all_content
chicm-ms's avatar
chicm-ms committed
16
from .rest_utils import rest_put, rest_post, check_rest_server, check_response
SparkSnail's avatar
SparkSnail committed
17
from .url_utils import cluster_metadata_url, experiment_url, get_local_urls
18
from .config_utils import Config, Experiments
chicm-ms's avatar
chicm-ms committed
19
20
21
from .common_utils import get_yml_content, get_json_content, print_error, print_normal, \
                          detect_port, get_user, get_python_dir
from .constants import NNICTL_HOME_DIR, ERROR_INFO, REST_TIME_OUT, EXPERIMENT_SUCCESS_INFO, LOG_HEADER, PACKAGE_REQUIREMENTS
22
from .command_utils import check_output_command, kill_command
SparkSnail's avatar
SparkSnail committed
23
from .nnictl_utils import update_experiment
Gems Guo's avatar
Gems Guo committed
24

25
26
27
28
29
30
31
32
33
34
def get_log_path(config_file_name):
    '''generate stdout and stderr log path'''
    stdout_full_path = os.path.join(NNICTL_HOME_DIR, config_file_name, 'stdout')
    stderr_full_path = os.path.join(NNICTL_HOME_DIR, config_file_name, 'stderr')
    return stdout_full_path, stderr_full_path

def print_log_content(config_file_name):
    '''print log information'''
    stdout_full_path, stderr_full_path = get_log_path(config_file_name)
    print_normal(' Stdout:')
35
    print(check_output_command(stdout_full_path))
36
37
    print('\n\n')
    print_normal(' Stderr:')
38
    print(check_output_command(stderr_full_path))
39

Zejun Lin's avatar
Zejun Lin committed
40
41
42
43
44
45
46
47
48
def get_nni_installation_path():
    ''' Find nni lib from the following locations in order
    Return nni root directory if it exists
    '''
    def try_installation_path_sequentially(*sitepackages):
        '''Try different installation path sequentially util nni is found.
        Return None if nothing is found
        '''
        def _generate_installation_path(sitepackages_path):
49
            python_dir = get_python_dir(sitepackages_path)
Zejun Lin's avatar
Zejun Lin committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
            entry_file = os.path.join(python_dir, 'nni', 'main.js')
            if os.path.isfile(entry_file):
                return python_dir
            return None

        for sitepackage in sitepackages:
            python_dir = _generate_installation_path(sitepackage)
            if python_dir:
                return python_dir
        return None

    if os.getenv('VIRTUAL_ENV'):
        # if 'virtualenv' package is used, `site` has not attr getsitepackages, so we will instead use VIRTUAL_ENV
        # Note that conda venv will not have VIRTUAL_ENV
        python_dir = os.getenv('VIRTUAL_ENV')
    else:
        python_sitepackage = site.getsitepackages()[0]
chicm-ms's avatar
chicm-ms committed
67
68
        # If system-wide python is used, we will give priority to using `local sitepackage`--"usersitepackages()" given
        # that nni exists there
Zejun Lin's avatar
Zejun Lin committed
69
70
71
72
73
74
75
76
77
78
79
        if python_sitepackage.startswith('/usr') or python_sitepackage.startswith('/Library'):
            python_dir = try_installation_path_sequentially(site.getusersitepackages(), site.getsitepackages()[0])
        else:
            python_dir = try_installation_path_sequentially(site.getsitepackages()[0], site.getusersitepackages())

    if python_dir:
        entry_file = os.path.join(python_dir, 'nni', 'main.js')
        if os.path.isfile(entry_file):
            return os.path.join(python_dir, 'nni')
    print_error('Fail to find nni under python library')
    exit(1)
80

81
def start_rest_server(port, platform, mode, config_file_name, experiment_id=None, log_dir=None, log_level=None):
Deshui Yu's avatar
Deshui Yu committed
82
    '''Run nni manager process'''
83
    if detect_port(port):
84
85
86
        print_error('Port %s is used by another process, please reset the port!\n' \
        'You could use \'nnictl create --help\' to get help information' % port)
        exit(1)
87

SparkSnail's avatar
SparkSnail committed
88
    if (platform != 'local') and detect_port(int(port) + 1):
89
90
91
        print_error('PAI mode need an additional adjacent port %d, and the port %d is used by another process!\n' \
        'You could set another port to start experiment!\n' \
        'You could use \'nnictl create --help\' to get help information' % ((int(port) + 1), (int(port) + 1)))
92
        exit(1)
Deshui Yu's avatar
Deshui Yu committed
93
94

    print_normal('Starting restful server...')
95

Zejun Lin's avatar
Zejun Lin committed
96
97
    entry_dir = get_nni_installation_path()
    entry_file = os.path.join(entry_dir, 'main.js')
chicm-ms's avatar
chicm-ms committed
98

demianzhang's avatar
demianzhang committed
99
100
101
    node_command = 'node'
    if sys.platform == 'win32':
        node_command = os.path.join(entry_dir[:-3], 'Scripts', 'node.exe')
SparkSnail's avatar
SparkSnail committed
102
103
104
105
106
107
    cmds = [node_command, entry_file, '--port', str(port), '--mode', platform]
    if mode == 'view':
        cmds += ['--start_mode', 'resume']
        cmds += ['--readonly', 'true']
    else:
        cmds += ['--start_mode', mode]
108
109
110
111
    if log_dir is not None:
        cmds += ['--log_dir', log_dir]
    if log_level is not None:
        cmds += ['--log_level', log_level]
SparkSnail's avatar
SparkSnail committed
112
    if mode in ['resume', 'view']:
Deshui Yu's avatar
Deshui Yu committed
113
        cmds += ['--experiment_id', experiment_id]
114
    stdout_full_path, stderr_full_path = get_log_path(config_file_name)
115
    with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file:
chicm-ms's avatar
chicm-ms committed
116
        time_now = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
117
118
119
120
121
122
123
124
125
        #add time information in the header of log files
        log_header = LOG_HEADER % str(time_now)
        stdout_file.write(log_header)
        stderr_file.write(log_header)
        if sys.platform == 'win32':
            from subprocess import CREATE_NEW_PROCESS_GROUP
            process = Popen(cmds, cwd=entry_dir, stdout=stdout_file, stderr=stderr_file, creationflags=CREATE_NEW_PROCESS_GROUP)
        else:
            process = Popen(cmds, cwd=entry_dir, stdout=stdout_file, stderr=stderr_file)
126
    return process, str(time_now)
Deshui Yu's avatar
Deshui Yu committed
127

128
def set_trial_config(experiment_config, port, config_file_name):
129
    '''set trial configuration'''
Deshui Yu's avatar
Deshui Yu committed
130
    request_data = dict()
131
    request_data['trial_config'] = experiment_config['trial']
132
    response = rest_put(cluster_metadata_url(port), json.dumps(request_data), REST_TIME_OUT)
133
134
135
    if check_response(response):
        return True
    else:
136
        print('Error message is {}'.format(response.text))
137
        _, stderr_full_path = get_log_path(config_file_name)
SparkSnail's avatar
SparkSnail committed
138
139
140
        if response:
            with open(stderr_full_path, 'a+') as fout:
                fout.write(json.dumps(json.loads(response.text), indent=4, sort_keys=True, separators=(',', ':')))
141
        return False
142

143
def set_local_config(experiment_config, port, config_file_name):
144
    '''set local configuration'''
145
146
147
    request_data = dict()
    if experiment_config.get('localConfig'):
        request_data['local_config'] = experiment_config['localConfig']
148
149
150
151
152
153
154
        if request_data['local_config']:
            if request_data['local_config'].get('gpuIndices') and isinstance(request_data['local_config'].get('gpuIndices'), int):
                request_data['local_config']['gpuIndices'] = str(request_data['local_config'].get('gpuIndices'))
            if request_data['local_config'].get('maxTrialNumOnEachGpu'):
                request_data['local_config']['maxTrialNumOnEachGpu'] = request_data['local_config'].get('maxTrialNumOnEachGpu')
            if request_data['local_config'].get('useActiveGpu'):
                request_data['local_config']['useActiveGpu'] = request_data['local_config'].get('useActiveGpu')
155
156
157
158
159
160
161
162
163
164
        response = rest_put(cluster_metadata_url(port), json.dumps(request_data), REST_TIME_OUT)
        err_message = ''
        if not response or not check_response(response):
            if response is not None:
                err_message = response.text
                _, stderr_full_path = get_log_path(config_file_name)
                with open(stderr_full_path, 'a+') as fout:
                    fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
            return False, err_message

SparkSnail's avatar
SparkSnail committed
165
    return set_trial_config(experiment_config, port, config_file_name), None
Deshui Yu's avatar
Deshui Yu committed
166

167
def set_remote_config(experiment_config, port, config_file_name):
Deshui Yu's avatar
Deshui Yu committed
168
169
170
171
    '''Call setClusterMetadata to pass trial'''
    #set machine_list
    request_data = dict()
    request_data['machine_list'] = experiment_config['machineList']
172
173
174
175
    if request_data['machine_list']:
        for i in range(len(request_data['machine_list'])):
            if isinstance(request_data['machine_list'][i].get('gpuIndices'), int):
                request_data['machine_list'][i]['gpuIndices'] = str(request_data['machine_list'][i].get('gpuIndices'))
176
    response = rest_put(cluster_metadata_url(port), json.dumps(request_data), REST_TIME_OUT)
177
    err_message = ''
178
    if not response or not check_response(response):
179
180
        if response is not None:
            err_message = response.text
181
            _, stderr_full_path = get_log_path(config_file_name)
goooxu's avatar
goooxu committed
182
            with open(stderr_full_path, 'a+') as fout:
183
                fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
184
        return False, err_message
185
186
187
    result, message = setNNIManagerIp(experiment_config, port, config_file_name)
    if not result:
        return result, message
Deshui Yu's avatar
Deshui Yu committed
188
    #set trial_config
189
    return set_trial_config(experiment_config, port, config_file_name), err_message
Deshui Yu's avatar
Deshui Yu committed
190

191
192
193
194
195
def setNNIManagerIp(experiment_config, port, config_file_name):
    '''set nniManagerIp'''
    if experiment_config.get('nniManagerIp') is None:
        return True, None
    ip_config_dict = dict()
chicm-ms's avatar
chicm-ms committed
196
    ip_config_dict['nni_manager_ip'] = {'nniManagerIp': experiment_config['nniManagerIp']}
197
    response = rest_put(cluster_metadata_url(port), json.dumps(ip_config_dict), REST_TIME_OUT)
198
199
200
201
202
203
204
205
206
207
    err_message = None
    if not response or not response.status_code == 200:
        if response is not None:
            err_message = response.text
            _, stderr_full_path = get_log_path(config_file_name)
            with open(stderr_full_path, 'a+') as fout:
                fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
        return False, err_message
    return True, None

208
def set_pai_config(experiment_config, port, config_file_name):
209
    '''set pai configuration'''
210
211
    pai_config_data = dict()
    pai_config_data['pai_config'] = experiment_config['paiConfig']
212
    response = rest_put(cluster_metadata_url(port), json.dumps(pai_config_data), REST_TIME_OUT)
213
    err_message = None
214
215
216
    if not response or not response.status_code == 200:
        if response is not None:
            err_message = response.text
217
            _, stderr_full_path = get_log_path(config_file_name)
218
            with open(stderr_full_path, 'a+') as fout:
chicm-ms's avatar
chicm-ms committed
219
                fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
220
        return False, err_message
221
222
223
    result, message = setNNIManagerIp(experiment_config, port, config_file_name)
    if not result:
        return result, message
224
    #set trial_config
225
    return set_trial_config(experiment_config, port, config_file_name), err_message
226

227
def set_kubeflow_config(experiment_config, port, config_file_name):
228
    '''set kubeflow configuration'''
229
230
    kubeflow_config_data = dict()
    kubeflow_config_data['kubeflow_config'] = experiment_config['kubeflowConfig']
231
    response = rest_put(cluster_metadata_url(port), json.dumps(kubeflow_config_data), REST_TIME_OUT)
232
233
234
235
236
237
238
239
    err_message = None
    if not response or not response.status_code == 200:
        if response is not None:
            err_message = response.text
            _, stderr_full_path = get_log_path(config_file_name)
            with open(stderr_full_path, 'a+') as fout:
                fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
        return False, err_message
240
241
242
    result, message = setNNIManagerIp(experiment_config, port, config_file_name)
    if not result:
        return result, message
243
244
245
    #set trial_config
    return set_trial_config(experiment_config, port, config_file_name), err_message

246
def set_frameworkcontroller_config(experiment_config, port, config_file_name):
247
    '''set kubeflow configuration'''
248
249
    frameworkcontroller_config_data = dict()
    frameworkcontroller_config_data['frameworkcontroller_config'] = experiment_config['frameworkcontrollerConfig']
250
    response = rest_put(cluster_metadata_url(port), json.dumps(frameworkcontroller_config_data), REST_TIME_OUT)
251
252
253
254
255
256
257
258
259
260
261
262
263
264
    err_message = None
    if not response or not response.status_code == 200:
        if response is not None:
            err_message = response.text
            _, stderr_full_path = get_log_path(config_file_name)
            with open(stderr_full_path, 'a+') as fout:
                fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
        return False, err_message
    result, message = setNNIManagerIp(experiment_config, port, config_file_name)
    if not result:
        return result, message
    #set trial_config
    return set_trial_config(experiment_config, port, config_file_name), err_message

265
def set_experiment(experiment_config, mode, port, config_file_name):
Deshui Yu's avatar
Deshui Yu committed
266
267
268
269
270
271
272
    '''Call startExperiment (rest POST /experiment) with yaml file content'''
    request_data = dict()
    request_data['authorName'] = experiment_config['authorName']
    request_data['experimentName'] = experiment_config['experimentName']
    request_data['trialConcurrency'] = experiment_config['trialConcurrency']
    request_data['maxExecDuration'] = experiment_config['maxExecDuration']
    request_data['maxTrialNum'] = experiment_config['maxTrialNum']
273
    request_data['searchSpace'] = experiment_config.get('searchSpace')
274
275
276
277
    request_data['trainingServicePlatform'] = experiment_config.get('trainingServicePlatform')

    if experiment_config.get('description'):
        request_data['description'] = experiment_config['description']
chicm-ms's avatar
chicm-ms committed
278
279
    if experiment_config.get('multiPhase'):
        request_data['multiPhase'] = experiment_config.get('multiPhase')
280
281
    if experiment_config.get('multiThread'):
        request_data['multiThread'] = experiment_config.get('multiThread')
QuanluZhang's avatar
QuanluZhang committed
282
283
    if experiment_config.get('advisor'):
        request_data['advisor'] = experiment_config['advisor']
284
285
286
287
        if request_data['advisor'].get('gpuNum'):
            print_error('gpuNum is deprecated, please use gpuIndices instead.')
        if request_data['advisor'].get('gpuIndices') and isinstance(request_data['advisor'].get('gpuIndices'), int):
            request_data['advisor']['gpuIndices'] = str(request_data['advisor'].get('gpuIndices'))
QuanluZhang's avatar
QuanluZhang committed
288
289
    else:
        request_data['tuner'] = experiment_config['tuner']
290
291
292
293
        if request_data['tuner'].get('gpuNum'):
            print_error('gpuNum is deprecated, please use gpuIndices instead.')
        if request_data['tuner'].get('gpuIndices') and isinstance(request_data['tuner'].get('gpuIndices'), int):
            request_data['tuner']['gpuIndices'] = str(request_data['tuner'].get('gpuIndices'))
QuanluZhang's avatar
QuanluZhang committed
294
295
        if 'assessor' in experiment_config:
            request_data['assessor'] = experiment_config['assessor']
296
297
            if request_data['assessor'].get('gpuNum'):
                print_error('gpuNum is deprecated, please remove it from your config file.')
SparkSnail's avatar
SparkSnail committed
298
    #debug mode should disable version check
299
    if experiment_config.get('debug') is not None:
SparkSnail's avatar
SparkSnail committed
300
        request_data['versionCheck'] = not experiment_config.get('debug')
301
302
303
    #validate version check
    if experiment_config.get('versionCheck') is not None:
        request_data['versionCheck'] = experiment_config.get('versionCheck')
SparkSnail's avatar
SparkSnail committed
304
305
    if experiment_config.get('logCollection'):
        request_data['logCollection'] = experiment_config.get('logCollection')
Deshui Yu's avatar
Deshui Yu committed
306
307
308
309

    request_data['clusterMetaData'] = []
    if experiment_config['trainingServicePlatform'] == 'local':
        request_data['clusterMetaData'].append(
310
            {'key':'codeDir', 'value':experiment_config['trial']['codeDir']})
Deshui Yu's avatar
Deshui Yu committed
311
        request_data['clusterMetaData'].append(
312
            {'key': 'command', 'value': experiment_config['trial']['command']})
313
    elif experiment_config['trainingServicePlatform'] == 'remote':
Deshui Yu's avatar
Deshui Yu committed
314
315
316
        request_data['clusterMetaData'].append(
            {'key': 'machine_list', 'value': experiment_config['machineList']})
        request_data['clusterMetaData'].append(
317
            {'key': 'trial_config', 'value': experiment_config['trial']})
318
319
    elif experiment_config['trainingServicePlatform'] == 'pai':
        request_data['clusterMetaData'].append(
320
            {'key': 'pai_config', 'value': experiment_config['paiConfig']})
321
        request_data['clusterMetaData'].append(
322
323
324
325
326
327
            {'key': 'trial_config', 'value': experiment_config['trial']})
    elif experiment_config['trainingServicePlatform'] == 'kubeflow':
        request_data['clusterMetaData'].append(
            {'key': 'kubeflow_config', 'value': experiment_config['kubeflowConfig']})
        request_data['clusterMetaData'].append(
            {'key': 'trial_config', 'value': experiment_config['trial']})
328
329
330
331
332
    elif experiment_config['trainingServicePlatform'] == 'frameworkcontroller':
        request_data['clusterMetaData'].append(
            {'key': 'frameworkcontroller_config', 'value': experiment_config['frameworkcontrollerConfig']})
        request_data['clusterMetaData'].append(
            {'key': 'trial_config', 'value': experiment_config['trial']})
333
    response = rest_post(experiment_url(port), json.dumps(request_data), REST_TIME_OUT, show_error=True)
334
335
336
    if check_response(response):
        return response
    else:
337
        _, stderr_full_path = get_log_path(config_file_name)
338
        if response is not None:
SparkSnail's avatar
SparkSnail committed
339
340
            with open(stderr_full_path, 'a+') as fout:
                fout.write(json.dumps(json.loads(response.text), indent=4, sort_keys=True, separators=(',', ':')))
SparkSnail's avatar
SparkSnail committed
341
            print_error('Setting experiment error, error message is {}'.format(response.text))
342
        return None
Deshui Yu's avatar
Deshui Yu committed
343

SparkSnail's avatar
SparkSnail committed
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
def set_platform_config(platform, experiment_config, port, config_file_name, rest_process):
    '''call set_cluster_metadata for specific platform'''
    print_normal('Setting {0} config...'.format(platform))
    config_result, err_msg = None, None
    if platform == 'local':
        config_result, err_msg = set_local_config(experiment_config, port, config_file_name)
    elif platform == 'remote':
        config_result, err_msg = set_remote_config(experiment_config, port, config_file_name)
    elif platform == 'pai':
        config_result, err_msg = set_pai_config(experiment_config, port, config_file_name)
    elif platform == 'kubeflow':
        config_result, err_msg = set_kubeflow_config(experiment_config, port, config_file_name)
    elif platform == 'frameworkcontroller':
        config_result, err_msg = set_frameworkcontroller_config(experiment_config, port, config_file_name)
    else:
        raise Exception(ERROR_INFO % 'Unsupported platform!')
        exit(1)
    if config_result:
        print_normal('Successfully set {0} config!'.format(platform))
    else:
        print_error('Failed! Error is: {}'.format(err_msg))
        try:
            kill_command(rest_process.pid)
        except Exception:
            raise Exception(ERROR_INFO % 'Rest server stopped!')
        exit(1)

371
def launch_experiment(args, experiment_config, mode, config_file_name, experiment_id=None):
Deshui Yu's avatar
Deshui Yu committed
372
    '''follow steps to start rest server and start experiment'''
373
    nni_config = Config(config_file_name)
374
    # check packages for tuner
375
    package_name, module_name = None, None
376
    if experiment_config.get('tuner') and experiment_config['tuner'].get('builtinTunerName'):
377
378
379
380
381
382
        package_name = experiment_config['tuner']['builtinTunerName']
        module_name = ModuleName.get(package_name)
    elif experiment_config.get('advisor') and experiment_config['advisor'].get('builtinAdvisorName'):
        package_name = experiment_config['advisor']['builtinAdvisorName']
        module_name = AdvisorModuleName.get(package_name)
    if package_name and module_name:
383
        try:
384
385
386
            stdout_full_path, stderr_full_path = get_log_path(config_file_name)
            with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file:
                check_call([sys.executable, '-c', 'import %s'%(module_name)], stdout=stdout_file, stderr=stderr_file)
chicm-ms's avatar
chicm-ms committed
387
        except CalledProcessError:
388
389
390
            print_error('some errors happen when import package %s.' %(package_name))
            print_log_content(config_file_name)
            if package_name in PACKAGE_REQUIREMENTS:
chicm-ms's avatar
chicm-ms committed
391
392
                print_error('If %s is not installed, it should be installed through '\
                            '\'nnictl package install --name %s\''%(package_name, package_name))
393
            exit(1)
394
395
    log_dir = experiment_config['logDir'] if experiment_config.get('logDir') else None
    log_level = experiment_config['logLevel'] if experiment_config.get('logLevel') else None
SparkSnail's avatar
SparkSnail committed
396
397
398
399
    #view experiment mode do not need debug function, when view an experiment, there will be no new logs created
    if mode != 'view':
        if log_level not in ['trace', 'debug'] and (args.debug or experiment_config.get('debug') is True):
            log_level = 'debug'
Deshui Yu's avatar
Deshui Yu committed
400
    # start rest server
chicm-ms's avatar
chicm-ms committed
401
402
    rest_process, start_time = start_rest_server(args.port, experiment_config['trainingServicePlatform'], \
                                                 mode, config_file_name, experiment_id, log_dir, log_level)
Deshui Yu's avatar
Deshui Yu committed
403
404
405
    nni_config.set_config('restServerPid', rest_process.pid)
    # Deal with annotation
    if experiment_config.get('useAnnotation'):
406
        path = os.path.join(tempfile.gettempdir(), get_user(), 'nni', 'annotation')
QuanluZhang's avatar
QuanluZhang committed
407
408
        if not os.path.isdir(path):
            os.makedirs(path)
liuzhe-lz's avatar
liuzhe-lz committed
409
        path = tempfile.mkdtemp(dir=path)
410
411
        nas_mode = experiment_config['trial'].get('nasMode', 'classic_mode')
        code_dir = expand_annotations(experiment_config['trial']['codeDir'], path, nas_mode=nas_mode)
liuzhe-lz's avatar
liuzhe-lz committed
412
413
        experiment_config['trial']['codeDir'] = code_dir
        search_space = generate_search_space(code_dir)
414
        experiment_config['searchSpace'] = json.dumps(search_space)
Deshui Yu's avatar
Deshui Yu committed
415
        assert search_space, ERROR_INFO % 'Generated search space is empty'
416
    elif experiment_config.get('searchSpacePath'):
Zejun Lin's avatar
Zejun Lin committed
417
418
        search_space = get_json_content(experiment_config.get('searchSpacePath'))
        experiment_config['searchSpace'] = json.dumps(search_space)
Deshui Yu's avatar
Deshui Yu committed
419
    else:
420
        experiment_config['searchSpace'] = json.dumps('')
Deshui Yu's avatar
Deshui Yu committed
421
422

    # check rest server
goooxu's avatar
goooxu committed
423
    running, _ = check_rest_server(args.port)
424
    if running:
425
        print_normal('Successfully started Restful server!')
Deshui Yu's avatar
Deshui Yu committed
426
427
    else:
        print_error('Restful server start failed!')
428
        print_log_content(config_file_name)
Deshui Yu's avatar
Deshui Yu committed
429
        try:
430
            kill_command(rest_process.pid)
Deshui Yu's avatar
Deshui Yu committed
431
432
        except Exception:
            raise Exception(ERROR_INFO % 'Rest server stopped!')
goooxu's avatar
goooxu committed
433
        exit(1)
SparkSnail's avatar
SparkSnail committed
434
435
    if mode != 'view':
        # set platform configuration
chicm-ms's avatar
chicm-ms committed
436
437
438
        set_platform_config(experiment_config['trainingServicePlatform'], experiment_config, args.port,\
                            config_file_name, rest_process)

Deshui Yu's avatar
Deshui Yu committed
439
440
    # start a new experiment
    print_normal('Starting experiment...')
441
    # set debug configuration
SparkSnail's avatar
SparkSnail committed
442
    if mode != 'view' and experiment_config.get('debug') is None:
443
        experiment_config['debug'] = args.debug
444
    response = set_experiment(experiment_config, mode, args.port, config_file_name)
Deshui Yu's avatar
Deshui Yu committed
445
446
447
448
449
    if response:
        if experiment_id is None:
            experiment_id = json.loads(response.text).get('experiment_id')
        nni_config.set_config('experimentId', experiment_id)
    else:
450
451
        print_error('Start experiment failed!')
        print_log_content(config_file_name)
Deshui Yu's avatar
Deshui Yu committed
452
        try:
453
            kill_command(rest_process.pid)
Deshui Yu's avatar
Deshui Yu committed
454
        except Exception:
455
            raise Exception(ERROR_INFO % 'Restful server stopped!')
goooxu's avatar
goooxu committed
456
        exit(1)
457
458
459
460
    if experiment_config.get('nniManagerIp'):
        web_ui_url_list = ['{0}:{1}'.format(experiment_config['nniManagerIp'], str(args.port))]
    else:
        web_ui_url_list = get_local_urls(args.port)
SparkSnail's avatar
SparkSnail committed
461
    nni_config.set_config('webuiUrl', web_ui_url_list)
462

463
    # save experiment information
SparkSnail's avatar
SparkSnail committed
464
    nnictl_experiment_config = Experiments()
465
466
467
    nnictl_experiment_config.add_experiment(experiment_id, args.port, start_time, config_file_name,
                                            experiment_config['trainingServicePlatform'],
                                            experiment_config['experimentName'])
468
469

    print_normal(EXPERIMENT_SUCCESS_INFO % (experiment_id, '   '.join(web_ui_url_list)))
Deshui Yu's avatar
Deshui Yu committed
470

SparkSnail's avatar
SparkSnail committed
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
def create_experiment(args):
    '''start a new experiment'''
    config_file_name = ''.join(random.sample(string.ascii_letters + string.digits, 8))
    nni_config = Config(config_file_name)
    config_path = os.path.abspath(args.config)
    if not os.path.exists(config_path):
        print_error('Please set correct config path!')
        exit(1)
    experiment_config = get_yml_content(config_path)
    validate_all_content(experiment_config, config_path)

    nni_config.set_config('experimentConfig', experiment_config)
    launch_experiment(args, experiment_config, 'new', config_file_name)
    nni_config.set_config('restServerPort', args.port)

def manage_stopped_experiment(args, mode):
    '''view a stopped experiment'''
SparkSnail's avatar
SparkSnail committed
488
    update_experiment()
489
490
491
492
493
    experiment_config = Experiments()
    experiment_dict = experiment_config.get_all_experiments()
    experiment_id = None
    #find the latest stopped experiment
    if not args.id:
SparkSnail's avatar
SparkSnail committed
494
495
        print_error('Please set experiment id! \nYou could use \'nnictl {0} {id}\' to {0} a stopped experiment!\n' \
        'You could use \'nnictl experiment list --all\' to show all experiments!'.format(mode))
SparkSnail's avatar
SparkSnail committed
496
        exit(1)
497
498
499
500
    else:
        if experiment_dict.get(args.id) is None:
            print_error('Id %s not exist!' % args.id)
            exit(1)
501
        if experiment_dict[args.id]['status'] != 'STOPPED':
SparkSnail's avatar
SparkSnail committed
502
            print_error('Only stopped experiments can be {0}ed!'.format(mode))
503
504
            exit(1)
        experiment_id = args.id
SparkSnail's avatar
SparkSnail committed
505
    print_normal('{0} experiment {1}...'.format(mode, experiment_id))
506
    nni_config = Config(experiment_dict[experiment_id]['fileName'])
Deshui Yu's avatar
Deshui Yu committed
507
508
    experiment_config = nni_config.get_config('experimentConfig')
    experiment_id = nni_config.get_config('experimentId')
SparkSnail's avatar
SparkSnail committed
509
510
511
    new_config_file_name = ''.join(random.sample(string.ascii_letters + string.digits, 8))
    new_nni_config = Config(new_config_file_name)
    new_nni_config.set_config('experimentConfig', experiment_config)
SparkSnail's avatar
SparkSnail committed
512
    launch_experiment(args, experiment_config, mode, new_config_file_name, experiment_id)
SparkSnail's avatar
SparkSnail committed
513
    new_nni_config.set_config('restServerPort', args.port)
Deshui Yu's avatar
Deshui Yu committed
514

SparkSnail's avatar
SparkSnail committed
515
516
517
def view_experiment(args):
    '''view a stopped experiment'''
    manage_stopped_experiment(args, 'view')
Deshui Yu's avatar
Deshui Yu committed
518

SparkSnail's avatar
SparkSnail committed
519
520
def resume_experiment(args):
    '''resume an experiment'''
liuzhe-lz's avatar
liuzhe-lz committed
521
    manage_stopped_experiment(args, 'resume')