transformer.py 28.8 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transformer."""
import math
import torch
19
import torch.nn.functional as F
20

Mohammad's avatar
Mohammad committed
21
from megatron import get_args
22
from megatron import mpu
23
from .module import MegatronModule
24
from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType
25
from megatron.model import LayerNorm
26
27
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
28
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
29
30
31
32
33
34
35
36
37
38
39
40


""" We use the following notation throughout this file:
     h: hidden size
     n: number of attention heads
     p: number of model parallel partitions
     np: n/p
     hp: h/p
     hn: h/n
     b: batch size
     s: sequence length
     l: number of layers
41
    Transformer takes input of size [s, b, h] and returns a
42
43
44
45
46
47
48
49
50
    tensor of the same size. We use the following arguments:
        hyperparameters: transformer hyperparameters
"""

class ParallelMLP(MegatronModule):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
hwijeen's avatar
hwijeen committed
51
    state back into h hidden dimension.
52
53
    """

54
    def __init__(self, init_method, output_layer_init_method):
55
        super(ParallelMLP, self).__init__()
Mohammad's avatar
Mohammad committed
56
        args = get_args()
57
58
59

        # Project to 4h.
        self.dense_h_to_4h = mpu.ColumnParallelLinear(
Mohammad's avatar
Mohammad committed
60
            args.hidden_size,
61
            args.ffn_hidden_size,
62
            gather_output=False,
63
64
            init_method=init_method,
            skip_bias_add=True)
65

66
67
68
69
70
71
        self.bias_gelu_fusion = args.bias_gelu_fusion
        self.activation_func = F.gelu
        if args.openai_gelu:
            self.activation_func = openai_gelu
        elif args.onnx_safe:
            self.activation_func = erf_gelu
72
73
74

        # Project back to h.
        self.dense_4h_to_h = mpu.RowParallelLinear(
75
            args.ffn_hidden_size,
Mohammad's avatar
Mohammad committed
76
            args.hidden_size,
77
            input_is_parallel=True,
78
79
            init_method=output_layer_init_method,
            skip_bias_add=True)
80

81
82
    def forward(self, hidden_states):

83
84
        # [s, b, 4hp]
        intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
85

86
87
88
89
90
91
92
93
94
95
        if self.bias_gelu_fusion:
             intermediate_parallel = \
                     bias_gelu_impl(intermediate_parallel, bias_parallel)
        else:
            intermediate_parallel = \
                self.activation_func(intermediate_parallel + bias_parallel)

        # [s, b, h]
        output, output_bias = self.dense_4h_to_h(intermediate_parallel)
        return output, output_bias
96
97


