"research/domain_adaptation/README.md" did not exist on "25dff8e4a1408c30fee1ac04f217e884008e5d2d"
transformer.py 27.4 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transformer."""
import math
import torch
19
import torch.nn.functional as F
20

Mohammad's avatar
Mohammad committed
21
from megatron import get_args
22
from megatron import mpu
23
from .module import MegatronModule
Vijay Korthikanti's avatar
Vijay Korthikanti committed
24
from megatron.checkpointing import get_checkpoint_version
25
from megatron.model.enums import AttnMaskType, LayerType, AttnType
26
from megatron.model import import_layernorm
27
28
29
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import openai_gelu, erf_gelu
30

31
32
33
34
35
# flags required to enable jit fusion kernels
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
36
37
38
39
40
41
42
43
44
45
46

""" We use the following notation throughout this file:
     h: hidden size
     n: number of attention heads
     p: number of model parallel partitions
     np: n/p
     hp: h/p
     hn: h/n
     b: batch size
     s: sequence length
     l: number of layers
47
    Transformer takes input of size [s, b, h] and returns a
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    tensor of the same size. We use the following arguments:
        hyperparameters: transformer hyperparameters
        attention_mask_func: a function that takes `unmaksed-attention-scores`
            with size [b, np, s, s] and an `attention-mask` and will apply
            the masking. The function should return a masked score of the
            same size [b, np, s, s].
               masked-attention-scores = attention_mask_func(
                                     unmaksed-attention-scores, attention-mask)
"""

class ParallelMLP(MegatronModule):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
    state back into h hidden dimension. At the end, dropout is also
    applied.
    """

67
    def __init__(self, init_method, output_layer_init_method):
68
        super(ParallelMLP, self).__init__()
Mohammad's avatar
Mohammad committed
69
        args = get_args()
70
71
72

        # Project to 4h.
        self.dense_h_to_4h = mpu.ColumnParallelLinear(
Mohammad's avatar
Mohammad committed
73
            args.hidden_size,
74
            args.ffn_hidden_size,
75
            gather_output=False,
76
77
            init_method=init_method,
            skip_bias_add=True)
78

79
80
81
82
83
84
        self.bias_gelu_fusion = args.bias_gelu_fusion
        self.activation_func = F.gelu
        if args.openai_gelu:
            self.activation_func = openai_gelu
        elif args.onnx_safe:
            self.activation_func = erf_gelu
85
86
87

        # Project back to h.
        self.dense_4h_to_h = mpu.RowParallelLinear(
88
            args.ffn_hidden_size,
Mohammad's avatar
Mohammad committed
89
            args.hidden_size,
90
            input_is_parallel=True,
91
92
            init_method=output_layer_init_method,
            skip_bias_add=True)
93
94
95

    def forward(self, hidden_states):

96
97
        # [s, b, 4hp]
        intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
98

99
100
101
102
103
104
105
106
107
108
        if self.bias_gelu_fusion:
             intermediate_parallel = \
                     bias_gelu_impl(intermediate_parallel, bias_parallel)
        else:
            intermediate_parallel = \
                self.activation_func(intermediate_parallel + bias_parallel)

        # [s, b, h]
        output, output_bias = self.dense_4h_to_h(intermediate_parallel)
        return output, output_bias
109
110


