transformer.py 25.8 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transformer."""
import math
import torch
19
import torch.nn.functional as F
20

Mohammad's avatar
Mohammad committed
21
from megatron import get_args
22
from megatron import mpu
23
from .module import MegatronModule
24
from megatron.model.enums import AttnMaskType, LayerType, AttnType
25
from megatron.model import LayerNorm
26
27
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
28
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
29

30
31
32
33
34
# flags required to enable jit fusion kernels
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
35
36
37
38
39
40
41
42
43
44
45

""" We use the following notation throughout this file:
     h: hidden size
     n: number of attention heads
     p: number of model parallel partitions
     np: n/p
     hp: h/p
     hn: h/n
     b: batch size
     s: sequence length
     l: number of layers
46
    Transformer takes input of size [s, b, h] and returns a
47
48
49
50
51
52
53
54
55
56
57
58
59
    tensor of the same size. We use the following arguments:
        hyperparameters: transformer hyperparameters
"""

class ParallelMLP(MegatronModule):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
    state back into h hidden dimension. At the end, dropout is also
    applied.
    """

60
    def __init__(self, init_method, output_layer_init_method):
61
        super(ParallelMLP, self).__init__()
Mohammad's avatar
Mohammad committed
62
        args = get_args()
63
64
65

        # Project to 4h.
        self.dense_h_to_4h = mpu.ColumnParallelLinear(
Mohammad's avatar
Mohammad committed
66
            args.hidden_size,
67
            args.ffn_hidden_size,
68
            gather_output=False,
69
70
            init_method=init_method,
            skip_bias_add=True)
71

72
73
74
75
76
77
        self.bias_gelu_fusion = args.bias_gelu_fusion
        self.activation_func = F.gelu
        if args.openai_gelu:
            self.activation_func = openai_gelu
        elif args.onnx_safe:
            self.activation_func = erf_gelu
78
79
80

        # Project back to h.
        self.dense_4h_to_h = mpu.RowParallelLinear(
81
            args.ffn_hidden_size,
Mohammad's avatar
Mohammad committed
82
            args.hidden_size,
83
            input_is_parallel=True,
84
85
            init_method=output_layer_init_method,
            skip_bias_add=True)
86

87
88
89

    def forward(self, hidden_states):

90
91
        # [s, b, 4hp]
        intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
92

93
94
95
96
97
98
99
100
101
102
        if self.bias_gelu_fusion:
             intermediate_parallel = \
                     bias_gelu_impl(intermediate_parallel, bias_parallel)
        else:
            intermediate_parallel = \
                self.activation_func(intermediate_parallel + bias_parallel)

        # [s, b, h]
        output, output_bias = self.dense_4h_to_h(intermediate_parallel)
        return output, output_bias
103
104


