layers.py 18.4 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch


import math

import torch
import torch.nn.functional as F
import torch.nn.init as init
from torch.nn.parameter import Parameter

28
29
from .initialize import get_tensor_model_parallel_rank
from .initialize import get_tensor_model_parallel_world_size
30
from .initialize import get_tensor_model_parallel_group
31
32
33
34
from .mappings import copy_to_tensor_model_parallel_region
from .mappings import gather_from_tensor_model_parallel_region
from .mappings import reduce_from_tensor_model_parallel_region
from .mappings import scatter_to_tensor_model_parallel_region
35
36
37
38
from .random import get_cuda_rng_tracker
from .utils import divide
from .utils import split_tensor_along_last_dim
from .utils import VocabUtility
39
from megatron import get_args
40

mohammad's avatar
mohammad committed
41
42
43
44
45
46

_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False,
                                      'partition_dim': -1,
                                      'partition_stride': 1}


mohammad's avatar
mohammad committed
47
48
49
50
51
52
def param_is_not_tensor_parallel_duplicate(param):
    return (hasattr(param, 'tensor_model_parallel') and
            param.tensor_model_parallel) or (
                get_tensor_model_parallel_rank() == 0)


mohammad's avatar
mohammad committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride):
    # Make sure the attributes are not set.
    for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
        assert not hasattr(tensor, attribute)
    # Set the attributes.
    setattr(tensor, 'tensor_model_parallel', is_parallel)
    setattr(tensor, 'partition_dim', dim)
    setattr(tensor, 'partition_stride', stride)


def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor):
    def maybe_set(attribute, value):
        if not hasattr(tensor, attribute):
            setattr(tensor, attribute, value)
    for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
        maybe_set(attribute, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS[attribute])


def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor):
    def maybe_copy(attribute):
        if hasattr(source_tensor, attribute):
            setattr(destination_tensor, attribute,
                    getattr(source_tensor, attribute))
    for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
        maybe_copy(attribute)


80
81
82
83
def _initialize_affine_weight_gpu(weight, init_method,
                                  partition_dim, stride=1):
    """Initialize affine weight for model parallel on GPU."""

mohammad's avatar
mohammad committed
84
85
86
87
88
    set_tensor_model_parallel_attributes(tensor=weight,
                                         is_parallel=True,
                                         dim=partition_dim,
                                         stride=stride)

89
90
91
92
93
94
95
96
    with get_cuda_rng_tracker().fork():
        init_method(weight)


def _initialize_affine_weight_cpu(weight, output_size, input_size,
                                  per_partition_size, partition_dim,
                                  init_method, stride=1,
                                  return_master_weight=False):
97
98
99
100
    """Initialize affine weight for model parallel.

    Build the master weight on all processes and scatter
    the relevant chunk."""
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
101

mohammad's avatar
mohammad committed
102
103
104
105
    set_tensor_model_parallel_attributes(tensor=weight,
                                         is_parallel=True,
                                         dim=partition_dim,
                                         stride=stride)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
106

107
108
    # Initialize master weight
    master_weight = torch.empty(output_size, input_size,
109
                                dtype=torch.float,
110
111
                                requires_grad=False)
    init_method(master_weight)
112
113
    args = get_args()
    master_weight = master_weight.to(dtype=args.params_dtype)
114
115
116
117
118

    # Split and copy
    per_partition_per_stride_size = divide(per_partition_size, stride)
    weight_list = torch.split(master_weight, per_partition_per_stride_size,
                              dim=partition_dim)
Jared Casper's avatar
Jared Casper committed
119
    rank = get_tensor_model_parallel_rank()
120
    world_size = get_tensor_model_parallel_world_size()
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
    my_weight_list = weight_list[rank::world_size]

    with torch.no_grad():
        torch.cat(my_weight_list, dim=partition_dim, out=weight)
    if return_master_weight:
        return master_weight
    return None


