layers.py 22 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14


# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch


import math

import torch
import torch.nn.functional as F
import torch.nn.init as init
from torch.nn.parameter import Parameter

15
16
from .initialize import get_tensor_model_parallel_rank
from .initialize import get_tensor_model_parallel_world_size
17
from .initialize import get_tensor_model_parallel_group
18
19
from .mappings import copy_to_tensor_model_parallel_region
from .mappings import gather_from_tensor_model_parallel_region
20
from .mappings import gather_from_sequence_parallel_region
21
22
from .mappings import reduce_from_tensor_model_parallel_region
from .mappings import scatter_to_tensor_model_parallel_region
23
from .mappings import reduce_scatter_to_sequence_parallel_region
24

25
26
27
28
from .random import get_cuda_rng_tracker
from .utils import divide
from .utils import split_tensor_along_last_dim
from .utils import VocabUtility
29
from megatron import get_args, get_global_memory_buffer
mohammad's avatar
mohammad committed
30
31
32
33
34

_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False,
                                      'partition_dim': -1,
                                      'partition_stride': 1}

mohammad's avatar
mohammad committed
35
36
37
38
39
40
def param_is_not_tensor_parallel_duplicate(param):
    return (hasattr(param, 'tensor_model_parallel') and
            param.tensor_model_parallel) or (
                get_tensor_model_parallel_rank() == 0)


mohammad's avatar
mohammad committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride):
    # Make sure the attributes are not set.
    for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
        assert not hasattr(tensor, attribute)
    # Set the attributes.
    setattr(tensor, 'tensor_model_parallel', is_parallel)
    setattr(tensor, 'partition_dim', dim)
    setattr(tensor, 'partition_stride', stride)


def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor):
    def maybe_set(attribute, value):
        if not hasattr(tensor, attribute):
            setattr(tensor, attribute, value)
    for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
        maybe_set(attribute, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS[attribute])


def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor):
    def maybe_copy(attribute):
        if hasattr(source_tensor, attribute):
            setattr(destination_tensor, attribute,
                    getattr(source_tensor, attribute))
    for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
        maybe_copy(attribute)


68
69
70
71
def _initialize_affine_weight_gpu(weight, init_method,
                                  partition_dim, stride=1):
    """Initialize affine weight for model parallel on GPU."""

mohammad's avatar
mohammad committed
72
73
74
75
76
    set_tensor_model_parallel_attributes(tensor=weight,
                                         is_parallel=True,
                                         dim=partition_dim,
                                         stride=stride)

77
78
79
80
81
82
83
84
    with get_cuda_rng_tracker().fork():
        init_method(weight)


def _initialize_affine_weight_cpu(weight, output_size, input_size,
                                  per_partition_size, partition_dim,
                                  init_method, stride=1,
                                  return_master_weight=False):
85
86
87
88
    """Initialize affine weight for model parallel.

    Build the master weight on all processes and scatter
    the relevant chunk."""
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
89

mohammad's avatar
mohammad committed
90
91
92
93
    set_tensor_model_parallel_attributes(tensor=weight,
                                         is_parallel=True,
                                         dim=partition_dim,
                                         stride=stride)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
94

95
96
    # Initialize master weight
    master_weight = torch.empty(output_size, input_size,
97
                                dtype=torch.float,
98
99
                                requires_grad=False)
    init_method(master_weight)
100
101
    args = get_args()
    master_weight = master_weight.to(dtype=args.params_dtype)
102
103
104
105
106

    # Split and copy
    per_partition_per_stride_size = divide(per_partition_size, stride)
    weight_list = torch.split(master_weight, per_partition_per_stride_size,
                              dim=partition_dim)
Jared Casper's avatar
Jared Casper committed
107
    rank = get_tensor_model_parallel_rank()
108
    world_size = get_tensor_model_parallel_world_size()
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
    my_weight_list = weight_list[rank::world_size]

    with torch.no_grad():
        torch.cat(my_weight_list, dim=partition_dim, out=weight)
    if return_master_weight:
        return master_weight
    return None


class VocabParallelEmbedding(torch.nn.Module):
    """Embedding parallelized in the vocabulary dimension.

    This is mainly adapted from torch.nn.Embedding and all the default
    values are kept.
    Arguments:
        num_embeddings: vocabulary size.
        embedding_dim: size of hidden state.
        init_method: method to initialize weights.
    """
