transformer.py 53.4 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4

"""Transformer."""
import math
5
from contextlib import nullcontext
6
import torch
7
import torch.nn.functional as F
8

9
from megatron import get_timers, get_args, core, get_num_microbatches
10
from .module import MegatronModule
11
from megatron.core import mpu, tensor_parallel
12
from megatron.core.enums import ModelType
Mostofa Patwary's avatar
Mostofa Patwary committed
13
from megatron.model import LayerNorm
14
from megatron.model.enums import AttnMaskType, LayerType, AttnType
15
16
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
17
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
18

19
20
21
22
23
24
25
26
27
28
try:
    from einops import rearrange
except ImportError:
    rearrange = None

try:
    from flash_attn.flash_attn_interface import flash_attn_unpadded_func
except ImportError:
    flash_attn_unpadded_func = None

29
30
31
32
33
34
35
36
37
38
""" We use the following notation throughout this file:
     h: hidden size
     n: number of attention heads
     p: number of model parallel partitions
     np: n/p
     hp: h/p
     hn: h/n
     b: batch size
     s: sequence length
     l: number of layers
39
    Transformer takes input of size [s, b, h] and returns a
40
41
42
43
    tensor of the same size. We use the following arguments:
        hyperparameters: transformer hyperparameters
"""

44
class DropPath(MegatronModule):
45
    """Drop paths (Stochastic Depth) per sample
46
47
48
    (when applied in main path of residual blocks).
    """

Vijay Korthikanti's avatar
Vijay Korthikanti committed
49
    def __init__(self, drop_prob=0.):
50
51
52
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

Vijay Korthikanti's avatar
Vijay Korthikanti committed
53
    def forward(self, hidden_state):
54
        if self.drop_prob == 0. or not self.training:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
55
            return hidden_state
56
57
        keep_prob = 1 - self.drop_prob
        # work with diff dim tensors, not just 2D ConvNets
58
59
        # hidden_state: [s, b, h]
        shape = (1,) + (hidden_state.shape[1],) + (1,) * (hidden_state.ndim - 2)
60
        random_tensor = keep_prob + \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
61
            torch.rand(shape, dtype=hidden_state.dtype, device=hidden_state.device)
62
        random_tensor.floor_()  # binarize
Vijay Korthikanti's avatar
Vijay Korthikanti committed
63
        output = hidden_state.div(keep_prob) * random_tensor
64
65
        return output

66
67
68
69
70
71
72
73
74
75
76
def _args_to_kwargs():
    args = get_args()

    common_kwargs = {
        "params_dtype": args.params_dtype,
        "use_cpu_initialization": args.use_cpu_initialization,
        "perform_initialization": args.perform_initialization,
        "gradient_accumulation_fusion": args.gradient_accumulation_fusion,
        "sequence_parallel_enabled": args.sequence_parallel,
    }
    return common_kwargs
77

78
79
80
81
82
class ParallelMLP(MegatronModule):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
hwijeen's avatar
hwijeen committed
83
    state back into h hidden dimension.
