transformer.py 21.7 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transformer."""

import math
import torch
20
import torch.nn.functional as F
21

Mohammad's avatar
Mohammad committed
22
from megatron import get_args
23
from megatron import mpu
24
from megatron.mpu import LayerNorm
25
from megatron.module import MegatronModule
Vijay Korthikanti's avatar
Vijay Korthikanti committed
26
from megatron.checkpointing import get_checkpoint_version
27
28
29
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import openai_gelu, erf_gelu
30

31
32
33
34
35
# flags required to enable jit fusion kernels
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
36
37
38
39
40
41
42
43
44
45
46

""" We use the following notation throughout this file:
     h: hidden size
     n: number of attention heads
     p: number of model parallel partitions
     np: n/p
     hp: h/p
     hn: h/n
     b: batch size
     s: sequence length
     l: number of layers
47
    Transformer takes input of size [s, b, h] and returns a
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    tensor of the same size. We use the following arguments:
        hyperparameters: transformer hyperparameters
        attention_mask_func: a function that takes `unmaksed-attention-scores`
            with size [b, np, s, s] and an `attention-mask` and will apply
            the masking. The function should return a masked score of the
            same size [b, np, s, s].
               masked-attention-scores = attention_mask_func(
                                     unmaksed-attention-scores, attention-mask)
"""

class ParallelMLP(MegatronModule):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
    state back into h hidden dimension. At the end, dropout is also
    applied.
    """

67
    def __init__(self, init_method, output_layer_init_method):
68
        super(ParallelMLP, self).__init__()
Mohammad's avatar
Mohammad committed
69
        args = get_args()
70
71
72

        # Project to 4h.
        self.dense_h_to_4h = mpu.ColumnParallelLinear(
Mohammad's avatar
Mohammad committed
73
            args.hidden_size,
Neel Kant's avatar
Neel Kant committed
74
            4 * args.hidden_size,
75
            gather_output=False,
76
77
            init_method=init_method,
            skip_bias_add=True)
78

79
80
81
82
83
84
        self.bias_gelu_fusion = args.bias_gelu_fusion
        self.activation_func = F.gelu
        if args.openai_gelu:
            self.activation_func = openai_gelu
        elif args.onnx_safe:
            self.activation_func = erf_gelu
85
86
87

        # Project back to h.
        self.dense_4h_to_h = mpu.RowParallelLinear(
Neel Kant's avatar
Neel Kant committed
88
            4 * args.hidden_size,
Mohammad's avatar
Mohammad committed
89
            args.hidden_size,
90
            input_is_parallel=True,
91
92
93
            init_method=output_layer_init_method,
            skip_bias_add=True)
         
94
95
96

    def forward(self, hidden_states):

97
98
        # [s, b, 4hp]
        intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
99

100
101
102
103
104
105
106
107
108
109
        if self.bias_gelu_fusion:
             intermediate_parallel = \
                     bias_gelu_impl(intermediate_parallel, bias_parallel)
        else:
            intermediate_parallel = \
                self.activation_func(intermediate_parallel + bias_parallel)

        # [s, b, h]
        output, output_bias = self.dense_4h_to_h(intermediate_parallel)
        return output, output_bias
110
111
112
113
114
115
116
117


