transformer.py 27.8 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transformer."""
import math
import torch
19
import torch.nn.functional as F
20

Mohammad's avatar
Mohammad committed
21
from megatron import get_args
22
from megatron import mpu
23
from .module import MegatronModule
24
from megatron.model.enums import AttnMaskType, LayerType, AttnType
25
from megatron.model import LayerNorm
26
27
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
28
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
29

30
31
32
33
34
# flags required to enable jit fusion kernels
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
35
36
37
38
39
40
41
42
43
44
45

""" We use the following notation throughout this file:
     h: hidden size
     n: number of attention heads
     p: number of model parallel partitions
     np: n/p
     hp: h/p
     hn: h/n
     b: batch size
     s: sequence length
     l: number of layers
46
    Transformer takes input of size [s, b, h] and returns a
47
48
49
50
51
52
53
54
55
    tensor of the same size. We use the following arguments:
        hyperparameters: transformer hyperparameters
"""

class ParallelMLP(MegatronModule):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
hwijeen's avatar
hwijeen committed
56
    state back into h hidden dimension.
57
58
    """

59
    def __init__(self, init_method, output_layer_init_method):
60
        super(ParallelMLP, self).__init__()
Mohammad's avatar
Mohammad committed
61
        args = get_args()
62
63
64

        # Project to 4h.
        self.dense_h_to_4h = mpu.ColumnParallelLinear(
Mohammad's avatar
Mohammad committed
65
            args.hidden_size,
66
            args.ffn_hidden_size,
67
            gather_output=False,
68
69
            init_method=init_method,
            skip_bias_add=True)
70

71
72
73
74
75
76
        self.bias_gelu_fusion = args.bias_gelu_fusion
        self.activation_func = F.gelu
        if args.openai_gelu:
            self.activation_func = openai_gelu
        elif args.onnx_safe:
            self.activation_func = erf_gelu
77
78
79

        # Project back to h.
        self.dense_4h_to_h = mpu.RowParallelLinear(
80
            args.ffn_hidden_size,
Mohammad's avatar
Mohammad committed
81
            args.hidden_size,
82
            input_is_parallel=True,
83
84
            init_method=output_layer_init_method,
            skip_bias_add=True)
85

86
87
    def forward(self, hidden_states):

88
89
        # [s, b, 4hp]
        intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
90

91
92
93
94
95
96
97
98
99
100
        if self.bias_gelu_fusion:
             intermediate_parallel = \
                     bias_gelu_impl(intermediate_parallel, bias_parallel)
        else:
            intermediate_parallel = \
                self.activation_func(intermediate_parallel + bias_parallel)

        # [s, b, h]
        output, output_bias = self.dense_4h_to_h(intermediate_parallel)
        return output, output_bias
101
102


