"csrc/multi_tensor_lamb_mp.cu" did not exist on "79906517e1e99c0cfc348ed2e6e716e7b8bd608c"
transformer.py 29.1 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transformer."""
import math
import torch
19
import torch.nn.functional as F
20

Mohammad's avatar
Mohammad committed
21
from megatron import get_args
22
from megatron import mpu
23
from .module import MegatronModule
24
from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType
25
from megatron.model import LayerNorm
26
27
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
28
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
29

30
31
32
# >>>
from megatron.mpu.random import make_viewless_tensor
# <<<
33
34
35
36
37
38
39
40
41
42
43

""" We use the following notation throughout this file:
     h: hidden size
     n: number of attention heads
     p: number of model parallel partitions
     np: n/p
     hp: h/p
     hn: h/n
     b: batch size
     s: sequence length
     l: number of layers
44
    Transformer takes input of size [s, b, h] and returns a
45
46
47
48
49
50
51
52
53
    tensor of the same size. We use the following arguments:
        hyperparameters: transformer hyperparameters
"""

class ParallelMLP(MegatronModule):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
hwijeen's avatar
hwijeen committed
54
    state back into h hidden dimension.
55
56
    """

57
    def __init__(self, init_method, output_layer_init_method):
58
        super(ParallelMLP, self).__init__()
Mohammad's avatar
Mohammad committed
59
        args = get_args()
60
61
62

        # Project to 4h.
        self.dense_h_to_4h = mpu.ColumnParallelLinear(
Mohammad's avatar
Mohammad committed
63
            args.hidden_size,
64
            args.ffn_hidden_size,
65
            gather_output=False,
66
67
            init_method=init_method,
            skip_bias_add=True)
68

69
70
71
72
73
74
        self.bias_gelu_fusion = args.bias_gelu_fusion
        self.activation_func = F.gelu
        if args.openai_gelu:
            self.activation_func = openai_gelu
        elif args.onnx_safe:
            self.activation_func = erf_gelu
75
76
77

        # Project back to h.
        self.dense_4h_to_h = mpu.RowParallelLinear(
78
            args.ffn_hidden_size,
Mohammad's avatar
Mohammad committed
79
            args.hidden_size,
80
            input_is_parallel=True,
81
82
            init_method=output_layer_init_method,
            skip_bias_add=True)
83

84
85
    def forward(self, hidden_states):

86
87
        # [s, b, 4hp]
        intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
88

89
90
91
92
93
94
95
96
97
98
        if self.bias_gelu_fusion:
             intermediate_parallel = \
                     bias_gelu_impl(intermediate_parallel, bias_parallel)
        else:
            intermediate_parallel = \
                self.activation_func(intermediate_parallel + bias_parallel)

        # [s, b, h]
        output, output_bias = self.dense_4h_to_h(intermediate_parallel)
        return output, output_bias
99
100


