updater.py 5.49 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
Deshui Yu's avatar
Deshui Yu committed
3
4
5

import json
import os
6
7
from .rest_utils import rest_put, rest_post, rest_get, check_rest_server_quick, check_response
from .url_utils import experiment_url, import_data_url
Deshui Yu's avatar
Deshui Yu committed
8
from .config_utils import Config
9
from .common_utils import get_json_content, print_normal, print_error, print_warning
chicm-ms's avatar
chicm-ms committed
10
from .nnictl_utils import get_experiment_port, get_config_filename
11
from .launcher_utils import parse_time
12
from .constants import REST_TIME_OUT, TUNERS_SUPPORTING_IMPORT_DATA, TUNERS_NO_NEED_TO_IMPORT_DATA
Deshui Yu's avatar
Deshui Yu committed
13
14
15
16
17
18
19
20
21
22
23

def validate_digit(value, start, end):
    '''validate if a digit is valid'''
    if not str(value).isdigit() or int(value) < start or int(value) > end:
        raise ValueError('%s must be a digit from %s to %s' % (value, start, end))

def validate_file(path):
    '''validate if a file exist'''
    if not os.path.exists(path):
        raise FileNotFoundError('%s is not a valid file path' % path)

24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def validate_dispatcher(args):
    '''validate if the dispatcher of the experiment supports importing data'''
    nni_config = Config(get_config_filename(args)).get_config('experimentConfig')
    if nni_config.get('tuner') and nni_config['tuner'].get('builtinTunerName'):
        dispatcher_name = nni_config['tuner']['builtinTunerName']
    elif nni_config.get('advisor') and nni_config['advisor'].get('builtinAdvisorName'):
        dispatcher_name = nni_config['advisor']['builtinAdvisorName']
    else: # otherwise it should be a customized one
        return
    if dispatcher_name not in TUNERS_SUPPORTING_IMPORT_DATA:
        if dispatcher_name in TUNERS_NO_NEED_TO_IMPORT_DATA:
            print_warning("There is no need to import data for %s" % dispatcher_name)
            exit(0)
        else:
            print_error("%s does not support importing addtional data" % dispatcher_name)
            exit(1)

Deshui Yu's avatar
Deshui Yu committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def load_search_space(path):
    '''load search space content'''
    content = json.dumps(get_json_content(path))
    if not content:
        raise ValueError('searchSpace file should not be empty')
    return content

def get_query_type(key):
    '''get update query type'''
    if key == 'trialConcurrency':
        return '?update_type=TRIAL_CONCURRENCY'
    if key == 'maxExecDuration':
        return '?update_type=MAX_EXEC_DURATION'
    if key == 'searchSpace':
        return '?update_type=SEARCH_SPACE'
QuanluZhang's avatar
QuanluZhang committed
56
57
    if key == 'maxTrialNum':
        return '?update_type=MAX_TRIAL_NUM'
Deshui Yu's avatar
Deshui Yu committed
58

goooxu's avatar
goooxu committed
59
def update_experiment_profile(args, key, value):
Deshui Yu's avatar
Deshui Yu committed
60
    '''call restful server to update experiment profile'''
61
    nni_config = Config(get_config_filename(args))
Deshui Yu's avatar
Deshui Yu committed
62
    rest_port = nni_config.get_config('restServerPort')
63
64
    running, _ = check_rest_server_quick(rest_port)
    if running:
65
        response = rest_get(experiment_url(rest_port), REST_TIME_OUT)
66
        if response and check_response(response):
Deshui Yu's avatar
Deshui Yu committed
67
68
            experiment_profile = json.loads(response.text)
            experiment_profile['params'][key] = value
69
            response = rest_put(experiment_url(rest_port)+get_query_type(key), json.dumps(experiment_profile), REST_TIME_OUT)
70
            if response and check_response(response):
Deshui Yu's avatar
Deshui Yu committed
71
72
                return response
    else:
73
        print_error('Restful server is not running...')
Deshui Yu's avatar
Deshui Yu committed
74
75
76
77
78
    return None

def update_searchspace(args):
    validate_file(args.filename)
    content = load_search_space(args.filename)
79
80
81
    args.port = get_experiment_port(args)
    if args.port is not None:
        if update_experiment_profile(args, 'searchSpace', content):
82
            print_normal('Update %s success!' % 'searchSpace')
83
        else:
84
85
            print_error('Update %s failed!' % 'searchSpace')

Deshui Yu's avatar
Deshui Yu committed
86
87
88

def update_concurrency(args):
    validate_digit(args.value, 1, 1000)
89
90
91
    args.port = get_experiment_port(args)
    if args.port is not None:
        if update_experiment_profile(args, 'trialConcurrency', int(args.value)):
92
            print_normal('Update %s success!' % 'concurrency')
93
        else:
94
            print_error('Update %s failed!' % 'concurrency')
Deshui Yu's avatar
Deshui Yu committed
95
96

def update_duration(args):
97
    #parse time, change time unit to seconds
98
    args.value = parse_time(args.value)
99
100
101
    args.port = get_experiment_port(args)
    if args.port is not None:
        if update_experiment_profile(args, 'maxExecDuration', int(args.value)):
102
            print_normal('Update %s success!' % 'duration')
103
        else:
104
            print_error('Update %s failed!' % 'duration')
Deshui Yu's avatar
Deshui Yu committed
105

QuanluZhang's avatar
QuanluZhang committed
106
107
def update_trialnum(args):
    validate_digit(args.value, 1, 999999999)
108
    if update_experiment_profile(args, 'maxTrialNum', int(args.value)):
109
        print_normal('Update %s success!' % 'trialnum')
110
    else:
111
112
113
114
115
116
117
118
119
120
        print_error('Update %s failed!' % 'trialnum')

def import_data(args):
    '''import additional data to the experiment'''
    validate_file(args.filename)
    validate_dispatcher(args)
    content = load_search_space(args.filename)
    args.port = get_experiment_port(args)
    if args.port is not None:
        if import_data_to_restful_server(args, content):
121
            pass
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
        else:
            print_error('Import data failed!')

def import_data_to_restful_server(args, content):
    '''call restful server to import data to the experiment'''
    nni_config = Config(get_config_filename(args))
    rest_port = nni_config.get_config('restServerPort')
    running, _ = check_rest_server_quick(rest_port)
    if running:
        response = rest_post(import_data_url(rest_port), content, REST_TIME_OUT)
        if response and check_response(response):
            return response
    else:
        print_error('Restful server is not running...')
    return None