transformer.py 20.7 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transformer."""

import math
import torch
20
import torch.nn.functional as F
21

Mohammad's avatar
Mohammad committed
22
from megatron import get_args
23
from megatron import mpu
24
from megatron.mpu import LayerNorm
25
from megatron.module import MegatronModule
26
27
28
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import openai_gelu, erf_gelu
29

30
31
32
33
34
# flags required to enable jit fusion kernels
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
35
36
37
38
39
40
41
42
43
44
45

""" We use the following notation throughout this file:
     h: hidden size
     n: number of attention heads
     p: number of model parallel partitions
     np: n/p
     hp: h/p
     hn: h/n
     b: batch size
     s: sequence length
     l: number of layers
46
    Transformer takes input of size [s, b, h] and returns a
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
    tensor of the same size. We use the following arguments:
        hyperparameters: transformer hyperparameters
        attention_mask_func: a function that takes `unmaksed-attention-scores`
            with size [b, np, s, s] and an `attention-mask` and will apply
            the masking. The function should return a masked score of the
            same size [b, np, s, s].
               masked-attention-scores = attention_mask_func(
                                     unmaksed-attention-scores, attention-mask)
"""

class ParallelMLP(MegatronModule):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
    state back into h hidden dimension. At the end, dropout is also
    applied.
    """

66
    def __init__(self, init_method, output_layer_init_method):
67
        super(ParallelMLP, self).__init__()
Mohammad's avatar
Mohammad committed
68
        args = get_args()
69
70
71

        # Project to 4h.
        self.dense_h_to_4h = mpu.ColumnParallelLinear(
Mohammad's avatar
Mohammad committed
72
            args.hidden_size,
Neel Kant's avatar
Neel Kant committed
73
            4 * args.hidden_size,
74
            gather_output=False,
75
76
            init_method=init_method,
            skip_bias_add=True)
77

78
79
80
81
82
83
        self.bias_gelu_fusion = args.bias_gelu_fusion
        self.activation_func = F.gelu
        if args.openai_gelu:
            self.activation_func = openai_gelu
        elif args.onnx_safe:
            self.activation_func = erf_gelu
84
85
86

        # Project back to h.
        self.dense_4h_to_h = mpu.RowParallelLinear(
Neel Kant's avatar
Neel Kant committed
87
            4 * args.hidden_size,
Mohammad's avatar
Mohammad committed
88
            args.hidden_size,
89
            input_is_parallel=True,
90
91
92
            init_method=output_layer_init_method,
            skip_bias_add=True)
         
93
94
95

    def forward(self, hidden_states):

96
97
        # [s, b, 4hp]
        intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
98

99
100
101
102
103
104
105
106
107
108
        if self.bias_gelu_fusion:
             intermediate_parallel = \
                     bias_gelu_impl(intermediate_parallel, bias_parallel)
        else:
            intermediate_parallel = \
                self.activation_func(intermediate_parallel + bias_parallel)

        # [s, b, h]
        output, output_bias = self.dense_4h_to_h(intermediate_parallel)
        return output, output_bias
109
110
111
112
113
114
115
116