class ParallelSelfAttention(MegatronModule):
    """Parallel self-attention layer abstract class.

    Self-attention layer takes input with size [b, s, h]
    and returns output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
118

Mohammad's avatar
Mohammad committed
119
120
    def __init__(self, attention_mask_func, init_method,
                 output_layer_init_method, layer_number):
121
        super(ParallelSelfAttention, self).__init__()
Mohammad's avatar
Mohammad committed
122
        args = get_args()
Mohammad's avatar
Mohammad committed
123
        self.fp16 = args.fp16
124
125

        self.attention_mask_func = attention_mask_func
Mohammad's avatar
Mohammad committed
126
127
        self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
128
129
130
        if self.apply_query_key_layer_scaling:
            self.attention_softmax_in_fp32 = True
        self.layer_number = max(1, layer_number)
131
132
133

        # Per attention head and per partition values.
        world_size = mpu.get_model_parallel_world_size()
Mohammad's avatar
Mohammad committed
134
135
        self.hidden_size_per_partition = mpu.divide(args.hidden_size,
                                                    world_size)
136
        self.hidden_size_per_attention_head = mpu.divide(
Mohammad's avatar
Mohammad committed
137
            args.hidden_size, args.num_attention_heads)
138
        self.num_attention_heads_per_partition = mpu.divide(
Mohammad's avatar
Mohammad committed
139
            args.num_attention_heads, world_size)
140
141
142

        # Strided linear layer.
        self.query_key_value = mpu.ColumnParallelLinear(
Mohammad's avatar
Mohammad committed
143
            args.hidden_size,
Neel Kant's avatar
Neel Kant committed
144
            3 * args.hidden_size,
145
            gather_output=False,
Mohammad's avatar
Mohammad committed
146
            init_method=init_method)
147

148
149
150
151
152
153
154
155
156
        coeff = None
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        if self.apply_query_key_layer_scaling:
            coeff = self.layer_number
            self.norm_factor *= coeff

        self.scale_mask_softmax = FusedScaleMaskSoftmax(
            self.fp16,
            args.scaled_upper_triang_masked_softmax_fusion,
157
            args.scaled_masked_softmax_fusion,
158
159
160
161
            self.attention_mask_func,
            self.attention_softmax_in_fp32,
            coeff)

162
163
164
        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
Mohammad's avatar
Mohammad committed
165
        self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
166
167
168

        # Output.
        self.dense = mpu.RowParallelLinear(
Mohammad's avatar
Mohammad committed
169
170
            args.hidden_size,
            args.hidden_size,
171
            input_is_parallel=True,
172
173
            init_method=output_layer_init_method,
            skip_bias_add=True)
174
175
176
177
178
179
180
 
    def _transpose_last_dim(self, mixed_layer):
        """[s, b, 3 * hp] -->(view) [s, b, 3, hp] -->(tranpose)
        [s, b, hp, 3] -->(view) [s, b, 3 * hp] """

        input_shape = mixed_layer.size();
        last_dim = input_shape[-1]
Vijay Korthikanti's avatar
Vijay Korthikanti committed
181
        assert last_dim % 3 == 0, "expected QKV dimension"
182
183
184
185
186
187
188
189
190
        last_dim_split = last_dim // 3
        
        intermediate_shape = input_shape[:-1] +\
            (3, last_dim_split)
        mixed_layer = mixed_layer.view(*intermediate_shape)
        mixed_layer = mixed_layer.transpose(-1, -2).contiguous()
        mixed_layer = mixed_layer.view(*input_shape)
        
        return mixed_layer
191

192
193
    def forward(self, hidden_states, attention_mask, layer_past=None,
                get_key_value=False):
194
        # hidden_states: [sq, b, h]
195

196
197
198
        # =====================
        # Query, Key, and Value
        # =====================
199

200
        # Attention heads [sq, b, hp] --> [sq, b, hp * 3]
201
        mixed_x_layer, _ = self.query_key_value(hidden_states)
202
 
Vijay Korthikanti's avatar
Vijay Korthikanti committed
203
204
205
        checkpoint_version = get_checkpoint_version()
        if checkpoint_version is not None and \
           checkpoint_version == 0:
206
            # [sq, b, 3 * hp] --> [sq, b, hp * 3]
Vijay Korthikanti's avatar
bug fix  
Vijay Korthikanti committed
207
            mixed_x_layer = self._transpose_last_dim(mixed_x_layer)
208

209
        # [sq, b, hp * 3] --> [sq, b, np, hn, 3]  
210
211
        new_tensor_shape = mixed_x_layer.size()[:-1] + \
            (self.num_attention_heads_per_partition,
Vijay Korthikanti's avatar
bug fix  
Vijay Korthikanti committed
212
             self.hidden_size_per_attention_head, 3)
213
214
        mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

215
        # [sq, b, np, hn, 3] --> 3 [sq, b, np, hn]
Vijay Korthikanti's avatar
bug fix  
Vijay Korthikanti committed
216
217
218
        query_layer = mixed_x_layer[:,:,:,:,0]
        key_layer = mixed_x_layer[:,:,:,:,1]
        value_layer = mixed_x_layer[:,:,:,:,2]
219

220
221
222
        # ==================================
        # Adjust key and value for inference
        # ==================================
223
224
225
226

        if layer_past is not None:
            past_key, past_value = layer_past
            key_layer = torch.cat((past_key.type_as(key_layer),
227
                                   key_layer), dim=0)
228
            value_layer = torch.cat((past_value.type_as(value_layer),
229
                                     value_layer), dim=0)
230
231
232
233
        if get_key_value:
            present = (key_layer, value_layer)


234
235
236
237
        # ===================================
        # Raw attention scores. [b, np, s, s]
        # ===================================
        
238
        # [b, np, sq, sk]
239
240
241
242
243
        output_size = (query_layer.size(1), 
                       query_layer.size(2), 
                       query_layer.size(0), 
                       key_layer.size(0))
        
244
        # [sq, b, np, hn] -> [sq, b * np, hn]
245
246
247
248
249
        query_layer = query_layer.view(output_size[2],
                                       output_size[0] * output_size[1], -1)
        key_layer = key_layer.view(output_size[3],
                                   output_size[0] * output_size[1], -1)

250
        # preallocting result tensor: [b * np, sq, sk]
251
252
253
254
255
256
257
        matmul_result = torch.empty(
            output_size[0]*output_size[1], 
            output_size[2], 
            output_size[3],
            dtype=query_layer.dtype, 
            device=torch.cuda.current_device())

258
        # Raw attention scores. [b * np, sq, sk]
259
        matmul_result = torch.baddbmm(matmul_result, 
260
261
            query_layer.transpose(0, 1),   # [b * np, sq, hn]
            key_layer.transpose(0,1).transpose(1, 2),  #[b * np, hn, sk]
262
263
            beta=0.0, alpha=(1.0/self.norm_factor))

264
        # change view to [b, np, sq, sk]
265
266
267
268
        attention_scores = matmul_result.view(*output_size)


        # ==================================================
269
        # Update attention mask for inference. [b, np, sq, sk]
270
        # ==================================================
271

272
273
274
275
276
        if get_key_value:
            with torch.no_grad():
                if layer_past is not None:
                    attention_mask = attention_mask[
                        ...,
Neel Kant's avatar
Neel Kant committed
277
                        attention_scores.size(3) - 1,
278
279
280
281
282
283
284
285
                        :attention_scores.size(3)].unsqueeze(2)
                else:
                    attention_mask = attention_mask[
                        ...,
                        :attention_scores.size(3),
                        :attention_scores.size(3)]


286
287
288
        # ===========================
        # Attention probs and dropout
        # ===========================
289

290
        # attention scores and attention mask [b, np, sq, sk]
291
292
        attention_probs = self.scale_mask_softmax(attention_scores,
                                                  attention_mask)
293

294
295
296
297
298
299
300
        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        with mpu.get_cuda_rng_tracker().fork():
            attention_probs = self.attention_dropout(attention_probs)


        # =========================
301
        # Context layer. [sq, b, hp]
302
303
        # =========================

304
305
        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]
306

307
        # context layer shape: [b, np, sq, hn]
308
309
        output_size = (value_layer.size(1), 
                       value_layer.size(2), 
310
                       query_layer.size(0), 
311
312
                       value_layer.size(3)) 

313
314
        # change view [sk, b * np, hn] 
        value_layer = value_layer.view(value_layer.size(0),
315
316
                                       output_size[0] * output_size[1], -1)
        
317
        # change view [b * np, sq, sk]
318
319
320
        attention_probs = attention_probs.view(output_size[0] * output_size[1],
                                               output_size[2], -1)
        
321
        # matmul: [b * np, sq, hn]
322
323
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0,1))

324
        # change view [b, np, sq, hn]
325
326
        context_layer = context_layer.view(*output_size)

327
        # [b, np, sq, hn] --> [sq, b, np, hn]
328
329
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

330
        # [sq, b, np, hn] --> [sq, b, hp]
331
332
333
334
335
336
        new_context_layer_shape = context_layer.size()[:-2] + \
            (self.hidden_size_per_partition,)
        context_layer = context_layer.view(*new_context_layer_shape)


        # =================
337
        # Output. [sq, b, h]
338
339
340
        # =================

        output, bias = self.dense(context_layer)
341
342
343
344

        if get_key_value:
            output = [output, present]

345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
        return output, bias


def bias_dropout_add(x, bias, residual, prob, training) :
    # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
    out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
    out = residual + out
    return out


def get_bias_dropout_add(training):
    def _bias_dropout_add(x, bias, residual, prob):
        return bias_dropout_add(x, bias, residual, prob, training)
    return _bias_dropout_add


@torch.jit.script
def bias_dropout_add_fused_train(x, bias, residual, prob) :
    # type: (Tensor, Tensor, Tensor, float) -> Tensor
    return bias_dropout_add(x, bias, residual, prob, True)


@torch.jit.script
def bias_dropout_add_fused_inference(x, bias, residual, prob) :
    # type: (Tensor, Tensor, Tensor, float) -> Tensor
    return bias_dropout_add(x, bias, residual, prob, False)
371
372
373
374
375
376
377
378


class ParallelTransformerLayer(MegatronModule):
    """A single transformer layer.

    Transformore layer takes input with size [b, s, h] and returns an
    output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
