"official/benchmark/ncf_keras_benchmark.py" did not exist on "cce82cd05480c95a07ba5eea5cff9e1913830aeb"
launcher.py 27 KB
Newer Older
Deshui Yu's avatar
Deshui Yu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge,
# to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.


import json
import os
24
import sys
Deshui Yu's avatar
Deshui Yu committed
25
import shutil
26
import string
27
from subprocess import Popen, PIPE, call, check_output, check_call, CalledProcessError
Deshui Yu's avatar
Deshui Yu committed
28
import tempfile
29
from nni.constants import ModuleName, AdvisorModuleName
30
from nni_annotation import *
Deshui Yu's avatar
Deshui Yu committed
31
from .launcher_utils import validate_all_content
32
from .rest_utils import rest_put, rest_post, check_rest_server, check_rest_server_quick, check_response
SparkSnail's avatar
SparkSnail committed
33
from .url_utils import cluster_metadata_url, experiment_url, get_local_urls
34
from .config_utils import Config, Experiments
35
from .common_utils import get_yml_content, get_json_content, print_error, print_normal, print_warning, detect_process, detect_port, get_user, get_python_dir
36
from .constants import *
37
import random
QuanluZhang's avatar
QuanluZhang committed
38
import site
39
import time
Gems Guo's avatar
Gems Guo committed
40
from pathlib import Path
41
from .command_utils import check_output_command, kill_command
SparkSnail's avatar
SparkSnail committed
42
from .nnictl_utils import update_experiment
Gems Guo's avatar
Gems Guo committed
43

44
45
46
47
48
49
50
51
52
53
def get_log_path(config_file_name):
    '''generate stdout and stderr log path'''
    stdout_full_path = os.path.join(NNICTL_HOME_DIR, config_file_name, 'stdout')
    stderr_full_path = os.path.join(NNICTL_HOME_DIR, config_file_name, 'stderr')
    return stdout_full_path, stderr_full_path

def print_log_content(config_file_name):
    '''print log information'''
    stdout_full_path, stderr_full_path = get_log_path(config_file_name)
    print_normal(' Stdout:')
54
    print(check_output_command(stdout_full_path))
55
56
    print('\n\n')
    print_normal(' Stderr:')
57
    print(check_output_command(stderr_full_path))
58

Zejun Lin's avatar
Zejun Lin committed
59
60
61
62
63
64
65
66
67
def get_nni_installation_path():
    ''' Find nni lib from the following locations in order
    Return nni root directory if it exists
    '''
    def try_installation_path_sequentially(*sitepackages):
        '''Try different installation path sequentially util nni is found.
        Return None if nothing is found
        '''
        def _generate_installation_path(sitepackages_path):
68
            python_dir = get_python_dir(sitepackages_path)
Zejun Lin's avatar
Zejun Lin committed
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
            entry_file = os.path.join(python_dir, 'nni', 'main.js')
            if os.path.isfile(entry_file):
                return python_dir
            return None

        for sitepackage in sitepackages:
            python_dir = _generate_installation_path(sitepackage)
            if python_dir:
                return python_dir
        return None

    if os.getenv('VIRTUAL_ENV'):
        # if 'virtualenv' package is used, `site` has not attr getsitepackages, so we will instead use VIRTUAL_ENV
        # Note that conda venv will not have VIRTUAL_ENV
        python_dir = os.getenv('VIRTUAL_ENV')
    else:
        python_sitepackage = site.getsitepackages()[0]
        # If system-wide python is used, we will give priority to using `local sitepackage`--"usersitepackages()" given that nni exists there
        if python_sitepackage.startswith('/usr') or python_sitepackage.startswith('/Library'):
            python_dir = try_installation_path_sequentially(site.getusersitepackages(), site.getsitepackages()[0])
        else:
            python_dir = try_installation_path_sequentially(site.getsitepackages()[0], site.getusersitepackages())

    if python_dir:
        entry_file = os.path.join(python_dir, 'nni', 'main.js')
        if os.path.isfile(entry_file):
            return os.path.join(python_dir, 'nni')
    print_error('Fail to find nni under python library')
    exit(1)
98

99
def start_rest_server(port, platform, mode, config_file_name, experiment_id=None, log_dir=None, log_level=None):
Deshui Yu's avatar
Deshui Yu committed
100
    '''Run nni manager process'''
101
    nni_config = Config(config_file_name)
