transformer.py 26.8 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transformer."""
import math
import torch
19
import torch.nn.functional as F
20

Mohammad's avatar
Mohammad committed
21
from megatron import get_args
22
from megatron import mpu
23
from .module import MegatronModule
Vijay Korthikanti's avatar
Vijay Korthikanti committed
24
from megatron.checkpointing import get_checkpoint_version
25
from megatron.model.enums import AttnMaskType, LayerType, AttnType
26
from megatron.model import import_layernorm
27
28
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
29
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
30

31
32
33
34
35
# flags required to enable jit fusion kernels
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
36
37
38
39
40
41
42
43
44
45
46

""" We use the following notation throughout this file:
     h: hidden size
     n: number of attention heads
     p: number of model parallel partitions
     np: n/p
     hp: h/p
     hn: h/n
     b: batch size
     s: sequence length
     l: number of layers
47
    Transformer takes input of size [s, b, h] and returns a
48
49
50
51
52
53
54
55
56
57
58
59
60
    tensor of the same size. We use the following arguments:
        hyperparameters: transformer hyperparameters
"""

class ParallelMLP(MegatronModule):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
    state back into h hidden dimension. At the end, dropout is also
    applied.
    """

61
    def __init__(self, init_method, output_layer_init_method):
62
        super(ParallelMLP, self).__init__()
Mohammad's avatar
Mohammad committed
63
        args = get_args()
64
65
66

        # Project to 4h.
        self.dense_h_to_4h = mpu.ColumnParallelLinear(
Mohammad's avatar
Mohammad committed
67
            args.hidden_size,
68
            args.ffn_hidden_size,
69
            gather_output=False,
70
71
            init_method=init_method,
            skip_bias_add=True)
72

73
74
75
76
77
78
        self.bias_gelu_fusion = args.bias_gelu_fusion
        self.activation_func = F.gelu
        if args.openai_gelu:
            self.activation_func = openai_gelu
        elif args.onnx_safe:
            self.activation_func = erf_gelu
79
80
81

        # Project back to h.
        self.dense_4h_to_h = mpu.RowParallelLinear(
82
            args.ffn_hidden_size,
Mohammad's avatar
Mohammad committed
83
            args.hidden_size,
84
            input_is_parallel=True,
85
86
            init_method=output_layer_init_method,
            skip_bias_add=True)
87
88
89

    def forward(self, hidden_states):

90
91
        # [s, b, 4hp]
        intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
92

93
94
95
96
97
98
99
100
101
102
        if self.bias_gelu_fusion:
             intermediate_parallel = \
                     bias_gelu_impl(intermediate_parallel, bias_parallel)
        else:
            intermediate_parallel = \
                self.activation_func(intermediate_parallel + bias_parallel)

        # [s, b, h]
        output, output_bias = self.dense_4h_to_h(intermediate_parallel)
        return output, output_bias
103
104


