transformer.py 42.2 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4

"""Transformer."""
import math
5
from contextlib import nullcontext
6
import torch
7
import torch.nn.functional as F
8

9
from megatron import get_timers, get_args, core
10
from .module import MegatronModule
11
from megatron.core import mpu, tensor_parallel
12
from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType
13
from megatron.model import LayerNorm
14
15
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
16
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
17

18

19
20
21
22
23
24
25
26
27
28
""" We use the following notation throughout this file:
     h: hidden size
     n: number of attention heads
     p: number of model parallel partitions
     np: n/p
     hp: h/p
     hn: h/n
     b: batch size
     s: sequence length
     l: number of layers
29
    Transformer takes input of size [s, b, h] and returns a
30
31
32
33
    tensor of the same size. We use the following arguments:
        hyperparameters: transformer hyperparameters
"""

34
class DropPath(MegatronModule):
35
    """Drop paths (Stochastic Depth) per sample
36
37
38
    (when applied in main path of residual blocks).
    """

Vijay Korthikanti's avatar
Vijay Korthikanti committed
39
    def __init__(self, drop_prob=0.):
40
41
42
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

Vijay Korthikanti's avatar
Vijay Korthikanti committed
43
    def forward(self, hidden_state):
44
        if self.drop_prob == 0. or not self.training:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
45
            return hidden_state
46
47
        keep_prob = 1 - self.drop_prob
        # work with diff dim tensors, not just 2D ConvNets
Vijay Korthikanti's avatar
Vijay Korthikanti committed
48
        shape = (hidden_state.shape[0],) + (1,) * (hidden_state.ndim - 1)
49
        random_tensor = keep_prob + \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
50
            torch.rand(shape, dtype=hidden_state.dtype, device=hidden_state.device)
51
        random_tensor.floor_()  # binarize
Vijay Korthikanti's avatar
Vijay Korthikanti committed
52
        output = hidden_state.div(keep_prob) * random_tensor
53
54
        return output

55
56
57
58
59
60
61
62
63
64
65
def _args_to_kwargs():
    args = get_args()

    common_kwargs = {
        "params_dtype": args.params_dtype,
        "use_cpu_initialization": args.use_cpu_initialization,
        "perform_initialization": args.perform_initialization,
        "gradient_accumulation_fusion": args.gradient_accumulation_fusion,
        "sequence_parallel_enabled": args.sequence_parallel,
    }
    return common_kwargs
66

67
68
69
70
71
class ParallelMLP(MegatronModule):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
hwijeen's avatar
hwijeen committed
72
    state back into h hidden dimension.
