transformer.py 27.3 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transformer."""
17
import enum
18
19
import math
import torch
20
import torch.nn.functional as F
21

Mohammad's avatar
Mohammad committed
22
from megatron import get_args
23
from megatron import mpu
24
from .module import MegatronModule
Vijay Korthikanti's avatar
Vijay Korthikanti committed
25
from megatron.checkpointing import get_checkpoint_version
26
from megatron.model.enums import AttnMaskType, LayerType, AttnType
27
from megatron.model import import_layernorm
28
29
30
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import openai_gelu, erf_gelu
31

32
33
34
35
36
# flags required to enable jit fusion kernels
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
37
38
39
40
41
42
43
44
45
46
47

""" We use the following notation throughout this file:
     h: hidden size
     n: number of attention heads
     p: number of model parallel partitions
     np: n/p
     hp: h/p
     hn: h/n
     b: batch size
     s: sequence length
     l: number of layers
48
    Transformer takes input of size [s, b, h] and returns a
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
    tensor of the same size. We use the following arguments:
        hyperparameters: transformer hyperparameters
        attention_mask_func: a function that takes `unmaksed-attention-scores`
            with size [b, np, s, s] and an `attention-mask` and will apply
            the masking. The function should return a masked score of the
            same size [b, np, s, s].
               masked-attention-scores = attention_mask_func(
                                     unmaksed-attention-scores, attention-mask)
"""

class ParallelMLP(MegatronModule):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
    state back into h hidden dimension. At the end, dropout is also
    applied.
    """

68
    def __init__(self, init_method, output_layer_init_method):
69
        super(ParallelMLP, self).__init__()
Mohammad's avatar
Mohammad committed
70
        args = get_args()
71
72
73

        # Project to 4h.
        self.dense_h_to_4h = mpu.ColumnParallelLinear(
Mohammad's avatar
Mohammad committed
74
            args.hidden_size,
75
            args.ffn_hidden_size,
76
            gather_output=False,
77
78
            init_method=init_method,
            skip_bias_add=True)
79

80
81
82
83
84
85
        self.bias_gelu_fusion = args.bias_gelu_fusion
        self.activation_func = F.gelu
        if args.openai_gelu:
            self.activation_func = openai_gelu
        elif args.onnx_safe:
            self.activation_func = erf_gelu
86
87
88

        # Project back to h.
        self.dense_4h_to_h = mpu.RowParallelLinear(
89
            args.ffn_hidden_size,
Mohammad's avatar
Mohammad committed
90
            args.hidden_size,
91
            input_is_parallel=True,
92
93
            init_method=output_layer_init_method,
            skip_bias_add=True)
94
95
96

    def forward(self, hidden_states):

97
98
        # [s, b, 4hp]
        intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
99

100
101
102
103
104
105
106
107
108
109
        if self.bias_gelu_fusion:
             intermediate_parallel = \
                     bias_gelu_impl(intermediate_parallel, bias_parallel)
        else:
            intermediate_parallel = \
                self.activation_func(intermediate_parallel + bias_parallel)

        # [s, b, h]
        output, output_bias = self.dense_4h_to_h(intermediate_parallel)
        return output, output_bias
110
111


