transformer.py 29.8 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transformer."""
import math
import torch
19
import torch.nn.functional as F
20

Mohammad's avatar
Mohammad committed
21
from megatron import get_args
22
from megatron import mpu
23
from .module import MegatronModule
24
from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType
25
from megatron.model import LayerNorm
26
27
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
28
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
29
30
31
32
33
34
35
36
37
38
39

""" We use the following notation throughout this file:
     h: hidden size
     n: number of attention heads
     p: number of model parallel partitions
     np: n/p
     hp: h/p
     hn: h/n
     b: batch size
     s: sequence length
     l: number of layers
40
    Transformer takes input of size [s, b, h] and returns a
41
42
43
44
45
46
47
48
49
    tensor of the same size. We use the following arguments:
        hyperparameters: transformer hyperparameters
"""

class ParallelMLP(MegatronModule):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
hwijeen's avatar
hwijeen committed
50
    state back into h hidden dimension.
51
52
    """

53
    def __init__(self, init_method, output_layer_init_method):
54
        super(ParallelMLP, self).__init__()
Mohammad's avatar
Mohammad committed
55
        args = get_args()
56
57
58

        # Project to 4h.
        self.dense_h_to_4h = mpu.ColumnParallelLinear(
Mohammad's avatar
Mohammad committed
59
            args.hidden_size,
60
            args.ffn_hidden_size,
61
            gather_output=False,
62
63
            init_method=init_method,
            skip_bias_add=True)
64

65
66
67
68
69
70
        self.bias_gelu_fusion = args.bias_gelu_fusion
        self.activation_func = F.gelu
        if args.openai_gelu:
            self.activation_func = openai_gelu
        elif args.onnx_safe:
            self.activation_func = erf_gelu
71
72
73

        # Project back to h.
        self.dense_4h_to_h = mpu.RowParallelLinear(
74
            args.ffn_hidden_size,
Mohammad's avatar
Mohammad committed
75
            args.hidden_size,
76
            input_is_parallel=True,
77
78
            init_method=output_layer_init_method,
            skip_bias_add=True)
79

80
81
    def forward(self, hidden_states):

82
83
        # [s, b, 4hp]
        intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
84

85
86
87
88
89
90
91
92
93
94
        if self.bias_gelu_fusion:
             intermediate_parallel = \
                     bias_gelu_impl(intermediate_parallel, bias_parallel)
        else:
            intermediate_parallel = \
                self.activation_func(intermediate_parallel + bias_parallel)

        # [s, b, h]
        output, output_bias = self.dense_4h_to_h(intermediate_parallel)
        return output, output_bias
95
96