98
class ParallelAttention(MegatronModule):
99
100
101
102
103
    """Parallel self-attention layer abstract class.

    Self-attention layer takes input with size [b, s, h]
    and returns output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
104

105
    def __init__(self, init_method,
106
107
108
109
                 output_layer_init_method, layer_number,
                 attention_type=AttnType.self_attn,
                 attn_mask_type=AttnMaskType.padding):
        super(ParallelAttention, self).__init__()
Mohammad's avatar
Mohammad committed
110
        args = get_args()
Mohammad's avatar
Mohammad committed
111
        self.fp16 = args.fp16
112
        self.bf16 = args.bf16
113

Mohammad's avatar
Mohammad committed
114
115
        self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
116
117
118
        if self.apply_query_key_layer_scaling:
            self.attention_softmax_in_fp32 = True
        self.layer_number = max(1, layer_number)
119
120
        self.attention_type = attention_type
        self.attn_mask_type = attn_mask_type
121
        self.params_dtype = args.params_dtype
122
123

        projection_size = args.kv_channels * args.num_attention_heads
124
125

        # Per attention head and per partition values.
126
        world_size = mpu.get_tensor_model_parallel_world_size()
127
        self.hidden_size_per_partition = mpu.divide(projection_size,
Mohammad's avatar
Mohammad committed
128
                                                    world_size)
129
        self.hidden_size_per_attention_head = mpu.divide(
130
            projection_size, args.num_attention_heads)
131
        self.num_attention_heads_per_partition = mpu.divide(
Mohammad's avatar
Mohammad committed
132
            args.num_attention_heads, world_size)
133
134

        # Strided linear layer.
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
        if attention_type == AttnType.self_attn:
            self.query_key_value = mpu.ColumnParallelLinear(
                args.hidden_size,
                3 * projection_size,
                gather_output=False,
                init_method=init_method)
        else:
            assert attention_type == AttnType.cross_attn
            self.query = mpu.ColumnParallelLinear(
                args.hidden_size,
                projection_size,
                gather_output=False,
                init_method=init_method)

            self.key_value = mpu.ColumnParallelLinear(
                args.hidden_size,
                2 * projection_size,
                gather_output=False,
                init_method=init_method)
154

155
156
157
158
159
160
161
        coeff = None
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        if self.apply_query_key_layer_scaling:
            coeff = self.layer_number
            self.norm_factor *= coeff

        self.scale_mask_softmax = FusedScaleMaskSoftmax(
162
            self.fp16, self.bf16,
163
164
            self.attn_mask_type,
            args.masked_softmax_fusion,
165
            attention_mask_func,
166
167
168
            self.attention_softmax_in_fp32,
            coeff)

169
170
171
        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
Mohammad's avatar
Mohammad committed
172
        self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
173
174
175

        # Output.
        self.dense = mpu.RowParallelLinear(
176
            projection_size,
Mohammad's avatar
Mohammad committed
177
            args.hidden_size,
178
            input_is_parallel=True,
179
180
            init_method=output_layer_init_method,
            skip_bias_add=True)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
181

182
183
184
185
186
187
188
189
190
191
192
193

    def _allocate_memory(self, inference_max_sequence_len, batch_size):
        return torch.empty(
            inference_max_sequence_len,
            batch_size,
            self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
            dtype=self.params_dtype,
            device=torch.cuda.current_device())
        

    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
194
                encoder_output=None, inference_params=None):
195
        # hidden_states: [sq, b, h]
196

197
198
199
200

        # =================================================
        # Pre-allocate memory for key-values for inference.
        # =================================================
mshoeybi's avatar
mshoeybi committed
201
        if inference_params:
202
            if self.layer_number not in inference_params.key_value_memory_dict:
mshoeybi's avatar
mshoeybi committed
203
                inf_max_seq_len = inference_params.max_sequence_len
mshoeybi's avatar
mshoeybi committed
204
                inf_max_batch_size = inference_params.max_batch_size
205
                inference_key_memory = self._allocate_memory(
mshoeybi's avatar
mshoeybi committed
206
                    inf_max_seq_len, inf_max_batch_size)
207
                inference_value_memory = self._allocate_memory(
mshoeybi's avatar
mshoeybi committed
208
                    inf_max_seq_len, inf_max_batch_size)
209
210
211
212
213
                inference_params.key_value_memory_dict[self.layer_number] = (
                    inference_key_memory, inference_value_memory)
            else:
                inference_key_memory, inference_value_memory = \
                    inference_params.key_value_memory_dict[self.layer_number]
mshoeybi's avatar
mshoeybi committed
214

215

216
217
218
        # =====================
        # Query, Key, and Value
        # =====================
219

220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
        if self.attention_type == AttnType.self_attn:
            # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
            mixed_x_layer, _ = self.query_key_value(hidden_states)

            # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
            new_tensor_shape = mixed_x_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 3 * self.hidden_size_per_attention_head)
            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

            # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
            (query_layer,
             key_layer,
             value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
        else:
            # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
            mixed_kv_layer, _ = self.key_value(encoder_output)

            # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
            new_tensor_shape = mixed_kv_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 2 * self.hidden_size_per_attention_head)
            mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

            # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
            (key_layer,
             value_layer) = mpu.split_tensor_along_last_dim(mixed_kv_layer, 2)

            # Attention head [sq, b, h] --> [sq, b, hp]
            query_layer, _ = self.query(hidden_states)
            # [sq, b, hp] --> [sq, b, np, hn]
            new_tensor_shape = query_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 self.hidden_size_per_attention_head)
            query_layer = query_layer.view(*new_tensor_shape)
255
256


mshoeybi's avatar
mshoeybi committed
257
258
259
        # ==================================
        # Adjust key and value for inference
        # ==================================
260

mshoeybi's avatar
mshoeybi committed
261
        if inference_params:
mshoeybi's avatar
mshoeybi committed
262
263
            batch_start = inference_params.batch_size_offset
            batch_end = batch_start + key_layer.size(1)
264
            assert batch_end <= inference_key_memory.size(1)
mshoeybi's avatar
mshoeybi committed
265
266
            sequence_start = inference_params.sequence_len_offset
            sequence_end = sequence_start + key_layer.size(0)
267
            assert sequence_end <= inference_key_memory.size(0)
268
            # Copy key and values.
269
270
271
272
273
            inference_key_memory[sequence_start:sequence_end,
                                 batch_start:batch_end, ...] = key_layer
            inference_value_memory[sequence_start:sequence_end,
                                   batch_start:batch_end, ...] = value_layer
            key_layer = inference_key_memory[
mshoeybi's avatar
mshoeybi committed
274
                :sequence_end, batch_start:batch_end, ...]
275
            value_layer = inference_value_memory[
mshoeybi's avatar
mshoeybi committed
276
                :sequence_end, batch_start:batch_end, ...]
277

278

279
280
281
        # ===================================
        # Raw attention scores. [b, np, s, s]
        # ===================================
282

283
        # [b, np, sq, sk]
284
285
286
        output_size = (query_layer.size(1),
                       query_layer.size(2),
                       query_layer.size(0),
287
                       key_layer.size(0))
288

289
        # [sq, b, np, hn] -> [sq, b * np, hn]
290
291
        query_layer = query_layer.view(output_size[2],
                                       output_size[0] * output_size[1], -1)
292
        # [sk, b, np, hn] -> [sk, b * np, hn]
293
294
295
        key_layer = key_layer.view(output_size[3],
                                   output_size[0] * output_size[1], -1)

296
        # preallocting result tensor: [b * np, sq, sk]
297
        matmul_result = torch.empty(
298
299
            output_size[0]*output_size[1],
            output_size[2],
300
            output_size[3],
301
            dtype=query_layer.dtype,
302
303
            device=torch.cuda.current_device())

304
        # Raw attention scores. [b * np, sq, sk]
305
306
        matmul_result = torch.baddbmm(
            matmul_result,
307
            query_layer.transpose(0, 1),   # [b * np, sq, hn]
308
            key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
309
310
            beta=0.0, alpha=(1.0/self.norm_factor))

311
        # change view to [b, np, sq, sk]
312
313
        attention_scores = matmul_result.view(*output_size)

314

315
316
317
        # ===========================
        # Attention probs and dropout
        # ===========================
318

319
        # attention scores and attention mask [b, np, sq, sk]
320
321
        attention_probs = self.scale_mask_softmax(attention_scores,
                                                  attention_mask)
322

323
324
325
326
327
328
        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        with mpu.get_cuda_rng_tracker().fork():
            attention_probs = self.attention_dropout(attention_probs)

        # =========================
329
        # Context layer. [sq, b, hp]
330
331
        # =========================

332
333
        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]
334

335
        # context layer shape: [b, np, sq, hn]
336
337
338
339
        output_size = (value_layer.size(1),
                       value_layer.size(2),
                       query_layer.size(0),
                       value_layer.size(3))
340

341
        # change view [sk, b * np, hn]
342
        value_layer = value_layer.view(value_layer.size(0),
343
                                       output_size[0] * output_size[1], -1)
344

345
        # change view [b * np, sq, sk]
346
347
        attention_probs = attention_probs.view(output_size[0] * output_size[1],
                                               output_size[2], -1)
348

349
        # matmul: [b * np, sq, hn]
350
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
351

352
        # change view [b, np, sq, hn]
353
354
        context_layer = context_layer.view(*output_size)

355
        # [b, np, sq, hn] --> [sq, b, np, hn]
356
357
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

358
        # [sq, b, np, hn] --> [sq, b, hp]
359
360
361
362
363
        new_context_layer_shape = context_layer.size()[:-2] + \
            (self.hidden_size_per_partition,)
        context_layer = context_layer.view(*new_context_layer_shape)

        # =================
364
        # Output. [sq, b, h]
365
366
367
        # =================

        output, bias = self.dense(context_layer)
368

369
370
371
        return output, bias


372
def bias_dropout_add(x, bias, residual, prob, training):
373
374
375
376
377
378
379
380
381
382
383
384
385
    # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
    out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
    out = residual + out
    return out


def get_bias_dropout_add(training):
    def _bias_dropout_add(x, bias, residual, prob):
        return bias_dropout_add(x, bias, residual, prob, training)
    return _bias_dropout_add


@torch.jit.script
386
387
388
389
def bias_dropout_add_fused_train(x: torch.Tensor,
                                 bias: torch.Tensor,
                                 residual: torch.Tensor,
                                 prob: float) -> torch.Tensor:
390
391
392
393
    return bias_dropout_add(x, bias, residual, prob, True)


@torch.jit.script
394
395
396
397
def bias_dropout_add_fused_inference(x: torch.Tensor,
                                     bias: torch.Tensor,
                                     residual: torch.Tensor,
                                     prob: float) -> torch.Tensor:
398
    return bias_dropout_add(x, bias, residual, prob, False)
399
400
401
402
403


class ParallelTransformerLayer(MegatronModule):
    """A single transformer layer.