102
    if detect_port(port):
103
104
105
        print_error('Port %s is used by another process, please reset the port!\n' \
        'You could use \'nnictl create --help\' to get help information' % port)
        exit(1)
106

SparkSnail's avatar
SparkSnail committed
107
    if (platform != 'local') and detect_port(int(port) + 1):
108
109
110
        print_error('PAI mode need an additional adjacent port %d, and the port %d is used by another process!\n' \
        'You could set another port to start experiment!\n' \
        'You could use \'nnictl create --help\' to get help information' % ((int(port) + 1), (int(port) + 1)))
111
        exit(1)
Deshui Yu's avatar
Deshui Yu committed
112
113

    print_normal('Starting restful server...')
114

Zejun Lin's avatar
Zejun Lin committed
115
116
    entry_dir = get_nni_installation_path()
    entry_file = os.path.join(entry_dir, 'main.js')
demianzhang's avatar
demianzhang committed
117
118
119
120
    
    node_command = 'node'
    if sys.platform == 'win32':
        node_command = os.path.join(entry_dir[:-3], 'Scripts', 'node.exe')
SparkSnail's avatar
SparkSnail committed
121
122
123
124
125
126
    cmds = [node_command, entry_file, '--port', str(port), '--mode', platform]
    if mode == 'view':
        cmds += ['--start_mode', 'resume']
        cmds += ['--readonly', 'true']
    else:
        cmds += ['--start_mode', mode]
127
128
129
130
    if log_dir is not None:
        cmds += ['--log_dir', log_dir]
    if log_level is not None:
        cmds += ['--log_level', log_level]
SparkSnail's avatar
SparkSnail committed
131
    if mode in ['resume', 'view']:
Deshui Yu's avatar
Deshui Yu committed
132
        cmds += ['--experiment_id', experiment_id]
133
    stdout_full_path, stderr_full_path = get_log_path(config_file_name)
134
135
136
137
138
139
140
141
142
143
144
    with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file:
        time_now = time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))
        #add time information in the header of log files
        log_header = LOG_HEADER % str(time_now)
        stdout_file.write(log_header)
        stderr_file.write(log_header)
        if sys.platform == 'win32':
            from subprocess import CREATE_NEW_PROCESS_GROUP
            process = Popen(cmds, cwd=entry_dir, stdout=stdout_file, stderr=stderr_file, creationflags=CREATE_NEW_PROCESS_GROUP)
        else:
            process = Popen(cmds, cwd=entry_dir, stdout=stdout_file, stderr=stderr_file)
145
    return process, str(time_now)
Deshui Yu's avatar
Deshui Yu committed
146

147
def set_trial_config(experiment_config, port, config_file_name):
148
    '''set trial configuration'''
Deshui Yu's avatar
Deshui Yu committed
149
    request_data = dict()
150
    request_data['trial_config'] = experiment_config['trial']
151
    response = rest_put(cluster_metadata_url(port), json.dumps(request_data), REST_TIME_OUT)
152
153
154
    if check_response(response):
        return True
    else:
155
        print('Error message is {}'.format(response.text))
156
        _, stderr_full_path = get_log_path(config_file_name)
SparkSnail's avatar
SparkSnail committed
157
158
159
        if response:
            with open(stderr_full_path, 'a+') as fout:
                fout.write(json.dumps(json.loads(response.text), indent=4, sort_keys=True, separators=(',', ':')))
160
        return False
161

162
def set_local_config(experiment_config, port, config_file_name):
163
    '''set local configuration'''
164
165
166
    request_data = dict()
    if experiment_config.get('localConfig'):
        request_data['local_config'] = experiment_config['localConfig']
167
168
169
170
171
172
173
        if request_data['local_config']:
            if request_data['local_config'].get('gpuIndices') and isinstance(request_data['local_config'].get('gpuIndices'), int):
                request_data['local_config']['gpuIndices'] = str(request_data['local_config'].get('gpuIndices'))
            if request_data['local_config'].get('maxTrialNumOnEachGpu'):
                request_data['local_config']['maxTrialNumOnEachGpu'] = request_data['local_config'].get('maxTrialNumOnEachGpu')
            if request_data['local_config'].get('useActiveGpu'):
                request_data['local_config']['useActiveGpu'] = request_data['local_config'].get('useActiveGpu')
