utils.py 6.35 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
3

Zejun Lin's avatar
Zejun Lin committed
4
import contextlib
5
import collections
Zejun Lin's avatar
Zejun Lin committed
6
import os
7
import socket
8
import sys
Zejun Lin's avatar
Zejun Lin committed
9
10
import subprocess
import requests
11
import time
12
import yaml
13
import shlex
14
import warnings
Zejun Lin's avatar
Zejun Lin committed
15

chicm-ms's avatar
chicm-ms committed
16
EXPERIMENT_DONE_SIGNAL = 'Experiment done'
Zejun Lin's avatar
Zejun Lin committed
17

18
19
20
21
GREEN = '\33[32m'
RED = '\33[31m'
CLEAR = '\33[0m'

chicm-ms's avatar
chicm-ms committed
22
23
24
25
26
27
REST_ENDPOINT = 'http://localhost:8080'
API_ROOT_URL = REST_ENDPOINT + '/api/v1/nni'
EXPERIMENT_URL = API_ROOT_URL + '/experiment'
STATUS_URL = API_ROOT_URL + '/check-status'
TRIAL_JOBS_URL = API_ROOT_URL + '/trial-jobs'
METRICS_URL = API_ROOT_URL + '/metric-data'
28
GET_IMPORTED_DATA_URL = API_ROOT_URL + '/experiment/imported-data'
29

Zejun Lin's avatar
Zejun Lin committed
30
def read_last_line(file_name):
31
    '''read last line of a file and return None if file not found'''
Zejun Lin's avatar
Zejun Lin committed
32
33
34
35
36
37
38
    try:
        *_, last_line = open(file_name)
        return last_line.strip()
    except (FileNotFoundError, ValueError):
        return None

def remove_files(file_list):
39
    '''remove a list of files'''
Zejun Lin's avatar
Zejun Lin committed
40
41
42
43
44
45
46
    for file_path in file_list:
        with contextlib.suppress(FileNotFoundError):
            os.remove(file_path)

def get_yml_content(file_path):
    '''Load yaml file content'''
    with open(file_path, 'r') as file:
47
        return yaml.safe_load(file)
Zejun Lin's avatar
Zejun Lin committed
48
49
50
51

def dump_yml_content(file_path, content):
    '''Dump yaml file content'''
    with open(file_path, 'w') as file:
52
        file.write(yaml.safe_dump(content, default_flow_style=False))
Zejun Lin's avatar
Zejun Lin committed
53

54
55
def setup_experiment(installed=True):
    '''setup the experiment if nni is not installed'''
Zejun Lin's avatar
Zejun Lin committed
56
    if not installed:
57
        os.environ['PATH'] = os.environ['PATH'] + ':' + os.getcwd()
Zejun Lin's avatar
Zejun Lin committed
58
59
60
61
62
63
64
65
66
        sdk_path = os.path.abspath('../src/sdk/pynni')
        cmd_path = os.path.abspath('../tools')
        pypath = os.environ.get('PYTHONPATH')
        if pypath:
            pypath = ':'.join([pypath, sdk_path, cmd_path])
        else:
            pypath = ':'.join([sdk_path, cmd_path])
        os.environ['PYTHONPATH'] = pypath

67
68
69
70
def get_experiment_id(experiment_url):
    experiment_id = requests.get(experiment_url).json()['id']
    return experiment_id

71
def get_experiment_dir(experiment_url=None, experiment_id=None):
72
    '''get experiment root directory'''
73
74
75
    assert any([experiment_url, experiment_id])
    if experiment_id is None:
        experiment_id = get_experiment_id(experiment_url)
chicm-ms's avatar
chicm-ms committed
76
    return os.path.join(os.path.expanduser('~'), 'nni-experiments', experiment_id)
Zejun Lin's avatar
Zejun Lin committed
77

78
def get_nni_log_dir(experiment_url=None, experiment_id=None):
79
    '''get nni's log directory from nni's experiment url'''
80
    return os.path.join(get_experiment_dir(experiment_url, experiment_id), 'log')
81
82
83
84

def get_nni_log_path(experiment_url):
    '''get nni's log path from nni's experiment url'''
    return os.path.join(get_nni_log_dir(experiment_url), 'nnimanager.log')
Zejun Lin's avatar
Zejun Lin committed
85

86
def is_experiment_done(nnimanager_log_path):
87
    '''check if the experiment is done successfully'''
Zejun Lin's avatar
Zejun Lin committed
88
    assert os.path.exists(nnimanager_log_path), 'Experiment starts failed'
chicm-ms's avatar
chicm-ms committed
89
90
91
    
    with open(nnimanager_log_path, 'r') as f:
        log_content = f.read()
