transformer.py 30.5 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transformer."""
import math
import torch
19
import torch.nn.functional as F
20

Mohammad's avatar
Mohammad committed
21
from megatron import get_args
22
from megatron import mpu
23
from .module import MegatronModule
24
from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType
25
from megatron.model import LayerNorm
26
27
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
28
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
29
30
31
32
33
34
35
36
37
38
39
40


""" We use the following notation throughout this file:
     h: hidden size
     n: number of attention heads
     p: number of model parallel partitions
     np: n/p
     hp: h/p
     hn: h/n
     b: batch size
     s: sequence length
     l: number of layers
41
    Transformer takes input of size [s, b, h] and returns a
42
43
44
45
    tensor of the same size. We use the following arguments:
        hyperparameters: transformer hyperparameters
"""

46
47
48
49
50
51

class DropPath(MegatronModule):
    """Drop paths (Stochastic Depth) per sample 
    (when applied in main path of residual blocks).
    """

Vijay Korthikanti's avatar
Vijay Korthikanti committed
52
    def __init__(self, drop_prob=0.):
53
54
55
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

Vijay Korthikanti's avatar
Vijay Korthikanti committed
56
    def forward(self, hidden_state):
57
        if self.drop_prob == 0. or not self.training:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
58
            return hidden_state
59
60
        keep_prob = 1 - self.drop_prob
        # work with diff dim tensors, not just 2D ConvNets
Vijay Korthikanti's avatar
Vijay Korthikanti committed
61
        shape = (hidden_state.shape[0],) + (1,) * (hidden_state.ndim - 1)
62
        random_tensor = keep_prob + \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
63
            torch.rand(shape, dtype=hidden_state.dtype, device=hidden_state.device)
64
        random_tensor.floor_()  # binarize
Vijay Korthikanti's avatar
Vijay Korthikanti committed
65
        output = hidden_state.div(keep_prob) * random_tensor
66
67
68
        return output


69
70
71
72
73
class ParallelMLP(MegatronModule):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
hwijeen's avatar
hwijeen committed
74
    state back into h hidden dimension.
75
76
    """

77
    def __init__(self, init_method, output_layer_init_method):
78
        super(ParallelMLP, self).__init__()
Mohammad's avatar
Mohammad committed
79
        args = get_args()
80
81
82

        # Project to 4h.
        self.dense_h_to_4h = mpu.ColumnParallelLinear(
Mohammad's avatar
Mohammad committed
83
            args.hidden_size,
84
            args.ffn_hidden_size,
85
            gather_output=False,
86
87
            init_method=init_method,
            skip_bias_add=True)
88

89
90
91
92
93
94
        self.bias_gelu_fusion = args.bias_gelu_fusion
        self.activation_func = F.gelu
        if args.openai_gelu:
            self.activation_func = openai_gelu
        elif args.onnx_safe:
            self.activation_func = erf_gelu
95
96
97

        # Project back to h.
        self.dense_4h_to_h = mpu.RowParallelLinear(
98
            args.ffn_hidden_size,
Mohammad's avatar
Mohammad committed
99
            args.hidden_size,
100
            input_is_parallel=True,
101
102
            init_method=output_layer_init_method,
            skip_bias_add=True)
103

104
105
    def forward(self, hidden_states):

106
107
        # [s, b, 4hp]
        intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
108

109
110
111
112
113
114
115
116
117
118
        if self.bias_gelu_fusion:
             intermediate_parallel = \
                     bias_gelu_impl(intermediate_parallel, bias_parallel)
        else:
            intermediate_parallel = \
                self.activation_func(intermediate_parallel + bias_parallel)

        # [s, b, h]
        output, output_bias = self.dense_4h_to_h(intermediate_parallel)
        return output, output_bias
119
120