101
class ParallelAttention(MegatronModule):
102
103
104
105
106
    """Parallel self-attention layer abstract class.

    Self-attention layer takes input with size [b, s, h]
    and returns output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
107

108
    def __init__(self, init_method,
109
110
111
112
                 output_layer_init_method, layer_number,
                 attention_type=AttnType.self_attn,
                 attn_mask_type=AttnMaskType.padding):
        super(ParallelAttention, self).__init__()
Mohammad's avatar
Mohammad committed
113
        args = get_args()
Mohammad's avatar
Mohammad committed
114
        self.fp16 = args.fp16
115
        self.bf16 = args.bf16
116

Mohammad's avatar
Mohammad committed
117
118
        self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
119
120
121
        if self.apply_query_key_layer_scaling:
            self.attention_softmax_in_fp32 = True
        self.layer_number = max(1, layer_number)
122
123
        self.attention_type = attention_type
        self.attn_mask_type = attn_mask_type
124
        self.params_dtype = args.params_dtype
125
126

        projection_size = args.kv_channels * args.num_attention_heads
127
128

        # Per attention head and per partition values.
129
        world_size = mpu.get_tensor_model_parallel_world_size()
130
        self.hidden_size_per_partition = mpu.divide(projection_size,
Mohammad's avatar
Mohammad committed
131
                                                    world_size)
132
        self.hidden_size_per_attention_head = mpu.divide(
133
            projection_size, args.num_attention_heads)
134
        self.num_attention_heads_per_partition = mpu.divide(
Mohammad's avatar
Mohammad committed
135
            args.num_attention_heads, world_size)
136
137

        # Strided linear layer.
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
        if attention_type == AttnType.self_attn:
            self.query_key_value = mpu.ColumnParallelLinear(
                args.hidden_size,
                3 * projection_size,
                gather_output=False,
                init_method=init_method)
        else:
            assert attention_type == AttnType.cross_attn
            self.query = mpu.ColumnParallelLinear(
                args.hidden_size,
                projection_size,
                gather_output=False,
                init_method=init_method)

            self.key_value = mpu.ColumnParallelLinear(
                args.hidden_size,
                2 * projection_size,
                gather_output=False,
                init_method=init_method)
157

158
159
160
161
162
163
164
        coeff = None
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        if self.apply_query_key_layer_scaling:
            coeff = self.layer_number
            self.norm_factor *= coeff

        self.scale_mask_softmax = FusedScaleMaskSoftmax(
165
            self.fp16, self.bf16,
166
167
            self.attn_mask_type,
            args.masked_softmax_fusion,
168
            attention_mask_func,
169
170
171
            self.attention_softmax_in_fp32,
            coeff)

172
173
174
        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
Mohammad's avatar
Mohammad committed
175
        self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
176
177
178

        # Output.
        self.dense = mpu.RowParallelLinear(
179
            projection_size,
Mohammad's avatar
Mohammad committed
180
            args.hidden_size,
181
            input_is_parallel=True,
182
183
            init_method=output_layer_init_method,
            skip_bias_add=True)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
184

185
186
187
188
189
190
191
192
193
194
195
196

    def _allocate_memory(self, inference_max_sequence_len, batch_size):
        return torch.empty(
            inference_max_sequence_len,
            batch_size,
            self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
            dtype=self.params_dtype,
            device=torch.cuda.current_device())
        

    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
197
                encoder_output=None, inference_params=None):
198
        # hidden_states: [sq, b, h]
199

200
201
202
203

        # =================================================
        # Pre-allocate memory for key-values for inference.
        # =================================================
mshoeybi's avatar
mshoeybi committed
204
        if inference_params:
205
            if self.layer_number not in inference_params.key_value_memory_dict:
mshoeybi's avatar
mshoeybi committed
206
                inf_max_seq_len = inference_params.max_sequence_len
mshoeybi's avatar
mshoeybi committed
207
                inf_max_batch_size = inference_params.max_batch_size
208
                inference_key_memory = self._allocate_memory(
mshoeybi's avatar
mshoeybi committed
209
                    inf_max_seq_len, inf_max_batch_size)
210
                inference_value_memory = self._allocate_memory(
mshoeybi's avatar
mshoeybi committed
211
                    inf_max_seq_len, inf_max_batch_size)
212
213
214
215
216
                inference_params.key_value_memory_dict[self.layer_number] = (
                    inference_key_memory, inference_value_memory)
            else:
                inference_key_memory, inference_value_memory = \
                    inference_params.key_value_memory_dict[self.layer_number]
mshoeybi's avatar
mshoeybi committed
217

218

219
220
221
        # =====================
        # Query, Key, and Value
        # =====================
222

223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
        if self.attention_type == AttnType.self_attn:
            # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
            mixed_x_layer, _ = self.query_key_value(hidden_states)

            # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
            new_tensor_shape = mixed_x_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 3 * self.hidden_size_per_attention_head)
            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

            # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
            (query_layer,
             key_layer,
             value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
        else:
            # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
            mixed_kv_layer, _ = self.key_value(encoder_output)

            # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
            new_tensor_shape = mixed_kv_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 2 * self.hidden_size_per_attention_head)
            mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

            # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
            (key_layer,
             value_layer) = mpu.split_tensor_along_last_dim(mixed_kv_layer, 2)

            # Attention head [sq, b, h] --> [sq, b, hp]
            query_layer, _ = self.query(hidden_states)
            # [sq, b, hp] --> [sq, b, np, hn]
            new_tensor_shape = query_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 self.hidden_size_per_attention_head)
            query_layer = query_layer.view(*new_tensor_shape)
258
259


mshoeybi's avatar
mshoeybi committed
260
261
262
        # ==================================
        # Adjust key and value for inference
        # ==================================
263

mshoeybi's avatar
mshoeybi committed
264
        if inference_params:
mshoeybi's avatar
mshoeybi committed
265
266
            batch_start = inference_params.batch_size_offset
            batch_end = batch_start + key_layer.size(1)
267
            assert batch_end <= inference_key_memory.size(1)
mshoeybi's avatar
mshoeybi committed
268
269
            sequence_start = inference_params.sequence_len_offset
            sequence_end = sequence_start + key_layer.size(0)
270
            assert sequence_end <= inference_key_memory.size(0)
271
            # Copy key and values.
272
273
274
275
276
            inference_key_memory[sequence_start:sequence_end,
                                 batch_start:batch_end, ...] = key_layer
            inference_value_memory[sequence_start:sequence_end,
                                   batch_start:batch_end, ...] = value_layer
            key_layer = inference_key_memory[
mshoeybi's avatar
mshoeybi committed
277
                :sequence_end, batch_start:batch_end, ...]
278
            value_layer = inference_value_memory[
mshoeybi's avatar
mshoeybi committed
279
                :sequence_end, batch_start:batch_end, ...]
280

281

282
283
284
        # ===================================
        # Raw attention scores. [b, np, s, s]
        # ===================================
285

286
        # [b, np, sq, sk]
287
288
289
        output_size = (query_layer.size(1),
                       query_layer.size(2),
                       query_layer.size(0),
290
                       key_layer.size(0))
291

292
        # [sq, b, np, hn] -> [sq, b * np, hn]
293
294
        query_layer = query_layer.view(output_size[2],
                                       output_size[0] * output_size[1], -1)
295
        # [sk, b, np, hn] -> [sk, b * np, hn]
296
297
298
        key_layer = key_layer.view(output_size[3],
                                   output_size[0] * output_size[1], -1)

299
        # preallocting result tensor: [b * np, sq, sk]
300
        matmul_result = torch.empty(
301
302
            output_size[0]*output_size[1],
            output_size[2],
303
            output_size[3],
304
            dtype=query_layer.dtype,
305
306
            device=torch.cuda.current_device())

307
        # Raw attention scores. [b * np, sq, sk]
308
309
        matmul_result = torch.baddbmm(
            matmul_result,
310
            query_layer.transpose(0, 1),   # [b * np, sq, hn]
311
            key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
312
313
            beta=0.0, alpha=(1.0/self.norm_factor))

314
        # change view to [b, np, sq, sk]
315
316
        attention_scores = matmul_result.view(*output_size)

317

318
319
320
        # ===========================
        # Attention probs and dropout
        # ===========================
321

322
        # attention scores and attention mask [b, np, sq, sk]
323
324
        attention_probs = self.scale_mask_softmax(attention_scores,
                                                  attention_mask)
325

326
327
328
329
330
331
        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        with mpu.get_cuda_rng_tracker().fork():
            attention_probs = self.attention_dropout(attention_probs)

        # =========================
332
        # Context layer. [sq, b, hp]
333
334
        # =========================

335
336
        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]
337

338
        # context layer shape: [b, np, sq, hn]
339
340
341
342
        output_size = (value_layer.size(1),
                       value_layer.size(2),
                       query_layer.size(0),
                       value_layer.size(3))
343

344
        # change view [sk, b * np, hn]
345
        value_layer = value_layer.view(value_layer.size(0),
346
                                       output_size[0] * output_size[1], -1)
347

348
        # change view [b * np, sq, sk]
349
350
        attention_probs = attention_probs.view(output_size[0] * output_size[1],
                                               output_size[2], -1)
351

352
        # matmul: [b * np, sq, hn]
353
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
354

355
        # change view [b, np, sq, hn]
356
357
        context_layer = context_layer.view(*output_size)

358
        # [b, np, sq, hn] --> [sq, b, np, hn]
359
360
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

361
        # [sq, b, np, hn] --> [sq, b, hp]
362
363
364
365
366
        new_context_layer_shape = context_layer.size()[:-2] + \
            (self.hidden_size_per_partition,)
        context_layer = context_layer.view(*new_context_layer_shape)

        # =================
367
        # Output. [sq, b, h]
368
369
370
        # =================

        output, bias = self.dense(context_layer)
371

372
373
374
        return output, bias


375
def bias_dropout_add(x, bias, residual, prob, training):
376
377
378
379
380
381
382
383
384
385
386
387
388
    # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
    out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
    out = residual + out
    return out


def get_bias_dropout_add(training):
    def _bias_dropout_add(x, bias, residual, prob):
        return bias_dropout_add(x, bias, residual, prob, training)
    return _bias_dropout_add


@torch.jit.script
389
390
391
392
def bias_dropout_add_fused_train(x: torch.Tensor,
                                 bias: torch.Tensor,
                                 residual: torch.Tensor,
                                 prob: float) -> torch.Tensor:
393
394
395
396
    return bias_dropout_add(x, bias, residual, prob, True)


@torch.jit.script
397
398
399
400
def bias_dropout_add_fused_inference(x: torch.Tensor,
                                     bias: torch.Tensor,
                                     residual: torch.Tensor,
                                     prob: float) -> torch.Tensor:
401
    return bias_dropout_add(x, bias, residual, prob, False)
402
403
404
405
406


class ParallelTransformerLayer(MegatronModule):
    """A single transformer layer.