112
class ParallelAttention(MegatronModule):
113
114
115
116
117
    """Parallel self-attention layer abstract class.

    Self-attention layer takes input with size [b, s, h]
    and returns output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
118

Mohammad's avatar
Mohammad committed
119
    def __init__(self, attention_mask_func, init_method,
120
121
122
123
                 output_layer_init_method, layer_number,
                 attention_type=AttnType.self_attn,
                 attn_mask_type=AttnMaskType.padding):
        super(ParallelAttention, self).__init__()
Mohammad's avatar
Mohammad committed
124
        args = get_args()
Mohammad's avatar
Mohammad committed
125
        self.fp16 = args.fp16
126
127

        self.attention_mask_func = attention_mask_func
Mohammad's avatar
Mohammad committed
128
129
        self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
130
131
132
        if self.apply_query_key_layer_scaling:
            self.attention_softmax_in_fp32 = True
        self.layer_number = max(1, layer_number)
133
134
135
136
        self.attention_type = attention_type
        self.attn_mask_type = attn_mask_type

        projection_size = args.kv_channels * args.num_attention_heads
137
138

        # Per attention head and per partition values.
139
        world_size = mpu.get_tensor_model_parallel_world_size()
140
        self.hidden_size_per_partition = mpu.divide(projection_size,
Mohammad's avatar
Mohammad committed
141
                                                    world_size)
142
        self.hidden_size_per_attention_head = mpu.divide(
143
            projection_size, args.num_attention_heads)
144
        self.num_attention_heads_per_partition = mpu.divide(
Mohammad's avatar
Mohammad committed
145
            args.num_attention_heads, world_size)
146
147

        # Strided linear layer.
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
        if attention_type == AttnType.self_attn:
            self.query_key_value = mpu.ColumnParallelLinear(
                args.hidden_size,
                3 * projection_size,
                gather_output=False,
                init_method=init_method)
        else:
            assert attention_type == AttnType.cross_attn
            self.query = mpu.ColumnParallelLinear(
                args.hidden_size,
                projection_size,
                gather_output=False,
                init_method=init_method)

            self.key_value = mpu.ColumnParallelLinear(
                args.hidden_size,
                2 * projection_size,
                gather_output=False,
                init_method=init_method)
167

168
169
170
171
172
173
174
175
        coeff = None
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        if self.apply_query_key_layer_scaling:
            coeff = self.layer_number
            self.norm_factor *= coeff

        self.scale_mask_softmax = FusedScaleMaskSoftmax(
            self.fp16,
176
177
            self.attn_mask_type,
            args.masked_softmax_fusion,
178
179
180
181
            self.attention_mask_func,
            self.attention_softmax_in_fp32,
            coeff)

182
183
184
        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
Mohammad's avatar
Mohammad committed
185
        self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
186
187
188

        # Output.
        self.dense = mpu.RowParallelLinear(
189
            projection_size,
Mohammad's avatar
Mohammad committed
190
            args.hidden_size,
191
            input_is_parallel=True,
192
193
            init_method=output_layer_init_method,
            skip_bias_add=True)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
194

Vijay Korthikanti's avatar
Vijay Korthikanti committed
195
    def _transpose_last_dim(self, mixed_layer, num_splits, num_splits_first):
196
        input_shape = mixed_layer.size()
Vijay Korthikanti's avatar
Vijay Korthikanti committed
197
198
        if num_splits_first:
            """[s, b, num_splits * np * hn] 
199
200
            -->(view) [s, b, num_splits, np, hn]
            -->(tranpose) [s, b, np, num_splits, hn]
Vijay Korthikanti's avatar
Vijay Korthikanti committed
201
202
            -->(view) [s, b, np * num_splits * hn] """

203
204
205
            intermediate_shape = input_shape[:-1] +\
                (num_splits, self.num_attention_heads_per_partition,
                 self.hidden_size_per_attention_head)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
206

207
208
209
            mixed_layer = mixed_layer.view(*intermediate_shape)
            mixed_layer = mixed_layer.transpose(-2, -3).contiguous()
        else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
210
            """[s, b, np * hn * num_splits] 
211
212
            -->(view) [s, b, np, hn, num_splits]
            -->(tranpose) [s, b, np, num_splits, hn]