121
class ParallelAttention(MegatronModule):
122
123
124
125
126
    """Parallel self-attention layer abstract class.

    Self-attention layer takes input with size [b, s, h]
    and returns output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
127

128
    def __init__(self, init_method,
129
130
131
132
                 output_layer_init_method, layer_number,
                 attention_type=AttnType.self_attn,
                 attn_mask_type=AttnMaskType.padding):
        super(ParallelAttention, self).__init__()
Mohammad's avatar
Mohammad committed
133
        args = get_args()
Mohammad's avatar
Mohammad committed
134
        self.fp16 = args.fp16
135
        self.bf16 = args.bf16
136

Mohammad's avatar
Mohammad committed
137
138
        self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
139
140
141
        if self.apply_query_key_layer_scaling:
            self.attention_softmax_in_fp32 = True
        self.layer_number = max(1, layer_number)
142
143
        self.attention_type = attention_type
        self.attn_mask_type = attn_mask_type
144
        self.params_dtype = args.params_dtype
145
146

        projection_size = args.kv_channels * args.num_attention_heads
147
148

        # Per attention head and per partition values.
149
        world_size = mpu.get_tensor_model_parallel_world_size()
150
        self.hidden_size_per_partition = mpu.divide(projection_size,
Mohammad's avatar
Mohammad committed
151
                                                    world_size)
152
        self.hidden_size_per_attention_head = mpu.divide(
153
            projection_size, args.num_attention_heads)
154
        self.num_attention_heads_per_partition = mpu.divide(
Mohammad's avatar
Mohammad committed
155
            args.num_attention_heads, world_size)
156
157

        # Strided linear layer.
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
        if attention_type == AttnType.self_attn:
            self.query_key_value = mpu.ColumnParallelLinear(
                args.hidden_size,
                3 * projection_size,
                gather_output=False,
                init_method=init_method)
        else:
            assert attention_type == AttnType.cross_attn
            self.query = mpu.ColumnParallelLinear(
                args.hidden_size,
                projection_size,
                gather_output=False,
                init_method=init_method)

            self.key_value = mpu.ColumnParallelLinear(
                args.hidden_size,
                2 * projection_size,
                gather_output=False,
                init_method=init_method)
177

178
179
180
181
182
183
184
        coeff = None
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        if self.apply_query_key_layer_scaling:
            coeff = self.layer_number
            self.norm_factor *= coeff

        self.scale_mask_softmax = FusedScaleMaskSoftmax(
185
            self.fp16, self.bf16,
186
187
            self.attn_mask_type,
            args.masked_softmax_fusion,
188
            attention_mask_func,
189
190
191
            self.attention_softmax_in_fp32,
            coeff)

192
193
194
        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
Mohammad's avatar
Mohammad committed
195
        self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
196
197
198

        # Output.
        self.dense = mpu.RowParallelLinear(
199
            projection_size,
Mohammad's avatar
Mohammad committed
200
            args.hidden_size,
201
            input_is_parallel=True,
202
203
            init_method=output_layer_init_method,
            skip_bias_add=True)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
204

205
206
207
208
209
210
211
212
213
214
215
216

    def _allocate_memory(self, inference_max_sequence_len, batch_size):
        return torch.empty(
            inference_max_sequence_len,
            batch_size,
            self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
            dtype=self.params_dtype,
            device=torch.cuda.current_device())
        

    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
217
                encoder_output=None, inference_params=None):
218
        # hidden_states: [sq, b, h]
219

220
221
222
223

        # =================================================
        # Pre-allocate memory for key-values for inference.
        # =================================================
mshoeybi's avatar
mshoeybi committed
224
        if inference_params:
225
            if self.layer_number not in inference_params.key_value_memory_dict:
mshoeybi's avatar
mshoeybi committed
226
                inf_max_seq_len = inference_params.max_sequence_len
mshoeybi's avatar
mshoeybi committed
227
                inf_max_batch_size = inference_params.max_batch_size
228
                inference_key_memory = self._allocate_memory(
mshoeybi's avatar
mshoeybi committed
229
                    inf_max_seq_len, inf_max_batch_size)
230
                inference_value_memory = self._allocate_memory(
mshoeybi's avatar
mshoeybi committed
231
                    inf_max_seq_len, inf_max_batch_size)
232
233
234
235
236
                inference_params.key_value_memory_dict[self.layer_number] = (
                    inference_key_memory, inference_value_memory)
            else:
                inference_key_memory, inference_value_memory = \
                    inference_params.key_value_memory_dict[self.layer_number]
mshoeybi's avatar
mshoeybi committed
237

238

239
240
241
        # =====================
        # Query, Key, and Value
        # =====================
242

243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
        if self.attention_type == AttnType.self_attn:
            # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
            mixed_x_layer, _ = self.query_key_value(hidden_states)

            # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
            new_tensor_shape = mixed_x_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 3 * self.hidden_size_per_attention_head)
            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

            # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
            (query_layer,
             key_layer,
             value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
        else:
            # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
            mixed_kv_layer, _ = self.key_value(encoder_output)

            # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
            new_tensor_shape = mixed_kv_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 2 * self.hidden_size_per_attention_head)
            mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

            # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
            (key_layer,
             value_layer) = mpu.split_tensor_along_last_dim(mixed_kv_layer, 2)

            # Attention head [sq, b, h] --> [sq, b, hp]
            query_layer, _ = self.query(hidden_states)
            # [sq, b, hp] --> [sq, b, np, hn]
            new_tensor_shape = query_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 self.hidden_size_per_attention_head)
            query_layer = query_layer.view(*new_tensor_shape)
278
279


mshoeybi's avatar
mshoeybi committed
280
281
282
        # ==================================
        # Adjust key and value for inference
        # ==================================
283

mshoeybi's avatar
mshoeybi committed
284
        if inference_params:
mshoeybi's avatar
mshoeybi committed
285
286
            batch_start = inference_params.batch_size_offset
            batch_end = batch_start + key_layer.size(1)
287
            assert batch_end <= inference_key_memory.size(1)
mshoeybi's avatar
mshoeybi committed
288
289
            sequence_start = inference_params.sequence_len_offset
            sequence_end = sequence_start + key_layer.size(0)
290
            assert sequence_end <= inference_key_memory.size(0)
291
            # Copy key and values.
292
293
294
295
296
            inference_key_memory[sequence_start:sequence_end,
                                 batch_start:batch_end, ...] = key_layer
            inference_value_memory[sequence_start:sequence_end,
                                   batch_start:batch_end, ...] = value_layer
            key_layer = inference_key_memory[
mshoeybi's avatar
mshoeybi committed
297
                :sequence_end, batch_start:batch_end, ...]
298
            value_layer = inference_value_memory[
mshoeybi's avatar
mshoeybi committed
299
                :sequence_end, batch_start:batch_end, ...]
300

301

302
303
304
        # ===================================
        # Raw attention scores. [b, np, s, s]
        # ===================================
305

306
        # [b, np, sq, sk]
307
308
309
        output_size = (query_layer.size(1),
                       query_layer.size(2),
                       query_layer.size(0),
310
                       key_layer.size(0))
311

312
        # [sq, b, np, hn] -> [sq, b * np, hn]
313
314
        query_layer = query_layer.view(output_size[2],
                                       output_size[0] * output_size[1], -1)
315
        # [sk, b, np, hn] -> [sk, b * np, hn]
316
317
318
        key_layer = key_layer.view(output_size[3],
                                   output_size[0] * output_size[1], -1)

319
        # preallocting result tensor: [b * np, sq, sk]
320
        matmul_result = torch.empty(
321
322
            output_size[0]*output_size[1],
            output_size[2],
323
            output_size[3],
324
            dtype=query_layer.dtype,
325
326
            device=torch.cuda.current_device())

327
        # Raw attention scores. [b * np, sq, sk]
328
329
        matmul_result = torch.baddbmm(
            matmul_result,
330
            query_layer.transpose(0, 1),   # [b * np, sq, hn]
331
            key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
332
333
            beta=0.0, alpha=(1.0/self.norm_factor))

334
        # change view to [b, np, sq, sk]
335
336
        attention_scores = matmul_result.view(*output_size)

337

338
339
340
        # ===========================
        # Attention probs and dropout
        # ===========================
341

342
        # attention scores and attention mask [b, np, sq, sk]
343
344
        attention_probs = self.scale_mask_softmax(attention_scores,
                                                  attention_mask)
345

346
347
348
349
350
351
        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        with mpu.get_cuda_rng_tracker().fork():
            attention_probs = self.attention_dropout(attention_probs)

        # =========================
352
        # Context layer. [sq, b, hp]
353
354
        # =========================

355
356
        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]
357

358
        # context layer shape: [b, np, sq, hn]
359
360
361
362
        output_size = (value_layer.size(1),
                       value_layer.size(2),
                       query_layer.size(0),
                       value_layer.size(3))
363

364
        # change view [sk, b * np, hn]
365
        value_layer = value_layer.view(value_layer.size(0),
366
                                       output_size[0] * output_size[1], -1)
367

368
        # change view [b * np, sq, sk]
369
370
        attention_probs = attention_probs.view(output_size[0] * output_size[1],
                                               output_size[2], -1)
371

372
        # matmul: [b * np, sq, hn]
373
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
374

375
        # change view [b, np, sq, hn]
376
377
        context_layer = context_layer.view(*output_size)

378
        # [b, np, sq, hn] --> [sq, b, np, hn]
379
380
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

381
        # [sq, b, np, hn] --> [sq, b, hp]
382
383
384
385
386
        new_context_layer_shape = context_layer.size()[:-2] + \
            (self.hidden_size_per_partition,)
        context_layer = context_layer.view(*new_context_layer_shape)

        # =================
387
        # Output. [sq, b, h]
388
389
390
        # =================

        output, bias = self.dense(context_layer)
391

392
393
394
        return output, bias


395
def bias_dropout_add(x, bias, residual, prob, training):
396
397
398
399
400
401
402
403
404
405
406
407
408
    # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
    out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
    out = residual + out
    return out


def get_bias_dropout_add(training):
    def _bias_dropout_add(x, bias, residual, prob):
        return bias_dropout_add(x, bias, residual, prob, training)
    return _bias_dropout_add


@torch.jit.script
409
410
411
412
def bias_dropout_add_fused_train(x: torch.Tensor,
                                 bias: torch.Tensor,
                                 residual: torch.Tensor,
                                 prob: float) -> torch.Tensor:
413
414
415
416
    return bias_dropout_add(x, bias, residual, prob, True)


@torch.jit.script
417
418
419
420
def bias_dropout_add_fused_inference(x: torch.Tensor,
                                     bias: torch.Tensor,
                                     residual: torch.Tensor,
                                     prob: float) -> torch.Tensor:
421
    return bias_dropout_add(x, bias, residual, prob, False)
422
423
424
425
426


class ParallelTransformerLayer(MegatronModule):
    """A single transformer layer.

