test_gen_traffic_pattern_config.py 6.06 KB
Newer Older
1
2
3
4
5
6
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""Tests for traffic pattern config generation module."""
import argparse
import unittest
7
import tempfile
8

9
from tests.helper import decorator
10
from superbench.common.utils import gen_traffic_pattern_host_groups
11
12
13
14


class GenConfigTest(unittest.TestCase):
    """Test the utils for generating config."""
15
    @decorator.load_data('tests/data/mpi_pattern.txt')    # noqa: C901
16
    @decorator.load_data('tests/data/ib_traffic_topo_aware_hostfile')    # noqa: C901
17
    def test_gen_traffic_pattern_host_group(self, expected_mpi_pattern, tp_hostfile):
18
19
        """Test the function of generating traffic pattern config from specified mode."""
        # Test for all-nodes pattern
20
21
22
        test_config_file = tempfile.NamedTemporaryFile()
        test_config_path = test_config_file.name
        test_benchmark_name = 'test_benchmark'
23
        hostx = ['node0', 'node1', 'node2', 'node3', 'node4', 'node5', 'node6', 'node7']
24

25
26
        parser = argparse.ArgumentParser()
        parser.add_argument(
27
            '--type',
28
29
30
            type=str,
            default='all-nodes',
        )
31
32
33
34
35
        parser.add_argument(
            '--mpi_pattern',
            type=bool,
            default=True,
        )
36
37
        pattern, _ = parser.parse_known_args()
        expected_host_group = [[['node0', 'node1', 'node2', 'node3', 'node4', 'node5', 'node6', 'node7']]]
38
39
40
        self.assertEqual(
            gen_traffic_pattern_host_groups(hostx, pattern, test_config_path, test_benchmark_name), expected_host_group
        )
41
42
43
44

        # Test for pair-wise pattern
        parser = argparse.ArgumentParser()
        parser.add_argument(
45
            '--type',
46
47
48
            type=str,
            default='pair-wise',
        )
49
50
51
52
53
        parser.add_argument(
            '--mpi_pattern',
            type=bool,
            default=True,
        )
54
55
56
57
58
59
60
61
62
63
        pattern, _ = parser.parse_known_args()
        expected_host_group = [
            [['node0', 'node7'], ['node1', 'node6'], ['node2', 'node5'], ['node3', 'node4']],
            [['node0', 'node1'], ['node2', 'node7'], ['node3', 'node6'], ['node4', 'node5']],
            [['node0', 'node2'], ['node3', 'node1'], ['node4', 'node7'], ['node5', 'node6']],
            [['node0', 'node3'], ['node4', 'node2'], ['node5', 'node1'], ['node6', 'node7']],
            [['node0', 'node4'], ['node5', 'node3'], ['node6', 'node2'], ['node7', 'node1']],
            [['node0', 'node5'], ['node6', 'node4'], ['node7', 'node3'], ['node1', 'node2']],
            [['node0', 'node6'], ['node7', 'node5'], ['node1', 'node4'], ['node2', 'node3']]
        ]
64
65
66
        self.assertEqual(
            gen_traffic_pattern_host_groups(hostx, pattern, test_config_path, test_benchmark_name), expected_host_group
        )
67
68
69
70
71
72
73
74
75
76
77
78
79

        # Test for k-batch pattern
        parser = argparse.ArgumentParser()
        parser.add_argument(
            '--type',
            type=str,
            default='k-batch',
        )
        parser.add_argument(
            '--batch',
            type=int,
            default=3,
        )
80
81
82
83
84
        parser.add_argument(
            '--mpi_pattern',
            type=bool,
            default=True,
        )
85
86
        pattern, _ = parser.parse_known_args()
        expected_host_group = [[['node0', 'node1', 'node2'], ['node3', 'node4', 'node5']]]
87
88
89
        self.assertEqual(
            gen_traffic_pattern_host_groups(hostx, pattern, test_config_path, test_benchmark_name), expected_host_group
        )
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119

        # Test for topo-aware pattern
        tp_ibstat_path = 'tests/data/ib_traffic_topo_aware_ibstat.txt'
        tp_ibnetdiscover_path = 'tests/data/ib_traffic_topo_aware_ibnetdiscover.txt'
        parser = argparse.ArgumentParser()
        parser.add_argument(
            '--type',
            type=str,
            default='topo-aware',
        )
        parser.add_argument(
            '--ibstat',
            type=str,
            default=tp_ibstat_path,
        )
        parser.add_argument(
            '--ibnetdiscover',
            type=str,
            default=tp_ibnetdiscover_path,
        )
        parser.add_argument(
            '--min_dist',
            type=int,
            default=2,
        )
        parser.add_argument(
            '--max_dist',
            type=int,
            default=6,
        )
120
121
122
123
124
        parser.add_argument(
            '--mpi_pattern',
            type=bool,
            default=True,
        )
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
        hostx = tp_hostfile.split()
        pattern, _ = parser.parse_known_args()
        expected_host_group = [
            [
                ['vma414bbc00005I', 'vma414bbc00005J'], ['vma414bbc00005K', 'vma414bbc00005L'],
                ['vma414bbc00005M', 'vma414bbc00005N'], ['vma414bbc00005O', 'vma414bbc00005P'],
                ['vma414bbc00005Q', 'vma414bbc00005R']
            ],
            [
                ['vma414bbc00005I', 'vma414bbc00005K'], ['vma414bbc00005J', 'vma414bbc00005L'],
                ['vma414bbc00005O', 'vma414bbc00005Q'], ['vma414bbc00005P', 'vma414bbc00005R']
            ],
            [
                ['vma414bbc00005I', 'vma414bbc00005O'], ['vma414bbc00005J', 'vma414bbc00005P'],
                ['vma414bbc00005K', 'vma414bbc00005Q'], ['vma414bbc00005L', 'vma414bbc00005R']
            ]
        ]
142
143
144
145
146
147
148
149
150
        self.assertEqual(
            gen_traffic_pattern_host_groups(hostx, pattern, test_config_path, test_benchmark_name), expected_host_group
        )

        # Test for mpi_pattern file
        with open(test_config_path, 'r') as f:
            content = f.read()
            self.assertEqual(content, expected_mpi_pattern)
        test_config_file.close()
151
152
153
154
155
156
157
158
159

        # Test for invalid pattern
        hostx = ['node0', 'node1', 'node2', 'node3', 'node4', 'node5', 'node6', 'node7']
        parser = argparse.ArgumentParser()
        parser.add_argument(
            '--type',
            type=str,
            default='invalid pattern',
        )
160
161
162
163
164
        parser.add_argument(
            '--mpi_pattern',
            type=bool,
            default=True,
        )
165
        pattern, _ = parser.parse_known_args()
166
        gen_traffic_pattern_host_groups(hostx, pattern, test_config_path, test_benchmark_name)