103
class ParallelAttention(MegatronModule):
104
105
106
107
108
    """Parallel self-attention layer abstract class.

    Self-attention layer takes input with size [b, s, h]
    and returns output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
109

110
    def __init__(self, init_method,
111
112
113
114
                 output_layer_init_method, layer_number,
                 attention_type=AttnType.self_attn,
                 attn_mask_type=AttnMaskType.padding):
        super(ParallelAttention, self).__init__()
Mohammad's avatar
Mohammad committed
115
        args = get_args()
Mohammad's avatar
Mohammad committed
116
        self.fp16 = args.fp16
117
        self.bf16 = args.bf16
118

Mohammad's avatar
Mohammad committed
119
120
        self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
121
122
123
        if self.apply_query_key_layer_scaling:
            self.attention_softmax_in_fp32 = True
        self.layer_number = max(1, layer_number)
124
125
126
127
        self.attention_type = attention_type
        self.attn_mask_type = attn_mask_type

        projection_size = args.kv_channels * args.num_attention_heads
128
129

        # Per attention head and per partition values.
130
        world_size = mpu.get_tensor_model_parallel_world_size()
131
        self.hidden_size_per_partition = mpu.divide(projection_size,
Mohammad's avatar
Mohammad committed
132
                                                    world_size)
133
        self.hidden_size_per_attention_head = mpu.divide(
134
            projection_size, args.num_attention_heads)
135
        self.num_attention_heads_per_partition = mpu.divide(
Mohammad's avatar
Mohammad committed
136
            args.num_attention_heads, world_size)
137
138

        # Strided linear layer.
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
        if attention_type == AttnType.self_attn:
            self.query_key_value = mpu.ColumnParallelLinear(
                args.hidden_size,
                3 * projection_size,
                gather_output=False,
                init_method=init_method)
        else:
            assert attention_type == AttnType.cross_attn
            self.query = mpu.ColumnParallelLinear(
                args.hidden_size,
                projection_size,
                gather_output=False,
                init_method=init_method)

            self.key_value = mpu.ColumnParallelLinear(
                args.hidden_size,
                2 * projection_size,
                gather_output=False,
                init_method=init_method)
158

159
160
161
162
163
164
165
        coeff = None
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        if self.apply_query_key_layer_scaling:
            coeff = self.layer_number
            self.norm_factor *= coeff

        self.scale_mask_softmax = FusedScaleMaskSoftmax(
166
            self.fp16, self.bf16,
167
168
            self.attn_mask_type,
            args.masked_softmax_fusion,
169
            attention_mask_func,
170
171
172
            self.attention_softmax_in_fp32,
            coeff)

173
174
175
        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
Mohammad's avatar
Mohammad committed
176
        self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
177
178
179

        # Output.
        self.dense = mpu.RowParallelLinear(
180
            projection_size,
Mohammad's avatar
Mohammad committed
181
            args.hidden_size,
182
            input_is_parallel=True,
183
184
            init_method=output_layer_init_method,
            skip_bias_add=True)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
185

186
    def forward(self, hidden_states, attention_mask, layer_past=None,
187
                get_key_value=False, encoder_output=None):
188
        # hidden_states: [sq, b, h]
189

190
191
192
        # =====================
        # Query, Key, and Value
        # =====================
193

194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
        if self.attention_type == AttnType.self_attn:
            # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
            mixed_x_layer, _ = self.query_key_value(hidden_states)

            # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
            new_tensor_shape = mixed_x_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 3 * self.hidden_size_per_attention_head)
            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

            # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
            (query_layer,
             key_layer,
             value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
        else:
            # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
            mixed_kv_layer, _ = self.key_value(encoder_output)

            # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
            new_tensor_shape = mixed_kv_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 2 * self.hidden_size_per_attention_head)
            mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

            # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
            (key_layer,
             value_layer) = mpu.split_tensor_along_last_dim(mixed_kv_layer, 2)

            # Attention head [sq, b, h] --> [sq, b, hp]
            query_layer, _ = self.query(hidden_states)
            # [sq, b, hp] --> [sq, b, np, hn]
            new_tensor_shape = query_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 self.hidden_size_per_attention_head)
            query_layer = query_layer.view(*new_tensor_shape)
229

230
231
232
        # ==================================
        # Adjust key and value for inference
        # ==================================
233
234
235
236

        if layer_past is not None:
            past_key, past_value = layer_past
            key_layer = torch.cat((past_key.type_as(key_layer),
237
                                   key_layer), dim=0)
238
            value_layer = torch.cat((past_value.type_as(value_layer),
239
                                     value_layer), dim=0)
240
241
242
        if get_key_value:
            present = (key_layer, value_layer)

243
244
245
        # ===================================
        # Raw attention scores. [b, np, s, s]
        # ===================================
246

247
        # [b, np, sq, sk]
248
249
250
        output_size = (query_layer.size(1),
                       query_layer.size(2),
                       query_layer.size(0),
251
                       key_layer.size(0))
252

253
        # [sq, b, np, hn] -> [sq, b * np, hn]
254
255
        query_layer = query_layer.view(output_size[2],
                                       output_size[0] * output_size[1], -1)
256
        # [sk, b, np, hn] -> [sk, b * np, hn]
257
258
259
        key_layer = key_layer.view(output_size[3],
                                   output_size[0] * output_size[1], -1)

260
        # preallocting result tensor: [b * np, sq, sk]
261
        matmul_result = torch.empty(
262
263
            output_size[0]*output_size[1],
            output_size[2],
264
            output_size[3],
265
            dtype=query_layer.dtype,
266
267
            device=torch.cuda.current_device())

268
        # Raw attention scores. [b * np, sq, sk]
269
270
        matmul_result = torch.baddbmm(
            matmul_result,
271
            query_layer.transpose(0, 1),   # [b * np, sq, hn]
272
            key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
273
274
            beta=0.0, alpha=(1.0/self.norm_factor))

275
        # change view to [b, np, sq, sk]
276
277
278
        attention_scores = matmul_result.view(*output_size)

        # ==================================================
279
        # Update attention mask for inference. [b, np, sq, sk]
280
        # ==================================================
281

282
283
284
285
286
        if get_key_value:
            with torch.no_grad():
                if layer_past is not None:
                    attention_mask = attention_mask[
                        ...,
Neel Kant's avatar
Neel Kant committed
287
                        attention_scores.size(3) - 1,
288
289
290
291
292
293
294
                        :attention_scores.size(3)].unsqueeze(2)
                else:
                    attention_mask = attention_mask[
                        ...,
                        :attention_scores.size(3),
                        :attention_scores.size(3)]

295
296
297
        # ===========================
        # Attention probs and dropout
        # ===========================
298

299
        # attention scores and attention mask [b, np, sq, sk]
300
301
        attention_probs = self.scale_mask_softmax(attention_scores,
                                                  attention_mask)
302

303
304
305
306
307
308
        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        with mpu.get_cuda_rng_tracker().fork():
            attention_probs = self.attention_dropout(attention_probs)

        # =========================
309
        # Context layer. [sq, b, hp]
310
311
        # =========================

312
313
        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]
314

315
        # context layer shape: [b, np, sq, hn]
316
317
318
319
        output_size = (value_layer.size(1),
                       value_layer.size(2),
                       query_layer.size(0),
                       value_layer.size(3))
320

321
        # change view [sk, b * np, hn]
322
        value_layer = value_layer.view(value_layer.size(0),
323
                                       output_size[0] * output_size[1], -1)
324

325
        # change view [b * np, sq, sk]
326
327
        attention_probs = attention_probs.view(output_size[0] * output_size[1],
                                               output_size[2], -1)
328

329
        # matmul: [b * np, sq, hn]
330
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
331

332
        # change view [b, np, sq, hn]
333
334
        context_layer = context_layer.view(*output_size)

335
        # [b, np, sq, hn] --> [sq, b, np, hn]
336
337
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

338
        # [sq, b, np, hn] --> [sq, b, hp]
339
340
341
342
343
        new_context_layer_shape = context_layer.size()[:-2] + \
            (self.hidden_size_per_partition,)
        context_layer = context_layer.view(*new_context_layer_shape)

        # =================
344
        # Output. [sq, b, h]
345
346
347
        # =================

        output, bias = self.dense(context_layer)
348
349
350
351

        if get_key_value:
            output = [output, present]

352
353
354
        return output, bias


355
def bias_dropout_add(x, bias, residual, prob, training):
356
357
358
359
360
361
362
363
364
365
366
367
368
    # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
    out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
    out = residual + out
    return out


def get_bias_dropout_add(training):
    def _bias_dropout_add(x, bias, residual, prob):
        return bias_dropout_add(x, bias, residual, prob, training)
    return _bias_dropout_add


@torch.jit.script
369
def bias_dropout_add_fused_train(x, bias, residual, prob):
370
371
372
373
374
    # type: (Tensor, Tensor, Tensor, float) -> Tensor
    return bias_dropout_add(x, bias, residual, prob, True)


@torch.jit.script
375
def bias_dropout_add_fused_inference(x, bias, residual, prob):
376
377
    # type: (Tensor, Tensor, Tensor, float) -> Tensor
    return bias_dropout_add(x, bias, residual, prob, False)
378
379
380
381
382


class ParallelTransformerLayer(MegatronModule):
    """A single transformer layer.

