transformer.py 30.4 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transformer."""
import math
import torch
19
import torch.nn.functional as F
20

Mohammad's avatar
Mohammad committed
21
from megatron import get_args
22
from megatron import mpu
23
from .module import MegatronModule
24
from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType
25
from megatron.model import LayerNorm
26
27
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
28
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
29
30
31
32
33
34
35
36
37
38
39
40


""" We use the following notation throughout this file:
     h: hidden size
     n: number of attention heads
     p: number of model parallel partitions
     np: n/p
     hp: h/p
     hn: h/n
     b: batch size
     s: sequence length
     l: number of layers
41
    Transformer takes input of size [s, b, h] and returns a
42
43
44
45
    tensor of the same size. We use the following arguments:
        hyperparameters: transformer hyperparameters
"""

46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68

class DropPath(MegatronModule):
    """Drop paths (Stochastic Depth) per sample 
    (when applied in main path of residual blocks).
    """

    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        if self.drop_prob == 0. or not self.training:
            return x
        keep_prob = 1 - self.drop_prob
        # work with diff dim tensors, not just 2D ConvNets
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        random_tensor = keep_prob + \
            torch.rand(shape, dtype=x.dtype, device=x.device)
        random_tensor.floor_()  # binarize
        output = x.div(keep_prob) * random_tensor
        return output


69
70
71
72
73
class ParallelMLP(MegatronModule):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
hwijeen's avatar
hwijeen committed
74
    state back into h hidden dimension.
75
76
    """

77
    def __init__(self, init_method, output_layer_init_method):
78
        super(ParallelMLP, self).__init__()
Mohammad's avatar
Mohammad committed
79
        args = get_args()
80
81
82

        # Project to 4h.
        self.dense_h_to_4h = mpu.ColumnParallelLinear(
Mohammad's avatar
Mohammad committed
83
            args.hidden_size,
84
            args.ffn_hidden_size,
85
            gather_output=False,
86
87
            init_method=init_method,
            skip_bias_add=True)
88

89
90
91
92
93
94
        self.bias_gelu_fusion = args.bias_gelu_fusion
        self.activation_func = F.gelu
        if args.openai_gelu:
            self.activation_func = openai_gelu
        elif args.onnx_safe:
            self.activation_func = erf_gelu
95
96
97

        # Project back to h.
        self.dense_4h_to_h = mpu.RowParallelLinear(
98
            args.ffn_hidden_size,
Mohammad's avatar
Mohammad committed
99
            args.hidden_size,
100
            input_is_parallel=True,
101
102
            init_method=output_layer_init_method,
            skip_bias_add=True)
103

104
105
    def forward(self, hidden_states):

106
107
        # [s, b, 4hp]
        intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
108

109
110
111
112
113
114
115
116
117
118
        if self.bias_gelu_fusion:
             intermediate_parallel = \
                     bias_gelu_impl(intermediate_parallel, bias_parallel)
        else:
            intermediate_parallel = \
                self.activation_func(intermediate_parallel + bias_parallel)

        # [s, b, h]
        output, output_bias = self.dense_4h_to_h(intermediate_parallel)
        return output, output_bias
119
120


