layers.py 22.9 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch


import math

import torch
import torch.nn.functional as F
import torch.nn.init as init
from torch.nn.parameter import Parameter

28
29
from .initialize import get_tensor_model_parallel_rank
from .initialize import get_tensor_model_parallel_world_size
30
from .initialize import get_tensor_model_parallel_group
31
from .mappings import copy_to_tensor_model_parallel_region
32
33
from .mappings import gather_from_tensor_model_parallel_region
from .mappings import gather_from_sequence_parallel_region
34
from .mappings import reduce_from_tensor_model_parallel_region
35
36
from .mappings import scatter_to_tensor_model_parallel_region
from .mappings import reduce_scatter_to_sequence_parallel_region
37

38
39
40
41
from .random import get_cuda_rng_tracker
from .utils import divide
from .utils import split_tensor_along_last_dim
from .utils import VocabUtility
42
from megatron import get_args
43

mohammad's avatar
mohammad committed
44
45
46
47
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False,
                                      'partition_dim': -1,
                                      'partition_stride': 1}

48
49
_TOTAL_INPUT = None
_SUB_GRAD_INPUT = None
mohammad's avatar
mohammad committed
50

mohammad's avatar
mohammad committed
51
52
53
54
55
56
def param_is_not_tensor_parallel_duplicate(param):
    return (hasattr(param, 'tensor_model_parallel') and
            param.tensor_model_parallel) or (
                get_tensor_model_parallel_rank() == 0)


mohammad's avatar
mohammad committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride):
    # Make sure the attributes are not set.
    for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
        assert not hasattr(tensor, attribute)
    # Set the attributes.
    setattr(tensor, 'tensor_model_parallel', is_parallel)
    setattr(tensor, 'partition_dim', dim)
    setattr(tensor, 'partition_stride', stride)


def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor):
    def maybe_set(attribute, value):
        if not hasattr(tensor, attribute):
            setattr(tensor, attribute, value)
    for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
        maybe_set(attribute, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS[attribute])


def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor):
    def maybe_copy(attribute):
        if hasattr(source_tensor, attribute):
            setattr(destination_tensor, attribute,
                    getattr(source_tensor, attribute))
    for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
        maybe_copy(attribute)


84
85
86
87
def _initialize_affine_weight_gpu(weight, init_method,
                                  partition_dim, stride=1):
    """Initialize affine weight for model parallel on GPU."""

mohammad's avatar
mohammad committed
88
89
90
91
92
    set_tensor_model_parallel_attributes(tensor=weight,
                                         is_parallel=True,
                                         dim=partition_dim,
                                         stride=stride)

93
94
95
96
97
98
99
100
    with get_cuda_rng_tracker().fork():
        init_method(weight)


def _initialize_affine_weight_cpu(weight, output_size, input_size,
                                  per_partition_size, partition_dim,
                                  init_method, stride=1,
                                  return_master_weight=False):
101
102
103
104
    """Initialize affine weight for model parallel.

    Build the master weight on all processes and scatter
    the relevant chunk."""
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
105

mohammad's avatar
mohammad committed
106
107
108
109
    set_tensor_model_parallel_attributes(tensor=weight,
                                         is_parallel=True,
                                         dim=partition_dim,
                                         stride=stride)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
110

111
112
    # Initialize master weight
    master_weight = torch.empty(output_size, input_size,
113
                                dtype=torch.float,
114
115
                                requires_grad=False)
    init_method(master_weight)
116
117
    args = get_args()
    master_weight = master_weight.to(dtype=args.params_dtype)
118
119
120
121
122

    # Split and copy
    per_partition_per_stride_size = divide(per_partition_size, stride)
    weight_list = torch.split(master_weight, per_partition_per_stride_size,
                              dim=partition_dim)
Jared Casper's avatar
Jared Casper committed
123
    rank = get_tensor_model_parallel_rank()
124
    world_size = get_tensor_model_parallel_world_size()
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
    my_weight_list = weight_list[rank::world_size]

    with torch.no_grad():
        torch.cat(my_weight_list, dim=partition_dim, out=weight)
    if return_master_weight:
        return master_weight
    return None


