cublaslt_function.py 4.15 KB
Newer Older
1
2
3
4
5
6
7
8
9
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""Module of the cuBLASLt GEMM benchmark."""

import os

from superbench.common.utils import logger
from superbench.benchmarks import BenchmarkRegistry, Platform, ReturnCode
10
from superbench.benchmarks.micro_benchmarks import BlasLtBaseBenchmark
11
12


13
class CublasLtBenchmark(BlasLtBaseBenchmark):
14
15
16
17
18
19
20
21
22
23
    """The cuBLASLt GEMM benchmark class."""
    def __init__(self, name, parameters=''):
        """Constructor.

        Args:
            name (str): benchmark name.
            parameters (str): benchmark parameters.
        """
        super().__init__(name, parameters)

24
        self._bin_name = 'cublaslt_gemm'
25
        self._in_types = ['fp64', 'fp32', 'fp16', 'bf16', 'fp8e4m3', 'fp8e5m2', 'int8']
26
27
28
29
30
31

    def add_parser_arguments(self):
        """Add the specified arguments."""
        super().add_parser_arguments()

        self._parser.add_argument(
32
            '--in_types',
33
            type=str,
34
35
            nargs='+',
            default=['fp8e4m3'],
36
            required=False,
37
            help='List of input data types, support {}.'.format(' '.join(self._in_types)),
38
        )
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
        self._parser.add_argument(
            '--enable_autotune',
            action='store_true',
            required=False,
            help='Enable exhaustive autotune mode to find best algorithm.',
        )
        self._parser.add_argument(
            '--num_warmup_autotune',
            type=int,
            default=20,
            required=False,
            help='Number of warm up steps for autotune.',
        )
        self._parser.add_argument(
            '--num_steps_autotune',
            type=int,
            default=50,
            required=False,
            help='Number of steps to measure for autotune.',
        )
59
60
61
62
63
64
65
66
67
68
69
70
71

    def _preprocess(self):
        """Preprocess/preparation operations before the benchmarking.

        Return:
            True if _preprocess() succeed.
        """
        if not super()._preprocess():
            return False

        self.__bin_path = os.path.join(self._args.bin_dir, self._bin_name)

        self._commands = []
72
        for _m, _n, _k, _b, _in_type in self._shapes_to_run:
73
74
75
76
77
78
            # pull out the autotune args onto their own short f-string
            autotune_args = (
                f' -a -W {self._args.num_warmup_autotune}'
                f' -I {self._args.num_steps_autotune}'
            ) if self._args.enable_autotune else ''

79
80
81
            self._commands.append(
                f'{self.__bin_path} -m {_m} -n {_n} -k {_k} -b {_b} '
                f'-w {self._args.num_warmup} -i {self._args.num_steps} -t {_in_type}'
82
                f'{(" " + autotune_args) if autotune_args else ""}'
83
            )
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104

        return True

    def _process_raw_result(self, cmd_idx, raw_output):
        """Function to parse raw results and save the summarized results.

          self._result.add_raw_data() and self._result.add_result() need to be called to save the results.

        Args:
            cmd_idx (int): the index of command corresponding with the raw_output.
            raw_output (str): raw output string of the micro-benchmark.

        Return:
            True if the raw output string is valid and result can be extracted.
        """
        self._result.add_raw_data(f'raw_output_{cmd_idx}', raw_output, self._args.log_raw_data)

        try:
            fields = raw_output.strip().split()
            if len(fields) != 6 or not all(x.isdigit() for x in fields[:4]):
                raise ValueError('Invalid result.')
105
106
107
            self._result.add_result(
                f'{self._commands[cmd_idx].split()[-1]}_{fields[3]}_{"_".join(fields[:3])}_flops', float(fields[-1])
            )
108
109
110
111
112
113
114
115
116
117
118
119
120
        except BaseException as e:
            self._result.set_return_code(ReturnCode.MICROBENCHMARK_RESULT_PARSING_FAILURE)
            logger.error(
                'The result format is invalid - round: {}, benchmark: {}, raw output: {}, message: {}.'.format(
                    self._curr_run_index, self._name, raw_output, str(e)
                )
            )
            return False

        return True


BenchmarkRegistry.register_benchmark('cublaslt-gemm', CublasLtBenchmark, platform=Platform.CUDA)