105
class ParallelAttention(MegatronModule):
106
107
108
109
110
    """Parallel self-attention layer abstract class.

    Self-attention layer takes input with size [b, s, h]
    and returns output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
111

112
    def __init__(self, init_method,
113
114
115
116
                 output_layer_init_method, layer_number,
                 attention_type=AttnType.self_attn,
                 attn_mask_type=AttnMaskType.padding):
        super(ParallelAttention, self).__init__()
Mohammad's avatar
Mohammad committed
117
        args = get_args()
Mohammad's avatar
Mohammad committed
118
        self.fp16 = args.fp16
119
        self.bf16 = args.bf16
120

Mohammad's avatar
Mohammad committed
121
122
        self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
123
124
125
        if self.apply_query_key_layer_scaling:
            self.attention_softmax_in_fp32 = True
        self.layer_number = max(1, layer_number)
126
127
128
129
        self.attention_type = attention_type
        self.attn_mask_type = attn_mask_type

        projection_size = args.kv_channels * args.num_attention_heads
130
131

        # Per attention head and per partition values.
132
        world_size = mpu.get_tensor_model_parallel_world_size()
133
        self.hidden_size_per_partition = mpu.divide(projection_size,
Mohammad's avatar
Mohammad committed
134
                                                    world_size)
135
        self.hidden_size_per_attention_head = mpu.divide(
136
            projection_size, args.num_attention_heads)
137
        self.num_attention_heads_per_partition = mpu.divide(
Mohammad's avatar
Mohammad committed
138
            args.num_attention_heads, world_size)
139
140

        # Strided linear layer.
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
        if attention_type == AttnType.self_attn:
            self.query_key_value = mpu.ColumnParallelLinear(
                args.hidden_size,
                3 * projection_size,
                gather_output=False,
                init_method=init_method)
        else:
            assert attention_type == AttnType.cross_attn
            self.query = mpu.ColumnParallelLinear(
                args.hidden_size,
                projection_size,
                gather_output=False,
                init_method=init_method)

            self.key_value = mpu.ColumnParallelLinear(
                args.hidden_size,
                2 * projection_size,
                gather_output=False,
                init_method=init_method)
160

161
162
163
164
165
166
167
        coeff = None
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        if self.apply_query_key_layer_scaling:
            coeff = self.layer_number
            self.norm_factor *= coeff

        self.scale_mask_softmax = FusedScaleMaskSoftmax(
168
            self.fp16, self.bf16,
169
170
            self.attn_mask_type,
            args.masked_softmax_fusion,
171
            attention_mask_func,
172
173
174
            self.attention_softmax_in_fp32,
            coeff)

175
176
177
        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
Mohammad's avatar
Mohammad committed
178
        self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
179
180
181

        # Output.
        self.dense = mpu.RowParallelLinear(
182
            projection_size,
Mohammad's avatar
Mohammad committed
183
            args.hidden_size,
184
            input_is_parallel=True,
185
186
            init_method=output_layer_init_method,
            skip_bias_add=True)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
187

188
    def forward(self, hidden_states, attention_mask, layer_past=None,
189
                get_key_value=False, encoder_output=None):
190
        # hidden_states: [sq, b, h]
191

192
193
194
        # =====================
        # Query, Key, and Value
        # =====================
195

196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
        if self.attention_type == AttnType.self_attn:
            # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
            mixed_x_layer, _ = self.query_key_value(hidden_states)

            # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
            new_tensor_shape = mixed_x_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 3 * self.hidden_size_per_attention_head)
            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

            # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
            (query_layer,
             key_layer,
             value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
        else:
            # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
            mixed_kv_layer, _ = self.key_value(encoder_output)

            # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
            new_tensor_shape = mixed_kv_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 2 * self.hidden_size_per_attention_head)
            mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

            # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
            (key_layer,
             value_layer) = mpu.split_tensor_along_last_dim(mixed_kv_layer, 2)

            # Attention head [sq, b, h] --> [sq, b, hp]
            query_layer, _ = self.query(hidden_states)
            # [sq, b, hp] --> [sq, b, np, hn]
            new_tensor_shape = query_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 self.hidden_size_per_attention_head)
            query_layer = query_layer.view(*new_tensor_shape)
231

232
233
234
        # ==================================
        # Adjust key and value for inference
        # ==================================
235
236
237
238

        if layer_past is not None:
            past_key, past_value = layer_past
            key_layer = torch.cat((past_key.type_as(key_layer),
239
                                   key_layer), dim=0)
240
            value_layer = torch.cat((past_value.type_as(value_layer),
241
                                     value_layer), dim=0)
242
243
244
        if get_key_value:
            present = (key_layer, value_layer)

245
246
247
        # ===================================
        # Raw attention scores. [b, np, s, s]
        # ===================================
248

249
        # [b, np, sq, sk]
250
251
252
        output_size = (query_layer.size(1),
                       query_layer.size(2),
                       query_layer.size(0),
253
                       key_layer.size(0))
254

255
        # [sq, b, np, hn] -> [sq, b * np, hn]
256
257
        query_layer = query_layer.view(output_size[2],
                                       output_size[0] * output_size[1], -1)
258
        # [sk, b, np, hn] -> [sk, b * np, hn]
259
260
261
        key_layer = key_layer.view(output_size[3],
                                   output_size[0] * output_size[1], -1)

262
        # preallocting result tensor: [b * np, sq, sk]
263
        matmul_result = torch.empty(
264
265
            output_size[0]*output_size[1],
            output_size[2],
266
            output_size[3],
267
            dtype=query_layer.dtype,
268
269
            device=torch.cuda.current_device())

270
        # Raw attention scores. [b * np, sq, sk]
271
272
        matmul_result = torch.baddbmm(
            matmul_result,
273
            query_layer.transpose(0, 1),   # [b * np, sq, hn]
274
            key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
275
276
            beta=0.0, alpha=(1.0/self.norm_factor))

277
        # change view to [b, np, sq, sk]
278
279
280
        attention_scores = matmul_result.view(*output_size)

        # ==================================================
281
        # Update attention mask for inference. [b, np, sq, sk]
282
        # ==================================================
283

284
285
286
287
288
        if get_key_value:
            with torch.no_grad():
                if layer_past is not None:
                    attention_mask = attention_mask[
                        ...,
Neel Kant's avatar
Neel Kant committed
289
                        attention_scores.size(3) - 1,
290
291
292
293
294
295
296
                        :attention_scores.size(3)].unsqueeze(2)
                else:
                    attention_mask = attention_mask[
                        ...,
                        :attention_scores.size(3),
                        :attention_scores.size(3)]

297
298
299
        # ===========================
        # Attention probs and dropout
        # ===========================
300

301
        # attention scores and attention mask [b, np, sq, sk]
302
303
        attention_probs = self.scale_mask_softmax(attention_scores,
                                                  attention_mask)
304

305
306
307
308
309
310
        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        with mpu.get_cuda_rng_tracker().fork():
            attention_probs = self.attention_dropout(attention_probs)

        # =========================
311
        # Context layer. [sq, b, hp]
312
313
        # =========================

314
315
        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]
316

317
        # context layer shape: [b, np, sq, hn]
318
319
320
321
        output_size = (value_layer.size(1),
                       value_layer.size(2),
                       query_layer.size(0),
                       value_layer.size(3))
322

323
        # change view [sk, b * np, hn]
324
        value_layer = value_layer.view(value_layer.size(0),
325
                                       output_size[0] * output_size[1], -1)
326

327
        # change view [b * np, sq, sk]
328
329
        attention_probs = attention_probs.view(output_size[0] * output_size[1],
                                               output_size[2], -1)
330

331
        # matmul: [b * np, sq, hn]
332
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
333

334
        # change view [b, np, sq, hn]
335
336
        context_layer = context_layer.view(*output_size)

337
        # [b, np, sq, hn] --> [sq, b, np, hn]
338
339
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

340
        # [sq, b, np, hn] --> [sq, b, hp]
341
342
343
344
345
        new_context_layer_shape = context_layer.size()[:-2] + \
            (self.hidden_size_per_partition,)
        context_layer = context_layer.view(*new_context_layer_shape)

        # =================
346
        # Output. [sq, b, h]
347
348
349
        # =================

        output, bias = self.dense(context_layer)
350
351
352
353

        if get_key_value:
            output = [output, present]

354
355
356
        return output, bias


357
def bias_dropout_add(x, bias, residual, prob, training):
358
359
360
361
362
363
364
365
366
367
368
369
370
    # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
    out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
    out = residual + out
    return out


def get_bias_dropout_add(training):
    def _bias_dropout_add(x, bias, residual, prob):
        return bias_dropout_add(x, bias, residual, prob, training)
    return _bias_dropout_add


@torch.jit.script
371
def bias_dropout_add_fused_train(x, bias, residual, prob):
372
373
374
375
376
    # type: (Tensor, Tensor, Tensor, float) -> Tensor
    return bias_dropout_add(x, bias, residual, prob, True)


@torch.jit.script
377
def bias_dropout_add_fused_inference(x, bias, residual, prob):
378
379
    # type: (Tensor, Tensor, Tensor, float) -> Tensor
    return bias_dropout_add(x, bias, residual, prob, False)
380
381
382
383
384


class ParallelTransformerLayer(MegatronModule):
    """A single transformer layer.

