layers.py 15.1 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch


import math

import torch
import torch.nn.functional as F
import torch.nn.init as init
from torch.nn.parameter import Parameter

28
29
30
31
32
33
try:
    from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
    # Try to use FusedLayerNorm from Apex - this will trigger an error.
    _ = LayerNorm(8, eps=1e-5)

except Exception as e:
34
35
    print('WARNING: APEX is not installed, using torch.nn.LayerNorm '
          'instead of apex.normalization.FusedLayerNorm!')
36
    from torch.nn import LayerNorm
37

38
39
40
41
42
43
from .initialize import get_intra_layer_model_parallel_rank
from .initialize import get_intra_layer_model_parallel_world_size
from .mappings import copy_to_intra_layer_model_parallel_region
from .mappings import gather_from_intra_layer_model_parallel_region
from .mappings import reduce_from_intra_layer_model_parallel_region
from .mappings import scatter_to_intra_layer_model_parallel_region
44
45
46
47
from .random import get_cuda_rng_tracker
from .utils import divide
from .utils import split_tensor_along_last_dim
from .utils import VocabUtility
48
from megatron import get_args
49

50
51
52
53
def _initialize_affine_weight_gpu(weight, init_method,
                                  partition_dim, stride=1):
    """Initialize affine weight for model parallel on GPU."""

54
    weight.intra_layer_model_parallel = True
55
56
    weight.partition_dim = partition_dim
    weight.partition_stride = stride
57
    
58
59
60
61
62
63
64
65
    with get_cuda_rng_tracker().fork():
        init_method(weight)


def _initialize_affine_weight_cpu(weight, output_size, input_size,
                                  per_partition_size, partition_dim,
                                  init_method, stride=1,
                                  return_master_weight=False):
66
67
68
69
    """Initialize affine weight for model parallel.

    Build the master weight on all processes and scatter
    the relevant chunk."""
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
70

71
    weight.intra_layer_model_parallel = True
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
72
    weight.partition_dim = partition_dim
73
    weight.partition_stride = stride
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
74

75
76
    # Initialize master weight
    master_weight = torch.empty(output_size, input_size,
77
                                dtype=torch.float,
78
79
                                requires_grad=False)
    init_method(master_weight)
80
81
    args = get_args()
    master_weight = master_weight.to(dtype=args.params_dtype)
82
83
84
85
86
87

    # Split and copy
    per_partition_per_stride_size = divide(per_partition_size, stride)
    weight_list = torch.split(master_weight, per_partition_per_stride_size,
                              dim=partition_dim)
    rank = get_model_parallel_rank()
88
    world_size = get_intra_layer_model_parallel_world_size()
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
    my_weight_list = weight_list[rank::world_size]

    with torch.no_grad():
        torch.cat(my_weight_list, dim=partition_dim, out=weight)
    if return_master_weight:
        return master_weight
    return None


class VocabParallelEmbedding(torch.nn.Module):
    """Embedding parallelized in the vocabulary dimension.

    This is mainly adapted from torch.nn.Embedding and all the default
    values are kept.
    Arguments:
        num_embeddings: vocabulary size.
        embedding_dim: size of hidden state.
        init_method: method to initialize weights.
    """
Neel Kant's avatar
Neel Kant committed
108

109
110
111
112
113
114
115
116
117
118
119
120
121
    def __init__(self, num_embeddings, embedding_dim,
                 init_method=init.xavier_normal_):
        super(VocabParallelEmbedding, self).__init__()
        # Keep the input dimensions.
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        # Set the detauls for compatibility.
        self.padding_idx = None
        self.max_norm = None
        self.norm_type = 2.
        self.scale_grad_by_freq = False
        self.sparse = False
        self._weight = None
122
        self.intra_layer_model_parallel_size = get_intra_layer_model_parallel_world_size()
123
124
125
        # Divide the weight matrix along the vocaburaly dimension.
        self.vocab_start_index, self.vocab_end_index = \
            VocabUtility.vocab_range_from_global_vocab_size(
126
127
                self.num_embeddings, get_intra_layer_model_parallel_rank(),
                self.intra_layer_model_parallel_size)
128
        self.num_embeddings_per_partition = self.vocab_end_index - \
Neel Kant's avatar
Neel Kant committed
129
            self.vocab_start_index
