msg_dispatcher.py 9.37 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================

import logging
from collections import defaultdict
import json_tricks

from .protocol import CommandType, send
from .msg_dispatcher_base import MsgDispatcherBase
from .assessor import AssessResult
28
from .common import multi_thread_enabled, multi_phase_enabled
chicm-ms's avatar
chicm-ms committed
29
from .env_vars import dispatcher_env_vars
30
from .utils import MetricType
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63

_logger = logging.getLogger(__name__)

# Assessor global variables
_trial_history = defaultdict(dict)
'''key: trial job ID; value: intermediate results, mapping from sequence number to data'''

_ended_trials = set()
'''trial_job_id of all ended trials.
We need this because NNI manager may send metrics after reporting a trial ended.
TODO: move this logic to NNI manager
'''

def _sort_history(history):
    ret = [ ]
    for i, _ in enumerate(history):
        if i in history:
            ret.append(history[i])
        else:
            break
    return ret

# Tuner global variables
_next_parameter_id = 0
_trial_params = {}
'''key: trial job ID; value: parameters'''
_customized_parameter_ids = set()

def _create_parameter_id():
    global _next_parameter_id  # pylint: disable=global-statement
    _next_parameter_id += 1
    return _next_parameter_id - 1

64
def _pack_parameter(parameter_id, params, customized=False, trial_job_id=None, parameter_index=None):
65
66
67
68
69
70
    _trial_params[parameter_id] = params
    ret = {
        'parameter_id': parameter_id,
        'parameter_source': 'customized' if customized else 'algorithm',
        'parameters': params
    }
71
72
73
74
75
76
    if trial_job_id is not None:
        ret['trial_job_id'] = trial_job_id
    if parameter_index is not None:
        ret['parameter_index'] = parameter_index
    else:
        ret['parameter_index'] = 0
77
78
79
80
    return json_tricks.dumps(ret)

class MsgDispatcher(MsgDispatcherBase):
    def __init__(self, tuner, assessor=None):
81
        super(MsgDispatcher, self).__init__()
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
        self.tuner = tuner
        self.assessor = assessor
        if assessor is None:
            _logger.debug('Assessor is not configured')

    def load_checkpoint(self):
        self.tuner.load_checkpoint()
        if self.assessor is not None:
            self.assessor.load_checkpoint()

    def save_checkpoint(self):
        self.tuner.save_checkpoint()
        if self.assessor is not None:
            self.assessor.save_checkpoint()

chicm-ms's avatar
chicm-ms committed
97
    def handle_initialize(self, data):
98
99
        """Data is search space
        """
chicm-ms's avatar
chicm-ms committed
100
101
102
        self.tuner.update_search_space(data)
        send(CommandType.Initialized, '')

103
104
105
    def handle_request_trial_jobs(self, data):
        # data: number or trial jobs
        ids = [_create_parameter_id() for _ in range(data)]
Yan Ni's avatar
Yan Ni committed
106
        _logger.debug("requesting for generating params of {}".format(ids))
107
        params_list = self.tuner.generate_multiple_parameters(ids)
xuehui's avatar
xuehui committed
108

109
110
        for i, _ in enumerate(params_list):
            send(CommandType.NewTrialJob, _pack_parameter(ids[i], params_list[i]))
xuehui's avatar
xuehui committed
111
        # when parameters is None.
112
        if len(params_list) < len(ids):
xuehui's avatar
xuehui committed
113
            send(CommandType.NoMoreTrialJobs, _pack_parameter(ids[0], ''))
114
115
116
117

    def handle_update_search_space(self, data):
        self.tuner.update_search_space(data)

118
119
120
121
122
123
    def handle_import_data(self, data):
        """Import additional data for tuning
        data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
        """
        self.tuner.import_data(data)

124
    def handle_add_customized_trial(self, data):
125
        # data: parameters
126
127
128
129
130
        id_ = _create_parameter_id()
        _customized_parameter_ids.add(id_)
        send(CommandType.NewTrialJob, _pack_parameter(id_, data, customized=True))

    def handle_report_metric_data(self, data):
131
        """
132
133
134
135
        data: a dict received from nni_manager, which contains:
              - 'parameter_id': id of the trial
              - 'value': metric value reported by nni.report_final_result()
              - 'type': report type, support {'FINAL', 'PERIODICAL'}
136
        """
137
        if data['type'] == MetricType.FINAL:
138
            self._handle_final_metric_data(data)
139
        elif data['type'] == MetricType.PERIODICAL:
140
141
            if self.assessor is not None:
                self._handle_intermediate_metric_data(data)
142
        elif data['type'] == MetricType.REQUEST_PARAMETER:
143
144
145
146
147
148
            assert multi_phase_enabled()
            assert data['trial_job_id'] is not None
            assert data['parameter_index'] is not None
            param_id = _create_parameter_id()
            param = self.tuner.generate_parameters(param_id, trial_job_id=data['trial_job_id'])
            send(CommandType.SendTrialJobParameter, _pack_parameter(param_id, param, trial_job_id=data['trial_job_id'], parameter_index=data['parameter_index']))
149
150
151
152
        else:
            raise ValueError('Data type not supported: {}'.format(data['type']))

    def handle_trial_end(self, data):
QuanluZhang's avatar
QuanluZhang committed
153
154
        """
        data: it has three keys: trial_job_id, event, hyper_params
155
156
157
             - trial_job_id: the id generated by training service
             - event: the job's state
             - hyper_params: the hyperparameters generated and returned by tuner
QuanluZhang's avatar
QuanluZhang committed
158
        """
159
160
161
162
163
164
        trial_job_id = data['trial_job_id']
        _ended_trials.add(trial_job_id)
        if trial_job_id in _trial_history:
            _trial_history.pop(trial_job_id)
            if self.assessor is not None:
                self.assessor.trial_end(trial_job_id, data['event'] == 'SUCCEEDED')
165
166
        if self.tuner is not None:
            self.tuner.trial_end(json_tricks.loads(data['hyper_params'])['parameter_id'], data['event'] == 'SUCCEEDED')
167

168
169
170
171
172
173
    def _handle_final_metric_data(self, data):
        """Call tuner to process final results
        """
        id_ = data['parameter_id']
        value = data['value']
        if id_ in _customized_parameter_ids:
174
175
176
177
            if multi_phase_enabled():
                self.tuner.receive_customized_trial_result(id_, _trial_params[id_], value, trial_job_id=data['trial_job_id'])
            else:
                self.tuner.receive_customized_trial_result(id_, _trial_params[id_], value)
178
        else:
179
180
181
182
            if multi_phase_enabled():
                self.tuner.receive_trial_result(id_, _trial_params[id_], value, trial_job_id=data['trial_job_id'])
            else:
                self.tuner.receive_trial_result(id_, _trial_params[id_], value)
183

184
    def _handle_intermediate_metric_data(self, data):
185
186
        """Call assessor to process intermediate results
        """
187
        if data['type'] != MetricType.PERIODICAL:
188
            return
189
        if self.assessor is None:
190
            return
191
192
193

        trial_job_id = data['trial_job_id']
        if trial_job_id in _ended_trials:
194
            return
195
196
197
198
199

        history = _trial_history[trial_job_id]
        history[data['sequence']] = data['value']
        ordered_history = _sort_history(history)
        if len(ordered_history) < data['sequence']:  # no user-visible update since last time
200
            return
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215

        try:
            result = self.assessor.assess_trial(trial_job_id, ordered_history)
        except Exception as e:
            _logger.exception('Assessor error')

        if isinstance(result, bool):
            result = AssessResult.Good if result else AssessResult.Bad
        elif not isinstance(result, AssessResult):
            msg = 'Result of Assessor.assess_trial must be an object of AssessResult, not %s'
            raise RuntimeError(msg % type(result))

        if result is AssessResult.Bad:
            _logger.debug('BAD, kill %s', trial_job_id)
            send(CommandType.KillTrialJob, json_tricks.dumps(trial_job_id))
216
            # notify tuner
chicm-ms's avatar
chicm-ms committed
217
218
            _logger.debug('env var: NNI_INCLUDE_INTERMEDIATE_RESULTS: [%s]', dispatcher_env_vars.NNI_INCLUDE_INTERMEDIATE_RESULTS)
            if dispatcher_env_vars.NNI_INCLUDE_INTERMEDIATE_RESULTS == 'true':
219
                self._earlystop_notify_tuner(data)
220
221
        else:
            _logger.debug('GOOD')
222
223
224
225
226
227

    def _earlystop_notify_tuner(self, data):
        """Send last intermediate result as final result to tuner in case the
        trial is early stopped.
        """
        _logger.debug('Early stop notify tuner data: [%s]', data)
228
        data['type'] = MetricType.FINAL
229
230
231
232
        if multi_thread_enabled():
            self._handle_final_metric_data(data)
        else:
            self.enqueue_command(CommandType.ReportMetricData, data)