nnictl_utils.py 42 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
Deshui Yu's avatar
Deshui Yu committed
3

Yan Ni's avatar
Yan Ni committed
4
import csv
Deshui Yu's avatar
Deshui Yu committed
5
import os
6
import sys
Deshui Yu's avatar
Deshui Yu committed
7
import json
8
import time
SparkSnail's avatar
SparkSnail committed
9
10
import re
import shutil
Yuge Zhang's avatar
Yuge Zhang committed
11
import subprocess
12
from functools import cmp_to_key
chicm-ms's avatar
chicm-ms committed
13
14
from datetime import datetime, timezone
from pathlib import Path
15
from subprocess import Popen
chicm-ms's avatar
chicm-ms committed
16
from pyhdfs import HdfsClient
chicm-ms's avatar
chicm-ms committed
17
from nni.package_utils import get_nni_installation_path
18
from nni_annotation import expand_annotations
19
from .rest_utils import rest_get, rest_delete, check_rest_server_quick, check_response
20
from .url_utils import trial_jobs_url, experiment_url, trial_job_id_url, export_data_url, metric_data_url
21
from .config_utils import Config, Experiments
22
from .constants import NNICTL_HOME_DIR, NNI_HOME_DIR, EXPERIMENT_INFORMATION_FORMAT, EXPERIMENT_DETAIL_FORMAT, \
23
     EXPERIMENT_MONITOR_INFO, TRIAL_MONITOR_HEAD, TRIAL_MONITOR_CONTENT, TRIAL_MONITOR_TAIL, REST_TIME_OUT
24
from .common_utils import print_normal, print_error, print_warning, detect_process, get_yml_content, generate_temp_dir
25
from .command_utils import check_output_command, kill_command
SparkSnail's avatar
SparkSnail committed
26
from .ssh_utils import create_ssh_sftp_client, remove_remote_directory
Deshui Yu's avatar
Deshui Yu committed
27

28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def get_experiment_time(port):
    '''get the startTime and endTime of an experiment'''
    response = rest_get(experiment_url(port), REST_TIME_OUT)
    if response and check_response(response):
        content = convert_time_stamp_to_date(json.loads(response.text))
        return content.get('startTime'), content.get('endTime')
    return None, None

def get_experiment_status(port):
    '''get the status of an experiment'''
    result, response = check_rest_server_quick(port)
    if result:
        return json.loads(response.text).get('status')
    return None

def update_experiment():
SparkSnail's avatar
SparkSnail committed
44
45
46
47
48
49
50
    '''Update the experiment status in config file'''
    experiment_config = Experiments()
    experiment_dict = experiment_config.get_all_experiments()
    if not experiment_dict:
        return None
    for key in experiment_dict.keys():
        if isinstance(experiment_dict[key], dict):
51
            if experiment_dict[key].get('status') != 'STOPPED':
SparkSnail's avatar
SparkSnail committed
52
53
54
                nni_config = Config(experiment_dict[key]['fileName'])
                rest_pid = nni_config.get_config('restServerPid')
                if not detect_process(rest_pid):
55
56
57
58
59
60
61
62
63
64
65
                    experiment_config.update_experiment(key, 'status', 'STOPPED')
                    continue
                rest_port = nni_config.get_config('restServerPort')
                startTime, endTime = get_experiment_time(rest_port)
                if startTime:
                    experiment_config.update_experiment(key, 'startTime', startTime)
                if endTime:
                    experiment_config.update_experiment(key, 'endTime', endTime)
                status = get_experiment_status(rest_port)
                if status:
                    experiment_config.update_experiment(key, 'status', status)
SparkSnail's avatar
SparkSnail committed
66

SparkSnail's avatar
SparkSnail committed
67
def check_experiment_id(args, update=True):
68
69
    '''check if the id is valid
    '''
SparkSnail's avatar
SparkSnail committed
70
71
    if update:
        update_experiment()
72
73
74
    experiment_config = Experiments()
    experiment_dict = experiment_config.get_all_experiments()
    if not experiment_dict:
75
        print_normal('There is no experiment running...')
chicm-ms's avatar
chicm-ms committed
76
        return None
77
    if not args.id:
78
79
        running_experiment_list = []
        for key in experiment_dict.keys():
SparkSnail's avatar
SparkSnail committed
80
            if isinstance(experiment_dict[key], dict):
81
                if experiment_dict[key].get('status') != 'STOPPED':
SparkSnail's avatar
SparkSnail committed
82
83
84
85
                    running_experiment_list.append(key)
            elif isinstance(experiment_dict[key], list):
                # if the config file is old version, remove the configuration from file
                experiment_config.remove_experiment(key)
86
        if len(running_experiment_list) > 1:
87
            print_error('There are multiple experiments, please set the experiment id...')
88
89
            experiment_information = ""
            for key in running_experiment_list:
90
91
92
93
94
95
96
                experiment_information += EXPERIMENT_DETAIL_FORMAT % (key,
                                                                      experiment_dict[key].get('experimentName', 'N/A'),
                                                                      experiment_dict[key]['status'],
                                                                      experiment_dict[key]['port'],
                                                                      experiment_dict[key].get('platform'),
                                                                      experiment_dict[key]['startTime'],
                                                                      experiment_dict[key]['endTime'])
97
98
99
            print(EXPERIMENT_INFORMATION_FORMAT % experiment_information)
            exit(1)
        elif not running_experiment_list:
SparkSnail's avatar
SparkSnail committed
100
            print_error('There is no experiment running.')
chicm-ms's avatar
chicm-ms committed
101
            return None
102
103
        else:
            return running_experiment_list[0]
104
105
106
    if experiment_dict.get(args.id):
        return args.id
    else:
SparkSnail's avatar
SparkSnail committed
107
        print_error('Id not correct.')
108
        return None
Deshui Yu's avatar
Deshui Yu committed
109