404
    Transformer layer takes input with size [b, s, h] and returns an
405
406
    output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
407

408
409
    def __init__(self, init_method, output_layer_init_method,
                 layer_number, layer_type=LayerType.encoder,
410
                 self_attn_mask_type=AttnMaskType.padding):
Mohammad's avatar
Mohammad committed
411
        args = get_args()
412
413

        super(ParallelTransformerLayer, self).__init__()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
414
        self.layer_number = layer_number
415
        self.layer_type = layer_type
416
417

        self.apply_residual_connection_post_layernorm \
Mohammad's avatar
Mohammad committed
418
            = args.apply_residual_connection_post_layernorm
419

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
420
421
422
        self.bf16 = args.bf16
        self.fp32_residual_connection = args.fp32_residual_connection

423
424
        # Layernorm on the input data.
        self.input_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
425
426
            args.hidden_size,
            eps=args.layernorm_epsilon)
427
428

        # Self attention.
429
430
431
432
433
434
        self.self_attention = ParallelAttention(
            init_method,
            output_layer_init_method,
            layer_number,
            attention_type=AttnType.self_attn,
            attn_mask_type=self_attn_mask_type)
435
436
        self.hidden_dropout = args.hidden_dropout
        self.bias_dropout_fusion = args.bias_dropout_fusion