97
class ParallelAttention(MegatronModule):
98
99
100
101
102
    """Parallel self-attention layer abstract class.

    Self-attention layer takes input with size [b, s, h]
    and returns output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
103

104
    def __init__(self, init_method,
105
106
107
108
                 output_layer_init_method, layer_number,
                 attention_type=AttnType.self_attn,
                 attn_mask_type=AttnMaskType.padding):
        super(ParallelAttention, self).__init__()
Mohammad's avatar
Mohammad committed
109
        args = get_args()
Mohammad's avatar
Mohammad committed
110
        self.fp16 = args.fp16
111
        self.bf16 = args.bf16
112

Mohammad's avatar
Mohammad committed
113
114
        self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
115
116
117
        if self.apply_query_key_layer_scaling:
            self.attention_softmax_in_fp32 = True
        self.layer_number = max(1, layer_number)
118
119
        self.attention_type = attention_type
        self.attn_mask_type = attn_mask_type
120
        self.params_dtype = args.params_dtype
121
122

        projection_size = args.kv_channels * args.num_attention_heads
123
124

        # Per attention head and per partition values.
125
        world_size = mpu.get_tensor_model_parallel_world_size()
126
        self.hidden_size_per_partition = mpu.divide(projection_size,
Mohammad's avatar
Mohammad committed
127
                                                    world_size)
128
        self.hidden_size_per_attention_head = mpu.divide(
129
            projection_size, args.num_attention_heads)
130
        self.num_attention_heads_per_partition = mpu.divide(
Mohammad's avatar
Mohammad committed
131
            args.num_attention_heads, world_size)
132
133

        # Strided linear layer.
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
        if attention_type == AttnType.self_attn:
            self.query_key_value = mpu.ColumnParallelLinear(
                args.hidden_size,
                3 * projection_size,
                gather_output=False,
                init_method=init_method)
        else:
            assert attention_type == AttnType.cross_attn
            self.query = mpu.ColumnParallelLinear(
                args.hidden_size,
                projection_size,
                gather_output=False,
                init_method=init_method)

            self.key_value = mpu.ColumnParallelLinear(
                args.hidden_size,
                2 * projection_size,
                gather_output=False,
                init_method=init_method)
153

154
155
156
157
158
159
160
        coeff = None
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        if self.apply_query_key_layer_scaling:
            coeff = self.layer_number
            self.norm_factor *= coeff

        self.scale_mask_softmax = FusedScaleMaskSoftmax(
161
            self.fp16, self.bf16,
162
163
            self.attn_mask_type,
            args.masked_softmax_fusion,
164
            attention_mask_func,
165
166
167
            self.attention_softmax_in_fp32,
            coeff)

168
169
170
        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
Mohammad's avatar
Mohammad committed
171
        self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
172
173
174

        # Output.
        self.dense = mpu.RowParallelLinear(
175
            projection_size,
Mohammad's avatar
Mohammad committed
176
            args.hidden_size,
177
            input_is_parallel=True,
178
179
            init_method=output_layer_init_method,
            skip_bias_add=True)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
180

181
182
183
184
185
186
187
188
189
190
191
192

    def _allocate_memory(self, inference_max_sequence_len, batch_size):
        return torch.empty(
            inference_max_sequence_len,
            batch_size,
            self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
            dtype=self.params_dtype,
            device=torch.cuda.current_device())
        

    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
193
                encoder_output=None, inference_params=None):
194
        # hidden_states: [sq, b, h]
195

196
197
198
199

        # =================================================
        # Pre-allocate memory for key-values for inference.
        # =================================================
mshoeybi's avatar
mshoeybi committed
200
        if inference_params:
201
            if self.layer_number not in inference_params.key_value_memory_dict:
mshoeybi's avatar
mshoeybi committed
202
                inf_max_seq_len = inference_params.max_sequence_len
mshoeybi's avatar
mshoeybi committed
203
                inf_max_batch_size = inference_params.max_batch_size
204
                inference_key_memory = self._allocate_memory(
mshoeybi's avatar
mshoeybi committed
205
                    inf_max_seq_len, inf_max_batch_size)
206
                inference_value_memory = self._allocate_memory(
mshoeybi's avatar
mshoeybi committed
207
                    inf_max_seq_len, inf_max_batch_size)
208
209
210
211
212
                inference_params.key_value_memory_dict[self.layer_number] = (
                    inference_key_memory, inference_value_memory)
            else:
                inference_key_memory, inference_value_memory = \
                    inference_params.key_value_memory_dict[self.layer_number]
mshoeybi's avatar
mshoeybi committed
213

214

215
216
217
        # =====================
        # Query, Key, and Value
        # =====================
218

219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
        if self.attention_type == AttnType.self_attn:
            # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
            mixed_x_layer, _ = self.query_key_value(hidden_states)

            # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
            new_tensor_shape = mixed_x_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 3 * self.hidden_size_per_attention_head)
            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

            # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
            (query_layer,
             key_layer,
             value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
        else:
            # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
            mixed_kv_layer, _ = self.key_value(encoder_output)

            # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
            new_tensor_shape = mixed_kv_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 2 * self.hidden_size_per_attention_head)
            mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

            # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
            (key_layer,
             value_layer) = mpu.split_tensor_along_last_dim(mixed_kv_layer, 2)

            # Attention head [sq, b, h] --> [sq, b, hp]
            query_layer, _ = self.query(hidden_states)
            # [sq, b, hp] --> [sq, b, np, hn]
            new_tensor_shape = query_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 self.hidden_size_per_attention_head)
            query_layer = query_layer.view(*new_tensor_shape)
254
255


mshoeybi's avatar
mshoeybi committed
256
257
258
        # ==================================
        # Adjust key and value for inference
        # ==================================
259

mshoeybi's avatar
mshoeybi committed
260
        if inference_params:
mshoeybi's avatar
mshoeybi committed
261
262
            batch_start = inference_params.batch_size_offset
            batch_end = batch_start + key_layer.size(1)
263
            assert batch_end <= inference_key_memory.size(1)
mshoeybi's avatar
mshoeybi committed
264
265
            sequence_start = inference_params.sequence_len_offset
            sequence_end = sequence_start + key_layer.size(0)
266
            assert sequence_end <= inference_key_memory.size(0)
267
            # Copy key and values.
268
269
270
271
272
            inference_key_memory[sequence_start:sequence_end,
                                 batch_start:batch_end, ...] = key_layer
            inference_value_memory[sequence_start:sequence_end,
                                   batch_start:batch_end, ...] = value_layer
            key_layer = inference_key_memory[
mshoeybi's avatar
mshoeybi committed
273
                :sequence_end, batch_start:batch_end, ...]
274
            value_layer = inference_value_memory[
mshoeybi's avatar
mshoeybi committed
275
                :sequence_end, batch_start:batch_end, ...]
276

277

278
279
280
        # ===================================
        # Raw attention scores. [b, np, s, s]
        # ===================================
281

282
        # [b, np, sq, sk]
283
284
285
        output_size = (query_layer.size(1),
                       query_layer.size(2),
                       query_layer.size(0),
286
                       key_layer.size(0))
287

288
        # [sq, b, np, hn] -> [sq, b * np, hn]
289
290
        query_layer = query_layer.view(output_size[2],
                                       output_size[0] * output_size[1], -1)
291
        # [sk, b, np, hn] -> [sk, b * np, hn]
292
293
294
        key_layer = key_layer.view(output_size[3],
                                   output_size[0] * output_size[1], -1)

295
        # preallocting result tensor: [b * np, sq, sk]
296
        matmul_result = torch.empty(
297
298
            output_size[0]*output_size[1],
            output_size[2],
299
            output_size[3],
300
            dtype=query_layer.dtype,
301
302
            device=torch.cuda.current_device())

303
        # Raw attention scores. [b * np, sq, sk]
304
305
        matmul_result = torch.baddbmm(
            matmul_result,
306
            query_layer.transpose(0, 1),   # [b * np, sq, hn]
307
            key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
308
309
            beta=0.0, alpha=(1.0/self.norm_factor))

310
        # change view to [b, np, sq, sk]
311
312
        attention_scores = matmul_result.view(*output_size)

313

314
315
316
        # ===========================
        # Attention probs and dropout
        # ===========================
317

318
        # attention scores and attention mask [b, np, sq, sk]
319
320
        attention_probs = self.scale_mask_softmax(attention_scores,
                                                  attention_mask)
321

322
323
324
325
326
327
        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        with mpu.get_cuda_rng_tracker().fork():
            attention_probs = self.attention_dropout(attention_probs)

        # =========================
328
        # Context layer. [sq, b, hp]
329
330
        # =========================

331
332
        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]
333

334
        # context layer shape: [b, np, sq, hn]
335
336
337
338
        output_size = (value_layer.size(1),
                       value_layer.size(2),
                       query_layer.size(0),
                       value_layer.size(3))
339

340
        # change view [sk, b * np, hn]
341
        value_layer = value_layer.view(value_layer.size(0),
342
                                       output_size[0] * output_size[1], -1)
343

344
        # change view [b * np, sq, sk]
345
346
        attention_probs = attention_probs.view(output_size[0] * output_size[1],
                                               output_size[2], -1)
347

348
        # matmul: [b * np, sq, hn]
349
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
350

351
        # change view [b, np, sq, hn]
352
353
        context_layer = context_layer.view(*output_size)

354
        # [b, np, sq, hn] --> [sq, b, np, hn]
355
356
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

357
        # [sq, b, np, hn] --> [sq, b, hp]
358
359
360
361
362
        new_context_layer_shape = context_layer.size()[:-2] + \
            (self.hidden_size_per_partition,)
        context_layer = context_layer.view(*new_context_layer_shape)

        # =================
363
        # Output. [sq, b, h]
364
365
366
        # =================

        output, bias = self.dense(context_layer)
367

368
369
370
        return output, bias


371
def bias_dropout_add(x, bias, residual, prob, training):
372
373
374
375
376
377
378
379
380
381
382
383
384
    # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
    out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
    out = residual + out
    return out


def get_bias_dropout_add(training):
    def _bias_dropout_add(x, bias, residual, prob):
        return bias_dropout_add(x, bias, residual, prob, training)
    return _bias_dropout_add


@torch.jit.script
385
386
387
388
def bias_dropout_add_fused_train(x: torch.Tensor,
                                 bias: torch.Tensor,
                                 residual: torch.Tensor,
                                 prob: float) -> torch.Tensor:
389
390
391
392
    return bias_dropout_add(x, bias, residual, prob, True)


@torch.jit.script
393
394
395
396
def bias_dropout_add_fused_inference(x: torch.Tensor,
                                     bias: torch.Tensor,
                                     residual: torch.Tensor,
                                     prob: float) -> torch.Tensor:
397
    return bias_dropout_add(x, bias, residual, prob, False)
398
399
400
401
402


class ParallelTransformerLayer(MegatronModule):
    """A single transformer layer.