379

380
381
    def __init__(self, attention_mask_func, init_method, 
                 output_layer_init_method, layer_number):
Mohammad's avatar
Mohammad committed
382
        args = get_args()
383
384

        super(ParallelTransformerLayer, self).__init__()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
385
        self.layer_number = layer_number
386
387

        self.apply_residual_connection_post_layernorm \
Mohammad's avatar
Mohammad committed
388
            = args.apply_residual_connection_post_layernorm
389
390
391

        # Layernorm on the input data.
        self.input_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
392
393
            args.hidden_size,
            eps=args.layernorm_epsilon)
394
395

        # Self attention.
Mohammad's avatar
Mohammad committed
396
397
398
        self.attention = ParallelSelfAttention(attention_mask_func, init_method,
                                               output_layer_init_method,
                                               layer_number)
399
400
        self.hidden_dropout = args.hidden_dropout
        self.bias_dropout_fusion = args.bias_dropout_fusion
401
402
403

        # Layernorm on the input data.
        self.post_attention_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
404
405
            args.hidden_size,
            eps=args.layernorm_epsilon)
406
407

        # MLP
408
        self.mlp = ParallelMLP(init_method,
Mohammad's avatar
Mohammad committed
409
                               output_layer_init_method)
410
411
412
413
414
415
416
417

    def forward(self, hidden_states, attention_mask, layer_past=None,
                get_key_value=False):
        # hidden_states: [b, s, h]

        # Layer norm at the begining of the transformer layer.
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
418
419
420
421
422
423
        attention_output, attention_bias = \
            self.attention(layernorm_output,
                           attention_mask,
                           layer_past=layer_past,
                           get_key_value=get_key_value)