385
    Transformer layer takes input with size [b, s, h] and returns an
386
387
    output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
388

389
390
    def __init__(self, init_method, output_layer_init_method,
                 layer_number, layer_type=LayerType.encoder,
391
                 self_attn_mask_type=AttnMaskType.padding):
Mohammad's avatar
Mohammad committed
392
        args = get_args()
393
394

        super(ParallelTransformerLayer, self).__init__()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
395
        self.layer_number = layer_number
396
        self.layer_type = layer_type
397
398

        self.apply_residual_connection_post_layernorm \
Mohammad's avatar
Mohammad committed
399
            = args.apply_residual_connection_post_layernorm
400

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
401
402
403
        self.bf16 = args.bf16
        self.fp32_residual_connection = args.fp32_residual_connection

404
405
        # Layernorm on the input data.
        self.input_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
406
407
            args.hidden_size,
            eps=args.layernorm_epsilon)
408
409

        # Self attention.
410
411
412
413
414
415
        self.self_attention = ParallelAttention(
            init_method,
            output_layer_init_method,
            layer_number,
            attention_type=AttnType.self_attn,
            attn_mask_type=self_attn_mask_type)
416
417
        self.hidden_dropout = args.hidden_dropout
        self.bias_dropout_fusion = args.bias_dropout_fusion