427
    Transformer layer takes input with size [b, s, h] and returns an
428
429
    output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
430

431
432
    def __init__(self, init_method, output_layer_init_method,
                 layer_number, layer_type=LayerType.encoder,
433
434
                 self_attn_mask_type=AttnMaskType.padding,
                 drop_path_rate=0.):
Mohammad's avatar
Mohammad committed
435
        args = get_args()
436
437

        super(ParallelTransformerLayer, self).__init__()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
438
        self.layer_number = layer_number
439
        self.layer_type = layer_type
440
441

        self.apply_residual_connection_post_layernorm \
Mohammad's avatar
Mohammad committed
442
            = args.apply_residual_connection_post_layernorm
443

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
444
445
446
        self.bf16 = args.bf16
        self.fp32_residual_connection = args.fp32_residual_connection

447
448
        # Layernorm on the input data.
        self.input_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
449
            args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
450
451
            eps=args.layernorm_epsilon,
            no_persist_layer_norm=args.no_persist_layer_norm)
452
453

        # Self attention.
454
455
456
457
458
459
        self.self_attention = ParallelAttention(
            init_method,
            output_layer_init_method,
            layer_number,
            attention_type=AttnType.self_attn,
            attn_mask_type=self_attn_mask_type)
460
461
        self.hidden_dropout = args.hidden_dropout
        self.bias_dropout_fusion = args.bias_dropout_fusion
