test_assessor.py 2.57 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
Deshui Yu's avatar
Deshui Yu committed
3
4
5
6
7

from io import BytesIO
import json
from unittest import TestCase, main

8
9
10
from nni.assessor import Assessor, AssessResult
from nni.runtime import msg_dispatcher_base as msg_dispatcher_base
from nni.runtime.msg_dispatcher import MsgDispatcher
11
from nni.runtime.tuner_command_channel.legacy import *
12

13
14
_trials = []
_end_trials = []
Deshui Yu's avatar
Deshui Yu committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


class NaiveAssessor(Assessor):
    def assess_trial(self, trial_job_id, trial_history):
        _trials.append(trial_job_id)
        if sum(trial_history) % 2 == 0:
            return AssessResult.Good
        else:
            return AssessResult.Bad

    def trial_end(self, trial_job_id, success):
        _end_trials.append((trial_job_id, success))


_in_buf = BytesIO()
_out_buf = BytesIO()

32

Deshui Yu's avatar
Deshui Yu committed
33
34
35
def _reverse_io():
    _in_buf.seek(0)
    _out_buf.seek(0)
36
37
    _set_out_file(_in_buf)
    _set_in_file(_out_buf)
Deshui Yu's avatar
Deshui Yu committed
38

39

Deshui Yu's avatar
Deshui Yu committed
40
41
42
def _restore_io():
    _in_buf.seek(0)
    _out_buf.seek(0)
43
44
    _set_in_file(_in_buf)
    _set_out_file(_out_buf)
Deshui Yu's avatar
Deshui Yu committed
45
46
47
48


class AssessorTestCase(TestCase):
    def test_assessor(self):
49
        pass
Deshui Yu's avatar
Deshui Yu committed
50
        _reverse_io()
51
52
53
54
55
        send(CommandType.ReportMetricData, '{"parameter_id": 0,"trial_job_id":"A","type":"PERIODICAL","sequence":0,"value":"2"}')
        send(CommandType.ReportMetricData, '{"parameter_id": 1,"trial_job_id":"B","type":"PERIODICAL","sequence":0,"value":"2"}')
        send(CommandType.ReportMetricData, '{"parameter_id": 0,"trial_job_id":"A","type":"PERIODICAL","sequence":1,"value":"3"}')
        send(CommandType.TrialEnd, '{"trial_job_id":"A","event":"SYS_CANCELED","hyper_params":"{\\"parameter_id\\": 0}"}')
        send(CommandType.TrialEnd, '{"trial_job_id":"B","event":"SUCCEEDED","hyper_params":"{\\"parameter_id\\": 1}"}')
Deshui Yu's avatar
Deshui Yu committed
56
57
58
59
        send(CommandType.NewTrialJob, 'null')
        _restore_io()

        assessor = NaiveAssessor()
60
        dispatcher = MsgDispatcher('ws://_unittest_placeholder_', None, assessor)
61
        dispatcher._channel = LegacyCommandChannel()
62
        msg_dispatcher_base._worker_fast_exit_on_terminate = False
63
64
65
66
67

        dispatcher.run()
        e = dispatcher.worker_exceptions[0]
        self.assertIs(type(e), AssertionError)
        self.assertEqual(e.args[0], 'Unsupported command: CommandType.NewTrialJob')
Deshui Yu's avatar
Deshui Yu committed
68
69
70
71
72
73
74
75
76
77
78
79

        self.assertEqual(_trials, ['A', 'B', 'A'])
        self.assertEqual(_end_trials, [('A', False), ('B', True)])

        _reverse_io()
        command, data = receive()
        self.assertIs(command, CommandType.KillTrialJob)
        self.assertEqual(data, '"A"')
        self.assertEqual(len(_out_buf.read()), 0)


if __name__ == '__main__':
80
    main()