transformer.py 28.5 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transformer."""
import math
import torch
19
import torch.nn.functional as F
20

Mohammad's avatar
Mohammad committed
21
from megatron import get_args
22
from megatron import mpu
23
from .module import MegatronModule
24
from megatron.model.enums import AttnMaskType, LayerType, AttnType
25
from megatron.model import LayerNorm
26
27
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
28
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
29
30
31
32
33
34
35
36
37
38
39
40


""" We use the following notation throughout this file:
     h: hidden size
     n: number of attention heads
     p: number of model parallel partitions
     np: n/p
     hp: h/p
     hn: h/n
     b: batch size
     s: sequence length
     l: number of layers
41
    Transformer takes input of size [s, b, h] and returns a
42
43
44
45
46
47
48
49
50
    tensor of the same size. We use the following arguments:
        hyperparameters: transformer hyperparameters
"""

class ParallelMLP(MegatronModule):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
hwijeen's avatar
hwijeen committed
51
    state back into h hidden dimension.
52
53
    """

54
    def __init__(self, init_method, output_layer_init_method):
55
        super(ParallelMLP, self).__init__()
Mohammad's avatar
Mohammad committed
56
        args = get_args()
57
58
59

        # Project to 4h.
        self.dense_h_to_4h = mpu.ColumnParallelLinear(
Mohammad's avatar
Mohammad committed
60
            args.hidden_size,
61
            args.ffn_hidden_size,
62
            gather_output=False,
63
64
            init_method=init_method,
            skip_bias_add=True)
65

66
67
68
69
70
71
        self.bias_gelu_fusion = args.bias_gelu_fusion
        self.activation_func = F.gelu
        if args.openai_gelu:
            self.activation_func = openai_gelu
        elif args.onnx_safe:
            self.activation_func = erf_gelu
72
73
74

        # Project back to h.
        self.dense_4h_to_h = mpu.RowParallelLinear(
75
            args.ffn_hidden_size,
Mohammad's avatar
Mohammad committed
76
            args.hidden_size,
77
            input_is_parallel=True,
78
79
            init_method=output_layer_init_method,
            skip_bias_add=True)
80

81
82
    def forward(self, hidden_states):

83
84
        # [s, b, 4hp]
        intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
85

86
87
88
89
90
91
92
93
94
95
        if self.bias_gelu_fusion:
             intermediate_parallel = \
                     bias_gelu_impl(intermediate_parallel, bias_parallel)
        else:
            intermediate_parallel = \
                self.activation_func(intermediate_parallel + bias_parallel)

        # [s, b, h]
        output, output_bias = self.dense_4h_to_h(intermediate_parallel)
        return output, output_bias
96
97


