integration.py 4.91 KB
Newer Older
1
import logging
2
import os
3
from typing import Any, Callable
4
5

from nni.runtime.msg_dispatcher_base import MsgDispatcherBase
6
from nni.runtime.protocol import CommandType, send
7
8
9
from nni.utils import MetricType

from .graph import MetricData
10
11
12
13
from .execution.base import BaseExecutionEngine
from .execution.cgo_engine import CGOExecutionEngine
from .execution.api import set_execution_engine
from .integration_api import register_advisor
14
from .serializer import json_dumps, json_loads
15

QuanluZhang's avatar
QuanluZhang committed
16
_logger = logging.getLogger(__name__)
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47


class RetiariiAdvisor(MsgDispatcherBase):
    """
    The class is to connect Retiarii components to NNI backend.

    It will function as the main thread when running a Retiarii experiment through NNI.
    Strategy will be launched as its thread, who will call APIs in execution engine. Execution
    engine will then find the advisor singleton and send payloads to advisor.

    When metrics are sent back, advisor will first receive the payloads, who will call the callback
    function (that is a member function in graph listener).

    The conversion advisor provides are minimum. It is only a send/receive module, and execution engine
    needs to handle all the rest.

    FIXME
        How does advisor exit when strategy exists?

    Attributes
    ----------
    send_trial_callback

    request_trial_jobs_callback

    trial_end_callback

    intermediate_metric_callback

    final_metric_callback
    """
48

QuanluZhang's avatar
QuanluZhang committed
49
    def __init__(self):
50
51
        super(RetiariiAdvisor, self).__init__()
        register_advisor(self)  # register the current advisor as the "global only" advisor
52
        self.search_space = None
53
54
55
56
57
58
59
60
61

        self.send_trial_callback: Callable[[dict], None] = None
        self.request_trial_jobs_callback: Callable[[int], None] = None
        self.trial_end_callback: Callable[[int, bool], None] = None
        self.intermediate_metric_callback: Callable[[int, MetricData], None] = None
        self.final_metric_callback: Callable[[int, MetricData], None] = None

        self.parameters_count = 0

62
63
64
65
66
67
68
69
70
        engine = self._create_execution_engine()
        set_execution_engine(engine)

    def _create_execution_engine(self):
        if os.environ.get('CGO') == 'true':
            return CGOExecutionEngine()
        else:
            return BaseExecutionEngine()

71
    def handle_initialize(self, data):
72
73
74
75
76
77
78
79
        """callback for initializing the advisor
        Parameters
        ----------
        data: dict
            search space
        """
        self.handle_update_search_space(data)
        send(CommandType.Initialized, '')
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101

    def send_trial(self, parameters):
        """
        Send parameters to NNI.

        Parameters
        ----------
        parameters : Any
            Any payload.

        Returns
        -------
        int
            Parameter ID that is assigned to this parameter,
            which will be used for identification in future.
        """
        self.parameters_count += 1
        new_trial = {
            'parameter_id': self.parameters_count,
            'parameters': parameters,
            'parameter_source': 'algorithm'
        }
102
        _logger.info('New trial sent: %s', new_trial)
103
        send(CommandType.NewTrialJob, json_dumps(new_trial))
104
105
106
107
108
        if self.send_trial_callback is not None:
            self.send_trial_callback(parameters)  # pylint: disable=not-callable
        return self.parameters_count

    def handle_request_trial_jobs(self, num_trials):
109
        _logger.info('Request trial jobs: %s', num_trials)
110
111
112
113
        if self.request_trial_jobs_callback is not None:
            self.request_trial_jobs_callback(num_trials)  # pylint: disable=not-callable

    def handle_update_search_space(self, data):
114
        _logger.info('Received search space: %s', data)
115
        self.search_space = data
116
117

    def handle_trial_end(self, data):
118
        _logger.info('Trial end: %s', data)
119
        self.trial_end_callback(json_loads(data['hyper_params'])['parameter_id'],  # pylint: disable=not-callable
120
121
122
                                data['event'] == 'SUCCEEDED')

    def handle_report_metric_data(self, data):
123
        _logger.info('Metric reported: %s', data)
124
125
126
127
128
129
130
131
132
133
134
        if data['type'] == MetricType.REQUEST_PARAMETER:
            raise ValueError('Request parameter not supported')
        elif data['type'] == MetricType.PERIODICAL:
            self.intermediate_metric_callback(data['parameter_id'],  # pylint: disable=not-callable
                                              self._process_value(data['value']))
        elif data['type'] == MetricType.FINAL:
            self.final_metric_callback(data['parameter_id'],  # pylint: disable=not-callable
                                       self._process_value(data['value']))

    @staticmethod
    def _process_value(value) -> Any:  # hopefully a float
135
        value = json_loads(value)
136
        if isinstance(value, dict):
137
138
139
140
            if 'default' in value:
                return value['default']
            else:
                return value
141
        return value