transformer.py 39.4 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4

"""Transformer."""
import math
5
from contextlib import nullcontext
6
import torch
7
import torch.nn.functional as F
8

9
10
11
from megatron import get_timers, get_args
from megatron.core import get_global_memory_buffer
from megatron import core
12
from .module import MegatronModule
13
from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType
14
from megatron.model import LayerNorm
15
16
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
17
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
18

19

20
21
22
23
24
25
26
27
28
29
""" We use the following notation throughout this file:
     h: hidden size
     n: number of attention heads
     p: number of model parallel partitions
     np: n/p
     hp: h/p
     hn: h/n
     b: batch size
     s: sequence length
     l: number of layers
30
    Transformer takes input of size [s, b, h] and returns a
31
32
33
34
    tensor of the same size. We use the following arguments:
        hyperparameters: transformer hyperparameters
"""

35
class DropPath(MegatronModule):
36
    """Drop paths (Stochastic Depth) per sample
37
38
39
    (when applied in main path of residual blocks).
    """

Vijay Korthikanti's avatar
Vijay Korthikanti committed
40
    def __init__(self, drop_prob=0.):
41
42
43
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

Vijay Korthikanti's avatar
Vijay Korthikanti committed
44
    def forward(self, hidden_state):
45
        if self.drop_prob == 0. or not self.training:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
46
            return hidden_state
47
48
        keep_prob = 1 - self.drop_prob
        # work with diff dim tensors, not just 2D ConvNets
Vijay Korthikanti's avatar
Vijay Korthikanti committed
49
        shape = (hidden_state.shape[0],) + (1,) * (hidden_state.ndim - 1)
50
        random_tensor = keep_prob + \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
51
            torch.rand(shape, dtype=hidden_state.dtype, device=hidden_state.device)
52
        random_tensor.floor_()  # binarize
Vijay Korthikanti's avatar
Vijay Korthikanti committed
53
        output = hidden_state.div(keep_prob) * random_tensor
54
55
        return output

56
57
58
59
60
61
62
63
64
65
66
def _args_to_kwargs():
    args = get_args()

    common_kwargs = {
        "params_dtype": args.params_dtype,
        "use_cpu_initialization": args.use_cpu_initialization,
        "perform_initialization": args.perform_initialization,
        "gradient_accumulation_fusion": args.gradient_accumulation_fusion,
        "sequence_parallel_enabled": args.sequence_parallel,
    }
    return common_kwargs
67

68
69
70
71
72
class ParallelMLP(MegatronModule):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
hwijeen's avatar
hwijeen committed
73
    state back into h hidden dimension.