98
class ParallelAttention(MegatronModule):
99
100
101
102
103
    """Parallel self-attention layer abstract class.

    Self-attention layer takes input with size [b, s, h]
    and returns output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
104

105
    def __init__(self, init_method,
106
107
108
109
                 output_layer_init_method, layer_number,
                 attention_type=AttnType.self_attn,
                 attn_mask_type=AttnMaskType.padding):
        super(ParallelAttention, self).__init__()
Mohammad's avatar
Mohammad committed
110
        args = get_args()
Mohammad's avatar
Mohammad committed
111
        self.fp16 = args.fp16
112
        self.bf16 = args.bf16
113

Mohammad's avatar
Mohammad committed
114
115
        self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
116
117
118
        if self.apply_query_key_layer_scaling:
            self.attention_softmax_in_fp32 = True
        self.layer_number = max(1, layer_number)
119
120
121
122
        self.attention_type = attention_type
        self.attn_mask_type = attn_mask_type

        projection_size = args.kv_channels * args.num_attention_heads
123
124

        # Per attention head and per partition values.
125
        world_size = mpu.get_tensor_model_parallel_world_size()
126
        self.hidden_size_per_partition = mpu.divide(projection_size,
Mohammad's avatar
Mohammad committed
127
                                                    world_size)
128
        self.hidden_size_per_attention_head = mpu.divide(
129
            projection_size, args.num_attention_heads)
130
        self.num_attention_heads_per_partition = mpu.divide(
Mohammad's avatar
Mohammad committed
131
            args.num_attention_heads, world_size)
132
133

        # Strided linear layer.
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
        if attention_type == AttnType.self_attn:
            self.query_key_value = mpu.ColumnParallelLinear(
                args.hidden_size,
                3 * projection_size,
                gather_output=False,
                init_method=init_method)
        else:
            assert attention_type == AttnType.cross_attn
            self.query = mpu.ColumnParallelLinear(
                args.hidden_size,
                projection_size,
                gather_output=False,
                init_method=init_method)

            self.key_value = mpu.ColumnParallelLinear(
                args.hidden_size,
                2 * projection_size,
                gather_output=False,
                init_method=init_method)
153

154
155
156
157
158
159
160
        coeff = None
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        if self.apply_query_key_layer_scaling:
            coeff = self.layer_number
            self.norm_factor *= coeff

        self.scale_mask_softmax = FusedScaleMaskSoftmax(
161
            self.fp16, self.bf16,
162
163
            self.attn_mask_type,
            args.masked_softmax_fusion,
164
            attention_mask_func,
165
166
167
            self.attention_softmax_in_fp32,
            coeff)

168
169
170
        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
Mohammad's avatar
Mohammad committed
171
        self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
172
173
174

        # Output.
        self.dense = mpu.RowParallelLinear(
175
            projection_size,
Mohammad's avatar
Mohammad committed
176
            args.hidden_size,
177
            input_is_parallel=True,
178
179
            init_method=output_layer_init_method,
            skip_bias_add=True)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
180

181
    def forward(self, hidden_states, attention_mask, layer_past=None,
182
                get_key_value=False, encoder_output=None):
183
        # hidden_states: [sq, b, h]
184

185
186
187
        # =====================
        # Query, Key, and Value
        # =====================
188

189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
        if self.attention_type == AttnType.self_attn:
            # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
            mixed_x_layer, _ = self.query_key_value(hidden_states)

            # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
            new_tensor_shape = mixed_x_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 3 * self.hidden_size_per_attention_head)
            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

            # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
            (query_layer,
             key_layer,
             value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
        else:
            # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
            mixed_kv_layer, _ = self.key_value(encoder_output)

            # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
            new_tensor_shape = mixed_kv_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 2 * self.hidden_size_per_attention_head)
            mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

            # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
            (key_layer,
             value_layer) = mpu.split_tensor_along_last_dim(mixed_kv_layer, 2)

            # Attention head [sq, b, h] --> [sq, b, hp]
            query_layer, _ = self.query(hidden_states)
            # [sq, b, hp] --> [sq, b, np, hn]
            new_tensor_shape = query_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 self.hidden_size_per_attention_head)
            query_layer = query_layer.view(*new_tensor_shape)
224

225
226
227
        # ==================================
        # Adjust key and value for inference
        # ==================================
228
229
230
231

        if layer_past is not None:
            past_key, past_value = layer_past
            key_layer = torch.cat((past_key.type_as(key_layer),
232
                                   key_layer), dim=0)
233
            value_layer = torch.cat((past_value.type_as(value_layer),
234
                                     value_layer), dim=0)
235
236
237
        if get_key_value:
            present = (key_layer, value_layer)

238
239
240
        # ===================================
        # Raw attention scores. [b, np, s, s]
        # ===================================
241

242
        # [b, np, sq, sk]
243
244
245
        output_size = (query_layer.size(1),
                       query_layer.size(2),
                       query_layer.size(0),
246
                       key_layer.size(0))
247

248
        # [sq, b, np, hn] -> [sq, b * np, hn]
249
250
        query_layer = query_layer.view(output_size[2],
                                       output_size[0] * output_size[1], -1)
251
        # [sk, b, np, hn] -> [sk, b * np, hn]
252
253
254
        key_layer = key_layer.view(output_size[3],
                                   output_size[0] * output_size[1], -1)

255
        # preallocting result tensor: [b * np, sq, sk]
256
        matmul_result = torch.empty(
257
258
            output_size[0]*output_size[1],
            output_size[2],
259
            output_size[3],
260
            dtype=query_layer.dtype,
261
262
            device=torch.cuda.current_device())

263
        # Raw attention scores. [b * np, sq, sk]
264
265
        matmul_result = torch.baddbmm(
            matmul_result,
266
            query_layer.transpose(0, 1),   # [b * np, sq, hn]
267
            key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
268
269
            beta=0.0, alpha=(1.0/self.norm_factor))

270
        # change view to [b, np, sq, sk]
271
272
273
        attention_scores = matmul_result.view(*output_size)

        # ==================================================
274
        # Update attention mask for inference. [b, np, sq, sk]
275
        # ==================================================
276

277
278
279
280
281
        if get_key_value:
            with torch.no_grad():
                if layer_past is not None:
                    attention_mask = attention_mask[
                        ...,
Neel Kant's avatar
Neel Kant committed
282
                        attention_scores.size(3) - 1,
283
284
285
286
287
288
289
                        :attention_scores.size(3)].unsqueeze(2)
                else:
                    attention_mask = attention_mask[
                        ...,
                        :attention_scores.size(3),
                        :attention_scores.size(3)]

290
291
292
        # ===========================
        # Attention probs and dropout
        # ===========================
293

294
        # attention scores and attention mask [b, np, sq, sk]
295
296
        attention_probs = self.scale_mask_softmax(attention_scores,
                                                  attention_mask)
297

298
299
300
301
302
303
        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        with mpu.get_cuda_rng_tracker().fork():
            attention_probs = self.attention_dropout(attention_probs)

        # =========================
304
        # Context layer. [sq, b, hp]
305
306
        # =========================

307
308
        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]
309

310
        # context layer shape: [b, np, sq, hn]
311
312
313
314
        output_size = (value_layer.size(1),
                       value_layer.size(2),
                       query_layer.size(0),
                       value_layer.size(3))
315

316
        # change view [sk, b * np, hn]
317
        value_layer = value_layer.view(value_layer.size(0),
318
                                       output_size[0] * output_size[1], -1)
319

320
        # change view [b * np, sq, sk]
321
322
        attention_probs = attention_probs.view(output_size[0] * output_size[1],
                                               output_size[2], -1)
323

324
        # matmul: [b * np, sq, hn]
325
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
326

327
        # change view [b, np, sq, hn]
328
329
        context_layer = context_layer.view(*output_size)

330
        # [b, np, sq, hn] --> [sq, b, np, hn]
331
332
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

333
        # [sq, b, np, hn] --> [sq, b, hp]
334
335
336
337
338
        new_context_layer_shape = context_layer.size()[:-2] + \
            (self.hidden_size_per_partition,)
        context_layer = context_layer.view(*new_context_layer_shape)

        # =================
339
        # Output. [sq, b, h]
340
341
342
        # =================

        output, bias = self.dense(context_layer)
343
344
345
346

        if get_key_value:
            output = [output, present]

347
348
349
        return output, bias


350
def bias_dropout_add(x, bias, residual, prob, training):
351
352
353
354
355
356
357
358
359
360
361
362
363
    # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
    out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
    out = residual + out
    return out


def get_bias_dropout_add(training):
    def _bias_dropout_add(x, bias, residual, prob):
        return bias_dropout_add(x, bias, residual, prob, training)
    return _bias_dropout_add


@torch.jit.script
364
def bias_dropout_add_fused_train(x, bias, residual, prob):
365
366
367
368
369
    # type: (Tensor, Tensor, Tensor, float) -> Tensor
    return bias_dropout_add(x, bias, residual, prob, True)


@torch.jit.script
370
def bias_dropout_add_fused_inference(x, bias, residual, prob):
371
372
    # type: (Tensor, Tensor, Tensor, float) -> Tensor
    return bias_dropout_add(x, bias, residual, prob, False)
373
374
375
376
377


class ParallelTransformerLayer(MegatronModule):
    """A single transformer layer.