437

438
        # Layernorm on the attention output
439
        self.post_attention_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
440
441
            args.hidden_size,
            eps=args.layernorm_epsilon)
442

443
444
445
446
447
448
449
450
451
452
453
        if self.layer_type == LayerType.decoder:
            self.inter_attention = ParallelAttention(
                init_method,
                output_layer_init_method,
                layer_number,
                attention_type=AttnType.cross_attn)
            # Layernorm on the attention output.
            self.post_inter_attention_layernorm = LayerNorm(
                args.hidden_size,
                eps=args.layernorm_epsilon)

454
        # MLP
455
        self.mlp = ParallelMLP(init_method,
Mohammad's avatar
Mohammad committed
456
                               output_layer_init_method)
457

458
    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
459
460
                encoder_output=None, enc_dec_attn_mask=None,
                inference_params=None):
461
462
        # hidden_states: [b, s, h]

463
        # Layer norm at the beginning of the transformer layer.
464
465
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
466
        attention_output, attention_bias = \
467
468
469
            self.self_attention(
                layernorm_output,
                attention_mask,
mshoeybi's avatar
mshoeybi committed
470
                inference_params=inference_params)
471

472
473
        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
474
475
476
477
            residual = layernorm_output
        else:
            residual = hidden_states

478
479
        # jit scripting for a nn.module (with dropout) is not
        # trigerring the fusion kernel. For now, we use two
480
481
482
483
484
485
486
        # different nn.functional routines to account for varying
        # dropout semantics during training and inference phases.
        if self.bias_dropout_fusion:
            if self.training:
                bias_dropout_add_func = bias_dropout_add_fused_train
            else:
                bias_dropout_add_func = bias_dropout_add_fused_inference
487
        else:
488
489
            bias_dropout_add_func = get_bias_dropout_add(self.training)

490
        # re-enable torch grad to enable fused optimization.
491
492
493
494
495
496
497
        with torch.enable_grad():
            layernorm_input = bias_dropout_add_func(
                attention_output,
                attention_bias.expand_as(residual),
                residual,
                self.hidden_dropout)

498
499
500
        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)

501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
        if self.layer_type == LayerType.decoder:
            attention_output, attention_bias = \
                self.inter_attention(layernorm_output,
                                     enc_dec_attn_mask,
                                     encoder_output=encoder_output)
            # residual connection
            if self.apply_residual_connection_post_layernorm:
                residual = layernorm_output
            else:
                residual = layernorm_input

            # re-enable torch grad to enable fused optimization.
            with torch.enable_grad():
                layernorm_input = bias_dropout_add_func(
                    attention_output,
                    attention_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)

            # Layer norm post the decoder attention
            layernorm_output = self.post_inter_attention_layernorm(layernorm_input)