class VocabParallelEmbedding(torch.nn.Module):
    """Embedding parallelized in the vocabulary dimension.

    This is mainly adapted from torch.nn.Embedding and all the default
    values are kept.
    Arguments:
        num_embeddings: vocabulary size.
        embedding_dim: size of hidden state.
        init_method: method to initialize weights.
    """
Neel Kant's avatar
Neel Kant committed
144

145
146
147
148
149
150
151
152
153
154
155
156
157
    def __init__(self, num_embeddings, embedding_dim,
                 init_method=init.xavier_normal_):
        super(VocabParallelEmbedding, self).__init__()
        # Keep the input dimensions.
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        # Set the detauls for compatibility.
        self.padding_idx = None
        self.max_norm = None
        self.norm_type = 2.
        self.scale_grad_by_freq = False
        self.sparse = False
        self._weight = None
158
        self.tensor_model_parallel_size = get_tensor_model_parallel_world_size()
159
160
161
        # Divide the weight matrix along the vocaburaly dimension.
        self.vocab_start_index, self.vocab_end_index = \
            VocabUtility.vocab_range_from_global_vocab_size(
162
163
                self.num_embeddings, get_tensor_model_parallel_rank(),
                self.tensor_model_parallel_size)
164
        self.num_embeddings_per_partition = self.vocab_end_index - \
Neel Kant's avatar
Neel Kant committed
165
            self.vocab_start_index
166

167
168
        # Allocate weights and initialize.
        args = get_args()
169
        if args.use_cpu_initialization:
170
171
172
173
174
175
176
177
178
179
180
181
            self.weight = Parameter(torch.empty(
                self.num_embeddings_per_partition, self.embedding_dim,
                dtype=args.params_dtype))
            _initialize_affine_weight_cpu(
                self.weight, self.num_embeddings, self.embedding_dim,
                self.num_embeddings_per_partition, 0, init_method)
        else:
            self.weight = Parameter(torch.empty(
                self.num_embeddings_per_partition, self.embedding_dim,
                device=torch.cuda.current_device(), dtype=args.params_dtype))
            _initialize_affine_weight_gpu(self.weight, init_method,
                                          partition_dim=0, stride=1)
182
183

    def forward(self, input_):
184
        if self.tensor_model_parallel_size > 1:
185
186
187
188
189
190
191
192
193
            # Build the mask.
            input_mask = (input_ < self.vocab_start_index) | \
                         (input_ >= self.vocab_end_index)
            # Mask the input.
            masked_input = input_.clone() - self.vocab_start_index
            masked_input[input_mask] = 0
        else:
            masked_input = input_
            # Get the embeddings.
194
195
196
197
198
        output_parallel = F.embedding(masked_input, self.weight,
                                      self.padding_idx, self.max_norm,
                                      self.norm_type, self.scale_grad_by_freq,
                                      self.sparse)
        # Mask the output embedding.
199
        if self.tensor_model_parallel_size > 1:
200
            output_parallel[input_mask, :] = 0.0
201
        # Reduce across all the model parallel GPUs.
202
        output = reduce_from_tensor_model_parallel_region(output_parallel)
203
204
205
        return output


206
class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
207
    """
208
209
    Linear layer execution with asynchronous communication and gradient accumulation
    fusion in backprop.
210
211
    """
    @staticmethod
212
213
    def forward(ctx, input, weight, bias, gradient_accumulation_fusion,
                async_grad_allreduce, model_parallel_memory_opt):
214
        ctx.save_for_backward(input, weight)
slym's avatar
slym committed
215
        ctx.use_bias = bias is not None
216
217
218
        ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
        ctx.async_grad_allreduce = async_grad_allreduce
        ctx.model_parallel_memory_opt = model_parallel_memory_opt
Vijay Korthikanti's avatar
Vijay Korthikanti committed
219
      
220
221
222
223
224
        if model_parallel_memory_opt:
            world_size = get_tensor_model_parallel_world_size()
            dim_size = list(input.size())
            dim_size[0] = dim_size[0] * world_size

225
226
227
228
229
            #total_input = torch.empty(dim_size, dtype=input.dtype,
            #                          device=torch.cuda.current_device(),
            #                          requires_grad=False)
            global _TOTAL_INPUT
            total_input = _TOTAL_INPUT