407
    Transformer layer takes input with size [b, s, h] and returns an
408
409
    output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
410

411
412
    def __init__(self, init_method, output_layer_init_method,
                 layer_number, layer_type=LayerType.encoder,
413
                 self_attn_mask_type=AttnMaskType.padding):
Mohammad's avatar
Mohammad committed
414
        args = get_args()
415
416

        super(ParallelTransformerLayer, self).__init__()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
417
        self.layer_number = layer_number
418
        self.layer_type = layer_type
419
420

        self.apply_residual_connection_post_layernorm \
Mohammad's avatar
Mohammad committed
421
            = args.apply_residual_connection_post_layernorm
422

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
423
424
425
        self.bf16 = args.bf16
        self.fp32_residual_connection = args.fp32_residual_connection

426
427
        # Layernorm on the input data.
        self.input_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
428
            args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
429
430
            eps=args.layernorm_epsilon,
            no_persist_layer_norm=args.no_persist_layer_norm)
431
432

        # Self attention.
433
434
435
436
437
438
        self.self_attention = ParallelAttention(
            init_method,
            output_layer_init_method,
            layer_number,
            attention_type=AttnType.self_attn,
            attn_mask_type=self_attn_mask_type)
439
440
        self.hidden_dropout = args.hidden_dropout
        self.bias_dropout_fusion = args.bias_dropout_fusion