110
def parse_ids(args):
111
    '''Parse the arguments for nnictl stop
112
113
114
115
116
117
118
119
    1.If port is provided and id is not specified, return the id who owns the port
    2.If both port and id are provided, return the id if it owns the port, otherwise fail
    3.If there is an id specified, return the corresponding id
    4.If there is no id specified, and there is an experiment running, return the id, or return Error
    5.If the id matches an experiment, nnictl will return the id.
    6.If the id ends with *, nnictl will match all ids matchs the regular
    7.If the id does not exist but match the prefix of an experiment id, nnictl will return the matched id
    8.If the id does not exist but match multiple prefix of the experiment ids, nnictl will give id information
120
    '''
121
    update_experiment()
122
123
124
125
126
127
    experiment_config = Experiments()
    experiment_dict = experiment_config.get_all_experiments()
    if not experiment_dict:
        print_normal('Experiment is not running...')
        return None
    result_list = []
128
129
    running_experiment_list = []
    for key in experiment_dict.keys():
SparkSnail's avatar
SparkSnail committed
130
        if isinstance(experiment_dict[key], dict):
131
            if experiment_dict[key].get('status') != 'STOPPED':
SparkSnail's avatar
SparkSnail committed
132
133
134
135
                running_experiment_list.append(key)
        elif isinstance(experiment_dict[key], list):
            # if the config file is old version, remove the configuration from file
            experiment_config.remove_experiment(key)
136
137
    if args.all:
        return running_experiment_list
138
139
140
141
142
143
144
145
    if args.port is not None:
        for key in running_experiment_list:
            if str(experiment_dict[key]['port']) == args.port:
                result_list.append(key)
        if args.id and result_list and args.id != result_list[0]:
            print_error('Experiment id and resful server port not match')
            exit(1)
    elif not args.id:
146
        if len(running_experiment_list) > 1:
147
            print_error('There are multiple experiments, please set the experiment id...')
148
            experiment_information = ""
149
            for key in running_experiment_list:
150
151
152
153
154
155
156
                experiment_information += EXPERIMENT_DETAIL_FORMAT % (key,
                                                                      experiment_dict[key].get('experimentName', 'N/A'),
                                                                      experiment_dict[key]['status'],
                                                                      experiment_dict[key]['port'],
                                                                      experiment_dict[key].get('platform'),
                                                                      experiment_dict[key]['startTime'],
                                                                      experiment_dict[key]['endTime'])
157
158
159
160
            print(EXPERIMENT_INFORMATION_FORMAT % experiment_information)
            exit(1)
        else:
            result_list = running_experiment_list
161
    elif args.id.endswith('*'):
chicm-ms's avatar
chicm-ms committed
162
163
164
        for expId in running_experiment_list:
            if expId.startswith(args.id[:-1]):
                result_list.append(expId)
165
    elif args.id in running_experiment_list:
166
167
        result_list.append(args.id)
    else:
chicm-ms's avatar
chicm-ms committed
168
169
170
        for expId in running_experiment_list:
            if expId.startswith(args.id):
                result_list.append(expId)
171
        if len(result_list) > 1:
chicm-ms's avatar
chicm-ms committed
172
            print_error(args.id + ' is ambiguous, please choose ' + ' '.join(result_list))
173
            return None
174
    if not result_list and (args.id  or args.port):
175
        print_error('There are no experiments matched, please set correct experiment id or restful server port')
SparkSnail's avatar
SparkSnail committed
176
    elif not result_list:
SparkSnail's avatar
SparkSnail committed
177
        print_error('There is no experiment running...')
178
179
    return result_list

180
181
182
def get_config_filename(args):
    '''get the file name of config file'''
    experiment_id = check_experiment_id(args)
chicm-ms's avatar
chicm-ms committed
183
    if experiment_id is None:
SparkSnail's avatar
SparkSnail committed
184
        print_error('Please set correct experiment id.')
chicm-ms's avatar
chicm-ms committed
185
        exit(1)
186
187
188
189
190
191
192
    experiment_config = Experiments()
    experiment_dict = experiment_config.get_all_experiments()
    return experiment_dict[experiment_id]['fileName']

def get_experiment_port(args):
    '''get the port of experiment'''
    experiment_id = check_experiment_id(args)
chicm-ms's avatar
chicm-ms committed
193
    if experiment_id is None:
SparkSnail's avatar
SparkSnail committed
194
        print_error('Please set correct experiment id.')
chicm-ms's avatar
chicm-ms committed
195
        exit(1)
196
197
198
199
200
201
202
203
204
    experiment_config = Experiments()
    experiment_dict = experiment_config.get_all_experiments()
    return experiment_dict[experiment_id]['port']

def convert_time_stamp_to_date(content):
    '''Convert time stamp to date time format'''
    start_time_stamp = content.get('startTime')
    end_time_stamp = content.get('endTime')
    if start_time_stamp:
205
        start_time = datetime.fromtimestamp(start_time_stamp // 1000, timezone.utc).astimezone().strftime("%Y/%m/%d %H:%M:%S")
206
207
        content['startTime'] = str(start_time)
    if end_time_stamp:
208
        end_time = datetime.fromtimestamp(end_time_stamp // 1000, timezone.utc).astimezone().strftime("%Y/%m/%d %H:%M:%S")
209
210
211
212
213
214
215
216
217
218
219
220
221
        content['endTime'] = str(end_time)
    return content

def check_rest(args):
    '''check if restful server is running'''
    nni_config = Config(get_config_filename(args))
    rest_port = nni_config.get_config('restServerPort')
    running, _ = check_rest_server_quick(rest_port)
    if not running:
        print_normal('Restful server is running...')
    else:
        print_normal('Restful server is not running...')

Deshui Yu's avatar
Deshui Yu committed
222
223
def stop_experiment(args):
    '''Stop the experiment which is running'''
224
225
226
    if args.id and args.id == 'all':
        print_warning('\'nnictl stop all\' is abolished, please use \'nnictl stop --all\' to stop all of experiments!')
        exit(1)
227
228
229
230
231
    experiment_id_list = parse_ids(args)
    if experiment_id_list:
        experiment_config = Experiments()
        experiment_dict = experiment_config.get_all_experiments()
        for experiment_id in experiment_id_list:
232
            print_normal('Stopping experiment %s' % experiment_id)
233
            nni_config = Config(experiment_dict[experiment_id]['fileName'])
234
235
            rest_pid = nni_config.get_config('restServerPid')
            if rest_pid:
236
                kill_command(rest_pid)
SparkSnail's avatar
SparkSnail committed
237
238
239
240
                tensorboard_pid_list = nni_config.get_config('tensorboardPidList')
                if tensorboard_pid_list:
                    for tensorboard_pid in tensorboard_pid_list:
                        try:
241
                            kill_command(tensorboard_pid)
SparkSnail's avatar
SparkSnail committed
242
243
244
                        except Exception as exception:
                            print_error(exception)
                    nni_config.set_config('tensorboardPidList', [])
SparkSnail's avatar
SparkSnail committed
245
            print_normal('Stop experiment success.')
246
            experiment_config.update_experiment(experiment_id, 'status', 'STOPPED')
chicm-ms's avatar
chicm-ms committed
247
            time_now = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
248
            experiment_config.update_experiment(experiment_id, 'endTime', str(time_now))
Deshui Yu's avatar
Deshui Yu committed
249
250
251

def trial_ls(args):
    '''List trial'''
252
253
254
255
256
257
258
259
260
261
262
263
264
265
    def final_metric_data_cmp(lhs, rhs):
        metric_l = json.loads(json.loads(lhs['finalMetricData'][0]['data']))
        metric_r = json.loads(json.loads(rhs['finalMetricData'][0]['data']))
        if isinstance(metric_l, float):
            return metric_l - metric_r
        elif isinstance(metric_l, dict):
            return metric_l['default'] - metric_r['default']
        else:
            print_error('Unexpected data format. Please check your data.')
            raise ValueError

    if args.head and args.tail:
        print_error('Head and tail cannot be set at the same time.')
        return
266
    nni_config = Config(get_config_filename(args))
Deshui Yu's avatar
Deshui Yu committed
267
268
269
270
271
    rest_port = nni_config.get_config('restServerPort')
    rest_pid = nni_config.get_config('restServerPid')
    if not detect_process(rest_pid):
        print_error('Experiment is not running...')
        return
272
273
    running, response = check_rest_server_quick(rest_port)
    if running:
274
        response = rest_get(trial_jobs_url(rest_port), REST_TIME_OUT)
275
        if response and check_response(response):
276
            content = json.loads(response.text)
277
278
279
280
281
282
283
284
            if args.head:
                assert args.head > 0, 'The number of requested data must be greater than 0.'
                content = sorted(filter(lambda x: 'finalMetricData' in x, content),
                                 key=cmp_to_key(final_metric_data_cmp), reverse=True)[:args.head]
            elif args.tail:
                assert args.tail > 0, 'The number of requested data must be greater than 0.'
                content = sorted(filter(lambda x: 'finalMetricData' in x, content),
                                 key=cmp_to_key(final_metric_data_cmp))[:args.tail]
285
            for index, value in enumerate(content):
286
287
                content[index] = convert_time_stamp_to_date(value)
            print(json.dumps(content, indent=4, sort_keys=True, separators=(',', ':')))
Deshui Yu's avatar
Deshui Yu committed
288
289
290
291
292
293
294
        else:
            print_error('List trial failed...')
    else:
        print_error('Restful server is not running...')

def trial_kill(args):
    '''List trial'''
295
    nni_config = Config(get_config_filename(args))
Deshui Yu's avatar
Deshui Yu committed
296
297
298
299
300
    rest_port = nni_config.get_config('restServerPort')
    rest_pid = nni_config.get_config('restServerPid')
    if not detect_process(rest_pid):
        print_error('Experiment is not running...')
        return
301
302
    running, _ = check_rest_server_quick(rest_port)
    if running:
SparkSnail's avatar
SparkSnail committed
303
        response = rest_delete(trial_job_id_url(rest_port, args.trial_id), REST_TIME_OUT)
304
        if response and check_response(response):
Deshui Yu's avatar
Deshui Yu committed
305
306
307
308
309
310
            print(response.text)
        else:
            print_error('Kill trial job failed...')
    else:
        print_error('Restful server is not running...')

311
312
313
314
315
316
317
318
319
320
321
def trial_codegen(args):
    '''Generate code for a specific trial'''
    print_warning('Currently, this command is only for nni nas programming interface.')
    exp_id = check_experiment_id(args)
    nni_config = Config(get_config_filename(args))
    if not nni_config.get_config('experimentConfig')['useAnnotation']:
        print_error('The experiment is not using annotation')
        exit(1)
    code_dir = nni_config.get_config('experimentConfig')['trial']['codeDir']
    expand_annotations(code_dir, './exp_%s_trial_%s_code'%(exp_id, args.trial_id), exp_id, args.trial_id)

Deshui Yu's avatar
Deshui Yu committed
322
323
def list_experiment(args):
    '''Get experiment information'''
324
    nni_config = Config(get_config_filename(args))
Deshui Yu's avatar
Deshui Yu committed
325
326
327
328
329
    rest_port = nni_config.get_config('restServerPort')
    rest_pid = nni_config.get_config('restServerPid')
    if not detect_process(rest_pid):
        print_error('Experiment is not running...')
        return
330
331
    running, _ = check_rest_server_quick(rest_port)
    if running:
332
        response = rest_get(experiment_url(rest_port), REST_TIME_OUT)
333
        if response and check_response(response):
334
335
            content = convert_time_stamp_to_date(json.loads(response.text))
            print(json.dumps(content, indent=4, sort_keys=True, separators=(',', ':')))
Deshui Yu's avatar
Deshui Yu committed
336
337
338
339
340
        else:
            print_error('List experiment failed...')
    else:
        print_error('Restful server is not running...')

341
342
def experiment_status(args):
    '''Show the status of experiment'''
343
    nni_config = Config(get_config_filename(args))
344
345
346
347
348
349
350
    rest_port = nni_config.get_config('restServerPort')
    result, response = check_rest_server_quick(rest_port)
    if not result:
        print_normal('Restful server is not running...')
    else:
        print(json.dumps(json.loads(response.text), indent=4, sort_keys=True, separators=(',', ':')))

Deshui Yu's avatar
Deshui Yu committed
351
352
def log_internal(args, filetype):
    '''internal function to call get_log_content'''
353
    file_name = get_config_filename(args)
Deshui Yu's avatar
Deshui Yu committed
354
    if filetype == 'stdout':
355
        file_full_path = os.path.join(NNICTL_HOME_DIR, file_name, 'stdout')
Deshui Yu's avatar
Deshui Yu committed
356
    else:
357
        file_full_path = os.path.join(NNICTL_HOME_DIR, file_name, 'stderr')
358
    print(check_output_command(file_full_path, head=args.head, tail=args.tail))
359

Deshui Yu's avatar
Deshui Yu committed
360
361
362
363
364
365
366
367
def log_stdout(args):
    '''get stdout log'''
    log_internal(args, 'stdout')

def log_stderr(args):
    '''get stderr log'''
    log_internal(args, 'stderr')

368
369
370
def log_trial(args):
    ''''get trial log path'''
    trial_id_path_dict = {}
SparkSnail's avatar
SparkSnail committed
371
    trial_id_list = []
372
    nni_config = Config(get_config_filename(args))
373
374
375
376
377
378
379
    rest_port = nni_config.get_config('restServerPort')
    rest_pid = nni_config.get_config('restServerPid')
    if not detect_process(rest_pid):
        print_error('Experiment is not running...')
        return
    running, response = check_rest_server_quick(rest_port)
    if running:
380
        response = rest_get(trial_jobs_url(rest_port), REST_TIME_OUT)
381
382
383
        if response and check_response(response):
            content = json.loads(response.text)
            for trial in content:
SparkSnail's avatar
SparkSnail committed
384
385
386
                trial_id_list.append(trial.get('id'))
                if trial.get('logPath'):
                    trial_id_path_dict[trial.get('id')] = trial['logPath']
387
388
    else:
        print_error('Restful server is not running...')
goooxu's avatar
goooxu committed
389
        exit(1)
SparkSnail's avatar
SparkSnail committed
390
391
392
393
394
395
    if args.trial_id:
        if args.trial_id not in trial_id_list:
            print_error('Trial id {0} not correct, please check your command!'.format(args.trial_id))
            exit(1)
        if trial_id_path_dict.get(args.trial_id):
            print_normal('id:' + args.trial_id + ' path:' + trial_id_path_dict[args.trial_id])
396
        else:
SparkSnail's avatar
SparkSnail committed
397
            print_error('Log path is not available yet, please wait...')
goooxu's avatar
goooxu committed
398
            exit(1)
399
    else:
SparkSnail's avatar
SparkSnail committed
400
        print_normal('All of trial log info:')
401
        for key in trial_id_path_dict:
SparkSnail's avatar
SparkSnail committed
402
403
404
            print_normal('id:' + key + ' path:' + trial_id_path_dict[key])
        if not trial_id_path_dict:
            print_normal('None')
405

Deshui Yu's avatar
Deshui Yu committed
406
407
def get_config(args):
    '''get config info'''
408
    nni_config = Config(get_config_filename(args))
Deshui Yu's avatar
Deshui Yu committed
409
    print(nni_config.get_all_config())
410
411
412

def webui_url(args):
    '''show the url of web ui'''
413
    nni_config = Config(get_config_filename(args))
414
415
    print_normal('{0} {1}'.format('Web UI url:', ' '.join(nni_config.get_config('webuiUrl'))))

Yuge Zhang's avatar
Yuge Zhang committed
416
417
418
419
def webui_nas(args):
    '''launch nas ui'''
    print_normal('Starting NAS UI...')
    try:
SparkSnail's avatar
SparkSnail committed
420
421
422
423
424
425
        entry_dir = get_nni_installation_path()
        entry_file = os.path.join(entry_dir, 'nasui', 'server.js')
        node_command = 'node'
        if sys.platform == 'win32':
            node_command = os.path.join(entry_dir[:-3], 'Scripts', 'node.exe')
        cmds = [node_command, '--max-old-space-size=4096', entry_file, '--port', str(args.port), '--logdir', args.logdir]
Yuge Zhang's avatar
Yuge Zhang committed
426
427
428
429
        subprocess.run(cmds)
    except KeyboardInterrupt:
        pass

SparkSnail's avatar
SparkSnail committed
430
431
432
433
434
def local_clean(directory):
    '''clean up local data'''
    print_normal('removing folder {0}'.format(directory))
    try:
        shutil.rmtree(directory)
chicm-ms's avatar
chicm-ms committed
435
    except FileNotFoundError:
SparkSnail's avatar
SparkSnail committed
436
        print_error('{0} does not exist.'.format(directory))
chicm-ms's avatar
chicm-ms committed
437

SparkSnail's avatar
SparkSnail committed
438
439
440
441
442
443
444
def remote_clean(machine_list, experiment_id=None):
    '''clean up remote data'''
    for machine in machine_list:
        passwd = machine.get('passwd')
        userName = machine.get('username')
        host = machine.get('ip')
        port = machine.get('port')
445
446
        sshKeyPath = machine.get('sshKeyPath')
        passphrase = machine.get('passphrase')
SparkSnail's avatar
SparkSnail committed
447
448
449
450
        if experiment_id:
            remote_dir = '/' + '/'.join(['tmp', 'nni', 'experiments', experiment_id])
        else:
            remote_dir = '/' + '/'.join(['tmp', 'nni', 'experiments'])
451
        sftp = create_ssh_sftp_client(host, port, userName, passwd, sshKeyPath, passphrase)
SparkSnail's avatar
SparkSnail committed
452
453
        print_normal('removing folder {0}'.format(host + ':' + str(port) + remote_dir))
        remove_remote_directory(sftp, remote_dir)
chicm-ms's avatar
chicm-ms committed
454

SparkSnail's avatar
SparkSnail committed
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
def hdfs_clean(host, user_name, output_dir, experiment_id=None):
    '''clean up hdfs data'''
    hdfs_client = HdfsClient(hosts='{0}:80'.format(host), user_name=user_name, webhdfs_path='/webhdfs/api/v1', timeout=5)
    if experiment_id:
        full_path = '/' + '/'.join([user_name, 'nni', 'experiments', experiment_id])
    else:
        full_path = '/' + '/'.join([user_name, 'nni', 'experiments'])
    print_normal('removing folder {0} in hdfs'.format(full_path))
    hdfs_client.delete(full_path, recursive=True)
    if output_dir:
        pattern = re.compile('hdfs://(?P<host>([0-9]{1,3}.){3}[0-9]{1,3})(:[0-9]{2,5})?(?P<baseDir>/.*)?')
        match_result = pattern.match(output_dir)
        if match_result:
            output_host = match_result.group('host')
            output_dir = match_result.group('baseDir')
            #check if the host is valid
            if output_host != host:
                print_warning('The host in {0} is not consistent with {1}'.format(output_dir, host))
            else:
                if experiment_id:
                    output_dir = output_dir + '/' + experiment_id
                print_normal('removing folder {0} in hdfs'.format(output_dir))
                hdfs_client.delete(output_dir, recursive=True)

def experiment_clean(args):
    '''clean up the experiment data'''
    experiment_id_list = []
    experiment_config = Experiments()
    experiment_dict = experiment_config.get_all_experiments()
    if args.all:
        experiment_id_list = list(experiment_dict.keys())
    else:
        if args.id is None:
SparkSnail's avatar
SparkSnail committed
488
            print_error('please set experiment id.')
SparkSnail's avatar
SparkSnail committed
489
490
            exit(1)
        if args.id not in experiment_dict:
SparkSnail's avatar
SparkSnail committed
491
            print_error('Cannot find experiment {0}.'.format(args.id))
SparkSnail's avatar
SparkSnail committed
492
493
494
            exit(1)
        experiment_id_list.append(args.id)
    while True:
SparkSnail's avatar
SparkSnail committed
495
        print('INFO: This action will delete experiment {0}, and it\'s not recoverable.'.format(' '.join(experiment_id_list)))
SparkSnail's avatar
SparkSnail committed
496
497
498
499
        inputs = input('INFO: do you want to continue?[y/N]:')
        if not inputs.lower() or inputs.lower() in ['n', 'no']:
            exit(0)
        elif inputs.lower() not in ['y', 'n', 'yes', 'no']:
SparkSnail's avatar
SparkSnail committed
500
            print_warning('please input Y or N.')
SparkSnail's avatar
SparkSnail committed
501
502
503
504
505
506
507
508
509
510
        else:
            break
    for experiment_id in experiment_id_list:
        nni_config = Config(experiment_dict[experiment_id]['fileName'])
        platform = nni_config.get_config('experimentConfig').get('trainingServicePlatform')
        experiment_id = nni_config.get_config('experimentId')
        if platform == 'remote':
            machine_list = nni_config.get_config('experimentConfig').get('machineList')
            remote_clean(machine_list, experiment_id)
        elif platform == 'pai':
chicm-ms's avatar
chicm-ms committed
511
            host = nni_config.get_config('experimentConfig').get('paiConfig').get('host')
SparkSnail's avatar
SparkSnail committed
512
513
514
515
516
            user_name = nni_config.get_config('experimentConfig').get('paiConfig').get('userName')
            output_dir = nni_config.get_config('experimentConfig').get('trial').get('outputDir')
            hdfs_clean(host, user_name, output_dir, experiment_id)
        elif platform != 'local':
            #TODO: support all platforms
SparkSnail's avatar
SparkSnail committed
517
            print_warning('platform {0} clean up not supported yet.'.format(platform))
SparkSnail's avatar
SparkSnail committed
518
519
520
521
522
523
524
525
526
527
            exit(0)
        #clean local data
        home = str(Path.home())
        local_dir = nni_config.get_config('experimentConfig').get('logDir')
        if not local_dir:
            local_dir = os.path.join(home, 'nni', 'experiments', experiment_id)
        local_clean(local_dir)
        experiment_config = Experiments()
        print_normal('removing metadata of experiment {0}'.format(experiment_id))
        experiment_config.remove_experiment(experiment_id)
chicm-ms's avatar
chicm-ms committed
528
        print_normal('Done.')
SparkSnail's avatar
SparkSnail committed
529
530
531
532
533
534
535
536
537
538

def get_platform_dir(config_content):
    '''get the dir list to be deleted'''
    platform = config_content.get('trainingServicePlatform')
    dir_list = []
    if platform == 'remote':
        machine_list = config_content.get('machineList')
        for machine in machine_list:
            host = machine.get('ip')
            port = machine.get('port')
SparkSnail's avatar
SparkSnail committed
539
            dir_list.append(host + ':' + str(port) + '/tmp/nni')
SparkSnail's avatar
SparkSnail committed
540
    elif platform == 'pai':
chicm-ms's avatar
chicm-ms committed
541
        host = config_content.get('paiConfig').get('host')
SparkSnail's avatar
SparkSnail committed
542
543
        user_name = config_content.get('paiConfig').get('userName')
        output_dir = config_content.get('trial').get('outputDir')
SparkSnail's avatar
SparkSnail committed
544
        dir_list.append('server: {0}, path: {1}/nni'.format(host, user_name))
SparkSnail's avatar
SparkSnail committed
545
546
547
548
549
550
551
552
        if output_dir:
            dir_list.append(output_dir)
    return dir_list

def platform_clean(args):
    '''clean up the experiment data'''
    config_path = os.path.abspath(args.config)
    if not os.path.exists(config_path):
SparkSnail's avatar
SparkSnail committed
553
        print_error('Please set correct config path.')
SparkSnail's avatar
SparkSnail committed
554
555
556
        exit(1)
    config_content = get_yml_content(config_path)
    platform = config_content.get('trainingServicePlatform')
SparkSnail's avatar
SparkSnail committed
557
558
559
    if platform == 'local':
        print_normal('it doesn’t need to clean local platform.')
        exit(0)
SparkSnail's avatar
SparkSnail committed
560
    if platform not in ['remote', 'pai']:
SparkSnail's avatar
SparkSnail committed
561
        print_normal('platform {0} not supported.'.format(platform))
SparkSnail's avatar
SparkSnail committed
562
        exit(0)
Shinai Yang's avatar
Shinai Yang committed
563
    update_experiment()
SparkSnail's avatar
SparkSnail committed
564
565
    dir_list = get_platform_dir(config_content)
    if not dir_list:
SparkSnail's avatar
SparkSnail committed
566
        print_normal('No folder of NNI caches is found.')
SparkSnail's avatar
SparkSnail committed
567
568
        exit(1)
    while True:
chicm-ms's avatar
chicm-ms committed
569
570
571
572
        print_normal('This command will remove below folders of NNI caches. If other users are using experiments' \
                     ' on below hosts, it will be broken.')
        for value in dir_list:
            print('       ' + value)
SparkSnail's avatar
SparkSnail committed
573
574
575
576
        inputs = input('INFO: do you want to continue?[y/N]:')
        if not inputs.lower() or inputs.lower() in ['n', 'no']:
            exit(0)
        elif inputs.lower() not in ['y', 'n', 'yes', 'no']:
SparkSnail's avatar
SparkSnail committed
577
            print_warning('please input Y or N.')
SparkSnail's avatar
SparkSnail committed
578
579
580
581
        else:
            break
    if platform == 'remote':
        machine_list = config_content.get('machineList')
chicm-ms's avatar
chicm-ms committed
582
        remote_clean(machine_list, None)
SparkSnail's avatar
SparkSnail committed
583
    elif platform == 'pai':
chicm-ms's avatar
chicm-ms committed
584
        host = config_content.get('paiConfig').get('host')
SparkSnail's avatar
SparkSnail committed
585
586
587
        user_name = config_content.get('paiConfig').get('userName')
        output_dir = config_content.get('trial').get('outputDir')
        hdfs_clean(host, user_name, output_dir, None)
SparkSnail's avatar
SparkSnail committed
588
    print_normal('Done.')
SparkSnail's avatar
SparkSnail committed
589

590
591
def experiment_list(args):
    '''get the information of all experiments'''
Shinai Yang's avatar
Shinai Yang committed
592
    update_experiment()
593
594
595
    experiment_config = Experiments()
    experiment_dict = experiment_config.get_all_experiments()
    if not experiment_dict:
SparkSnail's avatar
SparkSnail committed
596
        print_normal('Cannot find experiments.')
597
598
        exit(1)
    experiment_id_list = []
SparkSnail's avatar
SparkSnail committed
599
    if args.all:
600
601
        for key in experiment_dict.keys():
            experiment_id_list.append(key)
602
603
    else:
        for key in experiment_dict.keys():
604
            if experiment_dict[key]['status'] != 'STOPPED':
605
606
                experiment_id_list.append(key)
        if not experiment_id_list:
607
            print_warning('There is no experiment running...\nYou can use \'nnictl experiment list --all\' to list all experiments.')
608
609
    experiment_information = ""
    for key in experiment_id_list:
610
611
612
613
614
615
616
        experiment_information += EXPERIMENT_DETAIL_FORMAT % (key,
                                                              experiment_dict[key].get('experimentName', 'N/A'),
                                                              experiment_dict[key]['status'],
                                                              experiment_dict[key]['port'],
                                                              experiment_dict[key].get('platform'),
                                                              experiment_dict[key]['startTime'],
                                                              experiment_dict[key]['endTime'])
617
    print(EXPERIMENT_INFORMATION_FORMAT % experiment_information)
SparkSnail's avatar
SparkSnail committed
618

SparkSnail's avatar
SparkSnail committed
619
620
621
622
def get_time_interval(time1, time2):
    '''get the interval of two times'''
    try:
        #convert time to timestamp
623
624
        time1 = time.mktime(time.strptime(time1, '%Y/%m/%d %H:%M:%S'))
        time2 = time.mktime(time.strptime(time2, '%Y/%m/%d %H:%M:%S'))
625
        seconds = (datetime.fromtimestamp(time2) - datetime.fromtimestamp(time1)).seconds
SparkSnail's avatar
SparkSnail committed
626
627
628
629
630
631
632
633
634
635
636
637
638
        #convert seconds to day:hour:minute:second
        days = seconds / 86400
        seconds %= 86400
        hours = seconds / 3600
        seconds %= 3600
        minutes = seconds / 60
        seconds %= 60
        return '%dd %dh %dm %ds' % (days, hours, minutes, seconds)
    except:
        return 'N/A'

def show_experiment_info():
    '''show experiment information in monitor'''
Shinai Yang's avatar
Shinai Yang committed
639
    update_experiment()
SparkSnail's avatar
SparkSnail committed
640
641
642
643
644
645
646
    experiment_config = Experiments()
    experiment_dict = experiment_config.get_all_experiments()
    if not experiment_dict:
        print('There is no experiment running...')
        exit(1)
    experiment_id_list = []
    for key in experiment_dict.keys():
647
        if experiment_dict[key]['status'] != 'STOPPED':
SparkSnail's avatar
SparkSnail committed
648
649
650
651
652
653
            experiment_id_list.append(key)
    if not experiment_id_list:
        print_warning('There is no experiment running...')
        return
    for key in experiment_id_list:
        print(EXPERIMENT_MONITOR_INFO % (key, experiment_dict[key]['status'], experiment_dict[key]['port'], \
chicm-ms's avatar
chicm-ms committed
654
655
             experiment_dict[key].get('platform'), experiment_dict[key]['startTime'], \
             get_time_interval(experiment_dict[key]['startTime'], experiment_dict[key]['endTime'])))
SparkSnail's avatar
SparkSnail committed
656
657
658
        print(TRIAL_MONITOR_HEAD)
        running, response = check_rest_server_quick(experiment_dict[key]['port'])
        if running:
659
            response = rest_get(trial_jobs_url(experiment_dict[key]['port']), REST_TIME_OUT)
SparkSnail's avatar
SparkSnail committed
660
661
            if response and check_response(response):
                content = json.loads(response.text)
662
                for index, value in enumerate(content):
SparkSnail's avatar
SparkSnail committed
663
                    content[index] = convert_time_stamp_to_date(value)
chicm-ms's avatar
chicm-ms committed
664
665
                    print(TRIAL_MONITOR_CONTENT % (content[index].get('id'), content[index].get('startTime'), \
                          content[index].get('endTime'), content[index].get('status')))
SparkSnail's avatar
SparkSnail committed
666
667
        print(TRIAL_MONITOR_TAIL)

668
669
def set_monitor(auto_exit, time_interval, port=None, pid=None):
    '''set the experiment monitor engine'''
SparkSnail's avatar
SparkSnail committed
670
671
    while True:
        try:
672
673
674
675
            if sys.platform == 'win32':
                os.system('cls')
            else:
                os.system('clear')
676
            update_experiment()
SparkSnail's avatar
SparkSnail committed
677
            show_experiment_info()
678
679
680
681
682
683
684
685
686
            if auto_exit:
                status = get_experiment_status(port)
                if status in ['DONE', 'ERROR', 'STOPPED']:
                    print_normal('Experiment status is {0}.'.format(status))
                    print_normal('Stopping experiment...')
                    kill_command(pid)
                    print_normal('Stop experiment success.')
                    exit(0)
            time.sleep(time_interval)
SparkSnail's avatar
SparkSnail committed
687
        except KeyboardInterrupt:
688
689
690
691
692
693
            if auto_exit:
                print_normal('Stopping experiment...')
                kill_command(pid)
                print_normal('Stop experiment success.')
            else:
                print_normal('Exiting...')
SparkSnail's avatar
SparkSnail committed
694
695
696
697
            exit(0)
        except Exception as exception:
            print_error(exception)
            exit(1)
Yan Ni's avatar
Yan Ni committed
698

699
700
701
702
703
704
705
def monitor_experiment(args):
    '''monitor the experiment'''
    if args.time <= 0:
        print_error('please input a positive integer as time interval, the unit is second.')
        exit(1)
    set_monitor(False, args.time)

Yan Ni's avatar
Yan Ni committed
706
def export_trials_data(args):
707
    '''export experiment metadata and intermediate results to json or csv
708
    '''
709
710
711
712
713
714
715
    def groupby_trial_id(intermediate_results):
        sorted(intermediate_results, key=lambda x: x['timestamp'])
        groupby = dict()
        for content in intermediate_results:
            groupby.setdefault(content['trialJobId'], []).append(json.loads(content['data']))
        return groupby

Yan Ni's avatar
Yan Ni committed
716
717
718
    nni_config = Config(get_config_filename(args))
    rest_port = nni_config.get_config('restServerPort')
    rest_pid = nni_config.get_config('restServerPid')
719

Yan Ni's avatar
Yan Ni committed
720
721
722
723
    if not detect_process(rest_pid):
        print_error('Experiment is not running...')
        return
    running, response = check_rest_server_quick(rest_port)
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
    if not running:
        print_error('Restful server is not running')
        return
    response = rest_get(export_data_url(rest_port), 20)
    if response is not None and check_response(response):
        content = json.loads(response.text)
        if args.intermediate:
            intermediate_results_response = rest_get(metric_data_url(rest_port), REST_TIME_OUT)
            if not intermediate_results_response or not check_response(intermediate_results_response):
                print_error('Error getting intermediate results.')
                return
            intermediate_results = groupby_trial_id(json.loads(intermediate_results_response.text))
            for record in content:
                record['intermediate'] = intermediate_results[record['id']]
        if args.type == 'json':
            with open(args.path, 'w') as file:
                file.write(json.dumps(content))
        elif args.type == 'csv':
            trial_records = []
            for record in content:
                formated_record = dict()
                if args.intermediate:
                    formated_record['intermediate'] = '[' + ','.join(record['intermediate']) + ']'
                record_value = json.loads(record['value'])
                if not isinstance(record_value, (float, int)):
                    formated_record.update({**record['parameter'], **record_value, **{'id': record['id']}})
                else:
                    formated_record.update({**record['parameter'], **{'reward': record_value, 'id': record['id']}})
                trial_records.append(formated_record)
            if not trial_records:
                print_error('No trial results collected! Please check your trial log...')
                exit(0)
            with open(args.path, 'w', newline='') as file:
                writer = csv.DictWriter(file, set.union(*[set(r.keys()) for r in trial_records]))
                writer.writeheader()
                writer.writerows(trial_records)
Yan Ni's avatar
Yan Ni committed
760
        else:
761
762
            print_error('Unknown type: %s' % args.type)
            return
Yan Ni's avatar
Yan Ni committed
763
    else:
764
        print_error('Export failed...')
765
766
767
768
769
770

def search_space_auto_gen(args):
    '''dry run trial code to generate search space file'''
    trial_dir = os.path.expanduser(args.trial_dir)
    file_path = os.path.expanduser(args.file)
    if not os.path.isabs(file_path):
771
        file_path = os.path.join(os.getcwd(), file_path)
772
    assert os.path.exists(trial_dir)
773
774
    if os.path.exists(file_path):
        print_warning('%s already exists, will be overwritten.' % file_path)
775
    print_normal('Dry run to generate search space...')
776
777
778
779
780
    Popen(args.trial_command, cwd=trial_dir, env=dict(os.environ, NNI_GEN_SEARCH_SPACE=file_path), shell=True).wait()
    if not os.path.exists(file_path):
        print_warning('Expected search space file \'{}\' generated, but not found.'.format(file_path))
    else:
        print_normal('Generate search space done: \'{}\'.'.format(file_path))
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942

def save_experiment(args):
    '''save experiment data to a zip file'''
    experiment_config = Experiments()
    experiment_dict = experiment_config.get_all_experiments()
    if args.id is None:
        print_error('Please set experiment id.')
        exit(1)
    if args.id not in experiment_dict:
        print_error('Cannot find experiment {0}.'.format(args.id))
        exit(1)
    if experiment_dict[args.id].get('status') != 'STOPPED':
        print_error('Can only save stopped experiment!')
        exit(1)
    print_normal('Saving...')
    nni_config = Config(experiment_dict[args.id]['fileName'])
    logDir = os.path.join(NNI_HOME_DIR, args.id)
    if nni_config.get_config('logDir'):
        logDir = os.path.join(nni_config.get_config('logDir'), args.id)
    temp_root_dir = generate_temp_dir()

    # Step1. Copy logDir to temp folder
    if not os.path.exists(logDir):
        print_error('logDir: %s does not exist!' % logDir)
        exit(1)
    temp_experiment_dir = os.path.join(temp_root_dir, 'experiment')
    shutil.copytree(logDir, temp_experiment_dir)

    # Step2. Copy nnictl metadata to temp folder
    temp_nnictl_dir = os.path.join(temp_root_dir, 'nnictl')
    os.makedirs(temp_nnictl_dir, exist_ok=True)
    try:
        with open(os.path.join(temp_nnictl_dir, '.experiment'), 'w') as file:
            experiment_dict[args.id]['id'] = args.id
            json.dump(experiment_dict[args.id], file)
    except IOError:
        print_error('Write file to %s failed!' % os.path.join(temp_nnictl_dir, '.experiment'))
        exit(1)
    nnictl_config_dir = os.path.join(NNICTL_HOME_DIR, experiment_dict[args.id]['fileName'])
    shutil.copytree(nnictl_config_dir, os.path.join(temp_nnictl_dir, experiment_dict[args.id]['fileName']))

    # Step3. Copy code dir
    if args.saveCodeDir:
        temp_code_dir = os.path.join(temp_root_dir, 'code')
        shutil.copytree(nni_config.get_config('experimentConfig')['trial']['codeDir'], temp_code_dir)

    # Step4. Archive folder
    zip_package_name = 'nni_experiment_%s' % args.id
    if args.path:
        os.makedirs(args.path, exist_ok=True)
        zip_package_name = os.path.join(args.path, zip_package_name)
    shutil.make_archive(zip_package_name, 'zip', temp_root_dir)
    print_normal('Save to %s.zip success!' % zip_package_name)

    # Step5. Cleanup temp data
    shutil.rmtree(temp_root_dir)

def load_experiment(args):
    '''load experiment data'''
    package_path = os.path.expanduser(args.path)
    if not os.path.exists(args.path):
        print_error('file path %s does not exist!' % args.path)
        exit(1)
    temp_root_dir = generate_temp_dir()
    shutil.unpack_archive(package_path, temp_root_dir)
    print_normal('Loading...')
    # Step1. Validation
    if not os.path.exists(args.codeDir):
        print_error('Invalid: codeDir path does not exist!')
        exit(1)
    if args.logDir:
        if not os.path.exists(args.logDir):
            print_error('Invalid: logDir path does not exist!')
            exit(1)
    experiment_temp_dir = os.path.join(temp_root_dir, 'experiment')
    if not os.path.exists(os.path.join(experiment_temp_dir, 'db')):
        print_error('Invalid archive file: db file does not exist!')
        shutil.rmtree(temp_root_dir)
        exit(1)
    nnictl_temp_dir = os.path.join(temp_root_dir, 'nnictl')
    if not os.path.exists(os.path.join(nnictl_temp_dir, '.experiment')):
        print_error('Invalid archive file: nnictl metadata file does not exist!')
        shutil.rmtree(temp_root_dir)
        exit(1)
    try:
        with open(os.path.join(nnictl_temp_dir, '.experiment'), 'r') as file:
            experiment_metadata = json.load(file)
    except ValueError as err:
        print_error('Invalid nnictl metadata file: %s' % err)
        shutil.rmtree(temp_root_dir)
        exit(1)
    experiment_config = Experiments()
    experiment_dict = experiment_config.get_all_experiments()
    experiment_id = experiment_metadata.get('id')
    if experiment_id in experiment_dict:
        print_error('Invalid: experiment id already exist!')
        shutil.rmtree(temp_root_dir)
        exit(1)
    if not os.path.exists(os.path.join(nnictl_temp_dir, experiment_metadata.get('fileName'))):
        print_error('Invalid: experiment metadata does not exist!')
        shutil.rmtree(temp_root_dir)
        exit(1)

    # Step2. Copy nnictl metadata
    src_path = os.path.join(nnictl_temp_dir, experiment_metadata.get('fileName'))
    dest_path = os.path.join(NNICTL_HOME_DIR, experiment_metadata.get('fileName'))
    if os.path.exists(dest_path):
        shutil.rmtree(dest_path)
    shutil.copytree(src_path, dest_path)

    # Step3. Copy experiment data
    nni_config = Config(experiment_metadata.get('fileName'))
    nnictl_exp_config = nni_config.get_config('experimentConfig')
    if args.logDir:
        logDir = args.logDir
        nnictl_exp_config['logDir'] = logDir
    else:
        if nnictl_exp_config.get('logDir'):
            logDir = nnictl_exp_config['logDir']
        else:
            logDir = NNI_HOME_DIR
    os.rename(os.path.join(temp_root_dir, 'experiment'), os.path.join(temp_root_dir, experiment_id))
    src_path = os.path.join(os.path.join(temp_root_dir, experiment_id))
    dest_path = os.path.join(os.path.join(logDir, experiment_id))
    if os.path.exists(dest_path):
        shutil.rmtree(dest_path)
    shutil.copytree(src_path, dest_path)

    # Step4. Copy code dir
    codeDir = os.path.expanduser(args.codeDir)
    if not os.path.isabs(codeDir):
        codeDir = os.path.join(os.getcwd(), codeDir)
        print_normal('Expand codeDir to %s' % codeDir)
    nnictl_exp_config['trial']['codeDir'] = codeDir
    archive_code_dir = os.path.join(temp_root_dir, 'code')
    if os.path.exists(archive_code_dir):
        file_list = os.listdir(archive_code_dir)
        for file_name in file_list:
            src_path = os.path.join(archive_code_dir, file_name)
            target_path = os.path.join(codeDir, file_name)
            if os.path.exists(target_path):
                print_error('Copy %s failed, %s exist!' % (file_name, target_path))
                continue
            if os.path.isdir(src_path):
                shutil.copytree(src_path, target_path)
            else:
                shutil.copy(src_path, target_path)

    # Step5. Create experiment metadata
    nni_config.set_config('experimentConfig', nnictl_exp_config)
    experiment_config.add_experiment(experiment_id,
                                     experiment_metadata.get('port'),
                                     experiment_metadata.get('startTime'),
                                     experiment_metadata.get('fileName'),
                                     experiment_metadata.get('platform'),
                                     experiment_metadata.get('experimentName'),
                                     experiment_metadata.get('endTime'),
                                     experiment_metadata.get('status'))
    print_normal('Load experiment %s succsss!' % experiment_id)

    # Step6. Cleanup temp data
    shutil.rmtree(temp_root_dir)
943