230
231
            torch.distributed._all_gather_base(total_input, input,
                                               group=get_tensor_model_parallel_group())
232
        
233
234
235
        else:
            total_input = input

236
237
238
239
240
241
242
        output = torch.matmul(total_input, weight.t())
        if bias is not None:
            output = output + bias
        return output

    @staticmethod
    def backward(ctx, grad_output):
Sangkug Lym's avatar
Sangkug Lym committed
243
        import fused_dense_cuda
244
245
246
        input, weight = ctx.saved_tensors
        use_bias = ctx.use_bias
        
247
248
249
250
251
        if ctx.model_parallel_memory_opt:
            world_size = get_tensor_model_parallel_world_size()
            dim_size = list(input.size())
            dim_size[0] = dim_size[0] * world_size

252
253
254
255
256
257
            #total_input = torch.empty(dim_size, dtype=input.dtype,
            #                          device=torch.cuda.current_device(),
            #                          requires_grad=False)
            global _TOTAL_INPUT
            total_input = _TOTAL_INPUT

258
259
260
261
262
263
264
            handle = torch.distributed._all_gather_base(total_input, input,
                                           group=get_tensor_model_parallel_group(), async_op=True)
            # Delay the start of intput gradient computation shortly (3us) to have
            # gather scheduled first and have GPU resources allocated
            _ = torch.empty(1, device=grad_output.device) + 1
        else:
            total_input = input
265
266
        grad_input = grad_output.matmul(weight)

267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
        if ctx.model_parallel_memory_opt:
            handle.wait()

        # Convert the tensor shapes to 2D for execution compatibility
        grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1],
                                       grad_output.shape[2])
        total_input = total_input.view(total_input.shape[0] * total_input.shape[1],
				       total_input.shape[2])
 
        if ctx.async_grad_allreduce:
            # Asynchronous all-reduce
            handle = torch.distributed.all_reduce(
                    grad_input, group=get_tensor_model_parallel_group(), async_op=True)
            # Delay the start of weight gradient computation shortly (3us) to have
            # all-reduce scheduled first and have GPU resources allocated
            _ = torch.empty(1, device=grad_output.device) + 1
 
        if ctx.model_parallel_memory_opt:
            assert not ctx.async_grad_allreduce
            dim_size = list(input.size())
            sub_grad_input = torch.empty(dim_size, dtype=input.dtype,
288
289
                                         device=torch.cuda.current_device(),
                                         requires_grad=False)
290
291
292
293
294
295
296
297
298
299
300
301
302
303
            # reduce_scatter
            handle = torch.distributed._reduce_scatter_base(sub_grad_input, grad_input, 
                                                            group=get_tensor_model_parallel_group(),
                                                            async_op=True)
            # Delay the start of weight gradient computation shortly (3us) to have
            # reduce scatter scheduled first and have GPU resources allocated
            _ = torch.empty(1, device=grad_output.device) + 1
        

        if ctx.gradient_accumulation_fusion:
            fused_dense_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, weight.main_grad)
            grad_weight = None
        else:
            grad_weight = grad_output.t().matmul(total_input)
304
305
        grad_bias = grad_output.sum(dim=0) if use_bias else None

306
307
        if ctx.model_parallel_memory_opt:
            handle.wait()
Vijay Korthikanti's avatar
Vijay Korthikanti committed
308
            return sub_grad_input, grad_weight, grad_bias, None, None, None
309

Sangkug Lym's avatar
Sangkug Lym committed
310
311
        if ctx.async_grad_allreduce:
            handle.wait()
312

Vijay Korthikanti's avatar
Vijay Korthikanti committed
313
        return grad_input, grad_weight, grad_bias, None, None, None
314
315


316
317
318
319
320
321
322
323
324
325
class ColumnParallelLinear(torch.nn.Module):
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias
Sangkug Lym's avatar
Sangkug Lym committed
326
        gather_output: If true, call all-gather on output and make Y available
327
328
329
330
331
332
333
334
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        init_method: method to initialize weights. Note that bias is always set
                     to zero.
        stride: For the strided linear layers.
        keep_master_weight_for_test: This was added for testing and should be
                                     set to False. It returns the master weights
                                     used for initialization.