73
74
    """

75
    def __init__(self, init_method, output_layer_init_method):
76
        super(ParallelMLP, self).__init__()
Mohammad's avatar
Mohammad committed
77
        args = get_args()
78

79

80
        # Project to 4h.
81
        self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear(
Mohammad's avatar
Mohammad committed
82
            args.hidden_size,
83
            args.ffn_hidden_size,
84
            gather_output=False,
85
            init_method=init_method,
86
87
88
            skip_bias_add=True,
            async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
            **_args_to_kwargs())
89

90
91
92
93
94
95
        self.bias_gelu_fusion = args.bias_gelu_fusion
        self.activation_func = F.gelu
        if args.openai_gelu:
            self.activation_func = openai_gelu
        elif args.onnx_safe:
            self.activation_func = erf_gelu
96
97

        # Project back to h.
98
        self.dense_4h_to_h = tensor_parallel.RowParallelLinear(
99
            args.ffn_hidden_size,
Mohammad's avatar
Mohammad committed
100
            args.hidden_size,
101
            input_is_parallel=True,
102
            init_method=output_layer_init_method,
103
104
            skip_bias_add=True,
            **_args_to_kwargs())
105

106
107
    def forward(self, hidden_states):

108
109
        # [s, b, 4hp]
        intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
110

111
112
113
114
115
116
117
118
119
120
        if self.bias_gelu_fusion:
             intermediate_parallel = \
                     bias_gelu_impl(intermediate_parallel, bias_parallel)
        else:
            intermediate_parallel = \
                self.activation_func(intermediate_parallel + bias_parallel)

        # [s, b, h]
        output, output_bias = self.dense_4h_to_h(intermediate_parallel)
        return output, output_bias
121

rprenger's avatar
rprenger committed
122
123
124
125
class SwitchMLP(MegatronModule):
    """
    Routes input to one of N MLP "experts"
    """
rprenger's avatar
rprenger committed
126
    def __init__(self, init_method, output_layer_init_method):
rprenger's avatar
rprenger committed
127
128
        super(SwitchMLP, self).__init__()
        args = get_args()
rprenger's avatar
rprenger committed
129
        self.router = torch.nn.Linear(args.hidden_size, args.num_experts)
rprenger's avatar
rprenger committed
130
        self.experts = torch.nn.ModuleList()
rprenger's avatar
rprenger committed
131
        for i in range(args.num_experts):
rprenger's avatar
rprenger committed
132
            self.experts.append(ParallelMLP(init_method, output_layer_init_method))
133

rprenger's avatar
rprenger committed
134
    def forward(self, hidden_states):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
135
136
137
        # hidden_states: [s, b, h]
        s = hidden_states.size(0)
        b = hidden_states.size(1)
rprenger's avatar
rprenger committed
138
139
        h = hidden_states.size(2)
        route = self.router(hidden_states)
rprenger's avatar
rprenger committed
140
        route = torch.nn.functional.softmax(route, dim=2)
rprenger's avatar
rprenger committed
141
        max_prob, max_ind = torch.max(route, dim=2)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
142
        max_prob = torch.unsqueeze(max_prob, 2) # [s b 1]
143

rprenger's avatar
rprenger committed
144
        # TODO (rprenger) TODO this could be made easier to read
Vijay Korthikanti's avatar
Vijay Korthikanti committed
145
        # Converting [s, b, h] to [s*b, h].
146
        # Each vector could be routed differently
Vijay Korthikanti's avatar
Vijay Korthikanti committed
147
148
149
        hidden_states = hidden_states.view(-1, hidden_states.size(2)) # [s*b h]
        max_prob = max_prob.view(-1, max_prob.size(2)) # [s*b 1]
        max_ind = max_ind.view(-1) # [s*b]
rprenger's avatar
rprenger committed
150
151
152

        output_total = torch.empty_like(hidden_states)
        output_bias_total = torch.empty_like(hidden_states)
rprenger's avatar
rprenger committed
153
        #TODO (rprenger) This does each expert in serial, but it could be parallelized
154

rprenger's avatar
rprenger committed
155
        for expert_num, expert in enumerate(self.experts):
156
157
            local_indices = (max_ind == expert_num).nonzero()
            hidden = hidden_states[local_indices,:]
rprenger's avatar
rprenger committed
158
159
            output, output_bias = expert(hidden)
            output_bias = output_bias.expand_as(output)
160
161
162
            output_total[local_indices,:] = output
            output_bias_total[local_indices,:] = output_bias

rprenger's avatar
rprenger committed
163
164
        output_total = output_total*max_prob
        output_bias_total = output_bias_total*max_prob
Vijay Korthikanti's avatar
Vijay Korthikanti committed
165
166
        output_total = output_total.view(s, b, h)
        output_bias_total = output_bias_total.view(s, b, h)
rprenger's avatar
rprenger committed
167
168

        return output_total, output_bias_total
169

170
171

class CoreAttention(MegatronModule):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
172

173
174
175
176
177
178
179
180
181
182
183
184
185
    def __init__(self, layer_number,
                 attn_mask_type=AttnMaskType.padding):
        super(CoreAttention, self).__init__()
        args = get_args()
        self.fp16 = args.fp16
        self.bf16 = args.bf16

        self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
        if self.apply_query_key_layer_scaling:
            self.attention_softmax_in_fp32 = True
        self.layer_number = max(1, layer_number)
        self.attn_mask_type = attn_mask_type
Vijay Korthikanti's avatar
Vijay Korthikanti committed
186
        self.sequence_parallel = args.sequence_parallel
187
188
189
190

        projection_size = args.kv_channels * args.num_attention_heads

        # Per attention head and per partition values.
191
        world_size = mpu.get_tensor_model_parallel_world_size()
192
193
194
        self.hidden_size_per_partition = core.utils.divide(projection_size,
                                                           world_size)
        self.hidden_size_per_attention_head = core.utils.divide(
195
            projection_size, args.num_attention_heads)
196
        self.num_attention_heads_per_partition = core.utils.divide(
197
            args.num_attention_heads, world_size)
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216

        coeff = None
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        if self.apply_query_key_layer_scaling:
            coeff = self.layer_number
            self.norm_factor *= coeff

        self.scale_mask_softmax = FusedScaleMaskSoftmax(
            self.fp16, self.bf16,
            self.attn_mask_type,
            args.masked_softmax_fusion,
            attention_mask_func,
            self.attention_softmax_in_fp32,
            coeff)

        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
        self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
217

218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
    def forward(self, query_layer, key_layer,
                value_layer, attention_mask):

        # ===================================
        # Raw attention scores. [b, np, s, s]
        # ===================================

        # [b, np, sq, sk]
        output_size = (query_layer.size(1),
                       query_layer.size(2),
                       query_layer.size(0),
                       key_layer.size(0))

        # [sq, b, np, hn] -> [sq, b * np, hn]
        query_layer = query_layer.view(output_size[2],
                                       output_size[0] * output_size[1], -1)
        # [sk, b, np, hn] -> [sk, b * np, hn]
        key_layer = key_layer.view(output_size[3],
                                   output_size[0] * output_size[1], -1)

Vijay Korthikanti's avatar
Vijay Korthikanti committed
238
        # preallocting input tensor: [b * np, sq, sk]
239
        matmul_input_buffer = mpu.get_global_memory_buffer().get_tensor(
240
            (output_size[0]*output_size[1], output_size[2], output_size[3]),
Vijay Korthikanti's avatar
Vijay Korthikanti committed
241
            query_layer.dtype, "mpu")
242
243
244

        # Raw attention scores. [b * np, sq, sk]
        matmul_result = torch.baddbmm(
Vijay Korthikanti's avatar
Vijay Korthikanti committed
245
            matmul_input_buffer,
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
            query_layer.transpose(0, 1),   # [b * np, sq, hn]
            key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
            beta=0.0, alpha=(1.0/self.norm_factor))

        # change view to [b, np, sq, sk]
        attention_scores = matmul_result.view(*output_size)

        # ===========================
        # Attention probs and dropout
        # ===========================

        # attention scores and attention mask [b, np, sq, sk]
        attention_probs = self.scale_mask_softmax(attention_scores,
                                                  attention_mask)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.

Vijay Korthikanti's avatar
Vijay Korthikanti committed
264
        if not self.sequence_parallel:
265
            with tensor_parallel.get_cuda_rng_tracker().fork():
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
                attention_probs = self.attention_dropout(attention_probs)
        else:
            attention_probs = self.attention_dropout(attention_probs)

        # =========================
        # Context layer. [sq, b, hp]
        # =========================

        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]

        # context layer shape: [b, np, sq, hn]
        output_size = (value_layer.size(1),
                       value_layer.size(2),
                       query_layer.size(0),
                       value_layer.size(3))

        # change view [sk, b * np, hn]
        value_layer = value_layer.view(value_layer.size(0),
                                       output_size[0] * output_size[1], -1)

        # change view [b * np, sq, sk]
        attention_probs = attention_probs.view(output_size[0] * output_size[1],
                                               output_size[2], -1)

        # matmul: [b * np, sq, hn]
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))

        # change view [b, np, sq, hn]
        context_layer = context_layer.view(*output_size)

        # [b, np, sq, hn] --> [sq, b, np, hn]
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

        # [sq, b, np, hn] --> [sq, b, hp]
        new_context_layer_shape = context_layer.size()[:-2] + \
            (self.hidden_size_per_partition,)
        context_layer = context_layer.view(*new_context_layer_shape)

        return context_layer


308
class ParallelAttention(MegatronModule):
309
310
    """Parallel self-attention layer abstract class.