111
class ParallelAttention(MegatronModule):
112
113
114
115
116
    """Parallel self-attention layer abstract class.

    Self-attention layer takes input with size [b, s, h]
    and returns output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
117

Mohammad's avatar
Mohammad committed
118
    def __init__(self, attention_mask_func, init_method,
119
120
121
122
                 output_layer_init_method, layer_number,
                 attention_type=AttnType.self_attn,
                 attn_mask_type=AttnMaskType.padding):
        super(ParallelAttention, self).__init__()
Mohammad's avatar
Mohammad committed
123
        args = get_args()
Mohammad's avatar
Mohammad committed
124
        self.fp16 = args.fp16
125
126

        self.attention_mask_func = attention_mask_func
Mohammad's avatar
Mohammad committed
127
128
        self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
129
130
131
        if self.apply_query_key_layer_scaling:
            self.attention_softmax_in_fp32 = True
        self.layer_number = max(1, layer_number)
132
133
134
135
        self.attention_type = attention_type
        self.attn_mask_type = attn_mask_type

        projection_size = args.kv_channels * args.num_attention_heads
136
137

        # Per attention head and per partition values.
138
        world_size = mpu.get_tensor_model_parallel_world_size()
139
        self.hidden_size_per_partition = mpu.divide(projection_size,
Mohammad's avatar
Mohammad committed
140
                                                    world_size)
141
        self.hidden_size_per_attention_head = mpu.divide(
142
            projection_size, args.num_attention_heads)
143
        self.num_attention_heads_per_partition = mpu.divide(
Mohammad's avatar
Mohammad committed
144
            args.num_attention_heads, world_size)
145
146

        # Strided linear layer.
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
        if attention_type == AttnType.self_attn:
            self.query_key_value = mpu.ColumnParallelLinear(
                args.hidden_size,
                3 * projection_size,
                gather_output=False,
                init_method=init_method)
        else:
            assert attention_type == AttnType.cross_attn
            self.query = mpu.ColumnParallelLinear(
                args.hidden_size,
                projection_size,
                gather_output=False,
                init_method=init_method)

            self.key_value = mpu.ColumnParallelLinear(
                args.hidden_size,
                2 * projection_size,
                gather_output=False,
                init_method=init_method)
166

167
168
169
170
171
172
173
174
        coeff = None
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        if self.apply_query_key_layer_scaling:
            coeff = self.layer_number
            self.norm_factor *= coeff

        self.scale_mask_softmax = FusedScaleMaskSoftmax(
            self.fp16,
175
176
            self.attn_mask_type,
            args.masked_softmax_fusion,
177
178
179
180
            self.attention_mask_func,
            self.attention_softmax_in_fp32,
            coeff)

181
182
183
        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
Mohammad's avatar
Mohammad committed
184
        self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
185
186
187

        # Output.
        self.dense = mpu.RowParallelLinear(
188
            projection_size,
Mohammad's avatar
Mohammad committed
189
            args.hidden_size,
190
            input_is_parallel=True,
191
192
            init_method=output_layer_init_method,
            skip_bias_add=True)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
193

Vijay Korthikanti's avatar
Vijay Korthikanti committed
194
    def _transpose_last_dim(self, mixed_layer, num_splits, num_splits_first):
195
        input_shape = mixed_layer.size()
Vijay Korthikanti's avatar
Vijay Korthikanti committed
196
197
        if num_splits_first:
            """[s, b, num_splits * np * hn] 
198
199
            -->(view) [s, b, num_splits, np, hn]
            -->(tranpose) [s, b, np, num_splits, hn]
Vijay Korthikanti's avatar
Vijay Korthikanti committed
200
201
            -->(view) [s, b, np * num_splits * hn] """

202
203
204
            intermediate_shape = input_shape[:-1] +\
                (num_splits, self.num_attention_heads_per_partition,
                 self.hidden_size_per_attention_head)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
205

206
207
208
            mixed_layer = mixed_layer.view(*intermediate_shape)
            mixed_layer = mixed_layer.transpose(-2, -3).contiguous()
        else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
209
            """[s, b, np * hn * num_splits] 
210
211
            -->(view) [s, b, np, hn, num_splits]
            -->(tranpose) [s, b, np, num_splits, hn]
