nnictl_utils.py 44.4 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
Deshui Yu's avatar
Deshui Yu committed
3

Yan Ni's avatar
Yan Ni committed
4
import csv
Deshui Yu's avatar
Deshui Yu committed
5
import os
6
import sys
Deshui Yu's avatar
Deshui Yu committed
7
import json
8
import time
SparkSnail's avatar
SparkSnail committed
9
10
import re
import shutil
Yuge Zhang's avatar
Yuge Zhang committed
11
import subprocess
12
from functools import cmp_to_key
chicm-ms's avatar
chicm-ms committed
13
from datetime import datetime, timezone
14
from subprocess import Popen
chicm-ms's avatar
chicm-ms committed
15
from pyhdfs import HdfsClient
chicm-ms's avatar
chicm-ms committed
16
from nni.package_utils import get_nni_installation_path
17
from nni_annotation import expand_annotations
18
from .rest_utils import rest_get, rest_delete, check_rest_server_quick, check_response
19
from .url_utils import trial_jobs_url, experiment_url, trial_job_id_url, export_data_url, metric_data_url
20
from .config_utils import Config, Experiments
21
from .constants import NNICTL_HOME_DIR, NNI_HOME_DIR, EXPERIMENT_INFORMATION_FORMAT, EXPERIMENT_DETAIL_FORMAT, \
22
     EXPERIMENT_MONITOR_INFO, TRIAL_MONITOR_HEAD, TRIAL_MONITOR_CONTENT, TRIAL_MONITOR_TAIL, REST_TIME_OUT
23
from .common_utils import print_normal, print_error, print_warning, detect_process, get_yml_content, generate_temp_dir
24
from .command_utils import check_output_command, kill_command
SparkSnail's avatar
SparkSnail committed
25
from .ssh_utils import create_ssh_sftp_client, remove_remote_directory
Deshui Yu's avatar
Deshui Yu committed
26

27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def get_experiment_time(port):
    '''get the startTime and endTime of an experiment'''
    response = rest_get(experiment_url(port), REST_TIME_OUT)
    if response and check_response(response):
        content = convert_time_stamp_to_date(json.loads(response.text))
        return content.get('startTime'), content.get('endTime')
    return None, None

def get_experiment_status(port):
    '''get the status of an experiment'''
    result, response = check_rest_server_quick(port)
    if result:
        return json.loads(response.text).get('status')
    return None

def update_experiment():
SparkSnail's avatar
SparkSnail committed
43
44
45
46
47
48
49
    '''Update the experiment status in config file'''
    experiment_config = Experiments()
    experiment_dict = experiment_config.get_all_experiments()
    if not experiment_dict:
        return None
    for key in experiment_dict.keys():
        if isinstance(experiment_dict[key], dict):
50
            if experiment_dict[key].get('status') != 'STOPPED':
SparkSnail's avatar
SparkSnail committed
51
52
53
                nni_config = Config(experiment_dict[key]['fileName'])
                rest_pid = nni_config.get_config('restServerPid')
                if not detect_process(rest_pid):
54
55
56
57
58
59
60
61
62
63
64
                    experiment_config.update_experiment(key, 'status', 'STOPPED')
                    continue
                rest_port = nni_config.get_config('restServerPort')
                startTime, endTime = get_experiment_time(rest_port)
                if startTime:
                    experiment_config.update_experiment(key, 'startTime', startTime)
                if endTime:
                    experiment_config.update_experiment(key, 'endTime', endTime)
                status = get_experiment_status(rest_port)
                if status:
                    experiment_config.update_experiment(key, 'status', status)
SparkSnail's avatar
SparkSnail committed
65

SparkSnail's avatar
SparkSnail committed
66
def check_experiment_id(args, update=True):
67
68
    '''check if the id is valid
    '''
SparkSnail's avatar
SparkSnail committed
69
70
    if update:
        update_experiment()
71
72
73
    experiment_config = Experiments()
    experiment_dict = experiment_config.get_all_experiments()
    if not experiment_dict:
74
        print_normal('There is no experiment running...')
chicm-ms's avatar
chicm-ms committed
75
        return None
76
    if not args.id:
77
78
        running_experiment_list = []
        for key in experiment_dict.keys():
SparkSnail's avatar
SparkSnail committed
79
            if isinstance(experiment_dict[key], dict):
80
                if experiment_dict[key].get('status') != 'STOPPED':
SparkSnail's avatar
SparkSnail committed
81
82
83
84
                    running_experiment_list.append(key)
            elif isinstance(experiment_dict[key], list):
                # if the config file is old version, remove the configuration from file
                experiment_config.remove_experiment(key)
85
        if len(running_experiment_list) > 1:
86
            print_error('There are multiple experiments, please set the experiment id...')
87
88
            experiment_information = ""
            for key in running_experiment_list:
89
90
91
92
93
94
95
                experiment_information += EXPERIMENT_DETAIL_FORMAT % (key,
                                                                      experiment_dict[key].get('experimentName', 'N/A'),
                                                                      experiment_dict[key]['status'],
                                                                      experiment_dict[key]['port'],
                                                                      experiment_dict[key].get('platform'),
                                                                      experiment_dict[key]['startTime'],
                                                                      experiment_dict[key]['endTime'])
96
97
98
            print(EXPERIMENT_INFORMATION_FORMAT % experiment_information)
            exit(1)
        elif not running_experiment_list:
SparkSnail's avatar
SparkSnail committed
99
            print_error('There is no experiment running.')
chicm-ms's avatar
chicm-ms committed
100
            return None
101
102
        else:
            return running_experiment_list[0]
103
104
105
    if experiment_dict.get(args.id):
        return args.id
    else:
SparkSnail's avatar
SparkSnail committed
106
        print_error('Id not correct.')
107
        return None
Deshui Yu's avatar
Deshui Yu committed
108

109
def parse_ids(args):
110
    '''Parse the arguments for nnictl stop
111
112
113
114
115
116
117
118
    1.If port is provided and id is not specified, return the id who owns the port
    2.If both port and id are provided, return the id if it owns the port, otherwise fail
    3.If there is an id specified, return the corresponding id
    4.If there is no id specified, and there is an experiment running, return the id, or return Error
    5.If the id matches an experiment, nnictl will return the id.
    6.If the id ends with *, nnictl will match all ids matchs the regular
    7.If the id does not exist but match the prefix of an experiment id, nnictl will return the matched id
    8.If the id does not exist but match multiple prefix of the experiment ids, nnictl will give id information
119
    '''
120
    update_experiment()
