layers.py 22.3 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch


import math

import torch
import torch.nn.functional as F
import torch.nn.init as init
from torch.nn.parameter import Parameter

28
29
from .initialize import get_tensor_model_parallel_rank
from .initialize import get_tensor_model_parallel_world_size
30
from .initialize import get_tensor_model_parallel_group
31
from .mappings import copy_to_tensor_model_parallel_region
32
33
from .mappings import gather_from_tensor_model_parallel_region
from .mappings import gather_from_sequence_parallel_region
34
from .mappings import reduce_from_tensor_model_parallel_region
35
36
from .mappings import scatter_to_tensor_model_parallel_region
from .mappings import reduce_scatter_to_sequence_parallel_region
37

38
39
40
41
from .random import get_cuda_rng_tracker
from .utils import divide
from .utils import split_tensor_along_last_dim
from .utils import VocabUtility
42
from megatron import get_args
43

mohammad's avatar
mohammad committed
44
45
46
47
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False,
                                      'partition_dim': -1,
                                      'partition_stride': 1}

mohammad's avatar
mohammad committed
48
49
50
51
52
53
def param_is_not_tensor_parallel_duplicate(param):
    return (hasattr(param, 'tensor_model_parallel') and
            param.tensor_model_parallel) or (
                get_tensor_model_parallel_rank() == 0)


mohammad's avatar
mohammad committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride):
    # Make sure the attributes are not set.
    for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
        assert not hasattr(tensor, attribute)
    # Set the attributes.
    setattr(tensor, 'tensor_model_parallel', is_parallel)
    setattr(tensor, 'partition_dim', dim)
    setattr(tensor, 'partition_stride', stride)


def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor):
    def maybe_set(attribute, value):
        if not hasattr(tensor, attribute):
            setattr(tensor, attribute, value)
    for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
        maybe_set(attribute, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS[attribute])


def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor):
    def maybe_copy(attribute):
        if hasattr(source_tensor, attribute):
            setattr(destination_tensor, attribute,
                    getattr(source_tensor, attribute))
    for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
        maybe_copy(attribute)


81
82
83
84
def _initialize_affine_weight_gpu(weight, init_method,
                                  partition_dim, stride=1):
    """Initialize affine weight for model parallel on GPU."""

mohammad's avatar
mohammad committed
85
86
87
88
89
    set_tensor_model_parallel_attributes(tensor=weight,
                                         is_parallel=True,
                                         dim=partition_dim,
                                         stride=stride)

90
91
92
93
94
95
96
97
    with get_cuda_rng_tracker().fork():
        init_method(weight)


def _initialize_affine_weight_cpu(weight, output_size, input_size,
                                  per_partition_size, partition_dim,
                                  init_method, stride=1,
                                  return_master_weight=False):
98
99
100
101
    """Initialize affine weight for model parallel.

    Build the master weight on all processes and scatter
    the relevant chunk."""
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
102

mohammad's avatar
mohammad committed
103
104
105
106
    set_tensor_model_parallel_attributes(tensor=weight,
                                         is_parallel=True,
                                         dim=partition_dim,
                                         stride=stride)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
107

108
109
    # Initialize master weight
    master_weight = torch.empty(output_size, input_size,
110
                                dtype=torch.float,
111
112
                                requires_grad=False)
    init_method(master_weight)
113
114
    args = get_args()
    master_weight = master_weight.to(dtype=args.params_dtype)
115
116
117
118
119

    # Split and copy
    per_partition_per_stride_size = divide(per_partition_size, stride)
    weight_list = torch.split(master_weight, per_partition_per_stride_size,
                              dim=partition_dim)
Jared Casper's avatar
Jared Casper committed
120
    rank = get_tensor_model_parallel_rank()
121
    world_size = get_tensor_model_parallel_world_size()
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
    my_weight_list = weight_list[rank::world_size]

    with torch.no_grad():
        torch.cat(my_weight_list, dim=partition_dim, out=weight)
    if return_master_weight:
        return master_weight
    return None


class VocabParallelEmbedding(torch.nn.Module):
    """Embedding parallelized in the vocabulary dimension.

    This is mainly adapted from torch.nn.Embedding and all the default
    values are kept.
    Arguments:
        num_embeddings: vocabulary size.
        embedding_dim: size of hidden state.
        init_method: method to initialize weights.
    """
