transformer.py 29.5 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transformer."""
import math
import torch
19
import torch.nn.functional as F
20

Mohammad's avatar
Mohammad committed
21
from megatron import get_args
22
from megatron import mpu
23
from .module import MegatronModule
24
from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType
25
from megatron.model import LayerNorm
26
27
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
28
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
29
30
31
32
33
34
35
36
37
38
39
40


""" We use the following notation throughout this file:
     h: hidden size
     n: number of attention heads
     p: number of model parallel partitions
     np: n/p
     hp: h/p
     hn: h/n
     b: batch size
     s: sequence length
     l: number of layers
41
    Transformer takes input of size [s, b, h] and returns a
42
43
44
45
46
47
48
49
50
    tensor of the same size. We use the following arguments:
        hyperparameters: transformer hyperparameters
"""

class ParallelMLP(MegatronModule):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
hwijeen's avatar
hwijeen committed
51
    state back into h hidden dimension.
52
53
    """

54
    def __init__(self, init_method, output_layer_init_method):
55
        super(ParallelMLP, self).__init__()
Mohammad's avatar
Mohammad committed
56
        args = get_args()
57
58
59

        # Project to 4h.
        self.dense_h_to_4h = mpu.ColumnParallelLinear(
Mohammad's avatar
Mohammad committed
60
            args.hidden_size,
61
            args.ffn_hidden_size,
62
            gather_output=False,
63
64
            init_method=init_method,
            skip_bias_add=True)
65

66
67
68
69
70
71
        self.bias_gelu_fusion = args.bias_gelu_fusion
        self.activation_func = F.gelu
        if args.openai_gelu:
            self.activation_func = openai_gelu
        elif args.onnx_safe:
            self.activation_func = erf_gelu
72
73
74

        # Project back to h.
        self.dense_4h_to_h = mpu.RowParallelLinear(
75
            args.ffn_hidden_size,
Mohammad's avatar
Mohammad committed
76
            args.hidden_size,
77
            input_is_parallel=True,
78
79
            init_method=output_layer_init_method,
            skip_bias_add=True)
80

81
82
    def forward(self, hidden_states):

83
84
        # [s, b, 4hp]
        intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
85

86
87
88
89
90
91
92
93
94
95
        if self.bias_gelu_fusion:
             intermediate_parallel = \
                     bias_gelu_impl(intermediate_parallel, bias_parallel)
        else:
            intermediate_parallel = \
                self.activation_func(intermediate_parallel + bias_parallel)

        # [s, b, h]
        output, output_bias = self.dense_4h_to_h(intermediate_parallel)
        return output, output_bias
96
97


