transformer.py 31.6 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transformer."""
import math
import torch
19
import torch.nn.functional as F
20

Mohammad's avatar
Mohammad committed
21
from megatron import get_args
22
from megatron import mpu
23
from .module import MegatronModule
24
from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType
25
from megatron.model import LayerNorm
26
27
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
28
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
29
30
31
32
33
34
35
36
37
38
39

""" We use the following notation throughout this file:
     h: hidden size
     n: number of attention heads
     p: number of model parallel partitions
     np: n/p
     hp: h/p
     hn: h/n
     b: batch size
     s: sequence length
     l: number of layers
40
    Transformer takes input of size [s, b, h] and returns a
41
42
43
44
    tensor of the same size. We use the following arguments:
        hyperparameters: transformer hyperparameters
"""

45
46
47
48
49
50

class DropPath(MegatronModule):
    """Drop paths (Stochastic Depth) per sample 
    (when applied in main path of residual blocks).
    """

Vijay Korthikanti's avatar
Vijay Korthikanti committed
51
    def __init__(self, drop_prob=0.):
52
53
54
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

Vijay Korthikanti's avatar
Vijay Korthikanti committed
55
    def forward(self, hidden_state):
56
        if self.drop_prob == 0. or not self.training:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
57
            return hidden_state
58
59
        keep_prob = 1 - self.drop_prob
        # work with diff dim tensors, not just 2D ConvNets
Vijay Korthikanti's avatar
Vijay Korthikanti committed
60
        shape = (hidden_state.shape[0],) + (1,) * (hidden_state.ndim - 1)
61
        random_tensor = keep_prob + \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
62
            torch.rand(shape, dtype=hidden_state.dtype, device=hidden_state.device)
63
        random_tensor.floor_()  # binarize
Vijay Korthikanti's avatar
Vijay Korthikanti committed
64
        output = hidden_state.div(keep_prob) * random_tensor
65
66
67
        return output


68
69
70
71
72
class ParallelMLP(MegatronModule):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
hwijeen's avatar
hwijeen committed
73
    state back into h hidden dimension.
74
75
    """

76
    def __init__(self, init_method, output_layer_init_method):
77
        super(ParallelMLP, self).__init__()
Mohammad's avatar
Mohammad committed
78
        args = get_args()
79
80
81

        # Project to 4h.
        self.dense_h_to_4h = mpu.ColumnParallelLinear(
Mohammad's avatar
Mohammad committed
82
            args.hidden_size,
83
            args.ffn_hidden_size,
84
            gather_output=False,
85
86
            init_method=init_method,
            skip_bias_add=True)
87

88
89
90
91
92
93
        self.bias_gelu_fusion = args.bias_gelu_fusion
        self.activation_func = F.gelu
        if args.openai_gelu:
            self.activation_func = openai_gelu
        elif args.onnx_safe:
            self.activation_func = erf_gelu
94
95
96

        # Project back to h.
        self.dense_4h_to_h = mpu.RowParallelLinear(
97
            args.ffn_hidden_size,
Mohammad's avatar
Mohammad committed
98
            args.hidden_size,
99
            input_is_parallel=True,
100
101
            init_method=output_layer_init_method,
            skip_bias_add=True)
102

103
104
    def forward(self, hidden_states):

105
106
        # [s, b, 4hp]
        intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
107

108
109
110
111
112
113
114
115
116
117
        if self.bias_gelu_fusion:
             intermediate_parallel = \
                     bias_gelu_impl(intermediate_parallel, bias_parallel)
        else:
            intermediate_parallel = \
                self.activation_func(intermediate_parallel + bias_parallel)

        # [s, b, h]
        output, output_bias = self.dense_4h_to_h(intermediate_parallel)
        return output, output_bias
118
119