Vijay Korthikanti's avatar
Vijay Korthikanti committed
212
213
            -->(view) [s, b, np * num_splits * hn] """

214
215
216
217
218
219
            intermediate_shape = input_shape[:-1] +\
                (self.num_attention_heads_per_partition,
                 self.hidden_size_per_attention_head, num_splits)

            mixed_layer = mixed_layer.view(*intermediate_shape)
            mixed_layer = mixed_layer.transpose(-1, -2).contiguous()
220
        mixed_layer = mixed_layer.view(*input_shape)
221

222
        return mixed_layer
223

224
    def forward(self, hidden_states, attention_mask, layer_past=None,
225
                get_key_value=False, encoder_output=None):
226
        # hidden_states: [sq, b, h]
227

228
229
230
        # =====================
        # Query, Key, and Value
        # =====================
231

232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
        if self.attention_type == AttnType.self_attn:
            # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
            mixed_x_layer, _ = self.query_key_value(hidden_states)

            checkpoint_version = get_checkpoint_version()
            if checkpoint_version is not None:
                if checkpoint_version == 0:
                    # [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)]
                    mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, True)
                elif checkpoint_version == 1.0:
                    # [s, b, (np * hn * 3)] --> [s, b, (np * 3 * hn)]
                    mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, False)

            # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
            new_tensor_shape = mixed_x_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 3 * self.hidden_size_per_attention_head)
            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

            # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
            (query_layer,
             key_layer,
             value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
        else:
            # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
            mixed_kv_layer, _ = self.key_value(encoder_output)

            checkpoint_version = get_checkpoint_version()
            if checkpoint_version is not None:
                if checkpoint_version == 0:
                    # [s, b, (2 * np * hn)] --> [s, b, (np * 2 * hn)]
                    mixed_kv_layer = self._transpose_last_dim(mixed_kv_layer, 2, True)
                elif checkpoint_version == 1.0:
                    # [s, b, (np * hn * 2)] --> [s, b, (np * 2 * hn)]
                    mixed_kv_layer = self._transpose_last_dim(mixed_kv_layer, 2, False)

            # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
            new_tensor_shape = mixed_kv_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 2 * self.hidden_size_per_attention_head)
            mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

            # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
            (key_layer,
             value_layer) = mpu.split_tensor_along_last_dim(mixed_kv_layer, 2)

            # Attention head [sq, b, h] --> [sq, b, hp]
            query_layer, _ = self.query(hidden_states)
            # [sq, b, hp] --> [sq, b, np, hn]
            new_tensor_shape = query_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 self.hidden_size_per_attention_head)
            query_layer = query_layer.view(*new_tensor_shape)
285

286
287
288
        # ==================================
        # Adjust key and value for inference
        # ==================================
289
290
291
292

        if layer_past is not None:
            past_key, past_value = layer_past
            key_layer = torch.cat((past_key.type_as(key_layer),
293
                                   key_layer), dim=0)
294
            value_layer = torch.cat((past_value.type_as(value_layer),
295
                                     value_layer), dim=0)
296
297
298
        if get_key_value:
            present = (key_layer, value_layer)

299
300
301
        # ===================================
        # Raw attention scores. [b, np, s, s]
        # ===================================
302

303
        # [b, np, sq, sk]
304
305
306
        output_size = (query_layer.size(1),
                       query_layer.size(2),
                       query_layer.size(0),
307
                       key_layer.size(0))
308

309
        # [sq, b, np, hn] -> [sq, b * np, hn]
310
311
        query_layer = query_layer.view(output_size[2],
                                       output_size[0] * output_size[1], -1)
312
        # [sk, b, np, hn] -> [sk, b * np, hn]
313
314
315
        key_layer = key_layer.view(output_size[3],
                                   output_size[0] * output_size[1], -1)

316
        # preallocting result tensor: [b * np, sq, sk]
317
        matmul_result = torch.empty(
318
319
            output_size[0]*output_size[1],
            output_size[2],
320
            output_size[3],
321
            dtype=query_layer.dtype,
322
323
            device=torch.cuda.current_device())

324
        # Raw attention scores. [b * np, sq, sk]
325
326
        matmul_result = torch.baddbmm(
            matmul_result,
327
            query_layer.transpose(0, 1),   # [b * np, sq, hn]
328
            key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
329
330
            beta=0.0, alpha=(1.0/self.norm_factor))

331
        # change view to [b, np, sq, sk]
332
333
334
        attention_scores = matmul_result.view(*output_size)

        # ==================================================
335
        # Update attention mask for inference. [b, np, sq, sk]
336
        # ==================================================
337

338
339
340
341
342
        if get_key_value:
            with torch.no_grad():
                if layer_past is not None:
                    attention_mask = attention_mask[
                        ...,
Neel Kant's avatar
Neel Kant committed
343
                        attention_scores.size(3) - 1,
344
345
346
347
348
349
350
                        :attention_scores.size(3)].unsqueeze(2)
                else:
                    attention_mask = attention_mask[
                        ...,
                        :attention_scores.size(3),
                        :attention_scores.size(3)]

351
352
353
        # ===========================
        # Attention probs and dropout
        # ===========================
354

355
        # attention scores and attention mask [b, np, sq, sk]
356
357
        attention_probs = self.scale_mask_softmax(attention_scores,
                                                  attention_mask)
358

359
360
361
362
363
364
        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        with mpu.get_cuda_rng_tracker().fork():
            attention_probs = self.attention_dropout(attention_probs)

        # =========================
365
        # Context layer. [sq, b, hp]
366
367
        # =========================

368
369
        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]
370

371
        # context layer shape: [b, np, sq, hn]
372
373
374
375
        output_size = (value_layer.size(1),
                       value_layer.size(2),
                       query_layer.size(0),
                       value_layer.size(3))
376

377
        # change view [sk, b * np, hn]
378
        value_layer = value_layer.view(value_layer.size(0),
379
                                       output_size[0] * output_size[1], -1)
380

381
        # change view [b * np, sq, sk]
382
383
        attention_probs = attention_probs.view(output_size[0] * output_size[1],
                                               output_size[2], -1)
384

385
        # matmul: [b * np, sq, hn]
386
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
387

388
        # change view [b, np, sq, hn]
389
390
        context_layer = context_layer.view(*output_size)

391
        # [b, np, sq, hn] --> [sq, b, np, hn]
392
393
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

394
        # [sq, b, np, hn] --> [sq, b, hp]
395
396
397
398
399
        new_context_layer_shape = context_layer.size()[:-2] + \
            (self.hidden_size_per_partition,)
        context_layer = context_layer.view(*new_context_layer_shape)

        # =================
400
        # Output. [sq, b, h]
401
402
403
        # =================

        output, bias = self.dense(context_layer)
404
405
406
407

        if get_key_value:
            output = [output, present]

408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
        return output, bias


def bias_dropout_add(x, bias, residual, prob, training) :
    # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
    out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
    out = residual + out
    return out


def get_bias_dropout_add(training):
    def _bias_dropout_add(x, bias, residual, prob):
        return bias_dropout_add(x, bias, residual, prob, training)
    return _bias_dropout_add


@torch.jit.script
def bias_dropout_add_fused_train(x, bias, residual, prob) :
    # type: (Tensor, Tensor, Tensor, float) -> Tensor
    return bias_dropout_add(x, bias, residual, prob, True)


@torch.jit.script
def bias_dropout_add_fused_inference(x, bias, residual, prob) :
    # type: (Tensor, Tensor, Tensor, float) -> Tensor
    return bias_dropout_add(x, bias, residual, prob, False)
434
435
436
437
438


class ParallelTransformerLayer(MegatronModule):
    """A single transformer layer.