98
class ParallelAttention(MegatronModule):
99
100
101
102
103
    """Parallel self-attention layer abstract class.

    Self-attention layer takes input with size [b, s, h]
    and returns output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
104

105
    def __init__(self, init_method,
106
107
108
109
                 output_layer_init_method, layer_number,
                 attention_type=AttnType.self_attn,
                 attn_mask_type=AttnMaskType.padding):
        super(ParallelAttention, self).__init__()
Mohammad's avatar
Mohammad committed
110
        args = get_args()
Mohammad's avatar
Mohammad committed
111
        self.fp16 = args.fp16
112
        self.bf16 = args.bf16
113

Mohammad's avatar
Mohammad committed
114
115
        self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
116
117
118
        if self.apply_query_key_layer_scaling:
            self.attention_softmax_in_fp32 = True
        self.layer_number = max(1, layer_number)
119
120
        self.attention_type = attention_type
        self.attn_mask_type = attn_mask_type
121
        self.params_dtype = args.params_dtype
122
123

        projection_size = args.kv_channels * args.num_attention_heads
124
125

        # Per attention head and per partition values.
126
        world_size = mpu.get_tensor_model_parallel_world_size()
127
        self.hidden_size_per_partition = mpu.divide(projection_size,
Mohammad's avatar
Mohammad committed
128
                                                    world_size)
129
        self.hidden_size_per_attention_head = mpu.divide(
130
            projection_size, args.num_attention_heads)
131
        self.num_attention_heads_per_partition = mpu.divide(
Mohammad's avatar
Mohammad committed
132
            args.num_attention_heads, world_size)
133
134

        # Strided linear layer.
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
        if attention_type == AttnType.self_attn:
            self.query_key_value = mpu.ColumnParallelLinear(
                args.hidden_size,
                3 * projection_size,
                gather_output=False,
                init_method=init_method)
        else:
            assert attention_type == AttnType.cross_attn
            self.query = mpu.ColumnParallelLinear(
                args.hidden_size,
                projection_size,
                gather_output=False,
                init_method=init_method)

            self.key_value = mpu.ColumnParallelLinear(
                args.hidden_size,
                2 * projection_size,
                gather_output=False,
                init_method=init_method)
154

155
156
157
158
159
160
161
        coeff = None
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        if self.apply_query_key_layer_scaling:
            coeff = self.layer_number
            self.norm_factor *= coeff

        self.scale_mask_softmax = FusedScaleMaskSoftmax(
162
            self.fp16, self.bf16,
163
164
            self.attn_mask_type,
            args.masked_softmax_fusion,
165
            attention_mask_func,
166
167
168
            self.attention_softmax_in_fp32,
            coeff)

169
170
171
        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
Mohammad's avatar
Mohammad committed
172
        self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
173
174
175

        # Output.
        self.dense = mpu.RowParallelLinear(
176
            projection_size,
Mohammad's avatar
Mohammad committed
177
            args.hidden_size,
178
            input_is_parallel=True,
179
180
            init_method=output_layer_init_method,
            skip_bias_add=True)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
181

182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
        # Inference key-value memory
        self.inference_key_memory = None
        self.inference_value_memory = None
        self.inference_current_sequence_len = 0


    def _allocate_memory(self, inference_max_sequence_len, batch_size):
        return torch.empty(
            inference_max_sequence_len,
            batch_size,
            self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
            dtype=self.params_dtype,
            device=torch.cuda.current_device())
        

    def forward(self, hidden_states, attention_mask,
                encoder_output=None,
                set_inference_key_value_memory=False,
                inference_max_sequence_len=None):
202
        # hidden_states: [sq, b, h]
203

204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228

        # =================================================
        # Pre-allocate memory for key-values for inference.
        # =================================================
        if set_inference_key_value_memory:
            assert inference_max_sequence_len and inference_max_sequence_len > 0
            self.inference_key_memory = self._allocate_memory(
                inference_max_sequence_len, hidden_states.size(1))
            self.inference_value_memory = self._allocate_memory(
                inference_max_sequence_len, hidden_states.size(1))
            self.inference_current_sequence_len = 0
        # Some consistency check.
        if inference_max_sequence_len:
            assert self.inference_current_sequence_len < \
                self.inference_key_memory.size(0)
            assert inference_max_sequence_len == \
                self.inference_key_memory.size(0)
        # This is added for safety. In case inference_max_sequence_len
        # is not provided, make sure there is no potential memory left
        # from previous inference.
        if not inference_max_sequence_len:
            self.inference_key_memory = None
            self.inference_value_memory = None
        

229
230
231
        # =====================
        # Query, Key, and Value
        # =====================
232

233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
        if self.attention_type == AttnType.self_attn:
            # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
            mixed_x_layer, _ = self.query_key_value(hidden_states)

            # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
            new_tensor_shape = mixed_x_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 3 * self.hidden_size_per_attention_head)
            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

            # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
            (query_layer,
             key_layer,
             value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
        else:
            # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
            mixed_kv_layer, _ = self.key_value(encoder_output)

            # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
            new_tensor_shape = mixed_kv_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 2 * self.hidden_size_per_attention_head)
            mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

            # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
            (key_layer,
             value_layer) = mpu.split_tensor_along_last_dim(mixed_kv_layer, 2)

            # Attention head [sq, b, h] --> [sq, b, hp]
            query_layer, _ = self.query(hidden_states)
            # [sq, b, hp] --> [sq, b, np, hn]
            new_tensor_shape = query_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 self.hidden_size_per_attention_head)
            query_layer = query_layer.view(*new_tensor_shape)
268
269


270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
        # ===================================================
        # Adjust key, value, and attention mask for inference
        # ===================================================

        if inference_max_sequence_len:
            # Adjust the range variables.
            start = self.inference_current_sequence_len
            self.inference_current_sequence_len += key_layer.size(0)
            end = self.inference_current_sequence_len
            # Copy key and values.
            self.inference_key_memory[start:end, ...] = key_layer
            self.inference_value_memory[start:end, ...] = value_layer
            key_layer = self.inference_key_memory[:end, ...]
            value_layer = self.inference_value_memory[:end, ...]
            # Adjust attention mask
            attention_mask = attention_mask[..., start:end, :end]

287

288
289
290
        # ===================================
        # Raw attention scores. [b, np, s, s]
        # ===================================
291

292
        # [b, np, sq, sk]
293
294
295
        output_size = (query_layer.size(1),
                       query_layer.size(2),
                       query_layer.size(0),
296
                       key_layer.size(0))
297

298
        # [sq, b, np, hn] -> [sq, b * np, hn]
299
300
        query_layer = query_layer.view(output_size[2],
                                       output_size[0] * output_size[1], -1)
301
        # [sk, b, np, hn] -> [sk, b * np, hn]
302
303
304
        key_layer = key_layer.view(output_size[3],
                                   output_size[0] * output_size[1], -1)

305
        # preallocting result tensor: [b * np, sq, sk]
306
        matmul_result = torch.empty(
307
308
            output_size[0]*output_size[1],
            output_size[2],
309
            output_size[3],
310
            dtype=query_layer.dtype,
311
312
            device=torch.cuda.current_device())

313
        # Raw attention scores. [b * np, sq, sk]
314
315
        matmul_result = torch.baddbmm(
            matmul_result,
316
            query_layer.transpose(0, 1),   # [b * np, sq, hn]
317
            key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
318
319
            beta=0.0, alpha=(1.0/self.norm_factor))

320
        # change view to [b, np, sq, sk]
321
322
        attention_scores = matmul_result.view(*output_size)

323

324
325
326
        # ===========================
        # Attention probs and dropout
        # ===========================
327

328
        # attention scores and attention mask [b, np, sq, sk]
329
330
        attention_probs = self.scale_mask_softmax(attention_scores,
                                                  attention_mask)
331

332
333
334
335
336
337
        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        with mpu.get_cuda_rng_tracker().fork():
            attention_probs = self.attention_dropout(attention_probs)

        # =========================
338
        # Context layer. [sq, b, hp]
339
340
        # =========================

341
342
        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]
343

344
        # context layer shape: [b, np, sq, hn]
345
346
347
348
        output_size = (value_layer.size(1),
                       value_layer.size(2),
                       query_layer.size(0),
                       value_layer.size(3))
349

350
        # change view [sk, b * np, hn]
351
        value_layer = value_layer.view(value_layer.size(0),
352
                                       output_size[0] * output_size[1], -1)
353

354
        # change view [b * np, sq, sk]
355
356
        attention_probs = attention_probs.view(output_size[0] * output_size[1],
                                               output_size[2], -1)
357

358
        # matmul: [b * np, sq, hn]
359
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
360

361
        # change view [b, np, sq, hn]
362
363
        context_layer = context_layer.view(*output_size)

364
        # [b, np, sq, hn] --> [sq, b, np, hn]
365
366
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

367
        # [sq, b, np, hn] --> [sq, b, hp]
368
369
370
371
372
        new_context_layer_shape = context_layer.size()[:-2] + \
            (self.hidden_size_per_partition,)
        context_layer = context_layer.view(*new_context_layer_shape)

        # =================
373
        # Output. [sq, b, h]
374
375
376
        # =================

        output, bias = self.dense(context_layer)
377

378
379
380
        return output, bias


381
def bias_dropout_add(x, bias, residual, prob, training):
382
383
384
385
386
387
388
389
390
391
392
393
394
    # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
    out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
    out = residual + out
    return out


def get_bias_dropout_add(training):
    def _bias_dropout_add(x, bias, residual, prob):
        return bias_dropout_add(x, bias, residual, prob, training)
    return _bias_dropout_add


@torch.jit.script
395
396
397
398
def bias_dropout_add_fused_train(x: torch.Tensor,
                                 bias: torch.Tensor,
                                 residual: torch.Tensor,
                                 prob: float) -> torch.Tensor:
399
400
401
402
    return bias_dropout_add(x, bias, residual, prob, True)


@torch.jit.script
403
404
405
406
def bias_dropout_add_fused_inference(x: torch.Tensor,
                                     bias: torch.Tensor,
                                     residual: torch.Tensor,
                                     prob: float) -> torch.Tensor:
407
    return bias_dropout_add(x, bias, residual, prob, False)
408
409
410
411
412


class ParallelTransformerLayer(MegatronModule):
    """A single transformer layer.

