dist_inference.py 15.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""Module of the distributed inference benchmark."""

import os
import time

import torch
import torch.nn.functional as F
import torch.distributed as dist

from superbench.common.utils import logger
from superbench.benchmarks import DistributedImpl, DistributedBackend, BenchmarkRegistry, ReturnCode, Precision
from superbench.benchmarks.micro_benchmarks import MicroBenchmark
from superbench.benchmarks.context import Enum
17
from superbench.benchmarks.reducer import ReduceType
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124


class ComputationKernelType(Enum):
    """The Enum class representing different computation kernel type."""
    ADDMM = 'addmm'
    MATMUL = 'matmul'
    MUL = 'mul'


class CommunicationKernelType(Enum):
    """The Enum class representing different communication kernel type."""
    ALLGATHER = 'allgather'
    ALLREDUCE = 'allreduce'
    ALLTOALL = 'alltoall'


class ActivationKernelType(Enum):
    """The Enum class representing different activation kernel type."""
    RELU = 'relu'
    SIGMOID = 'sigmoid'
    TANH = 'tanh'


class DistInferenceModel(torch.nn.Module):
    """The model class for distributed inference benchmark."""
    def __init__(
        self, input_size, hidden_size, num_layers, computation, communication, activation, precision, num_ranks, device
    ):
        """Constructor.

        Args:
            input_size (int): input data dimension.
            hidden_size (int): hidden layer dimension.
            num_layers (int): number of layers in the model.
            computation (ComputationKernelType): type of computation kernel of this model.
            communication (CommunicationKernelType): type of communication kernel of this model.
            activation (ActivationKernelType): type of activation kernel of this model.
            precision (Precision): data type of this model.
            num_ranks (int): number of ranks in this model runs.
            device (torch.device): device this model runs on.
        """
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.weights = torch.rand(
            self.input_size, self.hidden_size, dtype=getattr(torch, precision.value), device=device
        )
        self.bias = torch.rand(self.hidden_size, dtype=getattr(torch, precision.value), device=device)
        self.num_ranks = num_ranks
        self.step_times = []

        self.__init_computation_kernels(computation)
        self.__init_communication_kernels(communication)
        self.__init_activation_kernels(activation)

    def __init_computation_kernels(self, computation):
        """Select computation kernel according to option.

        Args:
            computation (ComputationKernelType): the type of the computation kernel to run.
        """
        self.computation_kernel = None
        if computation == ComputationKernelType.ADDMM:
            self.computation_kernel = lambda x: torch.addmm(self.bias, x, self.weights)
        elif computation == ComputationKernelType.MATMUL:
            self.computation_kernel = lambda x: torch.matmul(x, self.weights)
        elif computation == ComputationKernelType.MUL:
            self.computation_kernel = lambda x: torch.mul(x, x)

    def __init_communication_kernels(self, communication):
        """Select communication kernel according to option.

        Args:
            communication (CommunicationKernelType): the type of the communication kernel to run.
        """
        self.communication_kernel = None
        if communication == CommunicationKernelType.ALLGATHER:
            self.communication_kernel = self.__all_gather_wrapper
        elif communication == CommunicationKernelType.ALLREDUCE:
            self.communication_kernel = self.__all_reduce_wrapper
        elif communication == CommunicationKernelType.ALLTOALL:
            self.communication_kernel = self.__all_to_all_wrapper

    def __init_activation_kernels(self, activation):
        """Select activation kernel according to option.

        Args:
            activation (ActivationKernelType): the type of the activation kernel to run.
        """
        self.activation_kernel = None
        if activation == ActivationKernelType.RELU:
            self.activation_kernel = F.relu
        elif activation == ActivationKernelType.SIGMOID:
            self.activation_kernel = F.sigmoid
        elif activation == ActivationKernelType.TANH:
            self.activation_kernel = F.tanh

    def __all_gather_wrapper(self, x):
        """All-gather wrapper with output initialization.

        Args:
            x (Tensor): input.

        Return:
            Tensor after all-gather.
        """
125
        output = torch.empty([x.shape[0] * self.num_ranks] + list(x.shape[1:]), dtype=x.dtype, device=x.device)
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
        dist.all_gather_into_tensor(output, x)
        return output

    def __all_reduce_wrapper(self, x):
        """All-reduce wrapper.

        Args:
            x (Tensor): input.

        Return:
            Tensor after all-reduce.
        """
        dist.all_reduce(x)
        return x

    def __all_to_all_wrapper(self, x):
        """All-to-all wrapper with output initialization.

        Args:
            x (Tensor): input.

        Return:
            Tensor after all-to-all.
        """
        output = torch.empty_like(x)
        dist.all_to_all_single(output, x)
        return output

    def forward(self, x):
        """Do forward loops.

        Args:
            x (Tensor): input.

        Return:
            Tensor after the whole inference process.
        """
        activation_out = None
        for i in range(self.num_layers):
            computation_out = self.computation_kernel(x)
            communication_out = self.communication_kernel(computation_out)
            activation_out = self.activation_kernel(communication_out)
        return activation_out


class DistInference(MicroBenchmark):
    """The base class of micro-benchmarks."""
    def __init__(self, name, parameters=''):
        """Constructor.

        Args:
            name (str): benchmark name.
            parameters (str): benchmark parameters.
        """
        super().__init__(name, parameters)
        self.__world_size = 1
        self.__local_rank = 0
        torch.backends.cudnn.benchmark = True
        self.__device = None
        self.__cuda_available = False

    def __timer(self):
        """Returns the current time which ensures all previous CUDA events have been finished.

        If there is no GPU present, this defaults to `time.time()`; otherwise it will
        synchronize CUDA before measuring the time.

        Return:
            Current time in second.
        """
        if self.__cuda_available:
            torch.cuda.synchronize()
        return time.time()

    def add_parser_arguments(self):
        """Add the specified arguments."""
        super().add_parser_arguments()

        self._parser.add_argument(
            '--batch_size',
            type=int,
            default=64,
            required=False,
            help='Batch size.',
        )
        self._parser.add_argument(
            '--input_size',
            type=int,
            default=1024,
            required=False,
            help='Input dimension size.',
        )
        self._parser.add_argument(
            '--hidden_size',
            type=int,
            default=1024,
            required=False,
            help='Hidden size.',
        )
        self._parser.add_argument(
            '--num_layers',
            type=int,
            default=1,
            required=False,
            help='Number of compute-communicate-activate layers.',
        )
        self._parser.add_argument(
            '--computation_kernel',
            type=ComputationKernelType,
            default=ComputationKernelType.MATMUL,
            required=False,
            help='Computation kernel type. E.g. {}.'.format(' '.join(ComputationKernelType.get_values())),
        )
        self._parser.add_argument(
            '--communication_kernel',
            type=CommunicationKernelType,
            default=CommunicationKernelType.ALLREDUCE,
            required=False,
            help='Communication kernel type. E.g. {}.'.format(' '.join(CommunicationKernelType.get_values())),
        )
        self._parser.add_argument(
            '--activation_kernel',
            type=ActivationKernelType,
            default=ActivationKernelType.RELU,
            required=False,
            help='Activation kernel type. E.g. {}.'.format(' '.join(ActivationKernelType.get_values())),
        )
        self._parser.add_argument(
            '--precision',
            type=Precision,
            default=Precision.FLOAT32,
            required=False,
            help='Model precision. E.g. {}.'.format(' '.join(Precision.get_values())),
        )
        self._parser.add_argument(
            '--num_warmup',
            type=int,
            default=50,
            required=False,
            help='Number of warmup steps.',
        )
        self._parser.add_argument(
            '--num_steps',
            type=int,
            default=10000,
            required=False,
            help='Number of test steps.',
        )
        self._parser.add_argument(
            '--distributed_impl',
            type=DistributedImpl,
            default=DistributedImpl.DDP,
            required=False,
            help='Distributed implementations. E.g. {}.'.format(' '.join(DistributedImpl.get_values())),
        )
        self._parser.add_argument(
            '--distributed_backend',
            type=DistributedBackend,
            default=DistributedBackend.NCCL,
            required=False,
            help='Distributed backends. E.g. {}.'.format(' '.join(DistributedBackend.get_values())),
        )

    def _preprocess(self):
        """Preprocess/preparation operations before the benchmarking.

        Return:
            True if _preprocess() succeed.
        """
        if not super()._preprocess():
            return False

        if self._args.distributed_impl != DistributedImpl.DDP:
            self._result.set_return_code(ReturnCode.DISTRIBUTED_SETTING_INIT_FAILURE)
            logger.error(
                'Unsupported distributed implementation - model: {}, distributed implementation: {}.'.format(
                    self._name, self._args.distributed_impl
                )
            )
            return False

        try:
            torch.distributed.init_process_group(backend=self._args.distributed_backend.value)
            self.__world_size = int(os.environ['WORLD_SIZE'])
            self.__local_rank = int(os.environ['LOCAL_RANK'])
        except BaseException as e:
            self._result.set_return_code(ReturnCode.DISTRIBUTED_SETTING_INIT_FAILURE)
            torch.distributed.destroy_process_group()
            logger.error('Initialize distributed env failed - benchmark: {}, message: {}.'.format(self._name, str(e)))
            return False

        if torch.cuda.is_available():
            torch.cuda.set_device(self.__local_rank)
            self.__device = torch.device('cuda:{}'.format(self.__local_rank))
            self.__cuda_available = True
        else:
            self.__device = torch.device('cpu:{}'.format(self.__local_rank))
            self.__cuda_available = False

        return True

    def _prepare_model(
        self, input_size, hidden_size, num_layers, computation, communication, activation, precision, num_ranks
    ):
        """Prepare model.

        Args:
            input_size (int): input data dimension.
            hidden_size (int): hidden layer dimension.
            num_layers (int): number of layers in the model.
            computation (ComputationKernelType): type of computation kernel of this model.
            communication (CommunicationKernelType): type of communication kernel of this model.
            activation (ActivationKernelType): type of activation kernel of this model.
            precision (Precision): data type of this model.
            num_ranks (int): number of ranks in this model runs.

        Return:
            Model prepared.
        """
        model = DistInferenceModel(
            input_size, hidden_size, num_layers, computation, communication, activation, precision, num_ranks,
            self.__device
        )
        model = model.to(dtype=getattr(torch, precision.value))
        if self.__cuda_available:
            model = model.cuda()
        return model

    def _run_model(self, model, batch_size, input_size, precision, device, num_warmup, num_steps):
        """Run model and collect step times.

        Args:
            model (torch.nn.Module): model to run.
            batch_size (int): batch size of input data.
            input_size (int): input data dimension.
            precision (Precision): data type of this model.
            device (torch.device): device this model runs on.
            num_warmup (int): number of warm-up runs.
            num_steps (int): number of test runs.

        Return:
            Model step times collected.
        """
        data = torch.rand(batch_size, input_size, dtype=getattr(torch, precision.value), device=self.__device)

        # warm up
        for i in range(num_warmup):
            model(data)

        # run and collect results
        step_times = [0.] * num_steps
        for i in range(self._args.num_steps):
            start = self.__timer()
            model(data)
            end = self.__timer()
            step_times[i] = (end - start) * 1000

        return step_times

    def _process_data(self, step_times):
        """Process data.

        Args:
            step_times (List[float]): Model step times collected.

        Return:
            True if _process_data succeeds.
        """
394
        if not self._process_numeric_result('step_times', step_times, reduce_type=ReduceType.MAX, cal_percentile=True):
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
            return False
        return True

    def _benchmark(self):
        """Implementation for benchmarking.

        Return:
            True if _benchmark succeeds.
        """
        batch_size = self._args.batch_size
        input_size = self._args.input_size
        hidden_size = self._args.hidden_size
        num_layers = self._args.num_layers
        computation = self._args.computation_kernel
        communication = self._args.communication_kernel
        activation = self._args.activation_kernel
        precision = self._args.precision
        num_warmup = self._args.num_warmup
        num_steps = self._args.num_steps

        if self.__local_rank == 0:
            logger.info(
                'Distributed Inference - using {} GPUs: '
                'batch_size={}, input_size={}, hidden_size={}, num_layers={}, '
                'computation_kernel={}, communication_kernel={}, activation_kernel={}, precision={}, '
                'num_warmup={} num_steps={}'.format(
                    self.__world_size, batch_size, input_size, hidden_size, num_layers, computation, communication,
                    activation, precision, num_warmup, num_steps
                )
            )

        # Prepare model
        model = self._prepare_model(
            input_size, hidden_size, num_layers, computation, communication, activation, precision, self.__world_size
        )

        # Run model
        step_times = self._run_model(model, batch_size, input_size, precision, self.__device, num_warmup, num_steps)

        # Process data and return
        return self._process_data(step_times)

    def _postprocess(self):
        """Postprocess/cleanup operations after the benchmarking.

        Return:
            True if _postprocess succeeds.
        """
        if not super()._postprocess():
            return False

        try:
            torch.distributed.destroy_process_group()
        except BaseException as e:
            self._result.set_return_code(ReturnCode.DISTRIBUTED_SETTING_DESTROY_FAILURE)
            logger.error('Post process failed - benchmark: {}, message: {}.'.format(self._name, str(e)))
            return False

        return True


BenchmarkRegistry.register_benchmark('pytorch-dist-inference', DistInference, parameters='')