120
class ParallelAttention(MegatronModule):
121
122
123
124
125
    """Parallel self-attention layer abstract class.

    Self-attention layer takes input with size [b, s, h]
    and returns output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
126

127
    def __init__(self, init_method,
128
129
130
131
                 output_layer_init_method, layer_number,
                 attention_type=AttnType.self_attn,
                 attn_mask_type=AttnMaskType.padding):
        super(ParallelAttention, self).__init__()
Mohammad's avatar
Mohammad committed
132
        args = get_args()
Mohammad's avatar
Mohammad committed
133
        self.fp16 = args.fp16
134
        self.bf16 = args.bf16
135

Mohammad's avatar
Mohammad committed
136
137
        self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
138
139
140
        if self.apply_query_key_layer_scaling:
            self.attention_softmax_in_fp32 = True
        self.layer_number = max(1, layer_number)
141
142
        self.attention_type = attention_type
        self.attn_mask_type = attn_mask_type
143
        self.params_dtype = args.params_dtype
144
145

        projection_size = args.kv_channels * args.num_attention_heads
146
147

        # Per attention head and per partition values.
148
        world_size = mpu.get_tensor_model_parallel_world_size()
149
        self.hidden_size_per_partition = mpu.divide(projection_size,
Mohammad's avatar
Mohammad committed
150
                                                    world_size)
151
        self.hidden_size_per_attention_head = mpu.divide(
152
            projection_size, args.num_attention_heads)
153
        self.num_attention_heads_per_partition = mpu.divide(
Mohammad's avatar
Mohammad committed
154
            args.num_attention_heads, world_size)
155
156

        # Strided linear layer.
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
        if attention_type == AttnType.self_attn:
            self.query_key_value = mpu.ColumnParallelLinear(
                args.hidden_size,
                3 * projection_size,
                gather_output=False,
                init_method=init_method)
        else:
            assert attention_type == AttnType.cross_attn
            self.query = mpu.ColumnParallelLinear(
                args.hidden_size,
                projection_size,
                gather_output=False,
                init_method=init_method)

            self.key_value = mpu.ColumnParallelLinear(
                args.hidden_size,
                2 * projection_size,
                gather_output=False,
                init_method=init_method)
176

177
178
179
180
181
182
183
        coeff = None
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        if self.apply_query_key_layer_scaling:
            coeff = self.layer_number
            self.norm_factor *= coeff

        self.scale_mask_softmax = FusedScaleMaskSoftmax(
184
            self.fp16, self.bf16,
185
186
            self.attn_mask_type,
            args.masked_softmax_fusion,
187
            attention_mask_func,
188
189
190
            self.attention_softmax_in_fp32,
            coeff)

191
192
193
        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
Mohammad's avatar
Mohammad committed
194
        self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
195
196
197

        # Output.
        self.dense = mpu.RowParallelLinear(
198
            projection_size,
Mohammad's avatar
Mohammad committed
199
            args.hidden_size,
200
            input_is_parallel=True,
201
202
            init_method=output_layer_init_method,
            skip_bias_add=True)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
203

204
205
206
207
208
209
210
211
212
213
214
215

    def _allocate_memory(self, inference_max_sequence_len, batch_size):
        return torch.empty(
            inference_max_sequence_len,
            batch_size,
            self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
            dtype=self.params_dtype,
            device=torch.cuda.current_device())
        

    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
216
                encoder_output=None, inference_params=None):
217
        # hidden_states: [sq, b, h]
218

219
220
221
222

        # =================================================
        # Pre-allocate memory for key-values for inference.
        # =================================================
mshoeybi's avatar
mshoeybi committed
223
        if inference_params:
224
            if self.layer_number not in inference_params.key_value_memory_dict:
mshoeybi's avatar
mshoeybi committed
225
                inf_max_seq_len = inference_params.max_sequence_len
mshoeybi's avatar
mshoeybi committed
226
                inf_max_batch_size = inference_params.max_batch_size
227
                inference_key_memory = self._allocate_memory(
mshoeybi's avatar
mshoeybi committed
228
                    inf_max_seq_len, inf_max_batch_size)
229
                inference_value_memory = self._allocate_memory(
mshoeybi's avatar
mshoeybi committed
230
                    inf_max_seq_len, inf_max_batch_size)
231
232
233
234
235
                inference_params.key_value_memory_dict[self.layer_number] = (
                    inference_key_memory, inference_value_memory)
            else:
                inference_key_memory, inference_value_memory = \
                    inference_params.key_value_memory_dict[self.layer_number]
mshoeybi's avatar
mshoeybi committed
236

237

238
239
240
        # =====================
        # Query, Key, and Value
        # =====================
241

242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
        if self.attention_type == AttnType.self_attn:
            # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
            mixed_x_layer, _ = self.query_key_value(hidden_states)

            # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
            new_tensor_shape = mixed_x_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 3 * self.hidden_size_per_attention_head)
            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

            # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
            (query_layer,
             key_layer,
             value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
        else:
            # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
            mixed_kv_layer, _ = self.key_value(encoder_output)

            # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
            new_tensor_shape = mixed_kv_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 2 * self.hidden_size_per_attention_head)
            mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

            # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
            (key_layer,
             value_layer) = mpu.split_tensor_along_last_dim(mixed_kv_layer, 2)

            # Attention head [sq, b, h] --> [sq, b, hp]
            query_layer, _ = self.query(hidden_states)
            # [sq, b, hp] --> [sq, b, np, hn]
            new_tensor_shape = query_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 self.hidden_size_per_attention_head)
            query_layer = query_layer.view(*new_tensor_shape)
277
278


mshoeybi's avatar
mshoeybi committed
279
280
281
        # ==================================
        # Adjust key and value for inference
        # ==================================
282

mshoeybi's avatar
mshoeybi committed
283
        if inference_params:
mshoeybi's avatar
mshoeybi committed
284
285
            batch_start = inference_params.batch_size_offset
            batch_end = batch_start + key_layer.size(1)
286
            assert batch_end <= inference_key_memory.size(1)
mshoeybi's avatar
mshoeybi committed
287
288
            sequence_start = inference_params.sequence_len_offset
            sequence_end = sequence_start + key_layer.size(0)
289
            assert sequence_end <= inference_key_memory.size(0)
290
            # Copy key and values.
291
292
293
294
295
            inference_key_memory[sequence_start:sequence_end,
                                 batch_start:batch_end, ...] = key_layer
            inference_value_memory[sequence_start:sequence_end,
                                   batch_start:batch_end, ...] = value_layer
            key_layer = inference_key_memory[
mshoeybi's avatar
mshoeybi committed
296
                :sequence_end, batch_start:batch_end, ...]
297
            value_layer = inference_value_memory[
mshoeybi's avatar
mshoeybi committed
298
                :sequence_end, batch_start:batch_end, ...]
299

300

301
302
303
        # ===================================
        # Raw attention scores. [b, np, s, s]
        # ===================================
304

305
        # [b, np, sq, sk]
306
307
308
        output_size = (query_layer.size(1),
                       query_layer.size(2),
                       query_layer.size(0),
309
                       key_layer.size(0))
310

311
        # [sq, b, np, hn] -> [sq, b * np, hn]
312
313
        query_layer = query_layer.view(output_size[2],
                                       output_size[0] * output_size[1], -1)
314
        # [sk, b, np, hn] -> [sk, b * np, hn]
315
316
317
        key_layer = key_layer.view(output_size[3],
                                   output_size[0] * output_size[1], -1)

318
        # preallocting result tensor: [b * np, sq, sk]
319
        matmul_result = torch.empty(
320
321
            output_size[0]*output_size[1],
            output_size[2],
322
            output_size[3],
323
            dtype=query_layer.dtype,
324
325
            device=torch.cuda.current_device())

326
        # Raw attention scores. [b * np, sq, sk]
327
328
        matmul_result = torch.baddbmm(
            matmul_result,
329
            query_layer.transpose(0, 1),   # [b * np, sq, hn]
330
            key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
331
332
            beta=0.0, alpha=(1.0/self.norm_factor))

333
        # change view to [b, np, sq, sk]
334
335
        attention_scores = matmul_result.view(*output_size)

336

337
338
339
        # ===========================
        # Attention probs and dropout
        # ===========================
340

341
        # attention scores and attention mask [b, np, sq, sk]
342
343
        attention_probs = self.scale_mask_softmax(attention_scores,
                                                  attention_mask)
344

345
346
347
348
349
350
        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        with mpu.get_cuda_rng_tracker().fork():
            attention_probs = self.attention_dropout(attention_probs)

        # =========================
351
        # Context layer. [sq, b, hp]
352
353
        # =========================

354
355
        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]
356

357
        # context layer shape: [b, np, sq, hn]
358
359
360
361
        output_size = (value_layer.size(1),
                       value_layer.size(2),
                       query_layer.size(0),
                       value_layer.size(3))
362

363
        # change view [sk, b * np, hn]
364
        value_layer = value_layer.view(value_layer.size(0),
365
                                       output_size[0] * output_size[1], -1)
366

367
        # change view [b * np, sq, sk]
368
369
        attention_probs = attention_probs.view(output_size[0] * output_size[1],
                                               output_size[2], -1)
370

371
        # matmul: [b * np, sq, hn]
372
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
373

374
        # change view [b, np, sq, hn]
375
376
        context_layer = context_layer.view(*output_size)

377
        # [b, np, sq, hn] --> [sq, b, np, hn]
378
379
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

380
        # [sq, b, np, hn] --> [sq, b, hp]
381
382
383
384
385
        new_context_layer_shape = context_layer.size()[:-2] + \
            (self.hidden_size_per_partition,)
        context_layer = context_layer.view(*new_context_layer_shape)

        # =================
386
        # Output. [sq, b, h]
387
388
389
        # =================

        output, bias = self.dense(context_layer)
390

391
392
393
        return output, bias


394
def bias_dropout_add(x, bias, residual, prob, training):
395
396
397
398
399
400
401
402
403
404
405
406
407
    # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
    out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
    out = residual + out
    return out


def get_bias_dropout_add(training):
    def _bias_dropout_add(x, bias, residual, prob):
        return bias_dropout_add(x, bias, residual, prob, training)
    return _bias_dropout_add


@torch.jit.script
408
409
410
411
def bias_dropout_add_fused_train(x: torch.Tensor,
                                 bias: torch.Tensor,
                                 residual: torch.Tensor,
                                 prob: float) -> torch.Tensor:
412
413
414
415
    return bias_dropout_add(x, bias, residual, prob, True)


@torch.jit.script
416
417
418
419
def bias_dropout_add_fused_inference(x: torch.Tensor,
                                     bias: torch.Tensor,
                                     residual: torch.Tensor,
                                     prob: float) -> torch.Tensor:
420
    return bias_dropout_add(x, bias, residual, prob, False)
421
422
423
424
425


class ParallelTransformerLayer(MegatronModule):
    """A single transformer layer.