84
85
    """

86
    def __init__(self, init_method, output_layer_init_method):
87
        super(ParallelMLP, self).__init__()
Mohammad's avatar
Mohammad committed
88
        args = get_args()
89

90
        self.add_bias = args.add_bias_linear
91

92
        # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
93
        self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear(
Mohammad's avatar
Mohammad committed
94
            args.hidden_size,
95
96
            args.ffn_hidden_size * 2 if args.swiglu else args.ffn_hidden_size,
            bias=self.add_bias,
97
            gather_output=False,
98
            init_method=init_method,
99
100
101
            skip_bias_add=True,
            async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
            **_args_to_kwargs())
102

103
104
105
106
        self.bias_gelu_fusion = False
        self.activation_func = None
        self.swiglu = args.swiglu

107
108
109
110
        if args.openai_gelu:
            self.activation_func = openai_gelu
        elif args.onnx_safe:
            self.activation_func = erf_gelu
111
112
113
114
115
116
117
118
119
120
121
122
        elif args.swiglu:
            def swiglu(x):
                x = torch.chunk(x, 2, dim=-1)
                return F.silu(x[0]) * x[1]
            self.activation_func = swiglu
        elif args.squared_relu:
            def squared_relu(x):
                return torch.pow(F.relu(x), 2)
            self.activation_func = squared_relu
        else:
            self.bias_gelu_fusion = args.bias_gelu_fusion
            self.activation_func = F.gelu
123
124

        # Project back to h.
125
        self.dense_4h_to_h = tensor_parallel.RowParallelLinear(
126
            args.ffn_hidden_size,
Mohammad's avatar
Mohammad committed
127
            args.hidden_size,
128
            bias=self.add_bias,
129
            input_is_parallel=True,
130
            init_method=output_layer_init_method,
131
132
            skip_bias_add=True,
            **_args_to_kwargs())
133

134
135
    def forward(self, hidden_states):

136
137
        # [s, b, 4hp]
        intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
138

139
        if self.bias_gelu_fusion:
140
141
142
            assert self.add_bias is True
            assert self.activation_func == F.gelu
            intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
143
        else:
Jared Casper's avatar
Jared Casper committed
144
            if bias_parallel is not None:
145
146
                intermediate_parallel = intermediate_parallel + bias_parallel
            intermediate_parallel = self.activation_func(intermediate_parallel)
147
148
149
150

        # [s, b, h]
        output, output_bias = self.dense_4h_to_h(intermediate_parallel)
        return output, output_bias
151

rprenger's avatar
rprenger committed
152
153
154
155
class SwitchMLP(MegatronModule):
    """
    Routes input to one of N MLP "experts"
    """
rprenger's avatar
rprenger committed
156
    def __init__(self, init_method, output_layer_init_method):
rprenger's avatar
rprenger committed
157
158
        super(SwitchMLP, self).__init__()
        args = get_args()
rprenger's avatar
rprenger committed
159
        self.router = torch.nn.Linear(args.hidden_size, args.num_experts)
rprenger's avatar
rprenger committed
160
        self.experts = torch.nn.ModuleList()
rprenger's avatar
rprenger committed
161
        for i in range(args.num_experts):
rprenger's avatar
rprenger committed
162
            self.experts.append(ParallelMLP(init_method, output_layer_init_method))
163

rprenger's avatar
rprenger committed
164
    def forward(self, hidden_states):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
165
166
167
        # hidden_states: [s, b, h]
        s = hidden_states.size(0)
        b = hidden_states.size(1)
rprenger's avatar
rprenger committed
168
169
        h = hidden_states.size(2)
        route = self.router(hidden_states)
rprenger's avatar
rprenger committed
170
        route = torch.nn.functional.softmax(route, dim=2)
rprenger's avatar
rprenger committed
171
        max_prob, max_ind = torch.max(route, dim=2)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
172
        max_prob = torch.unsqueeze(max_prob, 2) # [s b 1]
173

rprenger's avatar
rprenger committed
174
        # TODO (rprenger) TODO this could be made easier to read
Vijay Korthikanti's avatar
Vijay Korthikanti committed
175
        # Converting [s, b, h] to [s*b, h].
176
        # Each vector could be routed differently
Vijay Korthikanti's avatar
Vijay Korthikanti committed
177
178
179
        hidden_states = hidden_states.view(-1, hidden_states.size(2)) # [s*b h]
        max_prob = max_prob.view(-1, max_prob.size(2)) # [s*b 1]
        max_ind = max_ind.view(-1) # [s*b]
rprenger's avatar
rprenger committed
180
181
182

        output_total = torch.empty_like(hidden_states)
        output_bias_total = torch.empty_like(hidden_states)
rprenger's avatar
rprenger committed
183
        #TODO (rprenger) This does each expert in serial, but it could be parallelized
184

rprenger's avatar
rprenger committed
185
        for expert_num, expert in enumerate(self.experts):
186
187
            local_indices = (max_ind == expert_num).nonzero()
            hidden = hidden_states[local_indices,:]
rprenger's avatar
rprenger committed
188
189
            output, output_bias = expert(hidden)
            output_bias = output_bias.expand_as(output)
190
191
192
            output_total[local_indices,:] = output
            output_bias_total[local_indices,:] = output_bias

rprenger's avatar
rprenger committed
193
194
        output_total = output_total*max_prob
        output_bias_total = output_bias_total*max_prob
Vijay Korthikanti's avatar
Vijay Korthikanti committed
195
196
        output_total = output_total.view(s, b, h)
        output_bias_total = output_bias_total.view(s, b, h)
rprenger's avatar
rprenger committed
197
198

        return output_total, output_bias_total
199

200
201

class CoreAttention(MegatronModule):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
202

203
204
205
206
207
208
209
210
211
212
213
214
215
    def __init__(self, layer_number,
                 attn_mask_type=AttnMaskType.padding):
        super(CoreAttention, self).__init__()
        args = get_args()
        self.fp16 = args.fp16
        self.bf16 = args.bf16

        self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
        if self.apply_query_key_layer_scaling:
            self.attention_softmax_in_fp32 = True
        self.layer_number = max(1, layer_number)
        self.attn_mask_type = attn_mask_type
Vijay Korthikanti's avatar
Vijay Korthikanti committed
216
        self.sequence_parallel = args.sequence_parallel
217
218
219
220

        projection_size = args.kv_channels * args.num_attention_heads

        # Per attention head and per partition values.
221
        world_size = mpu.get_tensor_model_parallel_world_size()
222
223
224
        self.hidden_size_per_partition = core.utils.divide(projection_size,
                                                           world_size)
        self.hidden_size_per_attention_head = core.utils.divide(
225
            projection_size, args.num_attention_heads)
226
        self.num_attention_heads_per_partition = core.utils.divide(
227
            args.num_attention_heads, world_size)
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246

        coeff = None
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        if self.apply_query_key_layer_scaling:
            coeff = self.layer_number
            self.norm_factor *= coeff

        self.scale_mask_softmax = FusedScaleMaskSoftmax(
            self.fp16, self.bf16,
            self.attn_mask_type,
            args.masked_softmax_fusion,
            attention_mask_func,
            self.attention_softmax_in_fp32,
            coeff)

        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
        self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
247

248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
    def forward(self, query_layer, key_layer,
                value_layer, attention_mask):

        # ===================================
        # Raw attention scores. [b, np, s, s]
        # ===================================

        # [b, np, sq, sk]
        output_size = (query_layer.size(1),
                       query_layer.size(2),
                       query_layer.size(0),
                       key_layer.size(0))

        # [sq, b, np, hn] -> [sq, b * np, hn]
        query_layer = query_layer.view(output_size[2],
                                       output_size[0] * output_size[1], -1)
        # [sk, b, np, hn] -> [sk, b * np, hn]
        key_layer = key_layer.view(output_size[3],
                                   output_size[0] * output_size[1], -1)

Vijay Korthikanti's avatar
Vijay Korthikanti committed
268
        # preallocting input tensor: [b * np, sq, sk]
269
        matmul_input_buffer = mpu.get_global_memory_buffer().get_tensor(
270
            (output_size[0]*output_size[1], output_size[2], output_size[3]),
Vijay Korthikanti's avatar
Vijay Korthikanti committed
271
            query_layer.dtype, "mpu")
272
273
274

        # Raw attention scores. [b * np, sq, sk]
        matmul_result = torch.baddbmm(
Vijay Korthikanti's avatar
Vijay Korthikanti committed
275
            matmul_input_buffer,
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
            query_layer.transpose(0, 1),   # [b * np, sq, hn]
            key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
            beta=0.0, alpha=(1.0/self.norm_factor))

        # change view to [b, np, sq, sk]
        attention_scores = matmul_result.view(*output_size)

        # ===========================
        # Attention probs and dropout
        # ===========================

        # attention scores and attention mask [b, np, sq, sk]
        attention_probs = self.scale_mask_softmax(attention_scores,
                                                  attention_mask)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
293
        if not self.sequence_parallel:
294
            with tensor_parallel.get_cuda_rng_tracker().fork():
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
                attention_probs = self.attention_dropout(attention_probs)
        else:
            attention_probs = self.attention_dropout(attention_probs)

        # =========================
        # Context layer. [sq, b, hp]
        # =========================

        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]

        # context layer shape: [b, np, sq, hn]
        output_size = (value_layer.size(1),
                       value_layer.size(2),
                       query_layer.size(0),
                       value_layer.size(3))

        # change view [sk, b * np, hn]
        value_layer = value_layer.view(value_layer.size(0),
                                       output_size[0] * output_size[1], -1)

        # change view [b * np, sq, sk]
        attention_probs = attention_probs.view(output_size[0] * output_size[1],
                                               output_size[2], -1)

        # matmul: [b * np, sq, hn]
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))

        # change view [b, np, sq, hn]
        context_layer = context_layer.view(*output_size)

        # [b, np, sq, hn] --> [sq, b, np, hn]
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

        # [sq, b, np, hn] --> [sq, b, hp]
        new_context_layer_shape = context_layer.size()[:-2] + \
            (self.hidden_size_per_partition,)
        context_layer = context_layer.view(*new_context_layer_shape)

        return context_layer


337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
class FlashSelfAttention(torch.nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,
                 device=None, dtype=None):
        super().__init__()
        assert flash_attn_unpadded_func is not None, ('Please install FlashAttention first, '
                                                      'e.g., with pip install flash-attn')
        assert rearrange is not None, 'Please install einops first, e.g., with pip install einops'
        self.causal = causal
        self.softmax_scale = softmax_scale
        self.dropout_p = attention_dropout

    def forward(self, q, k, v):
        """Implements the multihead softmax attention.
        Arguments
        ---------
            q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
        """
        assert q.dtype in [torch.float16, torch.bfloat16]
        assert q.is_cuda
        batch_size, seqlen = q.shape[0], q.shape[1]
        q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]]
        max_s = seqlen
        cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
                                  device=q.device)
        output = flash_attn_unpadded_func(
            q, k, v, cu_seqlens, cu_seqlens, max_s, max_s,
            self.dropout_p if self.training else 0.0,
            softmax_scale=self.softmax_scale, causal=self.causal
        )
        output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
        return output


379
class ParallelAttention(MegatronModule):
380
381
    """Parallel self-attention layer abstract class.