130

131
132
        # Allocate weights and initialize.
        args = get_args()
133
        if args.use_cpu_initialization:
134
135
136
137
138
139
140
141
142
143
144
145
            self.weight = Parameter(torch.empty(
                self.num_embeddings_per_partition, self.embedding_dim,
                dtype=args.params_dtype))
            _initialize_affine_weight_cpu(
                self.weight, self.num_embeddings, self.embedding_dim,
                self.num_embeddings_per_partition, 0, init_method)
        else:
            self.weight = Parameter(torch.empty(
                self.num_embeddings_per_partition, self.embedding_dim,
                device=torch.cuda.current_device(), dtype=args.params_dtype))
            _initialize_affine_weight_gpu(self.weight, init_method,
                                          partition_dim=0, stride=1)
146
147

    def forward(self, input_):
148
        if self.intra_layer_model_parallel_size > 1:
149
150
151
152
153
154
155
156
157
            # Build the mask.
            input_mask = (input_ < self.vocab_start_index) | \
                         (input_ >= self.vocab_end_index)
            # Mask the input.
            masked_input = input_.clone() - self.vocab_start_index
            masked_input[input_mask] = 0
        else:
            masked_input = input_
            # Get the embeddings.
158
159
160
161
162
        output_parallel = F.embedding(masked_input, self.weight,
                                      self.padding_idx, self.max_norm,
                                      self.norm_type, self.scale_grad_by_freq,
                                      self.sparse)
        # Mask the output embedding.
163
        if self.intra_layer_model_parallel_size > 1:
164
            output_parallel[input_mask, :] = 0.0
165
        # Reduce across all the model parallel GPUs.
166
        output = reduce_from_intra_layer_model_parallel_region(output_parallel)
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
        return output


class ColumnParallelLinear(torch.nn.Module):
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias
        gather_output: If true, call all-gether on output and make Y avaiable
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        init_method: method to initialize weights. Note that bias is always set
                     to zero.
        stride: For the strided linear layers.
        keep_master_weight_for_test: This was added for testing and should be
                                     set to False. It returns the master weights
                                     used for initialization.
189
190
191
        skip_bias_add: This was added to enable performance optimations where bias
                       can be fused with other elementwise operations. we skip 
                       adding bias but instead return it.
192
    """
Neel Kant's avatar
Neel Kant committed
193

194
195
    def __init__(self, input_size, output_size, bias=True, gather_output=True,
                 init_method=init.xavier_normal_, stride=1,
196
197
                 keep_master_weight_for_test=False,
                 skip_bias_add=False):
198
199
200
201
202
203
204
        super(ColumnParallelLinear, self).__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.gather_output = gather_output
        # Divide the weight matrix along the last dimension.
205
        world_size = get_intra_layer_model_parallel_world_size()
206
        self.output_size_per_partition = divide(output_size, world_size)
207
        self.skip_bias_add = skip_bias_add
208
209
210
211

        # Parameters.
        # Note: torch.nn.functional.linear performs XA^T + b and as a result
        # we allocate the transpose.
212
213
        # Initialize weight.
        args = get_args()
214
        if args.use_cpu_initialization:
215
216
217
218
219
220
221
222
223
224
225
226
227
228
            self.weight = Parameter(torch.empty(self.output_size_per_partition,
                                                self.input_size,
                                                dtype=args.params_dtype))
            self.master_weight = _initialize_affine_weight_cpu(
                self.weight, self.output_size, self.input_size,
                self.output_size_per_partition, 0, init_method,
                stride=stride, return_master_weight=keep_master_weight_for_test)
        else:
            self.weight = Parameter(torch.empty(
                self.output_size_per_partition, self.input_size,
                device=torch.cuda.current_device(), dtype=args.params_dtype))
            _initialize_affine_weight_gpu(self.weight, init_method,
                                          partition_dim=0, stride=stride)
            
229
        if bias:
230
            if args.use_cpu_initialization:
231
232
233
234
235
236
237
                self.bias = Parameter(torch.empty(
                    self.output_size_per_partition, dtype=args.params_dtype))
            else:
                self.bias = Parameter(torch.empty(
                    self.output_size_per_partition,
                    device=torch.cuda.current_device(),
                    dtype=args.params_dtype))
238
            self.bias.intra_layer_model_parallel = True
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
239
240
            self.bias.partition_dim = 0
            self.bias.stride = stride
241
242
243
244
245
246
            # Always initialize bias to zero.
            with torch.no_grad():
                self.bias.zero_()
        else:
            self.register_parameter('bias', None)

247

248
249
250

    def forward(self, input_):
        # Set up backprop all-reduce.
251
        input_parallel = copy_to_intra_layer_model_parallel_region(input_)
252
        # Matrix multiply.
253
254
255

        bias = self.bias if not self.skip_bias_add else None
        output_parallel = F.linear(input_parallel, self.weight, bias)
256
257
        if self.gather_output:
            # All-gather across the partitions.
258
            output = gather_from_intra_layer_model_parallel_region(output_parallel)
259
        else:
260
261
262
            output = output_parallel 
        output_bias = self.bias if self.skip_bias_add else None
        return output, output_bias
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289


class RowParallelLinear(torch.nn.Module):
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        init_method: method to initialize weights. Note that bias is always set
                     to zero.
        stride: For the strided linear layers.
        keep_master_weight_for_test: This was added for testing and should be
                                     set to False. It returns the master weights
                                     used for initialization.
290
291
292
        skip_bias_add: This was added to enable performance optimations where bias
                       can be fused with other elementwise operations. we skip 
                       adding bias but instead return it.
293
    """