Vijay Korthikanti's avatar
Vijay Korthikanti committed
213
214
            -->(view) [s, b, np * num_splits * hn] """

215
216
217
218
219
220
            intermediate_shape = input_shape[:-1] +\
                (self.num_attention_heads_per_partition,
                 self.hidden_size_per_attention_head, num_splits)

            mixed_layer = mixed_layer.view(*intermediate_shape)
            mixed_layer = mixed_layer.transpose(-1, -2).contiguous()
221
        mixed_layer = mixed_layer.view(*input_shape)
222

223
        return mixed_layer
224

225
    def forward(self, hidden_states, attention_mask, layer_past=None,
226
                get_key_value=False, encoder_output=None):
227
        # hidden_states: [sq, b, h]
228

229
230
231
        # =====================
        # Query, Key, and Value
        # =====================
232

233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
        if self.attention_type == AttnType.self_attn:
            # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
            mixed_x_layer, _ = self.query_key_value(hidden_states)

            checkpoint_version = get_checkpoint_version()
            if checkpoint_version is not None:
                if checkpoint_version == 0:
                    # [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)]
                    mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, True)
                elif checkpoint_version == 1.0:
                    # [s, b, (np * hn * 3)] --> [s, b, (np * 3 * hn)]
                    mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, False)

            # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
            new_tensor_shape = mixed_x_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 3 * self.hidden_size_per_attention_head)
            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

            # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
            (query_layer,
             key_layer,
             value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
        else:
            # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
            mixed_kv_layer, _ = self.key_value(encoder_output)

            checkpoint_version = get_checkpoint_version()
            if checkpoint_version is not None:
                if checkpoint_version == 0:
                    # [s, b, (2 * np * hn)] --> [s, b, (np * 2 * hn)]
                    mixed_kv_layer = self._transpose_last_dim(mixed_kv_layer, 2, True)
                elif checkpoint_version == 1.0:
                    # [s, b, (np * hn * 2)] --> [s, b, (np * 2 * hn)]
                    mixed_kv_layer = self._transpose_last_dim(mixed_kv_layer, 2, False)

            # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
            new_tensor_shape = mixed_kv_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 2 * self.hidden_size_per_attention_head)
            mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

            # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
            (key_layer,
             value_layer) = mpu.split_tensor_along_last_dim(mixed_kv_layer, 2)

            # Attention head [sq, b, h] --> [sq, b, hp]
            query_layer, _ = self.query(hidden_states)
            # [sq, b, hp] --> [sq, b, np, hn]
            new_tensor_shape = query_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 self.hidden_size_per_attention_head)
            query_layer = query_layer.view(*new_tensor_shape)
286

287
288
289
        # ==================================
        # Adjust key and value for inference
        # ==================================
290
291
292
293

        if layer_past is not None:
            past_key, past_value = layer_past
            key_layer = torch.cat((past_key.type_as(key_layer),
294
                                   key_layer), dim=0)
295
            value_layer = torch.cat((past_value.type_as(value_layer),
296
                                     value_layer), dim=0)
297
298
299
        if get_key_value:
            present = (key_layer, value_layer)

300
301
302
        # ===================================
        # Raw attention scores. [b, np, s, s]
        # ===================================
303

304
        # [b, np, sq, sk]
305
306
307
        output_size = (query_layer.size(1),
                       query_layer.size(2),
                       query_layer.size(0),
308
                       key_layer.size(0))
309

310
        # [sq, b, np, hn] -> [sq, b * np, hn]
311
312
        query_layer = query_layer.view(output_size[2],
                                       output_size[0] * output_size[1], -1)
313
        # [sk, b, np, hn] -> [sk, b * np, hn]
314
315
316
        key_layer = key_layer.view(output_size[3],
                                   output_size[0] * output_size[1], -1)

317
        # preallocting result tensor: [b * np, sq, sk]
318
        matmul_result = torch.empty(
319
320
            output_size[0]*output_size[1],
            output_size[2],
321
            output_size[3],
322
            dtype=query_layer.dtype,
323
324
            device=torch.cuda.current_device())

325
        # Raw attention scores. [b * np, sq, sk]
326
327
        matmul_result = torch.baddbmm(
            matmul_result,
328
            query_layer.transpose(0, 1),   # [b * np, sq, hn]
329
            key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
330
331
            beta=0.0, alpha=(1.0/self.norm_factor))

332
        # change view to [b, np, sq, sk]
333
334
335
        attention_scores = matmul_result.view(*output_size)

        # ==================================================
336
        # Update attention mask for inference. [b, np, sq, sk]
337
        # ==================================================
338

339
340
341
342
343
        if get_key_value:
            with torch.no_grad():
                if layer_past is not None:
                    attention_mask = attention_mask[
                        ...,
Neel Kant's avatar
Neel Kant committed
344
                        attention_scores.size(3) - 1,
345
346
347
348
349
350
351
                        :attention_scores.size(3)].unsqueeze(2)
                else:
                    attention_mask = attention_mask[
                        ...,
                        :attention_scores.size(3),
                        :attention_scores.size(3)]

352
353
354
        # ===========================
        # Attention probs and dropout
        # ===========================
355

356
        # attention scores and attention mask [b, np, sq, sk]
357
358
        attention_probs = self.scale_mask_softmax(attention_scores,
                                                  attention_mask)
359

360
361
362
363
364
365
        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        with mpu.get_cuda_rng_tracker().fork():
            attention_probs = self.attention_dropout(attention_probs)

        # =========================
366
        # Context layer. [sq, b, hp]
367
368
        # =========================

369
370
        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]
371

372
        # context layer shape: [b, np, sq, hn]
373
374
375
376
        output_size = (value_layer.size(1),
                       value_layer.size(2),
                       query_layer.size(0),
                       value_layer.size(3))
377

378
        # change view [sk, b * np, hn]
379
        value_layer = value_layer.view(value_layer.size(0),
380
                                       output_size[0] * output_size[1], -1)
381

382
        # change view [b * np, sq, sk]
383
384
        attention_probs = attention_probs.view(output_size[0] * output_size[1],
                                               output_size[2], -1)
385

386
        # matmul: [b * np, sq, hn]
387
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
388

389
        # change view [b, np, sq, hn]
390
391
        context_layer = context_layer.view(*output_size)

392
        # [b, np, sq, hn] --> [sq, b, np, hn]
393
394
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

395
        # [sq, b, np, hn] --> [sq, b, hp]
396
397
398
399
400
        new_context_layer_shape = context_layer.size()[:-2] + \
            (self.hidden_size_per_partition,)
        context_layer = context_layer.view(*new_context_layer_shape)

        # =================
401
        # Output. [sq, b, h]
402
403
404
        # =================

        output, bias = self.dense(context_layer)
405
406
407
408

        if get_key_value:
            output = [output, present]

409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
        return output, bias


def bias_dropout_add(x, bias, residual, prob, training) :
    # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
    out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
    out = residual + out
    return out


def get_bias_dropout_add(training):
    def _bias_dropout_add(x, bias, residual, prob):
        return bias_dropout_add(x, bias, residual, prob, training)
    return _bias_dropout_add


@torch.jit.script
def bias_dropout_add_fused_train(x, bias, residual, prob) :
    # type: (Tensor, Tensor, Tensor, float) -> Tensor
    return bias_dropout_add(x, bias, residual, prob, True)


@torch.jit.script
def bias_dropout_add_fused_inference(x, bias, residual, prob) :
    # type: (Tensor, Tensor, Tensor, float) -> Tensor
    return bias_dropout_add(x, bias, residual, prob, False)
435
436
437
438
439


class ParallelTransformerLayer(MegatronModule):
    """A single transformer layer.