441

442
        # Layernorm on the attention output
443
        self.post_attention_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
444
            args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
445
446
            eps=args.layernorm_epsilon,
            no_persist_layer_norm=args.no_persist_layer_norm)
447

448
449
450
451
452
453
454
455
456
        if self.layer_type == LayerType.decoder:
            self.inter_attention = ParallelAttention(
                init_method,
                output_layer_init_method,
                layer_number,
                attention_type=AttnType.cross_attn)
            # Layernorm on the attention output.
            self.post_inter_attention_layernorm = LayerNorm(
                args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
457
458
                eps=args.layernorm_epsilon,
                no_persist_layer_norm=args.no_persist_layer_norm)
459

460
        # MLP
461
        self.mlp = ParallelMLP(init_method,
Mohammad's avatar
Mohammad committed
462
                               output_layer_init_method)
463

464
    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
465
466
                encoder_output=None, enc_dec_attn_mask=None,
                inference_params=None):
467
468
        # hidden_states: [b, s, h]

469
        # Layer norm at the beginning of the transformer layer.
470
471
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
472
        attention_output, attention_bias = \
473
474
475
            self.self_attention(
                layernorm_output,
                attention_mask,
mshoeybi's avatar
mshoeybi committed
476
                inference_params=inference_params)
477

478
479
        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
480
481
482
483
            residual = layernorm_output
        else:
            residual = hidden_states

484
485
        # jit scripting for a nn.module (with dropout) is not
        # trigerring the fusion kernel. For now, we use two
486
487
488
489
490
491
492
        # different nn.functional routines to account for varying
        # dropout semantics during training and inference phases.
        if self.bias_dropout_fusion:
            if self.training:
                bias_dropout_add_func = bias_dropout_add_fused_train
            else:
                bias_dropout_add_func = bias_dropout_add_fused_inference
493
        else:
494
495
            bias_dropout_add_func = get_bias_dropout_add(self.training)

496
        # re-enable torch grad to enable fused optimization.
497
498
499
500
501
502
503
        with torch.enable_grad():
            layernorm_input = bias_dropout_add_func(
                attention_output,
                attention_bias.expand_as(residual),
                residual,
                self.hidden_dropout)

504
505
506
        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)

507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
        if self.layer_type == LayerType.decoder:
            attention_output, attention_bias = \
                self.inter_attention(layernorm_output,
                                     enc_dec_attn_mask,
                                     encoder_output=encoder_output)
            # residual connection
            if self.apply_residual_connection_post_layernorm:
                residual = layernorm_output
            else:
                residual = layernorm_input

            # re-enable torch grad to enable fused optimization.
            with torch.enable_grad():
                layernorm_input = bias_dropout_add_func(
                    attention_output,
                    attention_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)

            # Layer norm post the decoder attention
            layernorm_output = self.post_inter_attention_layernorm(layernorm_input)