121
class ParallelAttention(MegatronModule):
122
123
124
125
126
    """Parallel self-attention layer abstract class.

    Self-attention layer takes input with size [b, s, h]
    and returns output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
127

128
    def __init__(self, init_method,
129
130
131
132
                 output_layer_init_method, layer_number,
                 attention_type=AttnType.self_attn,
                 attn_mask_type=AttnMaskType.padding):
        super(ParallelAttention, self).__init__()
Mohammad's avatar
Mohammad committed
133
        args = get_args()
Mohammad's avatar
Mohammad committed
134
        self.fp16 = args.fp16
135
        self.bf16 = args.bf16
136

Mohammad's avatar
Mohammad committed
137
138
        self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
139
140
141
        if self.apply_query_key_layer_scaling:
            self.attention_softmax_in_fp32 = True
        self.layer_number = max(1, layer_number)
142
143
        self.attention_type = attention_type
        self.attn_mask_type = attn_mask_type
144
        self.params_dtype = args.params_dtype
145
146

        projection_size = args.kv_channels * args.num_attention_heads
147
148

        # Per attention head and per partition values.
149
        world_size = mpu.get_tensor_model_parallel_world_size()
150
        self.hidden_size_per_partition = mpu.divide(projection_size,
Mohammad's avatar
Mohammad committed
151
                                                    world_size)
152
        self.hidden_size_per_attention_head = mpu.divide(
153
            projection_size, args.num_attention_heads)
154
        self.num_attention_heads_per_partition = mpu.divide(
Mohammad's avatar
Mohammad committed
155
            args.num_attention_heads, world_size)
156
157

        # Strided linear layer.
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
        if attention_type == AttnType.self_attn:
            self.query_key_value = mpu.ColumnParallelLinear(
                args.hidden_size,
                3 * projection_size,
                gather_output=False,
                init_method=init_method)
        else:
            assert attention_type == AttnType.cross_attn
            self.query = mpu.ColumnParallelLinear(
                args.hidden_size,
                projection_size,
                gather_output=False,
                init_method=init_method)

            self.key_value = mpu.ColumnParallelLinear(
                args.hidden_size,
                2 * projection_size,
                gather_output=False,
                init_method=init_method)
177

178
179
180
181
182
183
184
        coeff = None
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        if self.apply_query_key_layer_scaling:
            coeff = self.layer_number
            self.norm_factor *= coeff

        self.scale_mask_softmax = FusedScaleMaskSoftmax(
185
            self.fp16, self.bf16,
186
187
            self.attn_mask_type,
            args.masked_softmax_fusion,
188
            attention_mask_func,
189
190
191
            self.attention_softmax_in_fp32,
            coeff)

192
193
194
        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
Mohammad's avatar
Mohammad committed
195
        self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
196
197
198

        # Output.
        self.dense = mpu.RowParallelLinear(
199
            projection_size,
Mohammad's avatar
Mohammad committed
200
            args.hidden_size,
201
            input_is_parallel=True,
202
203
            init_method=output_layer_init_method,
            skip_bias_add=True)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
204

205
206
207
208
209
210
211
212
213
214
215
216

    def _allocate_memory(self, inference_max_sequence_len, batch_size):
        return torch.empty(
            inference_max_sequence_len,
            batch_size,
            self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
            dtype=self.params_dtype,
            device=torch.cuda.current_device())
        

    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
217
                encoder_output=None, inference_params=None):
218
        # hidden_states: [sq, b, h]
219

220
221
222
223

        # =================================================
        # Pre-allocate memory for key-values for inference.
        # =================================================
mshoeybi's avatar
mshoeybi committed
224
        if inference_params:
225
            if self.layer_number not in inference_params.key_value_memory_dict:
mshoeybi's avatar
mshoeybi committed
226
                inf_max_seq_len = inference_params.max_sequence_len
mshoeybi's avatar
mshoeybi committed
227
                inf_max_batch_size = inference_params.max_batch_size
228
                inference_key_memory = self._allocate_memory(
mshoeybi's avatar
mshoeybi committed
229
                    inf_max_seq_len, inf_max_batch_size)
230
                inference_value_memory = self._allocate_memory(
mshoeybi's avatar
mshoeybi committed
231
                    inf_max_seq_len, inf_max_batch_size)
232
233
234
235
236
                inference_params.key_value_memory_dict[self.layer_number] = (
                    inference_key_memory, inference_value_memory)
            else:
                inference_key_memory, inference_value_memory = \
                    inference_params.key_value_memory_dict[self.layer_number]
mshoeybi's avatar
mshoeybi committed
237

238

239
240
241
        # =====================
        # Query, Key, and Value
        # =====================
242

243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
        if self.attention_type == AttnType.self_attn:
            # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
            mixed_x_layer, _ = self.query_key_value(hidden_states)

            # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
            new_tensor_shape = mixed_x_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 3 * self.hidden_size_per_attention_head)
            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

            # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
            (query_layer,
             key_layer,
             value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
        else:
            # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
            mixed_kv_layer, _ = self.key_value(encoder_output)

            # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
            new_tensor_shape = mixed_kv_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 2 * self.hidden_size_per_attention_head)
            mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

            # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
            (key_layer,
             value_layer) = mpu.split_tensor_along_last_dim(mixed_kv_layer, 2)

            # Attention head [sq, b, h] --> [sq, b, hp]
            query_layer, _ = self.query(hidden_states)
            # [sq, b, hp] --> [sq, b, np, hn]
            new_tensor_shape = query_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 self.hidden_size_per_attention_head)
            query_layer = query_layer.view(*new_tensor_shape)
278
279


mshoeybi's avatar
mshoeybi committed
280
281
282
        # ==================================
        # Adjust key and value for inference
        # ==================================
283

mshoeybi's avatar
mshoeybi committed
284
        if inference_params:
mshoeybi's avatar
mshoeybi committed
285
286
            batch_start = inference_params.batch_size_offset
            batch_end = batch_start + key_layer.size(1)
287
            assert batch_end <= inference_key_memory.size(1)
mshoeybi's avatar
mshoeybi committed
288
289
            sequence_start = inference_params.sequence_len_offset
            sequence_end = sequence_start + key_layer.size(0)
290
            assert sequence_end <= inference_key_memory.size(0)
291
            # Copy key and values.
292
293
294
295
296
            inference_key_memory[sequence_start:sequence_end,
                                 batch_start:batch_end, ...] = key_layer
            inference_value_memory[sequence_start:sequence_end,
                                   batch_start:batch_end, ...] = value_layer
            key_layer = inference_key_memory[
mshoeybi's avatar
mshoeybi committed
297
                :sequence_end, batch_start:batch_end, ...]
298
            value_layer = inference_value_memory[
mshoeybi's avatar
mshoeybi committed
299
                :sequence_end, batch_start:batch_end, ...]
300

301

302
303
304
        # ===================================
        # Raw attention scores. [b, np, s, s]
        # ===================================
305

306
        # [b, np, sq, sk]
307
308
309
        output_size = (query_layer.size(1),
                       query_layer.size(2),
                       query_layer.size(0),
310
                       key_layer.size(0))
311

312
        # [sq, b, np, hn] -> [sq, b * np, hn]
313
314
        query_layer = query_layer.view(output_size[2],
                                       output_size[0] * output_size[1], -1)
315
        # [sk, b, np, hn] -> [sk, b * np, hn]
316
317
318
        key_layer = key_layer.view(output_size[3],
                                   output_size[0] * output_size[1], -1)

319
        # preallocting result tensor: [b * np, sq, sk]
320
        matmul_result = torch.empty(
321
322
            output_size[0]*output_size[1],
            output_size[2],
323
            output_size[3],
324
            dtype=query_layer.dtype,
325
326
            device=torch.cuda.current_device())

327
        # Raw attention scores. [b * np, sq, sk]
328
329
        matmul_result = torch.baddbmm(
            matmul_result,
330
            query_layer.transpose(0, 1),   # [b * np, sq, hn]
331
            key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
332
333
            beta=0.0, alpha=(1.0/self.norm_factor))

334
        # change view to [b, np, sq, sk]
335
336
        attention_scores = matmul_result.view(*output_size)

337

338
339
340
        # ===========================
        # Attention probs and dropout
        # ===========================
341

342
        # attention scores and attention mask [b, np, sq, sk]
343
344
        attention_probs = self.scale_mask_softmax(attention_scores,
                                                  attention_mask)
345

346
347
348
349
350
351
        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        with mpu.get_cuda_rng_tracker().fork():
            attention_probs = self.attention_dropout(attention_probs)

        # =========================
352
        # Context layer. [sq, b, hp]
353
354
        # =========================

355
356
        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]
357

358
        # context layer shape: [b, np, sq, hn]
359
360
361
362
        output_size = (value_layer.size(1),
                       value_layer.size(2),
                       query_layer.size(0),
                       value_layer.size(3))
363

364
        # change view [sk, b * np, hn]
365
        value_layer = value_layer.view(value_layer.size(0),
366
                                       output_size[0] * output_size[1], -1)
367

368
        # change view [b * np, sq, sk]
369
370
        attention_probs = attention_probs.view(output_size[0] * output_size[1],
                                               output_size[2], -1)
371

372
        # matmul: [b * np, sq, hn]
373
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
374

375
        # change view [b, np, sq, hn]
376
377
        context_layer = context_layer.view(*output_size)

378
        # [b, np, sq, hn] --> [sq, b, np, hn]
379
380
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

381
        # [sq, b, np, hn] --> [sq, b, hp]
382
383
384
385
386
        new_context_layer_shape = context_layer.size()[:-2] + \
            (self.hidden_size_per_partition,)
        context_layer = context_layer.view(*new_context_layer_shape)

        # =================
387
        # Output. [sq, b, h]
388
389
390
        # =================

        output, bias = self.dense(context_layer)
391

392
393
394
        return output, bias


395
def bias_dropout_add(x, bias, residual, prob, training):
396
397
398
399
400
401
402
403
404
405
406
407
408
    # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
    out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
    out = residual + out
    return out


def get_bias_dropout_add(training):
    def _bias_dropout_add(x, bias, residual, prob):
        return bias_dropout_add(x, bias, residual, prob, training)
    return _bias_dropout_add


@torch.jit.script
409
410
411
412
def bias_dropout_add_fused_train(x: torch.Tensor,
                                 bias: torch.Tensor,
                                 residual: torch.Tensor,
                                 prob: float) -> torch.Tensor:
413
414
415
416
    return bias_dropout_add(x, bias, residual, prob, True)


@torch.jit.script
417
418
419
420
def bias_dropout_add_fused_inference(x: torch.Tensor,
                                     bias: torch.Tensor,
                                     residual: torch.Tensor,
                                     prob: float) -> torch.Tensor:
421
    return bias_dropout_add(x, bias, residual, prob, False)
422
423
424
425
426


class ParallelTransformerLayer(MegatronModule):
    """A single transformer layer.