Neel Kant's avatar
Neel Kant committed
141

142
143
144
145
146
147
148
149
150
151
152
153
154
    def __init__(self, num_embeddings, embedding_dim,
                 init_method=init.xavier_normal_):
        super(VocabParallelEmbedding, self).__init__()
        # Keep the input dimensions.
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        # Set the detauls for compatibility.
        self.padding_idx = None
        self.max_norm = None
        self.norm_type = 2.
        self.scale_grad_by_freq = False
        self.sparse = False
        self._weight = None
155
        self.tensor_model_parallel_size = get_tensor_model_parallel_world_size()
156
157
158
        # Divide the weight matrix along the vocaburaly dimension.
        self.vocab_start_index, self.vocab_end_index = \
            VocabUtility.vocab_range_from_global_vocab_size(
159
160
                self.num_embeddings, get_tensor_model_parallel_rank(),
                self.tensor_model_parallel_size)
161
        self.num_embeddings_per_partition = self.vocab_end_index - \
Neel Kant's avatar
Neel Kant committed
162
            self.vocab_start_index
163

164
165
        # Allocate weights and initialize.
        args = get_args()
166
        if args.use_cpu_initialization:
167
168
169
170
171
172
173
174
175
176
177
178
            self.weight = Parameter(torch.empty(
                self.num_embeddings_per_partition, self.embedding_dim,
                dtype=args.params_dtype))
            _initialize_affine_weight_cpu(
                self.weight, self.num_embeddings, self.embedding_dim,
                self.num_embeddings_per_partition, 0, init_method)
        else:
            self.weight = Parameter(torch.empty(
                self.num_embeddings_per_partition, self.embedding_dim,
                device=torch.cuda.current_device(), dtype=args.params_dtype))
            _initialize_affine_weight_gpu(self.weight, init_method,
                                          partition_dim=0, stride=1)
179
180

    def forward(self, input_):
181
        if self.tensor_model_parallel_size > 1:
182
183
184
185
186
187
188
189
190
            # Build the mask.
            input_mask = (input_ < self.vocab_start_index) | \
                         (input_ >= self.vocab_end_index)
            # Mask the input.
            masked_input = input_.clone() - self.vocab_start_index
            masked_input[input_mask] = 0
        else:
            masked_input = input_
            # Get the embeddings.
191
192
193
194
195
        output_parallel = F.embedding(masked_input, self.weight,
                                      self.padding_idx, self.max_norm,
                                      self.norm_type, self.scale_grad_by_freq,
                                      self.sparse)
        # Mask the output embedding.
196
        if self.tensor_model_parallel_size > 1:
197
            output_parallel[input_mask, :] = 0.0
198
        # Reduce across all the model parallel GPUs.
199
        output = reduce_from_tensor_model_parallel_region(output_parallel)
200
201
202
        return output


203
class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
204
    """
205
206
    Linear layer execution with asynchronous communication and gradient accumulation
    fusion in backprop.
207
    """
Vijay Korthikanti's avatar
Vijay Korthikanti committed
208

209
    @staticmethod
210
    def forward(ctx, input, weight, bias, gradient_accumulation_fusion,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
211
                async_grad_allreduce, sequence_parallel):
212
        ctx.save_for_backward(input, weight)
slym's avatar
slym committed
213
        ctx.use_bias = bias is not None
214
215
        ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
        ctx.async_grad_allreduce = async_grad_allreduce
Vijay Korthikanti's avatar
Vijay Korthikanti committed
216
        ctx.sequence_parallel = sequence_parallel
Vijay Korthikanti's avatar
Vijay Korthikanti committed
217
      
Vijay Korthikanti's avatar
Vijay Korthikanti committed
218
        if sequence_parallel:
219
220
221
222
            world_size = get_tensor_model_parallel_world_size()
            dim_size = list(input.size())
            dim_size[0] = dim_size[0] * world_size

Vijay Korthikanti's avatar
Vijay Korthikanti committed
223
224
225
226
            all_gather_buffer = \
                torch.empty(dim_size, dtype=input.dtype,
                            device=torch.cuda.current_device(),
                            requires_grad=False)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
