transformer.py 26.8 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transformer."""
import math
import torch
19
import torch.nn.functional as F
20

Mohammad's avatar
Mohammad committed
21
from megatron import get_args
22
from megatron import mpu
23
from .module import MegatronModule
Vijay Korthikanti's avatar
Vijay Korthikanti committed
24
from megatron.checkpointing import get_checkpoint_version
25
from megatron.model.enums import AttnMaskType, LayerType, AttnType
26
from megatron.model import import_layernorm
27
28
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
29
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
30

31
32
33
34
35
# flags required to enable jit fusion kernels
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
36
37
38
39
40
41
42
43
44
45
46

""" We use the following notation throughout this file:
     h: hidden size
     n: number of attention heads
     p: number of model parallel partitions
     np: n/p
     hp: h/p
     hn: h/n
     b: batch size
     s: sequence length
     l: number of layers
47
    Transformer takes input of size [s, b, h] and returns a
48
49
50
51
52
53
54
55
56
57
58
59
60
    tensor of the same size. We use the following arguments:
        hyperparameters: transformer hyperparameters
"""

class ParallelMLP(MegatronModule):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
    state back into h hidden dimension. At the end, dropout is also
    applied.
    """

61
    def __init__(self, init_method, output_layer_init_method):
62
        super(ParallelMLP, self).__init__()
Mohammad's avatar
Mohammad committed
63
        args = get_args()
64
65
66

        # Project to 4h.
        self.dense_h_to_4h = mpu.ColumnParallelLinear(
Mohammad's avatar
Mohammad committed
67
            args.hidden_size,
68
            args.ffn_hidden_size,
69
            gather_output=False,
70
71
            init_method=init_method,
            skip_bias_add=True)
72

73
74
75
76
77
78
        self.bias_gelu_fusion = args.bias_gelu_fusion
        self.activation_func = F.gelu
        if args.openai_gelu:
            self.activation_func = openai_gelu
        elif args.onnx_safe:
            self.activation_func = erf_gelu
79
80
81

        # Project back to h.
        self.dense_4h_to_h = mpu.RowParallelLinear(
82
            args.ffn_hidden_size,
Mohammad's avatar
Mohammad committed
83
            args.hidden_size,
84
            input_is_parallel=True,
85
86
            init_method=output_layer_init_method,
            skip_bias_add=True)
87

88
89
90

    def forward(self, hidden_states):

91
92
        # [s, b, 4hp]
        intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
93

94
95
96
97
98
99
100
101
102
103
        if self.bias_gelu_fusion:
             intermediate_parallel = \
                     bias_gelu_impl(intermediate_parallel, bias_parallel)
        else:
            intermediate_parallel = \
                self.activation_func(intermediate_parallel + bias_parallel)

        # [s, b, h]
        output, output_bias = self.dense_4h_to_h(intermediate_parallel)
        return output, output_bias
104
105