383
    Transformer layer takes input with size [b, s, h] and returns an
384
385
    output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
386

387
388
    def __init__(self, init_method, output_layer_init_method,
                 layer_number, layer_type=LayerType.encoder,
389
                 self_attn_mask_type=AttnMaskType.padding):
Mohammad's avatar
Mohammad committed
390
        args = get_args()
391
392

        super(ParallelTransformerLayer, self).__init__()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
393
        self.layer_number = layer_number
394
        self.layer_type = layer_type
395
396

        self.apply_residual_connection_post_layernorm \
Mohammad's avatar
Mohammad committed
397
            = args.apply_residual_connection_post_layernorm
398

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
399
400
401
        self.bf16 = args.bf16
        self.fp32_residual_connection = args.fp32_residual_connection

402
403
        # Layernorm on the input data.
        self.input_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
404
405
            args.hidden_size,
            eps=args.layernorm_epsilon)
406
407

        # Self attention.
408
409
410
411
412
413
        self.self_attention = ParallelAttention(
            init_method,
            output_layer_init_method,
            layer_number,
            attention_type=AttnType.self_attn,
            attn_mask_type=self_attn_mask_type)
414
415
        self.hidden_dropout = args.hidden_dropout
        self.bias_dropout_fusion = args.bias_dropout_fusion
