utils.py 5.21 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
3

Zejun Lin's avatar
Zejun Lin committed
4
import contextlib
5
import collections
Zejun Lin's avatar
Zejun Lin committed
6
import os
7
import socket
8
import sys
Zejun Lin's avatar
Zejun Lin committed
9
10
import subprocess
import requests
11
import time
12
import ruamel.yaml as yaml
Zejun Lin's avatar
Zejun Lin committed
13

chicm-ms's avatar
chicm-ms committed
14
EXPERIMENT_DONE_SIGNAL = 'Experiment done'
Zejun Lin's avatar
Zejun Lin committed
15

16
17
18
19
20
21
22
23
24
25
GREEN = '\33[32m'
RED = '\33[31m'
CLEAR = '\33[0m'

REST_ENDPOINT = 'http://localhost:8080/api/v1/nni'
EXPERIMENT_URL = REST_ENDPOINT + '/experiment'
STATUS_URL = REST_ENDPOINT + '/check-status'
TRIAL_JOBS_URL = REST_ENDPOINT + '/trial-jobs'
METRICS_URL = REST_ENDPOINT + '/metric-data'

Zejun Lin's avatar
Zejun Lin committed
26
def read_last_line(file_name):
27
    '''read last line of a file and return None if file not found'''
Zejun Lin's avatar
Zejun Lin committed
28
29
30
31
32
33
34
    try:
        *_, last_line = open(file_name)
        return last_line.strip()
    except (FileNotFoundError, ValueError):
        return None

def remove_files(file_list):
35
    '''remove a list of files'''
Zejun Lin's avatar
Zejun Lin committed
36
37
38
39
40
41
42
    for file_path in file_list:
        with contextlib.suppress(FileNotFoundError):
            os.remove(file_path)

def get_yml_content(file_path):
    '''Load yaml file content'''
    with open(file_path, 'r') as file:
43
        return yaml.load(file, Loader=yaml.Loader)
Zejun Lin's avatar
Zejun Lin committed
44
45
46
47
48
49

def dump_yml_content(file_path, content):
    '''Dump yaml file content'''
    with open(file_path, 'w') as file:
        file.write(yaml.dump(content, default_flow_style=False))

50
51
def setup_experiment(installed=True):
    '''setup the experiment if nni is not installed'''
Zejun Lin's avatar
Zejun Lin committed
52
    if not installed:
53
        os.environ['PATH'] = os.environ['PATH'] + ':' + os.getcwd()
Zejun Lin's avatar
Zejun Lin committed
54
55
56
57
58
59
60
61
62
        sdk_path = os.path.abspath('../src/sdk/pynni')
        cmd_path = os.path.abspath('../tools')
        pypath = os.environ.get('PYTHONPATH')
        if pypath:
            pypath = ':'.join([pypath, sdk_path, cmd_path])
        else:
            pypath = ':'.join([sdk_path, cmd_path])
        os.environ['PYTHONPATH'] = pypath

63
64
65
66
def get_experiment_id(experiment_url):
    experiment_id = requests.get(experiment_url).json()['id']
    return experiment_id

67
68
def get_experiment_dir(experiment_url):
    '''get experiment root directory'''
69
    experiment_id = get_experiment_id(experiment_url)
70
    return os.path.join(os.path.expanduser('~'), 'nni', 'experiments', experiment_id)
Zejun Lin's avatar
Zejun Lin committed
71

72
73
74
75
76
77
78
def get_nni_log_dir(experiment_url):
    '''get nni's log directory from nni's experiment url'''
    return os.path.join(get_experiment_dir(experiment_url), 'log')

def get_nni_log_path(experiment_url):
    '''get nni's log path from nni's experiment url'''
    return os.path.join(get_nni_log_dir(experiment_url), 'nnimanager.log')
Zejun Lin's avatar
Zejun Lin committed
79

80
def is_experiment_done(nnimanager_log_path):
81
    '''check if the experiment is done successfully'''
Zejun Lin's avatar
Zejun Lin committed
82
    assert os.path.exists(nnimanager_log_path), 'Experiment starts failed'
chicm-ms's avatar
chicm-ms committed
83
84
85
86
87
    
    with open(nnimanager_log_path, 'r') as f:
        log_content = f.read()
    
    return EXPERIMENT_DONE_SIGNAL in log_content
88
89
90
91
92
93
94
95
96
97
98
99
100
101

def get_experiment_status(status_url):
    nni_status = requests.get(status_url).json()
    return nni_status['status']

def get_succeeded_trial_num(trial_jobs_url):
    trial_jobs = requests.get(trial_jobs_url).json()
    num_succeed = 0
    for trial_job in trial_jobs:
        if trial_job['status'] in ['SUCCEEDED', 'EARLY_STOPPED']:
            num_succeed += 1
    print('num_succeed:', num_succeed)
    return num_succeed

102
103
def get_failed_trial_jobs(trial_jobs_url):
    '''Return failed trial jobs'''
104
    trial_jobs = requests.get(trial_jobs_url).json()
105
106
107
108
109
110
111
112
113
    failed_jobs = []
    for trial_job in trial_jobs:
        if trial_job['status'] in ['FAILED']:
            failed_jobs.append(trial_job)
    return failed_jobs

def print_failed_job_log(training_service, trial_jobs_url):
    '''Print job log of FAILED trial jobs'''
    trial_jobs = get_failed_trial_jobs(trial_jobs_url)
114
    for trial_job in trial_jobs:
115
        if training_service == 'local':
116
            if sys.platform == "win32":
chicm-ms's avatar
chicm-ms committed
117
                p = trial_job['stderrPath'].split(':')
118
                log_filename = ':'.join([p[-2], p[-1]])
119
            else:
120
121
122
123
124
125
126
                log_filename = trial_job['stderrPath'].split(':')[-1]
        else:
            log_filename = os.path.join(get_experiment_dir(EXPERIMENT_URL), 'trials', trial_job['id'], 'stdout_log_collection.log')
        with open(log_filename, 'r') as f:
            log_content = f.read()
            print(log_filename, flush=True)
            print(log_content, flush=True)
127
128
129
130
131
132

def parse_max_duration_time(max_exec_duration):
    unit = max_exec_duration[-1]
    time = max_exec_duration[:-1]
    units_dict = {'s':1, 'm':60, 'h':3600, 'd':86400}
    return int(time) * units_dict[unit]
133
134
135
136
137
138
139
140
141
142
143
144
145

def deep_update(source, overrides):
    """Update a nested dictionary or similar mapping.

    Modify ``source`` in place.
    """
    for key, value in overrides.items():
        if isinstance(value, collections.Mapping) and value:
            returned = deep_update(source.get(key, {}), value)
            source[key] = returned
        else:
            source[key] = overrides[key]
    return source
146
147
148
149
150
151
152
153
154
155
156
157
158

def detect_port(port):
    '''Detect if the port is used'''
    socket_test = socket.socket(socket.AF_INET,socket.SOCK_STREAM)
    try:
        socket_test.connect(('127.0.0.1', int(port)))
        socket_test.close()
        return True
    except:
        return False

def snooze():
    '''Sleep to make sure previous stopped exp has enough time to exit'''
liuzhe-lz's avatar
liuzhe-lz committed
159
    time.sleep(6)