blaslt_function_base.py 4.63 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""Module of the BLASLt GEMM Base Class."""
import itertools

from superbench.common.utils import logger
from superbench.benchmarks.micro_benchmarks import MicroBenchmarkWithInvoke


def mrange(start, stop=-1, multiplication_factor=2, symbol='x'):
    """Range constructor with multiplication factor.

    Args:
        start (int): Start number.
        stop (int, optional): Stop number. Defaults to -1.
        multiplication_factor (int, optional): Multiplication factor. Defaults to 2.
        symbol (str, optional): Symbol. Defaults to 'x' (multiplication).

    Yields:
        int: number in the range.
    """
    if symbol == 'x':
        while True:
            yield start
            start *= multiplication_factor
            if start > stop or start == 0 or multiplication_factor < 2:
                break
    elif symbol == '+':
        while True:
            yield start
            start = start + multiplication_factor
            if start > stop or start == 0 or multiplication_factor < 1:
                break
    else:
        raise ValueError(f'Invalid symbol {symbol}.')


def validate_mrange(string):
    """Validate mrange string in format start[[:stop]:multiplication_factor].

    Args:
        string (str): mrange string.

    Returns:
        bool: whether the mrange is expected.
    """
    nums = string.split(':')
    if len(nums) > 3:
        return False

    if len(nums) < 3:
        return all(x.isdigit() for x in nums)
    return nums[0].isdigit() and nums[1].isdigit() and (nums[2].lstrip('+').isdigit() or nums[2].lstrip('x').isdigit())


class BlasLtBaseBenchmark(MicroBenchmarkWithInvoke):
    """The BLASLt GEMM Base class."""
    def __init__(self, name, parameters=''):
        """Constructor.

        Args:
            name (str): benchmark name.
            parameters (str): benchmark parameters.
        """
        super().__init__(name, parameters)

    def add_parser_arguments(self):
        """Add the specified arguments."""
        super().add_parser_arguments()

        self._parser.add_argument(
            '--shapes',
            type=str,
            nargs='+',
            default=[f'{x},{x},{x}' for x in [2048, 4096, 8192]],
            help='Shapes in m,n,k format. Support format start:stop:multiplication_factor, e.g., 16:128:2.',
        )
        self._parser.add_argument(
            '--batch',
            type=str,
            default='0',
            required=False,
            help=(
                'Batch size for strided batch GEMM, set 0 to disable.'
                ' Support format start:stop:multiplication_factor, e.g., 16:128:2.'
            ),
        )
        self._parser.add_argument(
            '--num_warmup',
            type=int,
            default=20,
            required=False,
            help='Number of warm up steps.',
        )
        self._parser.add_argument(
            '--num_steps',
            type=int,
            default=50,
            required=False,
            help='Number of steps to measure.',
        )

    def _preprocess(self):
        """Preprocess/preparation operations before the benchmarking.

        Return:
            True if _preprocess() succeed.
        """
        if not super()._preprocess():
            return False

        if not validate_mrange(self._args.batch):
            logger.error(f'Invalid batch size {self._args.batch}.')
            return False

        for _in_type in self._args.in_types:
            if _in_type not in self._in_types:
                logger.error(f'Invalid input type {_in_type}.')
                return False

        self._shapes_to_run = []
        for _in_type in self._args.in_types:
            for _b in mrange(*map(int, self._args.batch.split(':'))):
                for shape in self._args.shapes:
                    shape_list = shape.replace(',', ' ').split()
                    if len(shape_list) != 3 or not all(validate_mrange(x) for x in shape_list):
                        logger.error(f'Invalid shape {shape}.')
                        return False
                    for _m, _n, _k in itertools.product(
                        *map(
                            lambda shape: mrange(
                                *map(lambda dim: int(dim.lstrip('+').lstrip('x')), shape.split(':')),
                                symbol=shape.split(':')[2][0]
                                if len(shape.split(':')) == 3 and any([i in shape for i in ['+', 'x']]) else 'x'
                            ), shape_list
                        )
                    ):
                        self._shapes_to_run.append((_m, _n, _k, _b, _in_type))

        return True