378
    Transformer layer takes input with size [b, s, h] and returns an
379
380
    output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
381

382
383
    def __init__(self, init_method, output_layer_init_method,
                 layer_number, layer_type=LayerType.encoder,
384
                 self_attn_mask_type=AttnMaskType.padding):
Mohammad's avatar
Mohammad committed
385
        args = get_args()
386
387

        super(ParallelTransformerLayer, self).__init__()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
388
        self.layer_number = layer_number
389
        self.layer_type = layer_type
390
391

        self.apply_residual_connection_post_layernorm \
Mohammad's avatar
Mohammad committed
392
            = args.apply_residual_connection_post_layernorm
393

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
394
395
396
        self.bf16 = args.bf16
        self.fp32_residual_connection = args.fp32_residual_connection

397
398
        # Layernorm on the input data.
        self.input_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
399
400
            args.hidden_size,
            eps=args.layernorm_epsilon)
401
402

        # Self attention.
403
404
405
406
407
408
        self.self_attention = ParallelAttention(
            init_method,
            output_layer_init_method,
            layer_number,
            attention_type=AttnType.self_attn,
            attn_mask_type=self_attn_mask_type)
409
410
        self.hidden_dropout = args.hidden_dropout
        self.bias_dropout_fusion = args.bias_dropout_fusion