Vijay Korthikanti's avatar
Vijay Korthikanti committed
311
    Self-attention layer takes input with size [s, b, h]
312
313
    and returns output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
314

315
    def __init__(self, init_method,
316
317
318
319
                 output_layer_init_method, layer_number,
                 attention_type=AttnType.self_attn,
                 attn_mask_type=AttnMaskType.padding):
        super(ParallelAttention, self).__init__()
Mohammad's avatar
Mohammad committed
320
        args = get_args()
321
        self.layer_number = max(1, layer_number)
322
323
        self.attention_type = attention_type
        self.attn_mask_type = attn_mask_type
324
        self.params_dtype = args.params_dtype
325
326

        projection_size = args.kv_channels * args.num_attention_heads
327
328

        # Per attention head and per partition values.
329
        world_size = mpu.get_tensor_model_parallel_world_size()
330
        self.hidden_size_per_attention_head = core.utils.divide(
331
            projection_size, args.num_attention_heads)
332
        self.num_attention_heads_per_partition = core.utils.divide(
Mohammad's avatar
Mohammad committed
333
            args.num_attention_heads, world_size)
334
335

        # Strided linear layer.
336
        if attention_type == AttnType.self_attn:
337
            self.query_key_value = tensor_parallel.ColumnParallelLinear(
338
339
340
                args.hidden_size,
                3 * projection_size,
                gather_output=False,
341
342
343
                init_method=init_method,
                async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
                **_args_to_kwargs())
344
345
        else:
            assert attention_type == AttnType.cross_attn
346
            self.query = tensor_parallel.ColumnParallelLinear(
347
348
349
                args.hidden_size,
                projection_size,
                gather_output=False,
350
351
352
                init_method=init_method,
                async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
                **_args_to_kwargs())
353

354

355
            self.key_value = tensor_parallel.ColumnParallelLinear(
356
357
358
                args.hidden_size,
                2 * projection_size,
                gather_output=False,
359
360
361
                init_method=init_method,
                async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
                **_args_to_kwargs())
362

363
364
        self.core_attention = CoreAttention(self.layer_number,
                                            self.attn_mask_type)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
365
        self.checkpoint_core_attention = args.recompute_granularity == 'selective'
366
367

        # Output.
368
        self.dense = tensor_parallel.RowParallelLinear(
369
            projection_size,
Mohammad's avatar
Mohammad committed
370
            args.hidden_size,
371
            input_is_parallel=True,
372
            init_method=output_layer_init_method,
373
374
            skip_bias_add=True,
            **_args_to_kwargs())
Vijay Korthikanti's avatar
Vijay Korthikanti committed
375

376
377
378
379
380
381
382
383
384
385
386
387
    def _checkpointed_attention_forward(self, query_layer, key_layer,
                                        value_layer, attention_mask):
        """Forward method with activation checkpointing."""
        def custom_forward(*inputs):
            query_layer = inputs[0]
            key_layer = inputs[1]
            value_layer = inputs[2]
            attention_mask = inputs[3]
            output_ = self.core_attention(query_layer, key_layer,
                                          value_layer, attention_mask)
            return output_

388
        hidden_states = tensor_parallel.checkpoint(
389
390
391
392
            custom_forward,
            False, query_layer, key_layer, value_layer, attention_mask)

        return hidden_states
393
394
395
396
397
398
399
400
401
402
403

    def _allocate_memory(self, inference_max_sequence_len, batch_size):
        return torch.empty(
            inference_max_sequence_len,
            batch_size,
            self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
            dtype=self.params_dtype,
            device=torch.cuda.current_device())

    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
404
                encoder_output=None, inference_params=None):