227
            torch.distributed._all_gather_base(
Vijay Korthikanti's avatar
Vijay Korthikanti committed
228
                all_gather_buffer,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
229
230
                input,
                group=get_tensor_model_parallel_group())
Vijay Korthikanti's avatar
Vijay Korthikanti committed
231
            total_input = all_gather_buffer
232
233
234
        else:
            total_input = input

235
236
237
238
239
240
241
242
243
244
        output = torch.matmul(total_input, weight.t())
        if bias is not None:
            output = output + bias
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, weight = ctx.saved_tensors
        use_bias = ctx.use_bias
        
Vijay Korthikanti's avatar
Vijay Korthikanti committed
245
        if ctx.sequence_parallel:
246
247
248
249
            world_size = get_tensor_model_parallel_world_size()
            dim_size = list(input.size())
            dim_size[0] = dim_size[0] * world_size

Vijay Korthikanti's avatar
Vijay Korthikanti committed
250
251
252
253
254
            all_gather_buffer = \
                torch.empty(dim_size, dtype=input.dtype,
                            device=torch.cuda.current_device(),
                            requires_grad=False)
           
Vijay Korthikanti's avatar
Vijay Korthikanti committed
255
            handle = torch.distributed._all_gather_base(
Vijay Korthikanti's avatar
Vijay Korthikanti committed
256
                all_gather_buffer,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
257
258
                input,
                group=get_tensor_model_parallel_group(), async_op=True)
259

260
261
262
            # Delay the start of intput gradient computation shortly (3us) to have
            # gather scheduled first and have GPU resources allocated
            _ = torch.empty(1, device=grad_output.device) + 1
Vijay Korthikanti's avatar
Vijay Korthikanti committed
263
            total_input = all_gather_buffer
264
265
        else:
            total_input = input
266
267
        grad_input = grad_output.matmul(weight)

Vijay Korthikanti's avatar
Vijay Korthikanti committed
268
        if ctx.sequence_parallel:
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
            handle.wait()

        # Convert the tensor shapes to 2D for execution compatibility
        grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1],
                                       grad_output.shape[2])
        total_input = total_input.view(total_input.shape[0] * total_input.shape[1],
				       total_input.shape[2])
 
        if ctx.async_grad_allreduce:
            # Asynchronous all-reduce
            handle = torch.distributed.all_reduce(
                    grad_input, group=get_tensor_model_parallel_group(), async_op=True)
            # Delay the start of weight gradient computation shortly (3us) to have
            # all-reduce scheduled first and have GPU resources allocated
            _ = torch.empty(1, device=grad_output.device) + 1
 
Vijay Korthikanti's avatar
Vijay Korthikanti committed
285
        if ctx.sequence_parallel:
286
287
288
            assert not ctx.async_grad_allreduce
            dim_size = list(input.size())
            sub_grad_input = torch.empty(dim_size, dtype=input.dtype,
289
290
                                         device=torch.cuda.current_device(),
                                         requires_grad=False)
291
292
293
294
295
296
297
298
299
300
            # reduce_scatter
            handle = torch.distributed._reduce_scatter_base(sub_grad_input, grad_input, 
                                                            group=get_tensor_model_parallel_group(),
                                                            async_op=True)
            # Delay the start of weight gradient computation shortly (3us) to have
            # reduce scatter scheduled first and have GPU resources allocated
            _ = torch.empty(1, device=grad_output.device) + 1
        

        if ctx.gradient_accumulation_fusion:
301
            import fused_dense_cuda
302
303
304
305
            fused_dense_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, weight.main_grad)
            grad_weight = None
        else:
            grad_weight = grad_output.t().matmul(total_input)
306
307
        grad_bias = grad_output.sum(dim=0) if use_bias else None

Vijay Korthikanti's avatar
Vijay Korthikanti committed
308
        if ctx.sequence_parallel:
309
            handle.wait()
Vijay Korthikanti's avatar
Vijay Korthikanti committed
310
            return sub_grad_input, grad_weight, grad_bias, None, None, None
311

Sangkug Lym's avatar
Sangkug Lym committed
312
313
        if ctx.async_grad_allreduce:
            handle.wait()
314

Vijay Korthikanti's avatar
Vijay Korthikanti committed
315
        return grad_input, grad_weight, grad_bias, None, None, None
316
317