418

419
        # Layernorm on the attention output
420
        self.post_attention_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
421
422
            args.hidden_size,
            eps=args.layernorm_epsilon)
423

424
425
426
427
428
429
430
431
432
433
434
        if self.layer_type == LayerType.decoder:
            self.inter_attention = ParallelAttention(
                init_method,
                output_layer_init_method,
                layer_number,
                attention_type=AttnType.cross_attn)
            # Layernorm on the attention output.
            self.post_inter_attention_layernorm = LayerNorm(
                args.hidden_size,
                eps=args.layernorm_epsilon)

435
        # MLP
436
        self.mlp = ParallelMLP(init_method,
Mohammad's avatar
Mohammad committed
437
                               output_layer_init_method)
438

439
440
441
    def forward(self, hidden_states, attention_mask,
                encoder_output=None, enc_dec_attn_mask=None,
                layer_past=None, get_key_value=False):
442
443
        # hidden_states: [b, s, h]

444
        # Layer norm at the beginning of the transformer layer.
445
446
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
447
        attention_output, attention_bias = \
448
449
450
451
            self.self_attention(layernorm_output,
                                attention_mask,
                                layer_past=layer_past,
                                get_key_value=get_key_value)
452

453
454
        if get_key_value:
            attention_output, presents = attention_output
455

456
457
        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
458
459
460
461
            residual = layernorm_output
        else:
            residual = hidden_states

462
463
        # jit scripting for a nn.module (with dropout) is not
        # trigerring the fusion kernel. For now, we use two
464
465
466
467
468
469
470
        # different nn.functional routines to account for varying
        # dropout semantics during training and inference phases.
        if self.bias_dropout_fusion:
            if self.training:
                bias_dropout_add_func = bias_dropout_add_fused_train
            else:
                bias_dropout_add_func = bias_dropout_add_fused_inference
471
        else:
472
473
            bias_dropout_add_func = get_bias_dropout_add(self.training)

474
        # re-enable torch grad to enable fused optimization.