Vijay Korthikanti's avatar
Vijay Korthikanti committed
462
        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None
463

464
        # Layernorm on the attention output
465
        self.post_attention_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
466
            args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
467
468
            eps=args.layernorm_epsilon,
            no_persist_layer_norm=args.no_persist_layer_norm)
469

470
471
472
473
474
475
476
477
478
        if self.layer_type == LayerType.decoder:
            self.inter_attention = ParallelAttention(
                init_method,
                output_layer_init_method,
                layer_number,
                attention_type=AttnType.cross_attn)
            # Layernorm on the attention output.
            self.post_inter_attention_layernorm = LayerNorm(
                args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
479
480
                eps=args.layernorm_epsilon,
                no_persist_layer_norm=args.no_persist_layer_norm)
481

482
        # MLP
483
        self.mlp = ParallelMLP(init_method,
Mohammad's avatar
Mohammad committed
484
                               output_layer_init_method)
485

486
    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
487
488
                encoder_output=None, enc_dec_attn_mask=None,
                inference_params=None):
489
490
        # hidden_states: [b, s, h]

491
        # Layer norm at the beginning of the transformer layer.
492
493
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
494
        attention_output, attention_bias = \
495
496
497
            self.self_attention(
                layernorm_output,
                attention_mask,
mshoeybi's avatar
mshoeybi committed
498
                inference_params=inference_params)