403
    Transformer layer takes input with size [b, s, h] and returns an
404
405
    output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
406

407
408
    def __init__(self, init_method, output_layer_init_method,
                 layer_number, layer_type=LayerType.encoder,
409
                 self_attn_mask_type=AttnMaskType.padding):
Mohammad's avatar
Mohammad committed
410
        args = get_args()
411
412

        super(ParallelTransformerLayer, self).__init__()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
413
        self.layer_number = layer_number
414
        self.layer_type = layer_type
415
416

        self.apply_residual_connection_post_layernorm \
Mohammad's avatar
Mohammad committed
417
            = args.apply_residual_connection_post_layernorm
418

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
419
420
421
        self.bf16 = args.bf16
        self.fp32_residual_connection = args.fp32_residual_connection

422
423
        # Layernorm on the input data.
        self.input_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
424
            args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
425
426
            eps=args.layernorm_epsilon,
            no_persist_layer_norm=args.no_persist_layer_norm)
427
428

        # Self attention.
429
430
431
432
433
434
        self.self_attention = ParallelAttention(
            init_method,
            output_layer_init_method,
            layer_number,
            attention_type=AttnType.self_attn,
            attn_mask_type=self_attn_mask_type)
435
436
        self.hidden_dropout = args.hidden_dropout
        self.bias_dropout_fusion = args.bias_dropout_fusion