475
476
477
478
479
480
481
        with torch.enable_grad():
            layernorm_input = bias_dropout_add_func(
                attention_output,
                attention_bias.expand_as(residual),
                residual,
                self.hidden_dropout)

482
483
484
        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)

485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
        if self.layer_type == LayerType.decoder:
            attention_output, attention_bias = \
                self.inter_attention(layernorm_output,
                                     enc_dec_attn_mask,
                                     encoder_output=encoder_output)
            # residual connection
            if self.apply_residual_connection_post_layernorm:
                residual = layernorm_output
            else:
                residual = layernorm_input

            # re-enable torch grad to enable fused optimization.
            with torch.enable_grad():
                layernorm_input = bias_dropout_add_func(
                    attention_output,
                    attention_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)

            # Layer norm post the decoder attention
            layernorm_output = self.post_inter_attention_layernorm(layernorm_input)

507
        # MLP.
508
        mlp_output, mlp_bias = self.mlp(layernorm_output)
509

510
511
        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
512
            residual = layernorm_output
513
        else:
514
515
            residual = layernorm_input

516
        # re-enable torch grad to enable fused optimization.
517
518
519
520
521
522
        with torch.enable_grad():
            output = bias_dropout_add_func(
                mlp_output,
                mlp_bias.expand_as(residual),
                residual,
                self.hidden_dropout)
523
524
525
526
527
528
529
530
531
532

        if get_key_value:
            output = [output, presents]

        return output


class ParallelTransformer(MegatronModule):
    """Transformer class."""

533
    def __init__(self, init_method, output_layer_init_method,
534
535
                 layer_type=LayerType.encoder,
                 self_attn_mask_type=AttnMaskType.padding):
536
        super(ParallelTransformer, self).__init__()
Mohammad's avatar
Mohammad committed
537
        args = get_args()
538

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
539
        self.bf16 = args.bf16
540
541
        self.fp32_residual_connection = args.fp32_residual_connection

542
        # Store activation checkpoiting flag.
Mohammad's avatar
Mohammad committed
543
544
        self.checkpoint_activations = args.checkpoint_activations
        self.checkpoint_num_layers = args.checkpoint_num_layers
545

546
        # Number of layers.
547
        assert args.num_layers % mpu.get_pipeline_model_parallel_world_size() == 0, \
548
            'num_layers must be divisible by pipeline_model_parallel_size'
549
        self.num_layers = args.num_layers // mpu.get_pipeline_model_parallel_world_size()
Mohammad's avatar
Mohammad committed
550
551
552

        # Transformer layers.
        def build_layer(layer_number):
553
            return ParallelTransformerLayer(
554
555
556
                init_method,
                output_layer_init_method,
                layer_number,
557
558
                layer_type=layer_type,
                self_attn_mask_type=self_attn_mask_type)
559
560
        if args.virtual_pipeline_model_parallel_size is not None:
            assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
561
562
563
564
                'num_layers_per_stage must be divisible by ' \
                'virtual_pipeline_model_parallel_size'
            # Number of layers in each model chunk is the number of layers in the stage,
            # divided by the number of model chunks in a stage.
565
            self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
566
567
568
569
570
571
572
573
            # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0]  [2]  [4]  [6]
            # Stage 1: [1]  [3]  [5]  [7]
            # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0, 1]  [4, 5]
            # Stage 1: [2, 3]  [6, 7]
