integration.py 7.18 KB
Newer Older
1
2
3
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

4
import logging
5
import warnings
6
from typing import Any, Callable
7

8
import nni
9
from nni.runtime.msg_dispatcher_base import MsgDispatcherBase
10
from nni.runtime.protocol import CommandType, send
11
12
13
from nni.utils import MetricType

from .graph import MetricData
14
from .integration_api import register_advisor
15

QuanluZhang's avatar
QuanluZhang committed
16
_logger = logging.getLogger(__name__)
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47


class RetiariiAdvisor(MsgDispatcherBase):
    """
    The class is to connect Retiarii components to NNI backend.

    It will function as the main thread when running a Retiarii experiment through NNI.
    Strategy will be launched as its thread, who will call APIs in execution engine. Execution
    engine will then find the advisor singleton and send payloads to advisor.

    When metrics are sent back, advisor will first receive the payloads, who will call the callback
    function (that is a member function in graph listener).

    The conversion advisor provides are minimum. It is only a send/receive module, and execution engine
    needs to handle all the rest.

    FIXME
        How does advisor exit when strategy exists?

    Attributes
    ----------
    send_trial_callback

    request_trial_jobs_callback

    trial_end_callback

    intermediate_metric_callback

    final_metric_callback
    """
48

QuanluZhang's avatar
QuanluZhang committed
49
    def __init__(self):
50
51
        super(RetiariiAdvisor, self).__init__()
        register_advisor(self)  # register the current advisor as the "global only" advisor
52
        self.search_space = None
53
54
55
56
57
58
59
60
61
62

        self.send_trial_callback: Callable[[dict], None] = None
        self.request_trial_jobs_callback: Callable[[int], None] = None
        self.trial_end_callback: Callable[[int, bool], None] = None
        self.intermediate_metric_callback: Callable[[int, MetricData], None] = None
        self.final_metric_callback: Callable[[int, MetricData], None] = None

        self.parameters_count = 0

    def handle_initialize(self, data):
63
64
65
66
67
68
69
70
        """callback for initializing the advisor
        Parameters
        ----------
        data: dict
            search space
        """
        self.handle_update_search_space(data)
        send(CommandType.Initialized, '')
71

72
73
74
75
76
77
78
79
80
81
82
    def _validate_placement_constraint(self, placement_constraint):
        if placement_constraint is None:
            raise ValueError('placement_constraint is None')
        if not 'type' in placement_constraint:
            raise ValueError('placement_constraint must have `type`')
        if not 'gpus' in placement_constraint:
            raise ValueError('placement_constraint must have `gpus`')
        if placement_constraint['type'] not in ['None', 'GPUNumber', 'Device']:
            raise ValueError('placement_constraint.type must be either `None`,. `GPUNumber` or `Device`')
        if placement_constraint['type'] == 'None' and len(placement_constraint['gpus']) > 0:
            raise ValueError('placement_constraint.gpus must be an empty list when type == None')
83
84
85
86
87
88
        if placement_constraint['type'] == 'GPUNumber':
            if len(placement_constraint['gpus']) != 1:
                raise ValueError('placement_constraint.gpus currently only support one host when type == GPUNumber')
            for e in placement_constraint['gpus']:
                if not isinstance(e, int):
                    raise ValueError('placement_constraint.gpus must be a list of number when type == GPUNumber')
89
90
91
92
93
94
95
96
        if placement_constraint['type'] == 'Device':
            for e in placement_constraint['gpus']:
                if not isinstance(e, tuple):
                    raise ValueError('placement_constraint.gpus must be a list of tuple when type == Device')
                if not (len(e) == 2 and isinstance(e[0], str) and isinstance(e[1], int)):
                    raise ValueError('placement_constraint.gpus`s tuple must be (str, int)')

    def send_trial(self, parameters, placement_constraint=None):
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
        """
        Send parameters to NNI.

        Parameters
        ----------
        parameters : Any
            Any payload.

        Returns
        -------
        int
            Parameter ID that is assigned to this parameter,
            which will be used for identification in future.
        """
        self.parameters_count += 1
112
113
114
115
116
117
        if placement_constraint is None:
            placement_constraint = {
                'type': 'None',
                'gpus': []
            }
        self._validate_placement_constraint(placement_constraint)
118
119
120
        new_trial = {
            'parameter_id': self.parameters_count,
            'parameters': parameters,
121
122
            'parameter_source': 'algorithm',
            'placement_constraint': placement_constraint
123
        }
124
        _logger.debug('New trial sent: %s', new_trial)
125
126
127
128
129
130
131
132
133
134
135
136
137

        send_payload = nni.dump(new_trial, pickle_size_limit=-1)
        if len(send_payload) > 256 * 1024:
            warnings.warn(
                'The total payload of the trial is larger than 50 KB. '
                'This can cause performance issues and even the crash of NNI experiment. '
                'This is usually caused by pickling large objects (like datasets) by mistake. '
                'See https://nni.readthedocs.io/en/stable/NAS/Serialization.html for details.'
            )
        # trial parameters can be super large, disable pickle size limit here
        # nevertheless, there could still be blocked by pipe / nni-manager
        send(CommandType.NewTrialJob, send_payload)

138
139
140
141
        if self.send_trial_callback is not None:
            self.send_trial_callback(parameters)  # pylint: disable=not-callable
        return self.parameters_count

142
143
144
    def mark_experiment_as_ending(self):
        send(CommandType.NoMoreTrialJobs, '')

145
    def handle_request_trial_jobs(self, num_trials):
146
        _logger.debug('Request trial jobs: %s', num_trials)
147
148
149
150
        if self.request_trial_jobs_callback is not None:
            self.request_trial_jobs_callback(num_trials)  # pylint: disable=not-callable

    def handle_update_search_space(self, data):
151
        _logger.debug('Received search space: %s', data)
152
        self.search_space = data
153
154

    def handle_trial_end(self, data):
155
        _logger.debug('Trial end: %s', data)
156
        self.trial_end_callback(nni.load(data['hyper_params'])['parameter_id'],  # pylint: disable=not-callable
157
158
159
                                data['event'] == 'SUCCEEDED')

    def handle_report_metric_data(self, data):
160
        _logger.debug('Metric reported: %s', data)
161
162
163
164
165
166
167
168
169
170
171
        if data['type'] == MetricType.REQUEST_PARAMETER:
            raise ValueError('Request parameter not supported')
        elif data['type'] == MetricType.PERIODICAL:
            self.intermediate_metric_callback(data['parameter_id'],  # pylint: disable=not-callable
                                              self._process_value(data['value']))
        elif data['type'] == MetricType.FINAL:
            self.final_metric_callback(data['parameter_id'],  # pylint: disable=not-callable
                                       self._process_value(data['value']))

    @staticmethod
    def _process_value(value) -> Any:  # hopefully a float
172
        value = nni.load(value)
173
        if isinstance(value, dict):
174
175
176
177
            if 'default' in value:
                return value['default']
            else:
                return value
178
        return value