transformer.py 54.5 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4

"""Transformer."""
import math
5
from contextlib import nullcontext
6
import torch
7
import torch.nn.functional as F
8

9
from megatron import get_timers, get_args, core, get_num_microbatches
10
from .module import MegatronModule
11
from megatron.core import mpu, tensor_parallel
12
13
from megatron.core.enums import ModelType
from megatron.model.enums import AttnMaskType, LayerType, AttnType
14
from megatron.model import LayerNorm
15
16
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
Mostofa Patwary's avatar
Mostofa Patwary committed
17
from megatron.model.rotary_pos_embedding import apply_rotary_pos_emb
18
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
19

20
21
22
23
24
25
26
27
28
29
try:
    from einops import rearrange
except ImportError:
    rearrange = None

try:
    from flash_attn.flash_attn_interface import flash_attn_unpadded_func
except ImportError:
    flash_attn_unpadded_func = None

30
31
32
33
34
35
36
37
38
39
""" We use the following notation throughout this file:
     h: hidden size
     n: number of attention heads
     p: number of model parallel partitions
     np: n/p
     hp: h/p
     hn: h/n
     b: batch size
     s: sequence length
     l: number of layers
40
    Transformer takes input of size [s, b, h] and returns a
41
42
43
44
    tensor of the same size. We use the following arguments:
        hyperparameters: transformer hyperparameters
"""

45
class DropPath(MegatronModule):
46
    """Drop paths (Stochastic Depth) per sample
47
48
49
    (when applied in main path of residual blocks).
    """

Vijay Korthikanti's avatar
Vijay Korthikanti committed
50
    def __init__(self, drop_prob=0.):
51
52
53
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

Vijay Korthikanti's avatar
Vijay Korthikanti committed
54
    def forward(self, hidden_state):
55
        if self.drop_prob == 0. or not self.training:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
56
            return hidden_state
57
58
        keep_prob = 1 - self.drop_prob
        # work with diff dim tensors, not just 2D ConvNets
59
60
        # hidden_state: [s, b, h]
        shape = (1,) + (hidden_state.shape[1],) + (1,) * (hidden_state.ndim - 2)
61
        random_tensor = keep_prob + \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
62
            torch.rand(shape, dtype=hidden_state.dtype, device=hidden_state.device)
63
        random_tensor.floor_()  # binarize
Vijay Korthikanti's avatar
Vijay Korthikanti committed
64
        output = hidden_state.div(keep_prob) * random_tensor
65
66
        return output

67
68
69
70
71
72
73
74
75
76
77
def _args_to_kwargs():
    args = get_args()

    common_kwargs = {
        "params_dtype": args.params_dtype,
        "use_cpu_initialization": args.use_cpu_initialization,
        "perform_initialization": args.perform_initialization,
        "gradient_accumulation_fusion": args.gradient_accumulation_fusion,
        "sequence_parallel_enabled": args.sequence_parallel,
    }
    return common_kwargs
78

79
80
81
82
83
class ParallelMLP(MegatronModule):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
hwijeen's avatar
hwijeen committed
84
    state back into h hidden dimension.