106
class ParallelAttention(MegatronModule):
107
108
109
110
111
    """Parallel self-attention layer abstract class.

    Self-attention layer takes input with size [b, s, h]
    and returns output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
112

113
    def __init__(self, init_method,
114
115
116
117
                 output_layer_init_method, layer_number,
                 attention_type=AttnType.self_attn,
                 attn_mask_type=AttnMaskType.padding):
        super(ParallelAttention, self).__init__()
Mohammad's avatar
Mohammad committed
118
        args = get_args()
Mohammad's avatar
Mohammad committed
119
        self.fp16 = args.fp16
120

Mohammad's avatar
Mohammad committed
121
122
        self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
123
124
125
        if self.apply_query_key_layer_scaling:
            self.attention_softmax_in_fp32 = True
        self.layer_number = max(1, layer_number)
126
127
128
129
        self.attention_type = attention_type
        self.attn_mask_type = attn_mask_type

        projection_size = args.kv_channels * args.num_attention_heads
130
131

        # Per attention head and per partition values.
132
        world_size = mpu.get_tensor_model_parallel_world_size()
133
        self.hidden_size_per_partition = mpu.divide(projection_size,
Mohammad's avatar
Mohammad committed
134
                                                    world_size)
135
        self.hidden_size_per_attention_head = mpu.divide(
136
            projection_size, args.num_attention_heads)
137
        self.num_attention_heads_per_partition = mpu.divide(
Mohammad's avatar
Mohammad committed
138
            args.num_attention_heads, world_size)
139
140

        # Strided linear layer.
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
        if attention_type == AttnType.self_attn:
            self.query_key_value = mpu.ColumnParallelLinear(
                args.hidden_size,
                3 * projection_size,
                gather_output=False,
                init_method=init_method)
        else:
            assert attention_type == AttnType.cross_attn
            self.query = mpu.ColumnParallelLinear(
                args.hidden_size,
                projection_size,
                gather_output=False,
                init_method=init_method)

            self.key_value = mpu.ColumnParallelLinear(
                args.hidden_size,
                2 * projection_size,
                gather_output=False,
                init_method=init_method)
160

161
162
163
164
165
166
167
168
        coeff = None
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        if self.apply_query_key_layer_scaling:
            coeff = self.layer_number
            self.norm_factor *= coeff

        self.scale_mask_softmax = FusedScaleMaskSoftmax(
            self.fp16,
169
170
            self.attn_mask_type,
            args.masked_softmax_fusion,
171
            attention_mask_func,
172
173
174
            self.attention_softmax_in_fp32,
            coeff)

175
176
177
        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
Mohammad's avatar
Mohammad committed
178
        self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
179
180
181

        # Output.
        self.dense = mpu.RowParallelLinear(
182
            projection_size,
Mohammad's avatar
Mohammad committed
183
            args.hidden_size,
184
            input_is_parallel=True,
185
186
            init_method=output_layer_init_method,
            skip_bias_add=True)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
187

Vijay Korthikanti's avatar
Vijay Korthikanti committed
188
    def _transpose_last_dim(self, mixed_layer, num_splits, num_splits_first):
189
        input_shape = mixed_layer.size()
Vijay Korthikanti's avatar
Vijay Korthikanti committed
190
191
        if num_splits_first:
            """[s, b, num_splits * np * hn] 
192
193
            -->(view) [s, b, num_splits, np, hn]
            -->(tranpose) [s, b, np, num_splits, hn]
Vijay Korthikanti's avatar
Vijay Korthikanti committed
194
195
            -->(view) [s, b, np * num_splits * hn] """

196
197
198
            intermediate_shape = input_shape[:-1] +\
                (num_splits, self.num_attention_heads_per_partition,
                 self.hidden_size_per_attention_head)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
199

200
201
202
            mixed_layer = mixed_layer.view(*intermediate_shape)
            mixed_layer = mixed_layer.transpose(-2, -3).contiguous()
        else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
203
            """[s, b, np * hn * num_splits] 
204
205
            -->(view) [s, b, np, hn, num_splits]
            -->(tranpose) [s, b, np, num_splits, hn]