440
    Transformer layer takes input with size [b, s, h] and returns an
441
442
    output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
443

444
445
446
447
    def __init__(self, attention_mask_func, init_method,
                 output_layer_init_method, layer_number,
                 layer_type=LayerType.encoder,
                 self_attn_mask_type=AttnMaskType.padding):
Mohammad's avatar
Mohammad committed
448
        args = get_args()
449
450

        super(ParallelTransformerLayer, self).__init__()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
451
        self.layer_number = layer_number
452
        self.layer_type = layer_type
453
454

        self.apply_residual_connection_post_layernorm \
Mohammad's avatar
Mohammad committed
455
            = args.apply_residual_connection_post_layernorm
456
457

        # Layernorm on the input data.
458
        LayerNorm = import_layernorm(args.fp32_residual_connection)
459
        self.input_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
460
461
            args.hidden_size,
            eps=args.layernorm_epsilon)
462
463

        # Self attention.
464
465
466
467
468
469
470
        self.self_attention = ParallelAttention(
            attention_mask_func,
            init_method,
            output_layer_init_method,
            layer_number,
            attention_type=AttnType.self_attn,
            attn_mask_type=self_attn_mask_type)
471
472
        self.hidden_dropout = args.hidden_dropout
        self.bias_dropout_fusion = args.bias_dropout_fusion
473

474
        # Layernorm on the attention output
475
        self.post_attention_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
476
477
            args.hidden_size,
            eps=args.layernorm_epsilon)
478

479
480
481
482
483
484
485
486
487
488
489
490
        if self.layer_type == LayerType.decoder:
            self.inter_attention = ParallelAttention(
                attention_mask_func,
                init_method,
                output_layer_init_method,
                layer_number,
                attention_type=AttnType.cross_attn)
            # Layernorm on the attention output.
            self.post_inter_attention_layernorm = LayerNorm(
                args.hidden_size,
                eps=args.layernorm_epsilon)