Vijay Korthikanti's avatar
Vijay Korthikanti committed
382
    Self-attention layer takes input with size [s, b, h]
383
384
    and returns output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
385

386
    def __init__(self, init_method,
387
388
389
390
                 output_layer_init_method, layer_number,
                 attention_type=AttnType.self_attn,
                 attn_mask_type=AttnMaskType.padding):
        super(ParallelAttention, self).__init__()
Mohammad's avatar
Mohammad committed
391
        args = get_args()
392
        self.layer_number = max(1, layer_number)
393
394
        self.attention_type = attention_type
        self.attn_mask_type = attn_mask_type
395
        self.params_dtype = args.params_dtype
396
397
398
399
400
401
402
403
404
405
406
407
408
        self.sequence_parallel = args.sequence_parallel

        self.use_flash_attn = args.use_flash_attn
        if self.use_flash_attn:
            if flash_attn_unpadded_func is None:
                raise ImportError('FlashAttention is not installed, please install with '
                                  'pip install flash-attn')
            assert attention_type == AttnType.self_attn, ('FlashAttention code path only supports '
                                                          'self-attention for now')
            assert self.attn_mask_type == AttnMaskType.causal, ('FlashAttention code path only '
                                                                'supports causal mask for now')
            if rearrange is None:
                raise ImportError('einops is not installed, please install with pip install einops')
409
410

        projection_size = args.kv_channels * args.num_attention_heads
411
412

        # Per attention head and per partition values.
413
        world_size = mpu.get_tensor_model_parallel_world_size()
414
        self.hidden_size_per_attention_head = core.utils.divide(
415
            projection_size, args.num_attention_heads)
416
        self.num_attention_heads_per_partition = core.utils.divide(
Mohammad's avatar
Mohammad committed
417
            args.num_attention_heads, world_size)
418
419

        # Strided linear layer.
420
        if attention_type == AttnType.self_attn:
421
            self.query_key_value = tensor_parallel.ColumnParallelLinear(
422
423
                args.hidden_size,
                3 * projection_size,
424
                bias=args.add_bias_linear,
425
                gather_output=False,
426
427
428
                init_method=init_method,
                async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
                **_args_to_kwargs())
429
430
        else:
            assert attention_type == AttnType.cross_attn
431
            self.query = tensor_parallel.ColumnParallelLinear(
432
433
                args.hidden_size,
                projection_size,
434
                bias=args.add_bias_linear,
435
                gather_output=False,
436
437
438
                init_method=init_method,
                async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
                **_args_to_kwargs())
439

440

441
            self.key_value = tensor_parallel.ColumnParallelLinear(
442
443
                args.hidden_size,
                2 * projection_size,
444
                bias=args.add_bias_linear,
445
                gather_output=False,
446
447
448
                init_method=init_method,
                async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
                **_args_to_kwargs())
449

450
451
        self.core_attention = CoreAttention(self.layer_number,
                                            self.attn_mask_type)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
452
        self.checkpoint_core_attention = args.recompute_granularity == 'selective'
453

454
455
456
457
458
        if self.use_flash_attn:
            self.core_attention_flash = FlashSelfAttention(
                causal=True, attention_dropout=args.attention_dropout
            )

459
        # Output.
460
        self.dense = tensor_parallel.RowParallelLinear(
461
            projection_size,
Mohammad's avatar
Mohammad committed
462
            args.hidden_size,
463
            bias=args.add_bias_linear,
464
            input_is_parallel=True,
465
            init_method=output_layer_init_method,
466
467
            skip_bias_add=True,
            **_args_to_kwargs())
Vijay Korthikanti's avatar
Vijay Korthikanti committed
468

469
470
471
472
473
474
475
476
477
478
479
480
    def _checkpointed_attention_forward(self, query_layer, key_layer,
                                        value_layer, attention_mask):
        """Forward method with activation checkpointing."""
        def custom_forward(*inputs):
            query_layer = inputs[0]
            key_layer = inputs[1]
            value_layer = inputs[2]
            attention_mask = inputs[3]
            output_ = self.core_attention(query_layer, key_layer,
                                          value_layer, attention_mask)
            return output_

481
        hidden_states = tensor_parallel.checkpoint(
482
483
484
485
            custom_forward,
            False, query_layer, key_layer, value_layer, attention_mask)

        return hidden_states
486
487
488
489
490
491
492
493
494
495
496

    def _allocate_memory(self, inference_max_sequence_len, batch_size):
        return torch.empty(
            inference_max_sequence_len,
            batch_size,
            self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
            dtype=self.params_dtype,
            device=torch.cuda.current_device())

    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
497
                encoder_output=None, inference_params=None):
498
        # hidden_states: [sq, b, h]
499

500
501
502
        # =================================================
        # Pre-allocate memory for key-values for inference.
        # =================================================
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
503

mshoeybi's avatar
mshoeybi committed
504
        if inference_params:
505
            if self.layer_number not in inference_params.key_value_memory_dict:
mshoeybi's avatar
mshoeybi committed
506
                inf_max_seq_len = inference_params.max_sequence_len
mshoeybi's avatar
mshoeybi committed
507
                inf_max_batch_size = inference_params.max_batch_size