105
class ParallelAttention(MegatronModule):
106
107
108
109
110
    """Parallel self-attention layer abstract class.

    Self-attention layer takes input with size [b, s, h]
    and returns output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
111

112
    def __init__(self, init_method,
113
114
115
116
                 output_layer_init_method, layer_number,
                 attention_type=AttnType.self_attn,
                 attn_mask_type=AttnMaskType.padding):
        super(ParallelAttention, self).__init__()
Mohammad's avatar
Mohammad committed
117
        args = get_args()
Mohammad's avatar
Mohammad committed
118
        self.fp16 = args.fp16
119

Mohammad's avatar
Mohammad committed
120
121
        self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
122
123
124
        if self.apply_query_key_layer_scaling:
            self.attention_softmax_in_fp32 = True
        self.layer_number = max(1, layer_number)
125
126
127
128
        self.attention_type = attention_type
        self.attn_mask_type = attn_mask_type

        projection_size = args.kv_channels * args.num_attention_heads
129
130

        # Per attention head and per partition values.
131
        world_size = mpu.get_tensor_model_parallel_world_size()
132
        self.hidden_size_per_partition = mpu.divide(projection_size,
Mohammad's avatar
Mohammad committed
133
                                                    world_size)
134
        self.hidden_size_per_attention_head = mpu.divide(
135
            projection_size, args.num_attention_heads)
136
        self.num_attention_heads_per_partition = mpu.divide(
Mohammad's avatar
Mohammad committed
137
            args.num_attention_heads, world_size)
138
139

        # Strided linear layer.
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
        if attention_type == AttnType.self_attn:
            self.query_key_value = mpu.ColumnParallelLinear(
                args.hidden_size,
                3 * projection_size,
                gather_output=False,
                init_method=init_method)
        else:
            assert attention_type == AttnType.cross_attn
            self.query = mpu.ColumnParallelLinear(
                args.hidden_size,
                projection_size,
                gather_output=False,
                init_method=init_method)

            self.key_value = mpu.ColumnParallelLinear(
                args.hidden_size,
                2 * projection_size,
                gather_output=False,
                init_method=init_method)
159

160
161
162
163
164
165
166
167
        coeff = None
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        if self.apply_query_key_layer_scaling:
            coeff = self.layer_number
            self.norm_factor *= coeff

        self.scale_mask_softmax = FusedScaleMaskSoftmax(
            self.fp16,
168
169
            self.attn_mask_type,
            args.masked_softmax_fusion,
170
            attention_mask_func,
171
172
173
            self.attention_softmax_in_fp32,
            coeff)

174
175
176
        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
Mohammad's avatar
Mohammad committed
177
        self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
178
179
180

        # Output.
        self.dense = mpu.RowParallelLinear(
181
            projection_size,
Mohammad's avatar
Mohammad committed
182
            args.hidden_size,
183
            input_is_parallel=True,
184
185
            init_method=output_layer_init_method,
            skip_bias_add=True)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
186

Vijay Korthikanti's avatar
Vijay Korthikanti committed
187
    def _transpose_last_dim(self, mixed_layer, num_splits, num_splits_first):
188
        input_shape = mixed_layer.size()
Vijay Korthikanti's avatar
Vijay Korthikanti committed
189
190
        if num_splits_first:
            """[s, b, num_splits * np * hn] 
191
192
            -->(view) [s, b, num_splits, np, hn]
            -->(tranpose) [s, b, np, num_splits, hn]
Vijay Korthikanti's avatar
Vijay Korthikanti committed
193
194
            -->(view) [s, b, np * num_splits * hn] """

195
196
197
            intermediate_shape = input_shape[:-1] +\
                (num_splits, self.num_attention_heads_per_partition,
                 self.hidden_size_per_attention_head)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
198

199
200
201
            mixed_layer = mixed_layer.view(*intermediate_shape)
            mixed_layer = mixed_layer.transpose(-2, -3).contiguous()
        else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
202
            """[s, b, np * hn * num_splits] 
203
204
            -->(view) [s, b, np, hn, num_splits]
            -->(tranpose) [s, b, np, num_splits, hn]