174
175
176
177
178
179
180
181
182
183
        response = rest_put(cluster_metadata_url(port), json.dumps(request_data), REST_TIME_OUT)
        err_message = ''
        if not response or not check_response(response):
            if response is not None:
                err_message = response.text
                _, stderr_full_path = get_log_path(config_file_name)
                with open(stderr_full_path, 'a+') as fout:
                    fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
            return False, err_message

SparkSnail's avatar
SparkSnail committed
184
    return set_trial_config(experiment_config, port, config_file_name), None
Deshui Yu's avatar
Deshui Yu committed
185

186
def set_remote_config(experiment_config, port, config_file_name):
Deshui Yu's avatar
Deshui Yu committed
187
188
189
190
    '''Call setClusterMetadata to pass trial'''
    #set machine_list
    request_data = dict()
    request_data['machine_list'] = experiment_config['machineList']
191
192
193
194
    if request_data['machine_list']:
        for i in range(len(request_data['machine_list'])):
            if isinstance(request_data['machine_list'][i].get('gpuIndices'), int):
                request_data['machine_list'][i]['gpuIndices'] = str(request_data['machine_list'][i].get('gpuIndices'))
195
    response = rest_put(cluster_metadata_url(port), json.dumps(request_data), REST_TIME_OUT)
196
    err_message = ''
197
    if not response or not check_response(response):
198
199
        if response is not None:
            err_message = response.text
200
            _, stderr_full_path = get_log_path(config_file_name)
goooxu's avatar
goooxu committed
201
            with open(stderr_full_path, 'a+') as fout:
202
                fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
203
        return False, err_message
204
205
206
    result, message = setNNIManagerIp(experiment_config, port, config_file_name)
    if not result:
        return result, message
Deshui Yu's avatar
Deshui Yu committed
207
    #set trial_config
208
    return set_trial_config(experiment_config, port, config_file_name), err_message
Deshui Yu's avatar
Deshui Yu committed
209

210
211
212
213
214
215
def setNNIManagerIp(experiment_config, port, config_file_name):
    '''set nniManagerIp'''
    if experiment_config.get('nniManagerIp') is None:
        return True, None
    ip_config_dict = dict()
    ip_config_dict['nni_manager_ip'] = { 'nniManagerIp' : experiment_config['nniManagerIp'] }
216
    response = rest_put(cluster_metadata_url(port), json.dumps(ip_config_dict), REST_TIME_OUT)
217
218
219
220
221
222
223
224
225
226
    err_message = None
    if not response or not response.status_code == 200:
        if response is not None:
            err_message = response.text
            _, stderr_full_path = get_log_path(config_file_name)
            with open(stderr_full_path, 'a+') as fout:
                fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
        return False, err_message
    return True, None

227
def set_pai_config(experiment_config, port, config_file_name):
228
    '''set pai configuration'''
229
230
    pai_config_data = dict()
    pai_config_data['pai_config'] = experiment_config['paiConfig']
231
    response = rest_put(cluster_metadata_url(port), json.dumps(pai_config_data), REST_TIME_OUT)
232
    err_message = None
233
234
235
    if not response or not response.status_code == 200:
        if response is not None:
            err_message = response.text
236
            _, stderr_full_path = get_log_path(config_file_name)
237
            with open(stderr_full_path, 'a+') as fout:
chicm-ms's avatar
chicm-ms committed
238
                fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
239
        return False, err_message
240
241
242
    result, message = setNNIManagerIp(experiment_config, port, config_file_name)
    if not result:
        return result, message
243
    #set trial_config
244
    return set_trial_config(experiment_config, port, config_file_name), err_message
245

246
def set_kubeflow_config(experiment_config, port, config_file_name):
247
    '''set kubeflow configuration'''
248
249
    kubeflow_config_data = dict()
    kubeflow_config_data['kubeflow_config'] = experiment_config['kubeflowConfig']
250
    response = rest_put(cluster_metadata_url(port), json.dumps(kubeflow_config_data), REST_TIME_OUT)
251
252
253
254
255
256
257
258
    err_message = None
    if not response or not response.status_code == 200:
        if response is not None:
            err_message = response.text
            _, stderr_full_path = get_log_path(config_file_name)
            with open(stderr_full_path, 'a+') as fout:
                fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
        return False, err_message
259
260
261
    result, message = setNNIManagerIp(experiment_config, port, config_file_name)
    if not result:
        return result, message
262
263
264
    #set trial_config
    return set_trial_config(experiment_config, port, config_file_name), err_message