121
122
123
124
125
126
    experiment_config = Experiments()
    experiment_dict = experiment_config.get_all_experiments()
    if not experiment_dict:
        print_normal('Experiment is not running...')
        return None
    result_list = []
127
128
    running_experiment_list = []
    for key in experiment_dict.keys():
SparkSnail's avatar
SparkSnail committed
129
        if isinstance(experiment_dict[key], dict):
130
            if experiment_dict[key].get('status') != 'STOPPED':
SparkSnail's avatar
SparkSnail committed
131
132
133
134
                running_experiment_list.append(key)
        elif isinstance(experiment_dict[key], list):
            # if the config file is old version, remove the configuration from file
            experiment_config.remove_experiment(key)
135
136
    if args.all:
        return running_experiment_list
137
138
139
140
141
142
143
144
    if args.port is not None:
        for key in running_experiment_list:
            if str(experiment_dict[key]['port']) == args.port:
                result_list.append(key)
        if args.id and result_list and args.id != result_list[0]:
            print_error('Experiment id and resful server port not match')
            exit(1)
    elif not args.id:
145
        if len(running_experiment_list) > 1:
146
            print_error('There are multiple experiments, please set the experiment id...')
147
            experiment_information = ""
148
            for key in running_experiment_list:
149
150
151
152
153
154
155
                experiment_information += EXPERIMENT_DETAIL_FORMAT % (key,
                                                                      experiment_dict[key].get('experimentName', 'N/A'),
                                                                      experiment_dict[key]['status'],
                                                                      experiment_dict[key]['port'],
                                                                      experiment_dict[key].get('platform'),
                                                                      experiment_dict[key]['startTime'],
                                                                      experiment_dict[key]['endTime'])
156
157
158
159
            print(EXPERIMENT_INFORMATION_FORMAT % experiment_information)
            exit(1)
        else:
            result_list = running_experiment_list
160
    elif args.id.endswith('*'):
chicm-ms's avatar
chicm-ms committed
161
162
163
        for expId in running_experiment_list:
            if expId.startswith(args.id[:-1]):
                result_list.append(expId)
164
    elif args.id in running_experiment_list:
165
166
        result_list.append(args.id)
    else:
chicm-ms's avatar
chicm-ms committed
167
168
169
        for expId in running_experiment_list:
            if expId.startswith(args.id):
                result_list.append(expId)
170
        if len(result_list) > 1:
chicm-ms's avatar
chicm-ms committed
171
            print_error(args.id + ' is ambiguous, please choose ' + ' '.join(result_list))
172
            return None
173
    if not result_list and (args.id  or args.port):
174
        print_error('There are no experiments matched, please set correct experiment id or restful server port')
SparkSnail's avatar
SparkSnail committed
175
    elif not result_list:
SparkSnail's avatar
SparkSnail committed
176
        print_error('There is no experiment running...')
177
178
    return result_list

179
180
181
def get_config_filename(args):
    '''get the file name of config file'''
    experiment_id = check_experiment_id(args)
chicm-ms's avatar
chicm-ms committed
182
    if experiment_id is None:
SparkSnail's avatar
SparkSnail committed
183
        print_error('Please set correct experiment id.')
chicm-ms's avatar
chicm-ms committed
184
        exit(1)
185
186
187
188
189
190
191
    experiment_config = Experiments()
    experiment_dict = experiment_config.get_all_experiments()
    return experiment_dict[experiment_id]['fileName']

def get_experiment_port(args):
    '''get the port of experiment'''
    experiment_id = check_experiment_id(args)
chicm-ms's avatar
chicm-ms committed
192
    if experiment_id is None:
SparkSnail's avatar
SparkSnail committed
193
        print_error('Please set correct experiment id.')
chicm-ms's avatar
chicm-ms committed
194
        exit(1)
195
196
197
198
199
200
201
202
203
    experiment_config = Experiments()
    experiment_dict = experiment_config.get_all_experiments()
    return experiment_dict[experiment_id]['port']

def convert_time_stamp_to_date(content):
    '''Convert time stamp to date time format'''
    start_time_stamp = content.get('startTime')
    end_time_stamp = content.get('endTime')
    if start_time_stamp:
204
        start_time = datetime.fromtimestamp(start_time_stamp // 1000, timezone.utc).astimezone().strftime("%Y/%m/%d %H:%M:%S")
205
206
        content['startTime'] = str(start_time)
    if end_time_stamp:
207
        end_time = datetime.fromtimestamp(end_time_stamp // 1000, timezone.utc).astimezone().strftime("%Y/%m/%d %H:%M:%S")
208
209
210
211
212
213
214
215
        content['endTime'] = str(end_time)
    return content

def check_rest(args):
    '''check if restful server is running'''
    nni_config = Config(get_config_filename(args))
    rest_port = nni_config.get_config('restServerPort')
    running, _ = check_rest_server_quick(rest_port)
SparkSnail's avatar
SparkSnail committed
216
    if running:
217
218
219
        print_normal('Restful server is running...')
    else:
        print_normal('Restful server is not running...')
SparkSnail's avatar
SparkSnail committed
220
    return running
221

Deshui Yu's avatar
Deshui Yu committed
222
223
def stop_experiment(args):
    '''Stop the experiment which is running'''
224
225
226
    if args.id and args.id == 'all':
        print_warning('\'nnictl stop all\' is abolished, please use \'nnictl stop --all\' to stop all of experiments!')
        exit(1)
227
228
229
230
231
    experiment_id_list = parse_ids(args)
    if experiment_id_list:
        experiment_config = Experiments()
        experiment_dict = experiment_config.get_all_experiments()
        for experiment_id in experiment_id_list:
232
            print_normal('Stopping experiment %s' % experiment_id)
233
            nni_config = Config(experiment_dict[experiment_id]['fileName'])
234
235
            rest_pid = nni_config.get_config('restServerPid')
            if rest_pid:
236
                kill_command(rest_pid)
SparkSnail's avatar
SparkSnail committed
237
238
239
240
                tensorboard_pid_list = nni_config.get_config('tensorboardPidList')
                if tensorboard_pid_list:
                    for tensorboard_pid in tensorboard_pid_list:
                        try:
241
                            kill_command(tensorboard_pid)
SparkSnail's avatar
SparkSnail committed
242
243
244
                        except Exception as exception:
                            print_error(exception)
                    nni_config.set_config('tensorboardPidList', [])
SparkSnail's avatar
SparkSnail committed
245
            print_normal('Stop experiment success.')
246
            experiment_config.update_experiment(experiment_id, 'status', 'STOPPED')
chicm-ms's avatar
chicm-ms committed
247
            time_now = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
248
            experiment_config.update_experiment(experiment_id, 'endTime', str(time_now))
Deshui Yu's avatar
Deshui Yu committed
249
250
251

def trial_ls(args):
    '''List trial'''
252
253
254
255
256
257
258
259
260
261
262
263
264
265
    def final_metric_data_cmp(lhs, rhs):
        metric_l = json.loads(json.loads(lhs['finalMetricData'][0]['data']))
        metric_r = json.loads(json.loads(rhs['finalMetricData'][0]['data']))
        if isinstance(metric_l, float):
            return metric_l - metric_r
        elif isinstance(metric_l, dict):
            return metric_l['default'] - metric_r['default']
        else:
            print_error('Unexpected data format. Please check your data.')
            raise ValueError

    if args.head and args.tail:
        print_error('Head and tail cannot be set at the same time.')
        return
266
    nni_config = Config(get_config_filename(args))
Deshui Yu's avatar
Deshui Yu committed
267
268
269
270
271
    rest_port = nni_config.get_config('restServerPort')
    rest_pid = nni_config.get_config('restServerPid')
    if not detect_process(rest_pid):
        print_error('Experiment is not running...')
        return
272
273
    running, response = check_rest_server_quick(rest_port)
    if running:
274
        response = rest_get(trial_jobs_url(rest_port), REST_TIME_OUT)
275
        if response and check_response(response):
276
            content = json.loads(response.text)
277
278
279
280
281
282
283
284
            if args.head:
                assert args.head > 0, 'The number of requested data must be greater than 0.'
                content = sorted(filter(lambda x: 'finalMetricData' in x, content),
                                 key=cmp_to_key(final_metric_data_cmp), reverse=True)[:args.head]
            elif args.tail:
                assert args.tail > 0, 'The number of requested data must be greater than 0.'
                content = sorted(filter(lambda x: 'finalMetricData' in x, content),
                                 key=cmp_to_key(final_metric_data_cmp))[:args.tail]
285
            for index, value in enumerate(content):
286
287
                content[index] = convert_time_stamp_to_date(value)
            print(json.dumps(content, indent=4, sort_keys=True, separators=(',', ':')))
SparkSnail's avatar
SparkSnail committed
288
            return content
Deshui Yu's avatar
Deshui Yu committed
289
290
291
292
        else:
            print_error('List trial failed...')
    else:
        print_error('Restful server is not running...')
SparkSnail's avatar
SparkSnail committed
293
    return None
Deshui Yu's avatar
Deshui Yu committed
294
295
296

def trial_kill(args):
    '''List trial'''
297
    nni_config = Config(get_config_filename(args))
Deshui Yu's avatar
Deshui Yu committed
298
299
300
301
302
    rest_port = nni_config.get_config('restServerPort')
    rest_pid = nni_config.get_config('restServerPid')
    if not detect_process(rest_pid):
        print_error('Experiment is not running...')
        return
303
304
    running, _ = check_rest_server_quick(rest_port)
    if running:
SparkSnail's avatar
SparkSnail committed
305
        response = rest_delete(trial_job_id_url(rest_port, args.trial_id), REST_TIME_OUT)
306
        if response and check_response(response):
Deshui Yu's avatar
Deshui Yu committed
307
            print(response.text)
SparkSnail's avatar
SparkSnail committed
308
            return True
Deshui Yu's avatar
Deshui Yu committed
309
310
311
312
        else:
            print_error('Kill trial job failed...')
    else:
        print_error('Restful server is not running...')
SparkSnail's avatar
SparkSnail committed
313
    return False
Deshui Yu's avatar
Deshui Yu committed
314

315
316
317
318
319
320
321
322
323
324
325
def trial_codegen(args):
    '''Generate code for a specific trial'''
    print_warning('Currently, this command is only for nni nas programming interface.')
    exp_id = check_experiment_id(args)
    nni_config = Config(get_config_filename(args))
    if not nni_config.get_config('experimentConfig')['useAnnotation']:
        print_error('The experiment is not using annotation')
        exit(1)
    code_dir = nni_config.get_config('experimentConfig')['trial']['codeDir']
    expand_annotations(code_dir, './exp_%s_trial_%s_code'%(exp_id, args.trial_id), exp_id, args.trial_id)

Deshui Yu's avatar
Deshui Yu committed
326
327
def list_experiment(args):
    '''Get experiment information'''
328
    nni_config = Config(get_config_filename(args))
Deshui Yu's avatar
Deshui Yu committed
329
330
331
332
333
    rest_port = nni_config.get_config('restServerPort')
    rest_pid = nni_config.get_config('restServerPid')
    if not detect_process(rest_pid):
        print_error('Experiment is not running...')
        return
334
335
    running, _ = check_rest_server_quick(rest_port)
    if running:
336
        response = rest_get(experiment_url(rest_port), REST_TIME_OUT)
337
        if response and check_response(response):
338
339
            content = convert_time_stamp_to_date(json.loads(response.text))
            print(json.dumps(content, indent=4, sort_keys=True, separators=(',', ':')))
SparkSnail's avatar
SparkSnail committed
340
            return content
Deshui Yu's avatar
Deshui Yu committed
341
342
343
344
        else:
            print_error('List experiment failed...')
    else:
        print_error('Restful server is not running...')
SparkSnail's avatar
SparkSnail committed
345
    return None
Deshui Yu's avatar
Deshui Yu committed
346

347
348
def experiment_status(args):
    '''Show the status of experiment'''
349
    nni_config = Config(get_config_filename(args))
350
351
352
353
354
355
    rest_port = nni_config.get_config('restServerPort')
    result, response = check_rest_server_quick(rest_port)
    if not result:
        print_normal('Restful server is not running...')
    else:
        print(json.dumps(json.loads(response.text), indent=4, sort_keys=True, separators=(',', ':')))
SparkSnail's avatar
SparkSnail committed
356
    return result
357

Deshui Yu's avatar
Deshui Yu committed
358
359
def log_internal(args, filetype):
    '''internal function to call get_log_content'''
360
    file_name = get_config_filename(args)
Deshui Yu's avatar
Deshui Yu committed
361
    if filetype == 'stdout':
362
        file_full_path = os.path.join(NNICTL_HOME_DIR, file_name, 'stdout')
Deshui Yu's avatar
Deshui Yu committed
363
    else:
364
        file_full_path = os.path.join(NNICTL_HOME_DIR, file_name, 'stderr')
365
    print(check_output_command(file_full_path, head=args.head, tail=args.tail))
366

Deshui Yu's avatar
Deshui Yu committed
367
368
369
370
371
372
373
374
def log_stdout(args):
    '''get stdout log'''
    log_internal(args, 'stdout')

def log_stderr(args):
    '''get stderr log'''
    log_internal(args, 'stderr')

375
376
377
def log_trial(args):
    ''''get trial log path'''
    trial_id_path_dict = {}
SparkSnail's avatar
SparkSnail committed
378
    trial_id_list = []
379
    nni_config = Config(get_config_filename(args))
380
381
382
383
384
385
386
    rest_port = nni_config.get_config('restServerPort')
    rest_pid = nni_config.get_config('restServerPid')
    if not detect_process(rest_pid):
        print_error('Experiment is not running...')
        return
    running, response = check_rest_server_quick(rest_port)
    if running:
387
        response = rest_get(trial_jobs_url(rest_port), REST_TIME_OUT)
388
389
390
        if response and check_response(response):
            content = json.loads(response.text)
            for trial in content:
SparkSnail's avatar
SparkSnail committed
391
392
393
                trial_id_list.append(trial.get('id'))
                if trial.get('logPath'):
                    trial_id_path_dict[trial.get('id')] = trial['logPath']
394
395
    else:
        print_error('Restful server is not running...')
goooxu's avatar
goooxu committed
396
        exit(1)
SparkSnail's avatar
SparkSnail committed
397
398
399
400
401
402
    if args.trial_id:
        if args.trial_id not in trial_id_list:
            print_error('Trial id {0} not correct, please check your command!'.format(args.trial_id))
            exit(1)
        if trial_id_path_dict.get(args.trial_id):
            print_normal('id:' + args.trial_id + ' path:' + trial_id_path_dict[args.trial_id])
403
        else:
SparkSnail's avatar
SparkSnail committed
404
            print_error('Log path is not available yet, please wait...')
goooxu's avatar
goooxu committed
405
            exit(1)
406
    else:
SparkSnail's avatar
SparkSnail committed
407
        print_normal('All of trial log info:')
408
        for key in trial_id_path_dict:
SparkSnail's avatar
SparkSnail committed
409
410
411
            print_normal('id:' + key + ' path:' + trial_id_path_dict[key])
        if not trial_id_path_dict:
            print_normal('None')
412

Deshui Yu's avatar
Deshui Yu committed
413
414
def get_config(args):
    '''get config info'''
415
    nni_config = Config(get_config_filename(args))
Deshui Yu's avatar
Deshui Yu committed
416
    print(nni_config.get_all_config())
417
418
419

def webui_url(args):
    '''show the url of web ui'''
420
    nni_config = Config(get_config_filename(args))
421
422
    print_normal('{0} {1}'.format('Web UI url:', ' '.join(nni_config.get_config('webuiUrl'))))

Yuge Zhang's avatar
Yuge Zhang committed
423
424
425
426
def webui_nas(args):
    '''launch nas ui'''
    print_normal('Starting NAS UI...')
    try:
SparkSnail's avatar
SparkSnail committed
427
428
429
430
431
432
        entry_dir = get_nni_installation_path()
        entry_file = os.path.join(entry_dir, 'nasui', 'server.js')
        node_command = 'node'
        if sys.platform == 'win32':
            node_command = os.path.join(entry_dir[:-3], 'Scripts', 'node.exe')
        cmds = [node_command, '--max-old-space-size=4096', entry_file, '--port', str(args.port), '--logdir', args.logdir]
Yuge Zhang's avatar
Yuge Zhang committed
433
434
435
436
        subprocess.run(cmds)
    except KeyboardInterrupt:
        pass

SparkSnail's avatar
SparkSnail committed
437
438
439
440
441
def local_clean(directory):
    '''clean up local data'''
    print_normal('removing folder {0}'.format(directory))
    try:
        shutil.rmtree(directory)
chicm-ms's avatar
chicm-ms committed
442
    except FileNotFoundError:
SparkSnail's avatar
SparkSnail committed
443
        print_error('{0} does not exist.'.format(directory))
chicm-ms's avatar
chicm-ms committed
444

SparkSnail's avatar
SparkSnail committed
445
446
447
448
449
450
451
def remote_clean(machine_list, experiment_id=None):
    '''clean up remote data'''
    for machine in machine_list:
        passwd = machine.get('passwd')
        userName = machine.get('username')
        host = machine.get('ip')
        port = machine.get('port')
452
453
        sshKeyPath = machine.get('sshKeyPath')
        passphrase = machine.get('passphrase')
SparkSnail's avatar
SparkSnail committed
454
        if experiment_id:
Junwei Sun's avatar
Junwei Sun committed
455
            remote_dir = '/' + '/'.join(['tmp', 'nni-experiments', experiment_id])
SparkSnail's avatar
SparkSnail committed
456
        else:
Junwei Sun's avatar
Junwei Sun committed
457
            remote_dir = '/' + '/'.join(['tmp', 'nni-experiments'])
458
        sftp = create_ssh_sftp_client(host, port, userName, passwd, sshKeyPath, passphrase)
SparkSnail's avatar
SparkSnail committed
459
460
        print_normal('removing folder {0}'.format(host + ':' + str(port) + remote_dir))
        remove_remote_directory(sftp, remote_dir)
chicm-ms's avatar
chicm-ms committed
461

SparkSnail's avatar
SparkSnail committed
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
def hdfs_clean(host, user_name, output_dir, experiment_id=None):
    '''clean up hdfs data'''
    hdfs_client = HdfsClient(hosts='{0}:80'.format(host), user_name=user_name, webhdfs_path='/webhdfs/api/v1', timeout=5)
    if experiment_id:
        full_path = '/' + '/'.join([user_name, 'nni', 'experiments', experiment_id])
    else:
        full_path = '/' + '/'.join([user_name, 'nni', 'experiments'])
    print_normal('removing folder {0} in hdfs'.format(full_path))
    hdfs_client.delete(full_path, recursive=True)
    if output_dir:
        pattern = re.compile('hdfs://(?P<host>([0-9]{1,3}.){3}[0-9]{1,3})(:[0-9]{2,5})?(?P<baseDir>/.*)?')
        match_result = pattern.match(output_dir)
        if match_result:
            output_host = match_result.group('host')
            output_dir = match_result.group('baseDir')
            #check if the host is valid
            if output_host != host:
                print_warning('The host in {0} is not consistent with {1}'.format(output_dir, host))
            else:
                if experiment_id:
                    output_dir = output_dir + '/' + experiment_id
                print_normal('removing folder {0} in hdfs'.format(output_dir))
                hdfs_client.delete(output_dir, recursive=True)

def experiment_clean(args):
    '''clean up the experiment data'''
    experiment_id_list = []
    experiment_config = Experiments()
    experiment_dict = experiment_config.get_all_experiments()
    if args.all:
        experiment_id_list = list(experiment_dict.keys())
    else:
        if args.id is None:
SparkSnail's avatar
SparkSnail committed
495
            print_error('please set experiment id.')
SparkSnail's avatar
SparkSnail committed
496
497
            exit(1)
        if args.id not in experiment_dict:
SparkSnail's avatar
SparkSnail committed
498
            print_error('Cannot find experiment {0}.'.format(args.id))
SparkSnail's avatar
SparkSnail committed
499
500
501
            exit(1)
        experiment_id_list.append(args.id)
    while True:
SparkSnail's avatar
SparkSnail committed
502
        print('INFO: This action will delete experiment {0}, and it\'s not recoverable.'.format(' '.join(experiment_id_list)))
SparkSnail's avatar
SparkSnail committed
503
504
505
506
        inputs = input('INFO: do you want to continue?[y/N]:')
        if not inputs.lower() or inputs.lower() in ['n', 'no']:
            exit(0)
        elif inputs.lower() not in ['y', 'n', 'yes', 'no']:
SparkSnail's avatar
SparkSnail committed
507
            print_warning('please input Y or N.')
SparkSnail's avatar
SparkSnail committed
508
509
510
511
512
513
514
515
516
517
        else:
            break
    for experiment_id in experiment_id_list:
        nni_config = Config(experiment_dict[experiment_id]['fileName'])
        platform = nni_config.get_config('experimentConfig').get('trainingServicePlatform')
        experiment_id = nni_config.get_config('experimentId')
        if platform == 'remote':
            machine_list = nni_config.get_config('experimentConfig').get('machineList')
            remote_clean(machine_list, experiment_id)
        elif platform == 'pai':
chicm-ms's avatar
chicm-ms committed
518
            host = nni_config.get_config('experimentConfig').get('paiConfig').get('host')
SparkSnail's avatar
SparkSnail committed
519
520
521
522
523
            user_name = nni_config.get_config('experimentConfig').get('paiConfig').get('userName')
            output_dir = nni_config.get_config('experimentConfig').get('trial').get('outputDir')
            hdfs_clean(host, user_name, output_dir, experiment_id)
        elif platform != 'local':
            #TODO: support all platforms
SparkSnail's avatar
SparkSnail committed
524
            print_warning('platform {0} clean up not supported yet.'.format(platform))
SparkSnail's avatar
SparkSnail committed
525
526
            exit(0)
        #clean local data
527
528
529
530
531
532
533
534
535
        local_base_dir = nni_config.get_config('experimentConfig').get('logDir')
        if not local_base_dir:
            local_base_dir = NNI_HOME_DIR
        local_experiment_dir = os.path.join(local_base_dir, experiment_id)
        experiment_folder_name_list = ['checkpoint', 'db', 'log', 'trials']
        for folder_name in experiment_folder_name_list:
            local_clean(os.path.join(local_experiment_dir, folder_name))
        if not os.listdir(local_experiment_dir):
            local_clean(local_experiment_dir)
SparkSnail's avatar
SparkSnail committed
536
537
538
        experiment_config = Experiments()
        print_normal('removing metadata of experiment {0}'.format(experiment_id))
        experiment_config.remove_experiment(experiment_id)
chicm-ms's avatar
chicm-ms committed
539
        print_normal('Done.')
SparkSnail's avatar
SparkSnail committed
540
541
542
543
544
545
546
547
548
549

def get_platform_dir(config_content):
    '''get the dir list to be deleted'''
    platform = config_content.get('trainingServicePlatform')
    dir_list = []
    if platform == 'remote':
        machine_list = config_content.get('machineList')
        for machine in machine_list:
            host = machine.get('ip')
            port = machine.get('port')
SparkSnail's avatar
SparkSnail committed
550
            dir_list.append(host + ':' + str(port) + '/tmp/nni')
SparkSnail's avatar
SparkSnail committed
551
    elif platform == 'pai':
chicm-ms's avatar
chicm-ms committed
552
        host = config_content.get('paiConfig').get('host')
SparkSnail's avatar
SparkSnail committed
553
554
        user_name = config_content.get('paiConfig').get('userName')
        output_dir = config_content.get('trial').get('outputDir')
SparkSnail's avatar
SparkSnail committed
555
        dir_list.append('server: {0}, path: {1}/nni'.format(host, user_name))
SparkSnail's avatar
SparkSnail committed
556
557
558
559
560
561
562
563
        if output_dir:
            dir_list.append(output_dir)
    return dir_list

def platform_clean(args):
    '''clean up the experiment data'''
    config_path = os.path.abspath(args.config)
    if not os.path.exists(config_path):
SparkSnail's avatar
SparkSnail committed
564
        print_error('Please set correct config path.')
SparkSnail's avatar
SparkSnail committed
565
566
567
        exit(1)
    config_content = get_yml_content(config_path)
    platform = config_content.get('trainingServicePlatform')
SparkSnail's avatar
SparkSnail committed
568
569
570
    if platform == 'local':
        print_normal('it doesn’t need to clean local platform.')
        exit(0)
SparkSnail's avatar
SparkSnail committed
571
    if platform not in ['remote', 'pai']:
SparkSnail's avatar
SparkSnail committed
572
        print_normal('platform {0} not supported.'.format(platform))
SparkSnail's avatar
SparkSnail committed
573
        exit(0)
Shinai Yang's avatar
Shinai Yang committed
574
    update_experiment()
SparkSnail's avatar
SparkSnail committed
575
576
    dir_list = get_platform_dir(config_content)
    if not dir_list:
SparkSnail's avatar
SparkSnail committed
577
        print_normal('No folder of NNI caches is found.')
SparkSnail's avatar
SparkSnail committed
578
579
        exit(1)
    while True:
chicm-ms's avatar
chicm-ms committed
580
581
582
583
        print_normal('This command will remove below folders of NNI caches. If other users are using experiments' \
                     ' on below hosts, it will be broken.')
        for value in dir_list:
            print('       ' + value)
SparkSnail's avatar
SparkSnail committed
584
585
586
587
        inputs = input('INFO: do you want to continue?[y/N]:')
        if not inputs.lower() or inputs.lower() in ['n', 'no']:
            exit(0)
        elif inputs.lower() not in ['y', 'n', 'yes', 'no']:
SparkSnail's avatar
SparkSnail committed
588
            print_warning('please input Y or N.')
SparkSnail's avatar
SparkSnail committed
589
590
591
592
        else:
            break
    if platform == 'remote':
        machine_list = config_content.get('machineList')
chicm-ms's avatar
chicm-ms committed
593
        remote_clean(machine_list, None)
SparkSnail's avatar
SparkSnail committed
594
    elif platform == 'pai':
chicm-ms's avatar
chicm-ms committed
595
        host = config_content.get('paiConfig').get('host')
SparkSnail's avatar
SparkSnail committed
596
597
598
        user_name = config_content.get('paiConfig').get('userName')
        output_dir = config_content.get('trial').get('outputDir')
        hdfs_clean(host, user_name, output_dir, None)
SparkSnail's avatar
SparkSnail committed
599
    print_normal('Done.')
SparkSnail's avatar
SparkSnail committed
600

601
602
def experiment_list(args):
    '''get the information of all experiments'''
Shinai Yang's avatar
Shinai Yang committed
603
    update_experiment()
604
605
606
    experiment_config = Experiments()
    experiment_dict = experiment_config.get_all_experiments()
    if not experiment_dict:
SparkSnail's avatar
SparkSnail committed
607
        print_normal('Cannot find experiments.')
608
609
        exit(1)
    experiment_id_list = []
SparkSnail's avatar
SparkSnail committed
610
    if args.all:
611
612
        for key in experiment_dict.keys():
            experiment_id_list.append(key)
613
614
    else:
        for key in experiment_dict.keys():
615
            if experiment_dict[key]['status'] != 'STOPPED':
616
617
                experiment_id_list.append(key)
        if not experiment_id_list:
618
            print_warning('There is no experiment running...\nYou can use \'nnictl experiment list --all\' to list all experiments.')
619
620
    experiment_information = ""
    for key in experiment_id_list:
621
622
623
624
625
626
627
        experiment_information += EXPERIMENT_DETAIL_FORMAT % (key,
                                                              experiment_dict[key].get('experimentName', 'N/A'),
                                                              experiment_dict[key]['status'],
                                                              experiment_dict[key]['port'],
                                                              experiment_dict[key].get('platform'),
                                                              experiment_dict[key]['startTime'],
                                                              experiment_dict[key]['endTime'])
628
    print(EXPERIMENT_INFORMATION_FORMAT % experiment_information)
SparkSnail's avatar
SparkSnail committed
629
    return experiment_id_list
SparkSnail's avatar
SparkSnail committed
630

SparkSnail's avatar
SparkSnail committed
631
632
633
634
def get_time_interval(time1, time2):
    '''get the interval of two times'''
    try:
        #convert time to timestamp
635
636
        time1 = time.mktime(time.strptime(time1, '%Y/%m/%d %H:%M:%S'))
        time2 = time.mktime(time.strptime(time2, '%Y/%m/%d %H:%M:%S'))
637
        seconds = (datetime.fromtimestamp(time2) - datetime.fromtimestamp(time1)).seconds
SparkSnail's avatar
SparkSnail committed
638
639
640
641
642
643
644
645
646
647
648
649
650
        #convert seconds to day:hour:minute:second
        days = seconds / 86400
        seconds %= 86400
        hours = seconds / 3600
        seconds %= 3600
        minutes = seconds / 60
        seconds %= 60
        return '%dd %dh %dm %ds' % (days, hours, minutes, seconds)
    except:
        return 'N/A'

def show_experiment_info():
    '''show experiment information in monitor'''
Shinai Yang's avatar
Shinai Yang committed
651
    update_experiment()
SparkSnail's avatar
SparkSnail committed
652
653
654
655
656
657
658
    experiment_config = Experiments()
    experiment_dict = experiment_config.get_all_experiments()
    if not experiment_dict:
        print('There is no experiment running...')
        exit(1)
    experiment_id_list = []
    for key in experiment_dict.keys():
659
        if experiment_dict[key]['status'] != 'STOPPED':
SparkSnail's avatar
SparkSnail committed
660
661
662
663
664
665
            experiment_id_list.append(key)
    if not experiment_id_list:
        print_warning('There is no experiment running...')
        return
    for key in experiment_id_list:
        print(EXPERIMENT_MONITOR_INFO % (key, experiment_dict[key]['status'], experiment_dict[key]['port'], \
chicm-ms's avatar
chicm-ms committed
666
667
             experiment_dict[key].get('platform'), experiment_dict[key]['startTime'], \
             get_time_interval(experiment_dict[key]['startTime'], experiment_dict[key]['endTime'])))
SparkSnail's avatar
SparkSnail committed
668
669
670
        print(TRIAL_MONITOR_HEAD)
        running, response = check_rest_server_quick(experiment_dict[key]['port'])
        if running:
671
            response = rest_get(trial_jobs_url(experiment_dict[key]['port']), REST_TIME_OUT)
SparkSnail's avatar
SparkSnail committed
672
673
            if response and check_response(response):
                content = json.loads(response.text)
674
                for index, value in enumerate(content):
SparkSnail's avatar
SparkSnail committed
675
                    content[index] = convert_time_stamp_to_date(value)
chicm-ms's avatar
chicm-ms committed
676
677
                    print(TRIAL_MONITOR_CONTENT % (content[index].get('id'), content[index].get('startTime'), \
                          content[index].get('endTime'), content[index].get('status')))
SparkSnail's avatar
SparkSnail committed
678
679
        print(TRIAL_MONITOR_TAIL)

680
681
def set_monitor(auto_exit, time_interval, port=None, pid=None):
    '''set the experiment monitor engine'''
SparkSnail's avatar
SparkSnail committed
682
683
    while True:
        try:
684
685
686
687
            if sys.platform == 'win32':
                os.system('cls')
            else:
                os.system('clear')
688
            update_experiment()
SparkSnail's avatar
SparkSnail committed
689
            show_experiment_info()
690
691
692
693
694
695
696
697
698
            if auto_exit:
                status = get_experiment_status(port)
                if status in ['DONE', 'ERROR', 'STOPPED']:
                    print_normal('Experiment status is {0}.'.format(status))
                    print_normal('Stopping experiment...')
                    kill_command(pid)
                    print_normal('Stop experiment success.')
                    exit(0)
            time.sleep(time_interval)
SparkSnail's avatar
SparkSnail committed
699
        except KeyboardInterrupt:
700
701
702
703
704
705
            if auto_exit:
                print_normal('Stopping experiment...')
                kill_command(pid)
                print_normal('Stop experiment success.')
            else:
                print_normal('Exiting...')
SparkSnail's avatar
SparkSnail committed
706
707
708
709
            exit(0)
        except Exception as exception:
            print_error(exception)
            exit(1)
Yan Ni's avatar
Yan Ni committed
710

711
712
713
714
715
716
717
def monitor_experiment(args):
    '''monitor the experiment'''
    if args.time <= 0:
        print_error('please input a positive integer as time interval, the unit is second.')
        exit(1)
    set_monitor(False, args.time)

Yan Ni's avatar
Yan Ni committed
718
def export_trials_data(args):
719
    '''export experiment metadata and intermediate results to json or csv
720
    '''
721
722
723
724
725
726
727
    def groupby_trial_id(intermediate_results):
        sorted(intermediate_results, key=lambda x: x['timestamp'])
        groupby = dict()
        for content in intermediate_results:
            groupby.setdefault(content['trialJobId'], []).append(json.loads(content['data']))
        return groupby

Yan Ni's avatar
Yan Ni committed
728
729
730
    nni_config = Config(get_config_filename(args))
    rest_port = nni_config.get_config('restServerPort')
    rest_pid = nni_config.get_config('restServerPid')
731

Yan Ni's avatar
Yan Ni committed
732
733
734
735
    if not detect_process(rest_pid):
        print_error('Experiment is not running...')
        return
    running, response = check_rest_server_quick(rest_port)
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
    if not running:
        print_error('Restful server is not running')
        return
    response = rest_get(export_data_url(rest_port), 20)
    if response is not None and check_response(response):
        content = json.loads(response.text)
        if args.intermediate:
            intermediate_results_response = rest_get(metric_data_url(rest_port), REST_TIME_OUT)
            if not intermediate_results_response or not check_response(intermediate_results_response):
                print_error('Error getting intermediate results.')
                return
            intermediate_results = groupby_trial_id(json.loads(intermediate_results_response.text))
            for record in content:
                record['intermediate'] = intermediate_results[record['id']]
        if args.type == 'json':
            with open(args.path, 'w') as file:
                file.write(json.dumps(content))
        elif args.type == 'csv':
            trial_records = []
            for record in content:
                formated_record = dict()
                if args.intermediate:
                    formated_record['intermediate'] = '[' + ','.join(record['intermediate']) + ']'
                record_value = json.loads(record['value'])
                if not isinstance(record_value, (float, int)):
                    formated_record.update({**record['parameter'], **record_value, **{'id': record['id']}})
                else:
                    formated_record.update({**record['parameter'], **{'reward': record_value, 'id': record['id']}})
                trial_records.append(formated_record)
            if not trial_records:
                print_error('No trial results collected! Please check your trial log...')
                exit(0)
            with open(args.path, 'w', newline='') as file:
                writer = csv.DictWriter(file, set.union(*[set(r.keys()) for r in trial_records]))
                writer.writeheader()
                writer.writerows(trial_records)
Yan Ni's avatar
Yan Ni committed
772
        else:
773
774
            print_error('Unknown type: %s' % args.type)
            return
Yan Ni's avatar
Yan Ni committed
775
    else:
776
        print_error('Export failed...')
777
778
779
780
781
782

def search_space_auto_gen(args):
    '''dry run trial code to generate search space file'''
    trial_dir = os.path.expanduser(args.trial_dir)
    file_path = os.path.expanduser(args.file)
    if not os.path.isabs(file_path):
783
        file_path = os.path.join(os.getcwd(), file_path)
784
    assert os.path.exists(trial_dir)
785
786
    if os.path.exists(file_path):
        print_warning('%s already exists, will be overwritten.' % file_path)
787
    print_normal('Dry run to generate search space...')
788
789
790
791
792
    Popen(args.trial_command, cwd=trial_dir, env=dict(os.environ, NNI_GEN_SEARCH_SPACE=file_path), shell=True).wait()
    if not os.path.exists(file_path):
        print_warning('Expected search space file \'{}\' generated, but not found.'.format(file_path))
    else:
        print_normal('Generate search space done: \'{}\'.'.format(file_path))
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838

def save_experiment(args):
    '''save experiment data to a zip file'''
    experiment_config = Experiments()
    experiment_dict = experiment_config.get_all_experiments()
    if args.id is None:
        print_error('Please set experiment id.')
        exit(1)
    if args.id not in experiment_dict:
        print_error('Cannot find experiment {0}.'.format(args.id))
        exit(1)
    if experiment_dict[args.id].get('status') != 'STOPPED':
        print_error('Can only save stopped experiment!')
        exit(1)
    print_normal('Saving...')
    nni_config = Config(experiment_dict[args.id]['fileName'])
    logDir = os.path.join(NNI_HOME_DIR, args.id)
    if nni_config.get_config('logDir'):
        logDir = os.path.join(nni_config.get_config('logDir'), args.id)
    temp_root_dir = generate_temp_dir()

    # Step1. Copy logDir to temp folder
    if not os.path.exists(logDir):
        print_error('logDir: %s does not exist!' % logDir)
        exit(1)
    temp_experiment_dir = os.path.join(temp_root_dir, 'experiment')
    shutil.copytree(logDir, temp_experiment_dir)

    # Step2. Copy nnictl metadata to temp folder
    temp_nnictl_dir = os.path.join(temp_root_dir, 'nnictl')
    os.makedirs(temp_nnictl_dir, exist_ok=True)
    try:
        with open(os.path.join(temp_nnictl_dir, '.experiment'), 'w') as file:
            experiment_dict[args.id]['id'] = args.id
            json.dump(experiment_dict[args.id], file)
    except IOError:
        print_error('Write file to %s failed!' % os.path.join(temp_nnictl_dir, '.experiment'))
        exit(1)
    nnictl_config_dir = os.path.join(NNICTL_HOME_DIR, experiment_dict[args.id]['fileName'])
    shutil.copytree(nnictl_config_dir, os.path.join(temp_nnictl_dir, experiment_dict[args.id]['fileName']))

    # Step3. Copy code dir
    if args.saveCodeDir:
        temp_code_dir = os.path.join(temp_root_dir, 'code')
        shutil.copytree(nni_config.get_config('experimentConfig')['trial']['codeDir'], temp_code_dir)

839
840
841
842
843
844
845
846
847
848
849
850
    # Step4. Copy searchSpace file
    search_space_path = nni_config.get_config('experimentConfig').get('searchSpacePath')
    if search_space_path:
        if not os.path.exists(search_space_path):
            print_warning('search space %s does not exist!' % search_space_path)
        else:
            temp_search_space_dir = os.path.join(temp_root_dir, 'searchSpace')
            os.makedirs(temp_search_space_dir, exist_ok=True)
            search_space_name = os.path.basename(search_space_path)
            shutil.copyfile(search_space_path, os.path.join(temp_search_space_dir, search_space_name))

    # Step5. Archive folder
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
    zip_package_name = 'nni_experiment_%s' % args.id
    if args.path:
        os.makedirs(args.path, exist_ok=True)
        zip_package_name = os.path.join(args.path, zip_package_name)
    shutil.make_archive(zip_package_name, 'zip', temp_root_dir)
    print_normal('Save to %s.zip success!' % zip_package_name)

    # Step5. Cleanup temp data
    shutil.rmtree(temp_root_dir)

def load_experiment(args):
    '''load experiment data'''
    package_path = os.path.expanduser(args.path)
    if not os.path.exists(args.path):
        print_error('file path %s does not exist!' % args.path)
        exit(1)
867
868
869
    if args.searchSpacePath and os.path.isdir(args.searchSpacePath):
        print_error('search space path should be a full path with filename, not a directory!')
        exit(1)
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
    temp_root_dir = generate_temp_dir()
    shutil.unpack_archive(package_path, temp_root_dir)
    print_normal('Loading...')
    # Step1. Validation
    if not os.path.exists(args.codeDir):
        print_error('Invalid: codeDir path does not exist!')
        exit(1)
    if args.logDir:
        if not os.path.exists(args.logDir):
            print_error('Invalid: logDir path does not exist!')
            exit(1)
    experiment_temp_dir = os.path.join(temp_root_dir, 'experiment')
    if not os.path.exists(os.path.join(experiment_temp_dir, 'db')):
        print_error('Invalid archive file: db file does not exist!')
        shutil.rmtree(temp_root_dir)
        exit(1)
    nnictl_temp_dir = os.path.join(temp_root_dir, 'nnictl')
    if not os.path.exists(os.path.join(nnictl_temp_dir, '.experiment')):
        print_error('Invalid archive file: nnictl metadata file does not exist!')
        shutil.rmtree(temp_root_dir)
        exit(1)
    try:
        with open(os.path.join(nnictl_temp_dir, '.experiment'), 'r') as file:
            experiment_metadata = json.load(file)
    except ValueError as err:
        print_error('Invalid nnictl metadata file: %s' % err)
        shutil.rmtree(temp_root_dir)
        exit(1)
    experiment_config = Experiments()
    experiment_dict = experiment_config.get_all_experiments()
    experiment_id = experiment_metadata.get('id')
    if experiment_id in experiment_dict:
        print_error('Invalid: experiment id already exist!')
        shutil.rmtree(temp_root_dir)
        exit(1)
    if not os.path.exists(os.path.join(nnictl_temp_dir, experiment_metadata.get('fileName'))):
        print_error('Invalid: experiment metadata does not exist!')
        shutil.rmtree(temp_root_dir)
        exit(1)

    # Step2. Copy nnictl metadata
    src_path = os.path.join(nnictl_temp_dir, experiment_metadata.get('fileName'))
    dest_path = os.path.join(NNICTL_HOME_DIR, experiment_metadata.get('fileName'))
    if os.path.exists(dest_path):
        shutil.rmtree(dest_path)
    shutil.copytree(src_path, dest_path)

    # Step3. Copy experiment data
    nni_config = Config(experiment_metadata.get('fileName'))
    nnictl_exp_config = nni_config.get_config('experimentConfig')
    if args.logDir:
        logDir = args.logDir
        nnictl_exp_config['logDir'] = logDir
    else:
        if nnictl_exp_config.get('logDir'):
            logDir = nnictl_exp_config['logDir']
        else:
            logDir = NNI_HOME_DIR
    os.rename(os.path.join(temp_root_dir, 'experiment'), os.path.join(temp_root_dir, experiment_id))
    src_path = os.path.join(os.path.join(temp_root_dir, experiment_id))
    dest_path = os.path.join(os.path.join(logDir, experiment_id))
    if os.path.exists(dest_path):
        shutil.rmtree(dest_path)
    shutil.copytree(src_path, dest_path)

    # Step4. Copy code dir
    codeDir = os.path.expanduser(args.codeDir)
    if not os.path.isabs(codeDir):
        codeDir = os.path.join(os.getcwd(), codeDir)
        print_normal('Expand codeDir to %s' % codeDir)
    nnictl_exp_config['trial']['codeDir'] = codeDir
    archive_code_dir = os.path.join(temp_root_dir, 'code')
    if os.path.exists(archive_code_dir):
        file_list = os.listdir(archive_code_dir)
        for file_name in file_list:
            src_path = os.path.join(archive_code_dir, file_name)
            target_path = os.path.join(codeDir, file_name)
            if os.path.exists(target_path):
                print_error('Copy %s failed, %s exist!' % (file_name, target_path))
                continue
            if os.path.isdir(src_path):
                shutil.copytree(src_path, target_path)
            else:
                shutil.copy(src_path, target_path)

955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
    # Step5. Copy searchSpace file
    archive_search_space_dir = os.path.join(temp_root_dir, 'searchSpace')
    if args.searchSpacePath:
        target_path = os.path.expanduser(args.searchSpacePath)
    else:
        # set default path to codeDir
        target_path = os.path.join(codeDir, 'search_space.json')
    if not os.path.isabs(target_path):
        target_path = os.path.join(os.getcwd(), target_path)
        print_normal('Expand search space path to %s' % target_path)
    nnictl_exp_config['searchSpacePath'] = target_path
    # if the path already has a search space file, use the original one, otherwise use archived one
    if not os.path.isfile(target_path):
        if len(os.listdir(archive_search_space_dir)) == 0:
            print_error('Archive file does not contain search space file!')
            exit(1)
        else:
            for file in os.listdir(archive_search_space_dir):
                source_path = os.path.join(archive_search_space_dir, file)
                os.makedirs(os.path.dirname(target_path), exist_ok=True)
                shutil.copyfile(source_path, target_path)
                break
    elif not args.searchSpacePath:
        print_warning('%s exist, will not load search_space file' % target_path)

    # Step6. Create experiment metadata
981
982
983
984
985
986
987
988
989
990
991
992
993
    nni_config.set_config('experimentConfig', nnictl_exp_config)
    experiment_config.add_experiment(experiment_id,
                                     experiment_metadata.get('port'),
                                     experiment_metadata.get('startTime'),
                                     experiment_metadata.get('fileName'),
                                     experiment_metadata.get('platform'),
                                     experiment_metadata.get('experimentName'),
                                     experiment_metadata.get('endTime'),
                                     experiment_metadata.get('status'))
    print_normal('Load experiment %s succsss!' % experiment_id)

    # Step6. Cleanup temp data
    shutil.rmtree(temp_root_dir)
994