85
86
    """

87
    def __init__(self, init_method, output_layer_init_method):
88
        super(ParallelMLP, self).__init__()
Mohammad's avatar
Mohammad committed
89
        args = get_args()
90

91

92
        # Project to 4h.
93
        self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear(
Mohammad's avatar
Mohammad committed
94
            args.hidden_size,
95
            args.ffn_hidden_size,
96
            gather_output=False,
97
            init_method=init_method,
98
99
100
            skip_bias_add=True,
            async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
            **_args_to_kwargs())
101

102
103
104
105
106
107
        self.bias_gelu_fusion = args.bias_gelu_fusion
        self.activation_func = F.gelu
        if args.openai_gelu:
            self.activation_func = openai_gelu
        elif args.onnx_safe:
            self.activation_func = erf_gelu
108
109

        # Project back to h.
110
        self.dense_4h_to_h = tensor_parallel.RowParallelLinear(
111
            args.ffn_hidden_size,
Mohammad's avatar
Mohammad committed
112
            args.hidden_size,
113
            input_is_parallel=True,
114
            init_method=output_layer_init_method,
115
116
            skip_bias_add=True,
            **_args_to_kwargs())
117

118
119
    def forward(self, hidden_states):

120
121
        # [s, b, 4hp]
        intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
122

123
124
125
126
127
128
129
130
131
132
        if self.bias_gelu_fusion:
             intermediate_parallel = \
                     bias_gelu_impl(intermediate_parallel, bias_parallel)
        else:
            intermediate_parallel = \
                self.activation_func(intermediate_parallel + bias_parallel)

        # [s, b, h]
        output, output_bias = self.dense_4h_to_h(intermediate_parallel)
        return output, output_bias
133

rprenger's avatar
rprenger committed
134
135
136
137
class SwitchMLP(MegatronModule):
    """
    Routes input to one of N MLP "experts"
    """
rprenger's avatar
rprenger committed
138
    def __init__(self, init_method, output_layer_init_method):
rprenger's avatar
rprenger committed
139
140
        super(SwitchMLP, self).__init__()
        args = get_args()
rprenger's avatar
rprenger committed
141
        self.router = torch.nn.Linear(args.hidden_size, args.num_experts)
rprenger's avatar
rprenger committed
142
        self.experts = torch.nn.ModuleList()
rprenger's avatar
rprenger committed
143
        for i in range(args.num_experts):
rprenger's avatar
rprenger committed
144
            self.experts.append(ParallelMLP(init_method, output_layer_init_method))
145

rprenger's avatar
rprenger committed
146
    def forward(self, hidden_states):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
147
148
149
        # hidden_states: [s, b, h]
        s = hidden_states.size(0)
        b = hidden_states.size(1)
rprenger's avatar
rprenger committed
150
151
        h = hidden_states.size(2)
        route = self.router(hidden_states)
rprenger's avatar
rprenger committed
152
        route = torch.nn.functional.softmax(route, dim=2)
rprenger's avatar
rprenger committed
153
        max_prob, max_ind = torch.max(route, dim=2)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
154
        max_prob = torch.unsqueeze(max_prob, 2) # [s b 1]
155

rprenger's avatar
rprenger committed
156
        # TODO (rprenger) TODO this could be made easier to read
Vijay Korthikanti's avatar
Vijay Korthikanti committed
157
        # Converting [s, b, h] to [s*b, h].
158
        # Each vector could be routed differently
Vijay Korthikanti's avatar
Vijay Korthikanti committed
159
160
161
        hidden_states = hidden_states.view(-1, hidden_states.size(2)) # [s*b h]
        max_prob = max_prob.view(-1, max_prob.size(2)) # [s*b 1]
        max_ind = max_ind.view(-1) # [s*b]
rprenger's avatar
rprenger committed
162
163
164

        output_total = torch.empty_like(hidden_states)
        output_bias_total = torch.empty_like(hidden_states)
rprenger's avatar
rprenger committed
165
        #TODO (rprenger) This does each expert in serial, but it could be parallelized
166

rprenger's avatar
rprenger committed
167
        for expert_num, expert in enumerate(self.experts):
168
169
            local_indices = (max_ind == expert_num).nonzero()
            hidden = hidden_states[local_indices,:]
rprenger's avatar
rprenger committed
170
171
            output, output_bias = expert(hidden)
            output_bias = output_bias.expand_as(output)
172
173
174
            output_total[local_indices,:] = output
            output_bias_total[local_indices,:] = output_bias

rprenger's avatar
rprenger committed
175
176
        output_total = output_total*max_prob
        output_bias_total = output_bias_total*max_prob
Vijay Korthikanti's avatar
Vijay Korthikanti committed
177
178
        output_total = output_total.view(s, b, h)
        output_bias_total = output_bias_total.view(s, b, h)
rprenger's avatar
rprenger committed
179
180

        return output_total, output_bias_total
181

182
183

class CoreAttention(MegatronModule):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
184

185
186
187
188
189
190
191
192
193
194
195
196
197
    def __init__(self, layer_number,
                 attn_mask_type=AttnMaskType.padding):
        super(CoreAttention, self).__init__()
        args = get_args()
        self.fp16 = args.fp16
        self.bf16 = args.bf16

        self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
        if self.apply_query_key_layer_scaling:
            self.attention_softmax_in_fp32 = True
        self.layer_number = max(1, layer_number)
        self.attn_mask_type = attn_mask_type
Vijay Korthikanti's avatar
Vijay Korthikanti committed
198
        self.sequence_parallel = args.sequence_parallel
199
200
201
202

        projection_size = args.kv_channels * args.num_attention_heads

        # Per attention head and per partition values.
203
        world_size = mpu.get_tensor_model_parallel_world_size()
204
205
206
        self.hidden_size_per_partition = core.utils.divide(projection_size,
                                                           world_size)
        self.hidden_size_per_attention_head = core.utils.divide(
207
            projection_size, args.num_attention_heads)
208
        self.num_attention_heads_per_partition = core.utils.divide(
209
            args.num_attention_heads, world_size)
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228

        coeff = None
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        if self.apply_query_key_layer_scaling:
            coeff = self.layer_number
            self.norm_factor *= coeff

        self.scale_mask_softmax = FusedScaleMaskSoftmax(
            self.fp16, self.bf16,
            self.attn_mask_type,
            args.masked_softmax_fusion,
            attention_mask_func,
            self.attention_softmax_in_fp32,
            coeff)

        # Dropout. Note that for a single iteration, this layer will generate
        # different outputs on different number of parallel partitions but
        # on average it should not be partition dependent.
        self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
229

230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
    def forward(self, query_layer, key_layer,
                value_layer, attention_mask):

        # ===================================
        # Raw attention scores. [b, np, s, s]
        # ===================================

        # [b, np, sq, sk]
        output_size = (query_layer.size(1),
                       query_layer.size(2),
                       query_layer.size(0),
                       key_layer.size(0))

        # [sq, b, np, hn] -> [sq, b * np, hn]
        query_layer = query_layer.view(output_size[2],
                                       output_size[0] * output_size[1], -1)
        # [sk, b, np, hn] -> [sk, b * np, hn]
        key_layer = key_layer.view(output_size[3],
                                   output_size[0] * output_size[1], -1)

Vijay Korthikanti's avatar
Vijay Korthikanti committed
250
        # preallocting input tensor: [b * np, sq, sk]
251
        matmul_input_buffer = mpu.get_global_memory_buffer().get_tensor(
252
            (output_size[0]*output_size[1], output_size[2], output_size[3]),
Vijay Korthikanti's avatar
Vijay Korthikanti committed
253
            query_layer.dtype, "mpu")
254
255
256

        # Raw attention scores. [b * np, sq, sk]
        matmul_result = torch.baddbmm(
Vijay Korthikanti's avatar
Vijay Korthikanti committed
257
            matmul_input_buffer,
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
            query_layer.transpose(0, 1),   # [b * np, sq, hn]
            key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
            beta=0.0, alpha=(1.0/self.norm_factor))

        # change view to [b, np, sq, sk]
        attention_scores = matmul_result.view(*output_size)

        # ===========================
        # Attention probs and dropout
        # ===========================

        # attention scores and attention mask [b, np, sq, sk]
        attention_probs = self.scale_mask_softmax(attention_scores,
                                                  attention_mask)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
275
        if not self.sequence_parallel:
276
            with tensor_parallel.get_cuda_rng_tracker().fork():
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
                attention_probs = self.attention_dropout(attention_probs)
        else:
            attention_probs = self.attention_dropout(attention_probs)

        # =========================
        # Context layer. [sq, b, hp]
        # =========================

        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]

        # context layer shape: [b, np, sq, hn]
        output_size = (value_layer.size(1),
                       value_layer.size(2),
                       query_layer.size(0),
                       value_layer.size(3))

        # change view [sk, b * np, hn]
        value_layer = value_layer.view(value_layer.size(0),
                                       output_size[0] * output_size[1], -1)

        # change view [b * np, sq, sk]
        attention_probs = attention_probs.view(output_size[0] * output_size[1],
                                               output_size[2], -1)

        # matmul: [b * np, sq, hn]
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))

        # change view [b, np, sq, hn]
        context_layer = context_layer.view(*output_size)

        # [b, np, sq, hn] --> [sq, b, np, hn]
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

        # [sq, b, np, hn] --> [sq, b, hp]
        new_context_layer_shape = context_layer.size()[:-2] + \
            (self.hidden_size_per_partition,)
        context_layer = context_layer.view(*new_context_layer_shape)

        return context_layer


319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
class FlashSelfAttention(torch.nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_scale: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.0)
    """
    def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,
                 device=None, dtype=None):
        super().__init__()
        assert flash_attn_unpadded_func is not None, ('Please install FlashAttention first, '
                                                      'e.g., with pip install flash-attn')
        assert rearrange is not None, 'Please install einops first, e.g., with pip install einops'
        self.causal = causal
        self.softmax_scale = softmax_scale
        self.dropout_p = attention_dropout

    def forward(self, q, k, v):
        """Implements the multihead softmax attention.
        Arguments
        ---------
            q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
        """
        assert q.dtype in [torch.float16, torch.bfloat16]
        assert q.is_cuda
        batch_size, seqlen = q.shape[0], q.shape[1]
        q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]]
        max_s = seqlen
        cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
                                  device=q.device)
        output = flash_attn_unpadded_func(
            q, k, v, cu_seqlens, cu_seqlens, max_s, max_s,
            self.dropout_p if self.training else 0.0,
            softmax_scale=self.softmax_scale, causal=self.causal
        )
        output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
        return output