class VocabParallelEmbedding(torch.nn.Module):
    """Embedding parallelized in the vocabulary dimension.

    This is mainly adapted from torch.nn.Embedding and all the default
    values are kept.
    Arguments:
        num_embeddings: vocabulary size.
        embedding_dim: size of hidden state.
        init_method: method to initialize weights.
    """
Neel Kant's avatar
Neel Kant committed
140

141
142
143
144
145
146
147
148
149
150
151
152
153
    def __init__(self, num_embeddings, embedding_dim,
                 init_method=init.xavier_normal_):
        super(VocabParallelEmbedding, self).__init__()
        # Keep the input dimensions.
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        # Set the detauls for compatibility.
        self.padding_idx = None
        self.max_norm = None
        self.norm_type = 2.
        self.scale_grad_by_freq = False
        self.sparse = False
        self._weight = None
154
        self.tensor_model_parallel_size = get_tensor_model_parallel_world_size()
155
156
157
        # Divide the weight matrix along the vocaburaly dimension.
        self.vocab_start_index, self.vocab_end_index = \
            VocabUtility.vocab_range_from_global_vocab_size(
158
159
                self.num_embeddings, get_tensor_model_parallel_rank(),
                self.tensor_model_parallel_size)
160
        self.num_embeddings_per_partition = self.vocab_end_index - \
Neel Kant's avatar
Neel Kant committed
161
            self.vocab_start_index
162

163
164
        # Allocate weights and initialize.
        args = get_args()
165
        if args.use_cpu_initialization:
166
167
168
            self.weight = Parameter(torch.empty(
                self.num_embeddings_per_partition, self.embedding_dim,
                dtype=args.params_dtype))
169
170
171
172
            if args.perform_initialization:
                _initialize_affine_weight_cpu(
                    self.weight, self.num_embeddings, self.embedding_dim,
                    self.num_embeddings_per_partition, 0, init_method)
173
174
175
176
        else:
            self.weight = Parameter(torch.empty(
                self.num_embeddings_per_partition, self.embedding_dim,
                device=torch.cuda.current_device(), dtype=args.params_dtype))
177
178
179
            if args.perform_initialization:
                _initialize_affine_weight_gpu(self.weight, init_method,
                                              partition_dim=0, stride=1)
180
181

    def forward(self, input_):
182
        if self.tensor_model_parallel_size > 1:
183
184
185
186
187
188
189
190
191
            # Build the mask.
            input_mask = (input_ < self.vocab_start_index) | \
                         (input_ >= self.vocab_end_index)
            # Mask the input.
            masked_input = input_.clone() - self.vocab_start_index
            masked_input[input_mask] = 0
        else:
            masked_input = input_
            # Get the embeddings.
192
193
194
195
196
        output_parallel = F.embedding(masked_input, self.weight,
                                      self.padding_idx, self.max_norm,
                                      self.norm_type, self.scale_grad_by_freq,
                                      self.sparse)
        # Mask the output embedding.
197
        if self.tensor_model_parallel_size > 1:
198
            output_parallel[input_mask, :] = 0.0
199
        # Reduce across all the model parallel GPUs.
200
        output = reduce_from_tensor_model_parallel_region(output_parallel)
201
202
203
        return output


slym's avatar
slym committed
204
class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function):
205
206
207
208
209
    """
    Column-parallel linear layer execution with asynchronous all-reduce
    execution in backprop.
    """
    @staticmethod
slym's avatar
slym committed
210
    def forward(ctx, input, weight, bias):
211
        ctx.save_for_backward(input, weight)
slym's avatar
slym committed
212
        ctx.use_bias = bias is not None
213
        output = torch.matmul(input, weight.t())
slym's avatar
slym committed
214
        if bias is not None:
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
            output = output + bias
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, weight = ctx.saved_tensors
        use_bias = ctx.use_bias
        grad_input = grad_output.matmul(weight)
        # Asyncronous all-reduce
        handle = torch.distributed.all_reduce(
                grad_input, group=get_tensor_model_parallel_group(), async_op=True)
        # Delay the start of weight gradient computation shortly (3us) to have
        # all-reduce scheduled first and have GPU resources allocated
        _ = torch.empty(1, device=grad_output.device) + 1
        grad_weight = grad_output.t().matmul(input)
        grad_bias = grad_output.sum(dim=0) if use_bias else None
        handle.wait()