Neel Kant's avatar
Neel Kant committed
128

129
130
131
132
133
134
135
136
137
138
139
140
141
    def __init__(self, num_embeddings, embedding_dim,
                 init_method=init.xavier_normal_):
        super(VocabParallelEmbedding, self).__init__()
        # Keep the input dimensions.
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        # Set the detauls for compatibility.
        self.padding_idx = None
        self.max_norm = None
        self.norm_type = 2.
        self.scale_grad_by_freq = False
        self.sparse = False
        self._weight = None
142
        self.tensor_model_parallel_size = get_tensor_model_parallel_world_size()
143
144
145
        # Divide the weight matrix along the vocaburaly dimension.
        self.vocab_start_index, self.vocab_end_index = \
            VocabUtility.vocab_range_from_global_vocab_size(
146
147
                self.num_embeddings, get_tensor_model_parallel_rank(),
                self.tensor_model_parallel_size)
148
        self.num_embeddings_per_partition = self.vocab_end_index - \
Neel Kant's avatar
Neel Kant committed
149
            self.vocab_start_index
150

151
152
        # Allocate weights and initialize.
        args = get_args()
153
        if args.use_cpu_initialization:
154
155
156
            self.weight = Parameter(torch.empty(
                self.num_embeddings_per_partition, self.embedding_dim,
                dtype=args.params_dtype))
157
158
159
160
            if args.perform_initialization:
                _initialize_affine_weight_cpu(
                    self.weight, self.num_embeddings, self.embedding_dim,
                    self.num_embeddings_per_partition, 0, init_method)
161
162
163
164
        else:
            self.weight = Parameter(torch.empty(
                self.num_embeddings_per_partition, self.embedding_dim,
                device=torch.cuda.current_device(), dtype=args.params_dtype))
165
166
167
            if args.perform_initialization:
                _initialize_affine_weight_gpu(self.weight, init_method,
                                              partition_dim=0, stride=1)
168
169

    def forward(self, input_):
170
        if self.tensor_model_parallel_size > 1:
171
172
173
174
175
176
177
178
179
            # Build the mask.
            input_mask = (input_ < self.vocab_start_index) | \
                         (input_ >= self.vocab_end_index)
            # Mask the input.
            masked_input = input_.clone() - self.vocab_start_index
            masked_input[input_mask] = 0
        else:
            masked_input = input_
            # Get the embeddings.
180
181
182
183
184
        output_parallel = F.embedding(masked_input, self.weight,
                                      self.padding_idx, self.max_norm,
                                      self.norm_type, self.scale_grad_by_freq,
                                      self.sparse)
        # Mask the output embedding.
185
        if self.tensor_model_parallel_size > 1:
186
            output_parallel[input_mask, :] = 0.0
187
        # Reduce across all the model parallel GPUs.
188
        output = reduce_from_tensor_model_parallel_region(output_parallel)
189
190
191
        return output


192
class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
193
    """
194
195
    Linear layer execution with asynchronous communication and gradient accumulation
    fusion in backprop.
196
    """
Vijay Korthikanti's avatar
Vijay Korthikanti committed
197

198
    @staticmethod
199
    def forward(ctx, input, weight, bias, gradient_accumulation_fusion,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
200
                async_grad_allreduce, sequence_parallel):
201
        ctx.save_for_backward(input, weight)
slym's avatar
slym committed
202
        ctx.use_bias = bias is not None
203
204
        ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
        ctx.async_grad_allreduce = async_grad_allreduce
Vijay Korthikanti's avatar
Vijay Korthikanti committed
205
        ctx.sequence_parallel = sequence_parallel
Vijay Korthikanti's avatar
Vijay Korthikanti committed
206
      
Vijay Korthikanti's avatar
Vijay Korthikanti committed
207
        if sequence_parallel:
208
209
210
211
            world_size = get_tensor_model_parallel_world_size()
            dim_size = list(input.size())
            dim_size[0] = dim_size[0] * world_size

Vijay Korthikanti's avatar
Vijay Korthikanti committed
212
            all_gather_buffer = \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
213
                get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu")