405
        # hidden_states: [sq, b, h]
406

407
408
409
        # =================================================
        # Pre-allocate memory for key-values for inference.
        # =================================================
mshoeybi's avatar
mshoeybi committed
410
        if inference_params:
411
            if self.layer_number not in inference_params.key_value_memory_dict:
mshoeybi's avatar
mshoeybi committed
412
                inf_max_seq_len = inference_params.max_sequence_len
mshoeybi's avatar
mshoeybi committed
413
                inf_max_batch_size = inference_params.max_batch_size
414
                inference_key_memory = self._allocate_memory(
mshoeybi's avatar
mshoeybi committed
415
                    inf_max_seq_len, inf_max_batch_size)
416
                inference_value_memory = self._allocate_memory(
mshoeybi's avatar
mshoeybi committed
417
                    inf_max_seq_len, inf_max_batch_size)
418
419
420
421
422
                inference_params.key_value_memory_dict[self.layer_number] = (
                    inference_key_memory, inference_value_memory)
            else:
                inference_key_memory, inference_value_memory = \
                    inference_params.key_value_memory_dict[self.layer_number]
mshoeybi's avatar
mshoeybi committed
423

424
425
426
        # =====================
        # Query, Key, and Value
        # =====================
427

428
429
430
431
432
433
434
435
436
437
438
439
440
        if self.attention_type == AttnType.self_attn:
            # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
            mixed_x_layer, _ = self.query_key_value(hidden_states)

            # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
            new_tensor_shape = mixed_x_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 3 * self.hidden_size_per_attention_head)
            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

            # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
            (query_layer,
             key_layer,
441
             value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_x_layer, 3)
442
443
444
445
446
447
448
449
450
451
452
453
        else:
            # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
            mixed_kv_layer, _ = self.key_value(encoder_output)

            # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
            new_tensor_shape = mixed_kv_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 2 * self.hidden_size_per_attention_head)
            mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

            # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
            (key_layer,
454
             value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2)
455
456
457
458
459
460
461
462

            # Attention head [sq, b, h] --> [sq, b, hp]
            query_layer, _ = self.query(hidden_states)
            # [sq, b, hp] --> [sq, b, np, hn]
            new_tensor_shape = query_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 self.hidden_size_per_attention_head)
            query_layer = query_layer.view(*new_tensor_shape)
463

mshoeybi's avatar
mshoeybi committed
464
465
466
        # ==================================
        # Adjust key and value for inference
        # ==================================
467

mshoeybi's avatar
mshoeybi committed
468
        if inference_params:
mshoeybi's avatar
mshoeybi committed
469
470
            batch_start = inference_params.batch_size_offset
            batch_end = batch_start + key_layer.size(1)
471
            assert batch_end <= inference_key_memory.size(1)
mshoeybi's avatar
mshoeybi committed
472
473
            sequence_start = inference_params.sequence_len_offset
            sequence_end = sequence_start + key_layer.size(0)
474
            assert sequence_end <= inference_key_memory.size(0)
475
            # Copy key and values.
476
477
478
479
480
            inference_key_memory[sequence_start:sequence_end,
                                 batch_start:batch_end, ...] = key_layer
            inference_value_memory[sequence_start:sequence_end,
                                   batch_start:batch_end, ...] = value_layer
            key_layer = inference_key_memory[
mshoeybi's avatar
mshoeybi committed
481
                :sequence_end, batch_start:batch_end, ...]
482
            value_layer = inference_value_memory[
mshoeybi's avatar
mshoeybi committed
483
                :sequence_end, batch_start:batch_end, ...]
484

485
486
487
        # ==================================
        # core attention computation
        # ==================================
488

Vijay Korthikanti's avatar
Vijay Korthikanti committed
489
        if self.checkpoint_core_attention:
490
491
            context_layer = self._checkpointed_attention_forward(
                query_layer, key_layer, value_layer, attention_mask)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
492
        else:
493
494
            context_layer = self.core_attention(
                query_layer, key_layer, value_layer, attention_mask)
495
496

        # =================
497
        # Output. [sq, b, h]
498
499
500
        # =================

        output, bias = self.dense(context_layer)
501

502
503
504
        return output, bias


505
def bias_dropout_add(x, bias, residual, prob, training):
506
507
508
509
510
511
512
513
514
515
516
517
518
    # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
    out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
    out = residual + out
    return out


def get_bias_dropout_add(training):
    def _bias_dropout_add(x, bias, residual, prob):
        return bias_dropout_add(x, bias, residual, prob, training)
    return _bias_dropout_add


@torch.jit.script
519
520
521
522
def bias_dropout_add_fused_train(x: torch.Tensor,
                                 bias: torch.Tensor,
                                 residual: torch.Tensor,
                                 prob: float) -> torch.Tensor:
523
524
525
526
    return bias_dropout_add(x, bias, residual, prob, True)


@torch.jit.script
527
528
529
530
def bias_dropout_add_fused_inference(x: torch.Tensor,
                                     bias: torch.Tensor,
                                     residual: torch.Tensor,
                                     prob: float) -> torch.Tensor:
531
    return bias_dropout_add(x, bias, residual, prob, False)
532
533
534
535
536