slym's avatar
slym committed
232
        return grad_input, grad_weight, grad_bias
233
234


235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
class ColumnParallelLinear(torch.nn.Module):
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias
        gather_output: If true, call all-gether on output and make Y avaiable
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        init_method: method to initialize weights. Note that bias is always set
                     to zero.
        stride: For the strided linear layers.
        keep_master_weight_for_test: This was added for testing and should be
                                     set to False. It returns the master weights
                                     used for initialization.
254
        skip_bias_add: This was added to enable performance optimations where bias
255
                       can be fused with other elementwise operations. we skip
256
                       adding bias but instead return it.
257
    """
Neel Kant's avatar
Neel Kant committed
258

259
260
    def __init__(self, input_size, output_size, bias=True, gather_output=True,
                 init_method=init.xavier_normal_, stride=1,
261
262
                 keep_master_weight_for_test=False,
                 skip_bias_add=False):
263
264
265
266
267
268
269
        super(ColumnParallelLinear, self).__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.gather_output = gather_output
        # Divide the weight matrix along the last dimension.
270
        world_size = get_tensor_model_parallel_world_size()
271
        self.output_size_per_partition = divide(output_size, world_size)
272
        self.skip_bias_add = skip_bias_add
273
274
275
276

        # Parameters.
        # Note: torch.nn.functional.linear performs XA^T + b and as a result
        # we allocate the transpose.
277
278
        # Initialize weight.
        args = get_args()
279
        if args.use_cpu_initialization:
280
281
282
            self.weight = Parameter(torch.empty(self.output_size_per_partition,
                                                self.input_size,
                                                dtype=args.params_dtype))
283
284
285
286
287
            if args.perform_initialization:
                self.master_weight = _initialize_affine_weight_cpu(
                    self.weight, self.output_size, self.input_size,
                    self.output_size_per_partition, 0, init_method,
                    stride=stride, return_master_weight=keep_master_weight_for_test)
288
289
290
291
        else:
            self.weight = Parameter(torch.empty(
                self.output_size_per_partition, self.input_size,
                device=torch.cuda.current_device(), dtype=args.params_dtype))
292
293
294
            if args.perform_initialization:
                _initialize_affine_weight_gpu(self.weight, init_method,
                                              partition_dim=0, stride=stride)
hwijeen's avatar
hwijeen committed
295

296
        if bias:
297
            if args.use_cpu_initialization:
298
299
300
301
302
303
304
                self.bias = Parameter(torch.empty(
                    self.output_size_per_partition, dtype=args.params_dtype))
            else:
                self.bias = Parameter(torch.empty(
                    self.output_size_per_partition,
                    device=torch.cuda.current_device(),
                    dtype=args.params_dtype))
305
            set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
306
307
308
309
310
            # Always initialize bias to zero.
            with torch.no_grad():
                self.bias.zero_()
        else:
            self.register_parameter('bias', None)
slym's avatar
slym committed
311
        self.async_tensor_model_parallel_allreduce = (
slym's avatar
slym committed
312
313
                not args.no_async_tensor_model_parallel_allreduce and
                world_size > 1)
314

315

316
317

    def forward(self, input_):
318
        bias = self.bias if not self.skip_bias_add else None
319

slym's avatar
slym committed
320
        if self.async_tensor_model_parallel_allreduce:
321
322
            input_shape = input_.shape
            input_ = input_.view(input_shape[0] * input_shape[1],input_shape[2])
slym's avatar
slym committed
323
324
            # Maxtrix multiply with asynchronouse all-reduce execution
            output_parallel = ColumnParallelLinearWithAsyncAllreduce.apply(
slym's avatar
slym committed
325
                    input_, self.weight, bias)
326
327
328
329
330
331
332
333
334
            output_parallel = output_parallel.view(
                    input_shape[0], input_shape[1], output_parallel.shape[1])
        else:
            # Set up backprop all-reduce.
            input_parallel = copy_to_tensor_model_parallel_region(input_)

            # Matrix multiply.
            output_parallel = F.linear(input_parallel, self.weight, bias)

335
336
        if self.gather_output:
            # All-gather across the partitions.
337
            output = gather_from_tensor_model_parallel_region(output_parallel)
338
        else:
hwijeen's avatar
hwijeen committed
339
            output = output_parallel
340
341
        output_bias = self.bias if self.skip_bias_add else None
        return output, output_bias
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368


class RowParallelLinear(torch.nn.Module):
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        init_method: method to initialize weights. Note that bias is always set
                     to zero.
        stride: For the strided linear layers.
        keep_master_weight_for_test: This was added for testing and should be
                                     set to False. It returns the master weights
                                     used for initialization.
hwijeen's avatar
hwijeen committed
369
370
        skip_bias_add: This was added to enable performance optimization where bias
                       can be fused with other elementwise operations. We skip
371
                       adding bias but instead return it.
372
    """