Vijay Korthikanti's avatar
Vijay Korthikanti committed
205
206
            -->(view) [s, b, np * num_splits * hn] """

207
208
209
210
211
212
            intermediate_shape = input_shape[:-1] +\
                (self.num_attention_heads_per_partition,
                 self.hidden_size_per_attention_head, num_splits)

            mixed_layer = mixed_layer.view(*intermediate_shape)
            mixed_layer = mixed_layer.transpose(-1, -2).contiguous()
213
        mixed_layer = mixed_layer.view(*input_shape)
214

215
        return mixed_layer
216

217
    def forward(self, hidden_states, attention_mask, layer_past=None,
218
                get_key_value=False, encoder_output=None):
219
        # hidden_states: [sq, b, h]
220

221
222
223
        # =====================
        # Query, Key, and Value
        # =====================
224

225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
        if self.attention_type == AttnType.self_attn:
            # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
            mixed_x_layer, _ = self.query_key_value(hidden_states)

            checkpoint_version = get_checkpoint_version()
            if checkpoint_version is not None:
                if checkpoint_version == 0:
                    # [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)]
                    mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, True)
                elif checkpoint_version == 1.0:
                    # [s, b, (np * hn * 3)] --> [s, b, (np * 3 * hn)]
                    mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, False)

            # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
            new_tensor_shape = mixed_x_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 3 * self.hidden_size_per_attention_head)
            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

            # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
            (query_layer,
             key_layer,
             value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
        else:
            # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
            mixed_kv_layer, _ = self.key_value(encoder_output)

            checkpoint_version = get_checkpoint_version()
            if checkpoint_version is not None:
                if checkpoint_version == 0:
                    # [s, b, (2 * np * hn)] --> [s, b, (np * 2 * hn)]
                    mixed_kv_layer = self._transpose_last_dim(mixed_kv_layer, 2, True)
                elif checkpoint_version == 1.0:
                    # [s, b, (np * hn * 2)] --> [s, b, (np * 2 * hn)]
                    mixed_kv_layer = self._transpose_last_dim(mixed_kv_layer, 2, False)

            # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
            new_tensor_shape = mixed_kv_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 2 * self.hidden_size_per_attention_head)
            mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

            # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
            (key_layer,
             value_layer) = mpu.split_tensor_along_last_dim(mixed_kv_layer, 2)

            # Attention head [sq, b, h] --> [sq, b, hp]
            query_layer, _ = self.query(hidden_states)
            # [sq, b, hp] --> [sq, b, np, hn]
            new_tensor_shape = query_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 self.hidden_size_per_attention_head)
            query_layer = query_layer.view(*new_tensor_shape)
278

279
280
281
        # ==================================
        # Adjust key and value for inference
        # ==================================
282
283
284
285

        if layer_past is not None:
            past_key, past_value = layer_past
            key_layer = torch.cat((past_key.type_as(key_layer),
286
                                   key_layer), dim=0)
287
            value_layer = torch.cat((past_value.type_as(value_layer),
288
                                     value_layer), dim=0)
289
290
291
        if get_key_value:
            present = (key_layer, value_layer)

292
293
294
        # ===================================
        # Raw attention scores. [b, np, s, s]
        # ===================================
295

296
        # [b, np, sq, sk]
297
298
299
        output_size = (query_layer.size(1),
                       query_layer.size(2),
                       query_layer.size(0),
300
                       key_layer.size(0))
301

302
        # [sq, b, np, hn] -> [sq, b * np, hn]
303
304
        query_layer = query_layer.view(output_size[2],
                                       output_size[0] * output_size[1], -1)
305
        # [sk, b, np, hn] -> [sk, b * np, hn]
306
307
308
        key_layer = key_layer.view(output_size[3],
                                   output_size[0] * output_size[1], -1)

309
        # preallocting result tensor: [b * np, sq, sk]
310
        matmul_result = torch.empty(
311
312
            output_size[0]*output_size[1],
            output_size[2],
313
            output_size[3],
314
            dtype=query_layer.dtype,
315
316
            device=torch.cuda.current_device())

317
        # Raw attention scores. [b * np, sq, sk]
318
319
        matmul_result = torch.baddbmm(
            matmul_result,
320
            query_layer.transpose(0, 1),   # [b * np, sq, hn]
321
            key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
322
323
            beta=0.0, alpha=(1.0/self.norm_factor))

324
        # change view to [b, np, sq, sk]
325
326
327
        attention_scores = matmul_result.view(*output_size)

        # ==================================================
328
        # Update attention mask for inference. [b, np, sq, sk]
329
        # ==================================================
330

331
332
333
334
335
        if get_key_value:
            with torch.no_grad():
                if layer_past is not None:
                    attention_mask = attention_mask[
                        ...,
Neel Kant's avatar
Neel Kant committed
336
                        attention_scores.size(3) - 1,
337
338
339
340
341
342
343
                        :attention_scores.size(3)].unsqueeze(2)
                else:
                    attention_mask = attention_mask[
                        ...,
                        :attention_scores.size(3),
                        :attention_scores.size(3)]

344
345
346
        # ===========================
        # Attention probs and dropout
        # ===========================
347

348
        # attention scores and attention mask [b, np, sq, sk]
349
350
        attention_probs = self.scale_mask_softmax(attention_scores,
                                                  attention_mask)
351

352
353
354
355
356
357
        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        with mpu.get_cuda_rng_tracker().fork():
            attention_probs = self.attention_dropout(attention_probs)

        # =========================
358
        # Context layer. [sq, b, hp]
359
360
        # =========================

361
362
        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]
363

364
        # context layer shape: [b, np, sq, hn]
365
366
367
368
        output_size = (value_layer.size(1),
                       value_layer.size(2),
                       query_layer.size(0),
                       value_layer.size(3))
369

370
        # change view [sk, b * np, hn]
371
        value_layer = value_layer.view(value_layer.size(0),
372
                                       output_size[0] * output_size[1], -1)
373

374
        # change view [b * np, sq, sk]
375
376
        attention_probs = attention_probs.view(output_size[0] * output_size[1],
                                               output_size[2], -1)
377

378
        # matmul: [b * np, sq, hn]
379
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
380

381
        # change view [b, np, sq, hn]
382
383
        context_layer = context_layer.view(*output_size)

384
        # [b, np, sq, hn] --> [sq, b, np, hn]
385
386
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

387
        # [sq, b, np, hn] --> [sq, b, hp]
388
389
390
391
392
        new_context_layer_shape = context_layer.size()[:-2] + \
            (self.hidden_size_per_partition,)
        context_layer = context_layer.view(*new_context_layer_shape)

        # =================
393
        # Output. [sq, b, h]
394
395
396
        # =================

        output, bias = self.dense(context_layer)
397
398
399
400

        if get_key_value:
            output = [output, present]

401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
        return output, bias


def bias_dropout_add(x, bias, residual, prob, training) :
    # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
    out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
    out = residual + out
    return out


def get_bias_dropout_add(training):
    def _bias_dropout_add(x, bias, residual, prob):
        return bias_dropout_add(x, bias, residual, prob, training)
    return _bias_dropout_add


@torch.jit.script
def bias_dropout_add_fused_train(x, bias, residual, prob) :
    # type: (Tensor, Tensor, Tensor, float) -> Tensor
    return bias_dropout_add(x, bias, residual, prob, True)


@torch.jit.script
def bias_dropout_add_fused_inference(x, bias, residual, prob) :
    # type: (Tensor, Tensor, Tensor, float) -> Tensor
    return bias_dropout_add(x, bias, residual, prob, False)
427
428
429
430
431


class ParallelTransformerLayer(MegatronModule):
    """A single transformer layer.