411

412
        # Layernorm on the attention output
413
        self.post_attention_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
414
415
            args.hidden_size,
            eps=args.layernorm_epsilon)
416

417
418
419
420
421
422
423
424
425
426
427
        if self.layer_type == LayerType.decoder:
            self.inter_attention = ParallelAttention(
                init_method,
                output_layer_init_method,
                layer_number,
                attention_type=AttnType.cross_attn)
            # Layernorm on the attention output.
            self.post_inter_attention_layernorm = LayerNorm(
                args.hidden_size,
                eps=args.layernorm_epsilon)

428
        # MLP
429
        self.mlp = ParallelMLP(init_method,
Mohammad's avatar
Mohammad committed
430
                               output_layer_init_method)
431

432
433
434
    def forward(self, hidden_states, attention_mask,
                encoder_output=None, enc_dec_attn_mask=None,
                layer_past=None, get_key_value=False):
435
436
        # hidden_states: [b, s, h]

437
        # Layer norm at the beginning of the transformer layer.
438
439
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
440
        attention_output, attention_bias = \
441
442
443
444
            self.self_attention(layernorm_output,
                                attention_mask,
                                layer_past=layer_past,
                                get_key_value=get_key_value)
445

446
447
        if get_key_value:
            attention_output, presents = attention_output
448

449
450
        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
451
452
453
454
            residual = layernorm_output
        else:
            residual = hidden_states

455
456
        # jit scripting for a nn.module (with dropout) is not
        # trigerring the fusion kernel. For now, we use two
457
458
459
460
461
462
463
        # different nn.functional routines to account for varying
        # dropout semantics during training and inference phases.
        if self.bias_dropout_fusion:
            if self.training:
                bias_dropout_add_func = bias_dropout_add_fused_train
            else:
                bias_dropout_add_func = bias_dropout_add_fused_inference
464
        else:
465
466
            bias_dropout_add_func = get_bias_dropout_add(self.training)

467
        # re-enable torch grad to enable fused optimization.
468
469
470
471
472
473
474
        with torch.enable_grad():
            layernorm_input = bias_dropout_add_func(
                attention_output,
                attention_bias.expand_as(residual),
                residual,
                self.hidden_dropout)

475
476
477
        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)

478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
        if self.layer_type == LayerType.decoder:
            attention_output, attention_bias = \
                self.inter_attention(layernorm_output,
                                     enc_dec_attn_mask,
                                     encoder_output=encoder_output)
            # residual connection
            if self.apply_residual_connection_post_layernorm:
                residual = layernorm_output
            else:
                residual = layernorm_input

            # re-enable torch grad to enable fused optimization.
            with torch.enable_grad():
                layernorm_input = bias_dropout_add_func(
                    attention_output,
                    attention_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)

            # Layer norm post the decoder attention
            layernorm_output = self.post_inter_attention_layernorm(layernorm_input)

500
        # MLP.
501
        mlp_output, mlp_bias = self.mlp(layernorm_output)
502

503
504
        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
505
            residual = layernorm_output
506
        else:
507
508
            residual = layernorm_input

509
        # re-enable torch grad to enable fused optimization.
510
511
512
513
514
515
        with torch.enable_grad():
            output = bias_dropout_add_func(
                mlp_output,
                mlp_bias.expand_as(residual),
                residual,
                self.hidden_dropout)
516
517
518
519
520
521
522
523
524
525

        if get_key_value:
            output = [output, presents]

        return output


class ParallelTransformer(MegatronModule):
    """Transformer class."""

526
    def __init__(self, init_method, output_layer_init_method,
527
                 layer_type=LayerType.encoder,
528
529
                 self_attn_mask_type=AttnMaskType.padding,
                 pre_process=True, post_process=True):