523
        # MLP.
524
        mlp_output, mlp_bias = self.mlp(layernorm_output)
525

526
527
        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
528
            residual = layernorm_output
529
        else:
530
531
            residual = layernorm_input

532
        # re-enable torch grad to enable fused optimization.
533
534
535
536
537
538
        with torch.enable_grad():
            output = bias_dropout_add_func(
                mlp_output,
                mlp_bias.expand_as(residual),
                residual,
                self.hidden_dropout)
539
540
541
542
543
544
545

        return output


class ParallelTransformer(MegatronModule):
    """Transformer class."""

546
    def __init__(self, init_method, output_layer_init_method,
547
                 layer_type=LayerType.encoder,
548
549
                 self_attn_mask_type=AttnMaskType.padding,
                 pre_process=True, post_process=True):
550
        super(ParallelTransformer, self).__init__()
Mohammad's avatar
Mohammad committed
551
        args = get_args()
552

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
553
        self.bf16 = args.bf16
554
        self.fp32_residual_connection = args.fp32_residual_connection
555
556
557
        self.pre_process = pre_process
        self.post_process = post_process
        self.input_tensor = None
558

559
        # Store activation checkpoiting flag.
560
561
        self.activations_checkpoint_method = args.activations_checkpoint_method
        self.activations_checkpoint_num_layers = args.activations_checkpoint_num_layers
mshoeybi's avatar
mshoeybi committed
562
        self.distribute_checkpointed_activations = args.distribute_checkpointed_activations
563

564
        # Number of layers.
565
566
        self.num_layers = mpu.get_num_layers(
            args, args.model_type == ModelType.encoder_and_decoder)
Mohammad's avatar
Mohammad committed
567
568
569

        # Transformer layers.
        def build_layer(layer_number):
570
            return ParallelTransformerLayer(
571
572
573
                init_method,
                output_layer_init_method,
                layer_number,
574
575
                layer_type=layer_type,
                self_attn_mask_type=self_attn_mask_type)
576
577
        if args.virtual_pipeline_model_parallel_size is not None:
            assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
578
579
580
581
                'num_layers_per_stage must be divisible by ' \
                'virtual_pipeline_model_parallel_size'
            # Number of layers in each model chunk is the number of layers in the stage,
            # divided by the number of model chunks in a stage.
582
            self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
583
584
585
586
587
588
589
590
            # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0]  [2]  [4]  [6]
            # Stage 1: [1]  [3]  [5]  [7]
            # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0, 1]  [4, 5]
            # Stage 1: [2, 3]  [6, 7]