574
575
576
577
            offset = mpu.get_virtual_pipeline_model_parallel_rank() * (
                    args.num_layers // args.virtual_pipeline_model_parallel_size) + \
                (mpu.get_pipeline_model_parallel_rank() * self.num_layers)
        else:
578
            # Each stage gets a contiguous set of layers.
579
            offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
580
        self.layers = torch.nn.ModuleList(
581
            [build_layer(i + 1 + offset) for i in range(self.num_layers)])
582

583
        if mpu.is_pipeline_last_stage():
584
585
586
587
            # Final layer norm before output.
            self.final_layernorm = LayerNorm(
                args.hidden_size,
                eps=args.layernorm_epsilon)
588

Mohammad's avatar
Mohammad committed
589
    def _get_layer(self, layer_number):
590
        return self.layers[layer_number]
Mohammad's avatar
Mohammad committed
591

592
593
    def _checkpointed_forward(self, hidden_states, attention_mask,
                              encoder_output, enc_dec_attn_mask):
594
595
596
597
        """Forward method with activation checkpointing."""
        def custom(start, end):
            def custom_forward(*inputs):
                x_ = inputs[0]
598
599
600
                attention_mask = inputs[1]
                encoder_output = inputs[2]
                enc_dec_attn_mask = inputs[3]
Mohammad's avatar
Mohammad committed
601
602
                for index in range(start, end):
                    layer = self._get_layer(index)
603
                    x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
604
605
606
                return x_
            return custom_forward

607
608
        # Make sure memory is freed.
        mpu.reset_checkpointed_activations_memory_buffer()
609
        l = 0
Mohammad's avatar
Mohammad committed
610
        while l < self.num_layers:
611
            hidden_states = mpu.checkpoint(
Neel Kant's avatar
Neel Kant committed
612
                custom(l, l + self.checkpoint_num_layers),
613
                hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
614
615
616
617
618
            l += self.checkpoint_num_layers

        return hidden_states

    def forward(self, hidden_states, attention_mask, layer_past=None,
619
                get_key_value=False, encoder_output=None, enc_dec_attn_mask=None):
620

621
        # Checks.
622
623
624
625
626
627
628
629
630
        if layer_past is not None:
            assert get_key_value, \
                'for not None values in layer_past, ' \
                'expected get_key_value to be set'
        if get_key_value:
            assert not self.checkpoint_activations, \
                'get_key_value does not work with ' \
                'activation checkpointing'

631
632
        if mpu.is_pipeline_first_stage():
            # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
mshoeybi's avatar
mshoeybi committed
633
            # If the input flag for fp32 residual connection is set, convert for float.
634
635
            if self.fp32_residual_connection:
                hidden_states = hidden_states.transpose(0, 1).contiguous().float()
mshoeybi's avatar
mshoeybi committed
636
            # Otherwise, leave it as is.
637
638
            else:
                hidden_states = hidden_states.transpose(0, 1).contiguous()
639

Vijay Korthikanti's avatar
Vijay Korthikanti committed
640
641
642
        if encoder_output is not None:
             encoder_output = encoder_output.transpose(0, 1).contiguous()
          
643
644
        if self.checkpoint_activations:
            hidden_states = self._checkpointed_forward(hidden_states,
645
646
647
                                                       attention_mask,
                                                       encoder_output,
                                                       enc_dec_attn_mask)
648
649
650
        else:
            if get_key_value:
                presents = []
Mohammad's avatar
Mohammad committed
651
652
            for index in range(self.num_layers):
                layer = self._get_layer(index)
653
654
                past = None
                if layer_past is not None:
Mohammad's avatar
Mohammad committed
655
                    past = layer_past[index]
656
657
                hidden_states = layer(hidden_states,
                                      attention_mask,
658
659
                                      encoder_output=encoder_output,
                                      enc_dec_attn_mask=enc_dec_attn_mask,
660
661
662
663
664
                                      layer_past=past,
                                      get_key_value=get_key_value)
                if get_key_value:
                    hidden_states, present = hidden_states
                    presents.append(present)
665

666
        # Final layer norm.
667
        if mpu.is_pipeline_last_stage():
668
669
            # Reverting data format change [s b h] --> [b s h].
            hidden_states = hidden_states.transpose(0, 1).contiguous()
670
671
672
            output = self.final_layernorm(hidden_states)
        else:
            output = hidden_states
673
674
675
676
        if get_key_value:
            output = [output, presents]

        return output