426
    Transformer layer takes input with size [b, s, h] and returns an
427
428
    output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
429

430
431
    def __init__(self, init_method, output_layer_init_method,
                 layer_number, layer_type=LayerType.encoder,
432
433
                 self_attn_mask_type=AttnMaskType.padding,
                 drop_path_rate=0.):
Mohammad's avatar
Mohammad committed
434
        args = get_args()
435
436

        super(ParallelTransformerLayer, self).__init__()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
437
        self.layer_number = layer_number
438
        self.layer_type = layer_type
439
440

        self.apply_residual_connection_post_layernorm \
Mohammad's avatar
Mohammad committed
441
            = args.apply_residual_connection_post_layernorm
442

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
443
444
445
        self.bf16 = args.bf16
        self.fp32_residual_connection = args.fp32_residual_connection

446
447
        # Layernorm on the input data.
        self.input_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
448
            args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
449
450
            eps=args.layernorm_epsilon,
            no_persist_layer_norm=args.no_persist_layer_norm)
451
452

        # Self attention.
453
454
455
456
457
458
        self.self_attention = ParallelAttention(
            init_method,
            output_layer_init_method,
            layer_number,
            attention_type=AttnType.self_attn,
            attn_mask_type=self_attn_mask_type)
459
460
        self.hidden_dropout = args.hidden_dropout
        self.bias_dropout_fusion = args.bias_dropout_fusion