529
        # MLP.
530
        mlp_output, mlp_bias = self.mlp(layernorm_output)
531

532
533
        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
534
            residual = layernorm_output
535
        else:
536
537
            residual = layernorm_input

538
        # re-enable torch grad to enable fused optimization.
539
540
541
542
543
544
        with torch.enable_grad():
            output = bias_dropout_add_func(
                mlp_output,
                mlp_bias.expand_as(residual),
                residual,
                self.hidden_dropout)
545
546
547
548
549
550
551

        return output


class ParallelTransformer(MegatronModule):
    """Transformer class."""

552
    def __init__(self, init_method, output_layer_init_method,
553
                 layer_type=LayerType.encoder,
554
555
                 self_attn_mask_type=AttnMaskType.padding,
                 pre_process=True, post_process=True):
556
        super(ParallelTransformer, self).__init__()
Mohammad's avatar
Mohammad committed
557
        args = get_args()
558

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
559
        self.bf16 = args.bf16
560
        self.fp32_residual_connection = args.fp32_residual_connection
561
562
563
        self.pre_process = pre_process
        self.post_process = post_process
        self.input_tensor = None
564

565
        # Store activation checkpoiting flag.
566
567
        self.activations_checkpoint_method = args.activations_checkpoint_method
        self.activations_checkpoint_num_layers = args.activations_checkpoint_num_layers
mshoeybi's avatar
mshoeybi committed
568
        self.distribute_checkpointed_activations = args.distribute_checkpointed_activations
569

570
        # Number of layers.
571
572
        self.num_layers = mpu.get_num_layers(
            args, args.model_type == ModelType.encoder_and_decoder)
Mohammad's avatar
Mohammad committed
573
574
575

        # Transformer layers.
        def build_layer(layer_number):
576
            return ParallelTransformerLayer(
577
578
579
                init_method,
                output_layer_init_method,
                layer_number,
580
581
                layer_type=layer_type,
                self_attn_mask_type=self_attn_mask_type)
582
583
        if args.virtual_pipeline_model_parallel_size is not None:
            assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
584
585
                'num_layers_per_stage must be divisible by ' \
                'virtual_pipeline_model_parallel_size'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
586
            assert args.model_type != ModelType.encoder_and_decoder
587
588
            # Number of layers in each model chunk is the number of layers in the stage,
            # divided by the number of model chunks in a stage.
589
            self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
590
591
592
593
594
595
596
597
            # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0]  [2]  [4]  [6]
            # Stage 1: [1]  [3]  [5]  [7]
            # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0, 1]  [4, 5]
            # Stage 1: [2, 3]  [6, 7]