424
425
        if get_key_value:
            attention_output, presents = attention_output
426
    
427
428
        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
429
430
431
432
433
434
435
436
437
438
439
440
441
            residual = layernorm_output
        else:
            residual = hidden_states

        # jit scripting for a nn.module (with dropout) is not 
        # trigerring the fusion kernel. For now, we use two 
        # different nn.functional routines to account for varying
        # dropout semantics during training and inference phases.
        if self.bias_dropout_fusion:
            if self.training:
                bias_dropout_add_func = bias_dropout_add_fused_train
            else:
                bias_dropout_add_func = bias_dropout_add_fused_inference
442
        else:
443
444
445
446
447
448
449
450
451
452
            bias_dropout_add_func = get_bias_dropout_add(self.training)

        #re-enable torch grad to enable fused optimization.
        with torch.enable_grad():
            layernorm_input = bias_dropout_add_func(
                attention_output,
                attention_bias.expand_as(residual),
                residual,
                self.hidden_dropout)

453
454
455
456
        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)

        # MLP.
457
458
        mlp_output, mlp_bias = self.mlp(layernorm_output)
        
459
460
        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
461
            residual = layernorm_output
462
        else:
463
464
465
466
467
468
469
470
471
            residual = layernorm_input

        #re-enable torch grad to enable fused optimization.
        with torch.enable_grad():
            output = bias_dropout_add_func(
                mlp_output,
                mlp_bias.expand_as(residual),
                residual,
                self.hidden_dropout)
472
473
474
475
476
477
478
479
480
481

        if get_key_value:
            output = [output, presents]

        return output


class ParallelTransformer(MegatronModule):
    """Transformer class."""

482
    def __init__(self, attention_mask_func,
Mohammad's avatar
Mohammad committed
483
                 init_method, output_layer_init_method):
484
        super(ParallelTransformer, self).__init__()
Mohammad's avatar
Mohammad committed
485
        args = get_args()
486
487

        # Store activation checkpoiting flag.
Mohammad's avatar
Mohammad committed
488
489
        self.checkpoint_activations = args.checkpoint_activations
        self.checkpoint_num_layers = args.checkpoint_num_layers