416

417
        # Layernorm on the attention output
418
        self.post_attention_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
419
420
            args.hidden_size,
            eps=args.layernorm_epsilon)
421

422
423
424
425
426
427
428
429
430
431
432
        if self.layer_type == LayerType.decoder:
            self.inter_attention = ParallelAttention(
                init_method,
                output_layer_init_method,
                layer_number,
                attention_type=AttnType.cross_attn)
            # Layernorm on the attention output.
            self.post_inter_attention_layernorm = LayerNorm(
                args.hidden_size,
                eps=args.layernorm_epsilon)

433
        # MLP
434
        self.mlp = ParallelMLP(init_method,
Mohammad's avatar
Mohammad committed
435
                               output_layer_init_method)
436

437
438
439
    def forward(self, hidden_states, attention_mask,
                encoder_output=None, enc_dec_attn_mask=None,
                layer_past=None, get_key_value=False):
440
441
        # hidden_states: [b, s, h]

442
        # Layer norm at the beginning of the transformer layer.
443
444
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
445
        attention_output, attention_bias = \
446
447
448
449
            self.self_attention(layernorm_output,
                                attention_mask,
                                layer_past=layer_past,
                                get_key_value=get_key_value)
450

451
452
        if get_key_value:
            attention_output, presents = attention_output
453

454
455
        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
456
457
458
459
            residual = layernorm_output
        else:
            residual = hidden_states

460
461
        # jit scripting for a nn.module (with dropout) is not
        # trigerring the fusion kernel. For now, we use two
462
463
464
465
466
467
468
        # different nn.functional routines to account for varying
        # dropout semantics during training and inference phases.
        if self.bias_dropout_fusion:
            if self.training:
                bias_dropout_add_func = bias_dropout_add_fused_train
            else:
                bias_dropout_add_func = bias_dropout_add_fused_inference
469
        else:
470
471
            bias_dropout_add_func = get_bias_dropout_add(self.training)

472
        # re-enable torch grad to enable fused optimization.
473
474
475
476
477
478
479
        with torch.enable_grad():
            layernorm_input = bias_dropout_add_func(
                attention_output,
                attention_bias.expand_as(residual),
                residual,
                self.hidden_dropout)

480
481
482
        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)

483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
        if self.layer_type == LayerType.decoder:
            attention_output, attention_bias = \
                self.inter_attention(layernorm_output,
                                     enc_dec_attn_mask,
                                     encoder_output=encoder_output)
            # residual connection
            if self.apply_residual_connection_post_layernorm:
                residual = layernorm_output
            else:
                residual = layernorm_input

            # re-enable torch grad to enable fused optimization.
            with torch.enable_grad():
                layernorm_input = bias_dropout_add_func(
                    attention_output,
                    attention_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)

            # Layer norm post the decoder attention
            layernorm_output = self.post_inter_attention_layernorm(layernorm_input)

505
        # MLP.
506
        mlp_output, mlp_bias = self.mlp(layernorm_output)
507

508
509
        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
510
            residual = layernorm_output
511
        else:
512
513
            residual = layernorm_input

514
        # re-enable torch grad to enable fused optimization.