427
    Transformer layer takes input with size [b, s, h] and returns an
428
429
    output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
430

431
432
    def __init__(self, init_method, output_layer_init_method,
                 layer_number, layer_type=LayerType.encoder,
433
434
                 self_attn_mask_type=AttnMaskType.padding,
                 drop_path_rate=0.):
Mohammad's avatar
Mohammad committed
435
        args = get_args()
436
437

        super(ParallelTransformerLayer, self).__init__()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
438
        self.layer_number = layer_number
439
        self.layer_type = layer_type
440
        self.drop_path_rate = drop_path_rate
441
442

        self.apply_residual_connection_post_layernorm \
Mohammad's avatar
Mohammad committed
443
            = args.apply_residual_connection_post_layernorm
444

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
445
446
447
        self.bf16 = args.bf16
        self.fp32_residual_connection = args.fp32_residual_connection

448
449
        # Layernorm on the input data.
        self.input_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
450
            args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
451
452
            eps=args.layernorm_epsilon,
            no_persist_layer_norm=args.no_persist_layer_norm)
453
454

        # Self attention.
455
456
457
458
459
460
        self.self_attention = ParallelAttention(
            init_method,
            output_layer_init_method,
            layer_number,
            attention_type=AttnType.self_attn,
            attn_mask_type=self_attn_mask_type)