74
75
    """

76
    def __init__(self, init_method, output_layer_init_method):
77
        super(ParallelMLP, self).__init__()
Mohammad's avatar
Mohammad committed
78
        args = get_args()
79

80

81
        # Project to 4h.
82
        self.dense_h_to_4h = core.tensor_parallel.ColumnParallelLinear(
Mohammad's avatar
Mohammad committed
83
            args.hidden_size,
84
            args.ffn_hidden_size,
85
            gather_output=False,
86
            init_method=init_method,
87
88
89
            skip_bias_add=True,
            async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
            **_args_to_kwargs())
90

91
92
93
94
95
96
        self.bias_gelu_fusion = args.bias_gelu_fusion
        self.activation_func = F.gelu
        if args.openai_gelu:
            self.activation_func = openai_gelu
        elif args.onnx_safe:
            self.activation_func = erf_gelu
97
98

        # Project back to h.
99
        self.dense_4h_to_h = core.tensor_parallel.RowParallelLinear(
100
            args.ffn_hidden_size,
Mohammad's avatar
Mohammad committed
101
            args.hidden_size,
102
            input_is_parallel=True,
103
            init_method=output_layer_init_method,
104
105
            skip_bias_add=True,
            **_args_to_kwargs())
106

107
108
    def forward(self, hidden_states):

109
110
        # [s, b, 4hp]
        intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
111

112
113
114
115
116
117
118
119
120
121
        if self.bias_gelu_fusion:
             intermediate_parallel = \
                     bias_gelu_impl(intermediate_parallel, bias_parallel)
        else:
            intermediate_parallel = \
                self.activation_func(intermediate_parallel + bias_parallel)

        # [s, b, h]
        output, output_bias = self.dense_4h_to_h(intermediate_parallel)
        return output, output_bias
122

rprenger's avatar
rprenger committed
123
124
125
126
class SwitchMLP(MegatronModule):
    """
    Routes input to one of N MLP "experts"
    """
rprenger's avatar
rprenger committed
127
    def __init__(self, init_method, output_layer_init_method):
rprenger's avatar
rprenger committed
128
129
        super(SwitchMLP, self).__init__()
        args = get_args()
rprenger's avatar
rprenger committed
130
        self.router = torch.nn.Linear(args.hidden_size, args.num_experts)
rprenger's avatar
rprenger committed
131
        self.experts = torch.nn.ModuleList()
rprenger's avatar
rprenger committed
132
        for i in range(args.num_experts):
rprenger's avatar
rprenger committed
133
            self.experts.append(ParallelMLP(init_method, output_layer_init_method))
134

rprenger's avatar
rprenger committed
135
    def forward(self, hidden_states):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
136
137
138
        # hidden_states: [s, b, h]
        s = hidden_states.size(0)
        b = hidden_states.size(1)
rprenger's avatar
rprenger committed
139
140
        h = hidden_states.size(2)
        route = self.router(hidden_states)
rprenger's avatar
rprenger committed
141
        route = torch.nn.functional.softmax(route, dim=2)
rprenger's avatar
rprenger committed
142
        max_prob, max_ind = torch.max(route, dim=2)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
143
        max_prob = torch.unsqueeze(max_prob, 2) # [s b 1]
144

rprenger's avatar
rprenger committed
145
        # TODO (rprenger) TODO this could be made easier to read
Vijay Korthikanti's avatar
Vijay Korthikanti committed
146
        # Converting [s, b, h] to [s*b, h].
147
        # Each vector could be routed differently
Vijay Korthikanti's avatar
Vijay Korthikanti committed
148
149
150
        hidden_states = hidden_states.view(-1, hidden_states.size(2)) # [s*b h]
        max_prob = max_prob.view(-1, max_prob.size(2)) # [s*b 1]
        max_ind = max_ind.view(-1) # [s*b]
rprenger's avatar
rprenger committed
151
152
153

        output_total = torch.empty_like(hidden_states)
        output_bias_total = torch.empty_like(hidden_states)
rprenger's avatar
rprenger committed
154
        #TODO (rprenger) This does each expert in serial, but it could be parallelized
155

rprenger's avatar
rprenger committed
156
        for expert_num, expert in enumerate(self.experts):
157
158
            local_indices = (max_ind == expert_num).nonzero()
            hidden = hidden_states[local_indices,:]
rprenger's avatar
rprenger committed
159
160
            output, output_bias = expert(hidden)
            output_bias = output_bias.expand_as(output)
161
162
163
            output_total[local_indices,:] = output
            output_bias_total[local_indices,:] = output_bias

rprenger's avatar
rprenger committed
164
165
        output_total = output_total*max_prob
        output_bias_total = output_bias_total*max_prob
Vijay Korthikanti's avatar
Vijay Korthikanti committed
166
167
        output_total = output_total.view(s, b, h)
        output_bias_total = output_bias_total.view(s, b, h)
rprenger's avatar
rprenger committed
168
169

        return output_total, output_bias_total
170

171
172

class CoreAttention(MegatronModule):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
173

174
175
176
177
178
179
180
181
182
183
184
185
186
    def __init__(self, layer_number,
                 attn_mask_type=AttnMaskType.padding):
        super(CoreAttention, self).__init__()
        args = get_args()
        self.fp16 = args.fp16
        self.bf16 = args.bf16

        self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
        if self.apply_query_key_layer_scaling:
            self.attention_softmax_in_fp32 = True
        self.layer_number = max(1, layer_number)
        self.attn_mask_type = attn_mask_type
Vijay Korthikanti's avatar
Vijay Korthikanti committed
187
        self.sequence_parallel = args.sequence_parallel
188
189
190
191

        projection_size = args.kv_channels * args.num_attention_heads

        # Per attention head and per partition values.
192
193
194
195
        world_size = core.get_tensor_model_parallel_world_size()
        self.hidden_size_per_partition = core.utils.divide(projection_size,
                                                           world_size)
        self.hidden_size_per_attention_head = core.utils.divide(
196
            projection_size, args.num_attention_heads)
197
        self.num_attention_heads_per_partition = core.utils.divide(
198
            args.num_attention_heads, world_size)
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217

        coeff = None
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        if self.apply_query_key_layer_scaling:
            coeff = self.layer_number
            self.norm_factor *= coeff

        self.scale_mask_softmax = FusedScaleMaskSoftmax(
            self.fp16, self.bf16,
            self.attn_mask_type,
            args.masked_softmax_fusion,
            attention_mask_func,
            self.attention_softmax_in_fp32,
            coeff)

        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
        self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
218

219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
    def forward(self, query_layer, key_layer,
                value_layer, attention_mask):

        # ===================================
        # Raw attention scores. [b, np, s, s]
        # ===================================

        # [b, np, sq, sk]
        output_size = (query_layer.size(1),
                       query_layer.size(2),
                       query_layer.size(0),
                       key_layer.size(0))

        # [sq, b, np, hn] -> [sq, b * np, hn]
        query_layer = query_layer.view(output_size[2],
                                       output_size[0] * output_size[1], -1)
        # [sk, b, np, hn] -> [sk, b * np, hn]
        key_layer = key_layer.view(output_size[3],
                                   output_size[0] * output_size[1], -1)

Vijay Korthikanti's avatar
Vijay Korthikanti committed
239
        # preallocting input tensor: [b * np, sq, sk]
Vijay Korthikanti's avatar
Vijay Korthikanti committed
240
        matmul_input_buffer = get_global_memory_buffer().get_tensor(
241
            (output_size[0]*output_size[1], output_size[2], output_size[3]),
Vijay Korthikanti's avatar
Vijay Korthikanti committed
242
            query_layer.dtype, "mpu")
243
244
245

        # Raw attention scores. [b * np, sq, sk]
        matmul_result = torch.baddbmm(
Vijay Korthikanti's avatar
Vijay Korthikanti committed
246
            matmul_input_buffer,
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
            query_layer.transpose(0, 1),   # [b * np, sq, hn]
            key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
            beta=0.0, alpha=(1.0/self.norm_factor))

        # change view to [b, np, sq, sk]
        attention_scores = matmul_result.view(*output_size)

        # ===========================
        # Attention probs and dropout
        # ===========================

        # attention scores and attention mask [b, np, sq, sk]
        attention_probs = self.scale_mask_softmax(attention_scores,
                                                  attention_mask)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.

Vijay Korthikanti's avatar
Vijay Korthikanti committed
265
        if not self.sequence_parallel:
266
            with core.tensor_parallel.get_cuda_rng_tracker().fork():
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
                attention_probs = self.attention_dropout(attention_probs)
        else:
            attention_probs = self.attention_dropout(attention_probs)

        # =========================
        # Context layer. [sq, b, hp]
        # =========================

        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]

        # context layer shape: [b, np, sq, hn]
        output_size = (value_layer.size(1),
                       value_layer.size(2),
                       query_layer.size(0),
                       value_layer.size(3))

        # change view [sk, b * np, hn]
        value_layer = value_layer.view(value_layer.size(0),
                                       output_size[0] * output_size[1], -1)

        # change view [b * np, sq, sk]
        attention_probs = attention_probs.view(output_size[0] * output_size[1],
                                               output_size[2], -1)

        # matmul: [b * np, sq, hn]
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))

        # change view [b, np, sq, hn]
        context_layer = context_layer.view(*output_size)

        # [b, np, sq, hn] --> [sq, b, np, hn]
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

        # [sq, b, np, hn] --> [sq, b, hp]
        new_context_layer_shape = context_layer.size()[:-2] + \
            (self.hidden_size_per_partition,)
        context_layer = context_layer.view(*new_context_layer_shape)

        return context_layer


309
class ParallelAttention(MegatronModule):
310
311
    """Parallel self-attention layer abstract class.