Neel Kant's avatar
Neel Kant committed
373

374
375
376
    def __init__(self, input_size, output_size, bias=True,
                 input_is_parallel=False,
                 init_method=init.xavier_normal_, stride=1,
377
378
                 keep_master_weight_for_test=False,
                 skip_bias_add=False):
379
380
381
382
383
384
385
        super(RowParallelLinear, self).__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.input_is_parallel = input_is_parallel
        # Divide the weight matrix along the last dimension.
386
        world_size = get_tensor_model_parallel_world_size()
387
        self.input_size_per_partition = divide(input_size, world_size)
388
        self.skip_bias_add = skip_bias_add
389
390
391
392

        # Parameters.
        # Note: torch.nn.functional.linear performs XA^T + b and as a result
        # we allocate the transpose.
393
394
        # Initialize weight.
        args = get_args()
395
        if args.use_cpu_initialization:
396
397
398
            self.weight = Parameter(torch.empty(self.output_size,
                                                self.input_size_per_partition,
                                                dtype=args.params_dtype))
399
400
401
402
403
            if args.perform_initialization:
                self.master_weight = _initialize_affine_weight_cpu(
                    self.weight, self.output_size, self.input_size,
                    self.input_size_per_partition, 1, init_method,
                    stride=stride, return_master_weight=keep_master_weight_for_test)
404
405
406
407
        else:
            self.weight = Parameter(torch.empty(
                self.output_size, self.input_size_per_partition,
                device=torch.cuda.current_device(), dtype=args.params_dtype))
408
409
410
            if args.perform_initialization:
                _initialize_affine_weight_gpu(self.weight, init_method,
                                              partition_dim=1, stride=stride)
411
        if bias:
412
            if args.use_cpu_initialization:
413
414
415
416
417
418
                self.bias = Parameter(torch.empty(self.output_size,
                                                  dtype=args.params_dtype))
            else:
                self.bias = Parameter(torch.empty(
                    self.output_size, device=torch.cuda.current_device(),
                    dtype=args.params_dtype))
419
420
421
422
423
424
            # Always initialize bias to zero.
            with torch.no_grad():
                self.bias.zero_()
        else:
            self.register_parameter('bias', None)

425

426
427
428
429
430
431

    def forward(self, input_):
        # Set up backprop all-reduce.
        if self.input_is_parallel:
            input_parallel = input_
        else:
432
            input_parallel = scatter_to_tensor_model_parallel_region(input_)
433
434
435
        # Matrix multiply.
        output_parallel = F.linear(input_parallel, self.weight)
        # All-reduce across all the partitions.
436
        output_ = reduce_from_tensor_model_parallel_region(output_parallel)
437
438
439
        if not self.skip_bias_add:
            output = output_ + self.bias if self.bias is not None else output_
            output_bias = None
440
441
        else:
            output = output_
442
443
            output_bias = self.bias
        return output, output_bias