Vijay Korthikanti's avatar
Vijay Korthikanti committed
461
        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None
462

463
        # Layernorm on the attention output
464
        self.post_attention_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
465
            args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
466
467
            eps=args.layernorm_epsilon,
            no_persist_layer_norm=args.no_persist_layer_norm)
468

469
470
471
472
473
474
475
476
477
        if self.layer_type == LayerType.decoder:
            self.inter_attention = ParallelAttention(
                init_method,
                output_layer_init_method,
                layer_number,
                attention_type=AttnType.cross_attn)
            # Layernorm on the attention output.
            self.post_inter_attention_layernorm = LayerNorm(
                args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
478
479
                eps=args.layernorm_epsilon,
                no_persist_layer_norm=args.no_persist_layer_norm)
480

481
        # MLP
482
        self.mlp = ParallelMLP(init_method,
Mohammad's avatar
Mohammad committed
483
                               output_layer_init_method)
484

485
    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
486
487
                encoder_output=None, enc_dec_attn_mask=None,
                inference_params=None):
488
489
        # hidden_states: [b, s, h]

490
        # Layer norm at the beginning of the transformer layer.
491
492
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
493
        attention_output, attention_bias = \
494
495
496
            self.self_attention(
                layernorm_output,
                attention_mask,
mshoeybi's avatar
mshoeybi committed
497
                inference_params=inference_params)