Vijay Korthikanti's avatar
Vijay Korthikanti committed
312
    Self-attention layer takes input with size [s, b, h]
313
314
    and returns output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
315

316
    def __init__(self, init_method,
317
318
319
320
                 output_layer_init_method, layer_number,
                 attention_type=AttnType.self_attn,
                 attn_mask_type=AttnMaskType.padding):
        super(ParallelAttention, self).__init__()
Mohammad's avatar
Mohammad committed
321
        args = get_args()
322
        self.layer_number = max(1, layer_number)
323
324
        self.attention_type = attention_type
        self.attn_mask_type = attn_mask_type
325
        self.params_dtype = args.params_dtype
326
327

        projection_size = args.kv_channels * args.num_attention_heads
328
329

        # Per attention head and per partition values.
330
331
        world_size = core.get_tensor_model_parallel_world_size()
        self.hidden_size_per_attention_head = core.utils.divide(
332
            projection_size, args.num_attention_heads)
333
        self.num_attention_heads_per_partition = core.utils.divide(
Mohammad's avatar
Mohammad committed
334
            args.num_attention_heads, world_size)
335
336

        # Strided linear layer.
337
        if attention_type == AttnType.self_attn:
338
            self.query_key_value = core.tensor_parallel.ColumnParallelLinear(
339
340
341
                args.hidden_size,
                3 * projection_size,
                gather_output=False,
342
343
344
                init_method=init_method,
                async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
                **_args_to_kwargs())
345
346
        else:
            assert attention_type == AttnType.cross_attn
347
            self.query = core.tensor_parallel.ColumnParallelLinear(
348
349
350
                args.hidden_size,
                projection_size,
                gather_output=False,
351
352
353
                init_method=init_method,
                async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
                **_args_to_kwargs())
354

355
356

            self.key_value = core.tensor_parallel.ColumnParallelLinear(
357
358
359
                args.hidden_size,
                2 * projection_size,
                gather_output=False,
360
361
362
                init_method=init_method,
                async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
                **_args_to_kwargs())
363

364
365
        self.core_attention = CoreAttention(self.layer_number,
                                            self.attn_mask_type)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
366
        self.checkpoint_core_attention = args.recompute_granularity == 'selective'
367
368

        # Output.
369
        self.dense = core.tensor_parallel.RowParallelLinear(
370
            projection_size,
Mohammad's avatar
Mohammad committed
371
            args.hidden_size,
372
            input_is_parallel=True,
373
            init_method=output_layer_init_method,
374
375
            skip_bias_add=True,
            **_args_to_kwargs())