432
    Transformer layer takes input with size [b, s, h] and returns an
433
434
    output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
435

436
437
    def __init__(self, init_method, output_layer_init_method,
                 layer_number, layer_type=LayerType.encoder,
438
                 self_attn_mask_type=AttnMaskType.padding):
Mohammad's avatar
Mohammad committed
439
        args = get_args()
440
441

        super(ParallelTransformerLayer, self).__init__()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
442
        self.layer_number = layer_number
443
        self.layer_type = layer_type
444
445

        self.apply_residual_connection_post_layernorm \
Mohammad's avatar
Mohammad committed
446
            = args.apply_residual_connection_post_layernorm
447
448

        # Layernorm on the input data.
449
        LayerNorm = import_layernorm(args.fp32_residual_connection)
450
        self.input_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
451
452
            args.hidden_size,
            eps=args.layernorm_epsilon)
453
454

        # Self attention.
455
456
457
458
459
460
        self.self_attention = ParallelAttention(
            init_method,
            output_layer_init_method,
            layer_number,
            attention_type=AttnType.self_attn,
            attn_mask_type=self_attn_mask_type)
461
462
        self.hidden_dropout = args.hidden_dropout
        self.bias_dropout_fusion = args.bias_dropout_fusion
463

464
        # Layernorm on the attention output
465
        self.post_attention_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
466
467
            args.hidden_size,
            eps=args.layernorm_epsilon)
468

469
470
471
472
473
474
475
476
477
478
479
        if self.layer_type == LayerType.decoder:
            self.inter_attention = ParallelAttention(
                init_method,
                output_layer_init_method,
                layer_number,
                attention_type=AttnType.cross_attn)
            # Layernorm on the attention output.
            self.post_inter_attention_layernorm = LayerNorm(
                args.hidden_size,
                eps=args.layernorm_epsilon)

480
        # MLP