508
                inference_key_memory = self._allocate_memory(
mshoeybi's avatar
mshoeybi committed
509
                    inf_max_seq_len, inf_max_batch_size)
510
                inference_value_memory = self._allocate_memory(
mshoeybi's avatar
mshoeybi committed
511
                    inf_max_seq_len, inf_max_batch_size)
512
513
514
515
516
                inference_params.key_value_memory_dict[self.layer_number] = (
                    inference_key_memory, inference_value_memory)
            else:
                inference_key_memory, inference_value_memory = \
                    inference_params.key_value_memory_dict[self.layer_number]
mshoeybi's avatar
mshoeybi committed
517

518
519
520
        # =====================
        # Query, Key, and Value
        # =====================
521

522
523
524
525
526
527
528
529
530
531
532
533
534
        if self.attention_type == AttnType.self_attn:
            # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
            mixed_x_layer, _ = self.query_key_value(hidden_states)

            # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
            new_tensor_shape = mixed_x_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 3 * self.hidden_size_per_attention_head)
            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

            # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
            (query_layer,
             key_layer,
535
             value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_x_layer, 3)
536
537
538
539
540
541
542
543
544
545
546
547
        else:
            # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
            mixed_kv_layer, _ = self.key_value(encoder_output)

            # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
            new_tensor_shape = mixed_kv_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 2 * self.hidden_size_per_attention_head)
            mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

            # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
            (key_layer,
548
             value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2)
549
550
551
552
553
554
555
556

            # Attention head [sq, b, h] --> [sq, b, hp]
            query_layer, _ = self.query(hidden_states)
            # [sq, b, hp] --> [sq, b, np, hn]
            new_tensor_shape = query_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 self.hidden_size_per_attention_head)
            query_layer = query_layer.view(*new_tensor_shape)
557

mshoeybi's avatar
mshoeybi committed
558
559
560
        # ==================================
        # Adjust key and value for inference
        # ==================================
561

mshoeybi's avatar
mshoeybi committed
562
        if inference_params:
mshoeybi's avatar
mshoeybi committed
563
564
            batch_start = inference_params.batch_size_offset
            batch_end = batch_start + key_layer.size(1)
565
            assert batch_end <= inference_key_memory.size(1)
mshoeybi's avatar
mshoeybi committed
566
567
            sequence_start = inference_params.sequence_len_offset
            sequence_end = sequence_start + key_layer.size(0)
568
            assert sequence_end <= inference_key_memory.size(0)
569
            # Copy key and values.
570
571
572
573
574
            inference_key_memory[sequence_start:sequence_end,
                                 batch_start:batch_end, ...] = key_layer
            inference_value_memory[sequence_start:sequence_end,
                                   batch_start:batch_end, ...] = value_layer
            key_layer = inference_key_memory[
mshoeybi's avatar
mshoeybi committed
575
                :sequence_end, batch_start:batch_end, ...]
576
            value_layer = inference_value_memory[
mshoeybi's avatar
mshoeybi committed
577
                :sequence_end, batch_start:batch_end, ...]
578

579
580
581
        # ==================================
        # core attention computation
        # ==================================
582

583
584
585
586
587
588
589
        if not self.use_flash_attn:
            if self.checkpoint_core_attention:
                context_layer = self._checkpointed_attention_forward(
                    query_layer, key_layer, value_layer, attention_mask)
            else:
                context_layer = self.core_attention(
                    query_layer, key_layer, value_layer, attention_mask)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
590
        else:
591
592
593
594
595
596
597
598
            q, k, v = [rearrange(x, 's b ... -> b s ...').contiguous()
                       for x in (query_layer, key_layer, value_layer)]
            if not self.sequence_parallel:
                with tensor_parallel.get_cuda_rng_tracker().fork():
                    context_layer = self.core_attention_flash(q, k, v)
            else:
                context_layer = self.core_attention_flash(q, k, v)
            context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous()
599
600

        # =================
601
        # Output. [sq, b, h]
602
603
604
        # =================

        output, bias = self.dense(context_layer)
605

606
607
608
        return output, bias


609
def bias_dropout_add(x, bias, residual, prob, training):
610
    # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
611
612
613
    if bias is not None:
        x = x + bias
    out = torch.nn.functional.dropout(x, p=prob, training=training)
614
615
616
617
618
619
620
621
622
623
624
    out = residual + out
    return out


def get_bias_dropout_add(training):
    def _bias_dropout_add(x, bias, residual, prob):
        return bias_dropout_add(x, bias, residual, prob, training)
    return _bias_dropout_add


@torch.jit.script
625
626
627
628
def bias_dropout_add_fused_train(x: torch.Tensor,
                                 bias: torch.Tensor,
                                 residual: torch.Tensor,
                                 prob: float) -> torch.Tensor:
629
630
631
632
    return bias_dropout_add(x, bias, residual, prob, True)


@torch.jit.script
633
634
635
636
def bias_dropout_add_fused_inference(x: torch.Tensor,
                                     bias: torch.Tensor,
                                     residual: torch.Tensor,
                                     prob: float) -> torch.Tensor:
637
    return bias_dropout_add(x, bias, residual, prob, False)
638
639
640
641
642