598
            offset = mpu.get_virtual_pipeline_model_parallel_rank() * (
599
                args.num_layers // args.virtual_pipeline_model_parallel_size) + \
600
601
                (mpu.get_pipeline_model_parallel_rank() * self.num_layers)
        else:
602
            # Each stage gets a contiguous set of layers.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
603
604
            if args.model_type == ModelType.encoder_and_decoder and \
                    mpu.get_pipeline_model_parallel_world_size() > 1:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
605
606
607
608
609
610
611
612
                pipeline_rank = mpu.get_pipeline_model_parallel_rank()
                if layer_type == LayerType.encoder:
                    offset = pipeline_rank * self.num_layers
                else:
                    num_ranks_in_enc = args.pipeline_model_parallel_split_rank
                    offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers
            else:
                offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
613

614
        self.layers = torch.nn.ModuleList(
615
            [build_layer(i + 1 + offset) for i in range(self.num_layers)])
616

617
        if self.post_process:
618
619
620
            # Final layer norm before output.
            self.final_layernorm = LayerNorm(
                args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
621
622
                eps=args.layernorm_epsilon,
                no_persist_layer_norm=args.no_persist_layer_norm)
623

Mohammad's avatar
Mohammad committed
624
    def _get_layer(self, layer_number):
625
        return self.layers[layer_number]
Mohammad's avatar
Mohammad committed
626

627
628
    def _checkpointed_forward(self, hidden_states, attention_mask,
                              encoder_output, enc_dec_attn_mask):
629
630
631
632
        """Forward method with activation checkpointing."""
        def custom(start, end):
            def custom_forward(*inputs):
                x_ = inputs[0]
633
634
635
                attention_mask = inputs[1]
                encoder_output = inputs[2]
                enc_dec_attn_mask = inputs[3]
Mohammad's avatar
Mohammad committed
636
637
                for index in range(start, end):
                    layer = self._get_layer(index)
638
                    x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
639
640
641
                return x_
            return custom_forward

642
643
644
645
646
647
648
649
        if self.activations_checkpoint_method == 'uniform':
            # Uniformly divide the total number of Transformer layers and checkpoint
            # the input activation of each divided chunk.
            # A method to further reduce memory usage reducing checkpoints.
            l = 0
            while l < self.num_layers:
                hidden_states = mpu.checkpoint(
                    custom(l, l + self.activations_checkpoint_num_layers),
650
                    self.distribute_checkpointed_activations,
651
652
653
654
655
656
657
658
659
660
                    hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
                l += self.activations_checkpoint_num_layers
        elif self.activations_checkpoint_method == 'block':
            # Checkpoint the input activation of only a set number of individual
            # Transformer layers and skip the rest.
            # A method fully use the device memory removing redundant re-computation.
            for l in range(self.num_layers):
                if l < self.activations_checkpoint_num_layers:
                    hidden_states = mpu.checkpoint(
                        custom(l, l + 1),
661
                        self.distribute_checkpointed_activations,
662
663
664
665
666
667
                        hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
                else:
                    hidden_states = custom(l, l + 1)(
                        hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
        else:
            raise ValueError("Invalid activation checkpoint method.")
668
669
670

        return hidden_states

671
    def set_input_tensor(self, input_tensor):
672
673
674
675
676
677
678
        """Set input tensor to be used instead of forward()'s input.

        When doing pipeline parallelism the input from the previous
        stage comes from communication, not from the input, so the
        model's forward_step_func won't have it. This function is thus
        used by internal code to bypass the input provided by the
        forward_step_func"""
679
680
        self.input_tensor = input_tensor

681
    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
682
683
                encoder_output=None, enc_dec_attn_mask=None,
                inference_params=None):
684

685
        # Checks.
mshoeybi's avatar
mshoeybi committed
686
        if inference_params:
687
            assert self.activations_checkpoint_method is None, \
688
                'inference does not work with activation checkpointing'
689

690
        if self.pre_process:
691
            # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
mshoeybi's avatar
mshoeybi committed
692
            # If the input flag for fp32 residual connection is set, convert for float.
693
694
            if self.fp32_residual_connection:
                hidden_states = hidden_states.transpose(0, 1).contiguous().float()
mshoeybi's avatar
mshoeybi committed
695
            # Otherwise, leave it as is.
696
697
            else:
                hidden_states = hidden_states.transpose(0, 1).contiguous()
698
        else:
699
            # See set_input_tensor()
700
            hidden_states = self.input_tensor
701

702
        # hidden_states = make_standalone_tensor(hidden_states)
703
704
705
706
        # hidden_states = MakeStandaloneTensor.apply(hidden_states)
        # hidden_states = MakeViewlessTensor.apply(hidden_states)
        hidden_states = make_viewless_tensor(hidden_states)
        # hidden_states = hidden_states.clone()
707
708
        # >>>
        # from lutil import pax
709
        # pax(0, {"hidden_states": hidden_states})
710
711
        # <<<

Vijay Korthikanti's avatar
Vijay Korthikanti committed
712
713
        if encoder_output is not None:
             encoder_output = encoder_output.transpose(0, 1).contiguous()
714

715
        if self.activations_checkpoint_method is not None:
716
            hidden_states = self._checkpointed_forward(hidden_states,
717
718
719
                                                       attention_mask,
                                                       encoder_output,
                                                       enc_dec_attn_mask)
720
        else:
Mohammad's avatar
Mohammad committed
721
722
            for index in range(self.num_layers):
                layer = self._get_layer(index)
723
724
725
726
727
                hidden_states = layer(
                    hidden_states,
                    attention_mask,
                    encoder_output=encoder_output,
                    enc_dec_attn_mask=enc_dec_attn_mask,
mshoeybi's avatar
mshoeybi committed
728
729
                    inference_params=inference_params)

730

731
        # Final layer norm.
732
        if self.post_process:
733
734
            # Reverting data format change [s b h] --> [b s h].
            hidden_states = hidden_states.transpose(0, 1).contiguous()
735
736
737
            output = self.final_layernorm(hidden_states)
        else:
            output = hidden_states
738
        
739
        return output