test_msg_dispatcher.py 3.29 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
3
4
5
6
7

import json
from io import BytesIO
from unittest import TestCase, main

8
9
from nni.runtime import msg_dispatcher_base
from nni.runtime.msg_dispatcher import MsgDispatcher
10
from nni.runtime.tuner_command_channel.legacy import *
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from nni.tuner import Tuner
from nni.utils import extract_scalar_reward

class NaiveTuner(Tuner):
    def __init__(self):
        self.param = 0
        self.trial_results = []
        self.search_space = None
        self._accept_customized_trials()

    def generate_parameters(self, parameter_id, **kwargs):
        # report Tuner's internal states to generated parameters,
        # so we don't need to pause the main loop
        self.param += 2
        return {
            'param': self.param,
            'trial_results': self.trial_results,
            'search_space': self.search_space
        }

    def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
        reward = extract_scalar_reward(value)
        self.trial_results.append((parameter_id, parameters['param'], reward, kwargs.get("customized")))

    def update_search_space(self, search_space):
        self.search_space = search_space


_in_buf = BytesIO()
_out_buf = BytesIO()


def _reverse_io():
    _in_buf.seek(0)
    _out_buf.seek(0)
46
47
    _set_out_file(_in_buf)
    _set_in_file(_out_buf)
48
49
50
51
52


def _restore_io():
    _in_buf.seek(0)
    _out_buf.seek(0)
53
54
    _set_in_file(_in_buf)
    _set_out_file(_out_buf)
55
56
57
58
59
60


class MsgDispatcherTestCase(TestCase):
    def test_msg_dispatcher(self):
        _reverse_io()  # now we are sending to Tuner's incoming stream
        send(CommandType.RequestTrialJobs, '2')
chicm-ms's avatar
chicm-ms committed
61
62
        send(CommandType.ReportMetricData, '{"parameter_id":0,"type":"PERIODICAL","value":"10"}')
        send(CommandType.ReportMetricData, '{"parameter_id":1,"type":"FINAL","value":"11"}')
63
64
65
66
67
68
        send(CommandType.UpdateSearchSpace, '{"name":"SS0"}')
        send(CommandType.RequestTrialJobs, '1')
        send(CommandType.KillTrialJob, 'null')
        _restore_io()

        tuner = NaiveTuner()
69
70
        dispatcher = MsgDispatcher('ws://_placeholder_', tuner)
        dispatcher._channel = LegacyCommandChannel()
71
        msg_dispatcher_base._worker_fast_exit_on_terminate = False
72
73
74
75
76
77
78
79
80
81

        dispatcher.run()
        e = dispatcher.worker_exceptions[0]
        self.assertIs(type(e), AssertionError)
        self.assertEqual(e.args[0], 'Unsupported command: CommandType.KillTrialJob')

        _reverse_io()  # now we are receiving from Tuner's outgoing stream
        self._assert_params(0, 2, [], None)
        self._assert_params(1, 4, [], None)

82
        self._assert_params(2, 6, [[1, 4, 11, False]], {'name': 'SS0'})
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98

        self.assertEqual(len(_out_buf.read()), 0)  # no more commands

    def _assert_params(self, parameter_id, param, trial_results, search_space):
        command, data = receive()
        self.assertIs(command, CommandType.NewTrialJob)
        data = json.loads(data)
        self.assertEqual(data['parameter_id'], parameter_id)
        self.assertEqual(data['parameter_source'], 'algorithm')
        self.assertEqual(data['parameters']['param'], param)
        self.assertEqual(data['parameters']['trial_results'], trial_results)
        self.assertEqual(data['parameters']['search_space'], search_space)


if __name__ == '__main__':
    main()