class ParallelSelfAttention(MegatronModule):
    """Parallel self-attention layer abstract class.

    Self-attention layer takes input with size [b, s, h]
    and returns output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
117

Mohammad's avatar
Mohammad committed
118
119
    def __init__(self, attention_mask_func, init_method,
                 output_layer_init_method, layer_number):
120
        super(ParallelSelfAttention, self).__init__()
Mohammad's avatar
Mohammad committed
121
        args = get_args()
Mohammad's avatar
Mohammad committed
122
        self.fp16 = args.fp16
123
124

        self.attention_mask_func = attention_mask_func
Mohammad's avatar
Mohammad committed
125
126
        self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
127
128
129
        if self.apply_query_key_layer_scaling:
            self.attention_softmax_in_fp32 = True
        self.layer_number = max(1, layer_number)
130
131
132

        # Per attention head and per partition values.
        world_size = mpu.get_model_parallel_world_size()
Mohammad's avatar
Mohammad committed
133
134
        self.hidden_size_per_partition = mpu.divide(args.hidden_size,
                                                    world_size)
135
        self.hidden_size_per_attention_head = mpu.divide(
Mohammad's avatar
Mohammad committed
136
            args.hidden_size, args.num_attention_heads)
137
        self.num_attention_heads_per_partition = mpu.divide(
Mohammad's avatar
Mohammad committed
138
            args.num_attention_heads, world_size)
139
140
141

        # Strided linear layer.
        self.query_key_value = mpu.ColumnParallelLinear(
Mohammad's avatar
Mohammad committed
142
            args.hidden_size,
Neel Kant's avatar
Neel Kant committed
143
            3 * args.hidden_size,
144
            gather_output=False,
Mohammad's avatar
Mohammad committed
145
            init_method=init_method)
146

147
148
149
150
151
152
153
154
155
156
157
158
159
        coeff = None
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        if self.apply_query_key_layer_scaling:
            coeff = self.layer_number
            self.norm_factor *= coeff

        self.scale_mask_softmax = FusedScaleMaskSoftmax(
            self.fp16,
            args.scaled_upper_triang_masked_softmax_fusion,
            self.attention_mask_func,
            self.attention_softmax_in_fp32,
            coeff)

160
161
162
        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
Mohammad's avatar
Mohammad committed
163
        self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
164
165
166

        # Output.
        self.dense = mpu.RowParallelLinear(
Mohammad's avatar
Mohammad committed
167
168
            args.hidden_size,
            args.hidden_size,
169
            input_is_parallel=True,
170
171
            init_method=output_layer_init_method,
            skip_bias_add=True)
172
173


174
175
176
    def forward(self, hidden_states, attention_mask, layer_past=None,
                get_key_value=False):
        # hidden_states: [s, b, h]
177

178
179
180
        # =====================
        # Query, Key, and Value
        # =====================
181

182
183
        # Attention heads [s, b, hp] --> [s, b, 3 * hp]
        mixed_x_layer, _ = self.query_key_value(hidden_states)
184

185
186
187
188
189
190
191
192
193
194
        # [s, b, 3 * hp] --> [s, b, np, 3 * hn]  
        new_tensor_shape = mixed_x_layer.size()[:-1] + \
            (self.num_attention_heads_per_partition,
             3 * self.hidden_size_per_attention_head)
        mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

        # [s, b, np, 3 * hn] --> 3 [s, b, np, hn]
        (query_layer,
         key_layer,
         value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
195
196


197
198
199
        # ==================================
        # Adjust key and value for inference
        # ==================================
200
201
202
203

        if layer_past is not None:
            past_key, past_value = layer_past
            key_layer = torch.cat((past_key.type_as(key_layer),
204
                                   key_layer), dim=0)
205
            value_layer = torch.cat((past_value.type_as(value_layer),
206
                                     value_layer), dim=0)
207
208
209
210
        if get_key_value:
            present = (key_layer, value_layer)


211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
        # ===================================
        # Raw attention scores. [b, np, s, s]
        # ===================================
        
        # [b, np, s, s]
        output_size = (query_layer.size(1), 
                       query_layer.size(2), 
                       query_layer.size(0), 
                       key_layer.size(0))
        
        # [s, b, np, hn] -> [s, b * np, hn]
        query_layer = query_layer.view(output_size[2],
                                       output_size[0] * output_size[1], -1)
        key_layer = key_layer.view(output_size[3],
                                   output_size[0] * output_size[1], -1)

        # preallocting result tensor: [b * np, s, s]
        matmul_result = torch.empty(
            output_size[0]*output_size[1], 
            output_size[2], 
            output_size[3],
            dtype=query_layer.dtype, 
            device=torch.cuda.current_device())

        # Raw attention scores. [b * np, s, s]
        matmul_result = torch.baddbmm(matmul_result, 
            query_layer.transpose(0, 1),   # [b * np, s, hn]
            key_layer.transpose(0,1).transpose(1, 2),  #[b * np, hn, s]
            beta=0.0, alpha=(1.0/self.norm_factor))

        # change view to [b, np, s, s]
        attention_scores = matmul_result.view(*output_size)


        # ==================================================
        # Update attention mask for inference. [b, np, s, s]
        # ==================================================
248

249
250
251
252
253
        if get_key_value:
            with torch.no_grad():
                if layer_past is not None:
                    attention_mask = attention_mask[
                        ...,
Neel Kant's avatar
Neel Kant committed
254
                        attention_scores.size(3) - 1,
255
256
257
258
259
260
261
262
                        :attention_scores.size(3)].unsqueeze(2)
                else:
                    attention_mask = attention_mask[
                        ...,
                        :attention_scores.size(3),
                        :attention_scores.size(3)]


263
264
265
        # ===========================
        # Attention probs and dropout
        # ===========================
266

267
268
269
        # attention scores and attention mask [b, np, s, s]
        attention_probs = self.scale_mask_softmax(attention_scores,
                                                  attention_mask)
270

271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        with mpu.get_cuda_rng_tracker().fork():
            attention_probs = self.attention_dropout(attention_probs)


        # =========================
        # Context layer. [s, b, hp]
        # =========================

                # value_layer -> context layer.
        # [s, b, np, hn] --> [b, np, s, hn]

        # context layer shape: [b, np, s, hn]
        output_size = (value_layer.size(1), 
                       value_layer.size(2), 
                       value_layer.size(0), 
                       value_layer.size(3)) 

        # change view [s, b * np, hn] 
        value_layer = value_layer.view(output_size[2],
                                       output_size[0] * output_size[1], -1)
        
        # change view [b * np, s, s]
        attention_probs = attention_probs.view(output_size[0] * output_size[1],
                                               output_size[2], -1)
        
        # matmul: [b * np, s, hn]
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0,1))

        # change view [b, np, s, hn]
        context_layer = context_layer.view(*output_size)

        # [b, np, s, hn] --> [s, b, np, hn]
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

        # [s, b, np, hn] --> [s, b, hp]
        new_context_layer_shape = context_layer.size()[:-2] + \
            (self.hidden_size_per_partition,)
        context_layer = context_layer.view(*new_context_layer_shape)


        # =================
        # Output. [s, b, h]
        # =================

        output, bias = self.dense(context_layer)
318
319
320
321

        if get_key_value:
            output = [output, present]

322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
        return output, bias


def bias_dropout_add(x, bias, residual, prob, training) :
    # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
    out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
    out = residual + out
    return out


def get_bias_dropout_add(training):
    def _bias_dropout_add(x, bias, residual, prob):
        return bias_dropout_add(x, bias, residual, prob, training)
    return _bias_dropout_add


@torch.jit.script
def bias_dropout_add_fused_train(x, bias, residual, prob) :
    # type: (Tensor, Tensor, Tensor, float) -> Tensor
    return bias_dropout_add(x, bias, residual, prob, True)


@torch.jit.script
def bias_dropout_add_fused_inference(x, bias, residual, prob) :
    # type: (Tensor, Tensor, Tensor, float) -> Tensor
    return bias_dropout_add(x, bias, residual, prob, False)
348
349
350
351
352
353
354
355


class ParallelTransformerLayer(MegatronModule):
    """A single transformer layer.

    Transformore layer takes input with size [b, s, h] and returns an
    output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
