test_file_channel.py 4.06 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import json
import os
import random
import shutil
import string
import sys
import time
import unittest
from argparse import Namespace
from datetime import datetime

from tools.nni_trial_tool.base_channel import CommandType
from tools.nni_trial_tool.file_channel import (FileChannel, command_path,
                                               manager_commands_file_name)

sys.path.append("..")

runner_file_name = "commands/runner_commands.txt"
manager_file_name = "commands/manager_commands.txt"


class FileChannelTest(unittest.TestCase):

    def setUp(self):
        self.args = Namespace()
        self.args.node_count = 1
        self.args.node_id = None
        if os.path.exists(command_path):
            shutil.rmtree(command_path)

    def test_send(self):
        fc = None
        try:
            fc = FileChannel(self.args)
            fc.send(CommandType.ReportGpuInfo, "command1")
            fc.send(CommandType.ReportGpuInfo, "command2")

            self.check_timeout(2, lambda: os.path.exists(runner_file_name))

            self.assertTrue(os.path.exists(runner_file_name))
            with open(runner_file_name, "rb") as runner:
                lines = runner.readlines()
            self.assertListEqual(lines, [b'GI00000000000010"command1"\n', b'GI00000000000010"command2"\n'])
        finally:
            if fc is not None:
                fc.close()

    def test_send_multi_node(self):
        fc1 = None
        fc2 = None
        try:
            runner1_file_name = "commands/runner_commands_1.txt"
            self.args.node_id = 1
            fc1 = FileChannel(self.args)
            fc1.send(CommandType.ReportGpuInfo, "command1")
            # wait command have enough time to write before closed.

            runner2_file_name = "commands/runner_commands_2.txt"
            self.args.node_id = 2
            fc2 = FileChannel(self.args)
            fc2.send(CommandType.ReportGpuInfo, "command1")

            self.check_timeout(2, lambda: os.path.exists(runner1_file_name) and os.path.exists(runner2_file_name))

            self.assertTrue(os.path.exists(runner1_file_name))
            with open(runner1_file_name, "rb") as runner:
                lines1 = runner.readlines()
            self.assertTrue(os.path.exists(runner2_file_name))
            with open(runner2_file_name, "rb") as runner:
                lines2 = runner.readlines()
            self.assertListEqual(lines1, [b'GI00000000000010"command1"\n'])
            self.assertListEqual(lines2, [b'GI00000000000010"command1"\n'])
        finally:
            if fc1 is not None:
                fc1.close()
            if fc2 is not None:
                fc2.close()

    def test_receive(self):
        fc = None
        manager_file = None
        try:
            fc = FileChannel(self.args)
            message = fc.receive()
            self.assertEqual(message, (None, None))

            os.mkdir(command_path)
            manager_file = open(manager_file_name, "wb")
            manager_file.write(b'TR00000000000009"manager"\n')
            manager_file.flush()

            self.check_timeout(2, lambda: fc.received())
            message = fc.receive()
            self.assertEqual(message, (CommandType.NewTrialJob, "manager"))

            manager_file.write(b'TR00000000000010"manager2"\n')
            manager_file.flush()

            self.check_timeout(2, lambda: fc.received())
            message = fc.receive()
            self.assertEqual(message, (CommandType.NewTrialJob, "manager2"))
        finally:
            if fc is not None:
                fc.close()
            if manager_file is not None:
                manager_file.close()

    def check_timeout(self, timeout, callback):
        interval = 0.01
        start = datetime.now().timestamp()
        count = int(timeout / interval)
        for x in range(count):
            if callback():
                break
            time.sleep(interval)
        print("checked {} times, {:3F} seconds".format(x, datetime.now().timestamp()-start))


if __name__ == '__main__':
    unittest.main()