413
    Transformer layer takes input with size [b, s, h] and returns an
414
415
    output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
416

417
418
    def __init__(self, init_method, output_layer_init_method,
                 layer_number, layer_type=LayerType.encoder,
419
                 self_attn_mask_type=AttnMaskType.padding):
Mohammad's avatar
Mohammad committed
420
        args = get_args()
421
422

        super(ParallelTransformerLayer, self).__init__()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
423
        self.layer_number = layer_number
424
        self.layer_type = layer_type
425
426

        self.apply_residual_connection_post_layernorm \
Mohammad's avatar
Mohammad committed
427
            = args.apply_residual_connection_post_layernorm
428

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
429
430
431
        self.bf16 = args.bf16
        self.fp32_residual_connection = args.fp32_residual_connection

432
433
        # Layernorm on the input data.
        self.input_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
434
435
            args.hidden_size,
            eps=args.layernorm_epsilon)
436
437

        # Self attention.
438
439
440
441
442
443
        self.self_attention = ParallelAttention(
            init_method,
            output_layer_init_method,
            layer_number,
            attention_type=AttnType.self_attn,
            attn_mask_type=self_attn_mask_type)
444
445
        self.hidden_dropout = args.hidden_dropout
        self.bias_dropout_fusion = args.bias_dropout_fusion
