sharding_matmul.py 8.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""Module of the ShardingMatmul benchmarks.

ShardingMatmul benchmark is used to test the performance of large scale matmul operation with multiple GPUs:
  allreduce: Each GPU will calculate part of the MM calculation, and use AllReduce to merge all data into one tensor.
  allgather: Each GPU will calculate part of the MM calculation, and use AllGather + Concat to merge all data into
   one tensor.
  nosharding: Pure matmul operation with one GPU.

"""

import os
import time

# TODO - add mechanism to import torch as needed according to docker
import torch

from superbench.common.utils import logger
from superbench.benchmarks import BenchmarkRegistry, ReturnCode
from superbench.benchmarks.micro_benchmarks import MicroBenchmark
from superbench.benchmarks.context import Enum


class ShardingMode(Enum):
    """The Enum class representing different sharding mode."""
    ALLREDUCE = 'allreduce'
    ALLGATHER = 'allgather'
    NOSHARDING = 'nosharding'


class ShardingMatmul(MicroBenchmark):
    """The base class of micro-benchmarks."""
    def __init__(self, name, parameters=''):
        """Constructor.

        Args:
            name (str): benchmark name.
            parameters (str): benchmark parameters.
        """
        super().__init__(name, parameters)
        # Command lines to launch the micro-benchmarks.
        self.__commands = list()
        self.__world_size = 1
        self.__local_rank = 0
        torch.backends.cudnn.benchmark = True

    def add_parser_arguments(self):
        """Add the specified arguments."""
        super().add_parser_arguments()

        self._parser.add_argument(
            '--n',
            type=int,
            default=4096,
            required=False,
            help='The N dim of matmul (N, K) * (K, M).',
        )
        self._parser.add_argument(
            '--k',
            type=int,
            default=4096,
            required=False,
            help='The K dim of matmul (N, K) * (K, M).',
        )
        self._parser.add_argument(
            '--m',
            type=int,
            default=4096,
            required=False,
            help='The M dim of matmul (N, K) * (K, M).',
        )
        self._parser.add_argument(
            '--mode',
            type=ShardingMode,
            default=[ShardingMode.NOSHARDING],
            nargs='+',
            required=False,
            help='Sharding modes. E.g. {}.'.format(' '.join(ShardingMode.get_values())),
        )
        self._parser.add_argument(
            '--num_warmup',
            type=int,
            default=10,
            required=False,
            help='The number of warmup step.',
        )
        self._parser.add_argument(
            '--num_steps',
            type=int,
            default=500,
            required=False,
            help='The number of test step.',
        )

    def _preprocess(self):
        """Preprocess/preparation operations before the benchmarking.

        Return:
            True if _preprocess() succeed.
        """
        if not super()._preprocess():
            return False

        if ShardingMode.ALLGATHER in self._args.mode or ShardingMode.ALLREDUCE in self._args.mode:
            try:
                torch.distributed.init_process_group(backend='nccl')
                self.__world_size = int(os.environ['WORLD_SIZE'])
                self.__local_rank = int(os.environ['LOCAL_RANK'])
            except BaseException as e:
                self._result.set_return_code(ReturnCode.DISTRIBUTED_SETTING_INIT_FAILURE)
                logger.error(
                    'Initialize distributed env failed - benchmark: {}, message: {}.'.format(self._name, str(e))
                )
                return False

        if torch.cuda.is_available():
            torch.cuda.set_device(self.__local_rank)

        return True

    def __matmul_nosharding(self, M, K, N):
        """Matmul with single GPU.

        Args:
            N (int): The N dim of matmul (N, K) * (K, M).
            K (int): The K dim of matmul (N, K) * (K, M).
            M (int): The M dim of matmul (N, K) * (K, M).

        Return:
            elapse_times (List[float]): cost of every test.
        """
        x = torch.ones(N, K).cuda()
        y = torch.ones(K, M).cuda()
        for i in range(self._args.num_warmup):
            torch.matmul(x, y)
            torch.cuda.synchronize()

        elapse_times = list()
        for i in range(self._args.num_steps):
            start = time.time()
            torch.matmul(x, y)
            torch.cuda.synchronize()
            end = time.time()
            elapse_times.append((end - start) * 1000)

        return elapse_times

    def __matmul_allreduce(self, M, K, N):
        """Matmul with allreduce sharding.

        Args:
            N (int): The N dim of matmul (N, K) * (K, M).
            K (int): The K dim of matmul (N, K) * (K, M).
            M (int): The M dim of matmul (N, K) * (K, M).

        Return:
            elapse_times (List[float]): cost of every test.
        """
        x = torch.ones(N, K // self.__world_size).cuda()
        y = torch.ones(K // self.__world_size, M).cuda()
        for i in range(self._args.num_warmup):
            z = torch.matmul(x, y)
            torch.cuda.synchronize()
            torch.distributed.all_reduce(z, op=torch.distributed.ReduceOp.SUM)
            torch.cuda.synchronize()

        elapse_times = list()
        for i in range(self._args.num_steps):
            start = time.time()
            z = torch.matmul(x, y)
            torch.cuda.synchronize()
            torch.distributed.all_reduce(z, op=torch.distributed.ReduceOp.SUM)
            torch.cuda.synchronize()
            end = time.time()
            elapse_times.append((end - start) * 1000)

        return elapse_times

    def __matmul_allgather(self, M, K, N):
        """Matmul with allgather sharding.

        Args:
            N (int): The N dim of matmul (N, K) * (K, M).
            K (int): The K dim of matmul (N, K) * (K, M).
            M (int): The M dim of matmul (N, K) * (K, M).

        Return:
            elapse_times (List[float]): cost of every test.
        """
        x = torch.ones(N // self.__world_size, K).cuda()
        y = torch.ones(K, M).cuda()

        tensor_list = list()
        for i in range(self.__world_size):
            tensor_list.append(torch.zeros(N // self.__world_size, M).cuda())

        for i in range(self._args.num_warmup):
            z = torch.matmul(x, y)
            torch.cuda.synchronize()
            torch.distributed.all_gather(tensor_list, z)
            torch.cuda.synchronize()

        elapse_times = list()
        for i in range(self._args.num_steps):
            start = time.time()
            z = torch.matmul(x, y)
            torch.cuda.synchronize()
            torch.distributed.all_gather(tensor_list, z)
            z = torch.cat(tensor_list, 0)
            torch.cuda.synchronize()
            end = time.time()
            elapse_times.append((end - start) * 1000)

        return elapse_times

    def _benchmark(self):
        """Implementation for benchmarking."""
        M = self._args.m
        K = self._args.k
        N = self._args.n
        for mode in self._args.mode:
224
            if mode == ShardingMode.NOSHARDING:
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
                elapse_times = self.__matmul_nosharding(M, K, N)
            elif mode == ShardingMode.ALLREDUCE:
                elapse_times = self.__matmul_allreduce(M, K, N)
            elif mode == ShardingMode.ALLGATHER:
                elapse_times = self.__matmul_allgather(M, K, N)
            else:
                logger.error('Unknown sharding mode - benchmark: {}, mode: {}.'.format(self._name, mode))
                return False

            metric = 'matmul_sharding_{}'.format(mode)
            if not self._process_numeric_result(metric, elapse_times):
                return False

            logger.info(
                'Matmul sharding - round: {0}, name: {1}, shape: ({2}, {3}) * ({3}, {4}), mode: {5}, cost: {6} ms'.
                format(self._curr_run_index, self._name, M, K, N, mode,
                       sum(elapse_times) / len(elapse_times))
            )

        return True


BenchmarkRegistry.register_benchmark('pytorch-sharding-matmul', ShardingMatmul, parameters='--mode allreduce allgather')
BenchmarkRegistry.register_benchmark('pytorch-matmul', ShardingMatmul, parameters='--mode nosharding')