Vijay Korthikanti's avatar
Vijay Korthikanti committed
214
            torch.distributed._all_gather_base(
Vijay Korthikanti's avatar
Vijay Korthikanti committed
215
                all_gather_buffer,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
216
217
                input,
                group=get_tensor_model_parallel_group())
Vijay Korthikanti's avatar
Vijay Korthikanti committed
218
            total_input = all_gather_buffer
219
220
221
        else:
            total_input = input

222
        output = torch.matmul(total_input, weight.t())
slym's avatar
slym committed
223
        if bias is not None:
224
225
226
227
228
229
230
            output = output + bias
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, weight = ctx.saved_tensors
        use_bias = ctx.use_bias
231
        
Vijay Korthikanti's avatar
Vijay Korthikanti committed
232
        if ctx.sequence_parallel:
233
234
235
236
            world_size = get_tensor_model_parallel_world_size()
            dim_size = list(input.size())
            dim_size[0] = dim_size[0] * world_size

Vijay Korthikanti's avatar
Vijay Korthikanti committed
237
            all_gather_buffer = \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
238
                get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu")
Vijay Korthikanti's avatar
Vijay Korthikanti committed
239
            handle = torch.distributed._all_gather_base(
Vijay Korthikanti's avatar
Vijay Korthikanti committed
240
                all_gather_buffer,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
241
242
                input,
                group=get_tensor_model_parallel_group(), async_op=True)
243

244
245
246
            # Delay the start of intput gradient computation shortly (3us) to have
            # gather scheduled first and have GPU resources allocated
            _ = torch.empty(1, device=grad_output.device) + 1
Vijay Korthikanti's avatar
Vijay Korthikanti committed
247
            total_input = all_gather_buffer
248
249
        else:
            total_input = input
250
        grad_input = grad_output.matmul(weight)
251

Vijay Korthikanti's avatar
Vijay Korthikanti committed
252
        if ctx.sequence_parallel:
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
            handle.wait()

        # Convert the tensor shapes to 2D for execution compatibility
        grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1],
                                       grad_output.shape[2])
        total_input = total_input.view(total_input.shape[0] * total_input.shape[1],
				       total_input.shape[2])
 
        if ctx.async_grad_allreduce:
            # Asynchronous all-reduce
            handle = torch.distributed.all_reduce(
                    grad_input, group=get_tensor_model_parallel_group(), async_op=True)
            # Delay the start of weight gradient computation shortly (3us) to have
            # all-reduce scheduled first and have GPU resources allocated
            _ = torch.empty(1, device=grad_output.device) + 1
 
Vijay Korthikanti's avatar
Vijay Korthikanti committed
269
        if ctx.sequence_parallel:
270
271
272
            assert not ctx.async_grad_allreduce
            dim_size = list(input.size())
            sub_grad_input = torch.empty(dim_size, dtype=input.dtype,
273
274
                                         device=torch.cuda.current_device(),
                                         requires_grad=False)
275
276
277
278
279
280
281
282
283
284
            # reduce_scatter
            handle = torch.distributed._reduce_scatter_base(sub_grad_input, grad_input, 
                                                            group=get_tensor_model_parallel_group(),
                                                            async_op=True)
            # Delay the start of weight gradient computation shortly (3us) to have
            # reduce scatter scheduled first and have GPU resources allocated
            _ = torch.empty(1, device=grad_output.device) + 1
        

        if ctx.gradient_accumulation_fusion:
285
            import fused_dense_cuda
286
287
288
289
            fused_dense_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, weight.main_grad)
            grad_weight = None
        else:
            grad_weight = grad_output.t().matmul(total_input)
290
        grad_bias = grad_output.sum(dim=0) if use_bias else None
291

Vijay Korthikanti's avatar
Vijay Korthikanti committed
292
        if ctx.sequence_parallel:
293
            handle.wait()
Vijay Korthikanti's avatar
Vijay Korthikanti committed
294
            return sub_grad_input, grad_weight, grad_bias, None, None, None
295

Sangkug Lym's avatar
Sangkug Lym committed
296
297
        if ctx.async_grad_allreduce:
            handle.wait()
298

Vijay Korthikanti's avatar
Vijay Korthikanti committed
299
        return grad_input, grad_weight, grad_bias, None, None, None
300
301