Vijay Korthikanti's avatar
Vijay Korthikanti committed
206
207
            -->(view) [s, b, np * num_splits * hn] """

208
209
210
211
212
213
            intermediate_shape = input_shape[:-1] +\
                (self.num_attention_heads_per_partition,
                 self.hidden_size_per_attention_head, num_splits)

            mixed_layer = mixed_layer.view(*intermediate_shape)
            mixed_layer = mixed_layer.transpose(-1, -2).contiguous()
214
        mixed_layer = mixed_layer.view(*input_shape)
215

216
        return mixed_layer
217

218
    def forward(self, hidden_states, attention_mask, layer_past=None,
219
                get_key_value=False, encoder_output=None):
220
        # hidden_states: [sq, b, h]
221

222
223
224
        # =====================
        # Query, Key, and Value
        # =====================
225

226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
        if self.attention_type == AttnType.self_attn:
            # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
            mixed_x_layer, _ = self.query_key_value(hidden_states)

            checkpoint_version = get_checkpoint_version()
            if checkpoint_version is not None:
                if checkpoint_version == 0:
                    # [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)]
                    mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, True)
                elif checkpoint_version == 1.0:
                    # [s, b, (np * hn * 3)] --> [s, b, (np * 3 * hn)]
                    mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, False)

            # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
            new_tensor_shape = mixed_x_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 3 * self.hidden_size_per_attention_head)
            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

            # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
            (query_layer,
             key_layer,
             value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
        else:
            # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
            mixed_kv_layer, _ = self.key_value(encoder_output)

            checkpoint_version = get_checkpoint_version()
            if checkpoint_version is not None:
                if checkpoint_version == 0:
                    # [s, b, (2 * np * hn)] --> [s, b, (np * 2 * hn)]
                    mixed_kv_layer = self._transpose_last_dim(mixed_kv_layer, 2, True)
                elif checkpoint_version == 1.0:
                    # [s, b, (np * hn * 2)] --> [s, b, (np * 2 * hn)]
                    mixed_kv_layer = self._transpose_last_dim(mixed_kv_layer, 2, False)

            # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
            new_tensor_shape = mixed_kv_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 2 * self.hidden_size_per_attention_head)
            mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

            # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
            (key_layer,
             value_layer) = mpu.split_tensor_along_last_dim(mixed_kv_layer, 2)

            # Attention head [sq, b, h] --> [sq, b, hp]
            query_layer, _ = self.query(hidden_states)
            # [sq, b, hp] --> [sq, b, np, hn]
            new_tensor_shape = query_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 self.hidden_size_per_attention_head)
            query_layer = query_layer.view(*new_tensor_shape)
279

280
281
282
        # ==================================
        # Adjust key and value for inference
        # ==================================
283
284
285
286

        if layer_past is not None:
            past_key, past_value = layer_past
            key_layer = torch.cat((past_key.type_as(key_layer),
287
                                   key_layer), dim=0)
288
            value_layer = torch.cat((past_value.type_as(value_layer),
289
                                     value_layer), dim=0)
290
291
292
        if get_key_value:
            present = (key_layer, value_layer)

293
294
295
        # ===================================
        # Raw attention scores. [b, np, s, s]
        # ===================================
296

297
        # [b, np, sq, sk]
298
299
300
        output_size = (query_layer.size(1),
                       query_layer.size(2),
                       query_layer.size(0),
301
                       key_layer.size(0))
302

303
        # [sq, b, np, hn] -> [sq, b * np, hn]
304
305
        query_layer = query_layer.view(output_size[2],
                                       output_size[0] * output_size[1], -1)
306
        # [sk, b, np, hn] -> [sk, b * np, hn]
307
308
309
        key_layer = key_layer.view(output_size[3],
                                   output_size[0] * output_size[1], -1)

310
        # preallocting result tensor: [b * np, sq, sk]
311
        matmul_result = torch.empty(
312
313
            output_size[0]*output_size[1],
            output_size[2],
314
            output_size[3],
315
            dtype=query_layer.dtype,
316
317
            device=torch.cuda.current_device())

318
        # Raw attention scores. [b * np, sq, sk]
319
320
        matmul_result = torch.baddbmm(
            matmul_result,
321
            query_layer.transpose(0, 1),   # [b * np, sq, hn]
322
            key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
323
324
            beta=0.0, alpha=(1.0/self.norm_factor))

325
        # change view to [b, np, sq, sk]
326
327
328
        attention_scores = matmul_result.view(*output_size)

        # ==================================================
329
        # Update attention mask for inference. [b, np, sq, sk]
330
        # ==================================================
331

332
333
334
335
336
        if get_key_value:
            with torch.no_grad():
                if layer_past is not None:
                    attention_mask = attention_mask[
                        ...,
Neel Kant's avatar
Neel Kant committed
337
                        attention_scores.size(3) - 1,
338
339
340
341
342
343
344
                        :attention_scores.size(3)].unsqueeze(2)
                else:
                    attention_mask = attention_mask[
                        ...,
                        :attention_scores.size(3),
                        :attention_scores.size(3)]

345
346
347
        # ===========================
        # Attention probs and dropout
        # ===========================
348

349
        # attention scores and attention mask [b, np, sq, sk]
350
351
        attention_probs = self.scale_mask_softmax(attention_scores,
                                                  attention_mask)
352

353
354
355
356
357
358
        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        with mpu.get_cuda_rng_tracker().fork():
            attention_probs = self.attention_dropout(attention_probs)

        # =========================
359
        # Context layer. [sq, b, hp]
360
361
        # =========================

362
363
        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]
364

365
        # context layer shape: [b, np, sq, hn]
366
367
368
369
        output_size = (value_layer.size(1),
                       value_layer.size(2),
                       query_layer.size(0),
                       value_layer.size(3))
370

371
        # change view [sk, b * np, hn]
372
        value_layer = value_layer.view(value_layer.size(0),
373
                                       output_size[0] * output_size[1], -1)
374

375
        # change view [b * np, sq, sk]
376
377
        attention_probs = attention_probs.view(output_size[0] * output_size[1],
                                               output_size[2], -1)
378

379
        # matmul: [b * np, sq, hn]
380
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
381

382
        # change view [b, np, sq, hn]
383
384
        context_layer = context_layer.view(*output_size)

385
        # [b, np, sq, hn] --> [sq, b, np, hn]
386
387
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

388
        # [sq, b, np, hn] --> [sq, b, hp]
389
390
391
392
393
        new_context_layer_shape = context_layer.size()[:-2] + \
            (self.hidden_size_per_partition,)
        context_layer = context_layer.view(*new_context_layer_shape)

        # =================
394
        # Output. [sq, b, h]
395
396
397
        # =================

        output, bias = self.dense(context_layer)
398
399
400
401

        if get_key_value:
            output = [output, present]

402
403
404
        return output, bias


405
def bias_dropout_add(x, bias, residual, prob, training):
406
407
408
409
410
411
412
413
414
415
416
417
418
    # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
    out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
    out = residual + out
    return out


def get_bias_dropout_add(training):
    def _bias_dropout_add(x, bias, residual, prob):
        return bias_dropout_add(x, bias, residual, prob, training)
    return _bias_dropout_add


@torch.jit.script
419
def bias_dropout_add_fused_train(x, bias, residual, prob):
420
421
422
423
424
    # type: (Tensor, Tensor, Tensor, float) -> Tensor
    return bias_dropout_add(x, bias, residual, prob, True)


@torch.jit.script
425
def bias_dropout_add_fused_inference(x, bias, residual, prob):
426
427
    # type: (Tensor, Tensor, Tensor, float) -> Tensor
    return bias_dropout_add(x, bias, residual, prob, False)
428
429
430
431
432


class ParallelTransformerLayer(MegatronModule):
    """A single transformer layer.