361
class ParallelAttention(MegatronModule):
362
363
    """Parallel self-attention layer abstract class.

Vijay Korthikanti's avatar
Vijay Korthikanti committed
364
    Self-attention layer takes input with size [s, b, h]
365
366
    and returns output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
367

368
    def __init__(self, init_method,
369
370
371
372
                 output_layer_init_method, layer_number,
                 attention_type=AttnType.self_attn,
                 attn_mask_type=AttnMaskType.padding):
        super(ParallelAttention, self).__init__()
Mohammad's avatar
Mohammad committed
373
        args = get_args()
374
        self.layer_number = max(1, layer_number)
375
376
        self.attention_type = attention_type
        self.attn_mask_type = attn_mask_type
377
        self.params_dtype = args.params_dtype
378
379
380
381
382
383
384
385
386
387
388
389
390
        self.sequence_parallel = args.sequence_parallel

        self.use_flash_attn = args.use_flash_attn
        if self.use_flash_attn:
            if flash_attn_unpadded_func is None:
                raise ImportError('FlashAttention is not installed, please install with '
                                  'pip install flash-attn')
            assert attention_type == AttnType.self_attn, ('FlashAttention code path only supports '
                                                          'self-attention for now')
            assert self.attn_mask_type == AttnMaskType.causal, ('FlashAttention code path only '
                                                                'supports causal mask for now')
            if rearrange is None:
                raise ImportError('einops is not installed, please install with pip install einops')
391
392

        projection_size = args.kv_channels * args.num_attention_heads
393
394

        # Per attention head and per partition values.
395
        world_size = mpu.get_tensor_model_parallel_world_size()
396
        self.hidden_size_per_attention_head = core.utils.divide(
397
            projection_size, args.num_attention_heads)
398
        self.num_attention_heads_per_partition = core.utils.divide(
Mohammad's avatar
Mohammad committed
399
            args.num_attention_heads, world_size)
400
401

        # Strided linear layer.
402
        if attention_type == AttnType.self_attn:
403
            self.query_key_value = tensor_parallel.ColumnParallelLinear(
404
405
406
                args.hidden_size,
                3 * projection_size,
                gather_output=False,
407
408
409
                init_method=init_method,
                async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
                **_args_to_kwargs())
410
411
        else:
            assert attention_type == AttnType.cross_attn
412
            self.query = tensor_parallel.ColumnParallelLinear(
413
414
415
                args.hidden_size,
                projection_size,
                gather_output=False,
416
417
418
                init_method=init_method,
                async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
                **_args_to_kwargs())
419

420

421
            self.key_value = tensor_parallel.ColumnParallelLinear(
422
423
424
                args.hidden_size,
                2 * projection_size,
                gather_output=False,
425
426
427
                init_method=init_method,
                async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
                **_args_to_kwargs())
428

429
430
        self.core_attention = CoreAttention(self.layer_number,
                                            self.attn_mask_type)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
431
        self.checkpoint_core_attention = args.recompute_granularity == 'selective'
432

433
434
435
436
437
        if self.use_flash_attn:
            self.core_attention_flash = FlashSelfAttention(
                causal=True, attention_dropout=args.attention_dropout
            )

438
        # Output.
439
        self.dense = tensor_parallel.RowParallelLinear(
440
            projection_size,
Mohammad's avatar
Mohammad committed
441
            args.hidden_size,
442
            input_is_parallel=True,
443
            init_method=output_layer_init_method,
444
445
            skip_bias_add=True,
            **_args_to_kwargs())
Vijay Korthikanti's avatar
Vijay Korthikanti committed
446

447
    def _checkpointed_attention_forward(self, query_layer, key_layer,
Mostofa Patwary's avatar
Mostofa Patwary committed
448
449
                                        value_layer, attention_mask,
                                        rotary_pos_emb=None):
450
451
452
453
454
455
456
457
458
459
        """Forward method with activation checkpointing."""
        def custom_forward(*inputs):
            query_layer = inputs[0]
            key_layer = inputs[1]
            value_layer = inputs[2]
            attention_mask = inputs[3]
            output_ = self.core_attention(query_layer, key_layer,
                                          value_layer, attention_mask)
            return output_

Mostofa Patwary's avatar
Mostofa Patwary committed
460
461
462
        q_pos_emb, k_pos_emb = (None, None) if rotary_pos_emb is None \
            else rotary_pos_emb

463
        hidden_states = tensor_parallel.checkpoint(
464
            custom_forward,
Mostofa Patwary's avatar
Mostofa Patwary committed
465
466
            False, query_layer, key_layer, value_layer, attention_mask,
            q_pos_emb, k_pos_emb)
467
468

        return hidden_states
469
470
471
472
473
474
475
476
477
478
479

    def _allocate_memory(self, inference_max_sequence_len, batch_size):
        return torch.empty(
            inference_max_sequence_len,
            batch_size,
            self.num_attention_heads_per_partition,
            self.hidden_size_per_attention_head,
            dtype=self.params_dtype,
            device=torch.cuda.current_device())

    def forward(self, hidden_states, attention_mask,
Mostofa Patwary's avatar
Mostofa Patwary committed
480
481
                encoder_output=None, inference_params=None,
                rotary_pos_emb=None):
482
        # hidden_states: [sq, b, h]
483

484
485
486
        # =================================================
        # Pre-allocate memory for key-values for inference.
        # =================================================
Mostofa Patwary's avatar
Mostofa Patwary committed
487
        is_first_step = False
mshoeybi's avatar
mshoeybi committed
488
        if inference_params:
489
            if self.layer_number not in inference_params.key_value_memory_dict:
mshoeybi's avatar
mshoeybi committed
490
                inf_max_seq_len = inference_params.max_sequence_len
mshoeybi's avatar
mshoeybi committed
491
                inf_max_batch_size = inference_params.max_batch_size
492
                inference_key_memory = self._allocate_memory(
mshoeybi's avatar
mshoeybi committed
493
                    inf_max_seq_len, inf_max_batch_size)
494
                inference_value_memory = self._allocate_memory(
mshoeybi's avatar
mshoeybi committed
495
                    inf_max_seq_len, inf_max_batch_size)
496
497
                inference_params.key_value_memory_dict[self.layer_number] = (
                    inference_key_memory, inference_value_memory)
Mostofa Patwary's avatar
Mostofa Patwary committed
498
                is_first_step = True
499
500
501
            else:
                inference_key_memory, inference_value_memory = \
                    inference_params.key_value_memory_dict[self.layer_number]
mshoeybi's avatar
mshoeybi committed
502

503
504
505
        # =====================
        # Query, Key, and Value
        # =====================
506

507
508
509
510
511
512
513
514
515
516
517
518
519
        if self.attention_type == AttnType.self_attn:
            # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
            mixed_x_layer, _ = self.query_key_value(hidden_states)

            # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
            new_tensor_shape = mixed_x_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 3 * self.hidden_size_per_attention_head)
            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

            # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
            (query_layer,
             key_layer,
520
             value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_x_layer, 3)
521
522
523
524
525
526
527
528
529
530
531
532
        else:
            # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
            mixed_kv_layer, _ = self.key_value(encoder_output)

            # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
            new_tensor_shape = mixed_kv_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 2 * self.hidden_size_per_attention_head)
            mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

            # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
            (key_layer,
533
             value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2)
534
535
536
537
538
539
540
541

            # Attention head [sq, b, h] --> [sq, b, hp]
            query_layer, _ = self.query(hidden_states)
            # [sq, b, hp] --> [sq, b, np, hn]
            new_tensor_shape = query_layer.size()[:-1] + \
                (self.num_attention_heads_per_partition,
                 self.hidden_size_per_attention_head)
            query_layer = query_layer.view(*new_tensor_shape)
542

mshoeybi's avatar
mshoeybi committed
543
544
545
        # ==================================
        # Adjust key and value for inference
        # ==================================
546

Mostofa Patwary's avatar
Mostofa Patwary committed
547
548
        # duplicate the pos_emb for self attention
        if rotary_pos_emb is not None:
Mostofa Patwary's avatar
Mostofa Patwary committed
549
550
551
552
            if isinstance(rotary_pos_emb, tuple):
                rotary_pos_emb = rotary_pos_emb
            else:
                rotary_pos_emb = ((rotary_pos_emb,) * 2)
Mostofa Patwary's avatar
Mostofa Patwary committed
553

mshoeybi's avatar
mshoeybi committed
554
        if inference_params:
mshoeybi's avatar
mshoeybi committed
555
556
            batch_start = inference_params.batch_size_offset
            batch_end = batch_start + key_layer.size(1)
557
            assert batch_end <= inference_key_memory.size(1)
mshoeybi's avatar
mshoeybi committed
558
559
            sequence_start = inference_params.sequence_len_offset
            sequence_end = sequence_start + key_layer.size(0)
560
            assert sequence_end <= inference_key_memory.size(0)
561
            # Copy key and values.
562
563
564
565
566
            inference_key_memory[sequence_start:sequence_end,
                                 batch_start:batch_end, ...] = key_layer
            inference_value_memory[sequence_start:sequence_end,
                                   batch_start:batch_end, ...] = value_layer
            key_layer = inference_key_memory[
mshoeybi's avatar
mshoeybi committed
567
                :sequence_end, batch_start:batch_end, ...]
568
            value_layer = inference_value_memory[
mshoeybi's avatar
mshoeybi committed
569
                :sequence_end, batch_start:batch_end, ...]
570

Mostofa Patwary's avatar
Mostofa Patwary committed
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592

            # adjust the key rotary positional embedding
            if rotary_pos_emb is not None:
                q_pos_emb, k_pos_emb = rotary_pos_emb
                # need to cross check this condition during inference
                # if not set_inference_key_value_memory:
                if not is_first_step:
                    # In inference, we compute one token at a time.
                    # Select the correct positional embedding
                    # (only the last token in the sequence)
                    q_pos_emb = q_pos_emb[sequence_end - 1 : sequence_end]
                else:
                    # In the first forward pass of inference,
                    # we use the entire provided prefix.
                    # q_pos_emb here has the rope embeddings of the entire
                    # prefix + to-be-generated output so
                    # we slice to just the prefix.
                    q_pos_emb = q_pos_emb[:sequence_end, :, :, :]
                k_pos_emb = k_pos_emb[:sequence_end, :, :, :]
                rotary_pos_emb = (q_pos_emb, k_pos_emb)


593
594
595
        # ==================================
        # core attention computation
        # ==================================
596

Mostofa Patwary's avatar
Mostofa Patwary committed
597
598
599
600
601
602
603
604
605
606
        # apply relative positional encoding (rotary embedding)
        if rotary_pos_emb is not None:
            q_pos_emb, k_pos_emb = rotary_pos_emb
            query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb)
            key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb)
            # TODO, can apply positional embedding to value_layer so it has
            # absolute positional embedding.
            # otherwise, only relative positional embedding takes effect
            # value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb)

607
608
609
610
611
612
613
        if not self.use_flash_attn:
            if self.checkpoint_core_attention:
                context_layer = self._checkpointed_attention_forward(
                    query_layer, key_layer, value_layer, attention_mask)
            else:
                context_layer = self.core_attention(
                    query_layer, key_layer, value_layer, attention_mask)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
614
        else:
615
616
617
618
619
620
621
622
            q, k, v = [rearrange(x, 's b ... -> b s ...').contiguous()
                       for x in (query_layer, key_layer, value_layer)]
            if not self.sequence_parallel:
                with tensor_parallel.get_cuda_rng_tracker().fork():
                    context_layer = self.core_attention_flash(q, k, v)
            else:
                context_layer = self.core_attention_flash(q, k, v)
            context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous()
623
624

        # =================
625
        # Output. [sq, b, h]
626
627
628
        # =================

        output, bias = self.dense(context_layer)
629

630
631
632
        return output, bias


633
def bias_dropout_add(x, bias, residual, prob, training):
634
635
636
637
638
639
640
641
642
643
644
645
646
    # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
    out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
    out = residual + out
    return out


def get_bias_dropout_add(training):
    def _bias_dropout_add(x, bias, residual, prob):
        return bias_dropout_add(x, bias, residual, prob, training)
    return _bias_dropout_add


@torch.jit.script
647
648
649
650
def bias_dropout_add_fused_train(x: torch.Tensor,
                                 bias: torch.Tensor,
                                 residual: torch.Tensor,
                                 prob: float) -> torch.Tensor:
651
652
653
654
    return bias_dropout_add(x, bias, residual, prob, True)


@torch.jit.script
655
656
657
658
def bias_dropout_add_fused_inference(x: torch.Tensor,
                                     bias: torch.Tensor,
                                     residual: torch.Tensor,
                                     prob: float) -> torch.Tensor:
659
    return bias_dropout_add(x, bias, residual, prob, False)
660
661
662
663
664


class ParallelTransformerLayer(MegatronModule):
    """A single transformer layer.