515
516
517
518
519
520
        with torch.enable_grad():
            output = bias_dropout_add_func(
                mlp_output,
                mlp_bias.expand_as(residual),
                residual,
                self.hidden_dropout)
521
522
523
524
525
526
527
528
529
530

        if get_key_value:
            output = [output, presents]

        return output


class ParallelTransformer(MegatronModule):
    """Transformer class."""

531
    def __init__(self, init_method, output_layer_init_method,
532
                 layer_type=LayerType.encoder,
533
534
                 self_attn_mask_type=AttnMaskType.padding,
                 pre_process=True, post_process=True):
535
        super(ParallelTransformer, self).__init__()
Mohammad's avatar
Mohammad committed
536
        args = get_args()
537

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
538
        self.bf16 = args.bf16
539
        self.fp32_residual_connection = args.fp32_residual_connection
540
541
542
        self.pre_process = pre_process
        self.post_process = post_process
        self.input_tensor = None
543

544
        # Store activation checkpoiting flag.
545
546
        self.activations_checkpoint_method = args.activations_checkpoint_method
        self.activations_checkpoint_num_layers = args.activations_checkpoint_num_layers
mshoeybi's avatar
mshoeybi committed
547
        self.distribute_checkpointed_activations = args.distribute_checkpointed_activations
548

549
        # Number of layers.
550
        assert args.num_layers % mpu.get_pipeline_model_parallel_world_size() == 0, \
551
            'num_layers must be divisible by pipeline_model_parallel_size'
552
        self.num_layers = args.num_layers // mpu.get_pipeline_model_parallel_world_size()
Mohammad's avatar
Mohammad committed
553
554
555

        # Transformer layers.
        def build_layer(layer_number):
556
            return ParallelTransformerLayer(
557
558
559
                init_method,
                output_layer_init_method,
                layer_number,
560
561
                layer_type=layer_type,
                self_attn_mask_type=self_attn_mask_type)
562
563
        if args.virtual_pipeline_model_parallel_size is not None:
            assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
564
565
566
567
                'num_layers_per_stage must be divisible by ' \
                'virtual_pipeline_model_parallel_size'
            # Number of layers in each model chunk is the number of layers in the stage,
            # divided by the number of model chunks in a stage.
568
            self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
569
570
571
572
573
574
575
576
            # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0]  [2]  [4]  [6]
            # Stage 1: [1]  [3]  [5]  [7]
            # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0, 1]  [4, 5]
            # Stage 1: [2, 3]  [6, 7]