461
462
        self.hidden_dropout = args.hidden_dropout
        self.bias_dropout_fusion = args.bias_dropout_fusion
463
        self.drop_path = DropPath(drop_path_rate)
464

465
        # Layernorm on the attention output
466
        self.post_attention_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
467
            args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
468
469
            eps=args.layernorm_epsilon,
            no_persist_layer_norm=args.no_persist_layer_norm)
470

471
472
473
474
475
476
477
478
479
        if self.layer_type == LayerType.decoder:
            self.inter_attention = ParallelAttention(
                init_method,
                output_layer_init_method,
                layer_number,
                attention_type=AttnType.cross_attn)
            # Layernorm on the attention output.
            self.post_inter_attention_layernorm = LayerNorm(
                args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
480
481
                eps=args.layernorm_epsilon,
                no_persist_layer_norm=args.no_persist_layer_norm)
482

483
        # MLP
484
        self.mlp = ParallelMLP(init_method,
Mohammad's avatar
Mohammad committed
485
                               output_layer_init_method)
486

487
    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
488
489
                encoder_output=None, enc_dec_attn_mask=None,
                inference_params=None):
490
491
        # hidden_states: [b, s, h]

492
        # Layer norm at the beginning of the transformer layer.
493
494
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
495
        attention_output, attention_bias = \