481
        self.mlp = ParallelMLP(init_method,
Mohammad's avatar
Mohammad committed
482
                               output_layer_init_method)
483

484
485
486
    def forward(self, hidden_states, attention_mask,
                encoder_output=None, enc_dec_attn_mask=None,
                layer_past=None, get_key_value=False):
487
488
        # hidden_states: [b, s, h]

489
        # Layer norm at the beginning of the transformer layer.
490
491
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
492
        attention_output, attention_bias = \
493
494
495
496
            self.self_attention(layernorm_output,
                                attention_mask,
                                layer_past=layer_past,
                                get_key_value=get_key_value)
497

498
499
        if get_key_value:
            attention_output, presents = attention_output
500

501
502
        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
503
504
505
506
            residual = layernorm_output
        else:
            residual = hidden_states

507
508
        # jit scripting for a nn.module (with dropout) is not
        # trigerring the fusion kernel. For now, we use two
509
510
511
512
513
514
515
        # different nn.functional routines to account for varying
        # dropout semantics during training and inference phases.
        if self.bias_dropout_fusion:
            if self.training:
                bias_dropout_add_func = bias_dropout_add_fused_train
            else:
                bias_dropout_add_func = bias_dropout_add_fused_inference
516
        else:
517
518
            bias_dropout_add_func = get_bias_dropout_add(self.training)

519
        # re-enable torch grad to enable fused optimization.
520
521
522
523
524
525
526
        with torch.enable_grad():
            layernorm_input = bias_dropout_add_func(
                attention_output,
                attention_bias.expand_as(residual),
                residual,
                self.hidden_dropout)

527
528
529
        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)

530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
        if self.layer_type == LayerType.decoder:
            attention_output, attention_bias = \
                self.inter_attention(layernorm_output,
                                     enc_dec_attn_mask,
                                     encoder_output=encoder_output)
            # residual connection
            if self.apply_residual_connection_post_layernorm:
                residual = layernorm_output
            else:
                residual = layernorm_input

            # re-enable torch grad to enable fused optimization.
            with torch.enable_grad():
                layernorm_input = bias_dropout_add_func(
                    attention_output,
                    attention_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)

            # Layer norm post the decoder attention
            layernorm_output = self.post_inter_attention_layernorm(layernorm_input)

552
        # MLP.
553
        mlp_output, mlp_bias = self.mlp(layernorm_output)
554

555
556
        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
557
            residual = layernorm_output
558
        else:
559
560
            residual = layernorm_input

561
        # re-enable torch grad to enable fused optimization.
562
563
564
565
566
567
        with torch.enable_grad():
            output = bias_dropout_add_func(
                mlp_output,
                mlp_bias.expand_as(residual),
                residual,
                self.hidden_dropout)
568
569
570
571
572
573
574
575
576
577

        if get_key_value:
            output = [output, presents]

        return output


class ParallelTransformer(MegatronModule):
    """Transformer class."""

578
    def __init__(self, init_method, output_layer_init_method,
579
580
                 layer_type=LayerType.encoder,
                 self_attn_mask_type=AttnMaskType.padding):
581
        super(ParallelTransformer, self).__init__()
Mohammad's avatar
Mohammad committed
582
        args = get_args()
583

584
585
        self.fp32_residual_connection = args.fp32_residual_connection

586
        # Store activation checkpoiting flag.
Mohammad's avatar
Mohammad committed
587
588
        self.checkpoint_activations = args.checkpoint_activations
        self.checkpoint_num_layers = args.checkpoint_num_layers
589

590
        # Number of layers.
591
        assert args.num_layers % mpu.get_pipeline_model_parallel_world_size() == 0, \
592
            'num_layers must be divisible by pipeline_model_parallel_size'
593
        self.num_layers = args.num_layers // mpu.get_pipeline_model_parallel_world_size()
Mohammad's avatar
Mohammad committed
594
595
596

        # Transformer layers.
        def build_layer(layer_number):
597
            return ParallelTransformerLayer(
598
599
600
                init_method,
                output_layer_init_method,
                layer_number,
601
602
                layer_type=layer_type,
                self_attn_mask_type=self_attn_mask_type)
603
        offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