439
    Transformer layer takes input with size [b, s, h] and returns an
440
441
    output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
442

443
444
445
446
    def __init__(self, attention_mask_func, init_method,
                 output_layer_init_method, layer_number,
                 layer_type=LayerType.encoder,
                 self_attn_mask_type=AttnMaskType.padding):
Mohammad's avatar
Mohammad committed
447
        args = get_args()
448
449

        super(ParallelTransformerLayer, self).__init__()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
450
        self.layer_number = layer_number
451
        self.layer_type = layer_type
452
453

        self.apply_residual_connection_post_layernorm \
Mohammad's avatar
Mohammad committed
454
            = args.apply_residual_connection_post_layernorm
455
456

        # Layernorm on the input data.
457
        LayerNorm = import_layernorm(args.fp32_residual_connection)
458
        self.input_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
459
460
            args.hidden_size,
            eps=args.layernorm_epsilon)
461
462

        # Self attention.
463
464
465
466
467
468
469
        self.self_attention = ParallelAttention(
            attention_mask_func,
            init_method,
            output_layer_init_method,
            layer_number,
            attention_type=AttnType.self_attn,
            attn_mask_type=self_attn_mask_type)
470
471
        self.hidden_dropout = args.hidden_dropout
        self.bias_dropout_fusion = args.bias_dropout_fusion
472

473
        # Layernorm on the attention output
474
        self.post_attention_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
475
476
            args.hidden_size,
            eps=args.layernorm_epsilon)
477

478
479
480
481
482
483
484
485
486
487
488
489
        if self.layer_type == LayerType.decoder:
            self.inter_attention = ParallelAttention(
                attention_mask_func,
                init_method,
                output_layer_init_method,
                layer_number,
                attention_type=AttnType.cross_attn)
            # Layernorm on the attention output.
            self.post_inter_attention_layernorm = LayerNorm(
                args.hidden_size,
                eps=args.layernorm_epsilon)