496
497
498
            self.self_attention(
                layernorm_output,
                attention_mask,
mshoeybi's avatar
mshoeybi committed
499
                inference_params=inference_params)
500

501
502
        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
503
504
505
506
            residual = layernorm_output
        else:
            residual = hidden_states

507
508
509
510
511
512
513
514
515
516
        if self.drop_path_rate == 0.0:
            # jit scripting for a nn.module (with dropout) is not
            # trigerring the fusion kernel. For now, we use two
            # different nn.functional routines to account for varying
            # dropout semantics during training and inference phases.
            if self.bias_dropout_fusion:
                if self.training:
                    bias_dropout_add_func = bias_dropout_add_fused_train
                else:
                    bias_dropout_add_func = bias_dropout_add_fused_inference
517
            else:
518
                bias_dropout_add_func = get_bias_dropout_add(self.training)
519

520
521
522
523
524
525
526
527
528
529
530
531
            # re-enable torch grad to enable fused optimization.
            with torch.enable_grad():
                layernorm_input = bias_dropout_add_func(
                    attention_output,
                    attention_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)
        else:
            out = torch.nn.functional.dropout(attention_output + attention_bias,
                                              p=self.hidden_dropout,
                                              training=self.training)
            layernorm_input = residual + self.drop_path(out)
532

533
534
535
        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)

536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
        if self.layer_type == LayerType.decoder:
            attention_output, attention_bias = \
                self.inter_attention(layernorm_output,
                                     enc_dec_attn_mask,
                                     encoder_output=encoder_output)
            # residual connection
            if self.apply_residual_connection_post_layernorm:
                residual = layernorm_output
            else:
                residual = layernorm_input

            # re-enable torch grad to enable fused optimization.
            with torch.enable_grad():
                layernorm_input = bias_dropout_add_func(
                    attention_output,
                    attention_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)

            # Layer norm post the decoder attention
            layernorm_output = self.post_inter_attention_layernorm(layernorm_input)

558
        # MLP.
559
        mlp_output, mlp_bias = self.mlp(layernorm_output)
560

561
562
        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
563
            residual = layernorm_output
564
        else:
565
566
            residual = layernorm_input

567
568
569
570
571
572
573
574
575
576
577
578
579
        if self.drop_path_rate == 0.0:
            # re-enable torch grad to enable fused optimization.
            with torch.enable_grad():
                output = bias_dropout_add_func(
                    mlp_output,
                    mlp_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)
        else:
            out = torch.nn.functional.dropout(mlp_output + mlp_bias,
                                              p=self.hidden_dropout,
                                              training=self.training)
            output = residual + self.drop_path(out)
580
581
582
583
584
585
586

        return output


class ParallelTransformer(MegatronModule):
    """Transformer class."""