92

chicm-ms's avatar
chicm-ms committed
93
    return EXPERIMENT_DONE_SIGNAL in log_content
94
95
96
97
98

def get_experiment_status(status_url):
    nni_status = requests.get(status_url).json()
    return nni_status['status']

chicm-ms's avatar
chicm-ms committed
99
def get_trial_stats(trial_jobs_url):
100
    trial_jobs = requests.get(trial_jobs_url).json()
chicm-ms's avatar
chicm-ms committed
101
    trial_stats = collections.defaultdict(int)
102
    for trial_job in trial_jobs:
chicm-ms's avatar
chicm-ms committed
103
104
        trial_stats[trial_job['status']] += 1
    return trial_stats
105

chicm-ms's avatar
chicm-ms committed
106
def get_trial_jobs(trial_jobs_url, status=None):
107
    '''Return failed trial jobs'''
108
    trial_jobs = requests.get(trial_jobs_url).json()
chicm-ms's avatar
chicm-ms committed
109
    res = []
110
    for trial_job in trial_jobs:
chicm-ms's avatar
chicm-ms committed
111
112
113
114
115
116
117
118
119
120
121
122
123
        if status is None or trial_job['status'] == status:
            res.append(trial_job)
    return res

def get_failed_trial_jobs(trial_jobs_url):
    '''Return failed trial jobs'''
    return get_trial_jobs(trial_jobs_url, 'FAILED')

def print_file_content(filepath):
    with open(filepath, 'r') as f:
        content = f.read()
        print(filepath, flush=True)
        print(content, flush=True)
124

chicm-ms's avatar
chicm-ms committed
125
def print_trial_job_log(training_service, trial_jobs_url):
J-shang's avatar
J-shang committed
126
127
128
129
130
131
132
    trial_log_root = os.path.join(get_experiment_dir(EXPERIMENT_URL), 'trials')
    if not os.path.exists(trial_log_root):
        print('trial log folder does not exist: {}'.format(trial_log_root), flush=True)
        return
    folders = os.listdir(trial_log_root)
    for name in folders:
        trial_log_dir = os.path.join(trial_log_root, name)
chicm-ms's avatar
chicm-ms committed
133
134
        log_files = ['stderr', 'trial.log'] if training_service == 'local' else ['stdout_log_collection.log']
        for log_file in log_files:
J-shang's avatar
J-shang committed
135
136
137
            log_file_path = os.path.join(trial_log_dir, log_file)
            if os.path.exists(log_file_path):
                print_file_content(log_file_path)
chicm-ms's avatar
chicm-ms committed
138

139
140
def print_experiment_log(experiment_id):
    log_dir = get_nni_log_dir(experiment_id=experiment_id)
chicm-ms's avatar
chicm-ms committed
141
142
143
    for log_file in ['dispatcher.log', 'nnimanager.log']:
        filepath = os.path.join(log_dir, log_file)
        print_file_content(filepath)
144

145
146
147
148
149
    print('nnictl log stderr:')
    subprocess.run(shlex.split('nnictl log stderr {}'.format(experiment_id)))
    print('nnictl log stdout:')
    subprocess.run(shlex.split('nnictl log stdout {}'.format(experiment_id)))

150
151
152
153
154
def parse_max_duration_time(max_exec_duration):
    unit = max_exec_duration[-1]
    time = max_exec_duration[:-1]
    units_dict = {'s':1, 'm':60, 'h':3600, 'd':86400}
    return int(time) * units_dict[unit]
155
156
157
158
159
160
161
162
163
164
165
166
167

def deep_update(source, overrides):
    """Update a nested dictionary or similar mapping.

    Modify ``source`` in place.
    """
    for key, value in overrides.items():
        if isinstance(value, collections.Mapping) and value:
            returned = deep_update(source.get(key, {}), value)
            source[key] = returned
        else:
            source[key] = overrides[key]
    return source
168
169
170
171
172
173
174
175
176
177
178

def detect_port(port):
    '''Detect if the port is used'''
    socket_test = socket.socket(socket.AF_INET,socket.SOCK_STREAM)
    try:
        socket_test.connect(('127.0.0.1', int(port)))
        socket_test.close()
        return True
    except:
        return False

179
180

def wait_for_port_available(port, timeout):
181
    for i in range(timeout):
182
183
        if not detect_port(port):
            return
184
        warnings.warn("Port isn't available in {} seconds (patience: {})".format(i, timeout), RuntimeWarning)
185
        time.sleep(1)
186
187
188

    msg = 'Port {} is not available in {} seconds. Maybe the previous experiment fails to stop?'.format(port, timeout)
    raise RuntimeError(msg)