msg_dispatcher.py 7.98 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================

21
import os
22
23
24
25
26
27
28
import logging
from collections import defaultdict
import json_tricks

from .protocol import CommandType, send
from .msg_dispatcher_base import MsgDispatcherBase
from .assessor import AssessResult
29
from .common import multi_thread_enabled
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73

_logger = logging.getLogger(__name__)

# Assessor global variables
_trial_history = defaultdict(dict)
'''key: trial job ID; value: intermediate results, mapping from sequence number to data'''

_ended_trials = set()
'''trial_job_id of all ended trials.
We need this because NNI manager may send metrics after reporting a trial ended.
TODO: move this logic to NNI manager
'''

def _sort_history(history):
    ret = [ ]
    for i, _ in enumerate(history):
        if i in history:
            ret.append(history[i])
        else:
            break
    return ret

# Tuner global variables
_next_parameter_id = 0
_trial_params = {}
'''key: trial job ID; value: parameters'''
_customized_parameter_ids = set()

def _create_parameter_id():
    global _next_parameter_id  # pylint: disable=global-statement
    _next_parameter_id += 1
    return _next_parameter_id - 1

def _pack_parameter(parameter_id, params, customized=False):
    _trial_params[parameter_id] = params
    ret = {
        'parameter_id': parameter_id,
        'parameter_source': 'customized' if customized else 'algorithm',
        'parameters': params
    }
    return json_tricks.dumps(ret)

class MsgDispatcher(MsgDispatcherBase):
    def __init__(self, tuner, assessor=None):
74
        super(MsgDispatcher, self).__init__()
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
        self.tuner = tuner
        self.assessor = assessor
        if assessor is None:
            _logger.debug('Assessor is not configured')

    def load_checkpoint(self):
        self.tuner.load_checkpoint()
        if self.assessor is not None:
            self.assessor.load_checkpoint()

    def save_checkpoint(self):
        self.tuner.save_checkpoint()
        if self.assessor is not None:
            self.assessor.save_checkpoint()

chicm-ms's avatar
chicm-ms committed
90
    def handle_initialize(self, data):
91
92
        """Data is search space
        """
chicm-ms's avatar
chicm-ms committed
93
94
95
        self.tuner.update_search_space(data)
        send(CommandType.Initialized, '')

96
97
98
    def handle_request_trial_jobs(self, data):
        # data: number or trial jobs
        ids = [_create_parameter_id() for _ in range(data)]
Yan Ni's avatar
Yan Ni committed
99
        _logger.debug("requesting for generating params of {}".format(ids))
100
        params_list = self.tuner.generate_multiple_parameters(ids)
xuehui's avatar
xuehui committed
101

102
103
        for i, _ in enumerate(params_list):
            send(CommandType.NewTrialJob, _pack_parameter(ids[i], params_list[i]))
xuehui's avatar
xuehui committed
104
        # when parameters is None.
105
        if len(params_list) < len(ids):
xuehui's avatar
xuehui committed
106
            send(CommandType.NoMoreTrialJobs, _pack_parameter(ids[0], ''))
107
108
109
110
111
112
113
114
115
116
117

    def handle_update_search_space(self, data):
        self.tuner.update_search_space(data)

    def handle_add_customized_trial(self, data):
         # data: parameters
        id_ = _create_parameter_id()
        _customized_parameter_ids.add(id_)
        send(CommandType.NewTrialJob, _pack_parameter(id_, data, customized=True))

    def handle_report_metric_data(self, data):
118
119
120
121
122
123
        """
        :param data: a dict received from nni_manager, which contains:
                    - 'parameter_id': id of the trial
                    - 'value': metric value reported by nni.report_final_result()
                    - 'type': report type, support {'FINAL', 'PERIODICAL'}
        """
124
        if data['type'] == 'FINAL':
125
            self._handle_final_metric_data(data)
126
127
128
129
        elif data['type'] == 'PERIODICAL':
            if self.assessor is not None:
                self._handle_intermediate_metric_data(data)
            else:
chicm-ms's avatar
chicm-ms committed
130
                pass
131
132
133
134
        else:
            raise ValueError('Data type not supported: {}'.format(data['type']))

    def handle_trial_end(self, data):
QuanluZhang's avatar
QuanluZhang committed
135
136
137
138
139
140
        """
        data: it has three keys: trial_job_id, event, hyper_params
            trial_job_id: the id generated by training service
            event: the job's state
            hyper_params: the hyperparameters generated and returned by tuner
        """
141
142
143
144
145
146
        trial_job_id = data['trial_job_id']
        _ended_trials.add(trial_job_id)
        if trial_job_id in _trial_history:
            _trial_history.pop(trial_job_id)
            if self.assessor is not None:
                self.assessor.trial_end(trial_job_id, data['event'] == 'SUCCEEDED')
147
148
        if self.tuner is not None:
            self.tuner.trial_end(json_tricks.loads(data['hyper_params'])['parameter_id'], data['event'] == 'SUCCEEDED')
149

150
151
152
153
154
155
156
157
158
159
    def _handle_final_metric_data(self, data):
        """Call tuner to process final results
        """
        id_ = data['parameter_id']
        value = data['value']
        if id_ in _customized_parameter_ids:
            self.tuner.receive_customized_trial_result(id_, _trial_params[id_], value)
        else:
            self.tuner.receive_trial_result(id_, _trial_params[id_], value)

160
    def _handle_intermediate_metric_data(self, data):
161
162
        """Call assessor to process intermediate results
        """
163
        if data['type'] != 'PERIODICAL':
164
            return
165
        if self.assessor is None:
166
            return
167
168
169

        trial_job_id = data['trial_job_id']
        if trial_job_id in _ended_trials:
170
            return
171
172
173
174
175

        history = _trial_history[trial_job_id]
        history[data['sequence']] = data['value']
        ordered_history = _sort_history(history)
        if len(ordered_history) < data['sequence']:  # no user-visible update since last time
176
            return
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191

        try:
            result = self.assessor.assess_trial(trial_job_id, ordered_history)
        except Exception as e:
            _logger.exception('Assessor error')

        if isinstance(result, bool):
            result = AssessResult.Good if result else AssessResult.Bad
        elif not isinstance(result, AssessResult):
            msg = 'Result of Assessor.assess_trial must be an object of AssessResult, not %s'
            raise RuntimeError(msg % type(result))

        if result is AssessResult.Bad:
            _logger.debug('BAD, kill %s', trial_job_id)
            send(CommandType.KillTrialJob, json_tricks.dumps(trial_job_id))
192
193
194
195
            # notify tuner
            _logger.debug('env var: NNI_INCLUDE_INTERMEDIATE_RESULTS: [%s]', os.environ.get('NNI_INCLUDE_INTERMEDIATE_RESULTS'))
            if os.environ.get('NNI_INCLUDE_INTERMEDIATE_RESULTS') == 'true':
                self._earlystop_notify_tuner(data)
196
197
        else:
            _logger.debug('GOOD')
198
199
200
201
202
203
204
205
206
207
208

    def _earlystop_notify_tuner(self, data):
        """Send last intermediate result as final result to tuner in case the
        trial is early stopped.
        """
        _logger.debug('Early stop notify tuner data: [%s]', data)
        data['type'] = 'FINAL'
        if multi_thread_enabled():
            self._handle_final_metric_data(data)
        else:
            self.enqueue_command(CommandType.ReportMetricData, data)