335
336
337
        skip_bias_add: This was added to enable performance optimations where bias
                       can be fused with other elementwise operations. we skip 
                       adding bias but instead return it.
338
    """
Neel Kant's avatar
Neel Kant committed
339

340
341
    def __init__(self, input_size, output_size, bias=True, gather_output=True,
                 init_method=init.xavier_normal_, stride=1,
342
343
                 keep_master_weight_for_test=False,
                 skip_bias_add=False):
344
345
346
347
348
349
350
        super(ColumnParallelLinear, self).__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.gather_output = gather_output
        # Divide the weight matrix along the last dimension.
351
        world_size = get_tensor_model_parallel_world_size()
352
        self.output_size_per_partition = divide(output_size, world_size)
353
        self.skip_bias_add = skip_bias_add
354
355
356
357

        # Parameters.
        # Note: torch.nn.functional.linear performs XA^T + b and as a result
        # we allocate the transpose.
358
359
        # Initialize weight.
        args = get_args()
360
        if args.use_cpu_initialization:
361
362
363
364
365
366
367
368
369
370
371
372
373
            self.weight = Parameter(torch.empty(self.output_size_per_partition,
                                                self.input_size,
                                                dtype=args.params_dtype))
            self.master_weight = _initialize_affine_weight_cpu(
                self.weight, self.output_size, self.input_size,
                self.output_size_per_partition, 0, init_method,
                stride=stride, return_master_weight=keep_master_weight_for_test)
        else:
            self.weight = Parameter(torch.empty(
                self.output_size_per_partition, self.input_size,
                device=torch.cuda.current_device(), dtype=args.params_dtype))
            _initialize_affine_weight_gpu(self.weight, init_method,
                                          partition_dim=0, stride=stride)
hwijeen's avatar
hwijeen committed
374

375
        if bias:
376
            if args.use_cpu_initialization:
377
378
379
380
381
382
383
                self.bias = Parameter(torch.empty(
                    self.output_size_per_partition, dtype=args.params_dtype))
            else:
                self.bias = Parameter(torch.empty(
                    self.output_size_per_partition,
                    device=torch.cuda.current_device(),
                    dtype=args.params_dtype))
384
            set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
385
386
387
388
389
            # Always initialize bias to zero.
            with torch.no_grad():
                self.bias.zero_()
        else:
            self.register_parameter('bias', None)
slym's avatar
slym committed
390
        self.async_tensor_model_parallel_allreduce = (
Sangkug Lym's avatar
Sangkug Lym committed
391
                args.async_tensor_model_parallel_allreduce and
slym's avatar
slym committed
392
                world_size > 1)
393
394
395
396
397
        self.model_parallel_memory_opt = (
                args.model_parallel_memory_opt and
                world_size > 1)
        assert not self.async_tensor_model_parallel_allreduce or \
            not self.model_parallel_memory_opt
Sangkug Lym's avatar
Sangkug Lym committed
398
        self.gradient_accumulation_fusion = args.gradient_accumulation_fusion
399
400
401
402
403
        global _TOTAL_INPUT
        if _TOTAL_INPUT is None:
            _TOTAL_INPUT = torch.empty((args.seq_length, args.micro_batch_size, args.hidden_size), dtype=torch.bfloat16,
                                       device=torch.cuda.current_device(),
                                       requires_grad=False)
404

405
406

    def forward(self, input_):
407
        bias = self.bias if not self.skip_bias_add else None
408

409
410
411
        if self.async_tensor_model_parallel_allreduce or \
                self.model_parallel_memory_opt:
            input_parallel = input_
412
        else:
413
414
415
416
417
            input_parallel = copy_to_tensor_model_parallel_region(input_)
        # Matrix multiply.
        output_parallel = LinearWithGradAccumulationAndAsyncCommunication.apply(
            input_parallel, self.weight, bias, self.gradient_accumulation_fusion,
            self.async_tensor_model_parallel_allreduce, self.model_parallel_memory_opt)
418
419
        if self.gather_output:
            # All-gather across the partitions.
420
            assert not self.model_parallel_memory_opt
421
            output = gather_from_tensor_model_parallel_region(output_parallel)
422
        else:
hwijeen's avatar
hwijeen committed
423
            output = output_parallel
424
425
        output_bias = self.bias if self.skip_bias_add else None
        return output, output_bias
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452


class RowParallelLinear(torch.nn.Module):
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        init_method: method to initialize weights. Note that bias is always set
                     to zero.
        stride: For the strided linear layers.
        keep_master_weight_for_test: This was added for testing and should be
                                     set to False. It returns the master weights
                                     used for initialization.
hwijeen's avatar
hwijeen committed
453
454
        skip_bias_add: This was added to enable performance optimization where bias
                       can be fused with other elementwise operations. We skip
455
                       adding bias but instead return it.
456
    """
