updater.py 6.54 KB
Newer Older
Deshui Yu's avatar
Deshui Yu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge,
# to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.


import json
import os
24
25
from .rest_utils import rest_put, rest_post, rest_get, check_rest_server_quick, check_response
from .url_utils import experiment_url, import_data_url
Deshui Yu's avatar
Deshui Yu committed
26
from .config_utils import Config
27
from .common_utils import get_json_content, print_normal, print_error, print_warning
28
from .nnictl_utils import check_experiment_id, get_experiment_port, get_config_filename
29
from .launcher_utils import parse_time
30
from .constants import REST_TIME_OUT, TUNERS_SUPPORTING_IMPORT_DATA, TUNERS_NO_NEED_TO_IMPORT_DATA
Deshui Yu's avatar
Deshui Yu committed
31
32
33
34
35
36
37
38
39
40
41

def validate_digit(value, start, end):
    '''validate if a digit is valid'''
    if not str(value).isdigit() or int(value) < start or int(value) > end:
        raise ValueError('%s must be a digit from %s to %s' % (value, start, end))

def validate_file(path):
    '''validate if a file exist'''
    if not os.path.exists(path):
        raise FileNotFoundError('%s is not a valid file path' % path)

42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def validate_dispatcher(args):
    '''validate if the dispatcher of the experiment supports importing data'''
    nni_config = Config(get_config_filename(args)).get_config('experimentConfig')
    if nni_config.get('tuner') and nni_config['tuner'].get('builtinTunerName'):
        dispatcher_name = nni_config['tuner']['builtinTunerName']
    elif nni_config.get('advisor') and nni_config['advisor'].get('builtinAdvisorName'):
        dispatcher_name = nni_config['advisor']['builtinAdvisorName']
    else: # otherwise it should be a customized one
        return
    if dispatcher_name not in TUNERS_SUPPORTING_IMPORT_DATA:
        if dispatcher_name in TUNERS_NO_NEED_TO_IMPORT_DATA:
            print_warning("There is no need to import data for %s" % dispatcher_name)
            exit(0)
        else:
            print_error("%s does not support importing addtional data" % dispatcher_name)
            exit(1)

Deshui Yu's avatar
Deshui Yu committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def load_search_space(path):
    '''load search space content'''
    content = json.dumps(get_json_content(path))
    if not content:
        raise ValueError('searchSpace file should not be empty')
    return content

def get_query_type(key):
    '''get update query type'''
    if key == 'trialConcurrency':
        return '?update_type=TRIAL_CONCURRENCY'
    if key == 'maxExecDuration':
        return '?update_type=MAX_EXEC_DURATION'
    if key == 'searchSpace':
        return '?update_type=SEARCH_SPACE'
QuanluZhang's avatar
QuanluZhang committed
74
75
    if key == 'maxTrialNum':
        return '?update_type=MAX_TRIAL_NUM'
Deshui Yu's avatar
Deshui Yu committed
76

goooxu's avatar
goooxu committed
77
def update_experiment_profile(args, key, value):
Deshui Yu's avatar
Deshui Yu committed
78
    '''call restful server to update experiment profile'''
79
    nni_config = Config(get_config_filename(args))
Deshui Yu's avatar
Deshui Yu committed
80
    rest_port = nni_config.get_config('restServerPort')
81
82
    running, _ = check_rest_server_quick(rest_port)
    if running:
83
        response = rest_get(experiment_url(rest_port), REST_TIME_OUT)
84
        if response and check_response(response):
Deshui Yu's avatar
Deshui Yu committed
85
86
            experiment_profile = json.loads(response.text)
            experiment_profile['params'][key] = value
87
            response = rest_put(experiment_url(rest_port)+get_query_type(key), json.dumps(experiment_profile), REST_TIME_OUT)
88
            if response and check_response(response):
Deshui Yu's avatar
Deshui Yu committed
89
90
                return response
    else:
91
        print_error('Restful server is not running...')
Deshui Yu's avatar
Deshui Yu committed
92
93
94
95
96
    return None

def update_searchspace(args):
    validate_file(args.filename)
    content = load_search_space(args.filename)
97
98
99
    args.port = get_experiment_port(args)
    if args.port is not None:
        if update_experiment_profile(args, 'searchSpace', content):
100
            print_normal('Update %s success!' % 'searchSpace')
101
        else:
102
103
            print_error('Update %s failed!' % 'searchSpace')

Deshui Yu's avatar
Deshui Yu committed
104
105
106

def update_concurrency(args):
    validate_digit(args.value, 1, 1000)
107
108
109
    args.port = get_experiment_port(args)
    if args.port is not None:
        if update_experiment_profile(args, 'trialConcurrency', int(args.value)):
110
            print_normal('Update %s success!' % 'concurrency')
111
        else:
112
            print_error('Update %s failed!' % 'concurrency')
Deshui Yu's avatar
Deshui Yu committed
113
114

def update_duration(args):
115
116
    #parse time, change time unit to seconds 
    args.value = parse_time(args.value)
117
118
119
    args.port = get_experiment_port(args)
    if args.port is not None:
        if update_experiment_profile(args, 'maxExecDuration', int(args.value)):
120
            print_normal('Update %s success!' % 'duration')
121
        else:
122
            print_error('Update %s failed!' % 'duration')
Deshui Yu's avatar
Deshui Yu committed
123

QuanluZhang's avatar
QuanluZhang committed
124
125
def update_trialnum(args):
    validate_digit(args.value, 1, 999999999)
126
    if update_experiment_profile(args, 'maxTrialNum', int(args.value)):
127
        print_normal('Update %s success!' % 'trialnum')
128
    else:
129
130
131
132
133
134
135
136
137
138
        print_error('Update %s failed!' % 'trialnum')

def import_data(args):
    '''import additional data to the experiment'''
    validate_file(args.filename)
    validate_dispatcher(args)
    content = load_search_space(args.filename)
    args.port = get_experiment_port(args)
    if args.port is not None:
        if import_data_to_restful_server(args, content):
139
            pass
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
        else:
            print_error('Import data failed!')

def import_data_to_restful_server(args, content):
    '''call restful server to import data to the experiment'''
    nni_config = Config(get_config_filename(args))
    rest_port = nni_config.get_config('restServerPort')
    running, _ = check_rest_server_quick(rest_port)
    if running:
        response = rest_post(import_data_url(rest_port), content, REST_TIME_OUT)
        if response and check_response(response):
            return response
    else:
        print_error('Restful server is not running...')
    return None