433
    Transformer layer takes input with size [b, s, h] and returns an
434
435
    output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
436

437
438
    def __init__(self, init_method, output_layer_init_method,
                 layer_number, layer_type=LayerType.encoder,
439
                 self_attn_mask_type=AttnMaskType.padding):
Mohammad's avatar
Mohammad committed
440
        args = get_args()
441
442

        super(ParallelTransformerLayer, self).__init__()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
443
        self.layer_number = layer_number
444
        self.layer_type = layer_type
445
446

        self.apply_residual_connection_post_layernorm \
Mohammad's avatar
Mohammad committed
447
            = args.apply_residual_connection_post_layernorm
448
449

        # Layernorm on the input data.
450
        LayerNorm = import_layernorm(args.fp32_residual_connection)
451
        self.input_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
452
453
            args.hidden_size,
            eps=args.layernorm_epsilon)
454
455

        # Self attention.
456
457
458
459
460
461
        self.self_attention = ParallelAttention(
            init_method,
            output_layer_init_method,
            layer_number,
            attention_type=AttnType.self_attn,
            attn_mask_type=self_attn_mask_type)
462
463
        self.hidden_dropout = args.hidden_dropout
        self.bias_dropout_fusion = args.bias_dropout_fusion
464

465
        # Layernorm on the attention output
466
        self.post_attention_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
467
468
            args.hidden_size,
            eps=args.layernorm_epsilon)
469

470
471
472
473
474
475
476
477
478
479
480
        if self.layer_type == LayerType.decoder:
            self.inter_attention = ParallelAttention(
                init_method,
                output_layer_init_method,
                layer_number,
                attention_type=AttnType.cross_attn)
            # Layernorm on the attention output.
            self.post_inter_attention_layernorm = LayerNorm(
                args.hidden_size,
                eps=args.layernorm_epsilon)

481
        # MLP