591
            offset = mpu.get_virtual_pipeline_model_parallel_rank() * (
592
                args.num_layers // args.virtual_pipeline_model_parallel_size) + \
593
594
                (mpu.get_pipeline_model_parallel_rank() * self.num_layers)
        else:
595
            # Each stage gets a contiguous set of layers.
596
            offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
597

598
        self.layers = torch.nn.ModuleList(
599
            [build_layer(i + 1 + offset) for i in range(self.num_layers)])
600

601
        if self.post_process:
602
603
604
605
            # Final layer norm before output.
            self.final_layernorm = LayerNorm(
                args.hidden_size,
                eps=args.layernorm_epsilon)
606

Mohammad's avatar
Mohammad committed
607
    def _get_layer(self, layer_number):
608
        return self.layers[layer_number]
Mohammad's avatar
Mohammad committed
609

610
611
    def _checkpointed_forward(self, hidden_states, attention_mask,
                              encoder_output, enc_dec_attn_mask):
612
613
614
615
        """Forward method with activation checkpointing."""
        def custom(start, end):
            def custom_forward(*inputs):
                x_ = inputs[0]
616
617
618
                attention_mask = inputs[1]
                encoder_output = inputs[2]
                enc_dec_attn_mask = inputs[3]
Mohammad's avatar
Mohammad committed
619
620
                for index in range(start, end):
                    layer = self._get_layer(index)
621
                    x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
622
623
624
                return x_
            return custom_forward

mshoeybi's avatar
mshoeybi committed
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
        def distribute_checkpointed_activations_helper(layer_number):
            """Distribute checkpointed activations across the tensor model
               Parallel ranks if the `distribute-checkpointed-activations
               is on and either of the following conditions is met:
                 - it is not the first layer in the in the pipeline stage.
                   The first layer is used in the pipeline parallelism 
                   and changing its shape throws error in the backward pass.
                 - we are at the first pipline stage so the input tensor is
                   not used in pipeline parallelism. Note that no pipeline
                   parallelism is a special case of this.
            """
            not_first_layer_in_pipeline_stage = (layer_number > 0)
            is_first_pipeline_stage = (
                mpu.get_pipeline_model_parallel_rank() == 0)
            return self.distribute_checkpointed_activations and \
                (not_first_layer_in_pipeline_stage or is_first_pipeline_stage)

642
643
644
645
646
647
648
649
        if self.activations_checkpoint_method == 'uniform':
            # Uniformly divide the total number of Transformer layers and checkpoint
            # the input activation of each divided chunk.
            # A method to further reduce memory usage reducing checkpoints.
            l = 0
            while l < self.num_layers:
                hidden_states = mpu.checkpoint(
                    custom(l, l + self.activations_checkpoint_num_layers),
mshoeybi's avatar
mshoeybi committed
650
                    distribute_checkpointed_activations_helper(l),
651
652
653
654
655
656
657
658
659
660
                    hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
                l += self.activations_checkpoint_num_layers
        elif self.activations_checkpoint_method == 'block':
            # Checkpoint the input activation of only a set number of individual
            # Transformer layers and skip the rest.
            # A method fully use the device memory removing redundant re-computation.
            for l in range(self.num_layers):
                if l < self.activations_checkpoint_num_layers:
                    hidden_states = mpu.checkpoint(
                        custom(l, l + 1),
mshoeybi's avatar
mshoeybi committed
661
                        distribute_checkpointed_activations_helper(l),
662
663
664
665
666
667
                        hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
                else:
                    hidden_states = custom(l, l + 1)(
                        hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
        else:
            raise ValueError("Invalid activation checkpoint method.")
668
669
670

        return hidden_states

671
    def set_input_tensor(self, input_tensor):
672
673
674
675
676
677
678
        """Set input tensor to be used instead of forward()'s input.

        When doing pipeline parallelism the input from the previous
        stage comes from communication, not from the input, so the
        model's forward_step_func won't have it. This function is thus
        used by internal code to bypass the input provided by the
        forward_step_func"""
679
680
        self.input_tensor = input_tensor

681
    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
682
683
                encoder_output=None, enc_dec_attn_mask=None,
                inference_params=None):
684

685
        # Checks.
mshoeybi's avatar
mshoeybi committed
686
        if inference_params:
687
            assert self.activations_checkpoint_method is None, \
688
                'inference does not work with activation checkpointing'
689

690
        if self.pre_process:
691
            # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
mshoeybi's avatar
mshoeybi committed
692
            # If the input flag for fp32 residual connection is set, convert for float.
693
694
            if self.fp32_residual_connection:
                hidden_states = hidden_states.transpose(0, 1).contiguous().float()
mshoeybi's avatar
mshoeybi committed
695
            # Otherwise, leave it as is.
696
697
            else:
                hidden_states = hidden_states.transpose(0, 1).contiguous()
698
        else:
699
            # See set_input_tensor()
700
            hidden_states = self.input_tensor
701

Vijay Korthikanti's avatar
Vijay Korthikanti committed
702
703
        if encoder_output is not None:
             encoder_output = encoder_output.transpose(0, 1).contiguous()
704

705
        if self.activations_checkpoint_method is not None:
706
            hidden_states = self._checkpointed_forward(hidden_states,
707
708
709
                                                       attention_mask,
                                                       encoder_output,
                                                       enc_dec_attn_mask)
710
        else:
Mohammad's avatar
Mohammad committed
711
712
            for index in range(self.num_layers):
                layer = self._get_layer(index)
713
714
715
716
717
                hidden_states = layer(
                    hidden_states,
                    attention_mask,
                    encoder_output=encoder_output,
                    enc_dec_attn_mask=enc_dec_attn_mask,
mshoeybi's avatar
mshoeybi committed
718
719
                    inference_params=inference_params)

720

721
        # Final layer norm.
722
        if self.post_process:
723
724
            # Reverting data format change [s b h] --> [b s h].
            hidden_states = hidden_states.transpose(0, 1).contiguous()
725
726
727
            output = self.final_layernorm(hidden_states)
        else:
            output = hidden_states
728
        
729
        return output