cublaslt_function.py 6.01 KB
Newer Older
1
2
3
4
5
6
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""Module of the cuBLASLt GEMM benchmark."""

import os
7
import itertools
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24

from superbench.common.utils import logger
from superbench.benchmarks import BenchmarkRegistry, Platform, ReturnCode
from superbench.benchmarks.micro_benchmarks import MicroBenchmarkWithInvoke


class CublasLtBenchmark(MicroBenchmarkWithInvoke):
    """The cuBLASLt GEMM benchmark class."""
    def __init__(self, name, parameters=''):
        """Constructor.

        Args:
            name (str): benchmark name.
            parameters (str): benchmark parameters.
        """
        super().__init__(name, parameters)

25
26
        self._bin_name = 'cublaslt_gemm'
        self._in_types = ['fp64', 'fp32', 'fp16', 'bf16', 'fp8e4m3', 'fp8e5m2']
27

28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
    def mrange(self, start, stop=-1, multiplication_factor=2):
        """Range constructor with multiplication factor.

        Args:
            start (int): Start number.
            stop (int, optional): Stop number. Defaults to -1.
            multiplication_factor (int, optional): Multiplication factor. Defaults to 2.

        Yields:
            int: number in the range.
        """
        while True:
            yield start
            start *= multiplication_factor
            if start > stop or start == 0 or multiplication_factor < 2:
                break

    def validate_mrange(self, string):
        """Validate mrange string in format start[[:stop]:multiplication_factor].

        Args:
            string (str): mrange string.

        Returns:
            bool: whether the mrange is expected.
        """
        nums = string.split(':')
        if len(nums) > 3:
            return False
        return bool(all(x.isdigit() for x in nums))

59
60
61
62
63
64
65
66
67
    def add_parser_arguments(self):
        """Add the specified arguments."""
        super().add_parser_arguments()

        self._parser.add_argument(
            '--shapes',
            type=str,
            nargs='+',
            default=[f'{x},{x},{x}' for x in [2048, 4096, 8192]],
68
            help='Shapes in m,n,k format. Support format start:stop:multiplication_factor, e.g., 16:128:2.',
69
70
71
        )
        self._parser.add_argument(
            '--batch',
72
73
            type=str,
            default='0',
74
            required=False,
75
76
77
78
            help=(
                'Batch size for strided batch GEMM, set 0 to disable.'
                ' Support format start:stop:multiplication_factor, e.g., 16:128:2.'
            ),
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
        )
        self._parser.add_argument(
            '--num_warmup',
            type=int,
            default=20,
            required=False,
            help='Number of warm up steps.',
        )
        self._parser.add_argument(
            '--num_steps',
            type=int,
            default=50,
            required=False,
            help='Number of steps to measure.',
        )
        self._parser.add_argument(
95
            '--in_types',
96
            type=str,
97
98
            nargs='+',
            default=['fp8e4m3'],
99
            required=False,
100
            help='List of input data types, support {}.'.format(' '.join(self._in_types)),
101
102
103
104
105
106
107
108
109
110
111
112
113
        )

    def _preprocess(self):
        """Preprocess/preparation operations before the benchmarking.

        Return:
            True if _preprocess() succeed.
        """
        if not super()._preprocess():
            return False

        self.__bin_path = os.path.join(self._args.bin_dir, self._bin_name)

114
115
        if not self.validate_mrange(self._args.batch):
            logger.error(f'Invalid batch size {self._args.batch}.')
116
117
118
            return False

        self._commands = []
119
120
121
        for _in_type in self._args.in_types:
            if _in_type not in self._in_types:
                logger.error(f'Invalid input type {_in_type}.')
122
                return False
123
124
125
126
127
128
129
130
131
132
133
134
135
            for _b in self.mrange(*map(int, self._args.batch.split(':'))):
                for shape in self._args.shapes:
                    shape_list = shape.replace(',', ' ').split()
                    if len(shape_list) != 3 or not all(self.validate_mrange(x) for x in shape_list):
                        logger.error(f'Invalid shape {shape}.')
                        return False
                    for _m, _n, _k in itertools.product(
                        *map(lambda shape: self.mrange(*map(int, shape.split(':'))), shape_list)
                    ):
                        self._commands.append(
                            f'{self.__bin_path} -m {_m} -n {_n} -k {_k} -b {_b} '
                            f'-w {self._args.num_warmup} -i {self._args.num_steps} -t {_in_type}'
                        )
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156

        return True

    def _process_raw_result(self, cmd_idx, raw_output):
        """Function to parse raw results and save the summarized results.

          self._result.add_raw_data() and self._result.add_result() need to be called to save the results.

        Args:
            cmd_idx (int): the index of command corresponding with the raw_output.
            raw_output (str): raw output string of the micro-benchmark.

        Return:
            True if the raw output string is valid and result can be extracted.
        """
        self._result.add_raw_data(f'raw_output_{cmd_idx}', raw_output, self._args.log_raw_data)

        try:
            fields = raw_output.strip().split()
            if len(fields) != 6 or not all(x.isdigit() for x in fields[:4]):
                raise ValueError('Invalid result.')
157
158
159
            self._result.add_result(
                f'{self._commands[cmd_idx].split()[-1]}_{fields[3]}_{"_".join(fields[:3])}_flops', float(fields[-1])
            )
160
161
162
163
164
165
166
167
168
169
170
171
172
        except BaseException as e:
            self._result.set_return_code(ReturnCode.MICROBENCHMARK_RESULT_PARSING_FAILURE)
            logger.error(
                'The result format is invalid - round: {}, benchmark: {}, raw output: {}, message: {}.'.format(
                    self._curr_run_index, self._name, raw_output, str(e)
                )
            )
            return False

        return True


BenchmarkRegistry.register_benchmark('cublaslt-gemm', CublasLtBenchmark, platform=Platform.CUDA)