446

447
        # Layernorm on the attention output
448
        self.post_attention_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
449
450
            args.hidden_size,
            eps=args.layernorm_epsilon)
451

452
453
454
455
456
457
458
459
460
461
462
        if self.layer_type == LayerType.decoder:
            self.inter_attention = ParallelAttention(
                init_method,
                output_layer_init_method,
                layer_number,
                attention_type=AttnType.cross_attn)
            # Layernorm on the attention output.
            self.post_inter_attention_layernorm = LayerNorm(
                args.hidden_size,
                eps=args.layernorm_epsilon)

463
        # MLP
464
        self.mlp = ParallelMLP(init_method,
Mohammad's avatar
Mohammad committed
465
                               output_layer_init_method)
466

467
    def forward(self, hidden_states, attention_mask,
468
469
470
471
                encoder_output=None,
                enc_dec_attn_mask=None,
                set_inference_key_value_memory=False,
                inference_max_sequence_len=None):
472
473
        # hidden_states: [b, s, h]

474
        # Layer norm at the beginning of the transformer layer.
475
476
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
477
        attention_output, attention_bias = \
478
479
480
481
482
            self.self_attention(
                layernorm_output,
                attention_mask,
                set_inference_key_value_memory=set_inference_key_value_memory,
                inference_max_sequence_len=inference_max_sequence_len)
483

484
485
        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
486
487
488
489
            residual = layernorm_output
        else:
            residual = hidden_states

490
491
        # jit scripting for a nn.module (with dropout) is not
        # trigerring the fusion kernel. For now, we use two
492
493
494
495
496
497
498
        # different nn.functional routines to account for varying
        # dropout semantics during training and inference phases.
        if self.bias_dropout_fusion:
            if self.training:
                bias_dropout_add_func = bias_dropout_add_fused_train
            else:
                bias_dropout_add_func = bias_dropout_add_fused_inference
499
        else:
500
501
            bias_dropout_add_func = get_bias_dropout_add(self.training)

502
        # re-enable torch grad to enable fused optimization.
503
504
505
506
507
508
509
        with torch.enable_grad():
            layernorm_input = bias_dropout_add_func(
                attention_output,
                attention_bias.expand_as(residual),
                residual,
                self.hidden_dropout)

510
511
512
        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)

513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
        if self.layer_type == LayerType.decoder:
            attention_output, attention_bias = \
                self.inter_attention(layernorm_output,
                                     enc_dec_attn_mask,
                                     encoder_output=encoder_output)
            # residual connection
            if self.apply_residual_connection_post_layernorm:
                residual = layernorm_output
            else:
                residual = layernorm_input

            # re-enable torch grad to enable fused optimization.
            with torch.enable_grad():
                layernorm_input = bias_dropout_add_func(
                    attention_output,
                    attention_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)

            # Layer norm post the decoder attention
            layernorm_output = self.post_inter_attention_layernorm(layernorm_input)

535
        # MLP.
536
        mlp_output, mlp_bias = self.mlp(layernorm_output)