491
        # MLP
492
        self.mlp = ParallelMLP(init_method,
Mohammad's avatar
Mohammad committed
493
                               output_layer_init_method)
494

495
496
497
    def forward(self, hidden_states, attention_mask,
                encoder_output=None, enc_dec_attn_mask=None,
                layer_past=None, get_key_value=False):
498
499
        # hidden_states: [b, s, h]

500
        # Layer norm at the beginning of the transformer layer.
501
502
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
503
        attention_output, attention_bias = \
504
505
506
507
            self.self_attention(layernorm_output,
                                attention_mask,
                                layer_past=layer_past,
                                get_key_value=get_key_value)
508

509
510
        if get_key_value:
            attention_output, presents = attention_output
511

512
513
        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
514
515
516
517
            residual = layernorm_output
        else:
            residual = hidden_states

518
519
        # jit scripting for a nn.module (with dropout) is not
        # trigerring the fusion kernel. For now, we use two
520
521
522
523
524
525
526
        # different nn.functional routines to account for varying
        # dropout semantics during training and inference phases.
        if self.bias_dropout_fusion:
            if self.training:
                bias_dropout_add_func = bias_dropout_add_fused_train
            else:
                bias_dropout_add_func = bias_dropout_add_fused_inference
527
        else:
528
529
            bias_dropout_add_func = get_bias_dropout_add(self.training)

530
        # re-enable torch grad to enable fused optimization.
531
532
533
534
535
536
537
        with torch.enable_grad():
            layernorm_input = bias_dropout_add_func(
                attention_output,
                attention_bias.expand_as(residual),
                residual,
                self.hidden_dropout)

538
539
540
        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)

541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
        if self.layer_type == LayerType.decoder:
            attention_output, attention_bias = \
                self.inter_attention(layernorm_output,
                                     enc_dec_attn_mask,
                                     encoder_output=encoder_output)
            # residual connection
            if self.apply_residual_connection_post_layernorm:
                residual = layernorm_output
            else:
                residual = layernorm_input

            # re-enable torch grad to enable fused optimization.
            with torch.enable_grad():
                layernorm_input = bias_dropout_add_func(
                    attention_output,
                    attention_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)

            # Layer norm post the decoder attention
            layernorm_output = self.post_inter_attention_layernorm(layernorm_input)

563
        # MLP.
564
        mlp_output, mlp_bias = self.mlp(layernorm_output)
565

566
567
        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
568
            residual = layernorm_output
569
        else:
570
571
            residual = layernorm_input

572
        # re-enable torch grad to enable fused optimization.
573
574
575
576
577
578
        with torch.enable_grad():
            output = bias_dropout_add_func(
                mlp_output,
                mlp_bias.expand_as(residual),
                residual,
                self.hidden_dropout)
579
580
581
582
583
584
585
586
587
588

        if get_key_value:
            output = [output, presents]

        return output


class ParallelTransformer(MegatronModule):
    """Transformer class."""

589
    def __init__(self, attention_mask_func,
590
591
592
                 init_method, output_layer_init_method,
                 layer_type=LayerType.encoder,
                 self_attn_mask_type=AttnMaskType.padding):
593
        super(ParallelTransformer, self).__init__()
Mohammad's avatar
Mohammad committed
594
        args = get_args()
595

596
597
        self.fp32_residual_connection = args.fp32_residual_connection

598
        # Store activation checkpoiting flag.
Mohammad's avatar
Mohammad committed
599
600
        self.checkpoint_activations = args.checkpoint_activations
        self.checkpoint_num_layers = args.checkpoint_num_layers
601

602
        # Number of layers.
603
        assert args.num_layers % mpu.get_pipeline_model_parallel_world_size() == 0, \
604
            'num_layers must be divisible by pipeline_model_parallel_size'
605
        self.num_layers = args.num_layers // mpu.get_pipeline_model_parallel_world_size()
Mohammad's avatar
Mohammad committed
606
607
608

        # Transformer layers.
        def build_layer(layer_number):