class ParallelTransformerLayer(MegatronModule):
    """A single transformer layer.

Vijay Korthikanti's avatar
Vijay Korthikanti committed
643
    Transformer layer takes input with size [s, b, h] and returns an
644
645
    output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
646

647
648
    def __init__(self, init_method, output_layer_init_method,
                 layer_number, layer_type=LayerType.encoder,
649
650
                 self_attn_mask_type=AttnMaskType.padding,
                 drop_path_rate=0.):
Mohammad's avatar
Mohammad committed
651
        args = get_args()
652
653

        super(ParallelTransformerLayer, self).__init__()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
654
        self.layer_number = layer_number
655
        self.layer_type = layer_type
656
657

        self.apply_residual_connection_post_layernorm \
Mohammad's avatar
Mohammad committed
658
            = args.apply_residual_connection_post_layernorm
659

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
660
661
662
        self.bf16 = args.bf16
        self.fp32_residual_connection = args.fp32_residual_connection

663
664
        # Layernorm on the input data.
        self.input_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
665
            args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
666
            eps=args.layernorm_epsilon,
667
            no_persist_layer_norm=args.no_persist_layer_norm,
Mostofa Patwary's avatar
Mostofa Patwary committed
668
            sequence_parallel=args.sequence_parallel,
Jared Casper's avatar
Jared Casper committed
669
            apply_layernorm_1p=args.apply_layernorm_1p)
670
671

        # Self attention.
672
673
674
675
676
677
        self.self_attention = ParallelAttention(
            init_method,
            output_layer_init_method,
            layer_number,
            attention_type=AttnType.self_attn,
            attn_mask_type=self_attn_mask_type)
678
679
        self.hidden_dropout = args.hidden_dropout
        self.bias_dropout_fusion = args.bias_dropout_fusion
Vijay Korthikanti's avatar
Vijay Korthikanti committed
680
        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None
681

682
        # Layernorm on the attention output
683
        self.post_attention_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
684
            args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
685
            eps=args.layernorm_epsilon,
686
            no_persist_layer_norm=args.no_persist_layer_norm,
Mostofa Patwary's avatar
Mostofa Patwary committed
687
            sequence_parallel=args.sequence_parallel,
Jared Casper's avatar
Jared Casper committed
688
            apply_layernorm_1p=args.apply_layernorm_1p)
689

690
691
692
693
694
695
696
697
698
        if self.layer_type == LayerType.decoder:
            self.inter_attention = ParallelAttention(
                init_method,
                output_layer_init_method,
                layer_number,
                attention_type=AttnType.cross_attn)
            # Layernorm on the attention output.
            self.post_inter_attention_layernorm = LayerNorm(
                args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
699
                eps=args.layernorm_epsilon,
700
                no_persist_layer_norm=args.no_persist_layer_norm,
Mostofa Patwary's avatar
Mostofa Patwary committed
701
                sequence_parallel=args.sequence_parallel,
Jared Casper's avatar
Jared Casper committed
702
                apply_layernorm_1p=args.apply_layernorm_1p)
703

704
        # MLP
rprenger's avatar
rprenger committed
705
706
707
708
        if args.num_experts is not None:
            self.mlp = SwitchMLP(init_method, output_layer_init_method)
        else:
            self.mlp = ParallelMLP(init_method, output_layer_init_method)
709

710
711
712
713
714
715
716
        # Set bias+dropout+add fusion grad_enable execution handler.
        TORCH_MAJOR = int(torch.__version__.split('.')[0])
        TORCH_MINOR = int(torch.__version__.split('.')[1])
        use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
        self.bias_dropout_add_exec_handler = \
                nullcontext if use_nvfuser else torch.enable_grad

717
    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
718
719
                encoder_output=None, enc_dec_attn_mask=None,
                inference_params=None):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
720
        # hidden_states: [s, b, h]
721

722
        # Layer norm at the beginning of the transformer layer.
723
724
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
725
        attention_output, attention_bias = \
726
727
728
            self.self_attention(
                layernorm_output,
                attention_mask,
mshoeybi's avatar
mshoeybi committed
729
                inference_params=inference_params)
730

731
732
        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
733
734
735
736
            residual = layernorm_output
        else:
            residual = hidden_states

Vijay Korthikanti's avatar
Vijay Korthikanti committed
737
        if self.drop_path is None:
738
739
740
741
742
743
744
745
746
            # jit scripting for a nn.module (with dropout) is not
            # trigerring the fusion kernel. For now, we use two
            # different nn.functional routines to account for varying
            # dropout semantics during training and inference phases.
            if self.bias_dropout_fusion:
                if self.training:
                    bias_dropout_add_func = bias_dropout_add_fused_train
                else:
                    bias_dropout_add_func = bias_dropout_add_fused_inference
747
            else:
748
                bias_dropout_add_func = get_bias_dropout_add(self.training)
749

750
751
            if attention_bias is not None:
                attention_bias = attention_bias.expand_as(residual)
752
            with self.bias_dropout_add_exec_handler():
753
754
                layernorm_input = bias_dropout_add_func(
                    attention_output,
755
                    attention_bias,
756
757
758
759
760
761
762
                    residual,
                    self.hidden_dropout)
        else:
            out = torch.nn.functional.dropout(attention_output + attention_bias,
                                              p=self.hidden_dropout,
                                              training=self.training)
            layernorm_input = residual + self.drop_path(out)
763

764
765
766
        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)

767
768
769
770
771
772
773
774
775
776
777
        if self.layer_type == LayerType.decoder:
            attention_output, attention_bias = \
                self.inter_attention(layernorm_output,
                                     enc_dec_attn_mask,
                                     encoder_output=encoder_output)
            # residual connection
            if self.apply_residual_connection_post_layernorm:
                residual = layernorm_output
            else:
                residual = layernorm_input

778
779
780
            if attention_bias is not None:
                attention_bias = attention_bias.expand_as(residual)

781
            with self.bias_dropout_add_exec_handler():
782
783
                layernorm_input = bias_dropout_add_func(
                    attention_output,
784
                    attention_bias,
785
786
787
788
789
790
                    residual,
                    self.hidden_dropout)

            # Layer norm post the decoder attention
            layernorm_output = self.post_inter_attention_layernorm(layernorm_input)

791
        # MLP.
792
        mlp_output, mlp_bias = self.mlp(layernorm_output)
793

794
795
        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
796
            residual = layernorm_output
797
        else:
798
799
            residual = layernorm_input

Vijay Korthikanti's avatar
Vijay Korthikanti committed
800
        if self.drop_path is None:
801
802
            if mlp_bias is not None:
                mlp_bias = mlp_bias.expand_as(residual)
803
            with self.bias_dropout_add_exec_handler():
804
805
                output = bias_dropout_add_func(
                    mlp_output,
806
                    mlp_bias,
807
808
                    residual,
                    self.hidden_dropout)
809
810
811
812
813
814
815

            # Jit compiled function creates 'view' tensor. This tensor
            # potentially gets saved in the MPU checkpoint function context,
            # which rejects view tensors. While making a viewless tensor here
            # won't result in memory savings (like the data loader, or
            # p2p_communication), it serves to document the origin of this
            # 'view' tensor.
816
817
818
            output = core.utils.make_viewless_tensor(inp = output,
                                                     requires_grad = output.requires_grad,
                                                     keep_graph = True)
819

820
        else:
821
822
823
            if mlp_bias is not None:
                mlp_output = mlp_output + mlp_bias
            out = torch.nn.functional.dropout(mlp_output,
824
825
826
                                              p=self.hidden_dropout,
                                              training=self.training)
            output = residual + self.drop_path(out)
827
828
829
830

        return output


831
832
833
class NoopTransformerLayer(MegatronModule):
    """A single 'no-op' transformer layer.

Lawrence McAfee's avatar
Lawrence McAfee committed
834
    The sole purpose of this layer is for when a standalone embedding layer