265
def set_frameworkcontroller_config(experiment_config, port, config_file_name):
266
    '''set kubeflow configuration'''
267
268
    frameworkcontroller_config_data = dict()
    frameworkcontroller_config_data['frameworkcontroller_config'] = experiment_config['frameworkcontrollerConfig']
269
    response = rest_put(cluster_metadata_url(port), json.dumps(frameworkcontroller_config_data), REST_TIME_OUT)
270
271
272
273
274
275
276
277
278
279
280
281
282
283
    err_message = None
    if not response or not response.status_code == 200:
        if response is not None:
            err_message = response.text
            _, stderr_full_path = get_log_path(config_file_name)
            with open(stderr_full_path, 'a+') as fout:
                fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
        return False, err_message
    result, message = setNNIManagerIp(experiment_config, port, config_file_name)
    if not result:
        return result, message
    #set trial_config
    return set_trial_config(experiment_config, port, config_file_name), err_message

284
def set_experiment(experiment_config, mode, port, config_file_name):
Deshui Yu's avatar
Deshui Yu committed
285
286
287
288
289
290
291
    '''Call startExperiment (rest POST /experiment) with yaml file content'''
    request_data = dict()
    request_data['authorName'] = experiment_config['authorName']
    request_data['experimentName'] = experiment_config['experimentName']
    request_data['trialConcurrency'] = experiment_config['trialConcurrency']
    request_data['maxExecDuration'] = experiment_config['maxExecDuration']
    request_data['maxTrialNum'] = experiment_config['maxTrialNum']
292
    request_data['searchSpace'] = experiment_config.get('searchSpace')
293
294
295
296
    request_data['trainingServicePlatform'] = experiment_config.get('trainingServicePlatform')

    if experiment_config.get('description'):
        request_data['description'] = experiment_config['description']
chicm-ms's avatar
chicm-ms committed
297
298
    if experiment_config.get('multiPhase'):
        request_data['multiPhase'] = experiment_config.get('multiPhase')
299
300
    if experiment_config.get('multiThread'):
        request_data['multiThread'] = experiment_config.get('multiThread')
QuanluZhang's avatar
QuanluZhang committed
301
302
    if experiment_config.get('advisor'):
        request_data['advisor'] = experiment_config['advisor']
303
304
305
306
        if request_data['advisor'].get('gpuNum'):
            print_error('gpuNum is deprecated, please use gpuIndices instead.')
        if request_data['advisor'].get('gpuIndices') and isinstance(request_data['advisor'].get('gpuIndices'), int):
            request_data['advisor']['gpuIndices'] = str(request_data['advisor'].get('gpuIndices'))
QuanluZhang's avatar
QuanluZhang committed
307
308
    else:
        request_data['tuner'] = experiment_config['tuner']
309
310
311
312
        if request_data['tuner'].get('gpuNum'):
            print_error('gpuNum is deprecated, please use gpuIndices instead.')
        if request_data['tuner'].get('gpuIndices') and isinstance(request_data['tuner'].get('gpuIndices'), int):
            request_data['tuner']['gpuIndices'] = str(request_data['tuner'].get('gpuIndices'))
QuanluZhang's avatar
QuanluZhang committed
313
314
        if 'assessor' in experiment_config:
            request_data['assessor'] = experiment_config['assessor']
315
316
            if request_data['assessor'].get('gpuNum'):
                print_error('gpuNum is deprecated, please remove it from your config file.')
SparkSnail's avatar
SparkSnail committed
317
    #debug mode should disable version check
318
    if experiment_config.get('debug') is not None:
SparkSnail's avatar
SparkSnail committed
319
        request_data['versionCheck'] = not experiment_config.get('debug')
320
321
322
    #validate version check
    if experiment_config.get('versionCheck') is not None:
        request_data['versionCheck'] = experiment_config.get('versionCheck')
SparkSnail's avatar
SparkSnail committed
323
324
    if experiment_config.get('logCollection'):
        request_data['logCollection'] = experiment_config.get('logCollection')
Deshui Yu's avatar
Deshui Yu committed
325
326
327
328

    request_data['clusterMetaData'] = []
    if experiment_config['trainingServicePlatform'] == 'local':
        request_data['clusterMetaData'].append(
329
            {'key':'codeDir', 'value':experiment_config['trial']['codeDir']})