Vijay Korthikanti's avatar
Vijay Korthikanti committed
376

377
378
379
380
381
382
383
384
385
386
387
388
    def _checkpointed_attention_forward(self, query_layer, key_layer,
                                        value_layer, attention_mask):
        """Forward method with activation checkpointing."""
        def custom_forward(*inputs):
            query_layer = inputs[0]
            key_layer = inputs[1]
            value_layer = inputs[2]
            attention_mask = inputs[3]
            output_ = self.core_attention(query_layer, key_layer,
                                          value_layer, attention_mask)
            return output_

389
        hidden_states = core.tensor_parallel.checkpoint(
390
391
392
393
            custom_forward,
            False, query_layer, key_layer, value_layer, attention_mask)

        return hidden_states
394
395
396
397
398
399
400
401
402
403
404

    def _allocate_memory(self, inference_max_sequence_len, batch_size):
        return torch.empty(
            inference_max_sequence_len,
            batch_size,
            self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
            dtype=self.params_dtype,
            device=torch.cuda.current_device())

    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
405
                encoder_output=None, inference_params=None):
406
        # hidden_states: [sq, b, h]
407

408
409
410
        # =================================================
        # Pre-allocate memory for key-values for inference.
        # =================================================
mshoeybi's avatar
mshoeybi committed
411
        if inference_params:
412
            if self.layer_number not in inference_params.key_value_memory_dict:
mshoeybi's avatar
mshoeybi committed
413
                inf_max_seq_len = inference_params.max_sequence_len
mshoeybi's avatar
mshoeybi committed
414
                inf_max_batch_size = inference_params.max_batch_size
415
                inference_key_memory = self._allocate_memory(
mshoeybi's avatar
mshoeybi committed
416
                    inf_max_seq_len, inf_max_batch_size)
417
                inference_value_memory = self._allocate_memory(
mshoeybi's avatar
mshoeybi committed
418
                    inf_max_seq_len, inf_max_batch_size)
419
420
421
422
423
                inference_params.key_value_memory_dict[self.layer_number] = (
                    inference_key_memory, inference_value_memory)
            else:
                inference_key_memory, inference_value_memory = \
                    inference_params.key_value_memory_dict[self.layer_number]
mshoeybi's avatar
mshoeybi committed
424

425
426
427
        # =====================
        # Query, Key, and Value
        # =====================
428

429
430
431
432
433
434
435
436
437
438
439
440
441
        if self.attention_type == AttnType.self_attn:
            # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
            mixed_x_layer, _ = self.query_key_value(hidden_states)

            # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
            new_tensor_shape = mixed_x_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 3 * self.hidden_size_per_attention_head)
            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

            # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
            (query_layer,
             key_layer,
442
             value_layer) = core.tensor_parallel.split_tensor_along_last_dim(mixed_x_layer, 3)
443
444
445
446
447
448
449
450
451
452
453
454
        else:
            # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
            mixed_kv_layer, _ = self.key_value(encoder_output)

            # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
            new_tensor_shape = mixed_kv_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 2 * self.hidden_size_per_attention_head)
            mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

            # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
            (key_layer,
455
             value_layer) = core.tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2)
456
457
458
459
460
461
462
463

            # Attention head [sq, b, h] --> [sq, b, hp]
            query_layer, _ = self.query(hidden_states)
            # [sq, b, hp] --> [sq, b, np, hn]
            new_tensor_shape = query_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 self.hidden_size_per_attention_head)
            query_layer = query_layer.view(*new_tensor_shape)
464

mshoeybi's avatar
mshoeybi committed
465
466
467
        # ==================================
        # Adjust key and value for inference
        # ==================================
468

mshoeybi's avatar
mshoeybi committed
469
        if inference_params:
mshoeybi's avatar
mshoeybi committed
470
471
            batch_start = inference_params.batch_size_offset
            batch_end = batch_start + key_layer.size(1)
472
            assert batch_end <= inference_key_memory.size(1)
mshoeybi's avatar
mshoeybi committed
473
474
            sequence_start = inference_params.sequence_len_offset
            sequence_end = sequence_start + key_layer.size(0)
475
            assert sequence_end <= inference_key_memory.size(0)
476
            # Copy key and values.
477
478
479
480
481
            inference_key_memory[sequence_start:sequence_end,
                                 batch_start:batch_end, ...] = key_layer
            inference_value_memory[sequence_start:sequence_end,
                                   batch_start:batch_end, ...] = value_layer
            key_layer = inference_key_memory[
mshoeybi's avatar
mshoeybi committed
482
                :sequence_end, batch_start:batch_end, ...]
483
            value_layer = inference_value_memory[
mshoeybi's avatar
mshoeybi committed
484
                :sequence_end, batch_start:batch_end, ...]
485

486
487
488
        # ==================================
        # core attention computation
        # ==================================
489