356

357
358
    def __init__(self, attention_mask_func, init_method, 
                 output_layer_init_method, layer_number):
Mohammad's avatar
Mohammad committed
359
        args = get_args()
360
361

        super(ParallelTransformerLayer, self).__init__()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
362
        self.layer_number = layer_number
363
364

        self.apply_residual_connection_post_layernorm \
Mohammad's avatar
Mohammad committed
365
            = args.apply_residual_connection_post_layernorm
366
367
368

        # Layernorm on the input data.
        self.input_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
369
370
            args.hidden_size,
            eps=args.layernorm_epsilon)
371
372

        # Self attention.
Mohammad's avatar
Mohammad committed
373
374
375
        self.attention = ParallelSelfAttention(attention_mask_func, init_method,
                                               output_layer_init_method,
                                               layer_number)
376
377
        self.hidden_dropout = args.hidden_dropout
        self.bias_dropout_fusion = args.bias_dropout_fusion
378
379
380

        # Layernorm on the input data.
        self.post_attention_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
381
382
            args.hidden_size,
            eps=args.layernorm_epsilon)
383
384

        # MLP
385
        self.mlp = ParallelMLP(init_method,
Mohammad's avatar
Mohammad committed
386
                               output_layer_init_method)
387
388
389
390
391
392
393
394

    def forward(self, hidden_states, attention_mask, layer_past=None,
                get_key_value=False):
        # hidden_states: [b, s, h]

        # Layer norm at the begining of the transformer layer.
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
395
396
397
398
399
400
        attention_output, attention_bias = \
            self.attention(layernorm_output,
                           attention_mask,
                           layer_past=layer_past,
                           get_key_value=get_key_value)