530
        super(ParallelTransformer, self).__init__()
Mohammad's avatar
Mohammad committed
531
        args = get_args()
532

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
533
        self.bf16 = args.bf16
534
        self.fp32_residual_connection = args.fp32_residual_connection
535
536
537
        self.pre_process = pre_process
        self.post_process = post_process
        self.input_tensor = None
538

539
        # Store activation checkpoiting flag.
540
541
        self.activations_checkpoint_method = args.activations_checkpoint_method
        self.activations_checkpoint_num_layers = args.activations_checkpoint_num_layers
mshoeybi's avatar
mshoeybi committed
542
        self.distribute_checkpointed_activations = args.distribute_checkpointed_activations
543

544
        # Number of layers.
545
        assert args.num_layers % mpu.get_pipeline_model_parallel_world_size() == 0, \
546
            'num_layers must be divisible by pipeline_model_parallel_size'
547
        self.num_layers = args.num_layers // mpu.get_pipeline_model_parallel_world_size()
Mohammad's avatar
Mohammad committed
548
549
550

        # Transformer layers.
        def build_layer(layer_number):
551
            return ParallelTransformerLayer(
552
553
554
                init_method,
                output_layer_init_method,
                layer_number,
555
556
                layer_type=layer_type,
                self_attn_mask_type=self_attn_mask_type)
557
558
        if args.virtual_pipeline_model_parallel_size is not None:
            assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
559
560
561
562
                'num_layers_per_stage must be divisible by ' \
                'virtual_pipeline_model_parallel_size'
            # Number of layers in each model chunk is the number of layers in the stage,
            # divided by the number of model chunks in a stage.
563
            self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
564
565
566
567
568
569
570
571
            # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0]  [2]  [4]  [6]
            # Stage 1: [1]  [3]  [5]  [7]
            # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0, 1]  [4, 5]
            # Stage 1: [2, 3]  [6, 7]