class ParallelTransformerLayer(MegatronModule):
    """A single transformer layer.

Vijay Korthikanti's avatar
Vijay Korthikanti committed
537
    Transformer layer takes input with size [s, b, h] and returns an
538
539
    output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
540

541
542
    def __init__(self, init_method, output_layer_init_method,
                 layer_number, layer_type=LayerType.encoder,
543
544
                 self_attn_mask_type=AttnMaskType.padding,
                 drop_path_rate=0.):
Mohammad's avatar
Mohammad committed
545
        args = get_args()
546
547

        super(ParallelTransformerLayer, self).__init__()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
548
        self.layer_number = layer_number
549
        self.layer_type = layer_type
550
551

        self.apply_residual_connection_post_layernorm \
Mohammad's avatar
Mohammad committed
552
            = args.apply_residual_connection_post_layernorm
553

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
554
555
556
        self.bf16 = args.bf16
        self.fp32_residual_connection = args.fp32_residual_connection

557
558
        # Layernorm on the input data.
        self.input_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
559
            args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
560
            eps=args.layernorm_epsilon,
561
            no_persist_layer_norm=args.no_persist_layer_norm,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
562
            sequence_parallel=args.sequence_parallel)
563
564

        # Self attention.
565
566
567
568
569
570
        self.self_attention = ParallelAttention(
            init_method,
            output_layer_init_method,
            layer_number,
            attention_type=AttnType.self_attn,
            attn_mask_type=self_attn_mask_type)
571
572
        self.hidden_dropout = args.hidden_dropout
        self.bias_dropout_fusion = args.bias_dropout_fusion
Vijay Korthikanti's avatar
Vijay Korthikanti committed
573
        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None
574

575
        # Layernorm on the attention output
576
        self.post_attention_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
577
            args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
578
            eps=args.layernorm_epsilon,
579
            no_persist_layer_norm=args.no_persist_layer_norm,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
580
            sequence_parallel=args.sequence_parallel)
581

582
583
584
585
586
587
588
589
590
        if self.layer_type == LayerType.decoder:
            self.inter_attention = ParallelAttention(
                init_method,
                output_layer_init_method,
                layer_number,
                attention_type=AttnType.cross_attn)
            # Layernorm on the attention output.
            self.post_inter_attention_layernorm = LayerNorm(
                args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
591
                eps=args.layernorm_epsilon,
592
                no_persist_layer_norm=args.no_persist_layer_norm,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
593
                sequence_parallel=args.sequence_parallel)
594

595
        # MLP
rprenger's avatar
rprenger committed
596
597
598
599
        if args.num_experts is not None:
            self.mlp = SwitchMLP(init_method, output_layer_init_method)
        else:
            self.mlp = ParallelMLP(init_method, output_layer_init_method)
600

601
602
603
604
605
606
607
        # Set bias+dropout+add fusion grad_enable execution handler.
        TORCH_MAJOR = int(torch.__version__.split('.')[0])
        TORCH_MINOR = int(torch.__version__.split('.')[1])
        use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
        self.bias_dropout_add_exec_handler = \
                nullcontext if use_nvfuser else torch.enable_grad

608
    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
609
610
                encoder_output=None, enc_dec_attn_mask=None,
                inference_params=None):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
611
        # hidden_states: [s, b, h]
612

613
        # Layer norm at the beginning of the transformer layer.
614
615
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
616
        attention_output, attention_bias = \
617
618
619
            self.self_attention(
                layernorm_output,
                attention_mask,
mshoeybi's avatar
mshoeybi committed
620
                inference_params=inference_params)
621

622
623
        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
624
625
626
627
            residual = layernorm_output
        else:
            residual = hidden_states

Vijay Korthikanti's avatar
Vijay Korthikanti committed
628
        if self.drop_path is None:
629
630
631
632
633
634
635
636
637
            # jit scripting for a nn.module (with dropout) is not
            # trigerring the fusion kernel. For now, we use two
            # different nn.functional routines to account for varying
            # dropout semantics during training and inference phases.
            if self.bias_dropout_fusion:
                if self.training:
                    bias_dropout_add_func = bias_dropout_add_fused_train
                else:
                    bias_dropout_add_func = bias_dropout_add_fused_inference
638
            else:
639
                bias_dropout_add_func = get_bias_dropout_add(self.training)
640

641
            with self.bias_dropout_add_exec_handler():
642
643
644
645
646
647
648
649
650
651
                layernorm_input = bias_dropout_add_func(
                    attention_output,
                    attention_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)
        else:
            out = torch.nn.functional.dropout(attention_output + attention_bias,
                                              p=self.hidden_dropout,
                                              training=self.training)
            layernorm_input = residual + self.drop_path(out)
652

653
654
655
        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)

656
657
658
659
660
661
662
663
664
665
666
        if self.layer_type == LayerType.decoder:
            attention_output, attention_bias = \
                self.inter_attention(layernorm_output,
                                     enc_dec_attn_mask,
                                     encoder_output=encoder_output)
            # residual connection
            if self.apply_residual_connection_post_layernorm:
                residual = layernorm_output
            else:
                residual = layernorm_input

667
            with self.bias_dropout_add_exec_handler():
668
669
670
671
672
673
674
675
676
                layernorm_input = bias_dropout_add_func(
                    attention_output,
                    attention_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)

            # Layer norm post the decoder attention
            layernorm_output = self.post_inter_attention_layernorm(layernorm_input)

677
        # MLP.
678
        mlp_output, mlp_bias = self.mlp(layernorm_output)
679

680
681
        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
682
            residual = layernorm_output
683
        else:
684
685
            residual = layernorm_input

Vijay Korthikanti's avatar
Vijay Korthikanti committed
686
        if self.drop_path is None:
687
            with self.bias_dropout_add_exec_handler():
688
689
690
691
692
                output = bias_dropout_add_func(
                    mlp_output,
                    mlp_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)
693
694
695
696
697
698
699

            # Jit compiled function creates 'view' tensor. This tensor
            # potentially gets saved in the MPU checkpoint function context,
            # which rejects view tensors. While making a viewless tensor here
            # won't result in memory savings (like the data loader, or
            # p2p_communication), it serves to document the origin of this
            # 'view' tensor.
700
701
702
            output = core.utils.make_viewless_tensor(inp = output,
                                                     requires_grad = output.requires_grad,
                                                     keep_graph = True)
703

704
705
706
707
708
        else:
            out = torch.nn.functional.dropout(mlp_output + mlp_bias,
                                              p=self.hidden_dropout,
                                              training=self.training)
            output = residual + self.drop_path(out)
709
710
711
712

        return output


713
714
715
class NoopTransformerLayer(MegatronModule):
    """A single 'no-op' transformer layer.