498

499
500
        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
501
502
503
504
            residual = layernorm_output
        else:
            residual = hidden_states

Vijay Korthikanti's avatar
Vijay Korthikanti committed
505
        if self.drop_path is None:
506
507
508
509
510
511
512
513
514
            # jit scripting for a nn.module (with dropout) is not
            # trigerring the fusion kernel. For now, we use two
            # different nn.functional routines to account for varying
            # dropout semantics during training and inference phases.
            if self.bias_dropout_fusion:
                if self.training:
                    bias_dropout_add_func = bias_dropout_add_fused_train
                else:
                    bias_dropout_add_func = bias_dropout_add_fused_inference
515
            else:
516
                bias_dropout_add_func = get_bias_dropout_add(self.training)
517

518
519
520
521
522
523
524
525
526
527
528
529
            # re-enable torch grad to enable fused optimization.
            with torch.enable_grad():
                layernorm_input = bias_dropout_add_func(
                    attention_output,
                    attention_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)
        else:
            out = torch.nn.functional.dropout(attention_output + attention_bias,
                                              p=self.hidden_dropout,
                                              training=self.training)
            layernorm_input = residual + self.drop_path(out)
530

531
532
533
        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)

534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
        if self.layer_type == LayerType.decoder:
            attention_output, attention_bias = \
                self.inter_attention(layernorm_output,
                                     enc_dec_attn_mask,
                                     encoder_output=encoder_output)
            # residual connection
            if self.apply_residual_connection_post_layernorm:
                residual = layernorm_output
            else:
                residual = layernorm_input

            # re-enable torch grad to enable fused optimization.
            with torch.enable_grad():
                layernorm_input = bias_dropout_add_func(
                    attention_output,
                    attention_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)

            # Layer norm post the decoder attention
            layernorm_output = self.post_inter_attention_layernorm(layernorm_input)

556
        # MLP.
557
        mlp_output, mlp_bias = self.mlp(layernorm_output)
558

559
560
        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
561
            residual = layernorm_output
562
        else:
563
564
            residual = layernorm_input

Vijay Korthikanti's avatar
Vijay Korthikanti committed
565
        if self.drop_path is None:
566
567
568
569
570
571
572
573
574
575
576
577
            # re-enable torch grad to enable fused optimization.
            with torch.enable_grad():
                output = bias_dropout_add_func(
                    mlp_output,
                    mlp_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)
        else:
            out = torch.nn.functional.dropout(mlp_output + mlp_bias,
                                              p=self.hidden_dropout,
                                              training=self.training)
            output = residual + self.drop_path(out)
578
579
580
581
582
583
584

        return output


class ParallelTransformer(MegatronModule):
    """Transformer class."""

585
    def __init__(self, init_method, output_layer_init_method,
586
                 layer_type=LayerType.encoder,
587
                 self_attn_mask_type=AttnMaskType.padding,
588
589
                 pre_process=True, post_process=True,
                 drop_path_rate=0.0):
590
        super(ParallelTransformer, self).__init__()
Mohammad's avatar
Mohammad committed
591
        args = get_args()
592

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
593
        self.bf16 = args.bf16
594
        self.fp32_residual_connection = args.fp32_residual_connection
595
596
597
        self.pre_process = pre_process
        self.post_process = post_process
        self.input_tensor = None