302
303
304
305
306
307
308
309
310
311
class ColumnParallelLinear(torch.nn.Module):
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias
Sangkug Lym's avatar
Sangkug Lym committed
312
        gather_output: If true, call all-gather on output and make Y available
313
314
315
316
317
318
319
320
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        init_method: method to initialize weights. Note that bias is always set
                     to zero.
        stride: For the strided linear layers.
        keep_master_weight_for_test: This was added for testing and should be
                                     set to False. It returns the master weights
                                     used for initialization.
321
        skip_bias_add: This was added to enable performance optimations where bias
322
                       can be fused with other elementwise operations. we skip
323
                       adding bias but instead return it.
324
    """
Neel Kant's avatar
Neel Kant committed
325

326
327
    def __init__(self, input_size, output_size, bias=True, gather_output=True,
                 init_method=init.xavier_normal_, stride=1,
328
329
                 keep_master_weight_for_test=False,
                 skip_bias_add=False):
330
331
332
333
334
335
336
        super(ColumnParallelLinear, self).__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.gather_output = gather_output
        # Divide the weight matrix along the last dimension.
337
        world_size = get_tensor_model_parallel_world_size()
338
        self.output_size_per_partition = divide(output_size, world_size)
339
        self.skip_bias_add = skip_bias_add
340
341
342
343

        # Parameters.
        # Note: torch.nn.functional.linear performs XA^T + b and as a result
        # we allocate the transpose.
344
345
        # Initialize weight.
        args = get_args()
346
        if args.use_cpu_initialization:
347
348
349
            self.weight = Parameter(torch.empty(self.output_size_per_partition,
                                                self.input_size,
                                                dtype=args.params_dtype))
350
351
352
353
354
            if args.perform_initialization:
                self.master_weight = _initialize_affine_weight_cpu(
                    self.weight, self.output_size, self.input_size,
                    self.output_size_per_partition, 0, init_method,
                    stride=stride, return_master_weight=keep_master_weight_for_test)
355
356
357
358
        else:
            self.weight = Parameter(torch.empty(
                self.output_size_per_partition, self.input_size,
                device=torch.cuda.current_device(), dtype=args.params_dtype))
359
360
361
            if args.perform_initialization:
                _initialize_affine_weight_gpu(self.weight, init_method,
                                              partition_dim=0, stride=stride)
hwijeen's avatar
hwijeen committed
362

363
        if bias:
364
            if args.use_cpu_initialization:
365
366
367
368
369
370
371
                self.bias = Parameter(torch.empty(
                    self.output_size_per_partition, dtype=args.params_dtype))
            else:
                self.bias = Parameter(torch.empty(
                    self.output_size_per_partition,
                    device=torch.cuda.current_device(),
                    dtype=args.params_dtype))
372
            set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
373
374
375
376
377
            # Always initialize bias to zero.
            with torch.no_grad():
                self.bias.zero_()
        else:
            self.register_parameter('bias', None)
slym's avatar
slym committed
378
        self.async_tensor_model_parallel_allreduce = (
Sangkug Lym's avatar
Sangkug Lym committed
379
                args.async_tensor_model_parallel_allreduce and
slym's avatar
slym committed
380
                world_size > 1)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
381
382
        self.sequence_parallel = (
                args.sequence_parallel and
383
384
                world_size > 1)
        assert not self.async_tensor_model_parallel_allreduce or \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
385
            not self.sequence_parallel
Sangkug Lym's avatar
Sangkug Lym committed
386
        self.gradient_accumulation_fusion = args.gradient_accumulation_fusion
387
388

    def forward(self, input_):
389
        bias = self.bias if not self.skip_bias_add else None
390

391
        if self.async_tensor_model_parallel_allreduce or \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
392
                self.sequence_parallel:
393
            input_parallel = input_
394
395
        else:
            input_parallel = copy_to_tensor_model_parallel_region(input_)
396
397
398
        # Matrix multiply.
        output_parallel = LinearWithGradAccumulationAndAsyncCommunication.apply(
            input_parallel, self.weight, bias, self.gradient_accumulation_fusion,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
399
            self.async_tensor_model_parallel_allreduce, self.sequence_parallel)
400
401
        if self.gather_output:
            # All-gather across the partitions.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
402
            assert not self.sequence_parallel
403
            output = gather_from_tensor_model_parallel_region(output_parallel)
404
        else:
hwijeen's avatar
hwijeen committed
405
            output = output_parallel
406
407
        output_bias = self.bias if self.skip_bias_add else None
        return output, output_bias
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434


class RowParallelLinear(torch.nn.Module):
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        init_method: method to initialize weights. Note that bias is always set
                     to zero.
        stride: For the strided linear layers.
        keep_master_weight_for_test: This was added for testing and should be
                                     set to False. It returns the master weights
                                     used for initialization.
hwijeen's avatar
hwijeen committed
435
436
        skip_bias_add: This was added to enable performance optimization where bias
                       can be fused with other elementwise operations. We skip
437
                       adding bias but instead return it.
438
    """