604
        self.layers = torch.nn.ModuleList(
605
            [build_layer(i + 1 + offset) for i in range(self.num_layers)])
606

607
        if mpu.is_pipeline_last_stage():
608
            # Final layer norm before output.
609
            LayerNorm = import_layernorm(args.fp32_residual_connection)
610
611
612
            self.final_layernorm = LayerNorm(
                args.hidden_size,
                eps=args.layernorm_epsilon)
613

Mohammad's avatar
Mohammad committed
614
    def _get_layer(self, layer_number):
615
        return self.layers[layer_number]
Mohammad's avatar
Mohammad committed
616

617
618
    def _checkpointed_forward(self, hidden_states, attention_mask,
                              encoder_output, enc_dec_attn_mask):
619
620
621
622
        """Forward method with activation checkpointing."""
        def custom(start, end):
            def custom_forward(*inputs):
                x_ = inputs[0]
623
624
625
                attention_mask = inputs[1]
                encoder_output = inputs[2]
                enc_dec_attn_mask = inputs[3]
Mohammad's avatar
Mohammad committed
626
627
                for index in range(start, end):
                    layer = self._get_layer(index)
628
                    x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
629
630
631
                return x_
            return custom_forward

632
633
        # Make sure memory is freed.
        mpu.reset_checkpointed_activations_memory_buffer()
634
        l = 0
Mohammad's avatar
Mohammad committed
635
        while l < self.num_layers:
636
            hidden_states = mpu.checkpoint(
Neel Kant's avatar
Neel Kant committed
637
                custom(l, l + self.checkpoint_num_layers),
638
                hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
639
640
641
642
643
            l += self.checkpoint_num_layers

        return hidden_states

    def forward(self, hidden_states, attention_mask, layer_past=None,
644
                get_key_value=False, encoder_output=None, enc_dec_attn_mask=None):
645

646
        # Checks.
647
648
649
650
651
652
653
654
655
        if layer_past is not None:
            assert get_key_value, \
                'for not None values in layer_past, ' \
                'expected get_key_value to be set'
        if get_key_value:
            assert not self.checkpoint_activations, \
                'get_key_value does not work with ' \
                'activation checkpointing'

656
657
        if mpu.is_pipeline_first_stage():
            # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
mshoeybi's avatar
mshoeybi committed
658
            # If the input flag for fp32 residual connection is set, convert for float.
659
660
            if self.fp32_residual_connection:
                hidden_states = hidden_states.transpose(0, 1).contiguous().float()
mshoeybi's avatar
mshoeybi committed
661
            # Otherwise, leave it as is.
662
663
            else:
                hidden_states = hidden_states.transpose(0, 1).contiguous()
664

Vijay Korthikanti's avatar
Vijay Korthikanti committed
665
666
667
        if encoder_output is not None:
             encoder_output = encoder_output.transpose(0, 1).contiguous()
          
668
669
        if self.checkpoint_activations:
            hidden_states = self._checkpointed_forward(hidden_states,
670
671
672
                                                       attention_mask,
                                                       encoder_output,
                                                       enc_dec_attn_mask)
673
674
675
        else:
            if get_key_value:
                presents = []
Mohammad's avatar
Mohammad committed
676
677
            for index in range(self.num_layers):
                layer = self._get_layer(index)
678
679
                past = None
                if layer_past is not None:
Mohammad's avatar
Mohammad committed
680
                    past = layer_past[index]
681
682
                hidden_states = layer(hidden_states,
                                      attention_mask,
683
684
                                      encoder_output=encoder_output,
                                      enc_dec_attn_mask=enc_dec_attn_mask,
685
686
687
688
689
                                      layer_past=past,
                                      get_key_value=get_key_value)
                if get_key_value:
                    hidden_states, present = hidden_states
                    presents.append(present)
690

691
        # Final layer norm.
692
        if mpu.is_pipeline_last_stage():
693
694
            # Reverting data format change [s b h] --> [b s h].
            hidden_states = hidden_states.transpose(0, 1).contiguous()
695
696
697
            output = self.final_layernorm(hidden_states)
        else:
            output = hidden_states
698
699
700
701
        if get_key_value:
            output = [output, presents]

        return output