482
        self.mlp = ParallelMLP(init_method,
Mohammad's avatar
Mohammad committed
483
                               output_layer_init_method)
484

485
486
487
    def forward(self, hidden_states, attention_mask,
                encoder_output=None, enc_dec_attn_mask=None,
                layer_past=None, get_key_value=False):
488
489
        # hidden_states: [b, s, h]

490
        # Layer norm at the beginning of the transformer layer.
491
492
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
493
        attention_output, attention_bias = \
494
495
496
497
            self.self_attention(layernorm_output,
                                attention_mask,
                                layer_past=layer_past,
                                get_key_value=get_key_value)
498

499
500
        if get_key_value:
            attention_output, presents = attention_output
501

502
503
        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
504
505
506
507
            residual = layernorm_output
        else:
            residual = hidden_states

508
509
        # jit scripting for a nn.module (with dropout) is not
        # trigerring the fusion kernel. For now, we use two
510
511
512
513
514
515
516
        # different nn.functional routines to account for varying
        # dropout semantics during training and inference phases.
        if self.bias_dropout_fusion:
            if self.training:
                bias_dropout_add_func = bias_dropout_add_fused_train
            else:
                bias_dropout_add_func = bias_dropout_add_fused_inference
517
        else:
518
519
            bias_dropout_add_func = get_bias_dropout_add(self.training)

520
        # re-enable torch grad to enable fused optimization.
521
522
523
524
525
526
527
        with torch.enable_grad():
            layernorm_input = bias_dropout_add_func(
                attention_output,
                attention_bias.expand_as(residual),
                residual,
                self.hidden_dropout)

528
529
530
        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)

531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
        if self.layer_type == LayerType.decoder:
            attention_output, attention_bias = \
                self.inter_attention(layernorm_output,
                                     enc_dec_attn_mask,
                                     encoder_output=encoder_output)
            # residual connection
            if self.apply_residual_connection_post_layernorm:
                residual = layernorm_output
            else:
                residual = layernorm_input

            # re-enable torch grad to enable fused optimization.
            with torch.enable_grad():
                layernorm_input = bias_dropout_add_func(
                    attention_output,
                    attention_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)

            # Layer norm post the decoder attention
            layernorm_output = self.post_inter_attention_layernorm(layernorm_input)

553
        # MLP.
554
        mlp_output, mlp_bias = self.mlp(layernorm_output)
555

556
557
        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
558
            residual = layernorm_output
559
        else:
560
561
            residual = layernorm_input

562
        # re-enable torch grad to enable fused optimization.
563
564
565
566
567
568
        with torch.enable_grad():
            output = bias_dropout_add_func(
                mlp_output,
                mlp_bias.expand_as(residual),
                residual,
                self.hidden_dropout)
569
570
571
572
573
574
575
576
577
578

        if get_key_value:
            output = [output, presents]

        return output


class ParallelTransformer(MegatronModule):
    """Transformer class."""

579
    def __init__(self, init_method, output_layer_init_method,
580
581
                 layer_type=LayerType.encoder,
                 self_attn_mask_type=AttnMaskType.padding):
582
        super(ParallelTransformer, self).__init__()
Mohammad's avatar
Mohammad committed
583
        args = get_args()
584

585
586
        self.fp32_residual_connection = args.fp32_residual_connection

587
        # Store activation checkpoiting flag.
Mohammad's avatar
Mohammad committed
588
589
        self.checkpoint_activations = args.checkpoint_activations
        self.checkpoint_num_layers = args.checkpoint_num_layers
590

591
        # Number of layers.
592
        assert args.num_layers % mpu.get_pipeline_model_parallel_world_size() == 0, \
593
            'num_layers must be divisible by pipeline_model_parallel_size'
594
        self.num_layers = args.num_layers // mpu.get_pipeline_model_parallel_world_size()
Mohammad's avatar
Mohammad committed
595
596
597

        # Transformer layers.
        def build_layer(layer_number):
598
            return ParallelTransformerLayer(
599
600
601
                init_method,
                output_layer_init_method,
                layer_number,
602
603
                layer_type=layer_type,
                self_attn_mask_type=self_attn_mask_type)
604
        offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