835
    is used (i.e., args.standalone_embedding_stage == True). In this case,
Lawrence McAfee's avatar
Lawrence McAfee committed
836
837
838
839
840
841
842
843
844
    zero transformer layers are assigned when pipeline rank == 0. Additionally,
    when virtual pipeline rank >= 1, zero total model parameters are created
    (virtual rank 0 contains the input embedding). This results in the model's
    input and output tensors being the same, which causes an error when
    performing certain memory optimiations on the output tensor (e.g.,
    deallocating it). Thus, this layer disconnects the input from the output
    via a clone. Since ranks containing a no-op layer are generally under-
    utilized (both compute and memory), there's no worry of any performance
    degredation.
845
846
847
848
849
850
851
852
853
854
855
856
    """

    def __init__(self, layer_number):
        super().__init__()
        self.layer_number = layer_number

    def forward(self, hidden_states, attention_mask,
                encoder_output=None, enc_dec_attn_mask=None,
                inference_params=None):
        return hidden_states.clone()


Jared Casper's avatar
Jared Casper committed
857
def _get_num_layers(args, is_encoder_and_decoder_model, is_decoder=False):
858
    """Compute the number of transformer layers resident on the current rank."""
Jared Casper's avatar
Jared Casper committed
859
    if mpu.get_pipeline_model_parallel_world_size() > 1:
860
861
862
863
864
865
866
867
868
869
870
871
872
        if is_encoder_and_decoder_model:
            assert args.pipeline_model_parallel_split_rank is not None

            # When a standalone embedding stage is used, a rank is taken from
            # the encoder's ranks, to be used for the encoder's embedding
            # layer. This way, the rank referenced by the 'split rank' remains
            # the same whether or not a standalone embedding stage is used.
            num_ranks_in_encoder = (
                args.pipeline_model_parallel_split_rank - 1
                if args.standalone_embedding_stage else
                args.pipeline_model_parallel_split_rank
            )
            num_ranks_in_decoder = args.transformer_pipeline_model_parallel_size - num_ranks_in_encoder
Jared Casper's avatar
Jared Casper committed
873
874
875
876
            assert args.encoder_num_layers % num_ranks_in_encoder == 0, \
                    'encoder_num_layers (%d) must be divisible by number of ranks given to encoder (%d)' % (args.encoder_num_layers, num_ranks_in_encoder)
            assert args.decoder_num_layers % num_ranks_in_decoder == 0, \
                    'decoder_num_layers (%d) must be divisible by number of ranks given to decoder (%d)' % (args.decoder_num_layers, num_ranks_in_decoder)
Jared Casper's avatar
Jared Casper committed
877
            if mpu.is_pipeline_stage_before_split():
878
879
880
                num_layers = (
                    0
                    if args.standalone_embedding_stage
Jared Casper's avatar
Jared Casper committed
881
                    and mpu.get_pipeline_model_parallel_rank() == 0 else
Jared Casper's avatar
Jared Casper committed
882
                    args.encoder_num_layers // num_ranks_in_encoder
883
884
                )
            else:
Jared Casper's avatar
Jared Casper committed
885
                num_layers = args.decoder_num_layers // num_ranks_in_decoder
886
        else:
Jared Casper's avatar
Jared Casper committed
887
            assert args.num_layers == args.encoder_num_layers
888
889
890
891
892
893
894
895
896
897
            assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, \
                'num_layers must be divisible by transformer_pipeline_model_parallel_size'

            # When a standalone embedding stage is used, all transformer layers
            # are divided among pipeline rank >= 1, while on pipeline rank 0,
            # ranks either contain the input embedding layer (virtual pp rank 0),
            # or no layers at all (virtual pp rank >= 1).
            num_layers = (
                0
                if args.standalone_embedding_stage
Jared Casper's avatar
Jared Casper committed
898
                and mpu.get_pipeline_model_parallel_rank() == 0 else
899
900
901
                args.num_layers // args.transformer_pipeline_model_parallel_size
            )
    else:
Jared Casper's avatar
Jared Casper committed
902
903
904
905
        if not is_decoder:
            num_layers = args.encoder_num_layers
        else:
            num_layers = args.decoder_num_layers
906
907
908
    return num_layers


909
910
911
class ParallelTransformer(MegatronModule):
    """Transformer class."""

912
    def __init__(self, init_method, output_layer_init_method,
913
                 layer_type=LayerType.encoder,
914
                 self_attn_mask_type=AttnMaskType.padding,
915
                 post_layer_norm=True,
916
917
                 pre_process=True, post_process=True,
                 drop_path_rate=0.0):
918
        super(ParallelTransformer, self).__init__()
Mohammad's avatar
Mohammad committed
919
        args = get_args()
920

921
922
        self.layer_type = layer_type
        self.model_type = args.model_type
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
923
        self.bf16 = args.bf16
924
        self.fp32_residual_connection = args.fp32_residual_connection
925
        self.post_layer_norm = post_layer_norm
926
927
928
        self.pre_process = pre_process
        self.post_process = post_process
        self.input_tensor = None
929
        self.drop_path_rate = drop_path_rate
930
        self.transformer_impl = args.transformer_impl
931

932
        # Store activation checkpoiting flag.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
933
934
935
        self.recompute_granularity = args.recompute_granularity
        self.recompute_method = args.recompute_method
        self.recompute_num_layers = args.recompute_num_layers
Vijay Korthikanti's avatar
Vijay Korthikanti committed
936
937
        self.distribute_saved_activations = \
            args.distribute_saved_activations and not args.sequence_parallel
938

Vijay Korthikanti's avatar
Vijay Korthikanti committed
939
        self.sequence_parallel = args.sequence_parallel
940

941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
        # Transformer Engine Init.
        if self.transformer_impl == 'transformer_engine':
            global transformer_engine
            import transformer_engine
        self.use_fp8 = args.fp8_e4m3 or args.fp8_hybrid
        self.fp8_recipe = None
        self.fp8_group = mpu.get_data_parallel_group()
        if self.use_fp8:
            if args.fp8_e4m3:
                fp8_format = transformer_engine.common.recipe.Format.E4M3
            elif args.fp8_hybrid:
                fp8_format = transformer_engine.common.recipe.Format.HYBRID
            self.fp8_recipe = transformer_engine.common.recipe.DelayedScaling(
                margin=args.fp8_margin,
                interval=args.fp8_interval,
                fp8_format=fp8_format,
                amax_history_len=args.fp8_amax_history_len,
                amax_compute_algo=args.fp8_amax_compute_algo,
                override_linear_precision=(False, False, not args.fp8_wgrad),
            )

        self.num_microbatches_in_previous_step = -1
        self.microbatch_count = 0
        self.checkpoint_core_attention = args.recompute_granularity == 'selective'

966
        # Number of layers.
967
        self.num_layers = _get_num_layers(
968
969
970
            args,
            args.model_type == ModelType.encoder_and_decoder,
            layer_type == LayerType.decoder)
Mohammad's avatar
Mohammad committed
971

Vijay Korthikanti's avatar
Vijay Korthikanti committed
972
        self.drop_path_rates = [rate.item() for rate in torch.linspace(0, self.drop_path_rate, args.num_layers)]
973

Mohammad's avatar
Mohammad committed
974
975
        # Transformer layers.
        def build_layer(layer_number):
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
            if args.transformer_impl == 'local':
                return ParallelTransformerLayer(
                    init_method,
                    output_layer_init_method,
                    layer_number,
                    layer_type=layer_type,
                    self_attn_mask_type=self_attn_mask_type,
                    drop_path_rate=self.drop_path_rates[layer_number - 1])
            else:
                return transformer_engine.pytorch.TransformerLayer(
                    args.hidden_size,
                    args.ffn_hidden_size,
                    args.num_attention_heads,
                    layernorm_epsilon=args.layernorm_epsilon,
                    hidden_dropout=args.hidden_dropout,
                    attention_dropout=args.attention_dropout,
                    init_method=init_method,
                    output_layer_init_method=output_layer_init_method,
                    layer_number=layer_number,
                    kv_channels=args.kv_channels,
                    self_attn_mask_type=self_attn_mask_type.name,
                    tp_group=mpu.get_tensor_model_parallel_group(),
                    get_rng_state_tracker=tensor_parallel.get_cuda_rng_tracker,
                    fuse_wgrad_accumulation=args.gradient_accumulation_fusion,
                    apply_query_key_layer_scaling=args.apply_query_key_layer_scaling,
                    attention_softmax_in_fp32=args.attention_softmax_in_fp32,
                    seq_length=args.seq_length,
                    micro_batch_size=args.micro_batch_size,
                    sequence_parallel=args.sequence_parallel,
                    params_dtype=args.params_dtype,
                    apply_residual_connection_post_layernorm=args.apply_residual_connection_post_layernorm,
                    output_layernorm=False,
                    layer_type="encoder",
                    drop_path_rate=self.drop_path_rates[layer_number - 1],
                    set_parallel_mode=True,
                    fuse_qkv_params=True)

1013
1014
        if args.virtual_pipeline_model_parallel_size is not None:
            assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
1015
1016
                'num_layers_per_stage must be divisible by ' \
                'virtual_pipeline_model_parallel_size'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1017
            assert args.model_type != ModelType.encoder_and_decoder
1018
1019
            # Number of layers in each model chunk is the number of layers in the stage,
            # divided by the number of model chunks in a stage.
1020
            self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
1021
1022
1023
1024
1025
1026
1027
1028
            # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0]  [2]  [4]  [6]
            # Stage 1: [1]  [3]  [5]  [7]
            # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0, 1]  [4, 5]
            # Stage 1: [2, 3]  [6, 7]
1029
            offset = mpu.get_virtual_pipeline_model_parallel_rank() * (
1030
                args.num_layers // args.virtual_pipeline_model_parallel_size) + \
1031
                (mpu.get_pipeline_model_parallel_rank() * self.num_layers)
1032
        else:
1033
            # Each stage gets a contiguous set of layers.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1034
            if args.model_type == ModelType.encoder_and_decoder and \
1035
1036
                    mpu.get_pipeline_model_parallel_world_size() > 1:
                pipeline_rank = mpu.get_pipeline_model_parallel_rank()
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1037
1038
1039
1040
1041
1042
                if layer_type == LayerType.encoder:
                    offset = pipeline_rank * self.num_layers
                else:
                    num_ranks_in_enc = args.pipeline_model_parallel_split_rank
                    offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers
            else:
1043
                offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
1044

1045
        if self.num_layers == 0:
Lawrence McAfee's avatar
Lawrence McAfee committed
1046
            # When a standalone embedding stage is used (e.g.,
1047
            # args.standalone_embedding_stage == True), virtual pipeline ranks
1048
            # on pipeline rank 0 will have zero transformer layers assigned to
Lawrence McAfee's avatar
Lawrence McAfee committed
1049
1050
1051
1052
1053
            # them. This results in the model's input and output tensors to be
            # the same, which will cause failure for certain output tensor
            # optimizations (e.g., pipeline output deallocation). To remedy
            # this, we assign a 'no-op' layer on these ranks, which will
            # disconnect the input tensor from the output tensor.
1054
1055
1056
1057
1058
            self.num_layers = 1
            self.layers = torch.nn.ModuleList([ NoopTransformerLayer(1) ])
        else:
            self.layers = torch.nn.ModuleList(
                [build_layer(i + 1 + offset) for i in range(self.num_layers)])
1059

1060
        if self.post_process and self.post_layer_norm:
1061
1062
1063
            # Final layer norm before output.
            self.final_layernorm = LayerNorm(
                args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
1064
                eps=args.layernorm_epsilon,
1065
                no_persist_layer_norm=args.no_persist_layer_norm,
Mostofa Patwary's avatar
Mostofa Patwary committed
1066
                sequence_parallel=args.sequence_parallel,
Jared Casper's avatar
Jared Casper committed
1067
                apply_layernorm_1p=args.apply_layernorm_1p)
1068

Mohammad's avatar
Mohammad committed
1069
    def _get_layer(self, layer_number):
1070
        return self.layers[layer_number]
Mohammad's avatar
Mohammad committed
1071

1072
    def _checkpointed_forward(self, hidden_states, attention_mask,
1073
                              encoder_output, enc_dec_attn_mask, is_first_microbatch):
1074
        """Forward method with activation checkpointing."""
1075
1076
        def custom(start, end, is_transformer_engine=False):
            def custom_forward(*args, **kwargs):
1077
                x_, *args = args
Mohammad's avatar
Mohammad committed
1078
1079
                for index in range(start, end):
                    layer = self._get_layer(index)
1080
                    x_ = layer(x_, *args, **kwargs)
1081
                return x_
1082
1083
1084
1085
1086
1087
            def custom_forward_transformer_engine(*args, **kwargs):
                return custom_forward(*args, is_first_microbatch=is_first_microbatch, **kwargs)
            if not is_transformer_engine:
                return custom_forward
            else:
                return custom_forward_transformer_engine
1088

Vijay Korthikanti's avatar
Vijay Korthikanti committed
1089
        if self.recompute_method == 'uniform':
1090
1091
1092
1093
1094
            # Uniformly divide the total number of Transformer layers and checkpoint
            # the input activation of each divided chunk.
            # A method to further reduce memory usage reducing checkpoints.
            l = 0
            while l < self.num_layers:
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
                if self.transformer_impl == 'transformer_engine':
                    hidden_states = transformer_engine.pytorch.distributed.checkpoint(
                        custom(l, l + self.recompute_num_layers, is_transformer_engine=True),
                        self.distribute_saved_activations,
                        tensor_parallel.get_cuda_rng_tracker,
                        mpu.get_tensor_model_parallel_group(),
                        hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
                else:
                    hidden_states = tensor_parallel.checkpoint(
                        custom(l, l + self.recompute_num_layers),
                        self.distribute_saved_activations,
                        hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)

Vijay Korthikanti's avatar
Vijay Korthikanti committed
1108
                l += self.recompute_num_layers
1109

Vijay Korthikanti's avatar
Vijay Korthikanti committed
1110
        elif self.recompute_method == 'block':
1111
1112
1113
1114
            # Checkpoint the input activation of only a set number of individual
            # Transformer layers and skip the rest.
            # A method fully use the device memory removing redundant re-computation.
            for l in range(self.num_layers):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1115
                if l < self.recompute_num_layers:
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
                    if self.transformer_impl == 'transformer_engine':
                        hidden_states = transformer_engine.pytorch.distributed.checkpoint(
                            custom(l, l + 1, is_transformer_engine=True),
                            self.distribute_saved_activations,
                            tensor_parallel.get_cuda_rng_tracker,
                            mpu.get_tensor_model_parallel_group(),
                            hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
                    else:
                        hidden_states = tensor_parallel.checkpoint(
                            custom(l, l + 1),
                            self.distribute_saved_activations,
                            hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
1128
                else:
1129
1130
1131
1132
1133
1134
                    if self.transformer_impl == 'transformer_engine':
                        hidden_states = custom(l, l + 1, is_transformer_engine=True)(
                            hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
                    else:
                        hidden_states = custom(l, l + 1)(
                            hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
1135
        else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1136
            raise ValueError("Invalid activation recompute method.")
1137
1138
1139

        return hidden_states

1140
    def set_input_tensor(self, input_tensor):
1141
1142
1143
1144
1145
1146
1147
        """Set input tensor to be used instead of forward()'s input.

        When doing pipeline parallelism the input from the previous
        stage comes from communication, not from the input, so the
        model's forward_step_func won't have it. This function is thus
        used by internal code to bypass the input provided by the
        forward_step_func"""
1148
1149
        self.input_tensor = input_tensor

1150
    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
1151
1152
                encoder_output=None, enc_dec_attn_mask=None,
                inference_params=None):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1153
1154
        # hidden_states: [s, b, h]

1155
        # Checks.
mshoeybi's avatar
mshoeybi committed
1156
        if inference_params:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1157
            assert self.recompute_granularity is None, \
1158
                'inference does not work with activation checkpointing'
1159

1160
        if not self.pre_process:
1161
            # See set_input_tensor()
1162
            hidden_states = self.input_tensor
1163

1164
1165
        # Viewless tensor.
        # - We only need to create a viewless tensor in the case of micro batch
1166
1167
1168
1169
        #   size (mbs) == 1, since in this case, 'hidden_states.transpose()'
        #   above creates a view tensor, and '.contiguous()' is a pass-through.
        #   For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
        #   the need to make it viewless.
1170
1171
1172
1173
        #
        #   However, we don't explicitly check mbs == 1 here because
        #   make_viewless_tensor() has negligible overhead when its input
        #   is already viewless.
1174
        #
1175
1176
1177
1178
        # - For the 'else' case above, calling make_viewless_tensor() here is
        #   likely redundant, since p2p_communication.py (likely originator)
        #   already creates viewless tensors. That said, make_viewless_tensor()
        #   is called here to be future-proof and corner-case-proof.
1179
        hidden_states = core.utils.make_viewless_tensor(
1180
            hidden_states,
1181
1182
            requires_grad=True,
            keep_graph=True,
1183
1184
        )

Vijay Korthikanti's avatar
Vijay Korthikanti committed
1185
        if self.sequence_parallel:
1186
            rng_context = tensor_parallel.get_cuda_rng_tracker().fork()
1187
        else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1188
            rng_context = nullcontext()
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1189
1190

        with rng_context:
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
            # The fp8_autocast context manager is a no-op when enabled=True
            # The if...else serves to short circuit name resolution for fp8_autocast
            with transformer_engine.pytorch.fp8_autocast(
                enabled=self.use_fp8,
                fp8_recipe=self.fp8_recipe,
                fp8_group=self.fp8_group
            ) if self.use_fp8 else nullcontext():
                # Determine if the current iteration is first microbatch
                if self.num_microbatches_in_previous_step != get_num_microbatches():
                    self.microbatch_count = 0 # Reset count on new batch size rampup interval
                self.num_microbatches_in_previous_step = get_num_microbatches()
                is_first_microbatch = self.microbatch_count % get_num_microbatches() == 0

                # Forward pass.
                if self.recompute_granularity == 'full':
                    hidden_states = self._checkpointed_forward(hidden_states,
                                                               attention_mask,
                                                               encoder_output,
                                                               enc_dec_attn_mask,
                                                               is_first_microbatch)
                else:
                    forward_kwargs = {
                        'encoder_output': encoder_output,
                        'enc_dec_attn_mask': enc_dec_attn_mask,
                        'inference_params': inference_params,
                    }

                    if self.transformer_impl == 'transformer_engine':
                        forward_kwargs['is_first_microbatch'] = is_first_microbatch
                        forward_kwargs['checkpoint_core_attention'] = self.checkpoint_core_attention

                    for index in range(self.num_layers):
                        layer = self._get_layer(index)

                        hidden_states = layer(
                            hidden_states,
                            attention_mask,
                            **forward_kwargs)

                # Skip counter update for eval and activation checkpointing
                if torch.is_grad_enabled() and self.training:
                    self.microbatch_count += 1
mshoeybi's avatar
mshoeybi committed
1233

1234
        # Final layer norm.
1235
        if self.post_process and self.post_layer_norm:
1236
1237
            hidden_states = self.final_layernorm(hidden_states)

1238
        return hidden_states