499

500
501
        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
502
503
504
505
            residual = layernorm_output
        else:
            residual = hidden_states

Vijay Korthikanti's avatar
Vijay Korthikanti committed
506
        if self.drop_path is None:
507
508
509
510
511
512
513
514
515
            # jit scripting for a nn.module (with dropout) is not
            # trigerring the fusion kernel. For now, we use two
            # different nn.functional routines to account for varying
            # dropout semantics during training and inference phases.
            if self.bias_dropout_fusion:
                if self.training:
                    bias_dropout_add_func = bias_dropout_add_fused_train
                else:
                    bias_dropout_add_func = bias_dropout_add_fused_inference
516
            else:
517
                bias_dropout_add_func = get_bias_dropout_add(self.training)
518

519
520
521
522
523
524
525
526
527
528
529
530
            # re-enable torch grad to enable fused optimization.
            with torch.enable_grad():
                layernorm_input = bias_dropout_add_func(
                    attention_output,
                    attention_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)
        else:
            out = torch.nn.functional.dropout(attention_output + attention_bias,
                                              p=self.hidden_dropout,
                                              training=self.training)
            layernorm_input = residual + self.drop_path(out)
531

532
533
534
        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)

535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
        if self.layer_type == LayerType.decoder:
            attention_output, attention_bias = \
                self.inter_attention(layernorm_output,
                                     enc_dec_attn_mask,
                                     encoder_output=encoder_output)
            # residual connection
            if self.apply_residual_connection_post_layernorm:
                residual = layernorm_output
            else:
                residual = layernorm_input

            # re-enable torch grad to enable fused optimization.
            with torch.enable_grad():
                layernorm_input = bias_dropout_add_func(
                    attention_output,
                    attention_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)

            # Layer norm post the decoder attention
            layernorm_output = self.post_inter_attention_layernorm(layernorm_input)

557
        # MLP.
558
        mlp_output, mlp_bias = self.mlp(layernorm_output)
559

560
561
        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
562
            residual = layernorm_output
563
        else:
564
565
            residual = layernorm_input

Vijay Korthikanti's avatar
Vijay Korthikanti committed
566
        if self.drop_path is None:
567
568
569
570
571
572
573
574
575
576
577
578
            # re-enable torch grad to enable fused optimization.
            with torch.enable_grad():
                output = bias_dropout_add_func(
                    mlp_output,
                    mlp_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)
        else:
            out = torch.nn.functional.dropout(mlp_output + mlp_bias,
                                              p=self.hidden_dropout,
                                              training=self.training)
            output = residual + self.drop_path(out)
579
580
581
582
583
584
585

        return output


class ParallelTransformer(MegatronModule):
    """Transformer class."""