437

438
        # Layernorm on the attention output
439
        self.post_attention_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
440
            args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
441
442
            eps=args.layernorm_epsilon,
            no_persist_layer_norm=args.no_persist_layer_norm)
443

444
445
446
447
448
449
450
451
452
        if self.layer_type == LayerType.decoder:
            self.inter_attention = ParallelAttention(
                init_method,
                output_layer_init_method,
                layer_number,
                attention_type=AttnType.cross_attn)
            # Layernorm on the attention output.
            self.post_inter_attention_layernorm = LayerNorm(
                args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
453
454
                eps=args.layernorm_epsilon,
                no_persist_layer_norm=args.no_persist_layer_norm)
455

456
        # MLP
457
        self.mlp = ParallelMLP(init_method,
Mohammad's avatar
Mohammad committed
458
                               output_layer_init_method)
459

460
    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
461
462
                encoder_output=None, enc_dec_attn_mask=None,
                inference_params=None):
463
464
        # hidden_states: [b, s, h]

465
        # Layer norm at the beginning of the transformer layer.
466
467
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
468
        attention_output, attention_bias = \
469
470
471
            self.self_attention(
                layernorm_output,
                attention_mask,
mshoeybi's avatar
mshoeybi committed
472
                inference_params=inference_params)
473

474
475
        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