598
        self.drop_path_rate = drop_path_rate
599

600
        # Store activation checkpoiting flag.
601
602
        self.activations_checkpoint_method = args.activations_checkpoint_method
        self.activations_checkpoint_num_layers = args.activations_checkpoint_num_layers
mshoeybi's avatar
mshoeybi committed
603
        self.distribute_checkpointed_activations = args.distribute_checkpointed_activations
604

605
        # Number of layers.
606
607
        self.num_layers = mpu.get_num_layers(
            args, args.model_type == ModelType.encoder_and_decoder)
Mohammad's avatar
Mohammad committed
608

Vijay Korthikanti's avatar
Vijay Korthikanti committed
609
        self.drop_path_rates = [rate.item() for rate in torch.linspace(0, self.drop_path_rate, args.num_layers)]
610

Mohammad's avatar
Mohammad committed
611
612
        # Transformer layers.
        def build_layer(layer_number):
613
            return ParallelTransformerLayer(
614
615
616
                init_method,
                output_layer_init_method,
                layer_number,
617
                layer_type=layer_type,
618
                self_attn_mask_type=self_attn_mask_type,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
619
                drop_path_rate=self.drop_path_rates[layer_number - 1])
620
621
        if args.virtual_pipeline_model_parallel_size is not None:
            assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
622
623
                'num_layers_per_stage must be divisible by ' \
                'virtual_pipeline_model_parallel_size'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
624
            assert args.model_type != ModelType.encoder_and_decoder
625
626
            # Number of layers in each model chunk is the number of layers in the stage,
            # divided by the number of model chunks in a stage.
627
            self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
628
629
630
631
632
633
634
635
            # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0]  [2]  [4]  [6]
            # Stage 1: [1]  [3]  [5]  [7]
            # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0, 1]  [4, 5]
            # Stage 1: [2, 3]  [6, 7]