Neel Kant's avatar
Neel Kant committed
457

458
459
460
    def __init__(self, input_size, output_size, bias=True,
                 input_is_parallel=False,
                 init_method=init.xavier_normal_, stride=1,
461
462
                 keep_master_weight_for_test=False,
                 skip_bias_add=False):
463
464
465
466
467
468
469
        super(RowParallelLinear, self).__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.input_is_parallel = input_is_parallel
        # Divide the weight matrix along the last dimension.
470
        world_size = get_tensor_model_parallel_world_size()
471
        self.input_size_per_partition = divide(input_size, world_size)
472
        self.skip_bias_add = skip_bias_add
473
474
475
476

        # Parameters.
        # Note: torch.nn.functional.linear performs XA^T + b and as a result
        # we allocate the transpose.
477
478
        # Initialize weight.
        args = get_args()
479
        if args.use_cpu_initialization:
480
481
482
483
484
485
486
487
488
489
490
491
492
            self.weight = Parameter(torch.empty(self.output_size,
                                                self.input_size_per_partition,
                                                dtype=args.params_dtype))
            self.master_weight = _initialize_affine_weight_cpu(
                self.weight, self.output_size, self.input_size,
                self.input_size_per_partition, 1, init_method,
                stride=stride, return_master_weight=keep_master_weight_for_test)
        else:
            self.weight = Parameter(torch.empty(
                self.output_size, self.input_size_per_partition,
                device=torch.cuda.current_device(), dtype=args.params_dtype))
            _initialize_affine_weight_gpu(self.weight, init_method,
                                          partition_dim=1, stride=stride)
493
        if bias:
494
            if args.use_cpu_initialization:
495
496
497
498
499
500
                self.bias = Parameter(torch.empty(self.output_size,
                                                  dtype=args.params_dtype))
            else:
                self.bias = Parameter(torch.empty(
                    self.output_size, device=torch.cuda.current_device(),
                    dtype=args.params_dtype))
Vijay Korthikanti's avatar
Vijay Korthikanti committed
501
502
            setattr(self.bias, 'sequence_parallel', args.model_parallel_memory_opt)

503
504
505
506
507
            # Always initialize bias to zero.
            with torch.no_grad():
                self.bias.zero_()
        else:
            self.register_parameter('bias', None)
508
        self.model_parallel_memory_opt = args.model_parallel_memory_opt
Sangkug Lym's avatar
Sangkug Lym committed
509
        self.gradient_accumulation_fusion = args.gradient_accumulation_fusion
510

511

Vijay Korthikanti's avatar
Vijay Korthikanti committed
512

513
514
515
516
517
    def forward(self, input_):
        # Set up backprop all-reduce.
        if self.input_is_parallel:
            input_parallel = input_
        else:
518
            assert not self.model_parallel_memory_opt
519
            input_parallel = scatter_to_tensor_model_parallel_region(input_)
520
        # Matrix multiply.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
521
        output_parallel = LinearWithGradAccumulationAndAsyncCommunication.apply(
Sangkug Lym's avatar
Sangkug Lym committed
522
            input_parallel, self.weight, None,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
523
            self.gradient_accumulation_fusion, None, None)
524
        # All-reduce across all the partitions.
525
        if self.model_parallel_memory_opt:
526
            output_ = reduce_scatter_to_sequence_parallel_region(output_parallel)
527
528
        else:
            output_ = reduce_from_tensor_model_parallel_region(output_parallel)
529
530
531
        if not self.skip_bias_add:
            output = output_ + self.bias if self.bias is not None else output_
            output_bias = None
532
533
        else:
            output = output_
534
535
536
            output_bias = self.bias
        return output, output_bias