586
    def __init__(self, init_method, output_layer_init_method,
587
                 layer_type=LayerType.encoder,
588
                 self_attn_mask_type=AttnMaskType.padding,
589
590
                 pre_process=True, post_process=True,
                 drop_path_rate=0.0):
591
        super(ParallelTransformer, self).__init__()
Mohammad's avatar
Mohammad committed
592
        args = get_args()
593

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
594
        self.bf16 = args.bf16
595
        self.fp32_residual_connection = args.fp32_residual_connection
596
597
598
        self.pre_process = pre_process
        self.post_process = post_process
        self.input_tensor = None
599
        self.drop_path_rate = drop_path_rate
600

601
        # Store activation checkpoiting flag.
602
603
        self.activations_checkpoint_method = args.activations_checkpoint_method
        self.activations_checkpoint_num_layers = args.activations_checkpoint_num_layers
mshoeybi's avatar
mshoeybi committed
604
        self.distribute_checkpointed_activations = args.distribute_checkpointed_activations
605

606
        # Number of layers.
607
608
        self.num_layers = mpu.get_num_layers(
            args, args.model_type == ModelType.encoder_and_decoder)
Mohammad's avatar
Mohammad committed
609

Vijay Korthikanti's avatar
Vijay Korthikanti committed
610
        self.drop_path_rates = [rate.item() for rate in torch.linspace(0, self.drop_path_rate, args.num_layers)]
611

Mohammad's avatar
Mohammad committed
612
613
        # Transformer layers.
        def build_layer(layer_number):
614
            return ParallelTransformerLayer(
615
616
617
                init_method,
                output_layer_init_method,
                layer_number,
618
                layer_type=layer_type,
619
                self_attn_mask_type=self_attn_mask_type,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
620
                drop_path_rate=self.drop_path_rates[layer_number - 1])
621
622
        if args.virtual_pipeline_model_parallel_size is not None:
            assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
623
624
                'num_layers_per_stage must be divisible by ' \
                'virtual_pipeline_model_parallel_size'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
625
            assert args.model_type != ModelType.encoder_and_decoder
626
627
            # Number of layers in each model chunk is the number of layers in the stage,
            # divided by the number of model chunks in a stage.
628
            self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
629
630
631
632
633
634
635
636
            # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0]  [2]  [4]  [6]
            # Stage 1: [1]  [3]  [5]  [7]
            # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0, 1]  [4, 5]
            # Stage 1: [2, 3]  [6, 7]