318
319
320
321
322
323
324
325
326
327
class ColumnParallelLinear(torch.nn.Module):
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias
Sangkug Lym's avatar
Sangkug Lym committed
328
        gather_output: If true, call all-gather on output and make Y available
329
330
331
332
333
334
335
336
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        init_method: method to initialize weights. Note that bias is always set
                     to zero.
        stride: For the strided linear layers.
        keep_master_weight_for_test: This was added for testing and should be
                                     set to False. It returns the master weights
                                     used for initialization.
337
338
339
        skip_bias_add: This was added to enable performance optimations where bias
                       can be fused with other elementwise operations. we skip 
                       adding bias but instead return it.
340
    """
Neel Kant's avatar
Neel Kant committed
341

342
343
    def __init__(self, input_size, output_size, bias=True, gather_output=True,
                 init_method=init.xavier_normal_, stride=1,
344
345
                 keep_master_weight_for_test=False,
                 skip_bias_add=False):
346
347
348
349
350
351
352
        super(ColumnParallelLinear, self).__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.gather_output = gather_output
        # Divide the weight matrix along the last dimension.
353
        world_size = get_tensor_model_parallel_world_size()
354
        self.output_size_per_partition = divide(output_size, world_size)
355
        self.skip_bias_add = skip_bias_add
356
357
358
359

        # Parameters.
        # Note: torch.nn.functional.linear performs XA^T + b and as a result
        # we allocate the transpose.
360
361
        # Initialize weight.
        args = get_args()
362
        if args.use_cpu_initialization:
363
364
365
366
367
368
369
370
371
372
373
374
375
            self.weight = Parameter(torch.empty(self.output_size_per_partition,
                                                self.input_size,
                                                dtype=args.params_dtype))
            self.master_weight = _initialize_affine_weight_cpu(
                self.weight, self.output_size, self.input_size,
                self.output_size_per_partition, 0, init_method,
                stride=stride, return_master_weight=keep_master_weight_for_test)
        else:
            self.weight = Parameter(torch.empty(
                self.output_size_per_partition, self.input_size,
                device=torch.cuda.current_device(), dtype=args.params_dtype))
            _initialize_affine_weight_gpu(self.weight, init_method,
                                          partition_dim=0, stride=stride)
hwijeen's avatar
hwijeen committed
376

377
        if bias:
378
            if args.use_cpu_initialization:
379
380
381
382
383
384
385
                self.bias = Parameter(torch.empty(
                    self.output_size_per_partition, dtype=args.params_dtype))
            else:
                self.bias = Parameter(torch.empty(
                    self.output_size_per_partition,
                    device=torch.cuda.current_device(),
                    dtype=args.params_dtype))
386
            set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
387
388
389
390
391
            # Always initialize bias to zero.
            with torch.no_grad():
                self.bias.zero_()
        else:
            self.register_parameter('bias', None)
slym's avatar
slym committed
392
        self.async_tensor_model_parallel_allreduce = (
Sangkug Lym's avatar
Sangkug Lym committed
393
                args.async_tensor_model_parallel_allreduce and
slym's avatar
slym committed
394
                world_size > 1)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
395
396
        self.sequence_parallel = (
                args.sequence_parallel and
397
398
                world_size > 1)
        assert not self.async_tensor_model_parallel_allreduce or \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
399
            not self.sequence_parallel
Sangkug Lym's avatar
Sangkug Lym committed
400
        self.gradient_accumulation_fusion = args.gradient_accumulation_fusion
401
402

    def forward(self, input_):
403
        bias = self.bias if not self.skip_bias_add else None
404

405
        if self.async_tensor_model_parallel_allreduce or \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
406
                self.sequence_parallel:
407
            input_parallel = input_
408
        else:
409
410
411
412
            input_parallel = copy_to_tensor_model_parallel_region(input_)
        # Matrix multiply.
        output_parallel = LinearWithGradAccumulationAndAsyncCommunication.apply(
            input_parallel, self.weight, bias, self.gradient_accumulation_fusion,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
413
            self.async_tensor_model_parallel_allreduce, self.sequence_parallel)
414
415
        if self.gather_output:
            # All-gather across the partitions.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
416
            assert not self.sequence_parallel
417
            output = gather_from_tensor_model_parallel_region(output_parallel)
418
        else:
hwijeen's avatar
hwijeen committed
419
            output = output_parallel
420
421
        output_bias = self.bias if self.skip_bias_add else None
        return output, output_bias
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448


class RowParallelLinear(torch.nn.Module):
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        init_method: method to initialize weights. Note that bias is always set
                     to zero.
        stride: For the strided linear layers.
        keep_master_weight_for_test: This was added for testing and should be
                                     set to False. It returns the master weights
                                     used for initialization.
hwijeen's avatar
hwijeen committed
449
450
        skip_bias_add: This was added to enable performance optimization where bias
                       can be fused with other elementwise operations. We skip
451
                       adding bias but instead return it.
452
    """