587
    def __init__(self, init_method, output_layer_init_method,
588
                 layer_type=LayerType.encoder,
589
                 self_attn_mask_type=AttnMaskType.padding,
590
591
                 pre_process=True, post_process=True,
                 drop_path_rate=0.0):
592
        super(ParallelTransformer, self).__init__()
Mohammad's avatar
Mohammad committed
593
        args = get_args()
594

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
595
        self.bf16 = args.bf16
596
        self.fp32_residual_connection = args.fp32_residual_connection
597
598
599
        self.pre_process = pre_process
        self.post_process = post_process
        self.input_tensor = None
600
        self.drop_path_rate = drop_path_rate
601

602
        # Store activation checkpoiting flag.
603
604
        self.activations_checkpoint_method = args.activations_checkpoint_method
        self.activations_checkpoint_num_layers = args.activations_checkpoint_num_layers
mshoeybi's avatar
mshoeybi committed
605
        self.distribute_checkpointed_activations = args.distribute_checkpointed_activations
606

607
        # Number of layers.
608
609
        self.num_layers = mpu.get_num_layers(
            args, args.model_type == ModelType.encoder_and_decoder)
Mohammad's avatar
Mohammad committed
610

Vijay Korthikanti's avatar
Vijay Korthikanti committed
611
        self.dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, args.num_layers)]
612

Mohammad's avatar
Mohammad committed
613
614
        # Transformer layers.
        def build_layer(layer_number):
615
            return ParallelTransformerLayer(
616
617
618
                init_method,
                output_layer_init_method,
                layer_number,
619
                layer_type=layer_type,
620
621
                self_attn_mask_type=self_attn_mask_type,
                drop_path_rate=self.dpr[layer_number - 1])
622
623
        if args.virtual_pipeline_model_parallel_size is not None:
            assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
624
625
                'num_layers_per_stage must be divisible by ' \
                'virtual_pipeline_model_parallel_size'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
626
            assert args.model_type != ModelType.encoder_and_decoder
627
628
            # Number of layers in each model chunk is the number of layers in the stage,
            # divided by the number of model chunks in a stage.
629
            self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
630
631
632
633
634
635
636
637
            # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0]  [2]  [4]  [6]
            # Stage 1: [1]  [3]  [5]  [7]
            # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0, 1]  [4, 5]
            # Stage 1: [2, 3]  [6, 7]