577
            offset = mpu.get_virtual_pipeline_model_parallel_rank() * (
578
                args.num_layers // args.virtual_pipeline_model_parallel_size) + \
579
580
                (mpu.get_pipeline_model_parallel_rank() * self.num_layers)
        else:
581
            # Each stage gets a contiguous set of layers.
582
            offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
583

584
        self.layers = torch.nn.ModuleList(
585
            [build_layer(i + 1 + offset) for i in range(self.num_layers)])
586

587
        if self.post_process:
588
589
590
591
            # Final layer norm before output.
            self.final_layernorm = LayerNorm(
                args.hidden_size,
                eps=args.layernorm_epsilon)
592

Mohammad's avatar
Mohammad committed
593
    def _get_layer(self, layer_number):
594
        return self.layers[layer_number]
Mohammad's avatar
Mohammad committed
595

596
597
    def _checkpointed_forward(self, hidden_states, attention_mask,
                              encoder_output, enc_dec_attn_mask):
598
599
600
601
        """Forward method with activation checkpointing."""
        def custom(start, end):
            def custom_forward(*inputs):
                x_ = inputs[0]
602
603
604
                attention_mask = inputs[1]
                encoder_output = inputs[2]
                enc_dec_attn_mask = inputs[3]
Mohammad's avatar
Mohammad committed
605
606
                for index in range(start, end):
                    layer = self._get_layer(index)
607
                    x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
608
609
610
                return x_
            return custom_forward

611
612
613
614
615
616
617
618
        if self.activations_checkpoint_method == 'uniform':
            # Uniformly divide the total number of Transformer layers and checkpoint
            # the input activation of each divided chunk.
            # A method to further reduce memory usage reducing checkpoints.
            l = 0
            while l < self.num_layers:
                hidden_states = mpu.checkpoint(
                    custom(l, l + self.activations_checkpoint_num_layers),
mshoeybi's avatar
mshoeybi committed
619
                    self.distribute_checkpointed_activations and ( (l > 0) or (mpu.get_pipeline_model_parallel_rank() == 0)),
620
621
622
623
624
625
626
627
628
629
                    hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
                l += self.activations_checkpoint_num_layers
        elif self.activations_checkpoint_method == 'block':
            # Checkpoint the input activation of only a set number of individual
            # Transformer layers and skip the rest.
            # A method fully use the device memory removing redundant re-computation.
            for l in range(self.num_layers):
                if l < self.activations_checkpoint_num_layers:
                    hidden_states = mpu.checkpoint(
                        custom(l, l + 1),
mshoeybi's avatar
mshoeybi committed
630
                        self.distribute_checkpointed_activations and ( (l > 0) or (mpu.get_pipeline_model_parallel_rank() == 0)),
631
632
633
634
635
636
                        hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
                else:
                    hidden_states = custom(l, l + 1)(
                        hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
        else:
            raise ValueError("Invalid activation checkpoint method.")
637
638
639

        return hidden_states

640
    def set_input_tensor(self, input_tensor):
641
642
643
644
645
646
647
        """Set input tensor to be used instead of forward()'s input.

        When doing pipeline parallelism the input from the previous
        stage comes from communication, not from the input, so the
        model's forward_step_func won't have it. This function is thus
        used by internal code to bypass the input provided by the
        forward_step_func"""
648
649
        self.input_tensor = input_tensor

650
    def forward(self, hidden_states, attention_mask, layer_past=None,
651
                get_key_value=False, encoder_output=None, enc_dec_attn_mask=None):
652

653
        # Checks.
654
655
656
657
658
        if layer_past is not None:
            assert get_key_value, \
                'for not None values in layer_past, ' \
                'expected get_key_value to be set'
        if get_key_value:
659
            assert self.activations_checkpoint_method is None, \
660
661
662
                'get_key_value does not work with ' \
                'activation checkpointing'

663
        if self.pre_process:
664
            # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
mshoeybi's avatar
mshoeybi committed
665
            # If the input flag for fp32 residual connection is set, convert for float.
666
667
            if self.fp32_residual_connection:
                hidden_states = hidden_states.transpose(0, 1).contiguous().float()
mshoeybi's avatar
mshoeybi committed
668
            # Otherwise, leave it as is.
669
670
            else:
                hidden_states = hidden_states.transpose(0, 1).contiguous()
671
        else:
672
            # See set_input_tensor()
673
            hidden_states = self.input_tensor
674

Vijay Korthikanti's avatar
Vijay Korthikanti committed
675
676
        if encoder_output is not None:
             encoder_output = encoder_output.transpose(0, 1).contiguous()
677

678
        if self.activations_checkpoint_method is not None:
679
            hidden_states = self._checkpointed_forward(hidden_states,
680
681
682
                                                       attention_mask,
                                                       encoder_output,
                                                       enc_dec_attn_mask)
683
684
685
        else:
            if get_key_value:
                presents = []
Mohammad's avatar
Mohammad committed
686
687
            for index in range(self.num_layers):
                layer = self._get_layer(index)
688
689
                past = None
                if layer_past is not None:
Mohammad's avatar
Mohammad committed
690
                    past = layer_past[index]
691
692
                hidden_states = layer(hidden_states,
                                      attention_mask,
693
694
                                      encoder_output=encoder_output,
                                      enc_dec_attn_mask=enc_dec_attn_mask,
695
696
697
698
699
                                      layer_past=past,
                                      get_key_value=get_key_value)
                if get_key_value:
                    hidden_states, present = hidden_states
                    presents.append(present)
700

701
        # Final layer norm.
702
        if self.post_process:
703
704
            # Reverting data format change [s b h] --> [b s h].
            hidden_states = hidden_states.transpose(0, 1).contiguous()
705
706
707
            output = self.final_layernorm(hidden_states)
        else:
            output = hidden_states
708
709
710
711
        if get_key_value:
            output = [output, presents]

        return output