Neel Kant's avatar
Neel Kant committed
453

454
455
456
    def __init__(self, input_size, output_size, bias=True,
                 input_is_parallel=False,
                 init_method=init.xavier_normal_, stride=1,
457
458
                 keep_master_weight_for_test=False,
                 skip_bias_add=False):
459
460
461
462
463
464
465
        super(RowParallelLinear, self).__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.input_is_parallel = input_is_parallel
        # Divide the weight matrix along the last dimension.
466
        world_size = get_tensor_model_parallel_world_size()
467
        self.input_size_per_partition = divide(input_size, world_size)
468
        self.skip_bias_add = skip_bias_add
469
470
471
472

        # Parameters.
        # Note: torch.nn.functional.linear performs XA^T + b and as a result
        # we allocate the transpose.
473
474
        # Initialize weight.
        args = get_args()
475
        if args.use_cpu_initialization:
476
477
478
479
480
481
482
483
484
485
486
487
488
            self.weight = Parameter(torch.empty(self.output_size,
                                                self.input_size_per_partition,
                                                dtype=args.params_dtype))
            self.master_weight = _initialize_affine_weight_cpu(
                self.weight, self.output_size, self.input_size,
                self.input_size_per_partition, 1, init_method,
                stride=stride, return_master_weight=keep_master_weight_for_test)
        else:
            self.weight = Parameter(torch.empty(
                self.output_size, self.input_size_per_partition,
                device=torch.cuda.current_device(), dtype=args.params_dtype))
            _initialize_affine_weight_gpu(self.weight, init_method,
                                          partition_dim=1, stride=stride)
489
        if bias:
490
            if args.use_cpu_initialization:
491
492
493
494
495
496
                self.bias = Parameter(torch.empty(self.output_size,
                                                  dtype=args.params_dtype))
            else:
                self.bias = Parameter(torch.empty(
                    self.output_size, device=torch.cuda.current_device(),
                    dtype=args.params_dtype))
Vijay Korthikanti's avatar
Vijay Korthikanti committed
497
            setattr(self.bias, 'sequence_parallel', args.sequence_parallel)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
498

499
500
501
502
503
            # Always initialize bias to zero.
            with torch.no_grad():
                self.bias.zero_()
        else:
            self.register_parameter('bias', None)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
504
        self.sequence_parallel = args.sequence_parallel
Sangkug Lym's avatar
Sangkug Lym committed
505
        self.gradient_accumulation_fusion = args.gradient_accumulation_fusion
506

507

Vijay Korthikanti's avatar
Vijay Korthikanti committed
508

509
510
511
512
513
    def forward(self, input_):
        # Set up backprop all-reduce.
        if self.input_is_parallel:
            input_parallel = input_
        else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
514
            assert not self.sequence_parallel
515
            input_parallel = scatter_to_tensor_model_parallel_region(input_)
516
        # Matrix multiply.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
517
        output_parallel = LinearWithGradAccumulationAndAsyncCommunication.apply(
Sangkug Lym's avatar
Sangkug Lym committed
518
            input_parallel, self.weight, None,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
519
            self.gradient_accumulation_fusion, None, None)
520
        # All-reduce across all the partitions.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
521
        if self.sequence_parallel:
522
            output_ = reduce_scatter_to_sequence_parallel_region(output_parallel)
523
524
        else:
            output_ = reduce_from_tensor_model_parallel_region(output_parallel)
525
526
527
        if not self.skip_bias_add:
            output = output_ + self.bias if self.bias is not None else output_
            output_bias = None
528
529
        else:
            output = output_
530
531
532
            output_bias = self.bias
        return output, output_bias