Vijay Korthikanti's avatar
Vijay Korthikanti committed
490
        if self.checkpoint_core_attention:
491
492
            context_layer = self._checkpointed_attention_forward(
                query_layer, key_layer, value_layer, attention_mask)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
493
        else:
494
495
            context_layer = self.core_attention(
                query_layer, key_layer, value_layer, attention_mask)
496
497

        # =================
498
        # Output. [sq, b, h]
499
500
501
        # =================

        output, bias = self.dense(context_layer)
502

503
504
505
        return output, bias


506
def bias_dropout_add(x, bias, residual, prob, training):
507
508
509
510
511
512
513
514
515
516
517
518
519
    # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
    out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
    out = residual + out
    return out


def get_bias_dropout_add(training):
    def _bias_dropout_add(x, bias, residual, prob):
        return bias_dropout_add(x, bias, residual, prob, training)
    return _bias_dropout_add


@torch.jit.script
520
521
522
523
def bias_dropout_add_fused_train(x: torch.Tensor,
                                 bias: torch.Tensor,
                                 residual: torch.Tensor,
                                 prob: float) -> torch.Tensor:
524
525
526
527
    return bias_dropout_add(x, bias, residual, prob, True)


@torch.jit.script
528
529
530
531
def bias_dropout_add_fused_inference(x: torch.Tensor,
                                     bias: torch.Tensor,
                                     residual: torch.Tensor,
                                     prob: float) -> torch.Tensor:
532
    return bias_dropout_add(x, bias, residual, prob, False)
533
534
535
536
537


class ParallelTransformerLayer(MegatronModule):
    """A single transformer layer.

Vijay Korthikanti's avatar
Vijay Korthikanti committed
538
    Transformer layer takes input with size [s, b, h] and returns an
539
540
    output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
541

542
543
    def __init__(self, init_method, output_layer_init_method,
                 layer_number, layer_type=LayerType.encoder,
544
545
                 self_attn_mask_type=AttnMaskType.padding,
                 drop_path_rate=0.):
Mohammad's avatar
Mohammad committed
546
        args = get_args()
547
548

        super(ParallelTransformerLayer, self).__init__()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
549
        self.layer_number = layer_number
550
        self.layer_type = layer_type
551
552

        self.apply_residual_connection_post_layernorm \
Mohammad's avatar
Mohammad committed
553
            = args.apply_residual_connection_post_layernorm
554

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
555
556
557
        self.bf16 = args.bf16
        self.fp32_residual_connection = args.fp32_residual_connection

558
559
        # Layernorm on the input data.
        self.input_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
560
            args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
561
            eps=args.layernorm_epsilon,
562
            no_persist_layer_norm=args.no_persist_layer_norm,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
563
            sequence_parallel=args.sequence_parallel)
564
565

        # Self attention.
566
567
568
569
570
571
        self.self_attention = ParallelAttention(
            init_method,
            output_layer_init_method,
            layer_number,
            attention_type=AttnType.self_attn,
            attn_mask_type=self_attn_mask_type)
572
573
        self.hidden_dropout = args.hidden_dropout
        self.bias_dropout_fusion = args.bias_dropout_fusion
Vijay Korthikanti's avatar
Vijay Korthikanti committed
574
        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None
575

576
        # Layernorm on the attention output
577
        self.post_attention_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
578
            args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
579
            eps=args.layernorm_epsilon,
580
            no_persist_layer_norm=args.no_persist_layer_norm,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
581
            sequence_parallel=args.sequence_parallel)
582

583
584
585
586
587
588
589
590
591
        if self.layer_type == LayerType.decoder:
            self.inter_attention = ParallelAttention(
                init_method,
                output_layer_init_method,
                layer_number,
                attention_type=AttnType.cross_attn)
            # Layernorm on the attention output.
            self.post_inter_attention_layernorm = LayerNorm(
                args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
592
                eps=args.layernorm_epsilon,
593
                no_persist_layer_norm=args.no_persist_layer_norm,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
594
                sequence_parallel=args.sequence_parallel)
595

596
        # MLP
rprenger's avatar
rprenger committed
597
598
599
600
        if args.num_experts is not None:
            self.mlp = SwitchMLP(init_method, output_layer_init_method)
        else:
            self.mlp = ParallelMLP(init_method, output_layer_init_method)
601

602
603
604
605
606
607
608
        # Set bias+dropout+add fusion grad_enable execution handler.
        TORCH_MAJOR = int(torch.__version__.split('.')[0])
        TORCH_MINOR = int(torch.__version__.split('.')[1])
        use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
        self.bias_dropout_add_exec_handler = \
                nullcontext if use_nvfuser else torch.enable_grad

609
    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
610
611
                encoder_output=None, enc_dec_attn_mask=None,
                inference_params=None):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
612
        # hidden_states: [s, b, h]
613

614
        # Layer norm at the beginning of the transformer layer.
615
616
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
617
        attention_output, attention_bias = \
618
619
620
            self.self_attention(
                layernorm_output,
                attention_mask,
mshoeybi's avatar
mshoeybi committed
621
                inference_params=inference_params)
622

623
624
        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
625
626
627
628
            residual = layernorm_output
        else:
            residual = hidden_states

Vijay Korthikanti's avatar
Vijay Korthikanti committed
629
        if self.drop_path is None:
630
631
632
633
634
635
636
637
638
            # jit scripting for a nn.module (with dropout) is not
            # trigerring the fusion kernel. For now, we use two
            # different nn.functional routines to account for varying
            # dropout semantics during training and inference phases.
            if self.bias_dropout_fusion:
                if self.training:
                    bias_dropout_add_func = bias_dropout_add_fused_train
                else:
                    bias_dropout_add_func = bias_dropout_add_fused_inference
639
            else:
640
                bias_dropout_add_func = get_bias_dropout_add(self.training)
641

642
            with self.bias_dropout_add_exec_handler():
643
644
645
646
647
648
649
650
651
652
                layernorm_input = bias_dropout_add_func(
                    attention_output,
                    attention_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)
        else:
            out = torch.nn.functional.dropout(attention_output + attention_bias,
                                              p=self.hidden_dropout,
                                              training=self.training)
            layernorm_input = residual + self.drop_path(out)
653

654
655
656
        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)

657
658
659
660
661
662
663
664
665
666
667
        if self.layer_type == LayerType.decoder:
            attention_output, attention_bias = \
                self.inter_attention(layernorm_output,
                                     enc_dec_attn_mask,
                                     encoder_output=encoder_output)
            # residual connection
            if self.apply_residual_connection_post_layernorm:
                residual = layernorm_output
            else:
                residual = layernorm_input

668
            with self.bias_dropout_add_exec_handler():
669
670
671
672
673
674
675
676
677
                layernorm_input = bias_dropout_add_func(
                    attention_output,
                    attention_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)

            # Layer norm post the decoder attention
            layernorm_output = self.post_inter_attention_layernorm(layernorm_input)

678
        # MLP.
679
        mlp_output, mlp_bias = self.mlp(layernorm_output)
680

681
682
        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
683
            residual = layernorm_output
684
        else:
685
686
            residual = layernorm_input

Vijay Korthikanti's avatar
Vijay Korthikanti committed
687
        if self.drop_path is None:
688
            with self.bias_dropout_add_exec_handler():
689
690
691
692
693
                output = bias_dropout_add_func(
                    mlp_output,
                    mlp_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)
694
695
696
697
698
699
700

            # Jit compiled function creates 'view' tensor. This tensor
            # potentially gets saved in the MPU checkpoint function context,
            # which rejects view tensors. While making a viewless tensor here
            # won't result in memory savings (like the data loader, or
            # p2p_communication), it serves to document the origin of this
            # 'view' tensor.
701
702
703
            output = core.utils.make_viewless_tensor(inp = output,
                                                     requires_grad = output.requires_grad,
                                                     keep_graph = True)
704

705
706
707
708
709
        else:
            out = torch.nn.functional.dropout(mlp_output + mlp_bias,
                                              p=self.hidden_dropout,
                                              training=self.training)
            output = residual + self.drop_path(out)
710
711
712
713

        return output


714
715
716
class NoopTransformerLayer(MegatronModule):
    """A single 'no-op' transformer layer.