537

538
539
        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
540
            residual = layernorm_output
541
        else:
542
543
            residual = layernorm_input

544
        # re-enable torch grad to enable fused optimization.
545
546
547
548
549
550
        with torch.enable_grad():
            output = bias_dropout_add_func(
                mlp_output,
                mlp_bias.expand_as(residual),
                residual,
                self.hidden_dropout)
551
552
553
554
555
556
557

        return output


class ParallelTransformer(MegatronModule):
    """Transformer class."""

558
    def __init__(self, init_method, output_layer_init_method,
559
                 layer_type=LayerType.encoder,
560
561
                 self_attn_mask_type=AttnMaskType.padding,
                 pre_process=True, post_process=True):
562
        super(ParallelTransformer, self).__init__()
Mohammad's avatar
Mohammad committed
563
        args = get_args()
564

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
565
        self.bf16 = args.bf16
566
        self.fp32_residual_connection = args.fp32_residual_connection
567
568
569
        self.pre_process = pre_process
        self.post_process = post_process
        self.input_tensor = None
570

571
        # Store activation checkpoiting flag.
572
573
        self.activations_checkpoint_method = args.activations_checkpoint_method
        self.activations_checkpoint_num_layers = args.activations_checkpoint_num_layers
mshoeybi's avatar
mshoeybi committed
574
        self.distribute_checkpointed_activations = args.distribute_checkpointed_activations
575

576
        # Number of layers.
577
578
        self.num_layers = mpu.get_num_layers(
            args, args.model_type == ModelType.encoder_and_decoder)
Mohammad's avatar
Mohammad committed
579
580
581

        # Transformer layers.
        def build_layer(layer_number):
582
            return ParallelTransformerLayer(
583
584
585
                init_method,
                output_layer_init_method,
                layer_number,
586
587
                layer_type=layer_type,
                self_attn_mask_type=self_attn_mask_type)
588
589
        if args.virtual_pipeline_model_parallel_size is not None:
            assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
590
591
592
593
                'num_layers_per_stage must be divisible by ' \
                'virtual_pipeline_model_parallel_size'
            # Number of layers in each model chunk is the number of layers in the stage,
            # divided by the number of model chunks in a stage.
594
            self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
595
596
597
598
599
600
601
602
            # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0]  [2]  [4]  [6]
            # Stage 1: [1]  [3]  [5]  [7]
            # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0, 1]  [4, 5]
            # Stage 1: [2, 3]  [6, 7]