Deshui Yu's avatar
Deshui Yu committed
330
        request_data['clusterMetaData'].append(
331
            {'key': 'command', 'value': experiment_config['trial']['command']})
332
    elif experiment_config['trainingServicePlatform'] == 'remote':
Deshui Yu's avatar
Deshui Yu committed
333
334
335
        request_data['clusterMetaData'].append(
            {'key': 'machine_list', 'value': experiment_config['machineList']})
        request_data['clusterMetaData'].append(
336
            {'key': 'trial_config', 'value': experiment_config['trial']})
337
338
    elif experiment_config['trainingServicePlatform'] == 'pai':
        request_data['clusterMetaData'].append(
339
            {'key': 'pai_config', 'value': experiment_config['paiConfig']})
340
        request_data['clusterMetaData'].append(
341
342
343
344
345
346
            {'key': 'trial_config', 'value': experiment_config['trial']})
    elif experiment_config['trainingServicePlatform'] == 'kubeflow':
        request_data['clusterMetaData'].append(
            {'key': 'kubeflow_config', 'value': experiment_config['kubeflowConfig']})
        request_data['clusterMetaData'].append(
            {'key': 'trial_config', 'value': experiment_config['trial']})
347
348
349
350
351
    elif experiment_config['trainingServicePlatform'] == 'frameworkcontroller':
        request_data['clusterMetaData'].append(
            {'key': 'frameworkcontroller_config', 'value': experiment_config['frameworkcontrollerConfig']})
        request_data['clusterMetaData'].append(
            {'key': 'trial_config', 'value': experiment_config['trial']})
352
    response = rest_post(experiment_url(port), json.dumps(request_data), REST_TIME_OUT, show_error=True)
353
354
355
    if check_response(response):
        return response
    else:
356
        _, stderr_full_path = get_log_path(config_file_name)
357
        if response is not None:
SparkSnail's avatar
SparkSnail committed
358
359
            with open(stderr_full_path, 'a+') as fout:
                fout.write(json.dumps(json.loads(response.text), indent=4, sort_keys=True, separators=(',', ':')))
SparkSnail's avatar
SparkSnail committed
360
            print_error('Setting experiment error, error message is {}'.format(response.text))
361
        return None
Deshui Yu's avatar
Deshui Yu committed
362

SparkSnail's avatar
SparkSnail committed
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
def set_platform_config(platform, experiment_config, port, config_file_name, rest_process):
    '''call set_cluster_metadata for specific platform'''
    print_normal('Setting {0} config...'.format(platform))
    config_result, err_msg = None, None
    if platform == 'local':
        config_result, err_msg = set_local_config(experiment_config, port, config_file_name)
    elif platform == 'remote':
        config_result, err_msg = set_remote_config(experiment_config, port, config_file_name)
    elif platform == 'pai':
        config_result, err_msg = set_pai_config(experiment_config, port, config_file_name)
    elif platform == 'kubeflow':
        config_result, err_msg = set_kubeflow_config(experiment_config, port, config_file_name)
    elif platform == 'frameworkcontroller':
        config_result, err_msg = set_frameworkcontroller_config(experiment_config, port, config_file_name)
    else:
        raise Exception(ERROR_INFO % 'Unsupported platform!')
        exit(1)
    if config_result:
        print_normal('Successfully set {0} config!'.format(platform))
    else:
        print_error('Failed! Error is: {}'.format(err_msg))
        try:
            kill_command(rest_process.pid)
        except Exception:
            raise Exception(ERROR_INFO % 'Rest server stopped!')
        exit(1)

390
def launch_experiment(args, experiment_config, mode, config_file_name, experiment_id=None):
Deshui Yu's avatar
Deshui Yu committed
391
    '''follow steps to start rest server and start experiment'''
392
    nni_config = Config(config_file_name)
393
    # check packages for tuner
394
    package_name, module_name = None, None
395
    if experiment_config.get('tuner') and experiment_config['tuner'].get('builtinTunerName'):
396
397
398
399
400
401
        package_name = experiment_config['tuner']['builtinTunerName']
        module_name = ModuleName.get(package_name)
    elif experiment_config.get('advisor') and experiment_config['advisor'].get('builtinAdvisorName'):
        package_name = experiment_config['advisor']['builtinAdvisorName']
        module_name = AdvisorModuleName.get(package_name)
    if package_name and module_name:
402
        try:
403
404
405
            stdout_full_path, stderr_full_path = get_log_path(config_file_name)
            with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file:
                check_call([sys.executable, '-c', 'import %s'%(module_name)], stdout=stdout_file, stderr=stderr_file)
406
        except CalledProcessError as e:
407
408
409
410
            print_error('some errors happen when import package %s.' %(package_name))
            print_log_content(config_file_name)
            if package_name in PACKAGE_REQUIREMENTS:
                print_error('If %s is not installed, it should be installed through \'nnictl package install --name %s\''%(package_name, package_name))
411
            exit(1)
412
413
    log_dir = experiment_config['logDir'] if experiment_config.get('logDir') else None
    log_level = experiment_config['logLevel'] if experiment_config.get('logLevel') else None
SparkSnail's avatar
SparkSnail committed
414
415
416
417
    #view experiment mode do not need debug function, when view an experiment, there will be no new logs created
    if mode != 'view':
        if log_level not in ['trace', 'debug'] and (args.debug or experiment_config.get('debug') is True):
            log_level = 'debug'
Deshui Yu's avatar
Deshui Yu committed
418
    # start rest server
419
    rest_process, start_time = start_rest_server(args.port, experiment_config['trainingServicePlatform'], mode, config_file_name, experiment_id, log_dir, log_level)
Deshui Yu's avatar
Deshui Yu committed
420
421
422
    nni_config.set_config('restServerPid', rest_process.pid)
    # Deal with annotation
    if experiment_config.get('useAnnotation'):
423
        path = os.path.join(tempfile.gettempdir(), get_user(), 'nni', 'annotation')
QuanluZhang's avatar
QuanluZhang committed
424
425
        if not os.path.isdir(path):
            os.makedirs(path)
liuzhe-lz's avatar
liuzhe-lz committed
426
        path = tempfile.mkdtemp(dir=path)
427
428
        nas_mode = experiment_config['trial'].get('nasMode', 'classic_mode')
        code_dir = expand_annotations(experiment_config['trial']['codeDir'], path, nas_mode=nas_mode)
liuzhe-lz's avatar
liuzhe-lz committed
429
430
        experiment_config['trial']['codeDir'] = code_dir
        search_space = generate_search_space(code_dir)
431
        experiment_config['searchSpace'] = json.dumps(search_space)
Deshui Yu's avatar
Deshui Yu committed
432
        assert search_space, ERROR_INFO % 'Generated search space is empty'
433
    elif experiment_config.get('searchSpacePath'):
Zejun Lin's avatar
Zejun Lin committed
434
435
        search_space = get_json_content(experiment_config.get('searchSpacePath'))
        experiment_config['searchSpace'] = json.dumps(search_space)
Deshui Yu's avatar
Deshui Yu committed
436
    else:
437
        experiment_config['searchSpace'] = json.dumps('')
Deshui Yu's avatar
Deshui Yu committed
438
439

    # check rest server
goooxu's avatar
goooxu committed
440
    running, _ = check_rest_server(args.port)
441
    if running:
442
        print_normal('Successfully started Restful server!')
Deshui Yu's avatar
Deshui Yu committed
443
444
    else:
        print_error('Restful server start failed!')
445
        print_log_content(config_file_name)
Deshui Yu's avatar
Deshui Yu committed
446
        try:
447
            kill_command(rest_process.pid)
Deshui Yu's avatar
Deshui Yu committed
448
449
        except Exception:
            raise Exception(ERROR_INFO % 'Rest server stopped!')
goooxu's avatar
goooxu committed
450
        exit(1)
SparkSnail's avatar
SparkSnail committed
451
452
453
454
    if mode != 'view':
        # set platform configuration
        set_platform_config(experiment_config['trainingServicePlatform'], experiment_config, args.port, config_file_name, rest_process)
        
Deshui Yu's avatar
Deshui Yu committed
455
456
    # start a new experiment
    print_normal('Starting experiment...')
457
    # set debug configuration
SparkSnail's avatar
SparkSnail committed
458
    if mode != 'view' and experiment_config.get('debug') is None:
459
        experiment_config['debug'] = args.debug
460
    response = set_experiment(experiment_config, mode, args.port, config_file_name)
Deshui Yu's avatar
Deshui Yu committed
461
462
463
464
465
    if response:
        if experiment_id is None:
            experiment_id = json.loads(response.text).get('experiment_id')
        nni_config.set_config('experimentId', experiment_id)
    else:
466
467
        print_error('Start experiment failed!')
        print_log_content(config_file_name)
Deshui Yu's avatar
Deshui Yu committed
468
        try:
469
            kill_command(rest_process.pid)
Deshui Yu's avatar
Deshui Yu committed
470
        except Exception:
471
            raise Exception(ERROR_INFO % 'Restful server stopped!')
goooxu's avatar
goooxu committed
472
        exit(1)
473
474
475
476
    if experiment_config.get('nniManagerIp'):
        web_ui_url_list = ['{0}:{1}'.format(experiment_config['nniManagerIp'], str(args.port))]
    else:
        web_ui_url_list = get_local_urls(args.port)
SparkSnail's avatar
SparkSnail committed
477
    nni_config.set_config('webuiUrl', web_ui_url_list)
478

479
    #save experiment information
SparkSnail's avatar
SparkSnail committed
480
481
    nnictl_experiment_config = Experiments()
    nnictl_experiment_config.add_experiment(experiment_id, args.port, start_time, config_file_name, experiment_config['trainingServicePlatform'])
482
483

    print_normal(EXPERIMENT_SUCCESS_INFO % (experiment_id, '   '.join(web_ui_url_list)))
Deshui Yu's avatar
Deshui Yu committed
484

SparkSnail's avatar
SparkSnail committed
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
def create_experiment(args):
    '''start a new experiment'''
    config_file_name = ''.join(random.sample(string.ascii_letters + string.digits, 8))
    nni_config = Config(config_file_name)
    config_path = os.path.abspath(args.config)
    if not os.path.exists(config_path):
        print_error('Please set correct config path!')
        exit(1)
    experiment_config = get_yml_content(config_path)
    validate_all_content(experiment_config, config_path)

    nni_config.set_config('experimentConfig', experiment_config)
    launch_experiment(args, experiment_config, 'new', config_file_name)
    nni_config.set_config('restServerPort', args.port)

def manage_stopped_experiment(args, mode):
    '''view a stopped experiment'''
SparkSnail's avatar
SparkSnail committed
502
    update_experiment()
503
504
505
506
507
508
    experiment_config = Experiments()
    experiment_dict = experiment_config.get_all_experiments()
    experiment_id = None
    experiment_endTime = None
    #find the latest stopped experiment
    if not args.id:
SparkSnail's avatar
SparkSnail committed
509
510
        print_error('Please set experiment id! \nYou could use \'nnictl {0} {id}\' to {0} a stopped experiment!\n' \
        'You could use \'nnictl experiment list --all\' to show all experiments!'.format(mode))
SparkSnail's avatar
SparkSnail committed
511
        exit(1)
512
513
514
515
    else:
        if experiment_dict.get(args.id) is None:
            print_error('Id %s not exist!' % args.id)
            exit(1)
516
        if experiment_dict[args.id]['status'] != 'STOPPED':
SparkSnail's avatar
SparkSnail committed
517
            print_error('Only stopped experiments can be {0}ed!'.format(mode))
518
519
            exit(1)
        experiment_id = args.id
SparkSnail's avatar
SparkSnail committed
520
    print_normal('{0} experiment {1}...'.format(mode, experiment_id))
521
    nni_config = Config(experiment_dict[experiment_id]['fileName'])
Deshui Yu's avatar
Deshui Yu committed
522
523
    experiment_config = nni_config.get_config('experimentConfig')
    experiment_id = nni_config.get_config('experimentId')
SparkSnail's avatar
SparkSnail committed
524
525
526
    new_config_file_name = ''.join(random.sample(string.ascii_letters + string.digits, 8))
    new_nni_config = Config(new_config_file_name)
    new_nni_config.set_config('experimentConfig', experiment_config)
SparkSnail's avatar
SparkSnail committed
527
    launch_experiment(args, experiment_config, mode, new_config_file_name, experiment_id)
SparkSnail's avatar
SparkSnail committed
528
    new_nni_config.set_config('restServerPort', args.port)
Deshui Yu's avatar
Deshui Yu committed
529

SparkSnail's avatar
SparkSnail committed
530
531
532
def view_experiment(args):
    '''view a stopped experiment'''
    manage_stopped_experiment(args, 'view')
Deshui Yu's avatar
Deshui Yu committed
533

SparkSnail's avatar
SparkSnail committed
534
535
536
def resume_experiment(args):
    '''resume an experiment'''
    manage_stopped_experiment(args, 'resume')