Lawrence McAfee's avatar
Lawrence McAfee committed
717
    The sole purpose of this layer is for when a standalone embedding layer
718
    is used (i.e., args.standalone_embedding_stage == True). In this case,
Lawrence McAfee's avatar
Lawrence McAfee committed
719
720
721
722
723
724
725
726
727
    zero transformer layers are assigned when pipeline rank == 0. Additionally,
    when virtual pipeline rank >= 1, zero total model parameters are created
    (virtual rank 0 contains the input embedding). This results in the model's
    input and output tensors being the same, which causes an error when
    performing certain memory optimiations on the output tensor (e.g.,
    deallocating it). Thus, this layer disconnects the input from the output
    via a clone. Since ranks containing a no-op layer are generally under-
    utilized (both compute and memory), there's no worry of any performance
    degredation.
728
729
730
731
732
733
734
735
736
737
738
739
    """

    def __init__(self, layer_number):
        super().__init__()
        self.layer_number = layer_number

    def forward(self, hidden_states, attention_mask,
                encoder_output=None, enc_dec_attn_mask=None,
                inference_params=None):
        return hidden_states.clone()


740
741
742
class ParallelTransformer(MegatronModule):
    """Transformer class."""

743
    def __init__(self, init_method, output_layer_init_method,
744
                 layer_type=LayerType.encoder,
745
                 self_attn_mask_type=AttnMaskType.padding,
746
                 post_layer_norm=True,
747
748
                 pre_process=True, post_process=True,
                 drop_path_rate=0.0):
749
        super(ParallelTransformer, self).__init__()
Mohammad's avatar
Mohammad committed
750
        args = get_args()
751

752
753
        self.layer_type = layer_type
        self.model_type = args.model_type
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
754
        self.bf16 = args.bf16
755
        self.fp32_residual_connection = args.fp32_residual_connection
756
        self.post_layer_norm = post_layer_norm
757
758
759
        self.pre_process = pre_process
        self.post_process = post_process
        self.input_tensor = None
760
        self.drop_path_rate = drop_path_rate
761

762
        # Store activation checkpoiting flag.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
763
764
765
        self.recompute_granularity = args.recompute_granularity
        self.recompute_method = args.recompute_method
        self.recompute_num_layers = args.recompute_num_layers
Vijay Korthikanti's avatar
Vijay Korthikanti committed
766
767
        self.distribute_saved_activations = \
            args.distribute_saved_activations and not args.sequence_parallel
768

Vijay Korthikanti's avatar
Vijay Korthikanti committed
769
        self.sequence_parallel = args.sequence_parallel
770

771
        # Number of layers.
772
        self.num_layers = core.get_num_layers(
773
            args, args.model_type == ModelType.encoder_and_decoder)
Mohammad's avatar
Mohammad committed
774

Vijay Korthikanti's avatar
Vijay Korthikanti committed
775
        self.drop_path_rates = [rate.item() for rate in torch.linspace(0, self.drop_path_rate, args.num_layers)]
776

Mohammad's avatar
Mohammad committed
777
778
        # Transformer layers.
        def build_layer(layer_number):
779
            return ParallelTransformerLayer(
780
781
782
                init_method,
                output_layer_init_method,
                layer_number,
783
                layer_type=layer_type,
784
                self_attn_mask_type=self_attn_mask_type,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
785
                drop_path_rate=self.drop_path_rates[layer_number - 1])
786
787
        if args.virtual_pipeline_model_parallel_size is not None:
            assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
788
789
                'num_layers_per_stage must be divisible by ' \
                'virtual_pipeline_model_parallel_size'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
790
            assert args.model_type != ModelType.encoder_and_decoder
791
792
            # Number of layers in each model chunk is the number of layers in the stage,
            # divided by the number of model chunks in a stage.
793
            self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
794
795
796
797
798
799
800
801
            # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0]  [2]  [4]  [6]
            # Stage 1: [1]  [3]  [5]  [7]
            # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0, 1]  [4, 5]
            # Stage 1: [2, 3]  [6, 7]
802
            offset = core.get_virtual_pipeline_model_parallel_rank() * (
803
                args.num_layers // args.virtual_pipeline_model_parallel_size) + \
804
                (core.get_pipeline_model_parallel_rank() * self.num_layers)
805
        else:
806
            # Each stage gets a contiguous set of layers.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
807
            if args.model_type == ModelType.encoder_and_decoder and \
808
809
                    core.get_pipeline_model_parallel_world_size() > 1:
                pipeline_rank = core.get_pipeline_model_parallel_rank()
Vijay Korthikanti's avatar
Vijay Korthikanti committed
810
811
812
813
814
815
                if layer_type == LayerType.encoder:
                    offset = pipeline_rank * self.num_layers
                else:
                    num_ranks_in_enc = args.pipeline_model_parallel_split_rank
                    offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers
            else:
816
                offset = core.get_pipeline_model_parallel_rank() * self.num_layers
817

818
        if self.num_layers == 0:
Lawrence McAfee's avatar
Lawrence McAfee committed
819
            # When a standalone embedding stage is used (e.g.,
820
            # args.standalone_embedding_stage == True), virtual pipeline ranks
821
            # on pipeline rank 0 will have zero transformer layers assigned to
Lawrence McAfee's avatar
Lawrence McAfee committed
822
823
824
825
826
            # them. This results in the model's input and output tensors to be
            # the same, which will cause failure for certain output tensor
            # optimizations (e.g., pipeline output deallocation). To remedy
            # this, we assign a 'no-op' layer on these ranks, which will
            # disconnect the input tensor from the output tensor.
827
828
829
830
831
            self.num_layers = 1
            self.layers = torch.nn.ModuleList([ NoopTransformerLayer(1) ])
        else:
            self.layers = torch.nn.ModuleList(
                [build_layer(i + 1 + offset) for i in range(self.num_layers)])
832

833
        if self.post_process and self.post_layer_norm:
834
835
836
            # Final layer norm before output.
            self.final_layernorm = LayerNorm(
                args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
837
                eps=args.layernorm_epsilon,
838
                no_persist_layer_norm=args.no_persist_layer_norm,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
839
                sequence_parallel=args.sequence_parallel)
840

Mohammad's avatar
Mohammad committed
841
    def _get_layer(self, layer_number):
842
        return self.layers[layer_number]
Mohammad's avatar
Mohammad committed
843

844
845
    def _checkpointed_forward(self, hidden_states, attention_mask,
                              encoder_output, enc_dec_attn_mask):
846
847
848
849
        """Forward method with activation checkpointing."""
        def custom(start, end):
            def custom_forward(*inputs):
                x_ = inputs[0]
850
851
852
                attention_mask = inputs[1]
                encoder_output = inputs[2]
                enc_dec_attn_mask = inputs[3]
Mohammad's avatar
Mohammad committed
853
854
                for index in range(start, end):
                    layer = self._get_layer(index)
855
                    x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
856
857
858
                return x_
            return custom_forward

Vijay Korthikanti's avatar
Vijay Korthikanti committed
859
        if self.recompute_method == 'uniform':
860
861
862
863
864
            # Uniformly divide the total number of Transformer layers and checkpoint
            # the input activation of each divided chunk.
            # A method to further reduce memory usage reducing checkpoints.
            l = 0
            while l < self.num_layers:
865
                hidden_states = core.tensor_parallel.checkpoint(
Vijay Korthikanti's avatar
Vijay Korthikanti committed
866
                    custom(l, l + self.recompute_num_layers),
Vijay Korthikanti's avatar
Vijay Korthikanti committed
867
                    self.distribute_saved_activations,
868
                    hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
869
                l += self.recompute_num_layers
870

Vijay Korthikanti's avatar
Vijay Korthikanti committed
871
        elif self.recompute_method == 'block':
872
873
874
875
            # Checkpoint the input activation of only a set number of individual
            # Transformer layers and skip the rest.
            # A method fully use the device memory removing redundant re-computation.
            for l in range(self.num_layers):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
876
                if l < self.recompute_num_layers:
877
                    hidden_states = core.tensor_parallel.checkpoint(
878
                        custom(l, l + 1),
Vijay Korthikanti's avatar
Vijay Korthikanti committed
879
                        self.distribute_saved_activations,
880
881
882
883
884
                        hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
                else:
                    hidden_states = custom(l, l + 1)(
                        hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
        else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
885
            raise ValueError("Invalid activation recompute method.")
886
887
888

        return hidden_states

889
    def set_input_tensor(self, input_tensor):
890
891
892
893
894
895
896
        """Set input tensor to be used instead of forward()'s input.

        When doing pipeline parallelism the input from the previous
        stage comes from communication, not from the input, so the
        model's forward_step_func won't have it. This function is thus
        used by internal code to bypass the input provided by the
        forward_step_func"""
897
898
        self.input_tensor = input_tensor

899
    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
900
901
                encoder_output=None, enc_dec_attn_mask=None,
                inference_params=None):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
902
903
        # hidden_states: [s, b, h]

904
        # Checks.
mshoeybi's avatar
mshoeybi committed
905
        if inference_params:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
906
            assert self.recompute_granularity is None, \
907
                'inference does not work with activation checkpointing'
908

909
        if not self.pre_process:
910
            # See set_input_tensor()
911
            hidden_states = self.input_tensor
912

913
914
        # Viewless tensor.
        # - We only need to create a viewless tensor in the case of micro batch
915
916
917
918
        #   size (mbs) == 1, since in this case, 'hidden_states.transpose()'
        #   above creates a view tensor, and '.contiguous()' is a pass-through.
        #   For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
        #   the need to make it viewless.
919
920
921
922
        #
        #   However, we don't explicitly check mbs == 1 here because
        #   make_viewless_tensor() has negligible overhead when its input
        #   is already viewless.
923
        #
924
925
926
927
        # - For the 'else' case above, calling make_viewless_tensor() here is
        #   likely redundant, since p2p_communication.py (likely originator)
        #   already creates viewless tensors. That said, make_viewless_tensor()
        #   is called here to be future-proof and corner-case-proof.
928
        hidden_states = core.utils.make_viewless_tensor(
929
            hidden_states,
930
931
            requires_grad=True,
            keep_graph=True,
932
933
        )

Vijay Korthikanti's avatar
Vijay Korthikanti committed
934
        if self.sequence_parallel:
935
            rng_context = core.tensor_parallel.get_cuda_rng_tracker().fork()
936
        else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
937
            rng_context = nullcontext()
Vijay Korthikanti's avatar
Vijay Korthikanti committed
938
939

        with rng_context:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
940
            # Forward pass.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
941
            if self.recompute_granularity == 'full':
Vijay Korthikanti's avatar
Vijay Korthikanti committed
942
943
944
945
946
947
948
949
950
951
952
953
954
                hidden_states = self._checkpointed_forward(hidden_states,
                                                           attention_mask,
                                                           encoder_output,
                                                           enc_dec_attn_mask)
            else:
                for index in range(self.num_layers):
                    layer = self._get_layer(index)
                    hidden_states = layer(
                        hidden_states,
                        attention_mask,
                        encoder_output=encoder_output,
                        enc_dec_attn_mask=enc_dec_attn_mask,
                        inference_params=inference_params)
mshoeybi's avatar
mshoeybi committed
955

956
        # Final layer norm.
957
        if self.post_process and self.post_layer_norm:
958
959
            hidden_states = self.final_layernorm(hidden_states)

960
        return hidden_states