603
            offset = mpu.get_virtual_pipeline_model_parallel_rank() * (
604
                args.num_layers // args.virtual_pipeline_model_parallel_size) + \
605
606
                (mpu.get_pipeline_model_parallel_rank() * self.num_layers)
        else:
607
            # Each stage gets a contiguous set of layers.
608
            offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
609

610
        self.layers = torch.nn.ModuleList(
611
            [build_layer(i + 1 + offset) for i in range(self.num_layers)])
612

613
        if self.post_process:
614
615
616
617
            # Final layer norm before output.
            self.final_layernorm = LayerNorm(
                args.hidden_size,
                eps=args.layernorm_epsilon)
618

Mohammad's avatar
Mohammad committed
619
    def _get_layer(self, layer_number):
620
        return self.layers[layer_number]
Mohammad's avatar
Mohammad committed
621

622
623
    def _checkpointed_forward(self, hidden_states, attention_mask,
                              encoder_output, enc_dec_attn_mask):
624
625
626
627
        """Forward method with activation checkpointing."""
        def custom(start, end):
            def custom_forward(*inputs):
                x_ = inputs[0]
628
629
630
                attention_mask = inputs[1]
                encoder_output = inputs[2]
                enc_dec_attn_mask = inputs[3]
Mohammad's avatar
Mohammad committed
631
632
                for index in range(start, end):
                    layer = self._get_layer(index)
633
                    x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
634
635
636
                return x_
            return custom_forward

mshoeybi's avatar
mshoeybi committed
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
        def distribute_checkpointed_activations_helper(layer_number):
            """Distribute checkpointed activations across the tensor model
               Parallel ranks if the `distribute-checkpointed-activations
               is on and either of the following conditions is met:
                 - it is not the first layer in the in the pipeline stage.
                   The first layer is used in the pipeline parallelism 
                   and changing its shape throws error in the backward pass.
                 - we are at the first pipline stage so the input tensor is
                   not used in pipeline parallelism. Note that no pipeline
                   parallelism is a special case of this.
            """
            not_first_layer_in_pipeline_stage = (layer_number > 0)
            is_first_pipeline_stage = (
                mpu.get_pipeline_model_parallel_rank() == 0)
            return self.distribute_checkpointed_activations and \
                (not_first_layer_in_pipeline_stage or is_first_pipeline_stage)

654
655
656
657
658
659
660
661
        if self.activations_checkpoint_method == 'uniform':
            # Uniformly divide the total number of Transformer layers and checkpoint
            # the input activation of each divided chunk.
            # A method to further reduce memory usage reducing checkpoints.
            l = 0
            while l < self.num_layers:
                hidden_states = mpu.checkpoint(
                    custom(l, l + self.activations_checkpoint_num_layers),
mshoeybi's avatar
mshoeybi committed
662
                    distribute_checkpointed_activations_helper(l),
663
664
665
666
667
668
669
670
671
672
                    hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
                l += self.activations_checkpoint_num_layers
        elif self.activations_checkpoint_method == 'block':
            # Checkpoint the input activation of only a set number of individual
            # Transformer layers and skip the rest.
            # A method fully use the device memory removing redundant re-computation.
            for l in range(self.num_layers):
                if l < self.activations_checkpoint_num_layers:
                    hidden_states = mpu.checkpoint(
                        custom(l, l + 1),
mshoeybi's avatar
mshoeybi committed
673
                        distribute_checkpointed_activations_helper(l),
674
675
676
677
678
679
                        hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
                else:
                    hidden_states = custom(l, l + 1)(
                        hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
        else:
            raise ValueError("Invalid activation checkpoint method.")
680
681
682

        return hidden_states

683
    def set_input_tensor(self, input_tensor):
684
685
686
687
688
689
690
        """Set input tensor to be used instead of forward()'s input.

        When doing pipeline parallelism the input from the previous
        stage comes from communication, not from the input, so the
        model's forward_step_func won't have it. This function is thus
        used by internal code to bypass the input provided by the
        forward_step_func"""
691
692
        self.input_tensor = input_tensor

693
694
695
696
697
    def forward(self, hidden_states, attention_mask,
                encoder_output=None,
                enc_dec_attn_mask=None,
                set_inference_key_value_memory=False,
                inference_max_sequence_len=None):
698

699
        # Checks.
700
        if inference_max_sequence_len:
701
            assert self.activations_checkpoint_method is None, \
702
                'inference does not work with activation checkpointing'
703

704
        if self.pre_process:
705
            # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
mshoeybi's avatar
mshoeybi committed
706
            # If the input flag for fp32 residual connection is set, convert for float.
707
708
            if self.fp32_residual_connection:
                hidden_states = hidden_states.transpose(0, 1).contiguous().float()
mshoeybi's avatar
mshoeybi committed
709
            # Otherwise, leave it as is.
710
711
            else:
                hidden_states = hidden_states.transpose(0, 1).contiguous()
712
        else:
713
            # See set_input_tensor()
714
            hidden_states = self.input_tensor
715

Vijay Korthikanti's avatar
Vijay Korthikanti committed
716
717
        if encoder_output is not None:
             encoder_output = encoder_output.transpose(0, 1).contiguous()
718

719
        if self.activations_checkpoint_method is not None:
720
            hidden_states = self._checkpointed_forward(hidden_states,
721
722
723
                                                       attention_mask,
                                                       encoder_output,
                                                       enc_dec_attn_mask)
724
        else:
Mohammad's avatar
Mohammad committed
725
726
            for index in range(self.num_layers):
                layer = self._get_layer(index)
727
728
729
730
731
732
733
                hidden_states = layer(
                    hidden_states,
                    attention_mask,
                    encoder_output=encoder_output,
                    enc_dec_attn_mask=enc_dec_attn_mask,
                    set_inference_key_value_memory=set_inference_key_value_memory,
                    inference_max_sequence_len=inference_max_sequence_len)
734

735
        # Final layer norm.
736
        if self.post_process:
737
738
            # Reverting data format change [s b h] --> [b s h].
            hidden_states = hidden_states.transpose(0, 1).contiguous()
739
740
741
            output = self.final_layernorm(hidden_states)
        else:
            output = hidden_states
742
        
743
        return output