490

Mohammad's avatar
Mohammad committed
491
492
493
494
495
496
497
498
499
500
501
        # Number of layers:
        self.num_layers = args.num_layers
        self.num_unique_layers = args.num_unique_layers
        if self.num_unique_layers is None:
            self.num_unique_layers = self.num_layers
        assert self.num_layers % self.num_unique_layers == 0, \
            'number of layers should be divisible by number of unique layers'
        self.param_sharing_style = args.param_sharing_style

        # Transformer layers.
        def build_layer(layer_number):
502
            return ParallelTransformerLayer(
503
504
                attention_mask_func, init_method,
                output_layer_init_method, layer_number)
505
        self.layers = torch.nn.ModuleList(
Mohammad's avatar
Mohammad committed
506
507
508
509
510
511
512
            [build_layer(i + 1) for i in range(self.num_unique_layers)])

        # Print layer ordering.
        if self.num_layers != self.num_unique_layers:
            if torch.distributed.get_rank() == 0:
                print('> will be using the following layer ordering:')
                for i in range(self.num_layers):
mohammad's avatar
mohammad committed
513
514
515
                    print('   layer id: {:3d} --> unique layer id: '
                          '{:3d}'.format(i, self._get_layer_index(i)),
                          flush=True)
516
517
518

        # Final layer norm before output.
        self.final_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
519
520
            args.hidden_size,
            eps=args.layernorm_epsilon)
521

Mohammad's avatar
Mohammad committed
522
523
524
525
526
527
528
529
530
531
    def _get_layer_index(self, layer_number):
        if self.param_sharing_style == 'grouped':
            return layer_number % self.num_unique_layers
        if self.param_sharing_style == 'spaced':
            return layer_number // (self.num_layers // self.num_unique_layers) 
        assert False, 'should not be here'

    def _get_layer(self, layer_number):
        return self.layers[self._get_layer_index(layer_number)]

532
533
534
535
536
    def _checkpointed_forward(self, hidden_states, attention_mask):
        """Forward method with activation checkpointing."""
        def custom(start, end):
            def custom_forward(*inputs):
                x_ = inputs[0]
Mohammad's avatar
Mohammad committed
537
538
                for index in range(start, end):
                    layer = self._get_layer(index)
539
540
541
542
                    x_ = layer(x_, inputs[1])
                return x_
            return custom_forward

543
544
        # Make sure memory is freed.
        mpu.reset_checkpointed_activations_memory_buffer()
545
        l = 0
Mohammad's avatar
Mohammad committed
546
        while l < self.num_layers:
547
            hidden_states = mpu.checkpoint(
Neel Kant's avatar
Neel Kant committed
548
                custom(l, l + self.checkpoint_num_layers),
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
                hidden_states, attention_mask)
            l += self.checkpoint_num_layers

        return hidden_states

    def forward(self, hidden_states, attention_mask, layer_past=None,
                get_key_value=False):

        # Checks
        if layer_past is not None:
            assert get_key_value, \
                'for not None values in layer_past, ' \
                'expected get_key_value to be set'
        if get_key_value:
            assert not self.checkpoint_activations, \
                'get_key_value does not work with ' \
                'activation checkpointing'

567
568
569
        # data format change to avoid explicit tranposes : [b s h] --> [s b h]
        hidden_states = hidden_states.transpose(0, 1).contiguous()

570
571
572
573
574
575
        if self.checkpoint_activations:
            hidden_states = self._checkpointed_forward(hidden_states,
                                                       attention_mask)
        else:
            if get_key_value:
                presents = []
Mohammad's avatar
Mohammad committed
576
577
            for index in range(self.num_layers):
                layer = self._get_layer(index)
578
579
                past = None
                if layer_past is not None:
Mohammad's avatar
Mohammad committed
580
                    past = layer_past[index]
581
582
583
584
585
586
587
                hidden_states = layer(hidden_states,
                                      attention_mask,
                                      layer_past=past,
                                      get_key_value=get_key_value)
                if get_key_value:
                    hidden_states, present = hidden_states
                    presents.append(present)
588
589
590
        
        # reverting data format change [s b h] --> [b s h]
        hidden_states = hidden_states.transpose(0, 1).contiguous()
591
592
593
594
595
596
597

        # Final layer norm.
        output = self.final_layernorm(hidden_states)
        if get_key_value:
            output = [output, presents]

        return output