636
            offset = mpu.get_virtual_pipeline_model_parallel_rank() * (
637
                args.num_layers // args.virtual_pipeline_model_parallel_size) + \
638
639
                (mpu.get_pipeline_model_parallel_rank() * self.num_layers)
        else:
640
            # Each stage gets a contiguous set of layers.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
641
642
            if args.model_type == ModelType.encoder_and_decoder and \
                    mpu.get_pipeline_model_parallel_world_size() > 1:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
643
644
645
646
647
648
649
650
                pipeline_rank = mpu.get_pipeline_model_parallel_rank()
                if layer_type == LayerType.encoder:
                    offset = pipeline_rank * self.num_layers
                else:
                    num_ranks_in_enc = args.pipeline_model_parallel_split_rank
                    offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers
            else:
                offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
651

652
        self.layers = torch.nn.ModuleList(
653
            [build_layer(i + 1 + offset) for i in range(self.num_layers)])
654

655
        if self.post_process:
656
657
658
            # Final layer norm before output.
            self.final_layernorm = LayerNorm(
                args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
659
660
                eps=args.layernorm_epsilon,
                no_persist_layer_norm=args.no_persist_layer_norm)
661

Mohammad's avatar
Mohammad committed
662
    def _get_layer(self, layer_number):
663
        return self.layers[layer_number]
Mohammad's avatar
Mohammad committed
664

665
666
    def _checkpointed_forward(self, hidden_states, attention_mask,
                              encoder_output, enc_dec_attn_mask):
667
668
669
670
        """Forward method with activation checkpointing."""
        def custom(start, end):
            def custom_forward(*inputs):
                x_ = inputs[0]
671
672
673
                attention_mask = inputs[1]
                encoder_output = inputs[2]
                enc_dec_attn_mask = inputs[3]
Mohammad's avatar
Mohammad committed
674
675
                for index in range(start, end):
                    layer = self._get_layer(index)
676
                    x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
677
678
679
                return x_
            return custom_forward

680
681
682
683
684
685
686
687
        if self.activations_checkpoint_method == 'uniform':
            # Uniformly divide the total number of Transformer layers and checkpoint
            # the input activation of each divided chunk.
            # A method to further reduce memory usage reducing checkpoints.
            l = 0
            while l < self.num_layers:
                hidden_states = mpu.checkpoint(
                    custom(l, l + self.activations_checkpoint_num_layers),
688
                    self.distribute_checkpointed_activations,
689
690
691
692
693
694
695
696
697
698
                    hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
                l += self.activations_checkpoint_num_layers
        elif self.activations_checkpoint_method == 'block':
            # Checkpoint the input activation of only a set number of individual
            # Transformer layers and skip the rest.
            # A method fully use the device memory removing redundant re-computation.
            for l in range(self.num_layers):
                if l < self.activations_checkpoint_num_layers:
                    hidden_states = mpu.checkpoint(
                        custom(l, l + 1),
699
                        self.distribute_checkpointed_activations,
700
701
702
703
704
705
                        hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
                else:
                    hidden_states = custom(l, l + 1)(
                        hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
        else:
            raise ValueError("Invalid activation checkpoint method.")
706
707
708

        return hidden_states

709
    def set_input_tensor(self, input_tensor):
710
711
712
713
714
715
716
        """Set input tensor to be used instead of forward()'s input.

        When doing pipeline parallelism the input from the previous
        stage comes from communication, not from the input, so the
        model's forward_step_func won't have it. This function is thus
        used by internal code to bypass the input provided by the
        forward_step_func"""
717
718
        self.input_tensor = input_tensor

719
    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
720
721
                encoder_output=None, enc_dec_attn_mask=None,
                inference_params=None):
722

723
        # Checks.
mshoeybi's avatar
mshoeybi committed
724
        if inference_params:
725
            assert self.activations_checkpoint_method is None, \
726
                'inference does not work with activation checkpointing'
727

728
        if self.pre_process:
729
            # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
mshoeybi's avatar
mshoeybi committed
730
            # If the input flag for fp32 residual connection is set, convert for float.
731
732
            if self.fp32_residual_connection:
                hidden_states = hidden_states.transpose(0, 1).contiguous().float()
mshoeybi's avatar
mshoeybi committed
733
            # Otherwise, leave it as is.
734
735
            else:
                hidden_states = hidden_states.transpose(0, 1).contiguous()
736
        else:
737
            # See set_input_tensor()
738
            hidden_states = self.input_tensor
739

740
741
        # Viewless tensor.
        # - We only need to create a viewless tensor in the case of micro batch
742
743
744
745
        #   size (mbs) == 1, since in this case, 'hidden_states.transpose()'
        #   above creates a view tensor, and '.contiguous()' is a pass-through.
        #   For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
        #   the need to make it viewless.
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
        #
        #   However, we don't explicitly check mbs == 1 here because
        #   make_viewless_tensor() has negligible overhead when its input
        #   is already viewless.
        # 
        # - For the 'else' case above, calling make_viewless_tensor() here is
        #   likely redundant, since p2p_communication.py (likely originator)
        #   already creates viewless tensors. That said, make_viewless_tensor()
        #   is called here to be future-proof and corner-case-proof.
        hidden_states = mpu.make_viewless_tensor(
            hidden_states,
            requires_grad = True,
            keep_graph = True,
        )

        # Transpose encoder output.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
762
        if encoder_output is not None:
763
            encoder_output = encoder_output.transpose(0, 1).contiguous()
764

765
        # Forward pass.
766
        if self.activations_checkpoint_method is not None:
767
            hidden_states = self._checkpointed_forward(hidden_states,
768
769
770
                                                       attention_mask,
                                                       encoder_output,
                                                       enc_dec_attn_mask)
771
        else:
Mohammad's avatar
Mohammad committed
772
773
            for index in range(self.num_layers):
                layer = self._get_layer(index)
774
775
776
777
778
                hidden_states = layer(
                    hidden_states,
                    attention_mask,
                    encoder_output=encoder_output,
                    enc_dec_attn_mask=enc_dec_attn_mask,
mshoeybi's avatar
mshoeybi committed
779
780
                    inference_params=inference_params)

781

782
        # Final layer norm.
783
        if self.post_process:
784
785
            # Reverting data format change [s b h] --> [b s h].
            hidden_states = hidden_states.transpose(0, 1).contiguous()
786
787
788
            output = self.final_layernorm(hidden_states)
        else:
            output = hidden_states
789
        
790
        return output