Vijay Korthikanti's avatar
Vijay Korthikanti committed
665
    Transformer layer takes input with size [s, b, h] and returns an
666
667
    output of the same size.
    """
Neel Kant's avatar
Neel Kant committed
668

669
670
    def __init__(self, init_method, output_layer_init_method,
                 layer_number, layer_type=LayerType.encoder,
671
672
                 self_attn_mask_type=AttnMaskType.padding,
                 drop_path_rate=0.):
Mohammad's avatar
Mohammad committed
673
        args = get_args()
674
675

        super(ParallelTransformerLayer, self).__init__()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
676
        self.layer_number = layer_number
677
        self.layer_type = layer_type
678
679

        self.apply_residual_connection_post_layernorm \
Mohammad's avatar
Mohammad committed
680
            = args.apply_residual_connection_post_layernorm
681

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
682
683
684
        self.bf16 = args.bf16
        self.fp32_residual_connection = args.fp32_residual_connection

685
686
        # Layernorm on the input data.
        self.input_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
687
            args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
688
            eps=args.layernorm_epsilon,
689
            no_persist_layer_norm=args.no_persist_layer_norm,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
690
            sequence_parallel=args.sequence_parallel)
691
692

        # Self attention.
693
694
695
696
697
698
        self.self_attention = ParallelAttention(
            init_method,
            output_layer_init_method,
            layer_number,
            attention_type=AttnType.self_attn,
            attn_mask_type=self_attn_mask_type)
699
700
        self.hidden_dropout = args.hidden_dropout
        self.bias_dropout_fusion = args.bias_dropout_fusion
Vijay Korthikanti's avatar
Vijay Korthikanti committed
701
        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None
702

703
        # Layernorm on the attention output
704
        self.post_attention_layernorm = LayerNorm(
Mohammad's avatar
Mohammad committed
705
            args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
706
            eps=args.layernorm_epsilon,
707
            no_persist_layer_norm=args.no_persist_layer_norm,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
708
            sequence_parallel=args.sequence_parallel)
709

710
711
712
713
714
715
716
717
718
        if self.layer_type == LayerType.decoder:
            self.inter_attention = ParallelAttention(
                init_method,
                output_layer_init_method,
                layer_number,
                attention_type=AttnType.cross_attn)
            # Layernorm on the attention output.
            self.post_inter_attention_layernorm = LayerNorm(
                args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
719
                eps=args.layernorm_epsilon,
720
                no_persist_layer_norm=args.no_persist_layer_norm,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
721
                sequence_parallel=args.sequence_parallel)
722

723
        # MLP
rprenger's avatar
rprenger committed
724
725
726
727
        if args.num_experts is not None:
            self.mlp = SwitchMLP(init_method, output_layer_init_method)
        else:
            self.mlp = ParallelMLP(init_method, output_layer_init_method)
728

729
730
731
732
733
734
735
        # Set bias+dropout+add fusion grad_enable execution handler.
        TORCH_MAJOR = int(torch.__version__.split('.')[0])
        TORCH_MINOR = int(torch.__version__.split('.')[1])
        use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
        self.bias_dropout_add_exec_handler = \
                nullcontext if use_nvfuser else torch.enable_grad

736
    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
737
                encoder_output=None, enc_dec_attn_mask=None,
Mostofa Patwary's avatar
Mostofa Patwary committed
738
                inference_params=None, rotary_pos_emb=None):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
739
        # hidden_states: [s, b, h]
740

741
        # Layer norm at the beginning of the transformer layer.
742
743
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
744
        attention_output, attention_bias = \
745
746
747
            self.self_attention(
                layernorm_output,
                attention_mask,
Mostofa Patwary's avatar
Mostofa Patwary committed
748
                inference_params=inference_params,
Mostofa Patwary's avatar
Mostofa Patwary committed
749
                rotary_pos_emb=rotary_pos_emb)
750

751
752
        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
753
754
755
756
            residual = layernorm_output
        else:
            residual = hidden_states

Vijay Korthikanti's avatar
Vijay Korthikanti committed
757
        if self.drop_path is None:
758
759
760
761
762
763
764
765
766
            # jit scripting for a nn.module (with dropout) is not
            # trigerring the fusion kernel. For now, we use two
            # different nn.functional routines to account for varying
            # dropout semantics during training and inference phases.
            if self.bias_dropout_fusion:
                if self.training:
                    bias_dropout_add_func = bias_dropout_add_fused_train
                else:
                    bias_dropout_add_func = bias_dropout_add_fused_inference
767
            else:
768
                bias_dropout_add_func = get_bias_dropout_add(self.training)
769

770
            with self.bias_dropout_add_exec_handler():
771
772
773
774
775
776
777
778
779
780
                layernorm_input = bias_dropout_add_func(
                    attention_output,
                    attention_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)
        else:
            out = torch.nn.functional.dropout(attention_output + attention_bias,
                                              p=self.hidden_dropout,
                                              training=self.training)
            layernorm_input = residual + self.drop_path(out)
781

782
783
784
        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)

785
786
787
788
789
790
791
792
793
794
795
        if self.layer_type == LayerType.decoder:
            attention_output, attention_bias = \
                self.inter_attention(layernorm_output,
                                     enc_dec_attn_mask,
                                     encoder_output=encoder_output)
            # residual connection
            if self.apply_residual_connection_post_layernorm:
                residual = layernorm_output
            else:
                residual = layernorm_input

796
            with self.bias_dropout_add_exec_handler():
797
798
799
800
801
802
803
804
805
                layernorm_input = bias_dropout_add_func(
                    attention_output,
                    attention_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)

            # Layer norm post the decoder attention
            layernorm_output = self.post_inter_attention_layernorm(layernorm_input)

806
        # MLP.
807
        mlp_output, mlp_bias = self.mlp(layernorm_output)
808

809
810
        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
811
            residual = layernorm_output
812
        else:
813
814
            residual = layernorm_input

Vijay Korthikanti's avatar
Vijay Korthikanti committed
815
        if self.drop_path is None:
816
            with self.bias_dropout_add_exec_handler():
817
818
819
820
821
                output = bias_dropout_add_func(
                    mlp_output,
                    mlp_bias.expand_as(residual),
                    residual,
                    self.hidden_dropout)
822
823
824
825
826
827
828

            # Jit compiled function creates 'view' tensor. This tensor
            # potentially gets saved in the MPU checkpoint function context,
            # which rejects view tensors. While making a viewless tensor here
            # won't result in memory savings (like the data loader, or
            # p2p_communication), it serves to document the origin of this
            # 'view' tensor.
829
830
831
            output = core.utils.make_viewless_tensor(inp = output,
                                                     requires_grad = output.requires_grad,
                                                     keep_graph = True)
832

833
834
835
836
837
        else:
            out = torch.nn.functional.dropout(mlp_output + mlp_bias,
                                              p=self.hidden_dropout,
                                              training=self.training)
            output = residual + self.drop_path(out)
838
839
840
841

        return output


842
843
844
class NoopTransformerLayer(MegatronModule):
    """A single 'no-op' transformer layer.

