nvbandwidth.py 11.2 KB
Newer Older
1
2
3
4
5
6
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""Module of the NV Bandwidth Test."""

import os
7
import subprocess
8
9
10
import re

from superbench.common.utils import logger
11
from superbench.benchmarks import BenchmarkRegistry, Platform, ReturnCode
12
13
14
15
16
from superbench.benchmarks.micro_benchmarks import MicroBenchmarkWithInvoke


class NvBandwidthBenchmark(MicroBenchmarkWithInvoke):
    """The NV Bandwidth Test benchmark class."""
17
18
19
20
21
22
23
    # Regular expressions for summary line and matrix header detection
    re_block_start_pattern = re.compile(r'^Running\s+(.+)$')
    re_matrix_header_line = re.compile(r'^(memcpy|memory latency)')
    re_matrix_row_pattern = re.compile(r'^\s*\d')
    re_summary_pattern = re.compile(r'SUM (\S+) (\d+\.\d+)')
    re_unsupported_pattern = re.compile(r'ERROR: Testcase (\S+) not found!')

24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
    def __init__(self, name, parameters=''):
        """Constructor.

        Args:
            name (str): benchmark name.
            parameters (str): benchmark parameters.
        """
        super().__init__(name, parameters)

        self._bin_name = 'nvbandwidth'

    def add_parser_arguments(self):
        """Add the specified arguments."""
        super().add_parser_arguments()

        self._parser.add_argument(
            '--buffer_size',
            type=int,
            default=64,
            required=False,
            help='Memcpy buffer size in MiB. Default is 64.',
        )

        self._parser.add_argument(
            '--test_cases',
49
            nargs='+',
50
            type=str,
51
            default=[],
52
53
            required=False,
            help=(
54
                'Specify the test case(s) to execute by name only. '
55
                'To view the available test case names, run the command "nvbandwidth -l" on the host. '
56
                'If no specific test case is specified, all test cases will be executed by default.'
57
58
59
60
61
62
63
64
65
66
67
68
            ),
        )

        self._parser.add_argument(
            '--skip_verification',
            action='store_true',
            help='Skips data verification after copy. Default is False.',
        )

        self._parser.add_argument(
            '--disable_affinity',
            action='store_true',
69
70
71
72
73
            help=(
                'Disable automatic CPU affinity control. Default is False. '
                'If user would like to bind the process to specific NUMA node, '
                'please use --disable_affinity along with --numa argument.'
            ),
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
        )

        self._parser.add_argument(
            '--use_mean',
            action='store_true',
            help='Use mean instead of median for results. Default is False.',
        )

        self._parser.add_argument(
            '--num_loops',
            type=int,
            default=3,
            required=False,
            help='Iterations of the benchmark. Default is 3.',
        )

    def _preprocess(self):
        """Preprocess/preparation operations before the benchmarking.

        Return:
            True if _preprocess() succeed.
        """
        if not super()._preprocess():
            return False

99
        if not self._set_binary_path() or not self._get_arguments_from_env():
100
101
102
103
104
105
106
107
108
            return False

        # Construct the command for nvbandwidth
        command = os.path.join(self._args.bin_dir, self._bin_name)

        if self._args.buffer_size:
            command += f' --bufferSize {self._args.buffer_size}'

        if self._args.test_cases:
109
110
111
            command += ' --testcase ' + ' '.join(self._args.test_cases)
        else:
            self._args.test_cases = self._get_all_test_cases()
112
113
114
115
116
117

        if self._args.skip_verification:
            command += ' --skipVerification'

        if self._args.disable_affinity:
            command += ' --disableAffinity'
118
119
            if self._args.numa is not None:
                command = f'numactl --cpunodebind={self._args.numa} --membind={self._args.numa} ' + command
120
121
122
123
124
125
126
127
128
129
130
131

        if self._args.use_mean:
            command += ' --useMean'

        if self._args.num_loops:
            command += f' --testSamples {self._args.num_loops}'

        self._commands.append(command)

        return True

    def _process_raw_line(self, line, parse_status):
132
        """Process a raw line of text and update the parse status accordingly.
133
134

        Args:
135
136
137
            line (str): The raw line of text to be processed.
            parse_status (dict): A dictionary containing the current parsing status,
                     which will be updated based on the content of the line.
138

139
        Returns:
140
141
142
143
            None
        """
        line = line.strip()

144
145
146
147
148
        # Detect unsupported test cases
        if self.re_unsupported_pattern.match(line):
            parse_status['unsupported_testcases'].add(self.re_unsupported_pattern.match(line).group(1).lower())
            return

149
        # Detect the start of a test
150
151
152
        if self.re_block_start_pattern.match(line):
            parse_status['test_name'] = self.re_block_start_pattern.match(line).group(1).lower()[:-1]
            parse_status['excuted_testcases'].add(parse_status['test_name'])
153
154
155
            return

        # Detect the start of matrix data
156
        if parse_status['test_name'] and self.re_matrix_header_line.match(line):
157
            parse_status['benchmark_type'] = 'bw' if 'bandwidth' in line else 'lat'
158
159
160
161
162
            # Parse the row and column name
            tmp_idx = line.find('(row)')
            parse_status['metrix_row'] = line[tmp_idx - 3:tmp_idx].lower()
            tmp_idx = line.find('(column)')
            parse_status['metrix_col'] = line[tmp_idx - 3:tmp_idx].lower()
163
164
165
166
167
            return

        # Parse the matrix header
        if (
            parse_status['test_name'] and parse_status['benchmark_type'] and not parse_status['matrix_header']
168
            and self.re_matrix_row_pattern.match(line)
169
170
171
172
173
        ):
            parse_status['matrix_header'] = line.split()
            return

        # Parse matrix rows
174
        if parse_status['test_name'] and parse_status['benchmark_type'] and self.re_matrix_row_pattern.match(line):
175
176
177
            row_data = line.split()
            row_index = row_data[0]
            for col_index, value in enumerate(row_data[1:], start=1):
178
179
180
181
                # Skip 'N/A' values, 'N/A' indicates the test path is self to self.
                if value == 'N/A':
                    continue

182
183
184
                col_header = parse_status['matrix_header'][col_index - 1]
                test_name = parse_status['test_name']
                benchmark_type = parse_status['benchmark_type']
185
186
187
                row_name = parse_status['metrix_row']
                col_name = parse_status['metrix_col']
                metric_name = f'{test_name}_{row_name}{row_index}_{col_name}{col_header}_{benchmark_type}'
188
189
190
191
                parse_status['results'][metric_name] = float(value)
            return

        # Parse summary results
192
193
        if self.re_summary_pattern.match(line):
            value = self.re_summary_pattern.match(line).group(2)
194
195
            test_name = parse_status['test_name']
            benchmark_type = parse_status['benchmark_type']
196
            parse_status['results'][f'{test_name}_sum_{benchmark_type}'] = float(value)
197
198
199
200
201

            # Reset parsing state for next test
            parse_status['test_name'] = ''
            parse_status['benchmark_type'] = None
            parse_status['matrix_header'].clear()
202
203
204
            parse_status['metrix_row'] = ''
            parse_status['metrix_col'] = ''
            return
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222

    def _process_raw_result(self, cmd_idx, raw_output):
        """Function to parse raw results and save the summarized results.

           self._result.add_raw_data() and self._result.add_result() need to be called to save the results.

        Args:
            cmd_idx (int): the index of command corresponding with the raw_output.
            raw_output (str): raw output string of the micro-benchmark.

        Return:
            True if the raw output string is valid and result can be extracted.
        """
        try:
            self._result.add_raw_data('raw_output_' + str(cmd_idx), raw_output, self._args.log_raw_data)
            content = raw_output.splitlines()
            parsing_status = {
                'results': {},
223
224
                'excuted_testcases': set(),
                'unsupported_testcases': set(),
225
226
227
                'benchmark_type': None,
                'matrix_header': [],
                'test_name': '',
228
229
                'metrix_row': '',
                'metrix_col': '',
230
231
232
233
234
            }

            for line in content:
                self._process_raw_line(line, parsing_status)

235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
            return_code = ReturnCode.SUCCESS
            # Log unsupported test cases
            for testcase in parsing_status['unsupported_testcases']:
                logger.warning(f'Test case {testcase} is not supported.')
                return_code = ReturnCode.INVALID_ARGUMENT
                self._result.add_raw_data(testcase, 'Not supported', self._args.log_raw_data)

            # Check if the test case was waived
            for testcase in self._args.test_cases:
                if (
                    testcase not in parsing_status['unsupported_testcases']
                    and testcase not in parsing_status['excuted_testcases']
                ):
                    logger.warning(f'Test case {testcase} was waived.')
                    self._result.add_raw_data(testcase, 'waived', self._args.log_raw_data)
                    return_code = ReturnCode.INVALID_ARGUMENT

252
253
            if not parsing_status['results']:
                self._result.add_raw_data('nvbandwidth', 'No valid results found', self._args.log_raw_data)
254
                return_code = ReturnCode.MICROBENCHMARK_RESULT_PARSING_FAILURE
255
256
257
258
259
260
                return False

            # Store parsed results
            for metric, value in parsing_status['results'].items():
                self._result.add_result(metric, value)

261
            self._result.set_return_code(return_code)
262
263
264
265
266
267
268
269
270
271
            return True
        except Exception as e:
            logger.error(
                'The result format is invalid - round: {}, benchmark: {}, raw output: {}, message: {}.'.format(
                    self._curr_run_index, self._name, raw_output, str(e)
                )
            )
            self._result.add_result('abort', 1)
            return False

272
273
    def _get_all_test_cases(self):
        command = os.path.join(self._args.bin_dir, self._bin_name) + ' --list'
274
275
276
277
        test_case_pattern = re.compile(r'(\d+),\s+([\w_]+):')

        try:
            # Execute the command and capture output
278
279
280
            result = subprocess.run(
                command, shell=True, text=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=False
            )
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296

            # Check the return code
            if result.returncode != 0:
                logger.error(f'{command} failed with return code {result.returncode}')
                return []

            if result.stderr:
                logger.error(f'{command} failed with {result.stderr}')
                return []

            # Parse the output
            return [name for _, name in test_case_pattern.findall(result.stdout)]
        except Exception as e:
            logger.error(f'Failed to get all test case names: {e}')
            return []

297
298

BenchmarkRegistry.register_benchmark('nvbandwidth', NvBandwidthBenchmark, platform=Platform.CUDA)