609
            return ParallelTransformerLayer(
610
                attention_mask_func, init_method,
611
612
613
                output_layer_init_method, layer_number,
                layer_type=layer_type,
                self_attn_mask_type=self_attn_mask_type)
614
        offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
615
        self.layers = torch.nn.ModuleList(
616
            [build_layer(i + 1 + offset) for i in range(self.num_layers)])
617

618
        if mpu.is_pipeline_last_stage():
619
            # Final layer norm before output.
620
            LayerNorm = import_layernorm(args.fp32_residual_connection)
621
622
623
            self.final_layernorm = LayerNorm(
                args.hidden_size,
                eps=args.layernorm_epsilon)
624

Mohammad's avatar
Mohammad committed
625
    def _get_layer(self, layer_number):
626
        return self.layers[layer_number]
Mohammad's avatar
Mohammad committed
627

628
629
    def _checkpointed_forward(self, hidden_states, attention_mask,
                              encoder_output, enc_dec_attn_mask):
630
631
632
633
        """Forward method with activation checkpointing."""
        def custom(start, end):
            def custom_forward(*inputs):
                x_ = inputs[0]
634
635
636
                attention_mask = inputs[1]
                encoder_output = inputs[2]
                enc_dec_attn_mask = inputs[3]
Mohammad's avatar
Mohammad committed
637
638
                for index in range(start, end):
                    layer = self._get_layer(index)
639
                    x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
640
641
642
                return x_
            return custom_forward

643
644
        # Make sure memory is freed.
        mpu.reset_checkpointed_activations_memory_buffer()
645
        l = 0
Mohammad's avatar
Mohammad committed
646
        while l < self.num_layers:
647
            hidden_states = mpu.checkpoint(
Neel Kant's avatar
Neel Kant committed
648
                custom(l, l + self.checkpoint_num_layers),
649
                hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
650
651
652
653
654
            l += self.checkpoint_num_layers

        return hidden_states

    def forward(self, hidden_states, attention_mask, layer_past=None,
655
                get_key_value=False, encoder_output=None, enc_dec_attn_mask=None):
656

657
        # Checks.
658
659
660
661
662
663
664
665
666
        if layer_past is not None:
            assert get_key_value, \
                'for not None values in layer_past, ' \
                'expected get_key_value to be set'
        if get_key_value:
            assert not self.checkpoint_activations, \
                'get_key_value does not work with ' \
                'activation checkpointing'

667
668
        if mpu.is_pipeline_first_stage():
            # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
mshoeybi's avatar
mshoeybi committed
669
            # If the input flag for fp32 residual connection is set, convert for float.
670
671
            if self.fp32_residual_connection:
                hidden_states = hidden_states.transpose(0, 1).contiguous().float()
mshoeybi's avatar
mshoeybi committed
672
            # Otherwise, leave it as is.
673
674
            else:
                hidden_states = hidden_states.transpose(0, 1).contiguous()
675

676
677
        if self.checkpoint_activations:
            hidden_states = self._checkpointed_forward(hidden_states,
678
679
680
                                                       attention_mask,
                                                       encoder_output,
                                                       enc_dec_attn_mask)
681
682
683
        else:
            if get_key_value:
                presents = []
Mohammad's avatar
Mohammad committed
684
685
            for index in range(self.num_layers):
                layer = self._get_layer(index)
686
687
                past = None
                if layer_past is not None:
Mohammad's avatar
Mohammad committed
688
                    past = layer_past[index]
689
690
                hidden_states = layer(hidden_states,
                                      attention_mask,
691
692
                                      encoder_output=encoder_output,
                                      enc_dec_attn_mask=enc_dec_attn_mask,
693
694
695
696
697
                                      layer_past=past,
                                      get_key_value=get_key_value)
                if get_key_value:
                    hidden_states, present = hidden_states
                    presents.append(present)
698

699
        # Final layer norm.
700
        if mpu.is_pipeline_last_stage():
701
702
            # Reverting data format change [s b h] --> [b s h].
            hidden_states = hidden_states.transpose(0, 1).contiguous()
703
704
705
            output = self.final_layernorm(hidden_states)
        else:
            output = hidden_states
706
707
708
709
        if get_key_value:
            output = [output, presents]

        return output