Lawrence McAfee's avatar
Lawrence McAfee committed
716
    The sole purpose of this layer is for when a standalone embedding layer
717
    is used (i.e., args.standalone_embedding_stage == True). In this case,
Lawrence McAfee's avatar
Lawrence McAfee committed
718
719
720
721
722
723
724
725
726
    zero transformer layers are assigned when pipeline rank == 0. Additionally,
    when virtual pipeline rank >= 1, zero total model parameters are created
    (virtual rank 0 contains the input embedding). This results in the model's
    input and output tensors being the same, which causes an error when
    performing certain memory optimiations on the output tensor (e.g.,
    deallocating it). Thus, this layer disconnects the input from the output
    via a clone. Since ranks containing a no-op layer are generally under-
    utilized (both compute and memory), there's no worry of any performance
    degredation.
727
728
729
730
731
732
733
734
735
736
737
738
    """

    def __init__(self, layer_number):
        super().__init__()
        self.layer_number = layer_number

    def forward(self, hidden_states, attention_mask,
                encoder_output=None, enc_dec_attn_mask=None,
                inference_params=None):
        return hidden_states.clone()


Jared Casper's avatar
Jared Casper committed
739
def _get_num_layers(args, is_encoder_and_decoder_model, is_decoder=False):
740
    """Compute the number of transformer layers resident on the current rank."""
Jared Casper's avatar
Jared Casper committed
741
    if get_pipeline_model_parallel_world_size() > 1:
742
743
744
745
746
747
748
749
750
751
752
753
754
        if is_encoder_and_decoder_model:
            assert args.pipeline_model_parallel_split_rank is not None

            # When a standalone embedding stage is used, a rank is taken from
            # the encoder's ranks, to be used for the encoder's embedding
            # layer. This way, the rank referenced by the 'split rank' remains
            # the same whether or not a standalone embedding stage is used.
            num_ranks_in_encoder = (
                args.pipeline_model_parallel_split_rank - 1
                if args.standalone_embedding_stage else
                args.pipeline_model_parallel_split_rank
            )
            num_ranks_in_decoder = args.transformer_pipeline_model_parallel_size - num_ranks_in_encoder
Jared Casper's avatar
Jared Casper committed
755
756
757
758
759
            assert args.encoder_num_layers % num_ranks_in_encoder == 0, \
                    'encoder_num_layers (%d) must be divisible by number of ranks given to encoder (%d)' % (args.encoder_num_layers, num_ranks_in_encoder)
            assert args.decoder_num_layers % num_ranks_in_decoder == 0, \
                    'decoder_num_layers (%d) must be divisible by number of ranks given to decoder (%d)' % (args.decoder_num_layers, num_ranks_in_decoder)
            if is_pipeline_stage_before_split():
760
761
762
                num_layers = (
                    0
                    if args.standalone_embedding_stage
Jared Casper's avatar
Jared Casper committed
763
764
                    and get_pipeline_model_parallel_rank() == 0 else
                    args.encoder_num_layers // num_ranks_in_encoder
765
766
                )
            else:
Jared Casper's avatar
Jared Casper committed
767
                num_layers = args.decoder_num_layers // num_ranks_in_decoder
768
        else:
Jared Casper's avatar
Jared Casper committed
769
            assert args.num_layers == args.encoder_num_layers
770
771
772
773
774
775
776
777
778
779
            assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, \
                'num_layers must be divisible by transformer_pipeline_model_parallel_size'

            # When a standalone embedding stage is used, all transformer layers
            # are divided among pipeline rank >= 1, while on pipeline rank 0,
            # ranks either contain the input embedding layer (virtual pp rank 0),
            # or no layers at all (virtual pp rank >= 1).
            num_layers = (
                0
                if args.standalone_embedding_stage
Jared Casper's avatar
Jared Casper committed
780
                and get_pipeline_model_parallel_rank() == 0 else
781
782
783
                args.num_layers // args.transformer_pipeline_model_parallel_size
            )
    else:
Jared Casper's avatar
Jared Casper committed
784
785
786
787
        if not is_decoder:
            num_layers = args.encoder_num_layers
        else:
            num_layers = args.decoder_num_layers
788
789
790
    return num_layers


791
792
793
class ParallelTransformer(MegatronModule):
    """Transformer class."""

794
    def __init__(self, init_method, output_layer_init_method,
795
                 layer_type=LayerType.encoder,
796
                 self_attn_mask_type=AttnMaskType.padding,
797
                 post_layer_norm=True,
798
799
                 pre_process=True, post_process=True,
                 drop_path_rate=0.0):
800
        super(ParallelTransformer, self).__init__()
Mohammad's avatar
Mohammad committed
801
        args = get_args()
802

803
804
        self.layer_type = layer_type
        self.model_type = args.model_type
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
805
        self.bf16 = args.bf16
806
        self.fp32_residual_connection = args.fp32_residual_connection
807
        self.post_layer_norm = post_layer_norm
808
809
810
        self.pre_process = pre_process
        self.post_process = post_process
        self.input_tensor = None
811
        self.drop_path_rate = drop_path_rate
812

813
        # Store activation checkpoiting flag.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
814
815
816
        self.recompute_granularity = args.recompute_granularity
        self.recompute_method = args.recompute_method
        self.recompute_num_layers = args.recompute_num_layers
Vijay Korthikanti's avatar
Vijay Korthikanti committed
817
818
        self.distribute_saved_activations = \
            args.distribute_saved_activations and not args.sequence_parallel
819

Vijay Korthikanti's avatar
Vijay Korthikanti committed
820
        self.sequence_parallel = args.sequence_parallel
821

822
        # Number of layers.
823
        self.num_layers = _get_num_layers(
824
825
826
            args,
            args.model_type == ModelType.encoder_and_decoder,
            layer_type == LayerType.decoder)
Mohammad's avatar
Mohammad committed
827

Vijay Korthikanti's avatar
Vijay Korthikanti committed
828
        self.drop_path_rates = [rate.item() for rate in torch.linspace(0, self.drop_path_rate, args.num_layers)]
829

Mohammad's avatar
Mohammad committed
830
831
        # Transformer layers.
        def build_layer(layer_number):
832
            return ParallelTransformerLayer(
833
834
835
                init_method,
                output_layer_init_method,
                layer_number,
836
                layer_type=layer_type,
837
                self_attn_mask_type=self_attn_mask_type,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
838
                drop_path_rate=self.drop_path_rates[layer_number - 1])
839
840
        if args.virtual_pipeline_model_parallel_size is not None:
            assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
841
842
                'num_layers_per_stage must be divisible by ' \
                'virtual_pipeline_model_parallel_size'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
843
            assert args.model_type != ModelType.encoder_and_decoder
844
845
            # Number of layers in each model chunk is the number of layers in the stage,
            # divided by the number of model chunks in a stage.
846
            self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
847
848
849
850
851
852
853
854
            # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0]  [2]  [4]  [6]
            # Stage 1: [1]  [3]  [5]  [7]
            # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0, 1]  [4, 5]
            # Stage 1: [2, 3]  [6, 7]
855
            offset = mpu.get_virtual_pipeline_model_parallel_rank() * (
856
                args.num_layers // args.virtual_pipeline_model_parallel_size) + \
857
                (mpu.get_pipeline_model_parallel_rank() * self.num_layers)
858
        else:
859
            # Each stage gets a contiguous set of layers.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
860
            if args.model_type == ModelType.encoder_and_decoder and \
861
862
                    mpu.get_pipeline_model_parallel_world_size() > 1:
                pipeline_rank = mpu.get_pipeline_model_parallel_rank()
Vijay Korthikanti's avatar
Vijay Korthikanti committed
863
864
865
866
867
868
                if layer_type == LayerType.encoder:
                    offset = pipeline_rank * self.num_layers
                else:
                    num_ranks_in_enc = args.pipeline_model_parallel_split_rank
                    offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers
            else:
869
                offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
870

871
        if self.num_layers == 0:
Lawrence McAfee's avatar
Lawrence McAfee committed
872
            # When a standalone embedding stage is used (e.g.,
873
            # args.standalone_embedding_stage == True), virtual pipeline ranks
874
            # on pipeline rank 0 will have zero transformer layers assigned to
Lawrence McAfee's avatar
Lawrence McAfee committed
875
876
877
878
879
            # them. This results in the model's input and output tensors to be
            # the same, which will cause failure for certain output tensor
            # optimizations (e.g., pipeline output deallocation). To remedy
            # this, we assign a 'no-op' layer on these ranks, which will
            # disconnect the input tensor from the output tensor.
880
881
882
883
884
            self.num_layers = 1
            self.layers = torch.nn.ModuleList([ NoopTransformerLayer(1) ])
        else:
            self.layers = torch.nn.ModuleList(
                [build_layer(i + 1 + offset) for i in range(self.num_layers)])
885

886
        if self.post_process and self.post_layer_norm:
887
888
889
            # Final layer norm before output.
            self.final_layernorm = LayerNorm(
                args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
890
                eps=args.layernorm_epsilon,
891
                no_persist_layer_norm=args.no_persist_layer_norm,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
892
                sequence_parallel=args.sequence_parallel)
893

Mohammad's avatar
Mohammad committed
894
    def _get_layer(self, layer_number):
895
        return self.layers[layer_number]
Mohammad's avatar
Mohammad committed
896

897
898
    def _checkpointed_forward(self, hidden_states, attention_mask,
                              encoder_output, enc_dec_attn_mask):
899
900
901
902
        """Forward method with activation checkpointing."""
        def custom(start, end):
            def custom_forward(*inputs):
                x_ = inputs[0]
903
904
905
                attention_mask = inputs[1]
                encoder_output = inputs[2]
                enc_dec_attn_mask = inputs[3]
Mohammad's avatar
Mohammad committed
906
907
                for index in range(start, end):
                    layer = self._get_layer(index)
908
                    x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
909
910
911
                return x_
            return custom_forward

Vijay Korthikanti's avatar
Vijay Korthikanti committed
912
        if self.recompute_method == 'uniform':
913
914
915
916
917
            # Uniformly divide the total number of Transformer layers and checkpoint
            # the input activation of each divided chunk.
            # A method to further reduce memory usage reducing checkpoints.
            l = 0
            while l < self.num_layers:
918
                hidden_states = tensor_parallel.checkpoint(
Vijay Korthikanti's avatar
Vijay Korthikanti committed
919
                    custom(l, l + self.recompute_num_layers),
Vijay Korthikanti's avatar
Vijay Korthikanti committed
920
                    self.distribute_saved_activations,
921
                    hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
922
                l += self.recompute_num_layers
923

Vijay Korthikanti's avatar
Vijay Korthikanti committed
924
        elif self.recompute_method == 'block':
925
926
927
928
            # Checkpoint the input activation of only a set number of individual
            # Transformer layers and skip the rest.
            # A method fully use the device memory removing redundant re-computation.
            for l in range(self.num_layers):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
929
                if l < self.recompute_num_layers:
930
                    hidden_states = tensor_parallel.checkpoint(
931
                        custom(l, l + 1),
Vijay Korthikanti's avatar
Vijay Korthikanti committed
932
                        self.distribute_saved_activations,
933
934
935
936
937
                        hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
                else:
                    hidden_states = custom(l, l + 1)(
                        hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
        else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
938
            raise ValueError("Invalid activation recompute method.")
939
940
941

        return hidden_states

942
    def set_input_tensor(self, input_tensor):
943
944
945
946
947
948
949
        """Set input tensor to be used instead of forward()'s input.

        When doing pipeline parallelism the input from the previous
        stage comes from communication, not from the input, so the
        model's forward_step_func won't have it. This function is thus
        used by internal code to bypass the input provided by the
        forward_step_func"""
950
951
        self.input_tensor = input_tensor

952
    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
953
954
                encoder_output=None, enc_dec_attn_mask=None,
                inference_params=None):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
955
956
        # hidden_states: [s, b, h]

957
        # Checks.
mshoeybi's avatar
mshoeybi committed
958
        if inference_params:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
959
            assert self.recompute_granularity is None, \
960
                'inference does not work with activation checkpointing'
961

962
        if not self.pre_process:
963
            # See set_input_tensor()
964
            hidden_states = self.input_tensor
965

966
967
        # Viewless tensor.
        # - We only need to create a viewless tensor in the case of micro batch
968
969
970
971
        #   size (mbs) == 1, since in this case, 'hidden_states.transpose()'
        #   above creates a view tensor, and '.contiguous()' is a pass-through.
        #   For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
        #   the need to make it viewless.
972
973
974
975
        #
        #   However, we don't explicitly check mbs == 1 here because
        #   make_viewless_tensor() has negligible overhead when its input
        #   is already viewless.
976
        #
977
978
979
980
        # - For the 'else' case above, calling make_viewless_tensor() here is
        #   likely redundant, since p2p_communication.py (likely originator)
        #   already creates viewless tensors. That said, make_viewless_tensor()
        #   is called here to be future-proof and corner-case-proof.
981
        hidden_states = core.utils.make_viewless_tensor(
982
            hidden_states,
983
984
            requires_grad=True,
            keep_graph=True,
985
986
        )

Vijay Korthikanti's avatar
Vijay Korthikanti committed
987
        if self.sequence_parallel:
988
            rng_context = tensor_parallel.get_cuda_rng_tracker().fork()
989
        else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
990
            rng_context = nullcontext()
Vijay Korthikanti's avatar
Vijay Korthikanti committed
991
992

        with rng_context:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
993
            # Forward pass.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
994
            if self.recompute_granularity == 'full':
Vijay Korthikanti's avatar
Vijay Korthikanti committed
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
                hidden_states = self._checkpointed_forward(hidden_states,
                                                           attention_mask,
                                                           encoder_output,
                                                           enc_dec_attn_mask)
            else:
                for index in range(self.num_layers):
                    layer = self._get_layer(index)
                    hidden_states = layer(
                        hidden_states,
                        attention_mask,
                        encoder_output=encoder_output,
                        enc_dec_attn_mask=enc_dec_attn_mask,
                        inference_params=inference_params)
mshoeybi's avatar
mshoeybi committed
1008

1009
        # Final layer norm.
1010
        if self.post_process and self.post_layer_norm:
1011
1012
            hidden_states = self.final_layernorm(hidden_states)

1013
        return hidden_states