605
        self.layers = torch.nn.ModuleList(
606
            [build_layer(i + 1 + offset) for i in range(self.num_layers)])
607

608
        if mpu.is_pipeline_last_stage():
609
            # Final layer norm before output.
610
            LayerNorm = import_layernorm(args.fp32_residual_connection)
611
612
613
            self.final_layernorm = LayerNorm(
                args.hidden_size,
                eps=args.layernorm_epsilon)
614

Mohammad's avatar
Mohammad committed
615
    def _get_layer(self, layer_number):
616
        return self.layers[layer_number]
Mohammad's avatar
Mohammad committed
617

618
619
    def _checkpointed_forward(self, hidden_states, attention_mask,
                              encoder_output, enc_dec_attn_mask):
620
621
622
623
        """Forward method with activation checkpointing."""
        def custom(start, end):
            def custom_forward(*inputs):
                x_ = inputs[0]
624
625
626
                attention_mask = inputs[1]
                encoder_output = inputs[2]
                enc_dec_attn_mask = inputs[3]
Mohammad's avatar
Mohammad committed
627
628
                for index in range(start, end):
                    layer = self._get_layer(index)
629
                    x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
630
631
632
                return x_
            return custom_forward

633
634
        # Make sure memory is freed.
        mpu.reset_checkpointed_activations_memory_buffer()
635
        l = 0
Mohammad's avatar
Mohammad committed
636
        while l < self.num_layers:
637
            hidden_states = mpu.checkpoint(
Neel Kant's avatar
Neel Kant committed
638
                custom(l, l + self.checkpoint_num_layers),
639
                hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
640
641
642
643
644
            l += self.checkpoint_num_layers

        return hidden_states

    def forward(self, hidden_states, attention_mask, layer_past=None,
645
                get_key_value=False, encoder_output=None, enc_dec_attn_mask=None):
646

647
        # Checks.
648
649
650
651
652
653
654
655
656
        if layer_past is not None:
            assert get_key_value, \
                'for not None values in layer_past, ' \
                'expected get_key_value to be set'
        if get_key_value:
            assert not self.checkpoint_activations, \
                'get_key_value does not work with ' \
                'activation checkpointing'

657
658
        if mpu.is_pipeline_first_stage():
            # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
mshoeybi's avatar
mshoeybi committed
659
            # If the input flag for fp32 residual connection is set, convert for float.
660
661
            if self.fp32_residual_connection:
                hidden_states = hidden_states.transpose(0, 1).contiguous().float()
mshoeybi's avatar
mshoeybi committed
662
            # Otherwise, leave it as is.
663
664
            else:
                hidden_states = hidden_states.transpose(0, 1).contiguous()
665

Vijay Korthikanti's avatar
Vijay Korthikanti committed
666
667
668
        if encoder_output is not None:
             encoder_output = encoder_output.transpose(0, 1).contiguous()
          
669
670
        if self.checkpoint_activations:
            hidden_states = self._checkpointed_forward(hidden_states,
671
672
673
                                                       attention_mask,
                                                       encoder_output,
                                                       enc_dec_attn_mask)
674
675
676
        else:
            if get_key_value:
                presents = []
Mohammad's avatar
Mohammad committed
677
678
            for index in range(self.num_layers):
                layer = self._get_layer(index)
679
680
                past = None
                if layer_past is not None:
Mohammad's avatar
Mohammad committed
681
                    past = layer_past[index]
682
683
                hidden_states = layer(hidden_states,
                                      attention_mask,
684
685
                                      encoder_output=encoder_output,
                                      enc_dec_attn_mask=enc_dec_attn_mask,
686
687
688
689
690
                                      layer_past=past,
                                      get_key_value=get_key_value)
                if get_key_value:
                    hidden_states, present = hidden_states
                    presents.append(present)
691

692
        # Final layer norm.
693
        if mpu.is_pipeline_last_stage():
694
695
            # Reverting data format change [s b h] --> [b s h].
            hidden_states = hidden_states.transpose(0, 1).contiguous()
696
697
698
            output = self.final_layernorm(hidden_states)
        else:
            output = hidden_states
699
700
701
702
        if get_key_value:
            output = [output, presents]

        return output