Neel Kant's avatar
Neel Kant committed
294

295
296
297
    def __init__(self, input_size, output_size, bias=True,
                 input_is_parallel=False,
                 init_method=init.xavier_normal_, stride=1,
298
299
                 keep_master_weight_for_test=False,
                 skip_bias_add=False):
300
301
302
303
304
305
306
        super(RowParallelLinear, self).__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.input_is_parallel = input_is_parallel
        # Divide the weight matrix along the last dimension.
307
        world_size = get_intra_layer_model_parallel_world_size()
308
        self.input_size_per_partition = divide(input_size, world_size)
309
        self.skip_bias_add = skip_bias_add
310
311
312
313

        # Parameters.
        # Note: torch.nn.functional.linear performs XA^T + b and as a result
        # we allocate the transpose.
314
315
        # Initialize weight.
        args = get_args()
316
        if args.use_cpu_initialization:
317
318
319
320
321
322
323
324
325
326
327
328
329
            self.weight = Parameter(torch.empty(self.output_size,
                                                self.input_size_per_partition,
                                                dtype=args.params_dtype))
            self.master_weight = _initialize_affine_weight_cpu(
                self.weight, self.output_size, self.input_size,
                self.input_size_per_partition, 1, init_method,
                stride=stride, return_master_weight=keep_master_weight_for_test)
        else:
            self.weight = Parameter(torch.empty(
                self.output_size, self.input_size_per_partition,
                device=torch.cuda.current_device(), dtype=args.params_dtype))
            _initialize_affine_weight_gpu(self.weight, init_method,
                                          partition_dim=1, stride=stride)
330
        if bias:
331
            if args.use_cpu_initialization:
332
333
334
335
336
337
                self.bias = Parameter(torch.empty(self.output_size,
                                                  dtype=args.params_dtype))
            else:
                self.bias = Parameter(torch.empty(
                    self.output_size, device=torch.cuda.current_device(),
                    dtype=args.params_dtype))
338
339
340
341
342
343
            # Always initialize bias to zero.
            with torch.no_grad():
                self.bias.zero_()
        else:
            self.register_parameter('bias', None)

344

345
346
347
348
349
350

    def forward(self, input_):
        # Set up backprop all-reduce.
        if self.input_is_parallel:
            input_parallel = input_
        else:
351
            input_parallel = scatter_to_intra_layer_model_parallel_region(input_)
352
353
354
        # Matrix multiply.
        output_parallel = F.linear(input_parallel, self.weight)
        # All-reduce across all the partitions.
355
        output_ = reduce_from_intra_layer_model_parallel_region(output_parallel)
356
357
358
        if not self.skip_bias_add:
            output = output_ + self.bias if self.bias is not None else output_
            output_bias = None
359
360
        else:
            output = output_
361
362
363
            output_bias = self.bias
        return output, output_bias