476
477
478
479
            residual = layernorm_output
        else:
            residual = hidden_states

480
481
        # jit scripting for a nn.module (with dropout) is not
        # trigerring the fusion kernel. For now, we use two
482
483
484
485
486
487
488
        # different nn.functional routines to account for varying
        # dropout semantics during training and inference phases.
        if self.bias_dropout_fusion:
            if self.training:
                bias_dropout_add_func = bias_dropout_add_fused_train
            else:
                bias_dropout_add_func = bias_dropout_add_fused_inference
489
        else:
490
491
            bias_dropout_add_func = get_bias_dropout_add(self.training)

492
        # re-enable torch grad to enable fused optimization.
493
494
495
496
497
498
499
        with torch.enable_grad():
            layernorm_input = bias_dropout_add_func(
                attention_output,
                attention_bias.expand_as(residual),
                residual,
                self.hidden_dropout)

500
501
502
        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)

503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
        if self.layer_type == LayerType.decoder:
            attention_output, attention_bias = \
                self.inter_attention(layernorm_output,
                                     enc_dec_attn_mask,
                                     encoder_output=encoder_output)
            # residual connection
            if self.apply_residual_connection_post_layernorm:
                residual = layernorm_output
            else:
                residual = layernorm_input

            # re-enable torch grad to enable fused optimization.
            with torch.enable_grad():
                layernorm_input = bias_dropout_add_func(
                    attention_output,
                    attention_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)

            # Layer norm post the decoder attention
            layernorm_output = self.post_inter_attention_layernorm(layernorm_input)

525
        # MLP.
526
        mlp_output, mlp_bias = self.mlp(layernorm_output)
527

528
529
        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
530
            residual = layernorm_output
531
        else:
532
533
            residual = layernorm_input

534
        # re-enable torch grad to enable fused optimization.
535
536
537
538
539
540
        with torch.enable_grad():
            output = bias_dropout_add_func(
                mlp_output,
                mlp_bias.expand_as(residual),
                residual,
                self.hidden_dropout)
541
542
543
544
545
546
547

        return output


class ParallelTransformer(MegatronModule):
    """Transformer class."""

548
    def __init__(self, init_method, output_layer_init_method,
549
                 layer_type=LayerType.encoder,
550
551
                 self_attn_mask_type=AttnMaskType.padding,
                 pre_process=True, post_process=True):
552
        super(ParallelTransformer, self).__init__()
Mohammad's avatar
Mohammad committed
553
        args = get_args()
554

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
555
        self.bf16 = args.bf16
556
        self.fp32_residual_connection = args.fp32_residual_connection
557
558
559
        self.pre_process = pre_process
        self.post_process = post_process
        self.input_tensor = None
560

561
        # Store activation checkpoiting flag.
562
563
        self.activations_checkpoint_method = args.activations_checkpoint_method
        self.activations_checkpoint_num_layers = args.activations_checkpoint_num_layers
mshoeybi's avatar
mshoeybi committed
564
        self.distribute_checkpointed_activations = args.distribute_checkpointed_activations