401
402
        if get_key_value:
            attention_output, presents = attention_output
403
    
404
405
        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
406
407
408
409
410
411
412
413
414
415
416
417
418
            residual = layernorm_output
        else:
            residual = hidden_states

        # jit scripting for a nn.module (with dropout) is not 
        # trigerring the fusion kernel. For now, we use two 
        # different nn.functional routines to account for varying
        # dropout semantics during training and inference phases.
        if self.bias_dropout_fusion:
            if self.training:
                bias_dropout_add_func = bias_dropout_add_fused_train
            else:
                bias_dropout_add_func = bias_dropout_add_fused_inference
419
        else:
420
421
422
423
424
425
426
427
428
429
            bias_dropout_add_func = get_bias_dropout_add(self.training)

        #re-enable torch grad to enable fused optimization.
        with torch.enable_grad():
            layernorm_input = bias_dropout_add_func(
                attention_output,
                attention_bias.expand_as(residual),
                residual,
                self.hidden_dropout)

430
431
432
433
        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)

        # MLP.
434
435
        mlp_output, mlp_bias = self.mlp(layernorm_output)
        
436
437
        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
438
            residual = layernorm_output
439
        else:
440
441
442
443
444
445
446
447
448
            residual = layernorm_input

        #re-enable torch grad to enable fused optimization.
        with torch.enable_grad():
            output = bias_dropout_add_func(
                mlp_output,
                mlp_bias.expand_as(residual),
                residual,
                self.hidden_dropout)
449
450
451
452
453
454
455
456
457
458

        if get_key_value:
            output = [output, presents]

        return output


class ParallelTransformer(MegatronModule):
    """Transformer class."""

459
    def __init__(self, attention_mask_func,
Mohammad's avatar
Mohammad committed
460
                 init_method, output_layer_init_method):
461
        super(ParallelTransformer, self).__init__()
Mohammad's avatar
Mohammad committed
462
        args = get_args()
463
464

        # Store activation checkpoiting flag.
Mohammad's avatar
Mohammad committed
465
466
        self.checkpoint_activations = args.checkpoint_activations
        self.checkpoint_num_layers = args.checkpoint_num_layers
467

Mohammad's avatar
Mohammad committed
468
469
470
471
472
473
474
475
476
477
478
        # Number of layers:
        self.num_layers = args.num_layers
        self.num_unique_layers = args.num_unique_layers
        if self.num_unique_layers is None:
            self.num_unique_layers = self.num_layers
        assert self.num_layers % self.num_unique_layers == 0, \
            'number of layers should be divisible by number of unique layers'
        self.param_sharing_style = args.param_sharing_style

        # Transformer layers.
        def build_layer(layer_number):