Neel Kant's avatar
Neel Kant committed
439

440
441
442
    def __init__(self, input_size, output_size, bias=True,
                 input_is_parallel=False,
                 init_method=init.xavier_normal_, stride=1,
443
444
                 keep_master_weight_for_test=False,
                 skip_bias_add=False):
445
446
447
448
449
450
451
        super(RowParallelLinear, self).__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.input_is_parallel = input_is_parallel
        # Divide the weight matrix along the last dimension.
452
        world_size = get_tensor_model_parallel_world_size()
453
        self.input_size_per_partition = divide(input_size, world_size)
454
        self.skip_bias_add = skip_bias_add
455
456
457
458

        # Parameters.
        # Note: torch.nn.functional.linear performs XA^T + b and as a result
        # we allocate the transpose.
459
460
        # Initialize weight.
        args = get_args()
461
        if args.use_cpu_initialization:
462
463
464
            self.weight = Parameter(torch.empty(self.output_size,
                                                self.input_size_per_partition,
                                                dtype=args.params_dtype))
465
466
467
468
469
            if args.perform_initialization:
                self.master_weight = _initialize_affine_weight_cpu(
                    self.weight, self.output_size, self.input_size,
                    self.input_size_per_partition, 1, init_method,
                    stride=stride, return_master_weight=keep_master_weight_for_test)
470
471
472
473
        else:
            self.weight = Parameter(torch.empty(
                self.output_size, self.input_size_per_partition,
                device=torch.cuda.current_device(), dtype=args.params_dtype))
474
475
476
            if args.perform_initialization:
                _initialize_affine_weight_gpu(self.weight, init_method,
                                              partition_dim=1, stride=stride)
477
        if bias:
478
            if args.use_cpu_initialization:
479
480
481
482
483
484
                self.bias = Parameter(torch.empty(self.output_size,
                                                  dtype=args.params_dtype))
            else:
                self.bias = Parameter(torch.empty(
                    self.output_size, device=torch.cuda.current_device(),
                    dtype=args.params_dtype))
Vijay Korthikanti's avatar
Vijay Korthikanti committed
485
            setattr(self.bias, 'sequence_parallel', args.sequence_parallel)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
486

487
488
489
490
491
            # Always initialize bias to zero.
            with torch.no_grad():
                self.bias.zero_()
        else:
            self.register_parameter('bias', None)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
492
        self.sequence_parallel = args.sequence_parallel
Sangkug Lym's avatar
Sangkug Lym committed
493
        self.gradient_accumulation_fusion = args.gradient_accumulation_fusion
494

495

496
497
498
499
500
501

    def forward(self, input_):
        # Set up backprop all-reduce.
        if self.input_is_parallel:
            input_parallel = input_
        else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
502
            assert not self.sequence_parallel
503
            input_parallel = scatter_to_tensor_model_parallel_region(input_)
504
        # Matrix multiply.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
505
        output_parallel = LinearWithGradAccumulationAndAsyncCommunication.apply(
Sangkug Lym's avatar
Sangkug Lym committed
506
            input_parallel, self.weight, None,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
507
            self.gradient_accumulation_fusion, None, None)
508
        # All-reduce across all the partitions.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
509
        if self.sequence_parallel:
510
            output_ = reduce_scatter_to_sequence_parallel_region(output_parallel)
511
512
        else:
            output_ = reduce_from_tensor_model_parallel_region(output_parallel)
513
514
515
        if not self.skip_bias_add:
            output = output_ + self.bias if self.bias is not None else output_
            output_bias = None
516
517
        else:
            output = output_
518
519
            output_bias = self.bias
        return output, output_bias