490
        # MLP
491
        self.mlp = ParallelMLP(init_method,
Mohammad's avatar
Mohammad committed
492
                               output_layer_init_method)
493

494
495
496
    def forward(self, hidden_states, attention_mask,
                encoder_output=None, enc_dec_attn_mask=None,
                layer_past=None, get_key_value=False):
497
498
        # hidden_states: [b, s, h]

499
        # Layer norm at the beginning of the transformer layer.
500
501
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
502
        attention_output, attention_bias = \
503
504
505
506
            self.self_attention(layernorm_output,
                                attention_mask,
                                layer_past=layer_past,
                                get_key_value=get_key_value)
507

508
509
        if get_key_value:
            attention_output, presents = attention_output
510

511
512
        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
513
514
515
516
            residual = layernorm_output
        else:
            residual = hidden_states

517
518
        # jit scripting for a nn.module (with dropout) is not
        # trigerring the fusion kernel. For now, we use two
519
520
521
522
523
524
525
        # different nn.functional routines to account for varying
        # dropout semantics during training and inference phases.
        if self.bias_dropout_fusion:
            if self.training:
                bias_dropout_add_func = bias_dropout_add_fused_train
            else:
                bias_dropout_add_func = bias_dropout_add_fused_inference
526
        else:
527
528
            bias_dropout_add_func = get_bias_dropout_add(self.training)

529
        # re-enable torch grad to enable fused optimization.
530
531
532
533
534
535
536
        with torch.enable_grad():
            layernorm_input = bias_dropout_add_func(
                attention_output,
                attention_bias.expand_as(residual),
                residual,
                self.hidden_dropout)

537
538
539
        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)

540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
        if self.layer_type == LayerType.decoder:
            attention_output, attention_bias = \
                self.inter_attention(layernorm_output,
                                     enc_dec_attn_mask,
                                     encoder_output=encoder_output)
            # residual connection
            if self.apply_residual_connection_post_layernorm:
                residual = layernorm_output
            else:
                residual = layernorm_input

            # re-enable torch grad to enable fused optimization.
            with torch.enable_grad():
                layernorm_input = bias_dropout_add_func(
                    attention_output,
                    attention_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)

            # Layer norm post the decoder attention
            layernorm_output = self.post_inter_attention_layernorm(layernorm_input)

562
        # MLP.
563
        mlp_output, mlp_bias = self.mlp(layernorm_output)
564

565
566
        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
567
            residual = layernorm_output
568
        else:
569
570
            residual = layernorm_input

571
        # re-enable torch grad to enable fused optimization.
572
573
574
575
576
577
        with torch.enable_grad():
            output = bias_dropout_add_func(
                mlp_output,
                mlp_bias.expand_as(residual),
                residual,
                self.hidden_dropout)
578
579
580
581
582
583
584
585
586
587

        if get_key_value:
            output = [output, presents]

        return output


class ParallelTransformer(MegatronModule):
    """Transformer class."""

588
    def __init__(self, attention_mask_func,
589
590
591
                 init_method, output_layer_init_method,
                 layer_type=LayerType.encoder,
                 self_attn_mask_type=AttnMaskType.padding):
592
        super(ParallelTransformer, self).__init__()
Mohammad's avatar
Mohammad committed
593
        args = get_args()
594

595
596
        self.fp32_residual_connection = args.fp32_residual_connection

597
        # Store activation checkpoiting flag.
Mohammad's avatar
Mohammad committed
598
599
        self.checkpoint_activations = args.checkpoint_activations
        self.checkpoint_num_layers = args.checkpoint_num_layers
600

601
        # Number of layers.
602
        assert args.num_layers % mpu.get_pipeline_model_parallel_world_size() == 0, \
603
            'num_layers must be divisible by pipeline_model_parallel_size'
604
        self.num_layers = args.num_layers // mpu.get_pipeline_model_parallel_world_size()
Mohammad's avatar
Mohammad committed
605
606
607

        # Transformer layers.
        def build_layer(layer_number):
608
            return ParallelTransformerLayer(
609
                attention_mask_func, init_method,
610
611
612
                output_layer_init_method, layer_number,
                layer_type=layer_type,
                self_attn_mask_type=self_attn_mask_type)