479
            return ParallelTransformerLayer(
480
481
                attention_mask_func, init_method,
                output_layer_init_method, layer_number)
482
        self.layers = torch.nn.ModuleList(
Mohammad's avatar
Mohammad committed
483
484
485
486
487
488
489
            [build_layer(i + 1) for i in range(self.num_unique_layers)])

        # Print layer ordering.
        if self.num_layers != self.num_unique_layers:
            if torch.distributed.get_rank() == 0:
                print('> will be using the following layer ordering:')
                for i in range(self.num_layers):
mohammad's avatar
mohammad committed
490
491
492
                    print('   layer id: {:3d} --> unique layer id: '
                          '{:3d}'.format(i, self._get_layer_index(i)),
                          flush=True)
493
494
495

        # Final layer norm before output.
        self.final_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
496
497
            args.hidden_size,
            eps=args.layernorm_epsilon)
498

Mohammad's avatar
Mohammad committed
499
500
501
502
503
504
505
506
507
508
    def _get_layer_index(self, layer_number):
        if self.param_sharing_style == 'grouped':
            return layer_number % self.num_unique_layers
        if self.param_sharing_style == 'spaced':
            return layer_number // (self.num_layers // self.num_unique_layers) 
        assert False, 'should not be here'

    def _get_layer(self, layer_number):
        return self.layers[self._get_layer_index(layer_number)]

509
510
511
512
513
    def _checkpointed_forward(self, hidden_states, attention_mask):
        """Forward method with activation checkpointing."""
        def custom(start, end):
            def custom_forward(*inputs):
                x_ = inputs[0]
Mohammad's avatar
Mohammad committed
514
515
                for index in range(start, end):
                    layer = self._get_layer(index)
516
517
518
519
                    x_ = layer(x_, inputs[1])
                return x_
            return custom_forward

520
521
        # Make sure memory is freed.
        mpu.reset_checkpointed_activations_memory_buffer()
522
        l = 0
Mohammad's avatar
Mohammad committed
523
        while l < self.num_layers:
524
            hidden_states = mpu.checkpoint(
Neel Kant's avatar
Neel Kant committed
525
                custom(l, l + self.checkpoint_num_layers),
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
                hidden_states, attention_mask)
            l += self.checkpoint_num_layers

        return hidden_states

    def forward(self, hidden_states, attention_mask, layer_past=None,
                get_key_value=False):

        # Checks
        if layer_past is not None:
            assert get_key_value, \
                'for not None values in layer_past, ' \
                'expected get_key_value to be set'
        if get_key_value:
            assert not self.checkpoint_activations, \
                'get_key_value does not work with ' \
                'activation checkpointing'

544
545
546
        # data format change to avoid explicit tranposes : [b s h] --> [s b h]
        hidden_states = hidden_states.transpose(0, 1).contiguous()

547
548
549
550
551
552
        if self.checkpoint_activations:
            hidden_states = self._checkpointed_forward(hidden_states,
                                                       attention_mask)
        else:
            if get_key_value:
                presents = []
Mohammad's avatar
Mohammad committed
553
554
            for index in range(self.num_layers):
                layer = self._get_layer(index)
555
556
                past = None
                if layer_past is not None:
Mohammad's avatar
Mohammad committed
557
                    past = layer_past[index]
558
559
560
561
562
563
564
                hidden_states = layer(hidden_states,
                                      attention_mask,
                                      layer_past=past,
                                      get_key_value=get_key_value)
                if get_key_value:
                    hidden_states, present = hidden_states
                    presents.append(present)
565
566
567
        
        # reverting data format change [s b h] --> [b s h]
        hidden_states = hidden_states.transpose(0, 1).contiguous()
568
569
570
571
572
573
574

        # Final layer norm.
        output = self.final_layernorm(hidden_states)
        if get_key_value:
            output = [output, presents]

        return output