572
            offset = mpu.get_virtual_pipeline_model_parallel_rank() * (
573
                args.num_layers // args.virtual_pipeline_model_parallel_size) + \
574
575
                (mpu.get_pipeline_model_parallel_rank() * self.num_layers)
        else:
576
            # Each stage gets a contiguous set of layers.
577
            offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
578

579
        self.layers = torch.nn.ModuleList(
580
            [build_layer(i + 1 + offset) for i in range(self.num_layers)])
581

582
        if self.post_process:
583
584
585
586
            # Final layer norm before output.
            self.final_layernorm = LayerNorm(
                args.hidden_size,
                eps=args.layernorm_epsilon)
587

Mohammad's avatar
Mohammad committed
588
    def _get_layer(self, layer_number):
589
        return self.layers[layer_number]
Mohammad's avatar
Mohammad committed
590

591
592
    def _checkpointed_forward(self, hidden_states, attention_mask,
                              encoder_output, enc_dec_attn_mask):
593
594
595
596
        """Forward method with activation checkpointing."""
        def custom(start, end):
            def custom_forward(*inputs):
                x_ = inputs[0]
597
598
599
                attention_mask = inputs[1]
                encoder_output = inputs[2]
                enc_dec_attn_mask = inputs[3]
Mohammad's avatar
Mohammad committed
600
601
                for index in range(start, end):
                    layer = self._get_layer(index)
602
                    x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
603
604
605
                return x_
            return custom_forward

mshoeybi's avatar
mshoeybi committed
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
        def distribute_checkpointed_activations_helper(layer_number):
            """Distribute checkpointed activations across the tensor model
               Parallel ranks if the `distribute-checkpointed-activations
               is on and either of the following conditions is met:
                 - it is not the first layer in the in the pipeline stage.
                   The first layer is used in the pipeline parallelism 
                   and changing its shape throws error in the backward pass.
                 - we are at the first pipline stage so the input tensor is
                   not used in pipeline parallelism. Note that no pipeline
                   parallelism is a special case of this.
            """
            not_first_layer_in_pipeline_stage = (layer_number > 0)
            is_first_pipeline_stage = (
                mpu.get_pipeline_model_parallel_rank() == 0)
            return self.distribute_checkpointed_activations and \
                (not_first_layer_in_pipeline_stage or is_first_pipeline_stage)

623
624
625
626
627
628
629
630
        if self.activations_checkpoint_method == 'uniform':
            # Uniformly divide the total number of Transformer layers and checkpoint
            # the input activation of each divided chunk.
            # A method to further reduce memory usage reducing checkpoints.
            l = 0
            while l < self.num_layers:
                hidden_states = mpu.checkpoint(
                    custom(l, l + self.activations_checkpoint_num_layers),
mshoeybi's avatar
mshoeybi committed
631
                    distribute_checkpointed_activations_helper(l),
632
633
634
635
636
637
638
639
640
641
                    hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
                l += self.activations_checkpoint_num_layers
        elif self.activations_checkpoint_method == 'block':
            # Checkpoint the input activation of only a set number of individual
            # Transformer layers and skip the rest.
            # A method fully use the device memory removing redundant re-computation.
            for l in range(self.num_layers):
                if l < self.activations_checkpoint_num_layers:
                    hidden_states = mpu.checkpoint(
                        custom(l, l + 1),
mshoeybi's avatar
mshoeybi committed
642
                        distribute_checkpointed_activations_helper(l),
643
644
645
646
647
648
                        hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
                else:
                    hidden_states = custom(l, l + 1)(
                        hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
        else:
            raise ValueError("Invalid activation checkpoint method.")
649
650
651

        return hidden_states

652
    def set_input_tensor(self, input_tensor):
653
654
655
656
657
658
659
        """Set input tensor to be used instead of forward()'s input.

        When doing pipeline parallelism the input from the previous
        stage comes from communication, not from the input, so the
        model's forward_step_func won't have it. This function is thus
        used by internal code to bypass the input provided by the
        forward_step_func"""
660
661
        self.input_tensor = input_tensor

662
    def forward(self, hidden_states, attention_mask, layer_past=None,
663
                get_key_value=False, encoder_output=None, enc_dec_attn_mask=None):
664

665
        # Checks.
666
667
668
669
670
        if layer_past is not None:
            assert get_key_value, \
                'for not None values in layer_past, ' \
                'expected get_key_value to be set'
        if get_key_value:
671
            assert self.activations_checkpoint_method is None, \
672
673
674
                'get_key_value does not work with ' \
                'activation checkpointing'

675
        if self.pre_process:
676
            # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
mshoeybi's avatar
mshoeybi committed
677
            # If the input flag for fp32 residual connection is set, convert for float.
678
679
            if self.fp32_residual_connection:
                hidden_states = hidden_states.transpose(0, 1).contiguous().float()
mshoeybi's avatar
mshoeybi committed
680
            # Otherwise, leave it as is.
681
682
            else:
                hidden_states = hidden_states.transpose(0, 1).contiguous()
683
        else:
684
            # See set_input_tensor()
685
            hidden_states = self.input_tensor
686

Vijay Korthikanti's avatar
Vijay Korthikanti committed
687
688
        if encoder_output is not None:
             encoder_output = encoder_output.transpose(0, 1).contiguous()
689

690
        if self.activations_checkpoint_method is not None:
691
            hidden_states = self._checkpointed_forward(hidden_states,
692
693
694
                                                       attention_mask,
                                                       encoder_output,
                                                       enc_dec_attn_mask)
695
696
697
        else:
            if get_key_value:
                presents = []
Mohammad's avatar
Mohammad committed
698
699
            for index in range(self.num_layers):
                layer = self._get_layer(index)
700
701
                past = None
                if layer_past is not None:
Mohammad's avatar
Mohammad committed
702
                    past = layer_past[index]
703
704
                hidden_states = layer(hidden_states,
                                      attention_mask,
705
706
                                      encoder_output=encoder_output,
                                      enc_dec_attn_mask=enc_dec_attn_mask,
707
708
709
710
711
                                      layer_past=past,
                                      get_key_value=get_key_value)
                if get_key_value:
                    hidden_states, present = hidden_states
                    presents.append(present)
712

713
        # Final layer norm.
714
        if self.post_process:
715
716
            # Reverting data format change [s b h] --> [b s h].
            hidden_states = hidden_states.transpose(0, 1).contiguous()
717
718
719
            output = self.final_layernorm(hidden_states)
        else:
            output = hidden_states
720
721
722
723
        if get_key_value:
            output = [output, presents]

        return output