638
            offset = mpu.get_virtual_pipeline_model_parallel_rank() * (
639
                args.num_layers // args.virtual_pipeline_model_parallel_size) + \
640
641
                (mpu.get_pipeline_model_parallel_rank() * self.num_layers)
        else:
642
            # Each stage gets a contiguous set of layers.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
643
644
            if args.model_type == ModelType.encoder_and_decoder and \
                    mpu.get_pipeline_model_parallel_world_size() > 1:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
645
646
647
648
649
650
651
652
                pipeline_rank = mpu.get_pipeline_model_parallel_rank()
                if layer_type == LayerType.encoder:
                    offset = pipeline_rank * self.num_layers
                else:
                    num_ranks_in_enc = args.pipeline_model_parallel_split_rank
                    offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers
            else:
                offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
653

654
        self.layers = torch.nn.ModuleList(
655
            [build_layer(i + 1 + offset) for i in range(self.num_layers)])
656

657
        if self.post_process:
658
659
660
            # Final layer norm before output.
            self.final_layernorm = LayerNorm(
                args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
661
662
                eps=args.layernorm_epsilon,
                no_persist_layer_norm=args.no_persist_layer_norm)
663

Mohammad's avatar
Mohammad committed
664
    def _get_layer(self, layer_number):
665
        return self.layers[layer_number]
Mohammad's avatar
Mohammad committed
666

667
668
    def _checkpointed_forward(self, hidden_states, attention_mask,
                              encoder_output, enc_dec_attn_mask):
669
670
671
672
        """Forward method with activation checkpointing."""
        def custom(start, end):
            def custom_forward(*inputs):
                x_ = inputs[0]
673
674
675
                attention_mask = inputs[1]
                encoder_output = inputs[2]
                enc_dec_attn_mask = inputs[3]
Mohammad's avatar
Mohammad committed
676
677
                for index in range(start, end):
                    layer = self._get_layer(index)
678
                    x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
679
680
681
                return x_
            return custom_forward

682
683
684
685
686
687
688
689
        if self.activations_checkpoint_method == 'uniform':
            # Uniformly divide the total number of Transformer layers and checkpoint
            # the input activation of each divided chunk.
            # A method to further reduce memory usage reducing checkpoints.
            l = 0
            while l < self.num_layers:
                hidden_states = mpu.checkpoint(
                    custom(l, l + self.activations_checkpoint_num_layers),
690
                    self.distribute_checkpointed_activations,
691
692
693
694
695
696
697
698
699
700
                    hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
                l += self.activations_checkpoint_num_layers
        elif self.activations_checkpoint_method == 'block':
            # Checkpoint the input activation of only a set number of individual
            # Transformer layers and skip the rest.
            # A method fully use the device memory removing redundant re-computation.
            for l in range(self.num_layers):
                if l < self.activations_checkpoint_num_layers:
                    hidden_states = mpu.checkpoint(
                        custom(l, l + 1),
701
                        self.distribute_checkpointed_activations,
702
703
704
705
706
707
                        hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
                else:
                    hidden_states = custom(l, l + 1)(
                        hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
        else:
            raise ValueError("Invalid activation checkpoint method.")
708
709
710

        return hidden_states

711
    def set_input_tensor(self, input_tensor):
712
713
714
715
716
717
718
        """Set input tensor to be used instead of forward()'s input.

        When doing pipeline parallelism the input from the previous
        stage comes from communication, not from the input, so the
        model's forward_step_func won't have it. This function is thus
        used by internal code to bypass the input provided by the
        forward_step_func"""
719
720
        self.input_tensor = input_tensor

721
    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
722
723
                encoder_output=None, enc_dec_attn_mask=None,
                inference_params=None):
724

725
        # Checks.
mshoeybi's avatar
mshoeybi committed
726
        if inference_params:
727
            assert self.activations_checkpoint_method is None, \
728
                'inference does not work with activation checkpointing'
729

730
        if self.pre_process:
731
            # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
mshoeybi's avatar
mshoeybi committed
732
            # If the input flag for fp32 residual connection is set, convert for float.
733
734
            if self.fp32_residual_connection:
                hidden_states = hidden_states.transpose(0, 1).contiguous().float()
mshoeybi's avatar
mshoeybi committed
735
            # Otherwise, leave it as is.
736
737
            else:
                hidden_states = hidden_states.transpose(0, 1).contiguous()
738
        else:
739
            # See set_input_tensor()
740
            hidden_states = self.input_tensor
741

Vijay Korthikanti's avatar
Vijay Korthikanti committed
742
743
        if encoder_output is not None:
             encoder_output = encoder_output.transpose(0, 1).contiguous()
744

745
        if self.activations_checkpoint_method is not None:
746
            hidden_states = self._checkpointed_forward(hidden_states,
747
748
749
                                                       attention_mask,
                                                       encoder_output,
                                                       enc_dec_attn_mask)
750
        else:
Mohammad's avatar
Mohammad committed
751
752
            for index in range(self.num_layers):
                layer = self._get_layer(index)
753
754
755
756
757
                hidden_states = layer(
                    hidden_states,
                    attention_mask,
                    encoder_output=encoder_output,
                    enc_dec_attn_mask=enc_dec_attn_mask,
mshoeybi's avatar
mshoeybi committed
758
759
                    inference_params=inference_params)

760

761
        # Final layer norm.
762
        if self.post_process:
763
764
            # Reverting data format change [s b h] --> [b s h].
            hidden_states = hidden_states.transpose(0, 1).contiguous()
765
766
767
            output = self.final_layernorm(hidden_states)
        else:
            output = hidden_states
768
        
769
        return output