565

566
        # Number of layers.
567
568
569
        # >>>
        # raise Exception("rank %d." % torch.distributed.get_rank())
        # <<<
570
571
        self.num_layers = mpu.get_num_layers(
            args, args.model_type == ModelType.encoder_and_decoder)
Mohammad's avatar
Mohammad committed
572
573
574

        # Transformer layers.
        def build_layer(layer_number):
575
            return ParallelTransformerLayer(
576
577
578
                init_method,
                output_layer_init_method,
                layer_number,
579
580
                layer_type=layer_type,
                self_attn_mask_type=self_attn_mask_type)
581
582
        if args.virtual_pipeline_model_parallel_size is not None:
            assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
583
584
                'num_layers_per_stage must be divisible by ' \
                'virtual_pipeline_model_parallel_size'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
585
            assert args.model_type != ModelType.encoder_and_decoder
586
587
            # Number of layers in each model chunk is the number of layers in the stage,
            # divided by the number of model chunks in a stage.
588
            self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
589
590
591
592
593
594
595
596
            # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0]  [2]  [4]  [6]
            # Stage 1: [1]  [3]  [5]  [7]
            # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0, 1]  [4, 5]
            # Stage 1: [2, 3]  [6, 7]
597
            offset = mpu.get_virtual_pipeline_model_parallel_rank() * (
598
                args.num_layers // args.virtual_pipeline_model_parallel_size) + \
599
600
                (mpu.get_pipeline_model_parallel_rank() * self.num_layers)
        else:
601
            # Each stage gets a contiguous set of layers.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
602
603
            if args.model_type == ModelType.encoder_and_decoder and \
                    mpu.get_pipeline_model_parallel_world_size() > 1:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
604
605
606
607
608
609
610
611
                pipeline_rank = mpu.get_pipeline_model_parallel_rank()
                if layer_type == LayerType.encoder:
                    offset = pipeline_rank * self.num_layers
                else:
                    num_ranks_in_enc = args.pipeline_model_parallel_split_rank
                    offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers
            else:
                offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
612

613
        self.layers = torch.nn.ModuleList(
614
            [build_layer(i + 1 + offset) for i in range(self.num_layers)])
615

616
        if self.post_process:
617
618
619
            # Final layer norm before output.
            self.final_layernorm = LayerNorm(
                args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
620
621
                eps=args.layernorm_epsilon,
                no_persist_layer_norm=args.no_persist_layer_norm)
622

Mohammad's avatar
Mohammad committed
623
    def _get_layer(self, layer_number):
624
        return self.layers[layer_number]
Mohammad's avatar
Mohammad committed
625

626
627
    def _checkpointed_forward(self, hidden_states, attention_mask,
                              encoder_output, enc_dec_attn_mask):
628
629
630
631
        """Forward method with activation checkpointing."""
        def custom(start, end):
            def custom_forward(*inputs):
                x_ = inputs[0]
632
633
634
                attention_mask = inputs[1]
                encoder_output = inputs[2]
                enc_dec_attn_mask = inputs[3]
Mohammad's avatar
Mohammad committed
635
636
                for index in range(start, end):
                    layer = self._get_layer(index)
637
                    x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
638
639
640
                return x_
            return custom_forward

641
642
643
644
645
646
647
648
        if self.activations_checkpoint_method == 'uniform':
            # Uniformly divide the total number of Transformer layers and checkpoint
            # the input activation of each divided chunk.
            # A method to further reduce memory usage reducing checkpoints.
            l = 0
            while l < self.num_layers:
                hidden_states = mpu.checkpoint(
                    custom(l, l + self.activations_checkpoint_num_layers),
649
                    self.distribute_checkpointed_activations,
650
651
652
653
654
655
656
657
658
659
                    hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
                l += self.activations_checkpoint_num_layers
        elif self.activations_checkpoint_method == 'block':
            # Checkpoint the input activation of only a set number of individual
            # Transformer layers and skip the rest.
            # A method fully use the device memory removing redundant re-computation.
            for l in range(self.num_layers):
                if l < self.activations_checkpoint_num_layers:
                    hidden_states = mpu.checkpoint(
                        custom(l, l + 1),
660
                        self.distribute_checkpointed_activations,
661
662
663
664
665
666
                        hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
                else:
                    hidden_states = custom(l, l + 1)(
                        hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
        else:
            raise ValueError("Invalid activation checkpoint method.")
667
668
669

        return hidden_states

670
    def set_input_tensor(self, input_tensor):
671
672
673
674
675
676
677
        """Set input tensor to be used instead of forward()'s input.

        When doing pipeline parallelism the input from the previous
        stage comes from communication, not from the input, so the
        model's forward_step_func won't have it. This function is thus
        used by internal code to bypass the input provided by the
        forward_step_func"""
678
679
        self.input_tensor = input_tensor

680
    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
681
682
                encoder_output=None, enc_dec_attn_mask=None,
                inference_params=None):
683

684
        # Checks.
mshoeybi's avatar
mshoeybi committed
685
        if inference_params:
686
            assert self.activations_checkpoint_method is None, \
687
                'inference does not work with activation checkpointing'
688

689
        if self.pre_process:
690
            # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
mshoeybi's avatar
mshoeybi committed
691
            # If the input flag for fp32 residual connection is set, convert for float.
692
693
            if self.fp32_residual_connection:
                hidden_states = hidden_states.transpose(0, 1).contiguous().float()
mshoeybi's avatar
mshoeybi committed
694
            # Otherwise, leave it as is.
695
696
            else:
                hidden_states = hidden_states.transpose(0, 1).contiguous()
697
        else:
698
            # See set_input_tensor()
699
            hidden_states = self.input_tensor
700

701
702
        # Viewless tensor.
        # - We only need to create a viewless tensor in the case of micro batch
703
704
705
706
        #   size (mbs) == 1, since in this case, 'hidden_states.transpose()'
        #   above creates a view tensor, and '.contiguous()' is a pass-through.
        #   For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
        #   the need to make it viewless.
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
        #
        #   However, we don't explicitly check mbs == 1 here because
        #   make_viewless_tensor() has negligible overhead when its input
        #   is already viewless.
        # 
        # - For the 'else' case above, calling make_viewless_tensor() here is
        #   likely redundant, since p2p_communication.py (likely originator)
        #   already creates viewless tensors. That said, make_viewless_tensor()
        #   is called here to be future-proof and corner-case-proof.
        hidden_states = mpu.make_viewless_tensor(
            hidden_states,
            requires_grad = True,
            keep_graph = True,
        )

        # Transpose encoder output.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
723
        if encoder_output is not None:
724
            encoder_output = encoder_output.transpose(0, 1).contiguous()
725

726
        # Forward pass.
727
        if self.activations_checkpoint_method is not None:
728
            hidden_states = self._checkpointed_forward(hidden_states,
729
730
731
                                                       attention_mask,
                                                       encoder_output,
                                                       enc_dec_attn_mask)
732
        else:
Mohammad's avatar
Mohammad committed
733
734
            for index in range(self.num_layers):
                layer = self._get_layer(index)
735
736
737
738
739
                hidden_states = layer(
                    hidden_states,
                    attention_mask,
                    encoder_output=encoder_output,
                    enc_dec_attn_mask=enc_dec_attn_mask,
mshoeybi's avatar
mshoeybi committed
740
741
                    inference_params=inference_params)

742

743
        # Final layer norm.
744
        if self.post_process:
745
746
            # Reverting data format change [s b h] --> [b s h].
            hidden_states = hidden_states.transpose(0, 1).contiguous()
747
748
749
            output = self.final_layernorm(hidden_states)
        else:
            output = hidden_states
750
        
751
        return output