613
        offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
614
        self.layers = torch.nn.ModuleList(
615
            [build_layer(i + 1 + offset) for i in range(self.num_layers)])
616

617
        if mpu.is_pipeline_last_stage():
618
            # Final layer norm before output.
619
            LayerNorm = import_layernorm(args.fp32_residual_connection)
620
621
622
            self.final_layernorm = LayerNorm(
                args.hidden_size,
                eps=args.layernorm_epsilon)
623

Mohammad's avatar
Mohammad committed
624
    def _get_layer(self, layer_number):
625
        return self.layers[layer_number]
Mohammad's avatar
Mohammad committed
626

627
628
    def _checkpointed_forward(self, hidden_states, attention_mask,
                              encoder_output, enc_dec_attn_mask):
629
630
631
632
        """Forward method with activation checkpointing."""
        def custom(start, end):
            def custom_forward(*inputs):
                x_ = inputs[0]
633
634
635
                attention_mask = inputs[1]
                encoder_output = inputs[2]
                enc_dec_attn_mask = inputs[3]
Mohammad's avatar
Mohammad committed
636
637
                for index in range(start, end):
                    layer = self._get_layer(index)
638
                    x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
639
640
641
                return x_
            return custom_forward

642
643
        # Make sure memory is freed.
        mpu.reset_checkpointed_activations_memory_buffer()
644
        l = 0
Mohammad's avatar
Mohammad committed
645
        while l < self.num_layers:
646
            hidden_states = mpu.checkpoint(
Neel Kant's avatar
Neel Kant committed
647
                custom(l, l + self.checkpoint_num_layers),
648
                hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
649
650
651
652
653
            l += self.checkpoint_num_layers

        return hidden_states

    def forward(self, hidden_states, attention_mask, layer_past=None,
654
                get_key_value=False, encoder_output=None, enc_dec_attn_mask=None):
655

656
        # Checks.
657
658
659
660
661
662
663
664
665
        if layer_past is not None:
            assert get_key_value, \
                'for not None values in layer_past, ' \
                'expected get_key_value to be set'
        if get_key_value:
            assert not self.checkpoint_activations, \
                'get_key_value does not work with ' \
                'activation checkpointing'

666
667
        if mpu.is_pipeline_first_stage():
            # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
mshoeybi's avatar
mshoeybi committed
668
            # If the input flag for fp32 residual connection is set, convert for float.
669
670
            if self.fp32_residual_connection:
                hidden_states = hidden_states.transpose(0, 1).contiguous().float()
mshoeybi's avatar
mshoeybi committed
671
            # Otherwise, leave it as is.
672
673
            else:
                hidden_states = hidden_states.transpose(0, 1).contiguous()
674

Vijay Korthikanti's avatar
Vijay Korthikanti committed
675
676
677
        if encoder_output is not None:
             encoder_output = encoder_output.transpose(0, 1).contiguous()
          
678
679
        if self.checkpoint_activations:
            hidden_states = self._checkpointed_forward(hidden_states,
680
681
682
                                                       attention_mask,
                                                       encoder_output,
                                                       enc_dec_attn_mask)
683
684
685
        else:
            if get_key_value:
                presents = []
Mohammad's avatar
Mohammad committed
686
687
            for index in range(self.num_layers):
                layer = self._get_layer(index)
688
689
                past = None
                if layer_past is not None:
Mohammad's avatar
Mohammad committed
690
                    past = layer_past[index]
691
692
                hidden_states = layer(hidden_states,
                                      attention_mask,
693
694
                                      encoder_output=encoder_output,
                                      enc_dec_attn_mask=enc_dec_attn_mask,
695
696
697
698
699
                                      layer_past=past,
                                      get_key_value=get_key_value)
                if get_key_value:
                    hidden_states, present = hidden_states
                    presents.append(present)
700

701
        # Final layer norm.
702
        if mpu.is_pipeline_last_stage():
703
704
            # Reverting data format change [s b h] --> [b s h].
            hidden_states = hidden_states.transpose(0, 1).contiguous()
705
706
707
            output = self.final_layernorm(hidden_states)
        else:
            output = hidden_states
708
709
710
711
        if get_key_value:
            output = [output, presents]

        return output