Lawrence McAfee's avatar
Lawrence McAfee committed
845
    The sole purpose of this layer is for when a standalone embedding layer
846
    is used (i.e., args.standalone_embedding_stage == True). In this case,
Lawrence McAfee's avatar
Lawrence McAfee committed
847
848
849
850
851
852
853
854
855
    zero transformer layers are assigned when pipeline rank == 0. Additionally,
    when virtual pipeline rank >= 1, zero total model parameters are created
    (virtual rank 0 contains the input embedding). This results in the model's
    input and output tensors being the same, which causes an error when
    performing certain memory optimiations on the output tensor (e.g.,
    deallocating it). Thus, this layer disconnects the input from the output
    via a clone. Since ranks containing a no-op layer are generally under-
    utilized (both compute and memory), there's no worry of any performance
    degredation.
856
857
858
859
860
861
862
863
864
865
866
867
    """

    def __init__(self, layer_number):
        super().__init__()
        self.layer_number = layer_number

    def forward(self, hidden_states, attention_mask,
                encoder_output=None, enc_dec_attn_mask=None,
                inference_params=None):
        return hidden_states.clone()


Jared Casper's avatar
Jared Casper committed
868
def _get_num_layers(args, is_encoder_and_decoder_model, is_decoder=False):
869
    """Compute the number of transformer layers resident on the current rank."""
Jared Casper's avatar
Jared Casper committed
870
    if mpu.get_pipeline_model_parallel_world_size() > 1:
871
872
873
874
875
876
877
878
879
880
881
882
883
        if is_encoder_and_decoder_model:
            assert args.pipeline_model_parallel_split_rank is not None

            # When a standalone embedding stage is used, a rank is taken from
            # the encoder's ranks, to be used for the encoder's embedding
            # layer. This way, the rank referenced by the 'split rank' remains
            # the same whether or not a standalone embedding stage is used.
            num_ranks_in_encoder = (
                args.pipeline_model_parallel_split_rank - 1
                if args.standalone_embedding_stage else
                args.pipeline_model_parallel_split_rank
            )
            num_ranks_in_decoder = args.transformer_pipeline_model_parallel_size - num_ranks_in_encoder
Jared Casper's avatar
Jared Casper committed
884
885
886
887
            assert args.encoder_num_layers % num_ranks_in_encoder == 0, \
                    'encoder_num_layers (%d) must be divisible by number of ranks given to encoder (%d)' % (args.encoder_num_layers, num_ranks_in_encoder)
            assert args.decoder_num_layers % num_ranks_in_decoder == 0, \
                    'decoder_num_layers (%d) must be divisible by number of ranks given to decoder (%d)' % (args.decoder_num_layers, num_ranks_in_decoder)
Jared Casper's avatar
Jared Casper committed
888
            if mpu.is_pipeline_stage_before_split():
889
890
891
                num_layers = (
                    0
                    if args.standalone_embedding_stage
Jared Casper's avatar
Jared Casper committed
892
                    and mpu.get_pipeline_model_parallel_rank() == 0 else
Jared Casper's avatar
Jared Casper committed
893
                    args.encoder_num_layers // num_ranks_in_encoder
894
895
                )
            else:
Jared Casper's avatar
Jared Casper committed
896
                num_layers = args.decoder_num_layers // num_ranks_in_decoder
897
        else:
Jared Casper's avatar
Jared Casper committed
898
            assert args.num_layers == args.encoder_num_layers
899
900
901
902
903
904
905
906
907
908
            assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, \
                'num_layers must be divisible by transformer_pipeline_model_parallel_size'

            # When a standalone embedding stage is used, all transformer layers
            # are divided among pipeline rank >= 1, while on pipeline rank 0,
            # ranks either contain the input embedding layer (virtual pp rank 0),
            # or no layers at all (virtual pp rank >= 1).
            num_layers = (
                0
                if args.standalone_embedding_stage
Jared Casper's avatar
Jared Casper committed
909
                and mpu.get_pipeline_model_parallel_rank() == 0 else
910
911
912
                args.num_layers // args.transformer_pipeline_model_parallel_size
            )
    else:
Jared Casper's avatar
Jared Casper committed
913
914
915
916
        if not is_decoder:
            num_layers = args.encoder_num_layers
        else:
            num_layers = args.decoder_num_layers
917
918
919
    return num_layers


920
921
922
class ParallelTransformer(MegatronModule):
    """Transformer class."""

923
    def __init__(self, init_method, output_layer_init_method,
924
                 layer_type=LayerType.encoder,
925
                 self_attn_mask_type=AttnMaskType.padding,
926
                 post_layer_norm=True,
927
928
                 pre_process=True, post_process=True,
                 drop_path_rate=0.0):
929
        super(ParallelTransformer, self).__init__()
Mohammad's avatar
Mohammad committed
930
        args = get_args()
931

932
933
        self.layer_type = layer_type
        self.model_type = args.model_type
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
934
        self.bf16 = args.bf16
935
        self.fp32_residual_connection = args.fp32_residual_connection
936
        self.post_layer_norm = post_layer_norm
937
938
939
        self.pre_process = pre_process
        self.post_process = post_process
        self.input_tensor = None
940
        self.drop_path_rate = drop_path_rate
941
        self.transformer_impl = args.transformer_impl
942

943
        # Store activation checkpoiting flag.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
944
945
946
        self.recompute_granularity = args.recompute_granularity
        self.recompute_method = args.recompute_method
        self.recompute_num_layers = args.recompute_num_layers
Vijay Korthikanti's avatar
Vijay Korthikanti committed
947
948
        self.distribute_saved_activations = \
            args.distribute_saved_activations and not args.sequence_parallel
949

Vijay Korthikanti's avatar
Vijay Korthikanti committed
950
        self.sequence_parallel = args.sequence_parallel
951

952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
        # Transformer Engine Init.
        if self.transformer_impl == 'transformer_engine':
            global transformer_engine
            import transformer_engine
        self.use_fp8 = args.fp8_e4m3 or args.fp8_hybrid
        self.fp8_recipe = None
        self.fp8_group = mpu.get_data_parallel_group()
        if self.use_fp8:
            if args.fp8_e4m3:
                fp8_format = transformer_engine.common.recipe.Format.E4M3
            elif args.fp8_hybrid:
                fp8_format = transformer_engine.common.recipe.Format.HYBRID
            self.fp8_recipe = transformer_engine.common.recipe.DelayedScaling(
                margin=args.fp8_margin,
                interval=args.fp8_interval,
                fp8_format=fp8_format,
                amax_history_len=args.fp8_amax_history_len,
                amax_compute_algo=args.fp8_amax_compute_algo,
                override_linear_precision=(False, False, not args.fp8_wgrad),
            )

        self.num_microbatches_in_previous_step = -1
        self.microbatch_count = 0
        self.checkpoint_core_attention = args.recompute_granularity == 'selective'

977
        # Number of layers.
978
        self.num_layers = _get_num_layers(
979
980
981
            args,
            args.model_type == ModelType.encoder_and_decoder,
            layer_type == LayerType.decoder)
Mohammad's avatar
Mohammad committed
982

Vijay Korthikanti's avatar
Vijay Korthikanti committed
983
        self.drop_path_rates = [rate.item() for rate in torch.linspace(0, self.drop_path_rate, args.num_layers)]
984

Mohammad's avatar
Mohammad committed
985
986
        # Transformer layers.
        def build_layer(layer_number):
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
            if args.transformer_impl == 'local':
                return ParallelTransformerLayer(
                    init_method,
                    output_layer_init_method,
                    layer_number,
                    layer_type=layer_type,
                    self_attn_mask_type=self_attn_mask_type,
                    drop_path_rate=self.drop_path_rates[layer_number - 1])
            else:
                return transformer_engine.pytorch.TransformerLayer(
                    args.hidden_size,
                    args.ffn_hidden_size,
                    args.num_attention_heads,
                    layernorm_epsilon=args.layernorm_epsilon,
                    hidden_dropout=args.hidden_dropout,
                    attention_dropout=args.attention_dropout,
                    init_method=init_method,
                    output_layer_init_method=output_layer_init_method,
                    layer_number=layer_number,
                    kv_channels=args.kv_channels,
                    self_attn_mask_type=self_attn_mask_type.name,
                    tp_group=mpu.get_tensor_model_parallel_group(),
                    get_rng_state_tracker=tensor_parallel.get_cuda_rng_tracker,
                    fuse_wgrad_accumulation=args.gradient_accumulation_fusion,
                    apply_query_key_layer_scaling=args.apply_query_key_layer_scaling,
                    attention_softmax_in_fp32=args.attention_softmax_in_fp32,
                    seq_length=args.seq_length,
                    micro_batch_size=args.micro_batch_size,
                    sequence_parallel=args.sequence_parallel,
                    params_dtype=args.params_dtype,
                    apply_residual_connection_post_layernorm=args.apply_residual_connection_post_layernorm,
                    output_layernorm=False,
                    layer_type="encoder",
                    drop_path_rate=self.drop_path_rates[layer_number - 1],
                    set_parallel_mode=True,
                    fuse_qkv_params=True)

1024
1025
        if args.virtual_pipeline_model_parallel_size is not None:
            assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
1026
1027
                'num_layers_per_stage must be divisible by ' \
                'virtual_pipeline_model_parallel_size'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1028
            assert args.model_type != ModelType.encoder_and_decoder
1029
1030
            # Number of layers in each model chunk is the number of layers in the stage,
            # divided by the number of model chunks in a stage.
1031
            self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
1032
1033
1034
1035
1036
1037
1038
1039
            # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0]  [2]  [4]  [6]
            # Stage 1: [1]  [3]  [5]  [7]
            # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
            # layers to stages like (each list is a model chunk):
            # Stage 0: [0, 1]  [4, 5]
            # Stage 1: [2, 3]  [6, 7]
1040
            offset = mpu.get_virtual_pipeline_model_parallel_rank() * (
1041
                args.num_layers // args.virtual_pipeline_model_parallel_size) + \
1042
                (mpu.get_pipeline_model_parallel_rank() * self.num_layers)
1043
        else:
1044
            # Each stage gets a contiguous set of layers.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1045
            if args.model_type == ModelType.encoder_and_decoder and \
1046
1047
                    mpu.get_pipeline_model_parallel_world_size() > 1:
                pipeline_rank = mpu.get_pipeline_model_parallel_rank()
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1048
1049
1050
1051
1052
1053
                if layer_type == LayerType.encoder:
                    offset = pipeline_rank * self.num_layers
                else:
                    num_ranks_in_enc = args.pipeline_model_parallel_split_rank
                    offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers
            else:
1054
                offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
1055

1056
        if self.num_layers == 0:
Lawrence McAfee's avatar
Lawrence McAfee committed
1057
            # When a standalone embedding stage is used (e.g.,
1058
            # args.standalone_embedding_stage == True), virtual pipeline ranks
1059
            # on pipeline rank 0 will have zero transformer layers assigned to
Lawrence McAfee's avatar
Lawrence McAfee committed
1060
1061
1062
1063
1064
            # them. This results in the model's input and output tensors to be
            # the same, which will cause failure for certain output tensor
            # optimizations (e.g., pipeline output deallocation). To remedy
            # this, we assign a 'no-op' layer on these ranks, which will
            # disconnect the input tensor from the output tensor.
1065
1066
1067
1068
1069
            self.num_layers = 1
            self.layers = torch.nn.ModuleList([ NoopTransformerLayer(1) ])
        else:
            self.layers = torch.nn.ModuleList(
                [build_layer(i + 1 + offset) for i in range(self.num_layers)])
1070

1071
        if self.post_process and self.post_layer_norm:
1072
1073
1074
            # Final layer norm before output.
            self.final_layernorm = LayerNorm(
                args.hidden_size,
Sangkug Lym's avatar
Sangkug Lym committed
1075
                eps=args.layernorm_epsilon,
1076
                no_persist_layer_norm=args.no_persist_layer_norm,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1077
                sequence_parallel=args.sequence_parallel)
1078

Mohammad's avatar
Mohammad committed
1079
    def _get_layer(self, layer_number):
1080
        return self.layers[layer_number]
Mohammad's avatar
Mohammad committed
1081

1082
    def _checkpointed_forward(self, hidden_states, attention_mask,
Mostofa Patwary's avatar
Mostofa Patwary committed
1083
1084
                              encoder_output, enc_dec_attn_mask,
                              rotary_pos_emb, is_first_microbatch):
1085
        """Forward method with activation checkpointing."""
1086
1087
        def custom(start, end, is_transformer_engine=False):
            def custom_forward(*args, **kwargs):
Mohammad's avatar
Mohammad committed
1088
1089
                for index in range(start, end):
                    layer = self._get_layer(index)
1090
                    x_ = layer(*args, **kwargs)
1091
                return x_
1092
1093
1094
1095
1096
1097
            def custom_forward_transformer_engine(*args, **kwargs):
                return custom_forward(*args, is_first_microbatch=is_first_microbatch, **kwargs)
            if not is_transformer_engine:
                return custom_forward
            else:
                return custom_forward_transformer_engine
1098

Vijay Korthikanti's avatar
Vijay Korthikanti committed
1099
        if self.recompute_method == 'uniform':
1100
1101
1102
1103
1104
            # Uniformly divide the total number of Transformer layers and checkpoint
            # the input activation of each divided chunk.
            # A method to further reduce memory usage reducing checkpoints.
            l = 0
            while l < self.num_layers:
1105
1106
1107
1108
1109
1110
                if self.transformer_impl == 'transformer_engine':
                    hidden_states = transformer_engine.pytorch.distributed.checkpoint(
                        custom(l, l + self.recompute_num_layers, is_transformer_engine=True),
                        self.distribute_saved_activations,
                        tensor_parallel.get_cuda_rng_tracker,
                        mpu.get_tensor_model_parallel_group(),
Mostofa Patwary's avatar
Mostofa Patwary committed
1111
1112
                        hidden_states, attention_mask, encoder_output,
                        enc_dec_attn_mask, rotary_pos_emb)
1113
1114
1115
1116
                else:
                    hidden_states = tensor_parallel.checkpoint(
                        custom(l, l + self.recompute_num_layers),
                        self.distribute_saved_activations,
Mostofa Patwary's avatar
Mostofa Patwary committed
1117
1118
                        hidden_states, attention_mask, encoder_output,
                        enc_dec_attn_mask, rotary_pos_emb)
1119

Vijay Korthikanti's avatar
Vijay Korthikanti committed
1120
                l += self.recompute_num_layers
1121

Vijay Korthikanti's avatar
Vijay Korthikanti committed
1122
        elif self.recompute_method == 'block':
1123
1124
1125
1126
            # Checkpoint the input activation of only a set number of individual
            # Transformer layers and skip the rest.
            # A method fully use the device memory removing redundant re-computation.
            for l in range(self.num_layers):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1127
                if l < self.recompute_num_layers:
1128
1129
1130
1131
1132
1133
                    if self.transformer_impl == 'transformer_engine':
                        hidden_states = transformer_engine.pytorch.distributed.checkpoint(
                            custom(l, l + 1, is_transformer_engine=True),
                            self.distribute_saved_activations,
                            tensor_parallel.get_cuda_rng_tracker,
                            mpu.get_tensor_model_parallel_group(),
Mostofa Patwary's avatar
Mostofa Patwary committed
1134
1135
                            hidden_states, attention_mask, encoder_output,
                            enc_dec_attn_mask, rotary_pos_emb)
1136
1137
1138
1139
                    else:
                        hidden_states = tensor_parallel.checkpoint(
                            custom(l, l + 1),
                            self.distribute_saved_activations,
Mostofa Patwary's avatar
Mostofa Patwary committed
1140
1141
                            hidden_states, attention_mask, encoder_output,
                            enc_dec_attn_mask, rotary_pos_emb)
1142
                else:
1143
1144
                    if self.transformer_impl == 'transformer_engine':
                        hidden_states = custom(l, l + 1, is_transformer_engine=True)(
Mostofa Patwary's avatar
Mostofa Patwary committed
1145
1146
                            hidden_states, attention_mask, encoder_output,
                            enc_dec_attn_mask, rotary_pos_emb)
1147
1148
                    else:
                        hidden_states = custom(l, l + 1)(
Mostofa Patwary's avatar
Mostofa Patwary committed
1149
1150
                            hidden_states, attention_mask, encoder_output,
                            enc_dec_attn_mask, rotary_pos_emb)
1151
        else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1152
            raise ValueError("Invalid activation recompute method.")
1153
1154
1155

        return hidden_states

1156
    def set_input_tensor(self, input_tensor):
1157
1158
1159
1160
1161
1162
1163
        """Set input tensor to be used instead of forward()'s input.

        When doing pipeline parallelism the input from the previous
        stage comes from communication, not from the input, so the
        model's forward_step_func won't have it. This function is thus
        used by internal code to bypass the input provided by the
        forward_step_func"""
1164
1165
        self.input_tensor = input_tensor

1166
    def forward(self, hidden_states, attention_mask,
mshoeybi's avatar
mshoeybi committed
1167
                encoder_output=None, enc_dec_attn_mask=None,
Mostofa Patwary's avatar
Mostofa Patwary committed
1168
                inference_params=None, rotary_pos_emb=None):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1169
1170
        # hidden_states: [s, b, h]

1171
        # Checks.
mshoeybi's avatar
mshoeybi committed
1172
        if inference_params:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1173
            assert self.recompute_granularity is None, \
1174
                'inference does not work with activation checkpointing'
1175

1176
        if not self.pre_process:
1177
            # See set_input_tensor()
1178
            hidden_states = self.input_tensor
1179

1180
1181
        # Viewless tensor.
        # - We only need to create a viewless tensor in the case of micro batch
1182
1183
1184
1185
        #   size (mbs) == 1, since in this case, 'hidden_states.transpose()'
        #   above creates a view tensor, and '.contiguous()' is a pass-through.
        #   For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
        #   the need to make it viewless.
1186
1187
1188
1189
        #
        #   However, we don't explicitly check mbs == 1 here because
        #   make_viewless_tensor() has negligible overhead when its input
        #   is already viewless.
1190
        #
1191
1192
1193
1194
        # - For the 'else' case above, calling make_viewless_tensor() here is
        #   likely redundant, since p2p_communication.py (likely originator)
        #   already creates viewless tensors. That said, make_viewless_tensor()
        #   is called here to be future-proof and corner-case-proof.
1195
        hidden_states = core.utils.make_viewless_tensor(
1196
            hidden_states,
1197
1198
            requires_grad=True,
            keep_graph=True,
1199
1200
        )

Vijay Korthikanti's avatar
Vijay Korthikanti committed
1201
        if self.sequence_parallel:
1202
            rng_context = tensor_parallel.get_cuda_rng_tracker().fork()
1203
        else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1204
            rng_context = nullcontext()
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1205
1206

        with rng_context:
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
            # The fp8_autocast context manager is a no-op when enabled=True
            # The if...else serves to short circuit name resolution for fp8_autocast
            with transformer_engine.pytorch.fp8_autocast(
                enabled=self.use_fp8,
                fp8_recipe=self.fp8_recipe,
                fp8_group=self.fp8_group
            ) if self.use_fp8 else nullcontext():
                # Determine if the current iteration is first microbatch
                if self.num_microbatches_in_previous_step != get_num_microbatches():
                    self.microbatch_count = 0 # Reset count on new batch size rampup interval
                self.num_microbatches_in_previous_step = get_num_microbatches()
                is_first_microbatch = self.microbatch_count % get_num_microbatches() == 0

                # Forward pass.
                if self.recompute_granularity == 'full':
                    hidden_states = self._checkpointed_forward(hidden_states,
                                                               attention_mask,
                                                               encoder_output,
                                                               enc_dec_attn_mask,
Mostofa Patwary's avatar
Mostofa Patwary committed
1226
                                                               rotary_pos_emb,
1227
1228
1229
1230
1231
1232
                                                               is_first_microbatch)
                else:
                    forward_kwargs = {
                        'encoder_output': encoder_output,
                        'enc_dec_attn_mask': enc_dec_attn_mask,
                        'inference_params': inference_params,
Mostofa Patwary's avatar
Mostofa Patwary committed
1233
                        'rotary_pos_emb': rotary_pos_emb,
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
                    }

                    if self.transformer_impl == 'transformer_engine':
                        forward_kwargs['is_first_microbatch'] = is_first_microbatch
                        forward_kwargs['checkpoint_core_attention'] = self.checkpoint_core_attention

                    for index in range(self.num_layers):
                        layer = self._get_layer(index)

                        hidden_states = layer(
                            hidden_states,
                            attention_mask,
                            **forward_kwargs)

                # Skip counter update for eval and activation checkpointing
                if torch.is_grad_enabled() and self.training:
                    self.microbatch_count += 1
mshoeybi's avatar
mshoeybi committed
1251

1252
        # Final layer norm.
1253
        if self.post_process and self.post_layer_norm:
1254
1255
            hidden_states = self.final_layernorm(hidden_states)

1256
        return hidden_states