637
            offset = mpu.get_virtual_pipeline_model_parallel_rank() * (
638
                args.num_layers // args.virtual_pipeline_model_parallel_size) + \
639
640
                (mpu.get_pipeline_model_parallel_rank() * self.num_layers)
        else:
641
            # Each stage gets a contiguous set of layers.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
642
643
            if args.model_type == ModelType.encoder_and_decoder and \
                    mpu.get_pipeline_model_parallel_world_size() > 1:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
644
645
646
647
648
649
650
651
                pipeline_rank = mpu.get_pipeline_model_parallel_rank()
                if layer_type == LayerType.encoder:
                    offset = pipeline_rank * self.num_layers
                else:
                    num_ranks_in_enc = args.pipeline_model_parallel_split_rank
                    offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers
            else:
                offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
652

653
        self.layers = torch.nn.ModuleList(
654
            [build_layer(i + 1 + offset) for i in range(self.num_layers)])
655

656
        if self.post_process:
657
658
659
            # Final layer norm before output.
            self.final_layernorm = LayerNorm(
                args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
660
661
                eps=args.layernorm_epsilon,
                no_persist_layer_norm=args.no_persist_layer_norm)
662

Mohammad's avatar
Mohammad committed
663
    def _get_layer(self, layer_number):
664
        return self.layers[layer_number]
Mohammad's avatar
Mohammad committed
665

666
667
    def _checkpointed_forward(self, hidden_states, attention_mask,
                              encoder_output, enc_dec_attn_mask):
668
669
670
671
        """Forward method with activation checkpointing."""
        def custom(start, end):
            def custom_forward(*inputs):
                x_ = inputs[0]
672
673
674
                attention_mask = inputs[1]
                encoder_output = inputs[2]
                enc_dec_attn_mask = inputs[3]
Mohammad's avatar
Mohammad committed
675
676
                for index in range(start, end):
                    layer = self._get_layer(index)
677
                    x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
678
679
680
                return x_
            return custom_forward

681
682
683
684
685
686
687
688
        if self.activations_checkpoint_method == 'uniform':
            # Uniformly divide the total number of Transformer layers and checkpoint
            # the input activation of each divided chunk.
            # A method to further reduce memory usage reducing checkpoints.
            l = 0
            while l < self.num_layers:
                hidden_states = mpu.checkpoint(
                    custom(l, l + self.activations_checkpoint_num_layers),
689
                    self.distribute_checkpointed_activations,
690
691
692
693
694
695
696
697
698
699
                    hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
                l += self.activations_checkpoint_num_layers
        elif self.activations_checkpoint_method == 'block':
            # Checkpoint the input activation of only a set number of individual
            # Transformer layers and skip the rest.
            # A method fully use the device memory removing redundant re-computation.
            for l in range(self.num_layers):
                if l < self.activations_checkpoint_num_layers:
                    hidden_states = mpu.checkpoint(
                        custom(l, l + 1),
700
                        self.distribute_checkpointed_activations,
701
702
703
704
705
706
                        hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
                else:
                    hidden_states = custom(l, l + 1)(
                        hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
        else:
            raise ValueError("Invalid activation checkpoint method.")
707
708
709

        return hidden_states

710
    def set_input_tensor(self, input_tensor):
711
712
713
714
715
716
717
        """Set input tensor to be used instead of forward()'s input.

        When doing pipeline parallelism the input from the previous
        stage comes from communication, not from the input, so the
        model's forward_step_func won't have it. This function is thus
        used by internal code to bypass the input provided by the
        forward_step_func"""
718
719
        self.input_tensor = input_tensor

720
    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
721
722
                encoder_output=None, enc_dec_attn_mask=None,
                inference_params=None):
723

724
        # Checks.
mshoeybi's avatar
mshoeybi committed
725
        if inference_params:
726
            assert self.activations_checkpoint_method is None, \
727
                'inference does not work with activation checkpointing'
728

729
        if self.pre_process:
730
            # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
mshoeybi's avatar
mshoeybi committed
731
            # If the input flag for fp32 residual connection is set, convert for float.
732
733
            if self.fp32_residual_connection:
                hidden_states = hidden_states.transpose(0, 1).contiguous().float()
mshoeybi's avatar
mshoeybi committed
734
            # Otherwise, leave it as is.
735
736
            else:
                hidden_states = hidden_states.transpose(0, 1).contiguous()
737
        else:
738
            # See set_input_tensor()
739
            hidden_states = self.input_tensor
740

Vijay Korthikanti's avatar
Vijay Korthikanti committed
741
742
        if encoder_output is not None:
             encoder_output = encoder_output.transpose(0, 1).contiguous()
743

744
        if self.activations_checkpoint_method is not None:
745
            hidden_states = self._checkpointed_forward(hidden_states,
746
747
748
                                                       attention_mask,
                                                       encoder_output,
                                                       enc_dec_attn_mask)
749
        else:
Mohammad's avatar
Mohammad committed
750
751
            for index in range(self.num_layers):
                layer = self._get_layer(index)
752
753
754
755
756
                hidden_states = layer(
                    hidden_states,
                    attention_mask,
                    encoder_output=encoder_output,
                    enc_dec_attn_mask=enc_dec_attn_mask,
mshoeybi's avatar
mshoeybi committed
757
758
                    inference_params=inference_params)

759

760
        # Final layer norm.
761
        if self.post_process:
762
763
            # Reverting data format change [s b h] --> [b s h].
            hidden_states = hidden_states.transpose(0, 1).contiguous()
764
765
766
            output = self.final_layernorm(hidden_states)
        else:
            output = hidden_states
767
        
768
        return output